1 // Copyright 2021 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 /// Support for virtual sockets.
6 use std::fmt;
7 use std::io;
8 use std::mem;
9 use std::mem::size_of;
10 use std::num::ParseIntError;
11 use std::os::raw::c_uchar;
12 use std::os::raw::c_uint;
13 use std::os::raw::c_ushort;
14 use std::os::unix::io::AsRawFd;
15 use std::os::unix::io::IntoRawFd;
16 use std::os::unix::io::RawFd;
17 use std::result;
18 use std::str::FromStr;
19
20 use libc::c_void;
21 use libc::sa_family_t;
22 use libc::size_t;
23 use libc::sockaddr;
24 use libc::socklen_t;
25 use libc::F_GETFL;
26 use libc::F_SETFL;
27 use libc::O_NONBLOCK;
28 use libc::VMADDR_CID_ANY;
29 use libc::VMADDR_CID_HOST;
30 use libc::VMADDR_CID_HYPERVISOR;
31 use thiserror::Error;
32
33 // The domain for vsock sockets.
34 const AF_VSOCK: sa_family_t = 40;
35
36 // Vsock loopback address.
37 const VMADDR_CID_LOCAL: c_uint = 1;
38
39 /// Vsock equivalent of binding on port 0. Binds to a random port.
40 pub const VMADDR_PORT_ANY: c_uint = c_uint::MAX;
41
42 // The number of bytes of padding to be added to the sockaddr_vm struct. Taken directly
43 // from linux/vm_sockets.h.
44 const PADDING: usize = size_of::<sockaddr>()
45 - size_of::<sa_family_t>()
46 - size_of::<c_ushort>()
47 - (2 * size_of::<c_uint>());
48
49 #[repr(C)]
50 #[derive(Default)]
51 struct sockaddr_vm {
52 svm_family: sa_family_t,
53 svm_reserved1: c_ushort,
54 svm_port: c_uint,
55 svm_cid: c_uint,
56 svm_zero: [c_uchar; PADDING],
57 }
58
59 #[derive(Error, Debug)]
60 #[error("failed to parse vsock address")]
61 pub struct AddrParseError;
62
63 /// The vsock equivalent of an IP address.
64 #[derive(Debug, Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]
65 pub enum VsockCid {
66 /// Vsock equivalent of INADDR_ANY. Indicates the context id of the current endpoint.
67 Any,
68 /// An address that refers to the bare-metal machine that serves as the hypervisor.
69 Hypervisor,
70 /// The loopback address.
71 Local,
72 /// The parent machine. It may not be the hypervisor for nested VMs.
73 Host,
74 /// An assigned CID that serves as the address for VSOCK.
75 Cid(c_uint),
76 }
77
78 impl fmt::Display for VsockCid {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result79 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
80 match &self {
81 VsockCid::Any => write!(fmt, "Any"),
82 VsockCid::Hypervisor => write!(fmt, "Hypervisor"),
83 VsockCid::Local => write!(fmt, "Local"),
84 VsockCid::Host => write!(fmt, "Host"),
85 VsockCid::Cid(c) => write!(fmt, "'{}'", c),
86 }
87 }
88 }
89
90 impl From<c_uint> for VsockCid {
from(c: c_uint) -> Self91 fn from(c: c_uint) -> Self {
92 match c {
93 VMADDR_CID_ANY => VsockCid::Any,
94 VMADDR_CID_HYPERVISOR => VsockCid::Hypervisor,
95 VMADDR_CID_LOCAL => VsockCid::Local,
96 VMADDR_CID_HOST => VsockCid::Host,
97 _ => VsockCid::Cid(c),
98 }
99 }
100 }
101
102 impl FromStr for VsockCid {
103 type Err = ParseIntError;
104
from_str(s: &str) -> Result<Self, Self::Err>105 fn from_str(s: &str) -> Result<Self, Self::Err> {
106 let c: c_uint = s.parse()?;
107 Ok(c.into())
108 }
109 }
110
111 impl From<VsockCid> for c_uint {
from(cid: VsockCid) -> c_uint112 fn from(cid: VsockCid) -> c_uint {
113 match cid {
114 VsockCid::Any => VMADDR_CID_ANY,
115 VsockCid::Hypervisor => VMADDR_CID_HYPERVISOR,
116 VsockCid::Local => VMADDR_CID_LOCAL,
117 VsockCid::Host => VMADDR_CID_HOST,
118 VsockCid::Cid(c) => c,
119 }
120 }
121 }
122
123 /// An address associated with a virtual socket.
124 #[derive(Debug, Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]
125 pub struct SocketAddr {
126 pub cid: VsockCid,
127 pub port: c_uint,
128 }
129
130 pub trait ToSocketAddr {
to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>131 fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>;
132 }
133
134 impl ToSocketAddr for SocketAddr {
to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>135 fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
136 Ok(*self)
137 }
138 }
139
140 impl ToSocketAddr for str {
to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>141 fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
142 self.parse()
143 }
144 }
145
146 impl ToSocketAddr for (VsockCid, c_uint) {
to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>147 fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
148 let (cid, port) = *self;
149 Ok(SocketAddr { cid, port })
150 }
151 }
152
153 impl<'a, T: ToSocketAddr + ?Sized> ToSocketAddr for &'a T {
to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>154 fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
155 (**self).to_socket_addr()
156 }
157 }
158
159 impl FromStr for SocketAddr {
160 type Err = AddrParseError;
161
162 /// Parse a vsock SocketAddr from a string. vsock socket addresses are of the form
163 /// "vsock:cid:port".
from_str(s: &str) -> Result<SocketAddr, AddrParseError>164 fn from_str(s: &str) -> Result<SocketAddr, AddrParseError> {
165 let components: Vec<&str> = s.split(':').collect();
166 if components.len() != 3 || components[0] != "vsock" {
167 return Err(AddrParseError);
168 }
169
170 Ok(SocketAddr {
171 cid: components[1].parse().map_err(|_| AddrParseError)?,
172 port: components[2].parse().map_err(|_| AddrParseError)?,
173 })
174 }
175 }
176
177 impl fmt::Display for SocketAddr {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result178 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
179 write!(fmt, "{}:{}", self.cid, self.port)
180 }
181 }
182
183 /// Sets `fd` to be blocking or nonblocking. `fd` must be a valid fd of a type that accepts the
184 /// `O_NONBLOCK` flag. This includes regular files, pipes, and sockets.
set_nonblocking(fd: RawFd, nonblocking: bool) -> io::Result<()>185 unsafe fn set_nonblocking(fd: RawFd, nonblocking: bool) -> io::Result<()> {
186 let flags = libc::fcntl(fd, F_GETFL, 0);
187 if flags < 0 {
188 return Err(io::Error::last_os_error());
189 }
190
191 let flags = if nonblocking {
192 flags | O_NONBLOCK
193 } else {
194 flags & !O_NONBLOCK
195 };
196
197 let ret = libc::fcntl(fd, F_SETFL, flags);
198 if ret < 0 {
199 return Err(io::Error::last_os_error());
200 }
201
202 Ok(())
203 }
204
205 /// A virtual socket.
206 ///
207 /// Do not use this class unless you need to change socket options or query the
208 /// state of the socket prior to calling listen or connect. Instead use either VsockStream or
209 /// VsockListener.
210 #[derive(Debug)]
211 pub struct VsockSocket {
212 fd: RawFd,
213 }
214
215 impl VsockSocket {
new() -> io::Result<Self>216 pub fn new() -> io::Result<Self> {
217 // SAFETY: trivially safe
218 let fd = unsafe { libc::socket(libc::AF_VSOCK, libc::SOCK_STREAM | libc::SOCK_CLOEXEC, 0) };
219 if fd < 0 {
220 Err(io::Error::last_os_error())
221 } else {
222 Ok(VsockSocket { fd })
223 }
224 }
225
bind<A: ToSocketAddr>(&mut self, addr: A) -> io::Result<()>226 pub fn bind<A: ToSocketAddr>(&mut self, addr: A) -> io::Result<()> {
227 let sockaddr = addr
228 .to_socket_addr()
229 .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?;
230
231 // The compiler should optimize this out since these are both compile-time constants.
232 assert_eq!(size_of::<sockaddr_vm>(), size_of::<sockaddr>());
233
234 let svm = sockaddr_vm {
235 svm_family: AF_VSOCK,
236 svm_cid: sockaddr.cid.into(),
237 svm_port: sockaddr.port,
238 ..Default::default()
239 };
240
241 // SAFETY:
242 // Safe because this doesn't modify any memory and we check the return value.
243 let ret = unsafe {
244 libc::bind(
245 self.fd,
246 &svm as *const sockaddr_vm as *const sockaddr,
247 size_of::<sockaddr_vm>() as socklen_t,
248 )
249 };
250 if ret < 0 {
251 let bind_err = io::Error::last_os_error();
252 Err(bind_err)
253 } else {
254 Ok(())
255 }
256 }
257
connect<A: ToSocketAddr>(self, addr: A) -> io::Result<VsockStream>258 pub fn connect<A: ToSocketAddr>(self, addr: A) -> io::Result<VsockStream> {
259 let sockaddr = addr
260 .to_socket_addr()
261 .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?;
262
263 let svm = sockaddr_vm {
264 svm_family: AF_VSOCK,
265 svm_cid: sockaddr.cid.into(),
266 svm_port: sockaddr.port,
267 ..Default::default()
268 };
269
270 // SAFETY:
271 // Safe because this just connects a vsock socket, and the return value is checked.
272 let ret = unsafe {
273 libc::connect(
274 self.fd,
275 &svm as *const sockaddr_vm as *const sockaddr,
276 size_of::<sockaddr_vm>() as socklen_t,
277 )
278 };
279 if ret < 0 {
280 let connect_err = io::Error::last_os_error();
281 Err(connect_err)
282 } else {
283 Ok(VsockStream { sock: self })
284 }
285 }
286
listen(self) -> io::Result<VsockListener>287 pub fn listen(self) -> io::Result<VsockListener> {
288 // SAFETY:
289 // Safe because this doesn't modify any memory and we check the return value.
290 let ret = unsafe { libc::listen(self.fd, 1) };
291 if ret < 0 {
292 let listen_err = io::Error::last_os_error();
293 return Err(listen_err);
294 }
295 Ok(VsockListener { sock: self })
296 }
297
298 /// Returns the port that this socket is bound to. This can only succeed after bind is called.
local_port(&self) -> io::Result<u32>299 pub fn local_port(&self) -> io::Result<u32> {
300 let mut svm: sockaddr_vm = Default::default();
301
302 let mut addrlen = size_of::<sockaddr_vm>() as socklen_t;
303 // SAFETY:
304 // Safe because we give a valid pointer for addrlen and check the length.
305 let ret = unsafe {
306 // Get the socket address that was actually bound.
307 libc::getsockname(
308 self.fd,
309 &mut svm as *mut sockaddr_vm as *mut sockaddr,
310 &mut addrlen as *mut socklen_t,
311 )
312 };
313 if ret < 0 {
314 let getsockname_err = io::Error::last_os_error();
315 Err(getsockname_err)
316 } else {
317 // If this doesn't match, it's not safe to get the port out of the sockaddr.
318 assert_eq!(addrlen as usize, size_of::<sockaddr_vm>());
319
320 Ok(svm.svm_port)
321 }
322 }
323
try_clone(&self) -> io::Result<Self>324 pub fn try_clone(&self) -> io::Result<Self> {
325 // SAFETY:
326 // Safe because this doesn't modify any memory and we check the return value.
327 let dup_fd = unsafe { libc::fcntl(self.fd, libc::F_DUPFD_CLOEXEC, 0) };
328 if dup_fd < 0 {
329 Err(io::Error::last_os_error())
330 } else {
331 Ok(Self { fd: dup_fd })
332 }
333 }
334
set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()>335 pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
336 // SAFETY:
337 // Safe because the fd is valid and owned by this stream.
338 unsafe { set_nonblocking(self.fd, nonblocking) }
339 }
340 }
341
342 impl IntoRawFd for VsockSocket {
into_raw_fd(self) -> RawFd343 fn into_raw_fd(self) -> RawFd {
344 let fd = self.fd;
345 mem::forget(self);
346 fd
347 }
348 }
349
350 impl AsRawFd for VsockSocket {
as_raw_fd(&self) -> RawFd351 fn as_raw_fd(&self) -> RawFd {
352 self.fd
353 }
354 }
355
356 impl Drop for VsockSocket {
drop(&mut self)357 fn drop(&mut self) {
358 // SAFETY:
359 // Safe because this doesn't modify any memory and we are the only
360 // owner of the file descriptor.
361 unsafe { libc::close(self.fd) };
362 }
363 }
364
365 /// A virtual stream socket.
366 #[derive(Debug)]
367 pub struct VsockStream {
368 sock: VsockSocket,
369 }
370
371 impl VsockStream {
connect<A: ToSocketAddr>(addr: A) -> io::Result<VsockStream>372 pub fn connect<A: ToSocketAddr>(addr: A) -> io::Result<VsockStream> {
373 let sock = VsockSocket::new()?;
374 sock.connect(addr)
375 }
376
377 /// Returns the port that this stream is bound to.
local_port(&self) -> io::Result<u32>378 pub fn local_port(&self) -> io::Result<u32> {
379 self.sock.local_port()
380 }
381
try_clone(&self) -> io::Result<VsockStream>382 pub fn try_clone(&self) -> io::Result<VsockStream> {
383 self.sock.try_clone().map(|f| VsockStream { sock: f })
384 }
385
set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()>386 pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
387 self.sock.set_nonblocking(nonblocking)
388 }
389 }
390
391 impl io::Read for VsockStream {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>392 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
393 // SAFETY:
394 // Safe because this will only modify the contents of |buf| and we check the return value.
395 let ret = unsafe {
396 libc::read(
397 self.sock.as_raw_fd(),
398 buf as *mut [u8] as *mut c_void,
399 buf.len() as size_t,
400 )
401 };
402 if ret < 0 {
403 return Err(io::Error::last_os_error());
404 }
405
406 Ok(ret as usize)
407 }
408 }
409
410 impl io::Write for VsockStream {
write(&mut self, buf: &[u8]) -> io::Result<usize>411 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
412 // SAFETY:
413 // Safe because this doesn't modify any memory and we check the return value.
414 let ret = unsafe {
415 libc::write(
416 self.sock.as_raw_fd(),
417 buf as *const [u8] as *const c_void,
418 buf.len() as size_t,
419 )
420 };
421 if ret < 0 {
422 return Err(io::Error::last_os_error());
423 }
424
425 Ok(ret as usize)
426 }
427
flush(&mut self) -> io::Result<()>428 fn flush(&mut self) -> io::Result<()> {
429 // No buffered data so nothing to do.
430 Ok(())
431 }
432 }
433
434 impl AsRawFd for VsockStream {
as_raw_fd(&self) -> RawFd435 fn as_raw_fd(&self) -> RawFd {
436 self.sock.as_raw_fd()
437 }
438 }
439
440 impl IntoRawFd for VsockStream {
into_raw_fd(self) -> RawFd441 fn into_raw_fd(self) -> RawFd {
442 self.sock.into_raw_fd()
443 }
444 }
445
446 /// Represents a virtual socket server.
447 #[derive(Debug)]
448 pub struct VsockListener {
449 sock: VsockSocket,
450 }
451
452 impl VsockListener {
453 /// Creates a new `VsockListener` bound to the specified port on the current virtual socket
454 /// endpoint.
bind<A: ToSocketAddr>(addr: A) -> io::Result<VsockListener>455 pub fn bind<A: ToSocketAddr>(addr: A) -> io::Result<VsockListener> {
456 let mut sock = VsockSocket::new()?;
457 sock.bind(addr)?;
458 sock.listen()
459 }
460
461 /// Returns the port that this listener is bound to.
local_port(&self) -> io::Result<u32>462 pub fn local_port(&self) -> io::Result<u32> {
463 self.sock.local_port()
464 }
465
466 /// Accepts a new incoming connection on this listener. Blocks the calling thread until a
467 /// new connection is established. When established, returns the corresponding `VsockStream`
468 /// and the remote peer's address.
accept(&self) -> io::Result<(VsockStream, SocketAddr)>469 pub fn accept(&self) -> io::Result<(VsockStream, SocketAddr)> {
470 let mut svm: sockaddr_vm = Default::default();
471
472 let mut socklen: socklen_t = size_of::<sockaddr_vm>() as socklen_t;
473 // SAFETY:
474 // Safe because this will only modify |svm| and we check the return value.
475 let fd = unsafe {
476 libc::accept4(
477 self.sock.as_raw_fd(),
478 &mut svm as *mut sockaddr_vm as *mut sockaddr,
479 &mut socklen as *mut socklen_t,
480 libc::SOCK_CLOEXEC,
481 )
482 };
483 if fd < 0 {
484 return Err(io::Error::last_os_error());
485 }
486
487 if svm.svm_family != AF_VSOCK {
488 return Err(io::Error::new(
489 io::ErrorKind::InvalidData,
490 format!("unexpected address family: {}", svm.svm_family),
491 ));
492 }
493
494 Ok((
495 VsockStream {
496 sock: VsockSocket { fd },
497 },
498 SocketAddr {
499 cid: svm.svm_cid.into(),
500 port: svm.svm_port,
501 },
502 ))
503 }
504
set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()>505 pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
506 self.sock.set_nonblocking(nonblocking)
507 }
508 }
509
510 impl AsRawFd for VsockListener {
as_raw_fd(&self) -> RawFd511 fn as_raw_fd(&self) -> RawFd {
512 self.sock.as_raw_fd()
513 }
514 }
515