1 // Copyright 2015 The Rust Project Developers.
2 //
3 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4 // https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5 // <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6 // option. This file may not be copied, modified, or distributed
7 // except according to those terms.
8 
9 use std::cmp::min;
10 use std::io::{self, IoSlice};
11 use std::marker::PhantomData;
12 use std::mem::{self, size_of, MaybeUninit};
13 use std::net::{self, Ipv4Addr, Ipv6Addr, Shutdown};
14 use std::os::windows::io::{
15     AsRawSocket, AsSocket, BorrowedSocket, FromRawSocket, IntoRawSocket, OwnedSocket, RawSocket,
16 };
17 use std::path::Path;
18 use std::sync::Once;
19 use std::time::{Duration, Instant};
20 use std::{process, ptr, slice};
21 
22 use windows_sys::Win32::Foundation::{SetHandleInformation, HANDLE, HANDLE_FLAG_INHERIT};
23 #[cfg(feature = "all")]
24 use windows_sys::Win32::Networking::WinSock::SO_PROTOCOL_INFOW;
25 use windows_sys::Win32::Networking::WinSock::{
26     self, tcp_keepalive, FIONBIO, IN6_ADDR, IN6_ADDR_0, INVALID_SOCKET, IN_ADDR, IN_ADDR_0,
27     POLLERR, POLLHUP, POLLRDNORM, POLLWRNORM, SD_BOTH, SD_RECEIVE, SD_SEND, SIO_KEEPALIVE_VALS,
28     SOCKET_ERROR, WSABUF, WSAEMSGSIZE, WSAESHUTDOWN, WSAPOLLFD, WSAPROTOCOL_INFOW,
29     WSA_FLAG_NO_HANDLE_INHERIT, WSA_FLAG_OVERLAPPED,
30 };
31 use windows_sys::Win32::System::Threading::INFINITE;
32 
33 use crate::{MsgHdr, RecvFlags, SockAddr, TcpKeepalive, Type};
34 
35 #[allow(non_camel_case_types)]
36 pub(crate) type c_int = std::os::raw::c_int;
37 
38 /// Fake MSG_TRUNC flag for the [`RecvFlags`] struct.
39 ///
40 /// The flag is enabled when a `WSARecv[From]` call returns `WSAEMSGSIZE`. The
41 /// value of the flag is defined by us.
42 pub(crate) const MSG_TRUNC: c_int = 0x01;
43 
44 // Used in `Domain`.
45 pub(crate) const AF_INET: c_int = windows_sys::Win32::Networking::WinSock::AF_INET as c_int;
46 pub(crate) const AF_INET6: c_int = windows_sys::Win32::Networking::WinSock::AF_INET6 as c_int;
47 pub(crate) const AF_UNIX: c_int = windows_sys::Win32::Networking::WinSock::AF_UNIX as c_int;
48 pub(crate) const AF_UNSPEC: c_int = windows_sys::Win32::Networking::WinSock::AF_UNSPEC as c_int;
49 // Used in `Type`.
50 pub(crate) const SOCK_STREAM: c_int = windows_sys::Win32::Networking::WinSock::SOCK_STREAM as c_int;
51 pub(crate) const SOCK_DGRAM: c_int = windows_sys::Win32::Networking::WinSock::SOCK_DGRAM as c_int;
52 pub(crate) const SOCK_RAW: c_int = windows_sys::Win32::Networking::WinSock::SOCK_RAW as c_int;
53 const SOCK_RDM: c_int = windows_sys::Win32::Networking::WinSock::SOCK_RDM as c_int;
54 pub(crate) const SOCK_SEQPACKET: c_int =
55     windows_sys::Win32::Networking::WinSock::SOCK_SEQPACKET as c_int;
56 // Used in `Protocol`.
57 pub(crate) use windows_sys::Win32::Networking::WinSock::{
58     IPPROTO_ICMP, IPPROTO_ICMPV6, IPPROTO_TCP, IPPROTO_UDP,
59 };
60 // Used in `SockAddr`.
61 pub(crate) use windows_sys::Win32::Networking::WinSock::{
62     SOCKADDR as sockaddr, SOCKADDR_IN as sockaddr_in, SOCKADDR_IN6 as sockaddr_in6,
63     SOCKADDR_STORAGE as sockaddr_storage,
64 };
65 #[allow(non_camel_case_types)]
66 pub(crate) type sa_family_t = windows_sys::Win32::Networking::WinSock::ADDRESS_FAMILY;
67 #[allow(non_camel_case_types)]
68 pub(crate) type socklen_t = windows_sys::Win32::Networking::WinSock::socklen_t;
69 // Used in `Socket`.
70 #[cfg(feature = "all")]
71 pub(crate) use windows_sys::Win32::Networking::WinSock::IP_HDRINCL;
72 pub(crate) use windows_sys::Win32::Networking::WinSock::{
73     IPPROTO_IPV6, IPV6_ADD_MEMBERSHIP, IPV6_DROP_MEMBERSHIP, IPV6_MREQ as Ipv6Mreq,
74     IPV6_MULTICAST_HOPS, IPV6_MULTICAST_IF, IPV6_MULTICAST_LOOP, IPV6_RECVTCLASS,
75     IPV6_UNICAST_HOPS, IPV6_V6ONLY, IP_ADD_MEMBERSHIP, IP_ADD_SOURCE_MEMBERSHIP,
76     IP_DROP_MEMBERSHIP, IP_DROP_SOURCE_MEMBERSHIP, IP_MREQ as IpMreq,
77     IP_MREQ_SOURCE as IpMreqSource, IP_MULTICAST_IF, IP_MULTICAST_LOOP, IP_MULTICAST_TTL,
78     IP_RECVTOS, IP_TOS, IP_TTL, LINGER as linger, MSG_OOB, MSG_PEEK, SO_BROADCAST, SO_ERROR,
79     SO_KEEPALIVE, SO_LINGER, SO_OOBINLINE, SO_RCVBUF, SO_RCVTIMEO, SO_REUSEADDR, SO_SNDBUF,
80     SO_SNDTIMEO, SO_TYPE, TCP_NODELAY,
81 };
82 pub(crate) const IPPROTO_IP: c_int = windows_sys::Win32::Networking::WinSock::IPPROTO_IP as c_int;
83 pub(crate) const SOL_SOCKET: c_int = windows_sys::Win32::Networking::WinSock::SOL_SOCKET as c_int;
84 
85 /// Type used in set/getsockopt to retrieve the `TCP_NODELAY` option.
86 ///
87 /// NOTE: <https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-getsockopt>
88 /// documents that options such as `TCP_NODELAY` and `SO_KEEPALIVE` expect a
89 /// `BOOL` (alias for `c_int`, 4 bytes), however in practice this turns out to
90 /// be false (or misleading) as a `BOOLEAN` (`c_uchar`, 1 byte) is returned by
91 /// `getsockopt`.
92 pub(crate) type Bool = windows_sys::Win32::Foundation::BOOLEAN;
93 
94 /// Maximum size of a buffer passed to system call like `recv` and `send`.
95 const MAX_BUF_LEN: usize = c_int::MAX as usize;
96 
97 /// Helper macro to execute a system call that returns an `io::Result`.
98 macro_rules! syscall {
99     ($fn: ident ( $($arg: expr),* $(,)* ), $err_test: path, $err_value: expr) => {{
100         #[allow(unused_unsafe)]
101         let res = unsafe { windows_sys::Win32::Networking::WinSock::$fn($($arg, )*) };
102         if $err_test(&res, &$err_value) {
103             Err(io::Error::last_os_error())
104         } else {
105             Ok(res)
106         }
107     }};
108 }
109 
110 impl_debug!(
111     crate::Domain,
112     self::AF_INET,
113     self::AF_INET6,
114     self::AF_UNIX,
115     self::AF_UNSPEC,
116 );
117 
118 /// Windows only API.
119 impl Type {
120     /// Our custom flag to set `WSA_FLAG_NO_HANDLE_INHERIT` on socket creation.
121     /// Trying to mimic `Type::cloexec` on windows.
122     const NO_INHERIT: c_int = 1 << ((size_of::<c_int>() * 8) - 1); // Last bit.
123 
124     /// Set `WSA_FLAG_NO_HANDLE_INHERIT` on the socket.
125     #[cfg(feature = "all")]
126     #[cfg_attr(docsrs, doc(cfg(all(windows, feature = "all"))))]
no_inherit(self) -> Type127     pub const fn no_inherit(self) -> Type {
128         self._no_inherit()
129     }
130 
_no_inherit(self) -> Type131     pub(crate) const fn _no_inherit(self) -> Type {
132         Type(self.0 | Type::NO_INHERIT)
133     }
134 }
135 
136 impl_debug!(
137     crate::Type,
138     self::SOCK_STREAM,
139     self::SOCK_DGRAM,
140     self::SOCK_RAW,
141     self::SOCK_RDM,
142     self::SOCK_SEQPACKET,
143 );
144 
145 impl_debug!(
146     crate::Protocol,
147     WinSock::IPPROTO_ICMP,
148     WinSock::IPPROTO_ICMPV6,
149     WinSock::IPPROTO_TCP,
150     WinSock::IPPROTO_UDP,
151 );
152 
153 impl std::fmt::Debug for RecvFlags {
fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result154     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
155         f.debug_struct("RecvFlags")
156             .field("is_truncated", &self.is_truncated())
157             .finish()
158     }
159 }
160 
161 #[repr(transparent)]
162 pub struct MaybeUninitSlice<'a> {
163     vec: WSABUF,
164     _lifetime: PhantomData<&'a mut [MaybeUninit<u8>]>,
165 }
166 
167 unsafe impl<'a> Send for MaybeUninitSlice<'a> {}
168 
169 unsafe impl<'a> Sync for MaybeUninitSlice<'a> {}
170 
171 impl<'a> MaybeUninitSlice<'a> {
new(buf: &'a mut [MaybeUninit<u8>]) -> MaybeUninitSlice<'a>172     pub fn new(buf: &'a mut [MaybeUninit<u8>]) -> MaybeUninitSlice<'a> {
173         assert!(buf.len() <= u32::MAX as usize);
174         MaybeUninitSlice {
175             vec: WSABUF {
176                 len: buf.len() as u32,
177                 buf: buf.as_mut_ptr().cast(),
178             },
179             _lifetime: PhantomData,
180         }
181     }
182 
as_slice(&self) -> &[MaybeUninit<u8>]183     pub fn as_slice(&self) -> &[MaybeUninit<u8>] {
184         unsafe { slice::from_raw_parts(self.vec.buf.cast(), self.vec.len as usize) }
185     }
186 
as_mut_slice(&mut self) -> &mut [MaybeUninit<u8>]187     pub fn as_mut_slice(&mut self) -> &mut [MaybeUninit<u8>] {
188         unsafe { slice::from_raw_parts_mut(self.vec.buf.cast(), self.vec.len as usize) }
189     }
190 }
191 
192 // Used in `MsgHdr`.
193 pub(crate) use windows_sys::Win32::Networking::WinSock::WSAMSG as msghdr;
194 
set_msghdr_name(msg: &mut msghdr, name: &SockAddr)195 pub(crate) fn set_msghdr_name(msg: &mut msghdr, name: &SockAddr) {
196     msg.name = name.as_ptr() as *mut _;
197     msg.namelen = name.len();
198 }
199 
set_msghdr_iov(msg: &mut msghdr, ptr: *mut WSABUF, len: usize)200 pub(crate) fn set_msghdr_iov(msg: &mut msghdr, ptr: *mut WSABUF, len: usize) {
201     msg.lpBuffers = ptr;
202     msg.dwBufferCount = min(len, u32::MAX as usize) as u32;
203 }
204 
set_msghdr_control(msg: &mut msghdr, ptr: *mut u8, len: usize)205 pub(crate) fn set_msghdr_control(msg: &mut msghdr, ptr: *mut u8, len: usize) {
206     msg.Control.buf = ptr;
207     msg.Control.len = len as u32;
208 }
209 
set_msghdr_flags(msg: &mut msghdr, flags: c_int)210 pub(crate) fn set_msghdr_flags(msg: &mut msghdr, flags: c_int) {
211     msg.dwFlags = flags as u32;
212 }
213 
msghdr_flags(msg: &msghdr) -> RecvFlags214 pub(crate) fn msghdr_flags(msg: &msghdr) -> RecvFlags {
215     RecvFlags(msg.dwFlags as c_int)
216 }
217 
init()218 fn init() {
219     static INIT: Once = Once::new();
220 
221     INIT.call_once(|| {
222         // Initialize winsock through the standard library by just creating a
223         // dummy socket. Whether this is successful or not we drop the result as
224         // libstd will be sure to have initialized winsock.
225         let _ = net::UdpSocket::bind("127.0.0.1:34254");
226     });
227 }
228 
229 pub(crate) type Socket = windows_sys::Win32::Networking::WinSock::SOCKET;
230 
socket_from_raw(socket: Socket) -> crate::socket::Inner231 pub(crate) unsafe fn socket_from_raw(socket: Socket) -> crate::socket::Inner {
232     crate::socket::Inner::from_raw_socket(socket as RawSocket)
233 }
234 
socket_as_raw(socket: &crate::socket::Inner) -> Socket235 pub(crate) fn socket_as_raw(socket: &crate::socket::Inner) -> Socket {
236     socket.as_raw_socket() as Socket
237 }
238 
socket_into_raw(socket: crate::socket::Inner) -> Socket239 pub(crate) fn socket_into_raw(socket: crate::socket::Inner) -> Socket {
240     socket.into_raw_socket() as Socket
241 }
242 
socket(family: c_int, mut ty: c_int, protocol: c_int) -> io::Result<Socket>243 pub(crate) fn socket(family: c_int, mut ty: c_int, protocol: c_int) -> io::Result<Socket> {
244     init();
245 
246     // Check if we set our custom flag.
247     let flags = if ty & Type::NO_INHERIT != 0 {
248         ty = ty & !Type::NO_INHERIT;
249         WSA_FLAG_NO_HANDLE_INHERIT
250     } else {
251         0
252     };
253 
254     syscall!(
255         WSASocketW(
256             family,
257             ty,
258             protocol,
259             ptr::null_mut(),
260             0,
261             WSA_FLAG_OVERLAPPED | flags,
262         ),
263         PartialEq::eq,
264         INVALID_SOCKET
265     )
266 }
267 
bind(socket: Socket, addr: &SockAddr) -> io::Result<()>268 pub(crate) fn bind(socket: Socket, addr: &SockAddr) -> io::Result<()> {
269     syscall!(bind(socket, addr.as_ptr(), addr.len()), PartialEq::ne, 0).map(|_| ())
270 }
271 
connect(socket: Socket, addr: &SockAddr) -> io::Result<()>272 pub(crate) fn connect(socket: Socket, addr: &SockAddr) -> io::Result<()> {
273     syscall!(connect(socket, addr.as_ptr(), addr.len()), PartialEq::ne, 0).map(|_| ())
274 }
275 
poll_connect(socket: &crate::Socket, timeout: Duration) -> io::Result<()>276 pub(crate) fn poll_connect(socket: &crate::Socket, timeout: Duration) -> io::Result<()> {
277     let start = Instant::now();
278 
279     let mut fd_array = WSAPOLLFD {
280         fd: socket.as_raw(),
281         events: (POLLRDNORM | POLLWRNORM) as i16,
282         revents: 0,
283     };
284 
285     loop {
286         let elapsed = start.elapsed();
287         if elapsed >= timeout {
288             return Err(io::ErrorKind::TimedOut.into());
289         }
290 
291         let timeout = (timeout - elapsed).as_millis();
292         let timeout = clamp(timeout, 1, c_int::MAX as u128) as c_int;
293 
294         match syscall!(
295             WSAPoll(&mut fd_array, 1, timeout),
296             PartialEq::eq,
297             SOCKET_ERROR
298         ) {
299             Ok(0) => return Err(io::ErrorKind::TimedOut.into()),
300             Ok(_) => {
301                 // Error or hang up indicates an error (or failure to connect).
302                 if (fd_array.revents & POLLERR as i16) != 0
303                     || (fd_array.revents & POLLHUP as i16) != 0
304                 {
305                     match socket.take_error() {
306                         Ok(Some(err)) => return Err(err),
307                         Ok(None) => {
308                             return Err(io::Error::new(
309                                 io::ErrorKind::Other,
310                                 "no error set after POLLHUP",
311                             ))
312                         }
313                         Err(err) => return Err(err),
314                     }
315                 }
316                 return Ok(());
317             }
318             // Got interrupted, try again.
319             Err(ref err) if err.kind() == io::ErrorKind::Interrupted => continue,
320             Err(err) => return Err(err),
321         }
322     }
323 }
324 
325 // TODO: use clamp from std lib, stable since 1.50.
clamp<T>(value: T, min: T, max: T) -> T where T: Ord,326 fn clamp<T>(value: T, min: T, max: T) -> T
327 where
328     T: Ord,
329 {
330     if value <= min {
331         min
332     } else if value >= max {
333         max
334     } else {
335         value
336     }
337 }
338 
listen(socket: Socket, backlog: c_int) -> io::Result<()>339 pub(crate) fn listen(socket: Socket, backlog: c_int) -> io::Result<()> {
340     syscall!(listen(socket, backlog), PartialEq::ne, 0).map(|_| ())
341 }
342 
accept(socket: Socket) -> io::Result<(Socket, SockAddr)>343 pub(crate) fn accept(socket: Socket) -> io::Result<(Socket, SockAddr)> {
344     // Safety: `accept` initialises the `SockAddr` for us.
345     unsafe {
346         SockAddr::try_init(|storage, len| {
347             syscall!(
348                 accept(socket, storage.cast(), len),
349                 PartialEq::eq,
350                 INVALID_SOCKET
351             )
352         })
353     }
354 }
355 
getsockname(socket: Socket) -> io::Result<SockAddr>356 pub(crate) fn getsockname(socket: Socket) -> io::Result<SockAddr> {
357     // Safety: `getsockname` initialises the `SockAddr` for us.
358     unsafe {
359         SockAddr::try_init(|storage, len| {
360             syscall!(
361                 getsockname(socket, storage.cast(), len),
362                 PartialEq::eq,
363                 SOCKET_ERROR
364             )
365         })
366     }
367     .map(|(_, addr)| addr)
368 }
369 
getpeername(socket: Socket) -> io::Result<SockAddr>370 pub(crate) fn getpeername(socket: Socket) -> io::Result<SockAddr> {
371     // Safety: `getpeername` initialises the `SockAddr` for us.
372     unsafe {
373         SockAddr::try_init(|storage, len| {
374             syscall!(
375                 getpeername(socket, storage.cast(), len),
376                 PartialEq::eq,
377                 SOCKET_ERROR
378             )
379         })
380     }
381     .map(|(_, addr)| addr)
382 }
383 
try_clone(socket: Socket) -> io::Result<Socket>384 pub(crate) fn try_clone(socket: Socket) -> io::Result<Socket> {
385     let mut info: MaybeUninit<WSAPROTOCOL_INFOW> = MaybeUninit::uninit();
386     syscall!(
387         // NOTE: `process.id` is the same as `GetCurrentProcessId`.
388         WSADuplicateSocketW(socket, process::id(), info.as_mut_ptr()),
389         PartialEq::eq,
390         SOCKET_ERROR
391     )?;
392     // Safety: `WSADuplicateSocketW` intialised `info` for us.
393     let mut info = unsafe { info.assume_init() };
394 
395     syscall!(
396         WSASocketW(
397             info.iAddressFamily,
398             info.iSocketType,
399             info.iProtocol,
400             &mut info,
401             0,
402             WSA_FLAG_OVERLAPPED | WSA_FLAG_NO_HANDLE_INHERIT,
403         ),
404         PartialEq::eq,
405         INVALID_SOCKET
406     )
407 }
408 
set_nonblocking(socket: Socket, nonblocking: bool) -> io::Result<()>409 pub(crate) fn set_nonblocking(socket: Socket, nonblocking: bool) -> io::Result<()> {
410     let mut nonblocking = if nonblocking { 1 } else { 0 };
411     ioctlsocket(socket, FIONBIO, &mut nonblocking)
412 }
413 
shutdown(socket: Socket, how: Shutdown) -> io::Result<()>414 pub(crate) fn shutdown(socket: Socket, how: Shutdown) -> io::Result<()> {
415     let how = match how {
416         Shutdown::Write => SD_SEND,
417         Shutdown::Read => SD_RECEIVE,
418         Shutdown::Both => SD_BOTH,
419     } as i32;
420     syscall!(shutdown(socket, how), PartialEq::eq, SOCKET_ERROR).map(|_| ())
421 }
422 
recv(socket: Socket, buf: &mut [MaybeUninit<u8>], flags: c_int) -> io::Result<usize>423 pub(crate) fn recv(socket: Socket, buf: &mut [MaybeUninit<u8>], flags: c_int) -> io::Result<usize> {
424     let res = syscall!(
425         recv(
426             socket,
427             buf.as_mut_ptr().cast(),
428             min(buf.len(), MAX_BUF_LEN) as c_int,
429             flags,
430         ),
431         PartialEq::eq,
432         SOCKET_ERROR
433     );
434     match res {
435         Ok(n) => Ok(n as usize),
436         Err(ref err) if err.raw_os_error() == Some(WSAESHUTDOWN as i32) => Ok(0),
437         Err(err) => Err(err),
438     }
439 }
440 
recv_vectored( socket: Socket, bufs: &mut [crate::MaybeUninitSlice<'_>], flags: c_int, ) -> io::Result<(usize, RecvFlags)>441 pub(crate) fn recv_vectored(
442     socket: Socket,
443     bufs: &mut [crate::MaybeUninitSlice<'_>],
444     flags: c_int,
445 ) -> io::Result<(usize, RecvFlags)> {
446     let mut nread = 0;
447     let mut flags = flags as u32;
448     let res = syscall!(
449         WSARecv(
450             socket,
451             bufs.as_mut_ptr().cast(),
452             min(bufs.len(), u32::MAX as usize) as u32,
453             &mut nread,
454             &mut flags,
455             ptr::null_mut(),
456             None,
457         ),
458         PartialEq::eq,
459         SOCKET_ERROR
460     );
461     match res {
462         Ok(_) => Ok((nread as usize, RecvFlags(0))),
463         Err(ref err) if err.raw_os_error() == Some(WSAESHUTDOWN as i32) => Ok((0, RecvFlags(0))),
464         Err(ref err) if err.raw_os_error() == Some(WSAEMSGSIZE as i32) => {
465             Ok((nread as usize, RecvFlags(MSG_TRUNC)))
466         }
467         Err(err) => Err(err),
468     }
469 }
470 
recv_from( socket: Socket, buf: &mut [MaybeUninit<u8>], flags: c_int, ) -> io::Result<(usize, SockAddr)>471 pub(crate) fn recv_from(
472     socket: Socket,
473     buf: &mut [MaybeUninit<u8>],
474     flags: c_int,
475 ) -> io::Result<(usize, SockAddr)> {
476     // Safety: `recvfrom` initialises the `SockAddr` for us.
477     unsafe {
478         SockAddr::try_init(|storage, addrlen| {
479             let res = syscall!(
480                 recvfrom(
481                     socket,
482                     buf.as_mut_ptr().cast(),
483                     min(buf.len(), MAX_BUF_LEN) as c_int,
484                     flags,
485                     storage.cast(),
486                     addrlen,
487                 ),
488                 PartialEq::eq,
489                 SOCKET_ERROR
490             );
491             match res {
492                 Ok(n) => Ok(n as usize),
493                 Err(ref err) if err.raw_os_error() == Some(WSAESHUTDOWN as i32) => Ok(0),
494                 Err(err) => Err(err),
495             }
496         })
497     }
498 }
499 
peek_sender(socket: Socket) -> io::Result<SockAddr>500 pub(crate) fn peek_sender(socket: Socket) -> io::Result<SockAddr> {
501     // Safety: `recvfrom` initialises the `SockAddr` for us.
502     let ((), sender) = unsafe {
503         SockAddr::try_init(|storage, addrlen| {
504             let res = syscall!(
505                 recvfrom(
506                     socket,
507                     // Windows *appears* not to care if you pass a null pointer.
508                     ptr::null_mut(),
509                     0,
510                     MSG_PEEK,
511                     storage.cast(),
512                     addrlen,
513                 ),
514                 PartialEq::eq,
515                 SOCKET_ERROR
516             );
517             match res {
518                 Ok(_n) => Ok(()),
519                 Err(e) => match e.raw_os_error() {
520                     Some(code) if code == (WSAESHUTDOWN as i32) || code == (WSAEMSGSIZE as i32) => {
521                         Ok(())
522                     }
523                     _ => Err(e),
524                 },
525             }
526         })
527     }?;
528 
529     Ok(sender)
530 }
531 
recv_from_vectored( socket: Socket, bufs: &mut [crate::MaybeUninitSlice<'_>], flags: c_int, ) -> io::Result<(usize, RecvFlags, SockAddr)>532 pub(crate) fn recv_from_vectored(
533     socket: Socket,
534     bufs: &mut [crate::MaybeUninitSlice<'_>],
535     flags: c_int,
536 ) -> io::Result<(usize, RecvFlags, SockAddr)> {
537     // Safety: `recvfrom` initialises the `SockAddr` for us.
538     unsafe {
539         SockAddr::try_init(|storage, addrlen| {
540             let mut nread = 0;
541             let mut flags = flags as u32;
542             let res = syscall!(
543                 WSARecvFrom(
544                     socket,
545                     bufs.as_mut_ptr().cast(),
546                     min(bufs.len(), u32::MAX as usize) as u32,
547                     &mut nread,
548                     &mut flags,
549                     storage.cast(),
550                     addrlen,
551                     ptr::null_mut(),
552                     None,
553                 ),
554                 PartialEq::eq,
555                 SOCKET_ERROR
556             );
557             match res {
558                 Ok(_) => Ok((nread as usize, RecvFlags(0))),
559                 Err(ref err) if err.raw_os_error() == Some(WSAESHUTDOWN as i32) => {
560                     Ok((nread as usize, RecvFlags(0)))
561                 }
562                 Err(ref err) if err.raw_os_error() == Some(WSAEMSGSIZE as i32) => {
563                     Ok((nread as usize, RecvFlags(MSG_TRUNC)))
564                 }
565                 Err(err) => Err(err),
566             }
567         })
568     }
569     .map(|((n, recv_flags), addr)| (n, recv_flags, addr))
570 }
571 
send(socket: Socket, buf: &[u8], flags: c_int) -> io::Result<usize>572 pub(crate) fn send(socket: Socket, buf: &[u8], flags: c_int) -> io::Result<usize> {
573     syscall!(
574         send(
575             socket,
576             buf.as_ptr().cast(),
577             min(buf.len(), MAX_BUF_LEN) as c_int,
578             flags,
579         ),
580         PartialEq::eq,
581         SOCKET_ERROR
582     )
583     .map(|n| n as usize)
584 }
585 
send_vectored( socket: Socket, bufs: &[IoSlice<'_>], flags: c_int, ) -> io::Result<usize>586 pub(crate) fn send_vectored(
587     socket: Socket,
588     bufs: &[IoSlice<'_>],
589     flags: c_int,
590 ) -> io::Result<usize> {
591     let mut nsent = 0;
592     syscall!(
593         WSASend(
594             socket,
595             // FIXME: From the `WSASend` docs [1]:
596             // > For a Winsock application, once the WSASend function is called,
597             // > the system owns these buffers and the application may not
598             // > access them.
599             //
600             // So what we're doing is actually UB as `bufs` needs to be `&mut
601             // [IoSlice<'_>]`.
602             //
603             // Tracking issue: https://github.com/rust-lang/socket2-rs/issues/129.
604             //
605             // NOTE: `send_to_vectored` has the same problem.
606             //
607             // [1] https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasend
608             bufs.as_ptr() as *mut _,
609             min(bufs.len(), u32::MAX as usize) as u32,
610             &mut nsent,
611             flags as u32,
612             std::ptr::null_mut(),
613             None,
614         ),
615         PartialEq::eq,
616         SOCKET_ERROR
617     )
618     .map(|_| nsent as usize)
619 }
620 
send_to( socket: Socket, buf: &[u8], addr: &SockAddr, flags: c_int, ) -> io::Result<usize>621 pub(crate) fn send_to(
622     socket: Socket,
623     buf: &[u8],
624     addr: &SockAddr,
625     flags: c_int,
626 ) -> io::Result<usize> {
627     syscall!(
628         sendto(
629             socket,
630             buf.as_ptr().cast(),
631             min(buf.len(), MAX_BUF_LEN) as c_int,
632             flags,
633             addr.as_ptr(),
634             addr.len(),
635         ),
636         PartialEq::eq,
637         SOCKET_ERROR
638     )
639     .map(|n| n as usize)
640 }
641 
send_to_vectored( socket: Socket, bufs: &[IoSlice<'_>], addr: &SockAddr, flags: c_int, ) -> io::Result<usize>642 pub(crate) fn send_to_vectored(
643     socket: Socket,
644     bufs: &[IoSlice<'_>],
645     addr: &SockAddr,
646     flags: c_int,
647 ) -> io::Result<usize> {
648     let mut nsent = 0;
649     syscall!(
650         WSASendTo(
651             socket,
652             // FIXME: Same problem as in `send_vectored`.
653             bufs.as_ptr() as *mut _,
654             bufs.len().min(u32::MAX as usize) as u32,
655             &mut nsent,
656             flags as u32,
657             addr.as_ptr(),
658             addr.len(),
659             ptr::null_mut(),
660             None,
661         ),
662         PartialEq::eq,
663         SOCKET_ERROR
664     )
665     .map(|_| nsent as usize)
666 }
667 
sendmsg(socket: Socket, msg: &MsgHdr<'_, '_, '_>, flags: c_int) -> io::Result<usize>668 pub(crate) fn sendmsg(socket: Socket, msg: &MsgHdr<'_, '_, '_>, flags: c_int) -> io::Result<usize> {
669     let mut nsent = 0;
670     syscall!(
671         WSASendMsg(
672             socket,
673             &msg.inner,
674             flags as u32,
675             &mut nsent,
676             ptr::null_mut(),
677             None,
678         ),
679         PartialEq::eq,
680         SOCKET_ERROR
681     )
682     .map(|_| nsent as usize)
683 }
684 
685 /// Wrapper around `getsockopt` to deal with platform specific timeouts.
timeout_opt(fd: Socket, lvl: c_int, name: i32) -> io::Result<Option<Duration>>686 pub(crate) fn timeout_opt(fd: Socket, lvl: c_int, name: i32) -> io::Result<Option<Duration>> {
687     unsafe { getsockopt(fd, lvl, name).map(from_ms) }
688 }
689 
from_ms(duration: u32) -> Option<Duration>690 fn from_ms(duration: u32) -> Option<Duration> {
691     if duration == 0 {
692         None
693     } else {
694         let secs = duration / 1000;
695         let nsec = (duration % 1000) * 1000000;
696         Some(Duration::new(secs as u64, nsec as u32))
697     }
698 }
699 
700 /// Wrapper around `setsockopt` to deal with platform specific timeouts.
set_timeout_opt( socket: Socket, level: c_int, optname: i32, duration: Option<Duration>, ) -> io::Result<()>701 pub(crate) fn set_timeout_opt(
702     socket: Socket,
703     level: c_int,
704     optname: i32,
705     duration: Option<Duration>,
706 ) -> io::Result<()> {
707     let duration = into_ms(duration);
708     unsafe { setsockopt(socket, level, optname, duration) }
709 }
710 
into_ms(duration: Option<Duration>) -> u32711 fn into_ms(duration: Option<Duration>) -> u32 {
712     // Note that a duration is a (u64, u32) (seconds, nanoseconds) pair, and the
713     // timeouts in windows APIs are typically u32 milliseconds. To translate, we
714     // have two pieces to take care of:
715     //
716     // * Nanosecond precision is rounded up
717     // * Greater than u32::MAX milliseconds (50 days) is rounded up to
718     //   INFINITE (never time out).
719     duration.map_or(0, |duration| {
720         min(duration.as_millis(), INFINITE as u128) as u32
721     })
722 }
723 
set_tcp_keepalive(socket: Socket, keepalive: &TcpKeepalive) -> io::Result<()>724 pub(crate) fn set_tcp_keepalive(socket: Socket, keepalive: &TcpKeepalive) -> io::Result<()> {
725     let mut keepalive = tcp_keepalive {
726         onoff: 1,
727         keepalivetime: into_ms(keepalive.time),
728         keepaliveinterval: into_ms(keepalive.interval),
729     };
730     let mut out = 0;
731     syscall!(
732         WSAIoctl(
733             socket,
734             SIO_KEEPALIVE_VALS,
735             &mut keepalive as *mut _ as *mut _,
736             size_of::<tcp_keepalive>() as _,
737             ptr::null_mut(),
738             0,
739             &mut out,
740             ptr::null_mut(),
741             None,
742         ),
743         PartialEq::eq,
744         SOCKET_ERROR
745     )
746     .map(|_| ())
747 }
748 
749 /// Caller must ensure `T` is the correct type for `level` and `optname`.
750 // NOTE: `optname` is actually `i32`, but all constants are `u32`.
getsockopt<T>(socket: Socket, level: c_int, optname: i32) -> io::Result<T>751 pub(crate) unsafe fn getsockopt<T>(socket: Socket, level: c_int, optname: i32) -> io::Result<T> {
752     let mut optval: MaybeUninit<T> = MaybeUninit::uninit();
753     let mut optlen = mem::size_of::<T>() as c_int;
754     syscall!(
755         getsockopt(
756             socket,
757             level as i32,
758             optname,
759             optval.as_mut_ptr().cast(),
760             &mut optlen,
761         ),
762         PartialEq::eq,
763         SOCKET_ERROR
764     )
765     .map(|_| {
766         debug_assert_eq!(optlen as usize, mem::size_of::<T>());
767         // Safety: `getsockopt` initialised `optval` for us.
768         optval.assume_init()
769     })
770 }
771 
772 /// Caller must ensure `T` is the correct type for `level` and `optname`.
773 // NOTE: `optname` is actually `i32`, but all constants are `u32`.
setsockopt<T>( socket: Socket, level: c_int, optname: i32, optval: T, ) -> io::Result<()>774 pub(crate) unsafe fn setsockopt<T>(
775     socket: Socket,
776     level: c_int,
777     optname: i32,
778     optval: T,
779 ) -> io::Result<()> {
780     syscall!(
781         setsockopt(
782             socket,
783             level as i32,
784             optname,
785             (&optval as *const T).cast(),
786             mem::size_of::<T>() as c_int,
787         ),
788         PartialEq::eq,
789         SOCKET_ERROR
790     )
791     .map(|_| ())
792 }
793 
ioctlsocket(socket: Socket, cmd: i32, payload: &mut u32) -> io::Result<()>794 fn ioctlsocket(socket: Socket, cmd: i32, payload: &mut u32) -> io::Result<()> {
795     syscall!(
796         ioctlsocket(socket, cmd, payload),
797         PartialEq::eq,
798         SOCKET_ERROR
799     )
800     .map(|_| ())
801 }
802 
to_in_addr(addr: &Ipv4Addr) -> IN_ADDR803 pub(crate) fn to_in_addr(addr: &Ipv4Addr) -> IN_ADDR {
804     IN_ADDR {
805         S_un: IN_ADDR_0 {
806             // `S_un` is stored as BE on all machines, and the array is in BE
807             // order. So the native endian conversion method is used so that
808             // it's never swapped.
809             S_addr: u32::from_ne_bytes(addr.octets()),
810         },
811     }
812 }
813 
from_in_addr(in_addr: IN_ADDR) -> Ipv4Addr814 pub(crate) fn from_in_addr(in_addr: IN_ADDR) -> Ipv4Addr {
815     Ipv4Addr::from(unsafe { in_addr.S_un.S_addr }.to_ne_bytes())
816 }
817 
to_in6_addr(addr: &Ipv6Addr) -> IN6_ADDR818 pub(crate) fn to_in6_addr(addr: &Ipv6Addr) -> IN6_ADDR {
819     IN6_ADDR {
820         u: IN6_ADDR_0 {
821             Byte: addr.octets(),
822         },
823     }
824 }
825 
from_in6_addr(addr: IN6_ADDR) -> Ipv6Addr826 pub(crate) fn from_in6_addr(addr: IN6_ADDR) -> Ipv6Addr {
827     Ipv6Addr::from(unsafe { addr.u.Byte })
828 }
829 
to_mreqn( multiaddr: &Ipv4Addr, interface: &crate::socket::InterfaceIndexOrAddress, ) -> IpMreq830 pub(crate) fn to_mreqn(
831     multiaddr: &Ipv4Addr,
832     interface: &crate::socket::InterfaceIndexOrAddress,
833 ) -> IpMreq {
834     IpMreq {
835         imr_multiaddr: to_in_addr(multiaddr),
836         // Per https://docs.microsoft.com/en-us/windows/win32/api/ws2ipdef/ns-ws2ipdef-ip_mreq#members:
837         //
838         // imr_interface
839         //
840         // The local IPv4 address of the interface or the interface index on
841         // which the multicast group should be joined or dropped. This value is
842         // in network byte order. If this member specifies an IPv4 address of
843         // 0.0.0.0, the default IPv4 multicast interface is used.
844         //
845         // To use an interface index of 1 would be the same as an IP address of
846         // 0.0.0.1.
847         imr_interface: match interface {
848             crate::socket::InterfaceIndexOrAddress::Index(interface) => {
849                 to_in_addr(&(*interface).into())
850             }
851             crate::socket::InterfaceIndexOrAddress::Address(interface) => to_in_addr(interface),
852         },
853     }
854 }
855 
856 #[allow(unsafe_op_in_unsafe_fn)]
unix_sockaddr(path: &Path) -> io::Result<SockAddr>857 pub(crate) fn unix_sockaddr(path: &Path) -> io::Result<SockAddr> {
858     // SAFETY: a `sockaddr_storage` of all zeros is valid.
859     let mut storage = unsafe { mem::zeroed::<sockaddr_storage>() };
860     let len = {
861         let storage: &mut windows_sys::Win32::Networking::WinSock::SOCKADDR_UN =
862             unsafe { &mut *(&mut storage as *mut sockaddr_storage).cast() };
863 
864         // Windows expects a UTF-8 path here even though Windows paths are
865         // usually UCS-2 encoded. If Rust exposed OsStr's Wtf8 encoded
866         // buffer, this could be used directly, relying on Windows to
867         // validate the path, but Rust hides this implementation detail.
868         //
869         // See <https://github.com/rust-lang/rust/pull/95290>.
870         let bytes = path
871             .to_str()
872             .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "path must be valid UTF-8"))?
873             .as_bytes();
874 
875         // Windows appears to allow non-null-terminated paths, but this is
876         // not documented, so do not rely on it yet.
877         //
878         // See <https://github.com/rust-lang/socket2/issues/331>.
879         if bytes.len() >= storage.sun_path.len() {
880             return Err(io::Error::new(
881                 io::ErrorKind::InvalidInput,
882                 "path must be shorter than SUN_LEN",
883             ));
884         }
885 
886         storage.sun_family = crate::sys::AF_UNIX as sa_family_t;
887         // `storage` was initialized to zero above, so the path is
888         // already null terminated.
889         storage.sun_path[..bytes.len()].copy_from_slice(bytes);
890 
891         let base = storage as *const _ as usize;
892         let path = &storage.sun_path as *const _ as usize;
893         let sun_path_offset = path - base;
894         sun_path_offset + bytes.len() + 1
895     };
896     Ok(unsafe { SockAddr::new(storage, len as socklen_t) })
897 }
898 
899 /// Windows only API.
900 impl crate::Socket {
901     /// Sets `HANDLE_FLAG_INHERIT` using `SetHandleInformation`.
902     #[cfg(feature = "all")]
903     #[cfg_attr(docsrs, doc(cfg(all(windows, feature = "all"))))]
set_no_inherit(&self, no_inherit: bool) -> io::Result<()>904     pub fn set_no_inherit(&self, no_inherit: bool) -> io::Result<()> {
905         self._set_no_inherit(no_inherit)
906     }
907 
_set_no_inherit(&self, no_inherit: bool) -> io::Result<()>908     pub(crate) fn _set_no_inherit(&self, no_inherit: bool) -> io::Result<()> {
909         // NOTE: can't use `syscall!` because it expects the function in the
910         // `windows_sys::Win32::Networking::WinSock::` path.
911         let res = unsafe {
912             SetHandleInformation(
913                 self.as_raw() as HANDLE,
914                 HANDLE_FLAG_INHERIT,
915                 !no_inherit as _,
916             )
917         };
918         if res == 0 {
919             // Zero means error.
920             Err(io::Error::last_os_error())
921         } else {
922             Ok(())
923         }
924     }
925 
926     /// Returns the [`Protocol`] of this socket by checking the `SO_PROTOCOL_INFOW`
927     /// option on this socket.
928     ///
929     /// [`Protocol`]: crate::Protocol
930     #[cfg(feature = "all")]
protocol(&self) -> io::Result<Option<crate::Protocol>>931     pub fn protocol(&self) -> io::Result<Option<crate::Protocol>> {
932         let info = unsafe {
933             getsockopt::<WSAPROTOCOL_INFOW>(self.as_raw(), SOL_SOCKET, SO_PROTOCOL_INFOW)?
934         };
935         match info.iProtocol {
936             0 => Ok(None),
937             p => Ok(Some(crate::Protocol::from(p))),
938         }
939     }
940 }
941 
942 #[cfg_attr(docsrs, doc(cfg(windows)))]
943 impl AsSocket for crate::Socket {
as_socket(&self) -> BorrowedSocket<'_>944     fn as_socket(&self) -> BorrowedSocket<'_> {
945         // SAFETY: lifetime is bound by self.
946         unsafe { BorrowedSocket::borrow_raw(self.as_raw() as RawSocket) }
947     }
948 }
949 
950 #[cfg_attr(docsrs, doc(cfg(windows)))]
951 impl AsRawSocket for crate::Socket {
as_raw_socket(&self) -> RawSocket952     fn as_raw_socket(&self) -> RawSocket {
953         self.as_raw() as RawSocket
954     }
955 }
956 
957 #[cfg_attr(docsrs, doc(cfg(windows)))]
958 impl From<crate::Socket> for OwnedSocket {
from(sock: crate::Socket) -> OwnedSocket959     fn from(sock: crate::Socket) -> OwnedSocket {
960         // SAFETY: sock.into_raw() always returns a valid fd.
961         unsafe { OwnedSocket::from_raw_socket(sock.into_raw() as RawSocket) }
962     }
963 }
964 
965 #[cfg_attr(docsrs, doc(cfg(windows)))]
966 impl IntoRawSocket for crate::Socket {
into_raw_socket(self) -> RawSocket967     fn into_raw_socket(self) -> RawSocket {
968         self.into_raw() as RawSocket
969     }
970 }
971 
972 #[cfg_attr(docsrs, doc(cfg(windows)))]
973 impl From<OwnedSocket> for crate::Socket {
from(fd: OwnedSocket) -> crate::Socket974     fn from(fd: OwnedSocket) -> crate::Socket {
975         // SAFETY: `OwnedFd` ensures the fd is valid.
976         unsafe { crate::Socket::from_raw_socket(fd.into_raw_socket()) }
977     }
978 }
979 
980 #[cfg_attr(docsrs, doc(cfg(windows)))]
981 impl FromRawSocket for crate::Socket {
from_raw_socket(socket: RawSocket) -> crate::Socket982     unsafe fn from_raw_socket(socket: RawSocket) -> crate::Socket {
983         crate::Socket::from_raw(socket as Socket)
984     }
985 }
986 
987 #[test]
in_addr_convertion()988 fn in_addr_convertion() {
989     let ip = Ipv4Addr::new(127, 0, 0, 1);
990     let raw = to_in_addr(&ip);
991     assert_eq!(unsafe { raw.S_un.S_addr }, 127 << 0 | 1 << 24);
992     assert_eq!(from_in_addr(raw), ip);
993 
994     let ip = Ipv4Addr::new(127, 34, 4, 12);
995     let raw = to_in_addr(&ip);
996     assert_eq!(
997         unsafe { raw.S_un.S_addr },
998         127 << 0 | 34 << 8 | 4 << 16 | 12 << 24
999     );
1000     assert_eq!(from_in_addr(raw), ip);
1001 }
1002 
1003 #[test]
in6_addr_convertion()1004 fn in6_addr_convertion() {
1005     let ip = Ipv6Addr::new(0x2000, 1, 2, 3, 4, 5, 6, 7);
1006     let raw = to_in6_addr(&ip);
1007     let want = [
1008         0x2000u16.to_be(),
1009         1u16.to_be(),
1010         2u16.to_be(),
1011         3u16.to_be(),
1012         4u16.to_be(),
1013         5u16.to_be(),
1014         6u16.to_be(),
1015         7u16.to_be(),
1016     ];
1017     assert_eq!(unsafe { raw.u.Word }, want);
1018     assert_eq!(from_in6_addr(raw), ip);
1019 }
1020