xref: /aosp_15_r20/external/crosvm/third_party/vmm_vhost/src/lib.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright (C) 2019 Alibaba Cloud. All rights reserved.
2 // SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause
3 
4 //! Virtio Vhost Backend Drivers
5 //!
6 //! Virtio devices use virtqueues to transport data efficiently. The first generation of virtqueue
7 //! is a set of three different single-producer, single-consumer ring structures designed to store
8 //! generic scatter-gather I/O. The virtio specification 1.1 introduces an alternative compact
9 //! virtqueue layout named "Packed Virtqueue", which is more friendly to memory cache system and
10 //! hardware implemented virtio devices. The packed virtqueue uses read-write memory, that means
11 //! the memory will be both read and written by both host and guest. The new Packed Virtqueue is
12 //! preferred for performance.
13 //!
14 //! Vhost is a mechanism to improve performance of Virtio devices by delegate data plane operations
15 //! to dedicated IO service processes. Only the configuration, I/O submission notification, and I/O
16 //! completion interruption are piped through the hypervisor.
17 //! It uses the same virtqueue layout as Virtio to allow Vhost devices to be mapped directly to
18 //! Virtio devices. This allows a Vhost device to be accessed directly by a guest OS inside a
19 //! hypervisor process with an existing Virtio (PCI) driver.
20 //!
21 //! The initial vhost implementation is a part of the Linux kernel and uses ioctl interface to
22 //! communicate with userspace applications. Dedicated kernel worker threads are created to handle
23 //! IO requests from the guest.
24 //!
25 //! Later Vhost-user protocol is introduced to complement the ioctl interface used to control the
26 //! vhost implementation in the Linux kernel. It implements the control plane needed to establish
27 //! virtqueues sharing with a user space process on the same host. It uses communication over a
28 //! Unix domain socket to share file descriptors in the ancillary data of the message. The protocol
29 //! defines 2 sides of the communication, frontend and backend. Frontend is the application that
30 //! shares its virtqueues. Backend is the consumer of the virtqueues. Frontend and backend can be
31 //! either a client (i.e. connecting) or server (listening) in the socket communication.
32 
33 use std::fs::File;
34 use std::io::Error as IOError;
35 use std::num::TryFromIntError;
36 
37 use remain::sorted;
38 use thiserror::Error as ThisError;
39 
40 mod backend;
41 pub use backend::*;
42 
43 pub mod message;
44 pub use message::VHOST_USER_F_PROTOCOL_FEATURES;
45 
46 pub mod connection;
47 
48 mod sys;
49 pub use connection::Connection;
50 pub use message::BackendReq;
51 pub use message::FrontendReq;
52 #[cfg(unix)]
53 pub use sys::unix;
54 
55 pub(crate) mod backend_client;
56 pub use backend_client::BackendClient;
57 mod frontend_server;
58 pub use self::frontend_server::Frontend;
59 mod backend_server;
60 mod frontend_client;
61 pub use self::backend_server::Backend;
62 pub use self::backend_server::BackendServer;
63 pub use self::frontend_client::FrontendClient;
64 pub use self::frontend_server::FrontendServer;
65 
66 /// Errors for vhost-user operations
67 #[sorted]
68 #[derive(Debug, ThisError)]
69 pub enum Error {
70     /// Failure from the backend side.
71     #[error("backend internal error")]
72     BackendInternalError,
73     /// client exited properly.
74     #[error("client exited properly")]
75     ClientExit,
76     /// Failure to deserialize data.
77     #[error("failed to deserialize data")]
78     DeserializationFailed,
79     /// client disconnected.
80     /// If connection is closed properly, use `ClientExit` instead.
81     #[error("client closed the connection")]
82     Disconnect,
83     #[error("Failed to enter suspended state")]
84     EnterSuspendedState(anyhow::Error),
85     /// Virtio/protocol features mismatch.
86     #[error("virtio features mismatch")]
87     FeatureMismatch,
88     /// Failure from the frontend side.
89     #[error("frontend Internal error")]
90     FrontendInternalError,
91     /// Fd array in question is too big or too small
92     #[error("wrong number of attached fds")]
93     IncorrectFds,
94     /// Invalid cast to int.
95     #[error("invalid cast to int: {0}")]
96     InvalidCastToInt(TryFromIntError),
97     /// Invalid message format, flag or content.
98     #[error("invalid message")]
99     InvalidMessage,
100     /// Unsupported operations due to that the protocol feature hasn't been negotiated.
101     #[error("invalid operation")]
102     InvalidOperation,
103     /// Invalid parameters.
104     #[error("invalid parameters")]
105     InvalidParam,
106     /// Message is too large
107     #[error("oversized message")]
108     OversizedMsg,
109     /// Only part of a message have been sent or received successfully
110     #[error("partial message")]
111     PartialMessage,
112     /// Provided recv buffer was too small, and data was dropped.
113     #[error("buffer for recv was too small, data was dropped: got size {got}, needed {want}")]
114     RecvBufferTooSmall {
115         /// The size of the buffer received.
116         got: usize,
117         /// The expected size of the buffer.
118         want: usize,
119     },
120     /// Error from request handler
121     #[error("handler failed to handle request: {0}")]
122     ReqHandlerError(IOError),
123     /// Failure to restore.
124     #[error("Failed to restore")]
125     RestoreError(anyhow::Error),
126     /// Failure to serialize data.
127     #[error("failed to serialize data")]
128     SerializationFailed,
129     /// Failure to run device specific sleep.
130     #[error("Failed to run device specific sleep: {0}")]
131     SleepError(anyhow::Error),
132     /// Failure to snapshot.
133     #[error("Failed to snapshot")]
134     SnapshotError(anyhow::Error),
135     /// The socket is broken or has been closed.
136     #[error("socket is broken: {0}")]
137     SocketBroken(std::io::Error),
138     /// Can't connect to peer.
139     #[error("can't connect to peer: {0}")]
140     SocketConnect(std::io::Error),
141     /// Generic socket errors.
142     #[error("socket error: {0}")]
143     SocketError(std::io::Error),
144     /// Fail to get socket from the fd
145     #[error("Failed get socket from the fd: {0}")]
146     SocketFromFdError(std::path::PathBuf),
147     /// Should retry the socket operation again.
148     #[error("temporary socket error: {0}")]
149     SocketRetry(std::io::Error),
150     /// Failure to stop a queue.
151     #[error("failed to stop queue")]
152     StopQueueError(anyhow::Error),
153     /// Error from tx/rx on a Tube.
154     #[error("failed to read/write on Tube: {0}")]
155     TubeError(base::TubeError),
156     /// Error from VFIO device.
157     #[error("error occurred in VFIO device: {0}")]
158     VfioDeviceError(anyhow::Error),
159     /// Error from invalid vring index.
160     #[error("Vring index not found: {0}")]
161     VringIndexNotFound(usize),
162     /// Failure to run device specific wake.
163     #[error("Failed to run device specific wake: {0}")]
164     WakeError(anyhow::Error),
165 }
166 
167 impl From<base::TubeError> for Error {
from(err: base::TubeError) -> Self168     fn from(err: base::TubeError) -> Self {
169         match err {
170             base::TubeError::Disconnected => Error::Disconnect,
171             err => Error::TubeError(err),
172         }
173     }
174 }
175 
176 impl From<std::io::Error> for Error {
from(err: std::io::Error) -> Self177     fn from(err: std::io::Error) -> Self {
178         Error::SocketError(err)
179     }
180 }
181 
182 impl From<base::Error> for Error {
183     /// Convert raw socket errors into meaningful vhost-user errors.
184     ///
185     /// The base::Error is a simple wrapper over the raw errno, which doesn't means
186     /// much to the vhost-user connection manager. So convert it into meaningful errors to simplify
187     /// the connection manager logic.
188     ///
189     /// # Return:
190     /// * - Error::SocketRetry: temporary error caused by signals or short of resources.
191     /// * - Error::SocketBroken: the underline socket is broken.
192     /// * - Error::SocketError: other socket related errors.
193     #[allow(unreachable_patterns)] // EWOULDBLOCK equals to EGAIN on linux
from(err: base::Error) -> Self194     fn from(err: base::Error) -> Self {
195         match err.errno() {
196             // Retry:
197             // * EAGAIN, EWOULDBLOCK: The socket is marked nonblocking and the requested operation
198             //   would block.
199             // * EINTR: A signal occurred before any data was transmitted
200             // * ENOBUFS: The  output  queue  for  a network interface was full.  This generally
201             //   indicates that the interface has stopped sending, but may be caused by transient
202             //   congestion.
203             // * ENOMEM: No memory available.
204             libc::EAGAIN | libc::EWOULDBLOCK | libc::EINTR | libc::ENOBUFS | libc::ENOMEM => {
205                 Error::SocketRetry(err.into())
206             }
207             // Broken:
208             // * ECONNRESET: Connection reset by peer.
209             // * EPIPE: The local end has been shut down on a connection oriented socket. In this
210             //   case the process will also receive a SIGPIPE unless MSG_NOSIGNAL is set.
211             libc::ECONNRESET | libc::EPIPE => Error::SocketBroken(err.into()),
212             // Write permission is denied on the destination socket file, or search permission is
213             // denied for one of the directories the path prefix.
214             libc::EACCES => Error::SocketConnect(IOError::from_raw_os_error(libc::EACCES)),
215             // Catch all other errors
216             e => Error::SocketError(IOError::from_raw_os_error(e)),
217         }
218     }
219 }
220 
221 /// Result of vhost-user operations
222 pub type Result<T> = std::result::Result<T, Error>;
223 
224 /// Result of request handler.
225 pub type HandlerResult<T> = std::result::Result<T, IOError>;
226 
227 /// Utility function to convert a vector of files into a single file.
228 /// Returns `None` if the vector contains no files or more than one file.
into_single_file(mut files: Vec<File>) -> Option<File>229 pub(crate) fn into_single_file(mut files: Vec<File>) -> Option<File> {
230     if files.len() != 1 {
231         return None;
232     }
233     Some(files.swap_remove(0))
234 }
235 
236 #[cfg(test)]
237 mod test_backend;
238 
239 #[cfg(test)]
240 mod tests {
241     use std::sync::Arc;
242     use std::sync::Barrier;
243     use std::thread;
244 
245     use base::AsRawDescriptor;
246     use tempfile::tempfile;
247 
248     use super::*;
249     use crate::message::*;
250     use crate::test_backend::TestBackend;
251     use crate::test_backend::VIRTIO_FEATURES;
252     use crate::VhostUserMemoryRegionInfo;
253     use crate::VringConfigData;
254 
create_client_server_pair<S>(backend: S) -> (BackendClient, BackendServer<S>) where S: Backend,255     fn create_client_server_pair<S>(backend: S) -> (BackendClient, BackendServer<S>)
256     where
257         S: Backend,
258     {
259         let (client_connection, server_connection) = Connection::pair().unwrap();
260         let backend_client = BackendClient::new(client_connection);
261         (
262             backend_client,
263             BackendServer::<S>::new(server_connection, backend),
264         )
265     }
266 
267     /// Utility function to process a header and a message together.
handle_request(h: &mut BackendServer<TestBackend>) -> Result<()>268     fn handle_request(h: &mut BackendServer<TestBackend>) -> Result<()> {
269         // We assume that a header comes together with message body in tests so we don't wait before
270         // calling `process_message()`.
271         let (hdr, files) = h.recv_header()?;
272         h.process_message(hdr, files)
273     }
274 
275     #[test]
create_test_backend()276     fn create_test_backend() {
277         let mut backend = TestBackend::new();
278 
279         backend.set_owner().unwrap();
280         assert!(backend.set_owner().is_err());
281     }
282 
283     #[test]
test_set_owner()284     fn test_set_owner() {
285         let test_backend = TestBackend::new();
286         let (backend_client, mut backend_server) = create_client_server_pair(test_backend);
287 
288         assert!(!backend_server.as_ref().owned);
289         backend_client.set_owner().unwrap();
290         handle_request(&mut backend_server).unwrap();
291         assert!(backend_server.as_ref().owned);
292         backend_client.set_owner().unwrap();
293         assert!(handle_request(&mut backend_server).is_err());
294         assert!(backend_server.as_ref().owned);
295     }
296 
297     #[test]
test_set_features()298     fn test_set_features() {
299         let mbar = Arc::new(Barrier::new(2));
300         let sbar = mbar.clone();
301         let test_backend = TestBackend::new();
302         let (mut backend_client, mut backend_server) = create_client_server_pair(test_backend);
303 
304         thread::spawn(move || {
305             handle_request(&mut backend_server).unwrap();
306             assert!(backend_server.as_ref().owned);
307 
308             handle_request(&mut backend_server).unwrap();
309             handle_request(&mut backend_server).unwrap();
310             assert_eq!(
311                 backend_server.as_ref().acked_features,
312                 VIRTIO_FEATURES & !0x1
313             );
314 
315             handle_request(&mut backend_server).unwrap();
316             handle_request(&mut backend_server).unwrap();
317             assert_eq!(
318                 backend_server.as_ref().acked_protocol_features,
319                 VhostUserProtocolFeatures::all().bits()
320             );
321 
322             sbar.wait();
323         });
324 
325         backend_client.set_owner().unwrap();
326 
327         // set virtio features
328         let features = backend_client.get_features().unwrap();
329         assert_eq!(features, VIRTIO_FEATURES);
330         backend_client.set_features(VIRTIO_FEATURES & !0x1).unwrap();
331 
332         // set vhost protocol features
333         let features = backend_client.get_protocol_features().unwrap();
334         assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits());
335         backend_client.set_protocol_features(features).unwrap();
336 
337         mbar.wait();
338     }
339 
340     #[test]
test_client_server_process()341     fn test_client_server_process() {
342         let mbar = Arc::new(Barrier::new(2));
343         let sbar = mbar.clone();
344         let test_backend = TestBackend::new();
345         let (mut backend_client, mut backend_server) = create_client_server_pair(test_backend);
346 
347         thread::spawn(move || {
348             // set_own()
349             handle_request(&mut backend_server).unwrap();
350             assert!(backend_server.as_ref().owned);
351 
352             // get/set_features()
353             handle_request(&mut backend_server).unwrap();
354             handle_request(&mut backend_server).unwrap();
355             assert_eq!(
356                 backend_server.as_ref().acked_features,
357                 VIRTIO_FEATURES & !0x1
358             );
359 
360             handle_request(&mut backend_server).unwrap();
361             handle_request(&mut backend_server).unwrap();
362             assert_eq!(
363                 backend_server.as_ref().acked_protocol_features,
364                 VhostUserProtocolFeatures::all().bits()
365             );
366 
367             // get_inflight_fd()
368             handle_request(&mut backend_server).unwrap();
369             // set_inflight_fd()
370             handle_request(&mut backend_server).unwrap();
371 
372             // get_queue_num()
373             handle_request(&mut backend_server).unwrap();
374 
375             // set_mem_table()
376             handle_request(&mut backend_server).unwrap();
377 
378             // get/set_config()
379             handle_request(&mut backend_server).unwrap();
380             handle_request(&mut backend_server).unwrap();
381 
382             // set_backend_req_fd
383             handle_request(&mut backend_server).unwrap();
384 
385             // set_vring_enable
386             handle_request(&mut backend_server).unwrap();
387 
388             // set_log_base,set_log_fd()
389             handle_request(&mut backend_server).unwrap_err();
390             handle_request(&mut backend_server).unwrap_err();
391 
392             // set_vring_xxx
393             handle_request(&mut backend_server).unwrap();
394             handle_request(&mut backend_server).unwrap();
395             handle_request(&mut backend_server).unwrap();
396             handle_request(&mut backend_server).unwrap();
397             handle_request(&mut backend_server).unwrap();
398             handle_request(&mut backend_server).unwrap();
399 
400             // get_max_mem_slots()
401             handle_request(&mut backend_server).unwrap();
402 
403             // add_mem_region()
404             handle_request(&mut backend_server).unwrap();
405 
406             // remove_mem_region()
407             handle_request(&mut backend_server).unwrap();
408 
409             sbar.wait();
410         });
411 
412         backend_client.set_owner().unwrap();
413 
414         // set virtio features
415         let features = backend_client.get_features().unwrap();
416         assert_eq!(features, VIRTIO_FEATURES);
417         backend_client.set_features(VIRTIO_FEATURES & !0x1).unwrap();
418 
419         // set vhost protocol features
420         let features = backend_client.get_protocol_features().unwrap();
421         assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits());
422         backend_client.set_protocol_features(features).unwrap();
423 
424         // Retrieve inflight I/O tracking information
425         let (inflight_info, inflight_file) = backend_client
426             .get_inflight_fd(&VhostUserInflight {
427                 num_queues: 2,
428                 queue_size: 256,
429                 ..Default::default()
430             })
431             .unwrap();
432         // Set the buffer back to the backend
433         backend_client
434             .set_inflight_fd(&inflight_info, inflight_file.as_raw_descriptor())
435             .unwrap();
436 
437         let num = backend_client.get_queue_num().unwrap();
438         assert_eq!(num, 2);
439 
440         let event = base::Event::new().unwrap();
441         let mem = [VhostUserMemoryRegionInfo {
442             guest_phys_addr: 0,
443             memory_size: 0x10_0000,
444             userspace_addr: 0,
445             mmap_offset: 0,
446             mmap_handle: event.as_raw_descriptor(),
447         }];
448         backend_client.set_mem_table(&mem).unwrap();
449 
450         backend_client
451             .set_config(0x100, VhostUserConfigFlags::WRITABLE, &[0xa5u8])
452             .unwrap();
453         let buf = [0x0u8; 4];
454         let (reply_body, reply_payload) = backend_client
455             .get_config(0x100, 4, VhostUserConfigFlags::empty(), &buf)
456             .unwrap();
457         let offset = reply_body.offset;
458         assert_eq!(offset, 0x100);
459         assert_eq!(reply_payload[0], 0xa5);
460 
461         #[cfg(windows)]
462         let tubes = base::Tube::pair().unwrap();
463         #[cfg(windows)]
464         let descriptor =
465             // SAFETY:
466             // Safe because we will be importing the Tube in the other thread.
467             unsafe { tube_transporter::packed_tube::pack(tubes.0, std::process::id()).unwrap() };
468 
469         #[cfg(unix)]
470         let descriptor = base::Event::new().unwrap();
471 
472         backend_client.set_backend_req_fd(&descriptor).unwrap();
473         backend_client.set_vring_enable(0, true).unwrap();
474 
475         // unimplemented yet
476         backend_client
477             .set_log_base(0, Some(event.as_raw_descriptor()))
478             .unwrap();
479         backend_client
480             .set_log_fd(event.as_raw_descriptor())
481             .unwrap();
482 
483         backend_client.set_vring_num(0, 256).unwrap();
484         backend_client.set_vring_base(0, 0).unwrap();
485         let config = VringConfigData {
486             queue_size: 128,
487             flags: VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits(),
488             desc_table_addr: 0x1000,
489             used_ring_addr: 0x2000,
490             avail_ring_addr: 0x3000,
491             log_addr: Some(0x4000),
492         };
493         backend_client.set_vring_addr(0, &config).unwrap();
494         backend_client.set_vring_call(0, &event).unwrap();
495         backend_client.set_vring_kick(0, &event).unwrap();
496         backend_client.set_vring_err(0, &event).unwrap();
497 
498         let max_mem_slots = backend_client.get_max_mem_slots().unwrap();
499         assert_eq!(max_mem_slots, 32);
500 
501         let region_file = tempfile().unwrap();
502         let region = VhostUserMemoryRegionInfo {
503             guest_phys_addr: 0x10_0000,
504             memory_size: 0x10_0000,
505             userspace_addr: 0,
506             mmap_offset: 0,
507             mmap_handle: region_file.as_raw_descriptor(),
508         };
509         backend_client.add_mem_region(&region).unwrap();
510 
511         backend_client.remove_mem_region(&region).unwrap();
512 
513         mbar.wait();
514     }
515 
516     #[test]
test_error_display()517     fn test_error_display() {
518         assert_eq!(format!("{}", Error::InvalidParam), "invalid parameters");
519         assert_eq!(format!("{}", Error::InvalidOperation), "invalid operation");
520     }
521 
522     #[test]
test_error_from_base_error()523     fn test_error_from_base_error() {
524         let e: Error = base::Error::new(libc::EAGAIN).into();
525         if let Error::SocketRetry(e1) = e {
526             assert_eq!(e1.raw_os_error().unwrap(), libc::EAGAIN);
527         } else {
528             panic!("invalid error code conversion!");
529         }
530     }
531 }
532