1# Owner(s): ["module: tests"] 2 3import collections 4import doctest 5import functools 6import importlib 7import inspect 8import itertools 9import math 10import os 11import re 12import subprocess 13import sys 14import unittest.mock 15from typing import Any, Callable, Iterator, List, Tuple 16 17import torch 18 19from torch.testing import make_tensor 20from torch.testing._internal.common_utils import \ 21 (IS_FBCODE, IS_JETSON, IS_MACOS, IS_SANDCASTLE, IS_WINDOWS, TestCase, run_tests, slowTest, 22 parametrize, subtest, instantiate_parametrized_tests, dtype_name, TEST_WITH_ROCM, decorateIf) 23from torch.testing._internal.common_device_type import \ 24 (PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY, PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, dtypes, 25 get_device_type_test_bases, instantiate_device_type_tests, onlyCPU, onlyCUDA, onlyNativeDeviceTypes, 26 deviceCountAtLeast, ops, expectedFailureMeta, OpDTypes) 27from torch.testing._internal.common_methods_invocations import op_db 28from torch.testing._internal import opinfo 29from torch.testing._internal.common_dtype import all_types_and_complex_and, floating_types 30from torch.testing._internal.common_modules import modules, module_db, ModuleInfo 31from torch.testing._internal.opinfo.core import SampleInput, DecorateInfo, OpInfo 32import operator 33 34# For testing TestCase methods and torch.testing functions 35class TestTesting(TestCase): 36 # Ensure that assertEqual handles numpy arrays properly 37 @dtypes(*all_types_and_complex_and(torch.bool, torch.half)) 38 def test_assertEqual_numpy(self, device, dtype): 39 S = 10 40 test_sizes = [ 41 (), 42 (0,), 43 (S,), 44 (S, S), 45 (0, S), 46 (S, 0)] 47 for test_size in test_sizes: 48 a = make_tensor(test_size, dtype=dtype, device=device, low=-5, high=5) 49 a_n = a.cpu().numpy() 50 msg = f'size: {test_size}' 51 self.assertEqual(a_n, a, rtol=0, atol=0, msg=msg) 52 self.assertEqual(a, a_n, rtol=0, atol=0, msg=msg) 53 self.assertEqual(a_n, a_n, rtol=0, atol=0, msg=msg) 54 55 def test_assertEqual_longMessage(self): 56 actual = "actual" 57 expected = "expected" 58 59 long_message = self.longMessage 60 try: 61 # Capture the default error message by forcing TestCase.longMessage = False 62 self.longMessage = False 63 try: 64 self.assertEqual(actual, expected) 65 except AssertionError as error: 66 default_msg = str(error) 67 else: 68 raise AssertionError("AssertionError not raised") 69 70 self.longMessage = True 71 extra_msg = "sentinel" 72 with self.assertRaisesRegex(AssertionError, re.escape(f"{default_msg}\n{extra_msg}")): 73 self.assertEqual(actual, expected, msg=extra_msg) 74 finally: 75 self.longMessage = long_message 76 77 def _isclose_helper(self, tests, device, dtype, equal_nan, atol=1e-08, rtol=1e-05): 78 for test in tests: 79 a = torch.tensor((test[0],), device=device, dtype=dtype) 80 b = torch.tensor((test[1],), device=device, dtype=dtype) 81 82 actual = torch.isclose(a, b, equal_nan=equal_nan, atol=atol, rtol=rtol) 83 expected = test[2] 84 self.assertEqual(actual.item(), expected) 85 86 def test_isclose_bool(self, device): 87 tests = ( 88 (True, True, True), 89 (False, False, True), 90 (True, False, False), 91 (False, True, False), 92 ) 93 94 self._isclose_helper(tests, device, torch.bool, False) 95 96 @dtypes(torch.uint8, 97 torch.int8, torch.int16, torch.int32, torch.int64) 98 def test_isclose_integer(self, device, dtype): 99 tests = ( 100 (0, 0, True), 101 (0, 1, False), 102 (1, 0, False), 103 ) 104 105 self._isclose_helper(tests, device, dtype, False) 106 107 # atol and rtol tests 108 tests = [ 109 (0, 1, True), 110 (1, 0, False), 111 (1, 3, True), 112 ] 113 114 self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5) 115 116 if dtype is torch.uint8: 117 tests = [ 118 (-1, 1, False), 119 (1, -1, False) 120 ] 121 else: 122 tests = [ 123 (-1, 1, True), 124 (1, -1, True) 125 ] 126 127 self._isclose_helper(tests, device, dtype, False, atol=1.5, rtol=.5) 128 129 @onlyNativeDeviceTypes 130 @dtypes(torch.float16, torch.float32, torch.float64) 131 def test_isclose_float(self, device, dtype): 132 tests = ( 133 (0, 0, True), 134 (0, -1, False), 135 (float('inf'), float('inf'), True), 136 (-float('inf'), float('inf'), False), 137 (float('inf'), float('nan'), False), 138 (float('nan'), float('nan'), False), 139 (0, float('nan'), False), 140 (1, 1, True), 141 ) 142 143 self._isclose_helper(tests, device, dtype, False) 144 145 # atol and rtol tests 146 eps = 1e-2 if dtype is torch.half else 1e-6 147 tests = ( 148 (0, 1, True), 149 (0, 1 + eps, False), 150 (1, 0, False), 151 (1, 3, True), 152 (1 - eps, 3, False), 153 (-.25, .5, True), 154 (-.25 - eps, .5, False), 155 (.25, -.5, True), 156 (.25 + eps, -.5, False), 157 ) 158 159 self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5) 160 161 # equal_nan = True tests 162 tests = ( 163 (0, float('nan'), False), 164 (float('inf'), float('nan'), False), 165 (float('nan'), float('nan'), True), 166 ) 167 168 self._isclose_helper(tests, device, dtype, True) 169 170 @unittest.skipIf(IS_SANDCASTLE, "Skipping because doesn't work on sandcastle") 171 @dtypes(torch.complex64, torch.complex128) 172 def test_isclose_complex(self, device, dtype): 173 tests = ( 174 (complex(1, 1), complex(1, 1 + 1e-8), True), 175 (complex(0, 1), complex(1, 1), False), 176 (complex(1, 1), complex(1, 0), False), 177 (complex(1, 1), complex(1, float('nan')), False), 178 (complex(1, float('nan')), complex(1, float('nan')), False), 179 (complex(1, 1), complex(1, float('inf')), False), 180 (complex(float('inf'), 1), complex(1, float('inf')), False), 181 (complex(-float('inf'), 1), complex(1, float('inf')), False), 182 (complex(-float('inf'), 1), complex(float('inf'), 1), False), 183 (complex(float('inf'), 1), complex(float('inf'), 1), True), 184 (complex(float('inf'), 1), complex(float('inf'), 1 + 1e-4), False), 185 ) 186 187 self._isclose_helper(tests, device, dtype, False) 188 189 # atol and rtol tests 190 191 # atol and rtol tests 192 eps = 1e-6 193 tests = ( 194 # Complex versions of float tests (real part) 195 (complex(0, 0), complex(1, 0), True), 196 (complex(0, 0), complex(1 + eps, 0), False), 197 (complex(1, 0), complex(0, 0), False), 198 (complex(1, 0), complex(3, 0), True), 199 (complex(1 - eps, 0), complex(3, 0), False), 200 (complex(-.25, 0), complex(.5, 0), True), 201 (complex(-.25 - eps, 0), complex(.5, 0), False), 202 (complex(.25, 0), complex(-.5, 0), True), 203 (complex(.25 + eps, 0), complex(-.5, 0), False), 204 # Complex versions of float tests (imaginary part) 205 (complex(0, 0), complex(0, 1), True), 206 (complex(0, 0), complex(0, 1 + eps), False), 207 (complex(0, 1), complex(0, 0), False), 208 (complex(0, 1), complex(0, 3), True), 209 (complex(0, 1 - eps), complex(0, 3), False), 210 (complex(0, -.25), complex(0, .5), True), 211 (complex(0, -.25 - eps), complex(0, .5), False), 212 (complex(0, .25), complex(0, -.5), True), 213 (complex(0, .25 + eps), complex(0, -.5), False), 214 ) 215 216 self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5) 217 218 # atol and rtol tests for isclose 219 tests = ( 220 # Complex-specific tests 221 (complex(1, -1), complex(-1, 1), False), 222 (complex(1, -1), complex(2, -2), True), 223 (complex(-math.sqrt(2), math.sqrt(2)), 224 complex(-math.sqrt(.5), math.sqrt(.5)), True), 225 (complex(-math.sqrt(2), math.sqrt(2)), 226 complex(-math.sqrt(.501), math.sqrt(.499)), False), 227 (complex(2, 4), complex(1., 8.8523607), True), 228 (complex(2, 4), complex(1., 8.8523607 + eps), False), 229 (complex(1, 99), complex(4, 100), True), 230 ) 231 self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5) 232 233 # equal_nan = True tests 234 tests = ( 235 (complex(1, 1), complex(1, float('nan')), False), 236 (complex(1, 1), complex(float('nan'), 1), False), 237 (complex(float('nan'), 1), complex(float('nan'), 1), True), 238 (complex(float('nan'), 1), complex(1, float('nan')), True), 239 (complex(float('nan'), float('nan')), complex(float('nan'), float('nan')), True), 240 ) 241 self._isclose_helper(tests, device, dtype, True) 242 243 # Tests that isclose with rtol or atol values less than zero throws a 244 # RuntimeError 245 @dtypes(torch.bool, torch.uint8, 246 torch.int8, torch.int16, torch.int32, torch.int64, 247 torch.float16, torch.float32, torch.float64) 248 def test_isclose_atol_rtol_greater_than_zero(self, device, dtype): 249 t = torch.tensor((1,), device=device, dtype=dtype) 250 251 with self.assertRaises(RuntimeError): 252 torch.isclose(t, t, atol=-1, rtol=1) 253 with self.assertRaises(RuntimeError): 254 torch.isclose(t, t, atol=1, rtol=-1) 255 with self.assertRaises(RuntimeError): 256 torch.isclose(t, t, atol=-1, rtol=-1) 257 258 def test_isclose_equality_shortcut(self): 259 # For values >= 2**53, integers differing by 1 can no longer differentiated by torch.float64 or lower precision 260 # floating point dtypes. Thus, even with rtol == 0 and atol == 0, these tensors would be considered close if 261 # they were not compared as integers. 262 a = torch.tensor(2 ** 53, dtype=torch.int64) 263 b = a + 1 264 265 self.assertFalse(torch.isclose(a, b, rtol=0, atol=0)) 266 267 @dtypes(torch.float16, torch.float32, torch.float64, torch.complex64, torch.complex128) 268 def test_isclose_nan_equality_shortcut(self, device, dtype): 269 if dtype.is_floating_point: 270 a = b = torch.nan 271 else: 272 a = complex(torch.nan, 0) 273 b = complex(0, torch.nan) 274 275 expected = True 276 tests = [(a, b, expected)] 277 278 self._isclose_helper(tests, device, dtype, equal_nan=True, rtol=0, atol=0) 279 280 # The following tests (test_cuda_assert_*) are added to ensure test suite terminates early 281 # when CUDA assert was thrown. Because all subsequent test will fail if that happens. 282 # These tests are slow because it spawn another process to run test suite. 283 # See: https://github.com/pytorch/pytorch/issues/49019 284 @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support device side asserts") 285 @onlyCUDA 286 @slowTest 287 def test_cuda_assert_should_stop_common_utils_test_suite(self, device): 288 # test to ensure common_utils.py override has early termination for CUDA. 289 stderr = TestCase.runWithPytorchAPIUsageStderr("""\ 290#!/usr/bin/env python3 291 292import torch 293from torch.testing._internal.common_utils import (TestCase, run_tests, slowTest) 294 295class TestThatContainsCUDAAssertFailure(TestCase): 296 297 @slowTest 298 def test_throw_unrecoverable_cuda_exception(self): 299 x = torch.rand(10, device='cuda') 300 # cause unrecoverable CUDA exception, recoverable on CPU 301 y = x[torch.tensor([25])].cpu() 302 303 @slowTest 304 def test_trivial_passing_test_case_on_cpu_cuda(self): 305 x1 = torch.tensor([0., 1.], device='cuda') 306 x2 = torch.tensor([0., 1.], device='cpu') 307 self.assertEqual(x1, x2) 308 309if __name__ == '__main__': 310 run_tests() 311""") 312 # should capture CUDA error 313 self.assertIn('CUDA error: device-side assert triggered', stderr) 314 # should run only 1 test because it throws unrecoverable error. 315 self.assertIn('errors=1', stderr) 316 317 318 @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support device side asserts") 319 @onlyCUDA 320 @slowTest 321 def test_cuda_assert_should_stop_common_device_type_test_suite(self, device): 322 # test to ensure common_device_type.py override has early termination for CUDA. 323 stderr = TestCase.runWithPytorchAPIUsageStderr("""\ 324#!/usr/bin/env python3 325 326import torch 327from torch.testing._internal.common_utils import (TestCase, run_tests, slowTest) 328from torch.testing._internal.common_device_type import instantiate_device_type_tests 329 330class TestThatContainsCUDAAssertFailure(TestCase): 331 332 @slowTest 333 def test_throw_unrecoverable_cuda_exception(self, device): 334 x = torch.rand(10, device=device) 335 # cause unrecoverable CUDA exception, recoverable on CPU 336 y = x[torch.tensor([25])].cpu() 337 338 @slowTest 339 def test_trivial_passing_test_case_on_cpu_cuda(self, device): 340 x1 = torch.tensor([0., 1.], device=device) 341 x2 = torch.tensor([0., 1.], device='cpu') 342 self.assertEqual(x1, x2) 343 344instantiate_device_type_tests( 345 TestThatContainsCUDAAssertFailure, 346 globals(), 347 only_for='cuda' 348) 349 350if __name__ == '__main__': 351 run_tests() 352""") 353 # should capture CUDA error 354 self.assertIn('CUDA error: device-side assert triggered', stderr) 355 # should run only 1 test because it throws unrecoverable error. 356 self.assertIn('errors=1', stderr) 357 358 359 @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support device side asserts") 360 @onlyCUDA 361 @slowTest 362 def test_cuda_assert_should_not_stop_common_distributed_test_suite(self, device): 363 # test to ensure common_distributed.py override should not early terminate CUDA. 364 stderr = TestCase.runWithPytorchAPIUsageStderr("""\ 365#!/usr/bin/env python3 366 367import torch 368from torch.testing._internal.common_utils import (run_tests, slowTest) 369from torch.testing._internal.common_device_type import instantiate_device_type_tests 370from torch.testing._internal.common_distributed import MultiProcessTestCase 371 372class TestThatContainsCUDAAssertFailure(MultiProcessTestCase): 373 374 @slowTest 375 def test_throw_unrecoverable_cuda_exception(self, device): 376 x = torch.rand(10, device=device) 377 # cause unrecoverable CUDA exception, recoverable on CPU 378 y = x[torch.tensor([25])].cpu() 379 380 @slowTest 381 def test_trivial_passing_test_case_on_cpu_cuda(self, device): 382 x1 = torch.tensor([0., 1.], device=device) 383 x2 = torch.tensor([0., 1.], device='cpu') 384 self.assertEqual(x1, x2) 385 386instantiate_device_type_tests( 387 TestThatContainsCUDAAssertFailure, 388 globals(), 389 only_for='cuda' 390) 391 392if __name__ == '__main__': 393 run_tests() 394""") 395 # we are currently disabling CUDA early termination for distributed tests. 396 self.assertIn('errors=2', stderr) 397 398 @expectedFailureMeta # This is only supported for CPU and CUDA 399 @onlyNativeDeviceTypes 400 def test_get_supported_dtypes(self, device): 401 # Test the `get_supported_dtypes` helper function. 402 # We acquire the dtypes for few Ops dynamically and verify them against 403 # the correct statically described values. 404 ops_to_test = list(filter(lambda op: op.name in ['atan2', 'topk', 'xlogy'], op_db)) 405 406 for op in ops_to_test: 407 dynamic_dtypes = opinfo.utils.get_supported_dtypes(op, op.sample_inputs_func, self.device_type) 408 dynamic_dispatch = opinfo.utils.dtypes_dispatch_hint(dynamic_dtypes) 409 if self.device_type == 'cpu': 410 dtypes = op.dtypes 411 else: # device_type ='cuda' 412 dtypes = op.dtypesIfCUDA 413 414 self.assertTrue(set(dtypes) == set(dynamic_dtypes)) 415 self.assertTrue(set(dtypes) == set(dynamic_dispatch.dispatch_fn())) 416 417 @onlyCPU 418 @ops( 419 [ 420 op 421 for op in op_db 422 if len( 423 op.supported_dtypes("cpu").symmetric_difference( 424 op.supported_dtypes("cuda") 425 ) 426 ) 427 > 0 428 ][:1], 429 dtypes=OpDTypes.none, 430 ) 431 def test_supported_dtypes(self, device, op): 432 self.assertNotEqual(op.supported_dtypes("cpu"), op.supported_dtypes("cuda")) 433 self.assertEqual(op.supported_dtypes("cuda"), op.supported_dtypes("cuda:0")) 434 self.assertEqual( 435 op.supported_dtypes(torch.device("cuda")), 436 op.supported_dtypes(torch.device("cuda", index=1)), 437 ) 438 439instantiate_device_type_tests(TestTesting, globals()) 440 441 442class TestFrameworkUtils(TestCase): 443 444 @unittest.skipIf(IS_WINDOWS, "Skipping because doesn't work for windows") 445 @unittest.skipIf(IS_SANDCASTLE, "Skipping because doesn't work on sandcastle") 446 def test_filtering_env_var(self): 447 # Test environment variable selected device type test generator. 448 test_filter_file_template = """\ 449#!/usr/bin/env python3 450 451import torch 452from torch.testing._internal.common_utils import (TestCase, run_tests) 453from torch.testing._internal.common_device_type import instantiate_device_type_tests 454 455class TestEnvironmentVariable(TestCase): 456 457 def test_trivial_passing_test(self, device): 458 x1 = torch.tensor([0., 1.], device=device) 459 x2 = torch.tensor([0., 1.], device='cpu') 460 self.assertEqual(x1, x2) 461 462instantiate_device_type_tests( 463 TestEnvironmentVariable, 464 globals(), 465) 466 467if __name__ == '__main__': 468 run_tests() 469""" 470 test_bases_count = len(get_device_type_test_bases()) 471 # Test without setting env var should run everything. 472 env = dict(os.environ) 473 for k in ['CI', PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY]: 474 if k in env.keys(): 475 del env[k] 476 _, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env) 477 self.assertIn(f'Ran {test_bases_count} test', stderr.decode('ascii')) 478 479 # Test with setting only_for should only run 1 test. 480 env[PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY] = 'cpu' 481 _, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env) 482 self.assertIn('Ran 1 test', stderr.decode('ascii')) 483 484 # Test with setting except_for should run 1 less device type from default. 485 del env[PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY] 486 env[PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY] = 'cpu' 487 _, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env) 488 self.assertIn(f'Ran {test_bases_count-1} test', stderr.decode('ascii')) 489 490 # Test with setting both should throw exception 491 env[PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY] = 'cpu' 492 _, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env) 493 self.assertNotIn('OK', stderr.decode('ascii')) 494 495 496def make_assert_close_inputs(actual: Any, expected: Any) -> List[Tuple[Any, Any]]: 497 """Makes inputs for :func:`torch.testing.assert_close` functions based on two examples. 498 499 Args: 500 actual (Any): Actual input. 501 expected (Any): Expected input. 502 503 Returns: 504 List[Tuple[Any, Any]]: Pair of example inputs, as well as the example inputs wrapped in sequences 505 (:class:`tuple`, :class:`list`), and mappings (:class:`dict`, :class:`~collections.OrderedDict`). 506 """ 507 return [ 508 (actual, expected), 509 # tuple vs. tuple 510 ((actual,), (expected,)), 511 # list vs. list 512 ([actual], [expected]), 513 # tuple vs. list 514 ((actual,), [expected]), 515 # dict vs. dict 516 ({"t": actual}, {"t": expected}), 517 # OrderedDict vs. OrderedDict 518 (collections.OrderedDict([("t", actual)]), collections.OrderedDict([("t", expected)])), 519 # dict vs. OrderedDict 520 ({"t": actual}, collections.OrderedDict([("t", expected)])), 521 # list of tuples vs. tuple of lists 522 ([(actual,)], ([expected],)), 523 # list of dicts vs. tuple of OrderedDicts 524 ([{"t": actual}], (collections.OrderedDict([("t", expected)]),)), 525 # dict of lists vs. OrderedDict of tuples 526 ({"t": [actual]}, collections.OrderedDict([("t", (expected,))])), 527 ] 528 529 530def assert_close_with_inputs(actual: Any, expected: Any) -> Iterator[Callable]: 531 """Yields :func:`torch.testing.assert_close` with predefined positional inputs based on two examples. 532 533 .. note:: 534 535 Every test that does not test for a specific input should iterate over this to maximize the coverage. 536 537 Args: 538 actual (Any): Actual input. 539 expected (Any): Expected input. 540 541 Yields: 542 Callable: :func:`torch.testing.assert_close` with predefined positional inputs. 543 """ 544 for inputs in make_assert_close_inputs(actual, expected): 545 yield functools.partial(torch.testing.assert_close, *inputs) 546 547 548class TestAssertClose(TestCase): 549 def test_mismatching_types_subclasses(self): 550 actual = torch.rand(()) 551 expected = torch.nn.Parameter(actual) 552 553 for fn in assert_close_with_inputs(actual, expected): 554 fn() 555 556 def test_mismatching_types_type_equality(self): 557 actual = torch.empty(()) 558 expected = torch.nn.Parameter(actual) 559 560 for fn in assert_close_with_inputs(actual, expected): 561 with self.assertRaisesRegex(TypeError, str(type(expected))): 562 fn(allow_subclasses=False) 563 564 def test_mismatching_types(self): 565 actual = torch.empty(2) 566 expected = actual.numpy() 567 568 for fn, allow_subclasses in itertools.product(assert_close_with_inputs(actual, expected), (True, False)): 569 with self.assertRaisesRegex(TypeError, str(type(expected))): 570 fn(allow_subclasses=allow_subclasses) 571 572 def test_unknown_type(self): 573 actual = "0" 574 expected = "0" 575 576 for fn in assert_close_with_inputs(actual, expected): 577 with self.assertRaisesRegex(TypeError, str(type(actual))): 578 fn() 579 580 def test_mismatching_shape(self): 581 actual = torch.empty(()) 582 expected = actual.clone().reshape((1,)) 583 584 for fn in assert_close_with_inputs(actual, expected): 585 with self.assertRaisesRegex(AssertionError, "shape"): 586 fn() 587 588 @unittest.skipIf(not torch.backends.mkldnn.is_available(), reason="MKLDNN is not available.") 589 def test_unknown_layout(self): 590 actual = torch.empty((2, 2)) 591 expected = actual.to_mkldnn() 592 593 for fn in assert_close_with_inputs(actual, expected): 594 with self.assertRaisesRegex(ValueError, "layout"): 595 fn() 596 597 def test_meta(self): 598 actual = torch.empty((2, 2), device="meta") 599 expected = torch.empty((2, 2), device="meta") 600 601 for fn in assert_close_with_inputs(actual, expected): 602 fn() 603 604 def test_mismatching_layout(self): 605 strided = torch.empty((2, 2)) 606 sparse_coo = strided.to_sparse() 607 sparse_csr = strided.to_sparse_csr() 608 609 for actual, expected in itertools.combinations((strided, sparse_coo, sparse_csr), 2): 610 for fn in assert_close_with_inputs(actual, expected): 611 with self.assertRaisesRegex(AssertionError, "layout"): 612 fn() 613 614 def test_mismatching_layout_no_check(self): 615 strided = torch.randn((2, 2)) 616 sparse_coo = strided.to_sparse() 617 sparse_csr = strided.to_sparse_csr() 618 619 for actual, expected in itertools.combinations((strided, sparse_coo, sparse_csr), 2): 620 for fn in assert_close_with_inputs(actual, expected): 621 fn(check_layout=False) 622 623 def test_mismatching_dtype(self): 624 actual = torch.empty((), dtype=torch.float) 625 expected = actual.clone().to(torch.int) 626 627 for fn in assert_close_with_inputs(actual, expected): 628 with self.assertRaisesRegex(AssertionError, "dtype"): 629 fn() 630 631 def test_mismatching_dtype_no_check(self): 632 actual = torch.ones((), dtype=torch.float) 633 expected = actual.clone().to(torch.int) 634 635 for fn in assert_close_with_inputs(actual, expected): 636 fn(check_dtype=False) 637 638 def test_mismatching_stride(self): 639 actual = torch.empty((2, 2)) 640 expected = torch.as_strided(actual.clone().t().contiguous(), actual.shape, actual.stride()[::-1]) 641 642 for fn in assert_close_with_inputs(actual, expected): 643 with self.assertRaisesRegex(AssertionError, "stride"): 644 fn(check_stride=True) 645 646 def test_mismatching_stride_no_check(self): 647 actual = torch.rand((2, 2)) 648 expected = torch.as_strided(actual.clone().t().contiguous(), actual.shape, actual.stride()[::-1]) 649 for fn in assert_close_with_inputs(actual, expected): 650 fn() 651 652 def test_only_rtol(self): 653 actual = torch.empty(()) 654 expected = actual.clone() 655 656 for fn in assert_close_with_inputs(actual, expected): 657 with self.assertRaises(ValueError): 658 fn(rtol=0.0) 659 660 def test_only_atol(self): 661 actual = torch.empty(()) 662 expected = actual.clone() 663 664 for fn in assert_close_with_inputs(actual, expected): 665 with self.assertRaises(ValueError): 666 fn(atol=0.0) 667 668 def test_mismatching_values(self): 669 actual = torch.tensor(1) 670 expected = torch.tensor(2) 671 672 for fn in assert_close_with_inputs(actual, expected): 673 with self.assertRaises(AssertionError): 674 fn() 675 676 def test_mismatching_values_rtol(self): 677 eps = 1e-3 678 actual = torch.tensor(1.0) 679 expected = torch.tensor(1.0 + eps) 680 681 for fn in assert_close_with_inputs(actual, expected): 682 with self.assertRaises(AssertionError): 683 fn(rtol=eps / 2, atol=0.0) 684 685 def test_mismatching_values_atol(self): 686 eps = 1e-3 687 actual = torch.tensor(0.0) 688 expected = torch.tensor(eps) 689 690 for fn in assert_close_with_inputs(actual, expected): 691 with self.assertRaises(AssertionError): 692 fn(rtol=0.0, atol=eps / 2) 693 694 def test_matching(self): 695 actual = torch.tensor(1.0) 696 expected = actual.clone() 697 698 torch.testing.assert_close(actual, expected) 699 700 def test_matching_rtol(self): 701 eps = 1e-3 702 actual = torch.tensor(1.0) 703 expected = torch.tensor(1.0 + eps) 704 705 for fn in assert_close_with_inputs(actual, expected): 706 fn(rtol=eps * 2, atol=0.0) 707 708 def test_matching_atol(self): 709 eps = 1e-3 710 actual = torch.tensor(0.0) 711 expected = torch.tensor(eps) 712 713 for fn in assert_close_with_inputs(actual, expected): 714 fn(rtol=0.0, atol=eps * 2) 715 716 # TODO: the code that this test was designed for was removed in https://github.com/pytorch/pytorch/pull/56058 717 # We need to check if this test is still needed or if this behavior is now enabled by default. 718 def test_matching_conjugate_bit(self): 719 actual = torch.tensor(complex(1, 1)).conj() 720 expected = torch.tensor(complex(1, -1)) 721 722 for fn in assert_close_with_inputs(actual, expected): 723 fn() 724 725 def test_matching_nan(self): 726 nan = float("NaN") 727 728 tests = ( 729 (nan, nan), 730 (complex(nan, 0), complex(0, nan)), 731 (complex(nan, nan), complex(nan, 0)), 732 (complex(nan, nan), complex(nan, nan)), 733 ) 734 735 for actual, expected in tests: 736 for fn in assert_close_with_inputs(actual, expected): 737 with self.assertRaises(AssertionError): 738 fn() 739 740 def test_matching_nan_with_equal_nan(self): 741 nan = float("NaN") 742 743 tests = ( 744 (nan, nan), 745 (complex(nan, 0), complex(0, nan)), 746 (complex(nan, nan), complex(nan, 0)), 747 (complex(nan, nan), complex(nan, nan)), 748 ) 749 750 for actual, expected in tests: 751 for fn in assert_close_with_inputs(actual, expected): 752 fn(equal_nan=True) 753 754 def test_numpy(self): 755 tensor = torch.rand(2, 2, dtype=torch.float32) 756 actual = tensor.numpy() 757 expected = actual.copy() 758 759 for fn in assert_close_with_inputs(actual, expected): 760 fn() 761 762 def test_scalar(self): 763 number = torch.randint(10, size=()).item() 764 for actual, expected in itertools.product((int(number), float(number), complex(number)), repeat=2): 765 check_dtype = type(actual) is type(expected) 766 767 for fn in assert_close_with_inputs(actual, expected): 768 fn(check_dtype=check_dtype) 769 770 def test_bool(self): 771 actual = torch.tensor([True, False]) 772 expected = actual.clone() 773 774 for fn in assert_close_with_inputs(actual, expected): 775 fn() 776 777 def test_none(self): 778 actual = expected = None 779 780 for fn in assert_close_with_inputs(actual, expected): 781 fn() 782 783 def test_none_mismatch(self): 784 expected = None 785 786 for actual in (False, 0, torch.nan, torch.tensor(torch.nan)): 787 for fn in assert_close_with_inputs(actual, expected): 788 with self.assertRaises(AssertionError): 789 fn() 790 791 792 def test_docstring_examples(self): 793 finder = doctest.DocTestFinder(verbose=False) 794 runner = doctest.DocTestRunner(verbose=False, optionflags=doctest.NORMALIZE_WHITESPACE) 795 globs = dict(torch=torch) 796 doctests = finder.find(torch.testing.assert_close, globs=globs)[0] 797 failures = [] 798 runner.run(doctests, out=lambda report: failures.append(report)) 799 if failures: 800 raise AssertionError(f"Doctest found {len(failures)} failures:\n\n" + "\n".join(failures)) 801 802 def test_default_tolerance_selection_mismatching_dtypes(self): 803 # If the default tolerances where selected based on the promoted dtype, i.e. float64, 804 # these tensors wouldn't be considered close. 805 actual = torch.tensor(0.99, dtype=torch.bfloat16) 806 expected = torch.tensor(1.0, dtype=torch.float64) 807 808 for fn in assert_close_with_inputs(actual, expected): 809 fn(check_dtype=False) 810 811 class UnexpectedException(Exception): 812 """The only purpose of this exception is to test ``assert_close``'s handling of unexpected exceptions. Thus, 813 the test should mock a component to raise this instead of the regular behavior. We avoid using a builtin 814 exception here to avoid triggering possible handling of them. 815 """ 816 817 @unittest.mock.patch("torch.testing._comparison.TensorLikePair.__init__", side_effect=UnexpectedException) 818 def test_unexpected_error_originate(self, _): 819 actual = torch.tensor(1.0) 820 expected = actual.clone() 821 822 with self.assertRaisesRegex(RuntimeError, "unexpected exception"): 823 torch.testing.assert_close(actual, expected) 824 825 @unittest.mock.patch("torch.testing._comparison.TensorLikePair.compare", side_effect=UnexpectedException) 826 def test_unexpected_error_compare(self, _): 827 actual = torch.tensor(1.0) 828 expected = actual.clone() 829 830 with self.assertRaisesRegex(RuntimeError, "unexpected exception"): 831 torch.testing.assert_close(actual, expected) 832 833 834 835 836class TestAssertCloseMultiDevice(TestCase): 837 @deviceCountAtLeast(1) 838 def test_mismatching_device(self, devices): 839 for actual_device, expected_device in itertools.permutations(("cpu", *devices), 2): 840 actual = torch.empty((), device=actual_device) 841 expected = actual.clone().to(expected_device) 842 for fn in assert_close_with_inputs(actual, expected): 843 with self.assertRaisesRegex(AssertionError, "device"): 844 fn() 845 846 @deviceCountAtLeast(1) 847 def test_mismatching_device_no_check(self, devices): 848 for actual_device, expected_device in itertools.permutations(("cpu", *devices), 2): 849 actual = torch.rand((), device=actual_device) 850 expected = actual.clone().to(expected_device) 851 for fn in assert_close_with_inputs(actual, expected): 852 fn(check_device=False) 853 854 855instantiate_device_type_tests(TestAssertCloseMultiDevice, globals(), only_for="cuda") 856 857 858class TestAssertCloseErrorMessage(TestCase): 859 def test_identifier_tensor_likes(self): 860 actual = torch.tensor([1, 2, 3, 4]) 861 expected = torch.tensor([1, 2, 5, 6]) 862 863 for fn in assert_close_with_inputs(actual, expected): 864 with self.assertRaisesRegex(AssertionError, re.escape("Tensor-likes")): 865 fn() 866 867 def test_identifier_scalars(self): 868 actual = 3 869 expected = 5 870 for fn in assert_close_with_inputs(actual, expected): 871 with self.assertRaisesRegex(AssertionError, re.escape("Scalars")): 872 fn() 873 874 def test_not_equal(self): 875 actual = torch.tensor([1, 2, 3, 4], dtype=torch.float32) 876 expected = torch.tensor([1, 2, 5, 6], dtype=torch.float32) 877 878 for fn in assert_close_with_inputs(actual, expected): 879 with self.assertRaisesRegex(AssertionError, re.escape("not equal")): 880 fn(rtol=0.0, atol=0.0) 881 882 def test_not_close(self): 883 actual = torch.tensor([1, 2, 3, 4], dtype=torch.float32) 884 expected = torch.tensor([1, 2, 5, 6], dtype=torch.float32) 885 886 for fn, (rtol, atol) in itertools.product( 887 assert_close_with_inputs(actual, expected), ((1.3e-6, 0.0), (0.0, 1e-5), (1.3e-6, 1e-5)) 888 ): 889 with self.assertRaisesRegex(AssertionError, re.escape("not close")): 890 fn(rtol=rtol, atol=atol) 891 892 def test_mismatched_elements(self): 893 actual = torch.tensor([1, 2, 3, 4]) 894 expected = torch.tensor([1, 2, 5, 6]) 895 896 for fn in assert_close_with_inputs(actual, expected): 897 with self.assertRaisesRegex(AssertionError, re.escape("Mismatched elements: 2 / 4 (50.0%)")): 898 fn() 899 900 def test_abs_diff(self): 901 actual = torch.tensor([[1, 2], [3, 4]]) 902 expected = torch.tensor([[1, 2], [5, 4]]) 903 904 for fn in assert_close_with_inputs(actual, expected): 905 with self.assertRaisesRegex(AssertionError, re.escape("Greatest absolute difference: 2 at index (1, 0)")): 906 fn() 907 908 def test_abs_diff_scalar(self): 909 actual = 3 910 expected = 5 911 912 for fn in assert_close_with_inputs(actual, expected): 913 with self.assertRaisesRegex(AssertionError, re.escape("Absolute difference: 2")): 914 fn() 915 916 def test_rel_diff(self): 917 actual = torch.tensor([[1, 2], [3, 4]]) 918 expected = torch.tensor([[1, 4], [3, 4]]) 919 920 for fn in assert_close_with_inputs(actual, expected): 921 with self.assertRaisesRegex(AssertionError, re.escape("Greatest relative difference: 0.5 at index (0, 1)")): 922 fn() 923 924 def test_rel_diff_scalar(self): 925 actual = 2 926 expected = 4 927 928 for fn in assert_close_with_inputs(actual, expected): 929 with self.assertRaisesRegex(AssertionError, re.escape("Relative difference: 0.5")): 930 fn() 931 932 def test_zero_div_zero(self): 933 actual = torch.tensor([1.0, 0.0]) 934 expected = torch.tensor([2.0, 0.0]) 935 936 for fn in assert_close_with_inputs(actual, expected): 937 # Although it looks complicated, this regex just makes sure that the word 'nan' is not part of the error 938 # message. That would happen if the 0 / 0 is used for the mismatch computation although it matches. 939 with self.assertRaisesRegex(AssertionError, "((?!nan).)*"): 940 fn() 941 942 def test_rtol(self): 943 rtol = 1e-3 944 945 actual = torch.tensor((1, 2)) 946 expected = torch.tensor((2, 2)) 947 948 for fn in assert_close_with_inputs(actual, expected): 949 with self.assertRaisesRegex(AssertionError, re.escape(f"(up to {rtol} allowed)")): 950 fn(rtol=rtol, atol=0.0) 951 952 def test_atol(self): 953 atol = 1e-3 954 955 actual = torch.tensor((1, 2)) 956 expected = torch.tensor((2, 2)) 957 958 for fn in assert_close_with_inputs(actual, expected): 959 with self.assertRaisesRegex(AssertionError, re.escape(f"(up to {atol} allowed)")): 960 fn(rtol=0.0, atol=atol) 961 962 def test_msg_str(self): 963 msg = "Custom error message!" 964 965 actual = torch.tensor(1) 966 expected = torch.tensor(2) 967 968 for fn in assert_close_with_inputs(actual, expected): 969 with self.assertRaisesRegex(AssertionError, msg): 970 fn(msg=msg) 971 972 def test_msg_callable(self): 973 msg = "Custom error message" 974 975 actual = torch.tensor(1) 976 expected = torch.tensor(2) 977 978 for fn in assert_close_with_inputs(actual, expected): 979 with self.assertRaisesRegex(AssertionError, msg): 980 fn(msg=lambda _: msg) 981 982 983class TestAssertCloseContainer(TestCase): 984 def test_sequence_mismatching_len(self): 985 actual = (torch.empty(()),) 986 expected = () 987 988 with self.assertRaises(AssertionError): 989 torch.testing.assert_close(actual, expected) 990 991 def test_sequence_mismatching_values_msg(self): 992 t1 = torch.tensor(1) 993 t2 = torch.tensor(2) 994 995 actual = (t1, t1) 996 expected = (t1, t2) 997 998 with self.assertRaisesRegex(AssertionError, re.escape("item [1]")): 999 torch.testing.assert_close(actual, expected) 1000 1001 def test_mapping_mismatching_keys(self): 1002 actual = {"a": torch.empty(())} 1003 expected = {} 1004 1005 with self.assertRaises(AssertionError): 1006 torch.testing.assert_close(actual, expected) 1007 1008 def test_mapping_mismatching_values_msg(self): 1009 t1 = torch.tensor(1) 1010 t2 = torch.tensor(2) 1011 1012 actual = {"a": t1, "b": t1} 1013 expected = {"a": t1, "b": t2} 1014 1015 with self.assertRaisesRegex(AssertionError, re.escape("item ['b']")): 1016 torch.testing.assert_close(actual, expected) 1017 1018 1019class TestAssertCloseSparseCOO(TestCase): 1020 def test_matching_coalesced(self): 1021 indices = ( 1022 (0, 1), 1023 (1, 0), 1024 ) 1025 values = (1, 2) 1026 actual = torch.sparse_coo_tensor(indices, values, size=(2, 2)).coalesce() 1027 expected = actual.clone() 1028 1029 for fn in assert_close_with_inputs(actual, expected): 1030 fn() 1031 1032 def test_matching_uncoalesced(self): 1033 indices = ( 1034 (0, 1), 1035 (1, 0), 1036 ) 1037 values = (1, 2) 1038 actual = torch.sparse_coo_tensor(indices, values, size=(2, 2)) 1039 expected = actual.clone() 1040 1041 for fn in assert_close_with_inputs(actual, expected): 1042 fn() 1043 1044 def test_mismatching_sparse_dims(self): 1045 t = torch.randn(2, 3, 4) 1046 actual = t.to_sparse() 1047 expected = t.to_sparse(2) 1048 1049 for fn in assert_close_with_inputs(actual, expected): 1050 with self.assertRaisesRegex(AssertionError, re.escape("number of sparse dimensions in sparse COO tensors")): 1051 fn() 1052 1053 def test_mismatching_nnz(self): 1054 actual_indices = ( 1055 (0, 1), 1056 (1, 0), 1057 ) 1058 actual_values = (1, 2) 1059 actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2)) 1060 1061 expected_indices = ( 1062 (0, 1, 1,), 1063 (1, 0, 0,), 1064 ) 1065 expected_values = (1, 1, 1) 1066 expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2)) 1067 1068 for fn in assert_close_with_inputs(actual, expected): 1069 with self.assertRaisesRegex(AssertionError, re.escape("number of specified values in sparse COO tensors")): 1070 fn() 1071 1072 def test_mismatching_indices_msg(self): 1073 actual_indices = ( 1074 (0, 1), 1075 (1, 0), 1076 ) 1077 actual_values = (1, 2) 1078 actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2)) 1079 1080 expected_indices = ( 1081 (0, 1), 1082 (1, 1), 1083 ) 1084 expected_values = (1, 2) 1085 expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2)) 1086 1087 for fn in assert_close_with_inputs(actual, expected): 1088 with self.assertRaisesRegex(AssertionError, re.escape("Sparse COO indices")): 1089 fn() 1090 1091 def test_mismatching_values_msg(self): 1092 actual_indices = ( 1093 (0, 1), 1094 (1, 0), 1095 ) 1096 actual_values = (1, 2) 1097 actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2)) 1098 1099 expected_indices = ( 1100 (0, 1), 1101 (1, 0), 1102 ) 1103 expected_values = (1, 3) 1104 expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2)) 1105 1106 for fn in assert_close_with_inputs(actual, expected): 1107 with self.assertRaisesRegex(AssertionError, re.escape("Sparse COO values")): 1108 fn() 1109 1110 1111@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Not all sandcastle jobs support CSR testing") 1112class TestAssertCloseSparseCSR(TestCase): 1113 def test_matching(self): 1114 crow_indices = (0, 1, 2) 1115 col_indices = (1, 0) 1116 values = (1, 2) 1117 actual = torch.sparse_csr_tensor(crow_indices, col_indices, values, size=(2, 2)) 1118 expected = actual.clone() 1119 1120 for fn in assert_close_with_inputs(actual, expected): 1121 fn() 1122 1123 def test_mismatching_crow_indices_msg(self): 1124 actual_crow_indices = (0, 1, 2) 1125 actual_col_indices = (0, 1) 1126 actual_values = (1, 2) 1127 actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2)) 1128 1129 expected_crow_indices = (0, 2, 2) 1130 expected_col_indices = actual_col_indices 1131 expected_values = actual_values 1132 expected = torch.sparse_csr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2)) 1133 1134 for fn in assert_close_with_inputs(actual, expected): 1135 with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSR crow_indices")): 1136 fn() 1137 1138 def test_mismatching_col_indices_msg(self): 1139 actual_crow_indices = (0, 1, 2) 1140 actual_col_indices = (1, 0) 1141 actual_values = (1, 2) 1142 actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2)) 1143 1144 expected_crow_indices = actual_crow_indices 1145 expected_col_indices = (1, 1) 1146 expected_values = actual_values 1147 expected = torch.sparse_csr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2)) 1148 1149 for fn in assert_close_with_inputs(actual, expected): 1150 with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSR col_indices")): 1151 fn() 1152 1153 def test_mismatching_values_msg(self): 1154 actual_crow_indices = (0, 1, 2) 1155 actual_col_indices = (1, 0) 1156 actual_values = (1, 2) 1157 actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2)) 1158 1159 expected_crow_indices = actual_crow_indices 1160 expected_col_indices = actual_col_indices 1161 expected_values = (1, 3) 1162 expected = torch.sparse_csr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2)) 1163 1164 for fn in assert_close_with_inputs(actual, expected): 1165 with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSR values")): 1166 fn() 1167 1168 1169@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Not all sandcastle jobs support CSC testing") 1170class TestAssertCloseSparseCSC(TestCase): 1171 def test_matching(self): 1172 ccol_indices = (0, 1, 2) 1173 row_indices = (1, 0) 1174 values = (1, 2) 1175 actual = torch.sparse_csc_tensor(ccol_indices, row_indices, values, size=(2, 2)) 1176 expected = actual.clone() 1177 1178 for fn in assert_close_with_inputs(actual, expected): 1179 fn() 1180 1181 def test_mismatching_ccol_indices_msg(self): 1182 actual_ccol_indices = (0, 1, 2) 1183 actual_row_indices = (0, 1) 1184 actual_values = (1, 2) 1185 actual = torch.sparse_csc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2)) 1186 1187 expected_ccol_indices = (0, 2, 2) 1188 expected_row_indices = actual_row_indices 1189 expected_values = actual_values 1190 expected = torch.sparse_csc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2)) 1191 1192 for fn in assert_close_with_inputs(actual, expected): 1193 with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSC ccol_indices")): 1194 fn() 1195 1196 def test_mismatching_row_indices_msg(self): 1197 actual_ccol_indices = (0, 1, 2) 1198 actual_row_indices = (1, 0) 1199 actual_values = (1, 2) 1200 actual = torch.sparse_csc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2)) 1201 1202 expected_ccol_indices = actual_ccol_indices 1203 expected_row_indices = (1, 1) 1204 expected_values = actual_values 1205 expected = torch.sparse_csc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2)) 1206 1207 for fn in assert_close_with_inputs(actual, expected): 1208 with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSC row_indices")): 1209 fn() 1210 1211 def test_mismatching_values_msg(self): 1212 actual_ccol_indices = (0, 1, 2) 1213 actual_row_indices = (1, 0) 1214 actual_values = (1, 2) 1215 actual = torch.sparse_csc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2)) 1216 1217 expected_ccol_indices = actual_ccol_indices 1218 expected_row_indices = actual_row_indices 1219 expected_values = (1, 3) 1220 expected = torch.sparse_csc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2)) 1221 1222 for fn in assert_close_with_inputs(actual, expected): 1223 with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSC values")): 1224 fn() 1225 1226 1227@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Not all sandcastle jobs support BSR testing") 1228class TestAssertCloseSparseBSR(TestCase): 1229 def test_matching(self): 1230 crow_indices = (0, 1, 2) 1231 col_indices = (1, 0) 1232 values = ([[1]], [[2]]) 1233 actual = torch.sparse_bsr_tensor(crow_indices, col_indices, values, size=(2, 2)) 1234 expected = actual.clone() 1235 1236 for fn in assert_close_with_inputs(actual, expected): 1237 fn() 1238 1239 def test_mismatching_crow_indices_msg(self): 1240 actual_crow_indices = (0, 1, 2) 1241 actual_col_indices = (0, 1) 1242 actual_values = ([[1]], [[2]]) 1243 actual = torch.sparse_bsr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2)) 1244 1245 expected_crow_indices = (0, 2, 2) 1246 expected_col_indices = actual_col_indices 1247 expected_values = actual_values 1248 expected = torch.sparse_bsr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2)) 1249 1250 for fn in assert_close_with_inputs(actual, expected): 1251 with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSR crow_indices")): 1252 fn() 1253 1254 def test_mismatching_col_indices_msg(self): 1255 actual_crow_indices = (0, 1, 2) 1256 actual_col_indices = (1, 0) 1257 actual_values = ([[1]], [[2]]) 1258 actual = torch.sparse_bsr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2)) 1259 1260 expected_crow_indices = actual_crow_indices 1261 expected_col_indices = (1, 1) 1262 expected_values = actual_values 1263 expected = torch.sparse_bsr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2)) 1264 1265 for fn in assert_close_with_inputs(actual, expected): 1266 with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSR col_indices")): 1267 fn() 1268 1269 def test_mismatching_values_msg(self): 1270 actual_crow_indices = (0, 1, 2) 1271 actual_col_indices = (1, 0) 1272 actual_values = ([[1]], [[2]]) 1273 actual = torch.sparse_bsr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2)) 1274 1275 expected_crow_indices = actual_crow_indices 1276 expected_col_indices = actual_col_indices 1277 expected_values = ([[1]], [[3]]) 1278 expected = torch.sparse_bsr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2)) 1279 1280 for fn in assert_close_with_inputs(actual, expected): 1281 with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSR values")): 1282 fn() 1283 1284 1285@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Not all sandcastle jobs support BSC testing") 1286class TestAssertCloseSparseBSC(TestCase): 1287 def test_matching(self): 1288 ccol_indices = (0, 1, 2) 1289 row_indices = (1, 0) 1290 values = ([[1]], [[2]]) 1291 actual = torch.sparse_bsc_tensor(ccol_indices, row_indices, values, size=(2, 2)) 1292 expected = actual.clone() 1293 1294 for fn in assert_close_with_inputs(actual, expected): 1295 fn() 1296 1297 def test_mismatching_ccol_indices_msg(self): 1298 actual_ccol_indices = (0, 1, 2) 1299 actual_row_indices = (0, 1) 1300 actual_values = ([[1]], [[2]]) 1301 actual = torch.sparse_bsc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2)) 1302 1303 expected_ccol_indices = (0, 2, 2) 1304 expected_row_indices = actual_row_indices 1305 expected_values = actual_values 1306 expected = torch.sparse_bsc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2)) 1307 1308 for fn in assert_close_with_inputs(actual, expected): 1309 with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSC ccol_indices")): 1310 fn() 1311 1312 def test_mismatching_row_indices_msg(self): 1313 actual_ccol_indices = (0, 1, 2) 1314 actual_row_indices = (1, 0) 1315 actual_values = ([[1]], [[2]]) 1316 actual = torch.sparse_bsc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2)) 1317 1318 expected_ccol_indices = actual_ccol_indices 1319 expected_row_indices = (1, 1) 1320 expected_values = actual_values 1321 expected = torch.sparse_bsc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2)) 1322 1323 for fn in assert_close_with_inputs(actual, expected): 1324 with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSC row_indices")): 1325 fn() 1326 1327 def test_mismatching_values_msg(self): 1328 actual_ccol_indices = (0, 1, 2) 1329 actual_row_indices = (1, 0) 1330 actual_values = ([[1]], [[2]]) 1331 actual = torch.sparse_bsc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2)) 1332 1333 expected_ccol_indices = actual_ccol_indices 1334 expected_row_indices = actual_row_indices 1335 expected_values = ([[1]], [[3]]) 1336 expected = torch.sparse_bsc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2)) 1337 1338 for fn in assert_close_with_inputs(actual, expected): 1339 with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSC values")): 1340 fn() 1341 1342 1343class TestAssertCloseQuantized(TestCase): 1344 def test_mismatching_is_quantized(self): 1345 actual = torch.tensor(1.0) 1346 expected = torch.quantize_per_tensor(actual, scale=1.0, zero_point=0, dtype=torch.qint32) 1347 1348 for fn in assert_close_with_inputs(actual, expected): 1349 with self.assertRaisesRegex(AssertionError, "is_quantized"): 1350 fn() 1351 1352 def test_mismatching_qscheme(self): 1353 t = torch.tensor((1.0,)) 1354 actual = torch.quantize_per_tensor(t, scale=1.0, zero_point=0, dtype=torch.qint32) 1355 expected = torch.quantize_per_channel( 1356 t, 1357 scales=torch.tensor((1.0,)), 1358 zero_points=torch.tensor((0,)), 1359 axis=0, 1360 dtype=torch.qint32, 1361 ) 1362 1363 for fn in assert_close_with_inputs(actual, expected): 1364 with self.assertRaisesRegex(AssertionError, "qscheme"): 1365 fn() 1366 1367 def test_matching_per_tensor(self): 1368 actual = torch.quantize_per_tensor(torch.tensor(1.0), scale=1.0, zero_point=0, dtype=torch.qint32) 1369 expected = actual.clone() 1370 1371 for fn in assert_close_with_inputs(actual, expected): 1372 fn() 1373 1374 def test_matching_per_channel(self): 1375 actual = torch.quantize_per_channel( 1376 torch.tensor((1.0,)), 1377 scales=torch.tensor((1.0,)), 1378 zero_points=torch.tensor((0,)), 1379 axis=0, 1380 dtype=torch.qint32, 1381 ) 1382 expected = actual.clone() 1383 1384 for fn in assert_close_with_inputs(actual, expected): 1385 fn() 1386 1387 1388class TestMakeTensor(TestCase): 1389 supported_dtypes = dtypes( 1390 torch.bool, 1391 torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, 1392 torch.float16, torch.bfloat16, torch.float32, torch.float64, 1393 torch.complex32, torch.complex64, torch.complex128, 1394 ) 1395 1396 @supported_dtypes 1397 @parametrize("shape", [(), (0,), (1,), (1, 1), (2,), (2, 3), (8, 16, 32)]) 1398 @parametrize("splat_shape", [False, True]) 1399 def test_smoke(self, dtype, device, shape, splat_shape): 1400 t = torch.testing.make_tensor(*shape if splat_shape else shape, dtype=dtype, device=device) 1401 1402 self.assertIsInstance(t, torch.Tensor) 1403 self.assertEqual(t.shape, shape) 1404 self.assertEqual(t.dtype, dtype) 1405 self.assertEqual(t.device, torch.device(device)) 1406 1407 @supported_dtypes 1408 @parametrize("requires_grad", [False, True]) 1409 def test_requires_grad(self, dtype, device, requires_grad): 1410 make_tensor = functools.partial( 1411 torch.testing.make_tensor, 1412 dtype=dtype, 1413 device=device, 1414 requires_grad=requires_grad, 1415 ) 1416 1417 if not requires_grad or dtype.is_floating_point or dtype.is_complex: 1418 t = make_tensor() 1419 self.assertEqual(t.requires_grad, requires_grad) 1420 else: 1421 with self.assertRaisesRegex( 1422 ValueError, "`requires_grad=True` is not supported for boolean and integral dtypes" 1423 ): 1424 make_tensor() 1425 1426 @supported_dtypes 1427 @parametrize("noncontiguous", [False, True]) 1428 @parametrize("shape", [(), (0,), (1,), (1, 1), (2,), (2, 3), (8, 16, 32)]) 1429 def test_noncontiguous(self, dtype, device, noncontiguous, shape): 1430 numel = functools.reduce(operator.mul, shape, 1) 1431 1432 t = torch.testing.make_tensor(shape, dtype=dtype, device=device, noncontiguous=noncontiguous) 1433 self.assertEqual(t.is_contiguous(), not noncontiguous or numel < 2) 1434 1435 @supported_dtypes 1436 @parametrize( 1437 "memory_format_and_shape", 1438 [ 1439 (None, (2, 3, 4)), 1440 (torch.contiguous_format, (2, 3, 4)), 1441 (torch.channels_last, (2, 3, 4, 5)), 1442 (torch.channels_last_3d, (2, 3, 4, 5, 6)), 1443 (torch.preserve_format, (2, 3, 4)), 1444 ], 1445 ) 1446 def test_memory_format(self, dtype, device, memory_format_and_shape): 1447 memory_format, shape = memory_format_and_shape 1448 1449 t = torch.testing.make_tensor(shape, dtype=dtype, device=device, memory_format=memory_format) 1450 1451 self.assertTrue( 1452 t.is_contiguous(memory_format=torch.contiguous_format if memory_format is None else memory_format) 1453 ) 1454 1455 @supported_dtypes 1456 def test_noncontiguous_memory_format(self, dtype, device): 1457 with self.assertRaisesRegex(ValueError, "`noncontiguous` and `memory_format` are mutually exclusive"): 1458 torch.testing.make_tensor( 1459 (2, 3, 4, 5), 1460 dtype=dtype, 1461 device=device, 1462 noncontiguous=True, 1463 memory_format=torch.channels_last, 1464 ) 1465 1466 @supported_dtypes 1467 def test_exclude_zero(self, dtype, device): 1468 t = torch.testing.make_tensor(10_000, dtype=dtype, device=device, exclude_zero=True, low=-1, high=2) 1469 1470 self.assertTrue((t != 0).all()) 1471 1472 @supported_dtypes 1473 def test_low_high_smoke(self, dtype, device): 1474 low_inclusive, high_exclusive = 0, 2 1475 1476 t = torch.testing.make_tensor(10_000, dtype=dtype, device=device, low=low_inclusive, high=high_exclusive) 1477 if dtype.is_complex: 1478 t = torch.view_as_real(t) 1479 1480 self.assertTrue(((t >= low_inclusive) & (t < high_exclusive)).all()) 1481 1482 @supported_dtypes 1483 def test_low_high_default_smoke(self, dtype, device): 1484 low_inclusive, high_exclusive = { 1485 torch.bool: (0, 2), 1486 torch.uint8: (0, 10), 1487 **dict.fromkeys([torch.int8, torch.int16, torch.int32, torch.int64], (-9, 10)), 1488 }.get(dtype, (-9, 9)) 1489 1490 t = torch.testing.make_tensor(10_000, dtype=dtype, device=device, low=low_inclusive, high=high_exclusive) 1491 if dtype.is_complex: 1492 t = torch.view_as_real(t) 1493 1494 self.assertTrue(((t >= low_inclusive) & (t < high_exclusive)).all()) 1495 1496 @parametrize("low_high", [(0, 0), (1, 0), (0, -1)]) 1497 @parametrize("value_types", list(itertools.product([int, float], repeat=2))) 1498 @supported_dtypes 1499 def test_low_ge_high(self, dtype, device, low_high, value_types): 1500 low, high = (value_type(value) for value, value_type in zip(low_high, value_types)) 1501 1502 if low == high and (dtype.is_floating_point or dtype.is_complex): 1503 with self.assertWarnsRegex( 1504 FutureWarning, 1505 "Passing `low==high` to `torch.testing.make_tensor` for floating or complex types is deprecated", 1506 ): 1507 t = torch.testing.make_tensor(10_000, dtype=dtype, device=device, low=low, high=high) 1508 self.assertEqual(t, torch.full_like(t, complex(low, low) if dtype.is_complex else low)) 1509 else: 1510 with self.assertRaisesRegex(ValueError, "`low` must be less than `high`"): 1511 torch.testing.make_tensor(dtype=dtype, device=device, low=low, high=high) 1512 1513 @supported_dtypes 1514 @parametrize("low_high", [(None, torch.nan), (torch.nan, None), (torch.nan, torch.nan)]) 1515 def test_low_high_nan(self, dtype, device, low_high): 1516 low, high = low_high 1517 1518 with self.assertRaisesRegex(ValueError, "`low` and `high` cannot be NaN"): 1519 torch.testing.make_tensor(dtype=dtype, device=device, low=low, high=high) 1520 1521 @supported_dtypes 1522 def test_low_high_outside_valid_range(self, dtype, device): 1523 make_tensor = functools.partial(torch.testing.make_tensor, dtype=dtype, device=device) 1524 1525 def get_dtype_limits(dtype): 1526 if dtype is torch.bool: 1527 return 0, 1 1528 1529 info = (torch.finfo if dtype.is_floating_point or dtype.is_complex else torch.iinfo)(dtype) 1530 # We are using integer bounds here, because otherwise it would be impossible to pass `low` and `high` 1531 # outside their valid range. Python uses 64bit floating point numbers and thus trying to do something like 1532 # `torch.ffinfo(torch.float64)max * 2` will always result in `inf`. On the flipside, Pythons `int` is 1533 # unbounded. 1534 return int(info.min), int(info.max) 1535 1536 lowest_inclusive, highest_inclusive = get_dtype_limits(dtype) 1537 1538 with self.assertRaisesRegex(ValueError, ""): 1539 low, high = (-2, -1) if lowest_inclusive == 0 else (lowest_inclusive * 4, lowest_inclusive * 2) 1540 make_tensor(low=low, high=high) 1541 1542 with self.assertRaisesRegex(ValueError, ""): 1543 make_tensor(low=highest_inclusive * 2, high=highest_inclusive * 4) 1544 1545 @dtypes(torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64) 1546 def test_low_high_boolean_integral1(self, dtype, device): 1547 shape = (10_000,) 1548 eps = 1e-4 1549 1550 actual = torch.testing.make_tensor(shape, dtype=dtype, device=device, low=-(1 - eps), high=1 - eps) 1551 expected = torch.zeros(shape, dtype=dtype, device=device) 1552 1553 torch.testing.assert_close(actual, expected) 1554 1555 @dtypes(torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64) 1556 def test_low_high_boolean_integral2(self, dtype, device): 1557 shape = (10_000,) 1558 if dtype is torch.bool: 1559 low = 1 1560 elif dtype is torch.int64: 1561 # Due to its internals, `make_tensor` is not able to sample `torch.iinfo(torch.int64).max` 1562 low = torch.iinfo(dtype).max - 1 1563 else: 1564 low = torch.iinfo(dtype).max 1565 high = low + 1 1566 1567 actual = torch.testing.make_tensor(shape, dtype=dtype, device=device, low=low, high=high) 1568 expected = torch.full(shape, low, dtype=dtype, device=device) 1569 1570 torch.testing.assert_close(actual, expected) 1571 1572 1573instantiate_device_type_tests(TestMakeTensor, globals()) 1574 1575 1576def _get_test_names_for_test_class(test_cls): 1577 """ Convenience function to get all test names for a given test class. """ 1578 test_names = [f'{test_cls.__name__}.{key}' for key in test_cls.__dict__ 1579 if key.startswith('test_')] 1580 return sorted(test_names) 1581 1582 1583def _get_test_funcs_for_test_class(test_cls): 1584 """ Convenience function to get all (test function, parametrized_name) pairs for a given test class. """ 1585 test_funcs = [(getattr(test_cls, key), key) for key in test_cls.__dict__ if key.startswith('test_')] 1586 return test_funcs 1587 1588 1589class TestTestParametrization(TestCase): 1590 def test_default_names(self): 1591 1592 class TestParametrized(TestCase): 1593 @parametrize("x", range(5)) 1594 def test_default_names(self, x): 1595 pass 1596 1597 @parametrize("x,y", [(1, 2), (2, 3), (3, 4)]) 1598 def test_two_things_default_names(self, x, y): 1599 pass 1600 1601 instantiate_parametrized_tests(TestParametrized) 1602 1603 expected_test_names = [ 1604 'TestParametrized.test_default_names_x_0', 1605 'TestParametrized.test_default_names_x_1', 1606 'TestParametrized.test_default_names_x_2', 1607 'TestParametrized.test_default_names_x_3', 1608 'TestParametrized.test_default_names_x_4', 1609 'TestParametrized.test_two_things_default_names_x_1_y_2', 1610 'TestParametrized.test_two_things_default_names_x_2_y_3', 1611 'TestParametrized.test_two_things_default_names_x_3_y_4', 1612 ] 1613 test_names = _get_test_names_for_test_class(TestParametrized) 1614 self.assertEqual(expected_test_names, test_names) 1615 1616 def test_name_fn(self): 1617 1618 class TestParametrized(TestCase): 1619 @parametrize("bias", [False, True], name_fn=lambda b: 'bias' if b else 'no_bias') 1620 def test_custom_names(self, bias): 1621 pass 1622 1623 @parametrize("x", [1, 2], name_fn=str) 1624 @parametrize("y", [3, 4], name_fn=str) 1625 @parametrize("z", [5, 6], name_fn=str) 1626 def test_three_things_composition_custom_names(self, x, y, z): 1627 pass 1628 1629 @parametrize("x,y", [(1, 2), (1, 3), (1, 4)], name_fn=lambda x, y: f'{x}__{y}') 1630 def test_two_things_custom_names_alternate(self, x, y): 1631 pass 1632 1633 instantiate_parametrized_tests(TestParametrized) 1634 1635 expected_test_names = [ 1636 'TestParametrized.test_custom_names_bias', 1637 'TestParametrized.test_custom_names_no_bias', 1638 'TestParametrized.test_three_things_composition_custom_names_1_3_5', 1639 'TestParametrized.test_three_things_composition_custom_names_1_3_6', 1640 'TestParametrized.test_three_things_composition_custom_names_1_4_5', 1641 'TestParametrized.test_three_things_composition_custom_names_1_4_6', 1642 'TestParametrized.test_three_things_composition_custom_names_2_3_5', 1643 'TestParametrized.test_three_things_composition_custom_names_2_3_6', 1644 'TestParametrized.test_three_things_composition_custom_names_2_4_5', 1645 'TestParametrized.test_three_things_composition_custom_names_2_4_6', 1646 'TestParametrized.test_two_things_custom_names_alternate_1__2', 1647 'TestParametrized.test_two_things_custom_names_alternate_1__3', 1648 'TestParametrized.test_two_things_custom_names_alternate_1__4', 1649 ] 1650 test_names = _get_test_names_for_test_class(TestParametrized) 1651 self.assertEqual(expected_test_names, test_names) 1652 1653 def test_subtest_names(self): 1654 1655 class TestParametrized(TestCase): 1656 @parametrize("bias", [subtest(True, name='bias'), 1657 subtest(False, name='no_bias')]) 1658 def test_custom_names(self, bias): 1659 pass 1660 1661 @parametrize("x,y", [subtest((1, 2), name='double'), 1662 subtest((1, 3), name='triple'), 1663 subtest((1, 4), name='quadruple')]) 1664 def test_two_things_custom_names(self, x, y): 1665 pass 1666 1667 instantiate_parametrized_tests(TestParametrized) 1668 1669 expected_test_names = [ 1670 'TestParametrized.test_custom_names_bias', 1671 'TestParametrized.test_custom_names_no_bias', 1672 'TestParametrized.test_two_things_custom_names_double', 1673 'TestParametrized.test_two_things_custom_names_quadruple', 1674 'TestParametrized.test_two_things_custom_names_triple', 1675 ] 1676 test_names = _get_test_names_for_test_class(TestParametrized) 1677 self.assertEqual(expected_test_names, test_names) 1678 1679 def test_apply_param_specific_decorators(self): 1680 # Test that decorators can be applied on a per-param basis. 1681 1682 def test_dec(func): 1683 func._decorator_applied = True 1684 return func 1685 1686 class TestParametrized(TestCase): 1687 @parametrize("x", [subtest(1, name='one'), 1688 subtest(2, name='two', decorators=[test_dec]), 1689 subtest(3, name='three')]) 1690 def test_param(self, x): 1691 pass 1692 1693 instantiate_parametrized_tests(TestParametrized) 1694 1695 for test_func, name in _get_test_funcs_for_test_class(TestParametrized): 1696 self.assertEqual(hasattr(test_func, '_decorator_applied'), name == 'test_param_two') 1697 1698 def test_compose_param_specific_decorators(self): 1699 # Test that multiple per-param decorators compose correctly. 1700 1701 def test_dec(func): 1702 func._decorator_applied = True 1703 return func 1704 1705 class TestParametrized(TestCase): 1706 @parametrize("x", [subtest(1), 1707 subtest(2, decorators=[test_dec]), 1708 subtest(3)]) 1709 @parametrize("y", [subtest(False, decorators=[test_dec]), 1710 subtest(True)]) 1711 def test_param(self, x, y): 1712 pass 1713 1714 instantiate_parametrized_tests(TestParametrized) 1715 1716 for test_func, name in _get_test_funcs_for_test_class(TestParametrized): 1717 # Decorator should be applied whenever either x == 2 or y == False. 1718 should_apply = ('x_2' in name) or ('y_False' in name) 1719 self.assertEqual(hasattr(test_func, '_decorator_applied'), should_apply) 1720 1721 def test_modules_decorator_misuse_error(self): 1722 # Test that @modules errors out when used with instantiate_parametrized_tests(). 1723 1724 class TestParametrized(TestCase): 1725 @modules(module_db) 1726 def test_modules(self, module_info): 1727 pass 1728 1729 with self.assertRaisesRegex(RuntimeError, 'intended to be used in a device-specific context'): 1730 instantiate_parametrized_tests(TestParametrized) 1731 1732 def test_ops_decorator_misuse_error(self): 1733 # Test that @ops errors out when used with instantiate_parametrized_tests(). 1734 1735 class TestParametrized(TestCase): 1736 @ops(op_db) 1737 def test_ops(self, module_info): 1738 pass 1739 1740 with self.assertRaisesRegex(RuntimeError, 'intended to be used in a device-specific context'): 1741 instantiate_parametrized_tests(TestParametrized) 1742 1743 def test_multiple_handling_of_same_param_error(self): 1744 # Test that multiple decorators handling the same param errors out. 1745 1746 class TestParametrized(TestCase): 1747 @parametrize("x", range(3)) 1748 @parametrize("x", range(5)) 1749 def test_param(self, x): 1750 pass 1751 1752 with self.assertRaisesRegex(RuntimeError, 'multiple parametrization decorators'): 1753 instantiate_parametrized_tests(TestParametrized) 1754 1755 @parametrize("x", [1, subtest(2, decorators=[unittest.expectedFailure]), 3]) 1756 def test_subtest_expected_failure(self, x): 1757 if x == 2: 1758 raise RuntimeError('Boom') 1759 1760 @parametrize("x", [subtest(1, decorators=[unittest.expectedFailure]), 2, 3]) 1761 @parametrize("y", [4, 5, subtest(6, decorators=[unittest.expectedFailure])]) 1762 def test_two_things_subtest_expected_failure(self, x, y): 1763 if x == 1 or y == 6: 1764 raise RuntimeError('Boom') 1765 1766 1767class TestTestParametrizationDeviceType(TestCase): 1768 def test_unparametrized_names(self, device): 1769 # This test exists to protect against regressions in device / dtype test naming 1770 # due to parametrization logic. 1771 1772 device = self.device_type 1773 1774 class TestParametrized(TestCase): 1775 def test_device_specific(self, device): 1776 pass 1777 1778 @dtypes(torch.float32, torch.float64) 1779 def test_device_dtype_specific(self, device, dtype): 1780 pass 1781 1782 instantiate_device_type_tests(TestParametrized, locals(), only_for=device) 1783 1784 device_cls = locals()[f'TestParametrized{device.upper()}'] 1785 expected_test_names = [name.format(device_cls.__name__, device) for name in ( 1786 '{}.test_device_dtype_specific_{}_float32', 1787 '{}.test_device_dtype_specific_{}_float64', 1788 '{}.test_device_specific_{}') 1789 ] 1790 test_names = _get_test_names_for_test_class(device_cls) 1791 self.assertEqual(expected_test_names, test_names) 1792 1793 def test_empty_param_names(self, device): 1794 # If no param names are passed, ensure things still work without parametrization. 1795 device = self.device_type 1796 1797 class TestParametrized(TestCase): 1798 @parametrize("", []) 1799 def test_foo(self, device): 1800 pass 1801 1802 @parametrize("", range(5)) 1803 def test_bar(self, device): 1804 pass 1805 1806 instantiate_device_type_tests(TestParametrized, locals(), only_for=device) 1807 1808 device_cls = locals()[f'TestParametrized{device.upper()}'] 1809 expected_test_names = [name.format(device_cls.__name__, device) for name in ( 1810 '{}.test_bar_{}', 1811 '{}.test_foo_{}') 1812 ] 1813 test_names = _get_test_names_for_test_class(device_cls) 1814 self.assertEqual(expected_test_names, test_names) 1815 1816 def test_empty_param_list(self, device): 1817 # If no param values are passed, ensure a helpful error message is thrown. 1818 # In the wild, this could indicate reuse of an exhausted generator. 1819 device = self.device_type 1820 1821 generator = (a for a in range(5)) 1822 1823 class TestParametrized(TestCase): 1824 @parametrize("x", generator) 1825 def test_foo(self, device, x): 1826 pass 1827 1828 # Reuse generator from first test function. 1829 @parametrize("y", generator) 1830 def test_bar(self, device, y): 1831 pass 1832 1833 with self.assertRaisesRegex(ValueError, 'An empty arg_values was passed'): 1834 instantiate_device_type_tests(TestParametrized, locals(), only_for=device) 1835 1836 def test_default_names(self, device): 1837 device = self.device_type 1838 1839 class TestParametrized(TestCase): 1840 @parametrize("x", range(5)) 1841 def test_default_names(self, device, x): 1842 pass 1843 1844 @parametrize("x,y", [(1, 2), (2, 3), (3, 4)]) 1845 def test_two_things_default_names(self, device, x, y): 1846 pass 1847 1848 1849 instantiate_device_type_tests(TestParametrized, locals(), only_for=device) 1850 1851 device_cls = locals()[f'TestParametrized{device.upper()}'] 1852 expected_test_names = [name.format(device_cls.__name__, device) for name in ( 1853 '{}.test_default_names_x_0_{}', 1854 '{}.test_default_names_x_1_{}', 1855 '{}.test_default_names_x_2_{}', 1856 '{}.test_default_names_x_3_{}', 1857 '{}.test_default_names_x_4_{}', 1858 '{}.test_two_things_default_names_x_1_y_2_{}', 1859 '{}.test_two_things_default_names_x_2_y_3_{}', 1860 '{}.test_two_things_default_names_x_3_y_4_{}') 1861 ] 1862 test_names = _get_test_names_for_test_class(device_cls) 1863 self.assertEqual(expected_test_names, test_names) 1864 1865 def test_default_name_non_primitive(self, device): 1866 device = self.device_type 1867 1868 class TestParametrized(TestCase): 1869 @parametrize("x", [1, .5, "foo", object()]) 1870 def test_default_names(self, device, x): 1871 pass 1872 1873 @parametrize("x,y", [(1, object()), (object(), .5), (object(), object())]) 1874 def test_two_things_default_names(self, device, x, y): 1875 pass 1876 1877 instantiate_device_type_tests(TestParametrized, locals(), only_for=device) 1878 1879 device_cls = locals()[f'TestParametrized{device.upper()}'] 1880 expected_test_names = sorted(name.format(device_cls.__name__, device) for name in ( 1881 '{}.test_default_names_x_1_{}', 1882 '{}.test_default_names_x_0_5_{}', 1883 '{}.test_default_names_x_foo_{}', 1884 '{}.test_default_names_x3_{}', 1885 '{}.test_two_things_default_names_x_1_y0_{}', 1886 '{}.test_two_things_default_names_x1_y_0_5_{}', 1887 '{}.test_two_things_default_names_x2_y2_{}') 1888 ) 1889 test_names = _get_test_names_for_test_class(device_cls) 1890 self.assertEqual(expected_test_names, test_names) 1891 1892 def test_name_fn(self, device): 1893 device = self.device_type 1894 1895 class TestParametrized(TestCase): 1896 @parametrize("bias", [False, True], name_fn=lambda b: 'bias' if b else 'no_bias') 1897 def test_custom_names(self, device, bias): 1898 pass 1899 1900 @parametrize("x", [1, 2], name_fn=str) 1901 @parametrize("y", [3, 4], name_fn=str) 1902 @parametrize("z", [5, 6], name_fn=str) 1903 def test_three_things_composition_custom_names(self, device, x, y, z): 1904 pass 1905 1906 @parametrize("x,y", [(1, 2), (1, 3), (1, 4)], name_fn=lambda x, y: f'{x}__{y}') 1907 def test_two_things_custom_names_alternate(self, device, x, y): 1908 pass 1909 1910 instantiate_device_type_tests(TestParametrized, locals(), only_for=device) 1911 1912 device_cls = locals()[f'TestParametrized{device.upper()}'] 1913 expected_test_names = [name.format(device_cls.__name__, device) for name in ( 1914 '{}.test_custom_names_bias_{}', 1915 '{}.test_custom_names_no_bias_{}', 1916 '{}.test_three_things_composition_custom_names_1_3_5_{}', 1917 '{}.test_three_things_composition_custom_names_1_3_6_{}', 1918 '{}.test_three_things_composition_custom_names_1_4_5_{}', 1919 '{}.test_three_things_composition_custom_names_1_4_6_{}', 1920 '{}.test_three_things_composition_custom_names_2_3_5_{}', 1921 '{}.test_three_things_composition_custom_names_2_3_6_{}', 1922 '{}.test_three_things_composition_custom_names_2_4_5_{}', 1923 '{}.test_three_things_composition_custom_names_2_4_6_{}', 1924 '{}.test_two_things_custom_names_alternate_1__2_{}', 1925 '{}.test_two_things_custom_names_alternate_1__3_{}', 1926 '{}.test_two_things_custom_names_alternate_1__4_{}') 1927 ] 1928 test_names = _get_test_names_for_test_class(device_cls) 1929 self.assertEqual(expected_test_names, test_names) 1930 1931 def test_subtest_names(self, device): 1932 device = self.device_type 1933 1934 class TestParametrized(TestCase): 1935 @parametrize("bias", [subtest(True, name='bias'), 1936 subtest(False, name='no_bias')]) 1937 def test_custom_names(self, device, bias): 1938 pass 1939 1940 @parametrize("x,y", [subtest((1, 2), name='double'), 1941 subtest((1, 3), name='triple'), 1942 subtest((1, 4), name='quadruple')]) 1943 def test_two_things_custom_names(self, device, x, y): 1944 pass 1945 1946 instantiate_device_type_tests(TestParametrized, locals(), only_for=device) 1947 1948 device_cls = locals()[f'TestParametrized{device.upper()}'] 1949 expected_test_names = [name.format(device_cls.__name__, device) for name in ( 1950 '{}.test_custom_names_bias_{}', 1951 '{}.test_custom_names_no_bias_{}', 1952 '{}.test_two_things_custom_names_double_{}', 1953 '{}.test_two_things_custom_names_quadruple_{}', 1954 '{}.test_two_things_custom_names_triple_{}') 1955 ] 1956 test_names = _get_test_names_for_test_class(device_cls) 1957 self.assertEqual(expected_test_names, test_names) 1958 1959 def test_ops_composition_names(self, device): 1960 device = self.device_type 1961 1962 class TestParametrized(TestCase): 1963 @ops(op_db) 1964 @parametrize("flag", [False, True], lambda f: 'flag_enabled' if f else 'flag_disabled') 1965 def test_op_parametrized(self, device, dtype, op, flag): 1966 pass 1967 1968 instantiate_device_type_tests(TestParametrized, locals(), only_for=device) 1969 1970 device_cls = locals()[f'TestParametrized{device.upper()}'] 1971 expected_test_names = [] 1972 for op in op_db: 1973 for dtype in op.supported_dtypes(torch.device(device).type): 1974 for flag_part in ('flag_disabled', 'flag_enabled'): 1975 expected_name = f'{device_cls.__name__}.test_op_parametrized_{op.formatted_name}_{flag_part}_{device}_{dtype_name(dtype)}' # noqa: B950 1976 expected_test_names.append(expected_name) 1977 1978 test_names = _get_test_names_for_test_class(device_cls) 1979 self.assertEqual(sorted(expected_test_names), sorted(test_names)) 1980 1981 def test_modules_composition_names(self, device): 1982 device = self.device_type 1983 1984 class TestParametrized(TestCase): 1985 @modules(module_db) 1986 @parametrize("flag", [False, True], lambda f: 'flag_enabled' if f else 'flag_disabled') 1987 def test_module_parametrized(self, device, dtype, module_info, training, flag): 1988 pass 1989 1990 instantiate_device_type_tests(TestParametrized, locals(), only_for=device) 1991 1992 device_cls = locals()[f'TestParametrized{device.upper()}'] 1993 expected_test_names = [] 1994 for module_info in module_db: 1995 for dtype in module_info.dtypes: 1996 for flag_part in ('flag_disabled', 'flag_enabled'): 1997 expected_train_modes = ( 1998 ['train_mode', 'eval_mode'] if module_info.train_and_eval_differ else ['']) 1999 for training_part in expected_train_modes: 2000 expected_name = '{}.test_module_parametrized_{}{}_{}_{}_{}'.format( 2001 device_cls.__name__, module_info.formatted_name, 2002 '_' + training_part if len(training_part) > 0 else '', 2003 flag_part, device, dtype_name(dtype)) 2004 expected_test_names.append(expected_name) 2005 2006 test_names = _get_test_names_for_test_class(device_cls) 2007 self.assertEqual(sorted(expected_test_names), sorted(test_names)) 2008 2009 def test_ops_decorator_applies_op_and_param_specific_decorators(self, device): 2010 # Test that decorators can be applied on a per-op / per-param basis. 2011 2012 # Create a test op, OpInfo entry, and decorator to apply. 2013 def test_op(x): 2014 return -x 2015 2016 def test_dec(func): 2017 func._decorator_applied = True 2018 return func 2019 2020 test_op_info = OpInfo( 2021 'test_op', 2022 op=test_op, 2023 dtypes=floating_types(), 2024 sample_inputs_func=lambda _: [], 2025 decorators=[ 2026 DecorateInfo(test_dec, 'TestParametrized', 'test_op_param', 2027 device_type='cpu', dtypes=[torch.float64], 2028 active_if=lambda p: p['x'] == 2) 2029 ]) 2030 2031 class TestParametrized(TestCase): 2032 @ops(op_db + [test_op_info]) 2033 @parametrize("x", [2, 3]) 2034 def test_op_param(self, device, dtype, op, x): 2035 pass 2036 2037 @ops(op_db + [test_op_info]) 2038 @parametrize("y", [ 2039 subtest(4), 2040 subtest(5, decorators=[test_dec])]) 2041 def test_other(self, device, dtype, op, y): 2042 pass 2043 2044 @decorateIf(test_dec, lambda p: p['dtype'] == torch.int16) 2045 @ops(op_db) 2046 def test_three(self, device, dtype, op): 2047 pass 2048 2049 device = self.device_type 2050 instantiate_device_type_tests(TestParametrized, locals(), only_for=device) 2051 device_cls = locals()[f'TestParametrized{device.upper()}'] 2052 2053 for test_func, name in _get_test_funcs_for_test_class(device_cls): 2054 should_apply = (name == 'test_op_param_test_op_x_2_cpu_float64' or 2055 ('test_other' in name and 'y_5' in name) or 2056 ('test_three' in name and name.endswith('_int16'))) 2057 self.assertEqual(hasattr(test_func, '_decorator_applied'), should_apply) 2058 2059 def test_modules_decorator_applies_module_and_param_specific_decorators(self, device): 2060 # Test that decorators can be applied on a per-module / per-param basis. 2061 2062 # Create a test module, ModuleInfo entry, and decorator to apply. 2063 class TestModule(torch.nn.Module): 2064 def __init__(self) -> None: 2065 super().__init__() 2066 self.x = torch.nn.Parameter(torch.randn(3)) 2067 2068 def forward(self, y): 2069 return self.x + y 2070 2071 def test_dec(func): 2072 func._decorator_applied = True 2073 return func 2074 2075 test_module_info = ModuleInfo( 2076 TestModule, 2077 module_inputs_func=lambda _: [], 2078 decorators=[ 2079 DecorateInfo(test_dec, 'TestParametrized', 'test_module_param', 2080 device_type='cpu', dtypes=[torch.float64], 2081 active_if=lambda p: p['x'] == 2) 2082 ]) 2083 2084 class TestParametrized(TestCase): 2085 @modules(module_db + [test_module_info]) 2086 @parametrize("x", [2, 3]) 2087 def test_module_param(self, device, dtype, module_info, training, x): 2088 pass 2089 2090 @modules(module_db + [test_module_info]) 2091 @parametrize("y", [ 2092 subtest(4), 2093 subtest(5, decorators=[test_dec])]) 2094 def test_other(self, device, dtype, module_info, training, y): 2095 pass 2096 2097 @decorateIf(test_dec, lambda p: p['dtype'] == torch.float64) 2098 @modules(module_db) 2099 def test_three(self, device, dtype, module_info): 2100 pass 2101 2102 device = self.device_type 2103 instantiate_device_type_tests(TestParametrized, locals(), only_for=device) 2104 device_cls = locals()[f'TestParametrized{device.upper()}'] 2105 2106 for test_func, name in _get_test_funcs_for_test_class(device_cls): 2107 should_apply = (name == 'test_module_param_TestModule_x_2_cpu_float64' or 2108 ('test_other' in name and 'y_5' in name) or 2109 ('test_three' in name and name.endswith('float64'))) 2110 self.assertEqual(hasattr(test_func, '_decorator_applied'), should_apply) 2111 2112 def test_param_specific_decoration(self, device): 2113 2114 def test_dec(func): 2115 func._decorator_applied = True 2116 return func 2117 2118 class TestParametrized(TestCase): 2119 @decorateIf(test_dec, lambda params: params["x"] == 1 and params["y"]) 2120 @parametrize("x", range(5)) 2121 @parametrize("y", [False, True]) 2122 def test_param(self, x, y): 2123 pass 2124 2125 device = self.device_type 2126 instantiate_device_type_tests(TestParametrized, locals(), only_for=device) 2127 device_cls = locals()[f'TestParametrized{device.upper()}'] 2128 2129 for test_func, name in _get_test_funcs_for_test_class(device_cls): 2130 should_apply = ('test_param_x_1_y_True' in name) 2131 self.assertEqual(hasattr(test_func, '_decorator_applied'), should_apply) 2132 2133 def test_dtypes_composition_valid(self, device): 2134 # Test checks that @parametrize and @dtypes compose as expected when @parametrize 2135 # doesn't set dtype. 2136 2137 device = self.device_type 2138 2139 class TestParametrized(TestCase): 2140 @dtypes(torch.float32, torch.float64) 2141 @parametrize("x", range(3)) 2142 def test_parametrized(self, x, dtype): 2143 pass 2144 2145 instantiate_device_type_tests(TestParametrized, locals(), only_for=device) 2146 2147 device_cls = locals()[f'TestParametrized{device.upper()}'] 2148 expected_test_names = [name.format(device_cls.__name__, device) for name in ( 2149 '{}.test_parametrized_x_0_{}_float32', 2150 '{}.test_parametrized_x_0_{}_float64', 2151 '{}.test_parametrized_x_1_{}_float32', 2152 '{}.test_parametrized_x_1_{}_float64', 2153 '{}.test_parametrized_x_2_{}_float32', 2154 '{}.test_parametrized_x_2_{}_float64') 2155 ] 2156 test_names = _get_test_names_for_test_class(device_cls) 2157 self.assertEqual(sorted(expected_test_names), sorted(test_names)) 2158 2159 def test_dtypes_composition_invalid(self, device): 2160 # Test checks that @dtypes cannot be composed with parametrization decorators when they 2161 # also try to set dtype. 2162 2163 device = self.device_type 2164 2165 class TestParametrized(TestCase): 2166 @dtypes(torch.float32, torch.float64) 2167 @parametrize("dtype", [torch.int32, torch.int64]) 2168 def test_parametrized(self, dtype): 2169 pass 2170 2171 with self.assertRaisesRegex(RuntimeError, "handled multiple times"): 2172 instantiate_device_type_tests(TestParametrized, locals(), only_for=device) 2173 2174 # Verify proper error behavior with @ops + @dtypes, as both try to set dtype. 2175 2176 class TestParametrized(TestCase): 2177 @dtypes(torch.float32, torch.float64) 2178 @ops(op_db) 2179 def test_parametrized(self, op, dtype): 2180 pass 2181 2182 with self.assertRaisesRegex(RuntimeError, "handled multiple times"): 2183 instantiate_device_type_tests(TestParametrized, locals(), only_for=device) 2184 2185 def test_multiple_handling_of_same_param_error(self, device): 2186 # Test that multiple decorators handling the same param errors out. 2187 # Both @modules and @ops handle the dtype param. 2188 2189 class TestParametrized(TestCase): 2190 @ops(op_db) 2191 @modules(module_db) 2192 def test_param(self, device, dtype, op, module_info, training): 2193 pass 2194 2195 with self.assertRaisesRegex(RuntimeError, "handled multiple times"): 2196 instantiate_device_type_tests(TestParametrized, locals(), only_for=device) 2197 2198 @parametrize("x", [1, subtest(2, decorators=[unittest.expectedFailure]), 3]) 2199 def test_subtest_expected_failure(self, device, x): 2200 if x == 2: 2201 raise RuntimeError('Boom') 2202 2203 @parametrize("x", [subtest(1, decorators=[unittest.expectedFailure]), 2, 3]) 2204 @parametrize("y", [4, 5, subtest(6, decorators=[unittest.expectedFailure])]) 2205 def test_two_things_subtest_expected_failure(self, device, x, y): 2206 if x == 1 or y == 6: 2207 raise RuntimeError('Boom') 2208 2209 2210instantiate_parametrized_tests(TestTestParametrization) 2211instantiate_device_type_tests(TestTestParametrizationDeviceType, globals()) 2212 2213 2214class TestImports(TestCase): 2215 @classmethod 2216 def _check_python_output(cls, program) -> str: 2217 return subprocess.check_output( 2218 [sys.executable, "-W", "always", "-c", program], 2219 stderr=subprocess.STDOUT, 2220 # On Windows, opening the subprocess with the default CWD makes `import torch` 2221 # fail, so just set CWD to this script's directory 2222 cwd=os.path.dirname(os.path.realpath(__file__)),).decode("utf-8") 2223 2224 def test_circular_dependencies(self) -> None: 2225 """ Checks that all modules inside torch can be imported 2226 Prevents regression reported in https://github.com/pytorch/pytorch/issues/77441 """ 2227 ignored_modules = ["torch.utils.tensorboard", # deps on tensorboard 2228 "torch.distributed.elastic.rendezvous", # depps on etcd 2229 "torch.backends._coreml", # depends on pycoreml 2230 "torch.contrib.", # something weird 2231 "torch.testing._internal.distributed.", # just fails 2232 "torch.ao.pruning._experimental.", # depends on pytorch_lightning, not user-facing 2233 "torch.onnx._internal", # depends on onnx-script 2234 "torch._inductor.runtime.triton_helpers", # depends on triton 2235 "torch._inductor.codegen.cuda", # depends on cutlass 2236 ] 2237 # See https://github.com/pytorch/pytorch/issues/77801 2238 if not sys.version_info >= (3, 9): 2239 ignored_modules.append("torch.utils.benchmark") 2240 if IS_WINDOWS or IS_MACOS or IS_JETSON: 2241 # Distributed should be importable on Windows(except nn.api.), but not on Mac 2242 if IS_MACOS or IS_JETSON: 2243 ignored_modules.append("torch.distributed.") 2244 else: 2245 ignored_modules.append("torch.distributed.nn.api.") 2246 ignored_modules.append("torch.distributed.optim.") 2247 ignored_modules.append("torch.distributed.rpc.") 2248 ignored_modules.append("torch.testing._internal.dist_utils") 2249 # And these both end up with transitive dependencies on distributed 2250 ignored_modules.append("torch.nn.parallel._replicated_tensor_ddp_interop") 2251 ignored_modules.append("torch.testing._internal.common_fsdp") 2252 ignored_modules.append("torch.testing._internal.common_distributed") 2253 2254 torch_dir = os.path.dirname(torch.__file__) 2255 for base, folders, files in os.walk(torch_dir): 2256 prefix = os.path.relpath(base, os.path.dirname(torch_dir)).replace(os.path.sep, ".") 2257 for f in files: 2258 if not f.endswith(".py"): 2259 continue 2260 mod_name = f"{prefix}.{f[:-3]}" if f != "__init__.py" else prefix 2261 # Do not attempt to import executable modules 2262 if f == "__main__.py": 2263 continue 2264 if any(mod_name.startswith(x) for x in ignored_modules): 2265 continue 2266 try: 2267 mod = importlib.import_module(mod_name) 2268 except Exception as e: 2269 raise RuntimeError(f"Failed to import {mod_name}: {e}") from e 2270 self.assertTrue(inspect.ismodule(mod)) 2271 2272 @unittest.skipIf(IS_WINDOWS, "TODO enable on Windows") 2273 def test_lazy_imports_are_lazy(self) -> None: 2274 out = self._check_python_output("import sys;import torch;print(all(x not in sys.modules for x in torch._lazy_modules))") 2275 self.assertEqual(out.strip(), "True") 2276 2277 @unittest.skipIf(IS_WINDOWS, "importing torch+CUDA on CPU results in warning") 2278 def test_no_warning_on_import(self) -> None: 2279 out = self._check_python_output("import torch") 2280 self.assertEqual(out, "") 2281 2282 def test_not_import_sympy(self) -> None: 2283 out = self._check_python_output("import torch;import sys;print('sympy' not in sys.modules)") 2284 self.assertEqual(out.strip(), "True", 2285 "PyTorch should not depend on SymPy at import time as importing SymPy is *very* slow.\n" 2286 "See the beginning of the following blog post for how to profile and find which file is importing sympy:\n" 2287 "https://dev-discuss.pytorch.org/t/delving-into-what-happens-when-you-import-torch/1589\n\n" 2288 "If you hit this error, you may want to:\n" 2289 " - Refactor your code to avoid depending on sympy files you may not need to depend\n" 2290 " - Use TYPE_CHECKING if you are using sympy + strings if you are using sympy on type annotations\n" 2291 " - Import things that depend on SymPy locally") 2292 2293 @unittest.skipIf(IS_WINDOWS, "importing torch+CUDA on CPU results in warning") 2294 @parametrize('path', ['torch', 'functorch']) 2295 def test_no_mutate_global_logging_on_import(self, path) -> None: 2296 # Calling logging.basicConfig, among other things, modifies the global 2297 # logging state. It is not OK to modify the global logging state on 2298 # `import torch` (or other submodules we own) because users do not expect it. 2299 expected = 'abcdefghijklmnopqrstuvwxyz' 2300 commands = [ 2301 'import logging', 2302 f'import {path}', 2303 '_logger = logging.getLogger("torch_test_testing")', 2304 'logging.root.addHandler(logging.StreamHandler())', 2305 'logging.root.setLevel(logging.INFO)', 2306 f'_logger.info("{expected}")' 2307 ] 2308 out = self._check_python_output("; ".join(commands)) 2309 self.assertEqual(out.strip(), expected) 2310 2311class TestOpInfos(TestCase): 2312 def test_sample_input(self) -> None: 2313 a, b, c, d, e = (object() for _ in range(5)) 2314 2315 # Construction with natural syntax 2316 s = SampleInput(a, b, c, d=d, e=e) 2317 assert s.input is a 2318 assert s.args == (b, c) 2319 assert s.kwargs == dict(d=d, e=e) 2320 2321 # Construction with explicit args and kwargs 2322 s = SampleInput(a, args=(b,), kwargs=dict(c=c, d=d, e=e)) 2323 assert s.input is a 2324 assert s.args == (b,) 2325 assert s.kwargs == dict(c=c, d=d, e=e) 2326 2327 # Construction with a mixed form will error 2328 with self.assertRaises(AssertionError): 2329 s = SampleInput(a, b, c, args=(d, e)) 2330 2331 with self.assertRaises(AssertionError): 2332 s = SampleInput(a, b, c, kwargs=dict(d=d, e=e)) 2333 2334 with self.assertRaises(AssertionError): 2335 s = SampleInput(a, args=(b, c), d=d, e=e) 2336 2337 with self.assertRaises(AssertionError): 2338 s = SampleInput(a, b, c=c, kwargs=dict(d=d, e=e)) 2339 2340 # Mixing metadata into "natural" construction will error 2341 with self.assertRaises(AssertionError): 2342 s = SampleInput(a, b, name="foo") 2343 2344 with self.assertRaises(AssertionError): 2345 s = SampleInput(a, b, output_process_fn_grad=lambda x: x) 2346 2347 with self.assertRaises(AssertionError): 2348 s = SampleInput(a, b, broadcasts_input=True) 2349 2350 # But when only input is given, metadata is allowed for backward 2351 # compatibility 2352 s = SampleInput(a, broadcasts_input=True) 2353 assert s.input is a 2354 assert s.broadcasts_input 2355 2356 def test_sample_input_metadata(self) -> None: 2357 a, b = (object() for _ in range(2)) 2358 s1 = SampleInput(a, b=b) 2359 self.assertIs(s1.output_process_fn_grad(None), None) 2360 self.assertFalse(s1.broadcasts_input) 2361 self.assertEqual(s1.name, "") 2362 2363 s2 = s1.with_metadata( 2364 output_process_fn_grad=lambda x: a, 2365 broadcasts_input=True, 2366 name="foo", 2367 ) 2368 self.assertIs(s1, s2) 2369 self.assertIs(s2.output_process_fn_grad(None), a) 2370 self.assertTrue(s2.broadcasts_input) 2371 self.assertEqual(s2.name, "foo") 2372 2373 2374# Tests that validate the various sample generating functions on each OpInfo. 2375class TestOpInfoSampleFunctions(TestCase): 2376 2377 @ops(op_db, dtypes=OpDTypes.any_one) 2378 def test_opinfo_sample_generators(self, device, dtype, op): 2379 # Test op.sample_inputs doesn't generate multiple samples when called 2380 samples = op.sample_inputs(device, dtype) 2381 self.assertIsInstance(samples, Iterator) 2382 2383 @ops([op for op in op_db if op.reference_inputs_func is not None], dtypes=OpDTypes.any_one) 2384 def test_opinfo_reference_generators(self, device, dtype, op): 2385 # Test op.reference_inputs doesn't generate multiple samples when called 2386 samples = op.reference_inputs(device, dtype) 2387 self.assertIsInstance(samples, Iterator) 2388 2389 @ops([op for op in op_db if op.error_inputs_func is not None], dtypes=OpDTypes.none) 2390 def test_opinfo_error_generators(self, device, op): 2391 # Test op.error_inputs doesn't generate multiple inputs when called 2392 samples = op.error_inputs(device) 2393 self.assertIsInstance(samples, Iterator) 2394 2395 2396instantiate_device_type_tests(TestOpInfoSampleFunctions, globals()) 2397instantiate_parametrized_tests(TestImports) 2398 2399 2400if __name__ == '__main__': 2401 run_tests() 2402