xref: /aosp_15_r20/external/cronet/net/dns/dns_test_util.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/dns/dns_test_util.h"
6 
7 #include <cstdint>
8 #include <optional>
9 #include <string>
10 #include <string_view>
11 #include <utility>
12 #include <vector>
13 
14 #include "base/check.h"
15 #include "base/containers/span.h"
16 #include "base/functional/bind.h"
17 #include "base/location.h"
18 #include "base/numerics/byte_conversions.h"
19 #include "base/numerics/safe_conversions.h"
20 #include "base/ranges/algorithm.h"
21 #include "base/strings/strcat.h"
22 #include "base/sys_byteorder.h"
23 #include "base/task/single_thread_task_runner.h"
24 #include "base/test/test_timeouts.h"
25 #include "base/threading/thread_restrictions.h"
26 #include "base/time/time.h"
27 #include "base/types/optional_util.h"
28 #include "net/base/io_buffer.h"
29 #include "net/base/ip_address.h"
30 #include "net/base/ip_endpoint.h"
31 #include "net/base/net_errors.h"
32 #include "net/dns/address_sorter.h"
33 #include "net/dns/dns_hosts.h"
34 #include "net/dns/dns_names_util.h"
35 #include "net/dns/dns_query.h"
36 #include "net/dns/dns_session.h"
37 #include "net/dns/mock_host_resolver.h"
38 #include "net/dns/public/dns_over_https_server_config.h"
39 #include "net/dns/resolve_context.h"
40 #include "testing/gmock/include/gmock/gmock-matchers.h"
41 #include "testing/gtest/include/gtest/gtest.h"
42 #include "url/scheme_host_port.h"
43 
44 namespace net {
45 namespace {
46 
47 const uint8_t kMalformedResponseHeader[] = {
48     // Header
49     0x00, 0x14,  // Arbitrary ID
50     0x81, 0x80,  // Standard query response, RA, no error
51     0x00, 0x01,  // 1 question
52     0x00, 0x01,  // 1 RR (answers)
53     0x00, 0x00,  // 0 authority RRs
54     0x00, 0x00,  // 0 additional RRs
55 };
56 
57 // Create a response containing a valid question (as would normally be validated
58 // in DnsTransaction) but completely missing a header-declared answer.
CreateMalformedResponse(std::string hostname,uint16_t type)59 DnsResponse CreateMalformedResponse(std::string hostname, uint16_t type) {
60   std::optional<std::vector<uint8_t>> dns_name =
61       dns_names_util::DottedNameToNetwork(hostname);
62   CHECK(dns_name.has_value());
63   DnsQuery query(/*id=*/0x14, dns_name.value(), type);
64 
65   // Build response to simulate the barebones validation DnsResponse applies to
66   // responses received from the network.
67   auto buffer = base::MakeRefCounted<IOBufferWithSize>(
68       sizeof(kMalformedResponseHeader) + query.question().size());
69   memcpy(buffer->data(), kMalformedResponseHeader,
70          sizeof(kMalformedResponseHeader));
71   memcpy(buffer->data() + sizeof(kMalformedResponseHeader),
72          query.question().data(), query.question().size());
73 
74   DnsResponse response(buffer, buffer->size());
75   CHECK(response.InitParseWithoutQuery(buffer->size()));
76 
77   return response;
78 }
79 
80 class MockAddressSorter : public AddressSorter {
81  public:
82   ~MockAddressSorter() override = default;
Sort(const std::vector<IPEndPoint> & endpoints,CallbackType callback) const83   void Sort(const std::vector<IPEndPoint>& endpoints,
84             CallbackType callback) const override {
85     // Do nothing.
86     std::move(callback).Run(true, endpoints);
87   }
88 };
89 
90 }  // namespace
91 
CreateValidDnsConfig()92 DnsConfig CreateValidDnsConfig() {
93   IPAddress dns_ip(192, 168, 1, 0);
94   DnsConfig config;
95   config.nameservers.emplace_back(dns_ip, dns_protocol::kDefaultPort);
96   config.doh_config =
97       *DnsOverHttpsConfig::FromString("https://dns.example.com/");
98   config.secure_dns_mode = SecureDnsMode::kOff;
99   EXPECT_TRUE(config.IsValid());
100   return config;
101 }
102 
BuildTestDnsRecord(std::string name,uint16_t type,std::string rdata,base::TimeDelta ttl)103 DnsResourceRecord BuildTestDnsRecord(std::string name,
104                                      uint16_t type,
105                                      std::string rdata,
106                                      base::TimeDelta ttl) {
107   DCHECK(!name.empty());
108 
109   DnsResourceRecord record;
110   record.name = std::move(name);
111   record.type = type;
112   record.klass = dns_protocol::kClassIN;
113   record.ttl = ttl.InSeconds();
114 
115   if (!rdata.empty())
116     record.SetOwnedRdata(std::move(rdata));
117 
118   return record;
119 }
120 
BuildTestCnameRecord(std::string name,std::string_view canonical_name,base::TimeDelta ttl)121 DnsResourceRecord BuildTestCnameRecord(std::string name,
122                                        std::string_view canonical_name,
123                                        base::TimeDelta ttl) {
124   DCHECK(!name.empty());
125   DCHECK(!canonical_name.empty());
126 
127   std::optional<std::vector<uint8_t>> rdata =
128       dns_names_util::DottedNameToNetwork(canonical_name);
129   CHECK(rdata.has_value());
130 
131   return BuildTestDnsRecord(
132       std::move(name), dns_protocol::kTypeCNAME,
133       std::string(reinterpret_cast<char*>(rdata.value().data()),
134                   rdata.value().size()),
135       ttl);
136 }
137 
BuildTestAddressRecord(std::string name,const IPAddress & ip,base::TimeDelta ttl)138 DnsResourceRecord BuildTestAddressRecord(std::string name,
139                                          const IPAddress& ip,
140                                          base::TimeDelta ttl) {
141   DCHECK(!name.empty());
142   DCHECK(ip.IsValid());
143 
144   return BuildTestDnsRecord(
145       std::move(name),
146       ip.IsIPv4() ? dns_protocol::kTypeA : dns_protocol::kTypeAAAA,
147       net::IPAddressToPackedString(ip), ttl);
148 }
149 
BuildTestTextRecord(std::string name,std::vector<std::string> text_strings,base::TimeDelta ttl)150 DnsResourceRecord BuildTestTextRecord(std::string name,
151                                       std::vector<std::string> text_strings,
152                                       base::TimeDelta ttl) {
153   DCHECK(!text_strings.empty());
154 
155   std::string rdata;
156   for (const std::string& text_string : text_strings) {
157     DCHECK(!text_string.empty());
158 
159     rdata += base::checked_cast<unsigned char>(text_string.size());
160     rdata += text_string;
161   }
162 
163   return BuildTestDnsRecord(std::move(name), dns_protocol::kTypeTXT,
164                             std::move(rdata), ttl);
165 }
166 
BuildTestHttpsAliasRecord(std::string name,std::string_view alias_name,base::TimeDelta ttl)167 DnsResourceRecord BuildTestHttpsAliasRecord(std::string name,
168                                             std::string_view alias_name,
169                                             base::TimeDelta ttl) {
170   DCHECK(!name.empty());
171 
172   std::string rdata("\000\000", 2);
173 
174   std::optional<std::vector<uint8_t>> alias_domain =
175       dns_names_util::DottedNameToNetwork(alias_name);
176   CHECK(alias_domain.has_value());
177   rdata.append(reinterpret_cast<char*>(alias_domain.value().data()),
178                alias_domain.value().size());
179 
180   return BuildTestDnsRecord(std::move(name), dns_protocol::kTypeHttps,
181                             std::move(rdata), ttl);
182 }
183 
BuildTestHttpsServiceAlpnParam(const std::vector<std::string> & alpns)184 std::pair<uint16_t, std::string> BuildTestHttpsServiceAlpnParam(
185     const std::vector<std::string>& alpns) {
186   std::string param_value;
187 
188   for (const std::string& alpn : alpns) {
189     CHECK(!alpn.empty());
190     param_value.append(
191         1, static_cast<char>(base::checked_cast<uint8_t>(alpn.size())));
192     param_value.append(alpn);
193   }
194 
195   return std::pair(dns_protocol::kHttpsServiceParamKeyAlpn,
196                    std::move(param_value));
197 }
198 
BuildTestHttpsServiceEchConfigParam(base::span<const uint8_t> ech_config_list)199 std::pair<uint16_t, std::string> BuildTestHttpsServiceEchConfigParam(
200     base::span<const uint8_t> ech_config_list) {
201   return std::pair(
202       dns_protocol::kHttpsServiceParamKeyEchConfig,
203       std::string(reinterpret_cast<const char*>(ech_config_list.data()),
204                   ech_config_list.size()));
205 }
206 
BuildTestHttpsServiceMandatoryParam(std::vector<uint16_t> param_key_list)207 std::pair<uint16_t, std::string> BuildTestHttpsServiceMandatoryParam(
208     std::vector<uint16_t> param_key_list) {
209   base::ranges::sort(param_key_list);
210 
211   std::string value;
212   for (uint16_t param_key : param_key_list) {
213     std::array<uint8_t, 2> num_buffer = base::U16ToBigEndian(param_key);
214     value.append(num_buffer.begin(), num_buffer.end());
215   }
216 
217   return std::pair(dns_protocol::kHttpsServiceParamKeyMandatory,
218                    std::move(value));
219 }
220 
BuildTestHttpsServicePortParam(uint16_t port)221 std::pair<uint16_t, std::string> BuildTestHttpsServicePortParam(uint16_t port) {
222   std::array<uint8_t, 2> buffer = base::U16ToBigEndian(port);
223   return std::pair(dns_protocol::kHttpsServiceParamKeyPort,
224                    std::string(buffer.begin(), buffer.end()));
225 }
226 
BuildTestHttpsServiceRecord(std::string name,uint16_t priority,std::string_view service_name,const std::map<uint16_t,std::string> & params,base::TimeDelta ttl)227 DnsResourceRecord BuildTestHttpsServiceRecord(
228     std::string name,
229     uint16_t priority,
230     std::string_view service_name,
231     const std::map<uint16_t, std::string>& params,
232     base::TimeDelta ttl) {
233   DCHECK(!name.empty());
234   DCHECK_NE(priority, 0);
235 
236   std::string rdata;
237 
238   {
239     std::array<uint8_t, 2> buf = base::U16ToBigEndian(priority);
240     rdata.append(buf.begin(), buf.end());
241   }
242 
243   std::optional<std::vector<uint8_t>> service_domain;
244   if (service_name == ".") {
245     // HTTPS records have special behavior for `service_name == "."` (that it
246     // will be treated as if the service name is the same as the record owner
247     // name), so allow such inputs despite normally being disallowed for
248     // Chrome-encoded DNS names.
249     service_domain = std::vector<uint8_t>{0};
250   } else {
251     service_domain = dns_names_util::DottedNameToNetwork(service_name);
252   }
253   CHECK(service_domain.has_value());
254   rdata.append(reinterpret_cast<char*>(service_domain.value().data()),
255                service_domain.value().size());
256 
257   for (auto& param : params) {
258     {
259       std::array<uint8_t, 2> buf = base::U16ToBigEndian(param.first);
260       rdata.append(buf.begin(), buf.end());
261     }
262     {
263       std::array<uint8_t, 2> buf = base::U16ToBigEndian(
264           base::checked_cast<uint16_t>(param.second.size()));
265       rdata.append(buf.begin(), buf.end());
266     }
267     rdata.append(param.second);
268   }
269 
270   return BuildTestDnsRecord(std::move(name), dns_protocol::kTypeHttps,
271                             std::move(rdata), ttl);
272 }
273 
BuildTestDnsResponse(std::string name,uint16_t type,const std::vector<DnsResourceRecord> & answers,const std::vector<DnsResourceRecord> & authority,const std::vector<DnsResourceRecord> & additional,uint8_t rcode)274 DnsResponse BuildTestDnsResponse(
275     std::string name,
276     uint16_t type,
277     const std::vector<DnsResourceRecord>& answers,
278     const std::vector<DnsResourceRecord>& authority,
279     const std::vector<DnsResourceRecord>& additional,
280     uint8_t rcode) {
281   DCHECK(!name.empty());
282 
283   std::optional<std::vector<uint8_t>> dns_name =
284       dns_names_util::DottedNameToNetwork(name);
285   CHECK(dns_name.has_value());
286 
287   std::optional<DnsQuery> query(std::in_place, 0, dns_name.value(), type);
288   return DnsResponse(0, true /* is_authoritative */, answers,
289                      authority /* authority_records */,
290                      additional /* additional_records */, query, rcode,
291                      false /* validate_records */);
292 }
293 
BuildTestDnsAddressResponse(std::string name,const IPAddress & ip,std::string answer_name)294 DnsResponse BuildTestDnsAddressResponse(std::string name,
295                                         const IPAddress& ip,
296                                         std::string answer_name) {
297   DCHECK(ip.IsValid());
298 
299   if (answer_name.empty())
300     answer_name = name;
301 
302   std::vector<DnsResourceRecord> answers = {
303       BuildTestAddressRecord(std::move(answer_name), ip)};
304 
305   return BuildTestDnsResponse(
306       std::move(name),
307       ip.IsIPv4() ? dns_protocol::kTypeA : dns_protocol::kTypeAAAA, answers);
308 }
309 
BuildTestDnsAddressResponseWithCname(std::string name,const IPAddress & ip,std::string cannonname,std::string answer_name)310 DnsResponse BuildTestDnsAddressResponseWithCname(std::string name,
311                                                  const IPAddress& ip,
312                                                  std::string cannonname,
313                                                  std::string answer_name) {
314   DCHECK(ip.IsValid());
315   DCHECK(!cannonname.empty());
316 
317   if (answer_name.empty())
318     answer_name = name;
319 
320   std::optional<std::vector<uint8_t>> cname_rdata =
321       dns_names_util::DottedNameToNetwork(cannonname);
322   CHECK(cname_rdata.has_value());
323 
324   std::vector<DnsResourceRecord> answers = {
325       BuildTestDnsRecord(
326           std::move(answer_name), dns_protocol::kTypeCNAME,
327           std::string(reinterpret_cast<char*>(cname_rdata.value().data()),
328                       cname_rdata.value().size())),
329       BuildTestAddressRecord(std::move(cannonname), ip)};
330 
331   return BuildTestDnsResponse(
332       std::move(name),
333       ip.IsIPv4() ? dns_protocol::kTypeA : dns_protocol::kTypeAAAA, answers);
334 }
335 
BuildTestDnsTextResponse(std::string name,std::vector<std::vector<std::string>> text_records,std::string answer_name)336 DnsResponse BuildTestDnsTextResponse(
337     std::string name,
338     std::vector<std::vector<std::string>> text_records,
339     std::string answer_name) {
340   if (answer_name.empty())
341     answer_name = name;
342 
343   std::vector<DnsResourceRecord> answers;
344   for (std::vector<std::string>& text_record : text_records) {
345     answers.push_back(BuildTestTextRecord(answer_name, std::move(text_record)));
346   }
347 
348   return BuildTestDnsResponse(std::move(name), dns_protocol::kTypeTXT, answers);
349 }
350 
BuildTestDnsPointerResponse(std::string name,std::vector<std::string> pointer_names,std::string answer_name)351 DnsResponse BuildTestDnsPointerResponse(std::string name,
352                                         std::vector<std::string> pointer_names,
353                                         std::string answer_name) {
354   if (answer_name.empty())
355     answer_name = name;
356 
357   std::vector<DnsResourceRecord> answers;
358   for (std::string& pointer_name : pointer_names) {
359     std::optional<std::vector<uint8_t>> rdata =
360         dns_names_util::DottedNameToNetwork(pointer_name);
361     CHECK(rdata.has_value());
362 
363     answers.push_back(BuildTestDnsRecord(
364         answer_name, dns_protocol::kTypePTR,
365         std::string(reinterpret_cast<char*>(rdata.value().data()),
366                     rdata.value().size())));
367   }
368 
369   return BuildTestDnsResponse(std::move(name), dns_protocol::kTypePTR, answers);
370 }
371 
BuildTestDnsServiceResponse(std::string name,std::vector<TestServiceRecord> service_records,std::string answer_name)372 DnsResponse BuildTestDnsServiceResponse(
373     std::string name,
374     std::vector<TestServiceRecord> service_records,
375     std::string answer_name) {
376   if (answer_name.empty())
377     answer_name = name;
378 
379   std::vector<DnsResourceRecord> answers;
380   for (TestServiceRecord& service_record : service_records) {
381     std::string rdata;
382     {
383       std::array<uint8_t, 2> buf =
384           base::U16ToBigEndian(service_record.priority);
385       rdata.append(buf.begin(), buf.end());
386     }
387     {
388       std::array<uint8_t, 2> buf = base::U16ToBigEndian(service_record.weight);
389       rdata.append(buf.begin(), buf.end());
390     }
391     {
392       std::array<uint8_t, 2> buf = base::U16ToBigEndian(service_record.port);
393       rdata.append(buf.begin(), buf.end());
394     }
395 
396     std::optional<std::vector<uint8_t>> dns_name =
397         dns_names_util::DottedNameToNetwork(service_record.target);
398     CHECK(dns_name.has_value());
399     rdata.append(reinterpret_cast<char*>(dns_name.value().data()),
400                  dns_name.value().size());
401 
402     answers.push_back(BuildTestDnsRecord(answer_name, dns_protocol::kTypeSRV,
403                                          std::move(rdata), base::Hours(5)));
404   }
405 
406   return BuildTestDnsResponse(std::move(name), dns_protocol::kTypeSRV, answers);
407 }
408 
Result(ResultType type,std::optional<DnsResponse> response,std::optional<int> net_error)409 MockDnsClientRule::Result::Result(ResultType type,
410                                   std::optional<DnsResponse> response,
411                                   std::optional<int> net_error)
412     : type(type), response(std::move(response)), net_error(net_error) {}
413 
Result(DnsResponse response)414 MockDnsClientRule::Result::Result(DnsResponse response)
415     : type(ResultType::kOk),
416       response(std::move(response)),
417       net_error(std::nullopt) {}
418 
419 MockDnsClientRule::Result::Result(Result&&) = default;
420 
421 MockDnsClientRule::Result& MockDnsClientRule::Result::operator=(Result&&) =
422     default;
423 
424 MockDnsClientRule::Result::~Result() = default;
425 
MockDnsClientRule(const std::string & prefix,uint16_t qtype,bool secure,Result result,bool delay,URLRequestContext * context)426 MockDnsClientRule::MockDnsClientRule(const std::string& prefix,
427                                      uint16_t qtype,
428                                      bool secure,
429                                      Result result,
430                                      bool delay,
431                                      URLRequestContext* context)
432     : result(std::move(result)),
433       prefix(prefix),
434       qtype(qtype),
435       secure(secure),
436       delay(delay),
437       context(context) {}
438 
439 MockDnsClientRule::MockDnsClientRule(MockDnsClientRule&& rule) = default;
440 
441 // A DnsTransaction which uses MockDnsClientRuleList to determine the response.
442 class MockDnsTransactionFactory::MockTransaction
443     : public DnsTransaction,
444       public base::SupportsWeakPtr<MockTransaction> {
445  public:
MockTransaction(const MockDnsClientRuleList & rules,std::string hostname,uint16_t qtype,bool secure,bool force_doh_server_available,SecureDnsMode secure_dns_mode,ResolveContext * resolve_context,bool fast_timeout)446   MockTransaction(const MockDnsClientRuleList& rules,
447                   std::string hostname,
448                   uint16_t qtype,
449                   bool secure,
450                   bool force_doh_server_available,
451                   SecureDnsMode secure_dns_mode,
452                   ResolveContext* resolve_context,
453                   bool fast_timeout)
454       : hostname_(std::move(hostname)), qtype_(qtype) {
455     // Do not allow matching any rules if transaction is secure and no DoH
456     // servers are available.
457     if (!secure || force_doh_server_available ||
458         resolve_context->NumAvailableDohServers(
459             resolve_context->current_session_for_testing()) > 0) {
460       // Find the relevant rule which matches |qtype|, |secure|, prefix of
461       // |hostname_|, and |url_request_context| (iff the rule context is not
462       // null).
463       for (const auto& rule : rules) {
464         const std::string& prefix = rule.prefix;
465         if ((rule.qtype == qtype) && (rule.secure == secure) &&
466             (hostname_.size() >= prefix.size()) &&
467             (hostname_.compare(0, prefix.size(), prefix) == 0) &&
468             (!rule.context ||
469              rule.context == resolve_context->url_request_context())) {
470           const MockDnsClientRule::Result* result = &rule.result;
471           result_ = MockDnsClientRule::Result(result->type);
472           result_.net_error = result->net_error;
473           delayed_ = rule.delay;
474 
475           // Generate a DnsResponse when not provided with the rule.
476           std::vector<DnsResourceRecord> authority_records;
477           std::optional<std::vector<uint8_t>> dns_name =
478               dns_names_util::DottedNameToNetwork(hostname_);
479           CHECK(dns_name.has_value());
480           std::optional<DnsQuery> query(std::in_place, /*id=*/22,
481                                         dns_name.value(), qtype_);
482           switch (result->type) {
483             case MockDnsClientRule::ResultType::kNoDomain:
484             case MockDnsClientRule::ResultType::kEmpty:
485               DCHECK(!result->response);  // Not expected to be provided.
486               authority_records = {BuildTestDnsRecord(
487                   hostname_, dns_protocol::kTypeSOA, "fake rdata")};
488               result_.response = DnsResponse(
489                   22 /* id */, false /* is_authoritative */,
490                   std::vector<DnsResourceRecord>() /* answers */,
491                   authority_records,
492                   std::vector<DnsResourceRecord>() /* additional_records */,
493                   query,
494                   result->type == MockDnsClientRule::ResultType::kNoDomain
495                       ? dns_protocol::kRcodeNXDOMAIN
496                       : 0);
497               break;
498             case MockDnsClientRule::ResultType::kFail:
499               if (result->response)
500                 SetResponse(result);
501               break;
502             case MockDnsClientRule::ResultType::kTimeout:
503               DCHECK(!result->response);  // Not expected to be provided.
504               break;
505             case MockDnsClientRule::ResultType::kSlow:
506               if (!fast_timeout)
507                 SetResponse(result);
508               break;
509             case MockDnsClientRule::ResultType::kOk:
510               SetResponse(result);
511               break;
512             case MockDnsClientRule::ResultType::kMalformed:
513               DCHECK(!result->response);  // Not expected to be provided.
514               result_.response = CreateMalformedResponse(hostname_, qtype_);
515               break;
516             case MockDnsClientRule::ResultType::kUnexpected:
517               if (!delayed_) {
518                 // Assume a delayed kUnexpected transaction is only an issue if
519                 // allowed to complete.
520                 ADD_FAILURE()
521                     << "Unexpected DNS transaction created for hostname "
522                     << hostname_;
523               }
524               break;
525           }
526 
527           break;
528         }
529       }
530     }
531   }
532 
GetHostname() const533   const std::string& GetHostname() const override { return hostname_; }
534 
GetType() const535   uint16_t GetType() const override { return qtype_; }
536 
Start(ResponseCallback callback)537   void Start(ResponseCallback callback) override {
538     CHECK(!callback.is_null());
539     CHECK(callback_.is_null());
540     EXPECT_FALSE(started_);
541 
542     callback_ = std::move(callback);
543     started_ = true;
544     if (delayed_)
545       return;
546     // Using WeakPtr to cleanly cancel when transaction is destroyed.
547     base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
548         FROM_HERE, base::BindOnce(&MockTransaction::Finish, AsWeakPtr()));
549   }
550 
FinishDelayedTransaction()551   void FinishDelayedTransaction() {
552     EXPECT_TRUE(delayed_);
553     delayed_ = false;
554     Finish();
555   }
556 
delayed() const557   bool delayed() const { return delayed_; }
558 
559  private:
SetResponse(const MockDnsClientRule::Result * result)560   void SetResponse(const MockDnsClientRule::Result* result) {
561     if (result->response) {
562       // Copy response in case |result| is destroyed before the transaction
563       // completes.
564       auto buffer_copy = base::MakeRefCounted<IOBufferWithSize>(
565           result->response->io_buffer_size());
566       memcpy(buffer_copy->data(), result->response->io_buffer()->data(),
567              result->response->io_buffer_size());
568       result_.response = DnsResponse(std::move(buffer_copy),
569                                      result->response->io_buffer_size());
570       CHECK(result_.response->InitParseWithoutQuery(
571           result->response->io_buffer_size()));
572     } else {
573       // Generated response only available for address types.
574       DCHECK(qtype_ == dns_protocol::kTypeA ||
575              qtype_ == dns_protocol::kTypeAAAA);
576       result_.response = BuildTestDnsAddressResponse(
577           hostname_, qtype_ == dns_protocol::kTypeA
578                          ? IPAddress::IPv4Localhost()
579                          : IPAddress::IPv6Localhost());
580     }
581   }
582 
Finish()583   void Finish() {
584     switch (result_.type) {
585       case MockDnsClientRule::ResultType::kNoDomain:
586       case MockDnsClientRule::ResultType::kFail: {
587         int error = result_.net_error.value_or(ERR_NAME_NOT_RESOLVED);
588         DCHECK_NE(error, OK);
589         std::move(callback_).Run(error, base::OptionalToPtr(result_.response));
590         break;
591       }
592       case MockDnsClientRule::ResultType::kEmpty:
593       case MockDnsClientRule::ResultType::kOk:
594       case MockDnsClientRule::ResultType::kMalformed:
595         DCHECK(!result_.net_error.has_value());
596         std::move(callback_).Run(OK, base::OptionalToPtr(result_.response));
597         break;
598       case MockDnsClientRule::ResultType::kTimeout:
599         DCHECK(!result_.net_error.has_value());
600         std::move(callback_).Run(ERR_DNS_TIMED_OUT, /*response=*/nullptr);
601         break;
602       case MockDnsClientRule::ResultType::kSlow:
603         if (result_.response) {
604           std::move(callback_).Run(
605               result_.net_error.value_or(OK),
606               result_.response ? &result_.response.value() : nullptr);
607         } else {
608           DCHECK(!result_.net_error.has_value());
609           std::move(callback_).Run(ERR_DNS_TIMED_OUT, /*response=*/nullptr);
610         }
611         break;
612       case MockDnsClientRule::ResultType::kUnexpected:
613         ADD_FAILURE() << "Unexpected DNS transaction completed for hostname "
614                       << hostname_;
615         break;
616     }
617   }
618 
SetRequestPriority(RequestPriority priority)619   void SetRequestPriority(RequestPriority priority) override {}
620 
621   MockDnsClientRule::Result result_{MockDnsClientRule::ResultType::kFail};
622   const std::string hostname_;
623   const uint16_t qtype_;
624   ResponseCallback callback_;
625   bool started_ = false;
626   bool delayed_ = false;
627 };
628 
629 class MockDnsTransactionFactory::MockDohProbeRunner : public DnsProbeRunner {
630  public:
MockDohProbeRunner(base::WeakPtr<MockDnsTransactionFactory> factory)631   explicit MockDohProbeRunner(base::WeakPtr<MockDnsTransactionFactory> factory)
632       : factory_(std::move(factory)) {}
633 
~MockDohProbeRunner()634   ~MockDohProbeRunner() override {
635     if (factory_)
636       factory_->running_doh_probe_runners_.erase(this);
637   }
638 
Start(bool network_change)639   void Start(bool network_change) override {
640     DCHECK(factory_);
641     factory_->running_doh_probe_runners_.insert(this);
642   }
643 
GetDelayUntilNextProbeForTest(size_t doh_server_index) const644   base::TimeDelta GetDelayUntilNextProbeForTest(
645       size_t doh_server_index) const override {
646     NOTREACHED();
647     return base::TimeDelta();
648   }
649 
650  private:
651   base::WeakPtr<MockDnsTransactionFactory> factory_;
652 };
653 
MockDnsTransactionFactory(MockDnsClientRuleList rules)654 MockDnsTransactionFactory::MockDnsTransactionFactory(
655     MockDnsClientRuleList rules)
656     : rules_(std::move(rules)) {}
657 
658 MockDnsTransactionFactory::~MockDnsTransactionFactory() = default;
659 
CreateTransaction(std::string hostname,uint16_t qtype,const NetLogWithSource &,bool secure,SecureDnsMode secure_dns_mode,ResolveContext * resolve_context,bool fast_timeout)660 std::unique_ptr<DnsTransaction> MockDnsTransactionFactory::CreateTransaction(
661     std::string hostname,
662     uint16_t qtype,
663     const NetLogWithSource&,
664     bool secure,
665     SecureDnsMode secure_dns_mode,
666     ResolveContext* resolve_context,
667     bool fast_timeout) {
668   std::unique_ptr<MockTransaction> transaction =
669       std::make_unique<MockTransaction>(rules_, std::move(hostname), qtype,
670                                         secure, force_doh_server_available_,
671                                         secure_dns_mode, resolve_context,
672                                         fast_timeout);
673   if (transaction->delayed())
674     delayed_transactions_.push_back(transaction->AsWeakPtr());
675   return transaction;
676 }
677 
CreateDohProbeRunner(ResolveContext * resolve_context)678 std::unique_ptr<DnsProbeRunner> MockDnsTransactionFactory::CreateDohProbeRunner(
679     ResolveContext* resolve_context) {
680   return std::make_unique<MockDohProbeRunner>(weak_ptr_factory_.GetWeakPtr());
681 }
682 
AddEDNSOption(std::unique_ptr<OptRecordRdata::Opt> opt)683 void MockDnsTransactionFactory::AddEDNSOption(
684     std::unique_ptr<OptRecordRdata::Opt> opt) {}
685 
GetSecureDnsModeForTest()686 SecureDnsMode MockDnsTransactionFactory::GetSecureDnsModeForTest() {
687   return SecureDnsMode::kAutomatic;
688 }
689 
CompleteDelayedTransactions()690 void MockDnsTransactionFactory::CompleteDelayedTransactions() {
691   DelayedTransactionList old_delayed_transactions;
692   old_delayed_transactions.swap(delayed_transactions_);
693   for (auto& old_delayed_transaction : old_delayed_transactions) {
694     if (old_delayed_transaction.get())
695       old_delayed_transaction->FinishDelayedTransaction();
696   }
697 }
698 
CompleteOneDelayedTransactionOfType(DnsQueryType type)699 bool MockDnsTransactionFactory::CompleteOneDelayedTransactionOfType(
700     DnsQueryType type) {
701   for (base::WeakPtr<MockTransaction>& t : delayed_transactions_) {
702     if (t && t->GetType() == DnsQueryTypeToQtype(type)) {
703       t->FinishDelayedTransaction();
704       t.reset();
705       return true;
706     }
707   }
708   return false;
709 }
710 
MockDnsClient(DnsConfig config,MockDnsClientRuleList rules)711 MockDnsClient::MockDnsClient(DnsConfig config, MockDnsClientRuleList rules)
712     : config_(std::move(config)),
713       factory_(std::make_unique<MockDnsTransactionFactory>(std::move(rules))),
714       address_sorter_(std::make_unique<MockAddressSorter>()) {
715   effective_config_ = BuildEffectiveConfig();
716   session_ = BuildSession();
717 }
718 
719 MockDnsClient::~MockDnsClient() = default;
720 
CanUseSecureDnsTransactions() const721 bool MockDnsClient::CanUseSecureDnsTransactions() const {
722   const DnsConfig* config = GetEffectiveConfig();
723   return config && config->IsValid() && !config->doh_config.servers().empty();
724 }
725 
CanUseInsecureDnsTransactions() const726 bool MockDnsClient::CanUseInsecureDnsTransactions() const {
727   const DnsConfig* config = GetEffectiveConfig();
728   return config && config->IsValid() && insecure_enabled_ &&
729          !config->dns_over_tls_active;
730 }
731 
CanQueryAdditionalTypesViaInsecureDns() const732 bool MockDnsClient::CanQueryAdditionalTypesViaInsecureDns() const {
733   DCHECK(CanUseInsecureDnsTransactions());
734   return additional_types_enabled_;
735 }
736 
SetInsecureEnabled(bool enabled,bool additional_types_enabled)737 void MockDnsClient::SetInsecureEnabled(bool enabled,
738                                        bool additional_types_enabled) {
739   insecure_enabled_ = enabled;
740   additional_types_enabled_ = additional_types_enabled;
741 }
742 
FallbackFromSecureTransactionPreferred(ResolveContext * context) const743 bool MockDnsClient::FallbackFromSecureTransactionPreferred(
744     ResolveContext* context) const {
745   bool doh_server_available =
746       force_doh_server_available_ ||
747       context->NumAvailableDohServers(session_.get()) > 0;
748   return !CanUseSecureDnsTransactions() || !doh_server_available;
749 }
750 
FallbackFromInsecureTransactionPreferred() const751 bool MockDnsClient::FallbackFromInsecureTransactionPreferred() const {
752   return !CanUseInsecureDnsTransactions() ||
753          fallback_failures_ >= max_fallback_failures_;
754 }
755 
SetSystemConfig(std::optional<DnsConfig> system_config)756 bool MockDnsClient::SetSystemConfig(std::optional<DnsConfig> system_config) {
757   if (ignore_system_config_changes_)
758     return false;
759 
760   std::optional<DnsConfig> before = effective_config_;
761   config_ = std::move(system_config);
762   effective_config_ = BuildEffectiveConfig();
763   session_ = BuildSession();
764   return before != effective_config_;
765 }
766 
SetConfigOverrides(DnsConfigOverrides config_overrides)767 bool MockDnsClient::SetConfigOverrides(DnsConfigOverrides config_overrides) {
768   std::optional<DnsConfig> before = effective_config_;
769   overrides_ = std::move(config_overrides);
770   effective_config_ = BuildEffectiveConfig();
771   session_ = BuildSession();
772   return before != effective_config_;
773 }
774 
ReplaceCurrentSession()775 void MockDnsClient::ReplaceCurrentSession() {
776   // Noop if no current effective config.
777   session_ = BuildSession();
778 }
779 
GetCurrentSession()780 DnsSession* MockDnsClient::GetCurrentSession() {
781   return session_.get();
782 }
783 
GetEffectiveConfig() const784 const DnsConfig* MockDnsClient::GetEffectiveConfig() const {
785   return effective_config_.has_value() ? &effective_config_.value() : nullptr;
786 }
787 
GetDnsConfigAsValueForNetLog() const788 base::Value::Dict MockDnsClient::GetDnsConfigAsValueForNetLog() const {
789   // This is just a stub implementation that never produces a meaningful value.
790   return base::Value::Dict();
791 }
792 
GetHosts() const793 const DnsHosts* MockDnsClient::GetHosts() const {
794   const DnsConfig* config = GetEffectiveConfig();
795   if (!config)
796     return nullptr;
797 
798   return &config->hosts;
799 }
800 
GetTransactionFactory()801 DnsTransactionFactory* MockDnsClient::GetTransactionFactory() {
802   return GetEffectiveConfig() ? factory_.get() : nullptr;
803 }
804 
GetAddressSorter()805 AddressSorter* MockDnsClient::GetAddressSorter() {
806   return GetEffectiveConfig() ? address_sorter_.get() : nullptr;
807 }
808 
IncrementInsecureFallbackFailures()809 void MockDnsClient::IncrementInsecureFallbackFailures() {
810   ++fallback_failures_;
811 }
812 
ClearInsecureFallbackFailures()813 void MockDnsClient::ClearInsecureFallbackFailures() {
814   fallback_failures_ = 0;
815 }
816 
GetSystemConfigForTesting() const817 std::optional<DnsConfig> MockDnsClient::GetSystemConfigForTesting() const {
818   return config_;
819 }
820 
GetConfigOverridesForTesting() const821 DnsConfigOverrides MockDnsClient::GetConfigOverridesForTesting() const {
822   return overrides_;
823 }
824 
SetTransactionFactoryForTesting(std::unique_ptr<DnsTransactionFactory> factory)825 void MockDnsClient::SetTransactionFactoryForTesting(
826     std::unique_ptr<DnsTransactionFactory> factory) {
827   NOTREACHED();
828 }
829 
SetAddressSorterForTesting(std::unique_ptr<AddressSorter> address_sorter)830 void MockDnsClient::SetAddressSorterForTesting(
831     std::unique_ptr<AddressSorter> address_sorter) {
832   address_sorter_ = std::move(address_sorter);
833 }
834 
GetPresetAddrs(const url::SchemeHostPort & endpoint) const835 std::optional<std::vector<IPEndPoint>> MockDnsClient::GetPresetAddrs(
836     const url::SchemeHostPort& endpoint) const {
837   EXPECT_THAT(preset_endpoint_, testing::Optional(endpoint));
838   return preset_addrs_;
839 }
840 
CompleteDelayedTransactions()841 void MockDnsClient::CompleteDelayedTransactions() {
842   factory_->CompleteDelayedTransactions();
843 }
844 
CompleteOneDelayedTransactionOfType(DnsQueryType type)845 bool MockDnsClient::CompleteOneDelayedTransactionOfType(DnsQueryType type) {
846   return factory_->CompleteOneDelayedTransactionOfType(type);
847 }
848 
SetForceDohServerAvailable(bool available)849 void MockDnsClient::SetForceDohServerAvailable(bool available) {
850   force_doh_server_available_ = available;
851   factory_->set_force_doh_server_available(available);
852 }
853 
BuildEffectiveConfig()854 std::optional<DnsConfig> MockDnsClient::BuildEffectiveConfig() {
855   if (overrides_.OverridesEverything())
856     return overrides_.ApplyOverrides(DnsConfig());
857   if (!config_ || !config_.value().IsValid())
858     return std::nullopt;
859 
860   return overrides_.ApplyOverrides(config_.value());
861 }
862 
BuildSession()863 scoped_refptr<DnsSession> MockDnsClient::BuildSession() {
864   if (!effective_config_)
865     return nullptr;
866 
867   // Session not expected to be used for anything that will actually require
868   // random numbers.
869   auto null_random_callback =
870       base::BindRepeating([](int, int) -> int { base::ImmediateCrash(); });
871 
872   return base::MakeRefCounted<DnsSession>(
873       effective_config_.value(), null_random_callback, nullptr /* net_log */);
874 }
875 
MockHostResolverProc()876 MockHostResolverProc::MockHostResolverProc()
877     : HostResolverProc(nullptr),
878       requests_waiting_(&lock_),
879       slots_available_(&lock_) {}
880 
881 MockHostResolverProc::~MockHostResolverProc() = default;
882 
WaitFor(unsigned count)883 bool MockHostResolverProc::WaitFor(unsigned count) {
884   base::AutoLock lock(lock_);
885   base::Time start_time = base::Time::Now();
886   while (num_requests_waiting_ < count) {
887     requests_waiting_.TimedWait(TestTimeouts::action_timeout());
888     if (base::Time::Now() > start_time + TestTimeouts::action_timeout()) {
889       return false;
890     }
891   }
892   return true;
893 }
894 
SignalMultiple(unsigned count)895 void MockHostResolverProc::SignalMultiple(unsigned count) {
896   base::AutoLock lock(lock_);
897   num_slots_available_ += count;
898   slots_available_.Broadcast();
899 }
900 
SignalAll()901 void MockHostResolverProc::SignalAll() {
902   base::AutoLock lock(lock_);
903   num_slots_available_ = num_requests_waiting_;
904   slots_available_.Broadcast();
905 }
906 
AddRule(const std::string & hostname,AddressFamily family,const AddressList & result,HostResolverFlags flags)907 void MockHostResolverProc::AddRule(const std::string& hostname,
908                                    AddressFamily family,
909                                    const AddressList& result,
910                                    HostResolverFlags flags) {
911   base::AutoLock lock(lock_);
912   rules_[ResolveKey(hostname, family, flags)] = result;
913 }
914 
AddRule(const std::string & hostname,AddressFamily family,const std::string & ip_list,HostResolverFlags flags,const std::string & canonical_name)915 void MockHostResolverProc::AddRule(const std::string& hostname,
916                                    AddressFamily family,
917                                    const std::string& ip_list,
918                                    HostResolverFlags flags,
919                                    const std::string& canonical_name) {
920   AddressList result;
921   std::vector<std::string> dns_aliases;
922   if (canonical_name != "") {
923     dns_aliases = {canonical_name};
924   }
925   int rv = ParseAddressList(ip_list, &result.endpoints());
926   result.SetDnsAliases(dns_aliases);
927   DCHECK_EQ(OK, rv);
928   AddRule(hostname, family, result, flags);
929 }
930 
AddRuleForAllFamilies(const std::string & hostname,const std::string & ip_list,HostResolverFlags flags,const std::string & canonical_name)931 void MockHostResolverProc::AddRuleForAllFamilies(
932     const std::string& hostname,
933     const std::string& ip_list,
934     HostResolverFlags flags,
935     const std::string& canonical_name) {
936   AddressList result;
937   std::vector<std::string> dns_aliases;
938   if (canonical_name != "") {
939     dns_aliases = {canonical_name};
940   }
941   int rv = ParseAddressList(ip_list, &result.endpoints());
942   result.SetDnsAliases(dns_aliases);
943   DCHECK_EQ(OK, rv);
944   AddRule(hostname, ADDRESS_FAMILY_UNSPECIFIED, result, flags);
945   AddRule(hostname, ADDRESS_FAMILY_IPV4, result, flags);
946   AddRule(hostname, ADDRESS_FAMILY_IPV6, result, flags);
947 }
948 
Resolve(const std::string & hostname,AddressFamily address_family,HostResolverFlags host_resolver_flags,AddressList * addrlist,int * os_error)949 int MockHostResolverProc::Resolve(const std::string& hostname,
950                                   AddressFamily address_family,
951                                   HostResolverFlags host_resolver_flags,
952                                   AddressList* addrlist,
953                                   int* os_error) {
954   base::AutoLock lock(lock_);
955   capture_list_.emplace_back(hostname, address_family, host_resolver_flags);
956   ++num_requests_waiting_;
957   requests_waiting_.Broadcast();
958   {
959     base::ScopedAllowBaseSyncPrimitivesForTesting
960         scoped_allow_base_sync_primitives;
961     while (!num_slots_available_) {
962       slots_available_.Wait();
963     }
964   }
965   DCHECK_GT(num_requests_waiting_, 0u);
966   --num_slots_available_;
967   --num_requests_waiting_;
968   if (rules_.empty()) {
969     int rv = ParseAddressList("127.0.0.1", &addrlist->endpoints());
970     DCHECK_EQ(OK, rv);
971     return OK;
972   }
973   ResolveKey key(hostname, address_family, host_resolver_flags);
974   if (rules_.count(key) == 0) {
975     return ERR_NAME_NOT_RESOLVED;
976   }
977   *addrlist = rules_[key];
978   return OK;
979 }
980 
GetCaptureList() const981 MockHostResolverProc::CaptureList MockHostResolverProc::GetCaptureList() const {
982   CaptureList copy;
983   {
984     base::AutoLock lock(lock_);
985     copy = capture_list_;
986   }
987   return copy;
988 }
989 
ClearCaptureList()990 void MockHostResolverProc::ClearCaptureList() {
991   base::AutoLock lock(lock_);
992   capture_list_.clear();
993 }
994 
HasBlockedRequests() const995 bool MockHostResolverProc::HasBlockedRequests() const {
996   base::AutoLock lock(lock_);
997   return num_requests_waiting_ > num_slots_available_;
998 }
999 
1000 }  // namespace net
1001