xref: /aosp_15_r20/external/crosvm/cros_async/src/tokio_executor.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2023 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::future::Future;
6 use std::pin::Pin;
7 use std::sync::Arc;
8 use std::sync::OnceLock;
9 
10 use base::AsRawDescriptors;
11 use base::RawDescriptor;
12 use tokio::runtime::Runtime;
13 use tokio::task::LocalSet;
14 
15 use crate::sys::platform::tokio_source::TokioSource;
16 use crate::AsyncError;
17 use crate::AsyncResult;
18 use crate::ExecutorTrait;
19 use crate::IntoAsync;
20 use crate::IoSource;
21 use crate::TaskHandle;
22 
23 mod send_wrapper {
24     use std::thread;
25 
26     #[derive(Clone)]
27     pub(super) struct SendWrapper<T> {
28         instance: T,
29         thread_id: thread::ThreadId,
30     }
31 
32     impl<T> SendWrapper<T> {
new(instance: T) -> SendWrapper<T>33         pub(super) fn new(instance: T) -> SendWrapper<T> {
34             SendWrapper {
35                 instance,
36                 thread_id: thread::current().id(),
37             }
38         }
39     }
40 
41     // SAFETY: panics when the value is accessed on the wrong thread.
42     unsafe impl<T> Send for SendWrapper<T> {}
43     // SAFETY: panics when the value is accessed on the wrong thread.
44     unsafe impl<T> Sync for SendWrapper<T> {}
45 
46     impl<T> Drop for SendWrapper<T> {
drop(&mut self)47         fn drop(&mut self) {
48             if self.thread_id != thread::current().id() {
49                 panic!("SendWrapper value was dropped on the wrong thread");
50             }
51         }
52     }
53 
54     impl<T> std::ops::Deref for SendWrapper<T> {
55         type Target = T;
56 
deref(&self) -> &T57         fn deref(&self) -> &T {
58             if self.thread_id != thread::current().id() {
59                 panic!("SendWrapper value was accessed on the wrong thread");
60             }
61             &self.instance
62         }
63     }
64 }
65 
66 #[derive(Clone)]
67 pub struct TokioExecutor {
68     runtime: Arc<Runtime>,
69     local_set: Arc<OnceLock<send_wrapper::SendWrapper<LocalSet>>>,
70 }
71 
72 impl TokioExecutor {
new() -> AsyncResult<Self>73     pub fn new() -> AsyncResult<Self> {
74         Ok(TokioExecutor {
75             runtime: Arc::new(Runtime::new().map_err(AsyncError::Io)?),
76             local_set: Arc::new(OnceLock::new()),
77         })
78     }
79 }
80 
81 impl ExecutorTrait for TokioExecutor {
async_from<'a, F: IntoAsync + 'a>(&self, f: F) -> AsyncResult<IoSource<F>>82     fn async_from<'a, F: IntoAsync + 'a>(&self, f: F) -> AsyncResult<IoSource<F>> {
83         Ok(IoSource::Tokio(TokioSource::new(
84             f,
85             self.runtime.handle().clone(),
86         )?))
87     }
88 
run_until<F: Future>(&self, f: F) -> AsyncResult<F::Output>89     fn run_until<F: Future>(&self, f: F) -> AsyncResult<F::Output> {
90         let local_set = self
91             .local_set
92             .get_or_init(|| send_wrapper::SendWrapper::new(LocalSet::new()));
93         Ok(self
94             .runtime
95             .block_on(async { local_set.run_until(f).await }))
96     }
97 
spawn<F>(&self, f: F) -> TaskHandle<F::Output> where F: Future + Send + 'static, F::Output: Send + 'static,98     fn spawn<F>(&self, f: F) -> TaskHandle<F::Output>
99     where
100         F: Future + Send + 'static,
101         F::Output: Send + 'static,
102     {
103         TaskHandle::Tokio(TokioTaskHandle {
104             join_handle: Some(self.runtime.spawn(f)),
105         })
106     }
107 
spawn_blocking<F, R>(&self, f: F) -> TaskHandle<R> where F: FnOnce() -> R + Send + 'static, R: Send + 'static,108     fn spawn_blocking<F, R>(&self, f: F) -> TaskHandle<R>
109     where
110         F: FnOnce() -> R + Send + 'static,
111         R: Send + 'static,
112     {
113         TaskHandle::Tokio(TokioTaskHandle {
114             join_handle: Some(self.runtime.spawn_blocking(f)),
115         })
116     }
117 
spawn_local<F>(&self, f: F) -> TaskHandle<F::Output> where F: Future + 'static, F::Output: 'static,118     fn spawn_local<F>(&self, f: F) -> TaskHandle<F::Output>
119     where
120         F: Future + 'static,
121         F::Output: 'static,
122     {
123         let local_set = self
124             .local_set
125             .get_or_init(|| send_wrapper::SendWrapper::new(LocalSet::new()));
126         TaskHandle::Tokio(TokioTaskHandle {
127             join_handle: Some(local_set.spawn_local(f)),
128         })
129     }
130 }
131 
132 impl AsRawDescriptors for TokioExecutor {
as_raw_descriptors(&self) -> Vec<RawDescriptor>133     fn as_raw_descriptors(&self) -> Vec<RawDescriptor> {
134         todo!();
135     }
136 }
137 
138 pub struct TokioTaskHandle<T> {
139     join_handle: Option<tokio::task::JoinHandle<T>>,
140 }
141 impl<R> TokioTaskHandle<R> {
cancel(mut self) -> Option<R>142     pub async fn cancel(mut self) -> Option<R> {
143         match self.join_handle.take() {
144             Some(handle) => {
145                 handle.abort();
146                 handle.await.ok()
147             }
148             None => None,
149         }
150     }
detach(mut self)151     pub fn detach(mut self) {
152         self.join_handle.take();
153     }
154 }
155 impl<R: 'static> Future for TokioTaskHandle<R> {
156     type Output = R;
poll(self: Pin<&mut Self>, cx: &mut std::task::Context) -> std::task::Poll<Self::Output>157     fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context) -> std::task::Poll<Self::Output> {
158         let self_mut = self.get_mut();
159         Pin::new(self_mut.join_handle.as_mut().unwrap())
160             .poll(cx)
161             .map(|v| v.unwrap())
162     }
163 }
164 impl<T> std::ops::Drop for TokioTaskHandle<T> {
drop(&mut self)165     fn drop(&mut self) {
166         if let Some(handle) = self.join_handle.take() {
167             handle.abort()
168         }
169     }
170 }
171