xref: /aosp_15_r20/external/cronet/net/dns/dns_task_results_manager.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2024 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/dns/dns_task_results_manager.h"
6 
7 #include <algorithm>
8 #include <map>
9 #include <memory>
10 #include <set>
11 #include <string>
12 #include <vector>
13 
14 #include "base/memory/raw_ptr.h"
15 #include "base/strings/string_number_conversions.h"
16 #include "base/time/time.h"
17 #include "base/timer/timer.h"
18 #include "net/base/connection_endpoint_metadata.h"
19 #include "net/base/ip_endpoint.h"
20 #include "net/base/net_errors.h"
21 #include "net/dns/host_resolver.h"
22 #include "net/dns/host_resolver_dns_task.h"
23 #include "net/dns/host_resolver_internal_result.h"
24 #include "net/dns/https_record_rdata.h"
25 #include "net/dns/public/dns_query_type.h"
26 #include "net/dns/public/host_resolver_results.h"
27 #include "net/log/net_log_event_type.h"
28 #include "net/log/net_log_with_source.h"
29 #include "third_party/abseil-cpp/absl/types/variant.h"
30 #include "url/scheme_host_port.h"
31 
32 namespace net {
33 
34 namespace {
35 
36 // Prioritize with-ipv6 over ipv4-only.
CompareServiceEndpointAddresses(const ServiceEndpoint & a,const ServiceEndpoint & b)37 bool CompareServiceEndpointAddresses(const ServiceEndpoint& a,
38                                      const ServiceEndpoint& b) {
39   const bool a_has_ipv6 = !a.ipv6_endpoints.empty();
40   const bool b_has_ipv6 = !b.ipv6_endpoints.empty();
41   if ((a_has_ipv6 && b_has_ipv6) || (!a_has_ipv6 && !b_has_ipv6)) {
42     return false;
43   }
44 
45   if (b_has_ipv6) {
46     return false;
47   }
48 
49   return true;
50 }
51 
52 // Prioritize with-metadata, with-ipv6 over ipv4-only.
53 // TODO(crbug.com/41493696): Consider which fields should be prioritized. We
54 // may want to have different sorting algorithms and choose one via config.
CompareServiceEndpoint(const ServiceEndpoint & a,const ServiceEndpoint & b)55 bool CompareServiceEndpoint(const ServiceEndpoint& a,
56                             const ServiceEndpoint& b) {
57   const bool a_has_metadata = a.metadata != ConnectionEndpointMetadata();
58   const bool b_has_metadata = b.metadata != ConnectionEndpointMetadata();
59   if (a_has_metadata && b_has_metadata) {
60     return CompareServiceEndpointAddresses(a, b);
61   }
62 
63   if (a_has_metadata) {
64     return true;
65   }
66 
67   if (b_has_metadata) {
68     return false;
69   }
70 
71   return CompareServiceEndpointAddresses(a, b);
72 }
73 
74 }  // namespace
75 
76 // Holds service endpoint results per domain name.
77 struct DnsTaskResultsManager::PerDomainResult {
78   PerDomainResult() = default;
79   ~PerDomainResult() = default;
80 
81   PerDomainResult(PerDomainResult&&) = default;
82   PerDomainResult& operator=(PerDomainResult&&) = default;
83   PerDomainResult(const PerDomainResult&) = delete;
84   PerDomainResult& operator=(const PerDomainResult&) = delete;
85 
86   std::vector<IPEndPoint> ipv4_endpoints;
87   std::vector<IPEndPoint> ipv6_endpoints;
88 
89   std::multimap<HttpsRecordPriority, ConnectionEndpointMetadata> metadatas;
90 };
91 
DnsTaskResultsManager(Delegate * delegate,HostResolver::Host host,DnsQueryTypeSet query_types,const NetLogWithSource & net_log)92 DnsTaskResultsManager::DnsTaskResultsManager(Delegate* delegate,
93                                              HostResolver::Host host,
94                                              DnsQueryTypeSet query_types,
95                                              const NetLogWithSource& net_log)
96     : delegate_(delegate),
97       host_(std::move(host)),
98       query_types_(query_types),
99       net_log_(net_log) {
100   CHECK(delegate_);
101 }
102 
103 DnsTaskResultsManager::~DnsTaskResultsManager() = default;
104 
ProcessDnsTransactionResults(DnsQueryType query_type,const std::set<std::unique_ptr<HostResolverInternalResult>> & results)105 void DnsTaskResultsManager::ProcessDnsTransactionResults(
106     DnsQueryType query_type,
107     const std::set<std::unique_ptr<HostResolverInternalResult>>& results) {
108   CHECK(query_types_.Has(query_type));
109 
110   bool should_update_endpoints = false;
111   bool should_notify = false;
112 
113   if (query_type == DnsQueryType::HTTPS) {
114     // Chrome does not yet support HTTPS follow-up queries so metadata is
115     // considered ready when the HTTPS response is received.
116     CHECK(!is_metadata_ready_);
117     is_metadata_ready_ = true;
118     should_notify = true;
119   }
120 
121   if (query_type == DnsQueryType::AAAA) {
122     aaaa_response_received_ = true;
123     if (resolution_delay_timer_.IsRunning()) {
124       resolution_delay_timer_.Stop();
125       RecordResolutionDelayResult(/*timedout=*/false);
126       // Need to update endpoints when there are IPv4 addresses.
127       if (HasIpv4Addresses()) {
128         should_update_endpoints = true;
129       }
130     }
131   }
132 
133   for (auto& result : results) {
134     aliases_.insert(result->domain_name());
135 
136     switch (result->type()) {
137       case HostResolverInternalResult::Type::kData: {
138         PerDomainResult& per_domain_result =
139             GetOrCreatePerDomainResult(result->domain_name());
140         if (query_type == DnsQueryType::A) {
141           for (const auto& ip_endpoint : result->AsData().endpoints()) {
142             CHECK(ip_endpoint.address().IsIPv4());
143             CHECK_EQ(ip_endpoint.port(), 0);
144             per_domain_result.ipv4_endpoints.emplace_back(ip_endpoint.address(),
145                                                           host_.GetPort());
146           }
147         } else if (query_type == DnsQueryType::AAAA) {
148           for (const auto& ip_endpoint : result->AsData().endpoints()) {
149             CHECK(ip_endpoint.address().IsIPv6());
150             CHECK_EQ(ip_endpoint.port(), 0);
151             per_domain_result.ipv6_endpoints.emplace_back(ip_endpoint.address(),
152                                                           host_.GetPort());
153           }
154         } else {
155           // TODO(crbug.com/41493696): This will eventually need to handle
156           // DnsQueryType::HTTPS to support getting ipv{4,6}hints.
157           NOTREACHED() << "Unexpected query type: "
158                        << kDnsQueryTypes.at(query_type);
159         }
160 
161         should_update_endpoints |= !result->AsData().endpoints().empty();
162 
163         break;
164       }
165       case HostResolverInternalResult::Type::kMetadata: {
166         CHECK_EQ(query_type, DnsQueryType::HTTPS);
167         for (auto [priority, metadata] : result->AsMetadata().metadatas()) {
168           // Associate the metadata with the target name instead of the domain
169           // name since the metadata is for the target name.
170           PerDomainResult& per_domain_result =
171               GetOrCreatePerDomainResult(metadata.target_name);
172           per_domain_result.metadatas.emplace(priority, metadata);
173         }
174 
175         should_update_endpoints |= !result->AsMetadata().metadatas().empty();
176 
177         break;
178       }
179       case net::HostResolverInternalResult::Type::kAlias:
180         aliases_.insert(result->AsAlias().alias_target());
181 
182         break;
183       case net::HostResolverInternalResult::Type::kError:
184         // Need to update endpoints when AAAA response is NODATA but A response
185         // has at least one valid address.
186         // TODO(crbug.com/41493696): Revisit how to handle errors other than
187         // NODATA. Currently we just ignore errors here and defer
188         // HostResolverManager::Job to create an error result and notify the
189         // error to the corresponding requests. This means that if the
190         // connection layer has already attempted a connection using an
191         // intermediate endpoint, the error might not be treated as fatal. We
192         // may want to have a different semantics.
193         PerDomainResult& per_domain_result =
194             GetOrCreatePerDomainResult(result->domain_name());
195         if (query_type == DnsQueryType::AAAA &&
196             result->AsError().error() == ERR_NAME_NOT_RESOLVED &&
197             !per_domain_result.ipv4_endpoints.empty()) {
198           CHECK(per_domain_result.ipv6_endpoints.empty());
199           should_update_endpoints = true;
200         }
201 
202         break;
203     }
204   }
205 
206   const bool waiting_for_aaaa_response =
207       query_types_.Has(DnsQueryType::AAAA) && !aaaa_response_received_;
208   if (waiting_for_aaaa_response) {
209     if (query_type == DnsQueryType::A && should_update_endpoints) {
210       // A is responded, start the resolution delay timer.
211       CHECK(!resolution_delay_timer_.IsRunning());
212       resolution_delay_start_time_ = base::TimeTicks::Now();
213       net_log_.BeginEvent(
214           NetLogEventType::HOST_RESOLVER_SERVICE_ENDPOINTS_RESOLUTION_DELAY);
215       // Safe to unretain since `this` owns the timer.
216       resolution_delay_timer_.Start(
217           FROM_HERE, kResolutionDelay,
218           base::BindOnce(&DnsTaskResultsManager::OnAaaaResolutionTimedout,
219                          base::Unretained(this)));
220     }
221 
222     return;
223   }
224 
225   if (should_update_endpoints) {
226     UpdateEndpoints();
227     return;
228   }
229 
230   if (should_notify && !current_endpoints_.empty()) {
231     delegate_->OnServiceEndpointsUpdated();
232   }
233 }
234 
GetCurrentEndpoints() const235 const std::vector<ServiceEndpoint>& DnsTaskResultsManager::GetCurrentEndpoints()
236     const {
237   return current_endpoints_;
238 }
239 
GetAliases() const240 const std::set<std::string>& DnsTaskResultsManager::GetAliases() const {
241   return aliases_;
242 }
243 
IsMetadataReady() const244 bool DnsTaskResultsManager::IsMetadataReady() const {
245   return !query_types_.Has(DnsQueryType::HTTPS) || is_metadata_ready_;
246 }
247 
248 DnsTaskResultsManager::PerDomainResult&
GetOrCreatePerDomainResult(const std::string & domain_name)249 DnsTaskResultsManager::GetOrCreatePerDomainResult(
250     const std::string& domain_name) {
251   auto it = per_domain_results_.find(domain_name);
252   if (it == per_domain_results_.end()) {
253     it = per_domain_results_.try_emplace(it, domain_name,
254                                          std::make_unique<PerDomainResult>());
255   }
256   return *it->second;
257 }
258 
OnAaaaResolutionTimedout()259 void DnsTaskResultsManager::OnAaaaResolutionTimedout() {
260   CHECK(!aaaa_response_received_);
261   RecordResolutionDelayResult(/*timedout=*/true);
262   UpdateEndpoints();
263 }
264 
UpdateEndpoints()265 void DnsTaskResultsManager::UpdateEndpoints() {
266   std::vector<ServiceEndpoint> new_endpoints;
267 
268   for (const auto& [domain_name, per_domain_result] : per_domain_results_) {
269     if (per_domain_result->ipv4_endpoints.empty() &&
270         per_domain_result->ipv6_endpoints.empty()) {
271       continue;
272     }
273 
274     if (per_domain_result->metadatas.empty()) {
275       ServiceEndpoint endpoint;
276       endpoint.ipv4_endpoints = per_domain_result->ipv4_endpoints;
277       endpoint.ipv6_endpoints = per_domain_result->ipv6_endpoints;
278       new_endpoints.emplace_back(std::move(endpoint));
279     } else {
280       for (const auto& [_, metadata] : per_domain_result->metadatas) {
281         ServiceEndpoint endpoint;
282         endpoint.ipv4_endpoints = per_domain_result->ipv4_endpoints;
283         endpoint.ipv6_endpoints = per_domain_result->ipv6_endpoints;
284         // TODO(crbug.com/41493696): Just adding per-domain metadata does not
285         // work properly when the target name of HTTPS is an alias, e.g:
286         //   example.com.     60 IN CNAME svc.example.com.
287         //   svc.example.com. 60 IN AAAA  2001:db8::1
288         //   svc.example.com. 60 IN HTTPS 1 example.com alpn="h2"
289         // In this case, svc.example.com should have metadata with alpn="h2" but
290         // the current logic doesn't do that. To handle it correctly we need to
291         // go though an alias tree for the domain name.
292         endpoint.metadata = metadata;
293         new_endpoints.emplace_back(std::move(endpoint));
294       }
295     }
296   }
297 
298   // TODO(crbug.com/41493696): Determine how to handle non-SVCB connection
299   // fallback. See https://datatracker.ietf.org/doc/html/rfc9460#section-3-8
300   // HostCache::Entry::GetEndpoints() appends a final non-alternative endpoint
301   // at the end to ensure that the connection layer can fall back to non-SVCB
302   // connection. For ServiceEndpoint request API, the current plan is to handle
303   // non-SVCB connection fallback in the connection layer. The approach might
304   // not work when Chrome tries to support HTTPS follow-up queries and aliases.
305 
306   // Stable sort preserves metadata priorities.
307   std::stable_sort(new_endpoints.begin(), new_endpoints.end(),
308                    CompareServiceEndpoint);
309   current_endpoints_ = std::move(new_endpoints);
310 
311   if (current_endpoints_.empty()) {
312     return;
313   }
314 
315   net_log_.AddEvent(NetLogEventType::HOST_RESOLVER_SERVICE_ENDPOINTS_UPDATED,
316                     [&] {
317                       base::Value::Dict dict;
318                       base::Value::List endpoints;
319                       for (const auto& endpoint : current_endpoints_) {
320                         endpoints.Append(endpoint.ToValue());
321                       }
322                       dict.Set("endpoints", std::move(endpoints));
323                       return dict;
324                     });
325 
326   delegate_->OnServiceEndpointsUpdated();
327 }
328 
HasIpv4Addresses()329 bool DnsTaskResultsManager::HasIpv4Addresses() {
330   for (const auto& [_, per_domain_result] : per_domain_results_) {
331     if (!per_domain_result->ipv4_endpoints.empty()) {
332       return true;
333     }
334   }
335   return false;
336 }
337 
RecordResolutionDelayResult(bool timedout)338 void DnsTaskResultsManager::RecordResolutionDelayResult(bool timedout) {
339   net_log_.EndEvent(
340       NetLogEventType::HOST_RESOLVER_SERVICE_ENDPOINTS_RESOLUTION_DELAY, [&]() {
341         base::TimeDelta elapsed =
342             base::TimeTicks::Now() - resolution_delay_start_time_;
343         base::Value::Dict dict;
344         dict.Set("timedout", timedout);
345         dict.Set("elapsed", base::NumberToString(elapsed.InMilliseconds()));
346         return dict;
347       });
348 }
349 
350 }  // namespace net
351