1 // Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
2 // SPDX-License-Identifier: Apache-2.0
3 
4 //! Structs for Unix Domain Socket listener and endpoint.
5 
6 #![allow(dead_code)]
7 
8 use std::fs::File;
9 use std::io::ErrorKind;
10 use std::marker::PhantomData;
11 use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
12 use std::os::unix::net::{UnixListener, UnixStream};
13 use std::path::{Path, PathBuf};
14 use std::{mem, slice};
15 
16 use libc::{c_void, iovec};
17 use vm_memory::ByteValued;
18 use vmm_sys_util::sock_ctrl_msg::ScmSocket;
19 
20 use super::message::*;
21 use super::{Error, Result};
22 
23 /// Unix domain socket listener for accepting incoming connections.
24 pub struct Listener {
25     fd: UnixListener,
26     path: Option<PathBuf>,
27 }
28 
29 impl Listener {
30     /// Create a unix domain socket listener.
31     ///
32     /// # Return:
33     /// * - the new Listener object on success.
34     /// * - SocketError: failed to create listener socket.
new<P: AsRef<Path>>(path: P, unlink: bool) -> Result<Self>35     pub fn new<P: AsRef<Path>>(path: P, unlink: bool) -> Result<Self> {
36         if unlink {
37             let _ = std::fs::remove_file(&path);
38         }
39         let fd = UnixListener::bind(&path).map_err(Error::SocketError)?;
40         Ok(Listener {
41             fd,
42             path: Some(path.as_ref().to_owned()),
43         })
44     }
45 
46     /// Accept an incoming connection.
47     ///
48     /// # Return:
49     /// * - Some(UnixStream): new UnixStream object if new incoming connection is available.
50     /// * - None: no incoming connection available.
51     /// * - SocketError: errors from accept().
accept(&self) -> Result<Option<UnixStream>>52     pub fn accept(&self) -> Result<Option<UnixStream>> {
53         loop {
54             match self.fd.accept() {
55                 Ok((socket, _addr)) => return Ok(Some(socket)),
56                 Err(e) => {
57                     match e.kind() {
58                         // No incoming connection available.
59                         ErrorKind::WouldBlock => return Ok(None),
60                         // New connection closed by peer.
61                         ErrorKind::ConnectionAborted => return Ok(None),
62                         // Interrupted by signals, retry
63                         ErrorKind::Interrupted => continue,
64                         _ => return Err(Error::SocketError(e)),
65                     }
66                 }
67             }
68         }
69     }
70 
71     /// Change blocking status on the listener.
72     ///
73     /// # Return:
74     /// * - () on success.
75     /// * - SocketError: failure from set_nonblocking().
set_nonblocking(&self, block: bool) -> Result<()>76     pub fn set_nonblocking(&self, block: bool) -> Result<()> {
77         self.fd.set_nonblocking(block).map_err(Error::SocketError)
78     }
79 }
80 
81 impl AsRawFd for Listener {
as_raw_fd(&self) -> RawFd82     fn as_raw_fd(&self) -> RawFd {
83         self.fd.as_raw_fd()
84     }
85 }
86 
87 impl FromRawFd for Listener {
from_raw_fd(fd: RawFd) -> Self88     unsafe fn from_raw_fd(fd: RawFd) -> Self {
89         Listener {
90             fd: UnixListener::from_raw_fd(fd),
91             path: None,
92         }
93     }
94 }
95 
96 impl Drop for Listener {
drop(&mut self)97     fn drop(&mut self) {
98         if let Some(path) = &self.path {
99             let _ = std::fs::remove_file(path);
100         }
101     }
102 }
103 
104 /// Unix domain socket endpoint for vhost-user connection.
105 pub(super) struct Endpoint<R: Req> {
106     sock: UnixStream,
107     _r: PhantomData<R>,
108 }
109 
110 impl<R: Req> Endpoint<R> {
111     /// Create a new stream by connecting to server at `str`.
112     ///
113     /// # Return:
114     /// * - the new Endpoint object on success.
115     /// * - SocketConnect: failed to connect to peer.
connect<P: AsRef<Path>>(path: P) -> Result<Self>116     pub fn connect<P: AsRef<Path>>(path: P) -> Result<Self> {
117         let sock = UnixStream::connect(path).map_err(Error::SocketConnect)?;
118         Ok(Self::from_stream(sock))
119     }
120 
121     /// Create an endpoint from a stream object.
from_stream(sock: UnixStream) -> Self122     pub fn from_stream(sock: UnixStream) -> Self {
123         Endpoint {
124             sock,
125             _r: PhantomData,
126         }
127     }
128 
129     /// Sends bytes from scatter-gather vectors over the socket with optional attached file
130     /// descriptors.
131     ///
132     /// # Return:
133     /// * - number of bytes sent on success
134     /// * - SocketRetry: temporary error caused by signals or short of resources.
135     /// * - SocketBroken: the underline socket is broken.
136     /// * - SocketError: other socket related errors.
send_iovec(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result<usize>137     pub fn send_iovec(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result<usize> {
138         let rfds = match fds {
139             Some(rfds) => rfds,
140             _ => &[],
141         };
142         self.sock.send_with_fds(iovs, rfds).map_err(Into::into)
143     }
144 
145     /// Sends all bytes from scatter-gather vectors over the socket with optional attached file
146     /// descriptors. Will loop until all data has been transfered.
147     ///
148     /// # Return:
149     /// * - number of bytes sent on success
150     /// * - SocketBroken: the underline socket is broken.
151     /// * - SocketError: other socket related errors.
send_iovec_all(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result<usize>152     pub fn send_iovec_all(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result<usize> {
153         let mut data_sent = 0;
154         let mut data_total = 0;
155         let iov_lens: Vec<usize> = iovs.iter().map(|iov| iov.len()).collect();
156         for len in &iov_lens {
157             data_total += len;
158         }
159 
160         while (data_total - data_sent) > 0 {
161             let (nr_skip, offset) = get_sub_iovs_offset(&iov_lens, data_sent);
162             let iov = &iovs[nr_skip][offset..];
163 
164             let data = &[&[iov], &iovs[(nr_skip + 1)..]].concat();
165             let sfds = if data_sent == 0 { fds } else { None };
166 
167             let sent = self.send_iovec(data, sfds);
168             match sent {
169                 Ok(0) => return Ok(data_sent),
170                 Ok(n) => data_sent += n,
171                 Err(e) => match e {
172                     Error::SocketRetry(_) => {}
173                     _ => return Err(e),
174                 },
175             }
176         }
177         Ok(data_sent)
178     }
179 
180     /// Sends bytes from a slice over the socket with optional attached file descriptors.
181     ///
182     /// # Return:
183     /// * - number of bytes sent on success
184     /// * - SocketRetry: temporary error caused by signals or short of resources.
185     /// * - SocketBroken: the underline socket is broken.
186     /// * - SocketError: other socket related errors.
send_slice(&mut self, data: &[u8], fds: Option<&[RawFd]>) -> Result<usize>187     pub fn send_slice(&mut self, data: &[u8], fds: Option<&[RawFd]>) -> Result<usize> {
188         self.send_iovec(&[data], fds)
189     }
190 
191     /// Sends a header-only message with optional attached file descriptors.
192     ///
193     /// # Return:
194     /// * - number of bytes sent on success
195     /// * - SocketRetry: temporary error caused by signals or short of resources.
196     /// * - SocketBroken: the underline socket is broken.
197     /// * - SocketError: other socket related errors.
198     /// * - PartialMessage: received a partial message.
send_header( &mut self, hdr: &VhostUserMsgHeader<R>, fds: Option<&[RawFd]>, ) -> Result<()>199     pub fn send_header(
200         &mut self,
201         hdr: &VhostUserMsgHeader<R>,
202         fds: Option<&[RawFd]>,
203     ) -> Result<()> {
204         // SAFETY: Safe because there can't be other mutable referance to hdr.
205         let iovs = unsafe {
206             [slice::from_raw_parts(
207                 hdr as *const VhostUserMsgHeader<R> as *const u8,
208                 mem::size_of::<VhostUserMsgHeader<R>>(),
209             )]
210         };
211         let bytes = self.send_iovec_all(&iovs[..], fds)?;
212         if bytes != mem::size_of::<VhostUserMsgHeader<R>>() {
213             return Err(Error::PartialMessage);
214         }
215         Ok(())
216     }
217 
218     /// Send a message with header and body. Optional file descriptors may be attached to
219     /// the message.
220     ///
221     /// # Return:
222     /// * - number of bytes sent on success
223     /// * - SocketRetry: temporary error caused by signals or short of resources.
224     /// * - SocketBroken: the underline socket is broken.
225     /// * - SocketError: other socket related errors.
226     /// * - PartialMessage: received a partial message.
send_message<T: ByteValued>( &mut self, hdr: &VhostUserMsgHeader<R>, body: &T, fds: Option<&[RawFd]>, ) -> Result<()>227     pub fn send_message<T: ByteValued>(
228         &mut self,
229         hdr: &VhostUserMsgHeader<R>,
230         body: &T,
231         fds: Option<&[RawFd]>,
232     ) -> Result<()> {
233         if mem::size_of::<T>() > MAX_MSG_SIZE {
234             return Err(Error::OversizedMsg);
235         }
236         let bytes = self.send_iovec_all(&[hdr.as_slice(), body.as_slice()], fds)?;
237         if bytes != mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>() {
238             return Err(Error::PartialMessage);
239         }
240         Ok(())
241     }
242 
243     /// Send a message with header, body and payload. Optional file descriptors
244     /// may also be attached to the message.
245     ///
246     /// # Return:
247     /// * - number of bytes sent on success
248     /// * - SocketRetry: temporary error caused by signals or short of resources.
249     /// * - SocketBroken: the underline socket is broken.
250     /// * - SocketError: other socket related errors.
251     /// * - OversizedMsg: message size is too big.
252     /// * - PartialMessage: received a partial message.
253     /// * - IncorrectFds: wrong number of attached fds.
send_message_with_payload<T: ByteValued>( &mut self, hdr: &VhostUserMsgHeader<R>, body: &T, payload: &[u8], fds: Option<&[RawFd]>, ) -> Result<()>254     pub fn send_message_with_payload<T: ByteValued>(
255         &mut self,
256         hdr: &VhostUserMsgHeader<R>,
257         body: &T,
258         payload: &[u8],
259         fds: Option<&[RawFd]>,
260     ) -> Result<()> {
261         let len = payload.len();
262         if mem::size_of::<T>() > MAX_MSG_SIZE {
263             return Err(Error::OversizedMsg);
264         }
265         if len > MAX_MSG_SIZE - mem::size_of::<T>() {
266             return Err(Error::OversizedMsg);
267         }
268         if let Some(fd_arr) = fds {
269             if fd_arr.len() > MAX_ATTACHED_FD_ENTRIES {
270                 return Err(Error::IncorrectFds);
271             }
272         }
273 
274         let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>() + len;
275         let len = self.send_iovec_all(&[hdr.as_slice(), body.as_slice(), payload], fds)?;
276         if len != total {
277             return Err(Error::PartialMessage);
278         }
279         Ok(())
280     }
281 
282     /// Reads bytes from the socket into the given scatter/gather vectors.
283     ///
284     /// # Return:
285     /// * - (number of bytes received, buf) on success
286     /// * - SocketRetry: temporary error caused by signals or short of resources.
287     /// * - SocketBroken: the underline socket is broken.
288     /// * - SocketError: other socket related errors.
recv_data(&mut self, len: usize) -> Result<(usize, Vec<u8>)>289     pub fn recv_data(&mut self, len: usize) -> Result<(usize, Vec<u8>)> {
290         let mut rbuf = vec![0u8; len];
291         let mut iovs = [iovec {
292             iov_base: rbuf.as_mut_ptr() as *mut c_void,
293             iov_len: len,
294         }];
295         // SAFETY: Safe because we own rbuf and it's safe to fill a byte array with arbitrary data.
296         let (bytes, _) = unsafe { self.sock.recv_with_fds(&mut iovs, &mut [])? };
297         Ok((bytes, rbuf))
298     }
299 
300     /// Reads bytes from the socket into the given scatter/gather vectors with optional attached
301     /// file.
302     ///
303     /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little
304     /// tricky to pass file descriptors through such a communication channel. Let's assume that a
305     /// sender sending a message with some file descriptors attached. To successfully receive those
306     /// attached file descriptors, the receiver must obey following rules:
307     ///   1) file descriptors are attached to a message.
308     ///   2) message(packet) boundaries must be respected on the receive side.
309     /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the
310     /// attached file descriptors will get lost.
311     /// Note that this function wraps received file descriptors as `File`.
312     ///
313     /// # Return:
314     /// * - (number of bytes received, [received files]) on success
315     /// * - SocketRetry: temporary error caused by signals or short of resources.
316     /// * - SocketBroken: the underline socket is broken.
317     /// * - SocketError: other socket related errors.
318     ///
319     /// # Safety
320     ///
321     /// It is the callers responsibility to ensure it is safe for arbitrary data to be
322     /// written to the iovec pointers.
recv_into_iovec( &mut self, iovs: &mut [iovec], ) -> Result<(usize, Option<Vec<File>>)>323     pub unsafe fn recv_into_iovec(
324         &mut self,
325         iovs: &mut [iovec],
326     ) -> Result<(usize, Option<Vec<File>>)> {
327         let mut fd_array = vec![0; MAX_ATTACHED_FD_ENTRIES];
328         let (bytes, fds) = self.sock.recv_with_fds(iovs, &mut fd_array)?;
329 
330         let files = match fds {
331             0 => None,
332             n => {
333                 let files = fd_array
334                     .iter()
335                     .take(n)
336                     .map(|fd| {
337                         // Safe because we have the ownership of `fd`.
338                         File::from_raw_fd(*fd)
339                     })
340                     .collect();
341                 Some(files)
342             }
343         };
344 
345         Ok((bytes, files))
346     }
347 
348     /// Reads all bytes from the socket into the given scatter/gather vectors with optional
349     /// attached files. Will loop until all data has been transferred.
350     ///
351     /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little
352     /// tricky to pass file descriptors through such a communication channel. Let's assume that a
353     /// sender sending a message with some file descriptors attached. To successfully receive those
354     /// attached file descriptors, the receiver must obey following rules:
355     ///   1) file descriptors are attached to a message.
356     ///   2) message(packet) boundaries must be respected on the receive side.
357     /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the
358     /// attached file descriptors will get lost.
359     /// Note that this function wraps received file descriptors as `File`.
360     ///
361     /// # Return:
362     /// * - (number of bytes received, [received fds]) on success
363     /// * - SocketBroken: the underline socket is broken.
364     /// * - SocketError: other socket related errors.
365     ///
366     /// # Safety
367     ///
368     /// It is the callers responsibility to ensure it is safe for arbitrary data to be
369     /// written to the iovec pointers.
recv_into_iovec_all( &mut self, iovs: &mut [iovec], ) -> Result<(usize, Option<Vec<File>>)>370     pub unsafe fn recv_into_iovec_all(
371         &mut self,
372         iovs: &mut [iovec],
373     ) -> Result<(usize, Option<Vec<File>>)> {
374         let mut data_read = 0;
375         let mut data_total = 0;
376         let mut rfds = None;
377         let iov_lens: Vec<usize> = iovs.iter().map(|iov| iov.iov_len).collect();
378         for len in &iov_lens {
379             data_total += len;
380         }
381 
382         while (data_total - data_read) > 0 {
383             let (nr_skip, offset) = get_sub_iovs_offset(&iov_lens, data_read);
384             let iov = &mut iovs[nr_skip];
385 
386             let mut data = [
387                 &[iovec {
388                     iov_base: (iov.iov_base as usize + offset) as *mut c_void,
389                     iov_len: iov.iov_len - offset,
390                 }],
391                 &iovs[(nr_skip + 1)..],
392             ]
393             .concat();
394 
395             let res = self.recv_into_iovec(&mut data);
396             match res {
397                 Ok((0, _)) => return Ok((data_read, rfds)),
398                 Ok((n, fds)) => {
399                     if data_read == 0 {
400                         rfds = fds;
401                     }
402                     data_read += n;
403                 }
404                 Err(e) => match e {
405                     Error::SocketRetry(_) => {}
406                     _ => return Err(e),
407                 },
408             }
409         }
410         Ok((data_read, rfds))
411     }
412 
413     /// Reads bytes from the socket into a new buffer with optional attached
414     /// files. Received file descriptors are set close-on-exec and converted to `File`.
415     ///
416     /// # Return:
417     /// * - (number of bytes received, buf, [received files]) on success.
418     /// * - SocketRetry: temporary error caused by signals or short of resources.
419     /// * - SocketBroken: the underline socket is broken.
420     /// * - SocketError: other socket related errors.
recv_into_buf( &mut self, buf_size: usize, ) -> Result<(usize, Vec<u8>, Option<Vec<File>>)>421     pub fn recv_into_buf(
422         &mut self,
423         buf_size: usize,
424     ) -> Result<(usize, Vec<u8>, Option<Vec<File>>)> {
425         let mut buf = vec![0u8; buf_size];
426         let (bytes, files) = {
427             let mut iovs = [iovec {
428                 iov_base: buf.as_mut_ptr() as *mut c_void,
429                 iov_len: buf_size,
430             }];
431             // SAFETY: Safe because we own buf and it's safe to fill a byte array with arbitrary data.
432             unsafe { self.recv_into_iovec(&mut iovs)? }
433         };
434         Ok((bytes, buf, files))
435     }
436 
437     /// Receive a header-only message with optional attached files.
438     /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be
439     /// accepted and all other file descriptor will be discard silently.
440     ///
441     /// # Return:
442     /// * - (message header, [received files]) on success.
443     /// * - SocketRetry: temporary error caused by signals or short of resources.
444     /// * - SocketBroken: the underline socket is broken.
445     /// * - SocketError: other socket related errors.
446     /// * - PartialMessage: received a partial message.
447     /// * - InvalidMessage: received a invalid message.
recv_header(&mut self) -> Result<(VhostUserMsgHeader<R>, Option<Vec<File>>)>448     pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader<R>, Option<Vec<File>>)> {
449         let mut hdr = VhostUserMsgHeader::default();
450         let mut iovs = [iovec {
451             iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
452             iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
453         }];
454         // SAFETY: Safe because we own hdr and it's ByteValued.
455         let (bytes, files) = unsafe { self.recv_into_iovec_all(&mut iovs[..])? };
456 
457         if bytes == 0 {
458             return Err(Error::Disconnected);
459         } else if bytes != mem::size_of::<VhostUserMsgHeader<R>>() {
460             return Err(Error::PartialMessage);
461         } else if !hdr.is_valid() {
462             return Err(Error::InvalidMessage);
463         }
464 
465         Ok((hdr, files))
466     }
467 
468     /// Receive a message with optional attached file descriptors.
469     /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be
470     /// accepted and all other file descriptor will be discard silently.
471     ///
472     /// # Return:
473     /// * - (message header, message body, [received files]) on success.
474     /// * - SocketRetry: temporary error caused by signals or short of resources.
475     /// * - SocketBroken: the underline socket is broken.
476     /// * - SocketError: other socket related errors.
477     /// * - PartialMessage: received a partial message.
478     /// * - InvalidMessage: received a invalid message.
recv_body<T: ByteValued + Sized + VhostUserMsgValidator>( &mut self, ) -> Result<(VhostUserMsgHeader<R>, T, Option<Vec<File>>)>479     pub fn recv_body<T: ByteValued + Sized + VhostUserMsgValidator>(
480         &mut self,
481     ) -> Result<(VhostUserMsgHeader<R>, T, Option<Vec<File>>)> {
482         let mut hdr = VhostUserMsgHeader::default();
483         let mut body: T = Default::default();
484         let mut iovs = [
485             iovec {
486                 iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
487                 iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
488             },
489             iovec {
490                 iov_base: (&mut body as *mut T) as *mut c_void,
491                 iov_len: mem::size_of::<T>(),
492             },
493         ];
494         // SAFETY: Safe because we own hdr and body and they're ByteValued.
495         let (bytes, files) = unsafe { self.recv_into_iovec_all(&mut iovs[..])? };
496 
497         let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>();
498         if bytes != total {
499             return Err(Error::PartialMessage);
500         } else if !hdr.is_valid() || !body.is_valid() {
501             return Err(Error::InvalidMessage);
502         }
503 
504         Ok((hdr, body, files))
505     }
506 
507     /// Receive a message with header and optional content. Callers need to
508     /// pre-allocate a big enough buffer to receive the message body and
509     /// optional payload. If there are attached file descriptor associated
510     /// with the message, the first MAX_ATTACHED_FD_ENTRIES file descriptors
511     /// will be accepted and all other file descriptor will be discard
512     /// silently.
513     ///
514     /// # Return:
515     /// * - (message header, message size, [received files]) on success.
516     /// * - SocketRetry: temporary error caused by signals or short of resources.
517     /// * - SocketBroken: the underline socket is broken.
518     /// * - SocketError: other socket related errors.
519     /// * - PartialMessage: received a partial message.
520     /// * - InvalidMessage: received a invalid message.
recv_body_into_buf( &mut self, buf: &mut [u8], ) -> Result<(VhostUserMsgHeader<R>, usize, Option<Vec<File>>)>521     pub fn recv_body_into_buf(
522         &mut self,
523         buf: &mut [u8],
524     ) -> Result<(VhostUserMsgHeader<R>, usize, Option<Vec<File>>)> {
525         let mut hdr = VhostUserMsgHeader::default();
526         let mut iovs = [
527             iovec {
528                 iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
529                 iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
530             },
531             iovec {
532                 iov_base: buf.as_mut_ptr() as *mut c_void,
533                 iov_len: buf.len(),
534             },
535         ];
536         // SAFETY: Safe because we own hdr and have a mutable borrow of buf, and hdr is ByteValued
537         // and it's safe to fill a byte slice with arbitrary data.
538         let (bytes, files) = unsafe { self.recv_into_iovec_all(&mut iovs[..])? };
539 
540         if bytes < mem::size_of::<VhostUserMsgHeader<R>>() {
541             return Err(Error::PartialMessage);
542         } else if !hdr.is_valid() {
543             return Err(Error::InvalidMessage);
544         }
545 
546         Ok((hdr, bytes - mem::size_of::<VhostUserMsgHeader<R>>(), files))
547     }
548 
549     /// Receive a message with optional payload and attached file descriptors.
550     /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be
551     /// accepted and all other file descriptor will be discard silently.
552     ///
553     /// # Return:
554     /// * - (message header, message body, size of payload, [received files]) on success.
555     /// * - SocketRetry: temporary error caused by signals or short of resources.
556     /// * - SocketBroken: the underline socket is broken.
557     /// * - SocketError: other socket related errors.
558     /// * - PartialMessage: received a partial message.
559     /// * - InvalidMessage: received a invalid message.
560     #[cfg_attr(feature = "cargo-clippy", allow(clippy::type_complexity))]
recv_payload_into_buf<T: ByteValued + Sized + VhostUserMsgValidator>( &mut self, buf: &mut [u8], ) -> Result<(VhostUserMsgHeader<R>, T, usize, Option<Vec<File>>)>561     pub fn recv_payload_into_buf<T: ByteValued + Sized + VhostUserMsgValidator>(
562         &mut self,
563         buf: &mut [u8],
564     ) -> Result<(VhostUserMsgHeader<R>, T, usize, Option<Vec<File>>)> {
565         let mut hdr = VhostUserMsgHeader::default();
566         let mut body: T = Default::default();
567         let mut iovs = [
568             iovec {
569                 iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
570                 iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
571             },
572             iovec {
573                 iov_base: (&mut body as *mut T) as *mut c_void,
574                 iov_len: mem::size_of::<T>(),
575             },
576             iovec {
577                 iov_base: buf.as_mut_ptr() as *mut c_void,
578                 iov_len: buf.len(),
579             },
580         ];
581         // SAFETY: Safe because we own hdr and body and have a mutable borrow of buf, and
582         // hdr and body are ByteValued, and it's safe to fill a byte slice with
583         // arbitrary data.
584         let (bytes, files) = unsafe { self.recv_into_iovec_all(&mut iovs[..])? };
585 
586         let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>();
587         if bytes < total {
588             return Err(Error::PartialMessage);
589         } else if !hdr.is_valid() || !body.is_valid() {
590             return Err(Error::InvalidMessage);
591         }
592 
593         Ok((hdr, body, bytes - total, files))
594     }
595 }
596 
597 impl<T: Req> AsRawFd for Endpoint<T> {
as_raw_fd(&self) -> RawFd598     fn as_raw_fd(&self) -> RawFd {
599         self.sock.as_raw_fd()
600     }
601 }
602 
603 // Given a slice of sizes and the `skip_size`, return the offset of `skip_size` in the slice.
604 // For example:
605 //     let iov_lens = vec![4, 4, 5];
606 //     let size = 6;
607 //     assert_eq!(get_sub_iovs_offset(&iov_len, size), (1, 2));
get_sub_iovs_offset(iov_lens: &[usize], skip_size: usize) -> (usize, usize)608 fn get_sub_iovs_offset(iov_lens: &[usize], skip_size: usize) -> (usize, usize) {
609     let mut size = skip_size;
610     let mut nr_skip = 0;
611 
612     for len in iov_lens {
613         if size >= *len {
614             size -= *len;
615             nr_skip += 1;
616         } else {
617             break;
618         }
619     }
620     (nr_skip, size)
621 }
622 
623 #[cfg(test)]
624 mod tests {
625     use super::*;
626     use std::io::{Read, Seek, SeekFrom, Write};
627     use vmm_sys_util::rand::rand_alphanumerics;
628     use vmm_sys_util::tempfile::TempFile;
629 
temp_path() -> PathBuf630     fn temp_path() -> PathBuf {
631         PathBuf::from(format!(
632             "/tmp/vhost_test_{}",
633             rand_alphanumerics(8).to_str().unwrap()
634         ))
635     }
636 
637     #[test]
create_listener()638     fn create_listener() {
639         let path = temp_path();
640         let listener = Listener::new(path, true).unwrap();
641 
642         assert!(listener.as_raw_fd() > 0);
643     }
644 
645     #[test]
create_listener_from_raw_fd()646     fn create_listener_from_raw_fd() {
647         let path = temp_path();
648         let file = File::create(path).unwrap();
649 
650         // SAFETY: Safe because `file` contains a valid fd to a file just created.
651         let listener = unsafe { Listener::from_raw_fd(file.as_raw_fd()) };
652 
653         assert!(listener.as_raw_fd() > 0);
654     }
655 
656     #[test]
accept_connection()657     fn accept_connection() {
658         let path = temp_path();
659         let listener = Listener::new(path, true).unwrap();
660         listener.set_nonblocking(true).unwrap();
661 
662         // accept on a fd without incoming connection
663         let conn = listener.accept().unwrap();
664         assert!(conn.is_none());
665     }
666 
667     #[test]
send_data()668     fn send_data() {
669         let path = temp_path();
670         let listener = Listener::new(&path, true).unwrap();
671         listener.set_nonblocking(true).unwrap();
672         let mut master = Endpoint::<MasterReq>::connect(&path).unwrap();
673         let sock = listener.accept().unwrap().unwrap();
674         let mut slave = Endpoint::<MasterReq>::from_stream(sock);
675 
676         let buf1 = vec![0x1, 0x2, 0x3, 0x4];
677         let mut len = master.send_slice(&buf1[..], None).unwrap();
678         assert_eq!(len, 4);
679         let (bytes, buf2, _) = slave.recv_into_buf(0x1000).unwrap();
680         assert_eq!(bytes, 4);
681         assert_eq!(&buf1[..], &buf2[..bytes]);
682 
683         len = master.send_slice(&buf1[..], None).unwrap();
684         assert_eq!(len, 4);
685         let (bytes, buf2, _) = slave.recv_into_buf(0x2).unwrap();
686         assert_eq!(bytes, 2);
687         assert_eq!(&buf1[..2], &buf2[..]);
688         let (bytes, buf2, _) = slave.recv_into_buf(0x2).unwrap();
689         assert_eq!(bytes, 2);
690         assert_eq!(&buf1[2..], &buf2[..]);
691     }
692 
693     #[test]
send_fd()694     fn send_fd() {
695         let path = temp_path();
696         let listener = Listener::new(&path, true).unwrap();
697         listener.set_nonblocking(true).unwrap();
698         let mut master = Endpoint::<MasterReq>::connect(&path).unwrap();
699         let sock = listener.accept().unwrap().unwrap();
700         let mut slave = Endpoint::<MasterReq>::from_stream(sock);
701 
702         let mut fd = TempFile::new().unwrap().into_file();
703         write!(fd, "test").unwrap();
704 
705         // Normal case for sending/receiving file descriptors
706         let buf1 = vec![0x1, 0x2, 0x3, 0x4];
707         let len = master
708             .send_slice(&buf1[..], Some(&[fd.as_raw_fd()]))
709             .unwrap();
710         assert_eq!(len, 4);
711 
712         let (bytes, buf2, files) = slave.recv_into_buf(4).unwrap();
713         assert_eq!(bytes, 4);
714         assert_eq!(&buf1[..], &buf2[..]);
715         assert!(files.is_some());
716         let files = files.unwrap();
717         {
718             assert_eq!(files.len(), 1);
719             let mut file = &files[0];
720             let mut content = String::new();
721             file.seek(SeekFrom::Start(0)).unwrap();
722             file.read_to_string(&mut content).unwrap();
723             assert_eq!(content, "test");
724         }
725 
726         // Following communication pattern should work:
727         // Sending side: data(header, body) with fds
728         // Receiving side: data(header) with fds, data(body)
729         let len = master
730             .send_slice(
731                 &buf1[..],
732                 Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
733             )
734             .unwrap();
735         assert_eq!(len, 4);
736 
737         let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap();
738         assert_eq!(bytes, 2);
739         assert_eq!(&buf1[..2], &buf2[..]);
740         assert!(files.is_some());
741         let files = files.unwrap();
742         {
743             assert_eq!(files.len(), 3);
744             let mut file = &files[1];
745             let mut content = String::new();
746             file.seek(SeekFrom::Start(0)).unwrap();
747             file.read_to_string(&mut content).unwrap();
748             assert_eq!(content, "test");
749         }
750         let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap();
751         assert_eq!(bytes, 2);
752         assert_eq!(&buf1[2..], &buf2[..]);
753         assert!(files.is_none());
754 
755         // Following communication pattern should not work:
756         // Sending side: data(header, body) with fds
757         // Receiving side: data(header), data(body) with fds
758         let len = master
759             .send_slice(
760                 &buf1[..],
761                 Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
762             )
763             .unwrap();
764         assert_eq!(len, 4);
765 
766         let (bytes, buf4) = slave.recv_data(2).unwrap();
767         assert_eq!(bytes, 2);
768         assert_eq!(&buf1[..2], &buf4[..]);
769         let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap();
770         assert_eq!(bytes, 2);
771         assert_eq!(&buf1[2..], &buf2[..]);
772         assert!(files.is_none());
773 
774         // Following communication pattern should work:
775         // Sending side: data, data with fds
776         // Receiving side: data, data with fds
777         let len = master.send_slice(&buf1[..], None).unwrap();
778         assert_eq!(len, 4);
779         let len = master
780             .send_slice(
781                 &buf1[..],
782                 Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
783             )
784             .unwrap();
785         assert_eq!(len, 4);
786 
787         let (bytes, buf2, files) = slave.recv_into_buf(0x4).unwrap();
788         assert_eq!(bytes, 4);
789         assert_eq!(&buf1[..], &buf2[..]);
790         assert!(files.is_none());
791 
792         let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap();
793         assert_eq!(bytes, 2);
794         assert_eq!(&buf1[..2], &buf2[..]);
795         assert!(files.is_some());
796         let files = files.unwrap();
797         {
798             assert_eq!(files.len(), 3);
799             let mut file = &files[1];
800             let mut content = String::new();
801             file.seek(SeekFrom::Start(0)).unwrap();
802             file.read_to_string(&mut content).unwrap();
803             assert_eq!(content, "test");
804         }
805         let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap();
806         assert_eq!(bytes, 2);
807         assert_eq!(&buf1[2..], &buf2[..]);
808         assert!(files.is_none());
809 
810         // Following communication pattern should not work:
811         // Sending side: data1, data2 with fds
812         // Receiving side: data + partial of data2, left of data2 with fds
813         let len = master.send_slice(&buf1[..], None).unwrap();
814         assert_eq!(len, 4);
815         let len = master
816             .send_slice(
817                 &buf1[..],
818                 Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
819             )
820             .unwrap();
821         assert_eq!(len, 4);
822 
823         let (bytes, _) = slave.recv_data(5).unwrap();
824         assert_eq!(bytes, 5);
825 
826         let (bytes, _, files) = slave.recv_into_buf(0x4).unwrap();
827         assert_eq!(bytes, 3);
828         assert!(files.is_none());
829 
830         // If the target fd array is too small, extra file descriptors will get lost.
831         let len = master
832             .send_slice(
833                 &buf1[..],
834                 Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
835             )
836             .unwrap();
837         assert_eq!(len, 4);
838 
839         let (bytes, _, files) = slave.recv_into_buf(0x4).unwrap();
840         assert_eq!(bytes, 4);
841         assert!(files.is_some());
842     }
843 
844     #[test]
send_recv()845     fn send_recv() {
846         let path = temp_path();
847         let listener = Listener::new(&path, true).unwrap();
848         listener.set_nonblocking(true).unwrap();
849         let mut master = Endpoint::<MasterReq>::connect(&path).unwrap();
850         let sock = listener.accept().unwrap().unwrap();
851         let mut slave = Endpoint::<MasterReq>::from_stream(sock);
852 
853         let mut hdr1 =
854             VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0, mem::size_of::<u64>() as u32);
855         hdr1.set_need_reply(true);
856         let features1 = 0x1u64;
857         master.send_message(&hdr1, &features1, None).unwrap();
858 
859         let mut features2 = 0u64;
860 
861         // SAFETY: Safe because features2 is valid and it's an `u64`.
862         let slice = unsafe {
863             slice::from_raw_parts_mut(
864                 (&mut features2 as *mut u64) as *mut u8,
865                 mem::size_of::<u64>(),
866             )
867         };
868         let (hdr2, bytes, files) = slave.recv_body_into_buf(slice).unwrap();
869         assert_eq!(hdr1, hdr2);
870         assert_eq!(bytes, 8);
871         assert_eq!(features1, features2);
872         assert!(files.is_none());
873 
874         master.send_header(&hdr1, None).unwrap();
875         let (hdr2, files) = slave.recv_header().unwrap();
876         assert_eq!(hdr1, hdr2);
877         assert!(files.is_none());
878     }
879 
880     #[test]
partial_message()881     fn partial_message() {
882         let path = temp_path();
883         let listener = Listener::new(&path, true).unwrap();
884         let mut master = UnixStream::connect(&path).unwrap();
885         let sock = listener.accept().unwrap().unwrap();
886         let mut slave = Endpoint::<MasterReq>::from_stream(sock);
887 
888         write!(master, "a").unwrap();
889         drop(master);
890         assert!(matches!(slave.recv_header(), Err(Error::PartialMessage)));
891     }
892 
893     #[test]
disconnected()894     fn disconnected() {
895         let path = temp_path();
896         let listener = Listener::new(&path, true).unwrap();
897         let _ = UnixStream::connect(&path).unwrap();
898         let sock = listener.accept().unwrap().unwrap();
899         let mut slave = Endpoint::<MasterReq>::from_stream(sock);
900 
901         assert!(matches!(slave.recv_header(), Err(Error::Disconnected)));
902     }
903 }
904