1 // Copyright 2022 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 use std::any::Any;
6 use std::panic;
7 use std::sync::mpsc::channel;
8 use std::sync::mpsc::Receiver;
9 use std::thread;
10 use std::thread::JoinHandle;
11 use std::time::Duration;
12
13 /// Spawns a thread that can be joined with a timeout.
spawn_with_timeout<F, T>(f: F) -> JoinHandleWithTimeout<T> where F: FnOnce() -> T, F: Send + 'static, T: Send + 'static,14 pub fn spawn_with_timeout<F, T>(f: F) -> JoinHandleWithTimeout<T>
15 where
16 F: FnOnce() -> T,
17 F: Send + 'static,
18 T: Send + 'static,
19 {
20 // Use a channel to signal completion to the join handle
21 let (tx, rx) = channel();
22 let handle = thread::spawn(move || {
23 let val = panic::catch_unwind(panic::AssertUnwindSafe(f));
24 tx.send(()).unwrap();
25 val
26 });
27 JoinHandleWithTimeout { handle, rx }
28 }
29
30 pub struct JoinHandleWithTimeout<T> {
31 handle: JoinHandle<thread::Result<T>>,
32 rx: Receiver<()>,
33 }
34
35 #[derive(Debug)]
36 pub enum JoinError {
37 Panic(Box<dyn Any>),
38 Timeout,
39 }
40
41 impl<T> JoinHandleWithTimeout<T> {
42 /// Tries to join the thread. Returns an error if the join takes more than `timeout_ms`.
try_join(self, timeout: Duration) -> Result<T, JoinError>43 pub fn try_join(self, timeout: Duration) -> Result<T, JoinError> {
44 if self.rx.recv_timeout(timeout).is_ok() {
45 self.handle.join().unwrap().map_err(|e| JoinError::Panic(e))
46 } else {
47 Err(JoinError::Timeout)
48 }
49 }
50 }
51