xref: /aosp_15_r20/external/pigweed/pw_rpc/py/tests/packets_test.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1#!/usr/bin/env python3
2# Copyright 2020 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Tests creating pw_rpc client."""
16
17import unittest
18
19from pw_status import Status
20
21from pw_rpc.internal.packet_pb2 import PacketType, RpcPacket
22from pw_rpc import packets
23from pw_rpc.descriptors import RpcIds
24
25_TEST_IDS = RpcIds(1, 2, 3, 4)
26
27_TEST_STATUS = 321
28_TEST_REQUEST = RpcPacket(
29    type=PacketType.REQUEST,
30    channel_id=_TEST_IDS.channel_id,
31    service_id=_TEST_IDS.service_id,
32    method_id=_TEST_IDS.method_id,
33    call_id=_TEST_IDS.call_id,
34    payload=RpcPacket(status=_TEST_STATUS).SerializeToString(),
35)
36_TEST_RESPONSE = RpcPacket(
37    type=PacketType.RESPONSE,
38    channel_id=_TEST_IDS.channel_id,
39    service_id=_TEST_IDS.service_id,
40    method_id=_TEST_IDS.method_id,
41    call_id=_TEST_IDS.call_id,
42    payload=RpcPacket(status=_TEST_STATUS).SerializeToString(),
43)
44
45
46class PacketsTest(unittest.TestCase):
47    """Tests for packet encoding and decoding."""
48
49    def test_encode_request(self):
50        data = packets.encode_request(_TEST_IDS, RpcPacket(status=_TEST_STATUS))
51        packet = RpcPacket()
52        packet.ParseFromString(data)
53
54        self.assertEqual(_TEST_REQUEST, packet)
55
56    def test_encode_response(self):
57        data = packets.encode_response(
58            _TEST_IDS, RpcPacket(status=_TEST_STATUS), Status.OK
59        )
60        packet = RpcPacket()
61        packet.ParseFromString(data)
62
63        self.assertEqual(_TEST_RESPONSE, packet)
64
65    def test_encode_cancel(self):
66        data = packets.encode_cancel(RpcIds(9, 8, 7, 6))
67
68        packet = RpcPacket()
69        packet.ParseFromString(data)
70
71        self.assertEqual(
72            packet,
73            RpcPacket(
74                type=PacketType.CLIENT_ERROR,
75                channel_id=9,
76                service_id=8,
77                method_id=7,
78                call_id=6,
79                status=Status.CANCELLED.value,
80            ),
81        )
82
83    def test_encode_client_error(self):
84        data = packets.encode_client_error(_TEST_REQUEST, Status.NOT_FOUND)
85
86        packet = RpcPacket()
87        packet.ParseFromString(data)
88
89        self.assertEqual(
90            packet,
91            RpcPacket(
92                type=PacketType.CLIENT_ERROR,
93                channel_id=_TEST_IDS.channel_id,
94                service_id=_TEST_IDS.service_id,
95                method_id=_TEST_IDS.method_id,
96                call_id=_TEST_IDS.call_id,
97                status=Status.NOT_FOUND.value,
98            ),
99        )
100
101    def test_encode_server_error(self):
102        data = packets.encode_server_error(_TEST_REQUEST, Status.UNKNOWN)
103
104        packet = RpcPacket()
105        packet.ParseFromString(data)
106
107        self.assertEqual(
108            packet,
109            RpcPacket(
110                type=PacketType.SERVER_ERROR,
111                channel_id=_TEST_IDS.channel_id,
112                service_id=_TEST_IDS.service_id,
113                method_id=_TEST_IDS.method_id,
114                call_id=_TEST_IDS.call_id,
115                status=Status.UNKNOWN.value,
116            ),
117        )
118
119    def test_encode_server_stream(self):
120        data = packets.encode_server_stream(
121            _TEST_REQUEST, RpcPacket(status=_TEST_STATUS)
122        )
123
124        packet = RpcPacket()
125        packet.ParseFromString(data)
126
127        self.assertEqual(
128            packet,
129            RpcPacket(
130                type=PacketType.SERVER_STREAM,
131                channel_id=_TEST_IDS.channel_id,
132                service_id=_TEST_IDS.service_id,
133                method_id=_TEST_IDS.method_id,
134                call_id=_TEST_IDS.call_id,
135                payload=RpcPacket(status=_TEST_STATUS).SerializeToString(),
136            ),
137        )
138
139    def test_decode(self):
140        self.assertEqual(
141            _TEST_REQUEST, packets.decode(_TEST_REQUEST.SerializeToString())
142        )
143
144    def test_for_server(self):
145        self.assertTrue(packets.for_server(_TEST_REQUEST))
146        self.assertFalse(packets.for_server(_TEST_RESPONSE))
147
148
149if __name__ == '__main__':
150    unittest.main()
151