1 // Copyright 2015 The Rust Project Developers.
2 //
3 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4 // https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5 // <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6 // option. This file may not be copied, modified, or distributed
7 // except according to those terms.
8
9 use std::cmp::min;
10 use std::io::{self, IoSlice};
11 use std::marker::PhantomData;
12 use std::mem::{self, size_of, MaybeUninit};
13 use std::net::{self, Ipv4Addr, Ipv6Addr, Shutdown};
14 use std::os::windows::io::{
15 AsRawSocket, AsSocket, BorrowedSocket, FromRawSocket, IntoRawSocket, OwnedSocket, RawSocket,
16 };
17 use std::path::Path;
18 use std::sync::Once;
19 use std::time::{Duration, Instant};
20 use std::{process, ptr, slice};
21
22 use windows_sys::Win32::Foundation::{SetHandleInformation, HANDLE, HANDLE_FLAG_INHERIT};
23 #[cfg(feature = "all")]
24 use windows_sys::Win32::Networking::WinSock::SO_PROTOCOL_INFOW;
25 use windows_sys::Win32::Networking::WinSock::{
26 self, tcp_keepalive, FIONBIO, IN6_ADDR, IN6_ADDR_0, INVALID_SOCKET, IN_ADDR, IN_ADDR_0,
27 POLLERR, POLLHUP, POLLRDNORM, POLLWRNORM, SD_BOTH, SD_RECEIVE, SD_SEND, SIO_KEEPALIVE_VALS,
28 SOCKET_ERROR, WSABUF, WSAEMSGSIZE, WSAESHUTDOWN, WSAPOLLFD, WSAPROTOCOL_INFOW,
29 WSA_FLAG_NO_HANDLE_INHERIT, WSA_FLAG_OVERLAPPED,
30 };
31 use windows_sys::Win32::System::Threading::INFINITE;
32
33 use crate::{MsgHdr, RecvFlags, SockAddr, TcpKeepalive, Type};
34
35 #[allow(non_camel_case_types)]
36 pub(crate) type c_int = std::os::raw::c_int;
37
38 /// Fake MSG_TRUNC flag for the [`RecvFlags`] struct.
39 ///
40 /// The flag is enabled when a `WSARecv[From]` call returns `WSAEMSGSIZE`. The
41 /// value of the flag is defined by us.
42 pub(crate) const MSG_TRUNC: c_int = 0x01;
43
44 // Used in `Domain`.
45 pub(crate) const AF_INET: c_int = windows_sys::Win32::Networking::WinSock::AF_INET as c_int;
46 pub(crate) const AF_INET6: c_int = windows_sys::Win32::Networking::WinSock::AF_INET6 as c_int;
47 pub(crate) const AF_UNIX: c_int = windows_sys::Win32::Networking::WinSock::AF_UNIX as c_int;
48 pub(crate) const AF_UNSPEC: c_int = windows_sys::Win32::Networking::WinSock::AF_UNSPEC as c_int;
49 // Used in `Type`.
50 pub(crate) const SOCK_STREAM: c_int = windows_sys::Win32::Networking::WinSock::SOCK_STREAM as c_int;
51 pub(crate) const SOCK_DGRAM: c_int = windows_sys::Win32::Networking::WinSock::SOCK_DGRAM as c_int;
52 pub(crate) const SOCK_RAW: c_int = windows_sys::Win32::Networking::WinSock::SOCK_RAW as c_int;
53 const SOCK_RDM: c_int = windows_sys::Win32::Networking::WinSock::SOCK_RDM as c_int;
54 pub(crate) const SOCK_SEQPACKET: c_int =
55 windows_sys::Win32::Networking::WinSock::SOCK_SEQPACKET as c_int;
56 // Used in `Protocol`.
57 pub(crate) use windows_sys::Win32::Networking::WinSock::{
58 IPPROTO_ICMP, IPPROTO_ICMPV6, IPPROTO_TCP, IPPROTO_UDP,
59 };
60 // Used in `SockAddr`.
61 pub(crate) use windows_sys::Win32::Networking::WinSock::{
62 SOCKADDR as sockaddr, SOCKADDR_IN as sockaddr_in, SOCKADDR_IN6 as sockaddr_in6,
63 SOCKADDR_STORAGE as sockaddr_storage,
64 };
65 #[allow(non_camel_case_types)]
66 pub(crate) type sa_family_t = windows_sys::Win32::Networking::WinSock::ADDRESS_FAMILY;
67 #[allow(non_camel_case_types)]
68 pub(crate) type socklen_t = windows_sys::Win32::Networking::WinSock::socklen_t;
69 // Used in `Socket`.
70 #[cfg(feature = "all")]
71 pub(crate) use windows_sys::Win32::Networking::WinSock::IP_HDRINCL;
72 pub(crate) use windows_sys::Win32::Networking::WinSock::{
73 IPPROTO_IPV6, IPV6_ADD_MEMBERSHIP, IPV6_DROP_MEMBERSHIP, IPV6_MREQ as Ipv6Mreq,
74 IPV6_MULTICAST_HOPS, IPV6_MULTICAST_IF, IPV6_MULTICAST_LOOP, IPV6_RECVTCLASS,
75 IPV6_UNICAST_HOPS, IPV6_V6ONLY, IP_ADD_MEMBERSHIP, IP_ADD_SOURCE_MEMBERSHIP,
76 IP_DROP_MEMBERSHIP, IP_DROP_SOURCE_MEMBERSHIP, IP_MREQ as IpMreq,
77 IP_MREQ_SOURCE as IpMreqSource, IP_MULTICAST_IF, IP_MULTICAST_LOOP, IP_MULTICAST_TTL,
78 IP_RECVTOS, IP_TOS, IP_TTL, LINGER as linger, MSG_OOB, MSG_PEEK, SO_BROADCAST, SO_ERROR,
79 SO_KEEPALIVE, SO_LINGER, SO_OOBINLINE, SO_RCVBUF, SO_RCVTIMEO, SO_REUSEADDR, SO_SNDBUF,
80 SO_SNDTIMEO, SO_TYPE, TCP_NODELAY,
81 };
82 pub(crate) const IPPROTO_IP: c_int = windows_sys::Win32::Networking::WinSock::IPPROTO_IP as c_int;
83 pub(crate) const SOL_SOCKET: c_int = windows_sys::Win32::Networking::WinSock::SOL_SOCKET as c_int;
84
85 /// Type used in set/getsockopt to retrieve the `TCP_NODELAY` option.
86 ///
87 /// NOTE: <https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-getsockopt>
88 /// documents that options such as `TCP_NODELAY` and `SO_KEEPALIVE` expect a
89 /// `BOOL` (alias for `c_int`, 4 bytes), however in practice this turns out to
90 /// be false (or misleading) as a `BOOLEAN` (`c_uchar`, 1 byte) is returned by
91 /// `getsockopt`.
92 pub(crate) type Bool = windows_sys::Win32::Foundation::BOOLEAN;
93
94 /// Maximum size of a buffer passed to system call like `recv` and `send`.
95 const MAX_BUF_LEN: usize = c_int::MAX as usize;
96
97 /// Helper macro to execute a system call that returns an `io::Result`.
98 macro_rules! syscall {
99 ($fn: ident ( $($arg: expr),* $(,)* ), $err_test: path, $err_value: expr) => {{
100 #[allow(unused_unsafe)]
101 let res = unsafe { windows_sys::Win32::Networking::WinSock::$fn($($arg, )*) };
102 if $err_test(&res, &$err_value) {
103 Err(io::Error::last_os_error())
104 } else {
105 Ok(res)
106 }
107 }};
108 }
109
110 impl_debug!(
111 crate::Domain,
112 self::AF_INET,
113 self::AF_INET6,
114 self::AF_UNIX,
115 self::AF_UNSPEC,
116 );
117
118 /// Windows only API.
119 impl Type {
120 /// Our custom flag to set `WSA_FLAG_NO_HANDLE_INHERIT` on socket creation.
121 /// Trying to mimic `Type::cloexec` on windows.
122 const NO_INHERIT: c_int = 1 << ((size_of::<c_int>() * 8) - 1); // Last bit.
123
124 /// Set `WSA_FLAG_NO_HANDLE_INHERIT` on the socket.
125 #[cfg(feature = "all")]
126 #[cfg_attr(docsrs, doc(cfg(all(windows, feature = "all"))))]
no_inherit(self) -> Type127 pub const fn no_inherit(self) -> Type {
128 self._no_inherit()
129 }
130
_no_inherit(self) -> Type131 pub(crate) const fn _no_inherit(self) -> Type {
132 Type(self.0 | Type::NO_INHERIT)
133 }
134 }
135
136 impl_debug!(
137 crate::Type,
138 self::SOCK_STREAM,
139 self::SOCK_DGRAM,
140 self::SOCK_RAW,
141 self::SOCK_RDM,
142 self::SOCK_SEQPACKET,
143 );
144
145 impl_debug!(
146 crate::Protocol,
147 WinSock::IPPROTO_ICMP,
148 WinSock::IPPROTO_ICMPV6,
149 WinSock::IPPROTO_TCP,
150 WinSock::IPPROTO_UDP,
151 );
152
153 impl std::fmt::Debug for RecvFlags {
fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result154 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
155 f.debug_struct("RecvFlags")
156 .field("is_truncated", &self.is_truncated())
157 .finish()
158 }
159 }
160
161 #[repr(transparent)]
162 pub struct MaybeUninitSlice<'a> {
163 vec: WSABUF,
164 _lifetime: PhantomData<&'a mut [MaybeUninit<u8>]>,
165 }
166
167 unsafe impl<'a> Send for MaybeUninitSlice<'a> {}
168
169 unsafe impl<'a> Sync for MaybeUninitSlice<'a> {}
170
171 impl<'a> MaybeUninitSlice<'a> {
new(buf: &'a mut [MaybeUninit<u8>]) -> MaybeUninitSlice<'a>172 pub fn new(buf: &'a mut [MaybeUninit<u8>]) -> MaybeUninitSlice<'a> {
173 assert!(buf.len() <= u32::MAX as usize);
174 MaybeUninitSlice {
175 vec: WSABUF {
176 len: buf.len() as u32,
177 buf: buf.as_mut_ptr().cast(),
178 },
179 _lifetime: PhantomData,
180 }
181 }
182
as_slice(&self) -> &[MaybeUninit<u8>]183 pub fn as_slice(&self) -> &[MaybeUninit<u8>] {
184 unsafe { slice::from_raw_parts(self.vec.buf.cast(), self.vec.len as usize) }
185 }
186
as_mut_slice(&mut self) -> &mut [MaybeUninit<u8>]187 pub fn as_mut_slice(&mut self) -> &mut [MaybeUninit<u8>] {
188 unsafe { slice::from_raw_parts_mut(self.vec.buf.cast(), self.vec.len as usize) }
189 }
190 }
191
192 // Used in `MsgHdr`.
193 pub(crate) use windows_sys::Win32::Networking::WinSock::WSAMSG as msghdr;
194
set_msghdr_name(msg: &mut msghdr, name: &SockAddr)195 pub(crate) fn set_msghdr_name(msg: &mut msghdr, name: &SockAddr) {
196 msg.name = name.as_ptr() as *mut _;
197 msg.namelen = name.len();
198 }
199
set_msghdr_iov(msg: &mut msghdr, ptr: *mut WSABUF, len: usize)200 pub(crate) fn set_msghdr_iov(msg: &mut msghdr, ptr: *mut WSABUF, len: usize) {
201 msg.lpBuffers = ptr;
202 msg.dwBufferCount = min(len, u32::MAX as usize) as u32;
203 }
204
set_msghdr_control(msg: &mut msghdr, ptr: *mut u8, len: usize)205 pub(crate) fn set_msghdr_control(msg: &mut msghdr, ptr: *mut u8, len: usize) {
206 msg.Control.buf = ptr;
207 msg.Control.len = len as u32;
208 }
209
set_msghdr_flags(msg: &mut msghdr, flags: c_int)210 pub(crate) fn set_msghdr_flags(msg: &mut msghdr, flags: c_int) {
211 msg.dwFlags = flags as u32;
212 }
213
msghdr_flags(msg: &msghdr) -> RecvFlags214 pub(crate) fn msghdr_flags(msg: &msghdr) -> RecvFlags {
215 RecvFlags(msg.dwFlags as c_int)
216 }
217
init()218 fn init() {
219 static INIT: Once = Once::new();
220
221 INIT.call_once(|| {
222 // Initialize winsock through the standard library by just creating a
223 // dummy socket. Whether this is successful or not we drop the result as
224 // libstd will be sure to have initialized winsock.
225 let _ = net::UdpSocket::bind("127.0.0.1:34254");
226 });
227 }
228
229 pub(crate) type Socket = windows_sys::Win32::Networking::WinSock::SOCKET;
230
socket_from_raw(socket: Socket) -> crate::socket::Inner231 pub(crate) unsafe fn socket_from_raw(socket: Socket) -> crate::socket::Inner {
232 crate::socket::Inner::from_raw_socket(socket as RawSocket)
233 }
234
socket_as_raw(socket: &crate::socket::Inner) -> Socket235 pub(crate) fn socket_as_raw(socket: &crate::socket::Inner) -> Socket {
236 socket.as_raw_socket() as Socket
237 }
238
socket_into_raw(socket: crate::socket::Inner) -> Socket239 pub(crate) fn socket_into_raw(socket: crate::socket::Inner) -> Socket {
240 socket.into_raw_socket() as Socket
241 }
242
socket(family: c_int, mut ty: c_int, protocol: c_int) -> io::Result<Socket>243 pub(crate) fn socket(family: c_int, mut ty: c_int, protocol: c_int) -> io::Result<Socket> {
244 init();
245
246 // Check if we set our custom flag.
247 let flags = if ty & Type::NO_INHERIT != 0 {
248 ty = ty & !Type::NO_INHERIT;
249 WSA_FLAG_NO_HANDLE_INHERIT
250 } else {
251 0
252 };
253
254 syscall!(
255 WSASocketW(
256 family,
257 ty,
258 protocol,
259 ptr::null_mut(),
260 0,
261 WSA_FLAG_OVERLAPPED | flags,
262 ),
263 PartialEq::eq,
264 INVALID_SOCKET
265 )
266 }
267
bind(socket: Socket, addr: &SockAddr) -> io::Result<()>268 pub(crate) fn bind(socket: Socket, addr: &SockAddr) -> io::Result<()> {
269 syscall!(bind(socket, addr.as_ptr(), addr.len()), PartialEq::ne, 0).map(|_| ())
270 }
271
connect(socket: Socket, addr: &SockAddr) -> io::Result<()>272 pub(crate) fn connect(socket: Socket, addr: &SockAddr) -> io::Result<()> {
273 syscall!(connect(socket, addr.as_ptr(), addr.len()), PartialEq::ne, 0).map(|_| ())
274 }
275
poll_connect(socket: &crate::Socket, timeout: Duration) -> io::Result<()>276 pub(crate) fn poll_connect(socket: &crate::Socket, timeout: Duration) -> io::Result<()> {
277 let start = Instant::now();
278
279 let mut fd_array = WSAPOLLFD {
280 fd: socket.as_raw(),
281 events: (POLLRDNORM | POLLWRNORM) as i16,
282 revents: 0,
283 };
284
285 loop {
286 let elapsed = start.elapsed();
287 if elapsed >= timeout {
288 return Err(io::ErrorKind::TimedOut.into());
289 }
290
291 let timeout = (timeout - elapsed).as_millis();
292 let timeout = clamp(timeout, 1, c_int::MAX as u128) as c_int;
293
294 match syscall!(
295 WSAPoll(&mut fd_array, 1, timeout),
296 PartialEq::eq,
297 SOCKET_ERROR
298 ) {
299 Ok(0) => return Err(io::ErrorKind::TimedOut.into()),
300 Ok(_) => {
301 // Error or hang up indicates an error (or failure to connect).
302 if (fd_array.revents & POLLERR as i16) != 0
303 || (fd_array.revents & POLLHUP as i16) != 0
304 {
305 match socket.take_error() {
306 Ok(Some(err)) => return Err(err),
307 Ok(None) => {
308 return Err(io::Error::new(
309 io::ErrorKind::Other,
310 "no error set after POLLHUP",
311 ))
312 }
313 Err(err) => return Err(err),
314 }
315 }
316 return Ok(());
317 }
318 // Got interrupted, try again.
319 Err(ref err) if err.kind() == io::ErrorKind::Interrupted => continue,
320 Err(err) => return Err(err),
321 }
322 }
323 }
324
325 // TODO: use clamp from std lib, stable since 1.50.
clamp<T>(value: T, min: T, max: T) -> T where T: Ord,326 fn clamp<T>(value: T, min: T, max: T) -> T
327 where
328 T: Ord,
329 {
330 if value <= min {
331 min
332 } else if value >= max {
333 max
334 } else {
335 value
336 }
337 }
338
listen(socket: Socket, backlog: c_int) -> io::Result<()>339 pub(crate) fn listen(socket: Socket, backlog: c_int) -> io::Result<()> {
340 syscall!(listen(socket, backlog), PartialEq::ne, 0).map(|_| ())
341 }
342
accept(socket: Socket) -> io::Result<(Socket, SockAddr)>343 pub(crate) fn accept(socket: Socket) -> io::Result<(Socket, SockAddr)> {
344 // Safety: `accept` initialises the `SockAddr` for us.
345 unsafe {
346 SockAddr::try_init(|storage, len| {
347 syscall!(
348 accept(socket, storage.cast(), len),
349 PartialEq::eq,
350 INVALID_SOCKET
351 )
352 })
353 }
354 }
355
getsockname(socket: Socket) -> io::Result<SockAddr>356 pub(crate) fn getsockname(socket: Socket) -> io::Result<SockAddr> {
357 // Safety: `getsockname` initialises the `SockAddr` for us.
358 unsafe {
359 SockAddr::try_init(|storage, len| {
360 syscall!(
361 getsockname(socket, storage.cast(), len),
362 PartialEq::eq,
363 SOCKET_ERROR
364 )
365 })
366 }
367 .map(|(_, addr)| addr)
368 }
369
getpeername(socket: Socket) -> io::Result<SockAddr>370 pub(crate) fn getpeername(socket: Socket) -> io::Result<SockAddr> {
371 // Safety: `getpeername` initialises the `SockAddr` for us.
372 unsafe {
373 SockAddr::try_init(|storage, len| {
374 syscall!(
375 getpeername(socket, storage.cast(), len),
376 PartialEq::eq,
377 SOCKET_ERROR
378 )
379 })
380 }
381 .map(|(_, addr)| addr)
382 }
383
try_clone(socket: Socket) -> io::Result<Socket>384 pub(crate) fn try_clone(socket: Socket) -> io::Result<Socket> {
385 let mut info: MaybeUninit<WSAPROTOCOL_INFOW> = MaybeUninit::uninit();
386 syscall!(
387 // NOTE: `process.id` is the same as `GetCurrentProcessId`.
388 WSADuplicateSocketW(socket, process::id(), info.as_mut_ptr()),
389 PartialEq::eq,
390 SOCKET_ERROR
391 )?;
392 // Safety: `WSADuplicateSocketW` intialised `info` for us.
393 let mut info = unsafe { info.assume_init() };
394
395 syscall!(
396 WSASocketW(
397 info.iAddressFamily,
398 info.iSocketType,
399 info.iProtocol,
400 &mut info,
401 0,
402 WSA_FLAG_OVERLAPPED | WSA_FLAG_NO_HANDLE_INHERIT,
403 ),
404 PartialEq::eq,
405 INVALID_SOCKET
406 )
407 }
408
set_nonblocking(socket: Socket, nonblocking: bool) -> io::Result<()>409 pub(crate) fn set_nonblocking(socket: Socket, nonblocking: bool) -> io::Result<()> {
410 let mut nonblocking = if nonblocking { 1 } else { 0 };
411 ioctlsocket(socket, FIONBIO, &mut nonblocking)
412 }
413
shutdown(socket: Socket, how: Shutdown) -> io::Result<()>414 pub(crate) fn shutdown(socket: Socket, how: Shutdown) -> io::Result<()> {
415 let how = match how {
416 Shutdown::Write => SD_SEND,
417 Shutdown::Read => SD_RECEIVE,
418 Shutdown::Both => SD_BOTH,
419 } as i32;
420 syscall!(shutdown(socket, how), PartialEq::eq, SOCKET_ERROR).map(|_| ())
421 }
422
recv(socket: Socket, buf: &mut [MaybeUninit<u8>], flags: c_int) -> io::Result<usize>423 pub(crate) fn recv(socket: Socket, buf: &mut [MaybeUninit<u8>], flags: c_int) -> io::Result<usize> {
424 let res = syscall!(
425 recv(
426 socket,
427 buf.as_mut_ptr().cast(),
428 min(buf.len(), MAX_BUF_LEN) as c_int,
429 flags,
430 ),
431 PartialEq::eq,
432 SOCKET_ERROR
433 );
434 match res {
435 Ok(n) => Ok(n as usize),
436 Err(ref err) if err.raw_os_error() == Some(WSAESHUTDOWN as i32) => Ok(0),
437 Err(err) => Err(err),
438 }
439 }
440
recv_vectored( socket: Socket, bufs: &mut [crate::MaybeUninitSlice<'_>], flags: c_int, ) -> io::Result<(usize, RecvFlags)>441 pub(crate) fn recv_vectored(
442 socket: Socket,
443 bufs: &mut [crate::MaybeUninitSlice<'_>],
444 flags: c_int,
445 ) -> io::Result<(usize, RecvFlags)> {
446 let mut nread = 0;
447 let mut flags = flags as u32;
448 let res = syscall!(
449 WSARecv(
450 socket,
451 bufs.as_mut_ptr().cast(),
452 min(bufs.len(), u32::MAX as usize) as u32,
453 &mut nread,
454 &mut flags,
455 ptr::null_mut(),
456 None,
457 ),
458 PartialEq::eq,
459 SOCKET_ERROR
460 );
461 match res {
462 Ok(_) => Ok((nread as usize, RecvFlags(0))),
463 Err(ref err) if err.raw_os_error() == Some(WSAESHUTDOWN as i32) => Ok((0, RecvFlags(0))),
464 Err(ref err) if err.raw_os_error() == Some(WSAEMSGSIZE as i32) => {
465 Ok((nread as usize, RecvFlags(MSG_TRUNC)))
466 }
467 Err(err) => Err(err),
468 }
469 }
470
recv_from( socket: Socket, buf: &mut [MaybeUninit<u8>], flags: c_int, ) -> io::Result<(usize, SockAddr)>471 pub(crate) fn recv_from(
472 socket: Socket,
473 buf: &mut [MaybeUninit<u8>],
474 flags: c_int,
475 ) -> io::Result<(usize, SockAddr)> {
476 // Safety: `recvfrom` initialises the `SockAddr` for us.
477 unsafe {
478 SockAddr::try_init(|storage, addrlen| {
479 let res = syscall!(
480 recvfrom(
481 socket,
482 buf.as_mut_ptr().cast(),
483 min(buf.len(), MAX_BUF_LEN) as c_int,
484 flags,
485 storage.cast(),
486 addrlen,
487 ),
488 PartialEq::eq,
489 SOCKET_ERROR
490 );
491 match res {
492 Ok(n) => Ok(n as usize),
493 Err(ref err) if err.raw_os_error() == Some(WSAESHUTDOWN as i32) => Ok(0),
494 Err(err) => Err(err),
495 }
496 })
497 }
498 }
499
peek_sender(socket: Socket) -> io::Result<SockAddr>500 pub(crate) fn peek_sender(socket: Socket) -> io::Result<SockAddr> {
501 // Safety: `recvfrom` initialises the `SockAddr` for us.
502 let ((), sender) = unsafe {
503 SockAddr::try_init(|storage, addrlen| {
504 let res = syscall!(
505 recvfrom(
506 socket,
507 // Windows *appears* not to care if you pass a null pointer.
508 ptr::null_mut(),
509 0,
510 MSG_PEEK,
511 storage.cast(),
512 addrlen,
513 ),
514 PartialEq::eq,
515 SOCKET_ERROR
516 );
517 match res {
518 Ok(_n) => Ok(()),
519 Err(e) => match e.raw_os_error() {
520 Some(code) if code == (WSAESHUTDOWN as i32) || code == (WSAEMSGSIZE as i32) => {
521 Ok(())
522 }
523 _ => Err(e),
524 },
525 }
526 })
527 }?;
528
529 Ok(sender)
530 }
531
recv_from_vectored( socket: Socket, bufs: &mut [crate::MaybeUninitSlice<'_>], flags: c_int, ) -> io::Result<(usize, RecvFlags, SockAddr)>532 pub(crate) fn recv_from_vectored(
533 socket: Socket,
534 bufs: &mut [crate::MaybeUninitSlice<'_>],
535 flags: c_int,
536 ) -> io::Result<(usize, RecvFlags, SockAddr)> {
537 // Safety: `recvfrom` initialises the `SockAddr` for us.
538 unsafe {
539 SockAddr::try_init(|storage, addrlen| {
540 let mut nread = 0;
541 let mut flags = flags as u32;
542 let res = syscall!(
543 WSARecvFrom(
544 socket,
545 bufs.as_mut_ptr().cast(),
546 min(bufs.len(), u32::MAX as usize) as u32,
547 &mut nread,
548 &mut flags,
549 storage.cast(),
550 addrlen,
551 ptr::null_mut(),
552 None,
553 ),
554 PartialEq::eq,
555 SOCKET_ERROR
556 );
557 match res {
558 Ok(_) => Ok((nread as usize, RecvFlags(0))),
559 Err(ref err) if err.raw_os_error() == Some(WSAESHUTDOWN as i32) => {
560 Ok((nread as usize, RecvFlags(0)))
561 }
562 Err(ref err) if err.raw_os_error() == Some(WSAEMSGSIZE as i32) => {
563 Ok((nread as usize, RecvFlags(MSG_TRUNC)))
564 }
565 Err(err) => Err(err),
566 }
567 })
568 }
569 .map(|((n, recv_flags), addr)| (n, recv_flags, addr))
570 }
571
send(socket: Socket, buf: &[u8], flags: c_int) -> io::Result<usize>572 pub(crate) fn send(socket: Socket, buf: &[u8], flags: c_int) -> io::Result<usize> {
573 syscall!(
574 send(
575 socket,
576 buf.as_ptr().cast(),
577 min(buf.len(), MAX_BUF_LEN) as c_int,
578 flags,
579 ),
580 PartialEq::eq,
581 SOCKET_ERROR
582 )
583 .map(|n| n as usize)
584 }
585
send_vectored( socket: Socket, bufs: &[IoSlice<'_>], flags: c_int, ) -> io::Result<usize>586 pub(crate) fn send_vectored(
587 socket: Socket,
588 bufs: &[IoSlice<'_>],
589 flags: c_int,
590 ) -> io::Result<usize> {
591 let mut nsent = 0;
592 syscall!(
593 WSASend(
594 socket,
595 // FIXME: From the `WSASend` docs [1]:
596 // > For a Winsock application, once the WSASend function is called,
597 // > the system owns these buffers and the application may not
598 // > access them.
599 //
600 // So what we're doing is actually UB as `bufs` needs to be `&mut
601 // [IoSlice<'_>]`.
602 //
603 // Tracking issue: https://github.com/rust-lang/socket2-rs/issues/129.
604 //
605 // NOTE: `send_to_vectored` has the same problem.
606 //
607 // [1] https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasend
608 bufs.as_ptr() as *mut _,
609 min(bufs.len(), u32::MAX as usize) as u32,
610 &mut nsent,
611 flags as u32,
612 std::ptr::null_mut(),
613 None,
614 ),
615 PartialEq::eq,
616 SOCKET_ERROR
617 )
618 .map(|_| nsent as usize)
619 }
620
send_to( socket: Socket, buf: &[u8], addr: &SockAddr, flags: c_int, ) -> io::Result<usize>621 pub(crate) fn send_to(
622 socket: Socket,
623 buf: &[u8],
624 addr: &SockAddr,
625 flags: c_int,
626 ) -> io::Result<usize> {
627 syscall!(
628 sendto(
629 socket,
630 buf.as_ptr().cast(),
631 min(buf.len(), MAX_BUF_LEN) as c_int,
632 flags,
633 addr.as_ptr(),
634 addr.len(),
635 ),
636 PartialEq::eq,
637 SOCKET_ERROR
638 )
639 .map(|n| n as usize)
640 }
641
send_to_vectored( socket: Socket, bufs: &[IoSlice<'_>], addr: &SockAddr, flags: c_int, ) -> io::Result<usize>642 pub(crate) fn send_to_vectored(
643 socket: Socket,
644 bufs: &[IoSlice<'_>],
645 addr: &SockAddr,
646 flags: c_int,
647 ) -> io::Result<usize> {
648 let mut nsent = 0;
649 syscall!(
650 WSASendTo(
651 socket,
652 // FIXME: Same problem as in `send_vectored`.
653 bufs.as_ptr() as *mut _,
654 bufs.len().min(u32::MAX as usize) as u32,
655 &mut nsent,
656 flags as u32,
657 addr.as_ptr(),
658 addr.len(),
659 ptr::null_mut(),
660 None,
661 ),
662 PartialEq::eq,
663 SOCKET_ERROR
664 )
665 .map(|_| nsent as usize)
666 }
667
sendmsg(socket: Socket, msg: &MsgHdr<'_, '_, '_>, flags: c_int) -> io::Result<usize>668 pub(crate) fn sendmsg(socket: Socket, msg: &MsgHdr<'_, '_, '_>, flags: c_int) -> io::Result<usize> {
669 let mut nsent = 0;
670 syscall!(
671 WSASendMsg(
672 socket,
673 &msg.inner,
674 flags as u32,
675 &mut nsent,
676 ptr::null_mut(),
677 None,
678 ),
679 PartialEq::eq,
680 SOCKET_ERROR
681 )
682 .map(|_| nsent as usize)
683 }
684
685 /// Wrapper around `getsockopt` to deal with platform specific timeouts.
timeout_opt(fd: Socket, lvl: c_int, name: i32) -> io::Result<Option<Duration>>686 pub(crate) fn timeout_opt(fd: Socket, lvl: c_int, name: i32) -> io::Result<Option<Duration>> {
687 unsafe { getsockopt(fd, lvl, name).map(from_ms) }
688 }
689
from_ms(duration: u32) -> Option<Duration>690 fn from_ms(duration: u32) -> Option<Duration> {
691 if duration == 0 {
692 None
693 } else {
694 let secs = duration / 1000;
695 let nsec = (duration % 1000) * 1000000;
696 Some(Duration::new(secs as u64, nsec as u32))
697 }
698 }
699
700 /// Wrapper around `setsockopt` to deal with platform specific timeouts.
set_timeout_opt( socket: Socket, level: c_int, optname: i32, duration: Option<Duration>, ) -> io::Result<()>701 pub(crate) fn set_timeout_opt(
702 socket: Socket,
703 level: c_int,
704 optname: i32,
705 duration: Option<Duration>,
706 ) -> io::Result<()> {
707 let duration = into_ms(duration);
708 unsafe { setsockopt(socket, level, optname, duration) }
709 }
710
into_ms(duration: Option<Duration>) -> u32711 fn into_ms(duration: Option<Duration>) -> u32 {
712 // Note that a duration is a (u64, u32) (seconds, nanoseconds) pair, and the
713 // timeouts in windows APIs are typically u32 milliseconds. To translate, we
714 // have two pieces to take care of:
715 //
716 // * Nanosecond precision is rounded up
717 // * Greater than u32::MAX milliseconds (50 days) is rounded up to
718 // INFINITE (never time out).
719 duration.map_or(0, |duration| {
720 min(duration.as_millis(), INFINITE as u128) as u32
721 })
722 }
723
set_tcp_keepalive(socket: Socket, keepalive: &TcpKeepalive) -> io::Result<()>724 pub(crate) fn set_tcp_keepalive(socket: Socket, keepalive: &TcpKeepalive) -> io::Result<()> {
725 let mut keepalive = tcp_keepalive {
726 onoff: 1,
727 keepalivetime: into_ms(keepalive.time),
728 keepaliveinterval: into_ms(keepalive.interval),
729 };
730 let mut out = 0;
731 syscall!(
732 WSAIoctl(
733 socket,
734 SIO_KEEPALIVE_VALS,
735 &mut keepalive as *mut _ as *mut _,
736 size_of::<tcp_keepalive>() as _,
737 ptr::null_mut(),
738 0,
739 &mut out,
740 ptr::null_mut(),
741 None,
742 ),
743 PartialEq::eq,
744 SOCKET_ERROR
745 )
746 .map(|_| ())
747 }
748
749 /// Caller must ensure `T` is the correct type for `level` and `optname`.
750 // NOTE: `optname` is actually `i32`, but all constants are `u32`.
getsockopt<T>(socket: Socket, level: c_int, optname: i32) -> io::Result<T>751 pub(crate) unsafe fn getsockopt<T>(socket: Socket, level: c_int, optname: i32) -> io::Result<T> {
752 let mut optval: MaybeUninit<T> = MaybeUninit::uninit();
753 let mut optlen = mem::size_of::<T>() as c_int;
754 syscall!(
755 getsockopt(
756 socket,
757 level as i32,
758 optname,
759 optval.as_mut_ptr().cast(),
760 &mut optlen,
761 ),
762 PartialEq::eq,
763 SOCKET_ERROR
764 )
765 .map(|_| {
766 debug_assert_eq!(optlen as usize, mem::size_of::<T>());
767 // Safety: `getsockopt` initialised `optval` for us.
768 optval.assume_init()
769 })
770 }
771
772 /// Caller must ensure `T` is the correct type for `level` and `optname`.
773 // NOTE: `optname` is actually `i32`, but all constants are `u32`.
setsockopt<T>( socket: Socket, level: c_int, optname: i32, optval: T, ) -> io::Result<()>774 pub(crate) unsafe fn setsockopt<T>(
775 socket: Socket,
776 level: c_int,
777 optname: i32,
778 optval: T,
779 ) -> io::Result<()> {
780 syscall!(
781 setsockopt(
782 socket,
783 level as i32,
784 optname,
785 (&optval as *const T).cast(),
786 mem::size_of::<T>() as c_int,
787 ),
788 PartialEq::eq,
789 SOCKET_ERROR
790 )
791 .map(|_| ())
792 }
793
ioctlsocket(socket: Socket, cmd: i32, payload: &mut u32) -> io::Result<()>794 fn ioctlsocket(socket: Socket, cmd: i32, payload: &mut u32) -> io::Result<()> {
795 syscall!(
796 ioctlsocket(socket, cmd, payload),
797 PartialEq::eq,
798 SOCKET_ERROR
799 )
800 .map(|_| ())
801 }
802
to_in_addr(addr: &Ipv4Addr) -> IN_ADDR803 pub(crate) fn to_in_addr(addr: &Ipv4Addr) -> IN_ADDR {
804 IN_ADDR {
805 S_un: IN_ADDR_0 {
806 // `S_un` is stored as BE on all machines, and the array is in BE
807 // order. So the native endian conversion method is used so that
808 // it's never swapped.
809 S_addr: u32::from_ne_bytes(addr.octets()),
810 },
811 }
812 }
813
from_in_addr(in_addr: IN_ADDR) -> Ipv4Addr814 pub(crate) fn from_in_addr(in_addr: IN_ADDR) -> Ipv4Addr {
815 Ipv4Addr::from(unsafe { in_addr.S_un.S_addr }.to_ne_bytes())
816 }
817
to_in6_addr(addr: &Ipv6Addr) -> IN6_ADDR818 pub(crate) fn to_in6_addr(addr: &Ipv6Addr) -> IN6_ADDR {
819 IN6_ADDR {
820 u: IN6_ADDR_0 {
821 Byte: addr.octets(),
822 },
823 }
824 }
825
from_in6_addr(addr: IN6_ADDR) -> Ipv6Addr826 pub(crate) fn from_in6_addr(addr: IN6_ADDR) -> Ipv6Addr {
827 Ipv6Addr::from(unsafe { addr.u.Byte })
828 }
829
to_mreqn( multiaddr: &Ipv4Addr, interface: &crate::socket::InterfaceIndexOrAddress, ) -> IpMreq830 pub(crate) fn to_mreqn(
831 multiaddr: &Ipv4Addr,
832 interface: &crate::socket::InterfaceIndexOrAddress,
833 ) -> IpMreq {
834 IpMreq {
835 imr_multiaddr: to_in_addr(multiaddr),
836 // Per https://docs.microsoft.com/en-us/windows/win32/api/ws2ipdef/ns-ws2ipdef-ip_mreq#members:
837 //
838 // imr_interface
839 //
840 // The local IPv4 address of the interface or the interface index on
841 // which the multicast group should be joined or dropped. This value is
842 // in network byte order. If this member specifies an IPv4 address of
843 // 0.0.0.0, the default IPv4 multicast interface is used.
844 //
845 // To use an interface index of 1 would be the same as an IP address of
846 // 0.0.0.1.
847 imr_interface: match interface {
848 crate::socket::InterfaceIndexOrAddress::Index(interface) => {
849 to_in_addr(&(*interface).into())
850 }
851 crate::socket::InterfaceIndexOrAddress::Address(interface) => to_in_addr(interface),
852 },
853 }
854 }
855
856 #[allow(unsafe_op_in_unsafe_fn)]
unix_sockaddr(path: &Path) -> io::Result<SockAddr>857 pub(crate) fn unix_sockaddr(path: &Path) -> io::Result<SockAddr> {
858 // SAFETY: a `sockaddr_storage` of all zeros is valid.
859 let mut storage = unsafe { mem::zeroed::<sockaddr_storage>() };
860 let len = {
861 let storage: &mut windows_sys::Win32::Networking::WinSock::SOCKADDR_UN =
862 unsafe { &mut *(&mut storage as *mut sockaddr_storage).cast() };
863
864 // Windows expects a UTF-8 path here even though Windows paths are
865 // usually UCS-2 encoded. If Rust exposed OsStr's Wtf8 encoded
866 // buffer, this could be used directly, relying on Windows to
867 // validate the path, but Rust hides this implementation detail.
868 //
869 // See <https://github.com/rust-lang/rust/pull/95290>.
870 let bytes = path
871 .to_str()
872 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "path must be valid UTF-8"))?
873 .as_bytes();
874
875 // Windows appears to allow non-null-terminated paths, but this is
876 // not documented, so do not rely on it yet.
877 //
878 // See <https://github.com/rust-lang/socket2/issues/331>.
879 if bytes.len() >= storage.sun_path.len() {
880 return Err(io::Error::new(
881 io::ErrorKind::InvalidInput,
882 "path must be shorter than SUN_LEN",
883 ));
884 }
885
886 storage.sun_family = crate::sys::AF_UNIX as sa_family_t;
887 // `storage` was initialized to zero above, so the path is
888 // already null terminated.
889 storage.sun_path[..bytes.len()].copy_from_slice(bytes);
890
891 let base = storage as *const _ as usize;
892 let path = &storage.sun_path as *const _ as usize;
893 let sun_path_offset = path - base;
894 sun_path_offset + bytes.len() + 1
895 };
896 Ok(unsafe { SockAddr::new(storage, len as socklen_t) })
897 }
898
899 /// Windows only API.
900 impl crate::Socket {
901 /// Sets `HANDLE_FLAG_INHERIT` using `SetHandleInformation`.
902 #[cfg(feature = "all")]
903 #[cfg_attr(docsrs, doc(cfg(all(windows, feature = "all"))))]
set_no_inherit(&self, no_inherit: bool) -> io::Result<()>904 pub fn set_no_inherit(&self, no_inherit: bool) -> io::Result<()> {
905 self._set_no_inherit(no_inherit)
906 }
907
_set_no_inherit(&self, no_inherit: bool) -> io::Result<()>908 pub(crate) fn _set_no_inherit(&self, no_inherit: bool) -> io::Result<()> {
909 // NOTE: can't use `syscall!` because it expects the function in the
910 // `windows_sys::Win32::Networking::WinSock::` path.
911 let res = unsafe {
912 SetHandleInformation(
913 self.as_raw() as HANDLE,
914 HANDLE_FLAG_INHERIT,
915 !no_inherit as _,
916 )
917 };
918 if res == 0 {
919 // Zero means error.
920 Err(io::Error::last_os_error())
921 } else {
922 Ok(())
923 }
924 }
925
926 /// Returns the [`Protocol`] of this socket by checking the `SO_PROTOCOL_INFOW`
927 /// option on this socket.
928 ///
929 /// [`Protocol`]: crate::Protocol
930 #[cfg(feature = "all")]
protocol(&self) -> io::Result<Option<crate::Protocol>>931 pub fn protocol(&self) -> io::Result<Option<crate::Protocol>> {
932 let info = unsafe {
933 getsockopt::<WSAPROTOCOL_INFOW>(self.as_raw(), SOL_SOCKET, SO_PROTOCOL_INFOW)?
934 };
935 match info.iProtocol {
936 0 => Ok(None),
937 p => Ok(Some(crate::Protocol::from(p))),
938 }
939 }
940 }
941
942 #[cfg_attr(docsrs, doc(cfg(windows)))]
943 impl AsSocket for crate::Socket {
as_socket(&self) -> BorrowedSocket<'_>944 fn as_socket(&self) -> BorrowedSocket<'_> {
945 // SAFETY: lifetime is bound by self.
946 unsafe { BorrowedSocket::borrow_raw(self.as_raw() as RawSocket) }
947 }
948 }
949
950 #[cfg_attr(docsrs, doc(cfg(windows)))]
951 impl AsRawSocket for crate::Socket {
as_raw_socket(&self) -> RawSocket952 fn as_raw_socket(&self) -> RawSocket {
953 self.as_raw() as RawSocket
954 }
955 }
956
957 #[cfg_attr(docsrs, doc(cfg(windows)))]
958 impl From<crate::Socket> for OwnedSocket {
from(sock: crate::Socket) -> OwnedSocket959 fn from(sock: crate::Socket) -> OwnedSocket {
960 // SAFETY: sock.into_raw() always returns a valid fd.
961 unsafe { OwnedSocket::from_raw_socket(sock.into_raw() as RawSocket) }
962 }
963 }
964
965 #[cfg_attr(docsrs, doc(cfg(windows)))]
966 impl IntoRawSocket for crate::Socket {
into_raw_socket(self) -> RawSocket967 fn into_raw_socket(self) -> RawSocket {
968 self.into_raw() as RawSocket
969 }
970 }
971
972 #[cfg_attr(docsrs, doc(cfg(windows)))]
973 impl From<OwnedSocket> for crate::Socket {
from(fd: OwnedSocket) -> crate::Socket974 fn from(fd: OwnedSocket) -> crate::Socket {
975 // SAFETY: `OwnedFd` ensures the fd is valid.
976 unsafe { crate::Socket::from_raw_socket(fd.into_raw_socket()) }
977 }
978 }
979
980 #[cfg_attr(docsrs, doc(cfg(windows)))]
981 impl FromRawSocket for crate::Socket {
from_raw_socket(socket: RawSocket) -> crate::Socket982 unsafe fn from_raw_socket(socket: RawSocket) -> crate::Socket {
983 crate::Socket::from_raw(socket as Socket)
984 }
985 }
986
987 #[test]
in_addr_convertion()988 fn in_addr_convertion() {
989 let ip = Ipv4Addr::new(127, 0, 0, 1);
990 let raw = to_in_addr(&ip);
991 assert_eq!(unsafe { raw.S_un.S_addr }, 127 << 0 | 1 << 24);
992 assert_eq!(from_in_addr(raw), ip);
993
994 let ip = Ipv4Addr::new(127, 34, 4, 12);
995 let raw = to_in_addr(&ip);
996 assert_eq!(
997 unsafe { raw.S_un.S_addr },
998 127 << 0 | 34 << 8 | 4 << 16 | 12 << 24
999 );
1000 assert_eq!(from_in_addr(raw), ip);
1001 }
1002
1003 #[test]
in6_addr_convertion()1004 fn in6_addr_convertion() {
1005 let ip = Ipv6Addr::new(0x2000, 1, 2, 3, 4, 5, 6, 7);
1006 let raw = to_in6_addr(&ip);
1007 let want = [
1008 0x2000u16.to_be(),
1009 1u16.to_be(),
1010 2u16.to_be(),
1011 3u16.to_be(),
1012 4u16.to_be(),
1013 5u16.to_be(),
1014 6u16.to_be(),
1015 7u16.to_be(),
1016 ];
1017 assert_eq!(unsafe { raw.u.Word }, want);
1018 assert_eq!(from_in6_addr(raw), ip);
1019 }
1020