xref: /aosp_15_r20/external/pigweed/pw_rpc/py/pw_rpc/codegen_pwpb.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2022 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""This module generates the code for pw_protobuf pw_rpc services."""
15
16import os
17from typing import Iterable
18
19from pw_protobuf.output_file import OutputFile
20from pw_protobuf.proto_tree import ProtoServiceMethod
21from pw_protobuf.proto_tree import build_node_tree
22from pw_rpc import codegen
23from pw_rpc.codegen import (
24    client_call_type,
25    get_id,
26    CodeGenerator,
27    RPC_NAMESPACE,
28)
29
30PROTO_H_EXTENSION = '.pwpb.h'
31PWPB_H_EXTENSION = '.pwpb.h'
32
33
34def _proto_filename_to_pwpb_header(proto_file: str) -> str:
35    """Returns the generated pwpb header name for a .proto file."""
36    filename = os.path.splitext(proto_file)[0]
37    return f'{filename}{PWPB_H_EXTENSION}'
38
39
40def _proto_filename_to_generated_header(proto_file: str) -> str:
41    """Returns the generated C++ RPC header name for a .proto file."""
42    filename = os.path.splitext(proto_file)[0]
43    return f'{filename}.rpc{PROTO_H_EXTENSION}'
44
45
46def _serde(method: ProtoServiceMethod) -> str:
47    """Returns the PwpbMethodSerde for this method."""
48    return (
49        f'{RPC_NAMESPACE}::internal::kPwpbMethodSerde<'
50        f'&{method.request_type().pwpb_table()}, '
51        f'&{method.response_type().pwpb_table()}>'
52    )
53
54
55def _client_call(
56    method: ProtoServiceMethod, response: str | None = None
57) -> str:
58    template_args = []
59
60    if method.client_streaming():
61        template_args.append(method.request_type().pwpb_struct())
62
63    if response is None:
64        response = method.response_type().pwpb_struct()
65
66    template_args.append(response)
67
68    return f'{client_call_type(method, "Pwpb")}<{", ".join(template_args)}>'
69
70
71def _function(
72    method: ProtoServiceMethod,
73    name: str | None = None,
74) -> str:
75    return f'auto {name or method.name()}'
76
77
78def _user_args(
79    method: ProtoServiceMethod, response: str | None = None
80) -> Iterable[str]:
81    if not method.client_streaming():
82        yield f'const {method.request_type().pwpb_struct()}& request'
83
84    if response is None:
85        response = method.response_type().pwpb_struct()
86
87    if method.server_streaming():
88        yield f'::pw::Function<void(const {response}&)>&& on_next = nullptr'
89        yield '::pw::Function<void(::pw::Status)>&& on_completed = nullptr'
90    else:
91        yield (
92            f'::pw::Function<void(const {response}&, ::pw::Status)>&& '
93            'on_completed = nullptr'
94        )
95
96    yield '::pw::Function<void(::pw::Status)>&& on_error = nullptr'
97
98
99class PwpbCodeGenerator(CodeGenerator):
100    """Generates an RPC service and client using the pw_protobuf API."""
101
102    def name(self) -> str:
103        return 'pwpb'
104
105    def method_union_name(self) -> str:
106        return 'PwpbMethodUnion'
107
108    def includes(self, proto_file_name: str) -> Iterable[str]:
109        yield '#include "pw_rpc/pwpb/client_reader_writer.h"'
110        yield '#include "pw_rpc/pwpb/internal/method_union.h"'
111        yield '#include "pw_rpc/pwpb/server_reader_writer.h"'
112
113        # Include the corresponding pwpb header file for this proto file, in
114        # which the file's messages and enums are generated. All other files
115        # imported from the .proto file are #included in there.
116        pwpb_header = _proto_filename_to_pwpb_header(proto_file_name)
117        yield f'#include "{pwpb_header}"'
118
119    def service_aliases(self) -> None:
120        self.line('template <typename Response>')
121        self.line(
122            'using ServerWriter = '
123            f'{RPC_NAMESPACE}::PwpbServerWriter<Response>;'
124        )
125        self.line('template <typename Request, typename Response>')
126        self.line(
127            'using ServerReader = '
128            f'{RPC_NAMESPACE}::PwpbServerReader<Request, Response>;'
129        )
130        self.line('template <typename Request, typename Response>')
131        self.line(
132            'using ServerReaderWriter = '
133            f'{RPC_NAMESPACE}::PwpbServerReaderWriter<Request, Response>;'
134        )
135
136    def method_descriptor(self, method: ProtoServiceMethod) -> None:
137        impl_method = f'&Implementation::{method.name()}'
138
139        self.line(
140            f'{RPC_NAMESPACE}::internal::GetPwpbOrRawMethodFor<{impl_method}, '
141            f'{method.type().cc_enum()}, '
142            f'{method.request_type().pwpb_struct()}, '
143            f'{method.response_type().pwpb_struct()}>('
144        )
145        with self.indent(4):
146            self.line(f'{get_id(method)},  // Hash of "{method.name()}"')
147            self.line(f'{_serde(method)}),')
148
149    def _client_member_function(
150        self,
151        method: ProtoServiceMethod,
152        *,
153        response: str | None = None,
154        name: str | None = None,
155        dynamic: bool,
156    ) -> None:
157        if response is None:
158            response = method.response_type().pwpb_struct()
159
160        if name is None:
161            name = method.name()
162
163        self.line(f'{_function(method, name)}(')
164        self.indented_list(*_user_args(method, response), end=') const {')
165
166        with self.indent():
167            client_call = _client_call(method, response)
168            base = 'Stream' if method.server_streaming() else 'Unary'
169            self.line(
170                f'return {RPC_NAMESPACE}::internal::'
171                f'Pwpb{base}ResponseClientCall<{response}>::'
172                f'template Start{"Dynamic" if dynamic else ""}'
173                f'<{client_call}>('
174            )
175
176            service_client = RPC_NAMESPACE + '::internal::ServiceClient'
177
178            args = [
179                f'{service_client}::client()',
180                f'{service_client}::channel_id()',
181                'kServiceId',
182                get_id(method),
183                _serde(method),
184            ]
185            if method.server_streaming():
186                args.append('std::move(on_next)')
187
188            args.append('std::move(on_completed)')
189            args.append('std::move(on_error)')
190
191            if not method.client_streaming():
192                args.append('request')
193
194            self.indented_list(*args, end=');')
195
196        self.line('}')
197
198    def client_member_function(
199        self, method: ProtoServiceMethod, *, dynamic: bool
200    ) -> None:
201        """Outputs client code for a single RPC method."""
202        self._client_member_function(method, dynamic=dynamic)
203
204        if dynamic:  # Skip custom response overload
205            return
206
207        # Generate functions that allow specifying a custom response struct.
208        self.line(
209            'template <typename Response ='
210            + f'{method.response_type().pwpb_struct()}>'
211        )
212        self._client_member_function(
213            method,
214            response='Response',
215            name=method.name() + 'Template',
216            dynamic=dynamic,
217        )
218
219    def _client_static_function(
220        self,
221        method: ProtoServiceMethod,
222        response: str | None = None,
223        name: str | None = None,
224    ) -> None:
225        if response is None:
226            response = method.response_type().pwpb_struct()
227
228        if name is None:
229            name = method.name()
230
231        self.line(f'static {_function(method, name)}(')
232        self.indented_list(
233            f'{RPC_NAMESPACE}::Client& client',
234            'uint32_t channel_id',
235            *_user_args(method, response),
236            end=') {',
237        )
238
239        with self.indent():
240            self.line(f'return Client(client, channel_id).{name}(')
241
242            args = []
243
244            if not method.client_streaming():
245                args.append('request')
246
247            if method.server_streaming():
248                args.append('std::move(on_next)')
249
250            self.indented_list(
251                *args,
252                'std::move(on_completed)',
253                'std::move(on_error)',
254                end=');',
255            )
256
257        self.line('}')
258
259    def client_static_function(self, method: ProtoServiceMethod) -> None:
260        self._client_static_function(method)
261
262        self.line(
263            'template <typename Response ='
264            + f'{method.response_type().pwpb_struct()}>'
265        )
266        self._client_static_function(
267            method, 'Response', method.name() + 'Template'
268        )
269
270    def method_info_specialization(self, method: ProtoServiceMethod) -> None:
271        self.line()
272        self.line(f'using Request = {method.request_type().pwpb_struct()};')
273        self.line(f'using Response = {method.response_type().pwpb_struct()};')
274        self.line()
275        self.line(
276            f'static constexpr const {RPC_NAMESPACE}::'
277            'PwpbMethodSerde& serde() {'
278        )
279        with self.indent():
280            self.line(f'return {_serde(method)};')
281        self.line('}')
282
283
284class StubGenerator(codegen.StubGenerator):
285    """Generates pw_protobuf RPC stubs."""
286
287    def unary_signature(self, method: ProtoServiceMethod, prefix: str) -> str:
288        return (
289            f'::pw::Status {prefix}{method.name()}( '
290            f'const {method.request_type().pwpb_struct()}& request, '
291            f'{method.response_type().pwpb_struct()}& response)'
292        )
293
294    def unary_stub(
295        self, method: ProtoServiceMethod, output: OutputFile
296    ) -> None:
297        output.write_line(codegen.STUB_REQUEST_TODO)
298        output.write_line('static_cast<void>(request);')
299        output.write_line(codegen.STUB_RESPONSE_TODO)
300        output.write_line('static_cast<void>(response);')
301        output.write_line('return ::pw::Status::Unimplemented();')
302
303    def server_streaming_signature(
304        self, method: ProtoServiceMethod, prefix: str
305    ) -> str:
306        return (
307            f'void {prefix}{method.name()}( '
308            f'const {method.request_type().pwpb_struct()}& request, '
309            f'ServerWriter<{method.response_type().pwpb_struct()}>& writer)'
310        )
311
312    def client_streaming_signature(
313        self, method: ProtoServiceMethod, prefix: str
314    ) -> str:
315        return (
316            f'void {prefix}{method.name()}( '
317            f'ServerReader<{method.request_type().pwpb_struct()}, '
318            f'{method.response_type().pwpb_struct()}>& reader)'
319        )
320
321    def bidirectional_streaming_signature(
322        self, method: ProtoServiceMethod, prefix: str
323    ) -> str:
324        return (
325            f'void {prefix}{method.name()}( '
326            f'ServerReaderWriter<{method.request_type().pwpb_struct()}, '
327            f'{method.response_type().pwpb_struct()}>& reader_writer)'
328        )
329
330
331def process_proto_file(proto_file) -> Iterable[OutputFile]:
332    """Generates code for a single .proto file."""
333
334    _, package_root = build_node_tree(proto_file)
335    output_filename = _proto_filename_to_generated_header(proto_file.name)
336
337    generator = PwpbCodeGenerator(output_filename)
338    codegen.generate_package(proto_file, package_root, generator)
339
340    codegen.package_stubs(package_root, generator, StubGenerator())
341
342    return [generator.output]
343