1 use std::cmp;
2 use std::fmt;
3 use std::mem;
4 
5 use crate::errors::InvalidThreadAccess;
6 use crate::fragile::Fragile;
7 use crate::sticky::Sticky;
8 use crate::StackToken;
9 
10 enum SemiStickyImpl<T: 'static> {
11     Fragile(Box<Fragile<T>>),
12     Sticky(Sticky<T>),
13 }
14 
15 /// A [`SemiSticky<T>`] keeps a value T stored in a thread if it has a drop.
16 ///
17 /// This is a combined version of [`Fragile`] and [`Sticky`].  If the type
18 /// does not have a drop it will effectively be a [`Fragile`], otherwise it
19 /// will be internally behave like a [`Sticky`].
20 ///
21 /// This type requires `T: 'static` for the same reasons as [`Sticky`] and
22 /// also uses [`StackToken`]s.
23 pub struct SemiSticky<T: 'static> {
24     inner: SemiStickyImpl<T>,
25 }
26 
27 impl<T> SemiSticky<T> {
28     /// Creates a new [`SemiSticky`] wrapping a `value`.
29     ///
30     /// The value that is moved into the `SemiSticky` can be non `Send` and
31     /// will be anchored to the thread that created the object.  If the
32     /// sticky wrapper type ends up being send from thread to thread
33     /// only the original thread can interact with the value.  In case the
34     /// value does not have `Drop` it will be stored in the [`SemiSticky`]
35     /// instead.
new(value: T) -> Self36     pub fn new(value: T) -> Self {
37         SemiSticky {
38             inner: if mem::needs_drop::<T>() {
39                 SemiStickyImpl::Sticky(Sticky::new(value))
40             } else {
41                 SemiStickyImpl::Fragile(Box::new(Fragile::new(value)))
42             },
43         }
44     }
45 
46     /// Returns `true` if the access is valid.
47     ///
48     /// This will be `false` if the value was sent to another thread.
is_valid(&self) -> bool49     pub fn is_valid(&self) -> bool {
50         match self.inner {
51             SemiStickyImpl::Fragile(ref inner) => inner.is_valid(),
52             SemiStickyImpl::Sticky(ref inner) => inner.is_valid(),
53         }
54     }
55 
56     /// Consumes the [`SemiSticky`], returning the wrapped value.
57     ///
58     /// # Panics
59     ///
60     /// Panics if called from a different thread than the one where the
61     /// original value was created.
into_inner(self) -> T62     pub fn into_inner(self) -> T {
63         match self.inner {
64             SemiStickyImpl::Fragile(inner) => inner.into_inner(),
65             SemiStickyImpl::Sticky(inner) => inner.into_inner(),
66         }
67     }
68 
69     /// Consumes the [`SemiSticky`], returning the wrapped value if successful.
70     ///
71     /// The wrapped value is returned if this is called from the same thread
72     /// as the one where the original value was created, otherwise the
73     /// [`SemiSticky`] is returned as `Err(self)`.
try_into_inner(self) -> Result<T, Self>74     pub fn try_into_inner(self) -> Result<T, Self> {
75         match self.inner {
76             SemiStickyImpl::Fragile(inner) => inner.try_into_inner().map_err(|inner| SemiSticky {
77                 inner: SemiStickyImpl::Fragile(Box::new(inner)),
78             }),
79             SemiStickyImpl::Sticky(inner) => inner.try_into_inner().map_err(|inner| SemiSticky {
80                 inner: SemiStickyImpl::Sticky(inner),
81             }),
82         }
83     }
84 
85     /// Immutably borrows the wrapped value.
86     ///
87     /// # Panics
88     ///
89     /// Panics if the calling thread is not the one that wrapped the value.
90     /// For a non-panicking variant, use [`try_get`](Self::try_get).
get<'stack>(&'stack self, _proof: &'stack StackToken) -> &'stack T91     pub fn get<'stack>(&'stack self, _proof: &'stack StackToken) -> &'stack T {
92         match self.inner {
93             SemiStickyImpl::Fragile(ref inner) => inner.get(),
94             SemiStickyImpl::Sticky(ref inner) => inner.get(_proof),
95         }
96     }
97 
98     /// Mutably borrows the wrapped value.
99     ///
100     /// # Panics
101     ///
102     /// Panics if the calling thread is not the one that wrapped the value.
103     /// For a non-panicking variant, use [`try_get_mut`](Self::try_get_mut).
get_mut<'stack>(&'stack mut self, _proof: &'stack StackToken) -> &'stack mut T104     pub fn get_mut<'stack>(&'stack mut self, _proof: &'stack StackToken) -> &'stack mut T {
105         match self.inner {
106             SemiStickyImpl::Fragile(ref mut inner) => inner.get_mut(),
107             SemiStickyImpl::Sticky(ref mut inner) => inner.get_mut(_proof),
108         }
109     }
110 
111     /// Tries to immutably borrow the wrapped value.
112     ///
113     /// Returns `None` if the calling thread is not the one that wrapped the value.
try_get<'stack>( &'stack self, _proof: &'stack StackToken, ) -> Result<&'stack T, InvalidThreadAccess>114     pub fn try_get<'stack>(
115         &'stack self,
116         _proof: &'stack StackToken,
117     ) -> Result<&'stack T, InvalidThreadAccess> {
118         match self.inner {
119             SemiStickyImpl::Fragile(ref inner) => inner.try_get(),
120             SemiStickyImpl::Sticky(ref inner) => inner.try_get(_proof),
121         }
122     }
123 
124     /// Tries to mutably borrow the wrapped value.
125     ///
126     /// Returns `None` if the calling thread is not the one that wrapped the value.
try_get_mut<'stack>( &'stack mut self, _proof: &'stack StackToken, ) -> Result<&'stack mut T, InvalidThreadAccess>127     pub fn try_get_mut<'stack>(
128         &'stack mut self,
129         _proof: &'stack StackToken,
130     ) -> Result<&'stack mut T, InvalidThreadAccess> {
131         match self.inner {
132             SemiStickyImpl::Fragile(ref mut inner) => inner.try_get_mut(),
133             SemiStickyImpl::Sticky(ref mut inner) => inner.try_get_mut(_proof),
134         }
135     }
136 }
137 
138 impl<T> From<T> for SemiSticky<T> {
139     #[inline]
from(t: T) -> SemiSticky<T>140     fn from(t: T) -> SemiSticky<T> {
141         SemiSticky::new(t)
142     }
143 }
144 
145 impl<T: Clone> Clone for SemiSticky<T> {
146     #[inline]
clone(&self) -> SemiSticky<T>147     fn clone(&self) -> SemiSticky<T> {
148         crate::stack_token!(tok);
149         SemiSticky::new(self.get(tok).clone())
150     }
151 }
152 
153 impl<T: Default> Default for SemiSticky<T> {
154     #[inline]
default() -> SemiSticky<T>155     fn default() -> SemiSticky<T> {
156         SemiSticky::new(T::default())
157     }
158 }
159 
160 impl<T: PartialEq> PartialEq for SemiSticky<T> {
161     #[inline]
eq(&self, other: &SemiSticky<T>) -> bool162     fn eq(&self, other: &SemiSticky<T>) -> bool {
163         crate::stack_token!(tok);
164         *self.get(tok) == *other.get(tok)
165     }
166 }
167 
168 impl<T: Eq> Eq for SemiSticky<T> {}
169 
170 impl<T: PartialOrd> PartialOrd for SemiSticky<T> {
171     #[inline]
partial_cmp(&self, other: &SemiSticky<T>) -> Option<cmp::Ordering>172     fn partial_cmp(&self, other: &SemiSticky<T>) -> Option<cmp::Ordering> {
173         crate::stack_token!(tok);
174         self.get(tok).partial_cmp(other.get(tok))
175     }
176 
177     #[inline]
lt(&self, other: &SemiSticky<T>) -> bool178     fn lt(&self, other: &SemiSticky<T>) -> bool {
179         crate::stack_token!(tok);
180         *self.get(tok) < *other.get(tok)
181     }
182 
183     #[inline]
le(&self, other: &SemiSticky<T>) -> bool184     fn le(&self, other: &SemiSticky<T>) -> bool {
185         crate::stack_token!(tok);
186         *self.get(tok) <= *other.get(tok)
187     }
188 
189     #[inline]
gt(&self, other: &SemiSticky<T>) -> bool190     fn gt(&self, other: &SemiSticky<T>) -> bool {
191         crate::stack_token!(tok);
192         *self.get(tok) > *other.get(tok)
193     }
194 
195     #[inline]
ge(&self, other: &SemiSticky<T>) -> bool196     fn ge(&self, other: &SemiSticky<T>) -> bool {
197         crate::stack_token!(tok);
198         *self.get(tok) >= *other.get(tok)
199     }
200 }
201 
202 impl<T: Ord> Ord for SemiSticky<T> {
203     #[inline]
cmp(&self, other: &SemiSticky<T>) -> cmp::Ordering204     fn cmp(&self, other: &SemiSticky<T>) -> cmp::Ordering {
205         crate::stack_token!(tok);
206         self.get(tok).cmp(other.get(tok))
207     }
208 }
209 
210 impl<T: fmt::Display> fmt::Display for SemiSticky<T> {
fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error>211     fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
212         crate::stack_token!(tok);
213         fmt::Display::fmt(self.get(tok), f)
214     }
215 }
216 
217 impl<T: fmt::Debug> fmt::Debug for SemiSticky<T> {
fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error>218     fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
219         crate::stack_token!(tok);
220         match self.try_get(tok) {
221             Ok(value) => f.debug_struct("SemiSticky").field("value", value).finish(),
222             Err(..) => {
223                 struct InvalidPlaceholder;
224                 impl fmt::Debug for InvalidPlaceholder {
225                     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
226                         f.write_str("<invalid thread>")
227                     }
228                 }
229 
230                 f.debug_struct("SemiSticky")
231                     .field("value", &InvalidPlaceholder)
232                     .finish()
233             }
234         }
235     }
236 }
237 
238 #[test]
test_basic()239 fn test_basic() {
240     use std::thread;
241     let val = SemiSticky::new(true);
242     crate::stack_token!(tok);
243     assert_eq!(val.to_string(), "true");
244     assert_eq!(val.get(tok), &true);
245     assert!(val.try_get(tok).is_ok());
246     thread::spawn(move || {
247         crate::stack_token!(tok);
248         assert!(val.try_get(tok).is_err());
249     })
250     .join()
251     .unwrap();
252 }
253 
254 #[test]
test_mut()255 fn test_mut() {
256     let mut val = SemiSticky::new(true);
257     crate::stack_token!(tok);
258     *val.get_mut(tok) = false;
259     assert_eq!(val.to_string(), "false");
260     assert_eq!(val.get(tok), &false);
261 }
262 
263 #[test]
264 #[should_panic]
test_access_other_thread()265 fn test_access_other_thread() {
266     use std::thread;
267     let val = SemiSticky::new(true);
268     thread::spawn(move || {
269         crate::stack_token!(tok);
270         val.get(tok);
271     })
272     .join()
273     .unwrap();
274 }
275 
276 #[test]
test_drop_same_thread()277 fn test_drop_same_thread() {
278     use std::sync::atomic::{AtomicBool, Ordering};
279     use std::sync::Arc;
280     let was_called = Arc::new(AtomicBool::new(false));
281     struct X(Arc<AtomicBool>);
282     impl Drop for X {
283         fn drop(&mut self) {
284             self.0.store(true, Ordering::SeqCst);
285         }
286     }
287     let val = SemiSticky::new(X(was_called.clone()));
288     mem::drop(val);
289     assert!(was_called.load(Ordering::SeqCst));
290 }
291 
292 #[test]
test_noop_drop_elsewhere()293 fn test_noop_drop_elsewhere() {
294     use std::sync::atomic::{AtomicBool, Ordering};
295     use std::sync::Arc;
296     use std::thread;
297 
298     let was_called = Arc::new(AtomicBool::new(false));
299 
300     {
301         let was_called = was_called.clone();
302         thread::spawn(move || {
303             struct X(Arc<AtomicBool>);
304             impl Drop for X {
305                 fn drop(&mut self) {
306                     self.0.store(true, Ordering::SeqCst);
307                 }
308             }
309 
310             let val = SemiSticky::new(X(was_called.clone()));
311             assert!(thread::spawn(move || {
312                 // moves it here but do not deallocate
313                 crate::stack_token!(tok);
314                 val.try_get(tok).ok();
315             })
316             .join()
317             .is_ok());
318 
319             assert!(!was_called.load(Ordering::SeqCst));
320         })
321         .join()
322         .unwrap();
323     }
324 
325     assert!(was_called.load(Ordering::SeqCst));
326 }
327 
328 #[test]
test_rc_sending()329 fn test_rc_sending() {
330     use std::rc::Rc;
331     use std::thread;
332     let val = SemiSticky::new(Rc::new(true));
333     thread::spawn(move || {
334         crate::stack_token!(tok);
335         assert!(val.try_get(tok).is_err());
336     })
337     .join()
338     .unwrap();
339 }
340