1 use super::future::ResponseFuture; 2 use tokio::sync::{OwnedSemaphorePermit, Semaphore}; 3 use tokio_util::sync::PollSemaphore; 4 use tower_service::Service; 5 6 use futures_core::ready; 7 use std::{ 8 sync::Arc, 9 task::{Context, Poll}, 10 }; 11 12 /// Enforces a limit on the concurrent number of requests the underlying 13 /// service can handle. 14 #[derive(Debug)] 15 pub struct ConcurrencyLimit<T> { 16 inner: T, 17 semaphore: PollSemaphore, 18 /// The currently acquired semaphore permit, if there is sufficient 19 /// concurrency to send a new request. 20 /// 21 /// The permit is acquired in `poll_ready`, and taken in `call` when sending 22 /// a new request. 23 permit: Option<OwnedSemaphorePermit>, 24 } 25 26 impl<T> ConcurrencyLimit<T> { 27 /// Create a new concurrency limiter. new(inner: T, max: usize) -> Self28 pub fn new(inner: T, max: usize) -> Self { 29 Self::with_semaphore(inner, Arc::new(Semaphore::new(max))) 30 } 31 32 /// Create a new concurrency limiter with a provided shared semaphore with_semaphore(inner: T, semaphore: Arc<Semaphore>) -> Self33 pub fn with_semaphore(inner: T, semaphore: Arc<Semaphore>) -> Self { 34 ConcurrencyLimit { 35 inner, 36 semaphore: PollSemaphore::new(semaphore), 37 permit: None, 38 } 39 } 40 41 /// Get a reference to the inner service get_ref(&self) -> &T42 pub fn get_ref(&self) -> &T { 43 &self.inner 44 } 45 46 /// Get a mutable reference to the inner service get_mut(&mut self) -> &mut T47 pub fn get_mut(&mut self) -> &mut T { 48 &mut self.inner 49 } 50 51 /// Consume `self`, returning the inner service into_inner(self) -> T52 pub fn into_inner(self) -> T { 53 self.inner 54 } 55 } 56 57 impl<S, Request> Service<Request> for ConcurrencyLimit<S> 58 where 59 S: Service<Request>, 60 { 61 type Response = S::Response; 62 type Error = S::Error; 63 type Future = ResponseFuture<S::Future>; 64 poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>65 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { 66 // If we haven't already acquired a permit from the semaphore, try to 67 // acquire one first. 68 if self.permit.is_none() { 69 self.permit = ready!(self.semaphore.poll_acquire(cx)); 70 debug_assert!( 71 self.permit.is_some(), 72 "ConcurrencyLimit semaphore is never closed, so `poll_acquire` \ 73 should never fail", 74 ); 75 } 76 77 // Once we've acquired a permit (or if we already had one), poll the 78 // inner service. 79 self.inner.poll_ready(cx) 80 } 81 call(&mut self, request: Request) -> Self::Future82 fn call(&mut self, request: Request) -> Self::Future { 83 // Take the permit 84 let permit = self 85 .permit 86 .take() 87 .expect("max requests in-flight; poll_ready must be called first"); 88 89 // Call the inner service 90 let future = self.inner.call(request); 91 92 ResponseFuture::new(future, permit) 93 } 94 } 95 96 impl<T: Clone> Clone for ConcurrencyLimit<T> { clone(&self) -> Self97 fn clone(&self) -> Self { 98 // Since we hold an `OwnedSemaphorePermit`, we can't derive `Clone`. 99 // Instead, when cloning the service, create a new service with the 100 // same semaphore, but with the permit in the un-acquired state. 101 Self { 102 inner: self.inner.clone(), 103 semaphore: self.semaphore.clone(), 104 permit: None, 105 } 106 } 107 } 108 109 #[cfg(feature = "load")] 110 #[cfg_attr(docsrs, doc(cfg(feature = "load")))] 111 impl<S> crate::load::Load for ConcurrencyLimit<S> 112 where 113 S: crate::load::Load, 114 { 115 type Metric = S::Metric; load(&self) -> Self::Metric116 fn load(&self) -> Self::Metric { 117 self.inner.load() 118 } 119 } 120