xref: /aosp_15_r20/external/grpc-grpc/src/python/grpcio_tests/tests/interop/methods.py (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1# Copyright 2015 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Implementations of interoperability test methods."""
15
16# NOTE(lidiz) This module only exists in Bazel BUILD file, for more details
17# please refer to comments in the "bazel_namespace_package_hack" module.
18try:
19    from tests import bazel_namespace_package_hack
20
21    bazel_namespace_package_hack.sys_path_to_site_dir_hack()
22except ImportError:
23    pass
24
25import enum
26import json
27import os
28import threading
29import time
30
31from google import auth as google_auth
32from google.auth import environment_vars as google_auth_environment_vars
33from google.auth.transport import grpc as google_auth_transport_grpc
34from google.auth.transport import requests as google_auth_transport_requests
35import grpc
36
37from src.proto.grpc.testing import empty_pb2
38from src.proto.grpc.testing import messages_pb2
39
40_INITIAL_METADATA_KEY = "x-grpc-test-echo-initial"
41_TRAILING_METADATA_KEY = "x-grpc-test-echo-trailing-bin"
42
43
44def _expect_status_code(call, expected_code):
45    if call.code() != expected_code:
46        raise ValueError(
47            "expected code %s, got %s" % (expected_code, call.code())
48        )
49
50
51def _expect_status_details(call, expected_details):
52    if call.details() != expected_details:
53        raise ValueError(
54            "expected message %s, got %s" % (expected_details, call.details())
55        )
56
57
58def _validate_status_code_and_details(call, expected_code, expected_details):
59    _expect_status_code(call, expected_code)
60    _expect_status_details(call, expected_details)
61
62
63def _validate_payload_type_and_length(response, expected_type, expected_length):
64    if response.payload.type is not expected_type:
65        raise ValueError(
66            "expected payload type %s, got %s"
67            % (expected_type, type(response.payload.type))
68        )
69    elif len(response.payload.body) != expected_length:
70        raise ValueError(
71            "expected payload body size %d, got %d"
72            % (expected_length, len(response.payload.body))
73        )
74
75
76def _large_unary_common_behavior(
77    stub, fill_username, fill_oauth_scope, call_credentials
78):
79    size = 314159
80    request = messages_pb2.SimpleRequest(
81        response_type=messages_pb2.COMPRESSABLE,
82        response_size=size,
83        payload=messages_pb2.Payload(body=b"\x00" * 271828),
84        fill_username=fill_username,
85        fill_oauth_scope=fill_oauth_scope,
86    )
87    response_future = stub.UnaryCall.future(
88        request, credentials=call_credentials
89    )
90    response = response_future.result()
91    _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size)
92    return response
93
94
95def _empty_unary(stub):
96    response = stub.EmptyCall(empty_pb2.Empty())
97    if not isinstance(response, empty_pb2.Empty):
98        raise TypeError(
99            'response is of type "%s", not empty_pb2.Empty!' % type(response)
100        )
101
102
103def _large_unary(stub):
104    _large_unary_common_behavior(stub, False, False, None)
105
106
107def _client_streaming(stub):
108    payload_body_sizes = (
109        27182,
110        8,
111        1828,
112        45904,
113    )
114    payloads = (
115        messages_pb2.Payload(body=b"\x00" * size) for size in payload_body_sizes
116    )
117    requests = (
118        messages_pb2.StreamingInputCallRequest(payload=payload)
119        for payload in payloads
120    )
121    response = stub.StreamingInputCall(requests)
122    if response.aggregated_payload_size != 74922:
123        raise ValueError(
124            "incorrect size %d!" % response.aggregated_payload_size
125        )
126
127
128def _server_streaming(stub):
129    sizes = (
130        31415,
131        9,
132        2653,
133        58979,
134    )
135
136    request = messages_pb2.StreamingOutputCallRequest(
137        response_type=messages_pb2.COMPRESSABLE,
138        response_parameters=(
139            messages_pb2.ResponseParameters(size=sizes[0]),
140            messages_pb2.ResponseParameters(size=sizes[1]),
141            messages_pb2.ResponseParameters(size=sizes[2]),
142            messages_pb2.ResponseParameters(size=sizes[3]),
143        ),
144    )
145    response_iterator = stub.StreamingOutputCall(request)
146    for index, response in enumerate(response_iterator):
147        _validate_payload_type_and_length(
148            response, messages_pb2.COMPRESSABLE, sizes[index]
149        )
150
151
152class _Pipe(object):
153    def __init__(self):
154        self._condition = threading.Condition()
155        self._values = []
156        self._open = True
157
158    def __iter__(self):
159        return self
160
161    def __next__(self):
162        return self.next()
163
164    def next(self):
165        with self._condition:
166            while not self._values and self._open:
167                self._condition.wait()
168            if self._values:
169                return self._values.pop(0)
170            else:
171                raise StopIteration()
172
173    def add(self, value):
174        with self._condition:
175            self._values.append(value)
176            self._condition.notify()
177
178    def close(self):
179        with self._condition:
180            self._open = False
181            self._condition.notify()
182
183    def __enter__(self):
184        return self
185
186    def __exit__(self, type, value, traceback):
187        self.close()
188
189
190def _ping_pong(stub):
191    request_response_sizes = (
192        31415,
193        9,
194        2653,
195        58979,
196    )
197    request_payload_sizes = (
198        27182,
199        8,
200        1828,
201        45904,
202    )
203
204    with _Pipe() as pipe:
205        response_iterator = stub.FullDuplexCall(pipe)
206        for response_size, payload_size in zip(
207            request_response_sizes, request_payload_sizes
208        ):
209            request = messages_pb2.StreamingOutputCallRequest(
210                response_type=messages_pb2.COMPRESSABLE,
211                response_parameters=(
212                    messages_pb2.ResponseParameters(size=response_size),
213                ),
214                payload=messages_pb2.Payload(body=b"\x00" * payload_size),
215            )
216            pipe.add(request)
217            response = next(response_iterator)
218            _validate_payload_type_and_length(
219                response, messages_pb2.COMPRESSABLE, response_size
220            )
221
222
223def _cancel_after_begin(stub):
224    with _Pipe() as pipe:
225        response_future = stub.StreamingInputCall.future(pipe)
226        response_future.cancel()
227        if not response_future.cancelled():
228            raise ValueError("expected cancelled method to return True")
229        if response_future.code() is not grpc.StatusCode.CANCELLED:
230            raise ValueError("expected status code CANCELLED")
231
232
233def _cancel_after_first_response(stub):
234    request_response_sizes = (
235        31415,
236        9,
237        2653,
238        58979,
239    )
240    request_payload_sizes = (
241        27182,
242        8,
243        1828,
244        45904,
245    )
246    with _Pipe() as pipe:
247        response_iterator = stub.FullDuplexCall(pipe)
248
249        response_size = request_response_sizes[0]
250        payload_size = request_payload_sizes[0]
251        request = messages_pb2.StreamingOutputCallRequest(
252            response_type=messages_pb2.COMPRESSABLE,
253            response_parameters=(
254                messages_pb2.ResponseParameters(size=response_size),
255            ),
256            payload=messages_pb2.Payload(body=b"\x00" * payload_size),
257        )
258        pipe.add(request)
259        response = next(response_iterator)
260        # We test the contents of `response` in the Ping Pong test - don't check
261        # them here.
262        response_iterator.cancel()
263
264        try:
265            next(response_iterator)
266        except grpc.RpcError as rpc_error:
267            if rpc_error.code() is not grpc.StatusCode.CANCELLED:
268                raise
269        else:
270            raise ValueError("expected call to be cancelled")
271
272
273def _timeout_on_sleeping_server(stub):
274    request_payload_size = 27182
275    with _Pipe() as pipe:
276        response_iterator = stub.FullDuplexCall(pipe, timeout=0.001)
277
278        request = messages_pb2.StreamingOutputCallRequest(
279            response_type=messages_pb2.COMPRESSABLE,
280            payload=messages_pb2.Payload(body=b"\x00" * request_payload_size),
281        )
282        pipe.add(request)
283        try:
284            next(response_iterator)
285        except grpc.RpcError as rpc_error:
286            if rpc_error.code() is not grpc.StatusCode.DEADLINE_EXCEEDED:
287                raise
288        else:
289            raise ValueError("expected call to exceed deadline")
290
291
292def _empty_stream(stub):
293    with _Pipe() as pipe:
294        response_iterator = stub.FullDuplexCall(pipe)
295        pipe.close()
296        try:
297            next(response_iterator)
298            raise ValueError("expected exactly 0 responses")
299        except StopIteration:
300            pass
301
302
303def _status_code_and_message(stub):
304    details = "test status message"
305    code = 2
306    status = grpc.StatusCode.UNKNOWN  # code = 2
307
308    # Test with a UnaryCall
309    request = messages_pb2.SimpleRequest(
310        response_type=messages_pb2.COMPRESSABLE,
311        response_size=1,
312        payload=messages_pb2.Payload(body=b"\x00"),
313        response_status=messages_pb2.EchoStatus(code=code, message=details),
314    )
315    response_future = stub.UnaryCall.future(request)
316    _validate_status_code_and_details(response_future, status, details)
317
318    # Test with a FullDuplexCall
319    with _Pipe() as pipe:
320        response_iterator = stub.FullDuplexCall(pipe)
321        request = messages_pb2.StreamingOutputCallRequest(
322            response_type=messages_pb2.COMPRESSABLE,
323            response_parameters=(messages_pb2.ResponseParameters(size=1),),
324            payload=messages_pb2.Payload(body=b"\x00"),
325            response_status=messages_pb2.EchoStatus(code=code, message=details),
326        )
327        pipe.add(request)  # sends the initial request.
328    try:
329        next(response_iterator)
330    except grpc.RpcError as rpc_error:
331        assert rpc_error.code() == status
332    # Dropping out of with block closes the pipe
333    _validate_status_code_and_details(response_iterator, status, details)
334
335
336def _unimplemented_method(test_service_stub):
337    response_future = test_service_stub.UnimplementedCall.future(
338        empty_pb2.Empty()
339    )
340    _expect_status_code(response_future, grpc.StatusCode.UNIMPLEMENTED)
341
342
343def _unimplemented_service(unimplemented_service_stub):
344    response_future = unimplemented_service_stub.UnimplementedCall.future(
345        empty_pb2.Empty()
346    )
347    _expect_status_code(response_future, grpc.StatusCode.UNIMPLEMENTED)
348
349
350def _custom_metadata(stub):
351    initial_metadata_value = "test_initial_metadata_value"
352    trailing_metadata_value = b"\x0a\x0b\x0a\x0b\x0a\x0b"
353    metadata = (
354        (_INITIAL_METADATA_KEY, initial_metadata_value),
355        (_TRAILING_METADATA_KEY, trailing_metadata_value),
356    )
357
358    def _validate_metadata(response):
359        initial_metadata = dict(response.initial_metadata())
360        if initial_metadata[_INITIAL_METADATA_KEY] != initial_metadata_value:
361            raise ValueError(
362                "expected initial metadata %s, got %s"
363                % (
364                    initial_metadata_value,
365                    initial_metadata[_INITIAL_METADATA_KEY],
366                )
367            )
368        trailing_metadata = dict(response.trailing_metadata())
369        if trailing_metadata[_TRAILING_METADATA_KEY] != trailing_metadata_value:
370            raise ValueError(
371                "expected trailing metadata %s, got %s"
372                % (
373                    trailing_metadata_value,
374                    trailing_metadata[_TRAILING_METADATA_KEY],
375                )
376            )
377
378    # Testing with UnaryCall
379    request = messages_pb2.SimpleRequest(
380        response_type=messages_pb2.COMPRESSABLE,
381        response_size=1,
382        payload=messages_pb2.Payload(body=b"\x00"),
383    )
384    response_future = stub.UnaryCall.future(request, metadata=metadata)
385    _validate_metadata(response_future)
386
387    # Testing with FullDuplexCall
388    with _Pipe() as pipe:
389        response_iterator = stub.FullDuplexCall(pipe, metadata=metadata)
390        request = messages_pb2.StreamingOutputCallRequest(
391            response_type=messages_pb2.COMPRESSABLE,
392            response_parameters=(messages_pb2.ResponseParameters(size=1),),
393        )
394        pipe.add(request)  # Sends the request
395        next(response_iterator)  # Causes server to send trailing metadata
396    # Dropping out of the with block closes the pipe
397    _validate_metadata(response_iterator)
398
399
400def _compute_engine_creds(stub, args):
401    response = _large_unary_common_behavior(stub, True, True, None)
402    if args.default_service_account != response.username:
403        raise ValueError(
404            "expected username %s, got %s"
405            % (args.default_service_account, response.username)
406        )
407
408
409def _oauth2_auth_token(stub, args):
410    json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS]
411    wanted_email = json.load(open(json_key_filename, "r"))["client_email"]
412    response = _large_unary_common_behavior(stub, True, True, None)
413    if wanted_email != response.username:
414        raise ValueError(
415            "expected username %s, got %s" % (wanted_email, response.username)
416        )
417    if args.oauth_scope.find(response.oauth_scope) == -1:
418        raise ValueError(
419            'expected to find oauth scope "{}" in received "{}"'.format(
420                response.oauth_scope, args.oauth_scope
421            )
422        )
423
424
425def _jwt_token_creds(stub, args):
426    json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS]
427    wanted_email = json.load(open(json_key_filename, "r"))["client_email"]
428    response = _large_unary_common_behavior(stub, True, False, None)
429    if wanted_email != response.username:
430        raise ValueError(
431            "expected username %s, got %s" % (wanted_email, response.username)
432        )
433
434
435def _per_rpc_creds(stub, args):
436    json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS]
437    wanted_email = json.load(open(json_key_filename, "r"))["client_email"]
438    google_credentials, unused_project_id = google_auth.default(
439        scopes=[args.oauth_scope]
440    )
441    call_credentials = grpc.metadata_call_credentials(
442        google_auth_transport_grpc.AuthMetadataPlugin(
443            credentials=google_credentials,
444            request=google_auth_transport_requests.Request(),
445        )
446    )
447    response = _large_unary_common_behavior(stub, True, False, call_credentials)
448    if wanted_email != response.username:
449        raise ValueError(
450            "expected username %s, got %s" % (wanted_email, response.username)
451        )
452
453
454def _special_status_message(stub, args):
455    details = (
456        b"\t\ntest with whitespace\r\nand Unicode BMP \xe2\x98\xba and non-BMP"
457        b" \xf0\x9f\x98\x88\t\n".decode("utf-8")
458    )
459    code = 2
460    status = grpc.StatusCode.UNKNOWN  # code = 2
461
462    # Test with a UnaryCall
463    request = messages_pb2.SimpleRequest(
464        response_type=messages_pb2.COMPRESSABLE,
465        response_size=1,
466        payload=messages_pb2.Payload(body=b"\x00"),
467        response_status=messages_pb2.EchoStatus(code=code, message=details),
468    )
469    response_future = stub.UnaryCall.future(request)
470    _validate_status_code_and_details(response_future, status, details)
471
472
473@enum.unique
474class TestCase(enum.Enum):
475    EMPTY_UNARY = "empty_unary"
476    LARGE_UNARY = "large_unary"
477    SERVER_STREAMING = "server_streaming"
478    CLIENT_STREAMING = "client_streaming"
479    PING_PONG = "ping_pong"
480    CANCEL_AFTER_BEGIN = "cancel_after_begin"
481    CANCEL_AFTER_FIRST_RESPONSE = "cancel_after_first_response"
482    EMPTY_STREAM = "empty_stream"
483    STATUS_CODE_AND_MESSAGE = "status_code_and_message"
484    UNIMPLEMENTED_METHOD = "unimplemented_method"
485    UNIMPLEMENTED_SERVICE = "unimplemented_service"
486    CUSTOM_METADATA = "custom_metadata"
487    COMPUTE_ENGINE_CREDS = "compute_engine_creds"
488    OAUTH2_AUTH_TOKEN = "oauth2_auth_token"
489    JWT_TOKEN_CREDS = "jwt_token_creds"
490    PER_RPC_CREDS = "per_rpc_creds"
491    TIMEOUT_ON_SLEEPING_SERVER = "timeout_on_sleeping_server"
492    SPECIAL_STATUS_MESSAGE = "special_status_message"
493
494    def test_interoperability(self, stub, args):
495        if self is TestCase.EMPTY_UNARY:
496            _empty_unary(stub)
497        elif self is TestCase.LARGE_UNARY:
498            _large_unary(stub)
499        elif self is TestCase.SERVER_STREAMING:
500            _server_streaming(stub)
501        elif self is TestCase.CLIENT_STREAMING:
502            _client_streaming(stub)
503        elif self is TestCase.PING_PONG:
504            _ping_pong(stub)
505        elif self is TestCase.CANCEL_AFTER_BEGIN:
506            _cancel_after_begin(stub)
507        elif self is TestCase.CANCEL_AFTER_FIRST_RESPONSE:
508            _cancel_after_first_response(stub)
509        elif self is TestCase.TIMEOUT_ON_SLEEPING_SERVER:
510            _timeout_on_sleeping_server(stub)
511        elif self is TestCase.EMPTY_STREAM:
512            _empty_stream(stub)
513        elif self is TestCase.STATUS_CODE_AND_MESSAGE:
514            _status_code_and_message(stub)
515        elif self is TestCase.UNIMPLEMENTED_METHOD:
516            _unimplemented_method(stub)
517        elif self is TestCase.UNIMPLEMENTED_SERVICE:
518            _unimplemented_service(stub)
519        elif self is TestCase.CUSTOM_METADATA:
520            _custom_metadata(stub)
521        elif self is TestCase.COMPUTE_ENGINE_CREDS:
522            _compute_engine_creds(stub, args)
523        elif self is TestCase.OAUTH2_AUTH_TOKEN:
524            _oauth2_auth_token(stub, args)
525        elif self is TestCase.JWT_TOKEN_CREDS:
526            _jwt_token_creds(stub, args)
527        elif self is TestCase.PER_RPC_CREDS:
528            _per_rpc_creds(stub, args)
529        elif self is TestCase.SPECIAL_STATUS_MESSAGE:
530            _special_status_message(stub, args)
531        else:
532            raise NotImplementedError(
533                'Test case "%s" not implemented!' % self.name
534            )
535