xref: /aosp_15_r20/external/pytorch/torch/_functorch/_aot_autograd/functional_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""
3This file contains utilities related to functionalization in AOTAutograd:
41. converting to/from functional tensors
52. detecting Tensor mutations - both metadata and Tensor value
63. regenerating/replaying views from their base
74. checking if a graph is functional i.e. whether it contains any mutation ops
8"""
9from __future__ import annotations
10
11from typing import Optional
12
13import torch
14from torch import Tensor
15from torch._logging import getArtifactLogger
16from torch._subclasses.fake_tensor import FakeTensor
17from torch._subclasses.functional_tensor import FunctionalTensor
18from torch._subclasses.meta_utils import is_sparse_any
19from torch.fx.experimental.symbolic_shapes import definitely_true, sym_eq
20from torch.multiprocessing.reductions import StorageWeakRef
21from torch.utils._python_dispatch import (
22    is_traceable_wrapper_subclass,
23    transform_subclass,
24)
25
26
27aot_joint_log = getArtifactLogger(__name__, "aot_joint_graph")
28
29
30def to_fun(t):
31    if isinstance(t, Tensor):
32        if is_traceable_wrapper_subclass(t):
33            # See Note [Functionalization always runs last]
34            # This means that if we want to "functionalize" a subclass, we need to ensure that the functional wrapper
35            # goes at the bottom.
36            # recurse here, so we can support nested wrapper subclasses
37            out = transform_subclass(t, lambda _, inner_t: to_fun(inner_t))
38            torch._mirror_autograd_meta_to(t, out)  # type: ignore[attr-defined]
39            return out
40        else:
41            return FunctionalTensor.to_functional(t)
42    else:
43        return t
44
45
46def sync_functional_tensor(t):
47    if is_traceable_wrapper_subclass(t):
48        attrs, ctx = t.__tensor_flatten__()  # type: ignore[attr-defined]
49        for attr in attrs:
50            sync_functional_tensor(getattr(t, attr))
51    else:
52        torch._sync(t)
53
54
55# When subclasses are involved, t here will usually look something like:
56# SubclassA(SubclassB(FunctionalTensor(_to_fun_tensor(FakeTensor))))
57def from_fun(t):
58    if isinstance(t, Tensor) and is_traceable_wrapper_subclass(t):
59        # See Note [Functionalization always runs last]
60        # This means that if we want to "functionalize" a subclass, we need to ensure that the functional wrapper
61        # goes at the bottom.
62        # recurse here, so we can support nested wrapper subclasses
63        out = transform_subclass(t, lambda _, inner_t: from_fun(inner_t))
64        torch._mirror_autograd_meta_to(t, out)  # type: ignore[attr-defined]
65        return out
66
67    if not isinstance(t, FunctionalTensor):
68        # quick sanity assert
69        if isinstance(t, torch.Tensor):
70            assert not torch._is_functional_tensor(t)  # type: ignore[attr-defined]
71        return t
72    sync_functional_tensor(t)
73    return torch._from_functional_tensor(t.elem)
74
75
76def is_fun(t):
77    if isinstance(t, Tensor) and is_traceable_wrapper_subclass(t):
78        # See Note [Functionalization always runs last]
79        # This means that if we want to "functionalize" a subclass, we need to ensure that the functional wrapper
80        # goes at the bottom.
81        # recurse here, so we can support nested wrapper subclasses
82        t_attrs, _ = t.__tensor_flatten__()  # type: ignore[attr-defined]
83        t_inners = [getattr(t, attr) for attr in t_attrs]
84        any_fun = any(is_fun(x) for x in t_inners)
85        all_fun = all(is_fun(x) for x in t_inners)
86        assert any_fun == all_fun
87        return any_fun
88
89    return isinstance(t, FunctionalTensor)
90
91
92# t here is either
93# (1) A FunctionalTensor(_to_functional_tensor(FakeTensor))
94# (2) A traceable tensor subclass that holds a FunctionalTensor
95# (3) Not a tensor
96def has_data_mutation(t):
97    if is_traceable_wrapper_subclass(t):
98        attrs, _ = t.__tensor_flatten__()
99        # A tensor subclass was updated if any of its inner elements were updated
100        return any(has_data_mutation(getattr(t, attr)) for attr in attrs)
101    else:
102        if isinstance(t, torch.Tensor):
103            assert isinstance(t, FunctionalTensor)
104            return torch._functionalize_has_data_mutation(t.elem)  # type: ignore[attr-defined]
105        return False
106
107
108def are_all_mutations_hidden_from_autograd(t):
109    if is_traceable_wrapper_subclass(t):
110        attrs, _ = t.__tensor_flatten__()
111        # If all inner elements are mutations hidden from autograd, then it is a mutation hidden from autograd.
112        return all(
113            are_all_mutations_hidden_from_autograd(getattr(t, attr)) for attr in attrs
114        )
115    elif isinstance(t, torch.Tensor):
116        assert isinstance(t, FunctionalTensor)
117        return torch._functionalize_are_all_mutations_hidden_from_autograd(t.elem)
118    else:
119        return False
120
121
122def are_all_mutations_under_no_grad_or_inference_mode(t):
123    if is_traceable_wrapper_subclass(t):
124        attrs, _ = t.__tensor_flatten__()
125        return all(
126            are_all_mutations_under_no_grad_or_inference_mode(getattr(t, attr))
127            for attr in attrs
128        )
129    else:
130        assert isinstance(t, FunctionalTensor)
131        return torch._functionalize_are_all_mutations_under_no_grad_or_inference_mode(
132            t.elem
133        )
134
135
136def was_inductor_storage_resized(t):
137    if is_traceable_wrapper_subclass(t):
138        attrs, _ = t.__tensor_flatten__()
139        if any(was_inductor_storage_resized(getattr(t, attr)) for attr in attrs):
140            raise RuntimeError(
141                f"storage resizing is not supported on tensor subclass: {type(t)}"
142            )
143    elif not isinstance(t, torch.Tensor):
144        return False
145    else:
146        assert isinstance(t, FunctionalTensor)
147        return torch._functionalize_was_inductor_storage_resized(t.elem)
148
149
150# f_arg here is either
151# (1) A FunctionalTensor(_to_functional_tensor(FakeTensor))
152# (2) A traceable tensor subclass that holds a FunctionalTensor
153# (3) Not a tensor
154# Assumption: arg promises to be the "original" tensor wrapped by f_arg
155# Note: "storage mutations" coming from set_() are a type of metadata mutation. So:
156# - check_only_storage_mutation=True: only return true if there was a storage mutation
157# - check_only_storage_mutation=Flse: return true if there was any metadata mutation (including a storage mutation)
158def has_metadata_mutation(f_arg, arg, *, check_only_storage_mutation: bool):
159    if is_traceable_wrapper_subclass(f_arg):
160        attrs, _ = f_arg.__tensor_flatten__()
161        # A tensor subclass was updated if any of its inner elements were updated
162        f_inner_ts = [getattr(f_arg, attr) for attr in attrs]
163        inner_ts = [getattr(arg, attr) for attr in attrs]
164        return any(
165            has_metadata_mutation(
166                f_inner_t,
167                inner_t,
168                check_only_storage_mutation=check_only_storage_mutation,
169            )
170            for f_inner_t, inner_t in zip(f_inner_ts, inner_ts)
171        )
172    else:
173        if not isinstance(f_arg, torch.Tensor):
174            assert not isinstance(arg, torch.Tensor)
175            return False
176        assert isinstance(f_arg, FunctionalTensor)
177        assert isinstance(arg, FakeTensor)
178
179        arg_after = torch._from_functional_tensor(f_arg.elem)
180        # This is true if the current tensor experienced at least one set_() call
181        maybe_storage_changed = torch._functionalize_was_storage_changed(f_arg.elem)  # type: ignore[attr-defined]
182        # However, multiple set_() calls can cancel out. So we also check whether the
183        # storage of the tensor has changed.
184        # Note: if an input experienced two set_() calls that cancel out, **and**
185        # it experiences an data mutation, we pessimistically think that the set_()
186        # call is necessary here. We could in theory fix this, but this will
187        # hopefully never happen in user code, and is not needed for fsdp.
188        if is_sparse_any(arg):
189            # TODO:add sparse tensors support to functionalization
190            same_storages = False
191        else:
192            same_storages = StorageWeakRef(arg.untyped_storage()) == StorageWeakRef(
193                arg_after.untyped_storage()
194            )
195        has_storage_metadata_mutation = maybe_storage_changed and not same_storages
196        if check_only_storage_mutation:
197            return has_storage_metadata_mutation
198
199        # storage metadata mutation is a type of metadata mutation, so return true if we saw one
200        if has_storage_metadata_mutation:
201            return True
202
203        maybe_metadata_mutated = torch._functionalize_has_metadata_mutation(f_arg.elem)  # type: ignore[attr-defined]
204        # This is true if the current tensor experienced at least one metadata mutation.
205        # So if false, we know there was no metadata mutation
206        if not maybe_metadata_mutated:
207            return False
208
209        # However, multi metadata mutations can cancel out.
210        # So we also check if the concrete sizes/strides on the tensor have changed.
211        same_sizes = arg.shape == arg_after.shape
212        same_strides = arg.stride() == arg_after.stride()
213        same_offsets = arg.storage_offset() == arg_after.storage_offset()
214        has_metadata_mutation_ = maybe_metadata_mutated and not (
215            same_sizes and same_strides and same_offsets
216        )
217        # We consider a tensor to have been metadata mutated if its storage was mutated through a set_() call.
218        return has_metadata_mutation_
219
220
221def gen_alias_from_base(
222    aliased_base_tensor,
223    target_meta_tensor,
224    target_requires_grad,
225    target_functional_tensor: Optional[FunctionalTensorMetadataEq] = None,
226    *,
227    replay_views,
228):
229    # Patch the correct requires_grad field of the output tensor, depending on whether:
230    # (i) the reconstructed output (out) was came from a tensor that requires grad or not;
231    # and (ii) the concrete returned output does require grad or not.
232    def patch_requires_grad(out):
233        if aliased_base_tensor.requires_grad and not target_requires_grad:
234            out = out.detach()
235        elif not aliased_base_tensor.requires_grad and target_requires_grad:
236            out.requires_grad_(True)
237        return out
238
239    # If provided, use the target functional tensor for replaying the views.
240    #
241    # In summary, we use the fact that FunctionalTensorWrapper saves the view
242    # functions applied to itself (collected during functionalization) so as
243    # to replay them (view functions) on the aliased_base_tensor.
244    if (
245        replay_views
246        and target_functional_tensor is not None
247        and not torch._functionalize_is_symbolic(target_functional_tensor.tensor)
248    ):
249        functional_tensor = target_functional_tensor.tensor
250
251        out = torch._functionalize_apply_view_metas(
252            functional_tensor, aliased_base_tensor
253        )
254        # If re-applying the ViewMeta sequence succeeded, there should be no more
255        # problems going forward. We just check we got to the target shape and
256        # patch requires_grad flag.
257        assert out.shape == target_meta_tensor.shape, (
258            "incorrect out shape after application of ViewMeta sequence: "
259            f"{tuple(out.shape)} (actual) vs {tuple(target_meta_tensor.shape)} (expected)"
260        )
261        return patch_requires_grad(out)
262
263    # Try to do view-replay if possible.
264    # fall back to .as_strided() if we can't.
265    if target_meta_tensor._base is not None:
266        # The base that we want to replay our view off of might have a different shape than the view's original base.
267        b = target_meta_tensor._base
268        abt = aliased_base_tensor
269        # Don't unnecessarily call as_strided if nothing changed; as_strided's
270        # backward is poorly implemented and slow
271        if abt is not b and (
272            abt.size() != b.size()
273            or abt.stride() != b.stride()
274            or abt.storage_offset() != b.storage_offset()
275        ):
276            reshaped_base_tensor = aliased_base_tensor.as_strided(
277                b.size(), b.stride(), b.storage_offset()
278            )
279        else:
280            reshaped_base_tensor = aliased_base_tensor
281        out = target_meta_tensor._view_func(reshaped_base_tensor)
282        # This shape mismatch can happen due to a bug in inplace/view handling in autograd.
283        # Try putting a breakpoint here and running
284        # `test/functorch/test_aotdispatch TestAOTAutograd.test_output_all_alias_types`
285        # Also, https://github.com/pytorch/pytorch/issues/49825
286        #
287        # As a stopgap, we'll fall back to as_strided.
288        if out is not None and out.shape == target_meta_tensor.shape:
289            return patch_requires_grad(out)
290
291    size = target_meta_tensor.size()
292    stride = target_meta_tensor.stride()
293    storage_offset = target_meta_tensor.storage_offset()
294    if aliased_base_tensor.is_complex() and not target_meta_tensor.is_complex():
295        aliased_out = torch.view_as_real(aliased_base_tensor).as_strided(
296            size, stride, storage_offset
297        )
298    elif not aliased_base_tensor.is_complex() and target_meta_tensor.is_complex():
299        aliased_out = torch.view_as_complex(aliased_base_tensor).as_strided(
300            size, stride, storage_offset
301        )
302    else:
303        aliased_out = aliased_base_tensor.as_strided(size, stride, storage_offset)
304    # For outputs aliasing inputs, we need to check if the requires-gradness has changed.
305    aliased_out = patch_requires_grad(aliased_out)
306    # For outputs aliasing inputs, we need to check if the dtype has changed.
307    # as_strided() is the "most generic" view, but it does not cover cross-dtype views
308    if aliased_out.dtype != target_meta_tensor.dtype:
309        aliased_out = aliased_out.view(target_meta_tensor.dtype)
310    return aliased_out
311
312
313def has_same_metadata(t1, t2):
314    return (
315        definitely_true(sym_eq(t1.size(), t2.size()))
316        and definitely_true(t1.layout == t2.layout)
317        and (
318            is_sparse_any(t1)
319            or (
320                definitely_true(sym_eq(t1.stride(), t2.stride()))
321                and definitely_true(t1.storage_offset() == t2.storage_offset())
322            )
323        )
324        and t1.is_conj() == t2.is_conj()
325        and t1.is_neg() == t2.is_neg()
326    )
327
328
329# Wrapper around a FunctionalTensorWrapper for comparing only the resulting metadata
330# after applying all the ViewMeta operations.
331class FunctionalTensorMetadataEq:
332    def __init__(self, tensor: torch.Tensor) -> None:
333        assert torch._is_functional_tensor(tensor)
334        self.tensor = tensor
335
336    def __eq__(self, other: object) -> bool:
337        # If other is None, then it probably means that we weren't able to recreate
338        # the FunctionalTensorMetadataEq. One of this cases is when we update the
339        # view metadata by calling: create_synthetic_base_metadata.
340        if other is None:
341            return True
342
343        # Comparison agains any other type is not implemented.
344        if not isinstance(other, FunctionalTensorMetadataEq):
345            return NotImplemented
346
347        return has_same_metadata(self.tensor, other.tensor)
348
349
350# new_arg and arg here are either:
351# (1) both a FakeTensor
352# (2) both a traceable tensor subclass that holds a FakeTensor
353# Pre-condition: the two args are the "old" and "new" inputs from running functionalization.
354# When we run functionalization and wrap our inputs into FunctionalTensors,
355# we can detect whether or not an input was mutated by checking to see if the inner tensor has changed
356#
357# Normally it would be enough just to check if arg is new_arg, which is normally enough for functionalization
358# to confirm that inputs were not mutated when running the user's model with functionalization on.
359# But when we have subclass inputs, we can't rely on that:
360# `from_fun(to_fun(x)) is x` will return False, because the call to `from_fun` constructs
361# a brand new subclass instance: we are calling __tensor_unflatten__, and going
362# from Subclass(FakeTensor) to Subclass(FunctionalTensor(FakeTensor))
363def was_tensor_updated(arg, new_arg):
364    if is_traceable_wrapper_subclass(arg):
365        assert is_traceable_wrapper_subclass(new_arg)
366        attrs, _ = arg.__tensor_flatten__()
367        new_attrs, _ = new_arg.__tensor_flatten__()
368        assert attrs == new_attrs
369        # A tensor subclass was updated if any of its inner elements were updated
370        return any(
371            was_tensor_updated(getattr(arg, attr), getattr(new_arg, attr))
372            for attr in attrs
373        )
374    else:
375        return arg is not new_arg
376
377
378# new_arg and arg here are either:
379# (1) both a FakeTensor
380# (2) both a traceable tensor subclass that holds a FakeTensor
381# Pre-condition: the two args are the "old" and "new" inputs from running functionalization.
382# When we run functionalization and wrap our inputs into FunctionalTensors,
383# we can detect whether or not an input was mutated by checking to see if the inner tensor has changed,
384# but shares storage with the old input
385def was_tensor_metadata_updated(arg, new_arg):
386    if is_traceable_wrapper_subclass(arg):
387        assert is_traceable_wrapper_subclass(new_arg)
388        attrs, _ = arg.__tensor_flatten__()
389        new_attrs, _ = new_arg.__tensor_flatten__()
390        assert attrs == new_attrs
391        # A tensor subclass was updated if any of its inner elements were updated
392        return any(
393            was_tensor_metadata_updated(getattr(arg, attr), getattr(new_arg, attr))
394            for attr in attrs
395        )
396    else:
397        return arg is not new_arg and StorageWeakRef(
398            arg.untyped_storage()
399        ) == StorageWeakRef(new_arg.untyped_storage())
400
401
402# Returns the number of detected copy_
403def assert_functional_graph(fx_g: torch.fx.Graph) -> int:
404    allowed_mutation_ops = [
405        torch.ops.aten.copy_.default,
406        torch.ops.aten.set_.source_Tensor,
407    ]
408    if hasattr(torch.ops.fsdp, "set_"):
409        allowed_mutation_ops.append(torch.ops.fsdp.set_.default)
410
411    placeholders = set()
412    mutation_count = 0
413    # NB: It would also be nice to verify that the mutations all happen at the
414    # end, but we also do some administrative views after mutations so this
415    # isn't actually true.  (TODO: Could this cause problems for Inductor?)
416    for n in fx_g.nodes:
417        if n.op == "placeholder":
418            placeholders.add(n)
419        if isinstance(n.target, torch._ops.OpOverload):
420            if n.target in allowed_mutation_ops:
421                suffix = True
422                # Can only copy_/set_ into an input
423                # this is mostly a hack to avoid failing XLA tests.
424                # See https://github.com/pytorch/pytorch/pull/122434#issuecomment-2101012113
425                if "set_buffer_donor_" not in str(n.args[0]):
426                    assert (
427                        n.args[0] in placeholders
428                    ), f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}"
429                mutation_count += 1
430            else:
431                assert (
432                    not n.target._schema.is_mutable
433                ), f"aot_autograd expected to have an entirely functional graph, but found {n.format_node()}"
434    return mutation_count
435
436
437def propagate_input_mutation_stacktraces(fx_g: torch.fx.Graph) -> None:
438    placeholders = set()
439    for n in fx_g.nodes:
440        if n.op == "placeholder":
441            placeholders.add(n)
442        if isinstance(n.target, torch._ops.OpOverload):
443            if n.target is torch.ops.aten.copy_.default:
444                # Can only copy_ into an input, and can only do so once
445                if "set_buffer_donor_" not in str(n.args[0]):
446                    assert (
447                        n.args[0] in placeholders
448                    ), f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}"
449                    placeholders.remove(n.args[0])
450                copy_from_node = n.args[1]
451                # Pre-condition: every node has a "stack_trace" field in its meta,
452                # but copy_() nodes do not (since we manually added them during functionalization).
453                # Instead, we manually propagate here.
454                if "stack_trace" in copy_from_node.meta:
455                    n.meta["stack_trace"] = copy_from_node.meta["stack_trace"]
456
457
458def _check_if_mutation_can_be_in_graph(
459    keep_input_mutations: bool,
460    mutates_data,
461    mutates_metadata,
462    mutations_hidden_from_autograd,
463    mutations_under_no_grad_or_inference_mode,
464    mutates_storage_metadata,
465    mutation_inductor_storage_resize,
466    requires_grad,
467):
468    if keep_input_mutations:
469        in_graph = (
470            mutates_data or mutates_storage_metadata or mutation_inductor_storage_resize
471        ) and (
472            (not mutates_metadata and not requires_grad)
473            or mutations_hidden_from_autograd
474            or mutations_under_no_grad_or_inference_mode
475        )
476    else:
477        in_graph = False
478    # See Note [set_() Input Mutations in AOTAutograd]
479    # If there was a `set_()`, we require that all mutations were under no_grad,
480    # so we can (safely) emit the set_() in the graph at runtime
481    # resize_() gets the same treatment
482    if mutation_inductor_storage_resize or mutates_storage_metadata:
483        op_name = "resize_" if mutation_inductor_storage_resize else "set_"
484        assert in_graph, f"""\
485Encountered a {op_name} on a graph input, but the input has other mutations that we cannot
486keep in the graph. This is not supported today. Current state:
487  keep_input_mutations={keep_input_mutations}
488  mutates_data={mutates_data}
489  mutates_metadata={mutates_metadata}
490  mutations_hidden_from_autograd={mutations_hidden_from_autograd}
491  mutations_under_no_grad_or_inference_mode={mutations_under_no_grad_or_inference_mode}
492  mutation_inductor_storage_resize={mutation_inductor_storage_resize}
493  requires_grad={requires_grad}"""
494    return in_graph
495