1 //! A `Load` implementation that measures load using the PeakEWMA response latency.
2 
3 #[cfg(feature = "discover")]
4 use crate::discover::{Change, Discover};
5 #[cfg(feature = "discover")]
6 use futures_core::{ready, Stream};
7 #[cfg(feature = "discover")]
8 use pin_project_lite::pin_project;
9 #[cfg(feature = "discover")]
10 use std::pin::Pin;
11 
12 use super::completion::{CompleteOnResponse, TrackCompletion, TrackCompletionFuture};
13 use super::Load;
14 use std::task::{Context, Poll};
15 use std::{
16     sync::{Arc, Mutex},
17     time::Duration,
18 };
19 use tokio::time::Instant;
20 use tower_service::Service;
21 use tracing::trace;
22 
23 /// Measures the load of the underlying service using Peak-EWMA load measurement.
24 ///
25 /// [`PeakEwma`] implements [`Load`] with the [`Cost`] metric that estimates the amount of
26 /// pending work to an endpoint. Work is calculated by multiplying the
27 /// exponentially-weighted moving average (EWMA) of response latencies by the number of
28 /// pending requests. The Peak-EWMA algorithm is designed to be especially sensitive to
29 /// worst-case latencies. Over time, the peak latency value decays towards the moving
30 /// average of latencies to the endpoint.
31 ///
32 /// When no latency information has been measured for an endpoint, an arbitrary default
33 /// RTT of 1 second is used to prevent the endpoint from being overloaded before a
34 /// meaningful baseline can be established..
35 ///
36 /// ## Note
37 ///
38 /// This is derived from [Finagle][finagle], which is distributed under the Apache V2
39 /// license. Copyright 2017, Twitter Inc.
40 ///
41 /// [finagle]:
42 /// https://github.com/twitter/finagle/blob/9cc08d15216497bb03a1cafda96b7266cfbbcff1/finagle-core/src/main/scala/com/twitter/finagle/loadbalancer/PeakEwma.scala
43 #[derive(Debug)]
44 pub struct PeakEwma<S, C = CompleteOnResponse> {
45     service: S,
46     decay_ns: f64,
47     rtt_estimate: Arc<Mutex<RttEstimate>>,
48     completion: C,
49 }
50 
51 #[cfg(feature = "discover")]
52 pin_project! {
53     /// Wraps a `D`-typed stream of discovered services with `PeakEwma`.
54     #[cfg_attr(docsrs, doc(cfg(feature = "discover")))]
55     #[derive(Debug)]
56     pub struct PeakEwmaDiscover<D, C = CompleteOnResponse> {
57         #[pin]
58         discover: D,
59         decay_ns: f64,
60         default_rtt: Duration,
61         completion: C,
62     }
63 }
64 
65 /// Represents the relative cost of communicating with a service.
66 ///
67 /// The underlying value estimates the amount of pending work to a service: the Peak-EWMA
68 /// latency estimate multiplied by the number of pending requests.
69 #[derive(Copy, Clone, Debug, PartialEq, PartialOrd)]
70 pub struct Cost(f64);
71 
72 /// Tracks an in-flight request and updates the RTT-estimate on Drop.
73 #[derive(Debug)]
74 pub struct Handle {
75     sent_at: Instant,
76     decay_ns: f64,
77     rtt_estimate: Arc<Mutex<RttEstimate>>,
78 }
79 
80 /// Holds the current RTT estimate and the last time this value was updated.
81 #[derive(Debug)]
82 struct RttEstimate {
83     update_at: Instant,
84     rtt_ns: f64,
85 }
86 
87 const NANOS_PER_MILLI: f64 = 1_000_000.0;
88 
89 // ===== impl PeakEwma =====
90 
91 impl<S, C> PeakEwma<S, C> {
92     /// Wraps an `S`-typed service so that its load is tracked by the EWMA of its peak latency.
new(service: S, default_rtt: Duration, decay_ns: f64, completion: C) -> Self93     pub fn new(service: S, default_rtt: Duration, decay_ns: f64, completion: C) -> Self {
94         debug_assert!(decay_ns > 0.0, "decay_ns must be positive");
95         Self {
96             service,
97             decay_ns,
98             rtt_estimate: Arc::new(Mutex::new(RttEstimate::new(nanos(default_rtt)))),
99             completion,
100         }
101     }
102 
handle(&self) -> Handle103     fn handle(&self) -> Handle {
104         Handle {
105             decay_ns: self.decay_ns,
106             sent_at: Instant::now(),
107             rtt_estimate: self.rtt_estimate.clone(),
108         }
109     }
110 }
111 
112 impl<S, C, Request> Service<Request> for PeakEwma<S, C>
113 where
114     S: Service<Request>,
115     C: TrackCompletion<Handle, S::Response>,
116 {
117     type Response = C::Output;
118     type Error = S::Error;
119     type Future = TrackCompletionFuture<S::Future, C, Handle>;
120 
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>121     fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
122         self.service.poll_ready(cx)
123     }
124 
call(&mut self, req: Request) -> Self::Future125     fn call(&mut self, req: Request) -> Self::Future {
126         TrackCompletionFuture::new(
127             self.completion.clone(),
128             self.handle(),
129             self.service.call(req),
130         )
131     }
132 }
133 
134 impl<S, C> Load for PeakEwma<S, C> {
135     type Metric = Cost;
136 
load(&self) -> Self::Metric137     fn load(&self) -> Self::Metric {
138         let pending = Arc::strong_count(&self.rtt_estimate) as u32 - 1;
139 
140         // Update the RTT estimate to account for decay since the last update.
141         // If an estimate has not been established, a default is provided
142         let estimate = self.update_estimate();
143 
144         let cost = Cost(estimate * f64::from(pending + 1));
145         trace!(
146             "load estimate={:.0}ms pending={} cost={:?}",
147             estimate / NANOS_PER_MILLI,
148             pending,
149             cost,
150         );
151         cost
152     }
153 }
154 
155 impl<S, C> PeakEwma<S, C> {
update_estimate(&self) -> f64156     fn update_estimate(&self) -> f64 {
157         let mut rtt = self.rtt_estimate.lock().expect("peak ewma prior_estimate");
158         rtt.decay(self.decay_ns)
159     }
160 }
161 
162 // ===== impl PeakEwmaDiscover =====
163 
164 #[cfg(feature = "discover")]
165 impl<D, C> PeakEwmaDiscover<D, C> {
166     /// Wraps a `D`-typed [`Discover`] so that services have a [`PeakEwma`] load metric.
167     ///
168     /// The provided `default_rtt` is used as the default RTT estimate for newly
169     /// added services.
170     ///
171     /// They `decay` value determines over what time period a RTT estimate should
172     /// decay.
new<Request>(discover: D, default_rtt: Duration, decay: Duration, completion: C) -> Self where D: Discover, D::Service: Service<Request>, C: TrackCompletion<Handle, <D::Service as Service<Request>>::Response>,173     pub fn new<Request>(discover: D, default_rtt: Duration, decay: Duration, completion: C) -> Self
174     where
175         D: Discover,
176         D::Service: Service<Request>,
177         C: TrackCompletion<Handle, <D::Service as Service<Request>>::Response>,
178     {
179         PeakEwmaDiscover {
180             discover,
181             decay_ns: nanos(decay),
182             default_rtt,
183             completion,
184         }
185     }
186 }
187 
188 #[cfg(feature = "discover")]
189 #[cfg_attr(docsrs, doc(cfg(feature = "discover")))]
190 impl<D, C> Stream for PeakEwmaDiscover<D, C>
191 where
192     D: Discover,
193     C: Clone,
194 {
195     type Item = Result<Change<D::Key, PeakEwma<D::Service, C>>, D::Error>;
196 
poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>197     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
198         let this = self.project();
199         let change = match ready!(this.discover.poll_discover(cx)).transpose()? {
200             None => return Poll::Ready(None),
201             Some(Change::Remove(k)) => Change::Remove(k),
202             Some(Change::Insert(k, svc)) => {
203                 let peak_ewma = PeakEwma::new(
204                     svc,
205                     *this.default_rtt,
206                     *this.decay_ns,
207                     this.completion.clone(),
208                 );
209                 Change::Insert(k, peak_ewma)
210             }
211         };
212 
213         Poll::Ready(Some(Ok(change)))
214     }
215 }
216 
217 // ===== impl RttEstimate =====
218 
219 impl RttEstimate {
new(rtt_ns: f64) -> Self220     fn new(rtt_ns: f64) -> Self {
221         debug_assert!(0.0 < rtt_ns, "rtt must be positive");
222         Self {
223             rtt_ns,
224             update_at: Instant::now(),
225         }
226     }
227 
228     /// Decays the RTT estimate with a decay period of `decay_ns`.
decay(&mut self, decay_ns: f64) -> f64229     fn decay(&mut self, decay_ns: f64) -> f64 {
230         // Updates with a 0 duration so that the estimate decays towards 0.
231         let now = Instant::now();
232         self.update(now, now, decay_ns)
233     }
234 
235     /// Updates the Peak-EWMA RTT estimate.
236     ///
237     /// The elapsed time from `sent_at` to `recv_at` is added
update(&mut self, sent_at: Instant, recv_at: Instant, decay_ns: f64) -> f64238     fn update(&mut self, sent_at: Instant, recv_at: Instant, decay_ns: f64) -> f64 {
239         debug_assert!(
240             sent_at <= recv_at,
241             "recv_at={:?} after sent_at={:?}",
242             recv_at,
243             sent_at
244         );
245         let rtt = nanos(recv_at.saturating_duration_since(sent_at));
246 
247         let now = Instant::now();
248         debug_assert!(
249             self.update_at <= now,
250             "update_at={:?} in the future",
251             self.update_at
252         );
253 
254         self.rtt_ns = if self.rtt_ns < rtt {
255             // For Peak-EWMA, always use the worst-case (peak) value as the estimate for
256             // subsequent requests.
257             trace!(
258                 "update peak rtt={}ms prior={}ms",
259                 rtt / NANOS_PER_MILLI,
260                 self.rtt_ns / NANOS_PER_MILLI,
261             );
262             rtt
263         } else {
264             // When an RTT is observed that is less than the estimated RTT, we decay the
265             // prior estimate according to how much time has elapsed since the last
266             // update. The inverse of the decay is used to scale the estimate towards the
267             // observed RTT value.
268             let elapsed = nanos(now.saturating_duration_since(self.update_at));
269             let decay = (-elapsed / decay_ns).exp();
270             let recency = 1.0 - decay;
271             let next_estimate = (self.rtt_ns * decay) + (rtt * recency);
272             trace!(
273                 "update rtt={:03.0}ms decay={:06.0}ns; next={:03.0}ms",
274                 rtt / NANOS_PER_MILLI,
275                 self.rtt_ns - next_estimate,
276                 next_estimate / NANOS_PER_MILLI,
277             );
278             next_estimate
279         };
280         self.update_at = now;
281 
282         self.rtt_ns
283     }
284 }
285 
286 // ===== impl Handle =====
287 
288 impl Drop for Handle {
drop(&mut self)289     fn drop(&mut self) {
290         let recv_at = Instant::now();
291 
292         if let Ok(mut rtt) = self.rtt_estimate.lock() {
293             rtt.update(self.sent_at, recv_at, self.decay_ns);
294         }
295     }
296 }
297 
298 // ===== impl Cost =====
299 
300 // Utility that converts durations to nanos in f64.
301 //
302 // Due to a lossy transformation, the maximum value that can be represented is ~585 years,
303 // which, I hope, is more than enough to represent request latencies.
nanos(d: Duration) -> f64304 fn nanos(d: Duration) -> f64 {
305     const NANOS_PER_SEC: u64 = 1_000_000_000;
306     let n = f64::from(d.subsec_nanos());
307     let s = d.as_secs().saturating_mul(NANOS_PER_SEC) as f64;
308     n + s
309 }
310 
311 #[cfg(test)]
312 mod tests {
313     use futures_util::future;
314     use std::time::Duration;
315     use tokio::time;
316     use tokio_test::{assert_ready, assert_ready_ok, task};
317 
318     use super::*;
319 
320     struct Svc;
321     impl Service<()> for Svc {
322         type Response = ();
323         type Error = ();
324         type Future = future::Ready<Result<(), ()>>;
325 
poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), ()>>326         fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), ()>> {
327             Poll::Ready(Ok(()))
328         }
329 
call(&mut self, (): ()) -> Self::Future330         fn call(&mut self, (): ()) -> Self::Future {
331             future::ok(())
332         }
333     }
334 
335     /// The default RTT estimate decays, so that new nodes are considered if the
336     /// default RTT is too high.
337     #[tokio::test]
default_decay()338     async fn default_decay() {
339         time::pause();
340 
341         let svc = PeakEwma::new(
342             Svc,
343             Duration::from_millis(10),
344             NANOS_PER_MILLI * 1_000.0,
345             CompleteOnResponse,
346         );
347         let Cost(load) = svc.load();
348         assert_eq!(load, 10.0 * NANOS_PER_MILLI);
349 
350         time::advance(Duration::from_millis(100)).await;
351         let Cost(load) = svc.load();
352         assert!(9.0 * NANOS_PER_MILLI < load && load < 10.0 * NANOS_PER_MILLI);
353 
354         time::advance(Duration::from_millis(100)).await;
355         let Cost(load) = svc.load();
356         assert!(8.0 * NANOS_PER_MILLI < load && load < 9.0 * NANOS_PER_MILLI);
357     }
358 
359     // The default RTT estimate decays, so that new nodes are considered if the default RTT is too
360     // high.
361     #[tokio::test]
compound_decay()362     async fn compound_decay() {
363         time::pause();
364 
365         let mut svc = PeakEwma::new(
366             Svc,
367             Duration::from_millis(20),
368             NANOS_PER_MILLI * 1_000.0,
369             CompleteOnResponse,
370         );
371         assert_eq!(svc.load(), Cost(20.0 * NANOS_PER_MILLI));
372 
373         time::advance(Duration::from_millis(100)).await;
374         let mut rsp0 = task::spawn(svc.call(()));
375         assert!(svc.load() > Cost(20.0 * NANOS_PER_MILLI));
376 
377         time::advance(Duration::from_millis(100)).await;
378         let mut rsp1 = task::spawn(svc.call(()));
379         assert!(svc.load() > Cost(40.0 * NANOS_PER_MILLI));
380 
381         time::advance(Duration::from_millis(100)).await;
382         let () = assert_ready_ok!(rsp0.poll());
383         assert_eq!(svc.load(), Cost(400_000_000.0));
384 
385         time::advance(Duration::from_millis(100)).await;
386         let () = assert_ready_ok!(rsp1.poll());
387         assert_eq!(svc.load(), Cost(200_000_000.0));
388 
389         // Check that values decay as time elapses
390         time::advance(Duration::from_secs(1)).await;
391         assert!(svc.load() < Cost(100_000_000.0));
392 
393         time::advance(Duration::from_secs(10)).await;
394         assert!(svc.load() < Cost(100_000.0));
395     }
396 
397     #[test]
nanos()398     fn nanos() {
399         assert_eq!(super::nanos(Duration::new(0, 0)), 0.0);
400         assert_eq!(super::nanos(Duration::new(0, 123)), 123.0);
401         assert_eq!(super::nanos(Duration::new(1, 23)), 1_000_000_023.0);
402         assert_eq!(
403             super::nanos(Duration::new(::std::u64::MAX, 999_999_999)),
404             18446744074709553000.0
405         );
406     }
407 }
408