xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/common_device_type.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import copy
4import gc
5import inspect
6import os
7import runpy
8import sys
9import threading
10import unittest
11from collections import namedtuple
12from enum import Enum
13from functools import partial, wraps
14from typing import Any, ClassVar, Dict, List, Optional, Sequence, Set, Tuple, Union
15
16import torch
17from torch.testing._internal.common_cuda import (
18    _get_torch_cuda_version,
19    _get_torch_rocm_version,
20    TEST_CUSPARSE_GENERIC,
21    TEST_HIPSPARSE_GENERIC,
22)
23from torch.testing._internal.common_dtype import get_all_dtypes
24from torch.testing._internal.common_utils import (
25    _TestParametrizer,
26    clear_tracked_input,
27    compose_parametrize_fns,
28    dtype_name,
29    get_tracked_input,
30    IS_FBCODE,
31    is_privateuse1_backend_available,
32    IS_REMOTE_GPU,
33    IS_SANDCASTLE,
34    IS_WINDOWS,
35    NATIVE_DEVICES,
36    PRINT_REPRO_ON_FAILURE,
37    skipCUDANonDefaultStreamIf,
38    skipIfTorchDynamo,
39    TEST_HPU,
40    TEST_MKL,
41    TEST_MPS,
42    TEST_WITH_ASAN,
43    TEST_WITH_MIOPEN_SUGGEST_NHWC,
44    TEST_WITH_ROCM,
45    TEST_WITH_TORCHINDUCTOR,
46    TEST_WITH_TSAN,
47    TEST_WITH_UBSAN,
48    TEST_XPU,
49    TestCase,
50)
51
52
53try:
54    import psutil  # type: ignore[import]
55
56    HAS_PSUTIL = True
57except ModuleNotFoundError:
58    HAS_PSUTIL = False
59    psutil = None
60
61# Note [Writing Test Templates]
62# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
63#
64# This note was written shortly after the PyTorch 1.9 release.
65# If you notice it's out-of-date or think it could be improved then please
66# file an issue.
67#
68# PyTorch has its own framework for instantiating test templates. That is, for
69#   taking test classes that look similar to unittest or pytest
70#   compatible test classes and optionally doing the following:
71#
72#     - instantiating a version of the test class for each available device type
73#         (often the CPU, CUDA, and META device types)
74#     - further instantiating a version of each test that's always specialized
75#         on the test class's device type, and optionally specialized further
76#         on datatypes or operators
77#
78# This functionality is similar to pytest's parametrize functionality
79#   (see https://docs.pytest.org/en/6.2.x/parametrize.html), but with considerable
80#   additional logic that specializes the instantiated test classes for their
81#   device types (see CPUTestBase and CUDATestBase below), supports a variety
82#   of composable decorators that allow for test filtering and setting
83#   tolerances, and allows tests parametrized by operators to instantiate
84#   only the subset of device type x dtype that operator supports.
85#
86# This framework was built to make it easier to write tests that run on
87#   multiple device types, multiple datatypes (dtypes), and for multiple
88#   operators. It's also useful for controlling which tests are run. For example,
89#   only tests that use a CUDA device can be run on platforms with CUDA.
90#   Let's dive in with an example to get an idea for how it works:
91#
92# --------------------------------------------------------
93# A template class (looks like a regular unittest TestCase)
94# class TestClassFoo(TestCase):
95#
96#   # A template test that can be specialized with a device
97#   # NOTE: this test case is not runnable by unittest or pytest because it
98#   #   accepts an extra positional argument, "device", that they do not understand
99#   def test_bar(self, device):
100#     pass
101#
102# # Function that instantiates a template class and its tests
103# instantiate_device_type_tests(TestCommon, globals())
104# --------------------------------------------------------
105#
106# In the above code example we see a template class and a single test template
107#   that can be instantiated with a device. The function
108#   instantiate_device_type_tests(), called at file scope, instantiates
109#   new test classes, one per available device type, and new tests in those
110#   classes from these templates. It actually does this by removing
111#   the class TestClassFoo and replacing it with classes like TestClassFooCPU
112#   and TestClassFooCUDA, instantiated test classes that inherit from CPUTestBase
113#   and CUDATestBase respectively. Additional device types, like XLA,
114#   (see https://github.com/pytorch/xla) can further extend the set of
115#   instantiated test classes to create classes like TestClassFooXLA.
116#
117# The test template, test_bar(), is also instantiated. In this case the template
118#   is only specialized on a device, so (depending on the available device
119#   types) it might become test_bar_cpu() in TestClassFooCPU and test_bar_cuda()
120#   in TestClassFooCUDA. We can think of the instantiated test classes as
121#   looking like this:
122#
123# --------------------------------------------------------
124# # An instantiated test class for the CPU device type
125# class TestClassFooCPU(CPUTestBase):
126#
127#   # An instantiated test that calls the template with the string representation
128#   #   of a device from the test class's device type
129#   def test_bar_cpu(self):
130#     test_bar(self, 'cpu')
131#
132# # An instantiated test class for the CUDA device type
133# class TestClassFooCUDA(CUDATestBase):
134#
135#   # An instantiated test that calls the template with the string representation
136#   #   of a device from the test class's device type
137#   def test_bar_cuda(self):
138#     test_bar(self, 'cuda:0')
139# --------------------------------------------------------
140#
141# These instantiated test classes ARE discoverable and runnable by both
142#   unittest and pytest. One thing that may be confusing, however, is that
143#   attempting to run "test_bar" will not work, despite it appearing in the
144#   original template code. This is because "test_bar" is no longer discoverable
145#   after instantiate_device_type_tests() runs, as the above snippet shows.
146#   Instead "test_bar_cpu" and "test_bar_cuda" may be run directly, or both
147#   can be run with the option "-k test_bar".
148#
149# Removing the template class and adding the instantiated classes requires
150#   passing "globals()" to instantiate_device_type_tests(), because it
151#   edits the file's Python objects.
152#
153# As mentioned, tests can be additionally parametrized on dtypes or
154#   operators. Datatype parametrization uses the @dtypes decorator and
155#   require a test template like this:
156#
157# --------------------------------------------------------
158# # A template test that can be specialized with a device and a datatype (dtype)
159# @dtypes(torch.float32, torch.int64)
160# def test_car(self, device, dtype)
161#   pass
162# --------------------------------------------------------
163#
164# If the CPU and CUDA device types are available this test would be
165#   instantiated as 4 tests that cover the cross-product of the two dtypes
166#   and two device types:
167#
168#     - test_car_cpu_float32
169#     - test_car_cpu_int64
170#     - test_car_cuda_float32
171#     - test_car_cuda_int64
172#
173# The dtype is passed as a torch.dtype object.
174#
175# Tests parametrized on operators (actually on OpInfos, more on that in a
176#   moment...) use the @ops decorator and require a test template like this:
177# --------------------------------------------------------
178# # A template test that can be specialized with a device, dtype, and OpInfo
179# @ops(op_db)
180# def test_car(self, device, dtype, op)
181#   pass
182# --------------------------------------------------------
183#
184# See the documentation for the @ops decorator below for additional details
185#   on how to use it and see the note [OpInfos] in
186#   common_methods_invocations.py for more details on OpInfos.
187#
188# A test parametrized over the entire "op_db", which contains hundreds of
189#   OpInfos, will likely have hundreds or thousands of instantiations. The
190#   test will be instantiated on the cross-product of device types, operators,
191#   and the dtypes the operator supports on that device type. The instantiated
192#   tests will have names like:
193#
194#     - test_car_add_cpu_float32
195#     - test_car_sub_cuda_int64
196#
197# The first instantiated test calls the original test_car() with the OpInfo
198#   for torch.add as its "op" argument, the string 'cpu' for its "device" argument,
199#   and the dtype torch.float32 for is "dtype" argument. The second instantiated
200#   test calls the test_car() with the OpInfo for torch.sub, a CUDA device string
201#   like 'cuda:0' or 'cuda:1' for its "device" argument, and the dtype
202#   torch.int64 for its "dtype argument."
203#
204# In addition to parametrizing over device, dtype, and ops via OpInfos, the
205#   @parametrize decorator is supported for arbitrary parametrizations:
206# --------------------------------------------------------
207# # A template test that can be specialized with a device, dtype, and value for x
208# @parametrize("x", range(5))
209# def test_car(self, device, dtype, x)
210#   pass
211# --------------------------------------------------------
212#
213# See the documentation for @parametrize in common_utils.py for additional details
214#   on this. Note that the instantiate_device_type_tests() function will handle
215#   such parametrizations; there is no need to additionally call
216#   instantiate_parametrized_tests().
217#
218# Clever test filtering can be very useful when working with parametrized
219#   tests. "-k test_car" would run every instantiated variant of the test_car()
220#   test template, and "-k test_car_add" runs every variant instantiated with
221#   torch.add.
222#
223# It is important to use the passed device and dtype as appropriate. Use
224#   helper functions like make_tensor() that require explicitly specifying
225#   the device and dtype so they're not forgotten.
226#
227# Test templates can use a variety of composable decorators to specify
228#   additional options and requirements, some are listed here:
229#
230#     - @deviceCountAtLeast(<minimum number of devices to run test with>)
231#         Passes a list of strings representing all available devices of
232#         the test class's device type as the test template's "device" argument.
233#         If there are fewer devices than the value passed to the decorator
234#         the test is skipped.
235#     - @dtypes(<list of tuples of dtypes>)
236#         In addition to accepting multiple dtypes, the @dtypes decorator
237#         can accept a sequence of tuple pairs of dtypes. The test template
238#         will be called with each tuple for its "dtype" argument.
239#     - @onlyNativeDeviceTypes
240#         Skips the test if the device is not a native device type (currently CPU, CUDA, Meta)
241#     - @onlyCPU
242#         Skips the test if the device is not a CPU device
243#     - @onlyCUDA
244#         Skips the test if the device is not a CUDA device
245#     - @onlyMPS
246#         Skips the test if the device is not a MPS device
247#     - @skipCPUIfNoLapack
248#         Skips the test if the device is a CPU device and LAPACK is not installed
249#     - @skipCPUIfNoMkl
250#         Skips the test if the device is a CPU device and MKL is not installed
251#     - @skipCUDAIfNoMagma
252#         Skips the test if the device is a CUDA device and MAGMA is not installed
253#     - @skipCUDAIfRocm
254#         Skips the test if the device is a CUDA device and ROCm is being used
255
256
257# Note [Adding a Device Type]
258# ~~~~~~~~~~~~~~~~~~~~~~~~~~~
259#
260# To add a device type:
261#
262#   (1) Create a new "TestBase" extending DeviceTypeTestBase.
263#       See CPUTestBase and CUDATestBase below.
264#   (2) Define the "device_type" attribute of the base to be the
265#       appropriate string.
266#   (3) Add logic to this file that appends your base class to
267#       device_type_test_bases when your device type is available.
268#   (4) (Optional) Write setUpClass/tearDownClass class methods that
269#       instantiate dependencies (see MAGMA in CUDATestBase).
270#   (5) (Optional) Override the "instantiate_test" method for total
271#       control over how your class creates tests.
272#
273# setUpClass is called AFTER tests have been created and BEFORE and ONLY IF
274# they are run. This makes it useful for initializing devices and dependencies.
275
276
277# Note [Overriding methods in generic tests]
278# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
279#
280# Device generic tests look a lot like normal test classes, but they differ
281# from ordinary classes in some important ways.  In particular, overriding
282# methods in generic tests doesn't work quite the way you expect.
283#
284#     class TestFooDeviceType(TestCase):
285#         # Intention is to override
286#         def assertEqual(self, x, y):
287#             # This DOESN'T WORK!
288#             super().assertEqual(x, y)
289#
290# If you try to run this code, you'll get an error saying that TestFooDeviceType
291# is not in scope.  This is because after instantiating our classes, we delete
292# it from the parent scope.  Instead, you need to hardcode a direct invocation
293# of the desired subclass call, e.g.,
294#
295#     class TestFooDeviceType(TestCase):
296#         # Intention is to override
297#         def assertEqual(self, x, y):
298#             TestCase.assertEqual(x, y)
299#
300# However, a less error-prone way of customizing the behavior of TestCase
301# is to either (1) add your functionality to TestCase and make it toggled
302# by a class attribute, or (2) create your own subclass of TestCase, and
303# then inherit from it for your generic test.
304
305
306def _dtype_test_suffix(dtypes):
307    """Returns the test suffix for a dtype, sequence of dtypes, or None."""
308    if isinstance(dtypes, (list, tuple)):
309        if len(dtypes) == 0:
310            return ""
311        return "_" + "_".join(dtype_name(d) for d in dtypes)
312    elif dtypes:
313        return f"_{dtype_name(dtypes)}"
314    else:
315        return ""
316
317
318def _update_param_kwargs(param_kwargs, name, value):
319    """Adds a kwarg with the specified name and value to the param_kwargs dict."""
320    # Make name plural (e.g. devices / dtypes) if the value is composite.
321    plural_name = f"{name}s"
322
323    # Clear out old entries of the arg if any.
324    if name in param_kwargs:
325        del param_kwargs[name]
326    if plural_name in param_kwargs:
327        del param_kwargs[plural_name]
328
329    if isinstance(value, (list, tuple)):
330        param_kwargs[plural_name] = value
331    elif value is not None:
332        param_kwargs[name] = value
333
334    # Leave param_kwargs as-is when value is None.
335
336
337class DeviceTypeTestBase(TestCase):
338    device_type: str = "generic_device_type"
339
340    # Flag to disable test suite early due to unrecoverable error such as CUDA error.
341    _stop_test_suite = False
342
343    # Precision is a thread-local setting since it may be overridden per test
344    _tls = threading.local()
345    _tls.precision = TestCase._precision
346    _tls.rel_tol = TestCase._rel_tol
347
348    @property
349    def precision(self):
350        return self._tls.precision
351
352    @precision.setter
353    def precision(self, prec):
354        self._tls.precision = prec
355
356    @property
357    def rel_tol(self):
358        return self._tls.rel_tol
359
360    @rel_tol.setter
361    def rel_tol(self, prec):
362        self._tls.rel_tol = prec
363
364    # Returns a string representing the device that single device tests should use.
365    # Note: single device tests use this device exclusively.
366    @classmethod
367    def get_primary_device(cls):
368        return cls.device_type
369
370    @classmethod
371    def _init_and_get_primary_device(cls):
372        try:
373            return cls.get_primary_device()
374        except Exception:
375            # For CUDATestBase, XLATestBase, and possibly others, the primary device won't be available
376            # until setUpClass() sets it. Call that manually here if needed.
377            if hasattr(cls, "setUpClass"):
378                cls.setUpClass()
379            return cls.get_primary_device()
380
381    # Returns a list of strings representing all available devices of this
382    # device type. The primary device must be the first string in the list
383    # and the list must contain no duplicates.
384    # Note: UNSTABLE API. Will be replaced once PyTorch has a device generic
385    #   mechanism of acquiring all available devices.
386    @classmethod
387    def get_all_devices(cls):
388        return [cls.get_primary_device()]
389
390    # Returns the dtypes the test has requested.
391    # Prefers device-specific dtype specifications over generic ones.
392    @classmethod
393    def _get_dtypes(cls, test):
394        if not hasattr(test, "dtypes"):
395            return None
396
397        default_dtypes = test.dtypes.get("all")
398        msg = f"@dtypes is mandatory when using @dtypesIf however '{test.__name__}' didn't specify it"
399        assert default_dtypes is not None, msg
400
401        return test.dtypes.get(cls.device_type, default_dtypes)
402
403    def _get_precision_override(self, test, dtype):
404        if not hasattr(test, "precision_overrides"):
405            return self.precision
406        return test.precision_overrides.get(dtype, self.precision)
407
408    def _get_tolerance_override(self, test, dtype):
409        if not hasattr(test, "tolerance_overrides"):
410            return self.precision, self.rel_tol
411        return test.tolerance_overrides.get(dtype, tol(self.precision, self.rel_tol))
412
413    def _apply_precision_override_for_test(self, test, param_kwargs):
414        dtype = param_kwargs["dtype"] if "dtype" in param_kwargs else None
415        dtype = param_kwargs["dtypes"] if "dtypes" in param_kwargs else dtype
416        if dtype:
417            self.precision = self._get_precision_override(test, dtype)
418            self.precision, self.rel_tol = self._get_tolerance_override(test, dtype)
419
420    # Creates device-specific tests.
421    @classmethod
422    def instantiate_test(cls, name, test, *, generic_cls=None):
423        def instantiate_test_helper(
424            cls, name, *, test, param_kwargs=None, decorator_fn=lambda _: []
425        ):
426            # Add the device param kwarg if the test needs device or devices.
427            param_kwargs = {} if param_kwargs is None else param_kwargs
428            test_sig_params = inspect.signature(test).parameters
429            if "device" in test_sig_params or "devices" in test_sig_params:
430                device_arg: str = cls._init_and_get_primary_device()
431                if hasattr(test, "num_required_devices"):
432                    device_arg = cls.get_all_devices()
433                _update_param_kwargs(param_kwargs, "device", device_arg)
434
435            # Apply decorators based on param kwargs.
436            for decorator in decorator_fn(param_kwargs):
437                test = decorator(test)
438
439            # Constructs the test
440            @wraps(test)
441            def instantiated_test(self, param_kwargs=param_kwargs):
442                # Sets precision and runs test
443                # Note: precision is reset after the test is run
444                guard_precision = self.precision
445                guard_rel_tol = self.rel_tol
446                try:
447                    self._apply_precision_override_for_test(test, param_kwargs)
448                    result = test(self, **param_kwargs)
449                except RuntimeError as rte:
450                    # check if rte should stop entire test suite.
451                    self._stop_test_suite = self._should_stop_test_suite()
452                    # Check if test has been decorated with `@expectedFailure`
453                    # Using `__unittest_expecting_failure__` attribute, see
454                    # https://github.com/python/cpython/blob/ffa505b580464/Lib/unittest/case.py#L164
455                    # In that case, make it fail with "unexpected success" by suppressing exception
456                    if (
457                        getattr(test, "__unittest_expecting_failure__", False)
458                        and self._stop_test_suite
459                    ):
460                        import sys
461
462                        print(
463                            "Suppressing fatal exception to trigger unexpected success",
464                            file=sys.stderr,
465                        )
466                        return
467                    # raise the runtime error as is for the test suite to record.
468                    raise rte
469                finally:
470                    self.precision = guard_precision
471                    self.rel_tol = guard_rel_tol
472
473                return result
474
475            assert not hasattr(cls, name), f"Redefinition of test {name}"
476            setattr(cls, name, instantiated_test)
477
478        def default_parametrize_fn(test, generic_cls, device_cls):
479            # By default, no parametrization is needed.
480            yield (test, "", {}, lambda _: [])
481
482        # Parametrization decorators set the parametrize_fn attribute on the test.
483        parametrize_fn = getattr(test, "parametrize_fn", default_parametrize_fn)
484
485        # If one of the @dtypes* decorators is present, also parametrize over the dtypes set by it.
486        dtypes = cls._get_dtypes(test)
487        if dtypes is not None:
488
489            def dtype_parametrize_fn(test, generic_cls, device_cls, dtypes=dtypes):
490                for dtype in dtypes:
491                    param_kwargs: Dict[str, Any] = {}
492                    _update_param_kwargs(param_kwargs, "dtype", dtype)
493
494                    # Note that an empty test suffix is set here so that the dtype can be appended
495                    # later after the device.
496                    yield (test, "", param_kwargs, lambda _: [])
497
498            parametrize_fn = compose_parametrize_fns(
499                dtype_parametrize_fn, parametrize_fn
500            )
501
502        # Instantiate the parametrized tests.
503        for (
504            test,  # noqa: B020
505            test_suffix,
506            param_kwargs,
507            decorator_fn,
508        ) in parametrize_fn(test, generic_cls, cls):
509            test_suffix = "" if test_suffix == "" else "_" + test_suffix
510            cls_device_type = (
511                cls.device_type
512                if cls.device_type != "privateuse1"
513                else torch._C._get_privateuse1_backend_name()
514            )
515            device_suffix = "_" + cls_device_type
516
517            # Note: device and dtype suffix placement
518            # Special handling here to place dtype(s) after device according to test name convention.
519            dtype_kwarg = None
520            if "dtype" in param_kwargs or "dtypes" in param_kwargs:
521                dtype_kwarg = (
522                    param_kwargs["dtypes"]
523                    if "dtypes" in param_kwargs
524                    else param_kwargs["dtype"]
525                )
526            test_name = (
527                f"{name}{test_suffix}{device_suffix}{_dtype_test_suffix(dtype_kwarg)}"
528            )
529
530            instantiate_test_helper(
531                cls=cls,
532                name=test_name,
533                test=test,
534                param_kwargs=param_kwargs,
535                decorator_fn=decorator_fn,
536            )
537
538    def run(self, result=None):
539        super().run(result=result)
540        # Early terminate test if _stop_test_suite is set.
541        if self._stop_test_suite:
542            result.stop()
543
544
545class CPUTestBase(DeviceTypeTestBase):
546    device_type = "cpu"
547
548    # No critical error should stop CPU test suite
549    def _should_stop_test_suite(self):
550        return False
551
552
553class CUDATestBase(DeviceTypeTestBase):
554    device_type = "cuda"
555    _do_cuda_memory_leak_check = True
556    _do_cuda_non_default_stream = True
557    primary_device: ClassVar[str]
558    cudnn_version: ClassVar[Any]
559    no_magma: ClassVar[bool]
560    no_cudnn: ClassVar[bool]
561
562    def has_cudnn(self):
563        return not self.no_cudnn
564
565    @classmethod
566    def get_primary_device(cls):
567        return cls.primary_device
568
569    @classmethod
570    def get_all_devices(cls):
571        primary_device_idx = int(cls.get_primary_device().split(":")[1])
572        num_devices = torch.cuda.device_count()
573
574        prim_device = cls.get_primary_device()
575        cuda_str = "cuda:{0}"
576        non_primary_devices = [
577            cuda_str.format(idx)
578            for idx in range(num_devices)
579            if idx != primary_device_idx
580        ]
581        return [prim_device] + non_primary_devices
582
583    @classmethod
584    def setUpClass(cls):
585        # has_magma shows up after cuda is initialized
586        t = torch.ones(1).cuda()
587        cls.no_magma = not torch.cuda.has_magma
588
589        # Determines if cuDNN is available and its version
590        cls.no_cudnn = not torch.backends.cudnn.is_acceptable(t)
591        cls.cudnn_version = None if cls.no_cudnn else torch.backends.cudnn.version()
592
593        # Acquires the current device as the primary (test) device
594        cls.primary_device = f"cuda:{torch.cuda.current_device()}"
595
596
597# See Note [Lazy Tensor tests in device agnostic testing]
598lazy_ts_backend_init = False
599
600
601class LazyTestBase(DeviceTypeTestBase):
602    device_type = "lazy"
603
604    def _should_stop_test_suite(self):
605        return False
606
607    @classmethod
608    def setUpClass(cls):
609        import torch._lazy
610        import torch._lazy.metrics
611        import torch._lazy.ts_backend
612
613        global lazy_ts_backend_init
614        if not lazy_ts_backend_init:
615            # Need to connect the TS backend to lazy key before running tests
616            torch._lazy.ts_backend.init()
617            lazy_ts_backend_init = True
618
619
620class MPSTestBase(DeviceTypeTestBase):
621    device_type = "mps"
622    primary_device: ClassVar[str]
623
624    @classmethod
625    def get_primary_device(cls):
626        return cls.primary_device
627
628    @classmethod
629    def get_all_devices(cls):
630        # currently only one device is supported on MPS backend
631        prim_device = cls.get_primary_device()
632        return [prim_device]
633
634    @classmethod
635    def setUpClass(cls):
636        cls.primary_device = "mps:0"
637
638    def _should_stop_test_suite(self):
639        return False
640
641
642class XPUTestBase(DeviceTypeTestBase):
643    device_type = "xpu"
644    primary_device: ClassVar[str]
645
646    @classmethod
647    def get_primary_device(cls):
648        return cls.primary_device
649
650    @classmethod
651    def get_all_devices(cls):
652        # currently only one device is supported on MPS backend
653        prim_device = cls.get_primary_device()
654        return [prim_device]
655
656    @classmethod
657    def setUpClass(cls):
658        cls.primary_device = "xpu:0"
659
660    def _should_stop_test_suite(self):
661        return False
662
663
664class HPUTestBase(DeviceTypeTestBase):
665    device_type = "hpu"
666    primary_device: ClassVar[str]
667
668    @classmethod
669    def get_primary_device(cls):
670        return cls.primary_device
671
672    @classmethod
673    def setUpClass(cls):
674        cls.primary_device = "hpu:0"
675
676
677class PrivateUse1TestBase(DeviceTypeTestBase):
678    primary_device: ClassVar[str]
679    device_mod = None
680    device_type = "privateuse1"
681
682    @classmethod
683    def get_primary_device(cls):
684        return cls.primary_device
685
686    @classmethod
687    def get_all_devices(cls):
688        primary_device_idx = int(cls.get_primary_device().split(":")[1])
689        num_devices = cls.device_mod.device_count()
690        prim_device = cls.get_primary_device()
691        device_str = f"{cls.device_type}:{{0}}"
692        non_primary_devices = [
693            device_str.format(idx)
694            for idx in range(num_devices)
695            if idx != primary_device_idx
696        ]
697        return [prim_device] + non_primary_devices
698
699    @classmethod
700    def setUpClass(cls):
701        cls.device_type = torch._C._get_privateuse1_backend_name()
702        cls.device_mod = getattr(torch, cls.device_type, None)
703        assert (
704            cls.device_mod is not None
705        ), f"""torch has no module of `{cls.device_type}`, you should register
706                                            a module by `torch._register_device_module`."""
707        cls.primary_device = f"{cls.device_type}:{cls.device_mod.current_device()}"
708
709
710# Adds available device-type-specific test base classes
711def get_device_type_test_bases():
712    # set type to List[Any] due to mypy list-of-union issue:
713    # https://github.com/python/mypy/issues/3351
714    test_bases: List[Any] = []
715
716    if IS_SANDCASTLE or IS_FBCODE:
717        if IS_REMOTE_GPU:
718            # Skip if sanitizer is enabled
719            if not TEST_WITH_ASAN and not TEST_WITH_TSAN and not TEST_WITH_UBSAN:
720                test_bases.append(CUDATestBase)
721        else:
722            test_bases.append(CPUTestBase)
723    else:
724        test_bases.append(CPUTestBase)
725        if torch.cuda.is_available():
726            test_bases.append(CUDATestBase)
727
728        if is_privateuse1_backend_available():
729            test_bases.append(PrivateUse1TestBase)
730        # Disable MPS testing in generic device testing temporarily while we're
731        # ramping up support.
732        # elif torch.backends.mps.is_available():
733        #   test_bases.append(MPSTestBase)
734
735    return test_bases
736
737
738device_type_test_bases = get_device_type_test_bases()
739
740
741def filter_desired_device_types(device_type_test_bases, except_for=None, only_for=None):
742    # device type cannot appear in both except_for and only_for
743    intersect = set(except_for if except_for else []) & set(
744        only_for if only_for else []
745    )
746    assert (
747        not intersect
748    ), f"device ({intersect}) appeared in both except_for and only_for"
749
750    # Replace your privateuse1 backend name with 'privateuse1'
751    if is_privateuse1_backend_available():
752        privateuse1_backend_name = torch._C._get_privateuse1_backend_name()
753        except_for = (
754            ["privateuse1" if x == privateuse1_backend_name else x for x in except_for]
755            if except_for is not None
756            else None
757        )
758        only_for = (
759            ["privateuse1" if x == privateuse1_backend_name else x for x in only_for]
760            if only_for is not None
761            else None
762        )
763
764    if except_for:
765        device_type_test_bases = filter(
766            lambda x: x.device_type not in except_for, device_type_test_bases
767        )
768    if only_for:
769        device_type_test_bases = filter(
770            lambda x: x.device_type in only_for, device_type_test_bases
771        )
772
773    return list(device_type_test_bases)
774
775
776# Note [How to extend DeviceTypeTestBase to add new test device]
777# The following logic optionally allows downstream projects like pytorch/xla to
778# add more test devices.
779# Instructions:
780#  - Add a python file (e.g. pytorch/xla/test/pytorch_test_base.py) in downstream project.
781#    - Inside the file, one should inherit from `DeviceTypeTestBase` class and define
782#      a new DeviceTypeTest class (e.g. `XLATestBase`) with proper implementation of
783#      `instantiate_test` method.
784#    - DO NOT import common_device_type inside the file.
785#      `runpy.run_path` with `globals()` already properly setup the context so that
786#      `DeviceTypeTestBase` is already available.
787#    - Set a top-level variable `TEST_CLASS` equal to your new class.
788#      E.g. TEST_CLASS = XLATensorBase
789#  - To run tests with new device type, set `TORCH_TEST_DEVICE` env variable to path
790#    to this file. Multiple paths can be separated by `:`.
791# See pytorch/xla/test/pytorch_test_base.py for a more detailed example.
792_TORCH_TEST_DEVICES = os.environ.get("TORCH_TEST_DEVICES", None)
793if _TORCH_TEST_DEVICES:
794    for path in _TORCH_TEST_DEVICES.split(":"):
795        # runpy (a stdlib module) lacks annotations
796        mod = runpy.run_path(path, init_globals=globals())  # type: ignore[func-returns-value]
797        device_type_test_bases.append(mod["TEST_CLASS"])
798
799
800PYTORCH_CUDA_MEMCHECK = os.getenv("PYTORCH_CUDA_MEMCHECK", "0") == "1"
801
802PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY = "PYTORCH_TESTING_DEVICE_ONLY_FOR"
803PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY = "PYTORCH_TESTING_DEVICE_EXCEPT_FOR"
804PYTORCH_TESTING_DEVICE_FOR_CUSTOM_KEY = "PYTORCH_TESTING_DEVICE_FOR_CUSTOM"
805
806
807def get_desired_device_type_test_bases(
808    except_for=None, only_for=None, include_lazy=False, allow_mps=False, allow_xpu=False
809):
810    # allow callers to specifically opt tests into being tested on MPS, similar to `include_lazy`
811    test_bases = device_type_test_bases.copy()
812    if allow_mps and TEST_MPS and MPSTestBase not in test_bases:
813        test_bases.append(MPSTestBase)
814    if allow_xpu and TEST_XPU and XPUTestBase not in test_bases:
815        test_bases.append(XPUTestBase)
816    if TEST_HPU and HPUTestBase not in test_bases:
817        test_bases.append(HPUTestBase)
818    # Filter out the device types based on user inputs
819    desired_device_type_test_bases = filter_desired_device_types(
820        test_bases, except_for, only_for
821    )
822    if include_lazy:
823        # Note [Lazy Tensor tests in device agnostic testing]
824        # Right now, test_view_ops.py runs with LazyTensor.
825        # We don't want to opt every device-agnostic test into using the lazy device,
826        # because many of them will fail.
827        # So instead, the only way to opt a specific device-agnostic test file into
828        # lazy tensor testing is with include_lazy=True
829        if IS_FBCODE:
830            print(
831                "TorchScript backend not yet supported in FBCODE/OVRSOURCE builds",
832                file=sys.stderr,
833            )
834        else:
835            desired_device_type_test_bases.append(LazyTestBase)
836
837    def split_if_not_empty(x: str):
838        return x.split(",") if x else []
839
840    # run some cuda testcases on other devices if available
841    # Usage:
842    # export PYTORCH_TESTING_DEVICE_FOR_CUSTOM=privateuse1
843    env_custom_only_for = split_if_not_empty(
844        os.getenv(PYTORCH_TESTING_DEVICE_FOR_CUSTOM_KEY, "")
845    )
846    if env_custom_only_for:
847        desired_device_type_test_bases += filter(
848            lambda x: x.device_type in env_custom_only_for, test_bases
849        )
850        desired_device_type_test_bases = list(set(desired_device_type_test_bases))
851
852    # Filter out the device types based on environment variables if available
853    # Usage:
854    # export PYTORCH_TESTING_DEVICE_ONLY_FOR=cuda,cpu
855    # export PYTORCH_TESTING_DEVICE_EXCEPT_FOR=xla
856    env_only_for = split_if_not_empty(
857        os.getenv(PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, "")
858    )
859    env_except_for = split_if_not_empty(
860        os.getenv(PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY, "")
861    )
862
863    return filter_desired_device_types(
864        desired_device_type_test_bases, env_except_for, env_only_for
865    )
866
867
868# Adds 'instantiated' device-specific test cases to the given scope.
869# The tests in these test cases are derived from the generic tests in
870# generic_test_class. This function should be used instead of
871# instantiate_parametrized_tests() if the test class contains
872# device-specific tests (NB: this supports additional @parametrize usage).
873#
874# See note "Writing Test Templates"
875# TODO: remove "allow_xpu" option after Interl GPU support all test case instantiate by this function.
876def instantiate_device_type_tests(
877    generic_test_class,
878    scope,
879    except_for=None,
880    only_for=None,
881    include_lazy=False,
882    allow_mps=False,
883    allow_xpu=False,
884):
885    # Removes the generic test class from its enclosing scope so its tests
886    # are not discoverable.
887    del scope[generic_test_class.__name__]
888
889    # Creates an 'empty' version of the generic_test_class
890    # Note: we don't inherit from the generic_test_class directly because
891    #   that would add its tests to our test classes and they would be
892    #   discovered (despite not being runnable). Inherited methods also
893    #   can't be removed later, and we can't rely on load_tests because
894    #   pytest doesn't support it (as of this writing).
895    empty_name = generic_test_class.__name__ + "_base"
896    empty_class = type(empty_name, generic_test_class.__bases__, {})
897
898    # Acquires members names
899    # See Note [Overriding methods in generic tests]
900    generic_members = set(generic_test_class.__dict__.keys()) - set(
901        empty_class.__dict__.keys()
902    )
903    generic_tests = [x for x in generic_members if x.startswith("test")]
904
905    # Creates device-specific test cases
906    for base in get_desired_device_type_test_bases(
907        except_for, only_for, include_lazy, allow_mps, allow_xpu
908    ):
909        class_name = generic_test_class.__name__ + base.device_type.upper()
910
911        # type set to Any and suppressed due to unsupport runtime class:
912        # https://github.com/python/mypy/wiki/Unsupported-Python-Features
913        device_type_test_class: Any = type(class_name, (base, empty_class), {})
914
915        for name in generic_members:
916            if name in generic_tests:  # Instantiates test member
917                test = getattr(generic_test_class, name)
918                # XLA-compat shim (XLA's instantiate_test takes doesn't take generic_cls)
919                sig = inspect.signature(device_type_test_class.instantiate_test)
920                if len(sig.parameters) == 3:
921                    # Instantiates the device-specific tests
922                    device_type_test_class.instantiate_test(
923                        name, copy.deepcopy(test), generic_cls=generic_test_class
924                    )
925                else:
926                    device_type_test_class.instantiate_test(name, copy.deepcopy(test))
927            else:  # Ports non-test member
928                assert (
929                    name not in device_type_test_class.__dict__
930                ), f"Redefinition of directly defined member {name}"
931                nontest = getattr(generic_test_class, name)
932                setattr(device_type_test_class, name, nontest)
933
934        # The dynamically-created test class derives from the test template class
935        # and the empty class. Arrange for both setUpClass and tearDownClass methods
936        # to be called. This allows the parameterized test classes to support setup
937        # and teardown.
938        @classmethod
939        def _setUpClass(cls):
940            base.setUpClass()
941            empty_class.setUpClass()
942
943        @classmethod
944        def _tearDownClass(cls):
945            empty_class.tearDownClass()
946            base.tearDownClass()
947
948        device_type_test_class.setUpClass = _setUpClass
949        device_type_test_class.tearDownClass = _tearDownClass
950
951        # Mimics defining the instantiated class in the caller's file
952        # by setting its module to the given class's and adding
953        # the module to the given scope.
954        # This lets the instantiated class be discovered by unittest.
955        device_type_test_class.__module__ = generic_test_class.__module__
956        scope[class_name] = device_type_test_class
957
958
959# Category of dtypes to run an OpInfo-based test for
960# Example use: @ops(dtype=OpDTypes.supported)
961#
962# There are 5 categories:
963# - supported: Every dtype supported by the operator. Use for exhaustive
964#              testing of all dtypes.
965# - unsupported: Run tests on dtypes not supported by the operator. e.g. for
966#                testing the operator raises an error and doesn't crash.
967# - supported_backward: Every dtype supported by the operator's backward pass.
968# - unsupported_backward: Run tests on dtypes not supported by the operator's backward pass.
969# - any_one: Runs a test for one dtype the operator supports. Prioritizes dtypes the
970#     operator supports in both forward and backward.
971# - none: Useful for tests that are not dtype-specific. No dtype will be passed to the test
972#         when this is selected.
973class OpDTypes(Enum):
974    supported = 0  # Test all supported dtypes (default)
975    unsupported = 1  # Test only unsupported dtypes
976    supported_backward = 2  # Test all supported backward dtypes
977    unsupported_backward = 3  # Test only unsupported backward dtypes
978    any_one = 4  # Test precisely one supported dtype
979    none = 5  # Instantiate no dtype variants (no dtype kwarg needed)
980    any_common_cpu_cuda_one = (
981        6  # Test precisely one supported dtype that is common to both cuda and cpu
982    )
983
984
985# Arbitrary order
986ANY_DTYPE_ORDER = (
987    torch.float32,
988    torch.float64,
989    torch.complex64,
990    torch.complex128,
991    torch.float16,
992    torch.bfloat16,
993    torch.long,
994    torch.int32,
995    torch.int16,
996    torch.int8,
997    torch.uint8,
998    torch.bool,
999)
1000
1001
1002def _serialize_sample(sample_input):
1003    # NB: For OpInfos, SampleInput.summary() prints in a cleaner way.
1004    if getattr(sample_input, "summary", None) is not None:
1005        return sample_input.summary()
1006    return str(sample_input)
1007
1008
1009# Decorator that defines the OpInfos a test template should be instantiated for.
1010#
1011# Example usage:
1012#
1013# @ops(unary_ufuncs)
1014# def test_numerics(self, device, dtype, op):
1015#   <test_code>
1016#
1017# This will instantiate variants of test_numerics for each given OpInfo,
1018# on each device the OpInfo's operator supports, and for every dtype supported by
1019# that operator. There are a few caveats to the dtype rule, explained below.
1020#
1021# The @ops decorator can accept two
1022# additional arguments, "dtypes" and "allowed_dtypes". If "dtypes" is specified
1023# then the test variants are instantiated for those dtypes, regardless of
1024# what the operator supports. If given "allowed_dtypes" then test variants
1025# are instantiated only for the intersection of allowed_dtypes and the dtypes
1026# they would otherwise be instantiated with. That is, allowed_dtypes composes
1027# with the options listed above and below.
1028#
1029# The "dtypes" argument can also accept additional values (see OpDTypes above):
1030#   OpDTypes.supported - the test is instantiated for all dtypes the operator
1031#     supports
1032#   OpDTypes.unsupported - the test is instantiated for all dtypes the operator
1033#     doesn't support
1034#   OpDTypes.supported_backward - the test is instantiated for all dtypes the
1035#     operator's gradient formula supports
1036#   OpDTypes.unsupported_backward - the test is instantiated for all dtypes the
1037#     operator's gradient formula doesn't support
1038#   OpDTypes.any_one - the test is instantiated for one dtype the
1039#     operator supports. The dtype supports forward and backward if possible.
1040#   OpDTypes.none - the test is instantiated without any dtype. The test signature
1041#     should not include a dtype kwarg in this case.
1042#
1043# These options allow tests to have considerable control over the dtypes
1044#   they're instantiated for.
1045
1046
1047class ops(_TestParametrizer):
1048    def __init__(
1049        self,
1050        op_list,
1051        *,
1052        dtypes: Union[OpDTypes, Sequence[torch.dtype]] = OpDTypes.supported,
1053        allowed_dtypes: Optional[Sequence[torch.dtype]] = None,
1054        skip_if_dynamo=True,
1055    ):
1056        self.op_list = list(op_list)
1057        self.opinfo_dtypes = dtypes
1058        self.allowed_dtypes = (
1059            set(allowed_dtypes) if allowed_dtypes is not None else None
1060        )
1061        self.skip_if_dynamo = skip_if_dynamo
1062
1063    def _parametrize_test(self, test, generic_cls, device_cls):
1064        """Parameterizes the given test function across each op and its associated dtypes."""
1065        if device_cls is None:
1066            raise RuntimeError(
1067                "The @ops decorator is only intended to be used in a device-specific "
1068                "context; use it with instantiate_device_type_tests() instead of "
1069                "instantiate_parametrized_tests()"
1070            )
1071
1072        op = check_exhausted_iterator = object()
1073        for op in self.op_list:
1074            # Determine the set of dtypes to use.
1075            dtypes: Union[Set[torch.dtype], Set[None]]
1076            if isinstance(self.opinfo_dtypes, Sequence):
1077                dtypes = set(self.opinfo_dtypes)
1078            elif self.opinfo_dtypes == OpDTypes.unsupported_backward:
1079                dtypes = set(get_all_dtypes()).difference(
1080                    op.supported_backward_dtypes(device_cls.device_type)
1081                )
1082            elif self.opinfo_dtypes == OpDTypes.supported_backward:
1083                dtypes = op.supported_backward_dtypes(device_cls.device_type)
1084            elif self.opinfo_dtypes == OpDTypes.unsupported:
1085                dtypes = set(get_all_dtypes()).difference(
1086                    op.supported_dtypes(device_cls.device_type)
1087                )
1088            elif self.opinfo_dtypes == OpDTypes.supported:
1089                dtypes = set(op.supported_dtypes(device_cls.device_type))
1090            elif self.opinfo_dtypes == OpDTypes.any_one:
1091                # Tries to pick a dtype that supports both forward or backward
1092                supported = op.supported_dtypes(device_cls.device_type)
1093                supported_backward = op.supported_backward_dtypes(
1094                    device_cls.device_type
1095                )
1096                supported_both = supported.intersection(supported_backward)
1097                dtype_set = supported_both if len(supported_both) > 0 else supported
1098                for dtype in ANY_DTYPE_ORDER:
1099                    if dtype in dtype_set:
1100                        dtypes = {dtype}
1101                        break
1102                else:
1103                    dtypes = {}
1104            elif self.opinfo_dtypes == OpDTypes.any_common_cpu_cuda_one:
1105                # Tries to pick a dtype that supports both CPU and CUDA
1106                supported = set(op.dtypes).intersection(op.dtypesIfCUDA)
1107                if supported:
1108                    dtypes = {
1109                        next(dtype for dtype in ANY_DTYPE_ORDER if dtype in supported)
1110                    }
1111                else:
1112                    dtypes = {}
1113
1114            elif self.opinfo_dtypes == OpDTypes.none:
1115                dtypes = {None}
1116            else:
1117                raise RuntimeError(f"Unknown OpDType: {self.opinfo_dtypes}")
1118
1119            if self.allowed_dtypes is not None:
1120                dtypes = dtypes.intersection(self.allowed_dtypes)
1121
1122            # Construct the test name; device / dtype parts are handled outside.
1123            # See [Note: device and dtype suffix placement]
1124            test_name = op.formatted_name
1125
1126            for dtype in dtypes:
1127                # Construct parameter kwargs to pass to the test.
1128                param_kwargs = {"op": op}
1129                _update_param_kwargs(param_kwargs, "dtype", dtype)
1130
1131                # NOTE: test_wrapper exists because we don't want to apply
1132                #   op-specific decorators to the original test.
1133                #   Test-specific decorators are applied to the original test,
1134                #   however.
1135                try:
1136
1137                    @wraps(test)
1138                    def test_wrapper(*args, **kwargs):
1139                        try:
1140                            return test(*args, **kwargs)
1141                        except unittest.SkipTest as e:
1142                            raise e
1143                        except Exception as e:
1144                            tracked_input = get_tracked_input()
1145                            if PRINT_REPRO_ON_FAILURE and tracked_input is not None:
1146                                e_tracked = Exception(  # noqa: TRY002
1147                                    f"Caused by {tracked_input.type_desc} "
1148                                    f"at index {tracked_input.index}: "
1149                                    f"{_serialize_sample(tracked_input.val)}"
1150                                )
1151                                e_tracked._tracked_input = tracked_input  # type: ignore[attr]
1152                                raise e_tracked from e
1153                            raise e
1154                        finally:
1155                            clear_tracked_input()
1156
1157                    if self.skip_if_dynamo and not TEST_WITH_TORCHINDUCTOR:
1158                        test_wrapper = skipIfTorchDynamo(
1159                            "Policy: we don't run OpInfo tests w/ Dynamo"
1160                        )(test_wrapper)
1161
1162                    # Initialize info for the last input seen. This is useful for tracking
1163                    # down which inputs caused a test failure. Note that TrackedInputIter is
1164                    # responsible for managing this.
1165                    test.tracked_input = None
1166
1167                    decorator_fn = partial(
1168                        op.get_decorators,
1169                        generic_cls.__name__,
1170                        test.__name__,
1171                        device_cls.device_type,
1172                        dtype,
1173                    )
1174
1175                    yield (test_wrapper, test_name, param_kwargs, decorator_fn)
1176                except Exception as ex:
1177                    # Provides an error message for debugging before rethrowing the exception
1178                    print(f"Failed to instantiate {test_name} for op {op.name}!")
1179                    raise ex
1180        if op is check_exhausted_iterator:
1181            raise ValueError(
1182                "An empty op_list was passed to @ops. "
1183                "Note that this may result from reuse of a generator."
1184            )
1185
1186
1187# Decorator that skips a test if the given condition is true.
1188# Notes:
1189#   (1) Skip conditions stack.
1190#   (2) Skip conditions can be bools or strings. If a string the
1191#       test base must have defined the corresponding attribute to be False
1192#       for the test to run. If you want to use a string argument you should
1193#       probably define a new decorator instead (see below).
1194#   (3) Prefer the existing decorators to defining the 'device_type' kwarg.
1195class skipIf:
1196    def __init__(self, dep, reason, device_type=None):
1197        self.dep = dep
1198        self.reason = reason
1199        self.device_type = device_type
1200
1201    def __call__(self, fn):
1202        @wraps(fn)
1203        def dep_fn(slf, *args, **kwargs):
1204            if self.device_type is None or self.device_type == slf.device_type:
1205                if (isinstance(self.dep, str) and getattr(slf, self.dep, True)) or (
1206                    isinstance(self.dep, bool) and self.dep
1207                ):
1208                    raise unittest.SkipTest(self.reason)
1209
1210            return fn(slf, *args, **kwargs)
1211
1212        return dep_fn
1213
1214
1215# Skips a test on CPU if the condition is true.
1216class skipCPUIf(skipIf):
1217    def __init__(self, dep, reason):
1218        super().__init__(dep, reason, device_type="cpu")
1219
1220
1221# Skips a test on CUDA if the condition is true.
1222class skipCUDAIf(skipIf):
1223    def __init__(self, dep, reason):
1224        super().__init__(dep, reason, device_type="cuda")
1225
1226
1227# Skips a test on XPU if the condition is true.
1228class skipXPUIf(skipIf):
1229    def __init__(self, dep, reason):
1230        super().__init__(dep, reason, device_type="xpu")
1231
1232
1233# Skips a test on Lazy if the condition is true.
1234class skipLazyIf(skipIf):
1235    def __init__(self, dep, reason):
1236        super().__init__(dep, reason, device_type="lazy")
1237
1238
1239# Skips a test on Meta if the condition is true.
1240class skipMetaIf(skipIf):
1241    def __init__(self, dep, reason):
1242        super().__init__(dep, reason, device_type="meta")
1243
1244
1245# Skips a test on MPS if the condition is true.
1246class skipMPSIf(skipIf):
1247    def __init__(self, dep, reason):
1248        super().__init__(dep, reason, device_type="mps")
1249
1250
1251class skipHPUIf(skipIf):
1252    def __init__(self, dep, reason):
1253        super().__init__(dep, reason, device_type="hpu")
1254
1255
1256# Skips a test on XLA if the condition is true.
1257class skipXLAIf(skipIf):
1258    def __init__(self, dep, reason):
1259        super().__init__(dep, reason, device_type="xla")
1260
1261
1262class skipPRIVATEUSE1If(skipIf):
1263    def __init__(self, dep, reason):
1264        device_type = torch._C._get_privateuse1_backend_name()
1265        super().__init__(dep, reason, device_type=device_type)
1266
1267
1268def _has_sufficient_memory(device, size):
1269    if torch.device(device).type == "cuda":
1270        if not torch.cuda.is_available():
1271            return False
1272        gc.collect()
1273        torch.cuda.empty_cache()
1274        # torch.cuda.mem_get_info, aka cudaMemGetInfo, returns a tuple of (free memory, total memory) of a GPU
1275        if device == "cuda":
1276            device = "cuda:0"
1277        return torch.cuda.memory.mem_get_info(device)[0] >= size
1278
1279    if device == "xla":
1280        raise unittest.SkipTest("TODO: Memory availability checks for XLA?")
1281
1282    if device == "xpu":
1283        raise unittest.SkipTest("TODO: Memory availability checks for Intel GPU?")
1284
1285    if device != "cpu":
1286        raise unittest.SkipTest("Unknown device type")
1287
1288    # CPU
1289    if not HAS_PSUTIL:
1290        raise unittest.SkipTest("Need psutil to determine if memory is sufficient")
1291
1292    # The sanitizers have significant memory overheads
1293    if TEST_WITH_ASAN or TEST_WITH_TSAN or TEST_WITH_UBSAN:
1294        effective_size = size * 10
1295    else:
1296        effective_size = size
1297
1298    if psutil.virtual_memory().available < effective_size:
1299        gc.collect()
1300    return psutil.virtual_memory().available >= effective_size
1301
1302
1303def largeTensorTest(size, device=None):
1304    """Skip test if the device has insufficient memory to run the test
1305
1306    size may be a number of bytes, a string of the form "N GB", or a callable
1307
1308    If the test is a device generic test, available memory on the primary device will be checked.
1309    It can also be overriden by the optional `device=` argument.
1310    In other tests, the `device=` argument needs to be specified.
1311    """
1312    if isinstance(size, str):
1313        assert size.endswith(("GB", "gb")), "only bytes or GB supported"
1314        size = 1024**3 * int(size[:-2])
1315
1316    def inner(fn):
1317        @wraps(fn)
1318        def dep_fn(self, *args, **kwargs):
1319            size_bytes = size(self, *args, **kwargs) if callable(size) else size
1320            _device = device if device is not None else self.get_primary_device()
1321            if not _has_sufficient_memory(_device, size_bytes):
1322                raise unittest.SkipTest(f"Insufficient {_device} memory")
1323
1324            return fn(self, *args, **kwargs)
1325
1326        return dep_fn
1327
1328    return inner
1329
1330
1331class expectedFailure:
1332    def __init__(self, device_type):
1333        self.device_type = device_type
1334
1335    def __call__(self, fn):
1336        @wraps(fn)
1337        def efail_fn(slf, *args, **kwargs):
1338            if (
1339                not hasattr(slf, "device_type")
1340                and hasattr(slf, "device")
1341                and isinstance(slf.device, str)
1342            ):
1343                target_device_type = slf.device
1344            else:
1345                target_device_type = slf.device_type
1346
1347            if self.device_type is None or self.device_type == target_device_type:
1348                try:
1349                    fn(slf, *args, **kwargs)
1350                except Exception:
1351                    return
1352                else:
1353                    slf.fail("expected test to fail, but it passed")
1354
1355            return fn(slf, *args, **kwargs)
1356
1357        return efail_fn
1358
1359
1360class onlyOn:
1361    def __init__(self, device_type):
1362        self.device_type = device_type
1363
1364    def __call__(self, fn):
1365        @wraps(fn)
1366        def only_fn(slf, *args, **kwargs):
1367            if self.device_type != slf.device_type:
1368                reason = f"Only runs on {self.device_type}"
1369                raise unittest.SkipTest(reason)
1370
1371            return fn(slf, *args, **kwargs)
1372
1373        return only_fn
1374
1375
1376# Decorator that provides all available devices of the device type to the test
1377# as a list of strings instead of providing a single device string.
1378# Skips the test if the number of available devices of the variant's device
1379# type is less than the 'num_required_devices' arg.
1380class deviceCountAtLeast:
1381    def __init__(self, num_required_devices):
1382        self.num_required_devices = num_required_devices
1383
1384    def __call__(self, fn):
1385        assert not hasattr(
1386            fn, "num_required_devices"
1387        ), f"deviceCountAtLeast redefinition for {fn.__name__}"
1388        fn.num_required_devices = self.num_required_devices
1389
1390        @wraps(fn)
1391        def multi_fn(slf, devices, *args, **kwargs):
1392            if len(devices) < self.num_required_devices:
1393                reason = f"fewer than {self.num_required_devices} devices detected"
1394                raise unittest.SkipTest(reason)
1395
1396            return fn(slf, devices, *args, **kwargs)
1397
1398        return multi_fn
1399
1400
1401# Only runs the test on the native device type (currently CPU, CUDA, Meta and PRIVATEUSE1)
1402def onlyNativeDeviceTypes(fn):
1403    @wraps(fn)
1404    def only_fn(self, *args, **kwargs):
1405        if self.device_type not in NATIVE_DEVICES:
1406            reason = f"onlyNativeDeviceTypes: doesn't run on {self.device_type}"
1407            raise unittest.SkipTest(reason)
1408
1409        return fn(self, *args, **kwargs)
1410
1411    return only_fn
1412
1413
1414# Only runs the test on the native device types and devices specified in the devices list
1415def onlyNativeDeviceTypesAnd(devices=None):
1416    def decorator(fn):
1417        @wraps(fn)
1418        def only_fn(self, *args, **kwargs):
1419            if (
1420                self.device_type not in NATIVE_DEVICES
1421                and self.device_type not in devices
1422            ):
1423                reason = f"onlyNativeDeviceTypesAnd {devices} : doesn't run on {self.device_type}"
1424                raise unittest.SkipTest(reason)
1425
1426            return fn(self, *args, **kwargs)
1427
1428        return only_fn
1429
1430    return decorator
1431
1432
1433# Specifies per-dtype precision overrides.
1434# Ex.
1435#
1436# @precisionOverride({torch.half : 1e-2, torch.float : 1e-4})
1437# @dtypes(torch.half, torch.float, torch.double)
1438# def test_X(self, device, dtype):
1439#   ...
1440#
1441# When the test is instantiated its class's precision will be set to the
1442# corresponding override, if it exists.
1443# self.precision can be accessed directly, and it also controls the behavior of
1444# functions like self.assertEqual().
1445#
1446# Note that self.precision is a scalar value, so if you require multiple
1447# precisions (or are working with multiple dtypes) they should be specified
1448# explicitly and computed using self.precision (e.g.
1449# self.precision *2, max(1, self.precision)).
1450class precisionOverride:
1451    def __init__(self, d):
1452        assert isinstance(
1453            d, dict
1454        ), "precisionOverride not given a dtype : precision dict!"
1455        for dtype in d.keys():
1456            assert isinstance(
1457                dtype, torch.dtype
1458            ), f"precisionOverride given unknown dtype {dtype}"
1459
1460        self.d = d
1461
1462    def __call__(self, fn):
1463        fn.precision_overrides = self.d
1464        return fn
1465
1466
1467# Specifies per-dtype tolerance overrides tol(atol, rtol). It has priority over
1468# precisionOverride.
1469# Ex.
1470#
1471# @toleranceOverride({torch.float : tol(atol=1e-2, rtol=1e-3},
1472#                     torch.double : tol{atol=1e-4, rtol = 0})
1473# @dtypes(torch.half, torch.float, torch.double)
1474# def test_X(self, device, dtype):
1475#   ...
1476#
1477# When the test is instantiated its class's tolerance will be set to the
1478# corresponding override, if it exists.
1479# self.rtol and self.precision can be accessed directly, and they also control
1480# the behavior of functions like self.assertEqual().
1481#
1482# The above example sets atol = 1e-2 and rtol = 1e-3 for torch.float and
1483# atol = 1e-4 and rtol = 0 for torch.double.
1484tol = namedtuple("tol", ["atol", "rtol"])
1485
1486
1487class toleranceOverride:
1488    def __init__(self, d):
1489        assert isinstance(d, dict), "toleranceOverride not given a dtype : tol dict!"
1490        for dtype, prec in d.items():
1491            assert isinstance(
1492                dtype, torch.dtype
1493            ), f"toleranceOverride given unknown dtype {dtype}"
1494            assert isinstance(
1495                prec, tol
1496            ), "toleranceOverride not given a dtype : tol dict!"
1497
1498        self.d = d
1499
1500    def __call__(self, fn):
1501        fn.tolerance_overrides = self.d
1502        return fn
1503
1504
1505# Decorator that instantiates a variant of the test for each given dtype.
1506# Notes:
1507#   (1) Tests that accept the dtype argument MUST use this decorator.
1508#   (2) Can be overridden for CPU or CUDA, respectively, using dtypesIfCPU
1509#       or dtypesIfCUDA.
1510#   (3) Can accept an iterable of dtypes or an iterable of tuples
1511#       of dtypes.
1512# Examples:
1513# @dtypes(torch.float32, torch.float64)
1514# @dtypes((torch.long, torch.float32), (torch.int, torch.float64))
1515class dtypes:
1516    def __init__(self, *args, device_type="all"):
1517        if len(args) > 0 and isinstance(args[0], (list, tuple)):
1518            for arg in args:
1519                assert isinstance(arg, (list, tuple)), (
1520                    "When one dtype variant is a tuple or list, "
1521                    "all dtype variants must be. "
1522                    f"Received non-list non-tuple dtype {str(arg)}"
1523                )
1524                assert all(
1525                    isinstance(dtype, torch.dtype) for dtype in arg
1526                ), f"Unknown dtype in {str(arg)}"
1527        else:
1528            assert all(
1529                isinstance(arg, torch.dtype) for arg in args
1530            ), f"Unknown dtype in {str(args)}"
1531
1532        self.args = args
1533        self.device_type = device_type
1534
1535    def __call__(self, fn):
1536        d = getattr(fn, "dtypes", {})
1537        assert self.device_type not in d, f"dtypes redefinition for {self.device_type}"
1538        d[self.device_type] = self.args
1539        fn.dtypes = d
1540        return fn
1541
1542
1543# Overrides specified dtypes on the CPU.
1544class dtypesIfCPU(dtypes):
1545    def __init__(self, *args):
1546        super().__init__(*args, device_type="cpu")
1547
1548
1549# Overrides specified dtypes on CUDA.
1550class dtypesIfCUDA(dtypes):
1551    def __init__(self, *args):
1552        super().__init__(*args, device_type="cuda")
1553
1554
1555class dtypesIfMPS(dtypes):
1556    def __init__(self, *args):
1557        super().__init__(*args, device_type="mps")
1558
1559
1560class dtypesIfHPU(dtypes):
1561    def __init__(self, *args):
1562        super().__init__(*args, device_type="hpu")
1563
1564
1565class dtypesIfPRIVATEUSE1(dtypes):
1566    def __init__(self, *args):
1567        super().__init__(*args, device_type=torch._C._get_privateuse1_backend_name())
1568
1569
1570def onlyCPU(fn):
1571    return onlyOn("cpu")(fn)
1572
1573
1574def onlyCUDA(fn):
1575    return onlyOn("cuda")(fn)
1576
1577
1578def onlyMPS(fn):
1579    return onlyOn("mps")(fn)
1580
1581
1582def onlyXPU(fn):
1583    return onlyOn("xpu")(fn)
1584
1585
1586def onlyHPU(fn):
1587    return onlyOn("hpu")(fn)
1588
1589
1590def onlyPRIVATEUSE1(fn):
1591    device_type = torch._C._get_privateuse1_backend_name()
1592    device_mod = getattr(torch, device_type, None)
1593    if device_mod is None:
1594        reason = f"Skip as torch has no module of {device_type}"
1595        return unittest.skip(reason)(fn)
1596    return onlyOn(device_type)(fn)
1597
1598
1599def onlyCUDAAndPRIVATEUSE1(fn):
1600    @wraps(fn)
1601    def only_fn(self, *args, **kwargs):
1602        if self.device_type not in ("cuda", torch._C._get_privateuse1_backend_name()):
1603            reason = f"onlyCUDAAndPRIVATEUSE1: doesn't run on {self.device_type}"
1604            raise unittest.SkipTest(reason)
1605
1606        return fn(self, *args, **kwargs)
1607
1608    return only_fn
1609
1610
1611def disablecuDNN(fn):
1612    @wraps(fn)
1613    def disable_cudnn(self, *args, **kwargs):
1614        if self.device_type == "cuda" and self.has_cudnn():
1615            with torch.backends.cudnn.flags(enabled=False):
1616                return fn(self, *args, **kwargs)
1617        return fn(self, *args, **kwargs)
1618
1619    return disable_cudnn
1620
1621
1622def disableMkldnn(fn):
1623    @wraps(fn)
1624    def disable_mkldnn(self, *args, **kwargs):
1625        if torch.backends.mkldnn.is_available():
1626            with torch.backends.mkldnn.flags(enabled=False):
1627                return fn(self, *args, **kwargs)
1628        return fn(self, *args, **kwargs)
1629
1630    return disable_mkldnn
1631
1632
1633def expectedFailureCPU(fn):
1634    return expectedFailure("cpu")(fn)
1635
1636
1637def expectedFailureCUDA(fn):
1638    return expectedFailure("cuda")(fn)
1639
1640
1641def expectedFailureXPU(fn):
1642    return expectedFailure("xpu")(fn)
1643
1644
1645def expectedFailureMeta(fn):
1646    return skipIfTorchDynamo()(expectedFailure("meta")(fn))
1647
1648
1649def expectedFailureMPS(fn):
1650    return expectedFailure("mps")(fn)
1651
1652
1653def expectedFailureXLA(fn):
1654    return expectedFailure("xla")(fn)
1655
1656
1657def expectedFailureHPU(fn):
1658    return expectedFailure("hpu")(fn)
1659
1660
1661# Skips a test on CPU if LAPACK is not available.
1662def skipCPUIfNoLapack(fn):
1663    return skipCPUIf(not torch._C.has_lapack, "PyTorch compiled without Lapack")(fn)
1664
1665
1666# Skips a test on CPU if FFT is not available.
1667def skipCPUIfNoFFT(fn):
1668    return skipCPUIf(not torch._C.has_spectral, "PyTorch is built without FFT support")(
1669        fn
1670    )
1671
1672
1673# Skips a test on CPU if MKL is not available.
1674def skipCPUIfNoMkl(fn):
1675    return skipCPUIf(not TEST_MKL, "PyTorch is built without MKL support")(fn)
1676
1677
1678# Skips a test on CPU if MKL Sparse is not available (it's not linked on Windows).
1679def skipCPUIfNoMklSparse(fn):
1680    return skipCPUIf(
1681        IS_WINDOWS or not TEST_MKL, "PyTorch is built without MKL support"
1682    )(fn)
1683
1684
1685# Skips a test on CPU if mkldnn is not available.
1686def skipCPUIfNoMkldnn(fn):
1687    return skipCPUIf(
1688        not torch.backends.mkldnn.is_available(),
1689        "PyTorch is built without mkldnn support",
1690    )(fn)
1691
1692
1693# Skips a test on CUDA if MAGMA is not available.
1694def skipCUDAIfNoMagma(fn):
1695    return skipCUDAIf("no_magma", "no MAGMA library detected")(
1696        skipCUDANonDefaultStreamIf(True)(fn)
1697    )
1698
1699
1700def has_cusolver():
1701    return not TEST_WITH_ROCM
1702
1703
1704def has_hipsolver():
1705    rocm_version = _get_torch_rocm_version()
1706    # hipSOLVER is disabled on ROCM < 5.3
1707    return rocm_version >= (5, 3)
1708
1709
1710# Skips a test on CUDA/ROCM if cuSOLVER/hipSOLVER is not available
1711def skipCUDAIfNoCusolver(fn):
1712    return skipCUDAIf(
1713        not has_cusolver() and not has_hipsolver(), "cuSOLVER not available"
1714    )(fn)
1715
1716
1717# Skips a test if both cuSOLVER and MAGMA are not available
1718def skipCUDAIfNoMagmaAndNoCusolver(fn):
1719    if has_cusolver():
1720        return fn
1721    else:
1722        # cuSolver is disabled on cuda < 10.1.243, tests depend on MAGMA
1723        return skipCUDAIfNoMagma(fn)
1724
1725
1726# Skips a test if both cuSOLVER/hipSOLVER and MAGMA are not available
1727def skipCUDAIfNoMagmaAndNoLinalgsolver(fn):
1728    if has_cusolver() or has_hipsolver():
1729        return fn
1730    else:
1731        # cuSolver is disabled on cuda < 10.1.243, tests depend on MAGMA
1732        return skipCUDAIfNoMagma(fn)
1733
1734
1735# Skips a test on CUDA when using ROCm.
1736def skipCUDAIfRocm(func=None, *, msg="test doesn't currently work on the ROCm stack"):
1737    def dec_fn(fn):
1738        reason = f"skipCUDAIfRocm: {msg}"
1739        return skipCUDAIf(TEST_WITH_ROCM, reason=reason)(fn)
1740
1741    if func:
1742        return dec_fn(func)
1743    return dec_fn
1744
1745
1746# Skips a test on CUDA when not using ROCm.
1747def skipCUDAIfNotRocm(fn):
1748    return skipCUDAIf(
1749        not TEST_WITH_ROCM, "test doesn't currently work on the CUDA stack"
1750    )(fn)
1751
1752
1753# Skips a test on CUDA if ROCm is unavailable or its version is lower than requested.
1754def skipCUDAIfRocmVersionLessThan(version=None):
1755    def dec_fn(fn):
1756        @wraps(fn)
1757        def wrap_fn(self, *args, **kwargs):
1758            if self.device_type == "cuda":
1759                if not TEST_WITH_ROCM:
1760                    reason = "ROCm not available"
1761                    raise unittest.SkipTest(reason)
1762                rocm_version_tuple = _get_torch_rocm_version()
1763                if (
1764                    rocm_version_tuple is None
1765                    or version is None
1766                    or rocm_version_tuple < tuple(version)
1767                ):
1768                    reason = (
1769                        f"ROCm {rocm_version_tuple} is available but {version} required"
1770                    )
1771                    raise unittest.SkipTest(reason)
1772
1773            return fn(self, *args, **kwargs)
1774
1775        return wrap_fn
1776
1777    return dec_fn
1778
1779
1780# Skips a test on CUDA when using ROCm.
1781def skipCUDAIfNotMiopenSuggestNHWC(fn):
1782    return skipCUDAIf(
1783        not TEST_WITH_MIOPEN_SUGGEST_NHWC,
1784        "test doesn't currently work without MIOpen NHWC activation",
1785    )(fn)
1786
1787
1788# Skips a test for specified CUDA versions, given in the form of a list of [major, minor]s.
1789def skipCUDAVersionIn(versions: List[Tuple[int, int]] = None):
1790    def dec_fn(fn):
1791        @wraps(fn)
1792        def wrap_fn(self, *args, **kwargs):
1793            version = _get_torch_cuda_version()
1794            if version == (0, 0):  # cpu or rocm
1795                return fn(self, *args, **kwargs)
1796            if version in (versions or []):
1797                reason = f"test skipped for CUDA version {version}"
1798                raise unittest.SkipTest(reason)
1799            return fn(self, *args, **kwargs)
1800
1801        return wrap_fn
1802
1803    return dec_fn
1804
1805
1806# Skips a test for CUDA versions less than specified, given in the form of [major, minor].
1807def skipCUDAIfVersionLessThan(versions: Tuple[int, int] = None):
1808    def dec_fn(fn):
1809        @wraps(fn)
1810        def wrap_fn(self, *args, **kwargs):
1811            version = _get_torch_cuda_version()
1812            if version == (0, 0):  # cpu or rocm
1813                return fn(self, *args, **kwargs)
1814            if version < versions:
1815                reason = f"test skipped for CUDA versions < {version}"
1816                raise unittest.SkipTest(reason)
1817            return fn(self, *args, **kwargs)
1818
1819        return wrap_fn
1820
1821    return dec_fn
1822
1823
1824# Skips a test on CUDA if cuDNN is unavailable or its version is lower than requested.
1825def skipCUDAIfCudnnVersionLessThan(version=0):
1826    def dec_fn(fn):
1827        @wraps(fn)
1828        def wrap_fn(self, *args, **kwargs):
1829            if self.device_type == "cuda":
1830                if self.no_cudnn:
1831                    reason = "cuDNN not available"
1832                    raise unittest.SkipTest(reason)
1833                if self.cudnn_version is None or self.cudnn_version < version:
1834                    reason = f"cuDNN version {self.cudnn_version} is available but {version} required"
1835                    raise unittest.SkipTest(reason)
1836
1837            return fn(self, *args, **kwargs)
1838
1839        return wrap_fn
1840
1841    return dec_fn
1842
1843
1844# Skips a test on CUDA if cuSparse generic API is not available
1845def skipCUDAIfNoCusparseGeneric(fn):
1846    return skipCUDAIf(not TEST_CUSPARSE_GENERIC, "cuSparse Generic API not available")(
1847        fn
1848    )
1849
1850
1851def skipCUDAIfNoHipsparseGeneric(fn):
1852    return skipCUDAIf(
1853        not TEST_HIPSPARSE_GENERIC, "hipSparse Generic API not available"
1854    )(fn)
1855
1856
1857def skipCUDAIfNoSparseGeneric(fn):
1858    return skipCUDAIf(
1859        not (TEST_CUSPARSE_GENERIC or TEST_HIPSPARSE_GENERIC),
1860        "Sparse Generic API not available",
1861    )(fn)
1862
1863
1864def skipCUDAIfNoCudnn(fn):
1865    return skipCUDAIfCudnnVersionLessThan(0)(fn)
1866
1867
1868def skipCUDAIfMiopen(fn):
1869    return skipCUDAIf(torch.version.hip is not None, "Marked as skipped for MIOpen")(fn)
1870
1871
1872def skipCUDAIfNoMiopen(fn):
1873    return skipCUDAIf(torch.version.hip is None, "MIOpen is not available")(
1874        skipCUDAIfNoCudnn(fn)
1875    )
1876
1877
1878def skipLazy(fn):
1879    return skipLazyIf(True, "test doesn't work with lazy tensors")(fn)
1880
1881
1882def skipMeta(fn):
1883    return skipMetaIf(True, "test doesn't work with meta tensors")(fn)
1884
1885
1886def skipXLA(fn):
1887    return skipXLAIf(True, "Marked as skipped for XLA")(fn)
1888
1889
1890def skipMPS(fn):
1891    return skipMPSIf(True, "test doesn't work on MPS backend")(fn)
1892
1893
1894def skipHPU(fn):
1895    return skipHPUIf(True, "test doesn't work on HPU backend")(fn)
1896
1897
1898def skipPRIVATEUSE1(fn):
1899    return skipPRIVATEUSE1If(True, "test doesn't work on privateuse1 backend")(fn)
1900
1901
1902# TODO: the "all" in the name isn't true anymore for quite some time as we have also have for example XLA and MPS now.
1903#  This should probably enumerate all available device type test base classes.
1904def get_all_device_types() -> List[str]:
1905    return ["cpu"] if not torch.cuda.is_available() else ["cpu", "cuda"]
1906