1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import abc 5import collections 6import copy 7import operator 8from typing import Any, Dict, Final, Generator, Iterator, Sequence, Tuple 9 10import torch 11import torch.fx 12from torch.onnx._internal.fx import _pass, diagnostics 13from torch.utils import _pytree as pytree 14 15 16_FX_TRACER_NN_MODULE_META_TYPE = Tuple[str, type] 17"""Legacy type of item from `node.meta["nn_module_stack"].items()` produced by FX symbolic tracer.""" 18_FX_TRACER_NN_MODULE_STACK_META_TYPE = collections.OrderedDict 19"""Legacy type of `node.meta["nn_module_stack"]` produced by FX symbolic tracer.""" 20 21_DYNAMO_NN_MODULE_META_TYPE = Tuple[str, Tuple[str, type]] 22"""Type of item from `node.meta["nn_module_stack"].items()` produced by FX dynamo tracer.""" 23_DYNAMO_NN_MODULE_STACK_META_TYPE = Dict[str, _DYNAMO_NN_MODULE_META_TYPE] 24"""Type of `node.meta["nn_module_stack"]` produced by FX dynamo tracer.""" 25 26 27class _ModuleMeta: 28 """Meta information about a module. 29 30 This class is used to represent the module information in a more structured way. 31 It parses raw module information from a single item from 32 `node.meta["nn_module_stack"].items()`. 33 34 See the uses of `from_raw_meta`, `from_fx_tracer_produced_raw_meta`, and 35 `from_dynamo_produced_raw_meta` for how to create an instance. 36 37 Attributes: 38 _module_class: The class of the module. E.g. `torch.nn.module.sparse.Embedding`. 39 _module_name: The name of the module. E.g. `L__self___h_1_mlp_c_proj`. 40 _raw_meta: The raw meta '(module_name, node.meta["nn_module_stack"][module_name])'. 41 """ 42 43 _module_class: Final[type | str | None] # type: ignore[misc] 44 _module_name: Final[str] # type: ignore[misc] 45 _raw_meta: Final[tuple[Any, Any]] # type: ignore[misc] 46 47 def __init__( 48 self, 49 module_name: str, 50 module_class: type | str | None, 51 raw_meta: tuple[Any, Any], 52 ): 53 self._module_name = module_name 54 self._module_class = module_class 55 self._raw_meta = raw_meta 56 57 @property 58 def module_display_name(self) -> str: 59 """The display name of the module. 60 61 E.g. `h_1_mlp_c_proj`. 62 """ 63 # E.g., from 'L__self___h_1_mlp_c_proj' to 'h_1_mlp_c_proj'. 64 name = self.module_name 65 if name.startswith("L__self___"): 66 name = name[len("L__self___") :] 67 return name 68 69 @property 70 def qualified_module_class_name(self) -> str: 71 """Qualified name of the module class. 72 73 E.g. `torch_nn_module_sparse_Embedding`. 74 """ 75 if self._module_class is None: 76 return "" 77 mod_cls = self._module_class 78 if isinstance(mod_cls, type): 79 mod_cls = mod_cls.__module__ + "." + mod_cls.__qualname__ 80 return mod_cls.replace(".", "_") 81 82 @property 83 def module_class_name(self) -> str: 84 """Name of the module class. 85 86 E.g. `Embedding`. 87 """ 88 if self._module_class is None: 89 return "" 90 if isinstance(self._module_class, type): 91 return self._module_class.__name__ 92 return self._module_class 93 94 @property 95 def module_name(self) -> str: 96 """Name of the module. 97 98 E.g. `L__self___h_1_mlp_c_proj`. 99 """ 100 return self._module_name 101 102 @property 103 def raw_meta(self) -> tuple[Any, Any]: 104 """Returns the raw module meta data. 105 106 I.e. (module_name, node.meta['nn_module_stack'][module_name]). 107 """ 108 return self._raw_meta 109 110 def __eq__(self, __value: object) -> bool: 111 if not isinstance(__value, _ModuleMeta): 112 return False 113 return ( 114 self._module_name == __value._module_name 115 and self._module_class == __value._module_class 116 ) 117 118 def __hash__(self) -> int: 119 return hash((self._module_name, self._module_class)) 120 121 def __repr__(self) -> str: 122 return f"ModuleMeta(name={self._module_name}, class={self._module_class})" 123 124 @classmethod 125 def create_root(cls) -> _ModuleMeta: 126 """Create an empty module meta representing root module.""" 127 return _ModuleMeta("", None, ("", None)) 128 129 @classmethod 130 def from_fx_tracer_produced_raw_meta( 131 cls, raw_meta: _FX_TRACER_NN_MODULE_META_TYPE 132 ) -> _ModuleMeta: 133 """Create a module meta from raw meta produced by FX symbolic tracer.""" 134 module_name, module_class = raw_meta 135 return _ModuleMeta(module_name, module_class, raw_meta) 136 137 @classmethod 138 def from_dynamo_produced_raw_meta( 139 cls, raw_meta: _DYNAMO_NN_MODULE_META_TYPE 140 ) -> _ModuleMeta: 141 """Create a module meta from raw meta produced by FX dynamo tracer.""" 142 module_name, (qualified_name, module_class) = raw_meta 143 return _ModuleMeta(module_name, module_class, raw_meta) 144 145 @classmethod 146 def from_raw_meta( 147 cls, 148 raw_meta: _FX_TRACER_NN_MODULE_META_TYPE | _DYNAMO_NN_MODULE_META_TYPE, 149 ) -> _ModuleMeta: 150 if ( 151 isinstance(raw_meta, tuple) 152 and len(raw_meta) == 2 153 and isinstance(raw_meta[1], type) 154 ): 155 # Trying to do `instance(raw_meta, _FX_TRACER_NN_MODULE_META_TYPE)` 156 return _ModuleMeta.from_fx_tracer_produced_raw_meta(raw_meta) 157 if ( 158 isinstance(raw_meta, tuple) 159 and len(raw_meta) == 2 160 and isinstance(raw_meta[1], tuple) 161 ): 162 # Trying to do `instance(raw_meta, _DYNAMO_NN_MODULE_META_TYPE)` 163 return _ModuleMeta.from_dynamo_produced_raw_meta(raw_meta) 164 raise TypeError( 165 f"Unknown type of raw meta item from node.meta['nn_module_stack'].items(): {type(raw_meta)}" 166 ) 167 168 169class _ModuleStackMeta: 170 """Meta information about the module call stack. 171 172 This class is used to represent the module call stack information in a more 173 structured way. It parses raw module stack information from `node.meta["nn_module_stack"]`. 174 175 Example of raw module stack information: 176 177 If produced by dynamo: 178 179 { 180 'L__self___h_1': ( 181 "L['self'].h[1]", 182 <class 'transformers.models.gpt2.modeling_gpt2.GPT2Block'> 183 ), 184 'L__self___h_1_attn': ( 185 "L['self'].h[1].attn", 186 <class 'transformers.models.gpt2.modeling_gpt2.GPT2Attention'> 187 ) 188 } 189 190 If produced by fx.symbolic_trace: 191 192 { 193 'h.1': <class 'transformers.models.gpt2.modeling_gpt2.GPT2Block'>, 194 'h.1.attn': <class 'transformers.models.gpt2.modeling_gpt2.GPT2Attention'> 195 } 196 """ 197 198 _module_stack: Final[list[_ModuleMeta]] # type: ignore[misc] 199 200 def __init__( 201 self, 202 nn_module_stack_meta: _FX_TRACER_NN_MODULE_STACK_META_TYPE 203 | _DYNAMO_NN_MODULE_STACK_META_TYPE 204 | None, 205 is_exported_program: bool = True, 206 ): 207 self._module_stack = [] 208 if nn_module_stack_meta is None: 209 return 210 raw_meta = copy.copy(nn_module_stack_meta) 211 for item in raw_meta.items(): 212 # If produced by torch.export.export, there is another call stack layer 213 # that we need to skip 214 if is_exported_program: 215 is_exported_program = False 216 continue 217 self.push(_ModuleMeta.from_raw_meta(item)) # type: ignore[arg-type] 218 219 def __len__(self) -> int: 220 return len(self._module_stack) 221 222 def __getitem__(self, index: int) -> _ModuleMeta: 223 return self._module_stack[index] 224 225 def __iter__(self) -> Iterator[_ModuleMeta]: 226 return iter(self._module_stack) 227 228 def is_empty_or_root(self) -> bool: 229 return len(self._module_stack) == 0 230 231 def top(self) -> _ModuleMeta: 232 """Returns the top module meta in the stack. I.e., the meta for leaf module. 233 234 Example: 235 236 Consider the following module stack: 237 238 stack = [GPT, block1, Attention_1, MLP] 239 240 stack.top() == MLP 241 """ 242 if self.is_empty_or_root(): 243 return _ModuleMeta.create_root() 244 return self._module_stack[-1] 245 246 def is_superset_of( 247 self, 248 module_stack: _ModuleStackMeta, 249 ) -> bool: 250 """Determines if self is a superset of the provided module stack. 251 252 I.e., If self includes all elements from the provided module stack, plus additional 253 elements on top. If self is empty or root, this method always return False. 254 255 Example: 256 257 Consider the following module stack: 258 259 stack_1 = [GPT, block1, Attention_1, MLP] 260 stack_2 = [GPT, block1] 261 262 stack_1.is_superset_of(stack_2) == True 263 stack_2.is_superset_of(stack_1) == False 264 265 stack_3 = [GPT, block2, Attention_1] 266 267 stack_1.is_superset_of(stack_3) == False 268 stack_3.is_superset_of(stack_1) == False 269 """ 270 if self.is_empty_or_root(): 271 return False 272 273 if module_stack.is_empty_or_root() is None: 274 return True 275 276 if len(self) <= len(module_stack): 277 return False 278 279 for i, parent_key in enumerate(module_stack): 280 if self[i] != parent_key: 281 return False 282 283 return True 284 285 def push(self, module_meta: _ModuleMeta) -> None: 286 """Pushes a module meta to the stack.""" 287 self._module_stack.append(module_meta) 288 289 def __eq__(self, __value: object) -> bool: 290 if not isinstance(__value, _ModuleStackMeta): 291 return False 292 return self._module_stack == __value._module_stack 293 294 @property 295 def raw_meta(self) -> dict[str, tuple[str, type]] | None: 296 """Returns the raw module stack meta data, i.e. node.meta['nn_module_stack'].""" 297 return { 298 module_meta.raw_meta[0]: module_meta.raw_meta[1] 299 for module_meta in self._module_stack 300 } 301 302 def __repr__(self) -> str: 303 return f"ModuleStackMeta({self._module_stack})" 304 305 @property 306 def module_display_name(self) -> str: 307 """Returns the module display name of the top module.""" 308 return self.top().module_display_name 309 310 @property 311 def qualified_module_class_name(self) -> str: 312 """Returns the qualified module class name of the top module.""" 313 return self.top().qualified_module_class_name 314 315 @property 316 def module_class(self) -> type | str | None: 317 """Returns the module class of the top module.""" 318 return self.top()._module_class 319 320 321def _module_stack_meta_from_node( 322 node: torch.fx.Node, is_exported_program: bool = False 323) -> _ModuleStackMeta: 324 return _ModuleStackMeta( 325 node.meta.get("nn_module_stack"), is_exported_program=is_exported_program 326 ) 327 328 329def _get_unique_module_name(module_names: dict[str, int], module_name: str) -> str: 330 module_names.setdefault(module_name, 0) 331 module_names[module_name] += 1 332 return f"{module_name}_{module_names[module_name]}" 333 334 335class _IRNode(abc.ABC): 336 """Base class for IR nodes. 337 338 IR nodes are used for Modularize pass only. They add a layer of abstraction on top of 339 torch.fx.Node. 340 341 [NOTE: Modularize Pass Implementation] 342 The main job of the pass is to group `fx.Node`s that belong to the same `nn.Module` 343 forward call, and then create `call_module` node and sub `fx.GraphModule` from them. 344 Each `fx.Node` possesses an `nn_module_stack` meta data that contains information 345 about the module call stack. See `_ModuleStackMeta` for examples. 346 347 Analysis step 348 ------------- 349 350 Each module call is identified by a set of base stack layers. For each module call, 351 the pass creates a `_ModuleNode` and groups the sequence of nodes that shares the 352 same base stack layers. 353 354 For example, 355 356 stack_of_node_0 = [GPT, block0] 357 stack_of_node_1 = [GPT, block1] 358 stack_of_node_2 = [GPT, block1, Attention1, MLP] 359 stack_of_node_3 = [GPT, block1, Attention1] 360 stack_of_node_4 = [GPT, block2] 361 362 All nodes belong to the `GPT` module call, since they share the base stack layers [GPT]. 363 [node_1, node_2, node_3] are grouped for `GPT.block1`, because they share the base 364 stack layers [GPT, block1]. And [node_2, node_3] for `GPT.block1.Attention1`, [node_0] 365 for `GPT.block0`, and [node_4] for `GPT.block2` respectfully. 366 367 After the analysis step, a hierarchical representation is generated. 368 369 For above example, the representation is: 370 371 _ModuleNode(GPT) 372 _ModuleNode(block0) 373 _LeafNode(node_0) 374 _ModuleNode(block1) 375 _LeafNode(node_1) 376 _ModuleNode(Attention1) 377 _ModuleNode(MLP) 378 _LeafNode(node_2) 379 _LeafNode(node_3) 380 _ModuleNode(block2) 381 _LeafNode(node_4) 382 383 Construction step 384 ----------------- 385 386 The second step is to build the actual `call_module` node and the sub `fx.GraphModule`. 387 This is done recursively from the leaf `_ModuleNode` to the root. 388 389 For example, the first submodule to be built is `GPT.block1.Attention1.MLP`. Below pair 390 is generated from `_ModuleNode(MLP)`. 391 392 fx.GraphModule(GPT.block1.Attention1.MLP) 393 graph: 394 node_2 395 396 new_mlp_node = `call_module[GPT.block1.Attention1.MLP](...)` 397 398 Next, the `GPT.block1.Attention1` submodule is built. Below is generated from 399 `_ModuleNode(Attention1)`. 400 401 fx.GraphModule(GPT.block1.Attention1) 402 graph: 403 new_mlp_node 404 node_3 405 406 new_attention1_node = `call_module[GPT.block1.Attention1](...)` 407 408 Until every submodule is built, the new modularized `fx.GraphModule` is generated. 409 410 Alternatives 411 ------------ 412 413 The current algorithm adopts a top down approach. A bottom up approach is similar. 414 In contrast to these two, an alternative flat order approach is also possible, where 415 each node is traversed and copied to the corresponding submodule. 416 417 The advantage of the current approach lies in the encapsulation of the fx.GraphModule 418 construction for each individual submodule within a single `build_module` method, which 419 can be called separately once the analysis phase is completed, making debugging more 420 convenient. 421 422 Regarding construction step, an alternative implementation is to utilize `fx.Interpreter` 423 for traversing all the nodes under the flattened root module and copying the nodes 424 into their respective submodule under construction. This approach is not adopted because 425 426 1. It uses the flat order approach discussed above. This means one cannot individually 427 construct a submodule and examine it while debugging. 428 429 2. The graph execution functionality of `fx.Interpreter` is not necessary for the 430 purpose of this pass. Ignoring that, `fx.Interpreter.run` achieves the same effect 431 as a for loop over all the nodes. 432 """ 433 434 @property 435 @abc.abstractmethod 436 def stack_meta(self) -> _ModuleStackMeta: 437 """The module stack meta data associated with this node.""" 438 ... 439 440 @property 441 @abc.abstractmethod 442 def stack_trace(self) -> str | None: 443 """The stack trace associated with this node.""" 444 ... 445 446 447class _ModuleNode(_IRNode): 448 """Representing a sequence of fx.Nodes to be formed into a fx.GraphModule. 449 450 This class encapsulates metadata and provides building block methods to construct this 451 layered abstraction from a sequence of flat fx.Nodes. 452 453 Attributes: 454 - _stack_meta: Metadata of the module stack. 455 - _nodes: List of IR nodes in the module. 456 - _reference_root_module: Reference to the root flat fx.GraphModule instance. 457 """ 458 459 def __init__( 460 self, reference_root_module: torch.fx.GraphModule, stack_meta: _ModuleStackMeta 461 ): 462 self._stack_meta = stack_meta 463 self._nodes: list[_IRNode] = [] 464 self._reference_module = reference_root_module 465 466 @property 467 def stack_meta(self) -> _ModuleStackMeta: 468 return self._stack_meta 469 470 @property 471 def stack_trace(self) -> str | None: 472 assert self._nodes 473 return self._nodes[0].stack_trace 474 475 def __str__(self) -> str: 476 return f"ModuleNode({self._stack_meta})" 477 478 def is_same_module_as(self, node: _IRNode) -> bool: 479 """Determines if the provided node pertains to the same module as this node.""" 480 return self.stack_meta == node.stack_meta 481 482 def is_parent_module_of(self, node: _IRNode) -> bool: 483 """Determines if this node represents a parent module of the provided node.""" 484 return node.stack_meta.is_superset_of(self.stack_meta) 485 486 def add_leaf_node(self, leaf_node: _LeafNode) -> None: 487 """Adds a leaf node to the module. 488 489 The leaf node must belong to the same or a child module. This method will recursively 490 construct _ModuleNode instance based on the stack_meta information of the leaf node. 491 """ 492 if self.is_same_module_as(leaf_node) or leaf_node.fx_op == "call_module": 493 self._nodes.append(leaf_node) 494 elif leaf_node.fx_op == "placeholder": 495 # Although the original placeholder has empty nn_module_stack, the placeholder lifted 496 # from get_attr nodes by exported program has their original nn_module_stack. Here 497 # we need to avoid them building submodule. 498 self._nodes.append(leaf_node) 499 elif self.is_parent_module_of(leaf_node): 500 # This node belongs in a submodule. 501 # Check if the last node is a submodule and if it is the parent of this node. 502 last_node = self._nodes[-1] if self._nodes else None 503 if isinstance(last_node, _ModuleNode) and ( 504 last_node.is_parent_module_of(leaf_node) 505 or last_node.is_same_module_as(leaf_node) 506 ): 507 # This node belongs to the last_node. 508 last_node.add_leaf_node(leaf_node) 509 else: 510 # Create a new SubmoduleNode for the immediate child module of the current 511 # module. The leaf node may be a grandchild of the current module. 512 # Example: 513 # self.stack_meta = [A, B, C] 514 # leaf_node.stack_meta = [A, B, C, D, E, F] 515 # Create a new ModuleNode with stack_meta = [A, B, C, D] and add leaf_node to it. 516 stack_meta = copy.deepcopy(self.stack_meta) 517 stack_meta.push(leaf_node.stack_meta[len(self.stack_meta)]) 518 last_node = _ModuleNode( 519 self._reference_module, 520 stack_meta, 521 ) 522 self._nodes.append(last_node) 523 last_node.add_leaf_node(leaf_node) 524 else: 525 raise AssertionError( 526 f"Node {leaf_node} ({leaf_node.stack_meta}) does not belong to module " 527 f"{self._stack_meta}." 528 ) 529 530 def fx_nodes(self) -> Generator[torch.fx.Node, None, None]: 531 """Returns an iterator for the sequence of fx nodes this instance holds.""" 532 for node in self._nodes: 533 if isinstance(node, _ModuleNode): 534 yield from node.fx_nodes() 535 else: 536 assert isinstance(node, _LeafNode) 537 yield node.fx_node 538 539 def module_inputs(self) -> Sequence[torch.fx.Node]: 540 """Extract module inputs from the sequence of fx nodes this instance holds. 541 542 All node args that are produced by nodes outside of the module are considered module 543 inputs. The order of returned module inputs is the same as the their use order. 544 545 ### Known limitations 546 547 The original ordering of module inputs is not preserved. There is no meta information 548 to be found from the `fx.GraphModule` that can be used to recover the original ordering. 549 550 Returns: 551 Sequence of module inputs. 552 """ 553 nodes = list(self.fx_nodes()) 554 assert len(nodes) > 0, "Cannot extract module inputs from empty nodes." 555 module_inputs: dict[torch.fx.Node, None] = {} 556 node_set: set[torch.fx.Node] = set(nodes) 557 558 def _extract_arg_if_node_outside_module(arg: Any): 559 if isinstance(arg, torch.fx.Node) and arg not in node_set: 560 module_inputs[arg] = None 561 562 for node in nodes: 563 pytree.tree_map(_extract_arg_if_node_outside_module, node.args) 564 pytree.tree_map(_extract_arg_if_node_outside_module, node.kwargs) 565 return list(module_inputs.keys()) 566 567 def module_outputs(self) -> Sequence[torch.fx.Node]: 568 """Extract module outputs from the sequence of fx nodes this instance holds. 569 570 All nodes that are used by nodes outside of the module are considered module 571 outputs. The order of returned module outputs is the same as the their creation order. 572 573 ### Known limitations 574 575 The original ordering of module outputs is not preserved. There is no meta information 576 to be found from the `fx.GraphModule` that can be used to recover the original ordering. 577 578 Returns: 579 Sequence of module outputs. 580 """ 581 nodes = list(self.fx_nodes()) 582 assert len(nodes) > 0, "Cannot extract module inputs from empty nodes." 583 # Need ordered set. Emulate with dict. 584 module_outputs: dict[torch.fx.Node, None] = {} 585 node_set: set[torch.fx.Node] = set(nodes) 586 587 for node in nodes: 588 if any(user not in node_set for user in node.users): 589 module_outputs[node] = None 590 return list(module_outputs.keys()) 591 592 def build_module(self, module_names: dict[str, int]) -> torch.fx.GraphModule: 593 """ 594 Constructs the fx.GraphModule for this node, registering submodules as necessary. 595 596 Args: 597 module_names: A dictionary of module names and their counts. This is used to 598 generate unique module names for submodules. This should be an empty 599 dictionary when the method is called on a root module. 600 """ 601 module_class_name = self._stack_meta.qualified_module_class_name 602 fx_graph = torch.fx.Graph() 603 copy_env: dict[torch.fx.Node, torch.fx.Node] = {} 604 605 def _arg_transform(node: torch.fx.Node) -> torch.fx.Node: 606 return copy_env[node] 607 608 ref_inputs = self.module_inputs() 609 for node in ref_inputs: 610 copy_env[node] = fx_graph.placeholder(node.name, node.type) 611 copy_env[node].meta = copy.copy(node.meta) 612 613 for ir_node in self._nodes: 614 if isinstance(ir_node, _LeafNode): 615 fx_node = ir_node.fx_node 616 copy_env[fx_node] = fx_graph.node_copy( 617 fx_node, arg_transform=_arg_transform 618 ) 619 continue 620 621 assert isinstance(ir_node, _ModuleNode) 622 # Create fx.GraphModule for child submodule. 623 submodule = ir_node.build_module(module_names) 624 ref_submodule_inputs = ir_node.module_inputs() 625 ref_submodule_outputs = ir_node.module_outputs() 626 unique_submodule_name = _get_unique_module_name( 627 module_names, ir_node.stack_meta.module_display_name 628 ) 629 # Link the newly generated sub fx.GraphModule with the root reference module. 630 # This step is essential to meet the needs of the subsequent fx.GraphModule initialization 631 # for the fx.GraphModule being created by this method. 632 # The initialization of fx.GraphModule will replicate all necessary attributes from a reference 633 # fx.GraphModule for the fx.Graph. While the root reference module possesses all 634 # parameters and buffers, it does not include the newly created sub fx.GraphModule. 635 # Therefore, it's necessary to register it under the root reference at this stage. 636 self._reference_module.add_submodule(unique_submodule_name, submodule) 637 638 # create call_module fx.Node 639 submodule_node = fx_graph.call_module( 640 unique_submodule_name, 641 tuple(_arg_transform(node) for node in ref_submodule_inputs), 642 ) 643 if len(ref_submodule_outputs) > 1: 644 # Module node has multiple output. Create 'getitem' node for each output. 645 submodule_node.meta["val"] = tuple( 646 ref_output.meta.get("val") for ref_output in ref_submodule_outputs 647 ) 648 for i, ref_output in enumerate(ref_submodule_outputs): 649 getitem_node = fx_graph.call_function( 650 operator.getitem, 651 args=(submodule_node, i), 652 type_expr=ref_output.type, 653 ) 654 getitem_node.meta = copy.copy(ref_output.meta) 655 # Make a copy for "nn_module_stack" since the current module will be 656 # popped from the stack for this 'getitem' node. 657 getitem_node.meta["nn_module_stack"] = copy.copy( 658 ref_output.meta["nn_module_stack"] 659 ) 660 # The node is associated with the parent module. 661 getitem_node.meta["nn_module_stack"].popitem() 662 copy_env[ref_output] = getitem_node 663 else: 664 # Module node has single output. Use module node directly. 665 copy_env[ref_submodule_outputs[0]] = submodule_node 666 submodule_node.meta = copy.copy(ref_submodule_outputs[0].meta) 667 668 # Update meta for new call_module node. 669 if (stack_trace := ir_node.stack_trace) is not None: 670 submodule_node.meta["stack_trace"] = stack_trace 671 raw_module_stack_meta = ir_node.stack_meta.raw_meta 672 assert raw_module_stack_meta is not None 673 submodule_node.meta["nn_module_stack"] = copy.copy(raw_module_stack_meta) 674 # The node is associated with the parent module. 675 submodule_node.meta["nn_module_stack"].popitem() 676 677 new_nodes = fx_graph.nodes 678 # Skip if the last node is already 'output'. This is the case for root module. 679 # Otherwise create an 'output' node for the inferred outputs. 680 if next(iter(reversed(new_nodes))).op != "output": 681 ref_submodule_outputs = self.module_outputs() 682 new_outputs = [copy_env[ref_output] for ref_output in self.module_outputs()] 683 node = fx_graph.output( 684 new_outputs[0] if len(new_outputs) == 1 else new_outputs 685 ) 686 687 graph_module = torch.fx.GraphModule( 688 self._reference_module, fx_graph, module_class_name 689 ) 690 if (module_class := self._stack_meta.module_class) is not None: 691 graph_module.meta["onnx"] = _pass.GraphModuleOnnxMeta( 692 _pass.PackageInfo.from_python_class(module_class) 693 ) 694 return graph_module 695 696 697class _LeafNode(_IRNode): 698 """Representing a single fx.Node.""" 699 700 def __init__(self, node: torch.fx.Node, is_exported_program: bool = False): 701 self._node = node 702 self._stack_meta = _module_stack_meta_from_node( 703 node, is_exported_program=is_exported_program 704 ) 705 706 @property 707 def fx_op(self) -> str: 708 """Syntax sugar for self.fx_node.op.""" 709 return self._node.op 710 711 @property 712 def fx_node(self) -> torch.fx.Node: 713 """Returns the fx.Node this instance represents.""" 714 return self._node 715 716 @property 717 def stack_meta(self) -> _ModuleStackMeta: 718 """Returns the module stack meta data associated with this node.""" 719 return self._stack_meta 720 721 @property 722 def stack_trace(self) -> str | None: 723 """Returns the stack trace associated with this node.""" 724 return self.fx_node.meta.get("stack_trace") 725 726 def __str__(self) -> str: 727 return f"LeafNode({self._node})" 728 729 730class Modularize(_pass.Transform): 731 """Transforms a flattened `fx.GraphModule` into a modular structure. 732 733 In the flattened `fx.GraphModule`, each `nn.Module` forward call has been traced as 734 a sequence of `fx.Node`s. All these `fx.Node`s are flattened and reside in the same 735 `fx.GraphModule`. `fx.GraphModule` could be from `torch.export.ExportedProgram` or 736 directly generated by `torch._dynamo.export` with torch.nn.Module. 737 738 This pass generates a new `fx.GraphModule`. It groups the flattened `fx.Node`s that belong 739 to the same `nn.Module` forward call into a sub `fx.GraphModule`. It then replaces the 740 sequence of flattened `fx.Node`s with a single `call_module` node, which is linked with 741 the sub `fx.GraphModule` by `node.target`. The sub `fx.GraphModule` is registered as a 742 submodule of the new `fx.GraphModule`. 743 744 The process is done based on information from the `nn_module_stack` metadata of each node, i.e. 745 `node.meta["nn_module_stack"]`. For more implementation details, see [NOTE: Modularize Pass Implementation]. 746 747 An fx submodule under this context can typically be interpreted in three different ways: 748 749 1. As an embodiment of an nn.Module class, which is considered stateless. 750 Its execution path can vary depending on the configuration of module initialization, 751 which should also be part of the inputs. 752 753 2. As a representation of an nn.Module instance. It maintains the state initialized in the module. 754 The execution path can vary based on actual input data. 755 756 3. As a captured call of an nn.Module instance, where the execution path 757 is set. 758 759 The generality decreases along this list. Within the scope of this function, the pass 760 creates fx submodules according to the third interpretation. 761 762 The first interpretation is the most general case. It requires complex analysis and additional 763 metadata and code information to construct its general form. Consider an example nn.Module 764 that generates arbitrary submodules based on an initialization configuration file. It's impractical 765 to extract this logic for the generated fx submodule to function with arbitrary configuration. 766 767 The second interpretation demands less analysis and is sturdier than the 768 first. In most use cases, it's equivalent to the third. It only differs in exceptional situations 769 where a complex nn.Module instance is called multiple times, each with a different set of inputs 770 leading to a unique execution branching path. 771 772 The third interpretation is the most specific scenario. It necessitates the minimum 773 analysis and creates the most stable representation. The drawback is that it 774 generates more redundancy than the other two methods. If needed, a subsequent post-processing 775 pass can be applied to consolidate completely identical functions and reduce duplication. 776 777 ### Known constraints 778 Two successive calls to the same module instance will be conflated. They are indistinguishable. 779 This is due to limitations of the current fx metadata "nn_module_stack". 780 781 [NOTE: Modularize pass ordering] 782 This pass groups fx nodes into subgraphs that reside within the `call_module` fx node. 783 Other fx passes (including some outside the exporter) might not recognize `call_module`. 784 They may assume that all nodes are flattened. Hence it is recommended to invoke this pass 785 as the last pre onnx export fx pass. If not for this consideration, this operation could 786 potentially be relocated anywhere earlier in the pipeline. 787 788 Example: 789 790 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) 791 >>> import torch 792 >>> from torch.onnx._internal.fx import passes 793 >>> from torch.onnx._internal.diagnostics import infra 794 >>> 795 >>> class CustomModule(torch.nn.Module): 796 >>> def __init__(self) -> None: 797 >>> super().__init__() 798 >>> self.embedding = torch.nn.Embedding(10, 32) 799 >>> self.relu = torch.nn.ReLU() 800 >>> 801 >>> def forward(self, x): 802 >>> out = self.embedding(x) 803 >>> out = self.relu(out) 804 >>> return out 805 >>> 806 >>> class TestModule(torch.nn.Module): 807 >>> def __init__(self) -> None: 808 >>> super().__init__() 809 >>> self.layer = CustomModule() 810 >>> self.linear = torch.nn.Linear(32, 10) 811 >>> 812 >>> def forward(self, x): 813 >>> out = self.layer(x) 814 >>> out = self.linear(out) 815 >>> return out 816 >>> 817 >>> gm, _ = torch._dynamo.export(TestModule(), aten_graph=True)( 818 ... torch.tensor([0, 1, 2]) 819 ... ) 820 >>> gm.print_readable() 821 822 >>> gm = passes.Modularize(infra.DiagnosticContext("test_context", "1.0"), gm).run() 823 >>> gm.print_readable() 824 825 """ 826 827 def __init__( 828 self, 829 diagnostic_context: diagnostics.DiagnosticContext, 830 module: torch.fx.GraphModule, 831 is_exported_program: bool = False, 832 ): 833 super().__init__(diagnostic_context, module) 834 self.module = module 835 self.is_exported_program = is_exported_program 836 837 def _run(self) -> torch.fx.GraphModule: 838 # DCE to remove unused nodes. 839 # If a submodule is unused, it is hard to analyze which nodes constitutes the submodule 840 # outputs. But since it is unused, we can just remove it. 841 self.module.graph.eliminate_dead_code() 842 843 reference_module = torch.fx.GraphModule(self.module, self.module.graph) 844 root_module_node = _ModuleNode( 845 reference_module, 846 _ModuleStackMeta( 847 nn_module_stack_meta=None, is_exported_program=self.is_exported_program 848 ), 849 ) 850 for fx_node in self.module.graph.nodes: 851 root_module_node.add_leaf_node( 852 _LeafNode(fx_node, is_exported_program=self.is_exported_program) 853 ) 854 return root_module_node.build_module({}) 855