xref: /aosp_15_r20/external/pytorch/torch/_subclasses/fake_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import functools
4import warnings
5from typing import Callable, Union
6
7import torch
8import torch.utils._pytree as pytree
9from torch._ops import OpOverload
10from torch._subclasses.fake_tensor import (
11    FakeTensorMode,
12    tree_flatten_only,
13    UnsupportedFakeTensorException,
14)
15from torch.utils._python_dispatch import TorchDispatchMode
16
17
18aten = torch._ops.ops.aten
19
20
21def outputs_alias_inputs(outputs, inputs):
22    input_storages = {
23        inp._typed_storage()._cdata
24        for inp in tree_flatten_only(torch.Tensor, inputs)
25        if torch._C._has_storage(inp)
26    }
27    return any(
28        torch._C._has_storage(out) and out._typed_storage()._cdata in input_storages
29        for out in tree_flatten_only(torch.Tensor, outputs)
30    )
31
32
33def outputs_are_inputs(outputs, inputs):
34    input_ids = {id(inp) for inp in tree_flatten_only(torch.Tensor, inputs)}
35    return any(id(out) in input_ids for out in tree_flatten_only(torch.Tensor, outputs))
36
37
38def output_alias_each_other(outputs):
39    storages = set()
40    for out in tree_flatten_only(torch.Tensor, outputs):
41        if not torch._C._has_storage(out):
42            continue
43        stor = out._typed_storage()._cdata
44        if stor in storages:
45            return True
46        storages.add(stor)
47    return False
48
49
50def is_sdpa_error(func, idx, e):
51    if (
52        (
53            func is aten._scaled_dot_product_flash_attention.default
54            or func is aten._flash_attention_forward.default
55        )
56        and idx in (6, 7)
57        and "Devices" in repr(e)
58    ):
59        return True
60    if (
61        (
62            func is aten._scaled_dot_product_efficient_attention.default
63            or func is aten._efficient_attention_forward.default
64        )
65        and idx in (2, 3)
66        and "Devices" in repr(e)
67    ):
68        return True
69    if (
70        func is aten._scaled_dot_product_cudnn_attention.default
71        and idx in (6, 7)
72        and "Devices" in repr(e)
73    ):
74        return True
75    return False
76
77
78class CrossRefFakeMode(TorchDispatchMode):
79    def __init__(
80        self,
81        ignore_op_fn: Union[Callable[[OpOverload], bool], None] = None,
82        *,
83        check_strides=True,
84        check_aliasing=True,
85    ):
86        super().__init__()
87        self.ignore_op_fn = (
88            ignore_op_fn if ignore_op_fn is not None else lambda fn: False
89        )
90        self.check_strides = check_strides
91        self.check_aliasing = check_aliasing
92
93    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
94        kwargs = kwargs or {}
95
96        fake_r = None
97
98        # empty_like excluded for now due to sparse complex
99        # aten._to_dense.default this one is getting called with csc
100        if (
101            func
102            not in (
103                aten.lift_fresh.default,
104                aten.lift_fresh_copy.default,
105                aten.set_.source_Storage_storage_offset,
106            )
107            and not self.ignore_op_fn(func)
108            and torch.Tag.dynamic_output_shape not in func.tags
109            and torch.Tag.inplace_view not in func.tags
110            and torch.Tag.data_dependent_output not in func.tags
111        ):
112            # Do not import symbolic_shapes at the top of the module as it imports sympy and that's slow
113            from torch.fx.experimental.symbolic_shapes import ShapeEnv
114
115            try:
116                # TODO: enable_python_dispatcher() here
117                with FakeTensorMode(shape_env=ShapeEnv()) as fake_mode:
118                    fake_args, fake_kwargs = pytree.tree_map_only(
119                        torch.Tensor,
120                        functools.partial(fake_mode.from_tensor, static_shapes=True),
121                        (args, kwargs),
122                    )
123                    with warnings.catch_warnings():
124                        fake_r = func(*fake_args, **fake_kwargs)
125            except UnsupportedFakeTensorException:
126                pass
127
128        context = (
129            f"When comparing the output of {func} on FakeTensor and concrete Tensors, "
130            f"found"
131        )
132        r = func(*args, **kwargs)
133        if fake_r is not None:
134            r_flat = pytree.tree_leaves(r)
135            f_flat = pytree.tree_leaves(fake_r)
136            assert len(f_flat) == len(
137                r_flat
138            ), f"{context} mismatch in number of returns {len(f_flat)} != {len(r_flat)}"
139
140            if self.check_aliasing:
141                r_aliasing = outputs_alias_inputs(r, (args, kwargs))
142                f_aliasing = outputs_alias_inputs(fake_r, (fake_args, fake_kwargs))
143                assert (
144                    r_aliasing == f_aliasing
145                ), f"{context} mismatch in outputs_alias_inputs check {f_aliasing} != {r_aliasing}"
146
147                r_identity_eq = outputs_are_inputs(r, (args, kwargs))
148                f_identity_eq = outputs_are_inputs(fake_r, (fake_args, fake_kwargs))
149                assert (
150                    r_identity_eq == f_identity_eq
151                ), f"{context} mismatch in outputs_are_inputs check {f_identity_eq} != {r_identity_eq}"
152
153                r_output_alias_each_other = output_alias_each_other(r)
154                f_output_alias_each_other = output_alias_each_other(fake_r)
155                assert r_output_alias_each_other == f_output_alias_each_other, (
156                    f"{context} mismatch in outputs_alias_each_other check "
157                    f"{f_output_alias_each_other} != {r_output_alias_each_other}"
158                )
159
160            for idx, (r_out, fake_out) in enumerate(
161                zip(pytree.tree_leaves(r), pytree.tree_leaves(fake_r))
162            ):
163                r_is_ten = isinstance(r_out, torch.Tensor)
164                assert r_is_ten == isinstance(
165                    fake_out, torch.Tensor
166                ), f"{context} mismatched number of tensor outputs"
167                if r_is_ten:
168                    assert r_out.requires_grad == fake_out.requires_grad, (
169                        f"{context} mismatched requires_grad-ness of outputs. "
170                        f"This usually means that you have added autograd support "
171                        f"for your operator at a dispatch key other than Autograd, "
172                        f"which will lead to problems"
173                    )
174                    if torch._C._has_storage(r_out):
175                        r_offset = r_out.storage_offset()
176                        f_offset = fake_out.storage_offset()
177                        assert (
178                            r_offset == f_offset
179                        ), f"{context} mismatched storage offset"
180
181                    try:
182                        torch._prims.utils.compare_tensor_meta(
183                            r_out,
184                            fake_out,
185                            check_strides=self.check_strides,
186                            allow_rhs_unbacked=True,
187                        )
188                    except Exception as e:
189                        if is_sdpa_error(func, idx, e):
190                            continue
191                        error_message = (
192                            f"{context} mismatched tensor metadata: {e}"
193                            if len(r_flat) == 1
194                            else f"{context} mismatched tensor metadata for output[{idx}]: {e}"
195                        )
196                        raise RuntimeError(error_message) from e
197        return r
198