xref: /aosp_15_r20/external/pytorch/torch/onnx/_internal/fx/passes/readability.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4from typing import Sequence
5
6import torch
7from torch.onnx._internal.fx import _pass, diagnostics
8
9
10class RestoreParameterAndBufferNames(_pass.Transform):
11    """Restore parameter and buffer names from original nn.module.
12
13    This pass is useful for readability of the exported ONNX graph. It restores the
14    parameter and buffer names from the original nn.module. For example, if the original
15    nn.module has a parameter named `root.linear.0.weight`, and the parameter is renamed to
16    `_param_constant9` by FX, this pass will rename it back.
17
18    This pass must be run after `Decompose` pass. Because this pass is expected to be called on
19    `fx.GraphModule` produced by `proxy_tensor.make_fx`, where all parameters and buffers
20    are registered at root level.
21    """
22
23    def __init__(
24        self,
25        diagnostic_context: diagnostics.DiagnosticContext,
26        fx_module: torch.fx.GraphModule,
27        original_nn_module: torch.nn.Module,
28    ):
29        super().__init__(diagnostic_context, fx_module)
30        self.original_nn_module = original_nn_module
31
32    def _rename_param_and_buffer(
33        self,
34        diagnostic: diagnostics.Diagnostic,
35        nodes: Sequence[torch.fx.Node],
36        new_name: str,
37    ) -> None:
38        """Rename the parameter/buffer and replace corresponding nodes with new nodes of updated target."""
39        assert len(nodes) > 0, "`nodes` cannot be empty"
40        assert (
41            len({node.target for node in nodes}) == 1
42        ), "`nodes` must all have same `target`"
43        old_name = nodes[0].target
44        assert isinstance(old_name, str), f"Expected str, got type({old_name})"
45        # Parameter/buffer name cannot contain "."
46        normalized_name = new_name.replace(".", "/")
47        attr_value = getattr(self.module, old_name)
48        setattr(self.module, normalized_name, attr_value)
49        delattr(self.module, old_name)
50        for node in nodes:
51            with self.module.graph.inserting_before(node):
52                new_node = self.module.graph.get_attr(normalized_name)
53                new_node.meta = node.meta
54                node.replace_all_uses_with(new_node)
55                self.module.graph.erase_node(node)
56        diagnostic.info(
57            "Renamed 'self.%s' to 'self.%s', "
58            "normalized from original parameter name '%s'.",
59            old_name,
60            normalized_name,
61            new_name,
62        )
63
64    def _run(self, *args, **kwargs) -> torch.fx.GraphModule:
65        """Restore parameter and buffer names from original module.
66
67        For each `get_attr` node, if the target is a str representing a parameter or buffer
68        under `self.module`, we rename the parameter or buffer to its original name.
69        The parameters and buffers between `self.module` and `self.original_nn_module` refer
70        to the same objects, allowing us to use it as key to retrieve the original name.
71        """
72        assert len(args) == 0, "RestoreParameterAndBufferNames does not take any args"
73        assert (
74            len(kwargs) == 0
75        ), "RestoreParameterAndBufferNames does not take any kwargs"
76        # state_to_readable_name[parameter/buffer] returns the original readable name of
77        # the parameter/buffer. E.g., "self.linear.weight".
78        state_to_readable_name: dict[torch.nn.Parameter | torch.Tensor, str] = {}
79        state_to_readable_name.update(
80            {v: k for k, v in self.original_nn_module.named_parameters()}
81        )
82        state_to_readable_name.update(
83            {v: k for k, v in self.original_nn_module.named_buffers()}
84        )
85        diagnostic = self.diagnostic_context.inflight_diagnostic()
86
87        # old_name_to_nodes[old_name] returns a tuple of (nodes, new_name)
88        # where `nodes` is a list of `get_attr` nodes with `old_name` as `target` and
89        # `new_name` is the new readable name.
90        old_name_to_nodes: dict[str, tuple[list[torch.fx.Node], str]] = {}
91
92        for node in self.module.graph.nodes:
93            if node.op == "get_attr":
94                assert isinstance(
95                    node.target, str
96                ), f"Expected str, got type({node.target})"
97                if node.target.find(".") != -1:
98                    raise RuntimeError(
99                        f"Unexpected target {node.target} in get_attr, found '.' in target. "
100                        f"All parameters and buffers are expected to be registered at root level, "
101                        f"i.e., self.module. "
102                    )
103                if node.target in old_name_to_nodes:
104                    # We have already processed this parameter/buffer.
105                    old_name_to_nodes[node.target][0].append(node)
106                    continue
107                attr_value = getattr(self.module, node.target)
108                if (
109                    isinstance(attr_value, (torch.nn.Parameter, torch.Tensor))
110                    and attr_value in state_to_readable_name
111                ):
112                    readable_name = state_to_readable_name[attr_value]
113                    old_name_to_nodes[node.target] = ([node], readable_name)
114                    continue
115
116                diagnostic.info(
117                    "Cannot find readable name for self.%s: %s. The name is unchanged.",
118                    node.target,
119                    type(attr_value),
120                )
121                if isinstance(attr_value, torch.nn.Parameter):
122                    # If it is a parameter we treat it more seriously.
123                    diagnostic.level = diagnostics.levels.WARNING
124                else:
125                    diagnostic.level = diagnostics.levels.NONE
126
127        for nodes, new_name in old_name_to_nodes.values():
128            self._rename_param_and_buffer(diagnostic, nodes, new_name)
129
130        return self.module
131