1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4from typing import Sequence 5 6import torch 7from torch.onnx._internal.fx import _pass, diagnostics 8 9 10class RestoreParameterAndBufferNames(_pass.Transform): 11 """Restore parameter and buffer names from original nn.module. 12 13 This pass is useful for readability of the exported ONNX graph. It restores the 14 parameter and buffer names from the original nn.module. For example, if the original 15 nn.module has a parameter named `root.linear.0.weight`, and the parameter is renamed to 16 `_param_constant9` by FX, this pass will rename it back. 17 18 This pass must be run after `Decompose` pass. Because this pass is expected to be called on 19 `fx.GraphModule` produced by `proxy_tensor.make_fx`, where all parameters and buffers 20 are registered at root level. 21 """ 22 23 def __init__( 24 self, 25 diagnostic_context: diagnostics.DiagnosticContext, 26 fx_module: torch.fx.GraphModule, 27 original_nn_module: torch.nn.Module, 28 ): 29 super().__init__(diagnostic_context, fx_module) 30 self.original_nn_module = original_nn_module 31 32 def _rename_param_and_buffer( 33 self, 34 diagnostic: diagnostics.Diagnostic, 35 nodes: Sequence[torch.fx.Node], 36 new_name: str, 37 ) -> None: 38 """Rename the parameter/buffer and replace corresponding nodes with new nodes of updated target.""" 39 assert len(nodes) > 0, "`nodes` cannot be empty" 40 assert ( 41 len({node.target for node in nodes}) == 1 42 ), "`nodes` must all have same `target`" 43 old_name = nodes[0].target 44 assert isinstance(old_name, str), f"Expected str, got type({old_name})" 45 # Parameter/buffer name cannot contain "." 46 normalized_name = new_name.replace(".", "/") 47 attr_value = getattr(self.module, old_name) 48 setattr(self.module, normalized_name, attr_value) 49 delattr(self.module, old_name) 50 for node in nodes: 51 with self.module.graph.inserting_before(node): 52 new_node = self.module.graph.get_attr(normalized_name) 53 new_node.meta = node.meta 54 node.replace_all_uses_with(new_node) 55 self.module.graph.erase_node(node) 56 diagnostic.info( 57 "Renamed 'self.%s' to 'self.%s', " 58 "normalized from original parameter name '%s'.", 59 old_name, 60 normalized_name, 61 new_name, 62 ) 63 64 def _run(self, *args, **kwargs) -> torch.fx.GraphModule: 65 """Restore parameter and buffer names from original module. 66 67 For each `get_attr` node, if the target is a str representing a parameter or buffer 68 under `self.module`, we rename the parameter or buffer to its original name. 69 The parameters and buffers between `self.module` and `self.original_nn_module` refer 70 to the same objects, allowing us to use it as key to retrieve the original name. 71 """ 72 assert len(args) == 0, "RestoreParameterAndBufferNames does not take any args" 73 assert ( 74 len(kwargs) == 0 75 ), "RestoreParameterAndBufferNames does not take any kwargs" 76 # state_to_readable_name[parameter/buffer] returns the original readable name of 77 # the parameter/buffer. E.g., "self.linear.weight". 78 state_to_readable_name: dict[torch.nn.Parameter | torch.Tensor, str] = {} 79 state_to_readable_name.update( 80 {v: k for k, v in self.original_nn_module.named_parameters()} 81 ) 82 state_to_readable_name.update( 83 {v: k for k, v in self.original_nn_module.named_buffers()} 84 ) 85 diagnostic = self.diagnostic_context.inflight_diagnostic() 86 87 # old_name_to_nodes[old_name] returns a tuple of (nodes, new_name) 88 # where `nodes` is a list of `get_attr` nodes with `old_name` as `target` and 89 # `new_name` is the new readable name. 90 old_name_to_nodes: dict[str, tuple[list[torch.fx.Node], str]] = {} 91 92 for node in self.module.graph.nodes: 93 if node.op == "get_attr": 94 assert isinstance( 95 node.target, str 96 ), f"Expected str, got type({node.target})" 97 if node.target.find(".") != -1: 98 raise RuntimeError( 99 f"Unexpected target {node.target} in get_attr, found '.' in target. " 100 f"All parameters and buffers are expected to be registered at root level, " 101 f"i.e., self.module. " 102 ) 103 if node.target in old_name_to_nodes: 104 # We have already processed this parameter/buffer. 105 old_name_to_nodes[node.target][0].append(node) 106 continue 107 attr_value = getattr(self.module, node.target) 108 if ( 109 isinstance(attr_value, (torch.nn.Parameter, torch.Tensor)) 110 and attr_value in state_to_readable_name 111 ): 112 readable_name = state_to_readable_name[attr_value] 113 old_name_to_nodes[node.target] = ([node], readable_name) 114 continue 115 116 diagnostic.info( 117 "Cannot find readable name for self.%s: %s. The name is unchanged.", 118 node.target, 119 type(attr_value), 120 ) 121 if isinstance(attr_value, torch.nn.Parameter): 122 # If it is a parameter we treat it more seriously. 123 diagnostic.level = diagnostics.levels.WARNING 124 else: 125 diagnostic.level = diagnostics.levels.NONE 126 127 for nodes, new_name in old_name_to_nodes.values(): 128 self._rename_param_and_buffer(diagnostic, nodes, new_name) 129 130 return self.module 131