use futures_util::future::{AbortHandle, Abortable}; use std::fmt; use std::fmt::{Debug, Formatter}; use std::future::Future; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use tokio::runtime::Builder; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio::sync::oneshot; use tokio::task::{spawn_local, JoinHandle, LocalSet}; /// A cloneable handle to a local pool, used for spawning `!Send` tasks. /// /// Internally the local pool uses a [`tokio::task::LocalSet`] for each worker thread /// in the pool. Consequently you can also use [`tokio::task::spawn_local`] (which will /// execute on the same thread) inside the Future you supply to the various spawn methods /// of `LocalPoolHandle`. /// /// [`tokio::task::LocalSet`]: tokio::task::LocalSet /// [`tokio::task::spawn_local`]: tokio::task::spawn_local /// /// # Examples /// /// ``` /// use std::rc::Rc; /// use tokio::task; /// use tokio_util::task::LocalPoolHandle; /// /// #[tokio::main(flavor = "current_thread")] /// async fn main() { /// let pool = LocalPoolHandle::new(5); /// /// let output = pool.spawn_pinned(|| { /// // `data` is !Send + !Sync /// let data = Rc::new("local data"); /// let data_clone = data.clone(); /// /// async move { /// task::spawn_local(async move { /// println!("{}", data_clone); /// }); /// /// data.to_string() /// } /// }).await.unwrap(); /// println!("output: {}", output); /// } /// ``` /// #[derive(Clone)] pub struct LocalPoolHandle { pool: Arc, } impl LocalPoolHandle { /// Create a new pool of threads to handle `!Send` tasks. Spawn tasks onto this /// pool via [`LocalPoolHandle::spawn_pinned`]. /// /// # Panics /// /// Panics if the pool size is less than one. #[track_caller] pub fn new(pool_size: usize) -> LocalPoolHandle { assert!(pool_size > 0); let workers = (0..pool_size) .map(|_| LocalWorkerHandle::new_worker()) .collect(); let pool = Arc::new(LocalPool { workers }); LocalPoolHandle { pool } } /// Returns the number of threads of the Pool. #[inline] pub fn num_threads(&self) -> usize { self.pool.workers.len() } /// Returns the number of tasks scheduled on each worker. The indices of the /// worker threads correspond to the indices of the returned `Vec`. pub fn get_task_loads_for_each_worker(&self) -> Vec { self.pool .workers .iter() .map(|worker| worker.task_count.load(Ordering::SeqCst)) .collect::>() } /// Spawn a task onto a worker thread and pin it there so it can't be moved /// off of the thread. Note that the future is not [`Send`], but the /// [`FnOnce`] which creates it is. /// /// # Examples /// ``` /// use std::rc::Rc; /// use tokio_util::task::LocalPoolHandle; /// /// #[tokio::main] /// async fn main() { /// // Create the local pool /// let pool = LocalPoolHandle::new(1); /// /// // Spawn a !Send future onto the pool and await it /// let output = pool /// .spawn_pinned(|| { /// // Rc is !Send + !Sync /// let local_data = Rc::new("test"); /// /// // This future holds an Rc, so it is !Send /// async move { local_data.to_string() } /// }) /// .await /// .unwrap(); /// /// assert_eq!(output, "test"); /// } /// ``` pub fn spawn_pinned(&self, create_task: F) -> JoinHandle where F: FnOnce() -> Fut, F: Send + 'static, Fut: Future + 'static, Fut::Output: Send + 'static, { self.pool .spawn_pinned(create_task, WorkerChoice::LeastBurdened) } /// Differs from `spawn_pinned` only in that you can choose a specific worker thread /// of the pool, whereas `spawn_pinned` chooses the worker with the smallest /// number of tasks scheduled. /// /// A worker thread is chosen by index. Indices are 0 based and the largest index /// is given by `num_threads() - 1` /// /// # Panics /// /// This method panics if the index is out of bounds. /// /// # Examples /// /// This method can be used to spawn a task on all worker threads of the pool: /// /// ``` /// use tokio_util::task::LocalPoolHandle; /// /// #[tokio::main] /// async fn main() { /// const NUM_WORKERS: usize = 3; /// let pool = LocalPoolHandle::new(NUM_WORKERS); /// let handles = (0..pool.num_threads()) /// .map(|worker_idx| { /// pool.spawn_pinned_by_idx( /// || { /// async { /// "test" /// } /// }, /// worker_idx, /// ) /// }) /// .collect::>(); /// /// for handle in handles { /// handle.await.unwrap(); /// } /// } /// ``` /// #[track_caller] pub fn spawn_pinned_by_idx(&self, create_task: F, idx: usize) -> JoinHandle where F: FnOnce() -> Fut, F: Send + 'static, Fut: Future + 'static, Fut::Output: Send + 'static, { self.pool .spawn_pinned(create_task, WorkerChoice::ByIdx(idx)) } } impl Debug for LocalPoolHandle { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.write_str("LocalPoolHandle") } } enum WorkerChoice { LeastBurdened, ByIdx(usize), } struct LocalPool { workers: Box<[LocalWorkerHandle]>, } impl LocalPool { /// Spawn a `?Send` future onto a worker #[track_caller] fn spawn_pinned( &self, create_task: F, worker_choice: WorkerChoice, ) -> JoinHandle where F: FnOnce() -> Fut, F: Send + 'static, Fut: Future + 'static, Fut::Output: Send + 'static, { let (sender, receiver) = oneshot::channel(); let (worker, job_guard) = match worker_choice { WorkerChoice::LeastBurdened => self.find_and_incr_least_burdened_worker(), WorkerChoice::ByIdx(idx) => self.find_worker_by_idx(idx), }; let worker_spawner = worker.spawner.clone(); // Spawn a future onto the worker's runtime so we can immediately return // a join handle. worker.runtime_handle.spawn(async move { // Move the job guard into the task let _job_guard = job_guard; // Propagate aborts via Abortable/AbortHandle let (abort_handle, abort_registration) = AbortHandle::new_pair(); let _abort_guard = AbortGuard(abort_handle); // Inside the future we can't run spawn_local yet because we're not // in the context of a LocalSet. We need to send create_task to the // LocalSet task for spawning. let spawn_task = Box::new(move || { // Once we're in the LocalSet context we can call spawn_local let join_handle = spawn_local( async move { Abortable::new(create_task(), abort_registration).await }, ); // Send the join handle back to the spawner. If sending fails, // we assume the parent task was canceled, so cancel this task // as well. if let Err(join_handle) = sender.send(join_handle) { join_handle.abort() } }); // Send the callback to the LocalSet task if let Err(e) = worker_spawner.send(spawn_task) { // Propagate the error as a panic in the join handle. panic!("Failed to send job to worker: {e}"); } // Wait for the task's join handle let join_handle = match receiver.await { Ok(handle) => handle, Err(e) => { // We sent the task successfully, but failed to get its // join handle... We assume something happened to the worker // and the task was not spawned. Propagate the error as a // panic in the join handle. panic!("Worker failed to send join handle: {e}"); } }; // Wait for the task to complete let join_result = join_handle.await; match join_result { Ok(Ok(output)) => output, Ok(Err(_)) => { // Pinned task was aborted. But that only happens if this // task is aborted. So this is an impossible branch. unreachable!( "Reaching this branch means this task was previously \ aborted but it continued running anyways" ) } Err(e) => { if e.is_panic() { std::panic::resume_unwind(e.into_panic()); } else if e.is_cancelled() { // No one else should have the join handle, so this is // unexpected. Forward this error as a panic in the join // handle. panic!("spawn_pinned task was canceled: {e}"); } else { // Something unknown happened (not a panic or // cancellation). Forward this error as a panic in the // join handle. panic!("spawn_pinned task failed: {e}"); } } } }) } /// Find the worker with the least number of tasks, increment its task /// count, and return its handle. Make sure to actually spawn a task on /// the worker so the task count is kept consistent with load. /// /// A job count guard is also returned to ensure the task count gets /// decremented when the job is done. fn find_and_incr_least_burdened_worker(&self) -> (&LocalWorkerHandle, JobCountGuard) { loop { let (worker, task_count) = self .workers .iter() .map(|worker| (worker, worker.task_count.load(Ordering::SeqCst))) .min_by_key(|&(_, count)| count) .expect("There must be more than one worker"); // Make sure the task count hasn't changed since when we choose this // worker. Otherwise, restart the search. if worker .task_count .compare_exchange( task_count, task_count + 1, Ordering::SeqCst, Ordering::Relaxed, ) .is_ok() { return (worker, JobCountGuard(Arc::clone(&worker.task_count))); } } } #[track_caller] fn find_worker_by_idx(&self, idx: usize) -> (&LocalWorkerHandle, JobCountGuard) { let worker = &self.workers[idx]; worker.task_count.fetch_add(1, Ordering::SeqCst); (worker, JobCountGuard(Arc::clone(&worker.task_count))) } } /// Automatically decrements a worker's job count when a job finishes (when /// this gets dropped). struct JobCountGuard(Arc); impl Drop for JobCountGuard { fn drop(&mut self) { // Decrement the job count let previous_value = self.0.fetch_sub(1, Ordering::SeqCst); debug_assert!(previous_value >= 1); } } /// Calls abort on the handle when dropped. struct AbortGuard(AbortHandle); impl Drop for AbortGuard { fn drop(&mut self) { self.0.abort(); } } type PinnedFutureSpawner = Box; struct LocalWorkerHandle { runtime_handle: tokio::runtime::Handle, spawner: UnboundedSender, task_count: Arc, } impl LocalWorkerHandle { /// Create a new worker for executing pinned tasks fn new_worker() -> LocalWorkerHandle { let (sender, receiver) = unbounded_channel(); let runtime = Builder::new_current_thread() .enable_all() .build() .expect("Failed to start a pinned worker thread runtime"); let runtime_handle = runtime.handle().clone(); let task_count = Arc::new(AtomicUsize::new(0)); let task_count_clone = Arc::clone(&task_count); std::thread::spawn(|| Self::run(runtime, receiver, task_count_clone)); LocalWorkerHandle { runtime_handle, spawner: sender, task_count, } } fn run( runtime: tokio::runtime::Runtime, mut task_receiver: UnboundedReceiver, task_count: Arc, ) { let local_set = LocalSet::new(); local_set.block_on(&runtime, async { while let Some(spawn_task) = task_receiver.recv().await { // Calls spawn_local(future) (spawn_task)(); } }); // If there are any tasks on the runtime associated with a LocalSet task // that has already completed, but whose output has not yet been // reported, let that task complete. // // Since the task_count is decremented when the runtime task exits, // reading that counter lets us know if any such tasks completed during // the call to `block_on`. // // Tasks on the LocalSet can't complete during this loop since they're // stored on the LocalSet and we aren't accessing it. let mut previous_task_count = task_count.load(Ordering::SeqCst); loop { // This call will also run tasks spawned on the runtime. runtime.block_on(tokio::task::yield_now()); let new_task_count = task_count.load(Ordering::SeqCst); if new_task_count == previous_task_count { break; } else { previous_task_count = new_task_count; } } // It's now no longer possible for a task on the runtime to be // associated with a LocalSet task that has completed. Drop both the // LocalSet and runtime to let tasks on the runtime be cancelled if and // only if they are still on the LocalSet. // // Drop the LocalSet task first so that anyone awaiting the runtime // JoinHandle will see the cancelled error after the LocalSet task // destructor has completed. drop(local_set); drop(runtime); } }