xref: /aosp_15_r20/external/grpc-grpc/src/python/grpcio/grpc/aio/_channel.py (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1# Copyright 2019 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Invocation-side implementation of gRPC Asyncio Python."""
15
16import asyncio
17import sys
18from typing import Any, Iterable, List, Optional, Sequence
19
20import grpc
21from grpc import _common
22from grpc import _compression
23from grpc import _grpcio_metadata
24from grpc._cython import cygrpc
25
26from . import _base_call
27from . import _base_channel
28from ._call import StreamStreamCall
29from ._call import StreamUnaryCall
30from ._call import UnaryStreamCall
31from ._call import UnaryUnaryCall
32from ._interceptor import ClientInterceptor
33from ._interceptor import InterceptedStreamStreamCall
34from ._interceptor import InterceptedStreamUnaryCall
35from ._interceptor import InterceptedUnaryStreamCall
36from ._interceptor import InterceptedUnaryUnaryCall
37from ._interceptor import StreamStreamClientInterceptor
38from ._interceptor import StreamUnaryClientInterceptor
39from ._interceptor import UnaryStreamClientInterceptor
40from ._interceptor import UnaryUnaryClientInterceptor
41from ._metadata import Metadata
42from ._typing import ChannelArgumentType
43from ._typing import DeserializingFunction
44from ._typing import MetadataType
45from ._typing import RequestIterableType
46from ._typing import RequestType
47from ._typing import ResponseType
48from ._typing import SerializingFunction
49from ._utils import _timeout_to_deadline
50
51_USER_AGENT = "grpc-python-asyncio/{}".format(_grpcio_metadata.__version__)
52
53if sys.version_info[1] < 7:
54
55    def _all_tasks() -> Iterable[asyncio.Task]:
56        return asyncio.Task.all_tasks()  # pylint: disable=no-member
57
58else:
59
60    def _all_tasks() -> Iterable[asyncio.Task]:
61        return asyncio.all_tasks()
62
63
64def _augment_channel_arguments(
65    base_options: ChannelArgumentType, compression: Optional[grpc.Compression]
66):
67    compression_channel_argument = _compression.create_channel_option(
68        compression
69    )
70    user_agent_channel_argument = (
71        (
72            cygrpc.ChannelArgKey.primary_user_agent_string,
73            _USER_AGENT,
74        ),
75    )
76    return (
77        tuple(base_options)
78        + compression_channel_argument
79        + user_agent_channel_argument
80    )
81
82
83class _BaseMultiCallable:
84    """Base class of all multi callable objects.
85
86    Handles the initialization logic and stores common attributes.
87    """
88
89    _loop: asyncio.AbstractEventLoop
90    _channel: cygrpc.AioChannel
91    _method: bytes
92    _request_serializer: SerializingFunction
93    _response_deserializer: DeserializingFunction
94    _interceptors: Optional[Sequence[ClientInterceptor]]
95    _references: List[Any]
96    _loop: asyncio.AbstractEventLoop
97
98    # pylint: disable=too-many-arguments
99    def __init__(
100        self,
101        channel: cygrpc.AioChannel,
102        method: bytes,
103        request_serializer: SerializingFunction,
104        response_deserializer: DeserializingFunction,
105        interceptors: Optional[Sequence[ClientInterceptor]],
106        references: List[Any],
107        loop: asyncio.AbstractEventLoop,
108    ) -> None:
109        self._loop = loop
110        self._channel = channel
111        self._method = method
112        self._request_serializer = request_serializer
113        self._response_deserializer = response_deserializer
114        self._interceptors = interceptors
115        self._references = references
116
117    @staticmethod
118    def _init_metadata(
119        metadata: Optional[MetadataType] = None,
120        compression: Optional[grpc.Compression] = None,
121    ) -> Metadata:
122        """Based on the provided values for <metadata> or <compression> initialise the final
123        metadata, as it should be used for the current call.
124        """
125        metadata = metadata or Metadata()
126        if not isinstance(metadata, Metadata) and isinstance(metadata, tuple):
127            metadata = Metadata.from_tuple(metadata)
128        if compression:
129            metadata = Metadata(
130                *_compression.augment_metadata(metadata, compression)
131            )
132        return metadata
133
134
135class UnaryUnaryMultiCallable(
136    _BaseMultiCallable, _base_channel.UnaryUnaryMultiCallable
137):
138    def __call__(
139        self,
140        request: RequestType,
141        *,
142        timeout: Optional[float] = None,
143        metadata: Optional[MetadataType] = None,
144        credentials: Optional[grpc.CallCredentials] = None,
145        wait_for_ready: Optional[bool] = None,
146        compression: Optional[grpc.Compression] = None,
147    ) -> _base_call.UnaryUnaryCall[RequestType, ResponseType]:
148        metadata = self._init_metadata(metadata, compression)
149        if not self._interceptors:
150            call = UnaryUnaryCall(
151                request,
152                _timeout_to_deadline(timeout),
153                metadata,
154                credentials,
155                wait_for_ready,
156                self._channel,
157                self._method,
158                self._request_serializer,
159                self._response_deserializer,
160                self._loop,
161            )
162        else:
163            call = InterceptedUnaryUnaryCall(
164                self._interceptors,
165                request,
166                timeout,
167                metadata,
168                credentials,
169                wait_for_ready,
170                self._channel,
171                self._method,
172                self._request_serializer,
173                self._response_deserializer,
174                self._loop,
175            )
176
177        return call
178
179
180class UnaryStreamMultiCallable(
181    _BaseMultiCallable, _base_channel.UnaryStreamMultiCallable
182):
183    def __call__(
184        self,
185        request: RequestType,
186        *,
187        timeout: Optional[float] = None,
188        metadata: Optional[MetadataType] = None,
189        credentials: Optional[grpc.CallCredentials] = None,
190        wait_for_ready: Optional[bool] = None,
191        compression: Optional[grpc.Compression] = None,
192    ) -> _base_call.UnaryStreamCall[RequestType, ResponseType]:
193        metadata = self._init_metadata(metadata, compression)
194
195        if not self._interceptors:
196            call = UnaryStreamCall(
197                request,
198                _timeout_to_deadline(timeout),
199                metadata,
200                credentials,
201                wait_for_ready,
202                self._channel,
203                self._method,
204                self._request_serializer,
205                self._response_deserializer,
206                self._loop,
207            )
208        else:
209            call = InterceptedUnaryStreamCall(
210                self._interceptors,
211                request,
212                timeout,
213                metadata,
214                credentials,
215                wait_for_ready,
216                self._channel,
217                self._method,
218                self._request_serializer,
219                self._response_deserializer,
220                self._loop,
221            )
222
223        return call
224
225
226class StreamUnaryMultiCallable(
227    _BaseMultiCallable, _base_channel.StreamUnaryMultiCallable
228):
229    def __call__(
230        self,
231        request_iterator: Optional[RequestIterableType] = None,
232        timeout: Optional[float] = None,
233        metadata: Optional[MetadataType] = None,
234        credentials: Optional[grpc.CallCredentials] = None,
235        wait_for_ready: Optional[bool] = None,
236        compression: Optional[grpc.Compression] = None,
237    ) -> _base_call.StreamUnaryCall:
238        metadata = self._init_metadata(metadata, compression)
239
240        if not self._interceptors:
241            call = StreamUnaryCall(
242                request_iterator,
243                _timeout_to_deadline(timeout),
244                metadata,
245                credentials,
246                wait_for_ready,
247                self._channel,
248                self._method,
249                self._request_serializer,
250                self._response_deserializer,
251                self._loop,
252            )
253        else:
254            call = InterceptedStreamUnaryCall(
255                self._interceptors,
256                request_iterator,
257                timeout,
258                metadata,
259                credentials,
260                wait_for_ready,
261                self._channel,
262                self._method,
263                self._request_serializer,
264                self._response_deserializer,
265                self._loop,
266            )
267
268        return call
269
270
271class StreamStreamMultiCallable(
272    _BaseMultiCallable, _base_channel.StreamStreamMultiCallable
273):
274    def __call__(
275        self,
276        request_iterator: Optional[RequestIterableType] = None,
277        timeout: Optional[float] = None,
278        metadata: Optional[MetadataType] = None,
279        credentials: Optional[grpc.CallCredentials] = None,
280        wait_for_ready: Optional[bool] = None,
281        compression: Optional[grpc.Compression] = None,
282    ) -> _base_call.StreamStreamCall:
283        metadata = self._init_metadata(metadata, compression)
284
285        if not self._interceptors:
286            call = StreamStreamCall(
287                request_iterator,
288                _timeout_to_deadline(timeout),
289                metadata,
290                credentials,
291                wait_for_ready,
292                self._channel,
293                self._method,
294                self._request_serializer,
295                self._response_deserializer,
296                self._loop,
297            )
298        else:
299            call = InterceptedStreamStreamCall(
300                self._interceptors,
301                request_iterator,
302                timeout,
303                metadata,
304                credentials,
305                wait_for_ready,
306                self._channel,
307                self._method,
308                self._request_serializer,
309                self._response_deserializer,
310                self._loop,
311            )
312
313        return call
314
315
316class Channel(_base_channel.Channel):
317    _loop: asyncio.AbstractEventLoop
318    _channel: cygrpc.AioChannel
319    _unary_unary_interceptors: List[UnaryUnaryClientInterceptor]
320    _unary_stream_interceptors: List[UnaryStreamClientInterceptor]
321    _stream_unary_interceptors: List[StreamUnaryClientInterceptor]
322    _stream_stream_interceptors: List[StreamStreamClientInterceptor]
323
324    def __init__(
325        self,
326        target: str,
327        options: ChannelArgumentType,
328        credentials: Optional[grpc.ChannelCredentials],
329        compression: Optional[grpc.Compression],
330        interceptors: Optional[Sequence[ClientInterceptor]],
331    ):
332        """Constructor.
333
334        Args:
335          target: The target to which to connect.
336          options: Configuration options for the channel.
337          credentials: A cygrpc.ChannelCredentials or None.
338          compression: An optional value indicating the compression method to be
339            used over the lifetime of the channel.
340          interceptors: An optional list of interceptors that would be used for
341            intercepting any RPC executed with that channel.
342        """
343        self._unary_unary_interceptors = []
344        self._unary_stream_interceptors = []
345        self._stream_unary_interceptors = []
346        self._stream_stream_interceptors = []
347
348        if interceptors is not None:
349            for interceptor in interceptors:
350                if isinstance(interceptor, UnaryUnaryClientInterceptor):
351                    self._unary_unary_interceptors.append(interceptor)
352                elif isinstance(interceptor, UnaryStreamClientInterceptor):
353                    self._unary_stream_interceptors.append(interceptor)
354                elif isinstance(interceptor, StreamUnaryClientInterceptor):
355                    self._stream_unary_interceptors.append(interceptor)
356                elif isinstance(interceptor, StreamStreamClientInterceptor):
357                    self._stream_stream_interceptors.append(interceptor)
358                else:
359                    raise ValueError(
360                        "Interceptor {} must be ".format(interceptor)
361                        + "{} or ".format(UnaryUnaryClientInterceptor.__name__)
362                        + "{} or ".format(UnaryStreamClientInterceptor.__name__)
363                        + "{} or ".format(StreamUnaryClientInterceptor.__name__)
364                        + "{}. ".format(StreamStreamClientInterceptor.__name__)
365                    )
366
367        self._loop = cygrpc.get_working_loop()
368        self._channel = cygrpc.AioChannel(
369            _common.encode(target),
370            _augment_channel_arguments(options, compression),
371            credentials,
372            self._loop,
373        )
374
375    async def __aenter__(self):
376        return self
377
378    async def __aexit__(self, exc_type, exc_val, exc_tb):
379        await self._close(None)
380
381    async def _close(self, grace):  # pylint: disable=too-many-branches
382        if self._channel.closed():
383            return
384
385        # No new calls will be accepted by the Cython channel.
386        self._channel.closing()
387
388        # Iterate through running tasks
389        tasks = _all_tasks()
390        calls = []
391        call_tasks = []
392        for task in tasks:
393            try:
394                stack = task.get_stack(limit=1)
395            except AttributeError as attribute_error:
396                # NOTE(lidiz) tl;dr: If the Task is created with a CPython
397                # object, it will trigger AttributeError.
398                #
399                # In the global finalizer, the event loop schedules
400                # a CPython PyAsyncGenAThrow object.
401                # https://github.com/python/cpython/blob/00e45877e33d32bb61aa13a2033e3bba370bda4d/Lib/asyncio/base_events.py#L484
402                #
403                # However, the PyAsyncGenAThrow object is written in C and
404                # failed to include the normal Python frame objects. Hence,
405                # this exception is a false negative, and it is safe to ignore
406                # the failure. It is fixed by https://github.com/python/cpython/pull/18669,
407                # but not available until 3.9 or 3.8.3. So, we have to keep it
408                # for a while.
409                # TODO(lidiz) drop this hack after 3.8 deprecation
410                if "frame" in str(attribute_error):
411                    continue
412                else:
413                    raise
414
415            # If the Task is created by a C-extension, the stack will be empty.
416            if not stack:
417                continue
418
419            # Locate ones created by `aio.Call`.
420            frame = stack[0]
421            candidate = frame.f_locals.get("self")
422            if candidate:
423                if isinstance(candidate, _base_call.Call):
424                    if hasattr(candidate, "_channel"):
425                        # For intercepted Call object
426                        if candidate._channel is not self._channel:
427                            continue
428                    elif hasattr(candidate, "_cython_call"):
429                        # For normal Call object
430                        if candidate._cython_call._channel is not self._channel:
431                            continue
432                    else:
433                        # Unidentified Call object
434                        raise cygrpc.InternalError(
435                            f"Unrecognized call object: {candidate}"
436                        )
437
438                    calls.append(candidate)
439                    call_tasks.append(task)
440
441        # If needed, try to wait for them to finish.
442        # Call objects are not always awaitables.
443        if grace and call_tasks:
444            await asyncio.wait(call_tasks, timeout=grace)
445
446        # Time to cancel existing calls.
447        for call in calls:
448            call.cancel()
449
450        # Destroy the channel
451        self._channel.close()
452
453    async def close(self, grace: Optional[float] = None):
454        await self._close(grace)
455
456    def __del__(self):
457        if hasattr(self, "_channel"):
458            if not self._channel.closed():
459                self._channel.close()
460
461    def get_state(
462        self, try_to_connect: bool = False
463    ) -> grpc.ChannelConnectivity:
464        result = self._channel.check_connectivity_state(try_to_connect)
465        return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[result]
466
467    async def wait_for_state_change(
468        self,
469        last_observed_state: grpc.ChannelConnectivity,
470    ) -> None:
471        assert await self._channel.watch_connectivity_state(
472            last_observed_state.value[0], None
473        )
474
475    async def channel_ready(self) -> None:
476        state = self.get_state(try_to_connect=True)
477        while state != grpc.ChannelConnectivity.READY:
478            await self.wait_for_state_change(state)
479            state = self.get_state(try_to_connect=True)
480
481    # TODO(xuanwn): Implement this method after we have
482    # observability for Asyncio.
483    def _get_registered_call_handle(self, method: str) -> int:
484        pass
485
486    # TODO(xuanwn): Implement _registered_method after we have
487    # observability for Asyncio.
488    # pylint: disable=arguments-differ,unused-argument
489    def unary_unary(
490        self,
491        method: str,
492        request_serializer: Optional[SerializingFunction] = None,
493        response_deserializer: Optional[DeserializingFunction] = None,
494        _registered_method: Optional[bool] = False,
495    ) -> UnaryUnaryMultiCallable:
496        return UnaryUnaryMultiCallable(
497            self._channel,
498            _common.encode(method),
499            request_serializer,
500            response_deserializer,
501            self._unary_unary_interceptors,
502            [self],
503            self._loop,
504        )
505
506    # TODO(xuanwn): Implement _registered_method after we have
507    # observability for Asyncio.
508    # pylint: disable=arguments-differ,unused-argument
509    def unary_stream(
510        self,
511        method: str,
512        request_serializer: Optional[SerializingFunction] = None,
513        response_deserializer: Optional[DeserializingFunction] = None,
514        _registered_method: Optional[bool] = False,
515    ) -> UnaryStreamMultiCallable:
516        return UnaryStreamMultiCallable(
517            self._channel,
518            _common.encode(method),
519            request_serializer,
520            response_deserializer,
521            self._unary_stream_interceptors,
522            [self],
523            self._loop,
524        )
525
526    # TODO(xuanwn): Implement _registered_method after we have
527    # observability for Asyncio.
528    # pylint: disable=arguments-differ,unused-argument
529    def stream_unary(
530        self,
531        method: str,
532        request_serializer: Optional[SerializingFunction] = None,
533        response_deserializer: Optional[DeserializingFunction] = None,
534        _registered_method: Optional[bool] = False,
535    ) -> StreamUnaryMultiCallable:
536        return StreamUnaryMultiCallable(
537            self._channel,
538            _common.encode(method),
539            request_serializer,
540            response_deserializer,
541            self._stream_unary_interceptors,
542            [self],
543            self._loop,
544        )
545
546    # TODO(xuanwn): Implement _registered_method after we have
547    # observability for Asyncio.
548    # pylint: disable=arguments-differ,unused-argument
549    def stream_stream(
550        self,
551        method: str,
552        request_serializer: Optional[SerializingFunction] = None,
553        response_deserializer: Optional[DeserializingFunction] = None,
554        _registered_method: Optional[bool] = False,
555    ) -> StreamStreamMultiCallable:
556        return StreamStreamMultiCallable(
557            self._channel,
558            _common.encode(method),
559            request_serializer,
560            response_deserializer,
561            self._stream_stream_interceptors,
562            [self],
563            self._loop,
564        )
565
566
567def insecure_channel(
568    target: str,
569    options: Optional[ChannelArgumentType] = None,
570    compression: Optional[grpc.Compression] = None,
571    interceptors: Optional[Sequence[ClientInterceptor]] = None,
572):
573    """Creates an insecure asynchronous Channel to a server.
574
575    Args:
576      target: The server address
577      options: An optional list of key-value pairs (:term:`channel_arguments`
578        in gRPC Core runtime) to configure the channel.
579      compression: An optional value indicating the compression method to be
580        used over the lifetime of the channel.
581      interceptors: An optional sequence of interceptors that will be executed for
582        any call executed with this channel.
583
584    Returns:
585      A Channel.
586    """
587    return Channel(
588        target,
589        () if options is None else options,
590        None,
591        compression,
592        interceptors,
593    )
594
595
596def secure_channel(
597    target: str,
598    credentials: grpc.ChannelCredentials,
599    options: Optional[ChannelArgumentType] = None,
600    compression: Optional[grpc.Compression] = None,
601    interceptors: Optional[Sequence[ClientInterceptor]] = None,
602):
603    """Creates a secure asynchronous Channel to a server.
604
605    Args:
606      target: The server address.
607      credentials: A ChannelCredentials instance.
608      options: An optional list of key-value pairs (:term:`channel_arguments`
609        in gRPC Core runtime) to configure the channel.
610      compression: An optional value indicating the compression method to be
611        used over the lifetime of the channel.
612      interceptors: An optional sequence of interceptors that will be executed for
613        any call executed with this channel.
614
615    Returns:
616      An aio.Channel.
617    """
618    return Channel(
619        target,
620        () if options is None else options,
621        credentials._credentials,
622        compression,
623        interceptors,
624    )
625