xref: /aosp_15_r20/external/grpc-grpc/src/python/grpcio/grpc/_interceptor.py (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1# Copyright 2017 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Implementation of gRPC Python interceptors."""
15
16import collections
17import sys
18import types
19from typing import Any, Callable, Optional, Sequence, Tuple, Union
20
21import grpc
22
23from ._typing import DeserializingFunction
24from ._typing import DoneCallbackType
25from ._typing import MetadataType
26from ._typing import RequestIterableType
27from ._typing import SerializingFunction
28
29
30class _ServicePipeline(object):
31    interceptors: Tuple[grpc.ServerInterceptor]
32
33    def __init__(self, interceptors: Sequence[grpc.ServerInterceptor]):
34        self.interceptors = tuple(interceptors)
35
36    def _continuation(self, thunk: Callable, index: int) -> Callable:
37        return lambda context: self._intercept_at(thunk, index, context)
38
39    def _intercept_at(
40        self, thunk: Callable, index: int, context: grpc.HandlerCallDetails
41    ) -> grpc.RpcMethodHandler:
42        if index < len(self.interceptors):
43            interceptor = self.interceptors[index]
44            thunk = self._continuation(thunk, index + 1)
45            return interceptor.intercept_service(thunk, context)
46        else:
47            return thunk(context)
48
49    def execute(
50        self, thunk: Callable, context: grpc.HandlerCallDetails
51    ) -> grpc.RpcMethodHandler:
52        return self._intercept_at(thunk, 0, context)
53
54
55def service_pipeline(
56    interceptors: Optional[Sequence[grpc.ServerInterceptor]],
57) -> Optional[_ServicePipeline]:
58    return _ServicePipeline(interceptors) if interceptors else None
59
60
61class _ClientCallDetails(
62    collections.namedtuple(
63        "_ClientCallDetails",
64        (
65            "method",
66            "timeout",
67            "metadata",
68            "credentials",
69            "wait_for_ready",
70            "compression",
71        ),
72    ),
73    grpc.ClientCallDetails,
74):
75    pass
76
77
78def _unwrap_client_call_details(
79    call_details: grpc.ClientCallDetails,
80    default_details: grpc.ClientCallDetails,
81) -> Tuple[
82    str, float, MetadataType, grpc.CallCredentials, bool, grpc.Compression
83]:
84    try:
85        method = call_details.method  # pytype: disable=attribute-error
86    except AttributeError:
87        method = default_details.method  # pytype: disable=attribute-error
88
89    try:
90        timeout = call_details.timeout  # pytype: disable=attribute-error
91    except AttributeError:
92        timeout = default_details.timeout  # pytype: disable=attribute-error
93
94    try:
95        metadata = call_details.metadata  # pytype: disable=attribute-error
96    except AttributeError:
97        metadata = default_details.metadata  # pytype: disable=attribute-error
98
99    try:
100        credentials = (
101            call_details.credentials
102        )  # pytype: disable=attribute-error
103    except AttributeError:
104        credentials = (
105            default_details.credentials
106        )  # pytype: disable=attribute-error
107
108    try:
109        wait_for_ready = (
110            call_details.wait_for_ready
111        )  # pytype: disable=attribute-error
112    except AttributeError:
113        wait_for_ready = (
114            default_details.wait_for_ready
115        )  # pytype: disable=attribute-error
116
117    try:
118        compression = (
119            call_details.compression
120        )  # pytype: disable=attribute-error
121    except AttributeError:
122        compression = (
123            default_details.compression
124        )  # pytype: disable=attribute-error
125
126    return method, timeout, metadata, credentials, wait_for_ready, compression
127
128
129class _FailureOutcome(
130    grpc.RpcError, grpc.Future, grpc.Call
131):  # pylint: disable=too-many-ancestors
132    _exception: Exception
133    _traceback: types.TracebackType
134
135    def __init__(self, exception: Exception, traceback: types.TracebackType):
136        super(_FailureOutcome, self).__init__()
137        self._exception = exception
138        self._traceback = traceback
139
140    def initial_metadata(self) -> Optional[MetadataType]:
141        return None
142
143    def trailing_metadata(self) -> Optional[MetadataType]:
144        return None
145
146    def code(self) -> Optional[grpc.StatusCode]:
147        return grpc.StatusCode.INTERNAL
148
149    def details(self) -> Optional[str]:
150        return "Exception raised while intercepting the RPC"
151
152    def cancel(self) -> bool:
153        return False
154
155    def cancelled(self) -> bool:
156        return False
157
158    def is_active(self) -> bool:
159        return False
160
161    def time_remaining(self) -> Optional[float]:
162        return None
163
164    def running(self) -> bool:
165        return False
166
167    def done(self) -> bool:
168        return True
169
170    def result(self, ignored_timeout: Optional[float] = None):
171        raise self._exception
172
173    def exception(
174        self, ignored_timeout: Optional[float] = None
175    ) -> Optional[Exception]:
176        return self._exception
177
178    def traceback(
179        self, ignored_timeout: Optional[float] = None
180    ) -> Optional[types.TracebackType]:
181        return self._traceback
182
183    def add_callback(self, unused_callback) -> bool:
184        return False
185
186    def add_done_callback(self, fn: DoneCallbackType) -> None:
187        fn(self)
188
189    def __iter__(self):
190        return self
191
192    def __next__(self):
193        raise self._exception
194
195    def next(self):
196        return self.__next__()
197
198
199class _UnaryOutcome(grpc.Call, grpc.Future):
200    _response: Any
201    _call: grpc.Call
202
203    def __init__(self, response: Any, call: grpc.Call):
204        self._response = response
205        self._call = call
206
207    def initial_metadata(self) -> Optional[MetadataType]:
208        return self._call.initial_metadata()
209
210    def trailing_metadata(self) -> Optional[MetadataType]:
211        return self._call.trailing_metadata()
212
213    def code(self) -> Optional[grpc.StatusCode]:
214        return self._call.code()
215
216    def details(self) -> Optional[str]:
217        return self._call.details()
218
219    def is_active(self) -> bool:
220        return self._call.is_active()
221
222    def time_remaining(self) -> Optional[float]:
223        return self._call.time_remaining()
224
225    def cancel(self) -> bool:
226        return self._call.cancel()
227
228    def add_callback(self, callback) -> bool:
229        return self._call.add_callback(callback)
230
231    def cancelled(self) -> bool:
232        return False
233
234    def running(self) -> bool:
235        return False
236
237    def done(self) -> bool:
238        return True
239
240    def result(self, ignored_timeout: Optional[float] = None):
241        return self._response
242
243    def exception(self, ignored_timeout: Optional[float] = None):
244        return None
245
246    def traceback(self, ignored_timeout: Optional[float] = None):
247        return None
248
249    def add_done_callback(self, fn: DoneCallbackType) -> None:
250        fn(self)
251
252
253class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
254    _thunk: Callable
255    _method: str
256    _interceptor: grpc.UnaryUnaryClientInterceptor
257
258    def __init__(
259        self,
260        thunk: Callable,
261        method: str,
262        interceptor: grpc.UnaryUnaryClientInterceptor,
263    ):
264        self._thunk = thunk
265        self._method = method
266        self._interceptor = interceptor
267
268    def __call__(
269        self,
270        request: Any,
271        timeout: Optional[float] = None,
272        metadata: Optional[MetadataType] = None,
273        credentials: Optional[grpc.CallCredentials] = None,
274        wait_for_ready: Optional[bool] = None,
275        compression: Optional[grpc.Compression] = None,
276    ) -> Any:
277        response, ignored_call = self._with_call(
278            request,
279            timeout=timeout,
280            metadata=metadata,
281            credentials=credentials,
282            wait_for_ready=wait_for_ready,
283            compression=compression,
284        )
285        return response
286
287    def _with_call(
288        self,
289        request: Any,
290        timeout: Optional[float] = None,
291        metadata: Optional[MetadataType] = None,
292        credentials: Optional[grpc.CallCredentials] = None,
293        wait_for_ready: Optional[bool] = None,
294        compression: Optional[grpc.Compression] = None,
295    ) -> Tuple[Any, grpc.Call]:
296        client_call_details = _ClientCallDetails(
297            self._method,
298            timeout,
299            metadata,
300            credentials,
301            wait_for_ready,
302            compression,
303        )
304
305        def continuation(new_details, request):
306            (
307                new_method,
308                new_timeout,
309                new_metadata,
310                new_credentials,
311                new_wait_for_ready,
312                new_compression,
313            ) = _unwrap_client_call_details(new_details, client_call_details)
314            try:
315                response, call = self._thunk(new_method).with_call(
316                    request,
317                    timeout=new_timeout,
318                    metadata=new_metadata,
319                    credentials=new_credentials,
320                    wait_for_ready=new_wait_for_ready,
321                    compression=new_compression,
322                )
323                return _UnaryOutcome(response, call)
324            except grpc.RpcError as rpc_error:
325                return rpc_error
326            except Exception as exception:  # pylint:disable=broad-except
327                return _FailureOutcome(exception, sys.exc_info()[2])
328
329        call = self._interceptor.intercept_unary_unary(
330            continuation, client_call_details, request
331        )
332        return call.result(), call
333
334    def with_call(
335        self,
336        request: Any,
337        timeout: Optional[float] = None,
338        metadata: Optional[MetadataType] = None,
339        credentials: Optional[grpc.CallCredentials] = None,
340        wait_for_ready: Optional[bool] = None,
341        compression: Optional[grpc.Compression] = None,
342    ) -> Tuple[Any, grpc.Call]:
343        return self._with_call(
344            request,
345            timeout=timeout,
346            metadata=metadata,
347            credentials=credentials,
348            wait_for_ready=wait_for_ready,
349            compression=compression,
350        )
351
352    def future(
353        self,
354        request: Any,
355        timeout: Optional[float] = None,
356        metadata: Optional[MetadataType] = None,
357        credentials: Optional[grpc.CallCredentials] = None,
358        wait_for_ready: Optional[bool] = None,
359        compression: Optional[grpc.Compression] = None,
360    ) -> Any:
361        client_call_details = _ClientCallDetails(
362            self._method,
363            timeout,
364            metadata,
365            credentials,
366            wait_for_ready,
367            compression,
368        )
369
370        def continuation(new_details, request):
371            (
372                new_method,
373                new_timeout,
374                new_metadata,
375                new_credentials,
376                new_wait_for_ready,
377                new_compression,
378            ) = _unwrap_client_call_details(new_details, client_call_details)
379            return self._thunk(new_method).future(
380                request,
381                timeout=new_timeout,
382                metadata=new_metadata,
383                credentials=new_credentials,
384                wait_for_ready=new_wait_for_ready,
385                compression=new_compression,
386            )
387
388        try:
389            return self._interceptor.intercept_unary_unary(
390                continuation, client_call_details, request
391            )
392        except Exception as exception:  # pylint:disable=broad-except
393            return _FailureOutcome(exception, sys.exc_info()[2])
394
395
396class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
397    _thunk: Callable
398    _method: str
399    _interceptor: grpc.UnaryStreamClientInterceptor
400
401    def __init__(
402        self,
403        thunk: Callable,
404        method: str,
405        interceptor: grpc.UnaryStreamClientInterceptor,
406    ):
407        self._thunk = thunk
408        self._method = method
409        self._interceptor = interceptor
410
411    def __call__(
412        self,
413        request: Any,
414        timeout: Optional[float] = None,
415        metadata: Optional[MetadataType] = None,
416        credentials: Optional[grpc.CallCredentials] = None,
417        wait_for_ready: Optional[bool] = None,
418        compression: Optional[grpc.Compression] = None,
419    ):
420        client_call_details = _ClientCallDetails(
421            self._method,
422            timeout,
423            metadata,
424            credentials,
425            wait_for_ready,
426            compression,
427        )
428
429        def continuation(new_details, request):
430            (
431                new_method,
432                new_timeout,
433                new_metadata,
434                new_credentials,
435                new_wait_for_ready,
436                new_compression,
437            ) = _unwrap_client_call_details(new_details, client_call_details)
438            return self._thunk(new_method)(
439                request,
440                timeout=new_timeout,
441                metadata=new_metadata,
442                credentials=new_credentials,
443                wait_for_ready=new_wait_for_ready,
444                compression=new_compression,
445            )
446
447        try:
448            return self._interceptor.intercept_unary_stream(
449                continuation, client_call_details, request
450            )
451        except Exception as exception:  # pylint:disable=broad-except
452            return _FailureOutcome(exception, sys.exc_info()[2])
453
454
455class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
456    _thunk: Callable
457    _method: str
458    _interceptor: grpc.StreamUnaryClientInterceptor
459
460    def __init__(
461        self,
462        thunk: Callable,
463        method: str,
464        interceptor: grpc.StreamUnaryClientInterceptor,
465    ):
466        self._thunk = thunk
467        self._method = method
468        self._interceptor = interceptor
469
470    def __call__(
471        self,
472        request_iterator: RequestIterableType,
473        timeout: Optional[float] = None,
474        metadata: Optional[MetadataType] = None,
475        credentials: Optional[grpc.CallCredentials] = None,
476        wait_for_ready: Optional[bool] = None,
477        compression: Optional[grpc.Compression] = None,
478    ) -> Any:
479        response, ignored_call = self._with_call(
480            request_iterator,
481            timeout=timeout,
482            metadata=metadata,
483            credentials=credentials,
484            wait_for_ready=wait_for_ready,
485            compression=compression,
486        )
487        return response
488
489    def _with_call(
490        self,
491        request_iterator: RequestIterableType,
492        timeout: Optional[float] = None,
493        metadata: Optional[MetadataType] = None,
494        credentials: Optional[grpc.CallCredentials] = None,
495        wait_for_ready: Optional[bool] = None,
496        compression: Optional[grpc.Compression] = None,
497    ) -> Tuple[Any, grpc.Call]:
498        client_call_details = _ClientCallDetails(
499            self._method,
500            timeout,
501            metadata,
502            credentials,
503            wait_for_ready,
504            compression,
505        )
506
507        def continuation(new_details, request_iterator):
508            (
509                new_method,
510                new_timeout,
511                new_metadata,
512                new_credentials,
513                new_wait_for_ready,
514                new_compression,
515            ) = _unwrap_client_call_details(new_details, client_call_details)
516            try:
517                response, call = self._thunk(new_method).with_call(
518                    request_iterator,
519                    timeout=new_timeout,
520                    metadata=new_metadata,
521                    credentials=new_credentials,
522                    wait_for_ready=new_wait_for_ready,
523                    compression=new_compression,
524                )
525                return _UnaryOutcome(response, call)
526            except grpc.RpcError as rpc_error:
527                return rpc_error
528            except Exception as exception:  # pylint:disable=broad-except
529                return _FailureOutcome(exception, sys.exc_info()[2])
530
531        call = self._interceptor.intercept_stream_unary(
532            continuation, client_call_details, request_iterator
533        )
534        return call.result(), call
535
536    def with_call(
537        self,
538        request_iterator: RequestIterableType,
539        timeout: Optional[float] = None,
540        metadata: Optional[MetadataType] = None,
541        credentials: Optional[grpc.CallCredentials] = None,
542        wait_for_ready: Optional[bool] = None,
543        compression: Optional[grpc.Compression] = None,
544    ) -> Tuple[Any, grpc.Call]:
545        return self._with_call(
546            request_iterator,
547            timeout=timeout,
548            metadata=metadata,
549            credentials=credentials,
550            wait_for_ready=wait_for_ready,
551            compression=compression,
552        )
553
554    def future(
555        self,
556        request_iterator: RequestIterableType,
557        timeout: Optional[float] = None,
558        metadata: Optional[MetadataType] = None,
559        credentials: Optional[grpc.CallCredentials] = None,
560        wait_for_ready: Optional[bool] = None,
561        compression: Optional[grpc.Compression] = None,
562    ) -> Any:
563        client_call_details = _ClientCallDetails(
564            self._method,
565            timeout,
566            metadata,
567            credentials,
568            wait_for_ready,
569            compression,
570        )
571
572        def continuation(new_details, request_iterator):
573            (
574                new_method,
575                new_timeout,
576                new_metadata,
577                new_credentials,
578                new_wait_for_ready,
579                new_compression,
580            ) = _unwrap_client_call_details(new_details, client_call_details)
581            return self._thunk(new_method).future(
582                request_iterator,
583                timeout=new_timeout,
584                metadata=new_metadata,
585                credentials=new_credentials,
586                wait_for_ready=new_wait_for_ready,
587                compression=new_compression,
588            )
589
590        try:
591            return self._interceptor.intercept_stream_unary(
592                continuation, client_call_details, request_iterator
593            )
594        except Exception as exception:  # pylint:disable=broad-except
595            return _FailureOutcome(exception, sys.exc_info()[2])
596
597
598class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
599    _thunk: Callable
600    _method: str
601    _interceptor: grpc.StreamStreamClientInterceptor
602
603    def __init__(
604        self,
605        thunk: Callable,
606        method: str,
607        interceptor: grpc.StreamStreamClientInterceptor,
608    ):
609        self._thunk = thunk
610        self._method = method
611        self._interceptor = interceptor
612
613    def __call__(
614        self,
615        request_iterator: RequestIterableType,
616        timeout: Optional[float] = None,
617        metadata: Optional[MetadataType] = None,
618        credentials: Optional[grpc.CallCredentials] = None,
619        wait_for_ready: Optional[bool] = None,
620        compression: Optional[grpc.Compression] = None,
621    ):
622        client_call_details = _ClientCallDetails(
623            self._method,
624            timeout,
625            metadata,
626            credentials,
627            wait_for_ready,
628            compression,
629        )
630
631        def continuation(new_details, request_iterator):
632            (
633                new_method,
634                new_timeout,
635                new_metadata,
636                new_credentials,
637                new_wait_for_ready,
638                new_compression,
639            ) = _unwrap_client_call_details(new_details, client_call_details)
640            return self._thunk(new_method)(
641                request_iterator,
642                timeout=new_timeout,
643                metadata=new_metadata,
644                credentials=new_credentials,
645                wait_for_ready=new_wait_for_ready,
646                compression=new_compression,
647            )
648
649        try:
650            return self._interceptor.intercept_stream_stream(
651                continuation, client_call_details, request_iterator
652            )
653        except Exception as exception:  # pylint:disable=broad-except
654            return _FailureOutcome(exception, sys.exc_info()[2])
655
656
657class _Channel(grpc.Channel):
658    _channel: grpc.Channel
659    _interceptor: Union[
660        grpc.UnaryUnaryClientInterceptor,
661        grpc.UnaryStreamClientInterceptor,
662        grpc.StreamStreamClientInterceptor,
663        grpc.StreamUnaryClientInterceptor,
664    ]
665
666    def __init__(
667        self,
668        channel: grpc.Channel,
669        interceptor: Union[
670            grpc.UnaryUnaryClientInterceptor,
671            grpc.UnaryStreamClientInterceptor,
672            grpc.StreamStreamClientInterceptor,
673            grpc.StreamUnaryClientInterceptor,
674        ],
675    ):
676        self._channel = channel
677        self._interceptor = interceptor
678
679    def subscribe(
680        self, callback: Callable, try_to_connect: Optional[bool] = False
681    ):
682        self._channel.subscribe(callback, try_to_connect=try_to_connect)
683
684    def unsubscribe(self, callback: Callable):
685        self._channel.unsubscribe(callback)
686
687    # pylint: disable=arguments-differ
688    def unary_unary(
689        self,
690        method: str,
691        request_serializer: Optional[SerializingFunction] = None,
692        response_deserializer: Optional[DeserializingFunction] = None,
693        _registered_method: Optional[bool] = False,
694    ) -> grpc.UnaryUnaryMultiCallable:
695        # pytype: disable=wrong-arg-count
696        thunk = lambda m: self._channel.unary_unary(
697            m,
698            request_serializer,
699            response_deserializer,
700            _registered_method,
701        )
702        # pytype: enable=wrong-arg-count
703        if isinstance(self._interceptor, grpc.UnaryUnaryClientInterceptor):
704            return _UnaryUnaryMultiCallable(thunk, method, self._interceptor)
705        else:
706            return thunk(method)
707
708    # pylint: disable=arguments-differ
709    def unary_stream(
710        self,
711        method: str,
712        request_serializer: Optional[SerializingFunction] = None,
713        response_deserializer: Optional[DeserializingFunction] = None,
714        _registered_method: Optional[bool] = False,
715    ) -> grpc.UnaryStreamMultiCallable:
716        # pytype: disable=wrong-arg-count
717        thunk = lambda m: self._channel.unary_stream(
718            m,
719            request_serializer,
720            response_deserializer,
721            _registered_method,
722        )
723        # pytype: enable=wrong-arg-count
724        if isinstance(self._interceptor, grpc.UnaryStreamClientInterceptor):
725            return _UnaryStreamMultiCallable(thunk, method, self._interceptor)
726        else:
727            return thunk(method)
728
729    # pylint: disable=arguments-differ
730    def stream_unary(
731        self,
732        method: str,
733        request_serializer: Optional[SerializingFunction] = None,
734        response_deserializer: Optional[DeserializingFunction] = None,
735        _registered_method: Optional[bool] = False,
736    ) -> grpc.StreamUnaryMultiCallable:
737        # pytype: disable=wrong-arg-count
738        thunk = lambda m: self._channel.stream_unary(
739            m,
740            request_serializer,
741            response_deserializer,
742            _registered_method,
743        )
744        # pytype: enable=wrong-arg-count
745        if isinstance(self._interceptor, grpc.StreamUnaryClientInterceptor):
746            return _StreamUnaryMultiCallable(thunk, method, self._interceptor)
747        else:
748            return thunk(method)
749
750    # pylint: disable=arguments-differ
751    def stream_stream(
752        self,
753        method: str,
754        request_serializer: Optional[SerializingFunction] = None,
755        response_deserializer: Optional[DeserializingFunction] = None,
756        _registered_method: Optional[bool] = False,
757    ) -> grpc.StreamStreamMultiCallable:
758        # pytype: disable=wrong-arg-count
759        thunk = lambda m: self._channel.stream_stream(
760            m,
761            request_serializer,
762            response_deserializer,
763            _registered_method,
764        )
765        # pytype: enable=wrong-arg-count
766        if isinstance(self._interceptor, grpc.StreamStreamClientInterceptor):
767            return _StreamStreamMultiCallable(thunk, method, self._interceptor)
768        else:
769            return thunk(method)
770
771    def _close(self):
772        self._channel.close()
773
774    def __enter__(self):
775        return self
776
777    def __exit__(self, exc_type, exc_val, exc_tb):
778        self._close()
779        return False
780
781    def close(self):
782        self._channel.close()
783
784
785def intercept_channel(
786    channel: grpc.Channel,
787    *interceptors: Optional[
788        Sequence[
789            Union[
790                grpc.UnaryUnaryClientInterceptor,
791                grpc.UnaryStreamClientInterceptor,
792                grpc.StreamStreamClientInterceptor,
793                grpc.StreamUnaryClientInterceptor,
794            ]
795        ]
796    ],
797) -> grpc.Channel:
798    for interceptor in reversed(list(interceptors)):
799        if (
800            not isinstance(interceptor, grpc.UnaryUnaryClientInterceptor)
801            and not isinstance(interceptor, grpc.UnaryStreamClientInterceptor)
802            and not isinstance(interceptor, grpc.StreamUnaryClientInterceptor)
803            and not isinstance(interceptor, grpc.StreamStreamClientInterceptor)
804        ):
805            raise TypeError(
806                "interceptor must be "
807                "grpc.UnaryUnaryClientInterceptor or "
808                "grpc.UnaryStreamClientInterceptor or "
809                "grpc.StreamUnaryClientInterceptor or "
810                "grpc.StreamStreamClientInterceptor or "
811            )
812        channel = _Channel(channel, interceptor)
813    return channel
814