xref: /aosp_15_r20/external/pytorch/test/test_meta.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: decompositions"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport itertools
4*da0073e9SAndroid Build Coastguard Workerimport torch
5*da0073e9SAndroid Build Coastguard Workerimport os
6*da0073e9SAndroid Build Coastguard Workerimport numpy as np
7*da0073e9SAndroid Build Coastguard Workerfrom enum import Enum
8*da0073e9SAndroid Build Coastguard Workerfrom torch.overrides import resolve_name
9*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
10*da0073e9SAndroid Build Coastguard Workerfrom torch.utils import _pytree as pytree
11*da0073e9SAndroid Build Coastguard Workerfrom torch._subclasses.meta_utils import MetaConverter, assert_metadata_eq, is_sparse_any
12*da0073e9SAndroid Build Coastguard Workerimport torch.utils._python_dispatch
13*da0073e9SAndroid Build Coastguard Workerfrom torch._dispatch.python import enable_python_dispatcher
14*da0073e9SAndroid Build Coastguard Workerfrom torch._ops import OpOverload, OpOverloadPacket
15*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor
16*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import unMarkDynamoStrictTest
17*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
18*da0073e9SAndroid Build Coastguard Worker    TestCase,
19*da0073e9SAndroid Build Coastguard Worker    skipIfCrossRef,
20*da0073e9SAndroid Build Coastguard Worker    skipIfTorchDynamo,
21*da0073e9SAndroid Build Coastguard Worker    suppress_warnings,
22*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_ASAN,
23*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_TORCHDYNAMO,
24*da0073e9SAndroid Build Coastguard Worker    run_tests,
25*da0073e9SAndroid Build Coastguard Worker    dtype_abbrs,
26*da0073e9SAndroid Build Coastguard Worker    parametrize
27*da0073e9SAndroid Build Coastguard Worker)
28*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import (
29*da0073e9SAndroid Build Coastguard Worker    ops,
30*da0073e9SAndroid Build Coastguard Worker    instantiate_device_type_tests,
31*da0073e9SAndroid Build Coastguard Worker    onlyCUDA,
32*da0073e9SAndroid Build Coastguard Worker    onlyCPU,
33*da0073e9SAndroid Build Coastguard Worker    OpDTypes,
34*da0073e9SAndroid Build Coastguard Worker)
35*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import (
36*da0073e9SAndroid Build Coastguard Worker    binary_ufuncs, op_db, foreach_unary_op_db, foreach_binary_op_db,
37*da0073e9SAndroid Build Coastguard Worker    foreach_pointwise_op_db, foreach_reduce_op_db, foreach_other_op_db)
38*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.opinfo.core import S, SampleInput
39*da0073e9SAndroid Build Coastguard Workerfrom torchgen.yaml_utils import YamlLoader
40*da0073e9SAndroid Build Coastguard Workerfrom torchgen.model import OperatorName
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Workerimport copy
43*da0073e9SAndroid Build Coastguard Workerimport sys
44*da0073e9SAndroid Build Coastguard Workerimport yaml
45*da0073e9SAndroid Build Coastguard Workerimport atexit
46*da0073e9SAndroid Build Coastguard Workerimport re
47*da0073e9SAndroid Build Coastguard Workerfrom collections import defaultdict
48*da0073e9SAndroid Build Coastguard Workerfrom collections.abc import Iterable
49*da0073e9SAndroid Build Coastguard Workerimport unittest
50*da0073e9SAndroid Build Coastguard Workerimport warnings
51*da0073e9SAndroid Build Coastguard Workerimport weakref
52*da0073e9SAndroid Build Coastguard Workerfrom functools import partial, wraps
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Workerbf16 = torch.bfloat16
55*da0073e9SAndroid Build Coastguard Workerf64 = torch.float64
56*da0073e9SAndroid Build Coastguard Workerf32 = torch.float32
57*da0073e9SAndroid Build Coastguard Workerf16 = torch.float16
58*da0073e9SAndroid Build Coastguard Workerc32 = torch.complex32
59*da0073e9SAndroid Build Coastguard Workerc64 = torch.complex64
60*da0073e9SAndroid Build Coastguard Workerc128 = torch.complex128
61*da0073e9SAndroid Build Coastguard Workeri8 = torch.int8
62*da0073e9SAndroid Build Coastguard Workeri16 = torch.int16
63*da0073e9SAndroid Build Coastguard Workeri32 = torch.int32
64*da0073e9SAndroid Build Coastguard Workeri64 = torch.int64
65*da0073e9SAndroid Build Coastguard Workerb8 = torch.bool
66*da0073e9SAndroid Build Coastguard Workeru8 = torch.uint8
67*da0073e9SAndroid Build Coastguard Workeru16 = torch.uint16
68*da0073e9SAndroid Build Coastguard Workeru32 = torch.uint32
69*da0073e9SAndroid Build Coastguard Workeru64 = torch.uint64
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Workerforeach_op_db = (
72*da0073e9SAndroid Build Coastguard Worker    foreach_unary_op_db +
73*da0073e9SAndroid Build Coastguard Worker    foreach_binary_op_db +
74*da0073e9SAndroid Build Coastguard Worker    foreach_pointwise_op_db +
75*da0073e9SAndroid Build Coastguard Worker    foreach_reduce_op_db +
76*da0073e9SAndroid Build Coastguard Worker    foreach_other_op_db
77*da0073e9SAndroid Build Coastguard Worker)
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Workerclass TestMetaConverter(TestCase):
81*da0073e9SAndroid Build Coastguard Worker    def assertSameVersionCounter(self, m1, m2):
82*da0073e9SAndroid Build Coastguard Worker        # Cannot easily test m1 and m2 have same storage due to
83*da0073e9SAndroid Build Coastguard Worker        # lack of Storage bindings.  Use version counter.
84*da0073e9SAndroid Build Coastguard Worker        vc = m1._version
85*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m2._version, vc)
86*da0073e9SAndroid Build Coastguard Worker        # Doing it this way ensures that we get VC bump even with leaves
87*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
88*da0073e9SAndroid Build Coastguard Worker            m1._base.add_(3)
89*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(m1._version, vc)
90*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m2._version, m1._version)
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker    def assertMetadataMatches(self, m1, m2):
93*da0073e9SAndroid Build Coastguard Worker        assert_metadata_eq(self.assertEqual, m1, m2)
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker    def test_view_of_non_leaf(self):
96*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, requires_grad=True)
97*da0073e9SAndroid Build Coastguard Worker        y = x.neg()
98*da0073e9SAndroid Build Coastguard Worker        z1 = y[:]
99*da0073e9SAndroid Build Coastguard Worker        z2 = y[:]
100*da0073e9SAndroid Build Coastguard Worker        to_meta = MetaConverter()
101*da0073e9SAndroid Build Coastguard Worker        m1 = to_meta(z1)
102*da0073e9SAndroid Build Coastguard Worker        m2 = to_meta(z2)
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker        # check the test is actually testing what it claims
105*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m1._is_view())
106*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m1._base.is_leaf)
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker        self.assertIsNot(m1, m2)
109*da0073e9SAndroid Build Coastguard Worker        self.assertMetadataMatches(m1, z1)
110*da0073e9SAndroid Build Coastguard Worker        self.assertMetadataMatches(m2, z2)
111*da0073e9SAndroid Build Coastguard Worker        self.assertSameVersionCounter(m1, m2)
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker    def test_view_of_leaf(self):
114*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, requires_grad=True)
115*da0073e9SAndroid Build Coastguard Worker        z1 = x[:]
116*da0073e9SAndroid Build Coastguard Worker        z2 = x[:]
117*da0073e9SAndroid Build Coastguard Worker        to_meta = MetaConverter()
118*da0073e9SAndroid Build Coastguard Worker        m1 = to_meta(z1)
119*da0073e9SAndroid Build Coastguard Worker        m2 = to_meta(z2)
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker        # check the test is actually testing what it claims
122*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m1._is_view())
123*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m1._base.is_leaf)
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker        self.assertIsNot(m1, m2)
126*da0073e9SAndroid Build Coastguard Worker        self.assertMetadataMatches(m1, z1)
127*da0073e9SAndroid Build Coastguard Worker        self.assertMetadataMatches(m2, z2)
128*da0073e9SAndroid Build Coastguard Worker        self.assertSameVersionCounter(m1, m2)
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Worker    def test_view_of_view_of_leaf(self):
131*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(8)
132*da0073e9SAndroid Build Coastguard Worker        y = x.view(2, 4)
133*da0073e9SAndroid Build Coastguard Worker        y.requires_grad = True
134*da0073e9SAndroid Build Coastguard Worker        z = y.view(2, 2, 2)
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard Worker        to_meta = MetaConverter()
137*da0073e9SAndroid Build Coastguard Worker        mx = to_meta(x)
138*da0073e9SAndroid Build Coastguard Worker        mz = to_meta(z)
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(z.is_leaf)
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker        self.assertMetadataMatches(mx, x)
143*da0073e9SAndroid Build Coastguard Worker        self.assertMetadataMatches(mz, z)
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Worker    def test_leaf(self):
146*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, requires_grad=True)
147*da0073e9SAndroid Build Coastguard Worker        to_meta = MetaConverter()
148*da0073e9SAndroid Build Coastguard Worker        m = to_meta(x)
149*da0073e9SAndroid Build Coastguard Worker
150*da0073e9SAndroid Build Coastguard Worker        # check the test is actually testing what it claims
151*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m.is_leaf)
152*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m.requires_grad)
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker        self.assertMetadataMatches(m, x)
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Worker    def test_non_leaf(self):
157*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, requires_grad=True)
158*da0073e9SAndroid Build Coastguard Worker        y = x.neg()
159*da0073e9SAndroid Build Coastguard Worker        to_meta = MetaConverter()
160*da0073e9SAndroid Build Coastguard Worker        m = to_meta(y)
161*da0073e9SAndroid Build Coastguard Worker
162*da0073e9SAndroid Build Coastguard Worker        # check the test is actually testing what it claims
163*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m.is_leaf)
164*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m.requires_grad)
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Worker        self.assertMetadataMatches(m, y)
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker    def test_requires_grad_false(self):
169*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, requires_grad=False)
170*da0073e9SAndroid Build Coastguard Worker        to_meta = MetaConverter()
171*da0073e9SAndroid Build Coastguard Worker        m = to_meta(x)
172*da0073e9SAndroid Build Coastguard Worker
173*da0073e9SAndroid Build Coastguard Worker        # check the test is actually testing what it claims
174*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m.requires_grad)
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker        self.assertMetadataMatches(m, x)
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Worker    def test_channels_last(self):
179*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(2, 3, 4, 5, memory_format=torch.channels_last)
180*da0073e9SAndroid Build Coastguard Worker        to_meta = MetaConverter()
181*da0073e9SAndroid Build Coastguard Worker        m = to_meta(x)
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker        # check the test is actually testing what it claims
184*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m.is_leaf)
185*da0073e9SAndroid Build Coastguard Worker
186*da0073e9SAndroid Build Coastguard Worker        self.assertMetadataMatches(m, x)
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Worker    def test_channels_last_leaf(self):
189*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(2, 3, 4, 5, memory_format=torch.channels_last, requires_grad=True)
190*da0073e9SAndroid Build Coastguard Worker        to_meta = MetaConverter()
191*da0073e9SAndroid Build Coastguard Worker        m = to_meta(x)
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker        # check the test is actually testing what it claims
194*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m.requires_grad)
195*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m.is_leaf)
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker        self.assertMetadataMatches(m, x)
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard Worker    def test_channels_last_non_leaf(self):
200*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(2, 3, 4, 5, memory_format=torch.channels_last, requires_grad=True)
201*da0073e9SAndroid Build Coastguard Worker        y = x + 2
202*da0073e9SAndroid Build Coastguard Worker
203*da0073e9SAndroid Build Coastguard Worker        # sanity
204*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.stride(), y.stride())
205*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(y.is_leaf)
206*da0073e9SAndroid Build Coastguard Worker
207*da0073e9SAndroid Build Coastguard Worker        to_meta = MetaConverter()
208*da0073e9SAndroid Build Coastguard Worker        m = to_meta(y)
209*da0073e9SAndroid Build Coastguard Worker
210*da0073e9SAndroid Build Coastguard Worker        # check the test is actually testing what it claims
211*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m.requires_grad)
212*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m.is_leaf)
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker        self.assertMetadataMatches(m, y)
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker        # Check that we can autograd with m as input without erroring;
217*da0073e9SAndroid Build Coastguard Worker        # see https://github.com/pytorch/pytorch/issues/87956
218*da0073e9SAndroid Build Coastguard Worker        loss = m.sum()
219*da0073e9SAndroid Build Coastguard Worker        torch.autograd.grad(loss, m)
220*da0073e9SAndroid Build Coastguard Worker
221*da0073e9SAndroid Build Coastguard Worker    def test_empty_strided_non_dense_leaf(self):
222*da0073e9SAndroid Build Coastguard Worker        x = torch.empty_strided((2, 2), (4, 2), requires_grad=True)
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Worker        to_meta = MetaConverter()
225*da0073e9SAndroid Build Coastguard Worker        m = to_meta(x)
226*da0073e9SAndroid Build Coastguard Worker
227*da0073e9SAndroid Build Coastguard Worker        # check the test is actually testing what it claims
228*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m.requires_grad)
229*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m.is_leaf)
230*da0073e9SAndroid Build Coastguard Worker
231*da0073e9SAndroid Build Coastguard Worker        self.assertMetadataMatches(m, x)
232*da0073e9SAndroid Build Coastguard Worker
233*da0073e9SAndroid Build Coastguard Worker    def test_view_mutate(self):
234*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(4)
235*da0073e9SAndroid Build Coastguard Worker        y = x.view(2, 2)
236*da0073e9SAndroid Build Coastguard Worker
237*da0073e9SAndroid Build Coastguard Worker        to_meta = MetaConverter()
238*da0073e9SAndroid Build Coastguard Worker        m = to_meta(y)
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker        y.add_(torch.randn(2, 2, requires_grad=True))
241*da0073e9SAndroid Build Coastguard Worker        m.add_(torch.randn(2, 2, device='meta', requires_grad=True))
242*da0073e9SAndroid Build Coastguard Worker
243*da0073e9SAndroid Build Coastguard Worker    def test_non_leaf_torture(self):
244*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(20, requires_grad=True)
245*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
246*da0073e9SAndroid Build Coastguard Worker            x.set_(x.storage(), 10, (2,), (2,))
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker        to_meta = MetaConverter()
249*da0073e9SAndroid Build Coastguard Worker        m = to_meta(x)
250*da0073e9SAndroid Build Coastguard Worker
251*da0073e9SAndroid Build Coastguard Worker        # check the test is actually testing what it claims
252*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m.requires_grad)
253*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m.is_leaf)
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Worker        self.assertMetadataMatches(m, x)
256*da0073e9SAndroid Build Coastguard Worker
257*da0073e9SAndroid Build Coastguard Worker    # NB: complex stuff is not actually exercised right now because
258*da0073e9SAndroid Build Coastguard Worker    # we have a blanket exclusion for complex conversion
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker    def test_view_as_real(self):
261*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, dtype=torch.complex64)
262*da0073e9SAndroid Build Coastguard Worker        y = torch.view_as_real(x)
263*da0073e9SAndroid Build Coastguard Worker        m = MetaConverter()(y)
264*da0073e9SAndroid Build Coastguard Worker        self.assertMetadataMatches(m, y)
265*da0073e9SAndroid Build Coastguard Worker
266*da0073e9SAndroid Build Coastguard Worker    def test_complex_noncontiguous_bug(self):
267*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((2, 2, 4, 9), dtype=torch.complex32)[:, 0, :, :]
268*da0073e9SAndroid Build Coastguard Worker        m = MetaConverter()(x)
269*da0073e9SAndroid Build Coastguard Worker        self.assertMetadataMatches(m, x)
270*da0073e9SAndroid Build Coastguard Worker
271*da0073e9SAndroid Build Coastguard Worker    def test_view_as_complex(self):
272*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((4, 2), dtype=torch.float32)
273*da0073e9SAndroid Build Coastguard Worker        y = torch.view_as_complex(x)
274*da0073e9SAndroid Build Coastguard Worker        m = MetaConverter()(y)
275*da0073e9SAndroid Build Coastguard Worker        self.assertMetadataMatches(m, y)
276*da0073e9SAndroid Build Coastguard Worker
277*da0073e9SAndroid Build Coastguard Worker    def test_view_dtype(self):
278*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, dtype=torch.float32)
279*da0073e9SAndroid Build Coastguard Worker        y = x.view(dtype=torch.int32)
280*da0073e9SAndroid Build Coastguard Worker        m = MetaConverter()(y)
281*da0073e9SAndroid Build Coastguard Worker        self.assertMetadataMatches(m, y)
282*da0073e9SAndroid Build Coastguard Worker
283*da0073e9SAndroid Build Coastguard Worker    def test_imag(self):
284*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, dtype=torch.complex64)
285*da0073e9SAndroid Build Coastguard Worker        y = x.imag
286*da0073e9SAndroid Build Coastguard Worker        m = MetaConverter()(y)
287*da0073e9SAndroid Build Coastguard Worker        self.assertMetadataMatches(m, y)
288*da0073e9SAndroid Build Coastguard Worker
289*da0073e9SAndroid Build Coastguard Worker    def test_inplace_set_storage(self):
290*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([0, 1], dtype=torch.int64)
291*da0073e9SAndroid Build Coastguard Worker        storage = x.untyped_storage()
292*da0073e9SAndroid Build Coastguard Worker        ssize = storage.size()
293*da0073e9SAndroid Build Coastguard Worker        meta = torch.empty((), dtype=torch.int64)
294*da0073e9SAndroid Build Coastguard Worker        meta.set_(storage, 0, (), ())
295*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(storage.size(), ssize)
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
298*da0073e9SAndroid Build Coastguard Worker    def test_weakref(self):
299*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 4, 4)
300*da0073e9SAndroid Build Coastguard Worker        m = MetaConverter()
301*da0073e9SAndroid Build Coastguard Worker        y = m(x)
302*da0073e9SAndroid Build Coastguard Worker        z = m(x)
303*da0073e9SAndroid Build Coastguard Worker        self.assertIs(y, z)
304*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(m.tensor_memo), 1)
305*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(m.storage_memo), 1)
306*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(m.describer.lookup_tensor), 1)
307*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(m.describer.lookup_storage), 1)
308*da0073e9SAndroid Build Coastguard Worker        del x
309*da0073e9SAndroid Build Coastguard Worker        # Entries from Tensor -> int get deallocated when the real tensor
310*da0073e9SAndroid Build Coastguard Worker        # disappears...
311*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(m.describer.lookup_tensor), 0)
312*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(m.describer.lookup_storage), 0)
313*da0073e9SAndroid Build Coastguard Worker        del y
314*da0073e9SAndroid Build Coastguard Worker        del z
315*da0073e9SAndroid Build Coastguard Worker        # ... but the int -> FakeTensor entries don't die until the fake
316*da0073e9SAndroid Build Coastguard Worker        # tensors themselves die (because the user may have held onto the
317*da0073e9SAndroid Build Coastguard Worker        # int key and are expecting to get a consistent fake tensor in
318*da0073e9SAndroid Build Coastguard Worker        # this case)
319*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(m.tensor_memo), 0)
320*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(m.storage_memo), 0)
321*da0073e9SAndroid Build Coastguard Worker        li = []
322*da0073e9SAndroid Build Coastguard Worker        r = []
323*da0073e9SAndroid Build Coastguard Worker        for i in range(4):
324*da0073e9SAndroid Build Coastguard Worker            li.append(torch.rand([i]))
325*da0073e9SAndroid Build Coastguard Worker            r.append(m(li[-1]))
326*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(m.tensor_memo), 4)
327*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(m.storage_memo), 4)
328*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(m.describer.lookup_tensor), 4)
329*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(m.describer.lookup_storage), 4)
330*da0073e9SAndroid Build Coastguard Worker        del li
331*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(m.describer.lookup_tensor), 0)
332*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(m.describer.lookup_storage), 0)
333*da0073e9SAndroid Build Coastguard Worker        del r
334*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(m.tensor_memo), 0)
335*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(m.storage_memo), 0)
336*da0073e9SAndroid Build Coastguard Worker
337*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
338*da0073e9SAndroid Build Coastguard Worker    def test_tensor_outlives_converter(self):
339*da0073e9SAndroid Build Coastguard Worker        m = MetaConverter()
340*da0073e9SAndroid Build Coastguard Worker        ref = weakref.ref(m)
341*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([4, 4])
342*da0073e9SAndroid Build Coastguard Worker        y = m(x)
343*da0073e9SAndroid Build Coastguard Worker        del m
344*da0073e9SAndroid Build Coastguard Worker        self.assertIs(ref(), None)
345*da0073e9SAndroid Build Coastguard Worker
346*da0073e9SAndroid Build Coastguard Workeraten = torch.ops.aten
347*da0073e9SAndroid Build Coastguard Worker
348*da0073e9SAndroid Build Coastguard WorkerCHECK_STRIDES = {
349*da0073e9SAndroid Build Coastguard Worker    torch.Tensor.__getitem__,
350*da0073e9SAndroid Build Coastguard Worker}
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard WorkerCHECK_ALL_STRIDES = {
353*da0073e9SAndroid Build Coastguard Worker    aten.unsqueeze.default
354*da0073e9SAndroid Build Coastguard Worker}
355*da0073e9SAndroid Build Coastguard Worker
356*da0073e9SAndroid Build Coastguard WorkerCHECK_STRIDES_SKIPS = {
357*da0073e9SAndroid Build Coastguard Worker    aten._conj_physical.default,
358*da0073e9SAndroid Build Coastguard Worker    aten._fft_c2c.default,
359*da0073e9SAndroid Build Coastguard Worker    aten._fft_c2r.default,
360*da0073e9SAndroid Build Coastguard Worker    aten._fft_r2c.default,
361*da0073e9SAndroid Build Coastguard Worker    aten._linalg_svd.default,
362*da0073e9SAndroid Build Coastguard Worker    aten.binary_cross_entropy.default,
363*da0073e9SAndroid Build Coastguard Worker    aten.complex.default,
364*da0073e9SAndroid Build Coastguard Worker    aten.polar.default,
365*da0073e9SAndroid Build Coastguard Worker    aten.copysign.Tensor,
366*da0073e9SAndroid Build Coastguard Worker    aten.div.Tensor_mode,
367*da0073e9SAndroid Build Coastguard Worker    aten.floor_divide.default,
368*da0073e9SAndroid Build Coastguard Worker    aten.heaviside.default,
369*da0073e9SAndroid Build Coastguard Worker    aten.lerp.Scalar,
370*da0073e9SAndroid Build Coastguard Worker    aten.lerp.Tensor,
371*da0073e9SAndroid Build Coastguard Worker    aten.logaddexp.default,
372*da0073e9SAndroid Build Coastguard Worker    aten.logical_and.default,
373*da0073e9SAndroid Build Coastguard Worker    aten.logical_or.default,
374*da0073e9SAndroid Build Coastguard Worker    aten.logical_xor.default,
375*da0073e9SAndroid Build Coastguard Worker    aten.pow.Scalar,
376*da0073e9SAndroid Build Coastguard Worker    aten.prelu.default,
377*da0073e9SAndroid Build Coastguard Worker    aten.special_xlog1py.default,
378*da0073e9SAndroid Build Coastguard Worker    aten.xlogy.Tensor,
379*da0073e9SAndroid Build Coastguard Worker    aten.nll_loss2d_forward.default,
380*da0073e9SAndroid Build Coastguard Worker
381*da0073e9SAndroid Build Coastguard Worker    # channel_last and channel_last_3d related failures
382*da0073e9SAndroid Build Coastguard Worker    aten.convolution.default,
383*da0073e9SAndroid Build Coastguard Worker
384*da0073e9SAndroid Build Coastguard Worker    # following ops fails if include_storage_offset = True, but these are a bit edge casey
385*da0073e9SAndroid Build Coastguard Worker    # we should still fix them, leaving them here for tracking.
386*da0073e9SAndroid Build Coastguard Worker    # aten._reshape_alias.default,  # repro with test_dispatch_symbolic_meta_outplace_all_strides_matmul_cuda_float32
387*da0073e9SAndroid Build Coastguard Worker    # aten.view.default,  # repro with test_dispatch_symbolic_meta_outplace_all_strides_unflatten_cuda_float32
388*da0073e9SAndroid Build Coastguard Worker}
389*da0073e9SAndroid Build Coastguard Worker
390*da0073e9SAndroid Build Coastguard WorkerCHECK_CONJ_SKIPS = {
391*da0073e9SAndroid Build Coastguard Worker    # The conj bit is not copied, see:
392*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/pull/101836
393*da0073e9SAndroid Build Coastguard Worker    aten.linalg_lu_solve.out,
394*da0073e9SAndroid Build Coastguard Worker}
395*da0073e9SAndroid Build Coastguard Worker
396*da0073e9SAndroid Build Coastguard Workerclass CheckStrides(Enum):
397*da0073e9SAndroid Build Coastguard Worker    NONE = 0
398*da0073e9SAndroid Build Coastguard Worker    SIGNIFICANT = 1
399*da0073e9SAndroid Build Coastguard Worker    ALL = 2
400*da0073e9SAndroid Build Coastguard Worker
401*da0073e9SAndroid Build Coastguard Workerdef should_check_strides(func):
402*da0073e9SAndroid Build Coastguard Worker    if func in CHECK_ALL_STRIDES:
403*da0073e9SAndroid Build Coastguard Worker        return CheckStrides.ALL
404*da0073e9SAndroid Build Coastguard Worker    if func in CHECK_STRIDES:
405*da0073e9SAndroid Build Coastguard Worker        return CheckStrides.SIGNIFICANT
406*da0073e9SAndroid Build Coastguard Worker    if func in CHECK_STRIDES_SKIPS:
407*da0073e9SAndroid Build Coastguard Worker        return CheckStrides.NONE
408*da0073e9SAndroid Build Coastguard Worker    if not isinstance(func, torch._ops.OpOverload):
409*da0073e9SAndroid Build Coastguard Worker        return CheckStrides.NONE
410*da0073e9SAndroid Build Coastguard Worker    # Prims are expected to model strides correctly
411*da0073e9SAndroid Build Coastguard Worker    if func.namespace == "prims":
412*da0073e9SAndroid Build Coastguard Worker        return CheckStrides.SIGNIFICANT
413*da0073e9SAndroid Build Coastguard Worker    # Check if it's a view, by testing if any of the returns have
414*da0073e9SAndroid Build Coastguard Worker    # a non-empty alias set
415*da0073e9SAndroid Build Coastguard Worker    if any(r.alias_info.before_set for r in func._schema.returns if r.alias_info):
416*da0073e9SAndroid Build Coastguard Worker        return CheckStrides.SIGNIFICANT
417*da0073e9SAndroid Build Coastguard Worker    # TODO: check for TensorIterator
418*da0073e9SAndroid Build Coastguard Worker    return CheckStrides.SIGNIFICANT
419*da0073e9SAndroid Build Coastguard Worker
420*da0073e9SAndroid Build Coastguard Workerdef assert_ref_meta_equal(test_case, func, meta_rs, rs, msg_callable):
421*da0073e9SAndroid Build Coastguard Worker    flat_meta_rs = pytree.tree_leaves(meta_rs)
422*da0073e9SAndroid Build Coastguard Worker    flat_rs = pytree.tree_leaves(rs)
423*da0073e9SAndroid Build Coastguard Worker    test_case.assertEqual(len(flat_meta_rs), len(flat_rs))
424*da0073e9SAndroid Build Coastguard Worker    for i, meta_r, r in zip(range(len(flat_rs)), flat_meta_rs, flat_rs):
425*da0073e9SAndroid Build Coastguard Worker        def test_assert(cond, msg):
426*da0073e9SAndroid Build Coastguard Worker            if not cond:
427*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(f"output {i}: {msg_callable(msg)}")
428*da0073e9SAndroid Build Coastguard Worker        if not isinstance(r, torch.Tensor):
429*da0073e9SAndroid Build Coastguard Worker            continue
430*da0073e9SAndroid Build Coastguard Worker        test_assert(isinstance(meta_r, torch.Tensor), f"but real {i}th result is Tensor")
431*da0073e9SAndroid Build Coastguard Worker        test_assert(meta_r.dtype == r.dtype, f"for element {i}, was {meta_r.dtype} but real dtype was {r.dtype}")
432*da0073e9SAndroid Build Coastguard Worker        test_assert(meta_r.shape == r.shape, f"for element {i}, was {meta_r.shape} but real shape was {r.shape}")
433*da0073e9SAndroid Build Coastguard Worker        # See https://github.com/pytorch/pytorch/issues/78050
434*da0073e9SAndroid Build Coastguard Worker        if should_check_strides(func) == CheckStrides.ALL:
435*da0073e9SAndroid Build Coastguard Worker            same_strides, _ = torch._prims_common.check_all_strides(meta_r, r)
436*da0073e9SAndroid Build Coastguard Worker            test_assert(same_strides, f"for element {i}, was {meta_r.stride()} but real stride was {r.stride()}")
437*da0073e9SAndroid Build Coastguard Worker        elif should_check_strides(func) == CheckStrides.SIGNIFICANT:
438*da0073e9SAndroid Build Coastguard Worker            same_strides, _ = torch._prims_common.check_significant_strides(meta_r, r)
439*da0073e9SAndroid Build Coastguard Worker            test_assert(same_strides, f"for element {i}, was {meta_r.stride()} but real stride was {r.stride()}")
440*da0073e9SAndroid Build Coastguard Worker        test_assert(
441*da0073e9SAndroid Build Coastguard Worker            meta_r.storage_offset() == r.storage_offset(),
442*da0073e9SAndroid Build Coastguard Worker            f"for element {i}, was {meta_r.storage_offset()} but real storage_offset was {r.storage_offset()}")
443*da0073e9SAndroid Build Coastguard Worker        test_assert(meta_r.requires_grad == r.requires_grad,
444*da0073e9SAndroid Build Coastguard Worker                    f"for element {i}, was {meta_r.requires_grad} but real requires_grad was {r.requires_grad}")
445*da0073e9SAndroid Build Coastguard Worker        if func not in CHECK_CONJ_SKIPS:
446*da0073e9SAndroid Build Coastguard Worker            test_assert(meta_r.is_conj() == r.is_conj(),
447*da0073e9SAndroid Build Coastguard Worker                        f"for element {i}, was {meta_r.is_conj()} but real is_conj was {r.is_conj()}")
448*da0073e9SAndroid Build Coastguard Worker        test_assert(meta_r.is_neg() == r.is_neg(), f"for element {i}, was {meta_r.is_neg()} but real is_neg was {r.is_neg()}")
449*da0073e9SAndroid Build Coastguard Worker
450*da0073e9SAndroid Build Coastguard Worker
451*da0073e9SAndroid Build Coastguard Worker# This environment variable controls whether or not we print expected failure
452*da0073e9SAndroid Build Coastguard Worker# lists at the end of a test suite run.  The intended usage looks like this:
453*da0073e9SAndroid Build Coastguard Worker#
454*da0073e9SAndroid Build Coastguard Worker# 1. Run `PYTORCH_COLLECT_EXPECT=1 python test/test_meta.py` on a CUDA build
455*da0073e9SAndroid Build Coastguard Worker#    of PyTorch that has LAPACK/MAGMA installed.  You can filter `-k test_meta`
456*da0073e9SAndroid Build Coastguard Worker#    or `-k test_dispatch_meta` to only focus on one or another list
457*da0073e9SAndroid Build Coastguard Worker# 2. Given the printed skip/xfail list, add them to the corresponding lists;
458*da0073e9SAndroid Build Coastguard Worker#    torch.* entries go in meta_function and aten.* entries go in meta_dispatch.
459*da0073e9SAndroid Build Coastguard Worker#    If there are preexisting entries, you need to merge in the entries.
460*da0073e9SAndroid Build Coastguard Worker#
461*da0073e9SAndroid Build Coastguard Worker# This is somewhat manual but typically you shouldn't need to do this, unless
462*da0073e9SAndroid Build Coastguard Worker# you've made a major change (e.g., added a new dtype to PyTorch) and need to
463*da0073e9SAndroid Build Coastguard Worker# refresh the lists.  If you want to do it from scratch, just clear out the
464*da0073e9SAndroid Build Coastguard Worker# preexisting lists before running.
465*da0073e9SAndroid Build Coastguard Worker#
466*da0073e9SAndroid Build Coastguard Worker# WARNING: Python dict literals will silently ignore duplicate keys
467*da0073e9SAndroid Build Coastguard WorkerCOLLECT_EXPECT = os.getenv('PYTORCH_COLLECT_EXPECT', '0') == '1'
468*da0073e9SAndroid Build Coastguard Worker
469*da0073e9SAndroid Build Coastguard Workerseen_succeeded = {}
470*da0073e9SAndroid Build Coastguard Workerseen_failed = {}
471*da0073e9SAndroid Build Coastguard Workerfailed_reasons = defaultdict(set)
472*da0073e9SAndroid Build Coastguard Workerdef print_seen():
473*da0073e9SAndroid Build Coastguard Worker    expected_failures = []
474*da0073e9SAndroid Build Coastguard Worker    skips = []
475*da0073e9SAndroid Build Coastguard Worker
476*da0073e9SAndroid Build Coastguard Worker    def fmt_dtypes(dtypes):
477*da0073e9SAndroid Build Coastguard Worker        r = ', '.join(sorted(dtype_abbrs[d] for d in dtypes))
478*da0073e9SAndroid Build Coastguard Worker        return '{' + r + '}'
479*da0073e9SAndroid Build Coastguard Worker
480*da0073e9SAndroid Build Coastguard Worker    for op, failed_dtypes in seen_failed.items():
481*da0073e9SAndroid Build Coastguard Worker        ops = resolve_name(op)
482*da0073e9SAndroid Build Coastguard Worker        succeeded_dtypes = seen_succeeded.get(op, set())
483*da0073e9SAndroid Build Coastguard Worker        expected_failures_dtypes = failed_dtypes - succeeded_dtypes
484*da0073e9SAndroid Build Coastguard Worker        skips_dtypes = failed_dtypes & succeeded_dtypes
485*da0073e9SAndroid Build Coastguard Worker        reasons = ""
486*da0073e9SAndroid Build Coastguard Worker        if failed_reasons[op]:
487*da0073e9SAndroid Build Coastguard Worker            reasons = "  # " + ", ".join(sorted(failed_reasons[op]))
488*da0073e9SAndroid Build Coastguard Worker        if expected_failures_dtypes:
489*da0073e9SAndroid Build Coastguard Worker            expected_failures.append(f"    {ops}: {fmt_dtypes(expected_failures_dtypes)},{reasons}")
490*da0073e9SAndroid Build Coastguard Worker        if skips_dtypes:
491*da0073e9SAndroid Build Coastguard Worker            skips.append(f"    {ops}: {fmt_dtypes(skips_dtypes)},")
492*da0073e9SAndroid Build Coastguard Worker    expected_failures.sort()
493*da0073e9SAndroid Build Coastguard Worker    skips.sort()
494*da0073e9SAndroid Build Coastguard Worker    nl = '\n'
495*da0073e9SAndroid Build Coastguard Worker    print(f"""\
496*da0073e9SAndroid Build Coastguard Workerexpected_failures = {{
497*da0073e9SAndroid Build Coastguard Worker{nl.join(expected_failures)}
498*da0073e9SAndroid Build Coastguard Worker}}
499*da0073e9SAndroid Build Coastguard Worker
500*da0073e9SAndroid Build Coastguard Workerskips = {{
501*da0073e9SAndroid Build Coastguard Worker{nl.join(skips)}
502*da0073e9SAndroid Build Coastguard Worker}}
503*da0073e9SAndroid Build Coastguard Worker""")
504*da0073e9SAndroid Build Coastguard Workerif COLLECT_EXPECT:
505*da0073e9SAndroid Build Coastguard Worker    atexit.register(print_seen)
506*da0073e9SAndroid Build Coastguard Worker
507*da0073e9SAndroid Build Coastguard Worker# Success forces pass; failure forces fail; skip unconditionally skips testing
508*da0073e9SAndroid Build Coastguard WorkerTestExpect = Enum("TestExpect", ("SUCCESS", "XFAILURE", "SKIP"))
509*da0073e9SAndroid Build Coastguard Worker
510*da0073e9SAndroid Build Coastguard Worker# unlike print produce strides
511*da0073e9SAndroid Build Coastguard Workerdef verbose_print(e):
512*da0073e9SAndroid Build Coastguard Worker    class Lit:
513*da0073e9SAndroid Build Coastguard Worker        def __init__(self, s):
514*da0073e9SAndroid Build Coastguard Worker            self.s = s
515*da0073e9SAndroid Build Coastguard Worker
516*da0073e9SAndroid Build Coastguard Worker        def __repr__(self):
517*da0073e9SAndroid Build Coastguard Worker            return self.s
518*da0073e9SAndroid Build Coastguard Worker
519*da0073e9SAndroid Build Coastguard Worker    def go(t):
520*da0073e9SAndroid Build Coastguard Worker        if is_sparse_any(t):
521*da0073e9SAndroid Build Coastguard Worker            return t
522*da0073e9SAndroid Build Coastguard Worker        elif isinstance(t, torch.Tensor):
523*da0073e9SAndroid Build Coastguard Worker            return Lit(f"{t} stride={t.stride()}")
524*da0073e9SAndroid Build Coastguard Worker        else:
525*da0073e9SAndroid Build Coastguard Worker            return t
526*da0073e9SAndroid Build Coastguard Worker
527*da0073e9SAndroid Build Coastguard Worker    return repr(tree_map(go, e))
528*da0073e9SAndroid Build Coastguard Worker
529*da0073e9SAndroid Build Coastguard Workerdef run_meta_crossref(
530*da0073e9SAndroid Build Coastguard Worker    test_case,
531*da0073e9SAndroid Build Coastguard Worker    test_expect,
532*da0073e9SAndroid Build Coastguard Worker    func,
533*da0073e9SAndroid Build Coastguard Worker    args,
534*da0073e9SAndroid Build Coastguard Worker    kwargs,
535*da0073e9SAndroid Build Coastguard Worker    *,
536*da0073e9SAndroid Build Coastguard Worker    dtype,
537*da0073e9SAndroid Build Coastguard Worker    device_type,
538*da0073e9SAndroid Build Coastguard Worker    run_symbolic_meta: bool
539*da0073e9SAndroid Build Coastguard Worker):
540*da0073e9SAndroid Build Coastguard Worker    to_meta = MetaConverter()
541*da0073e9SAndroid Build Coastguard Worker    do_meta = test_expect is not TestExpect.SKIP
542*da0073e9SAndroid Build Coastguard Worker    if do_meta:
543*da0073e9SAndroid Build Coastguard Worker        try:
544*da0073e9SAndroid Build Coastguard Worker            meta_args = tree_map(to_meta, args)
545*da0073e9SAndroid Build Coastguard Worker            meta_kwargs = tree_map(to_meta, kwargs)
546*da0073e9SAndroid Build Coastguard Worker        except Exception as e:
547*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
548*da0073e9SAndroid Build Coastguard Worker                f"failed to convert args to meta; "
549*da0073e9SAndroid Build Coastguard Worker                f"originally (*{args}, **{kwargs})") from e
550*da0073e9SAndroid Build Coastguard Worker    try:
551*da0073e9SAndroid Build Coastguard Worker        rs = func(*args, **kwargs)
552*da0073e9SAndroid Build Coastguard Worker    except Exception as e:
553*da0073e9SAndroid Build Coastguard Worker        raise AssertionError("Original OpInfo is broken") from e
554*da0073e9SAndroid Build Coastguard Worker
555*da0073e9SAndroid Build Coastguard Worker    # TODO: also handle cases where func raise an exception
556*da0073e9SAndroid Build Coastguard Worker
557*da0073e9SAndroid Build Coastguard Worker    # For now, only attempt if we managed to convert all tensor types
558*da0073e9SAndroid Build Coastguard Worker    # (if any of them failed, we're in a mixed device situation and
559*da0073e9SAndroid Build Coastguard Worker    # this isn't well supported)
560*da0073e9SAndroid Build Coastguard Worker    if do_meta and to_meta.successful():
561*da0073e9SAndroid Build Coastguard Worker        # Special cases
562*da0073e9SAndroid Build Coastguard Worker        if func is torch.tensor_split:
563*da0073e9SAndroid Build Coastguard Worker            # Use original indices_or_sections, this argument is data dependent
564*da0073e9SAndroid Build Coastguard Worker            meta_args = (meta_args[0], args[1]) + meta_args[2:]
565*da0073e9SAndroid Build Coastguard Worker        elif func is torch.Tensor.__getitem__:
566*da0073e9SAndroid Build Coastguard Worker            # Ensure boolean tensors use original
567*da0073e9SAndroid Build Coastguard Worker            assert len(args) == 2
568*da0073e9SAndroid Build Coastguard Worker            flat_args = pytree.tree_leaves(args[1])
569*da0073e9SAndroid Build Coastguard Worker            flat_meta_args, spec = tree_flatten(meta_args[1])
570*da0073e9SAndroid Build Coastguard Worker            flat_new_args = []
571*da0073e9SAndroid Build Coastguard Worker            for a, ma in zip(flat_args, flat_meta_args):
572*da0073e9SAndroid Build Coastguard Worker                flat_new_args.append(a if isinstance(a, torch.Tensor) and a.dtype in [torch.int8, torch.bool] else ma)
573*da0073e9SAndroid Build Coastguard Worker            meta_args = (meta_args[0], tree_unflatten(flat_new_args, spec))
574*da0073e9SAndroid Build Coastguard Worker        elif func in (torch.ops.aten.repeat_interleave.Tensor, torch.ops.aten.repeat_interleave.Tensor_out):
575*da0073e9SAndroid Build Coastguard Worker            if kwargs.get("output_size", None) is None:
576*da0073e9SAndroid Build Coastguard Worker                meta_args = args
577*da0073e9SAndroid Build Coastguard Worker            if func is torch.ops.aten.repeat_interleave.Tensor_out:
578*da0073e9SAndroid Build Coastguard Worker                meta_kwargs["out"] = kwargs["out"]
579*da0073e9SAndroid Build Coastguard Worker        elif func in (torch.ops.aten.index.Tensor, torch.ops.aten.index.Tensor_out):
580*da0073e9SAndroid Build Coastguard Worker            # Don't convert boolean tensors to meta as they will have nonzero
581*da0073e9SAndroid Build Coastguard Worker            # called on them
582*da0073e9SAndroid Build Coastguard Worker            indices = []
583*da0073e9SAndroid Build Coastguard Worker            for meta_index, real_index in zip(meta_args[1], args[1]):
584*da0073e9SAndroid Build Coastguard Worker                if meta_index is not None and meta_index.dtype in [torch.int8, torch.bool]:
585*da0073e9SAndroid Build Coastguard Worker                    indices.append(real_index)
586*da0073e9SAndroid Build Coastguard Worker                else:
587*da0073e9SAndroid Build Coastguard Worker                    indices.append(meta_index)
588*da0073e9SAndroid Build Coastguard Worker            meta_args = (meta_args[0], indices)
589*da0073e9SAndroid Build Coastguard Worker        elif func is torch.nn.functional.ctc_loss and all([isinstance(args[2], list), isinstance(args[3], list)]):
590*da0073e9SAndroid Build Coastguard Worker            # torch.ops.aten._ctc_loss.IntList has a meta kernel but
591*da0073e9SAndroid Build Coastguard Worker            # torch.ops.aten._ctc_loss.Tensor does not
592*da0073e9SAndroid Build Coastguard Worker            test_expect = TestExpect.SUCCESS
593*da0073e9SAndroid Build Coastguard Worker
594*da0073e9SAndroid Build Coastguard Worker        if kwargs.get("device", None) is not None:
595*da0073e9SAndroid Build Coastguard Worker            meta_kwargs["device"] = "meta"
596*da0073e9SAndroid Build Coastguard Worker
597*da0073e9SAndroid Build Coastguard Worker        try:
598*da0073e9SAndroid Build Coastguard Worker            # Suppress warnings, this doesn't matter for test_meta.py
599*da0073e9SAndroid Build Coastguard Worker            # but it does matter if you want to use this decorator
600*da0073e9SAndroid Build Coastguard Worker            # for cross-ref testing, as some tests may be looking at
601*da0073e9SAndroid Build Coastguard Worker            # errors
602*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings():
603*da0073e9SAndroid Build Coastguard Worker                warnings.simplefilter("ignore")
604*da0073e9SAndroid Build Coastguard Worker                if run_symbolic_meta:
605*da0073e9SAndroid Build Coastguard Worker                    # Run the decomps and meta kernels registered
606*da0073e9SAndroid Build Coastguard Worker                    # to the python dispatcher instead of the regular dispatcher.
607*da0073e9SAndroid Build Coastguard Worker                    # This should be the same set of kernels
608*da0073e9SAndroid Build Coastguard Worker                    # that fake tensor runs in dynamic shapes mode.
609*da0073e9SAndroid Build Coastguard Worker                    with enable_python_dispatcher():
610*da0073e9SAndroid Build Coastguard Worker                        meta_rs = func(*meta_args, **meta_kwargs)
611*da0073e9SAndroid Build Coastguard Worker                else:
612*da0073e9SAndroid Build Coastguard Worker                    meta_rs = func(*meta_args, **meta_kwargs)
613*da0073e9SAndroid Build Coastguard Worker        except Exception as e:
614*da0073e9SAndroid Build Coastguard Worker            if test_expect is TestExpect.XFAILURE:
615*da0073e9SAndroid Build Coastguard Worker                return rs
616*da0073e9SAndroid Build Coastguard Worker            seen_failed.setdefault(func, set()).add(dtype)
617*da0073e9SAndroid Build Coastguard Worker            if isinstance(e, NotImplementedError):
618*da0073e9SAndroid Build Coastguard Worker                m = RE_NOT_IMPLEMENTED_MSG.search(e.args[0])
619*da0073e9SAndroid Build Coastguard Worker                if m:
620*da0073e9SAndroid Build Coastguard Worker                    failed_reasons[func].add(m.group(1))
621*da0073e9SAndroid Build Coastguard Worker            if COLLECT_EXPECT:
622*da0073e9SAndroid Build Coastguard Worker                return rs
623*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(f"""\
624*da0073e9SAndroid Build Coastguard Workerfailed to run: {resolve_name(func)}(
625*da0073e9SAndroid Build Coastguard Worker*{verbose_print(meta_args)},
626*da0073e9SAndroid Build Coastguard Worker**{verbose_print(meta_kwargs)}
627*da0073e9SAndroid Build Coastguard Worker)""") from e
628*da0073e9SAndroid Build Coastguard Worker        else:
629*da0073e9SAndroid Build Coastguard Worker            try:
630*da0073e9SAndroid Build Coastguard Worker                delim = ',\n  '
631*da0073e9SAndroid Build Coastguard Worker                assert_ref_meta_equal(test_case, func, meta_rs, rs, lambda msg: f"""\
632*da0073e9SAndroid Build Coastguard Workermeta disagrees with real impl:
633*da0073e9SAndroid Build Coastguard Worker{resolve_name(func)}(
634*da0073e9SAndroid Build Coastguard Worker  {delim.join(map(verbose_print, meta_args))},
635*da0073e9SAndroid Build Coastguard Worker  {delim.join(k + ": " + verbose_print(v) for k, v in meta_kwargs.items())}
636*da0073e9SAndroid Build Coastguard Worker) = (
637*da0073e9SAndroid Build Coastguard Worker  {verbose_print(meta_rs)}
638*da0073e9SAndroid Build Coastguard Worker)
639*da0073e9SAndroid Build Coastguard Worker{msg}
640*da0073e9SAndroid Build Coastguard Worker""")
641*da0073e9SAndroid Build Coastguard Worker            except Exception:
642*da0073e9SAndroid Build Coastguard Worker                if test_expect is TestExpect.XFAILURE:
643*da0073e9SAndroid Build Coastguard Worker                    return rs
644*da0073e9SAndroid Build Coastguard Worker                seen_failed.setdefault(func, set()).add(dtype)
645*da0073e9SAndroid Build Coastguard Worker                if COLLECT_EXPECT:
646*da0073e9SAndroid Build Coastguard Worker                    return rs
647*da0073e9SAndroid Build Coastguard Worker                raise
648*da0073e9SAndroid Build Coastguard Worker            else:
649*da0073e9SAndroid Build Coastguard Worker                seen_succeeded.setdefault(func, set()).add(dtype)
650*da0073e9SAndroid Build Coastguard Worker                if test_expect is TestExpect.XFAILURE and not COLLECT_EXPECT:
651*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError(f"unexpected success {resolve_name(func)} {meta_args} {meta_kwargs}")
652*da0073e9SAndroid Build Coastguard Worker
653*da0073e9SAndroid Build Coastguard Worker    return rs
654*da0073e9SAndroid Build Coastguard Worker
655*da0073e9SAndroid Build Coastguard Worker
656*da0073e9SAndroid Build Coastguard Worker
657*da0073e9SAndroid Build Coastguard WorkerRE_NOT_IMPLEMENTED_MSG = re.compile(r"Could not run '([^']+)' with arguments ")
658*da0073e9SAndroid Build Coastguard Worker
659*da0073e9SAndroid Build Coastguard Workermeta_function_expected_failures = {
660*da0073e9SAndroid Build Coastguard Worker    torch.Tensor.to_sparse : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32},
661*da0073e9SAndroid Build Coastguard Worker    torch.allclose : {f64, f16, c128, c64, bf16, f32},
662*da0073e9SAndroid Build Coastguard Worker    torch.argwhere : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32},
663*da0073e9SAndroid Build Coastguard Worker    torch.combinations : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32},
664*da0073e9SAndroid Build Coastguard Worker    torch.corrcoef : {f64, i32, c128, i64, i16, u8, c64, bf16, f16, i8, f32},
665*da0073e9SAndroid Build Coastguard Worker    torch.cov : {f64, i32, c128, i64, i16, u8, c64, bf16, i8, f32, f16},
666*da0073e9SAndroid Build Coastguard Worker    torch.functional.istft : {f64, c64, c128, f32},
667*da0073e9SAndroid Build Coastguard Worker    torch.geqrf : {f64, c64, c128, f32},
668*da0073e9SAndroid Build Coastguard Worker    torch.masked_select : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32},
669*da0073e9SAndroid Build Coastguard Worker    torch.nonzero : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32},
670*da0073e9SAndroid Build Coastguard Worker    torch.Tensor.nonzero : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32},
671*da0073e9SAndroid Build Coastguard Worker    torch.Tensor.item : {f64, i32, c128, i64, i16, f16, u8, c32, c64, bf16, b8, i8, f32},
672*da0073e9SAndroid Build Coastguard Worker    torch.bincount : {i32, i64, u8, i16, i8},
673*da0073e9SAndroid Build Coastguard Worker    torch.functional.unique : {f64, i32, i64, u8, i16, f16, bf16, b8, i8, f32, u16, u32, u64},
674*da0073e9SAndroid Build Coastguard Worker    torch.functional.unique_consecutive : {f64, i32, i64, u8, i16, f16, bf16, b8, i8, f32, u16, u32, u64},
675*da0073e9SAndroid Build Coastguard Worker    torch.histogram : {f64, f32},
676*da0073e9SAndroid Build Coastguard Worker    torch.histogramdd : {f64, f32},
677*da0073e9SAndroid Build Coastguard Worker    torch.nn.functional.ctc_loss : {f64, f32},
678*da0073e9SAndroid Build Coastguard Worker    torch.nn.functional.gaussian_nll_loss : {f16, f64, bf16, f32},
679*da0073e9SAndroid Build Coastguard Worker    torch.linalg.lstsq : {f64, f32, c128, c64},
680*da0073e9SAndroid Build Coastguard Worker}
681*da0073e9SAndroid Build Coastguard Worker
682*da0073e9SAndroid Build Coastguard Workermeta_function_expected_failures_conditional = {
683*da0073e9SAndroid Build Coastguard Worker    torch.repeat_interleave : (lambda dtype, *args, **kwargs: not isinstance(kwargs.get("repeats", None), int)),
684*da0073e9SAndroid Build Coastguard Worker}
685*da0073e9SAndroid Build Coastguard Worker
686*da0073e9SAndroid Build Coastguard Worker"""
687*da0073e9SAndroid Build Coastguard Worker# This is some sample code for how we could dump these dicts into YAML
688*da0073e9SAndroid Build Coastguard Worker# file for easier reading/writing
689*da0073e9SAndroid Build Coastguard Workerimport yaml
690*da0073e9SAndroid Build Coastguard Workerprint(yaml.dump(
691*da0073e9SAndroid Build Coastguard Worker  {resolve_name(k): [dtype_abbrs[d] for d in v]
692*da0073e9SAndroid Build Coastguard Worker   for k, v in meta_function_expected_failures.items()}, default_flow_style=None))
693*da0073e9SAndroid Build Coastguard Workerimport sys
694*da0073e9SAndroid Build Coastguard Workersys.exit()
695*da0073e9SAndroid Build Coastguard Worker"""
696*da0073e9SAndroid Build Coastguard Worker
697*da0073e9SAndroid Build Coastguard Workermeta_function_skips = {
698*da0073e9SAndroid Build Coastguard Worker    torch.Tensor.__rmatmul__ : {bf16, c128, f64, f32, f16, c64},
699*da0073e9SAndroid Build Coastguard Worker    torch.Tensor.matmul : {f64, f32, c128, c64},
700*da0073e9SAndroid Build Coastguard Worker    torch.functional.atleast_2d : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64},
701*da0073e9SAndroid Build Coastguard Worker    torch.functional.atleast_3d : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64},
702*da0073e9SAndroid Build Coastguard Worker    torch.functional.cartesian_prod : {bf16, i8, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64},
703*da0073e9SAndroid Build Coastguard Worker    torch.functional.einsum : {bf16, c128, f64, f32, f16, c64},
704*da0073e9SAndroid Build Coastguard Worker    torch.inner : {f16, bf16, i8, i64, u8, c128, f64, i16, f32, i32, c64},
705*da0073e9SAndroid Build Coastguard Worker    torch.linalg.matrix_norm : {c128, f32, c64, f64},
706*da0073e9SAndroid Build Coastguard Worker    torch.linalg.matrix_rank : {c128, c64},
707*da0073e9SAndroid Build Coastguard Worker    torch.linalg.svd : {c128, c64},
708*da0073e9SAndroid Build Coastguard Worker    torch.matmul : {bf16, c128, f64, f32, f16, c64},
709*da0073e9SAndroid Build Coastguard Worker    torch.nanquantile : {f64, f32},
710*da0073e9SAndroid Build Coastguard Worker    torch.narrow : {bf16, i8, i64, u8, c128, b8, f64, i16, i32, f32, f16, c32, c64},
711*da0073e9SAndroid Build Coastguard Worker    torch.nn.functional.batch_norm : {f64, f32},
712*da0073e9SAndroid Build Coastguard Worker    torch.nn.functional.binary_cross_entropy : {bf16, f64, f32, f16},
713*da0073e9SAndroid Build Coastguard Worker    torch.nn.functional.dropout3d : {bf16, f64, f32, f16},
714*da0073e9SAndroid Build Coastguard Worker    torch.nn.functional.local_response_norm : {bf16, f64, f32, f16},
715*da0073e9SAndroid Build Coastguard Worker    torch.svd : {c128, c64},
716*da0073e9SAndroid Build Coastguard Worker    torch.take_along_dim : {bf16, i8, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64},
717*da0073e9SAndroid Build Coastguard Worker    torch.vstack : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64},
718*da0073e9SAndroid Build Coastguard Worker    torch.diff : {b8},
719*da0073e9SAndroid Build Coastguard Worker    torch.equal : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64},
720*da0073e9SAndroid Build Coastguard Worker    torch.nanmean : {bf16, f64, f32, f16, c32, c64, c128},
721*da0073e9SAndroid Build Coastguard Worker    torch.nn.functional.cross_entropy : {bf16, f64, f32},
722*da0073e9SAndroid Build Coastguard Worker    torch.nn.functional.nll_loss : {bf16, f64, f32},
723*da0073e9SAndroid Build Coastguard Worker    torch.linalg.cond : {c128, c64, f32, f64},
724*da0073e9SAndroid Build Coastguard Worker    torch.linalg.vecdot : {bf16, f64, f32, f16},
725*da0073e9SAndroid Build Coastguard Worker    torch.empty : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64},
726*da0073e9SAndroid Build Coastguard Worker    torch.Tensor.addbmm_: {bf16, c128, c64, f32, f64, i16, i32, i64, i8, u8},
727*da0073e9SAndroid Build Coastguard Worker    torch.nn.functional.one_hot : {i64},
728*da0073e9SAndroid Build Coastguard Worker}
729*da0073e9SAndroid Build Coastguard Worker
730*da0073e9SAndroid Build Coastguard Worker
731*da0073e9SAndroid Build Coastguard Workermeta_function_device_expected_failures = defaultdict(dict)
732*da0073e9SAndroid Build Coastguard Workermeta_function_device_expected_failures_only_outplace = defaultdict(dict)
733*da0073e9SAndroid Build Coastguard Workermeta_function_device_skips = defaultdict(dict)
734*da0073e9SAndroid Build Coastguard Worker
735*da0073e9SAndroid Build Coastguard Workermeta_function_device_expected_failures['cpu'] = {
736*da0073e9SAndroid Build Coastguard Worker    # TODO: The decomps for these batch norm ops return different dtypes depending
737*da0073e9SAndroid Build Coastguard Worker    # on the device. We should make this work better with meta tensors.
738*da0073e9SAndroid Build Coastguard Worker    torch.native_batch_norm: {bf16, f16},
739*da0073e9SAndroid Build Coastguard Worker    torch._native_batch_norm_legit: {bf16, f16},
740*da0073e9SAndroid Build Coastguard Worker    torch.ops.aten._batch_norm_with_update: {bf16, f16},
741*da0073e9SAndroid Build Coastguard Worker    torch.native_layer_norm: {bf16, f16},
742*da0073e9SAndroid Build Coastguard Worker}
743*da0073e9SAndroid Build Coastguard Worker
744*da0073e9SAndroid Build Coastguard Workermeta_function_device_expected_failures['cuda'] = {
745*da0073e9SAndroid Build Coastguard Worker    torch.corrcoef: {bf16, f16},  # aten::_local_scalar_dense
746*da0073e9SAndroid Build Coastguard Worker    torch.cov: {f16},  # aten::_local_scalar_dense
747*da0073e9SAndroid Build Coastguard Worker    torch.functional.unique: {f16},  # aten::_unique2, aten::unique_dim
748*da0073e9SAndroid Build Coastguard Worker    torch.functional.unique_consecutive: {f16},  # aten::unique_consecutive
749*da0073e9SAndroid Build Coastguard Worker    torch.geqrf: {f32, f64},  # aten::geqrf
750*da0073e9SAndroid Build Coastguard Worker}
751*da0073e9SAndroid Build Coastguard Worker
752*da0073e9SAndroid Build Coastguard Workermeta_function_device_skips['cpu'] = {
753*da0073e9SAndroid Build Coastguard Worker    # TODO: The decomps for these batch norm ops return different dtypes depending
754*da0073e9SAndroid Build Coastguard Worker    # on the device. We should make this work better with meta tensors.
755*da0073e9SAndroid Build Coastguard Worker    torch.native_batch_norm: {f32, f64},
756*da0073e9SAndroid Build Coastguard Worker    torch._native_batch_norm_legit: {f32, f64},
757*da0073e9SAndroid Build Coastguard Worker    torch.ops.aten._batch_norm_with_update: {f32, f64},
758*da0073e9SAndroid Build Coastguard Worker}
759*da0073e9SAndroid Build Coastguard Worker
760*da0073e9SAndroid Build Coastguard Workermeta_function_device_skips['cuda'] = {
761*da0073e9SAndroid Build Coastguard Worker    torch.inner: {f16},
762*da0073e9SAndroid Build Coastguard Worker    torch.linalg.matrix_rank: {f32, f64},
763*da0073e9SAndroid Build Coastguard Worker    torch.linalg.svd: {f32, f64},
764*da0073e9SAndroid Build Coastguard Worker    torch.nn.functional.cross_entropy: {f16},
765*da0073e9SAndroid Build Coastguard Worker    torch.nn.functional.interpolate: {f16},
766*da0073e9SAndroid Build Coastguard Worker    torch.nn.functional.nll_loss: {f16},
767*da0073e9SAndroid Build Coastguard Worker    torch.svd: {f32, f64},
768*da0073e9SAndroid Build Coastguard Worker}
769*da0073e9SAndroid Build Coastguard Worker
770*da0073e9SAndroid Build Coastguard Worker# This is a __torch_function__ mode that, when enabled, interposes every
771*da0073e9SAndroid Build Coastguard Worker# Torch API call and runs the operator as normal, and then reruns it
772*da0073e9SAndroid Build Coastguard Worker# with meta inputs, and then checks that everything about the output agrees.
773*da0073e9SAndroid Build Coastguard Worker# Most of the logic deals with faithfully replicating the original tensor
774*da0073e9SAndroid Build Coastguard Worker# as a meta tensor, which is nontrivial because there are a lot of subsystems
775*da0073e9SAndroid Build Coastguard Worker# that may potentially be exercised.
776*da0073e9SAndroid Build Coastguard Worker#
777*da0073e9SAndroid Build Coastguard Worker# That being said, this class is a little overkill for what it is doing in
778*da0073e9SAndroid Build Coastguard Worker# this test file (since I could have just inlined __torch_function__ on the
779*da0073e9SAndroid Build Coastguard Worker# OpInfo call, and OpInfos generally have very regular inputs), but it will be
780*da0073e9SAndroid Build Coastguard Worker# useful for more comprehensive testing e.g., as seen in
781*da0073e9SAndroid Build Coastguard Worker# https://github.com/pytorch/pytorch/pull/75994  The big benefit is it is
782*da0073e9SAndroid Build Coastguard Worker# A LOT more efficient that torch dispatch mode (at the cost of less coverage)
783*da0073e9SAndroid Build Coastguard Workerclass MetaCrossRefFunctionMode(torch.overrides.TorchFunctionMode):
784*da0073e9SAndroid Build Coastguard Worker    test_case: TestCase
785*da0073e9SAndroid Build Coastguard Worker    device_type: str
786*da0073e9SAndroid Build Coastguard Worker    dtype: torch.dtype
787*da0073e9SAndroid Build Coastguard Worker
788*da0073e9SAndroid Build Coastguard Worker    def __init__(self, test_case, *, device, dtype, inplace):
789*da0073e9SAndroid Build Coastguard Worker        self.test_case = test_case
790*da0073e9SAndroid Build Coastguard Worker        self.device_type = torch.device(device).type
791*da0073e9SAndroid Build Coastguard Worker        self.dtype = dtype
792*da0073e9SAndroid Build Coastguard Worker        self.inplace = inplace
793*da0073e9SAndroid Build Coastguard Worker
794*da0073e9SAndroid Build Coastguard Worker    def __torch_function__(self, func, types, args=(), kwargs=None):
795*da0073e9SAndroid Build Coastguard Worker        kwargs = kwargs or {}
796*da0073e9SAndroid Build Coastguard Worker
797*da0073e9SAndroid Build Coastguard Worker        if (
798*da0073e9SAndroid Build Coastguard Worker            torch.jit.is_tracing() or isinstance(func, torch.ScriptMethod) or
799*da0073e9SAndroid Build Coastguard Worker            # meta converter doesn't work correctly when no_dispatch() is on, so
800*da0073e9SAndroid Build Coastguard Worker            # skip running the crossref test in this case
801*da0073e9SAndroid Build Coastguard Worker            torch._C._dispatch_tls_local_exclude_set().has(torch._C.DispatchKey.Python)
802*da0073e9SAndroid Build Coastguard Worker        ):
803*da0073e9SAndroid Build Coastguard Worker            return func(*args, **kwargs)
804*da0073e9SAndroid Build Coastguard Worker
805*da0073e9SAndroid Build Coastguard Worker        if self.dtype in meta_function_skips.get(func, set()):
806*da0073e9SAndroid Build Coastguard Worker            test_expect = TestExpect.SKIP
807*da0073e9SAndroid Build Coastguard Worker        elif self.dtype in meta_function_device_skips[self.device_type].get(func, set()):
808*da0073e9SAndroid Build Coastguard Worker            test_expect = TestExpect.SKIP
809*da0073e9SAndroid Build Coastguard Worker        elif self.dtype in meta_function_expected_failures.get(func, set()):
810*da0073e9SAndroid Build Coastguard Worker            test_expect = TestExpect.XFAILURE
811*da0073e9SAndroid Build Coastguard Worker        elif self.dtype in meta_function_device_expected_failures[self.device_type].get(func, set()):
812*da0073e9SAndroid Build Coastguard Worker            test_expect = TestExpect.XFAILURE
813*da0073e9SAndroid Build Coastguard Worker        elif meta_function_expected_failures_conditional.get(func, lambda *_, **__: False)(self.dtype, *args, **kwargs):
814*da0073e9SAndroid Build Coastguard Worker            test_expect = TestExpect.XFAILURE
815*da0073e9SAndroid Build Coastguard Worker        elif not self.inplace and \
816*da0073e9SAndroid Build Coastguard Worker                self.dtype in meta_function_device_expected_failures_only_outplace[self.device_type].get(func, set()):
817*da0073e9SAndroid Build Coastguard Worker            test_expect = TestExpect.XFAILURE
818*da0073e9SAndroid Build Coastguard Worker        else:
819*da0073e9SAndroid Build Coastguard Worker            test_expect = TestExpect.SUCCESS
820*da0073e9SAndroid Build Coastguard Worker
821*da0073e9SAndroid Build Coastguard Worker        return run_meta_crossref(
822*da0073e9SAndroid Build Coastguard Worker            self.test_case, test_expect, func, args,
823*da0073e9SAndroid Build Coastguard Worker            kwargs, dtype=self.dtype, device_type=self.device_type, run_symbolic_meta=False
824*da0073e9SAndroid Build Coastguard Worker        )
825*da0073e9SAndroid Build Coastguard Worker
826*da0073e9SAndroid Build Coastguard Worker# these always fail
827*da0073e9SAndroid Build Coastguard Workermeta_dispatch_expected_failures = {
828*da0073e9SAndroid Build Coastguard Worker    aten.allclose.default: {f16, bf16, f32, f64, c64, c128},  # NotImplementedError: 'aten::_local_scalar_dense'
829*da0073e9SAndroid Build Coastguard Worker    aten.geqrf.default : {c64, c128, f64, f32},
830*da0073e9SAndroid Build Coastguard Worker    aten.linalg_lstsq.default : {c64, c128, f64, f32},
831*da0073e9SAndroid Build Coastguard Worker    aten.masked_select.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
832*da0073e9SAndroid Build Coastguard Worker    aten.masked_select.out : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
833*da0073e9SAndroid Build Coastguard Worker    aten.nonzero.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, c32, b8, i16, u8},
834*da0073e9SAndroid Build Coastguard Worker    aten.nonzero.out : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, c32, b8, i16, u8},
835*da0073e9SAndroid Build Coastguard Worker    aten._to_sparse.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
836*da0073e9SAndroid Build Coastguard Worker    aten._to_sparse.sparse_dim : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
837*da0073e9SAndroid Build Coastguard Worker    aten._ctc_loss.Tensor : {f32, f64},  # Shape of second output depends on data.
838*da0073e9SAndroid Build Coastguard Worker    aten._histogramdd_bin_edges.default : {f32, f64},
839*da0073e9SAndroid Build Coastguard Worker    aten._histogramdd_from_bin_cts.default : {f32, f64},
840*da0073e9SAndroid Build Coastguard Worker    aten._histogramdd_from_bin_tensors.default : {f32, f64},
841*da0073e9SAndroid Build Coastguard Worker    aten._local_scalar_dense.default : {c32, c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
842*da0073e9SAndroid Build Coastguard Worker    aten._unique2.default : {i8, f64, i64, f16, bf16, f32, i32, b8, i16, u8, u16, u32, u64},
843*da0073e9SAndroid Build Coastguard Worker    aten.bincount.default : {i64, i8, i32, i16, u8},
844*da0073e9SAndroid Build Coastguard Worker    aten.equal.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
845*da0073e9SAndroid Build Coastguard Worker    aten.histogram.bin_ct : {f32, f64},
846*da0073e9SAndroid Build Coastguard Worker    aten.histogram.bins_tensor : {f32, f64},
847*da0073e9SAndroid Build Coastguard Worker    aten.unique_consecutive.default : {i8, f64, i64, f16, bf16, f32, i32, b8, i16, u8, u16, u32, u64},
848*da0073e9SAndroid Build Coastguard Worker    aten.unique_dim.default : {i8, f64, i64, f16, bf16, f32, i32, b8, i16, u8, u16, u32, u64},
849*da0073e9SAndroid Build Coastguard Worker    aten.upsample_nearest3d.vec : {bf16, f32, f64, u8},
850*da0073e9SAndroid Build Coastguard Worker
851*da0073e9SAndroid Build Coastguard Worker}
852*da0073e9SAndroid Build Coastguard Worker
853*da0073e9SAndroid Build Coastguard Worker# these sometimes pass and sometimes fail
854*da0073e9SAndroid Build Coastguard Workermeta_dispatch_skips = {
855*da0073e9SAndroid Build Coastguard Worker    aten.index.Tensor: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32, c32, c64, c128},  # at::nonzero doesn't have a Meta function
856*da0073e9SAndroid Build Coastguard Worker    aten._to_copy.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32, c32, c64, c128},
857*da0073e9SAndroid Build Coastguard Worker    aten.empty.memory_format: {b8, bf16, c128, c64, c32, f16, f32, f64, i16, i32, i64, i8, u8},
858*da0073e9SAndroid Build Coastguard Worker    aten.addbmm_.default: {bf16, c128, c64, f32, f64, i16, i32, i64, i8, u8},
859*da0073e9SAndroid Build Coastguard Worker}
860*da0073e9SAndroid Build Coastguard Worker
861*da0073e9SAndroid Build Coastguard Worker# For CompositeImplicitAutograd functions that fail before hitting the Mode
862*da0073e9SAndroid Build Coastguard Workermeta_dispatch_early_skips = set({
863*da0073e9SAndroid Build Coastguard Worker    torch.Tensor.float_power_,
864*da0073e9SAndroid Build Coastguard Worker    # Errors out in one of the tests, while ProxyTensor passes...
865*da0073e9SAndroid Build Coastguard Worker    torch.Tensor.cumprod_,
866*da0073e9SAndroid Build Coastguard Worker    torch.Tensor.cumsum_,
867*da0073e9SAndroid Build Coastguard Worker})
868*da0073e9SAndroid Build Coastguard Worker
869*da0073e9SAndroid Build Coastguard Workermeta_inplace_skips = set({
870*da0073e9SAndroid Build Coastguard Worker    # Errors out in one of the tests, while ProxyTensor passes...
871*da0073e9SAndroid Build Coastguard Worker    torch.Tensor.cumprod_,
872*da0073e9SAndroid Build Coastguard Worker    torch.Tensor.cumsum_,
873*da0073e9SAndroid Build Coastguard Worker})
874*da0073e9SAndroid Build Coastguard Worker
875*da0073e9SAndroid Build Coastguard Workermeta_dispatch_device_expected_failures = defaultdict(dict)
876*da0073e9SAndroid Build Coastguard Workermeta_dispatch_device_skips = defaultdict(dict)
877*da0073e9SAndroid Build Coastguard Worker
878*da0073e9SAndroid Build Coastguard Workermeta_dispatch_device_expected_failures['cpu'] = {
879*da0073e9SAndroid Build Coastguard Worker    # TODO: The decomps for these batch norm ops return different dtypes depending
880*da0073e9SAndroid Build Coastguard Worker    # on the device. We should make this work better with meta tensors.
881*da0073e9SAndroid Build Coastguard Worker    aten.native_batch_norm.default: {bf16, f16},
882*da0073e9SAndroid Build Coastguard Worker    aten._native_batch_norm_legit.default: {bf16, f16},
883*da0073e9SAndroid Build Coastguard Worker    aten._native_batch_norm_legit.no_stats: {bf16, f16},
884*da0073e9SAndroid Build Coastguard Worker    aten._batch_norm_with_update.default: {bf16, f16},
885*da0073e9SAndroid Build Coastguard Worker
886*da0073e9SAndroid Build Coastguard Worker    aten.native_layer_norm.default: {bf16, f16},
887*da0073e9SAndroid Build Coastguard Worker}
888*da0073e9SAndroid Build Coastguard Worker
889*da0073e9SAndroid Build Coastguard Workermeta_dispatch_device_expected_failures['cuda'] = {
890*da0073e9SAndroid Build Coastguard Worker    aten._unique2.default: {f16},  # aten::_unique2
891*da0073e9SAndroid Build Coastguard Worker    aten._use_cudnn_ctc_loss.default: {f32, f64},  # aten::_use_cudnn_ctc_loss
892*da0073e9SAndroid Build Coastguard Worker    aten._use_cudnn_ctc_loss.Tensor: {f32, f64},  # aten::_use_cudnn_ctc_loss.Tensor
893*da0073e9SAndroid Build Coastguard Worker    aten.cudnn_grid_sampler.default: {f16, f32, f64},  # aten::cudnn_grid_sampler
894*da0073e9SAndroid Build Coastguard Worker    aten.geqrf.default: {f32, f64},  # aten::geqrf
895*da0073e9SAndroid Build Coastguard Worker    aten.linalg_eigvalsh.out: {f32, f64},  # aten::linalg_eigvalsh.out
896*da0073e9SAndroid Build Coastguard Worker    aten.log_sigmoid_forward.default: {bf16, f16, f64, f32},
897*da0073e9SAndroid Build Coastguard Worker    aten.log_sigmoid_forward.output : {bf16, f16, f64, f32},  # aten::log_sigmoid_forward.output
898*da0073e9SAndroid Build Coastguard Worker    aten.unique_consecutive.default: {f16},  # aten::unique_consecutive
899*da0073e9SAndroid Build Coastguard Worker    aten.unique_dim.default: {f16},  # aten::unique_dim
900*da0073e9SAndroid Build Coastguard Worker    aten.upsample_nearest3d.vec: {f16},  # aten::upsample_nearest3d.vec
901*da0073e9SAndroid Build Coastguard Worker}
902*da0073e9SAndroid Build Coastguard Worker
903*da0073e9SAndroid Build Coastguard Workermeta_dispatch_device_skips['cpu'] = {
904*da0073e9SAndroid Build Coastguard Worker    aten._embedding_bag_forward_only.default: {bf16, f16, f32, f64},
905*da0073e9SAndroid Build Coastguard Worker
906*da0073e9SAndroid Build Coastguard Worker    # TODO: The decomps for these batch norm ops return different dtypes depending
907*da0073e9SAndroid Build Coastguard Worker    # on the device. We should make this work better with meta tensors.
908*da0073e9SAndroid Build Coastguard Worker    aten.native_batch_norm.default: {f32, f64},
909*da0073e9SAndroid Build Coastguard Worker    aten._native_batch_norm_legit.default: {f32, f64},
910*da0073e9SAndroid Build Coastguard Worker    aten._native_batch_norm_legit.no_stats: {f32, f64},
911*da0073e9SAndroid Build Coastguard Worker    aten._batch_norm_with_update.default: {f32, f64},
912*da0073e9SAndroid Build Coastguard Worker
913*da0073e9SAndroid Build Coastguard Worker    # If the computation dtype is different from the input
914*da0073e9SAndroid Build Coastguard Worker    # dtype this will fail. CPU execution may also have a
915*da0073e9SAndroid Build Coastguard Worker    # a different output from other devices.
916*da0073e9SAndroid Build Coastguard Worker    aten.native_batch_norm.out: {bf16, f16, f32, f64}
917*da0073e9SAndroid Build Coastguard Worker}
918*da0073e9SAndroid Build Coastguard Worker
919*da0073e9SAndroid Build Coastguard Workermeta_dispatch_device_skips['cuda'] = {
920*da0073e9SAndroid Build Coastguard Worker    aten._conj.default: {c32, f16},  # file issue
921*da0073e9SAndroid Build Coastguard Worker    aten._linalg_svd.default: {c64, c128},  # aten::linalg_eigvalsh.out
922*da0073e9SAndroid Build Coastguard Worker    aten.cudnn_batch_norm.default: {f32, f64},
923*da0073e9SAndroid Build Coastguard Worker    aten.log_softmax.int : {c32, c64},
924*da0073e9SAndroid Build Coastguard Worker    aten.softmax.int : {c32, c64},
925*da0073e9SAndroid Build Coastguard Worker    aten.softmax.int : {c32, c64},
926*da0073e9SAndroid Build Coastguard Worker
927*da0073e9SAndroid Build Coastguard Worker    # ROCm stuff; technically this should be expected failure but it's
928*da0073e9SAndroid Build Coastguard Worker    # not worth it; these should get unified anyway
929*da0073e9SAndroid Build Coastguard Worker    aten.miopen_batch_norm.default: {f32},
930*da0073e9SAndroid Build Coastguard Worker}
931*da0073e9SAndroid Build Coastguard Worker
932*da0073e9SAndroid Build Coastguard Workerdef get_strided_args(args):
933*da0073e9SAndroid Build Coastguard Worker
934*da0073e9SAndroid Build Coastguard Worker    def get_strided_variants(t, include_storage_offset=False):
935*da0073e9SAndroid Build Coastguard Worker        variants = []
936*da0073e9SAndroid Build Coastguard Worker
937*da0073e9SAndroid Build Coastguard Worker        # contiguous
938*da0073e9SAndroid Build Coastguard Worker        variants.append(t)
939*da0073e9SAndroid Build Coastguard Worker
940*da0073e9SAndroid Build Coastguard Worker        # transposed
941*da0073e9SAndroid Build Coastguard Worker        if t.ndim > 1:
942*da0073e9SAndroid Build Coastguard Worker            perm = list(reversed(range(t.ndim)))
943*da0073e9SAndroid Build Coastguard Worker            transposed = torch.empty(
944*da0073e9SAndroid Build Coastguard Worker                t.shape[::-1], device=t.device, dtype=t.dtype, requires_grad=t.requires_grad
945*da0073e9SAndroid Build Coastguard Worker            ).permute(perm).copy_(t)
946*da0073e9SAndroid Build Coastguard Worker            variants.append(transposed)
947*da0073e9SAndroid Build Coastguard Worker
948*da0073e9SAndroid Build Coastguard Worker        # nondense
949*da0073e9SAndroid Build Coastguard Worker        if t.ndim > 0:
950*da0073e9SAndroid Build Coastguard Worker            nondense = torch.repeat_interleave(t, 2, dim=-1)[..., ::2]
951*da0073e9SAndroid Build Coastguard Worker            variants.append(nondense)
952*da0073e9SAndroid Build Coastguard Worker
953*da0073e9SAndroid Build Coastguard Worker        # channel_last
954*da0073e9SAndroid Build Coastguard Worker        if t.ndim == 4:
955*da0073e9SAndroid Build Coastguard Worker            variants.append(t.contiguous(memory_format=torch.channels_last))
956*da0073e9SAndroid Build Coastguard Worker
957*da0073e9SAndroid Build Coastguard Worker        # channel_last_3d
958*da0073e9SAndroid Build Coastguard Worker        if t.ndim == 5:
959*da0073e9SAndroid Build Coastguard Worker            variants.append(t.contiguous(memory_format=torch.channels_last_3d))
960*da0073e9SAndroid Build Coastguard Worker
961*da0073e9SAndroid Build Coastguard Worker        # storage_offset
962*da0073e9SAndroid Build Coastguard Worker        if include_storage_offset:
963*da0073e9SAndroid Build Coastguard Worker            buffer = torch.empty(t.numel() + 1, device=t.device, dtype=t.dtype, requires_grad=t.requires_grad)
964*da0073e9SAndroid Build Coastguard Worker            buffer = buffer.as_strided(t.shape, t.stride(), storage_offset=1)
965*da0073e9SAndroid Build Coastguard Worker            buffer.copy_(t)
966*da0073e9SAndroid Build Coastguard Worker            variants.append(buffer)
967*da0073e9SAndroid Build Coastguard Worker
968*da0073e9SAndroid Build Coastguard Worker        return variants
969*da0073e9SAndroid Build Coastguard Worker
970*da0073e9SAndroid Build Coastguard Worker    strided_args = []
971*da0073e9SAndroid Build Coastguard Worker    for arg in args:
972*da0073e9SAndroid Build Coastguard Worker        if isinstance(arg, torch.Tensor) and not arg.is_sparse_csr and arg.is_contiguous():
973*da0073e9SAndroid Build Coastguard Worker            strided_arg_variants = get_strided_variants(arg)
974*da0073e9SAndroid Build Coastguard Worker        else:
975*da0073e9SAndroid Build Coastguard Worker            strided_arg_variants = [arg]
976*da0073e9SAndroid Build Coastguard Worker        strided_args.append(strided_arg_variants)
977*da0073e9SAndroid Build Coastguard Worker
978*da0073e9SAndroid Build Coastguard Worker    yield from itertools.product(*strided_args)
979*da0073e9SAndroid Build Coastguard Worker
980*da0073e9SAndroid Build Coastguard Workerclass MetaCrossRefDispatchMode(torch.utils._python_dispatch.TorchDispatchMode):
981*da0073e9SAndroid Build Coastguard Worker    test_case: TestCase
982*da0073e9SAndroid Build Coastguard Worker    device: torch.device
983*da0073e9SAndroid Build Coastguard Worker    dtype: torch.dtype
984*da0073e9SAndroid Build Coastguard Worker    aten_olp_no_out_overload: set = set()
985*da0073e9SAndroid Build Coastguard Worker
986*da0073e9SAndroid Build Coastguard Worker    def __init__(self, test_case, *, device, dtype, symbolic_meta: bool, inplace: bool, supports_out: bool):
987*da0073e9SAndroid Build Coastguard Worker        self.test_case = test_case
988*da0073e9SAndroid Build Coastguard Worker        # save TLS
989*da0073e9SAndroid Build Coastguard Worker        self.precision = test_case.precision
990*da0073e9SAndroid Build Coastguard Worker        self.rel_tol = test_case.rel_tol
991*da0073e9SAndroid Build Coastguard Worker        self.device_type = torch.device(device).type
992*da0073e9SAndroid Build Coastguard Worker        self.dtype = dtype
993*da0073e9SAndroid Build Coastguard Worker        self.symbolic_meta = symbolic_meta
994*da0073e9SAndroid Build Coastguard Worker        self.inplace = inplace
995*da0073e9SAndroid Build Coastguard Worker        self.supports_out = supports_out
996*da0073e9SAndroid Build Coastguard Worker
997*da0073e9SAndroid Build Coastguard Worker    @staticmethod
998*da0073e9SAndroid Build Coastguard Worker    def try_resolve_aten_out_overload(ol, args, kwargs, num_outputs):
999*da0073e9SAndroid Build Coastguard Worker
1000*da0073e9SAndroid Build Coastguard Worker        ol_args = ol._schema.arguments
1001*da0073e9SAndroid Build Coastguard Worker        olp: OpOverloadPacket = ol._overloadpacket
1002*da0073e9SAndroid Build Coastguard Worker
1003*da0073e9SAndroid Build Coastguard Worker        if olp in MetaCrossRefDispatchMode.aten_olp_no_out_overload:
1004*da0073e9SAndroid Build Coastguard Worker            return (None, None, None)
1005*da0073e9SAndroid Build Coastguard Worker
1006*da0073e9SAndroid Build Coastguard Worker        candidate_ols = []
1007*da0073e9SAndroid Build Coastguard Worker        for candidate_ol_name in olp.overloads():
1008*da0073e9SAndroid Build Coastguard Worker            candidate_ol = getattr(olp, candidate_ol_name)
1009*da0073e9SAndroid Build Coastguard Worker            if any(arg.is_out for arg in candidate_ol._schema.arguments):
1010*da0073e9SAndroid Build Coastguard Worker                candidate_ols.append(candidate_ol)
1011*da0073e9SAndroid Build Coastguard Worker
1012*da0073e9SAndroid Build Coastguard Worker        if not candidate_ols:
1013*da0073e9SAndroid Build Coastguard Worker            MetaCrossRefDispatchMode.aten_olp_no_out_overload.add(olp)
1014*da0073e9SAndroid Build Coastguard Worker            return (None, None, None)
1015*da0073e9SAndroid Build Coastguard Worker
1016*da0073e9SAndroid Build Coastguard Worker        # Now match based on args, kwargs and number of required outputs
1017*da0073e9SAndroid Build Coastguard Worker        candidate_ol: OpOverload = None
1018*da0073e9SAndroid Build Coastguard Worker        for candidate_ol in candidate_ols:
1019*da0073e9SAndroid Build Coastguard Worker            candidate_ol_args = candidate_ol._schema.arguments
1020*da0073e9SAndroid Build Coastguard Worker
1021*da0073e9SAndroid Build Coastguard Worker            if (len(args) >= len(candidate_ol_args)):
1022*da0073e9SAndroid Build Coastguard Worker                continue
1023*da0073e9SAndroid Build Coastguard Worker
1024*da0073e9SAndroid Build Coastguard Worker            # Positional arguments must have the same type
1025*da0073e9SAndroid Build Coastguard Worker            if not all(
1026*da0073e9SAndroid Build Coastguard Worker                ol_args[pos_arg_ind].type == candidate_ol_args[pos_arg_ind].type
1027*da0073e9SAndroid Build Coastguard Worker                for pos_arg_ind in range(len(args))
1028*da0073e9SAndroid Build Coastguard Worker            ):
1029*da0073e9SAndroid Build Coastguard Worker                continue
1030*da0073e9SAndroid Build Coastguard Worker
1031*da0073e9SAndroid Build Coastguard Worker            # Number of outputs must match
1032*da0073e9SAndroid Build Coastguard Worker            candidate_out_names = [out_arg.name for out_arg in candidate_ol_args[-num_outputs:] if out_arg.is_out]
1033*da0073e9SAndroid Build Coastguard Worker            if len(candidate_out_names) != num_outputs:
1034*da0073e9SAndroid Build Coastguard Worker                continue
1035*da0073e9SAndroid Build Coastguard Worker
1036*da0073e9SAndroid Build Coastguard Worker            # Now try and match kwargs. Just need to ensure that the
1037*da0073e9SAndroid Build Coastguard Worker            # remaining kwargs allow an out overload to be called. For example
1038*da0073e9SAndroid Build Coastguard Worker            # we can throw away parameters like `dtype` that may be passed to the
1039*da0073e9SAndroid Build Coastguard Worker            # functional version of the op since the `dtype` will already be present
1040*da0073e9SAndroid Build Coastguard Worker            # in the `out` argument
1041*da0073e9SAndroid Build Coastguard Worker            new_kwargs = {}
1042*da0073e9SAndroid Build Coastguard Worker            kwargs_match = True
1043*da0073e9SAndroid Build Coastguard Worker            for arg in candidate_ol_args[len(args):-num_outputs]:
1044*da0073e9SAndroid Build Coastguard Worker                if arg.name not in kwargs:
1045*da0073e9SAndroid Build Coastguard Worker                    if arg.has_default_value():
1046*da0073e9SAndroid Build Coastguard Worker                        new_kwargs[arg.name] = arg.default_value
1047*da0073e9SAndroid Build Coastguard Worker                    elif isinstance(arg.type, torch.OptionalType):
1048*da0073e9SAndroid Build Coastguard Worker                        if isinstance(arg.type.getElementType(), torch.BoolType):
1049*da0073e9SAndroid Build Coastguard Worker                            new_kwargs[arg.name] = False
1050*da0073e9SAndroid Build Coastguard Worker                        else:
1051*da0073e9SAndroid Build Coastguard Worker                            new_kwargs[arg.name] = None
1052*da0073e9SAndroid Build Coastguard Worker                    else:
1053*da0073e9SAndroid Build Coastguard Worker                        kwargs_match = False
1054*da0073e9SAndroid Build Coastguard Worker                        break
1055*da0073e9SAndroid Build Coastguard Worker                else:
1056*da0073e9SAndroid Build Coastguard Worker                    new_kwargs[arg.name] = kwargs[arg.name]
1057*da0073e9SAndroid Build Coastguard Worker
1058*da0073e9SAndroid Build Coastguard Worker            if kwargs_match:
1059*da0073e9SAndroid Build Coastguard Worker                return candidate_ol, candidate_out_names, new_kwargs
1060*da0073e9SAndroid Build Coastguard Worker
1061*da0073e9SAndroid Build Coastguard Worker        return None, None, None
1062*da0073e9SAndroid Build Coastguard Worker
1063*da0073e9SAndroid Build Coastguard Worker    def _get_expected_test_result(self, func: OpOverload):
1064*da0073e9SAndroid Build Coastguard Worker        if self.dtype in meta_dispatch_skips.get(func, set()):
1065*da0073e9SAndroid Build Coastguard Worker            test_expect = TestExpect.SKIP
1066*da0073e9SAndroid Build Coastguard Worker        elif self.dtype in meta_dispatch_device_skips[self.device_type].get(func, set()):
1067*da0073e9SAndroid Build Coastguard Worker            test_expect = TestExpect.SKIP
1068*da0073e9SAndroid Build Coastguard Worker        elif self.dtype in meta_dispatch_expected_failures.get(func, set()):
1069*da0073e9SAndroid Build Coastguard Worker            test_expect = TestExpect.XFAILURE
1070*da0073e9SAndroid Build Coastguard Worker        elif self.dtype in meta_dispatch_device_expected_failures[self.device_type].get(func, set()):
1071*da0073e9SAndroid Build Coastguard Worker            test_expect = TestExpect.XFAILURE
1072*da0073e9SAndroid Build Coastguard Worker        else:
1073*da0073e9SAndroid Build Coastguard Worker            test_expect = TestExpect.SUCCESS
1074*da0073e9SAndroid Build Coastguard Worker        return test_expect
1075*da0073e9SAndroid Build Coastguard Worker
1076*da0073e9SAndroid Build Coastguard Worker    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1077*da0073e9SAndroid Build Coastguard Worker        kwargs = kwargs or {}
1078*da0073e9SAndroid Build Coastguard Worker        self.test_case.precision = self.precision
1079*da0073e9SAndroid Build Coastguard Worker        self.test_case.rel_tol = self.rel_tol
1080*da0073e9SAndroid Build Coastguard Worker
1081*da0073e9SAndroid Build Coastguard Worker        test_expect = self._get_expected_test_result(func)
1082*da0073e9SAndroid Build Coastguard Worker
1083*da0073e9SAndroid Build Coastguard Worker        expected = run_meta_crossref(
1084*da0073e9SAndroid Build Coastguard Worker            self.test_case,
1085*da0073e9SAndroid Build Coastguard Worker            test_expect,
1086*da0073e9SAndroid Build Coastguard Worker            func,
1087*da0073e9SAndroid Build Coastguard Worker            args,
1088*da0073e9SAndroid Build Coastguard Worker            kwargs,
1089*da0073e9SAndroid Build Coastguard Worker            dtype=self.dtype,
1090*da0073e9SAndroid Build Coastguard Worker            device_type=self.device_type,
1091*da0073e9SAndroid Build Coastguard Worker            run_symbolic_meta=self.symbolic_meta,
1092*da0073e9SAndroid Build Coastguard Worker        )
1093*da0073e9SAndroid Build Coastguard Worker
1094*da0073e9SAndroid Build Coastguard Worker        # This is to test torch ops that do not have an out parameter but have
1095*da0073e9SAndroid Build Coastguard Worker        # aten op overloads that have out parameters. Additionally, Python decompositions
1096*da0073e9SAndroid Build Coastguard Worker        # may register OpOverloadPacket's so decompositions need to be tested
1097*da0073e9SAndroid Build Coastguard Worker        # to ensure all OpOverloads still function for the Meta key (e.g. if a python decomposition
1098*da0073e9SAndroid Build Coastguard Worker        # is registered for an aten op aten.foo with overloads [default, out], the python
1099*da0073e9SAndroid Build Coastguard Worker        # function needs to support receiving `out` arguments)
1100*da0073e9SAndroid Build Coastguard Worker        if (
1101*da0073e9SAndroid Build Coastguard Worker            not self.inplace and
1102*da0073e9SAndroid Build Coastguard Worker            not self.supports_out and
1103*da0073e9SAndroid Build Coastguard Worker            test_expect == TestExpect.SUCCESS and
1104*da0073e9SAndroid Build Coastguard Worker            (torch.is_tensor(expected) or isinstance(expected, Iterable))
1105*da0073e9SAndroid Build Coastguard Worker        ):
1106*da0073e9SAndroid Build Coastguard Worker
1107*da0073e9SAndroid Build Coastguard Worker            # check to see if there is a potential out overload
1108*da0073e9SAndroid Build Coastguard Worker            num_outputs = 1 if torch.is_tensor(expected) else len(expected)
1109*da0073e9SAndroid Build Coastguard Worker            func_out_overload, out_param_names, kwargs = self.try_resolve_aten_out_overload(func, args, kwargs, num_outputs)
1110*da0073e9SAndroid Build Coastguard Worker
1111*da0073e9SAndroid Build Coastguard Worker            if func_out_overload:
1112*da0073e9SAndroid Build Coastguard Worker
1113*da0073e9SAndroid Build Coastguard Worker                if num_outputs == 1:
1114*da0073e9SAndroid Build Coastguard Worker                    kwargs[out_param_names[0]] = expected
1115*da0073e9SAndroid Build Coastguard Worker                else:
1116*da0073e9SAndroid Build Coastguard Worker                    for ind, out_param_name in enumerate(out_param_names):
1117*da0073e9SAndroid Build Coastguard Worker                        kwargs[out_param_name] = expected[ind]
1118*da0073e9SAndroid Build Coastguard Worker
1119*da0073e9SAndroid Build Coastguard Worker                test_expect = self._get_expected_test_result(func_out_overload)
1120*da0073e9SAndroid Build Coastguard Worker
1121*da0073e9SAndroid Build Coastguard Worker                run_meta_crossref(
1122*da0073e9SAndroid Build Coastguard Worker                    self.test_case,
1123*da0073e9SAndroid Build Coastguard Worker                    test_expect,
1124*da0073e9SAndroid Build Coastguard Worker                    func_out_overload,
1125*da0073e9SAndroid Build Coastguard Worker                    args,
1126*da0073e9SAndroid Build Coastguard Worker                    kwargs,
1127*da0073e9SAndroid Build Coastguard Worker                    dtype=self.dtype,
1128*da0073e9SAndroid Build Coastguard Worker                    device_type=self.device_type,
1129*da0073e9SAndroid Build Coastguard Worker                    run_symbolic_meta=self.symbolic_meta,
1130*da0073e9SAndroid Build Coastguard Worker                )
1131*da0073e9SAndroid Build Coastguard Worker
1132*da0073e9SAndroid Build Coastguard Worker        return expected
1133*da0073e9SAndroid Build Coastguard Worker
1134*da0073e9SAndroid Build Coastguard Worker# NB: we're running these tests only on CUDA because there are some
1135*da0073e9SAndroid Build Coastguard Worker# inconsistencies between CUDA and CPU, and running on CUDA makes it easier
1136*da0073e9SAndroid Build Coastguard Worker# to ignore the CPU case when inconsistencies arise.  Ideally we deal
1137*da0073e9SAndroid Build Coastguard Worker# with the inconsistencies but this takes time.
1138*da0073e9SAndroid Build Coastguard Worker@unMarkDynamoStrictTest
1139*da0073e9SAndroid Build Coastguard Workerclass TestMeta(TestCase):
1140*da0073e9SAndroid Build Coastguard Worker    # Copies inputs to inplace operations to avoid inplace modifications
1141*da0073e9SAndroid Build Coastguard Worker    #   to leaves requiring gradient
1142*da0073e9SAndroid Build Coastguard Worker    def _get_safe_inplace(self, inplace_variant):
1143*da0073e9SAndroid Build Coastguard Worker        @wraps(inplace_variant)
1144*da0073e9SAndroid Build Coastguard Worker        def _fn(t, *args, **kwargs):
1145*da0073e9SAndroid Build Coastguard Worker            if isinstance(t, list):
1146*da0073e9SAndroid Build Coastguard Worker                return inplace_variant([x.clone() for x in t], *args, **kwargs)
1147*da0073e9SAndroid Build Coastguard Worker            else:
1148*da0073e9SAndroid Build Coastguard Worker                return inplace_variant(t.clone(), *args, **kwargs)
1149*da0073e9SAndroid Build Coastguard Worker
1150*da0073e9SAndroid Build Coastguard Worker        return _fn
1151*da0073e9SAndroid Build Coastguard Worker
1152*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1153*da0073e9SAndroid Build Coastguard Worker    @skipIfCrossRef
1154*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
1155*da0073e9SAndroid Build Coastguard Worker    @ops(itertools.chain(op_db, foreach_op_db))
1156*da0073e9SAndroid Build Coastguard Worker    def test_meta_outplace(self, device, dtype, op):
1157*da0073e9SAndroid Build Coastguard Worker        if "_scaled_mm" in op.name:
1158*da0073e9SAndroid Build Coastguard Worker            raise unittest.SkipTest("_scaled_mm dose not support meta device")
1159*da0073e9SAndroid Build Coastguard Worker        skip_op_names = (
1160*da0073e9SAndroid Build Coastguard Worker            "fft.ihfft",
1161*da0073e9SAndroid Build Coastguard Worker            "fft.ihfft2",
1162*da0073e9SAndroid Build Coastguard Worker            "linalg.lu_solve",
1163*da0073e9SAndroid Build Coastguard Worker        )
1164*da0073e9SAndroid Build Coastguard Worker        if TEST_WITH_TORCHDYNAMO and op.name in skip_op_names:
1165*da0073e9SAndroid Build Coastguard Worker            raise unittest.SkipTest("flaky")
1166*da0073e9SAndroid Build Coastguard Worker        # run the OpInfo sample inputs, cross-referencing them with the
1167*da0073e9SAndroid Build Coastguard Worker        # meta implementation and check the results are the same.  All
1168*da0073e9SAndroid Build Coastguard Worker        # the heavy lifting happens in MetaCrossRefFunctionMode
1169*da0073e9SAndroid Build Coastguard Worker        func = op.get_op()
1170*da0073e9SAndroid Build Coastguard Worker        samples = op.sample_inputs(device, dtype, requires_grad=False)
1171*da0073e9SAndroid Build Coastguard Worker        for sample_input in samples:
1172*da0073e9SAndroid Build Coastguard Worker            args = [sample_input.input] + list(sample_input.args)
1173*da0073e9SAndroid Build Coastguard Worker            kwargs = sample_input.kwargs
1174*da0073e9SAndroid Build Coastguard Worker            with MetaCrossRefFunctionMode(self, dtype=dtype, device=device, inplace=False):
1175*da0073e9SAndroid Build Coastguard Worker                expected = func(*args, **kwargs)
1176*da0073e9SAndroid Build Coastguard Worker                if isinstance(expected, torch.Tensor) and op.supports_out:
1177*da0073e9SAndroid Build Coastguard Worker                    func(*args, **kwargs, out=expected)
1178*da0073e9SAndroid Build Coastguard Worker
1179*da0073e9SAndroid Build Coastguard Worker            # Special test for functions taking "device" kwarg
1180*da0073e9SAndroid Build Coastguard Worker            # The crossref tests that replacing the device with "meta" works
1181*da0073e9SAndroid Build Coastguard Worker            # This part makes sure that *_like functions work well with a "meta"
1182*da0073e9SAndroid Build Coastguard Worker            # Tensor and their original device argument.
1183*da0073e9SAndroid Build Coastguard Worker            if "device" in kwargs and "_like" in op.name:
1184*da0073e9SAndroid Build Coastguard Worker                with torch.random.fork_rng():
1185*da0073e9SAndroid Build Coastguard Worker                    torch.manual_seed(123)
1186*da0073e9SAndroid Build Coastguard Worker                    ref = func(*args, **kwargs)
1187*da0073e9SAndroid Build Coastguard Worker
1188*da0073e9SAndroid Build Coastguard Worker                # *_like functions take a Tensor as first argument
1189*da0073e9SAndroid Build Coastguard Worker                assert isinstance(args[0], torch.Tensor)
1190*da0073e9SAndroid Build Coastguard Worker                with torch.random.fork_rng():
1191*da0073e9SAndroid Build Coastguard Worker                    torch.manual_seed(123)
1192*da0073e9SAndroid Build Coastguard Worker                    args[0] = args[0].to(device="meta")
1193*da0073e9SAndroid Build Coastguard Worker                    meta = func(*args, **kwargs)
1194*da0073e9SAndroid Build Coastguard Worker
1195*da0073e9SAndroid Build Coastguard Worker                # empty_like is not deterministic
1196*da0073e9SAndroid Build Coastguard Worker                if op.name != "empty_like":
1197*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(ref, meta)
1198*da0073e9SAndroid Build Coastguard Worker
1199*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1200*da0073e9SAndroid Build Coastguard Worker    @skipIfCrossRef
1201*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
1202*da0073e9SAndroid Build Coastguard Worker    @ops(itertools.chain(op_db, foreach_op_db))
1203*da0073e9SAndroid Build Coastguard Worker    def test_meta_inplace(self, device, dtype, op):
1204*da0073e9SAndroid Build Coastguard Worker        func = op.get_inplace()
1205*da0073e9SAndroid Build Coastguard Worker        if not func:
1206*da0073e9SAndroid Build Coastguard Worker            self.skipTest("No inplace variable for this op")
1207*da0073e9SAndroid Build Coastguard Worker        if op.promotes_int_to_float and not dtype.is_floating_point:
1208*da0073e9SAndroid Build Coastguard Worker            self.skipTest("Op promotes to float, which is impossible for inplace with non-float input")
1209*da0073e9SAndroid Build Coastguard Worker        if func in meta_inplace_skips:
1210*da0073e9SAndroid Build Coastguard Worker            self.skipTest("Skipped")
1211*da0073e9SAndroid Build Coastguard Worker        func = self._get_safe_inplace(func)
1212*da0073e9SAndroid Build Coastguard Worker        samples = op.sample_inputs(device, dtype, requires_grad=False)
1213*da0073e9SAndroid Build Coastguard Worker        for sample_input in samples:
1214*da0073e9SAndroid Build Coastguard Worker            if sample_input.broadcasts_input:
1215*da0073e9SAndroid Build Coastguard Worker                continue
1216*da0073e9SAndroid Build Coastguard Worker            args = [sample_input.input] + list(sample_input.args)
1217*da0073e9SAndroid Build Coastguard Worker            kwargs = sample_input.kwargs
1218*da0073e9SAndroid Build Coastguard Worker            with MetaCrossRefFunctionMode(self, dtype=dtype, device=device, inplace=True):
1219*da0073e9SAndroid Build Coastguard Worker                expected = func(*args, **kwargs)
1220*da0073e9SAndroid Build Coastguard Worker
1221*da0073e9SAndroid Build Coastguard Worker    def _run_dispatch_meta_test(self, device, dtype, op, symbolic_meta, inplace, all_stride_variants=False):
1222*da0073e9SAndroid Build Coastguard Worker        if "_scaled_mm" in op.name:
1223*da0073e9SAndroid Build Coastguard Worker            raise unittest.SkipTest("_scaled_mm dose not support meta device")
1224*da0073e9SAndroid Build Coastguard Worker        if inplace:
1225*da0073e9SAndroid Build Coastguard Worker            func = op.get_inplace()
1226*da0073e9SAndroid Build Coastguard Worker            if not func:
1227*da0073e9SAndroid Build Coastguard Worker                self.skipTest("No inplace variable for this op")
1228*da0073e9SAndroid Build Coastguard Worker            if op.promotes_int_to_float and not dtype.is_floating_point:
1229*da0073e9SAndroid Build Coastguard Worker                self.skipTest("Op promotes to float, which is impossible for inplace with non-float input")
1230*da0073e9SAndroid Build Coastguard Worker        else:
1231*da0073e9SAndroid Build Coastguard Worker            func = op.get_op()
1232*da0073e9SAndroid Build Coastguard Worker
1233*da0073e9SAndroid Build Coastguard Worker        if func in meta_dispatch_early_skips:
1234*da0073e9SAndroid Build Coastguard Worker            self.skipTest("Function is in dispatch early skips")
1235*da0073e9SAndroid Build Coastguard Worker
1236*da0073e9SAndroid Build Coastguard Worker        if inplace:
1237*da0073e9SAndroid Build Coastguard Worker            func = self._get_safe_inplace(func)
1238*da0073e9SAndroid Build Coastguard Worker
1239*da0073e9SAndroid Build Coastguard Worker        samples = op.sample_inputs(device, dtype, requires_grad=False)
1240*da0073e9SAndroid Build Coastguard Worker        for sample_input in samples:
1241*da0073e9SAndroid Build Coastguard Worker            if inplace and sample_input.broadcasts_input:
1242*da0073e9SAndroid Build Coastguard Worker                continue
1243*da0073e9SAndroid Build Coastguard Worker
1244*da0073e9SAndroid Build Coastguard Worker            sample_args = [sample_input.input] + list(sample_input.args)
1245*da0073e9SAndroid Build Coastguard Worker            kwargs = sample_input.kwargs
1246*da0073e9SAndroid Build Coastguard Worker
1247*da0073e9SAndroid Build Coastguard Worker            if all_stride_variants and sum(isinstance(arg, torch.Tensor) for arg in sample_args) <= 5:
1248*da0073e9SAndroid Build Coastguard Worker                # test inputs <= 5 tensors to avoid combinatorial explosion
1249*da0073e9SAndroid Build Coastguard Worker                strided_args = get_strided_args(sample_args)
1250*da0073e9SAndroid Build Coastguard Worker            else:
1251*da0073e9SAndroid Build Coastguard Worker                strided_args = [sample_args]
1252*da0073e9SAndroid Build Coastguard Worker
1253*da0073e9SAndroid Build Coastguard Worker            for args in strided_args:
1254*da0073e9SAndroid Build Coastguard Worker                with MetaCrossRefDispatchMode.push(
1255*da0073e9SAndroid Build Coastguard Worker                    self, dtype=dtype, device=device,
1256*da0073e9SAndroid Build Coastguard Worker                    symbolic_meta=symbolic_meta, inplace=inplace,
1257*da0073e9SAndroid Build Coastguard Worker                     supports_out=op.supports_out):
1258*da0073e9SAndroid Build Coastguard Worker                    expected = func(*args, **kwargs)
1259*da0073e9SAndroid Build Coastguard Worker
1260*da0073e9SAndroid Build Coastguard Worker                    if not inplace and isinstance(expected, torch.Tensor) and op.supports_out:
1261*da0073e9SAndroid Build Coastguard Worker                        func(*args, **kwargs, out=expected)
1262*da0073e9SAndroid Build Coastguard Worker
1263*da0073e9SAndroid Build Coastguard Worker
1264*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1265*da0073e9SAndroid Build Coastguard Worker    @skipIfCrossRef
1266*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
1267*da0073e9SAndroid Build Coastguard Worker    @ops(itertools.chain(op_db, foreach_op_db))
1268*da0073e9SAndroid Build Coastguard Worker    def test_dispatch_meta_outplace(self, device, dtype, op):
1269*da0073e9SAndroid Build Coastguard Worker        self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=False, inplace=False)
1270*da0073e9SAndroid Build Coastguard Worker
1271*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1272*da0073e9SAndroid Build Coastguard Worker    @skipIfCrossRef
1273*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
1274*da0073e9SAndroid Build Coastguard Worker    @ops(itertools.chain(op_db, foreach_op_db))
1275*da0073e9SAndroid Build Coastguard Worker    def test_dispatch_meta_inplace(self, device, dtype, op):
1276*da0073e9SAndroid Build Coastguard Worker        self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=False, inplace=True)
1277*da0073e9SAndroid Build Coastguard Worker
1278*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1279*da0073e9SAndroid Build Coastguard Worker    @skipIfCrossRef
1280*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
1281*da0073e9SAndroid Build Coastguard Worker    @ops(itertools.chain(op_db, foreach_op_db))
1282*da0073e9SAndroid Build Coastguard Worker    def test_dispatch_symbolic_meta_outplace(self, device, dtype, op):
1283*da0073e9SAndroid Build Coastguard Worker        self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=True, inplace=False)
1284*da0073e9SAndroid Build Coastguard Worker
1285*da0073e9SAndroid Build Coastguard Worker
1286*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1287*da0073e9SAndroid Build Coastguard Worker    @skipIfCrossRef
1288*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
1289*da0073e9SAndroid Build Coastguard Worker    @ops(itertools.chain(op_db, foreach_op_db))
1290*da0073e9SAndroid Build Coastguard Worker    def test_dispatch_symbolic_meta_inplace(self, device, dtype, op):
1291*da0073e9SAndroid Build Coastguard Worker        self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=True, inplace=True)
1292*da0073e9SAndroid Build Coastguard Worker
1293*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1294*da0073e9SAndroid Build Coastguard Worker    @skipIfCrossRef
1295*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
1296*da0073e9SAndroid Build Coastguard Worker    # only test one dtype, as output stride behavior is the same for all dtypes
1297*da0073e9SAndroid Build Coastguard Worker    @ops(itertools.chain(op_db, foreach_op_db), dtypes=OpDTypes.any_common_cpu_cuda_one)
1298*da0073e9SAndroid Build Coastguard Worker    # Only test on CUDA, as CUDA kernel's stride is the reference
1299*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1300*da0073e9SAndroid Build Coastguard Worker    def test_dispatch_symbolic_meta_outplace_all_strides(self, device, dtype, op):
1301*da0073e9SAndroid Build Coastguard Worker        self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=True, inplace=False, all_stride_variants=True)
1302*da0073e9SAndroid Build Coastguard Worker
1303*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1304*da0073e9SAndroid Build Coastguard Worker    @skipIfCrossRef
1305*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
1306*da0073e9SAndroid Build Coastguard Worker    # only test one dtype, as output stride behavior is the same for all dtypes
1307*da0073e9SAndroid Build Coastguard Worker    @ops(itertools.chain(op_db, foreach_op_db), dtypes=OpDTypes.any_common_cpu_cuda_one)
1308*da0073e9SAndroid Build Coastguard Worker    # Only test on CUDA, as CUDA kernel's stride is the reference
1309*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1310*da0073e9SAndroid Build Coastguard Worker    def test_dispatch_symbolic_meta_inplace_all_strides(self, device, dtype, op):
1311*da0073e9SAndroid Build Coastguard Worker        self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=True, inplace=True, all_stride_variants=True)
1312*da0073e9SAndroid Build Coastguard Worker
1313*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1314*da0073e9SAndroid Build Coastguard Worker    @skipIfCrossRef
1315*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
1316*da0073e9SAndroid Build Coastguard Worker    # only test one dtype, as output stride behavior is the same for all dtypes
1317*da0073e9SAndroid Build Coastguard Worker    @ops(binary_ufuncs, allowed_dtypes=(torch.float32,))
1318*da0073e9SAndroid Build Coastguard Worker    # Only test on CUDA, as CUDA kernel's stride is the reference
1319*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1320*da0073e9SAndroid Build Coastguard Worker    def test_binary_ufuncs_mixed_dtype(self, device, dtype, op):
1321*da0073e9SAndroid Build Coastguard Worker        make_arg = partial(
1322*da0073e9SAndroid Build Coastguard Worker            make_tensor,
1323*da0073e9SAndroid Build Coastguard Worker            device=device,
1324*da0073e9SAndroid Build Coastguard Worker        )
1325*da0073e9SAndroid Build Coastguard Worker
1326*da0073e9SAndroid Build Coastguard Worker        def sample_input(op, device, dtype, requires_grad, **kwargs):
1327*da0073e9SAndroid Build Coastguard Worker            yield SampleInput(
1328*da0073e9SAndroid Build Coastguard Worker                make_arg((S,), dtype=dtype), make_arg((S,), dtype=torch.float16)
1329*da0073e9SAndroid Build Coastguard Worker            )
1330*da0073e9SAndroid Build Coastguard Worker
1331*da0073e9SAndroid Build Coastguard Worker        op = copy.copy(op)
1332*da0073e9SAndroid Build Coastguard Worker        op.sample_inputs_func = sample_input
1333*da0073e9SAndroid Build Coastguard Worker
1334*da0073e9SAndroid Build Coastguard Worker        self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=True, inplace=False)
1335*da0073e9SAndroid Build Coastguard Worker
1336*da0073e9SAndroid Build Coastguard Worker
1337*da0073e9SAndroid Build Coastguard Worker    def test_empty_quantized(self):
1338*da0073e9SAndroid Build Coastguard Worker        r = torch.empty(2 ** 52, device='meta', dtype=torch.qint8)
1339*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r.device.type, 'meta')
1340*da0073e9SAndroid Build Coastguard Worker
1341*da0073e9SAndroid Build Coastguard Worker    def test_nan_to_num(self):
1342*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14], device='meta')
1343*da0073e9SAndroid Build Coastguard Worker        r = t.nan_to_num()
1344*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r.device.type, 'meta')
1345*da0073e9SAndroid Build Coastguard Worker
1346*da0073e9SAndroid Build Coastguard Worker    def test_inplace_masked_fill_error(self):
1347*da0073e9SAndroid Build Coastguard Worker        t = torch.randn(3, 3, device='meta')
1348*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "doesn't match the broadcast"):
1349*da0073e9SAndroid Build Coastguard Worker            t.masked_fill_((t > 0).unsqueeze(0), 0.1)
1350*da0073e9SAndroid Build Coastguard Worker
1351*da0073e9SAndroid Build Coastguard Worker    def test_inplace_bin_ops_error(self):
1352*da0073e9SAndroid Build Coastguard Worker        t = torch.randn(3, 3, device='meta')
1353*da0073e9SAndroid Build Coastguard Worker        for op in (torch.Tensor.add_, torch.Tensor.sub_, torch.Tensor.mul_, torch.Tensor.div_,
1354*da0073e9SAndroid Build Coastguard Worker                   torch.Tensor.logical_and_, torch.Tensor.logical_or_, torch.Tensor.logical_xor_):
1355*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "doesn't match the broadcast"):
1356*da0073e9SAndroid Build Coastguard Worker                op(t, t.clone().unsqueeze(0))
1357*da0073e9SAndroid Build Coastguard Worker
1358*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1359*da0073e9SAndroid Build Coastguard Worker    def test_meta_autograd_no_error(self):
1360*da0073e9SAndroid Build Coastguard Worker        with torch.library._scoped_library("meta_test", "DEF") as lib:
1361*da0073e9SAndroid Build Coastguard Worker            with torch.library._scoped_library("meta_test", "IMPL", "CPU") as impl_cpu:
1362*da0073e9SAndroid Build Coastguard Worker                with torch.library._scoped_library("meta_test", "IMPL", "Meta") as impl_meta:
1363*da0073e9SAndroid Build Coastguard Worker                    def foo_impl(x):
1364*da0073e9SAndroid Build Coastguard Worker                        return x + 1
1365*da0073e9SAndroid Build Coastguard Worker
1366*da0073e9SAndroid Build Coastguard Worker                    lib.define("foo(Tensor a) -> Tensor")
1367*da0073e9SAndroid Build Coastguard Worker                    impl_meta.impl("foo", foo_impl)
1368*da0073e9SAndroid Build Coastguard Worker                    impl_cpu.impl("foo", foo_impl)
1369*da0073e9SAndroid Build Coastguard Worker
1370*da0073e9SAndroid Build Coastguard Worker                    a = torch.ones(2, device='meta')
1371*da0073e9SAndroid Build Coastguard Worker                    # The point of the test is that this should not error:
1372*da0073e9SAndroid Build Coastguard Worker                    # We have a fallthrough kernel registered to the AutogradMeta
1373*da0073e9SAndroid Build Coastguard Worker                    # key for custom ops, so it's fine that `foo()` doesn't have
1374*da0073e9SAndroid Build Coastguard Worker                    # an autograd kernel.
1375*da0073e9SAndroid Build Coastguard Worker                    b = torch.ops.meta_test.foo.default(a)
1376*da0073e9SAndroid Build Coastguard Worker
1377*da0073e9SAndroid Build Coastguard Worker    def test_huber_loss_backward(self):
1378*da0073e9SAndroid Build Coastguard Worker        inps = [torch.rand(2**52, device='meta') for _ in range(3)]
1379*da0073e9SAndroid Build Coastguard Worker        r = torch.ops.aten.huber_loss_backward(*inps, 0, 1.0)
1380*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r.device.type, 'meta')
1381*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r.shape, inps[0].shape)
1382*da0073e9SAndroid Build Coastguard Worker
1383*da0073e9SAndroid Build Coastguard Worker    def _norm_backwards_test_helper(self, op, args, output_mask, expected_shapes):
1384*da0073e9SAndroid Build Coastguard Worker
1385*da0073e9SAndroid Build Coastguard Worker        dtype = torch.float32
1386*da0073e9SAndroid Build Coastguard Worker        device = "meta"
1387*da0073e9SAndroid Build Coastguard Worker
1388*da0073e9SAndroid Build Coastguard Worker        # test functional call
1389*da0073e9SAndroid Build Coastguard Worker        grads = op(*args, output_mask)
1390*da0073e9SAndroid Build Coastguard Worker
1391*da0073e9SAndroid Build Coastguard Worker        def assertEqualShapes(res, exp):
1392*da0073e9SAndroid Build Coastguard Worker            self.assertIsNone(res) if exp is None else self.assertEqual(exp, res.shape)
1393*da0073e9SAndroid Build Coastguard Worker
1394*da0073e9SAndroid Build Coastguard Worker        assertEqualShapes(grads[0], expected_shapes[0])
1395*da0073e9SAndroid Build Coastguard Worker        assertEqualShapes(grads[1], expected_shapes[1])
1396*da0073e9SAndroid Build Coastguard Worker        assertEqualShapes(grads[2], expected_shapes[2])
1397*da0073e9SAndroid Build Coastguard Worker
1398*da0073e9SAndroid Build Coastguard Worker        out_kwargs = {
1399*da0073e9SAndroid Build Coastguard Worker            f"out{i}": torch.empty(0, device=device, dtype=dtype)
1400*da0073e9SAndroid Build Coastguard Worker            for i in range(len(output_mask))
1401*da0073e9SAndroid Build Coastguard Worker        }
1402*da0073e9SAndroid Build Coastguard Worker
1403*da0073e9SAndroid Build Coastguard Worker        # test call with out parameters
1404*da0073e9SAndroid Build Coastguard Worker        grads = op(*args, output_mask, **out_kwargs)
1405*da0073e9SAndroid Build Coastguard Worker
1406*da0073e9SAndroid Build Coastguard Worker        def assertEqualShapes(res, exp):
1407*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(exp, res.shape) if exp is not None else True
1408*da0073e9SAndroid Build Coastguard Worker
1409*da0073e9SAndroid Build Coastguard Worker        assertEqualShapes(out_kwargs["out0"], expected_shapes[0])
1410*da0073e9SAndroid Build Coastguard Worker        assertEqualShapes(out_kwargs["out1"], expected_shapes[1])
1411*da0073e9SAndroid Build Coastguard Worker        assertEqualShapes(out_kwargs["out2"], expected_shapes[2])
1412*da0073e9SAndroid Build Coastguard Worker
1413*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1414*da0073e9SAndroid Build Coastguard Worker    @parametrize("output_mask", list(itertools.product([True, False], [True, False], [True, False])))
1415*da0073e9SAndroid Build Coastguard Worker    def test_layer_norm_backward(self, output_mask):
1416*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.common_methods_invocations import sample_inputs_layer_norm
1417*da0073e9SAndroid Build Coastguard Worker
1418*da0073e9SAndroid Build Coastguard Worker        device = "meta"
1419*da0073e9SAndroid Build Coastguard Worker        dtype = torch.float32
1420*da0073e9SAndroid Build Coastguard Worker
1421*da0073e9SAndroid Build Coastguard Worker        samples = sample_inputs_layer_norm(None, device, dtype, requires_grad=False)
1422*da0073e9SAndroid Build Coastguard Worker
1423*da0073e9SAndroid Build Coastguard Worker        for sample in samples:
1424*da0073e9SAndroid Build Coastguard Worker            with self.subTest(sample=sample):
1425*da0073e9SAndroid Build Coastguard Worker                # handle optional weight and bias
1426*da0073e9SAndroid Build Coastguard Worker                if len(sample.args) != 3:
1427*da0073e9SAndroid Build Coastguard Worker                    sample.args = (*sample.args, *([None] * (3 - len(sample.args))))
1428*da0073e9SAndroid Build Coastguard Worker
1429*da0073e9SAndroid Build Coastguard Worker                grad_out = torch.ones_like(sample.input)
1430*da0073e9SAndroid Build Coastguard Worker                normalized_shape, weight, bias = sample.args
1431*da0073e9SAndroid Build Coastguard Worker                ndims_after_reduction = sample.input.ndim - len(normalized_shape)
1432*da0073e9SAndroid Build Coastguard Worker                mean_shape = grad_out.shape[:ndims_after_reduction]
1433*da0073e9SAndroid Build Coastguard Worker                mean = torch.zeros(mean_shape, device=device, dtype=dtype)
1434*da0073e9SAndroid Build Coastguard Worker                rstd = torch.zeros(mean_shape, device=device, dtype=dtype)
1435*da0073e9SAndroid Build Coastguard Worker
1436*da0073e9SAndroid Build Coastguard Worker                expected_shapes = (
1437*da0073e9SAndroid Build Coastguard Worker                    sample.input.shape if output_mask[0] else None,
1438*da0073e9SAndroid Build Coastguard Worker                    weight.shape if output_mask[1] and weight is not None else None,
1439*da0073e9SAndroid Build Coastguard Worker                    bias.shape if output_mask[2] and bias is not None else None)
1440*da0073e9SAndroid Build Coastguard Worker
1441*da0073e9SAndroid Build Coastguard Worker                args = [grad_out, sample.input, normalized_shape, mean, rstd, weight, bias]
1442*da0073e9SAndroid Build Coastguard Worker
1443*da0073e9SAndroid Build Coastguard Worker                self._norm_backwards_test_helper(torch.ops.aten.native_layer_norm_backward,
1444*da0073e9SAndroid Build Coastguard Worker                                                 args, output_mask, expected_shapes)
1445*da0073e9SAndroid Build Coastguard Worker
1446*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1447*da0073e9SAndroid Build Coastguard Worker    @parametrize("output_mask", list(itertools.product([True, False], [True, False], [True, False])))
1448*da0073e9SAndroid Build Coastguard Worker    def test_group_norm_backward(self, output_mask):
1449*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.common_methods_invocations import sample_inputs_group_norm
1450*da0073e9SAndroid Build Coastguard Worker
1451*da0073e9SAndroid Build Coastguard Worker        # input, (args) num_groups, (kwargs) weight, bias eps
1452*da0073e9SAndroid Build Coastguard Worker        device = "meta"
1453*da0073e9SAndroid Build Coastguard Worker        dtype = torch.float32
1454*da0073e9SAndroid Build Coastguard Worker        samples = sample_inputs_group_norm(None, device, dtype, requires_grad=False)
1455*da0073e9SAndroid Build Coastguard Worker
1456*da0073e9SAndroid Build Coastguard Worker        for sample in samples:
1457*da0073e9SAndroid Build Coastguard Worker            with self.subTest(sample=sample):
1458*da0073e9SAndroid Build Coastguard Worker                grad_out = torch.ones_like(sample.input)
1459*da0073e9SAndroid Build Coastguard Worker                N, C = sample.input.shape[:2]
1460*da0073e9SAndroid Build Coastguard Worker                HxW = torch.prod(torch.as_tensor(sample.input.shape[2:]), dtype=torch.int32).item()
1461*da0073e9SAndroid Build Coastguard Worker                group = sample.args[0]
1462*da0073e9SAndroid Build Coastguard Worker                mean = torch.zeros((N, group), device=device, dtype=dtype)
1463*da0073e9SAndroid Build Coastguard Worker                rstd = torch.zeros((N, group), device=device, dtype=dtype)
1464*da0073e9SAndroid Build Coastguard Worker                weight = torch.zeros((C), device=device, dtype=dtype)
1465*da0073e9SAndroid Build Coastguard Worker
1466*da0073e9SAndroid Build Coastguard Worker                args = [grad_out, sample.input, mean, rstd, weight, N, C, HxW, group]
1467*da0073e9SAndroid Build Coastguard Worker
1468*da0073e9SAndroid Build Coastguard Worker                expected_shapes = (
1469*da0073e9SAndroid Build Coastguard Worker                    sample.input.shape if output_mask[0] else None,
1470*da0073e9SAndroid Build Coastguard Worker                    weight.shape if output_mask[1] else None,
1471*da0073e9SAndroid Build Coastguard Worker                    weight.shape if output_mask[2] else None)
1472*da0073e9SAndroid Build Coastguard Worker
1473*da0073e9SAndroid Build Coastguard Worker                # test functional call
1474*da0073e9SAndroid Build Coastguard Worker                self._norm_backwards_test_helper(torch.ops.aten.native_group_norm_backward,
1475*da0073e9SAndroid Build Coastguard Worker                                                 args, output_mask, expected_shapes)
1476*da0073e9SAndroid Build Coastguard Worker
1477*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1478*da0073e9SAndroid Build Coastguard Worker    @parametrize("output_mask", list(itertools.product([True], [True, False], [True, False])))
1479*da0073e9SAndroid Build Coastguard Worker    def test_batch_norm_backward(self, output_mask):
1480*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.common_methods_invocations import sample_inputs_batch_norm
1481*da0073e9SAndroid Build Coastguard Worker
1482*da0073e9SAndroid Build Coastguard Worker        # input, (args) num_groups, (kwargs) weight, bias eps
1483*da0073e9SAndroid Build Coastguard Worker        device = "meta"
1484*da0073e9SAndroid Build Coastguard Worker        dtype = torch.float32
1485*da0073e9SAndroid Build Coastguard Worker        samples = sample_inputs_batch_norm(None, device, dtype, requires_grad=False)
1486*da0073e9SAndroid Build Coastguard Worker
1487*da0073e9SAndroid Build Coastguard Worker        for sample in samples:
1488*da0073e9SAndroid Build Coastguard Worker            with self.subTest(sample=sample):
1489*da0073e9SAndroid Build Coastguard Worker
1490*da0073e9SAndroid Build Coastguard Worker                if sample.input.dim() < 2:
1491*da0073e9SAndroid Build Coastguard Worker                    continue
1492*da0073e9SAndroid Build Coastguard Worker
1493*da0073e9SAndroid Build Coastguard Worker                grad_out = torch.ones_like(sample.input)
1494*da0073e9SAndroid Build Coastguard Worker                running_mean, running_var, weight, bias = sample.args
1495*da0073e9SAndroid Build Coastguard Worker                train = sample.kwargs.get("training", True)
1496*da0073e9SAndroid Build Coastguard Worker                save_mean = torch.zeros((sample.input.shape[1], ), device=device, dtype=dtype) if train else None
1497*da0073e9SAndroid Build Coastguard Worker                save_invstd = torch.zeros((sample.input.shape[1], ), device=device, dtype=dtype) if train else None
1498*da0073e9SAndroid Build Coastguard Worker
1499*da0073e9SAndroid Build Coastguard Worker                args = [grad_out, sample.input, weight, running_mean, running_var,
1500*da0073e9SAndroid Build Coastguard Worker                        save_mean, save_invstd, train, sample.kwargs.get("eps", 1e-5)]
1501*da0073e9SAndroid Build Coastguard Worker
1502*da0073e9SAndroid Build Coastguard Worker                expected_shapes = (
1503*da0073e9SAndroid Build Coastguard Worker                    sample.input.shape,
1504*da0073e9SAndroid Build Coastguard Worker                    torch.Size([sample.input.shape[1]]) if output_mask[1] else None,
1505*da0073e9SAndroid Build Coastguard Worker                    torch.Size([sample.input.shape[1]]) if output_mask[2] else None)
1506*da0073e9SAndroid Build Coastguard Worker
1507*da0073e9SAndroid Build Coastguard Worker                self._norm_backwards_test_helper(torch.ops.aten.native_batch_norm_backward,
1508*da0073e9SAndroid Build Coastguard Worker                                                 args, output_mask, expected_shapes)
1509*da0073e9SAndroid Build Coastguard Worker
1510*da0073e9SAndroid Build Coastguard Worker    def test_fill__alias_relationship(self):
1511*da0073e9SAndroid Build Coastguard Worker        inps = torch.rand(2**52, device='meta')
1512*da0073e9SAndroid Build Coastguard Worker        r = torch.ops.aten.fill_(inps, 1.0)
1513*da0073e9SAndroid Build Coastguard Worker        # aten.fill_ returns an aliase
1514*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(id(inps), id(r))
1515*da0073e9SAndroid Build Coastguard Worker
1516*da0073e9SAndroid Build Coastguard Worker        # aten.fill returns a new tensor
1517*da0073e9SAndroid Build Coastguard Worker        r2 = torch.ops.aten.fill(inps, 1.0)
1518*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(id(inps), id(r2))
1519*da0073e9SAndroid Build Coastguard Worker
1520*da0073e9SAndroid Build Coastguard Worker    def test_meta__fused_moving_avg_obs_fq_helper(self, device):
1521*da0073e9SAndroid Build Coastguard Worker        from torch.ao.quantization import FusedMovingAvgObsFakeQuantize
1522*da0073e9SAndroid Build Coastguard Worker        to_meta = MetaConverter()
1523*da0073e9SAndroid Build Coastguard Worker
1524*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5, device=device)
1525*da0073e9SAndroid Build Coastguard Worker        running_min_op = torch.tensor(float("inf"), device=device)
1526*da0073e9SAndroid Build Coastguard Worker        running_max_op = torch.tensor(float("-inf"), device=device)
1527*da0073e9SAndroid Build Coastguard Worker        avg_const = 0.01
1528*da0073e9SAndroid Build Coastguard Worker        scale = torch.tensor([1.0], device=device)
1529*da0073e9SAndroid Build Coastguard Worker        zero_point = torch.tensor([0], dtype=torch.int, device=device)
1530*da0073e9SAndroid Build Coastguard Worker
1531*da0073e9SAndroid Build Coastguard Worker        mod = FusedMovingAvgObsFakeQuantize()
1532*da0073e9SAndroid Build Coastguard Worker        torch.ao.quantization.enable_fake_quant(mod)
1533*da0073e9SAndroid Build Coastguard Worker        torch.ao.quantization.enable_observer(mod)
1534*da0073e9SAndroid Build Coastguard Worker        mod.to(device)
1535*da0073e9SAndroid Build Coastguard Worker
1536*da0073e9SAndroid Build Coastguard Worker        meta_x = to_meta(x)
1537*da0073e9SAndroid Build Coastguard Worker
1538*da0073e9SAndroid Build Coastguard Worker        args = [
1539*da0073e9SAndroid Build Coastguard Worker            x,
1540*da0073e9SAndroid Build Coastguard Worker            mod.observer_enabled,
1541*da0073e9SAndroid Build Coastguard Worker            mod.fake_quant_enabled,
1542*da0073e9SAndroid Build Coastguard Worker            running_min_op,
1543*da0073e9SAndroid Build Coastguard Worker            running_max_op,
1544*da0073e9SAndroid Build Coastguard Worker            scale,
1545*da0073e9SAndroid Build Coastguard Worker            zero_point,
1546*da0073e9SAndroid Build Coastguard Worker            avg_const,
1547*da0073e9SAndroid Build Coastguard Worker            0,
1548*da0073e9SAndroid Build Coastguard Worker            255,
1549*da0073e9SAndroid Build Coastguard Worker            0,
1550*da0073e9SAndroid Build Coastguard Worker        ]
1551*da0073e9SAndroid Build Coastguard Worker
1552*da0073e9SAndroid Build Coastguard Worker        meta_args = args.copy()
1553*da0073e9SAndroid Build Coastguard Worker        meta_args[0] = meta_x
1554*da0073e9SAndroid Build Coastguard Worker
1555*da0073e9SAndroid Build Coastguard Worker        kwargss = [
1556*da0073e9SAndroid Build Coastguard Worker            {},
1557*da0073e9SAndroid Build Coastguard Worker            {"per_row_fake_quant": False, "symmetric_quant": False},
1558*da0073e9SAndroid Build Coastguard Worker            {"per_row_fake_quant": False, "symmetric_quant": True},
1559*da0073e9SAndroid Build Coastguard Worker        ]
1560*da0073e9SAndroid Build Coastguard Worker
1561*da0073e9SAndroid Build Coastguard Worker        for kwargs in kwargss:
1562*da0073e9SAndroid Build Coastguard Worker            ref_out = aten._fused_moving_avg_obs_fq_helper.default(*args, **kwargs)
1563*da0073e9SAndroid Build Coastguard Worker            meta_out = aten._fused_moving_avg_obs_fq_helper.default(*meta_args, **kwargs)
1564*da0073e9SAndroid Build Coastguard Worker
1565*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ref_out[0].size(), meta_out[0].size())
1566*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ref_out[0].stride(), meta_out[0].stride())
1567*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ref_out[1].size(), meta_out[1].size())
1568*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ref_out[1].stride(), meta_out[1].stride())
1569*da0073e9SAndroid Build Coastguard Worker
1570*da0073e9SAndroid Build Coastguard Worker    def test_cdist_forward(self, device):
1571*da0073e9SAndroid Build Coastguard Worker        to_meta = MetaConverter()
1572*da0073e9SAndroid Build Coastguard Worker        x1 = torch.rand([3, 2], device=device)
1573*da0073e9SAndroid Build Coastguard Worker        x2 = torch.rand([2, 2], device=device)
1574*da0073e9SAndroid Build Coastguard Worker        p = 2.0
1575*da0073e9SAndroid Build Coastguard Worker        for compute_mode in (None, 1, 2):
1576*da0073e9SAndroid Build Coastguard Worker            ref = aten._cdist_forward.default(x1, x2, p, compute_mode)
1577*da0073e9SAndroid Build Coastguard Worker            res = aten._cdist_forward.default(to_meta(x1), to_meta(x2), p, compute_mode)
1578*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res.device.type, 'meta')
1579*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ref.shape, res.shape)
1580*da0073e9SAndroid Build Coastguard Worker
1581*da0073e9SAndroid Build Coastguard Worker    def test_quantized_embedding_bag(self):
1582*da0073e9SAndroid Build Coastguard Worker        tab_shape = [8, 128]
1583*da0073e9SAndroid Build Coastguard Worker        emb_size, ind_len, off_len = tab_shape[0], 32, 33
1584*da0073e9SAndroid Build Coastguard Worker        f_table = torch.from_numpy((np.random.random_sample(tab_shape) + 1).astype(np.float32))
1585*da0073e9SAndroid Build Coastguard Worker        q_table = torch.ops.quantized.embedding_bag_byte_prepack(f_table)
1586*da0073e9SAndroid Build Coastguard Worker        indices = torch.from_numpy(np.random.randint(low=0, high=emb_size, size=ind_len)).int()
1587*da0073e9SAndroid Build Coastguard Worker        max_length = len(indices) // (off_len - 1)
1588*da0073e9SAndroid Build Coastguard Worker        if max_length > 20:
1589*da0073e9SAndroid Build Coastguard Worker            max_length = 20
1590*da0073e9SAndroid Build Coastguard Worker        np_lengths = np.random.randint(0, max_length + 1, size=off_len - 1).astype(np.int32)
1591*da0073e9SAndroid Build Coastguard Worker        offsets = torch.cat([torch.zeros([1]), torch.cumsum(torch.from_numpy(np_lengths), 0)]).int()
1592*da0073e9SAndroid Build Coastguard Worker
1593*da0073e9SAndroid Build Coastguard Worker        eb = torch.ops.quantized.embedding_bag_byte_rowwise_offsets(
1594*da0073e9SAndroid Build Coastguard Worker            q_table.to(device="meta"),
1595*da0073e9SAndroid Build Coastguard Worker            indices.to(device="meta"),
1596*da0073e9SAndroid Build Coastguard Worker            offsets.to(device="meta"),
1597*da0073e9SAndroid Build Coastguard Worker            mode=0,  # sum
1598*da0073e9SAndroid Build Coastguard Worker            per_sample_weights=None,
1599*da0073e9SAndroid Build Coastguard Worker            include_last_offset=True,
1600*da0073e9SAndroid Build Coastguard Worker        )
1601*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eb.shape, [32, 128])
1602*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eb.dtype, torch.float32)
1603*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eb.untyped_storage().data_ptr(), 0)
1604*da0073e9SAndroid Build Coastguard Worker
1605*da0073e9SAndroid Build Coastguard Worker    # Tests mean and max.
1606*da0073e9SAndroid Build Coastguard Worker    # Can't easily test sum, because there is a fast path for sum which
1607*da0073e9SAndroid Build Coastguard Worker    # causes offset2bag to not get allocated... but the backward function
1608*da0073e9SAndroid Build Coastguard Worker    # needs it, and the offset2bag computation lives inside the
1609*da0073e9SAndroid Build Coastguard Worker    # derivatives.yaml formula directly, so there is no way to access it.
1610*da0073e9SAndroid Build Coastguard Worker    # To test sum, need to manually compute offset2bag
1611*da0073e9SAndroid Build Coastguard Worker    @parametrize("mode", [1, 2])
1612*da0073e9SAndroid Build Coastguard Worker    def test_embedding_bag_dense_backward(self, mode):
1613*da0073e9SAndroid Build Coastguard Worker        weight = torch.randn(4, 3, requires_grad=True)
1614*da0073e9SAndroid Build Coastguard Worker        indices = torch.tensor([1, 0, 2, 1, 3])
1615*da0073e9SAndroid Build Coastguard Worker        offsets = torch.tensor([0, 2, 3, 5])
1616*da0073e9SAndroid Build Coastguard Worker        scale_grad_by_freq = False
1617*da0073e9SAndroid Build Coastguard Worker        sparse = False
1618*da0073e9SAndroid Build Coastguard Worker        per_sample_weights = None
1619*da0073e9SAndroid Build Coastguard Worker        include_last_offset = False
1620*da0073e9SAndroid Build Coastguard Worker        padding_idx = -1
1621*da0073e9SAndroid Build Coastguard Worker
1622*da0073e9SAndroid Build Coastguard Worker        output, offset2bag, bag_size, maximum_indices = torch.ops.aten._embedding_bag.default(
1623*da0073e9SAndroid Build Coastguard Worker            weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx
1624*da0073e9SAndroid Build Coastguard Worker        )
1625*da0073e9SAndroid Build Coastguard Worker        grad = torch.randn_like(output)
1626*da0073e9SAndroid Build Coastguard Worker
1627*da0073e9SAndroid Build Coastguard Worker        # Call the function with example inputs
1628*da0073e9SAndroid Build Coastguard Worker        grad_weight = torch.ops.aten._embedding_bag_dense_backward.default(
1629*da0073e9SAndroid Build Coastguard Worker            grad, indices, offset2bag, bag_size, maximum_indices, weight.size(0),
1630*da0073e9SAndroid Build Coastguard Worker            scale_grad_by_freq, mode, per_sample_weights, padding_idx
1631*da0073e9SAndroid Build Coastguard Worker        )
1632*da0073e9SAndroid Build Coastguard Worker        meta_grad_weight = torch.ops.aten._embedding_bag_dense_backward.default(
1633*da0073e9SAndroid Build Coastguard Worker            grad.to('meta'), indices.to('meta'), offset2bag.to('meta'), bag_size.to('meta'),
1634*da0073e9SAndroid Build Coastguard Worker            maximum_indices.to('meta'), weight.size(0),
1635*da0073e9SAndroid Build Coastguard Worker            scale_grad_by_freq, mode, per_sample_weights, padding_idx
1636*da0073e9SAndroid Build Coastguard Worker        )
1637*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(grad_weight.to('meta'), meta_grad_weight)
1638*da0073e9SAndroid Build Coastguard Worker
1639*da0073e9SAndroid Build Coastguard Worker    def test_embedding_bag_dense_backward_per_sample_weights(self):
1640*da0073e9SAndroid Build Coastguard Worker        weight = torch.randn(4, 3, requires_grad=True)
1641*da0073e9SAndroid Build Coastguard Worker        indices = torch.tensor([1, 0, 2, 1, 3])
1642*da0073e9SAndroid Build Coastguard Worker        offsets = torch.tensor([0, 2, 3, 5])
1643*da0073e9SAndroid Build Coastguard Worker        scale_grad_by_freq = False
1644*da0073e9SAndroid Build Coastguard Worker        sparse = False
1645*da0073e9SAndroid Build Coastguard Worker        mode = 0
1646*da0073e9SAndroid Build Coastguard Worker        per_sample_weights = torch.randn(5, requires_grad=True)
1647*da0073e9SAndroid Build Coastguard Worker        include_last_offset = False
1648*da0073e9SAndroid Build Coastguard Worker        padding_idx = -1
1649*da0073e9SAndroid Build Coastguard Worker
1650*da0073e9SAndroid Build Coastguard Worker        output, offset2bag, bag_size, maximum_indices = torch.ops.aten._embedding_bag.default(
1651*da0073e9SAndroid Build Coastguard Worker            weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx
1652*da0073e9SAndroid Build Coastguard Worker        )
1653*da0073e9SAndroid Build Coastguard Worker        grad = torch.randn_like(output)
1654*da0073e9SAndroid Build Coastguard Worker
1655*da0073e9SAndroid Build Coastguard Worker        # Call the function with example inputs
1656*da0073e9SAndroid Build Coastguard Worker        grad_weight = torch.ops.aten._embedding_bag_per_sample_weights_backward.default(
1657*da0073e9SAndroid Build Coastguard Worker            grad, weight, indices, offsets, offset2bag, mode, padding_idx
1658*da0073e9SAndroid Build Coastguard Worker        )
1659*da0073e9SAndroid Build Coastguard Worker        meta_grad_weight = torch.ops.aten._embedding_bag_per_sample_weights_backward.default(
1660*da0073e9SAndroid Build Coastguard Worker            grad.to('meta'), weight.to('meta'), indices.to('meta'),
1661*da0073e9SAndroid Build Coastguard Worker            offsets.to('meta'), offset2bag.to('meta'), mode, padding_idx
1662*da0073e9SAndroid Build Coastguard Worker        )
1663*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(grad_weight.to('meta'), meta_grad_weight)
1664*da0073e9SAndroid Build Coastguard Worker
1665*da0073e9SAndroid Build Coastguard Worker    # opinfo test is using aten.fill_, it's not testing aten.fill
1666*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1667*da0073e9SAndroid Build Coastguard Worker    def test_fill_stride(self):
1668*da0073e9SAndroid Build Coastguard Worker        to_meta = MetaConverter()
1669*da0073e9SAndroid Build Coastguard Worker        sample_args = [torch.rand(2, 2, 2, 2), 1.0]
1670*da0073e9SAndroid Build Coastguard Worker
1671*da0073e9SAndroid Build Coastguard Worker        for args in get_strided_args(sample_args):
1672*da0073e9SAndroid Build Coastguard Worker            meta_args = to_meta(args)
1673*da0073e9SAndroid Build Coastguard Worker            ref_out = torch.ops.aten.fill(*args)
1674*da0073e9SAndroid Build Coastguard Worker            meta_out = torch.ops.aten.fill(*meta_args)
1675*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ref_out.size(), meta_out.size())
1676*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ref_out.stride(), meta_out.stride())
1677*da0073e9SAndroid Build Coastguard Worker
1678*da0073e9SAndroid Build Coastguard Worker
1679*da0073e9SAndroid Build Coastguard Worker    def test_map_location_deserialize(self):
1680*da0073e9SAndroid Build Coastguard Worker        import io
1681*da0073e9SAndroid Build Coastguard Worker
1682*da0073e9SAndroid Build Coastguard Worker        t = torch.rand(10)
1683*da0073e9SAndroid Build Coastguard Worker        b = io.BytesIO()
1684*da0073e9SAndroid Build Coastguard Worker
1685*da0073e9SAndroid Build Coastguard Worker        torch.save(t, b)
1686*da0073e9SAndroid Build Coastguard Worker        b.seek(0)
1687*da0073e9SAndroid Build Coastguard Worker        r = torch.load(b, map_location=torch.device("meta"))
1688*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r.device.type, 'meta')
1689*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r.shape, t.shape)
1690*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r.dtype, t.dtype)
1691*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r.storage().data_ptr(), 0)
1692*da0073e9SAndroid Build Coastguard Worker
1693*da0073e9SAndroid Build Coastguard Worker    def test_embedding_bag_byte_prepack(self):
1694*da0073e9SAndroid Build Coastguard Worker        batch_size = 10
1695*da0073e9SAndroid Build Coastguard Worker        num_embeddings = 80
1696*da0073e9SAndroid Build Coastguard Worker        embedding_dim = [128, 256, 512]
1697*da0073e9SAndroid Build Coastguard Worker        res_shape = [[batch_size, num_embeddings, ed + 8] for ed in embedding_dim]
1698*da0073e9SAndroid Build Coastguard Worker        for ed, rs in zip(embedding_dim, res_shape):
1699*da0073e9SAndroid Build Coastguard Worker            weight = torch.randn(batch_size, num_embeddings, ed, dtype=torch.float32)
1700*da0073e9SAndroid Build Coastguard Worker            res = torch.ops.quantized.embedding_bag_byte_prepack(weight.to(device="meta"))
1701*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res.shape, rs)
1702*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res.dtype, torch.float32)
1703*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res.untyped_storage().data_ptr(), 0)
1704*da0073e9SAndroid Build Coastguard Worker
1705*da0073e9SAndroid Build Coastguard Worker    def test_embedding_bag_byte_unpack(self):
1706*da0073e9SAndroid Build Coastguard Worker        batch_size = 10
1707*da0073e9SAndroid Build Coastguard Worker        num_embeddings = 80
1708*da0073e9SAndroid Build Coastguard Worker        embedding_dim = [128, 256, 512]
1709*da0073e9SAndroid Build Coastguard Worker        res_shape = [[batch_size, num_embeddings, ed] for ed in embedding_dim]
1710*da0073e9SAndroid Build Coastguard Worker        for ed, rs in zip(embedding_dim, res_shape):
1711*da0073e9SAndroid Build Coastguard Worker            packed_weight = torch.randn(batch_size, num_embeddings, ed + 8, dtype=torch.float32)
1712*da0073e9SAndroid Build Coastguard Worker            res = torch.ops.quantized.embedding_bag_byte_unpack(packed_weight.to(device="meta"))
1713*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res.shape, rs)
1714*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res.dtype, torch.float32)
1715*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res.untyped_storage().data_ptr(), 0)
1716*da0073e9SAndroid Build Coastguard Worker
1717*da0073e9SAndroid Build Coastguard Worker    def test_index_select_out(self):
1718*da0073e9SAndroid Build Coastguard Worker        def f():
1719*da0073e9SAndroid Build Coastguard Worker            input = torch.randn([8, 16], device='meta')
1720*da0073e9SAndroid Build Coastguard Worker            index = torch.tensor([2, 1, 6, 7, 3, 1, 7, 5, 6, 7], device='meta')
1721*da0073e9SAndroid Build Coastguard Worker            out = torch.empty([10, 16], device='meta')
1722*da0073e9SAndroid Build Coastguard Worker            return torch.index_select(input=input, dim=0, index=index, out=out)
1723*da0073e9SAndroid Build Coastguard Worker        with enable_python_dispatcher():
1724*da0073e9SAndroid Build Coastguard Worker            out = f()
1725*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out.shape, [10, 16])
1726*da0073e9SAndroid Build Coastguard Worker
1727*da0073e9SAndroid Build Coastguard Worker    def test_local_scalar_dense_call(self):
1728*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "cannot be called on meta tensors"):
1729*da0073e9SAndroid Build Coastguard Worker            meta_tensor = torch.randn(1, device='meta')
1730*da0073e9SAndroid Build Coastguard Worker            meta_tensor.item()
1731*da0073e9SAndroid Build Coastguard Worker
1732*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestMeta, globals())
1733*da0073e9SAndroid Build Coastguard Worker
1734*da0073e9SAndroid Build Coastguard Workerdef print_op_str_if_not_supported(op_str):
1735*da0073e9SAndroid Build Coastguard Worker    op = OperatorName.parse(op_str)
1736*da0073e9SAndroid Build Coastguard Worker    packet = getattr(torch.ops.aten, str(op.name))
1737*da0073e9SAndroid Build Coastguard Worker    overload = getattr(packet, op.overload_name if op.overload_name else "default")
1738*da0073e9SAndroid Build Coastguard Worker    if any(overload in d for d in [meta_dispatch_skips, meta_dispatch_device_skips['cuda']]):
1739*da0073e9SAndroid Build Coastguard Worker        print(f"{overload}  # SKIP")
1740*da0073e9SAndroid Build Coastguard Worker    if any(overload in d for d in [meta_dispatch_expected_failures, meta_dispatch_device_expected_failures['cuda']]):
1741*da0073e9SAndroid Build Coastguard Worker        print(overload)
1742*da0073e9SAndroid Build Coastguard Worker
1743*da0073e9SAndroid Build Coastguard Worker
1744*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
1745*da0073e9SAndroid Build Coastguard Worker    COMPARE_XLA = os.getenv('PYTORCH_COMPARE_XLA', None)
1746*da0073e9SAndroid Build Coastguard Worker    if COMPARE_XLA is not None:
1747*da0073e9SAndroid Build Coastguard Worker        with open(COMPARE_XLA) as f:
1748*da0073e9SAndroid Build Coastguard Worker            d = yaml.load(f, Loader=YamlLoader)
1749*da0073e9SAndroid Build Coastguard Worker            ops = d.get("full_codegen", []) + d.get("supported", []) + d.get("autograd", [])
1750*da0073e9SAndroid Build Coastguard Worker            for op_str in ops:
1751*da0073e9SAndroid Build Coastguard Worker                print_op_str_if_not_supported(op_str)
1752*da0073e9SAndroid Build Coastguard Worker        sys.exit(0)
1753*da0073e9SAndroid Build Coastguard Worker
1754*da0073e9SAndroid Build Coastguard Worker    COMPARE_TEXT = os.getenv('PYTORCH_COMPARE_TEXT', None)
1755*da0073e9SAndroid Build Coastguard Worker    if COMPARE_TEXT is not None:
1756*da0073e9SAndroid Build Coastguard Worker        with open(COMPARE_TEXT) as f:
1757*da0073e9SAndroid Build Coastguard Worker            for op_str in f:
1758*da0073e9SAndroid Build Coastguard Worker                print_op_str_if_not_supported(op_str.strip())
1759*da0073e9SAndroid Build Coastguard Worker        sys.exit(0)
1760*da0073e9SAndroid Build Coastguard Worker
1761*da0073e9SAndroid Build Coastguard Worker    run_tests()
1762