1 // Copyright 2023 The ChromiumOS Authors 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 use std::future::Future; 6 use std::pin::Pin; 7 use std::sync::Arc; 8 use std::sync::OnceLock; 9 10 use base::AsRawDescriptors; 11 use base::RawDescriptor; 12 use tokio::runtime::Runtime; 13 use tokio::task::LocalSet; 14 15 use crate::sys::platform::tokio_source::TokioSource; 16 use crate::AsyncError; 17 use crate::AsyncResult; 18 use crate::ExecutorTrait; 19 use crate::IntoAsync; 20 use crate::IoSource; 21 use crate::TaskHandle; 22 23 mod send_wrapper { 24 use std::thread; 25 26 #[derive(Clone)] 27 pub(super) struct SendWrapper<T> { 28 instance: T, 29 thread_id: thread::ThreadId, 30 } 31 32 impl<T> SendWrapper<T> { new(instance: T) -> SendWrapper<T>33 pub(super) fn new(instance: T) -> SendWrapper<T> { 34 SendWrapper { 35 instance, 36 thread_id: thread::current().id(), 37 } 38 } 39 } 40 41 // SAFETY: panics when the value is accessed on the wrong thread. 42 unsafe impl<T> Send for SendWrapper<T> {} 43 // SAFETY: panics when the value is accessed on the wrong thread. 44 unsafe impl<T> Sync for SendWrapper<T> {} 45 46 impl<T> Drop for SendWrapper<T> { drop(&mut self)47 fn drop(&mut self) { 48 if self.thread_id != thread::current().id() { 49 panic!("SendWrapper value was dropped on the wrong thread"); 50 } 51 } 52 } 53 54 impl<T> std::ops::Deref for SendWrapper<T> { 55 type Target = T; 56 deref(&self) -> &T57 fn deref(&self) -> &T { 58 if self.thread_id != thread::current().id() { 59 panic!("SendWrapper value was accessed on the wrong thread"); 60 } 61 &self.instance 62 } 63 } 64 } 65 66 #[derive(Clone)] 67 pub struct TokioExecutor { 68 runtime: Arc<Runtime>, 69 local_set: Arc<OnceLock<send_wrapper::SendWrapper<LocalSet>>>, 70 } 71 72 impl TokioExecutor { new() -> AsyncResult<Self>73 pub fn new() -> AsyncResult<Self> { 74 Ok(TokioExecutor { 75 runtime: Arc::new(Runtime::new().map_err(AsyncError::Io)?), 76 local_set: Arc::new(OnceLock::new()), 77 }) 78 } 79 } 80 81 impl ExecutorTrait for TokioExecutor { async_from<'a, F: IntoAsync + 'a>(&self, f: F) -> AsyncResult<IoSource<F>>82 fn async_from<'a, F: IntoAsync + 'a>(&self, f: F) -> AsyncResult<IoSource<F>> { 83 Ok(IoSource::Tokio(TokioSource::new( 84 f, 85 self.runtime.handle().clone(), 86 )?)) 87 } 88 run_until<F: Future>(&self, f: F) -> AsyncResult<F::Output>89 fn run_until<F: Future>(&self, f: F) -> AsyncResult<F::Output> { 90 let local_set = self 91 .local_set 92 .get_or_init(|| send_wrapper::SendWrapper::new(LocalSet::new())); 93 Ok(self 94 .runtime 95 .block_on(async { local_set.run_until(f).await })) 96 } 97 spawn<F>(&self, f: F) -> TaskHandle<F::Output> where F: Future + Send + 'static, F::Output: Send + 'static,98 fn spawn<F>(&self, f: F) -> TaskHandle<F::Output> 99 where 100 F: Future + Send + 'static, 101 F::Output: Send + 'static, 102 { 103 TaskHandle::Tokio(TokioTaskHandle { 104 join_handle: Some(self.runtime.spawn(f)), 105 }) 106 } 107 spawn_blocking<F, R>(&self, f: F) -> TaskHandle<R> where F: FnOnce() -> R + Send + 'static, R: Send + 'static,108 fn spawn_blocking<F, R>(&self, f: F) -> TaskHandle<R> 109 where 110 F: FnOnce() -> R + Send + 'static, 111 R: Send + 'static, 112 { 113 TaskHandle::Tokio(TokioTaskHandle { 114 join_handle: Some(self.runtime.spawn_blocking(f)), 115 }) 116 } 117 spawn_local<F>(&self, f: F) -> TaskHandle<F::Output> where F: Future + 'static, F::Output: 'static,118 fn spawn_local<F>(&self, f: F) -> TaskHandle<F::Output> 119 where 120 F: Future + 'static, 121 F::Output: 'static, 122 { 123 let local_set = self 124 .local_set 125 .get_or_init(|| send_wrapper::SendWrapper::new(LocalSet::new())); 126 TaskHandle::Tokio(TokioTaskHandle { 127 join_handle: Some(local_set.spawn_local(f)), 128 }) 129 } 130 } 131 132 impl AsRawDescriptors for TokioExecutor { as_raw_descriptors(&self) -> Vec<RawDescriptor>133 fn as_raw_descriptors(&self) -> Vec<RawDescriptor> { 134 todo!(); 135 } 136 } 137 138 pub struct TokioTaskHandle<T> { 139 join_handle: Option<tokio::task::JoinHandle<T>>, 140 } 141 impl<R> TokioTaskHandle<R> { cancel(mut self) -> Option<R>142 pub async fn cancel(mut self) -> Option<R> { 143 match self.join_handle.take() { 144 Some(handle) => { 145 handle.abort(); 146 handle.await.ok() 147 } 148 None => None, 149 } 150 } detach(mut self)151 pub fn detach(mut self) { 152 self.join_handle.take(); 153 } 154 } 155 impl<R: 'static> Future for TokioTaskHandle<R> { 156 type Output = R; poll(self: Pin<&mut Self>, cx: &mut std::task::Context) -> std::task::Poll<Self::Output>157 fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context) -> std::task::Poll<Self::Output> { 158 let self_mut = self.get_mut(); 159 Pin::new(self_mut.join_handle.as_mut().unwrap()) 160 .poll(cx) 161 .map(|v| v.unwrap()) 162 } 163 } 164 impl<T> std::ops::Drop for TokioTaskHandle<T> { drop(&mut self)165 fn drop(&mut self) { 166 if let Some(handle) = self.join_handle.take() { 167 handle.abort() 168 } 169 } 170 } 171