1 //! Thread pool for blocking operations
2 
3 use crate::loom::sync::{Arc, Condvar, Mutex};
4 use crate::loom::thread;
5 use crate::runtime::blocking::schedule::BlockingSchedule;
6 use crate::runtime::blocking::{shutdown, BlockingTask};
7 use crate::runtime::builder::ThreadNameFn;
8 use crate::runtime::task::{self, JoinHandle};
9 use crate::runtime::{Builder, Callback, Handle, BOX_FUTURE_THRESHOLD};
10 use crate::util::metric_atomics::MetricAtomicUsize;
11 use crate::util::trace::{blocking_task, SpawnMeta};
12 
13 use std::collections::{HashMap, VecDeque};
14 use std::fmt;
15 use std::io;
16 use std::sync::atomic::Ordering;
17 use std::time::Duration;
18 
19 pub(crate) struct BlockingPool {
20     spawner: Spawner,
21     shutdown_rx: shutdown::Receiver,
22 }
23 
24 #[derive(Clone)]
25 pub(crate) struct Spawner {
26     inner: Arc<Inner>,
27 }
28 
29 #[derive(Default)]
30 pub(crate) struct SpawnerMetrics {
31     num_threads: MetricAtomicUsize,
32     num_idle_threads: MetricAtomicUsize,
33     queue_depth: MetricAtomicUsize,
34 }
35 
36 impl SpawnerMetrics {
num_threads(&self) -> usize37     fn num_threads(&self) -> usize {
38         self.num_threads.load(Ordering::Relaxed)
39     }
40 
num_idle_threads(&self) -> usize41     fn num_idle_threads(&self) -> usize {
42         self.num_idle_threads.load(Ordering::Relaxed)
43     }
44 
45     cfg_unstable_metrics! {
46         fn queue_depth(&self) -> usize {
47             self.queue_depth.load(Ordering::Relaxed)
48         }
49     }
50 
inc_num_threads(&self)51     fn inc_num_threads(&self) {
52         self.num_threads.increment();
53     }
54 
dec_num_threads(&self)55     fn dec_num_threads(&self) {
56         self.num_threads.decrement();
57     }
58 
inc_num_idle_threads(&self)59     fn inc_num_idle_threads(&self) {
60         self.num_idle_threads.increment();
61     }
62 
dec_num_idle_threads(&self) -> usize63     fn dec_num_idle_threads(&self) -> usize {
64         self.num_idle_threads.decrement()
65     }
66 
inc_queue_depth(&self)67     fn inc_queue_depth(&self) {
68         self.queue_depth.increment();
69     }
70 
dec_queue_depth(&self)71     fn dec_queue_depth(&self) {
72         self.queue_depth.decrement();
73     }
74 }
75 
76 struct Inner {
77     /// State shared between worker threads.
78     shared: Mutex<Shared>,
79 
80     /// Pool threads wait on this.
81     condvar: Condvar,
82 
83     /// Spawned threads use this name.
84     thread_name: ThreadNameFn,
85 
86     /// Spawned thread stack size.
87     stack_size: Option<usize>,
88 
89     /// Call after a thread starts.
90     after_start: Option<Callback>,
91 
92     /// Call before a thread stops.
93     before_stop: Option<Callback>,
94 
95     // Maximum number of threads.
96     thread_cap: usize,
97 
98     // Customizable wait timeout.
99     keep_alive: Duration,
100 
101     // Metrics about the pool.
102     metrics: SpawnerMetrics,
103 }
104 
105 struct Shared {
106     queue: VecDeque<Task>,
107     num_notify: u32,
108     shutdown: bool,
109     shutdown_tx: Option<shutdown::Sender>,
110     /// Prior to shutdown, we clean up `JoinHandles` by having each timed-out
111     /// thread join on the previous timed-out thread. This is not strictly
112     /// necessary but helps avoid Valgrind false positives, see
113     /// <https://github.com/tokio-rs/tokio/commit/646fbae76535e397ef79dbcaacb945d4c829f666>
114     /// for more information.
115     last_exiting_thread: Option<thread::JoinHandle<()>>,
116     /// This holds the `JoinHandles` for all running threads; on shutdown, the thread
117     /// calling shutdown handles joining on these.
118     worker_threads: HashMap<usize, thread::JoinHandle<()>>,
119     /// This is a counter used to iterate `worker_threads` in a consistent order (for loom's
120     /// benefit).
121     worker_thread_index: usize,
122 }
123 
124 pub(crate) struct Task {
125     task: task::UnownedTask<BlockingSchedule>,
126     mandatory: Mandatory,
127 }
128 
129 #[derive(PartialEq, Eq)]
130 pub(crate) enum Mandatory {
131     #[cfg_attr(not(fs), allow(dead_code))]
132     Mandatory,
133     NonMandatory,
134 }
135 
136 pub(crate) enum SpawnError {
137     /// Pool is shutting down and the task was not scheduled
138     ShuttingDown,
139     /// There are no worker threads available to take the task
140     /// and the OS failed to spawn a new one
141     NoThreads(io::Error),
142 }
143 
144 impl From<SpawnError> for io::Error {
from(e: SpawnError) -> Self145     fn from(e: SpawnError) -> Self {
146         match e {
147             SpawnError::ShuttingDown => {
148                 io::Error::new(io::ErrorKind::Other, "blocking pool shutting down")
149             }
150             SpawnError::NoThreads(e) => e,
151         }
152     }
153 }
154 
155 impl Task {
new(task: task::UnownedTask<BlockingSchedule>, mandatory: Mandatory) -> Task156     pub(crate) fn new(task: task::UnownedTask<BlockingSchedule>, mandatory: Mandatory) -> Task {
157         Task { task, mandatory }
158     }
159 
run(self)160     fn run(self) {
161         self.task.run();
162     }
163 
shutdown_or_run_if_mandatory(self)164     fn shutdown_or_run_if_mandatory(self) {
165         match self.mandatory {
166             Mandatory::NonMandatory => self.task.shutdown(),
167             Mandatory::Mandatory => self.task.run(),
168         }
169     }
170 }
171 
172 const KEEP_ALIVE: Duration = Duration::from_secs(10);
173 
174 /// Runs the provided function on an executor dedicated to blocking operations.
175 /// Tasks will be scheduled as non-mandatory, meaning they may not get executed
176 /// in case of runtime shutdown.
177 #[track_caller]
178 #[cfg_attr(target_os = "wasi", allow(dead_code))]
spawn_blocking<F, R>(func: F) -> JoinHandle<R> where F: FnOnce() -> R + Send + 'static, R: Send + 'static,179 pub(crate) fn spawn_blocking<F, R>(func: F) -> JoinHandle<R>
180 where
181     F: FnOnce() -> R + Send + 'static,
182     R: Send + 'static,
183 {
184     let rt = Handle::current();
185     rt.spawn_blocking(func)
186 }
187 
188 cfg_fs! {
189     #[cfg_attr(any(
190         all(loom, not(test)), // the function is covered by loom tests
191         test
192     ), allow(dead_code))]
193     /// Runs the provided function on an executor dedicated to blocking
194     /// operations. Tasks will be scheduled as mandatory, meaning they are
195     /// guaranteed to run unless a shutdown is already taking place. In case a
196     /// shutdown is already taking place, `None` will be returned.
197     pub(crate) fn spawn_mandatory_blocking<F, R>(func: F) -> Option<JoinHandle<R>>
198     where
199         F: FnOnce() -> R + Send + 'static,
200         R: Send + 'static,
201     {
202         let rt = Handle::current();
203         rt.inner.blocking_spawner().spawn_mandatory_blocking(&rt, func)
204     }
205 }
206 
207 // ===== impl BlockingPool =====
208 
209 impl BlockingPool {
new(builder: &Builder, thread_cap: usize) -> BlockingPool210     pub(crate) fn new(builder: &Builder, thread_cap: usize) -> BlockingPool {
211         let (shutdown_tx, shutdown_rx) = shutdown::channel();
212         let keep_alive = builder.keep_alive.unwrap_or(KEEP_ALIVE);
213 
214         BlockingPool {
215             spawner: Spawner {
216                 inner: Arc::new(Inner {
217                     shared: Mutex::new(Shared {
218                         queue: VecDeque::new(),
219                         num_notify: 0,
220                         shutdown: false,
221                         shutdown_tx: Some(shutdown_tx),
222                         last_exiting_thread: None,
223                         worker_threads: HashMap::new(),
224                         worker_thread_index: 0,
225                     }),
226                     condvar: Condvar::new(),
227                     thread_name: builder.thread_name.clone(),
228                     stack_size: builder.thread_stack_size,
229                     after_start: builder.after_start.clone(),
230                     before_stop: builder.before_stop.clone(),
231                     thread_cap,
232                     keep_alive,
233                     metrics: SpawnerMetrics::default(),
234                 }),
235             },
236             shutdown_rx,
237         }
238     }
239 
spawner(&self) -> &Spawner240     pub(crate) fn spawner(&self) -> &Spawner {
241         &self.spawner
242     }
243 
shutdown(&mut self, timeout: Option<Duration>)244     pub(crate) fn shutdown(&mut self, timeout: Option<Duration>) {
245         let mut shared = self.spawner.inner.shared.lock();
246 
247         // The function can be called multiple times. First, by explicitly
248         // calling `shutdown` then by the drop handler calling `shutdown`. This
249         // prevents shutting down twice.
250         if shared.shutdown {
251             return;
252         }
253 
254         shared.shutdown = true;
255         shared.shutdown_tx = None;
256         self.spawner.inner.condvar.notify_all();
257 
258         let last_exited_thread = std::mem::take(&mut shared.last_exiting_thread);
259         let workers = std::mem::take(&mut shared.worker_threads);
260 
261         drop(shared);
262 
263         if self.shutdown_rx.wait(timeout) {
264             let _ = last_exited_thread.map(thread::JoinHandle::join);
265 
266             // Loom requires that execution be deterministic, so sort by thread ID before joining.
267             // (HashMaps use a randomly-seeded hash function, so the order is nondeterministic)
268             #[cfg(loom)]
269             let workers: Vec<(usize, thread::JoinHandle<()>)> = {
270                 let mut workers: Vec<_> = workers.into_iter().collect();
271                 workers.sort_by_key(|(id, _)| *id);
272                 workers
273             };
274 
275             for (_id, handle) in workers {
276                 let _ = handle.join();
277             }
278         }
279     }
280 }
281 
282 impl Drop for BlockingPool {
drop(&mut self)283     fn drop(&mut self) {
284         self.shutdown(None);
285     }
286 }
287 
288 impl fmt::Debug for BlockingPool {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result289     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
290         fmt.debug_struct("BlockingPool").finish()
291     }
292 }
293 
294 // ===== impl Spawner =====
295 
296 impl Spawner {
297     #[track_caller]
spawn_blocking<F, R>(&self, rt: &Handle, func: F) -> JoinHandle<R> where F: FnOnce() -> R + Send + 'static, R: Send + 'static,298     pub(crate) fn spawn_blocking<F, R>(&self, rt: &Handle, func: F) -> JoinHandle<R>
299     where
300         F: FnOnce() -> R + Send + 'static,
301         R: Send + 'static,
302     {
303         let fn_size = std::mem::size_of::<F>();
304         let (join_handle, spawn_result) = if fn_size > BOX_FUTURE_THRESHOLD {
305             self.spawn_blocking_inner(
306                 Box::new(func),
307                 Mandatory::NonMandatory,
308                 SpawnMeta::new_unnamed(fn_size),
309                 rt,
310             )
311         } else {
312             self.spawn_blocking_inner(
313                 func,
314                 Mandatory::NonMandatory,
315                 SpawnMeta::new_unnamed(fn_size),
316                 rt,
317             )
318         };
319 
320         match spawn_result {
321             Ok(()) => join_handle,
322             // Compat: do not panic here, return the join_handle even though it will never resolve
323             Err(SpawnError::ShuttingDown) => join_handle,
324             Err(SpawnError::NoThreads(e)) => {
325                 panic!("OS can't spawn worker thread: {e}")
326             }
327         }
328     }
329 
330     cfg_fs! {
331         #[track_caller]
332         #[cfg_attr(any(
333             all(loom, not(test)), // the function is covered by loom tests
334             test
335         ), allow(dead_code))]
336         pub(crate) fn spawn_mandatory_blocking<F, R>(&self, rt: &Handle, func: F) -> Option<JoinHandle<R>>
337         where
338             F: FnOnce() -> R + Send + 'static,
339             R: Send + 'static,
340         {
341             let fn_size = std::mem::size_of::<F>();
342             let (join_handle, spawn_result) = if fn_size > BOX_FUTURE_THRESHOLD {
343                 self.spawn_blocking_inner(
344                     Box::new(func),
345                     Mandatory::Mandatory,
346                     SpawnMeta::new_unnamed(fn_size),
347                     rt,
348                 )
349             } else {
350                 self.spawn_blocking_inner(
351                     func,
352                     Mandatory::Mandatory,
353                     SpawnMeta::new_unnamed(fn_size),
354                     rt,
355                 )
356             };
357 
358             if spawn_result.is_ok() {
359                 Some(join_handle)
360             } else {
361                 None
362             }
363         }
364     }
365 
366     #[track_caller]
spawn_blocking_inner<F, R>( &self, func: F, is_mandatory: Mandatory, spawn_meta: SpawnMeta<'_>, rt: &Handle, ) -> (JoinHandle<R>, Result<(), SpawnError>) where F: FnOnce() -> R + Send + 'static, R: Send + 'static,367     pub(crate) fn spawn_blocking_inner<F, R>(
368         &self,
369         func: F,
370         is_mandatory: Mandatory,
371         spawn_meta: SpawnMeta<'_>,
372         rt: &Handle,
373     ) -> (JoinHandle<R>, Result<(), SpawnError>)
374     where
375         F: FnOnce() -> R + Send + 'static,
376         R: Send + 'static,
377     {
378         let id = task::Id::next();
379         let fut =
380             blocking_task::<F, BlockingTask<F>>(BlockingTask::new(func), spawn_meta, id.as_u64());
381 
382         let (task, handle) = task::unowned(fut, BlockingSchedule::new(rt), id);
383 
384         let spawned = self.spawn_task(Task::new(task, is_mandatory), rt);
385         (handle, spawned)
386     }
387 
spawn_task(&self, task: Task, rt: &Handle) -> Result<(), SpawnError>388     fn spawn_task(&self, task: Task, rt: &Handle) -> Result<(), SpawnError> {
389         let mut shared = self.inner.shared.lock();
390 
391         if shared.shutdown {
392             // Shutdown the task: it's fine to shutdown this task (even if
393             // mandatory) because it was scheduled after the shutdown of the
394             // runtime began.
395             task.task.shutdown();
396 
397             // no need to even push this task; it would never get picked up
398             return Err(SpawnError::ShuttingDown);
399         }
400 
401         shared.queue.push_back(task);
402         self.inner.metrics.inc_queue_depth();
403 
404         if self.inner.metrics.num_idle_threads() == 0 {
405             // No threads are able to process the task.
406 
407             if self.inner.metrics.num_threads() == self.inner.thread_cap {
408                 // At max number of threads
409             } else {
410                 assert!(shared.shutdown_tx.is_some());
411                 let shutdown_tx = shared.shutdown_tx.clone();
412 
413                 if let Some(shutdown_tx) = shutdown_tx {
414                     let id = shared.worker_thread_index;
415 
416                     match self.spawn_thread(shutdown_tx, rt, id) {
417                         Ok(handle) => {
418                             self.inner.metrics.inc_num_threads();
419                             shared.worker_thread_index += 1;
420                             shared.worker_threads.insert(id, handle);
421                         }
422                         Err(ref e)
423                             if is_temporary_os_thread_error(e)
424                                 && self.inner.metrics.num_threads() > 0 =>
425                         {
426                             // OS temporarily failed to spawn a new thread.
427                             // The task will be picked up eventually by a currently
428                             // busy thread.
429                         }
430                         Err(e) => {
431                             // The OS refused to spawn the thread and there is no thread
432                             // to pick up the task that has just been pushed to the queue.
433                             return Err(SpawnError::NoThreads(e));
434                         }
435                     }
436                 }
437             }
438         } else {
439             // Notify an idle worker thread. The notification counter
440             // is used to count the needed amount of notifications
441             // exactly. Thread libraries may generate spurious
442             // wakeups, this counter is used to keep us in a
443             // consistent state.
444             self.inner.metrics.dec_num_idle_threads();
445             shared.num_notify += 1;
446             self.inner.condvar.notify_one();
447         }
448 
449         Ok(())
450     }
451 
spawn_thread( &self, shutdown_tx: shutdown::Sender, rt: &Handle, id: usize, ) -> io::Result<thread::JoinHandle<()>>452     fn spawn_thread(
453         &self,
454         shutdown_tx: shutdown::Sender,
455         rt: &Handle,
456         id: usize,
457     ) -> io::Result<thread::JoinHandle<()>> {
458         let mut builder = thread::Builder::new().name((self.inner.thread_name)());
459 
460         if let Some(stack_size) = self.inner.stack_size {
461             builder = builder.stack_size(stack_size);
462         }
463 
464         let rt = rt.clone();
465 
466         builder.spawn(move || {
467             // Only the reference should be moved into the closure
468             let _enter = rt.enter();
469             rt.inner.blocking_spawner().inner.run(id);
470             drop(shutdown_tx);
471         })
472     }
473 }
474 
475 cfg_unstable_metrics! {
476     impl Spawner {
477         pub(crate) fn num_threads(&self) -> usize {
478             self.inner.metrics.num_threads()
479         }
480 
481         pub(crate) fn num_idle_threads(&self) -> usize {
482             self.inner.metrics.num_idle_threads()
483         }
484 
485         pub(crate) fn queue_depth(&self) -> usize {
486             self.inner.metrics.queue_depth()
487         }
488     }
489 }
490 
491 // Tells whether the error when spawning a thread is temporary.
492 #[inline]
is_temporary_os_thread_error(error: &io::Error) -> bool493 fn is_temporary_os_thread_error(error: &io::Error) -> bool {
494     matches!(error.kind(), io::ErrorKind::WouldBlock)
495 }
496 
497 impl Inner {
run(&self, worker_thread_id: usize)498     fn run(&self, worker_thread_id: usize) {
499         if let Some(f) = &self.after_start {
500             f();
501         }
502 
503         let mut shared = self.shared.lock();
504         let mut join_on_thread = None;
505 
506         'main: loop {
507             // BUSY
508             while let Some(task) = shared.queue.pop_front() {
509                 self.metrics.dec_queue_depth();
510                 drop(shared);
511                 task.run();
512 
513                 shared = self.shared.lock();
514             }
515 
516             // IDLE
517             self.metrics.inc_num_idle_threads();
518 
519             while !shared.shutdown {
520                 let lock_result = self.condvar.wait_timeout(shared, self.keep_alive).unwrap();
521 
522                 shared = lock_result.0;
523                 let timeout_result = lock_result.1;
524 
525                 if shared.num_notify != 0 {
526                     // We have received a legitimate wakeup,
527                     // acknowledge it by decrementing the counter
528                     // and transition to the BUSY state.
529                     shared.num_notify -= 1;
530                     break;
531                 }
532 
533                 // Even if the condvar "timed out", if the pool is entering the
534                 // shutdown phase, we want to perform the cleanup logic.
535                 if !shared.shutdown && timeout_result.timed_out() {
536                     // We'll join the prior timed-out thread's JoinHandle after dropping the lock.
537                     // This isn't done when shutting down, because the thread calling shutdown will
538                     // handle joining everything.
539                     let my_handle = shared.worker_threads.remove(&worker_thread_id);
540                     join_on_thread = std::mem::replace(&mut shared.last_exiting_thread, my_handle);
541 
542                     break 'main;
543                 }
544 
545                 // Spurious wakeup detected, go back to sleep.
546             }
547 
548             if shared.shutdown {
549                 // Drain the queue
550                 while let Some(task) = shared.queue.pop_front() {
551                     self.metrics.dec_queue_depth();
552                     drop(shared);
553 
554                     task.shutdown_or_run_if_mandatory();
555 
556                     shared = self.shared.lock();
557                 }
558 
559                 // Work was produced, and we "took" it (by decrementing num_notify).
560                 // This means that num_idle was decremented once for our wakeup.
561                 // But, since we are exiting, we need to "undo" that, as we'll stay idle.
562                 self.metrics.inc_num_idle_threads();
563                 // NOTE: Technically we should also do num_notify++ and notify again,
564                 // but since we're shutting down anyway, that won't be necessary.
565                 break;
566             }
567         }
568 
569         // Thread exit
570         self.metrics.dec_num_threads();
571 
572         // num_idle should now be tracked exactly, panic
573         // with a descriptive message if it is not the
574         // case.
575         let prev_idle = self.metrics.dec_num_idle_threads();
576         assert!(
577             prev_idle >= self.metrics.num_idle_threads(),
578             "num_idle_threads underflowed on thread exit"
579         );
580 
581         if shared.shutdown && self.metrics.num_threads() == 0 {
582             self.condvar.notify_one();
583         }
584 
585         drop(shared);
586 
587         if let Some(f) = &self.before_stop {
588             f();
589         }
590 
591         if let Some(handle) = join_on_thread {
592             let _ = handle.join();
593         }
594     }
595 }
596 
597 impl fmt::Debug for Spawner {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result598     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
599         fmt.debug_struct("blocking::Spawner").finish()
600     }
601 }
602