xref: /aosp_15_r20/external/pytorch/torch/_subclasses/functional_tensor.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import contextlib
3import warnings
4import weakref
5from abc import ABC, abstractmethod
6from typing import Any, Callable, ContextManager, Dict, List, Optional, Tuple, Union
7
8import torch
9import torch._inductor.config as inductor_config
10import torch.utils._pytree as pytree
11from torch._C import _functionalization_reapply_views_tls as _reapply_views
12from torch._ops import _get_dispatch_mode_pre_dispatch
13from torch._subclasses.meta_utils import is_sparse_any
14from torch.utils._python_dispatch import (
15    _detect_infra_mode,
16    _disable_infra_mode,
17    return_and_correct_aliasing,
18    TorchDispatchMode,
19)
20
21
22not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented")
23
24
25# NOTE Some special handling for tensor conversion during export is needed.
26# Normally, when tracing through the model with tensor.to(), the maybe-aliasing
27# relationship between input and output tensors will be baked into the graph.
28# For example, if we got a tensor with device cpu and call tensor.to("cpu"),
29# it will become a no-op in the graph. For a whole graph capture, this is not
30# sound so we need to do something different. Instead, in export we will try to
31# preserve the tensor conversion by forcing a non-semantic-breaking aten::_to_copy
32# operator to be traced in the graph, and subsequently banning mutations on all
33# such converted tensors.
34# In addition to patching .to() method call in functionalization, we will have to
35# patch other similar methods like float() and cpu(), because they intentionally
36# don't fall back to .to() methods, but have the same behavior as .to() according to
37# pytorch document. https://pytorch.org/docs/stable/generated/torch.Tensor.float.html
38# thus we simply force them to go through .to() call.
39def _conversion_method_template(**extra_kwargs):
40    def _(self, *args, **kwargs):
41        return self.to(*args, **{**kwargs, **extra_kwargs})
42
43    return _
44
45
46class FunctionalTensor(torch.Tensor):
47    """
48    Functional tensors represent tensors that will remove mutations
49    from a program. If you perform a mutable operation on a functional tensor,
50    it will re-dispatch to the functional variant of that operation.
51
52    Historically, functionalization is implemented in C++ in the dispatcher.
53    This class is a lightweight python shim around the C++ functionalization logic.
54
55    FunctionalTensor is required to be used with a corresponding
56    FunctionalTensormode active, because it relies
57    on using the mode for dispatch (which can properly handle factory functions).
58    """
59
60    elem: torch.Tensor
61    # Indicates to our torch_dispatch dispatching infra that
62    # this is an "infra" mode with lower dispatching precedence.
63    _mode_key = torch._C._TorchDispatchModeKey.FUNCTIONAL
64
65    # Note: The reason we add these extra keys to our FunctionalTensor subclass
66    # is to mirror the behavior of C++ functionalization (we can choose to change this
67    # later, as long as it doesn't break anything).
68    # FunctionalTensorWrapper copies **all** dispatch keys from the inner tensor
69    # to the wrapper, excluding functorch and python dispatch keys.
70    # Here I'm trying to re-use the keyset the functorch wrapper subclasses copy,
71    # except that they don't include ZeroTensor so I'm manually adding it in.
72    _extra_dispatch_keys = torch._C._additional_keys_to_prop_for_wrapper_tensors.add(
73        torch._C.DispatchKey.ZeroTensor
74    )
75
76    # These are all aten ops that correspond to metadata queries.
77    # We want FunctionalTensor to be able to handle them directly.
78    metadata_fns = [
79        torch.ops.aten.is_contiguous.default,  # type: ignore[has-type]
80        torch.ops.aten.is_contiguous.memory_format,  # type: ignore[has-type]
81        torch.ops.aten.is_strides_like_format.default,  # type: ignore[has-type]
82        torch.ops.aten.is_non_overlapping_and_dense.default,  # type: ignore[has-type]
83        torch.ops.aten.size.default,  # type: ignore[has-type]
84        torch.ops.aten.sym_size.default,  # type: ignore[has-type]
85        torch.ops.aten.stride.default,  # type: ignore[has-type]
86        torch.ops.aten.sym_stride.default,  # type: ignore[has-type]
87        torch.ops.aten.storage_offset.default,  # type: ignore[has-type]
88        torch.ops.aten.sym_storage_offset.default,  # type: ignore[has-type]
89        torch.ops.aten.numel.default,  # type: ignore[has-type]
90        torch.ops.aten.sym_numel.default,  # type: ignore[has-type]
91        torch.ops.aten.dim.default,  # type: ignore[has-type]
92        torch.ops.prim.device.default,  # type: ignore[has-type]
93    ]
94
95    # These are ops that claim to be functional, but actually are maybe-mutating/maybe-aliasing
96    # TODO (tmanlaibaatar) make it a tag
97    maybe_aliasing_or_mutating_ops = [
98        torch.ops.aten.dropout.default,  # type: ignore[has-type]
99        torch.ops.aten.batch_norm.default,  # type: ignore[has-type]
100        torch.ops.aten.native_batch_norm.default,  # type: ignore[has-type]
101        torch.ops.aten._batch_norm_impl_index.default,  # type: ignore[has-type]
102        torch.ops.aten.cudnn_batch_norm.default,  # type: ignore[has-type]
103        torch.ops.aten.miopen_batch_norm.default,  # type: ignore[has-type]
104        torch.ops.aten.atleast_1d.default,  # type: ignore[has-type]
105        torch.ops.aten.atleast_2d.default,  # type: ignore[has-type]
106        torch.ops.aten.atleast_3d.default,  # type: ignore[has-type]
107        torch.ops.aten.cartesian_prod.default,  # type: ignore[has-type]
108        torch.ops.aten.conj_physical.default,  # type: ignore[has-type]
109        torch.ops.aten.alpha_dropout.default,  # type: ignore[has-type]
110        torch.ops.aten.feature_dropout.default,  # type: ignore[has-type]
111        torch.ops.aten.feature_alpha_dropout.default,  # type: ignore[has-type]
112        torch.ops.aten.unsafe_chunk.default,  # type: ignore[has-type]
113    ]
114
115    # Used by auto_functionalize to determine base of tensors during inference mode.
116    _inference_mode_base: Optional["FunctionalTensor"] = None
117
118    def __new__(cls, elem, mode):
119        assert torch._is_functional_tensor(elem)
120
121        # In general, we'd like our functional tensor subclass to only be in charge of functionalization,
122        # and defer to the inner subclass for all other functionality.
123        # Example: If our inner tensor is a ZeroTensor, we would want to defer running the ZeroTensor fallback
124        # until after we redispatch to our inner ZeroTensor.
125        # However, there are a few keys that we need to mirror between the inner and outer tensors.
126        #   Conjugate
127        #   Negative
128        # Why? These keys are used to test metadata queries, like `.is_conj()` and `.is_neg()`.
129        # We **need** calls to is_conj() to return the same thing on the outer and inner tensors,
130        # Because user code / framework code that branches like so needs to do the same thing
131        # when it sees the outer FunctionalTensor:
132        #     if (x.is_conj()) {
133        #         return at::view_as_real(x.resolve_conj());
134        #     } else {
135        #         return at::view_as_real(x);
136        #     }
137        extra_dispatch_keys = (
138            FunctionalTensor._extra_dispatch_keys & torch._C._dispatch_keys(elem)
139        )
140
141        out = torch.Tensor._make_wrapper_subclass(  # type: ignore[arg-type, attr-defined]
142            # TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great.
143            # Calling the overload that has kwargs causes us to go down the first overload path,
144            # which will **always** specialize sizes.
145            # We should probably eventually fix this so that the first overload can just handle dynamic shapes.
146            cls,
147            elem.shape,  # sizes
148            elem.stride() if not is_sparse_any(elem) else None,  # strides
149            (
150                elem.storage_offset() if not is_sparse_any(elem) else None
151            ),  # storage_offset
152            None,  # memory_format
153            elem.dtype,  # dtype
154            elem.layout,  # layout
155            elem.device,  # device
156            False,  # pin_memory
157            elem.requires_grad,  # requires_grad
158            None,  # dispatch_sizes_strides_policy
159            False,  # dispatch_device
160            False,  # dispatch_layout
161            extra_dispatch_keys,  # _extra_dispatch_keys
162        )
163        torch._C._set_throw_on_mutable_data_ptr(out)
164        out.elem = elem
165
166        if (
167            torch.is_inference_mode_enabled()
168            and torch._inductor.config.enable_auto_functionalized_v2
169        ):
170            if out.is_base_tensor():
171                out._inference_mode_base = None
172                # This assumes that the FunctionalTensor.elem does not change its storage after this point.
173                # Otherwise this would be invalid.
174                mode._storage_to_base[out.elem.untyped_storage()] = out
175            else:
176                out._inference_mode_base = mode._storage_to_base[
177                    out.elem.untyped_storage()
178                ]
179                assert out._inference_mode_base is not None
180        return out
181
182    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
183        unrecognized_types = [
184            t
185            for t in types
186            if t not in [torch.Tensor, torch._subclasses.FakeTensor, FunctionalTensor]
187        ]
188        if unrecognized_types:
189            not_implemented_log.debug(
190                "FunctionalTensor unrecognized subclass(es): %s", unrecognized_types
191            )
192            return NotImplemented
193
194        if kwargs is None:
195            kwargs = {}
196
197        # FunctionalTensor needs to plumb all metadata requests to the inner tensor.
198        # In theory we don't have to do this - but if we want to service metadata requests here,
199        # we need to carefully make sure all metadata is accurate (including metadata mutations)
200        if func in FunctionalTensor.metadata_fns:
201            # All metadata accesses should be plumbed to the inner tensor, that way we don't have to worry
202            # about the problem of keeping metadata in sync between the wrapper and inner tensor.
203            # This also alleviates us from having to manually handle metadata mutations on the wrapper.
204            assert len(kwargs) == 0
205            if func in [
206                torch.ops.aten.is_strides_like_format.default,
207                torch.ops.aten.is_contiguous.memory_format,
208            ]:
209                assert len(args) == 2 and isinstance(args[0], FunctionalTensor)
210                return func(torch._from_functional_tensor(args[0].elem), args[1])
211            assert len(args) == 1 and isinstance(args[0], FunctionalTensor)
212
213            return func(torch._from_functional_tensor(args[0].elem))
214        # Originally I tried to implement my subclass without giving it a torch_dispatch, but I gave up:
215        # - _make_wrapper_subclass requires a __torch_dispatch__
216        # - If we want to use _make_subclass(), we have a problem: the subclass will share a TensorImpl with the inner tensor,
217        #   which is of type FunctionalTensorWrapper! We explicitly do not want our wrapper to be a FunctionalTensorWrapper.
218        # - If we use the default tensor.__new__(), we have another problem: it returns inner_tensor.alias(),
219        #   which causes every subclass created above autograd to have autograd view metadata
220        #   (in addition to also being a FunctionalTensorWrapper).
221        raise RuntimeError(
222            "Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()"
223        )
224
225    def __repr__(self):
226        return f"FunctionalTensor({repr(self.elem)})"
227
228    @staticmethod
229    def to_functional(x):
230        # We will do the wrapping for the user.
231
232        assert not torch._is_functional_tensor(x)
233        # The only autograd metadata we care about on the FunctionalTensor is:
234        # - requires_grad (so autograd runs)
235        # - is_leaf (so that mutations on graph inputs that are not leaves are allowed by the autograd engine)
236        #   this is handled by FunctionalTensor.to_functional
237        x_functional = torch._to_functional_tensor(x)
238        # Technically the FunctionalTensormode here is unnecessary,
239        # but it avoids spurious NotImplemented logs during `ProxyTorchDispatchMode` tracing.
240        # _mirror_autograd_meta_to queries tensor sizes,
241        # and otherwise the sym_size() call will go to the proxy mode before hitting
242        # FunctionalTensor.__torch_dispatch__
243
244        functional_mode = _detect_infra_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL)
245        assert functional_mode is not None
246
247        with functional_mode:
248            torch._mirror_autograd_meta_to(x, x_functional)  # type: ignore[attr-defined]
249            out = FunctionalTensor(x_functional, functional_mode)
250            torch._mirror_autograd_meta_to(x_functional, out)  # type: ignore[attr-defined]
251        return out
252
253    def from_functional(self):
254        torch._sync(self)
255        return torch._from_functional_tensor(self.elem)
256
257    def is_base_tensor(self) -> bool:
258        return torch._is_functional_tensor_base(self.elem)
259
260    def replace_(self, output) -> None:
261        torch._functionalize_replace(self.elem, output)
262
263    def commit_update(self) -> None:
264        torch._functionalize_commit_update(self.elem)
265
266    def sync(self) -> None:
267        torch._functionalize_sync(self.elem)
268
269    def mark_mutation_hidden_from_autograd(self) -> None:
270        torch._functionalize_mark_mutation_hidden_from_autograd(self.elem)
271
272    def tolist(self) -> Any:
273        if self.elem.dim() == 0:
274            return self.elem.item()
275        elif self.elem.dim() == 1:
276            return [elem.item() for elem in self.elem]
277        else:
278            return [elem.tolist() for elem in self.elem]
279
280    def to(self, *args, **kwargs):
281        if _detect_infra_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL).export:
282            # If copy is specified as pos arg, it's always the second one.
283            if len([arg for arg in args if isinstance(arg, bool)]) <= 1:
284                return super().to(*args, **{**kwargs, "copy": True})
285        return super().to(*args, **kwargs)
286
287    def cuda(self, device=None, *args, **kwargs):
288        device = device or torch.cuda.current_device()
289        if len(args) > 0:
290            return self.to(device, *args, **kwargs)
291        else:
292            return self.to(device=device, **kwargs)
293
294    char = _conversion_method_template(dtype=torch.int8)
295    cpu = _conversion_method_template(device=torch.device("cpu"))
296    bfloat16 = _conversion_method_template(dtype=torch.bfloat16)
297    byte = _conversion_method_template(dtype=torch.uint8)
298    double = _conversion_method_template(dtype=torch.float64)
299    float = _conversion_method_template(dtype=torch.float32)
300    bool = _conversion_method_template(dtype=torch.bool)
301    half = _conversion_method_template(dtype=torch.float16)
302    int = _conversion_method_template(dtype=torch.int32)
303    long = _conversion_method_template(dtype=torch.int64)
304
305    # TODO(sparse-team): fixes #133174 but can we do without the relay?
306    def to_dense(self):
307        return self.elem.to_dense()
308
309    @property
310    def layout(self):
311        return self.elem.layout
312
313
314class FunctionalTensorMode(TorchDispatchMode):
315    def __init__(self, pre_dispatch=False, export=False, _allow_token_discovery=False):
316        super().__init__()
317        self.export = export
318        self.is_on_stack = False
319        self.enter_stack = []
320        # Indicates to our torch_dispatch dispatching infra that
321        # this is an "infra" mode with lower dispatching precedence.
322        self._mode_key = torch._C._TorchDispatchModeKey.FUNCTIONAL
323        self.pre_dispatch = pre_dispatch
324        # This will be turned off later for pre-dispatch functionalization
325        self._dispatch_key = torch._C.DispatchKey.PreDispatch if pre_dispatch else None  # type: ignore[attr-defined]
326        # Map of effect type (ex. _EffectType.ORDERED) to a token. The tokens help keep
327        # track of the ordering between side effectful operations.
328        self._tokens: Dict[Any, torch.Tensor] = {}
329
330        # Filled after forward tracing.
331        self._tokens_forward_output: Dict[Any, torch.Tensor] = {}
332
333        # Functionalization runs twice in AOTAutograd, once in
334        # `run_functionalized_fw_and_collect_metadata` to collect metadata to
335        # see which tensors need to be functionalized and discover how many
336        # tokens we need, and another time in `make_fx` which does the actual
337        # tracing to replace ops with their functional variants and handling
338        # side-effectful ops. In the second stage there should be no token
339        # discovery. This flag distinguishes between the two stages.
340        self._allow_token_discovery = _allow_token_discovery
341
342        self._storage_to_base: weakref.WeakKeyDictionary[
343            torch.storage.UntypedStorage, Optional[FunctionalTensor]
344        ] = weakref.WeakKeyDictionary()
345
346    # No-op if FunctionalTensorMode is already in use
347    def __enter__(self):
348        def _get_prev_mode():
349            if self._dispatch_key == torch._C.DispatchKey.PreDispatch:
350                return _get_dispatch_mode_pre_dispatch(
351                    torch._C._TorchDispatchModeKey.FUNCTIONAL
352                )
353            return torch._C._get_dispatch_mode(
354                torch._C._TorchDispatchModeKey.FUNCTIONAL
355            )
356
357        if _get_prev_mode() is None:
358            self.enter_stack.append(True)
359            return super().__enter__()
360        else:
361            self.enter_stack.append(False)
362            return self
363
364    def __exit__(self, a, b, c):
365        is_on_stack = self.enter_stack.pop()
366        if is_on_stack:
367            super().__exit__(a, b, c)
368
369    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
370        if kwargs is None:
371            kwargs = {}
372
373        if self.export:
374            # We need to make sure that we don't decompose to() as usual in export mode,
375            # because it can get optimized away. Instead we always replace it with _to_copy().
376            if func == torch.ops.aten.to.dtype_layout:
377                kwargs.pop("copy", None)
378                return self.__torch_dispatch__(
379                    torch.ops.aten._to_copy.default, types, args, kwargs
380                )
381            if func == torch.ops.aten.to.dtype:
382                schema = tuple(arg.name for arg in func._schema.arguments)
383                for arg, name in zip(args[1:], schema[1:]):
384                    kwargs[name] = arg
385                kwargs.pop("copy", None)
386                return self.__torch_dispatch__(
387                    torch.ops.aten._to_copy.default, types, args[:1], kwargs
388                )
389
390        unrecognized_types = [
391            t
392            for t in types
393            if not issubclass(t, torch._subclasses.FakeTensor)
394            and t not in [torch.Tensor, FunctionalTensor]
395        ]
396
397        if unrecognized_types:
398            not_implemented_log.debug(
399                "FunctionalTensor unrecognized subclass(es): %s", unrecognized_types
400            )
401            return NotImplemented
402
403        def _can_decompose(func):
404            # See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832
405            # Never decompose dropout in export
406            if self.export and func == torch.ops.aten.dropout.default:
407                return False
408
409            # We unconditionally decompose ops that are maybe aliasing or mutating ops
410            if func in FunctionalTensor.maybe_aliasing_or_mutating_ops:
411                return True
412
413            # (1) we unconditionally decompose maybe-aliasing or maybe-mutating ops,
414            # because we must know statically of an op mutates or aliasing in order to functionalize it properly
415            # (2) for mutating ops that have CompositeImplicit decomps, we choose to decompose them today.
416            # In theory, we could walk this back and avoid decomposing them later if we need to.
417            alias_info_present = any(arg.alias_info for arg in func._schema.arguments)
418            if alias_info_present or func._schema.is_mutable:
419                return True
420
421            # If we are here, it means we are seeing functional composite op.
422            # For pre-dispatch IR or export inference IR, we wont' decompose them
423            if (self.export or self.pre_dispatch) and func._can_decompose():
424                if func.namespace not in ["aten", "prim"]:
425                    # TODO (tmanlaibaatar) check if the op is PT2 compliant
426                    warnings.warn(
427                        f"At pre-dispatch tracing, we assume that any custom op marked with "
428                        f"CompositeImplicitAutograd and have functional schema are safe to not decompose. "
429                        f"Found {func} to be one such op."
430                    )
431                return False
432
433            # in normal torch.compile IR, we decompose functional composite ops
434            return True
435
436        if (
437            func not in FunctionalTensor.metadata_fns
438            and _can_decompose(func)
439            # Not all funcs from __torch_dispatch__ are actual dispatcher ops,
440            # e.g. prim.device
441            and torch._C._dispatch_has_kernel(func.name())
442        ):
443            with self:
444                r = func.decompose(*args, **kwargs)
445                if r is not NotImplemented:
446                    return r
447
448        def wrap(x):
449            # Only wrap our outputs in subclasses if the inner functionalization call
450            # also wrapped outputs into FunctionalTensorWrappers.
451            # When can this happen? e.g. `torch.div(2, 2)`
452            assert not isinstance(x, FunctionalTensor)
453            if isinstance(x, torch.Tensor) and torch._is_functional_tensor(x):
454                return FunctionalTensor(x, self)
455            return x
456
457        def unwrap(x):
458            return x.elem
459
460        from torch._higher_order_ops.auto_functionalize import (
461            can_auto_functionalize,
462            do_auto_functionalize,
463            do_auto_functionalize_v2,
464        )
465
466        if can_auto_functionalize(
467            func
468        ) and not torch._C._dispatch_has_kernel_for_dispatch_key(
469            func.name(), torch._C.DispatchKey.Functionalize
470        ):
471            # it doesn't matter what mode we use here because
472            # the implementation of do_auto_functionalize doesn't
473            # interact with FunctionalTensorMode at all
474            if self.export or not inductor_config.enable_auto_functionalized_v2:
475                return do_auto_functionalize(func, args, kwargs)
476            else:
477                return do_auto_functionalize_v2(func, args, kwargs)
478
479        from torch._higher_order_ops.effects import handle_effects, has_effects
480
481        if has_effects(func, args, kwargs):
482            assert not torch._C._dispatch_has_kernel_for_dispatch_key(
483                func.name(), torch._C.DispatchKey.Functionalize
484            )
485            return handle_effects(
486                self._allow_token_discovery, self._tokens, func, args, kwargs
487            )
488
489        args_unwrapped, kwargs_unwrapped = pytree.tree_map_only(
490            FunctionalTensor, unwrap, (args, kwargs)
491        )
492
493        # Expectation: functionalization should not **already** be enabled above our mode.
494        # Why would that be bad? when we return a FunctionalTensor here, we don't want functionalization
495        # to run above this mode and further wrap that output in **another** C++ FunctionalTensorWrapper.
496        is_included = torch._C._dispatch_tls_is_dispatch_key_included(
497            torch._C.DispatchKey.Functionalize
498        )
499        is_excluded = torch._C._dispatch_tls_is_dispatch_key_excluded(
500            torch._C.DispatchKey.Functionalize
501        )
502        assert is_excluded or not is_included
503        include_to_set = (
504            torch._C._dispatch_tls_local_include_set()
505            | torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
506        )
507        exclude_to_set = (
508            torch._C._dispatch_tls_local_exclude_set().remove(
509                torch._C.DispatchKey.Functionalize
510            )
511            - FunctionalTensor._extra_dispatch_keys
512        )
513
514        # All we want to do here is re-use the existing C++ functionalization logic.
515        # This requires swizzling our TLS dispatch keys so that the Functionalize key is active.
516        with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
517            try:
518                # By default for python functionalization (for AOTAutograd), we reapply views.
519                old_apply_views = torch._functionalize_enable_reapply_views(True)  # type: ignore[attr-defined]
520
521                # Sometimes these functions cannot be directly dispatched to functionalize key
522                # because args are sometimes not functional tensors for some reason?
523                if func in FunctionalTensor.metadata_fns:
524                    outs_unwrapped = func(*args_unwrapped, **kwargs_unwrapped)
525                    outs_wrapped = pytree.tree_map_only(
526                        torch.Tensor, wrap, outs_unwrapped
527                    )
528                else:
529                    # When we dispatch to the C++ functionalization kernel, we might need to jump back to the
530                    # PreDispatch mode stack afterwards, to handle any other PreDispatch modes underneath
531                    # FunctionalTensorMode. If we call func() directly, we would need to exclude PreDispatch
532                    # from the TLS in order to avoid infinite looping, but this would prevent us from coming
533                    # back to PreDispatch later
534                    outs_unwrapped = func._op_dk(
535                        torch._C.DispatchKey.Functionalize,
536                        *args_unwrapped,
537                        **kwargs_unwrapped,
538                    )
539                    # We don't allow any mutation on result of dropout or _to_copy
540                    if self.export:
541                        if func in (
542                            torch.ops.aten.dropout.default,
543                            torch.ops.aten._to_copy.default,
544                        ):
545                            torch._freeze_functional_tensor(outs_unwrapped)  # type: ignore[attr-defined]
546                    outs_wrapped = pytree.tree_map_only(
547                        torch.Tensor, wrap, outs_unwrapped
548                    )
549            finally:
550                torch._disable_functionalization()
551                torch._functionalize_enable_reapply_views(old_apply_views)  # type: ignore[attr-defined]
552
553        is_included = torch._C._dispatch_tls_is_dispatch_key_included(
554            torch._C.DispatchKey.Functionalize
555        )
556        is_excluded = torch._C._dispatch_tls_is_dispatch_key_excluded(
557            torch._C.DispatchKey.Functionalize
558        )
559        assert is_excluded or not is_included
560
561        if (
562            # If no outputs are our functional subclass, then don't try to fix up aliasing
563            not any(
564                isinstance(x, FunctionalTensor)
565                for x in pytree.tree_leaves(outs_wrapped)
566            )
567            # Since lift_fresh lifts its argument into a functional tensor, we can skip the
568            # aliasing correction step. Otherwise, we would be setting the storage of a
569            # lifted tensor to that of an unlifted tensor.
570            # Ref: https://github.com/pytorch/pytorch/issues/111506
571            or func == torch.ops.aten.lift_fresh.default
572        ):
573            return outs_wrapped
574        # for metadata mutations, need to manually mutate the metadata of the FunctionalTensor wrapper
575        if (
576            torch.Tag.inplace_view in func.tags
577            and func is not torch.ops.aten.set_.source_Tensor
578        ):
579            with torch.utils._mode_utils.no_dispatch():
580                func(*args, **kwargs)
581        # Wrapper tensor subclasses do not have correct aliasing info! Use this util to manually correct the output aliasing.
582        # inplace ops like `aten.add_()` are expected to return inputs **directly**, instead of creating fresh tensor objects.
583        # Use this util to figure out the right thing to return.
584        # If none of our inputs were wrapped, then we have no FunctionalTensor outputs that we need to fix up storages for.
585        return return_and_correct_aliasing(func, args, kwargs, outs_wrapped)
586
587    @classmethod
588    def is_infra_mode(cls) -> bool:
589        return True
590
591
592@contextlib.contextmanager
593def disable_functional_mode():
594    return _disable_infra_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL)
595
596
597# This is similar to torch.func.functionalize, but:
598# - It uses FunctionalTensorMode, and FunctionalTensor (a python subclass).
599#   One important advantage to using this mode is that it will let us
600#   run functionalization underneath __torch_dispatch__,
601#   which we need in AOTAutograd.
602# - Doing so means that it does not automatically compose with other
603#   functorch transforms, since these transforms always run above __torch_dispatch__.
604#   That's why this util lives here, and not in functorch.
605def dispatch_functionalize(func, mode: FunctionalTensorMode = FunctionalTensorMode()):
606    # TODO: pull these from aot autograd
607    def to_fun(t):
608        if isinstance(t, torch.Tensor):
609            return FunctionalTensor.to_functional(t)
610        return t
611
612    def from_fun(t):
613        if not isinstance(t, FunctionalTensor):
614            # quick sanity assert
615            if isinstance(t, torch.Tensor):
616                assert not torch._is_functional_tensor(t)
617            return t
618        torch._sync(t)
619        return torch._from_functional_tensor(t.elem)
620
621    def inner(*args, **kwargs):
622        disable_above = torch._C._ExcludeDispatchKeyGuard(
623            torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
624        )
625        with disable_above, mode:
626            func_args = pytree.tree_map_only(torch.Tensor, to_fun, args)
627            func_kwargs = pytree.tree_map_only(torch.Tensor, to_fun, kwargs)
628            func_outputs = func(*func_args, **func_kwargs)
629            outputs = pytree.tree_map_only(FunctionalTensor, from_fun, func_outputs)
630
631            return outputs
632
633    return inner
634
635
636class BaseFunctionalizeAPI(ABC):
637    @abstractmethod
638    def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]:
639        pass
640
641    @abstractmethod
642    def unwrap_tensors(
643        self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...]]
644    ) -> Any:
645        pass
646
647    @abstractmethod
648    def functionalize(self, inner_f: Callable) -> Callable:
649        pass
650
651    @abstractmethod
652    def redispatch_to_next(self) -> ContextManager:
653        pass
654
655    @abstractmethod
656    def replace(self, input_tensor, output_tensor) -> None:
657        pass
658
659    @abstractmethod
660    def commit_update(self, tensor) -> None:
661        pass
662
663    @abstractmethod
664    def sync(self, tensor) -> None:
665        pass
666
667    @abstractmethod
668    def mark_mutation_hidden_from_autograd(self, tensor) -> None:
669        pass
670
671
672class PythonFunctionalizeAPI(BaseFunctionalizeAPI):
673    def __init__(
674        self, mode: Optional[FunctionalTensorMode] = None, pre_dispatch: bool = False
675    ) -> None:
676        super().__init__()
677        self.mode = mode if mode else FunctionalTensorMode()
678        self.pre_dispatch = pre_dispatch
679
680    def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]:
681        with self.mode:
682            return torch.utils._pytree.tree_map_only(
683                torch.Tensor, FunctionalTensor.to_functional, args
684            )
685
686    def unwrap_tensors(
687        self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...], List[torch.Tensor]]
688    ) -> Any:
689        return torch.utils._pytree.tree_map_only(
690            FunctionalTensor, FunctionalTensor.from_functional, args
691        )
692
693    def functionalize(self, inner_f: Callable) -> Callable:
694        return dispatch_functionalize(inner_f, self.mode)
695
696    def redispatch_to_next(self) -> ContextManager:
697        # [NOTE] We don't do anything here because at the time
698        # we exercise this path, we would have already popped the
699        # FunctionalTensorMode from mode stack. Since FunctionalTensorMode
700        # is now stateful, it is better to explicitly pass in correct mode
701        # directly instead of globally setting it.
702        return contextlib.nullcontext()
703
704    def replace(self, input_tensor, output_tensor) -> None:
705        assert isinstance(input_tensor, FunctionalTensor)
706        assert not isinstance(output_tensor, FunctionalTensor)
707        input_tensor.replace_(output_tensor)
708
709    def commit_update(self, tensor) -> None:
710        assert isinstance(tensor, FunctionalTensor)
711        tensor.commit_update()
712
713    def sync(self, tensor) -> None:
714        assert isinstance(tensor, FunctionalTensor)
715        tensor.sync()
716
717    def mark_mutation_hidden_from_autograd(self, tensor) -> None:
718        assert isinstance(tensor, FunctionalTensor)
719        tensor.mark_mutation_hidden_from_autograd()
720
721
722class CppFunctionalizeAPI(BaseFunctionalizeAPI):
723    def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]:
724        from torch._functorch.eager_transforms import _wrap_all_tensors_to_functional
725
726        return _wrap_all_tensors_to_functional(args, level=0)
727
728    def unwrap_tensors(
729        self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...]]
730    ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
731        from torch._functorch.eager_transforms import (
732            _unwrap_all_tensors_from_functional,
733        )
734
735        return _unwrap_all_tensors_from_functional(args, reapply_views=_reapply_views())
736
737    def functionalize(self, inner_f: Callable) -> Callable:
738        return torch.func.functionalize(inner_f)
739
740    def redispatch_to_next(self) -> ContextManager:
741        return torch._C._ExcludeDispatchKeyGuard(
742            torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
743        )
744
745    def replace(self, input_tensor, output_tensor) -> None:
746        torch._functionalize_replace(input_tensor, output_tensor)
747
748    def commit_update(self, tensor) -> None:
749        torch._functionalize_commit_update(tensor)
750
751    def sync(self, tensor) -> None:
752        torch._functionalize_sync(tensor)
753
754    def mark_mutation_hidden_from_autograd(self, tensor) -> None:
755        torch._functionalize_mark_mutation_hidden_from_autograd(tensor)
756
757
758class FunctorchFunctionalizeAPI(BaseFunctionalizeAPI):
759    def __init__(self, interpreter):
760        self.interpreter = interpreter
761
762    def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]:
763        from torch._functorch.eager_transforms import _wrap_all_tensors_to_functional
764
765        return _wrap_all_tensors_to_functional(args, level=self.interpreter.level())
766
767    def unwrap_tensors(
768        self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...]]
769    ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
770        from torch._functorch.eager_transforms import (
771            _unwrap_all_tensors_from_functional,
772        )
773
774        return _unwrap_all_tensors_from_functional(
775            args, reapply_views=self.interpreter.functionalize_add_back_views()
776        )
777
778    def functionalize(self, inner_f: Callable) -> Callable:
779        return torch.func.functionalize(
780            inner_f,
781            remove=(
782                "mutations_and_views"
783                if self.interpreter.functionalize_add_back_views()
784                else "mutations"
785            ),
786        )
787
788    def redispatch_to_next(self) -> ContextManager:
789        return self.interpreter.lower()
790
791    def replace(self, input_tensor, output_tensor) -> None:
792        torch._functionalize_replace(input_tensor, output_tensor)
793
794    def commit_update(self, tensor) -> None:
795        torch._functionalize_commit_update(tensor)
796
797    def sync(self, tensor) -> None:
798        torch._functionalize_sync(tensor)
799
800    def mark_mutation_hidden_from_autograd(self, tensor) -> None:
801        torch._functionalize_mark_mutation_hidden_from_autograd(tensor)
802
803
804def mb_unwrap_functional_tensor(tensor: torch.Tensor):
805    if isinstance(tensor, FunctionalTensor):
806        return torch._from_functional_tensor(tensor.elem)
807    return tensor
808