1 use futures_util::future::{AbortHandle, Abortable};
2 use std::fmt;
3 use std::fmt::{Debug, Formatter};
4 use std::future::Future;
5 use std::sync::atomic::{AtomicUsize, Ordering};
6 use std::sync::Arc;
7 use tokio::runtime::Builder;
8 use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
9 use tokio::sync::oneshot;
10 use tokio::task::{spawn_local, JoinHandle, LocalSet};
11 
12 /// A cloneable handle to a local pool, used for spawning `!Send` tasks.
13 ///
14 /// Internally the local pool uses a [`tokio::task::LocalSet`] for each worker thread
15 /// in the pool. Consequently you can also use [`tokio::task::spawn_local`] (which will
16 /// execute on the same thread) inside the Future you supply to the various spawn methods
17 /// of `LocalPoolHandle`.
18 ///
19 /// [`tokio::task::LocalSet`]: tokio::task::LocalSet
20 /// [`tokio::task::spawn_local`]: tokio::task::spawn_local
21 ///
22 /// # Examples
23 ///
24 /// ```
25 /// use std::rc::Rc;
26 /// use tokio::task;
27 /// use tokio_util::task::LocalPoolHandle;
28 ///
29 /// #[tokio::main(flavor = "current_thread")]
30 /// async fn main() {
31 ///     let pool = LocalPoolHandle::new(5);
32 ///
33 ///     let output = pool.spawn_pinned(|| {
34 ///         // `data` is !Send + !Sync
35 ///         let data = Rc::new("local data");
36 ///         let data_clone = data.clone();
37 ///
38 ///         async move {
39 ///             task::spawn_local(async move {
40 ///                 println!("{}", data_clone);
41 ///             });
42 ///
43 ///             data.to_string()
44 ///         }
45 ///     }).await.unwrap();
46 ///     println!("output: {}", output);
47 /// }
48 /// ```
49 ///
50 #[derive(Clone)]
51 pub struct LocalPoolHandle {
52     pool: Arc<LocalPool>,
53 }
54 
55 impl LocalPoolHandle {
56     /// Create a new pool of threads to handle `!Send` tasks. Spawn tasks onto this
57     /// pool via [`LocalPoolHandle::spawn_pinned`].
58     ///
59     /// # Panics
60     ///
61     /// Panics if the pool size is less than one.
62     #[track_caller]
new(pool_size: usize) -> LocalPoolHandle63     pub fn new(pool_size: usize) -> LocalPoolHandle {
64         assert!(pool_size > 0);
65 
66         let workers = (0..pool_size)
67             .map(|_| LocalWorkerHandle::new_worker())
68             .collect();
69 
70         let pool = Arc::new(LocalPool { workers });
71 
72         LocalPoolHandle { pool }
73     }
74 
75     /// Returns the number of threads of the Pool.
76     #[inline]
num_threads(&self) -> usize77     pub fn num_threads(&self) -> usize {
78         self.pool.workers.len()
79     }
80 
81     /// Returns the number of tasks scheduled on each worker. The indices of the
82     /// worker threads correspond to the indices of the returned `Vec`.
get_task_loads_for_each_worker(&self) -> Vec<usize>83     pub fn get_task_loads_for_each_worker(&self) -> Vec<usize> {
84         self.pool
85             .workers
86             .iter()
87             .map(|worker| worker.task_count.load(Ordering::SeqCst))
88             .collect::<Vec<_>>()
89     }
90 
91     /// Spawn a task onto a worker thread and pin it there so it can't be moved
92     /// off of the thread. Note that the future is not [`Send`], but the
93     /// [`FnOnce`] which creates it is.
94     ///
95     /// # Examples
96     /// ```
97     /// use std::rc::Rc;
98     /// use tokio_util::task::LocalPoolHandle;
99     ///
100     /// #[tokio::main]
101     /// async fn main() {
102     ///     // Create the local pool
103     ///     let pool = LocalPoolHandle::new(1);
104     ///
105     ///     // Spawn a !Send future onto the pool and await it
106     ///     let output = pool
107     ///         .spawn_pinned(|| {
108     ///             // Rc is !Send + !Sync
109     ///             let local_data = Rc::new("test");
110     ///
111     ///             // This future holds an Rc, so it is !Send
112     ///             async move { local_data.to_string() }
113     ///         })
114     ///         .await
115     ///         .unwrap();
116     ///
117     ///     assert_eq!(output, "test");
118     /// }
119     /// ```
spawn_pinned<F, Fut>(&self, create_task: F) -> JoinHandle<Fut::Output> where F: FnOnce() -> Fut, F: Send + 'static, Fut: Future + 'static, Fut::Output: Send + 'static,120     pub fn spawn_pinned<F, Fut>(&self, create_task: F) -> JoinHandle<Fut::Output>
121     where
122         F: FnOnce() -> Fut,
123         F: Send + 'static,
124         Fut: Future + 'static,
125         Fut::Output: Send + 'static,
126     {
127         self.pool
128             .spawn_pinned(create_task, WorkerChoice::LeastBurdened)
129     }
130 
131     /// Differs from `spawn_pinned` only in that you can choose a specific worker thread
132     /// of the pool, whereas `spawn_pinned` chooses the worker with the smallest
133     /// number of tasks scheduled.
134     ///
135     /// A worker thread is chosen by index. Indices are 0 based and the largest index
136     /// is given by `num_threads() - 1`
137     ///
138     /// # Panics
139     ///
140     /// This method panics if the index is out of bounds.
141     ///
142     /// # Examples
143     ///
144     /// This method can be used to spawn a task on all worker threads of the pool:
145     ///
146     /// ```
147     /// use tokio_util::task::LocalPoolHandle;
148     ///
149     /// #[tokio::main]
150     /// async fn main() {
151     ///     const NUM_WORKERS: usize = 3;
152     ///     let pool = LocalPoolHandle::new(NUM_WORKERS);
153     ///     let handles = (0..pool.num_threads())
154     ///         .map(|worker_idx| {
155     ///             pool.spawn_pinned_by_idx(
156     ///                 || {
157     ///                     async {
158     ///                         "test"
159     ///                     }
160     ///                 },
161     ///                 worker_idx,
162     ///             )
163     ///         })
164     ///         .collect::<Vec<_>>();
165     ///
166     ///     for handle in handles {
167     ///         handle.await.unwrap();
168     ///     }
169     /// }
170     /// ```
171     ///
172     #[track_caller]
spawn_pinned_by_idx<F, Fut>(&self, create_task: F, idx: usize) -> JoinHandle<Fut::Output> where F: FnOnce() -> Fut, F: Send + 'static, Fut: Future + 'static, Fut::Output: Send + 'static,173     pub fn spawn_pinned_by_idx<F, Fut>(&self, create_task: F, idx: usize) -> JoinHandle<Fut::Output>
174     where
175         F: FnOnce() -> Fut,
176         F: Send + 'static,
177         Fut: Future + 'static,
178         Fut::Output: Send + 'static,
179     {
180         self.pool
181             .spawn_pinned(create_task, WorkerChoice::ByIdx(idx))
182     }
183 }
184 
185 impl Debug for LocalPoolHandle {
fmt(&self, f: &mut Formatter<'_>) -> fmt::Result186     fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
187         f.write_str("LocalPoolHandle")
188     }
189 }
190 
191 enum WorkerChoice {
192     LeastBurdened,
193     ByIdx(usize),
194 }
195 
196 struct LocalPool {
197     workers: Box<[LocalWorkerHandle]>,
198 }
199 
200 impl LocalPool {
201     /// Spawn a `?Send` future onto a worker
202     #[track_caller]
spawn_pinned<F, Fut>( &self, create_task: F, worker_choice: WorkerChoice, ) -> JoinHandle<Fut::Output> where F: FnOnce() -> Fut, F: Send + 'static, Fut: Future + 'static, Fut::Output: Send + 'static,203     fn spawn_pinned<F, Fut>(
204         &self,
205         create_task: F,
206         worker_choice: WorkerChoice,
207     ) -> JoinHandle<Fut::Output>
208     where
209         F: FnOnce() -> Fut,
210         F: Send + 'static,
211         Fut: Future + 'static,
212         Fut::Output: Send + 'static,
213     {
214         let (sender, receiver) = oneshot::channel();
215         let (worker, job_guard) = match worker_choice {
216             WorkerChoice::LeastBurdened => self.find_and_incr_least_burdened_worker(),
217             WorkerChoice::ByIdx(idx) => self.find_worker_by_idx(idx),
218         };
219         let worker_spawner = worker.spawner.clone();
220 
221         // Spawn a future onto the worker's runtime so we can immediately return
222         // a join handle.
223         worker.runtime_handle.spawn(async move {
224             // Move the job guard into the task
225             let _job_guard = job_guard;
226 
227             // Propagate aborts via Abortable/AbortHandle
228             let (abort_handle, abort_registration) = AbortHandle::new_pair();
229             let _abort_guard = AbortGuard(abort_handle);
230 
231             // Inside the future we can't run spawn_local yet because we're not
232             // in the context of a LocalSet. We need to send create_task to the
233             // LocalSet task for spawning.
234             let spawn_task = Box::new(move || {
235                 // Once we're in the LocalSet context we can call spawn_local
236                 let join_handle =
237                     spawn_local(
238                         async move { Abortable::new(create_task(), abort_registration).await },
239                     );
240 
241                 // Send the join handle back to the spawner. If sending fails,
242                 // we assume the parent task was canceled, so cancel this task
243                 // as well.
244                 if let Err(join_handle) = sender.send(join_handle) {
245                     join_handle.abort()
246                 }
247             });
248 
249             // Send the callback to the LocalSet task
250             if let Err(e) = worker_spawner.send(spawn_task) {
251                 // Propagate the error as a panic in the join handle.
252                 panic!("Failed to send job to worker: {e}");
253             }
254 
255             // Wait for the task's join handle
256             let join_handle = match receiver.await {
257                 Ok(handle) => handle,
258                 Err(e) => {
259                     // We sent the task successfully, but failed to get its
260                     // join handle... We assume something happened to the worker
261                     // and the task was not spawned. Propagate the error as a
262                     // panic in the join handle.
263                     panic!("Worker failed to send join handle: {e}");
264                 }
265             };
266 
267             // Wait for the task to complete
268             let join_result = join_handle.await;
269 
270             match join_result {
271                 Ok(Ok(output)) => output,
272                 Ok(Err(_)) => {
273                     // Pinned task was aborted. But that only happens if this
274                     // task is aborted. So this is an impossible branch.
275                     unreachable!(
276                         "Reaching this branch means this task was previously \
277                          aborted but it continued running anyways"
278                     )
279                 }
280                 Err(e) => {
281                     if e.is_panic() {
282                         std::panic::resume_unwind(e.into_panic());
283                     } else if e.is_cancelled() {
284                         // No one else should have the join handle, so this is
285                         // unexpected. Forward this error as a panic in the join
286                         // handle.
287                         panic!("spawn_pinned task was canceled: {e}");
288                     } else {
289                         // Something unknown happened (not a panic or
290                         // cancellation). Forward this error as a panic in the
291                         // join handle.
292                         panic!("spawn_pinned task failed: {e}");
293                     }
294                 }
295             }
296         })
297     }
298 
299     /// Find the worker with the least number of tasks, increment its task
300     /// count, and return its handle. Make sure to actually spawn a task on
301     /// the worker so the task count is kept consistent with load.
302     ///
303     /// A job count guard is also returned to ensure the task count gets
304     /// decremented when the job is done.
find_and_incr_least_burdened_worker(&self) -> (&LocalWorkerHandle, JobCountGuard)305     fn find_and_incr_least_burdened_worker(&self) -> (&LocalWorkerHandle, JobCountGuard) {
306         loop {
307             let (worker, task_count) = self
308                 .workers
309                 .iter()
310                 .map(|worker| (worker, worker.task_count.load(Ordering::SeqCst)))
311                 .min_by_key(|&(_, count)| count)
312                 .expect("There must be more than one worker");
313 
314             // Make sure the task count hasn't changed since when we choose this
315             // worker. Otherwise, restart the search.
316             if worker
317                 .task_count
318                 .compare_exchange(
319                     task_count,
320                     task_count + 1,
321                     Ordering::SeqCst,
322                     Ordering::Relaxed,
323                 )
324                 .is_ok()
325             {
326                 return (worker, JobCountGuard(Arc::clone(&worker.task_count)));
327             }
328         }
329     }
330 
331     #[track_caller]
find_worker_by_idx(&self, idx: usize) -> (&LocalWorkerHandle, JobCountGuard)332     fn find_worker_by_idx(&self, idx: usize) -> (&LocalWorkerHandle, JobCountGuard) {
333         let worker = &self.workers[idx];
334         worker.task_count.fetch_add(1, Ordering::SeqCst);
335 
336         (worker, JobCountGuard(Arc::clone(&worker.task_count)))
337     }
338 }
339 
340 /// Automatically decrements a worker's job count when a job finishes (when
341 /// this gets dropped).
342 struct JobCountGuard(Arc<AtomicUsize>);
343 
344 impl Drop for JobCountGuard {
drop(&mut self)345     fn drop(&mut self) {
346         // Decrement the job count
347         let previous_value = self.0.fetch_sub(1, Ordering::SeqCst);
348         debug_assert!(previous_value >= 1);
349     }
350 }
351 
352 /// Calls abort on the handle when dropped.
353 struct AbortGuard(AbortHandle);
354 
355 impl Drop for AbortGuard {
drop(&mut self)356     fn drop(&mut self) {
357         self.0.abort();
358     }
359 }
360 
361 type PinnedFutureSpawner = Box<dyn FnOnce() + Send + 'static>;
362 
363 struct LocalWorkerHandle {
364     runtime_handle: tokio::runtime::Handle,
365     spawner: UnboundedSender<PinnedFutureSpawner>,
366     task_count: Arc<AtomicUsize>,
367 }
368 
369 impl LocalWorkerHandle {
370     /// Create a new worker for executing pinned tasks
new_worker() -> LocalWorkerHandle371     fn new_worker() -> LocalWorkerHandle {
372         let (sender, receiver) = unbounded_channel();
373         let runtime = Builder::new_current_thread()
374             .enable_all()
375             .build()
376             .expect("Failed to start a pinned worker thread runtime");
377         let runtime_handle = runtime.handle().clone();
378         let task_count = Arc::new(AtomicUsize::new(0));
379         let task_count_clone = Arc::clone(&task_count);
380 
381         std::thread::spawn(|| Self::run(runtime, receiver, task_count_clone));
382 
383         LocalWorkerHandle {
384             runtime_handle,
385             spawner: sender,
386             task_count,
387         }
388     }
389 
run( runtime: tokio::runtime::Runtime, mut task_receiver: UnboundedReceiver<PinnedFutureSpawner>, task_count: Arc<AtomicUsize>, )390     fn run(
391         runtime: tokio::runtime::Runtime,
392         mut task_receiver: UnboundedReceiver<PinnedFutureSpawner>,
393         task_count: Arc<AtomicUsize>,
394     ) {
395         let local_set = LocalSet::new();
396         local_set.block_on(&runtime, async {
397             while let Some(spawn_task) = task_receiver.recv().await {
398                 // Calls spawn_local(future)
399                 (spawn_task)();
400             }
401         });
402 
403         // If there are any tasks on the runtime associated with a LocalSet task
404         // that has already completed, but whose output has not yet been
405         // reported, let that task complete.
406         //
407         // Since the task_count is decremented when the runtime task exits,
408         // reading that counter lets us know if any such tasks completed during
409         // the call to `block_on`.
410         //
411         // Tasks on the LocalSet can't complete during this loop since they're
412         // stored on the LocalSet and we aren't accessing it.
413         let mut previous_task_count = task_count.load(Ordering::SeqCst);
414         loop {
415             // This call will also run tasks spawned on the runtime.
416             runtime.block_on(tokio::task::yield_now());
417             let new_task_count = task_count.load(Ordering::SeqCst);
418             if new_task_count == previous_task_count {
419                 break;
420             } else {
421                 previous_task_count = new_task_count;
422             }
423         }
424 
425         // It's now no longer possible for a task on the runtime to be
426         // associated with a LocalSet task that has completed. Drop both the
427         // LocalSet and runtime to let tasks on the runtime be cancelled if and
428         // only if they are still on the LocalSet.
429         //
430         // Drop the LocalSet task first so that anyone awaiting the runtime
431         // JoinHandle will see the cancelled error after the LocalSet task
432         // destructor has completed.
433         drop(local_set);
434         drop(runtime);
435     }
436 }
437