xref: /aosp_15_r20/external/pytorch/torch/_functorch/_aot_autograd/schemas.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""
3The various dataclasses, Enums, namedtuples etc used in AOTAutograd. This includes
4input/output types, metadata, config, function signatures etc.
5"""
6
7import collections
8import dataclasses
9import functools
10from dataclasses import dataclass, field
11from enum import Enum
12from typing import Any, Callable, Dict, List, NewType, Optional, Set, Union
13
14import torch
15import torch.utils._pytree as pytree
16from torch._guards import Source
17from torch._ops import OpOverload
18from torch._subclasses import FakeTensor
19from torch._subclasses.fake_tensor import is_fake
20from torch.utils._python_dispatch import is_traceable_wrapper_subclass
21
22from .. import config
23from .functional_utils import (
24    _check_if_mutation_can_be_in_graph,
25    FunctionalTensorMetadataEq,
26)
27from .utils import strict_zip
28
29
30zip = strict_zip
31
32OutputType = Enum(
33    "OutputType",
34    (
35        # output is not an alias
36        "non_alias",
37        # output aliases an input
38        "alias_of_input",
39        # output **is** an input tensor
40        "is_input",
41        # output has a ._base tensor, which is a graph intermediate.
42        # We need to return its ._base as a graph output,
43        # so its requires_grad info is populated correctly.
44        # Instructs the runtime code to regenerate the current output
45        # from a base tensor, graph_intermediates[base_idx]
46        "alias_of_intermediate_save_as_output",
47        # Same as above; but we don't need to explicitly add its ._base
48        # as a graph output, because it already **is** a graph output.
49        "alias_of_intermediate",
50        # Same as above; but the output's ._base is **already** a user output.
51        # Instructs the runtime code to regenerate the current output from
52        # a base tensor, user_outputs[base_idx]
53        "alias_of_intermediate_base_is_user_output",
54        # See Note [Intermediate Bases Optimization]
55        "unsafe_view_alias",
56        # output is an alias, but has a custom autograd.Function backward.
57        # In this case, we don't want to do view-replay, since we won't be able to replay the custom function.
58        # Instead, we'll treat this output "normally", and trace its backward into the graph.
59        "custom_function_view",
60    ),
61)
62
63
64# This class stores info about every user output.
65@dataclass(frozen=True)
66class OutputAliasInfo:
67    # Tells us if this output is:
68    # (1) a regular (non-aliased) output
69    # (2) an alias of a forward input
70    # (3) **is** a forward input (special case of "alias_of_input")
71    # (4) an alias of an intermediate (aka an alias of an output of the inner traced forward)
72    # (5) an alias of an intermediate, that explicitly requires returning the intermediate
73    #     as a graph output
74    # (6) an alias of an intermediate, where that intermediate is also a user output
75    output_type: OutputType
76    # The raw type of the output (torch.Tensor, SymInt, etc)
77    raw_type: type
78    # If (1) above, then
79    # - base_idx is None
80    # If (2) or (3) above, then
81    # - Tells us that the base of this alias is user_fwd_input[base_idx]
82    #   (This is an index into the inputs *before* we make synthetic bases)
83    # If (4) or (5) above, then
84    # - Tells us that the base of this alias is output_graph_intermediates[base_idx]
85    #   here, this refers to the index of the *direct* traced
86    # If (6) above, then:
87    # - Tells us that the base of this alias is output_user_fwds[base_idx]
88    #   here, this refers to the index of the *direct* traced
89    base_idx: Optional[int]
90    # If it is a Tensor, what the dynamic dims are (otherwise is None)
91    dynamic_dims: Optional[Set[int]]
92    # requires_grad
93    requires_grad: bool
94    # FunctionalTensorWrapper that represents this output.
95    #
96    # Provides us the means to replay views from it.
97    #
98    # We need to wrap the actual FunctionalTensorWrapper with this class so that
99    # we only compare the tensor's metadata. That's because with the transformations
100    # of the model throughout AOTAutograd, the sequence of ViewMeta and the base
101    # tensor might change.
102    functional_tensor: Optional[FunctionalTensorMetadataEq] = None
103
104
105class MutationType(Enum):
106    NOT_MUTATED = 1
107    MUTATED_IN_GRAPH = 2
108    MUTATED_OUT_GRAPH = 3
109
110
111# This class tells us info about user inputs.
112@dataclass(frozen=True)
113class InputAliasInfo:
114    is_leaf: bool
115    mutates_data: bool
116    mutates_metadata: bool
117    mutations_hidden_from_autograd: bool
118    mutations_under_no_grad_or_inference_mode: bool
119    mutation_inductor_storage_resize: bool
120    mutates_storage_metadata: bool
121    requires_grad: bool
122    keep_input_mutations: bool
123
124    def __post_init__(self):
125        if self.mutates_storage_metadata:
126            # For convenience, we guarantee that this is always true.
127            # In practice, If we call .set_(), then at runtime there is no need
128            # to additionally fix  up the tensor metadata, since our runtime
129            # call to inp.set_(updated_inp) will already have the right metadata
130            assert self.mutates_metadata
131
132    @functools.cached_property
133    def mutation_type(self) -> MutationType:
134        if (
135            (not self.mutates_data)
136            and (not self.mutates_metadata)
137            and not (self.mutation_inductor_storage_resize)
138        ):
139            return MutationType.NOT_MUTATED
140
141        if _check_if_mutation_can_be_in_graph(
142            self.keep_input_mutations,
143            self.mutates_data,
144            self.mutates_metadata,
145            self.mutations_hidden_from_autograd,
146            self.mutations_under_no_grad_or_inference_mode,
147            self.mutates_storage_metadata,
148            self.mutation_inductor_storage_resize,
149            self.requires_grad,
150        ):
151            return MutationType.MUTATED_IN_GRAPH
152
153        return MutationType.MUTATED_OUT_GRAPH
154
155
156@dataclass
157class SubclassCreationMeta:
158    """
159    Used for AOTDispatch.
160    This dataclass gives us the information we need to reconstruct a tensor subclass
161    from our flat inputs.
162    Why is this important? The graph that we'd like to trace out contains flat tensor inputs,
163    But the user's original model may have subclass inputs and outputs.
164    So we need to wrap/unwrap subclasses as necessary to translate between the user's
165    view (subclass inps/outs), and the backend compiler's view (graph with no subclass args).
166
167    Complications arise mostly from the fact that a subclass can hold more than one inner tensor;
168    So for a given subclass input/output, we need to carefully track which indices map
169    to the subclass tensor in the corresponding "dense-tensor-only" graph.
170    """
171
172    # In the inner graph that only takes in dense tensor inputs,
173    # this maps to the first index of "tensors that should go in this subclass wrapper"
174    flat_tensor_start_idx: int
175    # arg_count is inclusive of the arg_counts of any
176    # inner tensor subclasses: If I have a TwoTensor and
177    # both of its inner elements are TwoTensors, then the
178    # arg_count of the outer-most sublass will be 4
179    arg_count: int
180    # meta and attrs are produced by the subclass's __tensor_flatten__.
181    # We need to keep them around along with outer_size / outer_stride to plumb them
182    # into __tensor_unflatten__
183    attrs: Dict[str, Union["SubclassCreationMeta", None]]
184    outer_size: List[int]
185    outer_stride: List[int]
186    meta: Any
187    # Stores the original subclass itself.
188    # This is needed because we need the autograd metadata on the original subclass
189    # (this is guaranteed to be a wrapper subclass that holds a fake tensor,
190    #  so holding onto this at runtime shouldn't leak memory)
191    # This field is nulled out after calling make_runtime_safe()
192    original_subclass: Optional[torch.Tensor]
193
194    # Used at runtime to determine the subclass type, so we don't need to save the original subclass
195    original_subclass_type: Optional[type] = None
196
197    def creation_fn(self, all_args, *, is_runtime: bool):
198        inner_tensors = {}
199
200        curr_start_idx = self.flat_tensor_start_idx
201        for attr, creation_meta in self.attrs.items():
202            if creation_meta is None:
203                subclass = all_args[curr_start_idx]
204                curr_start_idx += 1
205            else:
206                subclass = creation_meta.creation_fn(all_args, is_runtime=is_runtime)
207                curr_start_idx += creation_meta.arg_count
208            inner_tensors[attr] = subclass
209
210        if is_runtime:
211            assert self.original_subclass_type is not None
212            original_subclass_type = self.original_subclass_type
213        else:
214            original_subclass_type = type(self.original_subclass)
215
216        rebuilt = original_subclass_type.__tensor_unflatten__(  # type: ignore[attr-defined]
217            inner_tensors, self.meta, self.outer_size, self.outer_stride
218        )
219
220        if not is_runtime:
221            # After wrapping up the inner dense tensors into a subclass, we need to make sure that our new wrapper
222            # has correct autograd metadata, since we'll be tracing through the autograd engine with the subclass.
223            # We don't trace through the autograd engine at runtime though, so no need
224            # to compute this extra metadata then!
225            torch._mirror_autograd_meta_to(self.original_subclass, rebuilt)  # type: ignore[attr-defined]
226
227        return rebuilt
228
229    def make_runtime_safe(self):
230        assert self.original_subclass is not None
231        self.original_subclass_type = type(self.original_subclass)
232        self.original_subclass = None
233        # Recurse on nested subclass info
234        for creation_meta in self.attrs.values():
235            if creation_meta is not None:
236                creation_meta.make_runtime_safe()
237
238    def __post_init__(self):
239        # sanity assert to make sure we don't leak memory
240        assert is_fake(self.original_subclass)
241
242        # This saves the type of subclass nested structure to compare
243        # against runtime tangent inputs. We do wanna compute this at AOT
244        # time as it is invoked in hot-path
245        from .subclass_utils import get_types_for_subclass
246
247        self.subclass_type = get_types_for_subclass(self.original_subclass)
248
249
250# This class encapsulates all aliasing + mutation info we need about the forward graph
251# See a more detailed overview of the edge case handling at
252# https://docs.google.com/document/d/19UoIh_SVrMy_b2Sx5ZaeOJttm6P0Qmyss2rdBuyfoic/edit
253@dataclass(eq=False)
254class ViewAndMutationMeta:
255    # length = # user inputs
256    # This gives us info about every input, and what sort of mutation happened to it (if any)
257    input_info: List[InputAliasInfo]
258
259    # length = # user outputs
260    # This gives us info about every output (mostly around whether it aliases other tensors)
261    output_info: List[OutputAliasInfo]
262
263    # length = the number of intermediate bases appended as outputs to the end of the forward graph.
264    # Note: this is not necessarily the same thing as:
265    #   len([x for x in output_info if x.output_type == OutputType.alias_of_intermediate])
266    # Because outputs might share a ._base, or an output's ._base might itself be
267    # another user output (in both cases, we won't redundantly append bases to the end of the graph)
268    num_intermediate_bases: int
269
270    # For inference only: instructs us to keep data-only input mutations directly in the graph
271    keep_input_mutations: bool
272
273    # length = (# inputs w data mutations) + (# user outputs that are non_aliasing tensors)
274    #        + (# intermediate bases)
275    # These are the FakeTensor (or potential SymInt) outputs that we traced from our
276    # metadata pass of the user's forward function.
277    # Their only use today is to pass them as a best-guess for tangents when tracing the joint.
278    # Stashing them as part of our "metadata" makes it simpler if we want to run our analysis
279    # pass once, and re-use the output throughout AOTAutograd
280    traced_tangents: List[Any]
281
282    # Each of these is a list telling us about subclasses for the inputs/outputs/grad_outs
283    # They are used throughout AOTDispatch to tell us how to generate a list of subclass tensors,
284    # Given a (potentially larger) list of plain torch tensors.
285
286    # Taking subclass_inp_meta as an example:
287    #   subclass_inp_meta[i] = j (an int) tells us:
288    #     "The i'th user input is not a subclass, and corresponds to inputs[j] of the plain-tensor graph."
289    #   subclass_inp_meta[i] = SubclassCreationMeta(flat_tensor_start_idx=3, arg_count=2)
290    #     "The i'th user input is subclass holding two inner tensors, which are
291    #      inputs[3] and inputs[4] of the plain-tensor graph".
292
293    # length = # user inputs
294    subclass_inp_meta: List[Union[int, SubclassCreationMeta]]
295    # So, the full set of outputs to the forward graph looks something like:
296    # (*mutated_inps, *user_outs, *intermediate_bases, *saved_for_bw_tensors)
297    # where the first 3 of those 4 can be subclasses
298    # (but not saved_for_bw tensors, since these are internal to the compiler
299    # and not user visible, so there's no point in wrapping/unwrapping them at runtime).
300    # This list contains subclass information on all of the fw graph outputs
301    # except for saved_for_bw_tensors.
302    subclass_fw_graph_out_meta: List[Union[int, SubclassCreationMeta]]
303    # length = # backward graph inputs
304    subclass_tangent_meta: List[Union[int, SubclassCreationMeta]]
305    # TODO: we should kill this
306    # (need to default it to not break internal)
307    is_train: bool = False
308
309    # length = (# inputs w data mutations) + (# user outputs that are non_aliasing tensors)
310    #        + (# intermediate bases)
311    # At runtime, we don't keep the traced_tangents around since they're not serializable.
312    # Instead, we keep any necessary subclass metadata necessary about each traced_tangent.
313    # This list is generated after calling make_runtime_safe().
314    traced_tangent_metas: Optional[List[Any]] = None
315
316    num_symints_saved_for_bw: Optional[int] = None
317
318    # The grad_enabled mutation that will be emitted in the runtime_wrapper epilogue
319    # NOTE: AOTAutograd will assume that the ambient `is_grad_enabled` is the grad mode
320    # that is intended to be in effect prior to running the graph, in keeping with
321    # equivalence to eager mode. It is the responsibility of upstream graph acquisition
322    # to reset the grad mode to its pre-graph value prior to calling aot_autograd.
323    grad_enabled_mutation: Optional[bool] = None
324
325    # Keeps track of whether `torch.use_deterministic_algorithms` was turned on
326    # when the forward was run. If deterministic mode was turned off during the
327    # forward, but is turned on during the backward call, then an error is
328    # raised
329    deterministic: Optional[bool] = None
330
331    # Keeps track of which input indices store parameters (which we will treat as static)
332    static_input_indices: List[int] = field(default_factory=list)
333
334    # Map of effect type (ex. _EffectType.ORDERED) to token.  If there are
335    # side-effectful operators, FunctionalTensorMode will populate this
336    # dictionary telling us how many tokens we will need during tracing.
337    tokens: Dict[Any, torch.Tensor] = field(default_factory=dict)
338
339    # Only filled in if/when we trace the joint function
340    # If an input requires grad and is mutated in the backward, it is only safe to keep the mutation
341    # in the graph if gradients are disabled while the backward runs
342    # (grad mode is disabled by default when users run the backward, but can be turned on with create_graph=True)
343    # At runtime during the backward, we use this list of indices to error properly if we find out
344    # that it was not safe to include a backward mutation in the graph.
345    indices_of_inputs_that_requires_grad_with_mutations_in_bw: List[int] = field(
346        default_factory=list
347    )
348
349    # Indexes of saved tensors which are donated buffer.
350    # Donated buffer means the tensor is not alias of any forward user input, forward user output,
351    # and backward output.
352    bw_donated_idxs: Optional[List[int]] = None
353
354    # Number of tokens used in backward, appended at the end of backward outputs.
355    # Filled after tracing joint function.
356    num_backward_tokens: int = 0
357
358    def __post_init__(self):
359        # pre-compute the indices of the inputs that are mutated.
360        # When keep_input_mutations is set, we don't need to worry about our epilogue
361        # handling data-only mutations, because we keep them directly in the graph.
362
363        mutated_inp_runtime_indices = [
364            i
365            for i, m in enumerate(self.input_info)
366            if (m.mutation_type == MutationType.MUTATED_OUT_GRAPH)
367        ]
368
369        mutated_graph_handled_indices = [
370            i
371            for i, m in enumerate(self.input_info)
372            if m.mutation_type == MutationType.MUTATED_IN_GRAPH
373        ]
374        self.mutated_graph_handled_indices = mutated_graph_handled_indices
375        self.num_mutated_graph_handled_indices = len(self.mutated_graph_handled_indices)
376
377        mutated_graph_handled_indices_seen_by_autograd = [
378            i
379            for i in mutated_graph_handled_indices
380            if not self.input_info[i].mutations_hidden_from_autograd
381        ]
382
383        self.mutated_graph_handled_indices_seen_by_autograd = (
384            mutated_graph_handled_indices_seen_by_autograd
385        )
386        self.num_mutated_graph_handled_indices_seen_by_autograd = len(
387            self.mutated_graph_handled_indices_seen_by_autograd
388        )
389
390        aliased_out_indices = [
391            i
392            for i, m in enumerate(self.output_info)
393            if m.output_type
394            not in [
395                OutputType.non_alias,
396                OutputType.unsafe_view_alias,
397                OutputType.custom_function_view,
398            ]
399        ]
400        unsafe_view_out_indices = [
401            i
402            for i, m in enumerate(self.output_info)
403            if m.output_type is OutputType.unsafe_view_alias
404        ]
405
406        # This is pre-computed in post_init for perf.
407        # It contains the index of every element
408        # of input_info that corresponds to a mutation (data or metadata or both)
409        self.mutated_inp_runtime_indices = mutated_inp_runtime_indices
410        self.num_mutated_inp_runtime_indices = len(self.mutated_inp_runtime_indices)
411
412        # This is pre-computed for perf.
413        # It contains the index of every element
414        # of output_info that corresponds to an alias (either of an input or intermediate)
415        self.aliased_out_indices = aliased_out_indices
416        self.unsafe_view_out_indices = unsafe_view_out_indices
417        self.num_outputs = len(self.output_info)
418        self.num_outputs_non_aliased = len(
419            [
420                x
421                for x in self.output_info
422                if x.output_type
423                in [
424                    OutputType.non_alias,
425                    OutputType.unsafe_view_alias,
426                    OutputType.custom_function_view,
427                ]
428            ]
429        )
430        self.num_outputs_aliased_to_inputs = len(
431            [
432                x
433                for x in self.output_info
434                if x.output_type
435                in [
436                    OutputType.alias_of_input,
437                    OutputType.is_input,
438                ]
439            ]
440        )
441        self.num_unsafe_view_outputs = len(self.unsafe_view_out_indices)
442        self.num_outputs_aliased_to_intermediates = len(
443            [
444                x
445                for x in self.output_info
446                if x.output_type
447                in [
448                    OutputType.alias_of_intermediate,
449                    OutputType.alias_of_intermediate_save_as_output,
450                    OutputType.alias_of_intermediate_base_is_user_output,
451                ]
452            ]
453        )
454        self.num_outputs_aliased = (
455            self.num_outputs_aliased_to_inputs
456            + self.num_outputs_aliased_to_intermediates
457        )
458
459        self.dynamic_outputs = any(o.dynamic_dims for o in self.output_info)
460        # See Note: [AOTAutograd Backward Guards]
461        # This is pre-computed for fast asserts on the types of our grad_outputs in the backward.
462        # Eventually, we should kill this and replace with real backward guards.
463        # (we want to precompute the "runtime" types, so replace FakeTensor with torch.Tensor)
464        self.output_types = [
465            torch.Tensor if isinstance(x, FakeTensor) else type(x)
466            for x in self.traced_tangents
467        ]
468
469        self.is_rng_op_functionalized = config.functionalize_rng_ops
470        # All of the above metadata is collected by tracing the fw function.
471        # However, extra outputs for rng offsets behave differently. Both fwd
472        # and bwd graphs have their own outputs for the total consumed offsets.
473        # Unlike mutated inputs, we don't have to worry about sending the right
474        # set of tensors between fwd and bwd. Fwd and bwd offsets are
475        # independent and simpler to handle. Therefore, we track them
476        # separately.
477        self.num_outputs_rng_offset = 1 if self.is_rng_op_functionalized else 0
478
479        # Our forward() returns both (tokens, mutated_inputs, outputs, output_intermediate_bases, saved_tensors, saved_symints)
480        # Tokens will be split out before mutations/view handling and we do not count them here.
481        self.num_forward_returns = (
482            self.num_mutated_inp_runtime_indices
483            + self.num_outputs
484            + self.num_intermediate_bases
485        )
486        # In case of functionalization of rng ops, the fw_module returns one
487        # additional output for rng offset. This rng offset is used right
488        # away to advance the rng state, and is not passed on to the raw
489        # outputs. However, we need to know the exact boundary to identify
490        # which tensors to be saved for the bwd graph.  num_forward captures
491        # this information.
492        self.num_forward = self.num_forward_returns + self.num_outputs_rng_offset
493
494    def make_runtime_safe(self):
495        """
496        There are various fields in ViewAndMutationMeta that aren't serializable. This function is called after all tracing
497        is completed to simplify certain fields in the metadata so that they can be safely cached.
498
499        Doing so may lose information (in the case of traced_tangents), but none of the information is needed at runtime.
500        """
501        # TODO: This function is only a best effort: there are other fields that may not be cache safe
502        # (i.e., there's no guarantee that tensor_flatten() returns a serializable result), or that
503        # SubclassCreationMeta is cache safe.
504        assert self.traced_tangent_metas is None
505
506        def extract_metadata(t):
507            if isinstance(t, torch.Tensor) and is_traceable_wrapper_subclass(t):
508                (inner_tensors, flatten_spec) = t.__tensor_flatten__()  # type: ignore[attr-defined]
509                # Technically, we only need the flatten_spec, not the inner tensors.
510                # However, some Tensor subclasses (like TwoTensor) may have flatten_spec = None.
511                # And we want to be able to assert that this metadata is non-None,
512                # to distinguish between "this was a tensor subclass with no metadata" vs.
513                # "this wasn't a tensor subclass at all".
514                return (inner_tensors, flatten_spec)
515            else:
516                return None
517
518        self.traced_tangent_metas = [extract_metadata(t) for t in self.traced_tangents]
519        # Clear traced tangents at runtime
520        self.traced_tangents = []
521        new_output_info = []
522        for out in self.output_info:
523            if config.view_replay_for_aliased_outputs:
524                new_out = out
525            else:
526                # If we're not using view_replay, remove the functional tensor.
527                # Functional tensors are unfortunately not serializable,
528                # so doing this is required for AOTAutograd caching.
529                new_out = dataclasses.replace(out, functional_tensor=None)
530            new_output_info.append(new_out)
531        self.output_info = new_output_info
532        for inp_meta in self.subclass_inp_meta:
533            if isinstance(inp_meta, SubclassCreationMeta):
534                inp_meta.make_runtime_safe()
535        for inp_meta in self.subclass_fw_graph_out_meta:
536            if isinstance(inp_meta, SubclassCreationMeta):
537                inp_meta.make_runtime_safe()
538        for inp_meta in self.subclass_tangent_meta:
539            if isinstance(inp_meta, SubclassCreationMeta):
540                inp_meta.make_runtime_safe()
541
542    @property
543    def tensors_saved_for_backwards_slice(self):
544        assert self.num_symints_saved_for_bw is not None
545        if self.num_symints_saved_for_bw > 0:
546            return slice(self.num_forward, -self.num_symints_saved_for_bw)
547        else:
548            return slice(self.num_forward, None)
549
550    @property
551    def symints_saved_for_backwards_slice(self):
552        assert self.num_symints_saved_for_bw is not None
553        if self.num_symints_saved_for_bw > 0:
554            return slice(-self.num_symints_saved_for_bw, None)
555        else:
556            return slice(0, 0)  # empty slice
557
558    def __eq__(self, other):
559        if not isinstance(other, ViewAndMutationMeta):
560            return NotImplemented
561        return (
562            self.input_info == other.input_info
563            and self.output_info == other.output_info
564            and self.num_intermediate_bases == other.num_intermediate_bases
565            and self.keep_input_mutations == other.keep_input_mutations
566            and self.is_rng_op_functionalized == other.is_rng_op_functionalized
567            and self.num_outputs_rng_offset == other.num_outputs_rng_offset
568            and len(self.traced_tangents) == len(other.traced_tangents)
569            and all(
570                x.shape == y.shape and x.dtype == y.dtype
571                for x, y, in zip(self.traced_tangents, other.traced_tangents)
572            )
573            and self.num_backward_tokens == other.num_backward_tokens
574        )
575
576
577@dataclass(eq=False)
578class SubclassMeta:
579    # A copy of all forward metadata, but computed on the *dense* tensor forward (after desugaring subclasses)
580    # So for example, if the user had a model containing two `TwoTensor` inputs,
581    # Then `SubclassMeta.fw_metadata.input_infos` would have length 4 here.
582    fw_metadata: ViewAndMutationMeta
583
584    # Note: [Computing Subclass Metadata about grad_inputs]
585    # Given a list of flattened, plain tensor grad_inputs, this tells us how to reconstruct the grad_input subclasses
586    #
587    # You might think: why not just assume that all grad_inputs will have the same subclass-ness as the original inputs?
588    # (AOTAutograd generally assumes other properties, e.g. that grad_outputs are contiguous)
589    #
590    # This doesn't really work though. take this example:
591    #
592    # def f(DoubleTensor, DenseTensor):
593    #     return DoubleTensor  * DenseTensor
594    #
595    # In the above example, the .grad field of *both* DoubleTensor and DenseTensor will be a DoubleTensor.
596    # When we trace out a joint fw-bw graph, we'll end up returning two subclasses for the two grad_inputs.
597    # This means that our backward graph will return 4 outputs (two dense tensors for each DoubleTensor grad_input)
598    # and we need to properly store the metadata that tells us how to turn these 4 outputs back into DoubleTensors.
599    #
600    # Note that this info **cannot** easily be figured out from ViewAndMutationMeta.
601    # We can only compute this info by tracing the entire joint and examining the grad_inputs that we computed.
602    #
603    # See Note: [AOTAutograd Backward Guards]
604    # This will also eventually require us to install backward guards,
605    # in case we made incorrect assumptions about the subclass-ness of our grad_outputs
606    #
607    # Optional field because we don't compute for inference graphs
608    grad_input_metas: Optional[List[Union[int, SubclassCreationMeta]]] = None
609
610    def __init__(self) -> None:
611        # The fields in this class get set after its construction.
612        pass
613
614
615# This class exists because:
616# - the autograd.Function.forward() in aot autograd returns outputs that might alias inputs
617# - we only care about the metadata on those aliases, so we can regenerate them.
618#   We do not want them to participate in the autograd.Function.
619# We do that by wrapping them in an opaque class, so the autograd.Function
620# does not know to treat them as tensors.
621@dataclass(frozen=True)
622class TensorAlias:
623    alias: torch.Tensor
624
625
626@dataclass
627class BackwardSignature:
628    """
629    Provides information about the backward section of an exported
630    joint forward-backward graph.
631    For a particular fx GraphModule, this class contains information on:
632    (1) A mapping from each gradient (backwards output) to the parameter
633        it corresponds to (forward input)
634    (2) A mapping from each gradient (backwards output) to the user input
635        it corresponds to (forward input)
636    (3) Which of the forward outputs corresponds to the loss, that we backprop on.
637
638    Each string name is the `node.name` of the corresponding node in the fx graph.
639    """
640
641    gradients_to_parameters: Dict[str, str]
642    gradients_to_user_inputs: Dict[str, str]
643    loss_output: str
644
645
646GraphOutputName = NewType("GraphOutputName", str)
647GraphInputName = NewType("GraphInputName", str)
648FQN = NewType("FQN", str)
649
650
651@dataclass
652class GraphSignature:
653    """
654    Provides information about an exported module.
655    For a particular fx GraphModule, this class contains information on:
656    (1) Which graph inputs are parameters, buffers, or user inputs
657    (2) (for params/buffers) a mapping from the name of each graph argument
658        to its parameter/buffer FQN in the original nn.Module.
659    (3) If there are input mutations, these are represented as extra outputs
660        in the fx GraphModule. We provide a mapping from these
661        extra output names to the names of the actual inputs.
662    (4) The pytree metadata on how to flatten/unflatten inputs and outputs.
663        The corresponding FX GraphModule only accepts and returns
664        pytree-flattened inputs/outputs.
665    (5) (Optionally) if the FX is a joint forward-backward graph, we provide
666        a signature on the backward section of the joint graph.
667    """
668
669    parameters: List[FQN]
670    buffers: List[FQN]
671
672    user_inputs: List[GraphInputName]
673    user_outputs: List[GraphOutputName]
674    inputs_to_parameters: Dict[GraphInputName, FQN]
675    inputs_to_buffers: Dict[GraphInputName, FQN]
676
677    # If the user's module mutates a buffer,
678    # it's represented in the graph as an extra graph output.
679    # This dict is a mapping from
680    # "graph outputs that correspond to updated buffers"
681    # to the FQN names of those mutated buffers.
682    buffers_to_mutate: Dict[GraphOutputName, FQN]
683    user_inputs_to_mutate: Dict[GraphOutputName, GraphInputName]
684
685    in_spec: pytree.TreeSpec
686    out_spec: pytree.TreeSpec
687
688    backward_signature: Optional[BackwardSignature]
689
690    input_tokens: List[GraphInputName]
691    output_tokens: List[GraphOutputName]
692
693    @classmethod
694    def from_tracing_metadata(
695        cls,
696        *,
697        in_spec: pytree.TreeSpec,
698        out_spec: pytree.TreeSpec,
699        graph_input_names: List[str],
700        graph_output_names: List[str],
701        view_mutation_metadata: ViewAndMutationMeta,
702        named_parameters: List[str],
703        named_buffers: List[str],
704        num_user_inputs: int,
705        num_user_outputs: int,
706        loss_index: Optional[int],
707        backward_signature: Optional[BackwardSignature],
708    ) -> "GraphSignature":
709        graph_inputs = graph_input_names
710        graph_outputs = graph_output_names
711        parameters = list(named_parameters)
712        buffers = list(named_buffers)
713        num_tokens = len(view_mutation_metadata.tokens)
714
715        # Calling convention assumptions:
716        # (1) graph inputs = (input_tokens, params, buffers, user_inputs)
717        # (2) graph outputs = (output_tokens, mutated_inputs, user_outs, param_gradients)
718        # (If we are capturing an inference graph, this convention is identical
719        #  except that param_gradients is empty)
720        # See Note [Side-Effectful Tokens in AOTAutograd] for information on tokens
721
722        # Address input calling conventions:
723        start, stop = 0, num_tokens
724        input_tokens = graph_inputs[start:stop]
725
726        start, stop = stop, stop + len(parameters)
727        inputs_to_parameters = dict(zip(graph_inputs[start:stop], parameters))
728
729        start, stop = stop, stop + len(buffers)
730        inputs_to_buffers = dict(
731            zip(
732                graph_inputs[start:stop],
733                buffers,
734            )
735        )
736
737        start, stop = stop, stop + num_user_inputs
738        user_inputs = graph_inputs[start:stop]
739
740        # We should've gone through all the inputs now
741        assert len(graph_inputs) - stop == 0
742
743        # Address output calling conventions:
744        start, stop = 0, num_tokens
745        output_tokens = graph_outputs[start:stop]
746
747        names = [*input_tokens, *parameters, *buffers, *user_inputs]
748        mutations = []
749        for idx, input_info in enumerate(view_mutation_metadata.input_info):
750            if input_info.mutates_data:
751                # Only buffers can be mutated, not parameters
752                assert idx >= len(parameters)
753                mutations.append(names[idx + num_tokens])
754
755        assert len(mutations) == view_mutation_metadata.num_mutated_inp_runtime_indices
756
757        start, stop = (
758            stop,
759            stop + view_mutation_metadata.num_mutated_inp_runtime_indices,
760        )
761        outputs_to_mutations = dict(zip(graph_outputs[start:stop], mutations))
762
763        user_inputs_to_mutate = {}
764        buffers_to_mutate = {}
765        for output_name, mutation_name in outputs_to_mutations.items():
766            if mutation_name in user_inputs:
767                user_inputs_to_mutate[output_name] = mutation_name
768            else:
769                assert mutation_name in buffers
770                buffers_to_mutate[output_name] = mutation_name
771
772        start, stop = stop, stop + num_user_outputs
773        user_outputs = graph_outputs[start:stop]
774
775        unused_outputs = len(graph_outputs) - stop
776        if backward_signature is not None:
777            unused_outputs -= len(backward_signature.gradients_to_parameters) + len(
778                backward_signature.gradients_to_user_inputs
779            )
780        assert unused_outputs == 0
781
782        return GraphSignature(
783            parameters=parameters,  # type: ignore[arg-type]
784            buffers=buffers,  # type: ignore[arg-type]
785            user_inputs=user_inputs,  # type: ignore[arg-type]
786            user_outputs=user_outputs,  # type: ignore[arg-type]
787            inputs_to_buffers=inputs_to_buffers,  # type: ignore[arg-type]
788            inputs_to_parameters=inputs_to_parameters,  # type: ignore[arg-type]
789            user_inputs_to_mutate=user_inputs_to_mutate,
790            buffers_to_mutate=buffers_to_mutate,  # type: ignore[arg-type]
791            in_spec=in_spec,
792            out_spec=out_spec,
793            backward_signature=backward_signature,
794            input_tokens=input_tokens,  # type: ignore[arg-type]
795            output_tokens=output_tokens,  # type: ignore[arg-type]
796        )
797
798
799@dataclass
800class AOTConfig:
801    """
802    Configuration for AOTDispatcher
803    """
804
805    fw_compiler: Callable
806    bw_compiler: Callable
807    partition_fn: Callable
808    decompositions: Dict[OpOverload, Callable]
809    num_params_buffers: int
810    aot_id: int
811    keep_inference_input_mutations: bool
812    is_export: bool = False
813    no_tangents: bool = False
814    dynamic_shapes: bool = False
815    aot_autograd_arg_pos_to_source: Optional[List[Source]] = None
816    static_input_indices: Optional[List[int]] = None
817    inference_compiler: Optional[Callable] = None
818    enable_log: bool = True
819    # this is always false outside of export.
820    pre_dispatch: bool = False
821
822    # Key to use for AOTAutogradCache
823    cache_key: Optional[str] = None
824
825    def __post_init__(self):
826        if self.pre_dispatch:
827            assert self.is_export, "Can only have pre_dispatch IR for export."
828
829
830SubclassTracingInfo = collections.namedtuple(
831    "SubclassTracingInfo",
832    ["plain_tensor_trace_fn", "plain_tensor_args", "maybe_subclass_meta"],
833)
834