1 use std::sync::atomic::{AtomicPtr, Ordering};
2 use std::sync::RwLock;
3 
4 use super::sealed::{CaS, InnerStrategy, Protected};
5 use crate::as_raw::AsRaw;
6 use crate::ref_cnt::RefCnt;
7 
8 impl<T: RefCnt> Protected<T> for T {
9     #[inline]
from_inner(ptr: T) -> Self10     fn from_inner(ptr: T) -> Self {
11         ptr
12     }
13 
14     #[inline]
into_inner(self) -> T15     fn into_inner(self) -> T {
16         self
17     }
18 }
19 
20 impl<T: RefCnt> InnerStrategy<T> for RwLock<()> {
21     type Protected = T;
load(&self, storage: &AtomicPtr<T::Base>) -> T22     unsafe fn load(&self, storage: &AtomicPtr<T::Base>) -> T {
23         let _guard = self.read().expect("We don't panic in here");
24         let ptr = storage.load(Ordering::Acquire);
25         let ptr = T::from_ptr(ptr as *const T::Base);
26         T::inc(&ptr);
27 
28         ptr
29     }
30 
wait_for_readers(&self, _: *const T::Base, _: &AtomicPtr<T::Base>)31     unsafe fn wait_for_readers(&self, _: *const T::Base, _: &AtomicPtr<T::Base>) {
32         // By acquiring the write lock, we make sure there are no read locks present across it.
33         drop(self.write().expect("We don't panic in here"));
34     }
35 }
36 
37 impl<T: RefCnt> CaS<T> for RwLock<()> {
compare_and_swap<C: AsRaw<T::Base>>( &self, storage: &AtomicPtr<T::Base>, current: C, new: T, ) -> Self::Protected38     unsafe fn compare_and_swap<C: AsRaw<T::Base>>(
39         &self,
40         storage: &AtomicPtr<T::Base>,
41         current: C,
42         new: T,
43     ) -> Self::Protected {
44         let _lock = self.write();
45         let cur = current.as_raw() as *mut T::Base;
46         let new = T::into_ptr(new);
47         let swapped = storage.compare_exchange(cur, new, Ordering::AcqRel, Ordering::Relaxed);
48         let old = match swapped {
49             Ok(old) => old,
50             Err(old) => old,
51         };
52         let old = T::from_ptr(old as *const T::Base);
53         if swapped.is_err() {
54             // If the new didn't go in, we need to destroy it and increment count in the old that
55             // we just duplicated
56             T::inc(&old);
57             drop(T::from_ptr(new));
58         }
59         drop(current);
60         old
61     }
62 }
63