xref: /aosp_15_r20/external/crosvm/base/src/sys/unix/net.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2018 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::cmp::Ordering;
6 use std::convert::TryFrom;
7 use std::ffi::OsString;
8 use std::fs::remove_file;
9 use std::io;
10 use std::mem;
11 use std::mem::size_of;
12 use std::net::Ipv4Addr;
13 use std::net::Ipv6Addr;
14 use std::net::SocketAddr;
15 use std::net::SocketAddrV4;
16 use std::net::SocketAddrV6;
17 use std::net::TcpListener;
18 use std::net::TcpStream;
19 use std::net::ToSocketAddrs;
20 use std::ops::Deref;
21 use std::os::fd::OwnedFd;
22 use std::os::unix::ffi::OsStringExt;
23 use std::path::Path;
24 use std::path::PathBuf;
25 use std::ptr::null_mut;
26 use std::time::Duration;
27 use std::time::Instant;
28 
29 use libc::c_int;
30 use libc::recvfrom;
31 use libc::sa_family_t;
32 use libc::sockaddr;
33 use libc::sockaddr_in;
34 use libc::sockaddr_in6;
35 use libc::socklen_t;
36 use libc::AF_INET;
37 use libc::AF_INET6;
38 use libc::MSG_PEEK;
39 use libc::MSG_TRUNC;
40 use log::warn;
41 use serde::Deserialize;
42 use serde::Serialize;
43 
44 use crate::descriptor::AsRawDescriptor;
45 use crate::descriptor::FromRawDescriptor;
46 use crate::descriptor::IntoRawDescriptor;
47 use crate::handle_eintr_errno;
48 use crate::sys::sockaddr_un;
49 use crate::sys::sockaddrv4_to_lib_c;
50 use crate::sys::sockaddrv6_to_lib_c;
51 use crate::Error;
52 use crate::RawDescriptor;
53 use crate::SafeDescriptor;
54 
55 /// Assist in handling both IP version 4 and IP version 6.
56 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
57 pub enum InetVersion {
58     V4,
59     V6,
60 }
61 
62 impl InetVersion {
from_sockaddr(s: &SocketAddr) -> Self63     pub fn from_sockaddr(s: &SocketAddr) -> Self {
64         match s {
65             SocketAddr::V4(_) => InetVersion::V4,
66             SocketAddr::V6(_) => InetVersion::V6,
67         }
68     }
69 }
70 
71 impl From<InetVersion> for sa_family_t {
from(v: InetVersion) -> sa_family_t72     fn from(v: InetVersion) -> sa_family_t {
73         match v {
74             InetVersion::V4 => AF_INET as sa_family_t,
75             InetVersion::V6 => AF_INET6 as sa_family_t,
76         }
77     }
78 }
79 
socket( domain: c_int, sock_type: c_int, protocol: c_int, ) -> io::Result<SafeDescriptor>80 pub(in crate::sys) fn socket(
81     domain: c_int,
82     sock_type: c_int,
83     protocol: c_int,
84 ) -> io::Result<SafeDescriptor> {
85     // SAFETY:
86     // Safe socket initialization since we handle the returned error.
87     match unsafe { libc::socket(domain, sock_type, protocol) } {
88         -1 => Err(io::Error::last_os_error()),
89         // SAFETY:
90         // Safe because we own the file descriptor.
91         fd => Ok(unsafe { SafeDescriptor::from_raw_descriptor(fd) }),
92     }
93 }
94 
socketpair( domain: c_int, sock_type: c_int, protocol: c_int, ) -> io::Result<(SafeDescriptor, SafeDescriptor)>95 pub(in crate::sys) fn socketpair(
96     domain: c_int,
97     sock_type: c_int,
98     protocol: c_int,
99 ) -> io::Result<(SafeDescriptor, SafeDescriptor)> {
100     let mut fds = [0, 0];
101     // SAFETY:
102     // Safe because we give enough space to store all the fds and we check the return value.
103     match unsafe { libc::socketpair(domain, sock_type, protocol, fds.as_mut_ptr()) } {
104         -1 => Err(io::Error::last_os_error()),
105         _ => Ok(
106             // SAFETY:
107             // Safe because we own the file descriptors.
108             unsafe {
109                 (
110                     SafeDescriptor::from_raw_descriptor(fds[0]),
111                     SafeDescriptor::from_raw_descriptor(fds[1]),
112                 )
113             },
114         ),
115     }
116 }
117 
118 /// A TCP socket.
119 ///
120 /// Do not use this class unless you need to change socket options or query the
121 /// state of the socket prior to calling listen or connect. Instead use either TcpStream or
122 /// TcpListener.
123 #[derive(Debug)]
124 pub struct TcpSocket {
125     pub(in crate::sys) inet_version: InetVersion,
126     pub(in crate::sys) descriptor: SafeDescriptor,
127 }
128 
129 impl TcpSocket {
bind<A: ToSocketAddrs>(&mut self, addr: A) -> io::Result<()>130     pub fn bind<A: ToSocketAddrs>(&mut self, addr: A) -> io::Result<()> {
131         let sockaddr = addr
132             .to_socket_addrs()
133             .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?
134             .next()
135             .unwrap();
136 
137         let ret = match sockaddr {
138             SocketAddr::V4(a) => {
139                 let sin = sockaddrv4_to_lib_c(&a);
140                 // SAFETY:
141                 // Safe because this doesn't modify any memory and we check the return value.
142                 unsafe {
143                     libc::bind(
144                         self.as_raw_descriptor(),
145                         &sin as *const sockaddr_in as *const sockaddr,
146                         size_of::<sockaddr_in>() as socklen_t,
147                     )
148                 }
149             }
150             SocketAddr::V6(a) => {
151                 let sin6 = sockaddrv6_to_lib_c(&a);
152                 // SAFETY:
153                 // Safe because this doesn't modify any memory and we check the return value.
154                 unsafe {
155                     libc::bind(
156                         self.as_raw_descriptor(),
157                         &sin6 as *const sockaddr_in6 as *const sockaddr,
158                         size_of::<sockaddr_in6>() as socklen_t,
159                     )
160                 }
161             }
162         };
163         if ret < 0 {
164             let bind_err = io::Error::last_os_error();
165             Err(bind_err)
166         } else {
167             Ok(())
168         }
169     }
170 
connect<A: ToSocketAddrs>(self, addr: A) -> io::Result<TcpStream>171     pub fn connect<A: ToSocketAddrs>(self, addr: A) -> io::Result<TcpStream> {
172         let sockaddr = addr
173             .to_socket_addrs()
174             .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?
175             .next()
176             .unwrap();
177 
178         let ret = match sockaddr {
179             SocketAddr::V4(a) => {
180                 let sin = sockaddrv4_to_lib_c(&a);
181                 // SAFETY:
182                 // Safe because this doesn't modify any memory and we check the return value.
183                 unsafe {
184                     libc::connect(
185                         self.as_raw_descriptor(),
186                         &sin as *const sockaddr_in as *const sockaddr,
187                         size_of::<sockaddr_in>() as socklen_t,
188                     )
189                 }
190             }
191             SocketAddr::V6(a) => {
192                 let sin6 = sockaddrv6_to_lib_c(&a);
193                 // SAFETY:
194                 // Safe because this doesn't modify any memory and we check the return value.
195                 unsafe {
196                     libc::connect(
197                         self.as_raw_descriptor(),
198                         &sin6 as *const sockaddr_in6 as *const sockaddr,
199                         size_of::<sockaddr_in>() as socklen_t,
200                     )
201                 }
202             }
203         };
204 
205         if ret < 0 {
206             let connect_err = io::Error::last_os_error();
207             Err(connect_err)
208         } else {
209             Ok(TcpStream::from(self.descriptor))
210         }
211     }
212 
listen(self) -> io::Result<TcpListener>213     pub fn listen(self) -> io::Result<TcpListener> {
214         // SAFETY:
215         // Safe because this doesn't modify any memory and we check the return value.
216         let ret = unsafe { libc::listen(self.as_raw_descriptor(), 1) };
217         if ret < 0 {
218             let listen_err = io::Error::last_os_error();
219             Err(listen_err)
220         } else {
221             Ok(TcpListener::from(self.descriptor))
222         }
223     }
224 
225     /// Returns the port that this socket is bound to. This can only succeed after bind is called.
local_port(&self) -> io::Result<u16>226     pub fn local_port(&self) -> io::Result<u16> {
227         match self.inet_version {
228             InetVersion::V4 => {
229                 let mut sin = sockaddrv4_to_lib_c(&SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0));
230 
231                 let mut addrlen = size_of::<sockaddr_in>() as socklen_t;
232                 // SAFETY:
233                 // Safe because we give a valid pointer for addrlen and check the length.
234                 let ret = unsafe {
235                     // Get the socket address that was actually bound.
236                     libc::getsockname(
237                         self.as_raw_descriptor(),
238                         &mut sin as *mut sockaddr_in as *mut sockaddr,
239                         &mut addrlen as *mut socklen_t,
240                     )
241                 };
242                 if ret < 0 {
243                     let getsockname_err = io::Error::last_os_error();
244                     Err(getsockname_err)
245                 } else {
246                     // If this doesn't match, it's not safe to get the port out of the sockaddr.
247                     assert_eq!(addrlen as usize, size_of::<sockaddr_in>());
248 
249                     Ok(u16::from_be(sin.sin_port))
250                 }
251             }
252             InetVersion::V6 => {
253                 let mut sin6 = sockaddrv6_to_lib_c(&SocketAddrV6::new(
254                     Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0),
255                     0,
256                     0,
257                     0,
258                 ));
259 
260                 let mut addrlen = size_of::<sockaddr_in6>() as socklen_t;
261                 // SAFETY:
262                 // Safe because we give a valid pointer for addrlen and check the length.
263                 let ret = unsafe {
264                     // Get the socket address that was actually bound.
265                     libc::getsockname(
266                         self.as_raw_descriptor(),
267                         &mut sin6 as *mut sockaddr_in6 as *mut sockaddr,
268                         &mut addrlen as *mut socklen_t,
269                     )
270                 };
271                 if ret < 0 {
272                     let getsockname_err = io::Error::last_os_error();
273                     Err(getsockname_err)
274                 } else {
275                     // If this doesn't match, it's not safe to get the port out of the sockaddr.
276                     assert_eq!(addrlen as usize, size_of::<sockaddr_in>());
277 
278                     Ok(u16::from_be(sin6.sin6_port))
279                 }
280             }
281         }
282     }
283 }
284 
285 impl AsRawDescriptor for TcpSocket {
as_raw_descriptor(&self) -> RawDescriptor286     fn as_raw_descriptor(&self) -> RawDescriptor {
287         self.descriptor.as_raw_descriptor()
288     }
289 }
290 
291 // Offset of sun_path in structure sockaddr_un.
sun_path_offset() -> usize292 pub(in crate::sys) fn sun_path_offset() -> usize {
293     // Prefer 0 to null() so that we do not need to subtract from the `sub_path` pointer.
294     #[allow(clippy::zero_ptr)]
295     let addr = 0 as *const libc::sockaddr_un;
296     // SAFETY:
297     // Safe because we only use the dereference to create a pointer to the desired field in
298     // calculating the offset.
299     unsafe { &(*addr).sun_path as *const _ as usize }
300 }
301 
302 /// A Unix `SOCK_SEQPACKET` socket point to given `path`
303 #[derive(Debug, Serialize, Deserialize)]
304 pub struct UnixSeqpacket(SafeDescriptor);
305 
306 impl UnixSeqpacket {
307     /// Open a `SOCK_SEQPACKET` connection to socket named by `path`.
308     ///
309     /// # Arguments
310     /// * `path` - Path to `SOCK_SEQPACKET` socket
311     ///
312     /// # Returns
313     /// A `UnixSeqpacket` structure point to the socket
314     ///
315     /// # Errors
316     /// Return `io::Error` when error occurs.
connect<P: AsRef<Path>>(path: P) -> io::Result<Self>317     pub fn connect<P: AsRef<Path>>(path: P) -> io::Result<Self> {
318         let descriptor = socket(libc::AF_UNIX, libc::SOCK_SEQPACKET, 0)?;
319         let (addr, len) = sockaddr_un(path.as_ref())?;
320         // SAFETY:
321         // Safe connect since we handle the error and use the right length generated from
322         // `sockaddr_un`.
323         unsafe {
324             let ret = libc::connect(
325                 descriptor.as_raw_descriptor(),
326                 &addr as *const _ as *const _,
327                 len,
328             );
329             if ret < 0 {
330                 return Err(io::Error::last_os_error());
331             }
332         }
333         Ok(UnixSeqpacket(descriptor))
334     }
335 
336     /// Clone the underlying FD.
try_clone(&self) -> io::Result<Self>337     pub fn try_clone(&self) -> io::Result<Self> {
338         Ok(Self(self.0.try_clone()?))
339     }
340 
341     /// Gets the number of bytes that can be read from this socket without blocking.
get_readable_bytes(&self) -> io::Result<usize>342     pub fn get_readable_bytes(&self) -> io::Result<usize> {
343         let mut byte_count = 0i32;
344         // SAFETY:
345         // Safe because self has valid raw descriptor and return value are checked.
346         let ret = unsafe { libc::ioctl(self.as_raw_descriptor(), libc::FIONREAD, &mut byte_count) };
347         if ret < 0 {
348             Err(io::Error::last_os_error())
349         } else {
350             Ok(byte_count as usize)
351         }
352     }
353 
354     /// Gets the number of bytes in the next packet. This blocks as if `recv` were called,
355     /// respecting the blocking and timeout settings of the underlying socket.
next_packet_size(&self) -> io::Result<usize>356     pub fn next_packet_size(&self) -> io::Result<usize> {
357         #[cfg(not(debug_assertions))]
358         let buf = null_mut();
359         // Work around for qemu's syscall translation which will reject null pointers in recvfrom.
360         // This only matters for running the unit tests for a non-native architecture. See the
361         // upstream thread for the qemu fix:
362         // https://lists.nongnu.org/archive/html/qemu-devel/2021-03/msg09027.html
363         #[cfg(debug_assertions)]
364         let buf = &mut 0 as *mut _ as *mut _;
365 
366         // SAFETY:
367         // This form of recvfrom doesn't modify any data because all null pointers are used. We only
368         // use the return value and check for errors on an FD owned by this structure.
369         let ret = unsafe {
370             recvfrom(
371                 self.as_raw_descriptor(),
372                 buf,
373                 0,
374                 MSG_TRUNC | MSG_PEEK,
375                 null_mut(),
376                 null_mut(),
377             )
378         };
379         if ret < 0 {
380             Err(io::Error::last_os_error())
381         } else {
382             Ok(ret as usize)
383         }
384     }
385 
386     /// Write data from a given buffer to the socket fd
387     ///
388     /// # Arguments
389     /// * `buf` - A reference to the data buffer.
390     ///
391     /// # Returns
392     /// * `usize` - The size of bytes written to the buffer.
393     ///
394     /// # Errors
395     /// Returns error when `libc::write` failed.
send(&self, buf: &[u8]) -> io::Result<usize>396     pub fn send(&self, buf: &[u8]) -> io::Result<usize> {
397         // SAFETY:
398         // Safe since we make sure the input `count` == `buf.len()` and handle the returned error.
399         unsafe {
400             let ret = libc::write(
401                 self.as_raw_descriptor(),
402                 buf.as_ptr() as *const _,
403                 buf.len(),
404             );
405             if ret < 0 {
406                 Err(io::Error::last_os_error())
407             } else {
408                 Ok(ret as usize)
409             }
410         }
411     }
412 
413     /// Read data from the socket fd to a given buffer
414     ///
415     /// # Arguments
416     /// * `buf` - A mut reference to the data buffer.
417     ///
418     /// # Returns
419     /// * `usize` - The size of bytes read to the buffer.
420     ///
421     /// # Errors
422     /// Returns error when `libc::read` failed.
recv(&self, buf: &mut [u8]) -> io::Result<usize>423     pub fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
424         // SAFETY:
425         // Safe since we make sure the input `count` == `buf.len()` and handle the returned error.
426         unsafe {
427             let ret = libc::read(
428                 self.as_raw_descriptor(),
429                 buf.as_mut_ptr() as *mut _,
430                 buf.len(),
431             );
432             if ret < 0 {
433                 Err(io::Error::last_os_error())
434             } else {
435                 Ok(ret as usize)
436             }
437         }
438     }
439 
440     /// Read data from the socket fd to a given `Vec`, resizing it to the received packet's size.
441     ///
442     /// # Arguments
443     /// * `buf` - A mut reference to a `Vec` to resize and read into.
444     ///
445     /// # Errors
446     /// Returns error when `libc::read` or `get_readable_bytes` failed.
recv_to_vec(&self, buf: &mut Vec<u8>) -> io::Result<()>447     pub fn recv_to_vec(&self, buf: &mut Vec<u8>) -> io::Result<()> {
448         let packet_size = self.next_packet_size()?;
449         buf.resize(packet_size, 0);
450         let read_bytes = self.recv(buf)?;
451         buf.resize(read_bytes, 0);
452         Ok(())
453     }
454 
455     /// Read data from the socket fd to a new `Vec`.
456     ///
457     /// # Returns
458     /// * `vec` - A new `Vec` with the entire received packet.
459     ///
460     /// # Errors
461     /// Returns error when `libc::read` or `get_readable_bytes` failed.
recv_as_vec(&self) -> io::Result<Vec<u8>>462     pub fn recv_as_vec(&self) -> io::Result<Vec<u8>> {
463         let mut buf = Vec::new();
464         self.recv_to_vec(&mut buf)?;
465         Ok(buf)
466     }
467 
468     #[allow(clippy::useless_conversion)]
set_timeout(&self, timeout: Option<Duration>, kind: libc::c_int) -> io::Result<()>469     fn set_timeout(&self, timeout: Option<Duration>, kind: libc::c_int) -> io::Result<()> {
470         let timeval = match timeout {
471             Some(t) => {
472                 if t.as_secs() == 0 && t.subsec_micros() == 0 {
473                     return Err(io::Error::new(
474                         io::ErrorKind::InvalidInput,
475                         "zero timeout duration is invalid",
476                     ));
477                 }
478                 // subsec_micros fits in i32 because it is defined to be less than one million.
479                 let nsec = t.subsec_micros() as i32;
480                 libc::timeval {
481                     tv_sec: t.as_secs() as libc::time_t,
482                     tv_usec: libc::suseconds_t::from(nsec),
483                 }
484             }
485             None => libc::timeval {
486                 tv_sec: 0,
487                 tv_usec: 0,
488             },
489         };
490         // SAFETY:
491         // Safe because we own the fd, and the length of the pointer's data is the same as the
492         // passed in length parameter. The level argument is valid, the kind is assumed to be valid,
493         // and the return value is checked.
494         let ret = unsafe {
495             libc::setsockopt(
496                 self.as_raw_descriptor(),
497                 libc::SOL_SOCKET,
498                 kind,
499                 &timeval as *const libc::timeval as *const libc::c_void,
500                 mem::size_of::<libc::timeval>() as libc::socklen_t,
501             )
502         };
503         if ret < 0 {
504             Err(io::Error::last_os_error())
505         } else {
506             Ok(())
507         }
508     }
509 
510     /// Sets or removes the timeout for read/recv operations on this socket.
set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()>511     pub fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
512         self.set_timeout(timeout, libc::SO_RCVTIMEO)
513     }
514 
515     /// Sets or removes the timeout for write/send operations on this socket.
set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()>516     pub fn set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
517         self.set_timeout(timeout, libc::SO_SNDTIMEO)
518     }
519 
520     /// Sets the blocking mode for this socket.
set_nonblocking(&self, nonblocking: bool) -> io::Result<()>521     pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
522         let mut nonblocking = nonblocking as libc::c_int;
523         // SAFETY:
524         // Safe because the return value is checked, and this ioctl call sets the nonblocking mode
525         // and does not continue holding the file descriptor after the call.
526         let ret = unsafe { libc::ioctl(self.as_raw_descriptor(), libc::FIONBIO, &mut nonblocking) };
527         if ret < 0 {
528             Err(io::Error::last_os_error())
529         } else {
530             Ok(())
531         }
532     }
533 }
534 
535 impl From<UnixSeqpacket> for SafeDescriptor {
from(s: UnixSeqpacket) -> Self536     fn from(s: UnixSeqpacket) -> Self {
537         s.0
538     }
539 }
540 
541 impl From<SafeDescriptor> for UnixSeqpacket {
from(s: SafeDescriptor) -> Self542     fn from(s: SafeDescriptor) -> Self {
543         Self(s)
544     }
545 }
546 
547 impl FromRawDescriptor for UnixSeqpacket {
from_raw_descriptor(descriptor: RawDescriptor) -> Self548     unsafe fn from_raw_descriptor(descriptor: RawDescriptor) -> Self {
549         Self(SafeDescriptor::from_raw_descriptor(descriptor))
550     }
551 }
552 
553 impl AsRawDescriptor for UnixSeqpacket {
as_raw_descriptor(&self) -> RawDescriptor554     fn as_raw_descriptor(&self) -> RawDescriptor {
555         self.0.as_raw_descriptor()
556     }
557 }
558 
559 impl IntoRawDescriptor for UnixSeqpacket {
into_raw_descriptor(self) -> RawDescriptor560     fn into_raw_descriptor(self) -> RawDescriptor {
561         self.0.into_raw_descriptor()
562     }
563 }
564 
565 impl io::Read for UnixSeqpacket {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>566     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
567         self.recv(buf)
568     }
569 }
570 
571 impl io::Write for UnixSeqpacket {
write(&mut self, buf: &[u8]) -> io::Result<usize>572     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
573         self.send(buf)
574     }
575 
flush(&mut self) -> io::Result<()>576     fn flush(&mut self) -> io::Result<()> {
577         Ok(())
578     }
579 }
580 
581 /// Like a `UnixListener` but for accepting `UnixSeqpacket` type sockets.
582 pub struct UnixSeqpacketListener {
583     descriptor: SafeDescriptor,
584     no_path: bool,
585 }
586 
587 impl UnixSeqpacketListener {
588     /// Creates a new `UnixSeqpacketListener` bound to the given path.
bind<P: AsRef<Path>>(path: P) -> io::Result<Self>589     pub fn bind<P: AsRef<Path>>(path: P) -> io::Result<Self> {
590         if path.as_ref().starts_with("/proc/self/fd/") {
591             let fd = path
592                 .as_ref()
593                 .file_name()
594                 .expect("Failed to get fd filename")
595                 .to_str()
596                 .expect("fd filename should be unicode")
597                 .parse::<i32>()
598                 .expect("fd should be an integer");
599             let mut result: c_int = 0;
600             let mut result_len = size_of::<c_int>() as libc::socklen_t;
601             // SAFETY: Safe because fd and other args are valid and the return value is checked.
602             let ret = unsafe {
603                 libc::getsockopt(
604                     fd,
605                     libc::SOL_SOCKET,
606                     libc::SO_ACCEPTCONN,
607                     &mut result as *mut _ as *mut libc::c_void,
608                     &mut result_len,
609                 )
610             };
611             if ret < 0 {
612                 return Err(io::Error::last_os_error());
613             }
614             if result != 1 {
615                 return Err(io::Error::new(
616                     io::ErrorKind::InvalidInput,
617                     "specified descriptor is not a listening socket",
618                 ));
619             }
620             // SAFETY:
621             // Safe because we validated the socket file descriptor.
622             let descriptor = unsafe { SafeDescriptor::from_raw_descriptor(fd) };
623             return Ok(UnixSeqpacketListener {
624                 descriptor,
625                 no_path: true,
626             });
627         }
628 
629         let descriptor = socket(libc::AF_UNIX, libc::SOCK_SEQPACKET, 0)?;
630         let (addr, len) = sockaddr_un(path.as_ref())?;
631 
632         // SAFETY:
633         // Safe connect since we handle the error and use the right length generated from
634         // `sockaddr_un`.
635         unsafe {
636             let ret = handle_eintr_errno!(libc::bind(
637                 descriptor.as_raw_descriptor(),
638                 &addr as *const _ as *const _,
639                 len
640             ));
641             if ret < 0 {
642                 return Err(io::Error::last_os_error());
643             }
644             let ret = handle_eintr_errno!(libc::listen(descriptor.as_raw_descriptor(), 128));
645             if ret < 0 {
646                 return Err(io::Error::last_os_error());
647             }
648         }
649         Ok(UnixSeqpacketListener {
650             descriptor,
651             no_path: false,
652         })
653     }
654 
accept_with_timeout(&self, timeout: Duration) -> io::Result<UnixSeqpacket>655     pub fn accept_with_timeout(&self, timeout: Duration) -> io::Result<UnixSeqpacket> {
656         let start = Instant::now();
657 
658         loop {
659             let mut fds = libc::pollfd {
660                 fd: self.as_raw_descriptor(),
661                 events: libc::POLLIN,
662                 revents: 0,
663             };
664             let elapsed = Instant::now().saturating_duration_since(start);
665             let remaining = timeout.checked_sub(elapsed).unwrap_or(Duration::ZERO);
666             let cur_timeout_ms = i32::try_from(remaining.as_millis()).unwrap_or(i32::MAX);
667             // SAFETY:
668             // Safe because we give a valid pointer to a list (of 1) FD and we check
669             // the return value.
670             match unsafe { libc::poll(&mut fds, 1, cur_timeout_ms) }.cmp(&0) {
671                 Ordering::Greater => return self.accept(),
672                 Ordering::Equal => return Err(io::Error::from_raw_os_error(libc::ETIMEDOUT)),
673                 Ordering::Less => {
674                     if Error::last() != Error::new(libc::EINTR) {
675                         return Err(io::Error::last_os_error());
676                     }
677                 }
678             }
679         }
680     }
681 
682     /// Gets the path that this listener is bound to.
path(&self) -> io::Result<PathBuf>683     pub fn path(&self) -> io::Result<PathBuf> {
684         let mut addr = sockaddr_un(Path::new(""))?.0;
685         if self.no_path {
686             return Err(io::Error::new(
687                 io::ErrorKind::InvalidInput,
688                 "socket has no path",
689             ));
690         }
691         let sun_path_offset = (&addr.sun_path as *const _ as usize
692             - &addr.sun_family as *const _ as usize)
693             as libc::socklen_t;
694         let mut len = mem::size_of::<libc::sockaddr_un>() as libc::socklen_t;
695         // SAFETY:
696         // Safe because the length given matches the length of the data of the given pointer, and we
697         // check the return value.
698         let ret = unsafe {
699             handle_eintr_errno!(libc::getsockname(
700                 self.as_raw_descriptor(),
701                 &mut addr as *mut libc::sockaddr_un as *mut libc::sockaddr,
702                 &mut len
703             ))
704         };
705         if ret < 0 {
706             return Err(io::Error::last_os_error());
707         }
708         if addr.sun_family != libc::AF_UNIX as libc::sa_family_t
709             || addr.sun_path[0] == 0
710             || len < 1 + sun_path_offset
711         {
712             return Err(io::Error::new(
713                 io::ErrorKind::InvalidInput,
714                 "getsockname on socket returned invalid value",
715             ));
716         }
717 
718         let path_os_str = OsString::from_vec(
719             addr.sun_path[..(len - sun_path_offset - 1) as usize]
720                 .iter()
721                 .map(|&c| c as _)
722                 .collect(),
723         );
724         Ok(path_os_str.into())
725     }
726 
727     /// Sets the blocking mode for this socket.
set_nonblocking(&self, nonblocking: bool) -> io::Result<()>728     pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
729         let mut nonblocking = nonblocking as libc::c_int;
730         // SAFETY:
731         // Safe because the return value is checked, and this ioctl call sets the nonblocking mode
732         // and does not continue holding the file descriptor after the call.
733         let ret = unsafe { libc::ioctl(self.as_raw_descriptor(), libc::FIONBIO, &mut nonblocking) };
734         if ret < 0 {
735             Err(io::Error::last_os_error())
736         } else {
737             Ok(())
738         }
739     }
740 }
741 
742 impl AsRawDescriptor for UnixSeqpacketListener {
as_raw_descriptor(&self) -> RawDescriptor743     fn as_raw_descriptor(&self) -> RawDescriptor {
744         self.descriptor.as_raw_descriptor()
745     }
746 }
747 
748 impl From<UnixSeqpacketListener> for OwnedFd {
from(val: UnixSeqpacketListener) -> Self749     fn from(val: UnixSeqpacketListener) -> Self {
750         val.descriptor.into()
751     }
752 }
753 
754 /// Used to attempt to clean up a `UnixSeqpacketListener` after it is dropped.
755 pub struct UnlinkUnixSeqpacketListener(pub UnixSeqpacketListener);
756 
757 impl AsRawDescriptor for UnlinkUnixSeqpacketListener {
as_raw_descriptor(&self) -> RawDescriptor758     fn as_raw_descriptor(&self) -> RawDescriptor {
759         self.0.as_raw_descriptor()
760     }
761 }
762 
763 impl AsRef<UnixSeqpacketListener> for UnlinkUnixSeqpacketListener {
as_ref(&self) -> &UnixSeqpacketListener764     fn as_ref(&self) -> &UnixSeqpacketListener {
765         &self.0
766     }
767 }
768 
769 impl Deref for UnlinkUnixSeqpacketListener {
770     type Target = UnixSeqpacketListener;
deref(&self) -> &Self::Target771     fn deref(&self) -> &Self::Target {
772         &self.0
773     }
774 }
775 
776 impl Drop for UnlinkUnixSeqpacketListener {
drop(&mut self)777     fn drop(&mut self) {
778         if let Ok(path) = self.0.path() {
779             if let Err(e) = remove_file(path) {
780                 warn!("failed to remove control socket file: {:?}", e);
781             }
782         }
783     }
784 }
785 
786 #[cfg(test)]
787 mod tests {
788     use super::*;
789 
790     #[test]
sockaddr_un_zero_length_input()791     fn sockaddr_un_zero_length_input() {
792         let _res = sockaddr_un(Path::new("")).expect("sockaddr_un failed");
793     }
794 
795     #[test]
sockaddr_un_long_input_err()796     fn sockaddr_un_long_input_err() {
797         let res = sockaddr_un(Path::new(&"a".repeat(108)));
798         assert!(res.is_err());
799     }
800 
801     #[test]
sockaddr_un_long_input_pass()802     fn sockaddr_un_long_input_pass() {
803         let _res = sockaddr_un(Path::new(&"a".repeat(107))).expect("sockaddr_un failed");
804     }
805 
806     #[test]
sockaddr_un_len_check()807     fn sockaddr_un_len_check() {
808         let (_addr, len) = sockaddr_un(Path::new(&"a".repeat(50))).expect("sockaddr_un failed");
809         assert_eq!(len, (sun_path_offset() + 50 + 1) as u32);
810     }
811 
812     #[test]
813     #[allow(clippy::unnecessary_cast)]
814     // c_char is u8 on aarch64 and i8 on x86, so clippy's suggested fix of changing
815     // `'a' as libc::c_char` below to `b'a'` won't work everywhere.
816     #[allow(clippy::char_lit_as_u8)]
sockaddr_un_pass()817     fn sockaddr_un_pass() {
818         let path_size = 50;
819         let (addr, len) =
820             sockaddr_un(Path::new(&"a".repeat(path_size))).expect("sockaddr_un failed");
821         assert_eq!(len, (sun_path_offset() + path_size + 1) as u32);
822         assert_eq!(addr.sun_family, libc::AF_UNIX as libc::sa_family_t);
823 
824         // Check `sun_path` in returned `sockaddr_un`
825         let mut ref_sun_path = [0 as libc::c_char; 108];
826         for path in ref_sun_path.iter_mut().take(path_size) {
827             *path = 'a' as libc::c_char;
828         }
829 
830         for (addr_char, ref_char) in addr.sun_path.iter().zip(ref_sun_path.iter()) {
831             assert_eq!(addr_char, ref_char);
832         }
833     }
834 }
835