1 use crate::task::AtomicWaker;
2 use alloc::sync::Arc;
3 use core::fmt;
4 use core::pin::Pin;
5 use core::sync::atomic::{AtomicBool, Ordering};
6 use futures_core::future::Future;
7 use futures_core::task::{Context, Poll};
8 use futures_core::Stream;
9 use pin_project_lite::pin_project;
10 
11 pin_project! {
12     /// A future/stream which can be remotely short-circuited using an `AbortHandle`.
13     #[derive(Debug, Clone)]
14     #[must_use = "futures/streams do nothing unless you poll them"]
15     pub struct Abortable<T> {
16         #[pin]
17         task: T,
18         inner: Arc<AbortInner>,
19     }
20 }
21 
22 impl<T> Abortable<T> {
23     /// Creates a new `Abortable` future/stream using an existing `AbortRegistration`.
24     /// `AbortRegistration`s can be acquired through `AbortHandle::new`.
25     ///
26     /// When `abort` is called on the handle tied to `reg` or if `abort` has
27     /// already been called, the future/stream will complete immediately without making
28     /// any further progress.
29     ///
30     /// # Examples:
31     ///
32     /// Usage with futures:
33     ///
34     /// ```
35     /// # futures::executor::block_on(async {
36     /// use futures::future::{Abortable, AbortHandle, Aborted};
37     ///
38     /// let (abort_handle, abort_registration) = AbortHandle::new_pair();
39     /// let future = Abortable::new(async { 2 }, abort_registration);
40     /// abort_handle.abort();
41     /// assert_eq!(future.await, Err(Aborted));
42     /// # });
43     /// ```
44     ///
45     /// Usage with streams:
46     ///
47     /// ```
48     /// # futures::executor::block_on(async {
49     /// # use futures::future::{Abortable, AbortHandle};
50     /// # use futures::stream::{self, StreamExt};
51     ///
52     /// let (abort_handle, abort_registration) = AbortHandle::new_pair();
53     /// let mut stream = Abortable::new(stream::iter(vec![1, 2, 3]), abort_registration);
54     /// abort_handle.abort();
55     /// assert_eq!(stream.next().await, None);
56     /// # });
57     /// ```
new(task: T, reg: AbortRegistration) -> Self58     pub fn new(task: T, reg: AbortRegistration) -> Self {
59         Self { task, inner: reg.inner }
60     }
61 
62     /// Checks whether the task has been aborted. Note that all this
63     /// method indicates is whether [`AbortHandle::abort`] was *called*.
64     /// This means that it will return `true` even if:
65     /// * `abort` was called after the task had completed.
66     /// * `abort` was called while the task was being polled - the task may still be running and
67     ///   will not be stopped until `poll` returns.
is_aborted(&self) -> bool68     pub fn is_aborted(&self) -> bool {
69         self.inner.aborted.load(Ordering::Relaxed)
70     }
71 }
72 
73 /// A registration handle for an `Abortable` task.
74 /// Values of this type can be acquired from `AbortHandle::new` and are used
75 /// in calls to `Abortable::new`.
76 #[derive(Debug)]
77 pub struct AbortRegistration {
78     pub(crate) inner: Arc<AbortInner>,
79 }
80 
81 impl AbortRegistration {
82     /// Create an [`AbortHandle`] from the given [`AbortRegistration`].
83     ///
84     /// The created [`AbortHandle`] is functionally the same as any other
85     /// [`AbortHandle`]s that are associated with the same [`AbortRegistration`],
86     /// such as the one created by [`AbortHandle::new_pair`].
handle(&self) -> AbortHandle87     pub fn handle(&self) -> AbortHandle {
88         AbortHandle { inner: self.inner.clone() }
89     }
90 }
91 
92 /// A handle to an `Abortable` task.
93 #[derive(Debug, Clone)]
94 pub struct AbortHandle {
95     inner: Arc<AbortInner>,
96 }
97 
98 impl AbortHandle {
99     /// Creates an (`AbortHandle`, `AbortRegistration`) pair which can be used
100     /// to abort a running future or stream.
101     ///
102     /// This function is usually paired with a call to [`Abortable::new`].
new_pair() -> (Self, AbortRegistration)103     pub fn new_pair() -> (Self, AbortRegistration) {
104         let inner =
105             Arc::new(AbortInner { waker: AtomicWaker::new(), aborted: AtomicBool::new(false) });
106 
107         (Self { inner: inner.clone() }, AbortRegistration { inner })
108     }
109 }
110 
111 // Inner type storing the waker to awaken and a bool indicating that it
112 // should be aborted.
113 #[derive(Debug)]
114 pub(crate) struct AbortInner {
115     pub(crate) waker: AtomicWaker,
116     pub(crate) aborted: AtomicBool,
117 }
118 
119 /// Indicator that the `Abortable` task was aborted.
120 #[derive(Copy, Clone, Debug, Eq, PartialEq)]
121 pub struct Aborted;
122 
123 impl fmt::Display for Aborted {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result124     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
125         write!(f, "`Abortable` future has been aborted")
126     }
127 }
128 
129 #[cfg(feature = "std")]
130 impl std::error::Error for Aborted {}
131 
132 impl<T> Abortable<T> {
try_poll<I>( mut self: Pin<&mut Self>, cx: &mut Context<'_>, poll: impl Fn(Pin<&mut T>, &mut Context<'_>) -> Poll<I>, ) -> Poll<Result<I, Aborted>>133     fn try_poll<I>(
134         mut self: Pin<&mut Self>,
135         cx: &mut Context<'_>,
136         poll: impl Fn(Pin<&mut T>, &mut Context<'_>) -> Poll<I>,
137     ) -> Poll<Result<I, Aborted>> {
138         // Check if the task has been aborted
139         if self.is_aborted() {
140             return Poll::Ready(Err(Aborted));
141         }
142 
143         // attempt to complete the task
144         if let Poll::Ready(x) = poll(self.as_mut().project().task, cx) {
145             return Poll::Ready(Ok(x));
146         }
147 
148         // Register to receive a wakeup if the task is aborted in the future
149         self.inner.waker.register(cx.waker());
150 
151         // Check to see if the task was aborted between the first check and
152         // registration.
153         // Checking with `is_aborted` which uses `Relaxed` is sufficient because
154         // `register` introduces an `AcqRel` barrier.
155         if self.is_aborted() {
156             return Poll::Ready(Err(Aborted));
157         }
158 
159         Poll::Pending
160     }
161 }
162 
163 impl<Fut> Future for Abortable<Fut>
164 where
165     Fut: Future,
166 {
167     type Output = Result<Fut::Output, Aborted>;
168 
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>169     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
170         self.try_poll(cx, |fut, cx| fut.poll(cx))
171     }
172 }
173 
174 impl<St> Stream for Abortable<St>
175 where
176     St: Stream,
177 {
178     type Item = St::Item;
179 
poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>180     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
181         self.try_poll(cx, |stream, cx| stream.poll_next(cx)).map(Result::ok).map(Option::flatten)
182     }
183 }
184 
185 impl AbortHandle {
186     /// Abort the `Abortable` stream/future associated with this handle.
187     ///
188     /// Notifies the Abortable task associated with this handle that it
189     /// should abort. Note that if the task is currently being polled on
190     /// another thread, it will not immediately stop running. Instead, it will
191     /// continue to run until its poll method returns.
abort(&self)192     pub fn abort(&self) {
193         self.inner.aborted.store(true, Ordering::Relaxed);
194         self.inner.waker.wake();
195     }
196 
197     /// Checks whether [`AbortHandle::abort`] was *called* on any associated
198     /// [`AbortHandle`]s, which includes all the [`AbortHandle`]s linked with
199     /// the same [`AbortRegistration`]. This means that it will return `true`
200     /// even if:
201     /// * `abort` was called after the task had completed.
202     /// * `abort` was called while the task was being polled - the task may still be running and
203     ///   will not be stopped until `poll` returns.
204     ///
205     /// This operation has a Relaxed ordering.
is_aborted(&self) -> bool206     pub fn is_aborted(&self) -> bool {
207         self.inner.aborted.load(Ordering::Relaxed)
208     }
209 }
210