1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4from typing import TYPE_CHECKING 5 6import torch 7from torch.onnx._internal.fx import _pass 8 9 10if TYPE_CHECKING: 11 import torch.fx 12 13 14class MovePlaceholderToFront(_pass.Transform): 15 """This pass move all placeholder nodes to the front of the graph node list. 16 17 In torch.fx.Graph, placeholder is a special assignment node. If it's not 18 executed in the beginning, it could overwrite values computed by upstream 19 nodes. 20 """ 21 22 def _run(self, *args, **kwargs) -> torch.fx.GraphModule: 23 graph_module = self.module 24 graph = graph_module.graph 25 placeholders = [] 26 first_not_placeholder = None 27 for node in graph.nodes: 28 if node.op == "placeholder": 29 placeholders.append(node) 30 if first_not_placeholder is None and node.op != "placeholder": 31 first_not_placeholder = node 32 if first_not_placeholder is None: 33 return graph_module 34 for placeholder in placeholders: 35 first_not_placeholder.prepend(placeholder) 36 return graph_module 37 38 39class ReplaceGetAttrWithPlaceholder(_pass.Transform): 40 """Replace get_attr with placeholder. 41 42 The parameters and buffers accessed by the original get_attr are returned; 43 they are useful when creating random inputs for the modified graph_module. 44 """ 45 46 _replaced_attrs: tuple[torch.Tensor, ...] | None 47 48 @property 49 def replaced_attrs(self) -> tuple[torch.Tensor, ...]: 50 """The list of replaced weight tensors.""" 51 assert ( 52 self._replaced_attrs is not None 53 ), "Must run ReplaceGetAttrWithPlaceholder first" 54 return self._replaced_attrs 55 56 def _run(self, *args, **kwargs) -> torch.fx.GraphModule: 57 graph_module = self.module 58 graph = graph_module.graph 59 replaced_attrs: list[torch.Tensor] = [] 60 for node in graph.nodes: 61 if node.op == "get_attr": 62 replaced_attr: torch.Tensor | None = None 63 # get_attr could retrieve either parameter or buffer, so 64 # we need to try both. 65 try: 66 replaced_attr = graph_module.get_parameter(node.target) 67 except AttributeError: 68 # It's possible that model author use buffer instead of 69 # parameter to store trainable weights. In this case, 70 # 1. get_parameter will throw something like 71 # AttributeError: `bias` is not an nn.Parameter. 72 # 2. get_buffer should work. 73 replaced_attr = graph_module.get_buffer(node.target) 74 75 # Reassign op type so that get_attr node becomes placeholder node. 76 node.op = "placeholder" 77 # The target name in placeholder must be a valid Python identifier. 78 # Thus, we replace, e.g., "module.submodule.weight" with 79 # "module_submodule_weight". 80 node.target = node.target.replace(".", "_") 81 # Default value is None. This is needed as long as the "graph_module" 82 # has optional inputs. Assume the original forward signature is 83 # def forward(self, x, y=None) 84 # and the replaced get_attr node has target "z". Then, the modified 85 # signature should be 86 # def forward(self, x, y=None, z=None) 87 # Without the following line, the signature will be 88 # def forward(self, x, y=None, z) 89 # , which is not valid Python code. 90 node.args = (None,) 91 92 replaced_attrs.append(replaced_attr) 93 94 self._replaced_attrs = tuple(replaced_attrs) 95 96 return graph_module 97