1 // Copyright 2023 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/dns/host_resolver_cache.h"
6
7 #include <cstddef>
8 #include <memory>
9 #include <optional>
10 #include <string>
11 #include <string_view>
12 #include <utility>
13 #include <vector>
14
15 #include "base/check_op.h"
16 #include "base/numerics/safe_conversions.h"
17 #include "base/time/clock.h"
18 #include "base/time/time.h"
19 #include "net/base/network_anonymization_key.h"
20 #include "net/dns/host_resolver_internal_result.h"
21 #include "net/dns/public/dns_query_type.h"
22 #include "net/dns/public/host_resolver_source.h"
23 #include "url/third_party/mozilla/url_parse.h"
24 #include "url/url_canon.h"
25 #include "url/url_canon_stdstring.h"
26
27 namespace net {
28
29 namespace {
30
31 constexpr std::string_view kNakKey = "network_anonymization_key";
32 constexpr std::string_view kSourceKey = "source";
33 constexpr std::string_view kSecureKey = "secure";
34 constexpr std::string_view kResultKey = "result";
35 constexpr std::string_view kStalenessGenerationKey = "staleness_generation";
36 constexpr std::string_view kMaxEntriesKey = "max_entries";
37 constexpr std::string_view kEntriesKey = "entries";
38
39 } // namespace
40
41 HostResolverCache::Key::~Key() = default;
42
StaleLookupResult(const HostResolverInternalResult & result,std::optional<base::TimeDelta> expired_by,bool stale_by_generation)43 HostResolverCache::StaleLookupResult::StaleLookupResult(
44 const HostResolverInternalResult& result,
45 std::optional<base::TimeDelta> expired_by,
46 bool stale_by_generation)
47 : result(result),
48 expired_by(expired_by),
49 stale_by_generation(stale_by_generation) {}
50
HostResolverCache(size_t max_results,const base::Clock & clock,const base::TickClock & tick_clock)51 HostResolverCache::HostResolverCache(size_t max_results,
52 const base::Clock& clock,
53 const base::TickClock& tick_clock)
54 : max_entries_(max_results), clock_(clock), tick_clock_(tick_clock) {
55 DCHECK_GT(max_entries_, 0u);
56 }
57
58 HostResolverCache::~HostResolverCache() = default;
59
60 HostResolverCache::HostResolverCache(HostResolverCache&&) = default;
61
62 HostResolverCache& HostResolverCache::operator=(HostResolverCache&&) = default;
63
Lookup(std::string_view domain_name,const NetworkAnonymizationKey & network_anonymization_key,DnsQueryType query_type,HostResolverSource source,std::optional<bool> secure) const64 const HostResolverInternalResult* HostResolverCache::Lookup(
65 std::string_view domain_name,
66 const NetworkAnonymizationKey& network_anonymization_key,
67 DnsQueryType query_type,
68 HostResolverSource source,
69 std::optional<bool> secure) const {
70 std::vector<EntryMap::const_iterator> candidates = LookupInternal(
71 domain_name, network_anonymization_key, query_type, source, secure);
72
73 // Get the most secure, last-matching (which is first in the vector returned
74 // by LookupInternal()) non-expired result.
75 base::TimeTicks now_ticks = tick_clock_->NowTicks();
76 base::Time now = clock_->Now();
77 HostResolverInternalResult* most_secure_result = nullptr;
78 for (const EntryMap::const_iterator& candidate : candidates) {
79 DCHECK(candidate->second.result->timed_expiration().has_value());
80
81 if (candidate->second.IsStale(now, now_ticks, staleness_generation_)) {
82 continue;
83 }
84
85 // If the candidate is secure, or all results are insecure, no need to check
86 // any more.
87 if (candidate->second.secure || !secure.value_or(true)) {
88 return candidate->second.result.get();
89 } else if (most_secure_result == nullptr) {
90 most_secure_result = candidate->second.result.get();
91 }
92 }
93
94 return most_secure_result;
95 }
96
97 std::optional<HostResolverCache::StaleLookupResult>
LookupStale(std::string_view domain_name,const NetworkAnonymizationKey & network_anonymization_key,DnsQueryType query_type,HostResolverSource source,std::optional<bool> secure) const98 HostResolverCache::LookupStale(
99 std::string_view domain_name,
100 const NetworkAnonymizationKey& network_anonymization_key,
101 DnsQueryType query_type,
102 HostResolverSource source,
103 std::optional<bool> secure) const {
104 std::vector<EntryMap::const_iterator> candidates = LookupInternal(
105 domain_name, network_anonymization_key, query_type, source, secure);
106
107 // Get the least expired, most secure result.
108 base::TimeTicks now_ticks = tick_clock_->NowTicks();
109 base::Time now = clock_->Now();
110 const Entry* best_match = nullptr;
111 base::TimeDelta best_match_time_until_expiration;
112 for (const EntryMap::const_iterator& candidate : candidates) {
113 DCHECK(candidate->second.result->timed_expiration().has_value());
114
115 base::TimeDelta candidate_time_until_expiration =
116 candidate->second.TimeUntilExpiration(now, now_ticks);
117
118 if (!candidate->second.IsStale(now, now_ticks, staleness_generation_) &&
119 (candidate->second.secure || !secure.value_or(true))) {
120 // If a non-stale candidate is secure, or all results are insecure, no
121 // need to check any more.
122 best_match = &candidate->second;
123 best_match_time_until_expiration = candidate_time_until_expiration;
124 break;
125 } else if (best_match == nullptr ||
126 (!candidate->second.IsStale(now, now_ticks,
127 staleness_generation_) &&
128 best_match->IsStale(now, now_ticks, staleness_generation_)) ||
129 candidate->second.staleness_generation >
130 best_match->staleness_generation ||
131 (candidate->second.staleness_generation ==
132 best_match->staleness_generation &&
133 candidate_time_until_expiration >
134 best_match_time_until_expiration) ||
135 (candidate->second.staleness_generation ==
136 best_match->staleness_generation &&
137 candidate_time_until_expiration ==
138 best_match_time_until_expiration &&
139 candidate->second.secure && !best_match->secure)) {
140 best_match = &candidate->second;
141 best_match_time_until_expiration = candidate_time_until_expiration;
142 }
143 }
144
145 if (best_match == nullptr) {
146 return std::nullopt;
147 } else {
148 std::optional<base::TimeDelta> expired_by;
149 if (best_match_time_until_expiration.is_negative()) {
150 expired_by = best_match_time_until_expiration.magnitude();
151 }
152 return StaleLookupResult(
153 *best_match->result, expired_by,
154 best_match->staleness_generation != staleness_generation_);
155 }
156 }
157
Set(std::unique_ptr<HostResolverInternalResult> result,const NetworkAnonymizationKey & network_anonymization_key,HostResolverSource source,bool secure)158 void HostResolverCache::Set(
159 std::unique_ptr<HostResolverInternalResult> result,
160 const NetworkAnonymizationKey& network_anonymization_key,
161 HostResolverSource source,
162 bool secure) {
163 Set(std::move(result), network_anonymization_key, source, secure,
164 /*replace_existing=*/true, staleness_generation_);
165 }
166
MakeAllResultsStale()167 void HostResolverCache::MakeAllResultsStale() {
168 ++staleness_generation_;
169 }
170
Serialize() const171 base::Value HostResolverCache::Serialize() const {
172 // Do not serialize any entries without a persistable anonymization key
173 // because it is required to store and restore entries with the correct
174 // annonymization key. A non-persistable anonymization key is typically used
175 // for short-lived contexts, and associated entries are not expected to be
176 // useful after persistence to disk anyway.
177 return SerializeEntries(/*serialize_staleness_generation=*/false,
178 /*require_persistable_anonymization_key=*/true);
179 }
180
RestoreFromValue(const base::Value & value)181 bool HostResolverCache::RestoreFromValue(const base::Value& value) {
182 const base::Value::List* list = value.GetIfList();
183 if (!list) {
184 return false;
185 }
186
187 for (const base::Value& list_value : *list) {
188 // Simply stop on reaching max size rather than attempting to figure out if
189 // any current entries should be evicted over the deserialized entries.
190 if (entries_.size() == max_entries_) {
191 return true;
192 }
193
194 const base::Value::Dict* dict = list_value.GetIfDict();
195 if (!dict) {
196 return false;
197 }
198
199 const base::Value* anonymization_key_value = dict->Find(kNakKey);
200 NetworkAnonymizationKey anonymization_key;
201 if (!anonymization_key_value ||
202 !NetworkAnonymizationKey::FromValue(*anonymization_key_value,
203 &anonymization_key)) {
204 return false;
205 }
206
207 const base::Value* source_value = dict->Find(kSourceKey);
208 std::optional<HostResolverSource> source =
209 source_value == nullptr ? std::nullopt
210 : HostResolverSourceFromValue(*source_value);
211 if (!source.has_value()) {
212 return false;
213 }
214
215 std::optional<bool> secure = dict->FindBool(kSecureKey);
216 if (!secure.has_value()) {
217 return false;
218 }
219
220 const base::Value* result_value = dict->Find(kResultKey);
221 std::unique_ptr<HostResolverInternalResult> result =
222 result_value == nullptr
223 ? nullptr
224 : HostResolverInternalResult::FromValue(*result_value);
225 if (!result || !result->timed_expiration().has_value()) {
226 return false;
227 }
228
229 // `staleness_generation_ - 1` to make entry stale-by-generation.
230 Set(std::move(result), anonymization_key, source.value(), secure.value(),
231 /*replace_existing=*/false, staleness_generation_ - 1);
232 }
233
234 CHECK_LE(entries_.size(), max_entries_);
235 return true;
236 }
237
SerializeForLogging() const238 base::Value HostResolverCache::SerializeForLogging() const {
239 base::Value::Dict dict;
240
241 dict.Set(kMaxEntriesKey, base::checked_cast<int>(max_entries_));
242 dict.Set(kStalenessGenerationKey, staleness_generation_);
243
244 // Include entries with non-persistable anonymization keys, so the log can
245 // contain all entries. Restoring from this serialization is not supported.
246 dict.Set(kEntriesKey,
247 SerializeEntries(/*serialize_staleness_generation=*/true,
248 /*require_persistable_anonymization_key=*/false));
249
250 return base::Value(std::move(dict));
251 }
252
Entry(std::unique_ptr<HostResolverInternalResult> result,HostResolverSource source,bool secure,int staleness_generation)253 HostResolverCache::Entry::Entry(
254 std::unique_ptr<HostResolverInternalResult> result,
255 HostResolverSource source,
256 bool secure,
257 int staleness_generation)
258 : result(std::move(result)),
259 source(source),
260 secure(secure),
261 staleness_generation(staleness_generation) {}
262
263 HostResolverCache::Entry::~Entry() = default;
264
265 HostResolverCache::Entry::Entry(Entry&&) = default;
266
267 HostResolverCache::Entry& HostResolverCache::Entry::operator=(Entry&&) =
268 default;
269
IsStale(base::Time now,base::TimeTicks now_ticks,int current_staleness_generation) const270 bool HostResolverCache::Entry::IsStale(base::Time now,
271 base::TimeTicks now_ticks,
272 int current_staleness_generation) const {
273 return staleness_generation != current_staleness_generation ||
274 TimeUntilExpiration(now, now_ticks).is_negative();
275 }
276
TimeUntilExpiration(base::Time now,base::TimeTicks now_ticks) const277 base::TimeDelta HostResolverCache::Entry::TimeUntilExpiration(
278 base::Time now,
279 base::TimeTicks now_ticks) const {
280 if (result->expiration().has_value()) {
281 return result->expiration().value() - now_ticks;
282 } else {
283 DCHECK(result->timed_expiration().has_value());
284 return result->timed_expiration().value() - now;
285 }
286 }
287
288 std::vector<HostResolverCache::EntryMap::const_iterator>
LookupInternal(std::string_view domain_name,const NetworkAnonymizationKey & network_anonymization_key,DnsQueryType query_type,HostResolverSource source,std::optional<bool> secure) const289 HostResolverCache::LookupInternal(
290 std::string_view domain_name,
291 const NetworkAnonymizationKey& network_anonymization_key,
292 DnsQueryType query_type,
293 HostResolverSource source,
294 std::optional<bool> secure) const {
295 auto matches = std::vector<EntryMap::const_iterator>();
296
297 if (entries_.empty()) {
298 return matches;
299 }
300
301 std::string canonicalized;
302 url::StdStringCanonOutput output(&canonicalized);
303 url::CanonHostInfo host_info;
304
305 url::CanonicalizeHostVerbose(domain_name.data(),
306 url::Component(0, domain_name.size()), &output,
307 &host_info);
308
309 // For performance, when canonicalization can't canonicalize, minimize string
310 // copies and just reuse the input StringPiece. This optimization prevents
311 // easily reusing a MaybeCanoncalize util with similar code.
312 std::string_view lookup_name = domain_name;
313 if (host_info.family == url::CanonHostInfo::Family::NEUTRAL) {
314 output.Complete();
315 lookup_name = canonicalized;
316 }
317
318 auto range = entries_.equal_range(
319 KeyRef{lookup_name, raw_ref(network_anonymization_key)});
320 if (range.first == entries_.cend() || range.second == entries_.cbegin() ||
321 range.first == range.second) {
322 return matches;
323 }
324
325 // Iterate in reverse order to return most-recently-added entry first.
326 auto it = --range.second;
327 while (true) {
328 if ((query_type == DnsQueryType::UNSPECIFIED ||
329 it->second.result->query_type() == DnsQueryType::UNSPECIFIED ||
330 query_type == it->second.result->query_type()) &&
331 (source == HostResolverSource::ANY || source == it->second.source) &&
332 (!secure.has_value() || secure.value() == it->second.secure)) {
333 matches.push_back(it);
334 }
335
336 if (it == range.first) {
337 break;
338 }
339 --it;
340 }
341
342 return matches;
343 }
344
Set(std::unique_ptr<HostResolverInternalResult> result,const NetworkAnonymizationKey & network_anonymization_key,HostResolverSource source,bool secure,bool replace_existing,int staleness_generation)345 void HostResolverCache::Set(
346 std::unique_ptr<HostResolverInternalResult> result,
347 const NetworkAnonymizationKey& network_anonymization_key,
348 HostResolverSource source,
349 bool secure,
350 bool replace_existing,
351 int staleness_generation) {
352 DCHECK(result);
353 // Result must have at least a timed expiration to be a cacheable result.
354 DCHECK(result->timed_expiration().has_value());
355
356 std::vector<EntryMap::const_iterator> matches =
357 LookupInternal(result->domain_name(), network_anonymization_key,
358 result->query_type(), source, secure);
359
360 if (!matches.empty() && !replace_existing) {
361 // Matches already present that are not to be replaced.
362 return;
363 }
364
365 for (const EntryMap::const_iterator& match : matches) {
366 entries_.erase(match);
367 }
368
369 std::string domain_name = result->domain_name();
370 entries_.emplace(
371 Key(std::move(domain_name), network_anonymization_key),
372 Entry(std::move(result), source, secure, staleness_generation));
373
374 if (entries_.size() > max_entries_) {
375 EvictEntries();
376 }
377 }
378
379 // Remove all stale entries, or if none stale, the soonest-to-expire,
380 // least-secure entry.
EvictEntries()381 void HostResolverCache::EvictEntries() {
382 base::TimeTicks now_ticks = tick_clock_->NowTicks();
383 base::Time now = clock_->Now();
384
385 bool stale_found = false;
386 base::TimeDelta soonest_time_till_expriation = base::TimeDelta::Max();
387 std::optional<EntryMap::const_iterator> best_for_removal;
388
389 auto it = entries_.cbegin();
390 while (it != entries_.cend()) {
391 if (it->second.IsStale(now, now_ticks, staleness_generation_)) {
392 stale_found = true;
393 it = entries_.erase(it);
394 } else {
395 base::TimeDelta time_till_expiration =
396 it->second.TimeUntilExpiration(now, now_ticks);
397
398 if (!best_for_removal.has_value() ||
399 time_till_expiration < soonest_time_till_expriation ||
400 (time_till_expiration == soonest_time_till_expriation &&
401 best_for_removal.value()->second.secure && !it->second.secure)) {
402 soonest_time_till_expriation = time_till_expiration;
403 best_for_removal = it;
404 }
405
406 ++it;
407 }
408 }
409
410 if (!stale_found) {
411 CHECK(best_for_removal.has_value());
412 entries_.erase(best_for_removal.value());
413 }
414
415 CHECK_LE(entries_.size(), max_entries_);
416 }
417
SerializeEntries(bool serialize_staleness_generation,bool require_persistable_anonymization_key) const418 base::Value HostResolverCache::SerializeEntries(
419 bool serialize_staleness_generation,
420 bool require_persistable_anonymization_key) const {
421 base::Value::List list;
422
423 for (const auto& [key, entry] : entries_) {
424 base::Value::Dict dict;
425
426 if (serialize_staleness_generation) {
427 dict.Set(kStalenessGenerationKey, entry.staleness_generation);
428 }
429
430 base::Value anonymization_key_value;
431 if (!key.network_anonymization_key.ToValue(&anonymization_key_value)) {
432 if (require_persistable_anonymization_key) {
433 continue;
434 } else {
435 // If the caller doesn't care about anonymization keys that can be
436 // serialized and restored, construct a serialization just for the sake
437 // of logging information.
438 anonymization_key_value =
439 base::Value("Non-persistable network anonymization key: " +
440 key.network_anonymization_key.ToDebugString());
441 }
442 }
443
444 dict.Set(kNakKey, std::move(anonymization_key_value));
445 dict.Set(kSourceKey, ToValue(entry.source));
446 dict.Set(kSecureKey, entry.secure);
447 dict.Set(kResultKey, entry.result->ToValue());
448
449 list.Append(std::move(dict));
450 }
451
452 return base::Value(std::move(list));
453 }
454
455 } // namespace net
456