1 //! Futures task based helpers
2 
3 #![allow(clippy::mutex_atomic)]
4 
5 use std::future::Future;
6 use std::mem;
7 use std::ops;
8 use std::pin::Pin;
9 use std::sync::{Arc, Condvar, Mutex};
10 use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
11 
12 use tokio_stream::Stream;
13 
14 /// TODO: dox
spawn<T>(task: T) -> Spawn<T>15 pub fn spawn<T>(task: T) -> Spawn<T> {
16     Spawn {
17         task: MockTask::new(),
18         future: Box::pin(task),
19     }
20 }
21 
22 /// Future spawned on a mock task
23 #[derive(Debug)]
24 pub struct Spawn<T> {
25     task: MockTask,
26     future: Pin<Box<T>>,
27 }
28 
29 /// Mock task
30 ///
31 /// A mock task is able to intercept and track wake notifications.
32 #[derive(Debug, Clone)]
33 struct MockTask {
34     waker: Arc<ThreadWaker>,
35 }
36 
37 #[derive(Debug)]
38 struct ThreadWaker {
39     state: Mutex<usize>,
40     condvar: Condvar,
41 }
42 
43 const IDLE: usize = 0;
44 const WAKE: usize = 1;
45 const SLEEP: usize = 2;
46 
47 impl<T> Spawn<T> {
48     /// Consumes `self` returning the inner value
into_inner(self) -> T where T: Unpin,49     pub fn into_inner(self) -> T
50     where
51         T: Unpin,
52     {
53         *Pin::into_inner(self.future)
54     }
55 
56     /// Returns `true` if the inner future has received a wake notification
57     /// since the last call to `enter`.
is_woken(&self) -> bool58     pub fn is_woken(&self) -> bool {
59         self.task.is_woken()
60     }
61 
62     /// Returns the number of references to the task waker
63     ///
64     /// The task itself holds a reference. The return value will never be zero.
waker_ref_count(&self) -> usize65     pub fn waker_ref_count(&self) -> usize {
66         self.task.waker_ref_count()
67     }
68 
69     /// Enter the task context
enter<F, R>(&mut self, f: F) -> R where F: FnOnce(&mut Context<'_>, Pin<&mut T>) -> R,70     pub fn enter<F, R>(&mut self, f: F) -> R
71     where
72         F: FnOnce(&mut Context<'_>, Pin<&mut T>) -> R,
73     {
74         let fut = self.future.as_mut();
75         self.task.enter(|cx| f(cx, fut))
76     }
77 }
78 
79 impl<T: Unpin> ops::Deref for Spawn<T> {
80     type Target = T;
81 
deref(&self) -> &T82     fn deref(&self) -> &T {
83         &self.future
84     }
85 }
86 
87 impl<T: Unpin> ops::DerefMut for Spawn<T> {
deref_mut(&mut self) -> &mut T88     fn deref_mut(&mut self) -> &mut T {
89         &mut self.future
90     }
91 }
92 
93 impl<T: Future> Spawn<T> {
94     /// Polls a future
poll(&mut self) -> Poll<T::Output>95     pub fn poll(&mut self) -> Poll<T::Output> {
96         let fut = self.future.as_mut();
97         self.task.enter(|cx| fut.poll(cx))
98     }
99 }
100 
101 impl<T: Stream> Spawn<T> {
102     /// Polls a stream
poll_next(&mut self) -> Poll<Option<T::Item>>103     pub fn poll_next(&mut self) -> Poll<Option<T::Item>> {
104         let stream = self.future.as_mut();
105         self.task.enter(|cx| stream.poll_next(cx))
106     }
107 }
108 
109 impl<T: Future> Future for Spawn<T> {
110     type Output = T::Output;
111 
poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>112     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
113         self.future.as_mut().poll(cx)
114     }
115 }
116 
117 impl<T: Stream> Stream for Spawn<T> {
118     type Item = T::Item;
119 
poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>120     fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
121         self.future.as_mut().poll_next(cx)
122     }
123 }
124 
125 impl MockTask {
126     /// Creates new mock task
new() -> Self127     fn new() -> Self {
128         MockTask {
129             waker: Arc::new(ThreadWaker::new()),
130         }
131     }
132 
133     /// Runs a closure from the context of the task.
134     ///
135     /// Any wake notifications resulting from the execution of the closure are
136     /// tracked.
enter<F, R>(&mut self, f: F) -> R where F: FnOnce(&mut Context<'_>) -> R,137     fn enter<F, R>(&mut self, f: F) -> R
138     where
139         F: FnOnce(&mut Context<'_>) -> R,
140     {
141         self.waker.clear();
142         let waker = self.waker();
143         let mut cx = Context::from_waker(&waker);
144 
145         f(&mut cx)
146     }
147 
148     /// Returns `true` if the inner future has received a wake notification
149     /// since the last call to `enter`.
is_woken(&self) -> bool150     fn is_woken(&self) -> bool {
151         self.waker.is_woken()
152     }
153 
154     /// Returns the number of references to the task waker
155     ///
156     /// The task itself holds a reference. The return value will never be zero.
waker_ref_count(&self) -> usize157     fn waker_ref_count(&self) -> usize {
158         Arc::strong_count(&self.waker)
159     }
160 
waker(&self) -> Waker161     fn waker(&self) -> Waker {
162         unsafe {
163             let raw = to_raw(self.waker.clone());
164             Waker::from_raw(raw)
165         }
166     }
167 }
168 
169 impl Default for MockTask {
default() -> Self170     fn default() -> Self {
171         Self::new()
172     }
173 }
174 
175 impl ThreadWaker {
new() -> Self176     fn new() -> Self {
177         ThreadWaker {
178             state: Mutex::new(IDLE),
179             condvar: Condvar::new(),
180         }
181     }
182 
183     /// Clears any previously received wakes, avoiding potential spurrious
184     /// wake notifications. This should only be called immediately before running the
185     /// task.
clear(&self)186     fn clear(&self) {
187         *self.state.lock().unwrap() = IDLE;
188     }
189 
is_woken(&self) -> bool190     fn is_woken(&self) -> bool {
191         match *self.state.lock().unwrap() {
192             IDLE => false,
193             WAKE => true,
194             _ => unreachable!(),
195         }
196     }
197 
wake(&self)198     fn wake(&self) {
199         // First, try transitioning from IDLE -> NOTIFY, this does not require a lock.
200         let mut state = self.state.lock().unwrap();
201         let prev = *state;
202 
203         if prev == WAKE {
204             return;
205         }
206 
207         *state = WAKE;
208 
209         if prev == IDLE {
210             return;
211         }
212 
213         // The other half is sleeping, so we wake it up.
214         assert_eq!(prev, SLEEP);
215         self.condvar.notify_one();
216     }
217 }
218 
219 static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker);
220 
to_raw(waker: Arc<ThreadWaker>) -> RawWaker221 unsafe fn to_raw(waker: Arc<ThreadWaker>) -> RawWaker {
222     RawWaker::new(Arc::into_raw(waker) as *const (), &VTABLE)
223 }
224 
from_raw(raw: *const ()) -> Arc<ThreadWaker>225 unsafe fn from_raw(raw: *const ()) -> Arc<ThreadWaker> {
226     Arc::from_raw(raw as *const ThreadWaker)
227 }
228 
clone(raw: *const ()) -> RawWaker229 unsafe fn clone(raw: *const ()) -> RawWaker {
230     let waker = from_raw(raw);
231 
232     // Increment the ref count
233     mem::forget(waker.clone());
234 
235     to_raw(waker)
236 }
237 
wake(raw: *const ())238 unsafe fn wake(raw: *const ()) {
239     let waker = from_raw(raw);
240     waker.wake();
241 }
242 
wake_by_ref(raw: *const ())243 unsafe fn wake_by_ref(raw: *const ()) {
244     let waker = from_raw(raw);
245     waker.wake();
246 
247     // We don't actually own a reference to the unparker
248     mem::forget(waker);
249 }
250 
drop_waker(raw: *const ())251 unsafe fn drop_waker(raw: *const ()) {
252     let _ = from_raw(raw);
253 }
254