1# Owner(s): ["module: decompositions"] 2 3import itertools 4import torch 5import os 6import numpy as np 7from enum import Enum 8from torch.overrides import resolve_name 9from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten 10from torch.utils import _pytree as pytree 11from torch._subclasses.meta_utils import MetaConverter, assert_metadata_eq, is_sparse_any 12import torch.utils._python_dispatch 13from torch._dispatch.python import enable_python_dispatcher 14from torch._ops import OpOverload, OpOverloadPacket 15from torch.testing import make_tensor 16from torch.testing._internal.common_utils import unMarkDynamoStrictTest 17from torch.testing._internal.common_utils import ( 18 TestCase, 19 skipIfCrossRef, 20 skipIfTorchDynamo, 21 suppress_warnings, 22 TEST_WITH_ASAN, 23 TEST_WITH_TORCHDYNAMO, 24 run_tests, 25 dtype_abbrs, 26 parametrize 27) 28from torch.testing._internal.common_device_type import ( 29 ops, 30 instantiate_device_type_tests, 31 onlyCUDA, 32 onlyCPU, 33 OpDTypes, 34) 35from torch.testing._internal.common_methods_invocations import ( 36 binary_ufuncs, op_db, foreach_unary_op_db, foreach_binary_op_db, 37 foreach_pointwise_op_db, foreach_reduce_op_db, foreach_other_op_db) 38from torch.testing._internal.opinfo.core import S, SampleInput 39from torchgen.yaml_utils import YamlLoader 40from torchgen.model import OperatorName 41 42import copy 43import sys 44import yaml 45import atexit 46import re 47from collections import defaultdict 48from collections.abc import Iterable 49import unittest 50import warnings 51import weakref 52from functools import partial, wraps 53 54bf16 = torch.bfloat16 55f64 = torch.float64 56f32 = torch.float32 57f16 = torch.float16 58c32 = torch.complex32 59c64 = torch.complex64 60c128 = torch.complex128 61i8 = torch.int8 62i16 = torch.int16 63i32 = torch.int32 64i64 = torch.int64 65b8 = torch.bool 66u8 = torch.uint8 67u16 = torch.uint16 68u32 = torch.uint32 69u64 = torch.uint64 70 71foreach_op_db = ( 72 foreach_unary_op_db + 73 foreach_binary_op_db + 74 foreach_pointwise_op_db + 75 foreach_reduce_op_db + 76 foreach_other_op_db 77) 78 79 80class TestMetaConverter(TestCase): 81 def assertSameVersionCounter(self, m1, m2): 82 # Cannot easily test m1 and m2 have same storage due to 83 # lack of Storage bindings. Use version counter. 84 vc = m1._version 85 self.assertEqual(m2._version, vc) 86 # Doing it this way ensures that we get VC bump even with leaves 87 with torch.no_grad(): 88 m1._base.add_(3) 89 self.assertNotEqual(m1._version, vc) 90 self.assertEqual(m2._version, m1._version) 91 92 def assertMetadataMatches(self, m1, m2): 93 assert_metadata_eq(self.assertEqual, m1, m2) 94 95 def test_view_of_non_leaf(self): 96 x = torch.randn(4, requires_grad=True) 97 y = x.neg() 98 z1 = y[:] 99 z2 = y[:] 100 to_meta = MetaConverter() 101 m1 = to_meta(z1) 102 m2 = to_meta(z2) 103 104 # check the test is actually testing what it claims 105 self.assertTrue(m1._is_view()) 106 self.assertFalse(m1._base.is_leaf) 107 108 self.assertIsNot(m1, m2) 109 self.assertMetadataMatches(m1, z1) 110 self.assertMetadataMatches(m2, z2) 111 self.assertSameVersionCounter(m1, m2) 112 113 def test_view_of_leaf(self): 114 x = torch.randn(4, requires_grad=True) 115 z1 = x[:] 116 z2 = x[:] 117 to_meta = MetaConverter() 118 m1 = to_meta(z1) 119 m2 = to_meta(z2) 120 121 # check the test is actually testing what it claims 122 self.assertTrue(m1._is_view()) 123 self.assertTrue(m1._base.is_leaf) 124 125 self.assertIsNot(m1, m2) 126 self.assertMetadataMatches(m1, z1) 127 self.assertMetadataMatches(m2, z2) 128 self.assertSameVersionCounter(m1, m2) 129 130 def test_view_of_view_of_leaf(self): 131 x = torch.randn(8) 132 y = x.view(2, 4) 133 y.requires_grad = True 134 z = y.view(2, 2, 2) 135 136 to_meta = MetaConverter() 137 mx = to_meta(x) 138 mz = to_meta(z) 139 140 self.assertFalse(z.is_leaf) 141 142 self.assertMetadataMatches(mx, x) 143 self.assertMetadataMatches(mz, z) 144 145 def test_leaf(self): 146 x = torch.randn(4, requires_grad=True) 147 to_meta = MetaConverter() 148 m = to_meta(x) 149 150 # check the test is actually testing what it claims 151 self.assertTrue(m.is_leaf) 152 self.assertTrue(m.requires_grad) 153 154 self.assertMetadataMatches(m, x) 155 156 def test_non_leaf(self): 157 x = torch.randn(4, requires_grad=True) 158 y = x.neg() 159 to_meta = MetaConverter() 160 m = to_meta(y) 161 162 # check the test is actually testing what it claims 163 self.assertFalse(m.is_leaf) 164 self.assertTrue(m.requires_grad) 165 166 self.assertMetadataMatches(m, y) 167 168 def test_requires_grad_false(self): 169 x = torch.randn(4, requires_grad=False) 170 to_meta = MetaConverter() 171 m = to_meta(x) 172 173 # check the test is actually testing what it claims 174 self.assertFalse(m.requires_grad) 175 176 self.assertMetadataMatches(m, x) 177 178 def test_channels_last(self): 179 x = torch.empty(2, 3, 4, 5, memory_format=torch.channels_last) 180 to_meta = MetaConverter() 181 m = to_meta(x) 182 183 # check the test is actually testing what it claims 184 self.assertTrue(m.is_leaf) 185 186 self.assertMetadataMatches(m, x) 187 188 def test_channels_last_leaf(self): 189 x = torch.empty(2, 3, 4, 5, memory_format=torch.channels_last, requires_grad=True) 190 to_meta = MetaConverter() 191 m = to_meta(x) 192 193 # check the test is actually testing what it claims 194 self.assertTrue(m.requires_grad) 195 self.assertTrue(m.is_leaf) 196 197 self.assertMetadataMatches(m, x) 198 199 def test_channels_last_non_leaf(self): 200 x = torch.empty(2, 3, 4, 5, memory_format=torch.channels_last, requires_grad=True) 201 y = x + 2 202 203 # sanity 204 self.assertEqual(x.stride(), y.stride()) 205 self.assertFalse(y.is_leaf) 206 207 to_meta = MetaConverter() 208 m = to_meta(y) 209 210 # check the test is actually testing what it claims 211 self.assertTrue(m.requires_grad) 212 self.assertFalse(m.is_leaf) 213 214 self.assertMetadataMatches(m, y) 215 216 # Check that we can autograd with m as input without erroring; 217 # see https://github.com/pytorch/pytorch/issues/87956 218 loss = m.sum() 219 torch.autograd.grad(loss, m) 220 221 def test_empty_strided_non_dense_leaf(self): 222 x = torch.empty_strided((2, 2), (4, 2), requires_grad=True) 223 224 to_meta = MetaConverter() 225 m = to_meta(x) 226 227 # check the test is actually testing what it claims 228 self.assertTrue(m.requires_grad) 229 self.assertTrue(m.is_leaf) 230 231 self.assertMetadataMatches(m, x) 232 233 def test_view_mutate(self): 234 x = torch.zeros(4) 235 y = x.view(2, 2) 236 237 to_meta = MetaConverter() 238 m = to_meta(y) 239 240 y.add_(torch.randn(2, 2, requires_grad=True)) 241 m.add_(torch.randn(2, 2, device='meta', requires_grad=True)) 242 243 def test_non_leaf_torture(self): 244 x = torch.empty(20, requires_grad=True) 245 with torch.no_grad(): 246 x.set_(x.storage(), 10, (2,), (2,)) 247 248 to_meta = MetaConverter() 249 m = to_meta(x) 250 251 # check the test is actually testing what it claims 252 self.assertTrue(m.requires_grad) 253 self.assertTrue(m.is_leaf) 254 255 self.assertMetadataMatches(m, x) 256 257 # NB: complex stuff is not actually exercised right now because 258 # we have a blanket exclusion for complex conversion 259 260 def test_view_as_real(self): 261 x = torch.randn(4, dtype=torch.complex64) 262 y = torch.view_as_real(x) 263 m = MetaConverter()(y) 264 self.assertMetadataMatches(m, y) 265 266 def test_complex_noncontiguous_bug(self): 267 x = torch.randn((2, 2, 4, 9), dtype=torch.complex32)[:, 0, :, :] 268 m = MetaConverter()(x) 269 self.assertMetadataMatches(m, x) 270 271 def test_view_as_complex(self): 272 x = torch.randn((4, 2), dtype=torch.float32) 273 y = torch.view_as_complex(x) 274 m = MetaConverter()(y) 275 self.assertMetadataMatches(m, y) 276 277 def test_view_dtype(self): 278 x = torch.randn(4, dtype=torch.float32) 279 y = x.view(dtype=torch.int32) 280 m = MetaConverter()(y) 281 self.assertMetadataMatches(m, y) 282 283 def test_imag(self): 284 x = torch.randn(4, dtype=torch.complex64) 285 y = x.imag 286 m = MetaConverter()(y) 287 self.assertMetadataMatches(m, y) 288 289 def test_inplace_set_storage(self): 290 x = torch.tensor([0, 1], dtype=torch.int64) 291 storage = x.untyped_storage() 292 ssize = storage.size() 293 meta = torch.empty((), dtype=torch.int64) 294 meta.set_(storage, 0, (), ()) 295 self.assertEqual(storage.size(), ssize) 296 297 @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991") 298 def test_weakref(self): 299 x = torch.randn(4, 4, 4) 300 m = MetaConverter() 301 y = m(x) 302 z = m(x) 303 self.assertIs(y, z) 304 self.assertEqual(len(m.tensor_memo), 1) 305 self.assertEqual(len(m.storage_memo), 1) 306 self.assertEqual(len(m.describer.lookup_tensor), 1) 307 self.assertEqual(len(m.describer.lookup_storage), 1) 308 del x 309 # Entries from Tensor -> int get deallocated when the real tensor 310 # disappears... 311 self.assertEqual(len(m.describer.lookup_tensor), 0) 312 self.assertEqual(len(m.describer.lookup_storage), 0) 313 del y 314 del z 315 # ... but the int -> FakeTensor entries don't die until the fake 316 # tensors themselves die (because the user may have held onto the 317 # int key and are expecting to get a consistent fake tensor in 318 # this case) 319 self.assertEqual(len(m.tensor_memo), 0) 320 self.assertEqual(len(m.storage_memo), 0) 321 li = [] 322 r = [] 323 for i in range(4): 324 li.append(torch.rand([i])) 325 r.append(m(li[-1])) 326 self.assertEqual(len(m.tensor_memo), 4) 327 self.assertEqual(len(m.storage_memo), 4) 328 self.assertEqual(len(m.describer.lookup_tensor), 4) 329 self.assertEqual(len(m.describer.lookup_storage), 4) 330 del li 331 self.assertEqual(len(m.describer.lookup_tensor), 0) 332 self.assertEqual(len(m.describer.lookup_storage), 0) 333 del r 334 self.assertEqual(len(m.tensor_memo), 0) 335 self.assertEqual(len(m.storage_memo), 0) 336 337 @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991") 338 def test_tensor_outlives_converter(self): 339 m = MetaConverter() 340 ref = weakref.ref(m) 341 x = torch.randn([4, 4]) 342 y = m(x) 343 del m 344 self.assertIs(ref(), None) 345 346aten = torch.ops.aten 347 348CHECK_STRIDES = { 349 torch.Tensor.__getitem__, 350} 351 352CHECK_ALL_STRIDES = { 353 aten.unsqueeze.default 354} 355 356CHECK_STRIDES_SKIPS = { 357 aten._conj_physical.default, 358 aten._fft_c2c.default, 359 aten._fft_c2r.default, 360 aten._fft_r2c.default, 361 aten._linalg_svd.default, 362 aten.binary_cross_entropy.default, 363 aten.complex.default, 364 aten.polar.default, 365 aten.copysign.Tensor, 366 aten.div.Tensor_mode, 367 aten.floor_divide.default, 368 aten.heaviside.default, 369 aten.lerp.Scalar, 370 aten.lerp.Tensor, 371 aten.logaddexp.default, 372 aten.logical_and.default, 373 aten.logical_or.default, 374 aten.logical_xor.default, 375 aten.pow.Scalar, 376 aten.prelu.default, 377 aten.special_xlog1py.default, 378 aten.xlogy.Tensor, 379 aten.nll_loss2d_forward.default, 380 381 # channel_last and channel_last_3d related failures 382 aten.convolution.default, 383 384 # following ops fails if include_storage_offset = True, but these are a bit edge casey 385 # we should still fix them, leaving them here for tracking. 386 # aten._reshape_alias.default, # repro with test_dispatch_symbolic_meta_outplace_all_strides_matmul_cuda_float32 387 # aten.view.default, # repro with test_dispatch_symbolic_meta_outplace_all_strides_unflatten_cuda_float32 388} 389 390CHECK_CONJ_SKIPS = { 391 # The conj bit is not copied, see: 392 # https://github.com/pytorch/pytorch/pull/101836 393 aten.linalg_lu_solve.out, 394} 395 396class CheckStrides(Enum): 397 NONE = 0 398 SIGNIFICANT = 1 399 ALL = 2 400 401def should_check_strides(func): 402 if func in CHECK_ALL_STRIDES: 403 return CheckStrides.ALL 404 if func in CHECK_STRIDES: 405 return CheckStrides.SIGNIFICANT 406 if func in CHECK_STRIDES_SKIPS: 407 return CheckStrides.NONE 408 if not isinstance(func, torch._ops.OpOverload): 409 return CheckStrides.NONE 410 # Prims are expected to model strides correctly 411 if func.namespace == "prims": 412 return CheckStrides.SIGNIFICANT 413 # Check if it's a view, by testing if any of the returns have 414 # a non-empty alias set 415 if any(r.alias_info.before_set for r in func._schema.returns if r.alias_info): 416 return CheckStrides.SIGNIFICANT 417 # TODO: check for TensorIterator 418 return CheckStrides.SIGNIFICANT 419 420def assert_ref_meta_equal(test_case, func, meta_rs, rs, msg_callable): 421 flat_meta_rs = pytree.tree_leaves(meta_rs) 422 flat_rs = pytree.tree_leaves(rs) 423 test_case.assertEqual(len(flat_meta_rs), len(flat_rs)) 424 for i, meta_r, r in zip(range(len(flat_rs)), flat_meta_rs, flat_rs): 425 def test_assert(cond, msg): 426 if not cond: 427 raise RuntimeError(f"output {i}: {msg_callable(msg)}") 428 if not isinstance(r, torch.Tensor): 429 continue 430 test_assert(isinstance(meta_r, torch.Tensor), f"but real {i}th result is Tensor") 431 test_assert(meta_r.dtype == r.dtype, f"for element {i}, was {meta_r.dtype} but real dtype was {r.dtype}") 432 test_assert(meta_r.shape == r.shape, f"for element {i}, was {meta_r.shape} but real shape was {r.shape}") 433 # See https://github.com/pytorch/pytorch/issues/78050 434 if should_check_strides(func) == CheckStrides.ALL: 435 same_strides, _ = torch._prims_common.check_all_strides(meta_r, r) 436 test_assert(same_strides, f"for element {i}, was {meta_r.stride()} but real stride was {r.stride()}") 437 elif should_check_strides(func) == CheckStrides.SIGNIFICANT: 438 same_strides, _ = torch._prims_common.check_significant_strides(meta_r, r) 439 test_assert(same_strides, f"for element {i}, was {meta_r.stride()} but real stride was {r.stride()}") 440 test_assert( 441 meta_r.storage_offset() == r.storage_offset(), 442 f"for element {i}, was {meta_r.storage_offset()} but real storage_offset was {r.storage_offset()}") 443 test_assert(meta_r.requires_grad == r.requires_grad, 444 f"for element {i}, was {meta_r.requires_grad} but real requires_grad was {r.requires_grad}") 445 if func not in CHECK_CONJ_SKIPS: 446 test_assert(meta_r.is_conj() == r.is_conj(), 447 f"for element {i}, was {meta_r.is_conj()} but real is_conj was {r.is_conj()}") 448 test_assert(meta_r.is_neg() == r.is_neg(), f"for element {i}, was {meta_r.is_neg()} but real is_neg was {r.is_neg()}") 449 450 451# This environment variable controls whether or not we print expected failure 452# lists at the end of a test suite run. The intended usage looks like this: 453# 454# 1. Run `PYTORCH_COLLECT_EXPECT=1 python test/test_meta.py` on a CUDA build 455# of PyTorch that has LAPACK/MAGMA installed. You can filter `-k test_meta` 456# or `-k test_dispatch_meta` to only focus on one or another list 457# 2. Given the printed skip/xfail list, add them to the corresponding lists; 458# torch.* entries go in meta_function and aten.* entries go in meta_dispatch. 459# If there are preexisting entries, you need to merge in the entries. 460# 461# This is somewhat manual but typically you shouldn't need to do this, unless 462# you've made a major change (e.g., added a new dtype to PyTorch) and need to 463# refresh the lists. If you want to do it from scratch, just clear out the 464# preexisting lists before running. 465# 466# WARNING: Python dict literals will silently ignore duplicate keys 467COLLECT_EXPECT = os.getenv('PYTORCH_COLLECT_EXPECT', '0') == '1' 468 469seen_succeeded = {} 470seen_failed = {} 471failed_reasons = defaultdict(set) 472def print_seen(): 473 expected_failures = [] 474 skips = [] 475 476 def fmt_dtypes(dtypes): 477 r = ', '.join(sorted(dtype_abbrs[d] for d in dtypes)) 478 return '{' + r + '}' 479 480 for op, failed_dtypes in seen_failed.items(): 481 ops = resolve_name(op) 482 succeeded_dtypes = seen_succeeded.get(op, set()) 483 expected_failures_dtypes = failed_dtypes - succeeded_dtypes 484 skips_dtypes = failed_dtypes & succeeded_dtypes 485 reasons = "" 486 if failed_reasons[op]: 487 reasons = " # " + ", ".join(sorted(failed_reasons[op])) 488 if expected_failures_dtypes: 489 expected_failures.append(f" {ops}: {fmt_dtypes(expected_failures_dtypes)},{reasons}") 490 if skips_dtypes: 491 skips.append(f" {ops}: {fmt_dtypes(skips_dtypes)},") 492 expected_failures.sort() 493 skips.sort() 494 nl = '\n' 495 print(f"""\ 496expected_failures = {{ 497{nl.join(expected_failures)} 498}} 499 500skips = {{ 501{nl.join(skips)} 502}} 503""") 504if COLLECT_EXPECT: 505 atexit.register(print_seen) 506 507# Success forces pass; failure forces fail; skip unconditionally skips testing 508TestExpect = Enum("TestExpect", ("SUCCESS", "XFAILURE", "SKIP")) 509 510# unlike print produce strides 511def verbose_print(e): 512 class Lit: 513 def __init__(self, s): 514 self.s = s 515 516 def __repr__(self): 517 return self.s 518 519 def go(t): 520 if is_sparse_any(t): 521 return t 522 elif isinstance(t, torch.Tensor): 523 return Lit(f"{t} stride={t.stride()}") 524 else: 525 return t 526 527 return repr(tree_map(go, e)) 528 529def run_meta_crossref( 530 test_case, 531 test_expect, 532 func, 533 args, 534 kwargs, 535 *, 536 dtype, 537 device_type, 538 run_symbolic_meta: bool 539): 540 to_meta = MetaConverter() 541 do_meta = test_expect is not TestExpect.SKIP 542 if do_meta: 543 try: 544 meta_args = tree_map(to_meta, args) 545 meta_kwargs = tree_map(to_meta, kwargs) 546 except Exception as e: 547 raise RuntimeError( 548 f"failed to convert args to meta; " 549 f"originally (*{args}, **{kwargs})") from e 550 try: 551 rs = func(*args, **kwargs) 552 except Exception as e: 553 raise AssertionError("Original OpInfo is broken") from e 554 555 # TODO: also handle cases where func raise an exception 556 557 # For now, only attempt if we managed to convert all tensor types 558 # (if any of them failed, we're in a mixed device situation and 559 # this isn't well supported) 560 if do_meta and to_meta.successful(): 561 # Special cases 562 if func is torch.tensor_split: 563 # Use original indices_or_sections, this argument is data dependent 564 meta_args = (meta_args[0], args[1]) + meta_args[2:] 565 elif func is torch.Tensor.__getitem__: 566 # Ensure boolean tensors use original 567 assert len(args) == 2 568 flat_args = pytree.tree_leaves(args[1]) 569 flat_meta_args, spec = tree_flatten(meta_args[1]) 570 flat_new_args = [] 571 for a, ma in zip(flat_args, flat_meta_args): 572 flat_new_args.append(a if isinstance(a, torch.Tensor) and a.dtype in [torch.int8, torch.bool] else ma) 573 meta_args = (meta_args[0], tree_unflatten(flat_new_args, spec)) 574 elif func in (torch.ops.aten.repeat_interleave.Tensor, torch.ops.aten.repeat_interleave.Tensor_out): 575 if kwargs.get("output_size", None) is None: 576 meta_args = args 577 if func is torch.ops.aten.repeat_interleave.Tensor_out: 578 meta_kwargs["out"] = kwargs["out"] 579 elif func in (torch.ops.aten.index.Tensor, torch.ops.aten.index.Tensor_out): 580 # Don't convert boolean tensors to meta as they will have nonzero 581 # called on them 582 indices = [] 583 for meta_index, real_index in zip(meta_args[1], args[1]): 584 if meta_index is not None and meta_index.dtype in [torch.int8, torch.bool]: 585 indices.append(real_index) 586 else: 587 indices.append(meta_index) 588 meta_args = (meta_args[0], indices) 589 elif func is torch.nn.functional.ctc_loss and all([isinstance(args[2], list), isinstance(args[3], list)]): 590 # torch.ops.aten._ctc_loss.IntList has a meta kernel but 591 # torch.ops.aten._ctc_loss.Tensor does not 592 test_expect = TestExpect.SUCCESS 593 594 if kwargs.get("device", None) is not None: 595 meta_kwargs["device"] = "meta" 596 597 try: 598 # Suppress warnings, this doesn't matter for test_meta.py 599 # but it does matter if you want to use this decorator 600 # for cross-ref testing, as some tests may be looking at 601 # errors 602 with warnings.catch_warnings(): 603 warnings.simplefilter("ignore") 604 if run_symbolic_meta: 605 # Run the decomps and meta kernels registered 606 # to the python dispatcher instead of the regular dispatcher. 607 # This should be the same set of kernels 608 # that fake tensor runs in dynamic shapes mode. 609 with enable_python_dispatcher(): 610 meta_rs = func(*meta_args, **meta_kwargs) 611 else: 612 meta_rs = func(*meta_args, **meta_kwargs) 613 except Exception as e: 614 if test_expect is TestExpect.XFAILURE: 615 return rs 616 seen_failed.setdefault(func, set()).add(dtype) 617 if isinstance(e, NotImplementedError): 618 m = RE_NOT_IMPLEMENTED_MSG.search(e.args[0]) 619 if m: 620 failed_reasons[func].add(m.group(1)) 621 if COLLECT_EXPECT: 622 return rs 623 raise RuntimeError(f"""\ 624failed to run: {resolve_name(func)}( 625*{verbose_print(meta_args)}, 626**{verbose_print(meta_kwargs)} 627)""") from e 628 else: 629 try: 630 delim = ',\n ' 631 assert_ref_meta_equal(test_case, func, meta_rs, rs, lambda msg: f"""\ 632meta disagrees with real impl: 633{resolve_name(func)}( 634 {delim.join(map(verbose_print, meta_args))}, 635 {delim.join(k + ": " + verbose_print(v) for k, v in meta_kwargs.items())} 636) = ( 637 {verbose_print(meta_rs)} 638) 639{msg} 640""") 641 except Exception: 642 if test_expect is TestExpect.XFAILURE: 643 return rs 644 seen_failed.setdefault(func, set()).add(dtype) 645 if COLLECT_EXPECT: 646 return rs 647 raise 648 else: 649 seen_succeeded.setdefault(func, set()).add(dtype) 650 if test_expect is TestExpect.XFAILURE and not COLLECT_EXPECT: 651 raise RuntimeError(f"unexpected success {resolve_name(func)} {meta_args} {meta_kwargs}") 652 653 return rs 654 655 656 657RE_NOT_IMPLEMENTED_MSG = re.compile(r"Could not run '([^']+)' with arguments ") 658 659meta_function_expected_failures = { 660 torch.Tensor.to_sparse : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32}, 661 torch.allclose : {f64, f16, c128, c64, bf16, f32}, 662 torch.argwhere : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32}, 663 torch.combinations : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32}, 664 torch.corrcoef : {f64, i32, c128, i64, i16, u8, c64, bf16, f16, i8, f32}, 665 torch.cov : {f64, i32, c128, i64, i16, u8, c64, bf16, i8, f32, f16}, 666 torch.functional.istft : {f64, c64, c128, f32}, 667 torch.geqrf : {f64, c64, c128, f32}, 668 torch.masked_select : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32}, 669 torch.nonzero : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32}, 670 torch.Tensor.nonzero : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32}, 671 torch.Tensor.item : {f64, i32, c128, i64, i16, f16, u8, c32, c64, bf16, b8, i8, f32}, 672 torch.bincount : {i32, i64, u8, i16, i8}, 673 torch.functional.unique : {f64, i32, i64, u8, i16, f16, bf16, b8, i8, f32, u16, u32, u64}, 674 torch.functional.unique_consecutive : {f64, i32, i64, u8, i16, f16, bf16, b8, i8, f32, u16, u32, u64}, 675 torch.histogram : {f64, f32}, 676 torch.histogramdd : {f64, f32}, 677 torch.nn.functional.ctc_loss : {f64, f32}, 678 torch.nn.functional.gaussian_nll_loss : {f16, f64, bf16, f32}, 679 torch.linalg.lstsq : {f64, f32, c128, c64}, 680} 681 682meta_function_expected_failures_conditional = { 683 torch.repeat_interleave : (lambda dtype, *args, **kwargs: not isinstance(kwargs.get("repeats", None), int)), 684} 685 686""" 687# This is some sample code for how we could dump these dicts into YAML 688# file for easier reading/writing 689import yaml 690print(yaml.dump( 691 {resolve_name(k): [dtype_abbrs[d] for d in v] 692 for k, v in meta_function_expected_failures.items()}, default_flow_style=None)) 693import sys 694sys.exit() 695""" 696 697meta_function_skips = { 698 torch.Tensor.__rmatmul__ : {bf16, c128, f64, f32, f16, c64}, 699 torch.Tensor.matmul : {f64, f32, c128, c64}, 700 torch.functional.atleast_2d : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64}, 701 torch.functional.atleast_3d : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64}, 702 torch.functional.cartesian_prod : {bf16, i8, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64}, 703 torch.functional.einsum : {bf16, c128, f64, f32, f16, c64}, 704 torch.inner : {f16, bf16, i8, i64, u8, c128, f64, i16, f32, i32, c64}, 705 torch.linalg.matrix_norm : {c128, f32, c64, f64}, 706 torch.linalg.matrix_rank : {c128, c64}, 707 torch.linalg.svd : {c128, c64}, 708 torch.matmul : {bf16, c128, f64, f32, f16, c64}, 709 torch.nanquantile : {f64, f32}, 710 torch.narrow : {bf16, i8, i64, u8, c128, b8, f64, i16, i32, f32, f16, c32, c64}, 711 torch.nn.functional.batch_norm : {f64, f32}, 712 torch.nn.functional.binary_cross_entropy : {bf16, f64, f32, f16}, 713 torch.nn.functional.dropout3d : {bf16, f64, f32, f16}, 714 torch.nn.functional.local_response_norm : {bf16, f64, f32, f16}, 715 torch.svd : {c128, c64}, 716 torch.take_along_dim : {bf16, i8, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64}, 717 torch.vstack : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64}, 718 torch.diff : {b8}, 719 torch.equal : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64}, 720 torch.nanmean : {bf16, f64, f32, f16, c32, c64, c128}, 721 torch.nn.functional.cross_entropy : {bf16, f64, f32}, 722 torch.nn.functional.nll_loss : {bf16, f64, f32}, 723 torch.linalg.cond : {c128, c64, f32, f64}, 724 torch.linalg.vecdot : {bf16, f64, f32, f16}, 725 torch.empty : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64}, 726 torch.Tensor.addbmm_: {bf16, c128, c64, f32, f64, i16, i32, i64, i8, u8}, 727 torch.nn.functional.one_hot : {i64}, 728} 729 730 731meta_function_device_expected_failures = defaultdict(dict) 732meta_function_device_expected_failures_only_outplace = defaultdict(dict) 733meta_function_device_skips = defaultdict(dict) 734 735meta_function_device_expected_failures['cpu'] = { 736 # TODO: The decomps for these batch norm ops return different dtypes depending 737 # on the device. We should make this work better with meta tensors. 738 torch.native_batch_norm: {bf16, f16}, 739 torch._native_batch_norm_legit: {bf16, f16}, 740 torch.ops.aten._batch_norm_with_update: {bf16, f16}, 741 torch.native_layer_norm: {bf16, f16}, 742} 743 744meta_function_device_expected_failures['cuda'] = { 745 torch.corrcoef: {bf16, f16}, # aten::_local_scalar_dense 746 torch.cov: {f16}, # aten::_local_scalar_dense 747 torch.functional.unique: {f16}, # aten::_unique2, aten::unique_dim 748 torch.functional.unique_consecutive: {f16}, # aten::unique_consecutive 749 torch.geqrf: {f32, f64}, # aten::geqrf 750} 751 752meta_function_device_skips['cpu'] = { 753 # TODO: The decomps for these batch norm ops return different dtypes depending 754 # on the device. We should make this work better with meta tensors. 755 torch.native_batch_norm: {f32, f64}, 756 torch._native_batch_norm_legit: {f32, f64}, 757 torch.ops.aten._batch_norm_with_update: {f32, f64}, 758} 759 760meta_function_device_skips['cuda'] = { 761 torch.inner: {f16}, 762 torch.linalg.matrix_rank: {f32, f64}, 763 torch.linalg.svd: {f32, f64}, 764 torch.nn.functional.cross_entropy: {f16}, 765 torch.nn.functional.interpolate: {f16}, 766 torch.nn.functional.nll_loss: {f16}, 767 torch.svd: {f32, f64}, 768} 769 770# This is a __torch_function__ mode that, when enabled, interposes every 771# Torch API call and runs the operator as normal, and then reruns it 772# with meta inputs, and then checks that everything about the output agrees. 773# Most of the logic deals with faithfully replicating the original tensor 774# as a meta tensor, which is nontrivial because there are a lot of subsystems 775# that may potentially be exercised. 776# 777# That being said, this class is a little overkill for what it is doing in 778# this test file (since I could have just inlined __torch_function__ on the 779# OpInfo call, and OpInfos generally have very regular inputs), but it will be 780# useful for more comprehensive testing e.g., as seen in 781# https://github.com/pytorch/pytorch/pull/75994 The big benefit is it is 782# A LOT more efficient that torch dispatch mode (at the cost of less coverage) 783class MetaCrossRefFunctionMode(torch.overrides.TorchFunctionMode): 784 test_case: TestCase 785 device_type: str 786 dtype: torch.dtype 787 788 def __init__(self, test_case, *, device, dtype, inplace): 789 self.test_case = test_case 790 self.device_type = torch.device(device).type 791 self.dtype = dtype 792 self.inplace = inplace 793 794 def __torch_function__(self, func, types, args=(), kwargs=None): 795 kwargs = kwargs or {} 796 797 if ( 798 torch.jit.is_tracing() or isinstance(func, torch.ScriptMethod) or 799 # meta converter doesn't work correctly when no_dispatch() is on, so 800 # skip running the crossref test in this case 801 torch._C._dispatch_tls_local_exclude_set().has(torch._C.DispatchKey.Python) 802 ): 803 return func(*args, **kwargs) 804 805 if self.dtype in meta_function_skips.get(func, set()): 806 test_expect = TestExpect.SKIP 807 elif self.dtype in meta_function_device_skips[self.device_type].get(func, set()): 808 test_expect = TestExpect.SKIP 809 elif self.dtype in meta_function_expected_failures.get(func, set()): 810 test_expect = TestExpect.XFAILURE 811 elif self.dtype in meta_function_device_expected_failures[self.device_type].get(func, set()): 812 test_expect = TestExpect.XFAILURE 813 elif meta_function_expected_failures_conditional.get(func, lambda *_, **__: False)(self.dtype, *args, **kwargs): 814 test_expect = TestExpect.XFAILURE 815 elif not self.inplace and \ 816 self.dtype in meta_function_device_expected_failures_only_outplace[self.device_type].get(func, set()): 817 test_expect = TestExpect.XFAILURE 818 else: 819 test_expect = TestExpect.SUCCESS 820 821 return run_meta_crossref( 822 self.test_case, test_expect, func, args, 823 kwargs, dtype=self.dtype, device_type=self.device_type, run_symbolic_meta=False 824 ) 825 826# these always fail 827meta_dispatch_expected_failures = { 828 aten.allclose.default: {f16, bf16, f32, f64, c64, c128}, # NotImplementedError: 'aten::_local_scalar_dense' 829 aten.geqrf.default : {c64, c128, f64, f32}, 830 aten.linalg_lstsq.default : {c64, c128, f64, f32}, 831 aten.masked_select.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8}, 832 aten.masked_select.out : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8}, 833 aten.nonzero.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, c32, b8, i16, u8}, 834 aten.nonzero.out : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, c32, b8, i16, u8}, 835 aten._to_sparse.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8}, 836 aten._to_sparse.sparse_dim : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8}, 837 aten._ctc_loss.Tensor : {f32, f64}, # Shape of second output depends on data. 838 aten._histogramdd_bin_edges.default : {f32, f64}, 839 aten._histogramdd_from_bin_cts.default : {f32, f64}, 840 aten._histogramdd_from_bin_tensors.default : {f32, f64}, 841 aten._local_scalar_dense.default : {c32, c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8}, 842 aten._unique2.default : {i8, f64, i64, f16, bf16, f32, i32, b8, i16, u8, u16, u32, u64}, 843 aten.bincount.default : {i64, i8, i32, i16, u8}, 844 aten.equal.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8}, 845 aten.histogram.bin_ct : {f32, f64}, 846 aten.histogram.bins_tensor : {f32, f64}, 847 aten.unique_consecutive.default : {i8, f64, i64, f16, bf16, f32, i32, b8, i16, u8, u16, u32, u64}, 848 aten.unique_dim.default : {i8, f64, i64, f16, bf16, f32, i32, b8, i16, u8, u16, u32, u64}, 849 aten.upsample_nearest3d.vec : {bf16, f32, f64, u8}, 850 851} 852 853# these sometimes pass and sometimes fail 854meta_dispatch_skips = { 855 aten.index.Tensor: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32, c32, c64, c128}, # at::nonzero doesn't have a Meta function 856 aten._to_copy.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32, c32, c64, c128}, 857 aten.empty.memory_format: {b8, bf16, c128, c64, c32, f16, f32, f64, i16, i32, i64, i8, u8}, 858 aten.addbmm_.default: {bf16, c128, c64, f32, f64, i16, i32, i64, i8, u8}, 859} 860 861# For CompositeImplicitAutograd functions that fail before hitting the Mode 862meta_dispatch_early_skips = set({ 863 torch.Tensor.float_power_, 864 # Errors out in one of the tests, while ProxyTensor passes... 865 torch.Tensor.cumprod_, 866 torch.Tensor.cumsum_, 867}) 868 869meta_inplace_skips = set({ 870 # Errors out in one of the tests, while ProxyTensor passes... 871 torch.Tensor.cumprod_, 872 torch.Tensor.cumsum_, 873}) 874 875meta_dispatch_device_expected_failures = defaultdict(dict) 876meta_dispatch_device_skips = defaultdict(dict) 877 878meta_dispatch_device_expected_failures['cpu'] = { 879 # TODO: The decomps for these batch norm ops return different dtypes depending 880 # on the device. We should make this work better with meta tensors. 881 aten.native_batch_norm.default: {bf16, f16}, 882 aten._native_batch_norm_legit.default: {bf16, f16}, 883 aten._native_batch_norm_legit.no_stats: {bf16, f16}, 884 aten._batch_norm_with_update.default: {bf16, f16}, 885 886 aten.native_layer_norm.default: {bf16, f16}, 887} 888 889meta_dispatch_device_expected_failures['cuda'] = { 890 aten._unique2.default: {f16}, # aten::_unique2 891 aten._use_cudnn_ctc_loss.default: {f32, f64}, # aten::_use_cudnn_ctc_loss 892 aten._use_cudnn_ctc_loss.Tensor: {f32, f64}, # aten::_use_cudnn_ctc_loss.Tensor 893 aten.cudnn_grid_sampler.default: {f16, f32, f64}, # aten::cudnn_grid_sampler 894 aten.geqrf.default: {f32, f64}, # aten::geqrf 895 aten.linalg_eigvalsh.out: {f32, f64}, # aten::linalg_eigvalsh.out 896 aten.log_sigmoid_forward.default: {bf16, f16, f64, f32}, 897 aten.log_sigmoid_forward.output : {bf16, f16, f64, f32}, # aten::log_sigmoid_forward.output 898 aten.unique_consecutive.default: {f16}, # aten::unique_consecutive 899 aten.unique_dim.default: {f16}, # aten::unique_dim 900 aten.upsample_nearest3d.vec: {f16}, # aten::upsample_nearest3d.vec 901} 902 903meta_dispatch_device_skips['cpu'] = { 904 aten._embedding_bag_forward_only.default: {bf16, f16, f32, f64}, 905 906 # TODO: The decomps for these batch norm ops return different dtypes depending 907 # on the device. We should make this work better with meta tensors. 908 aten.native_batch_norm.default: {f32, f64}, 909 aten._native_batch_norm_legit.default: {f32, f64}, 910 aten._native_batch_norm_legit.no_stats: {f32, f64}, 911 aten._batch_norm_with_update.default: {f32, f64}, 912 913 # If the computation dtype is different from the input 914 # dtype this will fail. CPU execution may also have a 915 # a different output from other devices. 916 aten.native_batch_norm.out: {bf16, f16, f32, f64} 917} 918 919meta_dispatch_device_skips['cuda'] = { 920 aten._conj.default: {c32, f16}, # file issue 921 aten._linalg_svd.default: {c64, c128}, # aten::linalg_eigvalsh.out 922 aten.cudnn_batch_norm.default: {f32, f64}, 923 aten.log_softmax.int : {c32, c64}, 924 aten.softmax.int : {c32, c64}, 925 aten.softmax.int : {c32, c64}, 926 927 # ROCm stuff; technically this should be expected failure but it's 928 # not worth it; these should get unified anyway 929 aten.miopen_batch_norm.default: {f32}, 930} 931 932def get_strided_args(args): 933 934 def get_strided_variants(t, include_storage_offset=False): 935 variants = [] 936 937 # contiguous 938 variants.append(t) 939 940 # transposed 941 if t.ndim > 1: 942 perm = list(reversed(range(t.ndim))) 943 transposed = torch.empty( 944 t.shape[::-1], device=t.device, dtype=t.dtype, requires_grad=t.requires_grad 945 ).permute(perm).copy_(t) 946 variants.append(transposed) 947 948 # nondense 949 if t.ndim > 0: 950 nondense = torch.repeat_interleave(t, 2, dim=-1)[..., ::2] 951 variants.append(nondense) 952 953 # channel_last 954 if t.ndim == 4: 955 variants.append(t.contiguous(memory_format=torch.channels_last)) 956 957 # channel_last_3d 958 if t.ndim == 5: 959 variants.append(t.contiguous(memory_format=torch.channels_last_3d)) 960 961 # storage_offset 962 if include_storage_offset: 963 buffer = torch.empty(t.numel() + 1, device=t.device, dtype=t.dtype, requires_grad=t.requires_grad) 964 buffer = buffer.as_strided(t.shape, t.stride(), storage_offset=1) 965 buffer.copy_(t) 966 variants.append(buffer) 967 968 return variants 969 970 strided_args = [] 971 for arg in args: 972 if isinstance(arg, torch.Tensor) and not arg.is_sparse_csr and arg.is_contiguous(): 973 strided_arg_variants = get_strided_variants(arg) 974 else: 975 strided_arg_variants = [arg] 976 strided_args.append(strided_arg_variants) 977 978 yield from itertools.product(*strided_args) 979 980class MetaCrossRefDispatchMode(torch.utils._python_dispatch.TorchDispatchMode): 981 test_case: TestCase 982 device: torch.device 983 dtype: torch.dtype 984 aten_olp_no_out_overload: set = set() 985 986 def __init__(self, test_case, *, device, dtype, symbolic_meta: bool, inplace: bool, supports_out: bool): 987 self.test_case = test_case 988 # save TLS 989 self.precision = test_case.precision 990 self.rel_tol = test_case.rel_tol 991 self.device_type = torch.device(device).type 992 self.dtype = dtype 993 self.symbolic_meta = symbolic_meta 994 self.inplace = inplace 995 self.supports_out = supports_out 996 997 @staticmethod 998 def try_resolve_aten_out_overload(ol, args, kwargs, num_outputs): 999 1000 ol_args = ol._schema.arguments 1001 olp: OpOverloadPacket = ol._overloadpacket 1002 1003 if olp in MetaCrossRefDispatchMode.aten_olp_no_out_overload: 1004 return (None, None, None) 1005 1006 candidate_ols = [] 1007 for candidate_ol_name in olp.overloads(): 1008 candidate_ol = getattr(olp, candidate_ol_name) 1009 if any(arg.is_out for arg in candidate_ol._schema.arguments): 1010 candidate_ols.append(candidate_ol) 1011 1012 if not candidate_ols: 1013 MetaCrossRefDispatchMode.aten_olp_no_out_overload.add(olp) 1014 return (None, None, None) 1015 1016 # Now match based on args, kwargs and number of required outputs 1017 candidate_ol: OpOverload = None 1018 for candidate_ol in candidate_ols: 1019 candidate_ol_args = candidate_ol._schema.arguments 1020 1021 if (len(args) >= len(candidate_ol_args)): 1022 continue 1023 1024 # Positional arguments must have the same type 1025 if not all( 1026 ol_args[pos_arg_ind].type == candidate_ol_args[pos_arg_ind].type 1027 for pos_arg_ind in range(len(args)) 1028 ): 1029 continue 1030 1031 # Number of outputs must match 1032 candidate_out_names = [out_arg.name for out_arg in candidate_ol_args[-num_outputs:] if out_arg.is_out] 1033 if len(candidate_out_names) != num_outputs: 1034 continue 1035 1036 # Now try and match kwargs. Just need to ensure that the 1037 # remaining kwargs allow an out overload to be called. For example 1038 # we can throw away parameters like `dtype` that may be passed to the 1039 # functional version of the op since the `dtype` will already be present 1040 # in the `out` argument 1041 new_kwargs = {} 1042 kwargs_match = True 1043 for arg in candidate_ol_args[len(args):-num_outputs]: 1044 if arg.name not in kwargs: 1045 if arg.has_default_value(): 1046 new_kwargs[arg.name] = arg.default_value 1047 elif isinstance(arg.type, torch.OptionalType): 1048 if isinstance(arg.type.getElementType(), torch.BoolType): 1049 new_kwargs[arg.name] = False 1050 else: 1051 new_kwargs[arg.name] = None 1052 else: 1053 kwargs_match = False 1054 break 1055 else: 1056 new_kwargs[arg.name] = kwargs[arg.name] 1057 1058 if kwargs_match: 1059 return candidate_ol, candidate_out_names, new_kwargs 1060 1061 return None, None, None 1062 1063 def _get_expected_test_result(self, func: OpOverload): 1064 if self.dtype in meta_dispatch_skips.get(func, set()): 1065 test_expect = TestExpect.SKIP 1066 elif self.dtype in meta_dispatch_device_skips[self.device_type].get(func, set()): 1067 test_expect = TestExpect.SKIP 1068 elif self.dtype in meta_dispatch_expected_failures.get(func, set()): 1069 test_expect = TestExpect.XFAILURE 1070 elif self.dtype in meta_dispatch_device_expected_failures[self.device_type].get(func, set()): 1071 test_expect = TestExpect.XFAILURE 1072 else: 1073 test_expect = TestExpect.SUCCESS 1074 return test_expect 1075 1076 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1077 kwargs = kwargs or {} 1078 self.test_case.precision = self.precision 1079 self.test_case.rel_tol = self.rel_tol 1080 1081 test_expect = self._get_expected_test_result(func) 1082 1083 expected = run_meta_crossref( 1084 self.test_case, 1085 test_expect, 1086 func, 1087 args, 1088 kwargs, 1089 dtype=self.dtype, 1090 device_type=self.device_type, 1091 run_symbolic_meta=self.symbolic_meta, 1092 ) 1093 1094 # This is to test torch ops that do not have an out parameter but have 1095 # aten op overloads that have out parameters. Additionally, Python decompositions 1096 # may register OpOverloadPacket's so decompositions need to be tested 1097 # to ensure all OpOverloads still function for the Meta key (e.g. if a python decomposition 1098 # is registered for an aten op aten.foo with overloads [default, out], the python 1099 # function needs to support receiving `out` arguments) 1100 if ( 1101 not self.inplace and 1102 not self.supports_out and 1103 test_expect == TestExpect.SUCCESS and 1104 (torch.is_tensor(expected) or isinstance(expected, Iterable)) 1105 ): 1106 1107 # check to see if there is a potential out overload 1108 num_outputs = 1 if torch.is_tensor(expected) else len(expected) 1109 func_out_overload, out_param_names, kwargs = self.try_resolve_aten_out_overload(func, args, kwargs, num_outputs) 1110 1111 if func_out_overload: 1112 1113 if num_outputs == 1: 1114 kwargs[out_param_names[0]] = expected 1115 else: 1116 for ind, out_param_name in enumerate(out_param_names): 1117 kwargs[out_param_name] = expected[ind] 1118 1119 test_expect = self._get_expected_test_result(func_out_overload) 1120 1121 run_meta_crossref( 1122 self.test_case, 1123 test_expect, 1124 func_out_overload, 1125 args, 1126 kwargs, 1127 dtype=self.dtype, 1128 device_type=self.device_type, 1129 run_symbolic_meta=self.symbolic_meta, 1130 ) 1131 1132 return expected 1133 1134# NB: we're running these tests only on CUDA because there are some 1135# inconsistencies between CUDA and CPU, and running on CUDA makes it easier 1136# to ignore the CPU case when inconsistencies arise. Ideally we deal 1137# with the inconsistencies but this takes time. 1138@unMarkDynamoStrictTest 1139class TestMeta(TestCase): 1140 # Copies inputs to inplace operations to avoid inplace modifications 1141 # to leaves requiring gradient 1142 def _get_safe_inplace(self, inplace_variant): 1143 @wraps(inplace_variant) 1144 def _fn(t, *args, **kwargs): 1145 if isinstance(t, list): 1146 return inplace_variant([x.clone() for x in t], *args, **kwargs) 1147 else: 1148 return inplace_variant(t.clone(), *args, **kwargs) 1149 1150 return _fn 1151 1152 @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 1153 @skipIfCrossRef 1154 @suppress_warnings 1155 @ops(itertools.chain(op_db, foreach_op_db)) 1156 def test_meta_outplace(self, device, dtype, op): 1157 if "_scaled_mm" in op.name: 1158 raise unittest.SkipTest("_scaled_mm dose not support meta device") 1159 skip_op_names = ( 1160 "fft.ihfft", 1161 "fft.ihfft2", 1162 "linalg.lu_solve", 1163 ) 1164 if TEST_WITH_TORCHDYNAMO and op.name in skip_op_names: 1165 raise unittest.SkipTest("flaky") 1166 # run the OpInfo sample inputs, cross-referencing them with the 1167 # meta implementation and check the results are the same. All 1168 # the heavy lifting happens in MetaCrossRefFunctionMode 1169 func = op.get_op() 1170 samples = op.sample_inputs(device, dtype, requires_grad=False) 1171 for sample_input in samples: 1172 args = [sample_input.input] + list(sample_input.args) 1173 kwargs = sample_input.kwargs 1174 with MetaCrossRefFunctionMode(self, dtype=dtype, device=device, inplace=False): 1175 expected = func(*args, **kwargs) 1176 if isinstance(expected, torch.Tensor) and op.supports_out: 1177 func(*args, **kwargs, out=expected) 1178 1179 # Special test for functions taking "device" kwarg 1180 # The crossref tests that replacing the device with "meta" works 1181 # This part makes sure that *_like functions work well with a "meta" 1182 # Tensor and their original device argument. 1183 if "device" in kwargs and "_like" in op.name: 1184 with torch.random.fork_rng(): 1185 torch.manual_seed(123) 1186 ref = func(*args, **kwargs) 1187 1188 # *_like functions take a Tensor as first argument 1189 assert isinstance(args[0], torch.Tensor) 1190 with torch.random.fork_rng(): 1191 torch.manual_seed(123) 1192 args[0] = args[0].to(device="meta") 1193 meta = func(*args, **kwargs) 1194 1195 # empty_like is not deterministic 1196 if op.name != "empty_like": 1197 self.assertEqual(ref, meta) 1198 1199 @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 1200 @skipIfCrossRef 1201 @suppress_warnings 1202 @ops(itertools.chain(op_db, foreach_op_db)) 1203 def test_meta_inplace(self, device, dtype, op): 1204 func = op.get_inplace() 1205 if not func: 1206 self.skipTest("No inplace variable for this op") 1207 if op.promotes_int_to_float and not dtype.is_floating_point: 1208 self.skipTest("Op promotes to float, which is impossible for inplace with non-float input") 1209 if func in meta_inplace_skips: 1210 self.skipTest("Skipped") 1211 func = self._get_safe_inplace(func) 1212 samples = op.sample_inputs(device, dtype, requires_grad=False) 1213 for sample_input in samples: 1214 if sample_input.broadcasts_input: 1215 continue 1216 args = [sample_input.input] + list(sample_input.args) 1217 kwargs = sample_input.kwargs 1218 with MetaCrossRefFunctionMode(self, dtype=dtype, device=device, inplace=True): 1219 expected = func(*args, **kwargs) 1220 1221 def _run_dispatch_meta_test(self, device, dtype, op, symbolic_meta, inplace, all_stride_variants=False): 1222 if "_scaled_mm" in op.name: 1223 raise unittest.SkipTest("_scaled_mm dose not support meta device") 1224 if inplace: 1225 func = op.get_inplace() 1226 if not func: 1227 self.skipTest("No inplace variable for this op") 1228 if op.promotes_int_to_float and not dtype.is_floating_point: 1229 self.skipTest("Op promotes to float, which is impossible for inplace with non-float input") 1230 else: 1231 func = op.get_op() 1232 1233 if func in meta_dispatch_early_skips: 1234 self.skipTest("Function is in dispatch early skips") 1235 1236 if inplace: 1237 func = self._get_safe_inplace(func) 1238 1239 samples = op.sample_inputs(device, dtype, requires_grad=False) 1240 for sample_input in samples: 1241 if inplace and sample_input.broadcasts_input: 1242 continue 1243 1244 sample_args = [sample_input.input] + list(sample_input.args) 1245 kwargs = sample_input.kwargs 1246 1247 if all_stride_variants and sum(isinstance(arg, torch.Tensor) for arg in sample_args) <= 5: 1248 # test inputs <= 5 tensors to avoid combinatorial explosion 1249 strided_args = get_strided_args(sample_args) 1250 else: 1251 strided_args = [sample_args] 1252 1253 for args in strided_args: 1254 with MetaCrossRefDispatchMode.push( 1255 self, dtype=dtype, device=device, 1256 symbolic_meta=symbolic_meta, inplace=inplace, 1257 supports_out=op.supports_out): 1258 expected = func(*args, **kwargs) 1259 1260 if not inplace and isinstance(expected, torch.Tensor) and op.supports_out: 1261 func(*args, **kwargs, out=expected) 1262 1263 1264 @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 1265 @skipIfCrossRef 1266 @suppress_warnings 1267 @ops(itertools.chain(op_db, foreach_op_db)) 1268 def test_dispatch_meta_outplace(self, device, dtype, op): 1269 self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=False, inplace=False) 1270 1271 @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 1272 @skipIfCrossRef 1273 @suppress_warnings 1274 @ops(itertools.chain(op_db, foreach_op_db)) 1275 def test_dispatch_meta_inplace(self, device, dtype, op): 1276 self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=False, inplace=True) 1277 1278 @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 1279 @skipIfCrossRef 1280 @suppress_warnings 1281 @ops(itertools.chain(op_db, foreach_op_db)) 1282 def test_dispatch_symbolic_meta_outplace(self, device, dtype, op): 1283 self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=True, inplace=False) 1284 1285 1286 @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 1287 @skipIfCrossRef 1288 @suppress_warnings 1289 @ops(itertools.chain(op_db, foreach_op_db)) 1290 def test_dispatch_symbolic_meta_inplace(self, device, dtype, op): 1291 self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=True, inplace=True) 1292 1293 @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 1294 @skipIfCrossRef 1295 @suppress_warnings 1296 # only test one dtype, as output stride behavior is the same for all dtypes 1297 @ops(itertools.chain(op_db, foreach_op_db), dtypes=OpDTypes.any_common_cpu_cuda_one) 1298 # Only test on CUDA, as CUDA kernel's stride is the reference 1299 @onlyCUDA 1300 def test_dispatch_symbolic_meta_outplace_all_strides(self, device, dtype, op): 1301 self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=True, inplace=False, all_stride_variants=True) 1302 1303 @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 1304 @skipIfCrossRef 1305 @suppress_warnings 1306 # only test one dtype, as output stride behavior is the same for all dtypes 1307 @ops(itertools.chain(op_db, foreach_op_db), dtypes=OpDTypes.any_common_cpu_cuda_one) 1308 # Only test on CUDA, as CUDA kernel's stride is the reference 1309 @onlyCUDA 1310 def test_dispatch_symbolic_meta_inplace_all_strides(self, device, dtype, op): 1311 self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=True, inplace=True, all_stride_variants=True) 1312 1313 @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 1314 @skipIfCrossRef 1315 @suppress_warnings 1316 # only test one dtype, as output stride behavior is the same for all dtypes 1317 @ops(binary_ufuncs, allowed_dtypes=(torch.float32,)) 1318 # Only test on CUDA, as CUDA kernel's stride is the reference 1319 @onlyCUDA 1320 def test_binary_ufuncs_mixed_dtype(self, device, dtype, op): 1321 make_arg = partial( 1322 make_tensor, 1323 device=device, 1324 ) 1325 1326 def sample_input(op, device, dtype, requires_grad, **kwargs): 1327 yield SampleInput( 1328 make_arg((S,), dtype=dtype), make_arg((S,), dtype=torch.float16) 1329 ) 1330 1331 op = copy.copy(op) 1332 op.sample_inputs_func = sample_input 1333 1334 self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=True, inplace=False) 1335 1336 1337 def test_empty_quantized(self): 1338 r = torch.empty(2 ** 52, device='meta', dtype=torch.qint8) 1339 self.assertEqual(r.device.type, 'meta') 1340 1341 def test_nan_to_num(self): 1342 t = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14], device='meta') 1343 r = t.nan_to_num() 1344 self.assertEqual(r.device.type, 'meta') 1345 1346 def test_inplace_masked_fill_error(self): 1347 t = torch.randn(3, 3, device='meta') 1348 with self.assertRaisesRegex(RuntimeError, "doesn't match the broadcast"): 1349 t.masked_fill_((t > 0).unsqueeze(0), 0.1) 1350 1351 def test_inplace_bin_ops_error(self): 1352 t = torch.randn(3, 3, device='meta') 1353 for op in (torch.Tensor.add_, torch.Tensor.sub_, torch.Tensor.mul_, torch.Tensor.div_, 1354 torch.Tensor.logical_and_, torch.Tensor.logical_or_, torch.Tensor.logical_xor_): 1355 with self.assertRaisesRegex(RuntimeError, "doesn't match the broadcast"): 1356 op(t, t.clone().unsqueeze(0)) 1357 1358 @onlyCPU 1359 def test_meta_autograd_no_error(self): 1360 with torch.library._scoped_library("meta_test", "DEF") as lib: 1361 with torch.library._scoped_library("meta_test", "IMPL", "CPU") as impl_cpu: 1362 with torch.library._scoped_library("meta_test", "IMPL", "Meta") as impl_meta: 1363 def foo_impl(x): 1364 return x + 1 1365 1366 lib.define("foo(Tensor a) -> Tensor") 1367 impl_meta.impl("foo", foo_impl) 1368 impl_cpu.impl("foo", foo_impl) 1369 1370 a = torch.ones(2, device='meta') 1371 # The point of the test is that this should not error: 1372 # We have a fallthrough kernel registered to the AutogradMeta 1373 # key for custom ops, so it's fine that `foo()` doesn't have 1374 # an autograd kernel. 1375 b = torch.ops.meta_test.foo.default(a) 1376 1377 def test_huber_loss_backward(self): 1378 inps = [torch.rand(2**52, device='meta') for _ in range(3)] 1379 r = torch.ops.aten.huber_loss_backward(*inps, 0, 1.0) 1380 self.assertEqual(r.device.type, 'meta') 1381 self.assertEqual(r.shape, inps[0].shape) 1382 1383 def _norm_backwards_test_helper(self, op, args, output_mask, expected_shapes): 1384 1385 dtype = torch.float32 1386 device = "meta" 1387 1388 # test functional call 1389 grads = op(*args, output_mask) 1390 1391 def assertEqualShapes(res, exp): 1392 self.assertIsNone(res) if exp is None else self.assertEqual(exp, res.shape) 1393 1394 assertEqualShapes(grads[0], expected_shapes[0]) 1395 assertEqualShapes(grads[1], expected_shapes[1]) 1396 assertEqualShapes(grads[2], expected_shapes[2]) 1397 1398 out_kwargs = { 1399 f"out{i}": torch.empty(0, device=device, dtype=dtype) 1400 for i in range(len(output_mask)) 1401 } 1402 1403 # test call with out parameters 1404 grads = op(*args, output_mask, **out_kwargs) 1405 1406 def assertEqualShapes(res, exp): 1407 self.assertEqual(exp, res.shape) if exp is not None else True 1408 1409 assertEqualShapes(out_kwargs["out0"], expected_shapes[0]) 1410 assertEqualShapes(out_kwargs["out1"], expected_shapes[1]) 1411 assertEqualShapes(out_kwargs["out2"], expected_shapes[2]) 1412 1413 @onlyCPU 1414 @parametrize("output_mask", list(itertools.product([True, False], [True, False], [True, False]))) 1415 def test_layer_norm_backward(self, output_mask): 1416 from torch.testing._internal.common_methods_invocations import sample_inputs_layer_norm 1417 1418 device = "meta" 1419 dtype = torch.float32 1420 1421 samples = sample_inputs_layer_norm(None, device, dtype, requires_grad=False) 1422 1423 for sample in samples: 1424 with self.subTest(sample=sample): 1425 # handle optional weight and bias 1426 if len(sample.args) != 3: 1427 sample.args = (*sample.args, *([None] * (3 - len(sample.args)))) 1428 1429 grad_out = torch.ones_like(sample.input) 1430 normalized_shape, weight, bias = sample.args 1431 ndims_after_reduction = sample.input.ndim - len(normalized_shape) 1432 mean_shape = grad_out.shape[:ndims_after_reduction] 1433 mean = torch.zeros(mean_shape, device=device, dtype=dtype) 1434 rstd = torch.zeros(mean_shape, device=device, dtype=dtype) 1435 1436 expected_shapes = ( 1437 sample.input.shape if output_mask[0] else None, 1438 weight.shape if output_mask[1] and weight is not None else None, 1439 bias.shape if output_mask[2] and bias is not None else None) 1440 1441 args = [grad_out, sample.input, normalized_shape, mean, rstd, weight, bias] 1442 1443 self._norm_backwards_test_helper(torch.ops.aten.native_layer_norm_backward, 1444 args, output_mask, expected_shapes) 1445 1446 @onlyCPU 1447 @parametrize("output_mask", list(itertools.product([True, False], [True, False], [True, False]))) 1448 def test_group_norm_backward(self, output_mask): 1449 from torch.testing._internal.common_methods_invocations import sample_inputs_group_norm 1450 1451 # input, (args) num_groups, (kwargs) weight, bias eps 1452 device = "meta" 1453 dtype = torch.float32 1454 samples = sample_inputs_group_norm(None, device, dtype, requires_grad=False) 1455 1456 for sample in samples: 1457 with self.subTest(sample=sample): 1458 grad_out = torch.ones_like(sample.input) 1459 N, C = sample.input.shape[:2] 1460 HxW = torch.prod(torch.as_tensor(sample.input.shape[2:]), dtype=torch.int32).item() 1461 group = sample.args[0] 1462 mean = torch.zeros((N, group), device=device, dtype=dtype) 1463 rstd = torch.zeros((N, group), device=device, dtype=dtype) 1464 weight = torch.zeros((C), device=device, dtype=dtype) 1465 1466 args = [grad_out, sample.input, mean, rstd, weight, N, C, HxW, group] 1467 1468 expected_shapes = ( 1469 sample.input.shape if output_mask[0] else None, 1470 weight.shape if output_mask[1] else None, 1471 weight.shape if output_mask[2] else None) 1472 1473 # test functional call 1474 self._norm_backwards_test_helper(torch.ops.aten.native_group_norm_backward, 1475 args, output_mask, expected_shapes) 1476 1477 @onlyCPU 1478 @parametrize("output_mask", list(itertools.product([True], [True, False], [True, False]))) 1479 def test_batch_norm_backward(self, output_mask): 1480 from torch.testing._internal.common_methods_invocations import sample_inputs_batch_norm 1481 1482 # input, (args) num_groups, (kwargs) weight, bias eps 1483 device = "meta" 1484 dtype = torch.float32 1485 samples = sample_inputs_batch_norm(None, device, dtype, requires_grad=False) 1486 1487 for sample in samples: 1488 with self.subTest(sample=sample): 1489 1490 if sample.input.dim() < 2: 1491 continue 1492 1493 grad_out = torch.ones_like(sample.input) 1494 running_mean, running_var, weight, bias = sample.args 1495 train = sample.kwargs.get("training", True) 1496 save_mean = torch.zeros((sample.input.shape[1], ), device=device, dtype=dtype) if train else None 1497 save_invstd = torch.zeros((sample.input.shape[1], ), device=device, dtype=dtype) if train else None 1498 1499 args = [grad_out, sample.input, weight, running_mean, running_var, 1500 save_mean, save_invstd, train, sample.kwargs.get("eps", 1e-5)] 1501 1502 expected_shapes = ( 1503 sample.input.shape, 1504 torch.Size([sample.input.shape[1]]) if output_mask[1] else None, 1505 torch.Size([sample.input.shape[1]]) if output_mask[2] else None) 1506 1507 self._norm_backwards_test_helper(torch.ops.aten.native_batch_norm_backward, 1508 args, output_mask, expected_shapes) 1509 1510 def test_fill__alias_relationship(self): 1511 inps = torch.rand(2**52, device='meta') 1512 r = torch.ops.aten.fill_(inps, 1.0) 1513 # aten.fill_ returns an aliase 1514 self.assertEqual(id(inps), id(r)) 1515 1516 # aten.fill returns a new tensor 1517 r2 = torch.ops.aten.fill(inps, 1.0) 1518 self.assertNotEqual(id(inps), id(r2)) 1519 1520 def test_meta__fused_moving_avg_obs_fq_helper(self, device): 1521 from torch.ao.quantization import FusedMovingAvgObsFakeQuantize 1522 to_meta = MetaConverter() 1523 1524 x = torch.randn(5, 5, device=device) 1525 running_min_op = torch.tensor(float("inf"), device=device) 1526 running_max_op = torch.tensor(float("-inf"), device=device) 1527 avg_const = 0.01 1528 scale = torch.tensor([1.0], device=device) 1529 zero_point = torch.tensor([0], dtype=torch.int, device=device) 1530 1531 mod = FusedMovingAvgObsFakeQuantize() 1532 torch.ao.quantization.enable_fake_quant(mod) 1533 torch.ao.quantization.enable_observer(mod) 1534 mod.to(device) 1535 1536 meta_x = to_meta(x) 1537 1538 args = [ 1539 x, 1540 mod.observer_enabled, 1541 mod.fake_quant_enabled, 1542 running_min_op, 1543 running_max_op, 1544 scale, 1545 zero_point, 1546 avg_const, 1547 0, 1548 255, 1549 0, 1550 ] 1551 1552 meta_args = args.copy() 1553 meta_args[0] = meta_x 1554 1555 kwargss = [ 1556 {}, 1557 {"per_row_fake_quant": False, "symmetric_quant": False}, 1558 {"per_row_fake_quant": False, "symmetric_quant": True}, 1559 ] 1560 1561 for kwargs in kwargss: 1562 ref_out = aten._fused_moving_avg_obs_fq_helper.default(*args, **kwargs) 1563 meta_out = aten._fused_moving_avg_obs_fq_helper.default(*meta_args, **kwargs) 1564 1565 self.assertEqual(ref_out[0].size(), meta_out[0].size()) 1566 self.assertEqual(ref_out[0].stride(), meta_out[0].stride()) 1567 self.assertEqual(ref_out[1].size(), meta_out[1].size()) 1568 self.assertEqual(ref_out[1].stride(), meta_out[1].stride()) 1569 1570 def test_cdist_forward(self, device): 1571 to_meta = MetaConverter() 1572 x1 = torch.rand([3, 2], device=device) 1573 x2 = torch.rand([2, 2], device=device) 1574 p = 2.0 1575 for compute_mode in (None, 1, 2): 1576 ref = aten._cdist_forward.default(x1, x2, p, compute_mode) 1577 res = aten._cdist_forward.default(to_meta(x1), to_meta(x2), p, compute_mode) 1578 self.assertEqual(res.device.type, 'meta') 1579 self.assertEqual(ref.shape, res.shape) 1580 1581 def test_quantized_embedding_bag(self): 1582 tab_shape = [8, 128] 1583 emb_size, ind_len, off_len = tab_shape[0], 32, 33 1584 f_table = torch.from_numpy((np.random.random_sample(tab_shape) + 1).astype(np.float32)) 1585 q_table = torch.ops.quantized.embedding_bag_byte_prepack(f_table) 1586 indices = torch.from_numpy(np.random.randint(low=0, high=emb_size, size=ind_len)).int() 1587 max_length = len(indices) // (off_len - 1) 1588 if max_length > 20: 1589 max_length = 20 1590 np_lengths = np.random.randint(0, max_length + 1, size=off_len - 1).astype(np.int32) 1591 offsets = torch.cat([torch.zeros([1]), torch.cumsum(torch.from_numpy(np_lengths), 0)]).int() 1592 1593 eb = torch.ops.quantized.embedding_bag_byte_rowwise_offsets( 1594 q_table.to(device="meta"), 1595 indices.to(device="meta"), 1596 offsets.to(device="meta"), 1597 mode=0, # sum 1598 per_sample_weights=None, 1599 include_last_offset=True, 1600 ) 1601 self.assertEqual(eb.shape, [32, 128]) 1602 self.assertEqual(eb.dtype, torch.float32) 1603 self.assertEqual(eb.untyped_storage().data_ptr(), 0) 1604 1605 # Tests mean and max. 1606 # Can't easily test sum, because there is a fast path for sum which 1607 # causes offset2bag to not get allocated... but the backward function 1608 # needs it, and the offset2bag computation lives inside the 1609 # derivatives.yaml formula directly, so there is no way to access it. 1610 # To test sum, need to manually compute offset2bag 1611 @parametrize("mode", [1, 2]) 1612 def test_embedding_bag_dense_backward(self, mode): 1613 weight = torch.randn(4, 3, requires_grad=True) 1614 indices = torch.tensor([1, 0, 2, 1, 3]) 1615 offsets = torch.tensor([0, 2, 3, 5]) 1616 scale_grad_by_freq = False 1617 sparse = False 1618 per_sample_weights = None 1619 include_last_offset = False 1620 padding_idx = -1 1621 1622 output, offset2bag, bag_size, maximum_indices = torch.ops.aten._embedding_bag.default( 1623 weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx 1624 ) 1625 grad = torch.randn_like(output) 1626 1627 # Call the function with example inputs 1628 grad_weight = torch.ops.aten._embedding_bag_dense_backward.default( 1629 grad, indices, offset2bag, bag_size, maximum_indices, weight.size(0), 1630 scale_grad_by_freq, mode, per_sample_weights, padding_idx 1631 ) 1632 meta_grad_weight = torch.ops.aten._embedding_bag_dense_backward.default( 1633 grad.to('meta'), indices.to('meta'), offset2bag.to('meta'), bag_size.to('meta'), 1634 maximum_indices.to('meta'), weight.size(0), 1635 scale_grad_by_freq, mode, per_sample_weights, padding_idx 1636 ) 1637 self.assertEqual(grad_weight.to('meta'), meta_grad_weight) 1638 1639 def test_embedding_bag_dense_backward_per_sample_weights(self): 1640 weight = torch.randn(4, 3, requires_grad=True) 1641 indices = torch.tensor([1, 0, 2, 1, 3]) 1642 offsets = torch.tensor([0, 2, 3, 5]) 1643 scale_grad_by_freq = False 1644 sparse = False 1645 mode = 0 1646 per_sample_weights = torch.randn(5, requires_grad=True) 1647 include_last_offset = False 1648 padding_idx = -1 1649 1650 output, offset2bag, bag_size, maximum_indices = torch.ops.aten._embedding_bag.default( 1651 weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx 1652 ) 1653 grad = torch.randn_like(output) 1654 1655 # Call the function with example inputs 1656 grad_weight = torch.ops.aten._embedding_bag_per_sample_weights_backward.default( 1657 grad, weight, indices, offsets, offset2bag, mode, padding_idx 1658 ) 1659 meta_grad_weight = torch.ops.aten._embedding_bag_per_sample_weights_backward.default( 1660 grad.to('meta'), weight.to('meta'), indices.to('meta'), 1661 offsets.to('meta'), offset2bag.to('meta'), mode, padding_idx 1662 ) 1663 self.assertEqual(grad_weight.to('meta'), meta_grad_weight) 1664 1665 # opinfo test is using aten.fill_, it's not testing aten.fill 1666 @onlyCUDA 1667 def test_fill_stride(self): 1668 to_meta = MetaConverter() 1669 sample_args = [torch.rand(2, 2, 2, 2), 1.0] 1670 1671 for args in get_strided_args(sample_args): 1672 meta_args = to_meta(args) 1673 ref_out = torch.ops.aten.fill(*args) 1674 meta_out = torch.ops.aten.fill(*meta_args) 1675 self.assertEqual(ref_out.size(), meta_out.size()) 1676 self.assertEqual(ref_out.stride(), meta_out.stride()) 1677 1678 1679 def test_map_location_deserialize(self): 1680 import io 1681 1682 t = torch.rand(10) 1683 b = io.BytesIO() 1684 1685 torch.save(t, b) 1686 b.seek(0) 1687 r = torch.load(b, map_location=torch.device("meta")) 1688 self.assertEqual(r.device.type, 'meta') 1689 self.assertEqual(r.shape, t.shape) 1690 self.assertEqual(r.dtype, t.dtype) 1691 self.assertEqual(r.storage().data_ptr(), 0) 1692 1693 def test_embedding_bag_byte_prepack(self): 1694 batch_size = 10 1695 num_embeddings = 80 1696 embedding_dim = [128, 256, 512] 1697 res_shape = [[batch_size, num_embeddings, ed + 8] for ed in embedding_dim] 1698 for ed, rs in zip(embedding_dim, res_shape): 1699 weight = torch.randn(batch_size, num_embeddings, ed, dtype=torch.float32) 1700 res = torch.ops.quantized.embedding_bag_byte_prepack(weight.to(device="meta")) 1701 self.assertEqual(res.shape, rs) 1702 self.assertEqual(res.dtype, torch.float32) 1703 self.assertEqual(res.untyped_storage().data_ptr(), 0) 1704 1705 def test_embedding_bag_byte_unpack(self): 1706 batch_size = 10 1707 num_embeddings = 80 1708 embedding_dim = [128, 256, 512] 1709 res_shape = [[batch_size, num_embeddings, ed] for ed in embedding_dim] 1710 for ed, rs in zip(embedding_dim, res_shape): 1711 packed_weight = torch.randn(batch_size, num_embeddings, ed + 8, dtype=torch.float32) 1712 res = torch.ops.quantized.embedding_bag_byte_unpack(packed_weight.to(device="meta")) 1713 self.assertEqual(res.shape, rs) 1714 self.assertEqual(res.dtype, torch.float32) 1715 self.assertEqual(res.untyped_storage().data_ptr(), 0) 1716 1717 def test_index_select_out(self): 1718 def f(): 1719 input = torch.randn([8, 16], device='meta') 1720 index = torch.tensor([2, 1, 6, 7, 3, 1, 7, 5, 6, 7], device='meta') 1721 out = torch.empty([10, 16], device='meta') 1722 return torch.index_select(input=input, dim=0, index=index, out=out) 1723 with enable_python_dispatcher(): 1724 out = f() 1725 self.assertEqual(out.shape, [10, 16]) 1726 1727 def test_local_scalar_dense_call(self): 1728 with self.assertRaisesRegex(RuntimeError, "cannot be called on meta tensors"): 1729 meta_tensor = torch.randn(1, device='meta') 1730 meta_tensor.item() 1731 1732instantiate_device_type_tests(TestMeta, globals()) 1733 1734def print_op_str_if_not_supported(op_str): 1735 op = OperatorName.parse(op_str) 1736 packet = getattr(torch.ops.aten, str(op.name)) 1737 overload = getattr(packet, op.overload_name if op.overload_name else "default") 1738 if any(overload in d for d in [meta_dispatch_skips, meta_dispatch_device_skips['cuda']]): 1739 print(f"{overload} # SKIP") 1740 if any(overload in d for d in [meta_dispatch_expected_failures, meta_dispatch_device_expected_failures['cuda']]): 1741 print(overload) 1742 1743 1744if __name__ == "__main__": 1745 COMPARE_XLA = os.getenv('PYTORCH_COMPARE_XLA', None) 1746 if COMPARE_XLA is not None: 1747 with open(COMPARE_XLA) as f: 1748 d = yaml.load(f, Loader=YamlLoader) 1749 ops = d.get("full_codegen", []) + d.get("supported", []) + d.get("autograd", []) 1750 for op_str in ops: 1751 print_op_str_if_not_supported(op_str) 1752 sys.exit(0) 1753 1754 COMPARE_TEXT = os.getenv('PYTORCH_COMPARE_TEXT', None) 1755 if COMPARE_TEXT is not None: 1756 with open(COMPARE_TEXT) as f: 1757 for op_str in f: 1758 print_op_str_if_not_supported(op_str.strip()) 1759 sys.exit(0) 1760 1761 run_tests() 1762