1 // Copyright 2022 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 //! Provides an async blocking pool whose tasks can be cancelled.
6
7 use std::collections::HashMap;
8 use std::future::Future;
9 use std::sync::Arc;
10 use std::time::Duration;
11 use std::time::Instant;
12
13 use once_cell::sync::Lazy;
14 use sync::Condvar;
15 use sync::Mutex;
16 use thiserror::Error as ThisError;
17
18 use crate::BlockingPool;
19
20 /// Global executor.
21 ///
22 /// This is convenient, though not preferred. Pros/cons:
23 /// + It avoids passing executor all the way to each call sites.
24 /// + The call site can assume that executor will never shutdown.
25 /// + Provides similar functionality as async_task with a few improvements around ability to cancel.
26 /// - Globals are harder to reason about.
27 static EXECUTOR: Lazy<CancellableBlockingPool> =
28 Lazy::new(|| CancellableBlockingPool::new(256, Duration::from_secs(10)));
29
30 const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
31
32 #[derive(PartialEq, Eq, PartialOrd, Default)]
33 enum WindDownStates {
34 #[default]
35 Armed,
36 Disarmed,
37 ShuttingDown,
38 ShutDown,
39 }
40
41 #[derive(Default)]
42 struct State {
43 wind_down: WindDownStates,
44
45 /// Helps to generate unique id to associate `cancel` with task.
46 current_cancellable_id: u64,
47
48 /// A map of all the `cancel` routines of queued/in-flight tasks.
49 cancellables: HashMap<u64, Box<dyn Fn() + Send + 'static>>,
50 }
51
52 #[derive(Debug, Clone, Copy)]
53 pub enum TimeoutAction {
54 /// Do nothing on timeout.
55 None,
56 /// Panic the thread on timeout.
57 Panic,
58 }
59
60 #[derive(ThisError, Debug, PartialEq, Eq)]
61 pub enum Error {
62 #[error("Timeout occurred while trying to join threads")]
63 Timedout,
64 #[error("Shutdown is in progress")]
65 ShutdownInProgress,
66 #[error("Already shut down")]
67 AlreadyShutdown,
68 }
69
70 struct Inner {
71 blocking_pool: BlockingPool,
72 state: Mutex<State>,
73
74 /// This condvar gets notified when `cancellables` is empty after removing an
75 /// entry.
76 cancellables_cv: Condvar,
77 }
78
79 impl Inner {
spawn<F, R>(self: &Arc<Self>, f: F) -> impl Future<Output = R> where F: FnOnce() -> R + Send + 'static, R: Send + 'static,80 pub fn spawn<F, R>(self: &Arc<Self>, f: F) -> impl Future<Output = R>
81 where
82 F: FnOnce() -> R + Send + 'static,
83 R: Send + 'static,
84 {
85 self.blocking_pool.spawn(f)
86 }
87
88 /// Adds cancel to a cancellables and returns an `id` with which `cancel` can be
89 /// accessed/removed.
add_cancellable(&self, cancel: Box<dyn Fn() + Send + 'static>) -> u6490 fn add_cancellable(&self, cancel: Box<dyn Fn() + Send + 'static>) -> u64 {
91 let mut state = self.state.lock();
92 let id = state.current_cancellable_id;
93 state.current_cancellable_id += 1;
94 state.cancellables.insert(id, cancel);
95 id
96 }
97 }
98
99 /// A thread pool for running work that may block.
100 ///
101 /// This is a wrapper around `BlockingPool` with an ability to cancel queued tasks.
102 /// See [BlockingPool] for more info.
103 ///
104 /// # Examples
105 ///
106 /// Spawn a task to run in the `CancellableBlockingPool` and await on its result.
107 ///
108 /// ```edition2018
109 /// use cros_async::CancellableBlockingPool;
110 ///
111 /// # async fn do_it() {
112 /// let pool = CancellableBlockingPool::default();
113 /// let CANCELLED = 0;
114 ///
115 /// let res = pool.spawn(move || {
116 /// // Do some CPU-intensive or blocking work here.
117 ///
118 /// 42
119 /// }, move || CANCELLED).await;
120 ///
121 /// assert_eq!(res, 42);
122 /// # }
123 /// # futures::executor::block_on(do_it());
124 /// ```
125 #[derive(Clone)]
126 pub struct CancellableBlockingPool {
127 inner: Arc<Inner>,
128 }
129
130 impl CancellableBlockingPool {
131 const RETRY_COUNT: usize = 10;
132 const SLEEP_DURATION: Duration = Duration::from_millis(100);
133
134 /// Create a new `CancellableBlockingPool`.
135 ///
136 /// When we try to shutdown or drop `CancellableBlockingPool`, it may happen that a hung thread
137 /// might prevent `CancellableBlockingPool` pool from getting dropped. On failure to shutdown in
138 /// `watchdog_opts.timeout` duration, `CancellableBlockingPool` can take an action specified by
139 /// `watchdog_opts.action`.
140 ///
141 /// See also: [BlockingPool::new()](BlockingPool::new)
new(max_threads: usize, keepalive: Duration) -> CancellableBlockingPool142 pub fn new(max_threads: usize, keepalive: Duration) -> CancellableBlockingPool {
143 CancellableBlockingPool {
144 inner: Arc::new(Inner {
145 blocking_pool: BlockingPool::new(max_threads, keepalive),
146 state: Default::default(),
147 cancellables_cv: Condvar::new(),
148 }),
149 }
150 }
151
152 /// Like [Self::new] but with pre-allocating capacity for up to `max_threads`.
with_capacity(max_threads: usize, keepalive: Duration) -> CancellableBlockingPool153 pub fn with_capacity(max_threads: usize, keepalive: Duration) -> CancellableBlockingPool {
154 CancellableBlockingPool {
155 inner: Arc::new(Inner {
156 blocking_pool: BlockingPool::with_capacity(max_threads, keepalive),
157 state: Mutex::new(State::default()),
158 cancellables_cv: Condvar::new(),
159 }),
160 }
161 }
162
163 /// Spawn a task to run in the `CancellableBlockingPool`.
164 ///
165 /// Callers may `await` the returned `Task` to be notified when the work is completed.
166 /// Dropping the future will not cancel the task.
167 ///
168 /// `cancel` helps to cancel a queued or in-flight operation `f`.
169 /// `cancel` may be called more than once if `f` doesn't respond to `cancel`.
170 /// `cancel` is not called if `f` completes successfully. For example,
171 /// # Examples
172 ///
173 /// ```edition2018
174 /// use {cros_async::CancellableBlockingPool, std::sync::{Arc, Mutex, Condvar}};
175 ///
176 /// # async fn cancel_it() {
177 /// let pool = CancellableBlockingPool::default();
178 /// let cancelled: i32 = 1;
179 /// let success: i32 = 2;
180 ///
181 /// let shared = Arc::new((Mutex::new(0), Condvar::new()));
182 /// let shared2 = shared.clone();
183 /// let shared3 = shared.clone();
184 ///
185 /// let res = pool
186 /// .spawn(
187 /// move || {
188 /// let guard = shared.0.lock().unwrap();
189 /// let mut guard = shared.1.wait_while(guard, |state| *state == 0).unwrap();
190 /// if *guard != cancelled {
191 /// *guard = success;
192 /// }
193 /// },
194 /// move || {
195 /// *shared2.0.lock().unwrap() = cancelled;
196 /// shared2.1.notify_all();
197 /// },
198 /// )
199 /// .await;
200 /// pool.shutdown();
201 ///
202 /// assert_eq!(*shared3.0.lock().unwrap(), cancelled);
203 /// # }
204 /// ```
spawn<F, R, G>(&self, f: F, cancel: G) -> impl Future<Output = R> where F: FnOnce() -> R + Send + 'static, R: Send + 'static, G: Fn() -> R + Send + 'static,205 pub fn spawn<F, R, G>(&self, f: F, cancel: G) -> impl Future<Output = R>
206 where
207 F: FnOnce() -> R + Send + 'static,
208 R: Send + 'static,
209 G: Fn() -> R + Send + 'static,
210 {
211 let inner = self.inner.clone();
212 let cancelled = Arc::new(Mutex::new(None));
213 let cancelled_spawn = cancelled.clone();
214 let id = inner.add_cancellable(Box::new(move || {
215 let mut c = cancelled.lock();
216 *c = Some(cancel());
217 }));
218
219 self.inner.spawn(move || {
220 if let Some(res) = cancelled_spawn.lock().take() {
221 return res;
222 }
223 let ret = f();
224 let mut state = inner.state.lock();
225 state.cancellables.remove(&id);
226 if state.cancellables.is_empty() {
227 inner.cancellables_cv.notify_one();
228 }
229 ret
230 })
231 }
232
233 /// Iterates over all the queued tasks and marks them as cancelled.
drain_cancellables(&self)234 fn drain_cancellables(&self) {
235 let mut state = self.inner.state.lock();
236 // Iterate a few times to try cancelling all the tasks.
237 for _ in 0..Self::RETRY_COUNT {
238 // Nothing left to do.
239 if state.cancellables.is_empty() {
240 return;
241 }
242
243 // We only cancel the task and do not remove it from the cancellables. It is runner's
244 // job to remove from state.cancellables.
245 for cancel in state.cancellables.values() {
246 cancel();
247 }
248 // Hold the state lock in a block before sleeping so that woken up threads can get to
249 // hold the lock.
250 // Wait for a while so that the threads get a chance complete task in flight.
251 let (state1, _cv_timeout) = self
252 .inner
253 .cancellables_cv
254 .wait_timeout(state, Self::SLEEP_DURATION);
255 state = state1;
256 }
257 }
258
259 /// Marks all the queued and in-flight tasks as cancelled. Any tasks queued after `disarm`ing
260 /// will be cancelled.
261 /// Does not wait for all the tasks to get cancelled.
disarm(&self)262 pub fn disarm(&self) {
263 {
264 let mut state = self.inner.state.lock();
265
266 if state.wind_down >= WindDownStates::Disarmed {
267 return;
268 }
269
270 // At this point any new incoming request will be cancelled when run.
271 state.wind_down = WindDownStates::Disarmed;
272 }
273 self.drain_cancellables();
274 }
275
276 /// Shut down the `CancellableBlockingPool`.
277 ///
278 /// This will block until all work that has been started by the worker threads is finished. Any
279 /// work that was added to the `CancellableBlockingPool` but not yet picked up by a worker
280 /// thread will not complete and `await`ing on the `Task` for that work will panic.
shutdown(&self) -> Result<(), Error>281 pub fn shutdown(&self) -> Result<(), Error> {
282 self.shutdown_with_timeout(DEFAULT_SHUTDOWN_TIMEOUT)
283 }
284
shutdown_with_timeout(&self, timeout: Duration) -> Result<(), Error>285 fn shutdown_with_timeout(&self, timeout: Duration) -> Result<(), Error> {
286 self.disarm();
287 {
288 let mut state = self.inner.state.lock();
289 if state.wind_down == WindDownStates::ShuttingDown {
290 return Err(Error::ShutdownInProgress);
291 }
292 if state.wind_down == WindDownStates::ShutDown {
293 return Err(Error::AlreadyShutdown);
294 }
295 state.wind_down = WindDownStates::ShuttingDown;
296 }
297
298 let res = self
299 .inner
300 .blocking_pool
301 .shutdown(/* deadline: */ Some(Instant::now() + timeout));
302
303 self.inner.state.lock().wind_down = WindDownStates::ShutDown;
304 match res {
305 Ok(_) => Ok(()),
306 Err(_) => Err(Error::Timedout),
307 }
308 }
309 }
310
311 impl Default for CancellableBlockingPool {
default() -> CancellableBlockingPool312 fn default() -> CancellableBlockingPool {
313 CancellableBlockingPool::new(256, Duration::from_secs(10))
314 }
315 }
316
317 impl Drop for CancellableBlockingPool {
drop(&mut self)318 fn drop(&mut self) {
319 if let Err(e) = self.shutdown() {
320 base::error!("CancellableBlockingPool::shutdown failed: {}", e);
321 }
322 }
323 }
324
325 /// Spawn a task to run in the `CancellableBlockingPool` static executor.
326 ///
327 /// `cancel` in-flight operation. cancel is called on operation during `disarm` or during
328 /// `shutdown`. Cancel may be called multiple times if running task doesn't get cancelled on first
329 /// attempt.
330 ///
331 /// Callers may `await` the returned `Task` to be notified when the work is completed.
332 ///
333 /// See also: `spawn`.
unblock<F, R, G>(f: F, cancel: G) -> impl Future<Output = R> where F: FnOnce() -> R + Send + 'static, R: Send + 'static, G: Fn() -> R + Send + 'static,334 pub fn unblock<F, R, G>(f: F, cancel: G) -> impl Future<Output = R>
335 where
336 F: FnOnce() -> R + Send + 'static,
337 R: Send + 'static,
338 G: Fn() -> R + Send + 'static,
339 {
340 EXECUTOR.spawn(f, cancel)
341 }
342
343 /// Marks all the queued and in-flight tasks as cancelled. Any tasks queued after `disarm`ing
344 /// will be cancelled.
345 /// Doesn't not wait for all the tasks to get cancelled.
unblock_disarm()346 pub fn unblock_disarm() {
347 EXECUTOR.disarm()
348 }
349
350 #[cfg(test)]
351 mod test {
352 use std::sync::Arc;
353 use std::sync::Barrier;
354 use std::thread;
355 use std::time::Duration;
356
357 use futures::executor::block_on;
358 use sync::Condvar;
359 use sync::Mutex;
360
361 use crate::blocking::Error;
362 use crate::CancellableBlockingPool;
363
364 #[test]
disarm_with_pending_work()365 fn disarm_with_pending_work() {
366 // Create a pool with only one thread.
367 let pool = CancellableBlockingPool::new(1, Duration::from_secs(10));
368
369 let mu = Arc::new(Mutex::new(false));
370 let cv = Arc::new(Condvar::new());
371 let blocker_is_running = Arc::new(Barrier::new(2));
372
373 // First spawn a thread that blocks the pool.
374 let task_mu = mu.clone();
375 let task_cv = cv.clone();
376 let task_blocker_is_running = blocker_is_running.clone();
377 let _blocking_task = pool.spawn(
378 move || {
379 task_blocker_is_running.wait();
380 let mut ready = task_mu.lock();
381 while !*ready {
382 ready = task_cv.wait(ready);
383 }
384 },
385 move || {},
386 );
387
388 // Wait for the worker to start running the blocking thread.
389 blocker_is_running.wait();
390
391 // This task will never finish because we will disarm the pool first.
392 let unfinished = pool.spawn(|| 5, || 0);
393
394 // Disarming should cancel the task.
395 pool.disarm();
396
397 // Shutdown the blocking thread. This will allow a worker to pick up the task that has
398 // to be cancelled.
399 *mu.lock() = true;
400 cv.notify_all();
401
402 // We expect the cancelled value to be returned.
403 assert_eq!(block_on(unfinished), 0);
404
405 // Now the pool is empty and can be shutdown without blocking.
406 pool.shutdown().unwrap();
407 }
408
409 #[test]
shutdown_with_blocked_work_should_timeout()410 fn shutdown_with_blocked_work_should_timeout() {
411 let pool = CancellableBlockingPool::new(1, Duration::from_secs(10));
412
413 let running = Arc::new((Mutex::new(false), Condvar::new()));
414 let running1 = running.clone();
415 let _blocking_task = pool.spawn(
416 move || {
417 *running1.0.lock() = true;
418 running1.1.notify_one();
419 thread::sleep(Duration::from_secs(10000));
420 },
421 move || {},
422 );
423
424 let mut is_running = running.0.lock();
425 while !*is_running {
426 is_running = running.1.wait(is_running);
427 }
428
429 // This shutdown will wait for the full timeout period, so use a short timeout.
430 assert_eq!(
431 pool.shutdown_with_timeout(Duration::from_millis(1)),
432 Err(Error::Timedout)
433 );
434 }
435
436 #[test]
multiple_shutdown_returns_error()437 fn multiple_shutdown_returns_error() {
438 let pool = CancellableBlockingPool::new(1, Duration::from_secs(10));
439 let _ = pool.shutdown();
440 assert_eq!(pool.shutdown(), Err(Error::AlreadyShutdown));
441 }
442
443 #[test]
shutdown_in_progress()444 fn shutdown_in_progress() {
445 let pool = CancellableBlockingPool::new(1, Duration::from_secs(10));
446
447 let running = Arc::new((Mutex::new(false), Condvar::new()));
448 let running1 = running.clone();
449 let _blocking_task = pool.spawn(
450 move || {
451 *running1.0.lock() = true;
452 running1.1.notify_one();
453 thread::sleep(Duration::from_secs(10000));
454 },
455 move || {},
456 );
457
458 let mut is_running = running.0.lock();
459 while !*is_running {
460 is_running = running.1.wait(is_running);
461 }
462
463 let pool_clone = pool.clone();
464 thread::spawn(move || {
465 while !pool_clone.inner.blocking_pool.shutting_down() {}
466 assert_eq!(pool_clone.shutdown(), Err(Error::ShutdownInProgress));
467 });
468
469 // This shutdown will wait for the full timeout period, so use a short timeout.
470 // However, it also needs to wait long enough for the thread spawned above to observe the
471 // shutting_down state, so don't make it too short.
472 assert_eq!(
473 pool.shutdown_with_timeout(Duration::from_millis(200)),
474 Err(Error::Timedout)
475 );
476 }
477 }
478