1 // There's a lot of scary concurrent code in this module, but it is copied from
2 // `std::sync::Once` with two changes:
3 //   * no poisoning
4 //   * init function can fail
5 
6 use std::{
7     cell::{Cell, UnsafeCell},
8     panic::{RefUnwindSafe, UnwindSafe},
9     sync::atomic::{AtomicBool, AtomicPtr, Ordering},
10     thread::{self, Thread},
11 };
12 
13 #[derive(Debug)]
14 pub(crate) struct OnceCell<T> {
15     // This `queue` field is the core of the implementation. It encodes two
16     // pieces of information:
17     //
18     // * The current state of the cell (`INCOMPLETE`, `RUNNING`, `COMPLETE`)
19     // * Linked list of threads waiting for the current cell.
20     //
21     // State is encoded in two low bits. Only `INCOMPLETE` and `RUNNING` states
22     // allow waiters.
23     queue: AtomicPtr<Waiter>,
24     value: UnsafeCell<Option<T>>,
25 }
26 
27 // Why do we need `T: Send`?
28 // Thread A creates a `OnceCell` and shares it with
29 // scoped thread B, which fills the cell, which is
30 // then destroyed by A. That is, destructor observes
31 // a sent value.
32 unsafe impl<T: Sync + Send> Sync for OnceCell<T> {}
33 unsafe impl<T: Send> Send for OnceCell<T> {}
34 
35 impl<T: RefUnwindSafe + UnwindSafe> RefUnwindSafe for OnceCell<T> {}
36 impl<T: UnwindSafe> UnwindSafe for OnceCell<T> {}
37 
38 impl<T> OnceCell<T> {
new() -> OnceCell<T>39     pub(crate) const fn new() -> OnceCell<T> {
40         OnceCell { queue: AtomicPtr::new(INCOMPLETE_PTR), value: UnsafeCell::new(None) }
41     }
42 
with_value(value: T) -> OnceCell<T>43     pub(crate) const fn with_value(value: T) -> OnceCell<T> {
44         OnceCell { queue: AtomicPtr::new(COMPLETE_PTR), value: UnsafeCell::new(Some(value)) }
45     }
46 
47     /// Safety: synchronizes with store to value via Release/(Acquire|SeqCst).
48     #[inline]
is_initialized(&self) -> bool49     pub(crate) fn is_initialized(&self) -> bool {
50         // An `Acquire` load is enough because that makes all the initialization
51         // operations visible to us, and, this being a fast path, weaker
52         // ordering helps with performance. This `Acquire` synchronizes with
53         // `SeqCst` operations on the slow path.
54         self.queue.load(Ordering::Acquire) == COMPLETE_PTR
55     }
56 
57     /// Safety: synchronizes with store to value via SeqCst read from state,
58     /// writes value only once because we never get to INCOMPLETE state after a
59     /// successful write.
60     #[cold]
initialize<F, E>(&self, f: F) -> Result<(), E> where F: FnOnce() -> Result<T, E>,61     pub(crate) fn initialize<F, E>(&self, f: F) -> Result<(), E>
62     where
63         F: FnOnce() -> Result<T, E>,
64     {
65         let mut f = Some(f);
66         let mut res: Result<(), E> = Ok(());
67         let slot: *mut Option<T> = self.value.get();
68         initialize_or_wait(
69             &self.queue,
70             Some(&mut || {
71                 let f = unsafe { f.take().unwrap_unchecked() };
72                 match f() {
73                     Ok(value) => {
74                         unsafe { *slot = Some(value) };
75                         true
76                     }
77                     Err(err) => {
78                         res = Err(err);
79                         false
80                     }
81                 }
82             }),
83         );
84         res
85     }
86 
87     #[cold]
wait(&self)88     pub(crate) fn wait(&self) {
89         initialize_or_wait(&self.queue, None);
90     }
91 
92     /// Get the reference to the underlying value, without checking if the cell
93     /// is initialized.
94     ///
95     /// # Safety
96     ///
97     /// Caller must ensure that the cell is in initialized state, and that
98     /// the contents are acquired by (synchronized to) this thread.
get_unchecked(&self) -> &T99     pub(crate) unsafe fn get_unchecked(&self) -> &T {
100         debug_assert!(self.is_initialized());
101         let slot = &*self.value.get();
102         slot.as_ref().unwrap_unchecked()
103     }
104 
105     /// Gets the mutable reference to the underlying value.
106     /// Returns `None` if the cell is empty.
get_mut(&mut self) -> Option<&mut T>107     pub(crate) fn get_mut(&mut self) -> Option<&mut T> {
108         // Safe b/c we have a unique access.
109         unsafe { &mut *self.value.get() }.as_mut()
110     }
111 
112     /// Consumes this `OnceCell`, returning the wrapped value.
113     /// Returns `None` if the cell was empty.
114     #[inline]
into_inner(self) -> Option<T>115     pub(crate) fn into_inner(self) -> Option<T> {
116         // Because `into_inner` takes `self` by value, the compiler statically
117         // verifies that it is not currently borrowed.
118         // So, it is safe to move out `Option<T>`.
119         self.value.into_inner()
120     }
121 }
122 
123 // Three states that a OnceCell can be in, encoded into the lower bits of `queue` in
124 // the OnceCell structure.
125 const INCOMPLETE: usize = 0x0;
126 const RUNNING: usize = 0x1;
127 const COMPLETE: usize = 0x2;
128 const INCOMPLETE_PTR: *mut Waiter = INCOMPLETE as *mut Waiter;
129 const COMPLETE_PTR: *mut Waiter = COMPLETE as *mut Waiter;
130 
131 // Mask to learn about the state. All other bits are the queue of waiters if
132 // this is in the RUNNING state.
133 const STATE_MASK: usize = 0x3;
134 
135 /// Representation of a node in the linked list of waiters in the RUNNING state.
136 /// A waiters is stored on the stack of the waiting threads.
137 #[repr(align(4))] // Ensure the two lower bits are free to use as state bits.
138 struct Waiter {
139     thread: Cell<Option<Thread>>,
140     signaled: AtomicBool,
141     next: *mut Waiter,
142 }
143 
144 /// Drains and notifies the queue of waiters on drop.
145 struct Guard<'a> {
146     queue: &'a AtomicPtr<Waiter>,
147     new_queue: *mut Waiter,
148 }
149 
150 impl Drop for Guard<'_> {
drop(&mut self)151     fn drop(&mut self) {
152         let queue = self.queue.swap(self.new_queue, Ordering::AcqRel);
153 
154         let state = strict::addr(queue) & STATE_MASK;
155         assert_eq!(state, RUNNING);
156 
157         unsafe {
158             let mut waiter = strict::map_addr(queue, |q| q & !STATE_MASK);
159             while !waiter.is_null() {
160                 let next = (*waiter).next;
161                 let thread = (*waiter).thread.take().unwrap();
162                 (*waiter).signaled.store(true, Ordering::Release);
163                 waiter = next;
164                 thread.unpark();
165             }
166         }
167     }
168 }
169 
170 // Corresponds to `std::sync::Once::call_inner`.
171 //
172 // Originally copied from std, but since modified to remove poisoning and to
173 // support wait.
174 //
175 // Note: this is intentionally monomorphic
176 #[inline(never)]
initialize_or_wait(queue: &AtomicPtr<Waiter>, mut init: Option<&mut dyn FnMut() -> bool>)177 fn initialize_or_wait(queue: &AtomicPtr<Waiter>, mut init: Option<&mut dyn FnMut() -> bool>) {
178     let mut curr_queue = queue.load(Ordering::Acquire);
179 
180     loop {
181         let curr_state = strict::addr(curr_queue) & STATE_MASK;
182         match (curr_state, &mut init) {
183             (COMPLETE, _) => return,
184             (INCOMPLETE, Some(init)) => {
185                 let exchange = queue.compare_exchange(
186                     curr_queue,
187                     strict::map_addr(curr_queue, |q| (q & !STATE_MASK) | RUNNING),
188                     Ordering::Acquire,
189                     Ordering::Acquire,
190                 );
191                 if let Err(new_queue) = exchange {
192                     curr_queue = new_queue;
193                     continue;
194                 }
195                 let mut guard = Guard { queue, new_queue: INCOMPLETE_PTR };
196                 if init() {
197                     guard.new_queue = COMPLETE_PTR;
198                 }
199                 return;
200             }
201             (INCOMPLETE, None) | (RUNNING, _) => {
202                 wait(queue, curr_queue);
203                 curr_queue = queue.load(Ordering::Acquire);
204             }
205             _ => debug_assert!(false),
206         }
207     }
208 }
209 
wait(queue: &AtomicPtr<Waiter>, mut curr_queue: *mut Waiter)210 fn wait(queue: &AtomicPtr<Waiter>, mut curr_queue: *mut Waiter) {
211     let curr_state = strict::addr(curr_queue) & STATE_MASK;
212     loop {
213         let node = Waiter {
214             thread: Cell::new(Some(thread::current())),
215             signaled: AtomicBool::new(false),
216             next: strict::map_addr(curr_queue, |q| q & !STATE_MASK),
217         };
218         let me = &node as *const Waiter as *mut Waiter;
219 
220         let exchange = queue.compare_exchange(
221             curr_queue,
222             strict::map_addr(me, |q| q | curr_state),
223             Ordering::Release,
224             Ordering::Relaxed,
225         );
226         if let Err(new_queue) = exchange {
227             if strict::addr(new_queue) & STATE_MASK != curr_state {
228                 return;
229             }
230             curr_queue = new_queue;
231             continue;
232         }
233 
234         while !node.signaled.load(Ordering::Acquire) {
235             thread::park();
236         }
237         break;
238     }
239 }
240 
241 // Polyfill of strict provenance from https://crates.io/crates/sptr.
242 //
243 // Use free-standing function rather than a trait to keep things simple and
244 // avoid any potential conflicts with future stabile std API.
245 mod strict {
246     #[must_use]
247     #[inline]
addr<T>(ptr: *mut T) -> usize where T: Sized,248     pub(crate) fn addr<T>(ptr: *mut T) -> usize
249     where
250         T: Sized,
251     {
252         // FIXME(strict_provenance_magic): I am magic and should be a compiler intrinsic.
253         // SAFETY: Pointer-to-integer transmutes are valid (if you are okay with losing the
254         // provenance).
255         unsafe { core::mem::transmute(ptr) }
256     }
257 
258     #[must_use]
259     #[inline]
with_addr<T>(ptr: *mut T, addr: usize) -> *mut T where T: Sized,260     pub(crate) fn with_addr<T>(ptr: *mut T, addr: usize) -> *mut T
261     where
262         T: Sized,
263     {
264         // FIXME(strict_provenance_magic): I am magic and should be a compiler intrinsic.
265         //
266         // In the mean-time, this operation is defined to be "as if" it was
267         // a wrapping_offset, so we can emulate it as such. This should properly
268         // restore pointer provenance even under today's compiler.
269         let self_addr = self::addr(ptr) as isize;
270         let dest_addr = addr as isize;
271         let offset = dest_addr.wrapping_sub(self_addr);
272 
273         // This is the canonical desugarring of this operation,
274         // but `pointer::cast` was only stabilized in 1.38.
275         // self.cast::<u8>().wrapping_offset(offset).cast::<T>()
276         (ptr as *mut u8).wrapping_offset(offset) as *mut T
277     }
278 
279     #[must_use]
280     #[inline]
map_addr<T>(ptr: *mut T, f: impl FnOnce(usize) -> usize) -> *mut T where T: Sized,281     pub(crate) fn map_addr<T>(ptr: *mut T, f: impl FnOnce(usize) -> usize) -> *mut T
282     where
283         T: Sized,
284     {
285         self::with_addr(ptr, f(addr(ptr)))
286     }
287 }
288 
289 // These test are snatched from std as well.
290 #[cfg(test)]
291 mod tests {
292     use std::panic;
293     use std::{sync::mpsc::channel, thread};
294 
295     use super::OnceCell;
296 
297     impl<T> OnceCell<T> {
init(&self, f: impl FnOnce() -> T)298         fn init(&self, f: impl FnOnce() -> T) {
299             enum Void {}
300             let _ = self.initialize(|| Ok::<T, Void>(f()));
301         }
302     }
303 
304     #[test]
smoke_once()305     fn smoke_once() {
306         static O: OnceCell<()> = OnceCell::new();
307         let mut a = 0;
308         O.init(|| a += 1);
309         assert_eq!(a, 1);
310         O.init(|| a += 1);
311         assert_eq!(a, 1);
312     }
313 
314     #[test]
stampede_once()315     fn stampede_once() {
316         static O: OnceCell<()> = OnceCell::new();
317         static mut RUN: bool = false;
318 
319         let (tx, rx) = channel();
320         for _ in 0..10 {
321             let tx = tx.clone();
322             thread::spawn(move || {
323                 for _ in 0..4 {
324                     thread::yield_now()
325                 }
326                 unsafe {
327                     O.init(|| {
328                         assert!(!RUN);
329                         RUN = true;
330                     });
331                     assert!(RUN);
332                 }
333                 tx.send(()).unwrap();
334             });
335         }
336 
337         unsafe {
338             O.init(|| {
339                 assert!(!RUN);
340                 RUN = true;
341             });
342             assert!(RUN);
343         }
344 
345         for _ in 0..10 {
346             rx.recv().unwrap();
347         }
348     }
349 
350     #[test]
351     #[cfg(not(target_os = "android"))]
poison_bad()352     fn poison_bad() {
353         static O: OnceCell<()> = OnceCell::new();
354 
355         // poison the once
356         let t = panic::catch_unwind(|| {
357             O.init(|| panic!());
358         });
359         assert!(t.is_err());
360 
361         // we can subvert poisoning, however
362         let mut called = false;
363         O.init(|| {
364             called = true;
365         });
366         assert!(called);
367 
368         // once any success happens, we stop propagating the poison
369         O.init(|| {});
370     }
371 
372     #[test]
373     #[cfg(not(target_os = "android"))]
wait_for_force_to_finish()374     fn wait_for_force_to_finish() {
375         static O: OnceCell<()> = OnceCell::new();
376 
377         // poison the once
378         let t = panic::catch_unwind(|| {
379             O.init(|| panic!());
380         });
381         assert!(t.is_err());
382 
383         // make sure someone's waiting inside the once via a force
384         let (tx1, rx1) = channel();
385         let (tx2, rx2) = channel();
386         let t1 = thread::spawn(move || {
387             O.init(|| {
388                 tx1.send(()).unwrap();
389                 rx2.recv().unwrap();
390             });
391         });
392 
393         rx1.recv().unwrap();
394 
395         // put another waiter on the once
396         let t2 = thread::spawn(|| {
397             let mut called = false;
398             O.init(|| {
399                 called = true;
400             });
401             assert!(!called);
402         });
403 
404         tx2.send(()).unwrap();
405 
406         assert!(t1.join().is_ok());
407         assert!(t2.join().is_ok());
408     }
409 
410     #[test]
411     #[cfg(target_pointer_width = "64")]
test_size()412     fn test_size() {
413         use std::mem::size_of;
414 
415         assert_eq!(size_of::<OnceCell<u32>>(), 4 * size_of::<u32>());
416     }
417 }
418