1 use futures_util::ready;
2 use pin_project_lite::pin_project;
3 use std::time::Duration;
4 use std::{
5     future::Future,
6     pin::Pin,
7     task::{Context, Poll},
8 };
9 use tower_service::Service;
10 
11 use crate::util::Oneshot;
12 
13 /// A policy which specifies how long each request should be delayed for.
14 pub trait Policy<Request> {
delay(&self, req: &Request) -> Duration15     fn delay(&self, req: &Request) -> Duration;
16 }
17 
18 /// A middleware which delays sending the request to the underlying service
19 /// for an amount of time specified by the policy.
20 #[derive(Debug)]
21 pub struct Delay<P, S> {
22     policy: P,
23     service: S,
24 }
25 
26 pin_project! {
27     #[derive(Debug)]
28     pub struct ResponseFuture<Request, S>
29     where
30         S: Service<Request>,
31     {
32         service: Option<S>,
33         #[pin]
34         state: State<Request, Oneshot<S, Request>>,
35     }
36 }
37 
38 pin_project! {
39     #[project = StateProj]
40     #[derive(Debug)]
41     enum State<Request, F> {
42         Delaying {
43             #[pin]
44             delay: tokio::time::Sleep,
45             req: Option<Request>,
46         },
47         Called {
48             #[pin]
49             fut: F,
50         },
51     }
52 }
53 
54 impl<Request, F> State<Request, F> {
delaying(delay: tokio::time::Sleep, req: Option<Request>) -> Self55     fn delaying(delay: tokio::time::Sleep, req: Option<Request>) -> Self {
56         Self::Delaying { delay, req }
57     }
58 
called(fut: F) -> Self59     fn called(fut: F) -> Self {
60         Self::Called { fut }
61     }
62 }
63 
64 impl<P, S> Delay<P, S> {
new<Request>(policy: P, service: S) -> Self where P: Policy<Request>, S: Service<Request> + Clone, S::Error: Into<crate::BoxError>,65     pub fn new<Request>(policy: P, service: S) -> Self
66     where
67         P: Policy<Request>,
68         S: Service<Request> + Clone,
69         S::Error: Into<crate::BoxError>,
70     {
71         Delay { policy, service }
72     }
73 }
74 
75 impl<Request, P, S> Service<Request> for Delay<P, S>
76 where
77     P: Policy<Request>,
78     S: Service<Request> + Clone,
79     S::Error: Into<crate::BoxError>,
80 {
81     type Response = S::Response;
82     type Error = crate::BoxError;
83     type Future = ResponseFuture<Request, S>;
84 
poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>85     fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
86         // Calling self.service.poll_ready would reserve a slot for the delayed request,
87         // potentially well in advance of actually making it.  Instead, signal readiness here and
88         // treat the service as a Oneshot in the future.
89         Poll::Ready(Ok(()))
90     }
91 
call(&mut self, request: Request) -> Self::Future92     fn call(&mut self, request: Request) -> Self::Future {
93         let delay = self.policy.delay(&request);
94         ResponseFuture {
95             service: Some(self.service.clone()),
96             state: State::delaying(tokio::time::sleep(delay), Some(request)),
97         }
98     }
99 }
100 
101 impl<Request, S, T, E> Future for ResponseFuture<Request, S>
102 where
103     E: Into<crate::BoxError>,
104     S: Service<Request, Response = T, Error = E>,
105 {
106     type Output = Result<T, crate::BoxError>;
107 
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>108     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
109         let mut this = self.project();
110 
111         loop {
112             match this.state.as_mut().project() {
113                 StateProj::Delaying { delay, req } => {
114                     ready!(delay.poll(cx));
115                     let req = req.take().expect("Missing request in delay");
116                     let svc = this.service.take().expect("Missing service in delay");
117                     let fut = Oneshot::new(svc, req);
118                     this.state.set(State::called(fut));
119                 }
120                 StateProj::Called { fut } => {
121                     return fut.poll(cx).map_err(Into::into);
122                 }
123             };
124         }
125     }
126 }
127