xref: /aosp_15_r20/external/pytorch/torch/_subclasses/fake_tensor.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2from __future__ import annotations
3
4import atexit
5import contextlib
6import dataclasses
7import functools
8import logging
9import math
10import os
11import traceback
12import typing
13import weakref
14from collections import defaultdict
15from dataclasses import dataclass
16from typing import (
17    Any,
18    Callable,
19    cast,
20    Dict,
21    Generator,
22    Iterable,
23    List,
24    Literal,
25    Mapping,
26    Optional,
27    Sequence,
28    Set,
29    Tuple,
30    Type,
31    TYPE_CHECKING,
32    TypeVar,
33    Union,
34)
35from typing_extensions import Self, TypeGuard
36from weakref import ReferenceType
37
38import torch
39from torch import SymBool, SymFloat, SymInt, Tensor
40from torch._C._functorch import is_functorch_wrapped_tensor, is_legacy_batchedtensor
41from torch._prims_common import suggest_memory_format
42from torch._subclasses.meta_utils import (
43    assert_eq,
44    assert_metadata_eq,
45    is_sparse_any,
46    is_sparse_compressed,
47    MetaConverter,
48)
49from torch._utils import render_call
50from torch.fx.immutable_collections import immutable_dict
51from torch.fx.operator_schemas import normalize_function
52from torch.multiprocessing.reductions import StorageWeakRef
53from torch.overrides import TorchFunctionMode
54from torch.types import IntLikeType, py_sym_types
55from torch.utils._backport_slots import dataclass_slots
56from torch.utils._mode_utils import no_dispatch
57from torch.utils._python_dispatch import (
58    is_traceable_wrapper_subclass,
59    TorchDispatchMode,
60)
61from torch.utils._pytree import PyTree, tree_map, tree_map_, TreeSpec
62from torch.utils._stats import count
63from torch.utils._traceback import CapturedTraceback
64
65from ._fake_tensor_utils import _CacheKeyState, _PySymInputStub, _SymIntOutputStub
66
67
68if TYPE_CHECKING:
69    from types import TracebackType
70
71    from torch._guards import Source
72    from torch._ops import OpOverload
73    from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymbolicContext
74
75log = logging.getLogger(__name__)
76
77# TODO: Hack to unblock https://github.com/pytorch/pytorch/pull/108186
78# Proper fix tracked by https://github.com/pytorch/pytorch/issues/120105
79try:
80    not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented")
81except ValueError as e:
82    if "'not_implemented' not registered" in str(e):
83        import logging as not_implemented_log
84    else:
85        raise e
86
87
88class _Unassigned:
89    pass
90
91
92_UNASSIGNED = _Unassigned()
93
94DimList = List
95
96pytree = torch.utils._pytree
97T = TypeVar("T")
98
99aten = torch._ops.ops.aten
100
101CONSTANT_NUMEL_LIMIT = 1
102
103RECURSION_COUNT = 0
104
105
106# Small helper that increments recursion count, and
107# resets it when the object goes out of scope.  Useful
108# if you don't want to increase indentation which is
109# what a context manager would do.
110class IncrementRecursionCount:
111    def __init__(self) -> None:
112        global RECURSION_COUNT
113        RECURSION_COUNT += 1
114
115    def __del__(self) -> None:
116        global RECURSION_COUNT
117        RECURSION_COUNT -= 1
118
119
120@dataclass
121class UnsupportedFakeTensorException(RuntimeError):
122    reason: str
123
124
125@dataclass
126class DynamicOutputShapeException(RuntimeError):
127    func: OpOverload
128
129
130@dataclass
131class DataDependentOutputException(RuntimeError):
132    func: OpOverload
133
134
135@dataclass
136class UnsupportedOperatorException(RuntimeError):
137    func: OpOverload
138
139
140def ordered_set(*items: T) -> Dict[T, Literal[True]]:
141    return dict.fromkeys(items, True)
142
143
144@contextlib.contextmanager
145def unset_fake_temporarily() -> Generator[Optional[TorchDispatchMode], None, None]:
146    old = torch._C._unset_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE)
147    try:
148        yield old
149    finally:
150        if old is not None:
151            torch._C._set_dispatch_mode(old)
152
153
154def get_plain_tensors(subclass: Tensor) -> List[Tensor]:
155    assert is_traceable_wrapper_subclass(subclass)
156    plain_tensors: List[Tensor] = []
157    todo = [subclass]
158    while todo:
159        curr = todo.pop()
160        if not is_traceable_wrapper_subclass(curr):
161            assert isinstance(curr, Tensor)
162            plain_tensors.append(curr)
163            continue
164
165        inner_keys, _ = curr.__tensor_flatten__()
166        for key in reversed(inner_keys):
167            todo.append(getattr(curr, key))
168
169    return plain_tensors
170
171
172def is_fake(x: object) -> TypeGuard[Tensor]:
173    if isinstance(x, FakeTensor):
174        return True
175    if is_traceable_wrapper_subclass(x):
176        attrs, _ = type(x).__tensor_flatten__(x)
177        flattened_tensors = [getattr(x, attr) for attr in attrs]
178        all_fake = all(is_fake(x) for x in flattened_tensors)
179        any_fake = any(is_fake(x) for x in flattened_tensors)
180        assert all_fake == any_fake, "got mixed fake and real tensors!"
181        return all_fake
182    elif isinstance(x, Tensor) and torch._is_functional_tensor(x):
183        reapply_views = torch._C._functionalization_reapply_views_tls()
184        unwrapped = torch._C._functorch._unwrap_functional_tensor(x, reapply_views)
185        return is_fake(unwrapped)
186    elif isinstance(x, Tensor) and is_functorch_wrapped_tensor(x):
187        unwrapped = torch._C._functorch.get_unwrapped(x)
188        return is_fake(unwrapped)
189    return False
190
191
192def maybe_get_fake_mode(t: object) -> Optional[FakeTensorMode]:
193    if isinstance(t, FakeTensor):
194        return t.fake_mode
195    if is_traceable_wrapper_subclass(t):
196        inner_tensor_names, _ = t.__tensor_flatten__()
197        modes = [
198            maybe_get_fake_mode(getattr(t, t_name)) for t_name in inner_tensor_names
199        ]
200        m = modes[0]
201        assert all(m is x for x in modes)
202        return m
203    elif isinstance(t, Tensor) and torch._is_functional_tensor(t):
204        reapply_views = torch._C._functionalization_reapply_views_tls()
205        unwrapped = torch._C._functorch._unwrap_functional_tensor(t, reapply_views)
206        return maybe_get_fake_mode(unwrapped)
207    elif isinstance(t, Tensor) and is_functorch_wrapped_tensor(t):
208        unwrapped = torch._C._functorch.get_unwrapped(t)
209        return maybe_get_fake_mode(unwrapped)
210    return None
211
212
213@functools.lru_cache(None)
214def get_schema_info(func: OpOverload) -> torch._C._SchemaInfo:
215    return torch._C._SchemaInfo(func._schema)
216
217
218# many of the decompositions registered to torch/_prims do not at the moment model
219# aliasing or strides, so as an incremental step, just enable the decompositions in
220# torch/_decomp/decompositions.py.
221# decomps are used for aot autograd tracing so we would like to unify on their
222# implementation and add additional testing to them
223@functools.lru_cache(None)
224def torch_decomp_decompositions(func: OpOverload) -> bool:
225    from torch._decomp import decomposition_table
226
227    decompositions = torch._decomp.decompositions
228    # Note that the function in the decomposition table might be
229    # different from the one in the module because of the difference
230    # in out handling in aten API and torch public API
231    return decomposition_table[func].__module__.startswith(
232        "torch._decomp"
233    ) and decomposition_table[func].__name__ in dir(decompositions)
234
235
236def tree_flatten_only(ty: Type[T], tree: PyTree) -> List[T]:
237    flat_vals = pytree.tree_leaves(tree)
238    return [elem for elem in flat_vals if isinstance(elem, ty)]
239
240
241def _is_plain_tensor(t: object) -> bool:
242    return (
243        type(t) is Tensor
244        and t.layout == torch.strided
245        and not (
246            t.is_sparse
247            or t.is_nested
248            or is_functorch_wrapped_tensor(t)
249            or is_legacy_batchedtensor(t)
250            or torch._is_functional_tensor(t)
251        )
252    )
253
254
255# Similar to `MetaConverter`, this is a class for converting
256# multiple tensors into fake tensors which share the same view/storage
257# structure. Like `MetaConverter`, it uses `WeakIdRef` to
258# hold a weak reference for all memoized tensors.
259class FakeTensorConverter:
260    @property
261    def tensor_memo(
262        self,
263    ) -> weakref.WeakValueDictionary:
264        # not valid until py3.10
265        # weakref.WeakValueDictionary["torch._subclasses.meta_utils.MetaTensorId", Optional["FakeTensor"]]
266        return self.meta_converter.tensor_memo
267
268    meta_converter: MetaConverter
269    constant_storage_mapping: Dict[StorageWeakRef, List[ReferenceType]]
270    export: bool
271
272    def __init__(self, *, copy_data: bool = False, export: bool = False) -> None:
273        self.meta_converter = MetaConverter(copy_data=copy_data)
274        self.export = export
275
276        # map from to storage to corresponding constant tensors
277        self.constant_storage_mapping = {}
278
279    def add_constant_storage_mapping(self, fake_tensor: FakeTensor) -> None:
280        # when you have a constant, aliased tensor:
281        # const_tensor.add_(torch.rand([1]))
282        # all aliases of it must become no longer const
283        assert isinstance(fake_tensor, FakeTensor) and fake_tensor.constant is not None
284        weak_st = StorageWeakRef(fake_tensor.constant._typed_storage())
285
286        # we need a map from a weak storage to all of its corresponding
287        # constant tensors. python doesn't have the weak value equivalent
288        # of defaultdict(list), so we are using a WeakValueDictionary as one
289        if weak_st not in self.constant_storage_mapping:
290            self.constant_storage_mapping[weak_st] = []
291        self.constant_storage_mapping[weak_st].append(weakref.ref(fake_tensor))
292
293    def invalidate_constant_aliases(self, tensor: Tensor) -> None:
294        assert not isinstance(tensor, FakeTensor)
295
296        weak_st = StorageWeakRef(tensor._typed_storage())
297        if weak_st not in self.constant_storage_mapping:
298            return
299
300        for weak_tensor_ref in self.constant_storage_mapping[weak_st]:
301            ten = weak_tensor_ref()
302            if ten is not None:
303                ten._fix_weakref()
304                ten.constant = None
305
306        del self.constant_storage_mapping[weak_st]
307
308    def _get_memo(self, t: Tensor) -> Optional[FakeTensor]:
309        tid = self.meta_converter.describer.lookup_tensor.get(t)
310        if tid is None:
311            return None
312        return self.tensor_memo.get(tid)
313
314    def set_tensor_memo(self, t: Tensor, v: FakeTensor) -> None:
315        tid = self.meta_converter.describer.get_tensor_id(t)
316        self.meta_converter.tensor_memo[tid] = v
317
318    # You can have a real tensor that you need to convert into a fake tensor.
319    # If you have a meta tensor already, call from_meta_and_device.
320    #
321    # You're allowed to pass a meta tensor to be turned into a fake
322    # tensor; although an odd thing to do, this can occur if you're doing
323    # cross ref testing and the inner test is already operating on meta tensors.
324    def from_real_tensor(
325        self,
326        fake_mode: FakeTensorMode,
327        t: Tensor,
328        make_constant: bool = False,
329        shape_env: Optional[ShapeEnv] = None,
330        *,
331        source: Optional[Source] = None,
332        symbolic_context: Optional[SymbolicContext] = None,
333        trace: bool = True,
334    ) -> FakeTensor:
335        # see note [Tensor Fakification and Symbol Caching]
336        if not symbolic_context and not source and shape_env:
337            if tracing_context := torch._guards.TracingContext.try_get():
338                if t in tracing_context.tensor_to_context:
339                    symbolic_context = tracing_context.tensor_to_context[t]
340                    from torch.fx.experimental.symbolic_shapes import (
341                        StatefulSymbolicContext,
342                    )
343
344                    assert isinstance(symbolic_context, StatefulSymbolicContext)
345                    source = symbolic_context.tensor_source
346
347        maybe_memo = self._get_memo(t)
348        if maybe_memo is not None:
349            return maybe_memo
350        existing_device = t.device
351        # not yet supported in metatensors
352        if t.is_quantized:
353            raise UnsupportedFakeTensorException("quantized nyi in meta tensors")
354        if type(t) is torch.nn.Parameter:
355            assert not make_constant
356
357        def mk_fake_tensor(make_meta_t: Callable[[], object]) -> FakeTensor:
358            # NB: don't use in_kernel_invocation_manager. to
359            # ensure FakeTensor can internally do constant computation
360            # as necessary.  Invocation manager is "more correct" as
361            # it works for more operators in make_meta_t, but
362            # invariant is that make_meta_t only calls factories
363            # for which it is not strictly necessary to use the
364            # invocation manager (I think!)
365            with no_dispatch():
366                return FakeTensor(
367                    fake_mode,
368                    make_meta_t(),
369                    existing_device,
370                    # TODO: callback might be used in recursive contexts, in
371                    # which case using t is wrong!  BUG!
372                    constant=t if make_constant else None,
373                )
374
375        out = self.meta_converter(
376            t,
377            shape_env=shape_env,
378            callback=mk_fake_tensor,
379            source=source,
380            symbolic_context=symbolic_context,
381            trace=trace,
382        )
383        if out is NotImplemented:
384            raise UnsupportedFakeTensorException("meta converter nyi")
385
386        from torch._dynamo.source import RandomValueSource
387
388        value = None
389        if (
390            not self.export
391            and _is_plain_tensor(t)  # mostly, we want to know if item() works
392            and t.dim() == 0
393            and t.device.type == "cpu"
394            # All integer types are fair game, because signed overflow is UB
395            # (and even int64 can overflow, since integers in Python are
396            # arbitrary precision). But only float64 is OK for float, because
397            # switching between float32 and float64 changes semantics in an
398            # observable way without hitting UB.
399            and t.dtype
400            in [torch.int64, torch.int32, torch.int16, torch.int8, torch.float64]
401            and source is not None
402            # Impede setting up item() on things coming from random.  These
403            # are not "real" item() calls, instead UnspecializedPythonVariable
404            # is unsafely pretending an int is a tensor, which can sometimes
405            # implicitly cause an item call.  The problem is this is pretty
406            # unsound: there's no reason substituting an int with a Tensor is
407            # going to give the same results.  Today, you mostly get around
408            # this by typically not having capture_scalar_outputs on and graph
409            # breaking when someone tries to use the unspec variable in an
410            # int-y context.  But allowing it through here would break that.
411            # So don't.
412            #
413            # Once random values are setup to be represented as
414            # SymNodeVariable, this condition can be removed.  To check if
415            # you've done it right, this is a good test:
416            #
417            #   PYTORCH_TEST_WITH_DYNAMO=1 python test/test_reductions.py -k
418            #   TestReductionsCPU.test_dim_reduction_fns_fn_name_amax_cpu_bfloat16
419            and not isinstance(source, RandomValueSource)
420            # In Dynamo, shape_env is never none (even with static shapes).
421            # However, FakeTensorMode can be used by hand and in some cases
422            # ShapeEnv is not allocated.
423            and shape_env is not None
424        ):
425            from torch._dynamo.source import CallMethodItemSource, FloatTensorSource
426            from torch.fx.experimental.symbolic_shapes import DimDynamic
427
428            with no_dispatch():
429                value = t.item()
430            if not math.isnan(value):
431                # Peephole strip out unnecessary torch.as_tensor(x).item()
432                if isinstance(source, FloatTensorSource):
433                    item_source = source.base
434                else:
435                    item_source = CallMethodItemSource(source)
436                symbol = shape_env.create_unspecified_symbol(
437                    value,
438                    source=item_source,
439                    dynamic_dim=DimDynamic.DYNAMIC,
440                )
441                # NB: reusing item_memo here ensures that we invalidate on
442                # mutation
443                if t.dtype == torch.int64:
444                    out.item_memo = shape_env.create_symintnode(
445                        symbol,
446                        hint=value,
447                        source=item_source,
448                    )
449                elif t.dtype == torch.float64:
450                    out.item_memo = shape_env.create_symfloatnode(
451                        symbol,
452                        hint=value,
453                        source=item_source,
454                    )
455        if make_constant:
456            self.add_constant_storage_mapping(out)
457        # NB: meta_converter set the memo
458        return out
459
460    # If you specify the device, it MUST be a meta tensor.
461    def from_meta_and_device(
462        self, fake_mode: FakeTensorMode, t: Tensor, device: torch.device
463    ) -> FakeTensor:
464        assert (
465            t.device.type == "meta"
466        ), f"tensor's device must be `meta`, got {t.device.type} instead"
467        # This is a bit abusive (this is not the "real" tensor) but whatever,
468        # the meta tensor should be fresh so there's no way to get it wrong
469        maybe_memo = self._get_memo(t)
470        if maybe_memo is not None:
471            return maybe_memo
472        out = FakeTensor(fake_mode, t, device)
473        self.set_tensor_memo(t, out)
474        return out
475
476
477@functools.lru_cache(None)
478def init_gpu_context() -> None:
479    # Backward will error with cuda Fake Tensors if no cuda tensors have been initialized first
480    if torch.cuda.is_available():
481        (
482            torch.empty(1, device="cuda")
483            if torch.version.hip is None
484            else torch.zeros(1, device="cuda")
485        )
486
487    if torch.xpu.is_available():
488        (torch.empty(1, device="xpu"))
489
490
491@contextlib.contextmanager
492def in_kernel_invocation_manager(
493    fake_mode: FakeTensorMode,
494) -> Generator[None, None, None]:
495    # See: note [Fake Tensor Dispatch Keys]
496    prev_in_kernel = fake_mode.in_kernel_invocation
497    meta_in_tls = torch._C._meta_in_tls_dispatch_include()
498    assert meta_in_tls == prev_in_kernel, f"{meta_in_tls}, {prev_in_kernel}"
499
500    with torch._C._DisableTorchDispatch():
501        fake_mode.in_kernel_invocation = True
502        # Unfortunately _set_meta_in_tls_dispatch_include(False) can leave
503        # `Dense` turned on (because it's implied by `Meta`)
504        with torch._C._PreserveDispatchKeyGuard():
505            torch._C._set_meta_in_tls_dispatch_include(True)
506            try:
507                yield
508            finally:
509                fake_mode.in_kernel_invocation = prev_in_kernel
510                # torch._C._set_meta_in_tls_dispatch_include(prev_in_kernel)
511
512
513# Return if the function allows Python numbers to bind to Tensors
514def should_allow_numbers_as_tensors(func: OpOverload) -> bool:
515    return torch._C._should_allow_numbers_as_tensors(
516        func.name().split("::")[-1].split(".")[0]
517    )
518
519
520class FakeTensorConfig:
521    debug = os.environ.get("TORCH_FAKE_TENSOR_DEBUG", "0") == "1"
522
523
524# This memorizes the unbacked SymInt representing quantities like the number
525# of nonzero elements in this tensor.  There is one instance of the descriptor
526# per particular quantity to memoize.
527#
528# Memoization is helpful if you do something like x[mask] and y[mask];
529# mask.nonzero() gets repeatedly called and should give a consistent unbacked
530# SymInt.  It needs to be invalidated in the same way constant is.
531#
532# Making this a descriptor may seem overly fancy, but actually it's the most
533# convenient way to make sure we have access to FakeTensor during access,
534# which is required for testing version counter and epoch validity
535class SymIntMemoDescriptor:
536    _name: str
537
538    # By default, SymInts in this memo are invalidated across versions/epochs.
539    # nested_ints however are preserved across epochs and across versions.
540    # Preserving across versions is okay for nested int since the association
541    # of a nested int is agnostic to the underlying data and nested ints are not
542    # shared across multiple distinct tensors.
543    _is_nested_int: bool
544
545    def __init__(self, *, is_nested_int: bool = False) -> None:
546        self._is_nested_int = is_nested_int
547
548    def __set_name__(self, owner: str, name: str) -> None:
549        self._name = name
550
551    def _memo(self, obj: FakeTensor) -> str:
552        return f"_{self._name}"
553
554    def _memo_vc(self, obj: FakeTensor) -> str:
555        return f"_{self._name}_vc"
556
557    # When we retrace, we need to invalidate all the memos so that we can
558    # accurately identify the first time unbacked SymInts are allocated.
559    # This is only relevant for inputs; for intermediates, they will get fresh
560    # fake tensors so you won't have a memo anyway
561    def _memo_epoch(self, obj: FakeTensor) -> str:
562        return f"_{self._name}_epoch"
563
564    def __get__(
565        self, obj: FakeTensor, objtype: Optional[Type[FakeTensor]] = None
566    ) -> Optional[torch.SymInt]:
567        if (r := getattr(obj, self._memo(obj))) is None:
568            return None
569        # Version counter based tracking isn't 100% sound but it's close
570        # enough
571        if (
572            not self._is_nested_int and getattr(obj, self._memo_vc(obj)) != obj._version
573        ) or (
574            not self._is_nested_int
575            and getattr(obj, self._memo_epoch(obj)) != obj.fake_mode.epoch
576        ):
577            setattr(obj, self._memo(obj), None)
578            return None
579        return r
580
581    def __set__(self, obj: FakeTensor, value: Optional[torch.SymInt]) -> None:
582        if value is None:
583            setattr(obj, self._memo(obj), None)
584            setattr(obj, self._memo_vc(obj), None)
585            setattr(obj, self._memo_epoch(obj), None)
586        elif not obj.is_inference() or self._is_nested_int:
587            setattr(obj, self._memo(obj), value)
588            if not self._is_nested_int:
589                setattr(obj, self._memo_vc(obj), obj._version)
590            setattr(obj, self._memo_epoch(obj), obj.fake_mode.epoch)
591
592
593class FakeTensor(Tensor):
594    """
595    Meta tensors give you the ability to run PyTorch code without having to
596    actually do computation through tensors allocated on a `meta` device.
597    Because the device is `meta`, meta tensors do not model device propagation.
598    FakeTensor extends MetaTensors to also carry an additional `fake_device`
599    which tracks devices that would have been used.
600    """
601
602    fake_device: torch.device
603    fake_mode: FakeTensorMode
604    constant: Optional[Tensor]
605    real_tensor: Optional[Tensor]
606
607    # TODO: Generalize this as needed, e.g., into a trie of memos, if
608    # you do something like x[0].item()  (x[0] is fresh each time, so
609    # memo mechanism here won't work)
610    nonzero_memo = SymIntMemoDescriptor()
611    item_memo = SymIntMemoDescriptor()
612    unique_memo = SymIntMemoDescriptor()
613
614    # We expect nested_int_memo to be None when an offsets is a graph
615    # intermediate, or an input that has never been associated with a
616    # nested int.
617    nested_int_memo = SymIntMemoDescriptor(is_nested_int=True)
618
619    # Indicates to our torch_dispatch dispatching infra that
620    # this is an "infra" mode with lower dispatching precedence.
621    _mode_key = torch._C._TorchDispatchModeKey.FAKE
622
623    @property
624    def device(self) -> torch.device:
625        if self.fake_mode.in_kernel_invocation:
626            return torch.device("meta")
627        else:
628            return self.fake_device
629
630    @device.setter
631    def device(self, _: torch.device) -> None:
632        raise NotImplementedError
633
634    # Note: [Fake Tensor Dispatch Keys]
635    # In order to model the behavior of device-specific autocast
636    # and autograd logic, we update the dispatch keys of FakeTensors
637    # to reflect their fake device. This includes the BackendComponent
638    # (DispatchKey::Meta -> DispatchKey::CUDA), and also the BackendComponent
639    # related Autocast and Autograd keys. __torch_dispatch__ sits below
640    # Autocast and Autograd, and is only invoked when we are at the
641    # kernel for the BackendComponent. Then, we add Meta to the
642    # thread-local dispatch include set to hit the meta kernel
643    # instead of the kernel of the BackendComponent for the fake device.
644    # The `device_for_backend_keys` does that below
645    # NOTE: this probably will not do the right thing for backends
646    # that have dispatch keys which are higher than the "meta" key:
647    # https://github.com/pytorch/pytorch/blob/main/c10/core/DispatchKey.h#L189
648
649    # We don't support named tensors; graph break
650    @property
651    def names(self) -> List[str]:
652        raise UnsupportedFakeTensorException(
653            "torch.compile doesn't support named tensors"
654        )
655
656    @names.setter
657    def names(self, _: List[str]) -> None:
658        raise NotImplementedError
659
660    @staticmethod
661    def __new__(
662        cls,
663        fake_mode: FakeTensorMode,
664        elem: Tensor,
665        device: torch.device,
666        constant: Optional[Tensor] = None,
667        real_tensor: Optional[Tensor] = None,
668    ) -> Self:
669        self = Tensor._make_subclass(
670            cls,
671            elem,
672            elem.requires_grad,
673            dispatch_device=True,
674            device_for_backend_keys=device,
675        )
676        if not fake_mode._allow_unsafe_data_ptr_access:
677            torch._C._set_throw_on_mutable_data_ptr(self)
678        else:
679            torch._C._set_warn_deprecated_on_mutable_data_ptr(self)
680
681        assert elem.device.type == "meta", elem.device.type
682        device = device if isinstance(device, torch.device) else torch.device(device)
683        # NB: it is fine, if a little confusing, for device to be meta
684        # (we are faking a meta tensor in that case).  However, it often
685        # indicates some sort of confusion (e.g., you accidentally passed
686        # in a meta tensor when you should have passed in the real tensor).
687        # So by default we disallow meta, and if you are working in a situation
688        # where it is helpful (e.g., crossref testing) you can turn it back
689        # on
690        if not fake_mode.allow_meta:
691            assert device.type != "meta"
692        # normalize device.
693        if device.type in ["cuda", "xpu"]:
694            init_gpu_context()
695
696        if (
697            device.type
698            in ["cuda", "hpu", "xpu", torch._C._get_privateuse1_backend_name()]
699            and device.index is None
700        ):
701            if getattr(torch, device.type).is_initialized():
702                device = torch.device(
703                    f"{device.type}:{getattr(torch, device.type).current_device()}"
704                )
705            else:
706                device = torch.device(f"{device.type}:0")
707        self.fake_device = device
708        self.fake_mode = fake_mode
709        self.constant = constant
710        assert not isinstance(real_tensor, FakeTensor)
711        self.real_tensor = real_tensor
712        self.nonzero_memo = None
713        self.item_memo = None
714        self.unique_memo = None
715        self.nested_int_memo = None
716
717        if FakeTensorConfig.debug:
718            self._debug_trace = CapturedTraceback.extract()  # type: ignore[attr-defined]
719        return self
720
721    # In some circumstances, a conventional Tensor constructor
722    # will get rewritten to call into FakeTensor.  We must provide an
723    # __init__ method that can accept the Python interpreters initialization
724    # in such a situation; we must also be able to handle direct fake
725    # tensor construction via FakeTensor().
726    #
727    # In particular, the __init__ call will look funny in the following case:
728    #
729    #   with FakeTensorMode():
730    #       x = Tensor([1, 2, 3])
731    #
732    # this desugars into:
733    #
734    #   with FakeTensorMode():
735    #       x = Tensor.__new__([1, 2, 3])
736    #       # NB: x is a fake tensor, because of the mode!
737    #       x.__init__([1, 2, 3])  # not the normal fake tensor args!
738    #
739    def __init__(self, *args: object, **kwargs: object) -> None:
740        super().__init__()
741
742    @staticmethod
743    def from_tensor(t: Tensor, fake_mode: FakeTensorMode) -> FakeTensor:
744        return fake_mode.from_tensor(t)
745
746    @classmethod
747    @count
748    def __torch_dispatch__(
749        cls,
750        func: OpOverload,
751        types: Sequence[Type],
752        args: Sequence[object] = (),
753        kwargs: Mapping[str, object] = immutable_dict(),
754    ) -> object:
755        # need to handle here to avoid infinite recursion
756        # see [in_kernel_invocation]
757        if func == torch.ops.prim.device.default:
758            assert len(args) == 1 and isinstance(args[0], FakeTensor)
759            if args[0].fake_mode.in_kernel_invocation:
760                return torch.device("meta")
761            else:
762                return args[0].fake_device
763
764        # this handler must be done inside FakeTensor subclass, not mode, because
765        # we can end up dispatching here when we have a fake tensor with
766        # symbolic sizes running under in_kernel_invocation_manager.
767        # The subclass is asked to handle this query because size (not
768        # sym_size) was called, but we are unable to serve it directly because
769        # there are symbolic sizes in the class.  The use of
770        # in_kernel_invocation_manager means it's incorrect to activate a
771        # mode to actually handle this (this caused
772        # https://github.com/pytorch/pytorch/issues/122772).
773        if handler := _DISPATCH_META_HANDLERS.get(func):
774            return handler(args)
775
776        # Because fake mode can return NotImplemented (if it sees a subclass
777        # it doesn't know how to deal with), this test here is important
778        # because the next dispatch after a fake mode will attempt to use
779        # subclasses of tensors to dispatch, and any FakeTensor arguments
780        # will be considered eligible.
781        unrecognized_types = [
782            t for t in types if not issubclass(t, FakeTensor) and t is not Tensor
783        ]
784        if unrecognized_types:
785            not_implemented_log.debug(
786                "FakeTensor unrecognized subclass(es): %s", unrecognized_types
787            )
788            return NotImplemented
789
790        fake_mode = None
791        for arg in pytree.arg_tree_leaves(*args, **kwargs):
792            if isinstance(arg, FakeTensor):
793                fake_mode = arg.fake_mode
794                break
795
796        assert fake_mode is not None
797
798        # If the fake mode is already active, don't try to reapply it!
799        # NotImplemented is the right thing to return here, because the
800        # typical situation this can occur is if ProxyTensorMode returned a
801        # NotImplemented because of a not implemented subclass; we may have
802        # unluckily attempted to hit FakeTensor's dispatch first,
803        # NotImplemented lets us keep chaining until we find the actual
804        # subclass
805        maybe_cur_fake_mode = torch._C._get_dispatch_mode(
806            torch._C._TorchDispatchModeKey.FAKE
807        )
808        if maybe_cur_fake_mode:
809            not_implemented_log.debug(
810                "FakeTensor mode already active: %s in %s",
811                fake_mode,
812                maybe_cur_fake_mode,
813            )
814            return NotImplemented
815
816        assert not fake_mode.in_kernel_invocation
817
818        with fake_mode:
819            return func(*args, **kwargs)
820
821    @staticmethod
822    def _find_common_device(
823        func: OpOverload, flat_args: Sequence[object]
824    ) -> Tuple[torch.device, bool]:
825        # Returns: (common_device, has_scalar_only_inputs)
826
827        # cpu - zero-dim tensors can be called in cuda kernels,
828        # so overwrite the common_device if it the only existing
829        # device comes from a cpu zero-dim tensor
830        common_device = None
831        has_scalar_only_inputs = False
832        is_cpu_zero_dim = None
833
834        def cpu_zero_dim(t: Tensor) -> bool:
835            return t.device.type == "cpu" and t.dim() == 0
836
837        def merge_devices(t: object) -> None:
838            nonlocal common_device
839            nonlocal is_cpu_zero_dim
840            if not isinstance(t, FakeTensor):
841                return
842
843            if common_device is None:
844                common_device = t.device
845                is_cpu_zero_dim = cpu_zero_dim(t)
846                return
847
848            t_is_cpu_zero_dim = cpu_zero_dim(t)
849            if t.device == common_device:
850                if is_cpu_zero_dim:
851                    is_cpu_zero_dim = t_is_cpu_zero_dim
852                return
853
854            # mismatching devices !
855            # if current tensor is cpu 0 dim, defer to existing device
856            if t_is_cpu_zero_dim:
857                return
858
859            # current device is from cpu 0 dim tensor, overwrite
860            if is_cpu_zero_dim:
861                common_device = t.device
862                is_cpu_zero_dim = t_is_cpu_zero_dim
863                return
864
865            # mismatching devices of non-zero dim tensors, throw
866            # This might be valid behavior and need to be explicitly modeled, e.g. reshape_as
867            raise RuntimeError(
868                f"Unhandled FakeTensor Device Propagation for {func}, found two different devices {common_device}, {t.device}"
869            )
870
871        for arg in flat_args:
872            merge_devices(arg)
873
874        # some functions that allow Python numbers to bind to Tensors
875        # if we have failed to find a device, and we're running one of these operators,
876        # we must have scalar only inputs
877        if should_allow_numbers_as_tensors(func) and common_device is None:
878            # ops with scalar only inputs always have result on cpu
879            has_scalar_only_inputs = True
880            common_device = torch.device("cpu")
881
882        assert common_device is not None, f"Could not find common device for {func}"
883
884        return common_device, has_scalar_only_inputs
885
886    def get_nested_int(
887        self,
888        *,
889        coeff: Union[int, torch.SymInt] = 1,
890    ) -> torch.SymInt:
891        if self.nested_int_memo is None:
892            self.nested_int_memo = self.fake_mode.create_symbolic_nested_int(
893                nt_tensor_id=None
894            )
895        return self.nested_int_memo * coeff
896
897    # Similar to FunctionalTensor.tolist
898    def tolist(self) -> Any:
899        if self.dim() == 0:
900            return self.item()
901        elif self.dim() == 1:
902            return [elem.item() for elem in self]
903        else:
904            return [elem.tolist() for elem in self]
905
906
907_MetadataIntLike = Union[IntLikeType, "_PySymInputStub", "_SymIntOutputStub"]
908
909
910@dataclass_slots
911@dataclass
912class TensorMetadata:
913    """
914    The Tensor metadata relevant to hashing FakeTensors when caching.
915    """
916
917    dtype: torch.dtype
918    shape: Tuple[_MetadataIntLike, ...]
919    stride: Tuple[_MetadataIntLike, ...]
920    device: torch.device
921    layout: torch.layout
922    memory_format: Optional[torch.memory_format]
923    storage_offset: _MetadataIntLike
924    storage_bytes: Optional[_MetadataIntLike]
925    requires_grad: bool
926    is_quantized: bool
927    is_conj: bool
928    is_neg: bool
929    is_inference: bool
930    is_sparse: bool  # read: is sparse COO
931    is_coalesced: Optional[bool]
932    dense_dim: Optional[int]
933    sparse_dim: Optional[int]
934
935    def _flatten_into(
936        self,
937        result: List[object],
938        mode: FakeTensorMode,
939        state: _CacheKeyState,
940    ) -> None:
941        # Flatten the TensorMetadata out into `result`.  Make sure to call
942        # state.convert_sym_int() on any SymInts.
943        for field in dataclasses.fields(self):
944            value = getattr(self, field.name)
945            if isinstance(value, (tuple, list, torch.Size)):
946                # This will recursively flatten the iterable, calling
947                # convert_sym_int() as necessary.
948                mode._prep_args_for_hash(result, value, state)
949            elif isinstance(value, SymInt):
950                state.convert_sym_int(result, value)
951            else:
952                result.append(value)
953
954
955def extract_tensor_metadata(t: Tensor) -> TensorMetadata:
956    """
957    Extract the TensorMetadata of a tensor.
958    """
959    memory_format: Optional[torch.memory_format] = suggest_memory_format(t)
960    # Don't call is_contiguous() on a Tensor which has symbolic sizes or things
961    # will go badly (guards will be messed up?)
962    if (
963        t._has_symbolic_sizes_strides
964        or is_sparse_any(t)
965        or not t.is_contiguous(memory_format=memory_format)
966    ):
967        memory_format = None
968
969    storage_offset = t.storage_offset()
970
971    return TensorMetadata(
972        t.dtype,
973        t.shape,
974        t.stride() if t.layout == torch.strided else (),
975        t.device,
976        t.layout,
977        memory_format,
978        storage_offset,
979        # Only set storage_bytes for tensors that have storage (not sparse)
980        t.untyped_storage().nbytes() if not is_sparse_any(t) else None,
981        t.requires_grad,
982        t.is_quantized,
983        t.is_conj(),
984        t.is_neg(),
985        t.is_inference(),
986        t.is_sparse,
987        t.is_coalesced() if t.is_sparse else None,
988        t.dense_dim() if is_sparse_any(t) else None,
989        t.sparse_dim() if is_sparse_any(t) else None,
990    )
991
992
993@dataclass_slots
994@dataclass
995class _DispatchCacheKey:
996    """
997    Key for the FakeTensor dispatch cache.
998    """
999
1000    key: Tuple[object, ...]
1001    hashvalue: int
1002
1003    def __init__(self, tup: Tuple[object, ...]) -> None:
1004        self.key = tup
1005        self.hashvalue = hash(tup)
1006
1007    def __eq__(self, other: object) -> bool:
1008        return isinstance(other, _DispatchCacheKey) and self.key == other.key
1009
1010    def __hash__(self) -> int:
1011        return self.hashvalue
1012
1013    def strip_shape_env(self) -> None:
1014        # We need to strip the ShapeEnv from any values before we store in the
1015        # cache so the cache doesn't keep our ShapeEnvs alive.
1016        for v in self.key:
1017            if isinstance(v, _PySymInputStub):
1018                v.strip_shape_env()
1019
1020
1021@dataclass_slots
1022@dataclass(frozen=True)
1023class _DispatchCacheEntry:
1024    """
1025    Entry type for the FakeTensor dispatch cache. Accounts for two possibilities:
1026    1) The op is inplace, and a hit means we need to alias the argument at a
1027       given index.
1028    2) We need to synthesize a new FakeTensor given tensor metadata. For view
1029       ops, we further capture the index of the arg to alias.
1030    """
1031
1032    inplace_idx: Optional[int]
1033    metadata: Optional[TensorMetadata]
1034    view_idx: Optional[int]
1035
1036
1037@dataclass_slots
1038@dataclass(frozen=True)
1039class _BypassDispatchCache(Exception):
1040    """
1041    Signals cases that should skip FakeTensor caching.
1042    """
1043
1044    reason: str
1045
1046
1047@dataclass_slots
1048@dataclass(frozen=True)
1049class DispatchCacheInfo:
1050    """
1051    Information about the state of the FakeTensor dispatch cache.
1052    """
1053
1054    hits: int
1055    misses: int
1056    bypasses: Dict[str, int]
1057    size: int
1058
1059
1060# We keep one instantiation of `fake_tensor_converter` active
1061# for the duration of `with FakeTensorMode()`.
1062# This allows accurate storage aliasing across invocation of
1063# different operators. While this will keep all freshly allocated
1064# tensors alive during `FakeTensorMode`, there will no be no
1065# new allocations of Tensors which have non-meta storage so
1066# memory should not significantly increase.
1067
1068
1069class FakeTensorMode(TorchDispatchMode):
1070    cache: Dict[_DispatchCacheKey, _DispatchCacheEntry] = {}
1071    cache_hits: int = 0
1072    cache_misses: int = 0
1073    cache_bypasses: Dict[str, int] = defaultdict(int)
1074    # Every time you retrace using the same fake tensor mode, you should
1075    # advance the epoch so we don't reuse unbacked memos
1076    epoch: int = 0
1077    in_kernel_invocation: bool = False
1078    static_shapes: bool
1079    shape_env: Optional[ShapeEnv]
1080    _stack: Optional[str]
1081    allow_meta: bool
1082
1083    # NestedTensor uses a tensor_id_counter to uniquely identify offsets.
1084    # This counter is incremented when an offsets is used to create an NJT
1085    # for the first time. To avoid mutating eager state if we construct NJT
1086    # during tracing, we maintain a separate counter on the FakeTensorMode.
1087    # The initial count is set to the current eager tensor_id_counter value
1088    # upon initialization, and every time you retrace using the same fake tensor
1089    # mode, you should reset the counter to the initial count.
1090    nt_tensor_id_counter: int = -1
1091    nt_tensor_id_initial_count: int = -1
1092
1093    def __init__(
1094        self,
1095        *,
1096        allow_fallback_kernels: bool = True,
1097        allow_non_fake_inputs: bool = False,
1098        shape_env: Optional[ShapeEnv] = None,
1099        static_shapes: Optional[bool] = None,
1100        # TODO: This is a temporary measure, see
1101        # https://github.com/pytorch/pytorch/pull/126245#discussion_r1604185748
1102        # We're currently solely using this to impede population of
1103        # item_memo for 0d scalar tensor inputs when export, because this
1104        # causes things that used to be deferred runtime asserts to turn into
1105        # guards, and then the guards are just lost.  We can potentially fix
1106        # this by ensuring guards also get put in the graph, but this is
1107        # pending a rework of how deferred runtime asserts in export.  Once
1108        # that's done, we can remove this.
1109        export: bool = False,
1110    ) -> None:
1111        log.debug("create_mode 0x%x", id(self))
1112        super().__init__()
1113        self.allow_fallback_kernels = allow_fallback_kernels
1114
1115        import torch._dynamo.config
1116        import torch._functorch.config
1117
1118        self.propagate_real_tensors = (
1119            torch._functorch.config.fake_tensor_propagate_real_tensors
1120        )
1121        self.fake_tensor_converter = FakeTensorConverter(
1122            copy_data=self.propagate_real_tensors,
1123            export=export,
1124        )
1125
1126        if static_shapes is not None:
1127            self.static_shapes = static_shapes
1128        else:
1129            self.static_shapes = shape_env is None
1130
1131        # This is temporarily patched to True in Dynamo to grandfather in some
1132        # places where we unconditionally allow scalar outputs, TO BE REMOVED
1133        self.allow_scalar_outputs = False
1134
1135        self._allow_unsafe_data_ptr_access = (
1136            torch._functorch.config.fake_tensor_allow_unsafe_data_ptr_access
1137        )
1138        self.allow_meta = torch._functorch.config.fake_tensor_allow_meta
1139        self.cache_enabled = (
1140            torch._dynamo.config.fake_tensor_cache_enabled
1141            and not self.propagate_real_tensors
1142        )
1143        self.cache_crosscheck_enabled = (
1144            torch._dynamo.config.fake_tensor_cache_crosscheck_enabled
1145        )
1146
1147        # A flag that controls, whether we want to invoke ops on mix of
1148        # real weights/global variables and fake inputs
1149        self.allow_non_fake_inputs = allow_non_fake_inputs
1150
1151        # [in_kernel_invocation]
1152        # when FakeTensor is invoked in user code, .device should return
1153        # the fake_device of the tensor so that code such as as `if x.is_cuda`
1154        # or torch.zeros([10, 10], device=x.device) continues to execute as if
1155        # the FakeTensor were real. However, within kernel execution, we return
1156        # the `Meta` device because all computation within the kernels should
1157        # behave as if the Tensors are on meta devices. Kernels should allocate
1158        # new tensors on meta devices, and checks like `is_meta` should return true.
1159        # within python refs, we always return the real device by defining
1160        # the device property
1161        self.in_kernel_invocation = False
1162
1163        # True if we enter'ed and actually enabled fake tensor mode,
1164        # false if it was a no-op.  Not thread safe but neither is
1165        # in_kernel_invocation
1166        # If another fake mode was already active when we enter, we also stash it here.
1167        # That way when we exit, we know to re-enable the previous fake mode.
1168        self.enter_stack: List[
1169            Tuple[bool, Optional[TorchDispatchMode], Optional[bool]]
1170        ] = []
1171
1172        self.shape_env = shape_env
1173
1174        self._stack_trace = traceback.extract_stack()
1175        self._stack = None
1176
1177        # Indicates to our torch_dispatch dispatching infra that
1178        # this is an "infra" mode with lower dispatching precedence.
1179        self._mode_key = torch._C._TorchDispatchModeKey.FAKE
1180
1181        import torch.nested._internal.nested_tensor
1182
1183        self.nt_tensor_id_initial_count = (
1184            torch.nested._internal.nested_tensor._tensor_id_counter
1185        )
1186        self.nt_tensor_id_counter = self.nt_tensor_id_initial_count
1187
1188    def reset_nt_tensor_id_counter(self) -> None:
1189        self.nt_tensor_id_counter = self.nt_tensor_id_initial_count
1190
1191    # Typically, there is only one fake tensor mode and you test for it by
1192    # doing an isinstance test.  However, in some situations, there might be
1193    # TWO fake tensor modes.  The canonical example of this is exporting
1194    # a fake model: there is an outer fake mode created by the user, and
1195    # an inner fake mode created by Dynamo.  The two phase process is required
1196    # because the outer fake mode typically won't have a ShapeEnv, even if
1197    # the user is interested in exporting with dynamic shapes (so the inner
1198    # fake mode will actually have a ShapeEnv and swap in symbolic sizes.)
1199    #
1200    # In this case, it's insufficient to test only one FakeTensor: you need
1201    # to distinguish between our fake tensor and other fake tensors.  That's
1202    # what this function does.
1203    def is_our_fake(self, t: object) -> TypeGuard[FakeTensor]:
1204        return isinstance(t, FakeTensor) and t.fake_mode is self
1205
1206    # If we should avoid device init. This changes the behavior of various APIs:
1207    # - We avoid constant-prop on Tensors with ops that move them to another device
1208    # - We change the torch.tensor ctor contract to never materialize
1209    #   tensors on device
1210    #   (see NOTE: [torch.tensor, lift_fresh, and device movement])
1211    @property
1212    def avoid_device_init(self) -> bool:
1213        if torch.xpu._is_compiled():
1214            assert not torch.cuda._is_compiled()
1215            return not torch.xpu.is_available()
1216
1217        return not torch.cuda.is_available()
1218
1219    @property
1220    def stack(self) -> str:
1221        if self._stack is None:
1222            self._stack = "".join(traceback.format_list(self._stack_trace))
1223        return self._stack
1224
1225    @count
1226    def __torch_dispatch__(
1227        self,
1228        func: OpOverload,
1229        types: Sequence[Type],
1230        args: Sequence[object] = (),
1231        kwargs: Mapping[str, object] = immutable_dict(),
1232    ) -> object:
1233        # FakeTensorMode should not be set when we're inside of it.
1234        assert (
1235            torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) is None
1236        ), func
1237        try:
1238            return self.dispatch(func, types, args, kwargs)
1239        except TypeError:
1240            log.exception("fake tensor raised TypeError")
1241            raise
1242
1243    # No-op if FakeTensorMode is already in use
1244    def __enter__(self) -> Self:
1245        import torch.nested._internal.nested_tensor
1246
1247        prev_only_lift_cpu_tensors = None
1248        if self.avoid_device_init:
1249            # See NOTE: [torch.tensor, lift_fresh, and device movement]
1250            prev_only_lift_cpu_tensors = torch._C._only_lift_cpu_tensors()
1251            torch._C._set_only_lift_cpu_tensors(True)
1252        maybe_prev_fake_mode = torch._C._unset_dispatch_mode(self._mode_key)
1253        if self is not maybe_prev_fake_mode:
1254            self.enter_stack.append(
1255                (True, maybe_prev_fake_mode, prev_only_lift_cpu_tensors)
1256            )
1257            return super().__enter__()
1258        else:
1259            # no-op (still need to re-set the fake mode though since we unset it)
1260            torch._C._set_dispatch_mode(self)
1261            self.enter_stack.append((False, None, prev_only_lift_cpu_tensors))
1262        return self
1263
1264    def __exit__(
1265        self,
1266        a: Optional[Type[BaseException]],
1267        b: Optional[BaseException],
1268        c: Optional[TracebackType],
1269    ) -> None:
1270        (
1271            live,
1272            maybe_prev_fake_mode,
1273            maybe_prev_only_lift_cpu_tensors,
1274        ) = self.enter_stack.pop()
1275        if live:
1276            out = super().__exit__(a, b, c)
1277            # Re-enable the previous fake mode, if there was one.
1278            if maybe_prev_fake_mode is not None:
1279                torch._C._set_dispatch_mode(maybe_prev_fake_mode)
1280            if maybe_prev_only_lift_cpu_tensors is not None:
1281                torch._C._set_only_lift_cpu_tensors(maybe_prev_only_lift_cpu_tensors)
1282
1283    @classmethod
1284    def is_infra_mode(cls) -> bool:
1285        return True
1286
1287    @classmethod
1288    def cache_info(cls) -> DispatchCacheInfo:
1289        """
1290        Query the state of the dispatch cache.
1291        """
1292        return DispatchCacheInfo(
1293            FakeTensorMode.cache_hits,
1294            FakeTensorMode.cache_misses,
1295            dict(FakeTensorMode.cache_bypasses),
1296            len(FakeTensorMode.cache),
1297        )
1298
1299    @classmethod
1300    def cache_clear(cls) -> None:
1301        """
1302        Clear the dispatch cache.
1303        """
1304        cls.cache_hits = 0
1305        cls.cache_misses = 0
1306        cls.cache_bypasses.clear()
1307        cls.cache.clear()
1308
1309    def _cached_dispatch_impl(
1310        self,
1311        func: OpOverload,
1312        types: Sequence[Type],
1313        args: Sequence[object],
1314        kwargs: Mapping[str, object],
1315    ) -> object:
1316        """
1317        Lookup a cache entry for the given arguments. If none exists, dispatch
1318        and cache the result (if the result is eligible for caching).
1319        """
1320        output: object = _UNASSIGNED
1321        try:
1322            state = _CacheKeyState(self.shape_env)
1323            key = self._cache_key(state, func, args, kwargs)
1324            if state.cache_on_shape_env():
1325                assert state.shape_env is not None
1326                cache = state.shape_env.fake_tensor_cache
1327            else:
1328                cache = FakeTensorMode.cache
1329            entry = cache.get(key, None)
1330            if entry is not None:
1331                output = self._output_from_cache_entry(state, entry, key, func, args)
1332                FakeTensorMode.cache_hits += 1
1333                if self.cache_crosscheck_enabled:
1334                    # For debugging / testing: Validate that the output synthesized
1335                    # from the cache matches the output created by normal dispatch.
1336                    self._crosscheck_cache_output(output, func, types, args, kwargs)
1337            else:
1338                self._validate_cache_key(func, args, kwargs)
1339                output = self._dispatch_impl(func, types, args, kwargs)
1340                entry = self._make_cache_entry(state, key, func, args, kwargs, output)
1341                key.strip_shape_env()
1342                cache[key] = entry
1343                FakeTensorMode.cache_misses += 1
1344        except _BypassDispatchCache as e:
1345            FakeTensorMode.cache_bypasses[e.reason] += 1
1346
1347        if output is _UNASSIGNED:
1348            output = self._dispatch_impl(func, types, args, kwargs)
1349
1350        return output
1351
1352    def _cache_key(
1353        self,
1354        state: _CacheKeyState,
1355        func: OpOverload,
1356        args: Sequence[object],
1357        kwargs: Mapping[str, object],
1358    ) -> _DispatchCacheKey:
1359        """
1360        Create a cache key given the dispatch args. Raises _BypassDispatchCache
1361        for any situation that precludes caching.
1362        """
1363        key_values = [
1364            func,
1365            # Capture the default_dtype mode since that can affect the output tensor,
1366            # e.g., when operating on constant float values.
1367            torch.get_default_dtype(),
1368            # Capture the current device to support, e.g., cache tensor creation,
1369            # where there isn't necessarily a tensor to take the device from.
1370            torch._C._get_default_device(),
1371            # We want to create tensors from cached metadata only when the inference
1372            # mode is the same.
1373            torch.is_inference_mode_enabled(),
1374            # Shape env settings could affect behavior. One example seen in the wild:
1375            # Disallowing dynamic shapes can introduce a DynamicOutputShapeException
1376            # where it wasn't seen on a previous instance of the same op.
1377            self.shape_env.settings if self.shape_env else None,
1378        ]
1379        # Translate any FakeTensor args to metadata.
1380        if args:
1381            self._prep_args_for_hash(key_values, args, state)
1382        if kwargs:
1383            self._prep_args_for_hash(key_values, kwargs, state)
1384        return _DispatchCacheKey(tuple(key_values))
1385
1386    def _validate_cache_key(
1387        self,
1388        func: OpOverload,
1389        args: Sequence[object],
1390        kwargs: Mapping[str, object],
1391    ) -> None:
1392        """
1393        Validate that the cache key generated by _cache_key will be
1394        reasonable.
1395        """
1396        # Avoid caching for any ops that would require a more sophisticated
1397        # caching implementation, e.g., data dependent ops or ops that modify
1398        # the inputs.
1399        if torch.Tag.data_dependent_output in func.tags:
1400            raise _BypassDispatchCache("data dependent output")
1401
1402        if torch.Tag.dynamic_output_shape in func.tags:
1403            raise _BypassDispatchCache("dynamic output shape")
1404
1405        if torch.Tag.inplace_view in func.tags:
1406            raise _BypassDispatchCache("inplace view")
1407
1408        if func == aten._unsafe_view.default:
1409            raise _BypassDispatchCache("unsafe view")
1410
1411        if func in self.lift_fns:
1412            raise _BypassDispatchCache("lift")
1413
1414        if func.name() == "inductor::resize_storage_bytes_":
1415            raise _BypassDispatchCache("inductor::resize_storage_bytes_")
1416
1417        if not torch._library.utils.is_builtin(func):
1418            raise _BypassDispatchCache("non-builtin")
1419
1420        # In order to handle storage aliasing, we need to establish the alias
1421        # for any view op on a cache hit. But CompositeImplicitAutograd ops may
1422        # or may not alias the input, so just punt on caching these.
1423        if func.is_view and torch._C._dispatch_has_kernel_for_dispatch_key(
1424            func.name(), torch._C.DispatchKey.CompositeImplicitAutograd
1425        ):
1426            raise _BypassDispatchCache("CompositeImplicitAutograd")
1427
1428    def _prep_args_for_hash(
1429        self,
1430        result: List[object],
1431        args: Union[Mapping[str, object], Sequence[object], Iterable[object]],
1432        state: _CacheKeyState,
1433    ) -> None:
1434        """
1435        Translate the provided args into a form suitable for caching at FakeTensor
1436        dispatch, i.e., convert unhashable types like lists & dicts into tuples and
1437        convert FakeTensors into metadata. Raises _BypassDispatchCache to signal
1438        unsupported cases that should bypass caching.
1439        """
1440        if isinstance(args, dict):
1441            self._prep_args_for_hash(result, args.keys(), state)
1442            self._prep_args_for_hash(result, args.values(), state)
1443            return
1444
1445        for arg in args:
1446            if isinstance(arg, FakeTensor):
1447                if not self.is_our_fake(arg):
1448                    raise _BypassDispatchCache("not our fake")
1449                if arg.constant is not None:
1450                    raise _BypassDispatchCache("constant attribute")
1451                if is_sparse_any(arg):
1452                    raise _BypassDispatchCache(f"{arg.layout} tensor")
1453                # FIXME: For now back out caching when there are symbolic nbytes
1454                # - this doesn't seem to play nice with set(). See T196779132 for examples.
1455                if isinstance(arg.untyped_storage().nbytes(), SymInt):
1456                    raise _BypassDispatchCache("symbolic nbytes")
1457                metadata = extract_tensor_metadata(arg)
1458                metadata._flatten_into(result, self, state)
1459            elif isinstance(arg, Tensor):
1460                raise _BypassDispatchCache("non-fake tensor")
1461            elif isinstance(arg, SymInt):
1462                state.convert_sym_int(result, arg)
1463            elif isinstance(arg, (SymBool, SymFloat)):
1464                raise _BypassDispatchCache("symbolic shape")
1465            elif isinstance(arg, (list, tuple, dict)):
1466                self._prep_args_for_hash(result, arg, state)
1467            else:
1468                # It's important to capture the type of the arg since, e.g., 1 and 1.0
1469                # hash to the same value, but can produce different dtypes for the
1470                # output tensor.
1471                result.append(type(arg))
1472                result.append(arg)
1473
1474    def _make_cache_entry(
1475        self,
1476        state: _CacheKeyState,
1477        key: _DispatchCacheKey,
1478        func: OpOverload,
1479        args: Sequence[object],
1480        kwargs: Mapping[str, object],
1481        output: Optional[FakeTensor],
1482    ) -> _DispatchCacheEntry:
1483        """
1484        Make a cache entry object for the given 'output' Tensor. Raises
1485        _BypassDispatchCache if the output tensor has characteristics that
1486        prevent caching it.
1487        """
1488        if output is None:
1489            return _DispatchCacheEntry(inplace_idx=None, metadata=None, view_idx=None)
1490
1491        # Some ops return tuples of Tensors, but it's rare, so avoid
1492        # the complexity of caching other types.
1493        if not isinstance(output, FakeTensor):
1494            raise _BypassDispatchCache("non-FakeTensor output")
1495
1496        # Avoid caching FakeTensors with constants attached since those
1497        # can be invalidated.
1498        if output.constant is not None:
1499            raise _BypassDispatchCache("constant attribute")
1500
1501        # TODO: support caching sparse outputs?
1502        if output.is_sparse:
1503            raise _BypassDispatchCache("sparse output")
1504
1505        if is_sparse_compressed(output):
1506            raise _BypassDispatchCache("sparse compressed output")
1507
1508        # Can an in-place op really reference a kwarg? If so, then we need
1509        # to extend the implementation to handle it.
1510        for kval in kwargs.values():
1511            if id(kval) == id(output):
1512                raise _BypassDispatchCache("kwarg aliases output")
1513
1514        # If this is an in-place op, the entry records which input arg is aliased.
1515        for idx in range(len(args)):
1516            if id(args[idx]) == id(output):
1517                return _DispatchCacheEntry(
1518                    inplace_idx=idx, metadata=None, view_idx=None
1519                )
1520
1521        # Otherwise, create an entry that records the output tensor's metadata.
1522        view_idx = None
1523        if func.is_view:
1524            idxs = [i for i, t in enumerate(args) if isinstance(t, Tensor)]
1525            assert len(idxs) == 1
1526            view_idx = idxs[0]
1527
1528        metadata = extract_tensor_metadata(output)
1529        metadata.shape = tuple(state.convert_output(v) for v in metadata.shape)
1530        metadata.stride = tuple(state.convert_output(v) for v in metadata.stride)
1531        metadata.storage_offset = state.convert_output(metadata.storage_offset)
1532        metadata.storage_bytes = (
1533            None
1534            if metadata.storage_bytes is None
1535            else state.convert_output(metadata.storage_bytes)
1536        )
1537
1538        entry = _DispatchCacheEntry(
1539            inplace_idx=None,
1540            metadata=metadata,
1541            view_idx=view_idx,
1542        )
1543
1544        # N.B.: Some checks for bypassing the cache would be performed on the
1545        # output tensor synthesized from the cached metadata. As an optimization,
1546        # we can synthesize a tensor here and do the checks on that instance.
1547        # This approach keeps the (more frequent) cache-hit path as lightweight
1548        # as possible.
1549        synth_output = self._output_from_cache_entry(state, entry, key, func, args)
1550
1551        # Make sure the dispatch_key_set from the synthesized output tensor will
1552        # be the same.
1553        synth_key_set = torch._C._dispatch_key_set(synth_output)
1554        key_set = torch._C._dispatch_key_set(output)
1555        if synth_key_set != key_set:
1556            raise _BypassDispatchCache("dispatch_key_set mismatch")
1557
1558        return entry
1559
1560    def _output_from_cache_entry(
1561        self,
1562        state: _CacheKeyState,
1563        entry: _DispatchCacheEntry,
1564        key: _DispatchCacheKey,
1565        func: OpOverload,
1566        args: Sequence[object],
1567    ) -> Optional[FakeTensor]:
1568        """
1569        Create a new FakeTensor from the cache entry.
1570        """
1571        if entry.inplace_idx is not None:
1572            # This is an in-place op; return the aliased arg.
1573            inplace_arg = args[entry.inplace_idx]
1574            assert isinstance(inplace_arg, FakeTensor)
1575            return inplace_arg
1576
1577        # Synthesize a new FakeTensor with the cached metadata.
1578        metadata = entry.metadata
1579        if metadata is None:
1580            return None
1581
1582        assert not is_sparse_any(metadata)
1583
1584        def check_value(
1585            value: _MetadataIntLike, state: _CacheKeyState
1586        ) -> Union[IntLikeType]:
1587            if isinstance(value, _SymIntOutputStub):
1588                assert state.shape_env is not None
1589                return value.extract(key, state.shape_env)
1590            else:
1591                assert not isinstance(value, _PySymInputStub)
1592                return value
1593
1594        shape = tuple(check_value(v, state) for v in metadata.shape)
1595        stride = tuple(check_value(v, state) for v in metadata.stride)
1596        storage_offset = check_value(metadata.storage_offset, state)
1597        storage_bytes = (
1598            None
1599            if metadata.storage_bytes is None
1600            else check_value(metadata.storage_bytes, state)
1601        )
1602
1603        maybe_suppress: Callable[[], typing.ContextManager] = contextlib.nullcontext
1604        if self.shape_env is not None:
1605            maybe_suppress = self.shape_env.suppress_guards
1606
1607        with in_kernel_invocation_manager(self), maybe_suppress():
1608            empty = torch.empty_strided(
1609                shape,
1610                stride,
1611                dtype=metadata.dtype,
1612                layout=metadata.layout,
1613                device="meta",
1614                requires_grad=metadata.requires_grad,
1615            )
1616
1617        if metadata.is_conj:
1618            torch._C._set_conj(empty, True)
1619        if metadata.is_neg:
1620            torch._C._set_neg(empty, True)
1621
1622        if func.is_view:
1623            # For view ops, the storage should be the same as the tensor input.
1624            view_arg = args[cast(int, entry.view_idx)]
1625            assert isinstance(view_arg, FakeTensor)
1626            storage = view_arg.untyped_storage()
1627            with in_kernel_invocation_manager(self), maybe_suppress():
1628                empty.set_(storage, storage_offset, shape, stride)
1629
1630        return FakeTensor(self, empty, metadata.device)
1631
1632    def _crosscheck_cache_output(
1633        self,
1634        output: Optional[FakeTensor],
1635        func: OpOverload,
1636        types: Sequence[Type],
1637        args: Sequence[object],
1638        kwargs: Mapping[str, object],
1639    ) -> None:
1640        """
1641        Helper to validate that the output synthesized from the cache matches
1642        the output created by normal dispatch.
1643        """
1644        try:
1645            true_output = self._dispatch_impl(func, types, args, kwargs)
1646        except Exception as e:
1647            raise RuntimeError(
1648                f"FakeTensor cache crosscheck failure: func={func}, "
1649                f"args={args}, kwargs={kwargs}: Dispatch raised={e}"
1650            ) from e
1651        try:
1652            if (true_output is not None) and (output is not None):
1653                assert_metadata_eq(assert_eq, true_output, output)
1654            else:
1655                assert true_output is None
1656                assert output is None
1657        except Exception as e:
1658            raise RuntimeError(
1659                f"FakeTensor cache crosscheck failure: func={func}, "
1660                f"args={args}, kwargs={kwargs}"
1661            ) from e
1662
1663    def dispatch(
1664        self,
1665        func: OpOverload,
1666        types: Sequence[Type],
1667        args: Sequence[object] = (),
1668        kwargs: Mapping[str, object] = immutable_dict(),
1669    ) -> object:
1670        kwargs = kwargs or {}
1671        with no_dispatch():
1672            log.debug("%s %s %s", func, args, kwargs)
1673
1674        if func in _DISPATCH_META_HANDLERS:
1675            return _DISPATCH_META_HANDLERS[func](args)
1676
1677        if log.getEffectiveLevel() <= logging.DEBUG:
1678            log.debug(
1679                "%sFakeTensorMode.__torch_dispatch__: %s", " " * RECURSION_COUNT, func
1680            )
1681            # NOTE: incr is intentionally unused for a RAII pattern
1682            incr = IncrementRecursionCount()
1683
1684        # Some attribute queries that can be serviced directly
1685        # See Note [is_coalesced is dispatched]
1686        if func in _DISPATCH_HANDLE_DIRECTLY:
1687            # NB: no_dispatch is ok here too, this func is very simple
1688            with in_kernel_invocation_manager(self):
1689                return func(*args, **kwargs)
1690
1691        if self.cache_enabled:
1692            return self._cached_dispatch_impl(func, types, args, kwargs)
1693        else:
1694            return self._dispatch_impl(func, types, args, kwargs)
1695
1696    def _dispatch_impl(
1697        self,
1698        func: OpOverload,
1699        types: Sequence[Type],
1700        args: Sequence[object],
1701        kwargs: Mapping[str, object],
1702    ) -> Optional[FakeTensor]:
1703        flat_args, args_spec = pytree.tree_flatten((args, kwargs))
1704
1705        # DO NOT PUT LOGIC BEFORE UNRECOGNIZED TYPE CHECKING
1706        # We must throw NotImplemented in case of unrecognized types to handle subclasses.
1707        # Throwing the exception will pass the control to the next __torch_dispatch__.
1708        # See [subclass inputs] below
1709        # NB: If you're seeing a mysterious infinite loop involving fake
1710        # tensor, it might be related to this line.  Though I'm not sure
1711        # how you'll know to read this comment, as this line won't show up
1712        # in the stack trace.
1713        has_unrecognized_types = _check_for_subclass(flat_args)
1714        if has_unrecognized_types:
1715            unrecognized_types = [
1716                type(x) for x in flat_args if _check_for_subclass_arg(x)
1717            ]
1718            not_implemented_log.debug(
1719                "FakeTensorMode unrecognized subclass(es): %s", unrecognized_types
1720            )
1721            return NotImplemented
1722
1723        flat_arg_fake_tensors = [t for t in flat_args if self.is_our_fake(t)]
1724        has_symbolic_sizes = any(
1725            i._has_symbolic_sizes_strides for i in flat_arg_fake_tensors
1726        ) or any(isinstance(a, SymInt) for a in flat_args)
1727
1728        converter = self.fake_tensor_converter
1729
1730        is_lift_func = func in self.lift_fns
1731
1732        # To constant propagate through these functions:
1733        # 1, If this is a lift due to a torch.tensor call,
1734        #    the input tensor is guaranteed to be a
1735        #    constant, so we keep a copy of the original argument along so
1736        #    we can query it if we're asked to item() it at some later point.
1737        #    (Note that you can always call a lift fn manually, so we do
1738        #    have to check if there are any fake tensors!)
1739        # 2, Some functions that allow Python numbers to bind to Tensors, e.g, torch.div
1740        if (is_lift_func and not flat_arg_fake_tensors) or (
1741            should_allow_numbers_as_tensors(func)
1742            and not has_symbolic_sizes
1743            and not flat_arg_fake_tensors
1744        ):
1745            assert all(
1746                t.constant is not None for t in flat_arg_fake_tensors
1747            ), f"{func} should not have fake inputs without constants"
1748            const_flat_args = [
1749                a.constant if self.is_our_fake(a) else a for a in flat_args
1750            ]
1751            const_args, const_kwargs = pytree.tree_unflatten(const_flat_args, args_spec)
1752            out = func(*const_args, **const_kwargs)
1753            if type(out) is Tensor and self.may_turn_const(out):
1754                # NB: not in_kernel_invocation_manager because we're doing real
1755                # compute here
1756                # NB: no_dispatch() here is VERY DANGEROUS (like, segfault
1757                # dangerous) if this is actually a wrapper subclass tensor,
1758                # therefore the exact type test above
1759                with no_dispatch():
1760                    out = out.clone()
1761                return converter.from_real_tensor(self, out, make_constant=True)
1762
1763        # if we are in the dispatch mode, we will enter this function even if the inputs
1764        # are not FakeTensors. For now, throw if any non-Fake Tensor inputs
1765        # and just support constructors.
1766
1767        # this is generated from torch.tensor(), which does not use the
1768        # dispatcher, to allow wrapper subclasses to wrap the new tensor
1769        if is_lift_func:
1770            assert len(kwargs) == 0 and len(args) == 1, f"{args} {kwargs}"
1771
1772            if type(args[0]) is Tensor:
1773                return converter.from_real_tensor(self, args[0])
1774
1775        # If we are trying to avoid device init, then we need to avoid constant
1776        # prop on constant tensors for ops that change devices.
1777        avoiding_device_init = False
1778        if self.avoid_device_init:
1779            if (
1780                func == torch.ops.aten._to_copy.default
1781                and "device" in kwargs
1782                and kwargs["device"] != "cpu"
1783            ):
1784                avoiding_device_init = True
1785            if func == torch.ops.prims.device_put.default:
1786                avoiding_device_init = True
1787
1788        # Recompute flat_arg_fake_tensors here again in case some of the inputs
1789        # were real tensors and fakified in validate_and_convert_non_fake_tensors
1790        (flat_args, flat_arg_fake_tensors) = self.validate_and_convert_non_fake_tensors(
1791            func, converter, flat_args, args_spec
1792        )
1793        del args, kwargs  # Invalidated
1794
1795        # The current constant handling only support tracing systems
1796        # (aot autograd, torchdynamo) where each operation is run consecutively.
1797        # Because each operation is run in order, we can trace out and support
1798        # sequences like: x = torch.tensor(0.); y = x.add_(1)
1799        # Whenver a constant is written to but with inputs that cannot be evaluated
1800        # statically, such as random_(), we invalidate all constants that alias the input
1801        # We will rely on functionalization for use of fake tensors constants as persistent
1802        # objects on an FX Graph.
1803
1804        # We dispatch size/stride/numel on the FakeTensor not its constant, so bail on inplace_view
1805        all_constant = all(e.constant is not None for e in flat_arg_fake_tensors)
1806        if (
1807            torch.Tag.nondeterministic_seeded not in func.tags
1808            and torch.Tag.inplace_view not in func.tags
1809            and all_constant
1810            and len(flat_arg_fake_tensors) != 0
1811            and not has_symbolic_sizes
1812            and not avoiding_device_init
1813        ):
1814            const_flat_args = [
1815                a.constant if self.is_our_fake(a) else a for a in flat_args
1816            ]
1817            const_args, const_kwargs = pytree.tree_unflatten(const_flat_args, args_spec)
1818
1819            # NB: not in_kernel_invocation_manager(self) as we want to do REAL
1820            # compute
1821            with no_dispatch():
1822                out = func(*const_args, **const_kwargs)
1823
1824            flat_out = pytree.tree_leaves(out)
1825            flat_out_tensors = [t for t in flat_out if isinstance(t, Tensor)]
1826            all_constant = all(self.may_turn_const(t) for t in flat_out_tensors)
1827
1828            if all_constant:
1829                return pytree.tree_map_only(
1830                    Tensor,
1831                    lambda t: converter.from_real_tensor(self, t, make_constant=True),
1832                    out,
1833                )
1834
1835            # we weren't able to turn outputs to constants,
1836            # so invalidate all constants that might be aliases of the outputs
1837            for ten in flat_out_tensors:
1838                converter.invalidate_constant_aliases(ten)
1839
1840        # we are falling through to running non constant tensors, any input constant that
1841        # is written to must be invalidated
1842        args, kwargs = pytree.tree_unflatten(flat_args, args_spec)
1843        self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)
1844
1845        def maybe_to_real_tensor(t: T) -> Optional[Union[T, Tensor]]:
1846            if isinstance(t, FakeTensor):
1847                return t.real_tensor
1848            elif isinstance(t, py_sym_types):
1849                assert self.shape_env is not None
1850                return t.node.pytype(
1851                    t.node.expr.xreplace(self.shape_env.var_to_val).xreplace(
1852                        self.shape_env.unbacked_var_to_val
1853                    )
1854                )
1855            else:
1856                return t
1857
1858        from torch.fx.experimental.symbolic_shapes import (
1859            compute_unbacked_bindings,
1860            free_unbacked_symbols,
1861        )
1862
1863        nil = object()
1864
1865        real_out = nil
1866        if (
1867            self.propagate_real_tensors
1868            and all(e.real_tensor is not None for e in flat_arg_fake_tensors)
1869            # TODO: Handle SymFloat/SymBool
1870            and not any(
1871                (
1872                    isinstance(a, SymInt)
1873                    and (syms := free_unbacked_symbols(a))
1874                    and self.shape_env is not None
1875                    and any(s not in self.shape_env.unbacked_var_to_val for s in syms)
1876                )
1877                for a in flat_args
1878            )
1879        ):
1880            real_flat_args = [maybe_to_real_tensor(a) for a in flat_args]
1881            real_args, real_kwargs = pytree.tree_unflatten(real_flat_args, args_spec)
1882            real_out = func(*real_args, **real_kwargs)
1883        elif self.propagate_real_tensors:
1884            # This can happen occasionally legitimately, specifically when you
1885            # are inside the meta of a data dependent operation and you create
1886            # a tensor on an unbacked SymInt; at this point in time we don't
1887            # know what the unbacked SymInt is, but we will know later.
1888            # However, if there's a bug in the condition above, this condition
1889            # will also trigger.
1890            log.debug(
1891                "propagate_real_tensors skipped %s(%s, %s) %s",
1892                func,
1893                flat_arg_fake_tensors,
1894                flat_args,
1895                self.shape_env.unbacked_var_to_val if self.shape_env else None,
1896            )
1897
1898        def maybe_propagate_real_tensors(fake_out: T) -> T:
1899            import sympy
1900
1901            def go(t: object, real_t: Tensor) -> None:
1902                if isinstance(t, FakeTensor):
1903                    # NB: unconditionally overwrite
1904                    t.real_tensor = real_t
1905                elif isinstance(t, py_sym_types) and free_unbacked_symbols(t):
1906                    if isinstance(t.node.expr, sympy.Symbol):
1907                        assert self.shape_env is not None
1908                        self.shape_env.set_unbacked_var_to_val(t.node.expr, real_t)
1909
1910            if real_out is not nil:
1911                tree_map_(go, fake_out, real_out)
1912
1913                # If a data-dependent op is used in a decomposition, we
1914                # may need to get the unbacked settings "early"
1915                # TODO: Is this really needed?
1916                compute_unbacked_bindings(self.shape_env, fake_out, peek=True)
1917
1918            return fake_out
1919
1920        # Try for fastpath
1921        if has_symbolic_sizes:
1922            fast_impl = get_fast_op_impls().get(func)
1923            if fast_impl is not None:
1924                return maybe_propagate_real_tensors(fast_impl(self, *args, **kwargs))
1925
1926        # If there's a Python meta, prefer that over the decomposition
1927        from torch._decomp import meta_table as meta_table
1928
1929        if func not in meta_table and not self.cpp_meta_supports_symint(func):
1930            from torch._decomp import decomposition_table
1931
1932            # Prefer Python decompositions over C++ ones
1933            if func in decomposition_table and (
1934                has_symbolic_sizes
1935                or (
1936                    # TODO: Remove these exclusions, so that we can remove
1937                    # this leg entirely
1938                    torch_decomp_decompositions(func)
1939                    and all(not is_sparse_any(e) for e in flat_arg_fake_tensors)
1940                )
1941            ):
1942                with self:
1943                    return decomposition_table[func](*args, **kwargs)
1944
1945            with self:
1946                # Decomposes CompositeImplicitAutograd ops
1947                r = func.decompose(*args, **kwargs)
1948                if r is not NotImplemented:
1949                    return r
1950
1951        # prims already wrap FakeTensor inputs to FakeTensor outputs
1952        # and do device logic, we dont need do anything but run them
1953        # and ensure that Meta kernels are dispatched to (see)
1954        # Fake Tensor Dispatch Keys
1955        # TODO - we should be use the prim aten impl
1956        # TODO - fix prims complex ops
1957        if (
1958            "prims::" in func._schema.name
1959            and hasattr(func, "prim_meta_impl")
1960            and not stride_incorrect_op(func)
1961        ):
1962            with self:
1963                return maybe_propagate_real_tensors(
1964                    func.prim_meta_impl(*args, **kwargs)
1965                )
1966
1967        # Users can register FakeTensor rules for custom operators
1968        # Call them if they exist.
1969        maybe_fake_impl = torch._library.simple_registry.singleton.find(
1970            func.name()
1971        ).fake_impl.kernel
1972        if maybe_fake_impl:
1973            ctx = torch._library.fake_impl.FakeImplCtx(self, func)
1974            with torch._library.fake_impl.set_ctx_getter(lambda: ctx), self:
1975                result = maybe_fake_impl(*args, **kwargs)
1976                return maybe_propagate_real_tensors(result)
1977
1978        # special handling for funcs registered through `register_op_impl`,
1979        # e.g., manipulating args on constructor calls to construct meta tensors
1980        # and then afterwards wrapping them to a FakeTensor
1981        for run_impl_check, op_impl in op_implementations_checks:
1982            if run_impl_check(func):
1983                op_impl_out = op_impl(self, func, *args, **kwargs)
1984                if op_impl_out is not NotImplemented:
1985                    return maybe_propagate_real_tensors(op_impl_out)
1986
1987        def maybe_run_unsafe_fallback(
1988            error: Optional[RuntimeError] = None,
1989        ) -> Optional[FakeTensor]:
1990            # We infer the meta of a custom ops that return None to just
1991            # return None. custom ops are not allowed to mutate metadata
1992            # of their inputs, so this is safe.
1993            if torch._library.utils.can_generate_trivial_fake_impl(func):
1994                return None
1995            # no meta kernel registered, fallback to kernel for the device
1996            if has_symbolic_sizes or not self.can_run_unsafe_fallback(func):
1997                raise UnsupportedOperatorException(func)
1998            if error is None:
1999                error = UnsupportedOperatorException(func)
2000            return run_fallback_kernel(self, func, flat_args, args_spec, error)
2001
2002        # Optimization: If there is no Meta kernel, it takes a surprisingly long
2003        # amount of time to catch the NotImplementedError, so we check it here.
2004        if not has_meta(func):
2005            fallback = maybe_run_unsafe_fallback()
2006            return maybe_propagate_real_tensors(fallback)
2007
2008        # run kernel registered to meta for func, which include
2009        # python meta registrations, prims, decomps, and c++ meta fns (structured kernels)
2010        # It's possible that the kernel will return NotImplementedError
2011        try:
2012            with in_kernel_invocation_manager(self):
2013                r = func(*args, **kwargs)
2014        except NotImplementedError as not_implemented_error:
2015            return maybe_run_unsafe_fallback(not_implemented_error)
2016        except Exception:
2017            log.exception("failed while attempting to run meta for %s", func)
2018            raise
2019
2020        return maybe_propagate_real_tensors(
2021            self.wrap_meta_outputs_with_default_device_logic(
2022                r, func, flat_args, device=kwargs.get("device")
2023            )
2024        )
2025
2026    # WARNING: DO NOT add any additional namespaces/operators here if they refer to operators
2027    # outside of the pytorch/pytorch library! Any pre-existing things here
2028    # are either in the pytorch/pytorch library or have been grandfathered in.
2029    # The fallback does not always work and MAY CRASH and emit unreadable error messages
2030    # so it should not be allowed by default.
2031    _can_run_unsafe_fallback_allowed_namespaces = ordered_set(
2032        "debugprims",
2033        "prims",
2034        "aten",
2035        "xla",
2036        "vision",
2037        "torchtext",
2038        "torchaudio",
2039        "quantized",
2040    )
2041
2042    def can_run_unsafe_fallback(self, func: OpOverload) -> bool:
2043        if not self.allow_fallback_kernels:
2044            return False
2045        # It's OK to try the fallback for built-in ops (e.g. aten, prims)
2046        # because we control and test these but the fallback leads to unexpected behavior
2047        # in user-defined custom ops
2048        return (
2049            func.namespace in self._can_run_unsafe_fallback_allowed_namespaces
2050            or func.name() == "fbgemm::gmm"
2051        )
2052
2053    def validate_and_convert_non_fake_tensors(
2054        self,
2055        func: OpOverload,
2056        converter: FakeTensorConverter,
2057        flat_args: Sequence[object],
2058        args_spec: TreeSpec,
2059    ) -> Tuple[List[object], List[FakeTensor]]:
2060        """
2061        Checks if the list of tensors are fake tensors.
2062        If not, try to convert them to fake tensors.
2063        Returns the original args, kwargs, and a flattened list of (args, kwargs) that are fake tensors.
2064        """
2065        flat_arg_fake_tensors: List[FakeTensor] = []
2066
2067        def validate(x: T) -> Union[T, FakeTensor]:
2068            if not isinstance(x, Tensor):
2069                return x
2070
2071            nonlocal flat_arg_fake_tensors
2072            if not self.is_our_fake(x):
2073                if torch.Tag.inplace_view in func.tags:
2074                    args, kwargs = pytree.tree_unflatten(flat_args, args_spec)
2075                    raise AssertionError(
2076                        f"Can't call metadata mutating ops on non-Fake Tensor inputs. Found in {render_call(func, args, kwargs)}"
2077                    )
2078                if not self.allow_non_fake_inputs:
2079                    if isinstance(x, FakeTensor) and x.fake_mode is not self:
2080                        raise AssertionError("Mixing fake modes NYI")
2081                    args, kwargs = pytree.tree_unflatten(flat_args, args_spec)
2082                    raise AssertionError(
2083                        f"Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode "
2084                        f"with 'allow_non_fake_inputs'. Found in {render_call(func, args, kwargs)}"
2085                    )
2086
2087                out = converter.from_real_tensor(self, x)
2088            else:
2089                out = x
2090
2091            flat_arg_fake_tensors.append(out)
2092            return out
2093
2094        validated_args = [validate(a) for a in flat_args]
2095        return validated_args, flat_arg_fake_tensors
2096
2097    def wrap_meta_outputs_with_default_device_logic(
2098        self,
2099        r: object,
2100        func: OpOverload,
2101        flat_args: Sequence[object],
2102        device: torch.device,
2103    ) -> PyTree:
2104        converter = self.fake_tensor_converter
2105
2106        # Lazily initialized, in case there are no tensor returns
2107        common_device = None
2108        has_scalar_only_inputs = False
2109
2110        def wrap(e: T) -> Union[T, FakeTensor]:
2111            nonlocal common_device
2112            nonlocal has_scalar_only_inputs
2113
2114            if not isinstance(e, Tensor):
2115                return e
2116
2117            if common_device is None:
2118                (
2119                    common_device,
2120                    has_scalar_only_inputs,
2121                ) = FakeTensor._find_common_device(func, flat_args)
2122
2123            is_our_fake = self.is_our_fake(e)
2124            if is_our_fake:
2125                torch._check(
2126                    e.device == common_device,
2127                    lambda: f"FakeTensor is wrapped to wrong device, found {e.device}, expected {common_device}",
2128                )
2129                return cast(T, e)
2130            elif converter is not None:
2131                if has_scalar_only_inputs:
2132                    # Under FakeTensorMode, op accepts scalar only inputs, such as aten.add/sub/mul/div,
2133                    # returns a real scalar tensor on CPU. See TensorMeta() in _prims/__init__.py for details.
2134                    # We thus directly convert real tensor to fake tensor.
2135                    return converter.from_real_tensor(self, e)
2136                else:
2137                    return converter.from_meta_and_device(
2138                        self, e, device or common_device
2139                    )
2140            else:
2141                return e
2142
2143        return tree_map(wrap, r)
2144
2145    def create_symbolic_nested_int(
2146        self, *, nt_tensor_id: Optional[int] = None
2147    ) -> torch.SymInt:
2148        # See Note: [Creating symbolic nested int]
2149        # Returned nested int always has coeff=1; multiply the result by coeff if needed
2150        import torch.nested._internal.nested_tensor
2151
2152        if nt_tensor_id is None:
2153            nt_tensor_id = self.nt_tensor_id_counter
2154            assert self.enter_stack, "should only called while FakeTensorMode is active"
2155            self.nt_tensor_id_counter += 1
2156        hint = torch._C._get_nested_int(nt_tensor_id, 1)
2157
2158        src = torch._dynamo.source.EphemeralSource("intermediate_offsets_or_lengths")
2159        assert self.shape_env is not None
2160        ret = self.shape_env.create_symintnode(
2161            sym=self.shape_env.create_symbol(
2162                val=hint,
2163                source=src,
2164            ),
2165            hint=hint,
2166            source=src,
2167        )
2168        return ret
2169
2170    _cpp_meta_supports_symint = ordered_set(
2171        aten.empty.memory_format,
2172        aten.empty_strided.default,
2173        aten.as_strided_scatter.default,
2174        aten.as_strided.default,
2175        aten.as_strided_.default,
2176        aten.zeros.default,
2177        aten.detach.default,
2178        aten.view_as_real.default,
2179        aten.view_as_complex.default,
2180        aten.set_.source_Storage_storage_offset,
2181        aten._sparse_coo_tensor_with_dims_and_tensors.default,
2182    )
2183
2184    def cpp_meta_supports_symint(self, func: OpOverload) -> bool:
2185        if torch.Tag.view_copy in func.tags:
2186            return True
2187        return func in self._cpp_meta_supports_symint
2188
2189    lift_fns = ordered_set(aten.lift_fresh.default, aten.lift_fresh_copy.default)
2190
2191    def may_turn_const(self, t: Tensor) -> bool:
2192        return (
2193            t.numel() <= CONSTANT_NUMEL_LIMIT
2194            and not is_sparse_any(t)
2195            and not self.is_our_fake(t)
2196            and not t.device.type == "meta"
2197        )
2198
2199    def invalidate_written_to_constants(
2200        self,
2201        func: OpOverload,
2202        flat_arg_fake_tensors: Sequence[FakeTensor],
2203        args: Sequence[object],
2204        kwargs: Mapping[str, object],
2205    ) -> None:
2206        any_constant = any(e.constant is not None for e in flat_arg_fake_tensors)
2207        schema_info = get_schema_info(func)
2208        if any_constant and schema_info.is_mutable():
2209            _, new_kwargs = normalize_function(  # type: ignore[misc]
2210                func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True  # type: ignore[arg-type]
2211            )
2212            for k, v in new_kwargs.items():
2213                k = k if (k != "input" or schema_info.has_argument(k)) else "self"
2214                if (
2215                    self.is_our_fake(v)
2216                    and schema_info.is_mutable(k)
2217                    and v.constant is not None
2218                ):
2219                    self.fake_tensor_converter.invalidate_constant_aliases(v.constant)
2220
2221    def from_tensor(
2222        self,
2223        tensor: Tensor,
2224        *,
2225        static_shapes: Optional[bool] = None,
2226        source: Optional[Source] = None,
2227        symbolic_context: Optional[SymbolicContext] = None,
2228        trace: bool = True,
2229    ) -> FakeTensor:
2230        shape_env: Optional[ShapeEnv] = self.shape_env
2231        if static_shapes is None:
2232            static_shapes = self.static_shapes
2233        if static_shapes:
2234            assert (
2235                symbolic_context is None
2236            ), "cannot set both static_shapes and symbolic_context"
2237            shape_env = None
2238        return self.fake_tensor_converter.from_real_tensor(
2239            self,
2240            tensor,
2241            shape_env=shape_env,
2242            source=source,
2243            symbolic_context=symbolic_context,
2244            trace=trace,
2245        )
2246
2247
2248_StoragePointer = object
2249
2250
2251# NB: returns fake tensors
2252def run_fallback_kernel(
2253    fake_mode: FakeTensorMode,
2254    func: OpOverload,
2255    flat_args: Sequence[object],
2256    args_spec: PyTree,
2257    orig_not_implemented_exception: RuntimeError,
2258) -> FakeTensor:
2259    # these should all be supported, just to be safe
2260    # avoid fallback for operators which inplace modify metadata
2261    # because the input fake tensors would be umodified
2262    if torch.Tag.inplace_view in func.tags:
2263        raise orig_not_implemented_exception
2264
2265    inp_impls = {}
2266
2267    # Don't use in_kernel_invocation_manager(fake_mode) as we want to do
2268    # REAL compute (not with meta device)
2269    with no_dispatch():
2270
2271        def to_real_tensor(e: T) -> Union[T, Tensor]:
2272            if fake_mode.is_our_fake(e):
2273                out = torch.zeros_like(e, device=e.fake_device)
2274                if e.is_sparse:
2275                    out._coalesced_(e.is_coalesced())
2276                inp_impls[id(out)] = e
2277                return out
2278            return e
2279
2280        flat_args = [to_real_tensor(a) for a in flat_args]
2281        args, kwargs = pytree.tree_unflatten(flat_args, args_spec)
2282
2283        r = func(*args, **kwargs)
2284
2285    storages: Set[_StoragePointer] = set()
2286
2287    for e in flat_args:
2288        if isinstance(e, Tensor):
2289            if not is_sparse_any(e):
2290                storages.add(e._typed_storage()._cdata)
2291
2292    # TODO: also check metadata change on inputs
2293    # proper aliasing/metadata relationship between outputs and inputs will
2294    # not be set up, bc of conversion to device, unless we can reuse an
2295    # input impl
2296
2297    def map_out(e: T) -> Union[T, FakeTensor]:
2298        if id(e) not in inp_impls and (
2299            isinstance(e, Tensor)
2300            and not is_sparse_any(e)
2301            and e._typed_storage()._cdata in storages
2302        ):
2303            raise orig_not_implemented_exception
2304
2305        if isinstance(e, Tensor):
2306            if id(e) in inp_impls:
2307                return inp_impls[id(e)]
2308            else:
2309                return fake_mode.fake_tensor_converter.from_real_tensor(fake_mode, e)
2310        else:
2311            return e
2312
2313    return pytree.tree_map(map_out, r)
2314
2315
2316# Just for use to allow copying a module to fake tensors,
2317# does not apply elsewhere
2318class FakeCopyMode(TorchFunctionMode):
2319    def __init__(self, fake_mode: FakeTensorMode) -> None:
2320        self.fake_mode = fake_mode
2321
2322    def __torch_function__(
2323        self,
2324        func: OpOverload,
2325        types: Sequence[Type],
2326        args: Sequence[object] = (),
2327        kwargs: Optional[Mapping[str, object]] = None,
2328    ) -> FakeTensor:
2329        kwargs = kwargs if kwargs else {}
2330
2331        # clone will get called in Parameter deepcopy
2332        if func == torch._C.TensorBase.clone:
2333            assert isinstance(args[0], Tensor)
2334            return func(
2335                self.fake_mode.from_tensor(args[0], static_shapes=True), **kwargs
2336            )
2337        elif func == Tensor.__deepcopy__:
2338            assert len(args) == 2 and len(kwargs) == 0
2339            tensor = cast(Tensor, args[0])
2340            memo = cast(Dict[int, FakeTensor], args[1])
2341
2342            if id(tensor) in memo:
2343                return memo[id(tensor)]
2344
2345            out = self.fake_mode.from_tensor(tensor, static_shapes=True)
2346            memo[id(tensor)] = out
2347            return out
2348        else:
2349            with torch._C.DisableTorchFunctionSubclass():
2350                return func(*args, **kwargs)
2351
2352
2353def _device_handler(args: Sequence[object]) -> torch.device:
2354    # NB: Don't use is_our_fake, just serve the fake information
2355    # as is.  Notice we don't use 'self'; we use args[0].fake_mode
2356    # because they may not be the same.  It would also be possible
2357    # to return NotImplemented here, in which case the FakeTensor
2358    # handler on args[0] would handle it, but we're being nice and
2359    # short-circuiting quickly.
2360    assert len(args) == 1 and isinstance(args[0], FakeTensor)
2361    if args[0].fake_mode.in_kernel_invocation:
2362        return torch.device("meta")
2363    else:
2364        return args[0].fake_device
2365
2366
2367# [subclass inputs]
2368# Suppose we enable fake tensor mode.  This means that fake tensor
2369# mode will run first.  But what if we do an operation that
2370# involves a tensor subclass that will desugar into normal tensor
2371# operations?  Without returning NotImplemented, fake tensor mode will run first,
2372# decide that a conversion was made (since there was a non fake
2373# tensor argument), and report an error that converting non
2374# fake tensor is not supported.  What we actually wanted to happen
2375# was to give the subclass a chance to figure out what it wants to
2376# before erroring out. Returning NotImplemented here allows this.
2377def _check_for_subclass(flat_args: Sequence[object]) -> bool:
2378    return any(_check_for_subclass_arg(x) for x in flat_args)
2379
2380
2381def _check_for_subclass_arg(x: object) -> bool:
2382    return (
2383        not isinstance(x, FakeTensor)
2384        and isinstance(x, Tensor)
2385        and type(x) is not Tensor
2386        and type(x) is not torch.nn.Parameter
2387    )
2388
2389
2390_DISPATCH_META_HANDLERS = {
2391    torch.ops.prim.device.default: _device_handler,
2392    torch.ops.aten.size.default: lambda args: tuple(
2393        int(s) for s in cast(Tensor, args[0]).size()
2394    ),
2395    torch.ops.aten.stride.default: lambda args: tuple(
2396        int(s) for s in cast(Tensor, args[0]).stride()
2397    ),
2398    torch.ops.aten.storage_offset.default: lambda args: int(
2399        cast(Tensor, args[0]).storage_offset()
2400    ),
2401}
2402
2403_DISPATCH_HANDLE_DIRECTLY = ordered_set(
2404    torch.ops.aten.is_coalesced.default,
2405    torch.ops.aten.dense_dim.default,
2406    torch.ops.aten.sparse_dim.default,
2407)
2408
2409from torch._subclasses.fake_impls import (  # noqa: F401
2410    _device_not_kwarg_ops,
2411    _is_tensor_constructor,
2412    _like_tensor_constructors,
2413    contains_tensor_types,
2414    get_fast_op_impls,
2415    has_meta,
2416    op_implementations_checks,
2417    stride_incorrect_op,
2418)
2419
2420
2421@atexit.register
2422def dump_cache_stats() -> None:
2423    log.info("FakeTensor cache stats:")
2424    log.info("  cache_hits: %s", FakeTensorMode.cache_hits)
2425    log.info("  cache_misses: %s", FakeTensorMode.cache_misses)
2426    bypasses = FakeTensorMode.cache_bypasses
2427    if bypasses:
2428        log.info("  cache_bypasses:")
2429        width = max(len(k) for k in bypasses)
2430        for k, v in sorted(bypasses.items(), key=lambda i: -i[1]):
2431            log.info("    %-*s %s", width + 1, f"{k}:", v)
2432