xref: /aosp_15_r20/external/pytorch/torch/export/_trace.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import dataclasses
4import functools
5import inspect
6import logging
7import re
8import time
9import warnings
10from contextlib import contextmanager, nullcontext
11from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
12
13import torch
14import torch._dynamo
15import torch.fx
16import torch.utils._pytree as pytree
17from torch._dispatch.python import enable_python_dispatcher
18from torch._dynamo.exc import UserError, UserErrorType
19from torch._export.db.logging import (
20    exportdb_error_message,
21    get_class_if_classified_error,
22)
23from torch._export.non_strict_utils import (
24    _fakify_script_objects,
25    _gather_constant_attrs,
26    _NonStrictTorchFunctionHandler,
27    make_constraints,
28    make_fake_inputs,
29    produce_guards_and_solve_constraints,
30)
31from torch._export.passes._node_metadata_hook import (
32    _node_metadata_hook,
33    _set_node_metadata_hook,
34)
35from torch._export.passes.collect_tracepoints_pass import CollectTracepointsPass
36from torch._export.passes.lift_constants_pass import (
37    ConstantAttrMap,
38    lift_constants_pass,
39    rewrite_script_object_meta,
40)
41from torch._export.utils import (
42    _collect_param_buffer_metadata,
43    _get_shape_env_from_gm,
44    _populate_param_buffer_metadata_to_new_gm,
45    placeholder_naming_pass,
46    placeholder_prefixes,
47)
48from torch._export.verifier import SpecViolationError
49from torch._export.wrappers import _wrap_submodules
50from torch._functorch._aot_autograd.input_output_analysis import (
51    _graph_input_names,
52    _graph_output_names,
53)
54from torch._functorch._aot_autograd.traced_function_transforms import (
55    create_functional_call,
56)
57from torch._functorch._aot_autograd.utils import create_tree_flattened_fn
58from torch._functorch.aot_autograd import aot_export_module
59from torch._guards import detect_fake_mode
60from torch._library.fake_class_registry import FakeScriptObject
61from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
62from torch._utils_internal import log_export_usage
63from torch.export.dynamic_shapes import (
64    _check_dynamic_shapes,
65    _combine_args,
66    _transform_shapes_for_default_dynamic,
67)
68from torch.export.exported_program import OutputKind
69from torch.fx._utils import first_call_function_nn_module_stack
70from torch.fx.experimental.proxy_tensor import make_fx
71from torch.fx.experimental.symbolic_shapes import (
72    ConstraintViolationError,
73    free_unbacked_symbols,
74    GuardOnDataDependentSymNode,
75    ShapeEnv,
76)
77from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
78from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
79from torch.utils._pytree import TreeSpec
80from torch.utils._sympy.value_ranges import ValueRangeError
81
82from ._safeguard import AutogradStateOpsFailSafeguard
83from .exported_program import (
84    _disable_prexisiting_fake_mode,
85    ExportedProgram,
86    InputKind,
87    ModuleCallEntry,
88    ModuleCallSignature,
89)
90from .graph_signature import _convert_to_export_graph_signature, ExportGraphSignature
91
92
93log = logging.getLogger(__name__)
94
95
96@dataclasses.dataclass
97class ExportDynamoConfig:
98    """
99    Manage Export-specific configurations of Dynamo.
100    """
101
102    allow_rnn: bool = True
103    reorderable_logging_functions: Set[Callable] = dataclasses.field(
104        default_factory=set
105    )
106    # Emit runtime asserts after AOTAutograd instead.
107    # This isn't really necessary, and isn't much more efficient since the runtime asserts pass does CSE,
108    # but if we want to reason more about what guards/runtime asserts to emit,
109    # this makes it a bit cleaner to do from the export side. Also no real point in running this twice.
110    do_not_emit_runtime_asserts = True
111
112
113@dataclasses.dataclass
114class ATenExportArtifact:
115    gm: torch.fx.GraphModule
116    sig: ExportGraphSignature
117    constants: Dict[
118        str,
119        Union[
120            torch.Tensor,
121            FakeScriptObject,
122            torch.ScriptObject,
123        ],
124    ]
125
126
127@dataclasses.dataclass(frozen=True)
128class ExportArtifact:
129    aten: ATenExportArtifact
130    out_spec: TreeSpec
131    fake_mode: FakeTensorMode
132    module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]]
133
134
135DEFAULT_EXPORT_DYNAMO_CONFIG = ExportDynamoConfig()
136DEFAULT_EXPORT_DYNAMO_CONFIG.reorderable_logging_functions = {
137    logging.critical,
138    logging.debug,
139    logging.error,
140    logging.exception,
141    logging.info,
142    logging.log,
143    logging.warning,
144    print,
145    warnings.warn,
146}
147
148
149@contextmanager
150def _ignore_backend_decomps():
151    orig_mkldnn_flag = torch.backends.mkldnn.set_flags(False)
152    orig_nnpack_flag = torch.backends.nnpack.set_flags(False)
153    try:
154        yield
155    finally:
156        torch.backends.mkldnn.set_flags(*orig_mkldnn_flag)
157        torch.backends.nnpack.set_flags(*orig_nnpack_flag)
158
159
160def _fixup_key(x):
161    return "L__self__" + _strip_root(x)
162
163
164def _strip_root(x):
165    if isinstance(x, str) and x.startswith("_export_root"):
166        stripped = x[len("_export_root") :]
167        return stripped[1:] if stripped.startswith(".") else stripped
168    return x
169
170
171def _rewrite_tracepoint_node(gm: torch.fx.GraphModule):
172    """
173    In-place modifiy input graph module by replacing the export tracepoint with a new node
174    that has the same target and args, but with the _export_root stripped from path.
175    """
176    for node in gm.graph.nodes:
177        if node.target == torch.ops.higher_order._export_tracepoint:
178            if "path" in node.kwargs:
179                path = _strip_root(node.kwargs["path"])
180                with gm.graph.inserting_before(node):
181                    new_node = gm.graph.create_node(
182                        "call_function",
183                        torch.ops.higher_order._export_tracepoint,
184                        args=node.args,
185                        kwargs={
186                            "path": path,
187                            "kind": node.kwargs["kind"],
188                        },
189                    )
190                    new_node.meta = node.meta
191                    node.replace_all_uses_with(new_node)
192                    gm.graph.erase_node(node)
193
194
195def _extract_fake_inputs(gm, args, kwargs):
196    """
197    Given a graph module, extract fakified input tensors from the metadata of
198    its placeholders, and map them to the structure of given args and kwargs.
199    Also return the fake mode used to fakify those inputs.
200    """
201
202    fake_inps: List[torch.Tensor] = []
203    fake_vals: List[torch.Tensor] = []
204    for node in gm.graph.nodes:
205        if node.op == "placeholder" and "val" in node.meta:
206            fake_val = node.meta["val"]
207            if fake_val is not None and isinstance(fake_val, torch.Tensor):
208                fake_inps.append(fake_val)
209        elif "example_value" in node.meta:
210            fake_val = node.meta["example_value"]
211            if fake_val is not None and isinstance(fake_val, torch.Tensor):
212                fake_vals.append(fake_val)
213
214    if detected_fake_mode := detect_fake_mode(fake_inps + fake_vals):
215        fake_mode = detected_fake_mode
216    else:
217        fake_mode = FakeTensorMode(shape_env=ShapeEnv(), export=True)
218
219    count = 0
220
221    def lookup_fake(x):
222        nonlocal count
223        val = fake_inps[count]
224        count += 1
225        return val
226
227    fake_args = pytree.tree_map_only(torch.Tensor, lookup_fake, args)
228    fake_kwargs = pytree.tree_map_only(torch.Tensor, lookup_fake, kwargs)
229
230    return fake_args, fake_kwargs, fake_mode
231
232
233def _replace_param_buffer_names(param_buffer_table, sig):
234    for spec in sig.input_specs:
235        if spec.kind in (
236            InputKind.PARAMETER,
237            InputKind.BUFFER,
238        ):
239            spec.target = param_buffer_table[spec.target]
240    for spec in sig.output_specs:
241        if spec.kind in (
242            OutputKind.BUFFER_MUTATION,
243            OutputKind.GRADIENT_TO_PARAMETER,
244        ):
245            spec.target = param_buffer_table[spec.target]
246
247
248def _convert_to_positional_args(orig_arg_names, args, kwargs):
249    assert len(orig_arg_names) == len(args) + len(kwargs), (
250        f"Total number of arg names is expected to be {len(orig_arg_names)} "
251        f"but got {len(args)} positional args, {len(kwargs)} kwargs."
252    )
253    reordered_kwargs = [kwargs[kw_name] for kw_name in orig_arg_names[len(args) :]]
254    return (
255        *args,
256        *reordered_kwargs,
257    )
258
259
260def _normalize_nn_module_stack(gm_torch_level, root_cls):
261    # Append a root module to every nn_module_stack.
262    root = "L['self']"
263    root_key = re.sub(r"[^a-zA-Z0-9]", "_", root)
264    for gm in gm_torch_level.modules():
265        if not isinstance(gm, torch.fx.GraphModule):
266            continue
267        for node in gm.graph.nodes:
268            if node.op in ["placeholder", "output"]:
269                continue
270            add_root = True
271            if nn_module_stack := node.meta.get("nn_module_stack", {}):
272                path, ty = next(iter(nn_module_stack.values()))
273                # After deserializing the class `ty` might not exist anymore so
274                # it could be a string
275                if inspect.isclass(ty) and issubclass(ty, torch.nn.Module):
276                    # TODO Figure out why sometimes we have root sometimes we don't.
277                    if path == root and ty is root_cls:
278                        add_root = False
279                else:
280                    assert isinstance(ty, str)
281            if add_root:
282
283                def normalize_path(path):
284                    try:
285                        parts = []
286
287                        class Path:
288                            def __getattr__(self, name):
289                                parts.append(name)
290                                return self
291
292                            def __getitem__(self, idx):
293                                parts.append(str(idx))
294                                return self
295
296                        eval(path, {"L": {"self": Path()}})
297                        return ".".join(parts)
298                    except Exception:  # TODO(zhxchen17) Remove this.
299                        return path
300
301                nn_module_stack = {
302                    root_key: (root, root_cls.__module__ + "." + root_cls.__qualname__),
303                    **nn_module_stack,
304                }
305                node.meta["nn_module_stack"] = {
306                    key: (normalize_path(path), ty)
307                    for key, (path, ty) in nn_module_stack.items()
308                }
309
310
311def _get_param_buffer_mapping(
312    original_module: torch.nn.Module,
313    traced_module: torch.nn.Module,
314) -> Dict[str, str]:
315    """
316    Returns a mapping of parameter/buffer names from the new module to the
317    original model. This is to help with restoring the FQN for parameter/buffers
318    of a traced module to what the original module contains.
319    """
320
321    param_lookup: Dict[int, str] = {}
322    buffer_lookup: Dict[int, str] = {}
323    for name, param in original_module.named_parameters(remove_duplicate=False):
324        param_lookup[id(param)] = name
325    for name, buffer in original_module.named_buffers(remove_duplicate=False):
326        buffer_lookup[id(buffer)] = name
327
328    param_buffer_table: Dict[str, str] = {}
329    for dynamo_name, dynamo_param in traced_module.named_parameters(
330        remove_duplicate=False
331    ):
332        assert dynamo_name not in param_buffer_table
333        if id(dynamo_param) in param_lookup:
334            param_buffer_table[dynamo_name] = param_lookup[id(dynamo_param)]
335
336    for dynamo_name, dynamo_buffer in traced_module.named_buffers(
337        remove_duplicate=False
338    ):
339        assert dynamo_name not in param_buffer_table
340        if id(dynamo_buffer) in buffer_lookup:
341            param_buffer_table[dynamo_name] = buffer_lookup[id(dynamo_buffer)]
342
343    return param_buffer_table
344
345
346def _preserve_requires_grad_pass(
347    gm: torch.fx.GraphModule,
348    sig: ExportGraphSignature,
349    fake_params_buffers: Dict[str, torch.Tensor],
350    constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]],
351    flat_fake_args: List[Any],
352):
353    placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"]
354    assert len(sig.input_specs) == len(placeholders)
355    i = 0
356    for node, spec in zip(placeholders, sig.input_specs):
357        if spec.kind in (
358            InputKind.PARAMETER,
359            InputKind.BUFFER,
360        ):
361            assert spec.target is not None
362            node.meta["val"].requires_grad = fake_params_buffers[
363                spec.target
364            ].requires_grad
365        elif spec.kind == InputKind.USER_INPUT:
366            fake_arg = flat_fake_args[i]
367            if isinstance(fake_arg, torch.Tensor):
368                node.meta["val"].requires_grad = fake_arg.requires_grad
369            i += 1
370        elif spec.kind == InputKind.CONSTANT_TENSOR:
371            assert spec.target is not None
372            constant = constants[spec.target]
373            if isinstance(constant, torch.Tensor):
374                # If the tensor is not leaf, it should already have a correct requires grad field
375                if node.meta["val"].is_leaf:
376                    node.meta["val"].requires_grad = constant.requires_grad
377                else:
378                    assert node.meta["val"].requires_grad == constant.requires_grad
379        elif spec.kind in (InputKind.CUSTOM_OBJ, InputKind.TOKEN):
380            continue
381        else:
382            raise AssertionError(spec.kind)
383
384
385def _remap_constants(
386    orig_constant_attrs: ConstantAttrMap,
387    graph_signature: ExportGraphSignature,
388    constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]],
389) -> None:
390    """Rewrite the graph signature and constants table to use the FQN from the original module."""
391    remap_table: Dict[str, List[str]] = {}
392    for name, value in constants.items():
393        if value in orig_constant_attrs:
394            remap_table[name] = orig_constant_attrs[value]
395
396    for spec in graph_signature.input_specs:
397        if spec.kind in (
398            InputKind.CONSTANT_TENSOR,
399            InputKind.CUSTOM_OBJ,
400        ):
401            orig_target = spec.target
402            assert orig_target is not None
403            targets = remap_table.get(orig_target, [orig_target])
404            spec.target = targets[0]
405
406            constant = constants[orig_target]
407            del constants[orig_target]
408            for target in targets:
409                constants[target] = constant
410
411
412def _rename_constants_nodes(
413    gm: torch.fx.GraphModule,
414    graph_signature: ExportGraphSignature,
415) -> None:
416    """
417    For strict mode, rename constants nodes that were previously annotated as buffers.
418    """
419    # handle name collisions with existing constants
420    node_names = {node.name for node in gm.graph.nodes}
421
422    def rename_constant(name):
423        if name in node_names:
424            n = 1
425            while (dup_name := f"{name}_{n}") in node_names:
426                n += 1
427            name = dup_name
428        node_names.add(name)
429        return name
430
431    # use input specs to map names from buffers to constants
432    buffer_prefix = placeholder_prefixes[InputKind.BUFFER]
433    const_prefix = placeholder_prefixes[InputKind.CONSTANT_TENSOR]
434    buffer_to_constant = {}
435    for spec in graph_signature.input_specs:
436        if spec.kind == InputKind.CONSTANT_TENSOR and not spec.arg.name.startswith(
437            const_prefix
438        ):
439            if spec.arg.name.startswith(buffer_prefix):  # map from buffer to constants
440                c_name = rename_constant(
441                    const_prefix + spec.arg.name[len(buffer_prefix) :]
442                )
443            else:  # lifted constant
444                c_name = rename_constant(const_prefix + spec.arg.name)
445            buffer_to_constant[spec.arg.name] = c_name
446            spec.arg.name = c_name
447    for spec in graph_signature.output_specs:
448        if spec.arg.name in buffer_to_constant:
449            spec.arg.name = buffer_to_constant[spec.arg.name]
450
451    # Rename constants nodes for all modules
452    for mod in gm.modules():
453        if not isinstance(mod, torch.fx.GraphModule):
454            continue
455        for node in mod.graph.nodes:
456            if node.name in buffer_to_constant:
457                node.name = node.target = buffer_to_constant[node.name]
458        mod.recompile()
459
460
461def _restore_state_dict(
462    original_module: torch.nn.Module, traced_module: torch.fx.GraphModule
463) -> None:
464    """
465    Restores the state dict of the traced module to that of the original module.
466    """
467    param_buffer_table = _get_param_buffer_mapping(original_module, traced_module)
468    # Since the graph module is flattened (no module heirarchy), we
469    # need to noramlize the module by replacing "." with "_". If we
470    # don't, it will try to save the weight to a submodule which no
471    # longer exists.
472    for name, fqn in param_buffer_table.items():
473        param_buffer_table[name] = fqn.replace(".", "_")
474
475    # Replace state dict attr names with the fqn
476    for name, fqn in param_buffer_table.items():
477        if not hasattr(traced_module, name):
478            continue
479
480        attr = getattr(traced_module, name)
481        if isinstance(attr, torch.Tensor) and not isinstance(attr, torch.nn.Parameter):
482            traced_module.register_buffer(fqn, attr)
483        else:
484            setattr(traced_module, fqn, attr)
485        delattr(traced_module, name)
486
487    # Replace graph getattr nodes with the correct name
488    for node in traced_module.graph.nodes:
489        if node.op == "get_attr":
490            attr_name = node.target
491            if attr_name in param_buffer_table:
492                node.target = param_buffer_table[attr_name]
493
494    traced_module.recompile()
495
496
497def _get_module_hierarchy(mod: torch.nn.Module) -> Dict[str, str]:
498    return {
499        name: type(m).__name__ for name, m in mod.named_modules(remove_duplicate=False)
500    }
501
502
503def _make_module_call_graph(
504    module_hierarchy: Dict[str, str],
505    in_spec: TreeSpec,
506    out_spec: TreeSpec,
507    module_call_signatures: Dict[str, ModuleCallSignature],
508) -> List[ModuleCallEntry]:
509    ret = [
510        ModuleCallEntry(fqn=fqn, signature=module_call_signatures.get(fqn))
511        for fqn in module_hierarchy
512    ]
513    assert ret[0].fqn == ""
514    ret[0].signature = ModuleCallSignature(
515        inputs=[], outputs=[], in_spec=in_spec, out_spec=out_spec
516    )
517    return ret
518
519
520def _export_to_torch_ir(
521    f: Callable,
522    args: Tuple[Any, ...],
523    kwargs: Optional[Dict[str, Any]] = None,
524    dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
525    *,
526    preserve_module_call_signature: Tuple[str, ...] = (),
527    disable_constraint_solver: bool = False,
528    allow_complex_guards_as_runtime_asserts: bool = False,
529    restore_fqn: bool = True,
530    _log_export_usage: bool = True,
531    same_signature: bool = True,
532) -> torch.fx.GraphModule:
533    """
534    Traces either an nn.Module's forward function or just a callable with PyTorch
535    operations inside and produce a torch.fx.GraphModule in torch IR.
536    """
537
538    if _log_export_usage:
539        log_export_usage(event="export.private_api", flags={"_export_to_torch_ir"})
540
541    if not isinstance(args, tuple):
542        raise UserError(
543            UserErrorType.INVALID_INPUT,
544            f"Expecting `args` to be a tuple of example positional inputs, got {type(args)}",
545        )
546
547    kwargs = kwargs or {}
548    combined_args = _combine_args(f, args, kwargs)
549    _check_dynamic_shapes(combined_args, dynamic_shapes)
550    transformed_dynamic_shapes = _transform_shapes_for_default_dynamic(
551        combined_args, dynamic_shapes
552    )
553
554    with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)):
555        try:
556            module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]] = {}
557            with _wrap_submodules(
558                f, preserve_module_call_signature, module_call_specs
559            ), _ignore_backend_decomps():
560                gm_torch_level, _ = torch._dynamo.export(
561                    f,
562                    dynamic_shapes=transformed_dynamic_shapes,  # type: ignore[arg-type]
563                    tracing_mode="symbolic",
564                    disable_constraint_solver=disable_constraint_solver,
565                    # currently the following 2 flags are tied together for export purposes,
566                    # but untangle for sake of dynamo export api
567                    prefer_deferred_runtime_asserts_over_guards=True,
568                    allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
569                    _log_export_usage=_log_export_usage,
570                    same_signature=same_signature,
571                )(
572                    *args,
573                    **kwargs,
574                )
575        except (ConstraintViolationError, ValueRangeError) as e:
576            raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e))  # noqa: B904
577        except GuardOnDataDependentSymNode as e:
578            raise UserError(  # noqa: B904
579                UserErrorType.ANTI_PATTERN,
580                f"Consider annotating your code using torch._check*(). {str(e)}",
581                case_name="constrain_as_size_example",
582            )
583
584    gm_torch_level.meta["module_call_specs"] = module_call_specs
585
586    if isinstance(f, torch.nn.Module) and restore_fqn:
587        _restore_state_dict(f, gm_torch_level)
588
589    return gm_torch_level
590
591
592def _export_to_aten_ir(
593    mod: torch.nn.Module,
594    fake_args,
595    fake_kwargs,
596    fake_params_buffers,
597    constant_attrs: ConstantAttrMap,
598    produce_guards_callback=None,
599    *,
600    transform=lambda x: x,  # TODO(zhxchen17) Revisit if this is needed later.
601    pre_dispatch=False,
602    decomp_table=None,
603    _check_autograd_state=True,
604    _is_torch_jit_trace=False,
605) -> ATenExportArtifact:
606    # [NOTE] If the user is exporting under training mode, we want to detect if there is any
607    # state change in the autograd global state and error. If the user is exporting under inference
608    # mode, we don't care. At predispatch level, we don't care about the state change.
609    is_grad_enabled = torch._C.is_grad_enabled()
610    grad_safe_guard = nullcontext()
611    # export_to_aten_ir is called when we decompose the ep into inference IR
612    # In that setting, we actually shouldn't check the state change as at this point,
613    # because the intention is specalizing to inference.
614    if _check_autograd_state:
615        if not pre_dispatch and is_grad_enabled:
616            grad_safe_guard = AutogradStateOpsFailSafeguard()  # type: ignore[assignment]
617
618    @contextmanager
619    def _compiling_state_context():
620        old_value = torch.compiler._is_compiling_flag
621        try:
622            torch.compiler._is_compiling_flag = True
623            yield
624        finally:
625            torch.compiler._is_compiling_flag = old_value
626
627    # This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode,
628    # otherwise aot_export_module will error out because it sees a mix of fake_modes.
629    # And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about.
630    with torch.nn.utils.stateless._reparametrize_module(
631        mod,
632        fake_params_buffers,
633        tie_weights=True,
634        strict=True,
635        stack_weights=True,
636    ), grad_safe_guard, _ignore_backend_decomps(), _compiling_state_context():  # type: ignore[attr-defined]
637        gm, graph_signature = transform(aot_export_module)(
638            mod,
639            fake_args,
640            trace_joint=False,
641            pre_dispatch=pre_dispatch,
642            decompositions=decomp_table,
643            kwargs=fake_kwargs,
644        )
645
646    def _maybe_fixup_gm_and_output_node_meta(old_gm, new_gm):
647        if isinstance(old_gm, torch.fx.GraphModule):
648            if hasattr(old_gm, "meta"):
649                new_gm.meta.update(old_gm.meta)
650            old_output_node = list(old_gm.graph.nodes)[-1]
651            new_output_node = list(new_gm.graph.nodes)[-1]
652            assert old_output_node.op == "output" and new_output_node.op == "output"
653            # make sure we don't override any meta
654            assert len(new_output_node.meta) == 0
655            new_output_node.meta.update(old_output_node.meta)
656
657    # TODO unfortunately preserving graph-level metadata and output node's meta
658    # is not working well with aot_export. So we manually copy it.
659    # (The node-level meta is addressed above.)
660    _maybe_fixup_gm_and_output_node_meta(mod, gm)
661
662    # Run produce guards before we handle runtime asserts.
663    # This means we run the export solver before the runtime asserts pass.
664    # Right now this doesn't mean much - the export solver is only there for suggested fixes,
665    # and we won't even get to constraint solving if that's needed.
666    # But if in future we want to control what runtime asserts are emitted for export,
667    # or rely on produce_guards + solver for some simplification on runtime asserts, this probably makes sense.
668    if produce_guards_callback:
669        try:
670            produce_guards_callback(gm)
671        except (ConstraintViolationError, ValueRangeError) as e:
672            raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e))  # noqa: B904
673
674    # Run runtime asserts pass before creating input/output specs, since size-related CSE/DCE might affect output signature.
675    # Overwrite output specs afterwards.
676    flat_fake_args = pytree.tree_leaves((fake_args, fake_kwargs))
677    if not torch._dynamo.config.do_not_emit_runtime_asserts:
678        stack_trace = (
679            'File "torch/fx/passes/runtime_assert.py", line 24, '
680            "in insert_deferred_runtime_asserts"
681        )
682        with _set_node_metadata_hook(
683            gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace)
684        ):
685            shape_env = _get_shape_env_from_gm(gm)
686            if shape_env:
687                insert_deferred_runtime_asserts(
688                    gm,
689                    shape_env,
690                    f"exported program: {first_call_function_nn_module_stack(gm.graph)}",
691                    export=True,
692                )
693
694    # update output specs
695    gm.recompile()
696    graph_signature.user_outputs = _graph_output_names(gm)
697
698    # NOTE: aot_export adds symint metadata for placeholders with int values;
699    # since these become specialized, we replace such metadata with the original values
700    index = 0
701    total_non_user_inputs = (
702        len(graph_signature.parameters)
703        + len(graph_signature.buffers)
704        + len(graph_signature.input_tokens)
705    )
706    for node in gm.graph.nodes:
707        if node.op == "placeholder":
708            if index >= total_non_user_inputs:
709                user_arg = flat_fake_args[index - total_non_user_inputs]
710                if not isinstance(user_arg, torch.Tensor):
711                    node.meta["val"] = user_arg
712            index += 1
713
714    export_graph_signature = _convert_to_export_graph_signature(
715        graph_signature, gm, _get_non_persistent_buffers(mod)
716    )
717
718    constants = rewrite_script_object_meta(gm)
719    constants.update(lift_constants_pass(gm, export_graph_signature, constant_attrs))
720
721    if pre_dispatch:
722        from torch._export.passes.replace_autocast_with_hop_pass import (
723            replace_autocast_with_hop_pass,
724        )
725        from torch._export.passes.replace_set_grad_with_hop_pass import (
726            replace_set_grad_with_hop_pass,
727        )
728
729        # Note: replace_set_grad_with_hop_pass need to be after lift_constant_pass because
730        # a getattr of a constant tensor doesn't have meta["val"] until after lift_constant_pass.
731        # If replace_set_grad_with_hop_pass is before lift_constant_pass,
732        # and the constant_tensor is passed as input of the set grad hop, the placeholder's
733        # meta["val"] will be None and fails our verifier for placeholder.
734        gm, export_graph_signature = replace_set_grad_with_hop_pass(
735            gm, export_graph_signature
736        )
737
738        gm, export_graph_signature = replace_autocast_with_hop_pass(
739            gm, export_graph_signature
740        )
741
742    # Remove nn_module_stack, stack_trace metadata from all placeholders/inputs nodes.
743    for _mod in gm.modules():
744        if not isinstance(_mod, torch.fx.GraphModule):
745            continue
746        for node in _mod.graph.nodes:
747            if node.op in ["placeholder", "output"]:
748                node.meta.pop("nn_module_stack", None)
749                node.meta.pop("stack_trace", None)
750
751    # Prettify names for placeholder nodes.
752    placeholder_naming_pass(
753        gm,
754        export_graph_signature,
755        mod,
756        fake_args,
757        fake_kwargs,
758        fake_params_buffers,
759        constants,
760    )
761
762    _preserve_requires_grad_pass(
763        gm, export_graph_signature, fake_params_buffers, constants, flat_fake_args
764    )
765
766    return ATenExportArtifact(
767        gm,
768        export_graph_signature,
769        constants,
770    )
771
772
773def _fakify_params_buffers(
774    fake_mode: FakeTensorMode,
775    mod: torch.nn.Module,
776) -> Dict[str, Union[torch.Tensor, torch.nn.Parameter]]:
777    params_buffers = {
778        **dict(mod.named_parameters(remove_duplicate=False)),
779        **dict(mod.named_buffers(remove_duplicate=False)),
780    }
781
782    faked_params_buffers = {}
783    memo: Dict[int, FakeTensor] = {}
784    for key, value in params_buffers.items():
785        if id(value) in memo:
786            fake_tensor = memo[id(value)]
787        else:
788            fake_tensor = fake_mode.from_tensor(value, static_shapes=True)
789            memo[id(value)] = fake_tensor
790        faked_params_buffers[key] = fake_tensor
791    return faked_params_buffers  # type: ignore[return-value]
792
793
794def _get_forward_arg_names(
795    mod: torch.nn.Module,
796    args: Tuple[Any, ...],
797    kwargs: Optional[Dict[str, Any]] = None,
798) -> List[str]:
799    """
800    Gets the argument names to forward that are used, for restoring the
801    original signature when unlifting the exported program module.
802    - Positional args: retain the original argument names, and enumerate
803        *args as args_0, args_1, ...
804    - Keyword args: retain the original kwarg names in the order specified
805        by the user. This order seems to matter for the current state of
806        export lifted modules.
807    """
808    sig = inspect.signature(mod.forward)
809    _args = sig.bind_partial(*args).arguments
810
811    names: List[str] = []
812    for name, value in _args.items():
813        # handle variable number of positional args
814        if sig.parameters[name].kind == inspect._ParameterKind.VAR_POSITIONAL:
815            names.extend([f"{name}_{i}" for i, _ in enumerate(value)])
816        else:
817            names.append(name)
818    # order of kwargs matters for input spec
819    if kwargs:
820        names.extend([kwarg for kwarg, _ in kwargs.items()])
821
822    return names
823
824
825def _get_non_persistent_buffers(mod: torch.nn.Module) -> Set[str]:
826    """
827    Returns set of non-persistent buffers in a module and its submodules.
828    """
829    result = set()
830    for name, m in mod.named_modules():
831        for b in m._non_persistent_buffers_set:
832            result.add(f"{name}.{b}" if name else b)
833    return result
834
835
836def _rewrite_dynamo_tensor_constants(
837    orig_mod_buffers: Set[torch.Tensor],
838    traced_mod_buffers: Dict[str, torch.Tensor],
839    graph_signature: ExportGraphSignature,
840    constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]],
841):
842    """
843    Dynamo erroneously marks tensor attributes on modules as buffers.
844    Rewrite them to be tensor constants.
845    """
846    for spec in graph_signature.input_specs:
847        if spec.kind == InputKind.BUFFER:
848            assert spec.target is not None
849            value = traced_mod_buffers[spec.target]
850            if value not in orig_mod_buffers:
851                # This was a tensor constant erroneously marked as a buffer.
852                # Convert it into a constant in the graph signature, and add its
853                # value to the constants table.
854                spec.kind = InputKind.CONSTANT_TENSOR
855                constants[spec.target] = value  # type: ignore[arg-type]
856
857
858def _move_non_persistent_buffers_to_tensor_constants(
859    orig_mod: torch.nn.Module,
860    graph_signature: ExportGraphSignature,
861    constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]],
862):
863    """
864    Moves non-persistent buffers to tensor constants.
865    """
866    for spec in graph_signature.input_specs:
867        if spec.kind == InputKind.BUFFER and not spec.persistent:
868            assert spec.target is not None
869            assert spec.target not in constants
870            constants[spec.target] = orig_mod.get_buffer(spec.target)  # type: ignore[arg-type]
871
872
873def _verify_nn_module_stack(graph_module: torch.fx.GraphModule) -> None:
874    """
875    Perform nn_module_stack checks on the graph.
876    Current constraints:
877        For the top level graph:
878        - populated for 'call_function', 'get_attr'
879        - None for 'placeholder', 'output'
880        For submodule graphs:
881        - None for 'placeholder', output'
882
883    TODO(pianpwk): make this a consistent node-level check once nn_module_stack is populated for cond submodules.
884    """
885    # Check top-level graph for all nodes, all graphs for placeholder & output nodes
886    for i, mod in enumerate([graph_module] + list(graph_module.modules())):
887        if not isinstance(mod, torch.fx.GraphModule):
888            continue
889        for node in mod.graph.nodes:
890            if node.op in ["call_function", "get_attr"]:
891                if i == 0:
892                    if (
893                        nn_module_stack := node.meta.get("nn_module_stack", None)
894                    ) is None:
895                        raise SpecViolationError(
896                            f"Node {node} of type {node.op} is missing nn_module_stack metadata"
897                        )
898                    if not all(
899                        isinstance(k, str)
900                        and isinstance(v, tuple)
901                        and len(v) == 2
902                        and all(isinstance(x, str) for x in v)
903                        for k, v in nn_module_stack.items()
904                    ):
905                        raise SpecViolationError(
906                            f"Node {node} of type {node.op} has incorrect nn_module_stack metadata format"
907                            f"expected Dict[str, Tuple[str, str]], but got {nn_module_stack}"
908                        )
909            elif node.op in ["placeholder", "output"]:
910                if node.meta.get("nn_module_stack", None):
911                    raise SpecViolationError(
912                        f"Node {node} of type {node.op} contains nn_module_stack metadata, this should be None"
913                    )
914
915
916def _verify_stack_trace(graph_module: torch.fx.GraphModule) -> None:
917    """
918    Perform stack trace checks on the graph.
919    Constraints:
920        - None or non-empty str for 'call_function', 'get_attr'
921        - None for 'placeholder', 'output'
922    """
923    for i, mod in enumerate([graph_module] + list(graph_module.modules())):
924        if not isinstance(mod, torch.fx.GraphModule):
925            continue
926        for node in graph_module.graph.nodes:
927            stack_trace = node.meta.get("stack_trace", None)
928            if node.op in ["call_function", "get_attr"]:
929                if not (stack_trace is None or isinstance(stack_trace, str)):
930                    raise SpecViolationError(
931                        f"Node {node} of type {node.op} has invalid stack_trace metadata, "
932                        f"expected a string or None but instead found: {stack_trace}"
933                    )
934            elif node.op in ["placeholder", "output"]:
935                if stack_trace:
936                    raise SpecViolationError(
937                        f"Node {node} of type {node.op} contains stack_trace metadata, "
938                        f"expected None but instead found: {stack_trace}"
939                    )
940
941
942def _verify_placeholder_names(gm: torch.fx.GraphModule, sig: ExportGraphSignature):
943    """
944    Performs a sanity check on the placeholder node names.
945    - User input nodes: no restrictions, should match the original forward() signature
946    - Params/buffers/constants/custom_obj/token nodes: should start with prefixes defined in <placeholder_prefixes>
947    """
948    name_to_kind = {spec.arg.name: spec.kind for spec in sig.input_specs}
949    for mod in gm.modules():
950        if not isinstance(mod, torch.fx.GraphModule):
951            continue
952        for node in mod.graph.nodes:
953            if node.op == "placeholder":
954                if node.name not in name_to_kind:
955                    continue
956                node_kind = name_to_kind[node.name]
957                prefix = placeholder_prefixes[node_kind]
958                if not node.name.startswith(prefix):
959                    raise SpecViolationError(
960                        f"Placeholder node name {node.name} does not follow spec for {node_kind}, name should have prefix: {prefix}"
961                    )
962
963
964def get_ep_stats(ep: ExportedProgram) -> Dict[str, Any]:
965    op_count = 0
966    op_set = set()
967    for m in ep.graph_module.modules():
968        if not isinstance(m, torch.fx.GraphModule):
969            continue
970        for node in m.graph.nodes:
971            if node.op != "call_function":
972                continue
973            op_count += 1
974            assert hasattr(node.target, "__module__")
975            assert hasattr(node.target, "__name__")
976            op_set.add(f"{node.target.__module__}.{node.target.__name__}")
977    return {"op_count": op_count, "op_set": op_set}
978
979
980_EXPORT_FLAGS: Optional[Set[str]] = None
981_EXPORT_MODULE_HIERARCHY: Optional[Dict[str, str]] = None
982
983
984def _log_export_wrapper(fn):
985    @functools.wraps(fn)
986    def wrapper(*args, **kwargs):
987        global _EXPORT_FLAGS, _EXPORT_MODULE_HIERARCHY
988        try:
989            start = time.time()
990            ep = fn(*args, **kwargs)
991            end = time.time()
992            log_export_usage(
993                event="export.time",
994                metrics=end - start,
995                flags=_EXPORT_FLAGS,
996                **get_ep_stats(ep),
997            )
998        except Exception as e:
999            t = type(e)
1000            error_type = t.__module__ + "." + t.__qualname__
1001            case_name = get_class_if_classified_error(e)
1002            if case_name is not None:
1003                log.error(exportdb_error_message(case_name))
1004                log_export_usage(
1005                    event="export.error.classified",
1006                    type=error_type,
1007                    message=str(e),
1008                    flags=_EXPORT_FLAGS,
1009                )
1010            else:
1011                log_export_usage(
1012                    event="export.error.unclassified",
1013                    type=error_type,
1014                    message=str(e),
1015                    flags=_EXPORT_FLAGS,
1016                )
1017            raise e
1018        finally:
1019            _EXPORT_FLAGS = None
1020            _EXPORT_MODULE_HIERARCHY = None
1021
1022        return ep
1023
1024    return wrapper
1025
1026
1027def _process_jit_trace_inputs_for_export(example_inputs, example_kwarg_inputs):
1028    if not isinstance(example_inputs, (tuple, list, dict)):
1029        example_inputs = (example_inputs,)
1030
1031    elif isinstance(example_inputs, list):
1032        example_inputs = tuple(example_inputs)
1033
1034    elif (
1035        isinstance(example_inputs, (torch.Tensor, dict))
1036        and example_kwarg_inputs is None
1037    ):
1038        example_inputs = (example_inputs,)
1039
1040    if example_kwarg_inputs is None:
1041        example_kwarg_inputs = {}
1042    return example_inputs, example_kwarg_inputs
1043
1044
1045def _process_export_inputs(mod, args, kwargs, dynamic_shapes):
1046    original_state_dict = mod.state_dict(keep_vars=True)
1047
1048    if not isinstance(args, tuple):
1049        raise UserError(
1050            UserErrorType.INVALID_INPUT,
1051            f"Expecting `args` to be a tuple of example positional inputs, got {type(args)}",
1052        )
1053    kwargs = kwargs if kwargs is not None else {}
1054    _, original_in_spec = pytree.tree_flatten((args, kwargs))
1055
1056    if isinstance(dynamic_shapes, torch.export.ShapesCollection):
1057        dynamic_shapes = dynamic_shapes.dynamic_shapes(mod, args, kwargs)
1058
1059    return args, kwargs, original_in_spec, original_state_dict, dynamic_shapes
1060
1061
1062def _get_module_call_graph(
1063    export_artifact: ExportArtifact,
1064    original_in_spec: TreeSpec,
1065    preserve_module_call_signature: Tuple[str, ...],
1066    strict_mode_export: bool,
1067):
1068    """
1069    In-place modify the graph module in export_artifact, remove _export_tracepoint nodes and
1070    return module_call_graph.
1071    """
1072    gm: torch.fx.GraphModule = export_artifact.aten.gm
1073    export_graph_signature: ExportGraphSignature = export_artifact.aten.sig
1074    module_call_specs: Dict[
1075        str, Dict[str, TreeSpec]
1076    ] = export_artifact.module_call_specs
1077    out_spec: TreeSpec = export_artifact.out_spec
1078
1079    # Make module signatures.
1080    module_call_signatures = {}
1081    for fqn, specs in module_call_specs.items():
1082        mod_fqn = _strip_root(fqn) if not strict_mode_export else fqn
1083        module_call_signatures[mod_fqn] = ModuleCallSignature(
1084            inputs=[], outputs=[], **specs
1085        )
1086
1087    if len(preserve_module_call_signature) > 0:
1088        if not strict_mode_export:
1089            _rewrite_tracepoint_node(gm)
1090        res = CollectTracepointsPass(module_call_signatures, export_graph_signature)(gm)
1091        assert res is not None
1092        gm = res.graph_module
1093
1094    assert _EXPORT_MODULE_HIERARCHY is not None
1095    module_call_graph = _make_module_call_graph(
1096        _EXPORT_MODULE_HIERARCHY,
1097        original_in_spec,
1098        out_spec,
1099        module_call_signatures,
1100    )
1101    return gm, module_call_graph
1102
1103
1104def _get_range_constraints(
1105    export_artifact: ExportArtifact, combined_args: Dict[str, Any], dynamic_shapes
1106):
1107    gm: torch.fx.GraphModule = export_artifact.aten.gm
1108    export_graph_signature: ExportGraphSignature = export_artifact.aten.sig
1109    fake_mode: FakeTensorMode = export_artifact.fake_mode
1110    num_lifted = next(
1111        (
1112            i
1113            for i, s in enumerate(export_graph_signature.input_specs)
1114            if s.kind == InputKind.USER_INPUT
1115        ),
1116        len(export_graph_signature.input_specs),
1117    )
1118    range_constraints = make_constraints(
1119        fake_mode,
1120        gm,
1121        combined_args,
1122        dynamic_shapes,
1123        num_lifted,
1124    )
1125    return range_constraints
1126
1127
1128def _get_inline_constraints(fake_mode: FakeTensorMode):
1129    assert fake_mode.shape_env is not None
1130    return {
1131        k: v
1132        for k, v in fake_mode.shape_env.var_to_range.items()
1133        if free_unbacked_symbols(k)
1134    }
1135
1136
1137@contextmanager
1138def patch_forward(obj: torch.nn.Module, new_method):
1139    """Helper method to make it easier to cleanly torch.export() a method on a
1140    module that is not `forward`.
1141    """
1142    # Save the original method
1143    original_method = obj.forward
1144
1145    # Patch the method
1146    obj.forward = new_method.__get__(obj, obj.__class__)
1147
1148    try:
1149        yield
1150    finally:
1151        # Restore the original method
1152        obj.forward = original_method
1153
1154
1155@contextmanager
1156def _temp_disable_texpr_fuser():
1157    original_state = torch._C._jit_texpr_fuser_enabled()
1158    torch._C._jit_set_texpr_fuser_enabled(False)
1159    try:
1160        yield
1161    finally:
1162        torch._C._jit_set_texpr_fuser_enabled(original_state)
1163
1164
1165class _WrapperModule(torch.nn.Module):
1166    def __init__(self, f):
1167        super().__init__()
1168        self.f = f
1169
1170    def forward(self, *args, **kwargs):
1171        return self.f(*args, **kwargs)
1172
1173
1174def _convert_ts_to_export_experimental(traced_callable, args, kwargs=None):
1175    with _temp_disable_texpr_fuser():
1176        from torch.jit._trace import TopLevelTracedModule
1177
1178        export_args, export_kwargs = _process_jit_trace_inputs_for_export(args, kwargs)
1179
1180        if isinstance(traced_callable, (TopLevelTracedModule, torch._C.ScriptModule)):  # type: ignore[operator]
1181            return _export(
1182                traced_callable,
1183                export_args,
1184                export_kwargs,
1185                strict=False,
1186                _is_torch_jit_trace=True,
1187            ).module()
1188
1189        elif isinstance(traced_callable, torch.ScriptMethod) and isinstance(
1190            traced_callable.owner(), (torch._C.ScriptModule, torch.nn.Module)  # type: ignore[operator]
1191        ):
1192            with patch_forward(traced_callable.owner(), traced_callable):  # type: ignore[operator]
1193                return _export(
1194                    traced_callable.owner(),  # type: ignore[operator]
1195                    export_args,
1196                    export_kwargs,
1197                    strict=False,
1198                    _is_torch_jit_trace=True,
1199                ).module()
1200
1201        else:
1202            return _export(
1203                _WrapperModule(traced_callable),
1204                export_args,
1205                export_kwargs,
1206                strict=False,
1207                _is_torch_jit_trace=True,
1208            ).module()
1209
1210
1211def _strict_export(
1212    mod: torch.nn.Module,
1213    args: Tuple[Any, ...],
1214    kwargs: Dict[str, Any],
1215    dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]],
1216    preserve_module_call_signature: Tuple[str, ...],
1217    pre_dispatch: bool,
1218    original_state_dict: Dict[str, Any],
1219    orig_in_spec: TreeSpec,
1220    allow_complex_guards_as_runtime_asserts: bool,
1221    _is_torch_jit_trace: bool,
1222) -> ExportArtifact:
1223    lower_to_aten = functools.partial(_export_to_aten_ir, pre_dispatch=pre_dispatch)
1224    return _strict_export_lower_to_aten_ir(
1225        mod=mod,
1226        args=args,
1227        kwargs=kwargs,
1228        dynamic_shapes=dynamic_shapes,
1229        preserve_module_call_signature=preserve_module_call_signature,
1230        pre_dispatch=pre_dispatch,
1231        original_state_dict=original_state_dict,
1232        orig_in_spec=orig_in_spec,
1233        allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
1234        _is_torch_jit_trace=_is_torch_jit_trace,
1235        lower_to_aten_callback=lower_to_aten,
1236    )
1237
1238
1239def _strict_export_lower_to_aten_ir(
1240    mod: torch.nn.Module,
1241    args: Tuple[Any, ...],
1242    kwargs: Dict[str, Any],
1243    dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]],
1244    preserve_module_call_signature: Tuple[str, ...],
1245    pre_dispatch: bool,
1246    original_state_dict: Dict[str, Any],
1247    orig_in_spec: TreeSpec,
1248    allow_complex_guards_as_runtime_asserts: bool,
1249    _is_torch_jit_trace: bool,
1250    lower_to_aten_callback: Callable,
1251) -> ExportArtifact:
1252    gm_torch_level = _export_to_torch_ir(
1253        mod,
1254        args,
1255        kwargs,
1256        dynamic_shapes,
1257        preserve_module_call_signature=preserve_module_call_signature,
1258        restore_fqn=False,  # don't need to restore because we will do it later
1259        allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
1260        _log_export_usage=False,
1261    )
1262
1263    # We detect the fake_mode by looking at gm_torch_level's placeholders, this is the fake_mode created in dynamo.
1264    (
1265        fake_args,
1266        fake_kwargs,
1267        dynamo_fake_mode,
1268    ) = _extract_fake_inputs(gm_torch_level, args, kwargs)
1269
1270    fake_params_buffers = _fakify_params_buffers(dynamo_fake_mode, gm_torch_level)
1271
1272    # First, we want to pass through the graph to try populating
1273    # val field for getattr if there is anything missing.
1274    # This can happen when quantization adds extra params and forgets
1275    # to update "val"
1276    for node in gm_torch_level.graph.nodes:
1277        if node.op == "get_attr" and "val" not in node.meta:
1278            attr = getattr(gm_torch_level, node.target)
1279            # Checks if it is not a HigherOrderOp branch or a module
1280            if not isinstance(attr, torch.nn.Module):
1281                assert (
1282                    dynamo_fake_mode is not None
1283                ), "Cannot find dynamo_fake_mode. This could be due to the exported graph module have no placeholders."
1284                node.meta["val"] = dynamo_fake_mode.from_tensor(
1285                    attr, static_shapes=True
1286                )
1287
1288    # Fix the graph output signature to be tuple if scalar
1289    out_spec = orig_out_spec = gm_torch_level._out_spec
1290
1291    # Used to get rid of lint type error.
1292    assert out_spec is not None
1293    assert orig_out_spec is not None
1294
1295    # aot_export expect the return type to always be a tuple.
1296    if out_spec.type not in (list, tuple):
1297        out_spec = pytree.TreeSpec(tuple, None, [out_spec])
1298
1299    orig_arg_names = gm_torch_level.graph._codegen.pytree_info.orig_args  # type: ignore[attr-defined]
1300
1301    gm_torch_level.graph._codegen = _PyTreeCodeGen(
1302        _PyTreeInfo(
1303            orig_arg_names,
1304            gm_torch_level._in_spec,
1305            out_spec,
1306        )
1307    )
1308    gm_torch_level.recompile()
1309
1310    _normalize_nn_module_stack(gm_torch_level, type(mod))
1311
1312    params_buffers_to_node_meta = _collect_param_buffer_metadata(gm_torch_level)
1313
1314    # When aot_export lifts the params, we lose metadata (e.g. source_fn_stack, stack_trace)
1315    # from the param nodes as they are treated as fresh inputs
1316    # Therefore, we manually extract them before calling into aot_export
1317    # params_buffers_to_node_meta = _collect_param_buffer_metadata(gm_torch_level)
1318
1319    constant_attrs = _gather_constant_attrs(mod)
1320    param_buffer_table: Dict[str, str] = _get_param_buffer_mapping(mod, gm_torch_level)
1321
1322    # Dynamo does not track which buffers were registered as non-persistent. This info
1323    # is available in the original module, so we transfer it to the traced module. Also,
1324    # since we didn't restore original param/buffer names yet, we must use traced names.
1325    non_persistent_buffers = _get_non_persistent_buffers(mod)
1326    reverse_name_lookup = {orig: traced for traced, orig in param_buffer_table.items()}
1327    gm_torch_level._non_persistent_buffers_set = {
1328        reverse_name_lookup[name]
1329        for name in non_persistent_buffers
1330        if name in reverse_name_lookup
1331    }
1332    with dynamo_fake_mode:
1333        aten_export_artifact = lower_to_aten_callback(
1334            gm_torch_level,
1335            # NOTE: graph module expects only positional args
1336            _convert_to_positional_args(orig_arg_names, fake_args, fake_kwargs),
1337            {},
1338            fake_params_buffers,
1339            constant_attrs,
1340        )
1341
1342    # Decompose for readability.
1343    gm = aten_export_artifact.gm
1344    export_graph_signature = aten_export_artifact.sig
1345    constants = aten_export_artifact.constants
1346
1347    _populate_param_buffer_metadata_to_new_gm(
1348        params_buffers_to_node_meta, gm, export_graph_signature
1349    )
1350
1351    # Do some cleanups on the graph module to restore the state dict to the
1352    # expected form. Each of these steps should probably get fixed upstream.
1353    # 1. Remove tensor constants that were added as buffers.
1354    _rewrite_dynamo_tensor_constants(
1355        orig_mod_buffers=set(mod.buffers()),
1356        traced_mod_buffers=dict(gm_torch_level.named_buffers()),
1357        graph_signature=export_graph_signature,
1358        constants=constants,
1359    )
1360    # 2. Restore FQN of param/buffers
1361    _replace_param_buffer_names(param_buffer_table, export_graph_signature)
1362
1363    # 3. Move non-persistent buffers to tensor constants
1364    _move_non_persistent_buffers_to_tensor_constants(
1365        mod, export_graph_signature, constants
1366    )
1367
1368    # 4. Rewrite constants to have the same FQN as the original module.
1369    _remap_constants(constant_attrs, export_graph_signature, constants)
1370
1371    # 5. Rename constants nodes in graph module from buffers to constants
1372    _rename_constants_nodes(gm, export_graph_signature)
1373
1374    return ExportArtifact(
1375        aten=aten_export_artifact,
1376        out_spec=orig_out_spec,
1377        fake_mode=dynamo_fake_mode,
1378        module_call_specs=gm_torch_level.meta["module_call_specs"],
1379    )
1380
1381
1382def _export_to_aten_ir_make_fx(
1383    mod: torch.nn.Module,
1384    fake_args,
1385    fake_kwargs,
1386    fake_params_buffers,
1387    constant_attrs: ConstantAttrMap,
1388    produce_guards_callback=None,
1389    transform=lambda x: x,
1390) -> ATenExportArtifact:
1391    @contextmanager
1392    def _compiling_state_context():
1393        old_value = torch.compiler._is_compiling_flag
1394        try:
1395            torch.compiler._is_compiling_flag = True
1396            yield
1397        finally:
1398            torch.compiler._is_compiling_flag = old_value
1399
1400    def _make_fx_helper(mod, args, kwargs, **flags):
1401        from torch._functorch._aot_autograd.schemas import GraphSignature
1402
1403        kwargs = kwargs or {}
1404
1405        named_parameters = dict(mod.named_parameters(remove_duplicate=False))
1406        named_buffers = dict(mod.named_buffers(remove_duplicate=False))
1407
1408        params_and_buffers = {**named_parameters, **named_buffers}
1409        params_and_buffers_flat, params_spec = pytree.tree_flatten(params_and_buffers)
1410        params_and_buffers_flat = tuple(params_and_buffers_flat)
1411
1412        param_len = len(named_parameters)
1413        buffer_len = len(named_buffers)
1414        params_len = len(params_and_buffers)
1415
1416        functional_call = create_functional_call(
1417            mod, params_spec, params_len, store_orig_mod=True
1418        )
1419
1420        params_buffers_args: List[Any] = []
1421        params_buffers_args.extend(params_and_buffers_flat)
1422        params_buffers_args.extend(args)
1423
1424        flat_fn, out_spec = create_tree_flattened_fn(
1425            functional_call, params_buffers_args, kwargs
1426        )
1427        flat_args, in_spec = pytree.tree_flatten((params_buffers_args, kwargs))
1428
1429        @functools.wraps(flat_fn)
1430        def wrapped_fn(*args):
1431            return tuple(flat_fn(*args))
1432
1433        with enable_python_dispatcher():
1434            gm = make_fx(
1435                wrapped_fn,
1436                record_module_stack=True,
1437                pre_dispatch=True,
1438            )(*flat_args)
1439            gm.graph.eliminate_dead_code()
1440
1441        # create graph signature
1442        input_names = _graph_input_names(gm)
1443        output_names = _graph_output_names(gm)
1444        sig = GraphSignature(
1445            parameters=list(named_parameters),
1446            buffers=list(named_buffers),
1447            user_inputs=input_names[params_len:],
1448            user_outputs=output_names,
1449            inputs_to_parameters=dict(zip(input_names[0:param_len], named_parameters)),
1450            inputs_to_buffers=dict(
1451                zip(input_names[param_len : param_len + buffer_len], named_buffers)
1452            ),
1453            buffers_to_mutate={},
1454            user_inputs_to_mutate={},
1455            in_spec=in_spec,
1456            out_spec=out_spec,  # type: ignore[arg-type]
1457            backward_signature=None,
1458            input_tokens=[],
1459            output_tokens=[],
1460        )
1461        return gm, sig
1462
1463    # This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode,
1464    # otherwise aot_export_module will error out because it sees a mix of fake_modes.
1465    # And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about.
1466    with torch.nn.utils.stateless._reparametrize_module(
1467        mod,
1468        fake_params_buffers,
1469        tie_weights=True,
1470        strict=True,
1471        stack_weights=True,
1472    ), _ignore_backend_decomps(), _compiling_state_context():  # type: ignore[attr-defined]
1473        param_len = len(dict(mod.named_parameters(remove_duplicate=False)))
1474        buffer_len = len(dict(mod.named_buffers(remove_duplicate=False)))
1475        params_len = param_len + buffer_len
1476
1477        gm, graph_signature = transform(_make_fx_helper)(
1478            mod,
1479            fake_args,
1480            trace_joint=False,
1481            kwargs=fake_kwargs,
1482        )
1483
1484        if isinstance(mod, torch.fx.GraphModule) and hasattr(mod, "meta"):
1485            gm.meta.update(mod.meta)
1486
1487    flat_args = pytree.tree_leaves((fake_args, fake_kwargs))
1488    index = 0
1489    for node in gm.graph.nodes:
1490        if node.op == "placeholder":
1491            if index >= params_len:
1492                user_arg = flat_args[index - params_len]
1493                if not isinstance(user_arg, torch.Tensor):
1494                    node.meta["val"] = user_arg
1495            index += 1
1496
1497    export_graph_signature = _convert_to_export_graph_signature(
1498        graph_signature, gm, _get_non_persistent_buffers(mod)
1499    )
1500
1501    # See comment in _export_to_aten_ir()
1502    if produce_guards_callback:
1503        try:
1504            produce_guards_callback(gm)
1505        except (ConstraintViolationError, ValueRangeError) as e:
1506            raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e))  # noqa: B904
1507
1508    fake_mode = detect_fake_mode(flat_args)
1509
1510    if not torch._dynamo.config.do_not_emit_runtime_asserts:
1511        stack_trace = (
1512            'File "torch/fx/passes/runtime_assert.py", line 24, '
1513            "in insert_deferred_runtime_asserts"
1514        )
1515        with _set_node_metadata_hook(
1516            gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace)
1517        ):
1518            insert_deferred_runtime_asserts(
1519                gm,
1520                fake_mode.shape_env,
1521                f"exported program: {first_call_function_nn_module_stack(gm.graph)}",
1522                export=True,
1523            )
1524
1525    # Remove nn_module_stack, stack_trace metadata from all placeholders/inputs nodes.
1526    for _mod in gm.modules():
1527        if not isinstance(_mod, torch.fx.GraphModule):
1528            continue
1529        for node in _mod.graph.nodes:
1530            if node.op in ["placeholder", "output"]:
1531                node.meta.pop("nn_module_stack", None)
1532                node.meta.pop("stack_trace", None)
1533
1534    constants = rewrite_script_object_meta(gm)
1535    constants.update(lift_constants_pass(gm, export_graph_signature, constant_attrs))
1536
1537    _preserve_requires_grad_pass(
1538        gm, export_graph_signature, fake_params_buffers, constants, flat_args
1539    )
1540
1541    # Prettify names for placeholder nodes.
1542    placeholder_naming_pass(
1543        gm,
1544        export_graph_signature,
1545        mod,
1546        fake_args,
1547        fake_kwargs,
1548        fake_params_buffers,
1549        constants,
1550    )
1551
1552    return ATenExportArtifact(
1553        gm,
1554        export_graph_signature,
1555        constants,
1556    )
1557
1558
1559def _non_strict_export(
1560    mod: torch.nn.Module,
1561    args: Tuple[Any, ...],
1562    kwargs: Dict[str, Any],
1563    dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]],
1564    preserve_module_call_signature: Tuple[str, ...],
1565    pre_dispatch: bool,
1566    original_state_dict: Dict[str, Any],
1567    orig_in_spec: TreeSpec,
1568    allow_complex_guards_as_runtime_asserts: bool,
1569    _is_torch_jit_trace: bool,
1570    dispatch_tracing_mode: str = "aot_export",
1571) -> ExportArtifact:
1572    """
1573    ``dispatch_tracing_mode`` can be either "make_fx” or “aot_export”, corresponding to
1574    _export_to_aten_ir_make_fx and _export_to_aten_ir, respectively.
1575    """
1576    assert dispatch_tracing_mode in ["make_fx", "aot_export"]
1577    out_spec: Optional[TreeSpec] = None
1578
1579    module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]] = {}
1580
1581    def _tuplify_outputs(aot_export):
1582        def _aot_export_non_strict(mod, args, kwargs=None, **flags):
1583            kwargs = kwargs or {}
1584
1585            class Wrapper(torch.nn.Module):
1586                def __init__(self, mod):
1587                    super().__init__()
1588                    self._export_root = mod
1589
1590                def forward(self, *args, **kwargs):
1591                    nonlocal out_spec
1592                    if isinstance(self._export_root, torch.fx.GraphModule):
1593                        with torch.fx.traceback.preserve_node_meta():
1594                            tree_out = torch.fx.Interpreter(self._export_root).run(
1595                                *args, **kwargs
1596                            )
1597                    else:
1598                        tree_out = self._export_root(*args, **kwargs)
1599                    flat_outs, out_spec = pytree.tree_flatten(tree_out)
1600                    return tuple(flat_outs)
1601
1602            wrapped_mod = Wrapper(mod)
1603            # Patch export_root to the signatures so that wrapper module correctly populates the
1604            # in/out spec
1605            new_preserved_call_signatures = [
1606                "_export_root." + i for i in preserve_module_call_signature
1607            ]
1608            with _wrap_submodules(
1609                wrapped_mod, new_preserved_call_signatures, module_call_specs
1610            ):
1611                gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
1612                log.debug("Exported program from AOTAutograd:\n%s", gm)
1613
1614            sig.parameters = pytree.tree_map(_strip_root, sig.parameters)
1615            sig.buffers = pytree.tree_map(_strip_root, sig.buffers)
1616            sig.inputs_to_buffers = pytree.tree_map(_strip_root, sig.inputs_to_buffers)
1617            sig.inputs_to_parameters = pytree.tree_map(
1618                _strip_root, sig.inputs_to_parameters
1619            )
1620            sig.buffers_to_mutate = pytree.tree_map(_strip_root, sig.buffers_to_mutate)
1621
1622            for node in gm.graph.nodes:
1623                if "nn_module_stack" in node.meta:
1624                    nn_module_stack = node.meta["nn_module_stack"]
1625                    node.meta["nn_module_stack"] = {
1626                        _fixup_key(key): val
1627                        for key, val in pytree.tree_map(
1628                            _strip_root, nn_module_stack
1629                        ).items()
1630                    }
1631
1632            return gm, sig
1633
1634        return _aot_export_non_strict
1635
1636    (
1637        fake_mode,
1638        fake_args,
1639        fake_kwargs,
1640        equalities_inputs,
1641        original_signature,
1642        transformed_dynamic_shapes,
1643    ) = make_fake_inputs(
1644        mod,
1645        args,
1646        kwargs,
1647        dynamic_shapes,
1648        _is_torch_jit_trace=_is_torch_jit_trace,
1649        allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,  # for shape env initialization
1650    )
1651
1652    fake_params_buffers = _fakify_params_buffers(fake_mode, mod)
1653
1654    def _produce_guards_callback(gm):
1655        return produce_guards_and_solve_constraints(
1656            fake_mode=fake_mode,
1657            gm=gm,
1658            dynamic_shapes=transformed_dynamic_shapes,
1659            equalities_inputs=equalities_inputs,
1660            original_signature=original_signature,
1661            _is_torch_jit_trace=_is_torch_jit_trace,
1662        )
1663
1664    with fake_mode, _NonStrictTorchFunctionHandler(), torch._dynamo.config.patch(
1665        assume_static_by_default=False
1666    ):
1667        with _fakify_script_objects(mod, fake_args, fake_kwargs, fake_mode) as (
1668            patched_mod,
1669            new_fake_args,
1670            new_fake_kwargs,
1671            new_fake_constant_attrs,
1672            map_fake_to_real,
1673        ):
1674            _to_aten_func = (
1675                _export_to_aten_ir_make_fx
1676                if dispatch_tracing_mode == "make_fx"
1677                else functools.partial(
1678                    _export_to_aten_ir,
1679                    pre_dispatch=pre_dispatch,
1680                    _is_torch_jit_trace=_is_torch_jit_trace,
1681                )
1682            )
1683            aten_export_artifact = _to_aten_func(  # type: ignore[operator]
1684                patched_mod,
1685                new_fake_args,
1686                new_fake_kwargs,
1687                fake_params_buffers,
1688                new_fake_constant_attrs,
1689                produce_guards_callback=_produce_guards_callback,
1690                transform=_tuplify_outputs,
1691            )
1692            # aten_export_artifact.constants contains only fake script objects, we need to map them back
1693            aten_export_artifact.constants = {
1694                fqn: map_fake_to_real[obj] if isinstance(obj, FakeScriptObject) else obj
1695                for fqn, obj in aten_export_artifact.constants.items()
1696            }
1697
1698    _move_non_persistent_buffers_to_tensor_constants(
1699        mod, aten_export_artifact.sig, aten_export_artifact.constants
1700    )
1701
1702    assert out_spec is not None
1703
1704    return ExportArtifact(
1705        aten=aten_export_artifact,
1706        out_spec=out_spec,
1707        fake_mode=fake_mode,
1708        module_call_specs=module_call_specs,
1709    )
1710
1711
1712# TODO (tmanlaibaatar) We need to preserve aten.to here somehow
1713@_log_export_wrapper
1714@_disable_prexisiting_fake_mode
1715def _export_for_training(
1716    mod: torch.nn.Module,
1717    args: Tuple[Any, ...],
1718    kwargs: Optional[Dict[str, Any]] = None,
1719    dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
1720    *,
1721    strict: bool = True,
1722    preserve_module_call_signature: Tuple[str, ...] = (),
1723) -> ExportedProgram:
1724    global _EXPORT_MODULE_HIERARCHY
1725    _EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod)
1726
1727    (
1728        args,
1729        kwargs,
1730        orig_in_spec,
1731        original_state_dict,
1732        dynamic_shapes,
1733    ) = _process_export_inputs(mod, args, kwargs, dynamic_shapes)
1734
1735    export_func = (
1736        functools.partial(
1737            _strict_export_lower_to_aten_ir,
1738            lower_to_aten_callback=_export_to_aten_ir_make_fx,
1739        )
1740        if strict
1741        else functools.partial(
1742            _non_strict_export,
1743            dispatch_tracing_mode="make_fx",
1744        )
1745    )
1746    export_artifact = export_func(  # type: ignore[operator]
1747        mod=mod,
1748        args=args,
1749        kwargs=kwargs,
1750        dynamic_shapes=dynamic_shapes,
1751        preserve_module_call_signature=preserve_module_call_signature,
1752        pre_dispatch=False,
1753        original_state_dict=original_state_dict,
1754        orig_in_spec=orig_in_spec,
1755        allow_complex_guards_as_runtime_asserts=False,
1756        _is_torch_jit_trace=False,
1757    )
1758
1759    export_graph_signature = export_artifact.aten.sig
1760
1761    forward_arg_names = _get_forward_arg_names(mod, args, kwargs)
1762    inline_constraints = _get_inline_constraints(export_artifact.fake_mode)
1763    # The unbacked symint symbols are updated in aot_export
1764    # so we serialize them here instead of inside dynamo.
1765    # Note: _get_range_constraints depends on "inline_constraints" to be set.
1766    export_artifact.aten.gm.meta["inline_constraints"] = inline_constraints
1767    range_constraints = _get_range_constraints(
1768        export_artifact,
1769        _combine_args(mod, args, kwargs, _is_torch_jit_trace=False),
1770        dynamic_shapes,
1771    )
1772    # The returned the gm is in-place modified
1773    gm, module_call_graph = _get_module_call_graph(
1774        export_artifact, orig_in_spec, preserve_module_call_signature, strict
1775    )
1776
1777    # Add forward args metadata.
1778    gm.meta["forward_arg_names"] = forward_arg_names
1779
1780    _verify_nn_module_stack(gm)
1781    _verify_stack_trace(gm)
1782    _verify_placeholder_names(gm, export_graph_signature)
1783
1784    from torch._export.verifier import TrainingIRVerifier
1785
1786    exported_program = ExportedProgram(
1787        root=gm,
1788        graph=gm.graph,
1789        graph_signature=export_graph_signature,
1790        state_dict=original_state_dict,
1791        range_constraints=range_constraints,
1792        module_call_graph=module_call_graph,
1793        example_inputs=(args, kwargs),
1794        constants=export_artifact.aten.constants,
1795        verifiers=[TrainingIRVerifier],
1796    )
1797
1798    return exported_program
1799
1800
1801@_log_export_wrapper
1802@_disable_prexisiting_fake_mode
1803def _export(
1804    mod: torch.nn.Module,
1805    args: Tuple[Any, ...],
1806    kwargs: Optional[Dict[str, Any]] = None,
1807    dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
1808    *,
1809    strict: bool = True,
1810    preserve_module_call_signature: Tuple[str, ...] = (),
1811    pre_dispatch: bool = False,
1812    allow_complex_guards_as_runtime_asserts: bool = False,
1813    _is_torch_jit_trace: bool = False,
1814) -> ExportedProgram:
1815    """
1816    Traces either an nn.Module's forward function or just a callable with PyTorch
1817    operations inside and produce a ExportedProgram.
1818
1819    Args:
1820        f: the `nn.Module` to trace.
1821
1822        args: example positional inputs.
1823
1824        kwargs: optional example keyword inputs.
1825
1826        dynamic_shapes:
1827         An optional argument where the type should either be:
1828         1) a dict from argument names of ``f`` to their dynamic shape specifications,
1829         2) a tuple that specifies dynamic shape specifications for each input in original order.
1830         If you are specifying dynamism on keyword args, you will need to pass them in the order that
1831         is defined in the original function signature.
1832
1833         The dynamic shape of a tensor argument can be specified as either
1834         (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
1835         not required to include static dimension indices in this dict, but when they are,
1836         they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
1837         where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
1838         are denoted by None. Arguments that are dicts or tuples / lists of tensors are
1839         recursively specified by using mappings or sequences of contained specifications.
1840
1841        preserve_module_call_signature: A list of submodule paths for which the original
1842            calling conventions are preserved as metadata.
1843
1844        allow_complex_guards_as_runtime_asserts:
1845         With the current dynamic shapes language for dims and derived dims, we can run into constraints
1846         that are not expressible with the language. For example, flattening a matrix and adding to a vector,
1847         both fully dynamic (i.e. x.reshape([-1]) + y) emits a guard s0 * s1 = s2, which is not expressible.
1848         By default, we either raise a constraint violation error or specialize to static values.
1849         If this flag is set to True, we avoid erroring out and instead allow complex constraints to exist as runtime
1850         assertions in the graph. The sympy interpreter (torch/utils/_sympy/interp.py) will produce the math ops
1851         required to compute and assert the value of the guard (e.g. sym_size_int, eq, _assert_scalar).
1852         Additionally, if TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1 is specified, we will allow complex constraints
1853         while not emitting runtime asserts, returning a cleaner graph with lesser guarantees around dynamic shapes.
1854
1855    Returns:
1856        An ExportedProgram containing the traced method.
1857    """
1858
1859    global _EXPORT_FLAGS, _EXPORT_MODULE_HIERARCHY
1860    _EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod)
1861
1862    flags = set()
1863    flags.add("strict" if strict else "non_strict")
1864    flags.add("pre_dispatch" if pre_dispatch else "aot_dispatch")
1865    _EXPORT_FLAGS = flags
1866
1867    log_export_usage(event="export.enter", flags=_EXPORT_FLAGS)
1868
1869    (
1870        args,
1871        kwargs,
1872        original_in_spec,
1873        original_state_dict,
1874        dynamic_shapes,
1875    ) = _process_export_inputs(mod, args, kwargs, dynamic_shapes)
1876
1877    # Call the appropriate export function based on the strictness of tracing.
1878    export_func = _strict_export if strict else _non_strict_export
1879
1880    export_artifact = export_func(  # type: ignore[operator]
1881        mod,
1882        args,
1883        kwargs,
1884        dynamic_shapes,
1885        preserve_module_call_signature,
1886        pre_dispatch,
1887        original_state_dict,
1888        original_in_spec,
1889        allow_complex_guards_as_runtime_asserts,
1890        _is_torch_jit_trace,
1891    )
1892    export_graph_signature: ExportGraphSignature = export_artifact.aten.sig
1893
1894    forward_arg_names = (
1895        _get_forward_arg_names(mod, args, kwargs) if not _is_torch_jit_trace else None
1896    )
1897    inline_constraints = _get_inline_constraints(export_artifact.fake_mode)
1898    # The unbacked symint symbols are updated in aot_export
1899    # so we serialize them here instead of inside dynamo.
1900    # Note: this step must be before _get_range_constraints.
1901    export_artifact.aten.gm.meta["inline_constraints"] = inline_constraints
1902    range_constraints = _get_range_constraints(
1903        export_artifact,
1904        _combine_args(mod, args, kwargs, _is_torch_jit_trace=_is_torch_jit_trace),
1905        dynamic_shapes,
1906    )
1907    gm, module_call_graph = _get_module_call_graph(
1908        export_artifact, original_in_spec, preserve_module_call_signature, strict
1909    )
1910
1911    # Add forward args metadata.
1912    gm.meta["forward_arg_names"] = forward_arg_names
1913
1914    _verify_nn_module_stack(gm)
1915    _verify_stack_trace(gm)
1916    if not _is_torch_jit_trace:
1917        _verify_placeholder_names(gm, export_graph_signature)
1918
1919    # Remove Proxy because they cannot be deepcopied or pickled.
1920    torch._export.utils.remove_proxy_from_state_dict(original_state_dict, in_place=True)
1921
1922    from torch._export.verifier import Verifier
1923
1924    if (
1925        isinstance(mod, torch.fx.GraphModule)
1926        and hasattr(mod, "meta")
1927        and "custom" in mod.meta
1928    ):
1929        gm.meta.update({"custom": mod.meta["custom"]})
1930
1931    exported_program = ExportedProgram(
1932        root=gm,
1933        graph=gm.graph,
1934        graph_signature=export_graph_signature,
1935        state_dict=original_state_dict,
1936        range_constraints=range_constraints,
1937        module_call_graph=module_call_graph,
1938        example_inputs=(args, kwargs),
1939        constants=export_artifact.aten.constants,
1940        verifiers=[Verifier],
1941    )
1942
1943    return exported_program
1944