xref: /aosp_15_r20/external/pytorch/torch/fx/experimental/proxy_tensor.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# Copyright (c) Facebook, Inc. and its affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8from __future__ import annotations
9
10import functools
11import inspect
12import logging
13import operator
14import traceback
15import typing
16import typing_extensions
17import warnings
18import weakref
19from collections import defaultdict
20from contextlib import contextmanager, ExitStack, nullcontext
21from dataclasses import dataclass
22from typing import (
23    Any,
24    Callable,
25    Dict,
26    Generator,
27    List,
28    Mapping,
29    Optional,
30    overload,
31    Protocol,
32    Sequence,
33    Tuple,
34    Type,
35    TYPE_CHECKING,
36    TypeVar,
37    Union,
38)
39from typing_extensions import Concatenate, ParamSpec, Self
40from weakref import WeakKeyDictionary
41
42import torch
43import torch._ops
44import torch.fx as fx
45import torch.fx.traceback as fx_traceback
46import torch.utils._pytree as pytree
47from torch import SymBool, SymInt, Tensor
48from torch._dispatch.python import enable_python_dispatcher
49from torch._library.fake_class_registry import FakeScriptObject
50from torch._subclasses.fake_impls import fast_detach
51from torch._subclasses.fake_tensor import (
52    FakeTensor,
53    FakeTensorMode,
54    is_fake,
55    unset_fake_temporarily,
56)
57from torch._subclasses.meta_utils import is_sparse_any
58from torch.fx import GraphModule, Proxy, Tracer
59from torch.fx.graph_module import _assign_attr
60from torch.fx.node import _side_effectful_need_to_be_preserved_pre_dispatch
61from torch.fx.passes.shape_prop import _extract_tensor_metadata
62from torch.nn import Module
63from torch.overrides import TorchFunctionMode
64from torch.utils._python_dispatch import (
65    _disable_infra_mode,
66    _push_mode,
67    _unset_infra_mode,
68    TorchDispatchMode,
69)
70from torch.utils._stats import count
71from torch.utils._thunk import Thunk
72from torch.utils._traceback import CapturedTraceback
73from torch.utils.weak import _WeakHashRef, WeakIdKeyDictionary, WeakTensorKeyDictionary
74
75from ._backward_state import BackwardState
76from .sym_node import SymNode
77
78
79if TYPE_CHECKING:
80    import types
81    from collections.abc import MutableMapping
82
83    import sympy
84
85    from torch._ops import OpOverload
86    from torch.fx._symbolic_trace import PHBase
87    from torch.types import IntLikeType
88
89__all__ = [
90    "PythonKeyTracer",
91    "dispatch_trace",
92    "make_fx",
93    "DecompositionInterpreter",
94    "py_sym_types",
95    "get_innermost_proxy_mode",
96    "get_proxy_mode",
97    "handle_sym_dispatch",
98    "maybe_enable_thunkify",
99    "maybe_disable_thunkify",
100]
101
102_ProxyTracer = Union["PythonKeyTracer", "_GraphAppendingTracerEx"]
103
104_AnyScriptObject = (torch.ScriptObject, FakeScriptObject)
105_AnyScriptObjectType = Union[torch.ScriptObject, FakeScriptObject]
106
107aten = torch.ops.aten
108prim = torch.ops.prim
109
110log = logging.getLogger(__name__)
111not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented")
112
113CURRENT_DECOMPOSITION_TABLE: Mapping[OpOverload, Callable] = {}
114
115CONSTANT_NUMEL_LIMIT = 1
116
117T = TypeVar("T")
118U = TypeVar("U")
119_P = ParamSpec("_P")
120R = TypeVar("R")
121
122null_ctx_type = type(nullcontext)
123# We currently convert all SymInt to proxies before we use them.
124# This could plausibly be handled at the Dynamo level.
125pytree.register_pytree_node(
126    torch.Size,
127    lambda xs: (list(xs), None),
128    lambda xs, _: tuple(xs),
129    flatten_with_keys_fn=lambda xs: (
130        [(pytree.SequenceKey(i), x) for i, x in enumerate(xs)],
131        None,
132    ),
133)
134
135
136def fake_signature(fn: Callable[_P, R], nargs: int) -> Callable[_P, R]:
137    """FX gets confused by varargs, de-confuse it"""
138    argnames = ",".join(f"arg{i}" for i in range(nargs))
139    return eval(f"lambda {argnames}: fn({argnames})", {"fn": fn})
140
141
142@contextmanager
143def decompose(
144    decomposition_table: Optional[Mapping[OpOverload, Callable]]
145) -> Generator[Mapping[OpOverload, Callable], None, None]:
146    global CURRENT_DECOMPOSITION_TABLE
147    old_decomposition_table = CURRENT_DECOMPOSITION_TABLE
148    CURRENT_DECOMPOSITION_TABLE = decomposition_table or {}
149    try:
150        yield CURRENT_DECOMPOSITION_TABLE
151    finally:
152        CURRENT_DECOMPOSITION_TABLE = old_decomposition_table
153
154
155# ensure we cannot collide with other properties
156proxy_slot = object()
157
158
159class _NoDefault:
160    pass
161
162
163no_default = _NoDefault()
164
165from torch.types import py_sym_types, PySymType
166
167
168class _HasMeta(Protocol):
169    meta: Dict[str, PySymType]
170
171
172def is_sym_node(node: _HasMeta) -> bool:
173    assert hasattr(node, "meta"), "All nodes traced with proxy_tensor should have meta"
174    return "val" in node.meta and isinstance(node.meta["val"], py_sym_types)
175
176
177@overload
178def set_proxy_slot(obj: Tensor, tracer: _ProxyTracer, proxy: _ProxyTensor) -> None:
179    ...
180
181
182@overload
183def set_proxy_slot(
184    obj: _AnyScriptObjectType, tracer: _ProxyTracer, proxy: Proxy
185) -> None:
186    ...
187
188
189@overload
190def set_proxy_slot(
191    obj: PySymType, tracer: _ProxyTracer, proxy: _PySymProxyType
192) -> None:
193    ...
194
195
196def set_proxy_slot(
197    obj: Union[PySymType, _AnyScriptObjectType, Tensor],
198    tracer: _ProxyTracer,
199    proxy: object,
200) -> None:
201    log.debug("set_proxy_slot %s (%s) %s", obj, id(obj), proxy)
202    if isinstance(obj, Tensor):
203        # We DO want to clobber proxies whenever we run an inplace operation
204        # on a tensor, and it affects the metadata on the proxy.
205        assert isinstance(proxy, _ProxyTensor)
206        tracer.tensor_tracker[obj] = proxy
207    elif isinstance(obj, (_AnyScriptObject)):
208        # We DO want to clobber proxies, with a similar rationale as for tensors.
209        assert isinstance(proxy, Proxy)
210        tracer.script_object_tracker[obj] = proxy
211    else:
212        # NB: Never clobber pre-existing proxy.  Although the proxies
213        # are in principle equivalent, when we do graph partitioning
214        # we need there not to be spurious dependencies on tangent inputs.
215        # This works because primals get their SymInts set first, and
216        # THEN later we allocate tangent inputs.  Make sure if a SymInt
217        # is derivable from a primal that we use that.
218        assert isinstance(obj, py_sym_types), type(obj)
219        if obj not in tracer.symnode_tracker:
220            tracer.symnode_tracker[obj] = typing.cast(_PySymProxyType, proxy)
221
222            # WAR: python test/dynamo/test_subclasses.py
223            # TestNestedTensor.test_basic_autograd
224            #
225            # AOTAutograd doesn't pass the "outer sizes" as an actual argument
226            # to make_fx, but it is made use of internally in AOTAutograd's
227            # call to tensor unflatten.  Because the outer sizes isn't passed
228            # as an argument, it is therefore untracked.  However, it turns
229            # out you luck out, because *Dynamo* will manually add the outer
230            # sizes as an argument so you can fix up the proxy'ness.
231            #
232            # This is probably fixed in
233            # https://github.com/pytorch/pytorch/pull/125941/
234            import sympy
235
236            if isinstance(obj.node.expr, sympy.Symbol):
237                tracer.sympy_expr_tracker[obj.node.expr] = proxy
238
239
240def has_proxy_slot(obj: Tensor, tracer: _ProxyTracer) -> bool:
241    assert isinstance(obj, (Tensor, SymNode)), type(obj)
242    return bool(get_proxy_slot(obj, tracer, False, lambda _: True))
243
244
245_PySymProxyType = Thunk[Proxy]
246
247
248@overload
249def get_proxy_slot(
250    obj: Tensor,
251    tracer: _ProxyTracer,
252) -> _ProxyTensor:
253    ...
254
255
256@overload
257def get_proxy_slot(
258    obj: Tensor,
259    tracer: _ProxyTracer,
260    default: U,
261) -> Union[_ProxyTensor, U]:
262    ...
263
264
265@overload
266def get_proxy_slot(
267    obj: Tensor,
268    tracer: _ProxyTracer,
269    default: U,
270    transform: Callable[[_ProxyTensor], R],
271) -> Union[R, U]:
272    ...
273
274
275@overload
276def get_proxy_slot(
277    obj: _AnyScriptObjectType,
278    tracer: _ProxyTracer,
279) -> Proxy:
280    ...
281
282
283@overload
284def get_proxy_slot(
285    obj: _AnyScriptObjectType,
286    tracer: _ProxyTracer,
287    default: U,
288) -> Union[Proxy, U]:
289    ...
290
291
292@overload
293def get_proxy_slot(
294    obj: _AnyScriptObjectType,
295    tracer: _ProxyTracer,
296    default: U,
297    transform: Callable[[Proxy], R],
298) -> Union[R, U]:
299    ...
300
301
302@overload
303def get_proxy_slot(
304    obj: PySymType,
305    tracer: _ProxyTracer,
306) -> _PySymProxyType:
307    ...
308
309
310@overload
311def get_proxy_slot(
312    obj: PySymType,
313    tracer: _ProxyTracer,
314    default: T,
315) -> Union[T, _PySymProxyType]:
316    ...
317
318
319@overload
320def get_proxy_slot(
321    obj: PySymType,
322    tracer: _ProxyTracer,
323    default: U,
324    transform: Callable[[_PySymProxyType], R],
325) -> Union[R, U]:
326    ...
327
328
329# the default argument is what to return if the slot is not set.
330# the transform argument is handy if you need to extract a subfield from
331# the successfully looked up result (but NOT the default.)
332def get_proxy_slot(
333    obj: Union[Tensor, _AnyScriptObjectType, PySymType],
334    tracer: _ProxyTracer,
335    default: object = no_default,
336    transform: Callable = lambda x: x,
337) -> object:
338    tracker: Any
339    if isinstance(obj, Tensor):
340        tracker = tracer.tensor_tracker
341    elif isinstance(obj, _AnyScriptObject):
342        tracker = tracer.script_object_tracker
343    else:
344        assert isinstance(obj, py_sym_types), type(obj)
345        tracker = tracer.symnode_tracker
346
347    if obj not in tracker:
348        # Last ditch
349        if isinstance(obj, py_sym_types) and obj.node.expr in tracer.sympy_expr_tracker:
350            value = tracer.sympy_expr_tracker[obj.node.expr]
351        else:
352            if isinstance(default, _NoDefault):
353                raise RuntimeError(
354                    f"{obj} ({id(obj)})is not tracked with proxy for {tracer}"
355                )
356            return default
357    else:
358        value = tracker[obj]
359    res = transform(value)
360    return res
361
362
363def snapshot_fake(val: Tensor) -> Optional[Tensor]:
364    # val.detach() will also eventually call fast_detach(),
365    # but this saves us a full trip into __torch_dispatch__
366    # (snapshot_fake is called a lot)
367    if isinstance(val, FakeTensor):
368        return fast_detach(val.fake_mode, val)
369    else:
370        return val.detach()
371
372
373_ExtractValType = Optional[
374    Union[
375        PySymType,
376        _AnyScriptObjectType,
377        BackwardState,
378        List["_ExtractValType"],
379        Tuple["_ExtractValType", ...],
380        Dict[str, "_ExtractValType"],
381        Tensor,
382        int,
383        float,
384        bool,
385    ]
386]
387
388
389def extract_val(val: _ExtractValType) -> _ExtractValType:
390    if is_fake(val):
391        return snapshot_fake(val)
392    elif isinstance(val, py_sym_types):
393        return val
394    elif isinstance(val, _AnyScriptObject):
395        return val
396    elif isinstance(val, BackwardState):
397        return val
398    elif isinstance(val, (list, tuple)):
399        return val.__class__([extract_val(x) for x in val])
400    elif isinstance(val, dict):
401        return {k: extract_val(v) for k, v in val.items()}
402    elif isinstance(val, Tensor):
403        if not val.is_sparse:
404            # NB: Kinda hacky, but we should try to get val as the metadata
405            # everywhere
406            # TODO: This doesn't properly track storages.  A more robust
407            # approach would be to maintain a per-trace FakeTensorMode and
408            # from_real_tensor to create fake values (don't forget to
409            # snapshot_fake)
410            fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=True)
411            with fake_tensor_mode:
412                return torch.empty_strided(
413                    val.shape, val.stride(), device=val.device, dtype=val.dtype
414                )
415        else:
416            return None
417    elif isinstance(val, (int, float, bool)):
418        return val
419    elif val is None:
420        return None
421
422    typing_extensions.assert_never(val)
423
424
425@contextmanager
426def _enable_thunkify(
427    tracer: _ProxyTracer, *, enable: bool = True
428) -> Generator[None, None, None]:
429    """
430    Enable thunkification inside the context manager.  Thunkification prevents
431    SymNode computation from directly being traced into an FX graph; instead,
432    the compute is only added to the graph if it is actually used.  This helps
433    us track SymNode compute when it is computed (since we need /something/
434    to put in the tracker) even if it is unlikely to be used.
435    """
436    old = tracer.enable_thunkify
437    tracer.enable_thunkify = enable
438    try:
439        yield
440    finally:
441        tracer.enable_thunkify = old
442
443
444@contextmanager
445def maybe_disable_thunkify() -> Generator[None, None, None]:
446    """Within a context, disable thunkification.  See :func:`maybe_enable_thunkify`
447    for more details.  This is helpful if you have a wrapper function which
448    you want to enable thunkification on, but in some segment on the inside (say,
449    the original user function), you want to disable thunkification as you know
450    it is not needed there.
451    """
452    proxy_mode = get_proxy_mode()
453    if proxy_mode is not None:
454        with _enable_thunkify(proxy_mode.tracer, enable=False):
455            yield
456    else:
457        yield
458
459
460@contextmanager
461def maybe_enable_thunkify() -> Generator[None, None, None]:
462    """Within this context manager, if you are doing make_fx tracing, we will thunkify
463    all SymNode compute and avoid tracing it into the graph unless it is actually needed.
464    You should prefer to avoid using this as much as possible, as lazy evaluation of
465    SymNode tracing can lead to long chains of thunks which will stack overflow
466    if you evaluate them.  However, this is currently sometimes necessary as there
467    are buggy parts of PT2 which will fail with "s0 is not tracked with proxy" error
468    due to insufficient tracing of SymNode computation.
469    """
470    proxy_mode = get_proxy_mode()
471    if proxy_mode is not None:
472        with _enable_thunkify(proxy_mode.tracer):
473            yield
474    else:
475        yield
476
477
478# Note [invariants for node meta 'val']
479# What invariants do we have for the 'val' set on the FX node?  It has accurate
480# metadata... but only for metadata that exists "below" all other subsystems
481# (most notably autograd, but also vmap, functorch transforms, etc).  This means
482# you can get the dtype, shape, stride, storage, but you CANNOT get requires_grad,
483# grad_fn, _base (_base actually may be set due to recursive call to
484# ADInplaceOrView, but you shouldn't rely on it.)
485def set_meta(proxy: Proxy, val: _ExtractValType) -> Proxy:
486    proxy.node.meta["val"] = extract_val(val)
487
488    with _enable_thunkify(proxy.tracer):  # type: ignore[arg-type]
489        # Best effort tensor_meta setting; prefer using val!
490        if is_fake(val):
491            proxy.node.meta["tensor_meta"] = _extract_tensor_metadata(val)
492        elif isinstance(val, Tensor) and not val.is_sparse:
493            proxy.node.meta["tensor_meta"] = _extract_tensor_metadata(val)
494    return proxy
495
496
497def thunkify(
498    tracer: _ProxyTracer, f: Callable[_P, R], *args: _P.args, **kwargs: _P.kwargs
499) -> Thunk[R]:
500    """
501    Delays computation of f until it's called again
502    Also caches the result
503    """
504    if tracer.enable_thunkify:
505        return Thunk(functools.partial(f, *args, **kwargs))
506    else:
507        r = f(*args, **kwargs)
508        return Thunk(lambda: r)
509
510
511def track_tensor(
512    tensor: Tensor, proxy: Proxy, *, constant: Optional[Tensor], tracer: _ProxyTracer
513) -> None:
514    def try_set_proxy_slot(
515        outer_s: IntLikeType,
516        proxy_callable: Callable[Concatenate[PySymType, _P], Proxy],
517        *args: _P.args,
518        **kwargs: _P.kwargs,
519    ) -> None:
520        assert callable(proxy_callable)
521        if isinstance(outer_s, SymInt):
522            with _enable_thunkify(tracer):
523                set_proxy_slot(
524                    outer_s,
525                    tracer,
526                    thunkify(tracer, proxy_callable, outer_s, *args, **kwargs),
527                )
528
529    # The basic idea is that we need to associate each tensor/SymInt
530    # with a Proxy.  How do we setup this association?  We just store
531    # the proxy on the proxy slot of the object, keyed on the tracer
532    # (so that if we have multiple tracers at the same time, they
533    # don't clobber each other.)
534    for i, s in enumerate(tensor.shape):
535        try_set_proxy_slot(
536            s,
537            lambda x, i: set_meta(
538                tracer.create_proxy(
539                    "call_function", torch.ops.aten.sym_size.int, (proxy, i), {}
540                ),
541                x,
542            ),
543            i,
544        )
545
546    if not is_sparse_any(tensor):
547        for i, s in enumerate(tensor.stride()):
548            try_set_proxy_slot(
549                s,
550                lambda x, i: set_meta(
551                    tracer.create_proxy(
552                        "call_function", torch.ops.aten.sym_stride.int, (proxy, i), {}
553                    ),
554                    x,
555                ),
556                i,
557            )
558
559    try_set_proxy_slot(
560        tensor.numel(),
561        lambda x: set_meta(
562            tracer.create_proxy(
563                "call_function", torch.ops.aten.sym_numel.default, (proxy,), {}
564            ),
565            x,
566        ),
567    )
568    if not is_sparse_any(tensor):
569        try_set_proxy_slot(
570            tensor.storage_offset(),
571            lambda x: set_meta(
572                tracer.create_proxy(
573                    "call_function",
574                    torch.ops.aten.sym_storage_offset.default,
575                    (proxy,),
576                    {},
577                ),
578                x,
579            ),
580        )
581    set_proxy_slot(tensor, tracer, _ProxyTensor(proxy, constant))
582
583
584_NestedProxys = Union[
585    Proxy, Sequence["_NestedProxys"], Mapping[object, "_NestedProxys"]
586]
587_NestedTensors = Union[
588    Tensor, Sequence["_NestedTensors"], Mapping[object, "_NestedTensors"]
589]
590
591
592def track_tensor_tree(
593    inner_res: T,
594    proxy_res: _NestedProxys,
595    *,
596    constant: Optional[_NestedTensors],
597    tracer: _ProxyTracer,
598) -> T:
599    # NB: We call set_unbacked_bindings only on the *topmost* call to
600    # track_tensor_tree, not recursive calls.  This is because there must
601    # be only ONE unbacked_binding proxy call, and it should be the one
602    # where all of the unbacked SymInts actually first come into existence.
603    # If you call this again on the inner proxies for the tuple projections,
604    # you will have multiple unbacked_bindings for the same symbol, but
605    # they're not going to show up anywhere.
606    #
607    # I was briefly deceived into setting unbacked bindings recursively when
608    # working on https://github.com/pytorch/pytorch/pull/133585 because I
609    # observed that some extra unbacked bindings were needed to handle some
610    # higher order operator code.  But actually it looks like this was
611    # just an unrelated bug that needed to be fixed separately.
612    _set_unbacked_bindings(inner_res, proxy_res)
613
614    def wrap_with_proxy(
615        e: object, proxy: _NestedProxys, constant: Optional[_NestedTensors]
616    ) -> None:
617        if isinstance(e, Tensor):
618            assert isinstance(proxy, Proxy)
619            assert constant is None or isinstance(constant, Tensor)
620            track_tensor(e, proxy, tracer=tracer, constant=constant)
621            set_meta(proxy, e)
622        elif isinstance(e, py_sym_types):
623            assert isinstance(proxy, Proxy)
624            # NB: eagerly set meta here, so that the numbering is in order
625            set_meta(proxy, e)
626            set_proxy_slot(e, tracer, thunkify(tracer, lambda: proxy))
627        elif isinstance(e, _AnyScriptObject):
628            assert isinstance(proxy, Proxy)
629            set_proxy_slot(e, tracer, proxy)
630            set_meta(proxy, e)
631        elif isinstance(e, (tuple, list)):
632            # example use case: allreduce_ returns ([tensor], work)
633            if isinstance(proxy, fx.Proxy):
634                set_meta(proxy, e)
635
636            def get_constant(
637                c: Optional[_NestedTensors], idx: int
638            ) -> Optional[_NestedTensors]:
639                if c is None:
640                    return None
641                else:
642                    assert isinstance(c, (list, tuple))
643                    return c[idx]
644
645            for idx, ee in enumerate(e):
646                # Use an indexer here - if proxy is a List then it will unwrap
647                # it. If it's a Proxy then it will proxy the getelem.
648                wrap_with_proxy(ee, proxy[idx], get_constant(constant, idx))  # type: ignore[index]
649
650        elif isinstance(e, dict):
651            # example use case: triton_kernel_wrapper takes arguments as kwargs
652
653            # In theory we could support const-prop when proxy-tensor-tracing
654            # operators that returns dicts of tensors, but we have no use case
655            # for it today (since the only op we currently trace that can
656            # return a dict is triton_kernel_wrapper_functional/mutation,
657            # which does not participate in const-prop)
658            assert constant is None
659
660            if isinstance(proxy, fx.Proxy):
661                set_meta(proxy, e)
662
663            for key, val in e.items():
664                wrap_with_proxy(val, proxy[key], None)  # type: ignore[index]
665
666        elif isinstance(e, BackwardState):
667            assert isinstance(proxy, Proxy)
668            set_meta(proxy, e)
669            e.proxy = proxy
670        else:
671            # intentionally pass on primitives
672            pass
673
674    wrap_with_proxy(inner_res, proxy_res, constant)
675
676    return inner_res
677
678
679@dataclass
680class _ProxyTensor:
681    proxy: Proxy
682    constant: Optional[Tensor]
683
684
685def fetch_sym_proxy(
686    tracer: _ProxyTracer,
687) -> Callable[[PySymType], Union[bool, int, float, Proxy]]:
688    def inner(e: PySymType) -> Union[int, bool, float, Proxy]:
689        n = e.node
690        if n.constant is not None:
691            return n.constant
692        if e.node.expr.is_number:
693            if isinstance(e, SymBool):
694                return bool(e.node.expr)
695            elif isinstance(e, SymInt):
696                return int(e.node.expr)
697            return float(e.node.expr)
698        else:
699            assert isinstance(e, py_sym_types)
700            # NB: we REQUIRE all symints to be tracked
701            return get_proxy_slot(e, tracer).force()
702
703    return inner
704
705
706@overload
707def fetch_object_proxy(tracer: _ProxyTracer, t: Tensor) -> Union[_ProxyTensor, Tensor]:
708    ...
709
710
711@overload
712def fetch_object_proxy(
713    tracer: _ProxyTracer, t: _AnyScriptObjectType
714) -> Union[Proxy, _AnyScriptObjectType]:
715    ...
716
717
718@overload
719def fetch_object_proxy(
720    tracer: _ProxyTracer, t: PySymType
721) -> Union[_PySymProxyType, PySymType]:
722    ...
723
724
725def fetch_object_proxy(
726    tracer: _ProxyTracer, t: Union[Tensor, _AnyScriptObjectType, PySymType]
727) -> object:
728    return get_proxy_slot(t, tracer, t)
729
730
731HANDLED_TYPES = (Tensor, torch.nn.Parameter, FakeTensor)
732
733
734def _maybe_record_pointwise_barrier(
735    func: object, proxy_mode: ProxyTorchDispatchMode
736) -> None:
737    """
738    Records pointwise operators in user program (non decomposed) that were output in fp16/bf16
739    """
740    if proxy_mode.decomp_layers or not proxy_mode.emulate_precision_casts:
741        return
742
743    if (
744        not isinstance(func, torch._ops.OpOverload)
745        or torch.Tag.pointwise not in func.tags
746    ):
747        return
748
749    last_node = next(iter(reversed(proxy_mode.tracer.graph.nodes)))
750    t = last_node.meta.get("val")
751    if not isinstance(t, torch.Tensor) or t.dtype not in (
752        torch.bfloat16,
753        torch.float16,
754    ):
755        return
756
757    last_node.meta["low_precision_pointwise_barrier"] = True
758
759
760def proxy_call(
761    proxy_mode: ProxyTorchDispatchMode,
762    func: OpOverload,
763    pre_dispatch: bool,
764    args: Tuple[object, ...],
765    kwargs: Dict[str, object],
766) -> object:
767    unrecognized_types: List[Type] = []
768    flat_args_kwargs, spec = pytree.tree_flatten((args, kwargs))
769
770    def can_handle_tensor(x: Tensor) -> bool:
771        r = type(x) in HANDLED_TYPES or has_proxy_slot(x, proxy_mode.tracer)
772        if proxy_mode._allow_fake_constant:
773            r = r or type(x) in (torch._subclasses.FakeTensor,)
774        if not r:
775            unrecognized_types.append(type(x))
776        return r
777
778    # If there are any tensor subclasses, we need to handle those tensor subclasses first
779    # TODO: we could use types to test this
780    if not all(can_handle_tensor(x) for x in flat_args_kwargs if isinstance(x, Tensor)):
781        not_implemented_log.debug(
782            "ProxyTensorMode tensors without proxy had unrecognized subclasses: %s",
783            unrecognized_types,
784        )
785        return NotImplemented
786
787    r = maybe_handle_decomp(proxy_mode, func, args, kwargs)
788    if r is not NotImplemented:
789        _maybe_record_pointwise_barrier(func, proxy_mode)
790        return r
791
792    # For pre-autograd tracing, we do not want to run CompositeImplicit decomps.
793    if not pre_dispatch and func not in [
794        torch.ops.aten.size.default,
795        torch.ops.aten.stride.default,
796        torch.ops.aten.storage_offset.default,
797    ]:
798        with proxy_mode:
799            r = func.decompose(*args, **kwargs)
800            if r is not NotImplemented:
801                return r
802
803    tracer = proxy_mode.tracer
804    f_flat_args_kwargs = [
805        (
806            fetch_object_proxy(tracer, x)
807            if isinstance(x, (Tensor, _AnyScriptObject))
808            else x
809        )
810        for x in flat_args_kwargs
811    ]
812
813    # If there are SymInts, we also should not consider this constant.
814    # However, fake tensor handling of SymInts is sufficiently broken that
815    # I couldn't write a test for this case
816    all_constant = (
817        not any(
818            t.constant is None
819            for t in f_flat_args_kwargs
820            if isinstance(t, _ProxyTensor)
821        )
822        # TODO: maybe constant SymInts should also be allowed?  Not sure if
823        # this can happen
824        and not any(isinstance(x, py_sym_types) for x in flat_args_kwargs)
825    )
826
827    if torch.Tag.data_dependent_output in func.tags:
828        # Check if all of the Tensor inputs are constants
829        if all_constant:
830            const_flat_args_kwargs = [
831                t.constant if isinstance(t, _ProxyTensor) else t
832                for t in f_flat_args_kwargs
833            ]
834            const_args, const_kwargs = pytree.tree_unflatten(
835                const_flat_args_kwargs, spec
836            )
837            with unset_fake_temporarily():
838                return func(*const_args, **const_kwargs)
839        # If any of the Tensor inputs are "real" (not FakeTensor), we may
840        # incorrectly burn in constants by allowing this access.  Raise
841        # an error in this case
842        if proxy_mode._error_on_data_dependent_ops and pytree.tree_all_only(
843            Tensor, lambda t: not is_fake(t), (args, kwargs)
844        ):
845            raise RuntimeError(
846                f"It appears that you're trying to get value out of a tracing tensor with {func} - erroring out! "
847                "It's likely that this is caused by data-dependent control flow or similar.  "
848                "It may be possible to trace this with dynamic shapes; try setting tracing_mode='symbolic' "
849                "in your make_fx call."
850            )
851
852    proxy_flat_args_kwargs = [
853        e.proxy if isinstance(e, _ProxyTensor) else e for e in f_flat_args_kwargs
854    ]
855    proxy_flat_args_kwargs = [
856        (fetch_sym_proxy(proxy_mode.tracer)(e) if isinstance(e, py_sym_types) else e)
857        for e in proxy_flat_args_kwargs
858    ]
859    proxy_args, proxy_kwargs = pytree.tree_unflatten(proxy_flat_args_kwargs, spec)
860
861    # When we trace through a torch.tensor invocation, you never actually
862    # see a torch.ops.aten.tensor call. Instead, the way this function is
863    # implemented internally is that we allocate a plain tensor (this is
864    # *guaranteed* to be a plain tensor, we disable all modes when doing
865    # so), and then call at::lift_fresh on it (to give modes a chance to do
866    # their stuff).  Furthermore, the tensor argument to lift_fresh is guaranteed
867    # to be freshly allocated, so we want lift_fresh to be a no-op (directly
868    # returning the input argument).
869    #
870    # Here is the basic problem: when we trace this sequence of executions
871    # into an FX graph, what happens to this call sequence?  Traditionally,
872    # tensor constants get interned as buffers on the FX GraphModule.  But
873    # this is dangerous.  Consider:
874    #
875    #       x = torch.tensor(1)
876    #       x.add_(2)
877    #
878    # Naively, this traces into:
879    #
880    #       t = self._tensor_constant0  # initialized to torch.tensor(1)
881    #       x = torch.ops.aten.lift_fresh(t)
882    #       x.add_(2)
883    #
884    # If lift_fresh returns t directly, the subsequent add_ call will
885    # modify the tensor constant. Really, the problem is we've violated
886    # the invariant the argument to lift is fresh.  So what we should
887    # preserve the invariant by replacing lift_fresh with lift_fresh_copy:
888    #
889    #       t = self._tensor_constant0  # initialized to torch.tensor(1)
890    #       x = torch.ops.aten.lift_fresh_copy(t)
891    #       x.add_(2)
892    #
893    # This is what the overload modification does.
894    if func is torch.ops.aten.lift_fresh.default:
895        func = torch.ops.aten.lift_fresh_copy.default
896
897    proxy_out = proxy_mode.tracer.create_proxy(
898        "call_function",
899        func,
900        proxy_args,
901        proxy_kwargs,
902        name=proxy_mode.tracer.graph._target_to_str(func.overloadpacket.__name__),
903    )
904
905    with _enable_thunkify(proxy_mode.tracer):
906        out = func(*args, **kwargs)
907
908    # In some circumstances, we will be tracing in a situation where a tensor
909    # is *statically* known to be a constant (currently, this only happens if
910    # you run torch.tensor; deterministic factory functions like torch.arange
911    # don't get this treatment).  When the tensor in question is small, it's
912    # helpful to due constant propagation in case we call item() (in which
913    # case we can return the constant value that is known, rather than give
914    # an error.)  The logic here tests if constant propagation is possible
915    # (because all of the inputs are constant).  If so, we disable fake tensor
916    # mode (if it is on) and do true compute on the constant.
917    #
918    # It's worth highlighting that we're making a policy decision here.
919    # There is a potential that the tensor is actually quite large, and we
920    # don't actually want to run the compute.  The tensor being quite large
921    # is one of the reasons why factory functions don't get this treatment
922    # (since they can be quite large; if a parameter is initialized to a
923    # constant value it will be!)  Similarly, there is also a potential
924    # to run an operator that blows up the size of a small tensor; we don't
925    # protect against this case, but we could force, e.g., only single
926    # element constant computation by testing the numel of the result before
927    # propagating const-ness.  Similarly, we don't require the constant to
928    # live on CPU, but we could.
929    any_constant = any(
930        t.constant is not None
931        for t in f_flat_args_kwargs
932        if isinstance(t, _ProxyTensor)
933    )
934
935    constant = None
936
937    def tensor_numel_in_limit(t: Tensor) -> bool:
938        return t.numel() <= CONSTANT_NUMEL_LIMIT
939
940    # If this is a lift, the input tensor is guaranteed to be a
941    # constant, so we keep a copy of the original argument along so
942    # we can query it if we're asked to item() it at some later point
943    if (
944        func is torch.ops.aten.lift_fresh_copy.default
945        and out.numel() <= CONSTANT_NUMEL_LIMIT
946    ):
947        with unset_fake_temporarily():
948            assert isinstance(args[0], (Proxy, Tensor)), type(args[0])
949            constant = args[0].clone()
950    elif (
951        torch.Tag.nondeterministic_seeded not in func.tags
952        and all_constant
953        and any_constant
954        and pytree.tree_all_only(Tensor, tensor_numel_in_limit, out)
955    ):
956        # NB: do NOT include factories as constants
957        with unset_fake_temporarily():
958            const_flat_args_kwargs = [
959                t.constant if isinstance(t, _ProxyTensor) else t
960                for t in f_flat_args_kwargs
961            ]
962            const_args, const_kwargs = pytree.tree_unflatten(
963                const_flat_args_kwargs, spec
964            )
965            constant = func(*const_args, **const_kwargs)
966    else:
967        constant = None
968
969    track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer)
970    _maybe_record_pointwise_barrier(func, proxy_mode)
971    return out
972
973
974class _SymNodeDict:
975    """
976    Wrapper around a dictionary that will hash SymInts with their nodes
977    """
978
979    def __init__(self) -> None:
980        self.sym_node_dict: Dict[PySymType, _PySymProxyType] = {}
981
982    def __setitem__(self, key: PySymType, value: _PySymProxyType) -> None:
983        self.sym_node_dict[key.node] = value
984
985    def __getitem__(self, key: PySymType) -> _PySymProxyType:
986        return self.sym_node_dict[key.node]
987
988    def __contains__(self, key: PySymType) -> bool:
989        return key.node in self.sym_node_dict
990
991    def get(
992        self, key: PySymType, default: Optional[_PySymProxyType] = None
993    ) -> _PySymProxyType:
994        # dict.get()'s annotation doesn't accept `None` when the value type
995        # isn't Optional.
996        return self.sym_node_dict.get(key.node, default)  # type: ignore[arg-type]
997
998    def __iter__(self) -> Any:
999        raise NotImplementedError
1000
1001    def __len__(self) -> int:
1002        return len(self.sym_node_dict)
1003
1004
1005class PythonKeyTracer(Tracer):
1006    script_object_tracker: MutableMapping[_AnyScriptObjectType, Proxy]
1007    symnode_tracker: _SymNodeDict
1008    sympy_expr_tracker: Dict[sympy.Symbol, object]
1009    tensor_tracker: MutableMapping[Tensor, _ProxyTensor]
1010    torch_fn_counts: Dict[OpOverload, int]
1011    enable_thunkify: bool = False
1012
1013    def __init__(self) -> None:
1014        super().__init__(autowrap_modules=())  # type: ignore[arg-type]
1015        self.tensor_tracker = WeakTensorKeyDictionary()
1016        self.symnode_tracker = _SymNodeDict()
1017        self.script_object_tracker = WeakIdKeyDictionary(
1018            dict=None, ref_type=_WeakHashRef
1019        )
1020        self.sympy_expr_tracker = dict()
1021
1022        # Stores the torch function that was called during tracing
1023        self.torch_fn_metadata = None
1024        # Stores the counts for every torch function called. This is to help
1025        # distinguish between different calls to the same torch function.
1026        self.torch_fn_counts = {}
1027        self.enable_thunkify = False
1028
1029    # In general, we don't want to make modules leaves. In principle, users of
1030    # this tracer might want to override this in order to turn a couple specific
1031    # modules into leaves in the traced graph.
1032    def call_module(
1033        self,
1034        m: Module,
1035        forward: Callable[..., Any],
1036        args: Tuple[Any, ...],
1037        kwargs: Dict[str, Any],
1038    ) -> Any:
1039        return forward(*args, **kwargs)
1040
1041    # We don't want to turn getattr calls into proxies. So we just return the actual value.
1042    def getattr(
1043        self, attr: str, attr_val: object, parameter_proxy_cache: Dict[str, Proxy]
1044    ) -> object:
1045        return attr_val
1046
1047    def create_arg(self, a: object) -> fx.node.Node:
1048        if isinstance(a, torch.nn.Parameter):
1049            for n, p in self.root.named_parameters():
1050                if a is p:
1051                    return self.create_node("get_attr", n, (), {})
1052
1053            qualname = self.get_fresh_qualname("_param_constant")
1054            setattr(self.root, qualname, a)
1055
1056            return self.create_node("get_attr", qualname, (), {})
1057        elif isinstance(a, py_sym_types):
1058            assert a.node.constant is not None
1059            return a.node.constant
1060        return super().create_arg(a)  # type: ignore[return-value]
1061
1062    @overload
1063    def unwrap_proxy(self, e: Tensor) -> Union[Proxy, Tensor]:
1064        ...
1065
1066    @overload
1067    def unwrap_proxy(self, e: PySymType) -> Union[Proxy, PySymType]:
1068        ...
1069
1070    @overload
1071    def unwrap_proxy(
1072        self, e: _AnyScriptObjectType
1073    ) -> Union[Proxy, _AnyScriptObjectType]:
1074        ...
1075
1076    def unwrap_proxy(self, e: T) -> object:
1077        if isinstance(e, Tensor):
1078            return get_proxy_slot(e, self, e, lambda x: x.proxy)
1079        elif isinstance(e, py_sym_types):
1080            return get_proxy_slot(e, self, e, lambda e: e.force())
1081        elif isinstance(e, _AnyScriptObject):
1082            return get_proxy_slot(e, self, e)
1083        else:
1084            return e
1085
1086
1087@contextmanager
1088def _temp_remove_pre_dispatch_torch_function_mode() -> Generator[None, None, None]:
1089    from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode
1090
1091    temp_elements = []
1092    pre_dispatch_mode = None
1093
1094    while _len_torch_function_stack() > 0:
1095        mode = _pop_mode()
1096        if isinstance(mode, PreDispatchTorchFunctionMode):
1097            pre_dispatch_mode = mode
1098            break
1099        else:
1100            temp_elements.append(mode)
1101
1102    for mode in reversed(temp_elements):
1103        _push_mode(mode)
1104
1105    try:
1106        yield
1107
1108    finally:
1109        if pre_dispatch_mode is not None:
1110            count = len(temp_elements)
1111            while count > 0:
1112                mode = _pop_mode()
1113                count -= 1
1114
1115            temp_elements.append(pre_dispatch_mode)
1116
1117            for mode in reversed(temp_elements):
1118                _push_mode(mode)
1119
1120
1121@torch._disable_dynamo
1122def dispatch_trace(
1123    root: Union[Module, Callable],
1124    tracer: Tracer,
1125    concrete_args: Optional[Tuple[Any, ...]] = None,
1126) -> GraphModule:
1127    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
1128
1129    # NB: be careful not to DCE .item() calls
1130    def impure_pred(n: fx.Node) -> bool:
1131        from .symbolic_shapes import is_accessor_node
1132
1133        # Always defer to the built-in notion of impure
1134        if n.is_impure():
1135            return True
1136
1137        # Accessors always OK to DCE
1138        if is_accessor_node(n):
1139            return False
1140
1141        # If the operator in question takes SymInt args to SymInt output,
1142        # we assume it's pure and OK to DCE
1143        if (
1144            isinstance(n.meta.get("val"), py_sym_types)
1145            and
1146            # NB: constant args ok
1147            all(
1148                isinstance(a.meta.get("val"), py_sym_types)
1149                for a in n.args
1150                if isinstance(a, fx.Node)
1151            )
1152        ):
1153            return False
1154
1155        # No idea, just assume it's not OK
1156        return True
1157
1158    graph.eliminate_dead_code(impure_pred)
1159    from torch._inductor.fx_passes.dedupe_symint_uses import dedupe_symints
1160
1161    dedupe_symints(graph)
1162    name = root.__class__.__name__ if isinstance(root, Module) else root.__name__
1163    return fx._lazy_graph_module._make_graph_module(tracer.root, graph, name)
1164
1165
1166def wrap_key(
1167    f: Callable[_P, R], tensors: _P.args, tracer: _ProxyTracer, pre_dispatch: bool
1168) -> Callable[_P, R]:
1169    flat_tensors, tensors_spec = pytree.tree_flatten(tensors)
1170
1171    @functools.wraps(f)
1172    def wrapped(*proxies: _P.args, **_unused: _P.kwargs) -> R:
1173        flat_proxies, proxies_spec = pytree.tree_flatten(proxies)
1174        assert len(flat_proxies) == len(flat_tensors)
1175        with disable_proxy_modes_tracing() as m:
1176            assert isinstance(m, ProxyTorchDispatchMode)
1177            track_tensor_tree(flat_tensors, flat_proxies, constant=None, tracer=tracer)
1178
1179        def get_tensor_proxy_slot(t: Tensor) -> Union[Tensor, Proxy]:
1180            return get_proxy_slot(t, tracer, t, lambda x: x.proxy)
1181
1182        out = f(*tensors)
1183        out = pytree.tree_map_only(Tensor, get_tensor_proxy_slot, out)
1184        out = pytree.tree_map_only(
1185            _AnyScriptObject, lambda t: get_proxy_slot(t, tracer, t, lambda x: x), out
1186        )
1187
1188        def get_sym_proxy_slot(t: PySymType) -> Proxy:
1189            return get_proxy_slot(t, tracer).force()
1190
1191        out = pytree.tree_map_only(py_sym_types, get_sym_proxy_slot, out)
1192        return out
1193
1194    return wrapped
1195
1196
1197# TODO: Make downstream users of this work with OperatorBase
1198ORIGINAL_ATEN: Optional[object] = None
1199
1200
1201@contextmanager
1202def set_original_aten_op(func: OpOverload) -> Generator[None, None, None]:
1203    global ORIGINAL_ATEN
1204    if ORIGINAL_ATEN is None and fx_traceback.has_preserved_node_meta():
1205        ORIGINAL_ATEN = func
1206        fx_traceback.current_meta["original_aten"] = func
1207        try:
1208            yield
1209        finally:
1210            ORIGINAL_ATEN = None
1211            fx_traceback.current_meta["original_aten"] = None
1212    else:
1213        yield
1214
1215
1216class TorchFunctionMetadataMode(TorchFunctionMode):
1217    def __init__(self, tracer: _ProxyTracer) -> None:
1218        self.tracer = tracer
1219
1220    def __torch_function__(
1221        self,
1222        func: OpOverload,
1223        types: Tuple[torch._C._TensorMeta, ...],
1224        args: Tuple[object, ...] = (),
1225        kwargs: Optional[Dict[str, object]] = None,
1226    ) -> object:
1227        kwargs = kwargs or {}
1228        self.tracer.torch_fn_metadata = func
1229        self.tracer.torch_fn_counts[func] = self.tracer.torch_fn_counts.get(func, 0) + 1
1230        return func(*args, **kwargs)
1231
1232
1233# This mode is **only** used for pre_dispatch tracing.
1234# In particular, we need to make sure that autograd/autocast API's
1235# that do not desugar into dispatcher operators stay in the graph.
1236class PreDispatchTorchFunctionMode(TorchFunctionMode):
1237    def __init__(self, tracer: _ProxyTracer) -> None:
1238        self.tracer = tracer
1239
1240    def __torch_function__(
1241        self,
1242        func: OpOverload,
1243        types: Tuple[torch._C._TensorMeta, ...],
1244        args: Tuple[object, ...] = (),
1245        kwargs: Optional[Dict[str, object]] = None,
1246    ) -> object:
1247        kwargs = kwargs or {}
1248        if func in _side_effectful_need_to_be_preserved_pre_dispatch:
1249            # It's for passing the export verifier which needs to verify the meta['val']
1250            # TODO(tmanlaibaatar): we should systematically couple it with expoert verifier,
1251            # instead of hardcoding it here.
1252            node = self.tracer.create_node("call_function", func, args, {})  # type: ignore[arg-type]
1253            if func is torch._C._set_grad_enabled:
1254                node.meta["val"] = None
1255            return node
1256            # Don't actually run the function! We just want to trace the calls
1257            # into a graph. We don't actualy want to change global autograd state.
1258        return func(*args, **kwargs)
1259
1260
1261class ProxyTorchDispatchMode(TorchDispatchMode):
1262    # Ensure this is read-only; this exists only for legacy reasons
1263    @property
1264    def enable_tracing(self) -> bool:
1265        return True
1266
1267    def __init__(
1268        self,
1269        tracer: _ProxyTracer,
1270        tracing_mode: str,
1271        pre_dispatch: bool = False,
1272        _allow_fake_constant: bool = False,
1273        _error_on_data_dependent_ops: bool = True,
1274    ) -> None:
1275        dk = torch._C.DispatchKey.PreDispatch if pre_dispatch else None
1276        super().__init__(dk)
1277        self.tracer = tracer
1278        self.tracing_mode = tracing_mode
1279        self.pre_dispatch = pre_dispatch
1280        self._allow_fake_constant = _allow_fake_constant
1281        self._error_on_data_dependent_ops = _error_on_data_dependent_ops
1282        # Indicates to our torch_dispatch dispatching infra that
1283        # this is an "infra" mode with lower dispatching precedence.
1284        self._mode_key = torch._C._TorchDispatchModeKey.PROXY
1285        # Every time we enter a mode, we maintain a stack telling us what the previous
1286        # ProxyTorchDispatchMode state was (if there was any).
1287        # This lets us properly reset the state on exit.
1288        self.enter_stack: List[Optional[ProxyTorchDispatchMode]] = []
1289        self.decomp_layers = 0
1290        from torch._inductor import config
1291
1292        self.emulate_precision_casts = config.emulate_precision_casts
1293
1294    @count
1295    def __torch_dispatch__(
1296        self,
1297        func: OpOverload,
1298        types: Tuple[torch._C._TensorMeta, ...],
1299        args: Tuple[object, ...] = (),
1300        kwargs: Optional[Dict[str, object]] = None,
1301    ) -> object:
1302        with set_original_aten_op(func):
1303            kwargs = kwargs or {}
1304
1305            if func in (prim.device.default,):
1306                return func(*args, **kwargs)
1307
1308            return proxy_call(self, func, self.pre_dispatch, args, kwargs)
1309
1310    def __enter__(self) -> Self:
1311        # Stash and store the previous proxy mode (there may or may not be one)
1312        maybe_prev_proxy_mode = _unset_infra_mode(torch._C._TorchDispatchModeKey.PROXY)
1313        self.enter_stack.append(maybe_prev_proxy_mode)
1314        return super().__enter__()
1315
1316    def __exit__(
1317        self,
1318        exc_type: Optional[Type[BaseException]],
1319        exc_value: Optional[BaseException],
1320        traceback: Optional[types.TracebackType],
1321    ) -> Optional[bool]:
1322        b = super().__exit__(exc_type, exc_value, traceback)
1323
1324        # Re-enable the previous proxy mode, if there was one.
1325        mb_previous_proxy_mode = self.enter_stack.pop()
1326        if mb_previous_proxy_mode is not None:
1327            _push_mode(mb_previous_proxy_mode)
1328
1329        return b
1330
1331    @classmethod
1332    def is_infra_mode(cls) -> bool:
1333        return True
1334
1335    def _compute_proxy(
1336        self, func: OpOverload, args: Tuple[object, ...], out: PySymType
1337    ) -> Proxy:
1338        n_args = tuple(
1339            get_proxy_slot(a, self.tracer).force().node
1340            if isinstance(a, py_sym_types)
1341            else a
1342            for a in args
1343        )
1344
1345        # func doesn't have a __torch_function__ that Proxy can interpose, so
1346        # we gotta do it manually
1347        n_out = self.tracer.create_node("call_function", func, n_args, {})  # type: ignore[arg-type]
1348        p_out = fx.Proxy(n_out, self.tracer)
1349        set_meta(p_out, out)
1350        return p_out
1351
1352    def __sym_dispatch__(
1353        self,
1354        func: OpOverload,
1355        types: Tuple[torch._C._TensorMeta, ...],
1356        args: Tuple[object, ...],
1357        kwargs: Dict[str, object],
1358    ) -> object:
1359        # Peephole optimize multiply by one
1360        # NB: be careful not to trigger guards here!
1361        if func == operator.mul:
1362            if isinstance(args[1], int) and args[1] == 1:
1363                return args[0]
1364            elif isinstance(args[0], int) and args[0] == 1:
1365                return args[1]
1366
1367        # For speed, we assume there are no nested data structures
1368        # (otherwise we could use tree_map)
1369        # We also assume there are no keyword arguments.
1370        assert not kwargs
1371        out = func(*args, **kwargs)
1372
1373        # If func returned a constant, we don't need to trace; we have
1374        # determined that the result is constant (no matter if the inputs
1375        # were symbolic) and it is no longer necessary to trace the
1376        # computation.  This could occur if func triggered some guards.
1377        if isinstance(out, py_sym_types):
1378            p_out_thunk = thunkify(
1379                self.tracer, self._compute_proxy, func=func, args=args, out=out
1380            )
1381            set_proxy_slot(out, self.tracer, p_out_thunk)
1382
1383        return out
1384
1385
1386class _GraphAppendingTracerEx(fx.proxy.GraphAppendingTracer):
1387    script_object_tracker: MutableMapping[_AnyScriptObjectType, Proxy]
1388    symnode_tracker: MutableMapping[PySymType, _PySymProxyType]
1389    tensor_tracker: MutableMapping[Tensor, _ProxyTensor]
1390    sympy_expr_tracker: Dict[sympy.Symbol, object]
1391    torch_fn_metadata: Optional[OpOverload]
1392    torch_fn_counts: Dict[OpOverload, int]
1393    enable_thunkify: bool = False
1394
1395    def __init__(self, graph: fx.graph.Graph) -> None:
1396        super().__init__(graph)
1397        self.symnode_tracker = weakref.WeakKeyDictionary()
1398        self.tensor_tracker = WeakTensorKeyDictionary()
1399        self.sympy_expr_tracker = {}
1400        self.script_object_tracker = WeakIdKeyDictionary(
1401            dict=None, ref_type=_WeakHashRef
1402        )
1403        # Stores the torch function that was called during tracing
1404        self.torch_fn_metadata = None
1405        # Stores the counts for every torch function called. This is to help
1406        # distinguish between different calls to the same torch function.
1407        self.torch_fn_counts = {}
1408
1409
1410# TODO: I'm not sure what the point of this class is; you can just
1411# make_fx through a regular Interpreter
1412class DecompositionInterpreter(fx.Interpreter):
1413    def __init__(
1414        self,
1415        module: fx.GraphModule,
1416        new_graph: fx.Graph,
1417        decomposition_table: Optional[Mapping[OpOverload, Callable]] = None,
1418        **kwargs: object,
1419    ) -> None:
1420        super().__init__(module, **kwargs)  # type: ignore[arg-type]
1421        self.new_graph = new_graph
1422        self.tracer = _GraphAppendingTracerEx(self.new_graph)
1423        # Blegh
1424        self.decomposition_table = decomposition_table or {}
1425        self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real")
1426
1427    def placeholder(
1428        self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object]  # type: ignore[override]
1429    ) -> object:
1430        out = super().placeholder(target, args, kwargs)  # type: ignore[arg-type]
1431        proxy = fx.Proxy(self.new_graph.placeholder(target), self.tracer)
1432        track_tensor_tree(out, proxy, constant=None, tracer=self.tracer)
1433        # TODO handle case where the first character of target is '*'
1434        return out
1435
1436    def get_attr(
1437        self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object]  # type: ignore[override]
1438    ) -> object:
1439        out = super().get_attr(target, args, kwargs)  # type: ignore[arg-type]
1440        proxy = fx.Proxy(self.new_graph.get_attr(target), self.tracer)
1441        track_tensor_tree(out, proxy, constant=None, tracer=self.tracer)
1442        return out
1443
1444    # call_function, call_method, call_module get traced automatically by the outer mode.
1445
1446    def output(
1447        self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object]  # type: ignore[override]
1448    ) -> object:
1449        out = super().output(target, args, kwargs)  # type: ignore[arg-type]
1450
1451        def get_proxy_node(x: _ProxyTensor) -> fx.node.Node:
1452            return x.proxy.node
1453
1454        def unwrap(e: Tensor) -> Union[Tensor, fx.Node]:
1455            return get_proxy_slot(e, self.tracer, e, get_proxy_node)
1456
1457        self.new_graph.output(pytree.tree_map(unwrap, out))
1458        return out
1459
1460    def run(self, *args: object, **kwargs: object) -> object:
1461        # Should enter the mode at least once for being able to restore it later
1462        # See: https://github.com/pytorch/pytorch/pull/82549#discussion_r934782025
1463        with decompose(self.decomposition_table), self.mode:
1464            return super().run(*args, **kwargs)  # type: ignore[arg-type]
1465
1466
1467def wrapper_and_args_for_make_fx(
1468    func: Callable[..., R], args: Tuple[object, ...], kwargs: Dict[str, object]
1469) -> Tuple[Callable[[List[object]], R], List[object]]:
1470    # make_fx doesn't support kwargs, so we need to do this flattening
1471    # and then unflatten the args before calling func
1472    flat_args, spec = pytree.tree_flatten((args, kwargs))
1473
1474    def wrapped(flat_args: List[object]) -> R:
1475        fn_args, fn_kwargs = pytree.tree_unflatten(flat_args, spec)
1476        return func(*fn_args, **fn_kwargs)
1477
1478    return wrapped, flat_args
1479
1480
1481@contextmanager
1482def disable_autocast_cache() -> Generator[None, None, None]:
1483    old_value = torch.is_autocast_cache_enabled()
1484    torch.set_autocast_cache_enabled(False)
1485    try:
1486        yield
1487    finally:
1488        torch.set_autocast_cache_enabled(old_value)
1489
1490
1491class _ModuleNotInstalledAsSubmoduleError(NameError):
1492    pass
1493
1494
1495# Base class for inline _ModuleStackTracer.__init__.AttrProxy
1496class _AttrProxy:
1497    def reset_proxy_mapping(self, base: Module, path: str) -> None:
1498        pass
1499
1500
1501class _ModuleStackTracer(PythonKeyTracer):
1502    r"""Customized version of PythonKeyTracer that retains module stack
1503    information in node.meta["nn_module_stack"].
1504
1505    FX symbolic trace actually does this already, but it relies on `self.root`
1506    being the actual module being traced. Since make_fx traces a lambda of our
1507    creation, things don't work properly.
1508
1509    So for this version we hold onto a reference to the original module
1510    (scope_root) and use that to match the path. Also when we see,
1511            A
1512           / \
1513          B   C
1514           \ /
1515            D
1516    we want to record the path as A.B.D by recording only one path.
1517    See Note [Preserving the nn module stack metadata during export non-strict mode]  # noqa: W605
1518    """
1519
1520    def __init__(self, scope_root: GraphModule) -> None:
1521        super().__init__()
1522        self.scope_root = scope_root
1523        self.proxy_paths: WeakKeyDictionary[_AttrProxy, str] = WeakKeyDictionary()
1524        self.attr_proxy_map: WeakKeyDictionary[Module, _AttrProxy] = WeakKeyDictionary()
1525        self.proxy_modules: WeakKeyDictionary[_AttrProxy, Module] = WeakKeyDictionary()
1526        self.counter = 0
1527
1528        self.module_id_cache = defaultdict(list)
1529        for name, mod in self.scope_root.named_modules(remove_duplicate=False):
1530            self.module_id_cache[id(mod)].append(name)
1531
1532        # Build a wrapper around _AttrProxy to provide the tracer. We can't
1533        # store it on _AttrProxy itself beceause we mimic the underlying class
1534        # (including its attributes).
1535        tracer = self
1536
1537        class AttrProxy(_AttrProxy):
1538            def __init__(self, base: Module, path: str) -> None:
1539                # Class is modified to be a subclass of torch.nn.Module
1540                # Warning: We blow away our own attributes here to mimic the base class
1541                # - so don't expect `self.x` to do anything useful.
1542                self.__class__ = type(
1543                    base.__class__.__name__,
1544                    (self.__class__, base.__class__),
1545                    {},
1546                )
1547                self.__dict__ = base.__dict__
1548                self.__class__.__module__ = base.__class__.__module__
1549                self.__class__.__qualname__ = base.__class__.__qualname__
1550                self.reset_proxy_mapping(base, path)
1551
1552            def reset_proxy_mapping(self, base: Module, path: str) -> None:
1553                tracer.proxy_paths[self] = path
1554                tracer.proxy_modules[self] = base
1555
1556            def __getattr__(self, name: str) -> AttrProxy:
1557                assert isinstance(self, Module)
1558                # Calling into torch.nn.Module.__getattr__ with super(),
1559                # That __getattr__ is patched to be module_getattr_wrapper in _symbolic_trace.py.
1560                # which then calls into _ModuleStackTracer.getattr
1561                attr_val = super().__getattr__(name)  # type: ignore[misc]
1562                if isinstance(attr_val, AttrProxy):
1563                    attr_val = tracer.proxy_modules[attr_val]
1564                elif not isinstance(attr_val, Module):
1565                    return attr_val
1566                if attr_val not in tracer.attr_proxy_map:
1567                    tracer.attr_proxy_map[attr_val] = AttrProxy(
1568                        attr_val, tracer.proxy_paths[self] + "." + name
1569                    )
1570                else:
1571                    # NOTE [caching AttrProxy]. Caching ensures a 1-1 mapping between AttrProxy and the actual attr_val.
1572                    # 1. We reset the proxy_mapping to solve the diamond shape reference problem: we want to record the
1573                    # path as A.B.D instead of A.C.D (the purpose of _ModuleStackTracer).
1574                    # 2. Instead of creating a new AttrProxy, we just reset the proxy_mapping of existing one. This is to avoid
1575                    # dynamo creating multiple guards for the same attr_val but different AttrProxy when exporting
1576                    # a model that calls torch.compile (e.g when a model uses torch.cond.)
1577                    tracer.attr_proxy_map[attr_val].reset_proxy_mapping(
1578                        attr_val, tracer.proxy_paths[self] + "." + name
1579                    )
1580                return tracer.attr_proxy_map[attr_val]
1581
1582            def get_base(self) -> Module:
1583                return tracer.proxy_modules[self]
1584
1585            @property
1586            def _modules(self) -> Dict[str, AttrProxy]:
1587                assert "_modules" in self.__dict__
1588                submodules = self.__dict__["_modules"]
1589                assert isinstance(submodules, dict)
1590                return {
1591                    key: AttrProxy(value, tracer.proxy_paths[self] + "." + str(key))
1592                    for key, value in submodules.items()
1593                }
1594
1595        self.proxy_type = AttrProxy
1596
1597    def path_of_module(self, mod: Module) -> str:
1598        """
1599        Use tracked access path during tracing instead of the default BFS behavior.
1600        Still use all the possible module paths to verify the result.
1601        """
1602        if mod is self.scope_root:
1603            return ""
1604
1605        if isinstance(mod, _AttrProxy):
1606            return self.proxy_paths[mod]
1607
1608        try:
1609            return Tracer.path_of_module(self, mod)
1610        except NameError as e:
1611            raise _ModuleNotInstalledAsSubmoduleError from e
1612
1613    def getattr(
1614        self, attr: str, attr_val: object, parameter_proxy_cache: Dict[str, Proxy]
1615    ) -> object:
1616        if not isinstance(attr_val, Module) or isinstance(attr_val, fx.GraphModule):
1617            return super().getattr(attr, attr_val, parameter_proxy_cache)
1618        if isinstance(attr_val, _AttrProxy):
1619            return attr_val
1620
1621        # See NOTE [caching AttrProxy].
1622        if attr_val not in self.attr_proxy_map:
1623            self.attr_proxy_map[attr_val] = self.proxy_type(attr_val, attr)
1624        else:
1625            self.attr_proxy_map[attr_val].reset_proxy_mapping(attr_val, attr)
1626        return self.attr_proxy_map[attr_val]
1627
1628    def trace(  # type: ignore[override]
1629        self, root: Union[Module, Callable], concrete_args: Optional[Dict[str, object]]
1630    ) -> fx.Graph:
1631        res = super().trace(root, concrete_args)
1632
1633        # Since we are making _AttrProxy mimic the original
1634        # submodule, when someone registers a module directly
1635        # to the tracer while tracing, the proxy object gets registered
1636        # first. So we need to replace the proxy modules with the real ones
1637        # This can happen during HOO tracing
1638        proxy_module_names_to_be_replaced: List[Tuple[str, _AttrProxy]] = []
1639        for name, module in self.root.named_modules():
1640            if module in self.proxy_modules:
1641                proxy_module_names_to_be_replaced.append((name, module))
1642
1643        def _delete_proxy_attr(obj: Module, target: str) -> bool:
1644            # Copied from fx/graph_module.py
1645            # Customized it for proxy type
1646            atoms = target.split(".")
1647            path, target_submod = atoms[:-1], atoms[-1]
1648            assert isinstance(obj, Module)
1649            mod = obj
1650
1651            # Get the parent module
1652            for item in path:
1653                if not hasattr(mod, item):
1654                    return False
1655
1656                mod = getattr(mod, item)
1657
1658                if not isinstance(mod, (_AttrProxy, Module)):
1659                    return False
1660
1661            if not hasattr(mod, target_submod):
1662                return False
1663
1664            # At least the leaf module should be proxy type.
1665            if not isinstance(getattr(mod, target_submod), _AttrProxy):
1666                return False
1667
1668            delattr(mod, target_submod)
1669            return True
1670
1671        for proxy_module_name, proxy_module in proxy_module_names_to_be_replaced:
1672            _delete_proxy_attr(self.root, proxy_module_name)
1673            actual_module = self.proxy_modules[proxy_module]
1674            _assign_attr(actual_module, self.root, proxy_module_name)
1675
1676        return res
1677
1678    def call_module(
1679        self,
1680        m: Module,
1681        forward: Callable,
1682        args: Tuple[object, ...],
1683        kwargs: Dict[str, object],
1684    ) -> None:
1685        """PythonKeyTracer overrides call_module to avoid the scope handling,
1686        but we actually want it.
1687        """
1688        from torch._dynamo import OptimizedModule
1689
1690        # FIXME (tmanlaibaatar)
1691        # When we call torch.compile inside HOO, we will end up
1692        # invoking a module that is not registered on the root. For
1693        # now, we just inline them. But once we start supporting
1694        # mark_strict in export, we do need to properly handle this.
1695        # Right now, it doesn't matter because current non-strict
1696        # use cases don't need to work with HOO.
1697        if isinstance(m, (OptimizedModule, GraphModule)):
1698            return forward(*args, **kwargs)
1699
1700        try:
1701            return Tracer.call_module(self, m, forward, args, kwargs)
1702        except _ModuleNotInstalledAsSubmoduleError as e:
1703            warnings.warn(
1704                f"Unable to find the path of the module {m}. "
1705                "This might be because the module was not properly registered "
1706                "as a submodule, which is not good practice. We will trace "
1707                "through the module without recording stack information."
1708            )
1709            return forward(*args, **kwargs)
1710
1711    def is_leaf_module(self, m: Module, module_qualified_name: str) -> bool:
1712        return False
1713
1714    def create_node(self, *args: object, **kwargs: object) -> fx.node.Node:
1715        """
1716        Create node and add on metadata.
1717        Add nn_module_stack here instead of TracerBase,
1718        since calls to make_fx() might not want to record module stack metadata.
1719        Add torch_fn by looking at torch_fn_metadata and torch_fn_counts.
1720        Add stack_trace by filtering out forward() stack frames.
1721        """
1722        node = super().create_node(*args, **kwargs)  # type: ignore[arg-type]
1723
1724        # nn_module_stack
1725        if node.op not in ["placeholder", "output"]:
1726            if "nn_module_stack" not in node.meta:
1727                node.meta["nn_module_stack"] = self.module_stack
1728            # convert nn_module_stack from Dict[key, (FQN, class)] -> Dict[str, Tuple[str, str]]
1729            for key, (fqn, mod_cls) in node.meta["nn_module_stack"].items():
1730                if isinstance(mod_cls, type):
1731                    node.meta["nn_module_stack"][key] = (
1732                        fqn,
1733                        mod_cls.__module__ + "." + mod_cls.__qualname__,
1734                    )
1735
1736        # torch_fn
1737        if (
1738            node.op == "call_function"
1739            and self.torch_fn_metadata is not None
1740            and "torch_fn" not in node.meta
1741        ):
1742            node.meta["torch_fn"] = (
1743                f"{self.torch_fn_metadata.__name__}_{self.torch_fn_counts[self.torch_fn_metadata]}",
1744                f"{self.torch_fn_metadata.__class__.__name__}.{self.torch_fn_metadata.__name__}",
1745            )
1746
1747        # stack_trace
1748        if "stack_trace" not in node.meta and node.op not in ["placeholder", "output"]:
1749            user_frame_summary = CapturedTraceback.extract().summary()
1750            if user_frame_summary:
1751                # we retain frames from forward() calls, or ops
1752                # located in torch/__init__.py (e.g. sym_int, sym_constrain_range, vmap)
1753                stack_trace = [
1754                    frame
1755                    for frame in user_frame_summary
1756                    if (
1757                        frame.name == "forward"
1758                        or frame.filename.endswith("torch/__init__.py")
1759                    )
1760                ]
1761                # filter out forward() frames from fx/_symbolic_trace.py, export/_trace.py
1762                # this is hardcoded, but leads to a much cleaner stack trace
1763                stack_trace = [
1764                    frame
1765                    for frame in stack_trace
1766                    if not (
1767                        frame.filename.endswith("fx/_symbolic_trace.py")
1768                        or frame.filename.endswith("export/_trace.py")
1769                    )
1770                ]
1771                if (
1772                    stack_trace
1773                ):  # empty list for strict mode, dynamo should handle stack_trace
1774                    stack_trace = traceback.StackSummary.from_list(stack_trace)
1775                    node.meta["stack_trace"] = "".join(stack_trace.format()).strip()
1776
1777        return node
1778
1779
1780class _MakefxTracer:
1781    def __init__(
1782        self,
1783        decomposition_table: Optional[Mapping[OpOverload, Callable]],
1784        tracing_mode: str,
1785        _allow_non_fake_inputs: bool,
1786        pre_dispatch: bool,
1787        record_module_stack: bool,
1788        _allow_fake_constant: bool,
1789        _error_on_data_dependent_ops: bool,
1790    ) -> None:
1791        # Configurations that are used to initialize the context managers and their states.
1792        # Should not modify them during tracing.
1793        self.decomposition_table: Dict[OpOverload, Callable] = dict(
1794            decomposition_table or {}
1795        )
1796        self.decomposition_table.setdefault(
1797            torch.ops.aten.sym_numel.default, torch._decomp.decompositions.sym_numel
1798        )
1799        self.tracing_mode: str = tracing_mode
1800        self._allow_non_fake_inputs: bool = _allow_non_fake_inputs
1801        self.pre_dispatch: bool = pre_dispatch
1802        self.record_module_stack: bool = record_module_stack
1803        self._allow_fake_constant: bool = _allow_fake_constant
1804        self._error_on_data_dependent_ops: bool = _error_on_data_dependent_ops
1805
1806        # All context managers and their states should be initialized before tracing based on the inputs
1807        # and configurations. After tracing, their states should be cleaned except for shape_env.
1808        # Remember to specify how to intialize it from user inputs and from parent tracer whenever
1809        # adding new modes in _MakefxTracer.
1810        self.fake_tensor_mode: Optional[FakeTensorMode] = None
1811        self.proxy_mode: Union[nullcontext, ProxyTorchDispatchMode] = nullcontext()
1812        self.proxy_function_mode: Union[
1813            nullcontext, PreDispatchTorchFunctionMode
1814        ] = nullcontext()
1815        self.fx_tracer: Optional[PythonKeyTracer] = None
1816        self.python_dispatcher_mode: Union[nullcontext, Any] = nullcontext()
1817        self.torch_fn_metadata_mode: Union[
1818            nullcontext, TorchFunctionMetadataMode
1819        ] = nullcontext()
1820
1821    def _checkpoint_modes(self) -> List[Any]:
1822        return [
1823            self.fake_tensor_mode,
1824            self.proxy_mode,
1825            self.proxy_function_mode,
1826            self.fx_tracer,
1827            self.python_dispatcher_mode,
1828            self.torch_fn_metadata_mode,
1829        ]
1830
1831    def _restore_modes(
1832        self,
1833        prev_fake_tensor_mode: Optional[FakeTensorMode],
1834        prev_proxy_mode: Union[nullcontext, ProxyTorchDispatchMode],
1835        prev_proxy_function_mode: Union[nullcontext, PreDispatchTorchFunctionMode],
1836        prev_fx_tracer: Optional[PythonKeyTracer],
1837        prev_python_dispatcher_mode: Union[nullcontext, Any],
1838        prev_torch_fn_metadata_mode: Union[nullcontext, TorchFunctionMetadataMode],
1839    ) -> None:
1840        self.fake_tensor_mode = prev_fake_tensor_mode
1841        self.proxy_mode = prev_proxy_mode
1842        self.proxy_function_mode = prev_proxy_function_mode
1843        self.fx_tracer = prev_fx_tracer
1844        self.python_dispatcher_mode = prev_python_dispatcher_mode
1845        self.torch_fn_metadata_mode = prev_torch_fn_metadata_mode
1846
1847    @contextmanager
1848    def _init_modes_from_inputs(
1849        self, f: Callable, args: Tuple[object, ...]
1850    ) -> Generator[None, None, None]:
1851        prev_modes = self._checkpoint_modes()
1852        try:
1853            # Avoid importing sympy at a module level
1854            from .symbolic_shapes import ShapeEnv
1855
1856            if hasattr(f, "_orig_mod") and self.record_module_stack:
1857                scope_root = f._orig_mod
1858                self.fx_tracer = _ModuleStackTracer(scope_root)
1859            else:
1860                self.fx_tracer = PythonKeyTracer()
1861
1862            if self.tracing_mode == "fake":
1863                import torch._dynamo
1864
1865                fake_tensor_mode = torch._dynamo.utils.detect_fake_mode(args)
1866                if fake_tensor_mode is None:
1867                    import torch._functorch.config as _config
1868
1869                    with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
1870                        fake_tensor_mode = FakeTensorMode(
1871                            allow_fallback_kernels=True,
1872                            allow_non_fake_inputs=self._allow_non_fake_inputs,
1873                            shape_env=ShapeEnv(),
1874                            static_shapes=True,
1875                        )
1876                self.fake_tensor_mode = fake_tensor_mode
1877            elif self.tracing_mode == "symbolic":
1878                import torch._dynamo
1879
1880                fake_tensor_mode = torch._dynamo.utils.detect_fake_mode(args)
1881                if fake_tensor_mode is None:
1882                    shape_env = ShapeEnv()
1883                    import torch._functorch.config as _config
1884
1885                    with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
1886                        fake_tensor_mode = FakeTensorMode(
1887                            allow_fallback_kernels=False,
1888                            allow_non_fake_inputs=self._allow_non_fake_inputs,
1889                            shape_env=shape_env,
1890                        )
1891                assert (
1892                    fake_tensor_mode.shape_env is not None
1893                ), "shape_env should be set if tracing with 'symbolic'"
1894                self.fake_tensor_mode = fake_tensor_mode
1895            else:
1896                if not self.tracing_mode == "real":
1897                    raise AssertionError(
1898                        f"Unexpected tracing type: {self.tracing_mode}"
1899                    )
1900
1901            self._construct_modes_with_fx_tracer(self.fx_tracer)
1902            yield
1903        finally:
1904            self._restore_modes(*prev_modes)
1905
1906    def _construct_modes_with_fx_tracer(self, fx_tracer: _ProxyTracer) -> None:
1907        self.proxy_mode = ProxyTorchDispatchMode(
1908            fx_tracer,
1909            self.tracing_mode,
1910            pre_dispatch=self.pre_dispatch,
1911            _allow_fake_constant=self._allow_fake_constant,
1912            _error_on_data_dependent_ops=self._error_on_data_dependent_ops,
1913        )
1914
1915        if self.pre_dispatch:
1916            self.proxy_function_mode = PreDispatchTorchFunctionMode(fx_tracer)
1917
1918        # pre-autograd tracing uses per-dispatch-key modes,
1919        # which requires the python dispatcher
1920        if self.tracing_mode == "symbolic" or self.pre_dispatch:
1921            self.python_dispatcher_mode = enable_python_dispatcher()
1922
1923        self.torch_fn_metadata_mode = TorchFunctionMetadataMode(fx_tracer)
1924
1925    @contextmanager
1926    def _init_modes_from_parent(
1927        self, parent_tracer: _MakefxTracer
1928    ) -> Generator[None, None, None]:
1929        # By default, subtracer creates new modes based on parent tracer's config.
1930        # However, there are cases where we want to share the same modes with parent tracer
1931        # For example, fake_tensor_mode, we want the example value's fake_mode of parent graph and subgraphs to be the same.
1932        prev_modes = self._checkpoint_modes()
1933        try:
1934            self.fake_tensor_mode = parent_tracer.fake_tensor_mode
1935
1936            def _create_sub_fx_tracer(parent_tracer: _ProxyTracer) -> PythonKeyTracer:
1937                if type(parent_tracer) == PythonKeyTracer:
1938                    return PythonKeyTracer()
1939                elif type(parent_tracer) == _ModuleStackTracer:
1940                    return _ModuleStackTracer(parent_tracer.scope_root)
1941                else:
1942                    raise RuntimeError(
1943                        f"Unexpected tracer type: {type(parent_tracer)}."
1944                    )
1945
1946            assert parent_tracer.fx_tracer is not None
1947            self.fx_tracer = _create_sub_fx_tracer(parent_tracer.fx_tracer)
1948            self._construct_modes_with_fx_tracer(self.fx_tracer)
1949            yield
1950        finally:
1951            self._restore_modes(*prev_modes)
1952
1953    def _trace_inner(self, f: Callable, *args: object) -> GraphModule:
1954        phs = pytree.tree_map(lambda _: torch.fx._symbolic_trace.PH, args)
1955
1956        def _wrap_fake(args: T) -> T:
1957            arg_count = 0
1958
1959            def inner_wrap_fake(x: object) -> object:
1960                nonlocal arg_count
1961                # TODO: it would be nice to line these up with the names
1962                # FX will choose for the placeholders, but we don't
1963                # actually know what the names will be at this point yet
1964                # NB: the Source here is actually meaningless
1965                from torch._dynamo.source import ConstantSource
1966
1967                assert self.fake_tensor_mode is not None
1968                source = ConstantSource(f"input{arg_count}")
1969                if isinstance(x, Tensor):
1970                    arg_count += 1
1971                    return self.fake_tensor_mode.from_tensor(x, source=source)
1972                # NB: don't match on bools
1973                elif type(x) is int and self.tracing_mode == "symbolic":
1974                    assert (
1975                        self.fake_tensor_mode.shape_env is not None
1976                    ), "shape_env should be set if tracing with 'symbolic'"
1977                    return self.fake_tensor_mode.shape_env.create_symintnode(
1978                        self.fake_tensor_mode.shape_env.create_symbol(
1979                            x, source, positive=None
1980                        ),
1981                        hint=x,
1982                        source=source,
1983                    )
1984                elif isinstance(x, torch.ScriptObject):
1985                    return torch._library.fake_class_registry.maybe_to_fake_obj(
1986                        self.fake_tensor_mode, x
1987                    )
1988
1989                assert not isinstance(
1990                    x, FakeScriptObject
1991                ), f"ScriptObject {x} has been fakified. Cannot wrap_fake it again."
1992                return x
1993
1994            wrap_fn_map = {
1995                "real": lambda x: x,
1996                "fake": inner_wrap_fake,
1997                "symbolic": inner_wrap_fake,
1998            }
1999            return pytree.tree_map(wrap_fn_map[self.tracing_mode], args)
2000
2001        def _wrap_func(f: Callable[_P, R], phs: Sequence[PHBase]) -> Callable[_P, R]:
2002            if (
2003                not hasattr(inspect.unwrap(f), "__code__")
2004                or inspect.unwrap(f).__code__.co_flags & inspect.CO_VARARGS
2005            ):
2006                # FX doesn't support varargs, so we gotta fake up a wrapper
2007                # TODO: Would be nice to fix this at the source...
2008                return fake_signature(f, len(phs))
2009            return f
2010
2011        args = _wrap_fake(args)
2012        func = _wrap_func(f, phs)
2013        # We disable the autocast cache as the autocast cache causes type conversions on parameters to
2014        # check a cache, which introduces untracked tensors into the graph
2015        #
2016        # We also disable tracing by any other tensor proxy-based tracers except the current. The
2017        # purpose of `make_fx` is to produce graphmodules as a side effect; its internal execution is
2018        # thus irrelevant to any external functional trace.
2019        proxy_mode: ProxyTorchDispatchMode = typing.cast(
2020            ProxyTorchDispatchMode, self.proxy_mode
2021        )
2022        with ExitStack() as stack:
2023            stack.enter_context(decompose(self.decomposition_table))
2024            if self.fake_tensor_mode:
2025                stack.enter_context(self.fake_tensor_mode)
2026            stack.enter_context(self.python_dispatcher_mode)
2027            stack.enter_context(self.proxy_function_mode)
2028            stack.enter_context(self.torch_fn_metadata_mode)
2029            stack.enter_context(proxy_mode)
2030            stack.enter_context(disable_autocast_cache())
2031            stack.enter_context(_set_make_fx_tracer(self))
2032
2033            assert self.fx_tracer is not None
2034            t = dispatch_trace(
2035                wrap_key(func, args, self.fx_tracer, self.pre_dispatch),
2036                tracer=self.fx_tracer,
2037                concrete_args=tuple(phs),
2038            )
2039
2040        # TODO: kind of a bad way to do it, should maybe figure out a better way
2041        if self.tracing_mode == "symbolic":
2042            assert self.fake_tensor_mode is not None
2043            t.shape_env = self.fake_tensor_mode.shape_env
2044        return t
2045
2046    def trace(self, f: Callable, *args: object) -> fx.GraphModule:
2047        with self._init_modes_from_inputs(f, args):
2048            return self._trace_inner(f, *args)
2049
2050    def trace_subgraph(self, f: Callable, *args: object) -> GraphModule:
2051        # Create a new tracer based on parent's config
2052        sub_tracer = _MakefxTracer(
2053            self.decomposition_table,
2054            "real",
2055            self._allow_non_fake_inputs,
2056            self.pre_dispatch,
2057            self.record_module_stack,
2058            self._allow_fake_constant,
2059            self._error_on_data_dependent_ops,
2060        )
2061        with sub_tracer._init_modes_from_parent(self):
2062            return sub_tracer._trace_inner(f, *args)
2063
2064
2065_CURRENT_MAKE_FX_TRACER: Optional[_MakefxTracer] = None
2066
2067
2068@contextmanager
2069def _set_make_fx_tracer(tracer: _MakefxTracer) -> Generator[None, None, None]:
2070    global _CURRENT_MAKE_FX_TRACER
2071    prev_tracer = _CURRENT_MAKE_FX_TRACER
2072    try:
2073        _CURRENT_MAKE_FX_TRACER = tracer
2074        yield
2075    finally:
2076        _CURRENT_MAKE_FX_TRACER = prev_tracer
2077
2078
2079def make_fx(
2080    f: Callable,
2081    decomposition_table: Optional[Mapping[OpOverload, Callable]] = None,
2082    tracing_mode: str = "real",
2083    _allow_non_fake_inputs: bool = False,
2084    *,
2085    pre_dispatch: bool = False,
2086    record_module_stack: bool = False,
2087    _allow_fake_constant: bool = False,
2088    _error_on_data_dependent_ops: bool = True,
2089) -> Callable[..., GraphModule]:
2090    """
2091    Given a function f, return a new function which when executed with valid
2092    arguments to f, returns an FX GraphModule representing the set of operations that
2093    were executed during the course of execution.
2094    """
2095
2096    assert tracing_mode in ["real", "fake", "symbolic"]
2097
2098    make_fx_tracer = _MakefxTracer(
2099        decomposition_table,
2100        tracing_mode,
2101        _allow_non_fake_inputs,
2102        pre_dispatch,
2103        record_module_stack,
2104        _allow_fake_constant,
2105        _error_on_data_dependent_ops,
2106    )
2107
2108    @functools.wraps(f)
2109    def wrapped(*args: object) -> GraphModule:
2110        return make_fx_tracer.trace(f, *args)
2111
2112    return wrapped
2113
2114
2115def get_torch_dispatch_modes() -> List[TorchDispatchMode]:
2116    return torch.utils._python_dispatch._get_current_dispatch_mode_stack()
2117
2118
2119# TODO: this is a legacy name, there is only ever one proxy mode as it's an
2120# infra mode
2121def get_innermost_proxy_mode() -> Optional[ProxyTorchDispatchMode]:
2122    return get_proxy_mode()
2123
2124
2125def get_proxy_mode() -> Optional[ProxyTorchDispatchMode]:
2126    """
2127    Current the currently active proxy tracing mode, or None if
2128    we are not currently tracing.  This includes pre-dispatch proxy
2129    tracing.
2130    """
2131    pre_dispatch_mode = torch._ops._get_dispatch_mode_pre_dispatch(
2132        torch._C._TorchDispatchModeKey.PROXY
2133    )
2134    mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY)
2135    assert (
2136        pre_dispatch_mode is None or mode is None
2137    ), f"pre_dispatch_mode={pre_dispatch_mode}, mode={mode}"
2138    return pre_dispatch_mode or mode
2139
2140
2141def handle_sym_dispatch(func: Callable[_P, R], args: _P.args, kwargs: _P.kwargs) -> R:
2142    """
2143    Call into the currently active proxy tracing mode to do a
2144    SymInt/SymFloat/SymBool dispatch trace on a function that operates on
2145    these arguments.
2146    """
2147    mode = get_proxy_mode()
2148    assert mode
2149    # Have to do it manually, because we're not doing the normal torch
2150    # dispatch machinery which disables it for us
2151    with disable_proxy_modes_tracing():
2152        # TODO: properly compute types
2153        types: List[Type] = []
2154        return mode.__sym_dispatch__(func, types, args, kwargs)  # type: ignore[arg-type, return-value]
2155
2156
2157@contextmanager
2158def disable_proxy_modes_tracing() -> Generator[ProxyTorchDispatchMode, None, None]:
2159    return _disable_infra_mode(torch._C._TorchDispatchModeKey.PROXY)
2160
2161
2162def maybe_handle_decomp(
2163    proxy_mode: ProxyTorchDispatchMode,
2164    op: OpOverload,
2165    args: Tuple[object, ...],
2166    kwargs: Dict[str, object],
2167) -> object:
2168    if op in CURRENT_DECOMPOSITION_TABLE:
2169        with proxy_mode:
2170            proxy_mode.decomp_layers += 1
2171            out = CURRENT_DECOMPOSITION_TABLE[op](*args, **kwargs)
2172            proxy_mode.decomp_layers -= 1
2173            return out
2174
2175    return NotImplemented
2176
2177
2178def get_isolated_graphmodule(
2179    func: Callable,
2180    args: Tuple[object, ...],
2181    kwargs: Dict[str, object],
2182    tracing_mode: str = "real",
2183    decomposition_table: Optional[Mapping[OpOverload, Callable]] = None,
2184) -> GraphModule:
2185    """A helper function used to get the GraphModule for the given func.
2186
2187    It's expected to be used in the ProxyTensor tracing context.
2188    It detaches the args and kwargs from the current tracer so that the trace of
2189    the current graph module can be created without any side-effects.
2190    """
2191    wrapped, all_args = wrapper_and_args_for_make_fx(func, args, kwargs)
2192
2193    with disable_proxy_modes_tracing():
2194        gm = make_fx(
2195            wrapped, decomposition_table=decomposition_table, tracing_mode=tracing_mode
2196        )(all_args)
2197    return gm
2198
2199
2200def _set_unbacked_bindings(out: object, out_proxy: _NestedProxys) -> None:
2201    """A helper function for setting up unbacked_bindings on the destination FX graph."""
2202    from .symbolic_shapes import compute_unbacked_bindings
2203
2204    # Can't use detect_fake_mode here,
2205    #
2206    # python test/distributed/_tensor/test_dtensor_compile.py -k
2207    # test_tp_compile_fullgraph_is_seq_parallel_False
2208    #
2209    # will fail.  Very strange, it probably isn't right for them to be using
2210    # two fake modes there...
2211    fake_mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE)
2212    if fake_mode and fake_mode.shape_env:
2213        if symbol_to_path := compute_unbacked_bindings(fake_mode.shape_env, out):
2214            assert isinstance(out_proxy, Proxy), out_proxy
2215            out_proxy.node.meta["unbacked_bindings"] = symbol_to_path
2216