xref: /aosp_15_r20/external/crosvm/base/src/sys/windows/thread.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2022 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::any::Any;
6 use std::panic;
7 use std::sync::mpsc::channel;
8 use std::sync::mpsc::Receiver;
9 use std::thread;
10 use std::thread::JoinHandle;
11 use std::time::Duration;
12 
13 /// Spawns a thread that can be joined with a timeout.
spawn_with_timeout<F, T>(f: F) -> JoinHandleWithTimeout<T> where F: FnOnce() -> T, F: Send + 'static, T: Send + 'static,14 pub fn spawn_with_timeout<F, T>(f: F) -> JoinHandleWithTimeout<T>
15 where
16     F: FnOnce() -> T,
17     F: Send + 'static,
18     T: Send + 'static,
19 {
20     // Use a channel to signal completion to the join handle
21     let (tx, rx) = channel();
22     let handle = thread::spawn(move || {
23         let val = panic::catch_unwind(panic::AssertUnwindSafe(f));
24         tx.send(()).unwrap();
25         val
26     });
27     JoinHandleWithTimeout { handle, rx }
28 }
29 
30 pub struct JoinHandleWithTimeout<T> {
31     handle: JoinHandle<thread::Result<T>>,
32     rx: Receiver<()>,
33 }
34 
35 #[derive(Debug)]
36 pub enum JoinError {
37     Panic(Box<dyn Any>),
38     Timeout,
39 }
40 
41 impl<T> JoinHandleWithTimeout<T> {
42     /// Tries to join the thread.  Returns an error if the join takes more than `timeout_ms`.
try_join(self, timeout: Duration) -> Result<T, JoinError>43     pub fn try_join(self, timeout: Duration) -> Result<T, JoinError> {
44         if self.rx.recv_timeout(timeout).is_ok() {
45             self.handle.join().unwrap().map_err(|e| JoinError::Panic(e))
46         } else {
47             Err(JoinError::Timeout)
48         }
49     }
50 }
51