xref: /aosp_15_r20/external/pytorch/torch/_subclasses/meta_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import contextlib
5import dataclasses
6import warnings
7import weakref
8from dataclasses import dataclass
9from typing import (
10    Any,
11    Callable,
12    ClassVar,
13    ContextManager,
14    Dict,
15    List,
16    Optional,
17    Tuple,
18    Type,
19    TYPE_CHECKING,
20    Union,
21)
22from typing_extensions import TypeAlias
23
24import torch
25from torch._C._autograd import CreationMeta
26from torch._C._functorch import (
27    _add_batch_dim,
28    _unwrap_functional_tensor,
29    _wrap_functional_tensor,
30    get_unwrapped,
31    is_batchedtensor,
32    is_functorch_wrapped_tensor,
33    is_gradtrackingtensor,
34    is_legacy_batchedtensor,
35    maybe_get_bdim,
36    maybe_get_level,
37    peek_interpreter_stack,
38)
39from torch._logging import trace_structured
40from torch.utils._mode_utils import no_dispatch
41from torch.utils._python_dispatch import is_traceable_wrapper_subclass
42from torch.utils.weak import WeakIdKeyDictionary
43
44
45if TYPE_CHECKING:
46    from torch._C._functorch import CInterpreter
47    from torch._guards import Source
48
49    # Import here to avoid cycle
50    from torch._subclasses.fake_tensor import FakeTensorMode
51
52    # Import the following modules during type checking to enable code intelligence features,
53    # Do not import unconditionally, as they import sympy and importing sympy is very slow
54    from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymbolicContext
55
56DimList = List
57
58
59def safe_is_leaf(t):
60    try:
61        return t.is_leaf
62    except RuntimeError:
63        # inference mode can trigger this
64        return False
65
66
67def safe_grad(t):
68    with warnings.catch_warnings():
69        warnings.filterwarnings("ignore", "The .grad attribute of a Tensor")
70        return t.grad
71
72
73def assert_eq(a, b):
74    assert a == b, f"{a} != {b}"
75
76
77def assert_metadata_eq(
78    assert_eq,
79    m1: Union[MetaTensorDesc, torch.Tensor],
80    m2: torch.Tensor,
81    *,
82    skip_symbolic=False,
83    skip_leaf=False,
84):
85    if isinstance(m1, torch.Tensor):
86        m1 = MetaTensorDescriber().describe_tensor(m1)
87
88    def go(m1, m2):
89        assert_eq(m1.dtype, m2.dtype)
90        if not skip_symbolic:
91            assert_eq(m1.shape, m2.shape)
92        assert_eq(m1.requires_grad, m2.requires_grad)
93        if not skip_leaf:
94            assert_eq(m1.is_leaf, m2.is_leaf)
95        # MetaTensorDesc doesn't store grad_fn; inferred from leaf
96        # assert_eq(m1.grad_fn is None, m2.grad_fn is None)
97        assert_eq(m1.is_sparse, m2.is_sparse)
98        assert_eq(m1.is_inference, m2.is_inference())
99        assert_eq(m1.is_conj, m2.is_conj())
100        assert_eq(m1.is_neg, m2.is_neg())
101        assert_eq(m1.grad is not None, safe_grad(m2) is not None)
102        if m1.grad is not None:
103            go(m1.grad, safe_grad(m2))
104        # TODO: move "assert_eq(m1.layout, m2.layout)" out of sparse
105        #       branches (but not ready for prime time yet)...
106        if m1.is_sparse:
107            assert_eq(m1.layout, m2.layout)
108            assert_eq(m1.dense_dim, m2.dense_dim())
109            assert_eq(m1.sparse_dim, m2.sparse_dim())
110            assert_eq(m1.is_coalesced, m2.is_coalesced())
111        elif is_sparse_compressed(m1):
112            assert_eq(m1.layout, m2.layout)
113            assert_eq(m1.dense_dim, m2.dense_dim())
114            assert_eq(m1.sparse_dim, m2.sparse_dim())
115        else:
116            if not skip_symbolic:
117                assert_eq(m1.stride, m2.stride())
118                assert_eq(m1.storage_offset, m2.storage_offset())
119            assert_eq(m1.is_view, m2._is_view())
120            if m1.is_view:
121                go(m1.base, m2._base)
122        # TODO: test if is resizable (no direct query for this atm)
123        # TODO: audit AutogradMeta to see if it matches
124        # TODO: test forward AD
125
126    return go(m1, m2)
127
128
129def is_sparse_coo(t):
130    return isinstance(t, torch.Tensor) and t.layout is torch.sparse_coo
131
132
133def is_sparse_compressed_layout(layout):
134    return layout in {
135        torch.sparse_csr,
136        torch.sparse_csc,
137        torch.sparse_bsr,
138        torch.sparse_bsc,
139    }
140
141
142def is_sparse_compressed(t):
143    return isinstance(t, torch.Tensor) and is_sparse_compressed_layout(t.layout)
144
145
146def is_sparse_any(t):
147    return is_sparse_coo(t) or is_sparse_compressed(t)
148
149
150# Don't use id() directly, because those can get reallocated over time.
151MetaStorageId: TypeAlias = int
152MetaTensorId: TypeAlias = int
153
154
155DESCRIBER_NEXT_ID = 0
156
157
158class MetaTensorDescriber:
159    """
160    Given a Tensor/Storage, generate a MetaTensorDesc/MetaStorageDesc
161    for it, which is enough information to reconstruct a meta tensor/fake tensor
162    corresponding to a Tensor as faithfully as possible.
163
164    This is a stateful conversion object because we keep track of the IDs
165    of the tensors/storages passed to us, so we can consistently give
166    the same ID when we see the same tensor/storage.
167    """
168
169    def __init__(self, *, copy_data=False):
170        global DESCRIBER_NEXT_ID
171        self.id = DESCRIBER_NEXT_ID
172        DESCRIBER_NEXT_ID += 1
173        self.next_tensor_id: MetaTensorId = 0
174        self.next_storage_id: MetaStorageId = 0
175        # Tensor -> int
176        self.lookup_tensor = WeakIdKeyDictionary()
177        # Storage -> int
178        self.lookup_storage = WeakIdKeyDictionary()
179        self.copy_data = copy_data
180        self.traced_tensors = set()
181        self.traced_storages = set()
182
183    def get_tensor_id(self, t: torch.Tensor):
184        if t not in self.lookup_tensor:
185            self.lookup_tensor[t] = self.next_tensor_id
186            self.next_tensor_id += 1
187        return self.lookup_tensor[t]
188
189    def get_storage_id(self, s: torch.UntypedStorage):
190        if s not in self.lookup_storage:
191            self.lookup_storage[s] = self.next_storage_id
192            self.next_storage_id += 1
193        return self.lookup_storage[s]
194
195    def describe_storage(self, s: torch.UntypedStorage, *, trace: bool = False):
196        r = MetaStorageDesc(
197            id=self.get_storage_id(s),
198            size=s.size(),
199            # NB: We don't do the copy yet; copy happens when we start
200            # creating the new storages
201            data=s if self.copy_data else None,
202        )
203        if trace and r.id not in self.traced_storages:
204            trace_structured(
205                "describe_storage",
206                metadata_fn=lambda: r.as_json(self.id),
207            )
208            self.traced_storages.add(r.id)
209        return r
210
211    def describe_tensor(
212        self, t: torch.Tensor, *, recurse: bool = True, trace: bool = False
213    ):
214        is_leaf = safe_is_leaf(t)
215        is_view = t._is_view()
216        is_sparse = t.is_sparse
217        layout = t.layout
218        is_nested = t.is_nested
219        is_traceable_wrapper_subclass_v = is_traceable_wrapper_subclass(t)
220        is_functorch_wrapped = is_functorch_wrapped_tensor(t)
221        is_mkldnn = t.is_mkldnn
222        is_batchedtensor_v = is_batchedtensor(t)
223        is_legacy_batchedtensor_v = is_legacy_batchedtensor(t)
224        is_gradtrackingtensor_v = is_gradtrackingtensor(t)
225        is_functorch_batched_or_grad = is_batchedtensor_v or is_gradtrackingtensor_v
226        is_functional = torch._is_functional_tensor(t)
227
228        storage = None
229        # NB: For compatibility, I default this to zero, as sometimes people
230        # still have stuffed zero into storage offset even though the tensor
231        # doesn't meaningfully have an offset
232        storage_offset = 0
233        if not (
234            is_sparse
235            or is_sparse_compressed_layout(layout)
236            or (is_nested and not is_traceable_wrapper_subclass_v)
237            or is_mkldnn
238            # TODO: TBH, functorch wrapped tensors probably should have
239            # storage associated with them
240            or is_functorch_wrapped
241            or is_legacy_batchedtensor_v
242        ):
243            # NB: We actually don't use storage to do views, but might as well
244            # put it in for accuracy
245            storage = self.describe_storage(t.untyped_storage(), trace=trace)
246            storage_offset = t.storage_offset()  # type: ignore[assignment]
247
248        stride = None
249        if not (
250            is_sparse
251            or is_sparse_compressed_layout(layout)
252            or (is_nested and not is_traceable_wrapper_subclass_v)
253        ):
254            # stride/storage_offset are called from is_functorch_wrapped,
255            # view_from_base, empty_create_subclass,
256            # sym_sizes_strides_storage_offset (empty_create)
257            stride = t.stride()
258
259        # NB: this technically should refer to functorch unwrapped tensor, but
260        # I am (perhaps abusively) using it to store both the functorch and
261        # non-functorch functional tensor
262        unwrapped = None
263        autograd_meta_from = None
264        current_level = None
265        if is_batchedtensor_v or is_gradtrackingtensor_v:
266            unwrapped = self.describe_tensor(get_unwrapped(t), trace=trace)
267        # xla and lazy tensors present as functional tensors, but we want them
268        # to be handled specially
269        elif is_functional and t.device.type not in ("xla", "lazy"):
270            if t._is_view():
271                raise RuntimeError(
272                    "Cannot safely fakify a view because this process drops the view information right now."
273                )
274            if not is_functorch_wrapped:
275                torch._sync(t)
276                unwrapped = self.describe_tensor(
277                    torch._from_functional_tensor(t), trace=trace
278                )
279                autograd_meta_from = t
280            else:
281                reapply_views = torch._C._functionalization_reapply_views_tls()
282                # NB: has side effects!
283                unwrapped = self.describe_tensor(
284                    _unwrap_functional_tensor(t, reapply_views), trace=trace
285                )
286                # TODO: It's pretty suspicious that functional tensors don't have
287                # valid level and thus we just grab whatever the current level
288                # is
289                current_level = torch._C._functorch.current_level()
290
291        maybe_functorch_stack = None
292        if is_functorch_wrapped:
293            with torch._functorch.pyfunctorch.temporarily_clear_interpreter_stack() as maybe_functorch_stack:
294                pass
295
296        attrs = None
297        ctx = None
298        type_v = None
299        if is_traceable_wrapper_subclass_v:
300            assert hasattr(t, "__tensor_flatten__")
301            raw_attrs, ctx = t.__tensor_flatten__()
302            attrs = {
303                attr: self.describe_tensor(getattr(t, attr), trace=trace)
304                for attr in raw_attrs
305            }
306            type_v = type(t)
307
308        from torch.nested._internal.nested_tensor import _tensor_symint_registry
309
310        # TODO: Is it important to enable torch.inference_mode before querying
311        # these values?
312        r = MetaTensorDesc(
313            id=self.get_tensor_id(t),
314            storage=storage,
315            is_inference=t.is_inference(),
316            is_leaf=is_leaf,
317            requires_grad=t.requires_grad,
318            # NB: ndim should be OK too but there is a disaster at
319            # python test/dynamo/test_subclasses.py -k test_user_overidden_property_unsupported
320            # Actually, this means that we have a little bit of a problem
321            # here, which is that there is some sensitivity to how exactly an
322            # access is done if you have a __torch_function__ subclass.  Maybe
323            # should disable torch function before doing accesses?
324            ndim=t.dim(),
325            dtype=t.dtype,
326            is_sparse=is_sparse,
327            is_mkldnn=is_mkldnn,
328            is_functorch_wrapped=is_functorch_wrapped,
329            is_batchedtensor=is_batchedtensor_v,
330            is_legacy_batchedtensor=is_legacy_batchedtensor_v,
331            is_gradtrackingtensor=is_gradtrackingtensor_v,
332            is_view=is_view,
333            is_conj=t.is_conj(),
334            is_neg=t.is_neg(),
335            is_parameter=isinstance(t, torch.nn.Parameter),
336            is_traceable_wrapper_subclass=is_traceable_wrapper_subclass_v,
337            is_nested=is_nested,
338            nested_int=(
339                _tensor_symint_registry[t].node.nested_int()
340                if t in _tensor_symint_registry
341                else None
342            ),
343            is_functional=is_functional,
344            layout=layout,
345            device=t.device,
346            size=t.size(),
347            stride=stride,
348            storage_offset=storage_offset,
349            dynamo_dynamic_indices=list(getattr(t, "_dynamo_dynamic_indices", set())),
350            sparse_dim=(
351                t.sparse_dim() if t.is_sparse or is_sparse_compressed(t) else None
352            ),
353            dense_dim=t.dense_dim() if t.is_sparse or is_sparse_compressed(t) else None,
354            is_coalesced=t.is_coalesced() if t.is_sparse else None,
355            # TODO: I actually think recursing here is correct, but we have at
356            # least an infinite cycle from base -> values -> base
357            # https://github.com/pytorch/pytorch/issues/122089
358            crow_indices=(
359                self.describe_tensor(t.crow_indices(), recurse=False, trace=trace)
360                if recurse and t.layout in {torch.sparse_csr, torch.sparse_bsr}
361                else None
362            ),
363            col_indices=(
364                self.describe_tensor(t.col_indices(), recurse=False, trace=trace)
365                if recurse and t.layout in {torch.sparse_csr, torch.sparse_bsr}
366                else None
367            ),
368            ccol_indices=(
369                self.describe_tensor(t.ccol_indices(), recurse=False, trace=trace)
370                if recurse and t.layout in {torch.sparse_csc, torch.sparse_bsc}
371                else None
372            ),
373            row_indices=(
374                self.describe_tensor(t.row_indices(), recurse=False, trace=trace)
375                if recurse and t.layout in {torch.sparse_csc, torch.sparse_bsc}
376                else None
377            ),
378            values=(
379                self.describe_tensor(t.values(), recurse=False, trace=trace)
380                if recurse and is_sparse_compressed(t)
381                else None
382            ),
383            grad=(
384                self.describe_tensor(safe_grad(t), trace=trace)
385                if safe_grad(t) is not None
386                else None
387            ),
388            creation_meta=(
389                torch._C._autograd._get_creation_meta(t) if t._is_view() else None
390            ),
391            unwrapped=unwrapped,
392            level=(
393                maybe_get_level(t)
394                if is_batchedtensor_v or is_gradtrackingtensor_v
395                else None
396            ),
397            bdim=maybe_get_bdim(t) if is_batchedtensor_v else None,
398            base=(
399                self.describe_tensor(t._base, trace=trace)
400                if recurse and t._is_view() and t._base is not None
401                else None
402            ),
403            fake_mode=torch._subclasses.fake_tensor.maybe_get_fake_mode(t),
404            view_func=t._view_func_unsafe,
405            attrs=attrs,
406            ctx=ctx,
407            type=type_v,
408            # NB: even if functorch is enabled, don't actually save the
409            # interpreter stack here unless we are actually functorch wrapped;
410            # it's irrelevant for non-functorch stuff
411            functorch_stack=maybe_functorch_stack,
412            autograd_meta_from=autograd_meta_from,
413            current_level=current_level,
414            data=t if self.copy_data else None,
415        )
416        if trace and r.id not in self.traced_tensors:
417            trace_structured(
418                "describe_tensor",
419                metadata_fn=lambda: r.as_json(self.id),
420            )
421            self.traced_tensors.add(r.id)
422        return r
423
424
425@dataclass(frozen=True)
426class MetaStorageDesc:
427    id: MetaStorageId
428    size: int
429    # NB: this is only populated with copy_data True, it is not directly
430    # serializable in JSON, you want to do something special here anyway
431    data: Optional[torch.UntypedStorage]
432
433    def as_json(self, describer_id):
434        return {
435            "id": self.id,
436            "describer_id": describer_id,
437            "size": self.size if isinstance(self.size, int) else repr(self.size),
438        }
439
440
441@dataclass(frozen=True)
442class MetaTensorDesc:
443    id: MetaTensorId
444    ndim: int
445    dtype: torch.dtype
446    device: torch.device
447
448    # NB: Sometimes, size, stride and storage_offset contain SymInt, in which
449    # case this is NOT serializable.  That only happens when you're
450    # re-fakeifying a fake tensor with an existing ShapeEnv... maybe we
451    # can get rid of this use case entirely.  Notably, even if we are
452    # fakeifying a real tensor into a fake tensor with symbolic shapes, the
453    # size here is NOT dynamic
454    # NB: These also contain SymInt because wrap_meta_outputs_with_default_device_logic
455    # goes through this codepath.  But it really should not LOL.
456    # NB: size could potentially be None as you can override it and make it
457    # throw an error, but we don't currently have any subclasses that do this
458    # except C++ nested tensor but we're going to have nested int to make this
459    # defined on NJT
460    size: Tuple[int, ...]
461    dynamo_dynamic_indices: List[int]
462
463    layout: torch.layout = torch.strided
464    is_inference: bool = False
465    is_leaf: bool = False
466    requires_grad: bool = False
467    is_sparse: bool = False
468    is_mkldnn: bool = False
469    is_functorch_wrapped: bool = False
470    is_batchedtensor: bool = False
471    is_legacy_batchedtensor: bool = False
472    is_gradtrackingtensor: bool = False
473    is_view: bool = False
474    is_nested: bool = False
475    # We eagerly symbolicize the associated nested int for e.g. offsets / lengths
476    # metadata if that offsets is already associated with a nested int.
477    # See test_construct_from_jagged_with_input_offsets_mixed_case.
478    nested_int: Optional[int] = None
479    is_traceable_wrapper_subclass: bool = False
480    is_functional: bool = False
481    is_conj: bool = False
482    is_neg: bool = False
483    is_parameter: bool = False
484    stride: Optional[Tuple[int, ...]] = None
485    storage_offset: int = 0
486    # NB: We have a choice whether or not to store the id or a direct pointer
487    # to the data structure.  For ease of use, we store the data structure,
488    # but this means that when we serialize, we have to swizzle these pointers
489    # back into ids (so we have accurate aliasing relationships)
490    storage: Optional[MetaStorageDesc] = None
491    sparse_dim: Optional[int] = None  # is_sparse, is_sparse_compressed
492    dense_dim: Optional[int] = None  # is_sparse, is_sparse_compressed
493    is_coalesced: Optional[bool] = None  # is_sparse
494    crow_indices: Optional[MetaTensorDesc] = None  # is_sparse_compressed
495    col_indices: Optional[MetaTensorDesc] = None  # is_sparse_compressed
496    ccol_indices: Optional[MetaTensorDesc] = None  # is_sparse_compressed
497    row_indices: Optional[MetaTensorDesc] = None  # is_sparse_compressed
498    values: Optional[MetaTensorDesc] = None  # is_sparse_compressed
499    unwrapped: Optional[MetaTensorDesc] = None  # is_functorch_wrapped
500    bdim: Optional[int] = None  # is_functorch_wrapped
501    base: Optional[MetaTensorDesc] = None  # is_view
502    attrs: Optional[Dict[str, MetaTensorDesc]] = None  # is_traceable_wrapper_subclass
503    creation_meta: Optional[CreationMeta] = None
504    grad: Optional[MetaTensorDesc] = None
505
506    # Everything below is NOT serializable, need some more work
507
508    _UNSERIALIZABLE: ClassVar[List[str]] = [
509        "ctx",
510        "type",
511        "fake_mode",
512        "view_func",
513        "level",
514        "current_level",
515        "functorch_stack",
516        "autograd_meta_from",
517        "data",
518        "nested_int",
519    ]
520
521    ctx: Optional[object] = None  # is_traceable_wrapper_subclass
522    type: Optional[Type] = None  # is_traceable_wrapper_subclass
523    fake_mode: Optional[FakeTensorMode] = None
524    view_func: Optional[
525        Callable[
526            [
527                torch.Tensor,
528                Callable[[int], int],
529                Callable[[torch.Tensor], torch.Tensor],
530            ],
531            torch.Tensor,
532        ]
533    ] = None
534    # level looks serializable, but actually it is meaningless without
535    # the functorch_stack below
536    level: Optional[int] = None  # is_functorch_wrapped
537    current_level: Optional[int] = None
538    functorch_stack: Optional[List[CInterpreter]] = None
539    autograd_meta_from: Optional[torch.Tensor] = None
540
541    # This is only populated on copy_data, and typically is not used at all,
542    # except for some of our meta-ification paths that don't properly use
543    # storage (pro-tip: you should use storage)
544    data: Optional[torch.Tensor] = None
545
546    # Faithfully serializing functorch tensors will not be too difficult.
547    # We only need to consider grad/vmap interpreters, and their internal
548    # state is only bools (mostly what the grad enabled/disabled state
549    # should be in the lower layer).  Beyond that, tensors just need to
550    # precisely indicate which particular interpreter they correspond
551    # to (we then replace level with a pointer to the interpreter stack.)
552    # However, this use of functorch is very "non-lexical" so it's not
553    # entirely clear how to make it all lexical again, so we haven't done
554    # it for now.
555
556    # NB: This will reference numeric IDs, and it is assumed that you've
557    # already serialized everything this recursively references
558    def as_json(self, describer_id):
559        def json(k, v):
560            # Some best-effort debugging serialization for unserializable
561            # fields (feel free to add other special cases as appropriate)
562            if k in ["data", "autograd_meta_from"]:
563                return None  # never repr these
564            if k in set(MetaTensorDesc._UNSERIALIZABLE):
565                return repr(v)
566            if isinstance(v, (torch.device, torch.dtype, torch.layout)):
567                return repr(v)
568            if isinstance(v, torch.SymInt):
569                return repr(v)
570            if isinstance(v, (tuple, list)):
571                return [json(k, v1) for v1 in v]
572            if isinstance(v, (MetaStorageDesc, MetaTensorDesc)):
573                return v.id
574            if isinstance(v, CreationMeta):
575                return str(v)
576            if k == "attrs" and isinstance(v, dict):
577                return {k1: v1.id for k1, v1 in v.items()}
578            return v
579
580        r = {
581            field.name: json(field.name, getattr(self, field.name))
582            for field in dataclasses.fields(self)
583            if not (
584                getattr(self, field.name) is field.default
585                or (
586                    field.name == "dynamo_dynamic_indices"
587                    and not getattr(self, field.name)
588                )
589            )
590        }
591        r.update({"describer_id": describer_id})
592        return r
593
594    @property
595    def shape(self):
596        return self.size
597
598
599# A more faithful reproduction would do a copy on the entire
600# storage, but this needs to be done carefully because the
601# underlying storage could have larger extent than is implied
602# by size/stride.  The real fix is to properly call
603# meta_storage recursively here.
604#
605# These "safe" functions are intended to be used under no_dispatch() mode.
606# The no_dispatch() here is intended to prevent ambient fake tensor mode from
607# fakeifying the operation.  But if we are given an honest to goodness
608# FakeTensor as src, we MUST NOT run the copy/clone operation.  A better way
609# to do this would be to not use no_dispatch and instead just disable fake
610# tensor mode only (allowing for subclass dispatch to occur)
611def _safe_copy(dst, src):
612    if type(src) is not torch.Tensor:
613        return
614    dst.copy_(src)
615
616
617def _safe_clone(src):
618    if type(src) is not torch.Tensor:
619        return None
620    return src.clone()
621
622
623# This is a class for converting multiple tensors into meta tensors which
624# share the same view/storage structure.  The operation model is you allocate
625# one of these, and then call it repeatedly on all the tensors you want to
626# convert.  It's important to use the same object for tensors you want to
627# share storage because this is how we correlate shared storages to the same
628# meta storages. This class will hold weak references to cached tenosrs
629# and tensor storages.
630class MetaConverter:
631    def __init__(self, *, copy_data: bool = False):
632        # Maps MetaStorageId to UntypedStorage
633        self.storage_memo: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
634        # Maps MetaTensorId to torch.Tensor (typically a meta tensor or
635        # FakeTensor)
636        self.tensor_memo: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
637        self.hit = 0
638        self.miss = 0
639        self.del_hook = None
640        self.arg_cnt = 0
641        # Ensures real_storage/real_tensor are populated on the resulting
642        # metaified storage/tensor.  The naming of this attribute is load
643        # bearing: FakeTensor relies on real tensor being set to exactly this
644        # value
645        self.copy_data = copy_data
646        self.describer = MetaTensorDescriber(copy_data=copy_data)
647
648    def successful(self):
649        return self.hit > 0 and self.miss == 0
650
651    def get_tensor_memo(self, t: MetaTensorDesc):
652        return self.tensor_memo.get(t.id, None)
653
654    def set_tensor_memo(self, t: MetaTensorDesc, v):
655        self.tensor_memo[t.id] = v
656
657    def get_storage_memo(self, s: MetaStorageDesc):
658        return self.storage_memo.get(s.id, None)
659
660    def set_storage_memo(self, s: MetaStorageDesc, v):
661        self.storage_memo[s.id] = v
662
663    def meta_storage(self, s: MetaStorageDesc, callback):
664        # If we are fakeifying a tensor that has a secretly-zero-sized storage,
665        # Need to make sure to resize the meta storage too.
666        if self.get_storage_memo(s) is None:
667            r_s = callback(
668                lambda: torch.empty(s.size, dtype=torch.uint8, device="meta"),
669            ).untyped_storage()
670            if self.copy_data:
671                # NB: no_dispatch is needed because internally storage copy is
672                # implemented as Tensor operations
673                with torch.no_grad(), no_dispatch():
674                    assert s.data is not None
675                    r_s.real_storage = s.data.clone()
676            self.set_storage_memo(s, r_s)
677            return r_s
678        else:
679            return self.get_storage_memo(s)
680
681    # This function assumes that it's possible to do the conversion
682    # NB: name here is used in a conventional way by Dynamo; it corresponds
683    # precisely to the Source.name() of the tensor we're fakeifying and
684    # corresponds to a valid Python expression.  When we construct sub-names
685    # as part of this process, we will maintain this invariant!  (Even though
686    # other users of this may not need it this property to be upheld.)
687    def meta_tensor(
688        self,
689        t: MetaTensorDesc,
690        shape_env: Optional[ShapeEnv] = None,
691        callback=lambda t: t(),
692        source: Optional[Source] = None,
693        symbolic_context: Optional[SymbolicContext] = None,
694    ):
695        if source is None:
696            from torch._dynamo.source import ConstantSource
697
698            # TODO: make a dedicated UnknownSource for this?
699            source = ConstantSource(
700                f"__meta_utils_unknown_tensor{len(self.tensor_memo)}"
701            )
702
703        # This indicates you set no_dispatch() before calling into this
704        # function.  This is an error: we may be creating fake tensors and
705        # will perform operations on them which need fake tensor mode to
706        # be active.  You will segfault if you are in a no_dispatch() block.
707        assert not torch._C._dispatch_tls_local_exclude_set().has(
708            torch._C.DispatchKey.Python
709        )
710        arg_cnt = self.arg_cnt
711        self.arg_cnt += 1
712
713        # When we make as_strided calls, we end up generating a guard
714        # that the new as_strided tensor is in bounds for the old storage
715        # for the base (since as_strided calls can "bust" out of their
716        # bounding box.)  This guard is unnecessary: if a user is able
717        # to provide us a tensor with the view base setup this way, we
718        # don't need to produce a guard, because the fact that they
719        # were able to produce the view base means its in bounds.
720        #
721        # Now, ordinarily, this guard would be harmless.  However, the
722        # generated guard refers to variables bound on the base variable.
723        # At the moment, Dynamo doesn't actually guard on x._base, because
724        # according to Voz this results in a lot of spurious invalidations,
725        # and also if the user doesn't directly make use of _base, its
726        # pointless anyway (because programs should be parametric over
727        # whether or not the input tensor is a view or not--unless you're
728        # mutating the input, but that's a whole 'nother ballgame).  So
729        # for expediency, we suppress these guards so we don't have to
730        # deal with this (yet, anyway.)
731        #
732        # NB: An old version of this code suppressed guards for ALL operations
733        # happening during meta conversion, not just as_strided calls.
734        # This is too aggressive: we do duck sizing and 0/1 simplification
735        # as we allocate variables, and we do need to register guards for
736        # these cases.
737        maybe_suppress: Callable[[], Any] = contextlib.nullcontext
738        if shape_env is not None:
739            maybe_suppress = shape_env.suppress_guards
740
741        def sym_sizes_strides_storage_offset(
742            t: MetaTensorDesc, src, symbolic_context=symbolic_context
743        ) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]:
744            assert t.stride is not None
745            if shape_env is not None:
746                fake_mode = t.fake_mode
747                if fake_mode is not None and fake_mode.shape_env is shape_env:
748                    # Don't reallocate the sizes; the shape envs are the same,
749                    # so reuse the old sizes/strides/etc
750                    return (t.size, t.stride, t.storage_offset)
751                else:
752                    # TODO: deduplicate this
753                    t_size = tuple(
754                        shape_env._maybe_specialize_sym_int_with_hint(sz)
755                        for sz in t.size
756                    )
757                    t_stride = tuple(
758                        shape_env._maybe_specialize_sym_int_with_hint(sd)
759                        for sd in t.stride
760                    )
761                    t_storage_offset = shape_env._maybe_specialize_sym_int_with_hint(
762                        t.storage_offset
763                    )
764                    return shape_env._create_symbolic_sizes_strides_storage_offset(
765                        t_size,
766                        t_stride,
767                        t_storage_offset,
768                        [d in t.dynamo_dynamic_indices for d in range(t.ndim)],
769                        src,
770                        symbolic_context=symbolic_context,
771                    )
772            else:
773                return (t.size, t.stride, t.storage_offset)
774
775        def empty_create(
776            inner_t: MetaTensorDesc, inner_src, symbolic_context=symbolic_context
777        ):
778            (
779                inner_sizes,
780                inner_strides,
781                inner_storage_offset,
782            ) = sym_sizes_strides_storage_offset(inner_t, inner_src, symbolic_context)
783            return torch.empty_strided(
784                inner_sizes,
785                inner_strides,
786                dtype=inner_t.dtype,
787                device="meta",
788            )
789
790        # Creates a subclass instance with empty inner tensors according to the specified
791        # symbolic context.
792        def empty_create_subclass(
793            t: MetaTensorDesc,
794            outer_size,
795            outer_stride,
796            symbolic_context=symbolic_context,
797            callback=callback,
798            source=source,
799        ):
800            from torch._dynamo.source import AttrSource
801            from torch.fx.experimental.symbolic_shapes import SubclassSymbolicContext
802
803            assert t.attrs is not None
804            assert t.type is not None
805            # NB: t.ctx could be None if the subclass in question has no
806            # meaningful context
807
808            # Note: transform_subclass will use __tensor_unflatten__ to generate
809            # a fresh subclass wrapper with outer sizes / strides according to the
810            # outer symbolic context (passed in to this function). Inner size / stride
811            # / storage offset symbols are allocated according to the appropriate inner
812            # symbolic contexts, after which the checks in transform_subclass() will
813            # relate them to the outer metadata as possible.
814            #
815            # Morally, the code here is same as transform_subclass, but we've
816            # written it from scratch to read EmptyCreateSubclass
817            outer_size = outer_size if outer_size is not None else t.size
818            outer_stride = outer_stride if outer_stride is not None else t.stride
819
820            assert symbolic_context is None or isinstance(
821                symbolic_context, SubclassSymbolicContext
822            )
823
824            def _empty_create_subclass(
825                t, outer_size, outer_stride, symbolic_context, callback, source
826            ):
827                # We are hitting plain meta_desc tensor so actually
828                # create a tensor here.
829                if t.attrs is None:
830                    return self.meta_tensor(
831                        t,
832                        shape_env=shape_env,
833                        callback=callback,
834                        source=source,
835                        symbolic_context=symbolic_context,
836                    )
837
838                inner_tensors = {}
839                for attr, meta_tensor_desc in t.attrs.items():
840                    current_context = None
841                    if symbolic_context is not None:
842                        current_context = symbolic_context.inner_contexts[attr]
843
844                    current_source = AttrSource(source, attr)
845                    new_empty_tensor = _empty_create_subclass(
846                        meta_tensor_desc,
847                        meta_tensor_desc.size,
848                        meta_tensor_desc.stride,
849                        current_context,
850                        callback,
851                        current_source,
852                    )
853                    inner_tensors[attr] = new_empty_tensor
854
855                return t.type.__tensor_unflatten__(
856                    inner_tensors, t.ctx, outer_size, outer_stride
857                )
858
859            sub = _empty_create_subclass(
860                t, outer_size, outer_stride, symbolic_context, callback, source
861            )
862
863            # NB: Purposefully guard here to simplify the inner / outer symbols.
864            # Using sym_eq() for symbolic comparison can result in an expression that's too
865            # difficult to guard on, so we use == here.
866            assert sub.shape == outer_size, (
867                f"Expected return value from {t.type}__tensor_unflatten__() to have "
868                f"shape equal to {outer_size}, but got: {sub.shape}"
869            )
870            assert sub.stride() == outer_stride, (
871                f"Expected return value from {t.type}__tensor_unflatten__() to have "
872                f"stride equal to {outer_stride}, but got: {sub.stride()}"
873            )
874
875            return sub
876
877        # Returns an all-dynamic symbolic context used for metafying the given tensor with
878        # fully dynamic dims. This is useful when fake-ifying intermediate tensors in
879        # closed-over ViewFunc state, as we don't have symbolic contexts for them, but we
880        # don't want to over-specialize during view replay.
881        def all_dynamic_symbolic_context(
882            t: MetaTensorDesc, source, shape_env, callback
883        ):
884            from torch._dynamo.source import AttrSource
885            from torch.fx.experimental.symbolic_shapes import (
886                DimDynamic,
887                StatelessSymbolicContext,
888                SubclassSymbolicContext,
889            )
890
891            view_base_context: Optional[SymbolicContext] = None
892            if t.is_view:
893                assert t.base is not None
894                view_base_context = all_dynamic_symbolic_context(
895                    t.base, AttrSource(source, "_base"), shape_env, callback
896                )
897
898            t_symbolic_context: SymbolicContext
899            t_dynamic_sizes = [DimDynamic.DYNAMIC] * t.ndim
900            if t.is_traceable_wrapper_subclass:
901                assert t.attrs is not None
902                inner_contexts: Dict[str, SymbolicContext] = {}
903                for attr, inner in t.attrs.items():
904                    assert isinstance(attr, str)
905                    inner_contexts[attr] = all_dynamic_symbolic_context(
906                        inner, AttrSource(source, attr), shape_env, callback
907                    )
908                t_symbolic_context = SubclassSymbolicContext(
909                    dynamic_sizes=t_dynamic_sizes,
910                    constraint_sizes=[None] * t.ndim,
911                    inner_contexts=inner_contexts,  # type: ignore[arg-type]
912                    tensor_source=source,
913                    view_base_context=view_base_context,
914                )
915            else:
916                t_symbolic_context = StatelessSymbolicContext(
917                    dynamic_sizes=t_dynamic_sizes,
918                    constraint_sizes=[None] * t.ndim,
919                    view_base_context=view_base_context,
920                )
921
922            return t_symbolic_context
923
924        # Returns a fake-ified version of an input view tensor t, given an already fake-ified
925        # base. At a high level, we want two things:
926        #   1. fake_t should have the same view relationship to the given fake base as the
927        #      input t has to its _base.
928        #   2. fake_t should have symbolic sizes / strides / storage offset according to the
929        #      appropriate symbolic context (i.e. from the automatic dynamic algorithm).
930        #
931        # We currently take different strategies across view types:
932        #   * For dense -> dense views, accomplish both (1) and (2) simultaneously via an
933        #     as_strided() call on the fake-ified base, passing symbolic metadata.
934        #   * For views involving subclasses, perform view replay using view funcs to
935        #     achieve (1). It's necessary for (2) to swap out any closed-over state in
936        #     the view funcs with symbolicized SymInts and fake-ified tensors. Doing this
937        #     avoids specialization (and thus over-eager simplification of symbols) that
938        #     could occur during view replay on the fake-ified base.
939        #
940        # Examples:
941        #   * t.unsqueeze(-1) with dense t is a dense -> dense view. It can be modeled
942        #     with an as_strided() call on the fake base passing symbolic metadata.
943        #   * sub.select(dim=0, index=3) is a subclass -> subclass view. The index arg
944        #     is made symbolic to avoid invalid specialization and view replay is then
945        #     done to reconstruct the view.
946        #   * _nested_from_jagged(values, offsets) is a dense -> subclass view
947        #     that returns a subclass instance from a dense values tensor. The offsets
948        #     tensor is closed over in the view func, as it can be considered view metadata.
949        #     First, the offsets tensor is fake-ified according to the inner symbolic
950        #     context and with the correct relationship to the outer size / stride metadata.
951        #     Then view replay is done, swapping in the fake offsets so the view replay output
952        #     is fully fake with no invalid specialization.
953        def view_from_base(
954            base: torch.Tensor, t: MetaTensorDesc, source=source, shape_env=shape_env
955        ):
956            # fake-ify t's metadata according to the outer symbolic context
957            (sizes, strides, storage_offset) = sym_sizes_strides_storage_offset(
958                t, source
959            )
960            if (
961                not t.is_traceable_wrapper_subclass
962                and not is_traceable_wrapper_subclass(base)
963            ):
964                # Dense -> Dense view case uses as_strided() to construct view relationship.
965                # TODO: Change this logic to use view replay for consistency?
966                # It's likely there is no view func available.
967                with maybe_suppress():
968                    return base.as_strided(sizes, strides, storage_offset)
969
970            from torch._dynamo.source import EphemeralSource
971            from torch.fx.experimental.symbolic_shapes import (
972                StatelessSymbolicContext,
973                sym_eq,
974            )
975
976            def symint_visitor_fn(s):
977                nonlocal symbolic_context
978                from torch.fx.experimental.symbolic_shapes import DimDynamic
979
980                all_static_sizes = (
981                    symbolic_context is not None
982                    and isinstance(symbolic_context, StatelessSymbolicContext)
983                    and all(
984                        x is DimDynamic.STATIC for x in symbolic_context.dynamic_sizes
985                    )
986                )
987                # Can't just rely on shape env being None - dynamo always initializes it
988                if all_static_sizes or shape_env is None:
989                    return s
990
991                # NB: The symbol here is expected to be simplified out because we a priori
992                # allocate inner and outer symbols according to the appropriate symbolic
993                # contexts and prefer those over this symbol during symbol simplification
994                # (via usage of EphemeralSource below). This -shouldn't- happen, but if
995                # this symbol somehow leaks out beyond the view tensor's shape metadata, our
996                # assumption of it being simplified out will fail and it may be guarded on,
997                # which will hard error.
998                sym_source = EphemeralSource("symint_visitor_fn")
999
1000                symbol = shape_env.create_symbol(s, sym_source, positive=None)
1001                return shape_env.create_symintnode(symbol, hint=s, source=sym_source)
1002
1003            real_to_fake_mapping = {}
1004            if t.is_traceable_wrapper_subclass:
1005                assert t.attrs is not None
1006                # NB: t.ctx could be None if the subclass in question has no
1007                # meaningful context
1008                assert t.type is not None
1009
1010                # Fake-ify t naively here; this is only done so we can get fake-ified inner
1011                # tensors with the correct relationships to the outer sizes / strides for use
1012                # in view replay. It's done beforehand here because it's not easy to do when
1013                # visiting tensors one-by-one during view replay.
1014                #
1015                # Example:
1016                #   Consider a Dense -> NJT view. NJT has (values, offsets) components and we
1017                #   want a view of values with the offsets closed over. As the offsets component
1018                #   is needed to describe the output view, it's important that it's fakeified
1019                #   correctly.
1020                fake_t = empty_create_subclass(
1021                    t, outer_size=sizes, outer_stride=strides
1022                )
1023                attrs, _ = fake_t.__tensor_flatten__()
1024                for attr in attrs:
1025                    real_to_fake_mapping[t.attrs[attr].id] = getattr(fake_t, attr)
1026
1027            def tensor_visitor_fn(
1028                visited_t: torch.Tensor,
1029                # These arguments are never passed, we just use them to close
1030                # over these relevant values
1031                shape_env=shape_env,
1032                callback=callback,
1033            ):
1034                # It's possible to close over an undefined tensor (e.g. NJT's lengths).
1035                if visited_t is None:
1036                    return None
1037
1038                # NB: visited_t being a Tensor here is very naughty!  Should
1039                # have already been described
1040
1041                # Fake inner tensors of view subclasses will come from the mapping built above.
1042                visited_id = self.describer.get_tensor_id(visited_t)
1043                fake_visited_t = real_to_fake_mapping.get(visited_id, None)
1044                if fake_visited_t is not None:
1045                    return fake_visited_t
1046
1047                visited_desc = self.describer.describe_tensor(visited_t)
1048
1049                # For other closed-over tensor state, fake-ify it as all dynamic with an
1050                # ephemeral source. This avoids invalid specialization during view replay.
1051                # If we find that in practice the usage of ephemeral sources isn't enough
1052                # to guarantee that we don't have guards on these symbols, we may need to
1053                # explicitly suppress guards (as is done for _base in the dense -> dense
1054                # view case).
1055                temp_source = EphemeralSource("tensor_visitor_fn")
1056                return self.meta_tensor(
1057                    visited_desc,
1058                    shape_env,
1059                    callback,
1060                    source=temp_source,
1061                    symbolic_context=all_dynamic_symbolic_context(
1062                        visited_desc, temp_source, shape_env, callback
1063                    ),
1064                )
1065
1066            # Replay the view, swapping out any non-symbolic SymInts or real tensors
1067            # for symbolic SymInts or fake tensors.
1068            assert t.view_func is not None
1069            # NB: we do NOT suppress guards here, we need to remove ephemeral
1070            # sources
1071            fake_t = t.view_func(base, symint_visitor_fn, tensor_visitor_fn)
1072
1073            # Ensure the output has symbolic shapes according to the outer symbolic context.
1074            # These checks should simplify out any symbols created for closed-over view func
1075            # SymInts.
1076            torch._check(sym_eq(fake_t.size(), sizes))
1077            torch._check(sym_eq(fake_t.stride(), strides))
1078            torch._check(sym_eq(fake_t.storage_offset(), storage_offset))
1079            return fake_t
1080
1081        if self.get_tensor_memo(t) is None:
1082            GRAD_TENSOR_SENTINEL_VALUE = -2
1083
1084            with torch.inference_mode(t.is_inference):
1085                if t.is_sparse:
1086                    is_leaf = t.is_leaf
1087
1088                    # The lambda function below is similar to
1089                    # `t.to(device='meta')` except the latter
1090                    # preserves nnz value
1091                    r = callback(
1092                        lambda: torch.ops.aten._sparse_coo_tensor_with_dims(
1093                            t.sparse_dim,
1094                            t.dense_dim,
1095                            t.size,
1096                            dtype=t.dtype,
1097                            layout=torch.sparse_coo,
1098                            device="meta",
1099                        )
1100                    )
1101                    if self.copy_data:
1102                        # Pray that sparse clone doesn't lose information
1103                        assert t.data is not None
1104                        with torch.no_grad(), no_dispatch():
1105                            r.real_tensor = _safe_clone(t.data)
1106                    assert safe_is_leaf(r), "the callback you passed in doesn't detach"
1107                    # Note [is_coalesced is dispatched]
1108                    # Strangely enough, is_coalesced() is a dispatched operator,
1109                    # which means that it will get caught by fake tensor mode.
1110                    # Ordinarily this would error, but there's some logic in
1111                    # fake tensor ensure this doesn't happen.
1112                    r._coalesced_(t.is_coalesced)
1113                    if t.requires_grad:
1114                        r.requires_grad = True
1115                    if t.requires_grad and not is_leaf:
1116                        # This should probably use DelayedError,
1117                        # but clone is fine for now for sparse tensors.
1118                        # (DelayedError does not work for sparse because it causes
1119                        # the Fake sparse tensor to "lose" its fakeness)
1120                        r = r.clone()
1121                        with torch.enable_grad():
1122                            r._coalesced_(t.is_coalesced)
1123                elif is_sparse_compressed_layout(t.layout):
1124                    is_leaf = t.is_leaf
1125
1126                    if t.layout in {torch.sparse_bsr, torch.sparse_bsc}:
1127                        assert t.sparse_dim is not None
1128                        assert t.dense_dim is not None
1129                        assert t.values is not None
1130                        batch_dim = t.ndim - t.sparse_dim - t.dense_dim
1131                        blocksize = t.values.shape[batch_dim + 1 : batch_dim + 3]
1132                    else:
1133                        blocksize = ()
1134                    if t.layout in {torch.sparse_csr, torch.sparse_bsr}:
1135                        assert t.crow_indices is not None
1136                        index_dtype = t.crow_indices.dtype
1137                    else:
1138                        assert t.ccol_indices is not None
1139                        index_dtype = t.ccol_indices.dtype
1140
1141                    r = callback(
1142                        lambda: torch.ops.aten._sparse_compressed_tensor_with_dims(
1143                            0,
1144                            t.dense_dim,
1145                            t.shape,
1146                            blocksize,
1147                            index_dtype,
1148                            layout=t.layout,
1149                            dtype=t.dtype,
1150                            device="meta",
1151                        )
1152                    )
1153                    if self.copy_data:
1154                        # Pray sparse clone doesn't lose information
1155                        assert t.data is not None
1156                        with torch.no_grad(), no_dispatch():
1157                            r.real_tensor = _safe_clone(t.data)
1158                    assert safe_is_leaf(r), "the callback you passed in doesn't detach"
1159                    if t.requires_grad:
1160                        r.requires_grad = True
1161                    if t.requires_grad and not is_leaf:
1162                        r = torch._C._functions.DelayedError(
1163                            "Internal error: Tried to backward() through example input",
1164                            1,
1165                        )(r)
1166                elif t.is_nested and not t.is_traceable_wrapper_subclass:
1167                    # TODO: Handle this better in Dynamo?
1168                    # There are checks there now, but this can still be triggered by a dense
1169                    # tensor graph input that is a view of a strided NT.
1170                    from torch._dynamo.exc import unimplemented
1171
1172                    unimplemented(
1173                        "strided nested tensors are not supported by meta conversion"
1174                    )
1175                elif t.is_mkldnn:
1176                    is_leaf = t.is_leaf
1177                    sizes, strides, _storage_offset = sym_sizes_strides_storage_offset(
1178                        t, source
1179                    )
1180                    # TODO: This doesn't seem right, where's the MKLDNN'ness
1181                    # lol
1182                    r = callback(
1183                        lambda: torch.empty_strided(
1184                            sizes, strides, dtype=t.dtype, device="meta"
1185                        )
1186                    )
1187                    if self.copy_data:
1188                        with torch.no_grad(), no_dispatch():
1189                            assert t.size is not None
1190                            assert t.stride is not None
1191                            r.real_tensor = torch.empty_strided(
1192                                t.size, t.stride, dtype=t.dtype, device=t.device
1193                            )
1194                            assert t.data is not None
1195                            _safe_copy(r.real_tensor, t.data)
1196                    assert safe_is_leaf(r), "the callback you passed in doesn't detach"
1197                    if t.requires_grad:
1198                        r.requires_grad = True
1199                    if t.requires_grad and not is_leaf:
1200                        r = torch._C._functions.DelayedError(
1201                            "Internal error: Tried to backward() through example input",
1202                            1,
1203                        )(r)
1204                elif t.is_functorch_wrapped:
1205                    if t.is_view:
1206                        from torch._dynamo.exc import unimplemented
1207
1208                        unimplemented(
1209                            "view functorch tensors are not supported by meta conversion"
1210                        )
1211
1212                    # Wraps a functorch tensor class (BatchedTensor, GradTrackingTensor)
1213                    # in a FakeTensor
1214                    def _to_fake_tensor(t: MetaTensorDesc):
1215                        # TODO: why aren't the recursive calls going to
1216                        # meta_tensor
1217                        if t.is_batchedtensor:
1218                            assert t.unwrapped is not None
1219                            assert t.level is not None
1220                            assert t.bdim is not None
1221                            ft = _to_fake_tensor(t.unwrapped)
1222                            lvl = t.level
1223                            bdim = t.bdim
1224                            # You cannot create functorch tensors without
1225                            # having the ambient funtorch interpreter stack
1226                            # available, as the level refers to things in the
1227                            # stack
1228                            with torch._functorch.pyfunctorch.temporarily_restore_interpreter_stack(
1229                                t.functorch_stack
1230                            ):
1231                                r = _add_batch_dim(ft, bdim, lvl)
1232                        elif t.is_gradtrackingtensor:
1233                            assert t.unwrapped is not None
1234                            assert t.level is not None
1235                            disable_functorch = torch._C._DisableFuncTorch
1236                            with disable_functorch():
1237                                ft = _to_fake_tensor(t.unwrapped)
1238                            lvl = t.level
1239                            if lvl == GRAD_TENSOR_SENTINEL_VALUE:
1240                                r = ft
1241                            else:
1242                                with torch._functorch.pyfunctorch.temporarily_restore_interpreter_stack(
1243                                    t.functorch_stack
1244                                ):
1245                                    r = torch._C._functorch._wrap_for_grad(ft, lvl)
1246
1247                            is_leaf = t.is_leaf
1248                            if t.requires_grad and safe_is_leaf(r):
1249                                r.requires_grad = True
1250                            elif t.requires_grad and not is_leaf:
1251                                r = torch._C._functions.DelayedError(  # type: ignore[assignment]
1252                                    "Internal error: Tried to backward() through example input",
1253                                    1,
1254                                )(
1255                                    r  # type: ignore[arg-type]
1256                                )
1257                        elif t.is_functional:
1258                            assert t.unwrapped is not None
1259                            assert t.current_level is not None
1260                            ft = self.meta_tensor(
1261                                t.unwrapped,
1262                                shape_env=shape_env,
1263                                callback=callback,
1264                                # NB: reuse these exactly, we treat the
1265                                # functional tensor as "invisible".
1266                                # TODO: Actually this all probably doesn't
1267                                # work, take a closer look.
1268                                source=source,
1269                                symbolic_context=symbolic_context,
1270                            )
1271                            r = _wrap_functional_tensor(ft, t.current_level)
1272                            # TODO: is_leaf/requires_grad?
1273                        else:
1274                            assert t.stride is not None
1275
1276                            sizes = t.size
1277                            strides = t.stride
1278                            r = callback(
1279                                lambda: torch.empty_strided(
1280                                    sizes,
1281                                    strides,
1282                                    dtype=t.dtype,
1283                                    device="meta",
1284                                )
1285                            )
1286                            if self.copy_data:
1287                                with torch.no_grad(), no_dispatch():
1288                                    r.real_tensor = torch.empty_strided(  # type: ignore[attr-defined]
1289                                        t.size,
1290                                        t.stride,
1291                                        dtype=t.dtype,
1292                                        device=t.device,
1293                                    )
1294                                    assert t.data is not None
1295                                    _safe_copy(r.real_tensor, t.data)  # type: ignore[attr-defined]
1296                        return r
1297
1298                    r = _to_fake_tensor(t)
1299
1300                elif t.is_functional and t.device.type not in ["xla", "lazy"]:
1301                    assert t.unwrapped is not None
1302                    assert not t.is_functorch_wrapped  # handled above
1303                    unwrapped = self.meta_tensor(
1304                        t.unwrapped,
1305                        shape_env=shape_env,
1306                        callback=callback,
1307                        source=source,
1308                        symbolic_context=symbolic_context,
1309                    )
1310                    r = torch._to_functional_tensor(unwrapped)
1311                    torch._mirror_autograd_meta_to(t.autograd_meta_from, r)  # type: ignore[attr-defined]
1312
1313                elif t.is_view:
1314                    # Construct views in two steps: recursively meta-fy their
1315                    # base, and then create view(s) off that.  NB: doing it
1316                    # directly from storage is WRONG because this won't cause
1317                    # version counters to get shared.
1318
1319                    assert t.base is not None
1320
1321                    base_symbolic_context = None
1322                    if shape_env and symbolic_context is not None:
1323                        from torch.fx.experimental.symbolic_shapes import (
1324                            StatelessSymbolicContext,
1325                        )
1326
1327                        assert isinstance(symbolic_context, StatelessSymbolicContext)
1328                        # NB: This should generally be set when the input is a view,
1329                        # but the exception right now is for fake-ifying grads, which is
1330                        # a work in progress.
1331                        if symbolic_context.view_base_context is not None:
1332                            base_symbolic_context = symbolic_context.view_base_context
1333
1334                    base = self.meta_tensor(
1335                        t.base,
1336                        shape_env,
1337                        callback,
1338                        source=torch._dynamo.source.AttrSource(source, "_base"),
1339                        symbolic_context=base_symbolic_context,
1340                    )
1341
1342                    def is_c_of_r(complex_dtype, real_dtype):
1343                        return (
1344                            utils.is_complex_dtype(complex_dtype)
1345                            and utils.corresponding_real_dtype(complex_dtype)
1346                            == real_dtype
1347                        )
1348
1349                    # In some situations, MetaConverter may be called in a
1350                    # context where autograd is disabled.  For the _is_view
1351                    # assert to pass, we have to setup the autograd view
1352                    # metadata anyway.  Do this by reenabling the
1353                    # ADInplaceOrView key.  This is kind of a hack.
1354                    old_exclude = torch._C._dispatch_tls_is_dispatch_key_excluded(
1355                        torch._C.DispatchKey.ADInplaceOrView
1356                    )
1357                    torch._C._dispatch_tls_set_dispatch_key_excluded(
1358                        torch._C.DispatchKey.ADInplaceOrView, False
1359                    )
1360                    try:
1361                        if base.dtype == t.dtype:
1362                            pass
1363                        elif is_c_of_r(base.dtype, t.dtype):
1364                            base = torch.view_as_real(base)
1365                        elif is_c_of_r(t.dtype, base.dtype):
1366                            base = torch.view_as_complex(base)
1367                        else:
1368                            # This is not guaranteed to succeed.  If it fails, it
1369                            # means there is another dtype-converting view function
1370                            # that hasn't been handled here
1371                            base = base.view(t.dtype)
1372
1373                        # This is very tricky.  Naively, you might expect this
1374                        # to hold:
1375                        #
1376                        #   if t.requires_grad and not safe_is_leaf(t)
1377                        #       assert t._base.requires_grad
1378                        #
1379                        # But it's not true!  As you can see in the following
1380                        # program:
1381                        #
1382                        #   x = torch.zeros(4)
1383                        #   y = x.view(1, 4)
1384                        #   y.requires_grad = True
1385                        #   z = y.view(1, 1, 4)
1386                        #   assert z._base is x
1387                        #
1388                        # So we may have to do *two* views out of the base to
1389                        # recreate this situation.
1390                        if t.is_leaf:
1391                            # Leaf views that track view metadata are created by
1392                            # creating a view inside a no_grad block
1393                            with torch.no_grad():
1394                                r = view_from_base(base, t)
1395                            # As it's a leaf, we can directly assign requires_grad
1396                            r.requires_grad = t.requires_grad
1397                        else:
1398                            if t.base.requires_grad == t.requires_grad:
1399                                # Easy case, just run the view op
1400                                with torch.enable_grad():
1401                                    r = view_from_base(base, t)
1402
1403                                # NB: We don't actaully faithfully replicate
1404                                # autograd connectivity, but that doesn't matter
1405                                # today. See following for more info:
1406                                # https://gist.github.com/soulitzer/e03f015b314c3f5fcf80888c69390913
1407                            else:
1408                                # Obscure case.  Create a leaf view and give it the
1409                                # correct requires_grad, then do the final view.
1410                                # NB: Can't have a non-leaf without requiring grad!
1411                                assert t.requires_grad
1412                                with torch.no_grad():
1413                                    mid = base.view(base.shape)
1414                                mid.requires_grad = t.requires_grad
1415                                with torch.enable_grad():
1416                                    r = view_from_base(mid, t)
1417                        # The CreationMeta influences whether or not inplace
1418                        # mutation is an error or not.  So we need to make
1419                        # sure we properly propagate this as well.
1420                        assert t.creation_meta is not None
1421                        torch._C._autograd._set_creation_meta(r, t.creation_meta)
1422                    finally:
1423                        torch._C._dispatch_tls_set_dispatch_key_excluded(
1424                            torch._C.DispatchKey.ADInplaceOrView, old_exclude
1425                        )
1426
1427                else:
1428                    is_leaf = t.is_leaf
1429
1430                    # Graph-Break for wrapped tensors
1431                    if (
1432                        not (t.is_batchedtensor or t.is_gradtrackingtensor)
1433                        and t.is_functorch_wrapped
1434                    ) or t.is_legacy_batchedtensor:
1435                        return NotImplemented
1436
1437                    (
1438                        sizes,
1439                        strides,
1440                        storage_offset,
1441                    ) = sym_sizes_strides_storage_offset(t, source, symbolic_context)
1442
1443                    # If we have a subclass that desugars into dense tensors,
1444                    # perform our callback on each inner tensor.
1445                    if t.is_traceable_wrapper_subclass:
1446                        r = empty_create_subclass(
1447                            t, outer_size=sizes, outer_stride=strides
1448                        )
1449                    else:
1450                        r = callback(
1451                            lambda: torch.empty_strided(
1452                                sizes,
1453                                strides,
1454                                dtype=t.dtype,
1455                                device="meta",
1456                            )
1457                        )
1458                        if self.copy_data:
1459                            with torch.no_grad(), no_dispatch():
1460                                assert t.size is not None
1461                                assert t.stride is not None
1462                                r.real_tensor = torch.empty_strided(
1463                                    t.size, t.stride, dtype=t.dtype, device=t.device
1464                                )
1465                                _safe_copy(r.real_tensor, t.data)
1466
1467                    assert safe_is_leaf(r), "the callback you passed in doesn't detach"
1468                    if t.requires_grad:
1469                        r.requires_grad = t.requires_grad
1470                        if not is_leaf:
1471                            # Fake up some autograd history.
1472                            # Note: we *used* to call .clone() here to mock up some autograd history.
1473                            # This is bad for subclasses.
1474                            # Consider the case where you have a wrapper subclass that is contiguous,
1475                            # but its inner tensor is noncontiguous().
1476                            # .clone() (or other ops) will have the side effect of changing
1477                            # the metadata of the inner tensor.
1478                            # So instead, we now have a dedicated fn to set autograd history,
1479                            # without inadvertently changing other metadata.
1480                            r = torch._C._functions.DelayedError(
1481                                "Internal error: Tried to backward() through example input",
1482                                1,
1483                            )(r)
1484
1485                    s = t.storage
1486                    assert s is not None
1487                    if s.id not in self.storage_memo and (
1488                        r.is_nested
1489                        or (
1490                            r.stride() == strides
1491                            and r.storage_offset() == storage_offset
1492                        )
1493                    ):
1494                        # You're normal and happy, install the fresh storage into the memo
1495                        self.set_storage_memo(s, r.untyped_storage())
1496                        if self.copy_data:
1497                            r.untyped_storage().real_storage = (
1498                                r.real_tensor.untyped_storage()
1499                            )
1500                    else:
1501                        # You're in crazy town; somehow you gave us a tensor
1502                        # that wasn't a view, but had nonzero storage offset,
1503                        # nontrivial strides (such that clone() couldn't
1504                        # preserve them), or already aliases with another
1505                        # tensor's storage.  The most typical way to end
1506                        # up here is with set_.  So use set_ to bludgeon this
1507                        # in.
1508                        r_s = self.meta_storage(s, callback=callback)
1509                        # NB: In principle, this should always work, but there
1510                        # is some subtle difference in the autograd metadata
1511                        # that means we will backprop the set_ call, even if
1512                        # r is declared as an input to grad.
1513                        # See https://github.com/pytorch/pytorch/issues/87956
1514                        # for the reproducer.
1515                        # NB: The in_kernel_invocation_manager here is necessary
1516                        # for fake tensor.  If we run the set_ call with fake
1517                        # tensor on, r will improperly report that it is NOT a
1518                        # meta tensor but a cpu tensor, and then the set_ call
1519                        # will fail due to device mismatch.  no_dispatch() is
1520                        # not enough, because the fake tensor will still claim
1521                        # to be a CPU tensor and you'll end up in the CPU
1522                        # kernel.  Arguably this is a hack; a cleaner way to
1523                        # solve this is to have a FakeStorage concept which
1524                        # would report it's CPU device--no problem now!  But
1525                        # this is difficult to do because we don't have storage
1526                        # subclasses.  Relevant test is
1527                        # DynamicShapesFunctionTests::test_add_dynamic_shapes in
1528                        # test/dynamo/test_dynamic_shapes.py
1529                        maybe_fake_mgr: ContextManager[None] = contextlib.nullcontext()
1530                        from torch._subclasses.fake_tensor import (
1531                            in_kernel_invocation_manager,
1532                            maybe_get_fake_mode,
1533                        )
1534
1535                        mb_fake_mode = maybe_get_fake_mode(r)
1536                        if mb_fake_mode is not None:
1537                            maybe_fake_mgr = in_kernel_invocation_manager(mb_fake_mode)
1538                        with torch.no_grad(), maybe_suppress():
1539                            with maybe_fake_mgr:
1540                                r.set_(r_s, storage_offset, sizes, strides)
1541                            if self.copy_data:
1542                                with torch.no_grad(), no_dispatch():
1543                                    r.real_tensor.set_(
1544                                        r_s.real_storage,
1545                                        t.storage_offset,
1546                                        t.size,
1547                                        t.stride,
1548                                    )
1549
1550                if t.grad is not None:
1551                    from torch._dynamo.source import AttrSource
1552
1553                    # TODO: Use a valid grad-specific symbolic context instead of recycling
1554                    # the one from t. This isn't correct if e.g. t._is_view() != t.grad._is_view().
1555                    r.grad = self.meta_tensor(
1556                        t.grad,
1557                        shape_env,
1558                        callback,
1559                        source=AttrSource(source, "grad"),
1560                        symbolic_context=symbolic_context,
1561                    )
1562                torch._C._set_conj(r, t.is_conj)
1563                torch._C._set_neg(r, t.is_neg)
1564            # This can be skipped if necessary for performance reasons
1565            skip_leaf = (
1566                t.is_gradtrackingtensor and t.level == GRAD_TENSOR_SENTINEL_VALUE
1567            )
1568            assert_metadata_eq(assert_eq, t, r, skip_symbolic=True, skip_leaf=skip_leaf)
1569            # Thanks to storage resizing, it's possible to end up with a tensor
1570            # that advertises a real size, but has a storage that actually has zero bytes.
1571            # Need to reflect this in the generated FakeTensor.
1572            if t.storage is not None and t.storage.size == 0:
1573                r.untyped_storage().resize_(0)
1574
1575            if t.is_parameter:
1576                r._is_param = True
1577
1578            # See Note: [Creating symbolic nested int]
1579            if t.nested_int is not None:
1580                r.nested_int_memo = r.fake_mode.create_symbolic_nested_int(
1581                    nt_tensor_id=t.nested_int
1582                )
1583
1584            self.set_tensor_memo(t, r)
1585
1586        return self.get_tensor_memo(t)
1587
1588    def __call__(
1589        self,
1590        t,
1591        shape_env=None,
1592        *,
1593        callback=lambda t: t(),
1594        source=None,
1595        symbolic_context=None,
1596        # Controls whether or not we should dump the tensor metadata to structured logs
1597        # when source is not None.  Because we refakify after Dynamo is done,
1598        # we don't want to dump info again from AOTAutograd, it is redundant.
1599        trace=True,
1600    ):
1601        # TODO: zero tensors?  We appear to have eliminated them by
1602        # excluding complex for now
1603
1604        # Filter out cases we don't support
1605        # TODO: This can probably be simplified quite a bit
1606        if isinstance(t, torch.Tensor):
1607            if (
1608                # Lazy tensors are not supported.  Note that XLA is
1609                # implemented on top of lazy tensor, not excluded here; we
1610                # have some special handling for it; this is for XLA Dynamo
1611                # integration
1612                t.device.type == "lazy"
1613                or
1614                # Quantization is not supported
1615                t.is_quantized
1616                or
1617                # Views out of sparse tensors not currently supported (plain
1618                # sparse is supported htough)
1619                (t._is_view() and t._base is not None and t._base.is_sparse)
1620            ):
1621                self.miss += 1
1622                return NotImplemented
1623            else:
1624                self.hit += 1
1625        elif torch.overrides.is_tensor_like(t):
1626            self.miss += 1
1627            return NotImplemented
1628        else:
1629            # non-Tensor types don't count as hit or miss
1630            return t
1631
1632        if source is None:
1633            trace = False
1634
1635        # Describe the tensor.  NB: do NOT disable ambient modes, we may need
1636        # to query them when figuring out what to put in here
1637        t_desc = self.describer.describe_tensor(t, trace=trace)
1638
1639        if trace:
1640            trace_structured(
1641                "describe_source",
1642                metadata_fn=lambda: {
1643                    "describer_id": self.describer.id,
1644                    "id": t_desc.id,
1645                    "source": source.name(),
1646                },
1647            )
1648
1649        # Do the meta-fication.  Here, we disable all the ambient modes, to
1650        # better simulate what would be like to re-fakeify from a fresh
1651        # process
1652        with contextlib.ExitStack() as exit_stack:
1653            exit_stack.enter_context(torch._dispatch.python.suspend_functionalization())
1654            st = peek_interpreter_stack()
1655            if st is not None:
1656                exit_stack.enter_context(
1657                    torch._functorch.pyfunctorch.temporarily_clear_interpreter_stack()
1658                )
1659
1660            r = self.meta_tensor(
1661                t_desc,
1662                shape_env=shape_env,
1663                callback=callback,
1664                source=source,
1665                symbolic_context=symbolic_context,
1666            )
1667
1668        if type(t) is torch.nn.Parameter:
1669            # NB: Cannot directly use Parameter constructor
1670            # because that would force a detach, not desirable
1671            r._is_param = True
1672
1673        # TODO: return the description for later
1674        return r
1675
1676
1677import torch._prims_common as utils
1678