1# Copyright 2015 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import collections 16import logging 17import threading 18from typing import Callable, Optional, Type 19 20import grpc 21from grpc import _common 22from grpc._cython import cygrpc 23from grpc._typing import MetadataType 24 25_LOGGER = logging.getLogger(__name__) 26 27 28class _AuthMetadataContext( 29 collections.namedtuple( 30 "AuthMetadataContext", 31 ( 32 "service_url", 33 "method_name", 34 ), 35 ), 36 grpc.AuthMetadataContext, 37): 38 pass 39 40 41class _CallbackState(object): 42 def __init__(self): 43 self.lock = threading.Lock() 44 self.called = False 45 self.exception = None 46 47 48class _AuthMetadataPluginCallback(grpc.AuthMetadataPluginCallback): 49 _state: _CallbackState 50 _callback: Callable 51 52 def __init__(self, state: _CallbackState, callback: Callable): 53 self._state = state 54 self._callback = callback 55 56 def __call__( 57 self, metadata: MetadataType, error: Optional[Type[BaseException]] 58 ): 59 with self._state.lock: 60 if self._state.exception is None: 61 if self._state.called: 62 raise RuntimeError( 63 "AuthMetadataPluginCallback invoked more than once!" 64 ) 65 else: 66 self._state.called = True 67 else: 68 raise RuntimeError( 69 'AuthMetadataPluginCallback raised exception "{}"!'.format( 70 self._state.exception 71 ) 72 ) 73 if error is None: 74 self._callback(metadata, cygrpc.StatusCode.ok, None) 75 else: 76 self._callback( 77 None, cygrpc.StatusCode.internal, _common.encode(str(error)) 78 ) 79 80 81class _Plugin(object): 82 _metadata_plugin: grpc.AuthMetadataPlugin 83 84 def __init__(self, metadata_plugin: grpc.AuthMetadataPlugin): 85 self._metadata_plugin = metadata_plugin 86 self._stored_ctx = None 87 88 try: 89 import contextvars # pylint: disable=wrong-import-position 90 91 # The plugin may be invoked on a thread created by Core, which will not 92 # have the context propagated. This context is stored and installed in 93 # the thread invoking the plugin. 94 self._stored_ctx = contextvars.copy_context() 95 except ImportError: 96 # Support versions predating contextvars. 97 pass 98 99 def __call__(self, service_url: str, method_name: str, callback: Callable): 100 context = _AuthMetadataContext( 101 _common.decode(service_url), _common.decode(method_name) 102 ) 103 callback_state = _CallbackState() 104 try: 105 self._metadata_plugin( 106 context, _AuthMetadataPluginCallback(callback_state, callback) 107 ) 108 except Exception as exception: # pylint: disable=broad-except 109 _LOGGER.exception( 110 'AuthMetadataPluginCallback "%s" raised exception!', 111 self._metadata_plugin, 112 ) 113 with callback_state.lock: 114 callback_state.exception = exception 115 if callback_state.called: 116 return 117 callback( 118 None, cygrpc.StatusCode.internal, _common.encode(str(exception)) 119 ) 120 121 122def metadata_plugin_call_credentials( 123 metadata_plugin: grpc.AuthMetadataPlugin, name: Optional[str] 124) -> grpc.CallCredentials: 125 if name is None: 126 try: 127 effective_name = metadata_plugin.__name__ 128 except AttributeError: 129 effective_name = metadata_plugin.__class__.__name__ 130 else: 131 effective_name = name 132 return grpc.CallCredentials( 133 cygrpc.MetadataPluginCallCredentials( 134 _Plugin(metadata_plugin), _common.encode(effective_name) 135 ) 136 ) 137