1#!/usr/bin/env python3 2# Copyright 2020 The Pigweed Authors 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); you may not 5# use this file except in compliance with the License. You may obtain a copy of 6# the License at 7# 8# https://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13# License for the specific language governing permissions and limitations under 14# the License. 15"""Tests creating pw_rpc client.""" 16 17import unittest 18 19from pw_status import Status 20 21from pw_rpc.internal.packet_pb2 import PacketType, RpcPacket 22from pw_rpc import packets 23from pw_rpc.descriptors import RpcIds 24 25_TEST_IDS = RpcIds(1, 2, 3, 4) 26 27_TEST_STATUS = 321 28_TEST_REQUEST = RpcPacket( 29 type=PacketType.REQUEST, 30 channel_id=_TEST_IDS.channel_id, 31 service_id=_TEST_IDS.service_id, 32 method_id=_TEST_IDS.method_id, 33 call_id=_TEST_IDS.call_id, 34 payload=RpcPacket(status=_TEST_STATUS).SerializeToString(), 35) 36_TEST_RESPONSE = RpcPacket( 37 type=PacketType.RESPONSE, 38 channel_id=_TEST_IDS.channel_id, 39 service_id=_TEST_IDS.service_id, 40 method_id=_TEST_IDS.method_id, 41 call_id=_TEST_IDS.call_id, 42 payload=RpcPacket(status=_TEST_STATUS).SerializeToString(), 43) 44 45 46class PacketsTest(unittest.TestCase): 47 """Tests for packet encoding and decoding.""" 48 49 def test_encode_request(self): 50 data = packets.encode_request(_TEST_IDS, RpcPacket(status=_TEST_STATUS)) 51 packet = RpcPacket() 52 packet.ParseFromString(data) 53 54 self.assertEqual(_TEST_REQUEST, packet) 55 56 def test_encode_response(self): 57 data = packets.encode_response( 58 _TEST_IDS, RpcPacket(status=_TEST_STATUS), Status.OK 59 ) 60 packet = RpcPacket() 61 packet.ParseFromString(data) 62 63 self.assertEqual(_TEST_RESPONSE, packet) 64 65 def test_encode_cancel(self): 66 data = packets.encode_cancel(RpcIds(9, 8, 7, 6)) 67 68 packet = RpcPacket() 69 packet.ParseFromString(data) 70 71 self.assertEqual( 72 packet, 73 RpcPacket( 74 type=PacketType.CLIENT_ERROR, 75 channel_id=9, 76 service_id=8, 77 method_id=7, 78 call_id=6, 79 status=Status.CANCELLED.value, 80 ), 81 ) 82 83 def test_encode_client_error(self): 84 data = packets.encode_client_error(_TEST_REQUEST, Status.NOT_FOUND) 85 86 packet = RpcPacket() 87 packet.ParseFromString(data) 88 89 self.assertEqual( 90 packet, 91 RpcPacket( 92 type=PacketType.CLIENT_ERROR, 93 channel_id=_TEST_IDS.channel_id, 94 service_id=_TEST_IDS.service_id, 95 method_id=_TEST_IDS.method_id, 96 call_id=_TEST_IDS.call_id, 97 status=Status.NOT_FOUND.value, 98 ), 99 ) 100 101 def test_encode_server_error(self): 102 data = packets.encode_server_error(_TEST_REQUEST, Status.UNKNOWN) 103 104 packet = RpcPacket() 105 packet.ParseFromString(data) 106 107 self.assertEqual( 108 packet, 109 RpcPacket( 110 type=PacketType.SERVER_ERROR, 111 channel_id=_TEST_IDS.channel_id, 112 service_id=_TEST_IDS.service_id, 113 method_id=_TEST_IDS.method_id, 114 call_id=_TEST_IDS.call_id, 115 status=Status.UNKNOWN.value, 116 ), 117 ) 118 119 def test_encode_server_stream(self): 120 data = packets.encode_server_stream( 121 _TEST_REQUEST, RpcPacket(status=_TEST_STATUS) 122 ) 123 124 packet = RpcPacket() 125 packet.ParseFromString(data) 126 127 self.assertEqual( 128 packet, 129 RpcPacket( 130 type=PacketType.SERVER_STREAM, 131 channel_id=_TEST_IDS.channel_id, 132 service_id=_TEST_IDS.service_id, 133 method_id=_TEST_IDS.method_id, 134 call_id=_TEST_IDS.call_id, 135 payload=RpcPacket(status=_TEST_STATUS).SerializeToString(), 136 ), 137 ) 138 139 def test_decode(self): 140 self.assertEqual( 141 _TEST_REQUEST, packets.decode(_TEST_REQUEST.SerializeToString()) 142 ) 143 144 def test_for_server(self): 145 self.assertTrue(packets.for_server(_TEST_REQUEST)) 146 self.assertFalse(packets.for_server(_TEST_RESPONSE)) 147 148 149if __name__ == '__main__': 150 unittest.main() 151