1# Copyright 2016 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests server and client side metadata API.""" 15 16import logging 17import unittest 18import weakref 19 20import grpc 21from grpc import _channel 22 23from tests.unit import test_common 24from tests.unit.framework.common import test_constants 25 26_CHANNEL_ARGS = ( 27 ("grpc.primary_user_agent", "primary-agent"), 28 ("grpc.secondary_user_agent", "secondary-agent"), 29) 30 31_REQUEST = b"\x00\x00\x00" 32_RESPONSE = b"\x00\x00\x00" 33 34_UNARY_UNARY = "/test/UnaryUnary" 35_UNARY_STREAM = "/test/UnaryStream" 36_STREAM_UNARY = "/test/StreamUnary" 37_STREAM_STREAM = "/test/StreamStream" 38 39_INVOCATION_METADATA = ( 40 ( 41 b"invocation-md-key", 42 "invocation-md-value", 43 ), 44 ( 45 "invocation-md-key-bin", 46 b"\x00\x01", 47 ), 48) 49_EXPECTED_INVOCATION_METADATA = ( 50 ( 51 "invocation-md-key", 52 "invocation-md-value", 53 ), 54 ( 55 "invocation-md-key-bin", 56 b"\x00\x01", 57 ), 58) 59 60_INITIAL_METADATA = ( 61 (b"initial-md-key", "initial-md-value"), 62 ("initial-md-key-bin", b"\x00\x02"), 63) 64_EXPECTED_INITIAL_METADATA = ( 65 ( 66 "initial-md-key", 67 "initial-md-value", 68 ), 69 ( 70 "initial-md-key-bin", 71 b"\x00\x02", 72 ), 73) 74 75_TRAILING_METADATA = ( 76 ( 77 "server-trailing-md-key", 78 "server-trailing-md-value", 79 ), 80 ( 81 "server-trailing-md-key-bin", 82 b"\x00\x03", 83 ), 84) 85_EXPECTED_TRAILING_METADATA = _TRAILING_METADATA 86 87 88def _user_agent(metadata): 89 for key, val in metadata: 90 if key == "user-agent": 91 return val 92 raise KeyError("No user agent!") 93 94 95def validate_client_metadata(test, servicer_context): 96 invocation_metadata = servicer_context.invocation_metadata() 97 test.assertTrue( 98 test_common.metadata_transmitted( 99 _EXPECTED_INVOCATION_METADATA, invocation_metadata 100 ) 101 ) 102 user_agent = _user_agent(invocation_metadata) 103 test.assertTrue( 104 user_agent.startswith("primary-agent " + _channel._USER_AGENT) 105 ) 106 test.assertTrue(user_agent.endswith("secondary-agent")) 107 108 109def handle_unary_unary(test, request, servicer_context): 110 validate_client_metadata(test, servicer_context) 111 servicer_context.send_initial_metadata(_INITIAL_METADATA) 112 servicer_context.set_trailing_metadata(_TRAILING_METADATA) 113 return _RESPONSE 114 115 116def handle_unary_stream(test, request, servicer_context): 117 validate_client_metadata(test, servicer_context) 118 servicer_context.send_initial_metadata(_INITIAL_METADATA) 119 servicer_context.set_trailing_metadata(_TRAILING_METADATA) 120 for _ in range(test_constants.STREAM_LENGTH): 121 yield _RESPONSE 122 123 124def handle_stream_unary(test, request_iterator, servicer_context): 125 validate_client_metadata(test, servicer_context) 126 servicer_context.send_initial_metadata(_INITIAL_METADATA) 127 servicer_context.set_trailing_metadata(_TRAILING_METADATA) 128 # TODO(issue:#6891) We should be able to remove this loop 129 for request in request_iterator: 130 pass 131 return _RESPONSE 132 133 134def handle_stream_stream(test, request_iterator, servicer_context): 135 validate_client_metadata(test, servicer_context) 136 servicer_context.send_initial_metadata(_INITIAL_METADATA) 137 servicer_context.set_trailing_metadata(_TRAILING_METADATA) 138 # TODO(issue:#6891) We should be able to remove this loop, 139 # and replace with return; yield 140 for request in request_iterator: 141 yield _RESPONSE 142 143 144class _MethodHandler(grpc.RpcMethodHandler): 145 def __init__(self, test, request_streaming, response_streaming): 146 self.request_streaming = request_streaming 147 self.response_streaming = response_streaming 148 self.request_deserializer = None 149 self.response_serializer = None 150 self.unary_unary = None 151 self.unary_stream = None 152 self.stream_unary = None 153 self.stream_stream = None 154 if self.request_streaming and self.response_streaming: 155 self.stream_stream = lambda x, y: handle_stream_stream(test, x, y) 156 elif self.request_streaming: 157 self.stream_unary = lambda x, y: handle_stream_unary(test, x, y) 158 elif self.response_streaming: 159 self.unary_stream = lambda x, y: handle_unary_stream(test, x, y) 160 else: 161 self.unary_unary = lambda x, y: handle_unary_unary(test, x, y) 162 163 164class _GenericHandler(grpc.GenericRpcHandler): 165 def __init__(self, test): 166 self._test = test 167 168 def service(self, handler_call_details): 169 if handler_call_details.method == _UNARY_UNARY: 170 return _MethodHandler(self._test, False, False) 171 elif handler_call_details.method == _UNARY_STREAM: 172 return _MethodHandler(self._test, False, True) 173 elif handler_call_details.method == _STREAM_UNARY: 174 return _MethodHandler(self._test, True, False) 175 elif handler_call_details.method == _STREAM_STREAM: 176 return _MethodHandler(self._test, True, True) 177 else: 178 return None 179 180 181class MetadataTest(unittest.TestCase): 182 def setUp(self): 183 self._server = test_common.test_server() 184 self._server.add_generic_rpc_handlers( 185 (_GenericHandler(weakref.proxy(self)),) 186 ) 187 port = self._server.add_insecure_port("[::]:0") 188 self._server.start() 189 self._channel = grpc.insecure_channel( 190 "localhost:%d" % port, options=_CHANNEL_ARGS 191 ) 192 193 def tearDown(self): 194 self._server.stop(0) 195 self._channel.close() 196 197 def testUnaryUnary(self): 198 multi_callable = self._channel.unary_unary( 199 _UNARY_UNARY, _registered_method=True 200 ) 201 unused_response, call = multi_callable.with_call( 202 _REQUEST, metadata=_INVOCATION_METADATA 203 ) 204 self.assertTrue( 205 test_common.metadata_transmitted( 206 _EXPECTED_INITIAL_METADATA, call.initial_metadata() 207 ) 208 ) 209 self.assertTrue( 210 test_common.metadata_transmitted( 211 _EXPECTED_TRAILING_METADATA, call.trailing_metadata() 212 ) 213 ) 214 215 def testUnaryStream(self): 216 multi_callable = self._channel.unary_stream( 217 _UNARY_STREAM, _registered_method=True 218 ) 219 call = multi_callable(_REQUEST, metadata=_INVOCATION_METADATA) 220 self.assertTrue( 221 test_common.metadata_transmitted( 222 _EXPECTED_INITIAL_METADATA, call.initial_metadata() 223 ) 224 ) 225 for _ in call: 226 pass 227 self.assertTrue( 228 test_common.metadata_transmitted( 229 _EXPECTED_TRAILING_METADATA, call.trailing_metadata() 230 ) 231 ) 232 233 def testStreamUnary(self): 234 multi_callable = self._channel.stream_unary( 235 _STREAM_UNARY, _registered_method=True 236 ) 237 unused_response, call = multi_callable.with_call( 238 iter([_REQUEST] * test_constants.STREAM_LENGTH), 239 metadata=_INVOCATION_METADATA, 240 ) 241 self.assertTrue( 242 test_common.metadata_transmitted( 243 _EXPECTED_INITIAL_METADATA, call.initial_metadata() 244 ) 245 ) 246 self.assertTrue( 247 test_common.metadata_transmitted( 248 _EXPECTED_TRAILING_METADATA, call.trailing_metadata() 249 ) 250 ) 251 252 def testStreamStream(self): 253 multi_callable = self._channel.stream_stream( 254 _STREAM_STREAM, _registered_method=True 255 ) 256 call = multi_callable( 257 iter([_REQUEST] * test_constants.STREAM_LENGTH), 258 metadata=_INVOCATION_METADATA, 259 ) 260 self.assertTrue( 261 test_common.metadata_transmitted( 262 _EXPECTED_INITIAL_METADATA, call.initial_metadata() 263 ) 264 ) 265 for _ in call: 266 pass 267 self.assertTrue( 268 test_common.metadata_transmitted( 269 _EXPECTED_TRAILING_METADATA, call.trailing_metadata() 270 ) 271 ) 272 273 274if __name__ == "__main__": 275 logging.basicConfig() 276 unittest.main(verbosity=2) 277