xref: /aosp_15_r20/external/crosvm/cros_async/src/blocking/cancellable_pool.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2022 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 //! Provides an async blocking pool whose tasks can be cancelled.
6 
7 use std::collections::HashMap;
8 use std::future::Future;
9 use std::sync::Arc;
10 use std::time::Duration;
11 use std::time::Instant;
12 
13 use once_cell::sync::Lazy;
14 use sync::Condvar;
15 use sync::Mutex;
16 use thiserror::Error as ThisError;
17 
18 use crate::BlockingPool;
19 
20 /// Global executor.
21 ///
22 /// This is convenient, though not preferred. Pros/cons:
23 /// + It avoids passing executor all the way to each call sites.
24 /// + The call site can assume that executor will never shutdown.
25 /// + Provides similar functionality as async_task with a few improvements around ability to cancel.
26 /// - Globals are harder to reason about.
27 static EXECUTOR: Lazy<CancellableBlockingPool> =
28     Lazy::new(|| CancellableBlockingPool::new(256, Duration::from_secs(10)));
29 
30 const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
31 
32 #[derive(PartialEq, Eq, PartialOrd, Default)]
33 enum WindDownStates {
34     #[default]
35     Armed,
36     Disarmed,
37     ShuttingDown,
38     ShutDown,
39 }
40 
41 #[derive(Default)]
42 struct State {
43     wind_down: WindDownStates,
44 
45     /// Helps to generate unique id to associate `cancel` with task.
46     current_cancellable_id: u64,
47 
48     /// A map of all the `cancel` routines of queued/in-flight tasks.
49     cancellables: HashMap<u64, Box<dyn Fn() + Send + 'static>>,
50 }
51 
52 #[derive(Debug, Clone, Copy)]
53 pub enum TimeoutAction {
54     /// Do nothing on timeout.
55     None,
56     /// Panic the thread on timeout.
57     Panic,
58 }
59 
60 #[derive(ThisError, Debug, PartialEq, Eq)]
61 pub enum Error {
62     #[error("Timeout occurred while trying to join threads")]
63     Timedout,
64     #[error("Shutdown is in progress")]
65     ShutdownInProgress,
66     #[error("Already shut down")]
67     AlreadyShutdown,
68 }
69 
70 struct Inner {
71     blocking_pool: BlockingPool,
72     state: Mutex<State>,
73 
74     /// This condvar gets notified when `cancellables` is empty after removing an
75     /// entry.
76     cancellables_cv: Condvar,
77 }
78 
79 impl Inner {
spawn<F, R>(self: &Arc<Self>, f: F) -> impl Future<Output = R> where F: FnOnce() -> R + Send + 'static, R: Send + 'static,80     pub fn spawn<F, R>(self: &Arc<Self>, f: F) -> impl Future<Output = R>
81     where
82         F: FnOnce() -> R + Send + 'static,
83         R: Send + 'static,
84     {
85         self.blocking_pool.spawn(f)
86     }
87 
88     /// Adds cancel to a cancellables and returns an `id` with which `cancel` can be
89     /// accessed/removed.
add_cancellable(&self, cancel: Box<dyn Fn() + Send + 'static>) -> u6490     fn add_cancellable(&self, cancel: Box<dyn Fn() + Send + 'static>) -> u64 {
91         let mut state = self.state.lock();
92         let id = state.current_cancellable_id;
93         state.current_cancellable_id += 1;
94         state.cancellables.insert(id, cancel);
95         id
96     }
97 }
98 
99 /// A thread pool for running work that may block.
100 ///
101 /// This is a wrapper around `BlockingPool` with an ability to cancel queued tasks.
102 /// See [BlockingPool] for more info.
103 ///
104 /// # Examples
105 ///
106 /// Spawn a task to run in the `CancellableBlockingPool` and await on its result.
107 ///
108 /// ```edition2018
109 /// use cros_async::CancellableBlockingPool;
110 ///
111 /// # async fn do_it() {
112 ///     let pool = CancellableBlockingPool::default();
113 ///     let CANCELLED = 0;
114 ///
115 ///     let res = pool.spawn(move || {
116 ///         // Do some CPU-intensive or blocking work here.
117 ///
118 ///         42
119 ///     }, move || CANCELLED).await;
120 ///
121 ///     assert_eq!(res, 42);
122 /// # }
123 /// # futures::executor::block_on(do_it());
124 /// ```
125 #[derive(Clone)]
126 pub struct CancellableBlockingPool {
127     inner: Arc<Inner>,
128 }
129 
130 impl CancellableBlockingPool {
131     const RETRY_COUNT: usize = 10;
132     const SLEEP_DURATION: Duration = Duration::from_millis(100);
133 
134     /// Create a new `CancellableBlockingPool`.
135     ///
136     /// When we try to shutdown or drop `CancellableBlockingPool`, it may happen that a hung thread
137     /// might prevent `CancellableBlockingPool` pool from getting dropped. On failure to shutdown in
138     /// `watchdog_opts.timeout` duration, `CancellableBlockingPool` can take an action specified by
139     /// `watchdog_opts.action`.
140     ///
141     /// See also: [BlockingPool::new()](BlockingPool::new)
new(max_threads: usize, keepalive: Duration) -> CancellableBlockingPool142     pub fn new(max_threads: usize, keepalive: Duration) -> CancellableBlockingPool {
143         CancellableBlockingPool {
144             inner: Arc::new(Inner {
145                 blocking_pool: BlockingPool::new(max_threads, keepalive),
146                 state: Default::default(),
147                 cancellables_cv: Condvar::new(),
148             }),
149         }
150     }
151 
152     /// Like [Self::new] but with pre-allocating capacity for up to `max_threads`.
with_capacity(max_threads: usize, keepalive: Duration) -> CancellableBlockingPool153     pub fn with_capacity(max_threads: usize, keepalive: Duration) -> CancellableBlockingPool {
154         CancellableBlockingPool {
155             inner: Arc::new(Inner {
156                 blocking_pool: BlockingPool::with_capacity(max_threads, keepalive),
157                 state: Mutex::new(State::default()),
158                 cancellables_cv: Condvar::new(),
159             }),
160         }
161     }
162 
163     /// Spawn a task to run in the `CancellableBlockingPool`.
164     ///
165     /// Callers may `await` the returned `Task` to be notified when the work is completed.
166     /// Dropping the future will not cancel the task.
167     ///
168     /// `cancel` helps to cancel a queued or in-flight operation `f`.
169     /// `cancel` may be called more than once if `f` doesn't respond to `cancel`.
170     /// `cancel` is not called if `f` completes successfully. For example,
171     /// # Examples
172     ///
173     /// ```edition2018
174     /// use {cros_async::CancellableBlockingPool, std::sync::{Arc, Mutex, Condvar}};
175     ///
176     /// # async fn cancel_it() {
177     ///    let pool = CancellableBlockingPool::default();
178     ///    let cancelled: i32 = 1;
179     ///    let success: i32 = 2;
180     ///
181     ///    let shared = Arc::new((Mutex::new(0), Condvar::new()));
182     ///    let shared2 = shared.clone();
183     ///    let shared3 = shared.clone();
184     ///
185     ///    let res = pool
186     ///        .spawn(
187     ///            move || {
188     ///                let guard = shared.0.lock().unwrap();
189     ///                let mut guard = shared.1.wait_while(guard, |state| *state == 0).unwrap();
190     ///                if *guard != cancelled {
191     ///                    *guard = success;
192     ///                }
193     ///            },
194     ///            move || {
195     ///                *shared2.0.lock().unwrap() = cancelled;
196     ///                shared2.1.notify_all();
197     ///            },
198     ///        )
199     ///        .await;
200     ///    pool.shutdown();
201     ///
202     ///    assert_eq!(*shared3.0.lock().unwrap(), cancelled);
203     /// # }
204     /// ```
spawn<F, R, G>(&self, f: F, cancel: G) -> impl Future<Output = R> where F: FnOnce() -> R + Send + 'static, R: Send + 'static, G: Fn() -> R + Send + 'static,205     pub fn spawn<F, R, G>(&self, f: F, cancel: G) -> impl Future<Output = R>
206     where
207         F: FnOnce() -> R + Send + 'static,
208         R: Send + 'static,
209         G: Fn() -> R + Send + 'static,
210     {
211         let inner = self.inner.clone();
212         let cancelled = Arc::new(Mutex::new(None));
213         let cancelled_spawn = cancelled.clone();
214         let id = inner.add_cancellable(Box::new(move || {
215             let mut c = cancelled.lock();
216             *c = Some(cancel());
217         }));
218 
219         self.inner.spawn(move || {
220             if let Some(res) = cancelled_spawn.lock().take() {
221                 return res;
222             }
223             let ret = f();
224             let mut state = inner.state.lock();
225             state.cancellables.remove(&id);
226             if state.cancellables.is_empty() {
227                 inner.cancellables_cv.notify_one();
228             }
229             ret
230         })
231     }
232 
233     /// Iterates over all the queued tasks and marks them as cancelled.
drain_cancellables(&self)234     fn drain_cancellables(&self) {
235         let mut state = self.inner.state.lock();
236         // Iterate a few times to try cancelling all the tasks.
237         for _ in 0..Self::RETRY_COUNT {
238             // Nothing left to do.
239             if state.cancellables.is_empty() {
240                 return;
241             }
242 
243             // We only cancel the task and do not remove it from the cancellables. It is runner's
244             // job to remove from state.cancellables.
245             for cancel in state.cancellables.values() {
246                 cancel();
247             }
248             // Hold the state lock in a block before sleeping so that woken up threads can get to
249             // hold the lock.
250             // Wait for a while so that the threads get a chance complete task in flight.
251             let (state1, _cv_timeout) = self
252                 .inner
253                 .cancellables_cv
254                 .wait_timeout(state, Self::SLEEP_DURATION);
255             state = state1;
256         }
257     }
258 
259     /// Marks all the queued and in-flight tasks as cancelled. Any tasks queued after `disarm`ing
260     /// will be cancelled.
261     /// Does not wait for all the tasks to get cancelled.
disarm(&self)262     pub fn disarm(&self) {
263         {
264             let mut state = self.inner.state.lock();
265 
266             if state.wind_down >= WindDownStates::Disarmed {
267                 return;
268             }
269 
270             // At this point any new incoming request will be cancelled when run.
271             state.wind_down = WindDownStates::Disarmed;
272         }
273         self.drain_cancellables();
274     }
275 
276     /// Shut down the `CancellableBlockingPool`.
277     ///
278     /// This will block until all work that has been started by the worker threads is finished. Any
279     /// work that was added to the `CancellableBlockingPool` but not yet picked up by a worker
280     /// thread will not complete and `await`ing on the `Task` for that work will panic.
shutdown(&self) -> Result<(), Error>281     pub fn shutdown(&self) -> Result<(), Error> {
282         self.shutdown_with_timeout(DEFAULT_SHUTDOWN_TIMEOUT)
283     }
284 
shutdown_with_timeout(&self, timeout: Duration) -> Result<(), Error>285     fn shutdown_with_timeout(&self, timeout: Duration) -> Result<(), Error> {
286         self.disarm();
287         {
288             let mut state = self.inner.state.lock();
289             if state.wind_down == WindDownStates::ShuttingDown {
290                 return Err(Error::ShutdownInProgress);
291             }
292             if state.wind_down == WindDownStates::ShutDown {
293                 return Err(Error::AlreadyShutdown);
294             }
295             state.wind_down = WindDownStates::ShuttingDown;
296         }
297 
298         let res = self
299             .inner
300             .blocking_pool
301             .shutdown(/* deadline: */ Some(Instant::now() + timeout));
302 
303         self.inner.state.lock().wind_down = WindDownStates::ShutDown;
304         match res {
305             Ok(_) => Ok(()),
306             Err(_) => Err(Error::Timedout),
307         }
308     }
309 }
310 
311 impl Default for CancellableBlockingPool {
default() -> CancellableBlockingPool312     fn default() -> CancellableBlockingPool {
313         CancellableBlockingPool::new(256, Duration::from_secs(10))
314     }
315 }
316 
317 impl Drop for CancellableBlockingPool {
drop(&mut self)318     fn drop(&mut self) {
319         if let Err(e) = self.shutdown() {
320             base::error!("CancellableBlockingPool::shutdown failed: {}", e);
321         }
322     }
323 }
324 
325 /// Spawn a task to run in the `CancellableBlockingPool` static executor.
326 ///
327 /// `cancel` in-flight operation. cancel is called on operation during `disarm` or during
328 /// `shutdown`.  Cancel may be called multiple times if running task doesn't get cancelled on first
329 /// attempt.
330 ///
331 /// Callers may `await` the returned `Task` to be notified when the work is completed.
332 ///
333 /// See also: `spawn`.
unblock<F, R, G>(f: F, cancel: G) -> impl Future<Output = R> where F: FnOnce() -> R + Send + 'static, R: Send + 'static, G: Fn() -> R + Send + 'static,334 pub fn unblock<F, R, G>(f: F, cancel: G) -> impl Future<Output = R>
335 where
336     F: FnOnce() -> R + Send + 'static,
337     R: Send + 'static,
338     G: Fn() -> R + Send + 'static,
339 {
340     EXECUTOR.spawn(f, cancel)
341 }
342 
343 /// Marks all the queued and in-flight tasks as cancelled. Any tasks queued after `disarm`ing
344 /// will be cancelled.
345 /// Doesn't not wait for all the tasks to get cancelled.
unblock_disarm()346 pub fn unblock_disarm() {
347     EXECUTOR.disarm()
348 }
349 
350 #[cfg(test)]
351 mod test {
352     use std::sync::Arc;
353     use std::sync::Barrier;
354     use std::thread;
355     use std::time::Duration;
356 
357     use futures::executor::block_on;
358     use sync::Condvar;
359     use sync::Mutex;
360 
361     use crate::blocking::Error;
362     use crate::CancellableBlockingPool;
363 
364     #[test]
disarm_with_pending_work()365     fn disarm_with_pending_work() {
366         // Create a pool with only one thread.
367         let pool = CancellableBlockingPool::new(1, Duration::from_secs(10));
368 
369         let mu = Arc::new(Mutex::new(false));
370         let cv = Arc::new(Condvar::new());
371         let blocker_is_running = Arc::new(Barrier::new(2));
372 
373         // First spawn a thread that blocks the pool.
374         let task_mu = mu.clone();
375         let task_cv = cv.clone();
376         let task_blocker_is_running = blocker_is_running.clone();
377         let _blocking_task = pool.spawn(
378             move || {
379                 task_blocker_is_running.wait();
380                 let mut ready = task_mu.lock();
381                 while !*ready {
382                     ready = task_cv.wait(ready);
383                 }
384             },
385             move || {},
386         );
387 
388         // Wait for the worker to start running the blocking thread.
389         blocker_is_running.wait();
390 
391         // This task will never finish because we will disarm the pool first.
392         let unfinished = pool.spawn(|| 5, || 0);
393 
394         // Disarming should cancel the task.
395         pool.disarm();
396 
397         // Shutdown the blocking thread. This will allow a worker to pick up the task that has
398         // to be cancelled.
399         *mu.lock() = true;
400         cv.notify_all();
401 
402         // We expect the cancelled value to be returned.
403         assert_eq!(block_on(unfinished), 0);
404 
405         // Now the pool is empty and can be shutdown without blocking.
406         pool.shutdown().unwrap();
407     }
408 
409     #[test]
shutdown_with_blocked_work_should_timeout()410     fn shutdown_with_blocked_work_should_timeout() {
411         let pool = CancellableBlockingPool::new(1, Duration::from_secs(10));
412 
413         let running = Arc::new((Mutex::new(false), Condvar::new()));
414         let running1 = running.clone();
415         let _blocking_task = pool.spawn(
416             move || {
417                 *running1.0.lock() = true;
418                 running1.1.notify_one();
419                 thread::sleep(Duration::from_secs(10000));
420             },
421             move || {},
422         );
423 
424         let mut is_running = running.0.lock();
425         while !*is_running {
426             is_running = running.1.wait(is_running);
427         }
428 
429         // This shutdown will wait for the full timeout period, so use a short timeout.
430         assert_eq!(
431             pool.shutdown_with_timeout(Duration::from_millis(1)),
432             Err(Error::Timedout)
433         );
434     }
435 
436     #[test]
multiple_shutdown_returns_error()437     fn multiple_shutdown_returns_error() {
438         let pool = CancellableBlockingPool::new(1, Duration::from_secs(10));
439         let _ = pool.shutdown();
440         assert_eq!(pool.shutdown(), Err(Error::AlreadyShutdown));
441     }
442 
443     #[test]
shutdown_in_progress()444     fn shutdown_in_progress() {
445         let pool = CancellableBlockingPool::new(1, Duration::from_secs(10));
446 
447         let running = Arc::new((Mutex::new(false), Condvar::new()));
448         let running1 = running.clone();
449         let _blocking_task = pool.spawn(
450             move || {
451                 *running1.0.lock() = true;
452                 running1.1.notify_one();
453                 thread::sleep(Duration::from_secs(10000));
454             },
455             move || {},
456         );
457 
458         let mut is_running = running.0.lock();
459         while !*is_running {
460             is_running = running.1.wait(is_running);
461         }
462 
463         let pool_clone = pool.clone();
464         thread::spawn(move || {
465             while !pool_clone.inner.blocking_pool.shutting_down() {}
466             assert_eq!(pool_clone.shutdown(), Err(Error::ShutdownInProgress));
467         });
468 
469         // This shutdown will wait for the full timeout period, so use a short timeout.
470         // However, it also needs to wait long enough for the thread spawned above to observe the
471         // shutting_down state, so don't make it too short.
472         assert_eq!(
473             pool.shutdown_with_timeout(Duration::from_millis(200)),
474             Err(Error::Timedout)
475         );
476     }
477 }
478