xref: /aosp_15_r20/external/pytorch/test/test_schema_check.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport os
4*da0073e9SAndroid Build Coastguard Workerimport sys
5*da0073e9SAndroid Build Coastguard Workerimport torch
6*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._pytree import tree_map
7*da0073e9SAndroid Build Coastguard Workerimport unittest
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, TEST_WITH_TORCHDYNAMO
10*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.operator_schemas import normalize_function
11*da0073e9SAndroid Build Coastguard Workerfrom torch._subclasses.schema_check_mode import SchemaCheckMode
12*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._python_dispatch import TorchDispatchMode
13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import op_db
14*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase
15*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ops, OpDTypes, instantiate_device_type_tests
16*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
17*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir)
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Workerdef secretly_aliasing(x):
20*da0073e9SAndroid Build Coastguard Worker    return x.view(-1)
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Workerdef secretly_mutating(x):
23*da0073e9SAndroid Build Coastguard Worker    x.mul_(2)
24*da0073e9SAndroid Build Coastguard Worker    return x * 3
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Workerdef output_is_input(x):
27*da0073e9SAndroid Build Coastguard Worker    return x
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Workercustom_lib = torch.library.Library("bad_schemas", "DEF")  # noqa: TOR901
30*da0073e9SAndroid Build Coastguard Workercustom_lib.define("secretly_aliasing(Tensor x) -> Tensor")
31*da0073e9SAndroid Build Coastguard Workercustom_lib.define("secretly_mutating(Tensor x) -> Tensor")
32*da0073e9SAndroid Build Coastguard Workercustom_lib.define("output_is_input(Tensor(a) x) -> Tensor(a)")
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Workercustom_lib_cpu = torch.library.Library("bad_schemas", "IMPL", "CPU")  # noqa: TOR901
35*da0073e9SAndroid Build Coastguard Workercustom_lib_cpu.impl("secretly_aliasing", secretly_aliasing)
36*da0073e9SAndroid Build Coastguard Workercustom_lib_cpu.impl("secretly_mutating", secretly_mutating)
37*da0073e9SAndroid Build Coastguard Workercustom_lib_cpu.impl("output_is_input", output_is_input)
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Workercustom_lib_meta = torch.library.Library("bad_schemas", "IMPL", "Meta")  # noqa: TOR901
40*da0073e9SAndroid Build Coastguard Workercustom_lib_meta.impl("secretly_aliasing", secretly_aliasing)
41*da0073e9SAndroid Build Coastguard Workercustom_lib_meta.impl("secretly_mutating", secretly_mutating)
42*da0073e9SAndroid Build Coastguard Workercustom_lib_meta.impl("output_is_input", output_is_input)
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker# This TorchDispatchTensor Subclass is used to simulate an incorrect schema
45*da0073e9SAndroid Build Coastguard Worker# which is then used to test that SchemaCheckMode behaves as expected
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Workerclass IncorrectAliasTensor(torch.Tensor):
48*da0073e9SAndroid Build Coastguard Worker    ALIAS_ARG_OUT = {"aten::add"}
49*da0073e9SAndroid Build Coastguard Worker    ALIAS_OUT_OUT = {"aten::aminmax"}
50*da0073e9SAndroid Build Coastguard Worker    MUTATE_ARGS_OUT = {"aten::sub"}
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker    elem: torch.Tensor
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Worker    __slots__ = ['elem']
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker    @staticmethod
57*da0073e9SAndroid Build Coastguard Worker    def __new__(cls, elem, *args, **kwargs):
58*da0073e9SAndroid Build Coastguard Worker        # The wrapping tensor (IncorrectAliasTensor) shouldn't hold any
59*da0073e9SAndroid Build Coastguard Worker        # memory for the class in question, but it should still
60*da0073e9SAndroid Build Coastguard Worker        # advertise the same device as before
61*da0073e9SAndroid Build Coastguard Worker        r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
62*da0073e9SAndroid Build Coastguard Worker            cls, elem.size(),
63*da0073e9SAndroid Build Coastguard Worker            strides=elem.stride(), storage_offset=elem.storage_offset(),
64*da0073e9SAndroid Build Coastguard Worker            # TODO: clone storage aliasing
65*da0073e9SAndroid Build Coastguard Worker            dtype=elem.dtype, layout=elem.layout,
66*da0073e9SAndroid Build Coastguard Worker            device=elem.device, requires_grad=kwargs.get("requires_grad", False)
67*da0073e9SAndroid Build Coastguard Worker        )
68*da0073e9SAndroid Build Coastguard Worker        # ...the real tensor is held as an element on the tensor.
69*da0073e9SAndroid Build Coastguard Worker        r.elem = elem.detach() if r.requires_grad else elem
70*da0073e9SAndroid Build Coastguard Worker        return r
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Worker    def __repr__(self):
73*da0073e9SAndroid Build Coastguard Worker        return super().__repr__(tensor_contents=f"{self.elem}")
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker    @classmethod
76*da0073e9SAndroid Build Coastguard Worker    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
77*da0073e9SAndroid Build Coastguard Worker        def unwrap(e):
78*da0073e9SAndroid Build Coastguard Worker            return e.elem if isinstance(e, cls) else e
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker        def wrap(e):
81*da0073e9SAndroid Build Coastguard Worker            return cls(e) if isinstance(e, torch.Tensor) else e
82*da0073e9SAndroid Build Coastguard Worker        unwrapped_args = tree_map(unwrap, args)
83*da0073e9SAndroid Build Coastguard Worker        out = func(*unwrapped_args, **tree_map(unwrap, kwargs))
84*da0073e9SAndroid Build Coastguard Worker        if func._schema.name in IncorrectAliasTensor.ALIAS_ARG_OUT:
85*da0073e9SAndroid Build Coastguard Worker            args[0].elem = out
86*da0073e9SAndroid Build Coastguard Worker        if func._schema.name in IncorrectAliasTensor.MUTATE_ARGS_OUT:
87*da0073e9SAndroid Build Coastguard Worker            args[0].elem = torch.rand(args[0].elem.shape)
88*da0073e9SAndroid Build Coastguard Worker        if func._schema.name in IncorrectAliasTensor.ALIAS_OUT_OUT:
89*da0073e9SAndroid Build Coastguard Worker            incorrect_out = list(out)
90*da0073e9SAndroid Build Coastguard Worker            incorrect_out[0] = incorrect_out[1]
91*da0073e9SAndroid Build Coastguard Worker            return tree_map(wrap, tuple(incorrect_out))
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker        return tree_map(wrap, out)
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker# Tests various schema checking functionalities.
96*da0073e9SAndroid Build Coastguard Workerclass TestSchemaCheck(JitTestCase):
97*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
98*da0073e9SAndroid Build Coastguard Worker        if TEST_WITH_TORCHDYNAMO:
99*da0073e9SAndroid Build Coastguard Worker            self.skipTest("SchemaCheckMode is ignored by dynamo")
100*da0073e9SAndroid Build Coastguard Worker        super().setUp()
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker    # Tests that SchemaCheckMode records operator order with grad
103*da0073e9SAndroid Build Coastguard Worker    def test_schema_check_mode_operator_order(self):
104*da0073e9SAndroid Build Coastguard Worker        with SchemaCheckMode() as schema_check:
105*da0073e9SAndroid Build Coastguard Worker            x = torch.rand((3, 3), requires_grad=True)
106*da0073e9SAndroid Build Coastguard Worker            x.relu().sin()
107*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(["aten::rand", "aten::relu", "aten::detach", "aten::sin"], schema_check.ops)
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker    # Tests that SchemaCheckMode records operator order without grad
110*da0073e9SAndroid Build Coastguard Worker    def test_schema_check_mode_operator_order_without_grad(self):
111*da0073e9SAndroid Build Coastguard Worker        with SchemaCheckMode() as schema_check:
112*da0073e9SAndroid Build Coastguard Worker            x = torch.rand((3, 3), requires_grad=False)
113*da0073e9SAndroid Build Coastguard Worker            x.relu().sin()
114*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(["aten::rand", "aten::relu", "aten::sin"], schema_check.ops)
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker    # Tests that SchemaCheckMode records mutations and aliases with none expected
117*da0073e9SAndroid Build Coastguard Worker    def test_schema_check_mode_mutated_aliasing_none(self):
118*da0073e9SAndroid Build Coastguard Worker        # NB: previously requires_grad=True, but this induces a detach for
119*da0073e9SAndroid Build Coastguard Worker        # saved variable
120*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((3, 3))
121*da0073e9SAndroid Build Coastguard Worker        with SchemaCheckMode() as schema_check:
122*da0073e9SAndroid Build Coastguard Worker            actual = x.relu().sin()
123*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([], schema_check.mutated)
124*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([], schema_check.aliasing)
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker    # Tests that SchemaCheckMode records mutations and aliases with mutation expected
127*da0073e9SAndroid Build Coastguard Worker    def test_schema_check_mode_mutated_aliasing_mutation(self):
128*da0073e9SAndroid Build Coastguard Worker        actual = torch.rand((3, 3), requires_grad=False)
129*da0073e9SAndroid Build Coastguard Worker        with SchemaCheckMode() as schema_check:
130*da0073e9SAndroid Build Coastguard Worker            actual.sinh_()
131*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([('aten::sinh_', 'input')], schema_check.mutated)
132*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([('aten::sinh_', 'input', 'output_0')], schema_check.aliasing)
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Worker    # Tests that SchemaCheckMode records mutations and aliases with resize_
135*da0073e9SAndroid Build Coastguard Worker    def test_schema_check_mode_mutated_aliasing_resize_(self):
136*da0073e9SAndroid Build Coastguard Worker        actual = torch.rand((3, 3), requires_grad=False)
137*da0073e9SAndroid Build Coastguard Worker        with SchemaCheckMode() as schema_check:
138*da0073e9SAndroid Build Coastguard Worker            actual.resize_(9)
139*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([('aten::resize_', 'input')], schema_check.mutated)
140*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([('aten::resize_', 'input', 'output_0')], schema_check.aliasing)
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker    # Tests that SchemaCheckMode records mutations and aliases with aliasing inputs
143*da0073e9SAndroid Build Coastguard Worker    def test_schema_check_mode_mutated_aliasing_aliasing_inputs(self):
144*da0073e9SAndroid Build Coastguard Worker        actual = torch.rand((3, 3))
145*da0073e9SAndroid Build Coastguard Worker        y = actual
146*da0073e9SAndroid Build Coastguard Worker        with SchemaCheckMode() as schema_check:
147*da0073e9SAndroid Build Coastguard Worker            actual.add_(y)
148*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
149*da0073e9SAndroid Build Coastguard Worker            [
150*da0073e9SAndroid Build Coastguard Worker                ('aten::add_', 'input'),
151*da0073e9SAndroid Build Coastguard Worker                ('aten::add_', 'other')
152*da0073e9SAndroid Build Coastguard Worker            ],
153*da0073e9SAndroid Build Coastguard Worker            schema_check.mutated
154*da0073e9SAndroid Build Coastguard Worker        )
155*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
156*da0073e9SAndroid Build Coastguard Worker            [
157*da0073e9SAndroid Build Coastguard Worker                ('aten::add_', 'input', 'output_0'),
158*da0073e9SAndroid Build Coastguard Worker                ('aten::add_', 'other', 'output_0')
159*da0073e9SAndroid Build Coastguard Worker            ],
160*da0073e9SAndroid Build Coastguard Worker            schema_check.aliasing
161*da0073e9SAndroid Build Coastguard Worker        )
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Worker    # Tests that SchemaCheckMode records mutations and alias with as_strided
164*da0073e9SAndroid Build Coastguard Worker    def test_schema_check_mode_mutated_aliasing_as_strided(self):
165*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((3, 6, 4))
166*da0073e9SAndroid Build Coastguard Worker        with SchemaCheckMode() as schema_check:
167*da0073e9SAndroid Build Coastguard Worker            x.as_strided_([3, 6, 4], [9, 1, 1])
168*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
169*da0073e9SAndroid Build Coastguard Worker            [
170*da0073e9SAndroid Build Coastguard Worker                ('aten::as_strided_', 'input')
171*da0073e9SAndroid Build Coastguard Worker            ],
172*da0073e9SAndroid Build Coastguard Worker            schema_check.mutated
173*da0073e9SAndroid Build Coastguard Worker        )
174*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
175*da0073e9SAndroid Build Coastguard Worker            [
176*da0073e9SAndroid Build Coastguard Worker                ('aten::as_strided_', 'input', 'output_0')
177*da0073e9SAndroid Build Coastguard Worker            ],
178*da0073e9SAndroid Build Coastguard Worker            schema_check.aliasing
179*da0073e9SAndroid Build Coastguard Worker        )
180*da0073e9SAndroid Build Coastguard Worker
181*da0073e9SAndroid Build Coastguard Worker    # Tests that SchemaCheckMode records mutations and aliases with multiple outputs
182*da0073e9SAndroid Build Coastguard Worker    def test_schema_check_mode_mutated_aliasing_multiple_outputs(self):
183*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(9.)
184*da0073e9SAndroid Build Coastguard Worker        m_actual = torch.arange(9.)
185*da0073e9SAndroid Build Coastguard Worker        e_actual = torch.zeros([9], dtype=torch.int32)
186*da0073e9SAndroid Build Coastguard Worker        with SchemaCheckMode() as schema_check:
187*da0073e9SAndroid Build Coastguard Worker            torch.frexp(x, out=(m_actual, e_actual))
188*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
189*da0073e9SAndroid Build Coastguard Worker            [
190*da0073e9SAndroid Build Coastguard Worker                ('aten::frexp', 'mantissa'),
191*da0073e9SAndroid Build Coastguard Worker                ('aten::frexp', 'exponent')
192*da0073e9SAndroid Build Coastguard Worker            ],
193*da0073e9SAndroid Build Coastguard Worker            schema_check.mutated
194*da0073e9SAndroid Build Coastguard Worker        )
195*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
196*da0073e9SAndroid Build Coastguard Worker            [
197*da0073e9SAndroid Build Coastguard Worker                ('aten::frexp', 'mantissa', 'output_0'),
198*da0073e9SAndroid Build Coastguard Worker                ('aten::frexp', 'exponent', 'output_1')
199*da0073e9SAndroid Build Coastguard Worker            ],
200*da0073e9SAndroid Build Coastguard Worker            schema_check.aliasing
201*da0073e9SAndroid Build Coastguard Worker        )
202*da0073e9SAndroid Build Coastguard Worker
203*da0073e9SAndroid Build Coastguard Worker    # Tests that SchemaCheckMode records mutations and aliases with aliasing outputs
204*da0073e9SAndroid Build Coastguard Worker    def test_schema_check_mode_mutated_aliasing_aliasing_outputs(self):
205*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((3, 3))
206*da0073e9SAndroid Build Coastguard Worker        actual = torch.zeros(3)
207*da0073e9SAndroid Build Coastguard Worker        with SchemaCheckMode() as schema_check:
208*da0073e9SAndroid Build Coastguard Worker            torch.aminmax(x, dim=0, out=[actual, actual])
209*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
210*da0073e9SAndroid Build Coastguard Worker            [
211*da0073e9SAndroid Build Coastguard Worker                ('aten::aminmax', 'min'),
212*da0073e9SAndroid Build Coastguard Worker                ('aten::aminmax', 'max')
213*da0073e9SAndroid Build Coastguard Worker            ],
214*da0073e9SAndroid Build Coastguard Worker            schema_check.mutated
215*da0073e9SAndroid Build Coastguard Worker        )
216*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
217*da0073e9SAndroid Build Coastguard Worker            [
218*da0073e9SAndroid Build Coastguard Worker                ('aten::aminmax', 'min', 'output_0'),
219*da0073e9SAndroid Build Coastguard Worker                ('aten::aminmax', 'min', 'output_1'),
220*da0073e9SAndroid Build Coastguard Worker                ('aten::aminmax', 'max', 'output_0'),
221*da0073e9SAndroid Build Coastguard Worker                ('aten::aminmax', 'max', 'output_1')
222*da0073e9SAndroid Build Coastguard Worker            ],
223*da0073e9SAndroid Build Coastguard Worker            schema_check.aliasing
224*da0073e9SAndroid Build Coastguard Worker        )
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Worker    # Tests that SchemaCheckMode wraps torch.Tensor
227*da0073e9SAndroid Build Coastguard Worker    def test_schema_check_mode_functionality(self):
228*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((3, 3), requires_grad=True)
229*da0073e9SAndroid Build Coastguard Worker        expected = x.relu().sin()
230*da0073e9SAndroid Build Coastguard Worker        with SchemaCheckMode():
231*da0073e9SAndroid Build Coastguard Worker            actual = x.relu().sin()
232*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, actual)
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker    # Tests that SchemaCheckMode wraps torch.Tensor when an argument's default is overriden
235*da0073e9SAndroid Build Coastguard Worker    def test_schema_check_mode_functionality_default_replaced(self):
236*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((3, 3), requires_grad=True)
237*da0073e9SAndroid Build Coastguard Worker        expected = x.add(x, alpha=2)
238*da0073e9SAndroid Build Coastguard Worker        with SchemaCheckMode():
239*da0073e9SAndroid Build Coastguard Worker            actual = x.add(x, alpha=2)
240*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, actual)
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Worker    # Tests that SchemaCheckMode wraps torch.Tensor when there is a Tensor[] argument
243*da0073e9SAndroid Build Coastguard Worker    def test_schema_check_mode_functionality_list_input(self):
244*da0073e9SAndroid Build Coastguard Worker        a = torch.rand((3, 3))
245*da0073e9SAndroid Build Coastguard Worker        b = torch.rand((3, 3))
246*da0073e9SAndroid Build Coastguard Worker        c = torch.rand((3, 3))
247*da0073e9SAndroid Build Coastguard Worker        expected = torch.linalg.multi_dot([a, b, c])
248*da0073e9SAndroid Build Coastguard Worker        with SchemaCheckMode():
249*da0073e9SAndroid Build Coastguard Worker            actual = torch.linalg.multi_dot([a, b, c])
250*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, actual)
251*da0073e9SAndroid Build Coastguard Worker
252*da0073e9SAndroid Build Coastguard Worker    # Tests that SchemaCheckMode wraps torch.Tensor with an op that has the (a -> *) notation
253*da0073e9SAndroid Build Coastguard Worker    def test_schema_check_mode_functionality_wildcard_after(self):
254*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((3, 3))
255*da0073e9SAndroid Build Coastguard Worker        expected = x.chunk(6)
256*da0073e9SAndroid Build Coastguard Worker        with SchemaCheckMode():
257*da0073e9SAndroid Build Coastguard Worker            actual = x.chunk(6)
258*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, actual)
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker    # Tests that SchemaCheckMode wraps torch.Tensor when there is a kwarg tensor input
261*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not torch._C.has_spectral, "ATen not built with FFT.")
262*da0073e9SAndroid Build Coastguard Worker    def test_schema_check_mode_functionality_kwarg_tensor(self):
263*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((3, 5))
264*da0073e9SAndroid Build Coastguard Worker        w = torch.rand(4)
265*da0073e9SAndroid Build Coastguard Worker        expected = torch.stft(x, 4, win_length=4, window=w, return_complex=True)
266*da0073e9SAndroid Build Coastguard Worker        with SchemaCheckMode():
267*da0073e9SAndroid Build Coastguard Worker            actual = torch.stft(x, 4, win_length=4, window=w, return_complex=True)
268*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, actual)
269*da0073e9SAndroid Build Coastguard Worker
270*da0073e9SAndroid Build Coastguard Worker    # Tests that SchemaCheckMode wraps torch.Tensor with a mutable op
271*da0073e9SAndroid Build Coastguard Worker    def test_schema_check_mode_functionality_mutable_inputs(self):
272*da0073e9SAndroid Build Coastguard Worker        expected = torch.rand((3, 3), requires_grad=False)
273*da0073e9SAndroid Build Coastguard Worker        actual = torch.clone(expected)
274*da0073e9SAndroid Build Coastguard Worker        expected.sinh_()
275*da0073e9SAndroid Build Coastguard Worker        with SchemaCheckMode():
276*da0073e9SAndroid Build Coastguard Worker            actual.sinh_()
277*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, actual)
278*da0073e9SAndroid Build Coastguard Worker
279*da0073e9SAndroid Build Coastguard Worker    # Tests that SchemaCheckMode wraps Torch.tensor when inputs alias
280*da0073e9SAndroid Build Coastguard Worker    def test_schema_check_mode_functionality_aliasing_inputs(self):
281*da0073e9SAndroid Build Coastguard Worker        expected = torch.rand((3, 3))
282*da0073e9SAndroid Build Coastguard Worker        x = expected
283*da0073e9SAndroid Build Coastguard Worker        actual = torch.clone(expected)
284*da0073e9SAndroid Build Coastguard Worker        y = actual
285*da0073e9SAndroid Build Coastguard Worker        expected.add_(x)
286*da0073e9SAndroid Build Coastguard Worker        with SchemaCheckMode():
287*da0073e9SAndroid Build Coastguard Worker            actual.add_(y)
288*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, actual)
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard Worker    # Tests that SchemaCheckMode wraps Torch.tensor with multiple tensor outputs
291*da0073e9SAndroid Build Coastguard Worker    def test_schema_check_mode_functionality_with_multiple_outputs(self):
292*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(9.)
293*da0073e9SAndroid Build Coastguard Worker        m_expected, e_expected = torch.frexp(x)
294*da0073e9SAndroid Build Coastguard Worker        m_actual = torch.arange(9.)
295*da0073e9SAndroid Build Coastguard Worker        e_actual = torch.zeros([9], dtype=torch.int32)
296*da0073e9SAndroid Build Coastguard Worker        with SchemaCheckMode():
297*da0073e9SAndroid Build Coastguard Worker            torch.frexp(x, out=(m_actual, e_actual))
298*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m_expected, m_actual)
299*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(e_expected, e_actual)
300*da0073e9SAndroid Build Coastguard Worker
301*da0073e9SAndroid Build Coastguard Worker    # Tests that SchemaCheckMode wraps Torch.tensor with aliasing outputs due to aliasing inputs
302*da0073e9SAndroid Build Coastguard Worker    def test_schema_check_mode_functionality_with_multiple_outputs_aliasing(self):
303*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((3, 3))
304*da0073e9SAndroid Build Coastguard Worker        actual = torch.zeros(3)
305*da0073e9SAndroid Build Coastguard Worker        with SchemaCheckMode():
306*da0073e9SAndroid Build Coastguard Worker            torch.aminmax(x, dim=0, out=[actual, actual])
307*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.amax(x, dim=0), actual)
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Worker    # Tests that SchemaCheckMode wraps Torch.tensor in ops with real Device input
310*da0073e9SAndroid Build Coastguard Worker    def test_schema_check_mode_functionality_device_input(self):
311*da0073e9SAndroid Build Coastguard Worker        with SchemaCheckMode():
312*da0073e9SAndroid Build Coastguard Worker            x = torch.rand((3, 3), device="cpu", dtype=torch.double)
313*da0073e9SAndroid Build Coastguard Worker            y = x + x
314*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x + x, y)
315*da0073e9SAndroid Build Coastguard Worker
316*da0073e9SAndroid Build Coastguard Worker    # Tests that SchemaCheckMode wraps Torch.tensor in special training op edge case
317*da0073e9SAndroid Build Coastguard Worker    def test_schema_check_mode_functionality_training_op(self):
318*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((3, 3), requires_grad=True)
319*da0073e9SAndroid Build Coastguard Worker        batch = torch.nn.BatchNorm1d(3, track_running_stats=True)
320*da0073e9SAndroid Build Coastguard Worker        expected = batch(x)
321*da0073e9SAndroid Build Coastguard Worker        with SchemaCheckMode():
322*da0073e9SAndroid Build Coastguard Worker            actual = batch(x)
323*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, actual)
324*da0073e9SAndroid Build Coastguard Worker
325*da0073e9SAndroid Build Coastguard Worker    # Tests that SchemaCheckMode wraps Torch.tensor with nested training op edge case
326*da0073e9SAndroid Build Coastguard Worker    def test_schema_check_mode_functionality_nested_training_op(self):
327*da0073e9SAndroid Build Coastguard Worker        actual = torch.rand((3, 3))
328*da0073e9SAndroid Build Coastguard Worker        batch = torch.nn.BatchNorm1d(3, track_running_stats=True)
329*da0073e9SAndroid Build Coastguard Worker        expected = torch.clone(actual)
330*da0073e9SAndroid Build Coastguard Worker        expected.sinh_()
331*da0073e9SAndroid Build Coastguard Worker        expected.tanh_()
332*da0073e9SAndroid Build Coastguard Worker        expected.relu_()
333*da0073e9SAndroid Build Coastguard Worker        expected = batch(expected)
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard Worker        with SchemaCheckMode():
336*da0073e9SAndroid Build Coastguard Worker            actual.sinh_()
337*da0073e9SAndroid Build Coastguard Worker            actual.tanh_()
338*da0073e9SAndroid Build Coastguard Worker            actual.relu_()
339*da0073e9SAndroid Build Coastguard Worker            actual = batch(actual)
340*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, actual)
341*da0073e9SAndroid Build Coastguard Worker
342*da0073e9SAndroid Build Coastguard Worker    # Tests that SchemaCheckMode wraps Torch.tensor with empty list input
343*da0073e9SAndroid Build Coastguard Worker    def test_schema_check_mode_empty_list_input(self):
344*da0073e9SAndroid Build Coastguard Worker        expected = torch.atleast_1d([])
345*da0073e9SAndroid Build Coastguard Worker        with SchemaCheckMode():
346*da0073e9SAndroid Build Coastguard Worker            actual = torch.atleast_1d([])
347*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, actual)
348*da0073e9SAndroid Build Coastguard Worker
349*da0073e9SAndroid Build Coastguard Worker    # Tests that an exception is raised for a mismatching mutation
350*da0073e9SAndroid Build Coastguard Worker    def test_mutation_check_fail(self):
351*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Argument input is not defined as mutable but was mutated"):
352*da0073e9SAndroid Build Coastguard Worker            x = torch.rand((3, 3))
353*da0073e9SAndroid Build Coastguard Worker            y = torch.rand((3, 3))
354*da0073e9SAndroid Build Coastguard Worker            with SchemaCheckMode():
355*da0073e9SAndroid Build Coastguard Worker                IncorrectAliasTensor(x).sub(IncorrectAliasTensor(y))
356*da0073e9SAndroid Build Coastguard Worker
357*da0073e9SAndroid Build Coastguard Worker    # # Tests that an exception is raised for a mismatching mutation over multiple ops
358*da0073e9SAndroid Build Coastguard Worker    def test_mutation_check_fail_multiple_operators(self):
359*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Argument input is not defined as mutable but was mutated"):
360*da0073e9SAndroid Build Coastguard Worker            x = torch.rand((3, 3))
361*da0073e9SAndroid Build Coastguard Worker            y = torch.rand((3, 3))
362*da0073e9SAndroid Build Coastguard Worker            with SchemaCheckMode():
363*da0073e9SAndroid Build Coastguard Worker                IncorrectAliasTensor(x).sin().cos().sub(IncorrectAliasTensor(y))
364*da0073e9SAndroid Build Coastguard Worker
365*da0073e9SAndroid Build Coastguard Worker    # Tests that an exception is raised for a mismatching alias
366*da0073e9SAndroid Build Coastguard Worker    def test_alias_check_fail_simple(self):
367*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"):
368*da0073e9SAndroid Build Coastguard Worker            x = torch.rand((3, 3), requires_grad=True)
369*da0073e9SAndroid Build Coastguard Worker            y = torch.rand((3, 3))
370*da0073e9SAndroid Build Coastguard Worker            with SchemaCheckMode():
371*da0073e9SAndroid Build Coastguard Worker                IncorrectAliasTensor(x).add(IncorrectAliasTensor(y), alpha=2)
372*da0073e9SAndroid Build Coastguard Worker
373*da0073e9SAndroid Build Coastguard Worker    # Tests that an exception is raised for a mismatching alias over multiple ops
374*da0073e9SAndroid Build Coastguard Worker    def test_alias_check_fail_multiple_operators(self):
375*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"):
376*da0073e9SAndroid Build Coastguard Worker            x = torch.rand((3, 3), requires_grad=True)
377*da0073e9SAndroid Build Coastguard Worker            y = torch.zeros((3, 3), requires_grad=True)
378*da0073e9SAndroid Build Coastguard Worker            with SchemaCheckMode():
379*da0073e9SAndroid Build Coastguard Worker                IncorrectAliasTensor(x).sin().relu().add(IncorrectAliasTensor(y), alpha=2)
380*da0073e9SAndroid Build Coastguard Worker
381*da0073e9SAndroid Build Coastguard Worker    # Tests that an exception is raised for a centered mismatching alias over multiple ops
382*da0073e9SAndroid Build Coastguard Worker    def test_alias_check_fail_multiple_operators_centered(self):
383*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"):
384*da0073e9SAndroid Build Coastguard Worker            x = torch.rand((3, 3), requires_grad=True)
385*da0073e9SAndroid Build Coastguard Worker            y = torch.zeros((3, 3), requires_grad=True)
386*da0073e9SAndroid Build Coastguard Worker            with SchemaCheckMode():
387*da0073e9SAndroid Build Coastguard Worker                IncorrectAliasTensor(x).sin().add(IncorrectAliasTensor(y), alpha=2).relu()
388*da0073e9SAndroid Build Coastguard Worker
389*da0073e9SAndroid Build Coastguard Worker    # Tests that an exception is raised for a centered mismatching alias over multiple ops
390*da0073e9SAndroid Build Coastguard Worker    def test_alias_check_fail_outputs_unexpectedly_aliasing(self):
391*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Outputs 0 and 1 alias unexpectedly"):
392*da0073e9SAndroid Build Coastguard Worker            x = torch.rand((3, 3))
393*da0073e9SAndroid Build Coastguard Worker            with SchemaCheckMode() as s:
394*da0073e9SAndroid Build Coastguard Worker                IncorrectAliasTensor(x).aminmax(dim=0)
395*da0073e9SAndroid Build Coastguard Worker
396*da0073e9SAndroid Build Coastguard Worker    # When this file was written, python op registration didn't exist.
397*da0073e9SAndroid Build Coastguard Worker    # It's probably worth re-writing the entire file to use it,
398*da0073e9SAndroid Build Coastguard Worker    # but instead I just added extra tests.
399*da0073e9SAndroid Build Coastguard Worker    def test_alias_check_fail_custom_ops_secretly_aliasing(self):
400*da0073e9SAndroid Build Coastguard Worker        def f(x):
401*da0073e9SAndroid Build Coastguard Worker            return torch.ops.bad_schemas.secretly_aliasing(x)
402*da0073e9SAndroid Build Coastguard Worker
403*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((3, 3))
404*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "not defined to alias output but was aliasing"):
405*da0073e9SAndroid Build Coastguard Worker            with SchemaCheckMode() as s:
406*da0073e9SAndroid Build Coastguard Worker                out = f(x)
407*da0073e9SAndroid Build Coastguard Worker
408*da0073e9SAndroid Build Coastguard Worker    def test_alias_check_fail_custom_ops_secretly_mutating(self):
409*da0073e9SAndroid Build Coastguard Worker        def f(x):
410*da0073e9SAndroid Build Coastguard Worker            return torch.ops.bad_schemas.secretly_mutating(x)
411*da0073e9SAndroid Build Coastguard Worker
412*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((3, 3))
413*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "not defined as mutable but was mutated"):
414*da0073e9SAndroid Build Coastguard Worker            with SchemaCheckMode() as s:
415*da0073e9SAndroid Build Coastguard Worker                out = f(x)
416*da0073e9SAndroid Build Coastguard Worker
417*da0073e9SAndroid Build Coastguard Worker    def test_alias_check_fail_custom_ops_output_is_input(self):
418*da0073e9SAndroid Build Coastguard Worker        def f(x):
419*da0073e9SAndroid Build Coastguard Worker            return torch.ops.bad_schemas.output_is_input(x)
420*da0073e9SAndroid Build Coastguard Worker
421*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((3, 3))
422*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "are not allowed to directly return inputs"):
423*da0073e9SAndroid Build Coastguard Worker            with SchemaCheckMode() as s:
424*da0073e9SAndroid Build Coastguard Worker                out = f(x)
425*da0073e9SAndroid Build Coastguard Worker
426*da0073e9SAndroid Build Coastguard Worker    # Tests that is_alias_of returns as expected
427*da0073e9SAndroid Build Coastguard Worker    def test_is_alias_of_basic(self):
428*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((3, 3), requires_grad=True)
429*da0073e9SAndroid Build Coastguard Worker        y = torch.rand((3, 3), requires_grad=True)
430*da0073e9SAndroid Build Coastguard Worker        y = x.add(x, alpha=2)
431*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._is_alias_of(x, x))
432*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch._C._is_alias_of(x, y))
433*da0073e9SAndroid Build Coastguard Worker
434*da0073e9SAndroid Build Coastguard Worker    # Tests that is_alias_of returns as expected with empty containers
435*da0073e9SAndroid Build Coastguard Worker    def test_is_alias_of_empty_container(self):
436*da0073e9SAndroid Build Coastguard Worker        x = []
437*da0073e9SAndroid Build Coastguard Worker        y = torch.rand((3, 3), requires_grad=True)
438*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch._C._is_alias_of(x, x))
439*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch._C._is_alias_of(x, y))
440*da0073e9SAndroid Build Coastguard Worker
441*da0073e9SAndroid Build Coastguard Worker    # Tests that overlaps returns as expected
442*da0073e9SAndroid Build Coastguard Worker    def test_overlaps_basic(self):
443*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((3, 3), requires_grad=True)
444*da0073e9SAndroid Build Coastguard Worker        y = torch.rand((3, 3), requires_grad=True)
445*da0073e9SAndroid Build Coastguard Worker        z = [x, y]
446*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._overlaps(x, x))
447*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch._C._overlaps(x, y))
448*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._overlaps(z, x))
449*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._overlaps(z, y))
450*da0073e9SAndroid Build Coastguard Worker
451*da0073e9SAndroid Build Coastguard Worker    # Tests that overlaps returns correctly with empty containers
452*da0073e9SAndroid Build Coastguard Worker    def test_overlaps_empty_container(self):
453*da0073e9SAndroid Build Coastguard Worker        x = []
454*da0073e9SAndroid Build Coastguard Worker        y = [torch.rand((3, 3), requires_grad=True)]
455*da0073e9SAndroid Build Coastguard Worker        # Empty containers return false
456*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch._C._overlaps(y, x))
457*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._overlaps(y, y))
458*da0073e9SAndroid Build Coastguard Worker
459*da0073e9SAndroid Build Coastguard Worker    # Tests that SchemaInfo Bindings work as expected
460*da0073e9SAndroid Build Coastguard Worker    def test_schema_info_bind_basic(self):
461*da0073e9SAndroid Build Coastguard Worker        class SchemaInfoBindTestMode(TorchDispatchMode):
462*da0073e9SAndroid Build Coastguard Worker            def __init__(self, test_self):
463*da0073e9SAndroid Build Coastguard Worker                self.test_self = test_self
464*da0073e9SAndroid Build Coastguard Worker
465*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
466*da0073e9SAndroid Build Coastguard Worker                named_arg_list = normalize_function(
467*da0073e9SAndroid Build Coastguard Worker                    func,
468*da0073e9SAndroid Build Coastguard Worker                    args,
469*da0073e9SAndroid Build Coastguard Worker                    kwargs,
470*da0073e9SAndroid Build Coastguard Worker                    normalize_to_only_use_kwargs=True
471*da0073e9SAndroid Build Coastguard Worker                ).kwargs
472*da0073e9SAndroid Build Coastguard Worker                schema_info_value_test = torch._C._SchemaInfo(func._schema)
473*da0073e9SAndroid Build Coastguard Worker                schema_info_values_test = torch._C._SchemaInfo(func._schema)
474*da0073e9SAndroid Build Coastguard Worker                self.test_self.assertFalse(schema_info_value_test.may_alias(
475*da0073e9SAndroid Build Coastguard Worker                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0),
476*da0073e9SAndroid Build Coastguard Worker                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1)))
477*da0073e9SAndroid Build Coastguard Worker                self.test_self.assertFalse(schema_info_values_test.may_alias(
478*da0073e9SAndroid Build Coastguard Worker                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0),
479*da0073e9SAndroid Build Coastguard Worker                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1)))
480*da0073e9SAndroid Build Coastguard Worker                for i in named_arg_list:
481*da0073e9SAndroid Build Coastguard Worker                    schema_info_value_test.add_argument_value(i, named_arg_list[i])
482*da0073e9SAndroid Build Coastguard Worker                schema_info_values_test.add_argument_values(named_arg_list)
483*da0073e9SAndroid Build Coastguard Worker                self.test_self.assertTrue(schema_info_value_test.may_alias(
484*da0073e9SAndroid Build Coastguard Worker                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0),
485*da0073e9SAndroid Build Coastguard Worker                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1)))
486*da0073e9SAndroid Build Coastguard Worker                self.test_self.assertTrue(schema_info_values_test.may_alias(
487*da0073e9SAndroid Build Coastguard Worker                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0),
488*da0073e9SAndroid Build Coastguard Worker                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1)))
489*da0073e9SAndroid Build Coastguard Worker
490*da0073e9SAndroid Build Coastguard Worker                return func(*args, **kwargs)
491*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((3, 3))
492*da0073e9SAndroid Build Coastguard Worker        with SchemaInfoBindTestMode(self) as schemaInfoCheck:
493*da0073e9SAndroid Build Coastguard Worker            x.add(x)
494*da0073e9SAndroid Build Coastguard Worker
495*da0073e9SAndroid Build Coastguard Worker
496*da0073e9SAndroid Build Coastguard Workerclass TestSchemaCheckModeOpInfo(JitTestCase):
497*da0073e9SAndroid Build Coastguard Worker    @ops(op_db, dtypes=OpDTypes.supported)
498*da0073e9SAndroid Build Coastguard Worker    def test_schema_correctness(self, device, dtype, op):
499*da0073e9SAndroid Build Coastguard Worker        # Currently torch.equal isn't supported with torch.complex32
500*da0073e9SAndroid Build Coastguard Worker        # There's also errors with complex64 and complex128
501*da0073e9SAndroid Build Coastguard Worker        if (dtype == torch.complex32):
502*da0073e9SAndroid Build Coastguard Worker            return
503*da0073e9SAndroid Build Coastguard Worker        for sample in op.sample_inputs(device, dtype, requires_grad=False):
504*da0073e9SAndroid Build Coastguard Worker            with SchemaCheckMode():
505*da0073e9SAndroid Build Coastguard Worker                op(sample.input, *sample.args, **sample.kwargs)
506*da0073e9SAndroid Build Coastguard Worker
507*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestSchemaCheckModeOpInfo, globals(), only_for=("cpu", "cuda"))
508*da0073e9SAndroid Build Coastguard Worker
509*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__':
510*da0073e9SAndroid Build Coastguard Worker    run_tests()
511