xref: /aosp_15_r20/external/grpc-grpc-java/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java (revision e07d83d3ffcef9ecfc9f7f475418ec639ff0e5fe)
1 /*
2  * Copyright 2023 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.xds;
18 
19 import static com.google.common.base.Preconditions.checkArgument;
20 import static com.google.common.base.Preconditions.checkNotNull;
21 
22 import com.google.common.annotations.VisibleForTesting;
23 import com.google.common.base.MoreObjects;
24 import com.google.common.base.Preconditions;
25 import io.grpc.ConnectivityState;
26 import io.grpc.ConnectivityStateInfo;
27 import io.grpc.Deadline.Ticker;
28 import io.grpc.EquivalentAddressGroup;
29 import io.grpc.ExperimentalApi;
30 import io.grpc.LoadBalancer;
31 import io.grpc.NameResolver;
32 import io.grpc.Status;
33 import io.grpc.SynchronizationContext;
34 import io.grpc.SynchronizationContext.ScheduledHandle;
35 import io.grpc.services.MetricReport;
36 import io.grpc.util.ForwardingLoadBalancerHelper;
37 import io.grpc.util.ForwardingSubchannel;
38 import io.grpc.util.RoundRobinLoadBalancer;
39 import io.grpc.xds.orca.OrcaOobUtil;
40 import io.grpc.xds.orca.OrcaOobUtil.OrcaOobReportListener;
41 import io.grpc.xds.orca.OrcaPerRequestUtil;
42 import io.grpc.xds.orca.OrcaPerRequestUtil.OrcaPerRequestReportListener;
43 import java.util.HashMap;
44 import java.util.HashSet;
45 import java.util.List;
46 import java.util.Map;
47 import java.util.PriorityQueue;
48 import java.util.Random;
49 import java.util.concurrent.ScheduledExecutorService;
50 import java.util.concurrent.TimeUnit;
51 import java.util.logging.Level;
52 import java.util.logging.Logger;
53 
54 /**
55  * A {@link LoadBalancer} that provides weighted-round-robin load-balancing over
56  * the {@link EquivalentAddressGroup}s from the {@link NameResolver}. The subchannel weights are
57  * determined by backend metrics using ORCA.
58  */
59 @ExperimentalApi("https://github.com/grpc/grpc-java/issues/9885")
60 final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
61   private static final Logger log = Logger.getLogger(
62       WeightedRoundRobinLoadBalancer.class.getName());
63   private WeightedRoundRobinLoadBalancerConfig config;
64   private final SynchronizationContext syncContext;
65   private final ScheduledExecutorService timeService;
66   private ScheduledHandle weightUpdateTimer;
67   private final Runnable updateWeightTask;
68   private final Random random;
69   private final long infTime;
70   private final Ticker ticker;
71 
WeightedRoundRobinLoadBalancer(Helper helper, Ticker ticker)72   public WeightedRoundRobinLoadBalancer(Helper helper, Ticker ticker) {
73     this(new WrrHelper(OrcaOobUtil.newOrcaReportingHelper(helper)), ticker, new Random());
74   }
75 
WeightedRoundRobinLoadBalancer(WrrHelper helper, Ticker ticker, Random random)76   public WeightedRoundRobinLoadBalancer(WrrHelper helper, Ticker ticker, Random random) {
77     super(helper);
78     helper.setLoadBalancer(this);
79     this.ticker = checkNotNull(ticker, "ticker");
80     this.infTime = ticker.nanoTime() + Long.MAX_VALUE;
81     this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext");
82     this.timeService = checkNotNull(helper.getScheduledExecutorService(), "timeService");
83     this.updateWeightTask = new UpdateWeightTask();
84     this.random = random;
85     log.log(Level.FINE, "weighted_round_robin LB created");
86   }
87 
88   @VisibleForTesting
WeightedRoundRobinLoadBalancer(Helper helper, Ticker ticker, Random random)89   WeightedRoundRobinLoadBalancer(Helper helper, Ticker ticker, Random random) {
90     this(new WrrHelper(OrcaOobUtil.newOrcaReportingHelper(helper)), ticker, random);
91   }
92 
93   @Override
acceptResolvedAddresses(ResolvedAddresses resolvedAddresses)94   public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) {
95     if (resolvedAddresses.getLoadBalancingPolicyConfig() == null) {
96       handleNameResolutionError(Status.UNAVAILABLE.withDescription(
97               "NameResolver returned no WeightedRoundRobinLoadBalancerConfig. addrs="
98                       + resolvedAddresses.getAddresses()
99                       + ", attrs=" + resolvedAddresses.getAttributes()));
100       return false;
101     }
102     config =
103             (WeightedRoundRobinLoadBalancerConfig) resolvedAddresses.getLoadBalancingPolicyConfig();
104     boolean accepted = super.acceptResolvedAddresses(resolvedAddresses);
105     if (weightUpdateTimer != null && weightUpdateTimer.isPending()) {
106       weightUpdateTimer.cancel();
107     }
108     updateWeightTask.run();
109     afterAcceptAddresses();
110     return accepted;
111   }
112 
113   @Override
createReadyPicker(List<Subchannel> activeList)114   public RoundRobinPicker createReadyPicker(List<Subchannel> activeList) {
115     return new WeightedRoundRobinPicker(activeList, config.enableOobLoadReport,
116         config.errorUtilizationPenalty);
117   }
118 
119   private final class UpdateWeightTask implements Runnable {
120     @Override
run()121     public void run() {
122       if (currentPicker != null && currentPicker instanceof WeightedRoundRobinPicker) {
123         ((WeightedRoundRobinPicker)currentPicker).updateWeight();
124       }
125       weightUpdateTimer = syncContext.schedule(this, config.weightUpdatePeriodNanos,
126           TimeUnit.NANOSECONDS, timeService);
127     }
128   }
129 
afterAcceptAddresses()130   private void afterAcceptAddresses() {
131     for (Subchannel subchannel : getSubchannels()) {
132       WrrSubchannel weightedSubchannel = (WrrSubchannel) subchannel;
133       if (config.enableOobLoadReport) {
134         OrcaOobUtil.setListener(weightedSubchannel,
135             weightedSubchannel.new OrcaReportListener(config.errorUtilizationPenalty),
136                 OrcaOobUtil.OrcaReportingConfig.newBuilder()
137                         .setReportInterval(config.oobReportingPeriodNanos, TimeUnit.NANOSECONDS)
138                         .build());
139       } else {
140         OrcaOobUtil.setListener(weightedSubchannel, null, null);
141       }
142     }
143   }
144 
145   @Override
shutdown()146   public void shutdown() {
147     if (weightUpdateTimer != null) {
148       weightUpdateTimer.cancel();
149     }
150     super.shutdown();
151   }
152 
153   private static final class WrrHelper extends ForwardingLoadBalancerHelper {
154     private final Helper delegate;
155     private WeightedRoundRobinLoadBalancer wrr;
156 
WrrHelper(Helper helper)157     WrrHelper(Helper helper) {
158       this.delegate = helper;
159     }
160 
setLoadBalancer(WeightedRoundRobinLoadBalancer lb)161     void setLoadBalancer(WeightedRoundRobinLoadBalancer lb) {
162       this.wrr = lb;
163     }
164 
165     @Override
delegate()166     protected Helper delegate() {
167       return delegate;
168     }
169 
170     @Override
createSubchannel(CreateSubchannelArgs args)171     public Subchannel createSubchannel(CreateSubchannelArgs args) {
172       return wrr.new WrrSubchannel(delegate().createSubchannel(args));
173     }
174   }
175 
176   @VisibleForTesting
177   final class WrrSubchannel extends ForwardingSubchannel {
178     private final Subchannel delegate;
179     private volatile long lastUpdated;
180     private volatile long nonEmptySince;
181     private volatile double weight;
182 
WrrSubchannel(Subchannel delegate)183     WrrSubchannel(Subchannel delegate) {
184       this.delegate = checkNotNull(delegate, "delegate");
185     }
186 
187     @Override
start(SubchannelStateListener listener)188     public void start(SubchannelStateListener listener) {
189       delegate().start(new SubchannelStateListener() {
190         @Override
191         public void onSubchannelState(ConnectivityStateInfo newState) {
192           if (newState.getState().equals(ConnectivityState.READY)) {
193             nonEmptySince = infTime;
194           }
195           listener.onSubchannelState(newState);
196         }
197       });
198     }
199 
getWeight()200     private double getWeight() {
201       if (config == null) {
202         return 0;
203       }
204       long now = ticker.nanoTime();
205       if (now - lastUpdated >= config.weightExpirationPeriodNanos) {
206         nonEmptySince = infTime;
207         return 0;
208       } else if (now - nonEmptySince < config.blackoutPeriodNanos
209           && config.blackoutPeriodNanos > 0) {
210         return 0;
211       } else {
212         return weight;
213       }
214     }
215 
216     @Override
delegate()217     protected Subchannel delegate() {
218       return delegate;
219     }
220 
221     final class OrcaReportListener implements OrcaPerRequestReportListener, OrcaOobReportListener {
222       private final float errorUtilizationPenalty;
223 
OrcaReportListener(float errorUtilizationPenalty)224       OrcaReportListener(float errorUtilizationPenalty) {
225         this.errorUtilizationPenalty = errorUtilizationPenalty;
226       }
227 
228       @Override
onLoadReport(MetricReport report)229       public void onLoadReport(MetricReport report) {
230         double newWeight = 0;
231         // Prefer application utilization and fallback to CPU utilization if unset.
232         double utilization =
233             report.getApplicationUtilization() > 0 ? report.getApplicationUtilization()
234                 : report.getCpuUtilization();
235         if (utilization > 0 && report.getQps() > 0) {
236           double penalty = 0;
237           if (report.getEps() > 0 && errorUtilizationPenalty > 0) {
238             penalty = report.getEps() / report.getQps() * errorUtilizationPenalty;
239           }
240           newWeight = report.getQps() / (utilization + penalty);
241         }
242         if (newWeight == 0) {
243           return;
244         }
245         if (nonEmptySince == infTime) {
246           nonEmptySince = ticker.nanoTime();
247         }
248         lastUpdated = ticker.nanoTime();
249         weight = newWeight;
250       }
251     }
252   }
253 
254   @VisibleForTesting
255   final class WeightedRoundRobinPicker extends RoundRobinPicker {
256     private final List<Subchannel> list;
257     private final Map<Subchannel, OrcaPerRequestReportListener> subchannelToReportListenerMap =
258         new HashMap<>();
259     private final boolean enableOobLoadReport;
260     private final float errorUtilizationPenalty;
261     private volatile EdfScheduler scheduler;
262 
WeightedRoundRobinPicker(List<Subchannel> list, boolean enableOobLoadReport, float errorUtilizationPenalty)263     WeightedRoundRobinPicker(List<Subchannel> list, boolean enableOobLoadReport,
264         float errorUtilizationPenalty) {
265       checkNotNull(list, "list");
266       Preconditions.checkArgument(!list.isEmpty(), "empty list");
267       this.list = list;
268       for (Subchannel subchannel : list) {
269         this.subchannelToReportListenerMap.put(subchannel,
270             ((WrrSubchannel) subchannel).new OrcaReportListener(errorUtilizationPenalty));
271       }
272       this.enableOobLoadReport = enableOobLoadReport;
273       this.errorUtilizationPenalty = errorUtilizationPenalty;
274       updateWeight();
275     }
276 
277     @Override
pickSubchannel(PickSubchannelArgs args)278     public PickResult pickSubchannel(PickSubchannelArgs args) {
279       Subchannel subchannel = list.get(scheduler.pick());
280       if (!enableOobLoadReport) {
281         return PickResult.withSubchannel(subchannel,
282             OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory(
283                 subchannelToReportListenerMap.getOrDefault(subchannel,
284                     ((WrrSubchannel) subchannel).new OrcaReportListener(errorUtilizationPenalty))));
285       } else {
286         return PickResult.withSubchannel(subchannel);
287       }
288     }
289 
updateWeight()290     private void updateWeight() {
291       int weightedChannelCount = 0;
292       double avgWeight = 0;
293       for (Subchannel value : list) {
294         double newWeight = ((WrrSubchannel) value).getWeight();
295         if (newWeight > 0) {
296           avgWeight += newWeight;
297           weightedChannelCount++;
298         }
299       }
300       EdfScheduler scheduler = new EdfScheduler(list.size(), random);
301       if (weightedChannelCount >= 1) {
302         avgWeight /= 1.0 * weightedChannelCount;
303       } else {
304         avgWeight = 1;
305       }
306       for (int i = 0; i < list.size(); i++) {
307         WrrSubchannel subchannel = (WrrSubchannel) list.get(i);
308         double newWeight = subchannel.getWeight();
309         scheduler.add(i, newWeight > 0 ? newWeight : avgWeight);
310       }
311       this.scheduler = scheduler;
312     }
313 
314     @Override
toString()315     public String toString() {
316       return MoreObjects.toStringHelper(WeightedRoundRobinPicker.class)
317           .add("enableOobLoadReport", enableOobLoadReport)
318           .add("errorUtilizationPenalty", errorUtilizationPenalty)
319           .add("list", list).toString();
320     }
321 
322     @VisibleForTesting
getList()323     List<Subchannel> getList() {
324       return list;
325     }
326 
327     @Override
isEquivalentTo(RoundRobinPicker picker)328     public boolean isEquivalentTo(RoundRobinPicker picker) {
329       if (!(picker instanceof WeightedRoundRobinPicker)) {
330         return false;
331       }
332       WeightedRoundRobinPicker other = (WeightedRoundRobinPicker) picker;
333       if (other == this) {
334         return true;
335       }
336       // the lists cannot contain duplicate subchannels
337       return enableOobLoadReport == other.enableOobLoadReport
338           && Float.compare(errorUtilizationPenalty, other.errorUtilizationPenalty) == 0
339           && list.size() == other.list.size() && new HashSet<>(list).containsAll(other.list);
340     }
341   }
342 
343   /**
344    * The earliest deadline first implementation in which each object is
345    * chosen deterministically and periodically with frequency proportional to its weight.
346    *
347    * <p>Specifically, each object added to chooser is given a deadline equal to the multiplicative
348    * inverse of its weight. The place of each object in its deadline is tracked, and each call to
349    * choose returns the object with the least remaining time in its deadline.
350    * (Ties are broken by the order in which the children were added to the chooser.) The deadline
351    * advances by the multiplicative inverse of the object's weight.
352    * For example, if items A and B are added with weights 0.5 and 0.2, successive chooses return:
353    *
354    * <ul>
355    *   <li>In the first call, the deadlines are A=2 (1/0.5) and B=5 (1/0.2), so A is returned.
356    *   The deadline of A is updated to 4.
357    *   <li>Next, the remaining deadlines are A=4 and B=5, so A is returned. The deadline of A (2) is
358    *       updated to A=6.
359    *   <li>Remaining deadlines are A=6 and B=5, so B is returned. The deadline of B is updated with
360    *       with B=10.
361    *   <li>Remaining deadlines are A=6 and B=10, so A is returned. The deadline of A is updated with
362    *        A=8.
363    *   <li>Remaining deadlines are A=8 and B=10, so A is returned. The deadline of A is updated with
364    *       A=10.
365    *   <li>Remaining deadlines are A=10 and B=10, so A is returned. The deadline of A is updated
366    *      with A=12.
367    *   <li>Remaining deadlines are A=12 and B=10, so B is returned. The deadline of B is updated
368    *      with B=15.
369    *   <li>etc.
370    * </ul>
371    *
372    * <p>In short: the entry with the highest weight is preferred.
373    *
374    * <ul>
375    *   <li>add() - O(lg n)
376    *   <li>pick() - O(lg n)
377    * </ul>
378    *
379    */
380   @VisibleForTesting
381   static final class EdfScheduler {
382     private final PriorityQueue<ObjectState> prioQueue;
383 
384     /**
385      * Weights below this value will be upped to this minimum weight.
386      */
387     private static final double MINIMUM_WEIGHT = 0.0001;
388 
389     private final Object lock = new Object();
390 
391     private final Random random;
392 
393     /**
394      * Use the item's deadline as the order in the priority queue. If the deadlines are the same,
395      * use the index. Index should be unique.
396      */
EdfScheduler(int initialCapacity, Random random)397     EdfScheduler(int initialCapacity, Random random) {
398       this.prioQueue = new PriorityQueue<ObjectState>(initialCapacity, (o1, o2) -> {
399         if (o1.deadline == o2.deadline) {
400           return Integer.compare(o1.index, o2.index);
401         } else {
402           return Double.compare(o1.deadline, o2.deadline);
403         }
404       });
405       this.random = random;
406     }
407 
408     /**
409      * Adds the item in the scheduler. This is not thread safe.
410      *
411      * @param index The field {@link ObjectState#index} to be added
412      * @param weight positive weight for the added object
413      */
add(int index, double weight)414     void add(int index, double weight) {
415       checkArgument(weight > 0.0, "Weights need to be positive.");
416       ObjectState state = new ObjectState(Math.max(weight, MINIMUM_WEIGHT), index);
417       // Randomize the initial deadline.
418       state.deadline = random.nextDouble() * (1 / state.weight);
419       prioQueue.add(state);
420     }
421 
422     /**
423      * Picks the next WRR object.
424      */
pick()425     int pick() {
426       synchronized (lock) {
427         ObjectState minObject = prioQueue.remove();
428         minObject.deadline += 1.0 / minObject.weight;
429         prioQueue.add(minObject);
430         return minObject.index;
431       }
432     }
433   }
434 
435   /** Holds the state of the object. */
436   @VisibleForTesting
437   static class ObjectState {
438     private final double weight;
439     private final int index;
440     private volatile double deadline;
441 
ObjectState(double weight, int index)442     ObjectState(double weight, int index) {
443       this.weight = weight;
444       this.index = index;
445     }
446   }
447 
448   static final class WeightedRoundRobinLoadBalancerConfig {
449     final long blackoutPeriodNanos;
450     final long weightExpirationPeriodNanos;
451     final boolean enableOobLoadReport;
452     final long oobReportingPeriodNanos;
453     final long weightUpdatePeriodNanos;
454     final float errorUtilizationPenalty;
455 
newBuilder()456     public static Builder newBuilder() {
457       return new Builder();
458     }
459 
WeightedRoundRobinLoadBalancerConfig(long blackoutPeriodNanos, long weightExpirationPeriodNanos, boolean enableOobLoadReport, long oobReportingPeriodNanos, long weightUpdatePeriodNanos, float errorUtilizationPenalty)460     private WeightedRoundRobinLoadBalancerConfig(long blackoutPeriodNanos,
461                                                  long weightExpirationPeriodNanos,
462                                                  boolean enableOobLoadReport,
463                                                  long oobReportingPeriodNanos,
464                                                  long weightUpdatePeriodNanos,
465                                                  float errorUtilizationPenalty) {
466       this.blackoutPeriodNanos = blackoutPeriodNanos;
467       this.weightExpirationPeriodNanos = weightExpirationPeriodNanos;
468       this.enableOobLoadReport = enableOobLoadReport;
469       this.oobReportingPeriodNanos = oobReportingPeriodNanos;
470       this.weightUpdatePeriodNanos = weightUpdatePeriodNanos;
471       this.errorUtilizationPenalty = errorUtilizationPenalty;
472     }
473 
474     static final class Builder {
475       long blackoutPeriodNanos = 10_000_000_000L; // 10s
476       long weightExpirationPeriodNanos = 180_000_000_000L; //3min
477       boolean enableOobLoadReport = false;
478       long oobReportingPeriodNanos = 10_000_000_000L; // 10s
479       long weightUpdatePeriodNanos = 1_000_000_000L; // 1s
480       float errorUtilizationPenalty = 1.0F;
481 
Builder()482       private Builder() {
483 
484       }
485 
setBlackoutPeriodNanos(long blackoutPeriodNanos)486       Builder setBlackoutPeriodNanos(long blackoutPeriodNanos) {
487         this.blackoutPeriodNanos = blackoutPeriodNanos;
488         return this;
489       }
490 
setWeightExpirationPeriodNanos(long weightExpirationPeriodNanos)491       Builder setWeightExpirationPeriodNanos(long weightExpirationPeriodNanos) {
492         this.weightExpirationPeriodNanos = weightExpirationPeriodNanos;
493         return this;
494       }
495 
setEnableOobLoadReport(boolean enableOobLoadReport)496       Builder setEnableOobLoadReport(boolean enableOobLoadReport) {
497         this.enableOobLoadReport = enableOobLoadReport;
498         return this;
499       }
500 
setOobReportingPeriodNanos(long oobReportingPeriodNanos)501       Builder setOobReportingPeriodNanos(long oobReportingPeriodNanos) {
502         this.oobReportingPeriodNanos = oobReportingPeriodNanos;
503         return this;
504       }
505 
setWeightUpdatePeriodNanos(long weightUpdatePeriodNanos)506       Builder setWeightUpdatePeriodNanos(long weightUpdatePeriodNanos) {
507         this.weightUpdatePeriodNanos = weightUpdatePeriodNanos;
508         return this;
509       }
510 
setErrorUtilizationPenalty(float errorUtilizationPenalty)511       Builder setErrorUtilizationPenalty(float errorUtilizationPenalty) {
512         this.errorUtilizationPenalty = errorUtilizationPenalty;
513         return this;
514       }
515 
build()516       WeightedRoundRobinLoadBalancerConfig build() {
517         return new WeightedRoundRobinLoadBalancerConfig(blackoutPeriodNanos,
518                 weightExpirationPeriodNanos, enableOobLoadReport, oobReportingPeriodNanos,
519                 weightUpdatePeriodNanos, errorUtilizationPenalty);
520       }
521     }
522   }
523 }
524