1# mypy: ignore-errors 2 3import copy 4import gc 5import inspect 6import os 7import runpy 8import sys 9import threading 10import unittest 11from collections import namedtuple 12from enum import Enum 13from functools import partial, wraps 14from typing import Any, ClassVar, Dict, List, Optional, Sequence, Set, Tuple, Union 15 16import torch 17from torch.testing._internal.common_cuda import ( 18 _get_torch_cuda_version, 19 _get_torch_rocm_version, 20 TEST_CUSPARSE_GENERIC, 21 TEST_HIPSPARSE_GENERIC, 22) 23from torch.testing._internal.common_dtype import get_all_dtypes 24from torch.testing._internal.common_utils import ( 25 _TestParametrizer, 26 clear_tracked_input, 27 compose_parametrize_fns, 28 dtype_name, 29 get_tracked_input, 30 IS_FBCODE, 31 is_privateuse1_backend_available, 32 IS_REMOTE_GPU, 33 IS_SANDCASTLE, 34 IS_WINDOWS, 35 NATIVE_DEVICES, 36 PRINT_REPRO_ON_FAILURE, 37 skipCUDANonDefaultStreamIf, 38 skipIfTorchDynamo, 39 TEST_HPU, 40 TEST_MKL, 41 TEST_MPS, 42 TEST_WITH_ASAN, 43 TEST_WITH_MIOPEN_SUGGEST_NHWC, 44 TEST_WITH_ROCM, 45 TEST_WITH_TORCHINDUCTOR, 46 TEST_WITH_TSAN, 47 TEST_WITH_UBSAN, 48 TEST_XPU, 49 TestCase, 50) 51 52 53try: 54 import psutil # type: ignore[import] 55 56 HAS_PSUTIL = True 57except ModuleNotFoundError: 58 HAS_PSUTIL = False 59 psutil = None 60 61# Note [Writing Test Templates] 62# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 63# 64# This note was written shortly after the PyTorch 1.9 release. 65# If you notice it's out-of-date or think it could be improved then please 66# file an issue. 67# 68# PyTorch has its own framework for instantiating test templates. That is, for 69# taking test classes that look similar to unittest or pytest 70# compatible test classes and optionally doing the following: 71# 72# - instantiating a version of the test class for each available device type 73# (often the CPU, CUDA, and META device types) 74# - further instantiating a version of each test that's always specialized 75# on the test class's device type, and optionally specialized further 76# on datatypes or operators 77# 78# This functionality is similar to pytest's parametrize functionality 79# (see https://docs.pytest.org/en/6.2.x/parametrize.html), but with considerable 80# additional logic that specializes the instantiated test classes for their 81# device types (see CPUTestBase and CUDATestBase below), supports a variety 82# of composable decorators that allow for test filtering and setting 83# tolerances, and allows tests parametrized by operators to instantiate 84# only the subset of device type x dtype that operator supports. 85# 86# This framework was built to make it easier to write tests that run on 87# multiple device types, multiple datatypes (dtypes), and for multiple 88# operators. It's also useful for controlling which tests are run. For example, 89# only tests that use a CUDA device can be run on platforms with CUDA. 90# Let's dive in with an example to get an idea for how it works: 91# 92# -------------------------------------------------------- 93# A template class (looks like a regular unittest TestCase) 94# class TestClassFoo(TestCase): 95# 96# # A template test that can be specialized with a device 97# # NOTE: this test case is not runnable by unittest or pytest because it 98# # accepts an extra positional argument, "device", that they do not understand 99# def test_bar(self, device): 100# pass 101# 102# # Function that instantiates a template class and its tests 103# instantiate_device_type_tests(TestCommon, globals()) 104# -------------------------------------------------------- 105# 106# In the above code example we see a template class and a single test template 107# that can be instantiated with a device. The function 108# instantiate_device_type_tests(), called at file scope, instantiates 109# new test classes, one per available device type, and new tests in those 110# classes from these templates. It actually does this by removing 111# the class TestClassFoo and replacing it with classes like TestClassFooCPU 112# and TestClassFooCUDA, instantiated test classes that inherit from CPUTestBase 113# and CUDATestBase respectively. Additional device types, like XLA, 114# (see https://github.com/pytorch/xla) can further extend the set of 115# instantiated test classes to create classes like TestClassFooXLA. 116# 117# The test template, test_bar(), is also instantiated. In this case the template 118# is only specialized on a device, so (depending on the available device 119# types) it might become test_bar_cpu() in TestClassFooCPU and test_bar_cuda() 120# in TestClassFooCUDA. We can think of the instantiated test classes as 121# looking like this: 122# 123# -------------------------------------------------------- 124# # An instantiated test class for the CPU device type 125# class TestClassFooCPU(CPUTestBase): 126# 127# # An instantiated test that calls the template with the string representation 128# # of a device from the test class's device type 129# def test_bar_cpu(self): 130# test_bar(self, 'cpu') 131# 132# # An instantiated test class for the CUDA device type 133# class TestClassFooCUDA(CUDATestBase): 134# 135# # An instantiated test that calls the template with the string representation 136# # of a device from the test class's device type 137# def test_bar_cuda(self): 138# test_bar(self, 'cuda:0') 139# -------------------------------------------------------- 140# 141# These instantiated test classes ARE discoverable and runnable by both 142# unittest and pytest. One thing that may be confusing, however, is that 143# attempting to run "test_bar" will not work, despite it appearing in the 144# original template code. This is because "test_bar" is no longer discoverable 145# after instantiate_device_type_tests() runs, as the above snippet shows. 146# Instead "test_bar_cpu" and "test_bar_cuda" may be run directly, or both 147# can be run with the option "-k test_bar". 148# 149# Removing the template class and adding the instantiated classes requires 150# passing "globals()" to instantiate_device_type_tests(), because it 151# edits the file's Python objects. 152# 153# As mentioned, tests can be additionally parametrized on dtypes or 154# operators. Datatype parametrization uses the @dtypes decorator and 155# require a test template like this: 156# 157# -------------------------------------------------------- 158# # A template test that can be specialized with a device and a datatype (dtype) 159# @dtypes(torch.float32, torch.int64) 160# def test_car(self, device, dtype) 161# pass 162# -------------------------------------------------------- 163# 164# If the CPU and CUDA device types are available this test would be 165# instantiated as 4 tests that cover the cross-product of the two dtypes 166# and two device types: 167# 168# - test_car_cpu_float32 169# - test_car_cpu_int64 170# - test_car_cuda_float32 171# - test_car_cuda_int64 172# 173# The dtype is passed as a torch.dtype object. 174# 175# Tests parametrized on operators (actually on OpInfos, more on that in a 176# moment...) use the @ops decorator and require a test template like this: 177# -------------------------------------------------------- 178# # A template test that can be specialized with a device, dtype, and OpInfo 179# @ops(op_db) 180# def test_car(self, device, dtype, op) 181# pass 182# -------------------------------------------------------- 183# 184# See the documentation for the @ops decorator below for additional details 185# on how to use it and see the note [OpInfos] in 186# common_methods_invocations.py for more details on OpInfos. 187# 188# A test parametrized over the entire "op_db", which contains hundreds of 189# OpInfos, will likely have hundreds or thousands of instantiations. The 190# test will be instantiated on the cross-product of device types, operators, 191# and the dtypes the operator supports on that device type. The instantiated 192# tests will have names like: 193# 194# - test_car_add_cpu_float32 195# - test_car_sub_cuda_int64 196# 197# The first instantiated test calls the original test_car() with the OpInfo 198# for torch.add as its "op" argument, the string 'cpu' for its "device" argument, 199# and the dtype torch.float32 for is "dtype" argument. The second instantiated 200# test calls the test_car() with the OpInfo for torch.sub, a CUDA device string 201# like 'cuda:0' or 'cuda:1' for its "device" argument, and the dtype 202# torch.int64 for its "dtype argument." 203# 204# In addition to parametrizing over device, dtype, and ops via OpInfos, the 205# @parametrize decorator is supported for arbitrary parametrizations: 206# -------------------------------------------------------- 207# # A template test that can be specialized with a device, dtype, and value for x 208# @parametrize("x", range(5)) 209# def test_car(self, device, dtype, x) 210# pass 211# -------------------------------------------------------- 212# 213# See the documentation for @parametrize in common_utils.py for additional details 214# on this. Note that the instantiate_device_type_tests() function will handle 215# such parametrizations; there is no need to additionally call 216# instantiate_parametrized_tests(). 217# 218# Clever test filtering can be very useful when working with parametrized 219# tests. "-k test_car" would run every instantiated variant of the test_car() 220# test template, and "-k test_car_add" runs every variant instantiated with 221# torch.add. 222# 223# It is important to use the passed device and dtype as appropriate. Use 224# helper functions like make_tensor() that require explicitly specifying 225# the device and dtype so they're not forgotten. 226# 227# Test templates can use a variety of composable decorators to specify 228# additional options and requirements, some are listed here: 229# 230# - @deviceCountAtLeast(<minimum number of devices to run test with>) 231# Passes a list of strings representing all available devices of 232# the test class's device type as the test template's "device" argument. 233# If there are fewer devices than the value passed to the decorator 234# the test is skipped. 235# - @dtypes(<list of tuples of dtypes>) 236# In addition to accepting multiple dtypes, the @dtypes decorator 237# can accept a sequence of tuple pairs of dtypes. The test template 238# will be called with each tuple for its "dtype" argument. 239# - @onlyNativeDeviceTypes 240# Skips the test if the device is not a native device type (currently CPU, CUDA, Meta) 241# - @onlyCPU 242# Skips the test if the device is not a CPU device 243# - @onlyCUDA 244# Skips the test if the device is not a CUDA device 245# - @onlyMPS 246# Skips the test if the device is not a MPS device 247# - @skipCPUIfNoLapack 248# Skips the test if the device is a CPU device and LAPACK is not installed 249# - @skipCPUIfNoMkl 250# Skips the test if the device is a CPU device and MKL is not installed 251# - @skipCUDAIfNoMagma 252# Skips the test if the device is a CUDA device and MAGMA is not installed 253# - @skipCUDAIfRocm 254# Skips the test if the device is a CUDA device and ROCm is being used 255 256 257# Note [Adding a Device Type] 258# ~~~~~~~~~~~~~~~~~~~~~~~~~~~ 259# 260# To add a device type: 261# 262# (1) Create a new "TestBase" extending DeviceTypeTestBase. 263# See CPUTestBase and CUDATestBase below. 264# (2) Define the "device_type" attribute of the base to be the 265# appropriate string. 266# (3) Add logic to this file that appends your base class to 267# device_type_test_bases when your device type is available. 268# (4) (Optional) Write setUpClass/tearDownClass class methods that 269# instantiate dependencies (see MAGMA in CUDATestBase). 270# (5) (Optional) Override the "instantiate_test" method for total 271# control over how your class creates tests. 272# 273# setUpClass is called AFTER tests have been created and BEFORE and ONLY IF 274# they are run. This makes it useful for initializing devices and dependencies. 275 276 277# Note [Overriding methods in generic tests] 278# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 279# 280# Device generic tests look a lot like normal test classes, but they differ 281# from ordinary classes in some important ways. In particular, overriding 282# methods in generic tests doesn't work quite the way you expect. 283# 284# class TestFooDeviceType(TestCase): 285# # Intention is to override 286# def assertEqual(self, x, y): 287# # This DOESN'T WORK! 288# super().assertEqual(x, y) 289# 290# If you try to run this code, you'll get an error saying that TestFooDeviceType 291# is not in scope. This is because after instantiating our classes, we delete 292# it from the parent scope. Instead, you need to hardcode a direct invocation 293# of the desired subclass call, e.g., 294# 295# class TestFooDeviceType(TestCase): 296# # Intention is to override 297# def assertEqual(self, x, y): 298# TestCase.assertEqual(x, y) 299# 300# However, a less error-prone way of customizing the behavior of TestCase 301# is to either (1) add your functionality to TestCase and make it toggled 302# by a class attribute, or (2) create your own subclass of TestCase, and 303# then inherit from it for your generic test. 304 305 306def _dtype_test_suffix(dtypes): 307 """Returns the test suffix for a dtype, sequence of dtypes, or None.""" 308 if isinstance(dtypes, (list, tuple)): 309 if len(dtypes) == 0: 310 return "" 311 return "_" + "_".join(dtype_name(d) for d in dtypes) 312 elif dtypes: 313 return f"_{dtype_name(dtypes)}" 314 else: 315 return "" 316 317 318def _update_param_kwargs(param_kwargs, name, value): 319 """Adds a kwarg with the specified name and value to the param_kwargs dict.""" 320 # Make name plural (e.g. devices / dtypes) if the value is composite. 321 plural_name = f"{name}s" 322 323 # Clear out old entries of the arg if any. 324 if name in param_kwargs: 325 del param_kwargs[name] 326 if plural_name in param_kwargs: 327 del param_kwargs[plural_name] 328 329 if isinstance(value, (list, tuple)): 330 param_kwargs[plural_name] = value 331 elif value is not None: 332 param_kwargs[name] = value 333 334 # Leave param_kwargs as-is when value is None. 335 336 337class DeviceTypeTestBase(TestCase): 338 device_type: str = "generic_device_type" 339 340 # Flag to disable test suite early due to unrecoverable error such as CUDA error. 341 _stop_test_suite = False 342 343 # Precision is a thread-local setting since it may be overridden per test 344 _tls = threading.local() 345 _tls.precision = TestCase._precision 346 _tls.rel_tol = TestCase._rel_tol 347 348 @property 349 def precision(self): 350 return self._tls.precision 351 352 @precision.setter 353 def precision(self, prec): 354 self._tls.precision = prec 355 356 @property 357 def rel_tol(self): 358 return self._tls.rel_tol 359 360 @rel_tol.setter 361 def rel_tol(self, prec): 362 self._tls.rel_tol = prec 363 364 # Returns a string representing the device that single device tests should use. 365 # Note: single device tests use this device exclusively. 366 @classmethod 367 def get_primary_device(cls): 368 return cls.device_type 369 370 @classmethod 371 def _init_and_get_primary_device(cls): 372 try: 373 return cls.get_primary_device() 374 except Exception: 375 # For CUDATestBase, XLATestBase, and possibly others, the primary device won't be available 376 # until setUpClass() sets it. Call that manually here if needed. 377 if hasattr(cls, "setUpClass"): 378 cls.setUpClass() 379 return cls.get_primary_device() 380 381 # Returns a list of strings representing all available devices of this 382 # device type. The primary device must be the first string in the list 383 # and the list must contain no duplicates. 384 # Note: UNSTABLE API. Will be replaced once PyTorch has a device generic 385 # mechanism of acquiring all available devices. 386 @classmethod 387 def get_all_devices(cls): 388 return [cls.get_primary_device()] 389 390 # Returns the dtypes the test has requested. 391 # Prefers device-specific dtype specifications over generic ones. 392 @classmethod 393 def _get_dtypes(cls, test): 394 if not hasattr(test, "dtypes"): 395 return None 396 397 default_dtypes = test.dtypes.get("all") 398 msg = f"@dtypes is mandatory when using @dtypesIf however '{test.__name__}' didn't specify it" 399 assert default_dtypes is not None, msg 400 401 return test.dtypes.get(cls.device_type, default_dtypes) 402 403 def _get_precision_override(self, test, dtype): 404 if not hasattr(test, "precision_overrides"): 405 return self.precision 406 return test.precision_overrides.get(dtype, self.precision) 407 408 def _get_tolerance_override(self, test, dtype): 409 if not hasattr(test, "tolerance_overrides"): 410 return self.precision, self.rel_tol 411 return test.tolerance_overrides.get(dtype, tol(self.precision, self.rel_tol)) 412 413 def _apply_precision_override_for_test(self, test, param_kwargs): 414 dtype = param_kwargs["dtype"] if "dtype" in param_kwargs else None 415 dtype = param_kwargs["dtypes"] if "dtypes" in param_kwargs else dtype 416 if dtype: 417 self.precision = self._get_precision_override(test, dtype) 418 self.precision, self.rel_tol = self._get_tolerance_override(test, dtype) 419 420 # Creates device-specific tests. 421 @classmethod 422 def instantiate_test(cls, name, test, *, generic_cls=None): 423 def instantiate_test_helper( 424 cls, name, *, test, param_kwargs=None, decorator_fn=lambda _: [] 425 ): 426 # Add the device param kwarg if the test needs device or devices. 427 param_kwargs = {} if param_kwargs is None else param_kwargs 428 test_sig_params = inspect.signature(test).parameters 429 if "device" in test_sig_params or "devices" in test_sig_params: 430 device_arg: str = cls._init_and_get_primary_device() 431 if hasattr(test, "num_required_devices"): 432 device_arg = cls.get_all_devices() 433 _update_param_kwargs(param_kwargs, "device", device_arg) 434 435 # Apply decorators based on param kwargs. 436 for decorator in decorator_fn(param_kwargs): 437 test = decorator(test) 438 439 # Constructs the test 440 @wraps(test) 441 def instantiated_test(self, param_kwargs=param_kwargs): 442 # Sets precision and runs test 443 # Note: precision is reset after the test is run 444 guard_precision = self.precision 445 guard_rel_tol = self.rel_tol 446 try: 447 self._apply_precision_override_for_test(test, param_kwargs) 448 result = test(self, **param_kwargs) 449 except RuntimeError as rte: 450 # check if rte should stop entire test suite. 451 self._stop_test_suite = self._should_stop_test_suite() 452 # Check if test has been decorated with `@expectedFailure` 453 # Using `__unittest_expecting_failure__` attribute, see 454 # https://github.com/python/cpython/blob/ffa505b580464/Lib/unittest/case.py#L164 455 # In that case, make it fail with "unexpected success" by suppressing exception 456 if ( 457 getattr(test, "__unittest_expecting_failure__", False) 458 and self._stop_test_suite 459 ): 460 import sys 461 462 print( 463 "Suppressing fatal exception to trigger unexpected success", 464 file=sys.stderr, 465 ) 466 return 467 # raise the runtime error as is for the test suite to record. 468 raise rte 469 finally: 470 self.precision = guard_precision 471 self.rel_tol = guard_rel_tol 472 473 return result 474 475 assert not hasattr(cls, name), f"Redefinition of test {name}" 476 setattr(cls, name, instantiated_test) 477 478 def default_parametrize_fn(test, generic_cls, device_cls): 479 # By default, no parametrization is needed. 480 yield (test, "", {}, lambda _: []) 481 482 # Parametrization decorators set the parametrize_fn attribute on the test. 483 parametrize_fn = getattr(test, "parametrize_fn", default_parametrize_fn) 484 485 # If one of the @dtypes* decorators is present, also parametrize over the dtypes set by it. 486 dtypes = cls._get_dtypes(test) 487 if dtypes is not None: 488 489 def dtype_parametrize_fn(test, generic_cls, device_cls, dtypes=dtypes): 490 for dtype in dtypes: 491 param_kwargs: Dict[str, Any] = {} 492 _update_param_kwargs(param_kwargs, "dtype", dtype) 493 494 # Note that an empty test suffix is set here so that the dtype can be appended 495 # later after the device. 496 yield (test, "", param_kwargs, lambda _: []) 497 498 parametrize_fn = compose_parametrize_fns( 499 dtype_parametrize_fn, parametrize_fn 500 ) 501 502 # Instantiate the parametrized tests. 503 for ( 504 test, # noqa: B020 505 test_suffix, 506 param_kwargs, 507 decorator_fn, 508 ) in parametrize_fn(test, generic_cls, cls): 509 test_suffix = "" if test_suffix == "" else "_" + test_suffix 510 cls_device_type = ( 511 cls.device_type 512 if cls.device_type != "privateuse1" 513 else torch._C._get_privateuse1_backend_name() 514 ) 515 device_suffix = "_" + cls_device_type 516 517 # Note: device and dtype suffix placement 518 # Special handling here to place dtype(s) after device according to test name convention. 519 dtype_kwarg = None 520 if "dtype" in param_kwargs or "dtypes" in param_kwargs: 521 dtype_kwarg = ( 522 param_kwargs["dtypes"] 523 if "dtypes" in param_kwargs 524 else param_kwargs["dtype"] 525 ) 526 test_name = ( 527 f"{name}{test_suffix}{device_suffix}{_dtype_test_suffix(dtype_kwarg)}" 528 ) 529 530 instantiate_test_helper( 531 cls=cls, 532 name=test_name, 533 test=test, 534 param_kwargs=param_kwargs, 535 decorator_fn=decorator_fn, 536 ) 537 538 def run(self, result=None): 539 super().run(result=result) 540 # Early terminate test if _stop_test_suite is set. 541 if self._stop_test_suite: 542 result.stop() 543 544 545class CPUTestBase(DeviceTypeTestBase): 546 device_type = "cpu" 547 548 # No critical error should stop CPU test suite 549 def _should_stop_test_suite(self): 550 return False 551 552 553class CUDATestBase(DeviceTypeTestBase): 554 device_type = "cuda" 555 _do_cuda_memory_leak_check = True 556 _do_cuda_non_default_stream = True 557 primary_device: ClassVar[str] 558 cudnn_version: ClassVar[Any] 559 no_magma: ClassVar[bool] 560 no_cudnn: ClassVar[bool] 561 562 def has_cudnn(self): 563 return not self.no_cudnn 564 565 @classmethod 566 def get_primary_device(cls): 567 return cls.primary_device 568 569 @classmethod 570 def get_all_devices(cls): 571 primary_device_idx = int(cls.get_primary_device().split(":")[1]) 572 num_devices = torch.cuda.device_count() 573 574 prim_device = cls.get_primary_device() 575 cuda_str = "cuda:{0}" 576 non_primary_devices = [ 577 cuda_str.format(idx) 578 for idx in range(num_devices) 579 if idx != primary_device_idx 580 ] 581 return [prim_device] + non_primary_devices 582 583 @classmethod 584 def setUpClass(cls): 585 # has_magma shows up after cuda is initialized 586 t = torch.ones(1).cuda() 587 cls.no_magma = not torch.cuda.has_magma 588 589 # Determines if cuDNN is available and its version 590 cls.no_cudnn = not torch.backends.cudnn.is_acceptable(t) 591 cls.cudnn_version = None if cls.no_cudnn else torch.backends.cudnn.version() 592 593 # Acquires the current device as the primary (test) device 594 cls.primary_device = f"cuda:{torch.cuda.current_device()}" 595 596 597# See Note [Lazy Tensor tests in device agnostic testing] 598lazy_ts_backend_init = False 599 600 601class LazyTestBase(DeviceTypeTestBase): 602 device_type = "lazy" 603 604 def _should_stop_test_suite(self): 605 return False 606 607 @classmethod 608 def setUpClass(cls): 609 import torch._lazy 610 import torch._lazy.metrics 611 import torch._lazy.ts_backend 612 613 global lazy_ts_backend_init 614 if not lazy_ts_backend_init: 615 # Need to connect the TS backend to lazy key before running tests 616 torch._lazy.ts_backend.init() 617 lazy_ts_backend_init = True 618 619 620class MPSTestBase(DeviceTypeTestBase): 621 device_type = "mps" 622 primary_device: ClassVar[str] 623 624 @classmethod 625 def get_primary_device(cls): 626 return cls.primary_device 627 628 @classmethod 629 def get_all_devices(cls): 630 # currently only one device is supported on MPS backend 631 prim_device = cls.get_primary_device() 632 return [prim_device] 633 634 @classmethod 635 def setUpClass(cls): 636 cls.primary_device = "mps:0" 637 638 def _should_stop_test_suite(self): 639 return False 640 641 642class XPUTestBase(DeviceTypeTestBase): 643 device_type = "xpu" 644 primary_device: ClassVar[str] 645 646 @classmethod 647 def get_primary_device(cls): 648 return cls.primary_device 649 650 @classmethod 651 def get_all_devices(cls): 652 # currently only one device is supported on MPS backend 653 prim_device = cls.get_primary_device() 654 return [prim_device] 655 656 @classmethod 657 def setUpClass(cls): 658 cls.primary_device = "xpu:0" 659 660 def _should_stop_test_suite(self): 661 return False 662 663 664class HPUTestBase(DeviceTypeTestBase): 665 device_type = "hpu" 666 primary_device: ClassVar[str] 667 668 @classmethod 669 def get_primary_device(cls): 670 return cls.primary_device 671 672 @classmethod 673 def setUpClass(cls): 674 cls.primary_device = "hpu:0" 675 676 677class PrivateUse1TestBase(DeviceTypeTestBase): 678 primary_device: ClassVar[str] 679 device_mod = None 680 device_type = "privateuse1" 681 682 @classmethod 683 def get_primary_device(cls): 684 return cls.primary_device 685 686 @classmethod 687 def get_all_devices(cls): 688 primary_device_idx = int(cls.get_primary_device().split(":")[1]) 689 num_devices = cls.device_mod.device_count() 690 prim_device = cls.get_primary_device() 691 device_str = f"{cls.device_type}:{{0}}" 692 non_primary_devices = [ 693 device_str.format(idx) 694 for idx in range(num_devices) 695 if idx != primary_device_idx 696 ] 697 return [prim_device] + non_primary_devices 698 699 @classmethod 700 def setUpClass(cls): 701 cls.device_type = torch._C._get_privateuse1_backend_name() 702 cls.device_mod = getattr(torch, cls.device_type, None) 703 assert ( 704 cls.device_mod is not None 705 ), f"""torch has no module of `{cls.device_type}`, you should register 706 a module by `torch._register_device_module`.""" 707 cls.primary_device = f"{cls.device_type}:{cls.device_mod.current_device()}" 708 709 710# Adds available device-type-specific test base classes 711def get_device_type_test_bases(): 712 # set type to List[Any] due to mypy list-of-union issue: 713 # https://github.com/python/mypy/issues/3351 714 test_bases: List[Any] = [] 715 716 if IS_SANDCASTLE or IS_FBCODE: 717 if IS_REMOTE_GPU: 718 # Skip if sanitizer is enabled 719 if not TEST_WITH_ASAN and not TEST_WITH_TSAN and not TEST_WITH_UBSAN: 720 test_bases.append(CUDATestBase) 721 else: 722 test_bases.append(CPUTestBase) 723 else: 724 test_bases.append(CPUTestBase) 725 if torch.cuda.is_available(): 726 test_bases.append(CUDATestBase) 727 728 if is_privateuse1_backend_available(): 729 test_bases.append(PrivateUse1TestBase) 730 # Disable MPS testing in generic device testing temporarily while we're 731 # ramping up support. 732 # elif torch.backends.mps.is_available(): 733 # test_bases.append(MPSTestBase) 734 735 return test_bases 736 737 738device_type_test_bases = get_device_type_test_bases() 739 740 741def filter_desired_device_types(device_type_test_bases, except_for=None, only_for=None): 742 # device type cannot appear in both except_for and only_for 743 intersect = set(except_for if except_for else []) & set( 744 only_for if only_for else [] 745 ) 746 assert ( 747 not intersect 748 ), f"device ({intersect}) appeared in both except_for and only_for" 749 750 # Replace your privateuse1 backend name with 'privateuse1' 751 if is_privateuse1_backend_available(): 752 privateuse1_backend_name = torch._C._get_privateuse1_backend_name() 753 except_for = ( 754 ["privateuse1" if x == privateuse1_backend_name else x for x in except_for] 755 if except_for is not None 756 else None 757 ) 758 only_for = ( 759 ["privateuse1" if x == privateuse1_backend_name else x for x in only_for] 760 if only_for is not None 761 else None 762 ) 763 764 if except_for: 765 device_type_test_bases = filter( 766 lambda x: x.device_type not in except_for, device_type_test_bases 767 ) 768 if only_for: 769 device_type_test_bases = filter( 770 lambda x: x.device_type in only_for, device_type_test_bases 771 ) 772 773 return list(device_type_test_bases) 774 775 776# Note [How to extend DeviceTypeTestBase to add new test device] 777# The following logic optionally allows downstream projects like pytorch/xla to 778# add more test devices. 779# Instructions: 780# - Add a python file (e.g. pytorch/xla/test/pytorch_test_base.py) in downstream project. 781# - Inside the file, one should inherit from `DeviceTypeTestBase` class and define 782# a new DeviceTypeTest class (e.g. `XLATestBase`) with proper implementation of 783# `instantiate_test` method. 784# - DO NOT import common_device_type inside the file. 785# `runpy.run_path` with `globals()` already properly setup the context so that 786# `DeviceTypeTestBase` is already available. 787# - Set a top-level variable `TEST_CLASS` equal to your new class. 788# E.g. TEST_CLASS = XLATensorBase 789# - To run tests with new device type, set `TORCH_TEST_DEVICE` env variable to path 790# to this file. Multiple paths can be separated by `:`. 791# See pytorch/xla/test/pytorch_test_base.py for a more detailed example. 792_TORCH_TEST_DEVICES = os.environ.get("TORCH_TEST_DEVICES", None) 793if _TORCH_TEST_DEVICES: 794 for path in _TORCH_TEST_DEVICES.split(":"): 795 # runpy (a stdlib module) lacks annotations 796 mod = runpy.run_path(path, init_globals=globals()) # type: ignore[func-returns-value] 797 device_type_test_bases.append(mod["TEST_CLASS"]) 798 799 800PYTORCH_CUDA_MEMCHECK = os.getenv("PYTORCH_CUDA_MEMCHECK", "0") == "1" 801 802PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY = "PYTORCH_TESTING_DEVICE_ONLY_FOR" 803PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY = "PYTORCH_TESTING_DEVICE_EXCEPT_FOR" 804PYTORCH_TESTING_DEVICE_FOR_CUSTOM_KEY = "PYTORCH_TESTING_DEVICE_FOR_CUSTOM" 805 806 807def get_desired_device_type_test_bases( 808 except_for=None, only_for=None, include_lazy=False, allow_mps=False, allow_xpu=False 809): 810 # allow callers to specifically opt tests into being tested on MPS, similar to `include_lazy` 811 test_bases = device_type_test_bases.copy() 812 if allow_mps and TEST_MPS and MPSTestBase not in test_bases: 813 test_bases.append(MPSTestBase) 814 if allow_xpu and TEST_XPU and XPUTestBase not in test_bases: 815 test_bases.append(XPUTestBase) 816 if TEST_HPU and HPUTestBase not in test_bases: 817 test_bases.append(HPUTestBase) 818 # Filter out the device types based on user inputs 819 desired_device_type_test_bases = filter_desired_device_types( 820 test_bases, except_for, only_for 821 ) 822 if include_lazy: 823 # Note [Lazy Tensor tests in device agnostic testing] 824 # Right now, test_view_ops.py runs with LazyTensor. 825 # We don't want to opt every device-agnostic test into using the lazy device, 826 # because many of them will fail. 827 # So instead, the only way to opt a specific device-agnostic test file into 828 # lazy tensor testing is with include_lazy=True 829 if IS_FBCODE: 830 print( 831 "TorchScript backend not yet supported in FBCODE/OVRSOURCE builds", 832 file=sys.stderr, 833 ) 834 else: 835 desired_device_type_test_bases.append(LazyTestBase) 836 837 def split_if_not_empty(x: str): 838 return x.split(",") if x else [] 839 840 # run some cuda testcases on other devices if available 841 # Usage: 842 # export PYTORCH_TESTING_DEVICE_FOR_CUSTOM=privateuse1 843 env_custom_only_for = split_if_not_empty( 844 os.getenv(PYTORCH_TESTING_DEVICE_FOR_CUSTOM_KEY, "") 845 ) 846 if env_custom_only_for: 847 desired_device_type_test_bases += filter( 848 lambda x: x.device_type in env_custom_only_for, test_bases 849 ) 850 desired_device_type_test_bases = list(set(desired_device_type_test_bases)) 851 852 # Filter out the device types based on environment variables if available 853 # Usage: 854 # export PYTORCH_TESTING_DEVICE_ONLY_FOR=cuda,cpu 855 # export PYTORCH_TESTING_DEVICE_EXCEPT_FOR=xla 856 env_only_for = split_if_not_empty( 857 os.getenv(PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, "") 858 ) 859 env_except_for = split_if_not_empty( 860 os.getenv(PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY, "") 861 ) 862 863 return filter_desired_device_types( 864 desired_device_type_test_bases, env_except_for, env_only_for 865 ) 866 867 868# Adds 'instantiated' device-specific test cases to the given scope. 869# The tests in these test cases are derived from the generic tests in 870# generic_test_class. This function should be used instead of 871# instantiate_parametrized_tests() if the test class contains 872# device-specific tests (NB: this supports additional @parametrize usage). 873# 874# See note "Writing Test Templates" 875# TODO: remove "allow_xpu" option after Interl GPU support all test case instantiate by this function. 876def instantiate_device_type_tests( 877 generic_test_class, 878 scope, 879 except_for=None, 880 only_for=None, 881 include_lazy=False, 882 allow_mps=False, 883 allow_xpu=False, 884): 885 # Removes the generic test class from its enclosing scope so its tests 886 # are not discoverable. 887 del scope[generic_test_class.__name__] 888 889 # Creates an 'empty' version of the generic_test_class 890 # Note: we don't inherit from the generic_test_class directly because 891 # that would add its tests to our test classes and they would be 892 # discovered (despite not being runnable). Inherited methods also 893 # can't be removed later, and we can't rely on load_tests because 894 # pytest doesn't support it (as of this writing). 895 empty_name = generic_test_class.__name__ + "_base" 896 empty_class = type(empty_name, generic_test_class.__bases__, {}) 897 898 # Acquires members names 899 # See Note [Overriding methods in generic tests] 900 generic_members = set(generic_test_class.__dict__.keys()) - set( 901 empty_class.__dict__.keys() 902 ) 903 generic_tests = [x for x in generic_members if x.startswith("test")] 904 905 # Creates device-specific test cases 906 for base in get_desired_device_type_test_bases( 907 except_for, only_for, include_lazy, allow_mps, allow_xpu 908 ): 909 class_name = generic_test_class.__name__ + base.device_type.upper() 910 911 # type set to Any and suppressed due to unsupport runtime class: 912 # https://github.com/python/mypy/wiki/Unsupported-Python-Features 913 device_type_test_class: Any = type(class_name, (base, empty_class), {}) 914 915 for name in generic_members: 916 if name in generic_tests: # Instantiates test member 917 test = getattr(generic_test_class, name) 918 # XLA-compat shim (XLA's instantiate_test takes doesn't take generic_cls) 919 sig = inspect.signature(device_type_test_class.instantiate_test) 920 if len(sig.parameters) == 3: 921 # Instantiates the device-specific tests 922 device_type_test_class.instantiate_test( 923 name, copy.deepcopy(test), generic_cls=generic_test_class 924 ) 925 else: 926 device_type_test_class.instantiate_test(name, copy.deepcopy(test)) 927 else: # Ports non-test member 928 assert ( 929 name not in device_type_test_class.__dict__ 930 ), f"Redefinition of directly defined member {name}" 931 nontest = getattr(generic_test_class, name) 932 setattr(device_type_test_class, name, nontest) 933 934 # The dynamically-created test class derives from the test template class 935 # and the empty class. Arrange for both setUpClass and tearDownClass methods 936 # to be called. This allows the parameterized test classes to support setup 937 # and teardown. 938 @classmethod 939 def _setUpClass(cls): 940 base.setUpClass() 941 empty_class.setUpClass() 942 943 @classmethod 944 def _tearDownClass(cls): 945 empty_class.tearDownClass() 946 base.tearDownClass() 947 948 device_type_test_class.setUpClass = _setUpClass 949 device_type_test_class.tearDownClass = _tearDownClass 950 951 # Mimics defining the instantiated class in the caller's file 952 # by setting its module to the given class's and adding 953 # the module to the given scope. 954 # This lets the instantiated class be discovered by unittest. 955 device_type_test_class.__module__ = generic_test_class.__module__ 956 scope[class_name] = device_type_test_class 957 958 959# Category of dtypes to run an OpInfo-based test for 960# Example use: @ops(dtype=OpDTypes.supported) 961# 962# There are 5 categories: 963# - supported: Every dtype supported by the operator. Use for exhaustive 964# testing of all dtypes. 965# - unsupported: Run tests on dtypes not supported by the operator. e.g. for 966# testing the operator raises an error and doesn't crash. 967# - supported_backward: Every dtype supported by the operator's backward pass. 968# - unsupported_backward: Run tests on dtypes not supported by the operator's backward pass. 969# - any_one: Runs a test for one dtype the operator supports. Prioritizes dtypes the 970# operator supports in both forward and backward. 971# - none: Useful for tests that are not dtype-specific. No dtype will be passed to the test 972# when this is selected. 973class OpDTypes(Enum): 974 supported = 0 # Test all supported dtypes (default) 975 unsupported = 1 # Test only unsupported dtypes 976 supported_backward = 2 # Test all supported backward dtypes 977 unsupported_backward = 3 # Test only unsupported backward dtypes 978 any_one = 4 # Test precisely one supported dtype 979 none = 5 # Instantiate no dtype variants (no dtype kwarg needed) 980 any_common_cpu_cuda_one = ( 981 6 # Test precisely one supported dtype that is common to both cuda and cpu 982 ) 983 984 985# Arbitrary order 986ANY_DTYPE_ORDER = ( 987 torch.float32, 988 torch.float64, 989 torch.complex64, 990 torch.complex128, 991 torch.float16, 992 torch.bfloat16, 993 torch.long, 994 torch.int32, 995 torch.int16, 996 torch.int8, 997 torch.uint8, 998 torch.bool, 999) 1000 1001 1002def _serialize_sample(sample_input): 1003 # NB: For OpInfos, SampleInput.summary() prints in a cleaner way. 1004 if getattr(sample_input, "summary", None) is not None: 1005 return sample_input.summary() 1006 return str(sample_input) 1007 1008 1009# Decorator that defines the OpInfos a test template should be instantiated for. 1010# 1011# Example usage: 1012# 1013# @ops(unary_ufuncs) 1014# def test_numerics(self, device, dtype, op): 1015# <test_code> 1016# 1017# This will instantiate variants of test_numerics for each given OpInfo, 1018# on each device the OpInfo's operator supports, and for every dtype supported by 1019# that operator. There are a few caveats to the dtype rule, explained below. 1020# 1021# The @ops decorator can accept two 1022# additional arguments, "dtypes" and "allowed_dtypes". If "dtypes" is specified 1023# then the test variants are instantiated for those dtypes, regardless of 1024# what the operator supports. If given "allowed_dtypes" then test variants 1025# are instantiated only for the intersection of allowed_dtypes and the dtypes 1026# they would otherwise be instantiated with. That is, allowed_dtypes composes 1027# with the options listed above and below. 1028# 1029# The "dtypes" argument can also accept additional values (see OpDTypes above): 1030# OpDTypes.supported - the test is instantiated for all dtypes the operator 1031# supports 1032# OpDTypes.unsupported - the test is instantiated for all dtypes the operator 1033# doesn't support 1034# OpDTypes.supported_backward - the test is instantiated for all dtypes the 1035# operator's gradient formula supports 1036# OpDTypes.unsupported_backward - the test is instantiated for all dtypes the 1037# operator's gradient formula doesn't support 1038# OpDTypes.any_one - the test is instantiated for one dtype the 1039# operator supports. The dtype supports forward and backward if possible. 1040# OpDTypes.none - the test is instantiated without any dtype. The test signature 1041# should not include a dtype kwarg in this case. 1042# 1043# These options allow tests to have considerable control over the dtypes 1044# they're instantiated for. 1045 1046 1047class ops(_TestParametrizer): 1048 def __init__( 1049 self, 1050 op_list, 1051 *, 1052 dtypes: Union[OpDTypes, Sequence[torch.dtype]] = OpDTypes.supported, 1053 allowed_dtypes: Optional[Sequence[torch.dtype]] = None, 1054 skip_if_dynamo=True, 1055 ): 1056 self.op_list = list(op_list) 1057 self.opinfo_dtypes = dtypes 1058 self.allowed_dtypes = ( 1059 set(allowed_dtypes) if allowed_dtypes is not None else None 1060 ) 1061 self.skip_if_dynamo = skip_if_dynamo 1062 1063 def _parametrize_test(self, test, generic_cls, device_cls): 1064 """Parameterizes the given test function across each op and its associated dtypes.""" 1065 if device_cls is None: 1066 raise RuntimeError( 1067 "The @ops decorator is only intended to be used in a device-specific " 1068 "context; use it with instantiate_device_type_tests() instead of " 1069 "instantiate_parametrized_tests()" 1070 ) 1071 1072 op = check_exhausted_iterator = object() 1073 for op in self.op_list: 1074 # Determine the set of dtypes to use. 1075 dtypes: Union[Set[torch.dtype], Set[None]] 1076 if isinstance(self.opinfo_dtypes, Sequence): 1077 dtypes = set(self.opinfo_dtypes) 1078 elif self.opinfo_dtypes == OpDTypes.unsupported_backward: 1079 dtypes = set(get_all_dtypes()).difference( 1080 op.supported_backward_dtypes(device_cls.device_type) 1081 ) 1082 elif self.opinfo_dtypes == OpDTypes.supported_backward: 1083 dtypes = op.supported_backward_dtypes(device_cls.device_type) 1084 elif self.opinfo_dtypes == OpDTypes.unsupported: 1085 dtypes = set(get_all_dtypes()).difference( 1086 op.supported_dtypes(device_cls.device_type) 1087 ) 1088 elif self.opinfo_dtypes == OpDTypes.supported: 1089 dtypes = set(op.supported_dtypes(device_cls.device_type)) 1090 elif self.opinfo_dtypes == OpDTypes.any_one: 1091 # Tries to pick a dtype that supports both forward or backward 1092 supported = op.supported_dtypes(device_cls.device_type) 1093 supported_backward = op.supported_backward_dtypes( 1094 device_cls.device_type 1095 ) 1096 supported_both = supported.intersection(supported_backward) 1097 dtype_set = supported_both if len(supported_both) > 0 else supported 1098 for dtype in ANY_DTYPE_ORDER: 1099 if dtype in dtype_set: 1100 dtypes = {dtype} 1101 break 1102 else: 1103 dtypes = {} 1104 elif self.opinfo_dtypes == OpDTypes.any_common_cpu_cuda_one: 1105 # Tries to pick a dtype that supports both CPU and CUDA 1106 supported = set(op.dtypes).intersection(op.dtypesIfCUDA) 1107 if supported: 1108 dtypes = { 1109 next(dtype for dtype in ANY_DTYPE_ORDER if dtype in supported) 1110 } 1111 else: 1112 dtypes = {} 1113 1114 elif self.opinfo_dtypes == OpDTypes.none: 1115 dtypes = {None} 1116 else: 1117 raise RuntimeError(f"Unknown OpDType: {self.opinfo_dtypes}") 1118 1119 if self.allowed_dtypes is not None: 1120 dtypes = dtypes.intersection(self.allowed_dtypes) 1121 1122 # Construct the test name; device / dtype parts are handled outside. 1123 # See [Note: device and dtype suffix placement] 1124 test_name = op.formatted_name 1125 1126 for dtype in dtypes: 1127 # Construct parameter kwargs to pass to the test. 1128 param_kwargs = {"op": op} 1129 _update_param_kwargs(param_kwargs, "dtype", dtype) 1130 1131 # NOTE: test_wrapper exists because we don't want to apply 1132 # op-specific decorators to the original test. 1133 # Test-specific decorators are applied to the original test, 1134 # however. 1135 try: 1136 1137 @wraps(test) 1138 def test_wrapper(*args, **kwargs): 1139 try: 1140 return test(*args, **kwargs) 1141 except unittest.SkipTest as e: 1142 raise e 1143 except Exception as e: 1144 tracked_input = get_tracked_input() 1145 if PRINT_REPRO_ON_FAILURE and tracked_input is not None: 1146 e_tracked = Exception( # noqa: TRY002 1147 f"Caused by {tracked_input.type_desc} " 1148 f"at index {tracked_input.index}: " 1149 f"{_serialize_sample(tracked_input.val)}" 1150 ) 1151 e_tracked._tracked_input = tracked_input # type: ignore[attr] 1152 raise e_tracked from e 1153 raise e 1154 finally: 1155 clear_tracked_input() 1156 1157 if self.skip_if_dynamo and not TEST_WITH_TORCHINDUCTOR: 1158 test_wrapper = skipIfTorchDynamo( 1159 "Policy: we don't run OpInfo tests w/ Dynamo" 1160 )(test_wrapper) 1161 1162 # Initialize info for the last input seen. This is useful for tracking 1163 # down which inputs caused a test failure. Note that TrackedInputIter is 1164 # responsible for managing this. 1165 test.tracked_input = None 1166 1167 decorator_fn = partial( 1168 op.get_decorators, 1169 generic_cls.__name__, 1170 test.__name__, 1171 device_cls.device_type, 1172 dtype, 1173 ) 1174 1175 yield (test_wrapper, test_name, param_kwargs, decorator_fn) 1176 except Exception as ex: 1177 # Provides an error message for debugging before rethrowing the exception 1178 print(f"Failed to instantiate {test_name} for op {op.name}!") 1179 raise ex 1180 if op is check_exhausted_iterator: 1181 raise ValueError( 1182 "An empty op_list was passed to @ops. " 1183 "Note that this may result from reuse of a generator." 1184 ) 1185 1186 1187# Decorator that skips a test if the given condition is true. 1188# Notes: 1189# (1) Skip conditions stack. 1190# (2) Skip conditions can be bools or strings. If a string the 1191# test base must have defined the corresponding attribute to be False 1192# for the test to run. If you want to use a string argument you should 1193# probably define a new decorator instead (see below). 1194# (3) Prefer the existing decorators to defining the 'device_type' kwarg. 1195class skipIf: 1196 def __init__(self, dep, reason, device_type=None): 1197 self.dep = dep 1198 self.reason = reason 1199 self.device_type = device_type 1200 1201 def __call__(self, fn): 1202 @wraps(fn) 1203 def dep_fn(slf, *args, **kwargs): 1204 if self.device_type is None or self.device_type == slf.device_type: 1205 if (isinstance(self.dep, str) and getattr(slf, self.dep, True)) or ( 1206 isinstance(self.dep, bool) and self.dep 1207 ): 1208 raise unittest.SkipTest(self.reason) 1209 1210 return fn(slf, *args, **kwargs) 1211 1212 return dep_fn 1213 1214 1215# Skips a test on CPU if the condition is true. 1216class skipCPUIf(skipIf): 1217 def __init__(self, dep, reason): 1218 super().__init__(dep, reason, device_type="cpu") 1219 1220 1221# Skips a test on CUDA if the condition is true. 1222class skipCUDAIf(skipIf): 1223 def __init__(self, dep, reason): 1224 super().__init__(dep, reason, device_type="cuda") 1225 1226 1227# Skips a test on XPU if the condition is true. 1228class skipXPUIf(skipIf): 1229 def __init__(self, dep, reason): 1230 super().__init__(dep, reason, device_type="xpu") 1231 1232 1233# Skips a test on Lazy if the condition is true. 1234class skipLazyIf(skipIf): 1235 def __init__(self, dep, reason): 1236 super().__init__(dep, reason, device_type="lazy") 1237 1238 1239# Skips a test on Meta if the condition is true. 1240class skipMetaIf(skipIf): 1241 def __init__(self, dep, reason): 1242 super().__init__(dep, reason, device_type="meta") 1243 1244 1245# Skips a test on MPS if the condition is true. 1246class skipMPSIf(skipIf): 1247 def __init__(self, dep, reason): 1248 super().__init__(dep, reason, device_type="mps") 1249 1250 1251class skipHPUIf(skipIf): 1252 def __init__(self, dep, reason): 1253 super().__init__(dep, reason, device_type="hpu") 1254 1255 1256# Skips a test on XLA if the condition is true. 1257class skipXLAIf(skipIf): 1258 def __init__(self, dep, reason): 1259 super().__init__(dep, reason, device_type="xla") 1260 1261 1262class skipPRIVATEUSE1If(skipIf): 1263 def __init__(self, dep, reason): 1264 device_type = torch._C._get_privateuse1_backend_name() 1265 super().__init__(dep, reason, device_type=device_type) 1266 1267 1268def _has_sufficient_memory(device, size): 1269 if torch.device(device).type == "cuda": 1270 if not torch.cuda.is_available(): 1271 return False 1272 gc.collect() 1273 torch.cuda.empty_cache() 1274 # torch.cuda.mem_get_info, aka cudaMemGetInfo, returns a tuple of (free memory, total memory) of a GPU 1275 if device == "cuda": 1276 device = "cuda:0" 1277 return torch.cuda.memory.mem_get_info(device)[0] >= size 1278 1279 if device == "xla": 1280 raise unittest.SkipTest("TODO: Memory availability checks for XLA?") 1281 1282 if device == "xpu": 1283 raise unittest.SkipTest("TODO: Memory availability checks for Intel GPU?") 1284 1285 if device != "cpu": 1286 raise unittest.SkipTest("Unknown device type") 1287 1288 # CPU 1289 if not HAS_PSUTIL: 1290 raise unittest.SkipTest("Need psutil to determine if memory is sufficient") 1291 1292 # The sanitizers have significant memory overheads 1293 if TEST_WITH_ASAN or TEST_WITH_TSAN or TEST_WITH_UBSAN: 1294 effective_size = size * 10 1295 else: 1296 effective_size = size 1297 1298 if psutil.virtual_memory().available < effective_size: 1299 gc.collect() 1300 return psutil.virtual_memory().available >= effective_size 1301 1302 1303def largeTensorTest(size, device=None): 1304 """Skip test if the device has insufficient memory to run the test 1305 1306 size may be a number of bytes, a string of the form "N GB", or a callable 1307 1308 If the test is a device generic test, available memory on the primary device will be checked. 1309 It can also be overriden by the optional `device=` argument. 1310 In other tests, the `device=` argument needs to be specified. 1311 """ 1312 if isinstance(size, str): 1313 assert size.endswith(("GB", "gb")), "only bytes or GB supported" 1314 size = 1024**3 * int(size[:-2]) 1315 1316 def inner(fn): 1317 @wraps(fn) 1318 def dep_fn(self, *args, **kwargs): 1319 size_bytes = size(self, *args, **kwargs) if callable(size) else size 1320 _device = device if device is not None else self.get_primary_device() 1321 if not _has_sufficient_memory(_device, size_bytes): 1322 raise unittest.SkipTest(f"Insufficient {_device} memory") 1323 1324 return fn(self, *args, **kwargs) 1325 1326 return dep_fn 1327 1328 return inner 1329 1330 1331class expectedFailure: 1332 def __init__(self, device_type): 1333 self.device_type = device_type 1334 1335 def __call__(self, fn): 1336 @wraps(fn) 1337 def efail_fn(slf, *args, **kwargs): 1338 if ( 1339 not hasattr(slf, "device_type") 1340 and hasattr(slf, "device") 1341 and isinstance(slf.device, str) 1342 ): 1343 target_device_type = slf.device 1344 else: 1345 target_device_type = slf.device_type 1346 1347 if self.device_type is None or self.device_type == target_device_type: 1348 try: 1349 fn(slf, *args, **kwargs) 1350 except Exception: 1351 return 1352 else: 1353 slf.fail("expected test to fail, but it passed") 1354 1355 return fn(slf, *args, **kwargs) 1356 1357 return efail_fn 1358 1359 1360class onlyOn: 1361 def __init__(self, device_type): 1362 self.device_type = device_type 1363 1364 def __call__(self, fn): 1365 @wraps(fn) 1366 def only_fn(slf, *args, **kwargs): 1367 if self.device_type != slf.device_type: 1368 reason = f"Only runs on {self.device_type}" 1369 raise unittest.SkipTest(reason) 1370 1371 return fn(slf, *args, **kwargs) 1372 1373 return only_fn 1374 1375 1376# Decorator that provides all available devices of the device type to the test 1377# as a list of strings instead of providing a single device string. 1378# Skips the test if the number of available devices of the variant's device 1379# type is less than the 'num_required_devices' arg. 1380class deviceCountAtLeast: 1381 def __init__(self, num_required_devices): 1382 self.num_required_devices = num_required_devices 1383 1384 def __call__(self, fn): 1385 assert not hasattr( 1386 fn, "num_required_devices" 1387 ), f"deviceCountAtLeast redefinition for {fn.__name__}" 1388 fn.num_required_devices = self.num_required_devices 1389 1390 @wraps(fn) 1391 def multi_fn(slf, devices, *args, **kwargs): 1392 if len(devices) < self.num_required_devices: 1393 reason = f"fewer than {self.num_required_devices} devices detected" 1394 raise unittest.SkipTest(reason) 1395 1396 return fn(slf, devices, *args, **kwargs) 1397 1398 return multi_fn 1399 1400 1401# Only runs the test on the native device type (currently CPU, CUDA, Meta and PRIVATEUSE1) 1402def onlyNativeDeviceTypes(fn): 1403 @wraps(fn) 1404 def only_fn(self, *args, **kwargs): 1405 if self.device_type not in NATIVE_DEVICES: 1406 reason = f"onlyNativeDeviceTypes: doesn't run on {self.device_type}" 1407 raise unittest.SkipTest(reason) 1408 1409 return fn(self, *args, **kwargs) 1410 1411 return only_fn 1412 1413 1414# Only runs the test on the native device types and devices specified in the devices list 1415def onlyNativeDeviceTypesAnd(devices=None): 1416 def decorator(fn): 1417 @wraps(fn) 1418 def only_fn(self, *args, **kwargs): 1419 if ( 1420 self.device_type not in NATIVE_DEVICES 1421 and self.device_type not in devices 1422 ): 1423 reason = f"onlyNativeDeviceTypesAnd {devices} : doesn't run on {self.device_type}" 1424 raise unittest.SkipTest(reason) 1425 1426 return fn(self, *args, **kwargs) 1427 1428 return only_fn 1429 1430 return decorator 1431 1432 1433# Specifies per-dtype precision overrides. 1434# Ex. 1435# 1436# @precisionOverride({torch.half : 1e-2, torch.float : 1e-4}) 1437# @dtypes(torch.half, torch.float, torch.double) 1438# def test_X(self, device, dtype): 1439# ... 1440# 1441# When the test is instantiated its class's precision will be set to the 1442# corresponding override, if it exists. 1443# self.precision can be accessed directly, and it also controls the behavior of 1444# functions like self.assertEqual(). 1445# 1446# Note that self.precision is a scalar value, so if you require multiple 1447# precisions (or are working with multiple dtypes) they should be specified 1448# explicitly and computed using self.precision (e.g. 1449# self.precision *2, max(1, self.precision)). 1450class precisionOverride: 1451 def __init__(self, d): 1452 assert isinstance( 1453 d, dict 1454 ), "precisionOverride not given a dtype : precision dict!" 1455 for dtype in d.keys(): 1456 assert isinstance( 1457 dtype, torch.dtype 1458 ), f"precisionOverride given unknown dtype {dtype}" 1459 1460 self.d = d 1461 1462 def __call__(self, fn): 1463 fn.precision_overrides = self.d 1464 return fn 1465 1466 1467# Specifies per-dtype tolerance overrides tol(atol, rtol). It has priority over 1468# precisionOverride. 1469# Ex. 1470# 1471# @toleranceOverride({torch.float : tol(atol=1e-2, rtol=1e-3}, 1472# torch.double : tol{atol=1e-4, rtol = 0}) 1473# @dtypes(torch.half, torch.float, torch.double) 1474# def test_X(self, device, dtype): 1475# ... 1476# 1477# When the test is instantiated its class's tolerance will be set to the 1478# corresponding override, if it exists. 1479# self.rtol and self.precision can be accessed directly, and they also control 1480# the behavior of functions like self.assertEqual(). 1481# 1482# The above example sets atol = 1e-2 and rtol = 1e-3 for torch.float and 1483# atol = 1e-4 and rtol = 0 for torch.double. 1484tol = namedtuple("tol", ["atol", "rtol"]) 1485 1486 1487class toleranceOverride: 1488 def __init__(self, d): 1489 assert isinstance(d, dict), "toleranceOverride not given a dtype : tol dict!" 1490 for dtype, prec in d.items(): 1491 assert isinstance( 1492 dtype, torch.dtype 1493 ), f"toleranceOverride given unknown dtype {dtype}" 1494 assert isinstance( 1495 prec, tol 1496 ), "toleranceOverride not given a dtype : tol dict!" 1497 1498 self.d = d 1499 1500 def __call__(self, fn): 1501 fn.tolerance_overrides = self.d 1502 return fn 1503 1504 1505# Decorator that instantiates a variant of the test for each given dtype. 1506# Notes: 1507# (1) Tests that accept the dtype argument MUST use this decorator. 1508# (2) Can be overridden for CPU or CUDA, respectively, using dtypesIfCPU 1509# or dtypesIfCUDA. 1510# (3) Can accept an iterable of dtypes or an iterable of tuples 1511# of dtypes. 1512# Examples: 1513# @dtypes(torch.float32, torch.float64) 1514# @dtypes((torch.long, torch.float32), (torch.int, torch.float64)) 1515class dtypes: 1516 def __init__(self, *args, device_type="all"): 1517 if len(args) > 0 and isinstance(args[0], (list, tuple)): 1518 for arg in args: 1519 assert isinstance(arg, (list, tuple)), ( 1520 "When one dtype variant is a tuple or list, " 1521 "all dtype variants must be. " 1522 f"Received non-list non-tuple dtype {str(arg)}" 1523 ) 1524 assert all( 1525 isinstance(dtype, torch.dtype) for dtype in arg 1526 ), f"Unknown dtype in {str(arg)}" 1527 else: 1528 assert all( 1529 isinstance(arg, torch.dtype) for arg in args 1530 ), f"Unknown dtype in {str(args)}" 1531 1532 self.args = args 1533 self.device_type = device_type 1534 1535 def __call__(self, fn): 1536 d = getattr(fn, "dtypes", {}) 1537 assert self.device_type not in d, f"dtypes redefinition for {self.device_type}" 1538 d[self.device_type] = self.args 1539 fn.dtypes = d 1540 return fn 1541 1542 1543# Overrides specified dtypes on the CPU. 1544class dtypesIfCPU(dtypes): 1545 def __init__(self, *args): 1546 super().__init__(*args, device_type="cpu") 1547 1548 1549# Overrides specified dtypes on CUDA. 1550class dtypesIfCUDA(dtypes): 1551 def __init__(self, *args): 1552 super().__init__(*args, device_type="cuda") 1553 1554 1555class dtypesIfMPS(dtypes): 1556 def __init__(self, *args): 1557 super().__init__(*args, device_type="mps") 1558 1559 1560class dtypesIfHPU(dtypes): 1561 def __init__(self, *args): 1562 super().__init__(*args, device_type="hpu") 1563 1564 1565class dtypesIfPRIVATEUSE1(dtypes): 1566 def __init__(self, *args): 1567 super().__init__(*args, device_type=torch._C._get_privateuse1_backend_name()) 1568 1569 1570def onlyCPU(fn): 1571 return onlyOn("cpu")(fn) 1572 1573 1574def onlyCUDA(fn): 1575 return onlyOn("cuda")(fn) 1576 1577 1578def onlyMPS(fn): 1579 return onlyOn("mps")(fn) 1580 1581 1582def onlyXPU(fn): 1583 return onlyOn("xpu")(fn) 1584 1585 1586def onlyHPU(fn): 1587 return onlyOn("hpu")(fn) 1588 1589 1590def onlyPRIVATEUSE1(fn): 1591 device_type = torch._C._get_privateuse1_backend_name() 1592 device_mod = getattr(torch, device_type, None) 1593 if device_mod is None: 1594 reason = f"Skip as torch has no module of {device_type}" 1595 return unittest.skip(reason)(fn) 1596 return onlyOn(device_type)(fn) 1597 1598 1599def onlyCUDAAndPRIVATEUSE1(fn): 1600 @wraps(fn) 1601 def only_fn(self, *args, **kwargs): 1602 if self.device_type not in ("cuda", torch._C._get_privateuse1_backend_name()): 1603 reason = f"onlyCUDAAndPRIVATEUSE1: doesn't run on {self.device_type}" 1604 raise unittest.SkipTest(reason) 1605 1606 return fn(self, *args, **kwargs) 1607 1608 return only_fn 1609 1610 1611def disablecuDNN(fn): 1612 @wraps(fn) 1613 def disable_cudnn(self, *args, **kwargs): 1614 if self.device_type == "cuda" and self.has_cudnn(): 1615 with torch.backends.cudnn.flags(enabled=False): 1616 return fn(self, *args, **kwargs) 1617 return fn(self, *args, **kwargs) 1618 1619 return disable_cudnn 1620 1621 1622def disableMkldnn(fn): 1623 @wraps(fn) 1624 def disable_mkldnn(self, *args, **kwargs): 1625 if torch.backends.mkldnn.is_available(): 1626 with torch.backends.mkldnn.flags(enabled=False): 1627 return fn(self, *args, **kwargs) 1628 return fn(self, *args, **kwargs) 1629 1630 return disable_mkldnn 1631 1632 1633def expectedFailureCPU(fn): 1634 return expectedFailure("cpu")(fn) 1635 1636 1637def expectedFailureCUDA(fn): 1638 return expectedFailure("cuda")(fn) 1639 1640 1641def expectedFailureXPU(fn): 1642 return expectedFailure("xpu")(fn) 1643 1644 1645def expectedFailureMeta(fn): 1646 return skipIfTorchDynamo()(expectedFailure("meta")(fn)) 1647 1648 1649def expectedFailureMPS(fn): 1650 return expectedFailure("mps")(fn) 1651 1652 1653def expectedFailureXLA(fn): 1654 return expectedFailure("xla")(fn) 1655 1656 1657def expectedFailureHPU(fn): 1658 return expectedFailure("hpu")(fn) 1659 1660 1661# Skips a test on CPU if LAPACK is not available. 1662def skipCPUIfNoLapack(fn): 1663 return skipCPUIf(not torch._C.has_lapack, "PyTorch compiled without Lapack")(fn) 1664 1665 1666# Skips a test on CPU if FFT is not available. 1667def skipCPUIfNoFFT(fn): 1668 return skipCPUIf(not torch._C.has_spectral, "PyTorch is built without FFT support")( 1669 fn 1670 ) 1671 1672 1673# Skips a test on CPU if MKL is not available. 1674def skipCPUIfNoMkl(fn): 1675 return skipCPUIf(not TEST_MKL, "PyTorch is built without MKL support")(fn) 1676 1677 1678# Skips a test on CPU if MKL Sparse is not available (it's not linked on Windows). 1679def skipCPUIfNoMklSparse(fn): 1680 return skipCPUIf( 1681 IS_WINDOWS or not TEST_MKL, "PyTorch is built without MKL support" 1682 )(fn) 1683 1684 1685# Skips a test on CPU if mkldnn is not available. 1686def skipCPUIfNoMkldnn(fn): 1687 return skipCPUIf( 1688 not torch.backends.mkldnn.is_available(), 1689 "PyTorch is built without mkldnn support", 1690 )(fn) 1691 1692 1693# Skips a test on CUDA if MAGMA is not available. 1694def skipCUDAIfNoMagma(fn): 1695 return skipCUDAIf("no_magma", "no MAGMA library detected")( 1696 skipCUDANonDefaultStreamIf(True)(fn) 1697 ) 1698 1699 1700def has_cusolver(): 1701 return not TEST_WITH_ROCM 1702 1703 1704def has_hipsolver(): 1705 rocm_version = _get_torch_rocm_version() 1706 # hipSOLVER is disabled on ROCM < 5.3 1707 return rocm_version >= (5, 3) 1708 1709 1710# Skips a test on CUDA/ROCM if cuSOLVER/hipSOLVER is not available 1711def skipCUDAIfNoCusolver(fn): 1712 return skipCUDAIf( 1713 not has_cusolver() and not has_hipsolver(), "cuSOLVER not available" 1714 )(fn) 1715 1716 1717# Skips a test if both cuSOLVER and MAGMA are not available 1718def skipCUDAIfNoMagmaAndNoCusolver(fn): 1719 if has_cusolver(): 1720 return fn 1721 else: 1722 # cuSolver is disabled on cuda < 10.1.243, tests depend on MAGMA 1723 return skipCUDAIfNoMagma(fn) 1724 1725 1726# Skips a test if both cuSOLVER/hipSOLVER and MAGMA are not available 1727def skipCUDAIfNoMagmaAndNoLinalgsolver(fn): 1728 if has_cusolver() or has_hipsolver(): 1729 return fn 1730 else: 1731 # cuSolver is disabled on cuda < 10.1.243, tests depend on MAGMA 1732 return skipCUDAIfNoMagma(fn) 1733 1734 1735# Skips a test on CUDA when using ROCm. 1736def skipCUDAIfRocm(func=None, *, msg="test doesn't currently work on the ROCm stack"): 1737 def dec_fn(fn): 1738 reason = f"skipCUDAIfRocm: {msg}" 1739 return skipCUDAIf(TEST_WITH_ROCM, reason=reason)(fn) 1740 1741 if func: 1742 return dec_fn(func) 1743 return dec_fn 1744 1745 1746# Skips a test on CUDA when not using ROCm. 1747def skipCUDAIfNotRocm(fn): 1748 return skipCUDAIf( 1749 not TEST_WITH_ROCM, "test doesn't currently work on the CUDA stack" 1750 )(fn) 1751 1752 1753# Skips a test on CUDA if ROCm is unavailable or its version is lower than requested. 1754def skipCUDAIfRocmVersionLessThan(version=None): 1755 def dec_fn(fn): 1756 @wraps(fn) 1757 def wrap_fn(self, *args, **kwargs): 1758 if self.device_type == "cuda": 1759 if not TEST_WITH_ROCM: 1760 reason = "ROCm not available" 1761 raise unittest.SkipTest(reason) 1762 rocm_version_tuple = _get_torch_rocm_version() 1763 if ( 1764 rocm_version_tuple is None 1765 or version is None 1766 or rocm_version_tuple < tuple(version) 1767 ): 1768 reason = ( 1769 f"ROCm {rocm_version_tuple} is available but {version} required" 1770 ) 1771 raise unittest.SkipTest(reason) 1772 1773 return fn(self, *args, **kwargs) 1774 1775 return wrap_fn 1776 1777 return dec_fn 1778 1779 1780# Skips a test on CUDA when using ROCm. 1781def skipCUDAIfNotMiopenSuggestNHWC(fn): 1782 return skipCUDAIf( 1783 not TEST_WITH_MIOPEN_SUGGEST_NHWC, 1784 "test doesn't currently work without MIOpen NHWC activation", 1785 )(fn) 1786 1787 1788# Skips a test for specified CUDA versions, given in the form of a list of [major, minor]s. 1789def skipCUDAVersionIn(versions: List[Tuple[int, int]] = None): 1790 def dec_fn(fn): 1791 @wraps(fn) 1792 def wrap_fn(self, *args, **kwargs): 1793 version = _get_torch_cuda_version() 1794 if version == (0, 0): # cpu or rocm 1795 return fn(self, *args, **kwargs) 1796 if version in (versions or []): 1797 reason = f"test skipped for CUDA version {version}" 1798 raise unittest.SkipTest(reason) 1799 return fn(self, *args, **kwargs) 1800 1801 return wrap_fn 1802 1803 return dec_fn 1804 1805 1806# Skips a test for CUDA versions less than specified, given in the form of [major, minor]. 1807def skipCUDAIfVersionLessThan(versions: Tuple[int, int] = None): 1808 def dec_fn(fn): 1809 @wraps(fn) 1810 def wrap_fn(self, *args, **kwargs): 1811 version = _get_torch_cuda_version() 1812 if version == (0, 0): # cpu or rocm 1813 return fn(self, *args, **kwargs) 1814 if version < versions: 1815 reason = f"test skipped for CUDA versions < {version}" 1816 raise unittest.SkipTest(reason) 1817 return fn(self, *args, **kwargs) 1818 1819 return wrap_fn 1820 1821 return dec_fn 1822 1823 1824# Skips a test on CUDA if cuDNN is unavailable or its version is lower than requested. 1825def skipCUDAIfCudnnVersionLessThan(version=0): 1826 def dec_fn(fn): 1827 @wraps(fn) 1828 def wrap_fn(self, *args, **kwargs): 1829 if self.device_type == "cuda": 1830 if self.no_cudnn: 1831 reason = "cuDNN not available" 1832 raise unittest.SkipTest(reason) 1833 if self.cudnn_version is None or self.cudnn_version < version: 1834 reason = f"cuDNN version {self.cudnn_version} is available but {version} required" 1835 raise unittest.SkipTest(reason) 1836 1837 return fn(self, *args, **kwargs) 1838 1839 return wrap_fn 1840 1841 return dec_fn 1842 1843 1844# Skips a test on CUDA if cuSparse generic API is not available 1845def skipCUDAIfNoCusparseGeneric(fn): 1846 return skipCUDAIf(not TEST_CUSPARSE_GENERIC, "cuSparse Generic API not available")( 1847 fn 1848 ) 1849 1850 1851def skipCUDAIfNoHipsparseGeneric(fn): 1852 return skipCUDAIf( 1853 not TEST_HIPSPARSE_GENERIC, "hipSparse Generic API not available" 1854 )(fn) 1855 1856 1857def skipCUDAIfNoSparseGeneric(fn): 1858 return skipCUDAIf( 1859 not (TEST_CUSPARSE_GENERIC or TEST_HIPSPARSE_GENERIC), 1860 "Sparse Generic API not available", 1861 )(fn) 1862 1863 1864def skipCUDAIfNoCudnn(fn): 1865 return skipCUDAIfCudnnVersionLessThan(0)(fn) 1866 1867 1868def skipCUDAIfMiopen(fn): 1869 return skipCUDAIf(torch.version.hip is not None, "Marked as skipped for MIOpen")(fn) 1870 1871 1872def skipCUDAIfNoMiopen(fn): 1873 return skipCUDAIf(torch.version.hip is None, "MIOpen is not available")( 1874 skipCUDAIfNoCudnn(fn) 1875 ) 1876 1877 1878def skipLazy(fn): 1879 return skipLazyIf(True, "test doesn't work with lazy tensors")(fn) 1880 1881 1882def skipMeta(fn): 1883 return skipMetaIf(True, "test doesn't work with meta tensors")(fn) 1884 1885 1886def skipXLA(fn): 1887 return skipXLAIf(True, "Marked as skipped for XLA")(fn) 1888 1889 1890def skipMPS(fn): 1891 return skipMPSIf(True, "test doesn't work on MPS backend")(fn) 1892 1893 1894def skipHPU(fn): 1895 return skipHPUIf(True, "test doesn't work on HPU backend")(fn) 1896 1897 1898def skipPRIVATEUSE1(fn): 1899 return skipPRIVATEUSE1If(True, "test doesn't work on privateuse1 backend")(fn) 1900 1901 1902# TODO: the "all" in the name isn't true anymore for quite some time as we have also have for example XLA and MPS now. 1903# This should probably enumerate all available device type test base classes. 1904def get_all_device_types() -> List[str]: 1905 return ["cpu"] if not torch.cuda.is_available() else ["cpu", "cuda"] 1906