xref: /aosp_15_r20/external/pigweed/pw_rpc/py/tests/client_test.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1#!/usr/bin/env python3
2# Copyright 2020 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Tests creating pw_rpc client."""
16
17import unittest
18from typing import Any, Callable
19
20from pw_protobuf_compiler import python_protos
21from pw_status import Status
22
23from pw_rpc import callback_client, client, packets
24import pw_rpc.ids
25from pw_rpc.internal.packet_pb2 import PacketType, RpcPacket
26from pw_rpc.descriptors import RpcIds
27
28TEST_PROTO_1 = """\
29syntax = "proto3";
30
31package pw.test1;
32
33message SomeMessage {
34  uint32 magic_number = 1;
35}
36
37message AnotherMessage {
38  enum Result {
39    FAILED = 0;
40    FAILED_MISERABLY = 1;
41    I_DONT_WANT_TO_TALK_ABOUT_IT = 2;
42  }
43
44  Result result = 1;
45  string payload = 2;
46}
47
48service PublicService {
49  rpc SomeUnary(SomeMessage) returns (AnotherMessage) {}
50  rpc SomeServerStreaming(SomeMessage) returns (stream AnotherMessage) {}
51  rpc SomeClientStreaming(stream SomeMessage) returns (AnotherMessage) {}
52  rpc SomeBidiStreaming(stream SomeMessage) returns (stream AnotherMessage) {}
53}
54"""
55
56TEST_PROTO_2 = """\
57syntax = "proto2";
58
59package pw.test2;
60
61message Request {
62  optional float magic_number = 1;
63}
64
65message Response {
66}
67
68service Alpha {
69  rpc Unary(Request) returns (Response) {}
70}
71
72service Bravo {
73  rpc BidiStreaming(stream Request) returns (stream Response) {}
74}
75"""
76
77PROTOS = python_protos.Library.from_strings([TEST_PROTO_1, TEST_PROTO_2])
78
79SOME_CHANNEL_ID: int = 237
80SOME_SERVICE_ID: int = 193
81SOME_METHOD_ID: int = 769
82SOME_CALL_ID: int = 452
83
84CLIENT_FIRST_CHANNEL_ID: int = 557
85CLIENT_SECOND_CHANNEL_ID: int = 474
86
87
88def create_client(
89    proto_modules: Any,
90    first_channel_output_fn: Callable[[bytes], Any] = lambda _: None,
91) -> client.Client:
92    return client.Client.from_modules(
93        callback_client.Impl(),
94        [
95            client.Channel(CLIENT_FIRST_CHANNEL_ID, first_channel_output_fn),
96            client.Channel(CLIENT_SECOND_CHANNEL_ID, lambda _: None),
97        ],
98        proto_modules,
99    )
100
101
102class ChannelClientTest(unittest.TestCase):
103    """Tests the ChannelClient."""
104
105    def setUp(self) -> None:
106        client_instance = create_client(PROTOS.modules())
107        self._channel_client: client.ChannelClient = client_instance.channel(
108            CLIENT_FIRST_CHANNEL_ID
109        )
110
111    def test_access_service_client_as_attribute_or_index(self) -> None:
112        self.assertIs(
113            self._channel_client.rpcs.pw.test1.PublicService,
114            self._channel_client.rpcs['pw.test1.PublicService'],
115        )
116        self.assertIs(
117            self._channel_client.rpcs.pw.test1.PublicService,
118            self._channel_client.rpcs[
119                pw_rpc.ids.calculate('pw.test1.PublicService')
120            ],
121        )
122
123    def test_access_method_client_as_attribute_or_index(self) -> None:
124        self.assertIs(
125            self._channel_client.rpcs.pw.test2.Alpha.Unary,
126            self._channel_client.rpcs['pw.test2.Alpha']['Unary'],
127        )
128        self.assertIs(
129            self._channel_client.rpcs.pw.test2.Alpha.Unary,
130            self._channel_client.rpcs['pw.test2.Alpha'][
131                pw_rpc.ids.calculate('Unary')
132            ],
133        )
134
135    def test_service_name(self) -> None:
136        self.assertEqual(
137            self._channel_client.rpcs.pw.test2.Alpha.Unary.service.name, 'Alpha'
138        )
139        self.assertEqual(
140            self._channel_client.rpcs.pw.test2.Alpha.Unary.service.full_name,
141            'pw.test2.Alpha',
142        )
143
144    def test_method_name(self) -> None:
145        self.assertEqual(
146            self._channel_client.rpcs.pw.test2.Alpha.Unary.method.name, 'Unary'
147        )
148        self.assertEqual(
149            self._channel_client.rpcs.pw.test2.Alpha.Unary.method.full_name,
150            'pw.test2.Alpha.Unary',
151        )
152
153    def test_iterate_over_all_methods(self) -> None:
154        channel_client = self._channel_client
155        all_methods = {
156            channel_client.rpcs.pw.test1.PublicService.SomeUnary,
157            channel_client.rpcs.pw.test1.PublicService.SomeServerStreaming,
158            channel_client.rpcs.pw.test1.PublicService.SomeClientStreaming,
159            channel_client.rpcs.pw.test1.PublicService.SomeBidiStreaming,
160            channel_client.rpcs.pw.test2.Alpha.Unary,
161            channel_client.rpcs.pw.test2.Bravo.BidiStreaming,
162        }
163        self.assertEqual(set(channel_client.methods()), all_methods)
164
165    def test_check_for_presence_of_services(self) -> None:
166        self.assertIn('pw.test1.PublicService', self._channel_client.rpcs)
167        self.assertIn(
168            pw_rpc.ids.calculate('pw.test1.PublicService'),
169            self._channel_client.rpcs,
170        )
171
172    def test_check_for_presence_of_missing_services(self) -> None:
173        self.assertNotIn('PublicService', self._channel_client.rpcs)
174        self.assertNotIn('NotAService', self._channel_client.rpcs)
175        self.assertNotIn(-1213, self._channel_client.rpcs)
176
177    def test_check_for_presence_of_methods(self) -> None:
178        service = self._channel_client.rpcs.pw.test1.PublicService
179        self.assertIn('SomeUnary', service)
180        self.assertIn(pw_rpc.ids.calculate('SomeUnary'), service)
181
182    def test_check_for_presence_of_missing_methods(self) -> None:
183        service = self._channel_client.rpcs.pw.test1.PublicService
184        self.assertNotIn('Some', service)
185        self.assertNotIn('Unary', service)
186        self.assertNotIn(12345, service)
187
188    def test_method_fully_qualified_name(self) -> None:
189        self.assertIs(
190            self._channel_client.method('pw.test2.Alpha/Unary'),
191            self._channel_client.rpcs.pw.test2.Alpha.Unary,
192        )
193        self.assertIs(
194            self._channel_client.method('pw.test2.Alpha.Unary'),
195            self._channel_client.rpcs.pw.test2.Alpha.Unary,
196        )
197
198
199class ClientTest(unittest.TestCase):
200    """Tests the pw_rpc Client independently of the ClientImpl."""
201
202    def setUp(self) -> None:
203        self._last_packet_sent_bytes: bytes | None = None
204        self._client = create_client(PROTOS.modules(), self._save_packet)
205
206    def _save_packet(self, packet) -> None:
207        self._last_packet_sent_bytes = packet
208
209    def _last_packet_sent(self) -> RpcPacket:
210        packet = RpcPacket()
211        assert self._last_packet_sent_bytes is not None
212        packet.MergeFromString(self._last_packet_sent_bytes)
213        return packet
214
215    def test_channel(self) -> None:
216        self.assertEqual(
217            self._client.channel(CLIENT_FIRST_CHANNEL_ID).channel.id,
218            CLIENT_FIRST_CHANNEL_ID,
219        )
220        self.assertEqual(
221            self._client.channel(CLIENT_SECOND_CHANNEL_ID).channel.id,
222            CLIENT_SECOND_CHANNEL_ID,
223        )
224
225    def test_channel_default_is_first_listed(self) -> None:
226        self.assertEqual(
227            self._client.channel().channel.id, CLIENT_FIRST_CHANNEL_ID
228        )
229
230    def test_channel_invalid(self) -> None:
231        with self.assertRaises(KeyError):
232            self._client.channel(404)
233
234    def test_all_methods(self) -> None:
235        services = self._client.services
236
237        all_methods = {
238            services['pw.test1.PublicService'].methods['SomeUnary'],
239            services['pw.test1.PublicService'].methods['SomeServerStreaming'],
240            services['pw.test1.PublicService'].methods['SomeClientStreaming'],
241            services['pw.test1.PublicService'].methods['SomeBidiStreaming'],
242            services['pw.test2.Alpha'].methods['Unary'],
243            services['pw.test2.Bravo'].methods['BidiStreaming'],
244        }
245        self.assertEqual(set(self._client.methods()), all_methods)
246
247    def test_method_present(self) -> None:
248        self.assertIs(
249            self._client.method('pw.test1.PublicService.SomeUnary'),
250            self._client.services['pw.test1.PublicService'].methods[
251                'SomeUnary'
252            ],
253        )
254        self.assertIs(
255            self._client.method('pw.test1.PublicService/SomeUnary'),
256            self._client.services['pw.test1.PublicService'].methods[
257                'SomeUnary'
258            ],
259        )
260
261    def test_method_invalid_format(self) -> None:
262        with self.assertRaises(ValueError):
263            self._client.method('SomeUnary')
264
265    def test_method_not_present(self) -> None:
266        with self.assertRaises(KeyError):
267            self._client.method('pw.test1.PublicService/ThisIsNotGood')
268
269        with self.assertRaises(KeyError):
270            self._client.method('nothing.Good')
271
272    def test_process_packet_invalid_proto_data(self) -> None:
273        self.assertIs(
274            self._client.process_packet(b'NOT a packet!'), Status.DATA_LOSS
275        )
276
277    def test_process_packet_not_for_client(self) -> None:
278        self.assertIs(
279            self._client.process_packet(
280                RpcPacket(type=PacketType.REQUEST).SerializeToString()
281            ),
282            Status.INVALID_ARGUMENT,
283        )
284
285    def test_process_packet_unrecognized_channel(self) -> None:
286        self.assertIs(
287            self._client.process_packet(
288                packets.encode_response(
289                    RpcIds(
290                        SOME_CHANNEL_ID,
291                        SOME_SERVICE_ID,
292                        SOME_METHOD_ID,
293                        SOME_CALL_ID,
294                    ),
295                    PROTOS.packages.pw.test2.Request(),
296                )
297            ),
298            Status.NOT_FOUND,
299        )
300
301    def test_process_packet_unrecognized_service(self) -> None:
302        self.assertIs(
303            self._client.process_packet(
304                packets.encode_response(
305                    RpcIds(
306                        CLIENT_FIRST_CHANNEL_ID,
307                        SOME_SERVICE_ID,
308                        SOME_METHOD_ID,
309                        SOME_CALL_ID,
310                    ),
311                    PROTOS.packages.pw.test2.Request(),
312                )
313            ),
314            Status.OK,
315        )
316
317        self.assertEqual(
318            self._last_packet_sent(),
319            RpcPacket(
320                type=PacketType.CLIENT_ERROR,
321                channel_id=CLIENT_FIRST_CHANNEL_ID,
322                service_id=SOME_SERVICE_ID,
323                method_id=SOME_METHOD_ID,
324                call_id=SOME_CALL_ID,
325                status=Status.NOT_FOUND.value,
326            ),
327        )
328
329    def test_process_packet_unrecognized_method(self) -> None:
330        service = next(iter(self._client.services))
331
332        self.assertIs(
333            self._client.process_packet(
334                packets.encode_response(
335                    RpcIds(
336                        CLIENT_FIRST_CHANNEL_ID,
337                        service.id,
338                        SOME_METHOD_ID,
339                        SOME_CALL_ID,
340                    ),
341                    PROTOS.packages.pw.test2.Request(),
342                )
343            ),
344            Status.OK,
345        )
346
347        self.assertEqual(
348            self._last_packet_sent(),
349            RpcPacket(
350                type=PacketType.CLIENT_ERROR,
351                channel_id=CLIENT_FIRST_CHANNEL_ID,
352                service_id=service.id,
353                method_id=SOME_METHOD_ID,
354                call_id=SOME_CALL_ID,
355                status=Status.NOT_FOUND.value,
356            ),
357        )
358
359    def test_process_packet_non_pending_method(self) -> None:
360        service = next(iter(self._client.services))
361        method = next(iter(service.methods))
362
363        self.assertIs(
364            self._client.process_packet(
365                packets.encode_response(
366                    RpcIds(
367                        CLIENT_FIRST_CHANNEL_ID,
368                        service.id,
369                        method.id,
370                        SOME_CALL_ID,
371                    ),
372                    PROTOS.packages.pw.test2.Request(),
373                )
374            ),
375            Status.OK,
376        )
377
378        self.assertEqual(
379            self._last_packet_sent(),
380            RpcPacket(
381                type=PacketType.CLIENT_ERROR,
382                channel_id=CLIENT_FIRST_CHANNEL_ID,
383                service_id=service.id,
384                method_id=method.id,
385                call_id=SOME_CALL_ID,
386                status=Status.FAILED_PRECONDITION.value,
387            ),
388        )
389
390    def test_process_packet_non_pending_calls_response_callback(self) -> None:
391        method = self._client.method('pw.test1.PublicService.SomeUnary')
392        reply = method.response_type(payload='hello')
393
394        def response_callback(
395            rpc: client.PendingRpc,
396            message,
397            status: Status | None,
398        ) -> None:
399            self.assertEqual(
400                rpc,
401                client.PendingRpc(
402                    self._client.channel(CLIENT_FIRST_CHANNEL_ID).channel,
403                    method.service,
404                    method,
405                    call_id=SOME_CALL_ID,
406                ),
407            )
408            self.assertEqual(message, reply)
409            self.assertIs(status, Status.OK)
410
411        self._client.response_callback = response_callback
412
413        self.assertIs(
414            self._client.process_packet(
415                packets.encode_response(
416                    RpcIds(
417                        CLIENT_FIRST_CHANNEL_ID,
418                        method.service.id,
419                        method.id,
420                        SOME_CALL_ID,
421                    ),
422                    reply,
423                )
424            ),
425            Status.OK,
426        )
427
428
429if __name__ == '__main__':
430    unittest.main()
431