xref: /aosp_15_r20/external/cronet/net/socket/udp_socket_win.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/socket/udp_socket_win.h"
6 
7 #include <winsock2.h>
8 
9 #include <mstcpip.h>
10 
11 #include <memory>
12 
13 #include "base/check_op.h"
14 #include "base/functional/bind.h"
15 #include "base/functional/callback.h"
16 #include "base/lazy_instance.h"
17 #include "base/memory/raw_ptr.h"
18 #include "base/metrics/histogram_functions.h"
19 #include "base/metrics/histogram_macros.h"
20 #include "base/notreached.h"
21 #include "base/rand_util.h"
22 #include "base/task/thread_pool.h"
23 #include "net/base/io_buffer.h"
24 #include "net/base/ip_address.h"
25 #include "net/base/ip_endpoint.h"
26 #include "net/base/net_errors.h"
27 #include "net/base/network_activity_monitor.h"
28 #include "net/base/network_change_notifier.h"
29 #include "net/base/sockaddr_storage.h"
30 #include "net/base/winsock_init.h"
31 #include "net/base/winsock_util.h"
32 #include "net/log/net_log.h"
33 #include "net/log/net_log_event_type.h"
34 #include "net/log/net_log_source.h"
35 #include "net/log/net_log_source_type.h"
36 #include "net/socket/socket_descriptor.h"
37 #include "net/socket/socket_options.h"
38 #include "net/socket/socket_tag.h"
39 #include "net/socket/udp_net_log_parameters.h"
40 #include "net/traffic_annotation/network_traffic_annotation.h"
41 
42 namespace net {
43 
44 // This class encapsulates all the state that has to be preserved as long as
45 // there is a network IO operation in progress. If the owner UDPSocketWin
46 // is destroyed while an operation is in progress, the Core is detached and it
47 // lives until the operation completes and the OS doesn't reference any resource
48 // declared on this class anymore.
49 class UDPSocketWin::Core : public base::RefCounted<Core> {
50  public:
51   explicit Core(UDPSocketWin* socket);
52 
53   Core(const Core&) = delete;
54   Core& operator=(const Core&) = delete;
55 
56   // Start watching for the end of a read or write operation.
57   void WatchForRead();
58   void WatchForWrite();
59 
60   // The UDPSocketWin is going away.
Detach()61   void Detach() { socket_ = nullptr; }
62 
63   // The separate OVERLAPPED variables for asynchronous operation.
64   OVERLAPPED read_overlapped_;
65   OVERLAPPED write_overlapped_;
66 
67   // The buffers used in Read() and Write().
68   scoped_refptr<IOBuffer> read_iobuffer_;
69   scoped_refptr<IOBuffer> write_iobuffer_;
70   // The struct for packet metadata passed to WSARecvMsg().
71   std::unique_ptr<WSAMSG> read_message_ = nullptr;
72   // Big enough for IP_ECN or IPV6_ECN, nothing more.
73   char read_control_buffer_[WSA_CMSG_SPACE(sizeof(int))];
74 
75   // The address storage passed to WSARecvFrom().
76   SockaddrStorage recv_addr_storage_;
77 
78  private:
79   friend class base::RefCounted<Core>;
80 
81   class ReadDelegate : public base::win::ObjectWatcher::Delegate {
82    public:
ReadDelegate(Core * core)83     explicit ReadDelegate(Core* core) : core_(core) {}
84     ~ReadDelegate() override = default;
85 
86     // base::ObjectWatcher::Delegate methods:
87     void OnObjectSignaled(HANDLE object) override;
88 
89    private:
90     const raw_ptr<Core> core_;
91   };
92 
93   class WriteDelegate : public base::win::ObjectWatcher::Delegate {
94    public:
WriteDelegate(Core * core)95     explicit WriteDelegate(Core* core) : core_(core) {}
96     ~WriteDelegate() override = default;
97 
98     // base::ObjectWatcher::Delegate methods:
99     void OnObjectSignaled(HANDLE object) override;
100 
101    private:
102     const raw_ptr<Core> core_;
103   };
104 
105   ~Core();
106 
107   // The socket that created this object.
108   raw_ptr<UDPSocketWin> socket_;
109 
110   // |reader_| handles the signals from |read_watcher_|.
111   ReadDelegate reader_;
112   // |writer_| handles the signals from |write_watcher_|.
113   WriteDelegate writer_;
114 
115   // |read_watcher_| watches for events from Read().
116   base::win::ObjectWatcher read_watcher_;
117   // |write_watcher_| watches for events from Write();
118   base::win::ObjectWatcher write_watcher_;
119 };
120 
Core(UDPSocketWin * socket)121 UDPSocketWin::Core::Core(UDPSocketWin* socket)
122     : socket_(socket),
123       reader_(this),
124       writer_(this) {
125   memset(&read_overlapped_, 0, sizeof(read_overlapped_));
126   memset(&write_overlapped_, 0, sizeof(write_overlapped_));
127 
128   read_overlapped_.hEvent = WSACreateEvent();
129   write_overlapped_.hEvent = WSACreateEvent();
130 }
131 
~Core()132 UDPSocketWin::Core::~Core() {
133   // Make sure the message loop is not watching this object anymore.
134   read_watcher_.StopWatching();
135   write_watcher_.StopWatching();
136 
137   WSACloseEvent(read_overlapped_.hEvent);
138   memset(&read_overlapped_, 0xaf, sizeof(read_overlapped_));
139   WSACloseEvent(write_overlapped_.hEvent);
140   memset(&write_overlapped_, 0xaf, sizeof(write_overlapped_));
141 }
142 
WatchForRead()143 void UDPSocketWin::Core::WatchForRead() {
144   // We grab an extra reference because there is an IO operation in progress.
145   // Balanced in ReadDelegate::OnObjectSignaled().
146   AddRef();
147   read_watcher_.StartWatchingOnce(read_overlapped_.hEvent, &reader_);
148 }
149 
WatchForWrite()150 void UDPSocketWin::Core::WatchForWrite() {
151   // We grab an extra reference because there is an IO operation in progress.
152   // Balanced in WriteDelegate::OnObjectSignaled().
153   AddRef();
154   write_watcher_.StartWatchingOnce(write_overlapped_.hEvent, &writer_);
155 }
156 
OnObjectSignaled(HANDLE object)157 void UDPSocketWin::Core::ReadDelegate::OnObjectSignaled(HANDLE object) {
158   DCHECK_EQ(object, core_->read_overlapped_.hEvent);
159   if (core_->socket_)
160     core_->socket_->DidCompleteRead();
161 
162   core_->Release();
163 }
164 
OnObjectSignaled(HANDLE object)165 void UDPSocketWin::Core::WriteDelegate::OnObjectSignaled(HANDLE object) {
166   DCHECK_EQ(object, core_->write_overlapped_.hEvent);
167   if (core_->socket_)
168     core_->socket_->DidCompleteWrite();
169 
170   core_->Release();
171 }
172 //-----------------------------------------------------------------------------
173 
QwaveApi()174 QwaveApi::QwaveApi() {
175   HMODULE qwave = LoadLibrary(L"qwave.dll");
176   if (!qwave)
177     return;
178   create_handle_func_ =
179       (CreateHandleFn)GetProcAddress(qwave, "QOSCreateHandle");
180   close_handle_func_ =
181       (CloseHandleFn)GetProcAddress(qwave, "QOSCloseHandle");
182   add_socket_to_flow_func_ =
183       (AddSocketToFlowFn)GetProcAddress(qwave, "QOSAddSocketToFlow");
184   remove_socket_from_flow_func_ =
185       (RemoveSocketFromFlowFn)GetProcAddress(qwave, "QOSRemoveSocketFromFlow");
186   set_flow_func_ = (SetFlowFn)GetProcAddress(qwave, "QOSSetFlow");
187 
188   if (create_handle_func_ && close_handle_func_ &&
189       add_socket_to_flow_func_ && remove_socket_from_flow_func_ &&
190       set_flow_func_) {
191     qwave_supported_ = true;
192   }
193 }
194 
GetDefault()195 QwaveApi* QwaveApi::GetDefault() {
196   static base::LazyInstance<QwaveApi>::Leaky lazy_qwave =
197       LAZY_INSTANCE_INITIALIZER;
198   return lazy_qwave.Pointer();
199 }
200 
qwave_supported() const201 bool QwaveApi::qwave_supported() const {
202   return qwave_supported_;
203 }
204 
OnFatalError()205 void QwaveApi::OnFatalError() {
206   // Disable everything moving forward.
207   qwave_supported_ = false;
208 }
209 
CreateHandle(PQOS_VERSION version,PHANDLE handle)210 BOOL QwaveApi::CreateHandle(PQOS_VERSION version, PHANDLE handle) {
211   return create_handle_func_(version, handle);
212 }
213 
CloseHandle(HANDLE handle)214 BOOL QwaveApi::CloseHandle(HANDLE handle) {
215   return close_handle_func_(handle);
216 }
217 
AddSocketToFlow(HANDLE handle,SOCKET socket,PSOCKADDR addr,QOS_TRAFFIC_TYPE traffic_type,DWORD flags,PQOS_FLOWID flow_id)218 BOOL QwaveApi::AddSocketToFlow(HANDLE handle,
219                                SOCKET socket,
220                                PSOCKADDR addr,
221                                QOS_TRAFFIC_TYPE traffic_type,
222                                DWORD flags,
223                                PQOS_FLOWID flow_id) {
224   return add_socket_to_flow_func_(handle, socket, addr, traffic_type, flags,
225                                   flow_id);
226 }
227 
RemoveSocketFromFlow(HANDLE handle,SOCKET socket,QOS_FLOWID flow_id,DWORD reserved)228 BOOL QwaveApi::RemoveSocketFromFlow(HANDLE handle,
229                                     SOCKET socket,
230                                     QOS_FLOWID flow_id,
231                                     DWORD reserved) {
232   return remove_socket_from_flow_func_(handle, socket, flow_id, reserved);
233 }
234 
SetFlow(HANDLE handle,QOS_FLOWID flow_id,QOS_SET_FLOW op,ULONG size,PVOID data,DWORD reserved,LPOVERLAPPED overlapped)235 BOOL QwaveApi::SetFlow(HANDLE handle,
236                        QOS_FLOWID flow_id,
237                        QOS_SET_FLOW op,
238                        ULONG size,
239                        PVOID data,
240                        DWORD reserved,
241                        LPOVERLAPPED overlapped) {
242   return set_flow_func_(handle, flow_id, op, size, data, reserved, overlapped);
243 }
244 
245 //-----------------------------------------------------------------------------
246 
UDPSocketWin(DatagramSocket::BindType bind_type,net::NetLog * net_log,const net::NetLogSource & source)247 UDPSocketWin::UDPSocketWin(DatagramSocket::BindType bind_type,
248                            net::NetLog* net_log,
249                            const net::NetLogSource& source)
250     : socket_(INVALID_SOCKET),
251       socket_options_(SOCKET_OPTION_MULTICAST_LOOP),
252       net_log_(NetLogWithSource::Make(net_log, NetLogSourceType::UDP_SOCKET)) {
253   EnsureWinsockInit();
254   net_log_.BeginEventReferencingSource(NetLogEventType::SOCKET_ALIVE, source);
255 }
256 
UDPSocketWin(DatagramSocket::BindType bind_type,NetLogWithSource source_net_log)257 UDPSocketWin::UDPSocketWin(DatagramSocket::BindType bind_type,
258                            NetLogWithSource source_net_log)
259     : socket_(INVALID_SOCKET),
260       socket_options_(SOCKET_OPTION_MULTICAST_LOOP),
261       net_log_(source_net_log) {
262   EnsureWinsockInit();
263   net_log_.BeginEventReferencingSource(NetLogEventType::SOCKET_ALIVE,
264                                        net_log_.source());
265 }
266 
~UDPSocketWin()267 UDPSocketWin::~UDPSocketWin() {
268   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
269   Close();
270   net_log_.EndEvent(NetLogEventType::SOCKET_ALIVE);
271 }
272 
Open(AddressFamily address_family)273 int UDPSocketWin::Open(AddressFamily address_family) {
274   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
275   DCHECK_EQ(socket_, INVALID_SOCKET);
276 
277   auto owned_socket_count = TryAcquireGlobalUDPSocketCount();
278   if (owned_socket_count.empty())
279     return ERR_INSUFFICIENT_RESOURCES;
280 
281   owned_socket_count_ = std::move(owned_socket_count);
282   addr_family_ = ConvertAddressFamily(address_family);
283   socket_ = CreatePlatformSocket(addr_family_, SOCK_DGRAM, IPPROTO_UDP);
284   if (socket_ == INVALID_SOCKET) {
285     owned_socket_count_.Reset();
286     return MapSystemError(WSAGetLastError());
287   }
288   ConfigureOpenedSocket();
289   return OK;
290 }
291 
AdoptOpenedSocket(AddressFamily address_family,SOCKET socket)292 int UDPSocketWin::AdoptOpenedSocket(AddressFamily address_family,
293                                     SOCKET socket) {
294   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
295   auto owned_socket_count = TryAcquireGlobalUDPSocketCount();
296   if (owned_socket_count.empty()) {
297     return ERR_INSUFFICIENT_RESOURCES;
298   }
299 
300   owned_socket_count_ = std::move(owned_socket_count);
301   addr_family_ = ConvertAddressFamily(address_family);
302   socket_ = socket;
303   ConfigureOpenedSocket();
304   return OK;
305 }
306 
ConfigureOpenedSocket()307 void UDPSocketWin::ConfigureOpenedSocket() {
308   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
309   if (!use_non_blocking_io_) {
310     core_ = base::MakeRefCounted<Core>(this);
311   } else {
312     read_write_event_.Set(WSACreateEvent());
313     WSAEventSelect(socket_, read_write_event_.Get(), FD_READ | FD_WRITE);
314   }
315 }
316 
Close()317 void UDPSocketWin::Close() {
318   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
319 
320   owned_socket_count_.Reset();
321 
322   if (socket_ == INVALID_SOCKET)
323     return;
324 
325   // Remove socket_ from the QoS subsystem before we invalidate it.
326   dscp_manager_ = nullptr;
327 
328   // Zero out any pending read/write callback state.
329   read_callback_.Reset();
330   recv_from_address_ = nullptr;
331   write_callback_.Reset();
332 
333   base::TimeTicks start_time = base::TimeTicks::Now();
334   closesocket(socket_);
335   UMA_HISTOGRAM_TIMES("Net.UDPSocketWinClose",
336                       base::TimeTicks::Now() - start_time);
337   socket_ = INVALID_SOCKET;
338   addr_family_ = 0;
339   is_connected_ = false;
340 
341   // Release buffers to free up memory.
342   read_iobuffer_ = nullptr;
343   read_iobuffer_len_ = 0;
344   write_iobuffer_ = nullptr;
345   write_iobuffer_len_ = 0;
346 
347   read_write_watcher_.StopWatching();
348   read_write_event_.Close();
349 
350   event_pending_.InvalidateWeakPtrs();
351 
352   if (core_) {
353     core_->Detach();
354     core_ = nullptr;
355   }
356 }
357 
GetPeerAddress(IPEndPoint * address) const358 int UDPSocketWin::GetPeerAddress(IPEndPoint* address) const {
359   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
360   DCHECK(address);
361   if (!is_connected())
362     return ERR_SOCKET_NOT_CONNECTED;
363 
364   // TODO(szym): Simplify. http://crbug.com/126152
365   if (!remote_address_.get()) {
366     SockaddrStorage storage;
367     if (getpeername(socket_, storage.addr, &storage.addr_len))
368       return MapSystemError(WSAGetLastError());
369     auto remote_address = std::make_unique<IPEndPoint>();
370     if (!remote_address->FromSockAddr(storage.addr, storage.addr_len))
371       return ERR_ADDRESS_INVALID;
372     remote_address_ = std::move(remote_address);
373   }
374 
375   *address = *remote_address_;
376   return OK;
377 }
378 
GetLocalAddress(IPEndPoint * address) const379 int UDPSocketWin::GetLocalAddress(IPEndPoint* address) const {
380   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
381   DCHECK(address);
382   if (!is_connected())
383     return ERR_SOCKET_NOT_CONNECTED;
384 
385   // TODO(szym): Simplify. http://crbug.com/126152
386   if (!local_address_.get()) {
387     SockaddrStorage storage;
388     if (getsockname(socket_, storage.addr, &storage.addr_len))
389       return MapSystemError(WSAGetLastError());
390     auto local_address = std::make_unique<IPEndPoint>();
391     if (!local_address->FromSockAddr(storage.addr, storage.addr_len))
392       return ERR_ADDRESS_INVALID;
393     local_address_ = std::move(local_address);
394     net_log_.AddEvent(NetLogEventType::UDP_LOCAL_ADDRESS, [&] {
395       return CreateNetLogUDPConnectParams(*local_address_,
396                                           handles::kInvalidNetworkHandle);
397     });
398   }
399 
400   *address = *local_address_;
401   return OK;
402 }
403 
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)404 int UDPSocketWin::Read(IOBuffer* buf,
405                        int buf_len,
406                        CompletionOnceCallback callback) {
407   return RecvFrom(buf, buf_len, nullptr, std::move(callback));
408 }
409 
RecvFrom(IOBuffer * buf,int buf_len,IPEndPoint * address,CompletionOnceCallback callback)410 int UDPSocketWin::RecvFrom(IOBuffer* buf,
411                            int buf_len,
412                            IPEndPoint* address,
413                            CompletionOnceCallback callback) {
414   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
415   DCHECK_NE(INVALID_SOCKET, socket_);
416   CHECK(read_callback_.is_null());
417   DCHECK(!recv_from_address_);
418   DCHECK(!callback.is_null());  // Synchronous operation not supported.
419   DCHECK_GT(buf_len, 0);
420 
421   int nread = core_ ? InternalRecvFromOverlapped(buf, buf_len, address)
422                     : InternalRecvFromNonBlocking(buf, buf_len, address);
423   if (nread != ERR_IO_PENDING)
424     return nread;
425 
426   read_callback_ = std::move(callback);
427   recv_from_address_ = address;
428   return ERR_IO_PENDING;
429 }
430 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag &)431 int UDPSocketWin::Write(
432     IOBuffer* buf,
433     int buf_len,
434     CompletionOnceCallback callback,
435     const NetworkTrafficAnnotationTag& /* traffic_annotation */) {
436   return SendToOrWrite(buf, buf_len, remote_address_.get(),
437                        std::move(callback));
438 }
439 
SendTo(IOBuffer * buf,int buf_len,const IPEndPoint & address,CompletionOnceCallback callback)440 int UDPSocketWin::SendTo(IOBuffer* buf,
441                          int buf_len,
442                          const IPEndPoint& address,
443                          CompletionOnceCallback callback) {
444   if (dscp_manager_) {
445     // Alert DscpManager in case this is a new remote address.  Failure to
446     // apply Dscp code is never fatal.
447     int rv = dscp_manager_->PrepareForSend(address);
448     if (rv != OK)
449       net_log_.AddEventWithNetErrorCode(NetLogEventType::UDP_SEND_ERROR, rv);
450   }
451   return SendToOrWrite(buf, buf_len, &address, std::move(callback));
452 }
453 
SendToOrWrite(IOBuffer * buf,int buf_len,const IPEndPoint * address,CompletionOnceCallback callback)454 int UDPSocketWin::SendToOrWrite(IOBuffer* buf,
455                                 int buf_len,
456                                 const IPEndPoint* address,
457                                 CompletionOnceCallback callback) {
458   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
459   DCHECK_NE(INVALID_SOCKET, socket_);
460   CHECK(write_callback_.is_null());
461   DCHECK(!callback.is_null());  // Synchronous operation not supported.
462   DCHECK_GT(buf_len, 0);
463   DCHECK(!send_to_address_.get());
464 
465   int nwrite = core_ ? InternalSendToOverlapped(buf, buf_len, address)
466                      : InternalSendToNonBlocking(buf, buf_len, address);
467   if (nwrite != ERR_IO_PENDING)
468     return nwrite;
469 
470   if (address)
471     send_to_address_ = std::make_unique<IPEndPoint>(*address);
472   write_callback_ = std::move(callback);
473   return ERR_IO_PENDING;
474 }
475 
Connect(const IPEndPoint & address)476 int UDPSocketWin::Connect(const IPEndPoint& address) {
477   DCHECK_NE(socket_, INVALID_SOCKET);
478   net_log_.BeginEvent(NetLogEventType::UDP_CONNECT, [&] {
479     return CreateNetLogUDPConnectParams(address,
480                                         handles::kInvalidNetworkHandle);
481   });
482   int rv = SetMulticastOptions();
483   if (rv != OK)
484     return rv;
485   rv = InternalConnect(address);
486   net_log_.EndEventWithNetErrorCode(NetLogEventType::UDP_CONNECT, rv);
487   is_connected_ = (rv == OK);
488   return rv;
489 }
490 
InternalConnect(const IPEndPoint & address)491 int UDPSocketWin::InternalConnect(const IPEndPoint& address) {
492   DCHECK(!is_connected());
493   DCHECK(!remote_address_.get());
494 
495   // Always do a random bind.
496   // Ignore failures, which may happen if the socket was already bound.
497   DWORD randomize_port_value = 1;
498   setsockopt(socket_, SOL_SOCKET, SO_RANDOMIZE_PORT,
499              reinterpret_cast<const char*>(&randomize_port_value),
500              sizeof(randomize_port_value));
501 
502   SockaddrStorage storage;
503   if (!address.ToSockAddr(storage.addr, &storage.addr_len))
504     return ERR_ADDRESS_INVALID;
505 
506   int rv = connect(socket_, storage.addr, storage.addr_len);
507   if (rv < 0)
508     return MapSystemError(WSAGetLastError());
509 
510   remote_address_ = std::make_unique<IPEndPoint>(address);
511 
512   if (dscp_manager_)
513     dscp_manager_->PrepareForSend(*remote_address_.get());
514 
515   return rv;
516 }
517 
Bind(const IPEndPoint & address)518 int UDPSocketWin::Bind(const IPEndPoint& address) {
519   DCHECK_NE(socket_, INVALID_SOCKET);
520   DCHECK(!is_connected());
521 
522   int rv = SetMulticastOptions();
523   if (rv < 0)
524     return rv;
525 
526   rv = DoBind(address);
527   if (rv < 0)
528     return rv;
529 
530   local_address_.reset();
531   is_connected_ = true;
532   return rv;
533 }
534 
BindToNetwork(handles::NetworkHandle network)535 int UDPSocketWin::BindToNetwork(handles::NetworkHandle network) {
536   NOTIMPLEMENTED();
537   return ERR_NOT_IMPLEMENTED;
538 }
539 
SetReceiveBufferSize(int32_t size)540 int UDPSocketWin::SetReceiveBufferSize(int32_t size) {
541   DCHECK_NE(socket_, INVALID_SOCKET);
542   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
543   int rv = SetSocketReceiveBufferSize(socket_, size);
544 
545   if (rv != 0)
546     return MapSystemError(WSAGetLastError());
547 
548   // According to documentation, setsockopt may succeed, but we need to check
549   // the results via getsockopt to be sure it works on Windows.
550   int32_t actual_size = 0;
551   int option_size = sizeof(actual_size);
552   rv = getsockopt(socket_, SOL_SOCKET, SO_RCVBUF,
553                   reinterpret_cast<char*>(&actual_size), &option_size);
554   if (rv != 0)
555     return MapSystemError(WSAGetLastError());
556   if (actual_size >= size)
557     return OK;
558   UMA_HISTOGRAM_CUSTOM_COUNTS("Net.SocketUnchangeableReceiveBuffer",
559                               actual_size, 1000, 1000000, 50);
560   return ERR_SOCKET_RECEIVE_BUFFER_SIZE_UNCHANGEABLE;
561 }
562 
SetSendBufferSize(int32_t size)563 int UDPSocketWin::SetSendBufferSize(int32_t size) {
564   DCHECK_NE(socket_, INVALID_SOCKET);
565   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
566   int rv = SetSocketSendBufferSize(socket_, size);
567   if (rv != 0)
568     return MapSystemError(WSAGetLastError());
569   // According to documentation, setsockopt may succeed, but we need to check
570   // the results via getsockopt to be sure it works on Windows.
571   int32_t actual_size = 0;
572   int option_size = sizeof(actual_size);
573   rv = getsockopt(socket_, SOL_SOCKET, SO_SNDBUF,
574                   reinterpret_cast<char*>(&actual_size), &option_size);
575   if (rv != 0)
576     return MapSystemError(WSAGetLastError());
577   if (actual_size >= size)
578     return OK;
579   UMA_HISTOGRAM_CUSTOM_COUNTS("Net.SocketUnchangeableSendBuffer",
580                               actual_size, 1000, 1000000, 50);
581   return ERR_SOCKET_SEND_BUFFER_SIZE_UNCHANGEABLE;
582 }
583 
SetDoNotFragment()584 int UDPSocketWin::SetDoNotFragment() {
585   DCHECK_NE(socket_, INVALID_SOCKET);
586   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
587 
588   if (addr_family_ == AF_INET6)
589     return OK;
590 
591   DWORD val = 1;
592   int rv = setsockopt(socket_, IPPROTO_IP, IP_DONTFRAGMENT,
593                       reinterpret_cast<const char*>(&val), sizeof(val));
594   return rv == 0 ? OK : MapSystemError(WSAGetLastError());
595 }
596 
GetRecvMsgPointer()597 LPFN_WSARECVMSG UDPSocketWin::GetRecvMsgPointer() {
598   LPFN_WSARECVMSG rv;
599   GUID message_code = WSAID_WSARECVMSG;
600   DWORD size;
601   if (WSAIoctl(socket_, SIO_GET_EXTENSION_FUNCTION_POINTER, &message_code,
602                sizeof(message_code), &rv, sizeof(rv), &size, NULL,
603                NULL) == SOCKET_ERROR) {
604     return nullptr;
605   }
606   return rv;
607 }
608 
GetSendMsgPointer()609 LPFN_WSASENDMSG UDPSocketWin::GetSendMsgPointer() {
610   LPFN_WSASENDMSG rv;
611   GUID message_code = WSAID_WSASENDMSG;
612   DWORD size;
613   if (WSAIoctl(socket_, SIO_GET_EXTENSION_FUNCTION_POINTER, &message_code,
614                sizeof(message_code), &rv, sizeof(rv), &size, NULL,
615                NULL) == SOCKET_ERROR) {
616     return nullptr;
617   }
618   return rv;
619 }
620 
SetRecvTos()621 int UDPSocketWin::SetRecvTos() {
622   DCHECK_NE(socket_, INVALID_SOCKET);
623   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
624   int rv = WSASetRecvIPEcn(socket_, TRUE);
625   if (rv != 0) {
626     int os_error = WSAGetLastError();
627     int result = MapSystemError(os_error);
628     LogRead(result, nullptr, nullptr);
629     return result;
630   }
631   wsa_recv_msg_ = GetRecvMsgPointer();
632   if (wsa_recv_msg_ == nullptr) {
633     int os_error = WSAGetLastError();
634     int result = MapSystemError(os_error);
635     LogRead(result, nullptr, nullptr);
636     return result;
637   }
638   report_ecn_ = true;
639   return 0;
640 }
641 
SetMsgConfirm(bool confirm)642 void UDPSocketWin::SetMsgConfirm(bool confirm) {}
643 
AllowAddressReuse()644 int UDPSocketWin::AllowAddressReuse() {
645   DCHECK_NE(socket_, INVALID_SOCKET);
646   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
647   DCHECK(!is_connected());
648 
649   BOOL true_value = TRUE;
650   int rv = setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR,
651                       reinterpret_cast<const char*>(&true_value),
652                       sizeof(true_value));
653   return rv == 0 ? OK : MapSystemError(WSAGetLastError());
654 }
655 
SetBroadcast(bool broadcast)656 int UDPSocketWin::SetBroadcast(bool broadcast) {
657   DCHECK_NE(socket_, INVALID_SOCKET);
658   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
659 
660   BOOL value = broadcast ? TRUE : FALSE;
661   int rv = setsockopt(socket_, SOL_SOCKET, SO_BROADCAST,
662                       reinterpret_cast<const char*>(&value), sizeof(value));
663   return rv == 0 ? OK : MapSystemError(WSAGetLastError());
664 }
665 
AllowAddressSharingForMulticast()666 int UDPSocketWin::AllowAddressSharingForMulticast() {
667   // When proper multicast groups are used, Windows further defines the address
668   // resuse option (SO_REUSEADDR) to ensure all listening sockets can receive
669   // all incoming messages for the multicast group.
670   return AllowAddressReuse();
671 }
672 
DoReadCallback(int rv)673 void UDPSocketWin::DoReadCallback(int rv) {
674   DCHECK_NE(rv, ERR_IO_PENDING);
675   DCHECK(!read_callback_.is_null());
676 
677   // since Run may result in Read being called, clear read_callback_ up front.
678   std::move(read_callback_).Run(rv);
679 }
680 
DoWriteCallback(int rv)681 void UDPSocketWin::DoWriteCallback(int rv) {
682   DCHECK_NE(rv, ERR_IO_PENDING);
683   DCHECK(!write_callback_.is_null());
684 
685   // since Run may result in Write being called, clear write_callback_ up front.
686   std::move(write_callback_).Run(rv);
687 }
688 
DidCompleteRead()689 void UDPSocketWin::DidCompleteRead() {
690   DWORD num_bytes, flags;
691   BOOL ok = WSAGetOverlappedResult(socket_, &core_->read_overlapped_,
692                                    &num_bytes, FALSE, &flags);
693   WSAResetEvent(core_->read_overlapped_.hEvent);
694   int result = ok ? num_bytes : MapSystemError(WSAGetLastError());
695   // Convert address.
696   IPEndPoint address;
697   IPEndPoint* address_to_log = nullptr;
698   if (result >= 0) {
699     if (address.FromSockAddr(core_->recv_addr_storage_.addr,
700                              core_->recv_addr_storage_.addr_len)) {
701       if (recv_from_address_)
702         *recv_from_address_ = address;
703       address_to_log = &address;
704     } else {
705       result = ERR_ADDRESS_INVALID;
706     }
707     if (core_->read_message_ != nullptr) {
708       SetLastTosFromWSAMSG(*core_->read_message_);
709     }
710   }
711   LogRead(result, core_->read_iobuffer_->data(), address_to_log);
712   core_->read_iobuffer_ = nullptr;
713   core_->read_message_ = nullptr;
714   recv_from_address_ = nullptr;
715   DoReadCallback(result);
716 }
717 
DidCompleteWrite()718 void UDPSocketWin::DidCompleteWrite() {
719   DWORD num_bytes, flags;
720   BOOL ok = WSAGetOverlappedResult(socket_, &core_->write_overlapped_,
721                                    &num_bytes, FALSE, &flags);
722   WSAResetEvent(core_->write_overlapped_.hEvent);
723   int result = ok ? num_bytes : MapSystemError(WSAGetLastError());
724   LogWrite(result, core_->write_iobuffer_->data(), send_to_address_.get());
725 
726   send_to_address_.reset();
727   core_->write_iobuffer_ = nullptr;
728   DoWriteCallback(result);
729 }
730 
OnObjectSignaled(HANDLE object)731 void UDPSocketWin::OnObjectSignaled(HANDLE object) {
732   DCHECK(object == read_write_event_.Get());
733   WSANETWORKEVENTS network_events;
734   int os_error = 0;
735   int rv =
736       WSAEnumNetworkEvents(socket_, read_write_event_.Get(), &network_events);
737   // Protects against trying to call the write callback if the read callback
738   // either closes or destroys |this|.
739   base::WeakPtr<UDPSocketWin> event_pending = event_pending_.GetWeakPtr();
740   if (rv == SOCKET_ERROR) {
741     os_error = WSAGetLastError();
742     rv = MapSystemError(os_error);
743 
744     if (read_iobuffer_) {
745       read_iobuffer_ = nullptr;
746       read_iobuffer_len_ = 0;
747       recv_from_address_ = nullptr;
748       DoReadCallback(rv);
749     }
750 
751     // Socket may have been closed or destroyed here.
752     if (event_pending && write_iobuffer_) {
753       write_iobuffer_ = nullptr;
754       write_iobuffer_len_ = 0;
755       send_to_address_.reset();
756       DoWriteCallback(rv);
757     }
758     return;
759   }
760 
761   if ((network_events.lNetworkEvents & FD_READ) && read_iobuffer_)
762     OnReadSignaled();
763   if (!event_pending)
764     return;
765 
766   if ((network_events.lNetworkEvents & FD_WRITE) && write_iobuffer_)
767     OnWriteSignaled();
768   if (!event_pending)
769     return;
770 
771   // There's still pending read / write. Watch for further events.
772   if (read_iobuffer_ || write_iobuffer_)
773     WatchForReadWrite();
774 }
775 
OnReadSignaled()776 void UDPSocketWin::OnReadSignaled() {
777   int rv = InternalRecvFromNonBlocking(read_iobuffer_.get(), read_iobuffer_len_,
778                                        recv_from_address_);
779   if (rv == ERR_IO_PENDING)
780     return;
781   read_iobuffer_ = nullptr;
782   read_iobuffer_len_ = 0;
783   recv_from_address_ = nullptr;
784   DoReadCallback(rv);
785 }
786 
OnWriteSignaled()787 void UDPSocketWin::OnWriteSignaled() {
788   int rv = InternalSendToNonBlocking(write_iobuffer_.get(), write_iobuffer_len_,
789                                      send_to_address_.get());
790   if (rv == ERR_IO_PENDING)
791     return;
792   write_iobuffer_ = nullptr;
793   write_iobuffer_len_ = 0;
794   send_to_address_.reset();
795   DoWriteCallback(rv);
796 }
797 
WatchForReadWrite()798 void UDPSocketWin::WatchForReadWrite() {
799   if (read_write_watcher_.IsWatching())
800     return;
801   bool watched =
802       read_write_watcher_.StartWatchingOnce(read_write_event_.Get(), this);
803   DCHECK(watched);
804 }
805 
LogRead(int result,const char * bytes,const IPEndPoint * address) const806 void UDPSocketWin::LogRead(int result,
807                            const char* bytes,
808                            const IPEndPoint* address) const {
809   if (result < 0) {
810     net_log_.AddEventWithNetErrorCode(NetLogEventType::UDP_RECEIVE_ERROR,
811                                       result);
812     return;
813   }
814 
815   if (net_log_.IsCapturing()) {
816     NetLogUDPDataTransfer(net_log_, NetLogEventType::UDP_BYTES_RECEIVED, result,
817                           bytes, address);
818   }
819 
820   activity_monitor::IncrementBytesReceived(result);
821 }
822 
LogWrite(int result,const char * bytes,const IPEndPoint * address) const823 void UDPSocketWin::LogWrite(int result,
824                             const char* bytes,
825                             const IPEndPoint* address) const {
826   if (result < 0) {
827     net_log_.AddEventWithNetErrorCode(NetLogEventType::UDP_SEND_ERROR, result);
828     return;
829   }
830 
831   if (net_log_.IsCapturing()) {
832     NetLogUDPDataTransfer(net_log_, NetLogEventType::UDP_BYTES_SENT, result,
833                           bytes, address);
834   }
835 }
836 
PopulateWSAMSG(WSAMSG & message,SockaddrStorage & storage,WSABUF * data_buffer,WSABUF & control_buffer,bool send)837 void UDPSocketWin::PopulateWSAMSG(WSAMSG& message,
838                                   SockaddrStorage& storage,
839                                   WSABUF* data_buffer,
840                                   WSABUF& control_buffer,
841                                   bool send) {
842   bool is_ipv6 = addr_family_ == AF_INET6;
843   message.name = storage.addr;
844   message.namelen = storage.addr_len;
845   message.lpBuffers = data_buffer;
846   message.dwBufferCount = 1;
847   message.Control.buf = control_buffer.buf;
848   message.dwFlags = 0;
849   if (send) {
850     message.Control.len = 0;
851     WSACMSGHDR* cmsg;
852     message.Control.len += WSA_CMSG_SPACE(sizeof(int));
853     cmsg = WSA_CMSG_FIRSTHDR(&message);
854     cmsg->cmsg_len = WSA_CMSG_LEN(sizeof(int));
855     cmsg->cmsg_level = is_ipv6 ? IPPROTO_IPV6 : IPPROTO_IP;
856     cmsg->cmsg_type = is_ipv6 ? IPV6_ECN : IP_ECN;
857     *(int*)WSA_CMSG_DATA(cmsg) = static_cast<int>(send_ecn_);
858   } else {
859     message.Control.len = control_buffer.len;
860   }
861 }
862 
SetLastTosFromWSAMSG(WSAMSG & message)863 void UDPSocketWin::SetLastTosFromWSAMSG(WSAMSG& message) {
864   int ecn = 0;
865   for (WSACMSGHDR* cmsg = WSA_CMSG_FIRSTHDR(&message); cmsg != NULL;
866        cmsg = WSA_CMSG_NXTHDR(&message, cmsg)) {
867     if ((cmsg->cmsg_level == IPPROTO_IPV6 && cmsg->cmsg_type == IPV6_ECN) ||
868         (cmsg->cmsg_level == IPPROTO_IP && cmsg->cmsg_type == IP_ECN)) {
869       ecn = *(int*)WSA_CMSG_DATA(cmsg);
870       break;
871     }
872   }
873   last_tos_.ecn = static_cast<EcnCodePoint>(ecn);
874 }
875 
InternalRecvFromOverlapped(IOBuffer * buf,int buf_len,IPEndPoint * address)876 int UDPSocketWin::InternalRecvFromOverlapped(IOBuffer* buf,
877                                              int buf_len,
878                                              IPEndPoint* address) {
879   DCHECK(!core_->read_iobuffer_.get());
880   DCHECK(!core_->read_message_.get());
881   SockaddrStorage& storage = core_->recv_addr_storage_;
882   storage.addr_len = sizeof(storage.addr_storage);
883 
884   WSABUF read_buffer;
885   read_buffer.buf = buf->data();
886   read_buffer.len = buf_len;
887 
888   DWORD flags = 0;
889   DWORD num;
890   CHECK_NE(INVALID_SOCKET, socket_);
891   int rv;
892   std::unique_ptr<WSAMSG> message;
893   if (report_ecn_) {
894     WSABUF control_buffer;
895     control_buffer.buf = core_->read_control_buffer_;
896     control_buffer.len = sizeof(core_->read_control_buffer_);
897     message = std::make_unique<WSAMSG>();
898     if (message == nullptr) {
899       return WSA_NOT_ENOUGH_MEMORY;
900     }
901     PopulateWSAMSG(*message, storage, &read_buffer, control_buffer, false);
902     rv = wsa_recv_msg_(socket_, message.get(), &num, &core_->read_overlapped_,
903                        nullptr);
904     if (rv == 0) {
905       SetLastTosFromWSAMSG(*message);
906     }
907   } else {
908     rv = WSARecvFrom(socket_, &read_buffer, 1, &num, &flags, storage.addr,
909                      &storage.addr_len, &core_->read_overlapped_, nullptr);
910   }
911   if (rv == 0) {
912     if (ResetEventIfSignaled(core_->read_overlapped_.hEvent)) {
913       int result = num;
914       // Convert address.
915       IPEndPoint address_storage;
916       IPEndPoint* address_to_log = nullptr;
917       if (result >= 0) {
918         if (address_storage.FromSockAddr(core_->recv_addr_storage_.addr,
919                                          core_->recv_addr_storage_.addr_len)) {
920           if (address)
921             *address = address_storage;
922           address_to_log = &address_storage;
923         } else {
924           result = ERR_ADDRESS_INVALID;
925         }
926       }
927       LogRead(result, buf->data(), address_to_log);
928       return result;
929     }
930   } else {
931     int os_error = WSAGetLastError();
932     if (os_error != WSA_IO_PENDING) {
933       int result = MapSystemError(os_error);
934       LogRead(result, nullptr, nullptr);
935       return result;
936     }
937   }
938   core_->WatchForRead();
939   core_->read_iobuffer_ = buf;
940   core_->read_message_ = std::move(message);
941   return ERR_IO_PENDING;
942 }
943 
InternalSendToOverlapped(IOBuffer * buf,int buf_len,const IPEndPoint * address)944 int UDPSocketWin::InternalSendToOverlapped(IOBuffer* buf,
945                                            int buf_len,
946                                            const IPEndPoint* address) {
947   DCHECK(!core_->write_iobuffer_.get());
948   SockaddrStorage storage;
949   struct sockaddr* addr = storage.addr;
950   // Convert address.
951   if (!address) {
952     addr = nullptr;
953     storage.addr_len = 0;
954   } else {
955     if (!address->ToSockAddr(addr, &storage.addr_len)) {
956       int result = ERR_ADDRESS_INVALID;
957       LogWrite(result, nullptr, nullptr);
958       return result;
959     }
960   }
961 
962   WSABUF write_buffer;
963   write_buffer.buf = buf->data();
964   write_buffer.len = buf_len;
965 
966   DWORD flags = 0;
967   DWORD num;
968   int rv;
969   if (send_ecn_ != ECN_NOT_ECT) {
970     WSABUF control_buffer;
971     char raw_control_buffer[WSA_CMSG_SPACE(sizeof(int))];
972     control_buffer.buf = raw_control_buffer;
973     control_buffer.len = sizeof(raw_control_buffer);
974     WSAMSG message;
975     PopulateWSAMSG(message, storage, &write_buffer, control_buffer, true);
976     rv = wsa_send_msg_(socket_, &message, flags, &num,
977                        &core_->write_overlapped_, nullptr);
978   } else {
979     rv = WSASendTo(socket_, &write_buffer, 1, &num, flags, addr,
980                    storage.addr_len, &core_->write_overlapped_, nullptr);
981   }
982   if (rv == 0) {
983     if (ResetEventIfSignaled(core_->write_overlapped_.hEvent)) {
984       int result = num;
985       LogWrite(result, buf->data(), address);
986       return result;
987     }
988   } else {
989     int os_error = WSAGetLastError();
990     if (os_error != WSA_IO_PENDING) {
991       int result = MapSystemError(os_error);
992       LogWrite(result, nullptr, nullptr);
993       return result;
994     }
995   }
996 
997   core_->WatchForWrite();
998   core_->write_iobuffer_ = buf;
999   return ERR_IO_PENDING;
1000 }
1001 
InternalRecvFromNonBlocking(IOBuffer * buf,int buf_len,IPEndPoint * address)1002 int UDPSocketWin::InternalRecvFromNonBlocking(IOBuffer* buf,
1003                                               int buf_len,
1004                                               IPEndPoint* address) {
1005   DCHECK(!read_iobuffer_ || read_iobuffer_.get() == buf);
1006   SockaddrStorage storage;
1007   storage.addr_len = sizeof(storage.addr_storage);
1008 
1009   CHECK_NE(INVALID_SOCKET, socket_);
1010 
1011   int rv;
1012   if (report_ecn_) {
1013     WSABUF read_buffer;
1014     read_buffer.buf = buf->data();
1015     read_buffer.len = buf_len;
1016     WSABUF control_buffer;
1017     char raw_control_buffer[WSA_CMSG_SPACE(sizeof(INT))];
1018     control_buffer.buf = raw_control_buffer;
1019     control_buffer.len = sizeof(raw_control_buffer);
1020     WSAMSG message;
1021     DWORD bytes_read;
1022     PopulateWSAMSG(message, storage, &read_buffer, control_buffer, false);
1023     rv = wsa_recv_msg_(socket_, &message, &bytes_read, nullptr, nullptr);
1024     SetLastTosFromWSAMSG(message);
1025     if (rv == 0) {
1026       rv = bytes_read;  // WSARecvMsg() returns zero on delivery, but recvfrom
1027                         // returns the number of bytes received.
1028     }
1029   } else {
1030     rv = recvfrom(socket_, buf->data(), buf_len, 0, storage.addr,
1031                   &storage.addr_len);
1032   }
1033   if (rv == SOCKET_ERROR) {
1034     int os_error = WSAGetLastError();
1035     if (os_error == WSAEWOULDBLOCK) {
1036       read_iobuffer_ = buf;
1037       read_iobuffer_len_ = buf_len;
1038       WatchForReadWrite();
1039       return ERR_IO_PENDING;
1040     }
1041     rv = MapSystemError(os_error);
1042     LogRead(rv, nullptr, nullptr);
1043     return rv;
1044   }
1045   IPEndPoint address_storage;
1046   IPEndPoint* address_to_log = nullptr;
1047   if (rv >= 0) {
1048     if (address_storage.FromSockAddr(storage.addr, storage.addr_len)) {
1049       if (address)
1050         *address = address_storage;
1051       address_to_log = &address_storage;
1052     } else {
1053       rv = ERR_ADDRESS_INVALID;
1054     }
1055   }
1056   LogRead(rv, buf->data(), address_to_log);
1057   return rv;
1058 }
1059 
InternalSendToNonBlocking(IOBuffer * buf,int buf_len,const IPEndPoint * address)1060 int UDPSocketWin::InternalSendToNonBlocking(IOBuffer* buf,
1061                                             int buf_len,
1062                                             const IPEndPoint* address) {
1063   DCHECK(!write_iobuffer_ || write_iobuffer_.get() == buf);
1064   SockaddrStorage storage;
1065   struct sockaddr* addr = storage.addr;
1066   // Convert address.
1067   if (address) {
1068     if (!address->ToSockAddr(addr, &storage.addr_len)) {
1069       int result = ERR_ADDRESS_INVALID;
1070       LogWrite(result, nullptr, nullptr);
1071       return result;
1072     }
1073   } else {
1074     addr = nullptr;
1075     storage.addr_len = 0;
1076   }
1077 
1078   int rv;
1079   if (send_ecn_ != ECN_NOT_ECT) {
1080     char raw_control_buffer[WSA_CMSG_SPACE(sizeof(INT))];
1081     WSABUF write_buffer;
1082     write_buffer.buf = buf->data();
1083     write_buffer.len = buf_len;
1084     WSABUF control_buffer;
1085     control_buffer.buf = raw_control_buffer;
1086     control_buffer.len = sizeof(raw_control_buffer);
1087     WSAMSG message;
1088     DWORD bytes_read;
1089     PopulateWSAMSG(message, storage, &write_buffer, control_buffer, true);
1090     rv = wsa_send_msg_(socket_, &message, 0, &bytes_read, nullptr, nullptr);
1091     if (rv == 0) {
1092       rv = bytes_read;
1093     }
1094   } else {
1095     rv = sendto(socket_, buf->data(), buf_len, 0, addr, storage.addr_len);
1096   }
1097   if (rv == SOCKET_ERROR) {
1098     int os_error = WSAGetLastError();
1099     if (os_error == WSAEWOULDBLOCK) {
1100       write_iobuffer_ = buf;
1101       write_iobuffer_len_ = buf_len;
1102       WatchForReadWrite();
1103       return ERR_IO_PENDING;
1104     }
1105     rv = MapSystemError(os_error);
1106     LogWrite(rv, nullptr, nullptr);
1107     return rv;
1108   }
1109   LogWrite(rv, buf->data(), address);
1110   return rv;
1111 }
1112 
SetMulticastOptions()1113 int UDPSocketWin::SetMulticastOptions() {
1114   if (!(socket_options_ & SOCKET_OPTION_MULTICAST_LOOP)) {
1115     DWORD loop = 0;
1116     int protocol_level =
1117         addr_family_ == AF_INET ? IPPROTO_IP : IPPROTO_IPV6;
1118     int option =
1119         addr_family_ == AF_INET ? IP_MULTICAST_LOOP: IPV6_MULTICAST_LOOP;
1120     int rv = setsockopt(socket_, protocol_level, option,
1121                         reinterpret_cast<const char*>(&loop), sizeof(loop));
1122     if (rv < 0)
1123       return MapSystemError(WSAGetLastError());
1124   }
1125   if (multicast_time_to_live_ != 1) {
1126     DWORD hops = multicast_time_to_live_;
1127     int protocol_level =
1128         addr_family_ == AF_INET ? IPPROTO_IP : IPPROTO_IPV6;
1129     int option =
1130         addr_family_ == AF_INET ? IP_MULTICAST_TTL: IPV6_MULTICAST_HOPS;
1131     int rv = setsockopt(socket_, protocol_level, option,
1132                         reinterpret_cast<const char*>(&hops), sizeof(hops));
1133     if (rv < 0)
1134       return MapSystemError(WSAGetLastError());
1135   }
1136   if (multicast_interface_ != 0) {
1137     switch (addr_family_) {
1138       case AF_INET: {
1139         in_addr address;
1140         address.s_addr = htonl(multicast_interface_);
1141         int rv = setsockopt(socket_, IPPROTO_IP, IP_MULTICAST_IF,
1142                             reinterpret_cast<const char*>(&address),
1143                             sizeof(address));
1144         if (rv)
1145           return MapSystemError(WSAGetLastError());
1146         break;
1147       }
1148       case AF_INET6: {
1149         uint32_t interface_index = multicast_interface_;
1150         int rv = setsockopt(socket_, IPPROTO_IPV6, IPV6_MULTICAST_IF,
1151                             reinterpret_cast<const char*>(&interface_index),
1152                             sizeof(interface_index));
1153         if (rv)
1154           return MapSystemError(WSAGetLastError());
1155         break;
1156       }
1157       default:
1158         NOTREACHED() << "Invalid address family";
1159         return ERR_ADDRESS_INVALID;
1160     }
1161   }
1162   return OK;
1163 }
1164 
DoBind(const IPEndPoint & address)1165 int UDPSocketWin::DoBind(const IPEndPoint& address) {
1166   SockaddrStorage storage;
1167   if (!address.ToSockAddr(storage.addr, &storage.addr_len))
1168     return ERR_ADDRESS_INVALID;
1169   int rv = bind(socket_, storage.addr, storage.addr_len);
1170   if (rv == 0)
1171     return OK;
1172   int last_error = WSAGetLastError();
1173   // Map some codes that are special to bind() separately.
1174   // * WSAEACCES: If a port is already bound to a socket, WSAEACCES may be
1175   //   returned instead of WSAEADDRINUSE, depending on whether the socket
1176   //   option SO_REUSEADDR or SO_EXCLUSIVEADDRUSE is set and whether the
1177   //   conflicting socket is owned by a different user account. See the MSDN
1178   //   page "Using SO_REUSEADDR and SO_EXCLUSIVEADDRUSE" for the gory details.
1179   if (last_error == WSAEACCES || last_error == WSAEADDRNOTAVAIL)
1180     return ERR_ADDRESS_IN_USE;
1181   return MapSystemError(last_error);
1182 }
1183 
GetQwaveApi() const1184 QwaveApi* UDPSocketWin::GetQwaveApi() const {
1185   return QwaveApi::GetDefault();
1186 }
1187 
JoinGroup(const IPAddress & group_address) const1188 int UDPSocketWin::JoinGroup(const IPAddress& group_address) const {
1189   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
1190   if (!is_connected())
1191     return ERR_SOCKET_NOT_CONNECTED;
1192 
1193   switch (group_address.size()) {
1194     case IPAddress::kIPv4AddressSize: {
1195       if (addr_family_ != AF_INET)
1196         return ERR_ADDRESS_INVALID;
1197       ip_mreq mreq;
1198       mreq.imr_interface.s_addr = htonl(multicast_interface_);
1199       memcpy(&mreq.imr_multiaddr, group_address.bytes().data(),
1200              IPAddress::kIPv4AddressSize);
1201       int rv = setsockopt(socket_, IPPROTO_IP, IP_ADD_MEMBERSHIP,
1202                           reinterpret_cast<const char*>(&mreq),
1203                           sizeof(mreq));
1204       if (rv)
1205         return MapSystemError(WSAGetLastError());
1206       return OK;
1207     }
1208     case IPAddress::kIPv6AddressSize: {
1209       if (addr_family_ != AF_INET6)
1210         return ERR_ADDRESS_INVALID;
1211       ipv6_mreq mreq;
1212       mreq.ipv6mr_interface = multicast_interface_;
1213       memcpy(&mreq.ipv6mr_multiaddr, group_address.bytes().data(),
1214              IPAddress::kIPv6AddressSize);
1215       int rv = setsockopt(socket_, IPPROTO_IPV6, IPV6_ADD_MEMBERSHIP,
1216                           reinterpret_cast<const char*>(&mreq),
1217                           sizeof(mreq));
1218       if (rv)
1219         return MapSystemError(WSAGetLastError());
1220       return OK;
1221     }
1222     default:
1223       NOTREACHED() << "Invalid address family";
1224       return ERR_ADDRESS_INVALID;
1225   }
1226 }
1227 
LeaveGroup(const IPAddress & group_address) const1228 int UDPSocketWin::LeaveGroup(const IPAddress& group_address) const {
1229   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
1230   if (!is_connected())
1231     return ERR_SOCKET_NOT_CONNECTED;
1232 
1233   switch (group_address.size()) {
1234     case IPAddress::kIPv4AddressSize: {
1235       if (addr_family_ != AF_INET)
1236         return ERR_ADDRESS_INVALID;
1237       ip_mreq mreq;
1238       mreq.imr_interface.s_addr = htonl(multicast_interface_);
1239       memcpy(&mreq.imr_multiaddr, group_address.bytes().data(),
1240              IPAddress::kIPv4AddressSize);
1241       int rv = setsockopt(socket_, IPPROTO_IP, IP_DROP_MEMBERSHIP,
1242                           reinterpret_cast<const char*>(&mreq), sizeof(mreq));
1243       if (rv)
1244         return MapSystemError(WSAGetLastError());
1245       return OK;
1246     }
1247     case IPAddress::kIPv6AddressSize: {
1248       if (addr_family_ != AF_INET6)
1249         return ERR_ADDRESS_INVALID;
1250       ipv6_mreq mreq;
1251       mreq.ipv6mr_interface = multicast_interface_;
1252       memcpy(&mreq.ipv6mr_multiaddr, group_address.bytes().data(),
1253              IPAddress::kIPv6AddressSize);
1254       int rv = setsockopt(socket_, IPPROTO_IPV6, IP_DROP_MEMBERSHIP,
1255                           reinterpret_cast<const char*>(&mreq), sizeof(mreq));
1256       if (rv)
1257         return MapSystemError(WSAGetLastError());
1258       return OK;
1259     }
1260     default:
1261       NOTREACHED() << "Invalid address family";
1262       return ERR_ADDRESS_INVALID;
1263   }
1264 }
1265 
SetMulticastInterface(uint32_t interface_index)1266 int UDPSocketWin::SetMulticastInterface(uint32_t interface_index) {
1267   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
1268   if (is_connected())
1269     return ERR_SOCKET_IS_CONNECTED;
1270   multicast_interface_ = interface_index;
1271   return OK;
1272 }
1273 
SetMulticastTimeToLive(int time_to_live)1274 int UDPSocketWin::SetMulticastTimeToLive(int time_to_live) {
1275   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
1276   if (is_connected())
1277     return ERR_SOCKET_IS_CONNECTED;
1278 
1279   if (time_to_live < 0 || time_to_live > 255)
1280     return ERR_INVALID_ARGUMENT;
1281   multicast_time_to_live_ = time_to_live;
1282   return OK;
1283 }
1284 
SetMulticastLoopbackMode(bool loopback)1285 int UDPSocketWin::SetMulticastLoopbackMode(bool loopback) {
1286   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
1287   if (is_connected())
1288     return ERR_SOCKET_IS_CONNECTED;
1289 
1290   if (loopback)
1291     socket_options_ |= SOCKET_OPTION_MULTICAST_LOOP;
1292   else
1293     socket_options_ &= ~SOCKET_OPTION_MULTICAST_LOOP;
1294   return OK;
1295 }
1296 
DscpToTrafficType(DiffServCodePoint dscp)1297 QOS_TRAFFIC_TYPE DscpToTrafficType(DiffServCodePoint dscp) {
1298   QOS_TRAFFIC_TYPE traffic_type = QOSTrafficTypeBestEffort;
1299   switch (dscp) {
1300     case DSCP_CS0:
1301       traffic_type = QOSTrafficTypeBestEffort;
1302       break;
1303     case DSCP_CS1:
1304       traffic_type = QOSTrafficTypeBackground;
1305       break;
1306     case DSCP_AF11:
1307     case DSCP_AF12:
1308     case DSCP_AF13:
1309     case DSCP_CS2:
1310     case DSCP_AF21:
1311     case DSCP_AF22:
1312     case DSCP_AF23:
1313     case DSCP_CS3:
1314     case DSCP_AF31:
1315     case DSCP_AF32:
1316     case DSCP_AF33:
1317     case DSCP_CS4:
1318       traffic_type = QOSTrafficTypeExcellentEffort;
1319       break;
1320     case DSCP_AF41:
1321     case DSCP_AF42:
1322     case DSCP_AF43:
1323     case DSCP_CS5:
1324       traffic_type = QOSTrafficTypeAudioVideo;
1325       break;
1326     case DSCP_EF:
1327     case DSCP_CS6:
1328       traffic_type = QOSTrafficTypeVoice;
1329       break;
1330     case DSCP_CS7:
1331       traffic_type = QOSTrafficTypeControl;
1332       break;
1333     case DSCP_NO_CHANGE:
1334       NOTREACHED();
1335       break;
1336   }
1337   return traffic_type;
1338 }
1339 
SetDiffServCodePoint(DiffServCodePoint dscp)1340 int UDPSocketWin::SetDiffServCodePoint(DiffServCodePoint dscp) {
1341   return SetTos(dscp, ECN_NO_CHANGE);
1342 }
1343 
SetTos(DiffServCodePoint dscp,EcnCodePoint ecn)1344 int UDPSocketWin::SetTos(DiffServCodePoint dscp, EcnCodePoint ecn) {
1345   if (!is_connected())
1346     return ERR_SOCKET_NOT_CONNECTED;
1347 
1348   if (dscp != DSCP_NO_CHANGE) {
1349     QwaveApi* api = GetQwaveApi();
1350 
1351     if (!api->qwave_supported()) {
1352       return ERR_NOT_IMPLEMENTED;
1353     }
1354 
1355     if (!dscp_manager_) {
1356       dscp_manager_ = std::make_unique<DscpManager>(api, socket_);
1357     }
1358 
1359     dscp_manager_->Set(dscp);
1360     if (remote_address_) {
1361       int rv = dscp_manager_->PrepareForSend(*remote_address_.get());
1362       if (rv != OK) {
1363         return rv;
1364       }
1365     }
1366   }
1367   if (ecn == ECN_NO_CHANGE) {
1368     return OK;
1369   }
1370   if (wsa_send_msg_ == nullptr) {
1371     wsa_send_msg_ = GetSendMsgPointer();
1372   }
1373   send_ecn_ = ecn;
1374   return OK;
1375 }
1376 
SetIPv6Only(bool ipv6_only)1377 int UDPSocketWin::SetIPv6Only(bool ipv6_only) {
1378   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
1379   if (is_connected()) {
1380     return ERR_SOCKET_IS_CONNECTED;
1381   }
1382   return net::SetIPv6Only(socket_, ipv6_only);
1383 }
1384 
DetachFromThread()1385 void UDPSocketWin::DetachFromThread() {
1386   DETACH_FROM_THREAD(thread_checker_);
1387 }
1388 
UseNonBlockingIO()1389 void UDPSocketWin::UseNonBlockingIO() {
1390   DCHECK(!core_);
1391   use_non_blocking_io_ = true;
1392 }
1393 
ApplySocketTag(const SocketTag & tag)1394 void UDPSocketWin::ApplySocketTag(const SocketTag& tag) {
1395   // Windows does not support any specific SocketTags so fail if any non-default
1396   // tag is applied.
1397   CHECK(tag == SocketTag());
1398 }
1399 
DscpManager(QwaveApi * api,SOCKET socket)1400 DscpManager::DscpManager(QwaveApi* api, SOCKET socket)
1401     : api_(api), socket_(socket) {
1402   RequestHandle();
1403 }
1404 
~DscpManager()1405 DscpManager::~DscpManager() {
1406   if (!qos_handle_)
1407     return;
1408 
1409   if (flow_id_ != 0)
1410     api_->RemoveSocketFromFlow(qos_handle_, NULL, flow_id_, 0);
1411 
1412   api_->CloseHandle(qos_handle_);
1413 }
1414 
Set(DiffServCodePoint dscp)1415 void DscpManager::Set(DiffServCodePoint dscp) {
1416   if (dscp == DSCP_NO_CHANGE || dscp == dscp_value_)
1417     return;
1418 
1419   dscp_value_ = dscp;
1420 
1421   // TODO(zstein): We could reuse the flow when the value changes
1422   // by calling QOSSetFlow with the new traffic type and dscp value.
1423   if (flow_id_ != 0 && qos_handle_) {
1424     api_->RemoveSocketFromFlow(qos_handle_, NULL, flow_id_, 0);
1425     configured_.clear();
1426     flow_id_ = 0;
1427   }
1428 }
1429 
PrepareForSend(const IPEndPoint & remote_address)1430 int DscpManager::PrepareForSend(const IPEndPoint& remote_address) {
1431   if (dscp_value_ == DSCP_NO_CHANGE) {
1432     // No DSCP value has been set.
1433     return OK;
1434   }
1435 
1436   if (!api_->qwave_supported())
1437     return ERR_NOT_IMPLEMENTED;
1438 
1439   if (!qos_handle_)
1440     return ERR_INVALID_HANDLE;  // The closest net error to try again later.
1441 
1442   if (configured_.find(remote_address) != configured_.end())
1443     return OK;
1444 
1445   SockaddrStorage storage;
1446   if (!remote_address.ToSockAddr(storage.addr, &storage.addr_len))
1447     return ERR_ADDRESS_INVALID;
1448 
1449   // We won't try this address again if we get an error.
1450   configured_.emplace(remote_address);
1451 
1452   // We don't need to call SetFlow if we already have a qos flow.
1453   bool new_flow = flow_id_ == 0;
1454 
1455   const QOS_TRAFFIC_TYPE traffic_type = DscpToTrafficType(dscp_value_);
1456 
1457   if (!api_->AddSocketToFlow(qos_handle_, socket_, storage.addr, traffic_type,
1458                              QOS_NON_ADAPTIVE_FLOW, &flow_id_)) {
1459     DWORD err = ::GetLastError();
1460     if (err == ERROR_DEVICE_REINITIALIZATION_NEEDED) {
1461       // Reset. PrepareForSend is called for every packet.  Once RequestHandle
1462       // completes asynchronously the next PrepareForSend call will re-register
1463       // the address with the new QoS Handle.  In the meantime, sends will
1464       // continue without DSCP.
1465       RequestHandle();
1466       configured_.clear();
1467       flow_id_ = 0;
1468       return ERR_INVALID_HANDLE;
1469     }
1470     return MapSystemError(err);
1471   }
1472 
1473   if (new_flow) {
1474     DWORD buf = dscp_value_;
1475     // This requires admin rights, and may fail, if so we ignore it
1476     // as AddSocketToFlow should still do *approximately* the right thing.
1477     api_->SetFlow(qos_handle_, flow_id_, QOSSetOutgoingDSCPValue, sizeof(buf),
1478                   &buf, 0, nullptr);
1479   }
1480 
1481   return OK;
1482 }
1483 
RequestHandle()1484 void DscpManager::RequestHandle() {
1485   if (handle_is_initializing_)
1486     return;
1487 
1488   if (qos_handle_) {
1489     api_->CloseHandle(qos_handle_);
1490     qos_handle_ = nullptr;
1491   }
1492 
1493   handle_is_initializing_ = true;
1494   base::ThreadPool::PostTaskAndReplyWithResult(
1495       FROM_HERE, {base::MayBlock()},
1496       base::BindOnce(&DscpManager::DoCreateHandle, api_),
1497       base::BindOnce(&DscpManager::OnHandleCreated, api_,
1498                      weak_ptr_factory_.GetWeakPtr()));
1499 }
1500 
DoCreateHandle(QwaveApi * api)1501 HANDLE DscpManager::DoCreateHandle(QwaveApi* api) {
1502   QOS_VERSION version;
1503   version.MajorVersion = 1;
1504   version.MinorVersion = 0;
1505 
1506   HANDLE handle = nullptr;
1507 
1508   // No access to net_log_ so swallow any errors here.
1509   api->CreateHandle(&version, &handle);
1510   return handle;
1511 }
1512 
OnHandleCreated(QwaveApi * api,base::WeakPtr<DscpManager> dscp_manager,HANDLE handle)1513 void DscpManager::OnHandleCreated(QwaveApi* api,
1514                                   base::WeakPtr<DscpManager> dscp_manager,
1515                                   HANDLE handle) {
1516   if (!handle)
1517     api->OnFatalError();
1518 
1519   if (!dscp_manager) {
1520     api->CloseHandle(handle);
1521     return;
1522   }
1523 
1524   DCHECK(dscp_manager->handle_is_initializing_);
1525   DCHECK(!dscp_manager->qos_handle_);
1526 
1527   dscp_manager->qos_handle_ = handle;
1528   dscp_manager->handle_is_initializing_ = false;
1529 }
1530 
1531 }  // namespace net
1532