xref: /aosp_15_r20/external/cronet/net/dns/mdns_client_impl.h (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifndef NET_DNS_MDNS_CLIENT_IMPL_H_
6 #define NET_DNS_MDNS_CLIENT_IMPL_H_
7 
8 #include <stdint.h>
9 
10 #include <map>
11 #include <memory>
12 #include <string>
13 #include <utility>
14 #include <vector>
15 
16 #include "base/cancelable_callback.h"
17 #include "base/containers/queue.h"
18 #include "base/gtest_prod_util.h"
19 #include "base/memory/raw_ptr.h"
20 #include "base/observer_list.h"
21 #include "base/time/time.h"
22 #include "net/base/io_buffer.h"
23 #include "net/base/ip_endpoint.h"
24 #include "net/base/net_export.h"
25 #include "net/dns/mdns_cache.h"
26 #include "net/dns/mdns_client.h"
27 #include "net/socket/datagram_server_socket.h"
28 #include "net/socket/udp_server_socket.h"
29 #include "net/socket/udp_socket.h"
30 
31 namespace base {
32 class Clock;
33 class OneShotTimer;
34 }  // namespace base
35 
36 namespace net {
37 
38 class NetLog;
39 
40 class MDnsSocketFactoryImpl : public MDnsSocketFactory {
41  public:
MDnsSocketFactoryImpl()42   MDnsSocketFactoryImpl() : net_log_(nullptr) {}
MDnsSocketFactoryImpl(NetLog * net_log)43   explicit MDnsSocketFactoryImpl(NetLog* net_log) : net_log_(net_log) {}
44 
45   MDnsSocketFactoryImpl(const MDnsSocketFactoryImpl&) = delete;
46   MDnsSocketFactoryImpl& operator=(const MDnsSocketFactoryImpl&) = delete;
47 
48   ~MDnsSocketFactoryImpl() override = default;
49 
50   void CreateSockets(
51       std::vector<std::unique_ptr<DatagramServerSocket>>* sockets) override;
52 
53  private:
54   const raw_ptr<NetLog> net_log_;
55 };
56 
57 // A connection to the network for multicast DNS clients. It reads data into
58 // DnsResponse objects and alerts the delegate that a packet has been received.
59 class NET_EXPORT_PRIVATE MDnsConnection {
60  public:
61   class Delegate {
62    public:
63     // Handle an mDNS packet buffered in |response| with a size of |bytes_read|.
64     virtual void HandlePacket(DnsResponse* response, int bytes_read) = 0;
65     virtual void OnConnectionError(int error) = 0;
66     virtual ~Delegate() = default;
67   };
68 
69   explicit MDnsConnection(MDnsConnection::Delegate* delegate);
70 
71   MDnsConnection(const MDnsConnection&) = delete;
72   MDnsConnection& operator=(const MDnsConnection&) = delete;
73 
74   virtual ~MDnsConnection();
75 
76   // Succeeds if at least one of the socket handlers succeeded.
77   int Init(MDnsSocketFactory* socket_factory);
78   void Send(const scoped_refptr<IOBuffer>& buffer, unsigned size);
79 
80  private:
81   class SocketHandler {
82    public:
83     SocketHandler(std::unique_ptr<DatagramServerSocket> socket,
84                   MDnsConnection* connection);
85 
86     SocketHandler(const SocketHandler&) = delete;
87     SocketHandler& operator=(const SocketHandler&) = delete;
88 
89     ~SocketHandler();
90 
91     int Start();
92     void Send(const scoped_refptr<IOBuffer>& buffer, unsigned size);
93 
94    private:
95     int DoLoop(int rv);
96     void OnDatagramReceived(int rv);
97 
98     // Callback for when sending a query has finished.
99     void SendDone(int rv);
100 
101     std::unique_ptr<DatagramServerSocket> socket_;
102     raw_ptr<MDnsConnection> connection_;
103     IPEndPoint recv_addr_;
104     DnsResponse response_;
105     IPEndPoint multicast_addr_;
106     bool send_in_progress_ = false;
107     base::queue<std::pair<scoped_refptr<IOBuffer>, unsigned>> send_queue_;
108   };
109 
110   // Callback for handling a datagram being received on either ipv4 or ipv6.
111   void OnDatagramReceived(DnsResponse* response,
112                           const IPEndPoint& recv_addr,
113                           int bytes_read);
114 
115   void PostOnError(SocketHandler* loop, int rv);
116   void OnError(int rv);
117 
118   // Only socket handlers which successfully bound and started are kept.
119   std::vector<std::unique_ptr<SocketHandler>> socket_handlers_;
120 
121   raw_ptr<Delegate> delegate_;
122 
123   base::WeakPtrFactory<MDnsConnection> weak_ptr_factory_{this};
124 };
125 
126 class MDnsListenerImpl;
127 
128 class NET_EXPORT_PRIVATE MDnsClientImpl : public MDnsClient {
129  public:
130   // The core object exists while the MDnsClient is listening, and is deleted
131   // whenever the number of listeners reaches zero. The deletion happens
132   // asychronously, so destroying the last listener does not immediately
133   // invalidate the core.
134   class Core : public base::SupportsWeakPtr<Core>, MDnsConnection::Delegate {
135    public:
136     Core(base::Clock* clock, base::OneShotTimer* timer);
137 
138     Core(const Core&) = delete;
139     Core& operator=(const Core&) = delete;
140 
141     ~Core() override;
142 
143     // Initialize the core.
144     int Init(MDnsSocketFactory* socket_factory);
145 
146     // Send a query with a specific rrtype and name. Returns true on success.
147     bool SendQuery(uint16_t rrtype, const std::string& name);
148 
149     // Add/remove a listener to the list of listeners.
150     void AddListener(MDnsListenerImpl* listener);
151     void RemoveListener(MDnsListenerImpl* listener);
152 
153     // Query the cache for records of a specific type and name.
154     void QueryCache(uint16_t rrtype,
155                     const std::string& name,
156                     std::vector<const RecordParsed*>* records) const;
157 
158     // Parse the response and alert relevant listeners.
159     void HandlePacket(DnsResponse* response, int bytes_read) override;
160 
161     void OnConnectionError(int error) override;
162 
cache_for_testing()163     MDnsCache* cache_for_testing() { return &cache_; }
164 
165    private:
166     FRIEND_TEST_ALL_PREFIXES(MDnsTest, CacheCleanupWithShortTTL);
167 
168     class ListenerKey {
169      public:
170       ListenerKey(const std::string& name, uint16_t type);
171       ListenerKey(const ListenerKey&) = default;
172       ListenerKey(ListenerKey&&) = default;
173       bool operator<(const ListenerKey& key) const;
name_lowercase()174       const std::string& name_lowercase() const { return name_lowercase_; }
type()175       uint16_t type() const { return type_; }
176 
177      private:
178       std::string name_lowercase_;
179       uint16_t type_;
180     };
181     typedef base::ObserverList<MDnsListenerImpl>::Unchecked ObserverListType;
182     typedef std::map<ListenerKey, std::unique_ptr<ObserverListType>>
183         ListenerMap;
184 
185     // Alert listeners of an update to the cache.
186     void AlertListeners(MDnsCache::UpdateType update_type,
187                         const ListenerKey& key, const RecordParsed* record);
188 
189     // Schedule a cache cleanup to a specific time, cancelling other cleanups.
190     void ScheduleCleanup(base::Time cleanup);
191 
192     // Clean up the cache and schedule a new cleanup.
193     void DoCleanup();
194 
195     // Callback for when a record is removed from the cache.
196     void OnRecordRemoved(const RecordParsed* record);
197 
198     void NotifyNsecRecord(const RecordParsed* record);
199 
200     // Delete and erase the observer list for |key|. Only deletes the observer
201     // list if is empty.
202     void CleanupObserverList(const ListenerKey& key);
203 
204     ListenerMap listeners_;
205 
206     MDnsCache cache_;
207 
208     raw_ptr<base::Clock> clock_;
209     raw_ptr<base::OneShotTimer> cleanup_timer_;
210     base::Time scheduled_cleanup_;
211 
212     std::unique_ptr<MDnsConnection> connection_;
213   };
214 
215   MDnsClientImpl();
216 
217   // Test constructor, takes a mock clock and mock timer.
218   MDnsClientImpl(base::Clock* clock,
219                  std::unique_ptr<base::OneShotTimer> cleanup_timer);
220 
221   MDnsClientImpl(const MDnsClientImpl&) = delete;
222   MDnsClientImpl& operator=(const MDnsClientImpl&) = delete;
223 
224   ~MDnsClientImpl() override;
225 
226   // MDnsClient implementation:
227   std::unique_ptr<MDnsListener> CreateListener(
228       uint16_t rrtype,
229       const std::string& name,
230       MDnsListener::Delegate* delegate) override;
231 
232   std::unique_ptr<MDnsTransaction> CreateTransaction(
233       uint16_t rrtype,
234       const std::string& name,
235       int flags,
236       const MDnsTransaction::ResultCallback& callback) override;
237 
238   int StartListening(MDnsSocketFactory* socket_factory) override;
239   void StopListening() override;
240   bool IsListening() const override;
241 
core()242   Core* core() { return core_.get(); }
243 
244  private:
245   raw_ptr<base::Clock> clock_;
246   std::unique_ptr<base::OneShotTimer> cleanup_timer_;
247 
248   std::unique_ptr<Core> core_;
249 };
250 
251 class MDnsListenerImpl : public MDnsListener,
252                          public base::SupportsWeakPtr<MDnsListenerImpl> {
253  public:
254   MDnsListenerImpl(uint16_t rrtype,
255                    const std::string& name,
256                    base::Clock* clock,
257                    MDnsListener::Delegate* delegate,
258                    MDnsClientImpl* client);
259 
260   MDnsListenerImpl(const MDnsListenerImpl&) = delete;
261   MDnsListenerImpl& operator=(const MDnsListenerImpl&) = delete;
262 
263   ~MDnsListenerImpl() override;
264 
265   // MDnsListener implementation:
266   bool Start() override;
267 
268   // Actively refresh any received records.
269   void SetActiveRefresh(bool active_refresh) override;
270 
271   const std::string& GetName() const override;
272 
273   uint16_t GetType() const override;
274 
delegate()275   MDnsListener::Delegate* delegate() { return delegate_; }
276 
277   // Alert the delegate of a record update.
278   void HandleRecordUpdate(MDnsCache::UpdateType update_type,
279                           const RecordParsed* record_parsed);
280 
281   // Alert the delegate of the existence of an Nsec record.
282   void AlertNsecRecord();
283 
284  private:
285   void ScheduleNextRefresh();
286   void DoRefresh();
287 
288   uint16_t rrtype_;
289   std::string name_;
290   raw_ptr<base::Clock> clock_;
291   raw_ptr<MDnsClientImpl> client_;
292   raw_ptr<MDnsListener::Delegate> delegate_;
293 
294   base::Time last_update_;
295   uint32_t ttl_;
296   bool started_ = false;
297   bool active_refresh_ = false;
298 
299   base::CancelableRepeatingClosure next_refresh_;
300 };
301 
302 class MDnsTransactionImpl : public base::SupportsWeakPtr<MDnsTransactionImpl>,
303                             public MDnsTransaction,
304                             public MDnsListener::Delegate {
305  public:
306   MDnsTransactionImpl(uint16_t rrtype,
307                       const std::string& name,
308                       int flags,
309                       const MDnsTransaction::ResultCallback& callback,
310                       MDnsClientImpl* client);
311 
312   MDnsTransactionImpl(const MDnsTransactionImpl&) = delete;
313   MDnsTransactionImpl& operator=(const MDnsTransactionImpl&) = delete;
314 
315   ~MDnsTransactionImpl() override;
316 
317   // MDnsTransaction implementation:
318   bool Start() override;
319 
320   const std::string& GetName() const override;
321   uint16_t GetType() const override;
322 
323   // MDnsListener::Delegate implementation:
324   void OnRecordUpdate(MDnsListener::UpdateType update,
325                       const RecordParsed* record) override;
326   void OnNsecRecord(const std::string& name, unsigned type) override;
327 
328   void OnCachePurged() override;
329 
330  private:
is_active()331   bool is_active() { return !callback_.is_null(); }
332 
333   void Reset();
334 
335   // Trigger the callback and reset all related variables.
336   void TriggerCallback(MDnsTransaction::Result result,
337                        const RecordParsed* record);
338 
339   // Internal callback for when a cache record is found.
340   void CacheRecordFound(const RecordParsed* record);
341 
342   // Signal the transactionis over and release all related resources.
343   void SignalTransactionOver();
344 
345   // Reads records from the cache and calls the callback for every
346   // record read.
347   void ServeRecordsFromCache();
348 
349   // Send a query to the network and set up a timeout to time out the
350   // transaction. Returns false if it fails to start listening on the network
351   // or if it fails to send a query.
352   bool QueryAndListen();
353 
354   uint16_t rrtype_;
355   std::string name_;
356   MDnsTransaction::ResultCallback callback_;
357 
358   std::unique_ptr<MDnsListener> listener_;
359   base::CancelableOnceCallback<void()> timeout_;
360 
361   raw_ptr<MDnsClientImpl> client_;
362 
363   bool started_ = false;
364   int flags_;
365 };
366 
367 }  // namespace net
368 #endif  // NET_DNS_MDNS_CLIENT_IMPL_H_
369