xref: /aosp_15_r20/external/cronet/net/dns/host_resolver_mdns_task.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2018 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/dns/host_resolver_mdns_task.h"
6 
7 #include <utility>
8 
9 #include "base/check_op.h"
10 #include "base/functional/bind.h"
11 #include "base/location.h"
12 #include "base/memory/raw_ptr.h"
13 #include "base/notreached.h"
14 #include "base/ranges/algorithm.h"
15 #include "base/strings/string_util.h"
16 #include "base/task/sequenced_task_runner.h"
17 #include "net/base/ip_endpoint.h"
18 #include "net/base/net_errors.h"
19 #include "net/dns/dns_util.h"
20 #include "net/dns/public/dns_protocol.h"
21 #include "net/dns/public/dns_query_type.h"
22 #include "net/dns/record_parsed.h"
23 #include "net/dns/record_rdata.h"
24 
25 namespace net {
26 
27 namespace {
ParseHostnameResult(const std::string & host,uint16_t port)28 HostCache::Entry ParseHostnameResult(const std::string& host, uint16_t port) {
29   // Filter out root domain. Depending on the type, it either means no-result
30   // or is simply not a result important to any expected Chrome usecases.
31   if (host.empty()) {
32     return HostCache::Entry(ERR_NAME_NOT_RESOLVED,
33                             HostCache::Entry::SOURCE_UNKNOWN);
34   }
35   return HostCache::Entry(OK,
36                           std::vector<HostPortPair>({HostPortPair(host, port)}),
37                           HostCache::Entry::SOURCE_UNKNOWN);
38 }
39 }  // namespace
40 
41 class HostResolverMdnsTask::Transaction {
42  public:
Transaction(DnsQueryType query_type,HostResolverMdnsTask * task)43   Transaction(DnsQueryType query_type, HostResolverMdnsTask* task)
44       : query_type_(query_type),
45         results_(ERR_IO_PENDING, HostCache::Entry::SOURCE_UNKNOWN),
46         task_(task) {}
47 
Start()48   void Start() {
49     DCHECK_CALLED_ON_VALID_SEQUENCE(task_->sequence_checker_);
50 
51     // Should not be completed or running yet.
52     DCHECK_EQ(ERR_IO_PENDING, results_.error());
53     DCHECK(!async_transaction_);
54 
55     // TODO(crbug.com/926300): Use |allow_cached_response| to set the
56     // QUERY_CACHE flag or not.
57     int flags = MDnsTransaction::SINGLE_RESULT | MDnsTransaction::QUERY_CACHE |
58                 MDnsTransaction::QUERY_NETWORK;
59     // If |this| is destroyed, destruction of |internal_transaction_| should
60     // cancel and prevent invocation of OnComplete.
61     std::unique_ptr<MDnsTransaction> inner_transaction =
62         task_->mdns_client_->CreateTransaction(
63             DnsQueryTypeToQtype(query_type_), task_->hostname_, flags,
64             base::BindRepeating(&HostResolverMdnsTask::Transaction::OnComplete,
65                                 base::Unretained(this)));
66 
67     // Side effect warning: Start() may finish and invoke callbacks inline.
68     bool start_result = inner_transaction->Start();
69 
70     if (!start_result)
71       task_->Complete(true /* post_needed */);
72     else if (results_.error() == ERR_IO_PENDING)
73       async_transaction_ = std::move(inner_transaction);
74   }
75 
IsDone() const76   bool IsDone() const { return results_.error() != ERR_IO_PENDING; }
IsError() const77   bool IsError() const {
78     return IsDone() && results_.error() != OK &&
79            results_.error() != ERR_NAME_NOT_RESOLVED;
80   }
results() const81   const HostCache::Entry& results() const { return results_; }
82 
Cancel()83   void Cancel() {
84     DCHECK_CALLED_ON_VALID_SEQUENCE(task_->sequence_checker_);
85     DCHECK_EQ(ERR_IO_PENDING, results_.error());
86 
87     results_ = HostCache::Entry(ERR_FAILED, HostCache::Entry::SOURCE_UNKNOWN);
88     async_transaction_ = nullptr;
89   }
90 
91  private:
OnComplete(MDnsTransaction::Result result,const RecordParsed * parsed)92   void OnComplete(MDnsTransaction::Result result, const RecordParsed* parsed) {
93     DCHECK_CALLED_ON_VALID_SEQUENCE(task_->sequence_checker_);
94     DCHECK_EQ(ERR_IO_PENDING, results_.error());
95 
96     int error = ERR_UNEXPECTED;
97     switch (result) {
98       case MDnsTransaction::RESULT_RECORD:
99         DCHECK(parsed);
100         error = OK;
101         break;
102       case MDnsTransaction::RESULT_NO_RESULTS:
103       case MDnsTransaction::RESULT_NSEC:
104         error = ERR_NAME_NOT_RESOLVED;
105         break;
106       default:
107         // No other results should be possible with the request flags used.
108         NOTREACHED();
109     }
110 
111     results_ = HostResolverMdnsTask::ParseResult(error, query_type_, parsed,
112                                                  task_->hostname_);
113 
114     // If we don't have a saved async_transaction, it means OnComplete was
115     // invoked inline in MDnsTransaction::Start. Callbacks will need to be
116     // invoked via post.
117     task_->CheckCompletion(!async_transaction_);
118   }
119 
120   const DnsQueryType query_type_;
121 
122   // ERR_IO_PENDING until transaction completes (or is cancelled).
123   HostCache::Entry results_;
124 
125   // Not saved until MDnsTransaction::Start completes to differentiate inline
126   // completion.
127   std::unique_ptr<MDnsTransaction> async_transaction_;
128 
129   // Back pointer. Expected to destroy |this| before destroying itself.
130   const raw_ptr<HostResolverMdnsTask> task_;
131 };
132 
HostResolverMdnsTask(MDnsClient * mdns_client,std::string hostname,DnsQueryTypeSet query_types)133 HostResolverMdnsTask::HostResolverMdnsTask(MDnsClient* mdns_client,
134                                            std::string hostname,
135                                            DnsQueryTypeSet query_types)
136     : mdns_client_(mdns_client), hostname_(std::move(hostname)) {
137   CHECK(!query_types.empty());
138   DCHECK(!query_types.Has(DnsQueryType::UNSPECIFIED));
139 
140   static constexpr DnsQueryTypeSet kUnwantedQueries = {DnsQueryType::HTTPS};
141 
142   for (DnsQueryType query_type : Difference(query_types, kUnwantedQueries)) {
143     transactions_.emplace_back(query_type, this);
144   }
145   CHECK(!transactions_.empty()) << "Only unwanted query types supplied.";
146 }
147 
~HostResolverMdnsTask()148 HostResolverMdnsTask::~HostResolverMdnsTask() {
149   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
150   transactions_.clear();
151 }
152 
Start(base::OnceClosure completion_closure)153 void HostResolverMdnsTask::Start(base::OnceClosure completion_closure) {
154   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
155   DCHECK(!completion_closure_);
156   DCHECK(mdns_client_);
157 
158   completion_closure_ = std::move(completion_closure);
159 
160   for (auto& transaction : transactions_) {
161     // Only start transaction if it is not already marked done. A transaction
162     // could be marked done before starting if it is preemptively canceled by
163     // a previously started transaction finishing with an error.
164     if (!transaction.IsDone())
165       transaction.Start();
166   }
167 }
168 
GetResults() const169 HostCache::Entry HostResolverMdnsTask::GetResults() const {
170   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
171   DCHECK(!transactions_.empty());
172   DCHECK(!completion_closure_);
173   DCHECK(base::ranges::all_of(transactions_,
174                               [](const Transaction& t) { return t.IsDone(); }));
175 
176   auto found_error =
177       base::ranges::find_if(transactions_, &Transaction::IsError);
178   if (found_error != transactions_.end()) {
179     return found_error->results();
180   }
181 
182   HostCache::Entry combined_results = transactions_.front().results();
183   for (auto it = ++transactions_.begin(); it != transactions_.end(); ++it) {
184     combined_results = HostCache::Entry::MergeEntries(
185         std::move(combined_results), it->results());
186   }
187 
188   return combined_results;
189 }
190 
191 // static
ParseResult(int error,DnsQueryType query_type,const RecordParsed * parsed,const std::string & expected_hostname)192 HostCache::Entry HostResolverMdnsTask::ParseResult(
193     int error,
194     DnsQueryType query_type,
195     const RecordParsed* parsed,
196     const std::string& expected_hostname) {
197   if (error != OK) {
198     return HostCache::Entry(error, HostCache::Entry::SOURCE_UNKNOWN);
199   }
200   DCHECK(parsed);
201 
202   // Expected to be validated by MDnsClient.
203   DCHECK_EQ(DnsQueryTypeToQtype(query_type), parsed->type());
204   DCHECK(base::EqualsCaseInsensitiveASCII(expected_hostname, parsed->name()));
205 
206   switch (query_type) {
207     case DnsQueryType::UNSPECIFIED:
208       // Should create two separate transactions with specified type.
209     case DnsQueryType::HTTPS:
210       // Not supported.
211       // TODO([email protected]): Consider support for HTTPS in mDNS if it
212       // is ever decided to support HTTPS via non-DoH.
213       NOTREACHED();
214       return HostCache::Entry(ERR_FAILED, HostCache::Entry::SOURCE_UNKNOWN);
215     case DnsQueryType::A:
216       return HostCache::Entry(
217           OK, {IPEndPoint(parsed->rdata<net::ARecordRdata>()->address(), 0)},
218           /*aliases=*/{}, HostCache::Entry::SOURCE_UNKNOWN);
219     case DnsQueryType::AAAA:
220       return HostCache::Entry(
221           OK, {IPEndPoint(parsed->rdata<net::AAAARecordRdata>()->address(), 0)},
222           /*aliases=*/{}, HostCache::Entry::SOURCE_UNKNOWN);
223     case DnsQueryType::TXT:
224       return HostCache::Entry(OK, parsed->rdata<net::TxtRecordRdata>()->texts(),
225                               HostCache::Entry::SOURCE_UNKNOWN);
226     case DnsQueryType::PTR:
227       return ParseHostnameResult(parsed->rdata<PtrRecordRdata>()->ptrdomain(),
228                                  0 /* port */);
229     case DnsQueryType::SRV:
230       return ParseHostnameResult(parsed->rdata<SrvRecordRdata>()->target(),
231                                  parsed->rdata<SrvRecordRdata>()->port());
232   }
233 }
234 
CheckCompletion(bool post_needed)235 void HostResolverMdnsTask::CheckCompletion(bool post_needed) {
236   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
237 
238   // Finish immediately if any transactions completed with an error.
239   if (base::ranges::any_of(transactions_,
240                            [](const Transaction& t) { return t.IsError(); })) {
241     Complete(post_needed);
242     return;
243   }
244 
245   if (base::ranges::all_of(transactions_,
246                            [](const Transaction& t) { return t.IsDone(); })) {
247     Complete(post_needed);
248     return;
249   }
250 }
251 
Complete(bool post_needed)252 void HostResolverMdnsTask::Complete(bool post_needed) {
253   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
254 
255   // Cancel any incomplete async transactions.
256   for (auto& transaction : transactions_) {
257     if (!transaction.IsDone())
258       transaction.Cancel();
259   }
260 
261   if (post_needed) {
262     base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
263         FROM_HERE, base::BindOnce(
264                        [](base::WeakPtr<HostResolverMdnsTask> task) {
265                          if (task)
266                            std::move(task->completion_closure_).Run();
267                        },
268                        weak_ptr_factory_.GetWeakPtr()));
269   } else {
270     std::move(completion_closure_).Run();
271   }
272 }
273 
274 }  // namespace net
275