1 // Copyright 2018 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 use std::cmp::Ordering;
6 use std::convert::TryFrom;
7 use std::ffi::OsString;
8 use std::fs::remove_file;
9 use std::io;
10 use std::mem;
11 use std::mem::size_of;
12 use std::net::Ipv4Addr;
13 use std::net::Ipv6Addr;
14 use std::net::SocketAddr;
15 use std::net::SocketAddrV4;
16 use std::net::SocketAddrV6;
17 use std::net::TcpListener;
18 use std::net::TcpStream;
19 use std::net::ToSocketAddrs;
20 use std::ops::Deref;
21 use std::os::fd::OwnedFd;
22 use std::os::unix::ffi::OsStringExt;
23 use std::path::Path;
24 use std::path::PathBuf;
25 use std::ptr::null_mut;
26 use std::time::Duration;
27 use std::time::Instant;
28
29 use libc::c_int;
30 use libc::recvfrom;
31 use libc::sa_family_t;
32 use libc::sockaddr;
33 use libc::sockaddr_in;
34 use libc::sockaddr_in6;
35 use libc::socklen_t;
36 use libc::AF_INET;
37 use libc::AF_INET6;
38 use libc::MSG_PEEK;
39 use libc::MSG_TRUNC;
40 use log::warn;
41 use serde::Deserialize;
42 use serde::Serialize;
43
44 use crate::descriptor::AsRawDescriptor;
45 use crate::descriptor::FromRawDescriptor;
46 use crate::descriptor::IntoRawDescriptor;
47 use crate::handle_eintr_errno;
48 use crate::sys::sockaddr_un;
49 use crate::sys::sockaddrv4_to_lib_c;
50 use crate::sys::sockaddrv6_to_lib_c;
51 use crate::Error;
52 use crate::RawDescriptor;
53 use crate::SafeDescriptor;
54
55 /// Assist in handling both IP version 4 and IP version 6.
56 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
57 pub enum InetVersion {
58 V4,
59 V6,
60 }
61
62 impl InetVersion {
from_sockaddr(s: &SocketAddr) -> Self63 pub fn from_sockaddr(s: &SocketAddr) -> Self {
64 match s {
65 SocketAddr::V4(_) => InetVersion::V4,
66 SocketAddr::V6(_) => InetVersion::V6,
67 }
68 }
69 }
70
71 impl From<InetVersion> for sa_family_t {
from(v: InetVersion) -> sa_family_t72 fn from(v: InetVersion) -> sa_family_t {
73 match v {
74 InetVersion::V4 => AF_INET as sa_family_t,
75 InetVersion::V6 => AF_INET6 as sa_family_t,
76 }
77 }
78 }
79
socket( domain: c_int, sock_type: c_int, protocol: c_int, ) -> io::Result<SafeDescriptor>80 pub(in crate::sys) fn socket(
81 domain: c_int,
82 sock_type: c_int,
83 protocol: c_int,
84 ) -> io::Result<SafeDescriptor> {
85 // SAFETY:
86 // Safe socket initialization since we handle the returned error.
87 match unsafe { libc::socket(domain, sock_type, protocol) } {
88 -1 => Err(io::Error::last_os_error()),
89 // SAFETY:
90 // Safe because we own the file descriptor.
91 fd => Ok(unsafe { SafeDescriptor::from_raw_descriptor(fd) }),
92 }
93 }
94
socketpair( domain: c_int, sock_type: c_int, protocol: c_int, ) -> io::Result<(SafeDescriptor, SafeDescriptor)>95 pub(in crate::sys) fn socketpair(
96 domain: c_int,
97 sock_type: c_int,
98 protocol: c_int,
99 ) -> io::Result<(SafeDescriptor, SafeDescriptor)> {
100 let mut fds = [0, 0];
101 // SAFETY:
102 // Safe because we give enough space to store all the fds and we check the return value.
103 match unsafe { libc::socketpair(domain, sock_type, protocol, fds.as_mut_ptr()) } {
104 -1 => Err(io::Error::last_os_error()),
105 _ => Ok(
106 // SAFETY:
107 // Safe because we own the file descriptors.
108 unsafe {
109 (
110 SafeDescriptor::from_raw_descriptor(fds[0]),
111 SafeDescriptor::from_raw_descriptor(fds[1]),
112 )
113 },
114 ),
115 }
116 }
117
118 /// A TCP socket.
119 ///
120 /// Do not use this class unless you need to change socket options or query the
121 /// state of the socket prior to calling listen or connect. Instead use either TcpStream or
122 /// TcpListener.
123 #[derive(Debug)]
124 pub struct TcpSocket {
125 pub(in crate::sys) inet_version: InetVersion,
126 pub(in crate::sys) descriptor: SafeDescriptor,
127 }
128
129 impl TcpSocket {
bind<A: ToSocketAddrs>(&mut self, addr: A) -> io::Result<()>130 pub fn bind<A: ToSocketAddrs>(&mut self, addr: A) -> io::Result<()> {
131 let sockaddr = addr
132 .to_socket_addrs()
133 .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?
134 .next()
135 .unwrap();
136
137 let ret = match sockaddr {
138 SocketAddr::V4(a) => {
139 let sin = sockaddrv4_to_lib_c(&a);
140 // SAFETY:
141 // Safe because this doesn't modify any memory and we check the return value.
142 unsafe {
143 libc::bind(
144 self.as_raw_descriptor(),
145 &sin as *const sockaddr_in as *const sockaddr,
146 size_of::<sockaddr_in>() as socklen_t,
147 )
148 }
149 }
150 SocketAddr::V6(a) => {
151 let sin6 = sockaddrv6_to_lib_c(&a);
152 // SAFETY:
153 // Safe because this doesn't modify any memory and we check the return value.
154 unsafe {
155 libc::bind(
156 self.as_raw_descriptor(),
157 &sin6 as *const sockaddr_in6 as *const sockaddr,
158 size_of::<sockaddr_in6>() as socklen_t,
159 )
160 }
161 }
162 };
163 if ret < 0 {
164 let bind_err = io::Error::last_os_error();
165 Err(bind_err)
166 } else {
167 Ok(())
168 }
169 }
170
connect<A: ToSocketAddrs>(self, addr: A) -> io::Result<TcpStream>171 pub fn connect<A: ToSocketAddrs>(self, addr: A) -> io::Result<TcpStream> {
172 let sockaddr = addr
173 .to_socket_addrs()
174 .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?
175 .next()
176 .unwrap();
177
178 let ret = match sockaddr {
179 SocketAddr::V4(a) => {
180 let sin = sockaddrv4_to_lib_c(&a);
181 // SAFETY:
182 // Safe because this doesn't modify any memory and we check the return value.
183 unsafe {
184 libc::connect(
185 self.as_raw_descriptor(),
186 &sin as *const sockaddr_in as *const sockaddr,
187 size_of::<sockaddr_in>() as socklen_t,
188 )
189 }
190 }
191 SocketAddr::V6(a) => {
192 let sin6 = sockaddrv6_to_lib_c(&a);
193 // SAFETY:
194 // Safe because this doesn't modify any memory and we check the return value.
195 unsafe {
196 libc::connect(
197 self.as_raw_descriptor(),
198 &sin6 as *const sockaddr_in6 as *const sockaddr,
199 size_of::<sockaddr_in>() as socklen_t,
200 )
201 }
202 }
203 };
204
205 if ret < 0 {
206 let connect_err = io::Error::last_os_error();
207 Err(connect_err)
208 } else {
209 Ok(TcpStream::from(self.descriptor))
210 }
211 }
212
listen(self) -> io::Result<TcpListener>213 pub fn listen(self) -> io::Result<TcpListener> {
214 // SAFETY:
215 // Safe because this doesn't modify any memory and we check the return value.
216 let ret = unsafe { libc::listen(self.as_raw_descriptor(), 1) };
217 if ret < 0 {
218 let listen_err = io::Error::last_os_error();
219 Err(listen_err)
220 } else {
221 Ok(TcpListener::from(self.descriptor))
222 }
223 }
224
225 /// Returns the port that this socket is bound to. This can only succeed after bind is called.
local_port(&self) -> io::Result<u16>226 pub fn local_port(&self) -> io::Result<u16> {
227 match self.inet_version {
228 InetVersion::V4 => {
229 let mut sin = sockaddrv4_to_lib_c(&SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0));
230
231 let mut addrlen = size_of::<sockaddr_in>() as socklen_t;
232 // SAFETY:
233 // Safe because we give a valid pointer for addrlen and check the length.
234 let ret = unsafe {
235 // Get the socket address that was actually bound.
236 libc::getsockname(
237 self.as_raw_descriptor(),
238 &mut sin as *mut sockaddr_in as *mut sockaddr,
239 &mut addrlen as *mut socklen_t,
240 )
241 };
242 if ret < 0 {
243 let getsockname_err = io::Error::last_os_error();
244 Err(getsockname_err)
245 } else {
246 // If this doesn't match, it's not safe to get the port out of the sockaddr.
247 assert_eq!(addrlen as usize, size_of::<sockaddr_in>());
248
249 Ok(u16::from_be(sin.sin_port))
250 }
251 }
252 InetVersion::V6 => {
253 let mut sin6 = sockaddrv6_to_lib_c(&SocketAddrV6::new(
254 Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0),
255 0,
256 0,
257 0,
258 ));
259
260 let mut addrlen = size_of::<sockaddr_in6>() as socklen_t;
261 // SAFETY:
262 // Safe because we give a valid pointer for addrlen and check the length.
263 let ret = unsafe {
264 // Get the socket address that was actually bound.
265 libc::getsockname(
266 self.as_raw_descriptor(),
267 &mut sin6 as *mut sockaddr_in6 as *mut sockaddr,
268 &mut addrlen as *mut socklen_t,
269 )
270 };
271 if ret < 0 {
272 let getsockname_err = io::Error::last_os_error();
273 Err(getsockname_err)
274 } else {
275 // If this doesn't match, it's not safe to get the port out of the sockaddr.
276 assert_eq!(addrlen as usize, size_of::<sockaddr_in>());
277
278 Ok(u16::from_be(sin6.sin6_port))
279 }
280 }
281 }
282 }
283 }
284
285 impl AsRawDescriptor for TcpSocket {
as_raw_descriptor(&self) -> RawDescriptor286 fn as_raw_descriptor(&self) -> RawDescriptor {
287 self.descriptor.as_raw_descriptor()
288 }
289 }
290
291 // Offset of sun_path in structure sockaddr_un.
sun_path_offset() -> usize292 pub(in crate::sys) fn sun_path_offset() -> usize {
293 // Prefer 0 to null() so that we do not need to subtract from the `sub_path` pointer.
294 #[allow(clippy::zero_ptr)]
295 let addr = 0 as *const libc::sockaddr_un;
296 // SAFETY:
297 // Safe because we only use the dereference to create a pointer to the desired field in
298 // calculating the offset.
299 unsafe { &(*addr).sun_path as *const _ as usize }
300 }
301
302 /// A Unix `SOCK_SEQPACKET` socket point to given `path`
303 #[derive(Debug, Serialize, Deserialize)]
304 pub struct UnixSeqpacket(SafeDescriptor);
305
306 impl UnixSeqpacket {
307 /// Open a `SOCK_SEQPACKET` connection to socket named by `path`.
308 ///
309 /// # Arguments
310 /// * `path` - Path to `SOCK_SEQPACKET` socket
311 ///
312 /// # Returns
313 /// A `UnixSeqpacket` structure point to the socket
314 ///
315 /// # Errors
316 /// Return `io::Error` when error occurs.
connect<P: AsRef<Path>>(path: P) -> io::Result<Self>317 pub fn connect<P: AsRef<Path>>(path: P) -> io::Result<Self> {
318 let descriptor = socket(libc::AF_UNIX, libc::SOCK_SEQPACKET, 0)?;
319 let (addr, len) = sockaddr_un(path.as_ref())?;
320 // SAFETY:
321 // Safe connect since we handle the error and use the right length generated from
322 // `sockaddr_un`.
323 unsafe {
324 let ret = libc::connect(
325 descriptor.as_raw_descriptor(),
326 &addr as *const _ as *const _,
327 len,
328 );
329 if ret < 0 {
330 return Err(io::Error::last_os_error());
331 }
332 }
333 Ok(UnixSeqpacket(descriptor))
334 }
335
336 /// Clone the underlying FD.
try_clone(&self) -> io::Result<Self>337 pub fn try_clone(&self) -> io::Result<Self> {
338 Ok(Self(self.0.try_clone()?))
339 }
340
341 /// Gets the number of bytes that can be read from this socket without blocking.
get_readable_bytes(&self) -> io::Result<usize>342 pub fn get_readable_bytes(&self) -> io::Result<usize> {
343 let mut byte_count = 0i32;
344 // SAFETY:
345 // Safe because self has valid raw descriptor and return value are checked.
346 let ret = unsafe { libc::ioctl(self.as_raw_descriptor(), libc::FIONREAD, &mut byte_count) };
347 if ret < 0 {
348 Err(io::Error::last_os_error())
349 } else {
350 Ok(byte_count as usize)
351 }
352 }
353
354 /// Gets the number of bytes in the next packet. This blocks as if `recv` were called,
355 /// respecting the blocking and timeout settings of the underlying socket.
next_packet_size(&self) -> io::Result<usize>356 pub fn next_packet_size(&self) -> io::Result<usize> {
357 #[cfg(not(debug_assertions))]
358 let buf = null_mut();
359 // Work around for qemu's syscall translation which will reject null pointers in recvfrom.
360 // This only matters for running the unit tests for a non-native architecture. See the
361 // upstream thread for the qemu fix:
362 // https://lists.nongnu.org/archive/html/qemu-devel/2021-03/msg09027.html
363 #[cfg(debug_assertions)]
364 let buf = &mut 0 as *mut _ as *mut _;
365
366 // SAFETY:
367 // This form of recvfrom doesn't modify any data because all null pointers are used. We only
368 // use the return value and check for errors on an FD owned by this structure.
369 let ret = unsafe {
370 recvfrom(
371 self.as_raw_descriptor(),
372 buf,
373 0,
374 MSG_TRUNC | MSG_PEEK,
375 null_mut(),
376 null_mut(),
377 )
378 };
379 if ret < 0 {
380 Err(io::Error::last_os_error())
381 } else {
382 Ok(ret as usize)
383 }
384 }
385
386 /// Write data from a given buffer to the socket fd
387 ///
388 /// # Arguments
389 /// * `buf` - A reference to the data buffer.
390 ///
391 /// # Returns
392 /// * `usize` - The size of bytes written to the buffer.
393 ///
394 /// # Errors
395 /// Returns error when `libc::write` failed.
send(&self, buf: &[u8]) -> io::Result<usize>396 pub fn send(&self, buf: &[u8]) -> io::Result<usize> {
397 // SAFETY:
398 // Safe since we make sure the input `count` == `buf.len()` and handle the returned error.
399 unsafe {
400 let ret = libc::write(
401 self.as_raw_descriptor(),
402 buf.as_ptr() as *const _,
403 buf.len(),
404 );
405 if ret < 0 {
406 Err(io::Error::last_os_error())
407 } else {
408 Ok(ret as usize)
409 }
410 }
411 }
412
413 /// Read data from the socket fd to a given buffer
414 ///
415 /// # Arguments
416 /// * `buf` - A mut reference to the data buffer.
417 ///
418 /// # Returns
419 /// * `usize` - The size of bytes read to the buffer.
420 ///
421 /// # Errors
422 /// Returns error when `libc::read` failed.
recv(&self, buf: &mut [u8]) -> io::Result<usize>423 pub fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
424 // SAFETY:
425 // Safe since we make sure the input `count` == `buf.len()` and handle the returned error.
426 unsafe {
427 let ret = libc::read(
428 self.as_raw_descriptor(),
429 buf.as_mut_ptr() as *mut _,
430 buf.len(),
431 );
432 if ret < 0 {
433 Err(io::Error::last_os_error())
434 } else {
435 Ok(ret as usize)
436 }
437 }
438 }
439
440 /// Read data from the socket fd to a given `Vec`, resizing it to the received packet's size.
441 ///
442 /// # Arguments
443 /// * `buf` - A mut reference to a `Vec` to resize and read into.
444 ///
445 /// # Errors
446 /// Returns error when `libc::read` or `get_readable_bytes` failed.
recv_to_vec(&self, buf: &mut Vec<u8>) -> io::Result<()>447 pub fn recv_to_vec(&self, buf: &mut Vec<u8>) -> io::Result<()> {
448 let packet_size = self.next_packet_size()?;
449 buf.resize(packet_size, 0);
450 let read_bytes = self.recv(buf)?;
451 buf.resize(read_bytes, 0);
452 Ok(())
453 }
454
455 /// Read data from the socket fd to a new `Vec`.
456 ///
457 /// # Returns
458 /// * `vec` - A new `Vec` with the entire received packet.
459 ///
460 /// # Errors
461 /// Returns error when `libc::read` or `get_readable_bytes` failed.
recv_as_vec(&self) -> io::Result<Vec<u8>>462 pub fn recv_as_vec(&self) -> io::Result<Vec<u8>> {
463 let mut buf = Vec::new();
464 self.recv_to_vec(&mut buf)?;
465 Ok(buf)
466 }
467
468 #[allow(clippy::useless_conversion)]
set_timeout(&self, timeout: Option<Duration>, kind: libc::c_int) -> io::Result<()>469 fn set_timeout(&self, timeout: Option<Duration>, kind: libc::c_int) -> io::Result<()> {
470 let timeval = match timeout {
471 Some(t) => {
472 if t.as_secs() == 0 && t.subsec_micros() == 0 {
473 return Err(io::Error::new(
474 io::ErrorKind::InvalidInput,
475 "zero timeout duration is invalid",
476 ));
477 }
478 // subsec_micros fits in i32 because it is defined to be less than one million.
479 let nsec = t.subsec_micros() as i32;
480 libc::timeval {
481 tv_sec: t.as_secs() as libc::time_t,
482 tv_usec: libc::suseconds_t::from(nsec),
483 }
484 }
485 None => libc::timeval {
486 tv_sec: 0,
487 tv_usec: 0,
488 },
489 };
490 // SAFETY:
491 // Safe because we own the fd, and the length of the pointer's data is the same as the
492 // passed in length parameter. The level argument is valid, the kind is assumed to be valid,
493 // and the return value is checked.
494 let ret = unsafe {
495 libc::setsockopt(
496 self.as_raw_descriptor(),
497 libc::SOL_SOCKET,
498 kind,
499 &timeval as *const libc::timeval as *const libc::c_void,
500 mem::size_of::<libc::timeval>() as libc::socklen_t,
501 )
502 };
503 if ret < 0 {
504 Err(io::Error::last_os_error())
505 } else {
506 Ok(())
507 }
508 }
509
510 /// Sets or removes the timeout for read/recv operations on this socket.
set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()>511 pub fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
512 self.set_timeout(timeout, libc::SO_RCVTIMEO)
513 }
514
515 /// Sets or removes the timeout for write/send operations on this socket.
set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()>516 pub fn set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
517 self.set_timeout(timeout, libc::SO_SNDTIMEO)
518 }
519
520 /// Sets the blocking mode for this socket.
set_nonblocking(&self, nonblocking: bool) -> io::Result<()>521 pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
522 let mut nonblocking = nonblocking as libc::c_int;
523 // SAFETY:
524 // Safe because the return value is checked, and this ioctl call sets the nonblocking mode
525 // and does not continue holding the file descriptor after the call.
526 let ret = unsafe { libc::ioctl(self.as_raw_descriptor(), libc::FIONBIO, &mut nonblocking) };
527 if ret < 0 {
528 Err(io::Error::last_os_error())
529 } else {
530 Ok(())
531 }
532 }
533 }
534
535 impl From<UnixSeqpacket> for SafeDescriptor {
from(s: UnixSeqpacket) -> Self536 fn from(s: UnixSeqpacket) -> Self {
537 s.0
538 }
539 }
540
541 impl From<SafeDescriptor> for UnixSeqpacket {
from(s: SafeDescriptor) -> Self542 fn from(s: SafeDescriptor) -> Self {
543 Self(s)
544 }
545 }
546
547 impl FromRawDescriptor for UnixSeqpacket {
from_raw_descriptor(descriptor: RawDescriptor) -> Self548 unsafe fn from_raw_descriptor(descriptor: RawDescriptor) -> Self {
549 Self(SafeDescriptor::from_raw_descriptor(descriptor))
550 }
551 }
552
553 impl AsRawDescriptor for UnixSeqpacket {
as_raw_descriptor(&self) -> RawDescriptor554 fn as_raw_descriptor(&self) -> RawDescriptor {
555 self.0.as_raw_descriptor()
556 }
557 }
558
559 impl IntoRawDescriptor for UnixSeqpacket {
into_raw_descriptor(self) -> RawDescriptor560 fn into_raw_descriptor(self) -> RawDescriptor {
561 self.0.into_raw_descriptor()
562 }
563 }
564
565 impl io::Read for UnixSeqpacket {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>566 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
567 self.recv(buf)
568 }
569 }
570
571 impl io::Write for UnixSeqpacket {
write(&mut self, buf: &[u8]) -> io::Result<usize>572 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
573 self.send(buf)
574 }
575
flush(&mut self) -> io::Result<()>576 fn flush(&mut self) -> io::Result<()> {
577 Ok(())
578 }
579 }
580
581 /// Like a `UnixListener` but for accepting `UnixSeqpacket` type sockets.
582 pub struct UnixSeqpacketListener {
583 descriptor: SafeDescriptor,
584 no_path: bool,
585 }
586
587 impl UnixSeqpacketListener {
588 /// Creates a new `UnixSeqpacketListener` bound to the given path.
bind<P: AsRef<Path>>(path: P) -> io::Result<Self>589 pub fn bind<P: AsRef<Path>>(path: P) -> io::Result<Self> {
590 if path.as_ref().starts_with("/proc/self/fd/") {
591 let fd = path
592 .as_ref()
593 .file_name()
594 .expect("Failed to get fd filename")
595 .to_str()
596 .expect("fd filename should be unicode")
597 .parse::<i32>()
598 .expect("fd should be an integer");
599 let mut result: c_int = 0;
600 let mut result_len = size_of::<c_int>() as libc::socklen_t;
601 // SAFETY: Safe because fd and other args are valid and the return value is checked.
602 let ret = unsafe {
603 libc::getsockopt(
604 fd,
605 libc::SOL_SOCKET,
606 libc::SO_ACCEPTCONN,
607 &mut result as *mut _ as *mut libc::c_void,
608 &mut result_len,
609 )
610 };
611 if ret < 0 {
612 return Err(io::Error::last_os_error());
613 }
614 if result != 1 {
615 return Err(io::Error::new(
616 io::ErrorKind::InvalidInput,
617 "specified descriptor is not a listening socket",
618 ));
619 }
620 // SAFETY:
621 // Safe because we validated the socket file descriptor.
622 let descriptor = unsafe { SafeDescriptor::from_raw_descriptor(fd) };
623 return Ok(UnixSeqpacketListener {
624 descriptor,
625 no_path: true,
626 });
627 }
628
629 let descriptor = socket(libc::AF_UNIX, libc::SOCK_SEQPACKET, 0)?;
630 let (addr, len) = sockaddr_un(path.as_ref())?;
631
632 // SAFETY:
633 // Safe connect since we handle the error and use the right length generated from
634 // `sockaddr_un`.
635 unsafe {
636 let ret = handle_eintr_errno!(libc::bind(
637 descriptor.as_raw_descriptor(),
638 &addr as *const _ as *const _,
639 len
640 ));
641 if ret < 0 {
642 return Err(io::Error::last_os_error());
643 }
644 let ret = handle_eintr_errno!(libc::listen(descriptor.as_raw_descriptor(), 128));
645 if ret < 0 {
646 return Err(io::Error::last_os_error());
647 }
648 }
649 Ok(UnixSeqpacketListener {
650 descriptor,
651 no_path: false,
652 })
653 }
654
accept_with_timeout(&self, timeout: Duration) -> io::Result<UnixSeqpacket>655 pub fn accept_with_timeout(&self, timeout: Duration) -> io::Result<UnixSeqpacket> {
656 let start = Instant::now();
657
658 loop {
659 let mut fds = libc::pollfd {
660 fd: self.as_raw_descriptor(),
661 events: libc::POLLIN,
662 revents: 0,
663 };
664 let elapsed = Instant::now().saturating_duration_since(start);
665 let remaining = timeout.checked_sub(elapsed).unwrap_or(Duration::ZERO);
666 let cur_timeout_ms = i32::try_from(remaining.as_millis()).unwrap_or(i32::MAX);
667 // SAFETY:
668 // Safe because we give a valid pointer to a list (of 1) FD and we check
669 // the return value.
670 match unsafe { libc::poll(&mut fds, 1, cur_timeout_ms) }.cmp(&0) {
671 Ordering::Greater => return self.accept(),
672 Ordering::Equal => return Err(io::Error::from_raw_os_error(libc::ETIMEDOUT)),
673 Ordering::Less => {
674 if Error::last() != Error::new(libc::EINTR) {
675 return Err(io::Error::last_os_error());
676 }
677 }
678 }
679 }
680 }
681
682 /// Gets the path that this listener is bound to.
path(&self) -> io::Result<PathBuf>683 pub fn path(&self) -> io::Result<PathBuf> {
684 let mut addr = sockaddr_un(Path::new(""))?.0;
685 if self.no_path {
686 return Err(io::Error::new(
687 io::ErrorKind::InvalidInput,
688 "socket has no path",
689 ));
690 }
691 let sun_path_offset = (&addr.sun_path as *const _ as usize
692 - &addr.sun_family as *const _ as usize)
693 as libc::socklen_t;
694 let mut len = mem::size_of::<libc::sockaddr_un>() as libc::socklen_t;
695 // SAFETY:
696 // Safe because the length given matches the length of the data of the given pointer, and we
697 // check the return value.
698 let ret = unsafe {
699 handle_eintr_errno!(libc::getsockname(
700 self.as_raw_descriptor(),
701 &mut addr as *mut libc::sockaddr_un as *mut libc::sockaddr,
702 &mut len
703 ))
704 };
705 if ret < 0 {
706 return Err(io::Error::last_os_error());
707 }
708 if addr.sun_family != libc::AF_UNIX as libc::sa_family_t
709 || addr.sun_path[0] == 0
710 || len < 1 + sun_path_offset
711 {
712 return Err(io::Error::new(
713 io::ErrorKind::InvalidInput,
714 "getsockname on socket returned invalid value",
715 ));
716 }
717
718 let path_os_str = OsString::from_vec(
719 addr.sun_path[..(len - sun_path_offset - 1) as usize]
720 .iter()
721 .map(|&c| c as _)
722 .collect(),
723 );
724 Ok(path_os_str.into())
725 }
726
727 /// Sets the blocking mode for this socket.
set_nonblocking(&self, nonblocking: bool) -> io::Result<()>728 pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
729 let mut nonblocking = nonblocking as libc::c_int;
730 // SAFETY:
731 // Safe because the return value is checked, and this ioctl call sets the nonblocking mode
732 // and does not continue holding the file descriptor after the call.
733 let ret = unsafe { libc::ioctl(self.as_raw_descriptor(), libc::FIONBIO, &mut nonblocking) };
734 if ret < 0 {
735 Err(io::Error::last_os_error())
736 } else {
737 Ok(())
738 }
739 }
740 }
741
742 impl AsRawDescriptor for UnixSeqpacketListener {
as_raw_descriptor(&self) -> RawDescriptor743 fn as_raw_descriptor(&self) -> RawDescriptor {
744 self.descriptor.as_raw_descriptor()
745 }
746 }
747
748 impl From<UnixSeqpacketListener> for OwnedFd {
from(val: UnixSeqpacketListener) -> Self749 fn from(val: UnixSeqpacketListener) -> Self {
750 val.descriptor.into()
751 }
752 }
753
754 /// Used to attempt to clean up a `UnixSeqpacketListener` after it is dropped.
755 pub struct UnlinkUnixSeqpacketListener(pub UnixSeqpacketListener);
756
757 impl AsRawDescriptor for UnlinkUnixSeqpacketListener {
as_raw_descriptor(&self) -> RawDescriptor758 fn as_raw_descriptor(&self) -> RawDescriptor {
759 self.0.as_raw_descriptor()
760 }
761 }
762
763 impl AsRef<UnixSeqpacketListener> for UnlinkUnixSeqpacketListener {
as_ref(&self) -> &UnixSeqpacketListener764 fn as_ref(&self) -> &UnixSeqpacketListener {
765 &self.0
766 }
767 }
768
769 impl Deref for UnlinkUnixSeqpacketListener {
770 type Target = UnixSeqpacketListener;
deref(&self) -> &Self::Target771 fn deref(&self) -> &Self::Target {
772 &self.0
773 }
774 }
775
776 impl Drop for UnlinkUnixSeqpacketListener {
drop(&mut self)777 fn drop(&mut self) {
778 if let Ok(path) = self.0.path() {
779 if let Err(e) = remove_file(path) {
780 warn!("failed to remove control socket file: {:?}", e);
781 }
782 }
783 }
784 }
785
786 #[cfg(test)]
787 mod tests {
788 use super::*;
789
790 #[test]
sockaddr_un_zero_length_input()791 fn sockaddr_un_zero_length_input() {
792 let _res = sockaddr_un(Path::new("")).expect("sockaddr_un failed");
793 }
794
795 #[test]
sockaddr_un_long_input_err()796 fn sockaddr_un_long_input_err() {
797 let res = sockaddr_un(Path::new(&"a".repeat(108)));
798 assert!(res.is_err());
799 }
800
801 #[test]
sockaddr_un_long_input_pass()802 fn sockaddr_un_long_input_pass() {
803 let _res = sockaddr_un(Path::new(&"a".repeat(107))).expect("sockaddr_un failed");
804 }
805
806 #[test]
sockaddr_un_len_check()807 fn sockaddr_un_len_check() {
808 let (_addr, len) = sockaddr_un(Path::new(&"a".repeat(50))).expect("sockaddr_un failed");
809 assert_eq!(len, (sun_path_offset() + 50 + 1) as u32);
810 }
811
812 #[test]
813 #[allow(clippy::unnecessary_cast)]
814 // c_char is u8 on aarch64 and i8 on x86, so clippy's suggested fix of changing
815 // `'a' as libc::c_char` below to `b'a'` won't work everywhere.
816 #[allow(clippy::char_lit_as_u8)]
sockaddr_un_pass()817 fn sockaddr_un_pass() {
818 let path_size = 50;
819 let (addr, len) =
820 sockaddr_un(Path::new(&"a".repeat(path_size))).expect("sockaddr_un failed");
821 assert_eq!(len, (sun_path_offset() + path_size + 1) as u32);
822 assert_eq!(addr.sun_family, libc::AF_UNIX as libc::sa_family_t);
823
824 // Check `sun_path` in returned `sockaddr_un`
825 let mut ref_sun_path = [0 as libc::c_char; 108];
826 for path in ref_sun_path.iter_mut().take(path_size) {
827 *path = 'a' as libc::c_char;
828 }
829
830 for (addr_char, ref_char) in addr.sun_path.iter().zip(ref_sun_path.iter()) {
831 assert_eq!(addr_char, ref_char);
832 }
833 }
834 }
835