1# Copyright 2020 The gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import argparse 16import collections 17import concurrent.futures 18import datetime 19import logging 20import signal 21import threading 22import time 23from typing import ( 24 DefaultDict, 25 Dict, 26 FrozenSet, 27 Iterable, 28 List, 29 Mapping, 30 Sequence, 31 Set, 32 Tuple, 33) 34 35import grpc 36from grpc import _typing as grpc_typing 37import grpc_admin 38from grpc_channelz.v1 import channelz 39 40from src.proto.grpc.testing import empty_pb2 41from src.proto.grpc.testing import messages_pb2 42from src.proto.grpc.testing import test_pb2 43from src.proto.grpc.testing import test_pb2_grpc 44 45logger = logging.getLogger() 46console_handler = logging.StreamHandler() 47formatter = logging.Formatter(fmt="%(asctime)s: %(levelname)-8s %(message)s") 48console_handler.setFormatter(formatter) 49logger.addHandler(console_handler) 50 51_SUPPORTED_METHODS = ( 52 "UnaryCall", 53 "EmptyCall", 54) 55 56_METHOD_CAMEL_TO_CAPS_SNAKE = { 57 "UnaryCall": "UNARY_CALL", 58 "EmptyCall": "EMPTY_CALL", 59} 60 61_METHOD_STR_TO_ENUM = { 62 "UnaryCall": messages_pb2.ClientConfigureRequest.UNARY_CALL, 63 "EmptyCall": messages_pb2.ClientConfigureRequest.EMPTY_CALL, 64} 65 66_METHOD_ENUM_TO_STR = {v: k for k, v in _METHOD_STR_TO_ENUM.items()} 67 68PerMethodMetadataType = Mapping[str, Sequence[Tuple[str, str]]] 69 70 71# FutureFromCall is both a grpc.Call and grpc.Future 72class FutureFromCallType(grpc.Call, grpc.Future): 73 pass 74 75 76_CONFIG_CHANGE_TIMEOUT = datetime.timedelta(milliseconds=500) 77 78 79class _StatsWatcher: 80 _start: int 81 _end: int 82 _rpcs_needed: int 83 _rpcs_by_peer: DefaultDict[str, int] 84 _rpcs_by_method: DefaultDict[str, DefaultDict[str, int]] 85 _no_remote_peer: int 86 _lock: threading.Lock 87 _condition: threading.Condition 88 _metadata_keys: FrozenSet[str] 89 _include_all_metadata: bool 90 _metadata_by_peer: DefaultDict[ 91 str, messages_pb2.LoadBalancerStatsResponse.MetadataByPeer 92 ] 93 94 def __init__(self, start: int, end: int, metadata_keys: Iterable[str]): 95 self._start = start 96 self._end = end 97 self._rpcs_needed = end - start 98 self._rpcs_by_peer = collections.defaultdict(int) 99 self._rpcs_by_method = collections.defaultdict( 100 lambda: collections.defaultdict(int) 101 ) 102 self._condition = threading.Condition() 103 self._no_remote_peer = 0 104 self._metadata_keys = frozenset( 105 self._sanitize_metadata_key(key) for key in metadata_keys 106 ) 107 self._include_all_metadata = "*" in self._metadata_keys 108 self._metadata_by_peer = collections.defaultdict( 109 messages_pb2.LoadBalancerStatsResponse.MetadataByPeer 110 ) 111 112 @classmethod 113 def _sanitize_metadata_key(cls, metadata_key: str) -> str: 114 return metadata_key.strip().lower() 115 116 def _add_metadata( 117 self, 118 rpc_metadata: messages_pb2.LoadBalancerStatsResponse.RpcMetadata, 119 metadata_to_add: grpc_typing.MetadataType, 120 metadata_type: messages_pb2.LoadBalancerStatsResponse.MetadataType, 121 ) -> None: 122 for key, value in metadata_to_add: 123 if ( 124 self._include_all_metadata 125 or self._sanitize_metadata_key(key) in self._metadata_keys 126 ): 127 rpc_metadata.metadata.append( 128 messages_pb2.LoadBalancerStatsResponse.MetadataEntry( 129 key=key, value=value, type=metadata_type 130 ) 131 ) 132 133 def on_rpc_complete( 134 self, 135 request_id: int, 136 peer: str, 137 method: str, 138 *, 139 initial_metadata: grpc_typing.MetadataType, 140 trailing_metadata: grpc_typing.MetadataType, 141 ) -> None: 142 """Records statistics for a single RPC.""" 143 if self._start <= request_id < self._end: 144 with self._condition: 145 if not peer: 146 self._no_remote_peer += 1 147 else: 148 self._rpcs_by_peer[peer] += 1 149 self._rpcs_by_method[method][peer] += 1 150 if self._metadata_keys: 151 rpc_metadata = ( 152 messages_pb2.LoadBalancerStatsResponse.RpcMetadata() 153 ) 154 self._add_metadata( 155 rpc_metadata, 156 initial_metadata, 157 messages_pb2.LoadBalancerStatsResponse.MetadataType.INITIAL, 158 ) 159 self._add_metadata( 160 rpc_metadata, 161 trailing_metadata, 162 messages_pb2.LoadBalancerStatsResponse.MetadataType.TRAILING, 163 ) 164 self._metadata_by_peer[peer].rpc_metadata.append( 165 rpc_metadata 166 ) 167 self._rpcs_needed -= 1 168 self._condition.notify() 169 170 def await_rpc_stats_response( 171 self, timeout_sec: int 172 ) -> messages_pb2.LoadBalancerStatsResponse: 173 """Blocks until a full response has been collected.""" 174 with self._condition: 175 self._condition.wait_for( 176 lambda: not self._rpcs_needed, timeout=float(timeout_sec) 177 ) 178 response = messages_pb2.LoadBalancerStatsResponse() 179 for peer, count in self._rpcs_by_peer.items(): 180 response.rpcs_by_peer[peer] = count 181 for method, count_by_peer in self._rpcs_by_method.items(): 182 for peer, count in count_by_peer.items(): 183 response.rpcs_by_method[method].rpcs_by_peer[peer] = count 184 for peer, metadata_by_peer in self._metadata_by_peer.items(): 185 response.metadatas_by_peer[peer].CopyFrom(metadata_by_peer) 186 response.num_failures = self._no_remote_peer + self._rpcs_needed 187 return response 188 189 190_global_lock = threading.Lock() 191_stop_event = threading.Event() 192_global_rpc_id: int = 0 193_watchers: Set[_StatsWatcher] = set() 194_global_server = None 195_global_rpcs_started: Mapping[str, int] = collections.defaultdict(int) 196_global_rpcs_succeeded: Mapping[str, int] = collections.defaultdict(int) 197_global_rpcs_failed: Mapping[str, int] = collections.defaultdict(int) 198 199# Mapping[method, Mapping[status_code, count]] 200_global_rpc_statuses: Mapping[str, Mapping[int, int]] = collections.defaultdict( 201 lambda: collections.defaultdict(int) 202) 203 204 205def _handle_sigint(sig, frame) -> None: 206 logger.warning("Received SIGINT") 207 _stop_event.set() 208 _global_server.stop(None) 209 210 211class _LoadBalancerStatsServicer( 212 test_pb2_grpc.LoadBalancerStatsServiceServicer 213): 214 def __init__(self): 215 super(_LoadBalancerStatsServicer).__init__() 216 217 def GetClientStats( 218 self, 219 request: messages_pb2.LoadBalancerStatsRequest, 220 context: grpc.ServicerContext, 221 ) -> messages_pb2.LoadBalancerStatsResponse: 222 logger.info("Received stats request.") 223 start = None 224 end = None 225 watcher = None 226 with _global_lock: 227 start = _global_rpc_id + 1 228 end = start + request.num_rpcs 229 watcher = _StatsWatcher(start, end, request.metadata_keys) 230 _watchers.add(watcher) 231 response = watcher.await_rpc_stats_response(request.timeout_sec) 232 with _global_lock: 233 _watchers.remove(watcher) 234 logger.info("Returning stats response: %s", response) 235 return response 236 237 def GetClientAccumulatedStats( 238 self, 239 request: messages_pb2.LoadBalancerAccumulatedStatsRequest, 240 context: grpc.ServicerContext, 241 ) -> messages_pb2.LoadBalancerAccumulatedStatsResponse: 242 logger.info("Received cumulative stats request.") 243 response = messages_pb2.LoadBalancerAccumulatedStatsResponse() 244 with _global_lock: 245 for method in _SUPPORTED_METHODS: 246 caps_method = _METHOD_CAMEL_TO_CAPS_SNAKE[method] 247 response.num_rpcs_started_by_method[ 248 caps_method 249 ] = _global_rpcs_started[method] 250 response.num_rpcs_succeeded_by_method[ 251 caps_method 252 ] = _global_rpcs_succeeded[method] 253 response.num_rpcs_failed_by_method[ 254 caps_method 255 ] = _global_rpcs_failed[method] 256 response.stats_per_method[ 257 caps_method 258 ].rpcs_started = _global_rpcs_started[method] 259 for code, count in _global_rpc_statuses[method].items(): 260 response.stats_per_method[caps_method].result[code] = count 261 logger.info("Returning cumulative stats response.") 262 return response 263 264 265def _start_rpc( 266 method: str, 267 metadata: Sequence[Tuple[str, str]], 268 request_id: int, 269 stub: test_pb2_grpc.TestServiceStub, 270 timeout: float, 271 futures: Mapping[int, Tuple[FutureFromCallType, str]], 272) -> None: 273 logger.debug(f"Sending {method} request to backend: {request_id}") 274 if method == "UnaryCall": 275 future = stub.UnaryCall.future( 276 messages_pb2.SimpleRequest(), metadata=metadata, timeout=timeout 277 ) 278 elif method == "EmptyCall": 279 future = stub.EmptyCall.future( 280 empty_pb2.Empty(), metadata=metadata, timeout=timeout 281 ) 282 else: 283 raise ValueError(f"Unrecognized method '{method}'.") 284 futures[request_id] = (future, method) 285 286 287def _on_rpc_done( 288 rpc_id: int, future: FutureFromCallType, method: str, print_response: bool 289) -> None: 290 exception = future.exception() 291 hostname = "" 292 with _global_lock: 293 _global_rpc_statuses[method][future.code().value[0]] += 1 294 if exception is not None: 295 with _global_lock: 296 _global_rpcs_failed[method] += 1 297 if exception.code() == grpc.StatusCode.DEADLINE_EXCEEDED: 298 logger.error(f"RPC {rpc_id} timed out") 299 else: 300 logger.error(exception) 301 else: 302 response = future.result() 303 hostname = None 304 for metadatum in future.initial_metadata(): 305 if metadatum[0] == "hostname": 306 hostname = metadatum[1] 307 break 308 else: 309 hostname = response.hostname 310 if future.code() == grpc.StatusCode.OK: 311 with _global_lock: 312 _global_rpcs_succeeded[method] += 1 313 else: 314 with _global_lock: 315 _global_rpcs_failed[method] += 1 316 if print_response: 317 if future.code() == grpc.StatusCode.OK: 318 logger.debug("Successful response.") 319 else: 320 logger.debug(f"RPC failed: {rpc_id}") 321 with _global_lock: 322 for watcher in _watchers: 323 watcher.on_rpc_complete( 324 rpc_id, 325 hostname, 326 method, 327 initial_metadata=future.initial_metadata(), 328 trailing_metadata=future.trailing_metadata(), 329 ) 330 331 332def _remove_completed_rpcs( 333 rpc_futures: Mapping[int, FutureFromCallType], print_response: bool 334) -> None: 335 logger.debug("Removing completed RPCs") 336 done = [] 337 for future_id, (future, method) in rpc_futures.items(): 338 if future.done(): 339 _on_rpc_done(future_id, future, method, args.print_response) 340 done.append(future_id) 341 for rpc_id in done: 342 del rpc_futures[rpc_id] 343 344 345def _cancel_all_rpcs(futures: Mapping[int, Tuple[grpc.Future, str]]) -> None: 346 logger.info("Cancelling all remaining RPCs") 347 for future, _ in futures.values(): 348 future.cancel() 349 350 351class _ChannelConfiguration: 352 """Configuration for a single client channel. 353 354 Instances of this class are meant to be dealt with as PODs. That is, 355 data member should be accessed directly. This class is not thread-safe. 356 When accessing any of its members, the lock member should be held. 357 """ 358 359 def __init__( 360 self, 361 method: str, 362 metadata: Sequence[Tuple[str, str]], 363 qps: int, 364 server: str, 365 rpc_timeout_sec: int, 366 print_response: bool, 367 secure_mode: bool, 368 ): 369 # condition is signalled when a change is made to the config. 370 self.condition = threading.Condition() 371 372 self.method = method 373 self.metadata = metadata 374 self.qps = qps 375 self.server = server 376 self.rpc_timeout_sec = rpc_timeout_sec 377 self.print_response = print_response 378 self.secure_mode = secure_mode 379 380 381def _run_single_channel(config: _ChannelConfiguration) -> None: 382 global _global_rpc_id # pylint: disable=global-statement 383 with config.condition: 384 server = config.server 385 channel = None 386 if config.secure_mode: 387 fallback_creds = grpc.experimental.insecure_channel_credentials() 388 channel_creds = grpc.xds_channel_credentials(fallback_creds) 389 channel = grpc.secure_channel(server, channel_creds) 390 else: 391 channel = grpc.insecure_channel(server) 392 with channel: 393 stub = test_pb2_grpc.TestServiceStub(channel) 394 futures: Dict[int, Tuple[FutureFromCallType, str]] = {} 395 while not _stop_event.is_set(): 396 with config.condition: 397 if config.qps == 0: 398 config.condition.wait( 399 timeout=_CONFIG_CHANGE_TIMEOUT.total_seconds() 400 ) 401 continue 402 else: 403 duration_per_query = 1.0 / float(config.qps) 404 request_id = None 405 with _global_lock: 406 request_id = _global_rpc_id 407 _global_rpc_id += 1 408 _global_rpcs_started[config.method] += 1 409 start = time.time() 410 end = start + duration_per_query 411 _start_rpc( 412 config.method, 413 config.metadata, 414 request_id, 415 stub, 416 float(config.rpc_timeout_sec), 417 futures, 418 ) 419 print_response = config.print_response 420 _remove_completed_rpcs(futures, config.print_response) 421 logger.debug(f"Currently {len(futures)} in-flight RPCs") 422 now = time.time() 423 while now < end: 424 time.sleep(end - now) 425 now = time.time() 426 _cancel_all_rpcs(futures) 427 428 429class _XdsUpdateClientConfigureServicer( 430 test_pb2_grpc.XdsUpdateClientConfigureServiceServicer 431): 432 def __init__( 433 self, per_method_configs: Mapping[str, _ChannelConfiguration], qps: int 434 ): 435 super(_XdsUpdateClientConfigureServicer).__init__() 436 self._per_method_configs = per_method_configs 437 self._qps = qps 438 439 def Configure( 440 self, 441 request: messages_pb2.ClientConfigureRequest, 442 context: grpc.ServicerContext, 443 ) -> messages_pb2.ClientConfigureResponse: 444 logger.info("Received Configure RPC: %s", request) 445 method_strs = [_METHOD_ENUM_TO_STR[t] for t in request.types] 446 for method in _SUPPORTED_METHODS: 447 method_enum = _METHOD_STR_TO_ENUM[method] 448 channel_config = self._per_method_configs[method] 449 if method in method_strs: 450 qps = self._qps 451 metadata = ( 452 (md.key, md.value) 453 for md in request.metadata 454 if md.type == method_enum 455 ) 456 # For backward compatibility, do not change timeout when we 457 # receive a default value timeout. 458 if request.timeout_sec == 0: 459 timeout_sec = channel_config.rpc_timeout_sec 460 else: 461 timeout_sec = request.timeout_sec 462 else: 463 qps = 0 464 metadata = () 465 # Leave timeout unchanged for backward compatibility. 466 timeout_sec = channel_config.rpc_timeout_sec 467 with channel_config.condition: 468 channel_config.qps = qps 469 channel_config.metadata = list(metadata) 470 channel_config.rpc_timeout_sec = timeout_sec 471 channel_config.condition.notify_all() 472 return messages_pb2.ClientConfigureResponse() 473 474 475class _MethodHandle: 476 """An object grouping together threads driving RPCs for a method.""" 477 478 _channel_threads: List[threading.Thread] 479 480 def __init__( 481 self, num_channels: int, channel_config: _ChannelConfiguration 482 ): 483 """Creates and starts a group of threads running the indicated method.""" 484 self._channel_threads = [] 485 for i in range(num_channels): 486 thread = threading.Thread( 487 target=_run_single_channel, args=(channel_config,) 488 ) 489 thread.start() 490 self._channel_threads.append(thread) 491 492 def stop(self) -> None: 493 """Joins all threads referenced by the handle.""" 494 for channel_thread in self._channel_threads: 495 channel_thread.join() 496 497 498def _run( 499 args: argparse.Namespace, 500 methods: Sequence[str], 501 per_method_metadata: PerMethodMetadataType, 502) -> None: 503 logger.info("Starting python xDS Interop Client.") 504 global _global_server # pylint: disable=global-statement 505 method_handles = [] 506 channel_configs = {} 507 for method in _SUPPORTED_METHODS: 508 if method in methods: 509 qps = args.qps 510 else: 511 qps = 0 512 channel_config = _ChannelConfiguration( 513 method, 514 per_method_metadata.get(method, []), 515 qps, 516 args.server, 517 args.rpc_timeout_sec, 518 args.print_response, 519 args.secure_mode, 520 ) 521 channel_configs[method] = channel_config 522 method_handles.append(_MethodHandle(args.num_channels, channel_config)) 523 _global_server = grpc.server(concurrent.futures.ThreadPoolExecutor()) 524 _global_server.add_insecure_port(f"0.0.0.0:{args.stats_port}") 525 test_pb2_grpc.add_LoadBalancerStatsServiceServicer_to_server( 526 _LoadBalancerStatsServicer(), _global_server 527 ) 528 test_pb2_grpc.add_XdsUpdateClientConfigureServiceServicer_to_server( 529 _XdsUpdateClientConfigureServicer(channel_configs, args.qps), 530 _global_server, 531 ) 532 channelz.add_channelz_servicer(_global_server) 533 grpc_admin.add_admin_servicers(_global_server) 534 _global_server.start() 535 _global_server.wait_for_termination() 536 for method_handle in method_handles: 537 method_handle.stop() 538 539 540def parse_metadata_arg(metadata_arg: str) -> PerMethodMetadataType: 541 metadata = metadata_arg.split(",") if args.metadata else [] 542 per_method_metadata = collections.defaultdict(list) 543 for metadatum in metadata: 544 elems = metadatum.split(":") 545 if len(elems) != 3: 546 raise ValueError( 547 f"'{metadatum}' was not in the form 'METHOD:KEY:VALUE'" 548 ) 549 if elems[0] not in _SUPPORTED_METHODS: 550 raise ValueError(f"Unrecognized method '{elems[0]}'") 551 per_method_metadata[elems[0]].append((elems[1], elems[2])) 552 return per_method_metadata 553 554 555def parse_rpc_arg(rpc_arg: str) -> Sequence[str]: 556 methods = rpc_arg.split(",") 557 if set(methods) - set(_SUPPORTED_METHODS): 558 raise ValueError( 559 "--rpc supported methods: {}".format(", ".join(_SUPPORTED_METHODS)) 560 ) 561 return methods 562 563 564def bool_arg(arg: str) -> bool: 565 if arg.lower() in ("true", "yes", "y"): 566 return True 567 elif arg.lower() in ("false", "no", "n"): 568 return False 569 else: 570 raise argparse.ArgumentTypeError(f"Could not parse '{arg}' as a bool.") 571 572 573if __name__ == "__main__": 574 parser = argparse.ArgumentParser( 575 description="Run Python XDS interop client." 576 ) 577 parser.add_argument( 578 "--num_channels", 579 default=1, 580 type=int, 581 help="The number of channels from which to send requests.", 582 ) 583 parser.add_argument( 584 "--print_response", 585 default="False", 586 type=bool_arg, 587 help="Write RPC response to STDOUT.", 588 ) 589 parser.add_argument( 590 "--qps", 591 default=1, 592 type=int, 593 help="The number of queries to send from each channel per second.", 594 ) 595 parser.add_argument( 596 "--rpc_timeout_sec", 597 default=30, 598 type=int, 599 help="The per-RPC timeout in seconds.", 600 ) 601 parser.add_argument( 602 "--server", default="localhost:50051", help="The address of the server." 603 ) 604 parser.add_argument( 605 "--stats_port", 606 default=50052, 607 type=int, 608 help="The port on which to expose the peer distribution stats service.", 609 ) 610 parser.add_argument( 611 "--secure_mode", 612 default="False", 613 type=bool_arg, 614 help="If specified, uses xDS credentials to connect to the server.", 615 ) 616 parser.add_argument( 617 "--verbose", 618 help="verbose log output", 619 default=False, 620 action="store_true", 621 ) 622 parser.add_argument( 623 "--log_file", default=None, type=str, help="A file to log to." 624 ) 625 rpc_help = "A comma-delimited list of RPC methods to run. Must be one of " 626 rpc_help += ", ".join(_SUPPORTED_METHODS) 627 rpc_help += "." 628 parser.add_argument("--rpc", default="UnaryCall", type=str, help=rpc_help) 629 metadata_help = ( 630 "A comma-delimited list of 3-tuples of the form " 631 + "METHOD:KEY:VALUE, e.g. " 632 + "EmptyCall:key1:value1,UnaryCall:key2:value2,EmptyCall:k3:v3" 633 ) 634 parser.add_argument("--metadata", default="", type=str, help=metadata_help) 635 args = parser.parse_args() 636 signal.signal(signal.SIGINT, _handle_sigint) 637 if args.verbose: 638 logger.setLevel(logging.DEBUG) 639 if args.log_file: 640 file_handler = logging.FileHandler(args.log_file, mode="a") 641 file_handler.setFormatter(formatter) 642 logger.addHandler(file_handler) 643 _run(args, parse_rpc_arg(args.rpc), parse_metadata_arg(args.metadata)) 644