1# Owner(s): ["module: unknown"] 2 3import contextlib 4import copy 5import inspect 6import itertools 7import os 8import re 9import unittest 10import warnings 11from collections import defaultdict 12from collections.abc import Sequence 13from functools import partial 14from importlib import import_module 15from typing import Dict, List 16 17import torch 18import torch._prims as prims 19import torch.utils._pytree as pytree 20from torch._prims.context import TorchRefsMode 21from torch._prims_common.wrappers import _maybe_remove_out_wrapper 22from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode 23from torch._subclasses.fake_utils import outputs_alias_inputs 24from torch.testing import make_tensor 25from torch.testing._internal import composite_compliance, opinfo 26from torch.testing._internal.common_device_type import ( 27 deviceCountAtLeast, 28 instantiate_device_type_tests, 29 onlyCPU, 30 onlyCUDA, 31 onlyNativeDeviceTypesAnd, 32 OpDTypes, 33 ops, 34 skipMeta, 35) 36from torch.testing._internal.common_dtype import ( 37 all_types_and_complex_and, 38 floating_and_complex_types_and, 39 integral_types_and, 40) 41from torch.testing._internal.common_methods_invocations import ( 42 BinaryUfuncInfo, 43 op_db, 44 ops_and_refs, 45 python_ref_db, 46 ReductionOpInfo, 47 ReductionPythonRefInfo, 48 skip, 49 skipOps, 50 SpectralFuncInfo, 51 UnaryUfuncInfo, 52 xfail, 53) 54from torch.testing._internal.common_utils import ( 55 clone_input_helper, 56 first_sample, 57 IS_CI, 58 IS_FBCODE, 59 is_iterable_of_tensors, 60 IS_SANDCASTLE, 61 IS_WINDOWS, 62 noncontiguous_like, 63 parametrize, 64 run_tests, 65 set_default_dtype, 66 skipIfTorchInductor, 67 slowTest, 68 suppress_warnings, 69 TEST_WITH_ASAN, 70 TEST_WITH_ROCM, 71 TEST_WITH_TORCHDYNAMO, 72 TEST_WITH_TORCHINDUCTOR, 73 TEST_WITH_UBSAN, 74 TestCase, 75 unMarkDynamoStrictTest, 76) 77from torch.utils._python_dispatch import TorchDispatchMode 78from torch.utils._pytree import tree_map 79 80 81assert torch.get_default_dtype() == torch.float32 82 83# variant testing is only done with torch.float and torch.cfloat to avoid 84# excessive test times and maximize signal to noise ratio 85_variant_ops = partial( 86 ops, dtypes=OpDTypes.supported, allowed_dtypes=(torch.float, torch.cfloat) 87) 88 89# Get names of all the operators which have ref in their entry in OpInfo (testing infra) 90# except for elementwise unary operators (separately implemented in test/test_unary_ufuncs.py), 91# elementwise binary operators (separately implemented in test_binary_ufuncs.py), 92# reduction operations (separately impelemented in test_reductions.py), 93# and Spectral Functions (separately implemented for only 1D as of now, in test/test_spectral_ops.py) 94_ref_test_ops = tuple( 95 filter( 96 lambda op: not isinstance( 97 op, (UnaryUfuncInfo, ReductionOpInfo, SpectralFuncInfo, BinaryUfuncInfo) 98 ) 99 and op.ref is not None, 100 op_db, 101 ) 102) 103 104 105def reduction_dtype_filter(op): 106 if ( 107 not isinstance(op, ReductionPythonRefInfo) 108 or not op.supports_out 109 or torch.int16 not in op.dtypes 110 ): 111 return False 112 return "dtype" in inspect.getfullargspec(op.op).kwonlyargs 113 114 115# Create a list of operators that are a subset of _ref_test_ops but don't have a 116# numpy ref to compare them too, If both CPU and CUDA are compared to numpy 117# then they do not need to be compared to each other 118_ops_and_refs_with_no_numpy_ref = [op for op in ops_and_refs if op.ref is None] 119 120aten = torch.ops.aten 121 122 123# Tests that apply to all operators and aren't related to any particular 124# system 125@unMarkDynamoStrictTest 126class TestCommon(TestCase): 127 exact_dtype = True 128 129 # Verifies, on teardown, that no OpInfo is still using dynamic dtypes in CI 130 @classmethod 131 def tearDownClass(cls): 132 super().tearDownClass() 133 134 if IS_CI: 135 err_msg = ( 136 "The operator(s) below is(are) using dynamic_dtypes in the OpInfo entries." 137 "This is OK for testing, but be sure to set the dtypes manually before landing your PR!" 138 ) 139 # Assure no opinfo entry has dynamic_dtypes 140 filtered_ops = list(filter(opinfo.utils.is_dynamic_dtype_set, op_db)) 141 for op in filtered_ops: 142 fmt_str = opinfo.utils.str_format_dynamic_dtype(op) 143 err_msg += "\n" + fmt_str 144 145 assert len(filtered_ops) == 0, err_msg 146 147 # Validates that each OpInfo works correctly on different CUDA devices 148 @onlyCUDA 149 @deviceCountAtLeast(2) 150 @ops(op_db, allowed_dtypes=(torch.float32, torch.long)) 151 def test_multiple_devices(self, devices, dtype, op): 152 for cuda_device_str in devices: 153 cuda_device = torch.device(cuda_device_str) 154 # NOTE: only tests on first sample 155 samples = op.sample_inputs(cuda_device, dtype) 156 sample = first_sample(self, samples) 157 result = op(sample.input, *sample.args, **sample.kwargs) 158 159 if isinstance(result, torch.Tensor): 160 self.assertTrue(result.device == cuda_device) 161 elif is_iterable_of_tensors(result): 162 self.assertTrue(all(t.device == cuda_device for t in result)) 163 else: 164 self.skipTest( 165 "Skipped! Only supports single tensor or iterable of tensor outputs." 166 ) 167 168 def test_pointwise_tag_coverage(self): 169 pytorch_dir = os.path.abspath(__file__ + "/../../") 170 files = [ 171 "aten/src/ATen/native/UnaryOps.cpp", 172 "aten/src/ATen/native/BinaryOps.cpp", 173 "aten/src/ATen/native/PointwiseOps.cpp", 174 "aten/src/ATen/native/TensorCompare.cpp", 175 ] 176 177 allowed_functions = ( 178 # reduction version of these operators 179 "aten.max.default", 180 "aten.max.dim", 181 "aten.max.dim_max", 182 "aten.max.names_dim", 183 "aten.max.names_dim_max", 184 "aten.max.unary_out", 185 "aten.min.default", 186 "aten.min.dim", 187 "aten.min.dim_min", 188 "aten.min.names_dim", 189 "aten.min.names_dim_min", 190 "aten.min.unary_out", 191 # not pointwise 192 "aten.isin.Tensor_Tensor", 193 "aten.isin.Tensor_Tensor_out", 194 "aten.isin.Tensor_Scalar", 195 "aten.isin.Tensor_Scalar_out", 196 "aten.isin.Scalar_Tensor", 197 "aten.isin.Scalar_Tensor_out", 198 "aten.mode.default", 199 "aten.mode.dimname", 200 "aten.mode.dimname_out", 201 "aten.mode.values", 202 ) 203 204 regex = re.compile(r"DEFINE_DISPATCH\(.*_stub") 205 206 def get_opoverloadpacket_from_dispatch(kernel): 207 if hasattr(torch.ops.aten, kernel): 208 return kernel 209 if hasattr(torch.ops.aten, f"__{kernel}__"): 210 return f"__{kernel}__" 211 if hasattr(torch.ops.aten, f"special_{kernel}"): 212 return f"special_{kernel}" 213 if "_" in kernel: 214 kernel_split = kernel.split("_") 215 new_kernel = "_".join(kernel_split[:-1]) 216 if hasattr(torch.ops.aten, new_kernel): 217 return new_kernel 218 219 # could not find op from kernel dispatch string 220 self.assertTrue(False) 221 222 for file_name in files: 223 with open(os.path.join(pytorch_dir, file_name)) as f: 224 lines = f.read() 225 matches = regex.findall(lines) 226 for match in matches: 227 kernel = match[len("DEFINE_DISPATCH(") : -len("_stub")] 228 229 # no op definition for it, but defined with DEFINE_DISPATCH ? 230 if kernel == "trigamma": 231 continue 232 233 kernel = get_opoverloadpacket_from_dispatch(kernel) 234 overloadpacket = getattr(torch.ops.aten, kernel) 235 236 for overload_name in overloadpacket.overloads(): 237 overload = getattr(overloadpacket, overload_name) 238 239 if not torch._C._dispatch_has_kernel(overload.name()): 240 continue 241 242 # TODO: tags are not propagated to generated overload, 243 # and there's no way of specifying them 244 if torch.Tag.generated in overload.tags: 245 continue 246 247 if str(overload) in allowed_functions: 248 continue 249 250 self.assertTrue(torch.Tag.pointwise in overload.tags) 251 252 # Tests that the function and its (ndarray-accepting) reference produce the same 253 # values on the tensors from sample_inputs func for the corresponding op. 254 # This test runs in double and complex double precision because 255 # NumPy does computation internally using double precision for many functions 256 # resulting in possible equality check failures. 257 # skip windows case on CPU due to https://github.com/pytorch/pytorch/issues/129947 258 @onlyNativeDeviceTypesAnd(["hpu"]) 259 @suppress_warnings 260 @ops(_ref_test_ops, allowed_dtypes=(torch.float64, torch.long, torch.complex128)) 261 def test_numpy_ref(self, device, dtype, op): 262 if ( 263 TEST_WITH_TORCHINDUCTOR 264 and op.formatted_name 265 in ("signal_windows_exponential", "signal_windows_bartlett") 266 and dtype == torch.float64 267 and "cuda" in device 268 or "cpu" in device 269 ): # noqa: E121 270 raise unittest.SkipTest("XXX: raises tensor-likes are not close.") 271 272 # Sets the default dtype to NumPy's default dtype of double 273 with set_default_dtype(torch.double): 274 for sample_input in op.reference_inputs(device, dtype): 275 self.compare_with_reference( 276 op, op.ref, sample_input, exact_dtype=(dtype is not torch.long) 277 ) 278 279 # Tests that the cpu and gpu results are consistent 280 @onlyCUDA 281 @suppress_warnings 282 @slowTest 283 @ops(_ops_and_refs_with_no_numpy_ref, dtypes=OpDTypes.any_common_cpu_cuda_one) 284 def test_compare_cpu(self, device, dtype, op): 285 def to_cpu(arg): 286 if isinstance(arg, torch.Tensor): 287 return arg.to(device="cpu") 288 return arg 289 290 samples = op.reference_inputs(device, dtype) 291 292 for sample in samples: 293 cpu_sample = sample.transform(to_cpu) 294 cuda_results = op(sample.input, *sample.args, **sample.kwargs) 295 cpu_results = op(cpu_sample.input, *cpu_sample.args, **cpu_sample.kwargs) 296 297 # output_process_fn_grad has a very unfortunate name 298 # We use this function in linalg extensively to postprocess the inputs of functions 299 # that are not completely well-defined. Think svd and muliplying the singular vectors by -1. 300 # CPU and CUDA implementations of the SVD can return valid SVDs that are different. 301 # We use this function to compare them. 302 cuda_results = sample.output_process_fn_grad(cuda_results) 303 cpu_results = cpu_sample.output_process_fn_grad(cpu_results) 304 305 # Lower tolerance because we are running this as a `@slowTest` 306 # Don't want the periodic tests to fail frequently 307 self.assertEqual(cuda_results, cpu_results, atol=1e-3, rtol=1e-3) 308 309 # Tests that experimental Python References can propagate shape, dtype, 310 # and device metadata properly. 311 # See https://github.com/pytorch/pytorch/issues/78050 for a discussion of stride propagation. 312 @onlyNativeDeviceTypesAnd(["hpu"]) 313 @ops(python_ref_db) 314 @skipIfTorchInductor("Takes too long for inductor") 315 def test_python_ref_meta(self, device, dtype, op): 316 CHECK_CONJ_SKIPS = { 317 torch._refs.linalg.svd, 318 } 319 320 with FakeTensorMode() as mode: 321 pass 322 323 def _to_tensormeta(x): 324 if isinstance(x, torch.Tensor): 325 out = FakeTensor.from_tensor(x, mode) 326 return out 327 return x 328 329 # TODO: iterate over requires_grad true/false 330 for sample in op.reference_inputs(device, dtype, requires_grad=False): 331 result = op(sample.input, *sample.args, **sample.kwargs) 332 333 meta_sample = sample.transform(_to_tensormeta) 334 try: 335 with mode: 336 meta_result = op( 337 meta_sample.input, *meta_sample.args, **meta_sample.kwargs 338 ) 339 except torch._subclasses.fake_tensor.UnsupportedFakeTensorException: 340 continue 341 except torch._subclasses.fake_tensor.DataDependentOutputException: 342 continue 343 except torch._subclasses.fake_tensor.UnsupportedOperatorException: 344 continue 345 346 if isinstance(result, torch.Tensor): 347 self.assertTrue(isinstance(meta_result, FakeTensor)) 348 prims.utils.compare_tensor_meta( 349 result, meta_result, check_conj=op.op not in CHECK_CONJ_SKIPS 350 ) 351 elif isinstance(result, Sequence): 352 for a, b in zip(result, meta_result): 353 if isinstance(a, torch.Tensor) or isinstance(b, torch.Tensor): 354 self.assertTrue(isinstance(b, FakeTensor)) 355 prims.utils.compare_tensor_meta( 356 a, b, check_conj=op.op not in CHECK_CONJ_SKIPS 357 ) 358 359 def _ref_test_helper( 360 self, 361 ctx, 362 device, 363 dtype, 364 op, 365 skip_zero_numel=False, 366 skip_zero_dim=False, 367 skip_bfloat=False, 368 skip_view_consistency=False, 369 ): 370 # NOTE: this test works by comparing the reference 371 ex = None 372 for sample in op.reference_inputs(device, dtype, requires_grad=False): 373 if ( 374 isinstance(sample.input, torch.Tensor) 375 and sample.input.numel() == 0 376 and skip_zero_numel 377 ): 378 continue 379 if ( 380 isinstance(sample.input, torch.Tensor) 381 and sample.input.ndim == 0 382 and skip_zero_dim 383 ): 384 continue 385 386 if skip_bfloat and ( 387 ( 388 isinstance(sample.input, torch.Tensor) 389 and sample.input.dtype == torch.bfloat16 390 ) 391 or any( 392 isinstance(arg, torch.Tensor) and arg.dtype == torch.bfloat16 393 for arg in sample.args 394 ) 395 ): 396 continue 397 with ctx(): 398 ref_result = op(sample.input, *sample.args, **sample.kwargs) 399 torch_result = op.torch_opinfo(sample.input, *sample.args, **sample.kwargs) 400 401 for a, b in zip( 402 pytree.tree_leaves(ref_result), pytree.tree_leaves(torch_result) 403 ): 404 if isinstance(a, torch.Tensor) or isinstance(b, torch.Tensor): 405 prims.utils.compare_tensor_meta(a, b) 406 if ( 407 getattr(op, "validate_view_consistency", True) 408 and not skip_view_consistency 409 ): 410 msg = ( 411 f"The torch implementation {'returns' if b._is_view() else 'does not return'} " 412 f"a view, while the reference {'does' if a._is_view() else 'does not'}" 413 ) 414 self.assertEqual(a._is_view(), b._is_view(), msg) 415 416 # Computes the dtype the more precise computatino would occur in 417 precise_dtype = torch.bool 418 if prims.utils.is_integer_dtype(dtype): 419 # Note: bool and integer dtypes do not have more 420 # precise dtypes -- they simply must be close 421 precise_dtype = dtype 422 if prims.utils.is_float_dtype(dtype): 423 precise_dtype = torch.double 424 if prims.utils.is_complex_dtype(dtype): 425 precise_dtype = torch.cdouble 426 427 # Checks if the results are close 428 try: 429 self.assertEqual( 430 ref_result, 431 torch_result, 432 exact_stride=False, 433 exact_device=True, 434 exact_layout=True, 435 exact_is_coalesced=True, 436 ) 437 except AssertionError as e: 438 # Raises the error if the precise dtype comparison wouldn't be 439 # different 440 if dtype is precise_dtype: 441 raise e 442 443 ex = e 444 445 # Goes to next sample if these results are close 446 if not ex: 447 continue 448 449 # If the results are not close, checks that the 450 # reference is more accurate than the torch op 451 def _make_precise(x): 452 if isinstance(x, torch.dtype): 453 return precise_dtype 454 if isinstance(x, torch.Tensor) and x.dtype is dtype: 455 return x.to(precise_dtype) 456 return x 457 458 precise_sample = sample.transform(_make_precise) 459 precise_result = op.torch_opinfo( 460 precise_sample.input, *precise_sample.args, **precise_sample.kwargs 461 ) 462 463 def _distance(a, b): 464 # Special-cases boolean comparisons 465 if prims.utils.is_boolean_dtype(a.dtype): 466 assert b.dtype is torch.bool 467 return (a ^ b).sum() 468 469 same = a == b 470 if prims.utils.is_float_dtype(a.dtype) or prims.utils.is_complex_dtype( 471 a.dtype 472 ): 473 same = torch.logical_or( 474 same, torch.logical_and(torch.isnan(a), torch.isnan(b)) 475 ) 476 477 actual_error = torch.where(same, 0, torch.abs(a - b)).sum() 478 return actual_error 479 480 ref_distance = 0 481 for a, b in zip( 482 pytree.tree_leaves(ref_result), pytree.tree_leaves(precise_result) 483 ): 484 ref_distance = ref_distance + _distance(a, b) 485 486 torch_distance = 0 487 for a, b in zip( 488 pytree.tree_leaves(torch_result), pytree.tree_leaves(precise_result) 489 ): 490 torch_distance = torch_distance + _distance(a, b) 491 492 # TODO: consider adding some tolerance to this comparison 493 msg = ( 494 f"Reference result was farther ({ref_distance}) from the precise " 495 f"computation than the torch result was ({torch_distance})!" 496 ) 497 self.assertTrue(ref_distance <= torch_distance, msg=msg) 498 499 # Reports numerical accuracy discrepancies 500 if ex is not None: 501 msg = "Test passed because the reference was more accurate than the torch operator." 502 warnings.warn(msg) 503 504 # Tests that experimental Python References perform the same computation 505 # as the operators they reference, when operator calls in the torch 506 # namesapce are remapped to the refs namespace (torch.foo becomes refs.foo). 507 @onlyNativeDeviceTypesAnd(["hpu"]) 508 @ops(python_ref_db) 509 @skipIfTorchInductor("Takes too long for inductor") 510 def test_python_ref(self, device, dtype, op): 511 # In this test, primTorch refs call into the refs namespace 512 # For example, a ref with torch.foo in it will calls refs.foo instead 513 # Direct calls to refs and prims are not affected 514 if ( 515 TEST_WITH_ROCM 516 and (op.name == "_refs.fft.ihfftn" or op.name == "_refs.fft.ihfft2") 517 and dtype == torch.float16 518 ): 519 self.skipTest("Skipped on ROCm") 520 self._ref_test_helper(lambda: TorchRefsMode(strict=True), device, dtype, op) 521 522 # Tests that experimental Python References perform the same computation 523 # as the operators they reference, when operator calls in the torch 524 # namespace are preserved (torch.foo remains torch.foo). 525 @onlyNativeDeviceTypesAnd(["hpu"]) 526 @ops(python_ref_db) 527 @skipIfTorchInductor("Takes too long for inductor") 528 def test_python_ref_torch_fallback(self, device, dtype, op): 529 # In this test, refs call into the torch namespace (after the initial invocation) 530 # For example, a ref with torch.foo in it will call torch.foo instead of refs.foo 531 # Direct calls to refs and prims are not translated 532 if TEST_WITH_ROCM and op.name == "_refs.fft.ihfftn" and dtype == torch.float16: 533 self.skipTest("Skipped on ROCm") 534 self._ref_test_helper(contextlib.nullcontext, device, dtype, op) 535 536 @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 537 @onlyCUDA 538 @ops(python_ref_db) 539 @parametrize("executor", ["aten"]) 540 @skipIfTorchInductor("Takes too long for inductor") 541 def test_python_ref_executor(self, device, dtype, op, executor): 542 if ( 543 TEST_WITH_ROCM 544 and (op.name == "_refs.fft.ihfftn" or op.name == "_refs.fft.ihfft2") 545 and dtype == torch.float16 546 ): 547 self.skipTest("Skipped on ROCm") 548 # skip zero-dim tensors for some composites of reduction operations and view 549 skip_zero_dim_ops = [ 550 "_refs.logsumexp", 551 "_refs.log_softmax", 552 "_refs.native_group_norm", 553 "_refs.softmax", 554 "_refs.sum_to_size", 555 "ops.nvprims.view", 556 ] 557 558 from copy import copy 559 560 from torch._prims.executor import make_traced 561 562 op = copy(op) 563 op.op = partial(make_traced(op.op), executor=executor) 564 self._ref_test_helper(contextlib.nullcontext, device, dtype, op) 565 566 @skipMeta 567 @onlyNativeDeviceTypesAnd(["hpu"]) 568 @ops([op for op in op_db if op.error_inputs_func is not None], dtypes=OpDTypes.none) 569 def test_errors(self, device, op): 570 error_inputs = op.error_inputs(device) 571 for ei in error_inputs: 572 si = ei.sample_input 573 with self.assertRaisesRegex(ei.error_type, ei.error_regex): 574 out = op(si.input, *si.args, **si.kwargs) 575 self.assertFalse(isinstance(out, type(NotImplemented))) 576 577 @skipMeta 578 @onlyNativeDeviceTypesAnd(["hpu"]) 579 @ops( 580 [op for op in op_db if op.error_inputs_sparse_func is not None], 581 dtypes=OpDTypes.none, 582 ) 583 @parametrize( 584 "layout", 585 ( 586 torch.sparse_csr, 587 torch.sparse_csc, 588 torch.sparse_bsr, 589 torch.sparse_bsc, 590 torch.sparse_coo, 591 ), 592 ) 593 def test_errors_sparse(self, device, op, layout): 594 for ei in op.error_inputs_sparse(device, layout): 595 si = ei.sample_input 596 with self.assertRaisesRegex(ei.error_type, ei.error_regex): 597 out = op(si.input, *si.args, **si.kwargs) 598 self.assertFalse(isinstance(out, type(NotImplemented))) 599 600 @skipMeta 601 @onlyNativeDeviceTypesAnd(["hpu"]) 602 @ops( 603 [op for op in python_ref_db if op.error_inputs_func is not None], 604 dtypes=OpDTypes.none, 605 ) 606 @skipIfTorchInductor("Takes too long for inductor") 607 def test_python_ref_errors(self, device, op): 608 mode = FakeTensorMode() 609 with mode: 610 pass 611 612 def _to_tensormeta(x): 613 if isinstance(x, torch.Tensor): 614 return FakeTensor.from_tensor(x, mode) 615 return x 616 617 error_inputs = op.error_inputs(device) 618 for ei in error_inputs: 619 si = ei.sample_input 620 meta_sample = si.transform(_to_tensormeta) 621 with self.assertRaisesRegex(ei.error_type, ei.error_regex): 622 op(meta_sample.input, *meta_sample.args, **meta_sample.kwargs) 623 624 # Tests that the function produces the same result when called with 625 # noncontiguous tensors. 626 # TODO: get working with Windows by addressing failing operators 627 # TODO: get working with ASAN by addressing failing operators 628 @unittest.skipIf(IS_WINDOWS, "Skipped under Windows") 629 @onlyNativeDeviceTypesAnd(["hpu"]) 630 @suppress_warnings 631 @ops(op_db, allowed_dtypes=(torch.float32, torch.long, torch.complex64)) 632 def test_noncontiguous_samples(self, device, dtype, op): 633 test_grad = dtype in op.supported_backward_dtypes(torch.device(device).type) 634 sample_inputs = op.sample_inputs(device, dtype, requires_grad=test_grad) 635 for sample_input in sample_inputs: 636 t_inp, t_args, t_kwargs = ( 637 sample_input.input, 638 sample_input.args, 639 sample_input.kwargs, 640 ) 641 noncontig_sample = sample_input.noncontiguous() 642 n_inp, n_args, n_kwargs = ( 643 noncontig_sample.input, 644 noncontig_sample.args, 645 noncontig_sample.kwargs, 646 ) 647 648 # validates forward 649 expected = op(t_inp, *t_args, **t_kwargs) 650 actual = op(n_inp, *n_args, **n_kwargs) 651 652 self.assertEqual(actual, expected) 653 654 # Validate backward 655 # Short-circuits if the op doesn't support grad in this device x dtype 656 if not test_grad: 657 continue 658 659 expected = sample_input.output_process_fn_grad(expected) 660 actual = sample_input.output_process_fn_grad(actual) 661 662 if isinstance(expected, torch.Tensor): 663 grad_for_expected = torch.randn_like(expected) 664 grad_for_actual = noncontiguous_like(grad_for_expected) 665 elif isinstance(expected, Sequence): 666 # Filter output elements that do not require grad 667 expected = [ 668 t 669 for t in expected 670 if isinstance(t, torch.Tensor) and t.requires_grad 671 ] 672 actual = [ 673 n for n in actual if isinstance(n, torch.Tensor) and n.requires_grad 674 ] 675 grad_for_expected = [torch.randn_like(t) for t in expected] 676 grad_for_actual = [noncontiguous_like(n) for n in grad_for_expected] 677 else: 678 # Nothing to do if it returns a scalar or things like that 679 continue 680 681 # Concatenate inputs into a tuple 682 t_inputs = ( 683 (t_inp,) + t_args 684 if isinstance(t_inp, torch.Tensor) 685 else tuple(t_inp) + t_args 686 ) 687 n_inputs = ( 688 (n_inp,) + n_args 689 if isinstance(n_inp, torch.Tensor) 690 else tuple(n_inp) + n_args 691 ) 692 693 # Filter the elemnts that are tensors that require grad 694 t_input_tensors = [ 695 t for t in t_inputs if isinstance(t, torch.Tensor) and t.requires_grad 696 ] 697 n_input_tensors = [ 698 n for n in n_inputs if isinstance(n, torch.Tensor) and n.requires_grad 699 ] 700 701 self.assertEqual(len(t_input_tensors), len(n_input_tensors)) 702 703 # Some functions may not use all the inputs to generate gradients. One of the 704 # few examples of this "odd" behaviour is F.hinge_embedding_loss 705 t_grads = torch.autograd.grad( 706 expected, t_input_tensors, grad_for_expected, allow_unused=True 707 ) 708 n_grads = torch.autograd.grad( 709 actual, n_input_tensors, grad_for_actual, allow_unused=True 710 ) 711 712 msg = "Got different gradients for contiguous / non-contiguous inputs wrt input {}." 713 for i, (t, n) in enumerate(zip(t_grads, n_grads)): 714 self.assertEqual(t, n, msg=msg.format(i)) 715 716 # Separates one case from the following test_out because many ops don't properly implement the 717 # incorrectly sized out parameter warning properly yet 718 # Cases test here: 719 # - out= with the correct dtype and device, but the wrong shape 720 @ops(ops_and_refs, dtypes=OpDTypes.none) 721 def test_out_warning(self, device, op): 722 if TEST_WITH_TORCHDYNAMO and op.name == "_refs.clamp": 723 self.skipTest("flaky") 724 # Prefers running in float32 but has a fallback for the first listed supported dtype 725 supported_dtypes = op.supported_dtypes(self.device_type) 726 if len(supported_dtypes) == 0: 727 self.skipTest("Skipped! Op has not supported dtypes on this device.") 728 dtype = ( 729 torch.float32 730 if torch.float32 in supported_dtypes 731 else next(iter(supported_dtypes)) 732 ) 733 734 # Ops from python_ref_db point to python decomps that are potentially 735 # wrapped with `torch._prims_common.wrappers.out_wrapper`. Unwrap these 736 # ops before testing to avoid clashing with OpInfo.supports_out 737 if not op.supports_out: 738 op = copy.copy(op) 739 op.op = _maybe_remove_out_wrapper(op.op) 740 741 samples = op.sample_inputs(device, dtype) 742 for sample in samples: 743 # calls it normally to get the expected result 744 expected = op(sample.input, *sample.args, **sample.kwargs) 745 op_out = partial(op, sample.input, *sample.args, **sample.kwargs) 746 747 # Short-circuits if output is not a single tensor or an 748 # iterable of tensors 749 if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors( 750 expected, include_empty=True 751 ): 752 self.skipTest( 753 "Skipped! Only supports single tensor or iterable of tensor outputs." 754 ) 755 756 # Validates the op doesn't support out if it claims not to 757 if not op.supports_out: 758 with self.assertRaises(Exception): 759 assert op_out(out=expected) != NotImplemented 760 return 761 762 # A wrapper around map that works with single tensors and always 763 # instantiates the map. Used below to apply transforms to 764 # single tensor and iterable tensor outputs. 765 def _apply_out_transform(fn, out): 766 if isinstance(out, torch.Tensor): 767 return fn(out) 768 769 # assumes (see above) that out is an iterable of tensors 770 return tuple(map(fn, out)) 771 772 # Extracts strides from a tensor or iterable of tensors into a tuple 773 def _extract_strides(out): 774 if isinstance(out, torch.Tensor): 775 return (out.stride(),) 776 777 # assumes (see above) that out is an iterable of tensors 778 return tuple(t.stride() for t in out) 779 780 # Extracts data pointers from a tensor or iterable of tensors into a tuple 781 # NOTE: only extracts on the CPU and CUDA device types since some 782 # device types don't have storage 783 def _extract_data_ptrs(out): 784 if self.device_type != "cpu" and self.device_type != "cuda": 785 return () 786 787 if isinstance(out, torch.Tensor): 788 return (out.data_ptr(),) 789 790 # assumes (see above) that out is an iterable of tensors 791 return tuple(t.data_ptr() for t in out) 792 793 @suppress_warnings 794 def _compare_out(transform, *, compare_strides_and_data_ptrs=True): 795 out = _apply_out_transform(transform, expected) 796 original_strides = _extract_strides(out) 797 original_ptrs = _extract_data_ptrs(out) 798 799 op_out(out=out) 800 final_strides = _extract_strides(out) 801 final_ptrs = _extract_data_ptrs(out) 802 803 self.assertEqual(expected, out) 804 805 if compare_strides_and_data_ptrs: 806 stride_msg = ( 807 f"Strides are not the same! Original strides were {original_strides} " 808 f"and strides are now {final_strides}" 809 ) 810 self.assertEqual(original_strides, final_strides, msg=stride_msg) 811 self.assertEqual(original_ptrs, final_ptrs) 812 813 # Case Zero: out= with the correct dtype and device, but the wrong shape 814 # Expected behavior: if nonempty, resize with a warning. 815 def _case_zero_transform(t): 816 wrong_shape = list(t.shape) 817 818 if len(wrong_shape) == 0: 819 # Handles scalar tensor case (empty list) 820 wrong_shape = [2] 821 else: 822 wrong_shape[-1] = wrong_shape[-1] + 1 823 return make_tensor(wrong_shape, dtype=t.dtype, device=t.device) 824 825 # Verifies the out values are correct 826 _compare_out(_case_zero_transform, compare_strides_and_data_ptrs=False) 827 828 # Additionally validates that the appropriate warning is thrown if a nonempty 829 # tensor is resized. 830 def _any_nonempty(out): 831 if isinstance(out, torch.Tensor): 832 return out.numel() > 0 833 834 return any(x.numel() > 0 for x in out) 835 836 out = _apply_out_transform(_case_zero_transform, expected) 837 msg_fail = "Resized a non-empty tensor but did not warn about it." 838 if _any_nonempty(out): 839 with self.assertWarnsRegex( 840 UserWarning, "An output with one or more elements", msg=msg_fail 841 ): 842 op_out(out=out) 843 844 # Validates ops implement the correct out= behavior 845 # See https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-does-out-work-in-pytorch 846 # for a description of the correct behavior 847 # Validates the following cases: 848 # - Case 0: out has the correct shape, dtype, and device but is full of extremal values 849 # - Case 1: out has the correct shape, dtype, and device but is noncontiguous 850 # - Case 2: out has the correct dtype and device, but is zero elements 851 # - Case 3: out has the correct shape and dtype, but is on a different device type 852 # - Case 4: out has the correct shape and device, but a dtype that cannot 853 # "safely" cast to 854 # 855 # Case 3 and 4 are slightly different when the op is a factory function: 856 # - if device, dtype are NOT passed, any combination of dtype/device should be OK for out 857 # - if device, dtype are passed, device and dtype should match 858 @ops(ops_and_refs, dtypes=OpDTypes.any_one) 859 def test_out(self, device, dtype, op): 860 # Prefers running in float32 but has a fallback for the first listed supported dtype 861 samples = op.sample_inputs(device, dtype) 862 863 # Ops from python_ref_db point to python decomps that are potentially 864 # wrapped with `torch._prims_common.wrappers.out_wrapper`. Unwrap these 865 # ops before testing to avoid clashing with OpInfo.supports_out 866 if not op.supports_out: 867 op = copy.copy(op) 868 op.op = _maybe_remove_out_wrapper(op.op) 869 870 for sample in samples: 871 # calls it normally to get the expected result 872 expected = op(sample.input, *sample.args, **sample.kwargs) 873 op_out = partial(op, sample.input, *sample.args, **sample.kwargs) 874 875 # Short-circuits if output is not a single tensor or an 876 # iterable of tensors 877 if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors( 878 expected, include_empty=True 879 ): 880 self.skipTest( 881 "Skipped! Only supports single tensor or iterable of tensor outputs." 882 ) 883 884 # Validates the op doesn't support out if it claims not to 885 if not op.supports_out: 886 with self.assertRaises(Exception): 887 assert op_out(out=expected) != NotImplemented 888 return 889 890 # A wrapper around map that works with single tensors and always 891 # instantiates the map. Used below to apply transforms to 892 # single tensor and iterable tensor outputs. 893 def _apply_out_transform(fn, out): 894 if isinstance(out, torch.Tensor): 895 return fn(out) 896 897 # assumes (see above) that out is an iterable of tensors 898 return tuple(map(fn, out)) 899 900 # Extracts strides from a tensor or iterable of tensors into a tuple 901 def _extract_strides(out): 902 if isinstance(out, torch.Tensor): 903 return (out.stride(),) 904 905 # assumes (see above) that out is an iterable of tensors 906 return tuple(t.stride() for t in out) 907 908 # Extracts data pointers from a tensor or iterable of tensors into a tuple 909 # NOTE: only extracts on the CPU and CUDA device types since some 910 # device types don't have storage 911 def _extract_data_ptrs(out): 912 if self.device_type != "cpu" and self.device_type != "cuda": 913 return () 914 915 if isinstance(out, torch.Tensor): 916 return (out.data_ptr(),) 917 918 # assumes (see above) that out is an iterable of tensors 919 return tuple(t.data_ptr() for t in out) 920 921 def _compare_out(transform, *, compare_strides_and_data_ptrs=True): 922 out = _apply_out_transform(transform, expected) 923 original_strides = _extract_strides(out) 924 original_ptrs = _extract_data_ptrs(out) 925 926 op_out(out=out) 927 final_strides = _extract_strides(out) 928 final_ptrs = _extract_data_ptrs(out) 929 self.assertEqual(expected, out) 930 931 if compare_strides_and_data_ptrs: 932 stride_msg = ( 933 "Strides are not the same! " 934 f"Original strides were {original_strides} and strides are now {final_strides}" 935 ) 936 self.assertEqual(original_strides, final_strides, msg=stride_msg) 937 self.assertEqual(original_ptrs, final_ptrs) 938 939 # Case 0: out= with the correct shape, dtype, and device 940 # but NaN values for floating point and complex tensors, and 941 # maximum values for integer tensors. 942 # Expected behavior: out= values have no effect on the computation. 943 def _case_zero_transform(t): 944 try: 945 info = torch.iinfo(t.dtype) 946 return torch.full_like(t, info.max) 947 except TypeError as te: 948 # for non-integer types fills with NaN 949 return torch.full_like(t, float("nan")) 950 951 _compare_out(_case_zero_transform) 952 953 # Case 1: out= with the correct shape, dtype, and device, 954 # but noncontiguous. 955 # Expected behavior: strides are respected and `out` storage is not changed. 956 def _case_one_transform(t): 957 return make_tensor( 958 t.shape, dtype=t.dtype, device=t.device, noncontiguous=True 959 ) 960 961 _compare_out(_case_one_transform) 962 963 # Case 2: out= with the correct dtype and device, but has no elements. 964 # Expected behavior: resize without warning. 965 def _case_two_transform(t): 966 return make_tensor((0,), dtype=t.dtype, device=t.device) 967 968 _compare_out(_case_two_transform, compare_strides_and_data_ptrs=False) 969 970 # Also validates that no warning is thrown when this out is resized 971 out = _apply_out_transform(_case_two_transform, expected) 972 with warnings.catch_warnings(record=True) as caught: 973 warnings.simplefilter("always") 974 op_out(out=out) 975 976 # Verifies no warning is a resize warning 977 for w in caught: 978 if "An output with one or more elements" in str(w.message): 979 self.fail( 980 "Resizing an out= argument with no elements threw a resize warning!" 981 ) 982 983 # Case 3: out= with correct shape and dtype, but wrong device. 984 wrong_device = None 985 if torch.device(device).type != "cpu": 986 wrong_device = "cpu" 987 elif torch.cuda.is_available(): 988 wrong_device = "cuda" 989 990 factory_fn_msg = ( 991 "\n\nNOTE: If your op is a factory function (i.e., it accepts TensorOptions) you should mark its " 992 "OpInfo with `is_factory_function=True`." 993 ) 994 if wrong_device is not None: 995 996 def _case_three_transform(t): 997 return make_tensor(t.shape, dtype=t.dtype, device=wrong_device) 998 999 out = _apply_out_transform(_case_three_transform, expected) 1000 1001 if op.is_factory_function and sample.kwargs.get("device", None) is None: 1002 op_out(out=out) 1003 else: 1004 msg_fail = ( 1005 f"Expected RuntimeError when calling with input.device={device} and out.device={wrong_device}." 1006 ) + factory_fn_msg 1007 with self.assertRaises(RuntimeError, msg=msg_fail): 1008 op_out(out=out) 1009 1010 # Case 4: out= with correct shape and device, but a dtype 1011 # that output cannot be "safely" cast to (long). 1012 # Expected behavior: error. 1013 # NOTE: this case is filtered by dtype since some ops produce 1014 # bool tensors, for example, which can be safely cast to any 1015 # dtype. It is applied when single tensors are floating point or complex 1016 # dtypes, or if an op returns multiple tensors when at least one such 1017 # tensor is a floating point or complex dtype. 1018 _dtypes = floating_and_complex_types_and(torch.float16, torch.bfloat16) 1019 if ( 1020 isinstance(expected, torch.Tensor) 1021 and expected.dtype in _dtypes 1022 or ( 1023 not isinstance(expected, torch.Tensor) 1024 and any(t.dtype in _dtypes for t in expected) 1025 ) 1026 ): 1027 1028 def _case_four_transform(t): 1029 return make_tensor(t.shape, dtype=torch.long, device=t.device) 1030 1031 out = _apply_out_transform(_case_four_transform, expected) 1032 msg_fail = "Expected RuntimeError when doing an unsafe cast!" 1033 msg_fail = ( 1034 msg_fail 1035 if not isinstance(expected, torch.Tensor) 1036 else ( 1037 "Expected RuntimeError when doing an unsafe cast from a result of dtype " 1038 f"{expected.dtype} into an out= with dtype torch.long" 1039 ) 1040 ) + factory_fn_msg 1041 1042 if op.is_factory_function and sample.kwargs.get("dtype", None) is None: 1043 op_out(out=out) 1044 else: 1045 with self.assertRaises(RuntimeError, msg=msg_fail): 1046 op_out(out=out) 1047 1048 @ops( 1049 [ 1050 op 1051 for op in op_db 1052 if op.supports_out and (op.supports_autograd or op.is_factory_function) 1053 ], 1054 dtypes=OpDTypes.supported, 1055 allowed_dtypes=[torch.float, torch.cfloat], 1056 ) 1057 def test_out_requires_grad_error(self, device, dtype, op): 1058 sample = first_sample(self, op.sample_inputs(device, dtype)) 1059 1060 # Call op to get prototype for out arguments 1061 expect = op(sample.input, *sample.args, **sample.kwargs) 1062 any_requires_grad = False 1063 1064 def set_requires_grad(x): 1065 nonlocal any_requires_grad 1066 if isinstance(x, torch.Tensor) and ( 1067 x.is_floating_point() or x.is_complex() 1068 ): 1069 any_requires_grad = True 1070 x.requires_grad_(True) 1071 return x 1072 1073 out = pytree.tree_map_(set_requires_grad, expect) 1074 if not any_requires_grad: 1075 # Skip ops without any floating point outputs, e.g. isnan 1076 return 1077 1078 msg = ( 1079 "functions with out=... arguments don't support automatic " 1080 "differentiation, but one of the arguments requires grad." 1081 ) 1082 with self.assertRaises(RuntimeError, msg=msg): 1083 op(sample.input, *sample.args, **sample.kwargs, out=out) 1084 1085 @ops(filter(reduction_dtype_filter, ops_and_refs), dtypes=(torch.int16,)) 1086 def test_out_integral_dtype(self, device, dtype, op): 1087 def helper(with_out, expectFail, op_to_test, inputs, *args, **kwargs): 1088 out = None 1089 try: 1090 if with_out: 1091 out = torch.empty(0, dtype=torch.int32, device=device) 1092 op_to_test(inputs, *args, out=out, **kwargs) 1093 else: 1094 out = op_to_test(inputs, *args, **kwargs) 1095 self.assertFalse(expectFail) 1096 except RuntimeError as err: 1097 self.assertEqual( 1098 str(err), "dtype argument and out dtype must match in reduction" 1099 ) 1100 self.assertTrue(expectFail) 1101 return out 1102 1103 samples = op.sample_inputs(device, dtype) 1104 for sample in samples: 1105 if "dtype" not in sample.kwargs: 1106 helper(False, False, op, sample.input, *sample.args, **sample.kwargs) 1107 helper(True, False, op, sample.input, *sample.args, **sample.kwargs) 1108 sample.kwargs["dtype"] = torch.int16 1109 helper(False, False, op, sample.input, *sample.args, **sample.kwargs) 1110 helper(True, True, op, sample.input, *sample.args, **sample.kwargs) 1111 sample.kwargs["dtype"] = torch.int32 1112 helper(False, False, op, sample.input, *sample.args, **sample.kwargs) 1113 helper(True, False, op, sample.input, *sample.args, **sample.kwargs) 1114 else: 1115 helper(False, False, op, sample.input, *sample.args, **sample.kwargs) 1116 helper( 1117 True, 1118 sample.kwargs["dtype"] != torch.int32, 1119 op, 1120 sample.input, 1121 *sample.args, 1122 **sample.kwargs, 1123 ) 1124 1125 # Tests that the forward and backward passes of operations produce the 1126 # same values for the cross-product of op variants (method, inplace) 1127 # against eager's gold standard op function variant 1128 @_variant_ops(op_db) 1129 def test_variant_consistency_eager(self, device, dtype, op): 1130 # Acquires variants (method variant, inplace variant, operator variant, inplace_operator variant, aliases) 1131 1132 method = op.method_variant 1133 inplace = op.inplace_variant 1134 operator = op.operator_variant 1135 inplace_operator = op.inplace_operator_variant 1136 1137 # list of all inplace ops: inplace variant + alias inplace variants if exist 1138 inplace_ops = [inplace, inplace_operator] 1139 variants = [method, inplace, operator, inplace_operator] 1140 operators = [operator, inplace_operator] 1141 1142 for a_op in op.aliases: 1143 variants.append(a_op.op) 1144 variants.append(a_op.method_variant) 1145 variants.append(a_op.inplace_variant) 1146 inplace_ops.append(a_op.inplace_variant) 1147 1148 inplace_variants = tuple(filter(None, inplace_ops)) 1149 variants = tuple(filter(None, variants)) 1150 operators = tuple(filter(None, operators)) 1151 1152 _requires_grad = dtype in op.supported_backward_dtypes( 1153 torch.device(device).type 1154 ) 1155 1156 include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex 1157 samples = op.sample_inputs( 1158 device, 1159 dtype, 1160 requires_grad=_requires_grad, 1161 include_conjugated_inputs=include_conjugated_inputs, 1162 ) 1163 samples = list(samples) 1164 1165 def _test_consistency_helper(samples, variants): 1166 for sample in samples: 1167 # TODO: Check grad for all Tensors requiring grad if sample.input is TensorList 1168 tensor = ( 1169 sample.input 1170 if isinstance(sample.input, torch.Tensor) 1171 else sample.input[0] 1172 ) 1173 1174 # Computes function forward and backward values 1175 tensor.grad = None 1176 expected_forward = op(sample.input, *sample.args, **sample.kwargs) 1177 expected_grad = None 1178 1179 output_process_fn_grad = ( 1180 sample.output_process_fn_grad 1181 if sample.output_process_fn_grad 1182 else lambda x: x 1183 ) 1184 1185 # Skips inplace variants if the output dtype is not the same as 1186 # the input dtype 1187 skip_inplace = False 1188 if ( 1189 isinstance(expected_forward, torch.Tensor) 1190 and expected_forward.dtype is not tensor.dtype 1191 ): 1192 skip_inplace = True 1193 1194 # TODO: backward consistency only supported for single tensor outputs 1195 # TODO: backward consistency only checked on sample.input, not all 1196 # tensor inputs 1197 # TODO: update to handle checking grads of all tensor inputs as 1198 # derived from each tensor output 1199 if isinstance( 1200 expected_forward, torch.Tensor 1201 ) and dtype in op.supported_backward_dtypes(torch.device(device).type): 1202 out = output_process_fn_grad(expected_forward).sum() 1203 if out.dtype.is_complex: 1204 out = out.abs() 1205 out.backward() 1206 expected_grad = tensor.grad 1207 1208 # Test eager consistency 1209 for variant in variants: 1210 # Skips inplace ops 1211 if variant in inplace_ops and skip_inplace: 1212 continue 1213 1214 # Compares variant's forward 1215 # Note: copies the to-be-modified input when testing the inplace variant 1216 tensor.grad = None 1217 cloned = ( 1218 clone_input_helper(sample.input) 1219 if variant in inplace_ops 1220 else sample.input 1221 ) 1222 1223 if variant in inplace_ops and sample.broadcasts_input: 1224 with self.assertRaises( 1225 RuntimeError, 1226 msg=( 1227 "inplace variant either incorrectly allowed " 1228 f"resizing or you have marked the sample {sample.summary()}" 1229 " incorrectly with `broadcasts_self=True" 1230 ), 1231 ): 1232 variant_forward = variant( 1233 cloned, *sample.args, **sample.kwargs 1234 ) 1235 continue 1236 1237 if variant in operators and sample.kwargs: 1238 # skip samples with kwargs for operator variants 1239 continue 1240 1241 variant_forward = variant(cloned, *sample.args, **sample.kwargs) 1242 self.assertEqual(expected_forward, variant_forward) 1243 1244 # Compares variant's backward 1245 if expected_grad is not None and ( 1246 variant not in inplace_ops or op.supports_inplace_autograd 1247 ): 1248 out = output_process_fn_grad(variant_forward).sum() 1249 if out.dtype.is_complex: 1250 out = out.abs() 1251 out.backward() 1252 self.assertEqual(expected_grad, tensor.grad) 1253 1254 _test_consistency_helper(samples, variants) 1255 1256 def _test_inplace_preserve_storage(samples, variants): 1257 for sample in samples: 1258 # Skips inplace variants if the output dtype is not the same as 1259 # the input dtype 1260 expected_forward = op(sample.input, *sample.args, **sample.kwargs) 1261 tensor = ( 1262 sample.input 1263 if isinstance(sample.input, torch.Tensor) 1264 else sample.input[0] 1265 ) 1266 skip_inplace = False 1267 if ( 1268 isinstance(expected_forward, torch.Tensor) 1269 and expected_forward.dtype is not tensor.dtype 1270 ): 1271 skip_inplace = True 1272 if skip_inplace: 1273 return 1274 for variant in variants: 1275 cloned = ( 1276 clone_input_helper(sample.input) 1277 if variant in inplace_ops 1278 else sample.input 1279 ) 1280 inp_tensor = ( 1281 cloned if isinstance(cloned, torch.Tensor) else cloned[0] 1282 ) 1283 data_ptr = inp_tensor.data_ptr() 1284 if variant in operators and sample.kwargs: 1285 # skip samples with kwargs for operator variants 1286 continue 1287 1288 variant_forward = variant(cloned, *sample.args, **sample.kwargs) 1289 # TODO Support non-tensor outputs if they exist for inplace ops 1290 if isinstance(variant_forward, torch.Tensor): 1291 self.assertEqual( 1292 data_ptr, variant_forward.data_ptr(), atol=0, rtol=0 1293 ) 1294 else: 1295 self.assertTrue( 1296 False, 1297 "Non-tensor outputs for inplace ops are not supported", 1298 ) 1299 1300 if len(inplace_ops) > 0: 1301 inplace_samples = list( 1302 filter(lambda sample: not sample.broadcasts_input, samples) 1303 ) 1304 _test_inplace_preserve_storage(inplace_samples, inplace_variants) 1305 1306 # Reference testing for operations in complex32 against complex64. 1307 # NOTE: We test against complex64 as NumPy doesn't have a complex32 equivalent dtype. 1308 @ops(op_db, allowed_dtypes=(torch.complex32,)) 1309 def test_complex_half_reference_testing(self, device, dtype, op): 1310 if not op.supports_dtype(torch.complex32, device): 1311 unittest.skip("Does not support complex32") 1312 1313 for sample in op.sample_inputs(device, dtype): 1314 actual = op(sample.input, *sample.args, **sample.kwargs) 1315 # sample.transform applies the lambda to torch.Tensor and torch.dtype. 1316 # However, we only want to apply it to Tensors with dtype `torch.complex32`.. 1317 transformed_sample = sample.transform( 1318 lambda x: x.to(torch.complex64) 1319 if isinstance(x, torch.Tensor) and x.dtype is torch.complex32 1320 else x 1321 ) 1322 expected = op( 1323 transformed_sample.input, 1324 *transformed_sample.args, 1325 **transformed_sample.kwargs, 1326 ) 1327 # Since range of chalf is much less compared to cfloat, 1328 # we get `inf`s easily (eg. with `pow`, `exp`), 1329 # so we cast `cfloat` back to `chalf`. 1330 expected = tree_map( 1331 lambda x: x.to(torch.complex32) 1332 if isinstance(x, torch.Tensor) and x.dtype is torch.complex64 1333 else x, 1334 expected, 1335 ) 1336 1337 # `exact_dtype` is False because for ops like real, imag 1338 # we get different dtypes for `actual` and `expected` 1339 # `chalf` input -> `half` output 1340 # `cfloat` input -> `float` output 1341 self.assertEqual(actual, expected, exact_dtype=False) 1342 1343 @ops(op_db, allowed_dtypes=(torch.bool,)) 1344 @unittest.skipIf(TEST_WITH_UBSAN, "Test uses undefined behavior") 1345 def test_non_standard_bool_values(self, device, dtype, op): 1346 # Test boolean values other than 0x00 and 0x01 (gh-54789) 1347 def convert_boolean_tensors(x): 1348 if not isinstance(x, torch.Tensor) or x.dtype != torch.bool: 1349 return x 1350 1351 # Map False -> 0 and True -> Random value in [2, 255] 1352 true_vals = torch.randint( 1353 2, 255, x.shape, dtype=torch.uint8, device=x.device 1354 ) 1355 false_vals = torch.zeros((), dtype=torch.uint8, device=x.device) 1356 x_int = torch.where(x, true_vals, false_vals) 1357 1358 ret = x_int.view(torch.bool) 1359 self.assertEqual(ret, x) 1360 return ret 1361 1362 for sample in op.sample_inputs(device, dtype): 1363 expect = op(sample.input, *sample.args, **sample.kwargs) 1364 1365 transformed = sample.transform(convert_boolean_tensors) 1366 actual = op(transformed.input, *transformed.args, **transformed.kwargs) 1367 1368 self.assertEqual(expect, actual) 1369 1370 # Validates that each OpInfo specifies its forward and backward dtypes 1371 # correctly for CPU and CUDA devices 1372 @skipMeta 1373 @onlyNativeDeviceTypesAnd(["hpu"]) 1374 @ops(ops_and_refs, dtypes=OpDTypes.none) 1375 def test_dtypes(self, device, op): 1376 # Check complex32 support only if the op claims. 1377 # TODO: Once the complex32 support is better, we should add check for complex32 unconditionally. 1378 device_type = torch.device(device).type 1379 include_complex32 = ( 1380 (torch.complex32,) 1381 if op.supports_dtype(torch.complex32, device_type) 1382 else () 1383 ) 1384 1385 # dtypes to try to backward in 1386 allowed_backward_dtypes = floating_and_complex_types_and( 1387 *((torch.half, torch.bfloat16) + include_complex32) 1388 ) 1389 1390 # lists for (un)supported dtypes 1391 supported_dtypes = set() 1392 unsupported_dtypes = set() 1393 supported_backward_dtypes = set() 1394 unsupported_backward_dtypes = set() 1395 dtype_error: Dict[torch.dtype, Exception] = {} 1396 1397 def unsupported(dtype, e): 1398 dtype_error[dtype] = e 1399 unsupported_dtypes.add(dtype) 1400 if dtype in allowed_backward_dtypes: 1401 unsupported_backward_dtypes.add(dtype) 1402 1403 for dtype in all_types_and_complex_and( 1404 *((torch.half, torch.bfloat16, torch.bool) + include_complex32) 1405 ): 1406 # tries to acquire samples - failure indicates lack of support 1407 requires_grad = dtype in allowed_backward_dtypes 1408 try: 1409 samples = tuple( 1410 op.sample_inputs(device, dtype, requires_grad=requires_grad) 1411 ) 1412 except Exception as e: 1413 unsupported(dtype, e) 1414 continue 1415 1416 for sample in samples: 1417 # tries to call operator with the sample - failure indicates 1418 # lack of support 1419 try: 1420 result = op(sample.input, *sample.args, **sample.kwargs) 1421 supported_dtypes.add(dtype) 1422 except Exception as e: 1423 # NOTE: some ops will fail in forward if their inputs 1424 # require grad but they don't support computing the gradient 1425 # in that type! This is a bug in the op! 1426 unsupported(dtype, e) 1427 continue 1428 1429 # Checks for backward support in the same dtype, if the input has 1430 # one or more tensors requiring grad 1431 def _tensor_requires_grad(x): 1432 if isinstance(x, dict): 1433 for v in x.values(): 1434 if _tensor_requires_grad(v): 1435 return True 1436 if isinstance(x, (list, tuple)): 1437 for a in x: 1438 if _tensor_requires_grad(a): 1439 return True 1440 if isinstance(x, torch.Tensor) and x.requires_grad: 1441 return True 1442 1443 return False 1444 1445 requires_grad = ( 1446 _tensor_requires_grad(sample.input) 1447 or _tensor_requires_grad(sample.args) 1448 or _tensor_requires_grad(sample.kwargs) 1449 ) 1450 if not requires_grad: 1451 continue 1452 1453 try: 1454 result = sample.output_process_fn_grad(result) 1455 if isinstance(result, torch.Tensor): 1456 backward_tensor = result 1457 elif isinstance(result, Sequence) and isinstance( 1458 result[0], torch.Tensor 1459 ): 1460 backward_tensor = result[0] 1461 else: 1462 continue 1463 1464 # Note: this grad may not have the same dtype as dtype 1465 # For functions like complex (float -> complex) or abs 1466 # (complex -> float) the grad tensor will have a 1467 # different dtype than the input. 1468 # For simplicity, this is still modeled as these ops 1469 # supporting grad in the input dtype. 1470 grad = torch.randn_like(backward_tensor) 1471 backward_tensor.backward(grad) 1472 supported_backward_dtypes.add(dtype) 1473 except Exception as e: 1474 dtype_error[dtype] = e 1475 unsupported_backward_dtypes.add(dtype) 1476 1477 # Checks that dtypes are listed correctly and generates an informative 1478 # error message 1479 1480 supported_forward = supported_dtypes - unsupported_dtypes 1481 partially_supported_forward = supported_dtypes & unsupported_dtypes 1482 unsupported_forward = unsupported_dtypes - supported_dtypes 1483 supported_backward = supported_backward_dtypes - unsupported_backward_dtypes 1484 partially_supported_backward = ( 1485 supported_backward_dtypes & unsupported_backward_dtypes 1486 ) 1487 unsupported_backward = unsupported_backward_dtypes - supported_backward_dtypes 1488 1489 device_type = torch.device(device).type 1490 1491 claimed_forward = set(op.supported_dtypes(device_type)) 1492 supported_but_unclaimed_forward = supported_forward - claimed_forward 1493 claimed_but_unsupported_forward = claimed_forward & unsupported_forward 1494 1495 claimed_backward = set(op.supported_backward_dtypes(device_type)) 1496 supported_but_unclaimed_backward = supported_backward - claimed_backward 1497 claimed_but_unsupported_backward = claimed_backward & unsupported_backward 1498 1499 # Partially supporting a dtype is not an error, but we print a warning 1500 if (len(partially_supported_forward) + len(partially_supported_backward)) > 0: 1501 msg = f"Some dtypes for {op.name} on device type {device_type} are only partially supported!\n" 1502 if len(partially_supported_forward) > 0: 1503 msg = ( 1504 msg 1505 + f"The following dtypes only worked on some samples during forward: {partially_supported_forward}.\n" 1506 ) 1507 if len(partially_supported_backward) > 0: 1508 msg = ( 1509 msg 1510 + f"The following dtypes only worked on some samples during backward: {partially_supported_backward}.\n" 1511 ) 1512 print(msg) 1513 1514 if ( 1515 len(supported_but_unclaimed_forward) 1516 + len(claimed_but_unsupported_forward) 1517 + len(supported_but_unclaimed_backward) 1518 + len(claimed_but_unsupported_backward) 1519 ) == 0: 1520 return 1521 1522 # Reference operators often support additional dtypes, and that's OK 1523 if op in python_ref_db: 1524 if ( 1525 len(claimed_but_unsupported_forward) 1526 + len(claimed_but_unsupported_backward) 1527 ) == 0: 1528 return 1529 1530 # Generates error msg 1531 msg = f"The supported dtypes for {op.name} on device type {device_type} are incorrect!\n" 1532 if len(supported_but_unclaimed_forward) > 0: 1533 msg = ( 1534 msg 1535 + "The following dtypes worked in forward but are not listed by the OpInfo: " 1536 + f"{supported_but_unclaimed_forward}.\n" 1537 ) 1538 if len(supported_but_unclaimed_backward) > 0: 1539 msg = ( 1540 msg 1541 + "The following dtypes worked in backward but are not listed by the OpInfo: " 1542 + f"{supported_but_unclaimed_backward}.\n" 1543 ) 1544 if len(claimed_but_unsupported_forward) > 0: 1545 msg = ( 1546 msg 1547 + "The following dtypes did not work in forward but are listed by the OpInfo: " 1548 + f"{claimed_but_unsupported_forward}.\n" 1549 ) 1550 if len(claimed_but_unsupported_backward) > 0: 1551 msg = ( 1552 msg 1553 + "The following dtypes did not work in backward " 1554 + f"but are listed by the OpInfo: {claimed_but_unsupported_backward}.\n" 1555 ) 1556 1557 all_claimed_but_unsupported = set.union( 1558 claimed_but_unsupported_backward, claimed_but_unsupported_forward 1559 ) 1560 if all_claimed_but_unsupported: 1561 msg += "Unexpected failures raised the following errors:\n" 1562 for dtype in all_claimed_but_unsupported: 1563 msg += f"{dtype} - {dtype_error[dtype]}\n" 1564 1565 self.fail(msg) 1566 1567 # Validates that each OpInfo that sets promotes_int_to_float=True does as it says 1568 @skipMeta 1569 @onlyNativeDeviceTypesAnd(["hpu"]) 1570 @ops( 1571 (op for op in op_db if op.promotes_int_to_float), 1572 allowed_dtypes=integral_types_and(torch.bool), 1573 ) 1574 def test_promotes_int_to_float(self, device, dtype, op): 1575 for sample in op.sample_inputs(device, dtype): 1576 output = op(sample.input, *sample.args, **sample.kwargs) 1577 if not output.dtype.is_floating_point: 1578 self.fail( 1579 f"The OpInfo sets `promotes_int_to_float=True`, but {dtype} was promoted to {output.dtype}." 1580 ) 1581 1582 1583@unMarkDynamoStrictTest 1584class TestCompositeCompliance(TestCase): 1585 # Checks if the operator (if it is composite) is written to support most 1586 # backends and Tensor subclasses. See "CompositeImplicitAutograd Compliance" 1587 # in aten/src/ATen/native/README.md for more details 1588 @unittest.skipIf( 1589 IS_FBCODE or IS_SANDCASTLE, "__torch_dispatch__ does not work in fbcode" 1590 ) 1591 @ops(op_db, allowed_dtypes=(torch.float,)) 1592 def test_operator(self, device, dtype, op): 1593 samples = op.sample_inputs(device, dtype, requires_grad=False) 1594 1595 for sample in samples: 1596 args = [sample.input] + list(sample.args) 1597 kwargs = sample.kwargs 1598 composite_compliance.check_with_mode(op, args, kwargs, self.assertEqual) 1599 composite_compliance.check_all_permutations( 1600 op, args, kwargs, self.assertEqual 1601 ) 1602 1603 @unittest.skipIf( 1604 IS_FBCODE or IS_SANDCASTLE, "__torch_dispatch__ does not work in fbcode" 1605 ) 1606 @ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,)) 1607 def test_backward(self, device, dtype, op): 1608 samples = op.sample_inputs(device, dtype, requires_grad=True) 1609 1610 for sample in samples: 1611 args = [sample.input] + list(sample.args) 1612 kwargs = sample.kwargs 1613 # We pass assertEqual so that decorators like `toleranceOverride` 1614 # actually work (otherwise they silently do nothing!) 1615 composite_compliance.check_backward_formula( 1616 op.get_op(), 1617 args, 1618 kwargs, 1619 sample.output_process_fn_grad, 1620 op.gradcheck_wrapper, 1621 self.assertEqual, 1622 ) 1623 1624 @unittest.skipIf( 1625 IS_FBCODE or IS_SANDCASTLE, "__torch_dispatch__ does not work in fbcode" 1626 ) 1627 @ops(op_db, allowed_dtypes=(torch.float,)) 1628 def test_forward_ad(self, device, dtype, op): 1629 if torch.float not in op.supported_backward_dtypes(device): 1630 raise unittest.SkipTest("Does not support autograd") 1631 1632 if not op.supports_forward_ad: 1633 raise unittest.SkipTest("Does not support forward_ad") 1634 1635 samples = op.sample_inputs(device, dtype, requires_grad=True) 1636 1637 for sample in samples: 1638 args = [sample.input] + list(sample.args) 1639 kwargs = sample.kwargs 1640 # We pass assertEqual so that decorators like `toleranceOverride` 1641 # actually work (otherwise they silently do nothing!) 1642 composite_compliance.check_forward_ad_formula( 1643 op.get_op(), args, kwargs, op.gradcheck_wrapper, self.assertEqual 1644 ) 1645 1646 @ops(op_db, allowed_dtypes=(torch.float,)) 1647 def test_cow_input(self, device, dtype, op): 1648 samples = op.sample_inputs(device, dtype, requires_grad=op.supports_autograd) 1649 1650 def is_strided_tensor(arg): 1651 return torch.is_tensor(arg) and arg.layout == torch.strided 1652 1653 def check_ignore_materialize(idx_or_kw, allow_list): 1654 return (allow_list is not None) and (idx_or_kw in allow_list) 1655 1656 def check_cow_input( 1657 arg, 1658 arg_copy, 1659 idx_or_kw, 1660 backward_or_forward="forward", 1661 supports_cow_input_no_materialize=op.supports_cow_input_no_materialize_forward, 1662 allow_list=op.allow_cow_input_materialize_forward, 1663 ): 1664 arg_name = ( 1665 f"Argument {idx_or_kw}" 1666 if isinstance(idx_or_kw, int) 1667 else f"Keyword argument '{idx_or_kw}'" 1668 ) + f" during {backward_or_forward} call" 1669 1670 if is_strided_tensor(arg): 1671 is_cow = torch._C._is_cow_tensor(arg) 1672 1673 if supports_cow_input_no_materialize and not check_ignore_materialize( 1674 idx_or_kw, allow_list 1675 ): 1676 self.assertTrue( 1677 is_cow, 1678 msg=( 1679 f"{arg_name} unexpectedly materializes. " 1680 f"Either set `supports_cow_input_no_materialize_{backward_or_forward}=False` " 1681 "in this operation's OpInfo, add the arg to the OpInfo's " 1682 f"`allow_cow_input_materialize_{backward_or_forward}` list, or change the " 1683 "implementation to avoid materialization." 1684 ), 1685 ) 1686 1687 if is_cow: 1688 self.assertTrue( 1689 torch.allclose(arg, arg_copy, rtol=0, atol=0, equal_nan=True), 1690 msg=( 1691 f"{arg_name} avoided materialization, " 1692 "but the operation mutated its data." 1693 ), 1694 ) 1695 1696 for sample in samples: 1697 args_raw = [sample.input] + list(sample.args) 1698 kwargs_raw = sample.kwargs 1699 args_copy = [] 1700 args = [] 1701 kwargs_copy = {} 1702 kwargs = {} 1703 1704 # Convert strided tensor inputs to COW tensors and make copies of 1705 # all inputs 1706 for idx, arg in enumerate(args_raw): 1707 if is_strided_tensor(arg): 1708 args_copy.append(arg.clone().detach()) 1709 args.append(torch._lazy_clone(arg)) 1710 else: 1711 if torch.is_tensor(arg): 1712 args_copy.append(arg.clone().detach()) 1713 else: 1714 args_copy.append(copy.deepcopy(arg)) 1715 args.append(arg) 1716 1717 for kw, arg in kwargs_raw.items(): 1718 if is_strided_tensor(arg): 1719 kwargs_copy[kw] = arg.clone().detach() 1720 kwargs[kw] = torch._lazy_clone(arg) 1721 else: 1722 if torch.is_tensor(arg): 1723 kwargs_copy[kw] = arg.clone().detach() 1724 else: 1725 kwargs_copy[kw] = copy.deepcopy(arg) 1726 kwargs[kw] = arg 1727 1728 leaf_tensors = composite_compliance.gather_leaf_tensors(args, kwargs) 1729 1730 # Call forward op 1731 results_raw = op.get_op()(*args, **kwargs) 1732 1733 # Check that COW inputs remain COW after the forward op is executed 1734 for idx, arg in enumerate(args): 1735 check_cow_input(arg, args_copy[idx], idx) 1736 1737 for kw, arg in kwargs.items(): 1738 check_cow_input(arg, kwargs_copy[kw], kw) 1739 1740 # Call backward op if it is supported. This part of the test is 1741 # based on `composite_compliance.check_backward_formula` 1742 if ( 1743 op.supports_autograd 1744 and len(leaf_tensors) > 0 1745 and not op.skip_cow_input_backward 1746 ): 1747 if sample.output_process_fn_grad is not None: 1748 results_raw = sample.output_process_fn_grad(results_raw) 1749 1750 leaf_results = pytree.tree_leaves(results_raw) 1751 results = [ 1752 r 1753 for r in leaf_results 1754 if isinstance(r, torch.Tensor) and r.requires_grad 1755 ] 1756 1757 all_results_strided = all( 1758 is_strided_tensor(result) for result in results 1759 ) 1760 1761 # Only test backward if the results are strided tensors 1762 if all_results_strided: 1763 output_grads_raw = [ 1764 torch.ones(r.shape, device=r.device, dtype=r.dtype) 1765 for r in results 1766 ] 1767 output_grads_copy = [] 1768 output_grads = [] 1769 1770 # Convert output grads to COW tensors and make copies 1771 for output_grad in output_grads_raw: 1772 output_grads_copy.append(output_grad.clone().detach()) 1773 output_grads.append(torch._lazy_clone(output_grad)) 1774 1775 input_grads = torch.autograd.grad( 1776 results, 1777 leaf_tensors, 1778 output_grads, 1779 allow_unused=True, 1780 retain_graph=True, 1781 ) 1782 1783 # Check that COW inputs remain COW after the backward op is executed 1784 for idx, arg in enumerate(args): 1785 check_cow_input( 1786 arg, 1787 args_copy[idx], 1788 idx, 1789 backward_or_forward="backward", 1790 supports_cow_input_no_materialize=op.supports_cow_input_no_materialize_backward, 1791 allow_list=op.allow_cow_input_materialize_backward, 1792 ) 1793 1794 # Check that COW inputs remain COW after the backward op is executed 1795 for idx, output_grad in enumerate(output_grads): 1796 check_cow_input( 1797 output_grad, 1798 output_grads_copy[idx], 1799 f"output grad {idx}", 1800 backward_or_forward="backward", 1801 supports_cow_input_no_materialize=op.supports_cow_input_no_materialize_backward, 1802 allow_list=op.allow_cow_input_materialize_backward, 1803 ) 1804 1805 @ops(op_db, allowed_dtypes=(torch.float,)) 1806 def test_view_replay(self, device, dtype, op): 1807 def _assert_match_metadata(a, b): 1808 self.assertEqual(a.size(), b.size()) 1809 self.assertEqual(a.stride(), b.stride()) 1810 self.assertEqual(a.storage_offset(), b.storage_offset()) 1811 self.assertEqual(a.device, b.device) 1812 self.assertEqual(a.dtype, b.dtype) 1813 1814 # ensure view replay is enabled 1815 with torch.autograd._force_original_view_tracking(True): 1816 for sample in op.sample_inputs(device, dtype, requires_grad=False): 1817 inp = sample.input 1818 outs = op(inp, *sample.args, **sample.kwargs) 1819 if not isinstance(outs, (tuple, List)): 1820 outs = [outs] 1821 1822 # for all outputs that are views of the input, we should be able to replay the 1823 # forward and reverse views via a functioning view_func() / rev_view_func(). 1824 for out in outs: 1825 if not ( 1826 isinstance(out, torch.Tensor) 1827 and out._is_view() 1828 and out._base is inp 1829 ): 1830 continue 1831 1832 # forward view_func 1833 new_inp = inp.clone() 1834 _assert_match_metadata(new_inp, inp) 1835 new_out = out._view_func_unsafe(new_inp) 1836 _assert_match_metadata(new_out, out) 1837 self.assertEqual(new_out, out) 1838 1839 # reverse view_func 1840 new_out = out.detach() 1841 new_inp = out._rev_view_func_unsafe(new_out) 1842 _assert_match_metadata(new_inp, inp) 1843 self.assertTrue(new_inp._is_view()) 1844 self.assertTrue(new_inp._base is new_out) 1845 1846 1847@unMarkDynamoStrictTest 1848class TestMathBits(TestCase): 1849 # Tests that 1850 # 1. The operator's output for physically conjugated/negated tensors and conjugate/negative view tensors 1851 # produces the same value 1852 # 2. The gradients are same in both cases mentioned in (1) 1853 # 3. If the operator's inplace variant is supported, tests that the inplace operation 1854 # produces the correct value when called on a conjugate/negative view tensor and that the output 1855 # has its conj/neg bit set to true 1856 # This test only runs for C -> R and C -> C functions 1857 # TODO: add tests for `R->C` functions 1858 # Note: This test runs for functions that take both tensors and tensorlists as input. 1859 def _test_math_view( 1860 self, 1861 device, 1862 dtype, 1863 op, 1864 samples, 1865 math_op_physical, 1866 math_op_view, 1867 is_bit_set, 1868 out_type, 1869 ): 1870 inplace_variant = op.inplace_variant 1871 1872 # helper function to clone and conjugate/negate the input if its a tensor 1873 # else clone the sequence and conjugate/negate the first element in the sequence 1874 # If a requires_grad argument is provided the tensor being conjugated/negated will 1875 # have its requires_grad set to that value. 1876 def clone_and_perform_view(input, **kwargs): 1877 if isinstance(input, torch.Tensor): 1878 requires_grad = kwargs.get("requires_grad", input.requires_grad) 1879 with torch.no_grad(): 1880 # Ensure view represents the original sample input 1881 input = math_op_physical(input) 1882 # Note: .conj() is not called under no_grad mode since it's not allowed to modify a 1883 # view created in no_grad mode. Here it's ok to do so, so as a workaround we call conj 1884 # before resetting the requires_grad field for input 1885 input = math_op_view(input) 1886 assert input.is_leaf 1887 return input.requires_grad_(requires_grad) 1888 1889 if isinstance(input, Sequence): 1890 out = list(map(clone_input_helper, input)) 1891 out[0] = clone_and_perform_view(out[0]) 1892 return tuple(out) 1893 1894 for sample in samples: 1895 tensor = ( 1896 sample.input 1897 if isinstance(sample.input, torch.Tensor) 1898 else sample.input[0] 1899 ) 1900 cloned1 = clone_and_perform_view(sample.input) 1901 1902 # Computes function forward value with a physically conjugated/negated tensor and 1903 # a conj/neg view tensor and verifies that the output in both case are equal. 1904 expected_forward = op(sample.input, *sample.args, **sample.kwargs) 1905 forward_with_mathview = op(cloned1, *sample.args, **sample.kwargs) 1906 self.assertEqual(expected_forward, forward_with_mathview) 1907 1908 # If the op has an inplace variant, and the input doesn't require broadcasting 1909 # and has the same dtype as output, verify that the inplace operation on a conjugated/negated 1910 # input produces correct output, and the output tensor has the conj/neg bit set to True 1911 if inplace_variant is not None and not sample.broadcasts_input: 1912 cloned2 = clone_and_perform_view(tensor, requires_grad=False) 1913 if ( 1914 isinstance(expected_forward, torch.Tensor) 1915 and expected_forward.dtype is tensor.dtype 1916 ): 1917 inplace_forward = inplace_variant( 1918 cloned2, *sample.args, **sample.kwargs 1919 ) 1920 self.assertTrue(is_bit_set(inplace_forward)) 1921 self.assertEqual(inplace_forward, expected_forward) 1922 1923 # TODO: backward consistency only supported for single tensor outputs 1924 # TODO: backward consistency only checked on sample.input, not all 1925 # tensor inputs 1926 # TODO: update to handle checking grads of all tensor inputs as 1927 # derived from each tensor output 1928 if ( 1929 isinstance(expected_forward, torch.Tensor) 1930 and expected_forward.requires_grad 1931 ): 1932 output_process_fn_grad = sample.output_process_fn_grad or (lambda x: x) 1933 expected_forward = output_process_fn_grad(expected_forward) 1934 forward_with_mathview = output_process_fn_grad(forward_with_mathview) 1935 1936 tensor = ( 1937 sample.input 1938 if isinstance(sample.input, torch.Tensor) 1939 else sample.input[0] 1940 ) 1941 expected_forward.sum().abs().backward(retain_graph=True) 1942 forward_with_mathview.sum().abs().backward(retain_graph=True) 1943 if tensor.grad is not None: 1944 cloned1_tensor = ( 1945 cloned1 if isinstance(cloned1, torch.Tensor) else cloned1[0] 1946 ) 1947 self.assertEqual(tensor.grad, cloned1_tensor.grad) 1948 1949 tensor.grad, cloned1_tensor.grad = None, None 1950 1951 # a repeat of the above test if output is not complex valued 1952 if out_type(expected_forward): 1953 grad = torch.randn_like(expected_forward) 1954 expected_forward.backward(grad) 1955 forward_with_mathview.backward( 1956 math_op_view(math_op_physical(grad)) 1957 ) 1958 1959 self.assertEqual(tensor.grad, cloned1_tensor.grad) 1960 1961 @ops(ops_and_refs, allowed_dtypes=(torch.cfloat,)) 1962 def test_conj_view(self, device, dtype, op): 1963 if not op.test_conjugated_samples: 1964 self.skipTest("Operation doesn't support conjugated inputs.") 1965 math_op_physical = torch.conj_physical 1966 math_op_view = torch.conj 1967 _requires_grad = torch.cfloat in op.supported_backward_dtypes( 1968 torch.device(device).type 1969 ) 1970 is_bit_set = torch.is_conj 1971 samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad) 1972 self._test_math_view( 1973 device, 1974 dtype, 1975 op, 1976 samples, 1977 math_op_physical, 1978 math_op_view, 1979 is_bit_set, 1980 torch.is_complex, 1981 ) 1982 1983 @ops(ops_and_refs, allowed_dtypes=(torch.double,)) 1984 def test_neg_view(self, device, dtype, op): 1985 if not op.test_neg_view: 1986 self.skipTest("Operation not tested with tensors with negative bit.") 1987 math_op_physical = torch.neg 1988 math_op_view = torch._neg_view 1989 is_bit_set = torch.is_neg 1990 samples = op.sample_inputs(device, dtype, requires_grad=op.supports_autograd) 1991 self._test_math_view( 1992 device, 1993 dtype, 1994 op, 1995 samples, 1996 math_op_physical, 1997 math_op_view, 1998 is_bit_set, 1999 lambda x: True, 2000 ) 2001 2002 @ops(ops_and_refs, allowed_dtypes=(torch.cdouble,)) 2003 def test_neg_conj_view(self, device, dtype, op): 2004 if not op.test_neg_view: 2005 self.skipTest("Operation not tested with tensors with negative bit.") 2006 if not op.test_conjugated_samples: 2007 self.skipTest("Operation doesn't support conjugated inputs.") 2008 2009 def math_op_physical(x): 2010 return -x.conj_physical() 2011 2012 def math_op_view(x): 2013 return torch._neg_view(x).conj() 2014 2015 def is_bit_set(x): 2016 return torch.is_neg(x) and torch.is_conj(x) 2017 2018 _requires_grad = dtype in op.supported_backward_dtypes( 2019 torch.device(device).type 2020 ) 2021 samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad) 2022 # Only test one sample 2023 samples = itertools.islice(samples, 1) 2024 self._test_math_view( 2025 device, 2026 dtype, 2027 op, 2028 samples, 2029 math_op_physical, 2030 math_op_view, 2031 is_bit_set, 2032 torch.is_complex, 2033 ) 2034 2035 2036# input strides and size may have been altered due to the result of an inplace op 2037def check_inplace_view(func, input, rs, input_size, input_strides): 2038 if func is None: 2039 return 2040 # TODO: extend this test to test ops with multiple outputs and ops like native_batch_norm(_legit).out 2041 # which mutate not necessarily the first input. 2042 if isinstance(rs, torch.Tensor) and rs is input: 2043 unequal_size = rs.size() != input_size 2044 unequal_strides = rs.stride() != input_strides 2045 # resize_ should probably have inplace_view tag. Not adding the tag since it 2046 # breaks some codegen logic 2047 if unequal_size or unequal_strides: 2048 if isinstance(func, torch._ops.OpOverloadPacket): 2049 func = func.default 2050 # Reference: https://github.com/pytorch/pytorch/issues/78759 2051 if func is not torch.ops.aten.resize_.default: 2052 # TODO: use self.assertIn when we have separate tests for each tag 2053 assert torch.Tag.inplace_view in func.tags 2054 2055 2056# A mode that when enabled runs correctness checks to ensure 2057# that operators have expected tags based on their input and 2058# output tensor properties 2059class TestTagsMode(TorchDispatchMode): 2060 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 2061 if isinstance(args[0], torch.Tensor): 2062 old_size = args[0].size() 2063 old_stride = args[0].stride() 2064 rs = func(*args, **kwargs) 2065 check_inplace_view(func, args[0], rs, old_size, old_stride) 2066 else: 2067 rs = func(*args, **kwargs) 2068 return rs 2069 2070 2071# Test to verify the correctness for tags in `tags.yaml`, also available for access through `torch.Tags` 2072@unMarkDynamoStrictTest 2073class TestTags(TestCase): 2074 @onlyCPU 2075 @ops(ops_and_refs, dtypes=OpDTypes.any_one) 2076 def test_tags(self, device, dtype, op): 2077 samples = op.sample_inputs(device, dtype, requires_grad=False) 2078 for sample in samples: 2079 # TODO: Test tags for ops that return a list of tensors 2080 input = sample.input 2081 if isinstance(input, torch.Tensor): 2082 old_size = input.size() 2083 old_stride = input.stride() 2084 with TestTagsMode(): 2085 rs = op(input, *sample.args, **sample.kwargs) 2086 # TODO: add test for aliases: https://github.com/pytorch/pytorch/issues/78761 2087 aten_name = op.aten_name if op.aten_name is not None else op.name 2088 opoverloadpacket = getattr(torch.ops.aten, aten_name, None) 2089 check_inplace_view(opoverloadpacket, input, rs, old_size, old_stride) 2090 2091 2092class TestSelfKwarg(TestCase): 2093 def test_self_kwargs(self): 2094 """Verify that we can call the aten ops with all kwargs even if the 2095 argument's name is "self" 2096 """ 2097 torch.ops.aten.reshape.default(self=torch.rand(1, 2), shape=[2]) 2098 torch.ops.aten.min.default(self=torch.rand(100)) 2099 2100 2101@unMarkDynamoStrictTest 2102class TestRefsOpsInfo(TestCase): 2103 import_paths = [ 2104 "_refs", 2105 "_refs.special", 2106 "_refs.nn.functional", 2107 "_refs.fft", 2108 "_refs._conversions", 2109 ] 2110 module_alls = [ 2111 (path, import_module(f"torch.{path}").__all__) for path in import_paths 2112 ] 2113 ref_ops_names = tuple( 2114 itertools.chain.from_iterable( 2115 [f"{path}.{op}" for op in module_all] for path, module_all in module_alls 2116 ) 2117 ) 2118 ref_db_names = {ref_op.name for ref_op in python_ref_db} 2119 2120 # TODO: References that do not have an entry in python_ref_db 2121 skip_ref_ops = { 2122 "_refs.alias", 2123 "_refs.bitwise_right_shift", 2124 "_refs.copy_to", 2125 "_refs.empty_permuted", 2126 "_refs.empty_strided", 2127 "_refs.equal", 2128 "_refs.full", 2129 "_refs.full_like", 2130 "_refs.is_complex", 2131 "_refs.to", 2132 "_refs.mvlgamma", 2133 "_refs.ones", 2134 "_refs.ones_like", 2135 "_refs.special.expit", 2136 "_refs.std_var", 2137 "_refs.swap_axes", 2138 "_refs.uniform", 2139 "_refs.scalar_tensor", 2140 "_refs.trunc_divide", 2141 "_refs.zero", 2142 "_refs.zeros", 2143 "_refs.zeros_like", 2144 "_refs.rfloordiv", 2145 "_refs.rtruediv", 2146 "_refs.rpow", 2147 # These should be tested with their out-of-place counterparts 2148 "_refs.index_add_", 2149 "_refs.index_copy_", 2150 "_refs.index_fill_", 2151 "_refs.native_group_norm", 2152 } 2153 2154 not_in_decomp_table = { 2155 # duplicated in _decomp and _refs 2156 "_refs.nn.functional.group_norm", 2157 "_refs.nn.functional.mse_loss", 2158 "_refs.floor_divide", 2159 # duplicated as refs do not have decent support for advanced indexing 2160 "_refs.index_copy", 2161 "_refs.index_copy_", 2162 "_refs.index_add", 2163 "_refs.index_add_", 2164 # these are not aten ops? 2165 "_refs._conversions.bfloat16", 2166 "_refs._conversions.bool", 2167 "_refs._conversions.byte", 2168 "_refs._conversions.char", 2169 "_refs._conversions.double", 2170 "_refs._conversions.float", 2171 "_refs._conversions.half", 2172 "_refs._conversions.int", 2173 "_refs._conversions.long", 2174 "_refs._conversions.short", 2175 "_refs._conversions.chalf", 2176 "_refs._conversions.cfloat", 2177 "_refs._conversions.cdouble", 2178 "_refs.broadcast_shapes", 2179 "_refs.broadcast_tensors", 2180 "_refs.mvlgamma", 2181 "_refs.nn.functional.layer_norm", 2182 "_refs.nn.functional.tanhshrink", 2183 "_refs.nn.functional.triplet_margin_loss", 2184 "_refs.rfloordiv", 2185 "_refs.rtruediv", 2186 "_refs.rpow", 2187 # CompositeImplicitAutograd 2188 "_refs.allclose", 2189 "_refs.atleast_1d", 2190 "_refs.atleast_2d", 2191 "_refs.atleast_3d", 2192 "_refs.broadcast_to", 2193 "_refs.chunk", 2194 "_refs.column_stack", 2195 "_refs.contiguous", 2196 "_refs.dsplit", 2197 "_refs.dstack", 2198 "_refs.fill", 2199 "_refs.fill_", 2200 "_refs.flatten", 2201 "_refs.fliplr", 2202 "_refs.flipud", 2203 "_refs.float_power", 2204 "_refs.hsplit", 2205 "_refs.hstack", 2206 "_refs.isclose", 2207 "_refs.isfinite", 2208 "_refs.isreal", 2209 "_refs.istft", 2210 "_refs.log_softmax", 2211 "_refs.movedim", 2212 "_refs.narrow", 2213 "_refs.nn.functional.dropout", 2214 "_refs.nn.functional.l1_loss", 2215 "_refs.nn.functional.smooth_l1_loss", 2216 "_refs.nn.functional.log_softmax", 2217 "_refs.nn.functional.poisson_nll_loss", 2218 "_refs.nn.functional.softmax", 2219 "_refs.nn.functional.softmin", 2220 "_refs.positive", 2221 "_refs.ravel", 2222 "_refs.reshape", 2223 "_refs.softmax", 2224 "_refs.special.expit", 2225 "_refs.special.log_softmax", 2226 "_refs.special.softmax", 2227 "_refs.square", 2228 "_refs.stft", 2229 "_refs.T", 2230 "_refs.take_along_dim", 2231 "_refs.tensor_split", 2232 "_refs.to", 2233 "_refs.true_divide", 2234 "_refs.trunc_divide", 2235 "_refs.vsplit", 2236 "_refs.vstack", 2237 "_refs.linalg.matrix_norm", 2238 "_refs.linalg.norm", 2239 "_refs.linalg.svd", 2240 "_refs.linalg.svdvals", 2241 "_refs.unflatten", 2242 "_refs.sum_to_size", 2243 # ref implementation missing kwargs 2244 "_refs.full_like", # missing "layout" 2245 "_refs.scalar_tensor", # missing "layout" 2246 # other 2247 "_refs.block_diag", # only refs._block_diag_iterable is in decomposition table 2248 "_refs.empty", # intentional; direct empty is faster and has less guards 2249 "_refs.empty_permuted", # intentional; direct empty is faster and has less guards 2250 "_refs.expand_as", 2251 "_refs.as_strided", # _prims._as_strided_meta: "reduce() of empty sequence with no initial value" 2252 "_refs.copy_to", # torch._C._jit_get_operation: No such operator aten::copy_to 2253 "_refs.equal", # 'bool' object has no attribute 'dtype' 2254 "_refs.conj", # Calls _prims.conj 2255 "_refs.real", 2256 "_refs.imag", 2257 "_refs.reshape_as", 2258 "_refs.view_as", 2259 "_refs.view_as_complex", # TorchInductor does not support complex at the moment. 2260 # the decompositions for these ops are slightly different 2261 # because of out handling 2262 "_refs.var_mean", 2263 "_refs.std_mean", 2264 "_refs.native_layer_norm", 2265 } 2266 2267 @parametrize("op", ref_ops_names) 2268 def test_refs_are_in_python_ref_db(self, op): 2269 inplace = op[-1] == "_" 2270 if op in self.skip_ref_ops: 2271 raise unittest.SkipTest(f"{op} does not have an entry in python_ref_db") 2272 elif inplace: 2273 self.assertNotIn( 2274 op, 2275 self.ref_db_names, 2276 msg=f"{op} is an in-place operation and should not have an OpInfo", 2277 ) 2278 else: 2279 # Intentionally don't use assertIn to avoid printing the 2280 # (very large) container 2281 self.assertTrue(op in self.ref_db_names, msg=f"{op} not in ref_db_names") 2282 2283 @parametrize("op", ref_ops_names) 2284 def test_refs_are_in_decomp_table(self, op): 2285 path = op.split(".") 2286 module_path = ".".join(path[:-1]) 2287 op_name = path[-1] 2288 op_impl = getattr(import_module(f"torch.{module_path}"), op_name) 2289 2290 if op in self.not_in_decomp_table: 2291 self.assertNotIn( 2292 op_impl, 2293 torch._decomp.decomposition_table.values(), 2294 f"Unexpectedly found {op} in torch._decomp.decomposition_table.values()", 2295 ) 2296 else: 2297 self.assertIn( 2298 op_impl, 2299 torch._decomp.decomposition_table.values(), 2300 f"Did not find {op} in torch._decomp.decomposition_table.values()", 2301 ) 2302 2303 2304fake_skips = ( 2305 "aminmax", # failing input 2306 "cov", # aweights cannot be negtaive 2307 "istft", # window overlap add min: 0 2308 "linalg.eigvals", # The tensor has a non-zero number of elements, but its data is not allocated yet 2309 "linalg.eigvalsh", # aten::linalg_eigvalsh.out' with arguments from the 'Meta' backend 2310 "linalg.matrix_power", # Could not run 'aten::eye.m_out' with arguments from the 'Meta' backend 2311 # "linalg.pinv", # Could not run 'aten::pinv.out' with arguments from the 'Meta' backen 2312 "linalg.matrix_rank.hermitian", # Could not run 'aten::linalg_eigvalsh.out' with arguments from the 'Meta' backend 2313 "linalg.pinv.hermitian", # tensor.mH is only supported on matrices or batches of matrices. Got 1-D tensor 2314 "linalg.solve", # Could not run 'aten::linalg_solve' with arguments from the 'Meta' backend 2315 "linalg.tensorsolve", # Could not run 'aten::linalg_solve' with arguments from the 'Meta' 2316 "lu_solve", # MALLOC ERROR: debug 2317 "multinomial", # Could not run 'aten::multinomial' with arguments from the 'Meta' backend 2318 "mvlgamma.mvlgamma_p_1", # Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend 2319 "mvlgamma.mvlgamma_p_3", # Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend 2320 "mvlgamma.mvlgamma_p_5", # Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend 2321 "nanmean", # logical_not() got an unexpected keyword argument 'out' 2322 "quantile", # quantile() q values must be in the range [0, 1] 2323 "nanquantile", # quantile() q values must be in the range [0, 1] 2324 "nn.functional.ctc_loss", # The tensor has a non-zero number of elements, but its data is not allocated yet 2325 "nn.functional.embedding_bag", # sometimes errors 2326 "nn.functional.nll_loss", # sometimes errors 2327 "nn.functional.max_pool1d", # The tensor has a non-zero number of elements 2328 "to_sparse", # Could not run 'aten::_to_sparse' with arguments from the 'Meta' backend 2329 "tensor_split", # The tensor has a non-zero number of elements, but its data is not allocated yet 2330 "repeat_interleave", # cannot repeat_interleave a meta tensor without output_size 2331 "sparse.sampled.addmm", # sparsity not supported 2332 # Can not infer total number of classes from meta. no way at present to throw DynamicOutputShapeException 2333 "nn.functional.one_hot", 2334 "narrow", # Fails only for one overload with DataDependentOutputException (hence skip). 2335) 2336 2337fake_autocast_device_skips = defaultdict(dict) 2338 2339# TODO: investigate/fix 2340fake_autocast_device_skips["cpu"] = {"linalg.pinv"} 2341fake_autocast_device_skips["cuda"] = {"linalg.pinv", "pinverse"} 2342 2343 2344dynamic_output_op_tests = ( 2345 "argwhere", 2346 "bincount", 2347 "combinations", 2348 "linalg.lstsq", 2349 "masked_select", 2350 "nonzero", 2351 "unique_consecutive", 2352 "unique", 2353 "linalg.lstsq.grad_oriented", 2354) 2355 2356# Ops that have dynamic output shapes that we can handle when 2357# allow_dynamic_shape_ops is True in fake tensor shape environment. 2358supported_dynamic_output_op_tests = ( 2359 "nonzero", 2360 "unique", 2361 "repeat_interleave", 2362 "masked_select", 2363) 2364 2365# some inputs invoke dynamic output shape operators, some do not 2366sometimes_dynamic_output_op_test = ("__getitem__", "index_select") 2367 2368data_dependent_op_tests = ( 2369 "equal", 2370 "corrcoef", 2371 "nn.functional.gaussian_nll_loss", 2372 "allclose", 2373) 2374 2375aliasing_failures = ("histogramdd",) 2376 2377fake_backward_skips = { 2378 "linalg.cond", 2379 "linalg.matrix_norm", 2380 "linalg.norm", 2381 "linalg.svd", 2382 "linalg.svdvals", 2383 "pca_lowrank", 2384 "roll", 2385 "svd_lowrank", 2386 "sgn", 2387} 2388 2389fake_backward_xfails = {skip(s) for s in fake_backward_skips} | { 2390 xfail("fft.ihfftn"), # Mismatch in aten._conj_physical.default 2391 xfail("fft.ihfft2"), # Mismatch in aten._conj_physical.default 2392 skip("nn.functional.ctc_loss"), 2393} 2394 2395fake_autocast_backward_xfails = { 2396 skip("nn.functional.binary_cross_entropy"), 2397 skip("sparse.sampled_addmm"), 2398 skip("linalg.pinv"), 2399 skip("linalg.pinv", "hermitian"), 2400 skip("linalg.pinv", "singular"), 2401 skip("pinverse"), 2402} 2403 2404 2405@unMarkDynamoStrictTest 2406class TestFakeTensor(TestCase): 2407 def setUp(self): 2408 # Turn on FakeTensor caching and cross-checking for these tests: 2409 cache_enabled = unittest.mock.patch( 2410 "torch._dynamo.config.fake_tensor_cache_enabled", True 2411 ) 2412 cache_enabled.start() 2413 self.addCleanup(cache_enabled.stop) 2414 2415 cache_crosscheck = unittest.mock.patch( 2416 "torch._dynamo.config.fake_tensor_cache_crosscheck_enabled", True 2417 ) 2418 cache_crosscheck.start() 2419 self.addCleanup(cache_crosscheck.stop) 2420 2421 def _test_fake_helper(self, device, dtype, op, context): 2422 name = op.name 2423 if op.variant_test_name: 2424 name += "." + op.variant_test_name 2425 if name in fake_skips or "sparse" in name or "jiterator" in name: 2426 self.skipTest("Skip failing test") 2427 2428 samples = op.sample_inputs(device, dtype, requires_grad=False) 2429 for sample in samples: 2430 mode = FakeTensorMode() 2431 2432 from torch.fx.experimental.symbolic_shapes import ShapeEnv 2433 2434 allow_dynamic_output_shape_shape_env = ShapeEnv( 2435 allow_dynamic_output_shape_ops=True 2436 ) 2437 2438 allow_dynamic_output_shape_mode = FakeTensorMode( 2439 shape_env=allow_dynamic_output_shape_shape_env 2440 ) 2441 2442 try: 2443 with context(): 2444 res = op(sample.input, *sample.args, **sample.kwargs) 2445 except Exception: 2446 continue 2447 2448 def run_with_fake_mode_and_verify(fake_mode, match_results=True): 2449 def map_to_fake(e): 2450 if isinstance(e, torch.Tensor): 2451 return fake_mode.from_tensor(e) 2452 else: 2453 return e 2454 2455 input = tree_map(map_to_fake, sample.input) 2456 args = tree_map(map_to_fake, sample.args) 2457 kwargs = tree_map(map_to_fake, sample.kwargs) 2458 2459 try: 2460 with context(): 2461 with fake_mode: 2462 res_fake = op(input, *args, **kwargs) 2463 2464 if not match_results: 2465 return 2466 2467 for fake_out, real_out in zip( 2468 pytree.tree_leaves(res_fake), pytree.tree_leaves(res) 2469 ): 2470 if not isinstance(fake_out, torch.Tensor): 2471 self.assertTrue(not isinstance(real_out, torch.Tensor)) 2472 self.assertEqual(fake_out, real_out) 2473 continue 2474 2475 self.assertTrue(isinstance(fake_out, FakeTensor)) 2476 # if you see a shape exception here, you may need to add 2477 # a `dynamic_output_shape` tag to an operator 2478 2479 if op.op not in [ 2480 torch.ops.aten._efficient_attention_forward, 2481 torch.ops.aten._flash_attention_forward, 2482 ]: 2483 # prims/decomps must correctly model strides, 2484 # see https://github.com/pytorch/pytorch/issues/78050#issuecomment-1253950325 2485 2486 # note: the excluded ops have intentionally incorrect device; 2487 # see "Note [Seed and Offset]" (_meta_registrations.py) 2488 prims.utils.compare_tensor_meta(fake_out, real_out, True) 2489 2490 if name not in aliasing_failures: 2491 fake_aliasing = outputs_alias_inputs( 2492 (input, args, kwargs), res_fake 2493 ) 2494 real_aliasing = outputs_alias_inputs( 2495 (sample.input, sample, args, sample.kwargs), res 2496 ) 2497 self.assertEqual(fake_aliasing, real_aliasing) 2498 2499 self.assertTrue( 2500 name not in dynamic_output_op_tests 2501 and name not in data_dependent_op_tests 2502 ) 2503 2504 except torch._subclasses.fake_tensor.UnsupportedFakeTensorException: 2505 pass 2506 except torch._subclasses.fake_tensor.UnsupportedOperatorException: 2507 pass 2508 except torch._subclasses.fake_tensor.DynamicOutputShapeException: 2509 self.assertTrue( 2510 name in dynamic_output_op_tests 2511 or name in sometimes_dynamic_output_op_test 2512 ) 2513 self.assertTrue( 2514 fake_mode.shape_env is None 2515 or not fake_mode.shape_env.allow_dynamic_output_shape_ops 2516 or name not in supported_dynamic_output_op_tests 2517 ) 2518 except torch._subclasses.fake_tensor.DataDependentOutputException: 2519 self.assertTrue(name in data_dependent_op_tests) 2520 2521 run_with_fake_mode_and_verify(mode) 2522 if name in supported_dynamic_output_op_tests: 2523 run_with_fake_mode_and_verify( 2524 allow_dynamic_output_shape_mode, match_results=False 2525 ) 2526 2527 @ops(op_db, dtypes=OpDTypes.any_one) 2528 def test_pointwise_ops(self, device, dtype, op): 2529 name = op.name 2530 if op.variant_test_name: 2531 name += "." + op.variant_test_name 2532 if name in fake_skips or "sparse" in name or "jiterator" in name: 2533 self.skipTest("Skip failing test") 2534 2535 test_self = self 2536 2537 class TestPointwiseMode(TorchDispatchMode): 2538 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 2539 kwargs = kwargs or {} 2540 2541 out = func(*args, **kwargs) 2542 2543 if torch.Tag.pointwise in func.tags: 2544 shapes = [] 2545 for inp in pytree.arg_tree_leaves(*args, **kwargs): 2546 if isinstance(inp, torch.Tensor): 2547 shapes.append(inp.shape) 2548 2549 out_shape = torch._refs._broadcast_shapes(*shapes) 2550 2551 for out_elem in pytree.tree_leaves(out): 2552 if isinstance(out_elem, torch.Tensor): 2553 test_self.assertEqual(out_elem.shape, out_shape) 2554 2555 return out 2556 2557 samples = op.sample_inputs(device, dtype, requires_grad=False) 2558 for sample in samples: 2559 mode = FakeTensorMode() 2560 2561 def map_to_fake(e): 2562 if isinstance(e, torch.Tensor): 2563 return mode.from_tensor(e) 2564 else: 2565 return e 2566 2567 input = tree_map(map_to_fake, sample.input) 2568 args = tree_map(map_to_fake, sample.args) 2569 kwargs = tree_map(map_to_fake, sample.kwargs) 2570 2571 try: 2572 op(input, *args, **kwargs) 2573 except Exception as e: 2574 continue 2575 2576 with TestPointwiseMode(): 2577 with mode: 2578 op(input, *args, **kwargs) 2579 2580 @ops(op_db, dtypes=OpDTypes.any_one) 2581 def test_fake(self, device, dtype, op): 2582 self._test_fake_helper(device, dtype, op, contextlib.nullcontext) 2583 2584 @ops(op_db, dtypes=OpDTypes.any_one) 2585 def test_fake_autocast(self, device, dtype, op): 2586 device_type = torch.device(device).type 2587 if op.name in fake_autocast_device_skips[device_type]: 2588 self.skipTest("Skip failing test") 2589 2590 def context_fn(): 2591 return torch.amp.autocast(device_type) 2592 2593 self._test_fake_helper(device, dtype, op, context_fn) 2594 2595 def _test_fake_crossref_helper(self, device, dtype, op, context): 2596 samples = op.sample_inputs(device, dtype, requires_grad=True) 2597 2598 for iter, sample in enumerate(samples): 2599 args = [sample.input] + list(sample.args) 2600 kwargs = sample.kwargs 2601 2602 # skip these to speed up tests 2603 common_skip_ops = ( 2604 aten.detach.default, 2605 aten.empty_strided.default, 2606 aten.copy_.default, 2607 aten.is_same_size.default, 2608 ) 2609 2610 # TODO: enable check_aliasing, batch norm fails 2611 try: 2612 with torch._subclasses.CrossRefFakeMode( 2613 ignore_op_fn=lambda fn: fn in common_skip_ops, check_aliasing=True 2614 ): 2615 with warnings.catch_warnings(), context(), torch.autograd.set_multithreading_enabled( 2616 False 2617 ): 2618 composite_compliance.compute_expected_grads( 2619 op.get_op(), 2620 args, 2621 kwargs, 2622 sample.output_process_fn_grad, 2623 op.gradcheck_wrapper, 2624 ) 2625 except torch._subclasses.fake_tensor.UnsupportedOperatorException: 2626 pass 2627 2628 @onlyCUDA 2629 @ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,)) 2630 @skipOps( 2631 "TestFakeTensor", "test_fake_crossref_backward_no_amp", fake_backward_xfails 2632 ) 2633 def test_fake_crossref_backward_no_amp(self, device, dtype, op): 2634 self._test_fake_crossref_helper(device, dtype, op, contextlib.nullcontext) 2635 2636 @onlyCUDA 2637 @ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,)) 2638 @skipOps( 2639 "TestFakeTensor", 2640 "test_fake_crossref_backward_amp", 2641 fake_backward_xfails | fake_autocast_backward_xfails, 2642 ) 2643 def test_fake_crossref_backward_amp(self, device, dtype, op): 2644 self._test_fake_crossref_helper(device, dtype, op, torch.cuda.amp.autocast) 2645 2646 @ops([op for op in ops_and_refs if op.is_factory_function]) 2647 def test_strided_layout(self, device, dtype, op): 2648 samples = op.sample_inputs(device, dtype) 2649 for sample in samples: 2650 kwargs = sample.kwargs.copy() 2651 kwargs["layout"] = torch.strided 2652 strided_result = op(sample.input, *sample.args, **kwargs) 2653 self.assertEqual(strided_result.layout, torch.strided) 2654 2655 2656instantiate_device_type_tests(TestCommon, globals()) 2657instantiate_device_type_tests(TestCompositeCompliance, globals()) 2658instantiate_device_type_tests(TestMathBits, globals()) 2659instantiate_device_type_tests(TestRefsOpsInfo, globals(), only_for="cpu") 2660instantiate_device_type_tests(TestFakeTensor, globals()) 2661instantiate_device_type_tests(TestTags, globals()) 2662 2663if __name__ == "__main__": 2664 TestCase._default_dtype_check_enabled = True 2665 run_tests() 2666