xref: /aosp_15_r20/external/pytorch/torch/onnx/_internal/fx/passes/modularization.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import abc
5import collections
6import copy
7import operator
8from typing import Any, Dict, Final, Generator, Iterator, Sequence, Tuple
9
10import torch
11import torch.fx
12from torch.onnx._internal.fx import _pass, diagnostics
13from torch.utils import _pytree as pytree
14
15
16_FX_TRACER_NN_MODULE_META_TYPE = Tuple[str, type]
17"""Legacy type of item from `node.meta["nn_module_stack"].items()` produced by FX symbolic tracer."""
18_FX_TRACER_NN_MODULE_STACK_META_TYPE = collections.OrderedDict
19"""Legacy type of `node.meta["nn_module_stack"]` produced by FX symbolic tracer."""
20
21_DYNAMO_NN_MODULE_META_TYPE = Tuple[str, Tuple[str, type]]
22"""Type of item from `node.meta["nn_module_stack"].items()` produced by FX dynamo tracer."""
23_DYNAMO_NN_MODULE_STACK_META_TYPE = Dict[str, _DYNAMO_NN_MODULE_META_TYPE]
24"""Type of `node.meta["nn_module_stack"]` produced by FX dynamo tracer."""
25
26
27class _ModuleMeta:
28    """Meta information about a module.
29
30    This class is used to represent the module information in a more structured way.
31    It parses raw module information from a single item from
32    `node.meta["nn_module_stack"].items()`.
33
34    See the uses of `from_raw_meta`, `from_fx_tracer_produced_raw_meta`, and
35    `from_dynamo_produced_raw_meta` for how to create an instance.
36
37    Attributes:
38        _module_class: The class of the module. E.g. `torch.nn.module.sparse.Embedding`.
39        _module_name: The name of the module. E.g. `L__self___h_1_mlp_c_proj`.
40        _raw_meta: The raw meta '(module_name, node.meta["nn_module_stack"][module_name])'.
41    """
42
43    _module_class: Final[type | str | None]  # type: ignore[misc]
44    _module_name: Final[str]  # type: ignore[misc]
45    _raw_meta: Final[tuple[Any, Any]]  # type: ignore[misc]
46
47    def __init__(
48        self,
49        module_name: str,
50        module_class: type | str | None,
51        raw_meta: tuple[Any, Any],
52    ):
53        self._module_name = module_name
54        self._module_class = module_class
55        self._raw_meta = raw_meta
56
57    @property
58    def module_display_name(self) -> str:
59        """The display name of the module.
60
61        E.g. `h_1_mlp_c_proj`.
62        """
63        # E.g., from 'L__self___h_1_mlp_c_proj' to 'h_1_mlp_c_proj'.
64        name = self.module_name
65        if name.startswith("L__self___"):
66            name = name[len("L__self___") :]
67        return name
68
69    @property
70    def qualified_module_class_name(self) -> str:
71        """Qualified name of the module class.
72
73        E.g. `torch_nn_module_sparse_Embedding`.
74        """
75        if self._module_class is None:
76            return ""
77        mod_cls = self._module_class
78        if isinstance(mod_cls, type):
79            mod_cls = mod_cls.__module__ + "." + mod_cls.__qualname__
80        return mod_cls.replace(".", "_")
81
82    @property
83    def module_class_name(self) -> str:
84        """Name of the module class.
85
86        E.g. `Embedding`.
87        """
88        if self._module_class is None:
89            return ""
90        if isinstance(self._module_class, type):
91            return self._module_class.__name__
92        return self._module_class
93
94    @property
95    def module_name(self) -> str:
96        """Name of the module.
97
98        E.g. `L__self___h_1_mlp_c_proj`.
99        """
100        return self._module_name
101
102    @property
103    def raw_meta(self) -> tuple[Any, Any]:
104        """Returns the raw module meta data.
105
106        I.e. (module_name, node.meta['nn_module_stack'][module_name]).
107        """
108        return self._raw_meta
109
110    def __eq__(self, __value: object) -> bool:
111        if not isinstance(__value, _ModuleMeta):
112            return False
113        return (
114            self._module_name == __value._module_name
115            and self._module_class == __value._module_class
116        )
117
118    def __hash__(self) -> int:
119        return hash((self._module_name, self._module_class))
120
121    def __repr__(self) -> str:
122        return f"ModuleMeta(name={self._module_name}, class={self._module_class})"
123
124    @classmethod
125    def create_root(cls) -> _ModuleMeta:
126        """Create an empty module meta representing root module."""
127        return _ModuleMeta("", None, ("", None))
128
129    @classmethod
130    def from_fx_tracer_produced_raw_meta(
131        cls, raw_meta: _FX_TRACER_NN_MODULE_META_TYPE
132    ) -> _ModuleMeta:
133        """Create a module meta from raw meta produced by FX symbolic tracer."""
134        module_name, module_class = raw_meta
135        return _ModuleMeta(module_name, module_class, raw_meta)
136
137    @classmethod
138    def from_dynamo_produced_raw_meta(
139        cls, raw_meta: _DYNAMO_NN_MODULE_META_TYPE
140    ) -> _ModuleMeta:
141        """Create a module meta from raw meta produced by FX dynamo tracer."""
142        module_name, (qualified_name, module_class) = raw_meta
143        return _ModuleMeta(module_name, module_class, raw_meta)
144
145    @classmethod
146    def from_raw_meta(
147        cls,
148        raw_meta: _FX_TRACER_NN_MODULE_META_TYPE | _DYNAMO_NN_MODULE_META_TYPE,
149    ) -> _ModuleMeta:
150        if (
151            isinstance(raw_meta, tuple)
152            and len(raw_meta) == 2
153            and isinstance(raw_meta[1], type)
154        ):
155            # Trying to do `instance(raw_meta, _FX_TRACER_NN_MODULE_META_TYPE)`
156            return _ModuleMeta.from_fx_tracer_produced_raw_meta(raw_meta)
157        if (
158            isinstance(raw_meta, tuple)
159            and len(raw_meta) == 2
160            and isinstance(raw_meta[1], tuple)
161        ):
162            # Trying to do `instance(raw_meta, _DYNAMO_NN_MODULE_META_TYPE)`
163            return _ModuleMeta.from_dynamo_produced_raw_meta(raw_meta)
164        raise TypeError(
165            f"Unknown type of raw meta item from node.meta['nn_module_stack'].items(): {type(raw_meta)}"
166        )
167
168
169class _ModuleStackMeta:
170    """Meta information about the module call stack.
171
172    This class is used to represent the module call stack information in a more
173    structured way. It parses raw module stack information from `node.meta["nn_module_stack"]`.
174
175    Example of raw module stack information:
176
177        If produced by dynamo:
178
179            {
180                'L__self___h_1': (
181                    "L['self'].h[1]",
182                    <class 'transformers.models.gpt2.modeling_gpt2.GPT2Block'>
183                ),
184                'L__self___h_1_attn': (
185                    "L['self'].h[1].attn",
186                    <class 'transformers.models.gpt2.modeling_gpt2.GPT2Attention'>
187                )
188            }
189
190        If produced by fx.symbolic_trace:
191
192            {
193                'h.1': <class 'transformers.models.gpt2.modeling_gpt2.GPT2Block'>,
194                'h.1.attn': <class 'transformers.models.gpt2.modeling_gpt2.GPT2Attention'>
195            }
196    """
197
198    _module_stack: Final[list[_ModuleMeta]]  # type: ignore[misc]
199
200    def __init__(
201        self,
202        nn_module_stack_meta: _FX_TRACER_NN_MODULE_STACK_META_TYPE
203        | _DYNAMO_NN_MODULE_STACK_META_TYPE
204        | None,
205        is_exported_program: bool = True,
206    ):
207        self._module_stack = []
208        if nn_module_stack_meta is None:
209            return
210        raw_meta = copy.copy(nn_module_stack_meta)
211        for item in raw_meta.items():
212            # If produced by torch.export.export, there is another call stack layer
213            # that we need to skip
214            if is_exported_program:
215                is_exported_program = False
216                continue
217            self.push(_ModuleMeta.from_raw_meta(item))  # type: ignore[arg-type]
218
219    def __len__(self) -> int:
220        return len(self._module_stack)
221
222    def __getitem__(self, index: int) -> _ModuleMeta:
223        return self._module_stack[index]
224
225    def __iter__(self) -> Iterator[_ModuleMeta]:
226        return iter(self._module_stack)
227
228    def is_empty_or_root(self) -> bool:
229        return len(self._module_stack) == 0
230
231    def top(self) -> _ModuleMeta:
232        """Returns the top module meta in the stack. I.e., the meta for leaf module.
233
234        Example:
235
236            Consider the following module stack:
237
238            stack = [GPT, block1, Attention_1, MLP]
239
240            stack.top() == MLP
241        """
242        if self.is_empty_or_root():
243            return _ModuleMeta.create_root()
244        return self._module_stack[-1]
245
246    def is_superset_of(
247        self,
248        module_stack: _ModuleStackMeta,
249    ) -> bool:
250        """Determines if self is a superset of the provided module stack.
251
252        I.e., If self includes all elements from the provided module stack, plus additional
253        elements on top. If self is empty or root, this method always return False.
254
255        Example:
256
257            Consider the following module stack:
258
259            stack_1 = [GPT, block1, Attention_1, MLP]
260            stack_2 = [GPT, block1]
261
262            stack_1.is_superset_of(stack_2) == True
263            stack_2.is_superset_of(stack_1) == False
264
265            stack_3 = [GPT, block2, Attention_1]
266
267            stack_1.is_superset_of(stack_3) == False
268            stack_3.is_superset_of(stack_1) == False
269        """
270        if self.is_empty_or_root():
271            return False
272
273        if module_stack.is_empty_or_root() is None:
274            return True
275
276        if len(self) <= len(module_stack):
277            return False
278
279        for i, parent_key in enumerate(module_stack):
280            if self[i] != parent_key:
281                return False
282
283        return True
284
285    def push(self, module_meta: _ModuleMeta) -> None:
286        """Pushes a module meta to the stack."""
287        self._module_stack.append(module_meta)
288
289    def __eq__(self, __value: object) -> bool:
290        if not isinstance(__value, _ModuleStackMeta):
291            return False
292        return self._module_stack == __value._module_stack
293
294    @property
295    def raw_meta(self) -> dict[str, tuple[str, type]] | None:
296        """Returns the raw module stack meta data, i.e. node.meta['nn_module_stack']."""
297        return {
298            module_meta.raw_meta[0]: module_meta.raw_meta[1]
299            for module_meta in self._module_stack
300        }
301
302    def __repr__(self) -> str:
303        return f"ModuleStackMeta({self._module_stack})"
304
305    @property
306    def module_display_name(self) -> str:
307        """Returns the module display name of the top module."""
308        return self.top().module_display_name
309
310    @property
311    def qualified_module_class_name(self) -> str:
312        """Returns the qualified module class name of the top module."""
313        return self.top().qualified_module_class_name
314
315    @property
316    def module_class(self) -> type | str | None:
317        """Returns the module class of the top module."""
318        return self.top()._module_class
319
320
321def _module_stack_meta_from_node(
322    node: torch.fx.Node, is_exported_program: bool = False
323) -> _ModuleStackMeta:
324    return _ModuleStackMeta(
325        node.meta.get("nn_module_stack"), is_exported_program=is_exported_program
326    )
327
328
329def _get_unique_module_name(module_names: dict[str, int], module_name: str) -> str:
330    module_names.setdefault(module_name, 0)
331    module_names[module_name] += 1
332    return f"{module_name}_{module_names[module_name]}"
333
334
335class _IRNode(abc.ABC):
336    """Base class for IR nodes.
337
338    IR nodes are used for Modularize pass only. They add a layer of abstraction on top of
339    torch.fx.Node.
340
341    [NOTE: Modularize Pass Implementation]
342    The main job of the pass is to group `fx.Node`s that belong to the same `nn.Module`
343    forward call, and then create `call_module` node and sub `fx.GraphModule` from them.
344    Each `fx.Node` possesses an `nn_module_stack` meta data that contains information
345    about the module call stack. See `_ModuleStackMeta` for examples.
346
347    Analysis step
348    -------------
349
350    Each module call is identified by a set of base stack layers. For each module call,
351    the pass creates a `_ModuleNode` and groups the sequence of nodes that shares the
352    same base stack layers.
353
354    For example,
355
356        stack_of_node_0 = [GPT, block0]
357        stack_of_node_1 = [GPT, block1]
358        stack_of_node_2 = [GPT, block1, Attention1, MLP]
359        stack_of_node_3 = [GPT, block1, Attention1]
360        stack_of_node_4 = [GPT, block2]
361
362    All nodes belong to the `GPT` module call, since they share the base stack layers [GPT].
363    [node_1, node_2, node_3] are grouped for `GPT.block1`, because they share the base
364    stack layers [GPT, block1]. And [node_2, node_3] for `GPT.block1.Attention1`, [node_0]
365    for `GPT.block0`, and [node_4] for `GPT.block2` respectfully.
366
367    After the analysis step, a hierarchical representation is generated.
368
369    For above example, the representation is:
370
371        _ModuleNode(GPT)
372            _ModuleNode(block0)
373                _LeafNode(node_0)
374            _ModuleNode(block1)
375                _LeafNode(node_1)
376                _ModuleNode(Attention1)
377                    _ModuleNode(MLP)
378                        _LeafNode(node_2)
379                _LeafNode(node_3)
380            _ModuleNode(block2)
381                _LeafNode(node_4)
382
383    Construction step
384    -----------------
385
386    The second step is to build the actual `call_module` node and the sub `fx.GraphModule`.
387    This is done recursively from the leaf `_ModuleNode` to the root.
388
389    For example, the first submodule to be built is `GPT.block1.Attention1.MLP`. Below pair
390    is generated from `_ModuleNode(MLP)`.
391
392        fx.GraphModule(GPT.block1.Attention1.MLP)
393            graph:
394                node_2
395
396        new_mlp_node = `call_module[GPT.block1.Attention1.MLP](...)`
397
398    Next, the `GPT.block1.Attention1` submodule is built. Below is generated from
399    `_ModuleNode(Attention1)`.
400
401        fx.GraphModule(GPT.block1.Attention1)
402            graph:
403                new_mlp_node
404                node_3
405
406        new_attention1_node = `call_module[GPT.block1.Attention1](...)`
407
408    Until every submodule is built, the new modularized `fx.GraphModule` is generated.
409
410    Alternatives
411    ------------
412
413    The current algorithm adopts a top down approach. A bottom up approach is similar.
414    In contrast to these two, an alternative flat order approach is also possible, where
415    each node is traversed and copied to the corresponding submodule.
416
417    The advantage of the current approach lies in the encapsulation of the fx.GraphModule
418    construction for each individual submodule within a single `build_module` method, which
419    can be called separately once the analysis phase is completed, making debugging more
420    convenient.
421
422    Regarding construction step, an alternative implementation is to utilize `fx.Interpreter`
423    for traversing all the nodes under the flattened root module and copying the nodes
424    into their respective submodule under construction. This approach is not adopted because
425
426        1. It uses the flat order approach discussed above. This means one cannot individually
427    construct a submodule and examine it while debugging.
428
429        2. The graph execution functionality of `fx.Interpreter` is not necessary for the
430    purpose of this pass. Ignoring that, `fx.Interpreter.run` achieves the same effect
431    as a for loop over all the nodes.
432    """
433
434    @property
435    @abc.abstractmethod
436    def stack_meta(self) -> _ModuleStackMeta:
437        """The module stack meta data associated with this node."""
438        ...
439
440    @property
441    @abc.abstractmethod
442    def stack_trace(self) -> str | None:
443        """The stack trace associated with this node."""
444        ...
445
446
447class _ModuleNode(_IRNode):
448    """Representing a sequence of fx.Nodes to be formed into a fx.GraphModule.
449
450    This class encapsulates metadata and provides building block methods to construct this
451    layered abstraction from a sequence of flat fx.Nodes.
452
453    Attributes:
454    - _stack_meta: Metadata of the module stack.
455    - _nodes: List of IR nodes in the module.
456    - _reference_root_module: Reference to the root flat fx.GraphModule instance.
457    """
458
459    def __init__(
460        self, reference_root_module: torch.fx.GraphModule, stack_meta: _ModuleStackMeta
461    ):
462        self._stack_meta = stack_meta
463        self._nodes: list[_IRNode] = []
464        self._reference_module = reference_root_module
465
466    @property
467    def stack_meta(self) -> _ModuleStackMeta:
468        return self._stack_meta
469
470    @property
471    def stack_trace(self) -> str | None:
472        assert self._nodes
473        return self._nodes[0].stack_trace
474
475    def __str__(self) -> str:
476        return f"ModuleNode({self._stack_meta})"
477
478    def is_same_module_as(self, node: _IRNode) -> bool:
479        """Determines if the provided node pertains to the same module as this node."""
480        return self.stack_meta == node.stack_meta
481
482    def is_parent_module_of(self, node: _IRNode) -> bool:
483        """Determines if this node represents a parent module of the provided node."""
484        return node.stack_meta.is_superset_of(self.stack_meta)
485
486    def add_leaf_node(self, leaf_node: _LeafNode) -> None:
487        """Adds a leaf node to the module.
488
489        The leaf node must belong to the same or a child module. This method will recursively
490        construct _ModuleNode instance based on the stack_meta information of the leaf node.
491        """
492        if self.is_same_module_as(leaf_node) or leaf_node.fx_op == "call_module":
493            self._nodes.append(leaf_node)
494        elif leaf_node.fx_op == "placeholder":
495            # Although the original placeholder has empty nn_module_stack, the placeholder lifted
496            # from get_attr nodes by exported program has their original nn_module_stack. Here
497            # we need to avoid them building submodule.
498            self._nodes.append(leaf_node)
499        elif self.is_parent_module_of(leaf_node):
500            # This node belongs in a submodule.
501            # Check if the last node is a submodule and if it is the parent of this node.
502            last_node = self._nodes[-1] if self._nodes else None
503            if isinstance(last_node, _ModuleNode) and (
504                last_node.is_parent_module_of(leaf_node)
505                or last_node.is_same_module_as(leaf_node)
506            ):
507                # This node belongs to the last_node.
508                last_node.add_leaf_node(leaf_node)
509            else:
510                # Create a new SubmoduleNode for the immediate child module of the current
511                # module. The leaf node may be a grandchild of the current module.
512                # Example:
513                #   self.stack_meta = [A, B, C]
514                #   leaf_node.stack_meta = [A, B, C, D, E, F]
515                # Create a new ModuleNode with stack_meta = [A, B, C, D] and add leaf_node to it.
516                stack_meta = copy.deepcopy(self.stack_meta)
517                stack_meta.push(leaf_node.stack_meta[len(self.stack_meta)])
518                last_node = _ModuleNode(
519                    self._reference_module,
520                    stack_meta,
521                )
522                self._nodes.append(last_node)
523                last_node.add_leaf_node(leaf_node)
524        else:
525            raise AssertionError(
526                f"Node {leaf_node} ({leaf_node.stack_meta}) does not belong to module "
527                f"{self._stack_meta}."
528            )
529
530    def fx_nodes(self) -> Generator[torch.fx.Node, None, None]:
531        """Returns an iterator for the sequence of fx nodes this instance holds."""
532        for node in self._nodes:
533            if isinstance(node, _ModuleNode):
534                yield from node.fx_nodes()
535            else:
536                assert isinstance(node, _LeafNode)
537                yield node.fx_node
538
539    def module_inputs(self) -> Sequence[torch.fx.Node]:
540        """Extract module inputs from the sequence of fx nodes this instance holds.
541
542        All node args that are produced by nodes outside of the module are considered module
543        inputs. The order of returned module inputs is the same as the their use order.
544
545        ### Known limitations
546
547        The original ordering of module inputs is not preserved. There is no meta information
548        to be found from the `fx.GraphModule` that can be used to recover the original ordering.
549
550        Returns:
551            Sequence of module inputs.
552        """
553        nodes = list(self.fx_nodes())
554        assert len(nodes) > 0, "Cannot extract module inputs from empty nodes."
555        module_inputs: dict[torch.fx.Node, None] = {}
556        node_set: set[torch.fx.Node] = set(nodes)
557
558        def _extract_arg_if_node_outside_module(arg: Any):
559            if isinstance(arg, torch.fx.Node) and arg not in node_set:
560                module_inputs[arg] = None
561
562        for node in nodes:
563            pytree.tree_map(_extract_arg_if_node_outside_module, node.args)
564            pytree.tree_map(_extract_arg_if_node_outside_module, node.kwargs)
565        return list(module_inputs.keys())
566
567    def module_outputs(self) -> Sequence[torch.fx.Node]:
568        """Extract module outputs from the sequence of fx nodes this instance holds.
569
570        All nodes that are used by nodes outside of the module are considered module
571        outputs. The order of returned module outputs is the same as the their creation order.
572
573        ### Known limitations
574
575        The original ordering of module outputs is not preserved. There is no meta information
576        to be found from the `fx.GraphModule` that can be used to recover the original ordering.
577
578        Returns:
579            Sequence of module outputs.
580        """
581        nodes = list(self.fx_nodes())
582        assert len(nodes) > 0, "Cannot extract module inputs from empty nodes."
583        # Need ordered set. Emulate with dict.
584        module_outputs: dict[torch.fx.Node, None] = {}
585        node_set: set[torch.fx.Node] = set(nodes)
586
587        for node in nodes:
588            if any(user not in node_set for user in node.users):
589                module_outputs[node] = None
590        return list(module_outputs.keys())
591
592    def build_module(self, module_names: dict[str, int]) -> torch.fx.GraphModule:
593        """
594        Constructs the fx.GraphModule for this node, registering submodules as necessary.
595
596        Args:
597            module_names: A dictionary of module names and their counts. This is used to
598                generate unique module names for submodules. This should be an empty
599                dictionary when the method is called on a root module.
600        """
601        module_class_name = self._stack_meta.qualified_module_class_name
602        fx_graph = torch.fx.Graph()
603        copy_env: dict[torch.fx.Node, torch.fx.Node] = {}
604
605        def _arg_transform(node: torch.fx.Node) -> torch.fx.Node:
606            return copy_env[node]
607
608        ref_inputs = self.module_inputs()
609        for node in ref_inputs:
610            copy_env[node] = fx_graph.placeholder(node.name, node.type)
611            copy_env[node].meta = copy.copy(node.meta)
612
613        for ir_node in self._nodes:
614            if isinstance(ir_node, _LeafNode):
615                fx_node = ir_node.fx_node
616                copy_env[fx_node] = fx_graph.node_copy(
617                    fx_node, arg_transform=_arg_transform
618                )
619                continue
620
621            assert isinstance(ir_node, _ModuleNode)
622            # Create fx.GraphModule for child submodule.
623            submodule = ir_node.build_module(module_names)
624            ref_submodule_inputs = ir_node.module_inputs()
625            ref_submodule_outputs = ir_node.module_outputs()
626            unique_submodule_name = _get_unique_module_name(
627                module_names, ir_node.stack_meta.module_display_name
628            )
629            # Link the newly generated sub fx.GraphModule with the root reference module.
630            # This step is essential to meet the needs of the subsequent fx.GraphModule initialization
631            # for the fx.GraphModule being created by this method.
632            # The initialization of fx.GraphModule will replicate all necessary attributes from a reference
633            # fx.GraphModule for the fx.Graph. While the root reference module possesses all
634            # parameters and buffers, it does not include the newly created sub fx.GraphModule.
635            # Therefore, it's necessary to register it under the root reference at this stage.
636            self._reference_module.add_submodule(unique_submodule_name, submodule)
637
638            # create call_module fx.Node
639            submodule_node = fx_graph.call_module(
640                unique_submodule_name,
641                tuple(_arg_transform(node) for node in ref_submodule_inputs),
642            )
643            if len(ref_submodule_outputs) > 1:
644                # Module node has multiple output. Create 'getitem' node for each output.
645                submodule_node.meta["val"] = tuple(
646                    ref_output.meta.get("val") for ref_output in ref_submodule_outputs
647                )
648                for i, ref_output in enumerate(ref_submodule_outputs):
649                    getitem_node = fx_graph.call_function(
650                        operator.getitem,
651                        args=(submodule_node, i),
652                        type_expr=ref_output.type,
653                    )
654                    getitem_node.meta = copy.copy(ref_output.meta)
655                    # Make a copy for "nn_module_stack" since the current module will be
656                    # popped from the stack for this 'getitem' node.
657                    getitem_node.meta["nn_module_stack"] = copy.copy(
658                        ref_output.meta["nn_module_stack"]
659                    )
660                    # The node is associated with the parent module.
661                    getitem_node.meta["nn_module_stack"].popitem()
662                    copy_env[ref_output] = getitem_node
663            else:
664                # Module node has single output. Use module node directly.
665                copy_env[ref_submodule_outputs[0]] = submodule_node
666                submodule_node.meta = copy.copy(ref_submodule_outputs[0].meta)
667
668            # Update meta for new call_module node.
669            if (stack_trace := ir_node.stack_trace) is not None:
670                submodule_node.meta["stack_trace"] = stack_trace
671            raw_module_stack_meta = ir_node.stack_meta.raw_meta
672            assert raw_module_stack_meta is not None
673            submodule_node.meta["nn_module_stack"] = copy.copy(raw_module_stack_meta)
674            # The node is associated with the parent module.
675            submodule_node.meta["nn_module_stack"].popitem()
676
677        new_nodes = fx_graph.nodes
678        # Skip if the last node is already 'output'. This is the case for root module.
679        # Otherwise create an 'output' node for the inferred outputs.
680        if next(iter(reversed(new_nodes))).op != "output":
681            ref_submodule_outputs = self.module_outputs()
682            new_outputs = [copy_env[ref_output] for ref_output in self.module_outputs()]
683            node = fx_graph.output(
684                new_outputs[0] if len(new_outputs) == 1 else new_outputs
685            )
686
687        graph_module = torch.fx.GraphModule(
688            self._reference_module, fx_graph, module_class_name
689        )
690        if (module_class := self._stack_meta.module_class) is not None:
691            graph_module.meta["onnx"] = _pass.GraphModuleOnnxMeta(
692                _pass.PackageInfo.from_python_class(module_class)
693            )
694        return graph_module
695
696
697class _LeafNode(_IRNode):
698    """Representing a single fx.Node."""
699
700    def __init__(self, node: torch.fx.Node, is_exported_program: bool = False):
701        self._node = node
702        self._stack_meta = _module_stack_meta_from_node(
703            node, is_exported_program=is_exported_program
704        )
705
706    @property
707    def fx_op(self) -> str:
708        """Syntax sugar for self.fx_node.op."""
709        return self._node.op
710
711    @property
712    def fx_node(self) -> torch.fx.Node:
713        """Returns the fx.Node this instance represents."""
714        return self._node
715
716    @property
717    def stack_meta(self) -> _ModuleStackMeta:
718        """Returns the module stack meta data associated with this node."""
719        return self._stack_meta
720
721    @property
722    def stack_trace(self) -> str | None:
723        """Returns the stack trace associated with this node."""
724        return self.fx_node.meta.get("stack_trace")
725
726    def __str__(self) -> str:
727        return f"LeafNode({self._node})"
728
729
730class Modularize(_pass.Transform):
731    """Transforms a flattened `fx.GraphModule` into a modular structure.
732
733    In the flattened `fx.GraphModule`, each `nn.Module` forward call has been traced as
734    a sequence of `fx.Node`s. All these `fx.Node`s are flattened and reside in the same
735    `fx.GraphModule`. `fx.GraphModule` could be from `torch.export.ExportedProgram` or
736    directly generated by `torch._dynamo.export` with torch.nn.Module.
737
738    This pass generates a new `fx.GraphModule`. It groups the flattened `fx.Node`s that belong
739    to the same `nn.Module` forward call into a sub `fx.GraphModule`. It then replaces the
740    sequence of flattened `fx.Node`s with a single `call_module` node, which is linked with
741    the sub `fx.GraphModule` by `node.target`. The sub `fx.GraphModule` is registered as a
742    submodule of the new `fx.GraphModule`.
743
744    The process is done based on information from the `nn_module_stack` metadata of each node, i.e.
745    `node.meta["nn_module_stack"]`. For more implementation details, see [NOTE: Modularize Pass Implementation].
746
747    An fx submodule under this context can typically be interpreted in three different ways:
748
749        1. As an embodiment of an nn.Module class, which is considered stateless.
750        Its execution path can vary depending on the configuration of module initialization,
751        which should also be part of the inputs.
752
753        2. As a representation of an nn.Module instance. It maintains the state initialized in the module.
754        The execution path can vary based on actual input data.
755
756        3. As a captured call of an nn.Module instance, where the execution path
757        is set.
758
759    The generality decreases along this list. Within the scope of this function, the pass
760    creates fx submodules according to the third interpretation.
761
762    The first interpretation is the most general case. It requires complex analysis and additional
763    metadata and code information to construct its general form. Consider an example nn.Module
764    that generates arbitrary submodules based on an initialization configuration file. It's impractical
765    to extract this logic for the generated fx submodule to function with arbitrary configuration.
766
767    The second interpretation demands less analysis and is sturdier than the
768    first. In most use cases, it's equivalent to the third. It only differs in exceptional situations
769    where a complex nn.Module instance is called multiple times, each with a different set of inputs
770    leading to a unique execution branching path.
771
772    The third interpretation is the most specific scenario. It necessitates the minimum
773    analysis and creates the most stable representation. The drawback is that it
774    generates more redundancy than the other two methods. If needed, a subsequent post-processing
775    pass can be applied to consolidate completely identical functions and reduce duplication.
776
777    ### Known constraints
778    Two successive calls to the same module instance will be conflated. They are indistinguishable.
779    This is due to limitations of the current fx metadata "nn_module_stack".
780
781    [NOTE: Modularize pass ordering]
782    This pass groups fx nodes into subgraphs that reside within the `call_module` fx node.
783    Other fx passes (including some outside the exporter) might not recognize `call_module`.
784    They may assume that all nodes are flattened. Hence it is recommended to invoke this pass
785    as the last pre onnx export fx pass. If not for this consideration, this operation could
786    potentially be relocated anywhere earlier in the pipeline.
787
788    Example:
789
790        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
791        >>> import torch
792        >>> from torch.onnx._internal.fx import passes
793        >>> from torch.onnx._internal.diagnostics import infra
794        >>>
795        >>> class CustomModule(torch.nn.Module):
796        >>>     def __init__(self) -> None:
797        >>>         super().__init__()
798        >>>         self.embedding = torch.nn.Embedding(10, 32)
799        >>>         self.relu = torch.nn.ReLU()
800        >>>
801        >>>     def forward(self, x):
802        >>>         out = self.embedding(x)
803        >>>         out = self.relu(out)
804        >>>         return out
805        >>>
806        >>> class TestModule(torch.nn.Module):
807        >>>     def __init__(self) -> None:
808        >>>         super().__init__()
809        >>>         self.layer = CustomModule()
810        >>>         self.linear = torch.nn.Linear(32, 10)
811        >>>
812        >>>     def forward(self, x):
813        >>>         out = self.layer(x)
814        >>>         out = self.linear(out)
815        >>>         return out
816        >>>
817        >>> gm, _ = torch._dynamo.export(TestModule(), aten_graph=True)(
818        ...     torch.tensor([0, 1, 2])
819        ... )
820        >>> gm.print_readable()
821
822        >>> gm = passes.Modularize(infra.DiagnosticContext("test_context", "1.0"), gm).run()
823        >>> gm.print_readable()
824
825    """
826
827    def __init__(
828        self,
829        diagnostic_context: diagnostics.DiagnosticContext,
830        module: torch.fx.GraphModule,
831        is_exported_program: bool = False,
832    ):
833        super().__init__(diagnostic_context, module)
834        self.module = module
835        self.is_exported_program = is_exported_program
836
837    def _run(self) -> torch.fx.GraphModule:
838        # DCE to remove unused nodes.
839        # If a submodule is unused, it is hard to analyze which nodes constitutes the submodule
840        # outputs. But since it is unused, we can just remove it.
841        self.module.graph.eliminate_dead_code()
842
843        reference_module = torch.fx.GraphModule(self.module, self.module.graph)
844        root_module_node = _ModuleNode(
845            reference_module,
846            _ModuleStackMeta(
847                nn_module_stack_meta=None, is_exported_program=self.is_exported_program
848            ),
849        )
850        for fx_node in self.module.graph.nodes:
851            root_module_node.add_leaf_node(
852                _LeafNode(fx_node, is_exported_program=self.is_exported_program)
853            )
854        return root_module_node.build_module({})
855