xref: /aosp_15_r20/external/pytorch/torch/onnx/_internal/fx/passes/virtualization.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4from typing import TYPE_CHECKING
5
6import torch
7from torch.onnx._internal.fx import _pass
8
9
10if TYPE_CHECKING:
11    import torch.fx
12
13
14class MovePlaceholderToFront(_pass.Transform):
15    """This pass move all placeholder nodes to the front of the graph node list.
16
17    In torch.fx.Graph, placeholder is a special assignment node. If it's not
18    executed in the beginning, it could overwrite values computed by upstream
19    nodes.
20    """
21
22    def _run(self, *args, **kwargs) -> torch.fx.GraphModule:
23        graph_module = self.module
24        graph = graph_module.graph
25        placeholders = []
26        first_not_placeholder = None
27        for node in graph.nodes:
28            if node.op == "placeholder":
29                placeholders.append(node)
30            if first_not_placeholder is None and node.op != "placeholder":
31                first_not_placeholder = node
32        if first_not_placeholder is None:
33            return graph_module
34        for placeholder in placeholders:
35            first_not_placeholder.prepend(placeholder)
36        return graph_module
37
38
39class ReplaceGetAttrWithPlaceholder(_pass.Transform):
40    """Replace get_attr with placeholder.
41
42    The parameters and buffers accessed by the original get_attr are returned;
43    they are useful when creating random inputs for the modified graph_module.
44    """
45
46    _replaced_attrs: tuple[torch.Tensor, ...] | None
47
48    @property
49    def replaced_attrs(self) -> tuple[torch.Tensor, ...]:
50        """The list of replaced weight tensors."""
51        assert (
52            self._replaced_attrs is not None
53        ), "Must run ReplaceGetAttrWithPlaceholder first"
54        return self._replaced_attrs
55
56    def _run(self, *args, **kwargs) -> torch.fx.GraphModule:
57        graph_module = self.module
58        graph = graph_module.graph
59        replaced_attrs: list[torch.Tensor] = []
60        for node in graph.nodes:
61            if node.op == "get_attr":
62                replaced_attr: torch.Tensor | None = None
63                # get_attr could retrieve either parameter or buffer, so
64                # we need to try both.
65                try:
66                    replaced_attr = graph_module.get_parameter(node.target)
67                except AttributeError:
68                    # It's possible that model author use buffer instead of
69                    # parameter to store trainable weights. In this case,
70                    # 1. get_parameter will throw something like
71                    #    AttributeError: `bias` is not an nn.Parameter.
72                    # 2. get_buffer should work.
73                    replaced_attr = graph_module.get_buffer(node.target)
74
75                # Reassign op type so that get_attr node becomes placeholder node.
76                node.op = "placeholder"
77                # The target name in placeholder must be a valid Python identifier.
78                # Thus, we replace, e.g., "module.submodule.weight" with
79                # "module_submodule_weight".
80                node.target = node.target.replace(".", "_")
81                # Default value is None. This is needed as long as the "graph_module"
82                # has optional inputs. Assume the original forward signature is
83                #  def forward(self, x, y=None)
84                # and the replaced get_attr node has target "z". Then, the modified
85                # signature should be
86                #  def forward(self, x, y=None, z=None)
87                # Without the following line, the signature will be
88                #  def forward(self, x, y=None, z)
89                # , which is not valid Python code.
90                node.args = (None,)
91
92                replaced_attrs.append(replaced_attr)
93
94        self._replaced_attrs = tuple(replaced_attrs)
95
96        return graph_module
97