xref: /aosp_15_r20/external/pytorch/torch/_functorch/_aot_autograd/input_output_analysis.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""
3This module is one of the analysis modules - it takes as input a function or graph
4and some preexisting properties, and returns some data that is useful for deciding
5how to further proceed with compilation or construct runtime wrappers.
6
7In particular, the following analyses are provided:
81. Refine the view and mutation metadata collected previously - removing duplicate
9   inputs or mapping views to their bases.
102. We also analyze the function signature for export graphs.
11"""
12
13import itertools
14from typing import Any, Dict, List, Optional, Tuple, Union
15
16import torch
17import torch.utils._pytree as pytree
18from torch import Tensor
19from torch._subclasses.functional_tensor import FunctionalTensor
20from torch.fx.experimental.symbolic_shapes import is_concrete_int
21
22from .. import config
23from .collect_metadata_analysis import coerce_tangent
24from .schemas import (
25    BackwardSignature,
26    GraphSignature,
27    InputAliasInfo,
28    OutputAliasInfo,
29    OutputType,
30    ViewAndMutationMeta,
31)
32from .utils import strict_zip
33
34
35zip = strict_zip
36
37
38def remove_dupe_metadata(
39    m: ViewAndMutationMeta,
40    keep_arg_mask: List[bool],
41    add_dupe_map: List[int],
42) -> ViewAndMutationMeta:
43    assert len(m.input_info) == len(keep_arg_mask)
44    # Easy invariant: the first argument should never be a dupe (it will be kept)
45    assert len(keep_arg_mask) > 0 and keep_arg_mask[0]
46
47    # Filter dupe'd mutated inputs out of traced_tangents
48    num_data_mutations = len([x for x in m.input_info if x.mutates_data])
49    other_traced_tangents = m.traced_tangents[num_data_mutations:]
50    inp_traced_tangents = m.traced_tangents[:num_data_mutations]
51    filtered_inp_traced_tangents = [
52        # See Note [Tangents must be contiguous]
53        x
54        for i, x in enumerate(inp_traced_tangents)
55        if keep_arg_mask[m.mutated_inp_runtime_indices[i]]
56    ]
57    traced_tangents = filtered_inp_traced_tangents + other_traced_tangents
58
59    return ViewAndMutationMeta(
60        input_info=[x for i, x in enumerate(m.input_info) if keep_arg_mask[i]],
61        # For outputs that are views of inputs, we store the index of the input that the output
62        # was generated from. Need to update that index to account for removed dupes.
63        output_info=[
64            OutputAliasInfo(
65                output_type=o.output_type,
66                raw_type=o.raw_type,
67                dynamic_dims=o.dynamic_dims,
68                base_idx=None if o.base_idx is None else add_dupe_map[o.base_idx],
69                requires_grad=o.requires_grad,
70                functional_tensor=o.functional_tensor,
71            )
72            for o in m.output_info
73        ],
74        num_intermediate_bases=m.num_intermediate_bases,
75        keep_input_mutations=m.keep_input_mutations,
76        traced_tangents=traced_tangents,
77        # We are guaranteed not to get here, since dupes are not supported today with subclass inputs.
78        subclass_inp_meta=[],
79        subclass_fw_graph_out_meta=[],
80        subclass_tangent_meta=[],
81        is_train=m.is_train,
82    )
83
84
85# Given our ViewAndMutation metadata, this fn constructs a new set of metadata,
86# after adding synthetic base arguments to the function.
87# Most of the work in this fn is slogging through all of the metadata corresponding to inputs,
88# and updating it with our synthetic base calling convention.
89#
90# When config.debug_assert is set, we automatically regenerate the metadata
91# and compare it to this output for sanity.
92#
93# In addition to the updated metadata, also return the list of input indices
94# that will need to be updated in the synthetic base epilogue
95
96
97# Given our ViewAndMutation metadata, this fn constructs a new set of metadata,
98# after adding synthetic base arguments to the function.
99# Most of the work in this fn is slogging through all of the metadata corresponding to inputs,
100# and updating it with our synthetic base calling convention.
101#
102# When config.debug_assert is set, we automatically regenerate the metadata
103# and compare it to this output for sanity.
104#
105# In addition to the updated metadata, also return the list of input indices
106# that will need to be updated in the synthetic base epilogue
107def create_synthetic_base_metadata(
108    m: ViewAndMutationMeta,
109    # Maps each outer argument idx to its inner idx (or, if this outer arg is generated from a
110    # synthetic base, you get a tuple of (i, TensorMeta), telling you the base tensor idx, and view metadata)
111    synthetic_base_info: List[Union[int, Tuple[int, torch.Tensor]]],
112    outer_args: List[Any],
113    inner_args: List[Any],
114) -> Tuple[ViewAndMutationMeta, List[int]]:
115    # maps inner arg indices to outer arg indices
116    synthetic_base_to_indices: Dict[int, List[int]] = {}
117    for inner_idx in range(len(inner_args)):
118        outer_aliased_indices_of_current_base_arg = [
119            outer_idx
120            for outer_idx, inner_idx_or_tuple in enumerate(synthetic_base_info)
121            if (isinstance(inner_idx_or_tuple, int) and inner_idx_or_tuple == inner_idx)
122            or (
123                isinstance(inner_idx_or_tuple, tuple)
124                and inner_idx_or_tuple[0] == inner_idx
125            )
126        ]
127        synthetic_base_to_indices[inner_idx] = outer_aliased_indices_of_current_base_arg
128
129    # given the requires_grad info on mutated inputs,
130    # generate the requires_grad info on those same mutated inputs, but after constructing synthetic bases.
131    input_infos = []
132    for outer_indices in synthetic_base_to_indices.values():
133        # leaf-ness should be all-or-nothing for aliased tensor.
134        # (aka if "a" and "b" are views, then a.is_leaf == b.is_leaf)
135        any_leaf = any(m.input_info[x].is_leaf for x in outer_indices)
136        all_leaf = all(m.input_info[x].is_leaf for x in outer_indices)
137        assert any_leaf == all_leaf
138
139        mutates_data = (
140            True
141            if len(outer_indices) > 1
142            else m.input_info[outer_indices[0]].mutates_data
143        )
144        mutates_metadata = (
145            False
146            if len(outer_indices) > 1
147            else m.input_info[outer_indices[0]].mutates_metadata
148        )
149        requires_grad = any(m.input_info[x].requires_grad for x in outer_indices)
150        mutations_hidden_from_autograd = all(
151            m.input_info[x].mutations_hidden_from_autograd for x in outer_indices
152        )
153        mutations_under_no_grad_or_inference_mode = all(
154            m.input_info[x].mutations_under_no_grad_or_inference_mode
155            for x in outer_indices
156        )
157
158        mutation_inductor_storage_resize = all(
159            m.input_info[x].mutation_inductor_storage_resize for x in outer_indices
160        )
161
162        inpt_info = InputAliasInfo(
163            # If len(outer_indices) > 1, then this input is a synthetic base.
164            # The invariant is that to the rest of aot autograd, synthetic bases only show up if
165            # one of their aliases gets a data mutation. And if any of their aliases get metadata
166            # mutations, they will be hidden from the rest of aot autograd.
167            mutates_data=mutates_data,
168            mutates_metadata=mutates_metadata,
169            mutations_hidden_from_autograd=all(
170                m.input_info[x].mutations_hidden_from_autograd for x in outer_indices
171            ),
172            mutates_storage_metadata=False
173            if len(outer_indices) > 1
174            else m.input_info[outer_indices[0]].mutates_storage_metadata,
175            mutations_under_no_grad_or_inference_mode=mutations_under_no_grad_or_inference_mode,
176            mutation_inductor_storage_resize=mutation_inductor_storage_resize,
177            is_leaf=any_leaf,
178            requires_grad=requires_grad,
179            keep_input_mutations=m.keep_input_mutations,
180        )
181        input_infos.append(inpt_info)
182
183    # Find any inputs that fulfill the following criteria:
184    # (1) They are part of a synthetic base (because they alias another input,
185    #      and at least one input experiences a data mutation)
186    # (2) They experience a metadata mutation
187    outer_aliased_arg_idx_with_metadata_mutations = [
188        outer_idx
189        for outer_idx, inpt_info in enumerate(m.input_info)
190        if inpt_info.mutates_metadata
191        and not isinstance(synthetic_base_info[outer_idx], int)
192    ]
193
194    # grab the original requires grad info on the outputs, except the ones from the mutated inputs
195    input_metadata_output_info = [
196        OutputAliasInfo(
197            output_type=OutputType.alias_of_input,
198            raw_type=FunctionalTensor,
199            dynamic_dims={
200                i
201                for i, s in enumerate(outer_args[outer_idx].shape)
202                if not is_concrete_int(s)
203            },
204            base_idx=synthetic_base_info[outer_idx][0],  # type: ignore[index]
205            requires_grad=outer_args[outer_idx].requires_grad,
206        )
207        for outer_idx in outer_aliased_arg_idx_with_metadata_mutations
208    ]
209    existing_output_infos = []
210    for o in m.output_info:
211        new_base_idx = (
212            None
213            if o.base_idx is None
214            else (
215                synthetic_base_info[o.base_idx]
216                if isinstance(synthetic_base_info[o.base_idx], int)
217                else synthetic_base_info[o.base_idx][0]  # type: ignore[index]
218            )
219        )
220        # If base_idx is changed for OutputType.is_input, we need to update the output type to reflect the change
221        new_output_type = (
222            OutputType.alias_of_input
223            if o.output_type == OutputType.is_input and o.base_idx != new_base_idx
224            else o.output_type
225        )
226        existing_output_infos.append(
227            OutputAliasInfo(
228                output_type=new_output_type,
229                raw_type=o.raw_type,
230                dynamic_dims=o.dynamic_dims,
231                # Map the input idx pre-synthetic-bases to the new idx post-synthetic-bases
232                base_idx=new_base_idx,  # type: ignore[arg-type]
233                requires_grad=o.requires_grad,
234                functional_tensor=o.functional_tensor,
235            )
236        )
237
238    inner_mutated_tangents = [
239        # See Note [Tangents must be contiguous]
240        coerce_tangent(x)
241        for inner_idx, x in enumerate(inner_args)
242        if input_infos[inner_idx].mutates_data and input_infos[inner_idx].requires_grad
243    ]
244
245    output_info = existing_output_infos + input_metadata_output_info
246    # Regenerate traced tangents to include mutated inputs including synthetic bases
247    traced_tangents = (
248        inner_mutated_tangents + m.traced_tangents[len(inner_mutated_tangents) :]
249    )
250
251    return (
252        ViewAndMutationMeta(
253            input_info=input_infos,
254            output_info=output_info,
255            num_intermediate_bases=m.num_intermediate_bases,
256            keep_input_mutations=m.keep_input_mutations,
257            traced_tangents=traced_tangents,
258            # We are guaranteed not to get here, since synthetic_base codepaths are not supported today with subclass inputs.
259            subclass_inp_meta=[],
260            subclass_fw_graph_out_meta=[],
261            subclass_tangent_meta=[],
262            is_train=m.is_train,
263        ),
264        outer_aliased_arg_idx_with_metadata_mutations,
265    )
266
267
268def _get_last_mem_address(x):
269    out = x.storage_offset()
270    for size, stride in zip(x.size(), x.stride()):
271        out += (size - 1) * stride
272    return out
273
274
275# Assumption: x and y are known to share a storage, and we are trying to determine
276# if their memory is actually completely disjoint, based on sizes/strides/storage_offset
277def _tensors_definitely_do_not_overlap(x, y):
278    if x is y:
279        return False
280    if x.numel() == 0 or y.numel() == 0:
281        return True
282
283    # Make x always on the left
284    if x.storage_offset() > y.storage_offset():
285        x, y = y, x
286    # Short-circuit in the "obvious" overlapping case: both tensors are contiguous
287    if x.is_contiguous() and y.is_contiguous():
288        if x.storage_offset() + x.numel() > y.storage_offset():
289            # definitely overlap
290            return False
291        else:
292            # definitely no overlap
293            return True
294
295    # Short-circuit: if last memory address of x is < start of y, then not overlapping.
296    x_last = _get_last_mem_address(x)
297    if x_last < y.storage_offset():
298        return True
299
300    if x.dim() == 2 and y.dim() == 2 and x.stride(1) == 1 and y.stride(1) == 1:
301        # This cases is needed for the shampoo optimizer.
302        # All tensors are 2d (non-contiguous), have the same outer stride, and have an inner stride of 1
303        # (so rows are contiguous)
304        if x.stride(0) == y.stride(0):
305            offset_delta = y.storage_offset() - x.storage_offset()
306            if offset_delta < x.size(1):
307                # definitely overlaps (row 0 of y overlaps with row 0 of x)
308                # Example:
309                #   base = torch.arange(32).reshape(4, 8)
310                #   x = base.narrow(1, 0, 4)
311                #     x: size=(4, 4), stride=(8, 1), offset=0
312                #   y = base.narrow(1, 3, 4)
313                #     y: size=(4, 4), stride=(8, 1), offset=3
314                return False
315            x_total_elems_covered = x.stride(0) * (x.size(0) - 1) + x.size(1)
316            if x_total_elems_covered <= offset_delta:
317                # definitely does not overlap (last byte of x is before start of y)
318                # Example:
319                #   x: size=(4, 4), stride=(8, 1), offset=0 (last byte is 27)
320                #   y: size=(4, 4), stride=(8, 1), offset=28 (start byte is 28)
321                return True
322            # At this point, we want to check if the 0th row of y
323            # overlaps with **some** row of x.
324            # We can check this by shifting y backward by the shared stride, repeatedly,
325            # until the first row of y is before the first row of x.
326            # Then we can check if these rows overlap.
327            # We can accomplish this by modding our offset by the stride.
328            offset_delta_mod = offset_delta % x.stride(0)
329            # Example:
330            # 0 1 2 3
331            # 9 10 11 12
332            # 18 19 20 21
333            # 27 28 29 30
334            #   x: size=(4, 4), stride=(9, 1), offset=0
335            #   y: size=(4, 4), stride=(9, 1), offset=22 (this would not overlap)
336            #   y: size=(4, 4), stride=(9, 1), offset=23 (this would not overlap)
337            #   y: size=(4, 4), stride=(9, 1), offset=24 (this would overlap)
338            #   y: size=(4, 4), stride=(9, 1), offset=25 (this would overlap)
339            # If the interval [modded_offset, modded_offset + x_size] falls entirely
340            # without
341            if offset_delta_mod + y.size(1) <= x.stride(0):
342                return True
343    return False
344
345
346def compute_overlapping_inputs(fwd_inputs, aliased_input_indices):
347    max_aliased_inps_w_dyn_shapes = (
348        config._max_aliased_inputs_with_dynamic_shapes_enabled
349    )
350    definitely_error_on_dyn_shapes = False
351    # If the JK is false / not set, we will fall back to obeying the config above
352    # If it is true, we will always error when there are aliased + mutated inps with dynamic shapes
353    if torch._inductor.config.is_fbcode():
354        definitely_error_on_dyn_shapes = torch._utils_internal.justknobs_check(
355            "pytorch/dynamo:disable_aliased_inputs_with_mutation_and_dyn_shapes"
356        )
357
358    actual_aliased_indices = set()
359    num_aliases = len(aliased_input_indices)
360    # > 2 check because num_aliases==1 means no aliasing
361    if num_aliases >= 2 and (
362        definitely_error_on_dyn_shapes or num_aliases > max_aliased_inps_w_dyn_shapes
363    ):
364        dynamic_shape_indices = set()
365        for j in range(num_aliases):
366            j_ = aliased_input_indices[j]
367            curr_inp = fwd_inputs[j_]
368            if any(
369                isinstance(x, torch.SymInt)
370                for x in itertools.chain(
371                    curr_inp.shape, curr_inp.stride(), [curr_inp.storage_offset()]
372                )
373            ):
374                dynamic_shape_indices.add(j_)
375        assert (
376            len(dynamic_shape_indices) == 0
377        ), f"""\
378Encountered a graph where:
379- {num_aliases} graph inputs all share the same storage (input indices: {str(aliased_input_indices)})
380- at least one of these aliased inputs was mutated
381- at least one of these inputs is being compiled with dynamic shapes (indices: {str(dynamic_shape_indices)})
382
383Current limit: {str(max_aliased_inps_w_dyn_shapes)}
384Killswitch enabled: {str(definitely_error_on_dyn_shapes)}
385
386The most common way to run into this situation is when your model parameters are allocated as one giant buffer
387and are all mutated by the optimizer, and some of your parameters end up getting compiled with dynamic shapes.
388
389You can avoid this problem by marking your parameters so they explicitly do not participate in dynamic shapes,
390by marking each dim of your parameter static:
391
392torch._dynamo.mark_static(param, 0) # (1, 2, ... for every dimension on the parameter).
393
394If you are running into this issue in a situation where your parameters are static but some other inputs
395are aliased and mutated, and they should be dynamic, please file an issue.
396"""
397    for j in range(num_aliases):
398        for i in range(j):
399            j_ = aliased_input_indices[j]
400            i_ = aliased_input_indices[i]
401            if not _tensors_definitely_do_not_overlap(fwd_inputs[i_], fwd_inputs[j_]):
402                actual_aliased_indices.add(i_)
403                actual_aliased_indices.add(j_)
404    return actual_aliased_indices
405
406
407def _graph_input_names(gm):
408    return [node.name for node in gm.graph.find_nodes(op="placeholder")]
409
410
411def _graph_output_names(gm):
412    output_node = next(iter(reversed(gm.graph.nodes)))
413    assert output_node.op == "output" and len(output_node.args) == 1
414    return_args = output_node.args[0]
415    return [getattr(return_arg, "name", None) for return_arg in return_args]
416
417
418def create_graph_signature(
419    fx_g: torch.fx.GraphModule,
420    fw_metadata: ViewAndMutationMeta,
421    in_spec: pytree.TreeSpec,
422    out_spec: pytree.TreeSpec,
423    *,
424    user_args_flat: List[Tensor],
425    params_and_buffers_flat: List[Tensor],
426    param_names: List[str],
427    buffer_names: List[str],
428    trace_joint: bool,
429    num_user_fw_outs: Optional[int],
430    loss_index: Optional[int],
431) -> GraphSignature:
432    # Retrieve graph input names
433    graph_input_names = _graph_input_names(fx_g)
434    # Retrieve graph output names
435    graph_output_names = _graph_output_names(fx_g)
436
437    num_params_buffers = len(param_names) + len(buffer_names)
438    num_tokens = len(fw_metadata.tokens)
439    # We have enough restrictions on the graph (no de-duping, synthetic bases, etc),
440    # Such that # graph inps = # user inps + # params + # buffers
441    num_user_args = len(graph_input_names) - num_params_buffers - num_tokens
442
443    if trace_joint:
444        assert num_user_fw_outs is not None
445        num_fw_outs = num_user_fw_outs + fw_metadata.num_mutated_inp_runtime_indices
446        backward_output_names = graph_output_names[num_fw_outs:]
447
448        grad_index = itertools.count(0)
449        gradients_to_parameters = {
450            backward_output_names[next(grad_index)]: param_names[i]
451            for i, param in enumerate(params_and_buffers_flat)
452            if param.requires_grad
453        }
454
455        gradients_to_user_inputs = {
456            backward_output_names[next(grad_index)]: graph_input_names[
457                i + len(params_and_buffers_flat)
458            ]
459            for i, user_input in enumerate(user_args_flat)
460            if user_input.requires_grad
461        }
462
463        assert len(gradients_to_parameters) + len(gradients_to_user_inputs) == len(
464            backward_output_names
465        )
466
467        # Check that we have fully accounted for all graph outputs
468        backward_signature = BackwardSignature(
469            gradients_to_parameters,
470            gradients_to_user_inputs,
471            graph_output_names[loss_index],
472        )
473    else:
474        backward_signature = None
475        num_user_fw_outs = (
476            len(graph_output_names)
477            - fw_metadata.num_mutated_inp_runtime_indices
478            - num_tokens
479        )
480
481    return GraphSignature.from_tracing_metadata(
482        in_spec=in_spec,
483        out_spec=out_spec,
484        graph_input_names=graph_input_names,
485        graph_output_names=graph_output_names,
486        view_mutation_metadata=fw_metadata,
487        named_parameters=param_names,
488        named_buffers=buffer_names,
489        num_user_inputs=num_user_args,
490        num_user_outputs=num_user_fw_outs,
491        loss_index=loss_index,
492        backward_signature=backward_signature,
493    )
494