xref: /aosp_15_r20/external/cronet/net/dns/dns_query.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/dns/dns_query.h"
6 
7 #include <optional>
8 #include <string_view>
9 #include <utility>
10 
11 #include "base/big_endian.h"
12 #include "base/containers/span.h"
13 #include "base/containers/span_writer.h"
14 #include "base/logging.h"
15 #include "base/memory/ptr_util.h"
16 #include "base/numerics/byte_conversions.h"
17 #include "base/numerics/safe_conversions.h"
18 #include "base/sys_byteorder.h"
19 #include "net/base/io_buffer.h"
20 #include "net/dns/dns_names_util.h"
21 #include "net/dns/opt_record_rdata.h"
22 #include "net/dns/public/dns_protocol.h"
23 #include "net/dns/record_rdata.h"
24 
25 namespace net {
26 
27 namespace {
28 
29 const size_t kHeaderSize = sizeof(dns_protocol::Header);
30 
31 // Size of the fixed part of an OPT RR:
32 // https://tools.ietf.org/html/rfc6891#section-6.1.2
33 static const size_t kOptRRFixedSize = 11;
34 
35 // https://tools.ietf.org/html/rfc6891#section-6.2.5
36 // TODO(robpercival): Determine a good value for this programmatically.
37 const uint16_t kMaxUdpPayloadSize = 4096;
38 
QuestionSize(size_t qname_size)39 size_t QuestionSize(size_t qname_size) {
40   // QNAME + QTYPE + QCLASS
41   return qname_size + sizeof(uint16_t) + sizeof(uint16_t);
42 }
43 
44 // Buffer size of Opt record for |rdata| (does not include Opt record or RData
45 // added for padding).
OptRecordSize(const OptRecordRdata * rdata)46 size_t OptRecordSize(const OptRecordRdata* rdata) {
47   return rdata == nullptr ? 0 : kOptRRFixedSize + rdata->buf().size();
48 }
49 
50 // Padding size includes Opt header for the padding.  Does not include OptRecord
51 // header (kOptRRFixedSize) even when added just for padding.
DeterminePaddingSize(size_t unpadded_size,DnsQuery::PaddingStrategy padding_strategy)52 size_t DeterminePaddingSize(size_t unpadded_size,
53                             DnsQuery::PaddingStrategy padding_strategy) {
54   switch (padding_strategy) {
55     case DnsQuery::PaddingStrategy::NONE:
56       return 0;
57     case DnsQuery::PaddingStrategy::BLOCK_LENGTH_128:
58       size_t padding_size = OptRecordRdata::Opt::kHeaderSize;
59       size_t remainder = (padding_size + unpadded_size) % 128;
60       padding_size += (128 - remainder) % 128;
61       DCHECK_EQ((unpadded_size + padding_size) % 128, 0u);
62       return padding_size;
63   }
64 }
65 
AddPaddingIfNecessary(const OptRecordRdata * opt_rdata,DnsQuery::PaddingStrategy padding_strategy,size_t no_opt_buffer_size)66 std::unique_ptr<OptRecordRdata> AddPaddingIfNecessary(
67     const OptRecordRdata* opt_rdata,
68     DnsQuery::PaddingStrategy padding_strategy,
69     size_t no_opt_buffer_size) {
70   // If no input OPT record rdata and no padding, no OPT record rdata needed.
71   if (!opt_rdata && padding_strategy == DnsQuery::PaddingStrategy::NONE)
72     return nullptr;
73 
74   std::unique_ptr<OptRecordRdata> merged_opt_rdata;
75   if (opt_rdata) {
76     merged_opt_rdata = OptRecordRdata::Create(
77         std::string_view(opt_rdata->buf().data(), opt_rdata->buf().size()));
78   } else {
79     merged_opt_rdata = std::make_unique<OptRecordRdata>();
80   }
81   DCHECK(merged_opt_rdata);
82 
83   size_t unpadded_size =
84       no_opt_buffer_size + OptRecordSize(merged_opt_rdata.get());
85   size_t padding_size = DeterminePaddingSize(unpadded_size, padding_strategy);
86 
87   if (padding_size > 0) {
88     // |opt_rdata| must not already contain padding if DnsQuery is to add
89     // padding.
90     DCHECK(!merged_opt_rdata->ContainsOptCode(dns_protocol::kEdnsPadding));
91     // OPT header is the minimum amount of padding.
92     DCHECK(padding_size >= OptRecordRdata::Opt::kHeaderSize);
93 
94     merged_opt_rdata->AddOpt(std::make_unique<OptRecordRdata::PaddingOpt>(
95         padding_size - OptRecordRdata::Opt::kHeaderSize));
96   }
97 
98   return merged_opt_rdata;
99 }
100 
101 }  // namespace
102 
103 // DNS query consists of a 12-byte header followed by a question section.
104 // For details, see RFC 1035 section 4.1.1.  This header template sets RD
105 // bit, which directs the name server to pursue query recursively, and sets
106 // the QDCOUNT to 1, meaning the question section has a single entry.
DnsQuery(uint16_t id,base::span<const uint8_t> qname,uint16_t qtype,const OptRecordRdata * opt_rdata,PaddingStrategy padding_strategy)107 DnsQuery::DnsQuery(uint16_t id,
108                    base::span<const uint8_t> qname,
109                    uint16_t qtype,
110                    const OptRecordRdata* opt_rdata,
111                    PaddingStrategy padding_strategy)
112     : qname_size_(qname.size()) {
113 #if DCHECK_IS_ON()
114   std::optional<std::string> dotted_name =
115       dns_names_util::NetworkToDottedName(qname);
116   DCHECK(dotted_name && !dotted_name.value().empty());
117 #endif  // DCHECK_IS_ON()
118 
119   size_t buffer_size = kHeaderSize + QuestionSize(qname_size_);
120   std::unique_ptr<OptRecordRdata> merged_opt_rdata =
121       AddPaddingIfNecessary(opt_rdata, padding_strategy, buffer_size);
122   if (merged_opt_rdata)
123     buffer_size += OptRecordSize(merged_opt_rdata.get());
124 
125   io_buffer_ = base::MakeRefCounted<IOBufferWithSize>(buffer_size);
126 
127   dns_protocol::Header* header = header_in_io_buffer();
128   *header = {};
129   header->id = base::HostToNet16(id);
130   header->flags = base::HostToNet16(dns_protocol::kFlagRD);
131   header->qdcount = base::HostToNet16(1);
132 
133   // Write question section after the header.
134   auto writer = base::SpanWriter(
135       base::as_writable_bytes(io_buffer_->span()).subspan(kHeaderSize));
136   writer.Write(qname);
137   writer.WriteU16BigEndian(qtype);
138   writer.WriteU16BigEndian(dns_protocol::kClassIN);
139 
140   if (merged_opt_rdata) {
141     DCHECK_NE(merged_opt_rdata->OptCount(), 0u);
142 
143     header->arcount = base::HostToNet16(1);
144     // Write OPT pseudo-resource record.
145     writer.WriteU8BigEndian(0);  // empty domain name (root domain)
146     writer.WriteU16BigEndian(OptRecordRdata::kType);  // type
147     writer.WriteU16BigEndian(kMaxUdpPayloadSize);     // class
148     // ttl (next 3 fields)
149     writer.WriteU8BigEndian(0);  // rcode does not apply to requests
150     writer.WriteU8BigEndian(0);  // version
151     // TODO(robpercival): Set "DNSSEC OK" flag if/when DNSSEC is supported:
152     // https://tools.ietf.org/html/rfc3225#section-3
153     writer.WriteU16BigEndian(0);  // flags
154 
155     // rdata
156     writer.WriteU16BigEndian(merged_opt_rdata->buf().size());  // rdata length
157     writer.Write(base::as_byte_span(merged_opt_rdata->buf()));
158   }
159 }
160 
DnsQuery(scoped_refptr<IOBufferWithSize> buffer)161 DnsQuery::DnsQuery(scoped_refptr<IOBufferWithSize> buffer)
162     : io_buffer_(std::move(buffer)) {}
163 
DnsQuery(const DnsQuery & query)164 DnsQuery::DnsQuery(const DnsQuery& query) {
165   CopyFrom(query);
166 }
167 
operator =(const DnsQuery & query)168 DnsQuery& DnsQuery::operator=(const DnsQuery& query) {
169   CopyFrom(query);
170   return *this;
171 }
172 
173 DnsQuery::~DnsQuery() = default;
174 
CloneWithNewId(uint16_t id) const175 std::unique_ptr<DnsQuery> DnsQuery::CloneWithNewId(uint16_t id) const {
176   return base::WrapUnique(new DnsQuery(*this, id));
177 }
178 
Parse(size_t valid_bytes)179 bool DnsQuery::Parse(size_t valid_bytes) {
180   if (io_buffer_ == nullptr || io_buffer_->span().empty()) {
181     return false;
182   }
183   auto reader =
184       base::SpanReader(base::as_bytes(io_buffer_->span()).first(valid_bytes));
185   dns_protocol::Header header;
186   if (!ReadHeader(&reader, &header)) {
187     return false;
188   }
189   if (header.flags & dns_protocol::kFlagResponse) {
190     return false;
191   }
192   if (header.qdcount != 1) {
193     VLOG(1) << "Not supporting parsing a DNS query with multiple (or zero) "
194                "questions.";
195     return false;
196   }
197   std::string qname;
198   if (!ReadName(&reader, &qname)) {
199     return false;
200   }
201   uint16_t qtype;
202   uint16_t qclass;
203   if (!reader.ReadU16BigEndian(qtype) || !reader.ReadU16BigEndian(qclass) ||
204       qclass != dns_protocol::kClassIN) {
205     return false;
206   }
207   // |io_buffer_| now contains the raw packet of a valid DNS query, we just
208   // need to properly initialize |qname_size_|.
209   qname_size_ = qname.size();
210   return true;
211 }
212 
id() const213 uint16_t DnsQuery::id() const {
214   return base::NetToHost16(header_in_io_buffer()->id);
215 }
216 
qname() const217 base::span<const uint8_t> DnsQuery::qname() const {
218   return base::as_bytes(io_buffer_->span()).subspan(kHeaderSize, qname_size_);
219 }
220 
qtype() const221 uint16_t DnsQuery::qtype() const {
222   return base::U16FromBigEndian(base::as_bytes(io_buffer_->span())
223                                     .subspan(kHeaderSize + qname_size_)
224                                     .first<2u>());
225 }
226 
question() const227 std::string_view DnsQuery::question() const {
228   auto s = io_buffer_->span().subspan(kHeaderSize, QuestionSize(qname_size_));
229   return std::string_view(s.begin(), s.end());
230 }
231 
question_size() const232 size_t DnsQuery::question_size() const {
233   return QuestionSize(qname_size_);
234 }
235 
set_flags(uint16_t flags)236 void DnsQuery::set_flags(uint16_t flags) {
237   header_in_io_buffer()->flags = flags;
238 }
239 
DnsQuery(const DnsQuery & orig,uint16_t id)240 DnsQuery::DnsQuery(const DnsQuery& orig, uint16_t id) {
241   CopyFrom(orig);
242   header_in_io_buffer()->id = base::HostToNet16(id);
243 }
244 
CopyFrom(const DnsQuery & orig)245 void DnsQuery::CopyFrom(const DnsQuery& orig) {
246   qname_size_ = orig.qname_size_;
247   io_buffer_ = base::MakeRefCounted<IOBufferWithSize>(orig.io_buffer()->size());
248   io_buffer_->span().copy_from(orig.io_buffer()->span());
249 }
250 
ReadHeader(base::SpanReader<const uint8_t> * reader,dns_protocol::Header * header)251 bool DnsQuery::ReadHeader(base::SpanReader<const uint8_t>* reader,
252                           dns_protocol::Header* header) {
253   return (reader->ReadU16BigEndian(header->id) &&
254           reader->ReadU16BigEndian(header->flags) &&
255           reader->ReadU16BigEndian(header->qdcount) &&
256           reader->ReadU16BigEndian(header->ancount) &&
257           reader->ReadU16BigEndian(header->nscount) &&
258           reader->ReadU16BigEndian(header->arcount));
259 }
260 
ReadName(base::SpanReader<const uint8_t> * reader,std::string * out)261 bool DnsQuery::ReadName(base::SpanReader<const uint8_t>* reader,
262                         std::string* out) {
263   DCHECK(out != nullptr);
264   out->clear();
265   out->reserve(dns_protocol::kMaxNameLength + 1);
266   uint8_t label_length;
267   if (!reader->ReadU8BigEndian(label_length)) {
268     return false;
269   }
270   while (label_length) {
271     if (out->size() + 1 + label_length > dns_protocol::kMaxNameLength) {
272       return false;
273     }
274 
275     out->push_back(static_cast<char>(label_length));
276 
277     std::optional<base::span<const uint8_t>> label = reader->Read(label_length);
278     if (!label) {
279       return false;
280     }
281     out->append(base::as_string_view(*label));
282 
283     if (!reader->ReadU8BigEndian(label_length)) {
284       return false;
285     }
286   }
287   DCHECK_LE(out->size(), static_cast<size_t>(dns_protocol::kMaxNameLength));
288   out->append(1, '\0');
289   return true;
290 }
291 
292 }  // namespace net
293