xref: /aosp_15_r20/external/pytorch/torch/export/exported_program.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import contextlib
4import copy
5import dataclasses
6import functools
7import operator
8import types
9import warnings
10from collections import namedtuple
11from contextlib import contextmanager
12from typing import (
13    Any,
14    Callable,
15    Dict,
16    final,
17    Iterator,
18    List,
19    Optional,
20    Tuple,
21    Type,
22    TYPE_CHECKING,
23    Union,
24)
25
26from torch._higher_order_ops.utils import autograd_not_implemented
27from torch._library.fake_class_registry import FakeScriptObject
28from torch.fx._utils import first_call_function_nn_module_stack
29from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
30from torch.fx.immutable_collections import immutable_dict, immutable_list
31from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
32
33
34if TYPE_CHECKING:
35    # Import the following modules during type checking to enable code intelligence features,
36    # such as auto-completion in tools like pylance, even when these modules are not explicitly
37    # imported in user code.
38
39    import sympy
40
41    from torch.utils._sympy.value_ranges import ValueRanges
42
43import torch
44import torch.utils._pytree as pytree
45from torch._export.utils import (
46    _collect_and_set_constant_attrs,
47    _collect_param_buffer_metadata,
48    _detect_fake_mode_from_gm,
49    _name_hoo_subgraph_placeholders,
50    _overwrite_signature_for_non_persistent_buffers,
51    _populate_param_buffer_metadata_to_new_gm,
52    _rename_without_collisions,
53)
54from torch._export.verifier import Verifier
55from torch._guards import detect_fake_mode
56from torch._subclasses.fake_tensor import unset_fake_temporarily
57from torch._subclasses.functional_tensor import FunctionalTensor
58from torch.export._tree_utils import is_equivalent, reorder_kwargs
59from torch.fx._compatibility import compatibility
60from torch.fx.passes.infra.pass_base import PassResult
61from torch.fx.passes.infra.pass_manager import PassManager
62
63from .graph_signature import (  # noqa: F401
64    ArgumentSpec,
65    ConstantArgument,
66    CustomObjArgument,
67    ExportGraphSignature,
68    InputKind,
69    InputSpec,
70    OutputKind,
71    OutputSpec,
72    SymIntArgument,
73    TensorArgument,
74    TokenArgument,
75)
76
77
78__all__ = [
79    "ExportedProgram",
80    "ModuleCallEntry",
81    "ModuleCallSignature",
82]
83
84
85PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
86
87
88@dataclasses.dataclass
89class ModuleCallSignature:
90    inputs: List[ArgumentSpec]
91    outputs: List[ArgumentSpec]
92    in_spec: pytree.TreeSpec
93    out_spec: pytree.TreeSpec
94
95    def replace_all_uses_with(self, original_node, new_node):
96        for i in self.inputs:
97            if i.name == original_node.name:
98                i.name = new_node.name
99        for o in self.outputs:
100            if o.name == original_node.name:
101                o.name = new_node.name
102
103
104@dataclasses.dataclass
105class ModuleCallEntry:
106    fqn: str
107    signature: Optional[ModuleCallSignature] = None
108
109
110def _disable_prexisiting_fake_mode(fn):
111    @functools.wraps(fn)
112    def wrapper(*args, **kwargs):
113        with unset_fake_temporarily():
114            return fn(*args, **kwargs)
115
116    return wrapper
117
118
119def _fx_collection_equivalence_fn(
120    spec1_type: Optional[type],
121    spec1_context: pytree.Context,
122    spec2_type: Optional[type],
123    spec2_context: pytree.Context,
124) -> bool:
125    """Treat containers and their immutable variants as the same type. Otherwise
126    compare as normal.
127    """
128    if spec1_type is None or spec2_type is None:
129        return spec1_type is spec2_type and spec1_context == spec2_context
130
131    if issubclass(spec1_type, (dict, immutable_dict)) and issubclass(
132        spec2_type, (dict, immutable_dict)
133    ):
134        return spec1_context == spec2_context
135
136    if issubclass(spec1_type, (list, immutable_list)) and issubclass(
137        spec2_type, (list, immutable_list)
138    ):
139        return spec1_context == spec2_context
140
141    return spec1_type is spec2_type and spec1_context == spec2_context
142
143
144def _register_cia_to_meta(*args, **kwargs):
145    kernel = kwargs["kernel"]
146    del kwargs["kernel"]
147
148    assert torch._C._dispatch_has_kernel_for_dispatch_key(
149        kernel.name(), torch._C.DispatchKey.CompositeImplicitAutograd
150    )
151
152    return kernel._op_dk(
153        torch._C.DispatchKey.CompositeImplicitAutograd, *args, **kwargs
154    )
155
156
157# This list is compiled from DispatchKey.cpp.
158# The idea is that we use these keys to override
159# CIA decomp in export
160_AUTOGRAD_ALIAS_BACKEND_KEYS_TO_OVERRIDE = [
161    torch._C.DispatchKey.AutogradCPU,
162    torch._C.DispatchKey.AutogradCUDA,
163    torch._C.DispatchKey.AutogradMeta,
164    torch._C.DispatchKey.AutogradXLA,
165    torch._C.DispatchKey.AutogradLazy,
166    torch._C.DispatchKey.AutogradIPU,
167    torch._C.DispatchKey.AutogradXPU,
168    torch._C.DispatchKey.AutogradMPS,
169    torch._C.DispatchKey.AutogradHPU,
170    torch._C.DispatchKey.AutogradPrivateUse1,
171    torch._C.DispatchKey.AutogradPrivateUse2,
172    torch._C.DispatchKey.AutogradPrivateUse3,
173]
174
175
176@contextmanager
177def _override_composite_implicit_decomp(ops_to_preserve, decomp_table, safe=True):
178    # This function overrides CompositeImplicitAutograd decomp for
179    # functional composite ops that user specified. Ideally we want to not-decompose
180    # ALL composite ops but today's C++ functinalization relies on
181    # the fact that it is working with the opset after decomp is run.
182    # Hence we can only do it for functional ops. One caveat is that
183    # there are some composite ops that lie about their schema (claimed to be
184    # functional but not really aka dropout), for these cases, we just decompose.
185
186    # When safe=False, we will assume that ops_to_preserve can be mutating/aliasing
187    # and their usual decompositions need to be shadowed rather than overridden.
188    # Thus we will avoid asserting that they are valid to preserve, and will not
189    # replace their CompositeImplicitAutograd kernels with NotImplemented.
190    # The only current users of this mode are variants of aten::to that we will
191    # replace with aten::_to_copy in FunctionalTensorMode.__torch_dispatch__.
192
193    saved_tables = {}
194    patched_ops = set()
195    removed_decomps = {}
196    for op_overload in ops_to_preserve:
197        # Our strategy for deciding if we can preserve CIA is following:
198        # 1. The op should be known statically that it is functional
199        # 2. If it is maybe aliasing, we decompose because we must know if an op
200        #    is mutating or aliasing.
201        # TODO (tmanlaibaatar) make this utility function and share it with functional_tensor
202        # decomp part. (https://github.com/pytorch/pytorch/issues/129431)
203        def assert_valid_to_preserve(op_overload):
204            if op_overload in FunctionalTensor.maybe_aliasing_or_mutating_ops:
205                raise RuntimeError(
206                    f"We can't detect {op_overload} as a functional op statically, so we can't preserve it"
207                )
208            if op_overload in FunctionalTensor.metadata_fns:
209                raise RuntimeError(
210                    f"{op_overload} is a metadata query function, "
211                    "it will be preserved implicitly in our tracing system. "
212                    "Please file an issue on github if you see otherwise"
213                )
214
215            alias_info = len(
216                [i for i in op_overload._schema.arguments if i.alias_info is not None]
217            )
218
219            is_mutating_or_aliasing = alias_info != 0 or op_overload._schema.is_mutable
220
221            if is_mutating_or_aliasing:
222                raise RuntimeError(
223                    f"{op_overload} is a mutating/aliasing op, we can't preserve it as is"
224                )
225
226            if not torch._C._dispatch_has_kernel(op_overload.name()):
227                raise RuntimeError(
228                    f"{op_overload} is a TorchScript op, we can't preserve it as is"
229                )
230
231            return True
232
233        if safe:
234            # If we didn't error, it means we can go ahead
235            assert_valid_to_preserve(op_overload)
236
237        saved_tables[op_overload] = op_overload.py_kernels.copy()
238        patched_ops.add(op_overload)
239
240        for override_dispatch_key in _AUTOGRAD_ALIAS_BACKEND_KEYS_TO_OVERRIDE:
241            if override_dispatch_key not in op_overload.py_kernels:
242                # TODO (tmanlaibaatar)https://github.com/pytorch/pytorch/issues/129430
243                op_overload.py_impl(override_dispatch_key)(
244                    autograd_not_implemented(op_overload, deferred_error=True)
245                )
246        if torch._C.DispatchKey.CompositeImplicitAutograd in op_overload.py_kernels:
247            del op_overload.py_kernels[torch._C.DispatchKey.CompositeImplicitAutograd]
248
249        if safe:
250
251            def _(*args, **kwargs):
252                return NotImplemented
253
254            op_overload.py_impl(torch._C.DispatchKey.CompositeImplicitAutograd)(_)
255
256        # For fake tensor prop, we do want to register meta kernel directly
257        if torch._C.DispatchKey.Meta not in op_overload.py_kernels:
258            op_overload.py_impl(torch._C.DispatchKey.Meta)(
259                functools.partial(_register_cia_to_meta, kernel=op_overload)
260            )
261
262        if op_overload in decomp_table:
263            removed_decomps[op_overload] = decomp_table[op_overload]
264            del decomp_table[op_overload]
265
266    try:
267        yield
268    finally:
269        for op in patched_ops:
270            op.py_kernels.clear()
271            op.py_kernels.update(saved_tables[op])
272            op._dispatch_cache.clear()
273
274        for op, decomp in removed_decomps.items():
275            decomp_table[op] = decomp
276
277
278@contextmanager
279def _override_decomp_aten_to_variants():
280    # Preserve variants of aten::to understanding that they are mutating/aliasing
281    # and their CompositeImplicitAutograd kernels will not become NotImplemented.
282    # We will later replace them with aten._to_copy when functionalizing.
283    with _override_composite_implicit_decomp(
284        (torch.ops.aten.to.dtype_layout, torch.ops.aten.to.dtype),
285        {},
286        safe=False,
287    ):
288        yield
289
290
291def _decompose_and_get_gm_with_new_signature_constants(
292    ep,
293    *,
294    decomp_table: Dict[torch._ops.OperatorBase, Callable],
295    _preserve_ops: Tuple[torch._ops.OpOverload],
296    joint_loss_index: Optional[int],
297):
298    from torch._functorch.aot_autograd import aot_export_module
299    from torch._subclasses.fake_tensor import FakeTensorMode
300    from torch.export._trace import (
301        _export_to_aten_ir,
302        _fakify_params_buffers,
303        _ignore_backend_decomps,
304        _verify_nn_module_stack,
305        _verify_placeholder_names,
306        _verify_stack_trace,
307    )
308    from torch.fx.experimental.symbolic_shapes import ShapeEnv
309
310    # TODO Merge this path with inference IR decomp, but it will require some additional work
311    # so I will leave it for now. T200307782
312    if ep.verifier.dialect == "TRAINING":
313        mod = ep.module()
314
315        fake_args = []
316        for node in mod.graph.nodes:
317            if node.op == "placeholder":
318                fake_args.append(node.meta["val"])
319
320        fake_args_unwrapped = pytree.tree_unflatten(fake_args, mod._in_spec)
321        fake_mode = _detect_fake_mode_from_gm(mod)
322        if fake_mode is None:
323            fake_mode = FakeTensorMode(shape_env=ShapeEnv(), export=True)
324
325        # Fix the graph output signature to be tuple if scalar
326        out_spec = mod._out_spec
327
328        orig_arg_names = mod.graph._codegen.pytree_info.orig_args  # type: ignore[attr-defined]
329
330        # aot_export expect the return type to always be a tuple.
331        if out_spec.type not in (list, tuple):
332            out_spec = pytree.TreeSpec(tuple, None, [out_spec])
333
334        mod.graph._codegen = _PyTreeCodeGen(
335            _PyTreeInfo(
336                orig_arg_names,
337                mod._in_spec,
338                out_spec,
339            )
340        )
341
342        mod.recompile()
343
344        # the exported module will store constants & non-persistent buffers such that
345        # retracing treats them as persistent buffers, so we inform the constants lifting pass
346        # and overwrite the new graph signature using the previous program.
347        constant_attrs = _collect_and_set_constant_attrs(
348            ep.graph_signature, ep.constants, mod
349        )
350
351        # get params & buffers after excluding constants
352        fake_params_buffers = _fakify_params_buffers(fake_mode, mod)
353
354        params_buffers_to_node_meta = _collect_param_buffer_metadata(mod)
355
356        with _ignore_backend_decomps(), (
357            fake_mode
358        ), _override_decomp_aten_to_variants(), _override_composite_implicit_decomp(
359            _preserve_ops,
360            decomp_table,
361        ):
362            aten_export_artifact = _export_to_aten_ir(
363                mod,
364                # this requires empty kwargs, but not in pytree.flattened format
365                (
366                    *fake_args_unwrapped[0],
367                    *fake_args_unwrapped[1].values(),
368                ),
369                {},
370                fake_params_buffers,
371                constant_attrs,
372                decomp_table=decomp_table,
373                _check_autograd_state=False,
374            )
375
376        gm = aten_export_artifact.gm
377        new_graph_signature = aten_export_artifact.sig
378
379        _populate_param_buffer_metadata_to_new_gm(
380            params_buffers_to_node_meta, gm, new_graph_signature
381        )
382
383        # overwrite signature for non-persistent buffers
384        new_graph_signature = _overwrite_signature_for_non_persistent_buffers(
385            ep.graph_signature, new_graph_signature
386        )
387
388        _verify_nn_module_stack(gm)
389        _verify_stack_trace(gm)
390        _verify_placeholder_names(gm, new_graph_signature)
391
392        return _remove_unneccessary_copy_op_pass(gm, new_graph_signature)
393
394    old_placeholders = [
395        node for node in ep.graph_module.graph.nodes if node.op == "placeholder"
396    ]
397    fake_args = [node.meta["val"] for node in old_placeholders]
398
399    buffers_to_remove = [name for name, _ in ep.graph_module.named_buffers()]
400    for name in buffers_to_remove:
401        delattr(ep.graph_module, name)
402
403    # TODO(zhxhchen17) Return the new graph_signature directly.
404    fake_mode = detect_fake_mode(fake_args)
405    fake_mode = contextlib.nullcontext() if fake_mode is None else fake_mode
406    with _ignore_backend_decomps(), fake_mode, _override_composite_implicit_decomp(
407        _preserve_ops,
408        decomp_table,
409    ):
410        gm, graph_signature = aot_export_module(
411            ep.graph_module,
412            fake_args,
413            decompositions=decomp_table,
414            trace_joint=True if joint_loss_index is not None else False,
415            output_loss_index=joint_loss_index
416            if joint_loss_index is not None
417            else None,
418        )
419
420    # Update the signatures with the new placeholder names in case they
421    # changed when calling aot_export
422    def update_arg(old_arg, new_ph):
423        if isinstance(old_arg, ConstantArgument):
424            return old_arg
425        elif isinstance(old_arg, TensorArgument):
426            return TensorArgument(name=new_ph.name)
427        elif isinstance(old_arg, SymIntArgument):
428            return SymIntArgument(name=new_ph.name)
429        raise RuntimeError(f"Type of old_arg not supported: {type(old_arg)}")
430
431    new_placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"]
432    new_outputs = list(gm.graph.nodes)[-1].args[0]
433
434    # rename the placeholders
435    assert len(new_placeholders) == len(old_placeholders)
436    for old_ph, new_ph in zip(old_placeholders, new_placeholders):
437        new_ph.name = new_ph.target = old_ph.name
438
439    # handle name collisions with newly decomposed graph nodes
440    name_map = {ph.name: ph.name for ph in new_placeholders}
441    for node in gm.graph.nodes:
442        if node.op == "placeholder":
443            continue
444        node.name = _rename_without_collisions(name_map, node.name, node.name)
445
446    # propagate names to higher order op subgraphs
447    _name_hoo_subgraph_placeholders(gm)
448
449    # Run this pass before creating input/output specs, since size-related CSE/DCE might affect output signature.
450    # Overwrite output specs afterwards.
451    from torch._export.passes._node_metadata_hook import (
452        _node_metadata_hook,
453        _set_node_metadata_hook,
454    )
455    from torch._functorch._aot_autograd.input_output_analysis import _graph_output_names
456
457    if not torch._dynamo.config.do_not_emit_runtime_asserts:
458        stack_trace = (
459            'File "torch/fx/passes/runtime_assert.py", line 24, '
460            "in insert_deferred_runtime_asserts"
461        )
462        shape_env = _get_shape_env(gm)
463        if shape_env is not None:
464            with _set_node_metadata_hook(
465                gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace)
466            ):
467                insert_deferred_runtime_asserts(
468                    gm,
469                    shape_env,
470                    f"exported program: {first_call_function_nn_module_stack(gm.graph)}",
471                    export=True,
472                )
473
474    # update output specs
475    gm.recompile()
476    for i, name in enumerate(_graph_output_names(gm)):
477        if isinstance(new_outputs[i], torch.fx.Node):
478            new_outputs[i].name = name
479
480    # To match the output target with correct input for input mutations
481    # need to find the old to new placeholder map
482    old_new_placeholder_map = {
483        spec.arg.name: new_placeholders[i].name
484        for i, spec in enumerate(ep.graph_signature.input_specs)
485        if not isinstance(spec.arg, ConstantArgument)
486    }
487
488    input_specs = [
489        InputSpec(
490            spec.kind,
491            update_arg(spec.arg, new_placeholders[i]),
492            spec.target,
493            spec.persistent,
494        )
495        for i, spec in enumerate(ep.graph_signature.input_specs)
496    ]
497    output_specs = [
498        OutputSpec(
499            spec.kind,
500            update_arg(spec.arg, new_outputs[i]),
501            old_new_placeholder_map.get(spec.target, spec.target),
502        )
503        for i, spec in enumerate(ep.graph_signature.output_specs)
504    ]
505
506    if joint_loss_index is not None:
507        assert graph_signature.backward_signature is not None
508        gradients = graph_signature.backward_signature.gradients_to_user_inputs
509        assert len(graph_signature.user_inputs) == len(ep.graph_signature.input_specs)
510        specs = {
511            graph_signature.user_inputs[i]: spec
512            for i, spec in enumerate(ep.graph_signature.input_specs)
513            if isinstance(spec.arg, TensorArgument)
514        }
515        for i, node in enumerate(new_outputs[len(output_specs) :]):
516            source = gradients[node.name]
517            spec = specs[source]  # type: ignore[index]
518            if spec.kind == InputKind.PARAMETER:
519                kind = OutputKind.GRADIENT_TO_PARAMETER
520                target = spec.target
521            elif spec.kind == InputKind.USER_INPUT:
522                kind = OutputKind.GRADIENT_TO_USER_INPUT
523                target = source
524            else:
525                raise AssertionError(f"Unknown input kind: {spec.kind}")
526            output_specs.append(
527                OutputSpec(
528                    kind,
529                    TensorArgument(name=node.name),
530                    target,
531                )
532            )
533
534    assert len(new_placeholders) == len(old_placeholders)
535
536    new_graph_signature = ExportGraphSignature(
537        input_specs=input_specs, output_specs=output_specs
538    )
539    # NOTE: aot_export adds symint metadata for placeholders with int
540    # values; since these become specialized, we replace such metadata with
541    # the original values.
542    # Also, set the param/buffer metadata back to the placeholders.
543    for old_node, new_node in zip(old_placeholders, new_placeholders):
544        if not isinstance(old_node.meta["val"], torch.Tensor):
545            new_node.meta["val"] = old_node.meta["val"]
546
547        if (
548            new_node.target in new_graph_signature.inputs_to_parameters
549            or new_node.target in new_graph_signature.inputs_to_buffers
550        ):
551            for k, v in old_node.meta.items():
552                new_node.meta[k] = v
553    return gm, new_graph_signature
554
555
556def _remove_unneccessary_copy_op_pass(
557    gm: torch.fx.GraphModule, new_graph_signature: ExportGraphSignature
558) -> Tuple[torch.fx.GraphModule, ExportGraphSignature]:
559    """
560    Removes redundant copy_ node that was introduced due to mutated buffer.
561    """
562    with gm._set_replace_hook(new_graph_signature.get_replace_hook()):
563        for node in gm.graph.nodes:
564            if node.op == "output":
565                args, _ = pytree.tree_flatten(node.args)
566                for out in args:
567                    if (
568                        isinstance(out, torch.fx.Node)
569                        and out.name in new_graph_signature.buffers_to_mutate
570                    ):
571                        if (
572                            out.op == "call_function"
573                            and out.target == torch.ops.aten.copy.default
574                        ):
575                            out.replace_all_uses_with(out.args[1])  # type: ignore[arg-type]
576                            gm.graph.erase_node(out)
577        gm.recompile()
578    return gm, new_graph_signature
579
580
581def _common_getitem_elimination_pass(
582    gm: torch.fx.GraphModule, graph_signature, module_call_graph
583):
584    with gm._set_replace_hook(graph_signature.get_replace_hook()):
585        for module in gm.modules():
586            if not isinstance(module, torch.fx.GraphModule):
587                continue
588
589            node_id: Dict[torch.fx.Node, str] = {}
590            getitems: Dict[str, torch.fx.Node] = {}
591            for node in list(module.graph.nodes):
592                if node.op == "call_function" and node.target == operator.getitem:
593                    source, idx = node.args
594                    new_id = f"{node_id[source]}.{idx}"
595                    if new_id in getitems:
596                        node.replace_all_uses_with(getitems[new_id])
597                        for entry in module_call_graph:
598                            if entry.signature is not None:
599                                entry.signature.replace_all_uses_with(
600                                    node, getitems[new_id]
601                                )
602                        module.graph.erase_node(node)
603                    else:
604                        getitems[new_id] = node
605                        node_id[node] = new_id
606                else:
607                    node_id[node] = node.name
608
609
610def _decompose_exported_program(
611    ep,
612    *,
613    decomp_table: Dict[torch._ops.OperatorBase, Callable],
614    _preserve_ops: Tuple[torch._ops.OpOverload],
615    joint_loss_index: Optional[int],
616):
617    gm, new_graph_signature = _decompose_and_get_gm_with_new_signature_constants(
618        ep,
619        decomp_table=decomp_table,
620        _preserve_ops=_preserve_ops,
621        joint_loss_index=joint_loss_index,
622    )
623
624    # TODO unfortunately preserving graph-level metadata is not
625    # working well with aot_export. So we manually copy it.
626    # (The node-level meta is addressed above.)
627    gm.meta.update(ep.graph_module.meta)
628
629    new_range_constraints = _get_updated_range_constraints(
630        gm,
631        ep.range_constraints,
632    )
633
634    exported_program = ExportedProgram(
635        root=gm,
636        graph=gm.graph,
637        graph_signature=new_graph_signature,
638        state_dict=ep.state_dict,
639        range_constraints=new_range_constraints,
640        module_call_graph=copy.deepcopy(ep.module_call_graph),
641        example_inputs=ep.example_inputs,
642        constants=ep.constants,
643    )
644    return exported_program
645
646
647class ExportedProgram:
648    """
649    Package of a program from :func:`export`. It contains
650    an :class:`torch.fx.Graph` that represents Tensor computation, a state_dict containing
651    tensor values of all lifted parameters and buffers, and various metadata.
652
653    You can call an ExportedProgram like the original callable traced by
654    :func:`export` with the same calling convention.
655
656    To perform transformations on the graph, use ``.module`` property to access
657    an :class:`torch.fx.GraphModule`. You can then use
658    `FX transformation <https://pytorch.org/docs/stable/fx.html#writing-transformations>`_
659    to rewrite the graph. Afterwards, you can simply use :func:`export`
660    again to construct a correct ExportedProgram.
661    """
662
663    def __init__(
664        self,
665        root: Union[torch.nn.Module, Dict[str, Any]],
666        graph: torch.fx.Graph,
667        graph_signature: ExportGraphSignature,
668        state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]],
669        range_constraints: "Dict[sympy.Symbol, Any]",
670        module_call_graph: List[ModuleCallEntry],
671        example_inputs: Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]] = None,
672        constants: Optional[
673            Dict[str, Union[torch.Tensor, FakeScriptObject, torch._C.ScriptObject]]
674        ] = None,
675        *,
676        verifiers: Optional[List[Type[Verifier]]] = None,
677    ):
678        # Remove codegen related things from the graph. It should just be a flat graph.
679        graph._codegen = torch.fx.graph.CodeGen()
680        self._graph_module = _create_graph_module_for_export(root, graph)
681        if isinstance(root, torch.fx.GraphModule):
682            self._graph_module.meta.update(root.meta)
683
684        _common_getitem_elimination_pass(
685            self._graph_module, graph_signature, module_call_graph
686        )
687        self._graph_signature: ExportGraphSignature = graph_signature
688        self._state_dict: Dict[str, Any] = state_dict
689        self._range_constraints: Dict[sympy.Symbol, ValueRanges] = range_constraints
690        assert module_call_graph is not None
691        self._module_call_graph: List[ModuleCallEntry] = module_call_graph
692        self._example_inputs = example_inputs
693
694        self._constants = constants or {}
695
696        verifiers = verifiers or [Verifier]
697        assert all(issubclass(v, Verifier) for v in verifiers)
698        self._verifiers = verifiers
699        # Validate should be always the last step of the constructor.
700        self.validate()
701
702    @property
703    @compatibility(is_backward_compatible=False)
704    def graph_module(self):
705        return self._graph_module
706
707    @property
708    @compatibility(is_backward_compatible=False)
709    def graph(self):
710        return self.graph_module.graph
711
712    @property
713    @compatibility(is_backward_compatible=False)
714    def graph_signature(self):
715        return self._graph_signature
716
717    @property
718    @compatibility(is_backward_compatible=False)
719    def state_dict(self):
720        return self._state_dict
721
722    @compatibility(is_backward_compatible=False)
723    def parameters(self) -> Iterator[torch.nn.Parameter]:
724        """
725        Returns an iterator over original module's parameters.
726        """
727        for _, param in self.named_parameters():
728            yield param
729
730    @compatibility(is_backward_compatible=False)
731    def named_parameters(self) -> Iterator[Tuple[str, torch.nn.Parameter]]:
732        """
733        Returns an iterator over original module parameters, yielding
734        both the name of the parameter as well as the parameter itself.
735        """
736        for param_name in self.graph_signature.parameters:
737            yield param_name, self.state_dict[param_name]
738
739    @compatibility(is_backward_compatible=False)
740    def buffers(self) -> Iterator[torch.Tensor]:
741        """
742        Returns an iterator over original module buffers.
743        """
744        for _, buf in self.named_buffers():
745            yield buf
746
747    @compatibility(is_backward_compatible=False)
748    def named_buffers(self) -> Iterator[Tuple[str, torch.Tensor]]:
749        """
750        Returns an iterator over original module buffers, yielding
751        both the name of the buffer as well as the buffer itself.
752        """
753        non_persistent_buffers = set(self.graph_signature.non_persistent_buffers)
754        for buffer_name in self.graph_signature.buffers:
755            if buffer_name in non_persistent_buffers:
756                yield buffer_name, self.constants[buffer_name]
757            else:
758                yield buffer_name, self.state_dict[buffer_name]
759
760    @property
761    @compatibility(is_backward_compatible=False)
762    def range_constraints(self):
763        return self._range_constraints
764
765    @property
766    @compatibility(is_backward_compatible=False)
767    def module_call_graph(self):
768        return self._module_call_graph
769
770    @property
771    @compatibility(is_backward_compatible=False)
772    def example_inputs(self):
773        return self._example_inputs
774
775    @property
776    @compatibility(is_backward_compatible=False)
777    def call_spec(self):
778        CallSpec = namedtuple("CallSpec", ["in_spec", "out_spec"])
779
780        if len(self.module_call_graph) == 0:
781            return CallSpec(in_spec=None, out_spec=None)
782        assert self.module_call_graph[0].fqn == ""
783        return CallSpec(
784            in_spec=self.module_call_graph[0].signature.in_spec,
785            out_spec=self.module_call_graph[0].signature.out_spec,
786        )
787
788    @property
789    @compatibility(is_backward_compatible=False)
790    def verifier(self) -> Any:
791        return self._verifiers[0]
792
793    @property
794    @compatibility(is_backward_compatible=False)
795    def dialect(self) -> str:
796        assert self._verifiers is not None
797        return self._verifiers[0].dialect
798
799    @property
800    @compatibility(is_backward_compatible=False)
801    def verifiers(self):
802        return self._verifiers
803
804    @property
805    @compatibility(is_backward_compatible=False)
806    def tensor_constants(self):
807        return self._constants
808
809    @property
810    @compatibility(is_backward_compatible=False)
811    def constants(self):
812        return self._constants
813
814    def _get_flat_args_with_check(self, args, kwargs):
815        """Flatten args, kwargs using pytree, then, check specs.
816
817        Args:
818            args: List[Any] original args passed to __call__
819            kwargs: Dict[str, Any] original kwargs passed to __call
820
821        Returns:
822            A tuple of (flat_args, received_spec)
823            flat_args is flattend args / kwargs
824            received_spec is the pytree spec produced while flattening the
825            tuple (args, kwargs)
826        """
827        in_spec = self.call_spec.in_spec
828        if in_spec is not None:
829            kwargs = reorder_kwargs(kwargs, in_spec)
830        flat_args_with_path, received_spec = pytree.tree_flatten_with_path(
831            (args, kwargs)
832        )  # type: ignore[possibly-undefined]
833        self._check_input_constraints(flat_args_with_path)
834        flat_args = tuple(x[1] for x in flat_args_with_path)
835        return flat_args, received_spec
836
837    def _graph_module_flat_inputs(self, args: Any, kwargs: Any) -> Any:
838        """Transform args, kwargs of __call__ to args for graph_module.
839
840        self.graph_module takes stuff from state dict as inputs.
841        The invariant is for ep: ExportedProgram is
842        ep(args, kwargs) ==
843          ep.postprocess(ep.graph_module(ep.graph_module_flat_inputs(args, kwargs)))
844        """
845
846        in_spec = self.call_spec.in_spec
847        flat_args, received_spec = self._get_flat_args_with_check(args, kwargs)
848        if in_spec is not None and not is_equivalent(
849            received_spec, in_spec, _fx_collection_equivalence_fn
850        ):
851            raise ValueError(
852                "Trying to flatten user inputs with exported input tree spec: \n"
853                f"{in_spec}\n"
854                "but actually got inputs with tree spec of: \n"
855                f"{received_spec}"
856            )
857
858        additional_inputs = []
859        for input_ in self.graph_signature.input_specs:
860            if input_.kind == InputKind.USER_INPUT:
861                continue
862            elif input_.kind in (
863                InputKind.PARAMETER,
864                InputKind.BUFFER,
865            ):
866                if input_.persistent is False:
867                    # This is a non-persistent buffer, grab it from our
868                    # constants instead of the state dict.
869                    additional_inputs.append(self.constants[input_.target])
870                else:
871                    additional_inputs.append(self.state_dict[input_.target])
872            elif input_.kind in (
873                InputKind.CONSTANT_TENSOR,
874                InputKind.CUSTOM_OBJ,
875            ):
876                additional_inputs.append(self.constants[input_.target])
877        additional_inputs = tuple(additional_inputs)
878
879        # NOTE: calling convention is first params, then buffers, then args as user supplied them.
880        # See: torch/_functorch/aot_autograd.py#L1034
881        return additional_inputs + flat_args
882
883    def __call__(self, *args: Any, **kwargs: Any) -> Any:
884        raise RuntimeError(
885            "Unable to call ExportedProgram directly. "
886            "You should use `exported_program.module()` instead."
887        )
888
889    def _postprocess_graph_module_outputs(self, res, orig_args, orig_kwargs):
890        """Process potential mutations to the input.
891
892        Because self.graph_module is functional, so mutations has to be written
893        back after execution of graph_module.
894        """
895        import torch._export.error as error
896
897        flat_args, _ = self._get_flat_args_with_check(orig_args, orig_kwargs)
898        if self.call_spec.out_spec is not None:
899            buffer_mutation = self.graph_signature.buffers_to_mutate
900            user_input_mutation = self.graph_signature.user_inputs_to_mutate
901            num_mutated = len(buffer_mutation) + len(user_input_mutation)
902            mutated_values = res[:num_mutated]
903
904            # Exclude dependency token from final result.
905            assertion_dep_token = self.graph_signature.assertion_dep_token
906            if assertion_dep_token is not None:
907                assertion_dep_token_index = next(iter(assertion_dep_token.keys()))
908                res = res[:assertion_dep_token_index]
909
910            res = res[num_mutated:]
911            try:
912                res = pytree.tree_unflatten(res, self.call_spec.out_spec)
913            except Exception:
914                _, received_spec = pytree.tree_flatten(res)
915                raise error.InternalError(  # noqa: B904
916                    "Trying to flatten user outputs with exported output tree spec: \n"
917                    f"{self.call_spec.out_spec}\n"
918                    "but actually got outputs with tree spec of: \n"
919                    f"{received_spec}"
920                )
921            finally:
922                user_inputs = [
923                    spec
924                    for spec in self.graph_signature.input_specs
925                    if spec.kind == InputKind.USER_INPUT
926                ]
927                for i, value in enumerate(mutated_values):
928                    output_spec = self.graph_signature.output_specs[i]
929                    if output_spec.kind == OutputKind.BUFFER_MUTATION:
930                        assert output_spec.target is not None
931                        self.state_dict[output_spec.target] = value
932                    elif output_spec.kind == OutputKind.USER_INPUT_MUTATION:
933                        assert output_spec.target is not None
934                        index = next(
935                            i
936                            for i, spec in enumerate(user_inputs)
937                            if spec.arg.name == output_spec.target
938                        )
939                        flat_args[index].copy_(value)
940                    else:
941                        raise AssertionError(f"Unexpected kind: {output_spec.kind}")
942        return res
943
944    def __str__(self) -> str:
945        graph_module = self.graph_module.print_readable(
946            print_output=False, colored=False
947        ).replace("\n", "\n    ")
948        string = (
949            "ExportedProgram:\n"
950            f"    {graph_module}\n"
951            f"Graph signature: {self.graph_signature}\n"
952            f"Range constraints: {self.range_constraints}\n"
953        )
954        return string
955
956    def module(self) -> torch.nn.Module:
957        """
958        Returns a self contained GraphModule with all the parameters/buffers inlined.
959        """
960        from ._unlift import _unlift_exported_program_lifted_states
961
962        module = _unlift_exported_program_lifted_states(self)
963
964        def _train(self, mode: bool = True):
965            raise NotImplementedError("Calling train() is not supported yet.")
966
967        def _eval(self, mode: bool = True):
968            raise NotImplementedError("Calling eval() is not supported yet.")
969
970        module.train = types.MethodType(_train, module)  # type: ignore[method-assign]
971        module.eval = types.MethodType(_eval, module)  # type: ignore[method-assign]
972        return module
973
974    def _num_lifted_params_buffers(self):
975        return next(
976            (
977                i
978                for i, s in enumerate(self._graph_signature.input_specs)
979                if s.kind == InputKind.USER_INPUT
980            ),
981            len(self._graph_signature.input_specs),
982        )
983
984    @_disable_prexisiting_fake_mode
985    def run_decompositions(
986        self,
987        decomp_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None,
988        _preserve_ops: Tuple[torch._ops.OpOverload, ...] = (),
989    ) -> "ExportedProgram":
990        """
991        Run a set of decompositions on the exported program and returns a new
992        exported program. By default we will run the Core ATen decompositions to
993        get operators in the
994        `Core ATen Operator Set <https://pytorch.org/docs/stable/torch.compiler_ir.html>`_.
995
996        For now, we do not decompose joint graphs.
997        """
998        from torch._decomp import core_aten_decompositions
999
1000        if decomp_table is None:
1001            decomp_table = core_aten_decompositions()
1002
1003        return _decompose_exported_program(
1004            self,
1005            decomp_table=decomp_table,
1006            _preserve_ops=_preserve_ops,  # type: ignore[arg-type]
1007            joint_loss_index=None,
1008        )
1009
1010    def _transform_do_not_use(self, *passes: PassType) -> "ExportedProgram":
1011        pm = PassManager(list(passes))
1012        # Since we abstractly run the passes, we need to disable backend decomp here
1013        # again.
1014        from torch.export._trace import _ignore_backend_decomps
1015
1016        with _ignore_backend_decomps():
1017            res = pm(self.graph_module)
1018        transformed_gm = res.graph_module if res is not None else self.graph_module
1019        assert transformed_gm is not None
1020
1021        if transformed_gm is self.graph_module and not res.modified:
1022            return self
1023
1024        # TODO(zhxchen17) Remove this.
1025        def _get_updated_graph_signature(
1026            old_signature: ExportGraphSignature,
1027            new_gm: torch.fx.GraphModule,
1028        ) -> ExportGraphSignature:
1029            """
1030            Update the graph signature's user_input/user_outputs.
1031            """
1032            new_input_specs = []
1033            for i, node in enumerate(new_gm.graph.nodes):
1034                if node.op != "placeholder":
1035                    break
1036
1037                assert i < len(
1038                    old_signature.input_specs
1039                ), "Number of inputs changed after transformation"
1040                old_input_spec = old_signature.input_specs[i]
1041                arg = (
1042                    old_input_spec.arg
1043                    if isinstance(
1044                        old_input_spec.arg, (ConstantArgument, CustomObjArgument)
1045                    )
1046                    else type(old_input_spec.arg)(node.name)
1047                )
1048                new_input_specs.append(
1049                    InputSpec(
1050                        old_input_spec.kind,
1051                        arg,
1052                        old_input_spec.target,
1053                        old_input_spec.persistent,
1054                    )
1055                )
1056
1057            output_node = list(new_gm.graph.nodes)[-1]
1058            assert output_node.op == "output"
1059
1060            new_output_specs = []
1061            for i, node in enumerate(output_node.args[0]):
1062                assert i < len(
1063                    old_signature.output_specs
1064                ), "Number of outputs changed after transformation"
1065                old_output_spec = old_signature.output_specs[i]
1066                arg = (
1067                    old_output_spec.arg
1068                    if isinstance(
1069                        old_output_spec.arg, (ConstantArgument, CustomObjArgument)
1070                    )
1071                    else type(old_output_spec.arg)(node.name)
1072                )
1073                new_output_specs.append(
1074                    OutputSpec(old_output_spec.kind, arg, old_output_spec.target)
1075                )
1076
1077            new_signature = ExportGraphSignature(
1078                input_specs=new_input_specs, output_specs=new_output_specs
1079            )
1080            return new_signature
1081
1082        transformed_ep = ExportedProgram(
1083            root=transformed_gm,
1084            graph=transformed_gm.graph,
1085            graph_signature=_get_updated_graph_signature(
1086                self.graph_signature, transformed_gm
1087            ),
1088            state_dict=self.state_dict,
1089            range_constraints=_get_updated_range_constraints(
1090                transformed_gm,
1091                self.range_constraints,
1092            ),
1093            module_call_graph=copy.deepcopy(self._module_call_graph),
1094            example_inputs=self.example_inputs,
1095            constants=self.constants,
1096            verifiers=self.verifiers,
1097        )
1098        transformed_ep.graph_module.meta.update(self.graph_module.meta)
1099        transformed_ep.graph_module.meta.update(res.graph_module.meta)
1100        return transformed_ep
1101
1102    def _check_input_constraints(self, flat_args_with_path):
1103        from torch._export.utils import _check_input_constraints_for_graph
1104
1105        placeholders = [p for p in self.graph.nodes if p.op == "placeholder"]
1106        input_placeholders = [
1107            p
1108            for p, s in zip(placeholders, self.graph_signature.input_specs)
1109            if s.kind == InputKind.USER_INPUT
1110        ]
1111        _check_input_constraints_for_graph(
1112            input_placeholders, flat_args_with_path, self.range_constraints
1113        )
1114
1115    @compatibility(is_backward_compatible=False)
1116    def validate(self):
1117        self._validate()
1118
1119    # TODO: remove this
1120    @final
1121    def _validate(self):
1122        assert (
1123            len(self.verifiers) > 0
1124        ), "ExportedProgram must have at least one verifier."
1125        for v in self.verifiers:
1126            v().check(self)
1127
1128    # TODO(zhxchen17) Formalize this.
1129    def _update(
1130        self, graph_module, graph_signature, *, state_dict=None, verifiers=None
1131    ) -> "ExportedProgram":
1132        return ExportedProgram(
1133            root=graph_module,
1134            graph=graph_module.graph,
1135            graph_signature=graph_signature,
1136            state_dict=state_dict if state_dict is not None else self.state_dict,
1137            range_constraints=copy.deepcopy(self.range_constraints),
1138            module_call_graph=copy.deepcopy(self._module_call_graph),
1139            example_inputs=self.example_inputs,
1140            constants=self.constants,
1141            verifiers=verifiers if verifiers is not None else self.verifiers,
1142        )
1143
1144
1145def _get_shape_env(gm):
1146    vals = [
1147        node.meta["val"]
1148        for node in gm.graph.nodes
1149        if node.meta.get("val", None) is not None
1150    ]
1151    from torch._guards import detect_fake_mode
1152
1153    fake_mode = detect_fake_mode(vals)
1154    if fake_mode is not None:
1155        return fake_mode.shape_env
1156    for v in vals:
1157        if isinstance(v, torch.SymInt):
1158            return v.node.shape_env
1159
1160
1161def _get_updated_range_constraints(
1162    gm: torch.fx.GraphModule,
1163    old_range_constraints: "Optional[Dict[sympy.Symbol, Any]]" = None,
1164) -> "Dict[sympy.Symbol, Any]":
1165    assert old_range_constraints is not None
1166
1167    shape_env = _get_shape_env(gm)
1168    if shape_env is None:
1169        return {}
1170
1171    range_constraints = copy.copy(old_range_constraints)
1172    range_constraints = {
1173        k: v for k, v in range_constraints.items() if k not in shape_env.replacements
1174    }
1175    # Only when we have an unbacked symint, and it's used as constructor inputs,
1176    # runtime_var_to_range will make a difference compated to var_to_range.
1177    # e.g. [2, oo) -> [0, oo)
1178    for k, v in shape_env.var_to_range.items():
1179        if k not in shape_env.replacements and k not in range_constraints:
1180            range_constraints[k] = v
1181    return range_constraints
1182
1183
1184def _create_graph_module_for_export(root, graph):
1185    try:
1186        gm = torch.fx.GraphModule(root, graph)
1187    except SyntaxError:
1188        # If custom objects stored in memory are being used in the graph,
1189        # the generated python code will result in a syntax error on the custom
1190        # object, since it is unable to parse the in-memory object. However
1191        # we can still run the graph eagerly through torch.fx.Interpreter,
1192        # so we will bypass this error.
1193        warnings.warn(
1194            "Unable to execute the generated python source code from "
1195            "the graph. The graph module will no longer be directly callable, "
1196            "but you can still run the ExportedProgram, and if needed, you can "
1197            "run the graph module eagerly using torch.fx.Interpreter."
1198        )
1199        gm = torch.fx.GraphModule(root, torch.fx.Graph())
1200        gm._graph = graph
1201
1202    return gm
1203