1 //! Futures task based helpers
2
3 #![allow(clippy::mutex_atomic)]
4
5 use std::future::Future;
6 use std::mem;
7 use std::ops;
8 use std::pin::Pin;
9 use std::sync::{Arc, Condvar, Mutex};
10 use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
11
12 use tokio_stream::Stream;
13
14 /// TODO: dox
spawn<T>(task: T) -> Spawn<T>15 pub fn spawn<T>(task: T) -> Spawn<T> {
16 Spawn {
17 task: MockTask::new(),
18 future: Box::pin(task),
19 }
20 }
21
22 /// Future spawned on a mock task
23 #[derive(Debug)]
24 pub struct Spawn<T> {
25 task: MockTask,
26 future: Pin<Box<T>>,
27 }
28
29 /// Mock task
30 ///
31 /// A mock task is able to intercept and track wake notifications.
32 #[derive(Debug, Clone)]
33 struct MockTask {
34 waker: Arc<ThreadWaker>,
35 }
36
37 #[derive(Debug)]
38 struct ThreadWaker {
39 state: Mutex<usize>,
40 condvar: Condvar,
41 }
42
43 const IDLE: usize = 0;
44 const WAKE: usize = 1;
45 const SLEEP: usize = 2;
46
47 impl<T> Spawn<T> {
48 /// Consumes `self` returning the inner value
into_inner(self) -> T where T: Unpin,49 pub fn into_inner(self) -> T
50 where
51 T: Unpin,
52 {
53 *Pin::into_inner(self.future)
54 }
55
56 /// Returns `true` if the inner future has received a wake notification
57 /// since the last call to `enter`.
is_woken(&self) -> bool58 pub fn is_woken(&self) -> bool {
59 self.task.is_woken()
60 }
61
62 /// Returns the number of references to the task waker
63 ///
64 /// The task itself holds a reference. The return value will never be zero.
waker_ref_count(&self) -> usize65 pub fn waker_ref_count(&self) -> usize {
66 self.task.waker_ref_count()
67 }
68
69 /// Enter the task context
enter<F, R>(&mut self, f: F) -> R where F: FnOnce(&mut Context<'_>, Pin<&mut T>) -> R,70 pub fn enter<F, R>(&mut self, f: F) -> R
71 where
72 F: FnOnce(&mut Context<'_>, Pin<&mut T>) -> R,
73 {
74 let fut = self.future.as_mut();
75 self.task.enter(|cx| f(cx, fut))
76 }
77 }
78
79 impl<T: Unpin> ops::Deref for Spawn<T> {
80 type Target = T;
81
deref(&self) -> &T82 fn deref(&self) -> &T {
83 &self.future
84 }
85 }
86
87 impl<T: Unpin> ops::DerefMut for Spawn<T> {
deref_mut(&mut self) -> &mut T88 fn deref_mut(&mut self) -> &mut T {
89 &mut self.future
90 }
91 }
92
93 impl<T: Future> Spawn<T> {
94 /// Polls a future
poll(&mut self) -> Poll<T::Output>95 pub fn poll(&mut self) -> Poll<T::Output> {
96 let fut = self.future.as_mut();
97 self.task.enter(|cx| fut.poll(cx))
98 }
99 }
100
101 impl<T: Stream> Spawn<T> {
102 /// Polls a stream
poll_next(&mut self) -> Poll<Option<T::Item>>103 pub fn poll_next(&mut self) -> Poll<Option<T::Item>> {
104 let stream = self.future.as_mut();
105 self.task.enter(|cx| stream.poll_next(cx))
106 }
107 }
108
109 impl<T: Future> Future for Spawn<T> {
110 type Output = T::Output;
111
poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>112 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
113 self.future.as_mut().poll(cx)
114 }
115 }
116
117 impl<T: Stream> Stream for Spawn<T> {
118 type Item = T::Item;
119
poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>120 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
121 self.future.as_mut().poll_next(cx)
122 }
123 }
124
125 impl MockTask {
126 /// Creates new mock task
new() -> Self127 fn new() -> Self {
128 MockTask {
129 waker: Arc::new(ThreadWaker::new()),
130 }
131 }
132
133 /// Runs a closure from the context of the task.
134 ///
135 /// Any wake notifications resulting from the execution of the closure are
136 /// tracked.
enter<F, R>(&mut self, f: F) -> R where F: FnOnce(&mut Context<'_>) -> R,137 fn enter<F, R>(&mut self, f: F) -> R
138 where
139 F: FnOnce(&mut Context<'_>) -> R,
140 {
141 self.waker.clear();
142 let waker = self.waker();
143 let mut cx = Context::from_waker(&waker);
144
145 f(&mut cx)
146 }
147
148 /// Returns `true` if the inner future has received a wake notification
149 /// since the last call to `enter`.
is_woken(&self) -> bool150 fn is_woken(&self) -> bool {
151 self.waker.is_woken()
152 }
153
154 /// Returns the number of references to the task waker
155 ///
156 /// The task itself holds a reference. The return value will never be zero.
waker_ref_count(&self) -> usize157 fn waker_ref_count(&self) -> usize {
158 Arc::strong_count(&self.waker)
159 }
160
waker(&self) -> Waker161 fn waker(&self) -> Waker {
162 unsafe {
163 let raw = to_raw(self.waker.clone());
164 Waker::from_raw(raw)
165 }
166 }
167 }
168
169 impl Default for MockTask {
default() -> Self170 fn default() -> Self {
171 Self::new()
172 }
173 }
174
175 impl ThreadWaker {
new() -> Self176 fn new() -> Self {
177 ThreadWaker {
178 state: Mutex::new(IDLE),
179 condvar: Condvar::new(),
180 }
181 }
182
183 /// Clears any previously received wakes, avoiding potential spurrious
184 /// wake notifications. This should only be called immediately before running the
185 /// task.
clear(&self)186 fn clear(&self) {
187 *self.state.lock().unwrap() = IDLE;
188 }
189
is_woken(&self) -> bool190 fn is_woken(&self) -> bool {
191 match *self.state.lock().unwrap() {
192 IDLE => false,
193 WAKE => true,
194 _ => unreachable!(),
195 }
196 }
197
wake(&self)198 fn wake(&self) {
199 // First, try transitioning from IDLE -> NOTIFY, this does not require a lock.
200 let mut state = self.state.lock().unwrap();
201 let prev = *state;
202
203 if prev == WAKE {
204 return;
205 }
206
207 *state = WAKE;
208
209 if prev == IDLE {
210 return;
211 }
212
213 // The other half is sleeping, so we wake it up.
214 assert_eq!(prev, SLEEP);
215 self.condvar.notify_one();
216 }
217 }
218
219 static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker);
220
to_raw(waker: Arc<ThreadWaker>) -> RawWaker221 unsafe fn to_raw(waker: Arc<ThreadWaker>) -> RawWaker {
222 RawWaker::new(Arc::into_raw(waker) as *const (), &VTABLE)
223 }
224
from_raw(raw: *const ()) -> Arc<ThreadWaker>225 unsafe fn from_raw(raw: *const ()) -> Arc<ThreadWaker> {
226 Arc::from_raw(raw as *const ThreadWaker)
227 }
228
clone(raw: *const ()) -> RawWaker229 unsafe fn clone(raw: *const ()) -> RawWaker {
230 let waker = from_raw(raw);
231
232 // Increment the ref count
233 mem::forget(waker.clone());
234
235 to_raw(waker)
236 }
237
wake(raw: *const ())238 unsafe fn wake(raw: *const ()) {
239 let waker = from_raw(raw);
240 waker.wake();
241 }
242
wake_by_ref(raw: *const ())243 unsafe fn wake_by_ref(raw: *const ()) {
244 let waker = from_raw(raw);
245 waker.wake();
246
247 // We don't actually own a reference to the unparker
248 mem::forget(waker);
249 }
250
drop_waker(raw: *const ())251 unsafe fn drop_waker(raw: *const ()) {
252 let _ = from_raw(raw);
253 }
254