1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: decompositions"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport itertools 4*da0073e9SAndroid Build Coastguard Workerimport torch 5*da0073e9SAndroid Build Coastguard Workerimport os 6*da0073e9SAndroid Build Coastguard Workerimport numpy as np 7*da0073e9SAndroid Build Coastguard Workerfrom enum import Enum 8*da0073e9SAndroid Build Coastguard Workerfrom torch.overrides import resolve_name 9*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._pytree import tree_map, tree_flatten, tree_unflatten 10*da0073e9SAndroid Build Coastguard Workerfrom torch.utils import _pytree as pytree 11*da0073e9SAndroid Build Coastguard Workerfrom torch._subclasses.meta_utils import MetaConverter, assert_metadata_eq, is_sparse_any 12*da0073e9SAndroid Build Coastguard Workerimport torch.utils._python_dispatch 13*da0073e9SAndroid Build Coastguard Workerfrom torch._dispatch.python import enable_python_dispatcher 14*da0073e9SAndroid Build Coastguard Workerfrom torch._ops import OpOverload, OpOverloadPacket 15*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor 16*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import unMarkDynamoStrictTest 17*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 18*da0073e9SAndroid Build Coastguard Worker TestCase, 19*da0073e9SAndroid Build Coastguard Worker skipIfCrossRef, 20*da0073e9SAndroid Build Coastguard Worker skipIfTorchDynamo, 21*da0073e9SAndroid Build Coastguard Worker suppress_warnings, 22*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ASAN, 23*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TORCHDYNAMO, 24*da0073e9SAndroid Build Coastguard Worker run_tests, 25*da0073e9SAndroid Build Coastguard Worker dtype_abbrs, 26*da0073e9SAndroid Build Coastguard Worker parametrize 27*da0073e9SAndroid Build Coastguard Worker) 28*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ( 29*da0073e9SAndroid Build Coastguard Worker ops, 30*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests, 31*da0073e9SAndroid Build Coastguard Worker onlyCUDA, 32*da0073e9SAndroid Build Coastguard Worker onlyCPU, 33*da0073e9SAndroid Build Coastguard Worker OpDTypes, 34*da0073e9SAndroid Build Coastguard Worker) 35*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import ( 36*da0073e9SAndroid Build Coastguard Worker binary_ufuncs, op_db, foreach_unary_op_db, foreach_binary_op_db, 37*da0073e9SAndroid Build Coastguard Worker foreach_pointwise_op_db, foreach_reduce_op_db, foreach_other_op_db) 38*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.opinfo.core import S, SampleInput 39*da0073e9SAndroid Build Coastguard Workerfrom torchgen.yaml_utils import YamlLoader 40*da0073e9SAndroid Build Coastguard Workerfrom torchgen.model import OperatorName 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Workerimport copy 43*da0073e9SAndroid Build Coastguard Workerimport sys 44*da0073e9SAndroid Build Coastguard Workerimport yaml 45*da0073e9SAndroid Build Coastguard Workerimport atexit 46*da0073e9SAndroid Build Coastguard Workerimport re 47*da0073e9SAndroid Build Coastguard Workerfrom collections import defaultdict 48*da0073e9SAndroid Build Coastguard Workerfrom collections.abc import Iterable 49*da0073e9SAndroid Build Coastguard Workerimport unittest 50*da0073e9SAndroid Build Coastguard Workerimport warnings 51*da0073e9SAndroid Build Coastguard Workerimport weakref 52*da0073e9SAndroid Build Coastguard Workerfrom functools import partial, wraps 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Workerbf16 = torch.bfloat16 55*da0073e9SAndroid Build Coastguard Workerf64 = torch.float64 56*da0073e9SAndroid Build Coastguard Workerf32 = torch.float32 57*da0073e9SAndroid Build Coastguard Workerf16 = torch.float16 58*da0073e9SAndroid Build Coastguard Workerc32 = torch.complex32 59*da0073e9SAndroid Build Coastguard Workerc64 = torch.complex64 60*da0073e9SAndroid Build Coastguard Workerc128 = torch.complex128 61*da0073e9SAndroid Build Coastguard Workeri8 = torch.int8 62*da0073e9SAndroid Build Coastguard Workeri16 = torch.int16 63*da0073e9SAndroid Build Coastguard Workeri32 = torch.int32 64*da0073e9SAndroid Build Coastguard Workeri64 = torch.int64 65*da0073e9SAndroid Build Coastguard Workerb8 = torch.bool 66*da0073e9SAndroid Build Coastguard Workeru8 = torch.uint8 67*da0073e9SAndroid Build Coastguard Workeru16 = torch.uint16 68*da0073e9SAndroid Build Coastguard Workeru32 = torch.uint32 69*da0073e9SAndroid Build Coastguard Workeru64 = torch.uint64 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Workerforeach_op_db = ( 72*da0073e9SAndroid Build Coastguard Worker foreach_unary_op_db + 73*da0073e9SAndroid Build Coastguard Worker foreach_binary_op_db + 74*da0073e9SAndroid Build Coastguard Worker foreach_pointwise_op_db + 75*da0073e9SAndroid Build Coastguard Worker foreach_reduce_op_db + 76*da0073e9SAndroid Build Coastguard Worker foreach_other_op_db 77*da0073e9SAndroid Build Coastguard Worker) 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Workerclass TestMetaConverter(TestCase): 81*da0073e9SAndroid Build Coastguard Worker def assertSameVersionCounter(self, m1, m2): 82*da0073e9SAndroid Build Coastguard Worker # Cannot easily test m1 and m2 have same storage due to 83*da0073e9SAndroid Build Coastguard Worker # lack of Storage bindings. Use version counter. 84*da0073e9SAndroid Build Coastguard Worker vc = m1._version 85*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m2._version, vc) 86*da0073e9SAndroid Build Coastguard Worker # Doing it this way ensures that we get VC bump even with leaves 87*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 88*da0073e9SAndroid Build Coastguard Worker m1._base.add_(3) 89*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(m1._version, vc) 90*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m2._version, m1._version) 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker def assertMetadataMatches(self, m1, m2): 93*da0073e9SAndroid Build Coastguard Worker assert_metadata_eq(self.assertEqual, m1, m2) 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker def test_view_of_non_leaf(self): 96*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, requires_grad=True) 97*da0073e9SAndroid Build Coastguard Worker y = x.neg() 98*da0073e9SAndroid Build Coastguard Worker z1 = y[:] 99*da0073e9SAndroid Build Coastguard Worker z2 = y[:] 100*da0073e9SAndroid Build Coastguard Worker to_meta = MetaConverter() 101*da0073e9SAndroid Build Coastguard Worker m1 = to_meta(z1) 102*da0073e9SAndroid Build Coastguard Worker m2 = to_meta(z2) 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker # check the test is actually testing what it claims 105*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m1._is_view()) 106*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m1._base.is_leaf) 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(m1, m2) 109*da0073e9SAndroid Build Coastguard Worker self.assertMetadataMatches(m1, z1) 110*da0073e9SAndroid Build Coastguard Worker self.assertMetadataMatches(m2, z2) 111*da0073e9SAndroid Build Coastguard Worker self.assertSameVersionCounter(m1, m2) 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Worker def test_view_of_leaf(self): 114*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, requires_grad=True) 115*da0073e9SAndroid Build Coastguard Worker z1 = x[:] 116*da0073e9SAndroid Build Coastguard Worker z2 = x[:] 117*da0073e9SAndroid Build Coastguard Worker to_meta = MetaConverter() 118*da0073e9SAndroid Build Coastguard Worker m1 = to_meta(z1) 119*da0073e9SAndroid Build Coastguard Worker m2 = to_meta(z2) 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker # check the test is actually testing what it claims 122*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m1._is_view()) 123*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m1._base.is_leaf) 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(m1, m2) 126*da0073e9SAndroid Build Coastguard Worker self.assertMetadataMatches(m1, z1) 127*da0073e9SAndroid Build Coastguard Worker self.assertMetadataMatches(m2, z2) 128*da0073e9SAndroid Build Coastguard Worker self.assertSameVersionCounter(m1, m2) 129*da0073e9SAndroid Build Coastguard Worker 130*da0073e9SAndroid Build Coastguard Worker def test_view_of_view_of_leaf(self): 131*da0073e9SAndroid Build Coastguard Worker x = torch.randn(8) 132*da0073e9SAndroid Build Coastguard Worker y = x.view(2, 4) 133*da0073e9SAndroid Build Coastguard Worker y.requires_grad = True 134*da0073e9SAndroid Build Coastguard Worker z = y.view(2, 2, 2) 135*da0073e9SAndroid Build Coastguard Worker 136*da0073e9SAndroid Build Coastguard Worker to_meta = MetaConverter() 137*da0073e9SAndroid Build Coastguard Worker mx = to_meta(x) 138*da0073e9SAndroid Build Coastguard Worker mz = to_meta(z) 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker self.assertFalse(z.is_leaf) 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker self.assertMetadataMatches(mx, x) 143*da0073e9SAndroid Build Coastguard Worker self.assertMetadataMatches(mz, z) 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Worker def test_leaf(self): 146*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, requires_grad=True) 147*da0073e9SAndroid Build Coastguard Worker to_meta = MetaConverter() 148*da0073e9SAndroid Build Coastguard Worker m = to_meta(x) 149*da0073e9SAndroid Build Coastguard Worker 150*da0073e9SAndroid Build Coastguard Worker # check the test is actually testing what it claims 151*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m.is_leaf) 152*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m.requires_grad) 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Worker self.assertMetadataMatches(m, x) 155*da0073e9SAndroid Build Coastguard Worker 156*da0073e9SAndroid Build Coastguard Worker def test_non_leaf(self): 157*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, requires_grad=True) 158*da0073e9SAndroid Build Coastguard Worker y = x.neg() 159*da0073e9SAndroid Build Coastguard Worker to_meta = MetaConverter() 160*da0073e9SAndroid Build Coastguard Worker m = to_meta(y) 161*da0073e9SAndroid Build Coastguard Worker 162*da0073e9SAndroid Build Coastguard Worker # check the test is actually testing what it claims 163*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m.is_leaf) 164*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m.requires_grad) 165*da0073e9SAndroid Build Coastguard Worker 166*da0073e9SAndroid Build Coastguard Worker self.assertMetadataMatches(m, y) 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker def test_requires_grad_false(self): 169*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, requires_grad=False) 170*da0073e9SAndroid Build Coastguard Worker to_meta = MetaConverter() 171*da0073e9SAndroid Build Coastguard Worker m = to_meta(x) 172*da0073e9SAndroid Build Coastguard Worker 173*da0073e9SAndroid Build Coastguard Worker # check the test is actually testing what it claims 174*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m.requires_grad) 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker self.assertMetadataMatches(m, x) 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Worker def test_channels_last(self): 179*da0073e9SAndroid Build Coastguard Worker x = torch.empty(2, 3, 4, 5, memory_format=torch.channels_last) 180*da0073e9SAndroid Build Coastguard Worker to_meta = MetaConverter() 181*da0073e9SAndroid Build Coastguard Worker m = to_meta(x) 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Worker # check the test is actually testing what it claims 184*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m.is_leaf) 185*da0073e9SAndroid Build Coastguard Worker 186*da0073e9SAndroid Build Coastguard Worker self.assertMetadataMatches(m, x) 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Worker def test_channels_last_leaf(self): 189*da0073e9SAndroid Build Coastguard Worker x = torch.empty(2, 3, 4, 5, memory_format=torch.channels_last, requires_grad=True) 190*da0073e9SAndroid Build Coastguard Worker to_meta = MetaConverter() 191*da0073e9SAndroid Build Coastguard Worker m = to_meta(x) 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker # check the test is actually testing what it claims 194*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m.requires_grad) 195*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m.is_leaf) 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker self.assertMetadataMatches(m, x) 198*da0073e9SAndroid Build Coastguard Worker 199*da0073e9SAndroid Build Coastguard Worker def test_channels_last_non_leaf(self): 200*da0073e9SAndroid Build Coastguard Worker x = torch.empty(2, 3, 4, 5, memory_format=torch.channels_last, requires_grad=True) 201*da0073e9SAndroid Build Coastguard Worker y = x + 2 202*da0073e9SAndroid Build Coastguard Worker 203*da0073e9SAndroid Build Coastguard Worker # sanity 204*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.stride(), y.stride()) 205*da0073e9SAndroid Build Coastguard Worker self.assertFalse(y.is_leaf) 206*da0073e9SAndroid Build Coastguard Worker 207*da0073e9SAndroid Build Coastguard Worker to_meta = MetaConverter() 208*da0073e9SAndroid Build Coastguard Worker m = to_meta(y) 209*da0073e9SAndroid Build Coastguard Worker 210*da0073e9SAndroid Build Coastguard Worker # check the test is actually testing what it claims 211*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m.requires_grad) 212*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m.is_leaf) 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker self.assertMetadataMatches(m, y) 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Worker # Check that we can autograd with m as input without erroring; 217*da0073e9SAndroid Build Coastguard Worker # see https://github.com/pytorch/pytorch/issues/87956 218*da0073e9SAndroid Build Coastguard Worker loss = m.sum() 219*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(loss, m) 220*da0073e9SAndroid Build Coastguard Worker 221*da0073e9SAndroid Build Coastguard Worker def test_empty_strided_non_dense_leaf(self): 222*da0073e9SAndroid Build Coastguard Worker x = torch.empty_strided((2, 2), (4, 2), requires_grad=True) 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Worker to_meta = MetaConverter() 225*da0073e9SAndroid Build Coastguard Worker m = to_meta(x) 226*da0073e9SAndroid Build Coastguard Worker 227*da0073e9SAndroid Build Coastguard Worker # check the test is actually testing what it claims 228*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m.requires_grad) 229*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m.is_leaf) 230*da0073e9SAndroid Build Coastguard Worker 231*da0073e9SAndroid Build Coastguard Worker self.assertMetadataMatches(m, x) 232*da0073e9SAndroid Build Coastguard Worker 233*da0073e9SAndroid Build Coastguard Worker def test_view_mutate(self): 234*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(4) 235*da0073e9SAndroid Build Coastguard Worker y = x.view(2, 2) 236*da0073e9SAndroid Build Coastguard Worker 237*da0073e9SAndroid Build Coastguard Worker to_meta = MetaConverter() 238*da0073e9SAndroid Build Coastguard Worker m = to_meta(y) 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Worker y.add_(torch.randn(2, 2, requires_grad=True)) 241*da0073e9SAndroid Build Coastguard Worker m.add_(torch.randn(2, 2, device='meta', requires_grad=True)) 242*da0073e9SAndroid Build Coastguard Worker 243*da0073e9SAndroid Build Coastguard Worker def test_non_leaf_torture(self): 244*da0073e9SAndroid Build Coastguard Worker x = torch.empty(20, requires_grad=True) 245*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 246*da0073e9SAndroid Build Coastguard Worker x.set_(x.storage(), 10, (2,), (2,)) 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard Worker to_meta = MetaConverter() 249*da0073e9SAndroid Build Coastguard Worker m = to_meta(x) 250*da0073e9SAndroid Build Coastguard Worker 251*da0073e9SAndroid Build Coastguard Worker # check the test is actually testing what it claims 252*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m.requires_grad) 253*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m.is_leaf) 254*da0073e9SAndroid Build Coastguard Worker 255*da0073e9SAndroid Build Coastguard Worker self.assertMetadataMatches(m, x) 256*da0073e9SAndroid Build Coastguard Worker 257*da0073e9SAndroid Build Coastguard Worker # NB: complex stuff is not actually exercised right now because 258*da0073e9SAndroid Build Coastguard Worker # we have a blanket exclusion for complex conversion 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker def test_view_as_real(self): 261*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, dtype=torch.complex64) 262*da0073e9SAndroid Build Coastguard Worker y = torch.view_as_real(x) 263*da0073e9SAndroid Build Coastguard Worker m = MetaConverter()(y) 264*da0073e9SAndroid Build Coastguard Worker self.assertMetadataMatches(m, y) 265*da0073e9SAndroid Build Coastguard Worker 266*da0073e9SAndroid Build Coastguard Worker def test_complex_noncontiguous_bug(self): 267*da0073e9SAndroid Build Coastguard Worker x = torch.randn((2, 2, 4, 9), dtype=torch.complex32)[:, 0, :, :] 268*da0073e9SAndroid Build Coastguard Worker m = MetaConverter()(x) 269*da0073e9SAndroid Build Coastguard Worker self.assertMetadataMatches(m, x) 270*da0073e9SAndroid Build Coastguard Worker 271*da0073e9SAndroid Build Coastguard Worker def test_view_as_complex(self): 272*da0073e9SAndroid Build Coastguard Worker x = torch.randn((4, 2), dtype=torch.float32) 273*da0073e9SAndroid Build Coastguard Worker y = torch.view_as_complex(x) 274*da0073e9SAndroid Build Coastguard Worker m = MetaConverter()(y) 275*da0073e9SAndroid Build Coastguard Worker self.assertMetadataMatches(m, y) 276*da0073e9SAndroid Build Coastguard Worker 277*da0073e9SAndroid Build Coastguard Worker def test_view_dtype(self): 278*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, dtype=torch.float32) 279*da0073e9SAndroid Build Coastguard Worker y = x.view(dtype=torch.int32) 280*da0073e9SAndroid Build Coastguard Worker m = MetaConverter()(y) 281*da0073e9SAndroid Build Coastguard Worker self.assertMetadataMatches(m, y) 282*da0073e9SAndroid Build Coastguard Worker 283*da0073e9SAndroid Build Coastguard Worker def test_imag(self): 284*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, dtype=torch.complex64) 285*da0073e9SAndroid Build Coastguard Worker y = x.imag 286*da0073e9SAndroid Build Coastguard Worker m = MetaConverter()(y) 287*da0073e9SAndroid Build Coastguard Worker self.assertMetadataMatches(m, y) 288*da0073e9SAndroid Build Coastguard Worker 289*da0073e9SAndroid Build Coastguard Worker def test_inplace_set_storage(self): 290*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([0, 1], dtype=torch.int64) 291*da0073e9SAndroid Build Coastguard Worker storage = x.untyped_storage() 292*da0073e9SAndroid Build Coastguard Worker ssize = storage.size() 293*da0073e9SAndroid Build Coastguard Worker meta = torch.empty((), dtype=torch.int64) 294*da0073e9SAndroid Build Coastguard Worker meta.set_(storage, 0, (), ()) 295*da0073e9SAndroid Build Coastguard Worker self.assertEqual(storage.size(), ssize) 296*da0073e9SAndroid Build Coastguard Worker 297*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991") 298*da0073e9SAndroid Build Coastguard Worker def test_weakref(self): 299*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 4, 4) 300*da0073e9SAndroid Build Coastguard Worker m = MetaConverter() 301*da0073e9SAndroid Build Coastguard Worker y = m(x) 302*da0073e9SAndroid Build Coastguard Worker z = m(x) 303*da0073e9SAndroid Build Coastguard Worker self.assertIs(y, z) 304*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(m.tensor_memo), 1) 305*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(m.storage_memo), 1) 306*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(m.describer.lookup_tensor), 1) 307*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(m.describer.lookup_storage), 1) 308*da0073e9SAndroid Build Coastguard Worker del x 309*da0073e9SAndroid Build Coastguard Worker # Entries from Tensor -> int get deallocated when the real tensor 310*da0073e9SAndroid Build Coastguard Worker # disappears... 311*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(m.describer.lookup_tensor), 0) 312*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(m.describer.lookup_storage), 0) 313*da0073e9SAndroid Build Coastguard Worker del y 314*da0073e9SAndroid Build Coastguard Worker del z 315*da0073e9SAndroid Build Coastguard Worker # ... but the int -> FakeTensor entries don't die until the fake 316*da0073e9SAndroid Build Coastguard Worker # tensors themselves die (because the user may have held onto the 317*da0073e9SAndroid Build Coastguard Worker # int key and are expecting to get a consistent fake tensor in 318*da0073e9SAndroid Build Coastguard Worker # this case) 319*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(m.tensor_memo), 0) 320*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(m.storage_memo), 0) 321*da0073e9SAndroid Build Coastguard Worker li = [] 322*da0073e9SAndroid Build Coastguard Worker r = [] 323*da0073e9SAndroid Build Coastguard Worker for i in range(4): 324*da0073e9SAndroid Build Coastguard Worker li.append(torch.rand([i])) 325*da0073e9SAndroid Build Coastguard Worker r.append(m(li[-1])) 326*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(m.tensor_memo), 4) 327*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(m.storage_memo), 4) 328*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(m.describer.lookup_tensor), 4) 329*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(m.describer.lookup_storage), 4) 330*da0073e9SAndroid Build Coastguard Worker del li 331*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(m.describer.lookup_tensor), 0) 332*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(m.describer.lookup_storage), 0) 333*da0073e9SAndroid Build Coastguard Worker del r 334*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(m.tensor_memo), 0) 335*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(m.storage_memo), 0) 336*da0073e9SAndroid Build Coastguard Worker 337*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991") 338*da0073e9SAndroid Build Coastguard Worker def test_tensor_outlives_converter(self): 339*da0073e9SAndroid Build Coastguard Worker m = MetaConverter() 340*da0073e9SAndroid Build Coastguard Worker ref = weakref.ref(m) 341*da0073e9SAndroid Build Coastguard Worker x = torch.randn([4, 4]) 342*da0073e9SAndroid Build Coastguard Worker y = m(x) 343*da0073e9SAndroid Build Coastguard Worker del m 344*da0073e9SAndroid Build Coastguard Worker self.assertIs(ref(), None) 345*da0073e9SAndroid Build Coastguard Worker 346*da0073e9SAndroid Build Coastguard Workeraten = torch.ops.aten 347*da0073e9SAndroid Build Coastguard Worker 348*da0073e9SAndroid Build Coastguard WorkerCHECK_STRIDES = { 349*da0073e9SAndroid Build Coastguard Worker torch.Tensor.__getitem__, 350*da0073e9SAndroid Build Coastguard Worker} 351*da0073e9SAndroid Build Coastguard Worker 352*da0073e9SAndroid Build Coastguard WorkerCHECK_ALL_STRIDES = { 353*da0073e9SAndroid Build Coastguard Worker aten.unsqueeze.default 354*da0073e9SAndroid Build Coastguard Worker} 355*da0073e9SAndroid Build Coastguard Worker 356*da0073e9SAndroid Build Coastguard WorkerCHECK_STRIDES_SKIPS = { 357*da0073e9SAndroid Build Coastguard Worker aten._conj_physical.default, 358*da0073e9SAndroid Build Coastguard Worker aten._fft_c2c.default, 359*da0073e9SAndroid Build Coastguard Worker aten._fft_c2r.default, 360*da0073e9SAndroid Build Coastguard Worker aten._fft_r2c.default, 361*da0073e9SAndroid Build Coastguard Worker aten._linalg_svd.default, 362*da0073e9SAndroid Build Coastguard Worker aten.binary_cross_entropy.default, 363*da0073e9SAndroid Build Coastguard Worker aten.complex.default, 364*da0073e9SAndroid Build Coastguard Worker aten.polar.default, 365*da0073e9SAndroid Build Coastguard Worker aten.copysign.Tensor, 366*da0073e9SAndroid Build Coastguard Worker aten.div.Tensor_mode, 367*da0073e9SAndroid Build Coastguard Worker aten.floor_divide.default, 368*da0073e9SAndroid Build Coastguard Worker aten.heaviside.default, 369*da0073e9SAndroid Build Coastguard Worker aten.lerp.Scalar, 370*da0073e9SAndroid Build Coastguard Worker aten.lerp.Tensor, 371*da0073e9SAndroid Build Coastguard Worker aten.logaddexp.default, 372*da0073e9SAndroid Build Coastguard Worker aten.logical_and.default, 373*da0073e9SAndroid Build Coastguard Worker aten.logical_or.default, 374*da0073e9SAndroid Build Coastguard Worker aten.logical_xor.default, 375*da0073e9SAndroid Build Coastguard Worker aten.pow.Scalar, 376*da0073e9SAndroid Build Coastguard Worker aten.prelu.default, 377*da0073e9SAndroid Build Coastguard Worker aten.special_xlog1py.default, 378*da0073e9SAndroid Build Coastguard Worker aten.xlogy.Tensor, 379*da0073e9SAndroid Build Coastguard Worker aten.nll_loss2d_forward.default, 380*da0073e9SAndroid Build Coastguard Worker 381*da0073e9SAndroid Build Coastguard Worker # channel_last and channel_last_3d related failures 382*da0073e9SAndroid Build Coastguard Worker aten.convolution.default, 383*da0073e9SAndroid Build Coastguard Worker 384*da0073e9SAndroid Build Coastguard Worker # following ops fails if include_storage_offset = True, but these are a bit edge casey 385*da0073e9SAndroid Build Coastguard Worker # we should still fix them, leaving them here for tracking. 386*da0073e9SAndroid Build Coastguard Worker # aten._reshape_alias.default, # repro with test_dispatch_symbolic_meta_outplace_all_strides_matmul_cuda_float32 387*da0073e9SAndroid Build Coastguard Worker # aten.view.default, # repro with test_dispatch_symbolic_meta_outplace_all_strides_unflatten_cuda_float32 388*da0073e9SAndroid Build Coastguard Worker} 389*da0073e9SAndroid Build Coastguard Worker 390*da0073e9SAndroid Build Coastguard WorkerCHECK_CONJ_SKIPS = { 391*da0073e9SAndroid Build Coastguard Worker # The conj bit is not copied, see: 392*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/pull/101836 393*da0073e9SAndroid Build Coastguard Worker aten.linalg_lu_solve.out, 394*da0073e9SAndroid Build Coastguard Worker} 395*da0073e9SAndroid Build Coastguard Worker 396*da0073e9SAndroid Build Coastguard Workerclass CheckStrides(Enum): 397*da0073e9SAndroid Build Coastguard Worker NONE = 0 398*da0073e9SAndroid Build Coastguard Worker SIGNIFICANT = 1 399*da0073e9SAndroid Build Coastguard Worker ALL = 2 400*da0073e9SAndroid Build Coastguard Worker 401*da0073e9SAndroid Build Coastguard Workerdef should_check_strides(func): 402*da0073e9SAndroid Build Coastguard Worker if func in CHECK_ALL_STRIDES: 403*da0073e9SAndroid Build Coastguard Worker return CheckStrides.ALL 404*da0073e9SAndroid Build Coastguard Worker if func in CHECK_STRIDES: 405*da0073e9SAndroid Build Coastguard Worker return CheckStrides.SIGNIFICANT 406*da0073e9SAndroid Build Coastguard Worker if func in CHECK_STRIDES_SKIPS: 407*da0073e9SAndroid Build Coastguard Worker return CheckStrides.NONE 408*da0073e9SAndroid Build Coastguard Worker if not isinstance(func, torch._ops.OpOverload): 409*da0073e9SAndroid Build Coastguard Worker return CheckStrides.NONE 410*da0073e9SAndroid Build Coastguard Worker # Prims are expected to model strides correctly 411*da0073e9SAndroid Build Coastguard Worker if func.namespace == "prims": 412*da0073e9SAndroid Build Coastguard Worker return CheckStrides.SIGNIFICANT 413*da0073e9SAndroid Build Coastguard Worker # Check if it's a view, by testing if any of the returns have 414*da0073e9SAndroid Build Coastguard Worker # a non-empty alias set 415*da0073e9SAndroid Build Coastguard Worker if any(r.alias_info.before_set for r in func._schema.returns if r.alias_info): 416*da0073e9SAndroid Build Coastguard Worker return CheckStrides.SIGNIFICANT 417*da0073e9SAndroid Build Coastguard Worker # TODO: check for TensorIterator 418*da0073e9SAndroid Build Coastguard Worker return CheckStrides.SIGNIFICANT 419*da0073e9SAndroid Build Coastguard Worker 420*da0073e9SAndroid Build Coastguard Workerdef assert_ref_meta_equal(test_case, func, meta_rs, rs, msg_callable): 421*da0073e9SAndroid Build Coastguard Worker flat_meta_rs = pytree.tree_leaves(meta_rs) 422*da0073e9SAndroid Build Coastguard Worker flat_rs = pytree.tree_leaves(rs) 423*da0073e9SAndroid Build Coastguard Worker test_case.assertEqual(len(flat_meta_rs), len(flat_rs)) 424*da0073e9SAndroid Build Coastguard Worker for i, meta_r, r in zip(range(len(flat_rs)), flat_meta_rs, flat_rs): 425*da0073e9SAndroid Build Coastguard Worker def test_assert(cond, msg): 426*da0073e9SAndroid Build Coastguard Worker if not cond: 427*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"output {i}: {msg_callable(msg)}") 428*da0073e9SAndroid Build Coastguard Worker if not isinstance(r, torch.Tensor): 429*da0073e9SAndroid Build Coastguard Worker continue 430*da0073e9SAndroid Build Coastguard Worker test_assert(isinstance(meta_r, torch.Tensor), f"but real {i}th result is Tensor") 431*da0073e9SAndroid Build Coastguard Worker test_assert(meta_r.dtype == r.dtype, f"for element {i}, was {meta_r.dtype} but real dtype was {r.dtype}") 432*da0073e9SAndroid Build Coastguard Worker test_assert(meta_r.shape == r.shape, f"for element {i}, was {meta_r.shape} but real shape was {r.shape}") 433*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/78050 434*da0073e9SAndroid Build Coastguard Worker if should_check_strides(func) == CheckStrides.ALL: 435*da0073e9SAndroid Build Coastguard Worker same_strides, _ = torch._prims_common.check_all_strides(meta_r, r) 436*da0073e9SAndroid Build Coastguard Worker test_assert(same_strides, f"for element {i}, was {meta_r.stride()} but real stride was {r.stride()}") 437*da0073e9SAndroid Build Coastguard Worker elif should_check_strides(func) == CheckStrides.SIGNIFICANT: 438*da0073e9SAndroid Build Coastguard Worker same_strides, _ = torch._prims_common.check_significant_strides(meta_r, r) 439*da0073e9SAndroid Build Coastguard Worker test_assert(same_strides, f"for element {i}, was {meta_r.stride()} but real stride was {r.stride()}") 440*da0073e9SAndroid Build Coastguard Worker test_assert( 441*da0073e9SAndroid Build Coastguard Worker meta_r.storage_offset() == r.storage_offset(), 442*da0073e9SAndroid Build Coastguard Worker f"for element {i}, was {meta_r.storage_offset()} but real storage_offset was {r.storage_offset()}") 443*da0073e9SAndroid Build Coastguard Worker test_assert(meta_r.requires_grad == r.requires_grad, 444*da0073e9SAndroid Build Coastguard Worker f"for element {i}, was {meta_r.requires_grad} but real requires_grad was {r.requires_grad}") 445*da0073e9SAndroid Build Coastguard Worker if func not in CHECK_CONJ_SKIPS: 446*da0073e9SAndroid Build Coastguard Worker test_assert(meta_r.is_conj() == r.is_conj(), 447*da0073e9SAndroid Build Coastguard Worker f"for element {i}, was {meta_r.is_conj()} but real is_conj was {r.is_conj()}") 448*da0073e9SAndroid Build Coastguard Worker test_assert(meta_r.is_neg() == r.is_neg(), f"for element {i}, was {meta_r.is_neg()} but real is_neg was {r.is_neg()}") 449*da0073e9SAndroid Build Coastguard Worker 450*da0073e9SAndroid Build Coastguard Worker 451*da0073e9SAndroid Build Coastguard Worker# This environment variable controls whether or not we print expected failure 452*da0073e9SAndroid Build Coastguard Worker# lists at the end of a test suite run. The intended usage looks like this: 453*da0073e9SAndroid Build Coastguard Worker# 454*da0073e9SAndroid Build Coastguard Worker# 1. Run `PYTORCH_COLLECT_EXPECT=1 python test/test_meta.py` on a CUDA build 455*da0073e9SAndroid Build Coastguard Worker# of PyTorch that has LAPACK/MAGMA installed. You can filter `-k test_meta` 456*da0073e9SAndroid Build Coastguard Worker# or `-k test_dispatch_meta` to only focus on one or another list 457*da0073e9SAndroid Build Coastguard Worker# 2. Given the printed skip/xfail list, add them to the corresponding lists; 458*da0073e9SAndroid Build Coastguard Worker# torch.* entries go in meta_function and aten.* entries go in meta_dispatch. 459*da0073e9SAndroid Build Coastguard Worker# If there are preexisting entries, you need to merge in the entries. 460*da0073e9SAndroid Build Coastguard Worker# 461*da0073e9SAndroid Build Coastguard Worker# This is somewhat manual but typically you shouldn't need to do this, unless 462*da0073e9SAndroid Build Coastguard Worker# you've made a major change (e.g., added a new dtype to PyTorch) and need to 463*da0073e9SAndroid Build Coastguard Worker# refresh the lists. If you want to do it from scratch, just clear out the 464*da0073e9SAndroid Build Coastguard Worker# preexisting lists before running. 465*da0073e9SAndroid Build Coastguard Worker# 466*da0073e9SAndroid Build Coastguard Worker# WARNING: Python dict literals will silently ignore duplicate keys 467*da0073e9SAndroid Build Coastguard WorkerCOLLECT_EXPECT = os.getenv('PYTORCH_COLLECT_EXPECT', '0') == '1' 468*da0073e9SAndroid Build Coastguard Worker 469*da0073e9SAndroid Build Coastguard Workerseen_succeeded = {} 470*da0073e9SAndroid Build Coastguard Workerseen_failed = {} 471*da0073e9SAndroid Build Coastguard Workerfailed_reasons = defaultdict(set) 472*da0073e9SAndroid Build Coastguard Workerdef print_seen(): 473*da0073e9SAndroid Build Coastguard Worker expected_failures = [] 474*da0073e9SAndroid Build Coastguard Worker skips = [] 475*da0073e9SAndroid Build Coastguard Worker 476*da0073e9SAndroid Build Coastguard Worker def fmt_dtypes(dtypes): 477*da0073e9SAndroid Build Coastguard Worker r = ', '.join(sorted(dtype_abbrs[d] for d in dtypes)) 478*da0073e9SAndroid Build Coastguard Worker return '{' + r + '}' 479*da0073e9SAndroid Build Coastguard Worker 480*da0073e9SAndroid Build Coastguard Worker for op, failed_dtypes in seen_failed.items(): 481*da0073e9SAndroid Build Coastguard Worker ops = resolve_name(op) 482*da0073e9SAndroid Build Coastguard Worker succeeded_dtypes = seen_succeeded.get(op, set()) 483*da0073e9SAndroid Build Coastguard Worker expected_failures_dtypes = failed_dtypes - succeeded_dtypes 484*da0073e9SAndroid Build Coastguard Worker skips_dtypes = failed_dtypes & succeeded_dtypes 485*da0073e9SAndroid Build Coastguard Worker reasons = "" 486*da0073e9SAndroid Build Coastguard Worker if failed_reasons[op]: 487*da0073e9SAndroid Build Coastguard Worker reasons = " # " + ", ".join(sorted(failed_reasons[op])) 488*da0073e9SAndroid Build Coastguard Worker if expected_failures_dtypes: 489*da0073e9SAndroid Build Coastguard Worker expected_failures.append(f" {ops}: {fmt_dtypes(expected_failures_dtypes)},{reasons}") 490*da0073e9SAndroid Build Coastguard Worker if skips_dtypes: 491*da0073e9SAndroid Build Coastguard Worker skips.append(f" {ops}: {fmt_dtypes(skips_dtypes)},") 492*da0073e9SAndroid Build Coastguard Worker expected_failures.sort() 493*da0073e9SAndroid Build Coastguard Worker skips.sort() 494*da0073e9SAndroid Build Coastguard Worker nl = '\n' 495*da0073e9SAndroid Build Coastguard Worker print(f"""\ 496*da0073e9SAndroid Build Coastguard Workerexpected_failures = {{ 497*da0073e9SAndroid Build Coastguard Worker{nl.join(expected_failures)} 498*da0073e9SAndroid Build Coastguard Worker}} 499*da0073e9SAndroid Build Coastguard Worker 500*da0073e9SAndroid Build Coastguard Workerskips = {{ 501*da0073e9SAndroid Build Coastguard Worker{nl.join(skips)} 502*da0073e9SAndroid Build Coastguard Worker}} 503*da0073e9SAndroid Build Coastguard Worker""") 504*da0073e9SAndroid Build Coastguard Workerif COLLECT_EXPECT: 505*da0073e9SAndroid Build Coastguard Worker atexit.register(print_seen) 506*da0073e9SAndroid Build Coastguard Worker 507*da0073e9SAndroid Build Coastguard Worker# Success forces pass; failure forces fail; skip unconditionally skips testing 508*da0073e9SAndroid Build Coastguard WorkerTestExpect = Enum("TestExpect", ("SUCCESS", "XFAILURE", "SKIP")) 509*da0073e9SAndroid Build Coastguard Worker 510*da0073e9SAndroid Build Coastguard Worker# unlike print produce strides 511*da0073e9SAndroid Build Coastguard Workerdef verbose_print(e): 512*da0073e9SAndroid Build Coastguard Worker class Lit: 513*da0073e9SAndroid Build Coastguard Worker def __init__(self, s): 514*da0073e9SAndroid Build Coastguard Worker self.s = s 515*da0073e9SAndroid Build Coastguard Worker 516*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 517*da0073e9SAndroid Build Coastguard Worker return self.s 518*da0073e9SAndroid Build Coastguard Worker 519*da0073e9SAndroid Build Coastguard Worker def go(t): 520*da0073e9SAndroid Build Coastguard Worker if is_sparse_any(t): 521*da0073e9SAndroid Build Coastguard Worker return t 522*da0073e9SAndroid Build Coastguard Worker elif isinstance(t, torch.Tensor): 523*da0073e9SAndroid Build Coastguard Worker return Lit(f"{t} stride={t.stride()}") 524*da0073e9SAndroid Build Coastguard Worker else: 525*da0073e9SAndroid Build Coastguard Worker return t 526*da0073e9SAndroid Build Coastguard Worker 527*da0073e9SAndroid Build Coastguard Worker return repr(tree_map(go, e)) 528*da0073e9SAndroid Build Coastguard Worker 529*da0073e9SAndroid Build Coastguard Workerdef run_meta_crossref( 530*da0073e9SAndroid Build Coastguard Worker test_case, 531*da0073e9SAndroid Build Coastguard Worker test_expect, 532*da0073e9SAndroid Build Coastguard Worker func, 533*da0073e9SAndroid Build Coastguard Worker args, 534*da0073e9SAndroid Build Coastguard Worker kwargs, 535*da0073e9SAndroid Build Coastguard Worker *, 536*da0073e9SAndroid Build Coastguard Worker dtype, 537*da0073e9SAndroid Build Coastguard Worker device_type, 538*da0073e9SAndroid Build Coastguard Worker run_symbolic_meta: bool 539*da0073e9SAndroid Build Coastguard Worker): 540*da0073e9SAndroid Build Coastguard Worker to_meta = MetaConverter() 541*da0073e9SAndroid Build Coastguard Worker do_meta = test_expect is not TestExpect.SKIP 542*da0073e9SAndroid Build Coastguard Worker if do_meta: 543*da0073e9SAndroid Build Coastguard Worker try: 544*da0073e9SAndroid Build Coastguard Worker meta_args = tree_map(to_meta, args) 545*da0073e9SAndroid Build Coastguard Worker meta_kwargs = tree_map(to_meta, kwargs) 546*da0073e9SAndroid Build Coastguard Worker except Exception as e: 547*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 548*da0073e9SAndroid Build Coastguard Worker f"failed to convert args to meta; " 549*da0073e9SAndroid Build Coastguard Worker f"originally (*{args}, **{kwargs})") from e 550*da0073e9SAndroid Build Coastguard Worker try: 551*da0073e9SAndroid Build Coastguard Worker rs = func(*args, **kwargs) 552*da0073e9SAndroid Build Coastguard Worker except Exception as e: 553*da0073e9SAndroid Build Coastguard Worker raise AssertionError("Original OpInfo is broken") from e 554*da0073e9SAndroid Build Coastguard Worker 555*da0073e9SAndroid Build Coastguard Worker # TODO: also handle cases where func raise an exception 556*da0073e9SAndroid Build Coastguard Worker 557*da0073e9SAndroid Build Coastguard Worker # For now, only attempt if we managed to convert all tensor types 558*da0073e9SAndroid Build Coastguard Worker # (if any of them failed, we're in a mixed device situation and 559*da0073e9SAndroid Build Coastguard Worker # this isn't well supported) 560*da0073e9SAndroid Build Coastguard Worker if do_meta and to_meta.successful(): 561*da0073e9SAndroid Build Coastguard Worker # Special cases 562*da0073e9SAndroid Build Coastguard Worker if func is torch.tensor_split: 563*da0073e9SAndroid Build Coastguard Worker # Use original indices_or_sections, this argument is data dependent 564*da0073e9SAndroid Build Coastguard Worker meta_args = (meta_args[0], args[1]) + meta_args[2:] 565*da0073e9SAndroid Build Coastguard Worker elif func is torch.Tensor.__getitem__: 566*da0073e9SAndroid Build Coastguard Worker # Ensure boolean tensors use original 567*da0073e9SAndroid Build Coastguard Worker assert len(args) == 2 568*da0073e9SAndroid Build Coastguard Worker flat_args = pytree.tree_leaves(args[1]) 569*da0073e9SAndroid Build Coastguard Worker flat_meta_args, spec = tree_flatten(meta_args[1]) 570*da0073e9SAndroid Build Coastguard Worker flat_new_args = [] 571*da0073e9SAndroid Build Coastguard Worker for a, ma in zip(flat_args, flat_meta_args): 572*da0073e9SAndroid Build Coastguard Worker flat_new_args.append(a if isinstance(a, torch.Tensor) and a.dtype in [torch.int8, torch.bool] else ma) 573*da0073e9SAndroid Build Coastguard Worker meta_args = (meta_args[0], tree_unflatten(flat_new_args, spec)) 574*da0073e9SAndroid Build Coastguard Worker elif func in (torch.ops.aten.repeat_interleave.Tensor, torch.ops.aten.repeat_interleave.Tensor_out): 575*da0073e9SAndroid Build Coastguard Worker if kwargs.get("output_size", None) is None: 576*da0073e9SAndroid Build Coastguard Worker meta_args = args 577*da0073e9SAndroid Build Coastguard Worker if func is torch.ops.aten.repeat_interleave.Tensor_out: 578*da0073e9SAndroid Build Coastguard Worker meta_kwargs["out"] = kwargs["out"] 579*da0073e9SAndroid Build Coastguard Worker elif func in (torch.ops.aten.index.Tensor, torch.ops.aten.index.Tensor_out): 580*da0073e9SAndroid Build Coastguard Worker # Don't convert boolean tensors to meta as they will have nonzero 581*da0073e9SAndroid Build Coastguard Worker # called on them 582*da0073e9SAndroid Build Coastguard Worker indices = [] 583*da0073e9SAndroid Build Coastguard Worker for meta_index, real_index in zip(meta_args[1], args[1]): 584*da0073e9SAndroid Build Coastguard Worker if meta_index is not None and meta_index.dtype in [torch.int8, torch.bool]: 585*da0073e9SAndroid Build Coastguard Worker indices.append(real_index) 586*da0073e9SAndroid Build Coastguard Worker else: 587*da0073e9SAndroid Build Coastguard Worker indices.append(meta_index) 588*da0073e9SAndroid Build Coastguard Worker meta_args = (meta_args[0], indices) 589*da0073e9SAndroid Build Coastguard Worker elif func is torch.nn.functional.ctc_loss and all([isinstance(args[2], list), isinstance(args[3], list)]): 590*da0073e9SAndroid Build Coastguard Worker # torch.ops.aten._ctc_loss.IntList has a meta kernel but 591*da0073e9SAndroid Build Coastguard Worker # torch.ops.aten._ctc_loss.Tensor does not 592*da0073e9SAndroid Build Coastguard Worker test_expect = TestExpect.SUCCESS 593*da0073e9SAndroid Build Coastguard Worker 594*da0073e9SAndroid Build Coastguard Worker if kwargs.get("device", None) is not None: 595*da0073e9SAndroid Build Coastguard Worker meta_kwargs["device"] = "meta" 596*da0073e9SAndroid Build Coastguard Worker 597*da0073e9SAndroid Build Coastguard Worker try: 598*da0073e9SAndroid Build Coastguard Worker # Suppress warnings, this doesn't matter for test_meta.py 599*da0073e9SAndroid Build Coastguard Worker # but it does matter if you want to use this decorator 600*da0073e9SAndroid Build Coastguard Worker # for cross-ref testing, as some tests may be looking at 601*da0073e9SAndroid Build Coastguard Worker # errors 602*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(): 603*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("ignore") 604*da0073e9SAndroid Build Coastguard Worker if run_symbolic_meta: 605*da0073e9SAndroid Build Coastguard Worker # Run the decomps and meta kernels registered 606*da0073e9SAndroid Build Coastguard Worker # to the python dispatcher instead of the regular dispatcher. 607*da0073e9SAndroid Build Coastguard Worker # This should be the same set of kernels 608*da0073e9SAndroid Build Coastguard Worker # that fake tensor runs in dynamic shapes mode. 609*da0073e9SAndroid Build Coastguard Worker with enable_python_dispatcher(): 610*da0073e9SAndroid Build Coastguard Worker meta_rs = func(*meta_args, **meta_kwargs) 611*da0073e9SAndroid Build Coastguard Worker else: 612*da0073e9SAndroid Build Coastguard Worker meta_rs = func(*meta_args, **meta_kwargs) 613*da0073e9SAndroid Build Coastguard Worker except Exception as e: 614*da0073e9SAndroid Build Coastguard Worker if test_expect is TestExpect.XFAILURE: 615*da0073e9SAndroid Build Coastguard Worker return rs 616*da0073e9SAndroid Build Coastguard Worker seen_failed.setdefault(func, set()).add(dtype) 617*da0073e9SAndroid Build Coastguard Worker if isinstance(e, NotImplementedError): 618*da0073e9SAndroid Build Coastguard Worker m = RE_NOT_IMPLEMENTED_MSG.search(e.args[0]) 619*da0073e9SAndroid Build Coastguard Worker if m: 620*da0073e9SAndroid Build Coastguard Worker failed_reasons[func].add(m.group(1)) 621*da0073e9SAndroid Build Coastguard Worker if COLLECT_EXPECT: 622*da0073e9SAndroid Build Coastguard Worker return rs 623*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"""\ 624*da0073e9SAndroid Build Coastguard Workerfailed to run: {resolve_name(func)}( 625*da0073e9SAndroid Build Coastguard Worker*{verbose_print(meta_args)}, 626*da0073e9SAndroid Build Coastguard Worker**{verbose_print(meta_kwargs)} 627*da0073e9SAndroid Build Coastguard Worker)""") from e 628*da0073e9SAndroid Build Coastguard Worker else: 629*da0073e9SAndroid Build Coastguard Worker try: 630*da0073e9SAndroid Build Coastguard Worker delim = ',\n ' 631*da0073e9SAndroid Build Coastguard Worker assert_ref_meta_equal(test_case, func, meta_rs, rs, lambda msg: f"""\ 632*da0073e9SAndroid Build Coastguard Workermeta disagrees with real impl: 633*da0073e9SAndroid Build Coastguard Worker{resolve_name(func)}( 634*da0073e9SAndroid Build Coastguard Worker {delim.join(map(verbose_print, meta_args))}, 635*da0073e9SAndroid Build Coastguard Worker {delim.join(k + ": " + verbose_print(v) for k, v in meta_kwargs.items())} 636*da0073e9SAndroid Build Coastguard Worker) = ( 637*da0073e9SAndroid Build Coastguard Worker {verbose_print(meta_rs)} 638*da0073e9SAndroid Build Coastguard Worker) 639*da0073e9SAndroid Build Coastguard Worker{msg} 640*da0073e9SAndroid Build Coastguard Worker""") 641*da0073e9SAndroid Build Coastguard Worker except Exception: 642*da0073e9SAndroid Build Coastguard Worker if test_expect is TestExpect.XFAILURE: 643*da0073e9SAndroid Build Coastguard Worker return rs 644*da0073e9SAndroid Build Coastguard Worker seen_failed.setdefault(func, set()).add(dtype) 645*da0073e9SAndroid Build Coastguard Worker if COLLECT_EXPECT: 646*da0073e9SAndroid Build Coastguard Worker return rs 647*da0073e9SAndroid Build Coastguard Worker raise 648*da0073e9SAndroid Build Coastguard Worker else: 649*da0073e9SAndroid Build Coastguard Worker seen_succeeded.setdefault(func, set()).add(dtype) 650*da0073e9SAndroid Build Coastguard Worker if test_expect is TestExpect.XFAILURE and not COLLECT_EXPECT: 651*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"unexpected success {resolve_name(func)} {meta_args} {meta_kwargs}") 652*da0073e9SAndroid Build Coastguard Worker 653*da0073e9SAndroid Build Coastguard Worker return rs 654*da0073e9SAndroid Build Coastguard Worker 655*da0073e9SAndroid Build Coastguard Worker 656*da0073e9SAndroid Build Coastguard Worker 657*da0073e9SAndroid Build Coastguard WorkerRE_NOT_IMPLEMENTED_MSG = re.compile(r"Could not run '([^']+)' with arguments ") 658*da0073e9SAndroid Build Coastguard Worker 659*da0073e9SAndroid Build Coastguard Workermeta_function_expected_failures = { 660*da0073e9SAndroid Build Coastguard Worker torch.Tensor.to_sparse : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32}, 661*da0073e9SAndroid Build Coastguard Worker torch.allclose : {f64, f16, c128, c64, bf16, f32}, 662*da0073e9SAndroid Build Coastguard Worker torch.argwhere : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32}, 663*da0073e9SAndroid Build Coastguard Worker torch.combinations : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32}, 664*da0073e9SAndroid Build Coastguard Worker torch.corrcoef : {f64, i32, c128, i64, i16, u8, c64, bf16, f16, i8, f32}, 665*da0073e9SAndroid Build Coastguard Worker torch.cov : {f64, i32, c128, i64, i16, u8, c64, bf16, i8, f32, f16}, 666*da0073e9SAndroid Build Coastguard Worker torch.functional.istft : {f64, c64, c128, f32}, 667*da0073e9SAndroid Build Coastguard Worker torch.geqrf : {f64, c64, c128, f32}, 668*da0073e9SAndroid Build Coastguard Worker torch.masked_select : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32}, 669*da0073e9SAndroid Build Coastguard Worker torch.nonzero : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32}, 670*da0073e9SAndroid Build Coastguard Worker torch.Tensor.nonzero : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32}, 671*da0073e9SAndroid Build Coastguard Worker torch.Tensor.item : {f64, i32, c128, i64, i16, f16, u8, c32, c64, bf16, b8, i8, f32}, 672*da0073e9SAndroid Build Coastguard Worker torch.bincount : {i32, i64, u8, i16, i8}, 673*da0073e9SAndroid Build Coastguard Worker torch.functional.unique : {f64, i32, i64, u8, i16, f16, bf16, b8, i8, f32, u16, u32, u64}, 674*da0073e9SAndroid Build Coastguard Worker torch.functional.unique_consecutive : {f64, i32, i64, u8, i16, f16, bf16, b8, i8, f32, u16, u32, u64}, 675*da0073e9SAndroid Build Coastguard Worker torch.histogram : {f64, f32}, 676*da0073e9SAndroid Build Coastguard Worker torch.histogramdd : {f64, f32}, 677*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.ctc_loss : {f64, f32}, 678*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.gaussian_nll_loss : {f16, f64, bf16, f32}, 679*da0073e9SAndroid Build Coastguard Worker torch.linalg.lstsq : {f64, f32, c128, c64}, 680*da0073e9SAndroid Build Coastguard Worker} 681*da0073e9SAndroid Build Coastguard Worker 682*da0073e9SAndroid Build Coastguard Workermeta_function_expected_failures_conditional = { 683*da0073e9SAndroid Build Coastguard Worker torch.repeat_interleave : (lambda dtype, *args, **kwargs: not isinstance(kwargs.get("repeats", None), int)), 684*da0073e9SAndroid Build Coastguard Worker} 685*da0073e9SAndroid Build Coastguard Worker 686*da0073e9SAndroid Build Coastguard Worker""" 687*da0073e9SAndroid Build Coastguard Worker# This is some sample code for how we could dump these dicts into YAML 688*da0073e9SAndroid Build Coastguard Worker# file for easier reading/writing 689*da0073e9SAndroid Build Coastguard Workerimport yaml 690*da0073e9SAndroid Build Coastguard Workerprint(yaml.dump( 691*da0073e9SAndroid Build Coastguard Worker {resolve_name(k): [dtype_abbrs[d] for d in v] 692*da0073e9SAndroid Build Coastguard Worker for k, v in meta_function_expected_failures.items()}, default_flow_style=None)) 693*da0073e9SAndroid Build Coastguard Workerimport sys 694*da0073e9SAndroid Build Coastguard Workersys.exit() 695*da0073e9SAndroid Build Coastguard Worker""" 696*da0073e9SAndroid Build Coastguard Worker 697*da0073e9SAndroid Build Coastguard Workermeta_function_skips = { 698*da0073e9SAndroid Build Coastguard Worker torch.Tensor.__rmatmul__ : {bf16, c128, f64, f32, f16, c64}, 699*da0073e9SAndroid Build Coastguard Worker torch.Tensor.matmul : {f64, f32, c128, c64}, 700*da0073e9SAndroid Build Coastguard Worker torch.functional.atleast_2d : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64}, 701*da0073e9SAndroid Build Coastguard Worker torch.functional.atleast_3d : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64}, 702*da0073e9SAndroid Build Coastguard Worker torch.functional.cartesian_prod : {bf16, i8, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64}, 703*da0073e9SAndroid Build Coastguard Worker torch.functional.einsum : {bf16, c128, f64, f32, f16, c64}, 704*da0073e9SAndroid Build Coastguard Worker torch.inner : {f16, bf16, i8, i64, u8, c128, f64, i16, f32, i32, c64}, 705*da0073e9SAndroid Build Coastguard Worker torch.linalg.matrix_norm : {c128, f32, c64, f64}, 706*da0073e9SAndroid Build Coastguard Worker torch.linalg.matrix_rank : {c128, c64}, 707*da0073e9SAndroid Build Coastguard Worker torch.linalg.svd : {c128, c64}, 708*da0073e9SAndroid Build Coastguard Worker torch.matmul : {bf16, c128, f64, f32, f16, c64}, 709*da0073e9SAndroid Build Coastguard Worker torch.nanquantile : {f64, f32}, 710*da0073e9SAndroid Build Coastguard Worker torch.narrow : {bf16, i8, i64, u8, c128, b8, f64, i16, i32, f32, f16, c32, c64}, 711*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.batch_norm : {f64, f32}, 712*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.binary_cross_entropy : {bf16, f64, f32, f16}, 713*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.dropout3d : {bf16, f64, f32, f16}, 714*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.local_response_norm : {bf16, f64, f32, f16}, 715*da0073e9SAndroid Build Coastguard Worker torch.svd : {c128, c64}, 716*da0073e9SAndroid Build Coastguard Worker torch.take_along_dim : {bf16, i8, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64}, 717*da0073e9SAndroid Build Coastguard Worker torch.vstack : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64}, 718*da0073e9SAndroid Build Coastguard Worker torch.diff : {b8}, 719*da0073e9SAndroid Build Coastguard Worker torch.equal : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64}, 720*da0073e9SAndroid Build Coastguard Worker torch.nanmean : {bf16, f64, f32, f16, c32, c64, c128}, 721*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.cross_entropy : {bf16, f64, f32}, 722*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.nll_loss : {bf16, f64, f32}, 723*da0073e9SAndroid Build Coastguard Worker torch.linalg.cond : {c128, c64, f32, f64}, 724*da0073e9SAndroid Build Coastguard Worker torch.linalg.vecdot : {bf16, f64, f32, f16}, 725*da0073e9SAndroid Build Coastguard Worker torch.empty : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64}, 726*da0073e9SAndroid Build Coastguard Worker torch.Tensor.addbmm_: {bf16, c128, c64, f32, f64, i16, i32, i64, i8, u8}, 727*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.one_hot : {i64}, 728*da0073e9SAndroid Build Coastguard Worker} 729*da0073e9SAndroid Build Coastguard Worker 730*da0073e9SAndroid Build Coastguard Worker 731*da0073e9SAndroid Build Coastguard Workermeta_function_device_expected_failures = defaultdict(dict) 732*da0073e9SAndroid Build Coastguard Workermeta_function_device_expected_failures_only_outplace = defaultdict(dict) 733*da0073e9SAndroid Build Coastguard Workermeta_function_device_skips = defaultdict(dict) 734*da0073e9SAndroid Build Coastguard Worker 735*da0073e9SAndroid Build Coastguard Workermeta_function_device_expected_failures['cpu'] = { 736*da0073e9SAndroid Build Coastguard Worker # TODO: The decomps for these batch norm ops return different dtypes depending 737*da0073e9SAndroid Build Coastguard Worker # on the device. We should make this work better with meta tensors. 738*da0073e9SAndroid Build Coastguard Worker torch.native_batch_norm: {bf16, f16}, 739*da0073e9SAndroid Build Coastguard Worker torch._native_batch_norm_legit: {bf16, f16}, 740*da0073e9SAndroid Build Coastguard Worker torch.ops.aten._batch_norm_with_update: {bf16, f16}, 741*da0073e9SAndroid Build Coastguard Worker torch.native_layer_norm: {bf16, f16}, 742*da0073e9SAndroid Build Coastguard Worker} 743*da0073e9SAndroid Build Coastguard Worker 744*da0073e9SAndroid Build Coastguard Workermeta_function_device_expected_failures['cuda'] = { 745*da0073e9SAndroid Build Coastguard Worker torch.corrcoef: {bf16, f16}, # aten::_local_scalar_dense 746*da0073e9SAndroid Build Coastguard Worker torch.cov: {f16}, # aten::_local_scalar_dense 747*da0073e9SAndroid Build Coastguard Worker torch.functional.unique: {f16}, # aten::_unique2, aten::unique_dim 748*da0073e9SAndroid Build Coastguard Worker torch.functional.unique_consecutive: {f16}, # aten::unique_consecutive 749*da0073e9SAndroid Build Coastguard Worker torch.geqrf: {f32, f64}, # aten::geqrf 750*da0073e9SAndroid Build Coastguard Worker} 751*da0073e9SAndroid Build Coastguard Worker 752*da0073e9SAndroid Build Coastguard Workermeta_function_device_skips['cpu'] = { 753*da0073e9SAndroid Build Coastguard Worker # TODO: The decomps for these batch norm ops return different dtypes depending 754*da0073e9SAndroid Build Coastguard Worker # on the device. We should make this work better with meta tensors. 755*da0073e9SAndroid Build Coastguard Worker torch.native_batch_norm: {f32, f64}, 756*da0073e9SAndroid Build Coastguard Worker torch._native_batch_norm_legit: {f32, f64}, 757*da0073e9SAndroid Build Coastguard Worker torch.ops.aten._batch_norm_with_update: {f32, f64}, 758*da0073e9SAndroid Build Coastguard Worker} 759*da0073e9SAndroid Build Coastguard Worker 760*da0073e9SAndroid Build Coastguard Workermeta_function_device_skips['cuda'] = { 761*da0073e9SAndroid Build Coastguard Worker torch.inner: {f16}, 762*da0073e9SAndroid Build Coastguard Worker torch.linalg.matrix_rank: {f32, f64}, 763*da0073e9SAndroid Build Coastguard Worker torch.linalg.svd: {f32, f64}, 764*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.cross_entropy: {f16}, 765*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.interpolate: {f16}, 766*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.nll_loss: {f16}, 767*da0073e9SAndroid Build Coastguard Worker torch.svd: {f32, f64}, 768*da0073e9SAndroid Build Coastguard Worker} 769*da0073e9SAndroid Build Coastguard Worker 770*da0073e9SAndroid Build Coastguard Worker# This is a __torch_function__ mode that, when enabled, interposes every 771*da0073e9SAndroid Build Coastguard Worker# Torch API call and runs the operator as normal, and then reruns it 772*da0073e9SAndroid Build Coastguard Worker# with meta inputs, and then checks that everything about the output agrees. 773*da0073e9SAndroid Build Coastguard Worker# Most of the logic deals with faithfully replicating the original tensor 774*da0073e9SAndroid Build Coastguard Worker# as a meta tensor, which is nontrivial because there are a lot of subsystems 775*da0073e9SAndroid Build Coastguard Worker# that may potentially be exercised. 776*da0073e9SAndroid Build Coastguard Worker# 777*da0073e9SAndroid Build Coastguard Worker# That being said, this class is a little overkill for what it is doing in 778*da0073e9SAndroid Build Coastguard Worker# this test file (since I could have just inlined __torch_function__ on the 779*da0073e9SAndroid Build Coastguard Worker# OpInfo call, and OpInfos generally have very regular inputs), but it will be 780*da0073e9SAndroid Build Coastguard Worker# useful for more comprehensive testing e.g., as seen in 781*da0073e9SAndroid Build Coastguard Worker# https://github.com/pytorch/pytorch/pull/75994 The big benefit is it is 782*da0073e9SAndroid Build Coastguard Worker# A LOT more efficient that torch dispatch mode (at the cost of less coverage) 783*da0073e9SAndroid Build Coastguard Workerclass MetaCrossRefFunctionMode(torch.overrides.TorchFunctionMode): 784*da0073e9SAndroid Build Coastguard Worker test_case: TestCase 785*da0073e9SAndroid Build Coastguard Worker device_type: str 786*da0073e9SAndroid Build Coastguard Worker dtype: torch.dtype 787*da0073e9SAndroid Build Coastguard Worker 788*da0073e9SAndroid Build Coastguard Worker def __init__(self, test_case, *, device, dtype, inplace): 789*da0073e9SAndroid Build Coastguard Worker self.test_case = test_case 790*da0073e9SAndroid Build Coastguard Worker self.device_type = torch.device(device).type 791*da0073e9SAndroid Build Coastguard Worker self.dtype = dtype 792*da0073e9SAndroid Build Coastguard Worker self.inplace = inplace 793*da0073e9SAndroid Build Coastguard Worker 794*da0073e9SAndroid Build Coastguard Worker def __torch_function__(self, func, types, args=(), kwargs=None): 795*da0073e9SAndroid Build Coastguard Worker kwargs = kwargs or {} 796*da0073e9SAndroid Build Coastguard Worker 797*da0073e9SAndroid Build Coastguard Worker if ( 798*da0073e9SAndroid Build Coastguard Worker torch.jit.is_tracing() or isinstance(func, torch.ScriptMethod) or 799*da0073e9SAndroid Build Coastguard Worker # meta converter doesn't work correctly when no_dispatch() is on, so 800*da0073e9SAndroid Build Coastguard Worker # skip running the crossref test in this case 801*da0073e9SAndroid Build Coastguard Worker torch._C._dispatch_tls_local_exclude_set().has(torch._C.DispatchKey.Python) 802*da0073e9SAndroid Build Coastguard Worker ): 803*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 804*da0073e9SAndroid Build Coastguard Worker 805*da0073e9SAndroid Build Coastguard Worker if self.dtype in meta_function_skips.get(func, set()): 806*da0073e9SAndroid Build Coastguard Worker test_expect = TestExpect.SKIP 807*da0073e9SAndroid Build Coastguard Worker elif self.dtype in meta_function_device_skips[self.device_type].get(func, set()): 808*da0073e9SAndroid Build Coastguard Worker test_expect = TestExpect.SKIP 809*da0073e9SAndroid Build Coastguard Worker elif self.dtype in meta_function_expected_failures.get(func, set()): 810*da0073e9SAndroid Build Coastguard Worker test_expect = TestExpect.XFAILURE 811*da0073e9SAndroid Build Coastguard Worker elif self.dtype in meta_function_device_expected_failures[self.device_type].get(func, set()): 812*da0073e9SAndroid Build Coastguard Worker test_expect = TestExpect.XFAILURE 813*da0073e9SAndroid Build Coastguard Worker elif meta_function_expected_failures_conditional.get(func, lambda *_, **__: False)(self.dtype, *args, **kwargs): 814*da0073e9SAndroid Build Coastguard Worker test_expect = TestExpect.XFAILURE 815*da0073e9SAndroid Build Coastguard Worker elif not self.inplace and \ 816*da0073e9SAndroid Build Coastguard Worker self.dtype in meta_function_device_expected_failures_only_outplace[self.device_type].get(func, set()): 817*da0073e9SAndroid Build Coastguard Worker test_expect = TestExpect.XFAILURE 818*da0073e9SAndroid Build Coastguard Worker else: 819*da0073e9SAndroid Build Coastguard Worker test_expect = TestExpect.SUCCESS 820*da0073e9SAndroid Build Coastguard Worker 821*da0073e9SAndroid Build Coastguard Worker return run_meta_crossref( 822*da0073e9SAndroid Build Coastguard Worker self.test_case, test_expect, func, args, 823*da0073e9SAndroid Build Coastguard Worker kwargs, dtype=self.dtype, device_type=self.device_type, run_symbolic_meta=False 824*da0073e9SAndroid Build Coastguard Worker ) 825*da0073e9SAndroid Build Coastguard Worker 826*da0073e9SAndroid Build Coastguard Worker# these always fail 827*da0073e9SAndroid Build Coastguard Workermeta_dispatch_expected_failures = { 828*da0073e9SAndroid Build Coastguard Worker aten.allclose.default: {f16, bf16, f32, f64, c64, c128}, # NotImplementedError: 'aten::_local_scalar_dense' 829*da0073e9SAndroid Build Coastguard Worker aten.geqrf.default : {c64, c128, f64, f32}, 830*da0073e9SAndroid Build Coastguard Worker aten.linalg_lstsq.default : {c64, c128, f64, f32}, 831*da0073e9SAndroid Build Coastguard Worker aten.masked_select.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8}, 832*da0073e9SAndroid Build Coastguard Worker aten.masked_select.out : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8}, 833*da0073e9SAndroid Build Coastguard Worker aten.nonzero.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, c32, b8, i16, u8}, 834*da0073e9SAndroid Build Coastguard Worker aten.nonzero.out : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, c32, b8, i16, u8}, 835*da0073e9SAndroid Build Coastguard Worker aten._to_sparse.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8}, 836*da0073e9SAndroid Build Coastguard Worker aten._to_sparse.sparse_dim : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8}, 837*da0073e9SAndroid Build Coastguard Worker aten._ctc_loss.Tensor : {f32, f64}, # Shape of second output depends on data. 838*da0073e9SAndroid Build Coastguard Worker aten._histogramdd_bin_edges.default : {f32, f64}, 839*da0073e9SAndroid Build Coastguard Worker aten._histogramdd_from_bin_cts.default : {f32, f64}, 840*da0073e9SAndroid Build Coastguard Worker aten._histogramdd_from_bin_tensors.default : {f32, f64}, 841*da0073e9SAndroid Build Coastguard Worker aten._local_scalar_dense.default : {c32, c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8}, 842*da0073e9SAndroid Build Coastguard Worker aten._unique2.default : {i8, f64, i64, f16, bf16, f32, i32, b8, i16, u8, u16, u32, u64}, 843*da0073e9SAndroid Build Coastguard Worker aten.bincount.default : {i64, i8, i32, i16, u8}, 844*da0073e9SAndroid Build Coastguard Worker aten.equal.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8}, 845*da0073e9SAndroid Build Coastguard Worker aten.histogram.bin_ct : {f32, f64}, 846*da0073e9SAndroid Build Coastguard Worker aten.histogram.bins_tensor : {f32, f64}, 847*da0073e9SAndroid Build Coastguard Worker aten.unique_consecutive.default : {i8, f64, i64, f16, bf16, f32, i32, b8, i16, u8, u16, u32, u64}, 848*da0073e9SAndroid Build Coastguard Worker aten.unique_dim.default : {i8, f64, i64, f16, bf16, f32, i32, b8, i16, u8, u16, u32, u64}, 849*da0073e9SAndroid Build Coastguard Worker aten.upsample_nearest3d.vec : {bf16, f32, f64, u8}, 850*da0073e9SAndroid Build Coastguard Worker 851*da0073e9SAndroid Build Coastguard Worker} 852*da0073e9SAndroid Build Coastguard Worker 853*da0073e9SAndroid Build Coastguard Worker# these sometimes pass and sometimes fail 854*da0073e9SAndroid Build Coastguard Workermeta_dispatch_skips = { 855*da0073e9SAndroid Build Coastguard Worker aten.index.Tensor: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32, c32, c64, c128}, # at::nonzero doesn't have a Meta function 856*da0073e9SAndroid Build Coastguard Worker aten._to_copy.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32, c32, c64, c128}, 857*da0073e9SAndroid Build Coastguard Worker aten.empty.memory_format: {b8, bf16, c128, c64, c32, f16, f32, f64, i16, i32, i64, i8, u8}, 858*da0073e9SAndroid Build Coastguard Worker aten.addbmm_.default: {bf16, c128, c64, f32, f64, i16, i32, i64, i8, u8}, 859*da0073e9SAndroid Build Coastguard Worker} 860*da0073e9SAndroid Build Coastguard Worker 861*da0073e9SAndroid Build Coastguard Worker# For CompositeImplicitAutograd functions that fail before hitting the Mode 862*da0073e9SAndroid Build Coastguard Workermeta_dispatch_early_skips = set({ 863*da0073e9SAndroid Build Coastguard Worker torch.Tensor.float_power_, 864*da0073e9SAndroid Build Coastguard Worker # Errors out in one of the tests, while ProxyTensor passes... 865*da0073e9SAndroid Build Coastguard Worker torch.Tensor.cumprod_, 866*da0073e9SAndroid Build Coastguard Worker torch.Tensor.cumsum_, 867*da0073e9SAndroid Build Coastguard Worker}) 868*da0073e9SAndroid Build Coastguard Worker 869*da0073e9SAndroid Build Coastguard Workermeta_inplace_skips = set({ 870*da0073e9SAndroid Build Coastguard Worker # Errors out in one of the tests, while ProxyTensor passes... 871*da0073e9SAndroid Build Coastguard Worker torch.Tensor.cumprod_, 872*da0073e9SAndroid Build Coastguard Worker torch.Tensor.cumsum_, 873*da0073e9SAndroid Build Coastguard Worker}) 874*da0073e9SAndroid Build Coastguard Worker 875*da0073e9SAndroid Build Coastguard Workermeta_dispatch_device_expected_failures = defaultdict(dict) 876*da0073e9SAndroid Build Coastguard Workermeta_dispatch_device_skips = defaultdict(dict) 877*da0073e9SAndroid Build Coastguard Worker 878*da0073e9SAndroid Build Coastguard Workermeta_dispatch_device_expected_failures['cpu'] = { 879*da0073e9SAndroid Build Coastguard Worker # TODO: The decomps for these batch norm ops return different dtypes depending 880*da0073e9SAndroid Build Coastguard Worker # on the device. We should make this work better with meta tensors. 881*da0073e9SAndroid Build Coastguard Worker aten.native_batch_norm.default: {bf16, f16}, 882*da0073e9SAndroid Build Coastguard Worker aten._native_batch_norm_legit.default: {bf16, f16}, 883*da0073e9SAndroid Build Coastguard Worker aten._native_batch_norm_legit.no_stats: {bf16, f16}, 884*da0073e9SAndroid Build Coastguard Worker aten._batch_norm_with_update.default: {bf16, f16}, 885*da0073e9SAndroid Build Coastguard Worker 886*da0073e9SAndroid Build Coastguard Worker aten.native_layer_norm.default: {bf16, f16}, 887*da0073e9SAndroid Build Coastguard Worker} 888*da0073e9SAndroid Build Coastguard Worker 889*da0073e9SAndroid Build Coastguard Workermeta_dispatch_device_expected_failures['cuda'] = { 890*da0073e9SAndroid Build Coastguard Worker aten._unique2.default: {f16}, # aten::_unique2 891*da0073e9SAndroid Build Coastguard Worker aten._use_cudnn_ctc_loss.default: {f32, f64}, # aten::_use_cudnn_ctc_loss 892*da0073e9SAndroid Build Coastguard Worker aten._use_cudnn_ctc_loss.Tensor: {f32, f64}, # aten::_use_cudnn_ctc_loss.Tensor 893*da0073e9SAndroid Build Coastguard Worker aten.cudnn_grid_sampler.default: {f16, f32, f64}, # aten::cudnn_grid_sampler 894*da0073e9SAndroid Build Coastguard Worker aten.geqrf.default: {f32, f64}, # aten::geqrf 895*da0073e9SAndroid Build Coastguard Worker aten.linalg_eigvalsh.out: {f32, f64}, # aten::linalg_eigvalsh.out 896*da0073e9SAndroid Build Coastguard Worker aten.log_sigmoid_forward.default: {bf16, f16, f64, f32}, 897*da0073e9SAndroid Build Coastguard Worker aten.log_sigmoid_forward.output : {bf16, f16, f64, f32}, # aten::log_sigmoid_forward.output 898*da0073e9SAndroid Build Coastguard Worker aten.unique_consecutive.default: {f16}, # aten::unique_consecutive 899*da0073e9SAndroid Build Coastguard Worker aten.unique_dim.default: {f16}, # aten::unique_dim 900*da0073e9SAndroid Build Coastguard Worker aten.upsample_nearest3d.vec: {f16}, # aten::upsample_nearest3d.vec 901*da0073e9SAndroid Build Coastguard Worker} 902*da0073e9SAndroid Build Coastguard Worker 903*da0073e9SAndroid Build Coastguard Workermeta_dispatch_device_skips['cpu'] = { 904*da0073e9SAndroid Build Coastguard Worker aten._embedding_bag_forward_only.default: {bf16, f16, f32, f64}, 905*da0073e9SAndroid Build Coastguard Worker 906*da0073e9SAndroid Build Coastguard Worker # TODO: The decomps for these batch norm ops return different dtypes depending 907*da0073e9SAndroid Build Coastguard Worker # on the device. We should make this work better with meta tensors. 908*da0073e9SAndroid Build Coastguard Worker aten.native_batch_norm.default: {f32, f64}, 909*da0073e9SAndroid Build Coastguard Worker aten._native_batch_norm_legit.default: {f32, f64}, 910*da0073e9SAndroid Build Coastguard Worker aten._native_batch_norm_legit.no_stats: {f32, f64}, 911*da0073e9SAndroid Build Coastguard Worker aten._batch_norm_with_update.default: {f32, f64}, 912*da0073e9SAndroid Build Coastguard Worker 913*da0073e9SAndroid Build Coastguard Worker # If the computation dtype is different from the input 914*da0073e9SAndroid Build Coastguard Worker # dtype this will fail. CPU execution may also have a 915*da0073e9SAndroid Build Coastguard Worker # a different output from other devices. 916*da0073e9SAndroid Build Coastguard Worker aten.native_batch_norm.out: {bf16, f16, f32, f64} 917*da0073e9SAndroid Build Coastguard Worker} 918*da0073e9SAndroid Build Coastguard Worker 919*da0073e9SAndroid Build Coastguard Workermeta_dispatch_device_skips['cuda'] = { 920*da0073e9SAndroid Build Coastguard Worker aten._conj.default: {c32, f16}, # file issue 921*da0073e9SAndroid Build Coastguard Worker aten._linalg_svd.default: {c64, c128}, # aten::linalg_eigvalsh.out 922*da0073e9SAndroid Build Coastguard Worker aten.cudnn_batch_norm.default: {f32, f64}, 923*da0073e9SAndroid Build Coastguard Worker aten.log_softmax.int : {c32, c64}, 924*da0073e9SAndroid Build Coastguard Worker aten.softmax.int : {c32, c64}, 925*da0073e9SAndroid Build Coastguard Worker aten.softmax.int : {c32, c64}, 926*da0073e9SAndroid Build Coastguard Worker 927*da0073e9SAndroid Build Coastguard Worker # ROCm stuff; technically this should be expected failure but it's 928*da0073e9SAndroid Build Coastguard Worker # not worth it; these should get unified anyway 929*da0073e9SAndroid Build Coastguard Worker aten.miopen_batch_norm.default: {f32}, 930*da0073e9SAndroid Build Coastguard Worker} 931*da0073e9SAndroid Build Coastguard Worker 932*da0073e9SAndroid Build Coastguard Workerdef get_strided_args(args): 933*da0073e9SAndroid Build Coastguard Worker 934*da0073e9SAndroid Build Coastguard Worker def get_strided_variants(t, include_storage_offset=False): 935*da0073e9SAndroid Build Coastguard Worker variants = [] 936*da0073e9SAndroid Build Coastguard Worker 937*da0073e9SAndroid Build Coastguard Worker # contiguous 938*da0073e9SAndroid Build Coastguard Worker variants.append(t) 939*da0073e9SAndroid Build Coastguard Worker 940*da0073e9SAndroid Build Coastguard Worker # transposed 941*da0073e9SAndroid Build Coastguard Worker if t.ndim > 1: 942*da0073e9SAndroid Build Coastguard Worker perm = list(reversed(range(t.ndim))) 943*da0073e9SAndroid Build Coastguard Worker transposed = torch.empty( 944*da0073e9SAndroid Build Coastguard Worker t.shape[::-1], device=t.device, dtype=t.dtype, requires_grad=t.requires_grad 945*da0073e9SAndroid Build Coastguard Worker ).permute(perm).copy_(t) 946*da0073e9SAndroid Build Coastguard Worker variants.append(transposed) 947*da0073e9SAndroid Build Coastguard Worker 948*da0073e9SAndroid Build Coastguard Worker # nondense 949*da0073e9SAndroid Build Coastguard Worker if t.ndim > 0: 950*da0073e9SAndroid Build Coastguard Worker nondense = torch.repeat_interleave(t, 2, dim=-1)[..., ::2] 951*da0073e9SAndroid Build Coastguard Worker variants.append(nondense) 952*da0073e9SAndroid Build Coastguard Worker 953*da0073e9SAndroid Build Coastguard Worker # channel_last 954*da0073e9SAndroid Build Coastguard Worker if t.ndim == 4: 955*da0073e9SAndroid Build Coastguard Worker variants.append(t.contiguous(memory_format=torch.channels_last)) 956*da0073e9SAndroid Build Coastguard Worker 957*da0073e9SAndroid Build Coastguard Worker # channel_last_3d 958*da0073e9SAndroid Build Coastguard Worker if t.ndim == 5: 959*da0073e9SAndroid Build Coastguard Worker variants.append(t.contiguous(memory_format=torch.channels_last_3d)) 960*da0073e9SAndroid Build Coastguard Worker 961*da0073e9SAndroid Build Coastguard Worker # storage_offset 962*da0073e9SAndroid Build Coastguard Worker if include_storage_offset: 963*da0073e9SAndroid Build Coastguard Worker buffer = torch.empty(t.numel() + 1, device=t.device, dtype=t.dtype, requires_grad=t.requires_grad) 964*da0073e9SAndroid Build Coastguard Worker buffer = buffer.as_strided(t.shape, t.stride(), storage_offset=1) 965*da0073e9SAndroid Build Coastguard Worker buffer.copy_(t) 966*da0073e9SAndroid Build Coastguard Worker variants.append(buffer) 967*da0073e9SAndroid Build Coastguard Worker 968*da0073e9SAndroid Build Coastguard Worker return variants 969*da0073e9SAndroid Build Coastguard Worker 970*da0073e9SAndroid Build Coastguard Worker strided_args = [] 971*da0073e9SAndroid Build Coastguard Worker for arg in args: 972*da0073e9SAndroid Build Coastguard Worker if isinstance(arg, torch.Tensor) and not arg.is_sparse_csr and arg.is_contiguous(): 973*da0073e9SAndroid Build Coastguard Worker strided_arg_variants = get_strided_variants(arg) 974*da0073e9SAndroid Build Coastguard Worker else: 975*da0073e9SAndroid Build Coastguard Worker strided_arg_variants = [arg] 976*da0073e9SAndroid Build Coastguard Worker strided_args.append(strided_arg_variants) 977*da0073e9SAndroid Build Coastguard Worker 978*da0073e9SAndroid Build Coastguard Worker yield from itertools.product(*strided_args) 979*da0073e9SAndroid Build Coastguard Worker 980*da0073e9SAndroid Build Coastguard Workerclass MetaCrossRefDispatchMode(torch.utils._python_dispatch.TorchDispatchMode): 981*da0073e9SAndroid Build Coastguard Worker test_case: TestCase 982*da0073e9SAndroid Build Coastguard Worker device: torch.device 983*da0073e9SAndroid Build Coastguard Worker dtype: torch.dtype 984*da0073e9SAndroid Build Coastguard Worker aten_olp_no_out_overload: set = set() 985*da0073e9SAndroid Build Coastguard Worker 986*da0073e9SAndroid Build Coastguard Worker def __init__(self, test_case, *, device, dtype, symbolic_meta: bool, inplace: bool, supports_out: bool): 987*da0073e9SAndroid Build Coastguard Worker self.test_case = test_case 988*da0073e9SAndroid Build Coastguard Worker # save TLS 989*da0073e9SAndroid Build Coastguard Worker self.precision = test_case.precision 990*da0073e9SAndroid Build Coastguard Worker self.rel_tol = test_case.rel_tol 991*da0073e9SAndroid Build Coastguard Worker self.device_type = torch.device(device).type 992*da0073e9SAndroid Build Coastguard Worker self.dtype = dtype 993*da0073e9SAndroid Build Coastguard Worker self.symbolic_meta = symbolic_meta 994*da0073e9SAndroid Build Coastguard Worker self.inplace = inplace 995*da0073e9SAndroid Build Coastguard Worker self.supports_out = supports_out 996*da0073e9SAndroid Build Coastguard Worker 997*da0073e9SAndroid Build Coastguard Worker @staticmethod 998*da0073e9SAndroid Build Coastguard Worker def try_resolve_aten_out_overload(ol, args, kwargs, num_outputs): 999*da0073e9SAndroid Build Coastguard Worker 1000*da0073e9SAndroid Build Coastguard Worker ol_args = ol._schema.arguments 1001*da0073e9SAndroid Build Coastguard Worker olp: OpOverloadPacket = ol._overloadpacket 1002*da0073e9SAndroid Build Coastguard Worker 1003*da0073e9SAndroid Build Coastguard Worker if olp in MetaCrossRefDispatchMode.aten_olp_no_out_overload: 1004*da0073e9SAndroid Build Coastguard Worker return (None, None, None) 1005*da0073e9SAndroid Build Coastguard Worker 1006*da0073e9SAndroid Build Coastguard Worker candidate_ols = [] 1007*da0073e9SAndroid Build Coastguard Worker for candidate_ol_name in olp.overloads(): 1008*da0073e9SAndroid Build Coastguard Worker candidate_ol = getattr(olp, candidate_ol_name) 1009*da0073e9SAndroid Build Coastguard Worker if any(arg.is_out for arg in candidate_ol._schema.arguments): 1010*da0073e9SAndroid Build Coastguard Worker candidate_ols.append(candidate_ol) 1011*da0073e9SAndroid Build Coastguard Worker 1012*da0073e9SAndroid Build Coastguard Worker if not candidate_ols: 1013*da0073e9SAndroid Build Coastguard Worker MetaCrossRefDispatchMode.aten_olp_no_out_overload.add(olp) 1014*da0073e9SAndroid Build Coastguard Worker return (None, None, None) 1015*da0073e9SAndroid Build Coastguard Worker 1016*da0073e9SAndroid Build Coastguard Worker # Now match based on args, kwargs and number of required outputs 1017*da0073e9SAndroid Build Coastguard Worker candidate_ol: OpOverload = None 1018*da0073e9SAndroid Build Coastguard Worker for candidate_ol in candidate_ols: 1019*da0073e9SAndroid Build Coastguard Worker candidate_ol_args = candidate_ol._schema.arguments 1020*da0073e9SAndroid Build Coastguard Worker 1021*da0073e9SAndroid Build Coastguard Worker if (len(args) >= len(candidate_ol_args)): 1022*da0073e9SAndroid Build Coastguard Worker continue 1023*da0073e9SAndroid Build Coastguard Worker 1024*da0073e9SAndroid Build Coastguard Worker # Positional arguments must have the same type 1025*da0073e9SAndroid Build Coastguard Worker if not all( 1026*da0073e9SAndroid Build Coastguard Worker ol_args[pos_arg_ind].type == candidate_ol_args[pos_arg_ind].type 1027*da0073e9SAndroid Build Coastguard Worker for pos_arg_ind in range(len(args)) 1028*da0073e9SAndroid Build Coastguard Worker ): 1029*da0073e9SAndroid Build Coastguard Worker continue 1030*da0073e9SAndroid Build Coastguard Worker 1031*da0073e9SAndroid Build Coastguard Worker # Number of outputs must match 1032*da0073e9SAndroid Build Coastguard Worker candidate_out_names = [out_arg.name for out_arg in candidate_ol_args[-num_outputs:] if out_arg.is_out] 1033*da0073e9SAndroid Build Coastguard Worker if len(candidate_out_names) != num_outputs: 1034*da0073e9SAndroid Build Coastguard Worker continue 1035*da0073e9SAndroid Build Coastguard Worker 1036*da0073e9SAndroid Build Coastguard Worker # Now try and match kwargs. Just need to ensure that the 1037*da0073e9SAndroid Build Coastguard Worker # remaining kwargs allow an out overload to be called. For example 1038*da0073e9SAndroid Build Coastguard Worker # we can throw away parameters like `dtype` that may be passed to the 1039*da0073e9SAndroid Build Coastguard Worker # functional version of the op since the `dtype` will already be present 1040*da0073e9SAndroid Build Coastguard Worker # in the `out` argument 1041*da0073e9SAndroid Build Coastguard Worker new_kwargs = {} 1042*da0073e9SAndroid Build Coastguard Worker kwargs_match = True 1043*da0073e9SAndroid Build Coastguard Worker for arg in candidate_ol_args[len(args):-num_outputs]: 1044*da0073e9SAndroid Build Coastguard Worker if arg.name not in kwargs: 1045*da0073e9SAndroid Build Coastguard Worker if arg.has_default_value(): 1046*da0073e9SAndroid Build Coastguard Worker new_kwargs[arg.name] = arg.default_value 1047*da0073e9SAndroid Build Coastguard Worker elif isinstance(arg.type, torch.OptionalType): 1048*da0073e9SAndroid Build Coastguard Worker if isinstance(arg.type.getElementType(), torch.BoolType): 1049*da0073e9SAndroid Build Coastguard Worker new_kwargs[arg.name] = False 1050*da0073e9SAndroid Build Coastguard Worker else: 1051*da0073e9SAndroid Build Coastguard Worker new_kwargs[arg.name] = None 1052*da0073e9SAndroid Build Coastguard Worker else: 1053*da0073e9SAndroid Build Coastguard Worker kwargs_match = False 1054*da0073e9SAndroid Build Coastguard Worker break 1055*da0073e9SAndroid Build Coastguard Worker else: 1056*da0073e9SAndroid Build Coastguard Worker new_kwargs[arg.name] = kwargs[arg.name] 1057*da0073e9SAndroid Build Coastguard Worker 1058*da0073e9SAndroid Build Coastguard Worker if kwargs_match: 1059*da0073e9SAndroid Build Coastguard Worker return candidate_ol, candidate_out_names, new_kwargs 1060*da0073e9SAndroid Build Coastguard Worker 1061*da0073e9SAndroid Build Coastguard Worker return None, None, None 1062*da0073e9SAndroid Build Coastguard Worker 1063*da0073e9SAndroid Build Coastguard Worker def _get_expected_test_result(self, func: OpOverload): 1064*da0073e9SAndroid Build Coastguard Worker if self.dtype in meta_dispatch_skips.get(func, set()): 1065*da0073e9SAndroid Build Coastguard Worker test_expect = TestExpect.SKIP 1066*da0073e9SAndroid Build Coastguard Worker elif self.dtype in meta_dispatch_device_skips[self.device_type].get(func, set()): 1067*da0073e9SAndroid Build Coastguard Worker test_expect = TestExpect.SKIP 1068*da0073e9SAndroid Build Coastguard Worker elif self.dtype in meta_dispatch_expected_failures.get(func, set()): 1069*da0073e9SAndroid Build Coastguard Worker test_expect = TestExpect.XFAILURE 1070*da0073e9SAndroid Build Coastguard Worker elif self.dtype in meta_dispatch_device_expected_failures[self.device_type].get(func, set()): 1071*da0073e9SAndroid Build Coastguard Worker test_expect = TestExpect.XFAILURE 1072*da0073e9SAndroid Build Coastguard Worker else: 1073*da0073e9SAndroid Build Coastguard Worker test_expect = TestExpect.SUCCESS 1074*da0073e9SAndroid Build Coastguard Worker return test_expect 1075*da0073e9SAndroid Build Coastguard Worker 1076*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1077*da0073e9SAndroid Build Coastguard Worker kwargs = kwargs or {} 1078*da0073e9SAndroid Build Coastguard Worker self.test_case.precision = self.precision 1079*da0073e9SAndroid Build Coastguard Worker self.test_case.rel_tol = self.rel_tol 1080*da0073e9SAndroid Build Coastguard Worker 1081*da0073e9SAndroid Build Coastguard Worker test_expect = self._get_expected_test_result(func) 1082*da0073e9SAndroid Build Coastguard Worker 1083*da0073e9SAndroid Build Coastguard Worker expected = run_meta_crossref( 1084*da0073e9SAndroid Build Coastguard Worker self.test_case, 1085*da0073e9SAndroid Build Coastguard Worker test_expect, 1086*da0073e9SAndroid Build Coastguard Worker func, 1087*da0073e9SAndroid Build Coastguard Worker args, 1088*da0073e9SAndroid Build Coastguard Worker kwargs, 1089*da0073e9SAndroid Build Coastguard Worker dtype=self.dtype, 1090*da0073e9SAndroid Build Coastguard Worker device_type=self.device_type, 1091*da0073e9SAndroid Build Coastguard Worker run_symbolic_meta=self.symbolic_meta, 1092*da0073e9SAndroid Build Coastguard Worker ) 1093*da0073e9SAndroid Build Coastguard Worker 1094*da0073e9SAndroid Build Coastguard Worker # This is to test torch ops that do not have an out parameter but have 1095*da0073e9SAndroid Build Coastguard Worker # aten op overloads that have out parameters. Additionally, Python decompositions 1096*da0073e9SAndroid Build Coastguard Worker # may register OpOverloadPacket's so decompositions need to be tested 1097*da0073e9SAndroid Build Coastguard Worker # to ensure all OpOverloads still function for the Meta key (e.g. if a python decomposition 1098*da0073e9SAndroid Build Coastguard Worker # is registered for an aten op aten.foo with overloads [default, out], the python 1099*da0073e9SAndroid Build Coastguard Worker # function needs to support receiving `out` arguments) 1100*da0073e9SAndroid Build Coastguard Worker if ( 1101*da0073e9SAndroid Build Coastguard Worker not self.inplace and 1102*da0073e9SAndroid Build Coastguard Worker not self.supports_out and 1103*da0073e9SAndroid Build Coastguard Worker test_expect == TestExpect.SUCCESS and 1104*da0073e9SAndroid Build Coastguard Worker (torch.is_tensor(expected) or isinstance(expected, Iterable)) 1105*da0073e9SAndroid Build Coastguard Worker ): 1106*da0073e9SAndroid Build Coastguard Worker 1107*da0073e9SAndroid Build Coastguard Worker # check to see if there is a potential out overload 1108*da0073e9SAndroid Build Coastguard Worker num_outputs = 1 if torch.is_tensor(expected) else len(expected) 1109*da0073e9SAndroid Build Coastguard Worker func_out_overload, out_param_names, kwargs = self.try_resolve_aten_out_overload(func, args, kwargs, num_outputs) 1110*da0073e9SAndroid Build Coastguard Worker 1111*da0073e9SAndroid Build Coastguard Worker if func_out_overload: 1112*da0073e9SAndroid Build Coastguard Worker 1113*da0073e9SAndroid Build Coastguard Worker if num_outputs == 1: 1114*da0073e9SAndroid Build Coastguard Worker kwargs[out_param_names[0]] = expected 1115*da0073e9SAndroid Build Coastguard Worker else: 1116*da0073e9SAndroid Build Coastguard Worker for ind, out_param_name in enumerate(out_param_names): 1117*da0073e9SAndroid Build Coastguard Worker kwargs[out_param_name] = expected[ind] 1118*da0073e9SAndroid Build Coastguard Worker 1119*da0073e9SAndroid Build Coastguard Worker test_expect = self._get_expected_test_result(func_out_overload) 1120*da0073e9SAndroid Build Coastguard Worker 1121*da0073e9SAndroid Build Coastguard Worker run_meta_crossref( 1122*da0073e9SAndroid Build Coastguard Worker self.test_case, 1123*da0073e9SAndroid Build Coastguard Worker test_expect, 1124*da0073e9SAndroid Build Coastguard Worker func_out_overload, 1125*da0073e9SAndroid Build Coastguard Worker args, 1126*da0073e9SAndroid Build Coastguard Worker kwargs, 1127*da0073e9SAndroid Build Coastguard Worker dtype=self.dtype, 1128*da0073e9SAndroid Build Coastguard Worker device_type=self.device_type, 1129*da0073e9SAndroid Build Coastguard Worker run_symbolic_meta=self.symbolic_meta, 1130*da0073e9SAndroid Build Coastguard Worker ) 1131*da0073e9SAndroid Build Coastguard Worker 1132*da0073e9SAndroid Build Coastguard Worker return expected 1133*da0073e9SAndroid Build Coastguard Worker 1134*da0073e9SAndroid Build Coastguard Worker# NB: we're running these tests only on CUDA because there are some 1135*da0073e9SAndroid Build Coastguard Worker# inconsistencies between CUDA and CPU, and running on CUDA makes it easier 1136*da0073e9SAndroid Build Coastguard Worker# to ignore the CPU case when inconsistencies arise. Ideally we deal 1137*da0073e9SAndroid Build Coastguard Worker# with the inconsistencies but this takes time. 1138*da0073e9SAndroid Build Coastguard Worker@unMarkDynamoStrictTest 1139*da0073e9SAndroid Build Coastguard Workerclass TestMeta(TestCase): 1140*da0073e9SAndroid Build Coastguard Worker # Copies inputs to inplace operations to avoid inplace modifications 1141*da0073e9SAndroid Build Coastguard Worker # to leaves requiring gradient 1142*da0073e9SAndroid Build Coastguard Worker def _get_safe_inplace(self, inplace_variant): 1143*da0073e9SAndroid Build Coastguard Worker @wraps(inplace_variant) 1144*da0073e9SAndroid Build Coastguard Worker def _fn(t, *args, **kwargs): 1145*da0073e9SAndroid Build Coastguard Worker if isinstance(t, list): 1146*da0073e9SAndroid Build Coastguard Worker return inplace_variant([x.clone() for x in t], *args, **kwargs) 1147*da0073e9SAndroid Build Coastguard Worker else: 1148*da0073e9SAndroid Build Coastguard Worker return inplace_variant(t.clone(), *args, **kwargs) 1149*da0073e9SAndroid Build Coastguard Worker 1150*da0073e9SAndroid Build Coastguard Worker return _fn 1151*da0073e9SAndroid Build Coastguard Worker 1152*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 1153*da0073e9SAndroid Build Coastguard Worker @skipIfCrossRef 1154*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 1155*da0073e9SAndroid Build Coastguard Worker @ops(itertools.chain(op_db, foreach_op_db)) 1156*da0073e9SAndroid Build Coastguard Worker def test_meta_outplace(self, device, dtype, op): 1157*da0073e9SAndroid Build Coastguard Worker if "_scaled_mm" in op.name: 1158*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest("_scaled_mm dose not support meta device") 1159*da0073e9SAndroid Build Coastguard Worker skip_op_names = ( 1160*da0073e9SAndroid Build Coastguard Worker "fft.ihfft", 1161*da0073e9SAndroid Build Coastguard Worker "fft.ihfft2", 1162*da0073e9SAndroid Build Coastguard Worker "linalg.lu_solve", 1163*da0073e9SAndroid Build Coastguard Worker ) 1164*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_TORCHDYNAMO and op.name in skip_op_names: 1165*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest("flaky") 1166*da0073e9SAndroid Build Coastguard Worker # run the OpInfo sample inputs, cross-referencing them with the 1167*da0073e9SAndroid Build Coastguard Worker # meta implementation and check the results are the same. All 1168*da0073e9SAndroid Build Coastguard Worker # the heavy lifting happens in MetaCrossRefFunctionMode 1169*da0073e9SAndroid Build Coastguard Worker func = op.get_op() 1170*da0073e9SAndroid Build Coastguard Worker samples = op.sample_inputs(device, dtype, requires_grad=False) 1171*da0073e9SAndroid Build Coastguard Worker for sample_input in samples: 1172*da0073e9SAndroid Build Coastguard Worker args = [sample_input.input] + list(sample_input.args) 1173*da0073e9SAndroid Build Coastguard Worker kwargs = sample_input.kwargs 1174*da0073e9SAndroid Build Coastguard Worker with MetaCrossRefFunctionMode(self, dtype=dtype, device=device, inplace=False): 1175*da0073e9SAndroid Build Coastguard Worker expected = func(*args, **kwargs) 1176*da0073e9SAndroid Build Coastguard Worker if isinstance(expected, torch.Tensor) and op.supports_out: 1177*da0073e9SAndroid Build Coastguard Worker func(*args, **kwargs, out=expected) 1178*da0073e9SAndroid Build Coastguard Worker 1179*da0073e9SAndroid Build Coastguard Worker # Special test for functions taking "device" kwarg 1180*da0073e9SAndroid Build Coastguard Worker # The crossref tests that replacing the device with "meta" works 1181*da0073e9SAndroid Build Coastguard Worker # This part makes sure that *_like functions work well with a "meta" 1182*da0073e9SAndroid Build Coastguard Worker # Tensor and their original device argument. 1183*da0073e9SAndroid Build Coastguard Worker if "device" in kwargs and "_like" in op.name: 1184*da0073e9SAndroid Build Coastguard Worker with torch.random.fork_rng(): 1185*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(123) 1186*da0073e9SAndroid Build Coastguard Worker ref = func(*args, **kwargs) 1187*da0073e9SAndroid Build Coastguard Worker 1188*da0073e9SAndroid Build Coastguard Worker # *_like functions take a Tensor as first argument 1189*da0073e9SAndroid Build Coastguard Worker assert isinstance(args[0], torch.Tensor) 1190*da0073e9SAndroid Build Coastguard Worker with torch.random.fork_rng(): 1191*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(123) 1192*da0073e9SAndroid Build Coastguard Worker args[0] = args[0].to(device="meta") 1193*da0073e9SAndroid Build Coastguard Worker meta = func(*args, **kwargs) 1194*da0073e9SAndroid Build Coastguard Worker 1195*da0073e9SAndroid Build Coastguard Worker # empty_like is not deterministic 1196*da0073e9SAndroid Build Coastguard Worker if op.name != "empty_like": 1197*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, meta) 1198*da0073e9SAndroid Build Coastguard Worker 1199*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 1200*da0073e9SAndroid Build Coastguard Worker @skipIfCrossRef 1201*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 1202*da0073e9SAndroid Build Coastguard Worker @ops(itertools.chain(op_db, foreach_op_db)) 1203*da0073e9SAndroid Build Coastguard Worker def test_meta_inplace(self, device, dtype, op): 1204*da0073e9SAndroid Build Coastguard Worker func = op.get_inplace() 1205*da0073e9SAndroid Build Coastguard Worker if not func: 1206*da0073e9SAndroid Build Coastguard Worker self.skipTest("No inplace variable for this op") 1207*da0073e9SAndroid Build Coastguard Worker if op.promotes_int_to_float and not dtype.is_floating_point: 1208*da0073e9SAndroid Build Coastguard Worker self.skipTest("Op promotes to float, which is impossible for inplace with non-float input") 1209*da0073e9SAndroid Build Coastguard Worker if func in meta_inplace_skips: 1210*da0073e9SAndroid Build Coastguard Worker self.skipTest("Skipped") 1211*da0073e9SAndroid Build Coastguard Worker func = self._get_safe_inplace(func) 1212*da0073e9SAndroid Build Coastguard Worker samples = op.sample_inputs(device, dtype, requires_grad=False) 1213*da0073e9SAndroid Build Coastguard Worker for sample_input in samples: 1214*da0073e9SAndroid Build Coastguard Worker if sample_input.broadcasts_input: 1215*da0073e9SAndroid Build Coastguard Worker continue 1216*da0073e9SAndroid Build Coastguard Worker args = [sample_input.input] + list(sample_input.args) 1217*da0073e9SAndroid Build Coastguard Worker kwargs = sample_input.kwargs 1218*da0073e9SAndroid Build Coastguard Worker with MetaCrossRefFunctionMode(self, dtype=dtype, device=device, inplace=True): 1219*da0073e9SAndroid Build Coastguard Worker expected = func(*args, **kwargs) 1220*da0073e9SAndroid Build Coastguard Worker 1221*da0073e9SAndroid Build Coastguard Worker def _run_dispatch_meta_test(self, device, dtype, op, symbolic_meta, inplace, all_stride_variants=False): 1222*da0073e9SAndroid Build Coastguard Worker if "_scaled_mm" in op.name: 1223*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest("_scaled_mm dose not support meta device") 1224*da0073e9SAndroid Build Coastguard Worker if inplace: 1225*da0073e9SAndroid Build Coastguard Worker func = op.get_inplace() 1226*da0073e9SAndroid Build Coastguard Worker if not func: 1227*da0073e9SAndroid Build Coastguard Worker self.skipTest("No inplace variable for this op") 1228*da0073e9SAndroid Build Coastguard Worker if op.promotes_int_to_float and not dtype.is_floating_point: 1229*da0073e9SAndroid Build Coastguard Worker self.skipTest("Op promotes to float, which is impossible for inplace with non-float input") 1230*da0073e9SAndroid Build Coastguard Worker else: 1231*da0073e9SAndroid Build Coastguard Worker func = op.get_op() 1232*da0073e9SAndroid Build Coastguard Worker 1233*da0073e9SAndroid Build Coastguard Worker if func in meta_dispatch_early_skips: 1234*da0073e9SAndroid Build Coastguard Worker self.skipTest("Function is in dispatch early skips") 1235*da0073e9SAndroid Build Coastguard Worker 1236*da0073e9SAndroid Build Coastguard Worker if inplace: 1237*da0073e9SAndroid Build Coastguard Worker func = self._get_safe_inplace(func) 1238*da0073e9SAndroid Build Coastguard Worker 1239*da0073e9SAndroid Build Coastguard Worker samples = op.sample_inputs(device, dtype, requires_grad=False) 1240*da0073e9SAndroid Build Coastguard Worker for sample_input in samples: 1241*da0073e9SAndroid Build Coastguard Worker if inplace and sample_input.broadcasts_input: 1242*da0073e9SAndroid Build Coastguard Worker continue 1243*da0073e9SAndroid Build Coastguard Worker 1244*da0073e9SAndroid Build Coastguard Worker sample_args = [sample_input.input] + list(sample_input.args) 1245*da0073e9SAndroid Build Coastguard Worker kwargs = sample_input.kwargs 1246*da0073e9SAndroid Build Coastguard Worker 1247*da0073e9SAndroid Build Coastguard Worker if all_stride_variants and sum(isinstance(arg, torch.Tensor) for arg in sample_args) <= 5: 1248*da0073e9SAndroid Build Coastguard Worker # test inputs <= 5 tensors to avoid combinatorial explosion 1249*da0073e9SAndroid Build Coastguard Worker strided_args = get_strided_args(sample_args) 1250*da0073e9SAndroid Build Coastguard Worker else: 1251*da0073e9SAndroid Build Coastguard Worker strided_args = [sample_args] 1252*da0073e9SAndroid Build Coastguard Worker 1253*da0073e9SAndroid Build Coastguard Worker for args in strided_args: 1254*da0073e9SAndroid Build Coastguard Worker with MetaCrossRefDispatchMode.push( 1255*da0073e9SAndroid Build Coastguard Worker self, dtype=dtype, device=device, 1256*da0073e9SAndroid Build Coastguard Worker symbolic_meta=symbolic_meta, inplace=inplace, 1257*da0073e9SAndroid Build Coastguard Worker supports_out=op.supports_out): 1258*da0073e9SAndroid Build Coastguard Worker expected = func(*args, **kwargs) 1259*da0073e9SAndroid Build Coastguard Worker 1260*da0073e9SAndroid Build Coastguard Worker if not inplace and isinstance(expected, torch.Tensor) and op.supports_out: 1261*da0073e9SAndroid Build Coastguard Worker func(*args, **kwargs, out=expected) 1262*da0073e9SAndroid Build Coastguard Worker 1263*da0073e9SAndroid Build Coastguard Worker 1264*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 1265*da0073e9SAndroid Build Coastguard Worker @skipIfCrossRef 1266*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 1267*da0073e9SAndroid Build Coastguard Worker @ops(itertools.chain(op_db, foreach_op_db)) 1268*da0073e9SAndroid Build Coastguard Worker def test_dispatch_meta_outplace(self, device, dtype, op): 1269*da0073e9SAndroid Build Coastguard Worker self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=False, inplace=False) 1270*da0073e9SAndroid Build Coastguard Worker 1271*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 1272*da0073e9SAndroid Build Coastguard Worker @skipIfCrossRef 1273*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 1274*da0073e9SAndroid Build Coastguard Worker @ops(itertools.chain(op_db, foreach_op_db)) 1275*da0073e9SAndroid Build Coastguard Worker def test_dispatch_meta_inplace(self, device, dtype, op): 1276*da0073e9SAndroid Build Coastguard Worker self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=False, inplace=True) 1277*da0073e9SAndroid Build Coastguard Worker 1278*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 1279*da0073e9SAndroid Build Coastguard Worker @skipIfCrossRef 1280*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 1281*da0073e9SAndroid Build Coastguard Worker @ops(itertools.chain(op_db, foreach_op_db)) 1282*da0073e9SAndroid Build Coastguard Worker def test_dispatch_symbolic_meta_outplace(self, device, dtype, op): 1283*da0073e9SAndroid Build Coastguard Worker self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=True, inplace=False) 1284*da0073e9SAndroid Build Coastguard Worker 1285*da0073e9SAndroid Build Coastguard Worker 1286*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 1287*da0073e9SAndroid Build Coastguard Worker @skipIfCrossRef 1288*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 1289*da0073e9SAndroid Build Coastguard Worker @ops(itertools.chain(op_db, foreach_op_db)) 1290*da0073e9SAndroid Build Coastguard Worker def test_dispatch_symbolic_meta_inplace(self, device, dtype, op): 1291*da0073e9SAndroid Build Coastguard Worker self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=True, inplace=True) 1292*da0073e9SAndroid Build Coastguard Worker 1293*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 1294*da0073e9SAndroid Build Coastguard Worker @skipIfCrossRef 1295*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 1296*da0073e9SAndroid Build Coastguard Worker # only test one dtype, as output stride behavior is the same for all dtypes 1297*da0073e9SAndroid Build Coastguard Worker @ops(itertools.chain(op_db, foreach_op_db), dtypes=OpDTypes.any_common_cpu_cuda_one) 1298*da0073e9SAndroid Build Coastguard Worker # Only test on CUDA, as CUDA kernel's stride is the reference 1299*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1300*da0073e9SAndroid Build Coastguard Worker def test_dispatch_symbolic_meta_outplace_all_strides(self, device, dtype, op): 1301*da0073e9SAndroid Build Coastguard Worker self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=True, inplace=False, all_stride_variants=True) 1302*da0073e9SAndroid Build Coastguard Worker 1303*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 1304*da0073e9SAndroid Build Coastguard Worker @skipIfCrossRef 1305*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 1306*da0073e9SAndroid Build Coastguard Worker # only test one dtype, as output stride behavior is the same for all dtypes 1307*da0073e9SAndroid Build Coastguard Worker @ops(itertools.chain(op_db, foreach_op_db), dtypes=OpDTypes.any_common_cpu_cuda_one) 1308*da0073e9SAndroid Build Coastguard Worker # Only test on CUDA, as CUDA kernel's stride is the reference 1309*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1310*da0073e9SAndroid Build Coastguard Worker def test_dispatch_symbolic_meta_inplace_all_strides(self, device, dtype, op): 1311*da0073e9SAndroid Build Coastguard Worker self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=True, inplace=True, all_stride_variants=True) 1312*da0073e9SAndroid Build Coastguard Worker 1313*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 1314*da0073e9SAndroid Build Coastguard Worker @skipIfCrossRef 1315*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 1316*da0073e9SAndroid Build Coastguard Worker # only test one dtype, as output stride behavior is the same for all dtypes 1317*da0073e9SAndroid Build Coastguard Worker @ops(binary_ufuncs, allowed_dtypes=(torch.float32,)) 1318*da0073e9SAndroid Build Coastguard Worker # Only test on CUDA, as CUDA kernel's stride is the reference 1319*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1320*da0073e9SAndroid Build Coastguard Worker def test_binary_ufuncs_mixed_dtype(self, device, dtype, op): 1321*da0073e9SAndroid Build Coastguard Worker make_arg = partial( 1322*da0073e9SAndroid Build Coastguard Worker make_tensor, 1323*da0073e9SAndroid Build Coastguard Worker device=device, 1324*da0073e9SAndroid Build Coastguard Worker ) 1325*da0073e9SAndroid Build Coastguard Worker 1326*da0073e9SAndroid Build Coastguard Worker def sample_input(op, device, dtype, requires_grad, **kwargs): 1327*da0073e9SAndroid Build Coastguard Worker yield SampleInput( 1328*da0073e9SAndroid Build Coastguard Worker make_arg((S,), dtype=dtype), make_arg((S,), dtype=torch.float16) 1329*da0073e9SAndroid Build Coastguard Worker ) 1330*da0073e9SAndroid Build Coastguard Worker 1331*da0073e9SAndroid Build Coastguard Worker op = copy.copy(op) 1332*da0073e9SAndroid Build Coastguard Worker op.sample_inputs_func = sample_input 1333*da0073e9SAndroid Build Coastguard Worker 1334*da0073e9SAndroid Build Coastguard Worker self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=True, inplace=False) 1335*da0073e9SAndroid Build Coastguard Worker 1336*da0073e9SAndroid Build Coastguard Worker 1337*da0073e9SAndroid Build Coastguard Worker def test_empty_quantized(self): 1338*da0073e9SAndroid Build Coastguard Worker r = torch.empty(2 ** 52, device='meta', dtype=torch.qint8) 1339*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r.device.type, 'meta') 1340*da0073e9SAndroid Build Coastguard Worker 1341*da0073e9SAndroid Build Coastguard Worker def test_nan_to_num(self): 1342*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14], device='meta') 1343*da0073e9SAndroid Build Coastguard Worker r = t.nan_to_num() 1344*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r.device.type, 'meta') 1345*da0073e9SAndroid Build Coastguard Worker 1346*da0073e9SAndroid Build Coastguard Worker def test_inplace_masked_fill_error(self): 1347*da0073e9SAndroid Build Coastguard Worker t = torch.randn(3, 3, device='meta') 1348*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "doesn't match the broadcast"): 1349*da0073e9SAndroid Build Coastguard Worker t.masked_fill_((t > 0).unsqueeze(0), 0.1) 1350*da0073e9SAndroid Build Coastguard Worker 1351*da0073e9SAndroid Build Coastguard Worker def test_inplace_bin_ops_error(self): 1352*da0073e9SAndroid Build Coastguard Worker t = torch.randn(3, 3, device='meta') 1353*da0073e9SAndroid Build Coastguard Worker for op in (torch.Tensor.add_, torch.Tensor.sub_, torch.Tensor.mul_, torch.Tensor.div_, 1354*da0073e9SAndroid Build Coastguard Worker torch.Tensor.logical_and_, torch.Tensor.logical_or_, torch.Tensor.logical_xor_): 1355*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "doesn't match the broadcast"): 1356*da0073e9SAndroid Build Coastguard Worker op(t, t.clone().unsqueeze(0)) 1357*da0073e9SAndroid Build Coastguard Worker 1358*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1359*da0073e9SAndroid Build Coastguard Worker def test_meta_autograd_no_error(self): 1360*da0073e9SAndroid Build Coastguard Worker with torch.library._scoped_library("meta_test", "DEF") as lib: 1361*da0073e9SAndroid Build Coastguard Worker with torch.library._scoped_library("meta_test", "IMPL", "CPU") as impl_cpu: 1362*da0073e9SAndroid Build Coastguard Worker with torch.library._scoped_library("meta_test", "IMPL", "Meta") as impl_meta: 1363*da0073e9SAndroid Build Coastguard Worker def foo_impl(x): 1364*da0073e9SAndroid Build Coastguard Worker return x + 1 1365*da0073e9SAndroid Build Coastguard Worker 1366*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor a) -> Tensor") 1367*da0073e9SAndroid Build Coastguard Worker impl_meta.impl("foo", foo_impl) 1368*da0073e9SAndroid Build Coastguard Worker impl_cpu.impl("foo", foo_impl) 1369*da0073e9SAndroid Build Coastguard Worker 1370*da0073e9SAndroid Build Coastguard Worker a = torch.ones(2, device='meta') 1371*da0073e9SAndroid Build Coastguard Worker # The point of the test is that this should not error: 1372*da0073e9SAndroid Build Coastguard Worker # We have a fallthrough kernel registered to the AutogradMeta 1373*da0073e9SAndroid Build Coastguard Worker # key for custom ops, so it's fine that `foo()` doesn't have 1374*da0073e9SAndroid Build Coastguard Worker # an autograd kernel. 1375*da0073e9SAndroid Build Coastguard Worker b = torch.ops.meta_test.foo.default(a) 1376*da0073e9SAndroid Build Coastguard Worker 1377*da0073e9SAndroid Build Coastguard Worker def test_huber_loss_backward(self): 1378*da0073e9SAndroid Build Coastguard Worker inps = [torch.rand(2**52, device='meta') for _ in range(3)] 1379*da0073e9SAndroid Build Coastguard Worker r = torch.ops.aten.huber_loss_backward(*inps, 0, 1.0) 1380*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r.device.type, 'meta') 1381*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r.shape, inps[0].shape) 1382*da0073e9SAndroid Build Coastguard Worker 1383*da0073e9SAndroid Build Coastguard Worker def _norm_backwards_test_helper(self, op, args, output_mask, expected_shapes): 1384*da0073e9SAndroid Build Coastguard Worker 1385*da0073e9SAndroid Build Coastguard Worker dtype = torch.float32 1386*da0073e9SAndroid Build Coastguard Worker device = "meta" 1387*da0073e9SAndroid Build Coastguard Worker 1388*da0073e9SAndroid Build Coastguard Worker # test functional call 1389*da0073e9SAndroid Build Coastguard Worker grads = op(*args, output_mask) 1390*da0073e9SAndroid Build Coastguard Worker 1391*da0073e9SAndroid Build Coastguard Worker def assertEqualShapes(res, exp): 1392*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(res) if exp is None else self.assertEqual(exp, res.shape) 1393*da0073e9SAndroid Build Coastguard Worker 1394*da0073e9SAndroid Build Coastguard Worker assertEqualShapes(grads[0], expected_shapes[0]) 1395*da0073e9SAndroid Build Coastguard Worker assertEqualShapes(grads[1], expected_shapes[1]) 1396*da0073e9SAndroid Build Coastguard Worker assertEqualShapes(grads[2], expected_shapes[2]) 1397*da0073e9SAndroid Build Coastguard Worker 1398*da0073e9SAndroid Build Coastguard Worker out_kwargs = { 1399*da0073e9SAndroid Build Coastguard Worker f"out{i}": torch.empty(0, device=device, dtype=dtype) 1400*da0073e9SAndroid Build Coastguard Worker for i in range(len(output_mask)) 1401*da0073e9SAndroid Build Coastguard Worker } 1402*da0073e9SAndroid Build Coastguard Worker 1403*da0073e9SAndroid Build Coastguard Worker # test call with out parameters 1404*da0073e9SAndroid Build Coastguard Worker grads = op(*args, output_mask, **out_kwargs) 1405*da0073e9SAndroid Build Coastguard Worker 1406*da0073e9SAndroid Build Coastguard Worker def assertEqualShapes(res, exp): 1407*da0073e9SAndroid Build Coastguard Worker self.assertEqual(exp, res.shape) if exp is not None else True 1408*da0073e9SAndroid Build Coastguard Worker 1409*da0073e9SAndroid Build Coastguard Worker assertEqualShapes(out_kwargs["out0"], expected_shapes[0]) 1410*da0073e9SAndroid Build Coastguard Worker assertEqualShapes(out_kwargs["out1"], expected_shapes[1]) 1411*da0073e9SAndroid Build Coastguard Worker assertEqualShapes(out_kwargs["out2"], expected_shapes[2]) 1412*da0073e9SAndroid Build Coastguard Worker 1413*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1414*da0073e9SAndroid Build Coastguard Worker @parametrize("output_mask", list(itertools.product([True, False], [True, False], [True, False]))) 1415*da0073e9SAndroid Build Coastguard Worker def test_layer_norm_backward(self, output_mask): 1416*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_methods_invocations import sample_inputs_layer_norm 1417*da0073e9SAndroid Build Coastguard Worker 1418*da0073e9SAndroid Build Coastguard Worker device = "meta" 1419*da0073e9SAndroid Build Coastguard Worker dtype = torch.float32 1420*da0073e9SAndroid Build Coastguard Worker 1421*da0073e9SAndroid Build Coastguard Worker samples = sample_inputs_layer_norm(None, device, dtype, requires_grad=False) 1422*da0073e9SAndroid Build Coastguard Worker 1423*da0073e9SAndroid Build Coastguard Worker for sample in samples: 1424*da0073e9SAndroid Build Coastguard Worker with self.subTest(sample=sample): 1425*da0073e9SAndroid Build Coastguard Worker # handle optional weight and bias 1426*da0073e9SAndroid Build Coastguard Worker if len(sample.args) != 3: 1427*da0073e9SAndroid Build Coastguard Worker sample.args = (*sample.args, *([None] * (3 - len(sample.args)))) 1428*da0073e9SAndroid Build Coastguard Worker 1429*da0073e9SAndroid Build Coastguard Worker grad_out = torch.ones_like(sample.input) 1430*da0073e9SAndroid Build Coastguard Worker normalized_shape, weight, bias = sample.args 1431*da0073e9SAndroid Build Coastguard Worker ndims_after_reduction = sample.input.ndim - len(normalized_shape) 1432*da0073e9SAndroid Build Coastguard Worker mean_shape = grad_out.shape[:ndims_after_reduction] 1433*da0073e9SAndroid Build Coastguard Worker mean = torch.zeros(mean_shape, device=device, dtype=dtype) 1434*da0073e9SAndroid Build Coastguard Worker rstd = torch.zeros(mean_shape, device=device, dtype=dtype) 1435*da0073e9SAndroid Build Coastguard Worker 1436*da0073e9SAndroid Build Coastguard Worker expected_shapes = ( 1437*da0073e9SAndroid Build Coastguard Worker sample.input.shape if output_mask[0] else None, 1438*da0073e9SAndroid Build Coastguard Worker weight.shape if output_mask[1] and weight is not None else None, 1439*da0073e9SAndroid Build Coastguard Worker bias.shape if output_mask[2] and bias is not None else None) 1440*da0073e9SAndroid Build Coastguard Worker 1441*da0073e9SAndroid Build Coastguard Worker args = [grad_out, sample.input, normalized_shape, mean, rstd, weight, bias] 1442*da0073e9SAndroid Build Coastguard Worker 1443*da0073e9SAndroid Build Coastguard Worker self._norm_backwards_test_helper(torch.ops.aten.native_layer_norm_backward, 1444*da0073e9SAndroid Build Coastguard Worker args, output_mask, expected_shapes) 1445*da0073e9SAndroid Build Coastguard Worker 1446*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1447*da0073e9SAndroid Build Coastguard Worker @parametrize("output_mask", list(itertools.product([True, False], [True, False], [True, False]))) 1448*da0073e9SAndroid Build Coastguard Worker def test_group_norm_backward(self, output_mask): 1449*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_methods_invocations import sample_inputs_group_norm 1450*da0073e9SAndroid Build Coastguard Worker 1451*da0073e9SAndroid Build Coastguard Worker # input, (args) num_groups, (kwargs) weight, bias eps 1452*da0073e9SAndroid Build Coastguard Worker device = "meta" 1453*da0073e9SAndroid Build Coastguard Worker dtype = torch.float32 1454*da0073e9SAndroid Build Coastguard Worker samples = sample_inputs_group_norm(None, device, dtype, requires_grad=False) 1455*da0073e9SAndroid Build Coastguard Worker 1456*da0073e9SAndroid Build Coastguard Worker for sample in samples: 1457*da0073e9SAndroid Build Coastguard Worker with self.subTest(sample=sample): 1458*da0073e9SAndroid Build Coastguard Worker grad_out = torch.ones_like(sample.input) 1459*da0073e9SAndroid Build Coastguard Worker N, C = sample.input.shape[:2] 1460*da0073e9SAndroid Build Coastguard Worker HxW = torch.prod(torch.as_tensor(sample.input.shape[2:]), dtype=torch.int32).item() 1461*da0073e9SAndroid Build Coastguard Worker group = sample.args[0] 1462*da0073e9SAndroid Build Coastguard Worker mean = torch.zeros((N, group), device=device, dtype=dtype) 1463*da0073e9SAndroid Build Coastguard Worker rstd = torch.zeros((N, group), device=device, dtype=dtype) 1464*da0073e9SAndroid Build Coastguard Worker weight = torch.zeros((C), device=device, dtype=dtype) 1465*da0073e9SAndroid Build Coastguard Worker 1466*da0073e9SAndroid Build Coastguard Worker args = [grad_out, sample.input, mean, rstd, weight, N, C, HxW, group] 1467*da0073e9SAndroid Build Coastguard Worker 1468*da0073e9SAndroid Build Coastguard Worker expected_shapes = ( 1469*da0073e9SAndroid Build Coastguard Worker sample.input.shape if output_mask[0] else None, 1470*da0073e9SAndroid Build Coastguard Worker weight.shape if output_mask[1] else None, 1471*da0073e9SAndroid Build Coastguard Worker weight.shape if output_mask[2] else None) 1472*da0073e9SAndroid Build Coastguard Worker 1473*da0073e9SAndroid Build Coastguard Worker # test functional call 1474*da0073e9SAndroid Build Coastguard Worker self._norm_backwards_test_helper(torch.ops.aten.native_group_norm_backward, 1475*da0073e9SAndroid Build Coastguard Worker args, output_mask, expected_shapes) 1476*da0073e9SAndroid Build Coastguard Worker 1477*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1478*da0073e9SAndroid Build Coastguard Worker @parametrize("output_mask", list(itertools.product([True], [True, False], [True, False]))) 1479*da0073e9SAndroid Build Coastguard Worker def test_batch_norm_backward(self, output_mask): 1480*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_methods_invocations import sample_inputs_batch_norm 1481*da0073e9SAndroid Build Coastguard Worker 1482*da0073e9SAndroid Build Coastguard Worker # input, (args) num_groups, (kwargs) weight, bias eps 1483*da0073e9SAndroid Build Coastguard Worker device = "meta" 1484*da0073e9SAndroid Build Coastguard Worker dtype = torch.float32 1485*da0073e9SAndroid Build Coastguard Worker samples = sample_inputs_batch_norm(None, device, dtype, requires_grad=False) 1486*da0073e9SAndroid Build Coastguard Worker 1487*da0073e9SAndroid Build Coastguard Worker for sample in samples: 1488*da0073e9SAndroid Build Coastguard Worker with self.subTest(sample=sample): 1489*da0073e9SAndroid Build Coastguard Worker 1490*da0073e9SAndroid Build Coastguard Worker if sample.input.dim() < 2: 1491*da0073e9SAndroid Build Coastguard Worker continue 1492*da0073e9SAndroid Build Coastguard Worker 1493*da0073e9SAndroid Build Coastguard Worker grad_out = torch.ones_like(sample.input) 1494*da0073e9SAndroid Build Coastguard Worker running_mean, running_var, weight, bias = sample.args 1495*da0073e9SAndroid Build Coastguard Worker train = sample.kwargs.get("training", True) 1496*da0073e9SAndroid Build Coastguard Worker save_mean = torch.zeros((sample.input.shape[1], ), device=device, dtype=dtype) if train else None 1497*da0073e9SAndroid Build Coastguard Worker save_invstd = torch.zeros((sample.input.shape[1], ), device=device, dtype=dtype) if train else None 1498*da0073e9SAndroid Build Coastguard Worker 1499*da0073e9SAndroid Build Coastguard Worker args = [grad_out, sample.input, weight, running_mean, running_var, 1500*da0073e9SAndroid Build Coastguard Worker save_mean, save_invstd, train, sample.kwargs.get("eps", 1e-5)] 1501*da0073e9SAndroid Build Coastguard Worker 1502*da0073e9SAndroid Build Coastguard Worker expected_shapes = ( 1503*da0073e9SAndroid Build Coastguard Worker sample.input.shape, 1504*da0073e9SAndroid Build Coastguard Worker torch.Size([sample.input.shape[1]]) if output_mask[1] else None, 1505*da0073e9SAndroid Build Coastguard Worker torch.Size([sample.input.shape[1]]) if output_mask[2] else None) 1506*da0073e9SAndroid Build Coastguard Worker 1507*da0073e9SAndroid Build Coastguard Worker self._norm_backwards_test_helper(torch.ops.aten.native_batch_norm_backward, 1508*da0073e9SAndroid Build Coastguard Worker args, output_mask, expected_shapes) 1509*da0073e9SAndroid Build Coastguard Worker 1510*da0073e9SAndroid Build Coastguard Worker def test_fill__alias_relationship(self): 1511*da0073e9SAndroid Build Coastguard Worker inps = torch.rand(2**52, device='meta') 1512*da0073e9SAndroid Build Coastguard Worker r = torch.ops.aten.fill_(inps, 1.0) 1513*da0073e9SAndroid Build Coastguard Worker # aten.fill_ returns an aliase 1514*da0073e9SAndroid Build Coastguard Worker self.assertEqual(id(inps), id(r)) 1515*da0073e9SAndroid Build Coastguard Worker 1516*da0073e9SAndroid Build Coastguard Worker # aten.fill returns a new tensor 1517*da0073e9SAndroid Build Coastguard Worker r2 = torch.ops.aten.fill(inps, 1.0) 1518*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(id(inps), id(r2)) 1519*da0073e9SAndroid Build Coastguard Worker 1520*da0073e9SAndroid Build Coastguard Worker def test_meta__fused_moving_avg_obs_fq_helper(self, device): 1521*da0073e9SAndroid Build Coastguard Worker from torch.ao.quantization import FusedMovingAvgObsFakeQuantize 1522*da0073e9SAndroid Build Coastguard Worker to_meta = MetaConverter() 1523*da0073e9SAndroid Build Coastguard Worker 1524*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5, device=device) 1525*da0073e9SAndroid Build Coastguard Worker running_min_op = torch.tensor(float("inf"), device=device) 1526*da0073e9SAndroid Build Coastguard Worker running_max_op = torch.tensor(float("-inf"), device=device) 1527*da0073e9SAndroid Build Coastguard Worker avg_const = 0.01 1528*da0073e9SAndroid Build Coastguard Worker scale = torch.tensor([1.0], device=device) 1529*da0073e9SAndroid Build Coastguard Worker zero_point = torch.tensor([0], dtype=torch.int, device=device) 1530*da0073e9SAndroid Build Coastguard Worker 1531*da0073e9SAndroid Build Coastguard Worker mod = FusedMovingAvgObsFakeQuantize() 1532*da0073e9SAndroid Build Coastguard Worker torch.ao.quantization.enable_fake_quant(mod) 1533*da0073e9SAndroid Build Coastguard Worker torch.ao.quantization.enable_observer(mod) 1534*da0073e9SAndroid Build Coastguard Worker mod.to(device) 1535*da0073e9SAndroid Build Coastguard Worker 1536*da0073e9SAndroid Build Coastguard Worker meta_x = to_meta(x) 1537*da0073e9SAndroid Build Coastguard Worker 1538*da0073e9SAndroid Build Coastguard Worker args = [ 1539*da0073e9SAndroid Build Coastguard Worker x, 1540*da0073e9SAndroid Build Coastguard Worker mod.observer_enabled, 1541*da0073e9SAndroid Build Coastguard Worker mod.fake_quant_enabled, 1542*da0073e9SAndroid Build Coastguard Worker running_min_op, 1543*da0073e9SAndroid Build Coastguard Worker running_max_op, 1544*da0073e9SAndroid Build Coastguard Worker scale, 1545*da0073e9SAndroid Build Coastguard Worker zero_point, 1546*da0073e9SAndroid Build Coastguard Worker avg_const, 1547*da0073e9SAndroid Build Coastguard Worker 0, 1548*da0073e9SAndroid Build Coastguard Worker 255, 1549*da0073e9SAndroid Build Coastguard Worker 0, 1550*da0073e9SAndroid Build Coastguard Worker ] 1551*da0073e9SAndroid Build Coastguard Worker 1552*da0073e9SAndroid Build Coastguard Worker meta_args = args.copy() 1553*da0073e9SAndroid Build Coastguard Worker meta_args[0] = meta_x 1554*da0073e9SAndroid Build Coastguard Worker 1555*da0073e9SAndroid Build Coastguard Worker kwargss = [ 1556*da0073e9SAndroid Build Coastguard Worker {}, 1557*da0073e9SAndroid Build Coastguard Worker {"per_row_fake_quant": False, "symmetric_quant": False}, 1558*da0073e9SAndroid Build Coastguard Worker {"per_row_fake_quant": False, "symmetric_quant": True}, 1559*da0073e9SAndroid Build Coastguard Worker ] 1560*da0073e9SAndroid Build Coastguard Worker 1561*da0073e9SAndroid Build Coastguard Worker for kwargs in kwargss: 1562*da0073e9SAndroid Build Coastguard Worker ref_out = aten._fused_moving_avg_obs_fq_helper.default(*args, **kwargs) 1563*da0073e9SAndroid Build Coastguard Worker meta_out = aten._fused_moving_avg_obs_fq_helper.default(*meta_args, **kwargs) 1564*da0073e9SAndroid Build Coastguard Worker 1565*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref_out[0].size(), meta_out[0].size()) 1566*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref_out[0].stride(), meta_out[0].stride()) 1567*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref_out[1].size(), meta_out[1].size()) 1568*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref_out[1].stride(), meta_out[1].stride()) 1569*da0073e9SAndroid Build Coastguard Worker 1570*da0073e9SAndroid Build Coastguard Worker def test_cdist_forward(self, device): 1571*da0073e9SAndroid Build Coastguard Worker to_meta = MetaConverter() 1572*da0073e9SAndroid Build Coastguard Worker x1 = torch.rand([3, 2], device=device) 1573*da0073e9SAndroid Build Coastguard Worker x2 = torch.rand([2, 2], device=device) 1574*da0073e9SAndroid Build Coastguard Worker p = 2.0 1575*da0073e9SAndroid Build Coastguard Worker for compute_mode in (None, 1, 2): 1576*da0073e9SAndroid Build Coastguard Worker ref = aten._cdist_forward.default(x1, x2, p, compute_mode) 1577*da0073e9SAndroid Build Coastguard Worker res = aten._cdist_forward.default(to_meta(x1), to_meta(x2), p, compute_mode) 1578*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.device.type, 'meta') 1579*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref.shape, res.shape) 1580*da0073e9SAndroid Build Coastguard Worker 1581*da0073e9SAndroid Build Coastguard Worker def test_quantized_embedding_bag(self): 1582*da0073e9SAndroid Build Coastguard Worker tab_shape = [8, 128] 1583*da0073e9SAndroid Build Coastguard Worker emb_size, ind_len, off_len = tab_shape[0], 32, 33 1584*da0073e9SAndroid Build Coastguard Worker f_table = torch.from_numpy((np.random.random_sample(tab_shape) + 1).astype(np.float32)) 1585*da0073e9SAndroid Build Coastguard Worker q_table = torch.ops.quantized.embedding_bag_byte_prepack(f_table) 1586*da0073e9SAndroid Build Coastguard Worker indices = torch.from_numpy(np.random.randint(low=0, high=emb_size, size=ind_len)).int() 1587*da0073e9SAndroid Build Coastguard Worker max_length = len(indices) // (off_len - 1) 1588*da0073e9SAndroid Build Coastguard Worker if max_length > 20: 1589*da0073e9SAndroid Build Coastguard Worker max_length = 20 1590*da0073e9SAndroid Build Coastguard Worker np_lengths = np.random.randint(0, max_length + 1, size=off_len - 1).astype(np.int32) 1591*da0073e9SAndroid Build Coastguard Worker offsets = torch.cat([torch.zeros([1]), torch.cumsum(torch.from_numpy(np_lengths), 0)]).int() 1592*da0073e9SAndroid Build Coastguard Worker 1593*da0073e9SAndroid Build Coastguard Worker eb = torch.ops.quantized.embedding_bag_byte_rowwise_offsets( 1594*da0073e9SAndroid Build Coastguard Worker q_table.to(device="meta"), 1595*da0073e9SAndroid Build Coastguard Worker indices.to(device="meta"), 1596*da0073e9SAndroid Build Coastguard Worker offsets.to(device="meta"), 1597*da0073e9SAndroid Build Coastguard Worker mode=0, # sum 1598*da0073e9SAndroid Build Coastguard Worker per_sample_weights=None, 1599*da0073e9SAndroid Build Coastguard Worker include_last_offset=True, 1600*da0073e9SAndroid Build Coastguard Worker ) 1601*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eb.shape, [32, 128]) 1602*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eb.dtype, torch.float32) 1603*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eb.untyped_storage().data_ptr(), 0) 1604*da0073e9SAndroid Build Coastguard Worker 1605*da0073e9SAndroid Build Coastguard Worker # Tests mean and max. 1606*da0073e9SAndroid Build Coastguard Worker # Can't easily test sum, because there is a fast path for sum which 1607*da0073e9SAndroid Build Coastguard Worker # causes offset2bag to not get allocated... but the backward function 1608*da0073e9SAndroid Build Coastguard Worker # needs it, and the offset2bag computation lives inside the 1609*da0073e9SAndroid Build Coastguard Worker # derivatives.yaml formula directly, so there is no way to access it. 1610*da0073e9SAndroid Build Coastguard Worker # To test sum, need to manually compute offset2bag 1611*da0073e9SAndroid Build Coastguard Worker @parametrize("mode", [1, 2]) 1612*da0073e9SAndroid Build Coastguard Worker def test_embedding_bag_dense_backward(self, mode): 1613*da0073e9SAndroid Build Coastguard Worker weight = torch.randn(4, 3, requires_grad=True) 1614*da0073e9SAndroid Build Coastguard Worker indices = torch.tensor([1, 0, 2, 1, 3]) 1615*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 2, 3, 5]) 1616*da0073e9SAndroid Build Coastguard Worker scale_grad_by_freq = False 1617*da0073e9SAndroid Build Coastguard Worker sparse = False 1618*da0073e9SAndroid Build Coastguard Worker per_sample_weights = None 1619*da0073e9SAndroid Build Coastguard Worker include_last_offset = False 1620*da0073e9SAndroid Build Coastguard Worker padding_idx = -1 1621*da0073e9SAndroid Build Coastguard Worker 1622*da0073e9SAndroid Build Coastguard Worker output, offset2bag, bag_size, maximum_indices = torch.ops.aten._embedding_bag.default( 1623*da0073e9SAndroid Build Coastguard Worker weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx 1624*da0073e9SAndroid Build Coastguard Worker ) 1625*da0073e9SAndroid Build Coastguard Worker grad = torch.randn_like(output) 1626*da0073e9SAndroid Build Coastguard Worker 1627*da0073e9SAndroid Build Coastguard Worker # Call the function with example inputs 1628*da0073e9SAndroid Build Coastguard Worker grad_weight = torch.ops.aten._embedding_bag_dense_backward.default( 1629*da0073e9SAndroid Build Coastguard Worker grad, indices, offset2bag, bag_size, maximum_indices, weight.size(0), 1630*da0073e9SAndroid Build Coastguard Worker scale_grad_by_freq, mode, per_sample_weights, padding_idx 1631*da0073e9SAndroid Build Coastguard Worker ) 1632*da0073e9SAndroid Build Coastguard Worker meta_grad_weight = torch.ops.aten._embedding_bag_dense_backward.default( 1633*da0073e9SAndroid Build Coastguard Worker grad.to('meta'), indices.to('meta'), offset2bag.to('meta'), bag_size.to('meta'), 1634*da0073e9SAndroid Build Coastguard Worker maximum_indices.to('meta'), weight.size(0), 1635*da0073e9SAndroid Build Coastguard Worker scale_grad_by_freq, mode, per_sample_weights, padding_idx 1636*da0073e9SAndroid Build Coastguard Worker ) 1637*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_weight.to('meta'), meta_grad_weight) 1638*da0073e9SAndroid Build Coastguard Worker 1639*da0073e9SAndroid Build Coastguard Worker def test_embedding_bag_dense_backward_per_sample_weights(self): 1640*da0073e9SAndroid Build Coastguard Worker weight = torch.randn(4, 3, requires_grad=True) 1641*da0073e9SAndroid Build Coastguard Worker indices = torch.tensor([1, 0, 2, 1, 3]) 1642*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 2, 3, 5]) 1643*da0073e9SAndroid Build Coastguard Worker scale_grad_by_freq = False 1644*da0073e9SAndroid Build Coastguard Worker sparse = False 1645*da0073e9SAndroid Build Coastguard Worker mode = 0 1646*da0073e9SAndroid Build Coastguard Worker per_sample_weights = torch.randn(5, requires_grad=True) 1647*da0073e9SAndroid Build Coastguard Worker include_last_offset = False 1648*da0073e9SAndroid Build Coastguard Worker padding_idx = -1 1649*da0073e9SAndroid Build Coastguard Worker 1650*da0073e9SAndroid Build Coastguard Worker output, offset2bag, bag_size, maximum_indices = torch.ops.aten._embedding_bag.default( 1651*da0073e9SAndroid Build Coastguard Worker weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx 1652*da0073e9SAndroid Build Coastguard Worker ) 1653*da0073e9SAndroid Build Coastguard Worker grad = torch.randn_like(output) 1654*da0073e9SAndroid Build Coastguard Worker 1655*da0073e9SAndroid Build Coastguard Worker # Call the function with example inputs 1656*da0073e9SAndroid Build Coastguard Worker grad_weight = torch.ops.aten._embedding_bag_per_sample_weights_backward.default( 1657*da0073e9SAndroid Build Coastguard Worker grad, weight, indices, offsets, offset2bag, mode, padding_idx 1658*da0073e9SAndroid Build Coastguard Worker ) 1659*da0073e9SAndroid Build Coastguard Worker meta_grad_weight = torch.ops.aten._embedding_bag_per_sample_weights_backward.default( 1660*da0073e9SAndroid Build Coastguard Worker grad.to('meta'), weight.to('meta'), indices.to('meta'), 1661*da0073e9SAndroid Build Coastguard Worker offsets.to('meta'), offset2bag.to('meta'), mode, padding_idx 1662*da0073e9SAndroid Build Coastguard Worker ) 1663*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_weight.to('meta'), meta_grad_weight) 1664*da0073e9SAndroid Build Coastguard Worker 1665*da0073e9SAndroid Build Coastguard Worker # opinfo test is using aten.fill_, it's not testing aten.fill 1666*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1667*da0073e9SAndroid Build Coastguard Worker def test_fill_stride(self): 1668*da0073e9SAndroid Build Coastguard Worker to_meta = MetaConverter() 1669*da0073e9SAndroid Build Coastguard Worker sample_args = [torch.rand(2, 2, 2, 2), 1.0] 1670*da0073e9SAndroid Build Coastguard Worker 1671*da0073e9SAndroid Build Coastguard Worker for args in get_strided_args(sample_args): 1672*da0073e9SAndroid Build Coastguard Worker meta_args = to_meta(args) 1673*da0073e9SAndroid Build Coastguard Worker ref_out = torch.ops.aten.fill(*args) 1674*da0073e9SAndroid Build Coastguard Worker meta_out = torch.ops.aten.fill(*meta_args) 1675*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref_out.size(), meta_out.size()) 1676*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref_out.stride(), meta_out.stride()) 1677*da0073e9SAndroid Build Coastguard Worker 1678*da0073e9SAndroid Build Coastguard Worker 1679*da0073e9SAndroid Build Coastguard Worker def test_map_location_deserialize(self): 1680*da0073e9SAndroid Build Coastguard Worker import io 1681*da0073e9SAndroid Build Coastguard Worker 1682*da0073e9SAndroid Build Coastguard Worker t = torch.rand(10) 1683*da0073e9SAndroid Build Coastguard Worker b = io.BytesIO() 1684*da0073e9SAndroid Build Coastguard Worker 1685*da0073e9SAndroid Build Coastguard Worker torch.save(t, b) 1686*da0073e9SAndroid Build Coastguard Worker b.seek(0) 1687*da0073e9SAndroid Build Coastguard Worker r = torch.load(b, map_location=torch.device("meta")) 1688*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r.device.type, 'meta') 1689*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r.shape, t.shape) 1690*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r.dtype, t.dtype) 1691*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r.storage().data_ptr(), 0) 1692*da0073e9SAndroid Build Coastguard Worker 1693*da0073e9SAndroid Build Coastguard Worker def test_embedding_bag_byte_prepack(self): 1694*da0073e9SAndroid Build Coastguard Worker batch_size = 10 1695*da0073e9SAndroid Build Coastguard Worker num_embeddings = 80 1696*da0073e9SAndroid Build Coastguard Worker embedding_dim = [128, 256, 512] 1697*da0073e9SAndroid Build Coastguard Worker res_shape = [[batch_size, num_embeddings, ed + 8] for ed in embedding_dim] 1698*da0073e9SAndroid Build Coastguard Worker for ed, rs in zip(embedding_dim, res_shape): 1699*da0073e9SAndroid Build Coastguard Worker weight = torch.randn(batch_size, num_embeddings, ed, dtype=torch.float32) 1700*da0073e9SAndroid Build Coastguard Worker res = torch.ops.quantized.embedding_bag_byte_prepack(weight.to(device="meta")) 1701*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.shape, rs) 1702*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.dtype, torch.float32) 1703*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.untyped_storage().data_ptr(), 0) 1704*da0073e9SAndroid Build Coastguard Worker 1705*da0073e9SAndroid Build Coastguard Worker def test_embedding_bag_byte_unpack(self): 1706*da0073e9SAndroid Build Coastguard Worker batch_size = 10 1707*da0073e9SAndroid Build Coastguard Worker num_embeddings = 80 1708*da0073e9SAndroid Build Coastguard Worker embedding_dim = [128, 256, 512] 1709*da0073e9SAndroid Build Coastguard Worker res_shape = [[batch_size, num_embeddings, ed] for ed in embedding_dim] 1710*da0073e9SAndroid Build Coastguard Worker for ed, rs in zip(embedding_dim, res_shape): 1711*da0073e9SAndroid Build Coastguard Worker packed_weight = torch.randn(batch_size, num_embeddings, ed + 8, dtype=torch.float32) 1712*da0073e9SAndroid Build Coastguard Worker res = torch.ops.quantized.embedding_bag_byte_unpack(packed_weight.to(device="meta")) 1713*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.shape, rs) 1714*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.dtype, torch.float32) 1715*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.untyped_storage().data_ptr(), 0) 1716*da0073e9SAndroid Build Coastguard Worker 1717*da0073e9SAndroid Build Coastguard Worker def test_index_select_out(self): 1718*da0073e9SAndroid Build Coastguard Worker def f(): 1719*da0073e9SAndroid Build Coastguard Worker input = torch.randn([8, 16], device='meta') 1720*da0073e9SAndroid Build Coastguard Worker index = torch.tensor([2, 1, 6, 7, 3, 1, 7, 5, 6, 7], device='meta') 1721*da0073e9SAndroid Build Coastguard Worker out = torch.empty([10, 16], device='meta') 1722*da0073e9SAndroid Build Coastguard Worker return torch.index_select(input=input, dim=0, index=index, out=out) 1723*da0073e9SAndroid Build Coastguard Worker with enable_python_dispatcher(): 1724*da0073e9SAndroid Build Coastguard Worker out = f() 1725*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.shape, [10, 16]) 1726*da0073e9SAndroid Build Coastguard Worker 1727*da0073e9SAndroid Build Coastguard Worker def test_local_scalar_dense_call(self): 1728*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "cannot be called on meta tensors"): 1729*da0073e9SAndroid Build Coastguard Worker meta_tensor = torch.randn(1, device='meta') 1730*da0073e9SAndroid Build Coastguard Worker meta_tensor.item() 1731*da0073e9SAndroid Build Coastguard Worker 1732*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestMeta, globals()) 1733*da0073e9SAndroid Build Coastguard Worker 1734*da0073e9SAndroid Build Coastguard Workerdef print_op_str_if_not_supported(op_str): 1735*da0073e9SAndroid Build Coastguard Worker op = OperatorName.parse(op_str) 1736*da0073e9SAndroid Build Coastguard Worker packet = getattr(torch.ops.aten, str(op.name)) 1737*da0073e9SAndroid Build Coastguard Worker overload = getattr(packet, op.overload_name if op.overload_name else "default") 1738*da0073e9SAndroid Build Coastguard Worker if any(overload in d for d in [meta_dispatch_skips, meta_dispatch_device_skips['cuda']]): 1739*da0073e9SAndroid Build Coastguard Worker print(f"{overload} # SKIP") 1740*da0073e9SAndroid Build Coastguard Worker if any(overload in d for d in [meta_dispatch_expected_failures, meta_dispatch_device_expected_failures['cuda']]): 1741*da0073e9SAndroid Build Coastguard Worker print(overload) 1742*da0073e9SAndroid Build Coastguard Worker 1743*da0073e9SAndroid Build Coastguard Worker 1744*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 1745*da0073e9SAndroid Build Coastguard Worker COMPARE_XLA = os.getenv('PYTORCH_COMPARE_XLA', None) 1746*da0073e9SAndroid Build Coastguard Worker if COMPARE_XLA is not None: 1747*da0073e9SAndroid Build Coastguard Worker with open(COMPARE_XLA) as f: 1748*da0073e9SAndroid Build Coastguard Worker d = yaml.load(f, Loader=YamlLoader) 1749*da0073e9SAndroid Build Coastguard Worker ops = d.get("full_codegen", []) + d.get("supported", []) + d.get("autograd", []) 1750*da0073e9SAndroid Build Coastguard Worker for op_str in ops: 1751*da0073e9SAndroid Build Coastguard Worker print_op_str_if_not_supported(op_str) 1752*da0073e9SAndroid Build Coastguard Worker sys.exit(0) 1753*da0073e9SAndroid Build Coastguard Worker 1754*da0073e9SAndroid Build Coastguard Worker COMPARE_TEXT = os.getenv('PYTORCH_COMPARE_TEXT', None) 1755*da0073e9SAndroid Build Coastguard Worker if COMPARE_TEXT is not None: 1756*da0073e9SAndroid Build Coastguard Worker with open(COMPARE_TEXT) as f: 1757*da0073e9SAndroid Build Coastguard Worker for op_str in f: 1758*da0073e9SAndroid Build Coastguard Worker print_op_str_if_not_supported(op_str.strip()) 1759*da0073e9SAndroid Build Coastguard Worker sys.exit(0) 1760*da0073e9SAndroid Build Coastguard Worker 1761*da0073e9SAndroid Build Coastguard Worker run_tests() 1762