1 // Copyright 2020 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "discovery/dnssd/impl/dns_data_graph.h"
6
7 #include <utility>
8
9 #include "discovery/dnssd/impl/conversion_layer.h"
10 #include "discovery/dnssd/impl/instance_key.h"
11
12 namespace openscreen {
13 namespace discovery {
14 namespace {
15
CreateEndpoint(const DomainName & domain,const absl::optional<ARecordRdata> & a,const absl::optional<AAAARecordRdata> & aaaa,const SrvRecordRdata & srv,const TxtRecordRdata & txt,NetworkInterfaceIndex network_interface)16 ErrorOr<DnsSdInstanceEndpoint> CreateEndpoint(
17 const DomainName& domain,
18 const absl::optional<ARecordRdata>& a,
19 const absl::optional<AAAARecordRdata>& aaaa,
20 const SrvRecordRdata& srv,
21 const TxtRecordRdata& txt,
22 NetworkInterfaceIndex network_interface) {
23 // Create the user-visible TXT record representation.
24 ErrorOr<DnsSdTxtRecord> txt_or_error = CreateFromDnsTxt(txt);
25 if (txt_or_error.is_error()) {
26 return txt_or_error.error();
27 }
28
29 InstanceKey instance_id(domain);
30 std::vector<IPEndpoint> endpoints;
31 if (a.has_value()) {
32 endpoints.push_back({a.value().ipv4_address(), srv.port()});
33 }
34 if (aaaa.has_value()) {
35 endpoints.push_back({aaaa.value().ipv6_address(), srv.port()});
36 }
37
38 return DnsSdInstanceEndpoint(
39 instance_id.instance_id(), instance_id.service_id(),
40 instance_id.domain_id(), std::move(txt_or_error.value()),
41 network_interface, std::move(endpoints));
42 }
43
44 class DnsDataGraphImpl : public DnsDataGraph {
45 public:
46 using DnsDataGraph::DomainChangeCallback;
47
DnsDataGraphImpl(NetworkInterfaceIndex network_interface)48 explicit DnsDataGraphImpl(NetworkInterfaceIndex network_interface)
49 : network_interface_(network_interface) {}
50 DnsDataGraphImpl(const DnsDataGraphImpl& other) = delete;
51 DnsDataGraphImpl(DnsDataGraphImpl&& other) = delete;
52
~DnsDataGraphImpl()53 ~DnsDataGraphImpl() override { is_dtor_running_ = true; }
54
55 DnsDataGraphImpl& operator=(const DnsDataGraphImpl& rhs) = delete;
56 DnsDataGraphImpl& operator=(DnsDataGraphImpl&& rhs) = delete;
57
58 // DnsDataGraph overrides.
59 void StartTracking(const DomainName& domain,
60 DomainChangeCallback on_start_tracking) override;
61
62 void StopTracking(const DomainName& domain,
63 DomainChangeCallback on_stop_tracking) override;
64
65 std::vector<ErrorOr<DnsSdInstanceEndpoint>> CreateEndpoints(
66 DomainGroup domain_group,
67 const DomainName& name) const override;
68
69 Error ApplyDataRecordChange(MdnsRecord record,
70 RecordChangedEvent event,
71 DomainChangeCallback on_start_tracking,
72 DomainChangeCallback on_stop_tracking) override;
73
GetTrackedDomainCount() const74 size_t GetTrackedDomainCount() const override { return nodes_.size(); }
75
IsTracked(const DomainName & name) const76 bool IsTracked(const DomainName& name) const override {
77 return nodes_.find(name) != nodes_.end();
78 }
79
80 private:
81 class NodeLifetimeHandler;
82
83 using ScopedCallbackHandler = std::unique_ptr<NodeLifetimeHandler>;
84
85 // A single node of the graph represented by this type.
86 class Node {
87 public:
88 // NOE: This class is non-copyable, non-movable because either operation
89 // would invalidate the pointer references or bidirectional edge states
90 // maintained by instances of this class.
91 Node(DomainName name, DnsDataGraphImpl* graph);
92 Node(const Node& other) = delete;
93 Node(Node&& other) = delete;
94
95 ~Node();
96
97 Node& operator=(const Node& rhs) = delete;
98 Node& operator=(Node&& rhs) = delete;
99
100 // Applies a record change for this node.
101 Error ApplyDataRecordChange(MdnsRecord record, RecordChangedEvent event);
102
103 // Returns the first rdata of a record with type matching |type| in this
104 // node's |records_|, or absl::nullopt if no such record exists.
105 template <typename T>
GetRdata(DnsType type)106 absl::optional<T> GetRdata(DnsType type) {
107 auto it = FindRecord(type);
108 if (it == records_.end()) {
109 return absl::nullopt;
110 } else {
111 return std::cref(absl::get<T>(it->rdata()));
112 }
113 }
114
name() const115 const DomainName& name() const { return name_; }
parents() const116 const std::vector<Node*>& parents() const { return parents_; }
children() const117 const std::vector<Node*>& children() const { return children_; }
records() const118 const std::vector<MdnsRecord>& records() const { return records_; }
119
120 private:
121 // Adds or removes an edge in |graph_|.
122 // NOTE: The same edge may be added multiple times, and one call to remove
123 // is needed for every such call.
124 void AddChild(Node* child);
125 void RemoveChild(Node* child);
126
127 // Applies the specified change to domain |child| for this node.
128 void ApplyChildChange(DomainName child_name, RecordChangedEvent event);
129
130 // Finds an iterator to the record of the provided type, or to
131 // records_.end() if no such record exists.
132 std::vector<MdnsRecord>::iterator FindRecord(DnsType type);
133
134 // The domain with which the data records stored in this node are
135 // associated.
136 const DomainName name_;
137
138 // Currently extant mDNS Records at |name_|.
139 std::vector<MdnsRecord> records_;
140
141 // Nodes which contain records pointing to this node's |name|.
142 std::vector<Node*> parents_;
143
144 // Nodes containing records pointed to by the records in this node.
145 std::vector<Node*> children_;
146
147 // Graph containing this node.
148 DnsDataGraphImpl* graph_;
149 };
150
151 // Wrapper to handle the creation and deletion callbacks. When the object is
152 // created, it sets the callback to use, and erases the callback when it goes
153 // out of scope. This class allows all node creations to complete before
154 // calling the user-provided callback to ensure there are no race-conditions.
155 class NodeLifetimeHandler {
156 public:
157 NodeLifetimeHandler(DomainChangeCallback* callback_ptr,
158 DomainChangeCallback callback);
159
160 // NOTE: The copy and delete ctors and operators must be deleted because
161 // they would invalidate the pointer logic used here.
162 NodeLifetimeHandler(const NodeLifetimeHandler& other) = delete;
163 NodeLifetimeHandler(NodeLifetimeHandler&& other) = delete;
164
165 ~NodeLifetimeHandler();
166
167 NodeLifetimeHandler operator=(const NodeLifetimeHandler& other) = delete;
168 NodeLifetimeHandler operator=(NodeLifetimeHandler&& other) = delete;
169
170 private:
171 std::vector<DomainName> domains_changed;
172
173 DomainChangeCallback* callback_ptr_;
174 DomainChangeCallback callback_;
175 };
176
177 // Helpers to create the ScopedCallbackHandlers for creation and deletion
178 // callbacks.
179 ScopedCallbackHandler GetScopedCreationHandler(
180 DomainChangeCallback creation_callback);
181 ScopedCallbackHandler GetScopedDeletionHandler(
182 DomainChangeCallback deletion_callback);
183
184 // Determines whether the provided node has the necessary records to be a
185 // valid node at the specified domain level.
186 static bool IsValidAddressNode(Node* node);
187 static bool IsValidSrvAndTxtNode(Node* node);
188
189 // Calculates the set of DnsSdInstanceEndpoints associated with the PTR
190 // records present at the given |node|.
191 std::vector<ErrorOr<DnsSdInstanceEndpoint>> CalculatePtrRecordEndpoints(
192 Node* node) const;
193
194 // Denotes whether the dtor for this instance has been called. This is
195 // required for validation of Node instance functionality. See the
196 // implementation of DnsDataGraph::Node::~Node() for more details.
197 bool is_dtor_running_ = false;
198
199 // Map from domain name to the node containing all records associated with the
200 // name.
201 std::map<DomainName, std::unique_ptr<Node>> nodes_;
202
203 const NetworkInterfaceIndex network_interface_;
204
205 // The methods to be called when a domain name either starts or stops being
206 // referenced. These will only be set when a record change is ongoing, and act
207 // as a single source of truth for the creation and deletion callbacks that
208 // should be used during that operation.
209 DomainChangeCallback on_node_creation_;
210 DomainChangeCallback on_node_deletion_;
211 };
212
Node(DomainName name,DnsDataGraphImpl * graph)213 DnsDataGraphImpl::Node::Node(DomainName name, DnsDataGraphImpl* graph)
214 : name_(std::move(name)), graph_(graph) {
215 OSP_DCHECK(graph_);
216
217 graph_->on_node_creation_(name_);
218 }
219
~Node()220 DnsDataGraphImpl::Node::~Node() {
221 // A node should only be deleted when it has no parents. The only case where
222 // a deletion can occur when parents are still extant is during destruction of
223 // the holding graph. In that case, the state of the graph no longer matters
224 // and all nodes will be deleted, so no need to consider the child pointers.
225 if (!graph_->is_dtor_running_) {
226 auto it = std::find_if(parents_.begin(), parents_.end(),
227 [this](Node* parent) { return parent != this; });
228 OSP_DCHECK(it == parents_.end());
229
230 // Erase all childrens' parent pointers to this node.
231 for (Node* child : children_) {
232 RemoveChild(child);
233 }
234
235 OSP_DCHECK(graph_->on_node_deletion_);
236 graph_->on_node_deletion_(name_);
237 }
238 }
239
ApplyDataRecordChange(MdnsRecord record,RecordChangedEvent event)240 Error DnsDataGraphImpl::Node::ApplyDataRecordChange(MdnsRecord record,
241 RecordChangedEvent event) {
242 OSP_DCHECK(record.name() == name_);
243
244 // The child domain to which the changed record points, or none. This is only
245 // applicable for PTR and SRV records, and is empty in all other cases.
246 DomainName child_name;
247
248 // The location of the current record. In the case of PTR records, multiple
249 // records are allowed for the same domain. In all other cases, this is not
250 // valid.
251 std::vector<MdnsRecord>::iterator it;
252
253 if (record.dns_type() == DnsType::kPTR) {
254 child_name = absl::get<PtrRecordRdata>(record.rdata()).ptr_domain();
255 it = std::find_if(records_.begin(), records_.end(),
256 [record](const MdnsRecord& rhs) {
257 return record.IsReannouncementOf(rhs);
258 });
259 } else {
260 if (record.dns_type() == DnsType::kSRV) {
261 child_name = absl::get<SrvRecordRdata>(record.rdata()).target();
262 }
263 it = FindRecord(record.dns_type());
264 }
265
266 // Validate that the requested change is allowed and apply it.
267 switch (event) {
268 case RecordChangedEvent::kCreated:
269 if (it != records_.end()) {
270 return Error::Code::kItemAlreadyExists;
271 }
272 records_.push_back(std::move(record));
273 break;
274
275 case RecordChangedEvent::kUpdated:
276 if (it == records_.end()) {
277 return Error::Code::kItemNotFound;
278 }
279 *it = std::move(record);
280 break;
281
282 case RecordChangedEvent::kExpired:
283 if (it == records_.end()) {
284 return Error::Code::kItemNotFound;
285 }
286 records_.erase(it);
287 break;
288 }
289
290 // Apply any required edge changes to the graph. This is only applicable if
291 // a |child| was found earlier. Note that the same child can be added multiple
292 // times to the |children_| vector, which simplifies the code dramatically.
293 if (!child_name.empty()) {
294 ApplyChildChange(std::move(child_name), event);
295 }
296
297 return Error::None();
298 }
299
ApplyChildChange(DomainName child_name,RecordChangedEvent event)300 void DnsDataGraphImpl::Node::ApplyChildChange(DomainName child_name,
301 RecordChangedEvent event) {
302 if (event == RecordChangedEvent::kCreated) {
303 const auto pair =
304 graph_->nodes_.emplace(child_name, std::unique_ptr<Node>());
305 if (pair.second) {
306 auto new_node = std::make_unique<Node>(std::move(child_name), graph_);
307 pair.first->second.swap(new_node);
308 }
309
310 AddChild(pair.first->second.get());
311 } else if (event == RecordChangedEvent::kExpired) {
312 const auto it = graph_->nodes_.find(child_name);
313 if (it == graph_->nodes_.end()) {
314 OSP_LOG_WARN << "Unable to find child_name=" << child_name.ToString();
315 } else {
316 RemoveChild(it->second.get());
317 }
318 }
319 }
320
AddChild(Node * child)321 void DnsDataGraphImpl::Node::AddChild(Node* child) {
322 OSP_DCHECK(child);
323 children_.push_back(child);
324 child->parents_.push_back(this);
325 }
326
RemoveChild(Node * child)327 void DnsDataGraphImpl::Node::RemoveChild(Node* child) {
328 OSP_DCHECK(child);
329
330 auto it = std::find(children_.begin(), children_.end(), child);
331 OSP_DCHECK(it != children_.end());
332 children_.erase(it);
333
334 it = std::find(child->parents_.begin(), child->parents_.end(), this);
335 OSP_DCHECK(it != child->parents_.end());
336 child->parents_.erase(it);
337
338 // If the node has been orphaned, remove it.
339 it = std::find_if(child->parents_.begin(), child->parents_.end(),
340 [child](Node* parent) { return parent != child; });
341 if (it == child->parents_.end()) {
342 DomainName child_name = child->name();
343 const size_t count = graph_->nodes_.erase(child_name);
344 OSP_DCHECK(child == this || count);
345 }
346 }
347
FindRecord(DnsType type)348 std::vector<MdnsRecord>::iterator DnsDataGraphImpl::Node::FindRecord(
349 DnsType type) {
350 return std::find_if(
351 records_.begin(), records_.end(),
352 [type](const MdnsRecord& record) { return record.dns_type() == type; });
353 }
354
NodeLifetimeHandler(DomainChangeCallback * callback_ptr,DomainChangeCallback callback)355 DnsDataGraphImpl::NodeLifetimeHandler::NodeLifetimeHandler(
356 DomainChangeCallback* callback_ptr,
357 DomainChangeCallback callback)
358 : callback_ptr_(callback_ptr), callback_(callback) {
359 OSP_DCHECK(callback_ptr_);
360 OSP_DCHECK(callback);
361 OSP_DCHECK(*callback_ptr_ == nullptr);
362 *callback_ptr = [this](DomainName domain) {
363 domains_changed.push_back(std::move(domain));
364 };
365 }
366
~NodeLifetimeHandler()367 DnsDataGraphImpl::NodeLifetimeHandler::~NodeLifetimeHandler() {
368 *callback_ptr_ = nullptr;
369 for (DomainName& domain : domains_changed) {
370 callback_(domain);
371 }
372 }
373
374 DnsDataGraphImpl::ScopedCallbackHandler
GetScopedCreationHandler(DomainChangeCallback creation_callback)375 DnsDataGraphImpl::GetScopedCreationHandler(
376 DomainChangeCallback creation_callback) {
377 return std::make_unique<NodeLifetimeHandler>(&on_node_creation_,
378 std::move(creation_callback));
379 }
380
381 DnsDataGraphImpl::ScopedCallbackHandler
GetScopedDeletionHandler(DomainChangeCallback deletion_callback)382 DnsDataGraphImpl::GetScopedDeletionHandler(
383 DomainChangeCallback deletion_callback) {
384 return std::make_unique<NodeLifetimeHandler>(&on_node_deletion_,
385 std::move(deletion_callback));
386 }
387
StartTracking(const DomainName & domain,DomainChangeCallback on_start_tracking)388 void DnsDataGraphImpl::StartTracking(const DomainName& domain,
389 DomainChangeCallback on_start_tracking) {
390 ScopedCallbackHandler creation_handler =
391 GetScopedCreationHandler(std::move(on_start_tracking));
392
393 auto pair = nodes_.emplace(domain, std::make_unique<Node>(domain, this));
394
395 OSP_DCHECK(pair.second);
396 OSP_DCHECK(nodes_.find(domain) != nodes_.end());
397 }
398
StopTracking(const DomainName & domain,DomainChangeCallback on_stop_tracking)399 void DnsDataGraphImpl::StopTracking(const DomainName& domain,
400 DomainChangeCallback on_stop_tracking) {
401 ScopedCallbackHandler deletion_handler =
402 GetScopedDeletionHandler(std::move(on_stop_tracking));
403
404 auto it = nodes_.find(domain);
405 OSP_CHECK(it != nodes_.end());
406 OSP_DCHECK(it->second->parents().empty());
407 it->second.reset();
408 const size_t erased_count = nodes_.erase(domain);
409 OSP_DCHECK(erased_count);
410 }
411
ApplyDataRecordChange(MdnsRecord record,RecordChangedEvent event,DomainChangeCallback on_start_tracking,DomainChangeCallback on_stop_tracking)412 Error DnsDataGraphImpl::ApplyDataRecordChange(
413 MdnsRecord record,
414 RecordChangedEvent event,
415 DomainChangeCallback on_start_tracking,
416 DomainChangeCallback on_stop_tracking) {
417 ScopedCallbackHandler creation_handler =
418 GetScopedCreationHandler(std::move(on_start_tracking));
419 ScopedCallbackHandler deletion_handler =
420 GetScopedDeletionHandler(std::move(on_stop_tracking));
421
422 auto it = nodes_.find(record.name());
423 if (it == nodes_.end()) {
424 return Error::Code::kOperationCancelled;
425 }
426
427 const auto result =
428 it->second->ApplyDataRecordChange(std::move(record), event);
429
430 return result;
431 }
432
CreateEndpoints(DomainGroup domain_group,const DomainName & name) const433 std::vector<ErrorOr<DnsSdInstanceEndpoint>> DnsDataGraphImpl::CreateEndpoints(
434 DomainGroup domain_group,
435 const DomainName& name) const {
436 const auto it = nodes_.find(name);
437 if (it == nodes_.end()) {
438 return {};
439 }
440 Node* target_node = it->second.get();
441
442 // NOTE: One of these will contain no more than one element, so iterating over
443 // them both will be fast.
444 std::vector<Node*> srv_and_txt_record_nodes;
445 std::vector<Node*> address_record_nodes;
446
447 switch (domain_group) {
448 case DomainGroup::kAddress:
449 if (!IsValidAddressNode(target_node)) {
450 return {};
451 }
452
453 address_record_nodes.push_back(target_node);
454 srv_and_txt_record_nodes = target_node->parents();
455 break;
456
457 case DomainGroup::kSrvAndTxt:
458 if (!IsValidSrvAndTxtNode(target_node)) {
459 return {};
460 }
461
462 srv_and_txt_record_nodes.push_back(target_node);
463 address_record_nodes = target_node->children();
464 break;
465
466 case DomainGroup::kPtr:
467 return CalculatePtrRecordEndpoints(target_node);
468
469 default:
470 return {};
471 }
472
473 // Iterate across all node pairs and create all possible DnsSdInstanceEndpoint
474 // objects.
475 std::vector<ErrorOr<DnsSdInstanceEndpoint>> endpoints;
476 for (Node* srv_and_txt : srv_and_txt_record_nodes) {
477 for (Node* address : address_record_nodes) {
478 // First, there has to be a SRV record present (to provide the port
479 // number), and the target of that SRV record has to be the node where the
480 // address records are sourced from.
481 const absl::optional<SrvRecordRdata> srv =
482 srv_and_txt->GetRdata<SrvRecordRdata>(DnsType::kSRV);
483 if (!srv.has_value() || srv.value().target() != address->name()) {
484 continue;
485 }
486
487 // Next, a TXT record must be present to provide additional connection
488 // information about the service per RFC 6763.
489 const absl::optional<TxtRecordRdata> txt =
490 srv_and_txt->GetRdata<TxtRecordRdata>(DnsType::kTXT);
491 if (!txt.has_value()) {
492 continue;
493 }
494
495 // Last, at least one address record must be present to provide an
496 // endpoint for this instance.
497 const absl::optional<ARecordRdata> a =
498 address->GetRdata<ARecordRdata>(DnsType::kA);
499 const absl::optional<AAAARecordRdata> aaaa =
500 address->GetRdata<AAAARecordRdata>(DnsType::kAAAA);
501 if (!a.has_value() && !aaaa.has_value()) {
502 continue;
503 }
504
505 // Then use the above info to create an endpoint object. If an error
506 // occurs, this is only related to the one endpoint and its possible that
507 // other endpoints may still be valid, so only the one endpoint is treated
508 // as failing. For instance, a bad TXT record for service A will not
509 // affect the endpoints for service B.
510 ErrorOr<DnsSdInstanceEndpoint> endpoint =
511 CreateEndpoint(srv_and_txt->name(), a, aaaa, srv.value(), txt.value(),
512 network_interface_);
513 endpoints.push_back(std::move(endpoint));
514 }
515 }
516
517 return endpoints;
518 }
519
520 // static
IsValidAddressNode(Node * node)521 bool DnsDataGraphImpl::IsValidAddressNode(Node* node) {
522 const absl::optional<ARecordRdata> a =
523 node->GetRdata<ARecordRdata>(DnsType::kA);
524 const absl::optional<AAAARecordRdata> aaaa =
525 node->GetRdata<AAAARecordRdata>(DnsType::kAAAA);
526 return a.has_value() || aaaa.has_value();
527 }
528
529 // static
IsValidSrvAndTxtNode(Node * node)530 bool DnsDataGraphImpl::IsValidSrvAndTxtNode(Node* node) {
531 const absl::optional<SrvRecordRdata> srv =
532 node->GetRdata<SrvRecordRdata>(DnsType::kSRV);
533 const absl::optional<TxtRecordRdata> txt =
534 node->GetRdata<TxtRecordRdata>(DnsType::kTXT);
535
536 return srv.has_value() && txt.has_value();
537 }
538
539 std::vector<ErrorOr<DnsSdInstanceEndpoint>>
CalculatePtrRecordEndpoints(Node * node) const540 DnsDataGraphImpl::CalculatePtrRecordEndpoints(Node* node) const {
541 // PTR records aren't actually part of the generated endpoint objects, so
542 // call this method recursively on all children and
543 std::vector<ErrorOr<DnsSdInstanceEndpoint>> endpoints;
544 for (const MdnsRecord& record : node->records()) {
545 if (record.dns_type() != DnsType::kPTR) {
546 continue;
547 }
548
549 const DomainName domain =
550 absl::get<PtrRecordRdata>(record.rdata()).ptr_domain();
551 const Node* child = nodes_.find(domain)->second.get();
552 std::vector<ErrorOr<DnsSdInstanceEndpoint>> child_endpoints =
553 CreateEndpoints(DomainGroup::kSrvAndTxt, child->name());
554 for (auto& endpoint_or_error : child_endpoints) {
555 endpoints.push_back(std::move(endpoint_or_error));
556 }
557 }
558 return endpoints;
559 }
560
561 } // namespace
562
563 DnsDataGraph::~DnsDataGraph() = default;
564
565 // static
Create(NetworkInterfaceIndex network_interface)566 std::unique_ptr<DnsDataGraph> DnsDataGraph::Create(
567 NetworkInterfaceIndex network_interface) {
568 return std::make_unique<DnsDataGraphImpl>(network_interface);
569 }
570
571 // static
GetDomainGroup(DnsType type)572 DnsDataGraphImpl::DomainGroup DnsDataGraph::GetDomainGroup(DnsType type) {
573 switch (type) {
574 case DnsType::kA:
575 case DnsType::kAAAA:
576 return DnsDataGraphImpl::DomainGroup::kAddress;
577 case DnsType::kSRV:
578 case DnsType::kTXT:
579 return DnsDataGraphImpl::DomainGroup::kSrvAndTxt;
580 case DnsType::kPTR:
581 return DnsDataGraphImpl::DomainGroup::kPtr;
582 default:
583 OSP_NOTREACHED();
584 }
585 }
586
587 // static
GetDomainGroup(const MdnsRecord record)588 DnsDataGraphImpl::DomainGroup DnsDataGraph::GetDomainGroup(
589 const MdnsRecord record) {
590 return GetDomainGroup(record.dns_type());
591 }
592
593 } // namespace discovery
594 } // namespace openscreen
595