1# Copyright 2019 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Interceptors implementation of gRPC Asyncio Python.""" 15from abc import ABCMeta 16from abc import abstractmethod 17import asyncio 18import collections 19import functools 20from typing import ( 21 AsyncIterable, 22 Awaitable, 23 Callable, 24 Iterator, 25 List, 26 Optional, 27 Sequence, 28 Union, 29) 30 31import grpc 32from grpc._cython import cygrpc 33 34from . import _base_call 35from ._call import AioRpcError 36from ._call import StreamStreamCall 37from ._call import StreamUnaryCall 38from ._call import UnaryStreamCall 39from ._call import UnaryUnaryCall 40from ._call import _API_STYLE_ERROR 41from ._call import _RPC_ALREADY_FINISHED_DETAILS 42from ._call import _RPC_HALF_CLOSED_DETAILS 43from ._metadata import Metadata 44from ._typing import DeserializingFunction 45from ._typing import DoneCallbackType 46from ._typing import RequestIterableType 47from ._typing import RequestType 48from ._typing import ResponseIterableType 49from ._typing import ResponseType 50from ._typing import SerializingFunction 51from ._utils import _timeout_to_deadline 52 53_LOCAL_CANCELLATION_DETAILS = "Locally cancelled by application!" 54 55 56class ServerInterceptor(metaclass=ABCMeta): 57 """Affords intercepting incoming RPCs on the service-side. 58 59 This is an EXPERIMENTAL API. 60 """ 61 62 @abstractmethod 63 async def intercept_service( 64 self, 65 continuation: Callable[ 66 [grpc.HandlerCallDetails], Awaitable[grpc.RpcMethodHandler] 67 ], 68 handler_call_details: grpc.HandlerCallDetails, 69 ) -> grpc.RpcMethodHandler: 70 """Intercepts incoming RPCs before handing them over to a handler. 71 72 State can be passed from an interceptor to downstream interceptors 73 via contextvars. The first interceptor is called from an empty 74 contextvars.Context, and the same Context is used for downstream 75 interceptors and for the final handler call. Note that there are no 76 guarantees that interceptors and handlers will be called from the 77 same thread. 78 79 Args: 80 continuation: A function that takes a HandlerCallDetails and 81 proceeds to invoke the next interceptor in the chain, if any, 82 or the RPC handler lookup logic, with the call details passed 83 as an argument, and returns an RpcMethodHandler instance if 84 the RPC is considered serviced, or None otherwise. 85 handler_call_details: A HandlerCallDetails describing the RPC. 86 87 Returns: 88 An RpcMethodHandler with which the RPC may be serviced if the 89 interceptor chooses to service this RPC, or None otherwise. 90 """ 91 92 93class ClientCallDetails( 94 collections.namedtuple( 95 "ClientCallDetails", 96 ("method", "timeout", "metadata", "credentials", "wait_for_ready"), 97 ), 98 grpc.ClientCallDetails, 99): 100 """Describes an RPC to be invoked. 101 102 This is an EXPERIMENTAL API. 103 104 Args: 105 method: The method name of the RPC. 106 timeout: An optional duration of time in seconds to allow for the RPC. 107 metadata: Optional metadata to be transmitted to the service-side of 108 the RPC. 109 credentials: An optional CallCredentials for the RPC. 110 wait_for_ready: An optional flag to enable :term:`wait_for_ready` mechanism. 111 """ 112 113 method: str 114 timeout: Optional[float] 115 metadata: Optional[Metadata] 116 credentials: Optional[grpc.CallCredentials] 117 wait_for_ready: Optional[bool] 118 119 120class ClientInterceptor(metaclass=ABCMeta): 121 """Base class used for all Aio Client Interceptor classes""" 122 123 124class UnaryUnaryClientInterceptor(ClientInterceptor, metaclass=ABCMeta): 125 """Affords intercepting unary-unary invocations.""" 126 127 @abstractmethod 128 async def intercept_unary_unary( 129 self, 130 continuation: Callable[ 131 [ClientCallDetails, RequestType], UnaryUnaryCall 132 ], 133 client_call_details: ClientCallDetails, 134 request: RequestType, 135 ) -> Union[UnaryUnaryCall, ResponseType]: 136 """Intercepts a unary-unary invocation asynchronously. 137 138 Args: 139 continuation: A coroutine that proceeds with the invocation by 140 executing the next interceptor in the chain or invoking the 141 actual RPC on the underlying Channel. It is the interceptor's 142 responsibility to call it if it decides to move the RPC forward. 143 The interceptor can use 144 `call = await continuation(client_call_details, request)` 145 to continue with the RPC. `continuation` returns the call to the 146 RPC. 147 client_call_details: A ClientCallDetails object describing the 148 outgoing RPC. 149 request: The request value for the RPC. 150 151 Returns: 152 An object with the RPC response. 153 154 Raises: 155 AioRpcError: Indicating that the RPC terminated with non-OK status. 156 asyncio.CancelledError: Indicating that the RPC was canceled. 157 """ 158 159 160class UnaryStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta): 161 """Affords intercepting unary-stream invocations.""" 162 163 @abstractmethod 164 async def intercept_unary_stream( 165 self, 166 continuation: Callable[ 167 [ClientCallDetails, RequestType], UnaryStreamCall 168 ], 169 client_call_details: ClientCallDetails, 170 request: RequestType, 171 ) -> Union[ResponseIterableType, UnaryStreamCall]: 172 """Intercepts a unary-stream invocation asynchronously. 173 174 The function could return the call object or an asynchronous 175 iterator, in case of being an asyncrhonous iterator this will 176 become the source of the reads done by the caller. 177 178 Args: 179 continuation: A coroutine that proceeds with the invocation by 180 executing the next interceptor in the chain or invoking the 181 actual RPC on the underlying Channel. It is the interceptor's 182 responsibility to call it if it decides to move the RPC forward. 183 The interceptor can use 184 `call = await continuation(client_call_details, request)` 185 to continue with the RPC. `continuation` returns the call to the 186 RPC. 187 client_call_details: A ClientCallDetails object describing the 188 outgoing RPC. 189 request: The request value for the RPC. 190 191 Returns: 192 The RPC Call or an asynchronous iterator. 193 194 Raises: 195 AioRpcError: Indicating that the RPC terminated with non-OK status. 196 asyncio.CancelledError: Indicating that the RPC was canceled. 197 """ 198 199 200class StreamUnaryClientInterceptor(ClientInterceptor, metaclass=ABCMeta): 201 """Affords intercepting stream-unary invocations.""" 202 203 @abstractmethod 204 async def intercept_stream_unary( 205 self, 206 continuation: Callable[ 207 [ClientCallDetails, RequestType], StreamUnaryCall 208 ], 209 client_call_details: ClientCallDetails, 210 request_iterator: RequestIterableType, 211 ) -> StreamUnaryCall: 212 """Intercepts a stream-unary invocation asynchronously. 213 214 Within the interceptor the usage of the call methods like `write` or 215 even awaiting the call should be done carefully, since the caller 216 could be expecting an untouched call, for example for start writing 217 messages to it. 218 219 Args: 220 continuation: A coroutine that proceeds with the invocation by 221 executing the next interceptor in the chain or invoking the 222 actual RPC on the underlying Channel. It is the interceptor's 223 responsibility to call it if it decides to move the RPC forward. 224 The interceptor can use 225 `call = await continuation(client_call_details, request_iterator)` 226 to continue with the RPC. `continuation` returns the call to the 227 RPC. 228 client_call_details: A ClientCallDetails object describing the 229 outgoing RPC. 230 request_iterator: The request iterator that will produce requests 231 for the RPC. 232 233 Returns: 234 The RPC Call. 235 236 Raises: 237 AioRpcError: Indicating that the RPC terminated with non-OK status. 238 asyncio.CancelledError: Indicating that the RPC was canceled. 239 """ 240 241 242class StreamStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta): 243 """Affords intercepting stream-stream invocations.""" 244 245 @abstractmethod 246 async def intercept_stream_stream( 247 self, 248 continuation: Callable[ 249 [ClientCallDetails, RequestType], StreamStreamCall 250 ], 251 client_call_details: ClientCallDetails, 252 request_iterator: RequestIterableType, 253 ) -> Union[ResponseIterableType, StreamStreamCall]: 254 """Intercepts a stream-stream invocation asynchronously. 255 256 Within the interceptor the usage of the call methods like `write` or 257 even awaiting the call should be done carefully, since the caller 258 could be expecting an untouched call, for example for start writing 259 messages to it. 260 261 The function could return the call object or an asynchronous 262 iterator, in case of being an asyncrhonous iterator this will 263 become the source of the reads done by the caller. 264 265 Args: 266 continuation: A coroutine that proceeds with the invocation by 267 executing the next interceptor in the chain or invoking the 268 actual RPC on the underlying Channel. It is the interceptor's 269 responsibility to call it if it decides to move the RPC forward. 270 The interceptor can use 271 `call = await continuation(client_call_details, request_iterator)` 272 to continue with the RPC. `continuation` returns the call to the 273 RPC. 274 client_call_details: A ClientCallDetails object describing the 275 outgoing RPC. 276 request_iterator: The request iterator that will produce requests 277 for the RPC. 278 279 Returns: 280 The RPC Call or an asynchronous iterator. 281 282 Raises: 283 AioRpcError: Indicating that the RPC terminated with non-OK status. 284 asyncio.CancelledError: Indicating that the RPC was canceled. 285 """ 286 287 288class InterceptedCall: 289 """Base implementation for all intercepted call arities. 290 291 Interceptors might have some work to do before the RPC invocation with 292 the capacity of changing the invocation parameters, and some work to do 293 after the RPC invocation with the capacity for accessing to the wrapped 294 `UnaryUnaryCall`. 295 296 It handles also early and later cancellations, when the RPC has not even 297 started and the execution is still held by the interceptors or when the 298 RPC has finished but again the execution is still held by the interceptors. 299 300 Once the RPC is finally executed, all methods are finally done against the 301 intercepted call, being at the same time the same call returned to the 302 interceptors. 303 304 As a base class for all of the interceptors implements the logic around 305 final status, metadata and cancellation. 306 """ 307 308 _interceptors_task: asyncio.Task 309 _pending_add_done_callbacks: Sequence[DoneCallbackType] 310 311 def __init__(self, interceptors_task: asyncio.Task) -> None: 312 self._interceptors_task = interceptors_task 313 self._pending_add_done_callbacks = [] 314 self._interceptors_task.add_done_callback( 315 self._fire_or_add_pending_done_callbacks 316 ) 317 318 def __del__(self): 319 self.cancel() 320 321 def _fire_or_add_pending_done_callbacks( 322 self, interceptors_task: asyncio.Task 323 ) -> None: 324 if not self._pending_add_done_callbacks: 325 return 326 327 call_completed = False 328 329 try: 330 call = interceptors_task.result() 331 if call.done(): 332 call_completed = True 333 except (AioRpcError, asyncio.CancelledError): 334 call_completed = True 335 336 if call_completed: 337 for callback in self._pending_add_done_callbacks: 338 callback(self) 339 else: 340 for callback in self._pending_add_done_callbacks: 341 callback = functools.partial( 342 self._wrap_add_done_callback, callback 343 ) 344 call.add_done_callback(callback) 345 346 self._pending_add_done_callbacks = [] 347 348 def _wrap_add_done_callback( 349 self, callback: DoneCallbackType, unused_call: _base_call.Call 350 ) -> None: 351 callback(self) 352 353 def cancel(self) -> bool: 354 if not self._interceptors_task.done(): 355 # There is no yet the intercepted call available, 356 # Trying to cancel it by using the generic Asyncio 357 # cancellation method. 358 return self._interceptors_task.cancel() 359 360 try: 361 call = self._interceptors_task.result() 362 except AioRpcError: 363 return False 364 except asyncio.CancelledError: 365 return False 366 367 return call.cancel() 368 369 def cancelled(self) -> bool: 370 if not self._interceptors_task.done(): 371 return False 372 373 try: 374 call = self._interceptors_task.result() 375 except AioRpcError as err: 376 return err.code() == grpc.StatusCode.CANCELLED 377 except asyncio.CancelledError: 378 return True 379 380 return call.cancelled() 381 382 def done(self) -> bool: 383 if not self._interceptors_task.done(): 384 return False 385 386 try: 387 call = self._interceptors_task.result() 388 except (AioRpcError, asyncio.CancelledError): 389 return True 390 391 return call.done() 392 393 def add_done_callback(self, callback: DoneCallbackType) -> None: 394 if not self._interceptors_task.done(): 395 self._pending_add_done_callbacks.append(callback) 396 return 397 398 try: 399 call = self._interceptors_task.result() 400 except (AioRpcError, asyncio.CancelledError): 401 callback(self) 402 return 403 404 if call.done(): 405 callback(self) 406 else: 407 callback = functools.partial(self._wrap_add_done_callback, callback) 408 call.add_done_callback(callback) 409 410 def time_remaining(self) -> Optional[float]: 411 raise NotImplementedError() 412 413 async def initial_metadata(self) -> Optional[Metadata]: 414 try: 415 call = await self._interceptors_task 416 except AioRpcError as err: 417 return err.initial_metadata() 418 except asyncio.CancelledError: 419 return None 420 421 return await call.initial_metadata() 422 423 async def trailing_metadata(self) -> Optional[Metadata]: 424 try: 425 call = await self._interceptors_task 426 except AioRpcError as err: 427 return err.trailing_metadata() 428 except asyncio.CancelledError: 429 return None 430 431 return await call.trailing_metadata() 432 433 async def code(self) -> grpc.StatusCode: 434 try: 435 call = await self._interceptors_task 436 except AioRpcError as err: 437 return err.code() 438 except asyncio.CancelledError: 439 return grpc.StatusCode.CANCELLED 440 441 return await call.code() 442 443 async def details(self) -> str: 444 try: 445 call = await self._interceptors_task 446 except AioRpcError as err: 447 return err.details() 448 except asyncio.CancelledError: 449 return _LOCAL_CANCELLATION_DETAILS 450 451 return await call.details() 452 453 async def debug_error_string(self) -> Optional[str]: 454 try: 455 call = await self._interceptors_task 456 except AioRpcError as err: 457 return err.debug_error_string() 458 except asyncio.CancelledError: 459 return "" 460 461 return await call.debug_error_string() 462 463 async def wait_for_connection(self) -> None: 464 call = await self._interceptors_task 465 return await call.wait_for_connection() 466 467 468class _InterceptedUnaryResponseMixin: 469 def __await__(self): 470 call = yield from self._interceptors_task.__await__() 471 response = yield from call.__await__() 472 return response 473 474 475class _InterceptedStreamResponseMixin: 476 _response_aiter: Optional[AsyncIterable[ResponseType]] 477 478 def _init_stream_response_mixin(self) -> None: 479 # Is initalized later, otherwise if the iterator is not finally 480 # consumed a logging warning is emmited by Asyncio. 481 self._response_aiter = None 482 483 async def _wait_for_interceptor_task_response_iterator( 484 self, 485 ) -> ResponseType: 486 call = await self._interceptors_task 487 async for response in call: 488 yield response 489 490 def __aiter__(self) -> AsyncIterable[ResponseType]: 491 if self._response_aiter is None: 492 self._response_aiter = ( 493 self._wait_for_interceptor_task_response_iterator() 494 ) 495 return self._response_aiter 496 497 async def read(self) -> ResponseType: 498 if self._response_aiter is None: 499 self._response_aiter = ( 500 self._wait_for_interceptor_task_response_iterator() 501 ) 502 return await self._response_aiter.asend(None) 503 504 505class _InterceptedStreamRequestMixin: 506 _write_to_iterator_async_gen: Optional[AsyncIterable[RequestType]] 507 _write_to_iterator_queue: Optional[asyncio.Queue] 508 _status_code_task: Optional[asyncio.Task] 509 510 _FINISH_ITERATOR_SENTINEL = object() 511 512 def _init_stream_request_mixin( 513 self, request_iterator: Optional[RequestIterableType] 514 ) -> RequestIterableType: 515 if request_iterator is None: 516 # We provide our own request iterator which is a proxy 517 # of the futures writes that will be done by the caller. 518 self._write_to_iterator_queue = asyncio.Queue(maxsize=1) 519 self._write_to_iterator_async_gen = ( 520 self._proxy_writes_as_request_iterator() 521 ) 522 self._status_code_task = None 523 request_iterator = self._write_to_iterator_async_gen 524 else: 525 self._write_to_iterator_queue = None 526 527 return request_iterator 528 529 async def _proxy_writes_as_request_iterator(self): 530 await self._interceptors_task 531 532 while True: 533 value = await self._write_to_iterator_queue.get() 534 if ( 535 value 536 is _InterceptedStreamRequestMixin._FINISH_ITERATOR_SENTINEL 537 ): 538 break 539 yield value 540 541 async def _write_to_iterator_queue_interruptible( 542 self, request: RequestType, call: InterceptedCall 543 ): 544 # Write the specified 'request' to the request iterator queue using the 545 # specified 'call' to allow for interruption of the write in the case 546 # of abrupt termination of the call. 547 if self._status_code_task is None: 548 self._status_code_task = self._loop.create_task(call.code()) 549 550 await asyncio.wait( 551 ( 552 self._loop.create_task( 553 self._write_to_iterator_queue.put(request) 554 ), 555 self._status_code_task, 556 ), 557 return_when=asyncio.FIRST_COMPLETED, 558 ) 559 560 async def write(self, request: RequestType) -> None: 561 # If no queue was created it means that requests 562 # should be expected through an iterators provided 563 # by the caller. 564 if self._write_to_iterator_queue is None: 565 raise cygrpc.UsageError(_API_STYLE_ERROR) 566 567 try: 568 call = await self._interceptors_task 569 except (asyncio.CancelledError, AioRpcError): 570 raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) 571 572 if call.done(): 573 raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) 574 elif call._done_writing_flag: 575 raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS) 576 577 await self._write_to_iterator_queue_interruptible(request, call) 578 579 if call.done(): 580 raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) 581 582 async def done_writing(self) -> None: 583 """Signal peer that client is done writing. 584 585 This method is idempotent. 586 """ 587 # If no queue was created it means that requests 588 # should be expected through an iterators provided 589 # by the caller. 590 if self._write_to_iterator_queue is None: 591 raise cygrpc.UsageError(_API_STYLE_ERROR) 592 593 try: 594 call = await self._interceptors_task 595 except asyncio.CancelledError: 596 raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) 597 598 await self._write_to_iterator_queue_interruptible( 599 _InterceptedStreamRequestMixin._FINISH_ITERATOR_SENTINEL, call 600 ) 601 602 603class InterceptedUnaryUnaryCall( 604 _InterceptedUnaryResponseMixin, InterceptedCall, _base_call.UnaryUnaryCall 605): 606 """Used for running a `UnaryUnaryCall` wrapped by interceptors. 607 608 For the `__await__` method is it is proxied to the intercepted call only when 609 the interceptor task is finished. 610 """ 611 612 _loop: asyncio.AbstractEventLoop 613 _channel: cygrpc.AioChannel 614 615 # pylint: disable=too-many-arguments 616 def __init__( 617 self, 618 interceptors: Sequence[UnaryUnaryClientInterceptor], 619 request: RequestType, 620 timeout: Optional[float], 621 metadata: Metadata, 622 credentials: Optional[grpc.CallCredentials], 623 wait_for_ready: Optional[bool], 624 channel: cygrpc.AioChannel, 625 method: bytes, 626 request_serializer: SerializingFunction, 627 response_deserializer: DeserializingFunction, 628 loop: asyncio.AbstractEventLoop, 629 ) -> None: 630 self._loop = loop 631 self._channel = channel 632 interceptors_task = loop.create_task( 633 self._invoke( 634 interceptors, 635 method, 636 timeout, 637 metadata, 638 credentials, 639 wait_for_ready, 640 request, 641 request_serializer, 642 response_deserializer, 643 ) 644 ) 645 super().__init__(interceptors_task) 646 647 # pylint: disable=too-many-arguments 648 async def _invoke( 649 self, 650 interceptors: Sequence[UnaryUnaryClientInterceptor], 651 method: bytes, 652 timeout: Optional[float], 653 metadata: Optional[Metadata], 654 credentials: Optional[grpc.CallCredentials], 655 wait_for_ready: Optional[bool], 656 request: RequestType, 657 request_serializer: SerializingFunction, 658 response_deserializer: DeserializingFunction, 659 ) -> UnaryUnaryCall: 660 """Run the RPC call wrapped in interceptors""" 661 662 async def _run_interceptor( 663 interceptors: List[UnaryUnaryClientInterceptor], 664 client_call_details: ClientCallDetails, 665 request: RequestType, 666 ) -> _base_call.UnaryUnaryCall: 667 if interceptors: 668 continuation = functools.partial( 669 _run_interceptor, interceptors[1:] 670 ) 671 call_or_response = await interceptors[0].intercept_unary_unary( 672 continuation, client_call_details, request 673 ) 674 675 if isinstance(call_or_response, _base_call.UnaryUnaryCall): 676 return call_or_response 677 else: 678 return UnaryUnaryCallResponse(call_or_response) 679 680 else: 681 return UnaryUnaryCall( 682 request, 683 _timeout_to_deadline(client_call_details.timeout), 684 client_call_details.metadata, 685 client_call_details.credentials, 686 client_call_details.wait_for_ready, 687 self._channel, 688 client_call_details.method, 689 request_serializer, 690 response_deserializer, 691 self._loop, 692 ) 693 694 client_call_details = ClientCallDetails( 695 method, timeout, metadata, credentials, wait_for_ready 696 ) 697 return await _run_interceptor( 698 list(interceptors), client_call_details, request 699 ) 700 701 def time_remaining(self) -> Optional[float]: 702 raise NotImplementedError() 703 704 705class InterceptedUnaryStreamCall( 706 _InterceptedStreamResponseMixin, InterceptedCall, _base_call.UnaryStreamCall 707): 708 """Used for running a `UnaryStreamCall` wrapped by interceptors.""" 709 710 _loop: asyncio.AbstractEventLoop 711 _channel: cygrpc.AioChannel 712 _last_returned_call_from_interceptors = Optional[_base_call.UnaryStreamCall] 713 714 # pylint: disable=too-many-arguments 715 def __init__( 716 self, 717 interceptors: Sequence[UnaryStreamClientInterceptor], 718 request: RequestType, 719 timeout: Optional[float], 720 metadata: Metadata, 721 credentials: Optional[grpc.CallCredentials], 722 wait_for_ready: Optional[bool], 723 channel: cygrpc.AioChannel, 724 method: bytes, 725 request_serializer: SerializingFunction, 726 response_deserializer: DeserializingFunction, 727 loop: asyncio.AbstractEventLoop, 728 ) -> None: 729 self._loop = loop 730 self._channel = channel 731 self._init_stream_response_mixin() 732 self._last_returned_call_from_interceptors = None 733 interceptors_task = loop.create_task( 734 self._invoke( 735 interceptors, 736 method, 737 timeout, 738 metadata, 739 credentials, 740 wait_for_ready, 741 request, 742 request_serializer, 743 response_deserializer, 744 ) 745 ) 746 super().__init__(interceptors_task) 747 748 # pylint: disable=too-many-arguments 749 async def _invoke( 750 self, 751 interceptors: Sequence[UnaryStreamClientInterceptor], 752 method: bytes, 753 timeout: Optional[float], 754 metadata: Optional[Metadata], 755 credentials: Optional[grpc.CallCredentials], 756 wait_for_ready: Optional[bool], 757 request: RequestType, 758 request_serializer: SerializingFunction, 759 response_deserializer: DeserializingFunction, 760 ) -> UnaryStreamCall: 761 """Run the RPC call wrapped in interceptors""" 762 763 async def _run_interceptor( 764 interceptors: List[UnaryStreamClientInterceptor], 765 client_call_details: ClientCallDetails, 766 request: RequestType, 767 ) -> _base_call.UnaryStreamCall: 768 if interceptors: 769 continuation = functools.partial( 770 _run_interceptor, interceptors[1:] 771 ) 772 773 call_or_response_iterator = await interceptors[ 774 0 775 ].intercept_unary_stream( 776 continuation, client_call_details, request 777 ) 778 779 if isinstance( 780 call_or_response_iterator, _base_call.UnaryStreamCall 781 ): 782 self._last_returned_call_from_interceptors = ( 783 call_or_response_iterator 784 ) 785 else: 786 self._last_returned_call_from_interceptors = ( 787 UnaryStreamCallResponseIterator( 788 self._last_returned_call_from_interceptors, 789 call_or_response_iterator, 790 ) 791 ) 792 return self._last_returned_call_from_interceptors 793 else: 794 self._last_returned_call_from_interceptors = UnaryStreamCall( 795 request, 796 _timeout_to_deadline(client_call_details.timeout), 797 client_call_details.metadata, 798 client_call_details.credentials, 799 client_call_details.wait_for_ready, 800 self._channel, 801 client_call_details.method, 802 request_serializer, 803 response_deserializer, 804 self._loop, 805 ) 806 807 return self._last_returned_call_from_interceptors 808 809 client_call_details = ClientCallDetails( 810 method, timeout, metadata, credentials, wait_for_ready 811 ) 812 return await _run_interceptor( 813 list(interceptors), client_call_details, request 814 ) 815 816 def time_remaining(self) -> Optional[float]: 817 raise NotImplementedError() 818 819 820class InterceptedStreamUnaryCall( 821 _InterceptedUnaryResponseMixin, 822 _InterceptedStreamRequestMixin, 823 InterceptedCall, 824 _base_call.StreamUnaryCall, 825): 826 """Used for running a `StreamUnaryCall` wrapped by interceptors. 827 828 For the `__await__` method is it is proxied to the intercepted call only when 829 the interceptor task is finished. 830 """ 831 832 _loop: asyncio.AbstractEventLoop 833 _channel: cygrpc.AioChannel 834 835 # pylint: disable=too-many-arguments 836 def __init__( 837 self, 838 interceptors: Sequence[StreamUnaryClientInterceptor], 839 request_iterator: Optional[RequestIterableType], 840 timeout: Optional[float], 841 metadata: Metadata, 842 credentials: Optional[grpc.CallCredentials], 843 wait_for_ready: Optional[bool], 844 channel: cygrpc.AioChannel, 845 method: bytes, 846 request_serializer: SerializingFunction, 847 response_deserializer: DeserializingFunction, 848 loop: asyncio.AbstractEventLoop, 849 ) -> None: 850 self._loop = loop 851 self._channel = channel 852 request_iterator = self._init_stream_request_mixin(request_iterator) 853 interceptors_task = loop.create_task( 854 self._invoke( 855 interceptors, 856 method, 857 timeout, 858 metadata, 859 credentials, 860 wait_for_ready, 861 request_iterator, 862 request_serializer, 863 response_deserializer, 864 ) 865 ) 866 super().__init__(interceptors_task) 867 868 # pylint: disable=too-many-arguments 869 async def _invoke( 870 self, 871 interceptors: Sequence[StreamUnaryClientInterceptor], 872 method: bytes, 873 timeout: Optional[float], 874 metadata: Optional[Metadata], 875 credentials: Optional[grpc.CallCredentials], 876 wait_for_ready: Optional[bool], 877 request_iterator: RequestIterableType, 878 request_serializer: SerializingFunction, 879 response_deserializer: DeserializingFunction, 880 ) -> StreamUnaryCall: 881 """Run the RPC call wrapped in interceptors""" 882 883 async def _run_interceptor( 884 interceptors: Iterator[StreamUnaryClientInterceptor], 885 client_call_details: ClientCallDetails, 886 request_iterator: RequestIterableType, 887 ) -> _base_call.StreamUnaryCall: 888 if interceptors: 889 continuation = functools.partial( 890 _run_interceptor, interceptors[1:] 891 ) 892 893 return await interceptors[0].intercept_stream_unary( 894 continuation, client_call_details, request_iterator 895 ) 896 else: 897 return StreamUnaryCall( 898 request_iterator, 899 _timeout_to_deadline(client_call_details.timeout), 900 client_call_details.metadata, 901 client_call_details.credentials, 902 client_call_details.wait_for_ready, 903 self._channel, 904 client_call_details.method, 905 request_serializer, 906 response_deserializer, 907 self._loop, 908 ) 909 910 client_call_details = ClientCallDetails( 911 method, timeout, metadata, credentials, wait_for_ready 912 ) 913 return await _run_interceptor( 914 list(interceptors), client_call_details, request_iterator 915 ) 916 917 def time_remaining(self) -> Optional[float]: 918 raise NotImplementedError() 919 920 921class InterceptedStreamStreamCall( 922 _InterceptedStreamResponseMixin, 923 _InterceptedStreamRequestMixin, 924 InterceptedCall, 925 _base_call.StreamStreamCall, 926): 927 """Used for running a `StreamStreamCall` wrapped by interceptors.""" 928 929 _loop: asyncio.AbstractEventLoop 930 _channel: cygrpc.AioChannel 931 _last_returned_call_from_interceptors = Optional[ 932 _base_call.StreamStreamCall 933 ] 934 935 # pylint: disable=too-many-arguments 936 def __init__( 937 self, 938 interceptors: Sequence[StreamStreamClientInterceptor], 939 request_iterator: Optional[RequestIterableType], 940 timeout: Optional[float], 941 metadata: Metadata, 942 credentials: Optional[grpc.CallCredentials], 943 wait_for_ready: Optional[bool], 944 channel: cygrpc.AioChannel, 945 method: bytes, 946 request_serializer: SerializingFunction, 947 response_deserializer: DeserializingFunction, 948 loop: asyncio.AbstractEventLoop, 949 ) -> None: 950 self._loop = loop 951 self._channel = channel 952 self._init_stream_response_mixin() 953 request_iterator = self._init_stream_request_mixin(request_iterator) 954 self._last_returned_call_from_interceptors = None 955 interceptors_task = loop.create_task( 956 self._invoke( 957 interceptors, 958 method, 959 timeout, 960 metadata, 961 credentials, 962 wait_for_ready, 963 request_iterator, 964 request_serializer, 965 response_deserializer, 966 ) 967 ) 968 super().__init__(interceptors_task) 969 970 # pylint: disable=too-many-arguments 971 async def _invoke( 972 self, 973 interceptors: Sequence[StreamStreamClientInterceptor], 974 method: bytes, 975 timeout: Optional[float], 976 metadata: Optional[Metadata], 977 credentials: Optional[grpc.CallCredentials], 978 wait_for_ready: Optional[bool], 979 request_iterator: RequestIterableType, 980 request_serializer: SerializingFunction, 981 response_deserializer: DeserializingFunction, 982 ) -> StreamStreamCall: 983 """Run the RPC call wrapped in interceptors""" 984 985 async def _run_interceptor( 986 interceptors: List[StreamStreamClientInterceptor], 987 client_call_details: ClientCallDetails, 988 request_iterator: RequestIterableType, 989 ) -> _base_call.StreamStreamCall: 990 if interceptors: 991 continuation = functools.partial( 992 _run_interceptor, interceptors[1:] 993 ) 994 995 call_or_response_iterator = await interceptors[ 996 0 997 ].intercept_stream_stream( 998 continuation, client_call_details, request_iterator 999 ) 1000 1001 if isinstance( 1002 call_or_response_iterator, _base_call.StreamStreamCall 1003 ): 1004 self._last_returned_call_from_interceptors = ( 1005 call_or_response_iterator 1006 ) 1007 else: 1008 self._last_returned_call_from_interceptors = ( 1009 StreamStreamCallResponseIterator( 1010 self._last_returned_call_from_interceptors, 1011 call_or_response_iterator, 1012 ) 1013 ) 1014 return self._last_returned_call_from_interceptors 1015 else: 1016 self._last_returned_call_from_interceptors = StreamStreamCall( 1017 request_iterator, 1018 _timeout_to_deadline(client_call_details.timeout), 1019 client_call_details.metadata, 1020 client_call_details.credentials, 1021 client_call_details.wait_for_ready, 1022 self._channel, 1023 client_call_details.method, 1024 request_serializer, 1025 response_deserializer, 1026 self._loop, 1027 ) 1028 return self._last_returned_call_from_interceptors 1029 1030 client_call_details = ClientCallDetails( 1031 method, timeout, metadata, credentials, wait_for_ready 1032 ) 1033 return await _run_interceptor( 1034 list(interceptors), client_call_details, request_iterator 1035 ) 1036 1037 def time_remaining(self) -> Optional[float]: 1038 raise NotImplementedError() 1039 1040 1041class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall): 1042 """Final UnaryUnaryCall class finished with a response.""" 1043 1044 _response: ResponseType 1045 1046 def __init__(self, response: ResponseType) -> None: 1047 self._response = response 1048 1049 def cancel(self) -> bool: 1050 return False 1051 1052 def cancelled(self) -> bool: 1053 return False 1054 1055 def done(self) -> bool: 1056 return True 1057 1058 def add_done_callback(self, unused_callback) -> None: 1059 raise NotImplementedError() 1060 1061 def time_remaining(self) -> Optional[float]: 1062 raise NotImplementedError() 1063 1064 async def initial_metadata(self) -> Optional[Metadata]: 1065 return None 1066 1067 async def trailing_metadata(self) -> Optional[Metadata]: 1068 return None 1069 1070 async def code(self) -> grpc.StatusCode: 1071 return grpc.StatusCode.OK 1072 1073 async def details(self) -> str: 1074 return "" 1075 1076 async def debug_error_string(self) -> Optional[str]: 1077 return None 1078 1079 def __await__(self): 1080 if False: # pylint: disable=using-constant-test 1081 # This code path is never used, but a yield statement is needed 1082 # for telling the interpreter that __await__ is a generator. 1083 yield None 1084 return self._response 1085 1086 async def wait_for_connection(self) -> None: 1087 pass 1088 1089 1090class _StreamCallResponseIterator: 1091 _call: Union[_base_call.UnaryStreamCall, _base_call.StreamStreamCall] 1092 _response_iterator: AsyncIterable[ResponseType] 1093 1094 def __init__( 1095 self, 1096 call: Union[_base_call.UnaryStreamCall, _base_call.StreamStreamCall], 1097 response_iterator: AsyncIterable[ResponseType], 1098 ) -> None: 1099 self._response_iterator = response_iterator 1100 self._call = call 1101 1102 def cancel(self) -> bool: 1103 return self._call.cancel() 1104 1105 def cancelled(self) -> bool: 1106 return self._call.cancelled() 1107 1108 def done(self) -> bool: 1109 return self._call.done() 1110 1111 def add_done_callback(self, callback) -> None: 1112 self._call.add_done_callback(callback) 1113 1114 def time_remaining(self) -> Optional[float]: 1115 return self._call.time_remaining() 1116 1117 async def initial_metadata(self) -> Optional[Metadata]: 1118 return await self._call.initial_metadata() 1119 1120 async def trailing_metadata(self) -> Optional[Metadata]: 1121 return await self._call.trailing_metadata() 1122 1123 async def code(self) -> grpc.StatusCode: 1124 return await self._call.code() 1125 1126 async def details(self) -> str: 1127 return await self._call.details() 1128 1129 async def debug_error_string(self) -> Optional[str]: 1130 return await self._call.debug_error_string() 1131 1132 def __aiter__(self): 1133 return self._response_iterator.__aiter__() 1134 1135 async def wait_for_connection(self) -> None: 1136 return await self._call.wait_for_connection() 1137 1138 1139class UnaryStreamCallResponseIterator( 1140 _StreamCallResponseIterator, _base_call.UnaryStreamCall 1141): 1142 """UnaryStreamCall class wich uses an alternative response iterator.""" 1143 1144 async def read(self) -> ResponseType: 1145 # Behind the scenes everyting goes through the 1146 # async iterator. So this path should not be reached. 1147 raise NotImplementedError() 1148 1149 1150class StreamStreamCallResponseIterator( 1151 _StreamCallResponseIterator, _base_call.StreamStreamCall 1152): 1153 """StreamStreamCall class wich uses an alternative response iterator.""" 1154 1155 async def read(self) -> ResponseType: 1156 # Behind the scenes everyting goes through the 1157 # async iterator. So this path should not be reached. 1158 raise NotImplementedError() 1159 1160 async def write(self, request: RequestType) -> None: 1161 # Behind the scenes everyting goes through the 1162 # async iterator provided by the InterceptedStreamStreamCall. 1163 # So this path should not be reached. 1164 raise NotImplementedError() 1165 1166 async def done_writing(self) -> None: 1167 # Behind the scenes everyting goes through the 1168 # async iterator provided by the InterceptedStreamStreamCall. 1169 # So this path should not be reached. 1170 raise NotImplementedError() 1171 1172 @property 1173 def _done_writing_flag(self) -> bool: 1174 return self._call._done_writing_flag 1175