xref: /aosp_15_r20/external/crosvm/cros_async/src/sys/linux/tokio_source.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2024 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::io;
6 use std::os::fd::AsRawFd;
7 use std::os::fd::OwnedFd;
8 use std::os::fd::RawFd;
9 use std::sync::Arc;
10 
11 use base::add_fd_flags;
12 use base::clone_descriptor;
13 use base::linux::fallocate;
14 use base::linux::FallocateMode;
15 use base::AsRawDescriptor;
16 use base::VolatileSlice;
17 use tokio::io::unix::AsyncFd;
18 
19 use crate::mem::MemRegion;
20 use crate::AsyncError;
21 use crate::AsyncResult;
22 use crate::BackingMemory;
23 
24 #[derive(Debug, thiserror::Error)]
25 pub enum Error {
26     #[error("Failed to copy the FD for the polling context: '{0}'")]
27     DuplicatingFd(base::Error),
28     #[error("Failed to punch hole in file: '{0}'.")]
29     Fallocate(base::Error),
30     #[error("Failed to fdatasync: '{0}'")]
31     Fdatasync(io::Error),
32     #[error("Failed to fsync: '{0}'")]
33     Fsync(io::Error),
34     #[error("Failed to join task: '{0}'")]
35     Join(tokio::task::JoinError),
36     #[error("Cannot wait on file descriptor")]
37     NonWaitable,
38     #[error("Failed to read: '{0}'")]
39     Read(io::Error),
40     #[error("Failed to set nonblocking: '{0}'")]
41     SettingNonBlocking(base::Error),
42     #[error("Tokio Async FD error: '{0}'")]
43     TokioAsyncFd(io::Error),
44     #[error("Failed to write: '{0}'")]
45     Write(io::Error),
46 }
47 
48 impl From<Error> for io::Error {
from(e: Error) -> Self49     fn from(e: Error) -> Self {
50         use Error::*;
51         match e {
52             DuplicatingFd(e) => e.into(),
53             Fallocate(e) => e.into(),
54             Fdatasync(e) => e,
55             Fsync(e) => e,
56             Join(e) => io::Error::new(io::ErrorKind::Other, e),
57             NonWaitable => io::Error::new(io::ErrorKind::Other, e),
58             Read(e) => e,
59             SettingNonBlocking(e) => e.into(),
60             TokioAsyncFd(e) => e,
61             Write(e) => e,
62         }
63     }
64 }
65 
66 enum FdType {
67     Async(AsyncFd<Arc<OwnedFd>>),
68     Blocking(Arc<OwnedFd>),
69 }
70 
71 impl AsRawFd for FdType {
as_raw_fd(&self) -> RawFd72     fn as_raw_fd(&self) -> RawFd {
73         match self {
74             FdType::Async(async_fd) => async_fd.as_raw_fd(),
75             FdType::Blocking(blocking) => blocking.as_raw_fd(),
76         }
77     }
78 }
79 
80 impl From<Error> for AsyncError {
from(e: Error) -> AsyncError81     fn from(e: Error) -> AsyncError {
82         AsyncError::SysVariants(e.into())
83     }
84 }
85 
do_fdatasync(raw: Arc<OwnedFd>) -> io::Result<()>86 fn do_fdatasync(raw: Arc<OwnedFd>) -> io::Result<()> {
87     let fd = raw.as_raw_fd();
88     // SAFETY: we partially own `raw`
89     match unsafe { libc::fdatasync(fd) } {
90         0 => Ok(()),
91         _ => Err(io::Error::last_os_error()),
92     }
93 }
94 
do_fsync(raw: Arc<OwnedFd>) -> io::Result<()>95 fn do_fsync(raw: Arc<OwnedFd>) -> io::Result<()> {
96     let fd = raw.as_raw_fd();
97     // SAFETY: we partially own `raw`
98     match unsafe { libc::fsync(fd) } {
99         0 => Ok(()),
100         _ => Err(io::Error::last_os_error()),
101     }
102 }
103 
do_read_vectored( raw: Arc<OwnedFd>, file_offset: Option<u64>, io_vecs: &[VolatileSlice], ) -> io::Result<usize>104 fn do_read_vectored(
105     raw: Arc<OwnedFd>,
106     file_offset: Option<u64>,
107     io_vecs: &[VolatileSlice],
108 ) -> io::Result<usize> {
109     let ptr = io_vecs.as_ptr() as *const libc::iovec;
110     let len = io_vecs.len() as i32;
111     let fd = raw.as_raw_fd();
112     let res = match file_offset {
113         // SAFETY: we partially own `raw`, `io_vecs` is validated
114         Some(off) => unsafe { libc::preadv64(fd, ptr, len, off as libc::off64_t) },
115         // SAFETY: we partially own `raw`, `io_vecs` is validated
116         None => unsafe { libc::readv(fd, ptr, len) },
117     };
118     match res {
119         r if r >= 0 => Ok(res as usize),
120         _ => Err(io::Error::last_os_error()),
121     }
122 }
do_read(raw: Arc<OwnedFd>, file_offset: Option<u64>, buf: &mut [u8]) -> io::Result<usize>123 fn do_read(raw: Arc<OwnedFd>, file_offset: Option<u64>, buf: &mut [u8]) -> io::Result<usize> {
124     let fd = raw.as_raw_fd();
125     let ptr = buf.as_mut_ptr() as *mut libc::c_void;
126     let res = match file_offset {
127         // SAFETY: we partially own `raw`, `ptr` has space up to vec.len()
128         Some(off) => unsafe { libc::pread64(fd, ptr, buf.len(), off as libc::off64_t) },
129         // SAFETY: we partially own `raw`, `ptr` has space up to vec.len()
130         None => unsafe { libc::read(fd, ptr, buf.len()) },
131     };
132     match res {
133         r if r >= 0 => Ok(res as usize),
134         _ => Err(io::Error::last_os_error()),
135     }
136 }
137 
do_write(raw: Arc<OwnedFd>, file_offset: Option<u64>, buf: &[u8]) -> io::Result<usize>138 fn do_write(raw: Arc<OwnedFd>, file_offset: Option<u64>, buf: &[u8]) -> io::Result<usize> {
139     let fd = raw.as_raw_fd();
140     let ptr = buf.as_ptr() as *const libc::c_void;
141     let res = match file_offset {
142         // SAFETY: we partially own `raw`, `ptr` has data up to vec.len()
143         Some(off) => unsafe { libc::pwrite64(fd, ptr, buf.len(), off as libc::off64_t) },
144         // SAFETY: we partially own `raw`, `ptr` has data up to vec.len()
145         None => unsafe { libc::write(fd, ptr, buf.len()) },
146     };
147     match res {
148         r if r >= 0 => Ok(res as usize),
149         _ => Err(io::Error::last_os_error()),
150     }
151 }
152 
do_write_vectored( raw: Arc<OwnedFd>, file_offset: Option<u64>, io_vecs: &[VolatileSlice], ) -> io::Result<usize>153 fn do_write_vectored(
154     raw: Arc<OwnedFd>,
155     file_offset: Option<u64>,
156     io_vecs: &[VolatileSlice],
157 ) -> io::Result<usize> {
158     let ptr = io_vecs.as_ptr() as *const libc::iovec;
159     let len = io_vecs.len() as i32;
160     let fd = raw.as_raw_fd();
161     let res = match file_offset {
162         // SAFETY: we partially own `raw`, `io_vecs` is validated
163         Some(off) => unsafe { libc::pwritev64(fd, ptr, len, off as libc::off64_t) },
164         // SAFETY: we partially own `raw`, `io_vecs` is validated
165         None => unsafe { libc::writev(fd, ptr, len) },
166     };
167     match res {
168         r if r >= 0 => Ok(res as usize),
169         _ => Err(io::Error::last_os_error()),
170     }
171 }
172 
173 pub struct TokioSource<T> {
174     fd: FdType,
175     inner: T,
176     runtime: tokio::runtime::Handle,
177 }
178 impl<T: AsRawDescriptor> TokioSource<T> {
new(inner: T, runtime: tokio::runtime::Handle) -> Result<TokioSource<T>, Error>179     pub fn new(inner: T, runtime: tokio::runtime::Handle) -> Result<TokioSource<T>, Error> {
180         let _guard = runtime.enter(); // Required for AsyncFd
181         let safe_fd = clone_descriptor(&inner).map_err(Error::DuplicatingFd)?;
182         let fd_arc: Arc<OwnedFd> = Arc::new(safe_fd.into());
183         let fd = match AsyncFd::new(fd_arc.clone()) {
184             Ok(async_fd) => {
185                 add_fd_flags(async_fd.get_ref().as_raw_descriptor(), libc::O_NONBLOCK)
186                     .map_err(Error::SettingNonBlocking)?;
187                 FdType::Async(async_fd)
188             }
189             Err(e) if e.kind() == io::ErrorKind::PermissionDenied => FdType::Blocking(fd_arc),
190             Err(e) => return Err(Error::TokioAsyncFd(e)),
191         };
192         Ok(TokioSource { fd, inner, runtime })
193     }
194 
as_source(&self) -> &T195     pub fn as_source(&self) -> &T {
196         &self.inner
197     }
198 
as_source_mut(&mut self) -> &mut T199     pub fn as_source_mut(&mut self) -> &mut T {
200         &mut self.inner
201     }
202 
clone_fd(&self) -> Arc<OwnedFd>203     fn clone_fd(&self) -> Arc<OwnedFd> {
204         match &self.fd {
205             FdType::Async(async_fd) => async_fd.get_ref().clone(),
206             FdType::Blocking(blocking) => blocking.clone(),
207         }
208     }
209 
fdatasync(&self) -> AsyncResult<()>210     pub async fn fdatasync(&self) -> AsyncResult<()> {
211         let fd = self.clone_fd();
212         Ok(self
213             .runtime
214             .spawn_blocking(move || do_fdatasync(fd))
215             .await
216             .map_err(Error::Join)?
217             .map_err(Error::Fdatasync)?)
218     }
219 
fsync(&self) -> AsyncResult<()>220     pub async fn fsync(&self) -> AsyncResult<()> {
221         let fd = self.clone_fd();
222         Ok(self
223             .runtime
224             .spawn_blocking(move || do_fsync(fd))
225             .await
226             .map_err(Error::Join)?
227             .map_err(Error::Fsync)?)
228     }
229 
into_source(self) -> T230     pub fn into_source(self) -> T {
231         self.inner
232     }
233 
read_to_vec( &self, file_offset: Option<u64>, mut vec: Vec<u8>, ) -> AsyncResult<(usize, Vec<u8>)>234     pub async fn read_to_vec(
235         &self,
236         file_offset: Option<u64>,
237         mut vec: Vec<u8>,
238     ) -> AsyncResult<(usize, Vec<u8>)> {
239         Ok(match &self.fd {
240             FdType::Async(async_fd) => {
241                 let res = async_fd
242                     .async_io(tokio::io::Interest::READABLE, |fd| {
243                         do_read(fd.clone(), file_offset, &mut vec)
244                     })
245                     .await
246                     .map_err(AsyncError::Io)?;
247                 (res, vec)
248             }
249             FdType::Blocking(blocking) => {
250                 let fd = blocking.clone();
251                 self.runtime
252                     .spawn_blocking(move || {
253                         let size = do_read(fd, file_offset, &mut vec)?;
254                         Ok((size, vec))
255                     })
256                     .await
257                     .map_err(Error::Join)?
258                     .map_err(Error::Read)?
259             }
260         })
261     }
262 
read_to_mem( &self, file_offset: Option<u64>, mem: Arc<dyn BackingMemory + Send + Sync>, mem_offsets: impl IntoIterator<Item = MemRegion>, ) -> AsyncResult<usize>263     pub async fn read_to_mem(
264         &self,
265         file_offset: Option<u64>,
266         mem: Arc<dyn BackingMemory + Send + Sync>,
267         mem_offsets: impl IntoIterator<Item = MemRegion>,
268     ) -> AsyncResult<usize> {
269         let mem_offsets_vec: Vec<MemRegion> = mem_offsets.into_iter().collect();
270         Ok(match &self.fd {
271             FdType::Async(async_fd) => {
272                 let iovecs = mem_offsets_vec
273                     .into_iter()
274                     .filter_map(|mem_range| mem.get_volatile_slice(mem_range).ok())
275                     .collect::<Vec<VolatileSlice>>();
276                 async_fd
277                     .async_io(tokio::io::Interest::READABLE, |fd| {
278                         do_read_vectored(fd.clone(), file_offset, &iovecs)
279                     })
280                     .await
281                     .map_err(AsyncError::Io)?
282             }
283             FdType::Blocking(blocking) => {
284                 let fd = blocking.clone();
285                 self.runtime
286                     .spawn_blocking(move || {
287                         let iovecs = mem_offsets_vec
288                             .into_iter()
289                             .filter_map(|mem_range| mem.get_volatile_slice(mem_range).ok())
290                             .collect::<Vec<VolatileSlice>>();
291                         do_read_vectored(fd, file_offset, &iovecs)
292                     })
293                     .await
294                     .map_err(Error::Join)?
295                     .map_err(Error::Read)?
296             }
297         })
298     }
299 
punch_hole(&self, file_offset: u64, len: u64) -> AsyncResult<()>300     pub async fn punch_hole(&self, file_offset: u64, len: u64) -> AsyncResult<()> {
301         let fd = self.clone_fd();
302         Ok(self
303             .runtime
304             .spawn_blocking(move || fallocate(&*fd, FallocateMode::PunchHole, file_offset, len))
305             .await
306             .map_err(Error::Join)?
307             .map_err(Error::Fallocate)?)
308     }
309 
wait_readable(&self) -> AsyncResult<()>310     pub async fn wait_readable(&self) -> AsyncResult<()> {
311         match &self.fd {
312             FdType::Async(async_fd) => async_fd
313                 .readable()
314                 .await
315                 .map_err(crate::AsyncError::Io)?
316                 .retain_ready(),
317             FdType::Blocking(_) => return Err(Error::NonWaitable.into()),
318         }
319         Ok(())
320     }
321 
write_from_mem( &self, file_offset: Option<u64>, mem: Arc<dyn BackingMemory + Send + Sync>, mem_offsets: impl IntoIterator<Item = MemRegion>, ) -> AsyncResult<usize>322     pub async fn write_from_mem(
323         &self,
324         file_offset: Option<u64>,
325         mem: Arc<dyn BackingMemory + Send + Sync>,
326         mem_offsets: impl IntoIterator<Item = MemRegion>,
327     ) -> AsyncResult<usize> {
328         let mem_offsets_vec: Vec<MemRegion> = mem_offsets.into_iter().collect();
329         Ok(match &self.fd {
330             FdType::Async(async_fd) => {
331                 let iovecs = mem_offsets_vec
332                     .into_iter()
333                     .filter_map(|mem_range| mem.get_volatile_slice(mem_range).ok())
334                     .collect::<Vec<VolatileSlice>>();
335                 async_fd
336                     .async_io(tokio::io::Interest::WRITABLE, |fd| {
337                         do_write_vectored(fd.clone(), file_offset, &iovecs)
338                     })
339                     .await
340                     .map_err(AsyncError::Io)?
341             }
342             FdType::Blocking(blocking) => {
343                 let fd = blocking.clone();
344                 self.runtime
345                     .spawn_blocking(move || {
346                         let iovecs = mem_offsets_vec
347                             .into_iter()
348                             .filter_map(|mem_range| mem.get_volatile_slice(mem_range).ok())
349                             .collect::<Vec<VolatileSlice>>();
350                         do_write_vectored(fd, file_offset, &iovecs)
351                     })
352                     .await
353                     .map_err(Error::Join)?
354                     .map_err(Error::Read)?
355             }
356         })
357     }
358 
write_from_vec( &self, file_offset: Option<u64>, vec: Vec<u8>, ) -> AsyncResult<(usize, Vec<u8>)>359     pub async fn write_from_vec(
360         &self,
361         file_offset: Option<u64>,
362         vec: Vec<u8>,
363     ) -> AsyncResult<(usize, Vec<u8>)> {
364         Ok(match &self.fd {
365             FdType::Async(async_fd) => {
366                 let res = async_fd
367                     .async_io(tokio::io::Interest::WRITABLE, |fd| {
368                         do_write(fd.clone(), file_offset, &vec)
369                     })
370                     .await
371                     .map_err(AsyncError::Io)?;
372                 (res, vec)
373             }
374             FdType::Blocking(blocking) => {
375                 let fd = blocking.clone();
376                 self.runtime
377                     .spawn_blocking(move || {
378                         let size = do_write(fd.clone(), file_offset, &vec)?;
379                         Ok((size, vec))
380                     })
381                     .await
382                     .map_err(Error::Join)?
383                     .map_err(Error::Read)?
384             }
385         })
386     }
387 
write_zeroes_at(&self, file_offset: u64, len: u64) -> AsyncResult<()>388     pub async fn write_zeroes_at(&self, file_offset: u64, len: u64) -> AsyncResult<()> {
389         let fd = self.clone_fd();
390         Ok(self
391             .runtime
392             .spawn_blocking(move || fallocate(&*fd, FallocateMode::ZeroRange, file_offset, len))
393             .await
394             .map_err(Error::Join)?
395             .map_err(Error::Fallocate)?)
396     }
397 }
398