1# mypy: allow-untyped-defs 2import ast 3import dataclasses 4import inspect 5import math 6import operator 7import re 8from inspect import Parameter 9from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING 10 11import torch 12from torch._guards import detect_fake_mode 13from torch._subclasses.fake_tensor import FakeTensor 14 15 16if TYPE_CHECKING: 17 from torch._export.passes.lift_constants_pass import ConstantAttrMap 18 from torch.export import ExportedProgram 19 from torch.export.graph_signature import ExportGraphSignature 20 21from torch.export.graph_signature import InputKind, OutputKind 22from torch.utils._pytree import ( 23 _register_pytree_node, 24 Context, 25 FlattenFunc, 26 FromDumpableContextFn, 27 GetAttrKey, 28 KeyPath, 29 keystr, 30 MappingKey, 31 SequenceKey, 32 ToDumpableContextFn, 33 tree_flatten_with_path, 34 UnflattenFunc, 35) 36 37 38placeholder_prefixes = { 39 InputKind.USER_INPUT: "", 40 InputKind.PARAMETER: "p_", 41 InputKind.BUFFER: "b_", 42 InputKind.CONSTANT_TENSOR: "c_", 43 InputKind.CUSTOM_OBJ: "obj_", 44 InputKind.TOKEN: "token", 45} 46 47 48def _collect_and_set_constant_attrs( 49 graph_signature, constants, mod 50) -> "ConstantAttrMap": 51 # the exported module will store constants & non-persistent buffers such that 52 # retracing treats them as persistent buffers, so we inform the constants lifting pass 53 # and overwrite the new graph signature using the previous program. This is intended to only be used 54 # in run_decompositions where we still have access to original EP. 55 from torch._export.passes.lift_constants_pass import ConstantAttrMap 56 57 constant_attrs = ConstantAttrMap() 58 non_persistent_buffers = { 59 spec.target 60 for spec in graph_signature.input_specs 61 if spec.kind == InputKind.BUFFER and not spec.persistent 62 } 63 for name, value in constants.items(): 64 if name in non_persistent_buffers: 65 continue 66 # recursive getattr 67 _mod = mod 68 *atoms, attr = name.split(".") 69 for atom in atoms: 70 _mod = getattr(_mod, atom) 71 # remove as buffer, reassign as constant/non-persistent buffer 72 _mod._buffers.pop(attr, None) 73 setattr(_mod, attr, value) 74 constant_attrs.add(value, name) 75 return constant_attrs 76 77 78def _overwrite_signature_for_non_persistent_buffers( 79 old_sig: "ExportGraphSignature", new_sig: "ExportGraphSignature" 80): 81 # overwrite signature for non-persistent buffers 82 non_persistent_buffers = { 83 spec.target 84 for spec in old_sig.input_specs 85 if spec.kind == InputKind.BUFFER and not spec.persistent 86 } 87 88 for spec in new_sig.input_specs: 89 if spec.kind == InputKind.BUFFER and spec.target in non_persistent_buffers: 90 spec.persistent = False 91 return new_sig 92 93 94def _collect_param_buffer_metadata(mod: torch.fx.GraphModule) -> Dict[str, Any]: 95 """ 96 Param/buffer metadata needs to be saved before lowering to aten IR 97 because aten IR lifts them, as a result, automatic preservation doesn't work. 98 This is intended to be called on the strict mode tracing right before lowering to 99 aten IR OR run_decomposition pass. 100 """ 101 params_buffers_to_node_meta = {} 102 103 def _getattr(model: torch.fx.GraphModule, attr_name: str): 104 *prefix, field = attr_name.split(".") 105 t = model 106 for item in prefix: 107 t = getattr(t, item, None) # type: ignore[assignment] 108 assert t is not None 109 110 return getattr(t, field) 111 112 for node in mod.graph.nodes: 113 target = node.target 114 meta = node.meta 115 if node.op == "call_module": 116 submodule = _getattr(mod, target) 117 if isinstance(submodule, torch.nn.Module): 118 for name, _ in submodule.named_parameters( 119 recurse=True, remove_duplicate=False 120 ): 121 params_buffers_to_node_meta[target + "." + name] = meta 122 123 for name, _ in submodule.named_buffers( 124 recurse=True, remove_duplicate=False 125 ): 126 params_buffers_to_node_meta[target + "." + name] = meta 127 128 if node.op == "get_attr": 129 submodule = _getattr(mod, target) 130 if not isinstance(submodule, torch.fx.GraphModule): 131 params_buffers_to_node_meta[target] = meta 132 133 # If the call_function uses param as input, we also need to update params' meta 134 # with this call_function node's meta. 135 # This is basically the same flow as torch.fx.traceback.preserve_meta() 136 if node.op == "call_function" and not isinstance( 137 node.target, torch._ops.HigherOrderOperator 138 ): 139 for arg in node._input_nodes: 140 if arg.op == "get_attr": 141 for entry in torch.fx.proxy._COPY_META_FIELDS: 142 if entry in meta: 143 params_buffers_to_node_meta[arg.target][entry] = meta[entry] 144 145 return params_buffers_to_node_meta 146 147 148def _populate_param_buffer_metadata_to_new_gm( 149 params_buffers_to_node_meta: Dict[str, Any], 150 gm: torch.fx.GraphModule, 151 new_sig: "ExportGraphSignature", 152) -> None: 153 """ 154 Given that we collected param'buffer metadata before, we put them back in 155 newly traced graph module 156 """ 157 # Don't copy over nn_module_stack, stack_trace metadata for params/buffers nodes 158 for metadata in params_buffers_to_node_meta.values(): 159 metadata.pop("nn_module_stack", None) 160 metadata.pop("stack_trace", None) 161 162 for node in gm.graph.nodes: 163 if node.op == "placeholder": 164 if node.target in new_sig.inputs_to_parameters: 165 param_name = new_sig.inputs_to_parameters[node.target] 166 if param_name in params_buffers_to_node_meta: 167 for k, v in params_buffers_to_node_meta[param_name].items(): 168 node.meta[k] = v 169 if node.target in new_sig.inputs_to_buffers: 170 buffer_name = new_sig.inputs_to_buffers[node.target] 171 if buffer_name in params_buffers_to_node_meta: 172 for k, v in params_buffers_to_node_meta[buffer_name].items(): 173 node.meta[k] = v 174 175 176def _get_shape_env_from_gm(gm: torch.fx.GraphModule): 177 vals = [ 178 node.meta["val"] 179 for node in gm.graph.nodes 180 if node.meta.get("val", None) is not None 181 ] 182 183 fake_mode = _detect_fake_mode_from_gm(gm) 184 if fake_mode is not None: 185 return fake_mode.shape_env 186 for v in vals: 187 if isinstance(v, torch.SymInt): 188 return v.node.shape_env 189 190 191def _rename_without_collisions( 192 name_map: Dict[str, str], 193 orig_name: str, 194 name: str, 195 is_placeholder: bool = False, 196): 197 """ 198 Renames nodes to avoid name collisions, with suffixing. 199 name_map: map from original name to new name 200 orig_name: mapping key 201 name: candidate name (potentially suffixed, e.g. mul_2) 202 is_placeholder: if the node is a placeholder, avoid detecting suffix 203 """ 204 if name in name_map.values(): 205 # non-placeholder nodes may be suffixed with the count 206 # instead of adding another suffix, we will try to increment it 207 match = re.match(r"(.*)_(\d+)", name) 208 if match and not is_placeholder: 209 name, n = match.group(1), int(match.group(2)) 210 else: 211 n = 0 212 while (dup_name := f"{name}_{n + 1}") in name_map.values(): 213 n += 1 214 name_map[orig_name] = dup_name 215 else: 216 name_map[orig_name] = name 217 return name_map[orig_name] 218 219 220def _check_input_constraints_for_graph( 221 input_placeholders: List[torch.fx.Node], flat_args_with_path, range_constraints 222): 223 def get_keystr(key_path: KeyPath) -> str: 224 """For a given index into the flat_args, return a human readable string 225 describing how to access it, e.g. "*args["foo"][0].bar" 226 """ 227 # Prefix the keypath with "*args" or "**kwargs" to make it clearer where 228 # the arguments come from. Ultimately we ought to serialize the 229 # original arg names for the best error message here. 230 args_kwargs_key_path = key_path[0] 231 assert isinstance(args_kwargs_key_path, SequenceKey) 232 if args_kwargs_key_path.idx == 0: 233 return f"*args{keystr(key_path[1:])}" 234 else: 235 kwarg_key = key_path[1] 236 assert isinstance(kwarg_key, MappingKey) 237 name = str(kwarg_key)[1:-1] # get rid of the enclosed [] 238 return f"{name}{keystr(key_path[2:])}" 239 240 import sympy 241 242 from torch._export.passes.add_runtime_assertions_for_constraints_pass import ( 243 _convert_range_to_int, 244 ) 245 from torch.utils._sympy.solve import try_solve 246 247 if len(flat_args_with_path) != len(input_placeholders): 248 raise RuntimeError( 249 "Unexpected number of inputs " 250 f"(expected {len(input_placeholders)}, got {len(flat_args_with_path)})" 251 ) 252 # NOTE: export already guarantees that the same symbol is used in metadata 253 # for all InputDims related by equality constraints, so we can just unify 254 # symbols with given input dimension values to check equality constraints. 255 unification_map: Dict[sympy.Symbol, Any] = {} 256 for (key_path, arg), node in zip(flat_args_with_path, input_placeholders): 257 node_val = node.meta.get("val") 258 if isinstance(node_val, FakeTensor): 259 if not isinstance(arg, torch.Tensor): 260 raise RuntimeError( 261 f"Expected input at {get_keystr(key_path)} to be a tensor, but got {type(arg)}", 262 ) 263 264 if len(node_val.shape) != len(arg.shape): 265 raise RuntimeError( 266 f"Unexpected number of dimensions in input at {get_keystr(key_path)}.shape " 267 f"(expected {node_val.shape}, got {arg.shape})" 268 ) 269 270 for j, (arg_dim, node_dim) in enumerate(zip(arg.shape, node_val.shape)): 271 # TODO(avik): Assert the following property in the IR verifier: 272 # node_dim is either an int or a SymInt containing an int or a unary sympy.Expr 273 if ( 274 isinstance(node_dim, torch.SymInt) 275 and len(node_dim.node.expr.free_symbols) == 1 276 ): 277 symbol = next(iter(node_dim.node.expr.free_symbols)) 278 if symbol in unification_map: 279 existing_dim = node_dim.node.expr.subs(unification_map) 280 if arg_dim != existing_dim: 281 raise RuntimeError( 282 f"Expected input at {get_keystr(key_path)}.shape[{j}] to be equal to " 283 f"{existing_dim}, but got {arg_dim}", 284 ) 285 else: 286 if ( 287 isinstance(arg_dim, torch.SymInt) 288 and not arg_dim.node.expr.is_number 289 ): 290 # This can happen when, say, arg is a fake tensor. 291 # We do not run checks on symbolic shapes of fake inputs as 292 # such checks can affect the shape env. 293 pass 294 else: 295 if isinstance(node_dim.node.expr, sympy.Symbol): 296 # Short cut for try_solve below. Also useful in cases where 297 # sympy.Eq(node_dim.node.expr, arg_dim) would evaluate to False 298 # purely because symbol is constrained to be size-like, 299 # e.g., when node_dim.node.expr = symbol and arg_dim = 0. 300 unification_map[symbol] = int(arg_dim) 301 else: 302 solution = try_solve( 303 sympy.Eq(node_dim.node.expr, arg_dim), symbol 304 ) 305 if solution is None: 306 raise RuntimeError( # noqa: B904 307 f"Expected input {node.name}.shape[{j}] = {arg_dim} to be " 308 f"of the form {node_dim.node.expr}, where {symbol} is an integer" 309 ) 310 else: 311 unification_map[symbol] = int(solution[1]) 312 313 if node_dim.node.expr in range_constraints: 314 min_val, max_val = _convert_range_to_int( 315 range_constraints[node_dim.node.expr] 316 ) 317 # NOTE: we allow dimensions to be 0/1 at runtime 318 if min_val > 2: 319 if arg_dim < min_val: 320 raise RuntimeError( 321 f"Expected input at {get_keystr(key_path)}.shape[{j}] to be >= " 322 f"{min_val}, but got {arg_dim}", 323 ) 324 if max_val < math.inf: 325 if arg_dim > max_val: 326 raise RuntimeError( 327 f"Expected input at {get_keystr(key_path)}.shape[{j}] to be <= " 328 f"{max_val}, but got {arg_dim}", 329 ) 330 else: 331 if arg_dim != node_dim: 332 if ( 333 isinstance(node_dim, torch.SymInt) 334 and not node_dim.node.expr.is_number 335 ): 336 # this means we deferred a guard from export analysis to runtime, let this pass 337 # we'll add a runtime assert checking equality to this replacement expression 338 continue 339 raise RuntimeError( 340 f"Expected input at {get_keystr(key_path)}.shape[{j}] to be equal to " 341 f"{node_dim}, but got {arg_dim}", 342 ) 343 elif isinstance(node_val, (int, float, str)): 344 if type(arg) != type(node_val) or arg != node_val: 345 raise RuntimeError( 346 f"Expected input at {get_keystr(key_path)} to be equal to {node_val}, but got {arg}", 347 ) 348 349 350def register_dataclass_as_pytree_node( 351 cls: Type[Any], 352 flatten_fn: Optional[FlattenFunc] = None, 353 unflatten_fn: Optional[UnflattenFunc] = None, 354 *, 355 serialized_type_name: Optional[str] = None, 356 to_dumpable_context: Optional[ToDumpableContextFn] = None, 357 from_dumpable_context: Optional[FromDumpableContextFn] = None, 358 return_none_fields: bool = False, 359) -> None: 360 assert dataclasses.is_dataclass( 361 cls 362 ), f"Only dataclasses can be registered with this function: {cls}" 363 364 def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]: 365 flattened = [] 366 flat_names = [] 367 none_names = [] 368 for f in dataclasses.fields(obj): 369 name, val = f.name, getattr(obj, f.name) 370 if val is not None or return_none_fields: 371 flattened.append(val) 372 flat_names.append(name) 373 else: 374 none_names.append(name) 375 return flattened, [flat_names, none_names] 376 377 def default_unflatten_fn(values: Iterable[Any], context: Context) -> Any: 378 flat_names, none_names = context 379 return cls(**dict(zip(flat_names, values)), **dict.fromkeys(none_names)) 380 381 def default_flatten_fn_with_keys(obj: Any) -> Tuple[List[Any], Context]: 382 flattened, (flat_names, none_names) = flatten_fn(obj) # type: ignore[misc] 383 return [(MappingKey(k), v) for k, v in zip(flat_names, flattened)], flat_names 384 385 flatten_fn = flatten_fn if flatten_fn is not None else default_flatten_fn 386 unflatten_fn = unflatten_fn if unflatten_fn is not None else default_unflatten_fn 387 388 if (to_dumpable_context is None) ^ (from_dumpable_context is None): 389 raise ValueError( 390 f"Both to_dumpable_context and from_dumpable_context for {cls} must " 391 "be None or registered." 392 ) 393 394 _register_pytree_node( 395 cls, 396 flatten_fn, 397 unflatten_fn, 398 serialized_type_name=serialized_type_name, 399 flatten_with_keys_fn=default_flatten_fn_with_keys, 400 to_dumpable_context=to_dumpable_context, 401 from_dumpable_context=from_dumpable_context, 402 ) 403 404 405def is_param(program: "ExportedProgram", node: torch.fx.Node) -> bool: 406 """ 407 Checks if the given node is a parameter within the exported program 408 """ 409 410 return node.name in program.graph_signature.inputs_to_parameters 411 412 413def get_param( 414 program: "ExportedProgram", 415 node: torch.fx.Node, 416) -> Optional[torch.nn.Parameter]: 417 """ 418 Returns the parameter associated with the given node in the exported program. 419 Returns None if the node is not a parameter within the exported program 420 """ 421 422 if is_param(program, node): 423 parameter_name = program.graph_signature.inputs_to_parameters[node.name] 424 return program.state_dict[parameter_name] 425 426 return None 427 428 429def is_buffer(program: "ExportedProgram", node: torch.fx.Node) -> bool: 430 """ 431 Checks if the given node is a buffer within the exported program 432 """ 433 434 return node.name in program.graph_signature.inputs_to_buffers 435 436 437def get_buffer( 438 program: "ExportedProgram", 439 node: torch.fx.Node, 440) -> Optional[torch.Tensor]: 441 """ 442 Returns the buffer associated with the given node in the exported program. 443 Returns None if the node is not a buffer within the exported program 444 """ 445 446 if is_buffer(program, node): 447 buffer_name = program.graph_signature.inputs_to_buffers[node.name] 448 if buffer_name in program.graph_signature.non_persistent_buffers: 449 return program.constants[buffer_name] 450 else: 451 return program.state_dict[buffer_name] 452 453 return None 454 455 456def is_lifted_tensor_constant( 457 program: "ExportedProgram", 458 node: torch.fx.Node, 459) -> bool: 460 """ 461 Checks if the given node is a lifted tensor constant within the exported program 462 """ 463 464 return node.name in program.graph_signature.inputs_to_lifted_tensor_constants 465 466 467def get_lifted_tensor_constant( 468 program: "ExportedProgram", 469 node: torch.fx.Node, 470) -> Optional[torch.Tensor]: 471 """ 472 Returns the lifted tensor constant associated with the given node in the exported program. 473 Returns None if the node is not a lifted tensor constant within the exported program 474 """ 475 476 if is_lifted_tensor_constant(program, node): 477 lifted_tensor_name = program.graph_signature.inputs_to_lifted_tensor_constants[ 478 node.name 479 ] 480 return program.constants[lifted_tensor_name] 481 482 return None 483 484 485def sequential_split(gm: torch.fx.GraphModule, node_call_back) -> torch.fx.GraphModule: 486 """ 487 sequential_split creates a new graph module that splits the input graph module into multiple submodules 488 based on the node_call_back. It doesn't mutate the input graph module. The node_call_back should return 489 True if the node is a delimiter. Delimiter will be the first node in the next submodule. 490 """ 491 from torch.fx.passes.split_module import split_module 492 493 split_map = {} 494 split_id = 0 495 for node in gm.graph.nodes: 496 if node_call_back(node): 497 split_id += 1 498 split_map[node] = split_id 499 500 new_gm = split_module( 501 gm, 502 gm, 503 lambda node: split_map[node], 504 keep_original_order=True, 505 keep_original_node_name=True, 506 ) 507 # Keep the codegen from original graph module to preserve e.g. pytree info. 508 new_gm.graph._codegen = gm.graph._codegen 509 new_gm.recompile() 510 return new_gm 511 512 513def nodes_filter(nodes: List[torch.fx.Node], node_call_back) -> List[torch.fx.Node]: 514 """Returns the nodes that match the node_call_back as a list.""" 515 return [node for node in nodes if node_call_back(node)] 516 517 518def nodes_first( 519 nodes: List[torch.fx.Node], node_call_back=None 520) -> Optional[torch.fx.Node]: 521 """ 522 Returns the first node that matches the node_call_back. If no node matches, returns None. 523 When node_call_back is None, returns the first node in the node list. 524 """ 525 ret = nodes_filter(nodes, node_call_back if node_call_back else lambda node: True) 526 if len(ret) > 0: 527 return ret[0] 528 return None 529 530 531def nodes_count(nodes: List[torch.fx.Node], node_call_back) -> int: 532 """Returns the number of nodes that match the node_call_back.""" 533 return len(nodes_filter(nodes, node_call_back)) 534 535 536def nodes_map(nodes: List[torch.fx.Node], node_call_back) -> List[torch.fx.Node]: 537 """ 538 Sequentially visit the nodes list and invoke node_call_back on each element. 539 Returns the nodes list after the node_call_back is invoked on each element. 540 """ 541 for node in nodes: 542 node_call_back(node) 543 return nodes 544 545 546def node_replace_(old_node: torch.fx.Node, new_node: torch.fx.Node) -> None: 547 """ 548 Replace all uses of old_node with new_node. 549 """ 550 old_node.replace_all_uses_with(new_node) 551 old_node.users.clear() 552 old_node.graph.erase_node(old_node) 553 554 555def node_inline_(call_mod_node: torch.fx.Node) -> None: 556 """ 557 Inline the submodule of the given node into the parent module. 558 Note: we only support the case where submodule takes tensors inputs. 559 """ 560 assert call_mod_node.op == "call_module" 561 gm = call_mod_node.graph.owning_module 562 563 assert isinstance(call_mod_node.target, str) 564 sub_gm = getattr(gm, call_mod_node.target) 565 566 phs = (node for node in sub_gm.graph.nodes if node.op == "placeholder") 567 body = ( 568 node for node in sub_gm.graph.nodes if node.op not in ("placeholder", "output") 569 ) 570 output = [node for node in sub_gm.graph.nodes if node.op == "output"] 571 572 for ph, arg in zip(phs, call_mod_node.args): 573 assert isinstance(arg, torch.fx.Node) 574 node_replace_(ph, arg) 575 576 with gm.graph.inserting_before(call_mod_node): 577 for node in body: 578 new_node = gm.graph.node_copy(node) 579 node_replace_(node, new_node) 580 581 if len(output) > 0: 582 assert len(output) == 1 and len(output[0].args) == 1 583 new_output = output[0].args[0] 584 585 if isinstance(new_output, torch.fx.Node): 586 # Clear the users of the output node and set 587 # the users to be the users of original call_module node. 588 new_output.users.clear() 589 node_replace_(call_mod_node, new_output) 590 elif isinstance(new_output, (list, tuple)): 591 # Pop subgraph output node from users. 592 for node in new_output: 593 node.users.pop(output[0]) 594 595 # Inline the get_item calls for the output node. 596 get_item_users = nodes_filter( 597 list(call_mod_node.users.keys()), 598 lambda node: node.op == "call_function" 599 and node.target == operator.getitem, 600 ) 601 # get_item_node.args[1] is the idx referring to new_output[idx] 602 nodes_map( 603 get_item_users, 604 lambda get_item_node: node_replace_( 605 get_item_node, 606 new_output[get_item_node.args[1]], 607 ), 608 ) 609 call_mod_node.graph.erase_node(call_mod_node) 610 else: 611 raise NotImplementedError( 612 f"Unsupported output type {type(new_output)}. Expect it to be a Node or a list/tuple of Nodes." 613 ) 614 else: 615 call_mod_node.graph.erase_node(call_mod_node) 616 617 gm.delete_all_unused_submodules() 618 gm.recompile() 619 return gm 620 621 622def _get_torch_jit_trace_forward_signature(mod: torch.nn.Module): 623 """ 624 Get source code and parse argument names using AST. The function returns 625 a signature of the forward() function. 626 627 # TODO: Directly provide inspect.signature compatible TS-d module. 628 """ 629 ast_mod = ast.parse(mod.code) 630 ast_func_def: ast.FunctionDef = ast_mod.body[0] # type: ignore[assignment] 631 632 # FIXME(jiashenc): TorchScript should only allow positional or keywords arguments. 633 arg_type_map = {"args": Parameter.POSITIONAL_OR_KEYWORD} 634 635 # Traverse all argument types in AST tree and create associated parameters. 636 param_list = [] 637 for arg_type, param_type in arg_type_map.items(): 638 arg_name_list = [a.arg for a in getattr(ast_func_def.args, arg_type)] 639 for arg_name in arg_name_list: 640 if arg_name == "self": 641 continue # Skip self argument. 642 param_list.append(inspect.Parameter(arg_name, param_type)) 643 644 return inspect.Signature(parameters=param_list) 645 646 647def _bind_signature_to_inputs(mod, fake_args, fake_kwargs): 648 if isinstance(mod, (torch.jit.ScriptModule, torch.jit.TracedModule)): 649 sig = _get_torch_jit_trace_forward_signature(mod) 650 651 # Sanity check for placeholder names coming from TorchScript. 652 assert len(sig.parameters) == len(fake_args) + len(fake_kwargs), ( 653 "Arguments other than POSITIONAL_OR_KEYWORD kinds in forward() " 654 "are not supported in _get_torch_jit_trace_forward_signature" 655 ) 656 else: 657 sig = inspect.signature(mod.forward) 658 659 return sig.bind(*fake_args, **fake_kwargs).arguments 660 661 662def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None: 663 """ 664 Propagate placeholder names from the top-level graph into HigherOrderOp subgraphs, 665 and handle collisions with non-placeholders by count suffixing. 666 Different HOO subgraph types have different input schemas, so we first enumerate them 667 and gather the top-level named placeholder nodes. 668 """ 669 # gather all HOO subgraphs and their top-level named placeholder nodes 670 subgraph_ph_tuples: List[Tuple[torch.fx.GraphModule, List[torch.fx.Node]]] = [] 671 for node in gm.graph.nodes: 672 if node.op == "call_function" and isinstance( 673 node.target, torch._ops.HigherOrderOperator 674 ): 675 # HOO subgraphs have varying input schemas, so we enumerate them there 676 if node.target._name == "cond": 677 _, true_graph, false_graph, cond_args = node._args 678 subgraph_ph_tuples.append((getattr(gm, true_graph.target), cond_args)) 679 subgraph_ph_tuples.append((getattr(gm, false_graph.target), cond_args)) 680 elif node.target._name == "wrap_with_set_grad_enabled": 681 subgraph, phs = node._args[1], node._args[2:] 682 subgraph_ph_tuples.append((getattr(gm, subgraph.target), phs)) 683 elif node.target._name == "map_impl": 684 body_graph, array, args = node._args 685 subgraph_ph_tuples.append( 686 (getattr(gm, body_graph.target), array + args) 687 ) 688 689 # propagate names 690 for subgraph, hoo_phs in subgraph_ph_tuples: 691 name_map: Dict[str, str] = {} 692 for i, node in enumerate(subgraph.graph.nodes): 693 if i < len(hoo_phs): # placeholder, retain name 694 name_map[node.name] = hoo_phs[i].name 695 node.name = node.target = hoo_phs[i].name 696 else: # non-placeholder, check for collisions 697 node.name = _rename_without_collisions(name_map, node.name, node.name) 698 699 # recurse and recompile 700 _name_hoo_subgraph_placeholders(subgraph) 701 subgraph.recompile() 702 703 704def placeholder_naming_pass( 705 gm: torch.fx.GraphModule, 706 export_graph_signature: "ExportGraphSignature", 707 mod: torch.nn.Module, 708 fake_args, 709 fake_kwargs, 710 fake_params_buffers, 711 constants: Dict[str, Any], 712) -> None: 713 """ 714 This pass is run at the end of _export_non_strict() to assign better placeholder node names: 715 - User inputs: 716 These follow the signature of mod.forward(), e.g. forward(x, y) produces nodes x, y. 717 For nested inputs from dictionaries, lists, tuples, or dataclasses, 718 the names are a concatenation of the path to the tensor. 719 e.g. x = { 720 'a': torch.randn(), 721 'b': [torch.randn(), torch.randn()] 722 } 723 produces nodes x_a, x_b_0, x_b_1. 724 - Parameters/buffers/constants/custom objects: 725 These follow the FQN of the object, prefixed by "p", "b", "c", "obj" respectively. 726 e.g. self.bar.l0.weight produces "p_bar_l0_weight". 727 - Effect tokens: 728 These are named token, token_1, ... 729 """ 730 731 def _strip_name(x): 732 if x.startswith("L__self___"): 733 x = x[len("L__self___") :] 734 elif x.startswith("self_"): 735 x = x[len("self_") :] 736 x = re.sub(r"[^a-zA-Z0-9]", "_", x) 737 return x 738 739 def _extract_pytree_key(x): 740 if isinstance(x, MappingKey): 741 x = re.sub(r"[^a-zA-Z0-9]", "_", str(x.key)) 742 return x 743 elif isinstance(x, SequenceKey): 744 return str(x.idx) 745 elif isinstance(x, GetAttrKey): 746 return x.name 747 else: 748 raise RuntimeError(f"Pytree key of type {type(x)} not handled for {x}") 749 750 name_map: Dict[str, str] = {} 751 752 # map user input names with mod.forward() signature 753 combined_args = _bind_signature_to_inputs(mod, fake_args, fake_kwargs) 754 755 flat_args_with_path, _ = tree_flatten_with_path(combined_args) 756 user_input_names = [ 757 spec.arg.name 758 for spec in export_graph_signature.input_specs 759 if spec.kind == InputKind.USER_INPUT 760 ] 761 762 # use pytree path to name nested user inputs 763 for (arg_path, arg), user_input_name in zip(flat_args_with_path, user_input_names): 764 if user_input_name: 765 _rename_without_collisions( 766 name_map, 767 user_input_name, 768 placeholder_prefixes[InputKind.USER_INPUT] 769 + "_".join(_extract_pytree_key(x).lower() for x in arg_path), 770 is_placeholder=True, 771 ) 772 773 # use graph signature input specs to map param/buffer/constant names 774 # name effect tokens as token, token_1, ... (these aren't visible to user) 775 for spec in export_graph_signature.input_specs: 776 if spec.kind == InputKind.USER_INPUT: 777 continue 778 if spec.kind == InputKind.TOKEN: 779 base_name = "" 780 else: 781 base_name = _strip_name(spec.target).lower() 782 base_name = re.sub(r"[^a-zA-Z0-9]", "_", base_name) 783 784 _rename_without_collisions( 785 name_map, 786 spec.arg.name, 787 placeholder_prefixes[spec.kind] + base_name, 788 is_placeholder=True, 789 ) 790 791 # handle naming collisions with call_function/get_attr inputs. 792 # here, we want to prioritize user input names over call_function names 793 # e.g. not have forward(self, mul): lead to a placeholder node called mul_13, 794 # so we increment the suffix of call_function nodes as needed 795 for node in gm.graph.nodes: 796 if node.op == "placeholder": 797 continue 798 _rename_without_collisions(name_map, node.name, node.name) 799 800 # assign new node names 801 for node in gm.graph.nodes: 802 if node.op == "placeholder": 803 assert node.name in name_map 804 node.name = node.target = name_map[node.name] 805 elif node.name in name_map: 806 node.name = name_map[node.name] 807 808 # propagate names to higher order op subgraphs 809 _name_hoo_subgraph_placeholders(gm) 810 811 # re-generate graph module code 812 gm.recompile() 813 814 # modify graph signature (input specs, output specs, user input mutations) 815 for spec in export_graph_signature.input_specs: 816 assert spec.arg.name in name_map 817 spec.arg.name = name_map[spec.arg.name] 818 if ( # handle targets for custom objects 819 spec.kind == InputKind.CUSTOM_OBJ and spec.target in name_map 820 ): 821 spec.target = name_map[spec.target][4:] # strip obj_ prefix 822 823 for spec in export_graph_signature.output_specs: 824 if spec.arg.name in name_map: 825 spec.arg.name = name_map[spec.arg.name] 826 if spec.kind == OutputKind.USER_INPUT_MUTATION and spec.target in name_map: 827 spec.target = name_map[spec.target] 828 829 # rename keys in constants dict for custom objects 830 for name in list(constants.keys()): 831 constant = constants[name] 832 if name in name_map and not isinstance( 833 constant, torch.Tensor 834 ): # rename custom objects with generic names 835 new_name = name_map[name] 836 if ( 837 new_name != name 838 and re.match(r"arg(\d+)_1", name) 839 and new_name != placeholder_prefixes[InputKind.CUSTOM_OBJ] + name 840 ): 841 constants[new_name] = constant 842 del constants[name] 843 844 845def remove_proxy_from_state_dict(state_dict: Dict, in_place: bool) -> Dict: 846 """ 847 If `in_place` is false, return a new copy of `state_dict` with "proxy" removed from `v.__dict__`. 848 `v` is the values in the dictionary. 849 If `in_place` is true, modify `state_dict` in place. 850 """ 851 if in_place: 852 for k, v in state_dict.items(): 853 if hasattr(v, "proxy"): 854 delattr(state_dict[k], "proxy") 855 return state_dict 856 else: 857 new_state_dict = {} 858 for k, v in state_dict.items(): 859 if hasattr(v, "proxy"): 860 new_state_dict[k] = v.clone().detach() 861 else: 862 new_state_dict[k] = v 863 return new_state_dict 864 865 866def _detect_fake_mode_from_gm( 867 gm: torch.fx.GraphModule, 868) -> torch._subclasses.fake_tensor.FakeTensorMode: 869 """ 870 For a given graph module, we look at the "val" of placeholder nodes to find the fake inputs. 871 Additionally, if gm doesn't have placeholders, we further look at the "example_value" or "val" of other nodes. 872 If no fake mode is found, we return None for fake_mode. 873 """ 874 875 fake_inps: List[torch.Tensor] = [] 876 fake_vals: List[torch.Tensor] = [] 877 for node in gm.graph.nodes: 878 if node.op == "placeholder" and "val" in node.meta: 879 fake_val = node.meta["val"] 880 if fake_val is not None and isinstance(fake_val, torch.Tensor): 881 fake_inps.append(fake_val) 882 elif len(fake_inps) == 0 and ( 883 "example_value" in node.meta or "val" in node.meta 884 ): 885 fake_val = None 886 if "example_value" in node.meta: 887 fake_val = node.meta["example_value"] 888 elif "val" in node.meta: 889 fake_val = node.meta["val"] 890 if fake_val is not None and isinstance(fake_val, torch.Tensor): 891 fake_vals.append(fake_val) 892 893 return detect_fake_mode(fake_inps + fake_vals) 894