xref: /aosp_15_r20/external/pigweed/pw_rpc/py/pw_rpc/codegen.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2021 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Common RPC codegen utilities."""
15
16import abc
17from datetime import datetime
18import os
19from typing import cast, Any, Iterable
20
21from pw_protobuf.output_file import OutputFile
22from pw_protobuf.proto_tree import ProtoNode, ProtoService, ProtoServiceMethod
23from pw_rpc import ids
24
25PLUGIN_NAME = 'pw_rpc_codegen'
26PLUGIN_VERSION = '0.3.0'
27
28RPC_NAMESPACE = '::pw::rpc'
29
30# todo-check: disable
31STUB_REQUEST_TODO = (
32    '// TODO: Read the request as appropriate for your application'
33)
34STUB_RESPONSE_TODO = (
35    '// TODO: Fill in the response as appropriate for your application'
36)
37STUB_WRITER_TODO = (
38    '// TODO: Send responses with the writer as appropriate for your '
39    'application'
40)
41STUB_READER_TODO = (
42    '// TODO: Set the client stream callback and send a response as '
43    'appropriate for your application'
44)
45STUB_READER_WRITER_TODO = (
46    '// TODO: Set the client stream callback and send responses as '
47    'appropriate for your application'
48)
49# todo-check: enable
50
51
52def get_id(item: ProtoService | ProtoServiceMethod) -> str:
53    name = item.proto_path() if isinstance(item, ProtoService) else item.name()
54    return f'0x{ids.calculate(name):08x}'
55
56
57def client_call_type(method: ProtoServiceMethod, prefix: str) -> str:
58    """Returns Client ReaderWriter/Reader/Writer/Recevier for the call."""
59    if method.type() is ProtoServiceMethod.Type.UNARY:
60        call_class = 'UnaryReceiver'
61    elif method.type() is ProtoServiceMethod.Type.SERVER_STREAMING:
62        call_class = 'ClientReader'
63    elif method.type() is ProtoServiceMethod.Type.CLIENT_STREAMING:
64        call_class = 'ClientWriter'
65    elif method.type() is ProtoServiceMethod.Type.BIDIRECTIONAL_STREAMING:
66        call_class = 'ClientReaderWriter'
67    else:
68        raise NotImplementedError(f'Unknown {method.type()}')
69
70    return f'{RPC_NAMESPACE}::{prefix}{call_class}'
71
72
73class CodeGenerator(abc.ABC):
74    """Generates RPC code for services and clients."""
75
76    def __init__(self, output_filename: str) -> None:
77        self.output = OutputFile(output_filename)
78
79    def indent(self, amount: int = OutputFile.INDENT_WIDTH) -> Any:
80        """Indents the output. Use in a with block."""
81        return self.output.indent(amount)
82
83    def line(self, value: str = '') -> None:
84        """Writes a line to the output."""
85        self.output.write_line(value)
86
87    def indented_list(self, *args: str, end: str = ',') -> None:
88        """Outputs each arg one per line; adds end to teh last arg."""
89        with self.indent(4):
90            for arg in args[:-1]:
91                self.line(arg + ',')
92
93            self.line(args[-1] + end)
94
95    @abc.abstractmethod
96    def name(self) -> str:
97        """Name of the pw_rpc implementation."""
98
99    @abc.abstractmethod
100    def method_union_name(self) -> str:
101        """Name of the MethodUnion class to use."""
102
103    @abc.abstractmethod
104    def includes(self, proto_file_name: str) -> Iterable[str]:
105        """Yields #include lines."""
106
107    @abc.abstractmethod
108    def service_aliases(self) -> None:
109        """Generates reader/writer aliases."""
110
111    @abc.abstractmethod
112    def method_descriptor(self, method: ProtoServiceMethod) -> None:
113        """Generates code for a service method."""
114
115    @abc.abstractmethod
116    def client_member_function(
117        self, method: ProtoServiceMethod, *, dynamic: bool
118    ) -> None:
119        """Generates the client code for the Client member functions."""
120
121    @abc.abstractmethod
122    def client_static_function(self, method: ProtoServiceMethod) -> None:
123        """Generates method static functions that instantiate a Client."""
124
125    def method_info_specialization(self, method: ProtoServiceMethod) -> None:
126        """Generates impl-specific additions to the MethodInfo specialization.
127
128        May be empty if the generator has nothing to add to the MethodInfo.
129        """
130
131    def private_additions(self, service: ProtoService) -> None:
132        """Additions to the private section of the outer generated class."""
133
134
135def generate_package(
136    file_descriptor_proto, proto_package: ProtoNode, gen: CodeGenerator
137) -> None:
138    """Generates service and client code for a package."""
139    assert proto_package.type() == ProtoNode.Type.PACKAGE
140
141    gen.line(
142        f'// {os.path.basename(gen.output.name())} automatically '
143        f'generated by {PLUGIN_NAME} {PLUGIN_VERSION}'
144    )
145    gen.line(f'// on {datetime.now().isoformat()}')
146    gen.line('// clang-format off')
147    gen.line('#pragma once\n')
148
149    gen.line('#include <array>')
150    gen.line('#include <cstdint>')
151    gen.line('#include <type_traits>\n')
152
153    include_lines = [
154        '#include "pw_rpc/internal/config.h"',
155        '#include "pw_rpc/internal/method_info.h"',
156        '#include "pw_rpc/internal/method_lookup.h"',
157        '#include "pw_rpc/internal/service_client.h"',
158        '#include "pw_rpc/method_type.h"',
159        '#include "pw_rpc/service.h"',
160        '#include "pw_rpc/service_id.h"',
161    ]
162    include_lines += gen.includes(file_descriptor_proto.name)
163
164    for include_line in sorted(include_lines):
165        gen.line(include_line)
166
167    gen.line()
168
169    if proto_package.cpp_namespace(codegen_subnamespace=None):
170        file_namespace = proto_package.cpp_namespace(codegen_subnamespace=None)
171        if file_namespace.startswith('::'):
172            file_namespace = file_namespace[2:]
173
174        gen.line(f'namespace {file_namespace} {{')
175    else:
176        file_namespace = ''
177
178    gen.line(f'namespace pw_rpc::{gen.name()} {{')
179    gen.line()
180
181    services = [
182        cast(ProtoService, node)
183        for node in proto_package
184        if node.type() == ProtoNode.Type.SERVICE
185    ]
186
187    for service in services:
188        _generate_service_and_client(gen, service)
189
190    gen.line()
191    gen.line(f'}}  // namespace pw_rpc::{gen.name()}\n')
192
193    if file_namespace:
194        gen.line('}  // namespace ' + file_namespace)
195
196    gen.line()
197    gen.line(
198        '// Specialize MethodInfo for each RPC to provide metadata at '
199        'compile time.'
200    )
201    for service in services:
202        _generate_info(gen, file_namespace, service)
203
204
205def _generate_service_and_client(
206    gen: CodeGenerator, service: ProtoService
207) -> None:
208    gen.line(
209        '// Wrapper class that namespaces server and client code for '
210        'this RPC service.'
211    )
212    gen.line(f'class {service.name()} final {{')
213    gen.line(' public:')
214
215    with gen.indent():
216        gen.line(f'{service.name()}() = delete;')
217        gen.line()
218
219        gen.line('static constexpr ::pw::rpc::ServiceId service_id() {')
220        with gen.indent():
221            gen.line('return ::pw::rpc::internal::WrapServiceId(kServiceId);')
222        gen.line('}')
223        gen.line()
224
225        _generate_service(gen, service)
226
227        gen.line()
228
229        _generate_client(gen, service)
230
231        # DynamicClient is only generated for pwpb for now.
232        if gen.name() == 'pwpb':
233            gen.line('#if PW_RPC_DYNAMIC_ALLOCATION')
234            _generate_client(gen, service, dynamic=True)
235            gen.line('#endif  // PW_RPC_DYNAMIC_ALLOCATION')
236
237        _generate_client_free_functions(gen, service)
238
239    gen.line(' private:')
240
241    with gen.indent():
242        gen.line(f'// Hash of "{service.proto_path()}".')
243        gen.line(f'static constexpr uint32_t kServiceId = {get_id(service)};')
244
245    gen.line('};')
246
247
248def _check_method_name(method: ProtoServiceMethod) -> None:
249    # Methods with the same name as their enclosing service will fail
250    # to compile because the generated method will be indistinguishable
251    # from a constructor.
252    if method.name() == method.service().name():
253        raise ValueError(
254            f'Attempted to compile `pw_rpc` for proto with method '
255            f'`{method.name()}` inside a service of the same name. '
256            '`pw_rpc` does not yet support methods with the same name as their '
257            'enclosing service.'
258        )
259    if method.name() in ('Service', 'ServiceInfo', 'Client'):
260        raise ValueError(
261            f'"{method.service().proto_path()}.{method.name()}" is not a '
262            f'valid method name! The name "{method.name()}" is reserved '
263            'for internal use by pw_rpc.'
264        )
265
266
267def _generate_client(
268    gen: CodeGenerator, service: ProtoService, *, dynamic: bool = False
269) -> None:
270    class_name = 'DynamicClient' if dynamic else 'Client'
271
272    gen.line('// The Client is used to invoke RPCs for this service.')
273    gen.line(
274        f'class {class_name} final : public {RPC_NAMESPACE}::internal::'
275        'ServiceClient {'
276    )
277    gen.line(' public:')
278
279    with gen.indent():
280        gen.line(
281            f'constexpr {class_name}({RPC_NAMESPACE}::Client& client,'
282            ' uint32_t channel_id)'
283        )
284        gen.line('    : ServiceClient(client, channel_id) {}')
285        gen.line()
286        gen.line(f'using ServiceInfo = {service.name()};')
287
288        for method in service.methods():
289            gen.line()
290            gen.client_member_function(method, dynamic=dynamic)
291
292    gen.line('};')
293    gen.line()
294
295
296def _generate_client_free_functions(
297    gen: CodeGenerator, service: ProtoService
298) -> None:
299    gen.line(
300        '// Static functions for invoking RPCs on a pw_rpc server. '
301        'These functions are '
302    )
303    gen.line(
304        '// equivalent to instantiating a Client and calling the '
305        'corresponding RPC.'
306    )
307    for method in service.methods():
308        _check_method_name(method)
309        gen.client_static_function(method)
310        gen.line()
311
312
313def _generate_info(
314    gen: CodeGenerator, namespace: str, service: ProtoService
315) -> None:
316    """Generates MethodInfo for each method."""
317    service_id = get_id(service)
318    info = f'struct {RPC_NAMESPACE.lstrip(":")}::internal::MethodInfo'
319
320    for method in service.methods():
321        gen.line('template <>')
322        gen.line(
323            f'{info}<{namespace}::pw_rpc::{gen.name()}::'
324            f'{service.name()}::{method.name()}> {{'
325        )
326
327        with gen.indent():
328            gen.line(f'static constexpr uint32_t kServiceId = {service_id};')
329            gen.line(
330                f'static constexpr uint32_t kMethodId = ' f'{get_id(method)};'
331            )
332            gen.line(
333                f'static constexpr {RPC_NAMESPACE}::MethodType kType = '
334                f'{method.type().cc_enum()};'
335            )
336            gen.line()
337
338            gen.line('template <typename ServiceImpl>')
339            gen.line('static constexpr auto Function() {')
340
341            with gen.indent():
342                gen.line(f'return &ServiceImpl::{method.name()};')
343
344            gen.line('}')
345
346            if gen.name() in ['pwpb', 'nanopb']:
347                gen.line('template <typename ServiceImpl, typename Response>')
348                gen.line('static constexpr auto FunctionTemplate() {')
349
350                with gen.indent():
351                    template_name = method.name() + 'Template<Response>'
352                    gen.line(f'return &ServiceImpl::template {template_name};')
353
354                gen.line('}')
355
356            gen.line(
357                'using GeneratedClient = '
358                f'{"::" + namespace if namespace else ""}'
359                f'::pw_rpc::{gen.name()}::{service.name()}::Client;'
360            )
361
362            gen.line(
363                'using ServiceClass = '
364                f'{"::" + namespace if namespace else ""}'
365                f'::pw_rpc::{gen.name()}::{service.name()};'
366            )
367
368            gen.method_info_specialization(method)
369
370        gen.line('};')
371        gen.line()
372
373
374def _generate_service(gen: CodeGenerator, service: ProtoService) -> None:
375    """Generates a C++ class for an RPC service."""
376
377    base_class = f'{RPC_NAMESPACE}::Service'
378    gen.line('// The RPC service base class.')
379    gen.line(
380        '// Inherit from this to implement an RPC service for a pw_rpc server.'
381    )
382    gen.line('template <typename Implementation>')
383    gen.line(f'class Service : public {base_class} {{')
384    gen.line(' public:')
385
386    with gen.indent():
387        gen.service_aliases()
388
389        gen.line()
390        gen.line(
391            f'static constexpr const char* name() '
392            f'{{ return "{service.name()}"; }}'
393        )
394        gen.line()
395        gen.line(f'using ServiceInfo = {service.name()};')
396        gen.line()
397
398    gen.line(' protected:')
399
400    with gen.indent():
401        gen.line(
402            'constexpr Service() : '
403            f'{base_class}(kServiceId, kPwRpcMethods) {{}}'
404        )
405
406    gen.line()
407    gen.line(' private:')
408
409    with gen.indent():
410        gen.line('friend class ::pw::rpc::internal::MethodLookup;')
411        gen.line()
412
413        # Generate the method table
414        gen.line(
415            'static constexpr std::array<'
416            f'{RPC_NAMESPACE}::internal::{gen.method_union_name()},'
417            f' {len(service.methods())}> kPwRpcMethods = {{'
418        )
419
420        with gen.indent(4):
421            for method in service.methods():
422                gen.method_descriptor(method)
423
424        gen.line('};\n')
425
426        # Generate the method lookup table
427        _method_lookup_table(gen, service)
428
429    gen.line('};')
430
431
432def _method_lookup_table(gen: CodeGenerator, service: ProtoService) -> None:
433    """Generates array of method IDs for looking up methods at compile time."""
434    gen.line(
435        'static constexpr std::array<uint32_t, '
436        f'{len(service.methods())}> kPwRpcMethodIds = {{'
437    )
438
439    with gen.indent(4):
440        for method in service.methods():
441            gen.line(f'{get_id(method)},  // Hash of "{method.name()}"')
442
443    gen.line('};')
444
445
446class StubGenerator(abc.ABC):
447    """Generates stub method implementations that can be copied-and-pasted."""
448
449    @abc.abstractmethod
450    def unary_signature(self, method: ProtoServiceMethod, prefix: str) -> str:
451        """Returns the signature of this unary method."""
452
453    @abc.abstractmethod
454    def unary_stub(
455        self, method: ProtoServiceMethod, output: OutputFile
456    ) -> None:
457        """Returns the stub for this unary method."""
458
459    @abc.abstractmethod
460    def server_streaming_signature(
461        self, method: ProtoServiceMethod, prefix: str
462    ) -> str:
463        """Returns the signature of this server streaming method."""
464
465    def server_streaming_stub(  # pylint: disable=no-self-use
466        self, unused_method: ProtoServiceMethod, output: OutputFile
467    ) -> None:
468        """Returns the stub for this server streaming method."""
469        output.write_line(STUB_REQUEST_TODO)
470        output.write_line('static_cast<void>(request);')
471        output.write_line(STUB_WRITER_TODO)
472        output.write_line('static_cast<void>(writer);')
473
474    @abc.abstractmethod
475    def client_streaming_signature(
476        self, method: ProtoServiceMethod, prefix: str
477    ) -> str:
478        """Returns the signature of this client streaming method."""
479
480    def client_streaming_stub(  # pylint: disable=no-self-use
481        self, unused_method: ProtoServiceMethod, output: OutputFile
482    ) -> None:
483        """Returns the stub for this client streaming method."""
484        output.write_line(STUB_READER_TODO)
485        output.write_line('static_cast<void>(reader);')
486
487    @abc.abstractmethod
488    def bidirectional_streaming_signature(
489        self, method: ProtoServiceMethod, prefix: str
490    ) -> str:
491        """Returns the signature of this bidirectional streaming method."""
492
493    def bidirectional_streaming_stub(  # pylint: disable=no-self-use
494        self, unused_method: ProtoServiceMethod, output: OutputFile
495    ) -> None:
496        """Returns the stub for this bidirectional streaming method."""
497        output.write_line(STUB_READER_WRITER_TODO)
498        output.write_line('static_cast<void>(reader_writer);')
499
500
501def _select_stub_methods(gen: StubGenerator, method: ProtoServiceMethod):
502    if method.type() is ProtoServiceMethod.Type.UNARY:
503        return gen.unary_signature, gen.unary_stub
504
505    if method.type() is ProtoServiceMethod.Type.SERVER_STREAMING:
506        return gen.server_streaming_signature, gen.server_streaming_stub
507
508    if method.type() is ProtoServiceMethod.Type.CLIENT_STREAMING:
509        return gen.client_streaming_signature, gen.client_streaming_stub
510
511    if method.type() is ProtoServiceMethod.Type.BIDIRECTIONAL_STREAMING:
512        return (
513            gen.bidirectional_streaming_signature,
514            gen.bidirectional_streaming_stub,
515        )
516
517    raise NotImplementedError(f'Unrecognized method type {method.type()}')
518
519
520_STUBS_COMMENT = r'''
521/*
522    ____                __                          __        __  _
523   /  _/___ ___  ____  / /__  ____ ___  ___  ____  / /_____ _/ /_(_)___  ____
524   / // __ `__ \/ __ \/ / _ \/ __ `__ \/ _ \/ __ \/ __/ __ `/ __/ / __ \/ __ \
525 _/ // / / / / / /_/ / /  __/ / / / / /  __/ / / / /_/ /_/ / /_/ / /_/ / / / /
526/___/_/ /_/ /_/ .___/_/\___/_/ /_/ /_/\___/_/ /_/\__/\__,_/\__/_/\____/_/ /_/
527             /_/
528   _____ __        __         __
529  / ___// /___  __/ /_  _____/ /
530  \__ \/ __/ / / / __ \/ ___/ /
531 ___/ / /_/ /_/ / /_/ (__  )_/
532/____/\__/\__,_/_.___/____(_)
533
534*/
535// This section provides stub implementations of the RPC services in this file.
536// The code below may be referenced or copied to serve as a starting point for
537// your RPC service implementations.
538'''
539
540
541def package_stubs(
542    proto_package: ProtoNode, gen: CodeGenerator, stub_generator: StubGenerator
543) -> None:
544    """Generates the RPC stubs for a package."""
545    if proto_package.cpp_namespace(codegen_subnamespace=None):
546        file_ns = proto_package.cpp_namespace(codegen_subnamespace=None)
547        if file_ns.startswith('::'):
548            file_ns = file_ns[2:]
549
550        def start_ns():
551            return gen.line(f'namespace {file_ns} {{\n')
552
553        def finish_ns():
554            return gen.line(f'}}  // namespace {file_ns}\n')
555
556    else:
557        start_ns = finish_ns = lambda: None
558
559    services = [
560        cast(ProtoService, node)
561        for node in proto_package
562        if node.type() == ProtoNode.Type.SERVICE
563    ]
564
565    gen.line('#ifdef _PW_RPC_COMPILE_GENERATED_SERVICE_STUBS')
566    gen.line(_STUBS_COMMENT)
567
568    gen.line(f'#include "{gen.output.name()}"\n')
569
570    start_ns()
571
572    for node in services:
573        _service_declaration_stub(node, gen, stub_generator)
574
575    gen.line()
576
577    finish_ns()
578
579    start_ns()
580
581    for node in services:
582        _service_definition_stub(node, gen, stub_generator)
583        gen.line()
584
585    finish_ns()
586
587    gen.line('#endif  // _PW_RPC_COMPILE_GENERATED_SERVICE_STUBS')
588
589
590def _service_declaration_stub(
591    service: ProtoService, gen: CodeGenerator, stub_generator: StubGenerator
592) -> None:
593    gen.line(f'// Implementation class for {service.proto_path()}.')
594    gen.line(
595        f'class {service.name()} : public pw_rpc::{gen.name()}::'
596        f'{service.name()}::Service<{service.name()}> {{'
597    )
598
599    gen.line(' public:')
600
601    with gen.indent():
602        blank_line = False
603
604        for method in service.methods():
605            if blank_line:
606                gen.line()
607            else:
608                blank_line = True
609
610            signature, _ = _select_stub_methods(stub_generator, method)
611
612            gen.line(signature(method, '') + ';')
613
614    gen.line('};\n')
615
616
617def _service_definition_stub(
618    service: ProtoService, gen: CodeGenerator, stub_generator: StubGenerator
619) -> None:
620    gen.line(f'// Method definitions for {service.proto_path()}.')
621
622    blank_line = False
623
624    for method in service.methods():
625        if blank_line:
626            gen.line()
627        else:
628            blank_line = True
629
630        signature, stub = _select_stub_methods(stub_generator, method)
631
632        gen.line(signature(method, f'{service.name()}::') + ' {')
633        with gen.indent():
634            stub(method, gen.output)
635        gen.line('}')
636