1 // Copyright 2024 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 use std::io;
6 use std::os::fd::AsRawFd;
7 use std::os::fd::OwnedFd;
8 use std::os::fd::RawFd;
9 use std::sync::Arc;
10
11 use base::add_fd_flags;
12 use base::clone_descriptor;
13 use base::linux::fallocate;
14 use base::linux::FallocateMode;
15 use base::AsRawDescriptor;
16 use base::VolatileSlice;
17 use tokio::io::unix::AsyncFd;
18
19 use crate::mem::MemRegion;
20 use crate::AsyncError;
21 use crate::AsyncResult;
22 use crate::BackingMemory;
23
24 #[derive(Debug, thiserror::Error)]
25 pub enum Error {
26 #[error("Failed to copy the FD for the polling context: '{0}'")]
27 DuplicatingFd(base::Error),
28 #[error("Failed to punch hole in file: '{0}'.")]
29 Fallocate(base::Error),
30 #[error("Failed to fdatasync: '{0}'")]
31 Fdatasync(io::Error),
32 #[error("Failed to fsync: '{0}'")]
33 Fsync(io::Error),
34 #[error("Failed to join task: '{0}'")]
35 Join(tokio::task::JoinError),
36 #[error("Cannot wait on file descriptor")]
37 NonWaitable,
38 #[error("Failed to read: '{0}'")]
39 Read(io::Error),
40 #[error("Failed to set nonblocking: '{0}'")]
41 SettingNonBlocking(base::Error),
42 #[error("Tokio Async FD error: '{0}'")]
43 TokioAsyncFd(io::Error),
44 #[error("Failed to write: '{0}'")]
45 Write(io::Error),
46 }
47
48 impl From<Error> for io::Error {
from(e: Error) -> Self49 fn from(e: Error) -> Self {
50 use Error::*;
51 match e {
52 DuplicatingFd(e) => e.into(),
53 Fallocate(e) => e.into(),
54 Fdatasync(e) => e,
55 Fsync(e) => e,
56 Join(e) => io::Error::new(io::ErrorKind::Other, e),
57 NonWaitable => io::Error::new(io::ErrorKind::Other, e),
58 Read(e) => e,
59 SettingNonBlocking(e) => e.into(),
60 TokioAsyncFd(e) => e,
61 Write(e) => e,
62 }
63 }
64 }
65
66 enum FdType {
67 Async(AsyncFd<Arc<OwnedFd>>),
68 Blocking(Arc<OwnedFd>),
69 }
70
71 impl AsRawFd for FdType {
as_raw_fd(&self) -> RawFd72 fn as_raw_fd(&self) -> RawFd {
73 match self {
74 FdType::Async(async_fd) => async_fd.as_raw_fd(),
75 FdType::Blocking(blocking) => blocking.as_raw_fd(),
76 }
77 }
78 }
79
80 impl From<Error> for AsyncError {
from(e: Error) -> AsyncError81 fn from(e: Error) -> AsyncError {
82 AsyncError::SysVariants(e.into())
83 }
84 }
85
do_fdatasync(raw: Arc<OwnedFd>) -> io::Result<()>86 fn do_fdatasync(raw: Arc<OwnedFd>) -> io::Result<()> {
87 let fd = raw.as_raw_fd();
88 // SAFETY: we partially own `raw`
89 match unsafe { libc::fdatasync(fd) } {
90 0 => Ok(()),
91 _ => Err(io::Error::last_os_error()),
92 }
93 }
94
do_fsync(raw: Arc<OwnedFd>) -> io::Result<()>95 fn do_fsync(raw: Arc<OwnedFd>) -> io::Result<()> {
96 let fd = raw.as_raw_fd();
97 // SAFETY: we partially own `raw`
98 match unsafe { libc::fsync(fd) } {
99 0 => Ok(()),
100 _ => Err(io::Error::last_os_error()),
101 }
102 }
103
do_read_vectored( raw: Arc<OwnedFd>, file_offset: Option<u64>, io_vecs: &[VolatileSlice], ) -> io::Result<usize>104 fn do_read_vectored(
105 raw: Arc<OwnedFd>,
106 file_offset: Option<u64>,
107 io_vecs: &[VolatileSlice],
108 ) -> io::Result<usize> {
109 let ptr = io_vecs.as_ptr() as *const libc::iovec;
110 let len = io_vecs.len() as i32;
111 let fd = raw.as_raw_fd();
112 let res = match file_offset {
113 // SAFETY: we partially own `raw`, `io_vecs` is validated
114 Some(off) => unsafe { libc::preadv64(fd, ptr, len, off as libc::off64_t) },
115 // SAFETY: we partially own `raw`, `io_vecs` is validated
116 None => unsafe { libc::readv(fd, ptr, len) },
117 };
118 match res {
119 r if r >= 0 => Ok(res as usize),
120 _ => Err(io::Error::last_os_error()),
121 }
122 }
do_read(raw: Arc<OwnedFd>, file_offset: Option<u64>, buf: &mut [u8]) -> io::Result<usize>123 fn do_read(raw: Arc<OwnedFd>, file_offset: Option<u64>, buf: &mut [u8]) -> io::Result<usize> {
124 let fd = raw.as_raw_fd();
125 let ptr = buf.as_mut_ptr() as *mut libc::c_void;
126 let res = match file_offset {
127 // SAFETY: we partially own `raw`, `ptr` has space up to vec.len()
128 Some(off) => unsafe { libc::pread64(fd, ptr, buf.len(), off as libc::off64_t) },
129 // SAFETY: we partially own `raw`, `ptr` has space up to vec.len()
130 None => unsafe { libc::read(fd, ptr, buf.len()) },
131 };
132 match res {
133 r if r >= 0 => Ok(res as usize),
134 _ => Err(io::Error::last_os_error()),
135 }
136 }
137
do_write(raw: Arc<OwnedFd>, file_offset: Option<u64>, buf: &[u8]) -> io::Result<usize>138 fn do_write(raw: Arc<OwnedFd>, file_offset: Option<u64>, buf: &[u8]) -> io::Result<usize> {
139 let fd = raw.as_raw_fd();
140 let ptr = buf.as_ptr() as *const libc::c_void;
141 let res = match file_offset {
142 // SAFETY: we partially own `raw`, `ptr` has data up to vec.len()
143 Some(off) => unsafe { libc::pwrite64(fd, ptr, buf.len(), off as libc::off64_t) },
144 // SAFETY: we partially own `raw`, `ptr` has data up to vec.len()
145 None => unsafe { libc::write(fd, ptr, buf.len()) },
146 };
147 match res {
148 r if r >= 0 => Ok(res as usize),
149 _ => Err(io::Error::last_os_error()),
150 }
151 }
152
do_write_vectored( raw: Arc<OwnedFd>, file_offset: Option<u64>, io_vecs: &[VolatileSlice], ) -> io::Result<usize>153 fn do_write_vectored(
154 raw: Arc<OwnedFd>,
155 file_offset: Option<u64>,
156 io_vecs: &[VolatileSlice],
157 ) -> io::Result<usize> {
158 let ptr = io_vecs.as_ptr() as *const libc::iovec;
159 let len = io_vecs.len() as i32;
160 let fd = raw.as_raw_fd();
161 let res = match file_offset {
162 // SAFETY: we partially own `raw`, `io_vecs` is validated
163 Some(off) => unsafe { libc::pwritev64(fd, ptr, len, off as libc::off64_t) },
164 // SAFETY: we partially own `raw`, `io_vecs` is validated
165 None => unsafe { libc::writev(fd, ptr, len) },
166 };
167 match res {
168 r if r >= 0 => Ok(res as usize),
169 _ => Err(io::Error::last_os_error()),
170 }
171 }
172
173 pub struct TokioSource<T> {
174 fd: FdType,
175 inner: T,
176 runtime: tokio::runtime::Handle,
177 }
178 impl<T: AsRawDescriptor> TokioSource<T> {
new(inner: T, runtime: tokio::runtime::Handle) -> Result<TokioSource<T>, Error>179 pub fn new(inner: T, runtime: tokio::runtime::Handle) -> Result<TokioSource<T>, Error> {
180 let _guard = runtime.enter(); // Required for AsyncFd
181 let safe_fd = clone_descriptor(&inner).map_err(Error::DuplicatingFd)?;
182 let fd_arc: Arc<OwnedFd> = Arc::new(safe_fd.into());
183 let fd = match AsyncFd::new(fd_arc.clone()) {
184 Ok(async_fd) => {
185 add_fd_flags(async_fd.get_ref().as_raw_descriptor(), libc::O_NONBLOCK)
186 .map_err(Error::SettingNonBlocking)?;
187 FdType::Async(async_fd)
188 }
189 Err(e) if e.kind() == io::ErrorKind::PermissionDenied => FdType::Blocking(fd_arc),
190 Err(e) => return Err(Error::TokioAsyncFd(e)),
191 };
192 Ok(TokioSource { fd, inner, runtime })
193 }
194
as_source(&self) -> &T195 pub fn as_source(&self) -> &T {
196 &self.inner
197 }
198
as_source_mut(&mut self) -> &mut T199 pub fn as_source_mut(&mut self) -> &mut T {
200 &mut self.inner
201 }
202
clone_fd(&self) -> Arc<OwnedFd>203 fn clone_fd(&self) -> Arc<OwnedFd> {
204 match &self.fd {
205 FdType::Async(async_fd) => async_fd.get_ref().clone(),
206 FdType::Blocking(blocking) => blocking.clone(),
207 }
208 }
209
fdatasync(&self) -> AsyncResult<()>210 pub async fn fdatasync(&self) -> AsyncResult<()> {
211 let fd = self.clone_fd();
212 Ok(self
213 .runtime
214 .spawn_blocking(move || do_fdatasync(fd))
215 .await
216 .map_err(Error::Join)?
217 .map_err(Error::Fdatasync)?)
218 }
219
fsync(&self) -> AsyncResult<()>220 pub async fn fsync(&self) -> AsyncResult<()> {
221 let fd = self.clone_fd();
222 Ok(self
223 .runtime
224 .spawn_blocking(move || do_fsync(fd))
225 .await
226 .map_err(Error::Join)?
227 .map_err(Error::Fsync)?)
228 }
229
into_source(self) -> T230 pub fn into_source(self) -> T {
231 self.inner
232 }
233
read_to_vec( &self, file_offset: Option<u64>, mut vec: Vec<u8>, ) -> AsyncResult<(usize, Vec<u8>)>234 pub async fn read_to_vec(
235 &self,
236 file_offset: Option<u64>,
237 mut vec: Vec<u8>,
238 ) -> AsyncResult<(usize, Vec<u8>)> {
239 Ok(match &self.fd {
240 FdType::Async(async_fd) => {
241 let res = async_fd
242 .async_io(tokio::io::Interest::READABLE, |fd| {
243 do_read(fd.clone(), file_offset, &mut vec)
244 })
245 .await
246 .map_err(AsyncError::Io)?;
247 (res, vec)
248 }
249 FdType::Blocking(blocking) => {
250 let fd = blocking.clone();
251 self.runtime
252 .spawn_blocking(move || {
253 let size = do_read(fd, file_offset, &mut vec)?;
254 Ok((size, vec))
255 })
256 .await
257 .map_err(Error::Join)?
258 .map_err(Error::Read)?
259 }
260 })
261 }
262
read_to_mem( &self, file_offset: Option<u64>, mem: Arc<dyn BackingMemory + Send + Sync>, mem_offsets: impl IntoIterator<Item = MemRegion>, ) -> AsyncResult<usize>263 pub async fn read_to_mem(
264 &self,
265 file_offset: Option<u64>,
266 mem: Arc<dyn BackingMemory + Send + Sync>,
267 mem_offsets: impl IntoIterator<Item = MemRegion>,
268 ) -> AsyncResult<usize> {
269 let mem_offsets_vec: Vec<MemRegion> = mem_offsets.into_iter().collect();
270 Ok(match &self.fd {
271 FdType::Async(async_fd) => {
272 let iovecs = mem_offsets_vec
273 .into_iter()
274 .filter_map(|mem_range| mem.get_volatile_slice(mem_range).ok())
275 .collect::<Vec<VolatileSlice>>();
276 async_fd
277 .async_io(tokio::io::Interest::READABLE, |fd| {
278 do_read_vectored(fd.clone(), file_offset, &iovecs)
279 })
280 .await
281 .map_err(AsyncError::Io)?
282 }
283 FdType::Blocking(blocking) => {
284 let fd = blocking.clone();
285 self.runtime
286 .spawn_blocking(move || {
287 let iovecs = mem_offsets_vec
288 .into_iter()
289 .filter_map(|mem_range| mem.get_volatile_slice(mem_range).ok())
290 .collect::<Vec<VolatileSlice>>();
291 do_read_vectored(fd, file_offset, &iovecs)
292 })
293 .await
294 .map_err(Error::Join)?
295 .map_err(Error::Read)?
296 }
297 })
298 }
299
punch_hole(&self, file_offset: u64, len: u64) -> AsyncResult<()>300 pub async fn punch_hole(&self, file_offset: u64, len: u64) -> AsyncResult<()> {
301 let fd = self.clone_fd();
302 Ok(self
303 .runtime
304 .spawn_blocking(move || fallocate(&*fd, FallocateMode::PunchHole, file_offset, len))
305 .await
306 .map_err(Error::Join)?
307 .map_err(Error::Fallocate)?)
308 }
309
wait_readable(&self) -> AsyncResult<()>310 pub async fn wait_readable(&self) -> AsyncResult<()> {
311 match &self.fd {
312 FdType::Async(async_fd) => async_fd
313 .readable()
314 .await
315 .map_err(crate::AsyncError::Io)?
316 .retain_ready(),
317 FdType::Blocking(_) => return Err(Error::NonWaitable.into()),
318 }
319 Ok(())
320 }
321
write_from_mem( &self, file_offset: Option<u64>, mem: Arc<dyn BackingMemory + Send + Sync>, mem_offsets: impl IntoIterator<Item = MemRegion>, ) -> AsyncResult<usize>322 pub async fn write_from_mem(
323 &self,
324 file_offset: Option<u64>,
325 mem: Arc<dyn BackingMemory + Send + Sync>,
326 mem_offsets: impl IntoIterator<Item = MemRegion>,
327 ) -> AsyncResult<usize> {
328 let mem_offsets_vec: Vec<MemRegion> = mem_offsets.into_iter().collect();
329 Ok(match &self.fd {
330 FdType::Async(async_fd) => {
331 let iovecs = mem_offsets_vec
332 .into_iter()
333 .filter_map(|mem_range| mem.get_volatile_slice(mem_range).ok())
334 .collect::<Vec<VolatileSlice>>();
335 async_fd
336 .async_io(tokio::io::Interest::WRITABLE, |fd| {
337 do_write_vectored(fd.clone(), file_offset, &iovecs)
338 })
339 .await
340 .map_err(AsyncError::Io)?
341 }
342 FdType::Blocking(blocking) => {
343 let fd = blocking.clone();
344 self.runtime
345 .spawn_blocking(move || {
346 let iovecs = mem_offsets_vec
347 .into_iter()
348 .filter_map(|mem_range| mem.get_volatile_slice(mem_range).ok())
349 .collect::<Vec<VolatileSlice>>();
350 do_write_vectored(fd, file_offset, &iovecs)
351 })
352 .await
353 .map_err(Error::Join)?
354 .map_err(Error::Read)?
355 }
356 })
357 }
358
write_from_vec( &self, file_offset: Option<u64>, vec: Vec<u8>, ) -> AsyncResult<(usize, Vec<u8>)>359 pub async fn write_from_vec(
360 &self,
361 file_offset: Option<u64>,
362 vec: Vec<u8>,
363 ) -> AsyncResult<(usize, Vec<u8>)> {
364 Ok(match &self.fd {
365 FdType::Async(async_fd) => {
366 let res = async_fd
367 .async_io(tokio::io::Interest::WRITABLE, |fd| {
368 do_write(fd.clone(), file_offset, &vec)
369 })
370 .await
371 .map_err(AsyncError::Io)?;
372 (res, vec)
373 }
374 FdType::Blocking(blocking) => {
375 let fd = blocking.clone();
376 self.runtime
377 .spawn_blocking(move || {
378 let size = do_write(fd.clone(), file_offset, &vec)?;
379 Ok((size, vec))
380 })
381 .await
382 .map_err(Error::Join)?
383 .map_err(Error::Read)?
384 }
385 })
386 }
387
write_zeroes_at(&self, file_offset: u64, len: u64) -> AsyncResult<()>388 pub async fn write_zeroes_at(&self, file_offset: u64, len: u64) -> AsyncResult<()> {
389 let fd = self.clone_fd();
390 Ok(self
391 .runtime
392 .spawn_blocking(move || fallocate(&*fd, FallocateMode::ZeroRange, file_offset, len))
393 .await
394 .map_err(Error::Join)?
395 .map_err(Error::Fallocate)?)
396 }
397 }
398