xref: /aosp_15_r20/external/executorch/exir/tracer.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import copy
10import json
11import traceback
12from contextlib import contextmanager
13from dataclasses import asdict, dataclass
14from typing import (
15    Any,
16    Callable,
17    Dict,
18    Generator,
19    Iterable,
20    List,
21    Optional,
22    Set,
23    Tuple,
24    Union,
25)
26
27import executorch.extension.pytree as ex_pytree
28import torch
29import torch._dynamo as torchdynamo
30import torch.fx as fx
31
32import torch.fx._pytree as fx_pytree
33import torch.utils._pytree as pytree
34
35from executorch.exir.common import (
36    extract_out_arguments,
37    format_schema_name,
38    no_dispatch,
39    setting_python_recursive_limit,
40)
41from executorch.exir.error import ExportError, ExportErrorType, InternalError
42from executorch.exir.graph_module import LeafValue
43from executorch.exir.operator.convert import is_out_variant
44from executorch.exir.types import ValueSpec
45
46from torch._C import _EnableTorchFunction, DisableTorchFunctionSubclass  # @manual
47from torch._decomp import get_decompositions
48from torch._dynamo.guards import Guard
49from torch._functorch.eager_transforms import _maybe_unwrap_functional_tensor
50from torch.export import default_decompositions
51from torch.func import functionalize
52from torch.fx.operator_schemas import normalize_function
53from torch.utils._pytree import TreeSpec
54
55from typing_extensions import TypeAlias
56
57
58Value: TypeAlias = Union[
59    LeafValue,
60    Tuple["Value", ...],
61    List["Value"],
62    Dict[str, "Value"],
63]
64
65torchdynamo_enabled = False
66
67
68def get_stacktrace() -> List[Dict[str, str]]:
69    """
70    Get the current stacktrace (between trace() and __torch_dispatch__())
71    Include the filename, function name, line number, and source code from the
72    start of the function to the given instruction.
73
74    Return:
75        A list of stacktraces for each instruction along with the source code
76        context surrounding each instruction
77    """
78
79    stacktrace = traceback.extract_stack()
80
81    # The stacktrace typically looks like this:
82    #
83    #   1. I stack frames from the top level runner (e.g., the
84    #      test suite runner)
85    #   2. J frames in executorch/exir/tracer.py setting up tracing
86    #      (call this INIT_EXIR)
87    #   3. K frames in user model code (this is what we want to save!)
88    #   4. 1 frame in executorch/exir/tracer.py __torch_function__
89    #      returning to tracer (call this TRACE_EXIR)
90    #   5. H frames in executorch/exir/tracer.py AND torch/_tensor.py
91    #      doing all of the internal tracer handling
92    #
93    # The PyE tests assert that executorch/exir/tracer.py never shows
94    # up in the user provided stack traces, so we must oblige them.
95    #
96    # Assumptions:
97    #   - Reentrant tracing is not a thing.  Thus, the first time
98    #     executorch/exir/tracer.py shows up in the trace, we know
99    #     THAT is the point at which we start tracing.  (An alternative
100    #     is that the tracer entry point could record the stack trace
101    #     at this time, but I didn't do this.)
102    #
103    # Our plan is to do a miniature stack machine traversing these
104    # stack machines.
105
106    # Remove parts before the trace function and parts after entering
107    # __torch_dispatch__.  Defaults to returning the entire stack trace.
108    init_exir_end = 0
109    trace_exir_start = None
110    # A miniature state machine, referring to the frame segments described
111    # above.  The locations are closed-open interval.
112    FIND_INIT_EXIR_START, FIND_INIT_EXIR_END, FIND_TRACE_EXIR_START = range(3)
113    state = FIND_INIT_EXIR_START
114    for i, frame in enumerate(stacktrace):
115        if state == FIND_INIT_EXIR_START:
116            if "executorch/exir/tracer.py" in frame.filename:
117                state = FIND_INIT_EXIR_END
118        elif state == FIND_INIT_EXIR_END:
119            if "executorch/exir/tracer.py" not in frame.filename:
120                init_exir_end = i
121                state = FIND_TRACE_EXIR_START
122        elif state == FIND_TRACE_EXIR_START:
123            if "executorch/exir/tracer.py" in frame.filename:
124                trace_exir_start = i
125                break
126
127    stacktrace = stacktrace[init_exir_end:trace_exir_start]
128
129    # Get the source code from the errored line to it
130    contexts: List[str] = []
131    for s in stacktrace:
132        try:
133            with open(s.filename) as file:
134                # pyre-fixme[6]: For 1st param expected `Union[SupportsTrunc, bytes,
135                #  str, SupportsInt, SupportsIndex]` but got `Optional[int]`.
136                lineno = int(s.lineno)
137                # Get the source code 5 lines above/below the current instruction
138                file_contents = [
139                    str(index + 1) + line for index, line in enumerate(file.readlines())
140                ]
141                file_contents_above = "".join(
142                    file_contents[max(lineno - 5, 0) : lineno]
143                )
144                file_contents_below = "".join(
145                    file_contents[lineno : min(lineno + 5, len(file_contents))]
146                )
147                context = (
148                    file_contents_above
149                    + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n"
150                    + file_contents_below
151                )
152                contexts.append(context)
153        except FileNotFoundError:
154            contexts.append("<unknown file: unknown line>")
155
156    # torch.fx stack preservation logic expects strings to
157    # be passed around. Working with dictionary is lot easier
158    # to convert to string and vice versa.
159    frames: List[Dict[str, str]] = []
160    for i, frame in enumerate(stacktrace):
161        frames.append(
162            {
163                "filename": str(frame.filename),
164                "lineno": str(frame.lineno),
165                "name": str(frame.name),
166                "line": str(frame.line),
167                "context": contexts[i],
168            }
169        )
170
171    return frames
172
173
174def unwrap_functional(t: torch.Tensor) -> torch.Tensor:
175    assert isinstance(t, torch.Tensor)
176    return _maybe_unwrap_functional_tensor(t, reapply_views=False)
177
178
179def unwrap_proxy(t: LeafValue) -> Union[LeafValue, torch.fx.Proxy]:
180    if not isinstance(t, torch.Tensor):
181        return t
182    t = unwrap_functional(t)
183    return t.proxy if isinstance(t, PythonTensor) else t
184
185
186def single_return(
187    output: LeafValue,
188    proxy: torch.fx.Proxy,
189    wrapper: Callable[..., LeafValue],
190) -> LeafValue:
191    if isinstance(output, torch.Tensor):
192        return wrapper(output, proxy)
193
194    return output
195
196
197def tree_return(
198    outputs: Value,
199    proxy: torch.fx.Proxy,
200    wrapper: Callable[..., LeafValue],
201    meta_type: Callable[..., Iterable[ValueSpec]] = tuple,
202) -> Value:
203    i: int = 0
204
205    def wrap(o: LeafValue) -> LeafValue:
206        nonlocal i
207        ret = single_return(o, proxy[i], wrapper)
208        i += 1
209        return ret
210
211    return pytree.tree_map(wrap, outputs)
212
213
214class DummyProxy:
215    def __init__(self) -> None:
216        class DummyNode:
217            def __init__(self):
218                self.meta = {}
219
220        self.node = DummyNode()
221
222    def __getitem__(self, key: str) -> "DummyProxy":
223        return DummyProxy()
224
225
226class PythonTensor(torch.Tensor):
227    """
228    A wrapper tensor subclass used in the DispatchTracer to keep track of
229    proxies to construct the FX graph.
230
231    Wrapping something in PythonTensor implicitly detaches gradients.  If
232    something required grad, we will collect it as if it were a leaf.  A
233    consequence of detaching in this way is you need to maintain a parameter
234    cache when translating tensors into PythonTensor, so you don't create
235    multiple copies of a gradient (they are aliased, but they would count as
236    independent leaves).  An alternate strategy would be to avoid implicitly
237    detaching and instead "catch" gradients as they exit the PythonTensor
238    boundary.
239    """
240
241    __slots__ = ["proxy", "is_immutable"]
242
243    @staticmethod
244    def __new__(
245        cls, elem: torch.Tensor, proxy: torch.fx.Proxy, is_immutable: bool = False
246    ) -> torch.Tensor:
247        # assert not elem.requires_grad or not torch.is_grad_enabled()
248
249        r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
250        assert isinstance(r, PythonTensor)
251        r.is_immutable: bool = is_immutable
252        r.update_proxy(proxy)
253        return r
254
255    def update_proxy(self, proxy: torch.fx.Proxy) -> None:
256        self.proxy = proxy
257
258    def __repr__(self, *, tensor_contents: None = None) -> str:
259        with no_dispatch():
260            return f"PythonTensor({self.as_subclass(torch.Tensor)})"
261
262    @classmethod
263    def __torch_function__(
264        cls,
265        # pyre-ignore: Missing parameter annotation [2]
266        func,
267        # pyre-ignore: Missing parameter annotation [2]
268        types,
269        args: Tuple[Value, ...] = (),
270        kwargs: Optional[Dict[str, Value]] = None,
271    ) -> Value:
272        if kwargs is None:
273            kwargs = {}
274        if torch.is_inference_mode_enabled():
275            if func is torch.nn.functional.layer_norm:
276                args, kwargs = normalize_function(func, args, kwargs)  # pyre-fixme[23]
277                input, normalized_shape = args
278                normalized_shape = list(normalized_shape)
279                return cls.__torch_dispatch__(
280                    torch.ops.aten.layer_norm.default,
281                    types,
282                    (input, normalized_shape),
283                    kwargs,
284                )
285            elif func is torch.nn.functional.linear:
286                return cls.__torch_dispatch__(
287                    torch.ops.aten.linear.default, types, args, kwargs
288                )
289        with DisableTorchFunctionSubclass():
290            return func(*args, **kwargs)
291
292    @classmethod
293    def __torch_dispatch__(  # noqa: C901
294        cls,
295        func_overload: torch._ops.OpOverload,
296        # pyre-ignore: Missing parameter annotation [2]
297        types,
298        args: Tuple[Value, ...] = (),
299        kwargs: Optional[Dict[str, Value]] = None,
300    ) -> Value:
301        """
302        This function is invoked every time an aten operation is called.
303
304        Args:
305            func_overload: The function that was called that invoked this
306                torch_dispatch call
307            types:
308            args: Arguments that were passed into the function. Each argument
309                has type PythonTensor.
310            kwargs: Keyword arguments that were passed into the function. Each
311                argument has type PythonTensor.
312        """
313        func = func_overload.overloadpacket
314
315        kwargs = kwargs or {}
316        if is_out_variant(func._qualified_op_name, func_overload._overloadname):
317            out_args = extract_out_arguments(func_overload._schema, kwargs)
318            out_args_iter = [out_args] if not isinstance(out_args, list) else out_args
319            for out_arg_name, out_arg_val in out_args_iter:
320                if isinstance(out_arg_val, PythonTensor) and out_arg_val.is_immutable:
321                    raise RuntimeError(
322                        "Immutable tensor `{}` is potentially getting modified by {}".format(
323                            out_arg_name, format_schema_name(func_overload._schema)
324                        )
325                    )
326
327        # pyre-fixme[16]: Module `pytree` has no attribute `tree_map`.
328        proxy_args = ex_pytree.tree_map(unwrap_proxy, args)
329        # pyre-fixme[16]: Module `pytree` has no attribute `tree_map`.
330        proxy_kwargs = ex_pytree.tree_map(unwrap_proxy, kwargs)
331
332        # Get the output of the function
333        g = _EnableTorchFunction()
334        try:
335            proxy_out = (
336                func_overload(*proxy_args, **proxy_kwargs)
337                if DispatchTracer.get() or torchdynamo_enabled
338                # Disable node creation when no tracer is active.
339                else DummyProxy()
340            )
341        finally:
342            del g
343
344        with no_dispatch():
345            real_out = func_overload(*args, **kwargs)
346
347        # Kind of a hacky way to test if an op is in-place or not
348        if func.__name__[-1] == "_" and func.__name__[0] != "_":
349            if isinstance(args[0], PythonTensor):
350                args[0].proxy = proxy_out
351
352        if not torch.fx.traceback.has_preserved_node_meta():
353            proxy_out.node.meta["stack_trace"] = json.dumps(get_stacktrace())
354
355        # Wrap the output tensors with the PythonTensor subclass to propagate to
356        # future tracing
357        def wrap_with_proxy(e: LeafValue, proxy: torch.fx.Proxy) -> LeafValue:
358            # Some ops (like native_batch_norm_backward) return undefined tensors that get
359            # converted into None in python.
360            # As the function signature expects tensors, if we directly return these None
361            # tensors back to C++, we'll error.
362            if e is None:
363                e = torch.empty(())
364
365            if isinstance(e, torch.Tensor):
366                return PythonTensor(e, proxy)
367
368            # Inplace and out-variant ops may return one of their arguments, which is already
369            # a PythonTensor. In this case, we need to update the PythonTensor's associated
370            # proxy to the newly created proxy.
371            if isinstance(e, PythonTensor):
372                e.update_proxy(proxy)
373                return e
374
375            return e
376
377        retval = None
378        if not isinstance(real_out, (list, tuple)):
379            retval = single_return(real_out, proxy_out, wrap_with_proxy)
380        else:
381            retval = tree_return(real_out, proxy_out, wrap_with_proxy, type(real_out))
382        return retval
383
384
385@contextmanager
386def using_tracer(tracer: Optional["DispatchTracer"]) -> Generator[None, None, None]:
387    """
388    Set the "current" global tracer within the scope of using_tracer
389    context manager.
390
391    Since various things we want to capture today with torch_dispatch
392    does not "trap" into dispatcher really (for example, cond() and
393    shape()), we need a separate singleton tracer exposed to user space
394    in addition to Dispatcher to trigger graph capturing.
395    """
396    global TRACER
397    TRACER, prev = tracer, TRACER
398    try:
399        yield
400    finally:
401        TRACER = prev
402
403
404class DispatchTracer(fx.Tracer):
405    def __init__(self) -> None:
406        super().__init__()
407        self.root: torch.nn.Module = torch.nn.Module()
408        self.tensor_attrs: Dict[torch.Tensor, str] = {}
409        self.submodules: Dict[fx.GraphModule, str] = {}
410
411    def call_module(
412        self,
413        m: torch.nn.Module,
414        forward: Callable[..., Value],
415        args: Tuple[Value, ...],
416        kwargs: Dict[str, Value],
417    ) -> Value:
418        return forward(*args, **kwargs)
419
420    def _module_getattr(
421        self, attr: str, attr_val: Value, parameter_proxy_cache: Dict[str, torch.Tensor]
422    ) -> Value:
423        if isinstance(attr_val, torch.nn.Parameter):
424            for n, p in self.root.named_parameters():
425                if attr_val is p:
426                    if n not in parameter_proxy_cache:
427                        proxy = self.create_proxy("get_attr", n, (), {})
428                        parameter_proxy_cache[n] = PythonTensor(attr_val, proxy)
429                    return parameter_proxy_cache[n]
430            return attr_val
431        return attr_val
432
433    def create_arg(self, a: Value) -> torch.fx.Node:  # noqa: C901
434        if isinstance(a, torch.nn.Parameter):
435            for n, p in self.root.named_parameters():
436                if a is p:
437                    return self.create_node("get_attr", n, (), {})
438            qualname: Optional[str] = None
439
440            if not qualname:
441                i = 0
442                while True:
443                    qualname = f"_param_constant{i}"
444                    if not hasattr(self.root, qualname):
445                        break
446                    i += 1
447                setattr(self.root, qualname, a)
448
449            return self.create_node("get_attr", qualname, (), {})
450
451        if isinstance(a, torch.Tensor):
452            qualname: Optional[str] = self.tensor_attrs.get(a)
453
454            if not qualname:
455                i = 0
456                while True:
457                    qualname = f"_tensor_constant{i}"
458                    if not hasattr(self.root, qualname):
459                        break
460                    i += 1
461                self.tensor_attrs[a] = qualname
462                self.root.register_buffer(qualname, a)
463
464            return self.create_node("get_attr", qualname, (), {})
465
466        # higher-order operator
467        if isinstance(a, fx.GraphModule):
468            if a not in self.submodules:
469                name_submodule = f"submodule_{len(self.submodules)}"
470                self.root.add_module(name_submodule, a)
471                self.submodules[a] = name_submodule
472            return self.create_node("get_attr", self.submodules[a], (), {})
473
474        return super().create_arg(a)  # pyre-fixme[7]
475
476    @staticmethod
477    def get() -> "DispatchTracer":
478        return TRACER
479
480    def trace(  # pyre-fixme[14,15]
481        self,
482        root: Callable[..., Value],
483        concrete_args: Tuple[Value, ...] = (),
484        in_spec: Optional[TreeSpec] = None,
485    ) -> Value:
486        """
487        Traces the given graph module.
488        """
489        with using_tracer(self):
490            return self._trace(root, concrete_args=concrete_args, in_spec=in_spec)
491
492    def _trace(
493        self,
494        root: Callable[..., Value],
495        concrete_args: Tuple[Value, ...],
496        in_spec: Optional[TreeSpec],
497    ) -> Value:
498        self.root = torch.nn.Module()
499        root_fn = root
500
501        tracer_cls = getattr(self, "__class__", None)
502        self.graph = fx.Graph(tracer_cls=tracer_cls)
503        # Don't support module, so tensor_attrs is always empty
504        self.tensor_attrs = {}
505
506        # Wrap all inputs as a PythonTensor subclass and insert them into the FX
507        # graph as placeholder nodes
508        def wrap(arg: Value, i: int) -> Value:
509            placeholder = self.create_proxy("placeholder", f"ph_{i}", (), {})
510            if isinstance(arg, torch.Tensor):
511                return PythonTensor(arg, placeholder, is_immutable=True)
512            else:
513                # torch._assert(
514                #     placeholder == arg,
515                #     f"ph_{i} has been specialized to have value {arg}",
516                # )
517                return arg
518
519        tree_args = [wrap(arg, i) for i, arg in enumerate(concrete_args)]
520        if in_spec:
521            tree_args = pytree.tree_unflatten(tree_args, in_spec)
522
523        tree_out = root_fn(*tree_args)
524
525        out_args, _ = pytree.tree_flatten(tree_out)
526
527        def unwrap(out: LeafValue) -> Union[LeafValue, torch.fx.Proxy]:
528            # it's legit for a model to return a list of items some of which
529            # are None
530            if out is None:
531                return None
532            if not isinstance(out, torch.Tensor):
533                raise TypeError(
534                    f"Expect model to return torch.Tensor, got type: '{type(out)}' (value: {out})."
535                )
536            return unwrap_proxy(out)
537
538        returns = [unwrap(out) for out in out_args]
539
540        return_annotation = None
541        # some ops like torch.sub doesn't have annotations
542        if hasattr(root_fn, "__annotations__"):
543            return_annotation = root_fn.__annotations__.get("return", None)
544
545        self.create_proxy(
546            "output",
547            "output",
548            (returns,),
549            {},
550            type_expr=return_annotation,
551        )
552
553        self.submodule_paths = None
554
555        return tree_out
556
557
558TRACER: Optional[DispatchTracer] = None
559TORCHDYNAMO_ENABLED: bool = False
560
561
562@contextmanager
563def using_dynamo(val: bool) -> Generator[None, None, None]:
564    global TORCHDYNAMO_ENABLED
565    TORCHDYNAMO_ENABLED, prev = val, TORCHDYNAMO_ENABLED
566    try:
567        yield
568    finally:
569        TORCHDYNAMO_ENABLED = prev
570
571
572def flattened_dispatch_trace(
573    f: Callable[..., Value],
574    args: Tuple[LeafValue, ...],
575    guards: Set[Guard],
576    in_spec: Optional[TreeSpec] = None,
577    enable_functionalization: bool = True,
578) -> Tuple[torch.fx.GraphModule, Value]:
579    if not isinstance(args, tuple):
580        raise TypeError(f"Expecting 'args' to be a tuple, got: {type(args)}")
581
582    tracer = DispatchTracer()
583
584    if enable_functionalization:
585        f = functionalize(f, remove="mutations_and_views")
586    tree_out = tracer.trace(f, concrete_args=args, in_spec=in_spec)
587
588    name = type(f).__name__ if isinstance(f, torch.nn.Module) else f.__name__
589    gm = torch.fx.GraphModule(tracer.root, tracer.graph, name)
590
591    return (gm, tree_out)
592
593
594@dataclass
595class ExirDynamoConfig:
596    """
597    Manage Exir-specific configurations of Dynamo.
598    """
599
600    allow_rnn: bool = True
601    verbose: bool = True
602    assume_static_by_default: bool = False
603
604
605def flatten_output(gm: torch.fx.GraphModule) -> None:
606    """
607    Modifies the output nodes in the submodules to return the result
608    as a flattened list. This keeps it consistent with the result of
609    EXIR's tracer
610    """
611    for node in reversed(gm.graph.nodes):
612        if node.op == "output":
613            assert len(node.args) == 1
614            outputs = node.args[0]
615            returns, _ = pytree.tree_flatten(outputs)
616            node.args = (returns,)
617            return
618    raise RuntimeError(f"Could not find an output node in {gm.graph}")
619
620
621def _default_decomposition_table(
622    _use_old_decomp_table=False,
623) -> Dict[torch._ops.OpOverload, Callable[..., Value]]:
624    if _use_old_decomp_table:
625        decomp_opset = [
626            torch.ops.aten.log_sigmoid_forward,
627            torch.ops.aten.ones,
628            torch.ops.aten.arange.default,
629            torch.ops.aten.arange.start,
630            torch.ops.aten.transpose,
631        ]
632        # pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.e...
633        return get_decompositions(decomp_opset)
634    # pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.exir....
635    return default_decompositions()
636
637
638def dynamo_trace(
639    f: Callable[..., Value],
640    # pyre-ignore
641    args: Tuple[Any, ...],
642    aten_graph: bool,
643    tracing_mode: str = "real",
644    dynamo_config: Optional[ExirDynamoConfig] = None,
645    # pyre-ignore
646    dynamic_shapes: Optional[List[Any]] = None,
647    _use_old_decomp_table: bool = False,
648) -> Tuple[torch.fx.GraphModule, Set[Guard]]:
649    """
650    TODO: Once we fully migrate to torchdynamo frontend, we will remove
651    this config option alltogether.  For now, it helps with quick
652    experiments with playing around with TorchDynamo
653    """
654    if dynamo_config is None:
655        dynamo_config = ExirDynamoConfig()
656
657    with torchdynamo.config.patch(
658        asdict(dynamo_config)
659    ), setting_python_recursive_limit(2000):
660        torchdynamo.reset()
661        try:
662            # TODO merge executorch functionalization with official
663            # functionalization
664            # pyre-fixme[7]: Expected `Tuple[GraphModule, Set[Guard]]` but got
665            #  `ExportResult`.
666            return torchdynamo.export(
667                f,
668                aten_graph=aten_graph,
669                tracing_mode=tracing_mode,
670                assume_static_by_default=dynamo_config.assume_static_by_default,
671                decomposition_table=(
672                    _default_decomposition_table(_use_old_decomp_table)
673                    if aten_graph
674                    else None
675                ),
676                dynamic_shapes=dynamic_shapes,
677            )(
678                *copy.deepcopy(args),
679            )
680        except torchdynamo.exc.Unsupported as exc:
681            raise ExportError(
682                ExportErrorType.NOT_SUPPORTED,
683                "The user code is using a feature we don't support. "
684                "Please try torchdynamo.explain() to get possible the reasons",
685            ) from exc
686        except Exception as exc:
687            raise InternalError(
688                "torchdynamo internal error occured. Please see above stacktrace"
689            ) from exc
690
691
692def dispatch_trace(
693    f: Callable[..., Value],
694    args: Tuple[Value, ...],
695) -> torch.fx.GraphModule:
696    """
697    Executes a given callable `f` with a given tuple of arguments. During
698    execution, Tensor operations are recorded in a fx.GraphModule, which is then
699    returned.
700
701    Args:
702        f: A `nn.Module` or a Python function that implements an ML program.
703        args: A tuple of arguments of any type to be used as inputs for the tracing run.
704
705    Returns:
706        EXIR contained in a fx.GraphModule
707    """
708    trace_func = f
709    guards = set()
710    if TORCHDYNAMO_ENABLED:
711        # Copying args is safer in case downstream implementations of trace_func mutate them
712        trace_func, guards = dynamo_trace(trace_func, args, False)
713
714    # Copying args is safer in case downstream implementations of trace_func mutate them
715    trace_args, in_spec = pytree.tree_flatten(args)
716
717    in_args = copy.deepcopy(tuple(trace_args))
718    gm, tree_out = flattened_dispatch_trace(
719        trace_func,
720        in_args,
721        guards,
722        in_spec,
723        enable_functionalization=False,
724    )
725
726    _, out_spec = pytree.tree_flatten(tree_out)
727
728    # pyre-fixme[16]: `GraphModule` has no attribute `in_spec`.
729    gm.in_spec = in_spec
730    # pyre-fixme[16]: `GraphModule` has no attribute `out_spec`.
731    gm.out_spec = out_spec
732
733    # TODO (tmanlaibaatar) This is bit clowny, but our
734    # dispatch_trace sometimes creates unused node that
735    # breaks functionalization. it seems too much trouble
736    # to fix it properly since dispatch_trace will be deprecated soon.
737    # Basically dispatch_trace struggles on:
738    # def f(x: torch.Tensor) -> torch.Tensor:
739    #    return torch.ones(6, dtype=x.dtype)
740    changed = gm.graph.eliminate_dead_code()
741    if changed:
742        gm.recompile()
743
744    in_args = copy.deepcopy(tuple(trace_args))
745    assert callable(gm)
746
747    # This wrapper is used for preserving the stacktrace
748    # during second round of tracing.
749    # pyre-ignore
750    def graph_with_interpreter(*args):
751        try:
752            args = fx_pytree.tree_flatten_spec(args, gm.in_spec)  # type: ignore[assignment]
753        except Exception:
754            _, received_spec = pytree.tree_flatten(args)
755            raise RuntimeError(
756                "Trying to flatten user inputs with exported input tree spec: \n"
757                f"{gm.in_spec}\n"
758                "but actually got inputs with tree spec of: \n"
759                f"{received_spec}"
760            )
761        with torch.fx.traceback.preserve_node_meta():
762            res = gm(*args)
763
764        if gm.out_spec is not None:
765            try:
766                res = pytree.tree_unflatten(res, gm.out_spec)
767            except Exception:
768                _, received_spec = pytree.tree_flatten(res)
769                raise RuntimeError(
770                    "Trying to flatten user outputs with exported output tree spec: \n"
771                    f"{gm.out_spec}\n"
772                    "but actually got outputs with tree spec of: \n"
773                    f"{received_spec}"
774                )
775        return res
776
777    gm, tree_out = flattened_dispatch_trace(
778        graph_with_interpreter,
779        in_args,
780        guards,
781        in_spec,
782        enable_functionalization=True,
783    )
784
785    gm.in_spec = in_spec
786    gm.out_spec = out_spec
787
788    return gm
789