xref: /aosp_15_r20/external/pigweed/pw_rpc/py/tests/callback_client_test.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1#!/usr/bin/env python3
2# Copyright 2021 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Tests using the callback client for pw_rpc."""
16
17import unittest
18from unittest import mock
19from typing import Any
20
21from pw_protobuf_compiler import python_protos
22from pw_status import Status
23
24from pw_rpc import callback_client, client, descriptors, packets
25from pw_rpc.internal import packet_pb2
26
27TEST_PROTO_1 = """\
28syntax = "proto3";
29
30package pw.test1;
31
32message SomeMessage {
33  uint32 magic_number = 1;
34}
35
36message AnotherMessage {
37  enum Result {
38    FAILED = 0;
39    FAILED_MISERABLY = 1;
40    I_DONT_WANT_TO_TALK_ABOUT_IT = 2;
41  }
42
43  Result result = 1;
44  string payload = 2;
45}
46
47service PublicService {
48  rpc SomeUnary(SomeMessage) returns (AnotherMessage) {}
49  rpc SomeServerStreaming(SomeMessage) returns (stream AnotherMessage) {}
50  rpc SomeClientStreaming(stream SomeMessage) returns (AnotherMessage) {}
51  rpc SomeBidiStreaming(stream SomeMessage) returns (stream AnotherMessage) {}
52}
53"""
54
55PROTOS = python_protos.Library.from_strings(TEST_PROTO_1)
56CLIENT_CHANNEL_ID: int = 489
57
58
59def _message_bytes(msg) -> bytes:
60    return msg if isinstance(msg, bytes) else msg.SerializeToString()
61
62
63class _CallbackClientImplTestBase(unittest.TestCase):
64    """Supports writing tests that require responses from an RPC server."""
65
66    def setUp(self) -> None:
67        self._request = PROTOS.packages.pw.test1.SomeMessage
68
69        self._client = client.Client.from_modules(
70            callback_client.Impl(),
71            [client.Channel(CLIENT_CHANNEL_ID, self._handle_packet)],
72            PROTOS.modules(),
73        )
74        self._service = self._client.channel(
75            CLIENT_CHANNEL_ID
76        ).rpcs.pw.test1.PublicService
77
78        self.requests: list[packet_pb2.RpcPacket] = []
79        self._next_packets: list[tuple[bytes, Status]] = []
80        self.send_responses_after_packets: float = 1
81
82        self.output_exception: Exception | None = None
83
84    def last_request(self) -> packet_pb2.RpcPacket:
85        assert self.requests
86        return self.requests[-1]
87
88    def _enqueue_response(
89        self,
90        channel_id: int = CLIENT_CHANNEL_ID,
91        method: descriptors.Method | None = None,
92        status: Status = Status.OK,
93        payload: bytes = b'',
94        *,
95        ids: tuple[int, int] | None = None,
96        process_status: Status = Status.OK,
97        call_id: int = client.OPEN_CALL_ID,
98    ) -> None:
99        if method:
100            assert ids is None
101            service_id, method_id = method.service.id, method.id
102        else:
103            assert ids is not None and method is None
104            service_id, method_id = ids
105
106        self._next_packets.append(
107            (
108                packet_pb2.RpcPacket(
109                    type=packet_pb2.PacketType.RESPONSE,
110                    channel_id=channel_id,
111                    service_id=service_id,
112                    method_id=method_id,
113                    call_id=call_id,
114                    status=status.value,
115                    payload=_message_bytes(payload),
116                ).SerializeToString(),
117                process_status,
118            )
119        )
120
121    def _enqueue_server_stream(
122        self,
123        channel_id: int,
124        method,
125        response,
126        process_status=Status.OK,
127        call_id: int = client.OPEN_CALL_ID,
128    ) -> None:
129        self._next_packets.append(
130            (
131                packet_pb2.RpcPacket(
132                    type=packet_pb2.PacketType.SERVER_STREAM,
133                    channel_id=channel_id,
134                    service_id=method.service.id,
135                    method_id=method.id,
136                    call_id=call_id,
137                    payload=_message_bytes(response),
138                ).SerializeToString(),
139                process_status,
140            )
141        )
142
143    def _enqueue_error(
144        self,
145        channel_id: int,
146        service,
147        method,
148        status: Status,
149        process_status=Status.OK,
150        call_id: int = client.OPEN_CALL_ID,
151    ) -> None:
152        self._next_packets.append(
153            (
154                packet_pb2.RpcPacket(
155                    type=packet_pb2.PacketType.SERVER_ERROR,
156                    channel_id=channel_id,
157                    service_id=service
158                    if isinstance(service, int)
159                    else service.id,
160                    method_id=method if isinstance(method, int) else method.id,
161                    call_id=call_id,
162                    status=status.value,
163                ).SerializeToString(),
164                process_status,
165            )
166        )
167
168    def _handle_packet(self, data: bytes) -> None:
169        if self.output_exception:
170            raise self.output_exception  # pylint: disable=raising-bad-type
171
172        self.requests.append(packets.decode(data))
173
174        if self.send_responses_after_packets > 1:
175            self.send_responses_after_packets -= 1
176            return
177
178        self._process_enqueued_packets()
179
180    def _process_enqueued_packets(self) -> None:
181        # Set send_responses_after_packets to infinity to prevent potential
182        # infinite recursion when a packet causes another packet to send.
183        send_after_count = self.send_responses_after_packets
184        self.send_responses_after_packets = float('inf')
185
186        for packet, status in self._next_packets:
187            self.assertIs(status, self._client.process_packet(packet))
188
189        self._next_packets.clear()
190        self.send_responses_after_packets = send_after_count
191
192    def _sent_payload(self, message_type: type) -> Any:
193        message = message_type()
194        message.ParseFromString(self.last_request().payload)
195        return message
196
197
198# Disable docstring requirements for test functions.
199# pylint: disable=missing-function-docstring
200
201
202class CallbackClientImplTest(_CallbackClientImplTestBase):
203    """Tests the callback_client.Impl client implementation."""
204
205    def test_callback_exceptions_suppressed(self) -> None:
206        stub = self._service.SomeUnary
207
208        self._enqueue_response(CLIENT_CHANNEL_ID, stub.method)
209        exception_msg = 'YOU BROKE IT O-]-<'
210
211        with self.assertLogs(callback_client.__package__, 'ERROR') as logs:
212            stub.invoke(
213                self._request(), mock.Mock(side_effect=Exception(exception_msg))
214            )
215
216        self.assertIn(exception_msg, ''.join(logs.output))
217
218        # Make sure we can still invoke the RPC.
219        self._enqueue_response(CLIENT_CHANNEL_ID, stub.method, Status.UNKNOWN)
220        status, _ = stub()
221        self.assertIs(status, Status.UNKNOWN)
222
223    def test_ignore_bad_packets_with_pending_rpc(self) -> None:
224        method = self._service.SomeUnary.method
225        service_id = method.service.id
226
227        # Unknown channel
228        self._enqueue_response(999, method, process_status=Status.NOT_FOUND)
229        # Bad service
230        self._enqueue_response(
231            CLIENT_CHANNEL_ID, ids=(999, method.id), process_status=Status.OK
232        )
233        # Bad method
234        self._enqueue_response(
235            CLIENT_CHANNEL_ID, ids=(service_id, 999), process_status=Status.OK
236        )
237        # For RPC not pending (is Status.OK because the packet is processed)
238        self._enqueue_response(
239            CLIENT_CHANNEL_ID,
240            ids=(service_id, self._service.SomeBidiStreaming.method.id),
241            process_status=Status.OK,
242        )
243
244        self._enqueue_response(
245            CLIENT_CHANNEL_ID, method, process_status=Status.OK
246        )
247
248        status, response = self._service.SomeUnary(magic_number=6)
249        self.assertIs(Status.OK, status)
250        self.assertEqual('', response.payload)
251
252    def test_server_error_for_unknown_call_sends_no_errors(self) -> None:
253        method = self._service.SomeUnary.method
254        service_id = method.service.id
255
256        # Unknown channel
257        self._enqueue_error(
258            999,
259            service_id,
260            method,
261            Status.NOT_FOUND,
262            process_status=Status.NOT_FOUND,
263        )
264        # Bad service
265        self._enqueue_error(
266            CLIENT_CHANNEL_ID, 999, method.id, Status.INVALID_ARGUMENT
267        )
268        # Bad method
269        self._enqueue_error(
270            CLIENT_CHANNEL_ID, service_id, 999, Status.INVALID_ARGUMENT
271        )
272        # For RPC not pending
273        self._enqueue_error(
274            CLIENT_CHANNEL_ID,
275            service_id,
276            self._service.SomeBidiStreaming.method.id,
277            Status.NOT_FOUND,
278        )
279
280        self._process_enqueued_packets()
281
282        self.assertEqual(self.requests, [])
283
284    def test_exception_if_payload_fails_to_decode(self) -> None:
285        method = self._service.SomeUnary.method
286
287        self._enqueue_response(
288            CLIENT_CHANNEL_ID,
289            method,
290            Status.OK,
291            b'INVALID DATA!!!',
292            process_status=Status.OK,
293        )
294
295        with self.assertRaises(callback_client.RpcError) as context:
296            self._service.SomeUnary(magic_number=6)
297
298        self.assertIs(context.exception.status, Status.DATA_LOSS)
299
300    def test_rpc_help_contains_method_name(self) -> None:
301        rpc = self._service.SomeUnary
302        self.assertIn(rpc.method.full_name, rpc.help())
303
304    def test_default_timeouts_set_on_impl(self) -> None:
305        impl = callback_client.Impl(None, 1.5)
306
307        self.assertEqual(impl.default_unary_timeout_s, None)
308        self.assertEqual(impl.default_stream_timeout_s, 1.5)
309
310    def test_default_timeouts_set_for_all_rpcs(self) -> None:
311        rpc_client = client.Client.from_modules(
312            callback_client.Impl(99, 100),
313            [client.Channel(CLIENT_CHANNEL_ID, lambda *a, **b: None)],
314            PROTOS.modules(),
315        )
316        rpcs = rpc_client.channel(CLIENT_CHANNEL_ID).rpcs
317
318        self.assertEqual(
319            rpcs.pw.test1.PublicService.SomeUnary.default_timeout_s, 99
320        )
321        self.assertEqual(
322            rpcs.pw.test1.PublicService.SomeServerStreaming.default_timeout_s,
323            100,
324        )
325        self.assertEqual(
326            rpcs.pw.test1.PublicService.SomeClientStreaming.default_timeout_s,
327            99,
328        )
329        self.assertEqual(
330            rpcs.pw.test1.PublicService.SomeBidiStreaming.default_timeout_s, 100
331        )
332
333    def test_rpc_provides_request_type(self) -> None:
334        self.assertIs(
335            self._service.SomeUnary.request,
336            self._service.SomeUnary.method.request_type,
337        )
338
339    def test_rpc_provides_response_type(self) -> None:
340        self.assertIs(
341            self._service.SomeUnary.request,
342            self._service.SomeUnary.method.request_type,
343        )
344
345
346class UnaryTest(_CallbackClientImplTestBase):
347    """Tests for invoking a unary RPC."""
348
349    def setUp(self) -> None:
350        super().setUp()
351        self.rpc = self._service.SomeUnary
352        self.method = self.rpc.method
353
354    def test_blocking_call(self) -> None:
355        for _ in range(3):
356            self._enqueue_response(
357                CLIENT_CHANNEL_ID,
358                self.method,
359                Status.ABORTED,
360                self.method.response_type(payload='0_o'),
361            )
362
363            status, response = self._service.SomeUnary(
364                self.method.request_type(magic_number=6)
365            )
366
367            self.assertEqual(
368                6, self._sent_payload(self.method.request_type).magic_number
369            )
370
371            self.assertIs(Status.ABORTED, status)
372            self.assertEqual('0_o', response.payload)
373
374    def test_nonblocking_call(self) -> None:
375        for _ in range(3):
376            callback = mock.Mock()
377            call = self.rpc.invoke(
378                self._request(magic_number=5), callback, callback
379            )
380
381            self._enqueue_response(
382                CLIENT_CHANNEL_ID,
383                self.method,
384                Status.ABORTED,
385                self.method.response_type(payload='0_o'),
386                call_id=call.call_id,
387            )
388            self._process_enqueued_packets()
389
390            callback.assert_has_calls(
391                [
392                    mock.call(call, self.method.response_type(payload='0_o')),
393                    mock.call(call, Status.ABORTED),
394                ]
395            )
396
397            self.assertEqual(
398                5, self._sent_payload(self.method.request_type).magic_number
399            )
400
401    def test_concurrent_nonblocking_calls(self) -> None:
402        # Start several calls to the same method
403        callbacks_and_calls: list[
404            tuple[mock.Mock, callback_client.call.Call]
405        ] = []
406        for _ in range(3):
407            callback = mock.Mock()
408            call = self.rpc.invoke(self._request(magic_number=5), callback)
409            callbacks_and_calls.append((callback, call))
410
411        # Respond only to the last call
412        last_callback, last_call = callbacks_and_calls.pop()
413        last_payload = self.method.response_type(payload='last payload')
414        self._enqueue_response(
415            CLIENT_CHANNEL_ID,
416            self.method,
417            payload=last_payload,
418            call_id=last_call.call_id,
419        )
420        self._process_enqueued_packets()
421
422        # Assert that only the last caller received a response
423        last_callback.assert_called_once_with(last_call, last_payload)
424        for remaining_callback, _ in callbacks_and_calls:
425            remaining_callback.assert_not_called()
426
427        # Respond to the other callers and check for receipt
428        other_payload = self.method.response_type(payload='other payload')
429        for callback, call in callbacks_and_calls:
430            self._enqueue_response(
431                CLIENT_CHANNEL_ID,
432                self.method,
433                payload=other_payload,
434                call_id=call.call_id,
435            )
436            self._process_enqueued_packets()
437            callback.assert_called_once_with(call, other_payload)
438
439    def test_open(self) -> None:
440        self.output_exception = IOError('this test should not send packets!')
441
442        for packet_id in (client.OPEN_CALL_ID, 123):
443            for _ in range(3):
444                self._enqueue_response(
445                    CLIENT_CHANNEL_ID,
446                    self.method,
447                    Status.ABORTED,
448                    self.method.response_type(payload='0_o'),
449                    call_id=packet_id,
450                )
451
452                callback = mock.Mock()
453                call = self.rpc.open(callback, callback, callback)
454                self.assertEqual(self.requests, [])
455
456                self._process_enqueued_packets()
457
458                callback.assert_has_calls(
459                    [
460                        mock.call(
461                            call, self.method.response_type(payload='0_o')
462                        ),
463                        mock.call(call, Status.ABORTED),
464                    ]
465                )
466                self.assertEqual(call.call_id, packet_id, "Adopts inbound ID")
467
468    def test_blocking_server_error(self) -> None:
469        for _ in range(3):
470            self._enqueue_error(
471                CLIENT_CHANNEL_ID,
472                self.method.service,
473                self.method,
474                Status.NOT_FOUND,
475            )
476
477            with self.assertRaises(callback_client.RpcError) as context:
478                self._service.SomeUnary(
479                    self.method.request_type(magic_number=6)
480                )
481
482            self.assertIs(context.exception.status, Status.NOT_FOUND)
483
484    def test_nonblocking_cancel(self) -> None:
485        callback = mock.Mock()
486
487        for _ in range(3):
488            call = self._service.SomeUnary.invoke(
489                self._request(magic_number=55), callback
490            )
491
492            self.assertGreater(len(self.requests), 0)
493            self.requests.clear()
494
495            self.assertTrue(call.cancel())
496            self.assertFalse(call.cancel())  # Already cancelled, returns False
497
498            self.assertEqual(
499                self.last_request().type, packet_pb2.PacketType.CLIENT_ERROR
500            )
501            self.assertEqual(self.last_request().status, Status.CANCELLED.value)
502
503        callback.assert_not_called()
504
505    def test_nonblocking_with_request_args(self) -> None:
506        self.rpc.invoke(request_args=dict(magic_number=1138))
507        self.assertEqual(
508            self._sent_payload(self.rpc.request).magic_number, 1138
509        )
510
511    def test_blocking_timeout_as_argument(self) -> None:
512        with self.assertRaises(callback_client.RpcTimeout):
513            self._service.SomeUnary(pw_rpc_timeout_s=0.0001)
514
515    def test_blocking_timeout_set_default(self) -> None:
516        self._service.SomeUnary.default_timeout_s = 0.0001
517
518        with self.assertRaises(callback_client.RpcTimeout):
519            self._service.SomeUnary()
520
521    def test_nonblocking_duplicate_calls_not_cancelled(self) -> None:
522        first_call = self.rpc.invoke()
523        self.assertFalse(first_call.completed())
524
525        second_call = self.rpc.invoke()
526
527        self.assertIs(first_call.error, None)
528        self.assertIs(second_call.error, None)
529
530    def test_nonblocking_exception_in_callback(self) -> None:
531        exception = ValueError('something went wrong! (intentionally)')
532
533        self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK)
534
535        call = self.rpc.invoke(on_completed=mock.Mock(side_effect=exception))
536
537        with self.assertRaises(RuntimeError) as context:
538            call.wait()
539
540        self.assertEqual(context.exception.__cause__, exception)
541
542    def test_unary_response(self) -> None:
543        proto = PROTOS.packages.pw.test1.SomeMessage(magic_number=123)
544        self.assertEqual(
545            repr(callback_client.UnaryResponse(Status.ABORTED, proto)),
546            '(Status.ABORTED, pw.test1.SomeMessage(magic_number=123))',
547        )
548        self.assertEqual(
549            repr(callback_client.UnaryResponse(Status.OK, None)),
550            '(Status.OK, None)',
551        )
552
553    def test_on_call_hook(self) -> None:
554        hook_function = mock.Mock()
555
556        self._client = client.Client.from_modules(
557            callback_client.Impl(on_call_hook=hook_function),
558            [client.Channel(CLIENT_CHANNEL_ID, self._handle_packet)],
559            PROTOS.modules(),
560        )
561
562        self._service = self._client.channel(
563            CLIENT_CHANNEL_ID
564        ).rpcs.pw.test1.PublicService
565
566        self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK)
567        self._service.SomeUnary(self.method.request_type(magic_number=6))
568
569        hook_function.assert_called_once()
570        self.assertEqual(
571            hook_function.call_args[0][0].method.full_name,
572            self.method.full_name,
573        )
574
575
576class ServerStreamingTest(_CallbackClientImplTestBase):
577    """Tests for server streaming RPCs."""
578
579    def setUp(self) -> None:
580        super().setUp()
581        self.rpc = self._service.SomeServerStreaming
582        self.method = self.rpc.method
583
584    def test_blocking_call(self) -> None:
585        rep1 = self.method.response_type(payload='!!!')
586        rep2 = self.method.response_type(payload='?')
587
588        for _ in range(3):
589            self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1)
590            self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep2)
591            self._enqueue_response(
592                CLIENT_CHANNEL_ID, self.method, Status.ABORTED
593            )
594
595            self.assertEqual(
596                [rep1, rep2],
597                self._service.SomeServerStreaming(magic_number=4).responses,
598            )
599
600            self.assertEqual(
601                4, self._sent_payload(self.method.request_type).magic_number
602            )
603
604    def test_nonblocking_call(self) -> None:
605        rep1 = self.method.response_type(payload='!!!')
606        rep2 = self.method.response_type(payload='?')
607
608        for _ in range(3):
609            self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1)
610            self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep2)
611            self._enqueue_response(
612                CLIENT_CHANNEL_ID, self.method, Status.ABORTED
613            )
614
615            callback = mock.Mock()
616            call = self.rpc.invoke(
617                self._request(magic_number=3), callback, callback
618            )
619
620            callback.assert_has_calls(
621                [
622                    mock.call(call, self.method.response_type(payload='!!!')),
623                    mock.call(call, self.method.response_type(payload='?')),
624                    mock.call(call, Status.ABORTED),
625                ]
626            )
627
628            self.assertEqual(
629                3, self._sent_payload(self.method.request_type).magic_number
630            )
631
632    def test_open(self) -> None:
633        self.output_exception = IOError('this test should not send packets!')
634        rep1 = self.method.response_type(payload='!!!')
635        rep2 = self.method.response_type(payload='?')
636
637        for packet_id in (client.OPEN_CALL_ID, 123):
638            for _ in range(3):
639                self._enqueue_server_stream(
640                    CLIENT_CHANNEL_ID, self.method, rep1, call_id=packet_id
641                )
642                self._enqueue_server_stream(
643                    CLIENT_CHANNEL_ID, self.method, rep2, call_id=packet_id
644                )
645                self._enqueue_response(
646                    CLIENT_CHANNEL_ID,
647                    self.method,
648                    Status.ABORTED,
649                    call_id=packet_id,
650                )
651
652                callback = mock.Mock()
653                call = self.rpc.open(callback, callback, callback)
654                self.assertEqual(self.requests, [])
655
656                self._process_enqueued_packets()
657
658                callback.assert_has_calls(
659                    [
660                        mock.call(
661                            call, self.method.response_type(payload='!!!')
662                        ),
663                        mock.call(call, self.method.response_type(payload='?')),
664                        mock.call(call, Status.ABORTED),
665                    ]
666                )
667                self.assertEqual(call.call_id, packet_id, "Adopts inbound ID")
668
669    def test_nonblocking_cancel(self) -> None:
670        resp = self.rpc.method.response_type(payload='!!!')
671        self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.rpc.method, resp)
672
673        callback = mock.Mock()
674        call = self.rpc.invoke(self._request(magic_number=3), callback)
675        callback.assert_called_once_with(
676            call, self.rpc.method.response_type(payload='!!!')
677        )
678
679        callback.reset_mock()
680
681        call.cancel()
682
683        self.assertEqual(
684            self.last_request().type, packet_pb2.PacketType.CLIENT_ERROR
685        )
686        self.assertEqual(self.last_request().status, Status.CANCELLED.value)
687
688        # Ensure the RPC can be called after being cancelled.
689        self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, resp)
690        self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK)
691
692        call = self.rpc.invoke(
693            self._request(magic_number=3), callback, callback
694        )
695
696        callback.assert_has_calls(
697            [
698                mock.call(call, self.method.response_type(payload='!!!')),
699                mock.call(call, Status.OK),
700            ]
701        )
702
703    def test_request_completion(self) -> None:
704        resp = self.rpc.method.response_type(payload='!!!')
705        self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.rpc.method, resp)
706
707        callback = mock.Mock()
708        call = self.rpc.invoke(self._request(magic_number=3), callback)
709        callback.assert_called_once_with(
710            call, self.rpc.method.response_type(payload='!!!')
711        )
712
713        callback.reset_mock()
714
715        call.request_completion()
716
717        self.assertEqual(
718            self.last_request().type,
719            packet_pb2.PacketType.CLIENT_REQUEST_COMPLETION,
720        )
721
722    def test_nonblocking_with_request_args(self) -> None:
723        self.rpc.invoke(request_args=dict(magic_number=1138))
724        self.assertEqual(
725            self._sent_payload(self.rpc.request).magic_number, 1138
726        )
727
728    def test_blocking_timeout(self) -> None:
729        with self.assertRaises(callback_client.RpcTimeout):
730            self._service.SomeServerStreaming(pw_rpc_timeout_s=0.0001)
731
732    def test_nonblocking_iteration_timeout(self) -> None:
733        call = self._service.SomeServerStreaming.invoke(timeout_s=0.0001)
734        with self.assertRaises(callback_client.RpcTimeout):
735            for _ in call:
736                pass
737
738    def test_nonblocking_duplicate_calls_not_cancelled(self) -> None:
739        first_call = self.rpc.invoke()
740        self.assertFalse(first_call.completed())
741
742        second_call = self.rpc.invoke()
743
744        self.assertIs(first_call.error, None)
745        self.assertIs(second_call.error, None)
746
747    def test_nonblocking_iterate_over_count(self) -> None:
748        reply = self.method.response_type(payload='!?')
749
750        for _ in range(4):
751            self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, reply)
752
753        call = self.rpc.invoke()
754
755        self.assertEqual(list(call.get_responses(count=1)), [reply])
756        self.assertEqual(next(iter(call)), reply)
757        self.assertEqual(list(call.get_responses(count=2)), [reply, reply])
758
759    def test_nonblocking_iterate_after_completed_doesnt_block(self) -> None:
760        reply = self.method.response_type(payload='!?')
761        self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, reply)
762        self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK)
763
764        call = self.rpc.invoke()
765
766        self.assertEqual(list(call.get_responses()), [reply])
767        self.assertEqual(list(call.get_responses()), [])
768        self.assertEqual(list(call), [])
769
770
771class ClientStreamingTest(_CallbackClientImplTestBase):
772    """Tests for client streaming RPCs."""
773
774    def setUp(self) -> None:
775        super().setUp()
776        self.rpc = self._service.SomeClientStreaming
777        self.method = self.rpc.method
778
779    def test_blocking_call(self) -> None:
780        requests = [
781            self.method.request_type(magic_number=123),
782            self.method.request_type(magic_number=456),
783        ]
784
785        # Send after len(requests) and the client stream end packet.
786        self.send_responses_after_packets = 3
787        response = self.method.response_type(payload='yo')
788        self._enqueue_response(
789            CLIENT_CHANNEL_ID, self.method, Status.OK, response
790        )
791
792        results = self.rpc(requests)
793        self.assertIs(results.status, Status.OK)
794        self.assertEqual(results.response, response)
795
796    def test_blocking_server_error(self) -> None:
797        requests = [self.method.request_type(magic_number=123)]
798
799        # Send after len(requests) and the client stream end packet.
800        self._enqueue_error(
801            CLIENT_CHANNEL_ID,
802            self.method.service,
803            self.method,
804            Status.NOT_FOUND,
805        )
806
807        with self.assertRaises(callback_client.RpcError) as context:
808            self.rpc(requests)
809
810        self.assertIs(context.exception.status, Status.NOT_FOUND)
811
812    def test_nonblocking_call(self) -> None:
813        """Tests a successful client streaming RPC ended by the server."""
814        payload_1 = self.method.response_type(payload='-_-')
815
816        for _ in range(3):
817            stream = self._service.SomeClientStreaming.invoke()
818            self.assertFalse(stream.completed())
819
820            stream.send(magic_number=31)
821            self.assertIs(
822                packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type
823            )
824            self.assertEqual(
825                31, self._sent_payload(self.method.request_type).magic_number
826            )
827            self.assertFalse(stream.completed())
828
829            # Enqueue the server response to be sent after the next message.
830            self._enqueue_response(
831                CLIENT_CHANNEL_ID, self.method, Status.OK, payload_1
832            )
833
834            stream.send(magic_number=32)
835            self.assertIs(
836                packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type
837            )
838            self.assertEqual(
839                32, self._sent_payload(self.method.request_type).magic_number
840            )
841
842            self.assertTrue(stream.completed())
843            self.assertIs(Status.OK, stream.status)
844            self.assertIsNone(stream.error)
845            self.assertEqual(payload_1, stream.response)
846
847    def test_open(self) -> None:
848        self.output_exception = IOError('this test should not send packets!')
849        payload = self.method.response_type(payload='-_-')
850
851        for packet_id in (client.OPEN_CALL_ID, 123):
852            for _ in range(3):
853                self._enqueue_response(
854                    CLIENT_CHANNEL_ID,
855                    self.method,
856                    Status.OK,
857                    payload,
858                    call_id=packet_id,
859                )
860
861                callback = mock.Mock()
862                call = self.rpc.open(callback, callback, callback)
863                self.assertEqual(self.requests, [])
864
865                self._process_enqueued_packets()
866
867                callback.assert_has_calls(
868                    [
869                        mock.call(call, payload),
870                        mock.call(call, Status.OK),
871                    ]
872                )
873                self.assertEqual(call.call_id, packet_id, "Adopts inbound ID")
874
875    def test_nonblocking_finish(self) -> None:
876        """Tests a client streaming RPC ended by the client."""
877        payload_1 = self.method.response_type(payload='-_-')
878
879        for _ in range(3):
880            stream = self._service.SomeClientStreaming.invoke()
881            self.assertFalse(stream.completed())
882
883            stream.send(magic_number=37)
884            self.assertIs(
885                packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type
886            )
887            self.assertEqual(
888                37, self._sent_payload(self.method.request_type).magic_number
889            )
890            self.assertFalse(stream.completed())
891
892            # Enqueue the server response to be sent after the next message.
893            self._enqueue_response(
894                CLIENT_CHANNEL_ID, self.method, Status.OK, payload_1
895            )
896
897            stream.finish_and_wait()
898            self.assertIs(
899                packet_pb2.PacketType.CLIENT_REQUEST_COMPLETION,
900                self.last_request().type,
901            )
902
903            self.assertTrue(stream.completed())
904            self.assertIs(Status.OK, stream.status)
905            self.assertIsNone(stream.error)
906            self.assertEqual(payload_1, stream.response)
907
908    def test_nonblocking_cancel(self) -> None:
909        for _ in range(3):
910            stream = self._service.SomeClientStreaming.invoke()
911            stream.send(magic_number=37)
912
913            self.assertTrue(stream.cancel())
914            self.assertIs(
915                packet_pb2.PacketType.CLIENT_ERROR, self.last_request().type
916            )
917            self.assertIs(Status.CANCELLED.value, self.last_request().status)
918            self.assertFalse(stream.cancel())
919
920            self.assertTrue(stream.completed())
921            self.assertIs(stream.error, Status.CANCELLED)
922
923    def test_nonblocking_server_error(self) -> None:
924        for _ in range(3):
925            stream = self._service.SomeClientStreaming.invoke()
926
927            self._enqueue_error(
928                CLIENT_CHANNEL_ID,
929                self.method.service,
930                self.method,
931                Status.INVALID_ARGUMENT,
932            )
933            stream.send(magic_number=2**32 - 1)
934
935            with self.assertRaises(callback_client.RpcError) as context:
936                stream.finish_and_wait()
937
938            self.assertIs(context.exception.status, Status.INVALID_ARGUMENT)
939
940    def test_nonblocking_server_error_after_stream_end(self) -> None:
941        for _ in range(3):
942            stream = self._service.SomeClientStreaming.invoke()
943
944            # Error will be sent in response to the CLIENT_REQUEST_COMPLETION
945            # packet.
946            self._enqueue_error(
947                CLIENT_CHANNEL_ID,
948                self.method.service,
949                self.method,
950                Status.INVALID_ARGUMENT,
951            )
952
953            with self.assertRaises(callback_client.RpcError) as context:
954                stream.finish_and_wait()
955
956            self.assertIs(context.exception.status, Status.INVALID_ARGUMENT)
957
958    def test_nonblocking_send_after_cancelled(self) -> None:
959        call = self._service.SomeClientStreaming.invoke()
960        self.assertTrue(call.cancel())
961
962        with self.assertRaises(callback_client.RpcError) as context:
963            call.send(payload='hello')
964
965        self.assertIs(context.exception.status, Status.CANCELLED)
966
967    def test_nonblocking_finish_after_completed(self) -> None:
968        reply = self.method.response_type(payload='!?')
969        self._enqueue_response(
970            CLIENT_CHANNEL_ID, self.method, Status.UNAVAILABLE, reply
971        )
972
973        call = self.rpc.invoke()
974        result = call.finish_and_wait()
975        self.assertEqual(result.response, reply)
976
977        self.assertEqual(result, call.finish_and_wait())
978        self.assertEqual(result, call.finish_and_wait())
979
980    def test_nonblocking_finish_after_error(self) -> None:
981        self._enqueue_error(
982            CLIENT_CHANNEL_ID,
983            self.method.service,
984            self.method,
985            Status.UNAVAILABLE,
986        )
987
988        call = self.rpc.invoke()
989
990        for _ in range(3):
991            with self.assertRaises(callback_client.RpcError) as context:
992                call.finish_and_wait()
993
994            self.assertIs(context.exception.status, Status.UNAVAILABLE)
995            self.assertIs(call.error, Status.UNAVAILABLE)
996            self.assertIsNone(call.response)
997
998    def test_nonblocking_duplicate_calls_not_cancelled(self) -> None:
999        first_call = self.rpc.invoke()
1000        self.assertFalse(first_call.completed())
1001
1002        second_call = self.rpc.invoke()
1003
1004        self.assertIs(first_call.error, None)
1005        self.assertIs(second_call.error, None)
1006
1007
1008class BidirectionalStreamingTest(_CallbackClientImplTestBase):
1009    """Tests for bidirectional streaming RPCs."""
1010
1011    def setUp(self) -> None:
1012        super().setUp()
1013        self.rpc = self._service.SomeBidiStreaming
1014        self.method = self.rpc.method
1015
1016    def test_blocking_call(self) -> None:
1017        requests = [
1018            self.method.request_type(magic_number=123),
1019            self.method.request_type(magic_number=456),
1020        ]
1021
1022        # Send after len(requests) and the client stream end packet.
1023        self.send_responses_after_packets = 3
1024        self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.NOT_FOUND)
1025
1026        results = self.rpc(requests)
1027        self.assertIs(results.status, Status.NOT_FOUND)
1028        self.assertFalse(results.responses)
1029
1030    def test_blocking_server_error(self) -> None:
1031        requests = [self.method.request_type(magic_number=123)]
1032
1033        # Send after len(requests) and the client stream end packet.
1034        self._enqueue_error(
1035            CLIENT_CHANNEL_ID,
1036            self.method.service,
1037            self.method,
1038            Status.NOT_FOUND,
1039        )
1040
1041        with self.assertRaises(callback_client.RpcError) as context:
1042            self.rpc(requests)
1043
1044        self.assertIs(context.exception.status, Status.NOT_FOUND)
1045
1046    def test_nonblocking_call(self) -> None:
1047        """Tests a bidirectional streaming RPC ended by the server."""
1048        rep1 = self.method.response_type(payload='!!!')
1049        rep2 = self.method.response_type(payload='?')
1050
1051        for _ in range(3):
1052            responses: list = []
1053            stream = self._service.SomeBidiStreaming.invoke(
1054                lambda _, res, responses=responses: responses.append(res)
1055            )
1056            self.assertFalse(stream.completed())
1057
1058            stream.send(magic_number=55)
1059            self.assertIs(
1060                packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type
1061            )
1062            self.assertEqual(
1063                55, self._sent_payload(self.method.request_type).magic_number
1064            )
1065            self.assertFalse(stream.completed())
1066            self.assertEqual([], responses)
1067
1068            self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1)
1069            self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep2)
1070
1071            stream.send(magic_number=66)
1072            self.assertIs(
1073                packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type
1074            )
1075            self.assertEqual(
1076                66, self._sent_payload(self.method.request_type).magic_number
1077            )
1078            self.assertFalse(stream.completed())
1079            self.assertEqual([rep1, rep2], responses)
1080
1081            self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK)
1082
1083            stream.send(magic_number=77)
1084            self.assertTrue(stream.completed())
1085            self.assertEqual([rep1, rep2], responses)
1086
1087            self.assertIs(Status.OK, stream.status)
1088            self.assertIsNone(stream.error)
1089
1090    def test_open(self) -> None:
1091        self.output_exception = IOError('this test should not send packets!')
1092        rep1 = self.method.response_type(payload='!!!')
1093        rep2 = self.method.response_type(payload='?')
1094
1095        for packet_id in (client.OPEN_CALL_ID, 123):
1096            for _ in range(3):
1097                self._enqueue_server_stream(
1098                    CLIENT_CHANNEL_ID, self.method, rep1, call_id=packet_id
1099                )
1100                self._enqueue_server_stream(
1101                    CLIENT_CHANNEL_ID, self.method, rep2, call_id=packet_id
1102                )
1103                self._enqueue_response(
1104                    CLIENT_CHANNEL_ID, self.method, Status.OK, call_id=packet_id
1105                )
1106
1107                callback = mock.Mock()
1108                call = self.rpc.open(callback, callback, callback)
1109                self.assertEqual(self.requests, [])
1110
1111                self._process_enqueued_packets()
1112
1113                callback.assert_has_calls(
1114                    [
1115                        mock.call(
1116                            call, self.method.response_type(payload='!!!')
1117                        ),
1118                        mock.call(call, self.method.response_type(payload='?')),
1119                        mock.call(call, Status.OK),
1120                    ]
1121                )
1122                self.assertEqual(call.call_id, packet_id, "Adopts inbound ID")
1123
1124    @mock.patch('pw_rpc.callback_client.call.Call._default_response')
1125    def test_nonblocking(self, callback) -> None:
1126        """Tests a bidirectional streaming RPC ended by the server."""
1127        reply = self.method.response_type(payload='This is the payload!')
1128        self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, reply)
1129
1130        self._service.SomeBidiStreaming.invoke()
1131
1132        callback.assert_called_once_with(mock.ANY, reply)
1133
1134    def test_nonblocking_server_error(self) -> None:
1135        rep1 = self.method.response_type(payload='!!!')
1136
1137        for _ in range(3):
1138            responses: list = []
1139            stream = self._service.SomeBidiStreaming.invoke(
1140                lambda _, res, responses=responses: responses.append(res)
1141            )
1142            self.assertFalse(stream.completed())
1143
1144            self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1)
1145
1146            stream.send(magic_number=55)
1147            self.assertFalse(stream.completed())
1148            self.assertEqual([rep1], responses)
1149
1150            self._enqueue_error(
1151                CLIENT_CHANNEL_ID,
1152                self.method.service,
1153                self.method,
1154                Status.OUT_OF_RANGE,
1155            )
1156
1157            stream.send(magic_number=99999)
1158            self.assertTrue(stream.completed())
1159            self.assertEqual([rep1], responses)
1160
1161            self.assertIsNone(stream.status)
1162            self.assertIs(Status.OUT_OF_RANGE, stream.error)
1163
1164            with self.assertRaises(callback_client.RpcError) as context:
1165                stream.finish_and_wait()
1166            self.assertIs(context.exception.status, Status.OUT_OF_RANGE)
1167
1168    def test_nonblocking_server_error_after_stream_end(self) -> None:
1169        for _ in range(3):
1170            stream = self._service.SomeBidiStreaming.invoke()
1171
1172            # Error will be sent in response to the CLIENT_REQUEST_COMPLETION
1173            # packet.
1174            self._enqueue_error(
1175                CLIENT_CHANNEL_ID,
1176                self.method.service,
1177                self.method,
1178                Status.INVALID_ARGUMENT,
1179            )
1180
1181            with self.assertRaises(callback_client.RpcError) as context:
1182                stream.finish_and_wait()
1183
1184            self.assertIs(context.exception.status, Status.INVALID_ARGUMENT)
1185
1186    def test_nonblocking_send_after_cancelled(self) -> None:
1187        call = self._service.SomeBidiStreaming.invoke()
1188        self.assertTrue(call.cancel())
1189
1190        with self.assertRaises(callback_client.RpcError) as context:
1191            call.send(payload='hello')
1192
1193        self.assertIs(context.exception.status, Status.CANCELLED)
1194
1195    def test_nonblocking_finish_after_completed(self) -> None:
1196        reply = self.method.response_type(payload='!?')
1197        self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, reply)
1198        self._enqueue_response(
1199            CLIENT_CHANNEL_ID, self.method, Status.UNAVAILABLE
1200        )
1201
1202        call = self.rpc.invoke()
1203        result = call.finish_and_wait()
1204        self.assertEqual(result.responses, [reply])
1205
1206        self.assertEqual(result, call.finish_and_wait())
1207        self.assertEqual(result, call.finish_and_wait())
1208
1209    def test_nonblocking_finish_after_error(self) -> None:
1210        reply = self.method.response_type(payload='!?')
1211        self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, reply)
1212        self._enqueue_error(
1213            CLIENT_CHANNEL_ID,
1214            self.method.service,
1215            self.method,
1216            Status.UNAVAILABLE,
1217        )
1218
1219        call = self.rpc.invoke()
1220
1221        for _ in range(3):
1222            with self.assertRaises(callback_client.RpcError) as context:
1223                call.finish_and_wait()
1224
1225            self.assertIs(context.exception.status, Status.UNAVAILABLE)
1226            self.assertIs(call.error, Status.UNAVAILABLE)
1227            self.assertEqual(list(call.responses), [reply])
1228
1229    def test_nonblocking_duplicate_calls_not_cancelled(self) -> None:
1230        first_call = self.rpc.invoke()
1231        self.assertFalse(first_call.completed())
1232
1233        second_call = self.rpc.invoke()
1234
1235        self.assertIs(first_call.error, None)
1236        self.assertIs(second_call.error, None)
1237
1238    def test_max_responses(self) -> None:
1239        rep1 = self.method.response_type(payload='a')
1240        rep2 = self.method.response_type(payload='b')
1241        rep3 = self.method.response_type(payload='c')
1242        rep4 = self.method.response_type(payload='d')
1243        rep5 = self.method.response_type(payload='e')
1244
1245        self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1)
1246        self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep2)
1247        self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep3)
1248        self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep4)
1249        self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep5)
1250        self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK)
1251
1252        responses: list = []
1253        call = self.rpc.invoke(
1254            on_next=lambda _, res, responses=responses: responses.append(res),
1255            max_responses=4,
1256        )
1257        result = call.finish_and_wait()
1258
1259        # All 5 responses are received, but only the most recent 4 are stored
1260        # in the call.
1261        self.assertEqual(responses, [rep1, rep2, rep3, rep4, rep5])
1262        self.assertEqual(result.responses, [rep2, rep3, rep4, rep5])
1263        self.assertEqual(result.responses, list(call.responses))
1264
1265    def test_stream_response(self) -> None:
1266        proto = PROTOS.packages.pw.test1.SomeMessage(magic_number=123)
1267        self.assertEqual(
1268            repr(callback_client.StreamResponse(Status.ABORTED, [proto] * 2)),
1269            '(Status.ABORTED, [pw.test1.SomeMessage(magic_number=123), '
1270            'pw.test1.SomeMessage(magic_number=123)])',
1271        )
1272        self.assertEqual(
1273            repr(callback_client.StreamResponse(Status.OK, [])),
1274            '(Status.OK, [])',
1275        )
1276
1277
1278if __name__ == '__main__':
1279    unittest.main()
1280