1 /*
2  * Copyright 2020 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.testing.integration;
18 
19 import com.google.common.base.CaseFormat;
20 import com.google.common.base.Splitter;
21 import com.google.common.collect.ImmutableList;
22 import com.google.common.primitives.Ints;
23 import com.google.common.util.concurrent.FutureCallback;
24 import com.google.common.util.concurrent.Futures;
25 import com.google.common.util.concurrent.ListenableScheduledFuture;
26 import com.google.common.util.concurrent.ListeningScheduledExecutorService;
27 import com.google.common.util.concurrent.MoreExecutors;
28 import com.google.common.util.concurrent.SettableFuture;
29 import io.grpc.CallOptions;
30 import io.grpc.Channel;
31 import io.grpc.ClientCall;
32 import io.grpc.ClientInterceptor;
33 import io.grpc.ForwardingClientCall.SimpleForwardingClientCall;
34 import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener;
35 import io.grpc.Grpc;
36 import io.grpc.InsecureChannelCredentials;
37 import io.grpc.ManagedChannel;
38 import io.grpc.Metadata;
39 import io.grpc.MethodDescriptor;
40 import io.grpc.Server;
41 import io.grpc.Status;
42 import io.grpc.netty.NettyServerBuilder;
43 import io.grpc.protobuf.services.ProtoReflectionService;
44 import io.grpc.services.AdminInterface;
45 import io.grpc.stub.StreamObserver;
46 import io.grpc.testing.integration.Messages.ClientConfigureRequest;
47 import io.grpc.testing.integration.Messages.ClientConfigureRequest.RpcType;
48 import io.grpc.testing.integration.Messages.ClientConfigureResponse;
49 import io.grpc.testing.integration.Messages.LoadBalancerAccumulatedStatsRequest;
50 import io.grpc.testing.integration.Messages.LoadBalancerAccumulatedStatsResponse;
51 import io.grpc.testing.integration.Messages.LoadBalancerAccumulatedStatsResponse.MethodStats;
52 import io.grpc.testing.integration.Messages.LoadBalancerStatsRequest;
53 import io.grpc.testing.integration.Messages.LoadBalancerStatsResponse;
54 import io.grpc.testing.integration.Messages.SimpleRequest;
55 import io.grpc.testing.integration.Messages.SimpleResponse;
56 import io.grpc.xds.XdsChannelCredentials;
57 import java.util.ArrayList;
58 import java.util.Collections;
59 import java.util.EnumMap;
60 import java.util.HashMap;
61 import java.util.HashSet;
62 import java.util.List;
63 import java.util.Map;
64 import java.util.Set;
65 import java.util.concurrent.CountDownLatch;
66 import java.util.concurrent.ExecutionException;
67 import java.util.concurrent.Executors;
68 import java.util.concurrent.TimeUnit;
69 import java.util.concurrent.atomic.AtomicReference;
70 import java.util.logging.Level;
71 import java.util.logging.Logger;
72 import javax.annotation.Nullable;
73 import javax.annotation.concurrent.ThreadSafe;
74 
75 /** Client for xDS interop tests. */
76 public final class XdsTestClient {
77   private static Logger logger = Logger.getLogger(XdsTestClient.class.getName());
78 
79   private final Set<XdsStatsWatcher> watchers = new HashSet<>();
80   private final Object lock = new Object();
81   private final List<ManagedChannel> channels = new ArrayList<>();
82   private final StatsAccumulator statsAccumulator = new StatsAccumulator();
83 
84   private int numChannels = 1;
85   private boolean printResponse = false;
86   private int qps = 1;
87   private volatile List<RpcConfig> rpcConfigs;
88   private int rpcTimeoutSec = 20;
89   private boolean secureMode = false;
90   private String server = "localhost:8080";
91   private int statsPort = 8081;
92   private Server statsServer;
93   private long currentRequestId;
94   private ListeningScheduledExecutorService exec;
95 
96   /**
97    * The main application allowing this client to be launched from the command line.
98    */
main(String[] args)99   public static void main(String[] args) {
100     final XdsTestClient client = new XdsTestClient();
101     client.parseArgs(args);
102     Runtime.getRuntime()
103         .addShutdownHook(
104             new Thread() {
105               @Override
106               @SuppressWarnings("CatchAndPrintStackTrace")
107               public void run() {
108                 try {
109                   client.stop();
110                 } catch (Exception e) {
111                   e.printStackTrace();
112                 }
113               }
114             });
115     client.run();
116   }
117 
parseArgs(String[] args)118   private void parseArgs(String[] args) {
119     boolean usage = false;
120     List<RpcType> rpcTypes = ImmutableList.of(RpcType.UNARY_CALL);
121     EnumMap<RpcType, Metadata> metadata = new EnumMap<>(RpcType.class);
122     for (String arg : args) {
123       if (!arg.startsWith("--")) {
124         System.err.println("All arguments must start with '--': " + arg);
125         usage = true;
126         break;
127       }
128       String[] parts = arg.substring(2).split("=", 2);
129       String key = parts[0];
130       if ("help".equals(key)) {
131         usage = true;
132         break;
133       }
134       if (parts.length != 2) {
135         System.err.println("All arguments must be of the form --arg=value");
136         usage = true;
137         break;
138       }
139       String value = parts[1];
140       if ("metadata".equals(key)) {
141         metadata = parseMetadata(value);
142       } else if ("num_channels".equals(key)) {
143         numChannels = Integer.valueOf(value);
144       } else if ("print_response".equals(key)) {
145         printResponse = Boolean.valueOf(value);
146       } else if ("qps".equals(key)) {
147         qps = Integer.valueOf(value);
148       } else if ("rpc".equals(key)) {
149         rpcTypes = parseRpcs(value);
150       } else if ("rpc_timeout_sec".equals(key)) {
151         rpcTimeoutSec = Integer.valueOf(value);
152       } else if ("server".equals(key)) {
153         server = value;
154       } else if ("stats_port".equals(key)) {
155         statsPort = Integer.valueOf(value);
156       } else if ("secure_mode".equals(key)) {
157         secureMode = Boolean.valueOf(value);
158       } else {
159         System.err.println("Unknown argument: " + key);
160         usage = true;
161         break;
162       }
163     }
164     List<RpcConfig> configs = new ArrayList<>();
165     for (RpcType type : rpcTypes) {
166       Metadata md = new Metadata();
167       if (metadata.containsKey(type)) {
168         md = metadata.get(type);
169       }
170       configs.add(new RpcConfig(type, md, rpcTimeoutSec));
171     }
172     rpcConfigs = Collections.unmodifiableList(configs);
173 
174     if (usage) {
175       XdsTestClient c = new XdsTestClient();
176       System.err.println(
177           "Usage: [ARGS...]"
178               + "\n"
179               + "\n  --num_channels=INT     Default: "
180               + c.numChannels
181               + "\n  --print_response=BOOL  Write RPC response to stdout. Default: "
182               + c.printResponse
183               + "\n  --qps=INT              Qps per channel, for each type of RPC. Default: "
184               + c.qps
185               + "\n  --rpc=STR              Types of RPCs to make, ',' separated string. RPCs can "
186               + "be EmptyCall or UnaryCall. Default: UnaryCall"
187               + "\n[deprecated] Use XdsUpdateClientConfigureService"
188               + "\n  --metadata=STR         The metadata to send with each RPC, in the format "
189               + "EmptyCall:key1:value1,UnaryCall:key2:value2."
190               + "\n[deprecated] Use XdsUpdateClientConfigureService"
191               + "\n  --rpc_timeout_sec=INT  Per RPC timeout seconds. Default: "
192               + c.rpcTimeoutSec
193               + "\n  --server=host:port     Address of server. Default: "
194               + c.server
195               + "\n  --secure_mode=BOOLEAN  Use true to enable XdsCredentials. Default: "
196               + c.secureMode
197               + "\n  --stats_port=INT       Port to expose peer distribution stats service. "
198               + "Default: "
199               + c.statsPort);
200       System.exit(1);
201     }
202   }
203 
parseRpcs(String rpcArg)204   private static List<RpcType> parseRpcs(String rpcArg) {
205     List<RpcType> rpcs = new ArrayList<>();
206     for (String rpc : Splitter.on(',').split(rpcArg)) {
207       rpcs.add(parseRpc(rpc));
208     }
209     return rpcs;
210   }
211 
parseMetadata(String metadataArg)212   private static EnumMap<RpcType, Metadata> parseMetadata(String metadataArg) {
213     EnumMap<RpcType, Metadata> rpcMetadata = new EnumMap<>(RpcType.class);
214     for (String metadata : Splitter.on(',').omitEmptyStrings().split(metadataArg)) {
215       List<String> parts = Splitter.on(':').splitToList(metadata);
216       if (parts.size() != 3) {
217         throw new IllegalArgumentException("Invalid metadata: '" + metadata + "'");
218       }
219       RpcType rpc = parseRpc(parts.get(0));
220       String key = parts.get(1);
221       String value = parts.get(2);
222       Metadata md = new Metadata();
223       md.put(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER), value);
224       if (rpcMetadata.containsKey(rpc)) {
225         rpcMetadata.get(rpc).merge(md);
226       } else {
227         rpcMetadata.put(rpc, md);
228       }
229     }
230     return rpcMetadata;
231   }
232 
parseRpc(String rpc)233   private static RpcType parseRpc(String rpc) {
234     if ("EmptyCall".equals(rpc)) {
235       return RpcType.EMPTY_CALL;
236     } else if ("UnaryCall".equals(rpc)) {
237       return RpcType.UNARY_CALL;
238     } else {
239       throw new IllegalArgumentException("Unknown RPC: '" + rpc + "'");
240     }
241   }
242 
run()243   private void run() {
244     statsServer =
245         NettyServerBuilder.forPort(statsPort)
246             .addService(new XdsStatsImpl())
247             .addService(new ConfigureUpdateServiceImpl())
248             .addService(ProtoReflectionService.newInstance())
249             .addServices(AdminInterface.getStandardServices())
250             .build();
251     try {
252       statsServer.start();
253       for (int i = 0; i < numChannels; i++) {
254         channels.add(
255             Grpc.newChannelBuilder(
256                     server,
257                     secureMode
258                         ? XdsChannelCredentials.create(InsecureChannelCredentials.create())
259                         : InsecureChannelCredentials.create())
260                 .enableRetry()
261                 .build());
262       }
263       exec = MoreExecutors.listeningDecorator(Executors.newSingleThreadScheduledExecutor());
264       runQps();
265     } catch (Throwable t) {
266       logger.log(Level.SEVERE, "Error running client", t);
267       System.exit(1);
268     }
269   }
270 
stop()271   private void stop() throws InterruptedException {
272     if (statsServer != null) {
273       statsServer.shutdownNow();
274       if (!statsServer.awaitTermination(5, TimeUnit.SECONDS)) {
275         System.err.println("Timed out waiting for server shutdown");
276       }
277     }
278     for (ManagedChannel channel : channels) {
279       channel.shutdownNow();
280     }
281     if (exec != null) {
282       exec.shutdownNow();
283     }
284   }
285 
286 
runQps()287   private void runQps() throws InterruptedException, ExecutionException {
288     final SettableFuture<Void> failure = SettableFuture.create();
289     final class PeriodicRpc implements Runnable {
290 
291       @Override
292       public void run() {
293         List<RpcConfig> configs = rpcConfigs;
294         for (RpcConfig cfg : configs) {
295           makeRpc(cfg);
296         }
297       }
298 
299       private void makeRpc(final RpcConfig config) {
300         final long requestId;
301         final Set<XdsStatsWatcher> savedWatchers = new HashSet<>();
302         synchronized (lock) {
303           currentRequestId += 1;
304           requestId = currentRequestId;
305           savedWatchers.addAll(watchers);
306         }
307 
308         ManagedChannel channel = channels.get((int) (requestId % channels.size()));
309         TestServiceGrpc.TestServiceStub stub = TestServiceGrpc.newStub(channel);
310         final AtomicReference<ClientCall<?, ?>> clientCallRef = new AtomicReference<>();
311         final AtomicReference<String> hostnameRef = new AtomicReference<>();
312         stub =
313             stub.withDeadlineAfter(config.timeoutSec, TimeUnit.SECONDS)
314                 .withInterceptors(
315                     new ClientInterceptor() {
316                       @Override
317                       public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
318                           MethodDescriptor<ReqT, RespT> method,
319                           CallOptions callOptions,
320                           Channel next) {
321                         ClientCall<ReqT, RespT> call = next.newCall(method, callOptions);
322                         clientCallRef.set(call);
323                         return new SimpleForwardingClientCall<ReqT, RespT>(call) {
324                           @Override
325                           public void start(Listener<RespT> responseListener, Metadata headers) {
326                             headers.merge(config.metadata);
327                             super.start(
328                                 new SimpleForwardingClientCallListener<RespT>(responseListener) {
329                                   @Override
330                                   public void onHeaders(Metadata headers) {
331                                     hostnameRef.set(headers.get(XdsTestServer.HOSTNAME_KEY));
332                                     super.onHeaders(headers);
333                                   }
334                                 },
335                                 headers);
336                           }
337                         };
338                       }
339                     });
340 
341         if (config.rpcType == RpcType.EMPTY_CALL) {
342           stub.emptyCall(
343               EmptyProtos.Empty.getDefaultInstance(),
344               new StreamObserver<EmptyProtos.Empty>() {
345                 @Override
346                 public void onCompleted() {
347                   handleRpcCompleted(requestId, config.rpcType, hostnameRef.get(), savedWatchers);
348                 }
349 
350                 @Override
351                 public void onError(Throwable t) {
352                   handleRpcError(requestId, config.rpcType, Status.fromThrowable(t),
353                       savedWatchers);
354                 }
355 
356                 @Override
357                 public void onNext(EmptyProtos.Empty response) {}
358               });
359         } else if (config.rpcType == RpcType.UNARY_CALL) {
360           SimpleRequest request = SimpleRequest.newBuilder().setFillServerId(true).build();
361           stub.unaryCall(
362               request,
363               new StreamObserver<SimpleResponse>() {
364                 @Override
365                 public void onCompleted() {
366                   handleRpcCompleted(requestId, config.rpcType, hostnameRef.get(), savedWatchers);
367                 }
368 
369                 @Override
370                 public void onError(Throwable t) {
371                   if (printResponse) {
372                     logger.log(Level.WARNING, "Rpc failed", t);
373                   }
374                   handleRpcError(requestId, config.rpcType, Status.fromThrowable(t),
375                       savedWatchers);
376                 }
377 
378                 @Override
379                 public void onNext(SimpleResponse response) {
380                   // TODO(ericgribkoff) Currently some test environments cannot access the stats RPC
381                   // service and rely on parsing stdout.
382                   if (printResponse) {
383                     System.out.println(
384                         "Greeting: Hello world, this is "
385                             + response.getHostname()
386                             + ", from "
387                             + clientCallRef
388                                 .get()
389                                 .getAttributes()
390                                 .get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
391                   }
392                   // Use the hostname from the response if not present in the metadata.
393                   // TODO(ericgribkoff) Delete when server is deployed that sets metadata value.
394                   if (hostnameRef.get() == null) {
395                     hostnameRef.set(response.getHostname());
396                   }
397                 }
398               });
399         } else {
400           throw new AssertionError("Unknown RPC type: " + config.rpcType);
401         }
402         statsAccumulator.recordRpcStarted(config.rpcType);
403       }
404 
405       private void handleRpcCompleted(long requestId, RpcType rpcType, String hostname,
406           Set<XdsStatsWatcher> watchers) {
407         statsAccumulator.recordRpcFinished(rpcType, Status.OK);
408         notifyWatchers(watchers, rpcType, requestId, hostname);
409       }
410 
411       private void handleRpcError(long requestId, RpcType rpcType, Status status,
412           Set<XdsStatsWatcher> watchers) {
413         statsAccumulator.recordRpcFinished(rpcType, status);
414         notifyWatchers(watchers, rpcType, requestId, null);
415       }
416     }
417 
418     long nanosPerQuery = TimeUnit.SECONDS.toNanos(1) / qps;
419     ListenableScheduledFuture<?> future =
420         exec.scheduleAtFixedRate(new PeriodicRpc(), 0, nanosPerQuery, TimeUnit.NANOSECONDS);
421     Futures.addCallback(
422         future,
423         new FutureCallback<Object>() {
424 
425           @Override
426           public void onFailure(Throwable t) {
427             failure.setException(t);
428           }
429 
430           @Override
431           public void onSuccess(Object o) {}
432         },
433         MoreExecutors.directExecutor());
434 
435     failure.get();
436   }
437 
notifyWatchers( Set<XdsStatsWatcher> watchers, RpcType rpcType, long requestId, String hostname)438   private void notifyWatchers(
439       Set<XdsStatsWatcher> watchers, RpcType rpcType, long requestId, String hostname) {
440     for (XdsStatsWatcher watcher : watchers) {
441       watcher.rpcCompleted(rpcType, requestId, hostname);
442     }
443   }
444 
445   private final class ConfigureUpdateServiceImpl extends
446       XdsUpdateClientConfigureServiceGrpc.XdsUpdateClientConfigureServiceImplBase {
447     @Override
configure(ClientConfigureRequest request, StreamObserver<ClientConfigureResponse> responseObserver)448     public void configure(ClientConfigureRequest request,
449         StreamObserver<ClientConfigureResponse> responseObserver) {
450       EnumMap<RpcType, Metadata> newMetadata = new EnumMap<>(RpcType.class);
451       for (ClientConfigureRequest.Metadata metadata : request.getMetadataList()) {
452         Metadata md = newMetadata.get(metadata.getType());
453         if (md == null) {
454           md = new Metadata();
455         }
456         md.put(Metadata.Key.of(metadata.getKey(), Metadata.ASCII_STRING_MARSHALLER),
457             metadata.getValue());
458         newMetadata.put(metadata.getType(), md);
459       }
460       List<RpcConfig> configs = new ArrayList<>();
461       for (RpcType type : request.getTypesList()) {
462         Metadata md = newMetadata.containsKey(type) ? newMetadata.get(type) : new Metadata();
463         int timeout = request.getTimeoutSec() != 0 ? request.getTimeoutSec() : rpcTimeoutSec;
464         configs.add(new RpcConfig(type, md, timeout));
465       }
466       rpcConfigs = Collections.unmodifiableList(configs);
467       responseObserver.onNext(ClientConfigureResponse.getDefaultInstance());
468       responseObserver.onCompleted();
469     }
470   }
471 
472   private class XdsStatsImpl extends LoadBalancerStatsServiceGrpc.LoadBalancerStatsServiceImplBase {
473     @Override
getClientStats( LoadBalancerStatsRequest req, StreamObserver<LoadBalancerStatsResponse> responseObserver)474     public void getClientStats(
475         LoadBalancerStatsRequest req, StreamObserver<LoadBalancerStatsResponse> responseObserver) {
476       XdsStatsWatcher watcher;
477       synchronized (lock) {
478         long startId = currentRequestId + 1;
479         long endId = startId + req.getNumRpcs();
480         watcher = new XdsStatsWatcher(startId, endId);
481         watchers.add(watcher);
482       }
483       LoadBalancerStatsResponse response = watcher.waitForRpcStats(req.getTimeoutSec());
484       synchronized (lock) {
485         watchers.remove(watcher);
486       }
487       responseObserver.onNext(response);
488       responseObserver.onCompleted();
489     }
490 
491     @Override
getClientAccumulatedStats(LoadBalancerAccumulatedStatsRequest request, StreamObserver<LoadBalancerAccumulatedStatsResponse> responseObserver)492     public void getClientAccumulatedStats(LoadBalancerAccumulatedStatsRequest request,
493         StreamObserver<LoadBalancerAccumulatedStatsResponse> responseObserver) {
494       responseObserver.onNext(statsAccumulator.getRpcStats());
495       responseObserver.onCompleted();
496     }
497   }
498 
499   /** Configuration applies to the specific type of RPCs. */
500   private static final class RpcConfig {
501     private final RpcType rpcType;
502     private final Metadata metadata;
503     private final int timeoutSec;
504 
RpcConfig(RpcType rpcType, Metadata metadata, int timeoutSec)505     private RpcConfig(RpcType rpcType, Metadata metadata, int timeoutSec) {
506       this.rpcType = rpcType;
507       this.metadata = metadata;
508       this.timeoutSec = timeoutSec;
509     }
510   }
511 
512   /** Stats recorder for test RPCs. */
513   @ThreadSafe
514   private static final class StatsAccumulator {
515     private final Map<String, Integer> rpcsStartedByMethod = new HashMap<>();
516     // TODO(chengyuanzhang): delete the following two after corresponding fields deleted in proto.
517     private final Map<String, Integer> rpcsFailedByMethod = new HashMap<>();
518     private final Map<String, Integer> rpcsSucceededByMethod = new HashMap<>();
519     private final Map<String, Map<Integer, Integer>> rpcStatusByMethod = new HashMap<>();
520 
recordRpcStarted(RpcType rpcType)521     private synchronized void recordRpcStarted(RpcType rpcType) {
522       String method = getRpcTypeString(rpcType);
523       int count = rpcsStartedByMethod.containsKey(method) ? rpcsStartedByMethod.get(method) : 0;
524       rpcsStartedByMethod.put(method, count + 1);
525     }
526 
recordRpcFinished(RpcType rpcType, Status status)527     private synchronized void recordRpcFinished(RpcType rpcType, Status status) {
528       String method = getRpcTypeString(rpcType);
529       if (status.isOk()) {
530         int count =
531             rpcsSucceededByMethod.containsKey(method) ? rpcsSucceededByMethod.get(method) : 0;
532         rpcsSucceededByMethod.put(method, count + 1);
533       } else {
534         int count = rpcsFailedByMethod.containsKey(method) ? rpcsFailedByMethod.get(method) : 0;
535         rpcsFailedByMethod.put(method, count + 1);
536       }
537       int statusCode = status.getCode().value();
538       Map<Integer, Integer> statusCounts = rpcStatusByMethod.get(method);
539       if (statusCounts == null) {
540         statusCounts = new HashMap<>();
541         rpcStatusByMethod.put(method, statusCounts);
542       }
543       int count = statusCounts.containsKey(statusCode) ? statusCounts.get(statusCode) : 0;
544       statusCounts.put(statusCode, count + 1);
545     }
546 
547     @SuppressWarnings("deprecation")
getRpcStats()548     private synchronized LoadBalancerAccumulatedStatsResponse getRpcStats() {
549       LoadBalancerAccumulatedStatsResponse.Builder builder =
550           LoadBalancerAccumulatedStatsResponse.newBuilder();
551       builder.putAllNumRpcsStartedByMethod(rpcsStartedByMethod);
552       builder.putAllNumRpcsSucceededByMethod(rpcsSucceededByMethod);
553       builder.putAllNumRpcsFailedByMethod(rpcsFailedByMethod);
554 
555       for (String method : rpcsStartedByMethod.keySet()) {
556         MethodStats.Builder methodStatsBuilder = MethodStats.newBuilder();
557         methodStatsBuilder.setRpcsStarted(rpcsStartedByMethod.get(method));
558         if (rpcStatusByMethod.containsKey(method)) {
559           methodStatsBuilder.putAllResult(rpcStatusByMethod.get(method));
560         }
561         builder.putStatsPerMethod(method, methodStatsBuilder.build());
562       }
563       return builder.build();
564     }
565 
566     // e.g., RpcType.UNARY_CALL -> "UNARY_CALL"
getRpcTypeString(RpcType rpcType)567     private static String getRpcTypeString(RpcType rpcType) {
568       return rpcType.name();
569     }
570   }
571 
572   /** Records the remote peer distribution for a given range of RPCs. */
573   private static class XdsStatsWatcher {
574     private final CountDownLatch latch;
575     private final long startId;
576     private final long endId;
577     private final Map<String, Integer> rpcsByPeer = new HashMap<>();
578     private final EnumMap<RpcType, Map<String, Integer>> rpcsByTypeAndPeer =
579         new EnumMap<>(RpcType.class);
580     private final Object lock = new Object();
581     private int rpcsFailed;
582 
XdsStatsWatcher(long startId, long endId)583     private XdsStatsWatcher(long startId, long endId) {
584       latch = new CountDownLatch(Ints.checkedCast(endId - startId));
585       this.startId = startId;
586       this.endId = endId;
587     }
588 
rpcCompleted(RpcType rpcType, long requestId, @Nullable String hostname)589     void rpcCompleted(RpcType rpcType, long requestId, @Nullable String hostname) {
590       synchronized (lock) {
591         if (startId <= requestId && requestId < endId) {
592           if (hostname != null) {
593             if (rpcsByPeer.containsKey(hostname)) {
594               rpcsByPeer.put(hostname, rpcsByPeer.get(hostname) + 1);
595             } else {
596               rpcsByPeer.put(hostname, 1);
597             }
598             if (rpcsByTypeAndPeer.containsKey(rpcType)) {
599               if (rpcsByTypeAndPeer.get(rpcType).containsKey(hostname)) {
600                 rpcsByTypeAndPeer
601                     .get(rpcType)
602                     .put(hostname, rpcsByTypeAndPeer.get(rpcType).get(hostname) + 1);
603               } else {
604                 rpcsByTypeAndPeer.get(rpcType).put(hostname, 1);
605               }
606             } else {
607               Map<String, Integer> rpcMap = new HashMap<>();
608               rpcMap.put(hostname, 1);
609               rpcsByTypeAndPeer.put(rpcType, rpcMap);
610             }
611           } else {
612             rpcsFailed += 1;
613           }
614           latch.countDown();
615         }
616       }
617     }
618 
waitForRpcStats(long timeoutSeconds)619     LoadBalancerStatsResponse waitForRpcStats(long timeoutSeconds) {
620       try {
621         boolean success = latch.await(timeoutSeconds, TimeUnit.SECONDS);
622         if (!success) {
623           logger.log(Level.INFO, "Await timed out, returning partial stats");
624         }
625       } catch (InterruptedException e) {
626         logger.log(Level.INFO, "Await interrupted, returning partial stats", e);
627         Thread.currentThread().interrupt();
628       }
629       LoadBalancerStatsResponse.Builder builder = LoadBalancerStatsResponse.newBuilder();
630       synchronized (lock) {
631         builder.putAllRpcsByPeer(rpcsByPeer);
632         for (Map.Entry<RpcType, Map<String, Integer>> entry : rpcsByTypeAndPeer.entrySet()) {
633           LoadBalancerStatsResponse.RpcsByPeer.Builder rpcs =
634               LoadBalancerStatsResponse.RpcsByPeer.newBuilder();
635           rpcs.putAllRpcsByPeer(entry.getValue());
636           builder.putRpcsByMethod(getRpcTypeString(entry.getKey()), rpcs.build());
637         }
638         builder.setNumFailures(rpcsFailed);
639       }
640       return builder.build();
641     }
642 
643     // e.g., RpcType.UNARY_CALL -> "UnaryCall"
getRpcTypeString(RpcType rpcType)644     private static String getRpcTypeString(RpcType rpcType) {
645       return CaseFormat.UPPER_UNDERSCORE.to(CaseFormat.UPPER_CAMEL, rpcType.name());
646     }
647   }
648 }
649