1# Copyright 2020 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Functions for working with pw_rpc packets.""" 15 16from google.protobuf import message 17from pw_status import Status 18 19from pw_rpc.descriptors import RpcIds, PendingRpc 20from pw_rpc.internal import packet_pb2 21 22 23def decode(data: bytes) -> packet_pb2.RpcPacket: 24 packet = packet_pb2.RpcPacket() 25 packet.MergeFromString(data) 26 return packet 27 28 29def decode_payload(packet, payload_type): 30 payload = payload_type() 31 payload.MergeFromString(packet.payload) 32 return payload 33 34 35def encode_request( 36 rpc: PendingRpc | RpcIds, request: message.Message | None 37) -> bytes: 38 payload = request.SerializeToString() if request is not None else bytes() 39 40 return packet_pb2.RpcPacket( 41 type=packet_pb2.PacketType.REQUEST, 42 channel_id=rpc.channel_id, 43 service_id=rpc.service_id, 44 method_id=rpc.method_id, 45 call_id=rpc.call_id, 46 payload=payload, 47 ).SerializeToString() 48 49 50def encode_response( 51 rpc: PendingRpc | RpcIds, 52 response: message.Message | None = None, 53 status: Status = Status.OK, 54) -> bytes: 55 return packet_pb2.RpcPacket( 56 type=packet_pb2.PacketType.RESPONSE, 57 channel_id=rpc.channel_id, 58 service_id=rpc.service_id, 59 method_id=rpc.method_id, 60 call_id=rpc.call_id, 61 payload=b'' if response is None else response.SerializeToString(), 62 status=status.value, 63 ).SerializeToString() 64 65 66def encode_client_stream( 67 rpc: PendingRpc | RpcIds, request: message.Message 68) -> bytes: 69 return packet_pb2.RpcPacket( 70 type=packet_pb2.PacketType.CLIENT_STREAM, 71 channel_id=rpc.channel_id, 72 service_id=rpc.service_id, 73 method_id=rpc.method_id, 74 call_id=rpc.call_id, 75 payload=request.SerializeToString(), 76 ).SerializeToString() 77 78 79def encode_client_error(packet: packet_pb2.RpcPacket, status: Status) -> bytes: 80 return packet_pb2.RpcPacket( 81 type=packet_pb2.PacketType.CLIENT_ERROR, 82 channel_id=packet.channel_id, 83 service_id=packet.service_id, 84 method_id=packet.method_id, 85 call_id=packet.call_id, 86 status=status.value, 87 ).SerializeToString() 88 89 90def encode_cancel(rpc: PendingRpc | RpcIds) -> bytes: 91 return packet_pb2.RpcPacket( 92 type=packet_pb2.PacketType.CLIENT_ERROR, 93 status=Status.CANCELLED.value, 94 channel_id=rpc.channel_id, 95 service_id=rpc.service_id, 96 method_id=rpc.method_id, 97 call_id=rpc.call_id, 98 ).SerializeToString() 99 100 101def encode_client_stream_end(rpc: PendingRpc | RpcIds) -> bytes: 102 return packet_pb2.RpcPacket( 103 type=packet_pb2.PacketType.CLIENT_REQUEST_COMPLETION, 104 channel_id=rpc.channel_id, 105 service_id=rpc.service_id, 106 method_id=rpc.method_id, 107 call_id=rpc.call_id, 108 ).SerializeToString() 109 110 111def encode_server_stream(rpc: RpcIds, payload: message.Message) -> bytes: 112 return packet_pb2.RpcPacket( 113 type=packet_pb2.PacketType.SERVER_STREAM, 114 channel_id=rpc.channel_id, 115 service_id=rpc.service_id, 116 method_id=rpc.method_id, 117 call_id=rpc.call_id, 118 payload=payload.SerializeToString(), 119 ).SerializeToString() 120 121 122def encode_server_error(rpc: RpcIds, status: Status) -> bytes: 123 assert not status.ok() 124 return packet_pb2.RpcPacket( 125 type=packet_pb2.PacketType.SERVER_ERROR, 126 status=status.value, 127 channel_id=rpc.channel_id, 128 service_id=rpc.service_id, 129 method_id=rpc.method_id, 130 call_id=rpc.call_id, 131 ).SerializeToString() 132 133 134def for_server(packet: packet_pb2.RpcPacket) -> bool: 135 return packet.type % 2 == 0 136