xref: /aosp_15_r20/external/cronet/net/dns/host_resolver_cache.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2023 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/dns/host_resolver_cache.h"
6 
7 #include <cstddef>
8 #include <memory>
9 #include <optional>
10 #include <string>
11 #include <string_view>
12 #include <utility>
13 #include <vector>
14 
15 #include "base/check_op.h"
16 #include "base/numerics/safe_conversions.h"
17 #include "base/time/clock.h"
18 #include "base/time/time.h"
19 #include "net/base/network_anonymization_key.h"
20 #include "net/dns/host_resolver_internal_result.h"
21 #include "net/dns/public/dns_query_type.h"
22 #include "net/dns/public/host_resolver_source.h"
23 #include "url/third_party/mozilla/url_parse.h"
24 #include "url/url_canon.h"
25 #include "url/url_canon_stdstring.h"
26 
27 namespace net {
28 
29 namespace {
30 
31 constexpr std::string_view kNakKey = "network_anonymization_key";
32 constexpr std::string_view kSourceKey = "source";
33 constexpr std::string_view kSecureKey = "secure";
34 constexpr std::string_view kResultKey = "result";
35 constexpr std::string_view kStalenessGenerationKey = "staleness_generation";
36 constexpr std::string_view kMaxEntriesKey = "max_entries";
37 constexpr std::string_view kEntriesKey = "entries";
38 
39 }  // namespace
40 
41 HostResolverCache::Key::~Key() = default;
42 
StaleLookupResult(const HostResolverInternalResult & result,std::optional<base::TimeDelta> expired_by,bool stale_by_generation)43 HostResolverCache::StaleLookupResult::StaleLookupResult(
44     const HostResolverInternalResult& result,
45     std::optional<base::TimeDelta> expired_by,
46     bool stale_by_generation)
47     : result(result),
48       expired_by(expired_by),
49       stale_by_generation(stale_by_generation) {}
50 
HostResolverCache(size_t max_results,const base::Clock & clock,const base::TickClock & tick_clock)51 HostResolverCache::HostResolverCache(size_t max_results,
52                                      const base::Clock& clock,
53                                      const base::TickClock& tick_clock)
54     : max_entries_(max_results), clock_(clock), tick_clock_(tick_clock) {
55   DCHECK_GT(max_entries_, 0u);
56 }
57 
58 HostResolverCache::~HostResolverCache() = default;
59 
60 HostResolverCache::HostResolverCache(HostResolverCache&&) = default;
61 
62 HostResolverCache& HostResolverCache::operator=(HostResolverCache&&) = default;
63 
Lookup(std::string_view domain_name,const NetworkAnonymizationKey & network_anonymization_key,DnsQueryType query_type,HostResolverSource source,std::optional<bool> secure) const64 const HostResolverInternalResult* HostResolverCache::Lookup(
65     std::string_view domain_name,
66     const NetworkAnonymizationKey& network_anonymization_key,
67     DnsQueryType query_type,
68     HostResolverSource source,
69     std::optional<bool> secure) const {
70   std::vector<EntryMap::const_iterator> candidates = LookupInternal(
71       domain_name, network_anonymization_key, query_type, source, secure);
72 
73   // Get the most secure, last-matching (which is first in the vector returned
74   // by LookupInternal()) non-expired result.
75   base::TimeTicks now_ticks = tick_clock_->NowTicks();
76   base::Time now = clock_->Now();
77   HostResolverInternalResult* most_secure_result = nullptr;
78   for (const EntryMap::const_iterator& candidate : candidates) {
79     DCHECK(candidate->second.result->timed_expiration().has_value());
80 
81     if (candidate->second.IsStale(now, now_ticks, staleness_generation_)) {
82       continue;
83     }
84 
85     // If the candidate is secure, or all results are insecure, no need to check
86     // any more.
87     if (candidate->second.secure || !secure.value_or(true)) {
88       return candidate->second.result.get();
89     } else if (most_secure_result == nullptr) {
90       most_secure_result = candidate->second.result.get();
91     }
92   }
93 
94   return most_secure_result;
95 }
96 
97 std::optional<HostResolverCache::StaleLookupResult>
LookupStale(std::string_view domain_name,const NetworkAnonymizationKey & network_anonymization_key,DnsQueryType query_type,HostResolverSource source,std::optional<bool> secure) const98 HostResolverCache::LookupStale(
99     std::string_view domain_name,
100     const NetworkAnonymizationKey& network_anonymization_key,
101     DnsQueryType query_type,
102     HostResolverSource source,
103     std::optional<bool> secure) const {
104   std::vector<EntryMap::const_iterator> candidates = LookupInternal(
105       domain_name, network_anonymization_key, query_type, source, secure);
106 
107   // Get the least expired, most secure result.
108   base::TimeTicks now_ticks = tick_clock_->NowTicks();
109   base::Time now = clock_->Now();
110   const Entry* best_match = nullptr;
111   base::TimeDelta best_match_time_until_expiration;
112   for (const EntryMap::const_iterator& candidate : candidates) {
113     DCHECK(candidate->second.result->timed_expiration().has_value());
114 
115     base::TimeDelta candidate_time_until_expiration =
116         candidate->second.TimeUntilExpiration(now, now_ticks);
117 
118     if (!candidate->second.IsStale(now, now_ticks, staleness_generation_) &&
119         (candidate->second.secure || !secure.value_or(true))) {
120       // If a non-stale candidate is secure, or all results are insecure, no
121       // need to check any more.
122       best_match = &candidate->second;
123       best_match_time_until_expiration = candidate_time_until_expiration;
124       break;
125     } else if (best_match == nullptr ||
126                (!candidate->second.IsStale(now, now_ticks,
127                                            staleness_generation_) &&
128                 best_match->IsStale(now, now_ticks, staleness_generation_)) ||
129                candidate->second.staleness_generation >
130                    best_match->staleness_generation ||
131                (candidate->second.staleness_generation ==
132                     best_match->staleness_generation &&
133                 candidate_time_until_expiration >
134                     best_match_time_until_expiration) ||
135                (candidate->second.staleness_generation ==
136                     best_match->staleness_generation &&
137                 candidate_time_until_expiration ==
138                     best_match_time_until_expiration &&
139                 candidate->second.secure && !best_match->secure)) {
140       best_match = &candidate->second;
141       best_match_time_until_expiration = candidate_time_until_expiration;
142     }
143   }
144 
145   if (best_match == nullptr) {
146     return std::nullopt;
147   } else {
148     std::optional<base::TimeDelta> expired_by;
149     if (best_match_time_until_expiration.is_negative()) {
150       expired_by = best_match_time_until_expiration.magnitude();
151     }
152     return StaleLookupResult(
153         *best_match->result, expired_by,
154         best_match->staleness_generation != staleness_generation_);
155   }
156 }
157 
Set(std::unique_ptr<HostResolverInternalResult> result,const NetworkAnonymizationKey & network_anonymization_key,HostResolverSource source,bool secure)158 void HostResolverCache::Set(
159     std::unique_ptr<HostResolverInternalResult> result,
160     const NetworkAnonymizationKey& network_anonymization_key,
161     HostResolverSource source,
162     bool secure) {
163   Set(std::move(result), network_anonymization_key, source, secure,
164       /*replace_existing=*/true, staleness_generation_);
165 }
166 
MakeAllResultsStale()167 void HostResolverCache::MakeAllResultsStale() {
168   ++staleness_generation_;
169 }
170 
Serialize() const171 base::Value HostResolverCache::Serialize() const {
172   // Do not serialize any entries without a persistable anonymization key
173   // because it is required to store and restore entries with the correct
174   // annonymization key. A non-persistable anonymization key is typically used
175   // for short-lived contexts, and associated entries are not expected to be
176   // useful after persistence to disk anyway.
177   return SerializeEntries(/*serialize_staleness_generation=*/false,
178                           /*require_persistable_anonymization_key=*/true);
179 }
180 
RestoreFromValue(const base::Value & value)181 bool HostResolverCache::RestoreFromValue(const base::Value& value) {
182   const base::Value::List* list = value.GetIfList();
183   if (!list) {
184     return false;
185   }
186 
187   for (const base::Value& list_value : *list) {
188     // Simply stop on reaching max size rather than attempting to figure out if
189     // any current entries should be evicted over the deserialized entries.
190     if (entries_.size() == max_entries_) {
191       return true;
192     }
193 
194     const base::Value::Dict* dict = list_value.GetIfDict();
195     if (!dict) {
196       return false;
197     }
198 
199     const base::Value* anonymization_key_value = dict->Find(kNakKey);
200     NetworkAnonymizationKey anonymization_key;
201     if (!anonymization_key_value ||
202         !NetworkAnonymizationKey::FromValue(*anonymization_key_value,
203                                             &anonymization_key)) {
204       return false;
205     }
206 
207     const base::Value* source_value = dict->Find(kSourceKey);
208     std::optional<HostResolverSource> source =
209         source_value == nullptr ? std::nullopt
210                                 : HostResolverSourceFromValue(*source_value);
211     if (!source.has_value()) {
212       return false;
213     }
214 
215     std::optional<bool> secure = dict->FindBool(kSecureKey);
216     if (!secure.has_value()) {
217       return false;
218     }
219 
220     const base::Value* result_value = dict->Find(kResultKey);
221     std::unique_ptr<HostResolverInternalResult> result =
222         result_value == nullptr
223             ? nullptr
224             : HostResolverInternalResult::FromValue(*result_value);
225     if (!result || !result->timed_expiration().has_value()) {
226       return false;
227     }
228 
229     // `staleness_generation_ - 1` to make entry stale-by-generation.
230     Set(std::move(result), anonymization_key, source.value(), secure.value(),
231         /*replace_existing=*/false, staleness_generation_ - 1);
232   }
233 
234   CHECK_LE(entries_.size(), max_entries_);
235   return true;
236 }
237 
SerializeForLogging() const238 base::Value HostResolverCache::SerializeForLogging() const {
239   base::Value::Dict dict;
240 
241   dict.Set(kMaxEntriesKey, base::checked_cast<int>(max_entries_));
242   dict.Set(kStalenessGenerationKey, staleness_generation_);
243 
244   // Include entries with non-persistable anonymization keys, so the log can
245   // contain all entries. Restoring from this serialization is not supported.
246   dict.Set(kEntriesKey,
247            SerializeEntries(/*serialize_staleness_generation=*/true,
248                             /*require_persistable_anonymization_key=*/false));
249 
250   return base::Value(std::move(dict));
251 }
252 
Entry(std::unique_ptr<HostResolverInternalResult> result,HostResolverSource source,bool secure,int staleness_generation)253 HostResolverCache::Entry::Entry(
254     std::unique_ptr<HostResolverInternalResult> result,
255     HostResolverSource source,
256     bool secure,
257     int staleness_generation)
258     : result(std::move(result)),
259       source(source),
260       secure(secure),
261       staleness_generation(staleness_generation) {}
262 
263 HostResolverCache::Entry::~Entry() = default;
264 
265 HostResolverCache::Entry::Entry(Entry&&) = default;
266 
267 HostResolverCache::Entry& HostResolverCache::Entry::operator=(Entry&&) =
268     default;
269 
IsStale(base::Time now,base::TimeTicks now_ticks,int current_staleness_generation) const270 bool HostResolverCache::Entry::IsStale(base::Time now,
271                                        base::TimeTicks now_ticks,
272                                        int current_staleness_generation) const {
273   return staleness_generation != current_staleness_generation ||
274          TimeUntilExpiration(now, now_ticks).is_negative();
275 }
276 
TimeUntilExpiration(base::Time now,base::TimeTicks now_ticks) const277 base::TimeDelta HostResolverCache::Entry::TimeUntilExpiration(
278     base::Time now,
279     base::TimeTicks now_ticks) const {
280   if (result->expiration().has_value()) {
281     return result->expiration().value() - now_ticks;
282   } else {
283     DCHECK(result->timed_expiration().has_value());
284     return result->timed_expiration().value() - now;
285   }
286 }
287 
288 std::vector<HostResolverCache::EntryMap::const_iterator>
LookupInternal(std::string_view domain_name,const NetworkAnonymizationKey & network_anonymization_key,DnsQueryType query_type,HostResolverSource source,std::optional<bool> secure) const289 HostResolverCache::LookupInternal(
290     std::string_view domain_name,
291     const NetworkAnonymizationKey& network_anonymization_key,
292     DnsQueryType query_type,
293     HostResolverSource source,
294     std::optional<bool> secure) const {
295   auto matches = std::vector<EntryMap::const_iterator>();
296 
297   if (entries_.empty()) {
298     return matches;
299   }
300 
301   std::string canonicalized;
302   url::StdStringCanonOutput output(&canonicalized);
303   url::CanonHostInfo host_info;
304 
305   url::CanonicalizeHostVerbose(domain_name.data(),
306                                url::Component(0, domain_name.size()), &output,
307                                &host_info);
308 
309   // For performance, when canonicalization can't canonicalize, minimize string
310   // copies and just reuse the input StringPiece. This optimization prevents
311   // easily reusing a MaybeCanoncalize util with similar code.
312   std::string_view lookup_name = domain_name;
313   if (host_info.family == url::CanonHostInfo::Family::NEUTRAL) {
314     output.Complete();
315     lookup_name = canonicalized;
316   }
317 
318   auto range = entries_.equal_range(
319       KeyRef{lookup_name, raw_ref(network_anonymization_key)});
320   if (range.first == entries_.cend() || range.second == entries_.cbegin() ||
321       range.first == range.second) {
322     return matches;
323   }
324 
325   // Iterate in reverse order to return most-recently-added entry first.
326   auto it = --range.second;
327   while (true) {
328     if ((query_type == DnsQueryType::UNSPECIFIED ||
329          it->second.result->query_type() == DnsQueryType::UNSPECIFIED ||
330          query_type == it->second.result->query_type()) &&
331         (source == HostResolverSource::ANY || source == it->second.source) &&
332         (!secure.has_value() || secure.value() == it->second.secure)) {
333       matches.push_back(it);
334     }
335 
336     if (it == range.first) {
337       break;
338     }
339     --it;
340   }
341 
342   return matches;
343 }
344 
Set(std::unique_ptr<HostResolverInternalResult> result,const NetworkAnonymizationKey & network_anonymization_key,HostResolverSource source,bool secure,bool replace_existing,int staleness_generation)345 void HostResolverCache::Set(
346     std::unique_ptr<HostResolverInternalResult> result,
347     const NetworkAnonymizationKey& network_anonymization_key,
348     HostResolverSource source,
349     bool secure,
350     bool replace_existing,
351     int staleness_generation) {
352   DCHECK(result);
353   // Result must have at least a timed expiration to be a cacheable result.
354   DCHECK(result->timed_expiration().has_value());
355 
356   std::vector<EntryMap::const_iterator> matches =
357       LookupInternal(result->domain_name(), network_anonymization_key,
358                      result->query_type(), source, secure);
359 
360   if (!matches.empty() && !replace_existing) {
361     // Matches already present that are not to be replaced.
362     return;
363   }
364 
365   for (const EntryMap::const_iterator& match : matches) {
366     entries_.erase(match);
367   }
368 
369   std::string domain_name = result->domain_name();
370   entries_.emplace(
371       Key(std::move(domain_name), network_anonymization_key),
372       Entry(std::move(result), source, secure, staleness_generation));
373 
374   if (entries_.size() > max_entries_) {
375     EvictEntries();
376   }
377 }
378 
379 // Remove all stale entries, or if none stale, the soonest-to-expire,
380 // least-secure entry.
EvictEntries()381 void HostResolverCache::EvictEntries() {
382   base::TimeTicks now_ticks = tick_clock_->NowTicks();
383   base::Time now = clock_->Now();
384 
385   bool stale_found = false;
386   base::TimeDelta soonest_time_till_expriation = base::TimeDelta::Max();
387   std::optional<EntryMap::const_iterator> best_for_removal;
388 
389   auto it = entries_.cbegin();
390   while (it != entries_.cend()) {
391     if (it->second.IsStale(now, now_ticks, staleness_generation_)) {
392       stale_found = true;
393       it = entries_.erase(it);
394     } else {
395       base::TimeDelta time_till_expiration =
396           it->second.TimeUntilExpiration(now, now_ticks);
397 
398       if (!best_for_removal.has_value() ||
399           time_till_expiration < soonest_time_till_expriation ||
400           (time_till_expiration == soonest_time_till_expriation &&
401            best_for_removal.value()->second.secure && !it->second.secure)) {
402         soonest_time_till_expriation = time_till_expiration;
403         best_for_removal = it;
404       }
405 
406       ++it;
407     }
408   }
409 
410   if (!stale_found) {
411     CHECK(best_for_removal.has_value());
412     entries_.erase(best_for_removal.value());
413   }
414 
415   CHECK_LE(entries_.size(), max_entries_);
416 }
417 
SerializeEntries(bool serialize_staleness_generation,bool require_persistable_anonymization_key) const418 base::Value HostResolverCache::SerializeEntries(
419     bool serialize_staleness_generation,
420     bool require_persistable_anonymization_key) const {
421   base::Value::List list;
422 
423   for (const auto& [key, entry] : entries_) {
424     base::Value::Dict dict;
425 
426     if (serialize_staleness_generation) {
427       dict.Set(kStalenessGenerationKey, entry.staleness_generation);
428     }
429 
430     base::Value anonymization_key_value;
431     if (!key.network_anonymization_key.ToValue(&anonymization_key_value)) {
432       if (require_persistable_anonymization_key) {
433         continue;
434       } else {
435         // If the caller doesn't care about anonymization keys that can be
436         // serialized and restored, construct a serialization just for the sake
437         // of logging information.
438         anonymization_key_value =
439             base::Value("Non-persistable network anonymization key: " +
440                         key.network_anonymization_key.ToDebugString());
441       }
442     }
443 
444     dict.Set(kNakKey, std::move(anonymization_key_value));
445     dict.Set(kSourceKey, ToValue(entry.source));
446     dict.Set(kSecureKey, entry.secure);
447     dict.Set(kResultKey, entry.result->ToValue());
448 
449     list.Append(std::move(dict));
450   }
451 
452   return base::Value(std::move(list));
453 }
454 
455 }  // namespace net
456