1 // Copyright 2022 The ChromiumOS Authors 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 use std::fs::File; 6 use std::io; 7 use std::io::Read; 8 use std::io::Seek; 9 use std::io::SeekFrom; 10 use std::io::Write; 11 use std::mem::ManuallyDrop; 12 use std::sync::Arc; 13 14 use base::AsRawDescriptor; 15 use base::FileReadWriteAtVolatile; 16 use base::FileReadWriteVolatile; 17 use base::FromRawDescriptor; 18 use base::PunchHole; 19 use base::VolatileSlice; 20 use base::WriteZeroesAt; 21 use smallvec::SmallVec; 22 use sync::Mutex; 23 24 use crate::mem::MemRegion; 25 use crate::AsyncError; 26 use crate::AsyncResult; 27 use crate::BackingMemory; 28 29 #[derive(Debug, thiserror::Error)] 30 pub enum Error { 31 #[error("An error occurred trying to seek: {0}.")] 32 IoSeekError(io::Error), 33 #[error("An error occurred trying to read: {0}.")] 34 IoReadError(io::Error), 35 #[error("An error occurred trying to write: {0}.")] 36 IoWriteError(io::Error), 37 #[error("An error occurred trying to flush: {0}.")] 38 IoFlushError(io::Error), 39 #[error("An error occurred trying to punch hole: {0}.")] 40 IoPunchHoleError(io::Error), 41 #[error("An error occurred trying to write zeroes: {0}.")] 42 IoWriteZeroesError(io::Error), 43 #[error("Failed to join task: '{0}'")] 44 Join(tokio::task::JoinError), 45 #[error("An error occurred trying to duplicate source handles: {0}.")] 46 HandleDuplicationFailed(io::Error), 47 #[error("An error occurred trying to wait on source handles: {0}.")] 48 HandleWaitFailed(base::Error), 49 #[error("An error occurred trying to get a VolatileSlice into BackingMemory: {0}.")] 50 BackingMemoryVolatileSliceFetchFailed(crate::mem::Error), 51 #[error("TokioSource is gone, so no handles are available to fulfill the IO request.")] 52 NoTokioSource, 53 #[error("Operation on TokioSource is cancelled.")] 54 OperationCancelled, 55 #[error("Operation on TokioSource was aborted (unexpected).")] 56 OperationAborted, 57 } 58 59 impl From<Error> for AsyncError { from(e: Error) -> AsyncError60 fn from(e: Error) -> AsyncError { 61 AsyncError::SysVariants(e.into()) 62 } 63 } 64 65 impl From<Error> for io::Error { from(e: Error) -> Self66 fn from(e: Error) -> Self { 67 use Error::*; 68 match e { 69 IoSeekError(e) => e, 70 IoReadError(e) => e, 71 IoWriteError(e) => e, 72 IoFlushError(e) => e, 73 IoPunchHoleError(e) => e, 74 IoWriteZeroesError(e) => e, 75 Join(e) => io::Error::new(io::ErrorKind::Other, e), 76 HandleDuplicationFailed(e) => e, 77 HandleWaitFailed(e) => e.into(), 78 BackingMemoryVolatileSliceFetchFailed(e) => io::Error::new(io::ErrorKind::Other, e), 79 NoTokioSource => io::Error::new(io::ErrorKind::Other, NoTokioSource), 80 OperationCancelled => io::Error::new(io::ErrorKind::Interrupted, OperationCancelled), 81 OperationAborted => io::Error::new(io::ErrorKind::Interrupted, OperationAborted), 82 } 83 } 84 } 85 86 pub type Result<T> = std::result::Result<T, Error>; 87 88 pub struct TokioSource<T: AsRawDescriptor> { 89 source: Option<T>, 90 source_file: Arc<Mutex<Option<ManuallyDrop<File>>>>, 91 runtime: tokio::runtime::Handle, 92 } 93 94 impl<T: AsRawDescriptor> TokioSource<T> { new(source: T, runtime: tokio::runtime::Handle) -> Result<TokioSource<T>>95 pub(crate) fn new(source: T, runtime: tokio::runtime::Handle) -> Result<TokioSource<T>> { 96 let descriptor = source.as_raw_descriptor(); 97 // SAFETY: The Drop implementation makes sure `source` outlives `source_file`. 98 let source_file = unsafe { ManuallyDrop::new(File::from_raw_descriptor(descriptor)) }; 99 Ok(Self { 100 source: Some(source), 101 source_file: Arc::new(Mutex::new(Some(source_file))), 102 runtime, 103 }) 104 } 105 #[inline] get_slices( mem: &Arc<dyn BackingMemory + Send + Sync>, mem_offsets: Vec<MemRegion>, ) -> Result<SmallVec<[VolatileSlice<'_>; 16]>>106 fn get_slices( 107 mem: &Arc<dyn BackingMemory + Send + Sync>, 108 mem_offsets: Vec<MemRegion>, 109 ) -> Result<SmallVec<[VolatileSlice<'_>; 16]>> { 110 mem_offsets 111 .into_iter() 112 .map(|region| { 113 mem.get_volatile_slice(region) 114 .map_err(Error::BackingMemoryVolatileSliceFetchFailed) 115 }) 116 .collect::<Result<SmallVec<[VolatileSlice; 16]>>>() 117 } as_source(&self) -> &T118 pub fn as_source(&self) -> &T { 119 self.source.as_ref().unwrap() 120 } as_source_mut(&mut self) -> &mut T121 pub fn as_source_mut(&mut self) -> &mut T { 122 self.source.as_mut().unwrap() 123 } fdatasync(&self) -> AsyncResult<()>124 pub async fn fdatasync(&self) -> AsyncResult<()> { 125 // TODO(b/282003931): Fall back to regular fsync. 126 self.fsync().await 127 } fsync(&self) -> AsyncResult<()>128 pub async fn fsync(&self) -> AsyncResult<()> { 129 let source_file = self.source_file.clone(); 130 Ok(self 131 .runtime 132 .spawn_blocking(move || { 133 source_file 134 .lock() 135 .as_mut() 136 .ok_or(Error::OperationCancelled)? 137 .flush() 138 .map_err(Error::IoFlushError) 139 }) 140 .await 141 .map_err(Error::Join)??) 142 } into_source(mut self) -> T143 pub fn into_source(mut self) -> T { 144 self.source_file.lock().take(); 145 self.source.take().unwrap() 146 } punch_hole(&self, file_offset: u64, len: u64) -> AsyncResult<()>147 pub async fn punch_hole(&self, file_offset: u64, len: u64) -> AsyncResult<()> { 148 let source_file = self.source_file.clone(); 149 Ok(self 150 .runtime 151 .spawn_blocking(move || { 152 source_file 153 .lock() 154 .as_mut() 155 .ok_or(Error::OperationCancelled)? 156 .punch_hole(file_offset, len) 157 .map_err(Error::IoPunchHoleError) 158 }) 159 .await 160 .map_err(Error::Join)??) 161 } read_to_mem( &self, file_offset: Option<u64>, mem: Arc<dyn BackingMemory + Send + Sync>, mem_offsets: impl IntoIterator<Item = MemRegion>, ) -> AsyncResult<usize>162 pub async fn read_to_mem( 163 &self, 164 file_offset: Option<u64>, 165 mem: Arc<dyn BackingMemory + Send + Sync>, 166 mem_offsets: impl IntoIterator<Item = MemRegion>, 167 ) -> AsyncResult<usize> { 168 let mem_offsets = mem_offsets.into_iter().collect(); 169 let source_file = self.source_file.clone(); 170 Ok(self 171 .runtime 172 .spawn_blocking(move || { 173 let mut file_lock = source_file.lock(); 174 let file = file_lock.as_mut().ok_or(Error::OperationCancelled)?; 175 let memory_slices = Self::get_slices(&mem, mem_offsets)?; 176 match file_offset { 177 Some(file_offset) => file 178 .read_vectored_at_volatile(memory_slices.as_slice(), file_offset) 179 .map_err(Error::IoReadError), 180 None => file 181 .read_vectored_volatile(memory_slices.as_slice()) 182 .map_err(Error::IoReadError), 183 } 184 }) 185 .await 186 .map_err(Error::Join)??) 187 } read_to_vec( &self, file_offset: Option<u64>, mut vec: Vec<u8>, ) -> AsyncResult<(usize, Vec<u8>)>188 pub async fn read_to_vec( 189 &self, 190 file_offset: Option<u64>, 191 mut vec: Vec<u8>, 192 ) -> AsyncResult<(usize, Vec<u8>)> { 193 let source_file = self.source_file.clone(); 194 Ok(self 195 .runtime 196 .spawn_blocking(move || { 197 let mut file_lock = source_file.lock(); 198 let file = file_lock.as_mut().ok_or(Error::OperationCancelled)?; 199 if let Some(file_offset) = file_offset { 200 file.seek(SeekFrom::Start(file_offset)) 201 .map_err(Error::IoSeekError)?; 202 } 203 Ok::<(usize, Vec<u8>), Error>(( 204 file.read(vec.as_mut_slice()).map_err(Error::IoReadError)?, 205 vec, 206 )) 207 }) 208 .await 209 .map_err(Error::Join)??) 210 } wait_readable(&self) -> AsyncResult<()>211 pub async fn wait_readable(&self) -> AsyncResult<()> { 212 unimplemented!(); 213 } wait_for_handle(&self) -> AsyncResult<()>214 pub async fn wait_for_handle(&self) -> AsyncResult<()> { 215 base::sys::windows::async_wait_for_single_object(self.source.as_ref().unwrap()).await?; 216 Ok(()) 217 } write_from_mem( &self, file_offset: Option<u64>, mem: Arc<dyn BackingMemory + Send + Sync>, mem_offsets: impl IntoIterator<Item = MemRegion>, ) -> AsyncResult<usize>218 pub async fn write_from_mem( 219 &self, 220 file_offset: Option<u64>, 221 mem: Arc<dyn BackingMemory + Send + Sync>, 222 mem_offsets: impl IntoIterator<Item = MemRegion>, 223 ) -> AsyncResult<usize> { 224 let mem_offsets = mem_offsets.into_iter().collect(); 225 let source_file = self.source_file.clone(); 226 Ok(self 227 .runtime 228 .spawn_blocking(move || { 229 let mut file_lock = source_file.lock(); 230 let file = file_lock.as_mut().ok_or(Error::OperationCancelled)?; 231 let memory_slices = Self::get_slices(&mem, mem_offsets)?; 232 match file_offset { 233 Some(file_offset) => file 234 .write_vectored_at_volatile(memory_slices.as_slice(), file_offset) 235 .map_err(Error::IoWriteError), 236 None => file 237 .write_vectored_volatile(memory_slices.as_slice()) 238 .map_err(Error::IoWriteError), 239 } 240 }) 241 .await 242 .map_err(Error::Join)??) 243 } write_from_vec( &self, file_offset: Option<u64>, vec: Vec<u8>, ) -> AsyncResult<(usize, Vec<u8>)>244 pub async fn write_from_vec( 245 &self, 246 file_offset: Option<u64>, 247 vec: Vec<u8>, 248 ) -> AsyncResult<(usize, Vec<u8>)> { 249 let source_file = self.source_file.clone(); 250 Ok(self 251 .runtime 252 .spawn_blocking(move || { 253 let mut file_lock = source_file.lock(); 254 let file = file_lock.as_mut().ok_or(Error::OperationCancelled)?; 255 if let Some(file_offset) = file_offset { 256 file.seek(SeekFrom::Start(file_offset)) 257 .map_err(Error::IoSeekError)?; 258 } 259 Ok::<(usize, Vec<u8>), Error>(( 260 file.write(vec.as_slice()).map_err(Error::IoWriteError)?, 261 vec, 262 )) 263 }) 264 .await 265 .map_err(Error::Join)??) 266 } write_zeroes_at(&self, file_offset: u64, len: u64) -> AsyncResult<()>267 pub async fn write_zeroes_at(&self, file_offset: u64, len: u64) -> AsyncResult<()> { 268 let source_file = self.source_file.clone(); 269 Ok(self 270 .runtime 271 .spawn_blocking(move || { 272 // ZeroRange calls `punch_hole` which doesn't extend the File size if it needs to. 273 // Will fix if it becomes a problem. 274 source_file 275 .lock() 276 .as_mut() 277 .ok_or(Error::OperationCancelled)? 278 .write_zeroes_at(file_offset, len as usize) 279 .map_err(Error::IoWriteZeroesError) 280 .map(|_| ()) 281 }) 282 .await 283 .map_err(Error::Join)??) 284 } 285 } 286 impl<T: AsRawDescriptor> Drop for TokioSource<T> { drop(&mut self)287 fn drop(&mut self) { 288 let mut source_file = self.source_file.lock(); 289 source_file.take(); 290 } 291 } 292