1 #[cfg(feature = "http2")]
2 use std::future::Future;
3 use std::marker::Unpin;
4 #[cfg(feature = "http2")]
5 use std::pin::Pin;
6 use std::task::{Context, Poll};
7
8 use futures_util::FutureExt;
9 use tokio::sync::{mpsc, oneshot};
10
11 pub(crate) type RetryPromise<T, U> = oneshot::Receiver<Result<U, (crate::Error, Option<T>)>>;
12 pub(crate) type Promise<T> = oneshot::Receiver<Result<T, crate::Error>>;
13
channel<T, U>() -> (Sender<T, U>, Receiver<T, U>)14 pub(crate) fn channel<T, U>() -> (Sender<T, U>, Receiver<T, U>) {
15 let (tx, rx) = mpsc::unbounded_channel();
16 let (giver, taker) = want::new();
17 let tx = Sender {
18 buffered_once: false,
19 giver,
20 inner: tx,
21 };
22 let rx = Receiver { inner: rx, taker };
23 (tx, rx)
24 }
25
26 /// A bounded sender of requests and callbacks for when responses are ready.
27 ///
28 /// While the inner sender is unbounded, the Giver is used to determine
29 /// if the Receiver is ready for another request.
30 pub(crate) struct Sender<T, U> {
31 /// One message is always allowed, even if the Receiver hasn't asked
32 /// for it yet. This boolean keeps track of whether we've sent one
33 /// without notice.
34 buffered_once: bool,
35 /// The Giver helps watch that the the Receiver side has been polled
36 /// when the queue is empty. This helps us know when a request and
37 /// response have been fully processed, and a connection is ready
38 /// for more.
39 giver: want::Giver,
40 /// Actually bounded by the Giver, plus `buffered_once`.
41 inner: mpsc::UnboundedSender<Envelope<T, U>>,
42 }
43
44 /// An unbounded version.
45 ///
46 /// Cannot poll the Giver, but can still use it to determine if the Receiver
47 /// has been dropped. However, this version can be cloned.
48 #[cfg(feature = "http2")]
49 pub(crate) struct UnboundedSender<T, U> {
50 /// Only used for `is_closed`, since mpsc::UnboundedSender cannot be checked.
51 giver: want::SharedGiver,
52 inner: mpsc::UnboundedSender<Envelope<T, U>>,
53 }
54
55 impl<T, U> Sender<T, U> {
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<crate::Result<()>>56 pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<crate::Result<()>> {
57 self.giver
58 .poll_want(cx)
59 .map_err(|_| crate::Error::new_closed())
60 }
61
is_ready(&self) -> bool62 pub(crate) fn is_ready(&self) -> bool {
63 self.giver.is_wanting()
64 }
65
is_closed(&self) -> bool66 pub(crate) fn is_closed(&self) -> bool {
67 self.giver.is_canceled()
68 }
69
can_send(&mut self) -> bool70 fn can_send(&mut self) -> bool {
71 if self.giver.give() || !self.buffered_once {
72 // If the receiver is ready *now*, then of course we can send.
73 //
74 // If the receiver isn't ready yet, but we don't have anything
75 // in the channel yet, then allow one message.
76 self.buffered_once = true;
77 true
78 } else {
79 false
80 }
81 }
82
try_send(&mut self, val: T) -> Result<RetryPromise<T, U>, T>83 pub(crate) fn try_send(&mut self, val: T) -> Result<RetryPromise<T, U>, T> {
84 if !self.can_send() {
85 return Err(val);
86 }
87 let (tx, rx) = oneshot::channel();
88 self.inner
89 .send(Envelope(Some((val, Callback::Retry(Some(tx))))))
90 .map(move |_| rx)
91 .map_err(|mut e| (e.0).0.take().expect("envelope not dropped").0)
92 }
93
send(&mut self, val: T) -> Result<Promise<U>, T>94 pub(crate) fn send(&mut self, val: T) -> Result<Promise<U>, T> {
95 if !self.can_send() {
96 return Err(val);
97 }
98 let (tx, rx) = oneshot::channel();
99 self.inner
100 .send(Envelope(Some((val, Callback::NoRetry(Some(tx))))))
101 .map(move |_| rx)
102 .map_err(|mut e| (e.0).0.take().expect("envelope not dropped").0)
103 }
104
105 #[cfg(feature = "http2")]
unbound(self) -> UnboundedSender<T, U>106 pub(crate) fn unbound(self) -> UnboundedSender<T, U> {
107 UnboundedSender {
108 giver: self.giver.shared(),
109 inner: self.inner,
110 }
111 }
112 }
113
114 #[cfg(feature = "http2")]
115 impl<T, U> UnboundedSender<T, U> {
is_ready(&self) -> bool116 pub(crate) fn is_ready(&self) -> bool {
117 !self.giver.is_canceled()
118 }
119
is_closed(&self) -> bool120 pub(crate) fn is_closed(&self) -> bool {
121 self.giver.is_canceled()
122 }
123
try_send(&mut self, val: T) -> Result<RetryPromise<T, U>, T>124 pub(crate) fn try_send(&mut self, val: T) -> Result<RetryPromise<T, U>, T> {
125 let (tx, rx) = oneshot::channel();
126 self.inner
127 .send(Envelope(Some((val, Callback::Retry(Some(tx))))))
128 .map(move |_| rx)
129 .map_err(|mut e| (e.0).0.take().expect("envelope not dropped").0)
130 }
131
132 #[cfg(all(feature = "backports", feature = "http2"))]
send(&mut self, val: T) -> Result<Promise<U>, T>133 pub(crate) fn send(&mut self, val: T) -> Result<Promise<U>, T> {
134 let (tx, rx) = oneshot::channel();
135 self.inner
136 .send(Envelope(Some((val, Callback::NoRetry(Some(tx))))))
137 .map(move |_| rx)
138 .map_err(|mut e| (e.0).0.take().expect("envelope not dropped").0)
139 }
140 }
141
142 #[cfg(feature = "http2")]
143 impl<T, U> Clone for UnboundedSender<T, U> {
clone(&self) -> Self144 fn clone(&self) -> Self {
145 UnboundedSender {
146 giver: self.giver.clone(),
147 inner: self.inner.clone(),
148 }
149 }
150 }
151
152 pub(crate) struct Receiver<T, U> {
153 inner: mpsc::UnboundedReceiver<Envelope<T, U>>,
154 taker: want::Taker,
155 }
156
157 impl<T, U> Receiver<T, U> {
poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<(T, Callback<T, U>)>>158 pub(crate) fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<(T, Callback<T, U>)>> {
159 match self.inner.poll_recv(cx) {
160 Poll::Ready(item) => {
161 Poll::Ready(item.map(|mut env| env.0.take().expect("envelope not dropped")))
162 }
163 Poll::Pending => {
164 self.taker.want();
165 Poll::Pending
166 }
167 }
168 }
169
170 #[cfg(feature = "http1")]
close(&mut self)171 pub(crate) fn close(&mut self) {
172 self.taker.cancel();
173 self.inner.close();
174 }
175
176 #[cfg(feature = "http1")]
try_recv(&mut self) -> Option<(T, Callback<T, U>)>177 pub(crate) fn try_recv(&mut self) -> Option<(T, Callback<T, U>)> {
178 match self.inner.recv().now_or_never() {
179 Some(Some(mut env)) => env.0.take(),
180 _ => None,
181 }
182 }
183 }
184
185 impl<T, U> Drop for Receiver<T, U> {
drop(&mut self)186 fn drop(&mut self) {
187 // Notify the giver about the closure first, before dropping
188 // the mpsc::Receiver.
189 self.taker.cancel();
190 }
191 }
192
193 struct Envelope<T, U>(Option<(T, Callback<T, U>)>);
194
195 impl<T, U> Drop for Envelope<T, U> {
drop(&mut self)196 fn drop(&mut self) {
197 if let Some((val, cb)) = self.0.take() {
198 cb.send(Err((
199 crate::Error::new_canceled().with("connection closed"),
200 Some(val),
201 )));
202 }
203 }
204 }
205
206 pub(crate) enum Callback<T, U> {
207 Retry(Option<oneshot::Sender<Result<U, (crate::Error, Option<T>)>>>),
208 NoRetry(Option<oneshot::Sender<Result<U, crate::Error>>>),
209 }
210
211 impl<T, U> Drop for Callback<T, U> {
drop(&mut self)212 fn drop(&mut self) {
213 // FIXME(nox): What errors do we want here?
214 let error = crate::Error::new_user_dispatch_gone().with(if std::thread::panicking() {
215 "user code panicked"
216 } else {
217 "runtime dropped the dispatch task"
218 });
219
220 match self {
221 Callback::Retry(tx) => {
222 if let Some(tx) = tx.take() {
223 let _ = tx.send(Err((error, None)));
224 }
225 }
226 Callback::NoRetry(tx) => {
227 if let Some(tx) = tx.take() {
228 let _ = tx.send(Err(error));
229 }
230 }
231 }
232 }
233 }
234
235 impl<T, U> Callback<T, U> {
236 #[cfg(feature = "http2")]
is_canceled(&self) -> bool237 pub(crate) fn is_canceled(&self) -> bool {
238 match *self {
239 Callback::Retry(Some(ref tx)) => tx.is_closed(),
240 Callback::NoRetry(Some(ref tx)) => tx.is_closed(),
241 _ => unreachable!(),
242 }
243 }
244
poll_canceled(&mut self, cx: &mut Context<'_>) -> Poll<()>245 pub(crate) fn poll_canceled(&mut self, cx: &mut Context<'_>) -> Poll<()> {
246 match *self {
247 Callback::Retry(Some(ref mut tx)) => tx.poll_closed(cx),
248 Callback::NoRetry(Some(ref mut tx)) => tx.poll_closed(cx),
249 _ => unreachable!(),
250 }
251 }
252
send(mut self, val: Result<U, (crate::Error, Option<T>)>)253 pub(crate) fn send(mut self, val: Result<U, (crate::Error, Option<T>)>) {
254 match self {
255 Callback::Retry(ref mut tx) => {
256 let _ = tx.take().unwrap().send(val);
257 }
258 Callback::NoRetry(ref mut tx) => {
259 let _ = tx.take().unwrap().send(val.map_err(|e| e.0));
260 }
261 }
262 }
263
264 #[cfg(feature = "http2")]
send_when( self, mut when: impl Future<Output = Result<U, (crate::Error, Option<T>)>> + Unpin, )265 pub(crate) async fn send_when(
266 self,
267 mut when: impl Future<Output = Result<U, (crate::Error, Option<T>)>> + Unpin,
268 ) {
269 use futures_util::future;
270 use tracing::trace;
271
272 let mut cb = Some(self);
273
274 // "select" on this callback being canceled, and the future completing
275 future::poll_fn(move |cx| {
276 match Pin::new(&mut when).poll(cx) {
277 Poll::Ready(Ok(res)) => {
278 cb.take().expect("polled after complete").send(Ok(res));
279 Poll::Ready(())
280 }
281 Poll::Pending => {
282 // check if the callback is canceled
283 ready!(cb.as_mut().unwrap().poll_canceled(cx));
284 trace!("send_when canceled");
285 Poll::Ready(())
286 }
287 Poll::Ready(Err(err)) => {
288 cb.take().expect("polled after complete").send(Err(err));
289 Poll::Ready(())
290 }
291 }
292 })
293 .await
294 }
295 }
296
297 #[cfg(test)]
298 mod tests {
299 #[cfg(feature = "nightly")]
300 extern crate test;
301
302 use std::future::Future;
303 use std::pin::Pin;
304 use std::task::{Context, Poll};
305
306 use super::{channel, Callback, Receiver};
307
308 #[derive(Debug)]
309 struct Custom(i32);
310
311 impl<T, U> Future for Receiver<T, U> {
312 type Output = Option<(T, Callback<T, U>)>;
313
poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>314 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
315 self.poll_recv(cx)
316 }
317 }
318
319 /// Helper to check if the future is ready after polling once.
320 struct PollOnce<'a, F>(&'a mut F);
321
322 impl<F, T> Future for PollOnce<'_, F>
323 where
324 F: Future<Output = T> + Unpin,
325 {
326 type Output = Option<()>;
327
poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>328 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
329 match Pin::new(&mut self.0).poll(cx) {
330 Poll::Ready(_) => Poll::Ready(Some(())),
331 Poll::Pending => Poll::Ready(None),
332 }
333 }
334 }
335
336 #[tokio::test]
drop_receiver_sends_cancel_errors()337 async fn drop_receiver_sends_cancel_errors() {
338 let _ = pretty_env_logger::try_init();
339
340 let (mut tx, mut rx) = channel::<Custom, ()>();
341
342 // must poll once for try_send to succeed
343 assert!(PollOnce(&mut rx).await.is_none(), "rx empty");
344
345 let promise = tx.try_send(Custom(43)).unwrap();
346 drop(rx);
347
348 let fulfilled = promise.await;
349 let err = fulfilled
350 .expect("fulfilled")
351 .expect_err("promise should error");
352 match (err.0.kind(), err.1) {
353 (&crate::error::Kind::Canceled, Some(_)) => (),
354 e => panic!("expected Error::Cancel(_), found {:?}", e),
355 }
356 }
357
358 #[tokio::test]
sender_checks_for_want_on_send()359 async fn sender_checks_for_want_on_send() {
360 let (mut tx, mut rx) = channel::<Custom, ()>();
361
362 // one is allowed to buffer, second is rejected
363 let _ = tx.try_send(Custom(1)).expect("1 buffered");
364 tx.try_send(Custom(2)).expect_err("2 not ready");
365
366 assert!(PollOnce(&mut rx).await.is_some(), "rx once");
367
368 // Even though 1 has been popped, only 1 could be buffered for the
369 // lifetime of the channel.
370 tx.try_send(Custom(2)).expect_err("2 still not ready");
371
372 assert!(PollOnce(&mut rx).await.is_none(), "rx empty");
373
374 let _ = tx.try_send(Custom(2)).expect("2 ready");
375 }
376
377 #[cfg(feature = "http2")]
378 #[test]
unbounded_sender_doesnt_bound_on_want()379 fn unbounded_sender_doesnt_bound_on_want() {
380 let (tx, rx) = channel::<Custom, ()>();
381 let mut tx = tx.unbound();
382
383 let _ = tx.try_send(Custom(1)).unwrap();
384 let _ = tx.try_send(Custom(2)).unwrap();
385 let _ = tx.try_send(Custom(3)).unwrap();
386
387 drop(rx);
388
389 let _ = tx.try_send(Custom(4)).unwrap_err();
390 }
391
392 #[cfg(feature = "nightly")]
393 #[bench]
giver_queue_throughput(b: &mut test::Bencher)394 fn giver_queue_throughput(b: &mut test::Bencher) {
395 use crate::{Body, Request, Response};
396
397 let rt = tokio::runtime::Builder::new_current_thread()
398 .enable_all()
399 .build()
400 .unwrap();
401 let (mut tx, mut rx) = channel::<Request<Body>, Response<Body>>();
402
403 b.iter(move || {
404 let _ = tx.send(Request::default()).unwrap();
405 rt.block_on(async {
406 loop {
407 let poll_once = PollOnce(&mut rx);
408 let opt = poll_once.await;
409 if opt.is_none() {
410 break;
411 }
412 }
413 });
414 })
415 }
416
417 #[cfg(feature = "nightly")]
418 #[bench]
giver_queue_not_ready(b: &mut test::Bencher)419 fn giver_queue_not_ready(b: &mut test::Bencher) {
420 let rt = tokio::runtime::Builder::new_current_thread()
421 .enable_all()
422 .build()
423 .unwrap();
424 let (_tx, mut rx) = channel::<i32, ()>();
425 b.iter(move || {
426 rt.block_on(async {
427 let poll_once = PollOnce(&mut rx);
428 assert!(poll_once.await.is_none());
429 });
430 })
431 }
432
433 #[cfg(feature = "nightly")]
434 #[bench]
giver_queue_cancel(b: &mut test::Bencher)435 fn giver_queue_cancel(b: &mut test::Bencher) {
436 let (_tx, mut rx) = channel::<i32, ()>();
437
438 b.iter(move || {
439 rx.taker.cancel();
440 })
441 }
442 }
443