xref: /aosp_15_r20/external/pigweed/pw_rpc/py/pw_rpc/codegen_raw.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2021 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""This module generates the code for raw pw_rpc services."""
15
16import os
17from typing import Iterable
18
19from pw_protobuf.output_file import OutputFile
20from pw_protobuf.proto_tree import ProtoServiceMethod
21from pw_protobuf.proto_tree import build_node_tree
22from pw_rpc import codegen
23from pw_rpc.codegen import (
24    client_call_type,
25    get_id,
26    CodeGenerator,
27    RPC_NAMESPACE,
28)
29
30PROTO_H_EXTENSION = '.pb.h'
31
32
33def _proto_filename_to_generated_header(proto_file: str) -> str:
34    """Returns the generated C++ RPC header name for a .proto file."""
35    filename = os.path.splitext(proto_file)[0]
36    return f'{filename}.raw_rpc{PROTO_H_EXTENSION}'
37
38
39def _proto_filename_to_stub_header(proto_file: str) -> str:
40    """Returns the generated C++ RPC header name for a .proto file."""
41    filename = os.path.splitext(proto_file)[0]
42    return f'{filename}.raw_rpc.stub{PROTO_H_EXTENSION}'
43
44
45def _function(method: ProtoServiceMethod) -> str:
46    return f'{client_call_type(method, "Raw")} {method.name()}'
47
48
49def _user_args(method: ProtoServiceMethod) -> Iterable[str]:
50    if not method.client_streaming():
51        yield '::pw::ConstByteSpan request'
52
53    if method.server_streaming():
54        yield '::pw::Function<void(::pw::ConstByteSpan)>&& on_next = nullptr'
55        yield '::pw::Function<void(::pw::Status)>&& on_completed = nullptr'
56    else:
57        yield (
58            '::pw::Function<void(::pw::ConstByteSpan, ::pw::Status)>&& '
59            'on_completed = nullptr'
60        )
61
62    yield '::pw::Function<void(::pw::Status)>&& on_error = nullptr'
63
64
65class RawCodeGenerator(CodeGenerator):
66    """Generates an RPC service and client using the raw buffers API."""
67
68    def name(self) -> str:
69        return 'raw'
70
71    def method_union_name(self) -> str:
72        return 'RawMethodUnion'
73
74    def includes(self, unused_proto_file_name: str) -> Iterable[str]:
75        yield '#include "pw_rpc/raw/client_reader_writer.h"'
76        yield '#include "pw_rpc/raw/internal/method_union.h"'
77        yield '#include "pw_rpc/raw/server_reader_writer.h"'
78
79    def service_aliases(self) -> None:
80        self.line(f'using RawServerWriter = {RPC_NAMESPACE}::RawServerWriter;')
81        self.line(f'using RawServerReader = {RPC_NAMESPACE}::RawServerReader;')
82        self.line(
83            'using RawServerReaderWriter = '
84            f'{RPC_NAMESPACE}::RawServerReaderWriter;'
85        )
86
87    def method_descriptor(self, method: ProtoServiceMethod) -> None:
88        impl_method = f'&Implementation::{method.name()}'
89
90        self.line(
91            f'{RPC_NAMESPACE}::internal::GetRawMethodFor<{impl_method}, '
92            f'{method.type().cc_enum()}>('
93        )
94        self.line(f'    {get_id(method)}),  // Hash of "{method.name()}"')
95
96    def client_member_function(
97        self, method: ProtoServiceMethod, *, dynamic: bool
98    ) -> None:
99        if dynamic:
100            self.line('// DynamicClient is not implemented for raw RPC')
101            return
102
103        self.line(f'{_function(method)}(')
104        self.indented_list(*_user_args(method), end=') const {')
105
106        with self.indent():
107            base = 'Stream' if method.server_streaming() else 'Unary'
108            self.line(
109                f'return {RPC_NAMESPACE}::internal::'
110                f'{base}ResponseClientCall::'
111                f'Start<{client_call_type(method, "Raw")}>('
112            )
113
114            service_client = RPC_NAMESPACE + '::internal::ServiceClient'
115            arg = ['std::move(on_next)'] if method.server_streaming() else []
116
117            self.indented_list(
118                f'{service_client}::client()',
119                f'{service_client}::channel_id()',
120                'kServiceId',
121                get_id(method),
122                *arg,
123                'std::move(on_completed)',
124                'std::move(on_error)',
125                '{}' if method.client_streaming() else 'request',
126                end=');',
127            )
128
129        self.line('}')
130
131    def client_static_function(self, method: ProtoServiceMethod) -> None:
132        self.line(f'static {_function(method)}(')
133        self.indented_list(
134            f'{RPC_NAMESPACE}::Client& client',
135            'uint32_t channel_id',
136            *_user_args(method),
137            end=') {',
138        )
139
140        with self.indent():
141            self.line(f'return Client(client, channel_id).{method.name()}(')
142
143            args = []
144
145            if not method.client_streaming():
146                args.append('request')
147
148            if method.server_streaming():
149                args.append('std::move(on_next)')
150
151            self.indented_list(
152                *args,
153                'std::move(on_completed)',
154                'std::move(on_error)',
155                end=');',
156            )
157
158        self.line('}')
159
160    def method_info_specialization(self, method: ProtoServiceMethod) -> None:
161        self.line()
162        # We have Request/Response as voids to mark raw as a special case.
163        # Raw operates in ConstByteSpans, which won't be copied by copying the
164        # span itself and without special treatment will lead to dangling
165        # pointers.
166        #
167        # Helpers/traits that want to use Request/Response and should support
168        # raw are required to do a special implementation for them instead that
169        # will copy the actual data.
170        self.line('using Request = void;')
171        self.line('using Response = void;')
172
173
174class StubGenerator(codegen.StubGenerator):
175    """TODO(frolv) Add docstring."""
176
177    def unary_signature(self, method: ProtoServiceMethod, prefix: str) -> str:
178        return (
179            f'void {prefix}{method.name()}(pw::ConstByteSpan request, '
180            'pw::rpc::RawUnaryResponder& responder)'
181        )
182
183    def unary_stub(
184        self, method: ProtoServiceMethod, output: OutputFile
185    ) -> None:
186        output.write_line(codegen.STUB_REQUEST_TODO)
187        output.write_line('static_cast<void>(request);')
188        output.write_line(codegen.STUB_RESPONSE_TODO)
189        output.write_line('static_cast<void>(responder);')
190
191    def server_streaming_signature(
192        self, method: ProtoServiceMethod, prefix: str
193    ) -> str:
194        return (
195            f'void {prefix}{method.name()}('
196            'pw::ConstByteSpan request, RawServerWriter& writer)'
197        )
198
199    def client_streaming_signature(
200        self, method: ProtoServiceMethod, prefix: str
201    ) -> str:
202        return f'void {prefix}{method.name()}(RawServerReader& reader)'
203
204    def bidirectional_streaming_signature(
205        self, method: ProtoServiceMethod, prefix: str
206    ) -> str:
207        return (
208            f'void {prefix}{method.name()}('
209            'RawServerReaderWriter& reader_writer)'
210        )
211
212
213def process_proto_file(proto_file) -> Iterable[OutputFile]:
214    """Generates code for a single .proto file."""
215
216    _, package_root = build_node_tree(proto_file)
217    output_filename = _proto_filename_to_generated_header(proto_file.name)
218
219    generator = RawCodeGenerator(output_filename)
220    codegen.generate_package(proto_file, package_root, generator)
221
222    codegen.package_stubs(package_root, generator, StubGenerator())
223
224    return [generator.output]
225