xref: /aosp_15_r20/external/crosvm/cros_async/src/sync/cv.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2020 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::cell::UnsafeCell;
6 use std::hint;
7 use std::mem;
8 use std::sync::atomic::AtomicUsize;
9 use std::sync::atomic::Ordering;
10 use std::sync::Arc;
11 
12 use super::super::sync::mu::RawRwLock;
13 use super::super::sync::mu::RwLockReadGuard;
14 use super::super::sync::mu::RwLockWriteGuard;
15 use super::super::sync::waiter::Kind as WaiterKind;
16 use super::super::sync::waiter::Waiter;
17 use super::super::sync::waiter::WaiterAdapter;
18 use super::super::sync::waiter::WaiterList;
19 use super::super::sync::waiter::WaitingFor;
20 
21 const SPINLOCK: usize = 1 << 0;
22 const HAS_WAITERS: usize = 1 << 1;
23 
24 /// A primitive to wait for an event to occur without consuming CPU time.
25 ///
26 /// Condition variables are used in combination with a `RwLock` when a thread wants to wait for some
27 /// condition to become true. The condition must always be verified while holding the `RwLock` lock.
28 /// It is an error to use a `Condvar` with more than one `RwLock` while there are threads waiting on
29 /// the `Condvar`.
30 ///
31 /// # Examples
32 ///
33 /// ```edition2018
34 /// use std::sync::Arc;
35 /// use std::thread;
36 /// use std::sync::mpsc::channel;
37 ///
38 /// use cros_async::{
39 ///     block_on,
40 ///     sync::{Condvar, RwLock},
41 /// };
42 ///
43 /// const N: usize = 13;
44 ///
45 /// // Spawn a few threads to increment a shared variable (non-atomically), and
46 /// // let all threads waiting on the Condvar know once the increments are done.
47 /// let data = Arc::new(RwLock::new(0));
48 /// let cv = Arc::new(Condvar::new());
49 ///
50 /// for _ in 0..N {
51 ///     let (data, cv) = (data.clone(), cv.clone());
52 ///     thread::spawn(move || {
53 ///         let mut data = block_on(data.lock());
54 ///         *data += 1;
55 ///         if *data == N {
56 ///             cv.notify_all();
57 ///         }
58 ///     });
59 /// }
60 ///
61 /// let mut val = block_on(data.lock());
62 /// while *val != N {
63 ///     val = block_on(cv.wait(val));
64 /// }
65 /// ```
66 #[repr(align(128))]
67 pub struct Condvar {
68     state: AtomicUsize,
69     waiters: UnsafeCell<WaiterList>,
70     mu: UnsafeCell<usize>,
71 }
72 
73 impl Condvar {
74     /// Creates a new condition variable ready to be waited on and notified.
new() -> Condvar75     pub fn new() -> Condvar {
76         Condvar {
77             state: AtomicUsize::new(0),
78             waiters: UnsafeCell::new(WaiterList::new(WaiterAdapter::new())),
79             mu: UnsafeCell::new(0),
80         }
81     }
82 
83     /// Block the current thread until this `Condvar` is notified by another thread.
84     ///
85     /// This method will atomically unlock the `RwLock` held by `guard` and then block the current
86     /// thread. Any call to `notify_one` or `notify_all` after the `RwLock` is unlocked may wake up
87     /// the thread.
88     ///
89     /// To allow for more efficient scheduling, this call may return even when the programmer
90     /// doesn't expect the thread to be woken. Therefore, calls to `wait()` should be used inside a
91     /// loop that checks the predicate before continuing.
92     ///
93     /// Callers that are not in an async context may wish to use the `block_on` method to block the
94     /// thread until the `Condvar` is notified.
95     ///
96     /// # Panics
97     ///
98     /// This method will panic if used with more than one `RwLock` at the same time.
99     ///
100     /// # Examples
101     ///
102     /// ```
103     /// # use std::sync::Arc;
104     /// # use std::thread;
105     ///
106     /// # use cros_async::{
107     /// #     block_on,
108     /// #     sync::{Condvar, RwLock},
109     /// # };
110     ///
111     /// # let mu = Arc::new(RwLock::new(false));
112     /// # let cv = Arc::new(Condvar::new());
113     /// # let (mu2, cv2) = (mu.clone(), cv.clone());
114     ///
115     /// # let t = thread::spawn(move || {
116     /// #     *block_on(mu2.lock()) = true;
117     /// #     cv2.notify_all();
118     /// # });
119     ///
120     /// let mut ready = block_on(mu.lock());
121     /// while !*ready {
122     ///     ready = block_on(cv.wait(ready));
123     /// }
124     ///
125     /// # t.join().expect("failed to join thread");
126     /// ```
127     // Clippy doesn't like the lifetime parameters here but doing what it suggests leads to code
128     // that doesn't compile.
129     #[allow(clippy::needless_lifetimes)]
wait<'g, T>(&self, guard: RwLockWriteGuard<'g, T>) -> RwLockWriteGuard<'g, T>130     pub async fn wait<'g, T>(&self, guard: RwLockWriteGuard<'g, T>) -> RwLockWriteGuard<'g, T> {
131         let waiter = Arc::new(Waiter::new(
132             WaiterKind::Exclusive,
133             cancel_waiter,
134             self as *const Condvar as usize,
135             WaitingFor::Condvar,
136         ));
137 
138         self.add_waiter(waiter.clone(), guard.as_raw_rwlock());
139 
140         // Get a reference to the rwlock and then drop the lock.
141         let mu = guard.into_inner();
142 
143         // Wait to be woken up.
144         waiter.wait().await;
145 
146         // Now re-acquire the lock.
147         mu.lock_from_cv().await
148     }
149 
150     /// Like `wait()` but takes and returns a `RwLockReadGuard` instead.
151     // Clippy doesn't like the lifetime parameters here but doing what it suggests leads to code
152     // that doesn't compile.
153     #[allow(clippy::needless_lifetimes)]
wait_read<'g, T>(&self, guard: RwLockReadGuard<'g, T>) -> RwLockReadGuard<'g, T>154     pub async fn wait_read<'g, T>(&self, guard: RwLockReadGuard<'g, T>) -> RwLockReadGuard<'g, T> {
155         let waiter = Arc::new(Waiter::new(
156             WaiterKind::Shared,
157             cancel_waiter,
158             self as *const Condvar as usize,
159             WaitingFor::Condvar,
160         ));
161 
162         self.add_waiter(waiter.clone(), guard.as_raw_rwlock());
163 
164         // Get a reference to the rwlock and then drop the lock.
165         let mu = guard.into_inner();
166 
167         // Wait to be woken up.
168         waiter.wait().await;
169 
170         // Now re-acquire the lock.
171         mu.read_lock_from_cv().await
172     }
173 
add_waiter(&self, waiter: Arc<Waiter>, raw_rwlock: &RawRwLock)174     fn add_waiter(&self, waiter: Arc<Waiter>, raw_rwlock: &RawRwLock) {
175         // Acquire the spin lock.
176         let mut oldstate = self.state.load(Ordering::Relaxed);
177         while (oldstate & SPINLOCK) != 0
178             || self
179                 .state
180                 .compare_exchange_weak(
181                     oldstate,
182                     oldstate | SPINLOCK | HAS_WAITERS,
183                     Ordering::Acquire,
184                     Ordering::Relaxed,
185                 )
186                 .is_err()
187         {
188             hint::spin_loop();
189             oldstate = self.state.load(Ordering::Relaxed);
190         }
191 
192         // SAFETY:
193         // Safe because the spin lock guarantees exclusive access and the reference does not escape
194         // this function.
195         let mu = unsafe { &mut *self.mu.get() };
196         let muptr = raw_rwlock as *const RawRwLock as usize;
197 
198         match *mu {
199             0 => *mu = muptr,
200             p if p == muptr => {}
201             _ => panic!("Attempting to use Condvar with more than one RwLock at the same time"),
202         }
203 
204         // SAFETY:
205         // Safe because the spin lock guarantees exclusive access.
206         unsafe { (*self.waiters.get()).push_back(waiter) };
207 
208         // Release the spin lock. Use a direct store here because no other thread can modify
209         // `self.state` while we hold the spin lock. Keep the `HAS_WAITERS` bit that we set earlier
210         // because we just added a waiter.
211         self.state.store(HAS_WAITERS, Ordering::Release);
212     }
213 
214     /// Notify at most one thread currently waiting on the `Condvar`.
215     ///
216     /// If there is a thread currently waiting on the `Condvar` it will be woken up from its call to
217     /// `wait`.
218     ///
219     /// Unlike more traditional condition variable interfaces, this method requires a reference to
220     /// the `RwLock` associated with this `Condvar`. This is because it is inherently racy to call
221     /// `notify_one` or `notify_all` without first acquiring the `RwLock` lock. Additionally, taking
222     /// a reference to the `RwLock` here allows us to make some optimizations that can improve
223     /// performance by reducing unnecessary wakeups.
notify_one(&self)224     pub fn notify_one(&self) {
225         let mut oldstate = self.state.load(Ordering::Relaxed);
226         if (oldstate & HAS_WAITERS) == 0 {
227             // No waiters.
228             return;
229         }
230 
231         while (oldstate & SPINLOCK) != 0
232             || self
233                 .state
234                 .compare_exchange_weak(
235                     oldstate,
236                     oldstate | SPINLOCK,
237                     Ordering::Acquire,
238                     Ordering::Relaxed,
239                 )
240                 .is_err()
241         {
242             hint::spin_loop();
243             oldstate = self.state.load(Ordering::Relaxed);
244         }
245 
246         // SAFETY:
247         // Safe because the spin lock guarantees exclusive access and the reference does not escape
248         // this function.
249         let waiters = unsafe { &mut *self.waiters.get() };
250         let wake_list = get_wake_list(waiters);
251 
252         let newstate = if waiters.is_empty() {
253             // SAFETY:
254             // Also clear the rwlock associated with this Condvar since there are no longer any
255             // waiters.  Safe because the spin lock guarantees exclusive access.
256             unsafe { *self.mu.get() = 0 };
257 
258             // We are releasing the spin lock and there are no more waiters so we can clear all bits
259             // in `self.state`.
260             0
261         } else {
262             // There are still waiters so we need to keep the HAS_WAITERS bit in the state.
263             HAS_WAITERS
264         };
265 
266         // Release the spin lock.
267         self.state.store(newstate, Ordering::Release);
268 
269         // Now wake any waiters in the wake list.
270         for w in wake_list {
271             w.wake();
272         }
273     }
274 
275     /// Notify all threads currently waiting on the `Condvar`.
276     ///
277     /// All threads currently waiting on the `Condvar` will be woken up from their call to `wait`.
278     ///
279     /// Unlike more traditional condition variable interfaces, this method requires a reference to
280     /// the `RwLock` associated with this `Condvar`. This is because it is inherently racy to call
281     /// `notify_one` or `notify_all` without first acquiring the `RwLock` lock. Additionally, taking
282     /// a reference to the `RwLock` here allows us to make some optimizations that can improve
283     /// performance by reducing unnecessary wakeups.
notify_all(&self)284     pub fn notify_all(&self) {
285         let mut oldstate = self.state.load(Ordering::Relaxed);
286         if (oldstate & HAS_WAITERS) == 0 {
287             // No waiters.
288             return;
289         }
290 
291         while (oldstate & SPINLOCK) != 0
292             || self
293                 .state
294                 .compare_exchange_weak(
295                     oldstate,
296                     oldstate | SPINLOCK,
297                     Ordering::Acquire,
298                     Ordering::Relaxed,
299                 )
300                 .is_err()
301         {
302             hint::spin_loop();
303             oldstate = self.state.load(Ordering::Relaxed);
304         }
305 
306         // SAFETY:
307         // Safe because the spin lock guarantees exclusive access to `self.waiters`.
308         let wake_list = unsafe { (*self.waiters.get()).take() };
309 
310         // SAFETY:
311         // Clear the rwlock associated with this Condvar since there are no longer any waiters. Safe
312         // because we the spin lock guarantees exclusive access.
313         unsafe { *self.mu.get() = 0 };
314 
315         // Mark any waiters left as no longer waiting for the Condvar.
316         for w in &wake_list {
317             w.set_waiting_for(WaitingFor::None);
318         }
319 
320         // Release the spin lock.  We can clear all bits in the state since we took all the waiters.
321         self.state.store(0, Ordering::Release);
322 
323         // Now wake any waiters in the wake list.
324         for w in wake_list {
325             w.wake();
326         }
327     }
328 
cancel_waiter(&self, waiter: &Waiter, wake_next: bool)329     fn cancel_waiter(&self, waiter: &Waiter, wake_next: bool) {
330         let mut oldstate = self.state.load(Ordering::Relaxed);
331         while oldstate & SPINLOCK != 0
332             || self
333                 .state
334                 .compare_exchange_weak(
335                     oldstate,
336                     oldstate | SPINLOCK,
337                     Ordering::Acquire,
338                     Ordering::Relaxed,
339                 )
340                 .is_err()
341         {
342             hint::spin_loop();
343             oldstate = self.state.load(Ordering::Relaxed);
344         }
345 
346         // SAFETY:
347         // Safe because the spin lock provides exclusive access and the reference does not escape
348         // this function.
349         let waiters = unsafe { &mut *self.waiters.get() };
350 
351         let waiting_for = waiter.is_waiting_for();
352         // Don't drop the old waiter now as we're still holding the spin lock.
353         let old_waiter = if waiter.is_linked() && waiting_for == WaitingFor::Condvar {
354             // SAFETY:
355             // Safe because we know that the waiter is still linked and is waiting for the Condvar,
356             // which guarantees that it is still in `self.waiters`.
357             let mut cursor = unsafe { waiters.cursor_mut_from_ptr(waiter as *const Waiter) };
358             cursor.remove()
359         } else {
360             None
361         };
362 
363         let wake_list = if wake_next || waiting_for == WaitingFor::None {
364             // Either the waiter was already woken or it's been removed from the condvar's waiter
365             // list and is going to be woken. Either way, we need to wake up another thread.
366             get_wake_list(waiters)
367         } else {
368             WaiterList::new(WaiterAdapter::new())
369         };
370 
371         let set_on_release = if waiters.is_empty() {
372             // SAFETY:
373             // Clear the rwlock associated with this Condvar since there are no longer any waiters.
374             // Safe because we the spin lock guarantees exclusive access.
375             unsafe { *self.mu.get() = 0 };
376 
377             0
378         } else {
379             HAS_WAITERS
380         };
381 
382         self.state.store(set_on_release, Ordering::Release);
383 
384         // Now wake any waiters still left in the wake list.
385         for w in wake_list {
386             w.wake();
387         }
388 
389         mem::drop(old_waiter);
390     }
391 }
392 
393 // TODO(b/315998194): Add safety comment
394 #[allow(clippy::undocumented_unsafe_blocks)]
395 unsafe impl Send for Condvar {}
396 // TODO(b/315998194): Add safety comment
397 #[allow(clippy::undocumented_unsafe_blocks)]
398 unsafe impl Sync for Condvar {}
399 
400 impl Default for Condvar {
default() -> Self401     fn default() -> Self {
402         Self::new()
403     }
404 }
405 
406 // Scan `waiters` and return all waiters that should be woken up.
407 //
408 // If the first waiter is trying to acquire a shared lock, then all waiters in the list that are
409 // waiting for a shared lock are also woken up. In addition one writer is woken up, if possible.
410 //
411 // If the first waiter is trying to acquire an exclusive lock, then only that waiter is returned and
412 // the rest of the list is not scanned.
get_wake_list(waiters: &mut WaiterList) -> WaiterList413 fn get_wake_list(waiters: &mut WaiterList) -> WaiterList {
414     let mut to_wake = WaiterList::new(WaiterAdapter::new());
415     let mut cursor = waiters.front_mut();
416 
417     let mut waking_readers = false;
418     let mut all_readers = true;
419     while let Some(w) = cursor.get() {
420         match w.kind() {
421             WaiterKind::Exclusive if !waking_readers => {
422                 // This is the first waiter and it's a writer. No need to check the other waiters.
423                 // Also mark the waiter as having been removed from the Condvar's waiter list.
424                 let waiter = cursor.remove().unwrap();
425                 waiter.set_waiting_for(WaitingFor::None);
426                 to_wake.push_back(waiter);
427                 break;
428             }
429 
430             WaiterKind::Shared => {
431                 // This is a reader and the first waiter in the list was not a writer so wake up all
432                 // the readers in the wait list.
433                 let waiter = cursor.remove().unwrap();
434                 waiter.set_waiting_for(WaitingFor::None);
435                 to_wake.push_back(waiter);
436                 waking_readers = true;
437             }
438 
439             WaiterKind::Exclusive => {
440                 debug_assert!(waking_readers);
441                 if all_readers {
442                     // We are waking readers but we need to ensure that at least one writer is woken
443                     // up. Since we haven't yet woken up a writer, wake up this one.
444                     let waiter = cursor.remove().unwrap();
445                     waiter.set_waiting_for(WaitingFor::None);
446                     to_wake.push_back(waiter);
447                     all_readers = false;
448                 } else {
449                     // We are waking readers and have already woken one writer. Skip this one.
450                     cursor.move_next();
451                 }
452             }
453         }
454     }
455 
456     to_wake
457 }
458 
cancel_waiter(cv: usize, waiter: &Waiter, wake_next: bool)459 fn cancel_waiter(cv: usize, waiter: &Waiter, wake_next: bool) {
460     let condvar = cv as *const Condvar;
461 
462     // SAFETY:
463     // Safe because the thread that owns the waiter being canceled must also own a reference to the
464     // Condvar, which guarantees that this pointer is valid.
465     unsafe { (*condvar).cancel_waiter(waiter, wake_next) }
466 }
467 
468 // TODO(b/194338842): Fix tests for windows
469 #[cfg(any(target_os = "android", target_os = "linux"))]
470 #[cfg(test)]
471 mod test {
472     use std::future::Future;
473     use std::mem;
474     use std::ptr;
475     use std::rc::Rc;
476     use std::sync::mpsc::channel;
477     use std::sync::mpsc::Sender;
478     use std::sync::Arc;
479     use std::task::Context;
480     use std::task::Poll;
481     use std::thread;
482     use std::thread::JoinHandle;
483     use std::time::Duration;
484 
485     use futures::channel::oneshot;
486     use futures::select;
487     use futures::task::waker_ref;
488     use futures::task::ArcWake;
489     use futures::FutureExt;
490     use futures_executor::LocalPool;
491     use futures_executor::LocalSpawner;
492     use futures_executor::ThreadPool;
493     use futures_util::task::LocalSpawnExt;
494 
495     use super::super::super::block_on;
496     use super::super::super::sync::RwLock;
497     use super::*;
498 
499     // Dummy waker used when we want to manually drive futures.
500     struct TestWaker;
501     impl ArcWake for TestWaker {
wake_by_ref(_arc_self: &Arc<Self>)502         fn wake_by_ref(_arc_self: &Arc<Self>) {}
503     }
504 
505     #[test]
smoke()506     fn smoke() {
507         let cv = Condvar::new();
508         cv.notify_one();
509         cv.notify_all();
510     }
511 
512     #[test]
notify_one()513     fn notify_one() {
514         let mu = Arc::new(RwLock::new(()));
515         let cv = Arc::new(Condvar::new());
516 
517         let mu2 = mu.clone();
518         let cv2 = cv.clone();
519 
520         let guard = block_on(mu.lock());
521         thread::spawn(move || {
522             let _g = block_on(mu2.lock());
523             cv2.notify_one();
524         });
525 
526         let guard = block_on(cv.wait(guard));
527         mem::drop(guard);
528     }
529 
530     #[test]
multi_rwlock()531     fn multi_rwlock() {
532         const NUM_THREADS: usize = 5;
533 
534         let mu = Arc::new(RwLock::new(false));
535         let cv = Arc::new(Condvar::new());
536 
537         let mut threads = Vec::with_capacity(NUM_THREADS);
538         for _ in 0..NUM_THREADS {
539             let mu = mu.clone();
540             let cv = cv.clone();
541 
542             threads.push(thread::spawn(move || {
543                 let mut ready = block_on(mu.lock());
544                 while !*ready {
545                     ready = block_on(cv.wait(ready));
546                 }
547             }));
548         }
549 
550         let mut g = block_on(mu.lock());
551         *g = true;
552         mem::drop(g);
553         cv.notify_all();
554 
555         threads
556             .into_iter()
557             .try_for_each(JoinHandle::join)
558             .expect("Failed to join threads");
559 
560         // Now use the Condvar with a different rwlock.
561         let alt_mu = Arc::new(RwLock::new(None));
562         let alt_mu2 = alt_mu.clone();
563         let cv2 = cv.clone();
564         let handle = thread::spawn(move || {
565             let mut g = block_on(alt_mu2.lock());
566             while g.is_none() {
567                 g = block_on(cv2.wait(g));
568             }
569         });
570 
571         let mut alt_g = block_on(alt_mu.lock());
572         *alt_g = Some(());
573         mem::drop(alt_g);
574         cv.notify_all();
575 
576         handle
577             .join()
578             .expect("Failed to join thread alternate rwlock");
579     }
580 
581     #[test]
notify_one_single_thread_async()582     fn notify_one_single_thread_async() {
583         async fn notify(mu: Rc<RwLock<()>>, cv: Rc<Condvar>) {
584             let _g = mu.lock().await;
585             cv.notify_one();
586         }
587 
588         async fn wait(mu: Rc<RwLock<()>>, cv: Rc<Condvar>, spawner: LocalSpawner) {
589             let mu2 = Rc::clone(&mu);
590             let cv2 = Rc::clone(&cv);
591 
592             let g = mu.lock().await;
593             // Has to be spawned _after_ acquiring the lock to prevent a race
594             // where the notify happens before the waiter has acquired the lock.
595             spawner
596                 .spawn_local(notify(mu2, cv2))
597                 .expect("Failed to spawn `notify` task");
598             let _g = cv.wait(g).await;
599         }
600 
601         let mut ex = LocalPool::new();
602         let spawner = ex.spawner();
603 
604         let mu = Rc::new(RwLock::new(()));
605         let cv = Rc::new(Condvar::new());
606 
607         spawner
608             .spawn_local(wait(mu, cv, spawner.clone()))
609             .expect("Failed to spawn `wait` task");
610 
611         ex.run();
612     }
613 
614     #[test]
notify_one_multi_thread_async()615     fn notify_one_multi_thread_async() {
616         async fn notify(mu: Arc<RwLock<()>>, cv: Arc<Condvar>) {
617             let _g = mu.lock().await;
618             cv.notify_one();
619         }
620 
621         async fn wait(mu: Arc<RwLock<()>>, cv: Arc<Condvar>, tx: Sender<()>, pool: ThreadPool) {
622             let mu2 = Arc::clone(&mu);
623             let cv2 = Arc::clone(&cv);
624 
625             let g = mu.lock().await;
626             // Has to be spawned _after_ acquiring the lock to prevent a race
627             // where the notify happens before the waiter has acquired the lock.
628             pool.spawn_ok(notify(mu2, cv2));
629             let _g = cv.wait(g).await;
630 
631             tx.send(()).expect("Failed to send completion notification");
632         }
633 
634         let ex = ThreadPool::new().expect("Failed to create ThreadPool");
635 
636         let mu = Arc::new(RwLock::new(()));
637         let cv = Arc::new(Condvar::new());
638 
639         let (tx, rx) = channel();
640         ex.spawn_ok(wait(mu, cv, tx, ex.clone()));
641 
642         rx.recv_timeout(Duration::from_secs(5))
643             .expect("Failed to receive completion notification");
644     }
645 
646     #[test]
notify_one_with_cancel()647     fn notify_one_with_cancel() {
648         const TASKS: usize = 17;
649         const OBSERVERS: usize = 7;
650         const ITERATIONS: usize = 103;
651 
652         async fn observe(mu: &Arc<RwLock<usize>>, cv: &Arc<Condvar>) {
653             let mut count = mu.read_lock().await;
654             while *count == 0 {
655                 count = cv.wait_read(count).await;
656             }
657             // SAFETY: Safe because count is valid and is byte aligned.
658             let _ = unsafe { ptr::read_volatile(&*count as *const usize) };
659         }
660 
661         async fn decrement(mu: &Arc<RwLock<usize>>, cv: &Arc<Condvar>) {
662             let mut count = mu.lock().await;
663             while *count == 0 {
664                 count = cv.wait(count).await;
665             }
666             *count -= 1;
667         }
668 
669         async fn increment(mu: Arc<RwLock<usize>>, cv: Arc<Condvar>, done: Sender<()>) {
670             for _ in 0..TASKS * OBSERVERS * ITERATIONS {
671                 *mu.lock().await += 1;
672                 cv.notify_one();
673             }
674 
675             done.send(()).expect("Failed to send completion message");
676         }
677 
678         async fn observe_either(
679             mu: Arc<RwLock<usize>>,
680             cv: Arc<Condvar>,
681             alt_mu: Arc<RwLock<usize>>,
682             alt_cv: Arc<Condvar>,
683             done: Sender<()>,
684         ) {
685             for _ in 0..ITERATIONS {
686                 select! {
687                     () = observe(&mu, &cv).fuse() => {},
688                     () = observe(&alt_mu, &alt_cv).fuse() => {},
689                 }
690             }
691 
692             done.send(()).expect("Failed to send completion message");
693         }
694 
695         async fn decrement_either(
696             mu: Arc<RwLock<usize>>,
697             cv: Arc<Condvar>,
698             alt_mu: Arc<RwLock<usize>>,
699             alt_cv: Arc<Condvar>,
700             done: Sender<()>,
701         ) {
702             for _ in 0..ITERATIONS {
703                 select! {
704                     () = decrement(&mu, &cv).fuse() => {},
705                     () = decrement(&alt_mu, &alt_cv).fuse() => {},
706                 }
707             }
708 
709             done.send(()).expect("Failed to send completion message");
710         }
711 
712         let ex = ThreadPool::new().expect("Failed to create ThreadPool");
713 
714         let mu = Arc::new(RwLock::new(0usize));
715         let alt_mu = Arc::new(RwLock::new(0usize));
716 
717         let cv = Arc::new(Condvar::new());
718         let alt_cv = Arc::new(Condvar::new());
719 
720         let (tx, rx) = channel();
721         for _ in 0..TASKS {
722             ex.spawn_ok(decrement_either(
723                 Arc::clone(&mu),
724                 Arc::clone(&cv),
725                 Arc::clone(&alt_mu),
726                 Arc::clone(&alt_cv),
727                 tx.clone(),
728             ));
729         }
730 
731         for _ in 0..OBSERVERS {
732             ex.spawn_ok(observe_either(
733                 Arc::clone(&mu),
734                 Arc::clone(&cv),
735                 Arc::clone(&alt_mu),
736                 Arc::clone(&alt_cv),
737                 tx.clone(),
738             ));
739         }
740 
741         ex.spawn_ok(increment(Arc::clone(&mu), Arc::clone(&cv), tx.clone()));
742         ex.spawn_ok(increment(Arc::clone(&alt_mu), Arc::clone(&alt_cv), tx));
743 
744         for _ in 0..TASKS + OBSERVERS + 2 {
745             if let Err(e) = rx.recv_timeout(Duration::from_secs(20)) {
746                 panic!("Error while waiting for threads to complete: {}", e);
747             }
748         }
749 
750         assert_eq!(
751             *block_on(mu.read_lock()) + *block_on(alt_mu.read_lock()),
752             (TASKS * OBSERVERS * ITERATIONS * 2) - (TASKS * ITERATIONS)
753         );
754         assert_eq!(cv.state.load(Ordering::Relaxed), 0);
755         assert_eq!(alt_cv.state.load(Ordering::Relaxed), 0);
756     }
757 
758     #[test]
notify_all_with_cancel()759     fn notify_all_with_cancel() {
760         const TASKS: usize = 17;
761         const ITERATIONS: usize = 103;
762 
763         async fn decrement(mu: &Arc<RwLock<usize>>, cv: &Arc<Condvar>) {
764             let mut count = mu.lock().await;
765             while *count == 0 {
766                 count = cv.wait(count).await;
767             }
768             *count -= 1;
769         }
770 
771         async fn increment(mu: Arc<RwLock<usize>>, cv: Arc<Condvar>, done: Sender<()>) {
772             for _ in 0..TASKS * ITERATIONS {
773                 *mu.lock().await += 1;
774                 cv.notify_all();
775             }
776 
777             done.send(()).expect("Failed to send completion message");
778         }
779 
780         async fn decrement_either(
781             mu: Arc<RwLock<usize>>,
782             cv: Arc<Condvar>,
783             alt_mu: Arc<RwLock<usize>>,
784             alt_cv: Arc<Condvar>,
785             done: Sender<()>,
786         ) {
787             for _ in 0..ITERATIONS {
788                 select! {
789                     () = decrement(&mu, &cv).fuse() => {},
790                     () = decrement(&alt_mu, &alt_cv).fuse() => {},
791                 }
792             }
793 
794             done.send(()).expect("Failed to send completion message");
795         }
796 
797         let ex = ThreadPool::new().expect("Failed to create ThreadPool");
798 
799         let mu = Arc::new(RwLock::new(0usize));
800         let alt_mu = Arc::new(RwLock::new(0usize));
801 
802         let cv = Arc::new(Condvar::new());
803         let alt_cv = Arc::new(Condvar::new());
804 
805         let (tx, rx) = channel();
806         for _ in 0..TASKS {
807             ex.spawn_ok(decrement_either(
808                 Arc::clone(&mu),
809                 Arc::clone(&cv),
810                 Arc::clone(&alt_mu),
811                 Arc::clone(&alt_cv),
812                 tx.clone(),
813             ));
814         }
815 
816         ex.spawn_ok(increment(Arc::clone(&mu), Arc::clone(&cv), tx.clone()));
817         ex.spawn_ok(increment(Arc::clone(&alt_mu), Arc::clone(&alt_cv), tx));
818 
819         for _ in 0..TASKS + 2 {
820             if let Err(e) = rx.recv_timeout(Duration::from_secs(10)) {
821                 panic!("Error while waiting for threads to complete: {}", e);
822             }
823         }
824 
825         assert_eq!(
826             *block_on(mu.read_lock()) + *block_on(alt_mu.read_lock()),
827             TASKS * ITERATIONS
828         );
829         assert_eq!(cv.state.load(Ordering::Relaxed), 0);
830         assert_eq!(alt_cv.state.load(Ordering::Relaxed), 0);
831     }
832     #[test]
notify_all()833     fn notify_all() {
834         const THREADS: usize = 13;
835 
836         let mu = Arc::new(RwLock::new(0));
837         let cv = Arc::new(Condvar::new());
838         let (tx, rx) = channel();
839 
840         let mut threads = Vec::with_capacity(THREADS);
841         for _ in 0..THREADS {
842             let mu2 = mu.clone();
843             let cv2 = cv.clone();
844             let tx2 = tx.clone();
845 
846             threads.push(thread::spawn(move || {
847                 let mut count = block_on(mu2.lock());
848                 *count += 1;
849                 if *count == THREADS {
850                     tx2.send(()).unwrap();
851                 }
852 
853                 while *count != 0 {
854                     count = block_on(cv2.wait(count));
855                 }
856             }));
857         }
858 
859         mem::drop(tx);
860 
861         // Wait till all threads have started.
862         rx.recv_timeout(Duration::from_secs(5)).unwrap();
863 
864         let mut count = block_on(mu.lock());
865         *count = 0;
866         mem::drop(count);
867         cv.notify_all();
868 
869         for t in threads {
870             t.join().unwrap();
871         }
872     }
873 
874     #[test]
notify_all_single_thread_async()875     fn notify_all_single_thread_async() {
876         const TASKS: usize = 13;
877 
878         async fn reset(mu: Rc<RwLock<usize>>, cv: Rc<Condvar>) {
879             let mut count = mu.lock().await;
880             *count = 0;
881             cv.notify_all();
882         }
883 
884         async fn watcher(mu: Rc<RwLock<usize>>, cv: Rc<Condvar>, spawner: LocalSpawner) {
885             let mut count = mu.lock().await;
886             *count += 1;
887             if *count == TASKS {
888                 spawner
889                     .spawn_local(reset(mu.clone(), cv.clone()))
890                     .expect("Failed to spawn reset task");
891             }
892 
893             while *count != 0 {
894                 count = cv.wait(count).await;
895             }
896         }
897 
898         let mut ex = LocalPool::new();
899         let spawner = ex.spawner();
900 
901         let mu = Rc::new(RwLock::new(0));
902         let cv = Rc::new(Condvar::new());
903 
904         for _ in 0..TASKS {
905             spawner
906                 .spawn_local(watcher(mu.clone(), cv.clone(), spawner.clone()))
907                 .expect("Failed to spawn watcher task");
908         }
909 
910         ex.run();
911     }
912 
913     #[test]
notify_all_multi_thread_async()914     fn notify_all_multi_thread_async() {
915         const TASKS: usize = 13;
916 
917         async fn reset(mu: Arc<RwLock<usize>>, cv: Arc<Condvar>) {
918             let mut count = mu.lock().await;
919             *count = 0;
920             cv.notify_all();
921         }
922 
923         async fn watcher(
924             mu: Arc<RwLock<usize>>,
925             cv: Arc<Condvar>,
926             pool: ThreadPool,
927             tx: Sender<()>,
928         ) {
929             let mut count = mu.lock().await;
930             *count += 1;
931             if *count == TASKS {
932                 pool.spawn_ok(reset(mu.clone(), cv.clone()));
933             }
934 
935             while *count != 0 {
936                 count = cv.wait(count).await;
937             }
938 
939             tx.send(()).expect("Failed to send completion notification");
940         }
941 
942         let pool = ThreadPool::new().expect("Failed to create ThreadPool");
943 
944         let mu = Arc::new(RwLock::new(0));
945         let cv = Arc::new(Condvar::new());
946 
947         let (tx, rx) = channel();
948         for _ in 0..TASKS {
949             pool.spawn_ok(watcher(mu.clone(), cv.clone(), pool.clone(), tx.clone()));
950         }
951 
952         for _ in 0..TASKS {
953             rx.recv_timeout(Duration::from_secs(5))
954                 .expect("Failed to receive completion notification");
955         }
956     }
957 
958     #[test]
wake_all_readers()959     fn wake_all_readers() {
960         async fn read(mu: Arc<RwLock<bool>>, cv: Arc<Condvar>) {
961             let mut ready = mu.read_lock().await;
962             while !*ready {
963                 ready = cv.wait_read(ready).await;
964             }
965         }
966 
967         let mu = Arc::new(RwLock::new(false));
968         let cv = Arc::new(Condvar::new());
969         let mut readers = [
970             Box::pin(read(mu.clone(), cv.clone())),
971             Box::pin(read(mu.clone(), cv.clone())),
972             Box::pin(read(mu.clone(), cv.clone())),
973             Box::pin(read(mu.clone(), cv.clone())),
974         ];
975 
976         let arc_waker = Arc::new(TestWaker);
977         let waker = waker_ref(&arc_waker);
978         let mut cx = Context::from_waker(&waker);
979 
980         // First have all the readers wait on the Condvar.
981         for r in &mut readers {
982             if let Poll::Ready(()) = r.as_mut().poll(&mut cx) {
983                 panic!("reader unexpectedly ready");
984             }
985         }
986 
987         assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS);
988 
989         // Now make the condition true and notify the condvar. Even though we will call notify_one,
990         // all the readers should be woken up.
991         *block_on(mu.lock()) = true;
992         cv.notify_one();
993 
994         assert_eq!(cv.state.load(Ordering::Relaxed), 0);
995 
996         // All readers should now be able to complete.
997         for r in &mut readers {
998             if r.as_mut().poll(&mut cx).is_pending() {
999                 panic!("reader unable to complete");
1000             }
1001         }
1002     }
1003 
1004     #[test]
cancel_before_notify()1005     fn cancel_before_notify() {
1006         async fn dec(mu: Arc<RwLock<usize>>, cv: Arc<Condvar>) {
1007             let mut count = mu.lock().await;
1008 
1009             while *count == 0 {
1010                 count = cv.wait(count).await;
1011             }
1012 
1013             *count -= 1;
1014         }
1015 
1016         let mu = Arc::new(RwLock::new(0));
1017         let cv = Arc::new(Condvar::new());
1018 
1019         let arc_waker = Arc::new(TestWaker);
1020         let waker = waker_ref(&arc_waker);
1021         let mut cx = Context::from_waker(&waker);
1022 
1023         let mut fut1 = Box::pin(dec(mu.clone(), cv.clone()));
1024         let mut fut2 = Box::pin(dec(mu.clone(), cv.clone()));
1025 
1026         if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) {
1027             panic!("future unexpectedly ready");
1028         }
1029         if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) {
1030             panic!("future unexpectedly ready");
1031         }
1032         assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS);
1033 
1034         *block_on(mu.lock()) = 2;
1035         // Drop fut1 before notifying the cv.
1036         mem::drop(fut1);
1037         cv.notify_one();
1038 
1039         // fut2 should now be ready to complete.
1040         assert_eq!(cv.state.load(Ordering::Relaxed), 0);
1041 
1042         if fut2.as_mut().poll(&mut cx).is_pending() {
1043             panic!("future unable to complete");
1044         }
1045 
1046         assert_eq!(*block_on(mu.lock()), 1);
1047     }
1048 
1049     #[test]
cancel_after_notify_one()1050     fn cancel_after_notify_one() {
1051         async fn dec(mu: Arc<RwLock<usize>>, cv: Arc<Condvar>) {
1052             let mut count = mu.lock().await;
1053 
1054             while *count == 0 {
1055                 count = cv.wait(count).await;
1056             }
1057 
1058             *count -= 1;
1059         }
1060 
1061         let mu = Arc::new(RwLock::new(0));
1062         let cv = Arc::new(Condvar::new());
1063 
1064         let arc_waker = Arc::new(TestWaker);
1065         let waker = waker_ref(&arc_waker);
1066         let mut cx = Context::from_waker(&waker);
1067 
1068         let mut fut1 = Box::pin(dec(mu.clone(), cv.clone()));
1069         let mut fut2 = Box::pin(dec(mu.clone(), cv.clone()));
1070 
1071         if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) {
1072             panic!("future unexpectedly ready");
1073         }
1074         if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) {
1075             panic!("future unexpectedly ready");
1076         }
1077         assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS);
1078 
1079         *block_on(mu.lock()) = 2;
1080         cv.notify_one();
1081 
1082         // fut1 should now be ready to complete. Drop it before polling. This should wake up fut2.
1083         mem::drop(fut1);
1084         assert_eq!(cv.state.load(Ordering::Relaxed), 0);
1085 
1086         if fut2.as_mut().poll(&mut cx).is_pending() {
1087             panic!("future unable to complete");
1088         }
1089 
1090         assert_eq!(*block_on(mu.lock()), 1);
1091     }
1092 
1093     #[test]
cancel_after_notify_all()1094     fn cancel_after_notify_all() {
1095         async fn dec(mu: Arc<RwLock<usize>>, cv: Arc<Condvar>) {
1096             let mut count = mu.lock().await;
1097 
1098             while *count == 0 {
1099                 count = cv.wait(count).await;
1100             }
1101 
1102             *count -= 1;
1103         }
1104 
1105         let mu = Arc::new(RwLock::new(0));
1106         let cv = Arc::new(Condvar::new());
1107 
1108         let arc_waker = Arc::new(TestWaker);
1109         let waker = waker_ref(&arc_waker);
1110         let mut cx = Context::from_waker(&waker);
1111 
1112         let mut fut1 = Box::pin(dec(mu.clone(), cv.clone()));
1113         let mut fut2 = Box::pin(dec(mu.clone(), cv.clone()));
1114 
1115         if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) {
1116             panic!("future unexpectedly ready");
1117         }
1118         if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) {
1119             panic!("future unexpectedly ready");
1120         }
1121         assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS);
1122 
1123         let mut count = block_on(mu.lock());
1124         *count = 2;
1125 
1126         // Notify the cv while holding the lock. This should wake up both waiters.
1127         cv.notify_all();
1128         assert_eq!(cv.state.load(Ordering::Relaxed), 0);
1129 
1130         mem::drop(count);
1131 
1132         mem::drop(fut1);
1133 
1134         if fut2.as_mut().poll(&mut cx).is_pending() {
1135             panic!("future unable to complete");
1136         }
1137 
1138         assert_eq!(*block_on(mu.lock()), 1);
1139     }
1140 
1141     #[test]
timed_wait()1142     fn timed_wait() {
1143         async fn wait_deadline(
1144             mu: Arc<RwLock<usize>>,
1145             cv: Arc<Condvar>,
1146             timeout: oneshot::Receiver<()>,
1147         ) {
1148             let mut count = mu.lock().await;
1149 
1150             if *count == 0 {
1151                 let mut rx = timeout.fuse();
1152 
1153                 while *count == 0 {
1154                     select! {
1155                         res = rx => {
1156                             if let Err(e) = res {
1157                                 panic!("Error while receiving timeout notification: {}", e);
1158                             }
1159 
1160                             return;
1161                         },
1162                         c = cv.wait(count).fuse() => count = c,
1163                     }
1164                 }
1165             }
1166 
1167             *count += 1;
1168         }
1169 
1170         let mu = Arc::new(RwLock::new(0));
1171         let cv = Arc::new(Condvar::new());
1172 
1173         let arc_waker = Arc::new(TestWaker);
1174         let waker = waker_ref(&arc_waker);
1175         let mut cx = Context::from_waker(&waker);
1176 
1177         let (tx, rx) = oneshot::channel();
1178         let mut wait = Box::pin(wait_deadline(mu.clone(), cv.clone(), rx));
1179 
1180         if let Poll::Ready(()) = wait.as_mut().poll(&mut cx) {
1181             panic!("wait_deadline unexpectedly ready");
1182         }
1183 
1184         assert_eq!(cv.state.load(Ordering::Relaxed), HAS_WAITERS);
1185 
1186         // Signal the channel, which should cancel the wait.
1187         tx.send(()).expect("Failed to send wakeup");
1188 
1189         // Wait for the timer to run out.
1190         if wait.as_mut().poll(&mut cx).is_pending() {
1191             panic!("wait_deadline unable to complete in time");
1192         }
1193 
1194         assert_eq!(cv.state.load(Ordering::Relaxed), 0);
1195         assert_eq!(*block_on(mu.lock()), 0);
1196     }
1197 }
1198