1 // Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
2 // SPDX-License-Identifier: Apache-2.0
3
4 //! Structs for Unix Domain Socket listener and endpoint.
5
6 #![allow(dead_code)]
7
8 use std::fs::File;
9 use std::io::ErrorKind;
10 use std::marker::PhantomData;
11 use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
12 use std::os::unix::net::{UnixListener, UnixStream};
13 use std::path::{Path, PathBuf};
14 use std::{mem, slice};
15
16 use libc::{c_void, iovec};
17 use vm_memory::ByteValued;
18 use vmm_sys_util::sock_ctrl_msg::ScmSocket;
19
20 use super::message::*;
21 use super::{Error, Result};
22
23 /// Unix domain socket listener for accepting incoming connections.
24 pub struct Listener {
25 fd: UnixListener,
26 path: Option<PathBuf>,
27 }
28
29 impl Listener {
30 /// Create a unix domain socket listener.
31 ///
32 /// # Return:
33 /// * - the new Listener object on success.
34 /// * - SocketError: failed to create listener socket.
new<P: AsRef<Path>>(path: P, unlink: bool) -> Result<Self>35 pub fn new<P: AsRef<Path>>(path: P, unlink: bool) -> Result<Self> {
36 if unlink {
37 let _ = std::fs::remove_file(&path);
38 }
39 let fd = UnixListener::bind(&path).map_err(Error::SocketError)?;
40 Ok(Listener {
41 fd,
42 path: Some(path.as_ref().to_owned()),
43 })
44 }
45
46 /// Accept an incoming connection.
47 ///
48 /// # Return:
49 /// * - Some(UnixStream): new UnixStream object if new incoming connection is available.
50 /// * - None: no incoming connection available.
51 /// * - SocketError: errors from accept().
accept(&self) -> Result<Option<UnixStream>>52 pub fn accept(&self) -> Result<Option<UnixStream>> {
53 loop {
54 match self.fd.accept() {
55 Ok((socket, _addr)) => return Ok(Some(socket)),
56 Err(e) => {
57 match e.kind() {
58 // No incoming connection available.
59 ErrorKind::WouldBlock => return Ok(None),
60 // New connection closed by peer.
61 ErrorKind::ConnectionAborted => return Ok(None),
62 // Interrupted by signals, retry
63 ErrorKind::Interrupted => continue,
64 _ => return Err(Error::SocketError(e)),
65 }
66 }
67 }
68 }
69 }
70
71 /// Change blocking status on the listener.
72 ///
73 /// # Return:
74 /// * - () on success.
75 /// * - SocketError: failure from set_nonblocking().
set_nonblocking(&self, block: bool) -> Result<()>76 pub fn set_nonblocking(&self, block: bool) -> Result<()> {
77 self.fd.set_nonblocking(block).map_err(Error::SocketError)
78 }
79 }
80
81 impl AsRawFd for Listener {
as_raw_fd(&self) -> RawFd82 fn as_raw_fd(&self) -> RawFd {
83 self.fd.as_raw_fd()
84 }
85 }
86
87 impl FromRawFd for Listener {
from_raw_fd(fd: RawFd) -> Self88 unsafe fn from_raw_fd(fd: RawFd) -> Self {
89 Listener {
90 fd: UnixListener::from_raw_fd(fd),
91 path: None,
92 }
93 }
94 }
95
96 impl Drop for Listener {
drop(&mut self)97 fn drop(&mut self) {
98 if let Some(path) = &self.path {
99 let _ = std::fs::remove_file(path);
100 }
101 }
102 }
103
104 /// Unix domain socket endpoint for vhost-user connection.
105 pub(super) struct Endpoint<R: Req> {
106 sock: UnixStream,
107 _r: PhantomData<R>,
108 }
109
110 impl<R: Req> Endpoint<R> {
111 /// Create a new stream by connecting to server at `str`.
112 ///
113 /// # Return:
114 /// * - the new Endpoint object on success.
115 /// * - SocketConnect: failed to connect to peer.
connect<P: AsRef<Path>>(path: P) -> Result<Self>116 pub fn connect<P: AsRef<Path>>(path: P) -> Result<Self> {
117 let sock = UnixStream::connect(path).map_err(Error::SocketConnect)?;
118 Ok(Self::from_stream(sock))
119 }
120
121 /// Create an endpoint from a stream object.
from_stream(sock: UnixStream) -> Self122 pub fn from_stream(sock: UnixStream) -> Self {
123 Endpoint {
124 sock,
125 _r: PhantomData,
126 }
127 }
128
129 /// Sends bytes from scatter-gather vectors over the socket with optional attached file
130 /// descriptors.
131 ///
132 /// # Return:
133 /// * - number of bytes sent on success
134 /// * - SocketRetry: temporary error caused by signals or short of resources.
135 /// * - SocketBroken: the underline socket is broken.
136 /// * - SocketError: other socket related errors.
send_iovec(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result<usize>137 pub fn send_iovec(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result<usize> {
138 let rfds = match fds {
139 Some(rfds) => rfds,
140 _ => &[],
141 };
142 self.sock.send_with_fds(iovs, rfds).map_err(Into::into)
143 }
144
145 /// Sends all bytes from scatter-gather vectors over the socket with optional attached file
146 /// descriptors. Will loop until all data has been transfered.
147 ///
148 /// # Return:
149 /// * - number of bytes sent on success
150 /// * - SocketBroken: the underline socket is broken.
151 /// * - SocketError: other socket related errors.
send_iovec_all(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result<usize>152 pub fn send_iovec_all(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result<usize> {
153 let mut data_sent = 0;
154 let mut data_total = 0;
155 let iov_lens: Vec<usize> = iovs.iter().map(|iov| iov.len()).collect();
156 for len in &iov_lens {
157 data_total += len;
158 }
159
160 while (data_total - data_sent) > 0 {
161 let (nr_skip, offset) = get_sub_iovs_offset(&iov_lens, data_sent);
162 let iov = &iovs[nr_skip][offset..];
163
164 let data = &[&[iov], &iovs[(nr_skip + 1)..]].concat();
165 let sfds = if data_sent == 0 { fds } else { None };
166
167 let sent = self.send_iovec(data, sfds);
168 match sent {
169 Ok(0) => return Ok(data_sent),
170 Ok(n) => data_sent += n,
171 Err(e) => match e {
172 Error::SocketRetry(_) => {}
173 _ => return Err(e),
174 },
175 }
176 }
177 Ok(data_sent)
178 }
179
180 /// Sends bytes from a slice over the socket with optional attached file descriptors.
181 ///
182 /// # Return:
183 /// * - number of bytes sent on success
184 /// * - SocketRetry: temporary error caused by signals or short of resources.
185 /// * - SocketBroken: the underline socket is broken.
186 /// * - SocketError: other socket related errors.
send_slice(&mut self, data: &[u8], fds: Option<&[RawFd]>) -> Result<usize>187 pub fn send_slice(&mut self, data: &[u8], fds: Option<&[RawFd]>) -> Result<usize> {
188 self.send_iovec(&[data], fds)
189 }
190
191 /// Sends a header-only message with optional attached file descriptors.
192 ///
193 /// # Return:
194 /// * - number of bytes sent on success
195 /// * - SocketRetry: temporary error caused by signals or short of resources.
196 /// * - SocketBroken: the underline socket is broken.
197 /// * - SocketError: other socket related errors.
198 /// * - PartialMessage: received a partial message.
send_header( &mut self, hdr: &VhostUserMsgHeader<R>, fds: Option<&[RawFd]>, ) -> Result<()>199 pub fn send_header(
200 &mut self,
201 hdr: &VhostUserMsgHeader<R>,
202 fds: Option<&[RawFd]>,
203 ) -> Result<()> {
204 // SAFETY: Safe because there can't be other mutable referance to hdr.
205 let iovs = unsafe {
206 [slice::from_raw_parts(
207 hdr as *const VhostUserMsgHeader<R> as *const u8,
208 mem::size_of::<VhostUserMsgHeader<R>>(),
209 )]
210 };
211 let bytes = self.send_iovec_all(&iovs[..], fds)?;
212 if bytes != mem::size_of::<VhostUserMsgHeader<R>>() {
213 return Err(Error::PartialMessage);
214 }
215 Ok(())
216 }
217
218 /// Send a message with header and body. Optional file descriptors may be attached to
219 /// the message.
220 ///
221 /// # Return:
222 /// * - number of bytes sent on success
223 /// * - SocketRetry: temporary error caused by signals or short of resources.
224 /// * - SocketBroken: the underline socket is broken.
225 /// * - SocketError: other socket related errors.
226 /// * - PartialMessage: received a partial message.
send_message<T: ByteValued>( &mut self, hdr: &VhostUserMsgHeader<R>, body: &T, fds: Option<&[RawFd]>, ) -> Result<()>227 pub fn send_message<T: ByteValued>(
228 &mut self,
229 hdr: &VhostUserMsgHeader<R>,
230 body: &T,
231 fds: Option<&[RawFd]>,
232 ) -> Result<()> {
233 if mem::size_of::<T>() > MAX_MSG_SIZE {
234 return Err(Error::OversizedMsg);
235 }
236 let bytes = self.send_iovec_all(&[hdr.as_slice(), body.as_slice()], fds)?;
237 if bytes != mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>() {
238 return Err(Error::PartialMessage);
239 }
240 Ok(())
241 }
242
243 /// Send a message with header, body and payload. Optional file descriptors
244 /// may also be attached to the message.
245 ///
246 /// # Return:
247 /// * - number of bytes sent on success
248 /// * - SocketRetry: temporary error caused by signals or short of resources.
249 /// * - SocketBroken: the underline socket is broken.
250 /// * - SocketError: other socket related errors.
251 /// * - OversizedMsg: message size is too big.
252 /// * - PartialMessage: received a partial message.
253 /// * - IncorrectFds: wrong number of attached fds.
send_message_with_payload<T: ByteValued>( &mut self, hdr: &VhostUserMsgHeader<R>, body: &T, payload: &[u8], fds: Option<&[RawFd]>, ) -> Result<()>254 pub fn send_message_with_payload<T: ByteValued>(
255 &mut self,
256 hdr: &VhostUserMsgHeader<R>,
257 body: &T,
258 payload: &[u8],
259 fds: Option<&[RawFd]>,
260 ) -> Result<()> {
261 let len = payload.len();
262 if mem::size_of::<T>() > MAX_MSG_SIZE {
263 return Err(Error::OversizedMsg);
264 }
265 if len > MAX_MSG_SIZE - mem::size_of::<T>() {
266 return Err(Error::OversizedMsg);
267 }
268 if let Some(fd_arr) = fds {
269 if fd_arr.len() > MAX_ATTACHED_FD_ENTRIES {
270 return Err(Error::IncorrectFds);
271 }
272 }
273
274 let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>() + len;
275 let len = self.send_iovec_all(&[hdr.as_slice(), body.as_slice(), payload], fds)?;
276 if len != total {
277 return Err(Error::PartialMessage);
278 }
279 Ok(())
280 }
281
282 /// Reads bytes from the socket into the given scatter/gather vectors.
283 ///
284 /// # Return:
285 /// * - (number of bytes received, buf) on success
286 /// * - SocketRetry: temporary error caused by signals or short of resources.
287 /// * - SocketBroken: the underline socket is broken.
288 /// * - SocketError: other socket related errors.
recv_data(&mut self, len: usize) -> Result<(usize, Vec<u8>)>289 pub fn recv_data(&mut self, len: usize) -> Result<(usize, Vec<u8>)> {
290 let mut rbuf = vec![0u8; len];
291 let mut iovs = [iovec {
292 iov_base: rbuf.as_mut_ptr() as *mut c_void,
293 iov_len: len,
294 }];
295 // SAFETY: Safe because we own rbuf and it's safe to fill a byte array with arbitrary data.
296 let (bytes, _) = unsafe { self.sock.recv_with_fds(&mut iovs, &mut [])? };
297 Ok((bytes, rbuf))
298 }
299
300 /// Reads bytes from the socket into the given scatter/gather vectors with optional attached
301 /// file.
302 ///
303 /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little
304 /// tricky to pass file descriptors through such a communication channel. Let's assume that a
305 /// sender sending a message with some file descriptors attached. To successfully receive those
306 /// attached file descriptors, the receiver must obey following rules:
307 /// 1) file descriptors are attached to a message.
308 /// 2) message(packet) boundaries must be respected on the receive side.
309 /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the
310 /// attached file descriptors will get lost.
311 /// Note that this function wraps received file descriptors as `File`.
312 ///
313 /// # Return:
314 /// * - (number of bytes received, [received files]) on success
315 /// * - SocketRetry: temporary error caused by signals or short of resources.
316 /// * - SocketBroken: the underline socket is broken.
317 /// * - SocketError: other socket related errors.
318 ///
319 /// # Safety
320 ///
321 /// It is the callers responsibility to ensure it is safe for arbitrary data to be
322 /// written to the iovec pointers.
recv_into_iovec( &mut self, iovs: &mut [iovec], ) -> Result<(usize, Option<Vec<File>>)>323 pub unsafe fn recv_into_iovec(
324 &mut self,
325 iovs: &mut [iovec],
326 ) -> Result<(usize, Option<Vec<File>>)> {
327 let mut fd_array = vec![0; MAX_ATTACHED_FD_ENTRIES];
328 let (bytes, fds) = self.sock.recv_with_fds(iovs, &mut fd_array)?;
329
330 let files = match fds {
331 0 => None,
332 n => {
333 let files = fd_array
334 .iter()
335 .take(n)
336 .map(|fd| {
337 // Safe because we have the ownership of `fd`.
338 File::from_raw_fd(*fd)
339 })
340 .collect();
341 Some(files)
342 }
343 };
344
345 Ok((bytes, files))
346 }
347
348 /// Reads all bytes from the socket into the given scatter/gather vectors with optional
349 /// attached files. Will loop until all data has been transferred.
350 ///
351 /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little
352 /// tricky to pass file descriptors through such a communication channel. Let's assume that a
353 /// sender sending a message with some file descriptors attached. To successfully receive those
354 /// attached file descriptors, the receiver must obey following rules:
355 /// 1) file descriptors are attached to a message.
356 /// 2) message(packet) boundaries must be respected on the receive side.
357 /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the
358 /// attached file descriptors will get lost.
359 /// Note that this function wraps received file descriptors as `File`.
360 ///
361 /// # Return:
362 /// * - (number of bytes received, [received fds]) on success
363 /// * - SocketBroken: the underline socket is broken.
364 /// * - SocketError: other socket related errors.
365 ///
366 /// # Safety
367 ///
368 /// It is the callers responsibility to ensure it is safe for arbitrary data to be
369 /// written to the iovec pointers.
recv_into_iovec_all( &mut self, iovs: &mut [iovec], ) -> Result<(usize, Option<Vec<File>>)>370 pub unsafe fn recv_into_iovec_all(
371 &mut self,
372 iovs: &mut [iovec],
373 ) -> Result<(usize, Option<Vec<File>>)> {
374 let mut data_read = 0;
375 let mut data_total = 0;
376 let mut rfds = None;
377 let iov_lens: Vec<usize> = iovs.iter().map(|iov| iov.iov_len).collect();
378 for len in &iov_lens {
379 data_total += len;
380 }
381
382 while (data_total - data_read) > 0 {
383 let (nr_skip, offset) = get_sub_iovs_offset(&iov_lens, data_read);
384 let iov = &mut iovs[nr_skip];
385
386 let mut data = [
387 &[iovec {
388 iov_base: (iov.iov_base as usize + offset) as *mut c_void,
389 iov_len: iov.iov_len - offset,
390 }],
391 &iovs[(nr_skip + 1)..],
392 ]
393 .concat();
394
395 let res = self.recv_into_iovec(&mut data);
396 match res {
397 Ok((0, _)) => return Ok((data_read, rfds)),
398 Ok((n, fds)) => {
399 if data_read == 0 {
400 rfds = fds;
401 }
402 data_read += n;
403 }
404 Err(e) => match e {
405 Error::SocketRetry(_) => {}
406 _ => return Err(e),
407 },
408 }
409 }
410 Ok((data_read, rfds))
411 }
412
413 /// Reads bytes from the socket into a new buffer with optional attached
414 /// files. Received file descriptors are set close-on-exec and converted to `File`.
415 ///
416 /// # Return:
417 /// * - (number of bytes received, buf, [received files]) on success.
418 /// * - SocketRetry: temporary error caused by signals or short of resources.
419 /// * - SocketBroken: the underline socket is broken.
420 /// * - SocketError: other socket related errors.
recv_into_buf( &mut self, buf_size: usize, ) -> Result<(usize, Vec<u8>, Option<Vec<File>>)>421 pub fn recv_into_buf(
422 &mut self,
423 buf_size: usize,
424 ) -> Result<(usize, Vec<u8>, Option<Vec<File>>)> {
425 let mut buf = vec![0u8; buf_size];
426 let (bytes, files) = {
427 let mut iovs = [iovec {
428 iov_base: buf.as_mut_ptr() as *mut c_void,
429 iov_len: buf_size,
430 }];
431 // SAFETY: Safe because we own buf and it's safe to fill a byte array with arbitrary data.
432 unsafe { self.recv_into_iovec(&mut iovs)? }
433 };
434 Ok((bytes, buf, files))
435 }
436
437 /// Receive a header-only message with optional attached files.
438 /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be
439 /// accepted and all other file descriptor will be discard silently.
440 ///
441 /// # Return:
442 /// * - (message header, [received files]) on success.
443 /// * - SocketRetry: temporary error caused by signals or short of resources.
444 /// * - SocketBroken: the underline socket is broken.
445 /// * - SocketError: other socket related errors.
446 /// * - PartialMessage: received a partial message.
447 /// * - InvalidMessage: received a invalid message.
recv_header(&mut self) -> Result<(VhostUserMsgHeader<R>, Option<Vec<File>>)>448 pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader<R>, Option<Vec<File>>)> {
449 let mut hdr = VhostUserMsgHeader::default();
450 let mut iovs = [iovec {
451 iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
452 iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
453 }];
454 // SAFETY: Safe because we own hdr and it's ByteValued.
455 let (bytes, files) = unsafe { self.recv_into_iovec_all(&mut iovs[..])? };
456
457 if bytes == 0 {
458 return Err(Error::Disconnected);
459 } else if bytes != mem::size_of::<VhostUserMsgHeader<R>>() {
460 return Err(Error::PartialMessage);
461 } else if !hdr.is_valid() {
462 return Err(Error::InvalidMessage);
463 }
464
465 Ok((hdr, files))
466 }
467
468 /// Receive a message with optional attached file descriptors.
469 /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be
470 /// accepted and all other file descriptor will be discard silently.
471 ///
472 /// # Return:
473 /// * - (message header, message body, [received files]) on success.
474 /// * - SocketRetry: temporary error caused by signals or short of resources.
475 /// * - SocketBroken: the underline socket is broken.
476 /// * - SocketError: other socket related errors.
477 /// * - PartialMessage: received a partial message.
478 /// * - InvalidMessage: received a invalid message.
recv_body<T: ByteValued + Sized + VhostUserMsgValidator>( &mut self, ) -> Result<(VhostUserMsgHeader<R>, T, Option<Vec<File>>)>479 pub fn recv_body<T: ByteValued + Sized + VhostUserMsgValidator>(
480 &mut self,
481 ) -> Result<(VhostUserMsgHeader<R>, T, Option<Vec<File>>)> {
482 let mut hdr = VhostUserMsgHeader::default();
483 let mut body: T = Default::default();
484 let mut iovs = [
485 iovec {
486 iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
487 iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
488 },
489 iovec {
490 iov_base: (&mut body as *mut T) as *mut c_void,
491 iov_len: mem::size_of::<T>(),
492 },
493 ];
494 // SAFETY: Safe because we own hdr and body and they're ByteValued.
495 let (bytes, files) = unsafe { self.recv_into_iovec_all(&mut iovs[..])? };
496
497 let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>();
498 if bytes != total {
499 return Err(Error::PartialMessage);
500 } else if !hdr.is_valid() || !body.is_valid() {
501 return Err(Error::InvalidMessage);
502 }
503
504 Ok((hdr, body, files))
505 }
506
507 /// Receive a message with header and optional content. Callers need to
508 /// pre-allocate a big enough buffer to receive the message body and
509 /// optional payload. If there are attached file descriptor associated
510 /// with the message, the first MAX_ATTACHED_FD_ENTRIES file descriptors
511 /// will be accepted and all other file descriptor will be discard
512 /// silently.
513 ///
514 /// # Return:
515 /// * - (message header, message size, [received files]) on success.
516 /// * - SocketRetry: temporary error caused by signals or short of resources.
517 /// * - SocketBroken: the underline socket is broken.
518 /// * - SocketError: other socket related errors.
519 /// * - PartialMessage: received a partial message.
520 /// * - InvalidMessage: received a invalid message.
recv_body_into_buf( &mut self, buf: &mut [u8], ) -> Result<(VhostUserMsgHeader<R>, usize, Option<Vec<File>>)>521 pub fn recv_body_into_buf(
522 &mut self,
523 buf: &mut [u8],
524 ) -> Result<(VhostUserMsgHeader<R>, usize, Option<Vec<File>>)> {
525 let mut hdr = VhostUserMsgHeader::default();
526 let mut iovs = [
527 iovec {
528 iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
529 iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
530 },
531 iovec {
532 iov_base: buf.as_mut_ptr() as *mut c_void,
533 iov_len: buf.len(),
534 },
535 ];
536 // SAFETY: Safe because we own hdr and have a mutable borrow of buf, and hdr is ByteValued
537 // and it's safe to fill a byte slice with arbitrary data.
538 let (bytes, files) = unsafe { self.recv_into_iovec_all(&mut iovs[..])? };
539
540 if bytes < mem::size_of::<VhostUserMsgHeader<R>>() {
541 return Err(Error::PartialMessage);
542 } else if !hdr.is_valid() {
543 return Err(Error::InvalidMessage);
544 }
545
546 Ok((hdr, bytes - mem::size_of::<VhostUserMsgHeader<R>>(), files))
547 }
548
549 /// Receive a message with optional payload and attached file descriptors.
550 /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be
551 /// accepted and all other file descriptor will be discard silently.
552 ///
553 /// # Return:
554 /// * - (message header, message body, size of payload, [received files]) on success.
555 /// * - SocketRetry: temporary error caused by signals or short of resources.
556 /// * - SocketBroken: the underline socket is broken.
557 /// * - SocketError: other socket related errors.
558 /// * - PartialMessage: received a partial message.
559 /// * - InvalidMessage: received a invalid message.
560 #[cfg_attr(feature = "cargo-clippy", allow(clippy::type_complexity))]
recv_payload_into_buf<T: ByteValued + Sized + VhostUserMsgValidator>( &mut self, buf: &mut [u8], ) -> Result<(VhostUserMsgHeader<R>, T, usize, Option<Vec<File>>)>561 pub fn recv_payload_into_buf<T: ByteValued + Sized + VhostUserMsgValidator>(
562 &mut self,
563 buf: &mut [u8],
564 ) -> Result<(VhostUserMsgHeader<R>, T, usize, Option<Vec<File>>)> {
565 let mut hdr = VhostUserMsgHeader::default();
566 let mut body: T = Default::default();
567 let mut iovs = [
568 iovec {
569 iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
570 iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
571 },
572 iovec {
573 iov_base: (&mut body as *mut T) as *mut c_void,
574 iov_len: mem::size_of::<T>(),
575 },
576 iovec {
577 iov_base: buf.as_mut_ptr() as *mut c_void,
578 iov_len: buf.len(),
579 },
580 ];
581 // SAFETY: Safe because we own hdr and body and have a mutable borrow of buf, and
582 // hdr and body are ByteValued, and it's safe to fill a byte slice with
583 // arbitrary data.
584 let (bytes, files) = unsafe { self.recv_into_iovec_all(&mut iovs[..])? };
585
586 let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>();
587 if bytes < total {
588 return Err(Error::PartialMessage);
589 } else if !hdr.is_valid() || !body.is_valid() {
590 return Err(Error::InvalidMessage);
591 }
592
593 Ok((hdr, body, bytes - total, files))
594 }
595 }
596
597 impl<T: Req> AsRawFd for Endpoint<T> {
as_raw_fd(&self) -> RawFd598 fn as_raw_fd(&self) -> RawFd {
599 self.sock.as_raw_fd()
600 }
601 }
602
603 // Given a slice of sizes and the `skip_size`, return the offset of `skip_size` in the slice.
604 // For example:
605 // let iov_lens = vec![4, 4, 5];
606 // let size = 6;
607 // assert_eq!(get_sub_iovs_offset(&iov_len, size), (1, 2));
get_sub_iovs_offset(iov_lens: &[usize], skip_size: usize) -> (usize, usize)608 fn get_sub_iovs_offset(iov_lens: &[usize], skip_size: usize) -> (usize, usize) {
609 let mut size = skip_size;
610 let mut nr_skip = 0;
611
612 for len in iov_lens {
613 if size >= *len {
614 size -= *len;
615 nr_skip += 1;
616 } else {
617 break;
618 }
619 }
620 (nr_skip, size)
621 }
622
623 #[cfg(test)]
624 mod tests {
625 use super::*;
626 use std::io::{Read, Seek, SeekFrom, Write};
627 use vmm_sys_util::rand::rand_alphanumerics;
628 use vmm_sys_util::tempfile::TempFile;
629
temp_path() -> PathBuf630 fn temp_path() -> PathBuf {
631 PathBuf::from(format!(
632 "/tmp/vhost_test_{}",
633 rand_alphanumerics(8).to_str().unwrap()
634 ))
635 }
636
637 #[test]
create_listener()638 fn create_listener() {
639 let path = temp_path();
640 let listener = Listener::new(path, true).unwrap();
641
642 assert!(listener.as_raw_fd() > 0);
643 }
644
645 #[test]
create_listener_from_raw_fd()646 fn create_listener_from_raw_fd() {
647 let path = temp_path();
648 let file = File::create(path).unwrap();
649
650 // SAFETY: Safe because `file` contains a valid fd to a file just created.
651 let listener = unsafe { Listener::from_raw_fd(file.as_raw_fd()) };
652
653 assert!(listener.as_raw_fd() > 0);
654 }
655
656 #[test]
accept_connection()657 fn accept_connection() {
658 let path = temp_path();
659 let listener = Listener::new(path, true).unwrap();
660 listener.set_nonblocking(true).unwrap();
661
662 // accept on a fd without incoming connection
663 let conn = listener.accept().unwrap();
664 assert!(conn.is_none());
665 }
666
667 #[test]
send_data()668 fn send_data() {
669 let path = temp_path();
670 let listener = Listener::new(&path, true).unwrap();
671 listener.set_nonblocking(true).unwrap();
672 let mut master = Endpoint::<MasterReq>::connect(&path).unwrap();
673 let sock = listener.accept().unwrap().unwrap();
674 let mut slave = Endpoint::<MasterReq>::from_stream(sock);
675
676 let buf1 = vec![0x1, 0x2, 0x3, 0x4];
677 let mut len = master.send_slice(&buf1[..], None).unwrap();
678 assert_eq!(len, 4);
679 let (bytes, buf2, _) = slave.recv_into_buf(0x1000).unwrap();
680 assert_eq!(bytes, 4);
681 assert_eq!(&buf1[..], &buf2[..bytes]);
682
683 len = master.send_slice(&buf1[..], None).unwrap();
684 assert_eq!(len, 4);
685 let (bytes, buf2, _) = slave.recv_into_buf(0x2).unwrap();
686 assert_eq!(bytes, 2);
687 assert_eq!(&buf1[..2], &buf2[..]);
688 let (bytes, buf2, _) = slave.recv_into_buf(0x2).unwrap();
689 assert_eq!(bytes, 2);
690 assert_eq!(&buf1[2..], &buf2[..]);
691 }
692
693 #[test]
send_fd()694 fn send_fd() {
695 let path = temp_path();
696 let listener = Listener::new(&path, true).unwrap();
697 listener.set_nonblocking(true).unwrap();
698 let mut master = Endpoint::<MasterReq>::connect(&path).unwrap();
699 let sock = listener.accept().unwrap().unwrap();
700 let mut slave = Endpoint::<MasterReq>::from_stream(sock);
701
702 let mut fd = TempFile::new().unwrap().into_file();
703 write!(fd, "test").unwrap();
704
705 // Normal case for sending/receiving file descriptors
706 let buf1 = vec![0x1, 0x2, 0x3, 0x4];
707 let len = master
708 .send_slice(&buf1[..], Some(&[fd.as_raw_fd()]))
709 .unwrap();
710 assert_eq!(len, 4);
711
712 let (bytes, buf2, files) = slave.recv_into_buf(4).unwrap();
713 assert_eq!(bytes, 4);
714 assert_eq!(&buf1[..], &buf2[..]);
715 assert!(files.is_some());
716 let files = files.unwrap();
717 {
718 assert_eq!(files.len(), 1);
719 let mut file = &files[0];
720 let mut content = String::new();
721 file.seek(SeekFrom::Start(0)).unwrap();
722 file.read_to_string(&mut content).unwrap();
723 assert_eq!(content, "test");
724 }
725
726 // Following communication pattern should work:
727 // Sending side: data(header, body) with fds
728 // Receiving side: data(header) with fds, data(body)
729 let len = master
730 .send_slice(
731 &buf1[..],
732 Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
733 )
734 .unwrap();
735 assert_eq!(len, 4);
736
737 let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap();
738 assert_eq!(bytes, 2);
739 assert_eq!(&buf1[..2], &buf2[..]);
740 assert!(files.is_some());
741 let files = files.unwrap();
742 {
743 assert_eq!(files.len(), 3);
744 let mut file = &files[1];
745 let mut content = String::new();
746 file.seek(SeekFrom::Start(0)).unwrap();
747 file.read_to_string(&mut content).unwrap();
748 assert_eq!(content, "test");
749 }
750 let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap();
751 assert_eq!(bytes, 2);
752 assert_eq!(&buf1[2..], &buf2[..]);
753 assert!(files.is_none());
754
755 // Following communication pattern should not work:
756 // Sending side: data(header, body) with fds
757 // Receiving side: data(header), data(body) with fds
758 let len = master
759 .send_slice(
760 &buf1[..],
761 Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
762 )
763 .unwrap();
764 assert_eq!(len, 4);
765
766 let (bytes, buf4) = slave.recv_data(2).unwrap();
767 assert_eq!(bytes, 2);
768 assert_eq!(&buf1[..2], &buf4[..]);
769 let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap();
770 assert_eq!(bytes, 2);
771 assert_eq!(&buf1[2..], &buf2[..]);
772 assert!(files.is_none());
773
774 // Following communication pattern should work:
775 // Sending side: data, data with fds
776 // Receiving side: data, data with fds
777 let len = master.send_slice(&buf1[..], None).unwrap();
778 assert_eq!(len, 4);
779 let len = master
780 .send_slice(
781 &buf1[..],
782 Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
783 )
784 .unwrap();
785 assert_eq!(len, 4);
786
787 let (bytes, buf2, files) = slave.recv_into_buf(0x4).unwrap();
788 assert_eq!(bytes, 4);
789 assert_eq!(&buf1[..], &buf2[..]);
790 assert!(files.is_none());
791
792 let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap();
793 assert_eq!(bytes, 2);
794 assert_eq!(&buf1[..2], &buf2[..]);
795 assert!(files.is_some());
796 let files = files.unwrap();
797 {
798 assert_eq!(files.len(), 3);
799 let mut file = &files[1];
800 let mut content = String::new();
801 file.seek(SeekFrom::Start(0)).unwrap();
802 file.read_to_string(&mut content).unwrap();
803 assert_eq!(content, "test");
804 }
805 let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap();
806 assert_eq!(bytes, 2);
807 assert_eq!(&buf1[2..], &buf2[..]);
808 assert!(files.is_none());
809
810 // Following communication pattern should not work:
811 // Sending side: data1, data2 with fds
812 // Receiving side: data + partial of data2, left of data2 with fds
813 let len = master.send_slice(&buf1[..], None).unwrap();
814 assert_eq!(len, 4);
815 let len = master
816 .send_slice(
817 &buf1[..],
818 Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
819 )
820 .unwrap();
821 assert_eq!(len, 4);
822
823 let (bytes, _) = slave.recv_data(5).unwrap();
824 assert_eq!(bytes, 5);
825
826 let (bytes, _, files) = slave.recv_into_buf(0x4).unwrap();
827 assert_eq!(bytes, 3);
828 assert!(files.is_none());
829
830 // If the target fd array is too small, extra file descriptors will get lost.
831 let len = master
832 .send_slice(
833 &buf1[..],
834 Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
835 )
836 .unwrap();
837 assert_eq!(len, 4);
838
839 let (bytes, _, files) = slave.recv_into_buf(0x4).unwrap();
840 assert_eq!(bytes, 4);
841 assert!(files.is_some());
842 }
843
844 #[test]
send_recv()845 fn send_recv() {
846 let path = temp_path();
847 let listener = Listener::new(&path, true).unwrap();
848 listener.set_nonblocking(true).unwrap();
849 let mut master = Endpoint::<MasterReq>::connect(&path).unwrap();
850 let sock = listener.accept().unwrap().unwrap();
851 let mut slave = Endpoint::<MasterReq>::from_stream(sock);
852
853 let mut hdr1 =
854 VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0, mem::size_of::<u64>() as u32);
855 hdr1.set_need_reply(true);
856 let features1 = 0x1u64;
857 master.send_message(&hdr1, &features1, None).unwrap();
858
859 let mut features2 = 0u64;
860
861 // SAFETY: Safe because features2 is valid and it's an `u64`.
862 let slice = unsafe {
863 slice::from_raw_parts_mut(
864 (&mut features2 as *mut u64) as *mut u8,
865 mem::size_of::<u64>(),
866 )
867 };
868 let (hdr2, bytes, files) = slave.recv_body_into_buf(slice).unwrap();
869 assert_eq!(hdr1, hdr2);
870 assert_eq!(bytes, 8);
871 assert_eq!(features1, features2);
872 assert!(files.is_none());
873
874 master.send_header(&hdr1, None).unwrap();
875 let (hdr2, files) = slave.recv_header().unwrap();
876 assert_eq!(hdr1, hdr2);
877 assert!(files.is_none());
878 }
879
880 #[test]
partial_message()881 fn partial_message() {
882 let path = temp_path();
883 let listener = Listener::new(&path, true).unwrap();
884 let mut master = UnixStream::connect(&path).unwrap();
885 let sock = listener.accept().unwrap().unwrap();
886 let mut slave = Endpoint::<MasterReq>::from_stream(sock);
887
888 write!(master, "a").unwrap();
889 drop(master);
890 assert!(matches!(slave.recv_header(), Err(Error::PartialMessage)));
891 }
892
893 #[test]
disconnected()894 fn disconnected() {
895 let path = temp_path();
896 let listener = Listener::new(&path, true).unwrap();
897 let _ = UnixStream::connect(&path).unwrap();
898 let sock = listener.accept().unwrap().unwrap();
899 let mut slave = Endpoint::<MasterReq>::from_stream(sock);
900
901 assert!(matches!(slave.recv_header(), Err(Error::Disconnected)));
902 }
903 }
904