1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/dns/mdns_client_impl.h"
6
7 #include <algorithm>
8 #include <cstdint>
9 #include <memory>
10 #include <optional>
11 #include <utility>
12 #include <vector>
13
14 #include "base/containers/fixed_flat_set.h"
15 #include "base/functional/bind.h"
16 #include "base/location.h"
17 #include "base/metrics/histogram_functions.h"
18 #include "base/observer_list.h"
19 #include "base/ranges/algorithm.h"
20 #include "base/strings/string_util.h"
21 #include "base/task/single_thread_task_runner.h"
22 #include "base/time/clock.h"
23 #include "base/time/default_clock.h"
24 #include "base/time/time.h"
25 #include "base/timer/timer.h"
26 #include "net/base/net_errors.h"
27 #include "net/base/rand_callback.h"
28 #include "net/dns/dns_names_util.h"
29 #include "net/dns/public/dns_protocol.h"
30 #include "net/dns/public/util.h"
31 #include "net/dns/record_rdata.h"
32 #include "net/socket/datagram_socket.h"
33
34 // TODO(gene): Remove this temporary method of disabling NSEC support once it
35 // becomes clear whether this feature should be
36 // supported. http://crbug.com/255232
37 #define ENABLE_NSEC
38
39 namespace net {
40
41 namespace {
42
43 // The fractions of the record's original TTL after which an active listener
44 // (one that had |SetActiveRefresh(true)| called) will send a query to refresh
45 // its cache. This happens both at 85% of the original TTL and again at 95% of
46 // the original TTL.
47 const double kListenerRefreshRatio1 = 0.85;
48 const double kListenerRefreshRatio2 = 0.95;
49
50 // These values are persisted to logs. Entries should not be renumbered and
51 // numeric values should never be reused.
52 enum class mdnsQueryType {
53 kInitial = 0, // Initial mDNS query sent.
54 kRefresh = 1, // Refresh mDNS query sent.
55 kMaxValue = kRefresh,
56 };
57
RecordQueryMetric(mdnsQueryType query_type,std::string_view host)58 void RecordQueryMetric(mdnsQueryType query_type, std::string_view host) {
59 constexpr auto kPrintScanServices = base::MakeFixedFlatSet<std::string_view>({
60 "_ipps._tcp.local",
61 "_ipp._tcp.local",
62 "_pdl-datastream._tcp.local",
63 "_printer._tcp.local",
64 "_print._sub._ipps._tcp.local",
65 "_print._sub._ipp._tcp.local",
66 "_scanner._tcp.local",
67 "_uscans._tcp.local",
68 "_uscan._tcp.local",
69 });
70
71 if (host.ends_with("_googlecast._tcp.local")) {
72 base::UmaHistogramEnumeration("Network.Mdns.Googlecast", query_type);
73 } else if (base::ranges::any_of(kPrintScanServices,
74 [&host](std::string_view service) {
75 return host.ends_with(service);
76 })) {
77 base::UmaHistogramEnumeration("Network.Mdns.PrintScan", query_type);
78 } else {
79 base::UmaHistogramEnumeration("Network.Mdns.Other", query_type);
80 }
81 }
82
83 } // namespace
84
CreateSockets(std::vector<std::unique_ptr<DatagramServerSocket>> * sockets)85 void MDnsSocketFactoryImpl::CreateSockets(
86 std::vector<std::unique_ptr<DatagramServerSocket>>* sockets) {
87 InterfaceIndexFamilyList interfaces(GetMDnsInterfacesToBind());
88 for (const auto& interface : interfaces) {
89 DCHECK(interface.second == ADDRESS_FAMILY_IPV4 ||
90 interface.second == ADDRESS_FAMILY_IPV6);
91 std::unique_ptr<DatagramServerSocket> socket(
92 CreateAndBindMDnsSocket(interface.second, interface.first, net_log_));
93 if (socket)
94 sockets->push_back(std::move(socket));
95 }
96 }
97
SocketHandler(std::unique_ptr<DatagramServerSocket> socket,MDnsConnection * connection)98 MDnsConnection::SocketHandler::SocketHandler(
99 std::unique_ptr<DatagramServerSocket> socket,
100 MDnsConnection* connection)
101 : socket_(std::move(socket)),
102 connection_(connection),
103 response_(dns_protocol::kMaxMulticastSize) {}
104
105 MDnsConnection::SocketHandler::~SocketHandler() = default;
106
Start()107 int MDnsConnection::SocketHandler::Start() {
108 IPEndPoint end_point;
109 int rv = socket_->GetLocalAddress(&end_point);
110 if (rv != OK)
111 return rv;
112 DCHECK(end_point.GetFamily() == ADDRESS_FAMILY_IPV4 ||
113 end_point.GetFamily() == ADDRESS_FAMILY_IPV6);
114 multicast_addr_ = dns_util::GetMdnsGroupEndPoint(end_point.GetFamily());
115 return DoLoop(0);
116 }
117
DoLoop(int rv)118 int MDnsConnection::SocketHandler::DoLoop(int rv) {
119 do {
120 if (rv > 0)
121 connection_->OnDatagramReceived(&response_, recv_addr_, rv);
122
123 rv = socket_->RecvFrom(
124 response_.io_buffer(), response_.io_buffer_size(), &recv_addr_,
125 base::BindOnce(&MDnsConnection::SocketHandler::OnDatagramReceived,
126 base::Unretained(this)));
127 } while (rv > 0);
128
129 if (rv != ERR_IO_PENDING)
130 return rv;
131
132 return OK;
133 }
134
OnDatagramReceived(int rv)135 void MDnsConnection::SocketHandler::OnDatagramReceived(int rv) {
136 if (rv >= OK)
137 rv = DoLoop(rv);
138
139 if (rv != OK)
140 connection_->PostOnError(this, rv);
141 }
142
Send(const scoped_refptr<IOBuffer> & buffer,unsigned size)143 void MDnsConnection::SocketHandler::Send(const scoped_refptr<IOBuffer>& buffer,
144 unsigned size) {
145 if (send_in_progress_) {
146 send_queue_.emplace(buffer, size);
147 return;
148 }
149 int rv =
150 socket_->SendTo(buffer.get(), size, multicast_addr_,
151 base::BindOnce(&MDnsConnection::SocketHandler::SendDone,
152 base::Unretained(this)));
153 if (rv == ERR_IO_PENDING) {
154 send_in_progress_ = true;
155 } else if (rv < OK) {
156 connection_->PostOnError(this, rv);
157 }
158 }
159
SendDone(int rv)160 void MDnsConnection::SocketHandler::SendDone(int rv) {
161 DCHECK(send_in_progress_);
162 send_in_progress_ = false;
163 if (rv != OK)
164 connection_->PostOnError(this, rv);
165 while (!send_in_progress_ && !send_queue_.empty()) {
166 std::pair<scoped_refptr<IOBuffer>, unsigned> buffer = send_queue_.front();
167 send_queue_.pop();
168 Send(buffer.first, buffer.second);
169 }
170 }
171
MDnsConnection(MDnsConnection::Delegate * delegate)172 MDnsConnection::MDnsConnection(MDnsConnection::Delegate* delegate)
173 : delegate_(delegate) {}
174
175 MDnsConnection::~MDnsConnection() = default;
176
Init(MDnsSocketFactory * socket_factory)177 int MDnsConnection::Init(MDnsSocketFactory* socket_factory) {
178 std::vector<std::unique_ptr<DatagramServerSocket>> sockets;
179 socket_factory->CreateSockets(&sockets);
180
181 for (std::unique_ptr<DatagramServerSocket>& socket : sockets) {
182 socket_handlers_.push_back(std::make_unique<MDnsConnection::SocketHandler>(
183 std::move(socket), this));
184 }
185
186 // All unbound sockets need to be bound before processing untrusted input.
187 // This is done for security reasons, so that an attacker can't get an unbound
188 // socket.
189 int last_failure = ERR_FAILED;
190 for (size_t i = 0; i < socket_handlers_.size();) {
191 int rv = socket_handlers_[i]->Start();
192 if (rv != OK) {
193 last_failure = rv;
194 socket_handlers_.erase(socket_handlers_.begin() + i);
195 VLOG(1) << "Start failed, socket=" << i << ", error=" << rv;
196 } else {
197 ++i;
198 }
199 }
200 VLOG(1) << "Sockets ready:" << socket_handlers_.size();
201 DCHECK_NE(ERR_IO_PENDING, last_failure);
202 return socket_handlers_.empty() ? last_failure : OK;
203 }
204
Send(const scoped_refptr<IOBuffer> & buffer,unsigned size)205 void MDnsConnection::Send(const scoped_refptr<IOBuffer>& buffer,
206 unsigned size) {
207 for (std::unique_ptr<SocketHandler>& handler : socket_handlers_)
208 handler->Send(buffer, size);
209 }
210
PostOnError(SocketHandler * loop,int rv)211 void MDnsConnection::PostOnError(SocketHandler* loop, int rv) {
212 int id = 0;
213 for (const auto& it : socket_handlers_) {
214 if (it.get() == loop)
215 break;
216 id++;
217 }
218 VLOG(1) << "Socket error. id=" << id << ", error=" << rv;
219 // Post to allow deletion of this object by delegate.
220 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
221 FROM_HERE, base::BindOnce(&MDnsConnection::OnError,
222 weak_ptr_factory_.GetWeakPtr(), rv));
223 }
224
OnError(int rv)225 void MDnsConnection::OnError(int rv) {
226 // TODO(noamsml): Specific handling of intermittent errors that can be handled
227 // in the connection.
228 delegate_->OnConnectionError(rv);
229 }
230
OnDatagramReceived(DnsResponse * response,const IPEndPoint & recv_addr,int bytes_read)231 void MDnsConnection::OnDatagramReceived(
232 DnsResponse* response,
233 const IPEndPoint& recv_addr,
234 int bytes_read) {
235 // TODO(noamsml): More sophisticated error handling.
236 DCHECK_GT(bytes_read, 0);
237 delegate_->HandlePacket(response, bytes_read);
238 }
239
Core(base::Clock * clock,base::OneShotTimer * timer)240 MDnsClientImpl::Core::Core(base::Clock* clock, base::OneShotTimer* timer)
241 : clock_(clock),
242 cleanup_timer_(timer),
243 connection_(
244 std::make_unique<MDnsConnection>((MDnsConnection::Delegate*)this)) {
245 DCHECK(cleanup_timer_);
246 DCHECK(!cleanup_timer_->IsRunning());
247 }
248
~Core()249 MDnsClientImpl::Core::~Core() {
250 cleanup_timer_->Stop();
251 }
252
Init(MDnsSocketFactory * socket_factory)253 int MDnsClientImpl::Core::Init(MDnsSocketFactory* socket_factory) {
254 CHECK(!cleanup_timer_->IsRunning());
255 return connection_->Init(socket_factory);
256 }
257
SendQuery(uint16_t rrtype,const std::string & name)258 bool MDnsClientImpl::Core::SendQuery(uint16_t rrtype, const std::string& name) {
259 std::optional<std::vector<uint8_t>> name_dns =
260 dns_names_util::DottedNameToNetwork(name);
261 if (!name_dns.has_value())
262 return false;
263
264 DnsQuery query(0, name_dns.value(), rrtype);
265 query.set_flags(0); // Remove the RD flag from the query. It is unneeded.
266
267 connection_->Send(query.io_buffer(), query.io_buffer()->size());
268 return true;
269 }
270
HandlePacket(DnsResponse * response,int bytes_read)271 void MDnsClientImpl::Core::HandlePacket(DnsResponse* response,
272 int bytes_read) {
273 unsigned offset;
274 // Note: We store cache keys rather than record pointers to avoid
275 // erroneous behavior in case a packet contains multiple exclusive
276 // records with the same type and name.
277 std::map<MDnsCache::Key, MDnsCache::UpdateType> update_keys;
278 DCHECK_GT(bytes_read, 0);
279 if (!response->InitParseWithoutQuery(bytes_read)) {
280 DVLOG(1) << "Could not understand an mDNS packet.";
281 return; // Message is unreadable.
282 }
283
284 // TODO(noamsml): duplicate query suppression.
285 if (!(response->flags() & dns_protocol::kFlagResponse))
286 return; // Message is a query. ignore it.
287
288 DnsRecordParser parser = response->Parser();
289 unsigned answer_count = response->answer_count() +
290 response->additional_answer_count();
291
292 for (unsigned i = 0; i < answer_count; i++) {
293 offset = parser.GetOffset();
294 std::unique_ptr<const RecordParsed> record =
295 RecordParsed::CreateFrom(&parser, clock_->Now());
296
297 if (!record) {
298 DVLOG(1) << "Could not understand an mDNS record.";
299
300 if (offset == parser.GetOffset()) {
301 DVLOG(1) << "Abandoned parsing the rest of the packet.";
302 return; // The parser did not advance, abort reading the packet.
303 } else {
304 continue; // We may be able to extract other records from the packet.
305 }
306 }
307
308 if ((record->klass() & dns_protocol::kMDnsClassMask) !=
309 dns_protocol::kClassIN) {
310 DVLOG(1) << "Received an mDNS record with non-IN class. Ignoring.";
311 continue; // Ignore all records not in the IN class.
312 }
313
314 MDnsCache::Key update_key = MDnsCache::Key::CreateFor(record.get());
315 MDnsCache::UpdateType update = cache_.UpdateDnsRecord(std::move(record));
316
317 // Cleanup time may have changed.
318 ScheduleCleanup(cache_.next_expiration());
319
320 update_keys.emplace(update_key, update);
321 }
322
323 for (const auto& update_key : update_keys) {
324 const RecordParsed* record = cache_.LookupKey(update_key.first);
325 if (!record)
326 continue;
327
328 if (record->type() == dns_protocol::kTypeNSEC) {
329 #if defined(ENABLE_NSEC)
330 NotifyNsecRecord(record);
331 #endif
332 } else {
333 AlertListeners(update_key.second,
334 ListenerKey(record->name(), record->type()), record);
335 }
336 }
337 }
338
NotifyNsecRecord(const RecordParsed * record)339 void MDnsClientImpl::Core::NotifyNsecRecord(const RecordParsed* record) {
340 DCHECK_EQ(dns_protocol::kTypeNSEC, record->type());
341 const NsecRecordRdata* rdata = record->rdata<NsecRecordRdata>();
342 DCHECK(rdata);
343
344 // Remove all cached records matching the nonexistent RR types.
345 std::vector<const RecordParsed*> records_to_remove;
346
347 cache_.FindDnsRecords(0, record->name(), &records_to_remove, clock_->Now());
348
349 for (const auto* record_to_remove : records_to_remove) {
350 if (record_to_remove->type() == dns_protocol::kTypeNSEC)
351 continue;
352 if (!rdata->GetBit(record_to_remove->type())) {
353 std::unique_ptr<const RecordParsed> record_removed =
354 cache_.RemoveRecord(record_to_remove);
355 DCHECK(record_removed);
356 OnRecordRemoved(record_removed.get());
357 }
358 }
359
360 // Alert all listeners waiting for the nonexistent RR types.
361 ListenerKey key(record->name(), 0);
362 auto i = listeners_.upper_bound(key);
363 for (; i != listeners_.end() &&
364 i->first.name_lowercase() == key.name_lowercase();
365 i++) {
366 if (!rdata->GetBit(i->first.type())) {
367 for (auto& observer : *i->second)
368 observer.AlertNsecRecord();
369 }
370 }
371 }
372
OnConnectionError(int error)373 void MDnsClientImpl::Core::OnConnectionError(int error) {
374 // TODO(noamsml): On connection error, recreate connection and flush cache.
375 VLOG(1) << "MDNS OnConnectionError (code: " << error << ")";
376 }
377
ListenerKey(const std::string & name,uint16_t type)378 MDnsClientImpl::Core::ListenerKey::ListenerKey(const std::string& name,
379 uint16_t type)
380 : name_lowercase_(base::ToLowerASCII(name)), type_(type) {}
381
operator <(const MDnsClientImpl::Core::ListenerKey & key) const382 bool MDnsClientImpl::Core::ListenerKey::operator<(
383 const MDnsClientImpl::Core::ListenerKey& key) const {
384 if (name_lowercase_ == key.name_lowercase_)
385 return type_ < key.type_;
386 return name_lowercase_ < key.name_lowercase_;
387 }
388
AlertListeners(MDnsCache::UpdateType update_type,const ListenerKey & key,const RecordParsed * record)389 void MDnsClientImpl::Core::AlertListeners(
390 MDnsCache::UpdateType update_type,
391 const ListenerKey& key,
392 const RecordParsed* record) {
393 auto listener_map_iterator = listeners_.find(key);
394 if (listener_map_iterator == listeners_.end()) return;
395
396 for (auto& observer : *listener_map_iterator->second)
397 observer.HandleRecordUpdate(update_type, record);
398 }
399
AddListener(MDnsListenerImpl * listener)400 void MDnsClientImpl::Core::AddListener(
401 MDnsListenerImpl* listener) {
402 ListenerKey key(listener->GetName(), listener->GetType());
403
404 auto& observer_list = listeners_[key];
405 if (!observer_list)
406 observer_list = std::make_unique<ObserverListType>();
407
408 observer_list->AddObserver(listener);
409 }
410
RemoveListener(MDnsListenerImpl * listener)411 void MDnsClientImpl::Core::RemoveListener(MDnsListenerImpl* listener) {
412 ListenerKey key(listener->GetName(), listener->GetType());
413 auto observer_list_iterator = listeners_.find(key);
414
415 DCHECK(observer_list_iterator != listeners_.end());
416 DCHECK(observer_list_iterator->second->HasObserver(listener));
417
418 observer_list_iterator->second->RemoveObserver(listener);
419
420 // Remove the observer list from the map if it is empty
421 if (observer_list_iterator->second->empty()) {
422 // Schedule the actual removal for later in case the listener removal
423 // happens while iterating over the observer list.
424 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
425 FROM_HERE, base::BindOnce(&MDnsClientImpl::Core::CleanupObserverList,
426 AsWeakPtr(), key));
427 }
428 }
429
CleanupObserverList(const ListenerKey & key)430 void MDnsClientImpl::Core::CleanupObserverList(const ListenerKey& key) {
431 auto found = listeners_.find(key);
432 if (found != listeners_.end() && found->second->empty()) {
433 listeners_.erase(found);
434 }
435 }
436
ScheduleCleanup(base::Time cleanup)437 void MDnsClientImpl::Core::ScheduleCleanup(base::Time cleanup) {
438 // If cache is overfilled. Force an immediate cleanup.
439 if (cache_.IsCacheOverfilled())
440 cleanup = clock_->Now();
441
442 // Cleanup is already scheduled, no need to do anything.
443 if (cleanup == scheduled_cleanup_) {
444 return;
445 }
446 scheduled_cleanup_ = cleanup;
447
448 // This cancels the previously scheduled cleanup.
449 cleanup_timer_->Stop();
450
451 // If |cleanup| is empty, then no cleanup necessary.
452 if (cleanup != base::Time()) {
453 cleanup_timer_->Start(FROM_HERE,
454 std::max(base::TimeDelta(), cleanup - clock_->Now()),
455 base::BindOnce(&MDnsClientImpl::Core::DoCleanup,
456 base::Unretained(this)));
457 }
458 }
459
DoCleanup()460 void MDnsClientImpl::Core::DoCleanup() {
461 cache_.CleanupRecords(
462 clock_->Now(), base::BindRepeating(&MDnsClientImpl::Core::OnRecordRemoved,
463 base::Unretained(this)));
464
465 ScheduleCleanup(cache_.next_expiration());
466 }
467
OnRecordRemoved(const RecordParsed * record)468 void MDnsClientImpl::Core::OnRecordRemoved(
469 const RecordParsed* record) {
470 AlertListeners(MDnsCache::RecordRemoved,
471 ListenerKey(record->name(), record->type()), record);
472 }
473
QueryCache(uint16_t rrtype,const std::string & name,std::vector<const RecordParsed * > * records) const474 void MDnsClientImpl::Core::QueryCache(
475 uint16_t rrtype,
476 const std::string& name,
477 std::vector<const RecordParsed*>* records) const {
478 cache_.FindDnsRecords(rrtype, name, records, clock_->Now());
479 }
480
MDnsClientImpl()481 MDnsClientImpl::MDnsClientImpl()
482 : clock_(base::DefaultClock::GetInstance()),
483 cleanup_timer_(std::make_unique<base::OneShotTimer>()) {}
484
MDnsClientImpl(base::Clock * clock,std::unique_ptr<base::OneShotTimer> timer)485 MDnsClientImpl::MDnsClientImpl(base::Clock* clock,
486 std::unique_ptr<base::OneShotTimer> timer)
487 : clock_(clock), cleanup_timer_(std::move(timer)) {}
488
~MDnsClientImpl()489 MDnsClientImpl::~MDnsClientImpl() {
490 StopListening();
491 }
492
StartListening(MDnsSocketFactory * socket_factory)493 int MDnsClientImpl::StartListening(MDnsSocketFactory* socket_factory) {
494 DCHECK(!core_.get());
495 core_ = std::make_unique<Core>(clock_, cleanup_timer_.get());
496 int rv = core_->Init(socket_factory);
497 if (rv != OK) {
498 DCHECK_NE(ERR_IO_PENDING, rv);
499 core_.reset();
500 }
501 return rv;
502 }
503
StopListening()504 void MDnsClientImpl::StopListening() {
505 core_.reset();
506 }
507
IsListening() const508 bool MDnsClientImpl::IsListening() const {
509 return core_.get() != nullptr;
510 }
511
CreateListener(uint16_t rrtype,const std::string & name,MDnsListener::Delegate * delegate)512 std::unique_ptr<MDnsListener> MDnsClientImpl::CreateListener(
513 uint16_t rrtype,
514 const std::string& name,
515 MDnsListener::Delegate* delegate) {
516 return std::make_unique<MDnsListenerImpl>(rrtype, name, clock_, delegate,
517 this);
518 }
519
CreateTransaction(uint16_t rrtype,const std::string & name,int flags,const MDnsTransaction::ResultCallback & callback)520 std::unique_ptr<MDnsTransaction> MDnsClientImpl::CreateTransaction(
521 uint16_t rrtype,
522 const std::string& name,
523 int flags,
524 const MDnsTransaction::ResultCallback& callback) {
525 return std::make_unique<MDnsTransactionImpl>(rrtype, name, flags, callback,
526 this);
527 }
528
MDnsListenerImpl(uint16_t rrtype,const std::string & name,base::Clock * clock,MDnsListener::Delegate * delegate,MDnsClientImpl * client)529 MDnsListenerImpl::MDnsListenerImpl(uint16_t rrtype,
530 const std::string& name,
531 base::Clock* clock,
532 MDnsListener::Delegate* delegate,
533 MDnsClientImpl* client)
534 : rrtype_(rrtype),
535 name_(name),
536 clock_(clock),
537 client_(client),
538 delegate_(delegate) {}
539
~MDnsListenerImpl()540 MDnsListenerImpl::~MDnsListenerImpl() {
541 if (started_) {
542 DCHECK(client_->core());
543 client_->core()->RemoveListener(this);
544 }
545 }
546
Start()547 bool MDnsListenerImpl::Start() {
548 DCHECK(!started_);
549
550 started_ = true;
551
552 DCHECK(client_->core());
553 client_->core()->AddListener(this);
554
555 return true;
556 }
557
SetActiveRefresh(bool active_refresh)558 void MDnsListenerImpl::SetActiveRefresh(bool active_refresh) {
559 active_refresh_ = active_refresh;
560
561 if (started_) {
562 if (!active_refresh_) {
563 next_refresh_.Cancel();
564 } else if (last_update_ != base::Time()) {
565 ScheduleNextRefresh();
566 }
567 }
568 }
569
GetName() const570 const std::string& MDnsListenerImpl::GetName() const {
571 return name_;
572 }
573
GetType() const574 uint16_t MDnsListenerImpl::GetType() const {
575 return rrtype_;
576 }
577
HandleRecordUpdate(MDnsCache::UpdateType update_type,const RecordParsed * record)578 void MDnsListenerImpl::HandleRecordUpdate(MDnsCache::UpdateType update_type,
579 const RecordParsed* record) {
580 DCHECK(started_);
581
582 if (update_type != MDnsCache::RecordRemoved) {
583 ttl_ = record->ttl();
584 last_update_ = record->time_created();
585
586 ScheduleNextRefresh();
587 }
588
589 if (update_type != MDnsCache::NoChange) {
590 MDnsListener::UpdateType update_external;
591
592 switch (update_type) {
593 case MDnsCache::RecordAdded:
594 update_external = MDnsListener::RECORD_ADDED;
595 break;
596 case MDnsCache::RecordChanged:
597 update_external = MDnsListener::RECORD_CHANGED;
598 break;
599 case MDnsCache::RecordRemoved:
600 update_external = MDnsListener::RECORD_REMOVED;
601 break;
602 case MDnsCache::NoChange:
603 default:
604 NOTREACHED();
605 // Dummy assignment to suppress compiler warning.
606 update_external = MDnsListener::RECORD_CHANGED;
607 break;
608 }
609
610 delegate_->OnRecordUpdate(update_external, record);
611 }
612 }
613
AlertNsecRecord()614 void MDnsListenerImpl::AlertNsecRecord() {
615 DCHECK(started_);
616 delegate_->OnNsecRecord(name_, rrtype_);
617 }
618
ScheduleNextRefresh()619 void MDnsListenerImpl::ScheduleNextRefresh() {
620 DCHECK(last_update_ != base::Time());
621
622 if (!active_refresh_)
623 return;
624
625 // A zero TTL is a goodbye packet and should not be refreshed.
626 if (ttl_ == 0) {
627 next_refresh_.Cancel();
628 return;
629 }
630
631 next_refresh_.Reset(
632 base::BindRepeating(&MDnsListenerImpl::DoRefresh, AsWeakPtr()));
633
634 // Schedule refreshes at both 85% and 95% of the original TTL. These will both
635 // be canceled and rescheduled if the record's TTL is updated due to a
636 // response being received.
637 base::Time next_refresh1 =
638 last_update_ +
639 base::Milliseconds(static_cast<int>(base::Time::kMillisecondsPerSecond *
640 kListenerRefreshRatio1 * ttl_));
641
642 base::Time next_refresh2 =
643 last_update_ +
644 base::Milliseconds(static_cast<int>(base::Time::kMillisecondsPerSecond *
645 kListenerRefreshRatio2 * ttl_));
646
647 base::SingleThreadTaskRunner::GetCurrentDefault()->PostDelayedTask(
648 FROM_HERE, next_refresh_.callback(), next_refresh1 - clock_->Now());
649
650 base::SingleThreadTaskRunner::GetCurrentDefault()->PostDelayedTask(
651 FROM_HERE, next_refresh_.callback(), next_refresh2 - clock_->Now());
652 }
653
DoRefresh()654 void MDnsListenerImpl::DoRefresh() {
655 RecordQueryMetric(mdnsQueryType::kRefresh, name_);
656 client_->core()->SendQuery(rrtype_, name_);
657 }
658
MDnsTransactionImpl(uint16_t rrtype,const std::string & name,int flags,const MDnsTransaction::ResultCallback & callback,MDnsClientImpl * client)659 MDnsTransactionImpl::MDnsTransactionImpl(
660 uint16_t rrtype,
661 const std::string& name,
662 int flags,
663 const MDnsTransaction::ResultCallback& callback,
664 MDnsClientImpl* client)
665 : rrtype_(rrtype),
666 name_(name),
667 callback_(callback),
668 client_(client),
669 flags_(flags) {
670 DCHECK((flags_ & MDnsTransaction::FLAG_MASK) == flags_);
671 DCHECK(flags_ & MDnsTransaction::QUERY_CACHE ||
672 flags_ & MDnsTransaction::QUERY_NETWORK);
673 }
674
~MDnsTransactionImpl()675 MDnsTransactionImpl::~MDnsTransactionImpl() {
676 timeout_.Cancel();
677 }
678
Start()679 bool MDnsTransactionImpl::Start() {
680 DCHECK(!started_);
681 started_ = true;
682
683 base::WeakPtr<MDnsTransactionImpl> weak_this = AsWeakPtr();
684 if (flags_ & MDnsTransaction::QUERY_CACHE) {
685 ServeRecordsFromCache();
686
687 if (!weak_this || !is_active()) return true;
688 }
689
690 if (flags_ & MDnsTransaction::QUERY_NETWORK) {
691 return QueryAndListen();
692 }
693
694 // If this is a cache only query, signal that the transaction is over
695 // immediately.
696 SignalTransactionOver();
697 return true;
698 }
699
GetName() const700 const std::string& MDnsTransactionImpl::GetName() const {
701 return name_;
702 }
703
GetType() const704 uint16_t MDnsTransactionImpl::GetType() const {
705 return rrtype_;
706 }
707
CacheRecordFound(const RecordParsed * record)708 void MDnsTransactionImpl::CacheRecordFound(const RecordParsed* record) {
709 DCHECK(started_);
710 OnRecordUpdate(MDnsListener::RECORD_ADDED, record);
711 }
712
TriggerCallback(MDnsTransaction::Result result,const RecordParsed * record)713 void MDnsTransactionImpl::TriggerCallback(MDnsTransaction::Result result,
714 const RecordParsed* record) {
715 DCHECK(started_);
716 if (!is_active()) return;
717
718 // Ensure callback is run after touching all class state, so that
719 // the callback can delete the transaction.
720 MDnsTransaction::ResultCallback callback = callback_;
721
722 // Reset the transaction if it expects a single result, or if the result
723 // is a final one (everything except for a record).
724 if (flags_ & MDnsTransaction::SINGLE_RESULT ||
725 result != MDnsTransaction::RESULT_RECORD) {
726 Reset();
727 }
728
729 callback.Run(result, record);
730 }
731
Reset()732 void MDnsTransactionImpl::Reset() {
733 callback_.Reset();
734 listener_.reset();
735 timeout_.Cancel();
736 }
737
OnRecordUpdate(MDnsListener::UpdateType update,const RecordParsed * record)738 void MDnsTransactionImpl::OnRecordUpdate(MDnsListener::UpdateType update,
739 const RecordParsed* record) {
740 DCHECK(started_);
741 if (update == MDnsListener::RECORD_ADDED ||
742 update == MDnsListener::RECORD_CHANGED)
743 TriggerCallback(MDnsTransaction::RESULT_RECORD, record);
744 }
745
SignalTransactionOver()746 void MDnsTransactionImpl::SignalTransactionOver() {
747 DCHECK(started_);
748 if (flags_ & MDnsTransaction::SINGLE_RESULT) {
749 TriggerCallback(MDnsTransaction::RESULT_NO_RESULTS, nullptr);
750 } else {
751 TriggerCallback(MDnsTransaction::RESULT_DONE, nullptr);
752 }
753 }
754
ServeRecordsFromCache()755 void MDnsTransactionImpl::ServeRecordsFromCache() {
756 std::vector<const RecordParsed*> records;
757 base::WeakPtr<MDnsTransactionImpl> weak_this = AsWeakPtr();
758
759 if (client_->core()) {
760 client_->core()->QueryCache(rrtype_, name_, &records);
761 for (auto i = records.begin(); i != records.end() && weak_this; ++i) {
762 weak_this->TriggerCallback(MDnsTransaction::RESULT_RECORD, *i);
763 }
764
765 #if defined(ENABLE_NSEC)
766 if (records.empty()) {
767 DCHECK(weak_this);
768 client_->core()->QueryCache(dns_protocol::kTypeNSEC, name_, &records);
769 if (!records.empty()) {
770 const NsecRecordRdata* rdata =
771 records.front()->rdata<NsecRecordRdata>();
772 DCHECK(rdata);
773 if (!rdata->GetBit(rrtype_))
774 weak_this->TriggerCallback(MDnsTransaction::RESULT_NSEC, nullptr);
775 }
776 }
777 #endif
778 }
779 }
780
QueryAndListen()781 bool MDnsTransactionImpl::QueryAndListen() {
782 listener_ = client_->CreateListener(rrtype_, name_, this);
783 if (!listener_->Start())
784 return false;
785
786 DCHECK(client_->core());
787 RecordQueryMetric(mdnsQueryType::kInitial, name_);
788 if (!client_->core()->SendQuery(rrtype_, name_))
789 return false;
790
791 timeout_.Reset(
792 base::BindOnce(&MDnsTransactionImpl::SignalTransactionOver, AsWeakPtr()));
793 base::SingleThreadTaskRunner::GetCurrentDefault()->PostDelayedTask(
794 FROM_HERE, timeout_.callback(), kTransactionTimeout);
795
796 return true;
797 }
798
OnNsecRecord(const std::string & name,unsigned type)799 void MDnsTransactionImpl::OnNsecRecord(const std::string& name, unsigned type) {
800 TriggerCallback(RESULT_NSEC, nullptr);
801 }
802
OnCachePurged()803 void MDnsTransactionImpl::OnCachePurged() {
804 // TODO(noamsml): Cache purge situations not yet implemented
805 }
806
807 } // namespace net
808