1 //! Synchronization primitive allowing multiple threads to synchronize the 2 //! beginning of some computation. 3 //! 4 //! Implementation adapted from the 'Barrier' type of the standard library. See: 5 //! <https://doc.rust-lang.org/std/sync/struct.Barrier.html> 6 //! 7 //! Copyright 2014 The Rust Project Developers. See the COPYRIGHT 8 //! file at the top-level directory of this distribution and at 9 //! <http://rust-lang.org/COPYRIGHT>. 10 //! 11 //! Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or 12 //! <http://www.apache.org/licenses/LICENSE-2.0>> or the MIT license 13 //! <LICENSE-MIT or <http://opensource.org/licenses/MIT>>, at your 14 //! option. This file may not be copied, modified, or distributed 15 //! except according to those terms. 16 17 use crate::{mutex::Mutex, RelaxStrategy, Spin}; 18 19 /// A primitive that synchronizes the execution of multiple threads. 20 /// 21 /// # Example 22 /// 23 /// ``` 24 /// use spin; 25 /// use std::sync::Arc; 26 /// use std::thread; 27 /// 28 /// let mut handles = Vec::with_capacity(10); 29 /// let barrier = Arc::new(spin::Barrier::new(10)); 30 /// for _ in 0..10 { 31 /// let c = barrier.clone(); 32 /// // The same messages will be printed together. 33 /// // You will NOT see any interleaving. 34 /// handles.push(thread::spawn(move|| { 35 /// println!("before wait"); 36 /// c.wait(); 37 /// println!("after wait"); 38 /// })); 39 /// } 40 /// // Wait for other threads to finish. 41 /// for handle in handles { 42 /// handle.join().unwrap(); 43 /// } 44 /// ``` 45 pub struct Barrier<R = Spin> { 46 lock: Mutex<BarrierState, R>, 47 num_threads: usize, 48 } 49 50 // The inner state of a double barrier 51 struct BarrierState { 52 count: usize, 53 generation_id: usize, 54 } 55 56 /// A `BarrierWaitResult` is returned by [`wait`] when all threads in the [`Barrier`] 57 /// have rendezvoused. 58 /// 59 /// [`wait`]: struct.Barrier.html#method.wait 60 /// [`Barrier`]: struct.Barrier.html 61 /// 62 /// # Examples 63 /// 64 /// ``` 65 /// use spin; 66 /// 67 /// let barrier = spin::Barrier::new(1); 68 /// let barrier_wait_result = barrier.wait(); 69 /// ``` 70 pub struct BarrierWaitResult(bool); 71 72 impl<R: RelaxStrategy> Barrier<R> { 73 /// Blocks the current thread until all threads have rendezvoused here. 74 /// 75 /// Barriers are re-usable after all threads have rendezvoused once, and can 76 /// be used continuously. 77 /// 78 /// A single (arbitrary) thread will receive a [`BarrierWaitResult`] that 79 /// returns `true` from [`is_leader`] when returning from this function, and 80 /// all other threads will receive a result that will return `false` from 81 /// [`is_leader`]. 82 /// 83 /// [`BarrierWaitResult`]: struct.BarrierWaitResult.html 84 /// [`is_leader`]: struct.BarrierWaitResult.html#method.is_leader 85 /// 86 /// # Examples 87 /// 88 /// ``` 89 /// use spin; 90 /// use std::sync::Arc; 91 /// use std::thread; 92 /// 93 /// let mut handles = Vec::with_capacity(10); 94 /// let barrier = Arc::new(spin::Barrier::new(10)); 95 /// for _ in 0..10 { 96 /// let c = barrier.clone(); 97 /// // The same messages will be printed together. 98 /// // You will NOT see any interleaving. 99 /// handles.push(thread::spawn(move|| { 100 /// println!("before wait"); 101 /// c.wait(); 102 /// println!("after wait"); 103 /// })); 104 /// } 105 /// // Wait for other threads to finish. 106 /// for handle in handles { 107 /// handle.join().unwrap(); 108 /// } 109 /// ``` wait(&self) -> BarrierWaitResult110 pub fn wait(&self) -> BarrierWaitResult { 111 let mut lock = self.lock.lock(); 112 lock.count += 1; 113 114 if lock.count < self.num_threads { 115 // not the leader 116 let local_gen = lock.generation_id; 117 118 while local_gen == lock.generation_id && lock.count < self.num_threads { 119 drop(lock); 120 R::relax(); 121 lock = self.lock.lock(); 122 } 123 BarrierWaitResult(false) 124 } else { 125 // this thread is the leader, 126 // and is responsible for incrementing the generation 127 lock.count = 0; 128 lock.generation_id = lock.generation_id.wrapping_add(1); 129 BarrierWaitResult(true) 130 } 131 } 132 } 133 134 impl<R> Barrier<R> { 135 /// Creates a new barrier that can block a given number of threads. 136 /// 137 /// A barrier will block `n`-1 threads which call [`wait`] and then wake up 138 /// all threads at once when the `n`th thread calls [`wait`]. A Barrier created 139 /// with n = 0 will behave identically to one created with n = 1. 140 /// 141 /// [`wait`]: #method.wait 142 /// 143 /// # Examples 144 /// 145 /// ``` 146 /// use spin; 147 /// 148 /// let barrier = spin::Barrier::new(10); 149 /// ``` new(n: usize) -> Self150 pub const fn new(n: usize) -> Self { 151 Self { 152 lock: Mutex::new(BarrierState { 153 count: 0, 154 generation_id: 0, 155 }), 156 num_threads: n, 157 } 158 } 159 } 160 161 impl BarrierWaitResult { 162 /// Returns whether this thread from [`wait`] is the "leader thread". 163 /// 164 /// Only one thread will have `true` returned from their result, all other 165 /// threads will have `false` returned. 166 /// 167 /// [`wait`]: struct.Barrier.html#method.wait 168 /// 169 /// # Examples 170 /// 171 /// ``` 172 /// use spin; 173 /// 174 /// let barrier = spin::Barrier::new(1); 175 /// let barrier_wait_result = barrier.wait(); 176 /// println!("{:?}", barrier_wait_result.is_leader()); 177 /// ``` is_leader(&self) -> bool178 pub fn is_leader(&self) -> bool { 179 self.0 180 } 181 } 182 183 #[cfg(test)] 184 mod tests { 185 use std::prelude::v1::*; 186 187 use std::sync::mpsc::{channel, TryRecvError}; 188 use std::sync::Arc; 189 use std::thread; 190 191 type Barrier = super::Barrier; 192 use_barrier(n: usize, barrier: Arc<Barrier>)193 fn use_barrier(n: usize, barrier: Arc<Barrier>) { 194 let (tx, rx) = channel(); 195 196 let mut ts = Vec::new(); 197 for _ in 0..n - 1 { 198 let c = barrier.clone(); 199 let tx = tx.clone(); 200 ts.push(thread::spawn(move || { 201 tx.send(c.wait().is_leader()).unwrap(); 202 })); 203 } 204 205 // At this point, all spawned threads should be blocked, 206 // so we shouldn't get anything from the port 207 assert!(match rx.try_recv() { 208 Err(TryRecvError::Empty) => true, 209 _ => false, 210 }); 211 212 let mut leader_found = barrier.wait().is_leader(); 213 214 // Now, the barrier is cleared and we should get data. 215 for _ in 0..n - 1 { 216 if rx.recv().unwrap() { 217 assert!(!leader_found); 218 leader_found = true; 219 } 220 } 221 assert!(leader_found); 222 223 for t in ts { 224 t.join().unwrap(); 225 } 226 } 227 228 #[test] test_barrier()229 fn test_barrier() { 230 const N: usize = 10; 231 232 let barrier = Arc::new(Barrier::new(N)); 233 234 use_barrier(N, barrier.clone()); 235 236 // use barrier twice to ensure it is reusable 237 use_barrier(N, barrier.clone()); 238 } 239 } 240