xref: /aosp_15_r20/external/grpc-grpc/src/python/grpcio_tests/tests/unit/_rpc_test_helpers.py (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1# Copyright 2020 The gRPC Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Test helpers for RPC invocation tests."""
15
16import datetime
17import threading
18
19import grpc
20from grpc.framework.foundation import logging_pool
21
22from tests.unit import test_common
23from tests.unit import thread_pool
24from tests.unit.framework.common import test_constants
25from tests.unit.framework.common import test_control
26
27_SERIALIZE_REQUEST = lambda bytestring: bytestring * 2
28_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2 :]
29_SERIALIZE_RESPONSE = lambda bytestring: bytestring * 3
30_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[: len(bytestring) // 3]
31
32_UNARY_UNARY = "/test/UnaryUnary"
33_UNARY_STREAM = "/test/UnaryStream"
34_UNARY_STREAM_NON_BLOCKING = "/test/UnaryStreamNonBlocking"
35_STREAM_UNARY = "/test/StreamUnary"
36_STREAM_STREAM = "/test/StreamStream"
37_STREAM_STREAM_NON_BLOCKING = "/test/StreamStreamNonBlocking"
38
39TIMEOUT_SHORT = datetime.timedelta(seconds=4).total_seconds()
40
41
42class Callback(object):
43    def __init__(self):
44        self._condition = threading.Condition()
45        self._value = None
46        self._called = False
47
48    def __call__(self, value):
49        with self._condition:
50            self._value = value
51            self._called = True
52            self._condition.notify_all()
53
54    def value(self):
55        with self._condition:
56            while not self._called:
57                self._condition.wait()
58            return self._value
59
60
61class _Handler(object):
62    def __init__(self, control, thread_pool):
63        self._control = control
64        self._thread_pool = thread_pool
65        non_blocking_functions = (
66            self.handle_unary_stream_non_blocking,
67            self.handle_stream_stream_non_blocking,
68        )
69        for non_blocking_function in non_blocking_functions:
70            non_blocking_function.__func__.experimental_non_blocking = True
71            non_blocking_function.__func__.experimental_thread_pool = (
72                self._thread_pool
73            )
74
75    def handle_unary_unary(self, request, servicer_context):
76        self._control.control()
77        if servicer_context is not None:
78            servicer_context.set_trailing_metadata(
79                (
80                    (
81                        "testkey",
82                        "testvalue",
83                    ),
84                )
85            )
86            # TODO(https://github.com/grpc/grpc/issues/8483): test the values
87            # returned by these methods rather than only "smoke" testing that
88            # the return after having been called.
89            servicer_context.is_active()
90            servicer_context.time_remaining()
91        return request
92
93    def handle_unary_stream(self, request, servicer_context):
94        for _ in range(test_constants.STREAM_LENGTH):
95            self._control.control()
96            yield request
97        self._control.control()
98        if servicer_context is not None:
99            servicer_context.set_trailing_metadata(
100                (
101                    (
102                        "testkey",
103                        "testvalue",
104                    ),
105                )
106            )
107
108    def handle_unary_stream_non_blocking(
109        self, request, servicer_context, on_next
110    ):
111        for _ in range(test_constants.STREAM_LENGTH):
112            self._control.control()
113            on_next(request)
114        self._control.control()
115        if servicer_context is not None:
116            servicer_context.set_trailing_metadata(
117                (
118                    (
119                        "testkey",
120                        "testvalue",
121                    ),
122                )
123            )
124        on_next(None)
125
126    def handle_stream_unary(self, request_iterator, servicer_context):
127        if servicer_context is not None:
128            servicer_context.invocation_metadata()
129        self._control.control()
130        response_elements = []
131        for request in request_iterator:
132            self._control.control()
133            response_elements.append(request)
134        self._control.control()
135        if servicer_context is not None:
136            servicer_context.set_trailing_metadata(
137                (
138                    (
139                        "testkey",
140                        "testvalue",
141                    ),
142                )
143            )
144        return b"".join(response_elements)
145
146    def handle_stream_stream(self, request_iterator, servicer_context):
147        self._control.control()
148        if servicer_context is not None:
149            servicer_context.set_trailing_metadata(
150                (
151                    (
152                        "testkey",
153                        "testvalue",
154                    ),
155                )
156            )
157        for request in request_iterator:
158            self._control.control()
159            yield request
160        self._control.control()
161
162    def handle_stream_stream_non_blocking(
163        self, request_iterator, servicer_context, on_next
164    ):
165        self._control.control()
166        if servicer_context is not None:
167            servicer_context.set_trailing_metadata(
168                (
169                    (
170                        "testkey",
171                        "testvalue",
172                    ),
173                )
174            )
175        for request in request_iterator:
176            self._control.control()
177            on_next(request)
178        self._control.control()
179        on_next(None)
180
181
182class _MethodHandler(grpc.RpcMethodHandler):
183    def __init__(
184        self,
185        request_streaming,
186        response_streaming,
187        request_deserializer,
188        response_serializer,
189        unary_unary,
190        unary_stream,
191        stream_unary,
192        stream_stream,
193    ):
194        self.request_streaming = request_streaming
195        self.response_streaming = response_streaming
196        self.request_deserializer = request_deserializer
197        self.response_serializer = response_serializer
198        self.unary_unary = unary_unary
199        self.unary_stream = unary_stream
200        self.stream_unary = stream_unary
201        self.stream_stream = stream_stream
202
203
204class _GenericHandler(grpc.GenericRpcHandler):
205    def __init__(self, handler):
206        self._handler = handler
207
208    def service(self, handler_call_details):
209        if handler_call_details.method == _UNARY_UNARY:
210            return _MethodHandler(
211                False,
212                False,
213                None,
214                None,
215                self._handler.handle_unary_unary,
216                None,
217                None,
218                None,
219            )
220        elif handler_call_details.method == _UNARY_STREAM:
221            return _MethodHandler(
222                False,
223                True,
224                _DESERIALIZE_REQUEST,
225                _SERIALIZE_RESPONSE,
226                None,
227                self._handler.handle_unary_stream,
228                None,
229                None,
230            )
231        elif handler_call_details.method == _UNARY_STREAM_NON_BLOCKING:
232            return _MethodHandler(
233                False,
234                True,
235                _DESERIALIZE_REQUEST,
236                _SERIALIZE_RESPONSE,
237                None,
238                self._handler.handle_unary_stream_non_blocking,
239                None,
240                None,
241            )
242        elif handler_call_details.method == _STREAM_UNARY:
243            return _MethodHandler(
244                True,
245                False,
246                _DESERIALIZE_REQUEST,
247                _SERIALIZE_RESPONSE,
248                None,
249                None,
250                self._handler.handle_stream_unary,
251                None,
252            )
253        elif handler_call_details.method == _STREAM_STREAM:
254            return _MethodHandler(
255                True,
256                True,
257                None,
258                None,
259                None,
260                None,
261                None,
262                self._handler.handle_stream_stream,
263            )
264        elif handler_call_details.method == _STREAM_STREAM_NON_BLOCKING:
265            return _MethodHandler(
266                True,
267                True,
268                None,
269                None,
270                None,
271                None,
272                None,
273                self._handler.handle_stream_stream_non_blocking,
274            )
275        else:
276            return None
277
278
279def unary_unary_multi_callable(channel):
280    return channel.unary_unary(
281        _UNARY_UNARY,
282        _registered_method=True,
283    )
284
285
286def unary_stream_multi_callable(channel):
287    return channel.unary_stream(
288        _UNARY_STREAM,
289        request_serializer=_SERIALIZE_REQUEST,
290        response_deserializer=_DESERIALIZE_RESPONSE,
291        _registered_method=True,
292    )
293
294
295def unary_stream_non_blocking_multi_callable(channel):
296    return channel.unary_stream(
297        _UNARY_STREAM_NON_BLOCKING,
298        request_serializer=_SERIALIZE_REQUEST,
299        response_deserializer=_DESERIALIZE_RESPONSE,
300        _registered_method=True,
301    )
302
303
304def stream_unary_multi_callable(channel):
305    return channel.stream_unary(
306        _STREAM_UNARY,
307        request_serializer=_SERIALIZE_REQUEST,
308        response_deserializer=_DESERIALIZE_RESPONSE,
309        _registered_method=True,
310    )
311
312
313def stream_stream_multi_callable(channel):
314    return channel.stream_stream(
315        _STREAM_STREAM,
316        _registered_method=True,
317    )
318
319
320def stream_stream_non_blocking_multi_callable(channel):
321    return channel.stream_stream(
322        _STREAM_STREAM_NON_BLOCKING,
323        _registered_method=True,
324    )
325
326
327class BaseRPCTest(object):
328    def setUp(self):
329        self._control = test_control.PauseFailControl()
330        self._thread_pool = thread_pool.RecordingThreadPool(max_workers=None)
331        self._handler = _Handler(self._control, self._thread_pool)
332
333        self._server = test_common.test_server()
334        port = self._server.add_insecure_port("[::]:0")
335        self._server.add_generic_rpc_handlers((_GenericHandler(self._handler),))
336        self._server.start()
337
338        self._channel = grpc.insecure_channel("localhost:%d" % port)
339
340    def tearDown(self):
341        self._server.stop(None)
342        self._channel.close()
343
344    def _consume_one_stream_response_unary_request(self, multi_callable):
345        request = b"\x57\x38"
346
347        response_iterator = multi_callable(
348            request,
349            metadata=(("test", "ConsumingOneStreamResponseUnaryRequest"),),
350        )
351        next(response_iterator)
352
353    def _consume_some_but_not_all_stream_responses_unary_request(
354        self, multi_callable
355    ):
356        request = b"\x57\x38"
357
358        response_iterator = multi_callable(
359            request,
360            metadata=(
361                ("test", "ConsumingSomeButNotAllStreamResponsesUnaryRequest"),
362            ),
363        )
364        for _ in range(test_constants.STREAM_LENGTH // 2):
365            next(response_iterator)
366
367    def _consume_some_but_not_all_stream_responses_stream_request(
368        self, multi_callable
369    ):
370        requests = tuple(
371            b"\x67\x88" for _ in range(test_constants.STREAM_LENGTH)
372        )
373        request_iterator = iter(requests)
374
375        response_iterator = multi_callable(
376            request_iterator,
377            metadata=(
378                ("test", "ConsumingSomeButNotAllStreamResponsesStreamRequest"),
379            ),
380        )
381        for _ in range(test_constants.STREAM_LENGTH // 2):
382            next(response_iterator)
383
384    def _consume_too_many_stream_responses_stream_request(self, multi_callable):
385        requests = tuple(
386            b"\x67\x88" for _ in range(test_constants.STREAM_LENGTH)
387        )
388        request_iterator = iter(requests)
389
390        response_iterator = multi_callable(
391            request_iterator,
392            metadata=(
393                ("test", "ConsumingTooManyStreamResponsesStreamRequest"),
394            ),
395        )
396        for _ in range(test_constants.STREAM_LENGTH):
397            next(response_iterator)
398        for _ in range(test_constants.STREAM_LENGTH):
399            with self.assertRaises(StopIteration):
400                next(response_iterator)
401
402        self.assertIsNotNone(response_iterator.initial_metadata())
403        self.assertIs(grpc.StatusCode.OK, response_iterator.code())
404        self.assertIsNotNone(response_iterator.details())
405        self.assertIsNotNone(response_iterator.trailing_metadata())
406
407    def _cancelled_unary_request_stream_response(self, multi_callable):
408        request = b"\x07\x19"
409
410        with self._control.pause():
411            response_iterator = multi_callable(
412                request,
413                metadata=(("test", "CancelledUnaryRequestStreamResponse"),),
414            )
415            self._control.block_until_paused()
416            response_iterator.cancel()
417
418        with self.assertRaises(grpc.RpcError) as exception_context:
419            next(response_iterator)
420        self.assertIs(
421            grpc.StatusCode.CANCELLED, exception_context.exception.code()
422        )
423        self.assertIsNotNone(response_iterator.initial_metadata())
424        self.assertIs(grpc.StatusCode.CANCELLED, response_iterator.code())
425        self.assertIsNotNone(response_iterator.details())
426        self.assertIsNotNone(response_iterator.trailing_metadata())
427
428    def _cancelled_stream_request_stream_response(self, multi_callable):
429        requests = tuple(
430            b"\x07\x08" for _ in range(test_constants.STREAM_LENGTH)
431        )
432        request_iterator = iter(requests)
433
434        with self._control.pause():
435            response_iterator = multi_callable(
436                request_iterator,
437                metadata=(("test", "CancelledStreamRequestStreamResponse"),),
438            )
439            response_iterator.cancel()
440
441        with self.assertRaises(grpc.RpcError):
442            next(response_iterator)
443        self.assertIsNotNone(response_iterator.initial_metadata())
444        self.assertIs(grpc.StatusCode.CANCELLED, response_iterator.code())
445        self.assertIsNotNone(response_iterator.details())
446        self.assertIsNotNone(response_iterator.trailing_metadata())
447
448    def _expired_unary_request_stream_response(self, multi_callable):
449        request = b"\x07\x19"
450
451        with self._control.pause():
452            with self.assertRaises(grpc.RpcError) as exception_context:
453                response_iterator = multi_callable(
454                    request,
455                    timeout=test_constants.SHORT_TIMEOUT,
456                    metadata=(("test", "ExpiredUnaryRequestStreamResponse"),),
457                )
458                next(response_iterator)
459
460        self.assertIs(
461            grpc.StatusCode.DEADLINE_EXCEEDED,
462            exception_context.exception.code(),
463        )
464        self.assertIs(
465            grpc.StatusCode.DEADLINE_EXCEEDED, response_iterator.code()
466        )
467
468    def _expired_stream_request_stream_response(self, multi_callable):
469        requests = tuple(
470            b"\x67\x18" for _ in range(test_constants.STREAM_LENGTH)
471        )
472        request_iterator = iter(requests)
473
474        with self._control.pause():
475            with self.assertRaises(grpc.RpcError) as exception_context:
476                response_iterator = multi_callable(
477                    request_iterator,
478                    timeout=test_constants.SHORT_TIMEOUT,
479                    metadata=(("test", "ExpiredStreamRequestStreamResponse"),),
480                )
481                next(response_iterator)
482
483        self.assertIs(
484            grpc.StatusCode.DEADLINE_EXCEEDED,
485            exception_context.exception.code(),
486        )
487        self.assertIs(
488            grpc.StatusCode.DEADLINE_EXCEEDED, response_iterator.code()
489        )
490
491    def _failed_unary_request_stream_response(self, multi_callable):
492        request = b"\x37\x17"
493
494        with self.assertRaises(grpc.RpcError) as exception_context:
495            with self._control.fail():
496                response_iterator = multi_callable(
497                    request,
498                    metadata=(("test", "FailedUnaryRequestStreamResponse"),),
499                )
500                next(response_iterator)
501
502        self.assertIs(
503            grpc.StatusCode.UNKNOWN, exception_context.exception.code()
504        )
505
506    def _failed_stream_request_stream_response(self, multi_callable):
507        requests = tuple(
508            b"\x67\x88" for _ in range(test_constants.STREAM_LENGTH)
509        )
510        request_iterator = iter(requests)
511
512        with self._control.fail():
513            with self.assertRaises(grpc.RpcError) as exception_context:
514                response_iterator = multi_callable(
515                    request_iterator,
516                    metadata=(("test", "FailedStreamRequestStreamResponse"),),
517                )
518                tuple(response_iterator)
519
520        self.assertIs(
521            grpc.StatusCode.UNKNOWN, exception_context.exception.code()
522        )
523        self.assertIs(grpc.StatusCode.UNKNOWN, response_iterator.code())
524
525    def _ignored_unary_stream_request_future_unary_response(
526        self, multi_callable
527    ):
528        request = b"\x37\x17"
529
530        multi_callable(
531            request, metadata=(("test", "IgnoredUnaryRequestStreamResponse"),)
532        )
533
534    def _ignored_stream_request_stream_response(self, multi_callable):
535        requests = tuple(
536            b"\x67\x88" for _ in range(test_constants.STREAM_LENGTH)
537        )
538        request_iterator = iter(requests)
539
540        multi_callable(
541            request_iterator,
542            metadata=(("test", "IgnoredStreamRequestStreamResponse"),),
543        )
544