xref: /aosp_15_r20/external/grpc-grpc/src/python/grpcio/grpc/_plugin_wrapping.py (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1# Copyright 2015 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import collections
16import logging
17import threading
18from typing import Callable, Optional, Type
19
20import grpc
21from grpc import _common
22from grpc._cython import cygrpc
23from grpc._typing import MetadataType
24
25_LOGGER = logging.getLogger(__name__)
26
27
28class _AuthMetadataContext(
29    collections.namedtuple(
30        "AuthMetadataContext",
31        (
32            "service_url",
33            "method_name",
34        ),
35    ),
36    grpc.AuthMetadataContext,
37):
38    pass
39
40
41class _CallbackState(object):
42    def __init__(self):
43        self.lock = threading.Lock()
44        self.called = False
45        self.exception = None
46
47
48class _AuthMetadataPluginCallback(grpc.AuthMetadataPluginCallback):
49    _state: _CallbackState
50    _callback: Callable
51
52    def __init__(self, state: _CallbackState, callback: Callable):
53        self._state = state
54        self._callback = callback
55
56    def __call__(
57        self, metadata: MetadataType, error: Optional[Type[BaseException]]
58    ):
59        with self._state.lock:
60            if self._state.exception is None:
61                if self._state.called:
62                    raise RuntimeError(
63                        "AuthMetadataPluginCallback invoked more than once!"
64                    )
65                else:
66                    self._state.called = True
67            else:
68                raise RuntimeError(
69                    'AuthMetadataPluginCallback raised exception "{}"!'.format(
70                        self._state.exception
71                    )
72                )
73        if error is None:
74            self._callback(metadata, cygrpc.StatusCode.ok, None)
75        else:
76            self._callback(
77                None, cygrpc.StatusCode.internal, _common.encode(str(error))
78            )
79
80
81class _Plugin(object):
82    _metadata_plugin: grpc.AuthMetadataPlugin
83
84    def __init__(self, metadata_plugin: grpc.AuthMetadataPlugin):
85        self._metadata_plugin = metadata_plugin
86        self._stored_ctx = None
87
88        try:
89            import contextvars  # pylint: disable=wrong-import-position
90
91            # The plugin may be invoked on a thread created by Core, which will not
92            # have the context propagated. This context is stored and installed in
93            # the thread invoking the plugin.
94            self._stored_ctx = contextvars.copy_context()
95        except ImportError:
96            # Support versions predating contextvars.
97            pass
98
99    def __call__(self, service_url: str, method_name: str, callback: Callable):
100        context = _AuthMetadataContext(
101            _common.decode(service_url), _common.decode(method_name)
102        )
103        callback_state = _CallbackState()
104        try:
105            self._metadata_plugin(
106                context, _AuthMetadataPluginCallback(callback_state, callback)
107            )
108        except Exception as exception:  # pylint: disable=broad-except
109            _LOGGER.exception(
110                'AuthMetadataPluginCallback "%s" raised exception!',
111                self._metadata_plugin,
112            )
113            with callback_state.lock:
114                callback_state.exception = exception
115                if callback_state.called:
116                    return
117            callback(
118                None, cygrpc.StatusCode.internal, _common.encode(str(exception))
119            )
120
121
122def metadata_plugin_call_credentials(
123    metadata_plugin: grpc.AuthMetadataPlugin, name: Optional[str]
124) -> grpc.CallCredentials:
125    if name is None:
126        try:
127            effective_name = metadata_plugin.__name__
128        except AttributeError:
129            effective_name = metadata_plugin.__class__.__name__
130    else:
131        effective_name = name
132    return grpc.CallCredentials(
133        cygrpc.MetadataPluginCallCredentials(
134            _Plugin(metadata_plugin), _common.encode(effective_name)
135        )
136    )
137