xref: /aosp_15_r20/external/pytorch/torch/_functorch/_aot_autograd/collect_metadata_analysis.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""
3This module is one of the analysis modules - it takes as input a function or graph
4and some preexisting properties, and returns some data that is useful for deciding
5how to further proceed with compilation or construct runtime wrappers.
6
7In particular, the analysis here constructs view and mutation metadata from running
8a functionalized version of the graph under compilation.
9"""
10
11import collections
12import contextlib
13import logging
14from functools import wraps
15from typing import Callable, DefaultDict, Dict, List, Optional
16
17import torch
18import torch.utils._pytree as pytree
19from torch import Tensor
20from torch._guards import detect_fake_mode
21from torch._logging import getArtifactLogger
22from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode
23from torch._subclasses.meta_utils import safe_is_leaf
24from torch.fx.experimental.symbolic_shapes import is_concrete_int
25from torch.multiprocessing.reductions import StorageWeakRef
26from torch.utils._python_dispatch import (
27    is_traceable_wrapper_subclass,
28    transform_subclass,
29)
30
31from .functional_utils import (
32    are_all_mutations_hidden_from_autograd,
33    are_all_mutations_under_no_grad_or_inference_mode,
34    from_fun,
35    has_data_mutation,
36    has_metadata_mutation,
37    has_same_metadata,
38    to_fun,
39    was_inductor_storage_resized,
40)
41from .schemas import (
42    FunctionalTensorMetadataEq,
43    InputAliasInfo,
44    MutationType,
45    OutputAliasInfo,
46    OutputType,
47    ViewAndMutationMeta,
48)
49from .subclass_utils import create_subclass_meta
50from .utils import _get_autocast_states, KNOWN_TYPES, strict_zip
51
52
53zip = strict_zip
54
55log = logging.getLogger(__name__)
56static_input_logger = getArtifactLogger("torch._dynamo", "cudagraph_static_inputs")
57
58
59# Note [Tangents must be contiguous]
60# We force tangents to be contiguous today.
61# The idea is that we are technically making a guess about the strides of our tangents,
62# while we trace out the joint.
63# Today, we force this guess to be correct by additioanlly calling contiguous()
64# on all tangents at runtime.
65# In the future, you could imagine lifting this restriction, since these contiguous()
66# calls can have noticeable perf overhead depending on the model.
67def coerce_tangent(x):
68    if not isinstance(x, Tensor):
69        return x
70    out = x.detach().contiguous()
71    # Note [Tangents must be contiguous, Part 2]
72    # In the same way that "what strides do we assigns to our tangents" is a question
73    # that we can not answer (and therefore have to guess) as we trace the backward ahead-of-time,
74    # The same applies to any tensor subclass metadata, when we have tangents that are subclasses.
75    # To handle this situation, we have two new methods that a tensor subclass can implement:
76    # (1) __coerce_tangent_metadata__(self)
77    #     Given a subclass with "non-standard" metadata, turn it into a new subclass with "normal" metadata.
78    #     The main example here is a DTensor with the "_Partial" placement.
79    #     If we have a forward output with a _Partial placement, and corresponding tangent
80    #     with a Replicate/Shard placement, we have no way to convert the tangent "back" to a _Partial placement.
81    #     This method lets us avoid the problem entirely by allowing subclasses to ensure that we can never
82    #     have a tangent with "problematic" metadata, that we cannot convert to.
83    # (1) __coerce_same_metadata_as_tangent__(self, metadata)
84    #     Given a subclass, and a target differing metadata,
85    #     convert self to have the same metadata as the target.
86    #     With DTensor being the main example, we can use this to convert a DTensor with a Replicate()
87    #     placement into one with a Shard() placement, in the case that we "guessed wrong",
88    #     and traced tangents with a Shard() placement at compile time.
89    #
90    if is_traceable_wrapper_subclass(out) and hasattr(
91        out, "__coerce_tangent_metadata__"
92    ):
93        out = out.__coerce_tangent_metadata__()  # type: ignore[attr-defined]
94    # It's possible to have a subclass that advertises as contiguous,
95    # but has noncontiguous inner tensors.
96    # Force these to be conntiguous too
97    if is_traceable_wrapper_subclass(out):
98        for attr in out.__tensor_flatten__()[0]:  # type: ignore[attr-defined]
99            elem = getattr(out, attr)
100            if not elem.is_contiguous():
101                elem_contig = elem.contiguous()
102                setattr(out, attr, elem_contig)
103    return out
104
105
106# This is a version of functionalization that is specifically designed
107# for the AOTAutograd use case.
108#
109# Unlike functorch's variant, this doesn't use the functorch level system,
110# instead it directly uses PyTorch's conventional dispatcher to hit the
111# functionalization key.  In particular, this means that FunctionalTensorWrapper
112# can have autograd data stored directly on it.
113#
114# In typical AOTAutograd usage, the dispatch key order will look like:
115#
116#   Autograd - Functionalization ~~~~> Proxy Mode - Fake Tensor
117#       outer tensor                        inner tensor
118#
119# Returns:
120# - ViewAndMutationMeta, telling us metadata about the inputs and outputs, and
121#   The list of outputs from the forward, but **only** the outputs that we need
122#   to pass in as tangents into the backward.
123#   Specifically, aliased outputs from the forward get regenerated, and don't participate
124#   in the compiled backward function.
125def run_functionalized_fw_and_collect_metadata(
126    f,
127    *,
128    keep_input_mutations: bool,
129    # TODO: refactor to kill this flag
130    is_train: bool = False,
131    # Note: this is guaranteed to be set when running under dynamo
132    static_input_indices: Optional[List[int]] = None,
133    pre_dispatch: bool = False,
134) -> Callable[..., ViewAndMutationMeta]:
135    memo: Dict[Tensor, Tensor] = {}
136
137    def _to_fun(t):
138        if isinstance(t, Tensor):
139            if t in memo:
140                return memo[t]
141            r = to_fun(t)
142            memo[t] = r
143            return r
144        else:
145            return t
146
147    @wraps(f)
148    def inner(*flat_args):
149        # This function is meant to be run with the forward, which expects a flat list of tensor/symint/other args.
150        assert all(isinstance(a, tuple(KNOWN_TYPES)) for a in flat_args)
151
152        input_info: List[InputAliasInfo] = []
153        output_info: List[OutputAliasInfo] = []
154
155        prior_grad_enabled = torch.is_grad_enabled()
156        prior_autocast_states = _get_autocast_states()
157
158        # See Note [Disabling Functionalize TLS Above Python Functionalization]
159        disable_above = torch._C._ExcludeDispatchKeyGuard(
160            torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
161        )
162
163        # It doesn't matter if we run this under predispatch or not because it is
164        # only for figuring out metadata
165        mode = FunctionalTensorMode(_allow_token_discovery=True)
166        suppress_pending = contextlib.nullcontext()
167        fake_mode = detect_fake_mode()
168        if fake_mode and (shape_env := fake_mode.shape_env):
169            suppress_pending = shape_env.ignore_fresh_unbacked_symbols()
170        with disable_above, mode, suppress_pending:
171            # precondition: The passed in function already handles unflattening inputs + flattening outputs
172            flat_f_args = pytree.tree_map(_to_fun, flat_args)
173            flat_f_outs = f(*flat_f_args)
174            # We didn't do any tracing, so we don't need to process the
175            # unbacked symbols, they will just disappear into the ether.
176            # Also, prevent memoization from applying.
177            if fake_mode:
178                fake_mode.epoch += 1
179                fake_mode.reset_nt_tensor_id_counter()
180
181        if prior_autocast_states != _get_autocast_states():
182            raise RuntimeError(
183                "AOTAutograd does not support tracing graphs that mutate the autocast state. "
184                "Dynamo will only insert autocast context managers (e.g. with torch.autocast(..)) into the graph, "
185                "which will unwind all of their mutations to autocast state before the graph exits. "
186                "If you encounter this error while using torch.compile, please file a bug."
187            )
188
189        # Inspect the state of the input tensor functional wrapper to detect input mutation info
190        # If inp[i] has a metadata-only mutation, then maybe_inputs_with_mutated_metadata[i] contains the updated version
191        for i, (arg, f_arg) in enumerate(zip(flat_args, flat_f_args)):
192            # NB: Mutation of non-contiguous tensor subclass input can result in a mismatch in
193            # strides between the functionalized arg inner tensors and non-functionalized arg inner
194            # tensors. This is a problem as the inner tensor stride change may not be reflected
195            # correctly in the outer tensor, so disallow this for now.
196            mutates_data = has_data_mutation(f_arg)
197            if (
198                mutates_data
199                and not arg.is_contiguous()
200                and is_traceable_wrapper_subclass(arg)
201            ):
202                raise RuntimeError(
203                    "Mutations on non-contiguous inputs are currently not allowed on "
204                    "tensor subclasses"
205                )
206
207            if not isinstance(arg, Tensor):
208                new_arg = arg
209            else:
210                new_arg = from_fun(f_arg)
211            mutates_metadata = has_metadata_mutation(
212                f_arg, arg, check_only_storage_mutation=False
213            )
214            if mutates_metadata and is_traceable_wrapper_subclass(arg):
215                raise RuntimeError(
216                    "Metadata mutations are currently not allowed on tensor subclasses"
217                )
218            mutates_storage_metadata = has_metadata_mutation(
219                f_arg, arg, check_only_storage_mutation=True
220            )
221            mutations_hidden_from_autograd = are_all_mutations_hidden_from_autograd(
222                f_arg
223            )
224            mutations_under_no_grad_or_inference_mode = (
225                mutates_data
226                and are_all_mutations_under_no_grad_or_inference_mode(f_arg)
227            )
228            mutation_inductor_storage_resize = was_inductor_storage_resized(f_arg)
229
230            if mutates_storage_metadata:
231                mutates_data = False
232
233            requires_grad = isinstance(f_arg, torch.Tensor) and f_arg.requires_grad
234
235            input_info.append(
236                InputAliasInfo(
237                    is_leaf=isinstance(arg, Tensor) and safe_is_leaf(arg),
238                    mutates_data=mutates_data,
239                    mutates_metadata=mutates_metadata,
240                    mutations_hidden_from_autograd=mutations_hidden_from_autograd,
241                    mutates_storage_metadata=mutates_storage_metadata,
242                    mutations_under_no_grad_or_inference_mode=mutations_under_no_grad_or_inference_mode,
243                    mutation_inductor_storage_resize=mutation_inductor_storage_resize,
244                    requires_grad=requires_grad,
245                    keep_input_mutations=keep_input_mutations,
246                )
247            )
248
249        # If a function involves creating a tensor, and returning a view of it, such that its _base is the intermediate,
250        # We need to make sure our graph returns the _base as a graph output, and we manually recreate the view
251        # to return to the user. Why? The backend compiler is free to (incorrectly) not set requires_grad
252        # on the base tensor, but we are obligated to properly set requires-gradness on the real output.
253
254        inp_storage_refs = {
255            StorageWeakRef(inpt.untyped_storage()): idx
256            for idx, inpt in enumerate(flat_f_args)
257            if isinstance(inpt, Tensor)
258        }
259
260        # We need inp tensor id's to be able to tell if an outputs **are** inputs.
261        inp_tensor_ids = {id(inpt) for inpt in flat_f_args if isinstance(inpt, Tensor)}
262        # We need output tensor id's to tell if any output._base` attributes **are** other outputs.
263        # (This is also a dict because we need to know that output's index, so we can regenerate
264        # the alias from it).
265        out_tensor_ids = {id(o): i for i, o in enumerate(flat_f_outs)}
266
267        # Keep track of which outputs alias other outputs
268        out_tensor_alias_counts: DefaultDict = collections.defaultdict(int)
269        # This tells us, for a given group of outputs that alias each other,
270        # whether they e.g. all came from an unbind call
271        num_aliased_tensors_that_are_multi_output_views: DefaultDict = (
272            collections.defaultdict(int)
273        )
274        out_storage_to_tensors: DefaultDict = collections.defaultdict(set)
275        curr_storage = None
276        for o in flat_f_outs:
277            if isinstance(o, torch.Tensor):
278                curr_storage = StorageWeakRef(o.untyped_storage())
279                out_tensor_alias_counts[curr_storage] += 1
280                # Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call]
281                # This is an optimization on top of the "alias of intermediates" logic,
282                # which you can read more about under Note [AOT Autograd: outputs aliasing inputs or intermediates!]
283                #
284                # Before describing the optimization: this is important for AOTAutograd to have good
285                # perf around, multi-output views. HOWEVER:
286                # - There is a more generic change to AOTAutograd that we'd like to make, that subsumes this case,
287                #   around using pre-dispatch tracing to partition out a graph so we can faithfully replay all
288                #   views without having to regenerate them at runtime.
289                # - It's loosely described in this doc (more details will be added soon):
290                #   https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit
291                # - Once that change lands, we should just rip out this "optimization", since:
292                #   (1) It will be fully unnecessary
293                #   (2) Although it is only a few lines of code, it is a bit difficult to reason about
294                #       its correctness with the autograd engine in all cases.
295                #
296                #
297                # What is this optimization? Consider the below case:
298                # def f(x):
299                #     intermediate = x.mul(2)
300                #     # x and intermediate here require grad
301                #     o1, o2, ... o10 = intermediate.unbind(-1)
302                #     return intermediate, o1, o2, ... o10
303                # Now, the "intermediate base" handling in AOTAutograd implies that we must do the following:
304                #   (1) return "intermediate as an extra output of the compiled graph
305                #   (2) regenerate each aliased output off of "intermediate", **outside** of the autograd.Function.
306                # The reason AOTAutograd ordinarily does this is for safety: the autograd engine needs to know
307                # that o1 through o10 are all aliased, and if we blindly return o1 through o10 from the autograd.Function,
308                # this information will be hidden.
309                # In particular, mutating one alias might require autograd to update autograd metadata on the other aliases
310                # (like their grad_fn, for example, when the autograd engine needs to do view-replay).
311                #
312                # However, intermediate_base logic can be bad for backward performance (we sometimes generate
313                # as_strided calls during the intermediate base logic, which can have a slow backward formula).
314                # Is it possible to find a set of conditions where it is **safe** to hide the output aliasing from autograd?
315                #
316                # For a set of outputs of the graph that alias each other, o_1...o_k, consider:
317                # (1) They came from the same multi-output view op, e.g. o_1, ..., o_k = intermediate.unbind(0)
318                # (2) If there are any other aliases of o_1 through o_k (in the example above, intermediate),
319                #     **at most** 1 can escape from the graph (e.g. there is not some other graph input/output
320                #     o_other, that aliases these outputs)
321                # (3) o_1...o_k all require_grad, they all share the same ._base, and their ._base requires grad.
322                #     This condition is important because it's what causes slowness in the intermediate_base
323                #     codepath of aot_autograd. Ordinarily, o_1...o_k would all get a grad_fn, and
324                #     aot_autograd's view-replay might give each output an AsStridedBackward as its grad_fn.
325                #     "K" AsStridedBackward calls will be *much* slower than a single UnbindBackward.
326                # In this setup, is it possible to mutate one of the outputs o_i in a way that would affect the autograd meta
327                # of the other aliases?
328                #
329                # Claim: No! Consider a few example (which I'm pretty sure cover all cases of mutation w.r.t. autograd):
330                # (a) What happens if we mutate any of o_1 through o_k directly?
331                #     Autograd raises an error:
332                #     "RuntimeError: Output 0 of UnbindBackward0 is a view and is being modified inplace. This view is
333                #      the output of a function that returns multiple views. Such functions do not allow the output
334                #      views to be modified inplace. You should replace the inplace operation by an out-of-place one."
335                # (b) What if we take a view of o_k and mutate it, o_k.view(o_k.shape).mul_(2)?
336                #     Autograd raises the same error- the "multi-output-view"ness of an alias propagates to future views.
337                # (c) What if we mutate o_k under no_grad?
338                #     Autograd raises the same error
339                # (d) What if we detach and mutate, e.g. o_k.detach().mul_(2)?
340                #     Autograd allows this, *but* autograd updates all alias's grad_fn's to be error functions when accessed.
341                #     Autograd raises the same error
342                # (e) What if we try to mutate another alias of o_1...o_k, that was **not** created from a multi-output view?
343                #     We promised that there is at most **one** such alias, e.g. intermediate in the example above.
344                #     You can mutate intermediate, but in eager mode this will change the grad_fn of o_1...o_k
345                #     to be error fn's.
346                #     Since intermediate was the *only* non-multi-output-alias, there are no other aliases
347                #     of `intermediate` around that were produced by the compiled fn and have a valid grad_fn.
348                #
349                # Coming back to this optimization:
350                # Given that it is not possible for mutating one of these aliases to affect the autograd metadata of another alias
351                # without causing an error in eager mode, we will simple hide the aliasing from autograd during torch.compile
352                # if all of the above conditions are met.
353                # This has the slight downside that it's possible to write some "bad" code that autograd will raise an error on
354                # in eager but fail to during torch.compile, but it has the benefit that this code has much better performance.
355                # NOTE: if and when we eventually update AOTAutograd to do the "view graph slicing" defined here:
356                # https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit,
357                # then this optimization will probably matter less and might be ok to remove.
358                is_cur_tensor_multi_out_view = isinstance(
359                    o, FunctionalTensor
360                ) and torch._functionalize_is_multi_output_view(  # type: ignore[attr-defined]
361                    o.elem
362                )
363                if is_cur_tensor_multi_out_view:
364                    num_aliased_tensors_that_are_multi_output_views[curr_storage] += 1
365                out_storage_to_tensors[curr_storage].add(o)
366
367        # maps the id of an intermediate base to its index in the output of the compiled forward
368        intermediate_base_tensor_id_to_output_idx: Dict[int, int] = {}
369        intermediate_bases: List[torch.Tensor] = []
370        # Why Do We Care If Storage Changed?
371        # It's important to understand the implications of storage changes in complex scenarios. Take this example:
372        #
373        # def f(x):
374        #     x_storage = x.untyped_storage()
375        #     non_leaf_tensor = torch.ones(4, requires_grad=True).clone()
376        #
377        #     # Using no_grad() and _unsafe_preserve_version_counter to simulate the .data = operation
378        #     with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(x):
379        #         x.set_(non_leaf_tensor.untyped_storage())
380        #
381        #     out = x.view(-1)
382        #
383        #     # Restoring x to its original storage, again simulating .data = operation
384        #     with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(x):
385        #         x.set_(x_storage)
386        #
387        #     return out
388        #
389        # In this scenario, 'x' and 'out' have different shapes and are stored at different memory addresses, aka no aliasing.
390        # However, due to how set_() and more specificlaly, set is functionalized, is defined to preserve eager semantics,
391        # the autograd engine mistakenly assumes that 'x' and 'out' are aliased, treating 'x' as 'out._base'.
392        # This misinterpretation leads to an 'alias_of_input' flag, causing an unnecessary as_strided() call to be generated,
393        # which could lead to issues later in the code.
394        for o in flat_f_outs:
395            functional_tensor_storage_changed = isinstance(
396                o, FunctionalTensor
397            ) and torch._functionalize_was_storage_changed(  # type: ignore[attr-defined]
398                o.elem
399            )
400            curr_storage = (
401                None
402                if not isinstance(o, torch.Tensor)
403                else StorageWeakRef(o.untyped_storage())
404            )
405            outs_with_identical_metadata_that_require_grad = (
406                []
407                if not isinstance(o, Tensor)
408                else [
409                    curr
410                    for curr in out_storage_to_tensors[curr_storage]
411                    if has_same_metadata(o, curr)
412                    and curr.requires_grad
413                    and o is not curr
414                ]
415            )
416
417            # See Note [Accessing .grad_fn on FunctionalTensor]
418            # In-place operations on views will trigger a lazy rebase of the autograd graph;
419            # this runs during access to the .grad_fn. The rebase logic will invoke view ops
420            # on FunctionalTensors, so we must enable a FunctionalTensorMode here to ensure
421            # these op calls succeed.
422            grad_fn = None
423            if isinstance(o, Tensor):
424                with FunctionalTensorMode():
425                    grad_fn = o.grad_fn
426
427            is_result_of_custom_autograd_fn = False
428            # Need to check for both custom cpp (CppFunction) and python (BackwardCFunction)
429            # autograd fns
430            if type(grad_fn).__name__ == "CppFunction":
431                is_result_of_custom_autograd_fn = True
432            if isinstance(grad_fn, torch.autograd.function.BackwardCFunction):
433                is_result_of_custom_autograd_fn = True
434
435            if not isinstance(o, Tensor):
436                output_type = OutputType.non_alias
437                base_idx = None
438            elif (
439                curr_storage in inp_storage_refs
440                and grad_fn is not None
441                and is_result_of_custom_autograd_fn
442            ):
443                output_type = OutputType.custom_function_view
444                base_idx = None
445            elif (
446                curr_storage in inp_storage_refs
447                and not functional_tensor_storage_changed
448            ):
449                base_idx = inp_storage_refs[curr_storage]
450                is_input_tensor = id(o) in inp_tensor_ids
451                num_aliased_outs = out_tensor_alias_counts[curr_storage]
452                num_multi_output_view_outs = (
453                    num_aliased_tensors_that_are_multi_output_views[curr_storage]
454                )
455                num_aliased_outs_that_are_not_multi_output_views = (
456                    num_aliased_outs - num_multi_output_view_outs
457                )
458                if (
459                    grad_fn is not None
460                    and num_aliased_outs_that_are_not_multi_output_views == 0
461                ):
462                    # See Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call]
463                    # In particular, given:
464                    # def f(x):
465                    #     return list(x.unbind(0))
466                    # The main reason we ordinarily try to regenerate these output aliases outside of the
467                    # compiled autograd.Function is because if any of the outputs are later mutated,
468                    # autograd needs to perform view-replay to regenerate them.
469                    # However, autograd does not allow users to mutate multi-output views
470                    # in any way that can change the autograd metadata of other aliases.
471                    # So we hide this aliasing from autograd here.
472                    log.debug(
473                        "Encountered AOTAutograd case: differentiable outputs that \
474alias each other from a multi-output view call"
475                    )
476                    output_type = OutputType.non_alias
477                elif is_input_tensor:
478                    output_type = OutputType.is_input
479                else:
480                    output_type = OutputType.alias_of_input
481            elif functional_tensor_storage_changed and id(o) in inp_tensor_ids:
482                # When there is a set_() on an input, we cannot rely on checking storages
483                # to detect if we are returning an input (since the inputs storage is different)
484                assert curr_storage is not None
485                base_idx = inp_storage_refs[curr_storage]
486                output_type = OutputType.is_input
487
488            # We only need to handle the intermediate base case when both
489            # the intermediate base and the output require gradients.
490            # See Note [AOT Autograd: outputs aliasing inputs or intermediates!]
491            elif o._base is not None and o.requires_grad and o._base.requires_grad:
492                num_aliased_outs = out_tensor_alias_counts[curr_storage]
493                num_multi_output_view_outs = (
494                    num_aliased_tensors_that_are_multi_output_views[curr_storage]
495                )
496                num_aliased_outs_that_are_not_multi_output_views = (
497                    num_aliased_outs - num_multi_output_view_outs
498                )
499                # Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call]
500                if (
501                    out_tensor_alias_counts[curr_storage] == 1
502                    or num_aliased_outs_that_are_not_multi_output_views <= 1
503                ):
504                    # Note [Intermediate Bases Optimization]
505                    # Normally if we have an output that aliases an intermediate,
506                    # we need to add the extra "intermediate base" logic further down
507                    # to prevent autograd from yelling at us if the user later tries to
508                    # mutate that output.
509                    # However, the common case here is if we have an output that aliases an intermediate,
510                    # but doesn't alias any other outputs.
511                    # In that case, autograd shouldn't have to worry about the aliasing at all
512                    # (if that output is mutated, there are no other live aliases for autograd to worry about).
513                    # The "intermediate bases" can hurt inductor perf by forcing more variables to become outputs.
514                    # So as an optimization, we won't do intermediate base handling in this case.
515                    # Instead, we'll hide the aliasing from autograd using aten._unsafe_view().
516                    if (
517                        out_tensor_alias_counts[curr_storage] != 1
518                        and num_aliased_outs_that_are_not_multi_output_views <= 1
519                    ):
520                        log.debug(
521                            "Encountered AOTAutograd case: differentiable outputs that alias each other \
522from a multi-output view call"
523                        )
524                    output_type = OutputType.unsafe_view_alias
525                    base_idx = None
526                else:
527                    # First, check if o's ._base is an existing output
528                    maybe_existing_out_idx = out_tensor_ids.get(id(o._base), None)
529                    if maybe_existing_out_idx is not None:
530                        # Special case where the output is an alias of a graph intermediate, but that intermediate
531                        # is itself also a user output.
532                        output_type = (
533                            OutputType.alias_of_intermediate_base_is_user_output
534                        )
535                        base_idx = maybe_existing_out_idx
536                    else:
537                        # Next, check if o's ._base is an intermediate base that we already returned
538                        maybe_existing_base_output_idx = (
539                            intermediate_base_tensor_id_to_output_idx.get(
540                                id(o._base), None
541                            )
542                        )
543                        if maybe_existing_base_output_idx is not None:
544                            output_type = OutputType.alias_of_intermediate
545                            base_idx = maybe_existing_base_output_idx
546                        else:
547                            # Otherwise, take o._base and explicitly return it as an output in the compiled graph
548                            new_out_idx = len(intermediate_bases)
549                            base_idx = new_out_idx
550                            # Indicate to the logic later on (when we trace the joint)
551                            # that this particular output should get it's ._base appended to the forward graph outputs
552                            output_type = (
553                                OutputType.alias_of_intermediate_save_as_output
554                            )
555                            intermediate_base_tensor_id_to_output_idx[
556                                id(o._base)
557                            ] = new_out_idx
558                            intermediate_bases.append(o._base)
559            elif (
560                # See https://github.com/pytorch/pytorch/issues/100348 for this case.
561                # This protects against the specific case where a user fn returns (output, output.detach())
562                out_tensor_alias_counts[curr_storage] > 1
563                and len(outs_with_identical_metadata_that_require_grad) > 0
564                and not o.requires_grad
565            ):
566                # In theory we could use any of these tensors to regenerate the aliased outputs from,
567                # since they all alias each other and have identical metatadata
568                out_alias = outs_with_identical_metadata_that_require_grad[0]
569                existing_out_idx = out_tensor_ids[id(out_alias)]
570                output_type = OutputType.alias_of_intermediate_base_is_user_output
571                base_idx = existing_out_idx
572            else:
573                output_type = OutputType.non_alias
574                base_idx = None
575
576            if isinstance(o, torch.Tensor):
577                dynamic_dims = {
578                    i for i, s in enumerate(o.shape) if not is_concrete_int(s)
579                }
580            else:
581                dynamic_dims = None
582
583            # Save the current FunctionalTensor output.
584            #
585            # This will be used at runtime for reconstructing output views from
586            # their respective base tensors.
587            #
588            # The FunctionalTensor will be saved if one of the 2 conditions below
589            # is true:
590            functional_tensor = None
591            if (
592                # 1. If the output_type is either of:
593                #    (i) alias_of_intermediate;
594                #    (ii) alias_of_intermediate_save_as_output; or
595                #    (iii) alias_of_intermediate_base_is_user_output.
596                #
597                # No need to worry about in-place view operations here, since
598                # this functionalization step elimitates mutations.
599                #
600                # i.e. we have access to the actual base tensor, before the
601                # in-place operation was applied.
602                output_type
603                in (
604                    OutputType.alias_of_intermediate,
605                    OutputType.alias_of_intermediate_save_as_output,
606                    OutputType.alias_of_intermediate_base_is_user_output,
607                )
608            ) or (
609                # 2. If the output_type is alias_of_input, and no in-place view
610                #    operationthe was run on the input (base tensor).
611                #
612                # In this case, we need to check for metadata mutation because
613                # the runtime explicitly reconstructs the inputs, before actually
614                # reconstructing the outputs. Due to in-place view operations, the
615                # fully reconstructed input may not be this output base tensor
616                # anymore.
617                output_type == OutputType.alias_of_input
618                and base_idx is not None
619                and not input_info[base_idx].mutates_metadata
620            ):
621                if isinstance(o, FunctionalTensor):
622                    functional_tensor = FunctionalTensorMetadataEq(o.elem)
623
624            out_info = OutputAliasInfo(
625                output_type=output_type,
626                raw_type=type(o),
627                base_idx=base_idx,
628                dynamic_dims=dynamic_dims,
629                requires_grad=isinstance(o, torch.Tensor) and o.requires_grad,
630                functional_tensor=functional_tensor,
631            )
632            output_info.append(out_info)
633
634        # See Note [AOT Autograd: Views to avoid tangents aliasing inputs]
635        def view_avoid_dupes_with_primals(t):
636            if isinstance(t, Tensor) and is_traceable_wrapper_subclass(t):
637                return transform_subclass(
638                    t, lambda _, inner_t: view_avoid_dupes_with_primals(inner_t)
639                )
640            if isinstance(t, Tensor):
641                return t.view(t.shape)
642            return t
643
644        # This analysis function returns *only* the outputs that are meant to be tangents to the backwards.
645        # Anything that aliases (inputs returned in the fw due to metadata mutations, or outputs that alias inputs/intermediates)
646        # are *regenerated* later, and not used directly in the autograd graph
647        f_input_tangents = [
648            inp
649            for inp, info in zip(flat_f_args, input_info)
650            if info.mutation_type == MutationType.MUTATED_OUT_GRAPH
651            and info.mutates_data
652            and info.requires_grad
653        ]
654        f_output_tangents = [
655            o
656            for o, info in zip(flat_f_outs, output_info)
657            if info.output_type
658            in [
659                OutputType.non_alias,
660                OutputType.unsafe_view_alias,
661                OutputType.custom_function_view,
662            ]
663            and issubclass(info.raw_type, torch.Tensor)
664            and info.requires_grad
665        ]
666        # intermediate bases are also included in the backward graph
667        f_tangents = f_input_tangents + f_output_tangents + intermediate_bases
668        traced_tangents = pytree.tree_map(from_fun, f_tangents)
669        traced_tangents = pytree.tree_map(
670            view_avoid_dupes_with_primals, traced_tangents
671        )
672        # See Note [Tangents must be contiguous]
673        traced_tangents = pytree.tree_map(
674            coerce_tangent,
675            traced_tangents,
676        )
677        user_outs = pytree.tree_map(from_fun, f_output_tangents)
678
679        nonlocal static_input_indices
680        static_input_indices = static_input_indices or []
681        if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
682            passed_indices = set(static_input_indices)
683            static_input_indices = [
684                i
685                for i, arg in enumerate(flat_args)
686                if (isinstance(arg, torch.nn.Parameter) or i in passed_indices)
687            ]
688
689        static_input_logger.debug(
690            "static input indices metadata analysis: %s", static_input_indices
691        )
692
693        f_mutated_inputs = [
694            inp
695            for inp, info in zip(flat_f_args, input_info)
696            if info.mutation_type == MutationType.MUTATED_OUT_GRAPH
697        ]
698        f_metadata_mutated_inputs = [
699            inp for inp, info in zip(flat_f_args, input_info) if info.mutates_metadata
700        ]
701        # This logic (annoyingly) re-figures out exactly what the outputs to the compiled fw graph will be.
702        # When handling subclasses, we need info about **all** outputs of compiled forward graph,
703        # so we know precisely which graph outputs to wrap back into tensor subclasses
704        # Ideally we would refactor this so not have an is_train flag, and have the separate
705        # inference and training paths decide which inputs/output to ask for subclass info on.
706        # However, we currently stash indexing information on each SubclassMeta about its order
707        # in the graph outputs list.
708        f_fw_graph_outs = list(flat_f_outs)
709        if is_train or not keep_input_mutations:
710            f_fw_graph_outs = f_mutated_inputs + f_fw_graph_outs
711        else:
712            # even when "keep_input_mutations" is True,
713            # we never keep metadata-only mutations in the fw graph
714            f_fw_graph_outs = f_metadata_mutated_inputs + f_fw_graph_outs
715        if is_train:
716            f_fw_graph_outs = f_fw_graph_outs + intermediate_bases
717        fw_graph_outs = pytree.tree_map(from_fun, f_fw_graph_outs)
718
719        grad_enabled_mutation = None
720        if torch.is_grad_enabled() != prior_grad_enabled:
721            grad_enabled_mutation = torch.is_grad_enabled()
722            torch.set_grad_enabled(
723                prior_grad_enabled
724            )  # Restore the prior state after tracing it
725            log.debug(
726                (
727                    "grad_mode mutation encountered in graph. "
728                    "Will emit mutation epilogue, to set grad_mode=%s"
729                ),
730                grad_enabled_mutation,
731            )
732
733        metadata = ViewAndMutationMeta(
734            input_info=input_info,
735            output_info=output_info,
736            num_intermediate_bases=len(intermediate_bases),
737            keep_input_mutations=keep_input_mutations,
738            traced_tangents=traced_tangents,
739            subclass_inp_meta=create_subclass_meta(flat_args),
740            subclass_fw_graph_out_meta=create_subclass_meta(fw_graph_outs),
741            subclass_tangent_meta=create_subclass_meta(traced_tangents),
742            is_train=is_train,
743            grad_enabled_mutation=grad_enabled_mutation,
744            static_input_indices=static_input_indices,
745            tokens=mode._tokens,
746        )
747        return metadata
748
749    return inner
750