1#!/usr/bin/env python3 2# Copyright 2020 The Pigweed Authors 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); you may not 5# use this file except in compliance with the License. You may obtain a copy of 6# the License at 7# 8# https://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13# License for the specific language governing permissions and limitations under 14# the License. 15"""Tests creating pw_rpc client.""" 16 17import unittest 18from typing import Any, Callable 19 20from pw_protobuf_compiler import python_protos 21from pw_status import Status 22 23from pw_rpc import callback_client, client, packets 24import pw_rpc.ids 25from pw_rpc.internal.packet_pb2 import PacketType, RpcPacket 26from pw_rpc.descriptors import RpcIds 27 28TEST_PROTO_1 = """\ 29syntax = "proto3"; 30 31package pw.test1; 32 33message SomeMessage { 34 uint32 magic_number = 1; 35} 36 37message AnotherMessage { 38 enum Result { 39 FAILED = 0; 40 FAILED_MISERABLY = 1; 41 I_DONT_WANT_TO_TALK_ABOUT_IT = 2; 42 } 43 44 Result result = 1; 45 string payload = 2; 46} 47 48service PublicService { 49 rpc SomeUnary(SomeMessage) returns (AnotherMessage) {} 50 rpc SomeServerStreaming(SomeMessage) returns (stream AnotherMessage) {} 51 rpc SomeClientStreaming(stream SomeMessage) returns (AnotherMessage) {} 52 rpc SomeBidiStreaming(stream SomeMessage) returns (stream AnotherMessage) {} 53} 54""" 55 56TEST_PROTO_2 = """\ 57syntax = "proto2"; 58 59package pw.test2; 60 61message Request { 62 optional float magic_number = 1; 63} 64 65message Response { 66} 67 68service Alpha { 69 rpc Unary(Request) returns (Response) {} 70} 71 72service Bravo { 73 rpc BidiStreaming(stream Request) returns (stream Response) {} 74} 75""" 76 77PROTOS = python_protos.Library.from_strings([TEST_PROTO_1, TEST_PROTO_2]) 78 79SOME_CHANNEL_ID: int = 237 80SOME_SERVICE_ID: int = 193 81SOME_METHOD_ID: int = 769 82SOME_CALL_ID: int = 452 83 84CLIENT_FIRST_CHANNEL_ID: int = 557 85CLIENT_SECOND_CHANNEL_ID: int = 474 86 87 88def create_client( 89 proto_modules: Any, 90 first_channel_output_fn: Callable[[bytes], Any] = lambda _: None, 91) -> client.Client: 92 return client.Client.from_modules( 93 callback_client.Impl(), 94 [ 95 client.Channel(CLIENT_FIRST_CHANNEL_ID, first_channel_output_fn), 96 client.Channel(CLIENT_SECOND_CHANNEL_ID, lambda _: None), 97 ], 98 proto_modules, 99 ) 100 101 102class ChannelClientTest(unittest.TestCase): 103 """Tests the ChannelClient.""" 104 105 def setUp(self) -> None: 106 client_instance = create_client(PROTOS.modules()) 107 self._channel_client: client.ChannelClient = client_instance.channel( 108 CLIENT_FIRST_CHANNEL_ID 109 ) 110 111 def test_access_service_client_as_attribute_or_index(self) -> None: 112 self.assertIs( 113 self._channel_client.rpcs.pw.test1.PublicService, 114 self._channel_client.rpcs['pw.test1.PublicService'], 115 ) 116 self.assertIs( 117 self._channel_client.rpcs.pw.test1.PublicService, 118 self._channel_client.rpcs[ 119 pw_rpc.ids.calculate('pw.test1.PublicService') 120 ], 121 ) 122 123 def test_access_method_client_as_attribute_or_index(self) -> None: 124 self.assertIs( 125 self._channel_client.rpcs.pw.test2.Alpha.Unary, 126 self._channel_client.rpcs['pw.test2.Alpha']['Unary'], 127 ) 128 self.assertIs( 129 self._channel_client.rpcs.pw.test2.Alpha.Unary, 130 self._channel_client.rpcs['pw.test2.Alpha'][ 131 pw_rpc.ids.calculate('Unary') 132 ], 133 ) 134 135 def test_service_name(self) -> None: 136 self.assertEqual( 137 self._channel_client.rpcs.pw.test2.Alpha.Unary.service.name, 'Alpha' 138 ) 139 self.assertEqual( 140 self._channel_client.rpcs.pw.test2.Alpha.Unary.service.full_name, 141 'pw.test2.Alpha', 142 ) 143 144 def test_method_name(self) -> None: 145 self.assertEqual( 146 self._channel_client.rpcs.pw.test2.Alpha.Unary.method.name, 'Unary' 147 ) 148 self.assertEqual( 149 self._channel_client.rpcs.pw.test2.Alpha.Unary.method.full_name, 150 'pw.test2.Alpha.Unary', 151 ) 152 153 def test_iterate_over_all_methods(self) -> None: 154 channel_client = self._channel_client 155 all_methods = { 156 channel_client.rpcs.pw.test1.PublicService.SomeUnary, 157 channel_client.rpcs.pw.test1.PublicService.SomeServerStreaming, 158 channel_client.rpcs.pw.test1.PublicService.SomeClientStreaming, 159 channel_client.rpcs.pw.test1.PublicService.SomeBidiStreaming, 160 channel_client.rpcs.pw.test2.Alpha.Unary, 161 channel_client.rpcs.pw.test2.Bravo.BidiStreaming, 162 } 163 self.assertEqual(set(channel_client.methods()), all_methods) 164 165 def test_check_for_presence_of_services(self) -> None: 166 self.assertIn('pw.test1.PublicService', self._channel_client.rpcs) 167 self.assertIn( 168 pw_rpc.ids.calculate('pw.test1.PublicService'), 169 self._channel_client.rpcs, 170 ) 171 172 def test_check_for_presence_of_missing_services(self) -> None: 173 self.assertNotIn('PublicService', self._channel_client.rpcs) 174 self.assertNotIn('NotAService', self._channel_client.rpcs) 175 self.assertNotIn(-1213, self._channel_client.rpcs) 176 177 def test_check_for_presence_of_methods(self) -> None: 178 service = self._channel_client.rpcs.pw.test1.PublicService 179 self.assertIn('SomeUnary', service) 180 self.assertIn(pw_rpc.ids.calculate('SomeUnary'), service) 181 182 def test_check_for_presence_of_missing_methods(self) -> None: 183 service = self._channel_client.rpcs.pw.test1.PublicService 184 self.assertNotIn('Some', service) 185 self.assertNotIn('Unary', service) 186 self.assertNotIn(12345, service) 187 188 def test_method_fully_qualified_name(self) -> None: 189 self.assertIs( 190 self._channel_client.method('pw.test2.Alpha/Unary'), 191 self._channel_client.rpcs.pw.test2.Alpha.Unary, 192 ) 193 self.assertIs( 194 self._channel_client.method('pw.test2.Alpha.Unary'), 195 self._channel_client.rpcs.pw.test2.Alpha.Unary, 196 ) 197 198 199class ClientTest(unittest.TestCase): 200 """Tests the pw_rpc Client independently of the ClientImpl.""" 201 202 def setUp(self) -> None: 203 self._last_packet_sent_bytes: bytes | None = None 204 self._client = create_client(PROTOS.modules(), self._save_packet) 205 206 def _save_packet(self, packet) -> None: 207 self._last_packet_sent_bytes = packet 208 209 def _last_packet_sent(self) -> RpcPacket: 210 packet = RpcPacket() 211 assert self._last_packet_sent_bytes is not None 212 packet.MergeFromString(self._last_packet_sent_bytes) 213 return packet 214 215 def test_channel(self) -> None: 216 self.assertEqual( 217 self._client.channel(CLIENT_FIRST_CHANNEL_ID).channel.id, 218 CLIENT_FIRST_CHANNEL_ID, 219 ) 220 self.assertEqual( 221 self._client.channel(CLIENT_SECOND_CHANNEL_ID).channel.id, 222 CLIENT_SECOND_CHANNEL_ID, 223 ) 224 225 def test_channel_default_is_first_listed(self) -> None: 226 self.assertEqual( 227 self._client.channel().channel.id, CLIENT_FIRST_CHANNEL_ID 228 ) 229 230 def test_channel_invalid(self) -> None: 231 with self.assertRaises(KeyError): 232 self._client.channel(404) 233 234 def test_all_methods(self) -> None: 235 services = self._client.services 236 237 all_methods = { 238 services['pw.test1.PublicService'].methods['SomeUnary'], 239 services['pw.test1.PublicService'].methods['SomeServerStreaming'], 240 services['pw.test1.PublicService'].methods['SomeClientStreaming'], 241 services['pw.test1.PublicService'].methods['SomeBidiStreaming'], 242 services['pw.test2.Alpha'].methods['Unary'], 243 services['pw.test2.Bravo'].methods['BidiStreaming'], 244 } 245 self.assertEqual(set(self._client.methods()), all_methods) 246 247 def test_method_present(self) -> None: 248 self.assertIs( 249 self._client.method('pw.test1.PublicService.SomeUnary'), 250 self._client.services['pw.test1.PublicService'].methods[ 251 'SomeUnary' 252 ], 253 ) 254 self.assertIs( 255 self._client.method('pw.test1.PublicService/SomeUnary'), 256 self._client.services['pw.test1.PublicService'].methods[ 257 'SomeUnary' 258 ], 259 ) 260 261 def test_method_invalid_format(self) -> None: 262 with self.assertRaises(ValueError): 263 self._client.method('SomeUnary') 264 265 def test_method_not_present(self) -> None: 266 with self.assertRaises(KeyError): 267 self._client.method('pw.test1.PublicService/ThisIsNotGood') 268 269 with self.assertRaises(KeyError): 270 self._client.method('nothing.Good') 271 272 def test_process_packet_invalid_proto_data(self) -> None: 273 self.assertIs( 274 self._client.process_packet(b'NOT a packet!'), Status.DATA_LOSS 275 ) 276 277 def test_process_packet_not_for_client(self) -> None: 278 self.assertIs( 279 self._client.process_packet( 280 RpcPacket(type=PacketType.REQUEST).SerializeToString() 281 ), 282 Status.INVALID_ARGUMENT, 283 ) 284 285 def test_process_packet_unrecognized_channel(self) -> None: 286 self.assertIs( 287 self._client.process_packet( 288 packets.encode_response( 289 RpcIds( 290 SOME_CHANNEL_ID, 291 SOME_SERVICE_ID, 292 SOME_METHOD_ID, 293 SOME_CALL_ID, 294 ), 295 PROTOS.packages.pw.test2.Request(), 296 ) 297 ), 298 Status.NOT_FOUND, 299 ) 300 301 def test_process_packet_unrecognized_service(self) -> None: 302 self.assertIs( 303 self._client.process_packet( 304 packets.encode_response( 305 RpcIds( 306 CLIENT_FIRST_CHANNEL_ID, 307 SOME_SERVICE_ID, 308 SOME_METHOD_ID, 309 SOME_CALL_ID, 310 ), 311 PROTOS.packages.pw.test2.Request(), 312 ) 313 ), 314 Status.OK, 315 ) 316 317 self.assertEqual( 318 self._last_packet_sent(), 319 RpcPacket( 320 type=PacketType.CLIENT_ERROR, 321 channel_id=CLIENT_FIRST_CHANNEL_ID, 322 service_id=SOME_SERVICE_ID, 323 method_id=SOME_METHOD_ID, 324 call_id=SOME_CALL_ID, 325 status=Status.NOT_FOUND.value, 326 ), 327 ) 328 329 def test_process_packet_unrecognized_method(self) -> None: 330 service = next(iter(self._client.services)) 331 332 self.assertIs( 333 self._client.process_packet( 334 packets.encode_response( 335 RpcIds( 336 CLIENT_FIRST_CHANNEL_ID, 337 service.id, 338 SOME_METHOD_ID, 339 SOME_CALL_ID, 340 ), 341 PROTOS.packages.pw.test2.Request(), 342 ) 343 ), 344 Status.OK, 345 ) 346 347 self.assertEqual( 348 self._last_packet_sent(), 349 RpcPacket( 350 type=PacketType.CLIENT_ERROR, 351 channel_id=CLIENT_FIRST_CHANNEL_ID, 352 service_id=service.id, 353 method_id=SOME_METHOD_ID, 354 call_id=SOME_CALL_ID, 355 status=Status.NOT_FOUND.value, 356 ), 357 ) 358 359 def test_process_packet_non_pending_method(self) -> None: 360 service = next(iter(self._client.services)) 361 method = next(iter(service.methods)) 362 363 self.assertIs( 364 self._client.process_packet( 365 packets.encode_response( 366 RpcIds( 367 CLIENT_FIRST_CHANNEL_ID, 368 service.id, 369 method.id, 370 SOME_CALL_ID, 371 ), 372 PROTOS.packages.pw.test2.Request(), 373 ) 374 ), 375 Status.OK, 376 ) 377 378 self.assertEqual( 379 self._last_packet_sent(), 380 RpcPacket( 381 type=PacketType.CLIENT_ERROR, 382 channel_id=CLIENT_FIRST_CHANNEL_ID, 383 service_id=service.id, 384 method_id=method.id, 385 call_id=SOME_CALL_ID, 386 status=Status.FAILED_PRECONDITION.value, 387 ), 388 ) 389 390 def test_process_packet_non_pending_calls_response_callback(self) -> None: 391 method = self._client.method('pw.test1.PublicService.SomeUnary') 392 reply = method.response_type(payload='hello') 393 394 def response_callback( 395 rpc: client.PendingRpc, 396 message, 397 status: Status | None, 398 ) -> None: 399 self.assertEqual( 400 rpc, 401 client.PendingRpc( 402 self._client.channel(CLIENT_FIRST_CHANNEL_ID).channel, 403 method.service, 404 method, 405 call_id=SOME_CALL_ID, 406 ), 407 ) 408 self.assertEqual(message, reply) 409 self.assertIs(status, Status.OK) 410 411 self._client.response_callback = response_callback 412 413 self.assertIs( 414 self._client.process_packet( 415 packets.encode_response( 416 RpcIds( 417 CLIENT_FIRST_CHANNEL_ID, 418 method.service.id, 419 method.id, 420 SOME_CALL_ID, 421 ), 422 reply, 423 ) 424 ), 425 Status.OK, 426 ) 427 428 429if __name__ == '__main__': 430 unittest.main() 431