xref: /aosp_15_r20/external/grpc-grpc/src/python/grpcio_tests/tests_py3_only/interop/xds_interop_client.py (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1# Copyright 2020 The gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import argparse
16import collections
17import concurrent.futures
18import datetime
19import logging
20import signal
21import threading
22import time
23from typing import (
24    DefaultDict,
25    Dict,
26    FrozenSet,
27    Iterable,
28    List,
29    Mapping,
30    Sequence,
31    Set,
32    Tuple,
33)
34
35import grpc
36from grpc import _typing as grpc_typing
37import grpc_admin
38from grpc_channelz.v1 import channelz
39
40from src.proto.grpc.testing import empty_pb2
41from src.proto.grpc.testing import messages_pb2
42from src.proto.grpc.testing import test_pb2
43from src.proto.grpc.testing import test_pb2_grpc
44
45logger = logging.getLogger()
46console_handler = logging.StreamHandler()
47formatter = logging.Formatter(fmt="%(asctime)s: %(levelname)-8s %(message)s")
48console_handler.setFormatter(formatter)
49logger.addHandler(console_handler)
50
51_SUPPORTED_METHODS = (
52    "UnaryCall",
53    "EmptyCall",
54)
55
56_METHOD_CAMEL_TO_CAPS_SNAKE = {
57    "UnaryCall": "UNARY_CALL",
58    "EmptyCall": "EMPTY_CALL",
59}
60
61_METHOD_STR_TO_ENUM = {
62    "UnaryCall": messages_pb2.ClientConfigureRequest.UNARY_CALL,
63    "EmptyCall": messages_pb2.ClientConfigureRequest.EMPTY_CALL,
64}
65
66_METHOD_ENUM_TO_STR = {v: k for k, v in _METHOD_STR_TO_ENUM.items()}
67
68PerMethodMetadataType = Mapping[str, Sequence[Tuple[str, str]]]
69
70
71# FutureFromCall is both a grpc.Call and grpc.Future
72class FutureFromCallType(grpc.Call, grpc.Future):
73    pass
74
75
76_CONFIG_CHANGE_TIMEOUT = datetime.timedelta(milliseconds=500)
77
78
79class _StatsWatcher:
80    _start: int
81    _end: int
82    _rpcs_needed: int
83    _rpcs_by_peer: DefaultDict[str, int]
84    _rpcs_by_method: DefaultDict[str, DefaultDict[str, int]]
85    _no_remote_peer: int
86    _lock: threading.Lock
87    _condition: threading.Condition
88    _metadata_keys: FrozenSet[str]
89    _include_all_metadata: bool
90    _metadata_by_peer: DefaultDict[
91        str, messages_pb2.LoadBalancerStatsResponse.MetadataByPeer
92    ]
93
94    def __init__(self, start: int, end: int, metadata_keys: Iterable[str]):
95        self._start = start
96        self._end = end
97        self._rpcs_needed = end - start
98        self._rpcs_by_peer = collections.defaultdict(int)
99        self._rpcs_by_method = collections.defaultdict(
100            lambda: collections.defaultdict(int)
101        )
102        self._condition = threading.Condition()
103        self._no_remote_peer = 0
104        self._metadata_keys = frozenset(
105            self._sanitize_metadata_key(key) for key in metadata_keys
106        )
107        self._include_all_metadata = "*" in self._metadata_keys
108        self._metadata_by_peer = collections.defaultdict(
109            messages_pb2.LoadBalancerStatsResponse.MetadataByPeer
110        )
111
112    @classmethod
113    def _sanitize_metadata_key(cls, metadata_key: str) -> str:
114        return metadata_key.strip().lower()
115
116    def _add_metadata(
117        self,
118        rpc_metadata: messages_pb2.LoadBalancerStatsResponse.RpcMetadata,
119        metadata_to_add: grpc_typing.MetadataType,
120        metadata_type: messages_pb2.LoadBalancerStatsResponse.MetadataType,
121    ) -> None:
122        for key, value in metadata_to_add:
123            if (
124                self._include_all_metadata
125                or self._sanitize_metadata_key(key) in self._metadata_keys
126            ):
127                rpc_metadata.metadata.append(
128                    messages_pb2.LoadBalancerStatsResponse.MetadataEntry(
129                        key=key, value=value, type=metadata_type
130                    )
131                )
132
133    def on_rpc_complete(
134        self,
135        request_id: int,
136        peer: str,
137        method: str,
138        *,
139        initial_metadata: grpc_typing.MetadataType,
140        trailing_metadata: grpc_typing.MetadataType,
141    ) -> None:
142        """Records statistics for a single RPC."""
143        if self._start <= request_id < self._end:
144            with self._condition:
145                if not peer:
146                    self._no_remote_peer += 1
147                else:
148                    self._rpcs_by_peer[peer] += 1
149                    self._rpcs_by_method[method][peer] += 1
150                    if self._metadata_keys:
151                        rpc_metadata = (
152                            messages_pb2.LoadBalancerStatsResponse.RpcMetadata()
153                        )
154                        self._add_metadata(
155                            rpc_metadata,
156                            initial_metadata,
157                            messages_pb2.LoadBalancerStatsResponse.MetadataType.INITIAL,
158                        )
159                        self._add_metadata(
160                            rpc_metadata,
161                            trailing_metadata,
162                            messages_pb2.LoadBalancerStatsResponse.MetadataType.TRAILING,
163                        )
164                        self._metadata_by_peer[peer].rpc_metadata.append(
165                            rpc_metadata
166                        )
167                self._rpcs_needed -= 1
168                self._condition.notify()
169
170    def await_rpc_stats_response(
171        self, timeout_sec: int
172    ) -> messages_pb2.LoadBalancerStatsResponse:
173        """Blocks until a full response has been collected."""
174        with self._condition:
175            self._condition.wait_for(
176                lambda: not self._rpcs_needed, timeout=float(timeout_sec)
177            )
178            response = messages_pb2.LoadBalancerStatsResponse()
179            for peer, count in self._rpcs_by_peer.items():
180                response.rpcs_by_peer[peer] = count
181            for method, count_by_peer in self._rpcs_by_method.items():
182                for peer, count in count_by_peer.items():
183                    response.rpcs_by_method[method].rpcs_by_peer[peer] = count
184            for peer, metadata_by_peer in self._metadata_by_peer.items():
185                response.metadatas_by_peer[peer].CopyFrom(metadata_by_peer)
186            response.num_failures = self._no_remote_peer + self._rpcs_needed
187        return response
188
189
190_global_lock = threading.Lock()
191_stop_event = threading.Event()
192_global_rpc_id: int = 0
193_watchers: Set[_StatsWatcher] = set()
194_global_server = None
195_global_rpcs_started: Mapping[str, int] = collections.defaultdict(int)
196_global_rpcs_succeeded: Mapping[str, int] = collections.defaultdict(int)
197_global_rpcs_failed: Mapping[str, int] = collections.defaultdict(int)
198
199# Mapping[method, Mapping[status_code, count]]
200_global_rpc_statuses: Mapping[str, Mapping[int, int]] = collections.defaultdict(
201    lambda: collections.defaultdict(int)
202)
203
204
205def _handle_sigint(sig, frame) -> None:
206    logger.warning("Received SIGINT")
207    _stop_event.set()
208    _global_server.stop(None)
209
210
211class _LoadBalancerStatsServicer(
212    test_pb2_grpc.LoadBalancerStatsServiceServicer
213):
214    def __init__(self):
215        super(_LoadBalancerStatsServicer).__init__()
216
217    def GetClientStats(
218        self,
219        request: messages_pb2.LoadBalancerStatsRequest,
220        context: grpc.ServicerContext,
221    ) -> messages_pb2.LoadBalancerStatsResponse:
222        logger.info("Received stats request.")
223        start = None
224        end = None
225        watcher = None
226        with _global_lock:
227            start = _global_rpc_id + 1
228            end = start + request.num_rpcs
229            watcher = _StatsWatcher(start, end, request.metadata_keys)
230            _watchers.add(watcher)
231        response = watcher.await_rpc_stats_response(request.timeout_sec)
232        with _global_lock:
233            _watchers.remove(watcher)
234        logger.info("Returning stats response: %s", response)
235        return response
236
237    def GetClientAccumulatedStats(
238        self,
239        request: messages_pb2.LoadBalancerAccumulatedStatsRequest,
240        context: grpc.ServicerContext,
241    ) -> messages_pb2.LoadBalancerAccumulatedStatsResponse:
242        logger.info("Received cumulative stats request.")
243        response = messages_pb2.LoadBalancerAccumulatedStatsResponse()
244        with _global_lock:
245            for method in _SUPPORTED_METHODS:
246                caps_method = _METHOD_CAMEL_TO_CAPS_SNAKE[method]
247                response.num_rpcs_started_by_method[
248                    caps_method
249                ] = _global_rpcs_started[method]
250                response.num_rpcs_succeeded_by_method[
251                    caps_method
252                ] = _global_rpcs_succeeded[method]
253                response.num_rpcs_failed_by_method[
254                    caps_method
255                ] = _global_rpcs_failed[method]
256                response.stats_per_method[
257                    caps_method
258                ].rpcs_started = _global_rpcs_started[method]
259                for code, count in _global_rpc_statuses[method].items():
260                    response.stats_per_method[caps_method].result[code] = count
261        logger.info("Returning cumulative stats response.")
262        return response
263
264
265def _start_rpc(
266    method: str,
267    metadata: Sequence[Tuple[str, str]],
268    request_id: int,
269    stub: test_pb2_grpc.TestServiceStub,
270    timeout: float,
271    futures: Mapping[int, Tuple[FutureFromCallType, str]],
272) -> None:
273    logger.debug(f"Sending {method} request to backend: {request_id}")
274    if method == "UnaryCall":
275        future = stub.UnaryCall.future(
276            messages_pb2.SimpleRequest(), metadata=metadata, timeout=timeout
277        )
278    elif method == "EmptyCall":
279        future = stub.EmptyCall.future(
280            empty_pb2.Empty(), metadata=metadata, timeout=timeout
281        )
282    else:
283        raise ValueError(f"Unrecognized method '{method}'.")
284    futures[request_id] = (future, method)
285
286
287def _on_rpc_done(
288    rpc_id: int, future: FutureFromCallType, method: str, print_response: bool
289) -> None:
290    exception = future.exception()
291    hostname = ""
292    with _global_lock:
293        _global_rpc_statuses[method][future.code().value[0]] += 1
294    if exception is not None:
295        with _global_lock:
296            _global_rpcs_failed[method] += 1
297        if exception.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
298            logger.error(f"RPC {rpc_id} timed out")
299        else:
300            logger.error(exception)
301    else:
302        response = future.result()
303        hostname = None
304        for metadatum in future.initial_metadata():
305            if metadatum[0] == "hostname":
306                hostname = metadatum[1]
307                break
308        else:
309            hostname = response.hostname
310        if future.code() == grpc.StatusCode.OK:
311            with _global_lock:
312                _global_rpcs_succeeded[method] += 1
313        else:
314            with _global_lock:
315                _global_rpcs_failed[method] += 1
316        if print_response:
317            if future.code() == grpc.StatusCode.OK:
318                logger.debug("Successful response.")
319            else:
320                logger.debug(f"RPC failed: {rpc_id}")
321    with _global_lock:
322        for watcher in _watchers:
323            watcher.on_rpc_complete(
324                rpc_id,
325                hostname,
326                method,
327                initial_metadata=future.initial_metadata(),
328                trailing_metadata=future.trailing_metadata(),
329            )
330
331
332def _remove_completed_rpcs(
333    rpc_futures: Mapping[int, FutureFromCallType], print_response: bool
334) -> None:
335    logger.debug("Removing completed RPCs")
336    done = []
337    for future_id, (future, method) in rpc_futures.items():
338        if future.done():
339            _on_rpc_done(future_id, future, method, args.print_response)
340            done.append(future_id)
341    for rpc_id in done:
342        del rpc_futures[rpc_id]
343
344
345def _cancel_all_rpcs(futures: Mapping[int, Tuple[grpc.Future, str]]) -> None:
346    logger.info("Cancelling all remaining RPCs")
347    for future, _ in futures.values():
348        future.cancel()
349
350
351class _ChannelConfiguration:
352    """Configuration for a single client channel.
353
354    Instances of this class are meant to be dealt with as PODs. That is,
355    data member should be accessed directly. This class is not thread-safe.
356    When accessing any of its members, the lock member should be held.
357    """
358
359    def __init__(
360        self,
361        method: str,
362        metadata: Sequence[Tuple[str, str]],
363        qps: int,
364        server: str,
365        rpc_timeout_sec: int,
366        print_response: bool,
367        secure_mode: bool,
368    ):
369        # condition is signalled when a change is made to the config.
370        self.condition = threading.Condition()
371
372        self.method = method
373        self.metadata = metadata
374        self.qps = qps
375        self.server = server
376        self.rpc_timeout_sec = rpc_timeout_sec
377        self.print_response = print_response
378        self.secure_mode = secure_mode
379
380
381def _run_single_channel(config: _ChannelConfiguration) -> None:
382    global _global_rpc_id  # pylint: disable=global-statement
383    with config.condition:
384        server = config.server
385    channel = None
386    if config.secure_mode:
387        fallback_creds = grpc.experimental.insecure_channel_credentials()
388        channel_creds = grpc.xds_channel_credentials(fallback_creds)
389        channel = grpc.secure_channel(server, channel_creds)
390    else:
391        channel = grpc.insecure_channel(server)
392    with channel:
393        stub = test_pb2_grpc.TestServiceStub(channel)
394        futures: Dict[int, Tuple[FutureFromCallType, str]] = {}
395        while not _stop_event.is_set():
396            with config.condition:
397                if config.qps == 0:
398                    config.condition.wait(
399                        timeout=_CONFIG_CHANGE_TIMEOUT.total_seconds()
400                    )
401                    continue
402                else:
403                    duration_per_query = 1.0 / float(config.qps)
404                request_id = None
405                with _global_lock:
406                    request_id = _global_rpc_id
407                    _global_rpc_id += 1
408                    _global_rpcs_started[config.method] += 1
409                start = time.time()
410                end = start + duration_per_query
411                _start_rpc(
412                    config.method,
413                    config.metadata,
414                    request_id,
415                    stub,
416                    float(config.rpc_timeout_sec),
417                    futures,
418                )
419                print_response = config.print_response
420            _remove_completed_rpcs(futures, config.print_response)
421            logger.debug(f"Currently {len(futures)} in-flight RPCs")
422            now = time.time()
423            while now < end:
424                time.sleep(end - now)
425                now = time.time()
426        _cancel_all_rpcs(futures)
427
428
429class _XdsUpdateClientConfigureServicer(
430    test_pb2_grpc.XdsUpdateClientConfigureServiceServicer
431):
432    def __init__(
433        self, per_method_configs: Mapping[str, _ChannelConfiguration], qps: int
434    ):
435        super(_XdsUpdateClientConfigureServicer).__init__()
436        self._per_method_configs = per_method_configs
437        self._qps = qps
438
439    def Configure(
440        self,
441        request: messages_pb2.ClientConfigureRequest,
442        context: grpc.ServicerContext,
443    ) -> messages_pb2.ClientConfigureResponse:
444        logger.info("Received Configure RPC: %s", request)
445        method_strs = [_METHOD_ENUM_TO_STR[t] for t in request.types]
446        for method in _SUPPORTED_METHODS:
447            method_enum = _METHOD_STR_TO_ENUM[method]
448            channel_config = self._per_method_configs[method]
449            if method in method_strs:
450                qps = self._qps
451                metadata = (
452                    (md.key, md.value)
453                    for md in request.metadata
454                    if md.type == method_enum
455                )
456                # For backward compatibility, do not change timeout when we
457                # receive a default value timeout.
458                if request.timeout_sec == 0:
459                    timeout_sec = channel_config.rpc_timeout_sec
460                else:
461                    timeout_sec = request.timeout_sec
462            else:
463                qps = 0
464                metadata = ()
465                # Leave timeout unchanged for backward compatibility.
466                timeout_sec = channel_config.rpc_timeout_sec
467            with channel_config.condition:
468                channel_config.qps = qps
469                channel_config.metadata = list(metadata)
470                channel_config.rpc_timeout_sec = timeout_sec
471                channel_config.condition.notify_all()
472        return messages_pb2.ClientConfigureResponse()
473
474
475class _MethodHandle:
476    """An object grouping together threads driving RPCs for a method."""
477
478    _channel_threads: List[threading.Thread]
479
480    def __init__(
481        self, num_channels: int, channel_config: _ChannelConfiguration
482    ):
483        """Creates and starts a group of threads running the indicated method."""
484        self._channel_threads = []
485        for i in range(num_channels):
486            thread = threading.Thread(
487                target=_run_single_channel, args=(channel_config,)
488            )
489            thread.start()
490            self._channel_threads.append(thread)
491
492    def stop(self) -> None:
493        """Joins all threads referenced by the handle."""
494        for channel_thread in self._channel_threads:
495            channel_thread.join()
496
497
498def _run(
499    args: argparse.Namespace,
500    methods: Sequence[str],
501    per_method_metadata: PerMethodMetadataType,
502) -> None:
503    logger.info("Starting python xDS Interop Client.")
504    global _global_server  # pylint: disable=global-statement
505    method_handles = []
506    channel_configs = {}
507    for method in _SUPPORTED_METHODS:
508        if method in methods:
509            qps = args.qps
510        else:
511            qps = 0
512        channel_config = _ChannelConfiguration(
513            method,
514            per_method_metadata.get(method, []),
515            qps,
516            args.server,
517            args.rpc_timeout_sec,
518            args.print_response,
519            args.secure_mode,
520        )
521        channel_configs[method] = channel_config
522        method_handles.append(_MethodHandle(args.num_channels, channel_config))
523    _global_server = grpc.server(concurrent.futures.ThreadPoolExecutor())
524    _global_server.add_insecure_port(f"0.0.0.0:{args.stats_port}")
525    test_pb2_grpc.add_LoadBalancerStatsServiceServicer_to_server(
526        _LoadBalancerStatsServicer(), _global_server
527    )
528    test_pb2_grpc.add_XdsUpdateClientConfigureServiceServicer_to_server(
529        _XdsUpdateClientConfigureServicer(channel_configs, args.qps),
530        _global_server,
531    )
532    channelz.add_channelz_servicer(_global_server)
533    grpc_admin.add_admin_servicers(_global_server)
534    _global_server.start()
535    _global_server.wait_for_termination()
536    for method_handle in method_handles:
537        method_handle.stop()
538
539
540def parse_metadata_arg(metadata_arg: str) -> PerMethodMetadataType:
541    metadata = metadata_arg.split(",") if args.metadata else []
542    per_method_metadata = collections.defaultdict(list)
543    for metadatum in metadata:
544        elems = metadatum.split(":")
545        if len(elems) != 3:
546            raise ValueError(
547                f"'{metadatum}' was not in the form 'METHOD:KEY:VALUE'"
548            )
549        if elems[0] not in _SUPPORTED_METHODS:
550            raise ValueError(f"Unrecognized method '{elems[0]}'")
551        per_method_metadata[elems[0]].append((elems[1], elems[2]))
552    return per_method_metadata
553
554
555def parse_rpc_arg(rpc_arg: str) -> Sequence[str]:
556    methods = rpc_arg.split(",")
557    if set(methods) - set(_SUPPORTED_METHODS):
558        raise ValueError(
559            "--rpc supported methods: {}".format(", ".join(_SUPPORTED_METHODS))
560        )
561    return methods
562
563
564def bool_arg(arg: str) -> bool:
565    if arg.lower() in ("true", "yes", "y"):
566        return True
567    elif arg.lower() in ("false", "no", "n"):
568        return False
569    else:
570        raise argparse.ArgumentTypeError(f"Could not parse '{arg}' as a bool.")
571
572
573if __name__ == "__main__":
574    parser = argparse.ArgumentParser(
575        description="Run Python XDS interop client."
576    )
577    parser.add_argument(
578        "--num_channels",
579        default=1,
580        type=int,
581        help="The number of channels from which to send requests.",
582    )
583    parser.add_argument(
584        "--print_response",
585        default="False",
586        type=bool_arg,
587        help="Write RPC response to STDOUT.",
588    )
589    parser.add_argument(
590        "--qps",
591        default=1,
592        type=int,
593        help="The number of queries to send from each channel per second.",
594    )
595    parser.add_argument(
596        "--rpc_timeout_sec",
597        default=30,
598        type=int,
599        help="The per-RPC timeout in seconds.",
600    )
601    parser.add_argument(
602        "--server", default="localhost:50051", help="The address of the server."
603    )
604    parser.add_argument(
605        "--stats_port",
606        default=50052,
607        type=int,
608        help="The port on which to expose the peer distribution stats service.",
609    )
610    parser.add_argument(
611        "--secure_mode",
612        default="False",
613        type=bool_arg,
614        help="If specified, uses xDS credentials to connect to the server.",
615    )
616    parser.add_argument(
617        "--verbose",
618        help="verbose log output",
619        default=False,
620        action="store_true",
621    )
622    parser.add_argument(
623        "--log_file", default=None, type=str, help="A file to log to."
624    )
625    rpc_help = "A comma-delimited list of RPC methods to run. Must be one of "
626    rpc_help += ", ".join(_SUPPORTED_METHODS)
627    rpc_help += "."
628    parser.add_argument("--rpc", default="UnaryCall", type=str, help=rpc_help)
629    metadata_help = (
630        "A comma-delimited list of 3-tuples of the form "
631        + "METHOD:KEY:VALUE, e.g. "
632        + "EmptyCall:key1:value1,UnaryCall:key2:value2,EmptyCall:k3:v3"
633    )
634    parser.add_argument("--metadata", default="", type=str, help=metadata_help)
635    args = parser.parse_args()
636    signal.signal(signal.SIGINT, _handle_sigint)
637    if args.verbose:
638        logger.setLevel(logging.DEBUG)
639    if args.log_file:
640        file_handler = logging.FileHandler(args.log_file, mode="a")
641        file_handler.setFormatter(formatter)
642        logger.addHandler(file_handler)
643    _run(args, parse_rpc_arg(args.rpc), parse_metadata_arg(args.metadata))
644