xref: /aosp_15_r20/external/pytorch/torch/_functorch/_aot_autograd/subclass_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""
3This file contains utilities for tracing through __torch_dispatch__ based tensor subclasses and modes.
4AOTAutograd's responsibility is to trace through all pytorch capabilities that live in the pytorch dispatcher,
5and this includes tensor subclasses that implement __torch_dispatch__.
6"""
7
8import typing
9from typing import Any, List, Optional, Tuple, Union
10
11import torch.utils._pytree as pytree
12from torch import Tensor
13from torch._subclasses.fake_tensor import get_plain_tensors
14from torch.utils._python_dispatch import is_traceable_wrapper_subclass
15
16from .schemas import MutationType, SubclassCreationMeta, ViewAndMutationMeta
17from .utils import strict_zip
18
19
20zip = strict_zip
21
22
23def requires_subclass_dispatch(args, fw_metadata: ViewAndMutationMeta) -> bool:
24    args_flattened = pytree.arg_tree_leaves(*args)
25    any_subclass_args = any(
26        is_traceable_wrapper_subclass(x)
27        for x in args_flattened
28        if isinstance(x, Tensor)
29    )
30    from torch._functorch._aot_autograd.schemas import SubclassCreationMeta
31
32    any_subclass_outputs = any(
33        type(x) is SubclassCreationMeta for x in fw_metadata.subclass_fw_graph_out_meta
34    )
35    # This tells us whether or not we need to perform any unwrapping/wrapping of tensor subclasses at runtime.
36    return any_subclass_args or any_subclass_outputs
37
38
39def create_subclass_metadata(a, start_idx):
40    if not is_traceable_wrapper_subclass(a):
41        return None, start_idx + 1
42
43    inner_keys, metadata = a.__tensor_flatten__()
44    new_start_idx = start_idx
45    attrs = {}
46    for key in inner_keys:
47        new_subclass_meta, new_start_idx = create_subclass_metadata(
48            getattr(a, key), new_start_idx
49        )
50        attrs[key] = new_subclass_meta
51
52    # It *must* be because is_traceable_wrapper_subclass() - but mypy is not smart.
53    assert isinstance(a, Tensor)
54
55    return (
56        SubclassCreationMeta(
57            flat_tensor_start_idx=start_idx,
58            arg_count=new_start_idx - start_idx,
59            attrs=attrs,
60            meta=metadata,
61            outer_size=a.size(),  # type: ignore[attr-defined, arg-type]
62            outer_stride=a.stride(),  # type: ignore[arg-type]
63            original_subclass=a,
64        ),
65        new_start_idx,
66    )
67
68
69# Given a real tensor subclass, returns a nested list of Plain tensor types
70def get_types_for_subclass(tensor_subclass):
71    if not is_traceable_wrapper_subclass(tensor_subclass):
72        return ["Tensor"]
73    inner_keys, _ = tensor_subclass.__tensor_flatten__()
74    result = []
75    for key in inner_keys:
76        inner_tensor = getattr(tensor_subclass, key)
77        result.extend(get_types_for_subclass(inner_tensor))
78    return result
79
80
81# Given a flat list of arguments, some of which may be tensor subclasses,
82# computes metadata about "how to reconstruct the current list of subclasses,
83# if we were given their flattened dense tensors instead"
84def create_subclass_meta(
85    curr_args: Union[List[Any], Tuple[Any, ...]]
86) -> List[Union[int, SubclassCreationMeta]]:
87    idx = 0
88    infos: List[Union[int, SubclassCreationMeta]] = []
89    for a in curr_args:
90        if is_traceable_wrapper_subclass(a):
91            assert isinstance(a, Tensor)
92            start_idx = idx
93            subclass_meta, _ = create_subclass_metadata(a, start_idx)
94            infos.append(subclass_meta)
95            cnt = subclass_meta.arg_count
96        else:
97            infos.append(idx)
98            cnt = 1
99        idx += cnt
100    return infos
101
102
103# Output structure:
104# - List[Tensor] if tracing an inference graph
105# - Tuple[List[Tensor], List[Tensor]] if tracing a joint graph.
106# This function effectively concats each inner list of subclass tensors
107# into a (potentially longer) list of inner tensors.
108#
109# This function takes in a pytree of arguments and unwraps any tensor subclasses.
110# Annoyingly, we can't use pytrees to perform the unwrapping, because unwrapping returns
111# a list of tensors that we would then need to concat together.
112# Instead, we specialize the logic for the inference vs. joint graph case.
113# NOTE: this function is hot, since we unwrap tensor subclass inputs at runtime
114def unwrap_tensor_subclasses(wrapped_args, *, is_joint_structure: bool):
115    def concat_inner_tensors_from_subclasses(xs):
116        xs_inner = []
117        for x in xs:
118            if is_traceable_wrapper_subclass(x):
119                xs_inner.extend(get_plain_tensors(typing.cast(Tensor, x)))
120            else:
121                xs_inner.append(x)
122        return xs_inner
123
124    if is_joint_structure:
125        assert isinstance(wrapped_args, tuple) and len(wrapped_args) == 2
126        assert isinstance(wrapped_args[0], (tuple, list)) and isinstance(
127            wrapped_args[1], (tuple, list)
128        )
129        unwrapped_args_fw = concat_inner_tensors_from_subclasses(wrapped_args[0])
130        unwrapped_args_tangents = concat_inner_tensors_from_subclasses(wrapped_args[1])
131        unwrapped_args = (unwrapped_args_fw, unwrapped_args_tangents)
132    else:
133        assert isinstance(wrapped_args, (list, tuple))
134        unwrapped_args_fw = concat_inner_tensors_from_subclasses(wrapped_args)
135        unwrapped_args = unwrapped_args_fw
136    return unwrapped_args
137
138
139def remap_unwrapped_subclass_arg_indices(wrapped_args, static_input_indices):
140    static_input_indices = set(static_input_indices)
141    new_ind = 0
142    remapped_static_indices = []
143    for i, arg in enumerate(wrapped_args):
144        num_indices = 1
145        if is_traceable_wrapper_subclass(arg):
146            num_indices = len(get_plain_tensors(typing.cast(Tensor, arg)))
147
148        for _ in range(num_indices):
149            if i in static_input_indices:
150                remapped_static_indices.append(new_ind)
151
152            new_ind += 1
153
154    return remapped_static_indices
155
156
157# Turns a flattened list of tensor arguments into (maybe) subclass tensors.
158# This function is used both at trace time and runtime, so we have an is_runtime flag telling us which context we're in.
159def wrap_tensor_subclasses(
160    unwrapped_args: Union[Tuple[Any, ...], List[Any]],
161    *,
162    subclass_metas: List[Union[int, SubclassCreationMeta]],
163    num_fw_outs_saved_for_bw: Optional[int] = None,
164    is_runtime: bool = False,
165) -> Tuple[Any, ...]:
166    wrapped_args = []
167    num_args_tallied = 0
168    for subclass_meta in subclass_metas:
169        if isinstance(subclass_meta, int):
170            wrapped_args.append(unwrapped_args[subclass_meta])
171            num_args_tallied += 1
172        else:
173            assert isinstance(subclass_meta, SubclassCreationMeta)
174            wrapped_args.append(
175                subclass_meta.creation_fn(unwrapped_args, is_runtime=is_runtime)
176            )
177            num_args_tallied += subclass_meta.arg_count
178
179    # Note: [Partitioner handling for Subclasses, Part 2]
180    # At the beginning of AOTAutograd, we collect metadata on the inputs and outputs of the user fw,
181    # to figure out which inputs/outputs are subclasses, and how to reconstruct the subclasses after flattening them.
182    #
183    # When this function is called at runtime in the forward,
184    # we have been passed a list of (flattened) dense-tensor fw-outs, and need to reconstruct any subclass fw outs.
185    #
186    # One reasonable question that you should ask: when should the dense_tensor -> subclass_tensor wrapping happen?
187    # Answer: we do it **inside of our compiled autograd.Function**.
188    # This seems like morally the right place: autograd happens above subclass desugaring,
189    # so autograd should see actual tensor subclasses at runtime, and not flattened dense tensors.
190    #
191    # This causes a tricky interaction though: when we run the min-cut partitioner to divvy up the joint graph
192    # into a forward and backward graph, we end up with some activations that show up as extra outputs
193    # in the compiled forward graph, that are **not** user outputs.
194    # These activations are not visible to the user, and so there's no need for us to wrap them back into subclasses.
195    #
196    # On top of that, when we first computed subclass metadata (in `run_functionalized_fw_and_collect_metadata`),
197    # we computed subclass metadata on every forward output, but this did **not** include activations
198    # created by the partitioner.
199    # as a result, `unwrapped_args` here will correspond to (*unwrapped_user_fw_outs, *activations),
200    # but `subclass_metas` will only correspond to subclass metatadata on `user_fw_outs`.
201    # We then need to make sure that we return (*wrapped_user_fw_outs, *activations).
202    if num_fw_outs_saved_for_bw is not None:
203        assert len(unwrapped_args) == num_args_tallied + num_fw_outs_saved_for_bw, (
204            f"Expected the number actual unwrapped-subclass outputs {len(unwrapped_args)} to equal "
205            f"the number of args calculated from subclasses ({num_args_tallied}) plus the number of "
206            f"additional activations saved for the backward pass ({num_fw_outs_saved_for_bw})"
207        )
208        activations = unwrapped_args[num_args_tallied:]
209        if isinstance(wrapped_args, tuple) and isinstance(activations, tuple):
210            return wrapped_args + activations
211        return tuple(list(wrapped_args) + list(activations))
212    else:
213        assert len(unwrapped_args) == num_args_tallied
214        return tuple(wrapped_args)
215
216
217# Given a bunch of "dense" tensor arguments, this function (potentially) wraps them into tensor subclasses.
218# This function carefully handles the inference vs. joint cases:
219# - when is_joint_structure is True, args is (primals, tangents)
220# - when is_joint_structure is False, args is [*primals]
221def wrap_tensor_subclasses_maybe_joint(
222    unwrapped_args, *, is_joint_structure: bool, meta: ViewAndMutationMeta
223) -> Union[Tuple[Any, ...], List[Any]]:
224    # Since this function is re-used for both inference and joint graphs,
225    if is_joint_structure:
226        assert isinstance(unwrapped_args, tuple) and len(unwrapped_args) == 2
227        assert isinstance(unwrapped_args[0], (tuple, list)) and isinstance(
228            unwrapped_args[1], (tuple, list)
229        )
230        primals, tangents = unwrapped_args[0], unwrapped_args[1]
231        wrapped_primals = wrap_tensor_subclasses(
232            primals, subclass_metas=meta.subclass_inp_meta
233        )
234        wrapped_tangents = wrap_tensor_subclasses(
235            tangents, subclass_metas=meta.subclass_tangent_meta
236        )
237        return (wrapped_primals, wrapped_tangents)
238    else:
239        wrapped_args = wrap_tensor_subclasses(
240            unwrapped_args, subclass_metas=meta.subclass_inp_meta
241        )
242        return wrapped_args
243
244
245# TODO: UNUSED. delete?
246def create_metadata_for_subclass(meta: ViewAndMutationMeta) -> ViewAndMutationMeta:
247    # input infos
248    input_info = []
249    for inp, subclass_meta in zip(meta.input_info, meta.subclass_inp_meta):
250        num_inps = 1 if isinstance(subclass_meta, int) else subclass_meta.arg_count
251        for _ in range(num_inps):
252            input_info.append(inp)
253
254    # output infos
255    output_info = []
256    subclass_out_meta_user_outs_only = meta.subclass_fw_graph_out_meta[
257        meta.num_mutated_inp_runtime_indices :
258    ]
259    if meta.num_intermediate_bases > 0:
260        subclass_out_meta_user_outs_only = subclass_out_meta_user_outs_only[
261            : -meta.num_intermediate_bases
262        ]
263    # sanity assert
264    assert len(meta.output_info) == len(subclass_out_meta_user_outs_only)
265    # Assume that the information on the output is shared by all of its inner tensors.
266    for out, subclass_meta in zip(meta.output_info, subclass_out_meta_user_outs_only):
267        num_outs = 1 if isinstance(subclass_meta, int) else subclass_meta.arg_count
268        for _ in range(num_outs):
269            output_info.append(out)
270
271    # A bit hacky, but we don't actually care about all of the metadata here.
272    # This metadata is used **underneath** both autograd and subclass de-sugaring,
273    # So all we really care about is stuff like:
274    # - num inputs/outputs (needed by the partitioner)
275    # - input mutations (**not** used today, since we don't handle input mutations inside the subclass,
276    #   although we should handle this eventually)
277    #   TODO: add a test case to assert we error when this happens, instead of getting silent correctness
278    num_intermediate_bases = None
279    keep_input_mutations = meta.keep_input_mutations
280    traced_tangents = None
281    subclass_inp_meta = None
282    subclass_fw_graph_out_meta = None
283    subclass_tangent_meta = None
284
285    metadata = ViewAndMutationMeta(
286        input_info=input_info,  # type: ignore[arg-type]
287        output_info=output_info,  # type: ignore[arg-type]
288        num_intermediate_bases=num_intermediate_bases,  # type: ignore[arg-type]
289        keep_input_mutations=keep_input_mutations,  # type: ignore[arg-type]
290        traced_tangents=traced_tangents,  # type: ignore[arg-type]
291        subclass_inp_meta=subclass_inp_meta,  # type: ignore[arg-type]
292        subclass_fw_graph_out_meta=subclass_fw_graph_out_meta,  # type: ignore[arg-type]
293        subclass_tangent_meta=subclass_tangent_meta,  # type: ignore[arg-type]
294    )
295    return metadata
296
297
298def compute_inner_mutated_inp_indices_from_subclass_meta(
299    fw_metadata: ViewAndMutationMeta,
300    inner_metadata: ViewAndMutationMeta,
301) -> List[int]:
302    # Note: [Recomputing subclass mutation handling]
303    #
304    # Generally, if a subclass requires grad, its components will not require grad.
305    # But for the purposes of tracking returned tensors, we should treat those component
306    # tensors as if they require grad.
307    #
308    # For example, if the subclass tensor requires grad and will be mutated in a way that
309    # requires us to handle the mutation outside of the graph, we need to return it
310    # from the forward graph. The inner_meta data won't consider the component tensors
311    # as if they need to be returned, because they don't require grad; but really, we
312    # should handle those tensors the same way we handle the subclass tensor itself; i.e.
313    # if we'd include the subclass tensor as part of the outputs, then we should also
314    # include the component tensors.
315    #
316    # To do this, we patch num_mutated_inp_runtime_indices below by expanding the inputs
317    # from the outer subclass tensors and propagating
318
319    updated_input_info = []
320    inner_idx = 0
321    if not fw_metadata.subclass_inp_meta:
322        # Sometimes we don't have subclass info, e.g. synthetic_base codepaths
323        return inner_metadata.mutated_inp_runtime_indices
324    assert len(fw_metadata.subclass_inp_meta) == len(fw_metadata.input_info)
325    for outer_idx, inp_meta in enumerate(fw_metadata.subclass_inp_meta):
326        if isinstance(inp_meta, int):
327            assert outer_idx < len(fw_metadata.input_info)
328            if inner_metadata is not None:
329                assert inner_idx < len(inner_metadata.input_info)
330                assert (
331                    inner_metadata.input_info[inner_idx]
332                    == fw_metadata.input_info[outer_idx]
333                )
334            updated_input_info.append(fw_metadata.input_info[outer_idx])
335            inner_idx += 1
336        else:
337            for _ in range(inp_meta.arg_count):
338                updated_input_info.append(fw_metadata.input_info[outer_idx])
339                inner_idx += 1
340    if inner_metadata is not None:
341        assert len(inner_metadata.input_info) == len(updated_input_info)
342
343    return [
344        i
345        for i, inp in enumerate(updated_input_info)
346        if inp.mutation_type == MutationType.MUTATED_OUT_GRAPH
347    ]
348