xref: /aosp_15_r20/external/pigweed/pw_protobuf/py/pw_protobuf/proto_tree.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2020 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""This module defines data structures for protobuf entities."""
15
16from __future__ import annotations
17
18import abc
19import collections
20import dataclasses
21import enum
22import itertools
23
24from typing import (
25    Callable,
26    Iterator,
27    TypeVar,
28    cast,
29)
30
31from google.protobuf import descriptor_pb2
32
33from pw_protobuf import edition_constants, options, symbol_name_mapping
34from pw_protobuf_codegen_protos.codegen_options_pb2 import CodegenOptions
35from pw_protobuf_protos.field_options_pb2 import pwpb as pwpb_field_options
36
37T = TypeVar('T')  # pylint: disable=invalid-name
38
39# Currently, protoc does not do a traversal to look up the package name of all
40# messages that are referenced in the file. For such "external" message names,
41# we are unable to find where the "::pwpb" subnamespace would be inserted by our
42# codegen. This namespace provides us with an alternative, more verbose
43# namespace that the codegen can use as a fallback in these cases. For example,
44# for the symbol name `my.external.package.ProtoMsg.SubMsg`, we would use
45# `::pw::pwpb_codegen_private::my::external::package:ProtoMsg::SubMsg` to refer
46# to the pw_protobuf generated code, when package name info is not available.
47#
48# TODO: b/258832150 - Explore removing this if possible
49EXTERNAL_SYMBOL_WORKAROUND_NAMESPACE = 'pw::pwpb_codegen_private'
50
51
52class ProtoNode(abc.ABC):
53    """A ProtoNode represents a C++ scope mapping of an entity in a .proto file.
54
55    Nodes form a tree beginning at a top-level (global) scope, descending into a
56    hierarchy of .proto packages and the messages and enums defined within them.
57    """
58
59    class Type(enum.Enum):
60        """The type of a ProtoNode.
61
62        PACKAGE maps to a C++ namespace.
63        MESSAGE maps to a C++ "Encoder" class within its own namespace.
64        ENUM maps to a C++ enum within its parent's namespace.
65        EXTERNAL represents a node defined within a different compilation unit.
66        SERVICE represents an RPC service definition.
67        """
68
69        PACKAGE = 1
70        MESSAGE = 2
71        ENUM = 3
72        EXTERNAL = 4
73        SERVICE = 5
74
75    def __init__(self, name: str):
76        self._name: str = name
77        self._children: dict[str, ProtoNode] = collections.OrderedDict()
78        self._parent: ProtoNode | None = None
79
80    @abc.abstractmethod
81    def type(self) -> ProtoNode.Type:
82        """The type of the node."""
83
84    def children(self) -> list[ProtoNode]:
85        return list(self._children.values())
86
87    def parent(self) -> ProtoNode | None:
88        return self._parent
89
90    def name(self) -> str:
91        return self._name
92
93    def cpp_name(self) -> str:
94        """The name of this node in generated C++ code."""
95        return symbol_name_mapping.fix_cc_identifier(self._name).replace(
96            '.', '::'
97        )
98
99    def _package_or_external(self) -> ProtoNode:
100        """Returns this node's deepest package or external ancestor node.
101
102        This method may need to return an external node, as a fallback for
103        external names that are referenced, but not processed into a more
104        regular proto tree. This is because there is no way to find the package
105        name of a node referring to an external symbol.
106        """
107        node: ProtoNode | None = self
108        while (
109            node
110            and node.type() != ProtoNode.Type.PACKAGE
111            and node.type() != ProtoNode.Type.EXTERNAL
112        ):
113            node = node.parent()
114
115        assert node, 'proto tree was built without a root'
116        return node
117
118    def cpp_namespace(
119        self,
120        root: ProtoNode | None = None,
121        codegen_subnamespace: str | None = 'pwpb',
122    ) -> str:
123        """C++ namespace of the node, up to the specified root.
124
125        Args:
126          root: Namespace from which this ProtoNode is referred. If this
127            ProtoNode has `root` as an ancestor namespace, then the ancestor
128            namespace scopes above `root` are omitted.
129
130          codegen_subnamespace: A subnamespace that is appended to the package
131            declared in the .proto file. It is appended to the declared package,
132            but before any namespaces that are needed for messages etc. This
133            feature can be used to allow different codegen tools to output
134            different, non-conflicting symbols for the same protos.
135
136            By default, this is "pwpb", which reflects the default behaviour
137            of the pwpb codegen.
138        """
139        self_pkg_or_ext = self._package_or_external()
140        root_pkg_or_ext = (
141            root._package_or_external()  # pylint: disable=protected-access
142            if root is not None
143            else None
144        )
145        if root_pkg_or_ext:
146            assert root_pkg_or_ext.type() != ProtoNode.Type.EXTERNAL
147
148        def compute_hierarchy() -> Iterator[str]:
149            same_package = True
150
151            if self_pkg_or_ext.type() == ProtoNode.Type.EXTERNAL:
152                # Can't figure out where the namespace cutoff is. Punt to using
153                # the external symbol workaround.
154                #
155                # TODO: b/250945489 - Investigate removing this limitation /
156                # hack
157                return itertools.chain(
158                    [EXTERNAL_SYMBOL_WORKAROUND_NAMESPACE],
159                    self._attr_hierarchy(ProtoNode.cpp_name, root=None),
160                )
161
162            if root is None or root_pkg_or_ext is None:  # extra check for mypy
163                # TODO: b/250945489 - maybe elide "::{codegen_subnamespace}"
164                # here, if this node doesn't have any package?
165                same_package = False
166            else:
167                paired_hierarchy = itertools.zip_longest(
168                    self_pkg_or_ext._attr_hierarchy(  # pylint: disable=protected-access
169                        ProtoNode.cpp_name, root=None
170                    ),
171                    root_pkg_or_ext._attr_hierarchy(  # pylint: disable=protected-access
172                        ProtoNode.cpp_name, root=None
173                    ),
174                )
175                for str_a, str_b in paired_hierarchy:
176                    if str_a != str_b:
177                        same_package = False
178                        break
179
180            if same_package:
181                # This ProtoNode and the requested root are in the same package,
182                # so the `codegen_subnamespace` should be omitted.
183                hierarchy = self._attr_hierarchy(ProtoNode.cpp_name, root)
184                return hierarchy
185
186            # The given root is either effectively nonexistent (common ancestor
187            # is ""), or is only a partial match for the package of this node.
188            # Either way, we will have to insert `codegen_subnamespace` after
189            # the relevant package string.
190            package_hierarchy = self_pkg_or_ext._attr_hierarchy(  # pylint: disable=protected-access
191                ProtoNode.cpp_name, root
192            )
193            maybe_subnamespace = (
194                [codegen_subnamespace] if codegen_subnamespace else []
195            )
196            inside_hierarchy = self._attr_hierarchy(
197                ProtoNode.cpp_name, self_pkg_or_ext
198            )
199
200            hierarchy = itertools.chain(
201                package_hierarchy, maybe_subnamespace, inside_hierarchy
202            )
203            return hierarchy
204
205        joined_namespace = '::'.join(
206            name for name in compute_hierarchy() if name
207        )
208
209        return (
210            '' if joined_namespace == codegen_subnamespace else joined_namespace
211        )
212
213    def proto_path(self) -> str:
214        """Fully-qualified package path of the node."""
215        path = '.'.join(self._attr_hierarchy(lambda node: node.name(), None))
216        return path.lstrip('.')
217
218    def pwpb_struct(self) -> str:
219        """Name of the pw_protobuf struct for this proto."""
220        return '::' + self.cpp_namespace() + '::Message'
221
222    def pwpb_table(self) -> str:
223        """Name of the pw_protobuf table constant for this proto."""
224        return '::' + self.cpp_namespace() + '::kMessageFields'
225
226    def nanopb_fields(self) -> str:
227        """Name of the Nanopb variable that represents the proto fields."""
228        return self._nanopb_name() + '_fields'
229
230    def nanopb_struct(self) -> str:
231        """Name of the Nanopb struct for this proto."""
232        return '::' + self._nanopb_name()
233
234    def _nanopb_name(self) -> str:
235        name = '_'.join(self._attr_hierarchy(lambda node: node.name(), None))
236        return name.lstrip('_')
237
238    def common_ancestor(self, other: ProtoNode) -> ProtoNode | None:
239        """Finds the earliest common ancestor of this node and other."""
240
241        if other is None:
242            return None
243
244        own_depth = self.depth()
245        other_depth = other.depth()
246        diff = abs(own_depth - other_depth)
247
248        if own_depth < other_depth:
249            first: ProtoNode | None = self
250            second: ProtoNode | None = other
251        else:
252            first = other
253            second = self
254
255        while diff > 0:
256            assert second is not None
257            second = second.parent()
258            diff -= 1
259
260        while first != second:
261            if first is None or second is None:
262                return None
263
264            first = first.parent()
265            second = second.parent()
266
267        return first
268
269    def depth(self) -> int:
270        """Returns the depth of this node from the root."""
271        depth = 0
272        node = self._parent
273        while node:
274            depth += 1
275            node = node.parent()
276        return depth
277
278    def add_child(self, child: ProtoNode) -> None:
279        """Inserts a new node into the tree as a child of this node.
280
281        Args:
282          child: The node to insert.
283
284        Raises:
285          ValueError: This node does not allow nesting the given type of child.
286        """
287        if not self._supports_child(child):
288            raise ValueError(
289                'Invalid child %s for node of type %s'
290                % (child.type(), self.type())
291            )
292
293        # pylint: disable=protected-access
294        if child._parent is not None:
295            del child._parent._children[child.name()]
296
297        child._parent = self
298        self._children[child.name()] = child
299        # pylint: enable=protected-access
300
301    def find(self, path: str) -> ProtoNode | None:
302        """Finds a node within this node's subtree.
303
304        Args:
305          path: The path to the sought node.
306        """
307        node = self
308
309        # pylint: disable=protected-access
310        for section in path.split('.'):
311            child = node._children.get(section)
312            if child is None:
313                return None
314            node = child
315        # pylint: enable=protected-access
316
317        return node
318
319    def __iter__(self) -> Iterator[ProtoNode]:
320        """Iterates depth-first through all nodes in this node's subtree."""
321        yield self
322        for child_iterator in self._children.values():
323            for child in child_iterator:
324                yield child
325
326    def _attr_hierarchy(
327        self,
328        attr_accessor: Callable[[ProtoNode], T],
329        root: ProtoNode | None,
330    ) -> Iterator[T]:
331        """Fetches node attributes at each level of the tree from the root.
332
333        Args:
334          attr_accessor: Function which extracts attributes from a ProtoNode.
335          root: The node at which to terminate.
336
337        Returns:
338          An iterator to a list of the selected attributes from the root to the
339          current node.
340        """
341        hierarchy = []
342        node: ProtoNode | None = self
343        while node is not None and node != root:
344            hierarchy.append(attr_accessor(node))
345            node = node.parent()
346        return reversed(hierarchy)
347
348    @abc.abstractmethod
349    def _supports_child(self, child: ProtoNode) -> bool:
350        """Returns True if child is a valid child type for the current node."""
351
352
353class ProtoPackage(ProtoNode):
354    """A protobuf package."""
355
356    def type(self) -> ProtoNode.Type:
357        return ProtoNode.Type.PACKAGE
358
359    def _supports_child(self, child: ProtoNode) -> bool:
360        return True
361
362
363class ProtoEnum(ProtoNode):
364    """Representation of an enum in a .proto file."""
365
366    def __init__(self, name: str):
367        super().__init__(name)
368        self._values: list[tuple[str, int]] = []
369
370    def type(self) -> ProtoNode.Type:
371        return ProtoNode.Type.ENUM
372
373    def values(self) -> list[tuple[str, int]]:
374        return list(self._values)
375
376    def add_value(self, name: str, value: int) -> None:
377        self._values.append(
378            (
379                ProtoMessageField.upper_snake_case(
380                    symbol_name_mapping.fix_cc_enum_value_name(name)
381                ),
382                value,
383            )
384        )
385
386    def _supports_child(self, child: ProtoNode) -> bool:
387        # Enums cannot have nested children.
388        return False
389
390
391class ProtoMessage(ProtoNode):
392    """Representation of a message in a .proto file."""
393
394    @dataclasses.dataclass
395    class OneOf:
396        name: str
397        fields: list[ProtoMessageField] = dataclasses.field(
398            default_factory=list
399        )
400
401        def is_synthetic(self) -> bool:
402            """Returns whether this is a synthetic oneof field."""
403            # protoc expresses proto3 optional fields as a "synthetic" oneof
404            # containing only a single member. pw_protobuf does not support
405            # oneof in general, but has special handling for proto3 optional
406            # fields. This method exists to distinguish a real, user-defined
407            # oneof from a compiler-generated one.
408            # https://cs.opensource.google/protobuf/protobuf/+/main:src/google/protobuf/descriptor.proto;l=305;drc=5a68dddcf9564f92815296099f07f7dfe8713908
409            return len(self.fields) == 1 and self.fields[0].has_presence()
410
411    def __init__(self, name: str):
412        super().__init__(name)
413        self._fields: list[ProtoMessageField] = []
414        self._oneofs: list[ProtoMessage.OneOf] = []
415        self._dependencies: list[ProtoMessage] | None = None
416        self._dependency_cycles: list[ProtoMessage] = []
417
418    def type(self) -> ProtoNode.Type:
419        return ProtoNode.Type.MESSAGE
420
421    def fields(self) -> list[ProtoMessageField]:
422        return list(self._fields)
423
424    def oneofs(self) -> list[ProtoMessage.OneOf]:
425        return list(self._oneofs)
426
427    def add_field(
428        self,
429        field: ProtoMessageField,
430        oneof_index: int | None = None,
431    ) -> None:
432        self._fields.append(field)
433
434        if oneof_index is not None:
435            self._oneofs[oneof_index].fields.append(field)
436            # pylint: disable=protected-access
437            field._oneof = self._oneofs[oneof_index]
438
439    def add_oneof(self, name) -> None:
440        self._oneofs.append(ProtoMessage.OneOf(name))
441
442    def _supports_child(self, child: ProtoNode) -> bool:
443        return (
444            child.type() == self.Type.ENUM or child.type() == self.Type.MESSAGE
445        )
446
447    def dependencies(self) -> list[ProtoMessage]:
448        if self._dependencies is None:
449            self._dependencies = []
450            for field in self._fields:
451                if (
452                    field.type()
453                    != descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE
454                ):
455                    continue
456
457                type_node = field.type_node()
458                assert type_node is not None
459                if type_node.type() == ProtoNode.Type.MESSAGE:
460                    self._dependencies.append(cast(ProtoMessage, type_node))
461
462        return list(self._dependencies)
463
464    def dependency_cycles(self) -> list[ProtoMessage]:
465        return list(self._dependency_cycles)
466
467    def remove_dependency_cycle(self, dependency: ProtoMessage):
468        assert self._dependencies is not None
469        assert dependency in self._dependencies
470        self._dependencies.remove(dependency)
471        self._dependency_cycles.append(dependency)
472
473
474class ProtoService(ProtoNode):
475    """Representation of a service in a .proto file."""
476
477    def __init__(self, name: str):
478        super().__init__(name)
479        self._methods: list[ProtoServiceMethod] = []
480
481    def type(self) -> ProtoNode.Type:
482        return ProtoNode.Type.SERVICE
483
484    def methods(self) -> list[ProtoServiceMethod]:
485        return list(self._methods)
486
487    def add_method(self, method: ProtoServiceMethod) -> None:
488        self._methods.append(method)
489
490    def _supports_child(self, child: ProtoNode) -> bool:
491        return False
492
493
494class ProtoExternal(ProtoNode):
495    """A node from a different compilation unit.
496
497    An external node is one that isn't defined within the current compilation
498    unit, most likely as it comes from an imported proto file. Its type is not
499    known, so it does not have any members or additional data. Its purpose
500    within the node graph is to provide namespace resolution between compile
501    units.
502    """
503
504    def type(self) -> ProtoNode.Type:
505        return ProtoNode.Type.EXTERNAL
506
507    def _supports_child(self, child: ProtoNode) -> bool:
508        return True
509
510
511# This class is not a node and does not appear in the proto tree.
512# Fields belong to proto messages and are processed separately.
513class ProtoMessageField:
514    """Representation of a field within a protobuf message."""
515
516    def __init__(
517        self,
518        field_name: str,
519        field_number: int,
520        field_type: int,
521        type_node: ProtoNode | None = None,
522        has_presence: bool = False,
523        repeated: bool = False,
524        codegen_options: CodegenOptions | None = None,
525    ):
526        self._field_name = symbol_name_mapping.fix_cc_identifier(field_name)
527        self._number: int = field_number
528        self._type: int = field_type
529        self._type_node: ProtoNode | None = type_node
530        self._has_presence: bool = has_presence
531        self._repeated: bool = repeated
532        self._options: CodegenOptions | None = codegen_options
533        self._oneof: ProtoMessage.OneOf | None = None
534
535    def name(self) -> str:
536        return self.upper_camel_case(self._field_name)
537
538    def field_name(self) -> str:
539        return self._field_name
540
541    def enum_name(self) -> str:
542        return 'k' + self.name()
543
544    def legacy_enum_name(self) -> str:
545        return self.upper_snake_case(
546            symbol_name_mapping.fix_cc_enum_value_name(self._field_name)
547        )
548
549    def number(self) -> int:
550        return self._number
551
552    def type(self) -> int:
553        return self._type
554
555    def type_node(self) -> ProtoNode | None:
556        return self._type_node
557
558    def has_presence(self) -> bool:
559        return self._has_presence
560
561    def is_repeated(self) -> bool:
562        return self._repeated
563
564    def options(self) -> CodegenOptions | None:
565        return self._options
566
567    def oneof(self) -> ProtoMessage.OneOf | None:
568        if self._oneof is not None and not self._oneof.is_synthetic():
569            return self._oneof
570        return None
571
572    @staticmethod
573    def upper_camel_case(field_name: str) -> str:
574        """Converts a field name to UpperCamelCase."""
575        name_components = field_name.split('_')
576        return ''.join([word.lower().capitalize() for word in name_components])
577
578    @staticmethod
579    def upper_snake_case(field_name: str) -> str:
580        """Converts a field name to UPPER_SNAKE_CASE."""
581        return field_name.upper()
582
583
584class ProtoServiceMethod:
585    """A method defined in a protobuf service."""
586
587    class Type(enum.Enum):
588        UNARY = 'kUnary'
589        SERVER_STREAMING = 'kServerStreaming'
590        CLIENT_STREAMING = 'kClientStreaming'
591        BIDIRECTIONAL_STREAMING = 'kBidirectionalStreaming'
592
593        def cc_enum(self) -> str:
594            """Returns the pw_rpc MethodType C++ enum for this method type."""
595            return '::pw::rpc::MethodType::' + self.value
596
597    def __init__(
598        self,
599        service: ProtoService,
600        name: str,
601        method_type: Type,
602        request_type: ProtoNode,
603        response_type: ProtoNode,
604    ):
605        self._service = service
606        self._name = name
607        self._type = method_type
608        self._request_type = request_type
609        self._response_type = response_type
610
611    def service(self) -> ProtoService:
612        return self._service
613
614    def name(self) -> str:
615        return self._name
616
617    def type(self) -> Type:
618        return self._type
619
620    def server_streaming(self) -> bool:
621        return self._type in (
622            self.Type.SERVER_STREAMING,
623            self.Type.BIDIRECTIONAL_STREAMING,
624        )
625
626    def client_streaming(self) -> bool:
627        return self._type in (
628            self.Type.CLIENT_STREAMING,
629            self.Type.BIDIRECTIONAL_STREAMING,
630        )
631
632    def request_type(self) -> ProtoNode:
633        return self._request_type
634
635    def response_type(self) -> ProtoNode:
636        return self._response_type
637
638
639def _add_enum_fields(enum_node: ProtoNode, proto_enum) -> None:
640    """Adds fields from a protobuf enum descriptor to an enum node."""
641    assert enum_node.type() == ProtoNode.Type.ENUM
642    enum_node = cast(ProtoEnum, enum_node)
643
644    for value in proto_enum.value:
645        enum_node.add_value(value.name, value.number)
646
647
648def _create_external_nodes(root: ProtoNode, path: str) -> ProtoNode:
649    """Creates external nodes for a path starting from the given root."""
650
651    node = root
652    for part in path.split('.'):
653        child = node.find(part)
654        if not child:
655            child = ProtoExternal(part)
656            node.add_child(child)
657        node = child
658
659    return node
660
661
662def _find_or_create_node(
663    global_root: ProtoNode, package_root: ProtoNode, path: str
664) -> ProtoNode:
665    """Searches the proto tree for a node by path, creating it if not found."""
666
667    if path[0] == '.':
668        # Fully qualified path.
669        root_relative_path = path[1:]
670        search_root = global_root
671    else:
672        root_relative_path = path
673        search_root = package_root
674
675    node = search_root.find(root_relative_path)
676    if node is None:
677        # Create nodes for field types that don't exist within this
678        # compilation context, such as those imported from other .proto
679        # files.
680        node = _create_external_nodes(search_root, root_relative_path)
681
682    return node
683
684
685def _add_message_fields(
686    proto_file: descriptor_pb2.FileDescriptorProto,
687    global_root: ProtoNode,
688    package_root: ProtoNode,
689    message: ProtoNode,
690    proto_message,
691    proto_options,
692) -> None:
693    """Adds fields from a protobuf message descriptor to a message node."""
694    assert message.type() == ProtoNode.Type.MESSAGE
695    message = cast(ProtoMessage, message)
696
697    type_node: ProtoNode | None
698
699    for field in proto_message.field:
700        if field.type_name:
701            # The "type_name" member contains the global .proto path of the
702            # field's type object, for example ".pw.protobuf.test.KeyValuePair".
703            # Try to find the node for this object within the current context.
704            type_node = _find_or_create_node(
705                global_root, package_root, field.type_name
706            )
707        else:
708            type_node = None
709
710        repeated = (
711            field.label == descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED
712        )
713
714        has_presence = False
715
716        # TODO: pwbug.dev/366316523 - The generated protobuf types do not
717        # include the "edition" property, so getattr is used. Fix this when
718        # we upgrade protobuf and mypy-protobuf.
719        if getattr(proto_file, 'edition', None) == '2023':
720            has_presence = not repeated and (
721                field.type is ProtoNode.Type.MESSAGE
722                or field.options.features.field_presence
723                != edition_constants.FieldPresence.IMPLICIT.value
724            )
725        else:
726            # If the file does not use editions, only consider explicit
727            # proto3 optionality.
728            has_presence = field.proto3_optional
729
730        codegen_options = (
731            options.match_options(
732                '.'.join((message.proto_path(), field.name)), proto_options
733            )
734            if proto_options is not None
735            else None
736        )
737
738        field_options = (
739            options.create_from_field_options(
740                field.options.Extensions[pwpb_field_options]
741            )
742            if field.options.HasExtension(pwpb_field_options)
743            else None
744        )
745
746        merged_options = None
747
748        if field_options and codegen_options:
749            merged_options = options.merge_field_and_codegen_options(
750                field_options, codegen_options
751            )
752        elif field_options:
753            merged_options = field_options
754        elif codegen_options:
755            merged_options = codegen_options
756
757        oneof_index = (
758            field.oneof_index if field.HasField('oneof_index') else None
759        )
760
761        message.add_field(
762            ProtoMessageField(
763                field.name,
764                field.number,
765                field.type,
766                type_node,
767                has_presence,
768                repeated,
769                merged_options,
770            ),
771            oneof_index=oneof_index,
772        )
773
774
775def _add_service_methods(
776    global_root: ProtoNode,
777    package_root: ProtoNode,
778    service: ProtoNode,
779    proto_service,
780) -> None:
781    assert service.type() == ProtoNode.Type.SERVICE
782    service = cast(ProtoService, service)
783
784    for method in proto_service.method:
785        if method.client_streaming and method.server_streaming:
786            method_type = ProtoServiceMethod.Type.BIDIRECTIONAL_STREAMING
787        elif method.client_streaming:
788            method_type = ProtoServiceMethod.Type.CLIENT_STREAMING
789        elif method.server_streaming:
790            method_type = ProtoServiceMethod.Type.SERVER_STREAMING
791        else:
792            method_type = ProtoServiceMethod.Type.UNARY
793
794        request_node = _find_or_create_node(
795            global_root, package_root, method.input_type
796        )
797        response_node = _find_or_create_node(
798            global_root, package_root, method.output_type
799        )
800
801        service.add_method(
802            ProtoServiceMethod(
803                service, method.name, method_type, request_node, response_node
804            )
805        )
806
807
808def _populate_fields(
809    proto_file: descriptor_pb2.FileDescriptorProto,
810    global_root: ProtoNode,
811    package_root: ProtoNode,
812    proto_options: options.ParsedOptions | None,
813) -> None:
814    """Traverses a proto file, adding all message and enum fields to a tree."""
815
816    def populate_message(node, message):
817        """Recursively populates nested messages and enums."""
818        _add_message_fields(
819            proto_file, global_root, package_root, node, message, proto_options
820        )
821
822        for proto_enum in message.enum_type:
823            _add_enum_fields(node.find(proto_enum.name), proto_enum)
824        for msg in message.nested_type:
825            populate_message(node.find(msg.name), msg)
826
827    # Iterate through the proto file, populating top-level objects.
828    for proto_enum in proto_file.enum_type:
829        enum_node = package_root.find(proto_enum.name)
830        assert enum_node is not None
831        _add_enum_fields(enum_node, proto_enum)
832
833    for message in proto_file.message_type:
834        populate_message(package_root.find(message.name), message)
835
836    for service in proto_file.service:
837        service_node = package_root.find(service.name)
838        assert service_node is not None
839        _add_service_methods(global_root, package_root, service_node, service)
840
841
842def _build_hierarchy(
843    proto_file: descriptor_pb2.FileDescriptorProto,
844) -> tuple[ProtoPackage, ProtoPackage]:
845    """Creates a ProtoNode hierarchy from a proto file descriptor."""
846
847    root = ProtoPackage('')
848    package_root = root
849
850    for part in proto_file.package.split('.'):
851        package = ProtoPackage(part)
852        package_root.add_child(package)
853        package_root = package
854
855    def build_message_subtree(proto_message):
856        node = ProtoMessage(proto_message.name)
857        for oneof in proto_message.oneof_decl:
858            node.add_oneof(oneof.name)
859        for proto_enum in proto_message.enum_type:
860            node.add_child(ProtoEnum(proto_enum.name))
861        for submessage in proto_message.nested_type:
862            node.add_child(build_message_subtree(submessage))
863
864        return node
865
866    for proto_enum in proto_file.enum_type:
867        package_root.add_child(ProtoEnum(proto_enum.name))
868
869    for message in proto_file.message_type:
870        package_root.add_child(build_message_subtree(message))
871
872    for service in proto_file.service:
873        package_root.add_child(ProtoService(service.name))
874
875    return root, package_root
876
877
878def build_node_tree(
879    file_descriptor_proto: descriptor_pb2.FileDescriptorProto,
880    proto_options: options.ParsedOptions | None = None,
881) -> tuple[ProtoNode, ProtoNode]:
882    """Constructs a tree of proto nodes from a file descriptor.
883
884    Returns the root node of the entire proto package tree and the node
885    representing the file's package.
886    """
887    global_root, package_root = _build_hierarchy(file_descriptor_proto)
888    _populate_fields(
889        file_descriptor_proto, global_root, package_root, proto_options
890    )
891    return global_root, package_root
892