xref: /aosp_15_r20/external/crosvm/third_party/vmm_vhost/src/sys/unix.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2022 The Chromium OS Authors. All rights reserved.
2 // SPDX-License-Identifier: Apache-2.0
3 
4 //! Unix specific code that keeps rest of the code in the crate platform independent.
5 
6 use std::any::Any;
7 use std::fs::File;
8 use std::io::ErrorKind;
9 use std::io::IoSlice;
10 use std::io::IoSliceMut;
11 use std::os::fd::OwnedFd;
12 use std::os::unix::net::UnixListener;
13 use std::os::unix::net::UnixStream;
14 use std::path::Path;
15 use std::path::PathBuf;
16 
17 use base::AsRawDescriptor;
18 use base::RawDescriptor;
19 use base::ReadNotifier;
20 use base::SafeDescriptor;
21 use base::ScmSocket;
22 
23 use crate::connection::Listener;
24 use crate::frontend_server::FrontendServer;
25 use crate::message::FrontendReq;
26 use crate::message::Req;
27 use crate::message::MAX_ATTACHED_FD_ENTRIES;
28 use crate::Connection;
29 use crate::Error;
30 use crate::Frontend;
31 use crate::Result;
32 
33 /// Alias to enable platform independent code.
34 pub type SystemListener = UnixListener;
35 
36 pub use SocketPlatformConnection as PlatformConnection;
37 
38 /// Unix domain socket listener for accepting incoming connections.
39 pub struct SocketListener {
40     fd: SystemListener,
41     drop_path: Option<Box<dyn Any>>,
42 }
43 
44 impl SocketListener {
45     /// Create a unix domain socket listener.
46     ///
47     /// # Return:
48     /// * - the new SocketListener object on success.
49     /// * - SocketError: failed to create listener socket.
new<P: AsRef<Path>>(path: P, unlink: bool) -> Result<Self>50     pub fn new<P: AsRef<Path>>(path: P, unlink: bool) -> Result<Self> {
51         if unlink {
52             let _ = std::fs::remove_file(&path);
53         }
54         let fd = SystemListener::bind(&path).map_err(Error::SocketError)?;
55 
56         struct DropPath {
57             path: PathBuf,
58         }
59 
60         impl Drop for DropPath {
61             fn drop(&mut self) {
62                 let _ = std::fs::remove_file(&self.path);
63             }
64         }
65 
66         Ok(SocketListener {
67             fd,
68             drop_path: Some(Box::new(DropPath {
69                 path: path.as_ref().to_owned(),
70             })),
71         })
72     }
73 
74     /// Take and return the resources that the parent process needs to keep alive as long as the
75     /// child process lives, in case of incoming fork.
take_resources_for_parent(&mut self) -> Option<Box<dyn Any>>76     pub fn take_resources_for_parent(&mut self) -> Option<Box<dyn Any>> {
77         self.drop_path.take()
78     }
79 }
80 
81 impl Listener for SocketListener {
82     /// Accept an incoming connection.
83     ///
84     /// # Return:
85     /// * - Some(SystemListener): new SystemListener object if new incoming connection is available.
86     /// * - None: no incoming connection available.
87     /// * - SocketError: errors from accept().
accept(&mut self) -> Result<Option<Connection<FrontendReq>>>88     fn accept(&mut self) -> Result<Option<Connection<FrontendReq>>> {
89         loop {
90             match self.fd.accept() {
91                 Ok((stream, _addr)) => {
92                     return Ok(Some(Connection::try_from(stream)?));
93                 }
94                 Err(e) => {
95                     match e.kind() {
96                         // No incoming connection available.
97                         ErrorKind::WouldBlock => return Ok(None),
98                         // New connection closed by peer.
99                         ErrorKind::ConnectionAborted => return Ok(None),
100                         // Interrupted by signals, retry
101                         ErrorKind::Interrupted => continue,
102                         _ => return Err(Error::SocketError(e)),
103                     }
104                 }
105             }
106         }
107     }
108 
109     /// Change blocking status on the listener.
110     ///
111     /// # Return:
112     /// * - () on success.
113     /// * - SocketError: failure from set_nonblocking().
set_nonblocking(&self, block: bool) -> Result<()>114     fn set_nonblocking(&self, block: bool) -> Result<()> {
115         self.fd.set_nonblocking(block).map_err(Error::SocketError)
116     }
117 }
118 
119 impl AsRawDescriptor for SocketListener {
as_raw_descriptor(&self) -> RawDescriptor120     fn as_raw_descriptor(&self) -> RawDescriptor {
121         self.fd.as_raw_descriptor()
122     }
123 }
124 
125 /// Unix domain socket based vhost-user connection.
126 pub struct SocketPlatformConnection {
127     sock: ScmSocket<UnixStream>,
128 }
129 
130 // Advance the internal cursor of the slices.
131 // This is same with a nightly API `IoSlice::advance_slices` but for `&[u8]`.
advance_slices(bufs: &mut &mut [&[u8]], mut count: usize)132 fn advance_slices(bufs: &mut &mut [&[u8]], mut count: usize) {
133     use std::mem::take;
134 
135     let mut idx = 0;
136     for b in bufs.iter() {
137         if count < b.len() {
138             break;
139         }
140         count -= b.len();
141         idx += 1;
142     }
143     *bufs = &mut take(bufs)[idx..];
144     if !bufs.is_empty() {
145         bufs[0] = &bufs[0][count..];
146     }
147 }
148 
149 impl SocketPlatformConnection {
150     /// Sends all bytes from scatter-gather vectors with optional attached file descriptors. Will
151     /// loop until all data has been transfered.
152     ///
153     /// # TODO
154     /// This function takes a slice of `&[u8]` instead of `IoSlice` because the internal
155     /// cursor needs to be moved by `advance_slices()`.
156     /// Once `IoSlice::advance_slices()` becomes stable, this should be updated.
157     /// <https://github.com/rust-lang/rust/issues/62726>.
send_iovec_all( &self, mut iovs: &mut [&[u8]], mut fds: Option<&[RawDescriptor]>, ) -> Result<()>158     fn send_iovec_all(
159         &self,
160         mut iovs: &mut [&[u8]],
161         mut fds: Option<&[RawDescriptor]>,
162     ) -> Result<()> {
163         // Guarantee that `iovs` becomes empty if it doesn't contain any data.
164         advance_slices(&mut iovs, 0);
165 
166         while !iovs.is_empty() {
167             let iovec: Vec<_> = iovs.iter_mut().map(|i| IoSlice::new(i)).collect();
168             match self.sock.send_vectored_with_fds(&iovec, fds.unwrap_or(&[])) {
169                 Ok(n) => {
170                     fds = None;
171                     advance_slices(&mut iovs, n);
172                 }
173                 Err(e) => match e.kind() {
174                     ErrorKind::WouldBlock | ErrorKind::Interrupted => {}
175                     _ => return Err(Error::SocketError(e)),
176                 },
177             }
178         }
179         Ok(())
180     }
181 
182     /// Sends a single message over the socket with optional attached file descriptors.
183     ///
184     /// - `hdr`: vhost message header
185     /// - `body`: vhost message body (may be empty to send a header-only message)
186     /// - `payload`: additional bytes to append to `body` (may be empty)
send_message( &self, hdr: &[u8], body: &[u8], payload: &[u8], fds: Option<&[RawDescriptor]>, ) -> Result<()>187     pub fn send_message(
188         &self,
189         hdr: &[u8],
190         body: &[u8],
191         payload: &[u8],
192         fds: Option<&[RawDescriptor]>,
193     ) -> Result<()> {
194         let mut iobufs = [hdr, body, payload];
195         self.send_iovec_all(&mut iobufs, fds)
196     }
197 
198     /// Reads bytes from the socket into the given scatter/gather vectors with optional attached
199     /// file.
200     ///
201     /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little
202     /// tricky to pass file descriptors through such a communication channel. Let's assume that a
203     /// sender sending a message with some file descriptors attached. To successfully receive those
204     /// attached file descriptors, the receiver must obey following rules:
205     ///   1) file descriptors are attached to a message.
206     ///   2) message(packet) boundaries must be respected on the receive side.
207     ///
208     /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the
209     /// attached file descriptors will get lost.
210     /// Note that this function wraps received file descriptors as `File`.
211     ///
212     /// # Return:
213     /// * - (number of bytes received, [received files]) on success
214     /// * - Disconnect: the connection is closed.
215     /// * - SocketRetry: temporary error caused by signals or short of resources.
216     /// * - SocketBroken: the underline socket is broken.
217     /// * - SocketError: other socket related errors.
recv_into_bufs( &self, bufs: &mut [IoSliceMut], allow_fd: bool, ) -> Result<(usize, Option<Vec<File>>)>218     pub fn recv_into_bufs(
219         &self,
220         bufs: &mut [IoSliceMut],
221         allow_fd: bool,
222     ) -> Result<(usize, Option<Vec<File>>)> {
223         let max_fds = if allow_fd { MAX_ATTACHED_FD_ENTRIES } else { 0 };
224         let (bytes, fds) = self.sock.recv_vectored_with_fds(bufs, max_fds)?;
225 
226         // 0-bytes indicates that the connection is closed.
227         if bytes == 0 {
228             return Err(Error::Disconnect);
229         }
230 
231         let files = if fds.is_empty() {
232             None
233         } else {
234             Some(fds.into_iter().map(File::from).collect())
235         };
236 
237         Ok((bytes, files))
238     }
239 }
240 
241 impl AsRawDescriptor for SocketPlatformConnection {
as_raw_descriptor(&self) -> RawDescriptor242     fn as_raw_descriptor(&self) -> RawDescriptor {
243         self.sock.as_raw_descriptor()
244     }
245 }
246 
247 impl ReadNotifier for SocketPlatformConnection {
get_read_notifier(&self) -> &dyn AsRawDescriptor248     fn get_read_notifier(&self) -> &dyn AsRawDescriptor {
249         &self.sock
250     }
251 }
252 
253 impl<R: Req> TryFrom<SafeDescriptor> for Connection<R> {
254     type Error = Error;
255 
try_from(fd: SafeDescriptor) -> Result<Self>256     fn try_from(fd: SafeDescriptor) -> Result<Self> {
257         UnixStream::from(fd).try_into()
258     }
259 }
260 
261 impl<R: Req> TryFrom<UnixStream> for Connection<R> {
262     type Error = Error;
263 
try_from(sock: UnixStream) -> Result<Self>264     fn try_from(sock: UnixStream) -> Result<Self> {
265         Ok(Self(
266             SocketPlatformConnection {
267                 sock: sock.try_into()?,
268             },
269             std::marker::PhantomData,
270             std::marker::PhantomData,
271         ))
272     }
273 }
274 
275 impl<R: Req> Connection<R> {
276     /// Create a pair of unnamed vhost-user connections connected to each other.
pair() -> Result<(Self, Self)>277     pub fn pair() -> Result<(Self, Self)> {
278         let (client, server) = UnixStream::pair()?;
279         let client_connection = Connection::try_from(client)?;
280         let server_connection = Connection::try_from(server)?;
281         Ok((client_connection, server_connection))
282     }
283 }
284 
285 impl<S: Frontend> AsRawDescriptor for FrontendServer<S> {
as_raw_descriptor(&self) -> RawDescriptor286     fn as_raw_descriptor(&self) -> RawDescriptor {
287         self.sub_sock.as_raw_descriptor()
288     }
289 }
290 
291 impl<S: Frontend> ReadNotifier for FrontendServer<S> {
get_read_notifier(&self) -> &dyn AsRawDescriptor292     fn get_read_notifier(&self) -> &dyn AsRawDescriptor {
293         self.sub_sock.0.get_read_notifier()
294     }
295 }
296 
297 impl<S: Frontend> FrontendServer<S> {
298     /// Create a `FrontendServer` that uses a Unix stream internally.
299     ///
300     /// The returned `SafeDescriptor` is the client side of the stream and should be sent to the
301     /// backend using [BackendClient::set_slave_request_fd()].
302     ///
303     /// [BackendClient::set_slave_request_fd()]: struct.BackendClient.html#method.set_slave_request_fd
with_stream(backend: S) -> Result<(Self, SafeDescriptor)>304     pub fn with_stream(backend: S) -> Result<(Self, SafeDescriptor)> {
305         let (tx, rx) = UnixStream::pair()?;
306         let rx_connection = Connection::try_from(rx)?;
307         Ok((
308             Self::new(backend, rx_connection)?,
309             SafeDescriptor::from(OwnedFd::from(tx)),
310         ))
311     }
312 }
313 
314 #[cfg(test)]
315 pub(crate) mod tests {
316     use tempfile::Builder;
317     use tempfile::TempDir;
318 
319     use super::*;
320     use crate::backend_client::BackendClient;
321     use crate::connection::Listener;
322     use crate::message::FrontendReq;
323     use crate::Connection;
324 
temp_dir() -> TempDir325     pub(crate) fn temp_dir() -> TempDir {
326         Builder::new().prefix("/tmp/vhost_test").tempdir().unwrap()
327     }
328 
connect(path: &Path) -> Result<Connection<FrontendReq>>329     fn connect(path: &Path) -> Result<Connection<FrontendReq>> {
330         let sock = UnixStream::connect(path).map_err(Error::SocketConnect)?;
331         Connection::try_from(sock)
332     }
333 
334     #[test]
create_listener()335     fn create_listener() {
336         let dir = temp_dir();
337         let mut path = dir.path().to_owned();
338         path.push("sock");
339         let listener = SocketListener::new(&path, true).unwrap();
340 
341         assert!(listener.as_raw_descriptor() > 0);
342     }
343 
344     #[test]
accept_connection()345     fn accept_connection() {
346         let dir = temp_dir();
347         let mut path = dir.path().to_owned();
348         path.push("sock");
349         let mut listener = SocketListener::new(&path, true).unwrap();
350         listener.set_nonblocking(true).unwrap();
351 
352         // accept on a fd without incoming connection
353         let conn = listener.accept().unwrap();
354         assert!(conn.is_none());
355     }
356 
357     #[test]
test_create_failure()358     fn test_create_failure() {
359         let dir = temp_dir();
360         let mut path = dir.path().to_owned();
361         path.push("sock");
362         let _ = SocketListener::new(&path, true).unwrap();
363         let _ = SocketListener::new(&path, false).is_err();
364         assert!(connect(&path).is_err());
365 
366         let mut listener = SocketListener::new(&path, true).unwrap();
367         assert!(SocketListener::new(&path, false).is_err());
368         listener.set_nonblocking(true).unwrap();
369 
370         let backend_connection = connect(&path).unwrap();
371         let _backend_client = BackendClient::new(backend_connection);
372         let _server_connection = listener.accept().unwrap().unwrap();
373     }
374 
375     #[test]
test_advance_slices()376     fn test_advance_slices() {
377         // Test case from https://doc.rust-lang.org/std/io/struct.IoSlice.html#method.advance_slices
378         let buf1 = [1; 8];
379         let buf2 = [2; 16];
380         let buf3 = [3; 8];
381         let mut bufs = &mut [&buf1[..], &buf2[..], &buf3[..]][..];
382         advance_slices(&mut bufs, 10);
383         assert_eq!(bufs[0], [2; 14].as_ref());
384         assert_eq!(bufs[1], [3; 8].as_ref());
385     }
386 }
387