1 //! An asynchronously awaitable `CancellationToken`.
2 //! The token allows to signal a cancellation request to one or more tasks.
3 pub(crate) mod guard;
4 mod tree_node;
5 
6 use crate::loom::sync::Arc;
7 use crate::util::MaybeDangling;
8 use core::future::Future;
9 use core::pin::Pin;
10 use core::task::{Context, Poll};
11 
12 use guard::DropGuard;
13 use pin_project_lite::pin_project;
14 
15 /// A token which can be used to signal a cancellation request to one or more
16 /// tasks.
17 ///
18 /// Tasks can call [`CancellationToken::cancelled()`] in order to
19 /// obtain a Future which will be resolved when cancellation is requested.
20 ///
21 /// Cancellation can be requested through the [`CancellationToken::cancel`] method.
22 ///
23 /// # Examples
24 ///
25 /// ```no_run
26 /// use tokio::select;
27 /// use tokio_util::sync::CancellationToken;
28 ///
29 /// #[tokio::main]
30 /// async fn main() {
31 ///     let token = CancellationToken::new();
32 ///     let cloned_token = token.clone();
33 ///
34 ///     let join_handle = tokio::spawn(async move {
35 ///         // Wait for either cancellation or a very long time
36 ///         select! {
37 ///             _ = cloned_token.cancelled() => {
38 ///                 // The token was cancelled
39 ///                 5
40 ///             }
41 ///             _ = tokio::time::sleep(std::time::Duration::from_secs(9999)) => {
42 ///                 99
43 ///             }
44 ///         }
45 ///     });
46 ///
47 ///     tokio::spawn(async move {
48 ///         tokio::time::sleep(std::time::Duration::from_millis(10)).await;
49 ///         token.cancel();
50 ///     });
51 ///
52 ///     assert_eq!(5, join_handle.await.unwrap());
53 /// }
54 /// ```
55 pub struct CancellationToken {
56     inner: Arc<tree_node::TreeNode>,
57 }
58 
59 impl std::panic::UnwindSafe for CancellationToken {}
60 impl std::panic::RefUnwindSafe for CancellationToken {}
61 
62 pin_project! {
63     /// A Future that is resolved once the corresponding [`CancellationToken`]
64     /// is cancelled.
65     #[must_use = "futures do nothing unless polled"]
66     pub struct WaitForCancellationFuture<'a> {
67         cancellation_token: &'a CancellationToken,
68         #[pin]
69         future: tokio::sync::futures::Notified<'a>,
70     }
71 }
72 
73 pin_project! {
74     /// A Future that is resolved once the corresponding [`CancellationToken`]
75     /// is cancelled.
76     ///
77     /// This is the counterpart to [`WaitForCancellationFuture`] that takes
78     /// [`CancellationToken`] by value instead of using a reference.
79     #[must_use = "futures do nothing unless polled"]
80     pub struct WaitForCancellationFutureOwned {
81         // This field internally has a reference to the cancellation token, but camouflages
82         // the relationship with `'static`. To avoid Undefined Behavior, we must ensure
83         // that the reference is only used while the cancellation token is still alive. To
84         // do that, we ensure that the future is the first field, so that it is dropped
85         // before the cancellation token.
86         //
87         // We use `MaybeDanglingFuture` here because without it, the compiler could assert
88         // the reference inside `future` to be valid even after the destructor of that
89         // field runs. (Specifically, when the `WaitForCancellationFutureOwned` is passed
90         // as an argument to a function, the reference can be asserted to be valid for the
91         // rest of that function.) To avoid that, we use `MaybeDangling` which tells the
92         // compiler that the reference stored inside it might not be valid.
93         //
94         // See <https://users.rust-lang.org/t/unsafe-code-review-semi-owning-weak-rwlock-t-guard/95706>
95         // for more info.
96         #[pin]
97         future: MaybeDangling<tokio::sync::futures::Notified<'static>>,
98         cancellation_token: CancellationToken,
99     }
100 }
101 
102 // ===== impl CancellationToken =====
103 
104 impl core::fmt::Debug for CancellationToken {
fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result105     fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
106         f.debug_struct("CancellationToken")
107             .field("is_cancelled", &self.is_cancelled())
108             .finish()
109     }
110 }
111 
112 impl Clone for CancellationToken {
113     /// Creates a clone of the `CancellationToken` which will get cancelled
114     /// whenever the current token gets cancelled, and vice versa.
clone(&self) -> Self115     fn clone(&self) -> Self {
116         tree_node::increase_handle_refcount(&self.inner);
117         CancellationToken {
118             inner: self.inner.clone(),
119         }
120     }
121 }
122 
123 impl Drop for CancellationToken {
drop(&mut self)124     fn drop(&mut self) {
125         tree_node::decrease_handle_refcount(&self.inner);
126     }
127 }
128 
129 impl Default for CancellationToken {
default() -> CancellationToken130     fn default() -> CancellationToken {
131         CancellationToken::new()
132     }
133 }
134 
135 impl CancellationToken {
136     /// Creates a new `CancellationToken` in the non-cancelled state.
new() -> CancellationToken137     pub fn new() -> CancellationToken {
138         CancellationToken {
139             inner: Arc::new(tree_node::TreeNode::new()),
140         }
141     }
142 
143     /// Creates a `CancellationToken` which will get cancelled whenever the
144     /// current token gets cancelled. Unlike a cloned `CancellationToken`,
145     /// cancelling a child token does not cancel the parent token.
146     ///
147     /// If the current token is already cancelled, the child token will get
148     /// returned in cancelled state.
149     ///
150     /// # Examples
151     ///
152     /// ```no_run
153     /// use tokio::select;
154     /// use tokio_util::sync::CancellationToken;
155     ///
156     /// #[tokio::main]
157     /// async fn main() {
158     ///     let token = CancellationToken::new();
159     ///     let child_token = token.child_token();
160     ///
161     ///     let join_handle = tokio::spawn(async move {
162     ///         // Wait for either cancellation or a very long time
163     ///         select! {
164     ///             _ = child_token.cancelled() => {
165     ///                 // The token was cancelled
166     ///                 5
167     ///             }
168     ///             _ = tokio::time::sleep(std::time::Duration::from_secs(9999)) => {
169     ///                 99
170     ///             }
171     ///         }
172     ///     });
173     ///
174     ///     tokio::spawn(async move {
175     ///         tokio::time::sleep(std::time::Duration::from_millis(10)).await;
176     ///         token.cancel();
177     ///     });
178     ///
179     ///     assert_eq!(5, join_handle.await.unwrap());
180     /// }
181     /// ```
child_token(&self) -> CancellationToken182     pub fn child_token(&self) -> CancellationToken {
183         CancellationToken {
184             inner: tree_node::child_node(&self.inner),
185         }
186     }
187 
188     /// Cancel the [`CancellationToken`] and all child tokens which had been
189     /// derived from it.
190     ///
191     /// This will wake up all tasks which are waiting for cancellation.
192     ///
193     /// Be aware that cancellation is not an atomic operation. It is possible
194     /// for another thread running in parallel with a call to `cancel` to first
195     /// receive `true` from `is_cancelled` on one child node, and then receive
196     /// `false` from `is_cancelled` on another child node. However, once the
197     /// call to `cancel` returns, all child nodes have been fully cancelled.
cancel(&self)198     pub fn cancel(&self) {
199         tree_node::cancel(&self.inner);
200     }
201 
202     /// Returns `true` if the `CancellationToken` is cancelled.
is_cancelled(&self) -> bool203     pub fn is_cancelled(&self) -> bool {
204         tree_node::is_cancelled(&self.inner)
205     }
206 
207     /// Returns a `Future` that gets fulfilled when cancellation is requested.
208     ///
209     /// The future will complete immediately if the token is already cancelled
210     /// when this method is called.
211     ///
212     /// # Cancel safety
213     ///
214     /// This method is cancel safe.
cancelled(&self) -> WaitForCancellationFuture<'_>215     pub fn cancelled(&self) -> WaitForCancellationFuture<'_> {
216         WaitForCancellationFuture {
217             cancellation_token: self,
218             future: self.inner.notified(),
219         }
220     }
221 
222     /// Returns a `Future` that gets fulfilled when cancellation is requested.
223     ///
224     /// The future will complete immediately if the token is already cancelled
225     /// when this method is called.
226     ///
227     /// The function takes self by value and returns a future that owns the
228     /// token.
229     ///
230     /// # Cancel safety
231     ///
232     /// This method is cancel safe.
cancelled_owned(self) -> WaitForCancellationFutureOwned233     pub fn cancelled_owned(self) -> WaitForCancellationFutureOwned {
234         WaitForCancellationFutureOwned::new(self)
235     }
236 
237     /// Creates a `DropGuard` for this token.
238     ///
239     /// Returned guard will cancel this token (and all its children) on drop
240     /// unless disarmed.
drop_guard(self) -> DropGuard241     pub fn drop_guard(self) -> DropGuard {
242         DropGuard { inner: Some(self) }
243     }
244 
245     /// Runs a future to completion and returns its result wrapped inside of an `Option`
246     /// unless the `CancellationToken` is cancelled. In that case the function returns
247     /// `None` and the future gets dropped.
248     ///
249     /// # Cancel safety
250     ///
251     /// This method is only cancel safe if `fut` is cancel safe.
run_until_cancelled<F>(&self, fut: F) -> Option<F::Output> where F: Future,252     pub async fn run_until_cancelled<F>(&self, fut: F) -> Option<F::Output>
253     where
254         F: Future,
255     {
256         pin_project! {
257             /// A Future that is resolved once the corresponding [`CancellationToken`]
258             /// is cancelled or a given Future gets resolved. It is biased towards the
259             /// Future completion.
260             #[must_use = "futures do nothing unless polled"]
261             struct RunUntilCancelledFuture<'a, F: Future> {
262                 #[pin]
263                 cancellation: WaitForCancellationFuture<'a>,
264                 #[pin]
265                 future: F,
266             }
267         }
268 
269         impl<'a, F: Future> Future for RunUntilCancelledFuture<'a, F> {
270             type Output = Option<F::Output>;
271 
272             fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
273                 let this = self.project();
274                 if let Poll::Ready(res) = this.future.poll(cx) {
275                     Poll::Ready(Some(res))
276                 } else if this.cancellation.poll(cx).is_ready() {
277                     Poll::Ready(None)
278                 } else {
279                     Poll::Pending
280                 }
281             }
282         }
283 
284         RunUntilCancelledFuture {
285             cancellation: self.cancelled(),
286             future: fut,
287         }
288         .await
289     }
290 }
291 
292 // ===== impl WaitForCancellationFuture =====
293 
294 impl<'a> core::fmt::Debug for WaitForCancellationFuture<'a> {
fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result295     fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
296         f.debug_struct("WaitForCancellationFuture").finish()
297     }
298 }
299 
300 impl<'a> Future for WaitForCancellationFuture<'a> {
301     type Output = ();
302 
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()>303     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
304         let mut this = self.project();
305         loop {
306             if this.cancellation_token.is_cancelled() {
307                 return Poll::Ready(());
308             }
309 
310             // No wakeups can be lost here because there is always a call to
311             // `is_cancelled` between the creation of the future and the call to
312             // `poll`, and the code that sets the cancelled flag does so before
313             // waking the `Notified`.
314             if this.future.as_mut().poll(cx).is_pending() {
315                 return Poll::Pending;
316             }
317 
318             this.future.set(this.cancellation_token.inner.notified());
319         }
320     }
321 }
322 
323 // ===== impl WaitForCancellationFutureOwned =====
324 
325 impl core::fmt::Debug for WaitForCancellationFutureOwned {
fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result326     fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
327         f.debug_struct("WaitForCancellationFutureOwned").finish()
328     }
329 }
330 
331 impl WaitForCancellationFutureOwned {
new(cancellation_token: CancellationToken) -> Self332     fn new(cancellation_token: CancellationToken) -> Self {
333         WaitForCancellationFutureOwned {
334             // cancellation_token holds a heap allocation and is guaranteed to have a
335             // stable deref, thus it would be ok to move the cancellation_token while
336             // the future holds a reference to it.
337             //
338             // # Safety
339             //
340             // cancellation_token is dropped after future due to the field ordering.
341             future: MaybeDangling::new(unsafe { Self::new_future(&cancellation_token) }),
342             cancellation_token,
343         }
344     }
345 
346     /// # Safety
347     /// The returned future must be destroyed before the cancellation token is
348     /// destroyed.
new_future( cancellation_token: &CancellationToken, ) -> tokio::sync::futures::Notified<'static>349     unsafe fn new_future(
350         cancellation_token: &CancellationToken,
351     ) -> tokio::sync::futures::Notified<'static> {
352         let inner_ptr = Arc::as_ptr(&cancellation_token.inner);
353         // SAFETY: The `Arc::as_ptr` method guarantees that `inner_ptr` remains
354         // valid until the strong count of the Arc drops to zero, and the caller
355         // guarantees that they will drop the future before that happens.
356         (*inner_ptr).notified()
357     }
358 }
359 
360 impl Future for WaitForCancellationFutureOwned {
361     type Output = ();
362 
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()>363     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
364         let mut this = self.project();
365 
366         loop {
367             if this.cancellation_token.is_cancelled() {
368                 return Poll::Ready(());
369             }
370 
371             // No wakeups can be lost here because there is always a call to
372             // `is_cancelled` between the creation of the future and the call to
373             // `poll`, and the code that sets the cancelled flag does so before
374             // waking the `Notified`.
375             if this.future.as_mut().poll(cx).is_pending() {
376                 return Poll::Pending;
377             }
378 
379             // # Safety
380             //
381             // cancellation_token is dropped after future due to the field ordering.
382             this.future.set(MaybeDangling::new(unsafe {
383                 Self::new_future(this.cancellation_token)
384             }));
385         }
386     }
387 }
388