1# Copyright 2015 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Implementations of interoperability test methods.""" 15 16# NOTE(lidiz) This module only exists in Bazel BUILD file, for more details 17# please refer to comments in the "bazel_namespace_package_hack" module. 18try: 19 from tests import bazel_namespace_package_hack 20 21 bazel_namespace_package_hack.sys_path_to_site_dir_hack() 22except ImportError: 23 pass 24 25import enum 26import json 27import os 28import threading 29import time 30 31from google import auth as google_auth 32from google.auth import environment_vars as google_auth_environment_vars 33from google.auth.transport import grpc as google_auth_transport_grpc 34from google.auth.transport import requests as google_auth_transport_requests 35import grpc 36 37from src.proto.grpc.testing import empty_pb2 38from src.proto.grpc.testing import messages_pb2 39 40_INITIAL_METADATA_KEY = "x-grpc-test-echo-initial" 41_TRAILING_METADATA_KEY = "x-grpc-test-echo-trailing-bin" 42 43 44def _expect_status_code(call, expected_code): 45 if call.code() != expected_code: 46 raise ValueError( 47 "expected code %s, got %s" % (expected_code, call.code()) 48 ) 49 50 51def _expect_status_details(call, expected_details): 52 if call.details() != expected_details: 53 raise ValueError( 54 "expected message %s, got %s" % (expected_details, call.details()) 55 ) 56 57 58def _validate_status_code_and_details(call, expected_code, expected_details): 59 _expect_status_code(call, expected_code) 60 _expect_status_details(call, expected_details) 61 62 63def _validate_payload_type_and_length(response, expected_type, expected_length): 64 if response.payload.type is not expected_type: 65 raise ValueError( 66 "expected payload type %s, got %s" 67 % (expected_type, type(response.payload.type)) 68 ) 69 elif len(response.payload.body) != expected_length: 70 raise ValueError( 71 "expected payload body size %d, got %d" 72 % (expected_length, len(response.payload.body)) 73 ) 74 75 76def _large_unary_common_behavior( 77 stub, fill_username, fill_oauth_scope, call_credentials 78): 79 size = 314159 80 request = messages_pb2.SimpleRequest( 81 response_type=messages_pb2.COMPRESSABLE, 82 response_size=size, 83 payload=messages_pb2.Payload(body=b"\x00" * 271828), 84 fill_username=fill_username, 85 fill_oauth_scope=fill_oauth_scope, 86 ) 87 response_future = stub.UnaryCall.future( 88 request, credentials=call_credentials 89 ) 90 response = response_future.result() 91 _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size) 92 return response 93 94 95def _empty_unary(stub): 96 response = stub.EmptyCall(empty_pb2.Empty()) 97 if not isinstance(response, empty_pb2.Empty): 98 raise TypeError( 99 'response is of type "%s", not empty_pb2.Empty!' % type(response) 100 ) 101 102 103def _large_unary(stub): 104 _large_unary_common_behavior(stub, False, False, None) 105 106 107def _client_streaming(stub): 108 payload_body_sizes = ( 109 27182, 110 8, 111 1828, 112 45904, 113 ) 114 payloads = ( 115 messages_pb2.Payload(body=b"\x00" * size) for size in payload_body_sizes 116 ) 117 requests = ( 118 messages_pb2.StreamingInputCallRequest(payload=payload) 119 for payload in payloads 120 ) 121 response = stub.StreamingInputCall(requests) 122 if response.aggregated_payload_size != 74922: 123 raise ValueError( 124 "incorrect size %d!" % response.aggregated_payload_size 125 ) 126 127 128def _server_streaming(stub): 129 sizes = ( 130 31415, 131 9, 132 2653, 133 58979, 134 ) 135 136 request = messages_pb2.StreamingOutputCallRequest( 137 response_type=messages_pb2.COMPRESSABLE, 138 response_parameters=( 139 messages_pb2.ResponseParameters(size=sizes[0]), 140 messages_pb2.ResponseParameters(size=sizes[1]), 141 messages_pb2.ResponseParameters(size=sizes[2]), 142 messages_pb2.ResponseParameters(size=sizes[3]), 143 ), 144 ) 145 response_iterator = stub.StreamingOutputCall(request) 146 for index, response in enumerate(response_iterator): 147 _validate_payload_type_and_length( 148 response, messages_pb2.COMPRESSABLE, sizes[index] 149 ) 150 151 152class _Pipe(object): 153 def __init__(self): 154 self._condition = threading.Condition() 155 self._values = [] 156 self._open = True 157 158 def __iter__(self): 159 return self 160 161 def __next__(self): 162 return self.next() 163 164 def next(self): 165 with self._condition: 166 while not self._values and self._open: 167 self._condition.wait() 168 if self._values: 169 return self._values.pop(0) 170 else: 171 raise StopIteration() 172 173 def add(self, value): 174 with self._condition: 175 self._values.append(value) 176 self._condition.notify() 177 178 def close(self): 179 with self._condition: 180 self._open = False 181 self._condition.notify() 182 183 def __enter__(self): 184 return self 185 186 def __exit__(self, type, value, traceback): 187 self.close() 188 189 190def _ping_pong(stub): 191 request_response_sizes = ( 192 31415, 193 9, 194 2653, 195 58979, 196 ) 197 request_payload_sizes = ( 198 27182, 199 8, 200 1828, 201 45904, 202 ) 203 204 with _Pipe() as pipe: 205 response_iterator = stub.FullDuplexCall(pipe) 206 for response_size, payload_size in zip( 207 request_response_sizes, request_payload_sizes 208 ): 209 request = messages_pb2.StreamingOutputCallRequest( 210 response_type=messages_pb2.COMPRESSABLE, 211 response_parameters=( 212 messages_pb2.ResponseParameters(size=response_size), 213 ), 214 payload=messages_pb2.Payload(body=b"\x00" * payload_size), 215 ) 216 pipe.add(request) 217 response = next(response_iterator) 218 _validate_payload_type_and_length( 219 response, messages_pb2.COMPRESSABLE, response_size 220 ) 221 222 223def _cancel_after_begin(stub): 224 with _Pipe() as pipe: 225 response_future = stub.StreamingInputCall.future(pipe) 226 response_future.cancel() 227 if not response_future.cancelled(): 228 raise ValueError("expected cancelled method to return True") 229 if response_future.code() is not grpc.StatusCode.CANCELLED: 230 raise ValueError("expected status code CANCELLED") 231 232 233def _cancel_after_first_response(stub): 234 request_response_sizes = ( 235 31415, 236 9, 237 2653, 238 58979, 239 ) 240 request_payload_sizes = ( 241 27182, 242 8, 243 1828, 244 45904, 245 ) 246 with _Pipe() as pipe: 247 response_iterator = stub.FullDuplexCall(pipe) 248 249 response_size = request_response_sizes[0] 250 payload_size = request_payload_sizes[0] 251 request = messages_pb2.StreamingOutputCallRequest( 252 response_type=messages_pb2.COMPRESSABLE, 253 response_parameters=( 254 messages_pb2.ResponseParameters(size=response_size), 255 ), 256 payload=messages_pb2.Payload(body=b"\x00" * payload_size), 257 ) 258 pipe.add(request) 259 response = next(response_iterator) 260 # We test the contents of `response` in the Ping Pong test - don't check 261 # them here. 262 response_iterator.cancel() 263 264 try: 265 next(response_iterator) 266 except grpc.RpcError as rpc_error: 267 if rpc_error.code() is not grpc.StatusCode.CANCELLED: 268 raise 269 else: 270 raise ValueError("expected call to be cancelled") 271 272 273def _timeout_on_sleeping_server(stub): 274 request_payload_size = 27182 275 with _Pipe() as pipe: 276 response_iterator = stub.FullDuplexCall(pipe, timeout=0.001) 277 278 request = messages_pb2.StreamingOutputCallRequest( 279 response_type=messages_pb2.COMPRESSABLE, 280 payload=messages_pb2.Payload(body=b"\x00" * request_payload_size), 281 ) 282 pipe.add(request) 283 try: 284 next(response_iterator) 285 except grpc.RpcError as rpc_error: 286 if rpc_error.code() is not grpc.StatusCode.DEADLINE_EXCEEDED: 287 raise 288 else: 289 raise ValueError("expected call to exceed deadline") 290 291 292def _empty_stream(stub): 293 with _Pipe() as pipe: 294 response_iterator = stub.FullDuplexCall(pipe) 295 pipe.close() 296 try: 297 next(response_iterator) 298 raise ValueError("expected exactly 0 responses") 299 except StopIteration: 300 pass 301 302 303def _status_code_and_message(stub): 304 details = "test status message" 305 code = 2 306 status = grpc.StatusCode.UNKNOWN # code = 2 307 308 # Test with a UnaryCall 309 request = messages_pb2.SimpleRequest( 310 response_type=messages_pb2.COMPRESSABLE, 311 response_size=1, 312 payload=messages_pb2.Payload(body=b"\x00"), 313 response_status=messages_pb2.EchoStatus(code=code, message=details), 314 ) 315 response_future = stub.UnaryCall.future(request) 316 _validate_status_code_and_details(response_future, status, details) 317 318 # Test with a FullDuplexCall 319 with _Pipe() as pipe: 320 response_iterator = stub.FullDuplexCall(pipe) 321 request = messages_pb2.StreamingOutputCallRequest( 322 response_type=messages_pb2.COMPRESSABLE, 323 response_parameters=(messages_pb2.ResponseParameters(size=1),), 324 payload=messages_pb2.Payload(body=b"\x00"), 325 response_status=messages_pb2.EchoStatus(code=code, message=details), 326 ) 327 pipe.add(request) # sends the initial request. 328 try: 329 next(response_iterator) 330 except grpc.RpcError as rpc_error: 331 assert rpc_error.code() == status 332 # Dropping out of with block closes the pipe 333 _validate_status_code_and_details(response_iterator, status, details) 334 335 336def _unimplemented_method(test_service_stub): 337 response_future = test_service_stub.UnimplementedCall.future( 338 empty_pb2.Empty() 339 ) 340 _expect_status_code(response_future, grpc.StatusCode.UNIMPLEMENTED) 341 342 343def _unimplemented_service(unimplemented_service_stub): 344 response_future = unimplemented_service_stub.UnimplementedCall.future( 345 empty_pb2.Empty() 346 ) 347 _expect_status_code(response_future, grpc.StatusCode.UNIMPLEMENTED) 348 349 350def _custom_metadata(stub): 351 initial_metadata_value = "test_initial_metadata_value" 352 trailing_metadata_value = b"\x0a\x0b\x0a\x0b\x0a\x0b" 353 metadata = ( 354 (_INITIAL_METADATA_KEY, initial_metadata_value), 355 (_TRAILING_METADATA_KEY, trailing_metadata_value), 356 ) 357 358 def _validate_metadata(response): 359 initial_metadata = dict(response.initial_metadata()) 360 if initial_metadata[_INITIAL_METADATA_KEY] != initial_metadata_value: 361 raise ValueError( 362 "expected initial metadata %s, got %s" 363 % ( 364 initial_metadata_value, 365 initial_metadata[_INITIAL_METADATA_KEY], 366 ) 367 ) 368 trailing_metadata = dict(response.trailing_metadata()) 369 if trailing_metadata[_TRAILING_METADATA_KEY] != trailing_metadata_value: 370 raise ValueError( 371 "expected trailing metadata %s, got %s" 372 % ( 373 trailing_metadata_value, 374 trailing_metadata[_TRAILING_METADATA_KEY], 375 ) 376 ) 377 378 # Testing with UnaryCall 379 request = messages_pb2.SimpleRequest( 380 response_type=messages_pb2.COMPRESSABLE, 381 response_size=1, 382 payload=messages_pb2.Payload(body=b"\x00"), 383 ) 384 response_future = stub.UnaryCall.future(request, metadata=metadata) 385 _validate_metadata(response_future) 386 387 # Testing with FullDuplexCall 388 with _Pipe() as pipe: 389 response_iterator = stub.FullDuplexCall(pipe, metadata=metadata) 390 request = messages_pb2.StreamingOutputCallRequest( 391 response_type=messages_pb2.COMPRESSABLE, 392 response_parameters=(messages_pb2.ResponseParameters(size=1),), 393 ) 394 pipe.add(request) # Sends the request 395 next(response_iterator) # Causes server to send trailing metadata 396 # Dropping out of the with block closes the pipe 397 _validate_metadata(response_iterator) 398 399 400def _compute_engine_creds(stub, args): 401 response = _large_unary_common_behavior(stub, True, True, None) 402 if args.default_service_account != response.username: 403 raise ValueError( 404 "expected username %s, got %s" 405 % (args.default_service_account, response.username) 406 ) 407 408 409def _oauth2_auth_token(stub, args): 410 json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS] 411 wanted_email = json.load(open(json_key_filename, "r"))["client_email"] 412 response = _large_unary_common_behavior(stub, True, True, None) 413 if wanted_email != response.username: 414 raise ValueError( 415 "expected username %s, got %s" % (wanted_email, response.username) 416 ) 417 if args.oauth_scope.find(response.oauth_scope) == -1: 418 raise ValueError( 419 'expected to find oauth scope "{}" in received "{}"'.format( 420 response.oauth_scope, args.oauth_scope 421 ) 422 ) 423 424 425def _jwt_token_creds(stub, args): 426 json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS] 427 wanted_email = json.load(open(json_key_filename, "r"))["client_email"] 428 response = _large_unary_common_behavior(stub, True, False, None) 429 if wanted_email != response.username: 430 raise ValueError( 431 "expected username %s, got %s" % (wanted_email, response.username) 432 ) 433 434 435def _per_rpc_creds(stub, args): 436 json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS] 437 wanted_email = json.load(open(json_key_filename, "r"))["client_email"] 438 google_credentials, unused_project_id = google_auth.default( 439 scopes=[args.oauth_scope] 440 ) 441 call_credentials = grpc.metadata_call_credentials( 442 google_auth_transport_grpc.AuthMetadataPlugin( 443 credentials=google_credentials, 444 request=google_auth_transport_requests.Request(), 445 ) 446 ) 447 response = _large_unary_common_behavior(stub, True, False, call_credentials) 448 if wanted_email != response.username: 449 raise ValueError( 450 "expected username %s, got %s" % (wanted_email, response.username) 451 ) 452 453 454def _special_status_message(stub, args): 455 details = ( 456 b"\t\ntest with whitespace\r\nand Unicode BMP \xe2\x98\xba and non-BMP" 457 b" \xf0\x9f\x98\x88\t\n".decode("utf-8") 458 ) 459 code = 2 460 status = grpc.StatusCode.UNKNOWN # code = 2 461 462 # Test with a UnaryCall 463 request = messages_pb2.SimpleRequest( 464 response_type=messages_pb2.COMPRESSABLE, 465 response_size=1, 466 payload=messages_pb2.Payload(body=b"\x00"), 467 response_status=messages_pb2.EchoStatus(code=code, message=details), 468 ) 469 response_future = stub.UnaryCall.future(request) 470 _validate_status_code_and_details(response_future, status, details) 471 472 473@enum.unique 474class TestCase(enum.Enum): 475 EMPTY_UNARY = "empty_unary" 476 LARGE_UNARY = "large_unary" 477 SERVER_STREAMING = "server_streaming" 478 CLIENT_STREAMING = "client_streaming" 479 PING_PONG = "ping_pong" 480 CANCEL_AFTER_BEGIN = "cancel_after_begin" 481 CANCEL_AFTER_FIRST_RESPONSE = "cancel_after_first_response" 482 EMPTY_STREAM = "empty_stream" 483 STATUS_CODE_AND_MESSAGE = "status_code_and_message" 484 UNIMPLEMENTED_METHOD = "unimplemented_method" 485 UNIMPLEMENTED_SERVICE = "unimplemented_service" 486 CUSTOM_METADATA = "custom_metadata" 487 COMPUTE_ENGINE_CREDS = "compute_engine_creds" 488 OAUTH2_AUTH_TOKEN = "oauth2_auth_token" 489 JWT_TOKEN_CREDS = "jwt_token_creds" 490 PER_RPC_CREDS = "per_rpc_creds" 491 TIMEOUT_ON_SLEEPING_SERVER = "timeout_on_sleeping_server" 492 SPECIAL_STATUS_MESSAGE = "special_status_message" 493 494 def test_interoperability(self, stub, args): 495 if self is TestCase.EMPTY_UNARY: 496 _empty_unary(stub) 497 elif self is TestCase.LARGE_UNARY: 498 _large_unary(stub) 499 elif self is TestCase.SERVER_STREAMING: 500 _server_streaming(stub) 501 elif self is TestCase.CLIENT_STREAMING: 502 _client_streaming(stub) 503 elif self is TestCase.PING_PONG: 504 _ping_pong(stub) 505 elif self is TestCase.CANCEL_AFTER_BEGIN: 506 _cancel_after_begin(stub) 507 elif self is TestCase.CANCEL_AFTER_FIRST_RESPONSE: 508 _cancel_after_first_response(stub) 509 elif self is TestCase.TIMEOUT_ON_SLEEPING_SERVER: 510 _timeout_on_sleeping_server(stub) 511 elif self is TestCase.EMPTY_STREAM: 512 _empty_stream(stub) 513 elif self is TestCase.STATUS_CODE_AND_MESSAGE: 514 _status_code_and_message(stub) 515 elif self is TestCase.UNIMPLEMENTED_METHOD: 516 _unimplemented_method(stub) 517 elif self is TestCase.UNIMPLEMENTED_SERVICE: 518 _unimplemented_service(stub) 519 elif self is TestCase.CUSTOM_METADATA: 520 _custom_metadata(stub) 521 elif self is TestCase.COMPUTE_ENGINE_CREDS: 522 _compute_engine_creds(stub, args) 523 elif self is TestCase.OAUTH2_AUTH_TOKEN: 524 _oauth2_auth_token(stub, args) 525 elif self is TestCase.JWT_TOKEN_CREDS: 526 _jwt_token_creds(stub, args) 527 elif self is TestCase.PER_RPC_CREDS: 528 _per_rpc_creds(stub, args) 529 elif self is TestCase.SPECIAL_STATUS_MESSAGE: 530 _special_status_message(stub, args) 531 else: 532 raise NotImplementedError( 533 'Test case "%s" not implemented!' % self.name 534 ) 535