xref: /aosp_15_r20/external/crosvm/cros_async/src/sys/windows/tokio_source.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2022 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::fs::File;
6 use std::io;
7 use std::io::Read;
8 use std::io::Seek;
9 use std::io::SeekFrom;
10 use std::io::Write;
11 use std::mem::ManuallyDrop;
12 use std::sync::Arc;
13 
14 use base::AsRawDescriptor;
15 use base::FileReadWriteAtVolatile;
16 use base::FileReadWriteVolatile;
17 use base::FromRawDescriptor;
18 use base::PunchHole;
19 use base::VolatileSlice;
20 use base::WriteZeroesAt;
21 use smallvec::SmallVec;
22 use sync::Mutex;
23 
24 use crate::mem::MemRegion;
25 use crate::AsyncError;
26 use crate::AsyncResult;
27 use crate::BackingMemory;
28 
29 #[derive(Debug, thiserror::Error)]
30 pub enum Error {
31     #[error("An error occurred trying to seek: {0}.")]
32     IoSeekError(io::Error),
33     #[error("An error occurred trying to read: {0}.")]
34     IoReadError(io::Error),
35     #[error("An error occurred trying to write: {0}.")]
36     IoWriteError(io::Error),
37     #[error("An error occurred trying to flush: {0}.")]
38     IoFlushError(io::Error),
39     #[error("An error occurred trying to punch hole: {0}.")]
40     IoPunchHoleError(io::Error),
41     #[error("An error occurred trying to write zeroes: {0}.")]
42     IoWriteZeroesError(io::Error),
43     #[error("Failed to join task: '{0}'")]
44     Join(tokio::task::JoinError),
45     #[error("An error occurred trying to duplicate source handles: {0}.")]
46     HandleDuplicationFailed(io::Error),
47     #[error("An error occurred trying to wait on source handles: {0}.")]
48     HandleWaitFailed(base::Error),
49     #[error("An error occurred trying to get a VolatileSlice into BackingMemory: {0}.")]
50     BackingMemoryVolatileSliceFetchFailed(crate::mem::Error),
51     #[error("TokioSource is gone, so no handles are available to fulfill the IO request.")]
52     NoTokioSource,
53     #[error("Operation on TokioSource is cancelled.")]
54     OperationCancelled,
55     #[error("Operation on TokioSource was aborted (unexpected).")]
56     OperationAborted,
57 }
58 
59 impl From<Error> for AsyncError {
from(e: Error) -> AsyncError60     fn from(e: Error) -> AsyncError {
61         AsyncError::SysVariants(e.into())
62     }
63 }
64 
65 impl From<Error> for io::Error {
from(e: Error) -> Self66     fn from(e: Error) -> Self {
67         use Error::*;
68         match e {
69             IoSeekError(e) => e,
70             IoReadError(e) => e,
71             IoWriteError(e) => e,
72             IoFlushError(e) => e,
73             IoPunchHoleError(e) => e,
74             IoWriteZeroesError(e) => e,
75             Join(e) => io::Error::new(io::ErrorKind::Other, e),
76             HandleDuplicationFailed(e) => e,
77             HandleWaitFailed(e) => e.into(),
78             BackingMemoryVolatileSliceFetchFailed(e) => io::Error::new(io::ErrorKind::Other, e),
79             NoTokioSource => io::Error::new(io::ErrorKind::Other, NoTokioSource),
80             OperationCancelled => io::Error::new(io::ErrorKind::Interrupted, OperationCancelled),
81             OperationAborted => io::Error::new(io::ErrorKind::Interrupted, OperationAborted),
82         }
83     }
84 }
85 
86 pub type Result<T> = std::result::Result<T, Error>;
87 
88 pub struct TokioSource<T: AsRawDescriptor> {
89     source: Option<T>,
90     source_file: Arc<Mutex<Option<ManuallyDrop<File>>>>,
91     runtime: tokio::runtime::Handle,
92 }
93 
94 impl<T: AsRawDescriptor> TokioSource<T> {
new(source: T, runtime: tokio::runtime::Handle) -> Result<TokioSource<T>>95     pub(crate) fn new(source: T, runtime: tokio::runtime::Handle) -> Result<TokioSource<T>> {
96         let descriptor = source.as_raw_descriptor();
97         // SAFETY: The Drop implementation makes sure `source` outlives `source_file`.
98         let source_file = unsafe { ManuallyDrop::new(File::from_raw_descriptor(descriptor)) };
99         Ok(Self {
100             source: Some(source),
101             source_file: Arc::new(Mutex::new(Some(source_file))),
102             runtime,
103         })
104     }
105     #[inline]
get_slices( mem: &Arc<dyn BackingMemory + Send + Sync>, mem_offsets: Vec<MemRegion>, ) -> Result<SmallVec<[VolatileSlice<'_>; 16]>>106     fn get_slices(
107         mem: &Arc<dyn BackingMemory + Send + Sync>,
108         mem_offsets: Vec<MemRegion>,
109     ) -> Result<SmallVec<[VolatileSlice<'_>; 16]>> {
110         mem_offsets
111             .into_iter()
112             .map(|region| {
113                 mem.get_volatile_slice(region)
114                     .map_err(Error::BackingMemoryVolatileSliceFetchFailed)
115             })
116             .collect::<Result<SmallVec<[VolatileSlice; 16]>>>()
117     }
as_source(&self) -> &T118     pub fn as_source(&self) -> &T {
119         self.source.as_ref().unwrap()
120     }
as_source_mut(&mut self) -> &mut T121     pub fn as_source_mut(&mut self) -> &mut T {
122         self.source.as_mut().unwrap()
123     }
fdatasync(&self) -> AsyncResult<()>124     pub async fn fdatasync(&self) -> AsyncResult<()> {
125         // TODO(b/282003931): Fall back to regular fsync.
126         self.fsync().await
127     }
fsync(&self) -> AsyncResult<()>128     pub async fn fsync(&self) -> AsyncResult<()> {
129         let source_file = self.source_file.clone();
130         Ok(self
131             .runtime
132             .spawn_blocking(move || {
133                 source_file
134                     .lock()
135                     .as_mut()
136                     .ok_or(Error::OperationCancelled)?
137                     .flush()
138                     .map_err(Error::IoFlushError)
139             })
140             .await
141             .map_err(Error::Join)??)
142     }
into_source(mut self) -> T143     pub fn into_source(mut self) -> T {
144         self.source_file.lock().take();
145         self.source.take().unwrap()
146     }
punch_hole(&self, file_offset: u64, len: u64) -> AsyncResult<()>147     pub async fn punch_hole(&self, file_offset: u64, len: u64) -> AsyncResult<()> {
148         let source_file = self.source_file.clone();
149         Ok(self
150             .runtime
151             .spawn_blocking(move || {
152                 source_file
153                     .lock()
154                     .as_mut()
155                     .ok_or(Error::OperationCancelled)?
156                     .punch_hole(file_offset, len)
157                     .map_err(Error::IoPunchHoleError)
158             })
159             .await
160             .map_err(Error::Join)??)
161     }
read_to_mem( &self, file_offset: Option<u64>, mem: Arc<dyn BackingMemory + Send + Sync>, mem_offsets: impl IntoIterator<Item = MemRegion>, ) -> AsyncResult<usize>162     pub async fn read_to_mem(
163         &self,
164         file_offset: Option<u64>,
165         mem: Arc<dyn BackingMemory + Send + Sync>,
166         mem_offsets: impl IntoIterator<Item = MemRegion>,
167     ) -> AsyncResult<usize> {
168         let mem_offsets = mem_offsets.into_iter().collect();
169         let source_file = self.source_file.clone();
170         Ok(self
171             .runtime
172             .spawn_blocking(move || {
173                 let mut file_lock = source_file.lock();
174                 let file = file_lock.as_mut().ok_or(Error::OperationCancelled)?;
175                 let memory_slices = Self::get_slices(&mem, mem_offsets)?;
176                 match file_offset {
177                     Some(file_offset) => file
178                         .read_vectored_at_volatile(memory_slices.as_slice(), file_offset)
179                         .map_err(Error::IoReadError),
180                     None => file
181                         .read_vectored_volatile(memory_slices.as_slice())
182                         .map_err(Error::IoReadError),
183                 }
184             })
185             .await
186             .map_err(Error::Join)??)
187     }
read_to_vec( &self, file_offset: Option<u64>, mut vec: Vec<u8>, ) -> AsyncResult<(usize, Vec<u8>)>188     pub async fn read_to_vec(
189         &self,
190         file_offset: Option<u64>,
191         mut vec: Vec<u8>,
192     ) -> AsyncResult<(usize, Vec<u8>)> {
193         let source_file = self.source_file.clone();
194         Ok(self
195             .runtime
196             .spawn_blocking(move || {
197                 let mut file_lock = source_file.lock();
198                 let file = file_lock.as_mut().ok_or(Error::OperationCancelled)?;
199                 if let Some(file_offset) = file_offset {
200                     file.seek(SeekFrom::Start(file_offset))
201                         .map_err(Error::IoSeekError)?;
202                 }
203                 Ok::<(usize, Vec<u8>), Error>((
204                     file.read(vec.as_mut_slice()).map_err(Error::IoReadError)?,
205                     vec,
206                 ))
207             })
208             .await
209             .map_err(Error::Join)??)
210     }
wait_readable(&self) -> AsyncResult<()>211     pub async fn wait_readable(&self) -> AsyncResult<()> {
212         unimplemented!();
213     }
wait_for_handle(&self) -> AsyncResult<()>214     pub async fn wait_for_handle(&self) -> AsyncResult<()> {
215         base::sys::windows::async_wait_for_single_object(self.source.as_ref().unwrap()).await?;
216         Ok(())
217     }
write_from_mem( &self, file_offset: Option<u64>, mem: Arc<dyn BackingMemory + Send + Sync>, mem_offsets: impl IntoIterator<Item = MemRegion>, ) -> AsyncResult<usize>218     pub async fn write_from_mem(
219         &self,
220         file_offset: Option<u64>,
221         mem: Arc<dyn BackingMemory + Send + Sync>,
222         mem_offsets: impl IntoIterator<Item = MemRegion>,
223     ) -> AsyncResult<usize> {
224         let mem_offsets = mem_offsets.into_iter().collect();
225         let source_file = self.source_file.clone();
226         Ok(self
227             .runtime
228             .spawn_blocking(move || {
229                 let mut file_lock = source_file.lock();
230                 let file = file_lock.as_mut().ok_or(Error::OperationCancelled)?;
231                 let memory_slices = Self::get_slices(&mem, mem_offsets)?;
232                 match file_offset {
233                     Some(file_offset) => file
234                         .write_vectored_at_volatile(memory_slices.as_slice(), file_offset)
235                         .map_err(Error::IoWriteError),
236                     None => file
237                         .write_vectored_volatile(memory_slices.as_slice())
238                         .map_err(Error::IoWriteError),
239                 }
240             })
241             .await
242             .map_err(Error::Join)??)
243     }
write_from_vec( &self, file_offset: Option<u64>, vec: Vec<u8>, ) -> AsyncResult<(usize, Vec<u8>)>244     pub async fn write_from_vec(
245         &self,
246         file_offset: Option<u64>,
247         vec: Vec<u8>,
248     ) -> AsyncResult<(usize, Vec<u8>)> {
249         let source_file = self.source_file.clone();
250         Ok(self
251             .runtime
252             .spawn_blocking(move || {
253                 let mut file_lock = source_file.lock();
254                 let file = file_lock.as_mut().ok_or(Error::OperationCancelled)?;
255                 if let Some(file_offset) = file_offset {
256                     file.seek(SeekFrom::Start(file_offset))
257                         .map_err(Error::IoSeekError)?;
258                 }
259                 Ok::<(usize, Vec<u8>), Error>((
260                     file.write(vec.as_slice()).map_err(Error::IoWriteError)?,
261                     vec,
262                 ))
263             })
264             .await
265             .map_err(Error::Join)??)
266     }
write_zeroes_at(&self, file_offset: u64, len: u64) -> AsyncResult<()>267     pub async fn write_zeroes_at(&self, file_offset: u64, len: u64) -> AsyncResult<()> {
268         let source_file = self.source_file.clone();
269         Ok(self
270             .runtime
271             .spawn_blocking(move || {
272                 // ZeroRange calls `punch_hole` which doesn't extend the File size if it needs to.
273                 // Will fix if it becomes a problem.
274                 source_file
275                     .lock()
276                     .as_mut()
277                     .ok_or(Error::OperationCancelled)?
278                     .write_zeroes_at(file_offset, len as usize)
279                     .map_err(Error::IoWriteZeroesError)
280                     .map(|_| ())
281             })
282             .await
283             .map_err(Error::Join)??)
284     }
285 }
286 impl<T: AsRawDescriptor> Drop for TokioSource<T> {
drop(&mut self)287     fn drop(&mut self) {
288         let mut source_file = self.source_file.lock();
289         source_file.take();
290     }
291 }
292