1 //! A `Load` implementation that measures load using the PeakEWMA response latency.
2
3 #[cfg(feature = "discover")]
4 use crate::discover::{Change, Discover};
5 #[cfg(feature = "discover")]
6 use futures_core::{ready, Stream};
7 #[cfg(feature = "discover")]
8 use pin_project_lite::pin_project;
9 #[cfg(feature = "discover")]
10 use std::pin::Pin;
11
12 use super::completion::{CompleteOnResponse, TrackCompletion, TrackCompletionFuture};
13 use super::Load;
14 use std::task::{Context, Poll};
15 use std::{
16 sync::{Arc, Mutex},
17 time::Duration,
18 };
19 use tokio::time::Instant;
20 use tower_service::Service;
21 use tracing::trace;
22
23 /// Measures the load of the underlying service using Peak-EWMA load measurement.
24 ///
25 /// [`PeakEwma`] implements [`Load`] with the [`Cost`] metric that estimates the amount of
26 /// pending work to an endpoint. Work is calculated by multiplying the
27 /// exponentially-weighted moving average (EWMA) of response latencies by the number of
28 /// pending requests. The Peak-EWMA algorithm is designed to be especially sensitive to
29 /// worst-case latencies. Over time, the peak latency value decays towards the moving
30 /// average of latencies to the endpoint.
31 ///
32 /// When no latency information has been measured for an endpoint, an arbitrary default
33 /// RTT of 1 second is used to prevent the endpoint from being overloaded before a
34 /// meaningful baseline can be established..
35 ///
36 /// ## Note
37 ///
38 /// This is derived from [Finagle][finagle], which is distributed under the Apache V2
39 /// license. Copyright 2017, Twitter Inc.
40 ///
41 /// [finagle]:
42 /// https://github.com/twitter/finagle/blob/9cc08d15216497bb03a1cafda96b7266cfbbcff1/finagle-core/src/main/scala/com/twitter/finagle/loadbalancer/PeakEwma.scala
43 #[derive(Debug)]
44 pub struct PeakEwma<S, C = CompleteOnResponse> {
45 service: S,
46 decay_ns: f64,
47 rtt_estimate: Arc<Mutex<RttEstimate>>,
48 completion: C,
49 }
50
51 #[cfg(feature = "discover")]
52 pin_project! {
53 /// Wraps a `D`-typed stream of discovered services with `PeakEwma`.
54 #[cfg_attr(docsrs, doc(cfg(feature = "discover")))]
55 #[derive(Debug)]
56 pub struct PeakEwmaDiscover<D, C = CompleteOnResponse> {
57 #[pin]
58 discover: D,
59 decay_ns: f64,
60 default_rtt: Duration,
61 completion: C,
62 }
63 }
64
65 /// Represents the relative cost of communicating with a service.
66 ///
67 /// The underlying value estimates the amount of pending work to a service: the Peak-EWMA
68 /// latency estimate multiplied by the number of pending requests.
69 #[derive(Copy, Clone, Debug, PartialEq, PartialOrd)]
70 pub struct Cost(f64);
71
72 /// Tracks an in-flight request and updates the RTT-estimate on Drop.
73 #[derive(Debug)]
74 pub struct Handle {
75 sent_at: Instant,
76 decay_ns: f64,
77 rtt_estimate: Arc<Mutex<RttEstimate>>,
78 }
79
80 /// Holds the current RTT estimate and the last time this value was updated.
81 #[derive(Debug)]
82 struct RttEstimate {
83 update_at: Instant,
84 rtt_ns: f64,
85 }
86
87 const NANOS_PER_MILLI: f64 = 1_000_000.0;
88
89 // ===== impl PeakEwma =====
90
91 impl<S, C> PeakEwma<S, C> {
92 /// Wraps an `S`-typed service so that its load is tracked by the EWMA of its peak latency.
new(service: S, default_rtt: Duration, decay_ns: f64, completion: C) -> Self93 pub fn new(service: S, default_rtt: Duration, decay_ns: f64, completion: C) -> Self {
94 debug_assert!(decay_ns > 0.0, "decay_ns must be positive");
95 Self {
96 service,
97 decay_ns,
98 rtt_estimate: Arc::new(Mutex::new(RttEstimate::new(nanos(default_rtt)))),
99 completion,
100 }
101 }
102
handle(&self) -> Handle103 fn handle(&self) -> Handle {
104 Handle {
105 decay_ns: self.decay_ns,
106 sent_at: Instant::now(),
107 rtt_estimate: self.rtt_estimate.clone(),
108 }
109 }
110 }
111
112 impl<S, C, Request> Service<Request> for PeakEwma<S, C>
113 where
114 S: Service<Request>,
115 C: TrackCompletion<Handle, S::Response>,
116 {
117 type Response = C::Output;
118 type Error = S::Error;
119 type Future = TrackCompletionFuture<S::Future, C, Handle>;
120
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>121 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
122 self.service.poll_ready(cx)
123 }
124
call(&mut self, req: Request) -> Self::Future125 fn call(&mut self, req: Request) -> Self::Future {
126 TrackCompletionFuture::new(
127 self.completion.clone(),
128 self.handle(),
129 self.service.call(req),
130 )
131 }
132 }
133
134 impl<S, C> Load for PeakEwma<S, C> {
135 type Metric = Cost;
136
load(&self) -> Self::Metric137 fn load(&self) -> Self::Metric {
138 let pending = Arc::strong_count(&self.rtt_estimate) as u32 - 1;
139
140 // Update the RTT estimate to account for decay since the last update.
141 // If an estimate has not been established, a default is provided
142 let estimate = self.update_estimate();
143
144 let cost = Cost(estimate * f64::from(pending + 1));
145 trace!(
146 "load estimate={:.0}ms pending={} cost={:?}",
147 estimate / NANOS_PER_MILLI,
148 pending,
149 cost,
150 );
151 cost
152 }
153 }
154
155 impl<S, C> PeakEwma<S, C> {
update_estimate(&self) -> f64156 fn update_estimate(&self) -> f64 {
157 let mut rtt = self.rtt_estimate.lock().expect("peak ewma prior_estimate");
158 rtt.decay(self.decay_ns)
159 }
160 }
161
162 // ===== impl PeakEwmaDiscover =====
163
164 #[cfg(feature = "discover")]
165 impl<D, C> PeakEwmaDiscover<D, C> {
166 /// Wraps a `D`-typed [`Discover`] so that services have a [`PeakEwma`] load metric.
167 ///
168 /// The provided `default_rtt` is used as the default RTT estimate for newly
169 /// added services.
170 ///
171 /// They `decay` value determines over what time period a RTT estimate should
172 /// decay.
new<Request>(discover: D, default_rtt: Duration, decay: Duration, completion: C) -> Self where D: Discover, D::Service: Service<Request>, C: TrackCompletion<Handle, <D::Service as Service<Request>>::Response>,173 pub fn new<Request>(discover: D, default_rtt: Duration, decay: Duration, completion: C) -> Self
174 where
175 D: Discover,
176 D::Service: Service<Request>,
177 C: TrackCompletion<Handle, <D::Service as Service<Request>>::Response>,
178 {
179 PeakEwmaDiscover {
180 discover,
181 decay_ns: nanos(decay),
182 default_rtt,
183 completion,
184 }
185 }
186 }
187
188 #[cfg(feature = "discover")]
189 #[cfg_attr(docsrs, doc(cfg(feature = "discover")))]
190 impl<D, C> Stream for PeakEwmaDiscover<D, C>
191 where
192 D: Discover,
193 C: Clone,
194 {
195 type Item = Result<Change<D::Key, PeakEwma<D::Service, C>>, D::Error>;
196
poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>197 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
198 let this = self.project();
199 let change = match ready!(this.discover.poll_discover(cx)).transpose()? {
200 None => return Poll::Ready(None),
201 Some(Change::Remove(k)) => Change::Remove(k),
202 Some(Change::Insert(k, svc)) => {
203 let peak_ewma = PeakEwma::new(
204 svc,
205 *this.default_rtt,
206 *this.decay_ns,
207 this.completion.clone(),
208 );
209 Change::Insert(k, peak_ewma)
210 }
211 };
212
213 Poll::Ready(Some(Ok(change)))
214 }
215 }
216
217 // ===== impl RttEstimate =====
218
219 impl RttEstimate {
new(rtt_ns: f64) -> Self220 fn new(rtt_ns: f64) -> Self {
221 debug_assert!(0.0 < rtt_ns, "rtt must be positive");
222 Self {
223 rtt_ns,
224 update_at: Instant::now(),
225 }
226 }
227
228 /// Decays the RTT estimate with a decay period of `decay_ns`.
decay(&mut self, decay_ns: f64) -> f64229 fn decay(&mut self, decay_ns: f64) -> f64 {
230 // Updates with a 0 duration so that the estimate decays towards 0.
231 let now = Instant::now();
232 self.update(now, now, decay_ns)
233 }
234
235 /// Updates the Peak-EWMA RTT estimate.
236 ///
237 /// The elapsed time from `sent_at` to `recv_at` is added
update(&mut self, sent_at: Instant, recv_at: Instant, decay_ns: f64) -> f64238 fn update(&mut self, sent_at: Instant, recv_at: Instant, decay_ns: f64) -> f64 {
239 debug_assert!(
240 sent_at <= recv_at,
241 "recv_at={:?} after sent_at={:?}",
242 recv_at,
243 sent_at
244 );
245 let rtt = nanos(recv_at.saturating_duration_since(sent_at));
246
247 let now = Instant::now();
248 debug_assert!(
249 self.update_at <= now,
250 "update_at={:?} in the future",
251 self.update_at
252 );
253
254 self.rtt_ns = if self.rtt_ns < rtt {
255 // For Peak-EWMA, always use the worst-case (peak) value as the estimate for
256 // subsequent requests.
257 trace!(
258 "update peak rtt={}ms prior={}ms",
259 rtt / NANOS_PER_MILLI,
260 self.rtt_ns / NANOS_PER_MILLI,
261 );
262 rtt
263 } else {
264 // When an RTT is observed that is less than the estimated RTT, we decay the
265 // prior estimate according to how much time has elapsed since the last
266 // update. The inverse of the decay is used to scale the estimate towards the
267 // observed RTT value.
268 let elapsed = nanos(now.saturating_duration_since(self.update_at));
269 let decay = (-elapsed / decay_ns).exp();
270 let recency = 1.0 - decay;
271 let next_estimate = (self.rtt_ns * decay) + (rtt * recency);
272 trace!(
273 "update rtt={:03.0}ms decay={:06.0}ns; next={:03.0}ms",
274 rtt / NANOS_PER_MILLI,
275 self.rtt_ns - next_estimate,
276 next_estimate / NANOS_PER_MILLI,
277 );
278 next_estimate
279 };
280 self.update_at = now;
281
282 self.rtt_ns
283 }
284 }
285
286 // ===== impl Handle =====
287
288 impl Drop for Handle {
drop(&mut self)289 fn drop(&mut self) {
290 let recv_at = Instant::now();
291
292 if let Ok(mut rtt) = self.rtt_estimate.lock() {
293 rtt.update(self.sent_at, recv_at, self.decay_ns);
294 }
295 }
296 }
297
298 // ===== impl Cost =====
299
300 // Utility that converts durations to nanos in f64.
301 //
302 // Due to a lossy transformation, the maximum value that can be represented is ~585 years,
303 // which, I hope, is more than enough to represent request latencies.
nanos(d: Duration) -> f64304 fn nanos(d: Duration) -> f64 {
305 const NANOS_PER_SEC: u64 = 1_000_000_000;
306 let n = f64::from(d.subsec_nanos());
307 let s = d.as_secs().saturating_mul(NANOS_PER_SEC) as f64;
308 n + s
309 }
310
311 #[cfg(test)]
312 mod tests {
313 use futures_util::future;
314 use std::time::Duration;
315 use tokio::time;
316 use tokio_test::{assert_ready, assert_ready_ok, task};
317
318 use super::*;
319
320 struct Svc;
321 impl Service<()> for Svc {
322 type Response = ();
323 type Error = ();
324 type Future = future::Ready<Result<(), ()>>;
325
poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), ()>>326 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), ()>> {
327 Poll::Ready(Ok(()))
328 }
329
call(&mut self, (): ()) -> Self::Future330 fn call(&mut self, (): ()) -> Self::Future {
331 future::ok(())
332 }
333 }
334
335 /// The default RTT estimate decays, so that new nodes are considered if the
336 /// default RTT is too high.
337 #[tokio::test]
default_decay()338 async fn default_decay() {
339 time::pause();
340
341 let svc = PeakEwma::new(
342 Svc,
343 Duration::from_millis(10),
344 NANOS_PER_MILLI * 1_000.0,
345 CompleteOnResponse,
346 );
347 let Cost(load) = svc.load();
348 assert_eq!(load, 10.0 * NANOS_PER_MILLI);
349
350 time::advance(Duration::from_millis(100)).await;
351 let Cost(load) = svc.load();
352 assert!(9.0 * NANOS_PER_MILLI < load && load < 10.0 * NANOS_PER_MILLI);
353
354 time::advance(Duration::from_millis(100)).await;
355 let Cost(load) = svc.load();
356 assert!(8.0 * NANOS_PER_MILLI < load && load < 9.0 * NANOS_PER_MILLI);
357 }
358
359 // The default RTT estimate decays, so that new nodes are considered if the default RTT is too
360 // high.
361 #[tokio::test]
compound_decay()362 async fn compound_decay() {
363 time::pause();
364
365 let mut svc = PeakEwma::new(
366 Svc,
367 Duration::from_millis(20),
368 NANOS_PER_MILLI * 1_000.0,
369 CompleteOnResponse,
370 );
371 assert_eq!(svc.load(), Cost(20.0 * NANOS_PER_MILLI));
372
373 time::advance(Duration::from_millis(100)).await;
374 let mut rsp0 = task::spawn(svc.call(()));
375 assert!(svc.load() > Cost(20.0 * NANOS_PER_MILLI));
376
377 time::advance(Duration::from_millis(100)).await;
378 let mut rsp1 = task::spawn(svc.call(()));
379 assert!(svc.load() > Cost(40.0 * NANOS_PER_MILLI));
380
381 time::advance(Duration::from_millis(100)).await;
382 let () = assert_ready_ok!(rsp0.poll());
383 assert_eq!(svc.load(), Cost(400_000_000.0));
384
385 time::advance(Duration::from_millis(100)).await;
386 let () = assert_ready_ok!(rsp1.poll());
387 assert_eq!(svc.load(), Cost(200_000_000.0));
388
389 // Check that values decay as time elapses
390 time::advance(Duration::from_secs(1)).await;
391 assert!(svc.load() < Cost(100_000_000.0));
392
393 time::advance(Duration::from_secs(10)).await;
394 assert!(svc.load() < Cost(100_000.0));
395 }
396
397 #[test]
nanos()398 fn nanos() {
399 assert_eq!(super::nanos(Duration::new(0, 0)), 0.0);
400 assert_eq!(super::nanos(Duration::new(0, 123)), 123.0);
401 assert_eq!(super::nanos(Duration::new(1, 23)), 1_000_000_023.0);
402 assert_eq!(
403 super::nanos(Duration::new(::std::u64::MAX, 999_999_999)),
404 18446744074709553000.0
405 );
406 }
407 }
408