xref: /aosp_15_r20/external/pytorch/torch/_subclasses/schema_check_mode.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3from collections import namedtuple
4from copy import deepcopy
5from itertools import combinations
6
7import torch
8from torch.fx.operator_schemas import normalize_function
9from torch.utils import _pytree as pytree
10from torch.utils._python_dispatch import TorchDispatchMode
11from torch.utils._pytree import tree_map
12
13
14# Named Tuples used within SchemaCheckMode
15Mutation = namedtuple("Mutation", ["op_name", "arg_name"])
16Aliasing = namedtuple("Aliasing", ["op_name", "arg_name", "output_number"])
17
18# Simplified naming for C++ classes
19SchemaArgument = torch._C._SchemaArgument
20SchemaArgType = torch._C._SchemaArgType
21SchemaInfo = torch._C._SchemaInfo
22
23# This TorchDispatchMode Subclass is used to verify op schemas
24# This TorchDispatchMode Scubclass currently:
25#  - Records the called ops
26#  - Checks for mutations on all inputs
27#  - Checks for aliasing on all inputs
28
29
30# move these 2 functions here to avoid numpy dependency in testing/_internal/common_utils.py
31
32
33def is_iterable_of_tensors(iterable):
34    # Tensor itself is iterable so we check this first
35    if isinstance(iterable, torch.Tensor):
36        return False
37    try:
38        if len(iterable) == 0:
39            return False
40        for t in iter(iterable):
41            if not isinstance(t, torch.Tensor):
42                return False
43    except TypeError as te:
44        return False
45    return True
46
47
48def clone_inputs(args):
49    inputs = []
50
51    for arg in args:
52        if isinstance(arg, torch.Tensor):
53            inputs.append(arg.detach().clone())
54        elif is_iterable_of_tensors(arg):
55            inputs.append([t.detach().clone() for t in arg])
56        else:
57            inputs.append(arg)
58
59    return inputs
60
61
62class SchemaCheckMode(TorchDispatchMode):
63    def __init__(self) -> None:
64        # Information recorded for testing purposes. For example:
65        #  - incorrect schemas
66        #  - overly conservative schemas
67        self.ops = []
68        self.mutated = []
69        self.aliasing = []
70
71    def reset_cache(self):
72        self.ops.clear()
73        self.mutated.clear()
74        self.aliasing.clear()
75
76    def display_ops(self):
77        print(*self.ops, sep=",")
78
79    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
80        def bitwise_equal(lhs, rhs):
81            if lhs.is_quantized:
82                # TODO: This is only OK if can't have NaN quantized; idk if
83                # this is actually true
84                return torch.equal(lhs, rhs)
85            else:
86                return torch.allclose(lhs, rhs, equal_nan=True)
87
88        def has_mutated(before, after, md):
89            are_tensors = type(before) == torch.Tensor and type(after) == torch.Tensor
90            if (
91                are_tensors
92                and before.layout != torch.sparse_csr
93                and after.layout != torch.sparse_csr
94            ):
95                return not (
96                    before.size() == after.size()
97                    and bitwise_equal(before, after)
98                    and md[0] == after.stride()
99                    and md[1] == after._typed_storage()._cdata
100                )
101            return False
102
103        def has_aliased(lhs, rhs):
104            try:
105                return torch._C._overlaps(lhs, rhs)
106            except Exception as exception:
107                if str(exception).startswith("Cannot inspect value of type "):
108                    return False
109                else:
110                    raise exception
111
112        def standardize_name(name):
113            return name if name != "self" else "input"
114
115        def unwrap(e):
116            if isinstance(e, torch.Tensor) and not type(e) == torch.Tensor:
117                try:
118                    return e.elem
119                except AttributeError as t:
120                    return e
121            return e
122
123        def parse_metadata(e):
124            if isinstance(e, torch.Tensor):
125                if not type(e) == torch.Tensor:
126                    try:
127                        current = e.elem
128                        return (
129                            deepcopy(current.stride()),
130                            current._typed_storage()._cdata,
131                        )
132                    except AttributeError as t:
133                        return None
134                # Sparse CSR tensors do not have strides or storage
135                elif e.layout != torch.sparse_csr:
136                    return (deepcopy(e.stride()), e._typed_storage()._cdata)
137            return None
138
139        self.ops.append(func._schema.name)
140
141        # Clone and process arguments and outputs
142        pre_arguments = normalize_function(
143            func, args, kwargs, normalize_to_only_use_kwargs=True
144        ).kwargs
145
146        c_p_args = dict(zip(pre_arguments.keys(), clone_inputs(pre_arguments.values())))
147        cloned_arguments = {
148            name: tree_map(unwrap, c_p_args.get(name)) for name in c_p_args
149        }
150        cloned_metadata = {
151            name: [
152                parse_metadata(a) for a in pytree.tree_leaves(pre_arguments.get(name))
153            ]
154            for name in pre_arguments
155        }
156
157        out = func(*args, **kwargs)
158        arguments = {
159            name: tree_map(unwrap, pre_arguments.get(name)) for name in pre_arguments
160        }
161        tuple_out = out if isinstance(out, tuple) else (out,)
162        tuple_out = tree_map(unwrap, tuple_out)
163
164        schema_info = SchemaInfo(func._schema)
165        schema_info.add_argument_values(pre_arguments)
166
167        # Process arguments with outputs
168        for i in range(len(func._schema.arguments)):
169            arg = func._schema.arguments[i]
170            name = standardize_name(arg.name)
171            if arguments.get(name) is not None:
172                before = cloned_arguments.get(name)
173                md = cloned_metadata.get(name)
174                after = arguments.get(name)
175                for j in range(len(tuple_out)):
176                    # aten::_unsafe_view is intended to have incorrect aliasing notation (hence unsafe)
177                    unsafe_ops = ("aten::_unsafe_view", "aten::unsafe_split")
178                    if (
179                        has_aliased(tuple_out[j], after)
180                        and func._schema.name not in unsafe_ops
181                    ):
182                        if not schema_info.may_contain_alias(
183                            SchemaArgument(SchemaArgType.output, j),
184                            SchemaArgument(SchemaArgType.input, i),
185                        ):
186                            raise RuntimeError(
187                                f"Argument {name} is not defined to alias output but was aliasing"
188                            )
189                        else:
190                            self.aliasing.append(
191                                Aliasing(func._schema.name, name, f"output_{j}")
192                            )
193                    if after is tuple_out[j] and isinstance(after, torch.Tensor):
194                        # Only mutable ops e.g. (add_, add.out) are allowed to directly return inputs.
195                        if not schema_info.is_mutable(
196                            SchemaArgument(SchemaArgType.input, i)
197                        ) and func not in [
198                            torch.ops.aten.lift.default,
199                            torch.ops.aten.lift_fresh.default,
200                        ]:
201                            raise RuntimeError(
202                                f"""\
203Dispatcher operators below autograd are not allowed to directly return inputs.
204However, we found that `outputs[{str(j)}] is {name}"""
205                            )
206                if any(
207                    has_mutated(a, b, c)
208                    for a, b, c in zip(
209                        pytree.tree_leaves(before), pytree.tree_leaves(after), md
210                    )
211                ):
212                    if not schema_info.is_mutable(
213                        SchemaArgument(SchemaArgType.input, i)
214                    ):
215                        raise RuntimeError(
216                            f"Argument {name} is not defined as mutable but was mutated"
217                        )
218                    else:
219                        self.mutated.append(Mutation(func._schema.name, name))
220
221        # Aliasing between outputs
222        for i, j in combinations(range(len(func._schema.returns)), 2):
223            if has_aliased(tuple_out[i], tuple_out[j]):
224                if not schema_info.may_contain_alias(
225                    SchemaArgument(SchemaArgType.output, i),
226                    SchemaArgument(SchemaArgType.output, j),
227                ):
228                    raise RuntimeError(f"Outputs {i} and {j} alias unexpectedly")
229
230        return out
231