1 // Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
2 // SPDX-License-Identifier: Apache-2.0
3 
4 //! The protocol for vhost-user is based on the existing implementation of vhost for the Linux
5 //! Kernel. The protocol defines two sides of the communication, master and slave. Master is
6 //! the application that shares its virtqueues. Slave is the consumer of the virtqueues.
7 //!
8 //! The communication channel between the master and the slave includes two sub channels. One is
9 //! used to send requests from the master to the slave and optional replies from the slave to the
10 //! master. This sub channel is created on master startup by connecting to the slave service
11 //! endpoint. The other is used to send requests from the slave to the master and optional replies
12 //! from the master to the slave. This sub channel is created by the master issuing a
13 //! VHOST_USER_SET_SLAVE_REQ_FD request to the slave with an auxiliary file descriptor.
14 //!
15 //! Unix domain socket is used as the underlying communication channel because the master needs to
16 //! send file descriptors to the slave.
17 //!
18 //! Most messages that can be sent via the Unix domain socket implementing vhost-user have an
19 //! equivalent ioctl to the kernel implementation.
20 
21 use std::fs::File;
22 use std::io::Error as IOError;
23 
24 pub mod message;
25 pub use self::message::{VhostUserProtocolFeatures, VhostUserVirtioFeatures};
26 
27 mod connection;
28 pub use self::connection::Listener;
29 
30 #[cfg(feature = "vhost-user-master")]
31 mod master;
32 #[cfg(feature = "vhost-user-master")]
33 pub use self::master::{Master, VhostUserMaster};
34 #[cfg(feature = "vhost-user")]
35 mod master_req_handler;
36 #[cfg(feature = "vhost-user")]
37 pub use self::master_req_handler::{
38     MasterReqHandler, VhostUserMasterReqHandler, VhostUserMasterReqHandlerMut,
39 };
40 
41 #[cfg(feature = "vhost-user-slave")]
42 mod slave;
43 #[cfg(feature = "vhost-user-slave")]
44 pub use self::slave::SlaveListener;
45 #[cfg(feature = "vhost-user-slave")]
46 mod slave_req_handler;
47 #[cfg(feature = "vhost-user-slave")]
48 pub use self::slave_req_handler::{
49     SlaveReqHandler, VhostUserSlaveReqHandler, VhostUserSlaveReqHandlerMut,
50 };
51 #[cfg(feature = "vhost-user-slave")]
52 mod slave_req;
53 #[cfg(feature = "vhost-user-slave")]
54 pub use self::slave_req::Slave;
55 
56 /// Errors for vhost-user operations
57 #[derive(Debug)]
58 pub enum Error {
59     /// Invalid parameters.
60     InvalidParam,
61     /// Invalid operation due to some reason
62     InvalidOperation(&'static str),
63     /// Unsupported operation due to missing feature
64     InactiveFeature(VhostUserVirtioFeatures),
65     /// Unsupported operations due to that the protocol feature hasn't been negotiated.
66     InactiveOperation(VhostUserProtocolFeatures),
67     /// Invalid message format, flag or content.
68     InvalidMessage,
69     /// Only part of a message have been sent or received successfully
70     PartialMessage,
71     /// The peer disconnected from the socket.
72     Disconnected,
73     /// Message is too large
74     OversizedMsg,
75     /// Fd array in question is too big or too small
76     IncorrectFds,
77     /// Can't connect to peer.
78     SocketConnect(std::io::Error),
79     /// Generic socket errors.
80     SocketError(std::io::Error),
81     /// The socket is broken or has been closed.
82     SocketBroken(std::io::Error),
83     /// Should retry the socket operation again.
84     SocketRetry(std::io::Error),
85     /// Failure from the slave side.
86     SlaveInternalError,
87     /// Failure from the master side.
88     MasterInternalError,
89     /// Virtio/protocol features mismatch.
90     FeatureMismatch,
91     /// Error from request handler
92     ReqHandlerError(IOError),
93     /// memfd file creation error
94     MemFdCreateError,
95     /// File truncate error
96     FileTrucateError,
97     /// memfd file seal errors
98     MemFdSealError,
99 }
100 
101 impl std::fmt::Display for Error {
fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result102     fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
103         match self {
104             Error::InvalidParam => write!(f, "invalid parameters"),
105             Error::InvalidOperation(reason) => write!(f, "invalid operation: {}", reason),
106             Error::InactiveFeature(bits) => write!(f, "inactive feature: {}", bits.bits()),
107             Error::InactiveOperation(bits) => {
108                 write!(f, "inactive protocol operation: {}", bits.bits())
109             }
110             Error::InvalidMessage => write!(f, "invalid message"),
111             Error::PartialMessage => write!(f, "partial message"),
112             Error::Disconnected => write!(f, "peer disconnected"),
113             Error::OversizedMsg => write!(f, "oversized message"),
114             Error::IncorrectFds => write!(f, "wrong number of attached fds"),
115             Error::SocketError(e) => write!(f, "socket error: {}", e),
116             Error::SocketConnect(e) => write!(f, "can't connect to peer: {}", e),
117             Error::SocketBroken(e) => write!(f, "socket is broken: {}", e),
118             Error::SocketRetry(e) => write!(f, "temporary socket error: {}", e),
119             Error::SlaveInternalError => write!(f, "slave internal error"),
120             Error::MasterInternalError => write!(f, "Master internal error"),
121             Error::FeatureMismatch => write!(f, "virtio/protocol features mismatch"),
122             Error::ReqHandlerError(e) => write!(f, "handler failed to handle request: {}", e),
123             Error::MemFdCreateError => {
124                 write!(f, "handler failed to allocate memfd during get_inflight_fd")
125             }
126             Error::FileTrucateError => {
127                 write!(f, "handler failed to trucate memfd during get_inflight_fd")
128             }
129             Error::MemFdSealError => write!(
130                 f,
131                 "handler failed to apply seals to memfd during get_inflight_fd"
132             ),
133         }
134     }
135 }
136 
137 impl std::error::Error for Error {}
138 
139 impl Error {
140     /// Determine whether to rebuild the underline communication channel.
should_reconnect(&self) -> bool141     pub fn should_reconnect(&self) -> bool {
142         match *self {
143             // Should reconnect because it may be caused by temporary network errors.
144             Error::PartialMessage => true,
145             // Should reconnect because the underline socket is broken.
146             Error::SocketBroken(_) => true,
147             // Slave internal error, hope it recovers on reconnect.
148             Error::SlaveInternalError => true,
149             // Master internal error, hope it recovers on reconnect.
150             Error::MasterInternalError => true,
151             // Should just retry the IO operation instead of rebuilding the underline connection.
152             Error::SocketRetry(_) => false,
153             // Looks like the peer deliberately disconnected the socket.
154             Error::Disconnected => false,
155             Error::InvalidParam | Error::InvalidOperation(_) => false,
156             Error::InactiveFeature(_) | Error::InactiveOperation(_) => false,
157             Error::InvalidMessage | Error::IncorrectFds | Error::OversizedMsg => false,
158             Error::SocketError(_) | Error::SocketConnect(_) => false,
159             Error::FeatureMismatch => false,
160             Error::ReqHandlerError(_) => false,
161             Error::MemFdCreateError | Error::FileTrucateError | Error::MemFdSealError => false,
162         }
163     }
164 }
165 
166 impl std::convert::From<vmm_sys_util::errno::Error> for Error {
167     /// Convert raw socket errors into meaningful vhost-user errors.
168     ///
169     /// The vmm_sys_util::errno::Error is a simple wrapper over the raw errno, which doesn't means
170     /// much to the vhost-user connection manager. So convert it into meaningful errors to simplify
171     /// the connection manager logic.
172     ///
173     /// # Return:
174     /// * - Error::SocketRetry: temporary error caused by signals or short of resources.
175     /// * - Error::SocketBroken: the underline socket is broken.
176     /// * - Error::SocketError: other socket related errors.
177     #[allow(unreachable_patterns)] // EWOULDBLOCK equals to EGAIN on linux
from(err: vmm_sys_util::errno::Error) -> Self178     fn from(err: vmm_sys_util::errno::Error) -> Self {
179         match err.errno() {
180             // The socket is marked nonblocking and the requested operation would block.
181             libc::EAGAIN => Error::SocketRetry(IOError::from_raw_os_error(libc::EAGAIN)),
182             // The socket is marked nonblocking and the requested operation would block.
183             libc::EWOULDBLOCK => Error::SocketRetry(IOError::from_raw_os_error(libc::EWOULDBLOCK)),
184             // A signal occurred before any data was transmitted
185             libc::EINTR => Error::SocketRetry(IOError::from_raw_os_error(libc::EINTR)),
186             // The  output  queue  for  a network interface was full.  This generally indicates
187             // that the interface has stopped sending, but may be caused by transient congestion.
188             libc::ENOBUFS => Error::SocketRetry(IOError::from_raw_os_error(libc::ENOBUFS)),
189             // No memory available.
190             libc::ENOMEM => Error::SocketRetry(IOError::from_raw_os_error(libc::ENOMEM)),
191             // Connection reset by peer.
192             libc::ECONNRESET => Error::SocketBroken(IOError::from_raw_os_error(libc::ECONNRESET)),
193             // The local end has been shut down on a connection oriented socket. In this  case the
194             // process will also receive a SIGPIPE unless MSG_NOSIGNAL is set.
195             libc::EPIPE => Error::SocketBroken(IOError::from_raw_os_error(libc::EPIPE)),
196             // Write permission is denied on the destination socket file, or search permission is
197             // denied for one of the directories the path prefix.
198             libc::EACCES => Error::SocketConnect(IOError::from_raw_os_error(libc::EACCES)),
199             // Catch all other errors
200             e => Error::SocketError(IOError::from_raw_os_error(e)),
201         }
202     }
203 }
204 
205 /// Result of vhost-user operations
206 pub type Result<T> = std::result::Result<T, Error>;
207 
208 /// Result of request handler.
209 pub type HandlerResult<T> = std::result::Result<T, IOError>;
210 
211 /// Utility function to take the first element from option of a vector of files.
212 /// Returns `None` if the vector contains no file or more than one file.
take_single_file(files: Option<Vec<File>>) -> Option<File>213 pub(crate) fn take_single_file(files: Option<Vec<File>>) -> Option<File> {
214     let mut files = files?;
215     if files.len() != 1 {
216         return None;
217     }
218     Some(files.swap_remove(0))
219 }
220 
221 #[cfg(all(test, feature = "vhost-user-slave"))]
222 mod dummy_slave;
223 
224 #[cfg(all(test, feature = "vhost-user-master", feature = "vhost-user-slave"))]
225 mod tests {
226     use std::fs::File;
227     use std::os::unix::io::AsRawFd;
228     use std::path::{Path, PathBuf};
229     use std::sync::{Arc, Barrier, Mutex};
230     use std::thread;
231     use vmm_sys_util::rand::rand_alphanumerics;
232     use vmm_sys_util::tempfile::TempFile;
233 
234     use super::dummy_slave::{DummySlaveReqHandler, VIRTIO_FEATURES};
235     use super::message::*;
236     use super::*;
237     use crate::backend::VhostBackend;
238     use crate::{VhostUserDirtyLogRegion, VhostUserMemoryRegionInfo, VringConfigData};
239 
temp_path() -> PathBuf240     fn temp_path() -> PathBuf {
241         PathBuf::from(format!(
242             "/tmp/vhost_test_{}",
243             rand_alphanumerics(8).to_str().unwrap()
244         ))
245     }
246 
create_slave<P, S>(path: P, backend: Arc<S>) -> (Master, SlaveReqHandler<S>) where P: AsRef<Path>, S: VhostUserSlaveReqHandler,247     fn create_slave<P, S>(path: P, backend: Arc<S>) -> (Master, SlaveReqHandler<S>)
248     where
249         P: AsRef<Path>,
250         S: VhostUserSlaveReqHandler,
251     {
252         let listener = Listener::new(&path, true).unwrap();
253         let mut slave_listener = SlaveListener::new(listener, backend).unwrap();
254         let master = Master::connect(&path, 1).unwrap();
255         (master, slave_listener.accept().unwrap().unwrap())
256     }
257 
258     #[test]
create_dummy_slave()259     fn create_dummy_slave() {
260         let slave = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
261 
262         slave.set_owner().unwrap();
263         assert!(slave.set_owner().is_err());
264     }
265 
266     #[test]
test_set_owner()267     fn test_set_owner() {
268         let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
269         let path = temp_path();
270         let (master, mut slave) = create_slave(path, slave_be.clone());
271 
272         assert!(!slave_be.lock().unwrap().owned);
273         master.set_owner().unwrap();
274         slave.handle_request().unwrap();
275         assert!(slave_be.lock().unwrap().owned);
276         master.set_owner().unwrap();
277         assert!(slave.handle_request().is_err());
278         assert!(slave_be.lock().unwrap().owned);
279     }
280 
281     #[test]
test_set_features()282     fn test_set_features() {
283         let mbar = Arc::new(Barrier::new(2));
284         let sbar = mbar.clone();
285         let path = temp_path();
286         let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
287         let (mut master, mut slave) = create_slave(path, slave_be.clone());
288 
289         thread::spawn(move || {
290             slave.handle_request().unwrap();
291             assert!(slave_be.lock().unwrap().owned);
292 
293             slave.handle_request().unwrap();
294             slave.handle_request().unwrap();
295             assert_eq!(
296                 slave_be.lock().unwrap().acked_features,
297                 VIRTIO_FEATURES & !0x1
298             );
299 
300             slave.handle_request().unwrap();
301             slave.handle_request().unwrap();
302             assert_eq!(
303                 slave_be.lock().unwrap().acked_protocol_features,
304                 VhostUserProtocolFeatures::all().bits()
305             );
306 
307             sbar.wait();
308         });
309 
310         master.set_owner().unwrap();
311 
312         // set virtio features
313         let features = master.get_features().unwrap();
314         assert_eq!(features, VIRTIO_FEATURES);
315         master.set_features(VIRTIO_FEATURES & !0x1).unwrap();
316 
317         // set vhost protocol features
318         let features = master.get_protocol_features().unwrap();
319         assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits());
320         master.set_protocol_features(features).unwrap();
321 
322         mbar.wait();
323     }
324 
325     #[test]
test_master_slave_process()326     fn test_master_slave_process() {
327         let mbar = Arc::new(Barrier::new(2));
328         let sbar = mbar.clone();
329         let path = temp_path();
330         let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
331         let (mut master, mut slave) = create_slave(path, slave_be.clone());
332 
333         thread::spawn(move || {
334             // set_own()
335             slave.handle_request().unwrap();
336             assert!(slave_be.lock().unwrap().owned);
337 
338             // get/set_features()
339             slave.handle_request().unwrap();
340             slave.handle_request().unwrap();
341             assert_eq!(
342                 slave_be.lock().unwrap().acked_features,
343                 VIRTIO_FEATURES & !0x1
344             );
345 
346             slave.handle_request().unwrap();
347             slave.handle_request().unwrap();
348 
349             let mut features = VhostUserProtocolFeatures::all();
350 
351             // Disable Xen mmap feature.
352             if !cfg!(feature = "xen") {
353                 features.remove(VhostUserProtocolFeatures::XEN_MMAP);
354             }
355 
356             assert_eq!(
357                 slave_be.lock().unwrap().acked_protocol_features,
358                 features.bits()
359             );
360 
361             // get_inflight_fd()
362             slave.handle_request().unwrap();
363             // set_inflight_fd()
364             slave.handle_request().unwrap();
365 
366             // get_queue_num()
367             slave.handle_request().unwrap();
368 
369             // set_mem_table()
370             slave.handle_request().unwrap();
371 
372             // get/set_config()
373             slave.handle_request().unwrap();
374             slave.handle_request().unwrap();
375 
376             // set_slave_request_fd
377             slave.handle_request().unwrap();
378 
379             // set_vring_enable
380             slave.handle_request().unwrap();
381 
382             // set_log_base,set_log_fd()
383             slave.handle_request().unwrap_err();
384             slave.handle_request().unwrap_err();
385 
386             // set_vring_xxx
387             slave.handle_request().unwrap();
388             slave.handle_request().unwrap();
389             slave.handle_request().unwrap();
390             slave.handle_request().unwrap();
391             slave.handle_request().unwrap();
392             slave.handle_request().unwrap();
393 
394             // get_max_mem_slots()
395             slave.handle_request().unwrap();
396 
397             // add_mem_region()
398             slave.handle_request().unwrap();
399 
400             // remove_mem_region()
401             slave.handle_request().unwrap();
402 
403             sbar.wait();
404         });
405 
406         master.set_owner().unwrap();
407 
408         // set virtio features
409         let features = master.get_features().unwrap();
410         assert_eq!(features, VIRTIO_FEATURES);
411         master.set_features(VIRTIO_FEATURES & !0x1).unwrap();
412 
413         // set vhost protocol features
414         let mut features = master.get_protocol_features().unwrap();
415         assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits());
416 
417         // Disable Xen mmap feature.
418         if !cfg!(feature = "xen") {
419             features.remove(VhostUserProtocolFeatures::XEN_MMAP);
420         }
421 
422         master.set_protocol_features(features).unwrap();
423 
424         // Retrieve inflight I/O tracking information
425         let (inflight_info, inflight_file) = master
426             .get_inflight_fd(&VhostUserInflight {
427                 num_queues: 2,
428                 queue_size: 256,
429                 ..Default::default()
430             })
431             .unwrap();
432         // Set the buffer back to the backend
433         master
434             .set_inflight_fd(&inflight_info, inflight_file.as_raw_fd())
435             .unwrap();
436 
437         let num = master.get_queue_num().unwrap();
438         assert_eq!(num, 2);
439 
440         let eventfd = vmm_sys_util::eventfd::EventFd::new(0).unwrap();
441         let mem = [VhostUserMemoryRegionInfo::new(
442             0,
443             0x10_0000,
444             0,
445             0,
446             eventfd.as_raw_fd(),
447         )];
448         master.set_mem_table(&mem).unwrap();
449 
450         master
451             .set_config(0x100, VhostUserConfigFlags::WRITABLE, &[0xa5u8; 4])
452             .unwrap();
453         let buf = [0x0u8; 4];
454         let (reply_body, reply_payload) = master
455             .get_config(0x100, 4, VhostUserConfigFlags::empty(), &buf)
456             .unwrap();
457         let offset = reply_body.offset;
458         assert_eq!(offset, 0x100);
459         assert_eq!(&reply_payload, &[0xa5; 4]);
460 
461         master.set_slave_request_fd(&eventfd).unwrap();
462         master.set_vring_enable(0, true).unwrap();
463 
464         master
465             .set_log_base(
466                 0,
467                 Some(VhostUserDirtyLogRegion {
468                     mmap_size: 0x1000,
469                     mmap_offset: 0,
470                     mmap_handle: eventfd.as_raw_fd(),
471                 }),
472             )
473             .unwrap();
474         master.set_log_fd(eventfd.as_raw_fd()).unwrap();
475 
476         master.set_vring_num(0, 256).unwrap();
477         master.set_vring_base(0, 0).unwrap();
478         let config = VringConfigData {
479             queue_max_size: 256,
480             queue_size: 128,
481             flags: VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits(),
482             desc_table_addr: 0x1000,
483             used_ring_addr: 0x2000,
484             avail_ring_addr: 0x3000,
485             log_addr: Some(0x4000),
486         };
487         master.set_vring_addr(0, &config).unwrap();
488         master.set_vring_call(0, &eventfd).unwrap();
489         master.set_vring_kick(0, &eventfd).unwrap();
490         master.set_vring_err(0, &eventfd).unwrap();
491 
492         let max_mem_slots = master.get_max_mem_slots().unwrap();
493         assert_eq!(max_mem_slots, 32);
494 
495         let region_file: File = TempFile::new().unwrap().into_file();
496         let region =
497             VhostUserMemoryRegionInfo::new(0x10_0000, 0x10_0000, 0, 0, region_file.as_raw_fd());
498         master.add_mem_region(&region).unwrap();
499 
500         master.remove_mem_region(&region).unwrap();
501 
502         mbar.wait();
503     }
504 
505     #[test]
test_error_display()506     fn test_error_display() {
507         assert_eq!(format!("{}", Error::InvalidParam), "invalid parameters");
508         assert_eq!(
509             format!("{}", Error::InvalidOperation("reason")),
510             "invalid operation: reason"
511         );
512     }
513 
514     #[test]
test_should_reconnect()515     fn test_should_reconnect() {
516         assert!(Error::PartialMessage.should_reconnect());
517         assert!(Error::SlaveInternalError.should_reconnect());
518         assert!(Error::MasterInternalError.should_reconnect());
519         assert!(!Error::InvalidParam.should_reconnect());
520         assert!(!Error::InvalidOperation("reason").should_reconnect());
521         assert!(
522             !Error::InactiveFeature(VhostUserVirtioFeatures::PROTOCOL_FEATURES).should_reconnect()
523         );
524         assert!(!Error::InactiveOperation(VhostUserProtocolFeatures::all()).should_reconnect());
525         assert!(!Error::InvalidMessage.should_reconnect());
526         assert!(!Error::IncorrectFds.should_reconnect());
527         assert!(!Error::OversizedMsg.should_reconnect());
528         assert!(!Error::FeatureMismatch.should_reconnect());
529     }
530 
531     #[test]
test_error_from_sys_util_error()532     fn test_error_from_sys_util_error() {
533         let e: Error = vmm_sys_util::errno::Error::new(libc::EAGAIN).into();
534         if let Error::SocketRetry(e1) = e {
535             assert_eq!(e1.raw_os_error().unwrap(), libc::EAGAIN);
536         } else {
537             panic!("invalid error code conversion!");
538         }
539     }
540 }
541