use super::future::ResponseFuture; use tokio::sync::{OwnedSemaphorePermit, Semaphore}; use tokio_util::sync::PollSemaphore; use tower_service::Service; use futures_core::ready; use std::{ sync::Arc, task::{Context, Poll}, }; /// Enforces a limit on the concurrent number of requests the underlying /// service can handle. #[derive(Debug)] pub struct ConcurrencyLimit { inner: T, semaphore: PollSemaphore, /// The currently acquired semaphore permit, if there is sufficient /// concurrency to send a new request. /// /// The permit is acquired in `poll_ready`, and taken in `call` when sending /// a new request. permit: Option, } impl ConcurrencyLimit { /// Create a new concurrency limiter. pub fn new(inner: T, max: usize) -> Self { Self::with_semaphore(inner, Arc::new(Semaphore::new(max))) } /// Create a new concurrency limiter with a provided shared semaphore pub fn with_semaphore(inner: T, semaphore: Arc) -> Self { ConcurrencyLimit { inner, semaphore: PollSemaphore::new(semaphore), permit: None, } } /// Get a reference to the inner service pub fn get_ref(&self) -> &T { &self.inner } /// Get a mutable reference to the inner service pub fn get_mut(&mut self) -> &mut T { &mut self.inner } /// Consume `self`, returning the inner service pub fn into_inner(self) -> T { self.inner } } impl Service for ConcurrencyLimit where S: Service, { type Response = S::Response; type Error = S::Error; type Future = ResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { // If we haven't already acquired a permit from the semaphore, try to // acquire one first. if self.permit.is_none() { self.permit = ready!(self.semaphore.poll_acquire(cx)); debug_assert!( self.permit.is_some(), "ConcurrencyLimit semaphore is never closed, so `poll_acquire` \ should never fail", ); } // Once we've acquired a permit (or if we already had one), poll the // inner service. self.inner.poll_ready(cx) } fn call(&mut self, request: Request) -> Self::Future { // Take the permit let permit = self .permit .take() .expect("max requests in-flight; poll_ready must be called first"); // Call the inner service let future = self.inner.call(request); ResponseFuture::new(future, permit) } } impl Clone for ConcurrencyLimit { fn clone(&self) -> Self { // Since we hold an `OwnedSemaphorePermit`, we can't derive `Clone`. // Instead, when cloning the service, create a new service with the // same semaphore, but with the permit in the un-acquired state. Self { inner: self.inner.clone(), semaphore: self.semaphore.clone(), permit: None, } } } #[cfg(feature = "load")] #[cfg_attr(docsrs, doc(cfg(feature = "load")))] impl crate::load::Load for ConcurrencyLimit where S: crate::load::Load, { type Metric = S::Metric; fn load(&self) -> Self::Metric { self.inner.load() } }