1 use futures_util::future::{AbortHandle, Abortable}; 2 use std::fmt; 3 use std::fmt::{Debug, Formatter}; 4 use std::future::Future; 5 use std::sync::atomic::{AtomicUsize, Ordering}; 6 use std::sync::Arc; 7 use tokio::runtime::Builder; 8 use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; 9 use tokio::sync::oneshot; 10 use tokio::task::{spawn_local, JoinHandle, LocalSet}; 11 12 /// A cloneable handle to a local pool, used for spawning `!Send` tasks. 13 /// 14 /// Internally the local pool uses a [`tokio::task::LocalSet`] for each worker thread 15 /// in the pool. Consequently you can also use [`tokio::task::spawn_local`] (which will 16 /// execute on the same thread) inside the Future you supply to the various spawn methods 17 /// of `LocalPoolHandle`. 18 /// 19 /// [`tokio::task::LocalSet`]: tokio::task::LocalSet 20 /// [`tokio::task::spawn_local`]: tokio::task::spawn_local 21 /// 22 /// # Examples 23 /// 24 /// ``` 25 /// use std::rc::Rc; 26 /// use tokio::task; 27 /// use tokio_util::task::LocalPoolHandle; 28 /// 29 /// #[tokio::main(flavor = "current_thread")] 30 /// async fn main() { 31 /// let pool = LocalPoolHandle::new(5); 32 /// 33 /// let output = pool.spawn_pinned(|| { 34 /// // `data` is !Send + !Sync 35 /// let data = Rc::new("local data"); 36 /// let data_clone = data.clone(); 37 /// 38 /// async move { 39 /// task::spawn_local(async move { 40 /// println!("{}", data_clone); 41 /// }); 42 /// 43 /// data.to_string() 44 /// } 45 /// }).await.unwrap(); 46 /// println!("output: {}", output); 47 /// } 48 /// ``` 49 /// 50 #[derive(Clone)] 51 pub struct LocalPoolHandle { 52 pool: Arc<LocalPool>, 53 } 54 55 impl LocalPoolHandle { 56 /// Create a new pool of threads to handle `!Send` tasks. Spawn tasks onto this 57 /// pool via [`LocalPoolHandle::spawn_pinned`]. 58 /// 59 /// # Panics 60 /// 61 /// Panics if the pool size is less than one. 62 #[track_caller] new(pool_size: usize) -> LocalPoolHandle63 pub fn new(pool_size: usize) -> LocalPoolHandle { 64 assert!(pool_size > 0); 65 66 let workers = (0..pool_size) 67 .map(|_| LocalWorkerHandle::new_worker()) 68 .collect(); 69 70 let pool = Arc::new(LocalPool { workers }); 71 72 LocalPoolHandle { pool } 73 } 74 75 /// Returns the number of threads of the Pool. 76 #[inline] num_threads(&self) -> usize77 pub fn num_threads(&self) -> usize { 78 self.pool.workers.len() 79 } 80 81 /// Returns the number of tasks scheduled on each worker. The indices of the 82 /// worker threads correspond to the indices of the returned `Vec`. get_task_loads_for_each_worker(&self) -> Vec<usize>83 pub fn get_task_loads_for_each_worker(&self) -> Vec<usize> { 84 self.pool 85 .workers 86 .iter() 87 .map(|worker| worker.task_count.load(Ordering::SeqCst)) 88 .collect::<Vec<_>>() 89 } 90 91 /// Spawn a task onto a worker thread and pin it there so it can't be moved 92 /// off of the thread. Note that the future is not [`Send`], but the 93 /// [`FnOnce`] which creates it is. 94 /// 95 /// # Examples 96 /// ``` 97 /// use std::rc::Rc; 98 /// use tokio_util::task::LocalPoolHandle; 99 /// 100 /// #[tokio::main] 101 /// async fn main() { 102 /// // Create the local pool 103 /// let pool = LocalPoolHandle::new(1); 104 /// 105 /// // Spawn a !Send future onto the pool and await it 106 /// let output = pool 107 /// .spawn_pinned(|| { 108 /// // Rc is !Send + !Sync 109 /// let local_data = Rc::new("test"); 110 /// 111 /// // This future holds an Rc, so it is !Send 112 /// async move { local_data.to_string() } 113 /// }) 114 /// .await 115 /// .unwrap(); 116 /// 117 /// assert_eq!(output, "test"); 118 /// } 119 /// ``` spawn_pinned<F, Fut>(&self, create_task: F) -> JoinHandle<Fut::Output> where F: FnOnce() -> Fut, F: Send + 'static, Fut: Future + 'static, Fut::Output: Send + 'static,120 pub fn spawn_pinned<F, Fut>(&self, create_task: F) -> JoinHandle<Fut::Output> 121 where 122 F: FnOnce() -> Fut, 123 F: Send + 'static, 124 Fut: Future + 'static, 125 Fut::Output: Send + 'static, 126 { 127 self.pool 128 .spawn_pinned(create_task, WorkerChoice::LeastBurdened) 129 } 130 131 /// Differs from `spawn_pinned` only in that you can choose a specific worker thread 132 /// of the pool, whereas `spawn_pinned` chooses the worker with the smallest 133 /// number of tasks scheduled. 134 /// 135 /// A worker thread is chosen by index. Indices are 0 based and the largest index 136 /// is given by `num_threads() - 1` 137 /// 138 /// # Panics 139 /// 140 /// This method panics if the index is out of bounds. 141 /// 142 /// # Examples 143 /// 144 /// This method can be used to spawn a task on all worker threads of the pool: 145 /// 146 /// ``` 147 /// use tokio_util::task::LocalPoolHandle; 148 /// 149 /// #[tokio::main] 150 /// async fn main() { 151 /// const NUM_WORKERS: usize = 3; 152 /// let pool = LocalPoolHandle::new(NUM_WORKERS); 153 /// let handles = (0..pool.num_threads()) 154 /// .map(|worker_idx| { 155 /// pool.spawn_pinned_by_idx( 156 /// || { 157 /// async { 158 /// "test" 159 /// } 160 /// }, 161 /// worker_idx, 162 /// ) 163 /// }) 164 /// .collect::<Vec<_>>(); 165 /// 166 /// for handle in handles { 167 /// handle.await.unwrap(); 168 /// } 169 /// } 170 /// ``` 171 /// 172 #[track_caller] spawn_pinned_by_idx<F, Fut>(&self, create_task: F, idx: usize) -> JoinHandle<Fut::Output> where F: FnOnce() -> Fut, F: Send + 'static, Fut: Future + 'static, Fut::Output: Send + 'static,173 pub fn spawn_pinned_by_idx<F, Fut>(&self, create_task: F, idx: usize) -> JoinHandle<Fut::Output> 174 where 175 F: FnOnce() -> Fut, 176 F: Send + 'static, 177 Fut: Future + 'static, 178 Fut::Output: Send + 'static, 179 { 180 self.pool 181 .spawn_pinned(create_task, WorkerChoice::ByIdx(idx)) 182 } 183 } 184 185 impl Debug for LocalPoolHandle { fmt(&self, f: &mut Formatter<'_>) -> fmt::Result186 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { 187 f.write_str("LocalPoolHandle") 188 } 189 } 190 191 enum WorkerChoice { 192 LeastBurdened, 193 ByIdx(usize), 194 } 195 196 struct LocalPool { 197 workers: Box<[LocalWorkerHandle]>, 198 } 199 200 impl LocalPool { 201 /// Spawn a `?Send` future onto a worker 202 #[track_caller] spawn_pinned<F, Fut>( &self, create_task: F, worker_choice: WorkerChoice, ) -> JoinHandle<Fut::Output> where F: FnOnce() -> Fut, F: Send + 'static, Fut: Future + 'static, Fut::Output: Send + 'static,203 fn spawn_pinned<F, Fut>( 204 &self, 205 create_task: F, 206 worker_choice: WorkerChoice, 207 ) -> JoinHandle<Fut::Output> 208 where 209 F: FnOnce() -> Fut, 210 F: Send + 'static, 211 Fut: Future + 'static, 212 Fut::Output: Send + 'static, 213 { 214 let (sender, receiver) = oneshot::channel(); 215 let (worker, job_guard) = match worker_choice { 216 WorkerChoice::LeastBurdened => self.find_and_incr_least_burdened_worker(), 217 WorkerChoice::ByIdx(idx) => self.find_worker_by_idx(idx), 218 }; 219 let worker_spawner = worker.spawner.clone(); 220 221 // Spawn a future onto the worker's runtime so we can immediately return 222 // a join handle. 223 worker.runtime_handle.spawn(async move { 224 // Move the job guard into the task 225 let _job_guard = job_guard; 226 227 // Propagate aborts via Abortable/AbortHandle 228 let (abort_handle, abort_registration) = AbortHandle::new_pair(); 229 let _abort_guard = AbortGuard(abort_handle); 230 231 // Inside the future we can't run spawn_local yet because we're not 232 // in the context of a LocalSet. We need to send create_task to the 233 // LocalSet task for spawning. 234 let spawn_task = Box::new(move || { 235 // Once we're in the LocalSet context we can call spawn_local 236 let join_handle = 237 spawn_local( 238 async move { Abortable::new(create_task(), abort_registration).await }, 239 ); 240 241 // Send the join handle back to the spawner. If sending fails, 242 // we assume the parent task was canceled, so cancel this task 243 // as well. 244 if let Err(join_handle) = sender.send(join_handle) { 245 join_handle.abort() 246 } 247 }); 248 249 // Send the callback to the LocalSet task 250 if let Err(e) = worker_spawner.send(spawn_task) { 251 // Propagate the error as a panic in the join handle. 252 panic!("Failed to send job to worker: {e}"); 253 } 254 255 // Wait for the task's join handle 256 let join_handle = match receiver.await { 257 Ok(handle) => handle, 258 Err(e) => { 259 // We sent the task successfully, but failed to get its 260 // join handle... We assume something happened to the worker 261 // and the task was not spawned. Propagate the error as a 262 // panic in the join handle. 263 panic!("Worker failed to send join handle: {e}"); 264 } 265 }; 266 267 // Wait for the task to complete 268 let join_result = join_handle.await; 269 270 match join_result { 271 Ok(Ok(output)) => output, 272 Ok(Err(_)) => { 273 // Pinned task was aborted. But that only happens if this 274 // task is aborted. So this is an impossible branch. 275 unreachable!( 276 "Reaching this branch means this task was previously \ 277 aborted but it continued running anyways" 278 ) 279 } 280 Err(e) => { 281 if e.is_panic() { 282 std::panic::resume_unwind(e.into_panic()); 283 } else if e.is_cancelled() { 284 // No one else should have the join handle, so this is 285 // unexpected. Forward this error as a panic in the join 286 // handle. 287 panic!("spawn_pinned task was canceled: {e}"); 288 } else { 289 // Something unknown happened (not a panic or 290 // cancellation). Forward this error as a panic in the 291 // join handle. 292 panic!("spawn_pinned task failed: {e}"); 293 } 294 } 295 } 296 }) 297 } 298 299 /// Find the worker with the least number of tasks, increment its task 300 /// count, and return its handle. Make sure to actually spawn a task on 301 /// the worker so the task count is kept consistent with load. 302 /// 303 /// A job count guard is also returned to ensure the task count gets 304 /// decremented when the job is done. find_and_incr_least_burdened_worker(&self) -> (&LocalWorkerHandle, JobCountGuard)305 fn find_and_incr_least_burdened_worker(&self) -> (&LocalWorkerHandle, JobCountGuard) { 306 loop { 307 let (worker, task_count) = self 308 .workers 309 .iter() 310 .map(|worker| (worker, worker.task_count.load(Ordering::SeqCst))) 311 .min_by_key(|&(_, count)| count) 312 .expect("There must be more than one worker"); 313 314 // Make sure the task count hasn't changed since when we choose this 315 // worker. Otherwise, restart the search. 316 if worker 317 .task_count 318 .compare_exchange( 319 task_count, 320 task_count + 1, 321 Ordering::SeqCst, 322 Ordering::Relaxed, 323 ) 324 .is_ok() 325 { 326 return (worker, JobCountGuard(Arc::clone(&worker.task_count))); 327 } 328 } 329 } 330 331 #[track_caller] find_worker_by_idx(&self, idx: usize) -> (&LocalWorkerHandle, JobCountGuard)332 fn find_worker_by_idx(&self, idx: usize) -> (&LocalWorkerHandle, JobCountGuard) { 333 let worker = &self.workers[idx]; 334 worker.task_count.fetch_add(1, Ordering::SeqCst); 335 336 (worker, JobCountGuard(Arc::clone(&worker.task_count))) 337 } 338 } 339 340 /// Automatically decrements a worker's job count when a job finishes (when 341 /// this gets dropped). 342 struct JobCountGuard(Arc<AtomicUsize>); 343 344 impl Drop for JobCountGuard { drop(&mut self)345 fn drop(&mut self) { 346 // Decrement the job count 347 let previous_value = self.0.fetch_sub(1, Ordering::SeqCst); 348 debug_assert!(previous_value >= 1); 349 } 350 } 351 352 /// Calls abort on the handle when dropped. 353 struct AbortGuard(AbortHandle); 354 355 impl Drop for AbortGuard { drop(&mut self)356 fn drop(&mut self) { 357 self.0.abort(); 358 } 359 } 360 361 type PinnedFutureSpawner = Box<dyn FnOnce() + Send + 'static>; 362 363 struct LocalWorkerHandle { 364 runtime_handle: tokio::runtime::Handle, 365 spawner: UnboundedSender<PinnedFutureSpawner>, 366 task_count: Arc<AtomicUsize>, 367 } 368 369 impl LocalWorkerHandle { 370 /// Create a new worker for executing pinned tasks new_worker() -> LocalWorkerHandle371 fn new_worker() -> LocalWorkerHandle { 372 let (sender, receiver) = unbounded_channel(); 373 let runtime = Builder::new_current_thread() 374 .enable_all() 375 .build() 376 .expect("Failed to start a pinned worker thread runtime"); 377 let runtime_handle = runtime.handle().clone(); 378 let task_count = Arc::new(AtomicUsize::new(0)); 379 let task_count_clone = Arc::clone(&task_count); 380 381 std::thread::spawn(|| Self::run(runtime, receiver, task_count_clone)); 382 383 LocalWorkerHandle { 384 runtime_handle, 385 spawner: sender, 386 task_count, 387 } 388 } 389 run( runtime: tokio::runtime::Runtime, mut task_receiver: UnboundedReceiver<PinnedFutureSpawner>, task_count: Arc<AtomicUsize>, )390 fn run( 391 runtime: tokio::runtime::Runtime, 392 mut task_receiver: UnboundedReceiver<PinnedFutureSpawner>, 393 task_count: Arc<AtomicUsize>, 394 ) { 395 let local_set = LocalSet::new(); 396 local_set.block_on(&runtime, async { 397 while let Some(spawn_task) = task_receiver.recv().await { 398 // Calls spawn_local(future) 399 (spawn_task)(); 400 } 401 }); 402 403 // If there are any tasks on the runtime associated with a LocalSet task 404 // that has already completed, but whose output has not yet been 405 // reported, let that task complete. 406 // 407 // Since the task_count is decremented when the runtime task exits, 408 // reading that counter lets us know if any such tasks completed during 409 // the call to `block_on`. 410 // 411 // Tasks on the LocalSet can't complete during this loop since they're 412 // stored on the LocalSet and we aren't accessing it. 413 let mut previous_task_count = task_count.load(Ordering::SeqCst); 414 loop { 415 // This call will also run tasks spawned on the runtime. 416 runtime.block_on(tokio::task::yield_now()); 417 let new_task_count = task_count.load(Ordering::SeqCst); 418 if new_task_count == previous_task_count { 419 break; 420 } else { 421 previous_task_count = new_task_count; 422 } 423 } 424 425 // It's now no longer possible for a task on the runtime to be 426 // associated with a LocalSet task that has completed. Drop both the 427 // LocalSet and runtime to let tasks on the runtime be cancelled if and 428 // only if they are still on the LocalSet. 429 // 430 // Drop the LocalSet task first so that anyone awaiting the runtime 431 // JoinHandle will see the cancelled error after the LocalSet task 432 // destructor has completed. 433 drop(local_set); 434 drop(runtime); 435 } 436 } 437