xref: /aosp_15_r20/external/crosvm/base/src/sys/linux/vsock.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2021 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 /// Support for virtual sockets.
6 use std::fmt;
7 use std::io;
8 use std::mem;
9 use std::mem::size_of;
10 use std::num::ParseIntError;
11 use std::os::raw::c_uchar;
12 use std::os::raw::c_uint;
13 use std::os::raw::c_ushort;
14 use std::os::unix::io::AsRawFd;
15 use std::os::unix::io::IntoRawFd;
16 use std::os::unix::io::RawFd;
17 use std::result;
18 use std::str::FromStr;
19 
20 use libc::c_void;
21 use libc::sa_family_t;
22 use libc::size_t;
23 use libc::sockaddr;
24 use libc::socklen_t;
25 use libc::F_GETFL;
26 use libc::F_SETFL;
27 use libc::O_NONBLOCK;
28 use libc::VMADDR_CID_ANY;
29 use libc::VMADDR_CID_HOST;
30 use libc::VMADDR_CID_HYPERVISOR;
31 use thiserror::Error;
32 
33 // The domain for vsock sockets.
34 const AF_VSOCK: sa_family_t = 40;
35 
36 // Vsock loopback address.
37 const VMADDR_CID_LOCAL: c_uint = 1;
38 
39 /// Vsock equivalent of binding on port 0. Binds to a random port.
40 pub const VMADDR_PORT_ANY: c_uint = c_uint::MAX;
41 
42 // The number of bytes of padding to be added to the sockaddr_vm struct.  Taken directly
43 // from linux/vm_sockets.h.
44 const PADDING: usize = size_of::<sockaddr>()
45     - size_of::<sa_family_t>()
46     - size_of::<c_ushort>()
47     - (2 * size_of::<c_uint>());
48 
49 #[repr(C)]
50 #[derive(Default)]
51 struct sockaddr_vm {
52     svm_family: sa_family_t,
53     svm_reserved1: c_ushort,
54     svm_port: c_uint,
55     svm_cid: c_uint,
56     svm_zero: [c_uchar; PADDING],
57 }
58 
59 #[derive(Error, Debug)]
60 #[error("failed to parse vsock address")]
61 pub struct AddrParseError;
62 
63 /// The vsock equivalent of an IP address.
64 #[derive(Debug, Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]
65 pub enum VsockCid {
66     /// Vsock equivalent of INADDR_ANY. Indicates the context id of the current endpoint.
67     Any,
68     /// An address that refers to the bare-metal machine that serves as the hypervisor.
69     Hypervisor,
70     /// The loopback address.
71     Local,
72     /// The parent machine. It may not be the hypervisor for nested VMs.
73     Host,
74     /// An assigned CID that serves as the address for VSOCK.
75     Cid(c_uint),
76 }
77 
78 impl fmt::Display for VsockCid {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result79     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
80         match &self {
81             VsockCid::Any => write!(fmt, "Any"),
82             VsockCid::Hypervisor => write!(fmt, "Hypervisor"),
83             VsockCid::Local => write!(fmt, "Local"),
84             VsockCid::Host => write!(fmt, "Host"),
85             VsockCid::Cid(c) => write!(fmt, "'{}'", c),
86         }
87     }
88 }
89 
90 impl From<c_uint> for VsockCid {
from(c: c_uint) -> Self91     fn from(c: c_uint) -> Self {
92         match c {
93             VMADDR_CID_ANY => VsockCid::Any,
94             VMADDR_CID_HYPERVISOR => VsockCid::Hypervisor,
95             VMADDR_CID_LOCAL => VsockCid::Local,
96             VMADDR_CID_HOST => VsockCid::Host,
97             _ => VsockCid::Cid(c),
98         }
99     }
100 }
101 
102 impl FromStr for VsockCid {
103     type Err = ParseIntError;
104 
from_str(s: &str) -> Result<Self, Self::Err>105     fn from_str(s: &str) -> Result<Self, Self::Err> {
106         let c: c_uint = s.parse()?;
107         Ok(c.into())
108     }
109 }
110 
111 impl From<VsockCid> for c_uint {
from(cid: VsockCid) -> c_uint112     fn from(cid: VsockCid) -> c_uint {
113         match cid {
114             VsockCid::Any => VMADDR_CID_ANY,
115             VsockCid::Hypervisor => VMADDR_CID_HYPERVISOR,
116             VsockCid::Local => VMADDR_CID_LOCAL,
117             VsockCid::Host => VMADDR_CID_HOST,
118             VsockCid::Cid(c) => c,
119         }
120     }
121 }
122 
123 /// An address associated with a virtual socket.
124 #[derive(Debug, Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]
125 pub struct SocketAddr {
126     pub cid: VsockCid,
127     pub port: c_uint,
128 }
129 
130 pub trait ToSocketAddr {
to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>131     fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>;
132 }
133 
134 impl ToSocketAddr for SocketAddr {
to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>135     fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
136         Ok(*self)
137     }
138 }
139 
140 impl ToSocketAddr for str {
to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>141     fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
142         self.parse()
143     }
144 }
145 
146 impl ToSocketAddr for (VsockCid, c_uint) {
to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>147     fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
148         let (cid, port) = *self;
149         Ok(SocketAddr { cid, port })
150     }
151 }
152 
153 impl<'a, T: ToSocketAddr + ?Sized> ToSocketAddr for &'a T {
to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>154     fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
155         (**self).to_socket_addr()
156     }
157 }
158 
159 impl FromStr for SocketAddr {
160     type Err = AddrParseError;
161 
162     /// Parse a vsock SocketAddr from a string. vsock socket addresses are of the form
163     /// "vsock:cid:port".
from_str(s: &str) -> Result<SocketAddr, AddrParseError>164     fn from_str(s: &str) -> Result<SocketAddr, AddrParseError> {
165         let components: Vec<&str> = s.split(':').collect();
166         if components.len() != 3 || components[0] != "vsock" {
167             return Err(AddrParseError);
168         }
169 
170         Ok(SocketAddr {
171             cid: components[1].parse().map_err(|_| AddrParseError)?,
172             port: components[2].parse().map_err(|_| AddrParseError)?,
173         })
174     }
175 }
176 
177 impl fmt::Display for SocketAddr {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result178     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
179         write!(fmt, "{}:{}", self.cid, self.port)
180     }
181 }
182 
183 /// Sets `fd` to be blocking or nonblocking. `fd` must be a valid fd of a type that accepts the
184 /// `O_NONBLOCK` flag. This includes regular files, pipes, and sockets.
set_nonblocking(fd: RawFd, nonblocking: bool) -> io::Result<()>185 unsafe fn set_nonblocking(fd: RawFd, nonblocking: bool) -> io::Result<()> {
186     let flags = libc::fcntl(fd, F_GETFL, 0);
187     if flags < 0 {
188         return Err(io::Error::last_os_error());
189     }
190 
191     let flags = if nonblocking {
192         flags | O_NONBLOCK
193     } else {
194         flags & !O_NONBLOCK
195     };
196 
197     let ret = libc::fcntl(fd, F_SETFL, flags);
198     if ret < 0 {
199         return Err(io::Error::last_os_error());
200     }
201 
202     Ok(())
203 }
204 
205 /// A virtual socket.
206 ///
207 /// Do not use this class unless you need to change socket options or query the
208 /// state of the socket prior to calling listen or connect. Instead use either VsockStream or
209 /// VsockListener.
210 #[derive(Debug)]
211 pub struct VsockSocket {
212     fd: RawFd,
213 }
214 
215 impl VsockSocket {
new() -> io::Result<Self>216     pub fn new() -> io::Result<Self> {
217         // SAFETY: trivially safe
218         let fd = unsafe { libc::socket(libc::AF_VSOCK, libc::SOCK_STREAM | libc::SOCK_CLOEXEC, 0) };
219         if fd < 0 {
220             Err(io::Error::last_os_error())
221         } else {
222             Ok(VsockSocket { fd })
223         }
224     }
225 
bind<A: ToSocketAddr>(&mut self, addr: A) -> io::Result<()>226     pub fn bind<A: ToSocketAddr>(&mut self, addr: A) -> io::Result<()> {
227         let sockaddr = addr
228             .to_socket_addr()
229             .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?;
230 
231         // The compiler should optimize this out since these are both compile-time constants.
232         assert_eq!(size_of::<sockaddr_vm>(), size_of::<sockaddr>());
233 
234         let svm = sockaddr_vm {
235             svm_family: AF_VSOCK,
236             svm_cid: sockaddr.cid.into(),
237             svm_port: sockaddr.port,
238             ..Default::default()
239         };
240 
241         // SAFETY:
242         // Safe because this doesn't modify any memory and we check the return value.
243         let ret = unsafe {
244             libc::bind(
245                 self.fd,
246                 &svm as *const sockaddr_vm as *const sockaddr,
247                 size_of::<sockaddr_vm>() as socklen_t,
248             )
249         };
250         if ret < 0 {
251             let bind_err = io::Error::last_os_error();
252             Err(bind_err)
253         } else {
254             Ok(())
255         }
256     }
257 
connect<A: ToSocketAddr>(self, addr: A) -> io::Result<VsockStream>258     pub fn connect<A: ToSocketAddr>(self, addr: A) -> io::Result<VsockStream> {
259         let sockaddr = addr
260             .to_socket_addr()
261             .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?;
262 
263         let svm = sockaddr_vm {
264             svm_family: AF_VSOCK,
265             svm_cid: sockaddr.cid.into(),
266             svm_port: sockaddr.port,
267             ..Default::default()
268         };
269 
270         // SAFETY:
271         // Safe because this just connects a vsock socket, and the return value is checked.
272         let ret = unsafe {
273             libc::connect(
274                 self.fd,
275                 &svm as *const sockaddr_vm as *const sockaddr,
276                 size_of::<sockaddr_vm>() as socklen_t,
277             )
278         };
279         if ret < 0 {
280             let connect_err = io::Error::last_os_error();
281             Err(connect_err)
282         } else {
283             Ok(VsockStream { sock: self })
284         }
285     }
286 
listen(self) -> io::Result<VsockListener>287     pub fn listen(self) -> io::Result<VsockListener> {
288         // SAFETY:
289         // Safe because this doesn't modify any memory and we check the return value.
290         let ret = unsafe { libc::listen(self.fd, 1) };
291         if ret < 0 {
292             let listen_err = io::Error::last_os_error();
293             return Err(listen_err);
294         }
295         Ok(VsockListener { sock: self })
296     }
297 
298     /// Returns the port that this socket is bound to. This can only succeed after bind is called.
local_port(&self) -> io::Result<u32>299     pub fn local_port(&self) -> io::Result<u32> {
300         let mut svm: sockaddr_vm = Default::default();
301 
302         let mut addrlen = size_of::<sockaddr_vm>() as socklen_t;
303         // SAFETY:
304         // Safe because we give a valid pointer for addrlen and check the length.
305         let ret = unsafe {
306             // Get the socket address that was actually bound.
307             libc::getsockname(
308                 self.fd,
309                 &mut svm as *mut sockaddr_vm as *mut sockaddr,
310                 &mut addrlen as *mut socklen_t,
311             )
312         };
313         if ret < 0 {
314             let getsockname_err = io::Error::last_os_error();
315             Err(getsockname_err)
316         } else {
317             // If this doesn't match, it's not safe to get the port out of the sockaddr.
318             assert_eq!(addrlen as usize, size_of::<sockaddr_vm>());
319 
320             Ok(svm.svm_port)
321         }
322     }
323 
try_clone(&self) -> io::Result<Self>324     pub fn try_clone(&self) -> io::Result<Self> {
325         // SAFETY:
326         // Safe because this doesn't modify any memory and we check the return value.
327         let dup_fd = unsafe { libc::fcntl(self.fd, libc::F_DUPFD_CLOEXEC, 0) };
328         if dup_fd < 0 {
329             Err(io::Error::last_os_error())
330         } else {
331             Ok(Self { fd: dup_fd })
332         }
333     }
334 
set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()>335     pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
336         // SAFETY:
337         // Safe because the fd is valid and owned by this stream.
338         unsafe { set_nonblocking(self.fd, nonblocking) }
339     }
340 }
341 
342 impl IntoRawFd for VsockSocket {
into_raw_fd(self) -> RawFd343     fn into_raw_fd(self) -> RawFd {
344         let fd = self.fd;
345         mem::forget(self);
346         fd
347     }
348 }
349 
350 impl AsRawFd for VsockSocket {
as_raw_fd(&self) -> RawFd351     fn as_raw_fd(&self) -> RawFd {
352         self.fd
353     }
354 }
355 
356 impl Drop for VsockSocket {
drop(&mut self)357     fn drop(&mut self) {
358         // SAFETY:
359         // Safe because this doesn't modify any memory and we are the only
360         // owner of the file descriptor.
361         unsafe { libc::close(self.fd) };
362     }
363 }
364 
365 /// A virtual stream socket.
366 #[derive(Debug)]
367 pub struct VsockStream {
368     sock: VsockSocket,
369 }
370 
371 impl VsockStream {
connect<A: ToSocketAddr>(addr: A) -> io::Result<VsockStream>372     pub fn connect<A: ToSocketAddr>(addr: A) -> io::Result<VsockStream> {
373         let sock = VsockSocket::new()?;
374         sock.connect(addr)
375     }
376 
377     /// Returns the port that this stream is bound to.
local_port(&self) -> io::Result<u32>378     pub fn local_port(&self) -> io::Result<u32> {
379         self.sock.local_port()
380     }
381 
try_clone(&self) -> io::Result<VsockStream>382     pub fn try_clone(&self) -> io::Result<VsockStream> {
383         self.sock.try_clone().map(|f| VsockStream { sock: f })
384     }
385 
set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()>386     pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
387         self.sock.set_nonblocking(nonblocking)
388     }
389 }
390 
391 impl io::Read for VsockStream {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>392     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
393         // SAFETY:
394         // Safe because this will only modify the contents of |buf| and we check the return value.
395         let ret = unsafe {
396             libc::read(
397                 self.sock.as_raw_fd(),
398                 buf as *mut [u8] as *mut c_void,
399                 buf.len() as size_t,
400             )
401         };
402         if ret < 0 {
403             return Err(io::Error::last_os_error());
404         }
405 
406         Ok(ret as usize)
407     }
408 }
409 
410 impl io::Write for VsockStream {
write(&mut self, buf: &[u8]) -> io::Result<usize>411     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
412         // SAFETY:
413         // Safe because this doesn't modify any memory and we check the return value.
414         let ret = unsafe {
415             libc::write(
416                 self.sock.as_raw_fd(),
417                 buf as *const [u8] as *const c_void,
418                 buf.len() as size_t,
419             )
420         };
421         if ret < 0 {
422             return Err(io::Error::last_os_error());
423         }
424 
425         Ok(ret as usize)
426     }
427 
flush(&mut self) -> io::Result<()>428     fn flush(&mut self) -> io::Result<()> {
429         // No buffered data so nothing to do.
430         Ok(())
431     }
432 }
433 
434 impl AsRawFd for VsockStream {
as_raw_fd(&self) -> RawFd435     fn as_raw_fd(&self) -> RawFd {
436         self.sock.as_raw_fd()
437     }
438 }
439 
440 impl IntoRawFd for VsockStream {
into_raw_fd(self) -> RawFd441     fn into_raw_fd(self) -> RawFd {
442         self.sock.into_raw_fd()
443     }
444 }
445 
446 /// Represents a virtual socket server.
447 #[derive(Debug)]
448 pub struct VsockListener {
449     sock: VsockSocket,
450 }
451 
452 impl VsockListener {
453     /// Creates a new `VsockListener` bound to the specified port on the current virtual socket
454     /// endpoint.
bind<A: ToSocketAddr>(addr: A) -> io::Result<VsockListener>455     pub fn bind<A: ToSocketAddr>(addr: A) -> io::Result<VsockListener> {
456         let mut sock = VsockSocket::new()?;
457         sock.bind(addr)?;
458         sock.listen()
459     }
460 
461     /// Returns the port that this listener is bound to.
local_port(&self) -> io::Result<u32>462     pub fn local_port(&self) -> io::Result<u32> {
463         self.sock.local_port()
464     }
465 
466     /// Accepts a new incoming connection on this listener.  Blocks the calling thread until a
467     /// new connection is established.  When established, returns the corresponding `VsockStream`
468     /// and the remote peer's address.
accept(&self) -> io::Result<(VsockStream, SocketAddr)>469     pub fn accept(&self) -> io::Result<(VsockStream, SocketAddr)> {
470         let mut svm: sockaddr_vm = Default::default();
471 
472         let mut socklen: socklen_t = size_of::<sockaddr_vm>() as socklen_t;
473         // SAFETY:
474         // Safe because this will only modify |svm| and we check the return value.
475         let fd = unsafe {
476             libc::accept4(
477                 self.sock.as_raw_fd(),
478                 &mut svm as *mut sockaddr_vm as *mut sockaddr,
479                 &mut socklen as *mut socklen_t,
480                 libc::SOCK_CLOEXEC,
481             )
482         };
483         if fd < 0 {
484             return Err(io::Error::last_os_error());
485         }
486 
487         if svm.svm_family != AF_VSOCK {
488             return Err(io::Error::new(
489                 io::ErrorKind::InvalidData,
490                 format!("unexpected address family: {}", svm.svm_family),
491             ));
492         }
493 
494         Ok((
495             VsockStream {
496                 sock: VsockSocket { fd },
497             },
498             SocketAddr {
499                 cid: svm.svm_cid.into(),
500                 port: svm.svm_port,
501             },
502         ))
503     }
504 
set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()>505     pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
506         self.sock.set_nonblocking(nonblocking)
507     }
508 }
509 
510 impl AsRawFd for VsockListener {
as_raw_fd(&self) -> RawFd511     fn as_raw_fd(&self) -> RawFd {
512         self.sock.as_raw_fd()
513     }
514 }
515