1 // Copyright 2020 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 use std::cell::UnsafeCell;
6 use std::hint;
7 use std::mem;
8 use std::sync::atomic::AtomicUsize;
9 use std::sync::atomic::Ordering;
10 use std::sync::Arc;
11
12 use super::super::sync::mu::RawRwLock;
13 use super::super::sync::mu::RwLockReadGuard;
14 use super::super::sync::mu::RwLockWriteGuard;
15 use super::super::sync::waiter::Kind as WaiterKind;
16 use super::super::sync::waiter::Waiter;
17 use super::super::sync::waiter::WaiterAdapter;
18 use super::super::sync::waiter::WaiterList;
19 use super::super::sync::waiter::WaitingFor;
20
21 const SPINLOCK: usize = 1 << 0;
22 const HAS_WAITERS: usize = 1 << 1;
23
24 /// A primitive to wait for an event to occur without consuming CPU time.
25 ///
26 /// Condition variables are used in combination with a `RwLock` when a thread wants to wait for some
27 /// condition to become true. The condition must always be verified while holding the `RwLock` lock.
28 /// It is an error to use a `Condvar` with more than one `RwLock` while there are threads waiting on
29 /// the `Condvar`.
30 ///
31 /// # Examples
32 ///
33 /// ```edition2018
34 /// use std::sync::Arc;
35 /// use std::thread;
36 /// use std::sync::mpsc::channel;
37 ///
38 /// use cros_async::{
39 /// block_on,
40 /// sync::{Condvar, RwLock},
41 /// };
42 ///
43 /// const N: usize = 13;
44 ///
45 /// // Spawn a few threads to increment a shared variable (non-atomically), and
46 /// // let all threads waiting on the Condvar know once the increments are done.
47 /// let data = Arc::new(RwLock::new(0));
48 /// let cv = Arc::new(Condvar::new());
49 ///
50 /// for _ in 0..N {
51 /// let (data, cv) = (data.clone(), cv.clone());
52 /// thread::spawn(move || {
53 /// let mut data = block_on(data.lock());
54 /// *data += 1;
55 /// if *data == N {
56 /// cv.notify_all();
57 /// }
58 /// });
59 /// }
60 ///
61 /// let mut val = block_on(data.lock());
62 /// while *val != N {
63 /// val = block_on(cv.wait(val));
64 /// }
65 /// ```
66 #[repr(align(128))]
67 pub struct Condvar {
68 state: AtomicUsize,
69 waiters: UnsafeCell<WaiterList>,
70 mu: UnsafeCell<usize>,
71 }
72
73 impl Condvar {
74 /// Creates a new condition variable ready to be waited on and notified.
new() -> Condvar75 pub fn new() -> Condvar {
76 Condvar {
77 state: AtomicUsize::new(0),
78 waiters: UnsafeCell::new(WaiterList::new(WaiterAdapter::new())),
79 mu: UnsafeCell::new(0),
80 }
81 }
82
83 /// Block the current thread until this `Condvar` is notified by another thread.
84 ///
85 /// This method will atomically unlock the `RwLock` held by `guard` and then block the current
86 /// thread. Any call to `notify_one` or `notify_all` after the `RwLock` is unlocked may wake up
87 /// the thread.
88 ///
89 /// To allow for more efficient scheduling, this call may return even when the programmer
90 /// doesn't expect the thread to be woken. Therefore, calls to `wait()` should be used inside a
91 /// loop that checks the predicate before continuing.
92 ///
93 /// Callers that are not in an async context may wish to use the `block_on` method to block the
94 /// thread until the `Condvar` is notified.
95 ///
96 /// # Panics
97 ///
98 /// This method will panic if used with more than one `RwLock` at the same time.
99 ///
100 /// # Examples
101 ///
102 /// ```
103 /// # use std::sync::Arc;
104 /// # use std::thread;
105 ///
106 /// # use cros_async::{
107 /// # block_on,
108 /// # sync::{Condvar, RwLock},
109 /// # };
110 ///
111 /// # let mu = Arc::new(RwLock::new(false));
112 /// # let cv = Arc::new(Condvar::new());
113 /// # let (mu2, cv2) = (mu.clone(), cv.clone());
114 ///
115 /// # let t = thread::spawn(move || {
116 /// # *block_on(mu2.lock()) = true;
117 /// # cv2.notify_all();
118 /// # });
119 ///
120 /// let mut ready = block_on(mu.lock());
121 /// while !*ready {
122 /// ready = block_on(cv.wait(ready));
123 /// }
124 ///
125 /// # t.join().expect("failed to join thread");
126 /// ```
127 // Clippy doesn't like the lifetime parameters here but doing what it suggests leads to code
128 // that doesn't compile.
129 #[allow(clippy::needless_lifetimes)]
wait<'g, T>(&self, guard: RwLockWriteGuard<'g, T>) -> RwLockWriteGuard<'g, T>130 pub async fn wait<'g, T>(&self, guard: RwLockWriteGuard<'g, T>) -> RwLockWriteGuard<'g, T> {
131 let waiter = Arc::new(Waiter::new(
132 WaiterKind::Exclusive,
133 cancel_waiter,
134 self as *const Condvar as usize,
135 WaitingFor::Condvar,
136 ));
137
138 self.add_waiter(waiter.clone(), guard.as_raw_rwlock());
139
140 // Get a reference to the rwlock and then drop the lock.
141 let mu = guard.into_inner();
142
143 // Wait to be woken up.
144 waiter.wait().await;
145
146 // Now re-acquire the lock.
147 mu.lock_from_cv().await
148 }
149
150 /// Like `wait()` but takes and returns a `RwLockReadGuard` instead.
151 // Clippy doesn't like the lifetime parameters here but doing what it suggests leads to code
152 // that doesn't compile.
153 #[allow(clippy::needless_lifetimes)]
wait_read<'g, T>(&self, guard: RwLockReadGuard<'g, T>) -> RwLockReadGuard<'g, T>154 pub async fn wait_read<'g, T>(&self, guard: RwLockReadGuard<'g, T>) -> RwLockReadGuard<'g, T> {
155 let waiter = Arc::new(Waiter::new(
156 WaiterKind::Shared,
157 cancel_waiter,
158 self as *const Condvar as usize,
159 WaitingFor::Condvar,
160 ));
161
162 self.add_waiter(waiter.clone(), guard.as_raw_rwlock());
163
164 // Get a reference to the rwlock and then drop the lock.
165 let mu = guard.into_inner();
166
167 // Wait to be woken up.
168 waiter.wait().await;
169
170 // Now re-acquire the lock.
171 mu.read_lock_from_cv().await
172 }
173
add_waiter(&self, waiter: Arc<Waiter>, raw_rwlock: &RawRwLock)174 fn add_waiter(&self, waiter: Arc<Waiter>, raw_rwlock: &RawRwLock) {
175 // Acquire the spin lock.
176 let mut oldstate = self.state.load(Ordering::Relaxed);
177 while (oldstate & SPINLOCK) != 0
178 || self
179 .state
180 .compare_exchange_weak(
181 oldstate,
182 oldstate | SPINLOCK | HAS_WAITERS,
183 Ordering::Acquire,
184 Ordering::Relaxed,
185 )
186 .is_err()
187 {
188 hint::spin_loop();
189 oldstate = self.state.load(Ordering::Relaxed);
190 }
191
192 // SAFETY:
193 // Safe because the spin lock guarantees exclusive access and the reference does not escape
194 // this function.
195 let mu = unsafe { &mut *self.mu.get() };
196 let muptr = raw_rwlock as *const RawRwLock as usize;
197
198 match *mu {
199 0 => *mu = muptr,
200 p if p == muptr => {}
201 _ => panic!("Attempting to use Condvar with more than one RwLock at the same time"),
202 }
203
204 // SAFETY:
205 // Safe because the spin lock guarantees exclusive access.
206 unsafe { (*self.waiters.get()).push_back(waiter) };
207
208 // Release the spin lock. Use a direct store here because no other thread can modify
209 // `self.state` while we hold the spin lock. Keep the `HAS_WAITERS` bit that we set earlier
210 // because we just added a waiter.
211 self.state.store(HAS_WAITERS, Ordering::Release);
212 }
213
214 /// Notify at most one thread currently waiting on the `Condvar`.
215 ///
216 /// If there is a thread currently waiting on the `Condvar` it will be woken up from its call to
217 /// `wait`.
218 ///
219 /// Unlike more traditional condition variable interfaces, this method requires a reference to
220 /// the `RwLock` associated with this `Condvar`. This is because it is inherently racy to call
221 /// `notify_one` or `notify_all` without first acquiring the `RwLock` lock. Additionally, taking
222 /// a reference to the `RwLock` here allows us to make some optimizations that can improve
223 /// performance by reducing unnecessary wakeups.
notify_one(&self)224 pub fn notify_one(&self) {
225 let mut oldstate = self.state.load(Ordering::Relaxed);
226 if (oldstate & HAS_WAITERS) == 0 {
227 // No waiters.
228 return;
229 }
230
231 while (oldstate & SPINLOCK) != 0
232 || self
233 .state
234 .compare_exchange_weak(
235 oldstate,
236 oldstate | SPINLOCK,
237 Ordering::Acquire,
238 Ordering::Relaxed,
239 )
240 .is_err()
241 {
242 hint::spin_loop();
243 oldstate = self.state.load(Ordering::Relaxed);
244 }
245
246 // SAFETY:
247 // Safe because the spin lock guarantees exclusive access and the reference does not escape
248 // this function.
249 let waiters = unsafe { &mut *self.waiters.get() };
250 let wake_list = get_wake_list(waiters);
251
252 let newstate = if waiters.is_empty() {
253 // SAFETY:
254 // Also clear the rwlock associated with this Condvar since there are no longer any
255 // waiters. Safe because the spin lock guarantees exclusive access.
256 unsafe { *self.mu.get() = 0 };
257
258 // We are releasing the spin lock and there are no more waiters so we can clear all bits
259 // in `self.state`.
260 0
261 } else {
262 // There are still waiters so we need to keep the HAS_WAITERS bit in the state.
263 HAS_WAITERS
264 };
265
266 // Release the spin lock.
267 self.state.store(newstate, Ordering::Release);
268
269 // Now wake any waiters in the wake list.
270 for w in wake_list {
271 w.wake();
272 }
273 }
274
275 /// Notify all threads currently waiting on the `Condvar`.
276 ///
277 /// All threads currently waiting on the `Condvar` will be woken up from their call to `wait`.
278 ///
279 /// Unlike more traditional condition variable interfaces, this method requires a reference to
280 /// the `RwLock` associated with this `Condvar`. This is because it is inherently racy to call
281 /// `notify_one` or `notify_all` without first acquiring the `RwLock` lock. Additionally, taking
282 /// a reference to the `RwLock` here allows us to make some optimizations that can improve
283 /// performance by reducing unnecessary wakeups.
notify_all(&self)284 pub fn notify_all(&self) {
285 let mut oldstate = self.state.load(Ordering::Relaxed);
286 if (oldstate & HAS_WAITERS) == 0 {
287 // No waiters.
288 return;
289 }
290
291 while (oldstate & SPINLOCK) != 0
292 || self
293 .state
294 .compare_exchange_weak(
295 oldstate,
296 oldstate | SPINLOCK,
297 Ordering::Acquire,
298 Ordering::Relaxed,
299 )
300 .is_err()
301 {
302 hint::spin_loop();
303 oldstate = self.state.load(Ordering::Relaxed);
304 }
305
306 // SAFETY:
307 // Safe because the spin lock guarantees exclusive access to `self.waiters`.
308 let wake_list = unsafe { (*self.waiters.get()).take() };
309
310 // SAFETY:
311 // Clear the rwlock associated with this Condvar since there are no longer any waiters. Safe
312 // because we the spin lock guarantees exclusive access.
313 unsafe { *self.mu.get() = 0 };
314
315 // Mark any waiters left as no longer waiting for the Condvar.
316 for w in &wake_list {
317 w.set_waiting_for(WaitingFor::None);
318 }
319
320 // Release the spin lock. We can clear all bits in the state since we took all the waiters.
321 self.state.store(0, Ordering::Release);
322
323 // Now wake any waiters in the wake list.
324 for w in wake_list {
325 w.wake();
326 }
327 }
328
cancel_waiter(&self, waiter: &Waiter, wake_next: bool)329 fn cancel_waiter(&self, waiter: &Waiter, wake_next: bool) {
330 let mut oldstate = self.state.load(Ordering::Relaxed);
331 while oldstate & SPINLOCK != 0
332 || self
333 .state
334 .compare_exchange_weak(
335 oldstate,
336 oldstate | SPINLOCK,
337 Ordering::Acquire,
338 Ordering::Relaxed,
339 )
340 .is_err()
341 {
342 hint::spin_loop();
343 oldstate = self.state.load(Ordering::Relaxed);
344 }
345
346 // SAFETY:
347 // Safe because the spin lock provides exclusive access and the reference does not escape
348 // this function.
349 let waiters = unsafe { &mut *self.waiters.get() };
350
351 let waiting_for = waiter.is_waiting_for();
352 // Don't drop the old waiter now as we're still holding the spin lock.
353 let old_waiter = if waiter.is_linked() && waiting_for == WaitingFor::Condvar {
354 // SAFETY:
355 // Safe because we know that the waiter is still linked and is waiting for the Condvar,
356 // which guarantees that it is still in `self.waiters`.
357 let mut cursor = unsafe { waiters.cursor_mut_from_ptr(waiter as *const Waiter) };
358 cursor.remove()
359 } else {
360 None
361 };
362
363 let wake_list = if wake_next || waiting_for == WaitingFor::None {
364 // Either the waiter was already woken or it's been removed from the condvar's waiter
365 // list and is going to be woken. Either way, we need to wake up another thread.
366 get_wake_list(waiters)
367 } else {
368 WaiterList::new(WaiterAdapter::new())
369 };
370
371 let set_on_release = if waiters.is_empty() {
372 // SAFETY:
373 // Clear the rwlock associated with this Condvar since there are no longer any waiters.
374 // Safe because we the spin lock guarantees exclusive access.
375 unsafe { *self.mu.get() = 0 };
376
377 0
378 } else {
379 HAS_WAITERS
380 };
381
382 self.state.store(set_on_release, Ordering::Release);
383
384 // Now wake any waiters still left in the wake list.
385 for w in wake_list {
386 w.wake();
387 }
388
389 mem::drop(old_waiter);
390 }
391 }
392
393 // TODO(b/315998194): Add safety comment
394 #[allow(clippy::undocumented_unsafe_blocks)]
395 unsafe impl Send for Condvar {}
396 // TODO(b/315998194): Add safety comment
397 #[allow(clippy::undocumented_unsafe_blocks)]
398 unsafe impl Sync for Condvar {}
399
400 impl Default for Condvar {
default() -> Self401 fn default() -> Self {
402 Self::new()
403 }
404 }
405
406 // Scan `waiters` and return all waiters that should be woken up.
407 //
408 // If the first waiter is trying to acquire a shared lock, then all waiters in the list that are
409 // waiting for a shared lock are also woken up. In addition one writer is woken up, if possible.
410 //
411 // If the first waiter is trying to acquire an exclusive lock, then only that waiter is returned and
412 // the rest of the list is not scanned.
get_wake_list(waiters: &mut WaiterList) -> WaiterList413 fn get_wake_list(waiters: &mut WaiterList) -> WaiterList {
414 let mut to_wake = WaiterList::new(WaiterAdapter::new());
415 let mut cursor = waiters.front_mut();
416
417 let mut waking_readers = false;
418 let mut all_readers = true;
419 while let Some(w) = cursor.get() {
420 match w.kind() {
421 WaiterKind::Exclusive if !waking_readers => {
422 // This is the first waiter and it's a writer. No need to check the other waiters.
423 // Also mark the waiter as having been removed from the Condvar's waiter list.
424 let waiter = cursor.remove().unwrap();
425 waiter.set_waiting_for(WaitingFor::None);
426 to_wake.push_back(waiter);
427 break;
428 }
429
430 WaiterKind::Shared => {
431 // This is a reader and the first waiter in the list was not a writer so wake up all
432 // the readers in the wait list.
433 let waiter = cursor.remove().unwrap();
434 waiter.set_waiting_for(WaitingFor::None);
435 to_wake.push_back(waiter);
436 waking_readers = true;
437 }
438
439 WaiterKind::Exclusive => {
440 debug_assert!(waking_readers);
441 if all_readers {
442 // We are waking readers but we need to ensure that at least one writer is woken
443 // up. Since we haven't yet woken up a writer, wake up this one.
444 let waiter = cursor.remove().unwrap();
445 waiter.set_waiting_for(WaitingFor::None);
446 to_wake.push_back(waiter);
447 all_readers = false;
448 } else {
449 // We are waking readers and have already woken one writer. Skip this one.
450 cursor.move_next();
451 }
452 }
453 }
454 }
455
456 to_wake
457 }
458
cancel_waiter(cv: usize, waiter: &Waiter, wake_next: bool)459 fn cancel_waiter(cv: usize, waiter: &Waiter, wake_next: bool) {
460 let condvar = cv as *const Condvar;
461
462 // SAFETY:
463 // Safe because the thread that owns the waiter being canceled must also own a reference to the
464 // Condvar, which guarantees that this pointer is valid.
465 unsafe { (*condvar).cancel_waiter(waiter, wake_next) }
466 }
467
468 // TODO(b/194338842): Fix tests for windows
469 #[cfg(any(target_os = "android", target_os = "linux"))]
470 #[cfg(test)]
471 mod test {
472 use std::future::Future;
473 use std::mem;
474 use std::ptr;
475 use std::rc::Rc;
476 use std::sync::mpsc::channel;
477 use std::sync::mpsc::Sender;
478 use std::sync::Arc;
479 use std::task::Context;
480 use std::task::Poll;
481 use std::thread;
482 use std::thread::JoinHandle;
483 use std::time::Duration;
484
485 use futures::channel::oneshot;
486 use futures::select;
487 use futures::task::waker_ref;
488 use futures::task::ArcWake;
489 use futures::FutureExt;
490 use futures_executor::LocalPool;
491 use futures_executor::LocalSpawner;
492 use futures_executor::ThreadPool;
493 use futures_util::task::LocalSpawnExt;
494
495 use super::super::super::block_on;
496 use super::super::super::sync::RwLock;
497 use super::*;
498
499 // Dummy waker used when we want to manually drive futures.
500 struct TestWaker;
501 impl ArcWake for TestWaker {
wake_by_ref(_arc_self: &Arc<Self>)502 fn wake_by_ref(_arc_self: &Arc<Self>) {}
503 }
504
505 #[test]
smoke()506 fn smoke() {
507 let cv = Condvar::new();
508 cv.notify_one();
509 cv.notify_all();
510 }
511
512 #[test]
notify_one()513 fn notify_one() {
514 let mu = Arc::new(RwLock::new(()));
515 let cv = Arc::new(Condvar::new());
516
517 let mu2 = mu.clone();
518 let cv2 = cv.clone();
519
520 let guard = block_on(mu.lock());
521 thread::spawn(move || {
522 let _g = block_on(mu2.lock());
523 cv2.notify_one();
524 });
525
526 let guard = block_on(cv.wait(guard));
527 mem::drop(guard);
528 }
529
530 #[test]
multi_rwlock()531 fn multi_rwlock() {
532 const NUM_THREADS: usize = 5;
533
534 let mu = Arc::new(RwLock::new(false));
535 let cv = Arc::new(Condvar::new());
536
537 let mut threads = Vec::with_capacity(NUM_THREADS);
538 for _ in 0..NUM_THREADS {
539 let mu = mu.clone();
540 let cv = cv.clone();
541
542 threads.push(thread::spawn(move || {
543 let mut ready = block_on(mu.lock());
544 while !*ready {
545 ready = block_on(cv.wait(ready));
546 }
547 }));
548 }
549
550 let mut g = block_on(mu.lock());
551 *g = true;
552 mem::drop(g);
553 cv.notify_all();
554
555 threads
556 .into_iter()
557 .try_for_each(JoinHandle::join)
558 .expect("Failed to join threads");
559
560 // Now use the Condvar with a different rwlock.
561 let alt_mu = Arc::new(RwLock::new(None));
562 let alt_mu2 = alt_mu.clone();
563 let cv2 = cv.clone();
564 let handle = thread::spawn(move || {
565 let mut g = block_on(alt_mu2.lock());
566 while g.is_none() {
567 g = block_on(cv2.wait(g));
568 }
569 });
570
571 let mut alt_g = block_on(alt_mu.lock());
572 *alt_g = Some(());
573 mem::drop(alt_g);
574 cv.notify_all();
575
576 handle
577 .join()
578 .expect("Failed to join thread alternate rwlock");
579 }
580
581 #[test]
notify_one_single_thread_async()582 fn notify_one_single_thread_async() {
583 async fn notify(mu: Rc<RwLock<()>>, cv: Rc<Condvar>) {
584 let _g = mu.lock().await;
585 cv.notify_one();
586 }
587
588 async fn wait(mu: Rc<RwLock<()>>, cv: Rc<Condvar>, spawner: LocalSpawner) {
589 let mu2 = Rc::clone(&mu);
590 let cv2 = Rc::clone(&cv);
591
592 let g = mu.lock().await;
593 // Has to be spawned _after_ acquiring the lock to prevent a race
594 // where the notify happens before the waiter has acquired the lock.
595 spawner
596 .spawn_local(notify(mu2, cv2))
597 .expect("Failed to spawn `notify` task");
598 let _g = cv.wait(g).await;
599 }
600
601 let mut ex = LocalPool::new();
602 let spawner = ex.spawner();
603
604 let mu = Rc::new(RwLock::new(()));
605 let cv = Rc::new(Condvar::new());
606
607 spawner
608 .spawn_local(wait(mu, cv, spawner.clone()))
609 .expect("Failed to spawn `wait` task");
610
611 ex.run();
612 }
613
614 #[test]
notify_one_multi_thread_async()615 fn notify_one_multi_thread_async() {
616 async fn notify(mu: Arc<RwLock<()>>, cv: Arc<Condvar>) {
617 let _g = mu.lock().await;
618 cv.notify_one();
619 }
620
621 async fn wait(mu: Arc<RwLock<()>>, cv: Arc<Condvar>, tx: Sender<()>, pool: ThreadPool) {
622 let mu2 = Arc::clone(&mu);
623 let cv2 = Arc::clone(&cv);
624
625 let g = mu.lock().await;
626 // Has to be spawned _after_ acquiring the lock to prevent a race
627 // where the notify happens before the waiter has acquired the lock.
628 pool.spawn_ok(notify(mu2, cv2));
629 let _g = cv.wait(g).await;
630
631 tx.send(()).expect("Failed to send completion notification");
632 }
633
634 let ex = ThreadPool::new().expect("Failed to create ThreadPool");
635
636 let mu = Arc::new(RwLock::new(()));
637 let cv = Arc::new(Condvar::new());
638
639 let (tx, rx) = channel();
640 ex.spawn_ok(wait(mu, cv, tx, ex.clone()));
641
642 rx.recv_timeout(Duration::from_secs(5))
643 .expect("Failed to receive completion notification");
644 }
645
646 #[test]
notify_one_with_cancel()647 fn notify_one_with_cancel() {
648 const TASKS: usize = 17;
649 const OBSERVERS: usize = 7;
650 const ITERATIONS: usize = 103;
651
652 async fn observe(mu: &Arc<RwLock<usize>>, cv: &Arc<Condvar>) {
653 let mut count = mu.read_lock().await;
654 while *count == 0 {
655 count = cv.wait_read(count).await;
656 }
657 // SAFETY: Safe because count is valid and is byte aligned.
658 let _ = unsafe { ptr::read_volatile(&*count as *const usize) };
659 }
660
661 async fn decrement(mu: &Arc<RwLock<usize>>, cv: &Arc<Condvar>) {
662 let mut count = mu.lock().await;
663 while *count == 0 {
664 count = cv.wait(count).await;
665 }
666 *count -= 1;
667 }
668
669 async fn increment(mu: Arc<RwLock<usize>>, cv: Arc<Condvar>, done: Sender<()>) {
670 for _ in 0..TASKS * OBSERVERS * ITERATIONS {
671 *mu.lock().await += 1;
672 cv.notify_one();
673 }
674
675 done.send(()).expect("Failed to send completion message");
676 }
677
678 async fn observe_either(
679 mu: Arc<RwLock<usize>>,
680 cv: Arc<Condvar>,
681 alt_mu: Arc<RwLock<usize>>,
682 alt_cv: Arc<Condvar>,
683 done: Sender<()>,
684 ) {
685 for _ in 0..ITERATIONS {
686 select! {
687 () = observe(&mu, &cv).fuse() => {},
688 () = observe(&alt_mu, &alt_cv).fuse() => {},
689 }
690 }
691
692 done.send(()).expect("Failed to send completion message");
693 }
694
695 async fn decrement_either(
696 mu: Arc<RwLock<usize>>,
697 cv: Arc<Condvar>,
698 alt_mu: Arc<RwLock<usize>>,
699 alt_cv: Arc<Condvar>,
700 done: Sender<()>,
701 ) {
702 for _ in 0..ITERATIONS {
703 select! {
704 () = decrement(&mu, &cv).fuse() => {},
705 () = decrement(&alt_mu, &alt_cv).fuse() => {},
706 }
707 }
708
709 done.send(()).expect("Failed to send completion message");
710 }
711
712 let ex = ThreadPool::new().expect("Failed to create ThreadPool");
713
714 let mu = Arc::new(RwLock::new(0usize));
715 let alt_mu = Arc::new(RwLock::new(0usize));
716
717 let cv = Arc::new(Condvar::new());
718 let alt_cv = Arc::new(Condvar::new());
719
720 let (tx, rx) = channel();
721 for _ in 0..TASKS {
722 ex.spawn_ok(decrement_either(
723 Arc::clone(&mu),
724 Arc::clone(&cv),
725 Arc::clone(&alt_mu),
726 Arc::clone(&alt_cv),
727 tx.clone(),
728 ));
729 }
730
731 for _ in 0..OBSERVERS {
732 ex.spawn_ok(observe_either(
733 Arc::clone(&mu),
734 Arc::clone(&cv),
735 Arc::clone(&alt_mu),
736 Arc::clone(&alt_cv),
737 tx.clone(),
738 ));
739 }
740
741 ex.spawn_ok(increment(Arc::clone(&mu), Arc::clone(&cv), tx.clone()));
742 ex.spawn_ok(increment(Arc::clone(&alt_mu), Arc::clone(&alt_cv), tx));
743
744 for _ in 0..TASKS + OBSERVERS + 2 {
745 if let Err(e) = rx.recv_timeout(Duration::from_secs(20)) {
746 panic!("Error while waiting for threads to complete: {}", e);
747 }
748 }
749
750 assert_eq!(
751 *block_on(mu.read_lock()) + *block_on(alt_mu.read_lock()),
752 (TASKS * OBSERVERS * ITERATIONS * 2) - (TASKS * ITERATIONS)
753 );
754 assert_eq!(cv.state.load(Ordering::Relaxed), 0);
755 assert_eq!(alt_cv.state.load(Ordering::Relaxed), 0);
756 }
757
758 #[test]
notify_all_with_cancel()759 fn notify_all_with_cancel() {
760 const TASKS: usize = 17;
761 const ITERATIONS: usize = 103;
762
763 async fn decrement(mu: &Arc<RwLock<usize>>, cv: &Arc<Condvar>) {
764 let mut count = mu.lock().await;
765 while *count == 0 {
766 count = cv.wait(count).await;
767 }
768 *count -= 1;
769 }
770
771 async fn increment(mu: Arc<RwLock<usize>>, cv: Arc<Condvar>, done: Sender<()>) {
772 for _ in 0..TASKS * ITERATIONS {
773 *mu.lock().await += 1;
774 cv.notify_all();
775 }
776
777 done.send(()).expect("Failed to send completion message");
778 }
779
780 async fn decrement_either(
781 mu: Arc<RwLock<usize>>,
782 cv: Arc<Condvar>,
783 alt_mu: Arc<RwLock<usize>>,
784 alt_cv: Arc<Condvar>,
785 done: Sender<()>,
786 ) {
787 for _ in 0..ITERATIONS {
788 select! {
789 () = decrement(&mu, &cv).fuse() => {},
790 () = decrement(&alt_mu, &alt_cv).fuse() => {},
791 }
792 }
793
794 done.send(()).expect("Failed to send completion message");
795 }
796
797 let ex = ThreadPool::new().expect("Failed to create ThreadPool");
798
799 let mu = Arc::new(RwLock::new(0usize));
800 let alt_mu = Arc::new(RwLock::new(0usize));
801
802 let cv = Arc::new(Condvar::new());
803 let alt_cv = Arc::new(Condvar::new());
804
805 let (tx, rx) = channel();
806 for _ in 0..TASKS {
807 ex.spawn_ok(decrement_either(
808 Arc::clone(&mu),
809 Arc::clone(&cv),
810 Arc::clone(&alt_mu),
811 Arc::clone(&alt_cv),
812 tx.clone(),
813 ));
814 }
815
816 ex.spawn_ok(increment(Arc::clone(&mu), Arc::clone(&cv), tx.clone()));
817 ex.spawn_ok(increment(Arc::clone(&alt_mu), Arc::clone(&alt_cv), tx));
818
819 for _ in 0..TASKS + 2 {
820 if let Err(e) = rx.recv_timeout(Duration::from_secs(10)) {
821 panic!("Error while waiting for threads to complete: {}", e);
822 }
823 }
824
825 assert_eq!(
826 *block_on(mu.read_lock()) + *block_on(alt_mu.read_lock()),
827 TASKS * ITERATIONS
828 );
829 assert_eq!(cv.state.load(Ordering::Relaxed), 0);
830 assert_eq!(alt_cv.state.load(Ordering::Relaxed), 0);
831 }
832 #[test]
notify_all()833 fn notify_all() {
834 const THREADS: usize = 13;
835
836 let mu = Arc::new(RwLock::new(0));
837 let cv = Arc::new(Condvar::new());
838 let (tx, rx) = channel();
839
840 let mut threads = Vec::with_capacity(THREADS);
841 for _ in 0..THREADS {
842 let mu2 = mu.clone();
843 let cv2 = cv.clone();
844 let tx2 = tx.clone();
845
846 threads.push(thread::spawn(move || {
847 let mut count = block_on(mu2.lock());
848 *count += 1;
849 if *count == THREADS {
850 tx2.send(()).unwrap();
851 }
852
853 while *count != 0 {
854 count = block_on(cv2.wait(count));
855 }
856 }));
857 }
858
859 mem::drop(tx);
860
861 // Wait till all threads have started.
862 rx.recv_timeout(Duration::from_secs(5)).unwrap();
863
864 let mut count = block_on(mu.lock());
865 *count = 0;
866 mem::drop(count);
867 cv.notify_all();
868
869 for t in threads {
870 t.join().unwrap();
871 }
872 }
873
874 #[test]
notify_all_single_thread_async()875 fn notify_all_single_thread_async() {
876 const TASKS: usize = 13;
877
878 async fn reset(mu: Rc<RwLock<usize>>, cv: Rc<Condvar>) {
879 let mut count = mu.lock().await;
880 *count = 0;
881 cv.notify_all();
882 }
883
884 async fn watcher(mu: Rc<RwLock<usize>>, cv: Rc<Condvar>, spawner: LocalSpawner) {
885 let mut count = mu.lock().await;
886 *count += 1;
887 if *count == TASKS {
888 spawner
889 .spawn_local(reset(mu.clone(), cv.clone()))
890 .expect("Failed to spawn reset task");
891 }
892
893 while *count != 0 {
894 count = cv.wait(count).await;
895 }
896 }
897
898 let mut ex = LocalPool::new();
899 let spawner = ex.spawner();
900
901 let mu = Rc::new(RwLock::new(0));
902 let cv = Rc::new(Condvar::new());
903
904 for _ in 0..TASKS {
905 spawner
906 .spawn_local(watcher(mu.clone(), cv.clone(), spawner.clone()))
907 .expect("Failed to spawn watcher task");
908 }
909
910 ex.run();
911 }
912
913 #[test]
notify_all_multi_thread_async()914 fn notify_all_multi_thread_async() {
915 const TASKS: usize = 13;
916
917 async fn reset(mu: Arc<RwLock<usize>>, cv: Arc<Condvar>) {
918 let mut count = mu.lock().await;
919 *count = 0;
920 cv.notify_all();
921 }
922
923 async fn watcher(
924 mu: Arc<RwLock<usize>>,
925 cv: Arc<Condvar>,
926 pool: ThreadPool,
927 tx: Sender<()>,
928 ) {
929 let mut count = mu.lock().await;
930 *count += 1;
931 if *count == TASKS {
932 pool.spawn_ok(reset(mu.clone(), cv.clone()));
933 }
934
935 while *count != 0 {
936 count = cv.wait(count).await;
937 }
938
939 tx.send(()).expect("Failed to send completion notification");
940 }
941
942 let pool = ThreadPool::new().expect("Failed to create ThreadPool");
943
944 let mu = Arc::new(RwLock::new(0));
945 let cv = Arc::new(Condvar::new());
946
947 let (tx, rx) = channel();
948 for _ in 0..TASKS {
949 pool.spawn_ok(watcher(mu.clone(), cv.clone(), pool.clone(), tx.clone()));
950 }
951
952 for _ in 0..TASKS {
953 rx.recv_timeout(Duration::from_secs(5))
954 .expect("Failed to receive completion notification");
955 }
956 }
957
958 #[test]
wake_all_readers()959 fn wake_all_readers() {
960 async fn read(mu: Arc<RwLock<bool>>, cv: Arc<Condvar>) {
961 let mut ready = mu.read_lock().await;
962 while !*ready {
963 ready = cv.wait_read(ready).await;
964 }
965 }
966
967 let mu = Arc::new(RwLock::new(false));
968 let cv = Arc::new(Condvar::new());
969 let mut readers = [
970 Box::pin(read(mu.clone(), cv.clone())),
971 Box::pin(read(mu.clone(), cv.clone())),
972 Box::pin(read(mu.clone(), cv.clone())),
973 Box::pin(read(mu.clone(), cv.clone())),
974 ];
975
976 let arc_waker = Arc::new(TestWaker);
977 let waker = waker_ref(&arc_waker);
978 let mut cx = Context::from_waker(&waker);
979
980 // First have all the readers wait on the Condvar.
981 for r in &mut readers {
982 if let Poll::Ready(()) = r.as_mut().poll(&mut cx) {
983 panic!("reader unexpectedly ready");
984 }
985 }
986
987 assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS);
988
989 // Now make the condition true and notify the condvar. Even though we will call notify_one,
990 // all the readers should be woken up.
991 *block_on(mu.lock()) = true;
992 cv.notify_one();
993
994 assert_eq!(cv.state.load(Ordering::Relaxed), 0);
995
996 // All readers should now be able to complete.
997 for r in &mut readers {
998 if r.as_mut().poll(&mut cx).is_pending() {
999 panic!("reader unable to complete");
1000 }
1001 }
1002 }
1003
1004 #[test]
cancel_before_notify()1005 fn cancel_before_notify() {
1006 async fn dec(mu: Arc<RwLock<usize>>, cv: Arc<Condvar>) {
1007 let mut count = mu.lock().await;
1008
1009 while *count == 0 {
1010 count = cv.wait(count).await;
1011 }
1012
1013 *count -= 1;
1014 }
1015
1016 let mu = Arc::new(RwLock::new(0));
1017 let cv = Arc::new(Condvar::new());
1018
1019 let arc_waker = Arc::new(TestWaker);
1020 let waker = waker_ref(&arc_waker);
1021 let mut cx = Context::from_waker(&waker);
1022
1023 let mut fut1 = Box::pin(dec(mu.clone(), cv.clone()));
1024 let mut fut2 = Box::pin(dec(mu.clone(), cv.clone()));
1025
1026 if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) {
1027 panic!("future unexpectedly ready");
1028 }
1029 if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) {
1030 panic!("future unexpectedly ready");
1031 }
1032 assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS);
1033
1034 *block_on(mu.lock()) = 2;
1035 // Drop fut1 before notifying the cv.
1036 mem::drop(fut1);
1037 cv.notify_one();
1038
1039 // fut2 should now be ready to complete.
1040 assert_eq!(cv.state.load(Ordering::Relaxed), 0);
1041
1042 if fut2.as_mut().poll(&mut cx).is_pending() {
1043 panic!("future unable to complete");
1044 }
1045
1046 assert_eq!(*block_on(mu.lock()), 1);
1047 }
1048
1049 #[test]
cancel_after_notify_one()1050 fn cancel_after_notify_one() {
1051 async fn dec(mu: Arc<RwLock<usize>>, cv: Arc<Condvar>) {
1052 let mut count = mu.lock().await;
1053
1054 while *count == 0 {
1055 count = cv.wait(count).await;
1056 }
1057
1058 *count -= 1;
1059 }
1060
1061 let mu = Arc::new(RwLock::new(0));
1062 let cv = Arc::new(Condvar::new());
1063
1064 let arc_waker = Arc::new(TestWaker);
1065 let waker = waker_ref(&arc_waker);
1066 let mut cx = Context::from_waker(&waker);
1067
1068 let mut fut1 = Box::pin(dec(mu.clone(), cv.clone()));
1069 let mut fut2 = Box::pin(dec(mu.clone(), cv.clone()));
1070
1071 if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) {
1072 panic!("future unexpectedly ready");
1073 }
1074 if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) {
1075 panic!("future unexpectedly ready");
1076 }
1077 assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS);
1078
1079 *block_on(mu.lock()) = 2;
1080 cv.notify_one();
1081
1082 // fut1 should now be ready to complete. Drop it before polling. This should wake up fut2.
1083 mem::drop(fut1);
1084 assert_eq!(cv.state.load(Ordering::Relaxed), 0);
1085
1086 if fut2.as_mut().poll(&mut cx).is_pending() {
1087 panic!("future unable to complete");
1088 }
1089
1090 assert_eq!(*block_on(mu.lock()), 1);
1091 }
1092
1093 #[test]
cancel_after_notify_all()1094 fn cancel_after_notify_all() {
1095 async fn dec(mu: Arc<RwLock<usize>>, cv: Arc<Condvar>) {
1096 let mut count = mu.lock().await;
1097
1098 while *count == 0 {
1099 count = cv.wait(count).await;
1100 }
1101
1102 *count -= 1;
1103 }
1104
1105 let mu = Arc::new(RwLock::new(0));
1106 let cv = Arc::new(Condvar::new());
1107
1108 let arc_waker = Arc::new(TestWaker);
1109 let waker = waker_ref(&arc_waker);
1110 let mut cx = Context::from_waker(&waker);
1111
1112 let mut fut1 = Box::pin(dec(mu.clone(), cv.clone()));
1113 let mut fut2 = Box::pin(dec(mu.clone(), cv.clone()));
1114
1115 if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) {
1116 panic!("future unexpectedly ready");
1117 }
1118 if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) {
1119 panic!("future unexpectedly ready");
1120 }
1121 assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS);
1122
1123 let mut count = block_on(mu.lock());
1124 *count = 2;
1125
1126 // Notify the cv while holding the lock. This should wake up both waiters.
1127 cv.notify_all();
1128 assert_eq!(cv.state.load(Ordering::Relaxed), 0);
1129
1130 mem::drop(count);
1131
1132 mem::drop(fut1);
1133
1134 if fut2.as_mut().poll(&mut cx).is_pending() {
1135 panic!("future unable to complete");
1136 }
1137
1138 assert_eq!(*block_on(mu.lock()), 1);
1139 }
1140
1141 #[test]
timed_wait()1142 fn timed_wait() {
1143 async fn wait_deadline(
1144 mu: Arc<RwLock<usize>>,
1145 cv: Arc<Condvar>,
1146 timeout: oneshot::Receiver<()>,
1147 ) {
1148 let mut count = mu.lock().await;
1149
1150 if *count == 0 {
1151 let mut rx = timeout.fuse();
1152
1153 while *count == 0 {
1154 select! {
1155 res = rx => {
1156 if let Err(e) = res {
1157 panic!("Error while receiving timeout notification: {}", e);
1158 }
1159
1160 return;
1161 },
1162 c = cv.wait(count).fuse() => count = c,
1163 }
1164 }
1165 }
1166
1167 *count += 1;
1168 }
1169
1170 let mu = Arc::new(RwLock::new(0));
1171 let cv = Arc::new(Condvar::new());
1172
1173 let arc_waker = Arc::new(TestWaker);
1174 let waker = waker_ref(&arc_waker);
1175 let mut cx = Context::from_waker(&waker);
1176
1177 let (tx, rx) = oneshot::channel();
1178 let mut wait = Box::pin(wait_deadline(mu.clone(), cv.clone(), rx));
1179
1180 if let Poll::Ready(()) = wait.as_mut().poll(&mut cx) {
1181 panic!("wait_deadline unexpectedly ready");
1182 }
1183
1184 assert_eq!(cv.state.load(Ordering::Relaxed), HAS_WAITERS);
1185
1186 // Signal the channel, which should cancel the wait.
1187 tx.send(()).expect("Failed to send wakeup");
1188
1189 // Wait for the timer to run out.
1190 if wait.as_mut().poll(&mut cx).is_pending() {
1191 panic!("wait_deadline unable to complete in time");
1192 }
1193
1194 assert_eq!(cv.state.load(Ordering::Relaxed), 0);
1195 assert_eq!(*block_on(mu.lock()), 0);
1196 }
1197 }
1198