xref: /aosp_15_r20/external/cronet/net/dns/mock_host_resolver.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/dns/mock_host_resolver.h"
6 
7 #include <stdint.h>
8 
9 #include <memory>
10 #include <optional>
11 #include <string>
12 #include <string_view>
13 #include <utility>
14 #include <vector>
15 
16 #include "base/check_op.h"
17 #include "base/functional/bind.h"
18 #include "base/functional/callback_helpers.h"
19 #include "base/location.h"
20 #include "base/logging.h"
21 #include "base/memory/ptr_util.h"
22 #include "base/memory/raw_ptr.h"
23 #include "base/memory/ref_counted.h"
24 #include "base/no_destructor.h"
25 #include "base/notreached.h"
26 #include "base/strings/pattern.h"
27 #include "base/strings/string_split.h"
28 #include "base/strings/string_util.h"
29 #include "base/task/single_thread_task_runner.h"
30 #include "base/threading/platform_thread.h"
31 #include "base/time/default_tick_clock.h"
32 #include "base/time/tick_clock.h"
33 #include "base/time/time.h"
34 #include "base/types/optional_util.h"
35 #include "build/build_config.h"
36 #include "net/base/address_family.h"
37 #include "net/base/address_list.h"
38 #include "net/base/host_port_pair.h"
39 #include "net/base/ip_address.h"
40 #include "net/base/ip_endpoint.h"
41 #include "net/base/net_errors.h"
42 #include "net/base/net_export.h"
43 #include "net/base/network_anonymization_key.h"
44 #include "net/base/test_completion_callback.h"
45 #include "net/dns/dns_alias_utility.h"
46 #include "net/dns/dns_names_util.h"
47 #include "net/dns/dns_util.h"
48 #include "net/dns/host_cache.h"
49 #include "net/dns/host_resolver.h"
50 #include "net/dns/host_resolver_manager.h"
51 #include "net/dns/host_resolver_system_task.h"
52 #include "net/dns/https_record_rdata.h"
53 #include "net/dns/public/dns_query_type.h"
54 #include "net/dns/public/host_resolver_results.h"
55 #include "net/dns/public/host_resolver_source.h"
56 #include "net/dns/public/mdns_listener_update_type.h"
57 #include "net/dns/public/resolve_error_info.h"
58 #include "net/dns/public/secure_dns_policy.h"
59 #include "net/log/net_log_with_source.h"
60 #include "net/url_request/url_request_context.h"
61 #include "third_party/abseil-cpp/absl/types/variant.h"
62 #include "url/scheme_host_port.h"
63 
64 #if BUILDFLAG(IS_WIN)
65 #include "net/base/winsock_init.h"
66 #endif
67 
68 namespace net {
69 
70 namespace {
71 
72 // Cache size for the MockCachingHostResolver.
73 const unsigned kMaxCacheEntries = 100;
74 // TTL for the successful resolutions. Failures are not cached.
75 const unsigned kCacheEntryTTLSeconds = 60;
76 
GetCacheHost(const HostResolver::Host & endpoint)77 absl::variant<url::SchemeHostPort, std::string> GetCacheHost(
78     const HostResolver::Host& endpoint) {
79   if (endpoint.HasScheme()) {
80     return endpoint.AsSchemeHostPort();
81   }
82 
83   return endpoint.GetHostname();
84 }
85 
CreateCacheEntry(std::string_view canonical_name,const std::vector<HostResolverEndpointResult> & endpoint_results,const std::set<std::string> & aliases)86 std::optional<HostCache::Entry> CreateCacheEntry(
87     std::string_view canonical_name,
88     const std::vector<HostResolverEndpointResult>& endpoint_results,
89     const std::set<std::string>& aliases) {
90   std::optional<std::vector<net::IPEndPoint>> ip_endpoints;
91   std::multimap<HttpsRecordPriority, ConnectionEndpointMetadata>
92       endpoint_metadatas;
93   for (const auto& endpoint_result : endpoint_results) {
94     if (!ip_endpoints) {
95       ip_endpoints = endpoint_result.ip_endpoints;
96     } else {
97       // TODO(crbug.com/1264933): Support caching different IP endpoints
98       // resutls.
99       CHECK(*ip_endpoints == endpoint_result.ip_endpoints)
100           << "Currently caching MockHostResolver only supports same IP "
101              "endpoints results.";
102     }
103 
104     if (!endpoint_result.metadata.supported_protocol_alpns.empty()) {
105       endpoint_metadatas.emplace(/*priority=*/1, endpoint_result.metadata);
106     }
107   }
108   DCHECK(ip_endpoints);
109   auto endpoint_entry = HostCache::Entry(OK, *ip_endpoints, aliases,
110                                          HostCache::Entry::SOURCE_UNKNOWN);
111   endpoint_entry.set_canonical_names(std::set{std::string(canonical_name)});
112   if (endpoint_metadatas.empty()) {
113     return endpoint_entry;
114   }
115   return HostCache::Entry::MergeEntries(
116       HostCache::Entry(OK, std::move(endpoint_metadatas),
117                        HostCache::Entry::SOURCE_UNKNOWN),
118       endpoint_entry);
119 }
120 }  // namespace
121 
ParseAddressList(std::string_view host_list,std::vector<net::IPEndPoint> * ip_endpoints)122 int ParseAddressList(std::string_view host_list,
123                      std::vector<net::IPEndPoint>* ip_endpoints) {
124   ip_endpoints->clear();
125   for (std::string_view address : base::SplitStringPiece(
126            host_list, ",", base::TRIM_WHITESPACE, base::SPLIT_WANT_ALL)) {
127     IPAddress ip_address;
128     if (!ip_address.AssignFromIPLiteral(address)) {
129       LOG(WARNING) << "Not a supported IP literal: " << address;
130       return ERR_UNEXPECTED;
131     }
132     ip_endpoints->push_back(IPEndPoint(ip_address, 0));
133   }
134   return OK;
135 }
136 
137 class MockHostResolverBase::RequestImpl
138     : public HostResolver::ResolveHostRequest {
139  public:
RequestImpl(Host request_endpoint,const NetworkAnonymizationKey & network_anonymization_key,const std::optional<ResolveHostParameters> & optional_parameters,base::WeakPtr<MockHostResolverBase> resolver)140   RequestImpl(Host request_endpoint,
141               const NetworkAnonymizationKey& network_anonymization_key,
142               const std::optional<ResolveHostParameters>& optional_parameters,
143               base::WeakPtr<MockHostResolverBase> resolver)
144       : request_endpoint_(std::move(request_endpoint)),
145         network_anonymization_key_(network_anonymization_key),
146         parameters_(optional_parameters ? optional_parameters.value()
147                                         : ResolveHostParameters()),
148         priority_(parameters_.initial_priority),
149         host_resolver_flags_(ParametersToHostResolverFlags(parameters_)),
150         resolve_error_info_(ResolveErrorInfo(ERR_IO_PENDING)),
151         resolver_(resolver) {}
152 
153   RequestImpl(const RequestImpl&) = delete;
154   RequestImpl& operator=(const RequestImpl&) = delete;
155 
~RequestImpl()156   ~RequestImpl() override {
157     if (id_ > 0) {
158       if (resolver_)
159         resolver_->DetachRequest(id_);
160       id_ = 0;
161       resolver_ = nullptr;
162     }
163   }
164 
DetachFromResolver()165   void DetachFromResolver() {
166     id_ = 0;
167     resolver_ = nullptr;
168   }
169 
Start(CompletionOnceCallback callback)170   int Start(CompletionOnceCallback callback) override {
171     DCHECK(callback);
172     // Start() may only be called once per request.
173     DCHECK_EQ(0u, id_);
174     DCHECK(!complete_);
175     DCHECK(!callback_);
176     // Parent HostResolver must still be alive to call Start().
177     DCHECK(resolver_);
178 
179     int rv = resolver_->Resolve(this);
180     DCHECK(!complete_);
181     if (rv == ERR_IO_PENDING) {
182       DCHECK_GT(id_, 0u);
183       callback_ = std::move(callback);
184     } else {
185       DCHECK_EQ(0u, id_);
186       complete_ = true;
187     }
188 
189     return rv;
190   }
191 
GetAddressResults() const192   const AddressList* GetAddressResults() const override {
193     DCHECK(complete_);
194     return base::OptionalToPtr(address_results_);
195   }
196 
GetEndpointResults() const197   const std::vector<HostResolverEndpointResult>* GetEndpointResults()
198       const override {
199     DCHECK(complete_);
200     return base::OptionalToPtr(endpoint_results_);
201   }
202 
GetTextResults() const203   const std::vector<std::string>* GetTextResults() const override {
204     DCHECK(complete_);
205     static const base::NoDestructor<std::vector<std::string>> empty_result;
206     return empty_result.get();
207   }
208 
GetHostnameResults() const209   const std::vector<HostPortPair>* GetHostnameResults() const override {
210     DCHECK(complete_);
211     static const base::NoDestructor<std::vector<HostPortPair>> empty_result;
212     return empty_result.get();
213   }
214 
GetDnsAliasResults() const215   const std::set<std::string>* GetDnsAliasResults() const override {
216     DCHECK(complete_);
217     return base::OptionalToPtr(fixed_up_dns_alias_results_);
218   }
219 
GetResolveErrorInfo() const220   net::ResolveErrorInfo GetResolveErrorInfo() const override {
221     DCHECK(complete_);
222     return resolve_error_info_;
223   }
224 
GetStaleInfo() const225   const std::optional<HostCache::EntryStaleness>& GetStaleInfo()
226       const override {
227     DCHECK(complete_);
228     return staleness_;
229   }
230 
ChangeRequestPriority(RequestPriority priority)231   void ChangeRequestPriority(RequestPriority priority) override {
232     priority_ = priority;
233   }
234 
SetError(int error)235   void SetError(int error) {
236     // Should only be called before request is marked completed.
237     DCHECK(!complete_);
238     resolve_error_info_ = ResolveErrorInfo(error);
239   }
240 
241   // Sets `endpoint_results_`, `fixed_up_dns_alias_results_`,
242   // `address_results_` and `staleness_` after fixing them up.
243   // Also sets `error` to OK.
SetEndpointResults(std::vector<HostResolverEndpointResult> endpoint_results,std::set<std::string> aliases,std::optional<HostCache::EntryStaleness> staleness)244   void SetEndpointResults(
245       std::vector<HostResolverEndpointResult> endpoint_results,
246       std::set<std::string> aliases,
247       std::optional<HostCache::EntryStaleness> staleness) {
248     DCHECK(!complete_);
249     DCHECK(!endpoint_results_);
250     DCHECK(!parameters_.is_speculative);
251 
252     endpoint_results_ = std::move(endpoint_results);
253     for (auto& result : *endpoint_results_) {
254       result.ip_endpoints = FixupEndPoints(result.ip_endpoints);
255     }
256 
257     fixed_up_dns_alias_results_ = FixupAliases(aliases);
258 
259     // `HostResolver` implementations are expected to provide an `AddressList`
260     // result whenever `HostResolverEndpointResult` is also available.
261     address_results_ = EndpointResultToAddressList(
262         *endpoint_results_, *fixed_up_dns_alias_results_);
263 
264     staleness_ = std::move(staleness);
265 
266     SetError(OK);
267   }
268 
OnAsyncCompleted(size_t id,int error)269   void OnAsyncCompleted(size_t id, int error) {
270     DCHECK_EQ(id_, id);
271     id_ = 0;
272 
273     // Check that error information has been set and that the top-level error
274     // code is valid.
275     DCHECK(resolve_error_info_.error != ERR_IO_PENDING);
276     DCHECK(error == OK || error == ERR_NAME_NOT_RESOLVED ||
277            error == ERR_DNS_NAME_HTTPS_ONLY);
278 
279     DCHECK(!complete_);
280     complete_ = true;
281 
282     DCHECK(callback_);
283     std::move(callback_).Run(error);
284   }
285 
request_endpoint() const286   const Host& request_endpoint() const { return request_endpoint_; }
287 
network_anonymization_key() const288   const NetworkAnonymizationKey& network_anonymization_key() const {
289     return network_anonymization_key_;
290   }
291 
parameters() const292   const ResolveHostParameters& parameters() const { return parameters_; }
293 
host_resolver_flags() const294   int host_resolver_flags() const { return host_resolver_flags_; }
295 
id()296   size_t id() { return id_; }
297 
priority() const298   RequestPriority priority() const { return priority_; }
299 
set_id(size_t id)300   void set_id(size_t id) {
301     DCHECK_GT(id, 0u);
302     DCHECK_EQ(0u, id_);
303 
304     id_ = id;
305   }
306 
complete()307   bool complete() { return complete_; }
308 
309   // Similar get GetAddressResults() and GetResolveErrorInfo(), but only exposed
310   // through the HostResolver::ResolveHostRequest interface, and don't have the
311   // DCHECKs that `complete_` is true.
address_results() const312   const std::optional<AddressList>& address_results() const {
313     return address_results_;
314   }
resolve_error_info() const315   ResolveErrorInfo resolve_error_info() const { return resolve_error_info_; }
316 
317  private:
FixupEndPoints(const std::vector<IPEndPoint> & endpoints)318   std::vector<IPEndPoint> FixupEndPoints(
319       const std::vector<IPEndPoint>& endpoints) {
320     std::vector<IPEndPoint> corrected;
321     for (const IPEndPoint& endpoint : endpoints) {
322       DCHECK_NE(endpoint.GetFamily(), ADDRESS_FAMILY_UNSPECIFIED);
323       if (parameters_.dns_query_type == DnsQueryType::UNSPECIFIED ||
324           parameters_.dns_query_type ==
325               AddressFamilyToDnsQueryType(endpoint.GetFamily())) {
326         if (endpoint.port() == 0) {
327           corrected.emplace_back(endpoint.address(),
328                                  request_endpoint_.GetPort());
329         } else {
330           corrected.push_back(endpoint);
331         }
332       }
333     }
334     return corrected;
335   }
FixupAliases(const std::set<std::string> aliases)336   std::set<std::string> FixupAliases(const std::set<std::string> aliases) {
337     if (aliases.empty())
338       return std::set<std::string>{
339           std::string(request_endpoint_.GetHostnameWithoutBrackets())};
340     return aliases;
341   }
342 
343   const Host request_endpoint_;
344   const NetworkAnonymizationKey network_anonymization_key_;
345   const ResolveHostParameters parameters_;
346   RequestPriority priority_;
347   int host_resolver_flags_;
348 
349   std::optional<AddressList> address_results_;
350   std::optional<std::vector<HostResolverEndpointResult>> endpoint_results_;
351   std::optional<std::set<std::string>> fixed_up_dns_alias_results_;
352   std::optional<HostCache::EntryStaleness> staleness_;
353   ResolveErrorInfo resolve_error_info_;
354 
355   // Used while stored with the resolver for async resolution.  Otherwise 0.
356   size_t id_ = 0;
357 
358   CompletionOnceCallback callback_;
359   // Use a WeakPtr as the resolver may be destroyed while there are still
360   // outstanding request objects.
361   base::WeakPtr<MockHostResolverBase> resolver_;
362   bool complete_ = false;
363 };
364 
365 class MockHostResolverBase::ProbeRequestImpl
366     : public HostResolver::ProbeRequest {
367  public:
ProbeRequestImpl(base::WeakPtr<MockHostResolverBase> resolver)368   explicit ProbeRequestImpl(base::WeakPtr<MockHostResolverBase> resolver)
369       : resolver_(std::move(resolver)) {}
370 
371   ProbeRequestImpl(const ProbeRequestImpl&) = delete;
372   ProbeRequestImpl& operator=(const ProbeRequestImpl&) = delete;
373 
~ProbeRequestImpl()374   ~ProbeRequestImpl() override {
375     if (resolver_) {
376       resolver_->state_->ClearDohProbeRequestIfMatching(this);
377     }
378   }
379 
Start()380   int Start() override {
381     DCHECK(resolver_);
382     resolver_->state_->set_doh_probe_request(this);
383 
384     return ERR_IO_PENDING;
385   }
386 
387  private:
388   base::WeakPtr<MockHostResolverBase> resolver_;
389 };
390 
391 class MockHostResolverBase::MdnsListenerImpl
392     : public HostResolver::MdnsListener {
393  public:
MdnsListenerImpl(const HostPortPair & host,DnsQueryType query_type,base::WeakPtr<MockHostResolverBase> resolver)394   MdnsListenerImpl(const HostPortPair& host,
395                    DnsQueryType query_type,
396                    base::WeakPtr<MockHostResolverBase> resolver)
397       : host_(host), query_type_(query_type), resolver_(resolver) {
398     DCHECK_NE(DnsQueryType::UNSPECIFIED, query_type_);
399     DCHECK(resolver_);
400   }
401 
~MdnsListenerImpl()402   ~MdnsListenerImpl() override {
403     if (resolver_)
404       resolver_->RemoveCancelledListener(this);
405   }
406 
Start(Delegate * delegate)407   int Start(Delegate* delegate) override {
408     DCHECK(delegate);
409     DCHECK(!delegate_);
410     DCHECK(resolver_);
411 
412     delegate_ = delegate;
413     resolver_->AddListener(this);
414 
415     return OK;
416   }
417 
TriggerAddressResult(MdnsListenerUpdateType update_type,IPEndPoint address)418   void TriggerAddressResult(MdnsListenerUpdateType update_type,
419                             IPEndPoint address) {
420     delegate_->OnAddressResult(update_type, query_type_, std::move(address));
421   }
422 
TriggerTextResult(MdnsListenerUpdateType update_type,std::vector<std::string> text_records)423   void TriggerTextResult(MdnsListenerUpdateType update_type,
424                          std::vector<std::string> text_records) {
425     delegate_->OnTextResult(update_type, query_type_, std::move(text_records));
426   }
427 
TriggerHostnameResult(MdnsListenerUpdateType update_type,HostPortPair host)428   void TriggerHostnameResult(MdnsListenerUpdateType update_type,
429                              HostPortPair host) {
430     delegate_->OnHostnameResult(update_type, query_type_, std::move(host));
431   }
432 
TriggerUnhandledResult(MdnsListenerUpdateType update_type)433   void TriggerUnhandledResult(MdnsListenerUpdateType update_type) {
434     delegate_->OnUnhandledResult(update_type, query_type_);
435   }
436 
host() const437   const HostPortPair& host() const { return host_; }
query_type() const438   DnsQueryType query_type() const { return query_type_; }
439 
440  private:
441   const HostPortPair host_;
442   const DnsQueryType query_type_;
443 
444   raw_ptr<Delegate> delegate_ = nullptr;
445 
446   // Use a WeakPtr as the resolver may be destroyed while there are still
447   // outstanding listener objects.
448   base::WeakPtr<MockHostResolverBase> resolver_;
449 };
450 
451 MockHostResolverBase::RuleResolver::RuleKey::RuleKey() = default;
452 
453 MockHostResolverBase::RuleResolver::RuleKey::~RuleKey() = default;
454 
455 MockHostResolverBase::RuleResolver::RuleKey::RuleKey(const RuleKey&) = default;
456 
457 MockHostResolverBase::RuleResolver::RuleKey&
458 MockHostResolverBase::RuleResolver::RuleKey::operator=(const RuleKey&) =
459     default;
460 
461 MockHostResolverBase::RuleResolver::RuleKey::RuleKey(RuleKey&&) = default;
462 
463 MockHostResolverBase::RuleResolver::RuleKey&
464 MockHostResolverBase::RuleResolver::RuleKey::operator=(RuleKey&&) = default;
465 
466 MockHostResolverBase::RuleResolver::RuleResult::RuleResult() = default;
467 
RuleResult(std::vector<HostResolverEndpointResult> endpoints,std::set<std::string> aliases)468 MockHostResolverBase::RuleResolver::RuleResult::RuleResult(
469     std::vector<HostResolverEndpointResult> endpoints,
470     std::set<std::string> aliases)
471     : endpoints(std::move(endpoints)), aliases(std::move(aliases)) {}
472 
473 MockHostResolverBase::RuleResolver::RuleResult::~RuleResult() = default;
474 
475 MockHostResolverBase::RuleResolver::RuleResult::RuleResult(const RuleResult&) =
476     default;
477 
478 MockHostResolverBase::RuleResolver::RuleResult&
479 MockHostResolverBase::RuleResolver::RuleResult::operator=(const RuleResult&) =
480     default;
481 
482 MockHostResolverBase::RuleResolver::RuleResult::RuleResult(RuleResult&&) =
483     default;
484 
485 MockHostResolverBase::RuleResolver::RuleResult&
486 MockHostResolverBase::RuleResolver::RuleResult::operator=(RuleResult&&) =
487     default;
488 
RuleResolver(std::optional<RuleResultOrError> default_result)489 MockHostResolverBase::RuleResolver::RuleResolver(
490     std::optional<RuleResultOrError> default_result)
491     : default_result_(std::move(default_result)) {}
492 
493 MockHostResolverBase::RuleResolver::~RuleResolver() = default;
494 
495 MockHostResolverBase::RuleResolver::RuleResolver(const RuleResolver&) = default;
496 
497 MockHostResolverBase::RuleResolver&
498 MockHostResolverBase::RuleResolver::operator=(const RuleResolver&) = default;
499 
500 MockHostResolverBase::RuleResolver::RuleResolver(RuleResolver&&) = default;
501 
502 MockHostResolverBase::RuleResolver&
503 MockHostResolverBase::RuleResolver::operator=(RuleResolver&&) = default;
504 
505 const MockHostResolverBase::RuleResolver::RuleResultOrError&
Resolve(const Host & request_endpoint,DnsQueryTypeSet request_types,HostResolverSource request_source) const506 MockHostResolverBase::RuleResolver::Resolve(
507     const Host& request_endpoint,
508     DnsQueryTypeSet request_types,
509     HostResolverSource request_source) const {
510   for (const auto& rule : rules_) {
511     const RuleKey& key = rule.first;
512     const RuleResultOrError& result = rule.second;
513 
514     if (absl::holds_alternative<RuleKey::NoScheme>(key.scheme) &&
515         request_endpoint.HasScheme()) {
516       continue;
517     }
518 
519     if (key.port.has_value() &&
520         key.port.value() != request_endpoint.GetPort()) {
521       continue;
522     }
523 
524     DCHECK(!key.query_type.has_value() ||
525            key.query_type.value() != DnsQueryType::UNSPECIFIED);
526     if (key.query_type.has_value() &&
527         !request_types.Has(key.query_type.value())) {
528       continue;
529     }
530 
531     if (key.query_source.has_value() &&
532         request_source != key.query_source.value()) {
533       continue;
534     }
535 
536     if (absl::holds_alternative<RuleKey::Scheme>(key.scheme) &&
537         (!request_endpoint.HasScheme() ||
538          request_endpoint.GetScheme() !=
539              absl::get<RuleKey::Scheme>(key.scheme))) {
540       continue;
541     }
542 
543     if (!base::MatchPattern(request_endpoint.GetHostnameWithoutBrackets(),
544                             key.hostname_pattern)) {
545       continue;
546     }
547 
548     return result;
549   }
550 
551   if (default_result_)
552     return default_result_.value();
553 
554   NOTREACHED() << "Request " << request_endpoint.GetHostname()
555                << " did not match any MockHostResolver rules.";
556   static const RuleResultOrError kUnexpected = ERR_UNEXPECTED;
557   return kUnexpected;
558 }
559 
ClearRules()560 void MockHostResolverBase::RuleResolver::ClearRules() {
561   rules_.clear();
562 }
563 
564 // static
565 MockHostResolverBase::RuleResolver::RuleResultOrError
GetLocalhostResult()566 MockHostResolverBase::RuleResolver::GetLocalhostResult() {
567   HostResolverEndpointResult endpoint;
568   endpoint.ip_endpoints = {IPEndPoint(IPAddress::IPv4Localhost(), /*port=*/0)};
569   return RuleResult(std::vector{endpoint});
570 }
571 
AddRule(RuleKey key,RuleResultOrError result)572 void MockHostResolverBase::RuleResolver::AddRule(RuleKey key,
573                                                  RuleResultOrError result) {
574   // Literals are always resolved to themselves by MockHostResolverBase,
575   // consequently we do not support remapping them.
576   IPAddress ip_address;
577   DCHECK(!ip_address.AssignFromIPLiteral(key.hostname_pattern));
578 
579   CHECK(rules_.emplace(std::move(key), std::move(result)).second)
580       << "Duplicate rule key";
581 }
582 
AddRule(RuleKey key,std::string_view ip_literal)583 void MockHostResolverBase::RuleResolver::AddRule(RuleKey key,
584                                                  std::string_view ip_literal) {
585   std::vector<HostResolverEndpointResult> endpoints;
586   endpoints.emplace_back();
587   CHECK_EQ(ParseAddressList(ip_literal, &endpoints[0].ip_endpoints), OK);
588   AddRule(std::move(key), RuleResult(std::move(endpoints)));
589 }
590 
AddRule(std::string_view hostname_pattern,RuleResultOrError result)591 void MockHostResolverBase::RuleResolver::AddRule(
592     std::string_view hostname_pattern,
593     RuleResultOrError result) {
594   RuleKey key;
595   key.hostname_pattern = std::string(hostname_pattern);
596   AddRule(std::move(key), std::move(result));
597 }
598 
AddRule(std::string_view hostname_pattern,std::string_view ip_literal)599 void MockHostResolverBase::RuleResolver::AddRule(
600     std::string_view hostname_pattern,
601     std::string_view ip_literal) {
602   std::vector<HostResolverEndpointResult> endpoints;
603   endpoints.emplace_back();
604   CHECK_EQ(ParseAddressList(ip_literal, &endpoints[0].ip_endpoints), OK);
605   AddRule(hostname_pattern, RuleResult(std::move(endpoints)));
606 }
607 
AddRule(std::string_view hostname_pattern,Error error)608 void MockHostResolverBase::RuleResolver::AddRule(
609     std::string_view hostname_pattern,
610     Error error) {
611   RuleKey key;
612   key.hostname_pattern = std::string(hostname_pattern);
613 
614   AddRule(std::move(key), error);
615 }
616 
AddIPLiteralRule(std::string_view hostname_pattern,std::string_view ip_literal,std::string_view canonical_name)617 void MockHostResolverBase::RuleResolver::AddIPLiteralRule(
618     std::string_view hostname_pattern,
619     std::string_view ip_literal,
620     std::string_view canonical_name) {
621   RuleKey key;
622   key.hostname_pattern = std::string(hostname_pattern);
623 
624   std::set<std::string> aliases;
625   if (!canonical_name.empty())
626     aliases.emplace(canonical_name);
627 
628   std::vector<HostResolverEndpointResult> endpoints;
629   endpoints.emplace_back();
630   CHECK_EQ(ParseAddressList(ip_literal, &endpoints[0].ip_endpoints), OK);
631   AddRule(std::move(key), RuleResult(std::move(endpoints), std::move(aliases)));
632 }
633 
AddIPLiteralRuleWithDnsAliases(std::string_view hostname_pattern,std::string_view ip_literal,std::vector<std::string> dns_aliases)634 void MockHostResolverBase::RuleResolver::AddIPLiteralRuleWithDnsAliases(
635     std::string_view hostname_pattern,
636     std::string_view ip_literal,
637     std::vector<std::string> dns_aliases) {
638   std::vector<HostResolverEndpointResult> endpoints;
639   endpoints.emplace_back();
640   CHECK_EQ(ParseAddressList(ip_literal, &endpoints[0].ip_endpoints), OK);
641   AddRule(hostname_pattern,
642           RuleResult(
643               std::move(endpoints),
644               std::set<std::string>(dns_aliases.begin(), dns_aliases.end())));
645 }
646 
AddIPLiteralRuleWithDnsAliases(std::string_view hostname_pattern,std::string_view ip_literal,std::set<std::string> dns_aliases)647 void MockHostResolverBase::RuleResolver::AddIPLiteralRuleWithDnsAliases(
648     std::string_view hostname_pattern,
649     std::string_view ip_literal,
650     std::set<std::string> dns_aliases) {
651   std::vector<std::string> aliases_vector;
652   base::ranges::move(dns_aliases, std::back_inserter(aliases_vector));
653 
654   AddIPLiteralRuleWithDnsAliases(hostname_pattern, ip_literal,
655                                  std::move(aliases_vector));
656 }
657 
AddSimulatedFailure(std::string_view hostname_pattern)658 void MockHostResolverBase::RuleResolver::AddSimulatedFailure(
659     std::string_view hostname_pattern) {
660   AddRule(hostname_pattern, ERR_NAME_NOT_RESOLVED);
661 }
662 
AddSimulatedTimeoutFailure(std::string_view hostname_pattern)663 void MockHostResolverBase::RuleResolver::AddSimulatedTimeoutFailure(
664     std::string_view hostname_pattern) {
665   AddRule(hostname_pattern, ERR_DNS_TIMED_OUT);
666 }
667 
AddRuleWithFlags(std::string_view host_pattern,std::string_view ip_literal,HostResolverFlags,std::vector<std::string> dns_aliases)668 void MockHostResolverBase::RuleResolver::AddRuleWithFlags(
669     std::string_view host_pattern,
670     std::string_view ip_literal,
671     HostResolverFlags /*flags*/,
672     std::vector<std::string> dns_aliases) {
673   std::vector<HostResolverEndpointResult> endpoints;
674   endpoints.emplace_back();
675   CHECK_EQ(ParseAddressList(ip_literal, &endpoints[0].ip_endpoints), OK);
676   AddRule(host_pattern, RuleResult(std::move(endpoints),
677                                    std::set<std::string>(dns_aliases.begin(),
678                                                          dns_aliases.end())));
679 }
680 
681 MockHostResolverBase::State::State() = default;
682 MockHostResolverBase::State::~State() = default;
683 
~MockHostResolverBase()684 MockHostResolverBase::~MockHostResolverBase() {
685   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
686 
687   // Sanity check that pending requests are always cleaned up, by waiting for
688   // completion, manually cancelling, or calling OnShutdown().
689   DCHECK(!state_->has_pending_requests());
690 }
691 
OnShutdown()692 void MockHostResolverBase::OnShutdown() {
693   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
694 
695   // Cancel all pending requests.
696   for (auto& request : state_->mutable_requests()) {
697     request.second->DetachFromResolver();
698   }
699   state_->mutable_requests().clear();
700 
701   // Prevent future requests by clearing resolution rules and the cache.
702   rule_resolver_.ClearRules();
703   cache_ = nullptr;
704 
705   state_->ClearDohProbeRequest();
706 }
707 
708 std::unique_ptr<HostResolver::ResolveHostRequest>
CreateRequest(url::SchemeHostPort host,NetworkAnonymizationKey network_anonymization_key,NetLogWithSource net_log,std::optional<ResolveHostParameters> optional_parameters)709 MockHostResolverBase::CreateRequest(
710     url::SchemeHostPort host,
711     NetworkAnonymizationKey network_anonymization_key,
712     NetLogWithSource net_log,
713     std::optional<ResolveHostParameters> optional_parameters) {
714   return std::make_unique<RequestImpl>(Host(std::move(host)),
715                                        network_anonymization_key,
716                                        optional_parameters, AsWeakPtr());
717 }
718 
719 std::unique_ptr<HostResolver::ResolveHostRequest>
CreateRequest(const HostPortPair & host,const NetworkAnonymizationKey & network_anonymization_key,const NetLogWithSource & source_net_log,const std::optional<ResolveHostParameters> & optional_parameters)720 MockHostResolverBase::CreateRequest(
721     const HostPortPair& host,
722     const NetworkAnonymizationKey& network_anonymization_key,
723     const NetLogWithSource& source_net_log,
724     const std::optional<ResolveHostParameters>& optional_parameters) {
725   return std::make_unique<RequestImpl>(Host(host), network_anonymization_key,
726                                        optional_parameters, AsWeakPtr());
727 }
728 
729 std::unique_ptr<HostResolver::ServiceEndpointRequest>
CreateServiceEndpointRequest(Host host,NetworkAnonymizationKey network_anonymization_key,NetLogWithSource net_log,ResolveHostParameters parameters)730 MockHostResolverBase::CreateServiceEndpointRequest(
731     Host host,
732     NetworkAnonymizationKey network_anonymization_key,
733     NetLogWithSource net_log,
734     ResolveHostParameters parameters) {
735   NOTIMPLEMENTED();
736   return nullptr;
737 }
738 
739 std::unique_ptr<HostResolver::ProbeRequest>
CreateDohProbeRequest()740 MockHostResolverBase::CreateDohProbeRequest() {
741   return std::make_unique<ProbeRequestImpl>(AsWeakPtr());
742 }
743 
744 std::unique_ptr<HostResolver::MdnsListener>
CreateMdnsListener(const HostPortPair & host,DnsQueryType query_type)745 MockHostResolverBase::CreateMdnsListener(const HostPortPair& host,
746                                          DnsQueryType query_type) {
747   return std::make_unique<MdnsListenerImpl>(host, query_type, AsWeakPtr());
748 }
749 
GetHostCache()750 HostCache* MockHostResolverBase::GetHostCache() {
751   return cache_.get();
752 }
753 
LoadIntoCache(absl::variant<url::SchemeHostPort,HostPortPair> endpoint,const NetworkAnonymizationKey & network_anonymization_key,const std::optional<ResolveHostParameters> & optional_parameters)754 int MockHostResolverBase::LoadIntoCache(
755     absl::variant<url::SchemeHostPort, HostPortPair> endpoint,
756     const NetworkAnonymizationKey& network_anonymization_key,
757     const std::optional<ResolveHostParameters>& optional_parameters) {
758   return LoadIntoCache(Host(std::move(endpoint)), network_anonymization_key,
759                        optional_parameters);
760 }
761 
LoadIntoCache(const Host & endpoint,const NetworkAnonymizationKey & network_anonymization_key,const std::optional<ResolveHostParameters> & optional_parameters)762 int MockHostResolverBase::LoadIntoCache(
763     const Host& endpoint,
764     const NetworkAnonymizationKey& network_anonymization_key,
765     const std::optional<ResolveHostParameters>& optional_parameters) {
766   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
767   DCHECK(cache_);
768 
769   ResolveHostParameters parameters =
770       optional_parameters.value_or(ResolveHostParameters());
771 
772   std::vector<HostResolverEndpointResult> endpoints;
773   std::set<std::string> aliases;
774   std::optional<HostCache::EntryStaleness> stale_info;
775   int rv = ResolveFromIPLiteralOrCache(
776       endpoint, network_anonymization_key, parameters.dns_query_type,
777       ParametersToHostResolverFlags(parameters), parameters.source,
778       parameters.cache_usage, &endpoints, &aliases, &stale_info);
779   if (rv != ERR_DNS_CACHE_MISS) {
780     // Request already in cache (or IP literal). No need to load it.
781     return rv;
782   }
783 
784   // Just like the real resolver, refuse to do anything with invalid
785   // hostnames.
786   if (!dns_names_util::IsValidDnsName(endpoint.GetHostnameWithoutBrackets()))
787     return ERR_NAME_NOT_RESOLVED;
788 
789   RequestImpl request(endpoint, network_anonymization_key, optional_parameters,
790                       AsWeakPtr());
791   return DoSynchronousResolution(request);
792 }
793 
ResolveAllPending()794 void MockHostResolverBase::ResolveAllPending() {
795   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
796   DCHECK(ondemand_mode_);
797   for (auto& [id, request] : state_->mutable_requests()) {
798     base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
799         FROM_HERE,
800         base::BindOnce(&MockHostResolverBase::ResolveNow, AsWeakPtr(), id));
801   }
802 }
803 
last_id()804 size_t MockHostResolverBase::last_id() {
805   if (!has_pending_requests())
806     return 0;
807   return state_->mutable_requests().rbegin()->first;
808 }
809 
ResolveNow(size_t id)810 void MockHostResolverBase::ResolveNow(size_t id) {
811   auto it = state_->mutable_requests().find(id);
812   if (it == state_->mutable_requests().end())
813     return;  // was canceled
814 
815   RequestImpl* req = it->second;
816   state_->mutable_requests().erase(it);
817 
818   int error = DoSynchronousResolution(*req);
819   req->OnAsyncCompleted(id, error);
820 }
821 
DetachRequest(size_t id)822 void MockHostResolverBase::DetachRequest(size_t id) {
823   auto it = state_->mutable_requests().find(id);
824   CHECK(it != state_->mutable_requests().end());
825   state_->mutable_requests().erase(it);
826 }
827 
request_host(size_t id)828 std::string_view MockHostResolverBase::request_host(size_t id) {
829   DCHECK(request(id));
830   return request(id)->request_endpoint().GetHostnameWithoutBrackets();
831 }
832 
request_priority(size_t id)833 RequestPriority MockHostResolverBase::request_priority(size_t id) {
834   DCHECK(request(id));
835   return request(id)->priority();
836 }
837 
838 const NetworkAnonymizationKey&
request_network_anonymization_key(size_t id)839 MockHostResolverBase::request_network_anonymization_key(size_t id) {
840   DCHECK(request(id));
841   return request(id)->network_anonymization_key();
842 }
843 
ResolveOnlyRequestNow()844 void MockHostResolverBase::ResolveOnlyRequestNow() {
845   DCHECK_EQ(1u, state_->mutable_requests().size());
846   ResolveNow(state_->mutable_requests().begin()->first);
847 }
848 
TriggerMdnsListeners(const HostPortPair & host,DnsQueryType query_type,MdnsListenerUpdateType update_type,const IPEndPoint & address_result)849 void MockHostResolverBase::TriggerMdnsListeners(
850     const HostPortPair& host,
851     DnsQueryType query_type,
852     MdnsListenerUpdateType update_type,
853     const IPEndPoint& address_result) {
854   for (MdnsListenerImpl* listener : listeners_) {
855     if (listener->host() == host && listener->query_type() == query_type)
856       listener->TriggerAddressResult(update_type, address_result);
857   }
858 }
859 
TriggerMdnsListeners(const HostPortPair & host,DnsQueryType query_type,MdnsListenerUpdateType update_type,const std::vector<std::string> & text_result)860 void MockHostResolverBase::TriggerMdnsListeners(
861     const HostPortPair& host,
862     DnsQueryType query_type,
863     MdnsListenerUpdateType update_type,
864     const std::vector<std::string>& text_result) {
865   for (MdnsListenerImpl* listener : listeners_) {
866     if (listener->host() == host && listener->query_type() == query_type)
867       listener->TriggerTextResult(update_type, text_result);
868   }
869 }
870 
TriggerMdnsListeners(const HostPortPair & host,DnsQueryType query_type,MdnsListenerUpdateType update_type,const HostPortPair & host_result)871 void MockHostResolverBase::TriggerMdnsListeners(
872     const HostPortPair& host,
873     DnsQueryType query_type,
874     MdnsListenerUpdateType update_type,
875     const HostPortPair& host_result) {
876   for (MdnsListenerImpl* listener : listeners_) {
877     if (listener->host() == host && listener->query_type() == query_type)
878       listener->TriggerHostnameResult(update_type, host_result);
879   }
880 }
881 
TriggerMdnsListeners(const HostPortPair & host,DnsQueryType query_type,MdnsListenerUpdateType update_type)882 void MockHostResolverBase::TriggerMdnsListeners(
883     const HostPortPair& host,
884     DnsQueryType query_type,
885     MdnsListenerUpdateType update_type) {
886   for (MdnsListenerImpl* listener : listeners_) {
887     if (listener->host() == host && listener->query_type() == query_type)
888       listener->TriggerUnhandledResult(update_type);
889   }
890 }
891 
request(size_t id)892 MockHostResolverBase::RequestImpl* MockHostResolverBase::request(size_t id) {
893   RequestMap::iterator request = state_->mutable_requests().find(id);
894   CHECK(request != state_->mutable_requests().end());
895   CHECK_EQ(request->second->id(), id);
896   return (*request).second;
897 }
898 
899 // start id from 1 to distinguish from NULL RequestHandle
MockHostResolverBase(bool use_caching,int cache_invalidation_num,RuleResolver rule_resolver)900 MockHostResolverBase::MockHostResolverBase(bool use_caching,
901                                            int cache_invalidation_num,
902                                            RuleResolver rule_resolver)
903     : rule_resolver_(std::move(rule_resolver)),
904       initial_cache_invalidation_num_(cache_invalidation_num),
905       tick_clock_(base::DefaultTickClock::GetInstance()),
906       state_(base::MakeRefCounted<State>()) {
907   if (use_caching)
908     cache_ = std::make_unique<HostCache>(kMaxCacheEntries);
909   else
910     DCHECK_GE(0, cache_invalidation_num);
911 }
912 
Resolve(RequestImpl * request)913 int MockHostResolverBase::Resolve(RequestImpl* request) {
914   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
915 
916   last_request_priority_ = request->parameters().initial_priority;
917   last_request_network_anonymization_key_ =
918       request->network_anonymization_key();
919   last_secure_dns_policy_ = request->parameters().secure_dns_policy;
920   state_->IncrementNumResolve();
921   std::vector<HostResolverEndpointResult> endpoints;
922   std::set<std::string> aliases;
923   std::optional<HostCache::EntryStaleness> stale_info;
924   // TODO(crbug.com/1264933): Allow caching `ConnectionEndpoint` results.
925   int rv = ResolveFromIPLiteralOrCache(
926       request->request_endpoint(), request->network_anonymization_key(),
927       request->parameters().dns_query_type, request->host_resolver_flags(),
928       request->parameters().source, request->parameters().cache_usage,
929       &endpoints, &aliases, &stale_info);
930 
931   if (rv == OK && !request->parameters().is_speculative) {
932     request->SetEndpointResults(std::move(endpoints), std::move(aliases),
933                                 std::move(stale_info));
934   } else {
935     request->SetError(rv);
936   }
937 
938   if (rv != ERR_DNS_CACHE_MISS ||
939       request->parameters().source == HostResolverSource::LOCAL_ONLY) {
940     return SquashErrorCode(rv);
941   }
942 
943   // Just like the real resolver, refuse to do anything with invalid
944   // hostnames.
945   if (!dns_names_util::IsValidDnsName(
946           request->request_endpoint().GetHostnameWithoutBrackets())) {
947     request->SetError(ERR_NAME_NOT_RESOLVED);
948     return ERR_NAME_NOT_RESOLVED;
949   }
950 
951   if (synchronous_mode_)
952     return DoSynchronousResolution(*request);
953 
954   // Store the request for asynchronous resolution
955   size_t id = next_request_id_++;
956   request->set_id(id);
957   state_->mutable_requests()[id] = request;
958 
959   if (!ondemand_mode_) {
960     base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
961         FROM_HERE,
962         base::BindOnce(&MockHostResolverBase::ResolveNow, AsWeakPtr(), id));
963   }
964 
965   return ERR_IO_PENDING;
966 }
967 
ResolveFromIPLiteralOrCache(const Host & endpoint,const NetworkAnonymizationKey & network_anonymization_key,DnsQueryType dns_query_type,HostResolverFlags flags,HostResolverSource source,HostResolver::ResolveHostParameters::CacheUsage cache_usage,std::vector<HostResolverEndpointResult> * out_endpoints,std::set<std::string> * out_aliases,std::optional<HostCache::EntryStaleness> * out_stale_info)968 int MockHostResolverBase::ResolveFromIPLiteralOrCache(
969     const Host& endpoint,
970     const NetworkAnonymizationKey& network_anonymization_key,
971     DnsQueryType dns_query_type,
972     HostResolverFlags flags,
973     HostResolverSource source,
974     HostResolver::ResolveHostParameters::CacheUsage cache_usage,
975     std::vector<HostResolverEndpointResult>* out_endpoints,
976     std::set<std::string>* out_aliases,
977     std::optional<HostCache::EntryStaleness>* out_stale_info) {
978   DCHECK(out_endpoints);
979   DCHECK(out_aliases);
980   DCHECK(out_stale_info);
981   out_endpoints->clear();
982   out_aliases->clear();
983   *out_stale_info = std::nullopt;
984 
985   IPAddress ip_address;
986   if (ip_address.AssignFromIPLiteral(endpoint.GetHostnameWithoutBrackets())) {
987     const DnsQueryType desired_address_query =
988         AddressFamilyToDnsQueryType(GetAddressFamily(ip_address));
989     DCHECK_NE(desired_address_query, DnsQueryType::UNSPECIFIED);
990 
991     // This matches the behavior HostResolverImpl.
992     if (dns_query_type != DnsQueryType::UNSPECIFIED &&
993         dns_query_type != desired_address_query) {
994       return ERR_NAME_NOT_RESOLVED;
995     }
996 
997     *out_endpoints = std::vector<HostResolverEndpointResult>(1);
998     (*out_endpoints)[0].ip_endpoints.emplace_back(ip_address,
999                                                   endpoint.GetPort());
1000     if (flags & HOST_RESOLVER_CANONNAME)
1001       *out_aliases = {ip_address.ToString()};
1002     return OK;
1003   }
1004 
1005   std::vector<IPEndPoint> localhost_endpoints;
1006   // Immediately resolve any "localhost" or recognized similar names.
1007   if (IsAddressType(dns_query_type) &&
1008       ResolveLocalHostname(endpoint.GetHostnameWithoutBrackets(),
1009                            &localhost_endpoints)) {
1010     *out_endpoints = std::vector<HostResolverEndpointResult>(1);
1011     (*out_endpoints)[0].ip_endpoints = localhost_endpoints;
1012     return OK;
1013   }
1014   int rv = ERR_DNS_CACHE_MISS;
1015   bool cache_allowed =
1016       cache_usage == HostResolver::ResolveHostParameters::CacheUsage::ALLOWED ||
1017       cache_usage ==
1018           HostResolver::ResolveHostParameters::CacheUsage::STALE_ALLOWED;
1019   if (cache_.get() && cache_allowed) {
1020     // Local-only requests search the cache for non-local-only results.
1021     HostResolverSource effective_source =
1022         source == HostResolverSource::LOCAL_ONLY ? HostResolverSource::ANY
1023                                                  : source;
1024     HostCache::Key key(GetCacheHost(endpoint), dns_query_type, flags,
1025                        effective_source, network_anonymization_key);
1026     const std::pair<const HostCache::Key, HostCache::Entry>* cache_result;
1027     HostCache::EntryStaleness stale_info = HostCache::kNotStale;
1028     if (cache_usage ==
1029         HostResolver::ResolveHostParameters::CacheUsage::STALE_ALLOWED) {
1030       cache_result = cache_->LookupStale(key, tick_clock_->NowTicks(),
1031                                          &stale_info, true /* ignore_secure */);
1032     } else {
1033       cache_result = cache_->Lookup(key, tick_clock_->NowTicks(),
1034                                     true /* ignore_secure */);
1035     }
1036     if (cache_result) {
1037       rv = cache_result->second.error();
1038       if (rv == OK) {
1039         *out_endpoints = cache_result->second.GetEndpoints();
1040 
1041         *out_aliases = cache_result->second.aliases();
1042         *out_stale_info = std::move(stale_info);
1043       }
1044 
1045       auto cache_invalidation_iterator = cache_invalidation_nums_.find(key);
1046       if (cache_invalidation_iterator != cache_invalidation_nums_.end()) {
1047         DCHECK_LE(1, cache_invalidation_iterator->second);
1048         cache_invalidation_iterator->second--;
1049         if (cache_invalidation_iterator->second == 0) {
1050           HostCache::Entry new_entry(cache_result->second);
1051           cache_->Set(key, new_entry, tick_clock_->NowTicks(),
1052                       base::TimeDelta());
1053           cache_invalidation_nums_.erase(cache_invalidation_iterator);
1054         }
1055       }
1056     }
1057   }
1058   return rv;
1059 }
1060 
DoSynchronousResolution(RequestImpl & request)1061 int MockHostResolverBase::DoSynchronousResolution(RequestImpl& request) {
1062   state_->IncrementNumNonLocalResolves();
1063 
1064   const RuleResolver::RuleResultOrError& result = rule_resolver_.Resolve(
1065       request.request_endpoint(), {request.parameters().dns_query_type},
1066       request.parameters().source);
1067 
1068   int error = ERR_UNEXPECTED;
1069   std::optional<HostCache::Entry> cache_entry;
1070   if (absl::holds_alternative<RuleResolver::RuleResult>(result)) {
1071     const auto& rule_result = absl::get<RuleResolver::RuleResult>(result);
1072     const auto& endpoint_results = rule_result.endpoints;
1073     const auto& aliases = rule_result.aliases;
1074     request.SetEndpointResults(endpoint_results, aliases,
1075                                /*staleness=*/std::nullopt);
1076     // TODO(crbug.com/1264933): Change `error` on empty results?
1077     error = OK;
1078     if (cache_.get()) {
1079       cache_entry = CreateCacheEntry(request.request_endpoint().GetHostname(),
1080                                      endpoint_results, aliases);
1081     }
1082   } else {
1083     DCHECK(absl::holds_alternative<RuleResolver::ErrorResult>(result));
1084     error = absl::get<RuleResolver::ErrorResult>(result);
1085     request.SetError(error);
1086     if (cache_.get()) {
1087       cache_entry.emplace(error, HostCache::Entry::SOURCE_UNKNOWN);
1088     }
1089   }
1090   if (cache_.get() && cache_entry.has_value()) {
1091     HostCache::Key key(
1092         GetCacheHost(request.request_endpoint()),
1093         request.parameters().dns_query_type, request.host_resolver_flags(),
1094         request.parameters().source, request.network_anonymization_key());
1095     // Storing a failure with TTL 0 so that it overwrites previous value.
1096     base::TimeDelta ttl;
1097     if (error == OK) {
1098       ttl = base::Seconds(kCacheEntryTTLSeconds);
1099       if (initial_cache_invalidation_num_ > 0)
1100         cache_invalidation_nums_[key] = initial_cache_invalidation_num_;
1101     }
1102     cache_->Set(key, cache_entry.value(), tick_clock_->NowTicks(), ttl);
1103   }
1104 
1105   return SquashErrorCode(error);
1106 }
1107 
AddListener(MdnsListenerImpl * listener)1108 void MockHostResolverBase::AddListener(MdnsListenerImpl* listener) {
1109   listeners_.insert(listener);
1110 }
1111 
RemoveCancelledListener(MdnsListenerImpl * listener)1112 void MockHostResolverBase::RemoveCancelledListener(MdnsListenerImpl* listener) {
1113   listeners_.erase(listener);
1114 }
1115 
MockHostResolverFactory(MockHostResolverBase::RuleResolver rules,bool use_caching,int cache_invalidation_num)1116 MockHostResolverFactory::MockHostResolverFactory(
1117     MockHostResolverBase::RuleResolver rules,
1118     bool use_caching,
1119     int cache_invalidation_num)
1120     : rules_(std::move(rules)),
1121       use_caching_(use_caching),
1122       cache_invalidation_num_(cache_invalidation_num) {}
1123 
1124 MockHostResolverFactory::~MockHostResolverFactory() = default;
1125 
CreateResolver(HostResolverManager * manager,std::string_view host_mapping_rules,bool enable_caching)1126 std::unique_ptr<HostResolver> MockHostResolverFactory::CreateResolver(
1127     HostResolverManager* manager,
1128     std::string_view host_mapping_rules,
1129     bool enable_caching) {
1130   DCHECK(host_mapping_rules.empty());
1131 
1132   // Explicit new to access private constructor.
1133   auto resolver = base::WrapUnique(new MockHostResolverBase(
1134       enable_caching && use_caching_, cache_invalidation_num_, rules_));
1135   return resolver;
1136 }
1137 
CreateStandaloneResolver(NetLog * net_log,const HostResolver::ManagerOptions & options,std::string_view host_mapping_rules,bool enable_caching)1138 std::unique_ptr<HostResolver> MockHostResolverFactory::CreateStandaloneResolver(
1139     NetLog* net_log,
1140     const HostResolver::ManagerOptions& options,
1141     std::string_view host_mapping_rules,
1142     bool enable_caching) {
1143   return CreateResolver(nullptr, host_mapping_rules, enable_caching);
1144 }
1145 
1146 //-----------------------------------------------------------------------------
1147 
Rule(ResolverType resolver_type,std::string_view host_pattern,AddressFamily address_family,HostResolverFlags host_resolver_flags,std::string_view replacement,std::vector<std::string> dns_aliases,int latency_ms)1148 RuleBasedHostResolverProc::Rule::Rule(ResolverType resolver_type,
1149                                       std::string_view host_pattern,
1150                                       AddressFamily address_family,
1151                                       HostResolverFlags host_resolver_flags,
1152                                       std::string_view replacement,
1153                                       std::vector<std::string> dns_aliases,
1154                                       int latency_ms)
1155     : resolver_type(resolver_type),
1156       host_pattern(host_pattern),
1157       address_family(address_family),
1158       host_resolver_flags(host_resolver_flags),
1159       replacement(replacement),
1160       dns_aliases(std::move(dns_aliases)),
1161       latency_ms(latency_ms) {
1162   DCHECK(this->dns_aliases != std::vector<std::string>({""}));
1163 }
1164 
1165 RuleBasedHostResolverProc::Rule::Rule(const Rule& other) = default;
1166 
1167 RuleBasedHostResolverProc::Rule::~Rule() = default;
1168 
RuleBasedHostResolverProc(scoped_refptr<HostResolverProc> previous,bool allow_fallback)1169 RuleBasedHostResolverProc::RuleBasedHostResolverProc(
1170     scoped_refptr<HostResolverProc> previous,
1171     bool allow_fallback)
1172     : HostResolverProc(std::move(previous), allow_fallback) {}
1173 
AddRule(std::string_view host_pattern,std::string_view replacement)1174 void RuleBasedHostResolverProc::AddRule(std::string_view host_pattern,
1175                                         std::string_view replacement) {
1176   AddRuleForAddressFamily(host_pattern, ADDRESS_FAMILY_UNSPECIFIED,
1177                           replacement);
1178 }
1179 
AddRuleForAddressFamily(std::string_view host_pattern,AddressFamily address_family,std::string_view replacement)1180 void RuleBasedHostResolverProc::AddRuleForAddressFamily(
1181     std::string_view host_pattern,
1182     AddressFamily address_family,
1183     std::string_view replacement) {
1184   DCHECK(!replacement.empty());
1185   HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY;
1186   Rule rule(Rule::kResolverTypeSystem, host_pattern, address_family, flags,
1187             replacement, {} /* dns_aliases */, 0);
1188   AddRuleInternal(rule);
1189 }
1190 
AddRuleWithFlags(std::string_view host_pattern,std::string_view replacement,HostResolverFlags flags,std::vector<std::string> dns_aliases)1191 void RuleBasedHostResolverProc::AddRuleWithFlags(
1192     std::string_view host_pattern,
1193     std::string_view replacement,
1194     HostResolverFlags flags,
1195     std::vector<std::string> dns_aliases) {
1196   DCHECK(!replacement.empty());
1197   Rule rule(Rule::kResolverTypeSystem, host_pattern, ADDRESS_FAMILY_UNSPECIFIED,
1198             flags, replacement, std::move(dns_aliases), 0);
1199   AddRuleInternal(rule);
1200 }
1201 
AddIPLiteralRule(std::string_view host_pattern,std::string_view ip_literal,std::string_view canonical_name)1202 void RuleBasedHostResolverProc::AddIPLiteralRule(
1203     std::string_view host_pattern,
1204     std::string_view ip_literal,
1205     std::string_view canonical_name) {
1206   // Literals are always resolved to themselves by HostResolverImpl,
1207   // consequently we do not support remapping them.
1208   IPAddress ip_address;
1209   DCHECK(!ip_address.AssignFromIPLiteral(host_pattern));
1210   HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY;
1211   std::vector<std::string> aliases;
1212   if (!canonical_name.empty()) {
1213     flags |= HOST_RESOLVER_CANONNAME;
1214     aliases.emplace_back(canonical_name);
1215   }
1216 
1217   Rule rule(Rule::kResolverTypeIPLiteral, host_pattern,
1218             ADDRESS_FAMILY_UNSPECIFIED, flags, ip_literal, std::move(aliases),
1219             0);
1220   AddRuleInternal(rule);
1221 }
1222 
AddIPLiteralRuleWithDnsAliases(std::string_view host_pattern,std::string_view ip_literal,std::vector<std::string> dns_aliases)1223 void RuleBasedHostResolverProc::AddIPLiteralRuleWithDnsAliases(
1224     std::string_view host_pattern,
1225     std::string_view ip_literal,
1226     std::vector<std::string> dns_aliases) {
1227   // Literals are always resolved to themselves by HostResolverImpl,
1228   // consequently we do not support remapping them.
1229   IPAddress ip_address;
1230   DCHECK(!ip_address.AssignFromIPLiteral(host_pattern));
1231   HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY;
1232   if (!dns_aliases.empty())
1233     flags |= HOST_RESOLVER_CANONNAME;
1234 
1235   Rule rule(Rule::kResolverTypeIPLiteral, host_pattern,
1236             ADDRESS_FAMILY_UNSPECIFIED, flags, ip_literal,
1237             std::move(dns_aliases), 0);
1238   AddRuleInternal(rule);
1239 }
1240 
AddRuleWithLatency(std::string_view host_pattern,std::string_view replacement,int latency_ms)1241 void RuleBasedHostResolverProc::AddRuleWithLatency(
1242     std::string_view host_pattern,
1243     std::string_view replacement,
1244     int latency_ms) {
1245   DCHECK(!replacement.empty());
1246   HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY;
1247   Rule rule(Rule::kResolverTypeSystem, host_pattern, ADDRESS_FAMILY_UNSPECIFIED,
1248             flags, replacement, /*dns_aliases=*/{}, latency_ms);
1249   AddRuleInternal(rule);
1250 }
1251 
AllowDirectLookup(std::string_view host_pattern)1252 void RuleBasedHostResolverProc::AllowDirectLookup(
1253     std::string_view host_pattern) {
1254   HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY;
1255   Rule rule(Rule::kResolverTypeSystem, host_pattern, ADDRESS_FAMILY_UNSPECIFIED,
1256             flags, std::string(), /*dns_aliases=*/{}, 0);
1257   AddRuleInternal(rule);
1258 }
1259 
AddSimulatedFailure(std::string_view host_pattern,HostResolverFlags flags)1260 void RuleBasedHostResolverProc::AddSimulatedFailure(
1261     std::string_view host_pattern,
1262     HostResolverFlags flags) {
1263   Rule rule(Rule::kResolverTypeFail, host_pattern, ADDRESS_FAMILY_UNSPECIFIED,
1264             flags, std::string(), /*dns_aliases=*/{}, 0);
1265   AddRuleInternal(rule);
1266 }
1267 
AddSimulatedTimeoutFailure(std::string_view host_pattern,HostResolverFlags flags)1268 void RuleBasedHostResolverProc::AddSimulatedTimeoutFailure(
1269     std::string_view host_pattern,
1270     HostResolverFlags flags) {
1271   Rule rule(Rule::kResolverTypeFailTimeout, host_pattern,
1272             ADDRESS_FAMILY_UNSPECIFIED, flags, std::string(),
1273             /*dns_aliases=*/{}, 0);
1274   AddRuleInternal(rule);
1275 }
1276 
ClearRules()1277 void RuleBasedHostResolverProc::ClearRules() {
1278   CHECK(modifications_allowed_);
1279   base::AutoLock lock(rule_lock_);
1280   rules_.clear();
1281 }
1282 
DisableModifications()1283 void RuleBasedHostResolverProc::DisableModifications() {
1284   modifications_allowed_ = false;
1285 }
1286 
GetRules()1287 RuleBasedHostResolverProc::RuleList RuleBasedHostResolverProc::GetRules() {
1288   RuleList rv;
1289   {
1290     base::AutoLock lock(rule_lock_);
1291     rv = rules_;
1292   }
1293   return rv;
1294 }
1295 
NumResolvesForHostPattern(std::string_view host_pattern)1296 size_t RuleBasedHostResolverProc::NumResolvesForHostPattern(
1297     std::string_view host_pattern) {
1298   base::AutoLock lock(rule_lock_);
1299   return num_resolves_per_host_pattern_[host_pattern];
1300 }
1301 
Resolve(const std::string & host,AddressFamily address_family,HostResolverFlags host_resolver_flags,AddressList * addrlist,int * os_error)1302 int RuleBasedHostResolverProc::Resolve(const std::string& host,
1303                                        AddressFamily address_family,
1304                                        HostResolverFlags host_resolver_flags,
1305                                        AddressList* addrlist,
1306                                        int* os_error) {
1307   base::AutoLock lock(rule_lock_);
1308   RuleList::iterator r;
1309   for (r = rules_.begin(); r != rules_.end(); ++r) {
1310     bool matches_address_family =
1311         r->address_family == ADDRESS_FAMILY_UNSPECIFIED ||
1312         r->address_family == address_family;
1313     // Ignore HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6, since it should
1314     // have no impact on whether a rule matches.
1315     HostResolverFlags flags =
1316         host_resolver_flags & ~HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
1317     // Flags match if all of the bitflags in host_resolver_flags are enabled
1318     // in the rule's host_resolver_flags. However, the rule may have additional
1319     // flags specified, in which case the flags should still be considered a
1320     // match.
1321     bool matches_flags = (r->host_resolver_flags & flags) == flags;
1322     if (matches_flags && matches_address_family &&
1323         base::MatchPattern(host, r->host_pattern)) {
1324       num_resolves_per_host_pattern_[r->host_pattern]++;
1325 
1326       if (r->latency_ms != 0) {
1327         base::PlatformThread::Sleep(base::Milliseconds(r->latency_ms));
1328       }
1329 
1330       // Remap to a new host.
1331       const std::string& effective_host =
1332           r->replacement.empty() ? host : r->replacement;
1333 
1334       // Apply the resolving function to the remapped hostname.
1335       switch (r->resolver_type) {
1336         case Rule::kResolverTypeFail:
1337           return ERR_NAME_NOT_RESOLVED;
1338         case Rule::kResolverTypeFailTimeout:
1339           return ERR_DNS_TIMED_OUT;
1340         case Rule::kResolverTypeSystem:
1341           EnsureSystemHostResolverCallReady();
1342           return SystemHostResolverCall(effective_host, address_family,
1343                                         host_resolver_flags, addrlist,
1344                                         os_error);
1345         case Rule::kResolverTypeIPLiteral: {
1346           AddressList raw_addr_list;
1347           std::vector<std::string> aliases;
1348           aliases = (!r->dns_aliases.empty())
1349                         ? r->dns_aliases
1350                         : std::vector<std::string>({host});
1351           std::vector<net::IPEndPoint> ip_endpoints;
1352           int result = ParseAddressList(effective_host, &ip_endpoints);
1353           // Filter out addresses with the wrong family.
1354           *addrlist = AddressList();
1355           for (const auto& address : ip_endpoints) {
1356             if (address_family == ADDRESS_FAMILY_UNSPECIFIED ||
1357                 address_family == address.GetFamily()) {
1358               addrlist->push_back(address);
1359             }
1360           }
1361           addrlist->SetDnsAliases(aliases);
1362 
1363           if (result == OK && addrlist->empty())
1364             return ERR_NAME_NOT_RESOLVED;
1365           return result;
1366         }
1367         default:
1368           NOTREACHED();
1369           return ERR_UNEXPECTED;
1370       }
1371     }
1372   }
1373 
1374   return ResolveUsingPrevious(host, address_family, host_resolver_flags,
1375                               addrlist, os_error);
1376 }
1377 
1378 RuleBasedHostResolverProc::~RuleBasedHostResolverProc() = default;
1379 
AddRuleInternal(const Rule & rule)1380 void RuleBasedHostResolverProc::AddRuleInternal(const Rule& rule) {
1381   Rule fixed_rule = rule;
1382   // SystemResolverProc expects valid DNS addresses.
1383   // So for kResolverTypeSystem rules:
1384   // * CHECK that replacement is empty (empty domain names mean use a direct
1385   //   lookup) or a valid DNS name (which includes IP addresses).
1386   // * If the replacement is an IP address, switch to an IP literal rule.
1387   if (fixed_rule.resolver_type == Rule::kResolverTypeSystem) {
1388     CHECK(fixed_rule.replacement.empty() ||
1389           dns_names_util::IsValidDnsName(fixed_rule.replacement));
1390 
1391     IPAddress ip_address;
1392     bool valid_address = ip_address.AssignFromIPLiteral(fixed_rule.replacement);
1393     if (valid_address) {
1394       fixed_rule.resolver_type = Rule::kResolverTypeIPLiteral;
1395     }
1396   }
1397 
1398   CHECK(modifications_allowed_);
1399   base::AutoLock lock(rule_lock_);
1400   rules_.push_back(fixed_rule);
1401 }
1402 
CreateCatchAllHostResolverProc()1403 scoped_refptr<RuleBasedHostResolverProc> CreateCatchAllHostResolverProc() {
1404   auto catchall =
1405       base::MakeRefCounted<RuleBasedHostResolverProc>(/*previous=*/nullptr,
1406                                                       /*allow_fallback=*/false);
1407   // Note that IPv6 lookups fail.
1408   catchall->AddIPLiteralRule("*", "127.0.0.1", "localhost");
1409 
1410   // Next add a rules-based layer that the test controls.
1411   return base::MakeRefCounted<RuleBasedHostResolverProc>(
1412       std::move(catchall), /*allow_fallback=*/false);
1413 }
1414 
1415 //-----------------------------------------------------------------------------
1416 
1417 // Implementation of ResolveHostRequest that tracks cancellations when the
1418 // request is destroyed after being started.
1419 class HangingHostResolver::RequestImpl
1420     : public HostResolver::ResolveHostRequest,
1421       public HostResolver::ProbeRequest {
1422  public:
RequestImpl(base::WeakPtr<HangingHostResolver> resolver)1423   explicit RequestImpl(base::WeakPtr<HangingHostResolver> resolver)
1424       : resolver_(resolver) {}
1425 
1426   RequestImpl(const RequestImpl&) = delete;
1427   RequestImpl& operator=(const RequestImpl&) = delete;
1428 
~RequestImpl()1429   ~RequestImpl() override {
1430     if (is_running_ && resolver_)
1431       resolver_->state_->IncrementNumCancellations();
1432   }
1433 
Start(CompletionOnceCallback callback)1434   int Start(CompletionOnceCallback callback) override { return Start(); }
1435 
Start()1436   int Start() override {
1437     DCHECK(resolver_);
1438     is_running_ = true;
1439     return ERR_IO_PENDING;
1440   }
1441 
GetAddressResults() const1442   const AddressList* GetAddressResults() const override {
1443     base::ImmediateCrash();
1444   }
1445 
GetEndpointResults() const1446   const std::vector<HostResolverEndpointResult>* GetEndpointResults()
1447       const override {
1448     base::ImmediateCrash();
1449   }
1450 
GetTextResults() const1451   const std::vector<std::string>* GetTextResults() const override {
1452     base::ImmediateCrash();
1453   }
1454 
GetHostnameResults() const1455   const std::vector<HostPortPair>* GetHostnameResults() const override {
1456     base::ImmediateCrash();
1457   }
1458 
GetDnsAliasResults() const1459   const std::set<std::string>* GetDnsAliasResults() const override {
1460     base::ImmediateCrash();
1461   }
1462 
GetResolveErrorInfo() const1463   net::ResolveErrorInfo GetResolveErrorInfo() const override {
1464     base::ImmediateCrash();
1465   }
1466 
GetStaleInfo() const1467   const std::optional<HostCache::EntryStaleness>& GetStaleInfo()
1468       const override {
1469     base::ImmediateCrash();
1470   }
1471 
ChangeRequestPriority(RequestPriority priority)1472   void ChangeRequestPriority(RequestPriority priority) override {}
1473 
1474  private:
1475   // Use a WeakPtr as the resolver may be destroyed while there are still
1476   // outstanding request objects.
1477   base::WeakPtr<HangingHostResolver> resolver_;
1478   bool is_running_ = false;
1479 };
1480 
1481 HangingHostResolver::State::State() = default;
1482 HangingHostResolver::State::~State() = default;
1483 
HangingHostResolver()1484 HangingHostResolver::HangingHostResolver()
1485     : state_(base::MakeRefCounted<State>()) {}
1486 
1487 HangingHostResolver::~HangingHostResolver() = default;
1488 
OnShutdown()1489 void HangingHostResolver::OnShutdown() {
1490   shutting_down_ = true;
1491 }
1492 
1493 std::unique_ptr<HostResolver::ResolveHostRequest>
CreateRequest(url::SchemeHostPort host,NetworkAnonymizationKey network_anonymization_key,NetLogWithSource net_log,std::optional<ResolveHostParameters> optional_parameters)1494 HangingHostResolver::CreateRequest(
1495     url::SchemeHostPort host,
1496     NetworkAnonymizationKey network_anonymization_key,
1497     NetLogWithSource net_log,
1498     std::optional<ResolveHostParameters> optional_parameters) {
1499   // TODO(crbug.com/1206799): Propagate scheme and make affect behavior.
1500   return CreateRequest(HostPortPair::FromSchemeHostPort(host),
1501                        network_anonymization_key, net_log, optional_parameters);
1502 }
1503 
1504 std::unique_ptr<HostResolver::ResolveHostRequest>
CreateRequest(const HostPortPair & host,const NetworkAnonymizationKey & network_anonymization_key,const NetLogWithSource & source_net_log,const std::optional<ResolveHostParameters> & optional_parameters)1505 HangingHostResolver::CreateRequest(
1506     const HostPortPair& host,
1507     const NetworkAnonymizationKey& network_anonymization_key,
1508     const NetLogWithSource& source_net_log,
1509     const std::optional<ResolveHostParameters>& optional_parameters) {
1510   last_host_ = host;
1511   last_network_anonymization_key_ = network_anonymization_key;
1512 
1513   if (shutting_down_)
1514     return CreateFailingRequest(ERR_CONTEXT_SHUT_DOWN);
1515 
1516   if (optional_parameters &&
1517       optional_parameters.value().source == HostResolverSource::LOCAL_ONLY) {
1518     return CreateFailingRequest(ERR_DNS_CACHE_MISS);
1519   }
1520 
1521   return std::make_unique<RequestImpl>(weak_ptr_factory_.GetWeakPtr());
1522 }
1523 
1524 std::unique_ptr<HostResolver::ServiceEndpointRequest>
CreateServiceEndpointRequest(Host host,NetworkAnonymizationKey network_anonymization_key,NetLogWithSource net_log,ResolveHostParameters parameters)1525 HangingHostResolver::CreateServiceEndpointRequest(
1526     Host host,
1527     NetworkAnonymizationKey network_anonymization_key,
1528     NetLogWithSource net_log,
1529     ResolveHostParameters parameters) {
1530   NOTIMPLEMENTED();
1531   return nullptr;
1532 }
1533 
1534 std::unique_ptr<HostResolver::ProbeRequest>
CreateDohProbeRequest()1535 HangingHostResolver::CreateDohProbeRequest() {
1536   if (shutting_down_)
1537     return CreateFailingProbeRequest(ERR_CONTEXT_SHUT_DOWN);
1538 
1539   return std::make_unique<RequestImpl>(weak_ptr_factory_.GetWeakPtr());
1540 }
1541 
SetRequestContext(URLRequestContext * url_request_context)1542 void HangingHostResolver::SetRequestContext(
1543     URLRequestContext* url_request_context) {}
1544 
1545 //-----------------------------------------------------------------------------
1546 
1547 ScopedDefaultHostResolverProc::ScopedDefaultHostResolverProc() = default;
1548 
ScopedDefaultHostResolverProc(HostResolverProc * proc)1549 ScopedDefaultHostResolverProc::ScopedDefaultHostResolverProc(
1550     HostResolverProc* proc) {
1551   Init(proc);
1552 }
1553 
~ScopedDefaultHostResolverProc()1554 ScopedDefaultHostResolverProc::~ScopedDefaultHostResolverProc() {
1555   HostResolverProc* old_proc =
1556       HostResolverProc::SetDefault(previous_proc_.get());
1557   // The lifetimes of multiple instances must be nested.
1558   CHECK_EQ(old_proc, current_proc_.get());
1559 }
1560 
Init(HostResolverProc * proc)1561 void ScopedDefaultHostResolverProc::Init(HostResolverProc* proc) {
1562   current_proc_ = proc;
1563   previous_proc_ = HostResolverProc::SetDefault(current_proc_.get());
1564   current_proc_->SetLastProc(previous_proc_);
1565 }
1566 
1567 }  // namespace net
1568