1 // Copyright 2013 The Chromium Authors 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #ifndef NET_DNS_MDNS_CLIENT_IMPL_H_ 6 #define NET_DNS_MDNS_CLIENT_IMPL_H_ 7 8 #include <stdint.h> 9 10 #include <map> 11 #include <memory> 12 #include <string> 13 #include <utility> 14 #include <vector> 15 16 #include "base/cancelable_callback.h" 17 #include "base/containers/queue.h" 18 #include "base/gtest_prod_util.h" 19 #include "base/memory/raw_ptr.h" 20 #include "base/observer_list.h" 21 #include "base/time/time.h" 22 #include "net/base/io_buffer.h" 23 #include "net/base/ip_endpoint.h" 24 #include "net/base/net_export.h" 25 #include "net/dns/mdns_cache.h" 26 #include "net/dns/mdns_client.h" 27 #include "net/socket/datagram_server_socket.h" 28 #include "net/socket/udp_server_socket.h" 29 #include "net/socket/udp_socket.h" 30 31 namespace base { 32 class Clock; 33 class OneShotTimer; 34 } // namespace base 35 36 namespace net { 37 38 class NetLog; 39 40 class MDnsSocketFactoryImpl : public MDnsSocketFactory { 41 public: MDnsSocketFactoryImpl()42 MDnsSocketFactoryImpl() : net_log_(nullptr) {} MDnsSocketFactoryImpl(NetLog * net_log)43 explicit MDnsSocketFactoryImpl(NetLog* net_log) : net_log_(net_log) {} 44 45 MDnsSocketFactoryImpl(const MDnsSocketFactoryImpl&) = delete; 46 MDnsSocketFactoryImpl& operator=(const MDnsSocketFactoryImpl&) = delete; 47 48 ~MDnsSocketFactoryImpl() override = default; 49 50 void CreateSockets( 51 std::vector<std::unique_ptr<DatagramServerSocket>>* sockets) override; 52 53 private: 54 const raw_ptr<NetLog> net_log_; 55 }; 56 57 // A connection to the network for multicast DNS clients. It reads data into 58 // DnsResponse objects and alerts the delegate that a packet has been received. 59 class NET_EXPORT_PRIVATE MDnsConnection { 60 public: 61 class Delegate { 62 public: 63 // Handle an mDNS packet buffered in |response| with a size of |bytes_read|. 64 virtual void HandlePacket(DnsResponse* response, int bytes_read) = 0; 65 virtual void OnConnectionError(int error) = 0; 66 virtual ~Delegate() = default; 67 }; 68 69 explicit MDnsConnection(MDnsConnection::Delegate* delegate); 70 71 MDnsConnection(const MDnsConnection&) = delete; 72 MDnsConnection& operator=(const MDnsConnection&) = delete; 73 74 virtual ~MDnsConnection(); 75 76 // Succeeds if at least one of the socket handlers succeeded. 77 int Init(MDnsSocketFactory* socket_factory); 78 void Send(const scoped_refptr<IOBuffer>& buffer, unsigned size); 79 80 private: 81 class SocketHandler { 82 public: 83 SocketHandler(std::unique_ptr<DatagramServerSocket> socket, 84 MDnsConnection* connection); 85 86 SocketHandler(const SocketHandler&) = delete; 87 SocketHandler& operator=(const SocketHandler&) = delete; 88 89 ~SocketHandler(); 90 91 int Start(); 92 void Send(const scoped_refptr<IOBuffer>& buffer, unsigned size); 93 94 private: 95 int DoLoop(int rv); 96 void OnDatagramReceived(int rv); 97 98 // Callback for when sending a query has finished. 99 void SendDone(int rv); 100 101 std::unique_ptr<DatagramServerSocket> socket_; 102 raw_ptr<MDnsConnection> connection_; 103 IPEndPoint recv_addr_; 104 DnsResponse response_; 105 IPEndPoint multicast_addr_; 106 bool send_in_progress_ = false; 107 base::queue<std::pair<scoped_refptr<IOBuffer>, unsigned>> send_queue_; 108 }; 109 110 // Callback for handling a datagram being received on either ipv4 or ipv6. 111 void OnDatagramReceived(DnsResponse* response, 112 const IPEndPoint& recv_addr, 113 int bytes_read); 114 115 void PostOnError(SocketHandler* loop, int rv); 116 void OnError(int rv); 117 118 // Only socket handlers which successfully bound and started are kept. 119 std::vector<std::unique_ptr<SocketHandler>> socket_handlers_; 120 121 raw_ptr<Delegate> delegate_; 122 123 base::WeakPtrFactory<MDnsConnection> weak_ptr_factory_{this}; 124 }; 125 126 class MDnsListenerImpl; 127 128 class NET_EXPORT_PRIVATE MDnsClientImpl : public MDnsClient { 129 public: 130 // The core object exists while the MDnsClient is listening, and is deleted 131 // whenever the number of listeners reaches zero. The deletion happens 132 // asychronously, so destroying the last listener does not immediately 133 // invalidate the core. 134 class Core : public base::SupportsWeakPtr<Core>, MDnsConnection::Delegate { 135 public: 136 Core(base::Clock* clock, base::OneShotTimer* timer); 137 138 Core(const Core&) = delete; 139 Core& operator=(const Core&) = delete; 140 141 ~Core() override; 142 143 // Initialize the core. 144 int Init(MDnsSocketFactory* socket_factory); 145 146 // Send a query with a specific rrtype and name. Returns true on success. 147 bool SendQuery(uint16_t rrtype, const std::string& name); 148 149 // Add/remove a listener to the list of listeners. 150 void AddListener(MDnsListenerImpl* listener); 151 void RemoveListener(MDnsListenerImpl* listener); 152 153 // Query the cache for records of a specific type and name. 154 void QueryCache(uint16_t rrtype, 155 const std::string& name, 156 std::vector<const RecordParsed*>* records) const; 157 158 // Parse the response and alert relevant listeners. 159 void HandlePacket(DnsResponse* response, int bytes_read) override; 160 161 void OnConnectionError(int error) override; 162 cache_for_testing()163 MDnsCache* cache_for_testing() { return &cache_; } 164 165 private: 166 FRIEND_TEST_ALL_PREFIXES(MDnsTest, CacheCleanupWithShortTTL); 167 168 class ListenerKey { 169 public: 170 ListenerKey(const std::string& name, uint16_t type); 171 ListenerKey(const ListenerKey&) = default; 172 ListenerKey(ListenerKey&&) = default; 173 bool operator<(const ListenerKey& key) const; name_lowercase()174 const std::string& name_lowercase() const { return name_lowercase_; } type()175 uint16_t type() const { return type_; } 176 177 private: 178 std::string name_lowercase_; 179 uint16_t type_; 180 }; 181 typedef base::ObserverList<MDnsListenerImpl>::Unchecked ObserverListType; 182 typedef std::map<ListenerKey, std::unique_ptr<ObserverListType>> 183 ListenerMap; 184 185 // Alert listeners of an update to the cache. 186 void AlertListeners(MDnsCache::UpdateType update_type, 187 const ListenerKey& key, const RecordParsed* record); 188 189 // Schedule a cache cleanup to a specific time, cancelling other cleanups. 190 void ScheduleCleanup(base::Time cleanup); 191 192 // Clean up the cache and schedule a new cleanup. 193 void DoCleanup(); 194 195 // Callback for when a record is removed from the cache. 196 void OnRecordRemoved(const RecordParsed* record); 197 198 void NotifyNsecRecord(const RecordParsed* record); 199 200 // Delete and erase the observer list for |key|. Only deletes the observer 201 // list if is empty. 202 void CleanupObserverList(const ListenerKey& key); 203 204 ListenerMap listeners_; 205 206 MDnsCache cache_; 207 208 raw_ptr<base::Clock> clock_; 209 raw_ptr<base::OneShotTimer> cleanup_timer_; 210 base::Time scheduled_cleanup_; 211 212 std::unique_ptr<MDnsConnection> connection_; 213 }; 214 215 MDnsClientImpl(); 216 217 // Test constructor, takes a mock clock and mock timer. 218 MDnsClientImpl(base::Clock* clock, 219 std::unique_ptr<base::OneShotTimer> cleanup_timer); 220 221 MDnsClientImpl(const MDnsClientImpl&) = delete; 222 MDnsClientImpl& operator=(const MDnsClientImpl&) = delete; 223 224 ~MDnsClientImpl() override; 225 226 // MDnsClient implementation: 227 std::unique_ptr<MDnsListener> CreateListener( 228 uint16_t rrtype, 229 const std::string& name, 230 MDnsListener::Delegate* delegate) override; 231 232 std::unique_ptr<MDnsTransaction> CreateTransaction( 233 uint16_t rrtype, 234 const std::string& name, 235 int flags, 236 const MDnsTransaction::ResultCallback& callback) override; 237 238 int StartListening(MDnsSocketFactory* socket_factory) override; 239 void StopListening() override; 240 bool IsListening() const override; 241 core()242 Core* core() { return core_.get(); } 243 244 private: 245 raw_ptr<base::Clock> clock_; 246 std::unique_ptr<base::OneShotTimer> cleanup_timer_; 247 248 std::unique_ptr<Core> core_; 249 }; 250 251 class MDnsListenerImpl : public MDnsListener, 252 public base::SupportsWeakPtr<MDnsListenerImpl> { 253 public: 254 MDnsListenerImpl(uint16_t rrtype, 255 const std::string& name, 256 base::Clock* clock, 257 MDnsListener::Delegate* delegate, 258 MDnsClientImpl* client); 259 260 MDnsListenerImpl(const MDnsListenerImpl&) = delete; 261 MDnsListenerImpl& operator=(const MDnsListenerImpl&) = delete; 262 263 ~MDnsListenerImpl() override; 264 265 // MDnsListener implementation: 266 bool Start() override; 267 268 // Actively refresh any received records. 269 void SetActiveRefresh(bool active_refresh) override; 270 271 const std::string& GetName() const override; 272 273 uint16_t GetType() const override; 274 delegate()275 MDnsListener::Delegate* delegate() { return delegate_; } 276 277 // Alert the delegate of a record update. 278 void HandleRecordUpdate(MDnsCache::UpdateType update_type, 279 const RecordParsed* record_parsed); 280 281 // Alert the delegate of the existence of an Nsec record. 282 void AlertNsecRecord(); 283 284 private: 285 void ScheduleNextRefresh(); 286 void DoRefresh(); 287 288 uint16_t rrtype_; 289 std::string name_; 290 raw_ptr<base::Clock> clock_; 291 raw_ptr<MDnsClientImpl> client_; 292 raw_ptr<MDnsListener::Delegate> delegate_; 293 294 base::Time last_update_; 295 uint32_t ttl_; 296 bool started_ = false; 297 bool active_refresh_ = false; 298 299 base::CancelableRepeatingClosure next_refresh_; 300 }; 301 302 class MDnsTransactionImpl : public base::SupportsWeakPtr<MDnsTransactionImpl>, 303 public MDnsTransaction, 304 public MDnsListener::Delegate { 305 public: 306 MDnsTransactionImpl(uint16_t rrtype, 307 const std::string& name, 308 int flags, 309 const MDnsTransaction::ResultCallback& callback, 310 MDnsClientImpl* client); 311 312 MDnsTransactionImpl(const MDnsTransactionImpl&) = delete; 313 MDnsTransactionImpl& operator=(const MDnsTransactionImpl&) = delete; 314 315 ~MDnsTransactionImpl() override; 316 317 // MDnsTransaction implementation: 318 bool Start() override; 319 320 const std::string& GetName() const override; 321 uint16_t GetType() const override; 322 323 // MDnsListener::Delegate implementation: 324 void OnRecordUpdate(MDnsListener::UpdateType update, 325 const RecordParsed* record) override; 326 void OnNsecRecord(const std::string& name, unsigned type) override; 327 328 void OnCachePurged() override; 329 330 private: is_active()331 bool is_active() { return !callback_.is_null(); } 332 333 void Reset(); 334 335 // Trigger the callback and reset all related variables. 336 void TriggerCallback(MDnsTransaction::Result result, 337 const RecordParsed* record); 338 339 // Internal callback for when a cache record is found. 340 void CacheRecordFound(const RecordParsed* record); 341 342 // Signal the transactionis over and release all related resources. 343 void SignalTransactionOver(); 344 345 // Reads records from the cache and calls the callback for every 346 // record read. 347 void ServeRecordsFromCache(); 348 349 // Send a query to the network and set up a timeout to time out the 350 // transaction. Returns false if it fails to start listening on the network 351 // or if it fails to send a query. 352 bool QueryAndListen(); 353 354 uint16_t rrtype_; 355 std::string name_; 356 MDnsTransaction::ResultCallback callback_; 357 358 std::unique_ptr<MDnsListener> listener_; 359 base::CancelableOnceCallback<void()> timeout_; 360 361 raw_ptr<MDnsClientImpl> client_; 362 363 bool started_ = false; 364 int flags_; 365 }; 366 367 } // namespace net 368 #endif // NET_DNS_MDNS_CLIENT_IMPL_H_ 369