xref: /aosp_15_r20/external/grpc-grpc/src/python/grpcio_tests/tests/unit/_contextvars_propagation_test.py (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1# Copyright 2020 The gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Test of propagation of contextvars to AuthMetadataPlugin threads.."""
15
16import contextlib
17import logging
18import os
19import queue
20import sys
21import threading
22import unittest
23
24import grpc
25
26from tests.unit import test_common
27
28_UNARY_UNARY = "/test/UnaryUnary"
29_REQUEST = b"0000"
30
31
32def _unary_unary_handler(request, context):
33    return request
34
35
36def contextvars_supported():
37    try:
38        import contextvars
39
40        return True
41    except ImportError:
42        return False
43
44
45class _GenericHandler(grpc.GenericRpcHandler):
46    def service(self, handler_call_details):
47        if handler_call_details.method == _UNARY_UNARY:
48            return grpc.unary_unary_rpc_method_handler(_unary_unary_handler)
49        else:
50            raise NotImplementedError()
51
52
53@contextlib.contextmanager
54def _server():
55    try:
56        server = test_common.test_server()
57        target = "localhost:0"
58        port = server.add_insecure_port(target)
59        server.add_generic_rpc_handlers((_GenericHandler(),))
60        server.start()
61        yield port
62    finally:
63        server.stop(None)
64
65
66if contextvars_supported():
67    import contextvars
68
69    _EXPECTED_VALUE = 24601
70    test_var = contextvars.ContextVar("test_var", default=None)
71
72    def set_up_expected_context():
73        test_var.set(_EXPECTED_VALUE)
74
75    class TestCallCredentials(grpc.AuthMetadataPlugin):
76        def __call__(self, context, callback):
77            if (
78                test_var.get() != _EXPECTED_VALUE
79                and not test_common.running_under_gevent()
80            ):
81                # contextvars do not work under gevent, but the rest of this
82                # test is still valuable as a test of concurrent runs of the
83                # metadata credentials code path.
84                raise AssertionError(
85                    "{} != {}".format(test_var.get(), _EXPECTED_VALUE)
86                )
87            callback((), None)
88
89        def assert_called(self, test):
90            test.assertTrue(self._invoked)
91            test.assertEqual(_EXPECTED_VALUE, self._recorded_value)
92
93else:
94
95    def set_up_expected_context():
96        pass
97
98    class TestCallCredentials(grpc.AuthMetadataPlugin):
99        def __call__(self, context, callback):
100            callback((), None)
101
102
103# TODO(https://github.com/grpc/grpc/issues/22257)
104@unittest.skipIf(os.name == "nt", "LocalCredentials not supported on Windows.")
105class ContextVarsPropagationTest(unittest.TestCase):
106    def test_propagation_to_auth_plugin(self):
107        set_up_expected_context()
108        with _server() as port:
109            target = "localhost:{}".format(port)
110            local_credentials = grpc.local_channel_credentials()
111            test_call_credentials = TestCallCredentials()
112            call_credentials = grpc.metadata_call_credentials(
113                test_call_credentials, "test call credentials"
114            )
115            composite_credentials = grpc.composite_channel_credentials(
116                local_credentials, call_credentials
117            )
118            with grpc.secure_channel(target, composite_credentials) as channel:
119                stub = channel.unary_unary(
120                    _UNARY_UNARY,
121                    _registered_method=True,
122                )
123                response = stub(_REQUEST, wait_for_ready=True)
124                self.assertEqual(_REQUEST, response)
125
126    def test_concurrent_propagation(self):
127        _THREAD_COUNT = 32
128        _RPC_COUNT = 32
129
130        set_up_expected_context()
131        with _server() as port:
132            target = "localhost:{}".format(port)
133            local_credentials = grpc.local_channel_credentials()
134            test_call_credentials = TestCallCredentials()
135            call_credentials = grpc.metadata_call_credentials(
136                test_call_credentials, "test call credentials"
137            )
138            composite_credentials = grpc.composite_channel_credentials(
139                local_credentials, call_credentials
140            )
141            wait_group = test_common.WaitGroup(_THREAD_COUNT)
142
143            def _run_on_thread(exception_queue):
144                try:
145                    with grpc.secure_channel(
146                        target, composite_credentials
147                    ) as channel:
148                        stub = channel.unary_unary(
149                            _UNARY_UNARY,
150                            _registered_method=True,
151                        )
152                        wait_group.done()
153                        wait_group.wait()
154                        for i in range(_RPC_COUNT):
155                            response = stub(_REQUEST, wait_for_ready=True)
156                            self.assertEqual(_REQUEST, response)
157                except Exception as e:  # pylint: disable=broad-except
158                    exception_queue.put(e)
159
160            threads = []
161
162            for _ in range(_THREAD_COUNT):
163                q = queue.Queue()
164                thread = threading.Thread(target=_run_on_thread, args=(q,))
165                thread.setDaemon(True)
166                thread.start()
167                threads.append((thread, q))
168
169            for thread, q in threads:
170                thread.join()
171                if not q.empty():
172                    raise q.get()
173
174
175if __name__ == "__main__":
176    logging.basicConfig()
177    unittest.main(verbosity=2)
178