xref: /aosp_15_r20/external/cronet/net/dns/mdns_cache.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/dns/mdns_cache.h"
6 
7 #include <algorithm>
8 #include <tuple>
9 #include <utility>
10 
11 #include "base/containers/contains.h"
12 #include "base/strings/string_number_conversions.h"
13 #include "base/strings/string_util.h"
14 #include "net/dns/public/dns_protocol.h"
15 #include "net/dns/record_parsed.h"
16 #include "net/dns/record_rdata.h"
17 
18 // TODO(noamsml): Recursive CNAME closure (backwards and forwards).
19 
20 namespace net {
21 
22 namespace {
23 constexpr size_t kDefaultEntryLimit = 100'000;
24 }  // namespace
25 
26 // The effective TTL given to records with a nominal zero TTL.
27 // Allows time for hosts to send updated records, as detailed in RFC 6762
28 // Section 10.1.
29 static const unsigned kZeroTTLSeconds = 1;
30 
Key(unsigned type,const std::string & name,const std::string & optional)31 MDnsCache::Key::Key(unsigned type,
32                     const std::string& name,
33                     const std::string& optional)
34     : type_(type),
35       name_lowercase_(base::ToLowerASCII(name)),
36       optional_(optional) {}
37 
38 MDnsCache::Key::Key(const MDnsCache::Key& other) = default;
39 
40 MDnsCache::Key& MDnsCache::Key::operator=(const MDnsCache::Key& other) =
41     default;
42 
43 MDnsCache::Key::~Key() = default;
44 
operator <(const MDnsCache::Key & other) const45 bool MDnsCache::Key::operator<(const MDnsCache::Key& other) const {
46   return std::tie(name_lowercase_, type_, optional_) <
47          std::tie(other.name_lowercase_, other.type_, other.optional_);
48 }
49 
operator ==(const MDnsCache::Key & key) const50 bool MDnsCache::Key::operator==(const MDnsCache::Key& key) const {
51   return type_ == key.type_ && name_lowercase_ == key.name_lowercase_ &&
52          optional_ == key.optional_;
53 }
54 
55 // static
CreateFor(const RecordParsed * record)56 MDnsCache::Key MDnsCache::Key::CreateFor(const RecordParsed* record) {
57   return Key(record->type(),
58              record->name(),
59              GetOptionalFieldForRecord(record));
60 }
61 
MDnsCache()62 MDnsCache::MDnsCache() : entry_limit_(kDefaultEntryLimit) {}
63 
64 MDnsCache::~MDnsCache() = default;
65 
LookupKey(const Key & key)66 const RecordParsed* MDnsCache::LookupKey(const Key& key) {
67   auto found = mdns_cache_.find(key);
68   if (found != mdns_cache_.end()) {
69     return found->second.get();
70   }
71   return nullptr;
72 }
73 
UpdateDnsRecord(std::unique_ptr<const RecordParsed> record)74 MDnsCache::UpdateType MDnsCache::UpdateDnsRecord(
75     std::unique_ptr<const RecordParsed> record) {
76   Key cache_key = Key::CreateFor(record.get());
77 
78   // Ignore "goodbye" packets for records not in cache.
79   if (record->ttl() == 0 && !base::Contains(mdns_cache_, cache_key)) {
80     return NoChange;
81   }
82 
83   base::Time new_expiration = GetEffectiveExpiration(record.get());
84   if (next_expiration_ != base::Time())
85     new_expiration = std::min(new_expiration, next_expiration_);
86 
87   std::pair<RecordMap::iterator, bool> insert_result =
88       mdns_cache_.emplace(cache_key, nullptr);
89   UpdateType type = NoChange;
90   if (insert_result.second) {
91     type = RecordAdded;
92   } else {
93     if (record->ttl() != 0 &&
94         !record->IsEqual(insert_result.first->second.get(), true)) {
95       type = RecordChanged;
96     }
97   }
98 
99   insert_result.first->second = std::move(record);
100   next_expiration_ = new_expiration;
101   return type;
102 }
103 
CleanupRecords(base::Time now,const RecordRemovedCallback & record_removed_callback)104 void MDnsCache::CleanupRecords(
105     base::Time now,
106     const RecordRemovedCallback& record_removed_callback) {
107   base::Time next_expiration;
108 
109   // TODO(crbug.com/946688): Make overfill pruning more intelligent than a bulk
110   // clearing of everything.
111   bool clear_cache = IsCacheOverfilled();
112 
113   // We are guaranteed that |next_expiration_| will be at or before the next
114   // expiration. This allows clients to eagrely call CleanupRecords with
115   // impunity.
116   if (now < next_expiration_ && !clear_cache)
117     return;
118 
119   for (auto i = mdns_cache_.begin(); i != mdns_cache_.end();) {
120     base::Time expiration = GetEffectiveExpiration(i->second.get());
121     if (clear_cache || now >= expiration) {
122       record_removed_callback.Run(i->second.get());
123       i = mdns_cache_.erase(i);
124     } else {
125       if (next_expiration == base::Time() ||  expiration < next_expiration) {
126         next_expiration = expiration;
127       }
128       ++i;
129     }
130   }
131 
132   next_expiration_ = next_expiration;
133 }
134 
FindDnsRecords(unsigned type,const std::string & name,std::vector<const RecordParsed * > * results,base::Time now) const135 void MDnsCache::FindDnsRecords(unsigned type,
136                                const std::string& name,
137                                std::vector<const RecordParsed*>* results,
138                                base::Time now) const {
139   DCHECK(results);
140   results->clear();
141 
142   const std::string name_lowercase = base::ToLowerASCII(name);
143   auto i = mdns_cache_.lower_bound(Key(type, name, ""));
144   for (; i != mdns_cache_.end(); ++i) {
145     if (i->first.name_lowercase() != name_lowercase ||
146         (type != 0 && i->first.type() != type)) {
147       break;
148     }
149 
150     const RecordParsed* record = i->second.get();
151 
152     // Records are deleted only upon request.
153     if (now >= GetEffectiveExpiration(record)) continue;
154 
155     results->push_back(record);
156   }
157 }
158 
RemoveRecord(const RecordParsed * record)159 std::unique_ptr<const RecordParsed> MDnsCache::RemoveRecord(
160     const RecordParsed* record) {
161   Key key = Key::CreateFor(record);
162   auto found = mdns_cache_.find(key);
163 
164   if (found != mdns_cache_.end() && found->second.get() == record) {
165     std::unique_ptr<const RecordParsed> result = std::move(found->second);
166     mdns_cache_.erase(key);
167     return result;
168   }
169 
170   return nullptr;
171 }
172 
IsCacheOverfilled() const173 bool MDnsCache::IsCacheOverfilled() const {
174   return mdns_cache_.size() > entry_limit_;
175 }
176 
177 // static
GetOptionalFieldForRecord(const RecordParsed * record)178 std::string MDnsCache::GetOptionalFieldForRecord(const RecordParsed* record) {
179   switch (record->type()) {
180     case PtrRecordRdata::kType: {
181       const PtrRecordRdata* rdata = record->rdata<PtrRecordRdata>();
182       return rdata->ptrdomain();
183     }
184     default:  // Most records are considered unique for our purposes
185       return "";
186   }
187 }
188 
189 // static
GetEffectiveExpiration(const RecordParsed * record)190 base::Time MDnsCache::GetEffectiveExpiration(const RecordParsed* record) {
191   base::TimeDelta ttl;
192 
193   if (record->ttl()) {
194     ttl = base::Seconds(record->ttl());
195   } else {
196     ttl = base::Seconds(kZeroTTLSeconds);
197   }
198 
199   return record->time_created() + ttl;
200 }
201 
202 }  // namespace net
203