xref: /aosp_15_r20/external/grpc-grpc/src/python/grpcio/grpc/aio/_call.py (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1# Copyright 2019 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Invocation-side implementation of gRPC Asyncio Python."""
15
16import asyncio
17import enum
18from functools import partial
19import inspect
20import logging
21import traceback
22from typing import Any, AsyncIterator, Generator, Generic, Optional, Tuple
23
24import grpc
25from grpc import _common
26from grpc._cython import cygrpc
27
28from . import _base_call
29from ._metadata import Metadata
30from ._typing import DeserializingFunction
31from ._typing import DoneCallbackType
32from ._typing import MetadatumType
33from ._typing import RequestIterableType
34from ._typing import RequestType
35from ._typing import ResponseType
36from ._typing import SerializingFunction
37
38__all__ = "AioRpcError", "Call", "UnaryUnaryCall", "UnaryStreamCall"
39
40_LOCAL_CANCELLATION_DETAILS = "Locally cancelled by application!"
41_GC_CANCELLATION_DETAILS = "Cancelled upon garbage collection!"
42_RPC_ALREADY_FINISHED_DETAILS = "RPC already finished."
43_RPC_HALF_CLOSED_DETAILS = 'RPC is half closed after calling "done_writing".'
44_API_STYLE_ERROR = (
45    "The iterator and read/write APIs may not be mixed on a single RPC."
46)
47
48_OK_CALL_REPRESENTATION = (
49    '<{} of RPC that terminated with:\n\tstatus = {}\n\tdetails = "{}"\n>'
50)
51
52_NON_OK_CALL_REPRESENTATION = (
53    "<{} of RPC that terminated with:\n"
54    "\tstatus = {}\n"
55    '\tdetails = "{}"\n'
56    '\tdebug_error_string = "{}"\n'
57    ">"
58)
59
60_LOGGER = logging.getLogger(__name__)
61
62
63class AioRpcError(grpc.RpcError):
64    """An implementation of RpcError to be used by the asynchronous API.
65
66    Raised RpcError is a snapshot of the final status of the RPC, values are
67    determined. Hence, its methods no longer needs to be coroutines.
68    """
69
70    _code: grpc.StatusCode
71    _details: Optional[str]
72    _initial_metadata: Optional[Metadata]
73    _trailing_metadata: Optional[Metadata]
74    _debug_error_string: Optional[str]
75
76    def __init__(
77        self,
78        code: grpc.StatusCode,
79        initial_metadata: Metadata,
80        trailing_metadata: Metadata,
81        details: Optional[str] = None,
82        debug_error_string: Optional[str] = None,
83    ) -> None:
84        """Constructor.
85
86        Args:
87          code: The status code with which the RPC has been finalized.
88          details: Optional details explaining the reason of the error.
89          initial_metadata: Optional initial metadata that could be sent by the
90            Server.
91          trailing_metadata: Optional metadata that could be sent by the Server.
92        """
93
94        super().__init__()
95        self._code = code
96        self._details = details
97        self._initial_metadata = initial_metadata
98        self._trailing_metadata = trailing_metadata
99        self._debug_error_string = debug_error_string
100
101    def code(self) -> grpc.StatusCode:
102        """Accesses the status code sent by the server.
103
104        Returns:
105          The `grpc.StatusCode` status code.
106        """
107        return self._code
108
109    def details(self) -> Optional[str]:
110        """Accesses the details sent by the server.
111
112        Returns:
113          The description of the error.
114        """
115        return self._details
116
117    def initial_metadata(self) -> Metadata:
118        """Accesses the initial metadata sent by the server.
119
120        Returns:
121          The initial metadata received.
122        """
123        return self._initial_metadata
124
125    def trailing_metadata(self) -> Metadata:
126        """Accesses the trailing metadata sent by the server.
127
128        Returns:
129          The trailing metadata received.
130        """
131        return self._trailing_metadata
132
133    def debug_error_string(self) -> str:
134        """Accesses the debug error string sent by the server.
135
136        Returns:
137          The debug error string received.
138        """
139        return self._debug_error_string
140
141    def _repr(self) -> str:
142        """Assembles the error string for the RPC error."""
143        return _NON_OK_CALL_REPRESENTATION.format(
144            self.__class__.__name__,
145            self._code,
146            self._details,
147            self._debug_error_string,
148        )
149
150    def __repr__(self) -> str:
151        return self._repr()
152
153    def __str__(self) -> str:
154        return self._repr()
155
156    def __reduce__(self):
157        return (
158            type(self),
159            (
160                self._code,
161                self._initial_metadata,
162                self._trailing_metadata,
163                self._details,
164                self._debug_error_string,
165            ),
166        )
167
168
169def _create_rpc_error(
170    initial_metadata: Metadata, status: cygrpc.AioRpcStatus
171) -> AioRpcError:
172    return AioRpcError(
173        _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()],
174        Metadata.from_tuple(initial_metadata),
175        Metadata.from_tuple(status.trailing_metadata()),
176        details=status.details(),
177        debug_error_string=status.debug_error_string(),
178    )
179
180
181class Call:
182    """Base implementation of client RPC Call object.
183
184    Implements logic around final status, metadata and cancellation.
185    """
186
187    _loop: asyncio.AbstractEventLoop
188    _code: grpc.StatusCode
189    _cython_call: cygrpc._AioCall
190    _metadata: Tuple[MetadatumType, ...]
191    _request_serializer: SerializingFunction
192    _response_deserializer: DeserializingFunction
193
194    def __init__(
195        self,
196        cython_call: cygrpc._AioCall,
197        metadata: Metadata,
198        request_serializer: SerializingFunction,
199        response_deserializer: DeserializingFunction,
200        loop: asyncio.AbstractEventLoop,
201    ) -> None:
202        self._loop = loop
203        self._cython_call = cython_call
204        self._metadata = tuple(metadata)
205        self._request_serializer = request_serializer
206        self._response_deserializer = response_deserializer
207
208    def __del__(self) -> None:
209        # The '_cython_call' object might be destructed before Call object
210        if hasattr(self, "_cython_call"):
211            if not self._cython_call.done():
212                self._cancel(_GC_CANCELLATION_DETAILS)
213
214    def cancelled(self) -> bool:
215        return self._cython_call.cancelled()
216
217    def _cancel(self, details: str) -> bool:
218        """Forwards the application cancellation reasoning."""
219        if not self._cython_call.done():
220            self._cython_call.cancel(details)
221            return True
222        else:
223            return False
224
225    def cancel(self) -> bool:
226        return self._cancel(_LOCAL_CANCELLATION_DETAILS)
227
228    def done(self) -> bool:
229        return self._cython_call.done()
230
231    def add_done_callback(self, callback: DoneCallbackType) -> None:
232        cb = partial(callback, self)
233        self._cython_call.add_done_callback(cb)
234
235    def time_remaining(self) -> Optional[float]:
236        return self._cython_call.time_remaining()
237
238    async def initial_metadata(self) -> Metadata:
239        raw_metadata_tuple = await self._cython_call.initial_metadata()
240        return Metadata.from_tuple(raw_metadata_tuple)
241
242    async def trailing_metadata(self) -> Metadata:
243        raw_metadata_tuple = (
244            await self._cython_call.status()
245        ).trailing_metadata()
246        return Metadata.from_tuple(raw_metadata_tuple)
247
248    async def code(self) -> grpc.StatusCode:
249        cygrpc_code = (await self._cython_call.status()).code()
250        return _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[cygrpc_code]
251
252    async def details(self) -> str:
253        return (await self._cython_call.status()).details()
254
255    async def debug_error_string(self) -> str:
256        return (await self._cython_call.status()).debug_error_string()
257
258    async def _raise_for_status(self) -> None:
259        if self._cython_call.is_locally_cancelled():
260            raise asyncio.CancelledError()
261        code = await self.code()
262        if code != grpc.StatusCode.OK:
263            raise _create_rpc_error(
264                await self.initial_metadata(), await self._cython_call.status()
265            )
266
267    def _repr(self) -> str:
268        return repr(self._cython_call)
269
270    def __repr__(self) -> str:
271        return self._repr()
272
273    def __str__(self) -> str:
274        return self._repr()
275
276
277class _APIStyle(enum.IntEnum):
278    UNKNOWN = 0
279    ASYNC_GENERATOR = 1
280    READER_WRITER = 2
281
282
283class _UnaryResponseMixin(Call, Generic[ResponseType]):
284    _call_response: asyncio.Task
285
286    def _init_unary_response_mixin(self, response_task: asyncio.Task):
287        self._call_response = response_task
288
289    def cancel(self) -> bool:
290        if super().cancel():
291            self._call_response.cancel()
292            return True
293        else:
294            return False
295
296    def __await__(self) -> Generator[Any, None, ResponseType]:
297        """Wait till the ongoing RPC request finishes."""
298        try:
299            response = yield from self._call_response
300        except asyncio.CancelledError:
301            # Even if we caught all other CancelledError, there is still
302            # this corner case. If the application cancels immediately after
303            # the Call object is created, we will observe this
304            # `CancelledError`.
305            if not self.cancelled():
306                self.cancel()
307            raise
308
309        # NOTE(lidiz) If we raise RpcError in the task, and users doesn't
310        # 'await' on it. AsyncIO will log 'Task exception was never retrieved'.
311        # Instead, if we move the exception raising here, the spam stops.
312        # Unfortunately, there can only be one 'yield from' in '__await__'. So,
313        # we need to access the private instance variable.
314        if response is cygrpc.EOF:
315            if self._cython_call.is_locally_cancelled():
316                raise asyncio.CancelledError()
317            else:
318                raise _create_rpc_error(
319                    self._cython_call._initial_metadata,
320                    self._cython_call._status,
321                )
322        else:
323            return response
324
325
326class _StreamResponseMixin(Call):
327    _message_aiter: AsyncIterator[ResponseType]
328    _preparation: asyncio.Task
329    _response_style: _APIStyle
330
331    def _init_stream_response_mixin(self, preparation: asyncio.Task):
332        self._message_aiter = None
333        self._preparation = preparation
334        self._response_style = _APIStyle.UNKNOWN
335
336    def _update_response_style(self, style: _APIStyle):
337        if self._response_style is _APIStyle.UNKNOWN:
338            self._response_style = style
339        elif self._response_style is not style:
340            raise cygrpc.UsageError(_API_STYLE_ERROR)
341
342    def cancel(self) -> bool:
343        if super().cancel():
344            self._preparation.cancel()
345            return True
346        else:
347            return False
348
349    async def _fetch_stream_responses(self) -> ResponseType:
350        message = await self._read()
351        while message is not cygrpc.EOF:
352            yield message
353            message = await self._read()
354
355        # If the read operation failed, Core should explain why.
356        await self._raise_for_status()
357
358    def __aiter__(self) -> AsyncIterator[ResponseType]:
359        self._update_response_style(_APIStyle.ASYNC_GENERATOR)
360        if self._message_aiter is None:
361            self._message_aiter = self._fetch_stream_responses()
362        return self._message_aiter
363
364    async def _read(self) -> ResponseType:
365        # Wait for the request being sent
366        await self._preparation
367
368        # Reads response message from Core
369        try:
370            raw_response = await self._cython_call.receive_serialized_message()
371        except asyncio.CancelledError:
372            if not self.cancelled():
373                self.cancel()
374            raise
375
376        if raw_response is cygrpc.EOF:
377            return cygrpc.EOF
378        else:
379            return _common.deserialize(
380                raw_response, self._response_deserializer
381            )
382
383    async def read(self) -> ResponseType:
384        if self.done():
385            await self._raise_for_status()
386            return cygrpc.EOF
387        self._update_response_style(_APIStyle.READER_WRITER)
388
389        response_message = await self._read()
390
391        if response_message is cygrpc.EOF:
392            # If the read operation failed, Core should explain why.
393            await self._raise_for_status()
394        return response_message
395
396
397class _StreamRequestMixin(Call):
398    _metadata_sent: asyncio.Event
399    _done_writing_flag: bool
400    _async_request_poller: Optional[asyncio.Task]
401    _request_style: _APIStyle
402
403    def _init_stream_request_mixin(
404        self, request_iterator: Optional[RequestIterableType]
405    ):
406        self._metadata_sent = asyncio.Event()
407        self._done_writing_flag = False
408
409        # If user passes in an async iterator, create a consumer Task.
410        if request_iterator is not None:
411            self._async_request_poller = self._loop.create_task(
412                self._consume_request_iterator(request_iterator)
413            )
414            self._request_style = _APIStyle.ASYNC_GENERATOR
415        else:
416            self._async_request_poller = None
417            self._request_style = _APIStyle.READER_WRITER
418
419    def _raise_for_different_style(self, style: _APIStyle):
420        if self._request_style is not style:
421            raise cygrpc.UsageError(_API_STYLE_ERROR)
422
423    def cancel(self) -> bool:
424        if super().cancel():
425            if self._async_request_poller is not None:
426                self._async_request_poller.cancel()
427            return True
428        else:
429            return False
430
431    def _metadata_sent_observer(self):
432        self._metadata_sent.set()
433
434    async def _consume_request_iterator(
435        self, request_iterator: RequestIterableType
436    ) -> None:
437        try:
438            if inspect.isasyncgen(request_iterator) or hasattr(
439                request_iterator, "__aiter__"
440            ):
441                async for request in request_iterator:
442                    try:
443                        await self._write(request)
444                    except AioRpcError as rpc_error:
445                        _LOGGER.debug(
446                            (
447                                "Exception while consuming the"
448                                " request_iterator: %s"
449                            ),
450                            rpc_error,
451                        )
452                        return
453            else:
454                for request in request_iterator:
455                    try:
456                        await self._write(request)
457                    except AioRpcError as rpc_error:
458                        _LOGGER.debug(
459                            (
460                                "Exception while consuming the"
461                                " request_iterator: %s"
462                            ),
463                            rpc_error,
464                        )
465                        return
466
467            await self._done_writing()
468        except:  # pylint: disable=bare-except
469            # Client iterators can raise exceptions, which we should handle by
470            # cancelling the RPC and logging the client's error. No exceptions
471            # should escape this function.
472            _LOGGER.debug(
473                "Client request_iterator raised exception:\n%s",
474                traceback.format_exc(),
475            )
476            self.cancel()
477
478    async def _write(self, request: RequestType) -> None:
479        if self.done():
480            raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
481        if self._done_writing_flag:
482            raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
483        if not self._metadata_sent.is_set():
484            await self._metadata_sent.wait()
485            if self.done():
486                await self._raise_for_status()
487
488        serialized_request = _common.serialize(
489            request, self._request_serializer
490        )
491        try:
492            await self._cython_call.send_serialized_message(serialized_request)
493        except cygrpc.InternalError as err:
494            self._cython_call.set_internal_error(str(err))
495            await self._raise_for_status()
496        except asyncio.CancelledError:
497            if not self.cancelled():
498                self.cancel()
499            raise
500
501    async def _done_writing(self) -> None:
502        if self.done():
503            # If the RPC is finished, do nothing.
504            return
505        if not self._done_writing_flag:
506            # If the done writing is not sent before, try to send it.
507            self._done_writing_flag = True
508            try:
509                await self._cython_call.send_receive_close()
510            except asyncio.CancelledError:
511                if not self.cancelled():
512                    self.cancel()
513                raise
514
515    async def write(self, request: RequestType) -> None:
516        self._raise_for_different_style(_APIStyle.READER_WRITER)
517        await self._write(request)
518
519    async def done_writing(self) -> None:
520        """Signal peer that client is done writing.
521
522        This method is idempotent.
523        """
524        self._raise_for_different_style(_APIStyle.READER_WRITER)
525        await self._done_writing()
526
527    async def wait_for_connection(self) -> None:
528        await self._metadata_sent.wait()
529        if self.done():
530            await self._raise_for_status()
531
532
533class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall):
534    """Object for managing unary-unary RPC calls.
535
536    Returned when an instance of `UnaryUnaryMultiCallable` object is called.
537    """
538
539    _request: RequestType
540    _invocation_task: asyncio.Task
541
542    # pylint: disable=too-many-arguments
543    def __init__(
544        self,
545        request: RequestType,
546        deadline: Optional[float],
547        metadata: Metadata,
548        credentials: Optional[grpc.CallCredentials],
549        wait_for_ready: Optional[bool],
550        channel: cygrpc.AioChannel,
551        method: bytes,
552        request_serializer: SerializingFunction,
553        response_deserializer: DeserializingFunction,
554        loop: asyncio.AbstractEventLoop,
555    ) -> None:
556        super().__init__(
557            channel.call(method, deadline, credentials, wait_for_ready),
558            metadata,
559            request_serializer,
560            response_deserializer,
561            loop,
562        )
563        self._request = request
564        self._context = cygrpc.build_census_context()
565        self._invocation_task = loop.create_task(self._invoke())
566        self._init_unary_response_mixin(self._invocation_task)
567
568    async def _invoke(self) -> ResponseType:
569        serialized_request = _common.serialize(
570            self._request, self._request_serializer
571        )
572
573        # NOTE(lidiz) asyncio.CancelledError is not a good transport for status,
574        # because the asyncio.Task class do not cache the exception object.
575        # https://github.com/python/cpython/blob/edad4d89e357c92f70c0324b937845d652b20afd/Lib/asyncio/tasks.py#L785
576        try:
577            serialized_response = await self._cython_call.unary_unary(
578                serialized_request, self._metadata, self._context
579            )
580        except asyncio.CancelledError:
581            if not self.cancelled():
582                self.cancel()
583
584        if self._cython_call.is_ok():
585            return _common.deserialize(
586                serialized_response, self._response_deserializer
587            )
588        else:
589            return cygrpc.EOF
590
591    async def wait_for_connection(self) -> None:
592        await self._invocation_task
593        if self.done():
594            await self._raise_for_status()
595
596
597class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall):
598    """Object for managing unary-stream RPC calls.
599
600    Returned when an instance of `UnaryStreamMultiCallable` object is called.
601    """
602
603    _request: RequestType
604    _send_unary_request_task: asyncio.Task
605
606    # pylint: disable=too-many-arguments
607    def __init__(
608        self,
609        request: RequestType,
610        deadline: Optional[float],
611        metadata: Metadata,
612        credentials: Optional[grpc.CallCredentials],
613        wait_for_ready: Optional[bool],
614        channel: cygrpc.AioChannel,
615        method: bytes,
616        request_serializer: SerializingFunction,
617        response_deserializer: DeserializingFunction,
618        loop: asyncio.AbstractEventLoop,
619    ) -> None:
620        super().__init__(
621            channel.call(method, deadline, credentials, wait_for_ready),
622            metadata,
623            request_serializer,
624            response_deserializer,
625            loop,
626        )
627        self._request = request
628        self._context = cygrpc.build_census_context()
629        self._send_unary_request_task = loop.create_task(
630            self._send_unary_request()
631        )
632        self._init_stream_response_mixin(self._send_unary_request_task)
633
634    async def _send_unary_request(self) -> ResponseType:
635        serialized_request = _common.serialize(
636            self._request, self._request_serializer
637        )
638        try:
639            await self._cython_call.initiate_unary_stream(
640                serialized_request, self._metadata, self._context
641            )
642        except asyncio.CancelledError:
643            if not self.cancelled():
644                self.cancel()
645            raise
646
647    async def wait_for_connection(self) -> None:
648        await self._send_unary_request_task
649        if self.done():
650            await self._raise_for_status()
651
652
653# pylint: disable=too-many-ancestors
654class StreamUnaryCall(
655    _StreamRequestMixin, _UnaryResponseMixin, Call, _base_call.StreamUnaryCall
656):
657    """Object for managing stream-unary RPC calls.
658
659    Returned when an instance of `StreamUnaryMultiCallable` object is called.
660    """
661
662    # pylint: disable=too-many-arguments
663    def __init__(
664        self,
665        request_iterator: Optional[RequestIterableType],
666        deadline: Optional[float],
667        metadata: Metadata,
668        credentials: Optional[grpc.CallCredentials],
669        wait_for_ready: Optional[bool],
670        channel: cygrpc.AioChannel,
671        method: bytes,
672        request_serializer: SerializingFunction,
673        response_deserializer: DeserializingFunction,
674        loop: asyncio.AbstractEventLoop,
675    ) -> None:
676        super().__init__(
677            channel.call(method, deadline, credentials, wait_for_ready),
678            metadata,
679            request_serializer,
680            response_deserializer,
681            loop,
682        )
683
684        self._context = cygrpc.build_census_context()
685        self._init_stream_request_mixin(request_iterator)
686        self._init_unary_response_mixin(loop.create_task(self._conduct_rpc()))
687
688    async def _conduct_rpc(self) -> ResponseType:
689        try:
690            serialized_response = await self._cython_call.stream_unary(
691                self._metadata, self._metadata_sent_observer, self._context
692            )
693        except asyncio.CancelledError:
694            if not self.cancelled():
695                self.cancel()
696            raise
697
698        if self._cython_call.is_ok():
699            return _common.deserialize(
700                serialized_response, self._response_deserializer
701            )
702        else:
703            return cygrpc.EOF
704
705
706class StreamStreamCall(
707    _StreamRequestMixin, _StreamResponseMixin, Call, _base_call.StreamStreamCall
708):
709    """Object for managing stream-stream RPC calls.
710
711    Returned when an instance of `StreamStreamMultiCallable` object is called.
712    """
713
714    _initializer: asyncio.Task
715
716    # pylint: disable=too-many-arguments
717    def __init__(
718        self,
719        request_iterator: Optional[RequestIterableType],
720        deadline: Optional[float],
721        metadata: Metadata,
722        credentials: Optional[grpc.CallCredentials],
723        wait_for_ready: Optional[bool],
724        channel: cygrpc.AioChannel,
725        method: bytes,
726        request_serializer: SerializingFunction,
727        response_deserializer: DeserializingFunction,
728        loop: asyncio.AbstractEventLoop,
729    ) -> None:
730        super().__init__(
731            channel.call(method, deadline, credentials, wait_for_ready),
732            metadata,
733            request_serializer,
734            response_deserializer,
735            loop,
736        )
737        self._context = cygrpc.build_census_context()
738        self._initializer = self._loop.create_task(self._prepare_rpc())
739        self._init_stream_request_mixin(request_iterator)
740        self._init_stream_response_mixin(self._initializer)
741
742    async def _prepare_rpc(self):
743        """This method prepares the RPC for receiving/sending messages.
744
745        All other operations around the stream should only happen after the
746        completion of this method.
747        """
748        try:
749            await self._cython_call.initiate_stream_stream(
750                self._metadata, self._metadata_sent_observer, self._context
751            )
752        except asyncio.CancelledError:
753            if not self.cancelled():
754                self.cancel()
755            # No need to raise RpcError here, because no one will `await` this task.
756