xref: /aosp_15_r20/external/pigweed/pw_rpc/py/pw_rpc/packets.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2020 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Functions for working with pw_rpc packets."""
15
16from google.protobuf import message
17from pw_status import Status
18
19from pw_rpc.descriptors import RpcIds, PendingRpc
20from pw_rpc.internal import packet_pb2
21
22
23def decode(data: bytes) -> packet_pb2.RpcPacket:
24    packet = packet_pb2.RpcPacket()
25    packet.MergeFromString(data)
26    return packet
27
28
29def decode_payload(packet, payload_type):
30    payload = payload_type()
31    payload.MergeFromString(packet.payload)
32    return payload
33
34
35def encode_request(
36    rpc: PendingRpc | RpcIds, request: message.Message | None
37) -> bytes:
38    payload = request.SerializeToString() if request is not None else bytes()
39
40    return packet_pb2.RpcPacket(
41        type=packet_pb2.PacketType.REQUEST,
42        channel_id=rpc.channel_id,
43        service_id=rpc.service_id,
44        method_id=rpc.method_id,
45        call_id=rpc.call_id,
46        payload=payload,
47    ).SerializeToString()
48
49
50def encode_response(
51    rpc: PendingRpc | RpcIds,
52    response: message.Message | None = None,
53    status: Status = Status.OK,
54) -> bytes:
55    return packet_pb2.RpcPacket(
56        type=packet_pb2.PacketType.RESPONSE,
57        channel_id=rpc.channel_id,
58        service_id=rpc.service_id,
59        method_id=rpc.method_id,
60        call_id=rpc.call_id,
61        payload=b'' if response is None else response.SerializeToString(),
62        status=status.value,
63    ).SerializeToString()
64
65
66def encode_client_stream(
67    rpc: PendingRpc | RpcIds, request: message.Message
68) -> bytes:
69    return packet_pb2.RpcPacket(
70        type=packet_pb2.PacketType.CLIENT_STREAM,
71        channel_id=rpc.channel_id,
72        service_id=rpc.service_id,
73        method_id=rpc.method_id,
74        call_id=rpc.call_id,
75        payload=request.SerializeToString(),
76    ).SerializeToString()
77
78
79def encode_client_error(packet: packet_pb2.RpcPacket, status: Status) -> bytes:
80    return packet_pb2.RpcPacket(
81        type=packet_pb2.PacketType.CLIENT_ERROR,
82        channel_id=packet.channel_id,
83        service_id=packet.service_id,
84        method_id=packet.method_id,
85        call_id=packet.call_id,
86        status=status.value,
87    ).SerializeToString()
88
89
90def encode_cancel(rpc: PendingRpc | RpcIds) -> bytes:
91    return packet_pb2.RpcPacket(
92        type=packet_pb2.PacketType.CLIENT_ERROR,
93        status=Status.CANCELLED.value,
94        channel_id=rpc.channel_id,
95        service_id=rpc.service_id,
96        method_id=rpc.method_id,
97        call_id=rpc.call_id,
98    ).SerializeToString()
99
100
101def encode_client_stream_end(rpc: PendingRpc | RpcIds) -> bytes:
102    return packet_pb2.RpcPacket(
103        type=packet_pb2.PacketType.CLIENT_REQUEST_COMPLETION,
104        channel_id=rpc.channel_id,
105        service_id=rpc.service_id,
106        method_id=rpc.method_id,
107        call_id=rpc.call_id,
108    ).SerializeToString()
109
110
111def encode_server_stream(rpc: RpcIds, payload: message.Message) -> bytes:
112    return packet_pb2.RpcPacket(
113        type=packet_pb2.PacketType.SERVER_STREAM,
114        channel_id=rpc.channel_id,
115        service_id=rpc.service_id,
116        method_id=rpc.method_id,
117        call_id=rpc.call_id,
118        payload=payload.SerializeToString(),
119    ).SerializeToString()
120
121
122def encode_server_error(rpc: RpcIds, status: Status) -> bytes:
123    assert not status.ok()
124    return packet_pb2.RpcPacket(
125        type=packet_pb2.PacketType.SERVER_ERROR,
126        status=status.value,
127        channel_id=rpc.channel_id,
128        service_id=rpc.service_id,
129        method_id=rpc.method_id,
130        call_id=rpc.call_id,
131    ).SerializeToString()
132
133
134def for_server(packet: packet_pb2.RpcPacket) -> bool:
135    return packet.type % 2 == 0
136