1 #[cfg(feature = "http2")]
2 use std::future::Future;
3 use std::marker::Unpin;
4 #[cfg(feature = "http2")]
5 use std::pin::Pin;
6 use std::task::{Context, Poll};
7 
8 use futures_util::FutureExt;
9 use tokio::sync::{mpsc, oneshot};
10 
11 pub(crate) type RetryPromise<T, U> = oneshot::Receiver<Result<U, (crate::Error, Option<T>)>>;
12 pub(crate) type Promise<T> = oneshot::Receiver<Result<T, crate::Error>>;
13 
channel<T, U>() -> (Sender<T, U>, Receiver<T, U>)14 pub(crate) fn channel<T, U>() -> (Sender<T, U>, Receiver<T, U>) {
15     let (tx, rx) = mpsc::unbounded_channel();
16     let (giver, taker) = want::new();
17     let tx = Sender {
18         buffered_once: false,
19         giver,
20         inner: tx,
21     };
22     let rx = Receiver { inner: rx, taker };
23     (tx, rx)
24 }
25 
26 /// A bounded sender of requests and callbacks for when responses are ready.
27 ///
28 /// While the inner sender is unbounded, the Giver is used to determine
29 /// if the Receiver is ready for another request.
30 pub(crate) struct Sender<T, U> {
31     /// One message is always allowed, even if the Receiver hasn't asked
32     /// for it yet. This boolean keeps track of whether we've sent one
33     /// without notice.
34     buffered_once: bool,
35     /// The Giver helps watch that the the Receiver side has been polled
36     /// when the queue is empty. This helps us know when a request and
37     /// response have been fully processed, and a connection is ready
38     /// for more.
39     giver: want::Giver,
40     /// Actually bounded by the Giver, plus `buffered_once`.
41     inner: mpsc::UnboundedSender<Envelope<T, U>>,
42 }
43 
44 /// An unbounded version.
45 ///
46 /// Cannot poll the Giver, but can still use it to determine if the Receiver
47 /// has been dropped. However, this version can be cloned.
48 #[cfg(feature = "http2")]
49 pub(crate) struct UnboundedSender<T, U> {
50     /// Only used for `is_closed`, since mpsc::UnboundedSender cannot be checked.
51     giver: want::SharedGiver,
52     inner: mpsc::UnboundedSender<Envelope<T, U>>,
53 }
54 
55 impl<T, U> Sender<T, U> {
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<crate::Result<()>>56     pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<crate::Result<()>> {
57         self.giver
58             .poll_want(cx)
59             .map_err(|_| crate::Error::new_closed())
60     }
61 
is_ready(&self) -> bool62     pub(crate) fn is_ready(&self) -> bool {
63         self.giver.is_wanting()
64     }
65 
is_closed(&self) -> bool66     pub(crate) fn is_closed(&self) -> bool {
67         self.giver.is_canceled()
68     }
69 
can_send(&mut self) -> bool70     fn can_send(&mut self) -> bool {
71         if self.giver.give() || !self.buffered_once {
72             // If the receiver is ready *now*, then of course we can send.
73             //
74             // If the receiver isn't ready yet, but we don't have anything
75             // in the channel yet, then allow one message.
76             self.buffered_once = true;
77             true
78         } else {
79             false
80         }
81     }
82 
try_send(&mut self, val: T) -> Result<RetryPromise<T, U>, T>83     pub(crate) fn try_send(&mut self, val: T) -> Result<RetryPromise<T, U>, T> {
84         if !self.can_send() {
85             return Err(val);
86         }
87         let (tx, rx) = oneshot::channel();
88         self.inner
89             .send(Envelope(Some((val, Callback::Retry(Some(tx))))))
90             .map(move |_| rx)
91             .map_err(|mut e| (e.0).0.take().expect("envelope not dropped").0)
92     }
93 
send(&mut self, val: T) -> Result<Promise<U>, T>94     pub(crate) fn send(&mut self, val: T) -> Result<Promise<U>, T> {
95         if !self.can_send() {
96             return Err(val);
97         }
98         let (tx, rx) = oneshot::channel();
99         self.inner
100             .send(Envelope(Some((val, Callback::NoRetry(Some(tx))))))
101             .map(move |_| rx)
102             .map_err(|mut e| (e.0).0.take().expect("envelope not dropped").0)
103     }
104 
105     #[cfg(feature = "http2")]
unbound(self) -> UnboundedSender<T, U>106     pub(crate) fn unbound(self) -> UnboundedSender<T, U> {
107         UnboundedSender {
108             giver: self.giver.shared(),
109             inner: self.inner,
110         }
111     }
112 }
113 
114 #[cfg(feature = "http2")]
115 impl<T, U> UnboundedSender<T, U> {
is_ready(&self) -> bool116     pub(crate) fn is_ready(&self) -> bool {
117         !self.giver.is_canceled()
118     }
119 
is_closed(&self) -> bool120     pub(crate) fn is_closed(&self) -> bool {
121         self.giver.is_canceled()
122     }
123 
try_send(&mut self, val: T) -> Result<RetryPromise<T, U>, T>124     pub(crate) fn try_send(&mut self, val: T) -> Result<RetryPromise<T, U>, T> {
125         let (tx, rx) = oneshot::channel();
126         self.inner
127             .send(Envelope(Some((val, Callback::Retry(Some(tx))))))
128             .map(move |_| rx)
129             .map_err(|mut e| (e.0).0.take().expect("envelope not dropped").0)
130     }
131 
132     #[cfg(all(feature = "backports", feature = "http2"))]
send(&mut self, val: T) -> Result<Promise<U>, T>133     pub(crate) fn send(&mut self, val: T) -> Result<Promise<U>, T> {
134         let (tx, rx) = oneshot::channel();
135         self.inner
136             .send(Envelope(Some((val, Callback::NoRetry(Some(tx))))))
137             .map(move |_| rx)
138             .map_err(|mut e| (e.0).0.take().expect("envelope not dropped").0)
139     }
140 }
141 
142 #[cfg(feature = "http2")]
143 impl<T, U> Clone for UnboundedSender<T, U> {
clone(&self) -> Self144     fn clone(&self) -> Self {
145         UnboundedSender {
146             giver: self.giver.clone(),
147             inner: self.inner.clone(),
148         }
149     }
150 }
151 
152 pub(crate) struct Receiver<T, U> {
153     inner: mpsc::UnboundedReceiver<Envelope<T, U>>,
154     taker: want::Taker,
155 }
156 
157 impl<T, U> Receiver<T, U> {
poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<(T, Callback<T, U>)>>158     pub(crate) fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<(T, Callback<T, U>)>> {
159         match self.inner.poll_recv(cx) {
160             Poll::Ready(item) => {
161                 Poll::Ready(item.map(|mut env| env.0.take().expect("envelope not dropped")))
162             }
163             Poll::Pending => {
164                 self.taker.want();
165                 Poll::Pending
166             }
167         }
168     }
169 
170     #[cfg(feature = "http1")]
close(&mut self)171     pub(crate) fn close(&mut self) {
172         self.taker.cancel();
173         self.inner.close();
174     }
175 
176     #[cfg(feature = "http1")]
try_recv(&mut self) -> Option<(T, Callback<T, U>)>177     pub(crate) fn try_recv(&mut self) -> Option<(T, Callback<T, U>)> {
178         match self.inner.recv().now_or_never() {
179             Some(Some(mut env)) => env.0.take(),
180             _ => None,
181         }
182     }
183 }
184 
185 impl<T, U> Drop for Receiver<T, U> {
drop(&mut self)186     fn drop(&mut self) {
187         // Notify the giver about the closure first, before dropping
188         // the mpsc::Receiver.
189         self.taker.cancel();
190     }
191 }
192 
193 struct Envelope<T, U>(Option<(T, Callback<T, U>)>);
194 
195 impl<T, U> Drop for Envelope<T, U> {
drop(&mut self)196     fn drop(&mut self) {
197         if let Some((val, cb)) = self.0.take() {
198             cb.send(Err((
199                 crate::Error::new_canceled().with("connection closed"),
200                 Some(val),
201             )));
202         }
203     }
204 }
205 
206 pub(crate) enum Callback<T, U> {
207     Retry(Option<oneshot::Sender<Result<U, (crate::Error, Option<T>)>>>),
208     NoRetry(Option<oneshot::Sender<Result<U, crate::Error>>>),
209 }
210 
211 impl<T, U> Drop for Callback<T, U> {
drop(&mut self)212     fn drop(&mut self) {
213         // FIXME(nox): What errors do we want here?
214         let error = crate::Error::new_user_dispatch_gone().with(if std::thread::panicking() {
215             "user code panicked"
216         } else {
217             "runtime dropped the dispatch task"
218         });
219 
220         match self {
221             Callback::Retry(tx) => {
222                 if let Some(tx) = tx.take() {
223                     let _ = tx.send(Err((error, None)));
224                 }
225             }
226             Callback::NoRetry(tx) => {
227                 if let Some(tx) = tx.take() {
228                     let _ = tx.send(Err(error));
229                 }
230             }
231         }
232     }
233 }
234 
235 impl<T, U> Callback<T, U> {
236     #[cfg(feature = "http2")]
is_canceled(&self) -> bool237     pub(crate) fn is_canceled(&self) -> bool {
238         match *self {
239             Callback::Retry(Some(ref tx)) => tx.is_closed(),
240             Callback::NoRetry(Some(ref tx)) => tx.is_closed(),
241             _ => unreachable!(),
242         }
243     }
244 
poll_canceled(&mut self, cx: &mut Context<'_>) -> Poll<()>245     pub(crate) fn poll_canceled(&mut self, cx: &mut Context<'_>) -> Poll<()> {
246         match *self {
247             Callback::Retry(Some(ref mut tx)) => tx.poll_closed(cx),
248             Callback::NoRetry(Some(ref mut tx)) => tx.poll_closed(cx),
249             _ => unreachable!(),
250         }
251     }
252 
send(mut self, val: Result<U, (crate::Error, Option<T>)>)253     pub(crate) fn send(mut self, val: Result<U, (crate::Error, Option<T>)>) {
254         match self {
255             Callback::Retry(ref mut tx) => {
256                 let _ = tx.take().unwrap().send(val);
257             }
258             Callback::NoRetry(ref mut tx) => {
259                 let _ = tx.take().unwrap().send(val.map_err(|e| e.0));
260             }
261         }
262     }
263 
264     #[cfg(feature = "http2")]
send_when( self, mut when: impl Future<Output = Result<U, (crate::Error, Option<T>)>> + Unpin, )265     pub(crate) async fn send_when(
266         self,
267         mut when: impl Future<Output = Result<U, (crate::Error, Option<T>)>> + Unpin,
268     ) {
269         use futures_util::future;
270         use tracing::trace;
271 
272         let mut cb = Some(self);
273 
274         // "select" on this callback being canceled, and the future completing
275         future::poll_fn(move |cx| {
276             match Pin::new(&mut when).poll(cx) {
277                 Poll::Ready(Ok(res)) => {
278                     cb.take().expect("polled after complete").send(Ok(res));
279                     Poll::Ready(())
280                 }
281                 Poll::Pending => {
282                     // check if the callback is canceled
283                     ready!(cb.as_mut().unwrap().poll_canceled(cx));
284                     trace!("send_when canceled");
285                     Poll::Ready(())
286                 }
287                 Poll::Ready(Err(err)) => {
288                     cb.take().expect("polled after complete").send(Err(err));
289                     Poll::Ready(())
290                 }
291             }
292         })
293         .await
294     }
295 }
296 
297 #[cfg(test)]
298 mod tests {
299     #[cfg(feature = "nightly")]
300     extern crate test;
301 
302     use std::future::Future;
303     use std::pin::Pin;
304     use std::task::{Context, Poll};
305 
306     use super::{channel, Callback, Receiver};
307 
308     #[derive(Debug)]
309     struct Custom(i32);
310 
311     impl<T, U> Future for Receiver<T, U> {
312         type Output = Option<(T, Callback<T, U>)>;
313 
poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>314         fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
315             self.poll_recv(cx)
316         }
317     }
318 
319     /// Helper to check if the future is ready after polling once.
320     struct PollOnce<'a, F>(&'a mut F);
321 
322     impl<F, T> Future for PollOnce<'_, F>
323     where
324         F: Future<Output = T> + Unpin,
325     {
326         type Output = Option<()>;
327 
poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>328         fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
329             match Pin::new(&mut self.0).poll(cx) {
330                 Poll::Ready(_) => Poll::Ready(Some(())),
331                 Poll::Pending => Poll::Ready(None),
332             }
333         }
334     }
335 
336     #[tokio::test]
drop_receiver_sends_cancel_errors()337     async fn drop_receiver_sends_cancel_errors() {
338         let _ = pretty_env_logger::try_init();
339 
340         let (mut tx, mut rx) = channel::<Custom, ()>();
341 
342         // must poll once for try_send to succeed
343         assert!(PollOnce(&mut rx).await.is_none(), "rx empty");
344 
345         let promise = tx.try_send(Custom(43)).unwrap();
346         drop(rx);
347 
348         let fulfilled = promise.await;
349         let err = fulfilled
350             .expect("fulfilled")
351             .expect_err("promise should error");
352         match (err.0.kind(), err.1) {
353             (&crate::error::Kind::Canceled, Some(_)) => (),
354             e => panic!("expected Error::Cancel(_), found {:?}", e),
355         }
356     }
357 
358     #[tokio::test]
sender_checks_for_want_on_send()359     async fn sender_checks_for_want_on_send() {
360         let (mut tx, mut rx) = channel::<Custom, ()>();
361 
362         // one is allowed to buffer, second is rejected
363         let _ = tx.try_send(Custom(1)).expect("1 buffered");
364         tx.try_send(Custom(2)).expect_err("2 not ready");
365 
366         assert!(PollOnce(&mut rx).await.is_some(), "rx once");
367 
368         // Even though 1 has been popped, only 1 could be buffered for the
369         // lifetime of the channel.
370         tx.try_send(Custom(2)).expect_err("2 still not ready");
371 
372         assert!(PollOnce(&mut rx).await.is_none(), "rx empty");
373 
374         let _ = tx.try_send(Custom(2)).expect("2 ready");
375     }
376 
377     #[cfg(feature = "http2")]
378     #[test]
unbounded_sender_doesnt_bound_on_want()379     fn unbounded_sender_doesnt_bound_on_want() {
380         let (tx, rx) = channel::<Custom, ()>();
381         let mut tx = tx.unbound();
382 
383         let _ = tx.try_send(Custom(1)).unwrap();
384         let _ = tx.try_send(Custom(2)).unwrap();
385         let _ = tx.try_send(Custom(3)).unwrap();
386 
387         drop(rx);
388 
389         let _ = tx.try_send(Custom(4)).unwrap_err();
390     }
391 
392     #[cfg(feature = "nightly")]
393     #[bench]
giver_queue_throughput(b: &mut test::Bencher)394     fn giver_queue_throughput(b: &mut test::Bencher) {
395         use crate::{Body, Request, Response};
396 
397         let rt = tokio::runtime::Builder::new_current_thread()
398             .enable_all()
399             .build()
400             .unwrap();
401         let (mut tx, mut rx) = channel::<Request<Body>, Response<Body>>();
402 
403         b.iter(move || {
404             let _ = tx.send(Request::default()).unwrap();
405             rt.block_on(async {
406                 loop {
407                     let poll_once = PollOnce(&mut rx);
408                     let opt = poll_once.await;
409                     if opt.is_none() {
410                         break;
411                     }
412                 }
413             });
414         })
415     }
416 
417     #[cfg(feature = "nightly")]
418     #[bench]
giver_queue_not_ready(b: &mut test::Bencher)419     fn giver_queue_not_ready(b: &mut test::Bencher) {
420         let rt = tokio::runtime::Builder::new_current_thread()
421             .enable_all()
422             .build()
423             .unwrap();
424         let (_tx, mut rx) = channel::<i32, ()>();
425         b.iter(move || {
426             rt.block_on(async {
427                 let poll_once = PollOnce(&mut rx);
428                 assert!(poll_once.await.is_none());
429             });
430         })
431     }
432 
433     #[cfg(feature = "nightly")]
434     #[bench]
giver_queue_cancel(b: &mut test::Bencher)435     fn giver_queue_cancel(b: &mut test::Bencher) {
436         let (_tx, mut rx) = channel::<i32, ()>();
437 
438         b.iter(move || {
439             rx.taker.cancel();
440         })
441     }
442 }
443