1# mypy: allow-untyped-defs 2import collections 3import warnings 4from typing import Any, Dict, List, Union 5 6import torch 7from torch._export.verifier import SpecViolationError 8from torch._guards import detect_fake_mode 9from torch._library.fake_class_registry import FakeScriptObject 10from torch._subclasses.fake_tensor import unset_fake_temporarily 11from torch.export.exported_program import ( 12 ArgumentSpec, 13 CustomObjArgument, 14 ExportGraphSignature, 15 InputKind, 16 InputSpec, 17 TensorArgument, 18) 19 20 21class ConstantAttrMap(collections.abc.MutableMapping): 22 """A mapping class that understands how to use module constants (tensors, 23 ScriptObjects, FakeScriptObjects) as keys. We store tensors and FakeScriptObjects normally, 24 but ScriptObjects are stored by hash, because different torch.ScriptObjects can point to 25 the same underlying value (but we guarantee that they will `hash()` to the same value 26 if that's the case). 27 """ 28 29 def __init__(self) -> None: 30 # Underlying dict that we use to implement this mapping. 31 self._constant_attrs: Dict[ 32 Union[int, torch.Tensor, FakeScriptObject], List[Any] 33 ] = {} 34 # Map from the hash(ScriptObject) to the ScriptObject itself. Used for 35 # APIs like `__iter__` that should look like they're returning the 36 # original ScriptObjects. 37 self._script_object_map: Dict[int, torch.ScriptObject] = {} 38 39 def __getitem__( 40 self, key: Union[torch.Tensor, torch.ScriptObject, FakeScriptObject] 41 ) -> Any: 42 real_key = hash(key) if isinstance(key, torch.ScriptObject) else key 43 assert isinstance(real_key, (int, torch.Tensor, FakeScriptObject)) 44 return self._constant_attrs[real_key] 45 46 def __setitem__(self, key: Union[torch.Tensor, torch.ScriptObject], value): 47 # we shouldn't actually call this, should go to add() instead to handle aliasing 48 raise NotImplementedError( 49 """Directly setting values for ConstantAttrMap is not supported, please use add(key, value) instead. 50The same key can be mapped to multiple values, for handling constant aliasing.""" 51 ) 52 53 def add( 54 self, key: Union[torch.Tensor, torch.ScriptObject, FakeScriptObject], value: Any 55 ) -> None: 56 if isinstance(key, torch.ScriptObject): 57 if hash(key) not in self._constant_attrs: 58 self._constant_attrs[hash(key)] = [] 59 self._constant_attrs[hash(key)].append(value) 60 self._script_object_map[hash(key)] = key 61 elif isinstance(key, (torch.Tensor, FakeScriptObject)): 62 if key not in self._constant_attrs: 63 self._constant_attrs[key] = [] 64 self._constant_attrs[key].append(value) 65 else: 66 raise TypeError( 67 f"Expected key to be a tensor or ScriptObject, got {type(key)}" 68 ) 69 70 def __delitem__(self, key): 71 real_key = hash(key) if isinstance(key, torch.ScriptObject) else key 72 73 del self._constant_attrs[real_key] 74 75 def __iter__(self): 76 for key in self._constant_attrs: 77 if isinstance(key, int): 78 yield self._script_object_map[key] 79 else: 80 yield key 81 82 def __len__(self): 83 return len(self._constant_attrs) 84 85 def __contains__(self, key: object) -> bool: 86 real_key = hash(key) if isinstance(key, torch.ScriptObject) else key 87 return real_key in self._constant_attrs 88 89 90def get_constant_fqn(node: torch.fx.Node, constant_name: str) -> str: 91 # The FQN of the constant tensor in the state dict should 92 # correspond to the module where the constant tensor was 93 # originally used. 94 if len(node.meta["nn_module_stack"]) == 0: 95 return constant_name 96 parent_fqn = list(node.meta["nn_module_stack"].values())[-1][0] 97 if len(parent_fqn) > 0: 98 return f"{parent_fqn}.{constant_name}" 99 else: 100 return constant_name 101 102 103def _get_first_fqn( 104 const_attrs: ConstantAttrMap, 105 key: Union[torch.Tensor, torch.ScriptObject, FakeScriptObject], 106) -> Any: 107 fqns = const_attrs.get(key) 108 return fqns[0] if fqns else None 109 110 111def lift_constants_pass( 112 gm: torch.fx.GraphModule, 113 graph_signature: ExportGraphSignature, 114 constant_attrs: ConstantAttrMap, 115) -> Dict[str, Union[torch.Tensor, torch.ScriptObject, FakeScriptObject]]: 116 """ 117 Takes a graph module, graph signature, and modifies them implace to lift any 118 constants (tensors or custom classes) as inputs to the graph. Returns a 119 dictionary of names to constants. 120 121 Arguments: 122 gm (torch.fx.GraphModule): The graph module containing the graph and constants to lift. 123 graph_signature (ExportGraphSignature): This graph signature will be 124 mutated to add additional CONSTANT_TENSOR and CUSTOM_OBJ inputs. 125 constant_attrs (ConstantAttr): A mapping from a constant value to its 126 fully-qualified path in `gm`. This is used to maintain consistent 127 location of constants between the original module and the exported 128 version. 129 130 Returns: 131 A dictionary of fqn => constant value. 132 """ 133 all_constants: Dict[ 134 str, Union[torch.Tensor, torch.ScriptObject, FakeScriptObject] 135 ] = {} 136 137 inputs = graph_signature.input_specs 138 num_custom_obj = sum( 139 input_specs.kind == InputKind.CUSTOM_OBJ for input_specs in inputs 140 ) 141 num_tensor_constants = sum( 142 input_specs.kind == InputKind.CONSTANT_TENSOR for input_specs in inputs 143 ) 144 145 fake_mode = detect_fake_mode( 146 tuple(node.meta["val"] for node in gm.graph.nodes if node.op == "placeholder") 147 ) 148 149 first_user_input_loc, first_user_input = 0, None 150 for node in gm.graph.nodes: 151 if node.op == "placeholder" and node.name in graph_signature.user_inputs: 152 first_user_input = node 153 break 154 first_user_input_loc += 1 155 156 lifted_objs = ConstantAttrMap() 157 for node in gm.graph.nodes: 158 if node.op == "get_attr": 159 constant_val = getattr(gm, node.target) 160 if constant_val in lifted_objs: 161 # We already lifted this constant elsewhere. Just rewrite uses 162 # of this get_attr to point to the already-existing placeholder 163 # node. 164 const_placeholder_node = _get_first_fqn(lifted_objs, constant_val) 165 node.replace_all_uses_with(const_placeholder_node) 166 gm.graph.erase_node(node) 167 continue 168 169 # For ScriptObject, Tensor and FakeScriptObject constants: 170 # First check if the constant was an attribute on some module by 171 # consulting `constant_attrs` map. If it is, use the fqn that keeps 172 # its location consistent with the eager module. 173 # 174 # If it's not in the `constant_attrs` map, that means it's an inline 175 # constant (e.g. x + torch.tensor(0)), and thus did not have a 176 # specific location in the eager module. In that case, just generate 177 # some name and attach it to the module in which it was used. 178 if isinstance(constant_val, (torch.ScriptObject, FakeScriptObject)): 179 constant_kind = InputKind.CUSTOM_OBJ 180 constant_fqn = _get_first_fqn(constant_attrs, constant_val) 181 if constant_fqn is not None: 182 constant_name = constant_fqn.replace(".", "_") 183 else: 184 constant_name = f"lifted_custom_{num_custom_obj}" 185 constant_fqn = get_constant_fqn(node, constant_name) 186 num_custom_obj += 1 187 elif isinstance(constant_val, torch.Tensor): 188 # Remove the parameterness of constant_val 189 if isinstance(constant_val, torch.nn.Parameter): 190 warnings.warn( 191 f"{node.target} created when tracing {node.meta['stack_trace']} is a parameter. But" 192 f"it's not registered with register_parameter(). export will treat it as a constant tensor" 193 ) 194 # We get the real data out of the parameter by disabling the surrounding fake mode. 195 with unset_fake_temporarily(): 196 constant_val = constant_val.data 197 constant_kind = InputKind.CONSTANT_TENSOR 198 constant_fqn = _get_first_fqn(constant_attrs, constant_val) 199 if constant_fqn is not None: 200 constant_name = constant_fqn.replace(".", "_") 201 else: 202 constant_name = f"lifted_tensor_{num_tensor_constants}" 203 constant_fqn = get_constant_fqn(node, constant_name) 204 num_tensor_constants += 1 205 elif isinstance(constant_val, torch.fx.GraphModule): 206 continue 207 elif "LoweredBackendModule" in type(constant_val).__name__: 208 continue 209 else: 210 raise SpecViolationError( 211 f"getattr node {node} referencing unsupported type {type(constant_val)}" 212 ) 213 214 with gm.graph.inserting_before(first_user_input): 215 # Insert the constant node before the first user input 216 const_placeholder_node = gm.graph.placeholder(constant_name) 217 # match target name with its node name in case there is name collision 218 # and suffix is added to node name in fx 219 const_placeholder_node.target = const_placeholder_node.name 220 221 for k, v in node.meta.items(): 222 const_placeholder_node.meta[k] = v 223 224 # Once the FQN has been used, remove nn_module_stack, stack_trace 225 const_placeholder_node.meta.pop("nn_module_stack") 226 const_placeholder_node.meta.pop("stack_trace", None) 227 228 input_spec_arg: ArgumentSpec 229 if isinstance(constant_val, torch.Tensor): 230 if fake_mode is not None: 231 const_placeholder_node.meta["val"] = fake_mode.from_tensor( 232 constant_val, static_shapes=True 233 ) 234 const_placeholder_node.meta["val"].constant = constant_val 235 else: 236 const_placeholder_node.meta["val"] = constant_val 237 input_spec_arg = TensorArgument(name=const_placeholder_node.name) 238 elif isinstance(constant_val, torch._C.ScriptObject): 239 class_fqn = constant_val._type().qualified_name() # type: ignore[attr-defined] 240 const_placeholder_node.meta["val"] = CustomObjArgument( 241 constant_fqn, class_fqn 242 ) 243 input_spec_arg = CustomObjArgument( 244 name=const_placeholder_node.name, class_fqn=class_fqn 245 ) 246 elif isinstance(constant_val, FakeScriptObject): 247 class_fqn = constant_val.script_class_name 248 const_placeholder_node.meta["val"] = CustomObjArgument( 249 constant_fqn, class_fqn, constant_val 250 ) 251 input_spec_arg = CustomObjArgument( 252 name=const_placeholder_node.name, 253 class_fqn=class_fqn, 254 fake_val=constant_val, 255 ) 256 else: 257 raise SpecViolationError( 258 f"tried to lift unsupported type {type(constant_val)} from node {node.format_node()}" 259 ) 260 261 lifted_objs.add(constant_val, const_placeholder_node) 262 node.replace_all_uses_with(const_placeholder_node) 263 gm.graph.erase_node(node) 264 265 # Add the constant as a buffer to the graph signature 266 graph_signature.input_specs.insert( 267 first_user_input_loc, 268 InputSpec( 269 kind=constant_kind, 270 arg=input_spec_arg, 271 target=constant_fqn, 272 ), 273 ) 274 if constant_val in constant_attrs: 275 for fqn in constant_attrs[constant_val]: 276 all_constants[fqn] = constant_val 277 else: 278 all_constants[constant_fqn] = constant_val 279 first_user_input_loc += 1 280 281 return all_constants 282 283 284def rewrite_script_object_meta( 285 gm: torch.fx.GraphModule, 286) -> Dict[str, Union[torch.Tensor, torch.ScriptObject, FakeScriptObject],]: 287 """When tracing, we produce a graph with FakeScriptObject in the 288 meta["val"]. 289 290 For now, we rewrie meta["val"] to be a placeholder CustomObjArgument 291 """ 292 constants: Dict[ 293 str, 294 Union[ 295 torch.Tensor, 296 torch.ScriptObject, 297 FakeScriptObject, 298 ], 299 ] = {} 300 for node in gm.graph.nodes: 301 if "val" not in node.meta: 302 continue 303 304 old_meta = node.meta["val"] 305 306 if isinstance(old_meta, torch.ScriptObject): 307 class_fqn = old_meta._type().qualified_name() # type: ignore[attr-defined] 308 new_meta = CustomObjArgument(node.name, class_fqn) 309 constants[node.name] = old_meta 310 node.meta["val"] = new_meta 311 312 elif isinstance(old_meta, FakeScriptObject): 313 class_fqn = old_meta.script_class_name # type: ignore[attr-defined] 314 new_meta = CustomObjArgument(node.name, class_fqn, old_meta) 315 constants[node.name] = old_meta 316 node.meta["val"] = new_meta 317 318 return constants 319