xref: /aosp_15_r20/external/grpc-grpc/src/python/grpcio/grpc/aio/_interceptor.py (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1# Copyright 2019 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Interceptors implementation of gRPC Asyncio Python."""
15from abc import ABCMeta
16from abc import abstractmethod
17import asyncio
18import collections
19import functools
20from typing import (
21    AsyncIterable,
22    Awaitable,
23    Callable,
24    Iterator,
25    List,
26    Optional,
27    Sequence,
28    Union,
29)
30
31import grpc
32from grpc._cython import cygrpc
33
34from . import _base_call
35from ._call import AioRpcError
36from ._call import StreamStreamCall
37from ._call import StreamUnaryCall
38from ._call import UnaryStreamCall
39from ._call import UnaryUnaryCall
40from ._call import _API_STYLE_ERROR
41from ._call import _RPC_ALREADY_FINISHED_DETAILS
42from ._call import _RPC_HALF_CLOSED_DETAILS
43from ._metadata import Metadata
44from ._typing import DeserializingFunction
45from ._typing import DoneCallbackType
46from ._typing import RequestIterableType
47from ._typing import RequestType
48from ._typing import ResponseIterableType
49from ._typing import ResponseType
50from ._typing import SerializingFunction
51from ._utils import _timeout_to_deadline
52
53_LOCAL_CANCELLATION_DETAILS = "Locally cancelled by application!"
54
55
56class ServerInterceptor(metaclass=ABCMeta):
57    """Affords intercepting incoming RPCs on the service-side.
58
59    This is an EXPERIMENTAL API.
60    """
61
62    @abstractmethod
63    async def intercept_service(
64        self,
65        continuation: Callable[
66            [grpc.HandlerCallDetails], Awaitable[grpc.RpcMethodHandler]
67        ],
68        handler_call_details: grpc.HandlerCallDetails,
69    ) -> grpc.RpcMethodHandler:
70        """Intercepts incoming RPCs before handing them over to a handler.
71
72        State can be passed from an interceptor to downstream interceptors
73        via contextvars. The first interceptor is called from an empty
74        contextvars.Context, and the same Context is used for downstream
75        interceptors and for the final handler call. Note that there are no
76        guarantees that interceptors and handlers will be called from the
77        same thread.
78
79        Args:
80            continuation: A function that takes a HandlerCallDetails and
81                proceeds to invoke the next interceptor in the chain, if any,
82                or the RPC handler lookup logic, with the call details passed
83                as an argument, and returns an RpcMethodHandler instance if
84                the RPC is considered serviced, or None otherwise.
85            handler_call_details: A HandlerCallDetails describing the RPC.
86
87        Returns:
88            An RpcMethodHandler with which the RPC may be serviced if the
89            interceptor chooses to service this RPC, or None otherwise.
90        """
91
92
93class ClientCallDetails(
94    collections.namedtuple(
95        "ClientCallDetails",
96        ("method", "timeout", "metadata", "credentials", "wait_for_ready"),
97    ),
98    grpc.ClientCallDetails,
99):
100    """Describes an RPC to be invoked.
101
102    This is an EXPERIMENTAL API.
103
104    Args:
105        method: The method name of the RPC.
106        timeout: An optional duration of time in seconds to allow for the RPC.
107        metadata: Optional metadata to be transmitted to the service-side of
108          the RPC.
109        credentials: An optional CallCredentials for the RPC.
110        wait_for_ready: An optional flag to enable :term:`wait_for_ready` mechanism.
111    """
112
113    method: str
114    timeout: Optional[float]
115    metadata: Optional[Metadata]
116    credentials: Optional[grpc.CallCredentials]
117    wait_for_ready: Optional[bool]
118
119
120class ClientInterceptor(metaclass=ABCMeta):
121    """Base class used for all Aio Client Interceptor classes"""
122
123
124class UnaryUnaryClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
125    """Affords intercepting unary-unary invocations."""
126
127    @abstractmethod
128    async def intercept_unary_unary(
129        self,
130        continuation: Callable[
131            [ClientCallDetails, RequestType], UnaryUnaryCall
132        ],
133        client_call_details: ClientCallDetails,
134        request: RequestType,
135    ) -> Union[UnaryUnaryCall, ResponseType]:
136        """Intercepts a unary-unary invocation asynchronously.
137
138        Args:
139          continuation: A coroutine that proceeds with the invocation by
140            executing the next interceptor in the chain or invoking the
141            actual RPC on the underlying Channel. It is the interceptor's
142            responsibility to call it if it decides to move the RPC forward.
143            The interceptor can use
144            `call = await continuation(client_call_details, request)`
145            to continue with the RPC. `continuation` returns the call to the
146            RPC.
147          client_call_details: A ClientCallDetails object describing the
148            outgoing RPC.
149          request: The request value for the RPC.
150
151        Returns:
152          An object with the RPC response.
153
154        Raises:
155          AioRpcError: Indicating that the RPC terminated with non-OK status.
156          asyncio.CancelledError: Indicating that the RPC was canceled.
157        """
158
159
160class UnaryStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
161    """Affords intercepting unary-stream invocations."""
162
163    @abstractmethod
164    async def intercept_unary_stream(
165        self,
166        continuation: Callable[
167            [ClientCallDetails, RequestType], UnaryStreamCall
168        ],
169        client_call_details: ClientCallDetails,
170        request: RequestType,
171    ) -> Union[ResponseIterableType, UnaryStreamCall]:
172        """Intercepts a unary-stream invocation asynchronously.
173
174        The function could return the call object or an asynchronous
175        iterator, in case of being an asyncrhonous iterator this will
176        become the source of the reads done by the caller.
177
178        Args:
179          continuation: A coroutine that proceeds with the invocation by
180            executing the next interceptor in the chain or invoking the
181            actual RPC on the underlying Channel. It is the interceptor's
182            responsibility to call it if it decides to move the RPC forward.
183            The interceptor can use
184            `call = await continuation(client_call_details, request)`
185            to continue with the RPC. `continuation` returns the call to the
186            RPC.
187          client_call_details: A ClientCallDetails object describing the
188            outgoing RPC.
189          request: The request value for the RPC.
190
191        Returns:
192          The RPC Call or an asynchronous iterator.
193
194        Raises:
195          AioRpcError: Indicating that the RPC terminated with non-OK status.
196          asyncio.CancelledError: Indicating that the RPC was canceled.
197        """
198
199
200class StreamUnaryClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
201    """Affords intercepting stream-unary invocations."""
202
203    @abstractmethod
204    async def intercept_stream_unary(
205        self,
206        continuation: Callable[
207            [ClientCallDetails, RequestType], StreamUnaryCall
208        ],
209        client_call_details: ClientCallDetails,
210        request_iterator: RequestIterableType,
211    ) -> StreamUnaryCall:
212        """Intercepts a stream-unary invocation asynchronously.
213
214        Within the interceptor the usage of the call methods like `write` or
215        even awaiting the call should be done carefully, since the caller
216        could be expecting an untouched call, for example for start writing
217        messages to it.
218
219        Args:
220          continuation: A coroutine that proceeds with the invocation by
221            executing the next interceptor in the chain or invoking the
222            actual RPC on the underlying Channel. It is the interceptor's
223            responsibility to call it if it decides to move the RPC forward.
224            The interceptor can use
225            `call = await continuation(client_call_details, request_iterator)`
226            to continue with the RPC. `continuation` returns the call to the
227            RPC.
228          client_call_details: A ClientCallDetails object describing the
229            outgoing RPC.
230          request_iterator: The request iterator that will produce requests
231            for the RPC.
232
233        Returns:
234          The RPC Call.
235
236        Raises:
237          AioRpcError: Indicating that the RPC terminated with non-OK status.
238          asyncio.CancelledError: Indicating that the RPC was canceled.
239        """
240
241
242class StreamStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
243    """Affords intercepting stream-stream invocations."""
244
245    @abstractmethod
246    async def intercept_stream_stream(
247        self,
248        continuation: Callable[
249            [ClientCallDetails, RequestType], StreamStreamCall
250        ],
251        client_call_details: ClientCallDetails,
252        request_iterator: RequestIterableType,
253    ) -> Union[ResponseIterableType, StreamStreamCall]:
254        """Intercepts a stream-stream invocation asynchronously.
255
256        Within the interceptor the usage of the call methods like `write` or
257        even awaiting the call should be done carefully, since the caller
258        could be expecting an untouched call, for example for start writing
259        messages to it.
260
261        The function could return the call object or an asynchronous
262        iterator, in case of being an asyncrhonous iterator this will
263        become the source of the reads done by the caller.
264
265        Args:
266          continuation: A coroutine that proceeds with the invocation by
267            executing the next interceptor in the chain or invoking the
268            actual RPC on the underlying Channel. It is the interceptor's
269            responsibility to call it if it decides to move the RPC forward.
270            The interceptor can use
271            `call = await continuation(client_call_details, request_iterator)`
272            to continue with the RPC. `continuation` returns the call to the
273            RPC.
274          client_call_details: A ClientCallDetails object describing the
275            outgoing RPC.
276          request_iterator: The request iterator that will produce requests
277            for the RPC.
278
279        Returns:
280          The RPC Call or an asynchronous iterator.
281
282        Raises:
283          AioRpcError: Indicating that the RPC terminated with non-OK status.
284          asyncio.CancelledError: Indicating that the RPC was canceled.
285        """
286
287
288class InterceptedCall:
289    """Base implementation for all intercepted call arities.
290
291    Interceptors might have some work to do before the RPC invocation with
292    the capacity of changing the invocation parameters, and some work to do
293    after the RPC invocation with the capacity for accessing to the wrapped
294    `UnaryUnaryCall`.
295
296    It handles also early and later cancellations, when the RPC has not even
297    started and the execution is still held by the interceptors or when the
298    RPC has finished but again the execution is still held by the interceptors.
299
300    Once the RPC is finally executed, all methods are finally done against the
301    intercepted call, being at the same time the same call returned to the
302    interceptors.
303
304    As a base class for all of the interceptors implements the logic around
305    final status, metadata and cancellation.
306    """
307
308    _interceptors_task: asyncio.Task
309    _pending_add_done_callbacks: Sequence[DoneCallbackType]
310
311    def __init__(self, interceptors_task: asyncio.Task) -> None:
312        self._interceptors_task = interceptors_task
313        self._pending_add_done_callbacks = []
314        self._interceptors_task.add_done_callback(
315            self._fire_or_add_pending_done_callbacks
316        )
317
318    def __del__(self):
319        self.cancel()
320
321    def _fire_or_add_pending_done_callbacks(
322        self, interceptors_task: asyncio.Task
323    ) -> None:
324        if not self._pending_add_done_callbacks:
325            return
326
327        call_completed = False
328
329        try:
330            call = interceptors_task.result()
331            if call.done():
332                call_completed = True
333        except (AioRpcError, asyncio.CancelledError):
334            call_completed = True
335
336        if call_completed:
337            for callback in self._pending_add_done_callbacks:
338                callback(self)
339        else:
340            for callback in self._pending_add_done_callbacks:
341                callback = functools.partial(
342                    self._wrap_add_done_callback, callback
343                )
344                call.add_done_callback(callback)
345
346        self._pending_add_done_callbacks = []
347
348    def _wrap_add_done_callback(
349        self, callback: DoneCallbackType, unused_call: _base_call.Call
350    ) -> None:
351        callback(self)
352
353    def cancel(self) -> bool:
354        if not self._interceptors_task.done():
355            # There is no yet the intercepted call available,
356            # Trying to cancel it by using the generic Asyncio
357            # cancellation method.
358            return self._interceptors_task.cancel()
359
360        try:
361            call = self._interceptors_task.result()
362        except AioRpcError:
363            return False
364        except asyncio.CancelledError:
365            return False
366
367        return call.cancel()
368
369    def cancelled(self) -> bool:
370        if not self._interceptors_task.done():
371            return False
372
373        try:
374            call = self._interceptors_task.result()
375        except AioRpcError as err:
376            return err.code() == grpc.StatusCode.CANCELLED
377        except asyncio.CancelledError:
378            return True
379
380        return call.cancelled()
381
382    def done(self) -> bool:
383        if not self._interceptors_task.done():
384            return False
385
386        try:
387            call = self._interceptors_task.result()
388        except (AioRpcError, asyncio.CancelledError):
389            return True
390
391        return call.done()
392
393    def add_done_callback(self, callback: DoneCallbackType) -> None:
394        if not self._interceptors_task.done():
395            self._pending_add_done_callbacks.append(callback)
396            return
397
398        try:
399            call = self._interceptors_task.result()
400        except (AioRpcError, asyncio.CancelledError):
401            callback(self)
402            return
403
404        if call.done():
405            callback(self)
406        else:
407            callback = functools.partial(self._wrap_add_done_callback, callback)
408            call.add_done_callback(callback)
409
410    def time_remaining(self) -> Optional[float]:
411        raise NotImplementedError()
412
413    async def initial_metadata(self) -> Optional[Metadata]:
414        try:
415            call = await self._interceptors_task
416        except AioRpcError as err:
417            return err.initial_metadata()
418        except asyncio.CancelledError:
419            return None
420
421        return await call.initial_metadata()
422
423    async def trailing_metadata(self) -> Optional[Metadata]:
424        try:
425            call = await self._interceptors_task
426        except AioRpcError as err:
427            return err.trailing_metadata()
428        except asyncio.CancelledError:
429            return None
430
431        return await call.trailing_metadata()
432
433    async def code(self) -> grpc.StatusCode:
434        try:
435            call = await self._interceptors_task
436        except AioRpcError as err:
437            return err.code()
438        except asyncio.CancelledError:
439            return grpc.StatusCode.CANCELLED
440
441        return await call.code()
442
443    async def details(self) -> str:
444        try:
445            call = await self._interceptors_task
446        except AioRpcError as err:
447            return err.details()
448        except asyncio.CancelledError:
449            return _LOCAL_CANCELLATION_DETAILS
450
451        return await call.details()
452
453    async def debug_error_string(self) -> Optional[str]:
454        try:
455            call = await self._interceptors_task
456        except AioRpcError as err:
457            return err.debug_error_string()
458        except asyncio.CancelledError:
459            return ""
460
461        return await call.debug_error_string()
462
463    async def wait_for_connection(self) -> None:
464        call = await self._interceptors_task
465        return await call.wait_for_connection()
466
467
468class _InterceptedUnaryResponseMixin:
469    def __await__(self):
470        call = yield from self._interceptors_task.__await__()
471        response = yield from call.__await__()
472        return response
473
474
475class _InterceptedStreamResponseMixin:
476    _response_aiter: Optional[AsyncIterable[ResponseType]]
477
478    def _init_stream_response_mixin(self) -> None:
479        # Is initalized later, otherwise if the iterator is not finally
480        # consumed a logging warning is emmited by Asyncio.
481        self._response_aiter = None
482
483    async def _wait_for_interceptor_task_response_iterator(
484        self,
485    ) -> ResponseType:
486        call = await self._interceptors_task
487        async for response in call:
488            yield response
489
490    def __aiter__(self) -> AsyncIterable[ResponseType]:
491        if self._response_aiter is None:
492            self._response_aiter = (
493                self._wait_for_interceptor_task_response_iterator()
494            )
495        return self._response_aiter
496
497    async def read(self) -> ResponseType:
498        if self._response_aiter is None:
499            self._response_aiter = (
500                self._wait_for_interceptor_task_response_iterator()
501            )
502        return await self._response_aiter.asend(None)
503
504
505class _InterceptedStreamRequestMixin:
506    _write_to_iterator_async_gen: Optional[AsyncIterable[RequestType]]
507    _write_to_iterator_queue: Optional[asyncio.Queue]
508    _status_code_task: Optional[asyncio.Task]
509
510    _FINISH_ITERATOR_SENTINEL = object()
511
512    def _init_stream_request_mixin(
513        self, request_iterator: Optional[RequestIterableType]
514    ) -> RequestIterableType:
515        if request_iterator is None:
516            # We provide our own request iterator which is a proxy
517            # of the futures writes that will be done by the caller.
518            self._write_to_iterator_queue = asyncio.Queue(maxsize=1)
519            self._write_to_iterator_async_gen = (
520                self._proxy_writes_as_request_iterator()
521            )
522            self._status_code_task = None
523            request_iterator = self._write_to_iterator_async_gen
524        else:
525            self._write_to_iterator_queue = None
526
527        return request_iterator
528
529    async def _proxy_writes_as_request_iterator(self):
530        await self._interceptors_task
531
532        while True:
533            value = await self._write_to_iterator_queue.get()
534            if (
535                value
536                is _InterceptedStreamRequestMixin._FINISH_ITERATOR_SENTINEL
537            ):
538                break
539            yield value
540
541    async def _write_to_iterator_queue_interruptible(
542        self, request: RequestType, call: InterceptedCall
543    ):
544        # Write the specified 'request' to the request iterator queue using the
545        # specified 'call' to allow for interruption of the write in the case
546        # of abrupt termination of the call.
547        if self._status_code_task is None:
548            self._status_code_task = self._loop.create_task(call.code())
549
550        await asyncio.wait(
551            (
552                self._loop.create_task(
553                    self._write_to_iterator_queue.put(request)
554                ),
555                self._status_code_task,
556            ),
557            return_when=asyncio.FIRST_COMPLETED,
558        )
559
560    async def write(self, request: RequestType) -> None:
561        # If no queue was created it means that requests
562        # should be expected through an iterators provided
563        # by the caller.
564        if self._write_to_iterator_queue is None:
565            raise cygrpc.UsageError(_API_STYLE_ERROR)
566
567        try:
568            call = await self._interceptors_task
569        except (asyncio.CancelledError, AioRpcError):
570            raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
571
572        if call.done():
573            raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
574        elif call._done_writing_flag:
575            raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
576
577        await self._write_to_iterator_queue_interruptible(request, call)
578
579        if call.done():
580            raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
581
582    async def done_writing(self) -> None:
583        """Signal peer that client is done writing.
584
585        This method is idempotent.
586        """
587        # If no queue was created it means that requests
588        # should be expected through an iterators provided
589        # by the caller.
590        if self._write_to_iterator_queue is None:
591            raise cygrpc.UsageError(_API_STYLE_ERROR)
592
593        try:
594            call = await self._interceptors_task
595        except asyncio.CancelledError:
596            raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
597
598        await self._write_to_iterator_queue_interruptible(
599            _InterceptedStreamRequestMixin._FINISH_ITERATOR_SENTINEL, call
600        )
601
602
603class InterceptedUnaryUnaryCall(
604    _InterceptedUnaryResponseMixin, InterceptedCall, _base_call.UnaryUnaryCall
605):
606    """Used for running a `UnaryUnaryCall` wrapped by interceptors.
607
608    For the `__await__` method is it is proxied to the intercepted call only when
609    the interceptor task is finished.
610    """
611
612    _loop: asyncio.AbstractEventLoop
613    _channel: cygrpc.AioChannel
614
615    # pylint: disable=too-many-arguments
616    def __init__(
617        self,
618        interceptors: Sequence[UnaryUnaryClientInterceptor],
619        request: RequestType,
620        timeout: Optional[float],
621        metadata: Metadata,
622        credentials: Optional[grpc.CallCredentials],
623        wait_for_ready: Optional[bool],
624        channel: cygrpc.AioChannel,
625        method: bytes,
626        request_serializer: SerializingFunction,
627        response_deserializer: DeserializingFunction,
628        loop: asyncio.AbstractEventLoop,
629    ) -> None:
630        self._loop = loop
631        self._channel = channel
632        interceptors_task = loop.create_task(
633            self._invoke(
634                interceptors,
635                method,
636                timeout,
637                metadata,
638                credentials,
639                wait_for_ready,
640                request,
641                request_serializer,
642                response_deserializer,
643            )
644        )
645        super().__init__(interceptors_task)
646
647    # pylint: disable=too-many-arguments
648    async def _invoke(
649        self,
650        interceptors: Sequence[UnaryUnaryClientInterceptor],
651        method: bytes,
652        timeout: Optional[float],
653        metadata: Optional[Metadata],
654        credentials: Optional[grpc.CallCredentials],
655        wait_for_ready: Optional[bool],
656        request: RequestType,
657        request_serializer: SerializingFunction,
658        response_deserializer: DeserializingFunction,
659    ) -> UnaryUnaryCall:
660        """Run the RPC call wrapped in interceptors"""
661
662        async def _run_interceptor(
663            interceptors: List[UnaryUnaryClientInterceptor],
664            client_call_details: ClientCallDetails,
665            request: RequestType,
666        ) -> _base_call.UnaryUnaryCall:
667            if interceptors:
668                continuation = functools.partial(
669                    _run_interceptor, interceptors[1:]
670                )
671                call_or_response = await interceptors[0].intercept_unary_unary(
672                    continuation, client_call_details, request
673                )
674
675                if isinstance(call_or_response, _base_call.UnaryUnaryCall):
676                    return call_or_response
677                else:
678                    return UnaryUnaryCallResponse(call_or_response)
679
680            else:
681                return UnaryUnaryCall(
682                    request,
683                    _timeout_to_deadline(client_call_details.timeout),
684                    client_call_details.metadata,
685                    client_call_details.credentials,
686                    client_call_details.wait_for_ready,
687                    self._channel,
688                    client_call_details.method,
689                    request_serializer,
690                    response_deserializer,
691                    self._loop,
692                )
693
694        client_call_details = ClientCallDetails(
695            method, timeout, metadata, credentials, wait_for_ready
696        )
697        return await _run_interceptor(
698            list(interceptors), client_call_details, request
699        )
700
701    def time_remaining(self) -> Optional[float]:
702        raise NotImplementedError()
703
704
705class InterceptedUnaryStreamCall(
706    _InterceptedStreamResponseMixin, InterceptedCall, _base_call.UnaryStreamCall
707):
708    """Used for running a `UnaryStreamCall` wrapped by interceptors."""
709
710    _loop: asyncio.AbstractEventLoop
711    _channel: cygrpc.AioChannel
712    _last_returned_call_from_interceptors = Optional[_base_call.UnaryStreamCall]
713
714    # pylint: disable=too-many-arguments
715    def __init__(
716        self,
717        interceptors: Sequence[UnaryStreamClientInterceptor],
718        request: RequestType,
719        timeout: Optional[float],
720        metadata: Metadata,
721        credentials: Optional[grpc.CallCredentials],
722        wait_for_ready: Optional[bool],
723        channel: cygrpc.AioChannel,
724        method: bytes,
725        request_serializer: SerializingFunction,
726        response_deserializer: DeserializingFunction,
727        loop: asyncio.AbstractEventLoop,
728    ) -> None:
729        self._loop = loop
730        self._channel = channel
731        self._init_stream_response_mixin()
732        self._last_returned_call_from_interceptors = None
733        interceptors_task = loop.create_task(
734            self._invoke(
735                interceptors,
736                method,
737                timeout,
738                metadata,
739                credentials,
740                wait_for_ready,
741                request,
742                request_serializer,
743                response_deserializer,
744            )
745        )
746        super().__init__(interceptors_task)
747
748    # pylint: disable=too-many-arguments
749    async def _invoke(
750        self,
751        interceptors: Sequence[UnaryStreamClientInterceptor],
752        method: bytes,
753        timeout: Optional[float],
754        metadata: Optional[Metadata],
755        credentials: Optional[grpc.CallCredentials],
756        wait_for_ready: Optional[bool],
757        request: RequestType,
758        request_serializer: SerializingFunction,
759        response_deserializer: DeserializingFunction,
760    ) -> UnaryStreamCall:
761        """Run the RPC call wrapped in interceptors"""
762
763        async def _run_interceptor(
764            interceptors: List[UnaryStreamClientInterceptor],
765            client_call_details: ClientCallDetails,
766            request: RequestType,
767        ) -> _base_call.UnaryStreamCall:
768            if interceptors:
769                continuation = functools.partial(
770                    _run_interceptor, interceptors[1:]
771                )
772
773                call_or_response_iterator = await interceptors[
774                    0
775                ].intercept_unary_stream(
776                    continuation, client_call_details, request
777                )
778
779                if isinstance(
780                    call_or_response_iterator, _base_call.UnaryStreamCall
781                ):
782                    self._last_returned_call_from_interceptors = (
783                        call_or_response_iterator
784                    )
785                else:
786                    self._last_returned_call_from_interceptors = (
787                        UnaryStreamCallResponseIterator(
788                            self._last_returned_call_from_interceptors,
789                            call_or_response_iterator,
790                        )
791                    )
792                return self._last_returned_call_from_interceptors
793            else:
794                self._last_returned_call_from_interceptors = UnaryStreamCall(
795                    request,
796                    _timeout_to_deadline(client_call_details.timeout),
797                    client_call_details.metadata,
798                    client_call_details.credentials,
799                    client_call_details.wait_for_ready,
800                    self._channel,
801                    client_call_details.method,
802                    request_serializer,
803                    response_deserializer,
804                    self._loop,
805                )
806
807                return self._last_returned_call_from_interceptors
808
809        client_call_details = ClientCallDetails(
810            method, timeout, metadata, credentials, wait_for_ready
811        )
812        return await _run_interceptor(
813            list(interceptors), client_call_details, request
814        )
815
816    def time_remaining(self) -> Optional[float]:
817        raise NotImplementedError()
818
819
820class InterceptedStreamUnaryCall(
821    _InterceptedUnaryResponseMixin,
822    _InterceptedStreamRequestMixin,
823    InterceptedCall,
824    _base_call.StreamUnaryCall,
825):
826    """Used for running a `StreamUnaryCall` wrapped by interceptors.
827
828    For the `__await__` method is it is proxied to the intercepted call only when
829    the interceptor task is finished.
830    """
831
832    _loop: asyncio.AbstractEventLoop
833    _channel: cygrpc.AioChannel
834
835    # pylint: disable=too-many-arguments
836    def __init__(
837        self,
838        interceptors: Sequence[StreamUnaryClientInterceptor],
839        request_iterator: Optional[RequestIterableType],
840        timeout: Optional[float],
841        metadata: Metadata,
842        credentials: Optional[grpc.CallCredentials],
843        wait_for_ready: Optional[bool],
844        channel: cygrpc.AioChannel,
845        method: bytes,
846        request_serializer: SerializingFunction,
847        response_deserializer: DeserializingFunction,
848        loop: asyncio.AbstractEventLoop,
849    ) -> None:
850        self._loop = loop
851        self._channel = channel
852        request_iterator = self._init_stream_request_mixin(request_iterator)
853        interceptors_task = loop.create_task(
854            self._invoke(
855                interceptors,
856                method,
857                timeout,
858                metadata,
859                credentials,
860                wait_for_ready,
861                request_iterator,
862                request_serializer,
863                response_deserializer,
864            )
865        )
866        super().__init__(interceptors_task)
867
868    # pylint: disable=too-many-arguments
869    async def _invoke(
870        self,
871        interceptors: Sequence[StreamUnaryClientInterceptor],
872        method: bytes,
873        timeout: Optional[float],
874        metadata: Optional[Metadata],
875        credentials: Optional[grpc.CallCredentials],
876        wait_for_ready: Optional[bool],
877        request_iterator: RequestIterableType,
878        request_serializer: SerializingFunction,
879        response_deserializer: DeserializingFunction,
880    ) -> StreamUnaryCall:
881        """Run the RPC call wrapped in interceptors"""
882
883        async def _run_interceptor(
884            interceptors: Iterator[StreamUnaryClientInterceptor],
885            client_call_details: ClientCallDetails,
886            request_iterator: RequestIterableType,
887        ) -> _base_call.StreamUnaryCall:
888            if interceptors:
889                continuation = functools.partial(
890                    _run_interceptor, interceptors[1:]
891                )
892
893                return await interceptors[0].intercept_stream_unary(
894                    continuation, client_call_details, request_iterator
895                )
896            else:
897                return StreamUnaryCall(
898                    request_iterator,
899                    _timeout_to_deadline(client_call_details.timeout),
900                    client_call_details.metadata,
901                    client_call_details.credentials,
902                    client_call_details.wait_for_ready,
903                    self._channel,
904                    client_call_details.method,
905                    request_serializer,
906                    response_deserializer,
907                    self._loop,
908                )
909
910        client_call_details = ClientCallDetails(
911            method, timeout, metadata, credentials, wait_for_ready
912        )
913        return await _run_interceptor(
914            list(interceptors), client_call_details, request_iterator
915        )
916
917    def time_remaining(self) -> Optional[float]:
918        raise NotImplementedError()
919
920
921class InterceptedStreamStreamCall(
922    _InterceptedStreamResponseMixin,
923    _InterceptedStreamRequestMixin,
924    InterceptedCall,
925    _base_call.StreamStreamCall,
926):
927    """Used for running a `StreamStreamCall` wrapped by interceptors."""
928
929    _loop: asyncio.AbstractEventLoop
930    _channel: cygrpc.AioChannel
931    _last_returned_call_from_interceptors = Optional[
932        _base_call.StreamStreamCall
933    ]
934
935    # pylint: disable=too-many-arguments
936    def __init__(
937        self,
938        interceptors: Sequence[StreamStreamClientInterceptor],
939        request_iterator: Optional[RequestIterableType],
940        timeout: Optional[float],
941        metadata: Metadata,
942        credentials: Optional[grpc.CallCredentials],
943        wait_for_ready: Optional[bool],
944        channel: cygrpc.AioChannel,
945        method: bytes,
946        request_serializer: SerializingFunction,
947        response_deserializer: DeserializingFunction,
948        loop: asyncio.AbstractEventLoop,
949    ) -> None:
950        self._loop = loop
951        self._channel = channel
952        self._init_stream_response_mixin()
953        request_iterator = self._init_stream_request_mixin(request_iterator)
954        self._last_returned_call_from_interceptors = None
955        interceptors_task = loop.create_task(
956            self._invoke(
957                interceptors,
958                method,
959                timeout,
960                metadata,
961                credentials,
962                wait_for_ready,
963                request_iterator,
964                request_serializer,
965                response_deserializer,
966            )
967        )
968        super().__init__(interceptors_task)
969
970    # pylint: disable=too-many-arguments
971    async def _invoke(
972        self,
973        interceptors: Sequence[StreamStreamClientInterceptor],
974        method: bytes,
975        timeout: Optional[float],
976        metadata: Optional[Metadata],
977        credentials: Optional[grpc.CallCredentials],
978        wait_for_ready: Optional[bool],
979        request_iterator: RequestIterableType,
980        request_serializer: SerializingFunction,
981        response_deserializer: DeserializingFunction,
982    ) -> StreamStreamCall:
983        """Run the RPC call wrapped in interceptors"""
984
985        async def _run_interceptor(
986            interceptors: List[StreamStreamClientInterceptor],
987            client_call_details: ClientCallDetails,
988            request_iterator: RequestIterableType,
989        ) -> _base_call.StreamStreamCall:
990            if interceptors:
991                continuation = functools.partial(
992                    _run_interceptor, interceptors[1:]
993                )
994
995                call_or_response_iterator = await interceptors[
996                    0
997                ].intercept_stream_stream(
998                    continuation, client_call_details, request_iterator
999                )
1000
1001                if isinstance(
1002                    call_or_response_iterator, _base_call.StreamStreamCall
1003                ):
1004                    self._last_returned_call_from_interceptors = (
1005                        call_or_response_iterator
1006                    )
1007                else:
1008                    self._last_returned_call_from_interceptors = (
1009                        StreamStreamCallResponseIterator(
1010                            self._last_returned_call_from_interceptors,
1011                            call_or_response_iterator,
1012                        )
1013                    )
1014                return self._last_returned_call_from_interceptors
1015            else:
1016                self._last_returned_call_from_interceptors = StreamStreamCall(
1017                    request_iterator,
1018                    _timeout_to_deadline(client_call_details.timeout),
1019                    client_call_details.metadata,
1020                    client_call_details.credentials,
1021                    client_call_details.wait_for_ready,
1022                    self._channel,
1023                    client_call_details.method,
1024                    request_serializer,
1025                    response_deserializer,
1026                    self._loop,
1027                )
1028                return self._last_returned_call_from_interceptors
1029
1030        client_call_details = ClientCallDetails(
1031            method, timeout, metadata, credentials, wait_for_ready
1032        )
1033        return await _run_interceptor(
1034            list(interceptors), client_call_details, request_iterator
1035        )
1036
1037    def time_remaining(self) -> Optional[float]:
1038        raise NotImplementedError()
1039
1040
1041class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall):
1042    """Final UnaryUnaryCall class finished with a response."""
1043
1044    _response: ResponseType
1045
1046    def __init__(self, response: ResponseType) -> None:
1047        self._response = response
1048
1049    def cancel(self) -> bool:
1050        return False
1051
1052    def cancelled(self) -> bool:
1053        return False
1054
1055    def done(self) -> bool:
1056        return True
1057
1058    def add_done_callback(self, unused_callback) -> None:
1059        raise NotImplementedError()
1060
1061    def time_remaining(self) -> Optional[float]:
1062        raise NotImplementedError()
1063
1064    async def initial_metadata(self) -> Optional[Metadata]:
1065        return None
1066
1067    async def trailing_metadata(self) -> Optional[Metadata]:
1068        return None
1069
1070    async def code(self) -> grpc.StatusCode:
1071        return grpc.StatusCode.OK
1072
1073    async def details(self) -> str:
1074        return ""
1075
1076    async def debug_error_string(self) -> Optional[str]:
1077        return None
1078
1079    def __await__(self):
1080        if False:  # pylint: disable=using-constant-test
1081            # This code path is never used, but a yield statement is needed
1082            # for telling the interpreter that __await__ is a generator.
1083            yield None
1084        return self._response
1085
1086    async def wait_for_connection(self) -> None:
1087        pass
1088
1089
1090class _StreamCallResponseIterator:
1091    _call: Union[_base_call.UnaryStreamCall, _base_call.StreamStreamCall]
1092    _response_iterator: AsyncIterable[ResponseType]
1093
1094    def __init__(
1095        self,
1096        call: Union[_base_call.UnaryStreamCall, _base_call.StreamStreamCall],
1097        response_iterator: AsyncIterable[ResponseType],
1098    ) -> None:
1099        self._response_iterator = response_iterator
1100        self._call = call
1101
1102    def cancel(self) -> bool:
1103        return self._call.cancel()
1104
1105    def cancelled(self) -> bool:
1106        return self._call.cancelled()
1107
1108    def done(self) -> bool:
1109        return self._call.done()
1110
1111    def add_done_callback(self, callback) -> None:
1112        self._call.add_done_callback(callback)
1113
1114    def time_remaining(self) -> Optional[float]:
1115        return self._call.time_remaining()
1116
1117    async def initial_metadata(self) -> Optional[Metadata]:
1118        return await self._call.initial_metadata()
1119
1120    async def trailing_metadata(self) -> Optional[Metadata]:
1121        return await self._call.trailing_metadata()
1122
1123    async def code(self) -> grpc.StatusCode:
1124        return await self._call.code()
1125
1126    async def details(self) -> str:
1127        return await self._call.details()
1128
1129    async def debug_error_string(self) -> Optional[str]:
1130        return await self._call.debug_error_string()
1131
1132    def __aiter__(self):
1133        return self._response_iterator.__aiter__()
1134
1135    async def wait_for_connection(self) -> None:
1136        return await self._call.wait_for_connection()
1137
1138
1139class UnaryStreamCallResponseIterator(
1140    _StreamCallResponseIterator, _base_call.UnaryStreamCall
1141):
1142    """UnaryStreamCall class wich uses an alternative response iterator."""
1143
1144    async def read(self) -> ResponseType:
1145        # Behind the scenes everyting goes through the
1146        # async iterator. So this path should not be reached.
1147        raise NotImplementedError()
1148
1149
1150class StreamStreamCallResponseIterator(
1151    _StreamCallResponseIterator, _base_call.StreamStreamCall
1152):
1153    """StreamStreamCall class wich uses an alternative response iterator."""
1154
1155    async def read(self) -> ResponseType:
1156        # Behind the scenes everyting goes through the
1157        # async iterator. So this path should not be reached.
1158        raise NotImplementedError()
1159
1160    async def write(self, request: RequestType) -> None:
1161        # Behind the scenes everyting goes through the
1162        # async iterator provided by the InterceptedStreamStreamCall.
1163        # So this path should not be reached.
1164        raise NotImplementedError()
1165
1166    async def done_writing(self) -> None:
1167        # Behind the scenes everyting goes through the
1168        # async iterator provided by the InterceptedStreamStreamCall.
1169        # So this path should not be reached.
1170        raise NotImplementedError()
1171
1172    @property
1173    def _done_writing_flag(self) -> bool:
1174        return self._call._done_writing_flag
1175