xref: /aosp_15_r20/external/pigweed/pw_protobuf/py/pw_protobuf/codegen_pwpb.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2023 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""This module defines the generated code for pw_protobuf C++ classes."""
15
16import abc
17from dataclasses import dataclass
18import enum
19
20# Type ignore here for graphlib-backport on Python 3.8
21from graphlib import CycleError, TopologicalSorter  # type: ignore
22from itertools import takewhile
23import os
24import sys
25from typing import Iterable, Type
26from typing import cast
27
28from google.protobuf import descriptor_pb2
29
30from pw_protobuf.output_file import OutputFile
31from pw_protobuf.proto_tree import ProtoEnum, ProtoMessage, ProtoMessageField
32from pw_protobuf.proto_tree import ProtoNode
33from pw_protobuf.proto_tree import build_node_tree
34from pw_protobuf.proto_tree import EXTERNAL_SYMBOL_WORKAROUND_NAMESPACE
35
36PLUGIN_NAME = 'pw_protobuf'
37PLUGIN_VERSION = '0.1.0'
38
39PROTO_H_EXTENSION = '.pwpb.h'
40PROTO_CC_EXTENSION = '.pwpb.cc'
41
42PROTOBUF_NAMESPACE = '::pw::protobuf'
43_INTERNAL_NAMESPACE = '::pw::protobuf::internal'
44
45
46@dataclass
47class GeneratorOptions:
48    oneof_callbacks: bool
49    exclude_legacy_snake_case_field_name_enums: bool
50    suppress_legacy_namespace: bool
51
52
53class CodegenError(Exception):
54    def __init__(
55        self,
56        error_message: str,
57        node: ProtoNode,
58        field: ProtoMessageField | None,
59    ):
60        super().__init__(f'pwpb codegen error: {error_message}')
61        self.error_message = error_message
62        self.node = node
63        self.field = field
64
65    def formatted_message(self) -> str:
66        lines = [
67            f'pwpb codegen error: {self.error_message}',
68            f'    at {self.node.proto_path()}',
69        ]
70
71        if self.field is not None:
72            lines.append(f'    in field {self.field.name()}')
73
74        return '\n'.join(lines)
75
76
77class ClassType(enum.Enum):
78    """Type of class."""
79
80    MEMORY_ENCODER = 1
81    STREAMING_ENCODER = 2
82    # MEMORY_DECODER = 3
83    STREAMING_DECODER = 4
84
85    def base_class_name(self) -> str:
86        """Returns the base class used by this class type."""
87        if self is self.STREAMING_ENCODER:
88            return 'StreamEncoder'
89        if self is self.MEMORY_ENCODER:
90            return 'MemoryEncoder'
91        if self is self.STREAMING_DECODER:
92            return 'StreamDecoder'
93
94        raise ValueError('Unknown class type')
95
96    def codegen_class_name(self) -> str:
97        """Returns the base class used by this class type."""
98        if self is self.STREAMING_ENCODER:
99            return 'StreamEncoder'
100        if self is self.MEMORY_ENCODER:
101            return 'MemoryEncoder'
102        if self is self.STREAMING_DECODER:
103            return 'StreamDecoder'
104
105        raise ValueError('Unknown class type')
106
107    def is_encoder(self) -> bool:
108        """Returns True if this class type is an encoder."""
109        if self is self.STREAMING_ENCODER:
110            return True
111        if self is self.MEMORY_ENCODER:
112            return True
113        if self is self.STREAMING_DECODER:
114            return False
115
116        raise ValueError('Unknown class type')
117
118
119# protoc captures stdout, so we need to printf debug to stderr.
120def debug_print(*args, **kwargs):
121    print(*args, file=sys.stderr, **kwargs)
122
123
124class _CallbackType(enum.Enum):
125    NONE = 0
126    SINGLE_FIELD = 1
127    ONEOF_GROUP = 2
128
129    def as_cpp(self) -> str:
130        match self:
131            case _CallbackType.NONE:
132                return 'kNone'
133            case _CallbackType.SINGLE_FIELD:
134                return 'kSingleField'
135            case _CallbackType.ONEOF_GROUP:
136                return 'kOneOfGroup'
137
138
139class ProtoMember(abc.ABC):
140    """Base class for a C++ class member for a field in a protobuf message."""
141
142    def __init__(
143        self,
144        codegen_options: GeneratorOptions,
145        field: ProtoMessageField,
146        scope: ProtoNode,
147        root: ProtoNode,
148    ):
149        """Creates an instance of a class member.
150
151        Args:
152          field: the ProtoMessageField to which the method belongs.
153          scope: the ProtoNode namespace in which the method is being defined.
154        """
155        self._codegen_options: GeneratorOptions = codegen_options
156        self._field: ProtoMessageField = field
157        self._scope: ProtoNode = scope
158        self._root: ProtoNode = root
159
160    @abc.abstractmethod
161    def name(self) -> str:
162        """Returns the name of the member, e.g. DoSomething."""
163
164    @abc.abstractmethod
165    def should_appear(self) -> bool:  # pylint: disable=no-self-use
166        """Whether the member should be generated."""
167
168    @abc.abstractmethod
169    def _use_callback(self) -> bool:
170        """Whether the member should be encoded and decoded with a callback."""
171
172    def callback_type(self) -> _CallbackType:
173        if (
174            self._codegen_options.oneof_callbacks
175            and self._field.oneof() is not None
176        ):
177            return _CallbackType.ONEOF_GROUP
178
179        options = self._field.options()
180        assert options is not None
181
182        if options.use_callback or self._use_callback():
183            return _CallbackType.SINGLE_FIELD
184        return _CallbackType.NONE
185
186    def field_cast(self) -> str:
187        return 'static_cast<uint32_t>(Fields::{})'.format(
188            self._field.enum_name()
189        )
190
191    def _relative_type_namespace(self, from_root: bool = False) -> str:
192        """Returns relative namespace between member's scope and field type."""
193        scope = self._root if from_root else self._scope
194        type_node = self._field.type_node()
195        assert type_node is not None
196
197        # If a class method is referencing its class, the namespace provided
198        # must be from the root or it will be empty.
199        if type_node == scope:
200            scope = self._root
201
202        ancestor = scope.common_ancestor(type_node)
203        namespace = type_node.cpp_namespace(ancestor)
204
205        assert namespace
206        return namespace
207
208
209class ProtoMethod(ProtoMember):
210    """Base class for a C++ method for a field in a protobuf message."""
211
212    def __init__(
213        self,
214        codegen_options: GeneratorOptions,
215        field: ProtoMessageField,
216        scope: ProtoNode,
217        root: ProtoNode,
218        base_class: str,
219    ):
220        super().__init__(codegen_options, field, scope, root)
221        self._base_class: str = base_class
222
223    @abc.abstractmethod
224    def params(self) -> list[tuple[str, str]]:
225        """Returns the parameters of the method as a list of (type, name) pairs.
226
227        e.g.
228        [('int', 'foo'), ('const char*', 'bar')]
229        """
230
231    @abc.abstractmethod
232    def body(self) -> list[str]:
233        """Returns the method body as a list of source code lines.
234
235        e.g.
236        [
237          'int baz = bar[foo];',
238          'return (baz ^ foo) >> 3;'
239        ]
240        """
241
242    @abc.abstractmethod
243    def return_type(self, from_root: bool = False) -> str:
244        """Returns the return type of the method, e.g. int.
245
246        For non-primitive return types, the from_root argument determines
247        whether the namespace should be relative to the message's scope
248        (default) or the root scope.
249        """
250
251    @abc.abstractmethod
252    def in_class_definition(self) -> bool:
253        """Determines where the method should be defined.
254
255        Returns True if the method definition should be inlined in its class
256        definition, or False if it should be declared in the class and defined
257        later.
258        """
259
260    def should_appear(self) -> bool:  # pylint: disable=no-self-use
261        """Whether the method should be generated."""
262        return True
263
264    def _use_callback(self) -> bool:  # pylint: disable=no-self-use
265        return False
266
267    def param_string(self) -> str:
268        return ', '.join([f'{type} {name}' for type, name in self.params()])
269
270
271class WriteMethod(ProtoMethod):
272    """Base class representing an encoder write method.
273
274    Write methods have following format (for the proto field foo):
275
276        Status WriteFoo({params...}) {
277          return encoder_->Write{type}(kFoo, {params...});
278        }
279
280    """
281
282    def name(self) -> str:
283        return 'Write{}'.format(self._field.name())
284
285    def return_type(self, from_root: bool = False) -> str:
286        return '::pw::Status'
287
288    def body(self) -> list[str]:
289        params = ', '.join([pair[1] for pair in self.params()])
290        line = 'return {}::{}({}, {});'.format(
291            self._base_class, self._encoder_fn(), self.field_cast(), params
292        )
293        return [line]
294
295    def params(self) -> list[tuple[str, str]]:
296        """Method parameters, defined in subclasses."""
297        raise NotImplementedError()
298
299    def in_class_definition(self) -> bool:
300        return True
301
302    def _encoder_fn(self) -> str:
303        """The encoder function to call.
304
305        Defined in subclasses.
306
307        e.g. 'WriteUint32', 'WriteBytes', etc.
308        """
309        raise NotImplementedError()
310
311
312class PackedWriteMethod(WriteMethod):
313    """A method for a writing a packed repeated field.
314
315    Same as a WriteMethod, but is only generated for repeated fields.
316    """
317
318    def should_appear(self) -> bool:
319        return self._field.is_repeated()
320
321    def _encoder_fn(self) -> str:
322        raise NotImplementedError()
323
324
325class ReadMethod(ProtoMethod):
326    """Base class representing an decoder read method.
327
328    Read methods have following format (for the proto field foo):
329
330        Result<{ctype}> ReadFoo({params...}) {
331          Result<uint32_t> field_number = FieldNumber();
332          PW_ASSERT(field_number.ok());
333          PW_ASSERT(field_number.value() ==
334                    static_cast<uint32_t>(Fields::kFoo));
335          return decoder_->Read{type}({params...});
336        }
337
338    """
339
340    def name(self) -> str:
341        return 'Read{}'.format(self._field.name())
342
343    def return_type(self, from_root: bool = False) -> str:
344        return '::pw::Result<{}>'.format(self._result_type())
345
346    def _result_type(self) -> str:
347        """The type returned by the deoder function.
348
349        Defined in subclasses.
350
351        e.g. 'uint32_t', 'pw::span<std::byte>', etc.
352        """
353        raise NotImplementedError()
354
355    def body(self) -> list[str]:
356        lines: list[str] = []
357        lines += ['::pw::Result<uint32_t> field_number = FieldNumber();']
358        lines += ['PW_ASSERT(field_number.ok());']
359        lines += [
360            'PW_ASSERT(field_number.value() == {});'.format(self.field_cast())
361        ]
362        lines += self._decoder_body()
363        return lines
364
365    def _decoder_body(self) -> list[str]:
366        """Returns the decoder body part as a list of source code lines."""
367        params = ', '.join([pair[1] for pair in self.params()])
368        line = 'return {}::{}({});'.format(
369            self._base_class, self._decoder_fn(), params
370        )
371        return [line]
372
373    def _decoder_fn(self) -> str:
374        """The decoder function to call.
375
376        Defined in subclasses.
377
378        e.g. 'ReadUint32', 'ReadBytes', etc.
379        """
380        raise NotImplementedError()
381
382    def params(self) -> list[tuple[str, str]]:
383        """Method parameters, can be overriden in subclasses."""
384        return []
385
386    def in_class_definition(self) -> bool:
387        return True
388
389
390class PackedReadMethod(ReadMethod):
391    """A method for a reading a packed repeated field.
392
393    Same as ReadMethod, but is only generated for repeated fields.
394    """
395
396    def should_appear(self) -> bool:
397        return self._field.is_repeated()
398
399    def return_type(self, from_root: bool = False) -> str:
400        return '::pw::StatusWithSize'
401
402    def params(self) -> list[tuple[str, str]]:
403        return [('pw::span<{}>'.format(self._result_type()), 'out')]
404
405
406class PackedReadVectorMethod(ReadMethod):
407    """A method for a reading a packed repeated field.
408
409    An alternative to ReadMethod for repeated fields that appends values into
410    a pw::Vector.
411    """
412
413    def should_appear(self) -> bool:
414        return self._field.is_repeated()
415
416    def return_type(self, from_root: bool = False) -> str:
417        return '::pw::Status'
418
419    def params(self) -> list[tuple[str, str]]:
420        return [('::pw::Vector<{}>&'.format(self._result_type()), 'out')]
421
422
423class FindMethod(ReadMethod):
424    """A method for finding a field within a serialized message."""
425
426    def name(self) -> str:
427        return 'Find{}'.format(self._field.name())
428
429    def params(self) -> list[tuple[str, str]]:
430        return [('::pw::ConstByteSpan', 'message')]
431
432    def return_type(self, from_root: bool = False) -> str:
433        if self._field.is_repeated():
434            return f'::pw::protobuf::{self._finder()}'
435        return '::pw::Result<{}>'.format(self._result_type())
436
437    def body(self) -> list[str]:
438        lines: list[str] = []
439        if self._field.is_repeated():
440            lines.append(
441                f'return ::pw::protobuf::{self._finder()}'
442                f'(message, {self.field_cast()});'
443            )
444        else:
445            lines += [
446                f'return {PROTOBUF_NAMESPACE}::{self._find_fn()}'
447                f'(message, {self.field_cast()});'
448            ]
449        return lines
450
451    def _find_fn(self) -> str:
452        """The find function to call.
453
454        Defined in subclasses.
455
456        e.g. 'FindUint32', 'FindBytes', etc.
457        """
458        raise NotImplementedError()
459
460    def _finder(self) -> str:
461        """Type of the finder object for the field type."""
462        raise NotImplementedError(f'xdd {self.__class__}')
463
464
465class FindStreamMethod(FindMethod):
466    def name(self) -> str:
467        return 'Find{}'.format(self._field.name())
468
469    def params(self) -> list[tuple[str, str]]:
470        return [('::pw::stream::Reader&', 'message_stream')]
471
472    def body(self) -> list[str]:
473        lines: list[str] = []
474        if self._field.is_repeated():
475            lines.append(
476                f'return ::pw::protobuf::{self._finder()}'
477                f'(message_stream, {self.field_cast()});'
478            )
479        else:
480            lines += [
481                f'return {PROTOBUF_NAMESPACE}::{self._find_fn()}'
482                f'(message_stream, {self.field_cast()});'
483            ]
484        return lines
485
486
487class MessageProperty(ProtoMember):
488    """Base class for a C++ property for a field in a protobuf message."""
489
490    def name(self) -> str:
491        return self._field.field_name()
492
493    def should_appear(self) -> bool:
494        # Oneof fields are not supported by the code generator.
495        oneof = self._field.oneof()
496        if self._codegen_options.oneof_callbacks and oneof is not None:
497            return oneof.is_synthetic()
498
499        return True
500
501    @abc.abstractmethod
502    def type_name(self, from_root: bool = False) -> str:
503        """Returns the type of the property, e.g. uint32_t."""
504
505    @abc.abstractmethod
506    def wire_type(self) -> str:
507        """Returns the wire type of the property, e.g. kVarint."""
508
509    def varint_decode_type(self) -> str:
510        """Returns the varint decoding type of the property, e.g. kZigZag.
511
512        Defined in subclasses that return kVarint for wire_type().
513        """
514        raise NotImplementedError()
515
516    def is_string(self) -> bool:  # pylint: disable=no-self-use
517        """True if this field is a string field (as opposed to bytes)."""
518        return False
519
520    @staticmethod
521    def repeated_field_container(type_name: str, max_size: str) -> str:
522        """Returns the container type used for repeated fields.
523
524        Defaults to ::pw::Vector<type, max_size>. String fields use
525        ::pw::InlineString<max_size> instead.
526        """
527        return f'::pw::Vector<{type_name}, {max_size}>'
528
529    def _use_callback(self) -> bool:  # pylint: disable=no-self-use
530        """Returns whether the decoder should use a callback."""
531        return self._field.is_repeated() and self.max_size() == 0
532
533    def is_optional(self) -> bool:
534        """Returns whether the decoder should use std::optional."""
535        return (
536            self._field.has_presence()
537            and self.max_size() == 0
538            and self.wire_type() != 'kDelimited'
539        )
540
541    def is_repeated(self) -> bool:
542        return self._field.is_repeated()
543
544    def max_size(self) -> int:
545        """Returns the maximum size of the field."""
546        if self._field.is_repeated():
547            options = self._field.options()
548            assert options is not None
549            return options.max_count
550
551        return 0
552
553    def is_fixed_size(self) -> bool:
554        """Returns whether the decoder should use a fixed sized field."""
555        if self._field.is_repeated():
556            options = self._field.options()
557            assert options is not None
558            return options.fixed_count
559
560        return False
561
562    def sub_table(self) -> str:  # pylint: disable=no-self-use
563        return '{}'
564
565    def struct_member_type(self, from_root: bool = False) -> str:
566        """Returns the structure member type."""
567        if self.callback_type() is _CallbackType.SINGLE_FIELD:
568            return (
569                f'{PROTOBUF_NAMESPACE}::Callback<StreamEncoder, StreamDecoder>'
570            )
571
572        # Optional fields are wrapped in std::optional
573        if self.is_optional():
574            return 'std::optional<{}>'.format(self.type_name(from_root))
575
576        # Non-repeated fields have a member of just the type name.
577        max_size = self.max_size()
578        if max_size == 0:
579            return self.type_name(from_root)
580
581        # Fixed size fields use std::array.
582        if self.is_fixed_size():
583            return 'std::array<{}, {}>'.format(
584                self.type_name(from_root), self.max_size_constant_name()
585            )
586
587        # Otherwise prefer pw::Vector for repeated fields.
588        return self.repeated_field_container(
589            self.type_name(from_root), self.max_size_constant_name()
590        )
591
592    def max_size_constant_name(self) -> str:
593        return f'k{self._field.name()}MaxSize'
594
595    def _varint_type_table_entry(self) -> str:
596        if self.wire_type() == 'kVarint':
597            return '{}::VarintType::{}'.format(
598                _INTERNAL_NAMESPACE, self.varint_decode_type()
599            )
600
601        return f'static_cast<{_INTERNAL_NAMESPACE}::VarintType>(0)'
602
603    def _wire_type_table_entry(self) -> str:
604        return '{}::WireType::{}'.format(PROTOBUF_NAMESPACE, self.wire_type())
605
606    def _elem_size_table_entry(self) -> str:
607        return 'sizeof({})'.format(self.type_name())
608
609    def _bool_attr(self, attr: str) -> str:
610        """C++ string for a bool argument that includes the argument name."""
611        return f'/*{attr}=*/{bool(getattr(self, attr)())}'.lower()
612
613    def table_entry(self) -> list[str]:
614        """Table entry."""
615
616        oneof = self._field.oneof()
617        if (
618            self._codegen_options.oneof_callbacks
619            and oneof is not None
620            and not oneof.is_synthetic()
621        ):
622            struct_member = oneof.name
623        else:
624            struct_member = self.name()
625
626        return [
627            self.field_cast(),
628            self._wire_type_table_entry(),
629            self._elem_size_table_entry(),
630            self._varint_type_table_entry(),
631            self._bool_attr('is_string'),
632            self._bool_attr('is_fixed_size'),
633            self._bool_attr('is_repeated'),
634            self._bool_attr('is_optional'),
635            (
636                f'{_INTERNAL_NAMESPACE}::CallbackType::'
637                + self.callback_type().as_cpp()
638            ),
639            'offsetof(Message, {})'.format(struct_member),
640            'sizeof(Message::{})'.format(struct_member),
641            self.sub_table(),
642        ]
643
644    @abc.abstractmethod
645    def _size_fn(self) -> str:
646        """Returns the name of the field size function."""
647
648    def _size_length(self) -> str | None:  # pylint: disable=no-self-use
649        """Returns the length to add to the maximum encoded size."""
650        return None
651
652    def max_encoded_size(self) -> str:
653        """Returns a constant expression for field's maximum encoded size."""
654        size_call = '{}::{}({})'.format(
655            PROTOBUF_NAMESPACE, self._size_fn(), self.field_cast()
656        )
657
658        if self.is_repeated():
659            # We have to assume the worst-case: a non-packed repeated field,
660            # which is encoded as one record per entry.
661            # https://protobuf.dev/programming-guides/encoding/#packed
662            if self.max_size():
663                size_call += f' * {self.max_size_constant_name()}'
664            else:
665                # TODO: https://pwbug.dev/379868242 - Change this to return
666                # None to indicate that we don't know the maximum encoded size,
667                # because the field is unconstrained.
668                msg = 'TODO: https://pwbug.dev/379868242 - Max size unknown!'
669                size_call += f' /* {msg} */'
670
671        size_length: str | None = self._size_length()
672        if size_length is None:
673            return size_call
674
675        return f'{size_call} + {size_length}'
676
677    def include_in_scratch_size(self) -> bool:  # pylint: disable=no-self-use
678        """Returns whether the field contributes to the scratch buffer size."""
679        return False
680
681
682#
683# The following code defines write and read methods for each of the
684# complex protobuf types.
685#
686
687
688class SubMessageEncoderMethod(ProtoMethod):
689    """Method which returns a sub-message encoder."""
690
691    def name(self) -> str:
692        return 'Get{}Encoder'.format(self._field.name())
693
694    def return_type(self, from_root: bool = False) -> str:
695        return '{}::StreamEncoder'.format(
696            self._relative_type_namespace(from_root)
697        )
698
699    def params(self) -> list[tuple[str, str]]:
700        return []
701
702    def body(self) -> list[str]:
703        line = 'return {}::StreamEncoder({}::GetNestedEncoder({}));'.format(
704            self._relative_type_namespace(), self._base_class, self.field_cast()
705        )
706        return [line]
707
708    # Submessage methods are not defined within the class itself because the
709    # submessage class may not yet have been defined.
710    def in_class_definition(self) -> bool:
711        return False
712
713
714class SubMessageDecoderMethod(ReadMethod):
715    """Method which returns a sub-message decoder."""
716
717    def name(self) -> str:
718        return 'Get{}Decoder'.format(self._field.name())
719
720    def return_type(self, from_root: bool = False) -> str:
721        return '{}::StreamDecoder'.format(
722            self._relative_type_namespace(from_root)
723        )
724
725    def _decoder_body(self) -> list[str]:
726        line = 'return {}::StreamDecoder(GetNestedDecoder());'.format(
727            self._relative_type_namespace()
728        )
729        return [line]
730
731    # Submessage methods are not defined within the class itself because the
732    # submessage class may not yet have been defined.
733    def in_class_definition(self) -> bool:
734        return False
735
736
737class SubMessageFindMethod(FindMethod):
738    """Method which reads a proto submessage."""
739
740    def _result_type(self) -> str:
741        return '::pw::ConstByteSpan'
742
743    def _find_fn(self) -> str:
744        return 'FindBytes'
745
746    def _finder(self) -> str:
747        return 'BytesFinder'
748
749
750class SubMessageProperty(MessageProperty):
751    """Property which contains a sub-message."""
752
753    def __init__(
754        self,
755        codegen_options: GeneratorOptions,
756        field: ProtoMessageField,
757        scope: ProtoNode,
758        root: ProtoNode,
759    ):
760        super().__init__(codegen_options, field, scope, root)
761
762        if self._field.is_repeated() and (
763            self.max_size() != 0 or self.is_fixed_size()
764        ):
765            raise CodegenError(
766                'Repeated messages cannot set a max_count or fixed_count',
767                scope,
768                field,
769            )
770
771    def _dependency_removed(self) -> bool:
772        """Returns true if the message dependency was removed to break a cycle.
773
774        Proto allows cycles between messages, but C++ doesn't allow cycles
775        between class references. So when we're forced to break one, the
776        struct member is replaced with a callback.
777        """
778        type_node = self._field.type_node()
779        assert type_node is not None
780        return type_node in cast(ProtoMessage, self._scope).dependency_cycles()
781
782    def _elem_size_table_entry(self) -> str:
783        # Since messages can't be repeated (as we couldn't set callbacks),
784        # only field size is used. Set elem_size to 0 so space can be saved by
785        # not using more than 4 bits for it.
786        return '0'
787
788    def type_name(self, from_root: bool = False) -> str:
789        return '{}::Message'.format(self._relative_type_namespace(from_root))
790
791    def _use_callback(self) -> bool:
792        # Always use a callback for a message dependency removed to break a
793        # cycle, and for repeated fields, since in both cases there's no way
794        # to handle the size of nested field.
795        return self._dependency_removed() or self._field.is_repeated()
796
797    def wire_type(self) -> str:
798        return 'kDelimited'
799
800    def sub_table(self) -> str:
801        if self.callback_type() is not _CallbackType.NONE:
802            return 'nullptr'
803
804        return '&{}::kMessageFields'.format(self._relative_type_namespace())
805
806    def _size_fn(self) -> str:
807        # This uses the WithoutValue method to ensure that the maximum length
808        # of the delimited field size varint is used. This is because the nested
809        # message might include callbacks and be longer than we expect, and to
810        # account for scratch overhead when used with MemoryEncoder.
811        return 'SizeOfDelimitedFieldWithoutValue'
812
813    def _size_length(self) -> str | None:
814        if self.callback_type() is not _CallbackType.NONE:
815            return None
816
817        return '{}::kMaxEncodedSizeBytes'.format(
818            self._relative_type_namespace()
819        )
820
821    def include_in_scratch_size(self) -> bool:
822        return True
823
824
825class BytesReaderMethod(ReadMethod):
826    """Method which returns a bytes reader."""
827
828    def name(self) -> str:
829        return 'Get{}Reader'.format(self._field.name())
830
831    def return_type(self, from_root: bool = False) -> str:
832        return f'{PROTOBUF_NAMESPACE}::StreamDecoder::BytesReader'
833
834    def _decoder_fn(self) -> str:
835        return 'GetBytesReader'
836
837
838#
839# The following code defines write and read methods for each of the
840# primitive protobuf types.
841#
842
843
844class DoubleWriteMethod(WriteMethod):
845    """Method which writes a proto double value."""
846
847    def params(self) -> list[tuple[str, str]]:
848        return [('double', 'value')]
849
850    def _encoder_fn(self) -> str:
851        return 'WriteDouble'
852
853
854class PackedDoubleWriteMethod(PackedWriteMethod):
855    """Method which writes a packed list of doubles."""
856
857    def params(self) -> list[tuple[str, str]]:
858        return [('pw::span<const double>', 'values')]
859
860    def _encoder_fn(self) -> str:
861        return 'WritePackedDouble'
862
863
864class PackedDoubleWriteVectorMethod(PackedWriteMethod):
865    """Method which writes a packed vector of doubles."""
866
867    def params(self) -> list[tuple[str, str]]:
868        return [('const ::pw::Vector<double>&', 'values')]
869
870    def _encoder_fn(self) -> str:
871        return 'WriteRepeatedDouble'
872
873
874class DoubleReadMethod(ReadMethod):
875    """Method which reads a proto double value."""
876
877    def _result_type(self) -> str:
878        return 'double'
879
880    def _decoder_fn(self) -> str:
881        return 'ReadDouble'
882
883
884class PackedDoubleReadMethod(PackedReadMethod):
885    """Method which reads packed double values."""
886
887    def _result_type(self) -> str:
888        return 'double'
889
890    def _decoder_fn(self) -> str:
891        return 'ReadPackedDouble'
892
893
894class PackedDoubleReadVectorMethod(PackedReadVectorMethod):
895    """Method which reads packed double values."""
896
897    def _result_type(self) -> str:
898        return 'double'
899
900    def _decoder_fn(self) -> str:
901        return 'ReadRepeatedDouble'
902
903
904class DoubleFindMethod(FindMethod):
905    """Method which reads a proto double value."""
906
907    def _result_type(self) -> str:
908        return 'double'
909
910    def _find_fn(self) -> str:
911        return 'FindDouble'
912
913    def _finder(self) -> str:
914        return 'DoubleFinder'
915
916
917class DoubleFindStreamMethod(FindStreamMethod):
918    """Method which reads a proto double value."""
919
920    def _result_type(self) -> str:
921        return 'double'
922
923    def _find_fn(self) -> str:
924        return 'FindDouble'
925
926    def _finder(self) -> str:
927        return 'DoubleStreamFinder'
928
929
930class DoubleProperty(MessageProperty):
931    """Property which holds a proto double value."""
932
933    def type_name(self, from_root: bool = False) -> str:
934        return 'double'
935
936    def wire_type(self) -> str:
937        return 'kFixed64'
938
939    def _size_fn(self) -> str:
940        return 'SizeOfFieldDouble'
941
942
943class FloatWriteMethod(WriteMethod):
944    """Method which writes a proto float value."""
945
946    def params(self) -> list[tuple[str, str]]:
947        return [('float', 'value')]
948
949    def _encoder_fn(self) -> str:
950        return 'WriteFloat'
951
952
953class PackedFloatWriteMethod(PackedWriteMethod):
954    """Method which writes a packed list of floats."""
955
956    def params(self) -> list[tuple[str, str]]:
957        return [('pw::span<const float>', 'values')]
958
959    def _encoder_fn(self) -> str:
960        return 'WritePackedFloat'
961
962
963class PackedFloatWriteVectorMethod(PackedWriteMethod):
964    """Method which writes a packed vector of floats."""
965
966    def params(self) -> list[tuple[str, str]]:
967        return [('const ::pw::Vector<float>&', 'values')]
968
969    def _encoder_fn(self) -> str:
970        return 'WriteRepeatedFloat'
971
972
973class FloatReadMethod(ReadMethod):
974    """Method which reads a proto float value."""
975
976    def _result_type(self) -> str:
977        return 'float'
978
979    def _decoder_fn(self) -> str:
980        return 'ReadFloat'
981
982
983class PackedFloatReadMethod(PackedReadMethod):
984    """Method which reads packed float values."""
985
986    def _result_type(self) -> str:
987        return 'float'
988
989    def _decoder_fn(self) -> str:
990        return 'ReadPackedFloat'
991
992
993class PackedFloatReadVectorMethod(PackedReadVectorMethod):
994    """Method which reads packed float values."""
995
996    def _result_type(self) -> str:
997        return 'float'
998
999    def _decoder_fn(self) -> str:
1000        return 'ReadRepeatedFloat'
1001
1002
1003class FloatFindMethod(FindMethod):
1004    """Method which reads a proto float value."""
1005
1006    def _result_type(self) -> str:
1007        return 'float'
1008
1009    def _find_fn(self) -> str:
1010        return 'FindFloat'
1011
1012    def _finder(self) -> str:
1013        return 'FloatFinder'
1014
1015
1016class FloatFindStreamMethod(FindStreamMethod):
1017    """Method which reads a proto float value."""
1018
1019    def _result_type(self) -> str:
1020        return 'float'
1021
1022    def _find_fn(self) -> str:
1023        return 'FindFloat'
1024
1025    def _finder(self) -> str:
1026        return 'FloatStreamFinder'
1027
1028
1029class FloatProperty(MessageProperty):
1030    """Property which holds a proto float value."""
1031
1032    def type_name(self, from_root: bool = False) -> str:
1033        return 'float'
1034
1035    def wire_type(self) -> str:
1036        return 'kFixed32'
1037
1038    def _size_fn(self) -> str:
1039        return 'SizeOfFieldFloat'
1040
1041
1042class Int32WriteMethod(WriteMethod):
1043    """Method which writes a proto int32 value."""
1044
1045    def params(self) -> list[tuple[str, str]]:
1046        return [('int32_t', 'value')]
1047
1048    def _encoder_fn(self) -> str:
1049        return 'WriteInt32'
1050
1051
1052class PackedInt32WriteMethod(PackedWriteMethod):
1053    """Method which writes a packed list of int32."""
1054
1055    def params(self) -> list[tuple[str, str]]:
1056        return [('pw::span<const int32_t>', 'values')]
1057
1058    def _encoder_fn(self) -> str:
1059        return 'WritePackedInt32'
1060
1061
1062class PackedInt32WriteVectorMethod(PackedWriteMethod):
1063    """Method which writes a packed vector of int32."""
1064
1065    def params(self) -> list[tuple[str, str]]:
1066        return [('const ::pw::Vector<int32_t>&', 'values')]
1067
1068    def _encoder_fn(self) -> str:
1069        return 'WriteRepeatedInt32'
1070
1071
1072class Int32ReadMethod(ReadMethod):
1073    """Method which reads a proto int32 value."""
1074
1075    def _result_type(self) -> str:
1076        return 'int32_t'
1077
1078    def _decoder_fn(self) -> str:
1079        return 'ReadInt32'
1080
1081
1082class PackedInt32ReadMethod(PackedReadMethod):
1083    """Method which reads packed int32 values."""
1084
1085    def _result_type(self) -> str:
1086        return 'int32_t'
1087
1088    def _decoder_fn(self) -> str:
1089        return 'ReadPackedInt32'
1090
1091
1092class PackedInt32ReadVectorMethod(PackedReadVectorMethod):
1093    """Method which reads packed int32 values."""
1094
1095    def _result_type(self) -> str:
1096        return 'int32_t'
1097
1098    def _decoder_fn(self) -> str:
1099        return 'ReadRepeatedInt32'
1100
1101
1102class Int32FindMethod(FindMethod):
1103    """Method which reads a proto int32 value."""
1104
1105    def _result_type(self) -> str:
1106        return 'int32_t'
1107
1108    def _find_fn(self) -> str:
1109        return 'FindInt32'
1110
1111    def _finder(self) -> str:
1112        return 'Int32Finder'
1113
1114
1115class Int32FindStreamMethod(FindStreamMethod):
1116    """Method which reads a proto int32 value."""
1117
1118    def _result_type(self) -> str:
1119        return 'int32_t'
1120
1121    def _find_fn(self) -> str:
1122        return 'FindInt32'
1123
1124    def _finder(self) -> str:
1125        return 'Int32StreamFinder'
1126
1127
1128class Int32Property(MessageProperty):
1129    """Property which holds a proto int32 value."""
1130
1131    def type_name(self, from_root: bool = False) -> str:
1132        return 'int32_t'
1133
1134    def wire_type(self) -> str:
1135        return 'kVarint'
1136
1137    def varint_decode_type(self) -> str:
1138        return 'kNormal'
1139
1140    def _size_fn(self) -> str:
1141        return 'SizeOfFieldInt32'
1142
1143
1144class Sint32WriteMethod(WriteMethod):
1145    """Method which writes a proto sint32 value."""
1146
1147    def params(self) -> list[tuple[str, str]]:
1148        return [('int32_t', 'value')]
1149
1150    def _encoder_fn(self) -> str:
1151        return 'WriteSint32'
1152
1153
1154class PackedSint32WriteMethod(PackedWriteMethod):
1155    """Method which writes a packed list of sint32."""
1156
1157    def params(self) -> list[tuple[str, str]]:
1158        return [('pw::span<const int32_t>', 'values')]
1159
1160    def _encoder_fn(self) -> str:
1161        return 'WritePackedSint32'
1162
1163
1164class PackedSint32WriteVectorMethod(PackedWriteMethod):
1165    """Method which writes a packed vector of sint32."""
1166
1167    def params(self) -> list[tuple[str, str]]:
1168        return [('const ::pw::Vector<int32_t>&', 'values')]
1169
1170    def _encoder_fn(self) -> str:
1171        return 'WriteRepeatedSint32'
1172
1173
1174class Sint32ReadMethod(ReadMethod):
1175    """Method which reads a proto sint32 value."""
1176
1177    def _result_type(self) -> str:
1178        return 'int32_t'
1179
1180    def _decoder_fn(self) -> str:
1181        return 'ReadSint32'
1182
1183
1184class PackedSint32ReadMethod(PackedReadMethod):
1185    """Method which reads packed sint32 values."""
1186
1187    def _result_type(self) -> str:
1188        return 'int32_t'
1189
1190    def _decoder_fn(self) -> str:
1191        return 'ReadPackedSint32'
1192
1193
1194class PackedSint32ReadVectorMethod(PackedReadVectorMethod):
1195    """Method which reads packed sint32 values."""
1196
1197    def _result_type(self) -> str:
1198        return 'int32_t'
1199
1200    def _decoder_fn(self) -> str:
1201        return 'ReadRepeatedSint32'
1202
1203
1204class Sint32FindMethod(FindMethod):
1205    """Method which reads a proto sint32 value."""
1206
1207    def _result_type(self) -> str:
1208        return 'int32_t'
1209
1210    def _find_fn(self) -> str:
1211        return 'FindSint32'
1212
1213    def _finder(self) -> str:
1214        return 'Sint32Finder'
1215
1216
1217class Sint32FindStreamMethod(FindStreamMethod):
1218    """Method which reads a proto sint32 value."""
1219
1220    def _result_type(self) -> str:
1221        return 'int32_t'
1222
1223    def _find_fn(self) -> str:
1224        return 'FindSint32'
1225
1226    def _finder(self) -> str:
1227        return 'Sint32StreamFinder'
1228
1229
1230class Sint32Property(MessageProperty):
1231    """Property which holds a proto sint32 value."""
1232
1233    def type_name(self, from_root: bool = False) -> str:
1234        return 'int32_t'
1235
1236    def wire_type(self) -> str:
1237        return 'kVarint'
1238
1239    def varint_decode_type(self) -> str:
1240        return 'kZigZag'
1241
1242    def _size_fn(self) -> str:
1243        return 'SizeOfFieldSint32'
1244
1245
1246class Sfixed32WriteMethod(WriteMethod):
1247    """Method which writes a proto sfixed32 value."""
1248
1249    def params(self) -> list[tuple[str, str]]:
1250        return [('int32_t', 'value')]
1251
1252    def _encoder_fn(self) -> str:
1253        return 'WriteSfixed32'
1254
1255
1256class PackedSfixed32WriteMethod(PackedWriteMethod):
1257    """Method which writes a packed list of sfixed32."""
1258
1259    def params(self) -> list[tuple[str, str]]:
1260        return [('pw::span<const int32_t>', 'values')]
1261
1262    def _encoder_fn(self) -> str:
1263        return 'WritePackedSfixed32'
1264
1265
1266class PackedSfixed32WriteVectorMethod(PackedWriteMethod):
1267    """Method which writes a packed vector of sfixed32."""
1268
1269    def params(self) -> list[tuple[str, str]]:
1270        return [('const ::pw::Vector<int32_t>&', 'values')]
1271
1272    def _encoder_fn(self) -> str:
1273        return 'WriteRepeatedSfixed32'
1274
1275
1276class Sfixed32ReadMethod(ReadMethod):
1277    """Method which reads a proto sfixed32 value."""
1278
1279    def _result_type(self) -> str:
1280        return 'int32_t'
1281
1282    def _decoder_fn(self) -> str:
1283        return 'ReadSfixed32'
1284
1285
1286class PackedSfixed32ReadMethod(PackedReadMethod):
1287    """Method which reads packed sfixed32 values."""
1288
1289    def _result_type(self) -> str:
1290        return 'int32_t'
1291
1292    def _decoder_fn(self) -> str:
1293        return 'ReadPackedSfixed32'
1294
1295
1296class PackedSfixed32ReadVectorMethod(PackedReadVectorMethod):
1297    """Method which reads packed sfixed32 values."""
1298
1299    def _result_type(self) -> str:
1300        return 'int32_t'
1301
1302    def _decoder_fn(self) -> str:
1303        return 'ReadRepeatedSfixed32'
1304
1305
1306class Sfixed32FindMethod(FindMethod):
1307    """Method which reads a proto sfixed32 value."""
1308
1309    def _result_type(self) -> str:
1310        return 'int32_t'
1311
1312    def _find_fn(self) -> str:
1313        return 'FindSfixed32'
1314
1315    def _finder(self) -> str:
1316        return 'Sfixed32Finder'
1317
1318
1319class Sfixed32FindStreamMethod(FindStreamMethod):
1320    """Method which reads a proto sfixed32 value."""
1321
1322    def _result_type(self) -> str:
1323        return 'int32_t'
1324
1325    def _find_fn(self) -> str:
1326        return 'FindSfixed32'
1327
1328    def _finder(self) -> str:
1329        return 'Sfixed32StreamFinder'
1330
1331
1332class Sfixed32Property(MessageProperty):
1333    """Property which holds a proto sfixed32 value."""
1334
1335    def type_name(self, from_root: bool = False) -> str:
1336        return 'int32_t'
1337
1338    def wire_type(self) -> str:
1339        return 'kFixed32'
1340
1341    def _size_fn(self) -> str:
1342        return 'SizeOfFieldSfixed32'
1343
1344
1345class Int64WriteMethod(WriteMethod):
1346    """Method which writes a proto int64 value."""
1347
1348    def params(self) -> list[tuple[str, str]]:
1349        return [('int64_t', 'value')]
1350
1351    def _encoder_fn(self) -> str:
1352        return 'WriteInt64'
1353
1354
1355class PackedInt64WriteMethod(PackedWriteMethod):
1356    """Method which writes a packed list of int64."""
1357
1358    def params(self) -> list[tuple[str, str]]:
1359        return [('pw::span<const int64_t>', 'values')]
1360
1361    def _encoder_fn(self) -> str:
1362        return 'WritePackedInt64'
1363
1364
1365class PackedInt64WriteVectorMethod(PackedWriteMethod):
1366    """Method which writes a packed vector of int64."""
1367
1368    def params(self) -> list[tuple[str, str]]:
1369        return [('const ::pw::Vector<int64_t>&', 'values')]
1370
1371    def _encoder_fn(self) -> str:
1372        return 'WriteRepeatedInt64'
1373
1374
1375class Int64ReadMethod(ReadMethod):
1376    """Method which reads a proto int64 value."""
1377
1378    def _result_type(self) -> str:
1379        return 'int64_t'
1380
1381    def _decoder_fn(self) -> str:
1382        return 'ReadInt64'
1383
1384
1385class PackedInt64ReadMethod(PackedReadMethod):
1386    """Method which reads packed int64 values."""
1387
1388    def _result_type(self) -> str:
1389        return 'int64_t'
1390
1391    def _decoder_fn(self) -> str:
1392        return 'ReadPackedInt64'
1393
1394
1395class PackedInt64ReadVectorMethod(PackedReadVectorMethod):
1396    """Method which reads packed int64 values."""
1397
1398    def _result_type(self) -> str:
1399        return 'int64_t'
1400
1401    def _decoder_fn(self) -> str:
1402        return 'ReadRepeatedInt64'
1403
1404
1405class Int64FindMethod(FindMethod):
1406    """Method which reads a proto int64 value."""
1407
1408    def _result_type(self) -> str:
1409        return 'int64_t'
1410
1411    def _find_fn(self) -> str:
1412        return 'FindInt64'
1413
1414    def _finder(self) -> str:
1415        return 'Int64Finder'
1416
1417
1418class Int64FindStreamMethod(FindStreamMethod):
1419    """Method which reads a proto int64 value."""
1420
1421    def _result_type(self) -> str:
1422        return 'int64_t'
1423
1424    def _find_fn(self) -> str:
1425        return 'FindInt64'
1426
1427    def _finder(self) -> str:
1428        return 'Int64StreamFinder'
1429
1430
1431class Int64Property(MessageProperty):
1432    """Property which holds a proto int64 value."""
1433
1434    def type_name(self, from_root: bool = False) -> str:
1435        return 'int64_t'
1436
1437    def wire_type(self) -> str:
1438        return 'kVarint'
1439
1440    def varint_decode_type(self) -> str:
1441        return 'kNormal'
1442
1443    def _size_fn(self) -> str:
1444        return 'SizeOfFieldInt64'
1445
1446
1447class Sint64WriteMethod(WriteMethod):
1448    """Method which writes a proto sint64 value."""
1449
1450    def params(self) -> list[tuple[str, str]]:
1451        return [('int64_t', 'value')]
1452
1453    def _encoder_fn(self) -> str:
1454        return 'WriteSint64'
1455
1456
1457class PackedSint64WriteMethod(PackedWriteMethod):
1458    """Method which writes a packst list of sint64."""
1459
1460    def params(self) -> list[tuple[str, str]]:
1461        return [('pw::span<const int64_t>', 'values')]
1462
1463    def _encoder_fn(self) -> str:
1464        return 'WritePackedSint64'
1465
1466
1467class PackedSint64WriteVectorMethod(PackedWriteMethod):
1468    """Method which writes a packed vector of sint64."""
1469
1470    def params(self) -> list[tuple[str, str]]:
1471        return [('const ::pw::Vector<int64_t>&', 'values')]
1472
1473    def _encoder_fn(self) -> str:
1474        return 'WriteRepeatedSint64'
1475
1476
1477class Sint64ReadMethod(ReadMethod):
1478    """Method which reads a proto sint64 value."""
1479
1480    def _result_type(self) -> str:
1481        return 'int64_t'
1482
1483    def _decoder_fn(self) -> str:
1484        return 'ReadSint64'
1485
1486
1487class PackedSint64ReadMethod(PackedReadMethod):
1488    """Method which reads packed sint64 values."""
1489
1490    def _result_type(self) -> str:
1491        return 'int64_t'
1492
1493    def _decoder_fn(self) -> str:
1494        return 'ReadPackedSint64'
1495
1496
1497class PackedSint64ReadVectorMethod(PackedReadVectorMethod):
1498    """Method which reads packed sint64 values."""
1499
1500    def _result_type(self) -> str:
1501        return 'int64_t'
1502
1503    def _decoder_fn(self) -> str:
1504        return 'ReadRepeatedSint64'
1505
1506
1507class Sint64FindMethod(FindMethod):
1508    """Method which reads a proto sint64 value."""
1509
1510    def _result_type(self) -> str:
1511        return 'int64_t'
1512
1513    def _find_fn(self) -> str:
1514        return 'FindSint64'
1515
1516
1517class Sint64FindStreamMethod(FindStreamMethod):
1518    """Method which reads a proto sint64 value."""
1519
1520    def _result_type(self) -> str:
1521        return 'int64_t'
1522
1523    def _find_fn(self) -> str:
1524        return 'FindSint64'
1525
1526
1527class Sint64Property(MessageProperty):
1528    """Property which holds a proto sint64 value."""
1529
1530    def type_name(self, from_root: bool = False) -> str:
1531        return 'int64_t'
1532
1533    def wire_type(self) -> str:
1534        return 'kVarint'
1535
1536    def varint_decode_type(self) -> str:
1537        return 'kZigZag'
1538
1539    def _size_fn(self) -> str:
1540        return 'SizeOfFieldSint64'
1541
1542
1543class Sfixed64WriteMethod(WriteMethod):
1544    """Method which writes a proto sfixed64 value."""
1545
1546    def params(self) -> list[tuple[str, str]]:
1547        return [('int64_t', 'value')]
1548
1549    def _encoder_fn(self) -> str:
1550        return 'WriteSfixed64'
1551
1552
1553class PackedSfixed64WriteMethod(PackedWriteMethod):
1554    """Method which writes a packed list of sfixed64."""
1555
1556    def params(self) -> list[tuple[str, str]]:
1557        return [('pw::span<const int64_t>', 'values')]
1558
1559    def _encoder_fn(self) -> str:
1560        return 'WritePackedSfixed4'
1561
1562
1563class PackedSfixed64WriteVectorMethod(PackedWriteMethod):
1564    """Method which writes a packed vector of sfixed64."""
1565
1566    def params(self) -> list[tuple[str, str]]:
1567        return [('const ::pw::Vector<int64_t>&', 'values')]
1568
1569    def _encoder_fn(self) -> str:
1570        return 'WriteRepeatedSfixed4'
1571
1572
1573class Sfixed64ReadMethod(ReadMethod):
1574    """Method which reads a proto sfixed64 value."""
1575
1576    def _result_type(self) -> str:
1577        return 'int64_t'
1578
1579    def _decoder_fn(self) -> str:
1580        return 'ReadSfixed64'
1581
1582
1583class PackedSfixed64ReadMethod(PackedReadMethod):
1584    """Method which reads packed sfixed64 values."""
1585
1586    def _result_type(self) -> str:
1587        return 'int64_t'
1588
1589    def _decoder_fn(self) -> str:
1590        return 'ReadPackedSfixed64'
1591
1592
1593class PackedSfixed64ReadVectorMethod(PackedReadVectorMethod):
1594    """Method which reads packed sfixed64 values."""
1595
1596    def _result_type(self) -> str:
1597        return 'int64_t'
1598
1599    def _decoder_fn(self) -> str:
1600        return 'ReadRepeatedSfixed64'
1601
1602
1603class Sfixed64FindMethod(FindMethod):
1604    """Method which reads a proto sfixed64 value."""
1605
1606    def _result_type(self) -> str:
1607        return 'int64_t'
1608
1609    def _find_fn(self) -> str:
1610        return 'FindSfixed64'
1611
1612
1613class Sfixed64FindStreamMethod(FindStreamMethod):
1614    """Method which reads a proto sfixed64 value."""
1615
1616    def _result_type(self) -> str:
1617        return 'int64_t'
1618
1619    def _find_fn(self) -> str:
1620        return 'FindSfixed64'
1621
1622
1623class Sfixed64Property(MessageProperty):
1624    """Property which holds a proto sfixed64 value."""
1625
1626    def type_name(self, from_root: bool = False) -> str:
1627        return 'int64_t'
1628
1629    def wire_type(self) -> str:
1630        return 'kFixed64'
1631
1632    def _size_fn(self) -> str:
1633        return 'SizeOfFieldSfixed64'
1634
1635
1636class Uint32WriteMethod(WriteMethod):
1637    """Method which writes a proto uint32 value."""
1638
1639    def params(self) -> list[tuple[str, str]]:
1640        return [('uint32_t', 'value')]
1641
1642    def _encoder_fn(self) -> str:
1643        return 'WriteUint32'
1644
1645
1646class PackedUint32WriteMethod(PackedWriteMethod):
1647    """Method which writes a packed list of uint32."""
1648
1649    def params(self) -> list[tuple[str, str]]:
1650        return [('pw::span<const uint32_t>', 'values')]
1651
1652    def _encoder_fn(self) -> str:
1653        return 'WritePackedUint32'
1654
1655
1656class PackedUint32WriteVectorMethod(PackedWriteMethod):
1657    """Method which writes a packed vector of uint32."""
1658
1659    def params(self) -> list[tuple[str, str]]:
1660        return [('const ::pw::Vector<uint32_t>&', 'values')]
1661
1662    def _encoder_fn(self) -> str:
1663        return 'WriteRepeatedUint32'
1664
1665
1666class Uint32ReadMethod(ReadMethod):
1667    """Method which reads a proto uint32 value."""
1668
1669    def _result_type(self) -> str:
1670        return 'uint32_t'
1671
1672    def _decoder_fn(self) -> str:
1673        return 'ReadUint32'
1674
1675
1676class PackedUint32ReadMethod(PackedReadMethod):
1677    """Method which reads packed uint32 values."""
1678
1679    def _result_type(self) -> str:
1680        return 'uint32_t'
1681
1682    def _decoder_fn(self) -> str:
1683        return 'ReadPackedUint32'
1684
1685
1686class PackedUint32ReadVectorMethod(PackedReadVectorMethod):
1687    """Method which reads packed uint32 values."""
1688
1689    def _result_type(self) -> str:
1690        return 'uint32_t'
1691
1692    def _decoder_fn(self) -> str:
1693        return 'ReadRepeatedUint32'
1694
1695
1696class Uint32FindMethod(FindMethod):
1697    """Method which finds a proto uint32 value."""
1698
1699    def _result_type(self) -> str:
1700        return 'uint32_t'
1701
1702    def _find_fn(self) -> str:
1703        return 'FindUint32'
1704
1705    def _finder(self) -> str:
1706        return 'Uint32Finder'
1707
1708
1709class Uint32FindStreamMethod(FindStreamMethod):
1710    """Method which finds a proto uint32 value."""
1711
1712    def _result_type(self) -> str:
1713        return 'uint32_t'
1714
1715    def _find_fn(self) -> str:
1716        return 'FindUint32'
1717
1718    def _finder(self) -> str:
1719        return 'Uint32StreamFinder'
1720
1721
1722class Uint32Property(MessageProperty):
1723    """Property which holds a proto uint32 value."""
1724
1725    def type_name(self, from_root: bool = False) -> str:
1726        return 'uint32_t'
1727
1728    def wire_type(self) -> str:
1729        return 'kVarint'
1730
1731    def varint_decode_type(self) -> str:
1732        return 'kUnsigned'
1733
1734    def _size_fn(self) -> str:
1735        return 'SizeOfFieldUint32'
1736
1737
1738class Fixed32WriteMethod(WriteMethod):
1739    """Method which writes a proto fixed32 value."""
1740
1741    def params(self) -> list[tuple[str, str]]:
1742        return [('uint32_t', 'value')]
1743
1744    def _encoder_fn(self) -> str:
1745        return 'WriteFixed32'
1746
1747
1748class PackedFixed32WriteMethod(PackedWriteMethod):
1749    """Method which writes a packed list of fixed32."""
1750
1751    def params(self) -> list[tuple[str, str]]:
1752        return [('pw::span<const uint32_t>', 'values')]
1753
1754    def _encoder_fn(self) -> str:
1755        return 'WritePackedFixed32'
1756
1757
1758class PackedFixed32WriteVectorMethod(PackedWriteMethod):
1759    """Method which writes a packed vector of fixed32."""
1760
1761    def params(self) -> list[tuple[str, str]]:
1762        return [('const ::pw::Vector<uint32_t>&', 'values')]
1763
1764    def _encoder_fn(self) -> str:
1765        return 'WriteRepeatedFixed32'
1766
1767
1768class Fixed32ReadMethod(ReadMethod):
1769    """Method which reads a proto fixed32 value."""
1770
1771    def _result_type(self) -> str:
1772        return 'uint32_t'
1773
1774    def _decoder_fn(self) -> str:
1775        return 'ReadFixed32'
1776
1777
1778class PackedFixed32ReadMethod(PackedReadMethod):
1779    """Method which reads packed fixed32 values."""
1780
1781    def _result_type(self) -> str:
1782        return 'uint32_t'
1783
1784    def _decoder_fn(self) -> str:
1785        return 'ReadPackedFixed32'
1786
1787
1788class PackedFixed32ReadVectorMethod(PackedReadVectorMethod):
1789    """Method which reads packed fixed32 values."""
1790
1791    def _result_type(self) -> str:
1792        return 'uint32_t'
1793
1794    def _decoder_fn(self) -> str:
1795        return 'ReadRepeatedFixed32'
1796
1797
1798class Fixed32FindMethod(FindMethod):
1799    """Method which finds a proto fixed32 value."""
1800
1801    def _result_type(self) -> str:
1802        return 'uint32_t'
1803
1804    def _find_fn(self) -> str:
1805        return 'FindFixed32'
1806
1807    def _finder(self) -> str:
1808        return 'Fixed32Finder'
1809
1810
1811class Fixed32FindStreamMethod(FindStreamMethod):
1812    """Method which finds a proto fixed32 value."""
1813
1814    def _result_type(self) -> str:
1815        return 'uint32_t'
1816
1817    def _find_fn(self) -> str:
1818        return 'FindFixed32'
1819
1820    def _finder(self) -> str:
1821        return 'Fixed32StreamFinder'
1822
1823
1824class Fixed32Property(MessageProperty):
1825    """Property which holds a proto fixed32 value."""
1826
1827    def type_name(self, from_root: bool = False) -> str:
1828        return 'uint32_t'
1829
1830    def wire_type(self) -> str:
1831        return 'kFixed32'
1832
1833    def _size_fn(self) -> str:
1834        return 'SizeOfFieldFixed32'
1835
1836
1837class Uint64WriteMethod(WriteMethod):
1838    """Method which writes a proto uint64 value."""
1839
1840    def params(self) -> list[tuple[str, str]]:
1841        return [('uint64_t', 'value')]
1842
1843    def _encoder_fn(self) -> str:
1844        return 'WriteUint64'
1845
1846
1847class PackedUint64WriteMethod(PackedWriteMethod):
1848    """Method which writes a packed list of uint64."""
1849
1850    def params(self) -> list[tuple[str, str]]:
1851        return [('pw::span<const uint64_t>', 'values')]
1852
1853    def _encoder_fn(self) -> str:
1854        return 'WritePackedUint64'
1855
1856
1857class PackedUint64WriteVectorMethod(PackedWriteMethod):
1858    """Method which writes a packed vector of uint64."""
1859
1860    def params(self) -> list[tuple[str, str]]:
1861        return [('const ::pw::Vector<uint64_t>&', 'values')]
1862
1863    def _encoder_fn(self) -> str:
1864        return 'WriteRepeatedUint64'
1865
1866
1867class Uint64ReadMethod(ReadMethod):
1868    """Method which reads a proto uint64 value."""
1869
1870    def _result_type(self) -> str:
1871        return 'uint64_t'
1872
1873    def _decoder_fn(self) -> str:
1874        return 'ReadUint64'
1875
1876
1877class PackedUint64ReadMethod(PackedReadMethod):
1878    """Method which reads packed uint64 values."""
1879
1880    def _result_type(self) -> str:
1881        return 'uint64_t'
1882
1883    def _decoder_fn(self) -> str:
1884        return 'ReadPackedUint64'
1885
1886
1887class PackedUint64ReadVectorMethod(PackedReadVectorMethod):
1888    """Method which reads packed uint64 values."""
1889
1890    def _result_type(self) -> str:
1891        return 'uint64_t'
1892
1893    def _decoder_fn(self) -> str:
1894        return 'ReadRepeatedUint64'
1895
1896
1897class Uint64FindMethod(FindMethod):
1898    """Method which finds a proto uint64 value."""
1899
1900    def _result_type(self) -> str:
1901        return 'uint64_t'
1902
1903    def _find_fn(self) -> str:
1904        return 'FindUint64'
1905
1906    def _finder(self) -> str:
1907        return 'Uint64Finder'
1908
1909
1910class Uint64FindStreamMethod(FindStreamMethod):
1911    """Method which finds a proto uint64 value."""
1912
1913    def _result_type(self) -> str:
1914        return 'uint64_t'
1915
1916    def _find_fn(self) -> str:
1917        return 'FindUint64'
1918
1919    def _finder(self) -> str:
1920        return 'Uint64StreamFinder'
1921
1922
1923class Uint64Property(MessageProperty):
1924    """Property which holds a proto uint64 value."""
1925
1926    def type_name(self, from_root: bool = False) -> str:
1927        return 'uint64_t'
1928
1929    def wire_type(self) -> str:
1930        return 'kVarint'
1931
1932    def varint_decode_type(self) -> str:
1933        return 'kUnsigned'
1934
1935    def _size_fn(self) -> str:
1936        return 'SizeOfFieldUint64'
1937
1938
1939class Fixed64WriteMethod(WriteMethod):
1940    """Method which writes a proto fixed64 value."""
1941
1942    def params(self) -> list[tuple[str, str]]:
1943        return [('uint64_t', 'value')]
1944
1945    def _encoder_fn(self) -> str:
1946        return 'WriteFixed64'
1947
1948
1949class PackedFixed64WriteMethod(PackedWriteMethod):
1950    """Method which writes a packed list of fixed64."""
1951
1952    def params(self) -> list[tuple[str, str]]:
1953        return [('pw::span<const uint64_t>', 'values')]
1954
1955    def _encoder_fn(self) -> str:
1956        return 'WritePackedFixed64'
1957
1958
1959class PackedFixed64WriteVectorMethod(PackedWriteMethod):
1960    """Method which writes a packed list of fixed64."""
1961
1962    def params(self) -> list[tuple[str, str]]:
1963        return [('const ::pw::Vector<uint64_t>&', 'values')]
1964
1965    def _encoder_fn(self) -> str:
1966        return 'WriteRepeatedFixed64'
1967
1968
1969class Fixed64ReadMethod(ReadMethod):
1970    """Method which reads a proto fixed64 value."""
1971
1972    def _result_type(self) -> str:
1973        return 'uint64_t'
1974
1975    def _decoder_fn(self) -> str:
1976        return 'ReadFixed64'
1977
1978
1979class PackedFixed64ReadMethod(PackedReadMethod):
1980    """Method which reads packed fixed64 values."""
1981
1982    def _result_type(self) -> str:
1983        return 'uint64_t'
1984
1985    def _decoder_fn(self) -> str:
1986        return 'ReadPackedFixed64'
1987
1988
1989class PackedFixed64ReadVectorMethod(PackedReadVectorMethod):
1990    """Method which reads packed fixed64 values."""
1991
1992    def _result_type(self) -> str:
1993        return 'uint64_t'
1994
1995    def _decoder_fn(self) -> str:
1996        return 'ReadRepeatedFixed64'
1997
1998
1999class Fixed64FindMethod(FindMethod):
2000    """Method which finds a proto fixed64 value."""
2001
2002    def _result_type(self) -> str:
2003        return 'uint64_t'
2004
2005    def _find_fn(self) -> str:
2006        return 'FindFixed64'
2007
2008    def _finder(self) -> str:
2009        return 'Fixed64Finder'
2010
2011
2012class Fixed64FindStreamMethod(FindStreamMethod):
2013    """Method which finds a proto fixed64 value."""
2014
2015    def _result_type(self) -> str:
2016        return 'uint64_t'
2017
2018    def _find_fn(self) -> str:
2019        return 'FindFixed64'
2020
2021    def _finder(self) -> str:
2022        return 'Fixed64StreamFinder'
2023
2024
2025class Fixed64Property(MessageProperty):
2026    """Property which holds a proto fixed64 value."""
2027
2028    def type_name(self, from_root: bool = False) -> str:
2029        return 'uint64_t'
2030
2031    def wire_type(self) -> str:
2032        return 'kFixed64'
2033
2034    def _size_fn(self) -> str:
2035        return 'SizeOfFieldFixed64'
2036
2037
2038class BoolWriteMethod(WriteMethod):
2039    """Method which writes a proto bool value."""
2040
2041    def params(self) -> list[tuple[str, str]]:
2042        return [('bool', 'value')]
2043
2044    def _encoder_fn(self) -> str:
2045        return 'WriteBool'
2046
2047
2048class PackedBoolWriteMethod(PackedWriteMethod):
2049    """Method which writes a packed list of bools."""
2050
2051    def params(self) -> list[tuple[str, str]]:
2052        return [('pw::span<const bool>', 'values')]
2053
2054    def _encoder_fn(self) -> str:
2055        return 'WritePackedBool'
2056
2057
2058class PackedBoolWriteVectorMethod(PackedWriteMethod):
2059    """Method which writes a packed vector of bools."""
2060
2061    def params(self) -> list[tuple[str, str]]:
2062        return [('const ::pw::Vector<bool>&', 'values')]
2063
2064    def _encoder_fn(self) -> str:
2065        return 'WriteRepeatedBool'
2066
2067
2068class BoolReadMethod(ReadMethod):
2069    """Method which reads a proto bool value."""
2070
2071    def _result_type(self) -> str:
2072        return 'bool'
2073
2074    def _decoder_fn(self) -> str:
2075        return 'ReadBool'
2076
2077
2078class PackedBoolReadMethod(PackedReadMethod):
2079    """Method which reads packed bool values."""
2080
2081    def _result_type(self) -> str:
2082        return 'bool'
2083
2084    def _decoder_fn(self) -> str:
2085        return 'ReadPackedBool'
2086
2087
2088class BoolFindMethod(FindMethod):
2089    """Method which finds a proto bool value."""
2090
2091    def _result_type(self) -> str:
2092        return 'bool'
2093
2094    def _find_fn(self) -> str:
2095        return 'FindBool'
2096
2097    def _finder(self) -> str:
2098        return 'BoolFinder'
2099
2100
2101class BoolFindStreamMethod(FindStreamMethod):
2102    """Method which finds a proto bool value."""
2103
2104    def _result_type(self) -> str:
2105        return 'bool'
2106
2107    def _find_fn(self) -> str:
2108        return 'FindBool'
2109
2110    def _finder(self) -> str:
2111        return 'BoolStreamFinder'
2112
2113
2114class BoolProperty(MessageProperty):
2115    """Property which holds a proto bool value."""
2116
2117    def type_name(self, from_root: bool = False) -> str:
2118        return 'bool'
2119
2120    def wire_type(self) -> str:
2121        return 'kVarint'
2122
2123    def varint_decode_type(self) -> str:
2124        return 'kUnsigned'
2125
2126    def _size_fn(self) -> str:
2127        return 'SizeOfFieldBool'
2128
2129
2130class BytesWriteMethod(WriteMethod):
2131    """Method which writes a proto bytes value."""
2132
2133    def params(self) -> list[tuple[str, str]]:
2134        return [('pw::span<const std::byte>', 'value')]
2135
2136    def _encoder_fn(self) -> str:
2137        return 'WriteBytes'
2138
2139
2140class BytesReadMethod(ReadMethod):
2141    """Method which reads a proto bytes value."""
2142
2143    def return_type(self, from_root: bool = False) -> str:
2144        return '::pw::StatusWithSize'
2145
2146    def params(self) -> list[tuple[str, str]]:
2147        return [('pw::span<std::byte>', 'out')]
2148
2149    def _decoder_fn(self) -> str:
2150        return 'ReadBytes'
2151
2152
2153class BytesFindMethod(FindMethod):
2154    """Method which reads a proto bytes value."""
2155
2156    def _result_type(self) -> str:
2157        return '::pw::ConstByteSpan'
2158
2159    def _find_fn(self) -> str:
2160        return 'FindBytes'
2161
2162    def _finder(self) -> str:
2163        return 'BytesFinder'
2164
2165
2166class BytesFindStreamMethod(FindStreamMethod):
2167    """Method which reads a proto bytes value."""
2168
2169    def return_type(self, from_root: bool = False) -> str:
2170        return '::pw::StatusWithSize'
2171
2172    def params(self) -> list[tuple[str, str]]:
2173        return [
2174            ('::pw::stream::Reader&', 'message_stream'),
2175            ('::pw::ByteSpan', 'out'),
2176        ]
2177
2178    def body(self) -> list[str]:
2179        lines: list[str] = []
2180        lines += [
2181            f'return {PROTOBUF_NAMESPACE}::{self._find_fn()}'
2182            f'(message_stream, {self.field_cast()}, out);'
2183        ]
2184        return lines
2185
2186    def _find_fn(self) -> str:
2187        return 'FindBytes'
2188
2189
2190class BytesProperty(MessageProperty):
2191    """Property which holds a proto bytes value."""
2192
2193    def type_name(self, from_root: bool = False) -> str:
2194        return 'std::byte'
2195
2196    def _use_callback(self) -> bool:
2197        return self.max_size() == 0
2198
2199    def max_size(self) -> int:
2200        if not self._field.is_repeated():
2201            options = self._field.options()
2202            assert options is not None
2203            return options.max_size
2204
2205        return 0
2206
2207    def is_fixed_size(self) -> bool:
2208        if not self._field.is_repeated():
2209            options = self._field.options()
2210            assert options is not None
2211            return options.fixed_size
2212
2213        return False
2214
2215    def wire_type(self) -> str:
2216        return 'kDelimited'
2217
2218    def _size_fn(self) -> str:
2219        # This uses the WithoutValue method to ensure that the maximum length
2220        # of the delimited field size varint is used. This accounts for scratch
2221        # overhead when used with MemoryEncoder.
2222        return 'SizeOfDelimitedFieldWithoutValue'
2223
2224    def _size_length(self) -> str | None:
2225        if self.callback_type() is not _CallbackType.NONE:
2226            return None
2227        return self.max_size_constant_name()
2228
2229
2230class StringLenWriteMethod(WriteMethod):
2231    """Method which writes a proto string value with length."""
2232
2233    def params(self) -> list[tuple[str, str]]:
2234        return [('const char*', 'value'), ('size_t', 'len')]
2235
2236    def _encoder_fn(self) -> str:
2237        return 'WriteString'
2238
2239
2240class StringWriteMethod(WriteMethod):
2241    """Method which writes a proto string value."""
2242
2243    def params(self) -> list[tuple[str, str]]:
2244        return [('std::string_view', 'value')]
2245
2246    def _encoder_fn(self) -> str:
2247        return 'WriteString'
2248
2249
2250class StringReadMethod(ReadMethod):
2251    """Method which reads a proto string value."""
2252
2253    def return_type(self, from_root: bool = False) -> str:
2254        return '::pw::StatusWithSize'
2255
2256    def params(self) -> list[tuple[str, str]]:
2257        return [('pw::span<char>', 'out')]
2258
2259    def _decoder_fn(self) -> str:
2260        return 'ReadString'
2261
2262
2263class StringFindMethod(FindMethod):
2264    """Method which reads a proto string value."""
2265
2266    def _result_type(self) -> str:
2267        return 'std::string_view'
2268
2269    def _find_fn(self) -> str:
2270        return 'FindString'
2271
2272    def _finder(self) -> str:
2273        return 'StringFinder'
2274
2275
2276class StringFindStreamMethod(FindStreamMethod):
2277    """Method which reads a proto string value."""
2278
2279    def return_type(self, from_root: bool = False) -> str:
2280        return '::pw::StatusWithSize'
2281
2282    def params(self) -> list[tuple[str, str]]:
2283        return [
2284            ('::pw::stream::Reader&', 'message_stream'),
2285            ('::pw::span<char>', 'out'),
2286        ]
2287
2288    def body(self) -> list[str]:
2289        lines: list[str] = []
2290        lines += [
2291            f'return {PROTOBUF_NAMESPACE}::{self._find_fn()}'
2292            f'(message_stream, {self.field_cast()}, out);'
2293        ]
2294        return lines
2295
2296    def _find_fn(self) -> str:
2297        return 'FindString'
2298
2299
2300class StringFindStreamMethodInlineString(FindStreamMethod):
2301    """Method which reads a proto string value to an InlineString."""
2302
2303    def return_type(self, from_root: bool = False) -> str:
2304        return '::pw::StatusWithSize'
2305
2306    def params(self) -> list[tuple[str, str]]:
2307        return [
2308            ('::pw::stream::Reader&', 'message_stream'),
2309            ('::pw::InlineString<>&', 'out'),
2310        ]
2311
2312    def body(self) -> list[str]:
2313        lines: list[str] = []
2314        lines += [
2315            f'return {PROTOBUF_NAMESPACE}::{self._find_fn()}'
2316            f'(message_stream, {self.field_cast()}, out);'
2317        ]
2318        return lines
2319
2320    def _find_fn(self) -> str:
2321        return 'FindString'
2322
2323
2324class StringProperty(MessageProperty):
2325    """Property which holds a proto string value."""
2326
2327    def type_name(self, from_root: bool = False) -> str:
2328        return 'char'
2329
2330    def _use_callback(self) -> bool:
2331        return self.max_size() == 0
2332
2333    def max_size(self) -> int:
2334        if not self._field.is_repeated():
2335            options = self._field.options()
2336            assert options is not None
2337            return options.max_size
2338
2339        return 0
2340
2341    def is_fixed_size(self) -> bool:
2342        return False
2343
2344    def wire_type(self) -> str:
2345        return 'kDelimited'
2346
2347    def is_string(self) -> bool:
2348        return True
2349
2350    @staticmethod
2351    def repeated_field_container(type_name: str, max_size: str) -> str:
2352        return f'::pw::InlineBasicString<{type_name}, {max_size}>'
2353
2354    def _size_fn(self) -> str:
2355        # This uses the WithoutValue method to ensure that the maximum length
2356        # of the delimited field size varint is used. This accounts for scratch
2357        # overhead when used with MemoryEncoder.
2358        return 'SizeOfDelimitedFieldWithoutValue'
2359
2360    def _size_length(self) -> str | None:
2361        if self.callback_type() is not _CallbackType.NONE:
2362            return None
2363        return self.max_size_constant_name()
2364
2365
2366class EnumWriteMethod(WriteMethod):
2367    """Method which writes a proto enum value."""
2368
2369    def params(self) -> list[tuple[str, str]]:
2370        return [(self._relative_type_namespace(), 'value')]
2371
2372    def body(self) -> list[str]:
2373        line = (
2374            'return {}::WriteUint32({}, '
2375            'static_cast<uint32_t>(value));'.format(
2376                self._base_class, self.field_cast()
2377            )
2378        )
2379        return [line]
2380
2381    def in_class_definition(self) -> bool:
2382        return True
2383
2384    def _encoder_fn(self) -> str:
2385        raise NotImplementedError()
2386
2387
2388class PackedEnumWriteMethod(PackedWriteMethod):
2389    """Method which writes a packed list of enum."""
2390
2391    def params(self) -> list[tuple[str, str]]:
2392        return [
2393            (
2394                'pw::span<const {}>'.format(self._relative_type_namespace()),
2395                'values',
2396            )
2397        ]
2398
2399    def body(self) -> list[str]:
2400        value_param = self.params()[0][1]
2401        line = (
2402            f'return {self._base_class}::WritePackedUint32('
2403            f'{self.field_cast()}, pw::span(reinterpret_cast<const uint32_t*>('
2404            f'{value_param}.data()), {value_param}.size()));'
2405        )
2406        return [line]
2407
2408    def in_class_definition(self) -> bool:
2409        return True
2410
2411    def _encoder_fn(self) -> str:
2412        raise NotImplementedError()
2413
2414
2415class PackedEnumWriteVectorMethod(PackedEnumWriteMethod):
2416    """Method which writes a packed vector of enum."""
2417
2418    def params(self) -> list[tuple[str, str]]:
2419        return [
2420            (
2421                'const ::pw::Vector<{}>&'.format(
2422                    self._relative_type_namespace()
2423                ),
2424                'values',
2425            )
2426        ]
2427
2428
2429class EnumReadMethod(ReadMethod):
2430    """Method which reads a proto enum value."""
2431
2432    def _result_type(self):
2433        return self._relative_type_namespace()
2434
2435    def _decoder_body(self) -> list[str]:
2436        lines: list[str] = []
2437        lines += ['::pw::Result<uint32_t> value = ReadUint32();']
2438        lines += ['if (!value.ok()) {']
2439        lines += ['  return value.status();']
2440        lines += ['}']
2441
2442        lines += [f'return static_cast<{self._result_type()}>(value.value());']
2443        return lines
2444
2445
2446class PackedEnumReadMethod(PackedReadMethod):
2447    """Method which reads packed enum values."""
2448
2449    def _result_type(self):
2450        return self._relative_type_namespace()
2451
2452    def _decoder_body(self) -> list[str]:
2453        value_param = self.params()[0][1]
2454        return [
2455            f'return ReadPackedUint32('
2456            f'pw::span(reinterpret_cast<uint32_t*>({value_param}.data()), '
2457            f'{value_param}.size()));'
2458        ]
2459
2460
2461class PackedEnumReadVectorMethod(PackedReadVectorMethod):
2462    """Method which reads packed enum values."""
2463
2464    def _result_type(self):
2465        return self._relative_type_namespace()
2466
2467    def _decoder_body(self) -> list[str]:
2468        value_param = self.params()[0][1]
2469        return [
2470            f'return ReadRepeatedUint32('
2471            f'*reinterpret_cast<pw::Vector<uint32_t>*>(&{value_param}));'
2472        ]
2473
2474
2475class EnumFindMethod(FindMethod):
2476    """Method which finds a proto enum value."""
2477
2478    def _result_type(self) -> str:
2479        return self._relative_type_namespace()
2480
2481    def body(self) -> list[str]:
2482        if self._field.is_repeated():
2483            return super().body()
2484
2485        lines: list[str] = []
2486        lines += [
2487            '::pw::Result<uint32_t> result = '
2488            f'{PROTOBUF_NAMESPACE}::{self._find_fn()}'
2489            f'(message, {self.field_cast()});',
2490            'if (!result.ok()) {',
2491            '  return result.status();',
2492            '}',
2493            f'return static_cast<{self._result_type()}>(result.value());',
2494        ]
2495        return lines
2496
2497    def _find_fn(self) -> str:
2498        return 'FindUint32'
2499
2500    def _finder(self) -> str:
2501        return f'EnumFinder<{self._result_type()}>'
2502
2503
2504class EnumFindStreamMethod(FindStreamMethod):
2505    """Method which finds a proto enum value."""
2506
2507    def _result_type(self) -> str:
2508        return self._relative_type_namespace()
2509
2510    def body(self) -> list[str]:
2511        if self._field.is_repeated():
2512            return super().body()
2513
2514        lines: list[str] = []
2515        lines += [
2516            '::pw::Result<uint32_t> result = '
2517            f'{PROTOBUF_NAMESPACE}::{self._find_fn()}'
2518            f'(message_stream, {self.field_cast()});',
2519            'if (!result.ok()) {',
2520            '  return result.status();',
2521            '}',
2522            f'return static_cast<{self._result_type()}>(result.value());',
2523        ]
2524        return lines
2525
2526    def _find_fn(self) -> str:
2527        return 'FindUint32'
2528
2529    def _finder(self) -> str:
2530        return f'EnumStreamFinder<{self._result_type()}>'
2531
2532
2533class EnumProperty(MessageProperty):
2534    """Property which holds a proto enum value."""
2535
2536    def type_name(self, from_root: bool = False) -> str:
2537        return self._relative_type_namespace(from_root=from_root)
2538
2539    def wire_type(self) -> str:
2540        return 'kVarint'
2541
2542    def varint_decode_type(self) -> str:
2543        return 'kUnsigned'
2544
2545    def _size_fn(self) -> str:
2546        return 'SizeOfFieldEnum'
2547
2548
2549# Mapping of protobuf field types to their method definitions.
2550PROTO_FIELD_WRITE_METHODS: dict[int, list] = {
2551    descriptor_pb2.FieldDescriptorProto.TYPE_DOUBLE: [
2552        DoubleWriteMethod,
2553        PackedDoubleWriteMethod,
2554        PackedDoubleWriteVectorMethod,
2555    ],
2556    descriptor_pb2.FieldDescriptorProto.TYPE_FLOAT: [
2557        FloatWriteMethod,
2558        PackedFloatWriteMethod,
2559        PackedFloatWriteVectorMethod,
2560    ],
2561    descriptor_pb2.FieldDescriptorProto.TYPE_INT32: [
2562        Int32WriteMethod,
2563        PackedInt32WriteMethod,
2564        PackedInt32WriteVectorMethod,
2565    ],
2566    descriptor_pb2.FieldDescriptorProto.TYPE_SINT32: [
2567        Sint32WriteMethod,
2568        PackedSint32WriteMethod,
2569        PackedSint32WriteVectorMethod,
2570    ],
2571    descriptor_pb2.FieldDescriptorProto.TYPE_SFIXED32: [
2572        Sfixed32WriteMethod,
2573        PackedSfixed32WriteMethod,
2574        PackedSfixed32WriteVectorMethod,
2575    ],
2576    descriptor_pb2.FieldDescriptorProto.TYPE_INT64: [
2577        Int64WriteMethod,
2578        PackedInt64WriteMethod,
2579        PackedInt64WriteVectorMethod,
2580    ],
2581    descriptor_pb2.FieldDescriptorProto.TYPE_SINT64: [
2582        Sint64WriteMethod,
2583        PackedSint64WriteMethod,
2584        PackedSint64WriteVectorMethod,
2585    ],
2586    descriptor_pb2.FieldDescriptorProto.TYPE_SFIXED64: [
2587        Sfixed64WriteMethod,
2588        PackedSfixed64WriteMethod,
2589        PackedSfixed64WriteVectorMethod,
2590    ],
2591    descriptor_pb2.FieldDescriptorProto.TYPE_UINT32: [
2592        Uint32WriteMethod,
2593        PackedUint32WriteMethod,
2594        PackedUint32WriteVectorMethod,
2595    ],
2596    descriptor_pb2.FieldDescriptorProto.TYPE_FIXED32: [
2597        Fixed32WriteMethod,
2598        PackedFixed32WriteMethod,
2599        PackedFixed32WriteVectorMethod,
2600    ],
2601    descriptor_pb2.FieldDescriptorProto.TYPE_UINT64: [
2602        Uint64WriteMethod,
2603        PackedUint64WriteMethod,
2604        PackedUint64WriteVectorMethod,
2605    ],
2606    descriptor_pb2.FieldDescriptorProto.TYPE_FIXED64: [
2607        Fixed64WriteMethod,
2608        PackedFixed64WriteMethod,
2609        PackedFixed64WriteVectorMethod,
2610    ],
2611    descriptor_pb2.FieldDescriptorProto.TYPE_BOOL: [
2612        BoolWriteMethod,
2613        PackedBoolWriteMethod,
2614        PackedBoolWriteVectorMethod,
2615    ],
2616    descriptor_pb2.FieldDescriptorProto.TYPE_BYTES: [BytesWriteMethod],
2617    descriptor_pb2.FieldDescriptorProto.TYPE_STRING: [
2618        StringLenWriteMethod,
2619        StringWriteMethod,
2620    ],
2621    descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE: [SubMessageEncoderMethod],
2622    descriptor_pb2.FieldDescriptorProto.TYPE_ENUM: [
2623        EnumWriteMethod,
2624        PackedEnumWriteMethod,
2625        PackedEnumWriteVectorMethod,
2626    ],
2627}
2628
2629PROTO_FIELD_READ_METHODS: dict[int, list] = {
2630    descriptor_pb2.FieldDescriptorProto.TYPE_DOUBLE: [
2631        DoubleReadMethod,
2632        PackedDoubleReadMethod,
2633        PackedDoubleReadVectorMethod,
2634    ],
2635    descriptor_pb2.FieldDescriptorProto.TYPE_FLOAT: [
2636        FloatReadMethod,
2637        PackedFloatReadMethod,
2638        PackedFloatReadVectorMethod,
2639    ],
2640    descriptor_pb2.FieldDescriptorProto.TYPE_INT32: [
2641        Int32ReadMethod,
2642        PackedInt32ReadMethod,
2643        PackedInt32ReadVectorMethod,
2644    ],
2645    descriptor_pb2.FieldDescriptorProto.TYPE_SINT32: [
2646        Sint32ReadMethod,
2647        PackedSint32ReadMethod,
2648        PackedSint32ReadVectorMethod,
2649    ],
2650    descriptor_pb2.FieldDescriptorProto.TYPE_SFIXED32: [
2651        Sfixed32ReadMethod,
2652        PackedSfixed32ReadMethod,
2653        PackedSfixed32ReadVectorMethod,
2654    ],
2655    descriptor_pb2.FieldDescriptorProto.TYPE_INT64: [
2656        Int64ReadMethod,
2657        PackedInt64ReadMethod,
2658        PackedInt64ReadVectorMethod,
2659    ],
2660    descriptor_pb2.FieldDescriptorProto.TYPE_SINT64: [
2661        Sint64ReadMethod,
2662        PackedSint64ReadMethod,
2663        PackedSint64ReadVectorMethod,
2664    ],
2665    descriptor_pb2.FieldDescriptorProto.TYPE_SFIXED64: [
2666        Sfixed64ReadMethod,
2667        PackedSfixed64ReadMethod,
2668        PackedSfixed64ReadVectorMethod,
2669    ],
2670    descriptor_pb2.FieldDescriptorProto.TYPE_UINT32: [
2671        Uint32ReadMethod,
2672        PackedUint32ReadMethod,
2673        PackedUint32ReadVectorMethod,
2674    ],
2675    descriptor_pb2.FieldDescriptorProto.TYPE_FIXED32: [
2676        Fixed32ReadMethod,
2677        PackedFixed32ReadMethod,
2678        PackedFixed32ReadVectorMethod,
2679    ],
2680    descriptor_pb2.FieldDescriptorProto.TYPE_UINT64: [
2681        Uint64ReadMethod,
2682        PackedUint64ReadMethod,
2683        PackedUint64ReadVectorMethod,
2684    ],
2685    descriptor_pb2.FieldDescriptorProto.TYPE_FIXED64: [
2686        Fixed64ReadMethod,
2687        PackedFixed64ReadMethod,
2688        PackedFixed64ReadVectorMethod,
2689    ],
2690    descriptor_pb2.FieldDescriptorProto.TYPE_BOOL: [
2691        BoolReadMethod,
2692        PackedBoolReadMethod,
2693    ],
2694    descriptor_pb2.FieldDescriptorProto.TYPE_BYTES: [
2695        BytesReadMethod,
2696        BytesReaderMethod,
2697    ],
2698    descriptor_pb2.FieldDescriptorProto.TYPE_STRING: [
2699        StringReadMethod,
2700        BytesReaderMethod,
2701    ],
2702    descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE: [SubMessageDecoderMethod],
2703    descriptor_pb2.FieldDescriptorProto.TYPE_ENUM: [
2704        EnumReadMethod,
2705        PackedEnumReadMethod,
2706        PackedEnumReadVectorMethod,
2707    ],
2708}
2709
2710PROTO_FIELD_FIND_METHODS: dict[int, list] = {
2711    descriptor_pb2.FieldDescriptorProto.TYPE_DOUBLE: [
2712        DoubleFindMethod,
2713        DoubleFindStreamMethod,
2714    ],
2715    descriptor_pb2.FieldDescriptorProto.TYPE_FLOAT: [
2716        FloatFindMethod,
2717        FloatFindStreamMethod,
2718    ],
2719    descriptor_pb2.FieldDescriptorProto.TYPE_INT32: [
2720        Int32FindMethod,
2721        Int32FindStreamMethod,
2722    ],
2723    descriptor_pb2.FieldDescriptorProto.TYPE_SINT32: [
2724        Sint32FindMethod,
2725        Sint32FindStreamMethod,
2726    ],
2727    descriptor_pb2.FieldDescriptorProto.TYPE_SFIXED32: [
2728        Sfixed32FindMethod,
2729        Sfixed32FindStreamMethod,
2730    ],
2731    descriptor_pb2.FieldDescriptorProto.TYPE_INT64: [
2732        Int64FindMethod,
2733        Int64FindStreamMethod,
2734    ],
2735    descriptor_pb2.FieldDescriptorProto.TYPE_SINT64: [
2736        Sint64FindMethod,
2737        Sint64FindStreamMethod,
2738    ],
2739    descriptor_pb2.FieldDescriptorProto.TYPE_SFIXED64: [
2740        Sfixed64FindMethod,
2741        Sfixed64FindStreamMethod,
2742    ],
2743    descriptor_pb2.FieldDescriptorProto.TYPE_UINT32: [
2744        Uint32FindMethod,
2745        Uint32FindStreamMethod,
2746    ],
2747    descriptor_pb2.FieldDescriptorProto.TYPE_FIXED32: [
2748        Fixed32FindMethod,
2749        Fixed32FindStreamMethod,
2750    ],
2751    descriptor_pb2.FieldDescriptorProto.TYPE_UINT64: [
2752        Uint64FindMethod,
2753        Uint64FindStreamMethod,
2754    ],
2755    descriptor_pb2.FieldDescriptorProto.TYPE_FIXED64: [
2756        Fixed64FindMethod,
2757        Fixed64FindStreamMethod,
2758    ],
2759    descriptor_pb2.FieldDescriptorProto.TYPE_BOOL: [
2760        BoolFindMethod,
2761        BoolFindStreamMethod,
2762    ],
2763    descriptor_pb2.FieldDescriptorProto.TYPE_BYTES: [
2764        BytesFindMethod,
2765        BytesFindStreamMethod,
2766    ],
2767    descriptor_pb2.FieldDescriptorProto.TYPE_STRING: [
2768        StringFindMethod,
2769        StringFindStreamMethod,
2770        StringFindStreamMethodInlineString,
2771    ],
2772    descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE: [
2773        SubMessageFindMethod,
2774    ],
2775    descriptor_pb2.FieldDescriptorProto.TYPE_ENUM: [
2776        EnumFindMethod,
2777        EnumFindStreamMethod,
2778    ],
2779}
2780
2781PROTO_FIELD_PROPERTIES: dict[int, Type[MessageProperty]] = {
2782    descriptor_pb2.FieldDescriptorProto.TYPE_DOUBLE: DoubleProperty,
2783    descriptor_pb2.FieldDescriptorProto.TYPE_FLOAT: FloatProperty,
2784    descriptor_pb2.FieldDescriptorProto.TYPE_INT32: Int32Property,
2785    descriptor_pb2.FieldDescriptorProto.TYPE_SINT32: Sint32Property,
2786    descriptor_pb2.FieldDescriptorProto.TYPE_SFIXED32: Sfixed32Property,
2787    descriptor_pb2.FieldDescriptorProto.TYPE_INT64: Int64Property,
2788    descriptor_pb2.FieldDescriptorProto.TYPE_SINT64: Sint64Property,
2789    descriptor_pb2.FieldDescriptorProto.TYPE_SFIXED64: Sfixed32Property,
2790    descriptor_pb2.FieldDescriptorProto.TYPE_UINT32: Uint32Property,
2791    descriptor_pb2.FieldDescriptorProto.TYPE_FIXED32: Fixed32Property,
2792    descriptor_pb2.FieldDescriptorProto.TYPE_UINT64: Uint64Property,
2793    descriptor_pb2.FieldDescriptorProto.TYPE_FIXED64: Fixed64Property,
2794    descriptor_pb2.FieldDescriptorProto.TYPE_BOOL: BoolProperty,
2795    descriptor_pb2.FieldDescriptorProto.TYPE_BYTES: BytesProperty,
2796    descriptor_pb2.FieldDescriptorProto.TYPE_STRING: StringProperty,
2797    descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE: SubMessageProperty,
2798    descriptor_pb2.FieldDescriptorProto.TYPE_ENUM: EnumProperty,
2799}
2800
2801
2802def proto_message_field_props(
2803    codegen_options: GeneratorOptions,
2804    message: ProtoMessage,
2805    root: ProtoNode,
2806    include_hidden: bool = False,
2807) -> Iterable[MessageProperty]:
2808    """Yields a MessageProperty for each field in a ProtoMessage.
2809
2810    Only properties which should_appear() is True are returned, unless
2811    `include_hidden` is set.
2812
2813    Args:
2814      message: The ProtoMessage whose fields are iterated.
2815      root: The root ProtoNode of the tree.
2816      include_hidden: If True, also yield fields which shouldn't appear in the
2817          struct.
2818
2819    Yields:
2820      An appropriately-typed MessageProperty object for each field
2821      in the message, to which the property refers.
2822    """
2823    for field in message.fields():
2824        property_class = PROTO_FIELD_PROPERTIES[field.type()]
2825        prop = property_class(codegen_options, field, message, root)
2826        if include_hidden or prop.should_appear():
2827            yield prop
2828
2829
2830def proto_field_methods(class_type: ClassType, field_type: int) -> list:
2831    return (
2832        PROTO_FIELD_WRITE_METHODS[field_type]
2833        if class_type.is_encoder()
2834        else PROTO_FIELD_READ_METHODS[field_type]
2835    )
2836
2837
2838def generate_class_for_message(
2839    message: ProtoMessage,
2840    root: ProtoNode,
2841    output: OutputFile,
2842    codegen_options: GeneratorOptions,
2843    class_type: ClassType,
2844) -> None:
2845    """Creates a C++ class to encode or decoder a protobuf message."""
2846    assert message.type() == ProtoNode.Type.MESSAGE
2847
2848    base_class_name = class_type.base_class_name()
2849    class_name = class_type.codegen_class_name()
2850
2851    # Message classes inherit from the base proto message class in codegen.h
2852    # and use its constructor.
2853    base_class = f'{PROTOBUF_NAMESPACE}::{base_class_name}'
2854    output.write_line(
2855        f'class {message.cpp_namespace(root=root)}::{class_name} '
2856        f': public {base_class} {{'
2857    )
2858    output.write_line(' public:')
2859
2860    with output.indent():
2861        # Inherit the constructors from the base class.
2862        output.write_line(f'using {base_class}::{base_class_name};')
2863
2864        # Declare a move constructor that takes a base class.
2865        output.write_line(
2866            f'constexpr {class_name}({base_class}&& parent) '
2867            f': {base_class}(std::move(parent)) {{}}'
2868        )
2869
2870        # Allow MemoryEncoder& to be converted to StreamEncoder&.
2871        if class_type == ClassType.MEMORY_ENCODER:
2872            stream_type = (
2873                f'::{message.cpp_namespace()}::'
2874                f'{ClassType.STREAMING_ENCODER.codegen_class_name()}'
2875            )
2876            output.write_line(
2877                f'operator {stream_type}&() '
2878                f' {{ return static_cast<{stream_type}&>('
2879                f'*static_cast<{PROTOBUF_NAMESPACE}::StreamEncoder*>(this));}}'
2880            )
2881
2882        # Add a typed Field() member to StreamDecoder
2883        if class_type == ClassType.STREAMING_DECODER:
2884            output.write_line()
2885            output.write_line('::pw::Result<Fields> Field() {')
2886            with output.indent():
2887                output.write_line(
2888                    '::pw::Result<uint32_t> result ' '= FieldNumber();'
2889                )
2890                output.write_line('if (!result.ok()) {')
2891                with output.indent():
2892                    output.write_line('return result.status();')
2893                output.write_line('}')
2894                output.write_line('return static_cast<Fields>(result.value());')
2895            output.write_line('}')
2896
2897        # Generate entry for message table read or write methods.
2898        if class_type == ClassType.STREAMING_DECODER:
2899            output.write_line()
2900            output.write_line('::pw::Status Read(Message& message) {')
2901            with output.indent():
2902                output.write_line(
2903                    f'return {base_class}::Read('
2904                    'pw::as_writable_bytes(pw::span(&message, 1)), '
2905                    'kMessageFields);'
2906                )
2907            output.write_line('}')
2908        elif class_type in (
2909            ClassType.STREAMING_ENCODER,
2910            ClassType.MEMORY_ENCODER,
2911        ):
2912            output.write_line()
2913            output.write_line('::pw::Status Write(const Message& message) {')
2914            with output.indent():
2915                output.write_line(
2916                    f'return {base_class}::Write('
2917                    'pw::as_bytes(pw::span(&message, 1)), kMessageFields);'
2918                )
2919            output.write_line('}')
2920
2921        # Generate methods for each of the message's fields.
2922        for field in message.fields():
2923            for method_class in proto_field_methods(class_type, field.type()):
2924                method = method_class(
2925                    codegen_options, field, message, root, base_class
2926                )
2927                if not method.should_appear():
2928                    continue
2929
2930                output.write_line()
2931                method_signature = (
2932                    f'{method.return_type()} '
2933                    f'{method.name()}({method.param_string()})'
2934                )
2935
2936                if not method.in_class_definition():
2937                    # Method will be defined outside of the class at the end of
2938                    # the file.
2939                    output.write_line(f'{method_signature};')
2940                    continue
2941
2942                output.write_line(f'{method_signature} {{')
2943                with output.indent():
2944                    for line in method.body():
2945                        output.write_line(line)
2946                output.write_line('}')
2947
2948    output.write_line('};')
2949
2950
2951def define_not_in_class_methods(
2952    message: ProtoMessage,
2953    root: ProtoNode,
2954    output: OutputFile,
2955    codegen_options: GeneratorOptions,
2956    class_type: ClassType,
2957) -> None:
2958    """Defines methods for a message class that were previously declared."""
2959    assert message.type() == ProtoNode.Type.MESSAGE
2960
2961    base_class_name = class_type.base_class_name()
2962    base_class = f'{PROTOBUF_NAMESPACE}::{base_class_name}'
2963
2964    for field in message.fields():
2965        for method_class in proto_field_methods(class_type, field.type()):
2966            method = method_class(
2967                codegen_options,
2968                field,
2969                message,
2970                root,
2971                base_class,
2972            )
2973            if not method.should_appear() or method.in_class_definition():
2974                continue
2975
2976            output.write_line()
2977            class_name = (
2978                f'{message.cpp_namespace(root=root)}::'
2979                f'{class_type.codegen_class_name()}'
2980            )
2981            method_signature = (
2982                f'inline {method.return_type(from_root=True)} '
2983                f'{class_name}::{method.name()}({method.param_string()})'
2984            )
2985            output.write_line(f'{method_signature} {{')
2986            with output.indent():
2987                for line in method.body():
2988                    output.write_line(line)
2989            output.write_line('}')
2990
2991
2992def _common_value_prefix(proto_enum: ProtoEnum) -> str:
2993    """Calculate the common prefix of all enum values.
2994
2995    Given an enumeration:
2996        enum Thing {
2997            THING_ONE = 1;
2998            THING_TWO = 2;
2999            THING_THREE = 3;
3000        }
3001
3002    If will return 'THING_', resulting in generated "style" aliases of
3003    'kOne', 'kTwo', and 'kThree'.
3004
3005    The prefix is walked back to the last _, so that the enumeration:
3006        enum Activity {
3007            ACTIVITY_RUN = 1;
3008            ACTIVITY_ROW = 2;
3009        }
3010
3011    Returns 'ACTIVITY_' and not 'ACTIVITY_R'.
3012    """
3013    if len(proto_enum.values()) <= 1:
3014        return ''
3015
3016    common_prefix = "".join(
3017        ch[0]
3018        for ch in takewhile(
3019            lambda ch: all(ch[0] == c for c in ch),
3020            zip(*[name for name, _ in proto_enum.values()]),
3021        )
3022    )
3023    (left, under, _) = common_prefix.rpartition('_')
3024    return left + under
3025
3026
3027def generate_code_for_enum(
3028    proto_enum: ProtoEnum, root: ProtoNode, output: OutputFile
3029) -> None:
3030    """Creates a C++ enum for a proto enum."""
3031    assert proto_enum.type() == ProtoNode.Type.ENUM
3032
3033    common_prefix = _common_value_prefix(proto_enum)
3034    output.write_line(
3035        f'enum class {proto_enum.cpp_namespace(root=root)} ' f': uint32_t {{'
3036    )
3037    with output.indent():
3038        for name, number in proto_enum.values():
3039            output.write_line(f'{name} = {number},')
3040
3041            style_name = 'k' + ProtoMessageField.upper_camel_case(
3042                name[len(common_prefix) :]
3043            )
3044            if style_name != name:
3045                output.write_line(f'{style_name} = {name},')
3046
3047    output.write_line('};')
3048
3049
3050def generate_function_for_enum(
3051    proto_enum: ProtoEnum, root: ProtoNode, output: OutputFile
3052) -> None:
3053    """Creates a C++ validation function for a proto enum."""
3054    assert proto_enum.type() == ProtoNode.Type.ENUM
3055
3056    enum_name = proto_enum.cpp_namespace(root=root)
3057    output.write_line(
3058        f'constexpr bool IsValid{enum_name}({enum_name} value) {{'
3059    )
3060    with output.indent():
3061        output.write_line('switch (value) {')
3062        with output.indent():
3063            for name, _ in proto_enum.values():
3064                output.write_line(f'case {enum_name}::{name}: return true;')
3065            output.write_line('default: return false;')
3066        output.write_line('}')
3067    output.write_line('}')
3068
3069
3070def generate_to_string_for_enum(
3071    proto_enum: ProtoEnum, root: ProtoNode, output: OutputFile
3072) -> None:
3073    """Creates a C++ to string function for a proto enum."""
3074    assert proto_enum.type() == ProtoNode.Type.ENUM
3075
3076    enum_name = proto_enum.cpp_namespace(root=root)
3077    output.write_line(
3078        f'// Returns string names for {enum_name}; '
3079        'returns "" for invalid enum values.'
3080    )
3081    output.write_line(
3082        f'constexpr const char* {enum_name}ToString({enum_name} value) {{'
3083    )
3084    with output.indent():
3085        output.write_line('switch (value) {')
3086        with output.indent():
3087            for name, _ in proto_enum.values():
3088                output.write_line(f'case {enum_name}::{name}: return "{name}";')
3089            output.write_line('default: return "";')
3090        output.write_line('}')
3091    output.write_line('}')
3092
3093
3094def forward_declare(
3095    message: ProtoMessage,
3096    root: ProtoNode,
3097    output: OutputFile,
3098    codegen_options: GeneratorOptions,
3099) -> None:
3100    """Generates code forward-declaring entities in a message's namespace."""
3101    namespace = message.cpp_namespace(root=root)
3102    output.write_line()
3103    output.write_line(f'namespace {namespace} {{')
3104
3105    # Define an enum defining each of the message's fields and their numbers.
3106    output.write_line('enum class Fields : uint32_t {')
3107    with output.indent():
3108        for field in message.fields():
3109            output.write_line(f'{field.enum_name()} = {field.number()},')
3110
3111        # Migration support from SNAKE_CASE to kConstantCase.
3112        if not codegen_options.exclude_legacy_snake_case_field_name_enums:
3113            for field in message.fields():
3114                output.write_line(
3115                    f'{field.legacy_enum_name()} = {field.number()},'
3116                )
3117
3118    output.write_line('};')
3119
3120    # Define constants for fixed-size fields.
3121    output.write_line()
3122    for prop in proto_message_field_props(codegen_options, message, root):
3123        max_size = prop.max_size()
3124        if max_size:
3125            output.write_line(
3126                f'static constexpr size_t {prop.max_size_constant_name()} '
3127                f'= {max_size};'
3128            )
3129
3130    # Declare the message's message struct.
3131    output.write_line()
3132    output.write_line('struct Message;')
3133
3134    # Declare the message's encoder classes.
3135    output.write_line()
3136    output.write_line('class StreamEncoder;')
3137    output.write_line('class MemoryEncoder;')
3138
3139    # Declare the message's decoder classes.
3140    output.write_line()
3141    output.write_line('class StreamDecoder;')
3142
3143    # Declare the message's enums.
3144    for child in message.children():
3145        if child.type() == ProtoNode.Type.ENUM:
3146            output.write_line()
3147            generate_code_for_enum(cast(ProtoEnum, child), message, output)
3148            output.write_line()
3149            generate_function_for_enum(cast(ProtoEnum, child), message, output)
3150            output.write_line()
3151            generate_to_string_for_enum(cast(ProtoEnum, child), message, output)
3152
3153    output.write_line(f'}}  // namespace {namespace}')
3154
3155
3156def generate_struct_for_message(
3157    message: ProtoMessage,
3158    root: ProtoNode,
3159    output: OutputFile,
3160    codegen_options: GeneratorOptions,
3161) -> None:
3162    """Creates a C++ struct to hold a protobuf message values."""
3163    assert message.type() == ProtoNode.Type.MESSAGE
3164
3165    output.write_line(f'struct {message.cpp_namespace(root=root)}::Message {{')
3166
3167    # Generate members for each of the message's fields.
3168    with output.indent():
3169        cmp: list[str] = []
3170        for prop in proto_message_field_props(codegen_options, message, root):
3171            type_name = prop.struct_member_type()
3172            name = prop.name()
3173            output.write_line(f'{type_name} {name};')
3174
3175            if prop.callback_type() is _CallbackType.NONE:
3176                cmp.append(f'this->{name} == other.{name}')
3177
3178        if codegen_options.oneof_callbacks:
3179            for oneof in message.oneofs():
3180                if oneof.is_synthetic():
3181                    continue
3182
3183                fields = f'{message.cpp_namespace(root=root)}::Fields'
3184                output.write_line(
3185                    f'{PROTOBUF_NAMESPACE}::OneOf'
3186                    f'<StreamEncoder, StreamDecoder, {fields}> {oneof.name};'
3187                )
3188
3189        # Equality operator
3190        output.write_line()
3191        output.write_line('bool operator==(const Message& other) const {')
3192        with output.indent():
3193            if len(cmp) > 0:
3194                output.write_line(f'return {" && ".join(cmp)};')
3195            else:
3196                output.write_line('static_cast<void>(other);')
3197                output.write_line('return true;')
3198        output.write_line('}')
3199        output.write_line(
3200            'bool operator!=(const Message& other) const '
3201            '{ return !(*this == other); }'
3202        )
3203
3204    output.write_line('};')
3205
3206
3207def generate_table_for_message(
3208    message: ProtoMessage,
3209    root: ProtoNode,
3210    output: OutputFile,
3211    codegen_options: GeneratorOptions,
3212) -> None:
3213    """Creates a C++ array to hold a protobuf message description."""
3214    assert message.type() == ProtoNode.Type.MESSAGE
3215
3216    namespace = message.cpp_namespace(root=root)
3217    output.write_line(f'namespace {namespace} {{')
3218
3219    properties = list(proto_message_field_props(codegen_options, message, root))
3220
3221    output.write_line('PW_MODIFY_DIAGNOSTICS_PUSH();')
3222    output.write_line('PW_MODIFY_DIAGNOSTIC(ignored, "-Winvalid-offsetof");')
3223
3224    # Generate static_asserts to fail at compile-time if the structure cannot
3225    # be converted into a table.
3226    for idx, prop in enumerate(properties):
3227        if idx > 0:
3228            output.write_line(
3229                'static_assert(offsetof(Message, {}) > 0);'.format(prop.name())
3230            )
3231        output.write_line(
3232            'static_assert(sizeof(Message::{}) <= '
3233            '{}::MessageField::kMaxFieldSize);'.format(
3234                prop.name(), _INTERNAL_NAMESPACE
3235            )
3236        )
3237
3238    # Zero-length C arrays are not permitted by the C++ standard, so only
3239    # generate the message fields array if it is non-empty. Zero-length
3240    # std::arrays are valid, but older toolchains may not support constexpr
3241    # std::arrays, even with -std=c++17.
3242    #
3243    # The kMessageFields span is generated whether the message has fields or
3244    # not. Only the span is referenced elsewhere.
3245    all_properties = list(
3246        proto_message_field_props(
3247            codegen_options, message, root, include_hidden=True
3248        )
3249    )
3250    if all_properties:
3251        output.write_line(
3252            f'inline constexpr {_INTERNAL_NAMESPACE}::MessageField '
3253            ' _kMessageFields[] = {'
3254        )
3255
3256        # Generate members for each of the message's fields.
3257        with output.indent():
3258            for prop in all_properties:
3259                table = ', '.join(prop.table_entry())
3260                output.write_line(f'{{{table}}},')
3261
3262        output.write_line('};')
3263        output.write_line('PW_MODIFY_DIAGNOSTICS_POP();')
3264
3265        output.write_line(
3266            f'inline constexpr pw::span<const {_INTERNAL_NAMESPACE}::'
3267            'MessageField> kMessageFields = _kMessageFields;'
3268        )
3269
3270        member_list = ', '.join(
3271            [f'message.{prop.name()}' for prop in properties]
3272        )
3273
3274        if properties:
3275            # Generate std::tuple for main Message fields only.
3276            output.write_line(
3277                'inline constexpr auto ToTuple(const Message &message) {'
3278            )
3279            output.write_line(f'  return std::tie({member_list});')
3280            output.write_line('}')
3281
3282            # Generate mutable std::tuple for Message fields.
3283            output.write_line(
3284                'inline constexpr auto ToMutableTuple(Message &message) {'
3285            )
3286            output.write_line(f'  return std::tie({member_list});')
3287            output.write_line('}')
3288    else:
3289        output.write_line(
3290            f'inline constexpr pw::span<const {_INTERNAL_NAMESPACE}::'
3291            'MessageField> kMessageFields;'
3292        )
3293
3294    output.write_line(f'}}  // namespace {namespace}')
3295
3296
3297def generate_sizes_for_message(
3298    message: ProtoMessage,
3299    root: ProtoNode,
3300    output: OutputFile,
3301    codegen_options: GeneratorOptions,
3302) -> None:
3303    """Creates C++ constants for the encoded sizes of a protobuf message."""
3304    assert message.type() == ProtoNode.Type.MESSAGE
3305
3306    namespace = message.cpp_namespace(root=root)
3307    output.write_line(f'namespace {namespace} {{')
3308
3309    property_sizes: list[str] = []
3310    scratch_sizes: list[str] = []
3311    for prop in proto_message_field_props(codegen_options, message, root):
3312        property_sizes.append(prop.max_encoded_size())
3313        if prop.include_in_scratch_size():
3314            scratch_sizes.append(prop.max_encoded_size())
3315
3316    output.write_line('inline constexpr size_t kMaxEncodedSizeBytes =')
3317    with output.indent():
3318        if len(property_sizes) == 0:
3319            output.write_line('0;')
3320        while len(property_sizes) > 0:
3321            property_size = property_sizes.pop(0)
3322            if len(property_sizes) > 0:
3323                output.write_line(f'{property_size} +')
3324            else:
3325                output.write_line(f'{property_size};')
3326
3327    output.write_line()
3328    output.write_line(
3329        'inline constexpr size_t kScratchBufferSizeBytes = '
3330        + ('std::max({' if len(scratch_sizes) > 0 else '0;')
3331    )
3332    with output.indent():
3333        for scratch_size in scratch_sizes:
3334            output.write_line(f'{scratch_size},')
3335    if len(scratch_sizes) > 0:
3336        output.write_line('});')
3337
3338    output.write_line(f'}}  // namespace {namespace}')
3339
3340
3341def generate_find_functions_for_message(
3342    message: ProtoMessage,
3343    root: ProtoNode,
3344    output: OutputFile,
3345    codegen_options: GeneratorOptions,
3346) -> None:
3347    """Creates C++ constants for the encoded sizes of a protobuf message."""
3348    assert message.type() == ProtoNode.Type.MESSAGE
3349
3350    namespace = message.cpp_namespace(root=root)
3351    output.write_line(f'namespace {namespace} {{')
3352
3353    for field in message.fields():
3354        try:
3355            methods = PROTO_FIELD_FIND_METHODS[field.type()]
3356        except KeyError:
3357            continue
3358
3359        for cls in methods:
3360            method = cls(codegen_options, field, message, root, '')
3361            method_signature = (
3362                f'inline {method.return_type()} '
3363                f'{method.name()}({method.param_string()})'
3364            )
3365
3366            output.write_line()
3367            output.write_line(f'{method_signature} {{')
3368
3369            with output.indent():
3370                for line in method.body():
3371                    output.write_line(line)
3372
3373            output.write_line('}')
3374
3375    output.write_line(f'}}  // namespace {namespace}')
3376
3377
3378def generate_is_trivially_comparable_specialization(
3379    message: ProtoMessage,
3380    root: ProtoNode,
3381    output: OutputFile,
3382    codegen_options: GeneratorOptions,
3383) -> None:
3384    is_trivially_comparable = True
3385    for prop in proto_message_field_props(codegen_options, message, root):
3386        if prop.callback_type() is not _CallbackType.NONE:
3387            is_trivially_comparable = False
3388            break
3389
3390    qualified_message = f'::{message.cpp_namespace()}::Message'
3391
3392    output.write_line('template <>')
3393    output.write_line(
3394        'constexpr bool IsTriviallyComparable' f'<{qualified_message}>() {{'
3395    )
3396    output.write_line(f'  return {str(is_trivially_comparable).lower()};')
3397    output.write_line('}')
3398
3399
3400def _proto_filename_to_generated_header(proto_file: str) -> str:
3401    """Returns the generated C++ header name for a .proto file."""
3402    return os.path.splitext(proto_file)[0] + PROTO_H_EXTENSION
3403
3404
3405def dependency_sorted_messages(package: ProtoNode):
3406    """Yields the messages in the package sorted after their dependencies."""
3407
3408    # Build the graph of dependencies between messages.
3409    graph: dict[ProtoMessage, list[ProtoMessage]] = {}
3410    for node in package:
3411        if node.type() == ProtoNode.Type.MESSAGE:
3412            message = cast(ProtoMessage, node)
3413            graph[message] = message.dependencies()
3414
3415    # Repeatedly prepare a topological sort of the dependency graph, removing
3416    # a dependency each time a cycle is a detected, until we're left with a
3417    # fully directed graph.
3418    tsort: TopologicalSorter
3419    while True:
3420        tsort = TopologicalSorter(graph)
3421        try:
3422            tsort.prepare()
3423            break
3424        except CycleError as err:
3425            dependency, message = err.args[1][0], err.args[1][1]
3426            message.remove_dependency_cycle(dependency)
3427            graph[message] = message.dependencies()
3428
3429    # Yield the messages from the sorted graph.
3430    while tsort.is_active():
3431        messages = tsort.get_ready()
3432        yield from messages
3433        tsort.done(*messages)
3434
3435
3436def generate_code_for_package(
3437    file_descriptor_proto,
3438    package: ProtoNode,
3439    output: OutputFile,
3440    codegen_options: GeneratorOptions,
3441) -> None:
3442    """Generates code for a single .pb.h file corresponding to a .proto file."""
3443
3444    assert package.type() == ProtoNode.Type.PACKAGE
3445
3446    output.write_line(
3447        f'// {os.path.basename(output.name())} automatically '
3448        f'generated by {PLUGIN_NAME} {PLUGIN_VERSION}'
3449    )
3450    output.write_line('#pragma once\n')
3451    output.write_line('#include <algorithm>')
3452    output.write_line('#include <array>')
3453    output.write_line('#include <cstddef>')
3454    output.write_line('#include <cstdint>')
3455    output.write_line('#include <optional>')
3456    output.write_line('#include <string_view>\n')
3457    output.write_line('#include "pw_assert/assert.h"')
3458    output.write_line('#include "pw_containers/vector.h"')
3459    output.write_line('#include "pw_preprocessor/compiler.h"')
3460    output.write_line('#include "pw_protobuf/encoder.h"')
3461    output.write_line('#include "pw_protobuf/find.h"')
3462    output.write_line('#include "pw_protobuf/internal/codegen.h"')
3463    output.write_line('#include "pw_protobuf/serialized_size.h"')
3464    output.write_line('#include "pw_protobuf/stream_decoder.h"')
3465    output.write_line('#include "pw_result/result.h"')
3466    output.write_line('#include "pw_span/span.h"')
3467    output.write_line('#include "pw_status/status.h"')
3468    output.write_line('#include "pw_status/status_with_size.h"')
3469    output.write_line('#include "pw_string/string.h"')
3470
3471    for imported_file in file_descriptor_proto.dependency:
3472        generated_header = _proto_filename_to_generated_header(imported_file)
3473        output.write_line(f'#include "{generated_header}"')
3474
3475    if package.cpp_namespace():
3476        file_namespace = package.cpp_namespace()
3477        if file_namespace.startswith('::'):
3478            file_namespace = file_namespace[2:]
3479
3480        output.write_line(f'\nnamespace {file_namespace} {{')
3481
3482    for node in package:
3483        if node.type() == ProtoNode.Type.MESSAGE:
3484            forward_declare(
3485                cast(ProtoMessage, node),
3486                package,
3487                output,
3488                codegen_options,
3489            )
3490
3491    # Define all top-level enums.
3492    for node in package.children():
3493        if node.type() == ProtoNode.Type.ENUM:
3494            output.write_line()
3495            generate_code_for_enum(cast(ProtoEnum, node), package, output)
3496            output.write_line()
3497            generate_function_for_enum(cast(ProtoEnum, node), package, output)
3498            output.write_line()
3499            generate_to_string_for_enum(cast(ProtoEnum, node), package, output)
3500
3501    # Run through all messages, generating structs and classes for each.
3502    messages = []
3503    for message in dependency_sorted_messages(package):
3504        output.write_line()
3505        generate_struct_for_message(message, package, output, codegen_options)
3506        output.write_line()
3507        generate_table_for_message(message, package, output, codegen_options)
3508        output.write_line()
3509        generate_sizes_for_message(message, package, output, codegen_options)
3510        output.write_line()
3511        generate_find_functions_for_message(
3512            message,
3513            package,
3514            output,
3515            codegen_options,
3516        )
3517        output.write_line()
3518        generate_class_for_message(
3519            message,
3520            package,
3521            output,
3522            codegen_options,
3523            ClassType.STREAMING_ENCODER,
3524        )
3525        output.write_line()
3526        generate_class_for_message(
3527            message,
3528            package,
3529            output,
3530            codegen_options,
3531            ClassType.MEMORY_ENCODER,
3532        )
3533        output.write_line()
3534        generate_class_for_message(
3535            message,
3536            package,
3537            output,
3538            codegen_options,
3539            ClassType.STREAMING_DECODER,
3540        )
3541        messages.append(message)
3542
3543    # Run a second pass through the messages, this time defining all of the
3544    # methods which were previously only declared.
3545    for message in messages:
3546        define_not_in_class_methods(
3547            message,
3548            package,
3549            output,
3550            codegen_options,
3551            ClassType.STREAMING_ENCODER,
3552        )
3553        define_not_in_class_methods(
3554            message,
3555            package,
3556            output,
3557            codegen_options,
3558            ClassType.MEMORY_ENCODER,
3559        )
3560        define_not_in_class_methods(
3561            message,
3562            package,
3563            output,
3564            codegen_options,
3565            ClassType.STREAMING_DECODER,
3566        )
3567
3568    if package.cpp_namespace():
3569        output.write_line(f'\n}}  // namespace {package.cpp_namespace()}')
3570
3571        # Aliasing namespaces aren't needed if `package.cpp_namespace()` is
3572        # empty (since everyone can see the global namespace). It shouldn't
3573        # ever be empty, though.
3574
3575        if not codegen_options.suppress_legacy_namespace:
3576            output.write_line()
3577            output.write_line(
3578                '// Aliases for legacy pwpb codegen interface. '
3579                'Please use the'
3580            )
3581            output.write_line('// `::pwpb`-suffixed names in new code.')
3582            legacy_namespace = package.cpp_namespace(codegen_subnamespace=None)
3583            output.write_line(f'namespace {legacy_namespace} {{')
3584            output.write_line(f'using namespace ::{package.cpp_namespace()};')
3585            output.write_line(f'}}  // namespace {legacy_namespace}')
3586
3587        # TODO: b/250945489 - Remove this if possible
3588        output.write_line()
3589        output.write_line(
3590            '// Codegen implementation detail; do not use this namespace!'
3591        )
3592
3593        external_lookup_namespace = "{}::{}".format(
3594            EXTERNAL_SYMBOL_WORKAROUND_NAMESPACE,
3595            package.cpp_namespace(codegen_subnamespace=None),
3596        )
3597
3598        output.write_line(f'namespace {external_lookup_namespace} {{')
3599        output.write_line(f'using namespace ::{package.cpp_namespace()};')
3600        output.write_line(f'}}  // namespace {external_lookup_namespace}')
3601
3602    if messages:
3603        proto_namespace = PROTOBUF_NAMESPACE.lstrip(':')
3604        output.write_line()
3605        output.write_line(f'namespace {proto_namespace} {{')
3606
3607        for message in messages:
3608            generate_is_trivially_comparable_specialization(
3609                message,
3610                package,
3611                output,
3612                codegen_options,
3613            )
3614
3615        output.write_line(f'}}  // namespace {proto_namespace}')
3616
3617
3618def process_proto_file(
3619    proto_file,
3620    proto_options,
3621    codegen_options: GeneratorOptions,
3622) -> Iterable[OutputFile] | None:
3623    """Generates code for a single .proto file."""
3624
3625    # Two passes are made through the file. The first builds the tree of all
3626    # message/enum nodes, then the second creates the fields in each. This is
3627    # done as non-primitive fields need pointers to their types, which requires
3628    # the entire tree to have been parsed into memory.
3629    _, package_root = build_node_tree(proto_file, proto_options=proto_options)
3630
3631    output_filename = _proto_filename_to_generated_header(proto_file.name)
3632    output_file = OutputFile(output_filename)
3633
3634    try:
3635        generate_code_for_package(
3636            proto_file,
3637            package_root,
3638            output_file,
3639            codegen_options,
3640        )
3641    except CodegenError as e:
3642        print(e.formatted_message(), file=sys.stderr)
3643        return None
3644
3645    return [output_file]
3646