1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import dataclasses 4import functools 5import inspect 6import logging 7import re 8import time 9import warnings 10from contextlib import contextmanager, nullcontext 11from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union 12 13import torch 14import torch._dynamo 15import torch.fx 16import torch.utils._pytree as pytree 17from torch._dispatch.python import enable_python_dispatcher 18from torch._dynamo.exc import UserError, UserErrorType 19from torch._export.db.logging import ( 20 exportdb_error_message, 21 get_class_if_classified_error, 22) 23from torch._export.non_strict_utils import ( 24 _fakify_script_objects, 25 _gather_constant_attrs, 26 _NonStrictTorchFunctionHandler, 27 make_constraints, 28 make_fake_inputs, 29 produce_guards_and_solve_constraints, 30) 31from torch._export.passes._node_metadata_hook import ( 32 _node_metadata_hook, 33 _set_node_metadata_hook, 34) 35from torch._export.passes.collect_tracepoints_pass import CollectTracepointsPass 36from torch._export.passes.lift_constants_pass import ( 37 ConstantAttrMap, 38 lift_constants_pass, 39 rewrite_script_object_meta, 40) 41from torch._export.utils import ( 42 _collect_param_buffer_metadata, 43 _get_shape_env_from_gm, 44 _populate_param_buffer_metadata_to_new_gm, 45 placeholder_naming_pass, 46 placeholder_prefixes, 47) 48from torch._export.verifier import SpecViolationError 49from torch._export.wrappers import _wrap_submodules 50from torch._functorch._aot_autograd.input_output_analysis import ( 51 _graph_input_names, 52 _graph_output_names, 53) 54from torch._functorch._aot_autograd.traced_function_transforms import ( 55 create_functional_call, 56) 57from torch._functorch._aot_autograd.utils import create_tree_flattened_fn 58from torch._functorch.aot_autograd import aot_export_module 59from torch._guards import detect_fake_mode 60from torch._library.fake_class_registry import FakeScriptObject 61from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode 62from torch._utils_internal import log_export_usage 63from torch.export.dynamic_shapes import ( 64 _check_dynamic_shapes, 65 _combine_args, 66 _transform_shapes_for_default_dynamic, 67) 68from torch.export.exported_program import OutputKind 69from torch.fx._utils import first_call_function_nn_module_stack 70from torch.fx.experimental.proxy_tensor import make_fx 71from torch.fx.experimental.symbolic_shapes import ( 72 ConstraintViolationError, 73 free_unbacked_symbols, 74 GuardOnDataDependentSymNode, 75 ShapeEnv, 76) 77from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo 78from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts 79from torch.utils._pytree import TreeSpec 80from torch.utils._sympy.value_ranges import ValueRangeError 81 82from ._safeguard import AutogradStateOpsFailSafeguard 83from .exported_program import ( 84 _disable_prexisiting_fake_mode, 85 ExportedProgram, 86 InputKind, 87 ModuleCallEntry, 88 ModuleCallSignature, 89) 90from .graph_signature import _convert_to_export_graph_signature, ExportGraphSignature 91 92 93log = logging.getLogger(__name__) 94 95 96@dataclasses.dataclass 97class ExportDynamoConfig: 98 """ 99 Manage Export-specific configurations of Dynamo. 100 """ 101 102 allow_rnn: bool = True 103 reorderable_logging_functions: Set[Callable] = dataclasses.field( 104 default_factory=set 105 ) 106 # Emit runtime asserts after AOTAutograd instead. 107 # This isn't really necessary, and isn't much more efficient since the runtime asserts pass does CSE, 108 # but if we want to reason more about what guards/runtime asserts to emit, 109 # this makes it a bit cleaner to do from the export side. Also no real point in running this twice. 110 do_not_emit_runtime_asserts = True 111 112 113@dataclasses.dataclass 114class ATenExportArtifact: 115 gm: torch.fx.GraphModule 116 sig: ExportGraphSignature 117 constants: Dict[ 118 str, 119 Union[ 120 torch.Tensor, 121 FakeScriptObject, 122 torch.ScriptObject, 123 ], 124 ] 125 126 127@dataclasses.dataclass(frozen=True) 128class ExportArtifact: 129 aten: ATenExportArtifact 130 out_spec: TreeSpec 131 fake_mode: FakeTensorMode 132 module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]] 133 134 135DEFAULT_EXPORT_DYNAMO_CONFIG = ExportDynamoConfig() 136DEFAULT_EXPORT_DYNAMO_CONFIG.reorderable_logging_functions = { 137 logging.critical, 138 logging.debug, 139 logging.error, 140 logging.exception, 141 logging.info, 142 logging.log, 143 logging.warning, 144 print, 145 warnings.warn, 146} 147 148 149@contextmanager 150def _ignore_backend_decomps(): 151 orig_mkldnn_flag = torch.backends.mkldnn.set_flags(False) 152 orig_nnpack_flag = torch.backends.nnpack.set_flags(False) 153 try: 154 yield 155 finally: 156 torch.backends.mkldnn.set_flags(*orig_mkldnn_flag) 157 torch.backends.nnpack.set_flags(*orig_nnpack_flag) 158 159 160def _fixup_key(x): 161 return "L__self__" + _strip_root(x) 162 163 164def _strip_root(x): 165 if isinstance(x, str) and x.startswith("_export_root"): 166 stripped = x[len("_export_root") :] 167 return stripped[1:] if stripped.startswith(".") else stripped 168 return x 169 170 171def _rewrite_tracepoint_node(gm: torch.fx.GraphModule): 172 """ 173 In-place modifiy input graph module by replacing the export tracepoint with a new node 174 that has the same target and args, but with the _export_root stripped from path. 175 """ 176 for node in gm.graph.nodes: 177 if node.target == torch.ops.higher_order._export_tracepoint: 178 if "path" in node.kwargs: 179 path = _strip_root(node.kwargs["path"]) 180 with gm.graph.inserting_before(node): 181 new_node = gm.graph.create_node( 182 "call_function", 183 torch.ops.higher_order._export_tracepoint, 184 args=node.args, 185 kwargs={ 186 "path": path, 187 "kind": node.kwargs["kind"], 188 }, 189 ) 190 new_node.meta = node.meta 191 node.replace_all_uses_with(new_node) 192 gm.graph.erase_node(node) 193 194 195def _extract_fake_inputs(gm, args, kwargs): 196 """ 197 Given a graph module, extract fakified input tensors from the metadata of 198 its placeholders, and map them to the structure of given args and kwargs. 199 Also return the fake mode used to fakify those inputs. 200 """ 201 202 fake_inps: List[torch.Tensor] = [] 203 fake_vals: List[torch.Tensor] = [] 204 for node in gm.graph.nodes: 205 if node.op == "placeholder" and "val" in node.meta: 206 fake_val = node.meta["val"] 207 if fake_val is not None and isinstance(fake_val, torch.Tensor): 208 fake_inps.append(fake_val) 209 elif "example_value" in node.meta: 210 fake_val = node.meta["example_value"] 211 if fake_val is not None and isinstance(fake_val, torch.Tensor): 212 fake_vals.append(fake_val) 213 214 if detected_fake_mode := detect_fake_mode(fake_inps + fake_vals): 215 fake_mode = detected_fake_mode 216 else: 217 fake_mode = FakeTensorMode(shape_env=ShapeEnv(), export=True) 218 219 count = 0 220 221 def lookup_fake(x): 222 nonlocal count 223 val = fake_inps[count] 224 count += 1 225 return val 226 227 fake_args = pytree.tree_map_only(torch.Tensor, lookup_fake, args) 228 fake_kwargs = pytree.tree_map_only(torch.Tensor, lookup_fake, kwargs) 229 230 return fake_args, fake_kwargs, fake_mode 231 232 233def _replace_param_buffer_names(param_buffer_table, sig): 234 for spec in sig.input_specs: 235 if spec.kind in ( 236 InputKind.PARAMETER, 237 InputKind.BUFFER, 238 ): 239 spec.target = param_buffer_table[spec.target] 240 for spec in sig.output_specs: 241 if spec.kind in ( 242 OutputKind.BUFFER_MUTATION, 243 OutputKind.GRADIENT_TO_PARAMETER, 244 ): 245 spec.target = param_buffer_table[spec.target] 246 247 248def _convert_to_positional_args(orig_arg_names, args, kwargs): 249 assert len(orig_arg_names) == len(args) + len(kwargs), ( 250 f"Total number of arg names is expected to be {len(orig_arg_names)} " 251 f"but got {len(args)} positional args, {len(kwargs)} kwargs." 252 ) 253 reordered_kwargs = [kwargs[kw_name] for kw_name in orig_arg_names[len(args) :]] 254 return ( 255 *args, 256 *reordered_kwargs, 257 ) 258 259 260def _normalize_nn_module_stack(gm_torch_level, root_cls): 261 # Append a root module to every nn_module_stack. 262 root = "L['self']" 263 root_key = re.sub(r"[^a-zA-Z0-9]", "_", root) 264 for gm in gm_torch_level.modules(): 265 if not isinstance(gm, torch.fx.GraphModule): 266 continue 267 for node in gm.graph.nodes: 268 if node.op in ["placeholder", "output"]: 269 continue 270 add_root = True 271 if nn_module_stack := node.meta.get("nn_module_stack", {}): 272 path, ty = next(iter(nn_module_stack.values())) 273 # After deserializing the class `ty` might not exist anymore so 274 # it could be a string 275 if inspect.isclass(ty) and issubclass(ty, torch.nn.Module): 276 # TODO Figure out why sometimes we have root sometimes we don't. 277 if path == root and ty is root_cls: 278 add_root = False 279 else: 280 assert isinstance(ty, str) 281 if add_root: 282 283 def normalize_path(path): 284 try: 285 parts = [] 286 287 class Path: 288 def __getattr__(self, name): 289 parts.append(name) 290 return self 291 292 def __getitem__(self, idx): 293 parts.append(str(idx)) 294 return self 295 296 eval(path, {"L": {"self": Path()}}) 297 return ".".join(parts) 298 except Exception: # TODO(zhxchen17) Remove this. 299 return path 300 301 nn_module_stack = { 302 root_key: (root, root_cls.__module__ + "." + root_cls.__qualname__), 303 **nn_module_stack, 304 } 305 node.meta["nn_module_stack"] = { 306 key: (normalize_path(path), ty) 307 for key, (path, ty) in nn_module_stack.items() 308 } 309 310 311def _get_param_buffer_mapping( 312 original_module: torch.nn.Module, 313 traced_module: torch.nn.Module, 314) -> Dict[str, str]: 315 """ 316 Returns a mapping of parameter/buffer names from the new module to the 317 original model. This is to help with restoring the FQN for parameter/buffers 318 of a traced module to what the original module contains. 319 """ 320 321 param_lookup: Dict[int, str] = {} 322 buffer_lookup: Dict[int, str] = {} 323 for name, param in original_module.named_parameters(remove_duplicate=False): 324 param_lookup[id(param)] = name 325 for name, buffer in original_module.named_buffers(remove_duplicate=False): 326 buffer_lookup[id(buffer)] = name 327 328 param_buffer_table: Dict[str, str] = {} 329 for dynamo_name, dynamo_param in traced_module.named_parameters( 330 remove_duplicate=False 331 ): 332 assert dynamo_name not in param_buffer_table 333 if id(dynamo_param) in param_lookup: 334 param_buffer_table[dynamo_name] = param_lookup[id(dynamo_param)] 335 336 for dynamo_name, dynamo_buffer in traced_module.named_buffers( 337 remove_duplicate=False 338 ): 339 assert dynamo_name not in param_buffer_table 340 if id(dynamo_buffer) in buffer_lookup: 341 param_buffer_table[dynamo_name] = buffer_lookup[id(dynamo_buffer)] 342 343 return param_buffer_table 344 345 346def _preserve_requires_grad_pass( 347 gm: torch.fx.GraphModule, 348 sig: ExportGraphSignature, 349 fake_params_buffers: Dict[str, torch.Tensor], 350 constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]], 351 flat_fake_args: List[Any], 352): 353 placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"] 354 assert len(sig.input_specs) == len(placeholders) 355 i = 0 356 for node, spec in zip(placeholders, sig.input_specs): 357 if spec.kind in ( 358 InputKind.PARAMETER, 359 InputKind.BUFFER, 360 ): 361 assert spec.target is not None 362 node.meta["val"].requires_grad = fake_params_buffers[ 363 spec.target 364 ].requires_grad 365 elif spec.kind == InputKind.USER_INPUT: 366 fake_arg = flat_fake_args[i] 367 if isinstance(fake_arg, torch.Tensor): 368 node.meta["val"].requires_grad = fake_arg.requires_grad 369 i += 1 370 elif spec.kind == InputKind.CONSTANT_TENSOR: 371 assert spec.target is not None 372 constant = constants[spec.target] 373 if isinstance(constant, torch.Tensor): 374 # If the tensor is not leaf, it should already have a correct requires grad field 375 if node.meta["val"].is_leaf: 376 node.meta["val"].requires_grad = constant.requires_grad 377 else: 378 assert node.meta["val"].requires_grad == constant.requires_grad 379 elif spec.kind in (InputKind.CUSTOM_OBJ, InputKind.TOKEN): 380 continue 381 else: 382 raise AssertionError(spec.kind) 383 384 385def _remap_constants( 386 orig_constant_attrs: ConstantAttrMap, 387 graph_signature: ExportGraphSignature, 388 constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]], 389) -> None: 390 """Rewrite the graph signature and constants table to use the FQN from the original module.""" 391 remap_table: Dict[str, List[str]] = {} 392 for name, value in constants.items(): 393 if value in orig_constant_attrs: 394 remap_table[name] = orig_constant_attrs[value] 395 396 for spec in graph_signature.input_specs: 397 if spec.kind in ( 398 InputKind.CONSTANT_TENSOR, 399 InputKind.CUSTOM_OBJ, 400 ): 401 orig_target = spec.target 402 assert orig_target is not None 403 targets = remap_table.get(orig_target, [orig_target]) 404 spec.target = targets[0] 405 406 constant = constants[orig_target] 407 del constants[orig_target] 408 for target in targets: 409 constants[target] = constant 410 411 412def _rename_constants_nodes( 413 gm: torch.fx.GraphModule, 414 graph_signature: ExportGraphSignature, 415) -> None: 416 """ 417 For strict mode, rename constants nodes that were previously annotated as buffers. 418 """ 419 # handle name collisions with existing constants 420 node_names = {node.name for node in gm.graph.nodes} 421 422 def rename_constant(name): 423 if name in node_names: 424 n = 1 425 while (dup_name := f"{name}_{n}") in node_names: 426 n += 1 427 name = dup_name 428 node_names.add(name) 429 return name 430 431 # use input specs to map names from buffers to constants 432 buffer_prefix = placeholder_prefixes[InputKind.BUFFER] 433 const_prefix = placeholder_prefixes[InputKind.CONSTANT_TENSOR] 434 buffer_to_constant = {} 435 for spec in graph_signature.input_specs: 436 if spec.kind == InputKind.CONSTANT_TENSOR and not spec.arg.name.startswith( 437 const_prefix 438 ): 439 if spec.arg.name.startswith(buffer_prefix): # map from buffer to constants 440 c_name = rename_constant( 441 const_prefix + spec.arg.name[len(buffer_prefix) :] 442 ) 443 else: # lifted constant 444 c_name = rename_constant(const_prefix + spec.arg.name) 445 buffer_to_constant[spec.arg.name] = c_name 446 spec.arg.name = c_name 447 for spec in graph_signature.output_specs: 448 if spec.arg.name in buffer_to_constant: 449 spec.arg.name = buffer_to_constant[spec.arg.name] 450 451 # Rename constants nodes for all modules 452 for mod in gm.modules(): 453 if not isinstance(mod, torch.fx.GraphModule): 454 continue 455 for node in mod.graph.nodes: 456 if node.name in buffer_to_constant: 457 node.name = node.target = buffer_to_constant[node.name] 458 mod.recompile() 459 460 461def _restore_state_dict( 462 original_module: torch.nn.Module, traced_module: torch.fx.GraphModule 463) -> None: 464 """ 465 Restores the state dict of the traced module to that of the original module. 466 """ 467 param_buffer_table = _get_param_buffer_mapping(original_module, traced_module) 468 # Since the graph module is flattened (no module heirarchy), we 469 # need to noramlize the module by replacing "." with "_". If we 470 # don't, it will try to save the weight to a submodule which no 471 # longer exists. 472 for name, fqn in param_buffer_table.items(): 473 param_buffer_table[name] = fqn.replace(".", "_") 474 475 # Replace state dict attr names with the fqn 476 for name, fqn in param_buffer_table.items(): 477 if not hasattr(traced_module, name): 478 continue 479 480 attr = getattr(traced_module, name) 481 if isinstance(attr, torch.Tensor) and not isinstance(attr, torch.nn.Parameter): 482 traced_module.register_buffer(fqn, attr) 483 else: 484 setattr(traced_module, fqn, attr) 485 delattr(traced_module, name) 486 487 # Replace graph getattr nodes with the correct name 488 for node in traced_module.graph.nodes: 489 if node.op == "get_attr": 490 attr_name = node.target 491 if attr_name in param_buffer_table: 492 node.target = param_buffer_table[attr_name] 493 494 traced_module.recompile() 495 496 497def _get_module_hierarchy(mod: torch.nn.Module) -> Dict[str, str]: 498 return { 499 name: type(m).__name__ for name, m in mod.named_modules(remove_duplicate=False) 500 } 501 502 503def _make_module_call_graph( 504 module_hierarchy: Dict[str, str], 505 in_spec: TreeSpec, 506 out_spec: TreeSpec, 507 module_call_signatures: Dict[str, ModuleCallSignature], 508) -> List[ModuleCallEntry]: 509 ret = [ 510 ModuleCallEntry(fqn=fqn, signature=module_call_signatures.get(fqn)) 511 for fqn in module_hierarchy 512 ] 513 assert ret[0].fqn == "" 514 ret[0].signature = ModuleCallSignature( 515 inputs=[], outputs=[], in_spec=in_spec, out_spec=out_spec 516 ) 517 return ret 518 519 520def _export_to_torch_ir( 521 f: Callable, 522 args: Tuple[Any, ...], 523 kwargs: Optional[Dict[str, Any]] = None, 524 dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None, 525 *, 526 preserve_module_call_signature: Tuple[str, ...] = (), 527 disable_constraint_solver: bool = False, 528 allow_complex_guards_as_runtime_asserts: bool = False, 529 restore_fqn: bool = True, 530 _log_export_usage: bool = True, 531 same_signature: bool = True, 532) -> torch.fx.GraphModule: 533 """ 534 Traces either an nn.Module's forward function or just a callable with PyTorch 535 operations inside and produce a torch.fx.GraphModule in torch IR. 536 """ 537 538 if _log_export_usage: 539 log_export_usage(event="export.private_api", flags={"_export_to_torch_ir"}) 540 541 if not isinstance(args, tuple): 542 raise UserError( 543 UserErrorType.INVALID_INPUT, 544 f"Expecting `args` to be a tuple of example positional inputs, got {type(args)}", 545 ) 546 547 kwargs = kwargs or {} 548 combined_args = _combine_args(f, args, kwargs) 549 _check_dynamic_shapes(combined_args, dynamic_shapes) 550 transformed_dynamic_shapes = _transform_shapes_for_default_dynamic( 551 combined_args, dynamic_shapes 552 ) 553 554 with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)): 555 try: 556 module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]] = {} 557 with _wrap_submodules( 558 f, preserve_module_call_signature, module_call_specs 559 ), _ignore_backend_decomps(): 560 gm_torch_level, _ = torch._dynamo.export( 561 f, 562 dynamic_shapes=transformed_dynamic_shapes, # type: ignore[arg-type] 563 tracing_mode="symbolic", 564 disable_constraint_solver=disable_constraint_solver, 565 # currently the following 2 flags are tied together for export purposes, 566 # but untangle for sake of dynamo export api 567 prefer_deferred_runtime_asserts_over_guards=True, 568 allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, 569 _log_export_usage=_log_export_usage, 570 same_signature=same_signature, 571 )( 572 *args, 573 **kwargs, 574 ) 575 except (ConstraintViolationError, ValueRangeError) as e: 576 raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904 577 except GuardOnDataDependentSymNode as e: 578 raise UserError( # noqa: B904 579 UserErrorType.ANTI_PATTERN, 580 f"Consider annotating your code using torch._check*(). {str(e)}", 581 case_name="constrain_as_size_example", 582 ) 583 584 gm_torch_level.meta["module_call_specs"] = module_call_specs 585 586 if isinstance(f, torch.nn.Module) and restore_fqn: 587 _restore_state_dict(f, gm_torch_level) 588 589 return gm_torch_level 590 591 592def _export_to_aten_ir( 593 mod: torch.nn.Module, 594 fake_args, 595 fake_kwargs, 596 fake_params_buffers, 597 constant_attrs: ConstantAttrMap, 598 produce_guards_callback=None, 599 *, 600 transform=lambda x: x, # TODO(zhxchen17) Revisit if this is needed later. 601 pre_dispatch=False, 602 decomp_table=None, 603 _check_autograd_state=True, 604 _is_torch_jit_trace=False, 605) -> ATenExportArtifact: 606 # [NOTE] If the user is exporting under training mode, we want to detect if there is any 607 # state change in the autograd global state and error. If the user is exporting under inference 608 # mode, we don't care. At predispatch level, we don't care about the state change. 609 is_grad_enabled = torch._C.is_grad_enabled() 610 grad_safe_guard = nullcontext() 611 # export_to_aten_ir is called when we decompose the ep into inference IR 612 # In that setting, we actually shouldn't check the state change as at this point, 613 # because the intention is specalizing to inference. 614 if _check_autograd_state: 615 if not pre_dispatch and is_grad_enabled: 616 grad_safe_guard = AutogradStateOpsFailSafeguard() # type: ignore[assignment] 617 618 @contextmanager 619 def _compiling_state_context(): 620 old_value = torch.compiler._is_compiling_flag 621 try: 622 torch.compiler._is_compiling_flag = True 623 yield 624 finally: 625 torch.compiler._is_compiling_flag = old_value 626 627 # This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode, 628 # otherwise aot_export_module will error out because it sees a mix of fake_modes. 629 # And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about. 630 with torch.nn.utils.stateless._reparametrize_module( 631 mod, 632 fake_params_buffers, 633 tie_weights=True, 634 strict=True, 635 stack_weights=True, 636 ), grad_safe_guard, _ignore_backend_decomps(), _compiling_state_context(): # type: ignore[attr-defined] 637 gm, graph_signature = transform(aot_export_module)( 638 mod, 639 fake_args, 640 trace_joint=False, 641 pre_dispatch=pre_dispatch, 642 decompositions=decomp_table, 643 kwargs=fake_kwargs, 644 ) 645 646 def _maybe_fixup_gm_and_output_node_meta(old_gm, new_gm): 647 if isinstance(old_gm, torch.fx.GraphModule): 648 if hasattr(old_gm, "meta"): 649 new_gm.meta.update(old_gm.meta) 650 old_output_node = list(old_gm.graph.nodes)[-1] 651 new_output_node = list(new_gm.graph.nodes)[-1] 652 assert old_output_node.op == "output" and new_output_node.op == "output" 653 # make sure we don't override any meta 654 assert len(new_output_node.meta) == 0 655 new_output_node.meta.update(old_output_node.meta) 656 657 # TODO unfortunately preserving graph-level metadata and output node's meta 658 # is not working well with aot_export. So we manually copy it. 659 # (The node-level meta is addressed above.) 660 _maybe_fixup_gm_and_output_node_meta(mod, gm) 661 662 # Run produce guards before we handle runtime asserts. 663 # This means we run the export solver before the runtime asserts pass. 664 # Right now this doesn't mean much - the export solver is only there for suggested fixes, 665 # and we won't even get to constraint solving if that's needed. 666 # But if in future we want to control what runtime asserts are emitted for export, 667 # or rely on produce_guards + solver for some simplification on runtime asserts, this probably makes sense. 668 if produce_guards_callback: 669 try: 670 produce_guards_callback(gm) 671 except (ConstraintViolationError, ValueRangeError) as e: 672 raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904 673 674 # Run runtime asserts pass before creating input/output specs, since size-related CSE/DCE might affect output signature. 675 # Overwrite output specs afterwards. 676 flat_fake_args = pytree.tree_leaves((fake_args, fake_kwargs)) 677 if not torch._dynamo.config.do_not_emit_runtime_asserts: 678 stack_trace = ( 679 'File "torch/fx/passes/runtime_assert.py", line 24, ' 680 "in insert_deferred_runtime_asserts" 681 ) 682 with _set_node_metadata_hook( 683 gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace) 684 ): 685 shape_env = _get_shape_env_from_gm(gm) 686 if shape_env: 687 insert_deferred_runtime_asserts( 688 gm, 689 shape_env, 690 f"exported program: {first_call_function_nn_module_stack(gm.graph)}", 691 export=True, 692 ) 693 694 # update output specs 695 gm.recompile() 696 graph_signature.user_outputs = _graph_output_names(gm) 697 698 # NOTE: aot_export adds symint metadata for placeholders with int values; 699 # since these become specialized, we replace such metadata with the original values 700 index = 0 701 total_non_user_inputs = ( 702 len(graph_signature.parameters) 703 + len(graph_signature.buffers) 704 + len(graph_signature.input_tokens) 705 ) 706 for node in gm.graph.nodes: 707 if node.op == "placeholder": 708 if index >= total_non_user_inputs: 709 user_arg = flat_fake_args[index - total_non_user_inputs] 710 if not isinstance(user_arg, torch.Tensor): 711 node.meta["val"] = user_arg 712 index += 1 713 714 export_graph_signature = _convert_to_export_graph_signature( 715 graph_signature, gm, _get_non_persistent_buffers(mod) 716 ) 717 718 constants = rewrite_script_object_meta(gm) 719 constants.update(lift_constants_pass(gm, export_graph_signature, constant_attrs)) 720 721 if pre_dispatch: 722 from torch._export.passes.replace_autocast_with_hop_pass import ( 723 replace_autocast_with_hop_pass, 724 ) 725 from torch._export.passes.replace_set_grad_with_hop_pass import ( 726 replace_set_grad_with_hop_pass, 727 ) 728 729 # Note: replace_set_grad_with_hop_pass need to be after lift_constant_pass because 730 # a getattr of a constant tensor doesn't have meta["val"] until after lift_constant_pass. 731 # If replace_set_grad_with_hop_pass is before lift_constant_pass, 732 # and the constant_tensor is passed as input of the set grad hop, the placeholder's 733 # meta["val"] will be None and fails our verifier for placeholder. 734 gm, export_graph_signature = replace_set_grad_with_hop_pass( 735 gm, export_graph_signature 736 ) 737 738 gm, export_graph_signature = replace_autocast_with_hop_pass( 739 gm, export_graph_signature 740 ) 741 742 # Remove nn_module_stack, stack_trace metadata from all placeholders/inputs nodes. 743 for _mod in gm.modules(): 744 if not isinstance(_mod, torch.fx.GraphModule): 745 continue 746 for node in _mod.graph.nodes: 747 if node.op in ["placeholder", "output"]: 748 node.meta.pop("nn_module_stack", None) 749 node.meta.pop("stack_trace", None) 750 751 # Prettify names for placeholder nodes. 752 placeholder_naming_pass( 753 gm, 754 export_graph_signature, 755 mod, 756 fake_args, 757 fake_kwargs, 758 fake_params_buffers, 759 constants, 760 ) 761 762 _preserve_requires_grad_pass( 763 gm, export_graph_signature, fake_params_buffers, constants, flat_fake_args 764 ) 765 766 return ATenExportArtifact( 767 gm, 768 export_graph_signature, 769 constants, 770 ) 771 772 773def _fakify_params_buffers( 774 fake_mode: FakeTensorMode, 775 mod: torch.nn.Module, 776) -> Dict[str, Union[torch.Tensor, torch.nn.Parameter]]: 777 params_buffers = { 778 **dict(mod.named_parameters(remove_duplicate=False)), 779 **dict(mod.named_buffers(remove_duplicate=False)), 780 } 781 782 faked_params_buffers = {} 783 memo: Dict[int, FakeTensor] = {} 784 for key, value in params_buffers.items(): 785 if id(value) in memo: 786 fake_tensor = memo[id(value)] 787 else: 788 fake_tensor = fake_mode.from_tensor(value, static_shapes=True) 789 memo[id(value)] = fake_tensor 790 faked_params_buffers[key] = fake_tensor 791 return faked_params_buffers # type: ignore[return-value] 792 793 794def _get_forward_arg_names( 795 mod: torch.nn.Module, 796 args: Tuple[Any, ...], 797 kwargs: Optional[Dict[str, Any]] = None, 798) -> List[str]: 799 """ 800 Gets the argument names to forward that are used, for restoring the 801 original signature when unlifting the exported program module. 802 - Positional args: retain the original argument names, and enumerate 803 *args as args_0, args_1, ... 804 - Keyword args: retain the original kwarg names in the order specified 805 by the user. This order seems to matter for the current state of 806 export lifted modules. 807 """ 808 sig = inspect.signature(mod.forward) 809 _args = sig.bind_partial(*args).arguments 810 811 names: List[str] = [] 812 for name, value in _args.items(): 813 # handle variable number of positional args 814 if sig.parameters[name].kind == inspect._ParameterKind.VAR_POSITIONAL: 815 names.extend([f"{name}_{i}" for i, _ in enumerate(value)]) 816 else: 817 names.append(name) 818 # order of kwargs matters for input spec 819 if kwargs: 820 names.extend([kwarg for kwarg, _ in kwargs.items()]) 821 822 return names 823 824 825def _get_non_persistent_buffers(mod: torch.nn.Module) -> Set[str]: 826 """ 827 Returns set of non-persistent buffers in a module and its submodules. 828 """ 829 result = set() 830 for name, m in mod.named_modules(): 831 for b in m._non_persistent_buffers_set: 832 result.add(f"{name}.{b}" if name else b) 833 return result 834 835 836def _rewrite_dynamo_tensor_constants( 837 orig_mod_buffers: Set[torch.Tensor], 838 traced_mod_buffers: Dict[str, torch.Tensor], 839 graph_signature: ExportGraphSignature, 840 constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]], 841): 842 """ 843 Dynamo erroneously marks tensor attributes on modules as buffers. 844 Rewrite them to be tensor constants. 845 """ 846 for spec in graph_signature.input_specs: 847 if spec.kind == InputKind.BUFFER: 848 assert spec.target is not None 849 value = traced_mod_buffers[spec.target] 850 if value not in orig_mod_buffers: 851 # This was a tensor constant erroneously marked as a buffer. 852 # Convert it into a constant in the graph signature, and add its 853 # value to the constants table. 854 spec.kind = InputKind.CONSTANT_TENSOR 855 constants[spec.target] = value # type: ignore[arg-type] 856 857 858def _move_non_persistent_buffers_to_tensor_constants( 859 orig_mod: torch.nn.Module, 860 graph_signature: ExportGraphSignature, 861 constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]], 862): 863 """ 864 Moves non-persistent buffers to tensor constants. 865 """ 866 for spec in graph_signature.input_specs: 867 if spec.kind == InputKind.BUFFER and not spec.persistent: 868 assert spec.target is not None 869 assert spec.target not in constants 870 constants[spec.target] = orig_mod.get_buffer(spec.target) # type: ignore[arg-type] 871 872 873def _verify_nn_module_stack(graph_module: torch.fx.GraphModule) -> None: 874 """ 875 Perform nn_module_stack checks on the graph. 876 Current constraints: 877 For the top level graph: 878 - populated for 'call_function', 'get_attr' 879 - None for 'placeholder', 'output' 880 For submodule graphs: 881 - None for 'placeholder', output' 882 883 TODO(pianpwk): make this a consistent node-level check once nn_module_stack is populated for cond submodules. 884 """ 885 # Check top-level graph for all nodes, all graphs for placeholder & output nodes 886 for i, mod in enumerate([graph_module] + list(graph_module.modules())): 887 if not isinstance(mod, torch.fx.GraphModule): 888 continue 889 for node in mod.graph.nodes: 890 if node.op in ["call_function", "get_attr"]: 891 if i == 0: 892 if ( 893 nn_module_stack := node.meta.get("nn_module_stack", None) 894 ) is None: 895 raise SpecViolationError( 896 f"Node {node} of type {node.op} is missing nn_module_stack metadata" 897 ) 898 if not all( 899 isinstance(k, str) 900 and isinstance(v, tuple) 901 and len(v) == 2 902 and all(isinstance(x, str) for x in v) 903 for k, v in nn_module_stack.items() 904 ): 905 raise SpecViolationError( 906 f"Node {node} of type {node.op} has incorrect nn_module_stack metadata format" 907 f"expected Dict[str, Tuple[str, str]], but got {nn_module_stack}" 908 ) 909 elif node.op in ["placeholder", "output"]: 910 if node.meta.get("nn_module_stack", None): 911 raise SpecViolationError( 912 f"Node {node} of type {node.op} contains nn_module_stack metadata, this should be None" 913 ) 914 915 916def _verify_stack_trace(graph_module: torch.fx.GraphModule) -> None: 917 """ 918 Perform stack trace checks on the graph. 919 Constraints: 920 - None or non-empty str for 'call_function', 'get_attr' 921 - None for 'placeholder', 'output' 922 """ 923 for i, mod in enumerate([graph_module] + list(graph_module.modules())): 924 if not isinstance(mod, torch.fx.GraphModule): 925 continue 926 for node in graph_module.graph.nodes: 927 stack_trace = node.meta.get("stack_trace", None) 928 if node.op in ["call_function", "get_attr"]: 929 if not (stack_trace is None or isinstance(stack_trace, str)): 930 raise SpecViolationError( 931 f"Node {node} of type {node.op} has invalid stack_trace metadata, " 932 f"expected a string or None but instead found: {stack_trace}" 933 ) 934 elif node.op in ["placeholder", "output"]: 935 if stack_trace: 936 raise SpecViolationError( 937 f"Node {node} of type {node.op} contains stack_trace metadata, " 938 f"expected None but instead found: {stack_trace}" 939 ) 940 941 942def _verify_placeholder_names(gm: torch.fx.GraphModule, sig: ExportGraphSignature): 943 """ 944 Performs a sanity check on the placeholder node names. 945 - User input nodes: no restrictions, should match the original forward() signature 946 - Params/buffers/constants/custom_obj/token nodes: should start with prefixes defined in <placeholder_prefixes> 947 """ 948 name_to_kind = {spec.arg.name: spec.kind for spec in sig.input_specs} 949 for mod in gm.modules(): 950 if not isinstance(mod, torch.fx.GraphModule): 951 continue 952 for node in mod.graph.nodes: 953 if node.op == "placeholder": 954 if node.name not in name_to_kind: 955 continue 956 node_kind = name_to_kind[node.name] 957 prefix = placeholder_prefixes[node_kind] 958 if not node.name.startswith(prefix): 959 raise SpecViolationError( 960 f"Placeholder node name {node.name} does not follow spec for {node_kind}, name should have prefix: {prefix}" 961 ) 962 963 964def get_ep_stats(ep: ExportedProgram) -> Dict[str, Any]: 965 op_count = 0 966 op_set = set() 967 for m in ep.graph_module.modules(): 968 if not isinstance(m, torch.fx.GraphModule): 969 continue 970 for node in m.graph.nodes: 971 if node.op != "call_function": 972 continue 973 op_count += 1 974 assert hasattr(node.target, "__module__") 975 assert hasattr(node.target, "__name__") 976 op_set.add(f"{node.target.__module__}.{node.target.__name__}") 977 return {"op_count": op_count, "op_set": op_set} 978 979 980_EXPORT_FLAGS: Optional[Set[str]] = None 981_EXPORT_MODULE_HIERARCHY: Optional[Dict[str, str]] = None 982 983 984def _log_export_wrapper(fn): 985 @functools.wraps(fn) 986 def wrapper(*args, **kwargs): 987 global _EXPORT_FLAGS, _EXPORT_MODULE_HIERARCHY 988 try: 989 start = time.time() 990 ep = fn(*args, **kwargs) 991 end = time.time() 992 log_export_usage( 993 event="export.time", 994 metrics=end - start, 995 flags=_EXPORT_FLAGS, 996 **get_ep_stats(ep), 997 ) 998 except Exception as e: 999 t = type(e) 1000 error_type = t.__module__ + "." + t.__qualname__ 1001 case_name = get_class_if_classified_error(e) 1002 if case_name is not None: 1003 log.error(exportdb_error_message(case_name)) 1004 log_export_usage( 1005 event="export.error.classified", 1006 type=error_type, 1007 message=str(e), 1008 flags=_EXPORT_FLAGS, 1009 ) 1010 else: 1011 log_export_usage( 1012 event="export.error.unclassified", 1013 type=error_type, 1014 message=str(e), 1015 flags=_EXPORT_FLAGS, 1016 ) 1017 raise e 1018 finally: 1019 _EXPORT_FLAGS = None 1020 _EXPORT_MODULE_HIERARCHY = None 1021 1022 return ep 1023 1024 return wrapper 1025 1026 1027def _process_jit_trace_inputs_for_export(example_inputs, example_kwarg_inputs): 1028 if not isinstance(example_inputs, (tuple, list, dict)): 1029 example_inputs = (example_inputs,) 1030 1031 elif isinstance(example_inputs, list): 1032 example_inputs = tuple(example_inputs) 1033 1034 elif ( 1035 isinstance(example_inputs, (torch.Tensor, dict)) 1036 and example_kwarg_inputs is None 1037 ): 1038 example_inputs = (example_inputs,) 1039 1040 if example_kwarg_inputs is None: 1041 example_kwarg_inputs = {} 1042 return example_inputs, example_kwarg_inputs 1043 1044 1045def _process_export_inputs(mod, args, kwargs, dynamic_shapes): 1046 original_state_dict = mod.state_dict(keep_vars=True) 1047 1048 if not isinstance(args, tuple): 1049 raise UserError( 1050 UserErrorType.INVALID_INPUT, 1051 f"Expecting `args` to be a tuple of example positional inputs, got {type(args)}", 1052 ) 1053 kwargs = kwargs if kwargs is not None else {} 1054 _, original_in_spec = pytree.tree_flatten((args, kwargs)) 1055 1056 if isinstance(dynamic_shapes, torch.export.ShapesCollection): 1057 dynamic_shapes = dynamic_shapes.dynamic_shapes(mod, args, kwargs) 1058 1059 return args, kwargs, original_in_spec, original_state_dict, dynamic_shapes 1060 1061 1062def _get_module_call_graph( 1063 export_artifact: ExportArtifact, 1064 original_in_spec: TreeSpec, 1065 preserve_module_call_signature: Tuple[str, ...], 1066 strict_mode_export: bool, 1067): 1068 """ 1069 In-place modify the graph module in export_artifact, remove _export_tracepoint nodes and 1070 return module_call_graph. 1071 """ 1072 gm: torch.fx.GraphModule = export_artifact.aten.gm 1073 export_graph_signature: ExportGraphSignature = export_artifact.aten.sig 1074 module_call_specs: Dict[ 1075 str, Dict[str, TreeSpec] 1076 ] = export_artifact.module_call_specs 1077 out_spec: TreeSpec = export_artifact.out_spec 1078 1079 # Make module signatures. 1080 module_call_signatures = {} 1081 for fqn, specs in module_call_specs.items(): 1082 mod_fqn = _strip_root(fqn) if not strict_mode_export else fqn 1083 module_call_signatures[mod_fqn] = ModuleCallSignature( 1084 inputs=[], outputs=[], **specs 1085 ) 1086 1087 if len(preserve_module_call_signature) > 0: 1088 if not strict_mode_export: 1089 _rewrite_tracepoint_node(gm) 1090 res = CollectTracepointsPass(module_call_signatures, export_graph_signature)(gm) 1091 assert res is not None 1092 gm = res.graph_module 1093 1094 assert _EXPORT_MODULE_HIERARCHY is not None 1095 module_call_graph = _make_module_call_graph( 1096 _EXPORT_MODULE_HIERARCHY, 1097 original_in_spec, 1098 out_spec, 1099 module_call_signatures, 1100 ) 1101 return gm, module_call_graph 1102 1103 1104def _get_range_constraints( 1105 export_artifact: ExportArtifact, combined_args: Dict[str, Any], dynamic_shapes 1106): 1107 gm: torch.fx.GraphModule = export_artifact.aten.gm 1108 export_graph_signature: ExportGraphSignature = export_artifact.aten.sig 1109 fake_mode: FakeTensorMode = export_artifact.fake_mode 1110 num_lifted = next( 1111 ( 1112 i 1113 for i, s in enumerate(export_graph_signature.input_specs) 1114 if s.kind == InputKind.USER_INPUT 1115 ), 1116 len(export_graph_signature.input_specs), 1117 ) 1118 range_constraints = make_constraints( 1119 fake_mode, 1120 gm, 1121 combined_args, 1122 dynamic_shapes, 1123 num_lifted, 1124 ) 1125 return range_constraints 1126 1127 1128def _get_inline_constraints(fake_mode: FakeTensorMode): 1129 assert fake_mode.shape_env is not None 1130 return { 1131 k: v 1132 for k, v in fake_mode.shape_env.var_to_range.items() 1133 if free_unbacked_symbols(k) 1134 } 1135 1136 1137@contextmanager 1138def patch_forward(obj: torch.nn.Module, new_method): 1139 """Helper method to make it easier to cleanly torch.export() a method on a 1140 module that is not `forward`. 1141 """ 1142 # Save the original method 1143 original_method = obj.forward 1144 1145 # Patch the method 1146 obj.forward = new_method.__get__(obj, obj.__class__) 1147 1148 try: 1149 yield 1150 finally: 1151 # Restore the original method 1152 obj.forward = original_method 1153 1154 1155@contextmanager 1156def _temp_disable_texpr_fuser(): 1157 original_state = torch._C._jit_texpr_fuser_enabled() 1158 torch._C._jit_set_texpr_fuser_enabled(False) 1159 try: 1160 yield 1161 finally: 1162 torch._C._jit_set_texpr_fuser_enabled(original_state) 1163 1164 1165class _WrapperModule(torch.nn.Module): 1166 def __init__(self, f): 1167 super().__init__() 1168 self.f = f 1169 1170 def forward(self, *args, **kwargs): 1171 return self.f(*args, **kwargs) 1172 1173 1174def _convert_ts_to_export_experimental(traced_callable, args, kwargs=None): 1175 with _temp_disable_texpr_fuser(): 1176 from torch.jit._trace import TopLevelTracedModule 1177 1178 export_args, export_kwargs = _process_jit_trace_inputs_for_export(args, kwargs) 1179 1180 if isinstance(traced_callable, (TopLevelTracedModule, torch._C.ScriptModule)): # type: ignore[operator] 1181 return _export( 1182 traced_callable, 1183 export_args, 1184 export_kwargs, 1185 strict=False, 1186 _is_torch_jit_trace=True, 1187 ).module() 1188 1189 elif isinstance(traced_callable, torch.ScriptMethod) and isinstance( 1190 traced_callable.owner(), (torch._C.ScriptModule, torch.nn.Module) # type: ignore[operator] 1191 ): 1192 with patch_forward(traced_callable.owner(), traced_callable): # type: ignore[operator] 1193 return _export( 1194 traced_callable.owner(), # type: ignore[operator] 1195 export_args, 1196 export_kwargs, 1197 strict=False, 1198 _is_torch_jit_trace=True, 1199 ).module() 1200 1201 else: 1202 return _export( 1203 _WrapperModule(traced_callable), 1204 export_args, 1205 export_kwargs, 1206 strict=False, 1207 _is_torch_jit_trace=True, 1208 ).module() 1209 1210 1211def _strict_export( 1212 mod: torch.nn.Module, 1213 args: Tuple[Any, ...], 1214 kwargs: Dict[str, Any], 1215 dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]], 1216 preserve_module_call_signature: Tuple[str, ...], 1217 pre_dispatch: bool, 1218 original_state_dict: Dict[str, Any], 1219 orig_in_spec: TreeSpec, 1220 allow_complex_guards_as_runtime_asserts: bool, 1221 _is_torch_jit_trace: bool, 1222) -> ExportArtifact: 1223 lower_to_aten = functools.partial(_export_to_aten_ir, pre_dispatch=pre_dispatch) 1224 return _strict_export_lower_to_aten_ir( 1225 mod=mod, 1226 args=args, 1227 kwargs=kwargs, 1228 dynamic_shapes=dynamic_shapes, 1229 preserve_module_call_signature=preserve_module_call_signature, 1230 pre_dispatch=pre_dispatch, 1231 original_state_dict=original_state_dict, 1232 orig_in_spec=orig_in_spec, 1233 allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, 1234 _is_torch_jit_trace=_is_torch_jit_trace, 1235 lower_to_aten_callback=lower_to_aten, 1236 ) 1237 1238 1239def _strict_export_lower_to_aten_ir( 1240 mod: torch.nn.Module, 1241 args: Tuple[Any, ...], 1242 kwargs: Dict[str, Any], 1243 dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]], 1244 preserve_module_call_signature: Tuple[str, ...], 1245 pre_dispatch: bool, 1246 original_state_dict: Dict[str, Any], 1247 orig_in_spec: TreeSpec, 1248 allow_complex_guards_as_runtime_asserts: bool, 1249 _is_torch_jit_trace: bool, 1250 lower_to_aten_callback: Callable, 1251) -> ExportArtifact: 1252 gm_torch_level = _export_to_torch_ir( 1253 mod, 1254 args, 1255 kwargs, 1256 dynamic_shapes, 1257 preserve_module_call_signature=preserve_module_call_signature, 1258 restore_fqn=False, # don't need to restore because we will do it later 1259 allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, 1260 _log_export_usage=False, 1261 ) 1262 1263 # We detect the fake_mode by looking at gm_torch_level's placeholders, this is the fake_mode created in dynamo. 1264 ( 1265 fake_args, 1266 fake_kwargs, 1267 dynamo_fake_mode, 1268 ) = _extract_fake_inputs(gm_torch_level, args, kwargs) 1269 1270 fake_params_buffers = _fakify_params_buffers(dynamo_fake_mode, gm_torch_level) 1271 1272 # First, we want to pass through the graph to try populating 1273 # val field for getattr if there is anything missing. 1274 # This can happen when quantization adds extra params and forgets 1275 # to update "val" 1276 for node in gm_torch_level.graph.nodes: 1277 if node.op == "get_attr" and "val" not in node.meta: 1278 attr = getattr(gm_torch_level, node.target) 1279 # Checks if it is not a HigherOrderOp branch or a module 1280 if not isinstance(attr, torch.nn.Module): 1281 assert ( 1282 dynamo_fake_mode is not None 1283 ), "Cannot find dynamo_fake_mode. This could be due to the exported graph module have no placeholders." 1284 node.meta["val"] = dynamo_fake_mode.from_tensor( 1285 attr, static_shapes=True 1286 ) 1287 1288 # Fix the graph output signature to be tuple if scalar 1289 out_spec = orig_out_spec = gm_torch_level._out_spec 1290 1291 # Used to get rid of lint type error. 1292 assert out_spec is not None 1293 assert orig_out_spec is not None 1294 1295 # aot_export expect the return type to always be a tuple. 1296 if out_spec.type not in (list, tuple): 1297 out_spec = pytree.TreeSpec(tuple, None, [out_spec]) 1298 1299 orig_arg_names = gm_torch_level.graph._codegen.pytree_info.orig_args # type: ignore[attr-defined] 1300 1301 gm_torch_level.graph._codegen = _PyTreeCodeGen( 1302 _PyTreeInfo( 1303 orig_arg_names, 1304 gm_torch_level._in_spec, 1305 out_spec, 1306 ) 1307 ) 1308 gm_torch_level.recompile() 1309 1310 _normalize_nn_module_stack(gm_torch_level, type(mod)) 1311 1312 params_buffers_to_node_meta = _collect_param_buffer_metadata(gm_torch_level) 1313 1314 # When aot_export lifts the params, we lose metadata (e.g. source_fn_stack, stack_trace) 1315 # from the param nodes as they are treated as fresh inputs 1316 # Therefore, we manually extract them before calling into aot_export 1317 # params_buffers_to_node_meta = _collect_param_buffer_metadata(gm_torch_level) 1318 1319 constant_attrs = _gather_constant_attrs(mod) 1320 param_buffer_table: Dict[str, str] = _get_param_buffer_mapping(mod, gm_torch_level) 1321 1322 # Dynamo does not track which buffers were registered as non-persistent. This info 1323 # is available in the original module, so we transfer it to the traced module. Also, 1324 # since we didn't restore original param/buffer names yet, we must use traced names. 1325 non_persistent_buffers = _get_non_persistent_buffers(mod) 1326 reverse_name_lookup = {orig: traced for traced, orig in param_buffer_table.items()} 1327 gm_torch_level._non_persistent_buffers_set = { 1328 reverse_name_lookup[name] 1329 for name in non_persistent_buffers 1330 if name in reverse_name_lookup 1331 } 1332 with dynamo_fake_mode: 1333 aten_export_artifact = lower_to_aten_callback( 1334 gm_torch_level, 1335 # NOTE: graph module expects only positional args 1336 _convert_to_positional_args(orig_arg_names, fake_args, fake_kwargs), 1337 {}, 1338 fake_params_buffers, 1339 constant_attrs, 1340 ) 1341 1342 # Decompose for readability. 1343 gm = aten_export_artifact.gm 1344 export_graph_signature = aten_export_artifact.sig 1345 constants = aten_export_artifact.constants 1346 1347 _populate_param_buffer_metadata_to_new_gm( 1348 params_buffers_to_node_meta, gm, export_graph_signature 1349 ) 1350 1351 # Do some cleanups on the graph module to restore the state dict to the 1352 # expected form. Each of these steps should probably get fixed upstream. 1353 # 1. Remove tensor constants that were added as buffers. 1354 _rewrite_dynamo_tensor_constants( 1355 orig_mod_buffers=set(mod.buffers()), 1356 traced_mod_buffers=dict(gm_torch_level.named_buffers()), 1357 graph_signature=export_graph_signature, 1358 constants=constants, 1359 ) 1360 # 2. Restore FQN of param/buffers 1361 _replace_param_buffer_names(param_buffer_table, export_graph_signature) 1362 1363 # 3. Move non-persistent buffers to tensor constants 1364 _move_non_persistent_buffers_to_tensor_constants( 1365 mod, export_graph_signature, constants 1366 ) 1367 1368 # 4. Rewrite constants to have the same FQN as the original module. 1369 _remap_constants(constant_attrs, export_graph_signature, constants) 1370 1371 # 5. Rename constants nodes in graph module from buffers to constants 1372 _rename_constants_nodes(gm, export_graph_signature) 1373 1374 return ExportArtifact( 1375 aten=aten_export_artifact, 1376 out_spec=orig_out_spec, 1377 fake_mode=dynamo_fake_mode, 1378 module_call_specs=gm_torch_level.meta["module_call_specs"], 1379 ) 1380 1381 1382def _export_to_aten_ir_make_fx( 1383 mod: torch.nn.Module, 1384 fake_args, 1385 fake_kwargs, 1386 fake_params_buffers, 1387 constant_attrs: ConstantAttrMap, 1388 produce_guards_callback=None, 1389 transform=lambda x: x, 1390) -> ATenExportArtifact: 1391 @contextmanager 1392 def _compiling_state_context(): 1393 old_value = torch.compiler._is_compiling_flag 1394 try: 1395 torch.compiler._is_compiling_flag = True 1396 yield 1397 finally: 1398 torch.compiler._is_compiling_flag = old_value 1399 1400 def _make_fx_helper(mod, args, kwargs, **flags): 1401 from torch._functorch._aot_autograd.schemas import GraphSignature 1402 1403 kwargs = kwargs or {} 1404 1405 named_parameters = dict(mod.named_parameters(remove_duplicate=False)) 1406 named_buffers = dict(mod.named_buffers(remove_duplicate=False)) 1407 1408 params_and_buffers = {**named_parameters, **named_buffers} 1409 params_and_buffers_flat, params_spec = pytree.tree_flatten(params_and_buffers) 1410 params_and_buffers_flat = tuple(params_and_buffers_flat) 1411 1412 param_len = len(named_parameters) 1413 buffer_len = len(named_buffers) 1414 params_len = len(params_and_buffers) 1415 1416 functional_call = create_functional_call( 1417 mod, params_spec, params_len, store_orig_mod=True 1418 ) 1419 1420 params_buffers_args: List[Any] = [] 1421 params_buffers_args.extend(params_and_buffers_flat) 1422 params_buffers_args.extend(args) 1423 1424 flat_fn, out_spec = create_tree_flattened_fn( 1425 functional_call, params_buffers_args, kwargs 1426 ) 1427 flat_args, in_spec = pytree.tree_flatten((params_buffers_args, kwargs)) 1428 1429 @functools.wraps(flat_fn) 1430 def wrapped_fn(*args): 1431 return tuple(flat_fn(*args)) 1432 1433 with enable_python_dispatcher(): 1434 gm = make_fx( 1435 wrapped_fn, 1436 record_module_stack=True, 1437 pre_dispatch=True, 1438 )(*flat_args) 1439 gm.graph.eliminate_dead_code() 1440 1441 # create graph signature 1442 input_names = _graph_input_names(gm) 1443 output_names = _graph_output_names(gm) 1444 sig = GraphSignature( 1445 parameters=list(named_parameters), 1446 buffers=list(named_buffers), 1447 user_inputs=input_names[params_len:], 1448 user_outputs=output_names, 1449 inputs_to_parameters=dict(zip(input_names[0:param_len], named_parameters)), 1450 inputs_to_buffers=dict( 1451 zip(input_names[param_len : param_len + buffer_len], named_buffers) 1452 ), 1453 buffers_to_mutate={}, 1454 user_inputs_to_mutate={}, 1455 in_spec=in_spec, 1456 out_spec=out_spec, # type: ignore[arg-type] 1457 backward_signature=None, 1458 input_tokens=[], 1459 output_tokens=[], 1460 ) 1461 return gm, sig 1462 1463 # This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode, 1464 # otherwise aot_export_module will error out because it sees a mix of fake_modes. 1465 # And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about. 1466 with torch.nn.utils.stateless._reparametrize_module( 1467 mod, 1468 fake_params_buffers, 1469 tie_weights=True, 1470 strict=True, 1471 stack_weights=True, 1472 ), _ignore_backend_decomps(), _compiling_state_context(): # type: ignore[attr-defined] 1473 param_len = len(dict(mod.named_parameters(remove_duplicate=False))) 1474 buffer_len = len(dict(mod.named_buffers(remove_duplicate=False))) 1475 params_len = param_len + buffer_len 1476 1477 gm, graph_signature = transform(_make_fx_helper)( 1478 mod, 1479 fake_args, 1480 trace_joint=False, 1481 kwargs=fake_kwargs, 1482 ) 1483 1484 if isinstance(mod, torch.fx.GraphModule) and hasattr(mod, "meta"): 1485 gm.meta.update(mod.meta) 1486 1487 flat_args = pytree.tree_leaves((fake_args, fake_kwargs)) 1488 index = 0 1489 for node in gm.graph.nodes: 1490 if node.op == "placeholder": 1491 if index >= params_len: 1492 user_arg = flat_args[index - params_len] 1493 if not isinstance(user_arg, torch.Tensor): 1494 node.meta["val"] = user_arg 1495 index += 1 1496 1497 export_graph_signature = _convert_to_export_graph_signature( 1498 graph_signature, gm, _get_non_persistent_buffers(mod) 1499 ) 1500 1501 # See comment in _export_to_aten_ir() 1502 if produce_guards_callback: 1503 try: 1504 produce_guards_callback(gm) 1505 except (ConstraintViolationError, ValueRangeError) as e: 1506 raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904 1507 1508 fake_mode = detect_fake_mode(flat_args) 1509 1510 if not torch._dynamo.config.do_not_emit_runtime_asserts: 1511 stack_trace = ( 1512 'File "torch/fx/passes/runtime_assert.py", line 24, ' 1513 "in insert_deferred_runtime_asserts" 1514 ) 1515 with _set_node_metadata_hook( 1516 gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace) 1517 ): 1518 insert_deferred_runtime_asserts( 1519 gm, 1520 fake_mode.shape_env, 1521 f"exported program: {first_call_function_nn_module_stack(gm.graph)}", 1522 export=True, 1523 ) 1524 1525 # Remove nn_module_stack, stack_trace metadata from all placeholders/inputs nodes. 1526 for _mod in gm.modules(): 1527 if not isinstance(_mod, torch.fx.GraphModule): 1528 continue 1529 for node in _mod.graph.nodes: 1530 if node.op in ["placeholder", "output"]: 1531 node.meta.pop("nn_module_stack", None) 1532 node.meta.pop("stack_trace", None) 1533 1534 constants = rewrite_script_object_meta(gm) 1535 constants.update(lift_constants_pass(gm, export_graph_signature, constant_attrs)) 1536 1537 _preserve_requires_grad_pass( 1538 gm, export_graph_signature, fake_params_buffers, constants, flat_args 1539 ) 1540 1541 # Prettify names for placeholder nodes. 1542 placeholder_naming_pass( 1543 gm, 1544 export_graph_signature, 1545 mod, 1546 fake_args, 1547 fake_kwargs, 1548 fake_params_buffers, 1549 constants, 1550 ) 1551 1552 return ATenExportArtifact( 1553 gm, 1554 export_graph_signature, 1555 constants, 1556 ) 1557 1558 1559def _non_strict_export( 1560 mod: torch.nn.Module, 1561 args: Tuple[Any, ...], 1562 kwargs: Dict[str, Any], 1563 dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]], 1564 preserve_module_call_signature: Tuple[str, ...], 1565 pre_dispatch: bool, 1566 original_state_dict: Dict[str, Any], 1567 orig_in_spec: TreeSpec, 1568 allow_complex_guards_as_runtime_asserts: bool, 1569 _is_torch_jit_trace: bool, 1570 dispatch_tracing_mode: str = "aot_export", 1571) -> ExportArtifact: 1572 """ 1573 ``dispatch_tracing_mode`` can be either "make_fx” or “aot_export”, corresponding to 1574 _export_to_aten_ir_make_fx and _export_to_aten_ir, respectively. 1575 """ 1576 assert dispatch_tracing_mode in ["make_fx", "aot_export"] 1577 out_spec: Optional[TreeSpec] = None 1578 1579 module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]] = {} 1580 1581 def _tuplify_outputs(aot_export): 1582 def _aot_export_non_strict(mod, args, kwargs=None, **flags): 1583 kwargs = kwargs or {} 1584 1585 class Wrapper(torch.nn.Module): 1586 def __init__(self, mod): 1587 super().__init__() 1588 self._export_root = mod 1589 1590 def forward(self, *args, **kwargs): 1591 nonlocal out_spec 1592 if isinstance(self._export_root, torch.fx.GraphModule): 1593 with torch.fx.traceback.preserve_node_meta(): 1594 tree_out = torch.fx.Interpreter(self._export_root).run( 1595 *args, **kwargs 1596 ) 1597 else: 1598 tree_out = self._export_root(*args, **kwargs) 1599 flat_outs, out_spec = pytree.tree_flatten(tree_out) 1600 return tuple(flat_outs) 1601 1602 wrapped_mod = Wrapper(mod) 1603 # Patch export_root to the signatures so that wrapper module correctly populates the 1604 # in/out spec 1605 new_preserved_call_signatures = [ 1606 "_export_root." + i for i in preserve_module_call_signature 1607 ] 1608 with _wrap_submodules( 1609 wrapped_mod, new_preserved_call_signatures, module_call_specs 1610 ): 1611 gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags) 1612 log.debug("Exported program from AOTAutograd:\n%s", gm) 1613 1614 sig.parameters = pytree.tree_map(_strip_root, sig.parameters) 1615 sig.buffers = pytree.tree_map(_strip_root, sig.buffers) 1616 sig.inputs_to_buffers = pytree.tree_map(_strip_root, sig.inputs_to_buffers) 1617 sig.inputs_to_parameters = pytree.tree_map( 1618 _strip_root, sig.inputs_to_parameters 1619 ) 1620 sig.buffers_to_mutate = pytree.tree_map(_strip_root, sig.buffers_to_mutate) 1621 1622 for node in gm.graph.nodes: 1623 if "nn_module_stack" in node.meta: 1624 nn_module_stack = node.meta["nn_module_stack"] 1625 node.meta["nn_module_stack"] = { 1626 _fixup_key(key): val 1627 for key, val in pytree.tree_map( 1628 _strip_root, nn_module_stack 1629 ).items() 1630 } 1631 1632 return gm, sig 1633 1634 return _aot_export_non_strict 1635 1636 ( 1637 fake_mode, 1638 fake_args, 1639 fake_kwargs, 1640 equalities_inputs, 1641 original_signature, 1642 transformed_dynamic_shapes, 1643 ) = make_fake_inputs( 1644 mod, 1645 args, 1646 kwargs, 1647 dynamic_shapes, 1648 _is_torch_jit_trace=_is_torch_jit_trace, 1649 allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, # for shape env initialization 1650 ) 1651 1652 fake_params_buffers = _fakify_params_buffers(fake_mode, mod) 1653 1654 def _produce_guards_callback(gm): 1655 return produce_guards_and_solve_constraints( 1656 fake_mode=fake_mode, 1657 gm=gm, 1658 dynamic_shapes=transformed_dynamic_shapes, 1659 equalities_inputs=equalities_inputs, 1660 original_signature=original_signature, 1661 _is_torch_jit_trace=_is_torch_jit_trace, 1662 ) 1663 1664 with fake_mode, _NonStrictTorchFunctionHandler(), torch._dynamo.config.patch( 1665 assume_static_by_default=False 1666 ): 1667 with _fakify_script_objects(mod, fake_args, fake_kwargs, fake_mode) as ( 1668 patched_mod, 1669 new_fake_args, 1670 new_fake_kwargs, 1671 new_fake_constant_attrs, 1672 map_fake_to_real, 1673 ): 1674 _to_aten_func = ( 1675 _export_to_aten_ir_make_fx 1676 if dispatch_tracing_mode == "make_fx" 1677 else functools.partial( 1678 _export_to_aten_ir, 1679 pre_dispatch=pre_dispatch, 1680 _is_torch_jit_trace=_is_torch_jit_trace, 1681 ) 1682 ) 1683 aten_export_artifact = _to_aten_func( # type: ignore[operator] 1684 patched_mod, 1685 new_fake_args, 1686 new_fake_kwargs, 1687 fake_params_buffers, 1688 new_fake_constant_attrs, 1689 produce_guards_callback=_produce_guards_callback, 1690 transform=_tuplify_outputs, 1691 ) 1692 # aten_export_artifact.constants contains only fake script objects, we need to map them back 1693 aten_export_artifact.constants = { 1694 fqn: map_fake_to_real[obj] if isinstance(obj, FakeScriptObject) else obj 1695 for fqn, obj in aten_export_artifact.constants.items() 1696 } 1697 1698 _move_non_persistent_buffers_to_tensor_constants( 1699 mod, aten_export_artifact.sig, aten_export_artifact.constants 1700 ) 1701 1702 assert out_spec is not None 1703 1704 return ExportArtifact( 1705 aten=aten_export_artifact, 1706 out_spec=out_spec, 1707 fake_mode=fake_mode, 1708 module_call_specs=module_call_specs, 1709 ) 1710 1711 1712# TODO (tmanlaibaatar) We need to preserve aten.to here somehow 1713@_log_export_wrapper 1714@_disable_prexisiting_fake_mode 1715def _export_for_training( 1716 mod: torch.nn.Module, 1717 args: Tuple[Any, ...], 1718 kwargs: Optional[Dict[str, Any]] = None, 1719 dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None, 1720 *, 1721 strict: bool = True, 1722 preserve_module_call_signature: Tuple[str, ...] = (), 1723) -> ExportedProgram: 1724 global _EXPORT_MODULE_HIERARCHY 1725 _EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod) 1726 1727 ( 1728 args, 1729 kwargs, 1730 orig_in_spec, 1731 original_state_dict, 1732 dynamic_shapes, 1733 ) = _process_export_inputs(mod, args, kwargs, dynamic_shapes) 1734 1735 export_func = ( 1736 functools.partial( 1737 _strict_export_lower_to_aten_ir, 1738 lower_to_aten_callback=_export_to_aten_ir_make_fx, 1739 ) 1740 if strict 1741 else functools.partial( 1742 _non_strict_export, 1743 dispatch_tracing_mode="make_fx", 1744 ) 1745 ) 1746 export_artifact = export_func( # type: ignore[operator] 1747 mod=mod, 1748 args=args, 1749 kwargs=kwargs, 1750 dynamic_shapes=dynamic_shapes, 1751 preserve_module_call_signature=preserve_module_call_signature, 1752 pre_dispatch=False, 1753 original_state_dict=original_state_dict, 1754 orig_in_spec=orig_in_spec, 1755 allow_complex_guards_as_runtime_asserts=False, 1756 _is_torch_jit_trace=False, 1757 ) 1758 1759 export_graph_signature = export_artifact.aten.sig 1760 1761 forward_arg_names = _get_forward_arg_names(mod, args, kwargs) 1762 inline_constraints = _get_inline_constraints(export_artifact.fake_mode) 1763 # The unbacked symint symbols are updated in aot_export 1764 # so we serialize them here instead of inside dynamo. 1765 # Note: _get_range_constraints depends on "inline_constraints" to be set. 1766 export_artifact.aten.gm.meta["inline_constraints"] = inline_constraints 1767 range_constraints = _get_range_constraints( 1768 export_artifact, 1769 _combine_args(mod, args, kwargs, _is_torch_jit_trace=False), 1770 dynamic_shapes, 1771 ) 1772 # The returned the gm is in-place modified 1773 gm, module_call_graph = _get_module_call_graph( 1774 export_artifact, orig_in_spec, preserve_module_call_signature, strict 1775 ) 1776 1777 # Add forward args metadata. 1778 gm.meta["forward_arg_names"] = forward_arg_names 1779 1780 _verify_nn_module_stack(gm) 1781 _verify_stack_trace(gm) 1782 _verify_placeholder_names(gm, export_graph_signature) 1783 1784 from torch._export.verifier import TrainingIRVerifier 1785 1786 exported_program = ExportedProgram( 1787 root=gm, 1788 graph=gm.graph, 1789 graph_signature=export_graph_signature, 1790 state_dict=original_state_dict, 1791 range_constraints=range_constraints, 1792 module_call_graph=module_call_graph, 1793 example_inputs=(args, kwargs), 1794 constants=export_artifact.aten.constants, 1795 verifiers=[TrainingIRVerifier], 1796 ) 1797 1798 return exported_program 1799 1800 1801@_log_export_wrapper 1802@_disable_prexisiting_fake_mode 1803def _export( 1804 mod: torch.nn.Module, 1805 args: Tuple[Any, ...], 1806 kwargs: Optional[Dict[str, Any]] = None, 1807 dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None, 1808 *, 1809 strict: bool = True, 1810 preserve_module_call_signature: Tuple[str, ...] = (), 1811 pre_dispatch: bool = False, 1812 allow_complex_guards_as_runtime_asserts: bool = False, 1813 _is_torch_jit_trace: bool = False, 1814) -> ExportedProgram: 1815 """ 1816 Traces either an nn.Module's forward function or just a callable with PyTorch 1817 operations inside and produce a ExportedProgram. 1818 1819 Args: 1820 f: the `nn.Module` to trace. 1821 1822 args: example positional inputs. 1823 1824 kwargs: optional example keyword inputs. 1825 1826 dynamic_shapes: 1827 An optional argument where the type should either be: 1828 1) a dict from argument names of ``f`` to their dynamic shape specifications, 1829 2) a tuple that specifies dynamic shape specifications for each input in original order. 1830 If you are specifying dynamism on keyword args, you will need to pass them in the order that 1831 is defined in the original function signature. 1832 1833 The dynamic shape of a tensor argument can be specified as either 1834 (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is 1835 not required to include static dimension indices in this dict, but when they are, 1836 they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, 1837 where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions 1838 are denoted by None. Arguments that are dicts or tuples / lists of tensors are 1839 recursively specified by using mappings or sequences of contained specifications. 1840 1841 preserve_module_call_signature: A list of submodule paths for which the original 1842 calling conventions are preserved as metadata. 1843 1844 allow_complex_guards_as_runtime_asserts: 1845 With the current dynamic shapes language for dims and derived dims, we can run into constraints 1846 that are not expressible with the language. For example, flattening a matrix and adding to a vector, 1847 both fully dynamic (i.e. x.reshape([-1]) + y) emits a guard s0 * s1 = s2, which is not expressible. 1848 By default, we either raise a constraint violation error or specialize to static values. 1849 If this flag is set to True, we avoid erroring out and instead allow complex constraints to exist as runtime 1850 assertions in the graph. The sympy interpreter (torch/utils/_sympy/interp.py) will produce the math ops 1851 required to compute and assert the value of the guard (e.g. sym_size_int, eq, _assert_scalar). 1852 Additionally, if TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1 is specified, we will allow complex constraints 1853 while not emitting runtime asserts, returning a cleaner graph with lesser guarantees around dynamic shapes. 1854 1855 Returns: 1856 An ExportedProgram containing the traced method. 1857 """ 1858 1859 global _EXPORT_FLAGS, _EXPORT_MODULE_HIERARCHY 1860 _EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod) 1861 1862 flags = set() 1863 flags.add("strict" if strict else "non_strict") 1864 flags.add("pre_dispatch" if pre_dispatch else "aot_dispatch") 1865 _EXPORT_FLAGS = flags 1866 1867 log_export_usage(event="export.enter", flags=_EXPORT_FLAGS) 1868 1869 ( 1870 args, 1871 kwargs, 1872 original_in_spec, 1873 original_state_dict, 1874 dynamic_shapes, 1875 ) = _process_export_inputs(mod, args, kwargs, dynamic_shapes) 1876 1877 # Call the appropriate export function based on the strictness of tracing. 1878 export_func = _strict_export if strict else _non_strict_export 1879 1880 export_artifact = export_func( # type: ignore[operator] 1881 mod, 1882 args, 1883 kwargs, 1884 dynamic_shapes, 1885 preserve_module_call_signature, 1886 pre_dispatch, 1887 original_state_dict, 1888 original_in_spec, 1889 allow_complex_guards_as_runtime_asserts, 1890 _is_torch_jit_trace, 1891 ) 1892 export_graph_signature: ExportGraphSignature = export_artifact.aten.sig 1893 1894 forward_arg_names = ( 1895 _get_forward_arg_names(mod, args, kwargs) if not _is_torch_jit_trace else None 1896 ) 1897 inline_constraints = _get_inline_constraints(export_artifact.fake_mode) 1898 # The unbacked symint symbols are updated in aot_export 1899 # so we serialize them here instead of inside dynamo. 1900 # Note: this step must be before _get_range_constraints. 1901 export_artifact.aten.gm.meta["inline_constraints"] = inline_constraints 1902 range_constraints = _get_range_constraints( 1903 export_artifact, 1904 _combine_args(mod, args, kwargs, _is_torch_jit_trace=_is_torch_jit_trace), 1905 dynamic_shapes, 1906 ) 1907 gm, module_call_graph = _get_module_call_graph( 1908 export_artifact, original_in_spec, preserve_module_call_signature, strict 1909 ) 1910 1911 # Add forward args metadata. 1912 gm.meta["forward_arg_names"] = forward_arg_names 1913 1914 _verify_nn_module_stack(gm) 1915 _verify_stack_trace(gm) 1916 if not _is_torch_jit_trace: 1917 _verify_placeholder_names(gm, export_graph_signature) 1918 1919 # Remove Proxy because they cannot be deepcopied or pickled. 1920 torch._export.utils.remove_proxy_from_state_dict(original_state_dict, in_place=True) 1921 1922 from torch._export.verifier import Verifier 1923 1924 if ( 1925 isinstance(mod, torch.fx.GraphModule) 1926 and hasattr(mod, "meta") 1927 and "custom" in mod.meta 1928 ): 1929 gm.meta.update({"custom": mod.meta["custom"]}) 1930 1931 exported_program = ExportedProgram( 1932 root=gm, 1933 graph=gm.graph, 1934 graph_signature=export_graph_signature, 1935 state_dict=original_state_dict, 1936 range_constraints=range_constraints, 1937 module_call_graph=module_call_graph, 1938 example_inputs=(args, kwargs), 1939 constants=export_artifact.aten.constants, 1940 verifiers=[Verifier], 1941 ) 1942 1943 return exported_program 1944