1# Copyright 2020 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""This module defines data structures for protobuf entities.""" 15 16from __future__ import annotations 17 18import abc 19import collections 20import dataclasses 21import enum 22import itertools 23 24from typing import ( 25 Callable, 26 Iterator, 27 TypeVar, 28 cast, 29) 30 31from google.protobuf import descriptor_pb2 32 33from pw_protobuf import edition_constants, options, symbol_name_mapping 34from pw_protobuf_codegen_protos.codegen_options_pb2 import CodegenOptions 35from pw_protobuf_protos.field_options_pb2 import pwpb as pwpb_field_options 36 37T = TypeVar('T') # pylint: disable=invalid-name 38 39# Currently, protoc does not do a traversal to look up the package name of all 40# messages that are referenced in the file. For such "external" message names, 41# we are unable to find where the "::pwpb" subnamespace would be inserted by our 42# codegen. This namespace provides us with an alternative, more verbose 43# namespace that the codegen can use as a fallback in these cases. For example, 44# for the symbol name `my.external.package.ProtoMsg.SubMsg`, we would use 45# `::pw::pwpb_codegen_private::my::external::package:ProtoMsg::SubMsg` to refer 46# to the pw_protobuf generated code, when package name info is not available. 47# 48# TODO: b/258832150 - Explore removing this if possible 49EXTERNAL_SYMBOL_WORKAROUND_NAMESPACE = 'pw::pwpb_codegen_private' 50 51 52class ProtoNode(abc.ABC): 53 """A ProtoNode represents a C++ scope mapping of an entity in a .proto file. 54 55 Nodes form a tree beginning at a top-level (global) scope, descending into a 56 hierarchy of .proto packages and the messages and enums defined within them. 57 """ 58 59 class Type(enum.Enum): 60 """The type of a ProtoNode. 61 62 PACKAGE maps to a C++ namespace. 63 MESSAGE maps to a C++ "Encoder" class within its own namespace. 64 ENUM maps to a C++ enum within its parent's namespace. 65 EXTERNAL represents a node defined within a different compilation unit. 66 SERVICE represents an RPC service definition. 67 """ 68 69 PACKAGE = 1 70 MESSAGE = 2 71 ENUM = 3 72 EXTERNAL = 4 73 SERVICE = 5 74 75 def __init__(self, name: str): 76 self._name: str = name 77 self._children: dict[str, ProtoNode] = collections.OrderedDict() 78 self._parent: ProtoNode | None = None 79 80 @abc.abstractmethod 81 def type(self) -> ProtoNode.Type: 82 """The type of the node.""" 83 84 def children(self) -> list[ProtoNode]: 85 return list(self._children.values()) 86 87 def parent(self) -> ProtoNode | None: 88 return self._parent 89 90 def name(self) -> str: 91 return self._name 92 93 def cpp_name(self) -> str: 94 """The name of this node in generated C++ code.""" 95 return symbol_name_mapping.fix_cc_identifier(self._name).replace( 96 '.', '::' 97 ) 98 99 def _package_or_external(self) -> ProtoNode: 100 """Returns this node's deepest package or external ancestor node. 101 102 This method may need to return an external node, as a fallback for 103 external names that are referenced, but not processed into a more 104 regular proto tree. This is because there is no way to find the package 105 name of a node referring to an external symbol. 106 """ 107 node: ProtoNode | None = self 108 while ( 109 node 110 and node.type() != ProtoNode.Type.PACKAGE 111 and node.type() != ProtoNode.Type.EXTERNAL 112 ): 113 node = node.parent() 114 115 assert node, 'proto tree was built without a root' 116 return node 117 118 def cpp_namespace( 119 self, 120 root: ProtoNode | None = None, 121 codegen_subnamespace: str | None = 'pwpb', 122 ) -> str: 123 """C++ namespace of the node, up to the specified root. 124 125 Args: 126 root: Namespace from which this ProtoNode is referred. If this 127 ProtoNode has `root` as an ancestor namespace, then the ancestor 128 namespace scopes above `root` are omitted. 129 130 codegen_subnamespace: A subnamespace that is appended to the package 131 declared in the .proto file. It is appended to the declared package, 132 but before any namespaces that are needed for messages etc. This 133 feature can be used to allow different codegen tools to output 134 different, non-conflicting symbols for the same protos. 135 136 By default, this is "pwpb", which reflects the default behaviour 137 of the pwpb codegen. 138 """ 139 self_pkg_or_ext = self._package_or_external() 140 root_pkg_or_ext = ( 141 root._package_or_external() # pylint: disable=protected-access 142 if root is not None 143 else None 144 ) 145 if root_pkg_or_ext: 146 assert root_pkg_or_ext.type() != ProtoNode.Type.EXTERNAL 147 148 def compute_hierarchy() -> Iterator[str]: 149 same_package = True 150 151 if self_pkg_or_ext.type() == ProtoNode.Type.EXTERNAL: 152 # Can't figure out where the namespace cutoff is. Punt to using 153 # the external symbol workaround. 154 # 155 # TODO: b/250945489 - Investigate removing this limitation / 156 # hack 157 return itertools.chain( 158 [EXTERNAL_SYMBOL_WORKAROUND_NAMESPACE], 159 self._attr_hierarchy(ProtoNode.cpp_name, root=None), 160 ) 161 162 if root is None or root_pkg_or_ext is None: # extra check for mypy 163 # TODO: b/250945489 - maybe elide "::{codegen_subnamespace}" 164 # here, if this node doesn't have any package? 165 same_package = False 166 else: 167 paired_hierarchy = itertools.zip_longest( 168 self_pkg_or_ext._attr_hierarchy( # pylint: disable=protected-access 169 ProtoNode.cpp_name, root=None 170 ), 171 root_pkg_or_ext._attr_hierarchy( # pylint: disable=protected-access 172 ProtoNode.cpp_name, root=None 173 ), 174 ) 175 for str_a, str_b in paired_hierarchy: 176 if str_a != str_b: 177 same_package = False 178 break 179 180 if same_package: 181 # This ProtoNode and the requested root are in the same package, 182 # so the `codegen_subnamespace` should be omitted. 183 hierarchy = self._attr_hierarchy(ProtoNode.cpp_name, root) 184 return hierarchy 185 186 # The given root is either effectively nonexistent (common ancestor 187 # is ""), or is only a partial match for the package of this node. 188 # Either way, we will have to insert `codegen_subnamespace` after 189 # the relevant package string. 190 package_hierarchy = self_pkg_or_ext._attr_hierarchy( # pylint: disable=protected-access 191 ProtoNode.cpp_name, root 192 ) 193 maybe_subnamespace = ( 194 [codegen_subnamespace] if codegen_subnamespace else [] 195 ) 196 inside_hierarchy = self._attr_hierarchy( 197 ProtoNode.cpp_name, self_pkg_or_ext 198 ) 199 200 hierarchy = itertools.chain( 201 package_hierarchy, maybe_subnamespace, inside_hierarchy 202 ) 203 return hierarchy 204 205 joined_namespace = '::'.join( 206 name for name in compute_hierarchy() if name 207 ) 208 209 return ( 210 '' if joined_namespace == codegen_subnamespace else joined_namespace 211 ) 212 213 def proto_path(self) -> str: 214 """Fully-qualified package path of the node.""" 215 path = '.'.join(self._attr_hierarchy(lambda node: node.name(), None)) 216 return path.lstrip('.') 217 218 def pwpb_struct(self) -> str: 219 """Name of the pw_protobuf struct for this proto.""" 220 return '::' + self.cpp_namespace() + '::Message' 221 222 def pwpb_table(self) -> str: 223 """Name of the pw_protobuf table constant for this proto.""" 224 return '::' + self.cpp_namespace() + '::kMessageFields' 225 226 def nanopb_fields(self) -> str: 227 """Name of the Nanopb variable that represents the proto fields.""" 228 return self._nanopb_name() + '_fields' 229 230 def nanopb_struct(self) -> str: 231 """Name of the Nanopb struct for this proto.""" 232 return '::' + self._nanopb_name() 233 234 def _nanopb_name(self) -> str: 235 name = '_'.join(self._attr_hierarchy(lambda node: node.name(), None)) 236 return name.lstrip('_') 237 238 def common_ancestor(self, other: ProtoNode) -> ProtoNode | None: 239 """Finds the earliest common ancestor of this node and other.""" 240 241 if other is None: 242 return None 243 244 own_depth = self.depth() 245 other_depth = other.depth() 246 diff = abs(own_depth - other_depth) 247 248 if own_depth < other_depth: 249 first: ProtoNode | None = self 250 second: ProtoNode | None = other 251 else: 252 first = other 253 second = self 254 255 while diff > 0: 256 assert second is not None 257 second = second.parent() 258 diff -= 1 259 260 while first != second: 261 if first is None or second is None: 262 return None 263 264 first = first.parent() 265 second = second.parent() 266 267 return first 268 269 def depth(self) -> int: 270 """Returns the depth of this node from the root.""" 271 depth = 0 272 node = self._parent 273 while node: 274 depth += 1 275 node = node.parent() 276 return depth 277 278 def add_child(self, child: ProtoNode) -> None: 279 """Inserts a new node into the tree as a child of this node. 280 281 Args: 282 child: The node to insert. 283 284 Raises: 285 ValueError: This node does not allow nesting the given type of child. 286 """ 287 if not self._supports_child(child): 288 raise ValueError( 289 'Invalid child %s for node of type %s' 290 % (child.type(), self.type()) 291 ) 292 293 # pylint: disable=protected-access 294 if child._parent is not None: 295 del child._parent._children[child.name()] 296 297 child._parent = self 298 self._children[child.name()] = child 299 # pylint: enable=protected-access 300 301 def find(self, path: str) -> ProtoNode | None: 302 """Finds a node within this node's subtree. 303 304 Args: 305 path: The path to the sought node. 306 """ 307 node = self 308 309 # pylint: disable=protected-access 310 for section in path.split('.'): 311 child = node._children.get(section) 312 if child is None: 313 return None 314 node = child 315 # pylint: enable=protected-access 316 317 return node 318 319 def __iter__(self) -> Iterator[ProtoNode]: 320 """Iterates depth-first through all nodes in this node's subtree.""" 321 yield self 322 for child_iterator in self._children.values(): 323 for child in child_iterator: 324 yield child 325 326 def _attr_hierarchy( 327 self, 328 attr_accessor: Callable[[ProtoNode], T], 329 root: ProtoNode | None, 330 ) -> Iterator[T]: 331 """Fetches node attributes at each level of the tree from the root. 332 333 Args: 334 attr_accessor: Function which extracts attributes from a ProtoNode. 335 root: The node at which to terminate. 336 337 Returns: 338 An iterator to a list of the selected attributes from the root to the 339 current node. 340 """ 341 hierarchy = [] 342 node: ProtoNode | None = self 343 while node is not None and node != root: 344 hierarchy.append(attr_accessor(node)) 345 node = node.parent() 346 return reversed(hierarchy) 347 348 @abc.abstractmethod 349 def _supports_child(self, child: ProtoNode) -> bool: 350 """Returns True if child is a valid child type for the current node.""" 351 352 353class ProtoPackage(ProtoNode): 354 """A protobuf package.""" 355 356 def type(self) -> ProtoNode.Type: 357 return ProtoNode.Type.PACKAGE 358 359 def _supports_child(self, child: ProtoNode) -> bool: 360 return True 361 362 363class ProtoEnum(ProtoNode): 364 """Representation of an enum in a .proto file.""" 365 366 def __init__(self, name: str): 367 super().__init__(name) 368 self._values: list[tuple[str, int]] = [] 369 370 def type(self) -> ProtoNode.Type: 371 return ProtoNode.Type.ENUM 372 373 def values(self) -> list[tuple[str, int]]: 374 return list(self._values) 375 376 def add_value(self, name: str, value: int) -> None: 377 self._values.append( 378 ( 379 ProtoMessageField.upper_snake_case( 380 symbol_name_mapping.fix_cc_enum_value_name(name) 381 ), 382 value, 383 ) 384 ) 385 386 def _supports_child(self, child: ProtoNode) -> bool: 387 # Enums cannot have nested children. 388 return False 389 390 391class ProtoMessage(ProtoNode): 392 """Representation of a message in a .proto file.""" 393 394 @dataclasses.dataclass 395 class OneOf: 396 name: str 397 fields: list[ProtoMessageField] = dataclasses.field( 398 default_factory=list 399 ) 400 401 def is_synthetic(self) -> bool: 402 """Returns whether this is a synthetic oneof field.""" 403 # protoc expresses proto3 optional fields as a "synthetic" oneof 404 # containing only a single member. pw_protobuf does not support 405 # oneof in general, but has special handling for proto3 optional 406 # fields. This method exists to distinguish a real, user-defined 407 # oneof from a compiler-generated one. 408 # https://cs.opensource.google/protobuf/protobuf/+/main:src/google/protobuf/descriptor.proto;l=305;drc=5a68dddcf9564f92815296099f07f7dfe8713908 409 return len(self.fields) == 1 and self.fields[0].has_presence() 410 411 def __init__(self, name: str): 412 super().__init__(name) 413 self._fields: list[ProtoMessageField] = [] 414 self._oneofs: list[ProtoMessage.OneOf] = [] 415 self._dependencies: list[ProtoMessage] | None = None 416 self._dependency_cycles: list[ProtoMessage] = [] 417 418 def type(self) -> ProtoNode.Type: 419 return ProtoNode.Type.MESSAGE 420 421 def fields(self) -> list[ProtoMessageField]: 422 return list(self._fields) 423 424 def oneofs(self) -> list[ProtoMessage.OneOf]: 425 return list(self._oneofs) 426 427 def add_field( 428 self, 429 field: ProtoMessageField, 430 oneof_index: int | None = None, 431 ) -> None: 432 self._fields.append(field) 433 434 if oneof_index is not None: 435 self._oneofs[oneof_index].fields.append(field) 436 # pylint: disable=protected-access 437 field._oneof = self._oneofs[oneof_index] 438 439 def add_oneof(self, name) -> None: 440 self._oneofs.append(ProtoMessage.OneOf(name)) 441 442 def _supports_child(self, child: ProtoNode) -> bool: 443 return ( 444 child.type() == self.Type.ENUM or child.type() == self.Type.MESSAGE 445 ) 446 447 def dependencies(self) -> list[ProtoMessage]: 448 if self._dependencies is None: 449 self._dependencies = [] 450 for field in self._fields: 451 if ( 452 field.type() 453 != descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE 454 ): 455 continue 456 457 type_node = field.type_node() 458 assert type_node is not None 459 if type_node.type() == ProtoNode.Type.MESSAGE: 460 self._dependencies.append(cast(ProtoMessage, type_node)) 461 462 return list(self._dependencies) 463 464 def dependency_cycles(self) -> list[ProtoMessage]: 465 return list(self._dependency_cycles) 466 467 def remove_dependency_cycle(self, dependency: ProtoMessage): 468 assert self._dependencies is not None 469 assert dependency in self._dependencies 470 self._dependencies.remove(dependency) 471 self._dependency_cycles.append(dependency) 472 473 474class ProtoService(ProtoNode): 475 """Representation of a service in a .proto file.""" 476 477 def __init__(self, name: str): 478 super().__init__(name) 479 self._methods: list[ProtoServiceMethod] = [] 480 481 def type(self) -> ProtoNode.Type: 482 return ProtoNode.Type.SERVICE 483 484 def methods(self) -> list[ProtoServiceMethod]: 485 return list(self._methods) 486 487 def add_method(self, method: ProtoServiceMethod) -> None: 488 self._methods.append(method) 489 490 def _supports_child(self, child: ProtoNode) -> bool: 491 return False 492 493 494class ProtoExternal(ProtoNode): 495 """A node from a different compilation unit. 496 497 An external node is one that isn't defined within the current compilation 498 unit, most likely as it comes from an imported proto file. Its type is not 499 known, so it does not have any members or additional data. Its purpose 500 within the node graph is to provide namespace resolution between compile 501 units. 502 """ 503 504 def type(self) -> ProtoNode.Type: 505 return ProtoNode.Type.EXTERNAL 506 507 def _supports_child(self, child: ProtoNode) -> bool: 508 return True 509 510 511# This class is not a node and does not appear in the proto tree. 512# Fields belong to proto messages and are processed separately. 513class ProtoMessageField: 514 """Representation of a field within a protobuf message.""" 515 516 def __init__( 517 self, 518 field_name: str, 519 field_number: int, 520 field_type: int, 521 type_node: ProtoNode | None = None, 522 has_presence: bool = False, 523 repeated: bool = False, 524 codegen_options: CodegenOptions | None = None, 525 ): 526 self._field_name = symbol_name_mapping.fix_cc_identifier(field_name) 527 self._number: int = field_number 528 self._type: int = field_type 529 self._type_node: ProtoNode | None = type_node 530 self._has_presence: bool = has_presence 531 self._repeated: bool = repeated 532 self._options: CodegenOptions | None = codegen_options 533 self._oneof: ProtoMessage.OneOf | None = None 534 535 def name(self) -> str: 536 return self.upper_camel_case(self._field_name) 537 538 def field_name(self) -> str: 539 return self._field_name 540 541 def enum_name(self) -> str: 542 return 'k' + self.name() 543 544 def legacy_enum_name(self) -> str: 545 return self.upper_snake_case( 546 symbol_name_mapping.fix_cc_enum_value_name(self._field_name) 547 ) 548 549 def number(self) -> int: 550 return self._number 551 552 def type(self) -> int: 553 return self._type 554 555 def type_node(self) -> ProtoNode | None: 556 return self._type_node 557 558 def has_presence(self) -> bool: 559 return self._has_presence 560 561 def is_repeated(self) -> bool: 562 return self._repeated 563 564 def options(self) -> CodegenOptions | None: 565 return self._options 566 567 def oneof(self) -> ProtoMessage.OneOf | None: 568 if self._oneof is not None and not self._oneof.is_synthetic(): 569 return self._oneof 570 return None 571 572 @staticmethod 573 def upper_camel_case(field_name: str) -> str: 574 """Converts a field name to UpperCamelCase.""" 575 name_components = field_name.split('_') 576 return ''.join([word.lower().capitalize() for word in name_components]) 577 578 @staticmethod 579 def upper_snake_case(field_name: str) -> str: 580 """Converts a field name to UPPER_SNAKE_CASE.""" 581 return field_name.upper() 582 583 584class ProtoServiceMethod: 585 """A method defined in a protobuf service.""" 586 587 class Type(enum.Enum): 588 UNARY = 'kUnary' 589 SERVER_STREAMING = 'kServerStreaming' 590 CLIENT_STREAMING = 'kClientStreaming' 591 BIDIRECTIONAL_STREAMING = 'kBidirectionalStreaming' 592 593 def cc_enum(self) -> str: 594 """Returns the pw_rpc MethodType C++ enum for this method type.""" 595 return '::pw::rpc::MethodType::' + self.value 596 597 def __init__( 598 self, 599 service: ProtoService, 600 name: str, 601 method_type: Type, 602 request_type: ProtoNode, 603 response_type: ProtoNode, 604 ): 605 self._service = service 606 self._name = name 607 self._type = method_type 608 self._request_type = request_type 609 self._response_type = response_type 610 611 def service(self) -> ProtoService: 612 return self._service 613 614 def name(self) -> str: 615 return self._name 616 617 def type(self) -> Type: 618 return self._type 619 620 def server_streaming(self) -> bool: 621 return self._type in ( 622 self.Type.SERVER_STREAMING, 623 self.Type.BIDIRECTIONAL_STREAMING, 624 ) 625 626 def client_streaming(self) -> bool: 627 return self._type in ( 628 self.Type.CLIENT_STREAMING, 629 self.Type.BIDIRECTIONAL_STREAMING, 630 ) 631 632 def request_type(self) -> ProtoNode: 633 return self._request_type 634 635 def response_type(self) -> ProtoNode: 636 return self._response_type 637 638 639def _add_enum_fields(enum_node: ProtoNode, proto_enum) -> None: 640 """Adds fields from a protobuf enum descriptor to an enum node.""" 641 assert enum_node.type() == ProtoNode.Type.ENUM 642 enum_node = cast(ProtoEnum, enum_node) 643 644 for value in proto_enum.value: 645 enum_node.add_value(value.name, value.number) 646 647 648def _create_external_nodes(root: ProtoNode, path: str) -> ProtoNode: 649 """Creates external nodes for a path starting from the given root.""" 650 651 node = root 652 for part in path.split('.'): 653 child = node.find(part) 654 if not child: 655 child = ProtoExternal(part) 656 node.add_child(child) 657 node = child 658 659 return node 660 661 662def _find_or_create_node( 663 global_root: ProtoNode, package_root: ProtoNode, path: str 664) -> ProtoNode: 665 """Searches the proto tree for a node by path, creating it if not found.""" 666 667 if path[0] == '.': 668 # Fully qualified path. 669 root_relative_path = path[1:] 670 search_root = global_root 671 else: 672 root_relative_path = path 673 search_root = package_root 674 675 node = search_root.find(root_relative_path) 676 if node is None: 677 # Create nodes for field types that don't exist within this 678 # compilation context, such as those imported from other .proto 679 # files. 680 node = _create_external_nodes(search_root, root_relative_path) 681 682 return node 683 684 685def _add_message_fields( 686 proto_file: descriptor_pb2.FileDescriptorProto, 687 global_root: ProtoNode, 688 package_root: ProtoNode, 689 message: ProtoNode, 690 proto_message, 691 proto_options, 692) -> None: 693 """Adds fields from a protobuf message descriptor to a message node.""" 694 assert message.type() == ProtoNode.Type.MESSAGE 695 message = cast(ProtoMessage, message) 696 697 type_node: ProtoNode | None 698 699 for field in proto_message.field: 700 if field.type_name: 701 # The "type_name" member contains the global .proto path of the 702 # field's type object, for example ".pw.protobuf.test.KeyValuePair". 703 # Try to find the node for this object within the current context. 704 type_node = _find_or_create_node( 705 global_root, package_root, field.type_name 706 ) 707 else: 708 type_node = None 709 710 repeated = ( 711 field.label == descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED 712 ) 713 714 has_presence = False 715 716 # TODO: pwbug.dev/366316523 - The generated protobuf types do not 717 # include the "edition" property, so getattr is used. Fix this when 718 # we upgrade protobuf and mypy-protobuf. 719 if getattr(proto_file, 'edition', None) == '2023': 720 has_presence = not repeated and ( 721 field.type is ProtoNode.Type.MESSAGE 722 or field.options.features.field_presence 723 != edition_constants.FieldPresence.IMPLICIT.value 724 ) 725 else: 726 # If the file does not use editions, only consider explicit 727 # proto3 optionality. 728 has_presence = field.proto3_optional 729 730 codegen_options = ( 731 options.match_options( 732 '.'.join((message.proto_path(), field.name)), proto_options 733 ) 734 if proto_options is not None 735 else None 736 ) 737 738 field_options = ( 739 options.create_from_field_options( 740 field.options.Extensions[pwpb_field_options] 741 ) 742 if field.options.HasExtension(pwpb_field_options) 743 else None 744 ) 745 746 merged_options = None 747 748 if field_options and codegen_options: 749 merged_options = options.merge_field_and_codegen_options( 750 field_options, codegen_options 751 ) 752 elif field_options: 753 merged_options = field_options 754 elif codegen_options: 755 merged_options = codegen_options 756 757 oneof_index = ( 758 field.oneof_index if field.HasField('oneof_index') else None 759 ) 760 761 message.add_field( 762 ProtoMessageField( 763 field.name, 764 field.number, 765 field.type, 766 type_node, 767 has_presence, 768 repeated, 769 merged_options, 770 ), 771 oneof_index=oneof_index, 772 ) 773 774 775def _add_service_methods( 776 global_root: ProtoNode, 777 package_root: ProtoNode, 778 service: ProtoNode, 779 proto_service, 780) -> None: 781 assert service.type() == ProtoNode.Type.SERVICE 782 service = cast(ProtoService, service) 783 784 for method in proto_service.method: 785 if method.client_streaming and method.server_streaming: 786 method_type = ProtoServiceMethod.Type.BIDIRECTIONAL_STREAMING 787 elif method.client_streaming: 788 method_type = ProtoServiceMethod.Type.CLIENT_STREAMING 789 elif method.server_streaming: 790 method_type = ProtoServiceMethod.Type.SERVER_STREAMING 791 else: 792 method_type = ProtoServiceMethod.Type.UNARY 793 794 request_node = _find_or_create_node( 795 global_root, package_root, method.input_type 796 ) 797 response_node = _find_or_create_node( 798 global_root, package_root, method.output_type 799 ) 800 801 service.add_method( 802 ProtoServiceMethod( 803 service, method.name, method_type, request_node, response_node 804 ) 805 ) 806 807 808def _populate_fields( 809 proto_file: descriptor_pb2.FileDescriptorProto, 810 global_root: ProtoNode, 811 package_root: ProtoNode, 812 proto_options: options.ParsedOptions | None, 813) -> None: 814 """Traverses a proto file, adding all message and enum fields to a tree.""" 815 816 def populate_message(node, message): 817 """Recursively populates nested messages and enums.""" 818 _add_message_fields( 819 proto_file, global_root, package_root, node, message, proto_options 820 ) 821 822 for proto_enum in message.enum_type: 823 _add_enum_fields(node.find(proto_enum.name), proto_enum) 824 for msg in message.nested_type: 825 populate_message(node.find(msg.name), msg) 826 827 # Iterate through the proto file, populating top-level objects. 828 for proto_enum in proto_file.enum_type: 829 enum_node = package_root.find(proto_enum.name) 830 assert enum_node is not None 831 _add_enum_fields(enum_node, proto_enum) 832 833 for message in proto_file.message_type: 834 populate_message(package_root.find(message.name), message) 835 836 for service in proto_file.service: 837 service_node = package_root.find(service.name) 838 assert service_node is not None 839 _add_service_methods(global_root, package_root, service_node, service) 840 841 842def _build_hierarchy( 843 proto_file: descriptor_pb2.FileDescriptorProto, 844) -> tuple[ProtoPackage, ProtoPackage]: 845 """Creates a ProtoNode hierarchy from a proto file descriptor.""" 846 847 root = ProtoPackage('') 848 package_root = root 849 850 for part in proto_file.package.split('.'): 851 package = ProtoPackage(part) 852 package_root.add_child(package) 853 package_root = package 854 855 def build_message_subtree(proto_message): 856 node = ProtoMessage(proto_message.name) 857 for oneof in proto_message.oneof_decl: 858 node.add_oneof(oneof.name) 859 for proto_enum in proto_message.enum_type: 860 node.add_child(ProtoEnum(proto_enum.name)) 861 for submessage in proto_message.nested_type: 862 node.add_child(build_message_subtree(submessage)) 863 864 return node 865 866 for proto_enum in proto_file.enum_type: 867 package_root.add_child(ProtoEnum(proto_enum.name)) 868 869 for message in proto_file.message_type: 870 package_root.add_child(build_message_subtree(message)) 871 872 for service in proto_file.service: 873 package_root.add_child(ProtoService(service.name)) 874 875 return root, package_root 876 877 878def build_node_tree( 879 file_descriptor_proto: descriptor_pb2.FileDescriptorProto, 880 proto_options: options.ParsedOptions | None = None, 881) -> tuple[ProtoNode, ProtoNode]: 882 """Constructs a tree of proto nodes from a file descriptor. 883 884 Returns the root node of the entire proto package tree and the node 885 representing the file's package. 886 """ 887 global_root, package_root = _build_hierarchy(file_descriptor_proto) 888 _populate_fields( 889 file_descriptor_proto, global_root, package_root, proto_options 890 ) 891 return global_root, package_root 892