1# Copyright 2020 The gRPC Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Test helpers for RPC invocation tests.""" 15 16import datetime 17import threading 18 19import grpc 20from grpc.framework.foundation import logging_pool 21 22from tests.unit import test_common 23from tests.unit import thread_pool 24from tests.unit.framework.common import test_constants 25from tests.unit.framework.common import test_control 26 27_SERIALIZE_REQUEST = lambda bytestring: bytestring * 2 28_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2 :] 29_SERIALIZE_RESPONSE = lambda bytestring: bytestring * 3 30_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[: len(bytestring) // 3] 31 32_UNARY_UNARY = "/test/UnaryUnary" 33_UNARY_STREAM = "/test/UnaryStream" 34_UNARY_STREAM_NON_BLOCKING = "/test/UnaryStreamNonBlocking" 35_STREAM_UNARY = "/test/StreamUnary" 36_STREAM_STREAM = "/test/StreamStream" 37_STREAM_STREAM_NON_BLOCKING = "/test/StreamStreamNonBlocking" 38 39TIMEOUT_SHORT = datetime.timedelta(seconds=4).total_seconds() 40 41 42class Callback(object): 43 def __init__(self): 44 self._condition = threading.Condition() 45 self._value = None 46 self._called = False 47 48 def __call__(self, value): 49 with self._condition: 50 self._value = value 51 self._called = True 52 self._condition.notify_all() 53 54 def value(self): 55 with self._condition: 56 while not self._called: 57 self._condition.wait() 58 return self._value 59 60 61class _Handler(object): 62 def __init__(self, control, thread_pool): 63 self._control = control 64 self._thread_pool = thread_pool 65 non_blocking_functions = ( 66 self.handle_unary_stream_non_blocking, 67 self.handle_stream_stream_non_blocking, 68 ) 69 for non_blocking_function in non_blocking_functions: 70 non_blocking_function.__func__.experimental_non_blocking = True 71 non_blocking_function.__func__.experimental_thread_pool = ( 72 self._thread_pool 73 ) 74 75 def handle_unary_unary(self, request, servicer_context): 76 self._control.control() 77 if servicer_context is not None: 78 servicer_context.set_trailing_metadata( 79 ( 80 ( 81 "testkey", 82 "testvalue", 83 ), 84 ) 85 ) 86 # TODO(https://github.com/grpc/grpc/issues/8483): test the values 87 # returned by these methods rather than only "smoke" testing that 88 # the return after having been called. 89 servicer_context.is_active() 90 servicer_context.time_remaining() 91 return request 92 93 def handle_unary_stream(self, request, servicer_context): 94 for _ in range(test_constants.STREAM_LENGTH): 95 self._control.control() 96 yield request 97 self._control.control() 98 if servicer_context is not None: 99 servicer_context.set_trailing_metadata( 100 ( 101 ( 102 "testkey", 103 "testvalue", 104 ), 105 ) 106 ) 107 108 def handle_unary_stream_non_blocking( 109 self, request, servicer_context, on_next 110 ): 111 for _ in range(test_constants.STREAM_LENGTH): 112 self._control.control() 113 on_next(request) 114 self._control.control() 115 if servicer_context is not None: 116 servicer_context.set_trailing_metadata( 117 ( 118 ( 119 "testkey", 120 "testvalue", 121 ), 122 ) 123 ) 124 on_next(None) 125 126 def handle_stream_unary(self, request_iterator, servicer_context): 127 if servicer_context is not None: 128 servicer_context.invocation_metadata() 129 self._control.control() 130 response_elements = [] 131 for request in request_iterator: 132 self._control.control() 133 response_elements.append(request) 134 self._control.control() 135 if servicer_context is not None: 136 servicer_context.set_trailing_metadata( 137 ( 138 ( 139 "testkey", 140 "testvalue", 141 ), 142 ) 143 ) 144 return b"".join(response_elements) 145 146 def handle_stream_stream(self, request_iterator, servicer_context): 147 self._control.control() 148 if servicer_context is not None: 149 servicer_context.set_trailing_metadata( 150 ( 151 ( 152 "testkey", 153 "testvalue", 154 ), 155 ) 156 ) 157 for request in request_iterator: 158 self._control.control() 159 yield request 160 self._control.control() 161 162 def handle_stream_stream_non_blocking( 163 self, request_iterator, servicer_context, on_next 164 ): 165 self._control.control() 166 if servicer_context is not None: 167 servicer_context.set_trailing_metadata( 168 ( 169 ( 170 "testkey", 171 "testvalue", 172 ), 173 ) 174 ) 175 for request in request_iterator: 176 self._control.control() 177 on_next(request) 178 self._control.control() 179 on_next(None) 180 181 182class _MethodHandler(grpc.RpcMethodHandler): 183 def __init__( 184 self, 185 request_streaming, 186 response_streaming, 187 request_deserializer, 188 response_serializer, 189 unary_unary, 190 unary_stream, 191 stream_unary, 192 stream_stream, 193 ): 194 self.request_streaming = request_streaming 195 self.response_streaming = response_streaming 196 self.request_deserializer = request_deserializer 197 self.response_serializer = response_serializer 198 self.unary_unary = unary_unary 199 self.unary_stream = unary_stream 200 self.stream_unary = stream_unary 201 self.stream_stream = stream_stream 202 203 204class _GenericHandler(grpc.GenericRpcHandler): 205 def __init__(self, handler): 206 self._handler = handler 207 208 def service(self, handler_call_details): 209 if handler_call_details.method == _UNARY_UNARY: 210 return _MethodHandler( 211 False, 212 False, 213 None, 214 None, 215 self._handler.handle_unary_unary, 216 None, 217 None, 218 None, 219 ) 220 elif handler_call_details.method == _UNARY_STREAM: 221 return _MethodHandler( 222 False, 223 True, 224 _DESERIALIZE_REQUEST, 225 _SERIALIZE_RESPONSE, 226 None, 227 self._handler.handle_unary_stream, 228 None, 229 None, 230 ) 231 elif handler_call_details.method == _UNARY_STREAM_NON_BLOCKING: 232 return _MethodHandler( 233 False, 234 True, 235 _DESERIALIZE_REQUEST, 236 _SERIALIZE_RESPONSE, 237 None, 238 self._handler.handle_unary_stream_non_blocking, 239 None, 240 None, 241 ) 242 elif handler_call_details.method == _STREAM_UNARY: 243 return _MethodHandler( 244 True, 245 False, 246 _DESERIALIZE_REQUEST, 247 _SERIALIZE_RESPONSE, 248 None, 249 None, 250 self._handler.handle_stream_unary, 251 None, 252 ) 253 elif handler_call_details.method == _STREAM_STREAM: 254 return _MethodHandler( 255 True, 256 True, 257 None, 258 None, 259 None, 260 None, 261 None, 262 self._handler.handle_stream_stream, 263 ) 264 elif handler_call_details.method == _STREAM_STREAM_NON_BLOCKING: 265 return _MethodHandler( 266 True, 267 True, 268 None, 269 None, 270 None, 271 None, 272 None, 273 self._handler.handle_stream_stream_non_blocking, 274 ) 275 else: 276 return None 277 278 279def unary_unary_multi_callable(channel): 280 return channel.unary_unary( 281 _UNARY_UNARY, 282 _registered_method=True, 283 ) 284 285 286def unary_stream_multi_callable(channel): 287 return channel.unary_stream( 288 _UNARY_STREAM, 289 request_serializer=_SERIALIZE_REQUEST, 290 response_deserializer=_DESERIALIZE_RESPONSE, 291 _registered_method=True, 292 ) 293 294 295def unary_stream_non_blocking_multi_callable(channel): 296 return channel.unary_stream( 297 _UNARY_STREAM_NON_BLOCKING, 298 request_serializer=_SERIALIZE_REQUEST, 299 response_deserializer=_DESERIALIZE_RESPONSE, 300 _registered_method=True, 301 ) 302 303 304def stream_unary_multi_callable(channel): 305 return channel.stream_unary( 306 _STREAM_UNARY, 307 request_serializer=_SERIALIZE_REQUEST, 308 response_deserializer=_DESERIALIZE_RESPONSE, 309 _registered_method=True, 310 ) 311 312 313def stream_stream_multi_callable(channel): 314 return channel.stream_stream( 315 _STREAM_STREAM, 316 _registered_method=True, 317 ) 318 319 320def stream_stream_non_blocking_multi_callable(channel): 321 return channel.stream_stream( 322 _STREAM_STREAM_NON_BLOCKING, 323 _registered_method=True, 324 ) 325 326 327class BaseRPCTest(object): 328 def setUp(self): 329 self._control = test_control.PauseFailControl() 330 self._thread_pool = thread_pool.RecordingThreadPool(max_workers=None) 331 self._handler = _Handler(self._control, self._thread_pool) 332 333 self._server = test_common.test_server() 334 port = self._server.add_insecure_port("[::]:0") 335 self._server.add_generic_rpc_handlers((_GenericHandler(self._handler),)) 336 self._server.start() 337 338 self._channel = grpc.insecure_channel("localhost:%d" % port) 339 340 def tearDown(self): 341 self._server.stop(None) 342 self._channel.close() 343 344 def _consume_one_stream_response_unary_request(self, multi_callable): 345 request = b"\x57\x38" 346 347 response_iterator = multi_callable( 348 request, 349 metadata=(("test", "ConsumingOneStreamResponseUnaryRequest"),), 350 ) 351 next(response_iterator) 352 353 def _consume_some_but_not_all_stream_responses_unary_request( 354 self, multi_callable 355 ): 356 request = b"\x57\x38" 357 358 response_iterator = multi_callable( 359 request, 360 metadata=( 361 ("test", "ConsumingSomeButNotAllStreamResponsesUnaryRequest"), 362 ), 363 ) 364 for _ in range(test_constants.STREAM_LENGTH // 2): 365 next(response_iterator) 366 367 def _consume_some_but_not_all_stream_responses_stream_request( 368 self, multi_callable 369 ): 370 requests = tuple( 371 b"\x67\x88" for _ in range(test_constants.STREAM_LENGTH) 372 ) 373 request_iterator = iter(requests) 374 375 response_iterator = multi_callable( 376 request_iterator, 377 metadata=( 378 ("test", "ConsumingSomeButNotAllStreamResponsesStreamRequest"), 379 ), 380 ) 381 for _ in range(test_constants.STREAM_LENGTH // 2): 382 next(response_iterator) 383 384 def _consume_too_many_stream_responses_stream_request(self, multi_callable): 385 requests = tuple( 386 b"\x67\x88" for _ in range(test_constants.STREAM_LENGTH) 387 ) 388 request_iterator = iter(requests) 389 390 response_iterator = multi_callable( 391 request_iterator, 392 metadata=( 393 ("test", "ConsumingTooManyStreamResponsesStreamRequest"), 394 ), 395 ) 396 for _ in range(test_constants.STREAM_LENGTH): 397 next(response_iterator) 398 for _ in range(test_constants.STREAM_LENGTH): 399 with self.assertRaises(StopIteration): 400 next(response_iterator) 401 402 self.assertIsNotNone(response_iterator.initial_metadata()) 403 self.assertIs(grpc.StatusCode.OK, response_iterator.code()) 404 self.assertIsNotNone(response_iterator.details()) 405 self.assertIsNotNone(response_iterator.trailing_metadata()) 406 407 def _cancelled_unary_request_stream_response(self, multi_callable): 408 request = b"\x07\x19" 409 410 with self._control.pause(): 411 response_iterator = multi_callable( 412 request, 413 metadata=(("test", "CancelledUnaryRequestStreamResponse"),), 414 ) 415 self._control.block_until_paused() 416 response_iterator.cancel() 417 418 with self.assertRaises(grpc.RpcError) as exception_context: 419 next(response_iterator) 420 self.assertIs( 421 grpc.StatusCode.CANCELLED, exception_context.exception.code() 422 ) 423 self.assertIsNotNone(response_iterator.initial_metadata()) 424 self.assertIs(grpc.StatusCode.CANCELLED, response_iterator.code()) 425 self.assertIsNotNone(response_iterator.details()) 426 self.assertIsNotNone(response_iterator.trailing_metadata()) 427 428 def _cancelled_stream_request_stream_response(self, multi_callable): 429 requests = tuple( 430 b"\x07\x08" for _ in range(test_constants.STREAM_LENGTH) 431 ) 432 request_iterator = iter(requests) 433 434 with self._control.pause(): 435 response_iterator = multi_callable( 436 request_iterator, 437 metadata=(("test", "CancelledStreamRequestStreamResponse"),), 438 ) 439 response_iterator.cancel() 440 441 with self.assertRaises(grpc.RpcError): 442 next(response_iterator) 443 self.assertIsNotNone(response_iterator.initial_metadata()) 444 self.assertIs(grpc.StatusCode.CANCELLED, response_iterator.code()) 445 self.assertIsNotNone(response_iterator.details()) 446 self.assertIsNotNone(response_iterator.trailing_metadata()) 447 448 def _expired_unary_request_stream_response(self, multi_callable): 449 request = b"\x07\x19" 450 451 with self._control.pause(): 452 with self.assertRaises(grpc.RpcError) as exception_context: 453 response_iterator = multi_callable( 454 request, 455 timeout=test_constants.SHORT_TIMEOUT, 456 metadata=(("test", "ExpiredUnaryRequestStreamResponse"),), 457 ) 458 next(response_iterator) 459 460 self.assertIs( 461 grpc.StatusCode.DEADLINE_EXCEEDED, 462 exception_context.exception.code(), 463 ) 464 self.assertIs( 465 grpc.StatusCode.DEADLINE_EXCEEDED, response_iterator.code() 466 ) 467 468 def _expired_stream_request_stream_response(self, multi_callable): 469 requests = tuple( 470 b"\x67\x18" for _ in range(test_constants.STREAM_LENGTH) 471 ) 472 request_iterator = iter(requests) 473 474 with self._control.pause(): 475 with self.assertRaises(grpc.RpcError) as exception_context: 476 response_iterator = multi_callable( 477 request_iterator, 478 timeout=test_constants.SHORT_TIMEOUT, 479 metadata=(("test", "ExpiredStreamRequestStreamResponse"),), 480 ) 481 next(response_iterator) 482 483 self.assertIs( 484 grpc.StatusCode.DEADLINE_EXCEEDED, 485 exception_context.exception.code(), 486 ) 487 self.assertIs( 488 grpc.StatusCode.DEADLINE_EXCEEDED, response_iterator.code() 489 ) 490 491 def _failed_unary_request_stream_response(self, multi_callable): 492 request = b"\x37\x17" 493 494 with self.assertRaises(grpc.RpcError) as exception_context: 495 with self._control.fail(): 496 response_iterator = multi_callable( 497 request, 498 metadata=(("test", "FailedUnaryRequestStreamResponse"),), 499 ) 500 next(response_iterator) 501 502 self.assertIs( 503 grpc.StatusCode.UNKNOWN, exception_context.exception.code() 504 ) 505 506 def _failed_stream_request_stream_response(self, multi_callable): 507 requests = tuple( 508 b"\x67\x88" for _ in range(test_constants.STREAM_LENGTH) 509 ) 510 request_iterator = iter(requests) 511 512 with self._control.fail(): 513 with self.assertRaises(grpc.RpcError) as exception_context: 514 response_iterator = multi_callable( 515 request_iterator, 516 metadata=(("test", "FailedStreamRequestStreamResponse"),), 517 ) 518 tuple(response_iterator) 519 520 self.assertIs( 521 grpc.StatusCode.UNKNOWN, exception_context.exception.code() 522 ) 523 self.assertIs(grpc.StatusCode.UNKNOWN, response_iterator.code()) 524 525 def _ignored_unary_stream_request_future_unary_response( 526 self, multi_callable 527 ): 528 request = b"\x37\x17" 529 530 multi_callable( 531 request, metadata=(("test", "IgnoredUnaryRequestStreamResponse"),) 532 ) 533 534 def _ignored_stream_request_stream_response(self, multi_callable): 535 requests = tuple( 536 b"\x67\x88" for _ in range(test_constants.STREAM_LENGTH) 537 ) 538 request_iterator = iter(requests) 539 540 multi_callable( 541 request_iterator, 542 metadata=(("test", "IgnoredStreamRequestStreamResponse"),), 543 ) 544