1 //! A `Barrier` that provides `wait_timeout`.
2 //!
3 //! This implementation mirrors that of the Rust standard library.
4 
5 use crate::loom::sync::{Condvar, Mutex};
6 use std::fmt;
7 use std::time::{Duration, Instant};
8 
9 /// A barrier enables multiple threads to synchronize the beginning
10 /// of some computation.
11 ///
12 /// # Examples
13 ///
14 /// ```
15 /// use std::sync::{Arc, Barrier};
16 /// use std::thread;
17 ///
18 /// let mut handles = Vec::with_capacity(10);
19 /// let barrier = Arc::new(Barrier::new(10));
20 /// for _ in 0..10 {
21 ///     let c = Arc::clone(&barrier);
22 ///     // The same messages will be printed together.
23 ///     // You will NOT see any interleaving.
24 ///     handles.push(thread::spawn(move|| {
25 ///         println!("before wait");
26 ///         c.wait();
27 ///         println!("after wait");
28 ///     }));
29 /// }
30 /// // Wait for other threads to finish.
31 /// for handle in handles {
32 ///     handle.join().unwrap();
33 /// }
34 /// ```
35 pub(crate) struct Barrier {
36     lock: Mutex<BarrierState>,
37     cvar: Condvar,
38     num_threads: usize,
39 }
40 
41 // The inner state of a double barrier
42 struct BarrierState {
43     count: usize,
44     generation_id: usize,
45 }
46 
47 /// A `BarrierWaitResult` is returned by [`Barrier::wait()`] when all threads
48 /// in the [`Barrier`] have rendezvoused.
49 ///
50 /// # Examples
51 ///
52 /// ```
53 /// use std::sync::Barrier;
54 ///
55 /// let barrier = Barrier::new(1);
56 /// let barrier_wait_result = barrier.wait();
57 /// ```
58 pub(crate) struct BarrierWaitResult(bool);
59 
60 impl fmt::Debug for Barrier {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result61     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62         f.debug_struct("Barrier").finish_non_exhaustive()
63     }
64 }
65 
66 impl Barrier {
67     /// Creates a new barrier that can block a given number of threads.
68     ///
69     /// A barrier will block `n`-1 threads which call [`wait()`] and then wake
70     /// up all threads at once when the `n`th thread calls [`wait()`].
71     ///
72     /// [`wait()`]: Barrier::wait
73     ///
74     /// # Examples
75     ///
76     /// ```
77     /// use std::sync::Barrier;
78     ///
79     /// let barrier = Barrier::new(10);
80     /// ```
81     #[must_use]
new(n: usize) -> Barrier82     pub(crate) fn new(n: usize) -> Barrier {
83         Barrier {
84             lock: Mutex::new(BarrierState {
85                 count: 0,
86                 generation_id: 0,
87             }),
88             cvar: Condvar::new(),
89             num_threads: n,
90         }
91     }
92 
93     /// Blocks the current thread until all threads have rendezvoused here.
94     ///
95     /// Barriers are re-usable after all threads have rendezvoused once, and can
96     /// be used continuously.
97     ///
98     /// A single (arbitrary) thread will receive a [`BarrierWaitResult`] that
99     /// returns `true` from [`BarrierWaitResult::is_leader()`] when returning
100     /// from this function, and all other threads will receive a result that
101     /// will return `false` from [`BarrierWaitResult::is_leader()`].
102     ///
103     /// # Examples
104     ///
105     /// ```
106     /// use std::sync::{Arc, Barrier};
107     /// use std::thread;
108     ///
109     /// let mut handles = Vec::with_capacity(10);
110     /// let barrier = Arc::new(Barrier::new(10));
111     /// for _ in 0..10 {
112     ///     let c = Arc::clone(&barrier);
113     ///     // The same messages will be printed together.
114     ///     // You will NOT see any interleaving.
115     ///     handles.push(thread::spawn(move|| {
116     ///         println!("before wait");
117     ///         c.wait();
118     ///         println!("after wait");
119     ///     }));
120     /// }
121     /// // Wait for other threads to finish.
122     /// for handle in handles {
123     ///     handle.join().unwrap();
124     /// }
125     /// ```
wait(&self) -> BarrierWaitResult126     pub(crate) fn wait(&self) -> BarrierWaitResult {
127         let mut lock = self.lock.lock();
128         let local_gen = lock.generation_id;
129         lock.count += 1;
130         if lock.count < self.num_threads {
131             // We need a while loop to guard against spurious wakeups.
132             // https://en.wikipedia.org/wiki/Spurious_wakeup
133             while local_gen == lock.generation_id {
134                 lock = self.cvar.wait(lock).unwrap();
135             }
136             BarrierWaitResult(false)
137         } else {
138             lock.count = 0;
139             lock.generation_id = lock.generation_id.wrapping_add(1);
140             self.cvar.notify_all();
141             BarrierWaitResult(true)
142         }
143     }
144 
145     /// Blocks the current thread until all threads have rendezvoused here for
146     /// at most `timeout` duration.
wait_timeout(&self, timeout: Duration) -> Option<BarrierWaitResult>147     pub(crate) fn wait_timeout(&self, timeout: Duration) -> Option<BarrierWaitResult> {
148         // This implementation mirrors `wait`, but with each blocking operation
149         // replaced by a timeout-amenable alternative.
150 
151         let deadline = Instant::now() + timeout;
152 
153         // Acquire `self.lock` with at most `timeout` duration.
154         let mut lock = loop {
155             if let Some(guard) = self.lock.try_lock() {
156                 break guard;
157             } else if Instant::now() > deadline {
158                 return None;
159             } else {
160                 std::thread::yield_now();
161             }
162         };
163 
164         // Shrink the `timeout` to account for the time taken to acquire `lock`.
165         let timeout = deadline.saturating_duration_since(Instant::now());
166 
167         let local_gen = lock.generation_id;
168         lock.count += 1;
169         if lock.count < self.num_threads {
170             // We need a while loop to guard against spurious wakeups.
171             // https://en.wikipedia.org/wiki/Spurious_wakeup
172             while local_gen == lock.generation_id {
173                 let (guard, timeout_result) = self.cvar.wait_timeout(lock, timeout).unwrap();
174                 lock = guard;
175                 if timeout_result.timed_out() {
176                     return None;
177                 }
178             }
179             Some(BarrierWaitResult(false))
180         } else {
181             lock.count = 0;
182             lock.generation_id = lock.generation_id.wrapping_add(1);
183             self.cvar.notify_all();
184             Some(BarrierWaitResult(true))
185         }
186     }
187 }
188 
189 impl fmt::Debug for BarrierWaitResult {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result190     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191         f.debug_struct("BarrierWaitResult")
192             .field("is_leader", &self.is_leader())
193             .finish()
194     }
195 }
196 
197 impl BarrierWaitResult {
198     /// Returns `true` if this thread is the "leader thread" for the call to
199     /// [`Barrier::wait()`].
200     ///
201     /// Only one thread will have `true` returned from their result, all other
202     /// threads will have `false` returned.
203     ///
204     /// # Examples
205     ///
206     /// ```
207     /// use std::sync::Barrier;
208     ///
209     /// let barrier = Barrier::new(1);
210     /// let barrier_wait_result = barrier.wait();
211     /// println!("{:?}", barrier_wait_result.is_leader());
212     /// ```
213     #[must_use]
is_leader(&self) -> bool214     pub(crate) fn is_leader(&self) -> bool {
215         self.0
216     }
217 }
218