xref: /aosp_15_r20/external/grpc-grpc/src/python/grpcio/grpc/beta/_client_adaptations.py (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Translates gRPC's client-side API into gRPC's client-side Beta API."""
15
16import grpc
17from grpc import _common
18from grpc.beta import _metadata
19from grpc.beta import interfaces
20from grpc.framework.common import cardinality
21from grpc.framework.foundation import future
22from grpc.framework.interfaces.face import face
23
24# pylint: disable=too-many-arguments,too-many-locals,unused-argument
25
26_STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS = {
27    grpc.StatusCode.CANCELLED: (
28        face.Abortion.Kind.CANCELLED,
29        face.CancellationError,
30    ),
31    grpc.StatusCode.UNKNOWN: (
32        face.Abortion.Kind.REMOTE_FAILURE,
33        face.RemoteError,
34    ),
35    grpc.StatusCode.DEADLINE_EXCEEDED: (
36        face.Abortion.Kind.EXPIRED,
37        face.ExpirationError,
38    ),
39    grpc.StatusCode.UNIMPLEMENTED: (
40        face.Abortion.Kind.LOCAL_FAILURE,
41        face.LocalError,
42    ),
43}
44
45
46def _effective_metadata(metadata, metadata_transformer):
47    non_none_metadata = () if metadata is None else metadata
48    if metadata_transformer is None:
49        return non_none_metadata
50    else:
51        return metadata_transformer(non_none_metadata)
52
53
54def _credentials(grpc_call_options):
55    return None if grpc_call_options is None else grpc_call_options.credentials
56
57
58def _abortion(rpc_error_call):
59    code = rpc_error_call.code()
60    pair = _STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS.get(code)
61    error_kind = face.Abortion.Kind.LOCAL_FAILURE if pair is None else pair[0]
62    return face.Abortion(
63        error_kind,
64        rpc_error_call.initial_metadata(),
65        rpc_error_call.trailing_metadata(),
66        code,
67        rpc_error_call.details(),
68    )
69
70
71def _abortion_error(rpc_error_call):
72    code = rpc_error_call.code()
73    pair = _STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS.get(code)
74    exception_class = face.AbortionError if pair is None else pair[1]
75    return exception_class(
76        rpc_error_call.initial_metadata(),
77        rpc_error_call.trailing_metadata(),
78        code,
79        rpc_error_call.details(),
80    )
81
82
83class _InvocationProtocolContext(interfaces.GRPCInvocationContext):
84    def disable_next_request_compression(self):
85        pass  # TODO(https://github.com/grpc/grpc/issues/4078): design, implement.
86
87
88class _Rendezvous(future.Future, face.Call):
89    def __init__(self, response_future, response_iterator, call):
90        self._future = response_future
91        self._iterator = response_iterator
92        self._call = call
93
94    def cancel(self):
95        return self._call.cancel()
96
97    def cancelled(self):
98        return self._future.cancelled()
99
100    def running(self):
101        return self._future.running()
102
103    def done(self):
104        return self._future.done()
105
106    def result(self, timeout=None):
107        try:
108            return self._future.result(timeout=timeout)
109        except grpc.RpcError as rpc_error_call:
110            raise _abortion_error(rpc_error_call)
111        except grpc.FutureTimeoutError:
112            raise future.TimeoutError()
113        except grpc.FutureCancelledError:
114            raise future.CancelledError()
115
116    def exception(self, timeout=None):
117        try:
118            rpc_error_call = self._future.exception(timeout=timeout)
119            if rpc_error_call is None:
120                return None
121            else:
122                return _abortion_error(rpc_error_call)
123        except grpc.FutureTimeoutError:
124            raise future.TimeoutError()
125        except grpc.FutureCancelledError:
126            raise future.CancelledError()
127
128    def traceback(self, timeout=None):
129        try:
130            return self._future.traceback(timeout=timeout)
131        except grpc.FutureTimeoutError:
132            raise future.TimeoutError()
133        except grpc.FutureCancelledError:
134            raise future.CancelledError()
135
136    def add_done_callback(self, fn):
137        self._future.add_done_callback(lambda ignored_callback: fn(self))
138
139    def __iter__(self):
140        return self
141
142    def _next(self):
143        try:
144            return next(self._iterator)
145        except grpc.RpcError as rpc_error_call:
146            raise _abortion_error(rpc_error_call)
147
148    def __next__(self):
149        return self._next()
150
151    def next(self):
152        return self._next()
153
154    def is_active(self):
155        return self._call.is_active()
156
157    def time_remaining(self):
158        return self._call.time_remaining()
159
160    def add_abortion_callback(self, abortion_callback):
161        def done_callback():
162            if self.code() is not grpc.StatusCode.OK:
163                abortion_callback(_abortion(self._call))
164
165        registered = self._call.add_callback(done_callback)
166        return None if registered else done_callback()
167
168    def protocol_context(self):
169        return _InvocationProtocolContext()
170
171    def initial_metadata(self):
172        return _metadata.beta(self._call.initial_metadata())
173
174    def terminal_metadata(self):
175        return _metadata.beta(self._call.terminal_metadata())
176
177    def code(self):
178        return self._call.code()
179
180    def details(self):
181        return self._call.details()
182
183
184def _blocking_unary_unary(
185    channel,
186    group,
187    method,
188    timeout,
189    with_call,
190    protocol_options,
191    metadata,
192    metadata_transformer,
193    request,
194    request_serializer,
195    response_deserializer,
196):
197    try:
198        multi_callable = channel.unary_unary(
199            _common.fully_qualified_method(group, method),
200            request_serializer=request_serializer,
201            response_deserializer=response_deserializer,
202        )
203        effective_metadata = _effective_metadata(metadata, metadata_transformer)
204        if with_call:
205            response, call = multi_callable.with_call(
206                request,
207                timeout=timeout,
208                metadata=_metadata.unbeta(effective_metadata),
209                credentials=_credentials(protocol_options),
210            )
211            return response, _Rendezvous(None, None, call)
212        else:
213            return multi_callable(
214                request,
215                timeout=timeout,
216                metadata=_metadata.unbeta(effective_metadata),
217                credentials=_credentials(protocol_options),
218            )
219    except grpc.RpcError as rpc_error_call:
220        raise _abortion_error(rpc_error_call)
221
222
223def _future_unary_unary(
224    channel,
225    group,
226    method,
227    timeout,
228    protocol_options,
229    metadata,
230    metadata_transformer,
231    request,
232    request_serializer,
233    response_deserializer,
234):
235    multi_callable = channel.unary_unary(
236        _common.fully_qualified_method(group, method),
237        request_serializer=request_serializer,
238        response_deserializer=response_deserializer,
239    )
240    effective_metadata = _effective_metadata(metadata, metadata_transformer)
241    response_future = multi_callable.future(
242        request,
243        timeout=timeout,
244        metadata=_metadata.unbeta(effective_metadata),
245        credentials=_credentials(protocol_options),
246    )
247    return _Rendezvous(response_future, None, response_future)
248
249
250def _unary_stream(
251    channel,
252    group,
253    method,
254    timeout,
255    protocol_options,
256    metadata,
257    metadata_transformer,
258    request,
259    request_serializer,
260    response_deserializer,
261):
262    multi_callable = channel.unary_stream(
263        _common.fully_qualified_method(group, method),
264        request_serializer=request_serializer,
265        response_deserializer=response_deserializer,
266    )
267    effective_metadata = _effective_metadata(metadata, metadata_transformer)
268    response_iterator = multi_callable(
269        request,
270        timeout=timeout,
271        metadata=_metadata.unbeta(effective_metadata),
272        credentials=_credentials(protocol_options),
273    )
274    return _Rendezvous(None, response_iterator, response_iterator)
275
276
277def _blocking_stream_unary(
278    channel,
279    group,
280    method,
281    timeout,
282    with_call,
283    protocol_options,
284    metadata,
285    metadata_transformer,
286    request_iterator,
287    request_serializer,
288    response_deserializer,
289):
290    try:
291        multi_callable = channel.stream_unary(
292            _common.fully_qualified_method(group, method),
293            request_serializer=request_serializer,
294            response_deserializer=response_deserializer,
295        )
296        effective_metadata = _effective_metadata(metadata, metadata_transformer)
297        if with_call:
298            response, call = multi_callable.with_call(
299                request_iterator,
300                timeout=timeout,
301                metadata=_metadata.unbeta(effective_metadata),
302                credentials=_credentials(protocol_options),
303            )
304            return response, _Rendezvous(None, None, call)
305        else:
306            return multi_callable(
307                request_iterator,
308                timeout=timeout,
309                metadata=_metadata.unbeta(effective_metadata),
310                credentials=_credentials(protocol_options),
311            )
312    except grpc.RpcError as rpc_error_call:
313        raise _abortion_error(rpc_error_call)
314
315
316def _future_stream_unary(
317    channel,
318    group,
319    method,
320    timeout,
321    protocol_options,
322    metadata,
323    metadata_transformer,
324    request_iterator,
325    request_serializer,
326    response_deserializer,
327):
328    multi_callable = channel.stream_unary(
329        _common.fully_qualified_method(group, method),
330        request_serializer=request_serializer,
331        response_deserializer=response_deserializer,
332    )
333    effective_metadata = _effective_metadata(metadata, metadata_transformer)
334    response_future = multi_callable.future(
335        request_iterator,
336        timeout=timeout,
337        metadata=_metadata.unbeta(effective_metadata),
338        credentials=_credentials(protocol_options),
339    )
340    return _Rendezvous(response_future, None, response_future)
341
342
343def _stream_stream(
344    channel,
345    group,
346    method,
347    timeout,
348    protocol_options,
349    metadata,
350    metadata_transformer,
351    request_iterator,
352    request_serializer,
353    response_deserializer,
354):
355    multi_callable = channel.stream_stream(
356        _common.fully_qualified_method(group, method),
357        request_serializer=request_serializer,
358        response_deserializer=response_deserializer,
359    )
360    effective_metadata = _effective_metadata(metadata, metadata_transformer)
361    response_iterator = multi_callable(
362        request_iterator,
363        timeout=timeout,
364        metadata=_metadata.unbeta(effective_metadata),
365        credentials=_credentials(protocol_options),
366    )
367    return _Rendezvous(None, response_iterator, response_iterator)
368
369
370class _UnaryUnaryMultiCallable(face.UnaryUnaryMultiCallable):
371    def __init__(
372        self,
373        channel,
374        group,
375        method,
376        metadata_transformer,
377        request_serializer,
378        response_deserializer,
379    ):
380        self._channel = channel
381        self._group = group
382        self._method = method
383        self._metadata_transformer = metadata_transformer
384        self._request_serializer = request_serializer
385        self._response_deserializer = response_deserializer
386
387    def __call__(
388        self,
389        request,
390        timeout,
391        metadata=None,
392        with_call=False,
393        protocol_options=None,
394    ):
395        return _blocking_unary_unary(
396            self._channel,
397            self._group,
398            self._method,
399            timeout,
400            with_call,
401            protocol_options,
402            metadata,
403            self._metadata_transformer,
404            request,
405            self._request_serializer,
406            self._response_deserializer,
407        )
408
409    def future(self, request, timeout, metadata=None, protocol_options=None):
410        return _future_unary_unary(
411            self._channel,
412            self._group,
413            self._method,
414            timeout,
415            protocol_options,
416            metadata,
417            self._metadata_transformer,
418            request,
419            self._request_serializer,
420            self._response_deserializer,
421        )
422
423    def event(
424        self,
425        request,
426        receiver,
427        abortion_callback,
428        timeout,
429        metadata=None,
430        protocol_options=None,
431    ):
432        raise NotImplementedError()
433
434
435class _UnaryStreamMultiCallable(face.UnaryStreamMultiCallable):
436    def __init__(
437        self,
438        channel,
439        group,
440        method,
441        metadata_transformer,
442        request_serializer,
443        response_deserializer,
444    ):
445        self._channel = channel
446        self._group = group
447        self._method = method
448        self._metadata_transformer = metadata_transformer
449        self._request_serializer = request_serializer
450        self._response_deserializer = response_deserializer
451
452    def __call__(self, request, timeout, metadata=None, protocol_options=None):
453        return _unary_stream(
454            self._channel,
455            self._group,
456            self._method,
457            timeout,
458            protocol_options,
459            metadata,
460            self._metadata_transformer,
461            request,
462            self._request_serializer,
463            self._response_deserializer,
464        )
465
466    def event(
467        self,
468        request,
469        receiver,
470        abortion_callback,
471        timeout,
472        metadata=None,
473        protocol_options=None,
474    ):
475        raise NotImplementedError()
476
477
478class _StreamUnaryMultiCallable(face.StreamUnaryMultiCallable):
479    def __init__(
480        self,
481        channel,
482        group,
483        method,
484        metadata_transformer,
485        request_serializer,
486        response_deserializer,
487    ):
488        self._channel = channel
489        self._group = group
490        self._method = method
491        self._metadata_transformer = metadata_transformer
492        self._request_serializer = request_serializer
493        self._response_deserializer = response_deserializer
494
495    def __call__(
496        self,
497        request_iterator,
498        timeout,
499        metadata=None,
500        with_call=False,
501        protocol_options=None,
502    ):
503        return _blocking_stream_unary(
504            self._channel,
505            self._group,
506            self._method,
507            timeout,
508            with_call,
509            protocol_options,
510            metadata,
511            self._metadata_transformer,
512            request_iterator,
513            self._request_serializer,
514            self._response_deserializer,
515        )
516
517    def future(
518        self, request_iterator, timeout, metadata=None, protocol_options=None
519    ):
520        return _future_stream_unary(
521            self._channel,
522            self._group,
523            self._method,
524            timeout,
525            protocol_options,
526            metadata,
527            self._metadata_transformer,
528            request_iterator,
529            self._request_serializer,
530            self._response_deserializer,
531        )
532
533    def event(
534        self,
535        receiver,
536        abortion_callback,
537        timeout,
538        metadata=None,
539        protocol_options=None,
540    ):
541        raise NotImplementedError()
542
543
544class _StreamStreamMultiCallable(face.StreamStreamMultiCallable):
545    def __init__(
546        self,
547        channel,
548        group,
549        method,
550        metadata_transformer,
551        request_serializer,
552        response_deserializer,
553    ):
554        self._channel = channel
555        self._group = group
556        self._method = method
557        self._metadata_transformer = metadata_transformer
558        self._request_serializer = request_serializer
559        self._response_deserializer = response_deserializer
560
561    def __call__(
562        self, request_iterator, timeout, metadata=None, protocol_options=None
563    ):
564        return _stream_stream(
565            self._channel,
566            self._group,
567            self._method,
568            timeout,
569            protocol_options,
570            metadata,
571            self._metadata_transformer,
572            request_iterator,
573            self._request_serializer,
574            self._response_deserializer,
575        )
576
577    def event(
578        self,
579        receiver,
580        abortion_callback,
581        timeout,
582        metadata=None,
583        protocol_options=None,
584    ):
585        raise NotImplementedError()
586
587
588class _GenericStub(face.GenericStub):
589    def __init__(
590        self,
591        channel,
592        metadata_transformer,
593        request_serializers,
594        response_deserializers,
595    ):
596        self._channel = channel
597        self._metadata_transformer = metadata_transformer
598        self._request_serializers = request_serializers or {}
599        self._response_deserializers = response_deserializers or {}
600
601    def blocking_unary_unary(
602        self,
603        group,
604        method,
605        request,
606        timeout,
607        metadata=None,
608        with_call=None,
609        protocol_options=None,
610    ):
611        request_serializer = self._request_serializers.get(
612            (
613                group,
614                method,
615            )
616        )
617        response_deserializer = self._response_deserializers.get(
618            (
619                group,
620                method,
621            )
622        )
623        return _blocking_unary_unary(
624            self._channel,
625            group,
626            method,
627            timeout,
628            with_call,
629            protocol_options,
630            metadata,
631            self._metadata_transformer,
632            request,
633            request_serializer,
634            response_deserializer,
635        )
636
637    def future_unary_unary(
638        self,
639        group,
640        method,
641        request,
642        timeout,
643        metadata=None,
644        protocol_options=None,
645    ):
646        request_serializer = self._request_serializers.get(
647            (
648                group,
649                method,
650            )
651        )
652        response_deserializer = self._response_deserializers.get(
653            (
654                group,
655                method,
656            )
657        )
658        return _future_unary_unary(
659            self._channel,
660            group,
661            method,
662            timeout,
663            protocol_options,
664            metadata,
665            self._metadata_transformer,
666            request,
667            request_serializer,
668            response_deserializer,
669        )
670
671    def inline_unary_stream(
672        self,
673        group,
674        method,
675        request,
676        timeout,
677        metadata=None,
678        protocol_options=None,
679    ):
680        request_serializer = self._request_serializers.get(
681            (
682                group,
683                method,
684            )
685        )
686        response_deserializer = self._response_deserializers.get(
687            (
688                group,
689                method,
690            )
691        )
692        return _unary_stream(
693            self._channel,
694            group,
695            method,
696            timeout,
697            protocol_options,
698            metadata,
699            self._metadata_transformer,
700            request,
701            request_serializer,
702            response_deserializer,
703        )
704
705    def blocking_stream_unary(
706        self,
707        group,
708        method,
709        request_iterator,
710        timeout,
711        metadata=None,
712        with_call=None,
713        protocol_options=None,
714    ):
715        request_serializer = self._request_serializers.get(
716            (
717                group,
718                method,
719            )
720        )
721        response_deserializer = self._response_deserializers.get(
722            (
723                group,
724                method,
725            )
726        )
727        return _blocking_stream_unary(
728            self._channel,
729            group,
730            method,
731            timeout,
732            with_call,
733            protocol_options,
734            metadata,
735            self._metadata_transformer,
736            request_iterator,
737            request_serializer,
738            response_deserializer,
739        )
740
741    def future_stream_unary(
742        self,
743        group,
744        method,
745        request_iterator,
746        timeout,
747        metadata=None,
748        protocol_options=None,
749    ):
750        request_serializer = self._request_serializers.get(
751            (
752                group,
753                method,
754            )
755        )
756        response_deserializer = self._response_deserializers.get(
757            (
758                group,
759                method,
760            )
761        )
762        return _future_stream_unary(
763            self._channel,
764            group,
765            method,
766            timeout,
767            protocol_options,
768            metadata,
769            self._metadata_transformer,
770            request_iterator,
771            request_serializer,
772            response_deserializer,
773        )
774
775    def inline_stream_stream(
776        self,
777        group,
778        method,
779        request_iterator,
780        timeout,
781        metadata=None,
782        protocol_options=None,
783    ):
784        request_serializer = self._request_serializers.get(
785            (
786                group,
787                method,
788            )
789        )
790        response_deserializer = self._response_deserializers.get(
791            (
792                group,
793                method,
794            )
795        )
796        return _stream_stream(
797            self._channel,
798            group,
799            method,
800            timeout,
801            protocol_options,
802            metadata,
803            self._metadata_transformer,
804            request_iterator,
805            request_serializer,
806            response_deserializer,
807        )
808
809    def event_unary_unary(
810        self,
811        group,
812        method,
813        request,
814        receiver,
815        abortion_callback,
816        timeout,
817        metadata=None,
818        protocol_options=None,
819    ):
820        raise NotImplementedError()
821
822    def event_unary_stream(
823        self,
824        group,
825        method,
826        request,
827        receiver,
828        abortion_callback,
829        timeout,
830        metadata=None,
831        protocol_options=None,
832    ):
833        raise NotImplementedError()
834
835    def event_stream_unary(
836        self,
837        group,
838        method,
839        receiver,
840        abortion_callback,
841        timeout,
842        metadata=None,
843        protocol_options=None,
844    ):
845        raise NotImplementedError()
846
847    def event_stream_stream(
848        self,
849        group,
850        method,
851        receiver,
852        abortion_callback,
853        timeout,
854        metadata=None,
855        protocol_options=None,
856    ):
857        raise NotImplementedError()
858
859    def unary_unary(self, group, method):
860        request_serializer = self._request_serializers.get(
861            (
862                group,
863                method,
864            )
865        )
866        response_deserializer = self._response_deserializers.get(
867            (
868                group,
869                method,
870            )
871        )
872        return _UnaryUnaryMultiCallable(
873            self._channel,
874            group,
875            method,
876            self._metadata_transformer,
877            request_serializer,
878            response_deserializer,
879        )
880
881    def unary_stream(self, group, method):
882        request_serializer = self._request_serializers.get(
883            (
884                group,
885                method,
886            )
887        )
888        response_deserializer = self._response_deserializers.get(
889            (
890                group,
891                method,
892            )
893        )
894        return _UnaryStreamMultiCallable(
895            self._channel,
896            group,
897            method,
898            self._metadata_transformer,
899            request_serializer,
900            response_deserializer,
901        )
902
903    def stream_unary(self, group, method):
904        request_serializer = self._request_serializers.get(
905            (
906                group,
907                method,
908            )
909        )
910        response_deserializer = self._response_deserializers.get(
911            (
912                group,
913                method,
914            )
915        )
916        return _StreamUnaryMultiCallable(
917            self._channel,
918            group,
919            method,
920            self._metadata_transformer,
921            request_serializer,
922            response_deserializer,
923        )
924
925    def stream_stream(self, group, method):
926        request_serializer = self._request_serializers.get(
927            (
928                group,
929                method,
930            )
931        )
932        response_deserializer = self._response_deserializers.get(
933            (
934                group,
935                method,
936            )
937        )
938        return _StreamStreamMultiCallable(
939            self._channel,
940            group,
941            method,
942            self._metadata_transformer,
943            request_serializer,
944            response_deserializer,
945        )
946
947    def __enter__(self):
948        return self
949
950    def __exit__(self, exc_type, exc_val, exc_tb):
951        return False
952
953
954class _DynamicStub(face.DynamicStub):
955    def __init__(self, backing_generic_stub, group, cardinalities):
956        self._generic_stub = backing_generic_stub
957        self._group = group
958        self._cardinalities = cardinalities
959
960    def __getattr__(self, attr):
961        method_cardinality = self._cardinalities.get(attr)
962        if method_cardinality is cardinality.Cardinality.UNARY_UNARY:
963            return self._generic_stub.unary_unary(self._group, attr)
964        elif method_cardinality is cardinality.Cardinality.UNARY_STREAM:
965            return self._generic_stub.unary_stream(self._group, attr)
966        elif method_cardinality is cardinality.Cardinality.STREAM_UNARY:
967            return self._generic_stub.stream_unary(self._group, attr)
968        elif method_cardinality is cardinality.Cardinality.STREAM_STREAM:
969            return self._generic_stub.stream_stream(self._group, attr)
970        else:
971            raise AttributeError(
972                '_DynamicStub object has no attribute "%s"!' % attr
973            )
974
975    def __enter__(self):
976        return self
977
978    def __exit__(self, exc_type, exc_val, exc_tb):
979        return False
980
981
982def generic_stub(
983    channel,
984    host,
985    metadata_transformer,
986    request_serializers,
987    response_deserializers,
988):
989    return _GenericStub(
990        channel,
991        metadata_transformer,
992        request_serializers,
993        response_deserializers,
994    )
995
996
997def dynamic_stub(
998    channel,
999    service,
1000    cardinalities,
1001    host,
1002    metadata_transformer,
1003    request_serializers,
1004    response_deserializers,
1005):
1006    return _DynamicStub(
1007        _GenericStub(
1008            channel,
1009            metadata_transformer,
1010            request_serializers,
1011            response_deserializers,
1012        ),
1013        service,
1014        cardinalities,
1015    )
1016