xref: /aosp_15_r20/external/grpc-grpc/src/python/grpcio_tests/tests/unit/_metadata_test.py (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests server and client side metadata API."""
15
16import logging
17import unittest
18import weakref
19
20import grpc
21from grpc import _channel
22
23from tests.unit import test_common
24from tests.unit.framework.common import test_constants
25
26_CHANNEL_ARGS = (
27    ("grpc.primary_user_agent", "primary-agent"),
28    ("grpc.secondary_user_agent", "secondary-agent"),
29)
30
31_REQUEST = b"\x00\x00\x00"
32_RESPONSE = b"\x00\x00\x00"
33
34_UNARY_UNARY = "/test/UnaryUnary"
35_UNARY_STREAM = "/test/UnaryStream"
36_STREAM_UNARY = "/test/StreamUnary"
37_STREAM_STREAM = "/test/StreamStream"
38
39_INVOCATION_METADATA = (
40    (
41        b"invocation-md-key",
42        "invocation-md-value",
43    ),
44    (
45        "invocation-md-key-bin",
46        b"\x00\x01",
47    ),
48)
49_EXPECTED_INVOCATION_METADATA = (
50    (
51        "invocation-md-key",
52        "invocation-md-value",
53    ),
54    (
55        "invocation-md-key-bin",
56        b"\x00\x01",
57    ),
58)
59
60_INITIAL_METADATA = (
61    (b"initial-md-key", "initial-md-value"),
62    ("initial-md-key-bin", b"\x00\x02"),
63)
64_EXPECTED_INITIAL_METADATA = (
65    (
66        "initial-md-key",
67        "initial-md-value",
68    ),
69    (
70        "initial-md-key-bin",
71        b"\x00\x02",
72    ),
73)
74
75_TRAILING_METADATA = (
76    (
77        "server-trailing-md-key",
78        "server-trailing-md-value",
79    ),
80    (
81        "server-trailing-md-key-bin",
82        b"\x00\x03",
83    ),
84)
85_EXPECTED_TRAILING_METADATA = _TRAILING_METADATA
86
87
88def _user_agent(metadata):
89    for key, val in metadata:
90        if key == "user-agent":
91            return val
92    raise KeyError("No user agent!")
93
94
95def validate_client_metadata(test, servicer_context):
96    invocation_metadata = servicer_context.invocation_metadata()
97    test.assertTrue(
98        test_common.metadata_transmitted(
99            _EXPECTED_INVOCATION_METADATA, invocation_metadata
100        )
101    )
102    user_agent = _user_agent(invocation_metadata)
103    test.assertTrue(
104        user_agent.startswith("primary-agent " + _channel._USER_AGENT)
105    )
106    test.assertTrue(user_agent.endswith("secondary-agent"))
107
108
109def handle_unary_unary(test, request, servicer_context):
110    validate_client_metadata(test, servicer_context)
111    servicer_context.send_initial_metadata(_INITIAL_METADATA)
112    servicer_context.set_trailing_metadata(_TRAILING_METADATA)
113    return _RESPONSE
114
115
116def handle_unary_stream(test, request, servicer_context):
117    validate_client_metadata(test, servicer_context)
118    servicer_context.send_initial_metadata(_INITIAL_METADATA)
119    servicer_context.set_trailing_metadata(_TRAILING_METADATA)
120    for _ in range(test_constants.STREAM_LENGTH):
121        yield _RESPONSE
122
123
124def handle_stream_unary(test, request_iterator, servicer_context):
125    validate_client_metadata(test, servicer_context)
126    servicer_context.send_initial_metadata(_INITIAL_METADATA)
127    servicer_context.set_trailing_metadata(_TRAILING_METADATA)
128    # TODO(issue:#6891) We should be able to remove this loop
129    for request in request_iterator:
130        pass
131    return _RESPONSE
132
133
134def handle_stream_stream(test, request_iterator, servicer_context):
135    validate_client_metadata(test, servicer_context)
136    servicer_context.send_initial_metadata(_INITIAL_METADATA)
137    servicer_context.set_trailing_metadata(_TRAILING_METADATA)
138    # TODO(issue:#6891) We should be able to remove this loop,
139    # and replace with return; yield
140    for request in request_iterator:
141        yield _RESPONSE
142
143
144class _MethodHandler(grpc.RpcMethodHandler):
145    def __init__(self, test, request_streaming, response_streaming):
146        self.request_streaming = request_streaming
147        self.response_streaming = response_streaming
148        self.request_deserializer = None
149        self.response_serializer = None
150        self.unary_unary = None
151        self.unary_stream = None
152        self.stream_unary = None
153        self.stream_stream = None
154        if self.request_streaming and self.response_streaming:
155            self.stream_stream = lambda x, y: handle_stream_stream(test, x, y)
156        elif self.request_streaming:
157            self.stream_unary = lambda x, y: handle_stream_unary(test, x, y)
158        elif self.response_streaming:
159            self.unary_stream = lambda x, y: handle_unary_stream(test, x, y)
160        else:
161            self.unary_unary = lambda x, y: handle_unary_unary(test, x, y)
162
163
164class _GenericHandler(grpc.GenericRpcHandler):
165    def __init__(self, test):
166        self._test = test
167
168    def service(self, handler_call_details):
169        if handler_call_details.method == _UNARY_UNARY:
170            return _MethodHandler(self._test, False, False)
171        elif handler_call_details.method == _UNARY_STREAM:
172            return _MethodHandler(self._test, False, True)
173        elif handler_call_details.method == _STREAM_UNARY:
174            return _MethodHandler(self._test, True, False)
175        elif handler_call_details.method == _STREAM_STREAM:
176            return _MethodHandler(self._test, True, True)
177        else:
178            return None
179
180
181class MetadataTest(unittest.TestCase):
182    def setUp(self):
183        self._server = test_common.test_server()
184        self._server.add_generic_rpc_handlers(
185            (_GenericHandler(weakref.proxy(self)),)
186        )
187        port = self._server.add_insecure_port("[::]:0")
188        self._server.start()
189        self._channel = grpc.insecure_channel(
190            "localhost:%d" % port, options=_CHANNEL_ARGS
191        )
192
193    def tearDown(self):
194        self._server.stop(0)
195        self._channel.close()
196
197    def testUnaryUnary(self):
198        multi_callable = self._channel.unary_unary(
199            _UNARY_UNARY, _registered_method=True
200        )
201        unused_response, call = multi_callable.with_call(
202            _REQUEST, metadata=_INVOCATION_METADATA
203        )
204        self.assertTrue(
205            test_common.metadata_transmitted(
206                _EXPECTED_INITIAL_METADATA, call.initial_metadata()
207            )
208        )
209        self.assertTrue(
210            test_common.metadata_transmitted(
211                _EXPECTED_TRAILING_METADATA, call.trailing_metadata()
212            )
213        )
214
215    def testUnaryStream(self):
216        multi_callable = self._channel.unary_stream(
217            _UNARY_STREAM, _registered_method=True
218        )
219        call = multi_callable(_REQUEST, metadata=_INVOCATION_METADATA)
220        self.assertTrue(
221            test_common.metadata_transmitted(
222                _EXPECTED_INITIAL_METADATA, call.initial_metadata()
223            )
224        )
225        for _ in call:
226            pass
227        self.assertTrue(
228            test_common.metadata_transmitted(
229                _EXPECTED_TRAILING_METADATA, call.trailing_metadata()
230            )
231        )
232
233    def testStreamUnary(self):
234        multi_callable = self._channel.stream_unary(
235            _STREAM_UNARY, _registered_method=True
236        )
237        unused_response, call = multi_callable.with_call(
238            iter([_REQUEST] * test_constants.STREAM_LENGTH),
239            metadata=_INVOCATION_METADATA,
240        )
241        self.assertTrue(
242            test_common.metadata_transmitted(
243                _EXPECTED_INITIAL_METADATA, call.initial_metadata()
244            )
245        )
246        self.assertTrue(
247            test_common.metadata_transmitted(
248                _EXPECTED_TRAILING_METADATA, call.trailing_metadata()
249            )
250        )
251
252    def testStreamStream(self):
253        multi_callable = self._channel.stream_stream(
254            _STREAM_STREAM, _registered_method=True
255        )
256        call = multi_callable(
257            iter([_REQUEST] * test_constants.STREAM_LENGTH),
258            metadata=_INVOCATION_METADATA,
259        )
260        self.assertTrue(
261            test_common.metadata_transmitted(
262                _EXPECTED_INITIAL_METADATA, call.initial_metadata()
263            )
264        )
265        for _ in call:
266            pass
267        self.assertTrue(
268            test_common.metadata_transmitted(
269                _EXPECTED_TRAILING_METADATA, call.trailing_metadata()
270            )
271        )
272
273
274if __name__ == "__main__":
275    logging.basicConfig()
276    unittest.main(verbosity=2)
277