1# Copyright 2020 The gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Test of propagation of contextvars to AuthMetadataPlugin threads..""" 15 16import contextlib 17import logging 18import os 19import queue 20import sys 21import threading 22import unittest 23 24import grpc 25 26from tests.unit import test_common 27 28_UNARY_UNARY = "/test/UnaryUnary" 29_REQUEST = b"0000" 30 31 32def _unary_unary_handler(request, context): 33 return request 34 35 36def contextvars_supported(): 37 try: 38 import contextvars 39 40 return True 41 except ImportError: 42 return False 43 44 45class _GenericHandler(grpc.GenericRpcHandler): 46 def service(self, handler_call_details): 47 if handler_call_details.method == _UNARY_UNARY: 48 return grpc.unary_unary_rpc_method_handler(_unary_unary_handler) 49 else: 50 raise NotImplementedError() 51 52 53@contextlib.contextmanager 54def _server(): 55 try: 56 server = test_common.test_server() 57 target = "localhost:0" 58 port = server.add_insecure_port(target) 59 server.add_generic_rpc_handlers((_GenericHandler(),)) 60 server.start() 61 yield port 62 finally: 63 server.stop(None) 64 65 66if contextvars_supported(): 67 import contextvars 68 69 _EXPECTED_VALUE = 24601 70 test_var = contextvars.ContextVar("test_var", default=None) 71 72 def set_up_expected_context(): 73 test_var.set(_EXPECTED_VALUE) 74 75 class TestCallCredentials(grpc.AuthMetadataPlugin): 76 def __call__(self, context, callback): 77 if ( 78 test_var.get() != _EXPECTED_VALUE 79 and not test_common.running_under_gevent() 80 ): 81 # contextvars do not work under gevent, but the rest of this 82 # test is still valuable as a test of concurrent runs of the 83 # metadata credentials code path. 84 raise AssertionError( 85 "{} != {}".format(test_var.get(), _EXPECTED_VALUE) 86 ) 87 callback((), None) 88 89 def assert_called(self, test): 90 test.assertTrue(self._invoked) 91 test.assertEqual(_EXPECTED_VALUE, self._recorded_value) 92 93else: 94 95 def set_up_expected_context(): 96 pass 97 98 class TestCallCredentials(grpc.AuthMetadataPlugin): 99 def __call__(self, context, callback): 100 callback((), None) 101 102 103# TODO(https://github.com/grpc/grpc/issues/22257) 104@unittest.skipIf(os.name == "nt", "LocalCredentials not supported on Windows.") 105class ContextVarsPropagationTest(unittest.TestCase): 106 def test_propagation_to_auth_plugin(self): 107 set_up_expected_context() 108 with _server() as port: 109 target = "localhost:{}".format(port) 110 local_credentials = grpc.local_channel_credentials() 111 test_call_credentials = TestCallCredentials() 112 call_credentials = grpc.metadata_call_credentials( 113 test_call_credentials, "test call credentials" 114 ) 115 composite_credentials = grpc.composite_channel_credentials( 116 local_credentials, call_credentials 117 ) 118 with grpc.secure_channel(target, composite_credentials) as channel: 119 stub = channel.unary_unary( 120 _UNARY_UNARY, 121 _registered_method=True, 122 ) 123 response = stub(_REQUEST, wait_for_ready=True) 124 self.assertEqual(_REQUEST, response) 125 126 def test_concurrent_propagation(self): 127 _THREAD_COUNT = 32 128 _RPC_COUNT = 32 129 130 set_up_expected_context() 131 with _server() as port: 132 target = "localhost:{}".format(port) 133 local_credentials = grpc.local_channel_credentials() 134 test_call_credentials = TestCallCredentials() 135 call_credentials = grpc.metadata_call_credentials( 136 test_call_credentials, "test call credentials" 137 ) 138 composite_credentials = grpc.composite_channel_credentials( 139 local_credentials, call_credentials 140 ) 141 wait_group = test_common.WaitGroup(_THREAD_COUNT) 142 143 def _run_on_thread(exception_queue): 144 try: 145 with grpc.secure_channel( 146 target, composite_credentials 147 ) as channel: 148 stub = channel.unary_unary( 149 _UNARY_UNARY, 150 _registered_method=True, 151 ) 152 wait_group.done() 153 wait_group.wait() 154 for i in range(_RPC_COUNT): 155 response = stub(_REQUEST, wait_for_ready=True) 156 self.assertEqual(_REQUEST, response) 157 except Exception as e: # pylint: disable=broad-except 158 exception_queue.put(e) 159 160 threads = [] 161 162 for _ in range(_THREAD_COUNT): 163 q = queue.Queue() 164 thread = threading.Thread(target=_run_on_thread, args=(q,)) 165 thread.setDaemon(True) 166 thread.start() 167 threads.append((thread, q)) 168 169 for thread, q in threads: 170 thread.join() 171 if not q.empty(): 172 raise q.get() 173 174 175if __name__ == "__main__": 176 logging.basicConfig() 177 unittest.main(verbosity=2) 178