xref: /aosp_15_r20/external/pytorch/test/test_schema_check.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import os
4import sys
5import torch
6from torch.utils._pytree import tree_map
7import unittest
8
9from torch.testing._internal.common_utils import run_tests, TEST_WITH_TORCHDYNAMO
10from torch.fx.operator_schemas import normalize_function
11from torch._subclasses.schema_check_mode import SchemaCheckMode
12from torch.utils._python_dispatch import TorchDispatchMode
13from torch.testing._internal.common_methods_invocations import op_db
14from torch.testing._internal.jit_utils import JitTestCase
15from torch.testing._internal.common_device_type import ops, OpDTypes, instantiate_device_type_tests
16pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
17sys.path.append(pytorch_test_dir)
18
19def secretly_aliasing(x):
20    return x.view(-1)
21
22def secretly_mutating(x):
23    x.mul_(2)
24    return x * 3
25
26def output_is_input(x):
27    return x
28
29custom_lib = torch.library.Library("bad_schemas", "DEF")  # noqa: TOR901
30custom_lib.define("secretly_aliasing(Tensor x) -> Tensor")
31custom_lib.define("secretly_mutating(Tensor x) -> Tensor")
32custom_lib.define("output_is_input(Tensor(a) x) -> Tensor(a)")
33
34custom_lib_cpu = torch.library.Library("bad_schemas", "IMPL", "CPU")  # noqa: TOR901
35custom_lib_cpu.impl("secretly_aliasing", secretly_aliasing)
36custom_lib_cpu.impl("secretly_mutating", secretly_mutating)
37custom_lib_cpu.impl("output_is_input", output_is_input)
38
39custom_lib_meta = torch.library.Library("bad_schemas", "IMPL", "Meta")  # noqa: TOR901
40custom_lib_meta.impl("secretly_aliasing", secretly_aliasing)
41custom_lib_meta.impl("secretly_mutating", secretly_mutating)
42custom_lib_meta.impl("output_is_input", output_is_input)
43
44# This TorchDispatchTensor Subclass is used to simulate an incorrect schema
45# which is then used to test that SchemaCheckMode behaves as expected
46
47class IncorrectAliasTensor(torch.Tensor):
48    ALIAS_ARG_OUT = {"aten::add"}
49    ALIAS_OUT_OUT = {"aten::aminmax"}
50    MUTATE_ARGS_OUT = {"aten::sub"}
51
52    elem: torch.Tensor
53
54    __slots__ = ['elem']
55
56    @staticmethod
57    def __new__(cls, elem, *args, **kwargs):
58        # The wrapping tensor (IncorrectAliasTensor) shouldn't hold any
59        # memory for the class in question, but it should still
60        # advertise the same device as before
61        r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
62            cls, elem.size(),
63            strides=elem.stride(), storage_offset=elem.storage_offset(),
64            # TODO: clone storage aliasing
65            dtype=elem.dtype, layout=elem.layout,
66            device=elem.device, requires_grad=kwargs.get("requires_grad", False)
67        )
68        # ...the real tensor is held as an element on the tensor.
69        r.elem = elem.detach() if r.requires_grad else elem
70        return r
71
72    def __repr__(self):
73        return super().__repr__(tensor_contents=f"{self.elem}")
74
75    @classmethod
76    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
77        def unwrap(e):
78            return e.elem if isinstance(e, cls) else e
79
80        def wrap(e):
81            return cls(e) if isinstance(e, torch.Tensor) else e
82        unwrapped_args = tree_map(unwrap, args)
83        out = func(*unwrapped_args, **tree_map(unwrap, kwargs))
84        if func._schema.name in IncorrectAliasTensor.ALIAS_ARG_OUT:
85            args[0].elem = out
86        if func._schema.name in IncorrectAliasTensor.MUTATE_ARGS_OUT:
87            args[0].elem = torch.rand(args[0].elem.shape)
88        if func._schema.name in IncorrectAliasTensor.ALIAS_OUT_OUT:
89            incorrect_out = list(out)
90            incorrect_out[0] = incorrect_out[1]
91            return tree_map(wrap, tuple(incorrect_out))
92
93        return tree_map(wrap, out)
94
95# Tests various schema checking functionalities.
96class TestSchemaCheck(JitTestCase):
97    def setUp(self):
98        if TEST_WITH_TORCHDYNAMO:
99            self.skipTest("SchemaCheckMode is ignored by dynamo")
100        super().setUp()
101
102    # Tests that SchemaCheckMode records operator order with grad
103    def test_schema_check_mode_operator_order(self):
104        with SchemaCheckMode() as schema_check:
105            x = torch.rand((3, 3), requires_grad=True)
106            x.relu().sin()
107        self.assertEqual(["aten::rand", "aten::relu", "aten::detach", "aten::sin"], schema_check.ops)
108
109    # Tests that SchemaCheckMode records operator order without grad
110    def test_schema_check_mode_operator_order_without_grad(self):
111        with SchemaCheckMode() as schema_check:
112            x = torch.rand((3, 3), requires_grad=False)
113            x.relu().sin()
114        self.assertEqual(["aten::rand", "aten::relu", "aten::sin"], schema_check.ops)
115
116    # Tests that SchemaCheckMode records mutations and aliases with none expected
117    def test_schema_check_mode_mutated_aliasing_none(self):
118        # NB: previously requires_grad=True, but this induces a detach for
119        # saved variable
120        x = torch.rand((3, 3))
121        with SchemaCheckMode() as schema_check:
122            actual = x.relu().sin()
123        self.assertEqual([], schema_check.mutated)
124        self.assertEqual([], schema_check.aliasing)
125
126    # Tests that SchemaCheckMode records mutations and aliases with mutation expected
127    def test_schema_check_mode_mutated_aliasing_mutation(self):
128        actual = torch.rand((3, 3), requires_grad=False)
129        with SchemaCheckMode() as schema_check:
130            actual.sinh_()
131        self.assertEqual([('aten::sinh_', 'input')], schema_check.mutated)
132        self.assertEqual([('aten::sinh_', 'input', 'output_0')], schema_check.aliasing)
133
134    # Tests that SchemaCheckMode records mutations and aliases with resize_
135    def test_schema_check_mode_mutated_aliasing_resize_(self):
136        actual = torch.rand((3, 3), requires_grad=False)
137        with SchemaCheckMode() as schema_check:
138            actual.resize_(9)
139        self.assertEqual([('aten::resize_', 'input')], schema_check.mutated)
140        self.assertEqual([('aten::resize_', 'input', 'output_0')], schema_check.aliasing)
141
142    # Tests that SchemaCheckMode records mutations and aliases with aliasing inputs
143    def test_schema_check_mode_mutated_aliasing_aliasing_inputs(self):
144        actual = torch.rand((3, 3))
145        y = actual
146        with SchemaCheckMode() as schema_check:
147            actual.add_(y)
148        self.assertEqual(
149            [
150                ('aten::add_', 'input'),
151                ('aten::add_', 'other')
152            ],
153            schema_check.mutated
154        )
155        self.assertEqual(
156            [
157                ('aten::add_', 'input', 'output_0'),
158                ('aten::add_', 'other', 'output_0')
159            ],
160            schema_check.aliasing
161        )
162
163    # Tests that SchemaCheckMode records mutations and alias with as_strided
164    def test_schema_check_mode_mutated_aliasing_as_strided(self):
165        x = torch.rand((3, 6, 4))
166        with SchemaCheckMode() as schema_check:
167            x.as_strided_([3, 6, 4], [9, 1, 1])
168        self.assertEqual(
169            [
170                ('aten::as_strided_', 'input')
171            ],
172            schema_check.mutated
173        )
174        self.assertEqual(
175            [
176                ('aten::as_strided_', 'input', 'output_0')
177            ],
178            schema_check.aliasing
179        )
180
181    # Tests that SchemaCheckMode records mutations and aliases with multiple outputs
182    def test_schema_check_mode_mutated_aliasing_multiple_outputs(self):
183        x = torch.arange(9.)
184        m_actual = torch.arange(9.)
185        e_actual = torch.zeros([9], dtype=torch.int32)
186        with SchemaCheckMode() as schema_check:
187            torch.frexp(x, out=(m_actual, e_actual))
188        self.assertEqual(
189            [
190                ('aten::frexp', 'mantissa'),
191                ('aten::frexp', 'exponent')
192            ],
193            schema_check.mutated
194        )
195        self.assertEqual(
196            [
197                ('aten::frexp', 'mantissa', 'output_0'),
198                ('aten::frexp', 'exponent', 'output_1')
199            ],
200            schema_check.aliasing
201        )
202
203    # Tests that SchemaCheckMode records mutations and aliases with aliasing outputs
204    def test_schema_check_mode_mutated_aliasing_aliasing_outputs(self):
205        x = torch.rand((3, 3))
206        actual = torch.zeros(3)
207        with SchemaCheckMode() as schema_check:
208            torch.aminmax(x, dim=0, out=[actual, actual])
209        self.assertEqual(
210            [
211                ('aten::aminmax', 'min'),
212                ('aten::aminmax', 'max')
213            ],
214            schema_check.mutated
215        )
216        self.assertEqual(
217            [
218                ('aten::aminmax', 'min', 'output_0'),
219                ('aten::aminmax', 'min', 'output_1'),
220                ('aten::aminmax', 'max', 'output_0'),
221                ('aten::aminmax', 'max', 'output_1')
222            ],
223            schema_check.aliasing
224        )
225
226    # Tests that SchemaCheckMode wraps torch.Tensor
227    def test_schema_check_mode_functionality(self):
228        x = torch.rand((3, 3), requires_grad=True)
229        expected = x.relu().sin()
230        with SchemaCheckMode():
231            actual = x.relu().sin()
232        self.assertEqual(expected, actual)
233
234    # Tests that SchemaCheckMode wraps torch.Tensor when an argument's default is overriden
235    def test_schema_check_mode_functionality_default_replaced(self):
236        x = torch.rand((3, 3), requires_grad=True)
237        expected = x.add(x, alpha=2)
238        with SchemaCheckMode():
239            actual = x.add(x, alpha=2)
240        self.assertEqual(expected, actual)
241
242    # Tests that SchemaCheckMode wraps torch.Tensor when there is a Tensor[] argument
243    def test_schema_check_mode_functionality_list_input(self):
244        a = torch.rand((3, 3))
245        b = torch.rand((3, 3))
246        c = torch.rand((3, 3))
247        expected = torch.linalg.multi_dot([a, b, c])
248        with SchemaCheckMode():
249            actual = torch.linalg.multi_dot([a, b, c])
250        self.assertEqual(expected, actual)
251
252    # Tests that SchemaCheckMode wraps torch.Tensor with an op that has the (a -> *) notation
253    def test_schema_check_mode_functionality_wildcard_after(self):
254        x = torch.rand((3, 3))
255        expected = x.chunk(6)
256        with SchemaCheckMode():
257            actual = x.chunk(6)
258        self.assertEqual(expected, actual)
259
260    # Tests that SchemaCheckMode wraps torch.Tensor when there is a kwarg tensor input
261    @unittest.skipIf(not torch._C.has_spectral, "ATen not built with FFT.")
262    def test_schema_check_mode_functionality_kwarg_tensor(self):
263        x = torch.rand((3, 5))
264        w = torch.rand(4)
265        expected = torch.stft(x, 4, win_length=4, window=w, return_complex=True)
266        with SchemaCheckMode():
267            actual = torch.stft(x, 4, win_length=4, window=w, return_complex=True)
268        self.assertEqual(expected, actual)
269
270    # Tests that SchemaCheckMode wraps torch.Tensor with a mutable op
271    def test_schema_check_mode_functionality_mutable_inputs(self):
272        expected = torch.rand((3, 3), requires_grad=False)
273        actual = torch.clone(expected)
274        expected.sinh_()
275        with SchemaCheckMode():
276            actual.sinh_()
277        self.assertEqual(expected, actual)
278
279    # Tests that SchemaCheckMode wraps Torch.tensor when inputs alias
280    def test_schema_check_mode_functionality_aliasing_inputs(self):
281        expected = torch.rand((3, 3))
282        x = expected
283        actual = torch.clone(expected)
284        y = actual
285        expected.add_(x)
286        with SchemaCheckMode():
287            actual.add_(y)
288        self.assertEqual(expected, actual)
289
290    # Tests that SchemaCheckMode wraps Torch.tensor with multiple tensor outputs
291    def test_schema_check_mode_functionality_with_multiple_outputs(self):
292        x = torch.arange(9.)
293        m_expected, e_expected = torch.frexp(x)
294        m_actual = torch.arange(9.)
295        e_actual = torch.zeros([9], dtype=torch.int32)
296        with SchemaCheckMode():
297            torch.frexp(x, out=(m_actual, e_actual))
298        self.assertEqual(m_expected, m_actual)
299        self.assertEqual(e_expected, e_actual)
300
301    # Tests that SchemaCheckMode wraps Torch.tensor with aliasing outputs due to aliasing inputs
302    def test_schema_check_mode_functionality_with_multiple_outputs_aliasing(self):
303        x = torch.rand((3, 3))
304        actual = torch.zeros(3)
305        with SchemaCheckMode():
306            torch.aminmax(x, dim=0, out=[actual, actual])
307        self.assertEqual(torch.amax(x, dim=0), actual)
308
309    # Tests that SchemaCheckMode wraps Torch.tensor in ops with real Device input
310    def test_schema_check_mode_functionality_device_input(self):
311        with SchemaCheckMode():
312            x = torch.rand((3, 3), device="cpu", dtype=torch.double)
313            y = x + x
314        self.assertEqual(x + x, y)
315
316    # Tests that SchemaCheckMode wraps Torch.tensor in special training op edge case
317    def test_schema_check_mode_functionality_training_op(self):
318        x = torch.rand((3, 3), requires_grad=True)
319        batch = torch.nn.BatchNorm1d(3, track_running_stats=True)
320        expected = batch(x)
321        with SchemaCheckMode():
322            actual = batch(x)
323        self.assertEqual(expected, actual)
324
325    # Tests that SchemaCheckMode wraps Torch.tensor with nested training op edge case
326    def test_schema_check_mode_functionality_nested_training_op(self):
327        actual = torch.rand((3, 3))
328        batch = torch.nn.BatchNorm1d(3, track_running_stats=True)
329        expected = torch.clone(actual)
330        expected.sinh_()
331        expected.tanh_()
332        expected.relu_()
333        expected = batch(expected)
334
335        with SchemaCheckMode():
336            actual.sinh_()
337            actual.tanh_()
338            actual.relu_()
339            actual = batch(actual)
340        self.assertEqual(expected, actual)
341
342    # Tests that SchemaCheckMode wraps Torch.tensor with empty list input
343    def test_schema_check_mode_empty_list_input(self):
344        expected = torch.atleast_1d([])
345        with SchemaCheckMode():
346            actual = torch.atleast_1d([])
347        self.assertEqual(expected, actual)
348
349    # Tests that an exception is raised for a mismatching mutation
350    def test_mutation_check_fail(self):
351        with self.assertRaisesRegex(RuntimeError, "Argument input is not defined as mutable but was mutated"):
352            x = torch.rand((3, 3))
353            y = torch.rand((3, 3))
354            with SchemaCheckMode():
355                IncorrectAliasTensor(x).sub(IncorrectAliasTensor(y))
356
357    # # Tests that an exception is raised for a mismatching mutation over multiple ops
358    def test_mutation_check_fail_multiple_operators(self):
359        with self.assertRaisesRegex(RuntimeError, "Argument input is not defined as mutable but was mutated"):
360            x = torch.rand((3, 3))
361            y = torch.rand((3, 3))
362            with SchemaCheckMode():
363                IncorrectAliasTensor(x).sin().cos().sub(IncorrectAliasTensor(y))
364
365    # Tests that an exception is raised for a mismatching alias
366    def test_alias_check_fail_simple(self):
367        with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"):
368            x = torch.rand((3, 3), requires_grad=True)
369            y = torch.rand((3, 3))
370            with SchemaCheckMode():
371                IncorrectAliasTensor(x).add(IncorrectAliasTensor(y), alpha=2)
372
373    # Tests that an exception is raised for a mismatching alias over multiple ops
374    def test_alias_check_fail_multiple_operators(self):
375        with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"):
376            x = torch.rand((3, 3), requires_grad=True)
377            y = torch.zeros((3, 3), requires_grad=True)
378            with SchemaCheckMode():
379                IncorrectAliasTensor(x).sin().relu().add(IncorrectAliasTensor(y), alpha=2)
380
381    # Tests that an exception is raised for a centered mismatching alias over multiple ops
382    def test_alias_check_fail_multiple_operators_centered(self):
383        with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"):
384            x = torch.rand((3, 3), requires_grad=True)
385            y = torch.zeros((3, 3), requires_grad=True)
386            with SchemaCheckMode():
387                IncorrectAliasTensor(x).sin().add(IncorrectAliasTensor(y), alpha=2).relu()
388
389    # Tests that an exception is raised for a centered mismatching alias over multiple ops
390    def test_alias_check_fail_outputs_unexpectedly_aliasing(self):
391        with self.assertRaisesRegex(RuntimeError, "Outputs 0 and 1 alias unexpectedly"):
392            x = torch.rand((3, 3))
393            with SchemaCheckMode() as s:
394                IncorrectAliasTensor(x).aminmax(dim=0)
395
396    # When this file was written, python op registration didn't exist.
397    # It's probably worth re-writing the entire file to use it,
398    # but instead I just added extra tests.
399    def test_alias_check_fail_custom_ops_secretly_aliasing(self):
400        def f(x):
401            return torch.ops.bad_schemas.secretly_aliasing(x)
402
403        x = torch.rand((3, 3))
404        with self.assertRaisesRegex(RuntimeError, "not defined to alias output but was aliasing"):
405            with SchemaCheckMode() as s:
406                out = f(x)
407
408    def test_alias_check_fail_custom_ops_secretly_mutating(self):
409        def f(x):
410            return torch.ops.bad_schemas.secretly_mutating(x)
411
412        x = torch.rand((3, 3))
413        with self.assertRaisesRegex(RuntimeError, "not defined as mutable but was mutated"):
414            with SchemaCheckMode() as s:
415                out = f(x)
416
417    def test_alias_check_fail_custom_ops_output_is_input(self):
418        def f(x):
419            return torch.ops.bad_schemas.output_is_input(x)
420
421        x = torch.rand((3, 3))
422        with self.assertRaisesRegex(RuntimeError, "are not allowed to directly return inputs"):
423            with SchemaCheckMode() as s:
424                out = f(x)
425
426    # Tests that is_alias_of returns as expected
427    def test_is_alias_of_basic(self):
428        x = torch.rand((3, 3), requires_grad=True)
429        y = torch.rand((3, 3), requires_grad=True)
430        y = x.add(x, alpha=2)
431        self.assertTrue(torch._C._is_alias_of(x, x))
432        self.assertFalse(torch._C._is_alias_of(x, y))
433
434    # Tests that is_alias_of returns as expected with empty containers
435    def test_is_alias_of_empty_container(self):
436        x = []
437        y = torch.rand((3, 3), requires_grad=True)
438        self.assertFalse(torch._C._is_alias_of(x, x))
439        self.assertFalse(torch._C._is_alias_of(x, y))
440
441    # Tests that overlaps returns as expected
442    def test_overlaps_basic(self):
443        x = torch.rand((3, 3), requires_grad=True)
444        y = torch.rand((3, 3), requires_grad=True)
445        z = [x, y]
446        self.assertTrue(torch._C._overlaps(x, x))
447        self.assertFalse(torch._C._overlaps(x, y))
448        self.assertTrue(torch._C._overlaps(z, x))
449        self.assertTrue(torch._C._overlaps(z, y))
450
451    # Tests that overlaps returns correctly with empty containers
452    def test_overlaps_empty_container(self):
453        x = []
454        y = [torch.rand((3, 3), requires_grad=True)]
455        # Empty containers return false
456        self.assertFalse(torch._C._overlaps(y, x))
457        self.assertTrue(torch._C._overlaps(y, y))
458
459    # Tests that SchemaInfo Bindings work as expected
460    def test_schema_info_bind_basic(self):
461        class SchemaInfoBindTestMode(TorchDispatchMode):
462            def __init__(self, test_self):
463                self.test_self = test_self
464
465            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
466                named_arg_list = normalize_function(
467                    func,
468                    args,
469                    kwargs,
470                    normalize_to_only_use_kwargs=True
471                ).kwargs
472                schema_info_value_test = torch._C._SchemaInfo(func._schema)
473                schema_info_values_test = torch._C._SchemaInfo(func._schema)
474                self.test_self.assertFalse(schema_info_value_test.may_alias(
475                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0),
476                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1)))
477                self.test_self.assertFalse(schema_info_values_test.may_alias(
478                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0),
479                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1)))
480                for i in named_arg_list:
481                    schema_info_value_test.add_argument_value(i, named_arg_list[i])
482                schema_info_values_test.add_argument_values(named_arg_list)
483                self.test_self.assertTrue(schema_info_value_test.may_alias(
484                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0),
485                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1)))
486                self.test_self.assertTrue(schema_info_values_test.may_alias(
487                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0),
488                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1)))
489
490                return func(*args, **kwargs)
491        x = torch.rand((3, 3))
492        with SchemaInfoBindTestMode(self) as schemaInfoCheck:
493            x.add(x)
494
495
496class TestSchemaCheckModeOpInfo(JitTestCase):
497    @ops(op_db, dtypes=OpDTypes.supported)
498    def test_schema_correctness(self, device, dtype, op):
499        # Currently torch.equal isn't supported with torch.complex32
500        # There's also errors with complex64 and complex128
501        if (dtype == torch.complex32):
502            return
503        for sample in op.sample_inputs(device, dtype, requires_grad=False):
504            with SchemaCheckMode():
505                op(sample.input, *sample.args, **sample.kwargs)
506
507instantiate_device_type_tests(TestSchemaCheckModeOpInfo, globals(), only_for=("cpu", "cuda"))
508
509if __name__ == '__main__':
510    run_tests()
511