1# Copyright 2017 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Implementation of gRPC Python interceptors.""" 15 16import collections 17import sys 18import types 19from typing import Any, Callable, Optional, Sequence, Tuple, Union 20 21import grpc 22 23from ._typing import DeserializingFunction 24from ._typing import DoneCallbackType 25from ._typing import MetadataType 26from ._typing import RequestIterableType 27from ._typing import SerializingFunction 28 29 30class _ServicePipeline(object): 31 interceptors: Tuple[grpc.ServerInterceptor] 32 33 def __init__(self, interceptors: Sequence[grpc.ServerInterceptor]): 34 self.interceptors = tuple(interceptors) 35 36 def _continuation(self, thunk: Callable, index: int) -> Callable: 37 return lambda context: self._intercept_at(thunk, index, context) 38 39 def _intercept_at( 40 self, thunk: Callable, index: int, context: grpc.HandlerCallDetails 41 ) -> grpc.RpcMethodHandler: 42 if index < len(self.interceptors): 43 interceptor = self.interceptors[index] 44 thunk = self._continuation(thunk, index + 1) 45 return interceptor.intercept_service(thunk, context) 46 else: 47 return thunk(context) 48 49 def execute( 50 self, thunk: Callable, context: grpc.HandlerCallDetails 51 ) -> grpc.RpcMethodHandler: 52 return self._intercept_at(thunk, 0, context) 53 54 55def service_pipeline( 56 interceptors: Optional[Sequence[grpc.ServerInterceptor]], 57) -> Optional[_ServicePipeline]: 58 return _ServicePipeline(interceptors) if interceptors else None 59 60 61class _ClientCallDetails( 62 collections.namedtuple( 63 "_ClientCallDetails", 64 ( 65 "method", 66 "timeout", 67 "metadata", 68 "credentials", 69 "wait_for_ready", 70 "compression", 71 ), 72 ), 73 grpc.ClientCallDetails, 74): 75 pass 76 77 78def _unwrap_client_call_details( 79 call_details: grpc.ClientCallDetails, 80 default_details: grpc.ClientCallDetails, 81) -> Tuple[ 82 str, float, MetadataType, grpc.CallCredentials, bool, grpc.Compression 83]: 84 try: 85 method = call_details.method # pytype: disable=attribute-error 86 except AttributeError: 87 method = default_details.method # pytype: disable=attribute-error 88 89 try: 90 timeout = call_details.timeout # pytype: disable=attribute-error 91 except AttributeError: 92 timeout = default_details.timeout # pytype: disable=attribute-error 93 94 try: 95 metadata = call_details.metadata # pytype: disable=attribute-error 96 except AttributeError: 97 metadata = default_details.metadata # pytype: disable=attribute-error 98 99 try: 100 credentials = ( 101 call_details.credentials 102 ) # pytype: disable=attribute-error 103 except AttributeError: 104 credentials = ( 105 default_details.credentials 106 ) # pytype: disable=attribute-error 107 108 try: 109 wait_for_ready = ( 110 call_details.wait_for_ready 111 ) # pytype: disable=attribute-error 112 except AttributeError: 113 wait_for_ready = ( 114 default_details.wait_for_ready 115 ) # pytype: disable=attribute-error 116 117 try: 118 compression = ( 119 call_details.compression 120 ) # pytype: disable=attribute-error 121 except AttributeError: 122 compression = ( 123 default_details.compression 124 ) # pytype: disable=attribute-error 125 126 return method, timeout, metadata, credentials, wait_for_ready, compression 127 128 129class _FailureOutcome( 130 grpc.RpcError, grpc.Future, grpc.Call 131): # pylint: disable=too-many-ancestors 132 _exception: Exception 133 _traceback: types.TracebackType 134 135 def __init__(self, exception: Exception, traceback: types.TracebackType): 136 super(_FailureOutcome, self).__init__() 137 self._exception = exception 138 self._traceback = traceback 139 140 def initial_metadata(self) -> Optional[MetadataType]: 141 return None 142 143 def trailing_metadata(self) -> Optional[MetadataType]: 144 return None 145 146 def code(self) -> Optional[grpc.StatusCode]: 147 return grpc.StatusCode.INTERNAL 148 149 def details(self) -> Optional[str]: 150 return "Exception raised while intercepting the RPC" 151 152 def cancel(self) -> bool: 153 return False 154 155 def cancelled(self) -> bool: 156 return False 157 158 def is_active(self) -> bool: 159 return False 160 161 def time_remaining(self) -> Optional[float]: 162 return None 163 164 def running(self) -> bool: 165 return False 166 167 def done(self) -> bool: 168 return True 169 170 def result(self, ignored_timeout: Optional[float] = None): 171 raise self._exception 172 173 def exception( 174 self, ignored_timeout: Optional[float] = None 175 ) -> Optional[Exception]: 176 return self._exception 177 178 def traceback( 179 self, ignored_timeout: Optional[float] = None 180 ) -> Optional[types.TracebackType]: 181 return self._traceback 182 183 def add_callback(self, unused_callback) -> bool: 184 return False 185 186 def add_done_callback(self, fn: DoneCallbackType) -> None: 187 fn(self) 188 189 def __iter__(self): 190 return self 191 192 def __next__(self): 193 raise self._exception 194 195 def next(self): 196 return self.__next__() 197 198 199class _UnaryOutcome(grpc.Call, grpc.Future): 200 _response: Any 201 _call: grpc.Call 202 203 def __init__(self, response: Any, call: grpc.Call): 204 self._response = response 205 self._call = call 206 207 def initial_metadata(self) -> Optional[MetadataType]: 208 return self._call.initial_metadata() 209 210 def trailing_metadata(self) -> Optional[MetadataType]: 211 return self._call.trailing_metadata() 212 213 def code(self) -> Optional[grpc.StatusCode]: 214 return self._call.code() 215 216 def details(self) -> Optional[str]: 217 return self._call.details() 218 219 def is_active(self) -> bool: 220 return self._call.is_active() 221 222 def time_remaining(self) -> Optional[float]: 223 return self._call.time_remaining() 224 225 def cancel(self) -> bool: 226 return self._call.cancel() 227 228 def add_callback(self, callback) -> bool: 229 return self._call.add_callback(callback) 230 231 def cancelled(self) -> bool: 232 return False 233 234 def running(self) -> bool: 235 return False 236 237 def done(self) -> bool: 238 return True 239 240 def result(self, ignored_timeout: Optional[float] = None): 241 return self._response 242 243 def exception(self, ignored_timeout: Optional[float] = None): 244 return None 245 246 def traceback(self, ignored_timeout: Optional[float] = None): 247 return None 248 249 def add_done_callback(self, fn: DoneCallbackType) -> None: 250 fn(self) 251 252 253class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): 254 _thunk: Callable 255 _method: str 256 _interceptor: grpc.UnaryUnaryClientInterceptor 257 258 def __init__( 259 self, 260 thunk: Callable, 261 method: str, 262 interceptor: grpc.UnaryUnaryClientInterceptor, 263 ): 264 self._thunk = thunk 265 self._method = method 266 self._interceptor = interceptor 267 268 def __call__( 269 self, 270 request: Any, 271 timeout: Optional[float] = None, 272 metadata: Optional[MetadataType] = None, 273 credentials: Optional[grpc.CallCredentials] = None, 274 wait_for_ready: Optional[bool] = None, 275 compression: Optional[grpc.Compression] = None, 276 ) -> Any: 277 response, ignored_call = self._with_call( 278 request, 279 timeout=timeout, 280 metadata=metadata, 281 credentials=credentials, 282 wait_for_ready=wait_for_ready, 283 compression=compression, 284 ) 285 return response 286 287 def _with_call( 288 self, 289 request: Any, 290 timeout: Optional[float] = None, 291 metadata: Optional[MetadataType] = None, 292 credentials: Optional[grpc.CallCredentials] = None, 293 wait_for_ready: Optional[bool] = None, 294 compression: Optional[grpc.Compression] = None, 295 ) -> Tuple[Any, grpc.Call]: 296 client_call_details = _ClientCallDetails( 297 self._method, 298 timeout, 299 metadata, 300 credentials, 301 wait_for_ready, 302 compression, 303 ) 304 305 def continuation(new_details, request): 306 ( 307 new_method, 308 new_timeout, 309 new_metadata, 310 new_credentials, 311 new_wait_for_ready, 312 new_compression, 313 ) = _unwrap_client_call_details(new_details, client_call_details) 314 try: 315 response, call = self._thunk(new_method).with_call( 316 request, 317 timeout=new_timeout, 318 metadata=new_metadata, 319 credentials=new_credentials, 320 wait_for_ready=new_wait_for_ready, 321 compression=new_compression, 322 ) 323 return _UnaryOutcome(response, call) 324 except grpc.RpcError as rpc_error: 325 return rpc_error 326 except Exception as exception: # pylint:disable=broad-except 327 return _FailureOutcome(exception, sys.exc_info()[2]) 328 329 call = self._interceptor.intercept_unary_unary( 330 continuation, client_call_details, request 331 ) 332 return call.result(), call 333 334 def with_call( 335 self, 336 request: Any, 337 timeout: Optional[float] = None, 338 metadata: Optional[MetadataType] = None, 339 credentials: Optional[grpc.CallCredentials] = None, 340 wait_for_ready: Optional[bool] = None, 341 compression: Optional[grpc.Compression] = None, 342 ) -> Tuple[Any, grpc.Call]: 343 return self._with_call( 344 request, 345 timeout=timeout, 346 metadata=metadata, 347 credentials=credentials, 348 wait_for_ready=wait_for_ready, 349 compression=compression, 350 ) 351 352 def future( 353 self, 354 request: Any, 355 timeout: Optional[float] = None, 356 metadata: Optional[MetadataType] = None, 357 credentials: Optional[grpc.CallCredentials] = None, 358 wait_for_ready: Optional[bool] = None, 359 compression: Optional[grpc.Compression] = None, 360 ) -> Any: 361 client_call_details = _ClientCallDetails( 362 self._method, 363 timeout, 364 metadata, 365 credentials, 366 wait_for_ready, 367 compression, 368 ) 369 370 def continuation(new_details, request): 371 ( 372 new_method, 373 new_timeout, 374 new_metadata, 375 new_credentials, 376 new_wait_for_ready, 377 new_compression, 378 ) = _unwrap_client_call_details(new_details, client_call_details) 379 return self._thunk(new_method).future( 380 request, 381 timeout=new_timeout, 382 metadata=new_metadata, 383 credentials=new_credentials, 384 wait_for_ready=new_wait_for_ready, 385 compression=new_compression, 386 ) 387 388 try: 389 return self._interceptor.intercept_unary_unary( 390 continuation, client_call_details, request 391 ) 392 except Exception as exception: # pylint:disable=broad-except 393 return _FailureOutcome(exception, sys.exc_info()[2]) 394 395 396class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): 397 _thunk: Callable 398 _method: str 399 _interceptor: grpc.UnaryStreamClientInterceptor 400 401 def __init__( 402 self, 403 thunk: Callable, 404 method: str, 405 interceptor: grpc.UnaryStreamClientInterceptor, 406 ): 407 self._thunk = thunk 408 self._method = method 409 self._interceptor = interceptor 410 411 def __call__( 412 self, 413 request: Any, 414 timeout: Optional[float] = None, 415 metadata: Optional[MetadataType] = None, 416 credentials: Optional[grpc.CallCredentials] = None, 417 wait_for_ready: Optional[bool] = None, 418 compression: Optional[grpc.Compression] = None, 419 ): 420 client_call_details = _ClientCallDetails( 421 self._method, 422 timeout, 423 metadata, 424 credentials, 425 wait_for_ready, 426 compression, 427 ) 428 429 def continuation(new_details, request): 430 ( 431 new_method, 432 new_timeout, 433 new_metadata, 434 new_credentials, 435 new_wait_for_ready, 436 new_compression, 437 ) = _unwrap_client_call_details(new_details, client_call_details) 438 return self._thunk(new_method)( 439 request, 440 timeout=new_timeout, 441 metadata=new_metadata, 442 credentials=new_credentials, 443 wait_for_ready=new_wait_for_ready, 444 compression=new_compression, 445 ) 446 447 try: 448 return self._interceptor.intercept_unary_stream( 449 continuation, client_call_details, request 450 ) 451 except Exception as exception: # pylint:disable=broad-except 452 return _FailureOutcome(exception, sys.exc_info()[2]) 453 454 455class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): 456 _thunk: Callable 457 _method: str 458 _interceptor: grpc.StreamUnaryClientInterceptor 459 460 def __init__( 461 self, 462 thunk: Callable, 463 method: str, 464 interceptor: grpc.StreamUnaryClientInterceptor, 465 ): 466 self._thunk = thunk 467 self._method = method 468 self._interceptor = interceptor 469 470 def __call__( 471 self, 472 request_iterator: RequestIterableType, 473 timeout: Optional[float] = None, 474 metadata: Optional[MetadataType] = None, 475 credentials: Optional[grpc.CallCredentials] = None, 476 wait_for_ready: Optional[bool] = None, 477 compression: Optional[grpc.Compression] = None, 478 ) -> Any: 479 response, ignored_call = self._with_call( 480 request_iterator, 481 timeout=timeout, 482 metadata=metadata, 483 credentials=credentials, 484 wait_for_ready=wait_for_ready, 485 compression=compression, 486 ) 487 return response 488 489 def _with_call( 490 self, 491 request_iterator: RequestIterableType, 492 timeout: Optional[float] = None, 493 metadata: Optional[MetadataType] = None, 494 credentials: Optional[grpc.CallCredentials] = None, 495 wait_for_ready: Optional[bool] = None, 496 compression: Optional[grpc.Compression] = None, 497 ) -> Tuple[Any, grpc.Call]: 498 client_call_details = _ClientCallDetails( 499 self._method, 500 timeout, 501 metadata, 502 credentials, 503 wait_for_ready, 504 compression, 505 ) 506 507 def continuation(new_details, request_iterator): 508 ( 509 new_method, 510 new_timeout, 511 new_metadata, 512 new_credentials, 513 new_wait_for_ready, 514 new_compression, 515 ) = _unwrap_client_call_details(new_details, client_call_details) 516 try: 517 response, call = self._thunk(new_method).with_call( 518 request_iterator, 519 timeout=new_timeout, 520 metadata=new_metadata, 521 credentials=new_credentials, 522 wait_for_ready=new_wait_for_ready, 523 compression=new_compression, 524 ) 525 return _UnaryOutcome(response, call) 526 except grpc.RpcError as rpc_error: 527 return rpc_error 528 except Exception as exception: # pylint:disable=broad-except 529 return _FailureOutcome(exception, sys.exc_info()[2]) 530 531 call = self._interceptor.intercept_stream_unary( 532 continuation, client_call_details, request_iterator 533 ) 534 return call.result(), call 535 536 def with_call( 537 self, 538 request_iterator: RequestIterableType, 539 timeout: Optional[float] = None, 540 metadata: Optional[MetadataType] = None, 541 credentials: Optional[grpc.CallCredentials] = None, 542 wait_for_ready: Optional[bool] = None, 543 compression: Optional[grpc.Compression] = None, 544 ) -> Tuple[Any, grpc.Call]: 545 return self._with_call( 546 request_iterator, 547 timeout=timeout, 548 metadata=metadata, 549 credentials=credentials, 550 wait_for_ready=wait_for_ready, 551 compression=compression, 552 ) 553 554 def future( 555 self, 556 request_iterator: RequestIterableType, 557 timeout: Optional[float] = None, 558 metadata: Optional[MetadataType] = None, 559 credentials: Optional[grpc.CallCredentials] = None, 560 wait_for_ready: Optional[bool] = None, 561 compression: Optional[grpc.Compression] = None, 562 ) -> Any: 563 client_call_details = _ClientCallDetails( 564 self._method, 565 timeout, 566 metadata, 567 credentials, 568 wait_for_ready, 569 compression, 570 ) 571 572 def continuation(new_details, request_iterator): 573 ( 574 new_method, 575 new_timeout, 576 new_metadata, 577 new_credentials, 578 new_wait_for_ready, 579 new_compression, 580 ) = _unwrap_client_call_details(new_details, client_call_details) 581 return self._thunk(new_method).future( 582 request_iterator, 583 timeout=new_timeout, 584 metadata=new_metadata, 585 credentials=new_credentials, 586 wait_for_ready=new_wait_for_ready, 587 compression=new_compression, 588 ) 589 590 try: 591 return self._interceptor.intercept_stream_unary( 592 continuation, client_call_details, request_iterator 593 ) 594 except Exception as exception: # pylint:disable=broad-except 595 return _FailureOutcome(exception, sys.exc_info()[2]) 596 597 598class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): 599 _thunk: Callable 600 _method: str 601 _interceptor: grpc.StreamStreamClientInterceptor 602 603 def __init__( 604 self, 605 thunk: Callable, 606 method: str, 607 interceptor: grpc.StreamStreamClientInterceptor, 608 ): 609 self._thunk = thunk 610 self._method = method 611 self._interceptor = interceptor 612 613 def __call__( 614 self, 615 request_iterator: RequestIterableType, 616 timeout: Optional[float] = None, 617 metadata: Optional[MetadataType] = None, 618 credentials: Optional[grpc.CallCredentials] = None, 619 wait_for_ready: Optional[bool] = None, 620 compression: Optional[grpc.Compression] = None, 621 ): 622 client_call_details = _ClientCallDetails( 623 self._method, 624 timeout, 625 metadata, 626 credentials, 627 wait_for_ready, 628 compression, 629 ) 630 631 def continuation(new_details, request_iterator): 632 ( 633 new_method, 634 new_timeout, 635 new_metadata, 636 new_credentials, 637 new_wait_for_ready, 638 new_compression, 639 ) = _unwrap_client_call_details(new_details, client_call_details) 640 return self._thunk(new_method)( 641 request_iterator, 642 timeout=new_timeout, 643 metadata=new_metadata, 644 credentials=new_credentials, 645 wait_for_ready=new_wait_for_ready, 646 compression=new_compression, 647 ) 648 649 try: 650 return self._interceptor.intercept_stream_stream( 651 continuation, client_call_details, request_iterator 652 ) 653 except Exception as exception: # pylint:disable=broad-except 654 return _FailureOutcome(exception, sys.exc_info()[2]) 655 656 657class _Channel(grpc.Channel): 658 _channel: grpc.Channel 659 _interceptor: Union[ 660 grpc.UnaryUnaryClientInterceptor, 661 grpc.UnaryStreamClientInterceptor, 662 grpc.StreamStreamClientInterceptor, 663 grpc.StreamUnaryClientInterceptor, 664 ] 665 666 def __init__( 667 self, 668 channel: grpc.Channel, 669 interceptor: Union[ 670 grpc.UnaryUnaryClientInterceptor, 671 grpc.UnaryStreamClientInterceptor, 672 grpc.StreamStreamClientInterceptor, 673 grpc.StreamUnaryClientInterceptor, 674 ], 675 ): 676 self._channel = channel 677 self._interceptor = interceptor 678 679 def subscribe( 680 self, callback: Callable, try_to_connect: Optional[bool] = False 681 ): 682 self._channel.subscribe(callback, try_to_connect=try_to_connect) 683 684 def unsubscribe(self, callback: Callable): 685 self._channel.unsubscribe(callback) 686 687 # pylint: disable=arguments-differ 688 def unary_unary( 689 self, 690 method: str, 691 request_serializer: Optional[SerializingFunction] = None, 692 response_deserializer: Optional[DeserializingFunction] = None, 693 _registered_method: Optional[bool] = False, 694 ) -> grpc.UnaryUnaryMultiCallable: 695 # pytype: disable=wrong-arg-count 696 thunk = lambda m: self._channel.unary_unary( 697 m, 698 request_serializer, 699 response_deserializer, 700 _registered_method, 701 ) 702 # pytype: enable=wrong-arg-count 703 if isinstance(self._interceptor, grpc.UnaryUnaryClientInterceptor): 704 return _UnaryUnaryMultiCallable(thunk, method, self._interceptor) 705 else: 706 return thunk(method) 707 708 # pylint: disable=arguments-differ 709 def unary_stream( 710 self, 711 method: str, 712 request_serializer: Optional[SerializingFunction] = None, 713 response_deserializer: Optional[DeserializingFunction] = None, 714 _registered_method: Optional[bool] = False, 715 ) -> grpc.UnaryStreamMultiCallable: 716 # pytype: disable=wrong-arg-count 717 thunk = lambda m: self._channel.unary_stream( 718 m, 719 request_serializer, 720 response_deserializer, 721 _registered_method, 722 ) 723 # pytype: enable=wrong-arg-count 724 if isinstance(self._interceptor, grpc.UnaryStreamClientInterceptor): 725 return _UnaryStreamMultiCallable(thunk, method, self._interceptor) 726 else: 727 return thunk(method) 728 729 # pylint: disable=arguments-differ 730 def stream_unary( 731 self, 732 method: str, 733 request_serializer: Optional[SerializingFunction] = None, 734 response_deserializer: Optional[DeserializingFunction] = None, 735 _registered_method: Optional[bool] = False, 736 ) -> grpc.StreamUnaryMultiCallable: 737 # pytype: disable=wrong-arg-count 738 thunk = lambda m: self._channel.stream_unary( 739 m, 740 request_serializer, 741 response_deserializer, 742 _registered_method, 743 ) 744 # pytype: enable=wrong-arg-count 745 if isinstance(self._interceptor, grpc.StreamUnaryClientInterceptor): 746 return _StreamUnaryMultiCallable(thunk, method, self._interceptor) 747 else: 748 return thunk(method) 749 750 # pylint: disable=arguments-differ 751 def stream_stream( 752 self, 753 method: str, 754 request_serializer: Optional[SerializingFunction] = None, 755 response_deserializer: Optional[DeserializingFunction] = None, 756 _registered_method: Optional[bool] = False, 757 ) -> grpc.StreamStreamMultiCallable: 758 # pytype: disable=wrong-arg-count 759 thunk = lambda m: self._channel.stream_stream( 760 m, 761 request_serializer, 762 response_deserializer, 763 _registered_method, 764 ) 765 # pytype: enable=wrong-arg-count 766 if isinstance(self._interceptor, grpc.StreamStreamClientInterceptor): 767 return _StreamStreamMultiCallable(thunk, method, self._interceptor) 768 else: 769 return thunk(method) 770 771 def _close(self): 772 self._channel.close() 773 774 def __enter__(self): 775 return self 776 777 def __exit__(self, exc_type, exc_val, exc_tb): 778 self._close() 779 return False 780 781 def close(self): 782 self._channel.close() 783 784 785def intercept_channel( 786 channel: grpc.Channel, 787 *interceptors: Optional[ 788 Sequence[ 789 Union[ 790 grpc.UnaryUnaryClientInterceptor, 791 grpc.UnaryStreamClientInterceptor, 792 grpc.StreamStreamClientInterceptor, 793 grpc.StreamUnaryClientInterceptor, 794 ] 795 ] 796 ], 797) -> grpc.Channel: 798 for interceptor in reversed(list(interceptors)): 799 if ( 800 not isinstance(interceptor, grpc.UnaryUnaryClientInterceptor) 801 and not isinstance(interceptor, grpc.UnaryStreamClientInterceptor) 802 and not isinstance(interceptor, grpc.StreamUnaryClientInterceptor) 803 and not isinstance(interceptor, grpc.StreamStreamClientInterceptor) 804 ): 805 raise TypeError( 806 "interceptor must be " 807 "grpc.UnaryUnaryClientInterceptor or " 808 "grpc.UnaryStreamClientInterceptor or " 809 "grpc.StreamUnaryClientInterceptor or " 810 "grpc.StreamStreamClientInterceptor or " 811 ) 812 channel = _Channel(channel, interceptor) 813 return channel 814