xref: /aosp_15_r20/external/pytorch/test/test_testing.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: tests"]
2
3import collections
4import doctest
5import functools
6import importlib
7import inspect
8import itertools
9import math
10import os
11import re
12import subprocess
13import sys
14import unittest.mock
15from typing import Any, Callable, Iterator, List, Tuple
16
17import torch
18
19from torch.testing import make_tensor
20from torch.testing._internal.common_utils import \
21    (IS_FBCODE, IS_JETSON, IS_MACOS, IS_SANDCASTLE, IS_WINDOWS, TestCase, run_tests, slowTest,
22     parametrize, subtest, instantiate_parametrized_tests, dtype_name, TEST_WITH_ROCM, decorateIf)
23from torch.testing._internal.common_device_type import \
24    (PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY, PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, dtypes,
25     get_device_type_test_bases, instantiate_device_type_tests, onlyCPU, onlyCUDA, onlyNativeDeviceTypes,
26     deviceCountAtLeast, ops, expectedFailureMeta, OpDTypes)
27from torch.testing._internal.common_methods_invocations import op_db
28from torch.testing._internal import opinfo
29from torch.testing._internal.common_dtype import all_types_and_complex_and, floating_types
30from torch.testing._internal.common_modules import modules, module_db, ModuleInfo
31from torch.testing._internal.opinfo.core import SampleInput, DecorateInfo, OpInfo
32import operator
33
34# For testing TestCase methods and torch.testing functions
35class TestTesting(TestCase):
36    # Ensure that assertEqual handles numpy arrays properly
37    @dtypes(*all_types_and_complex_and(torch.bool, torch.half))
38    def test_assertEqual_numpy(self, device, dtype):
39        S = 10
40        test_sizes = [
41            (),
42            (0,),
43            (S,),
44            (S, S),
45            (0, S),
46            (S, 0)]
47        for test_size in test_sizes:
48            a = make_tensor(test_size, dtype=dtype, device=device, low=-5, high=5)
49            a_n = a.cpu().numpy()
50            msg = f'size: {test_size}'
51            self.assertEqual(a_n, a, rtol=0, atol=0, msg=msg)
52            self.assertEqual(a, a_n, rtol=0, atol=0, msg=msg)
53            self.assertEqual(a_n, a_n, rtol=0, atol=0, msg=msg)
54
55    def test_assertEqual_longMessage(self):
56        actual = "actual"
57        expected = "expected"
58
59        long_message = self.longMessage
60        try:
61            # Capture the default error message by forcing TestCase.longMessage = False
62            self.longMessage = False
63            try:
64                self.assertEqual(actual, expected)
65            except AssertionError as error:
66                default_msg = str(error)
67            else:
68                raise AssertionError("AssertionError not raised")
69
70            self.longMessage = True
71            extra_msg = "sentinel"
72            with self.assertRaisesRegex(AssertionError, re.escape(f"{default_msg}\n{extra_msg}")):
73                self.assertEqual(actual, expected, msg=extra_msg)
74        finally:
75            self.longMessage = long_message
76
77    def _isclose_helper(self, tests, device, dtype, equal_nan, atol=1e-08, rtol=1e-05):
78        for test in tests:
79            a = torch.tensor((test[0],), device=device, dtype=dtype)
80            b = torch.tensor((test[1],), device=device, dtype=dtype)
81
82            actual = torch.isclose(a, b, equal_nan=equal_nan, atol=atol, rtol=rtol)
83            expected = test[2]
84            self.assertEqual(actual.item(), expected)
85
86    def test_isclose_bool(self, device):
87        tests = (
88            (True, True, True),
89            (False, False, True),
90            (True, False, False),
91            (False, True, False),
92        )
93
94        self._isclose_helper(tests, device, torch.bool, False)
95
96    @dtypes(torch.uint8,
97            torch.int8, torch.int16, torch.int32, torch.int64)
98    def test_isclose_integer(self, device, dtype):
99        tests = (
100            (0, 0, True),
101            (0, 1, False),
102            (1, 0, False),
103        )
104
105        self._isclose_helper(tests, device, dtype, False)
106
107        # atol and rtol tests
108        tests = [
109            (0, 1, True),
110            (1, 0, False),
111            (1, 3, True),
112        ]
113
114        self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
115
116        if dtype is torch.uint8:
117            tests = [
118                (-1, 1, False),
119                (1, -1, False)
120            ]
121        else:
122            tests = [
123                (-1, 1, True),
124                (1, -1, True)
125            ]
126
127        self._isclose_helper(tests, device, dtype, False, atol=1.5, rtol=.5)
128
129    @onlyNativeDeviceTypes
130    @dtypes(torch.float16, torch.float32, torch.float64)
131    def test_isclose_float(self, device, dtype):
132        tests = (
133            (0, 0, True),
134            (0, -1, False),
135            (float('inf'), float('inf'), True),
136            (-float('inf'), float('inf'), False),
137            (float('inf'), float('nan'), False),
138            (float('nan'), float('nan'), False),
139            (0, float('nan'), False),
140            (1, 1, True),
141        )
142
143        self._isclose_helper(tests, device, dtype, False)
144
145        # atol and rtol tests
146        eps = 1e-2 if dtype is torch.half else 1e-6
147        tests = (
148            (0, 1, True),
149            (0, 1 + eps, False),
150            (1, 0, False),
151            (1, 3, True),
152            (1 - eps, 3, False),
153            (-.25, .5, True),
154            (-.25 - eps, .5, False),
155            (.25, -.5, True),
156            (.25 + eps, -.5, False),
157        )
158
159        self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
160
161        # equal_nan = True tests
162        tests = (
163            (0, float('nan'), False),
164            (float('inf'), float('nan'), False),
165            (float('nan'), float('nan'), True),
166        )
167
168        self._isclose_helper(tests, device, dtype, True)
169
170    @unittest.skipIf(IS_SANDCASTLE, "Skipping because doesn't work on sandcastle")
171    @dtypes(torch.complex64, torch.complex128)
172    def test_isclose_complex(self, device, dtype):
173        tests = (
174            (complex(1, 1), complex(1, 1 + 1e-8), True),
175            (complex(0, 1), complex(1, 1), False),
176            (complex(1, 1), complex(1, 0), False),
177            (complex(1, 1), complex(1, float('nan')), False),
178            (complex(1, float('nan')), complex(1, float('nan')), False),
179            (complex(1, 1), complex(1, float('inf')), False),
180            (complex(float('inf'), 1), complex(1, float('inf')), False),
181            (complex(-float('inf'), 1), complex(1, float('inf')), False),
182            (complex(-float('inf'), 1), complex(float('inf'), 1), False),
183            (complex(float('inf'), 1), complex(float('inf'), 1), True),
184            (complex(float('inf'), 1), complex(float('inf'), 1 + 1e-4), False),
185        )
186
187        self._isclose_helper(tests, device, dtype, False)
188
189        # atol and rtol tests
190
191        # atol and rtol tests
192        eps = 1e-6
193        tests = (
194            # Complex versions of float tests (real part)
195            (complex(0, 0), complex(1, 0), True),
196            (complex(0, 0), complex(1 + eps, 0), False),
197            (complex(1, 0), complex(0, 0), False),
198            (complex(1, 0), complex(3, 0), True),
199            (complex(1 - eps, 0), complex(3, 0), False),
200            (complex(-.25, 0), complex(.5, 0), True),
201            (complex(-.25 - eps, 0), complex(.5, 0), False),
202            (complex(.25, 0), complex(-.5, 0), True),
203            (complex(.25 + eps, 0), complex(-.5, 0), False),
204            # Complex versions of float tests (imaginary part)
205            (complex(0, 0), complex(0, 1), True),
206            (complex(0, 0), complex(0, 1 + eps), False),
207            (complex(0, 1), complex(0, 0), False),
208            (complex(0, 1), complex(0, 3), True),
209            (complex(0, 1 - eps), complex(0, 3), False),
210            (complex(0, -.25), complex(0, .5), True),
211            (complex(0, -.25 - eps), complex(0, .5), False),
212            (complex(0, .25), complex(0, -.5), True),
213            (complex(0, .25 + eps), complex(0, -.5), False),
214        )
215
216        self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
217
218        # atol and rtol tests for isclose
219        tests = (
220            # Complex-specific tests
221            (complex(1, -1), complex(-1, 1), False),
222            (complex(1, -1), complex(2, -2), True),
223            (complex(-math.sqrt(2), math.sqrt(2)),
224             complex(-math.sqrt(.5), math.sqrt(.5)), True),
225            (complex(-math.sqrt(2), math.sqrt(2)),
226             complex(-math.sqrt(.501), math.sqrt(.499)), False),
227            (complex(2, 4), complex(1., 8.8523607), True),
228            (complex(2, 4), complex(1., 8.8523607 + eps), False),
229            (complex(1, 99), complex(4, 100), True),
230        )
231        self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
232
233        # equal_nan = True tests
234        tests = (
235            (complex(1, 1), complex(1, float('nan')), False),
236            (complex(1, 1), complex(float('nan'), 1), False),
237            (complex(float('nan'), 1), complex(float('nan'), 1), True),
238            (complex(float('nan'), 1), complex(1, float('nan')), True),
239            (complex(float('nan'), float('nan')), complex(float('nan'), float('nan')), True),
240        )
241        self._isclose_helper(tests, device, dtype, True)
242
243    # Tests that isclose with rtol or atol values less than zero throws a
244    #   RuntimeError
245    @dtypes(torch.bool, torch.uint8,
246            torch.int8, torch.int16, torch.int32, torch.int64,
247            torch.float16, torch.float32, torch.float64)
248    def test_isclose_atol_rtol_greater_than_zero(self, device, dtype):
249        t = torch.tensor((1,), device=device, dtype=dtype)
250
251        with self.assertRaises(RuntimeError):
252            torch.isclose(t, t, atol=-1, rtol=1)
253        with self.assertRaises(RuntimeError):
254            torch.isclose(t, t, atol=1, rtol=-1)
255        with self.assertRaises(RuntimeError):
256            torch.isclose(t, t, atol=-1, rtol=-1)
257
258    def test_isclose_equality_shortcut(self):
259        # For values >= 2**53, integers differing by 1 can no longer differentiated by torch.float64 or lower precision
260        # floating point dtypes. Thus, even with rtol == 0 and atol == 0, these tensors would be considered close if
261        # they were not compared as integers.
262        a = torch.tensor(2 ** 53, dtype=torch.int64)
263        b = a + 1
264
265        self.assertFalse(torch.isclose(a, b, rtol=0, atol=0))
266
267    @dtypes(torch.float16, torch.float32, torch.float64, torch.complex64, torch.complex128)
268    def test_isclose_nan_equality_shortcut(self, device, dtype):
269        if dtype.is_floating_point:
270            a = b = torch.nan
271        else:
272            a = complex(torch.nan, 0)
273            b = complex(0, torch.nan)
274
275        expected = True
276        tests = [(a, b, expected)]
277
278        self._isclose_helper(tests, device, dtype, equal_nan=True, rtol=0, atol=0)
279
280    # The following tests (test_cuda_assert_*) are added to ensure test suite terminates early
281    # when CUDA assert was thrown. Because all subsequent test will fail if that happens.
282    # These tests are slow because it spawn another process to run test suite.
283    # See: https://github.com/pytorch/pytorch/issues/49019
284    @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support device side asserts")
285    @onlyCUDA
286    @slowTest
287    def test_cuda_assert_should_stop_common_utils_test_suite(self, device):
288        # test to ensure common_utils.py override has early termination for CUDA.
289        stderr = TestCase.runWithPytorchAPIUsageStderr("""\
290#!/usr/bin/env python3
291
292import torch
293from torch.testing._internal.common_utils import (TestCase, run_tests, slowTest)
294
295class TestThatContainsCUDAAssertFailure(TestCase):
296
297    @slowTest
298    def test_throw_unrecoverable_cuda_exception(self):
299        x = torch.rand(10, device='cuda')
300        # cause unrecoverable CUDA exception, recoverable on CPU
301        y = x[torch.tensor([25])].cpu()
302
303    @slowTest
304    def test_trivial_passing_test_case_on_cpu_cuda(self):
305        x1 = torch.tensor([0., 1.], device='cuda')
306        x2 = torch.tensor([0., 1.], device='cpu')
307        self.assertEqual(x1, x2)
308
309if __name__ == '__main__':
310    run_tests()
311""")
312        # should capture CUDA error
313        self.assertIn('CUDA error: device-side assert triggered', stderr)
314        # should run only 1 test because it throws unrecoverable error.
315        self.assertIn('errors=1', stderr)
316
317
318    @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support device side asserts")
319    @onlyCUDA
320    @slowTest
321    def test_cuda_assert_should_stop_common_device_type_test_suite(self, device):
322        # test to ensure common_device_type.py override has early termination for CUDA.
323        stderr = TestCase.runWithPytorchAPIUsageStderr("""\
324#!/usr/bin/env python3
325
326import torch
327from torch.testing._internal.common_utils import (TestCase, run_tests, slowTest)
328from torch.testing._internal.common_device_type import instantiate_device_type_tests
329
330class TestThatContainsCUDAAssertFailure(TestCase):
331
332    @slowTest
333    def test_throw_unrecoverable_cuda_exception(self, device):
334        x = torch.rand(10, device=device)
335        # cause unrecoverable CUDA exception, recoverable on CPU
336        y = x[torch.tensor([25])].cpu()
337
338    @slowTest
339    def test_trivial_passing_test_case_on_cpu_cuda(self, device):
340        x1 = torch.tensor([0., 1.], device=device)
341        x2 = torch.tensor([0., 1.], device='cpu')
342        self.assertEqual(x1, x2)
343
344instantiate_device_type_tests(
345    TestThatContainsCUDAAssertFailure,
346    globals(),
347    only_for='cuda'
348)
349
350if __name__ == '__main__':
351    run_tests()
352""")
353        # should capture CUDA error
354        self.assertIn('CUDA error: device-side assert triggered', stderr)
355        # should run only 1 test because it throws unrecoverable error.
356        self.assertIn('errors=1', stderr)
357
358
359    @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support device side asserts")
360    @onlyCUDA
361    @slowTest
362    def test_cuda_assert_should_not_stop_common_distributed_test_suite(self, device):
363        # test to ensure common_distributed.py override should not early terminate CUDA.
364        stderr = TestCase.runWithPytorchAPIUsageStderr("""\
365#!/usr/bin/env python3
366
367import torch
368from torch.testing._internal.common_utils import (run_tests, slowTest)
369from torch.testing._internal.common_device_type import instantiate_device_type_tests
370from torch.testing._internal.common_distributed import MultiProcessTestCase
371
372class TestThatContainsCUDAAssertFailure(MultiProcessTestCase):
373
374    @slowTest
375    def test_throw_unrecoverable_cuda_exception(self, device):
376        x = torch.rand(10, device=device)
377        # cause unrecoverable CUDA exception, recoverable on CPU
378        y = x[torch.tensor([25])].cpu()
379
380    @slowTest
381    def test_trivial_passing_test_case_on_cpu_cuda(self, device):
382        x1 = torch.tensor([0., 1.], device=device)
383        x2 = torch.tensor([0., 1.], device='cpu')
384        self.assertEqual(x1, x2)
385
386instantiate_device_type_tests(
387    TestThatContainsCUDAAssertFailure,
388    globals(),
389    only_for='cuda'
390)
391
392if __name__ == '__main__':
393    run_tests()
394""")
395        # we are currently disabling CUDA early termination for distributed tests.
396        self.assertIn('errors=2', stderr)
397
398    @expectedFailureMeta  # This is only supported for CPU and CUDA
399    @onlyNativeDeviceTypes
400    def test_get_supported_dtypes(self, device):
401        # Test the `get_supported_dtypes` helper function.
402        # We acquire the dtypes for few Ops dynamically and verify them against
403        # the correct statically described values.
404        ops_to_test = list(filter(lambda op: op.name in ['atan2', 'topk', 'xlogy'], op_db))
405
406        for op in ops_to_test:
407            dynamic_dtypes = opinfo.utils.get_supported_dtypes(op, op.sample_inputs_func, self.device_type)
408            dynamic_dispatch = opinfo.utils.dtypes_dispatch_hint(dynamic_dtypes)
409            if self.device_type == 'cpu':
410                dtypes = op.dtypes
411            else:  # device_type ='cuda'
412                dtypes = op.dtypesIfCUDA
413
414            self.assertTrue(set(dtypes) == set(dynamic_dtypes))
415            self.assertTrue(set(dtypes) == set(dynamic_dispatch.dispatch_fn()))
416
417    @onlyCPU
418    @ops(
419        [
420            op
421            for op in op_db
422            if len(
423                op.supported_dtypes("cpu").symmetric_difference(
424                    op.supported_dtypes("cuda")
425                )
426            )
427            > 0
428        ][:1],
429        dtypes=OpDTypes.none,
430    )
431    def test_supported_dtypes(self, device, op):
432        self.assertNotEqual(op.supported_dtypes("cpu"), op.supported_dtypes("cuda"))
433        self.assertEqual(op.supported_dtypes("cuda"), op.supported_dtypes("cuda:0"))
434        self.assertEqual(
435            op.supported_dtypes(torch.device("cuda")),
436            op.supported_dtypes(torch.device("cuda", index=1)),
437        )
438
439instantiate_device_type_tests(TestTesting, globals())
440
441
442class TestFrameworkUtils(TestCase):
443
444    @unittest.skipIf(IS_WINDOWS, "Skipping because doesn't work for windows")
445    @unittest.skipIf(IS_SANDCASTLE, "Skipping because doesn't work on sandcastle")
446    def test_filtering_env_var(self):
447        # Test environment variable selected device type test generator.
448        test_filter_file_template = """\
449#!/usr/bin/env python3
450
451import torch
452from torch.testing._internal.common_utils import (TestCase, run_tests)
453from torch.testing._internal.common_device_type import instantiate_device_type_tests
454
455class TestEnvironmentVariable(TestCase):
456
457    def test_trivial_passing_test(self, device):
458        x1 = torch.tensor([0., 1.], device=device)
459        x2 = torch.tensor([0., 1.], device='cpu')
460        self.assertEqual(x1, x2)
461
462instantiate_device_type_tests(
463    TestEnvironmentVariable,
464    globals(),
465)
466
467if __name__ == '__main__':
468    run_tests()
469"""
470        test_bases_count = len(get_device_type_test_bases())
471        # Test without setting env var should run everything.
472        env = dict(os.environ)
473        for k in ['CI', PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY]:
474            if k in env.keys():
475                del env[k]
476        _, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env)
477        self.assertIn(f'Ran {test_bases_count} test', stderr.decode('ascii'))
478
479        # Test with setting only_for should only run 1 test.
480        env[PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY] = 'cpu'
481        _, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env)
482        self.assertIn('Ran 1 test', stderr.decode('ascii'))
483
484        # Test with setting except_for should run 1 less device type from default.
485        del env[PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY]
486        env[PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY] = 'cpu'
487        _, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env)
488        self.assertIn(f'Ran {test_bases_count-1} test', stderr.decode('ascii'))
489
490        # Test with setting both should throw exception
491        env[PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY] = 'cpu'
492        _, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env)
493        self.assertNotIn('OK', stderr.decode('ascii'))
494
495
496def make_assert_close_inputs(actual: Any, expected: Any) -> List[Tuple[Any, Any]]:
497    """Makes inputs for :func:`torch.testing.assert_close` functions based on two examples.
498
499    Args:
500        actual (Any): Actual input.
501        expected (Any): Expected input.
502
503    Returns:
504        List[Tuple[Any, Any]]: Pair of example inputs, as well as the example inputs wrapped in sequences
505        (:class:`tuple`, :class:`list`), and mappings (:class:`dict`, :class:`~collections.OrderedDict`).
506    """
507    return [
508        (actual, expected),
509        # tuple vs. tuple
510        ((actual,), (expected,)),
511        # list vs. list
512        ([actual], [expected]),
513        # tuple vs. list
514        ((actual,), [expected]),
515        # dict vs. dict
516        ({"t": actual}, {"t": expected}),
517        # OrderedDict vs. OrderedDict
518        (collections.OrderedDict([("t", actual)]), collections.OrderedDict([("t", expected)])),
519        # dict vs. OrderedDict
520        ({"t": actual}, collections.OrderedDict([("t", expected)])),
521        # list of tuples vs. tuple of lists
522        ([(actual,)], ([expected],)),
523        # list of dicts vs. tuple of OrderedDicts
524        ([{"t": actual}], (collections.OrderedDict([("t", expected)]),)),
525        # dict of lists vs. OrderedDict of tuples
526        ({"t": [actual]}, collections.OrderedDict([("t", (expected,))])),
527    ]
528
529
530def assert_close_with_inputs(actual: Any, expected: Any) -> Iterator[Callable]:
531    """Yields :func:`torch.testing.assert_close` with predefined positional inputs based on two examples.
532
533    .. note::
534
535        Every test that does not test for a specific input should iterate over this to maximize the coverage.
536
537    Args:
538        actual (Any): Actual input.
539        expected (Any): Expected input.
540
541    Yields:
542        Callable: :func:`torch.testing.assert_close` with predefined positional inputs.
543    """
544    for inputs in make_assert_close_inputs(actual, expected):
545        yield functools.partial(torch.testing.assert_close, *inputs)
546
547
548class TestAssertClose(TestCase):
549    def test_mismatching_types_subclasses(self):
550        actual = torch.rand(())
551        expected = torch.nn.Parameter(actual)
552
553        for fn in assert_close_with_inputs(actual, expected):
554            fn()
555
556    def test_mismatching_types_type_equality(self):
557        actual = torch.empty(())
558        expected = torch.nn.Parameter(actual)
559
560        for fn in assert_close_with_inputs(actual, expected):
561            with self.assertRaisesRegex(TypeError, str(type(expected))):
562                fn(allow_subclasses=False)
563
564    def test_mismatching_types(self):
565        actual = torch.empty(2)
566        expected = actual.numpy()
567
568        for fn, allow_subclasses in itertools.product(assert_close_with_inputs(actual, expected), (True, False)):
569            with self.assertRaisesRegex(TypeError, str(type(expected))):
570                fn(allow_subclasses=allow_subclasses)
571
572    def test_unknown_type(self):
573        actual = "0"
574        expected = "0"
575
576        for fn in assert_close_with_inputs(actual, expected):
577            with self.assertRaisesRegex(TypeError, str(type(actual))):
578                fn()
579
580    def test_mismatching_shape(self):
581        actual = torch.empty(())
582        expected = actual.clone().reshape((1,))
583
584        for fn in assert_close_with_inputs(actual, expected):
585            with self.assertRaisesRegex(AssertionError, "shape"):
586                fn()
587
588    @unittest.skipIf(not torch.backends.mkldnn.is_available(), reason="MKLDNN is not available.")
589    def test_unknown_layout(self):
590        actual = torch.empty((2, 2))
591        expected = actual.to_mkldnn()
592
593        for fn in assert_close_with_inputs(actual, expected):
594            with self.assertRaisesRegex(ValueError, "layout"):
595                fn()
596
597    def test_meta(self):
598        actual = torch.empty((2, 2), device="meta")
599        expected = torch.empty((2, 2), device="meta")
600
601        for fn in assert_close_with_inputs(actual, expected):
602            fn()
603
604    def test_mismatching_layout(self):
605        strided = torch.empty((2, 2))
606        sparse_coo = strided.to_sparse()
607        sparse_csr = strided.to_sparse_csr()
608
609        for actual, expected in itertools.combinations((strided, sparse_coo, sparse_csr), 2):
610            for fn in assert_close_with_inputs(actual, expected):
611                with self.assertRaisesRegex(AssertionError, "layout"):
612                    fn()
613
614    def test_mismatching_layout_no_check(self):
615        strided = torch.randn((2, 2))
616        sparse_coo = strided.to_sparse()
617        sparse_csr = strided.to_sparse_csr()
618
619        for actual, expected in itertools.combinations((strided, sparse_coo, sparse_csr), 2):
620            for fn in assert_close_with_inputs(actual, expected):
621                fn(check_layout=False)
622
623    def test_mismatching_dtype(self):
624        actual = torch.empty((), dtype=torch.float)
625        expected = actual.clone().to(torch.int)
626
627        for fn in assert_close_with_inputs(actual, expected):
628            with self.assertRaisesRegex(AssertionError, "dtype"):
629                fn()
630
631    def test_mismatching_dtype_no_check(self):
632        actual = torch.ones((), dtype=torch.float)
633        expected = actual.clone().to(torch.int)
634
635        for fn in assert_close_with_inputs(actual, expected):
636            fn(check_dtype=False)
637
638    def test_mismatching_stride(self):
639        actual = torch.empty((2, 2))
640        expected = torch.as_strided(actual.clone().t().contiguous(), actual.shape, actual.stride()[::-1])
641
642        for fn in assert_close_with_inputs(actual, expected):
643            with self.assertRaisesRegex(AssertionError, "stride"):
644                fn(check_stride=True)
645
646    def test_mismatching_stride_no_check(self):
647        actual = torch.rand((2, 2))
648        expected = torch.as_strided(actual.clone().t().contiguous(), actual.shape, actual.stride()[::-1])
649        for fn in assert_close_with_inputs(actual, expected):
650            fn()
651
652    def test_only_rtol(self):
653        actual = torch.empty(())
654        expected = actual.clone()
655
656        for fn in assert_close_with_inputs(actual, expected):
657            with self.assertRaises(ValueError):
658                fn(rtol=0.0)
659
660    def test_only_atol(self):
661        actual = torch.empty(())
662        expected = actual.clone()
663
664        for fn in assert_close_with_inputs(actual, expected):
665            with self.assertRaises(ValueError):
666                fn(atol=0.0)
667
668    def test_mismatching_values(self):
669        actual = torch.tensor(1)
670        expected = torch.tensor(2)
671
672        for fn in assert_close_with_inputs(actual, expected):
673            with self.assertRaises(AssertionError):
674                fn()
675
676    def test_mismatching_values_rtol(self):
677        eps = 1e-3
678        actual = torch.tensor(1.0)
679        expected = torch.tensor(1.0 + eps)
680
681        for fn in assert_close_with_inputs(actual, expected):
682            with self.assertRaises(AssertionError):
683                fn(rtol=eps / 2, atol=0.0)
684
685    def test_mismatching_values_atol(self):
686        eps = 1e-3
687        actual = torch.tensor(0.0)
688        expected = torch.tensor(eps)
689
690        for fn in assert_close_with_inputs(actual, expected):
691            with self.assertRaises(AssertionError):
692                fn(rtol=0.0, atol=eps / 2)
693
694    def test_matching(self):
695        actual = torch.tensor(1.0)
696        expected = actual.clone()
697
698        torch.testing.assert_close(actual, expected)
699
700    def test_matching_rtol(self):
701        eps = 1e-3
702        actual = torch.tensor(1.0)
703        expected = torch.tensor(1.0 + eps)
704
705        for fn in assert_close_with_inputs(actual, expected):
706            fn(rtol=eps * 2, atol=0.0)
707
708    def test_matching_atol(self):
709        eps = 1e-3
710        actual = torch.tensor(0.0)
711        expected = torch.tensor(eps)
712
713        for fn in assert_close_with_inputs(actual, expected):
714            fn(rtol=0.0, atol=eps * 2)
715
716    # TODO: the code that this test was designed for was removed in https://github.com/pytorch/pytorch/pull/56058
717    #  We need to check if this test is still needed or if this behavior is now enabled by default.
718    def test_matching_conjugate_bit(self):
719        actual = torch.tensor(complex(1, 1)).conj()
720        expected = torch.tensor(complex(1, -1))
721
722        for fn in assert_close_with_inputs(actual, expected):
723            fn()
724
725    def test_matching_nan(self):
726        nan = float("NaN")
727
728        tests = (
729            (nan, nan),
730            (complex(nan, 0), complex(0, nan)),
731            (complex(nan, nan), complex(nan, 0)),
732            (complex(nan, nan), complex(nan, nan)),
733        )
734
735        for actual, expected in tests:
736            for fn in assert_close_with_inputs(actual, expected):
737                with self.assertRaises(AssertionError):
738                    fn()
739
740    def test_matching_nan_with_equal_nan(self):
741        nan = float("NaN")
742
743        tests = (
744            (nan, nan),
745            (complex(nan, 0), complex(0, nan)),
746            (complex(nan, nan), complex(nan, 0)),
747            (complex(nan, nan), complex(nan, nan)),
748        )
749
750        for actual, expected in tests:
751            for fn in assert_close_with_inputs(actual, expected):
752                fn(equal_nan=True)
753
754    def test_numpy(self):
755        tensor = torch.rand(2, 2, dtype=torch.float32)
756        actual = tensor.numpy()
757        expected = actual.copy()
758
759        for fn in assert_close_with_inputs(actual, expected):
760            fn()
761
762    def test_scalar(self):
763        number = torch.randint(10, size=()).item()
764        for actual, expected in itertools.product((int(number), float(number), complex(number)), repeat=2):
765            check_dtype = type(actual) is type(expected)
766
767            for fn in assert_close_with_inputs(actual, expected):
768                fn(check_dtype=check_dtype)
769
770    def test_bool(self):
771        actual = torch.tensor([True, False])
772        expected = actual.clone()
773
774        for fn in assert_close_with_inputs(actual, expected):
775            fn()
776
777    def test_none(self):
778        actual = expected = None
779
780        for fn in assert_close_with_inputs(actual, expected):
781            fn()
782
783    def test_none_mismatch(self):
784        expected = None
785
786        for actual in (False, 0, torch.nan, torch.tensor(torch.nan)):
787            for fn in assert_close_with_inputs(actual, expected):
788                with self.assertRaises(AssertionError):
789                    fn()
790
791
792    def test_docstring_examples(self):
793        finder = doctest.DocTestFinder(verbose=False)
794        runner = doctest.DocTestRunner(verbose=False, optionflags=doctest.NORMALIZE_WHITESPACE)
795        globs = dict(torch=torch)
796        doctests = finder.find(torch.testing.assert_close, globs=globs)[0]
797        failures = []
798        runner.run(doctests, out=lambda report: failures.append(report))
799        if failures:
800            raise AssertionError(f"Doctest found {len(failures)} failures:\n\n" + "\n".join(failures))
801
802    def test_default_tolerance_selection_mismatching_dtypes(self):
803        # If the default tolerances where selected based on the promoted dtype, i.e. float64,
804        # these tensors wouldn't be considered close.
805        actual = torch.tensor(0.99, dtype=torch.bfloat16)
806        expected = torch.tensor(1.0, dtype=torch.float64)
807
808        for fn in assert_close_with_inputs(actual, expected):
809            fn(check_dtype=False)
810
811    class UnexpectedException(Exception):
812        """The only purpose of this exception is to test ``assert_close``'s handling of unexpected exceptions. Thus,
813        the test should mock a component to raise this instead of the regular behavior. We avoid using a builtin
814        exception here to avoid triggering possible handling of them.
815        """
816
817    @unittest.mock.patch("torch.testing._comparison.TensorLikePair.__init__", side_effect=UnexpectedException)
818    def test_unexpected_error_originate(self, _):
819        actual = torch.tensor(1.0)
820        expected = actual.clone()
821
822        with self.assertRaisesRegex(RuntimeError, "unexpected exception"):
823            torch.testing.assert_close(actual, expected)
824
825    @unittest.mock.patch("torch.testing._comparison.TensorLikePair.compare", side_effect=UnexpectedException)
826    def test_unexpected_error_compare(self, _):
827        actual = torch.tensor(1.0)
828        expected = actual.clone()
829
830        with self.assertRaisesRegex(RuntimeError, "unexpected exception"):
831            torch.testing.assert_close(actual, expected)
832
833
834
835
836class TestAssertCloseMultiDevice(TestCase):
837    @deviceCountAtLeast(1)
838    def test_mismatching_device(self, devices):
839        for actual_device, expected_device in itertools.permutations(("cpu", *devices), 2):
840            actual = torch.empty((), device=actual_device)
841            expected = actual.clone().to(expected_device)
842            for fn in assert_close_with_inputs(actual, expected):
843                with self.assertRaisesRegex(AssertionError, "device"):
844                    fn()
845
846    @deviceCountAtLeast(1)
847    def test_mismatching_device_no_check(self, devices):
848        for actual_device, expected_device in itertools.permutations(("cpu", *devices), 2):
849            actual = torch.rand((), device=actual_device)
850            expected = actual.clone().to(expected_device)
851            for fn in assert_close_with_inputs(actual, expected):
852                fn(check_device=False)
853
854
855instantiate_device_type_tests(TestAssertCloseMultiDevice, globals(), only_for="cuda")
856
857
858class TestAssertCloseErrorMessage(TestCase):
859    def test_identifier_tensor_likes(self):
860        actual = torch.tensor([1, 2, 3, 4])
861        expected = torch.tensor([1, 2, 5, 6])
862
863        for fn in assert_close_with_inputs(actual, expected):
864            with self.assertRaisesRegex(AssertionError, re.escape("Tensor-likes")):
865                fn()
866
867    def test_identifier_scalars(self):
868        actual = 3
869        expected = 5
870        for fn in assert_close_with_inputs(actual, expected):
871            with self.assertRaisesRegex(AssertionError, re.escape("Scalars")):
872                fn()
873
874    def test_not_equal(self):
875        actual = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
876        expected = torch.tensor([1, 2, 5, 6], dtype=torch.float32)
877
878        for fn in assert_close_with_inputs(actual, expected):
879            with self.assertRaisesRegex(AssertionError, re.escape("not equal")):
880                fn(rtol=0.0, atol=0.0)
881
882    def test_not_close(self):
883        actual = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
884        expected = torch.tensor([1, 2, 5, 6], dtype=torch.float32)
885
886        for fn, (rtol, atol) in itertools.product(
887            assert_close_with_inputs(actual, expected), ((1.3e-6, 0.0), (0.0, 1e-5), (1.3e-6, 1e-5))
888        ):
889            with self.assertRaisesRegex(AssertionError, re.escape("not close")):
890                fn(rtol=rtol, atol=atol)
891
892    def test_mismatched_elements(self):
893        actual = torch.tensor([1, 2, 3, 4])
894        expected = torch.tensor([1, 2, 5, 6])
895
896        for fn in assert_close_with_inputs(actual, expected):
897            with self.assertRaisesRegex(AssertionError, re.escape("Mismatched elements: 2 / 4 (50.0%)")):
898                fn()
899
900    def test_abs_diff(self):
901        actual = torch.tensor([[1, 2], [3, 4]])
902        expected = torch.tensor([[1, 2], [5, 4]])
903
904        for fn in assert_close_with_inputs(actual, expected):
905            with self.assertRaisesRegex(AssertionError, re.escape("Greatest absolute difference: 2 at index (1, 0)")):
906                fn()
907
908    def test_abs_diff_scalar(self):
909        actual = 3
910        expected = 5
911
912        for fn in assert_close_with_inputs(actual, expected):
913            with self.assertRaisesRegex(AssertionError, re.escape("Absolute difference: 2")):
914                fn()
915
916    def test_rel_diff(self):
917        actual = torch.tensor([[1, 2], [3, 4]])
918        expected = torch.tensor([[1, 4], [3, 4]])
919
920        for fn in assert_close_with_inputs(actual, expected):
921            with self.assertRaisesRegex(AssertionError, re.escape("Greatest relative difference: 0.5 at index (0, 1)")):
922                fn()
923
924    def test_rel_diff_scalar(self):
925        actual = 2
926        expected = 4
927
928        for fn in assert_close_with_inputs(actual, expected):
929            with self.assertRaisesRegex(AssertionError, re.escape("Relative difference: 0.5")):
930                fn()
931
932    def test_zero_div_zero(self):
933        actual = torch.tensor([1.0, 0.0])
934        expected = torch.tensor([2.0, 0.0])
935
936        for fn in assert_close_with_inputs(actual, expected):
937            # Although it looks complicated, this regex just makes sure that the word 'nan' is not part of the error
938            # message. That would happen if the 0 / 0 is used for the mismatch computation although it matches.
939            with self.assertRaisesRegex(AssertionError, "((?!nan).)*"):
940                fn()
941
942    def test_rtol(self):
943        rtol = 1e-3
944
945        actual = torch.tensor((1, 2))
946        expected = torch.tensor((2, 2))
947
948        for fn in assert_close_with_inputs(actual, expected):
949            with self.assertRaisesRegex(AssertionError, re.escape(f"(up to {rtol} allowed)")):
950                fn(rtol=rtol, atol=0.0)
951
952    def test_atol(self):
953        atol = 1e-3
954
955        actual = torch.tensor((1, 2))
956        expected = torch.tensor((2, 2))
957
958        for fn in assert_close_with_inputs(actual, expected):
959            with self.assertRaisesRegex(AssertionError, re.escape(f"(up to {atol} allowed)")):
960                fn(rtol=0.0, atol=atol)
961
962    def test_msg_str(self):
963        msg = "Custom error message!"
964
965        actual = torch.tensor(1)
966        expected = torch.tensor(2)
967
968        for fn in assert_close_with_inputs(actual, expected):
969            with self.assertRaisesRegex(AssertionError, msg):
970                fn(msg=msg)
971
972    def test_msg_callable(self):
973        msg = "Custom error message"
974
975        actual = torch.tensor(1)
976        expected = torch.tensor(2)
977
978        for fn in assert_close_with_inputs(actual, expected):
979            with self.assertRaisesRegex(AssertionError, msg):
980                fn(msg=lambda _: msg)
981
982
983class TestAssertCloseContainer(TestCase):
984    def test_sequence_mismatching_len(self):
985        actual = (torch.empty(()),)
986        expected = ()
987
988        with self.assertRaises(AssertionError):
989            torch.testing.assert_close(actual, expected)
990
991    def test_sequence_mismatching_values_msg(self):
992        t1 = torch.tensor(1)
993        t2 = torch.tensor(2)
994
995        actual = (t1, t1)
996        expected = (t1, t2)
997
998        with self.assertRaisesRegex(AssertionError, re.escape("item [1]")):
999            torch.testing.assert_close(actual, expected)
1000
1001    def test_mapping_mismatching_keys(self):
1002        actual = {"a": torch.empty(())}
1003        expected = {}
1004
1005        with self.assertRaises(AssertionError):
1006            torch.testing.assert_close(actual, expected)
1007
1008    def test_mapping_mismatching_values_msg(self):
1009        t1 = torch.tensor(1)
1010        t2 = torch.tensor(2)
1011
1012        actual = {"a": t1, "b": t1}
1013        expected = {"a": t1, "b": t2}
1014
1015        with self.assertRaisesRegex(AssertionError, re.escape("item ['b']")):
1016            torch.testing.assert_close(actual, expected)
1017
1018
1019class TestAssertCloseSparseCOO(TestCase):
1020    def test_matching_coalesced(self):
1021        indices = (
1022            (0, 1),
1023            (1, 0),
1024        )
1025        values = (1, 2)
1026        actual = torch.sparse_coo_tensor(indices, values, size=(2, 2)).coalesce()
1027        expected = actual.clone()
1028
1029        for fn in assert_close_with_inputs(actual, expected):
1030            fn()
1031
1032    def test_matching_uncoalesced(self):
1033        indices = (
1034            (0, 1),
1035            (1, 0),
1036        )
1037        values = (1, 2)
1038        actual = torch.sparse_coo_tensor(indices, values, size=(2, 2))
1039        expected = actual.clone()
1040
1041        for fn in assert_close_with_inputs(actual, expected):
1042            fn()
1043
1044    def test_mismatching_sparse_dims(self):
1045        t = torch.randn(2, 3, 4)
1046        actual = t.to_sparse()
1047        expected = t.to_sparse(2)
1048
1049        for fn in assert_close_with_inputs(actual, expected):
1050            with self.assertRaisesRegex(AssertionError, re.escape("number of sparse dimensions in sparse COO tensors")):
1051                fn()
1052
1053    def test_mismatching_nnz(self):
1054        actual_indices = (
1055            (0, 1),
1056            (1, 0),
1057        )
1058        actual_values = (1, 2)
1059        actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2))
1060
1061        expected_indices = (
1062            (0, 1, 1,),
1063            (1, 0, 0,),
1064        )
1065        expected_values = (1, 1, 1)
1066        expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
1067
1068        for fn in assert_close_with_inputs(actual, expected):
1069            with self.assertRaisesRegex(AssertionError, re.escape("number of specified values in sparse COO tensors")):
1070                fn()
1071
1072    def test_mismatching_indices_msg(self):
1073        actual_indices = (
1074            (0, 1),
1075            (1, 0),
1076        )
1077        actual_values = (1, 2)
1078        actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2))
1079
1080        expected_indices = (
1081            (0, 1),
1082            (1, 1),
1083        )
1084        expected_values = (1, 2)
1085        expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
1086
1087        for fn in assert_close_with_inputs(actual, expected):
1088            with self.assertRaisesRegex(AssertionError, re.escape("Sparse COO indices")):
1089                fn()
1090
1091    def test_mismatching_values_msg(self):
1092        actual_indices = (
1093            (0, 1),
1094            (1, 0),
1095        )
1096        actual_values = (1, 2)
1097        actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2))
1098
1099        expected_indices = (
1100            (0, 1),
1101            (1, 0),
1102        )
1103        expected_values = (1, 3)
1104        expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
1105
1106        for fn in assert_close_with_inputs(actual, expected):
1107            with self.assertRaisesRegex(AssertionError, re.escape("Sparse COO values")):
1108                fn()
1109
1110
1111@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Not all sandcastle jobs support CSR testing")
1112class TestAssertCloseSparseCSR(TestCase):
1113    def test_matching(self):
1114        crow_indices = (0, 1, 2)
1115        col_indices = (1, 0)
1116        values = (1, 2)
1117        actual = torch.sparse_csr_tensor(crow_indices, col_indices, values, size=(2, 2))
1118        expected = actual.clone()
1119
1120        for fn in assert_close_with_inputs(actual, expected):
1121            fn()
1122
1123    def test_mismatching_crow_indices_msg(self):
1124        actual_crow_indices = (0, 1, 2)
1125        actual_col_indices = (0, 1)
1126        actual_values = (1, 2)
1127        actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
1128
1129        expected_crow_indices = (0, 2, 2)
1130        expected_col_indices = actual_col_indices
1131        expected_values = actual_values
1132        expected = torch.sparse_csr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
1133
1134        for fn in assert_close_with_inputs(actual, expected):
1135            with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSR crow_indices")):
1136                fn()
1137
1138    def test_mismatching_col_indices_msg(self):
1139        actual_crow_indices = (0, 1, 2)
1140        actual_col_indices = (1, 0)
1141        actual_values = (1, 2)
1142        actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
1143
1144        expected_crow_indices = actual_crow_indices
1145        expected_col_indices = (1, 1)
1146        expected_values = actual_values
1147        expected = torch.sparse_csr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
1148
1149        for fn in assert_close_with_inputs(actual, expected):
1150            with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSR col_indices")):
1151                fn()
1152
1153    def test_mismatching_values_msg(self):
1154        actual_crow_indices = (0, 1, 2)
1155        actual_col_indices = (1, 0)
1156        actual_values = (1, 2)
1157        actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
1158
1159        expected_crow_indices = actual_crow_indices
1160        expected_col_indices = actual_col_indices
1161        expected_values = (1, 3)
1162        expected = torch.sparse_csr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
1163
1164        for fn in assert_close_with_inputs(actual, expected):
1165            with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSR values")):
1166                fn()
1167
1168
1169@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Not all sandcastle jobs support CSC testing")
1170class TestAssertCloseSparseCSC(TestCase):
1171    def test_matching(self):
1172        ccol_indices = (0, 1, 2)
1173        row_indices = (1, 0)
1174        values = (1, 2)
1175        actual = torch.sparse_csc_tensor(ccol_indices, row_indices, values, size=(2, 2))
1176        expected = actual.clone()
1177
1178        for fn in assert_close_with_inputs(actual, expected):
1179            fn()
1180
1181    def test_mismatching_ccol_indices_msg(self):
1182        actual_ccol_indices = (0, 1, 2)
1183        actual_row_indices = (0, 1)
1184        actual_values = (1, 2)
1185        actual = torch.sparse_csc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2))
1186
1187        expected_ccol_indices = (0, 2, 2)
1188        expected_row_indices = actual_row_indices
1189        expected_values = actual_values
1190        expected = torch.sparse_csc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2))
1191
1192        for fn in assert_close_with_inputs(actual, expected):
1193            with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSC ccol_indices")):
1194                fn()
1195
1196    def test_mismatching_row_indices_msg(self):
1197        actual_ccol_indices = (0, 1, 2)
1198        actual_row_indices = (1, 0)
1199        actual_values = (1, 2)
1200        actual = torch.sparse_csc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2))
1201
1202        expected_ccol_indices = actual_ccol_indices
1203        expected_row_indices = (1, 1)
1204        expected_values = actual_values
1205        expected = torch.sparse_csc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2))
1206
1207        for fn in assert_close_with_inputs(actual, expected):
1208            with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSC row_indices")):
1209                fn()
1210
1211    def test_mismatching_values_msg(self):
1212        actual_ccol_indices = (0, 1, 2)
1213        actual_row_indices = (1, 0)
1214        actual_values = (1, 2)
1215        actual = torch.sparse_csc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2))
1216
1217        expected_ccol_indices = actual_ccol_indices
1218        expected_row_indices = actual_row_indices
1219        expected_values = (1, 3)
1220        expected = torch.sparse_csc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2))
1221
1222        for fn in assert_close_with_inputs(actual, expected):
1223            with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSC values")):
1224                fn()
1225
1226
1227@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Not all sandcastle jobs support BSR testing")
1228class TestAssertCloseSparseBSR(TestCase):
1229    def test_matching(self):
1230        crow_indices = (0, 1, 2)
1231        col_indices = (1, 0)
1232        values = ([[1]], [[2]])
1233        actual = torch.sparse_bsr_tensor(crow_indices, col_indices, values, size=(2, 2))
1234        expected = actual.clone()
1235
1236        for fn in assert_close_with_inputs(actual, expected):
1237            fn()
1238
1239    def test_mismatching_crow_indices_msg(self):
1240        actual_crow_indices = (0, 1, 2)
1241        actual_col_indices = (0, 1)
1242        actual_values = ([[1]], [[2]])
1243        actual = torch.sparse_bsr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
1244
1245        expected_crow_indices = (0, 2, 2)
1246        expected_col_indices = actual_col_indices
1247        expected_values = actual_values
1248        expected = torch.sparse_bsr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
1249
1250        for fn in assert_close_with_inputs(actual, expected):
1251            with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSR crow_indices")):
1252                fn()
1253
1254    def test_mismatching_col_indices_msg(self):
1255        actual_crow_indices = (0, 1, 2)
1256        actual_col_indices = (1, 0)
1257        actual_values = ([[1]], [[2]])
1258        actual = torch.sparse_bsr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
1259
1260        expected_crow_indices = actual_crow_indices
1261        expected_col_indices = (1, 1)
1262        expected_values = actual_values
1263        expected = torch.sparse_bsr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
1264
1265        for fn in assert_close_with_inputs(actual, expected):
1266            with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSR col_indices")):
1267                fn()
1268
1269    def test_mismatching_values_msg(self):
1270        actual_crow_indices = (0, 1, 2)
1271        actual_col_indices = (1, 0)
1272        actual_values = ([[1]], [[2]])
1273        actual = torch.sparse_bsr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
1274
1275        expected_crow_indices = actual_crow_indices
1276        expected_col_indices = actual_col_indices
1277        expected_values = ([[1]], [[3]])
1278        expected = torch.sparse_bsr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
1279
1280        for fn in assert_close_with_inputs(actual, expected):
1281            with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSR values")):
1282                fn()
1283
1284
1285@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Not all sandcastle jobs support BSC testing")
1286class TestAssertCloseSparseBSC(TestCase):
1287    def test_matching(self):
1288        ccol_indices = (0, 1, 2)
1289        row_indices = (1, 0)
1290        values = ([[1]], [[2]])
1291        actual = torch.sparse_bsc_tensor(ccol_indices, row_indices, values, size=(2, 2))
1292        expected = actual.clone()
1293
1294        for fn in assert_close_with_inputs(actual, expected):
1295            fn()
1296
1297    def test_mismatching_ccol_indices_msg(self):
1298        actual_ccol_indices = (0, 1, 2)
1299        actual_row_indices = (0, 1)
1300        actual_values = ([[1]], [[2]])
1301        actual = torch.sparse_bsc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2))
1302
1303        expected_ccol_indices = (0, 2, 2)
1304        expected_row_indices = actual_row_indices
1305        expected_values = actual_values
1306        expected = torch.sparse_bsc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2))
1307
1308        for fn in assert_close_with_inputs(actual, expected):
1309            with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSC ccol_indices")):
1310                fn()
1311
1312    def test_mismatching_row_indices_msg(self):
1313        actual_ccol_indices = (0, 1, 2)
1314        actual_row_indices = (1, 0)
1315        actual_values = ([[1]], [[2]])
1316        actual = torch.sparse_bsc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2))
1317
1318        expected_ccol_indices = actual_ccol_indices
1319        expected_row_indices = (1, 1)
1320        expected_values = actual_values
1321        expected = torch.sparse_bsc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2))
1322
1323        for fn in assert_close_with_inputs(actual, expected):
1324            with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSC row_indices")):
1325                fn()
1326
1327    def test_mismatching_values_msg(self):
1328        actual_ccol_indices = (0, 1, 2)
1329        actual_row_indices = (1, 0)
1330        actual_values = ([[1]], [[2]])
1331        actual = torch.sparse_bsc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2))
1332
1333        expected_ccol_indices = actual_ccol_indices
1334        expected_row_indices = actual_row_indices
1335        expected_values = ([[1]], [[3]])
1336        expected = torch.sparse_bsc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2))
1337
1338        for fn in assert_close_with_inputs(actual, expected):
1339            with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSC values")):
1340                fn()
1341
1342
1343class TestAssertCloseQuantized(TestCase):
1344    def test_mismatching_is_quantized(self):
1345        actual = torch.tensor(1.0)
1346        expected = torch.quantize_per_tensor(actual, scale=1.0, zero_point=0, dtype=torch.qint32)
1347
1348        for fn in assert_close_with_inputs(actual, expected):
1349            with self.assertRaisesRegex(AssertionError, "is_quantized"):
1350                fn()
1351
1352    def test_mismatching_qscheme(self):
1353        t = torch.tensor((1.0,))
1354        actual = torch.quantize_per_tensor(t, scale=1.0, zero_point=0, dtype=torch.qint32)
1355        expected = torch.quantize_per_channel(
1356            t,
1357            scales=torch.tensor((1.0,)),
1358            zero_points=torch.tensor((0,)),
1359            axis=0,
1360            dtype=torch.qint32,
1361        )
1362
1363        for fn in assert_close_with_inputs(actual, expected):
1364            with self.assertRaisesRegex(AssertionError, "qscheme"):
1365                fn()
1366
1367    def test_matching_per_tensor(self):
1368        actual = torch.quantize_per_tensor(torch.tensor(1.0), scale=1.0, zero_point=0, dtype=torch.qint32)
1369        expected = actual.clone()
1370
1371        for fn in assert_close_with_inputs(actual, expected):
1372            fn()
1373
1374    def test_matching_per_channel(self):
1375        actual = torch.quantize_per_channel(
1376            torch.tensor((1.0,)),
1377            scales=torch.tensor((1.0,)),
1378            zero_points=torch.tensor((0,)),
1379            axis=0,
1380            dtype=torch.qint32,
1381        )
1382        expected = actual.clone()
1383
1384        for fn in assert_close_with_inputs(actual, expected):
1385            fn()
1386
1387
1388class TestMakeTensor(TestCase):
1389    supported_dtypes = dtypes(
1390        torch.bool,
1391        torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64,
1392        torch.float16, torch.bfloat16, torch.float32, torch.float64,
1393        torch.complex32, torch.complex64, torch.complex128,
1394    )
1395
1396    @supported_dtypes
1397    @parametrize("shape", [(), (0,), (1,), (1, 1), (2,), (2, 3), (8, 16, 32)])
1398    @parametrize("splat_shape", [False, True])
1399    def test_smoke(self, dtype, device, shape, splat_shape):
1400        t = torch.testing.make_tensor(*shape if splat_shape else shape, dtype=dtype, device=device)
1401
1402        self.assertIsInstance(t, torch.Tensor)
1403        self.assertEqual(t.shape, shape)
1404        self.assertEqual(t.dtype, dtype)
1405        self.assertEqual(t.device, torch.device(device))
1406
1407    @supported_dtypes
1408    @parametrize("requires_grad", [False, True])
1409    def test_requires_grad(self, dtype, device, requires_grad):
1410        make_tensor = functools.partial(
1411            torch.testing.make_tensor,
1412            dtype=dtype,
1413            device=device,
1414            requires_grad=requires_grad,
1415        )
1416
1417        if not requires_grad or dtype.is_floating_point or dtype.is_complex:
1418            t = make_tensor()
1419            self.assertEqual(t.requires_grad, requires_grad)
1420        else:
1421            with self.assertRaisesRegex(
1422                    ValueError, "`requires_grad=True` is not supported for boolean and integral dtypes"
1423            ):
1424                make_tensor()
1425
1426    @supported_dtypes
1427    @parametrize("noncontiguous", [False, True])
1428    @parametrize("shape", [(), (0,), (1,), (1, 1), (2,), (2, 3), (8, 16, 32)])
1429    def test_noncontiguous(self, dtype, device, noncontiguous, shape):
1430        numel = functools.reduce(operator.mul, shape, 1)
1431
1432        t = torch.testing.make_tensor(shape, dtype=dtype, device=device, noncontiguous=noncontiguous)
1433        self.assertEqual(t.is_contiguous(), not noncontiguous or numel < 2)
1434
1435    @supported_dtypes
1436    @parametrize(
1437        "memory_format_and_shape",
1438        [
1439            (None, (2, 3, 4)),
1440            (torch.contiguous_format, (2, 3, 4)),
1441            (torch.channels_last, (2, 3, 4, 5)),
1442            (torch.channels_last_3d, (2, 3, 4, 5, 6)),
1443            (torch.preserve_format, (2, 3, 4)),
1444        ],
1445    )
1446    def test_memory_format(self, dtype, device, memory_format_and_shape):
1447        memory_format, shape = memory_format_and_shape
1448
1449        t = torch.testing.make_tensor(shape, dtype=dtype, device=device, memory_format=memory_format)
1450
1451        self.assertTrue(
1452            t.is_contiguous(memory_format=torch.contiguous_format if memory_format is None else memory_format)
1453        )
1454
1455    @supported_dtypes
1456    def test_noncontiguous_memory_format(self, dtype, device):
1457        with self.assertRaisesRegex(ValueError, "`noncontiguous` and `memory_format` are mutually exclusive"):
1458            torch.testing.make_tensor(
1459                (2, 3, 4, 5),
1460                dtype=dtype,
1461                device=device,
1462                noncontiguous=True,
1463                memory_format=torch.channels_last,
1464            )
1465
1466    @supported_dtypes
1467    def test_exclude_zero(self, dtype, device):
1468        t = torch.testing.make_tensor(10_000, dtype=dtype, device=device, exclude_zero=True, low=-1, high=2)
1469
1470        self.assertTrue((t != 0).all())
1471
1472    @supported_dtypes
1473    def test_low_high_smoke(self, dtype, device):
1474        low_inclusive, high_exclusive = 0, 2
1475
1476        t = torch.testing.make_tensor(10_000, dtype=dtype, device=device, low=low_inclusive, high=high_exclusive)
1477        if dtype.is_complex:
1478            t = torch.view_as_real(t)
1479
1480        self.assertTrue(((t >= low_inclusive) & (t < high_exclusive)).all())
1481
1482    @supported_dtypes
1483    def test_low_high_default_smoke(self, dtype, device):
1484        low_inclusive, high_exclusive = {
1485            torch.bool: (0, 2),
1486            torch.uint8: (0, 10),
1487            **dict.fromkeys([torch.int8, torch.int16, torch.int32, torch.int64], (-9, 10)),
1488        }.get(dtype, (-9, 9))
1489
1490        t = torch.testing.make_tensor(10_000, dtype=dtype, device=device, low=low_inclusive, high=high_exclusive)
1491        if dtype.is_complex:
1492            t = torch.view_as_real(t)
1493
1494        self.assertTrue(((t >= low_inclusive) & (t < high_exclusive)).all())
1495
1496    @parametrize("low_high", [(0, 0), (1, 0), (0, -1)])
1497    @parametrize("value_types", list(itertools.product([int, float], repeat=2)))
1498    @supported_dtypes
1499    def test_low_ge_high(self, dtype, device, low_high, value_types):
1500        low, high = (value_type(value) for value, value_type in zip(low_high, value_types))
1501
1502        if low == high and (dtype.is_floating_point or dtype.is_complex):
1503            with self.assertWarnsRegex(
1504                    FutureWarning,
1505                    "Passing `low==high` to `torch.testing.make_tensor` for floating or complex types is deprecated",
1506            ):
1507                t = torch.testing.make_tensor(10_000, dtype=dtype, device=device, low=low, high=high)
1508            self.assertEqual(t, torch.full_like(t, complex(low, low) if dtype.is_complex else low))
1509        else:
1510            with self.assertRaisesRegex(ValueError, "`low` must be less than `high`"):
1511                torch.testing.make_tensor(dtype=dtype, device=device, low=low, high=high)
1512
1513    @supported_dtypes
1514    @parametrize("low_high", [(None, torch.nan), (torch.nan, None), (torch.nan, torch.nan)])
1515    def test_low_high_nan(self, dtype, device, low_high):
1516        low, high = low_high
1517
1518        with self.assertRaisesRegex(ValueError, "`low` and `high` cannot be NaN"):
1519            torch.testing.make_tensor(dtype=dtype, device=device, low=low, high=high)
1520
1521    @supported_dtypes
1522    def test_low_high_outside_valid_range(self, dtype, device):
1523        make_tensor = functools.partial(torch.testing.make_tensor, dtype=dtype, device=device)
1524
1525        def get_dtype_limits(dtype):
1526            if dtype is torch.bool:
1527                return 0, 1
1528
1529            info = (torch.finfo if dtype.is_floating_point or dtype.is_complex else torch.iinfo)(dtype)
1530            # We are using integer bounds here, because otherwise it would be impossible to pass `low` and `high`
1531            # outside their valid range. Python uses 64bit floating point numbers and thus trying to do something like
1532            # `torch.ffinfo(torch.float64)max * 2` will always result in `inf`. On the flipside, Pythons `int` is
1533            # unbounded.
1534            return int(info.min), int(info.max)
1535
1536        lowest_inclusive, highest_inclusive = get_dtype_limits(dtype)
1537
1538        with self.assertRaisesRegex(ValueError, ""):
1539            low, high = (-2, -1) if lowest_inclusive == 0 else (lowest_inclusive * 4, lowest_inclusive * 2)
1540            make_tensor(low=low, high=high)
1541
1542        with self.assertRaisesRegex(ValueError, ""):
1543            make_tensor(low=highest_inclusive * 2, high=highest_inclusive * 4)
1544
1545    @dtypes(torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
1546    def test_low_high_boolean_integral1(self, dtype, device):
1547        shape = (10_000,)
1548        eps = 1e-4
1549
1550        actual = torch.testing.make_tensor(shape, dtype=dtype, device=device, low=-(1 - eps), high=1 - eps)
1551        expected = torch.zeros(shape, dtype=dtype, device=device)
1552
1553        torch.testing.assert_close(actual, expected)
1554
1555    @dtypes(torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
1556    def test_low_high_boolean_integral2(self, dtype, device):
1557        shape = (10_000,)
1558        if dtype is torch.bool:
1559            low = 1
1560        elif dtype is torch.int64:
1561            # Due to its internals, `make_tensor` is not able to sample `torch.iinfo(torch.int64).max`
1562            low = torch.iinfo(dtype).max - 1
1563        else:
1564            low = torch.iinfo(dtype).max
1565        high = low + 1
1566
1567        actual = torch.testing.make_tensor(shape, dtype=dtype, device=device, low=low, high=high)
1568        expected = torch.full(shape, low, dtype=dtype, device=device)
1569
1570        torch.testing.assert_close(actual, expected)
1571
1572
1573instantiate_device_type_tests(TestMakeTensor, globals())
1574
1575
1576def _get_test_names_for_test_class(test_cls):
1577    """ Convenience function to get all test names for a given test class. """
1578    test_names = [f'{test_cls.__name__}.{key}' for key in test_cls.__dict__
1579                  if key.startswith('test_')]
1580    return sorted(test_names)
1581
1582
1583def _get_test_funcs_for_test_class(test_cls):
1584    """ Convenience function to get all (test function, parametrized_name) pairs for a given test class. """
1585    test_funcs = [(getattr(test_cls, key), key) for key in test_cls.__dict__ if key.startswith('test_')]
1586    return test_funcs
1587
1588
1589class TestTestParametrization(TestCase):
1590    def test_default_names(self):
1591
1592        class TestParametrized(TestCase):
1593            @parametrize("x", range(5))
1594            def test_default_names(self, x):
1595                pass
1596
1597            @parametrize("x,y", [(1, 2), (2, 3), (3, 4)])
1598            def test_two_things_default_names(self, x, y):
1599                pass
1600
1601        instantiate_parametrized_tests(TestParametrized)
1602
1603        expected_test_names = [
1604            'TestParametrized.test_default_names_x_0',
1605            'TestParametrized.test_default_names_x_1',
1606            'TestParametrized.test_default_names_x_2',
1607            'TestParametrized.test_default_names_x_3',
1608            'TestParametrized.test_default_names_x_4',
1609            'TestParametrized.test_two_things_default_names_x_1_y_2',
1610            'TestParametrized.test_two_things_default_names_x_2_y_3',
1611            'TestParametrized.test_two_things_default_names_x_3_y_4',
1612        ]
1613        test_names = _get_test_names_for_test_class(TestParametrized)
1614        self.assertEqual(expected_test_names, test_names)
1615
1616    def test_name_fn(self):
1617
1618        class TestParametrized(TestCase):
1619            @parametrize("bias", [False, True], name_fn=lambda b: 'bias' if b else 'no_bias')
1620            def test_custom_names(self, bias):
1621                pass
1622
1623            @parametrize("x", [1, 2], name_fn=str)
1624            @parametrize("y", [3, 4], name_fn=str)
1625            @parametrize("z", [5, 6], name_fn=str)
1626            def test_three_things_composition_custom_names(self, x, y, z):
1627                pass
1628
1629            @parametrize("x,y", [(1, 2), (1, 3), (1, 4)], name_fn=lambda x, y: f'{x}__{y}')
1630            def test_two_things_custom_names_alternate(self, x, y):
1631                pass
1632
1633        instantiate_parametrized_tests(TestParametrized)
1634
1635        expected_test_names = [
1636            'TestParametrized.test_custom_names_bias',
1637            'TestParametrized.test_custom_names_no_bias',
1638            'TestParametrized.test_three_things_composition_custom_names_1_3_5',
1639            'TestParametrized.test_three_things_composition_custom_names_1_3_6',
1640            'TestParametrized.test_three_things_composition_custom_names_1_4_5',
1641            'TestParametrized.test_three_things_composition_custom_names_1_4_6',
1642            'TestParametrized.test_three_things_composition_custom_names_2_3_5',
1643            'TestParametrized.test_three_things_composition_custom_names_2_3_6',
1644            'TestParametrized.test_three_things_composition_custom_names_2_4_5',
1645            'TestParametrized.test_three_things_composition_custom_names_2_4_6',
1646            'TestParametrized.test_two_things_custom_names_alternate_1__2',
1647            'TestParametrized.test_two_things_custom_names_alternate_1__3',
1648            'TestParametrized.test_two_things_custom_names_alternate_1__4',
1649        ]
1650        test_names = _get_test_names_for_test_class(TestParametrized)
1651        self.assertEqual(expected_test_names, test_names)
1652
1653    def test_subtest_names(self):
1654
1655        class TestParametrized(TestCase):
1656            @parametrize("bias", [subtest(True, name='bias'),
1657                                  subtest(False, name='no_bias')])
1658            def test_custom_names(self, bias):
1659                pass
1660
1661            @parametrize("x,y", [subtest((1, 2), name='double'),
1662                                 subtest((1, 3), name='triple'),
1663                                 subtest((1, 4), name='quadruple')])
1664            def test_two_things_custom_names(self, x, y):
1665                pass
1666
1667        instantiate_parametrized_tests(TestParametrized)
1668
1669        expected_test_names = [
1670            'TestParametrized.test_custom_names_bias',
1671            'TestParametrized.test_custom_names_no_bias',
1672            'TestParametrized.test_two_things_custom_names_double',
1673            'TestParametrized.test_two_things_custom_names_quadruple',
1674            'TestParametrized.test_two_things_custom_names_triple',
1675        ]
1676        test_names = _get_test_names_for_test_class(TestParametrized)
1677        self.assertEqual(expected_test_names, test_names)
1678
1679    def test_apply_param_specific_decorators(self):
1680        # Test that decorators can be applied on a per-param basis.
1681
1682        def test_dec(func):
1683            func._decorator_applied = True
1684            return func
1685
1686        class TestParametrized(TestCase):
1687            @parametrize("x", [subtest(1, name='one'),
1688                               subtest(2, name='two', decorators=[test_dec]),
1689                               subtest(3, name='three')])
1690            def test_param(self, x):
1691                pass
1692
1693        instantiate_parametrized_tests(TestParametrized)
1694
1695        for test_func, name in _get_test_funcs_for_test_class(TestParametrized):
1696            self.assertEqual(hasattr(test_func, '_decorator_applied'), name == 'test_param_two')
1697
1698    def test_compose_param_specific_decorators(self):
1699        # Test that multiple per-param decorators compose correctly.
1700
1701        def test_dec(func):
1702            func._decorator_applied = True
1703            return func
1704
1705        class TestParametrized(TestCase):
1706            @parametrize("x", [subtest(1),
1707                               subtest(2, decorators=[test_dec]),
1708                               subtest(3)])
1709            @parametrize("y", [subtest(False, decorators=[test_dec]),
1710                               subtest(True)])
1711            def test_param(self, x, y):
1712                pass
1713
1714        instantiate_parametrized_tests(TestParametrized)
1715
1716        for test_func, name in _get_test_funcs_for_test_class(TestParametrized):
1717            # Decorator should be applied whenever either x == 2 or y == False.
1718            should_apply = ('x_2' in name) or ('y_False' in name)
1719            self.assertEqual(hasattr(test_func, '_decorator_applied'), should_apply)
1720
1721    def test_modules_decorator_misuse_error(self):
1722        # Test that @modules errors out when used with instantiate_parametrized_tests().
1723
1724        class TestParametrized(TestCase):
1725            @modules(module_db)
1726            def test_modules(self, module_info):
1727                pass
1728
1729        with self.assertRaisesRegex(RuntimeError, 'intended to be used in a device-specific context'):
1730            instantiate_parametrized_tests(TestParametrized)
1731
1732    def test_ops_decorator_misuse_error(self):
1733        # Test that @ops errors out when used with instantiate_parametrized_tests().
1734
1735        class TestParametrized(TestCase):
1736            @ops(op_db)
1737            def test_ops(self, module_info):
1738                pass
1739
1740        with self.assertRaisesRegex(RuntimeError, 'intended to be used in a device-specific context'):
1741            instantiate_parametrized_tests(TestParametrized)
1742
1743    def test_multiple_handling_of_same_param_error(self):
1744        # Test that multiple decorators handling the same param errors out.
1745
1746        class TestParametrized(TestCase):
1747            @parametrize("x", range(3))
1748            @parametrize("x", range(5))
1749            def test_param(self, x):
1750                pass
1751
1752        with self.assertRaisesRegex(RuntimeError, 'multiple parametrization decorators'):
1753            instantiate_parametrized_tests(TestParametrized)
1754
1755    @parametrize("x", [1, subtest(2, decorators=[unittest.expectedFailure]), 3])
1756    def test_subtest_expected_failure(self, x):
1757        if x == 2:
1758            raise RuntimeError('Boom')
1759
1760    @parametrize("x", [subtest(1, decorators=[unittest.expectedFailure]), 2, 3])
1761    @parametrize("y", [4, 5, subtest(6, decorators=[unittest.expectedFailure])])
1762    def test_two_things_subtest_expected_failure(self, x, y):
1763        if x == 1 or y == 6:
1764            raise RuntimeError('Boom')
1765
1766
1767class TestTestParametrizationDeviceType(TestCase):
1768    def test_unparametrized_names(self, device):
1769        # This test exists to protect against regressions in device / dtype test naming
1770        # due to parametrization logic.
1771
1772        device = self.device_type
1773
1774        class TestParametrized(TestCase):
1775            def test_device_specific(self, device):
1776                pass
1777
1778            @dtypes(torch.float32, torch.float64)
1779            def test_device_dtype_specific(self, device, dtype):
1780                pass
1781
1782        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
1783
1784        device_cls = locals()[f'TestParametrized{device.upper()}']
1785        expected_test_names = [name.format(device_cls.__name__, device) for name in (
1786            '{}.test_device_dtype_specific_{}_float32',
1787            '{}.test_device_dtype_specific_{}_float64',
1788            '{}.test_device_specific_{}')
1789        ]
1790        test_names = _get_test_names_for_test_class(device_cls)
1791        self.assertEqual(expected_test_names, test_names)
1792
1793    def test_empty_param_names(self, device):
1794        # If no param names are passed, ensure things still work without parametrization.
1795        device = self.device_type
1796
1797        class TestParametrized(TestCase):
1798            @parametrize("", [])
1799            def test_foo(self, device):
1800                pass
1801
1802            @parametrize("", range(5))
1803            def test_bar(self, device):
1804                pass
1805
1806        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
1807
1808        device_cls = locals()[f'TestParametrized{device.upper()}']
1809        expected_test_names = [name.format(device_cls.__name__, device) for name in (
1810            '{}.test_bar_{}',
1811            '{}.test_foo_{}')
1812        ]
1813        test_names = _get_test_names_for_test_class(device_cls)
1814        self.assertEqual(expected_test_names, test_names)
1815
1816    def test_empty_param_list(self, device):
1817        # If no param values are passed, ensure a helpful error message is thrown.
1818        # In the wild, this could indicate reuse of an exhausted generator.
1819        device = self.device_type
1820
1821        generator = (a for a in range(5))
1822
1823        class TestParametrized(TestCase):
1824            @parametrize("x", generator)
1825            def test_foo(self, device, x):
1826                pass
1827
1828            # Reuse generator from first test function.
1829            @parametrize("y", generator)
1830            def test_bar(self, device, y):
1831                pass
1832
1833        with self.assertRaisesRegex(ValueError, 'An empty arg_values was passed'):
1834            instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
1835
1836    def test_default_names(self, device):
1837        device = self.device_type
1838
1839        class TestParametrized(TestCase):
1840            @parametrize("x", range(5))
1841            def test_default_names(self, device, x):
1842                pass
1843
1844            @parametrize("x,y", [(1, 2), (2, 3), (3, 4)])
1845            def test_two_things_default_names(self, device, x, y):
1846                pass
1847
1848
1849        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
1850
1851        device_cls = locals()[f'TestParametrized{device.upper()}']
1852        expected_test_names = [name.format(device_cls.__name__, device) for name in (
1853            '{}.test_default_names_x_0_{}',
1854            '{}.test_default_names_x_1_{}',
1855            '{}.test_default_names_x_2_{}',
1856            '{}.test_default_names_x_3_{}',
1857            '{}.test_default_names_x_4_{}',
1858            '{}.test_two_things_default_names_x_1_y_2_{}',
1859            '{}.test_two_things_default_names_x_2_y_3_{}',
1860            '{}.test_two_things_default_names_x_3_y_4_{}')
1861        ]
1862        test_names = _get_test_names_for_test_class(device_cls)
1863        self.assertEqual(expected_test_names, test_names)
1864
1865    def test_default_name_non_primitive(self, device):
1866        device = self.device_type
1867
1868        class TestParametrized(TestCase):
1869            @parametrize("x", [1, .5, "foo", object()])
1870            def test_default_names(self, device, x):
1871                pass
1872
1873            @parametrize("x,y", [(1, object()), (object(), .5), (object(), object())])
1874            def test_two_things_default_names(self, device, x, y):
1875                pass
1876
1877        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
1878
1879        device_cls = locals()[f'TestParametrized{device.upper()}']
1880        expected_test_names = sorted(name.format(device_cls.__name__, device) for name in (
1881            '{}.test_default_names_x_1_{}',
1882            '{}.test_default_names_x_0_5_{}',
1883            '{}.test_default_names_x_foo_{}',
1884            '{}.test_default_names_x3_{}',
1885            '{}.test_two_things_default_names_x_1_y0_{}',
1886            '{}.test_two_things_default_names_x1_y_0_5_{}',
1887            '{}.test_two_things_default_names_x2_y2_{}')
1888        )
1889        test_names = _get_test_names_for_test_class(device_cls)
1890        self.assertEqual(expected_test_names, test_names)
1891
1892    def test_name_fn(self, device):
1893        device = self.device_type
1894
1895        class TestParametrized(TestCase):
1896            @parametrize("bias", [False, True], name_fn=lambda b: 'bias' if b else 'no_bias')
1897            def test_custom_names(self, device, bias):
1898                pass
1899
1900            @parametrize("x", [1, 2], name_fn=str)
1901            @parametrize("y", [3, 4], name_fn=str)
1902            @parametrize("z", [5, 6], name_fn=str)
1903            def test_three_things_composition_custom_names(self, device, x, y, z):
1904                pass
1905
1906            @parametrize("x,y", [(1, 2), (1, 3), (1, 4)], name_fn=lambda x, y: f'{x}__{y}')
1907            def test_two_things_custom_names_alternate(self, device, x, y):
1908                pass
1909
1910        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
1911
1912        device_cls = locals()[f'TestParametrized{device.upper()}']
1913        expected_test_names = [name.format(device_cls.__name__, device) for name in (
1914            '{}.test_custom_names_bias_{}',
1915            '{}.test_custom_names_no_bias_{}',
1916            '{}.test_three_things_composition_custom_names_1_3_5_{}',
1917            '{}.test_three_things_composition_custom_names_1_3_6_{}',
1918            '{}.test_three_things_composition_custom_names_1_4_5_{}',
1919            '{}.test_three_things_composition_custom_names_1_4_6_{}',
1920            '{}.test_three_things_composition_custom_names_2_3_5_{}',
1921            '{}.test_three_things_composition_custom_names_2_3_6_{}',
1922            '{}.test_three_things_composition_custom_names_2_4_5_{}',
1923            '{}.test_three_things_composition_custom_names_2_4_6_{}',
1924            '{}.test_two_things_custom_names_alternate_1__2_{}',
1925            '{}.test_two_things_custom_names_alternate_1__3_{}',
1926            '{}.test_two_things_custom_names_alternate_1__4_{}')
1927        ]
1928        test_names = _get_test_names_for_test_class(device_cls)
1929        self.assertEqual(expected_test_names, test_names)
1930
1931    def test_subtest_names(self, device):
1932        device = self.device_type
1933
1934        class TestParametrized(TestCase):
1935            @parametrize("bias", [subtest(True, name='bias'),
1936                                  subtest(False, name='no_bias')])
1937            def test_custom_names(self, device, bias):
1938                pass
1939
1940            @parametrize("x,y", [subtest((1, 2), name='double'),
1941                                 subtest((1, 3), name='triple'),
1942                                 subtest((1, 4), name='quadruple')])
1943            def test_two_things_custom_names(self, device, x, y):
1944                pass
1945
1946        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
1947
1948        device_cls = locals()[f'TestParametrized{device.upper()}']
1949        expected_test_names = [name.format(device_cls.__name__, device) for name in (
1950            '{}.test_custom_names_bias_{}',
1951            '{}.test_custom_names_no_bias_{}',
1952            '{}.test_two_things_custom_names_double_{}',
1953            '{}.test_two_things_custom_names_quadruple_{}',
1954            '{}.test_two_things_custom_names_triple_{}')
1955        ]
1956        test_names = _get_test_names_for_test_class(device_cls)
1957        self.assertEqual(expected_test_names, test_names)
1958
1959    def test_ops_composition_names(self, device):
1960        device = self.device_type
1961
1962        class TestParametrized(TestCase):
1963            @ops(op_db)
1964            @parametrize("flag", [False, True], lambda f: 'flag_enabled' if f else 'flag_disabled')
1965            def test_op_parametrized(self, device, dtype, op, flag):
1966                pass
1967
1968        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
1969
1970        device_cls = locals()[f'TestParametrized{device.upper()}']
1971        expected_test_names = []
1972        for op in op_db:
1973            for dtype in op.supported_dtypes(torch.device(device).type):
1974                for flag_part in ('flag_disabled', 'flag_enabled'):
1975                    expected_name = f'{device_cls.__name__}.test_op_parametrized_{op.formatted_name}_{flag_part}_{device}_{dtype_name(dtype)}'  # noqa: B950
1976                    expected_test_names.append(expected_name)
1977
1978        test_names = _get_test_names_for_test_class(device_cls)
1979        self.assertEqual(sorted(expected_test_names), sorted(test_names))
1980
1981    def test_modules_composition_names(self, device):
1982        device = self.device_type
1983
1984        class TestParametrized(TestCase):
1985            @modules(module_db)
1986            @parametrize("flag", [False, True], lambda f: 'flag_enabled' if f else 'flag_disabled')
1987            def test_module_parametrized(self, device, dtype, module_info, training, flag):
1988                pass
1989
1990        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
1991
1992        device_cls = locals()[f'TestParametrized{device.upper()}']
1993        expected_test_names = []
1994        for module_info in module_db:
1995            for dtype in module_info.dtypes:
1996                for flag_part in ('flag_disabled', 'flag_enabled'):
1997                    expected_train_modes = (
1998                        ['train_mode', 'eval_mode'] if module_info.train_and_eval_differ else [''])
1999                    for training_part in expected_train_modes:
2000                        expected_name = '{}.test_module_parametrized_{}{}_{}_{}_{}'.format(
2001                            device_cls.__name__, module_info.formatted_name,
2002                            '_' + training_part if len(training_part) > 0 else '',
2003                            flag_part, device, dtype_name(dtype))
2004                        expected_test_names.append(expected_name)
2005
2006        test_names = _get_test_names_for_test_class(device_cls)
2007        self.assertEqual(sorted(expected_test_names), sorted(test_names))
2008
2009    def test_ops_decorator_applies_op_and_param_specific_decorators(self, device):
2010        # Test that decorators can be applied on a per-op / per-param basis.
2011
2012        # Create a test op, OpInfo entry, and decorator to apply.
2013        def test_op(x):
2014            return -x
2015
2016        def test_dec(func):
2017            func._decorator_applied = True
2018            return func
2019
2020        test_op_info = OpInfo(
2021            'test_op',
2022            op=test_op,
2023            dtypes=floating_types(),
2024            sample_inputs_func=lambda _: [],
2025            decorators=[
2026                DecorateInfo(test_dec, 'TestParametrized', 'test_op_param',
2027                             device_type='cpu', dtypes=[torch.float64],
2028                             active_if=lambda p: p['x'] == 2)
2029            ])
2030
2031        class TestParametrized(TestCase):
2032            @ops(op_db + [test_op_info])
2033            @parametrize("x", [2, 3])
2034            def test_op_param(self, device, dtype, op, x):
2035                pass
2036
2037            @ops(op_db + [test_op_info])
2038            @parametrize("y", [
2039                subtest(4),
2040                subtest(5, decorators=[test_dec])])
2041            def test_other(self, device, dtype, op, y):
2042                pass
2043
2044            @decorateIf(test_dec, lambda p: p['dtype'] == torch.int16)
2045            @ops(op_db)
2046            def test_three(self, device, dtype, op):
2047                pass
2048
2049        device = self.device_type
2050        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
2051        device_cls = locals()[f'TestParametrized{device.upper()}']
2052
2053        for test_func, name in _get_test_funcs_for_test_class(device_cls):
2054            should_apply = (name == 'test_op_param_test_op_x_2_cpu_float64' or
2055                            ('test_other' in name and 'y_5' in name) or
2056                            ('test_three' in name and name.endswith('_int16')))
2057            self.assertEqual(hasattr(test_func, '_decorator_applied'), should_apply)
2058
2059    def test_modules_decorator_applies_module_and_param_specific_decorators(self, device):
2060        # Test that decorators can be applied on a per-module / per-param basis.
2061
2062        # Create a test module, ModuleInfo entry, and decorator to apply.
2063        class TestModule(torch.nn.Module):
2064            def __init__(self) -> None:
2065                super().__init__()
2066                self.x = torch.nn.Parameter(torch.randn(3))
2067
2068            def forward(self, y):
2069                return self.x + y
2070
2071        def test_dec(func):
2072            func._decorator_applied = True
2073            return func
2074
2075        test_module_info = ModuleInfo(
2076            TestModule,
2077            module_inputs_func=lambda _: [],
2078            decorators=[
2079                DecorateInfo(test_dec, 'TestParametrized', 'test_module_param',
2080                             device_type='cpu', dtypes=[torch.float64],
2081                             active_if=lambda p: p['x'] == 2)
2082            ])
2083
2084        class TestParametrized(TestCase):
2085            @modules(module_db + [test_module_info])
2086            @parametrize("x", [2, 3])
2087            def test_module_param(self, device, dtype, module_info, training, x):
2088                pass
2089
2090            @modules(module_db + [test_module_info])
2091            @parametrize("y", [
2092                subtest(4),
2093                subtest(5, decorators=[test_dec])])
2094            def test_other(self, device, dtype, module_info, training, y):
2095                pass
2096
2097            @decorateIf(test_dec, lambda p: p['dtype'] == torch.float64)
2098            @modules(module_db)
2099            def test_three(self, device, dtype, module_info):
2100                pass
2101
2102        device = self.device_type
2103        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
2104        device_cls = locals()[f'TestParametrized{device.upper()}']
2105
2106        for test_func, name in _get_test_funcs_for_test_class(device_cls):
2107            should_apply = (name == 'test_module_param_TestModule_x_2_cpu_float64' or
2108                            ('test_other' in name and 'y_5' in name) or
2109                            ('test_three' in name and name.endswith('float64')))
2110            self.assertEqual(hasattr(test_func, '_decorator_applied'), should_apply)
2111
2112    def test_param_specific_decoration(self, device):
2113
2114        def test_dec(func):
2115            func._decorator_applied = True
2116            return func
2117
2118        class TestParametrized(TestCase):
2119            @decorateIf(test_dec, lambda params: params["x"] == 1 and params["y"])
2120            @parametrize("x", range(5))
2121            @parametrize("y", [False, True])
2122            def test_param(self, x, y):
2123                pass
2124
2125        device = self.device_type
2126        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
2127        device_cls = locals()[f'TestParametrized{device.upper()}']
2128
2129        for test_func, name in _get_test_funcs_for_test_class(device_cls):
2130            should_apply = ('test_param_x_1_y_True' in name)
2131            self.assertEqual(hasattr(test_func, '_decorator_applied'), should_apply)
2132
2133    def test_dtypes_composition_valid(self, device):
2134        # Test checks that @parametrize and @dtypes compose as expected when @parametrize
2135        # doesn't set dtype.
2136
2137        device = self.device_type
2138
2139        class TestParametrized(TestCase):
2140            @dtypes(torch.float32, torch.float64)
2141            @parametrize("x", range(3))
2142            def test_parametrized(self, x, dtype):
2143                pass
2144
2145        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
2146
2147        device_cls = locals()[f'TestParametrized{device.upper()}']
2148        expected_test_names = [name.format(device_cls.__name__, device) for name in (
2149            '{}.test_parametrized_x_0_{}_float32',
2150            '{}.test_parametrized_x_0_{}_float64',
2151            '{}.test_parametrized_x_1_{}_float32',
2152            '{}.test_parametrized_x_1_{}_float64',
2153            '{}.test_parametrized_x_2_{}_float32',
2154            '{}.test_parametrized_x_2_{}_float64')
2155        ]
2156        test_names = _get_test_names_for_test_class(device_cls)
2157        self.assertEqual(sorted(expected_test_names), sorted(test_names))
2158
2159    def test_dtypes_composition_invalid(self, device):
2160        # Test checks that @dtypes cannot be composed with parametrization decorators when they
2161        # also try to set dtype.
2162
2163        device = self.device_type
2164
2165        class TestParametrized(TestCase):
2166            @dtypes(torch.float32, torch.float64)
2167            @parametrize("dtype", [torch.int32, torch.int64])
2168            def test_parametrized(self, dtype):
2169                pass
2170
2171        with self.assertRaisesRegex(RuntimeError, "handled multiple times"):
2172            instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
2173
2174        # Verify proper error behavior with @ops + @dtypes, as both try to set dtype.
2175
2176        class TestParametrized(TestCase):
2177            @dtypes(torch.float32, torch.float64)
2178            @ops(op_db)
2179            def test_parametrized(self, op, dtype):
2180                pass
2181
2182        with self.assertRaisesRegex(RuntimeError, "handled multiple times"):
2183            instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
2184
2185    def test_multiple_handling_of_same_param_error(self, device):
2186        # Test that multiple decorators handling the same param errors out.
2187        # Both @modules and @ops handle the dtype param.
2188
2189        class TestParametrized(TestCase):
2190            @ops(op_db)
2191            @modules(module_db)
2192            def test_param(self, device, dtype, op, module_info, training):
2193                pass
2194
2195        with self.assertRaisesRegex(RuntimeError, "handled multiple times"):
2196            instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
2197
2198    @parametrize("x", [1, subtest(2, decorators=[unittest.expectedFailure]), 3])
2199    def test_subtest_expected_failure(self, device, x):
2200        if x == 2:
2201            raise RuntimeError('Boom')
2202
2203    @parametrize("x", [subtest(1, decorators=[unittest.expectedFailure]), 2, 3])
2204    @parametrize("y", [4, 5, subtest(6, decorators=[unittest.expectedFailure])])
2205    def test_two_things_subtest_expected_failure(self, device, x, y):
2206        if x == 1 or y == 6:
2207            raise RuntimeError('Boom')
2208
2209
2210instantiate_parametrized_tests(TestTestParametrization)
2211instantiate_device_type_tests(TestTestParametrizationDeviceType, globals())
2212
2213
2214class TestImports(TestCase):
2215    @classmethod
2216    def _check_python_output(cls, program) -> str:
2217        return subprocess.check_output(
2218            [sys.executable, "-W", "always", "-c", program],
2219            stderr=subprocess.STDOUT,
2220            # On Windows, opening the subprocess with the default CWD makes `import torch`
2221            # fail, so just set CWD to this script's directory
2222            cwd=os.path.dirname(os.path.realpath(__file__)),).decode("utf-8")
2223
2224    def test_circular_dependencies(self) -> None:
2225        """ Checks that all modules inside torch can be imported
2226        Prevents regression reported in https://github.com/pytorch/pytorch/issues/77441 """
2227        ignored_modules = ["torch.utils.tensorboard",  # deps on tensorboard
2228                           "torch.distributed.elastic.rendezvous",  # depps on etcd
2229                           "torch.backends._coreml",  # depends on pycoreml
2230                           "torch.contrib.",  # something weird
2231                           "torch.testing._internal.distributed.",  # just fails
2232                           "torch.ao.pruning._experimental.",  # depends on pytorch_lightning, not user-facing
2233                           "torch.onnx._internal",  # depends on onnx-script
2234                           "torch._inductor.runtime.triton_helpers",  # depends on triton
2235                           "torch._inductor.codegen.cuda",  # depends on cutlass
2236                           ]
2237        # See https://github.com/pytorch/pytorch/issues/77801
2238        if not sys.version_info >= (3, 9):
2239            ignored_modules.append("torch.utils.benchmark")
2240        if IS_WINDOWS or IS_MACOS or IS_JETSON:
2241            # Distributed should be importable on Windows(except nn.api.), but not on Mac
2242            if IS_MACOS or IS_JETSON:
2243                ignored_modules.append("torch.distributed.")
2244            else:
2245                ignored_modules.append("torch.distributed.nn.api.")
2246                ignored_modules.append("torch.distributed.optim.")
2247                ignored_modules.append("torch.distributed.rpc.")
2248            ignored_modules.append("torch.testing._internal.dist_utils")
2249            # And these both end up with transitive dependencies on distributed
2250            ignored_modules.append("torch.nn.parallel._replicated_tensor_ddp_interop")
2251            ignored_modules.append("torch.testing._internal.common_fsdp")
2252            ignored_modules.append("torch.testing._internal.common_distributed")
2253
2254        torch_dir = os.path.dirname(torch.__file__)
2255        for base, folders, files in os.walk(torch_dir):
2256            prefix = os.path.relpath(base, os.path.dirname(torch_dir)).replace(os.path.sep, ".")
2257            for f in files:
2258                if not f.endswith(".py"):
2259                    continue
2260                mod_name = f"{prefix}.{f[:-3]}" if f != "__init__.py" else prefix
2261                # Do not attempt to import executable modules
2262                if f == "__main__.py":
2263                    continue
2264                if any(mod_name.startswith(x) for x in ignored_modules):
2265                    continue
2266                try:
2267                    mod = importlib.import_module(mod_name)
2268                except Exception as e:
2269                    raise RuntimeError(f"Failed to import {mod_name}: {e}") from e
2270                self.assertTrue(inspect.ismodule(mod))
2271
2272    @unittest.skipIf(IS_WINDOWS, "TODO enable on Windows")
2273    def test_lazy_imports_are_lazy(self) -> None:
2274        out = self._check_python_output("import sys;import torch;print(all(x not in sys.modules for x in torch._lazy_modules))")
2275        self.assertEqual(out.strip(), "True")
2276
2277    @unittest.skipIf(IS_WINDOWS, "importing torch+CUDA on CPU results in warning")
2278    def test_no_warning_on_import(self) -> None:
2279        out = self._check_python_output("import torch")
2280        self.assertEqual(out, "")
2281
2282    def test_not_import_sympy(self) -> None:
2283        out = self._check_python_output("import torch;import sys;print('sympy' not in sys.modules)")
2284        self.assertEqual(out.strip(), "True",
2285                         "PyTorch should not depend on SymPy at import time as importing SymPy is *very* slow.\n"
2286                         "See the beginning of the following blog post for how to profile and find which file is importing sympy:\n"
2287                         "https://dev-discuss.pytorch.org/t/delving-into-what-happens-when-you-import-torch/1589\n\n"
2288                         "If you hit this error, you may want to:\n"
2289                         "  - Refactor your code to avoid depending on sympy files you may not need to depend\n"
2290                         "  - Use TYPE_CHECKING if you are using sympy + strings if you are using sympy on type annotations\n"
2291                         "  - Import things that depend on SymPy locally")
2292
2293    @unittest.skipIf(IS_WINDOWS, "importing torch+CUDA on CPU results in warning")
2294    @parametrize('path', ['torch', 'functorch'])
2295    def test_no_mutate_global_logging_on_import(self, path) -> None:
2296        # Calling logging.basicConfig, among other things, modifies the global
2297        # logging state. It is not OK to modify the global logging state on
2298        # `import torch` (or other submodules we own) because users do not expect it.
2299        expected = 'abcdefghijklmnopqrstuvwxyz'
2300        commands = [
2301            'import logging',
2302            f'import {path}',
2303            '_logger = logging.getLogger("torch_test_testing")',
2304            'logging.root.addHandler(logging.StreamHandler())',
2305            'logging.root.setLevel(logging.INFO)',
2306            f'_logger.info("{expected}")'
2307        ]
2308        out = self._check_python_output("; ".join(commands))
2309        self.assertEqual(out.strip(), expected)
2310
2311class TestOpInfos(TestCase):
2312    def test_sample_input(self) -> None:
2313        a, b, c, d, e = (object() for _ in range(5))
2314
2315        # Construction with natural syntax
2316        s = SampleInput(a, b, c, d=d, e=e)
2317        assert s.input is a
2318        assert s.args == (b, c)
2319        assert s.kwargs == dict(d=d, e=e)
2320
2321        # Construction with explicit args and kwargs
2322        s = SampleInput(a, args=(b,), kwargs=dict(c=c, d=d, e=e))
2323        assert s.input is a
2324        assert s.args == (b,)
2325        assert s.kwargs == dict(c=c, d=d, e=e)
2326
2327        # Construction with a mixed form will error
2328        with self.assertRaises(AssertionError):
2329            s = SampleInput(a, b, c, args=(d, e))
2330
2331        with self.assertRaises(AssertionError):
2332            s = SampleInput(a, b, c, kwargs=dict(d=d, e=e))
2333
2334        with self.assertRaises(AssertionError):
2335            s = SampleInput(a, args=(b, c), d=d, e=e)
2336
2337        with self.assertRaises(AssertionError):
2338            s = SampleInput(a, b, c=c, kwargs=dict(d=d, e=e))
2339
2340        # Mixing metadata into "natural" construction will error
2341        with self.assertRaises(AssertionError):
2342            s = SampleInput(a, b, name="foo")
2343
2344        with self.assertRaises(AssertionError):
2345            s = SampleInput(a, b, output_process_fn_grad=lambda x: x)
2346
2347        with self.assertRaises(AssertionError):
2348            s = SampleInput(a, b, broadcasts_input=True)
2349
2350        # But when only input is given, metadata is allowed for backward
2351        # compatibility
2352        s = SampleInput(a, broadcasts_input=True)
2353        assert s.input is a
2354        assert s.broadcasts_input
2355
2356    def test_sample_input_metadata(self) -> None:
2357        a, b = (object() for _ in range(2))
2358        s1 = SampleInput(a, b=b)
2359        self.assertIs(s1.output_process_fn_grad(None), None)
2360        self.assertFalse(s1.broadcasts_input)
2361        self.assertEqual(s1.name, "")
2362
2363        s2 = s1.with_metadata(
2364            output_process_fn_grad=lambda x: a,
2365            broadcasts_input=True,
2366            name="foo",
2367        )
2368        self.assertIs(s1, s2)
2369        self.assertIs(s2.output_process_fn_grad(None), a)
2370        self.assertTrue(s2.broadcasts_input)
2371        self.assertEqual(s2.name, "foo")
2372
2373
2374# Tests that validate the various sample generating functions on each OpInfo.
2375class TestOpInfoSampleFunctions(TestCase):
2376
2377    @ops(op_db, dtypes=OpDTypes.any_one)
2378    def test_opinfo_sample_generators(self, device, dtype, op):
2379        # Test op.sample_inputs doesn't generate multiple samples when called
2380        samples = op.sample_inputs(device, dtype)
2381        self.assertIsInstance(samples, Iterator)
2382
2383    @ops([op for op in op_db if op.reference_inputs_func is not None], dtypes=OpDTypes.any_one)
2384    def test_opinfo_reference_generators(self, device, dtype, op):
2385        # Test op.reference_inputs doesn't generate multiple samples when called
2386        samples = op.reference_inputs(device, dtype)
2387        self.assertIsInstance(samples, Iterator)
2388
2389    @ops([op for op in op_db if op.error_inputs_func is not None], dtypes=OpDTypes.none)
2390    def test_opinfo_error_generators(self, device, op):
2391        # Test op.error_inputs doesn't generate multiple inputs when called
2392        samples = op.error_inputs(device)
2393        self.assertIsInstance(samples, Iterator)
2394
2395
2396instantiate_device_type_tests(TestOpInfoSampleFunctions, globals())
2397instantiate_parametrized_tests(TestImports)
2398
2399
2400if __name__ == '__main__':
2401    run_tests()
2402