1 use super::future::ResponseFuture;
2 use tokio::sync::{OwnedSemaphorePermit, Semaphore};
3 use tokio_util::sync::PollSemaphore;
4 use tower_service::Service;
5 
6 use futures_core::ready;
7 use std::{
8     sync::Arc,
9     task::{Context, Poll},
10 };
11 
12 /// Enforces a limit on the concurrent number of requests the underlying
13 /// service can handle.
14 #[derive(Debug)]
15 pub struct ConcurrencyLimit<T> {
16     inner: T,
17     semaphore: PollSemaphore,
18     /// The currently acquired semaphore permit, if there is sufficient
19     /// concurrency to send a new request.
20     ///
21     /// The permit is acquired in `poll_ready`, and taken in `call` when sending
22     /// a new request.
23     permit: Option<OwnedSemaphorePermit>,
24 }
25 
26 impl<T> ConcurrencyLimit<T> {
27     /// Create a new concurrency limiter.
new(inner: T, max: usize) -> Self28     pub fn new(inner: T, max: usize) -> Self {
29         Self::with_semaphore(inner, Arc::new(Semaphore::new(max)))
30     }
31 
32     /// Create a new concurrency limiter with a provided shared semaphore
with_semaphore(inner: T, semaphore: Arc<Semaphore>) -> Self33     pub fn with_semaphore(inner: T, semaphore: Arc<Semaphore>) -> Self {
34         ConcurrencyLimit {
35             inner,
36             semaphore: PollSemaphore::new(semaphore),
37             permit: None,
38         }
39     }
40 
41     /// Get a reference to the inner service
get_ref(&self) -> &T42     pub fn get_ref(&self) -> &T {
43         &self.inner
44     }
45 
46     /// Get a mutable reference to the inner service
get_mut(&mut self) -> &mut T47     pub fn get_mut(&mut self) -> &mut T {
48         &mut self.inner
49     }
50 
51     /// Consume `self`, returning the inner service
into_inner(self) -> T52     pub fn into_inner(self) -> T {
53         self.inner
54     }
55 }
56 
57 impl<S, Request> Service<Request> for ConcurrencyLimit<S>
58 where
59     S: Service<Request>,
60 {
61     type Response = S::Response;
62     type Error = S::Error;
63     type Future = ResponseFuture<S::Future>;
64 
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>65     fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
66         // If we haven't already acquired a permit from the semaphore, try to
67         // acquire one first.
68         if self.permit.is_none() {
69             self.permit = ready!(self.semaphore.poll_acquire(cx));
70             debug_assert!(
71                 self.permit.is_some(),
72                 "ConcurrencyLimit semaphore is never closed, so `poll_acquire` \
73                  should never fail",
74             );
75         }
76 
77         // Once we've acquired a permit (or if we already had one), poll the
78         // inner service.
79         self.inner.poll_ready(cx)
80     }
81 
call(&mut self, request: Request) -> Self::Future82     fn call(&mut self, request: Request) -> Self::Future {
83         // Take the permit
84         let permit = self
85             .permit
86             .take()
87             .expect("max requests in-flight; poll_ready must be called first");
88 
89         // Call the inner service
90         let future = self.inner.call(request);
91 
92         ResponseFuture::new(future, permit)
93     }
94 }
95 
96 impl<T: Clone> Clone for ConcurrencyLimit<T> {
clone(&self) -> Self97     fn clone(&self) -> Self {
98         // Since we hold an `OwnedSemaphorePermit`, we can't derive `Clone`.
99         // Instead, when cloning the service, create a new service with the
100         // same semaphore, but with the permit in the un-acquired state.
101         Self {
102             inner: self.inner.clone(),
103             semaphore: self.semaphore.clone(),
104             permit: None,
105         }
106     }
107 }
108 
109 #[cfg(feature = "load")]
110 #[cfg_attr(docsrs, doc(cfg(feature = "load")))]
111 impl<S> crate::load::Load for ConcurrencyLimit<S>
112 where
113     S: crate::load::Load,
114 {
115     type Metric = S::Metric;
load(&self) -> Self::Metric116     fn load(&self) -> Self::Metric {
117         self.inner.load()
118     }
119 }
120