1 use futures_sink::Sink;
2 use std::pin::Pin;
3 use std::task::{Context, Poll};
4 use std::{fmt, mem};
5 use tokio::sync::mpsc::OwnedPermit;
6 use tokio::sync::mpsc::Sender;
7 
8 use super::ReusableBoxFuture;
9 
10 /// Error returned by the `PollSender` when the channel is closed.
11 #[derive(Debug)]
12 pub struct PollSendError<T>(Option<T>);
13 
14 impl<T> PollSendError<T> {
15     /// Consumes the stored value, if any.
16     ///
17     /// If this error was encountered when calling `start_send`/`send_item`, this will be the item
18     /// that the caller attempted to send.  Otherwise, it will be `None`.
into_inner(self) -> Option<T>19     pub fn into_inner(self) -> Option<T> {
20         self.0
21     }
22 }
23 
24 impl<T> fmt::Display for PollSendError<T> {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result25     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
26         write!(fmt, "channel closed")
27     }
28 }
29 
30 impl<T: fmt::Debug> std::error::Error for PollSendError<T> {}
31 
32 #[derive(Debug)]
33 enum State<T> {
34     Idle(Sender<T>),
35     Acquiring,
36     ReadyToSend(OwnedPermit<T>),
37     Closed,
38 }
39 
40 /// A wrapper around [`mpsc::Sender`] that can be polled.
41 ///
42 /// [`mpsc::Sender`]: tokio::sync::mpsc::Sender
43 #[derive(Debug)]
44 pub struct PollSender<T> {
45     sender: Option<Sender<T>>,
46     state: State<T>,
47     acquire: PollSenderFuture<T>,
48 }
49 
50 // Creates a future for acquiring a permit from the underlying channel.  This is used to ensure
51 // there's capacity for a send to complete.
52 //
53 // By reusing the same async fn for both `Some` and `None`, we make sure every future passed to
54 // ReusableBoxFuture has the same underlying type, and hence the same size and alignment.
make_acquire_future<T>( data: Option<Sender<T>>, ) -> Result<OwnedPermit<T>, PollSendError<T>>55 async fn make_acquire_future<T>(
56     data: Option<Sender<T>>,
57 ) -> Result<OwnedPermit<T>, PollSendError<T>> {
58     match data {
59         Some(sender) => sender
60             .reserve_owned()
61             .await
62             .map_err(|_| PollSendError(None)),
63         None => unreachable!("this future should not be pollable in this state"),
64     }
65 }
66 
67 type InnerFuture<'a, T> = ReusableBoxFuture<'a, Result<OwnedPermit<T>, PollSendError<T>>>;
68 
69 #[derive(Debug)]
70 // TODO: This should be replace with a type_alias_impl_trait to eliminate `'static` and all the transmutes
71 struct PollSenderFuture<T>(InnerFuture<'static, T>);
72 
73 impl<T> PollSenderFuture<T> {
74     /// Create with an empty inner future with no `Send` bound.
empty() -> Self75     fn empty() -> Self {
76         // We don't use `make_acquire_future` here because our relaxed bounds on `T` are not
77         // compatible with the transitive bounds required by `Sender<T>`.
78         Self(ReusableBoxFuture::new(async { unreachable!() }))
79     }
80 }
81 
82 impl<T: Send> PollSenderFuture<T> {
83     /// Create with an empty inner future.
new() -> Self84     fn new() -> Self {
85         let v = InnerFuture::new(make_acquire_future(None));
86         // This is safe because `make_acquire_future(None)` is actually `'static`
87         Self(unsafe { mem::transmute::<InnerFuture<'_, T>, InnerFuture<'static, T>>(v) })
88     }
89 
90     /// Poll the inner future.
poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<OwnedPermit<T>, PollSendError<T>>>91     fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<OwnedPermit<T>, PollSendError<T>>> {
92         self.0.poll(cx)
93     }
94 
95     /// Replace the inner future.
set(&mut self, sender: Option<Sender<T>>)96     fn set(&mut self, sender: Option<Sender<T>>) {
97         let inner: *mut InnerFuture<'static, T> = &mut self.0;
98         let inner: *mut InnerFuture<'_, T> = inner.cast();
99         // SAFETY: The `make_acquire_future(sender)` future must not exist after the type `T`
100         // becomes invalid, and this casts away the type-level lifetime check for that. However, the
101         // inner future is never moved out of this `PollSenderFuture<T>`, so the future will not
102         // live longer than the `PollSenderFuture<T>` lives. A `PollSenderFuture<T>` is guaranteed
103         // to not exist after the type `T` becomes invalid, because it is annotated with a `T`, so
104         // this is ok.
105         let inner = unsafe { &mut *inner };
106         inner.set(make_acquire_future(sender));
107     }
108 }
109 
110 impl<T: Send> PollSender<T> {
111     /// Creates a new `PollSender`.
new(sender: Sender<T>) -> Self112     pub fn new(sender: Sender<T>) -> Self {
113         Self {
114             sender: Some(sender.clone()),
115             state: State::Idle(sender),
116             acquire: PollSenderFuture::new(),
117         }
118     }
119 
take_state(&mut self) -> State<T>120     fn take_state(&mut self) -> State<T> {
121         mem::replace(&mut self.state, State::Closed)
122     }
123 
124     /// Attempts to prepare the sender to receive a value.
125     ///
126     /// This method must be called and return `Poll::Ready(Ok(()))` prior to each call to
127     /// `send_item`.
128     ///
129     /// This method returns `Poll::Ready` once the underlying channel is ready to receive a value,
130     /// by reserving a slot in the channel for the item to be sent. If this method returns
131     /// `Poll::Pending`, the current task is registered to be notified (via
132     /// `cx.waker().wake_by_ref()`) when `poll_reserve` should be called again.
133     ///
134     /// # Errors
135     ///
136     /// If the channel is closed, an error will be returned.  This is a permanent state.
poll_reserve(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), PollSendError<T>>>137     pub fn poll_reserve(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), PollSendError<T>>> {
138         loop {
139             let (result, next_state) = match self.take_state() {
140                 State::Idle(sender) => {
141                     // Start trying to acquire a permit to reserve a slot for our send, and
142                     // immediately loop back around to poll it the first time.
143                     self.acquire.set(Some(sender));
144                     (None, State::Acquiring)
145                 }
146                 State::Acquiring => match self.acquire.poll(cx) {
147                     // Channel has capacity.
148                     Poll::Ready(Ok(permit)) => {
149                         (Some(Poll::Ready(Ok(()))), State::ReadyToSend(permit))
150                     }
151                     // Channel is closed.
152                     Poll::Ready(Err(e)) => (Some(Poll::Ready(Err(e))), State::Closed),
153                     // Channel doesn't have capacity yet, so we need to wait.
154                     Poll::Pending => (Some(Poll::Pending), State::Acquiring),
155                 },
156                 // We're closed, either by choice or because the underlying sender was closed.
157                 s @ State::Closed => (Some(Poll::Ready(Err(PollSendError(None)))), s),
158                 // We're already ready to send an item.
159                 s @ State::ReadyToSend(_) => (Some(Poll::Ready(Ok(()))), s),
160             };
161 
162             self.state = next_state;
163             if let Some(result) = result {
164                 return result;
165             }
166         }
167     }
168 
169     /// Sends an item to the channel.
170     ///
171     /// Before calling `send_item`, `poll_reserve` must be called with a successful return
172     /// value of `Poll::Ready(Ok(()))`.
173     ///
174     /// # Errors
175     ///
176     /// If the channel is closed, an error will be returned.  This is a permanent state.
177     ///
178     /// # Panics
179     ///
180     /// If `poll_reserve` was not successfully called prior to calling `send_item`, then this method
181     /// will panic.
182     #[track_caller]
send_item(&mut self, value: T) -> Result<(), PollSendError<T>>183     pub fn send_item(&mut self, value: T) -> Result<(), PollSendError<T>> {
184         let (result, next_state) = match self.take_state() {
185             State::Idle(_) | State::Acquiring => {
186                 panic!("`send_item` called without first calling `poll_reserve`")
187             }
188             // We have a permit to send our item, so go ahead, which gets us our sender back.
189             State::ReadyToSend(permit) => (Ok(()), State::Idle(permit.send(value))),
190             // We're closed, either by choice or because the underlying sender was closed.
191             State::Closed => (Err(PollSendError(Some(value))), State::Closed),
192         };
193 
194         // Handle deferred closing if `close` was called between `poll_reserve` and `send_item`.
195         self.state = if self.sender.is_some() {
196             next_state
197         } else {
198             State::Closed
199         };
200         result
201     }
202 
203     /// Checks whether this sender is been closed.
204     ///
205     /// The underlying channel that this sender was wrapping may still be open.
is_closed(&self) -> bool206     pub fn is_closed(&self) -> bool {
207         matches!(self.state, State::Closed) || self.sender.is_none()
208     }
209 
210     /// Gets a reference to the `Sender` of the underlying channel.
211     ///
212     /// If `PollSender` has been closed, `None` is returned. The underlying channel that this sender
213     /// was wrapping may still be open.
get_ref(&self) -> Option<&Sender<T>>214     pub fn get_ref(&self) -> Option<&Sender<T>> {
215         self.sender.as_ref()
216     }
217 
218     /// Closes this sender.
219     ///
220     /// No more messages will be able to be sent from this sender, but the underlying channel will
221     /// remain open until all senders have dropped, or until the [`Receiver`] closes the channel.
222     ///
223     /// If a slot was previously reserved by calling `poll_reserve`, then a final call can be made
224     /// to `send_item` in order to consume the reserved slot.  After that, no further sends will be
225     /// possible.  If you do not intend to send another item, you can release the reserved slot back
226     /// to the underlying sender by calling [`abort_send`].
227     ///
228     /// [`abort_send`]: crate::sync::PollSender::abort_send
229     /// [`Receiver`]: tokio::sync::mpsc::Receiver
close(&mut self)230     pub fn close(&mut self) {
231         // Mark ourselves officially closed by dropping our main sender.
232         self.sender = None;
233 
234         // If we're already idle, closed, or we haven't yet reserved a slot, we can quickly
235         // transition to the closed state.  Otherwise, leave the existing permit in place for the
236         // caller if they want to complete the send.
237         match self.state {
238             State::Idle(_) => self.state = State::Closed,
239             State::Acquiring => {
240                 self.acquire.set(None);
241                 self.state = State::Closed;
242             }
243             _ => {}
244         }
245     }
246 
247     /// Aborts the current in-progress send, if any.
248     ///
249     /// Returns `true` if a send was aborted.  If the sender was closed prior to calling
250     /// `abort_send`, then the sender will remain in the closed state, otherwise the sender will be
251     /// ready to attempt another send.
abort_send(&mut self) -> bool252     pub fn abort_send(&mut self) -> bool {
253         // We may have been closed in the meantime, after a call to `poll_reserve` already
254         // succeeded.  We'll check if `self.sender` is `None` to see if we should transition to the
255         // closed state when we actually abort a send, rather than resetting ourselves back to idle.
256 
257         let (result, next_state) = match self.take_state() {
258             // We're currently trying to reserve a slot to send into.
259             State::Acquiring => {
260                 // Replacing the future drops the in-flight one.
261                 self.acquire.set(None);
262 
263                 // If we haven't closed yet, we have to clone our stored sender since we have no way
264                 // to get it back from the acquire future we just dropped.
265                 let state = match self.sender.clone() {
266                     Some(sender) => State::Idle(sender),
267                     None => State::Closed,
268                 };
269                 (true, state)
270             }
271             // We got the permit.  If we haven't closed yet, get the sender back.
272             State::ReadyToSend(permit) => {
273                 let state = if self.sender.is_some() {
274                     State::Idle(permit.release())
275                 } else {
276                     State::Closed
277                 };
278                 (true, state)
279             }
280             s => (false, s),
281         };
282 
283         self.state = next_state;
284         result
285     }
286 }
287 
288 impl<T> Clone for PollSender<T> {
289     /// Clones this `PollSender`.
290     ///
291     /// The resulting `PollSender` will have an initial state identical to calling `PollSender::new`.
clone(&self) -> PollSender<T>292     fn clone(&self) -> PollSender<T> {
293         let (sender, state) = match self.sender.clone() {
294             Some(sender) => (Some(sender.clone()), State::Idle(sender)),
295             None => (None, State::Closed),
296         };
297 
298         Self {
299             sender,
300             state,
301             acquire: PollSenderFuture::empty(),
302         }
303     }
304 }
305 
306 impl<T: Send> Sink<T> for PollSender<T> {
307     type Error = PollSendError<T>;
308 
poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>309     fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
310         Pin::into_inner(self).poll_reserve(cx)
311     }
312 
poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>313     fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
314         Poll::Ready(Ok(()))
315     }
316 
start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error>317     fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
318         Pin::into_inner(self).send_item(item)
319     }
320 
poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>321     fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
322         Pin::into_inner(self).close();
323         Poll::Ready(Ok(()))
324     }
325 }
326