1#!/usr/bin/env python3 2# Copyright 2021 The Pigweed Authors 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); you may not 5# use this file except in compliance with the License. You may obtain a copy of 6# the License at 7# 8# https://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13# License for the specific language governing permissions and limitations under 14# the License. 15"""Tests using the callback client for pw_rpc.""" 16 17import unittest 18from unittest import mock 19from typing import Any 20 21from pw_protobuf_compiler import python_protos 22from pw_status import Status 23 24from pw_rpc import callback_client, client, descriptors, packets 25from pw_rpc.internal import packet_pb2 26 27TEST_PROTO_1 = """\ 28syntax = "proto3"; 29 30package pw.test1; 31 32message SomeMessage { 33 uint32 magic_number = 1; 34} 35 36message AnotherMessage { 37 enum Result { 38 FAILED = 0; 39 FAILED_MISERABLY = 1; 40 I_DONT_WANT_TO_TALK_ABOUT_IT = 2; 41 } 42 43 Result result = 1; 44 string payload = 2; 45} 46 47service PublicService { 48 rpc SomeUnary(SomeMessage) returns (AnotherMessage) {} 49 rpc SomeServerStreaming(SomeMessage) returns (stream AnotherMessage) {} 50 rpc SomeClientStreaming(stream SomeMessage) returns (AnotherMessage) {} 51 rpc SomeBidiStreaming(stream SomeMessage) returns (stream AnotherMessage) {} 52} 53""" 54 55PROTOS = python_protos.Library.from_strings(TEST_PROTO_1) 56CLIENT_CHANNEL_ID: int = 489 57 58 59def _message_bytes(msg) -> bytes: 60 return msg if isinstance(msg, bytes) else msg.SerializeToString() 61 62 63class _CallbackClientImplTestBase(unittest.TestCase): 64 """Supports writing tests that require responses from an RPC server.""" 65 66 def setUp(self) -> None: 67 self._request = PROTOS.packages.pw.test1.SomeMessage 68 69 self._client = client.Client.from_modules( 70 callback_client.Impl(), 71 [client.Channel(CLIENT_CHANNEL_ID, self._handle_packet)], 72 PROTOS.modules(), 73 ) 74 self._service = self._client.channel( 75 CLIENT_CHANNEL_ID 76 ).rpcs.pw.test1.PublicService 77 78 self.requests: list[packet_pb2.RpcPacket] = [] 79 self._next_packets: list[tuple[bytes, Status]] = [] 80 self.send_responses_after_packets: float = 1 81 82 self.output_exception: Exception | None = None 83 84 def last_request(self) -> packet_pb2.RpcPacket: 85 assert self.requests 86 return self.requests[-1] 87 88 def _enqueue_response( 89 self, 90 channel_id: int = CLIENT_CHANNEL_ID, 91 method: descriptors.Method | None = None, 92 status: Status = Status.OK, 93 payload: bytes = b'', 94 *, 95 ids: tuple[int, int] | None = None, 96 process_status: Status = Status.OK, 97 call_id: int = client.OPEN_CALL_ID, 98 ) -> None: 99 if method: 100 assert ids is None 101 service_id, method_id = method.service.id, method.id 102 else: 103 assert ids is not None and method is None 104 service_id, method_id = ids 105 106 self._next_packets.append( 107 ( 108 packet_pb2.RpcPacket( 109 type=packet_pb2.PacketType.RESPONSE, 110 channel_id=channel_id, 111 service_id=service_id, 112 method_id=method_id, 113 call_id=call_id, 114 status=status.value, 115 payload=_message_bytes(payload), 116 ).SerializeToString(), 117 process_status, 118 ) 119 ) 120 121 def _enqueue_server_stream( 122 self, 123 channel_id: int, 124 method, 125 response, 126 process_status=Status.OK, 127 call_id: int = client.OPEN_CALL_ID, 128 ) -> None: 129 self._next_packets.append( 130 ( 131 packet_pb2.RpcPacket( 132 type=packet_pb2.PacketType.SERVER_STREAM, 133 channel_id=channel_id, 134 service_id=method.service.id, 135 method_id=method.id, 136 call_id=call_id, 137 payload=_message_bytes(response), 138 ).SerializeToString(), 139 process_status, 140 ) 141 ) 142 143 def _enqueue_error( 144 self, 145 channel_id: int, 146 service, 147 method, 148 status: Status, 149 process_status=Status.OK, 150 call_id: int = client.OPEN_CALL_ID, 151 ) -> None: 152 self._next_packets.append( 153 ( 154 packet_pb2.RpcPacket( 155 type=packet_pb2.PacketType.SERVER_ERROR, 156 channel_id=channel_id, 157 service_id=service 158 if isinstance(service, int) 159 else service.id, 160 method_id=method if isinstance(method, int) else method.id, 161 call_id=call_id, 162 status=status.value, 163 ).SerializeToString(), 164 process_status, 165 ) 166 ) 167 168 def _handle_packet(self, data: bytes) -> None: 169 if self.output_exception: 170 raise self.output_exception # pylint: disable=raising-bad-type 171 172 self.requests.append(packets.decode(data)) 173 174 if self.send_responses_after_packets > 1: 175 self.send_responses_after_packets -= 1 176 return 177 178 self._process_enqueued_packets() 179 180 def _process_enqueued_packets(self) -> None: 181 # Set send_responses_after_packets to infinity to prevent potential 182 # infinite recursion when a packet causes another packet to send. 183 send_after_count = self.send_responses_after_packets 184 self.send_responses_after_packets = float('inf') 185 186 for packet, status in self._next_packets: 187 self.assertIs(status, self._client.process_packet(packet)) 188 189 self._next_packets.clear() 190 self.send_responses_after_packets = send_after_count 191 192 def _sent_payload(self, message_type: type) -> Any: 193 message = message_type() 194 message.ParseFromString(self.last_request().payload) 195 return message 196 197 198# Disable docstring requirements for test functions. 199# pylint: disable=missing-function-docstring 200 201 202class CallbackClientImplTest(_CallbackClientImplTestBase): 203 """Tests the callback_client.Impl client implementation.""" 204 205 def test_callback_exceptions_suppressed(self) -> None: 206 stub = self._service.SomeUnary 207 208 self._enqueue_response(CLIENT_CHANNEL_ID, stub.method) 209 exception_msg = 'YOU BROKE IT O-]-<' 210 211 with self.assertLogs(callback_client.__package__, 'ERROR') as logs: 212 stub.invoke( 213 self._request(), mock.Mock(side_effect=Exception(exception_msg)) 214 ) 215 216 self.assertIn(exception_msg, ''.join(logs.output)) 217 218 # Make sure we can still invoke the RPC. 219 self._enqueue_response(CLIENT_CHANNEL_ID, stub.method, Status.UNKNOWN) 220 status, _ = stub() 221 self.assertIs(status, Status.UNKNOWN) 222 223 def test_ignore_bad_packets_with_pending_rpc(self) -> None: 224 method = self._service.SomeUnary.method 225 service_id = method.service.id 226 227 # Unknown channel 228 self._enqueue_response(999, method, process_status=Status.NOT_FOUND) 229 # Bad service 230 self._enqueue_response( 231 CLIENT_CHANNEL_ID, ids=(999, method.id), process_status=Status.OK 232 ) 233 # Bad method 234 self._enqueue_response( 235 CLIENT_CHANNEL_ID, ids=(service_id, 999), process_status=Status.OK 236 ) 237 # For RPC not pending (is Status.OK because the packet is processed) 238 self._enqueue_response( 239 CLIENT_CHANNEL_ID, 240 ids=(service_id, self._service.SomeBidiStreaming.method.id), 241 process_status=Status.OK, 242 ) 243 244 self._enqueue_response( 245 CLIENT_CHANNEL_ID, method, process_status=Status.OK 246 ) 247 248 status, response = self._service.SomeUnary(magic_number=6) 249 self.assertIs(Status.OK, status) 250 self.assertEqual('', response.payload) 251 252 def test_server_error_for_unknown_call_sends_no_errors(self) -> None: 253 method = self._service.SomeUnary.method 254 service_id = method.service.id 255 256 # Unknown channel 257 self._enqueue_error( 258 999, 259 service_id, 260 method, 261 Status.NOT_FOUND, 262 process_status=Status.NOT_FOUND, 263 ) 264 # Bad service 265 self._enqueue_error( 266 CLIENT_CHANNEL_ID, 999, method.id, Status.INVALID_ARGUMENT 267 ) 268 # Bad method 269 self._enqueue_error( 270 CLIENT_CHANNEL_ID, service_id, 999, Status.INVALID_ARGUMENT 271 ) 272 # For RPC not pending 273 self._enqueue_error( 274 CLIENT_CHANNEL_ID, 275 service_id, 276 self._service.SomeBidiStreaming.method.id, 277 Status.NOT_FOUND, 278 ) 279 280 self._process_enqueued_packets() 281 282 self.assertEqual(self.requests, []) 283 284 def test_exception_if_payload_fails_to_decode(self) -> None: 285 method = self._service.SomeUnary.method 286 287 self._enqueue_response( 288 CLIENT_CHANNEL_ID, 289 method, 290 Status.OK, 291 b'INVALID DATA!!!', 292 process_status=Status.OK, 293 ) 294 295 with self.assertRaises(callback_client.RpcError) as context: 296 self._service.SomeUnary(magic_number=6) 297 298 self.assertIs(context.exception.status, Status.DATA_LOSS) 299 300 def test_rpc_help_contains_method_name(self) -> None: 301 rpc = self._service.SomeUnary 302 self.assertIn(rpc.method.full_name, rpc.help()) 303 304 def test_default_timeouts_set_on_impl(self) -> None: 305 impl = callback_client.Impl(None, 1.5) 306 307 self.assertEqual(impl.default_unary_timeout_s, None) 308 self.assertEqual(impl.default_stream_timeout_s, 1.5) 309 310 def test_default_timeouts_set_for_all_rpcs(self) -> None: 311 rpc_client = client.Client.from_modules( 312 callback_client.Impl(99, 100), 313 [client.Channel(CLIENT_CHANNEL_ID, lambda *a, **b: None)], 314 PROTOS.modules(), 315 ) 316 rpcs = rpc_client.channel(CLIENT_CHANNEL_ID).rpcs 317 318 self.assertEqual( 319 rpcs.pw.test1.PublicService.SomeUnary.default_timeout_s, 99 320 ) 321 self.assertEqual( 322 rpcs.pw.test1.PublicService.SomeServerStreaming.default_timeout_s, 323 100, 324 ) 325 self.assertEqual( 326 rpcs.pw.test1.PublicService.SomeClientStreaming.default_timeout_s, 327 99, 328 ) 329 self.assertEqual( 330 rpcs.pw.test1.PublicService.SomeBidiStreaming.default_timeout_s, 100 331 ) 332 333 def test_rpc_provides_request_type(self) -> None: 334 self.assertIs( 335 self._service.SomeUnary.request, 336 self._service.SomeUnary.method.request_type, 337 ) 338 339 def test_rpc_provides_response_type(self) -> None: 340 self.assertIs( 341 self._service.SomeUnary.request, 342 self._service.SomeUnary.method.request_type, 343 ) 344 345 346class UnaryTest(_CallbackClientImplTestBase): 347 """Tests for invoking a unary RPC.""" 348 349 def setUp(self) -> None: 350 super().setUp() 351 self.rpc = self._service.SomeUnary 352 self.method = self.rpc.method 353 354 def test_blocking_call(self) -> None: 355 for _ in range(3): 356 self._enqueue_response( 357 CLIENT_CHANNEL_ID, 358 self.method, 359 Status.ABORTED, 360 self.method.response_type(payload='0_o'), 361 ) 362 363 status, response = self._service.SomeUnary( 364 self.method.request_type(magic_number=6) 365 ) 366 367 self.assertEqual( 368 6, self._sent_payload(self.method.request_type).magic_number 369 ) 370 371 self.assertIs(Status.ABORTED, status) 372 self.assertEqual('0_o', response.payload) 373 374 def test_nonblocking_call(self) -> None: 375 for _ in range(3): 376 callback = mock.Mock() 377 call = self.rpc.invoke( 378 self._request(magic_number=5), callback, callback 379 ) 380 381 self._enqueue_response( 382 CLIENT_CHANNEL_ID, 383 self.method, 384 Status.ABORTED, 385 self.method.response_type(payload='0_o'), 386 call_id=call.call_id, 387 ) 388 self._process_enqueued_packets() 389 390 callback.assert_has_calls( 391 [ 392 mock.call(call, self.method.response_type(payload='0_o')), 393 mock.call(call, Status.ABORTED), 394 ] 395 ) 396 397 self.assertEqual( 398 5, self._sent_payload(self.method.request_type).magic_number 399 ) 400 401 def test_concurrent_nonblocking_calls(self) -> None: 402 # Start several calls to the same method 403 callbacks_and_calls: list[ 404 tuple[mock.Mock, callback_client.call.Call] 405 ] = [] 406 for _ in range(3): 407 callback = mock.Mock() 408 call = self.rpc.invoke(self._request(magic_number=5), callback) 409 callbacks_and_calls.append((callback, call)) 410 411 # Respond only to the last call 412 last_callback, last_call = callbacks_and_calls.pop() 413 last_payload = self.method.response_type(payload='last payload') 414 self._enqueue_response( 415 CLIENT_CHANNEL_ID, 416 self.method, 417 payload=last_payload, 418 call_id=last_call.call_id, 419 ) 420 self._process_enqueued_packets() 421 422 # Assert that only the last caller received a response 423 last_callback.assert_called_once_with(last_call, last_payload) 424 for remaining_callback, _ in callbacks_and_calls: 425 remaining_callback.assert_not_called() 426 427 # Respond to the other callers and check for receipt 428 other_payload = self.method.response_type(payload='other payload') 429 for callback, call in callbacks_and_calls: 430 self._enqueue_response( 431 CLIENT_CHANNEL_ID, 432 self.method, 433 payload=other_payload, 434 call_id=call.call_id, 435 ) 436 self._process_enqueued_packets() 437 callback.assert_called_once_with(call, other_payload) 438 439 def test_open(self) -> None: 440 self.output_exception = IOError('this test should not send packets!') 441 442 for packet_id in (client.OPEN_CALL_ID, 123): 443 for _ in range(3): 444 self._enqueue_response( 445 CLIENT_CHANNEL_ID, 446 self.method, 447 Status.ABORTED, 448 self.method.response_type(payload='0_o'), 449 call_id=packet_id, 450 ) 451 452 callback = mock.Mock() 453 call = self.rpc.open(callback, callback, callback) 454 self.assertEqual(self.requests, []) 455 456 self._process_enqueued_packets() 457 458 callback.assert_has_calls( 459 [ 460 mock.call( 461 call, self.method.response_type(payload='0_o') 462 ), 463 mock.call(call, Status.ABORTED), 464 ] 465 ) 466 self.assertEqual(call.call_id, packet_id, "Adopts inbound ID") 467 468 def test_blocking_server_error(self) -> None: 469 for _ in range(3): 470 self._enqueue_error( 471 CLIENT_CHANNEL_ID, 472 self.method.service, 473 self.method, 474 Status.NOT_FOUND, 475 ) 476 477 with self.assertRaises(callback_client.RpcError) as context: 478 self._service.SomeUnary( 479 self.method.request_type(magic_number=6) 480 ) 481 482 self.assertIs(context.exception.status, Status.NOT_FOUND) 483 484 def test_nonblocking_cancel(self) -> None: 485 callback = mock.Mock() 486 487 for _ in range(3): 488 call = self._service.SomeUnary.invoke( 489 self._request(magic_number=55), callback 490 ) 491 492 self.assertGreater(len(self.requests), 0) 493 self.requests.clear() 494 495 self.assertTrue(call.cancel()) 496 self.assertFalse(call.cancel()) # Already cancelled, returns False 497 498 self.assertEqual( 499 self.last_request().type, packet_pb2.PacketType.CLIENT_ERROR 500 ) 501 self.assertEqual(self.last_request().status, Status.CANCELLED.value) 502 503 callback.assert_not_called() 504 505 def test_nonblocking_with_request_args(self) -> None: 506 self.rpc.invoke(request_args=dict(magic_number=1138)) 507 self.assertEqual( 508 self._sent_payload(self.rpc.request).magic_number, 1138 509 ) 510 511 def test_blocking_timeout_as_argument(self) -> None: 512 with self.assertRaises(callback_client.RpcTimeout): 513 self._service.SomeUnary(pw_rpc_timeout_s=0.0001) 514 515 def test_blocking_timeout_set_default(self) -> None: 516 self._service.SomeUnary.default_timeout_s = 0.0001 517 518 with self.assertRaises(callback_client.RpcTimeout): 519 self._service.SomeUnary() 520 521 def test_nonblocking_duplicate_calls_not_cancelled(self) -> None: 522 first_call = self.rpc.invoke() 523 self.assertFalse(first_call.completed()) 524 525 second_call = self.rpc.invoke() 526 527 self.assertIs(first_call.error, None) 528 self.assertIs(second_call.error, None) 529 530 def test_nonblocking_exception_in_callback(self) -> None: 531 exception = ValueError('something went wrong! (intentionally)') 532 533 self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK) 534 535 call = self.rpc.invoke(on_completed=mock.Mock(side_effect=exception)) 536 537 with self.assertRaises(RuntimeError) as context: 538 call.wait() 539 540 self.assertEqual(context.exception.__cause__, exception) 541 542 def test_unary_response(self) -> None: 543 proto = PROTOS.packages.pw.test1.SomeMessage(magic_number=123) 544 self.assertEqual( 545 repr(callback_client.UnaryResponse(Status.ABORTED, proto)), 546 '(Status.ABORTED, pw.test1.SomeMessage(magic_number=123))', 547 ) 548 self.assertEqual( 549 repr(callback_client.UnaryResponse(Status.OK, None)), 550 '(Status.OK, None)', 551 ) 552 553 def test_on_call_hook(self) -> None: 554 hook_function = mock.Mock() 555 556 self._client = client.Client.from_modules( 557 callback_client.Impl(on_call_hook=hook_function), 558 [client.Channel(CLIENT_CHANNEL_ID, self._handle_packet)], 559 PROTOS.modules(), 560 ) 561 562 self._service = self._client.channel( 563 CLIENT_CHANNEL_ID 564 ).rpcs.pw.test1.PublicService 565 566 self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK) 567 self._service.SomeUnary(self.method.request_type(magic_number=6)) 568 569 hook_function.assert_called_once() 570 self.assertEqual( 571 hook_function.call_args[0][0].method.full_name, 572 self.method.full_name, 573 ) 574 575 576class ServerStreamingTest(_CallbackClientImplTestBase): 577 """Tests for server streaming RPCs.""" 578 579 def setUp(self) -> None: 580 super().setUp() 581 self.rpc = self._service.SomeServerStreaming 582 self.method = self.rpc.method 583 584 def test_blocking_call(self) -> None: 585 rep1 = self.method.response_type(payload='!!!') 586 rep2 = self.method.response_type(payload='?') 587 588 for _ in range(3): 589 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1) 590 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep2) 591 self._enqueue_response( 592 CLIENT_CHANNEL_ID, self.method, Status.ABORTED 593 ) 594 595 self.assertEqual( 596 [rep1, rep2], 597 self._service.SomeServerStreaming(magic_number=4).responses, 598 ) 599 600 self.assertEqual( 601 4, self._sent_payload(self.method.request_type).magic_number 602 ) 603 604 def test_nonblocking_call(self) -> None: 605 rep1 = self.method.response_type(payload='!!!') 606 rep2 = self.method.response_type(payload='?') 607 608 for _ in range(3): 609 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1) 610 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep2) 611 self._enqueue_response( 612 CLIENT_CHANNEL_ID, self.method, Status.ABORTED 613 ) 614 615 callback = mock.Mock() 616 call = self.rpc.invoke( 617 self._request(magic_number=3), callback, callback 618 ) 619 620 callback.assert_has_calls( 621 [ 622 mock.call(call, self.method.response_type(payload='!!!')), 623 mock.call(call, self.method.response_type(payload='?')), 624 mock.call(call, Status.ABORTED), 625 ] 626 ) 627 628 self.assertEqual( 629 3, self._sent_payload(self.method.request_type).magic_number 630 ) 631 632 def test_open(self) -> None: 633 self.output_exception = IOError('this test should not send packets!') 634 rep1 = self.method.response_type(payload='!!!') 635 rep2 = self.method.response_type(payload='?') 636 637 for packet_id in (client.OPEN_CALL_ID, 123): 638 for _ in range(3): 639 self._enqueue_server_stream( 640 CLIENT_CHANNEL_ID, self.method, rep1, call_id=packet_id 641 ) 642 self._enqueue_server_stream( 643 CLIENT_CHANNEL_ID, self.method, rep2, call_id=packet_id 644 ) 645 self._enqueue_response( 646 CLIENT_CHANNEL_ID, 647 self.method, 648 Status.ABORTED, 649 call_id=packet_id, 650 ) 651 652 callback = mock.Mock() 653 call = self.rpc.open(callback, callback, callback) 654 self.assertEqual(self.requests, []) 655 656 self._process_enqueued_packets() 657 658 callback.assert_has_calls( 659 [ 660 mock.call( 661 call, self.method.response_type(payload='!!!') 662 ), 663 mock.call(call, self.method.response_type(payload='?')), 664 mock.call(call, Status.ABORTED), 665 ] 666 ) 667 self.assertEqual(call.call_id, packet_id, "Adopts inbound ID") 668 669 def test_nonblocking_cancel(self) -> None: 670 resp = self.rpc.method.response_type(payload='!!!') 671 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.rpc.method, resp) 672 673 callback = mock.Mock() 674 call = self.rpc.invoke(self._request(magic_number=3), callback) 675 callback.assert_called_once_with( 676 call, self.rpc.method.response_type(payload='!!!') 677 ) 678 679 callback.reset_mock() 680 681 call.cancel() 682 683 self.assertEqual( 684 self.last_request().type, packet_pb2.PacketType.CLIENT_ERROR 685 ) 686 self.assertEqual(self.last_request().status, Status.CANCELLED.value) 687 688 # Ensure the RPC can be called after being cancelled. 689 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, resp) 690 self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK) 691 692 call = self.rpc.invoke( 693 self._request(magic_number=3), callback, callback 694 ) 695 696 callback.assert_has_calls( 697 [ 698 mock.call(call, self.method.response_type(payload='!!!')), 699 mock.call(call, Status.OK), 700 ] 701 ) 702 703 def test_request_completion(self) -> None: 704 resp = self.rpc.method.response_type(payload='!!!') 705 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.rpc.method, resp) 706 707 callback = mock.Mock() 708 call = self.rpc.invoke(self._request(magic_number=3), callback) 709 callback.assert_called_once_with( 710 call, self.rpc.method.response_type(payload='!!!') 711 ) 712 713 callback.reset_mock() 714 715 call.request_completion() 716 717 self.assertEqual( 718 self.last_request().type, 719 packet_pb2.PacketType.CLIENT_REQUEST_COMPLETION, 720 ) 721 722 def test_nonblocking_with_request_args(self) -> None: 723 self.rpc.invoke(request_args=dict(magic_number=1138)) 724 self.assertEqual( 725 self._sent_payload(self.rpc.request).magic_number, 1138 726 ) 727 728 def test_blocking_timeout(self) -> None: 729 with self.assertRaises(callback_client.RpcTimeout): 730 self._service.SomeServerStreaming(pw_rpc_timeout_s=0.0001) 731 732 def test_nonblocking_iteration_timeout(self) -> None: 733 call = self._service.SomeServerStreaming.invoke(timeout_s=0.0001) 734 with self.assertRaises(callback_client.RpcTimeout): 735 for _ in call: 736 pass 737 738 def test_nonblocking_duplicate_calls_not_cancelled(self) -> None: 739 first_call = self.rpc.invoke() 740 self.assertFalse(first_call.completed()) 741 742 second_call = self.rpc.invoke() 743 744 self.assertIs(first_call.error, None) 745 self.assertIs(second_call.error, None) 746 747 def test_nonblocking_iterate_over_count(self) -> None: 748 reply = self.method.response_type(payload='!?') 749 750 for _ in range(4): 751 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, reply) 752 753 call = self.rpc.invoke() 754 755 self.assertEqual(list(call.get_responses(count=1)), [reply]) 756 self.assertEqual(next(iter(call)), reply) 757 self.assertEqual(list(call.get_responses(count=2)), [reply, reply]) 758 759 def test_nonblocking_iterate_after_completed_doesnt_block(self) -> None: 760 reply = self.method.response_type(payload='!?') 761 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, reply) 762 self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK) 763 764 call = self.rpc.invoke() 765 766 self.assertEqual(list(call.get_responses()), [reply]) 767 self.assertEqual(list(call.get_responses()), []) 768 self.assertEqual(list(call), []) 769 770 771class ClientStreamingTest(_CallbackClientImplTestBase): 772 """Tests for client streaming RPCs.""" 773 774 def setUp(self) -> None: 775 super().setUp() 776 self.rpc = self._service.SomeClientStreaming 777 self.method = self.rpc.method 778 779 def test_blocking_call(self) -> None: 780 requests = [ 781 self.method.request_type(magic_number=123), 782 self.method.request_type(magic_number=456), 783 ] 784 785 # Send after len(requests) and the client stream end packet. 786 self.send_responses_after_packets = 3 787 response = self.method.response_type(payload='yo') 788 self._enqueue_response( 789 CLIENT_CHANNEL_ID, self.method, Status.OK, response 790 ) 791 792 results = self.rpc(requests) 793 self.assertIs(results.status, Status.OK) 794 self.assertEqual(results.response, response) 795 796 def test_blocking_server_error(self) -> None: 797 requests = [self.method.request_type(magic_number=123)] 798 799 # Send after len(requests) and the client stream end packet. 800 self._enqueue_error( 801 CLIENT_CHANNEL_ID, 802 self.method.service, 803 self.method, 804 Status.NOT_FOUND, 805 ) 806 807 with self.assertRaises(callback_client.RpcError) as context: 808 self.rpc(requests) 809 810 self.assertIs(context.exception.status, Status.NOT_FOUND) 811 812 def test_nonblocking_call(self) -> None: 813 """Tests a successful client streaming RPC ended by the server.""" 814 payload_1 = self.method.response_type(payload='-_-') 815 816 for _ in range(3): 817 stream = self._service.SomeClientStreaming.invoke() 818 self.assertFalse(stream.completed()) 819 820 stream.send(magic_number=31) 821 self.assertIs( 822 packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type 823 ) 824 self.assertEqual( 825 31, self._sent_payload(self.method.request_type).magic_number 826 ) 827 self.assertFalse(stream.completed()) 828 829 # Enqueue the server response to be sent after the next message. 830 self._enqueue_response( 831 CLIENT_CHANNEL_ID, self.method, Status.OK, payload_1 832 ) 833 834 stream.send(magic_number=32) 835 self.assertIs( 836 packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type 837 ) 838 self.assertEqual( 839 32, self._sent_payload(self.method.request_type).magic_number 840 ) 841 842 self.assertTrue(stream.completed()) 843 self.assertIs(Status.OK, stream.status) 844 self.assertIsNone(stream.error) 845 self.assertEqual(payload_1, stream.response) 846 847 def test_open(self) -> None: 848 self.output_exception = IOError('this test should not send packets!') 849 payload = self.method.response_type(payload='-_-') 850 851 for packet_id in (client.OPEN_CALL_ID, 123): 852 for _ in range(3): 853 self._enqueue_response( 854 CLIENT_CHANNEL_ID, 855 self.method, 856 Status.OK, 857 payload, 858 call_id=packet_id, 859 ) 860 861 callback = mock.Mock() 862 call = self.rpc.open(callback, callback, callback) 863 self.assertEqual(self.requests, []) 864 865 self._process_enqueued_packets() 866 867 callback.assert_has_calls( 868 [ 869 mock.call(call, payload), 870 mock.call(call, Status.OK), 871 ] 872 ) 873 self.assertEqual(call.call_id, packet_id, "Adopts inbound ID") 874 875 def test_nonblocking_finish(self) -> None: 876 """Tests a client streaming RPC ended by the client.""" 877 payload_1 = self.method.response_type(payload='-_-') 878 879 for _ in range(3): 880 stream = self._service.SomeClientStreaming.invoke() 881 self.assertFalse(stream.completed()) 882 883 stream.send(magic_number=37) 884 self.assertIs( 885 packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type 886 ) 887 self.assertEqual( 888 37, self._sent_payload(self.method.request_type).magic_number 889 ) 890 self.assertFalse(stream.completed()) 891 892 # Enqueue the server response to be sent after the next message. 893 self._enqueue_response( 894 CLIENT_CHANNEL_ID, self.method, Status.OK, payload_1 895 ) 896 897 stream.finish_and_wait() 898 self.assertIs( 899 packet_pb2.PacketType.CLIENT_REQUEST_COMPLETION, 900 self.last_request().type, 901 ) 902 903 self.assertTrue(stream.completed()) 904 self.assertIs(Status.OK, stream.status) 905 self.assertIsNone(stream.error) 906 self.assertEqual(payload_1, stream.response) 907 908 def test_nonblocking_cancel(self) -> None: 909 for _ in range(3): 910 stream = self._service.SomeClientStreaming.invoke() 911 stream.send(magic_number=37) 912 913 self.assertTrue(stream.cancel()) 914 self.assertIs( 915 packet_pb2.PacketType.CLIENT_ERROR, self.last_request().type 916 ) 917 self.assertIs(Status.CANCELLED.value, self.last_request().status) 918 self.assertFalse(stream.cancel()) 919 920 self.assertTrue(stream.completed()) 921 self.assertIs(stream.error, Status.CANCELLED) 922 923 def test_nonblocking_server_error(self) -> None: 924 for _ in range(3): 925 stream = self._service.SomeClientStreaming.invoke() 926 927 self._enqueue_error( 928 CLIENT_CHANNEL_ID, 929 self.method.service, 930 self.method, 931 Status.INVALID_ARGUMENT, 932 ) 933 stream.send(magic_number=2**32 - 1) 934 935 with self.assertRaises(callback_client.RpcError) as context: 936 stream.finish_and_wait() 937 938 self.assertIs(context.exception.status, Status.INVALID_ARGUMENT) 939 940 def test_nonblocking_server_error_after_stream_end(self) -> None: 941 for _ in range(3): 942 stream = self._service.SomeClientStreaming.invoke() 943 944 # Error will be sent in response to the CLIENT_REQUEST_COMPLETION 945 # packet. 946 self._enqueue_error( 947 CLIENT_CHANNEL_ID, 948 self.method.service, 949 self.method, 950 Status.INVALID_ARGUMENT, 951 ) 952 953 with self.assertRaises(callback_client.RpcError) as context: 954 stream.finish_and_wait() 955 956 self.assertIs(context.exception.status, Status.INVALID_ARGUMENT) 957 958 def test_nonblocking_send_after_cancelled(self) -> None: 959 call = self._service.SomeClientStreaming.invoke() 960 self.assertTrue(call.cancel()) 961 962 with self.assertRaises(callback_client.RpcError) as context: 963 call.send(payload='hello') 964 965 self.assertIs(context.exception.status, Status.CANCELLED) 966 967 def test_nonblocking_finish_after_completed(self) -> None: 968 reply = self.method.response_type(payload='!?') 969 self._enqueue_response( 970 CLIENT_CHANNEL_ID, self.method, Status.UNAVAILABLE, reply 971 ) 972 973 call = self.rpc.invoke() 974 result = call.finish_and_wait() 975 self.assertEqual(result.response, reply) 976 977 self.assertEqual(result, call.finish_and_wait()) 978 self.assertEqual(result, call.finish_and_wait()) 979 980 def test_nonblocking_finish_after_error(self) -> None: 981 self._enqueue_error( 982 CLIENT_CHANNEL_ID, 983 self.method.service, 984 self.method, 985 Status.UNAVAILABLE, 986 ) 987 988 call = self.rpc.invoke() 989 990 for _ in range(3): 991 with self.assertRaises(callback_client.RpcError) as context: 992 call.finish_and_wait() 993 994 self.assertIs(context.exception.status, Status.UNAVAILABLE) 995 self.assertIs(call.error, Status.UNAVAILABLE) 996 self.assertIsNone(call.response) 997 998 def test_nonblocking_duplicate_calls_not_cancelled(self) -> None: 999 first_call = self.rpc.invoke() 1000 self.assertFalse(first_call.completed()) 1001 1002 second_call = self.rpc.invoke() 1003 1004 self.assertIs(first_call.error, None) 1005 self.assertIs(second_call.error, None) 1006 1007 1008class BidirectionalStreamingTest(_CallbackClientImplTestBase): 1009 """Tests for bidirectional streaming RPCs.""" 1010 1011 def setUp(self) -> None: 1012 super().setUp() 1013 self.rpc = self._service.SomeBidiStreaming 1014 self.method = self.rpc.method 1015 1016 def test_blocking_call(self) -> None: 1017 requests = [ 1018 self.method.request_type(magic_number=123), 1019 self.method.request_type(magic_number=456), 1020 ] 1021 1022 # Send after len(requests) and the client stream end packet. 1023 self.send_responses_after_packets = 3 1024 self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.NOT_FOUND) 1025 1026 results = self.rpc(requests) 1027 self.assertIs(results.status, Status.NOT_FOUND) 1028 self.assertFalse(results.responses) 1029 1030 def test_blocking_server_error(self) -> None: 1031 requests = [self.method.request_type(magic_number=123)] 1032 1033 # Send after len(requests) and the client stream end packet. 1034 self._enqueue_error( 1035 CLIENT_CHANNEL_ID, 1036 self.method.service, 1037 self.method, 1038 Status.NOT_FOUND, 1039 ) 1040 1041 with self.assertRaises(callback_client.RpcError) as context: 1042 self.rpc(requests) 1043 1044 self.assertIs(context.exception.status, Status.NOT_FOUND) 1045 1046 def test_nonblocking_call(self) -> None: 1047 """Tests a bidirectional streaming RPC ended by the server.""" 1048 rep1 = self.method.response_type(payload='!!!') 1049 rep2 = self.method.response_type(payload='?') 1050 1051 for _ in range(3): 1052 responses: list = [] 1053 stream = self._service.SomeBidiStreaming.invoke( 1054 lambda _, res, responses=responses: responses.append(res) 1055 ) 1056 self.assertFalse(stream.completed()) 1057 1058 stream.send(magic_number=55) 1059 self.assertIs( 1060 packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type 1061 ) 1062 self.assertEqual( 1063 55, self._sent_payload(self.method.request_type).magic_number 1064 ) 1065 self.assertFalse(stream.completed()) 1066 self.assertEqual([], responses) 1067 1068 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1) 1069 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep2) 1070 1071 stream.send(magic_number=66) 1072 self.assertIs( 1073 packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type 1074 ) 1075 self.assertEqual( 1076 66, self._sent_payload(self.method.request_type).magic_number 1077 ) 1078 self.assertFalse(stream.completed()) 1079 self.assertEqual([rep1, rep2], responses) 1080 1081 self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK) 1082 1083 stream.send(magic_number=77) 1084 self.assertTrue(stream.completed()) 1085 self.assertEqual([rep1, rep2], responses) 1086 1087 self.assertIs(Status.OK, stream.status) 1088 self.assertIsNone(stream.error) 1089 1090 def test_open(self) -> None: 1091 self.output_exception = IOError('this test should not send packets!') 1092 rep1 = self.method.response_type(payload='!!!') 1093 rep2 = self.method.response_type(payload='?') 1094 1095 for packet_id in (client.OPEN_CALL_ID, 123): 1096 for _ in range(3): 1097 self._enqueue_server_stream( 1098 CLIENT_CHANNEL_ID, self.method, rep1, call_id=packet_id 1099 ) 1100 self._enqueue_server_stream( 1101 CLIENT_CHANNEL_ID, self.method, rep2, call_id=packet_id 1102 ) 1103 self._enqueue_response( 1104 CLIENT_CHANNEL_ID, self.method, Status.OK, call_id=packet_id 1105 ) 1106 1107 callback = mock.Mock() 1108 call = self.rpc.open(callback, callback, callback) 1109 self.assertEqual(self.requests, []) 1110 1111 self._process_enqueued_packets() 1112 1113 callback.assert_has_calls( 1114 [ 1115 mock.call( 1116 call, self.method.response_type(payload='!!!') 1117 ), 1118 mock.call(call, self.method.response_type(payload='?')), 1119 mock.call(call, Status.OK), 1120 ] 1121 ) 1122 self.assertEqual(call.call_id, packet_id, "Adopts inbound ID") 1123 1124 @mock.patch('pw_rpc.callback_client.call.Call._default_response') 1125 def test_nonblocking(self, callback) -> None: 1126 """Tests a bidirectional streaming RPC ended by the server.""" 1127 reply = self.method.response_type(payload='This is the payload!') 1128 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, reply) 1129 1130 self._service.SomeBidiStreaming.invoke() 1131 1132 callback.assert_called_once_with(mock.ANY, reply) 1133 1134 def test_nonblocking_server_error(self) -> None: 1135 rep1 = self.method.response_type(payload='!!!') 1136 1137 for _ in range(3): 1138 responses: list = [] 1139 stream = self._service.SomeBidiStreaming.invoke( 1140 lambda _, res, responses=responses: responses.append(res) 1141 ) 1142 self.assertFalse(stream.completed()) 1143 1144 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1) 1145 1146 stream.send(magic_number=55) 1147 self.assertFalse(stream.completed()) 1148 self.assertEqual([rep1], responses) 1149 1150 self._enqueue_error( 1151 CLIENT_CHANNEL_ID, 1152 self.method.service, 1153 self.method, 1154 Status.OUT_OF_RANGE, 1155 ) 1156 1157 stream.send(magic_number=99999) 1158 self.assertTrue(stream.completed()) 1159 self.assertEqual([rep1], responses) 1160 1161 self.assertIsNone(stream.status) 1162 self.assertIs(Status.OUT_OF_RANGE, stream.error) 1163 1164 with self.assertRaises(callback_client.RpcError) as context: 1165 stream.finish_and_wait() 1166 self.assertIs(context.exception.status, Status.OUT_OF_RANGE) 1167 1168 def test_nonblocking_server_error_after_stream_end(self) -> None: 1169 for _ in range(3): 1170 stream = self._service.SomeBidiStreaming.invoke() 1171 1172 # Error will be sent in response to the CLIENT_REQUEST_COMPLETION 1173 # packet. 1174 self._enqueue_error( 1175 CLIENT_CHANNEL_ID, 1176 self.method.service, 1177 self.method, 1178 Status.INVALID_ARGUMENT, 1179 ) 1180 1181 with self.assertRaises(callback_client.RpcError) as context: 1182 stream.finish_and_wait() 1183 1184 self.assertIs(context.exception.status, Status.INVALID_ARGUMENT) 1185 1186 def test_nonblocking_send_after_cancelled(self) -> None: 1187 call = self._service.SomeBidiStreaming.invoke() 1188 self.assertTrue(call.cancel()) 1189 1190 with self.assertRaises(callback_client.RpcError) as context: 1191 call.send(payload='hello') 1192 1193 self.assertIs(context.exception.status, Status.CANCELLED) 1194 1195 def test_nonblocking_finish_after_completed(self) -> None: 1196 reply = self.method.response_type(payload='!?') 1197 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, reply) 1198 self._enqueue_response( 1199 CLIENT_CHANNEL_ID, self.method, Status.UNAVAILABLE 1200 ) 1201 1202 call = self.rpc.invoke() 1203 result = call.finish_and_wait() 1204 self.assertEqual(result.responses, [reply]) 1205 1206 self.assertEqual(result, call.finish_and_wait()) 1207 self.assertEqual(result, call.finish_and_wait()) 1208 1209 def test_nonblocking_finish_after_error(self) -> None: 1210 reply = self.method.response_type(payload='!?') 1211 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, reply) 1212 self._enqueue_error( 1213 CLIENT_CHANNEL_ID, 1214 self.method.service, 1215 self.method, 1216 Status.UNAVAILABLE, 1217 ) 1218 1219 call = self.rpc.invoke() 1220 1221 for _ in range(3): 1222 with self.assertRaises(callback_client.RpcError) as context: 1223 call.finish_and_wait() 1224 1225 self.assertIs(context.exception.status, Status.UNAVAILABLE) 1226 self.assertIs(call.error, Status.UNAVAILABLE) 1227 self.assertEqual(list(call.responses), [reply]) 1228 1229 def test_nonblocking_duplicate_calls_not_cancelled(self) -> None: 1230 first_call = self.rpc.invoke() 1231 self.assertFalse(first_call.completed()) 1232 1233 second_call = self.rpc.invoke() 1234 1235 self.assertIs(first_call.error, None) 1236 self.assertIs(second_call.error, None) 1237 1238 def test_max_responses(self) -> None: 1239 rep1 = self.method.response_type(payload='a') 1240 rep2 = self.method.response_type(payload='b') 1241 rep3 = self.method.response_type(payload='c') 1242 rep4 = self.method.response_type(payload='d') 1243 rep5 = self.method.response_type(payload='e') 1244 1245 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1) 1246 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep2) 1247 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep3) 1248 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep4) 1249 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep5) 1250 self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK) 1251 1252 responses: list = [] 1253 call = self.rpc.invoke( 1254 on_next=lambda _, res, responses=responses: responses.append(res), 1255 max_responses=4, 1256 ) 1257 result = call.finish_and_wait() 1258 1259 # All 5 responses are received, but only the most recent 4 are stored 1260 # in the call. 1261 self.assertEqual(responses, [rep1, rep2, rep3, rep4, rep5]) 1262 self.assertEqual(result.responses, [rep2, rep3, rep4, rep5]) 1263 self.assertEqual(result.responses, list(call.responses)) 1264 1265 def test_stream_response(self) -> None: 1266 proto = PROTOS.packages.pw.test1.SomeMessage(magic_number=123) 1267 self.assertEqual( 1268 repr(callback_client.StreamResponse(Status.ABORTED, [proto] * 2)), 1269 '(Status.ABORTED, [pw.test1.SomeMessage(magic_number=123), ' 1270 'pw.test1.SomeMessage(magic_number=123)])', 1271 ) 1272 self.assertEqual( 1273 repr(callback_client.StreamResponse(Status.OK, [])), 1274 '(Status.OK, [])', 1275 ) 1276 1277 1278if __name__ == '__main__': 1279 unittest.main() 1280