xref: /aosp_15_r20/external/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""
3This module defines runtime wrappers, which, based on previous analysis attempts to:
41. process the inputs and outputs
52. apply mutations
63. handle functionalized randomness
74. deduplicate inputs and consolidate views into their bases (see input_output_analysis)
8"""
9import builtins
10import collections
11import pprint
12from contextlib import nullcontext
13from dataclasses import dataclass, field
14from functools import wraps
15from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
16
17import torch
18import torch.utils.dlpack
19from torch import Tensor
20from torch._guards import (
21    compile_context,
22    CompileContext,
23    detect_fake_mode,
24    DuplicateInputs,
25    tracing,
26    TracingContext,
27)
28from torch._prims_common import CUDARngStateHelper
29from torch._subclasses import FakeTensor
30from torch.fx.experimental._backward_state import BackwardState
31from torch.multiprocessing.reductions import StorageWeakRef
32from torch.utils._python_dispatch import is_traceable_wrapper_subclass
33
34from .. import config
35from .collect_metadata_analysis import run_functionalized_fw_and_collect_metadata
36from .functional_utils import gen_alias_from_base
37from .input_output_analysis import (
38    compute_overlapping_inputs,
39    create_synthetic_base_metadata,
40    remove_dupe_metadata,
41)
42from .logging_utils import describe_input, format_guard_bug_msg, track_graph_compiling
43from .schemas import (
44    AOTConfig,
45    InputAliasInfo,
46    MutationType,
47    OutputType,
48    SubclassCreationMeta,
49    SubclassMeta,
50    TensorAlias,
51    ViewAndMutationMeta,
52)
53from .subclass_utils import (
54    get_types_for_subclass,
55    requires_subclass_dispatch,
56    unwrap_tensor_subclasses,
57    wrap_tensor_subclasses,
58)
59from .traced_function_transforms import aot_dispatch_subclass
60from .utils import (
61    call_func_at_runtime_with_args,
62    make_boxed_func,
63    normalize_as_list,
64    partial_flatten_asdict,
65    strict_zip,
66)
67
68
69zip = strict_zip
70
71
72class CompilerWrapper:
73    """
74    A wrapper around the inputs and outputs to the compiler_fn. We separate these into two parts:
75
76    1. The prologue, which edits the input to the compiler_fn(flat_fn, flat_args, etc)
77    2. The epilogue, which edits the outputs of the compiler_fn (compiled_fn, real arguments)
78
79    Each wrapper below should be implemented as a CompilerWrapper, so that we can facilitate
80    caching on the compiled output, and re-wrapping the output via epilogues.
81    Extra metadata that is needed to compute pre or post compile can be passed in via attributes.
82    """
83
84    def pre_compile(
85        self,
86        flat_fn,
87        flat_args: List[Tensor],
88        aot_config: AOTConfig,
89        *,
90        fw_metadata: ViewAndMutationMeta,
91    ) -> Tuple[Callable, List[Tensor], ViewAndMutationMeta]:
92        """
93        Process the inputs to the compiler_fn. You can pass in extra metadata via kwargs.
94        Args:
95        flat_fn: The function to compile
96        flat_args: Metadata from example inputs of the function to compile
97        aot_config: AOTConfig passed in at compile time
98        fw_metadata: ViewAndMutationMeta generated from flat_fn and flat_args
99        """
100        return flat_fn, flat_args, fw_metadata
101
102    def post_compile(self, compiled_fn, aot_config, *, runtime_metadata) -> Callable:
103        """
104        Given an output of the compiler, wrap it with information received from prologue.
105        Args:
106        compiled_fn: Callable after calling compiler_fn
107        aot_config: AOTConfig after calling prologue
108        runtime_metadata: ViewAndMutationMeta after calling all wrappers's pre_compile steps.
109        Example:
110
111        def wrapped_compiled_fn(args):
112            # do something with args, aot_config, fw_metadata
113            return compiled_fn(args)
114
115        return wrapped_compiled_fn
116        """
117        return compiled_fn
118
119
120# The wrapper created by this function handles all of the runtime aliasing and mutation "epilogue" logic
121# that needs to run after the compiled function.
122#
123# This function accepts a trace_joint flag, indicating whether or not we're generating the runtime
124# epilogue for a forward-only inference graph, or for an autograd.Function.apply function.
125# This is because there are some minor differences in how we treat these cases at runtime:
126# - resize_() is currently handled in the inference case, but not fully handled in the autograd case.
127# - the autograd cases inserts TensorAlias wrapper objects for outputs that alias inputs
128@dataclass
129class RuntimeWrapper(CompilerWrapper):
130    indices_of_inps_to_detach: List[int]
131    trace_joint: bool
132    disable_amp: bool
133
134    def post_compile(
135        self,
136        compiled_fn,
137        aot_config: AOTConfig,
138        *,
139        runtime_metadata: ViewAndMutationMeta,
140    ):
141        return _create_runtime_wrapper(
142            compiled_fn,
143            runtime_metadata=runtime_metadata,
144            indices_of_inps_to_detach=self.indices_of_inps_to_detach,
145            trace_joint=self.trace_joint,
146            keep_input_mutations=aot_config.keep_inference_input_mutations,
147            disable_amp=self.disable_amp,
148        )
149
150
151class NoopAliasHandler:
152    def __init__(self, info, runtime_metadata, trace_joint):
153        pass
154
155    def __call__(self, orig_inputs, fw_outs, out):
156        return out
157
158
159def _unwrap_tensoralias(x):
160    assert isinstance(x, TensorAlias)
161    return x.alias
162
163
164def _identity(x):
165    return x
166
167
168class AliasOfInputHandler:
169    def __init__(self, info, runtime_metadata, trace_joint):
170        self.base_idx = info.base_idx
171        self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity
172        self.requires_grad = info.requires_grad
173        self.functional_tensor = info.functional_tensor
174        self.replay_views = config.view_replay_for_aliased_outputs
175
176    def __call__(self, orig_inputs, fw_outs, out):
177        aliased_base_tensor = orig_inputs[self.base_idx]
178        return gen_alias_from_base(
179            aliased_base_tensor,
180            self.unwrap_out(out),
181            self.requires_grad,
182            self.functional_tensor,
183            replay_views=self.replay_views,
184        )
185
186
187class IsInputHandler:
188    def __init__(self, info, runtime_metadata, trace_joint):
189        self.base_idx = info.base_idx
190        self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity
191
192    def __call__(self, orig_inputs, fw_outs, out):
193        aliased_base_tensor = orig_inputs[self.base_idx]
194        return aliased_base_tensor
195
196
197class AliasOfIntermediateHandler:
198    def __init__(self, info, runtime_metadata, trace_joint):
199        if info.output_type in (
200            OutputType.alias_of_intermediate,
201            OutputType.alias_of_intermediate_save_as_output,
202        ):
203            num_user_outputs = len(runtime_metadata.output_info)
204            self.base_idx = info.base_idx + num_user_outputs
205        else:
206            self.base_idx = info.base_idx
207
208        self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity
209        self.requires_grad = info.requires_grad
210        self.functional_tensor = info.functional_tensor
211        self.replay_views = config.view_replay_for_aliased_outputs
212
213    def __call__(self, orig_inputs, fw_outs, out):
214        aliased_base_tensor = fw_outs[self.base_idx]
215        return gen_alias_from_base(
216            aliased_base_tensor,
217            self.unwrap_out(out),
218            self.requires_grad,
219            self.functional_tensor,
220            replay_views=self.replay_views,
221        )
222
223
224_HANDLER_MAP = {
225    OutputType.non_alias: NoopAliasHandler,
226    OutputType.unsafe_view_alias: NoopAliasHandler,
227    OutputType.custom_function_view: NoopAliasHandler,
228    OutputType.alias_of_input: AliasOfInputHandler,
229    OutputType.is_input: IsInputHandler,
230    OutputType.alias_of_intermediate: AliasOfIntermediateHandler,
231    OutputType.alias_of_intermediate_save_as_output: AliasOfIntermediateHandler,
232    OutputType.alias_of_intermediate_base_is_user_output: AliasOfIntermediateHandler,
233}
234
235
236def make_output_handler(info, runtime_metadata, trace_joint):
237    handler_type = _HANDLER_MAP[info.output_type]
238    return handler_type(info, runtime_metadata, trace_joint)
239
240
241def _create_runtime_wrapper(
242    compiled_fn,
243    *,
244    runtime_metadata: ViewAndMutationMeta,
245    indices_of_inps_to_detach: List[int],
246    trace_joint: bool,
247    keep_input_mutations: bool,
248    disable_amp: bool,
249):
250    if not hasattr(compiled_fn, "_boxed_call"):
251        compiled_fn = make_boxed_func(compiled_fn)
252
253    # Note [Inputs needed in runtime epilogue after list clearing]
254    # In Python functions, you can't free the input arguments of a function within the scope of that function. A workaround is to
255    # wrap the input arguments in a list, and clear the list from within the function.
256    # Here, this is implemented as `call_func_at_runtime_with_args(..., steal_args=True)`.
257    #
258    # This is needed for Compiled Autograd since some of the inputs (activations) should be freed early.
259    # However, we cannot blindly clear the entire list, because AOTAutograd may need access to some of the graph inputs
260    # **after** the compiled function has finished running. There are two main cases:
261    #   (1) Input mutations: If there are an input mutations that we must run outside of the graph, we need access to the input.
262    #   (2) Output aliasing: Outputs that aliases graph inputs generally must be regenerated outside of the `autograd.Function`,
263    #       and doing so requires us accessing the corresponding input after the compiled artifact has run.
264    epilogue_args_idx = []
265    epilogue_args_idx.extend(runtime_metadata.mutated_inp_runtime_indices)
266    for info in runtime_metadata.output_info:
267        if (
268            info.output_type == OutputType.alias_of_input
269            or info.output_type == OutputType.is_input
270        ):
271            assert isinstance(info.base_idx, int)
272            epilogue_args_idx.append(info.base_idx)
273
274    if config.unlift_effect_tokens:
275        assert len(runtime_metadata.tokens) == 0
276
277    replay_views = config.view_replay_for_aliased_outputs
278    if runtime_metadata.num_outputs_aliased > 0:
279        output_handlers = tuple(
280            make_output_handler(info, runtime_metadata, trace_joint)
281            for info in runtime_metadata.output_info
282        )
283
284    def runtime_wrapper(args: List[Any]):
285        # stash a ref to each input tensor we plan to use after the compiled function
286        orig_inputs = {i: args[i] for i in epilogue_args_idx}
287
288        if keep_input_mutations:
289            mutated_args = (
290                args[i]
291                for i in runtime_metadata.mutated_graph_handled_indices_seen_by_autograd
292            )
293            torch.autograd.graph.increment_version(mutated_args)
294
295        if trace_joint:
296            args_ = list(args)
297            # See Note [Detaching inputs that never need gradients]
298            for idx in indices_of_inps_to_detach:
299                if isinstance(args_[idx], torch.Tensor):
300                    args_[idx] = args_[idx].detach()
301
302            # It's possible to have trace_joint inside user specified with no_grad() region,
303            # if there is a nested with enable_grad(), that forces some outputs to require gradients.
304            # Therefore, we unconditionally turn on enable_grad() for compiled_fn execution.
305            with torch.autograd._force_original_view_tracking(
306                True
307            ), torch.enable_grad():
308                all_outs = call_func_at_runtime_with_args(
309                    compiled_fn, args_, disable_amp=disable_amp, steal_args=True
310                )
311        else:
312            # When we have an inference graph, we run with grad disabled.
313            # It's possible to get an inference graph with inputs that require grad,
314            # in which case we want to make sure autograd is disabled
315            # (since e.g., inductor will generate aten.addmm.out calls which autograd will complain on)
316            # NOTE: We use _set_grad_enabled directly to reduce runtime overhead
317            grad_enabled = torch.is_grad_enabled()
318            try:
319                if grad_enabled:
320                    torch._C._set_grad_enabled(False)
321                all_outs = call_func_at_runtime_with_args(
322                    compiled_fn, args, disable_amp=disable_amp, steal_args=True
323                )
324            finally:
325                if grad_enabled:
326                    torch._C._set_grad_enabled(True)
327        del args
328
329        num_mutated_runtime_inps = runtime_metadata.num_mutated_inp_runtime_indices
330        num_intermediate_bases = runtime_metadata.num_intermediate_bases
331
332        assert (
333            len(all_outs)
334            == num_mutated_runtime_inps
335            + runtime_metadata.num_outputs
336            + num_intermediate_bases
337        )
338
339        # Step 3: After running the compiled fw, apply updates to mutated inputs
340        num_mutations_to_apply = runtime_metadata.num_mutated_inp_runtime_indices
341        if num_mutations_to_apply > 0:
342            updated_inputs = all_outs[:num_mutations_to_apply]
343            fw_outs = all_outs[num_mutations_to_apply:]
344
345            for i, inpt_idx in enumerate(runtime_metadata.mutated_inp_runtime_indices):
346                meta = runtime_metadata.input_info[inpt_idx]
347                if not meta.mutates_data and not meta.mutates_metadata:
348                    continue
349                original_inpt = orig_inputs[inpt_idx]
350                updated_inpt = updated_inputs[i]
351                if meta.mutates_storage_metadata:
352                    # See Note [set_() Input Mutations in AOTAutograd]
353                    # mutates_storage_metadata means our input saw a x.set_(y) call.
354                    # What if x **also** saw a data and/or a metadata mutation?
355                    # (1) If the [meta]data mutation occurred after the set_(),
356                    #     then there is no need to copy_() the data.
357                    #     When we perform x.set_(x_updated), we are guaranteed that
358                    #     x_updated already has the final version of the data/metadata
359                    # (2) If a data mutation occurred before the set_().
360                    #     This case seems very difficult to support.
361                    #     TODO: discuss on the PR and decide if we want to tr to
362                    #     either support it, or detect and ban it.
363                    if trace_joint:
364                        assert isinstance(updated_inpt, TensorAlias)
365                        updated_inpt = updated_inpt.alias
366                    with torch.no_grad():
367                        original_inpt.set_(updated_inpt)
368                    continue
369                if meta.mutates_metadata and not meta.mutates_data:
370                    if trace_joint:
371                        assert isinstance(updated_inpt, TensorAlias)
372                        updated_inpt = updated_inpt.alias
373                    # We need to grab the size/stride/storage_offset from the compiled forward,
374                    # and use that to mutate the metadata of the input
375                    original_inpt.as_strided_(
376                        updated_inpt.size(),
377                        updated_inpt.stride(),
378                        updated_inpt.storage_offset(),
379                    )
380                else:
381                    if meta.mutates_data and meta.mutates_metadata:
382                        original_inpt.as_strided_(
383                            updated_inpt.size(),
384                            updated_inpt.stride(),
385                            updated_inpt.storage_offset(),
386                        )
387                    else:
388                        assert meta.mutates_data
389                    if meta.is_leaf and original_inpt.requires_grad:
390                        # We can hit this situation in this case:
391                        #   def f(x):
392                        #       x.detach().mul_(2)
393                        #       return x + 1
394                        # AOTAutograd will see a mutation in the above case, and try to
395                        # apply a copy_() here, in the epilogue.
396                        # But if x required gradients, and is a leaf, then autograd
397                        # will yell at us for trying to mutate it.
398                        # However, it's only possible to end up in this scenario (like the above)
399                        # if all of the mutations to the leaf input were non-autograd-tracking mutations
400                        # (aka mutations under no_grad(), or on detached views).
401                        # In that case, we fully want to hide the mutation from autograd, so detaching is ok.
402                        original_inpt.detach().copy_(updated_inpt)
403                    else:
404                        original_inpt.copy_(updated_inpt)
405        else:
406            fw_outs = all_outs
407
408        # Step 4: Manually regenerate any outputs that are aliased to inputs, instead of
409        # compiling them.
410        if runtime_metadata.num_outputs_aliased > 0:
411            # The compiled forward also returned intermediate bases. We don't want to return them to the user.
412            expect_num_outputs = (
413                len(output_handlers) + runtime_metadata.num_intermediate_bases
414            )
415            assert len(fw_outs) == expect_num_outputs
416            ret_outs = [
417                handler(orig_inputs, fw_outs, out)
418                for out, handler in builtins.zip(fw_outs, output_handlers)
419            ]
420        else:
421            ret_outs = fw_outs
422
423        if runtime_metadata.dynamic_outputs:
424            for t, o in zip(ret_outs, runtime_metadata.output_info):
425                if o.dynamic_dims is None:
426                    continue
427                if hasattr(t, "_dynamo_weak_dynamic_indices"):
428                    t._dynamo_weak_dynamic_indices |= o.dynamic_dims
429                else:
430                    t._dynamo_weak_dynamic_indices = o.dynamic_dims.copy()
431        if runtime_metadata.grad_enabled_mutation is not None:
432            torch._C._set_grad_enabled(runtime_metadata.grad_enabled_mutation)
433        return ret_outs
434
435    return runtime_wrapper
436
437
438@dataclass
439class FunctionalizedRngRuntimeWrapper(CompilerWrapper):
440    # TODO: I would love to get rid of this argument, but it's
441    # Wrapped pretty tightly around our aot_dispatch_autograd logic.
442    # Specifically, tensors_saved_for_backwards_slice's value is both used for calculating indices
443    # for setting placeholder strides(which is done before runtime, before this wrapper runs)
444    # and for saving tensors for backward (which is done during runtime, after this wrapper runs)
445    # So in aot_dispatch_autograd, this wrapper can't edit the set of outs without making one
446    # of those two indices incorrect.
447    return_new_outs: bool = True
448
449    def pre_compile(
450        self,
451        flat_fn,
452        flat_args,
453        aot_config,
454        *,
455        fw_metadata,
456    ) -> Tuple[Callable, List[Tensor], ViewAndMutationMeta]:
457        if config.functionalize_rng_ops:
458            # Update example inputs for the fw_compiler
459            fake_mode = detect_fake_mode()
460            seed, offset = CUDARngStateHelper.get_torch_state_as_tuple(fake_mode)
461            flat_args.extend([seed, offset])
462            # We are not clearing flat_args here because
463            # 1) There is a check in the debug compiler at the end
464            # 2) It does not matter as these are fake tensors
465        return flat_fn, flat_args, fw_metadata
466
467    def post_compile(
468        self,
469        compiled_fn,
470        aot_config: AOTConfig,
471        *,
472        runtime_metadata: ViewAndMutationMeta,
473    ):
474        @wraps(compiled_fn)
475        def wrapper(runtime_args: List[Any]):
476            if runtime_metadata.is_rng_op_functionalized:
477                # Add the seed and offset to args
478                seed, offset = CUDARngStateHelper.get_torch_state_as_tuple()
479                runtime_args.extend([seed, offset])
480                out = compiled_fn(runtime_args)
481                out = self._functionalized_rng_runtime_epilogue(
482                    runtime_metadata,
483                    out,
484                    # TODO: this won't be right for the backward when we convert the call_compiled_backward to use the wrapper
485                    runtime_metadata.num_forward_returns,
486                )
487                return out
488            return compiled_fn(runtime_args)
489
490        return wrapper
491
492    # Calling convention: If we are running functionalized RNG, then outs consists
493    # of (user_outs, rng_offset)
494    def _functionalized_rng_runtime_epilogue(
495        self,
496        metadata: ViewAndMutationMeta,
497        outs,
498        offset_index,
499    ):
500        if metadata.is_rng_op_functionalized:
501            assert metadata.num_outputs_rng_offset == 1
502            new_rng_offset = outs[offset_index]
503            CUDARngStateHelper.set_new_offset(new_rng_offset)
504            if self.return_new_outs:
505                user_outs = outs[:offset_index] + outs[offset_index + 1 :]
506                return user_outs
507            else:
508                return outs
509
510        return outs
511
512
513@dataclass
514class FakifiedOutWrapper(CompilerWrapper):
515    out_metas: List[torch.Tensor] = field(default_factory=list)
516    # TracingContext.fwd_output_strides
517    # Generated from actually doing compile
518    fwd_output_strides: Optional[List[List[int]]] = None
519    needs_post_compile: bool = True
520
521    def pre_compile(
522        self,
523        fw_module,  # Must be fw_module from aot_dispatch_*_graph
524        flat_args,
525        aot_config,
526        *,
527        fw_metadata,
528    ) -> Tuple[Callable, List[Tensor], ViewAndMutationMeta]:
529        tracing_context = torch._guards.TracingContext.try_get()
530        if tracing_context and tracing_context.fakify_first_call:
531            self.out_metas = [
532                n.meta["val"] for n in (list(fw_module.graph.nodes)[-1].args[0])
533            ]
534        else:
535            self.needs_post_compile = False
536        return fw_module, flat_args, fw_metadata
537
538    def _compute_output_meta_with_inductor_strides(self):
539        out = self.out_metas
540        fwd_output_strides = self.fwd_output_strides
541        if not fwd_output_strides:
542            return out
543
544        from torch.fx.experimental.symbolic_shapes import statically_known_true
545
546        for i in range(len(out)):
547            if not isinstance(out[i], Tensor):
548                continue
549            if all(
550                statically_known_true(s1 == s2)
551                for s1, s2 in zip(out[i].stride(), fwd_output_strides[i])
552            ):
553                continue
554            out[i] = out[i].as_strided(out[i].shape, fwd_output_strides[i])
555        return out
556
557    # To be called post compile
558    def set_fwd_output_strides(self, fwd_output_strides):
559        self.fwd_output_strides = fwd_output_strides
560
561    def post_compile(
562        self,
563        compiled_fn,
564        aot_config: AOTConfig,
565        *,
566        runtime_metadata: ViewAndMutationMeta,
567    ):
568        if self.needs_post_compile:
569            assert self.fwd_output_strides is not None
570            fakified_out = self._compute_output_meta_with_inductor_strides()
571
572            @wraps(compiled_fn)
573            def wrapper(runtime_args):
574                nonlocal fakified_out
575                if fakified_out is not None:
576                    out = fakified_out
577                    fakified_out = None
578                    return out
579                return compiled_fn(runtime_args)
580
581            return wrapper
582        # If we don't need to fakify, we can just return the original compiled function
583        return compiled_fn
584
585
586# This wrapper handles the AOTDispatch runtime logic for tensor subclasses.
587# At runtime, we have a compiled function that knows how to operate on the domain of DenseTensor -> DenseTensor,
588# But the user might have passed us some tensor subclass inputs (or expect some subclass tensor outputs).
589# This function handles the wrapping and unwrapping of tensor subclasses at runtime.
590@dataclass
591class AOTDispatchSubclassWrapper(CompilerWrapper):
592    trace_joint: bool
593    fw_only: Optional[Callable]  # Not cached, only used in pre_compile
594    maybe_subclass_meta: Optional[SubclassMeta]
595    num_fw_outs_saved_for_bw: Optional[int]
596
597    def pre_compile(
598        self,
599        flat_fn,
600        flat_args: List[Tensor],
601        aot_config: AOTConfig,
602        *,
603        fw_metadata: ViewAndMutationMeta,
604    ):
605        (new_flat_fn, new_flat_args, subclass_meta) = aot_dispatch_subclass(
606            flat_fn,
607            flat_args,
608            is_joint_structure=self.trace_joint,
609            meta=fw_metadata,
610            fw_only=self.fw_only,  # type: ignore[arg-type]
611        )
612        self.maybe_subclass_meta = subclass_meta
613        return new_flat_fn, new_flat_args, fw_metadata
614
615    def post_compile(
616        self,
617        compiled_fn,
618        _aot_config: AOTConfig,
619        *,
620        runtime_metadata: ViewAndMutationMeta,
621    ):
622        if self.maybe_subclass_meta is None:
623            return compiled_fn
624
625        subclass_metas = runtime_metadata.subclass_fw_graph_out_meta
626
627        @wraps(compiled_fn)
628        def inner_fn(args: List[Any]):
629            unwrapped_args = unwrap_tensor_subclasses(
630                args, is_joint_structure=self.trace_joint
631            )
632            args.clear()
633            # expectation: runtime_fn is a boxed fn
634            unwrapped_outs = compiled_fn(unwrapped_args)
635            wrapped_outs = wrap_tensor_subclasses(
636                unwrapped_outs,
637                subclass_metas=subclass_metas,
638                num_fw_outs_saved_for_bw=self.num_fw_outs_saved_for_bw,
639                is_runtime=True,
640            )
641            return wrapped_outs
642
643        # box it
644        inner_fn._boxed_call = True  # type: ignore[attr-defined]
645        return inner_fn
646
647
648@dataclass
649class EffectTokensWrapper(CompilerWrapper):
650    def post_compile(
651        self,
652        compiled_fn,
653        _aot_config,
654        *,
655        runtime_metadata: ViewAndMutationMeta,
656    ):
657        num_tokens = len(runtime_metadata.tokens)
658
659        @wraps(compiled_fn)
660        def inner_fn(args: List[Any]):
661            if num_tokens > 0:
662                # Pass in forward effect tokens (See Note [Side-Effectful Tokens in AOTAutograd])
663                old_args = args
664                args = [*([None] * num_tokens), *args]
665                old_args.clear()
666
667            outs = compiled_fn(args)
668
669            # Inductor cache DummyModule can return None
670            if outs is None:
671                return None
672            # Toss out the effect tokens (See Note [Side-Effectful Tokens in AOTAutograd])
673            return outs[num_tokens:] if num_tokens != 0 else outs
674
675        # box it
676        inner_fn._boxed_call = True  # type: ignore[attr-defined]
677        return inner_fn
678
679
680# MOTIVATION:
681#
682# When tracing functions for future execution, one must be careful not to pass
683# in the same input tensor multiple times (e.g., f(x, x), as this can result
684# in graphs that are ONLY valid if you later pass a new tensor in exactly the
685# same way (e.g., f(y, y)).  (NB: we really mean duplicate; two distinct
686# tensors that alias each other is a different situation that is covered by
687# aot_dispatch_deduplicated_autograd). Here are two examples:
688#
689# (1) Suppose you have a function:
690#
691#   def f(x, y):
692#       return x + y
693#
694# If you make_fx(f)(x, x), you will trace out:
695#
696#   def f(x, y):
697#       return y + y
698#
699# Oops!
700#
701# (2) For most tensors x and y, you can compute f's gradient with respect to
702# these to inputs by saying torch.autograd.grad(f(x, y), (x, y)).  However,
703# if x is y, you will trace out a program that gets incorrect gradients:
704#
705#   >>> x = torch.randn(1, requires_grad=True)
706#   >>> torch.autograd.grad(x + x, (x, x))
707#   (tensor([2.]), tensor([2.]))
708#
709# In other words, the gradient is double-counted.  Deduplicating the arguments
710# gives you an appropriate gradient:
711#
712#   >>> y = torch.randn(1, requires_grad=True)
713#   >>> torch.autograd.grad(x + y, (x, y))
714#   (tensor([1.]), tensor([1.]))
715#
716# HOW TO DEDUPLICATE:
717#
718# There are a few strategies, in order of preference:
719#
720# 1. For every duplicate argument to the function, detach it into
721#    a separate leaf tensor, so that it is no longer duplicated.
722#
723#       PRO: The resulting compiled graph works for any configuration
724#       of duplicated arguments.
725#
726#       CON: It does not (naively) work if you mutate the metadata of inputs:
727#
728#           def f(x, y):
729#               x.transpose_(0, 1)
730#               y.transpose_(0, 2)
731#
732#           x = torch.randn(2, 3, 4)
733#           f(x, x)
734#
735#       The ordering of the transposes inside f dictates whether or not
736#       you get [4, 2, 3] or [3, 4, 2].  This means that you cannot precompute
737#       what metadata mutations should get applied to each input; you need to
738#       assume they aren't duplicates (what we do today) or preserve
739#       the original metadata mutations exactly in order, so that they work
740#       for any duplicate configuration.
741#
742#       CON: It does not (naively) work if you mutate the data of inputs.
743#       In particular, leaf tensors that require grad cannot be mutated,
744#       this makes it impossible to differentiate with respect to the original
745#       base.
746#
747# 2. For every duplicate argument to the function, remove it, so it is
748#    no longer part of the "true" signature:
749#
750#       PRO: Implemented naively, it still works for metadata/data mutation.
751#
752#       CON: The resulting compiled graph is duplicate-specialized: it only
753#       works if future calls duplicate arguments in exactly the same way.
754#       Horribly, Dynamo doesn't guard on this at the moment.  But even if
755#       it did, you could still end up recompiling a bunch of each duplicate.
756#
757# Our strategy is to do (1) if we can, and do (2) otherwise, erroring if
758# Dynamo's guards are not enough.  In practice, this seems to cover
759# everything.
760#
761@dataclass
762class AOTDedupeWrapper(CompilerWrapper):
763    keep_arg_mask: List[bool] = field(default_factory=list)
764    add_dupe_map: List[int] = field(default_factory=list)
765    old_input_metadata: List[InputAliasInfo] = field(default_factory=list)
766    needs_post_compile: bool = True
767
768    # NB: Hot path, avoid set lookups here
769    # TODO: Can avoid the zip here too, probably
770    def remove_dupe_args(self, args):
771        return [t for t, keep in zip(args, self.keep_arg_mask) if keep]
772
773    def add_dupe_args(self, args):
774        return [args[i] for i in self.add_dupe_map]
775
776    def pre_compile(
777        self,
778        flat_fn,
779        flat_args: List[Tensor],
780        aot_config: AOTConfig,
781        *,
782        fw_metadata: ViewAndMutationMeta,
783    ) -> Tuple[Callable, List[Tensor], ViewAndMutationMeta]:
784        # Use information about whether or not flat_fn mutates its arguments
785        # or not to handle dupe args
786
787        # Strategy 1: For any input that is not mutated, we can leafify it if we
788        # need to remove a duplicate.
789        leaf_flat_args = []
790        args_set = set()
791        ok = True
792
793        for i, a in enumerate(flat_args):
794            if not isinstance(a, torch.Tensor):
795                leaf_flat_args.append(a)
796            elif a not in args_set:
797                args_set.add(a)
798                leaf_flat_args.append(a)
799            elif (
800                not fw_metadata.input_info[i].mutates_data
801                and not fw_metadata.input_info[i].mutates_metadata
802            ):
803                leaf_flat_args.append(a.detach().requires_grad_(a.requires_grad))
804            else:
805                ok = False
806                break
807
808        if ok:
809            self.needs_post_compile = False
810            return flat_fn, leaf_flat_args, fw_metadata
811
812        if requires_subclass_dispatch(leaf_flat_args, fw_metadata):
813            raise RuntimeError(
814                """\
815        Encountered duplicate inputs that are mutated in the graph, but at least one input/output
816        to the graph is a tensor subclass. This is not supported today. You can try to
817        remove the aliasing yourself as a workaround, or otherwise file an issue on github."""
818            )
819
820        # export path: ban duplicate inputs for now, add later if requested.
821        if aot_config.is_export:
822            raise RuntimeError(
823                f"""\
824        Encountered duplicated inputs that are mutated in the graph you are trying to export.
825        This functionality is currently not supported. If needed, please file a github issue.
826
827        fw_metadata={str(fw_metadata)}
828            """
829            )
830
831        # Strategy 2: Duplicate specialize.
832        #
833        # In Haskell types, suppose you have:
834        #
835        #   add_dupe_args :: DedupedArgs -> Args
836        #   remove_dupe_args :: Args -> DedupedArgs
837        #
838        #   compiler_fn
839        #       :: (DedupedArgs -> R) -> DedupedArgs -> AOTConfig -> (DedupedArgs -> R)
840        #   deped_compiler_fn
841        #       :: (Args -> R) -> Args -> AOTConfig -> (Args -> R)
842        #
843        # Then the code below can be written in point-free style as:
844        #
845        #   deduped_compiler_fn f a c =
846        #       compiler_fn (f . add_dupe_args) (remove_dupe_args a) c . remove_dupe_args
847        #
848        # Suppose you have:
849        #
850        #   [a, b, a, c]
851        #
852        # We want:
853        #
854        #   remove_dupe_args([a, b, a, c]) == [a, b, c]
855        #   add_dupe_args([a, b, c]) == [a, b, a, c]
856        #
857        # This is done via (respectively):
858        #
859        #   seen_args = {a: 0, b: 1, c: 2}
860        #   enumerate(add_dupe_map) = [  # how to get args from the deduped list
861        #       (0, 0),
862        #       (1, 1),
863        #       (2, 0),
864        #       (3, 2),
865        #   ]
866        #   keep_arg_mask = [True, True, False, True]
867
868        seen_args: Dict[Tensor, int] = {}
869        # Implicitly map duped arg position (list index) to de-duped arg position
870        keep_arg_mask: List[bool] = []
871        add_dupe_map: List[int] = []
872        duped_arg_len = len(flat_args)
873
874        j = 0  # index into deduped_flat_args
875        for t in flat_args:
876            if isinstance(t, torch.Tensor):
877                if t in seen_args:
878                    keep_arg_mask.append(False)
879                    add_dupe_map.append(seen_args[t])
880                    continue
881                seen_args[t] = j
882
883            keep_arg_mask.append(True)
884            add_dupe_map.append(j)
885            j += 1
886        assert (
887            len(add_dupe_map) == duped_arg_len
888        ), f"Expects add_dupe_map to have length {duped_arg_len} but got {len(add_dupe_map)}"
889
890        self.keep_arg_mask = keep_arg_mask
891        self.add_dupe_map = add_dupe_map
892
893        deduped_flat_args = self.remove_dupe_args(flat_args)
894
895        # Update our input metadata to remove duped input metadata.
896        updated_fw_metadata = remove_dupe_metadata(
897            fw_metadata, keep_arg_mask, add_dupe_map
898        )
899
900        if (
901            tracing_context := TracingContext.try_get()
902            and aot_config.aot_autograd_arg_pos_to_source
903        ):
904            # TODO(voz): This structure is 1:1, we could consider an alternate structure like
905            # kept_pos:[dupe_arg_pos], however, add_dupe_map is 1:1 so we would need a new structure there,
906            # which feels like needless complexity for a tiny bit of efficiency at this point.
907            for dupe_arg_pos, (kept_pos, keep_arg) in enumerate(
908                zip(add_dupe_map, keep_arg_mask)
909            ):
910                if not keep_arg:
911                    dupe_arg_source = aot_config.aot_autograd_arg_pos_to_source[
912                        dupe_arg_pos
913                    ]
914                    kept_arg_source = aot_config.aot_autograd_arg_pos_to_source[
915                        kept_pos
916                    ]
917                    tracing_context.guards_context.aotautograd_guards.append(  # type: ignore[attr-defined]
918                        DuplicateInputs(kept_arg_source, dupe_arg_source)
919                    )
920
921        @wraps(flat_fn)
922        def wrapped_flat_fn(*args):
923            return flat_fn(*self.add_dupe_args(args))
924
925        if config.debug_assert:
926            ref_fw_metadata = run_functionalized_fw_and_collect_metadata(
927                wrapped_flat_fn,
928                static_input_indices=aot_config.static_input_indices,
929                keep_input_mutations=fw_metadata.keep_input_mutations,
930                is_train=fw_metadata.is_train,
931            )(*deduped_flat_args)
932            assert (
933                ref_fw_metadata == updated_fw_metadata
934            ), f"ref_metadata={str(ref_fw_metadata)}, actual_metadata={str(updated_fw_metadata)}"
935
936        return wrapped_flat_fn, deduped_flat_args, updated_fw_metadata
937
938    def post_compile(
939        self,
940        compiled_fn,
941        aot_config: AOTConfig,
942        *,
943        runtime_metadata: ViewAndMutationMeta,
944    ):
945        if not self.needs_post_compile:
946            return compiled_fn
947
948        @wraps(compiled_fn)
949        def wrapped_compiled_fn(args: List[Any]):
950            deduped_args = self.remove_dupe_args(args)
951            args.clear()
952            return compiled_fn(deduped_args)
953
954        wrapped_compiled_fn._boxed_call = True  # type: ignore[attr-defined]
955
956        # This can be uncommented when we properly guard for duplicates,
957        # but right now we must not do it.
958        # if not config.debug_assert:
959        #     return wrapped_compiled_fn
960
961        @wraps(wrapped_compiled_fn)
962        def debugged_compiled_fn(args):
963            # Test that the computed remove/add arg functions are an inverse
964            new_args = self.add_dupe_args(self.remove_dupe_args(args))
965            seen: Dict[Any, None] = {}
966            for i, (x, y) in enumerate(zip(new_args, args)):
967                seen[y] = None
968                assert x is y, format_guard_bug_msg(
969                    aot_config,
970                    f"{describe_input(i, aot_config)} would be a duplicate of "
971                    f"{describe_input(self.add_dupe_map[i], aot_config)}",
972                )
973            # This is only an error if there is metadata mutation on both of
974            # the duped arguments; in this case, we need to know what order
975            # the metadata mutation applies in.  You'll get the correct result
976            # otherwise, because a graph that assumes distinct inputs works if
977            # you dupe the inputs (the gradient contributions from each input
978            # will get summed up appropriately.)
979            #
980            # TODO: work out how to setup this assert correctly
981            """
982            assert len(seen) == unique_args, format_guard_bug_msg(aot_config,
983                f"there would be {unique_args} distinct arguments"
984            )
985            """
986            return wrapped_compiled_fn(args)
987
988        debugged_compiled_fn._boxed_call = True  # type: ignore[attr-defined]
989
990        return debugged_compiled_fn
991
992
993# This layer handles the situation where you have two inputs that alias each other,
994# and one of the inputs is mutated.
995# We need to take special care to ensure that the mutation is applied to the other aliases in the graph.
996#
997# pre-condition: AOTDedupWrapper has already run.
998# (This function will in theory work if there are duplicate args.
999# However, the synthetic base code path is a bit sub-optimal, and running with dupe'd inputs
1000# would cause us to hit that path more frequently).
1001@dataclass
1002class AOTSyntheticBaseWrapper(CompilerWrapper):
1003    # Currently, the only reason we need to plumb this bool is because
1004    # the synthetic base code prohibits more cases in the autograd case than the inference case.
1005    trace_joint: bool  # TODO: refactor trace_joint
1006    needs_post_compile: bool = True
1007    aliased_arg_idx_with_metadata_mutations: List[int] = field(default_factory=list)
1008
1009    def pre_compile(
1010        self,
1011        flat_fn,
1012        flat_args: List[Any],
1013        aot_config: AOTConfig,
1014        *,
1015        fw_metadata: ViewAndMutationMeta,
1016    ) -> Tuple[Callable, List[Tensor], ViewAndMutationMeta]:
1017        is_inference = not self.trace_joint
1018        flat_args_with_synthetic_bases, synthetic_base_info = merge_view_inputs(
1019            flat_args,
1020            fw_metadata.input_info,
1021            is_inference=is_inference,
1022        )
1023
1024        # Happy path: we don't need synthetic bases
1025        if synthetic_base_info is None:
1026            self.needs_post_compile = False
1027            return flat_fn, flat_args, fw_metadata
1028
1029        # export path: ban synthetic bases for now, add later if requested.
1030        if requires_subclass_dispatch(flat_args, fw_metadata):
1031            raise RuntimeError(
1032                """\
1033        Encountered aliased inputs that are mutated in the graph, but at least one input/output
1034        to the graph is a tensor subclass. This is not supported today. You can try to
1035        remove the aliasing yourself as a workaround, or otherwise file an issue on github."""
1036            )
1037
1038        if aot_config.is_export:
1039            raise RuntimeError(
1040                f"""\
1041        Encountered aliased inputs that are mutated in the graph you are trying to export.
1042        This functionality is currently not supported. If needed, please file a github issue.
1043
1044        synthetic_base_info={str(synthetic_base_info)}
1045
1046        fw_metadata={str(fw_metadata)}
1047                """
1048            )
1049
1050        assert len(fw_metadata.input_info) == len(synthetic_base_info)
1051
1052        # Update our forward metadata to take synthetic bases into account
1053        (
1054            fw_metadata_updated,
1055            aliased_arg_idx_with_metadata_mutations,
1056        ) = create_synthetic_base_metadata(
1057            fw_metadata, synthetic_base_info, flat_args, flat_args_with_synthetic_bases
1058        )
1059        # Save old input args for post-compile
1060        self.old_input_info = fw_metadata.input_info
1061
1062        self.aliased_arg_idx_with_metadata_mutations = (
1063            aliased_arg_idx_with_metadata_mutations
1064        )
1065
1066        num_aliased_args_with_metadata_mutations = len(
1067            aliased_arg_idx_with_metadata_mutations
1068        )
1069
1070        replay_views = config.view_replay_for_aliased_outputs
1071
1072        def _unpack_synthetic_bases(primals: Tuple[Any, ...]) -> List[Any]:
1073            f_args_inner = []
1074            for inner_idx_or_tuple in synthetic_base_info:
1075                if isinstance(inner_idx_or_tuple, int):
1076                    f_args_inner.append(primals[inner_idx_or_tuple])
1077                else:
1078                    inner_base_idx, view_tensor = inner_idx_or_tuple
1079                    base = primals[inner_base_idx]
1080                    view_arg = gen_alias_from_base(
1081                        base,
1082                        view_tensor,
1083                        view_tensor.requires_grad,
1084                        replay_views=replay_views,
1085                    )
1086                    f_args_inner.append(view_arg)
1087            return f_args_inner
1088
1089        @wraps(flat_fn)
1090        def wrapped_flat_fn(*args):
1091            unpacked_args = _unpack_synthetic_bases(args)
1092            # This is a bit subtle. The goal of this entire function (aot_dispatch_synthetic_bases)
1093            # is to relieve the downstream logic from having to reason about mutations on inputs that alias
1094            # each other, by replacing aliased inputs with a synthetic base.
1095            # One area where this breaks down a bit however is if one of those aliased inputs
1096            # experienced a metadata mutation.
1097            # We are now obligated to reapply the metadata mutation directly to the user's input;
1098            # it isn't enough to apply mutations back to the synthetic base in the downstream logic.
1099            #
1100            # The way we handle this is by pretending that those aliased inputs that experience metadata mutations
1101            # are additional outputs in the user's forward function.
1102            # The downstream logic will just treat these as "user outputs that alias inputs".
1103            # However, we will manually grab them at runtime here, use them to reapply the metadata mutation
1104            # to the user inputs, and not return them to the user.
1105            aliased_args_with_metadata_mutations = [
1106                x
1107                for i, x in enumerate(unpacked_args)
1108                if i in self.aliased_arg_idx_with_metadata_mutations
1109            ]
1110            if len(aliased_args_with_metadata_mutations) > 0:
1111                return *(flat_fn(*unpacked_args)), *aliased_args_with_metadata_mutations
1112            else:
1113                return flat_fn(*unpacked_args)
1114
1115        if config.debug_assert:
1116            ref_fw_metadata = run_functionalized_fw_and_collect_metadata(
1117                wrapped_flat_fn,
1118                static_input_indices=aot_config.static_input_indices,
1119                keep_input_mutations=fw_metadata.keep_input_mutations,
1120                is_train=fw_metadata.is_train,
1121            )(*flat_args_with_synthetic_bases)
1122            assert ref_fw_metadata == fw_metadata_updated, (
1123                f"ref_metadata={pprint.pformat(partial_flatten_asdict(ref_fw_metadata))}, "
1124                f"\nactual_metadata={pprint.pformat(partial_flatten_asdict(fw_metadata_updated))}"
1125            )
1126        return (
1127            wrapped_flat_fn,
1128            flat_args_with_synthetic_bases,
1129            fw_metadata_updated,
1130        )
1131
1132    def post_compile(
1133        self,
1134        compiled_fn,
1135        aot_config: AOTConfig,
1136        *,
1137        runtime_metadata: ViewAndMutationMeta,
1138    ):
1139        if not self.needs_post_compile:
1140            return compiled_fn
1141
1142        is_inference = not self.trace_joint
1143
1144        @wraps(compiled_fn)
1145        def wrapped_compiled_fn(args):
1146            args_with_synthetic_bases, synthetic_base_info = merge_view_inputs(
1147                args, self.old_input_info, is_inference=is_inference
1148            )
1149            assert synthetic_base_info is not None
1150            aliased_args_w_metadata_mutations = [
1151                args[i] for i in self.aliased_arg_idx_with_metadata_mutations
1152            ]
1153            num_aliased_args_with_metadata_mutations = len(
1154                aliased_args_w_metadata_mutations
1155            )
1156            args.clear()
1157            outs = compiled_fn(args_with_synthetic_bases)
1158            if num_aliased_args_with_metadata_mutations > 0:
1159                # This code does not handle **all** input metadata mutations.
1160                # Instead, it only handles metadata mutations on inputs that were converted into synthetic bases
1161                # (which only happens if at least one aliased input experienced a data mutation).
1162                # e.g:
1163                # def f(a, b):
1164                #     a.mul_(2)
1165                #     b.t_(1, 0)
1166                # f(x.view(2, 2), x.view(2, 2))
1167                mutated_metadata_inps = outs[-num_aliased_args_with_metadata_mutations:]
1168                user_outs = outs[:-num_aliased_args_with_metadata_mutations]
1169                for inp, mutated_inp in zip(
1170                    aliased_args_w_metadata_mutations, mutated_metadata_inps
1171                ):
1172                    inp.as_strided_(
1173                        mutated_inp.size(),
1174                        mutated_inp.stride(),
1175                        mutated_inp.storage_offset(),
1176                    )
1177                return user_outs
1178            return outs
1179
1180        return wrapped_compiled_fn
1181
1182
1183# Note [Handling mutations on an input that aliases other inputs]
1184# The easiest example to show-case this edge case is here:
1185#
1186# def f(a, b):
1187#     a.mul_(2)
1188#     out = a + b
1189#     return out
1190# b = torch.ones(...)
1191# a = b.view(-1)
1192# f(a, b)
1193#
1194# In this situation, if a and b happened to be aliased, we need to trace something different!
1195# Suppose we had b = a.view(-1)
1196# (In this case, that means that `a._base is b`)
1197#
1198# We need to ensure that the aliasing relationship between a and b is preserved.
1199# We do that detecting the specific situation above (mutate an input that aliases another input),
1200# and when we do that, we create a synthetic base argument. Then inside of the traced forward,
1201# we regenerate a and b off of that base.
1202# The complete example of the transformed function looks like this:
1203#
1204# // The traced forward takes in a synthetic base, and regenerates the aliased inputs as views
1205# // We could consider getting view-replay support here to minimize as_strided_scatter ops in the graph
1206# def traced_forward(base):
1207#     a = base.as_strided(...)
1208#     b = base.as_strided(...)
1209#     a_updated = a.mul(2)
1210#     base_updated = torch.as_strided_scatter(base, a_updated, ...)
1211#     b_updated = base_updated.as_strided(...)
1212#     out = a_updated + b_updated
1213#     return a_updated, out
1214#
1215# def compiled_fn(a, b):
1216#     // we detect that a is the "differentiable base" here
1217#     base = a
1218#     // In other situations, we might do either:
1219#     // (1) a and b are both views off of some larger differentiable base
1220#     //     assert a._base is b._base and a._base is not None
1221#     //     base = a._base
1222#     // (2) a and b both don't require gradients. Create a base from the storage
1223#     //     assert a._base is None and b._base is None
1224#     //     base = torch.Tensor(a.storage())
1225#     a_updated, out = traced_forward(base)
1226#     a.copy_(a_updated)
1227#     return out
1228#
1229# This function:
1230# (1) Merges input views into a synthetic base argument, when any of those input views are mutated
1231# (2) Returns metadata telling the autograd.Function how to modify their arguments properly,
1232#     to respect the new calling convention.
1233#
1234# The calling convention is as follows.
1235# Any inputs that were originally views of one another get yanked, and replaced with a synthetic base.
1236# The argument list ordering goes [base1, ..., baseN], [arg1, ..., argN],
1237# Where the ordering of the bases is determined from the ordering of the original view args.
1238# baseA will come before baseB if the earliest original argument coming from baseA
1239# showed up earlier in the argument list than the earliest original argument coming from baseB.
1240#
1241# Example, given some tensors a, b, c, d
1242# call site:
1243#   f(a, c.view(-1), b.view(-1), b, c, d)
1244# Modified argument list:
1245#   c_base comes first because the first c view came earlier in arg list than the first b view
1246#   a and d still show up in the modified arg list, but b and c don't- they're regenerated from their bases
1247#   b_base = torch.Tensor(b.storage())
1248#   c_base = torch.Tensor(c.storage())
1249#   f(c_base, b_base, a, d)
1250def merge_view_inputs(
1251    fwd_inputs: List[Any],
1252    mutated_input_info: List[InputAliasInfo],
1253    *,
1254    # The autograd case currently has more restrictions than the inference case.
1255    is_inference: bool,
1256) -> Tuple[List[Any], Optional[List[Union[int, Tuple[int, torch.Tensor]]]]]:
1257    def _are_differentiable_views(view1, view2):
1258        if view1 is view2:
1259            return True
1260        if view1._base is None and view2._base is None:
1261            return False
1262        if view1._base is view2._base or view1._base is view2 or view1 is view2._base:
1263            return True
1264        return False
1265
1266    def _same_dtype_views(view1, view2):
1267        if view1.dtype != view2.dtype:
1268            return False
1269        if view1._base is not None and view1.dtype != view1._base.dtype:
1270            return False
1271        if view2._base is not None and view2.dtype != view2._base.dtype:
1272            return False
1273        return True
1274
1275    assert len(fwd_inputs) == len(mutated_input_info)
1276    if not [info for info in mutated_input_info if info.mutates_data]:
1277        # Return early when there are no mutations.
1278        return fwd_inputs, None
1279
1280    storage_ref_to_idx: Dict[StorageWeakRef, List[int]] = collections.defaultdict(list)
1281    base_args = []
1282    other_args = []
1283    for i, inpt in enumerate(fwd_inputs):
1284        if isinstance(inpt, Tensor):
1285            storage_ref = StorageWeakRef(inpt.untyped_storage())
1286            storage_ref_to_idx[storage_ref].append(i)
1287        else:
1288            other_args.append(inpt)
1289    # Note [Synthetic Base Info Metadata]
1290    # This list contains metadata that tells you what the i'th argument in the inner calling convention should be.
1291    # It's either:
1292    # - another int (corresponding to the index in the argument list of the element from the outer calling convention)
1293    # - idx, view_tensor, where we can generate the new output with view_tensor._view_func(old_args[idx])
1294    #   idx corresponds to which synthetic base from the outer calling context to view
1295    inner_calling_convention_meta: Dict[int, Union[int, Tuple[int, torch.Tensor]]] = {}
1296    for aliased_input_indices in storage_ref_to_idx.values():
1297        if len(aliased_input_indices) <= 1 or not any(
1298            # We only care about mutations that affect all aliases,
1299            # so metadata mutations on an input doesn't require us to do synthetic base handling.
1300            mutated_input_info[inpt_idx].mutates_data
1301            for inpt_idx in aliased_input_indices
1302        ):
1303            for curr_idx in aliased_input_indices:
1304                other_args.append(fwd_inputs[curr_idx])
1305            continue
1306
1307        # Here, we attempt to do a more complicated check to detect false aliasing
1308        # (e.g. if all the tensors have the same storage, but don't actually overlap)
1309        # In theory, we could have a large group of tensors that all share storages, where only *some* of them
1310        # have overlapping memory.
1311        # I don't bother with that case for now: here, we only bail out earlier if we detect that **every** pair
1312        # of tensors in the current group that shares a storage is non-overlapping.
1313        aliased_input_indices_no_false_sharing = compute_overlapping_inputs(
1314            fwd_inputs, aliased_input_indices
1315        )
1316        if len(aliased_input_indices_no_false_sharing) <= 1:
1317            for curr_idx in aliased_input_indices:
1318                other_args.append(fwd_inputs[curr_idx])
1319            continue
1320
1321        # We detected an input that was mutated, AND aliases with another input.
1322        # we need to replace this set of aliased inputs with a single synthetic base.
1323        # For now, I'm banning a bunch of cases. We expect dynamo to properly detect these cases
1324        # and error out. We can fix them later.
1325        # These checks are transitive, so we don't need to check every pair.
1326        for idx1, idx2 in zip(
1327            aliased_input_indices, aliased_input_indices[1:], strict=False
1328        ):
1329            view1 = fwd_inputs[idx1]
1330            view2 = fwd_inputs[idx2]
1331            # The "inputs that are aliased but have different differentiable bases" case
1332            # is more complicated and hopefully pretty rare. Not currently handled.
1333            if not is_inference:
1334                assert _are_differentiable_views(
1335                    view1, view2
1336                ), "aot_autograd() does not yet handle non-differentiable view input mutations."
1337            # Regenerating views when reinterpreting complex / real tensors seems non-trivial,
1338            # not handling for now
1339            assert _same_dtype_views(
1340                view1, view2
1341            ), "aot_autograd() does not yet handle input mutations on views with different dtypes."
1342        non_none_bases = [
1343            fwd_inputs[i]._base
1344            for i in aliased_input_indices
1345            if fwd_inputs[i]._base is not None
1346        ]
1347        aliases_with_none_bases = [
1348            fwd_inputs[i] for i in aliased_input_indices if fwd_inputs[i]._base is None
1349        ]
1350        if len(non_none_bases) == 0:
1351            # Case where none of the aliases have a ._base
1352            # we generate a synthetic base without gradients, and generate views off of it
1353            # We hit this case when we have input tensors to the graph that share a storage,
1354            # but do not have a ._base field.
1355            # Wondering when we hit this case?
1356            # The _base field simply says that autograd knows about the aliasing relationship,
1357            # but sometimes we create tensors which are aliased out of the same storage but guaranteed
1358            # to be disjoint. In these cases, we will skip setting up the _base relationship
1359            # for performance reasons (because the fact that the tensors share the same storage
1360            # is unobservable unless you (1) do naughty things with resize_/as_strided
1361            # or (2) look at the storage--as we are doing here.)
1362            # One particular example of this is optimizer steps on the LSTM module:
1363            # LSTM parameters are packed into a contiguous storage for efficiency reasons when
1364            # calling cuDNN kernels, so when these parameters get passed to the optimizer we will
1365            # find they share the same storage, but do not have _base set since they are all disjoint.
1366            #
1367            # NOTE: There is one case where this is unsafe:
1368            # torch.Tensor(storage) will ALWAYS create a 1D tensor, which is not necessarily
1369            # the same shape as the "actual" base that the tensor came from.
1370            # For the most part this is fine, because we always use as_strided()
1371            # to generate the original aliased inputs again.
1372            # If we were to use view-replay though, this could cause the aliased views
1373            # to have incorrect sizes.
1374            example_idx = aliased_input_indices[0]
1375            example_alias = fwd_inputs[example_idx]
1376            # Note that this function is re-used at both trace time and runtime.
1377            # At trace time, we're under a FakeMode so synthetic_base becomes a FakeTensor.
1378            synthetic_base = torch.empty(
1379                (0,), dtype=example_alias.dtype, device=example_alias.device
1380            )
1381            # We don't actually have a convenient way of going from storage -> tensor,
1382            # So using set_() here (we suffer some minor overhead, but this case is rare).
1383            synthetic_base.set_(example_alias.untyped_storage())
1384        else:
1385            # Case where all of the aliases require gradients, and have the same _base.
1386            synthetic_base = non_none_bases[0]
1387            for other_base in non_none_bases[1:]:
1388                assert (
1389                    other_base is synthetic_base
1390                ), "aot_autograd() does not yet handle non-differentiable view input mutations."
1391            for alias in aliases_with_none_bases:
1392                assert (
1393                    alias is synthetic_base
1394                ), "aot_autograd() does not yet handle non-differentiable view input mutations."
1395        base_args.append(synthetic_base)
1396        for curr_view_idx in aliased_input_indices:
1397            curr_view = fwd_inputs[curr_view_idx]
1398            base_idx = len(base_args) - 1
1399            # We store just enough info here so that we can regenerate the view later.
1400            # Regeneration: curr_view._view_func(args[base_idx])
1401            inner_calling_convention_meta[curr_view_idx] = (base_idx, curr_view)
1402    if len(base_args) == 0:
1403        assert len(other_args) == len(fwd_inputs)
1404        # If no synthetic bases are necessary, just return the original inputs.
1405        return fwd_inputs, None
1406    else:
1407        # Otherwise, return:
1408        # (1) The new args according to the updated calling convention: (synthetic_bases, other_args)
1409        # (2) Metadata telling functionalization how to generate the inner argument list given the outer calling convention.
1410        #     We post-process it into a list, where meta[i] tells you info about the i'th argument in the inner calling convention.
1411        args_to_functionalization = base_args + other_args
1412        arg_to_old_idx_map = {arg: i for (i, arg) in enumerate(fwd_inputs)}
1413        for i, other_arg in enumerate(other_args):
1414            new_idx = len(base_args) + i
1415            old_idx = arg_to_old_idx_map[other_arg]
1416            inner_calling_convention_meta[old_idx] = new_idx
1417        # post process into a list
1418        post_processed_calling_convention_meta: List[
1419            Union[int, Tuple[int, torch.Tensor]]
1420        ] = [-1 for _ in range(len(inner_calling_convention_meta))]
1421        for k, v in inner_calling_convention_meta.items():
1422            post_processed_calling_convention_meta[k] = v
1423        # Quick assert: every argument in the inner calling convention should be accounted for.
1424        for x in post_processed_calling_convention_meta:
1425            assert x != -1
1426        return args_to_functionalization, post_processed_calling_convention_meta
1427
1428
1429@dataclass
1430class AutogradLazyBackwardCompileInfo:
1431    bw_module: Callable
1432    placeholder_list: List[Any]
1433    saved_context: Optional[TracingContext]
1434    saved_compile_context: Optional[CompileContext]
1435
1436
1437# This is wrapped in a class just for namespacing purposes
1438# No need to make it into an actual CompilerWrapper because it doesn't fit the abstract as cleanly
1439class AOTDispatchAutograd:
1440    @staticmethod
1441    def _force_contiguous(x):
1442        if not isinstance(x, torch.Tensor):
1443            return x
1444        x = x.contiguous()
1445        if not is_traceable_wrapper_subclass(x):
1446            return x
1447        for attr in x.__tensor_flatten__()[0]:  # type: ignore[attr-defined]
1448            elem = getattr(x, attr)
1449            if not elem.is_contiguous():
1450                setattr(x, attr, elem.contiguous())
1451        return x
1452
1453    # See Note [Tangents must be contiguous, Part 2]
1454    @staticmethod
1455    def coerce_runtime_tangent(x, metadata):
1456        if not isinstance(x, torch.Tensor):
1457            return x
1458        if not is_traceable_wrapper_subclass(x):
1459            return x
1460        assert metadata is not None
1461        (_, expected_tangent_metadata) = metadata
1462        _, runtime_tangent_metadata = x.__tensor_flatten__()  # type: ignore[attr-defined]
1463        if runtime_tangent_metadata == expected_tangent_metadata:
1464            return x
1465        if not hasattr(x, "__coerce_same_metadata_as_tangent__"):
1466            raise RuntimeError(
1467                f"""
1468During the backward, we encountered a tensor subclass where we guessed its
1469metadata incorrectly.
1470
1471Expected metadata: {str(expected_tangent_metadata)}
1472
1473Runtime metadata: {str(runtime_tangent_metadata)}
1474
1475shape: {str(cast(torch.Tensor, x).shape)}
1476To fix this, your tensor subclass must implement the dunder method __force_to_same_metadata__.
1477"""
1478            )
1479        return x.__coerce_same_metadata_as_tangent__(expected_tangent_metadata)  # type: ignore[attr-defined]
1480
1481    @staticmethod
1482    def post_compile(
1483        compiled_fw_func,  # fw_module after compilation + wrappers
1484        compiled_bw_func,  # bw_module after compilation + wrappers
1485        maybe_subclass_meta: Optional[SubclassMeta],
1486        num_symints_saved_for_bw_: int,
1487        backward_state_indices: List[int],
1488        disable_amp: bool,
1489        indices_of_inps_to_detach: List[int],
1490        lazy_backward_info: Optional[AutogradLazyBackwardCompileInfo],
1491        aot_config: AOTConfig,
1492        *,
1493        fw_metadata: ViewAndMutationMeta,  # runtime metadata
1494        try_save_cache_entry: Optional[Callable],  # Save cache entry after compilation
1495    ):
1496        class CompiledFunction(torch.autograd.Function):
1497            compiled_fw = compiled_fw_func
1498            compiled_bw = compiled_bw_func
1499            metadata: ViewAndMutationMeta = fw_metadata  # type: ignore[assignment]
1500            maybe_subclass_metadata: Optional[SubclassMeta] = maybe_subclass_meta
1501            num_symints_saved_for_bw = num_symints_saved_for_bw_
1502            _compiled_autograd_should_lift = False
1503            _aot_id = aot_config.aot_id
1504            _lazy_backward_info = lazy_backward_info
1505
1506            @staticmethod
1507            def _compiled_autograd_key(ctx):
1508                return (ctx._autograd_function_id, *ctx.symints)
1509
1510            @staticmethod
1511            def forward(ctx, *deduped_flat_tensor_args):
1512                args = deduped_flat_tensor_args
1513                if backward_state_indices:
1514                    bw_state = args[backward_state_indices[0]]
1515                    assert isinstance(bw_state, BackwardState)
1516                    ctx._compiled_autograd_backward_state = bw_state
1517
1518                # There is a pretty complicated calling convention around what the compiled fw returns.
1519                # The full list of outputs and their relative order is:
1520                # (*tokens, *mutated_inputs, *fw_outs, *fw_intermediate_bases, *saved_tensors, *saved_symints)
1521                # - Note that in the synthetic bases case, mutated_inputs will correspond to an updated version
1522                #   of the original view, and not the synthetic base
1523                # - Note that donated buffer logic requires (*saved_tensors, *saved_symints) showing up last
1524                #   in the fw output order.
1525                fw_outs = call_func_at_runtime_with_args(
1526                    CompiledFunction.compiled_fw,
1527                    args,
1528                    disable_amp=disable_amp,
1529                )
1530
1531                num_outputs = CompiledFunction.metadata.num_outputs
1532                num_outputs_aliased = CompiledFunction.metadata.num_outputs_aliased
1533                num_mutated_runtime_inps = (
1534                    CompiledFunction.metadata.num_mutated_inp_runtime_indices
1535                )
1536                num_forward_returns = CompiledFunction.metadata.num_forward_returns
1537
1538                # Partitioners must put symint arguments at the end separate from tensor arguments
1539                tensors_saved_for_backwards = fw_outs[
1540                    CompiledFunction.metadata.tensors_saved_for_backwards_slice
1541                ]
1542                assert all(
1543                    isinstance(x, torch.Tensor) for x in tensors_saved_for_backwards
1544                )
1545                # See Note [Detaching saved tensors in AOTAutograd]
1546                ctx.save_for_backward(
1547                    *(
1548                        x.detach() if x._is_view() else x
1549                        for x in tensors_saved_for_backwards
1550                    )
1551                )
1552                symint_outs = fw_outs[
1553                    CompiledFunction.metadata.symints_saved_for_backwards_slice
1554                ]
1555                assert all(
1556                    isinstance(x, (int, float, torch.SymInt, torch.SymFloat))
1557                    for x in symint_outs
1558                ), str([type(x) for x in symint_outs])
1559                ctx.symints = symint_outs
1560
1561                raw_returns = fw_outs[0:num_forward_returns]
1562
1563                # Wrap all autograd.Function.forward() outputs that are aliases
1564                # so that autograd.Function doesn't treat them as tensors
1565                if num_mutated_runtime_inps > 0:
1566                    for i, idx in enumerate(
1567                        CompiledFunction.metadata.mutated_inp_runtime_indices
1568                    ):
1569                        # We could make this faster by only looping over inputs with metadata-only mutations
1570                        # (instead of looping over inputs with either data or metadata mutations), but there shouldn't be many.
1571                        info = CompiledFunction.metadata.input_info[idx]
1572                        if info.mutates_metadata and not info.mutates_data:
1573                            raw_return_idx = i
1574                            raw_returns[raw_return_idx] = TensorAlias(
1575                                raw_returns[raw_return_idx]
1576                            )
1577
1578                    if config.debug_assert:
1579                        user_mutated_inputs_raw = raw_returns[
1580                            0:num_mutated_runtime_inps
1581                        ]
1582                        mut_inp_infos = [
1583                            x
1584                            for x in CompiledFunction.metadata.input_info
1585                            if x.mutates_data or x.mutates_metadata
1586                        ]
1587                        assert len(user_mutated_inputs_raw) == len(mut_inp_infos)
1588
1589                if CompiledFunction.metadata.num_unsafe_view_outputs > 0:
1590                    for idx in CompiledFunction.metadata.unsafe_view_out_indices:
1591                        raw_return_idx = num_mutated_runtime_inps + idx
1592                        o = raw_returns[raw_return_idx]
1593                        raw_returns[raw_return_idx] = torch.ops.aten._unsafe_view(
1594                            o, o.shape
1595                        )
1596
1597                if num_outputs_aliased > 0:
1598                    for idx in CompiledFunction.metadata.aliased_out_indices:
1599                        raw_return_idx = num_mutated_runtime_inps + idx
1600                        raw_returns[raw_return_idx] = TensorAlias(
1601                            raw_returns[raw_return_idx]
1602                        )
1603
1604                    if config.debug_assert:
1605                        intermediates_raw = raw_returns[
1606                            num_mutated_runtime_inps + num_outputs :
1607                        ]
1608                        assert not any(
1609                            isinstance(x, TensorAlias) for x in intermediates_raw
1610                        )
1611
1612                # invariant: intermediate bases always require gradients, so we don't have to
1613                # consider marking them as non-differentiable.
1614                raw_returns_not_including_intermediate_bases = raw_returns[
1615                    : num_mutated_runtime_inps + num_outputs
1616                ]
1617                raw_returns_meta = [
1618                    x
1619                    for x in CompiledFunction.metadata.input_info
1620                    if x.mutation_type == MutationType.MUTATED_OUT_GRAPH
1621                ] + CompiledFunction.metadata.output_info
1622
1623                fw_outs_not_requiring_grad = [
1624                    x
1625                    for (i, x) in enumerate(
1626                        raw_returns_not_including_intermediate_bases
1627                    )
1628                    if isinstance(x, torch.Tensor)
1629                    and not raw_returns_meta[i].requires_grad
1630                ]
1631                ctx.mark_non_differentiable(*fw_outs_not_requiring_grad)
1632                ctx._materialize_non_diff_grads = False
1633                return tuple(raw_returns)
1634
1635            @staticmethod
1636            def backward(ctx, *flat_args):
1637                # Calling convention: we expect a grad_out passed to the backward:
1638                # - for every output of the fw that does *not* alias an input or graph intermediate
1639                # - for every updated_input generated by the fw that does *not* alias an input (aka only data-mutations)
1640                # - for every graph intermediate that we need to use to generate an output later.
1641                # The other outputs in the autograd.Function.forward that do *not* show up in the backward include:
1642                # - outputs that alias inputs or graph intermediates
1643                # - updated inputs due to metadata-only mutations.
1644                # We need to return them in the forward, but ensure that they all do not get gradients in the backward,
1645                # and we filter them out here before passing the remaining grad_outputs into the compiled backward.
1646                num_intermediate_bases = (
1647                    CompiledFunction.metadata.num_intermediate_bases
1648                )
1649                num_mutated_runtime_inps = (
1650                    CompiledFunction.metadata.num_mutated_inp_runtime_indices
1651                )
1652                expected_grad_outs = (
1653                    CompiledFunction.metadata.num_outputs
1654                    + num_mutated_runtime_inps
1655                    + num_intermediate_bases
1656                )
1657                deterministic = CompiledFunction.metadata.deterministic
1658                global_deterministic = torch.are_deterministic_algorithms_enabled()
1659                if deterministic is not None:
1660                    torch._check(
1661                        not (not deterministic and global_deterministic),
1662                        lambda: (
1663                            "This compiled backward function is being run with "
1664                            "torch.use_deterministic_algorithms(True), "
1665                            "but it was previously generated during the forward function while "
1666                            "torch.use_deterministic_algorithms(False) was set."
1667                        ),
1668                    )
1669
1670                assert len(flat_args) == expected_grad_outs
1671                out_info = CompiledFunction.metadata.output_info
1672
1673                inp_tangents, out_tangents, intermediate_base_tangents = (
1674                    flat_args[:num_mutated_runtime_inps],
1675                    flat_args[
1676                        num_mutated_runtime_inps : num_mutated_runtime_inps
1677                        + CompiledFunction.metadata.num_outputs
1678                    ],
1679                    flat_args[
1680                        num_mutated_runtime_inps
1681                        + CompiledFunction.metadata.num_outputs :
1682                    ],
1683                )
1684                # input_info contains info on *every* input,
1685                # But in the backward(), we are only given grad outputs for every mutated input
1686                # We then need to filter out the grad outputs that correspond to metadata-only mutations or don't require grad
1687                input_info = CompiledFunction.metadata.input_info
1688                inp_tangents_filtered = [
1689                    x
1690                    for x, info_idx in zip(
1691                        inp_tangents,
1692                        CompiledFunction.metadata.mutated_inp_runtime_indices,
1693                    )
1694                    if input_info[info_idx].mutates_data
1695                    and input_info[info_idx].requires_grad
1696                ]
1697                # We also need to filter out grad outputs that correspond to outputs aliasing inputs/intermediates
1698                out_tangents_filtered = [
1699                    x
1700                    for x, info in zip(out_tangents, out_info)
1701                    if info.output_type
1702                    in [
1703                        OutputType.non_alias,
1704                        OutputType.unsafe_view_alias,
1705                        OutputType.custom_function_view,
1706                    ]
1707                    and issubclass(info.raw_type, torch.Tensor)
1708                    and info.requires_grad
1709                ]
1710                # intermediate bases always require gradients, and always participate in the backward graph.
1711                flat_bw_args_with_grads = [
1712                    *inp_tangents_filtered,
1713                    *out_tangents_filtered,
1714                    *intermediate_base_tangents,
1715                ]
1716                num_flat_bw_args_with_grads = len(flat_bw_args_with_grads)
1717
1718                # sanity asserts
1719                # metadata_only_inps = [
1720                #     x for x, info_idx in zip(inp_tangents, mutated_inp_indices)
1721                #     if not input_info[info_idx].mutates_data
1722                # ]
1723                # aliased_outputs = [
1724                #     x for x, info in zip(out_tangents, out_info) if info.output_type != OutputType.non_alias]
1725                # assert all(x is None for x in metadata_only_inps)
1726                # assert all(x is None for x in aliased_outputs)
1727                # TODO: replace this with FunctionalizedRngRuntimeWrapper
1728                rng_args = []
1729                if CompiledFunction.metadata.is_rng_op_functionalized:
1730                    # Add the seed and offset to args
1731                    rng_args = CUDARngStateHelper.get_torch_state_as_tuple()
1732
1733                bw_tokens = [None] * CompiledFunction.metadata.num_backward_tokens
1734
1735                # - note: donated buffer logic requires (*ctx.symints, *ctx.saved_tensors) showing up first
1736                #   in the bw output order.
1737
1738                # Every dereference of ctx.saved_tensors incurs saved_tensors_hooks calls
1739                # There are tests that count these calls, saving to var.
1740                ctx_saved_tensors = ctx.saved_tensors
1741                num_ctx_saved_tensors = len(ctx_saved_tensors)
1742                all_args = [
1743                    *ctx.symints,
1744                    *ctx_saved_tensors,
1745                    *flat_bw_args_with_grads,
1746                    *bw_tokens,
1747                    *rng_args,
1748                ]
1749                del ctx_saved_tensors
1750
1751                # Note: [AOTAutograd Backward Guards]
1752                # During AOTDispatch, we eagerly create and trace out a joint fw-bw graph.
1753                # Doing so requires us to "guess" about some of the metadata of our grad_outputs.
1754                #
1755                # In particular: if an output to the forward is a plain tensor or a subclass,
1756                # its corresponding grad_output in the backward **may or may not** be
1757                # a plain tensor or a subclass. The main cases are:
1758                # (1) If an output is a plain tensor, its grad_out will also be a plain tensor,
1759                #     *unless* the output is used in some subclass compute later in the forward graph,
1760                #     which will cause its grad_output to become a subclass
1761                # (2) If an output is a subclass, its grad_out will also be a subclass,
1762                #     *unless* the output of the forward did not actually participate in the gradient computation,
1763                #     in which case autograd will insert a plain tensor of zeros for the grad_output.
1764                #     We could avoid this case with `torch.autograd.Function.set_materialize_grads`,
1765                #     although this is not turned on today in AOTAutgrad and would require more work.
1766                #
1767                # Today, we make a guess on subclass-ness based on the above examples,
1768                # and hard-error in the backward if we guessed wrong.
1769                #
1770                # In the future, we should add backward guards that would allow us to
1771                # properly handle this case instead of erroring: we would need to retrace the backward graph,
1772                # since we might produce an entirely different trace if our grad_outputs are subclass or not.
1773                assert (
1774                    len(CompiledFunction.metadata.output_types)
1775                    == num_flat_bw_args_with_grads
1776                )
1777
1778                grad_output_types = [type(x) for x in flat_bw_args_with_grads]
1779                # In general, we can add more asserts/guards here for when we partitioned
1780                # with incorrect assumptions about the grad_outputs.
1781                # Normalize FakeTensor -> torch.Tensor
1782                # - during tracing our types are FakeTensor
1783                # - at runtime in the backward our types are torch.Tensor...
1784                # - unless we're running compiled backward, in which case they are also FakeTensor
1785                grad_output_types_ = [
1786                    torch.Tensor if x is FakeTensor else x for x in grad_output_types
1787                ]
1788                assert (
1789                    grad_output_types_ == CompiledFunction.metadata.output_types
1790                ), f"""\
1791    We incorrectly attempted to compile the backward with incorrect subclass metadata.
1792    If you run into this error, please file an issue.
1793    Expected grad_output types: {str(CompiledFunction.metadata.output_types)}
1794    Got grad_output types: {str(grad_output_types)}"""
1795
1796                del flat_bw_args_with_grads
1797
1798                tangents_start_idx = (
1799                    len(all_args)
1800                    - num_flat_bw_args_with_grads
1801                    - len(rng_args)
1802                    - len(bw_tokens)
1803                )
1804                assert tangents_start_idx == len(ctx.symints) + num_ctx_saved_tensors
1805                tangents_end_idx = len(all_args) - len(rng_args) - len(bw_tokens)
1806
1807                # TODO: figure out how to refactor the backward properly
1808                # so I can use aot_dispatch_subclass_wrapper() here.
1809                if CompiledFunction.maybe_subclass_metadata is not None:
1810                    tangents = all_args[tangents_start_idx:tangents_end_idx]
1811
1812                    def get_types_for_tangents(tangents):
1813                        infos = []
1814                        idx = 0
1815                        for a in tangents:
1816                            if isinstance(a, Tensor) and is_traceable_wrapper_subclass(
1817                                a
1818                            ):
1819                                infos.append(get_types_for_subclass(a))
1820                            else:
1821                                infos.append(idx)
1822                            idx += 1
1823                        return infos
1824
1825                    runtime_subclass_info = get_types_for_tangents(tangents)
1826
1827                    if len(runtime_subclass_info) != len(
1828                        CompiledFunction.metadata.subclass_tangent_meta
1829                    ):
1830                        raise RuntimeError(
1831                            "The grad inputs should be same number as forward output tangents"
1832                        )
1833                    for a, b in zip(
1834                        runtime_subclass_info,
1835                        CompiledFunction.metadata.subclass_tangent_meta,
1836                    ):
1837                        # Types should match between runtime and traced tangents.
1838                        # TODO (tmanlaibaatar) Should actually call coerce_runtime_tangent
1839                        if isinstance(a, List) and (
1840                            isinstance(b, SubclassCreationMeta) and b.subclass_type
1841                        ):
1842                            if not a == b.subclass_type:
1843                                raise RuntimeError(
1844                                    "The grad inputs should be same tensor subclass type as forward output"
1845                                )
1846
1847                    # Get the number of tangents after unwrapping
1848                    len_tangents = len(
1849                        unwrap_tensor_subclasses(
1850                            tangents,
1851                            is_joint_structure=False,
1852                        )
1853                    )
1854                    assert CompiledFunction.metadata.traced_tangent_metas is not None
1855                    all_args = [
1856                        (
1857                            AOTDispatchAutograd.coerce_runtime_tangent(
1858                                t,
1859                                CompiledFunction.metadata.traced_tangent_metas[
1860                                    i - tangents_start_idx
1861                                ],
1862                            )
1863                            if tangents_start_idx <= i < tangents_end_idx
1864                            else t
1865                        )
1866                        for i, t in enumerate(all_args)
1867                    ]
1868                    all_args = unwrap_tensor_subclasses(
1869                        all_args, is_joint_structure=False
1870                    )
1871                    tangents_start_idx = (
1872                        len(all_args) - len_tangents - len(rng_args) - len(bw_tokens)
1873                    )
1874                    tangents_end_idx = tangents_start_idx + len_tangents
1875
1876                # Make the tangents contiguous. Note that we must do this after subclass desugaring
1877                # because inputs to inductor have to be contiguous
1878                all_args = [
1879                    (
1880                        AOTDispatchAutograd._force_contiguous(t)
1881                        if (tangents_start_idx <= i < tangents_end_idx)
1882                        else t
1883                    )
1884                    for i, t in enumerate(all_args)
1885                ]
1886
1887                def call_compiled_backward():
1888                    if ctx._is_compiled_autograd_tracing():
1889                        if lazy_backward_info is None:
1890                            raise RuntimeError(
1891                                """This compiled backward function was saved by AOTAutogradCache, which does not support
1892                            compiled autograd. Please turn off AOTAutogradCache using `ENABLE_AOT_AUTOGRAD_CACHE=0` to continue."""
1893                            )
1894                        bw_module = lazy_backward_info.bw_module
1895                        # For compiled autograd, run raw FX graph so that it can be inlined into the larger graph
1896                        symints = ctx._get_compiled_autograd_symints()
1897                        assert len(symints) == len(ctx.symints)
1898                        all_args[: len(symints)] = symints
1899                        if backward_state_indices:
1900                            assert (
1901                                ctx._compiled_autograd_backward_state.proxy is not None
1902                            )
1903                            all_args.append(ctx._compiled_autograd_backward_state)
1904                        context = (
1905                            torch._C._DisableAutocast if disable_amp else nullcontext
1906                        )
1907                        with context():
1908                            out = normalize_as_list(bw_module(*all_args))
1909                        # TODO: replace with post_compile wrapper
1910                        out = FunctionalizedRngRuntimeWrapper()._functionalized_rng_runtime_epilogue(
1911                            CompiledFunction.metadata, out, offset_index=len(out) - 1
1912                        )
1913                        return tuple(out)
1914                    assert (
1915                        not backward_state_indices
1916                    ), "BackwardState requires CompiledAutograd"
1917                    ctx.maybe_clear_saved_tensors()
1918
1919                    saved_tensors_use_once = (
1920                        not torch._C._autograd._get_current_graph_task_keep_graph()
1921                    )
1922
1923                    if CompiledFunction.compiled_bw is None:
1924                        assert lazy_backward_info is not None
1925
1926                        if not saved_tensors_use_once:
1927                            fw_metadata.bw_donated_idxs = []
1928                            # Update bw_donated_idxs if using lazy_backward_info from `aot_dispatch_autograd`
1929                            if (
1930                                hasattr(lazy_backward_info, "saved_context")
1931                                and hasattr(
1932                                    lazy_backward_info.saved_context, "fw_metadata"
1933                                )
1934                                and hasattr(
1935                                    lazy_backward_info.saved_context.fw_metadata,  # type: ignore[union-attr]
1936                                    "bw_donated_idxs",
1937                                )
1938                            ):
1939                                lazy_backward_info.saved_context.fw_metadata.bw_donated_idxs = (  # type: ignore[union-attr]
1940                                    []
1941                                )
1942
1943                        bw_module = lazy_backward_info.bw_module
1944                        placeholder_list = lazy_backward_info.placeholder_list
1945                        saved_context = lazy_backward_info.saved_context
1946                        saved_compile_context = lazy_backward_info.saved_compile_context
1947
1948                        context = (
1949                            torch._C._DisableAutocast if disable_amp else nullcontext
1950                        )
1951                        with tracing(saved_context), compile_context(
1952                            saved_compile_context
1953                        ), context(), track_graph_compiling(aot_config, "backward"):
1954                            CompiledFunction.compiled_bw = aot_config.bw_compiler(
1955                                bw_module, placeholder_list
1956                            )
1957                            # Maybe save cache entry
1958                            if try_save_cache_entry is not None:
1959                                try_save_cache_entry(
1960                                    CompiledFunction.compiled_bw, fw_metadata
1961                                )
1962
1963                    if (
1964                        torch._functorch.config.donated_buffer
1965                        and not saved_tensors_use_once
1966                        and fw_metadata.bw_donated_idxs != []
1967                    ):
1968                        torch._check(
1969                            False,
1970                            lambda: (
1971                                "This backward function was compiled with non-empty donated "
1972                                "buffers which requires create_graph=False and retain_graph=False. "
1973                                "Please keep backward(create_graph=False, retain_graph=False) "
1974                                "across all backward() function calls, or set "
1975                                "torch._functorch.config.donated_buffer=False to disable "
1976                                "donated buffer."
1977                            ),
1978                        )
1979
1980                    out = call_func_at_runtime_with_args(
1981                        CompiledFunction.compiled_bw,
1982                        all_args,
1983                        steal_args=True,
1984                        disable_amp=disable_amp,
1985                    )
1986
1987                    # Toss out the backward output tokens
1988                    num_bw_tokens = CompiledFunction.metadata.num_backward_tokens
1989                    if num_bw_tokens > 0:
1990                        out = out[:-num_bw_tokens]
1991
1992                    # TODO: replace this with FunctionalizedRngRuntimeWrapper.post_compile
1993                    out = FunctionalizedRngRuntimeWrapper()._functionalized_rng_runtime_epilogue(
1994                        CompiledFunction.metadata, out, offset_index=len(out) - 1
1995                    )
1996                    return tuple(out)
1997
1998                # Backward with forward inputs mutations is not supported in double backward.
1999                if (
2000                    torch.is_grad_enabled()
2001                    and CompiledFunction.metadata.indices_of_inputs_that_requires_grad_with_mutations_in_bw
2002                ):
2003                    raise RuntimeError(
2004                        "aot_autograd does not support input mutations with requires_grad in backward for create_graph=True"
2005                    )
2006
2007                if torch.is_grad_enabled() and any(
2008                    t.requires_grad for t in all_args if isinstance(t, torch.Tensor)
2009                ):
2010                    # Ensure that the graph is connected, and error if double backward is performed.
2011                    # See comment for why once_differentiable is not sufficient:
2012                    # https://github.com/pytorch/pytorch/pull/92348/files#r1072962107
2013                    class CompiledFunctionBackward(torch.autograd.Function):
2014                        # CompiledFunctionBackward is not yet supported in dynamo skipfiles
2015                        _compiled_autograd_should_lift = False
2016                        _aot_id = aot_config.aot_id
2017
2018                        @staticmethod
2019                        def forward(ctx, *unused_args):
2020                            outs = call_compiled_backward()
2021                            # TODO: figure out how to refactor the backward properly
2022                            # so I can use aot_dispatch_subclass_wrapper() here.
2023                            if CompiledFunction.maybe_subclass_metadata is not None:
2024                                assert (
2025                                    CompiledFunction.maybe_subclass_metadata.grad_input_metas
2026                                    is not None
2027                                )
2028                                outs_wrapped = wrap_tensor_subclasses(
2029                                    outs,
2030                                    subclass_metas=CompiledFunction.maybe_subclass_metadata.grad_input_metas,
2031                                )
2032                                return outs_wrapped
2033                            return outs
2034
2035                        @staticmethod
2036                        def backward(ctx, *args):
2037                            raise RuntimeError(
2038                                "torch.compile with aot_autograd does not currently support double backward"
2039                            )
2040
2041                    CompiledFunctionBackward._compiled_autograd_key = (  # type: ignore[method-assign]
2042                        CompiledFunction._compiled_autograd_key
2043                    )
2044
2045                    # Pass args even though they're unused, so that the graph is built
2046                    out = CompiledFunctionBackward.apply(*all_args)
2047                else:
2048                    out = call_compiled_backward()
2049
2050                # TODO: figure out how to refactor the backward properly so I can use aot_dispatch_subclass_wrapper() here.
2051                if CompiledFunction.maybe_subclass_metadata is not None:
2052                    assert (
2053                        CompiledFunction.maybe_subclass_metadata.grad_input_metas
2054                        is not None
2055                    )
2056                    outs_wrapped = wrap_tensor_subclasses(
2057                        out,
2058                        subclass_metas=CompiledFunction.maybe_subclass_metadata.grad_input_metas,
2059                    )
2060                    return outs_wrapped
2061                return out
2062
2063        compiled_function = RuntimeWrapper(
2064            indices_of_inps_to_detach=indices_of_inps_to_detach,
2065            trace_joint=True,
2066            disable_amp=disable_amp,
2067        ).post_compile(
2068            CompiledFunction.apply,
2069            aot_config,
2070            runtime_metadata=fw_metadata,
2071        )
2072
2073        return compiled_function
2074
2075
2076@dataclass
2077class DebugAssertWrapper(CompilerWrapper):
2078    flat_requires_grad: List[Optional[bool]] = field(default_factory=list)
2079
2080    def post_compile(
2081        self,
2082        compiled_fn,
2083        aot_config: AOTConfig,
2084        *,
2085        runtime_metadata: ViewAndMutationMeta,
2086    ):
2087        @wraps(compiled_fn)
2088        def debug_compiled_function(args: List[Any]):
2089            # TODO: Check aliasing relationships
2090            # TODO: Check strides for metadata mutation
2091            # (NB: ideally, this logic is factored out of this function and
2092            # you move these debug checks there)
2093
2094            # Check requires grad.  Bad case is when we compiled with
2095            # requires_grad = False, but input requires_grad = True
2096            # (vice versa is OK; we compute a gradient and then throw
2097            # it away when it hits the input.)
2098            for i, a in enumerate(args):
2099                can_require_grad = self.flat_requires_grad[i]
2100                if can_require_grad is None:
2101                    assert not isinstance(a, Tensor)
2102                elif not can_require_grad:
2103                    assert not a.requires_grad, format_guard_bug_msg(
2104                        aot_config,
2105                        f"{describe_input(i, aot_config)} would not require grad",
2106                    )
2107
2108            return compiled_fn(args)
2109
2110        return debug_compiled_function
2111
2112
2113def pre_compile(
2114    wrappers: List[CompilerWrapper],
2115    flat_fn: Callable,
2116    flat_args: List[Any],
2117    aot_config: AOTConfig,
2118    *,
2119    fw_metadata: ViewAndMutationMeta,
2120) -> Tuple[Callable, List[Tensor], ViewAndMutationMeta]:
2121    """
2122    Runs a sequence of wrappers on the given function and arguments.
2123    Mutates wrappers in place.
2124    """
2125    for wrapper in wrappers:
2126        flat_fn, flat_args, fw_metadata = wrapper.pre_compile(
2127            flat_fn, flat_args, aot_config, fw_metadata=fw_metadata
2128        )
2129    return flat_fn, flat_args, fw_metadata
2130
2131
2132def post_compile(
2133    wrappers: List[CompilerWrapper],
2134    compiled_fn: Callable,
2135    aot_config: AOTConfig,
2136    *,
2137    runtime_metadata: ViewAndMutationMeta,
2138) -> Tuple[Callable, ViewAndMutationMeta]:
2139    """
2140    Runs a sequence of wrappers on the given function. Should be called after pre_compile()
2141    """
2142    for wrapper in reversed(wrappers):
2143        compiled_fn = wrapper.post_compile(
2144            compiled_fn, aot_config, runtime_metadata=runtime_metadata
2145        )
2146    return compiled_fn, runtime_metadata
2147
2148
2149def make_runtime_safe(
2150    fw_metadata: ViewAndMutationMeta,
2151    maybe_subclass_meta: Optional[SubclassMeta],
2152):
2153    """
2154    Calls make_runtime_safe on all ViewAndMutationMetas.
2155    Modifies both arguments. Allows ViewAndMutationMetas to
2156    be safely cached in AOTAutogradCache.
2157    """
2158    fw_metadata.make_runtime_safe()
2159    if maybe_subclass_meta is not None:
2160        maybe_subclass_meta.fw_metadata.make_runtime_safe()
2161