1 // Copyright 2022 The Chromium OS Authors. All rights reserved.
2 // SPDX-License-Identifier: Apache-2.0
3
4 //! Unix specific code that keeps rest of the code in the crate platform independent.
5
6 use std::any::Any;
7 use std::fs::File;
8 use std::io::ErrorKind;
9 use std::io::IoSlice;
10 use std::io::IoSliceMut;
11 use std::os::fd::OwnedFd;
12 use std::os::unix::net::UnixListener;
13 use std::os::unix::net::UnixStream;
14 use std::path::Path;
15 use std::path::PathBuf;
16
17 use base::AsRawDescriptor;
18 use base::RawDescriptor;
19 use base::ReadNotifier;
20 use base::SafeDescriptor;
21 use base::ScmSocket;
22
23 use crate::connection::Listener;
24 use crate::frontend_server::FrontendServer;
25 use crate::message::FrontendReq;
26 use crate::message::Req;
27 use crate::message::MAX_ATTACHED_FD_ENTRIES;
28 use crate::Connection;
29 use crate::Error;
30 use crate::Frontend;
31 use crate::Result;
32
33 /// Alias to enable platform independent code.
34 pub type SystemListener = UnixListener;
35
36 pub use SocketPlatformConnection as PlatformConnection;
37
38 /// Unix domain socket listener for accepting incoming connections.
39 pub struct SocketListener {
40 fd: SystemListener,
41 drop_path: Option<Box<dyn Any>>,
42 }
43
44 impl SocketListener {
45 /// Create a unix domain socket listener.
46 ///
47 /// # Return:
48 /// * - the new SocketListener object on success.
49 /// * - SocketError: failed to create listener socket.
new<P: AsRef<Path>>(path: P, unlink: bool) -> Result<Self>50 pub fn new<P: AsRef<Path>>(path: P, unlink: bool) -> Result<Self> {
51 if unlink {
52 let _ = std::fs::remove_file(&path);
53 }
54 let fd = SystemListener::bind(&path).map_err(Error::SocketError)?;
55
56 struct DropPath {
57 path: PathBuf,
58 }
59
60 impl Drop for DropPath {
61 fn drop(&mut self) {
62 let _ = std::fs::remove_file(&self.path);
63 }
64 }
65
66 Ok(SocketListener {
67 fd,
68 drop_path: Some(Box::new(DropPath {
69 path: path.as_ref().to_owned(),
70 })),
71 })
72 }
73
74 /// Take and return the resources that the parent process needs to keep alive as long as the
75 /// child process lives, in case of incoming fork.
take_resources_for_parent(&mut self) -> Option<Box<dyn Any>>76 pub fn take_resources_for_parent(&mut self) -> Option<Box<dyn Any>> {
77 self.drop_path.take()
78 }
79 }
80
81 impl Listener for SocketListener {
82 /// Accept an incoming connection.
83 ///
84 /// # Return:
85 /// * - Some(SystemListener): new SystemListener object if new incoming connection is available.
86 /// * - None: no incoming connection available.
87 /// * - SocketError: errors from accept().
accept(&mut self) -> Result<Option<Connection<FrontendReq>>>88 fn accept(&mut self) -> Result<Option<Connection<FrontendReq>>> {
89 loop {
90 match self.fd.accept() {
91 Ok((stream, _addr)) => {
92 return Ok(Some(Connection::try_from(stream)?));
93 }
94 Err(e) => {
95 match e.kind() {
96 // No incoming connection available.
97 ErrorKind::WouldBlock => return Ok(None),
98 // New connection closed by peer.
99 ErrorKind::ConnectionAborted => return Ok(None),
100 // Interrupted by signals, retry
101 ErrorKind::Interrupted => continue,
102 _ => return Err(Error::SocketError(e)),
103 }
104 }
105 }
106 }
107 }
108
109 /// Change blocking status on the listener.
110 ///
111 /// # Return:
112 /// * - () on success.
113 /// * - SocketError: failure from set_nonblocking().
set_nonblocking(&self, block: bool) -> Result<()>114 fn set_nonblocking(&self, block: bool) -> Result<()> {
115 self.fd.set_nonblocking(block).map_err(Error::SocketError)
116 }
117 }
118
119 impl AsRawDescriptor for SocketListener {
as_raw_descriptor(&self) -> RawDescriptor120 fn as_raw_descriptor(&self) -> RawDescriptor {
121 self.fd.as_raw_descriptor()
122 }
123 }
124
125 /// Unix domain socket based vhost-user connection.
126 pub struct SocketPlatformConnection {
127 sock: ScmSocket<UnixStream>,
128 }
129
130 // Advance the internal cursor of the slices.
131 // This is same with a nightly API `IoSlice::advance_slices` but for `&[u8]`.
advance_slices(bufs: &mut &mut [&[u8]], mut count: usize)132 fn advance_slices(bufs: &mut &mut [&[u8]], mut count: usize) {
133 use std::mem::take;
134
135 let mut idx = 0;
136 for b in bufs.iter() {
137 if count < b.len() {
138 break;
139 }
140 count -= b.len();
141 idx += 1;
142 }
143 *bufs = &mut take(bufs)[idx..];
144 if !bufs.is_empty() {
145 bufs[0] = &bufs[0][count..];
146 }
147 }
148
149 impl SocketPlatformConnection {
150 /// Sends all bytes from scatter-gather vectors with optional attached file descriptors. Will
151 /// loop until all data has been transfered.
152 ///
153 /// # TODO
154 /// This function takes a slice of `&[u8]` instead of `IoSlice` because the internal
155 /// cursor needs to be moved by `advance_slices()`.
156 /// Once `IoSlice::advance_slices()` becomes stable, this should be updated.
157 /// <https://github.com/rust-lang/rust/issues/62726>.
send_iovec_all( &self, mut iovs: &mut [&[u8]], mut fds: Option<&[RawDescriptor]>, ) -> Result<()>158 fn send_iovec_all(
159 &self,
160 mut iovs: &mut [&[u8]],
161 mut fds: Option<&[RawDescriptor]>,
162 ) -> Result<()> {
163 // Guarantee that `iovs` becomes empty if it doesn't contain any data.
164 advance_slices(&mut iovs, 0);
165
166 while !iovs.is_empty() {
167 let iovec: Vec<_> = iovs.iter_mut().map(|i| IoSlice::new(i)).collect();
168 match self.sock.send_vectored_with_fds(&iovec, fds.unwrap_or(&[])) {
169 Ok(n) => {
170 fds = None;
171 advance_slices(&mut iovs, n);
172 }
173 Err(e) => match e.kind() {
174 ErrorKind::WouldBlock | ErrorKind::Interrupted => {}
175 _ => return Err(Error::SocketError(e)),
176 },
177 }
178 }
179 Ok(())
180 }
181
182 /// Sends a single message over the socket with optional attached file descriptors.
183 ///
184 /// - `hdr`: vhost message header
185 /// - `body`: vhost message body (may be empty to send a header-only message)
186 /// - `payload`: additional bytes to append to `body` (may be empty)
send_message( &self, hdr: &[u8], body: &[u8], payload: &[u8], fds: Option<&[RawDescriptor]>, ) -> Result<()>187 pub fn send_message(
188 &self,
189 hdr: &[u8],
190 body: &[u8],
191 payload: &[u8],
192 fds: Option<&[RawDescriptor]>,
193 ) -> Result<()> {
194 let mut iobufs = [hdr, body, payload];
195 self.send_iovec_all(&mut iobufs, fds)
196 }
197
198 /// Reads bytes from the socket into the given scatter/gather vectors with optional attached
199 /// file.
200 ///
201 /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little
202 /// tricky to pass file descriptors through such a communication channel. Let's assume that a
203 /// sender sending a message with some file descriptors attached. To successfully receive those
204 /// attached file descriptors, the receiver must obey following rules:
205 /// 1) file descriptors are attached to a message.
206 /// 2) message(packet) boundaries must be respected on the receive side.
207 ///
208 /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the
209 /// attached file descriptors will get lost.
210 /// Note that this function wraps received file descriptors as `File`.
211 ///
212 /// # Return:
213 /// * - (number of bytes received, [received files]) on success
214 /// * - Disconnect: the connection is closed.
215 /// * - SocketRetry: temporary error caused by signals or short of resources.
216 /// * - SocketBroken: the underline socket is broken.
217 /// * - SocketError: other socket related errors.
recv_into_bufs( &self, bufs: &mut [IoSliceMut], allow_fd: bool, ) -> Result<(usize, Option<Vec<File>>)>218 pub fn recv_into_bufs(
219 &self,
220 bufs: &mut [IoSliceMut],
221 allow_fd: bool,
222 ) -> Result<(usize, Option<Vec<File>>)> {
223 let max_fds = if allow_fd { MAX_ATTACHED_FD_ENTRIES } else { 0 };
224 let (bytes, fds) = self.sock.recv_vectored_with_fds(bufs, max_fds)?;
225
226 // 0-bytes indicates that the connection is closed.
227 if bytes == 0 {
228 return Err(Error::Disconnect);
229 }
230
231 let files = if fds.is_empty() {
232 None
233 } else {
234 Some(fds.into_iter().map(File::from).collect())
235 };
236
237 Ok((bytes, files))
238 }
239 }
240
241 impl AsRawDescriptor for SocketPlatformConnection {
as_raw_descriptor(&self) -> RawDescriptor242 fn as_raw_descriptor(&self) -> RawDescriptor {
243 self.sock.as_raw_descriptor()
244 }
245 }
246
247 impl ReadNotifier for SocketPlatformConnection {
get_read_notifier(&self) -> &dyn AsRawDescriptor248 fn get_read_notifier(&self) -> &dyn AsRawDescriptor {
249 &self.sock
250 }
251 }
252
253 impl<R: Req> TryFrom<SafeDescriptor> for Connection<R> {
254 type Error = Error;
255
try_from(fd: SafeDescriptor) -> Result<Self>256 fn try_from(fd: SafeDescriptor) -> Result<Self> {
257 UnixStream::from(fd).try_into()
258 }
259 }
260
261 impl<R: Req> TryFrom<UnixStream> for Connection<R> {
262 type Error = Error;
263
try_from(sock: UnixStream) -> Result<Self>264 fn try_from(sock: UnixStream) -> Result<Self> {
265 Ok(Self(
266 SocketPlatformConnection {
267 sock: sock.try_into()?,
268 },
269 std::marker::PhantomData,
270 std::marker::PhantomData,
271 ))
272 }
273 }
274
275 impl<R: Req> Connection<R> {
276 /// Create a pair of unnamed vhost-user connections connected to each other.
pair() -> Result<(Self, Self)>277 pub fn pair() -> Result<(Self, Self)> {
278 let (client, server) = UnixStream::pair()?;
279 let client_connection = Connection::try_from(client)?;
280 let server_connection = Connection::try_from(server)?;
281 Ok((client_connection, server_connection))
282 }
283 }
284
285 impl<S: Frontend> AsRawDescriptor for FrontendServer<S> {
as_raw_descriptor(&self) -> RawDescriptor286 fn as_raw_descriptor(&self) -> RawDescriptor {
287 self.sub_sock.as_raw_descriptor()
288 }
289 }
290
291 impl<S: Frontend> ReadNotifier for FrontendServer<S> {
get_read_notifier(&self) -> &dyn AsRawDescriptor292 fn get_read_notifier(&self) -> &dyn AsRawDescriptor {
293 self.sub_sock.0.get_read_notifier()
294 }
295 }
296
297 impl<S: Frontend> FrontendServer<S> {
298 /// Create a `FrontendServer` that uses a Unix stream internally.
299 ///
300 /// The returned `SafeDescriptor` is the client side of the stream and should be sent to the
301 /// backend using [BackendClient::set_slave_request_fd()].
302 ///
303 /// [BackendClient::set_slave_request_fd()]: struct.BackendClient.html#method.set_slave_request_fd
with_stream(backend: S) -> Result<(Self, SafeDescriptor)>304 pub fn with_stream(backend: S) -> Result<(Self, SafeDescriptor)> {
305 let (tx, rx) = UnixStream::pair()?;
306 let rx_connection = Connection::try_from(rx)?;
307 Ok((
308 Self::new(backend, rx_connection)?,
309 SafeDescriptor::from(OwnedFd::from(tx)),
310 ))
311 }
312 }
313
314 #[cfg(test)]
315 pub(crate) mod tests {
316 use tempfile::Builder;
317 use tempfile::TempDir;
318
319 use super::*;
320 use crate::backend_client::BackendClient;
321 use crate::connection::Listener;
322 use crate::message::FrontendReq;
323 use crate::Connection;
324
temp_dir() -> TempDir325 pub(crate) fn temp_dir() -> TempDir {
326 Builder::new().prefix("/tmp/vhost_test").tempdir().unwrap()
327 }
328
connect(path: &Path) -> Result<Connection<FrontendReq>>329 fn connect(path: &Path) -> Result<Connection<FrontendReq>> {
330 let sock = UnixStream::connect(path).map_err(Error::SocketConnect)?;
331 Connection::try_from(sock)
332 }
333
334 #[test]
create_listener()335 fn create_listener() {
336 let dir = temp_dir();
337 let mut path = dir.path().to_owned();
338 path.push("sock");
339 let listener = SocketListener::new(&path, true).unwrap();
340
341 assert!(listener.as_raw_descriptor() > 0);
342 }
343
344 #[test]
accept_connection()345 fn accept_connection() {
346 let dir = temp_dir();
347 let mut path = dir.path().to_owned();
348 path.push("sock");
349 let mut listener = SocketListener::new(&path, true).unwrap();
350 listener.set_nonblocking(true).unwrap();
351
352 // accept on a fd without incoming connection
353 let conn = listener.accept().unwrap();
354 assert!(conn.is_none());
355 }
356
357 #[test]
test_create_failure()358 fn test_create_failure() {
359 let dir = temp_dir();
360 let mut path = dir.path().to_owned();
361 path.push("sock");
362 let _ = SocketListener::new(&path, true).unwrap();
363 let _ = SocketListener::new(&path, false).is_err();
364 assert!(connect(&path).is_err());
365
366 let mut listener = SocketListener::new(&path, true).unwrap();
367 assert!(SocketListener::new(&path, false).is_err());
368 listener.set_nonblocking(true).unwrap();
369
370 let backend_connection = connect(&path).unwrap();
371 let _backend_client = BackendClient::new(backend_connection);
372 let _server_connection = listener.accept().unwrap().unwrap();
373 }
374
375 #[test]
test_advance_slices()376 fn test_advance_slices() {
377 // Test case from https://doc.rust-lang.org/std/io/struct.IoSlice.html#method.advance_slices
378 let buf1 = [1; 8];
379 let buf2 = [2; 16];
380 let buf3 = [3; 8];
381 let mut bufs = &mut [&buf1[..], &buf2[..], &buf3[..]][..];
382 advance_slices(&mut bufs, 10);
383 assert_eq!(bufs[0], [2; 14].as_ref());
384 assert_eq!(bufs[1], [3; 8].as_ref());
385 }
386 }
387