1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/dns/mdns_cache.h"
6
7 #include <algorithm>
8 #include <tuple>
9 #include <utility>
10
11 #include "base/containers/contains.h"
12 #include "base/strings/string_number_conversions.h"
13 #include "base/strings/string_util.h"
14 #include "net/dns/public/dns_protocol.h"
15 #include "net/dns/record_parsed.h"
16 #include "net/dns/record_rdata.h"
17
18 // TODO(noamsml): Recursive CNAME closure (backwards and forwards).
19
20 namespace net {
21
22 namespace {
23 constexpr size_t kDefaultEntryLimit = 100'000;
24 } // namespace
25
26 // The effective TTL given to records with a nominal zero TTL.
27 // Allows time for hosts to send updated records, as detailed in RFC 6762
28 // Section 10.1.
29 static const unsigned kZeroTTLSeconds = 1;
30
Key(unsigned type,const std::string & name,const std::string & optional)31 MDnsCache::Key::Key(unsigned type,
32 const std::string& name,
33 const std::string& optional)
34 : type_(type),
35 name_lowercase_(base::ToLowerASCII(name)),
36 optional_(optional) {}
37
38 MDnsCache::Key::Key(const MDnsCache::Key& other) = default;
39
40 MDnsCache::Key& MDnsCache::Key::operator=(const MDnsCache::Key& other) =
41 default;
42
43 MDnsCache::Key::~Key() = default;
44
operator <(const MDnsCache::Key & other) const45 bool MDnsCache::Key::operator<(const MDnsCache::Key& other) const {
46 return std::tie(name_lowercase_, type_, optional_) <
47 std::tie(other.name_lowercase_, other.type_, other.optional_);
48 }
49
operator ==(const MDnsCache::Key & key) const50 bool MDnsCache::Key::operator==(const MDnsCache::Key& key) const {
51 return type_ == key.type_ && name_lowercase_ == key.name_lowercase_ &&
52 optional_ == key.optional_;
53 }
54
55 // static
CreateFor(const RecordParsed * record)56 MDnsCache::Key MDnsCache::Key::CreateFor(const RecordParsed* record) {
57 return Key(record->type(),
58 record->name(),
59 GetOptionalFieldForRecord(record));
60 }
61
MDnsCache()62 MDnsCache::MDnsCache() : entry_limit_(kDefaultEntryLimit) {}
63
64 MDnsCache::~MDnsCache() = default;
65
LookupKey(const Key & key)66 const RecordParsed* MDnsCache::LookupKey(const Key& key) {
67 auto found = mdns_cache_.find(key);
68 if (found != mdns_cache_.end()) {
69 return found->second.get();
70 }
71 return nullptr;
72 }
73
UpdateDnsRecord(std::unique_ptr<const RecordParsed> record)74 MDnsCache::UpdateType MDnsCache::UpdateDnsRecord(
75 std::unique_ptr<const RecordParsed> record) {
76 Key cache_key = Key::CreateFor(record.get());
77
78 // Ignore "goodbye" packets for records not in cache.
79 if (record->ttl() == 0 && !base::Contains(mdns_cache_, cache_key)) {
80 return NoChange;
81 }
82
83 base::Time new_expiration = GetEffectiveExpiration(record.get());
84 if (next_expiration_ != base::Time())
85 new_expiration = std::min(new_expiration, next_expiration_);
86
87 std::pair<RecordMap::iterator, bool> insert_result =
88 mdns_cache_.emplace(cache_key, nullptr);
89 UpdateType type = NoChange;
90 if (insert_result.second) {
91 type = RecordAdded;
92 } else {
93 if (record->ttl() != 0 &&
94 !record->IsEqual(insert_result.first->second.get(), true)) {
95 type = RecordChanged;
96 }
97 }
98
99 insert_result.first->second = std::move(record);
100 next_expiration_ = new_expiration;
101 return type;
102 }
103
CleanupRecords(base::Time now,const RecordRemovedCallback & record_removed_callback)104 void MDnsCache::CleanupRecords(
105 base::Time now,
106 const RecordRemovedCallback& record_removed_callback) {
107 base::Time next_expiration;
108
109 // TODO(crbug.com/946688): Make overfill pruning more intelligent than a bulk
110 // clearing of everything.
111 bool clear_cache = IsCacheOverfilled();
112
113 // We are guaranteed that |next_expiration_| will be at or before the next
114 // expiration. This allows clients to eagrely call CleanupRecords with
115 // impunity.
116 if (now < next_expiration_ && !clear_cache)
117 return;
118
119 for (auto i = mdns_cache_.begin(); i != mdns_cache_.end();) {
120 base::Time expiration = GetEffectiveExpiration(i->second.get());
121 if (clear_cache || now >= expiration) {
122 record_removed_callback.Run(i->second.get());
123 i = mdns_cache_.erase(i);
124 } else {
125 if (next_expiration == base::Time() || expiration < next_expiration) {
126 next_expiration = expiration;
127 }
128 ++i;
129 }
130 }
131
132 next_expiration_ = next_expiration;
133 }
134
FindDnsRecords(unsigned type,const std::string & name,std::vector<const RecordParsed * > * results,base::Time now) const135 void MDnsCache::FindDnsRecords(unsigned type,
136 const std::string& name,
137 std::vector<const RecordParsed*>* results,
138 base::Time now) const {
139 DCHECK(results);
140 results->clear();
141
142 const std::string name_lowercase = base::ToLowerASCII(name);
143 auto i = mdns_cache_.lower_bound(Key(type, name, ""));
144 for (; i != mdns_cache_.end(); ++i) {
145 if (i->first.name_lowercase() != name_lowercase ||
146 (type != 0 && i->first.type() != type)) {
147 break;
148 }
149
150 const RecordParsed* record = i->second.get();
151
152 // Records are deleted only upon request.
153 if (now >= GetEffectiveExpiration(record)) continue;
154
155 results->push_back(record);
156 }
157 }
158
RemoveRecord(const RecordParsed * record)159 std::unique_ptr<const RecordParsed> MDnsCache::RemoveRecord(
160 const RecordParsed* record) {
161 Key key = Key::CreateFor(record);
162 auto found = mdns_cache_.find(key);
163
164 if (found != mdns_cache_.end() && found->second.get() == record) {
165 std::unique_ptr<const RecordParsed> result = std::move(found->second);
166 mdns_cache_.erase(key);
167 return result;
168 }
169
170 return nullptr;
171 }
172
IsCacheOverfilled() const173 bool MDnsCache::IsCacheOverfilled() const {
174 return mdns_cache_.size() > entry_limit_;
175 }
176
177 // static
GetOptionalFieldForRecord(const RecordParsed * record)178 std::string MDnsCache::GetOptionalFieldForRecord(const RecordParsed* record) {
179 switch (record->type()) {
180 case PtrRecordRdata::kType: {
181 const PtrRecordRdata* rdata = record->rdata<PtrRecordRdata>();
182 return rdata->ptrdomain();
183 }
184 default: // Most records are considered unique for our purposes
185 return "";
186 }
187 }
188
189 // static
GetEffectiveExpiration(const RecordParsed * record)190 base::Time MDnsCache::GetEffectiveExpiration(const RecordParsed* record) {
191 base::TimeDelta ttl;
192
193 if (record->ttl()) {
194 ttl = base::Seconds(record->ttl());
195 } else {
196 ttl = base::Seconds(kZeroTTLSeconds);
197 }
198
199 return record->time_created() + ttl;
200 }
201
202 } // namespace net
203