1 use super::{
2     error::{Closed, ServiceError},
3     message::Message,
4 };
5 use futures_core::ready;
6 use std::sync::{Arc, Mutex, Weak};
7 use std::{
8     future::Future,
9     pin::Pin,
10     task::{Context, Poll},
11 };
12 use tokio::sync::{mpsc, Semaphore};
13 use tower_service::Service;
14 
15 pin_project_lite::pin_project! {
16     /// Task that handles processing the buffer. This type should not be used
17     /// directly, instead `Buffer` requires an `Executor` that can accept this task.
18     ///
19     /// The struct is `pub` in the private module and the type is *not* re-exported
20     /// as part of the public API. This is the "sealed" pattern to include "private"
21     /// types in public traits that are not meant for consumers of the library to
22     /// implement (only call).
23     #[derive(Debug)]
24     pub struct Worker<T, Request>
25     where
26         T: Service<Request>,
27     {
28         current_message: Option<Message<Request, T::Future>>,
29         rx: mpsc::UnboundedReceiver<Message<Request, T::Future>>,
30         service: T,
31         finish: bool,
32         failed: Option<ServiceError>,
33         handle: Handle,
34         close: Option<Weak<Semaphore>>,
35     }
36 
37     impl<T: Service<Request>, Request> PinnedDrop for Worker<T, Request>
38     {
39         fn drop(mut this: Pin<&mut Self>) {
40             this.as_mut().close_semaphore();
41         }
42     }
43 }
44 
45 /// Get the error out
46 #[derive(Debug)]
47 pub(crate) struct Handle {
48     inner: Arc<Mutex<Option<ServiceError>>>,
49 }
50 
51 impl<T, Request> Worker<T, Request>
52 where
53     T: Service<Request>,
54 {
55     /// Closes the buffer's semaphore if it is still open, waking any pending
56     /// tasks.
close_semaphore(&mut self)57     fn close_semaphore(&mut self) {
58         if let Some(close) = self.close.take().as_ref().and_then(Weak::upgrade) {
59             tracing::debug!("buffer closing; waking pending tasks");
60             close.close();
61         } else {
62             tracing::trace!("buffer already closed");
63         }
64     }
65 }
66 
67 impl<T, Request> Worker<T, Request>
68 where
69     T: Service<Request>,
70     T::Error: Into<crate::BoxError>,
71 {
new( service: T, rx: mpsc::UnboundedReceiver<Message<Request, T::Future>>, semaphore: &Arc<Semaphore>, ) -> (Handle, Worker<T, Request>)72     pub(crate) fn new(
73         service: T,
74         rx: mpsc::UnboundedReceiver<Message<Request, T::Future>>,
75         semaphore: &Arc<Semaphore>,
76     ) -> (Handle, Worker<T, Request>) {
77         let handle = Handle {
78             inner: Arc::new(Mutex::new(None)),
79         };
80 
81         let semaphore = Arc::downgrade(semaphore);
82         let worker = Worker {
83             current_message: None,
84             finish: false,
85             failed: None,
86             rx,
87             service,
88             handle: handle.clone(),
89             close: Some(semaphore),
90         };
91 
92         (handle, worker)
93     }
94 
95     /// Return the next queued Message that hasn't been canceled.
96     ///
97     /// If a `Message` is returned, the `bool` is true if this is the first time we received this
98     /// message, and false otherwise (i.e., we tried to forward it to the backing service before).
poll_next_msg( &mut self, cx: &mut Context<'_>, ) -> Poll<Option<(Message<Request, T::Future>, bool)>>99     fn poll_next_msg(
100         &mut self,
101         cx: &mut Context<'_>,
102     ) -> Poll<Option<(Message<Request, T::Future>, bool)>> {
103         if self.finish {
104             // We've already received None and are shutting down
105             return Poll::Ready(None);
106         }
107 
108         tracing::trace!("worker polling for next message");
109         if let Some(msg) = self.current_message.take() {
110             // If the oneshot sender is closed, then the receiver is dropped,
111             // and nobody cares about the response. If this is the case, we
112             // should continue to the next request.
113             if !msg.tx.is_closed() {
114                 tracing::trace!("resuming buffered request");
115                 return Poll::Ready(Some((msg, false)));
116             }
117 
118             tracing::trace!("dropping cancelled buffered request");
119         }
120 
121         // Get the next request
122         while let Some(msg) = ready!(Pin::new(&mut self.rx).poll_recv(cx)) {
123             if !msg.tx.is_closed() {
124                 tracing::trace!("processing new request");
125                 return Poll::Ready(Some((msg, true)));
126             }
127             // Otherwise, request is canceled, so pop the next one.
128             tracing::trace!("dropping cancelled request");
129         }
130 
131         Poll::Ready(None)
132     }
133 
failed(&mut self, error: crate::BoxError)134     fn failed(&mut self, error: crate::BoxError) {
135         // The underlying service failed when we called `poll_ready` on it with the given `error`. We
136         // need to communicate this to all the `Buffer` handles. To do so, we wrap up the error in
137         // an `Arc`, send that `Arc<E>` to all pending requests, and store it so that subsequent
138         // requests will also fail with the same error.
139 
140         // Note that we need to handle the case where some handle is concurrently trying to send us
141         // a request. We need to make sure that *either* the send of the request fails *or* it
142         // receives an error on the `oneshot` it constructed. Specifically, we want to avoid the
143         // case where we send errors to all outstanding requests, and *then* the caller sends its
144         // request. We do this by *first* exposing the error, *then* closing the channel used to
145         // send more requests (so the client will see the error when the send fails), and *then*
146         // sending the error to all outstanding requests.
147         let error = ServiceError::new(error);
148 
149         let mut inner = self.handle.inner.lock().unwrap();
150 
151         if inner.is_some() {
152             // Future::poll was called after we've already errored out!
153             return;
154         }
155 
156         *inner = Some(error.clone());
157         drop(inner);
158 
159         self.rx.close();
160 
161         // By closing the mpsc::Receiver, we know that poll_next_msg will soon return Ready(None),
162         // which will trigger the `self.finish == true` phase. We just need to make sure that any
163         // requests that we receive before we've exhausted the receiver receive the error:
164         self.failed = Some(error);
165     }
166 }
167 
168 impl<T, Request> Future for Worker<T, Request>
169 where
170     T: Service<Request>,
171     T::Error: Into<crate::BoxError>,
172 {
173     type Output = ();
174 
poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>175     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
176         if self.finish {
177             return Poll::Ready(());
178         }
179 
180         loop {
181             match ready!(self.poll_next_msg(cx)) {
182                 Some((msg, first)) => {
183                     let _guard = msg.span.enter();
184                     if let Some(ref failed) = self.failed {
185                         tracing::trace!("notifying caller about worker failure");
186                         let _ = msg.tx.send(Err(failed.clone()));
187                         continue;
188                     }
189 
190                     // Wait for the service to be ready
191                     tracing::trace!(
192                         resumed = !first,
193                         message = "worker received request; waiting for service readiness"
194                     );
195                     match self.service.poll_ready(cx) {
196                         Poll::Ready(Ok(())) => {
197                             tracing::debug!(service.ready = true, message = "processing request");
198                             let response = self.service.call(msg.request);
199 
200                             // Send the response future back to the sender.
201                             //
202                             // An error means the request had been canceled in-between
203                             // our calls, the response future will just be dropped.
204                             tracing::trace!("returning response future");
205                             let _ = msg.tx.send(Ok(response));
206                         }
207                         Poll::Pending => {
208                             tracing::trace!(service.ready = false, message = "delay");
209                             // Put out current message back in its slot.
210                             drop(_guard);
211                             self.current_message = Some(msg);
212                             return Poll::Pending;
213                         }
214                         Poll::Ready(Err(e)) => {
215                             let error = e.into();
216                             tracing::debug!({ %error }, "service failed");
217                             drop(_guard);
218                             self.failed(error);
219                             let _ = msg.tx.send(Err(self
220                                 .failed
221                                 .as_ref()
222                                 .expect("Worker::failed did not set self.failed?")
223                                 .clone()));
224                             // Wake any tasks waiting on channel capacity.
225                             self.close_semaphore();
226                         }
227                     }
228                 }
229                 None => {
230                     // No more more requests _ever_.
231                     self.finish = true;
232                     return Poll::Ready(());
233                 }
234             }
235         }
236     }
237 }
238 
239 impl Handle {
get_error_on_closed(&self) -> crate::BoxError240     pub(crate) fn get_error_on_closed(&self) -> crate::BoxError {
241         self.inner
242             .lock()
243             .unwrap()
244             .as_ref()
245             .map(|svc_err| svc_err.clone().into())
246             .unwrap_or_else(|| Closed::new().into())
247     }
248 }
249 
250 impl Clone for Handle {
clone(&self) -> Handle251     fn clone(&self) -> Handle {
252         Handle {
253             inner: self.inner.clone(),
254         }
255     }
256 }
257