1# mypy: ignore-errors 2 3r"""Importing this file must **not** initialize CUDA context. test_distributed 4relies on this assumption to properly run. This means that when this is imported 5no CUDA calls shall be made, including torch.cuda.device_count(), etc. 6 7torch.testing._internal.common_cuda.py can freely initialize CUDA context when imported. 8""" 9 10import argparse 11import contextlib 12import copy 13import ctypes 14import errno 15import functools 16import gc 17import hashlib 18import inspect 19import io 20import json 21import logging 22import math 23import operator 24import os 25import platform 26import random 27import re 28import shutil 29import signal 30import socket 31import subprocess 32import sys 33import tempfile 34import threading 35import time 36import types 37import unittest 38import warnings 39from collections.abc import Mapping, Sequence 40from contextlib import closing, contextmanager 41from copy import deepcopy 42from dataclasses import dataclass 43from enum import Enum 44from functools import partial, wraps 45from itertools import product, chain 46from pathlib import Path 47from statistics import mean 48from typing import ( 49 Any, 50 Callable, 51 Dict, 52 Iterable, 53 Iterator, 54 List, 55 Optional, 56 Tuple, 57 Type, 58 TypeVar, 59 Union, 60) 61from unittest.mock import MagicMock 62 63import expecttest 64import numpy as np 65 66import __main__ # type: ignore[import] 67import torch 68import torch.backends.cudnn 69import torch.backends.mkl 70import torch.backends.mps 71import torch.backends.xnnpack 72import torch.cuda 73from torch import Tensor 74from torch._C import ScriptDict, ScriptList # type: ignore[attr-defined] 75from torch._dynamo.trace_rules import _as_posix_path 76from torch._utils_internal import get_writable_path 77from torch.nn import ( 78 ModuleDict, 79 ModuleList, 80 ParameterDict, 81 ParameterList, 82 Sequential, 83) 84from torch.onnx import ( 85 register_custom_op_symbolic, 86 unregister_custom_op_symbolic, 87) 88from torch.testing import make_tensor 89from torch.testing._comparison import ( 90 BooleanPair, 91 NonePair, 92 NumberPair, 93 Pair, 94 TensorLikePair, 95) 96from torch.testing._comparison import not_close_error_metas 97from torch.testing._internal.common_dtype import get_all_dtypes 98from torch.utils._import_utils import _check_module_exists 99import torch.utils._pytree as pytree 100try: 101 import pytest 102 has_pytest = True 103except ImportError: 104 has_pytest = False 105 106 107MI300_ARCH = ("gfx940", "gfx941", "gfx942") 108 109 110def freeze_rng_state(*args, **kwargs): 111 return torch.testing._utils.freeze_rng_state(*args, **kwargs) 112 113 114# Class to keep track of test flags configurable by environment variables. 115# Flags set here are intended to be read-only and should not be modified after 116# definition. 117# TODO: Expand this class to handle abritrary settings in addition to boolean flags? 118class TestEnvironment: 119 # Set of env vars to set for the repro command that is output on test failure. 120 # Specifically, this includes env vars that are set to non-default values and 121 # are not implied. Maps from env var name -> value (int) 122 repro_env_vars: dict = {} 123 124 # Defines a flag usable throughout the test suite, determining its value by querying 125 # the specified environment variable. 126 # 127 # Args: 128 # name (str): The name of the flag. A global variable with this name will be set 129 # for convenient access throughout the test suite. 130 # env_var (str): The name of the primary environment variable from which to 131 # determine the value of this flag. If this is None or the environment variable 132 # is unset, the default value will be used unless otherwise implied (see 133 # implied_by_fn). Default: None 134 # default (bool): The default value to use for the flag if unset by the environment 135 # variable and unimplied. Default: False 136 # include_in_repro (bool): Indicates whether this flag should be included in the 137 # repro command that is output on test failure (i.e. whether it is possibly 138 # relevant to reproducing the test failure). Default: True 139 # enabled_fn (Callable): Callable returning whether the flag should be enabled 140 # given the environment variable value and the default value. Default: Lambda 141 # requiring "0" to disable if on by default OR "1" to enable if off by default. 142 # implied_by_fn (Callable): Thunk returning a bool to imply this flag as enabled 143 # by something outside of its primary environment variable setting. For example, 144 # this can be useful if the value of another environment variable implies the flag 145 # as enabled. Default: Lambda returning False to indicate no implications. 146 @staticmethod 147 def def_flag( 148 name, 149 env_var=None, 150 default=False, 151 include_in_repro=True, 152 enabled_fn=lambda env_var_val, default: ( 153 (env_var_val != "0") if default else (env_var_val == "1")), 154 implied_by_fn=lambda: False, 155 ): 156 enabled = default 157 if env_var is not None: 158 env_var_val = os.getenv(env_var) 159 enabled = enabled_fn(env_var_val, default) 160 implied = implied_by_fn() 161 enabled = enabled or implied 162 if include_in_repro and (env_var is not None) and (enabled != default) and not implied: 163 TestEnvironment.repro_env_vars[env_var] = env_var_val 164 165 # export flag globally for convenience 166 assert name not in globals(), f"duplicate definition of flag '{name}'" 167 globals()[name] = enabled 168 return enabled 169 170 # Defines a setting usable throughout the test suite, determining its value by querying 171 # the specified environment variable. This differs from a flag in that it's not restricted 172 # to a boolean value. 173 # 174 # Args: 175 # name (str): The name of the setting. A global variable with this name will be set 176 # for convenient access throughout the test suite. 177 # env_var (str): The name of the primary environment variable from which to 178 # determine the value of this setting. If this is None or the environment variable 179 # is unset, the default value will be used. Default: None 180 # default (Any): The default value to use for the setting if unset by the environment 181 # variable. Default: None 182 # include_in_repro (bool): Indicates whether this setting should be included in the 183 # repro command that is output on test failure (i.e. whether it is possibly 184 # relevant to reproducing the test failure). Default: True 185 # parse_fn (Callable): Callable parsing the env var string. Default value just uses 186 # the string itself. 187 @staticmethod 188 def def_setting( 189 name, 190 env_var=None, 191 default=None, 192 include_in_repro=True, 193 parse_fn=lambda maybe_val_str: maybe_val_str, 194 ): 195 value = default if env_var is None else os.getenv(env_var) 196 value = parse_fn(value) 197 if include_in_repro and (value != default): 198 TestEnvironment.repro_env_vars[env_var] = value 199 200 # export setting globally for convenience 201 assert name not in globals(), f"duplicate definition of setting '{name}'" 202 globals()[name] = value 203 return value 204 205 # Returns a string prefix usable to set environment variables for any test 206 # settings that should be explicitly set to match this instantiation of the 207 # test suite. 208 # Example: "PYTORCH_TEST_WITH_ASAN=1 PYTORCH_TEST_WITH_ROCM=1" 209 @staticmethod 210 def repro_env_var_prefix() -> str: 211 return " ".join([f"{env_var}={value}" 212 for env_var, value in TestEnvironment.repro_env_vars.items()]) 213 214 215log = logging.getLogger(__name__) 216torch.backends.disable_global_flags() 217 218FILE_SCHEMA = "file://" 219if sys.platform == 'win32': 220 FILE_SCHEMA = "file:///" 221 222# NB: This flag differs semantically from others in that setting the env var to any 223# non-empty value will cause it to be true: 224# CI=1, CI="true", CI=0, etc. all set the flag to be true. 225# CI= and an unset CI set the flag to be false. 226# GitHub sets the value to CI="true" to enable it. 227IS_CI: bool = TestEnvironment.def_flag( 228 "IS_CI", 229 env_var="CI", 230 include_in_repro=False, 231 enabled_fn=lambda env_var_value, _: bool(env_var_value), 232) 233IS_SANDCASTLE: bool = TestEnvironment.def_flag( 234 "IS_SANDCASTLE", 235 env_var="SANDCASTLE", 236 implied_by_fn=lambda: os.getenv("TW_JOB_USER") == "sandcastle", 237 include_in_repro=False, 238) 239 240_is_fbcode_default = ( 241 hasattr(torch._utils_internal, "IS_FBSOURCE") and 242 torch._utils_internal.IS_FBSOURCE 243) 244 245IS_FBCODE: bool = TestEnvironment.def_flag( 246 "IS_FBCODE", 247 env_var="PYTORCH_TEST_FBCODE", 248 default=_is_fbcode_default, 249 include_in_repro=False, 250) 251IS_REMOTE_GPU: bool = TestEnvironment.def_flag( 252 "IS_REMOTE_GPU", 253 env_var="PYTORCH_TEST_REMOTE_GPU", 254 include_in_repro=False, 255) 256 257DISABLE_RUNNING_SCRIPT_CHK: bool = TestEnvironment.def_flag( 258 "DISABLE_RUNNING_SCRIPT_CHK", 259 env_var="PYTORCH_DISABLE_RUNNING_SCRIPT_CHK", 260 include_in_repro=False, 261) 262# NB: enabled by default unless in an fbcode context. 263PRINT_REPRO_ON_FAILURE: bool = TestEnvironment.def_flag( 264 "PRINT_REPRO_ON_FAILURE", 265 env_var="PYTORCH_PRINT_REPRO_ON_FAILURE", 266 default=(not IS_FBCODE), 267 include_in_repro=False, 268) 269 270# possibly restrict OpInfo tests to a single sample input 271OPINFO_SAMPLE_INPUT_INDEX: Optional[int] = TestEnvironment.def_setting( 272 "OPINFO_SAMPLE_INPUT_INDEX", 273 env_var="PYTORCH_OPINFO_SAMPLE_INPUT_INDEX", 274 default=None, 275 # Don't include the env var value in the repro command because the info will 276 # be queried from the tracked sample input instead 277 include_in_repro=False, 278 parse_fn=lambda val: None if val is None else int(val), 279) 280 281DEFAULT_DISABLED_TESTS_FILE = '.pytorch-disabled-tests.json' 282DEFAULT_SLOW_TESTS_FILE = 'slow_tests.json' 283 284disabled_tests_dict = {} 285slow_tests_dict = {} 286 287def maybe_load_json(filename): 288 if os.path.isfile(filename): 289 with open(filename) as fp: 290 return json.load(fp) 291 log.warning("Attempted to load json file '%s' but it does not exist.", filename) 292 return {} 293 294# set them here in case the tests are running in a subprocess that doesn't call run_tests 295if os.getenv("SLOW_TESTS_FILE", ""): 296 slow_tests_dict = maybe_load_json(os.getenv("SLOW_TESTS_FILE", "")) 297if os.getenv("DISABLED_TESTS_FILE", ""): 298 disabled_tests_dict = maybe_load_json(os.getenv("DISABLED_TESTS_FILE", "")) 299 300NATIVE_DEVICES = ('cpu', 'cuda', 'xpu', 'meta', torch._C._get_privateuse1_backend_name()) 301 302check_names = ['orin', 'concord', 'galen', 'xavier', 'nano', 'jetson', 'tegra'] 303IS_JETSON = any(name in platform.platform() for name in check_names) 304 305def gcIfJetson(fn): 306 # Irregular Jetson host/device memory setup requires cleanup to avoid tests being killed 307 @functools.wraps(fn) 308 def wrapper(*args, **kwargs): 309 if IS_JETSON: 310 gc.collect() 311 torch.cuda.empty_cache() 312 fn(*args, **kwargs) 313 return wrapper 314 315# Tries to extract the current test function by crawling the stack. 316# If unsuccessful, return None. 317def extract_test_fn() -> Optional[Callable]: 318 try: 319 stack = inspect.stack() 320 for frame_info in stack: 321 frame = frame_info.frame 322 if "self" not in frame.f_locals: 323 continue 324 self_val = frame.f_locals["self"] 325 if isinstance(self_val, unittest.TestCase): 326 test_id = self_val.id() 327 test_name = test_id.split('.')[2] 328 test_fn = getattr(self_val, test_name).__func__ 329 return test_fn 330 except Exception: 331 pass 332 return None 333 334# Contains tracked input data useful for debugging purposes 335@dataclass 336class TrackedInput: 337 index: int 338 val: Any 339 type_desc: str 340 341# Attempt to pull out tracked input information from the test function. 342# A TrackedInputIter is used to insert this information. 343def get_tracked_input() -> Optional[TrackedInput]: 344 test_fn = extract_test_fn() 345 if test_fn is None: 346 return None 347 if not hasattr(test_fn, "tracked_input"): 348 return None 349 return test_fn.tracked_input 350 351def clear_tracked_input(): 352 test_fn = extract_test_fn() 353 if test_fn is None: 354 return 355 if not hasattr(test_fn, "tracked_input"): 356 return None 357 test_fn.tracked_input = None 358 359# Wraps an iterator and tracks the most recent value the iterator produces 360# for debugging purposes. Tracked values are stored on the test function. 361class TrackedInputIter: 362 def __init__(self, child_iter, input_type_desc, 363 callback=lambda x: x, set_seed=True, restrict_to_index=None): 364 self.child_iter = enumerate(child_iter) 365 # Input type describes the things we're tracking (e.g. "sample input", "error input"). 366 self.input_type_desc = input_type_desc 367 # Callback is run on each iterated thing to get the thing to track. 368 self.callback = callback 369 self.test_fn = extract_test_fn() 370 # Indicates whether the random seed should be set before each call to the iterator 371 self.set_seed = set_seed 372 # Indicates that iteration should be restricted to only the provided index. 373 # If None, no restriction is done 374 self.restrict_to_index = restrict_to_index 375 376 def __iter__(self): 377 return self 378 379 def __next__(self): 380 while True: 381 if self.set_seed: 382 # use a test-name-specific hash for the seed if possible 383 seed = ( 384 int.from_bytes(hashlib.sha256( 385 self.test_fn.__qualname__.encode("utf-8")).digest()[:4], 'little') 386 if self.test_fn is not None else SEED 387 ) 388 set_rng_seed(seed) 389 390 # allow StopIteration to bubble up 391 input_idx, input_val = next(self.child_iter) 392 if (self.restrict_to_index is None) or (input_idx == self.restrict_to_index): 393 break 394 395 self._set_tracked_input( 396 TrackedInput( 397 index=input_idx, val=self.callback(input_val), type_desc=self.input_type_desc 398 ) 399 ) 400 return input_val 401 402 def _set_tracked_input(self, tracked_input: TrackedInput): 403 if self.test_fn is None: 404 return 405 if not hasattr(self.test_fn, "tracked_input"): 406 return 407 self.test_fn.tracked_input = tracked_input 408 409class _TestParametrizer: 410 """ 411 Decorator class for parametrizing a test function, yielding a set of new tests spawned 412 from the original generic test, each specialized for a specific set of test inputs. For 413 example, parametrizing a test across the set of ops will result in a test function per op. 414 415 The decision of how to parametrize / what to parametrize over is intended to be implemented 416 by each derived class. 417 418 In the details, the decorator adds a 'parametrize_fn' property to the test function. This function 419 is intended to be called later by one of: 420 * Device-specific test instantiation via instantiate_device_type_tests(). Note that for this 421 case there is no need to explicitly parametrize over device type, as that is handled separately. 422 * Device-agnostic parametrized test instantiation via instantiate_parametrized_tests(). 423 424 If the decorator is applied to a test function that already has a 'parametrize_fn' property, a new 425 composite 'parametrize_fn' will be created that generates tests with the product of the parameters 426 generated by the old and new parametrize_fns. This allows for convenient composability of decorators. 427 """ 428 def _parametrize_test(self, test, generic_cls, device_cls): 429 """ 430 Parametrizes the given test function across whatever dimension is specified by the derived class. 431 Tests can be parametrized over any arbitrary dimension or combination of dimensions, such as all 432 ops, all modules, or all ops + their associated dtypes. 433 434 Args: 435 test (fn): Test function to parametrize over 436 generic_cls (class): Generic test class object containing tests (e.g. TestFoo) 437 device_cls (class): Device-specialized test class object (e.g. TestFooCPU); set to None 438 if the tests are not part of a device-specific set 439 440 Returns: 441 Generator object returning 4-tuples of: 442 test (fn): Parametrized test function; must support a device arg and args for any params 443 test_name (str): Parametrized suffix for the test (e.g. opname_int64); will be appended to 444 the base name of the test 445 param_kwargs (dict): Param kwargs to pass to the test (e.g. {'op': 'add', 'dtype': torch.int64}) 446 decorator_fn (callable): Callable[[Dict], List] for list of decorators to apply given param_kwargs 447 """ 448 raise NotImplementedError 449 450 def __call__(self, fn): 451 if hasattr(fn, 'parametrize_fn'): 452 # Do composition with the product of args. 453 old_parametrize_fn = fn.parametrize_fn 454 new_parametrize_fn = self._parametrize_test 455 fn.parametrize_fn = compose_parametrize_fns(old_parametrize_fn, new_parametrize_fn) 456 else: 457 fn.parametrize_fn = self._parametrize_test 458 return fn 459 460 461def compose_parametrize_fns(old_parametrize_fn, new_parametrize_fn): 462 """ 463 Returns a parametrize_fn that parametrizes over the product of the parameters handled 464 by the given parametrize_fns. Each given parametrize_fn should each have the signature 465 f(test, generic_cls, device_cls). 466 467 The test names will be a combination of the names produced by the parametrize_fns in 468 "<new_name>_<old_name>" order. This order is done to match intuition for constructed names 469 when composing multiple decorators; the names will be built in top to bottom order when stacking 470 parametrization decorators. 471 472 Args: 473 old_parametrize_fn (callable) - First parametrize_fn to compose. 474 new_parametrize_fn (callable) - Second parametrize_fn to compose. 475 """ 476 477 def composite_fn(test, generic_cls, device_cls, 478 old_parametrize_fn=old_parametrize_fn, 479 new_parametrize_fn=new_parametrize_fn): 480 old_tests = list(old_parametrize_fn(test, generic_cls, device_cls)) 481 for (old_test, old_test_name, old_param_kwargs, old_dec_fn) in old_tests: 482 for (new_test, new_test_name, new_param_kwargs, new_dec_fn) in \ 483 new_parametrize_fn(old_test, generic_cls, device_cls): 484 redundant_params = set(old_param_kwargs.keys()).intersection(new_param_kwargs.keys()) 485 if redundant_params: 486 raise RuntimeError('Parametrization over the same parameter by multiple parametrization ' 487 f'decorators is not supported. For test "{test.__name__}", the following parameters ' 488 f'are handled multiple times: {redundant_params}') 489 full_param_kwargs = {**old_param_kwargs, **new_param_kwargs} 490 merged_test_name = '{}{}{}'.format(new_test_name, 491 '_' if old_test_name != '' and new_test_name != '' else '', 492 old_test_name) 493 494 def merged_decorator_fn(param_kwargs, old_dec_fn=old_dec_fn, new_dec_fn=new_dec_fn): 495 return list(old_dec_fn(param_kwargs)) + list(new_dec_fn(param_kwargs)) 496 497 yield (new_test, merged_test_name, full_param_kwargs, merged_decorator_fn) 498 499 return composite_fn 500 501 502def instantiate_parametrized_tests(generic_cls): 503 """ 504 Instantiates tests that have been decorated with a parametrize_fn. This is generally performed by a 505 decorator subclass of _TestParametrizer. The generic test will be replaced on the test class by 506 parametrized tests with specialized names. This should be used instead of 507 instantiate_device_type_tests() if the test class contains device-agnostic tests. 508 509 You can also use it as a class decorator. E.g. 510 511 ``` 512 @instantiate_parametrized_tests 513 class TestFoo(TestCase): 514 ... 515 ``` 516 517 Args: 518 generic_cls (class): Generic test class object containing tests (e.g. TestFoo) 519 """ 520 for attr_name in tuple(dir(generic_cls)): 521 class_attr = getattr(generic_cls, attr_name) 522 if not hasattr(class_attr, 'parametrize_fn'): 523 continue 524 525 # Remove the generic test from the test class. 526 delattr(generic_cls, attr_name) 527 528 # Add parametrized tests to the test class. 529 def instantiate_test_helper(cls, name, test, param_kwargs): 530 @wraps(test) 531 def instantiated_test(self, param_kwargs=param_kwargs): 532 test(self, **param_kwargs) 533 534 assert not hasattr(generic_cls, name), f"Redefinition of test {name}" 535 setattr(generic_cls, name, instantiated_test) 536 537 for (test, test_suffix, param_kwargs, decorator_fn) in class_attr.parametrize_fn( 538 class_attr, generic_cls=generic_cls, device_cls=None): 539 full_name = f'{test.__name__}_{test_suffix}' 540 541 # Apply decorators based on full param kwargs. 542 for decorator in decorator_fn(param_kwargs): 543 test = decorator(test) 544 545 instantiate_test_helper(cls=generic_cls, name=full_name, test=test, param_kwargs=param_kwargs) 546 return generic_cls 547 548 549class subtest: 550 """ 551 Explicit subtest case for use with test parametrization. 552 Allows for explicit naming of individual subtest cases as well as applying 553 decorators to the parametrized test. 554 555 Args: 556 arg_values (iterable): Iterable of arg values (e.g. range(10)) or 557 tuples of arg values (e.g. [(1, 2), (3, 4)]). 558 name (str): Optional name to use for the test. 559 decorators (iterable): Iterable of decorators to apply to the generated test. 560 """ 561 __slots__ = ['arg_values', 'name', 'decorators'] 562 563 def __init__(self, arg_values, name=None, decorators=None): 564 self.arg_values = arg_values 565 self.name = name 566 self.decorators = decorators if decorators else [] 567 568 569class parametrize(_TestParametrizer): 570 """ 571 Decorator for applying generic test parametrizations. 572 573 The interface for this decorator is modeled after `@pytest.mark.parametrize`. 574 Basic usage between this decorator and pytest's is identical. The first argument 575 should be a string containing comma-separated names of parameters for the test, and 576 the second argument should be an iterable returning values or tuples of values for 577 the case of multiple parameters. 578 579 Beyond this basic usage, the decorator provides some additional functionality that 580 pytest does not. 581 582 1. Parametrized tests end up as generated test functions on unittest test classes. 583 Since this differs from how pytest works, this decorator takes on the additional 584 responsibility of naming these test functions. The default test names consists of 585 the test's base name followed by each parameter name + value (e.g. "test_bar_x_1_y_foo"), 586 but custom names can be defined using `name_fn` or the `subtest` structure (see below). 587 588 2. The decorator specially handles parameter values of type `subtest`, which allows for 589 more fine-grained control over both test naming and test execution. In particular, it can 590 be used to tag subtests with explicit test names or apply arbitrary decorators (see examples 591 below). 592 593 Examples:: 594 595 @parametrize("x", range(5)) 596 def test_foo(self, x): 597 ... 598 599 @parametrize("x,y", [(1, 'foo'), (2, 'bar'), (3, 'baz')]) 600 def test_bar(self, x, y): 601 ... 602 603 @parametrize("x,y", [(1, 'foo'), (2, 'bar'), (3, 'baz')], 604 name_fn=lambda x, y: '{}_{}'.format(x, y)) 605 def test_bar_custom_names(self, x, y): 606 ... 607 608 @parametrize("x, y", [subtest((1, 2), name='double'), 609 subtest((1, 3), name='triple', decorators=[unittest.expectedFailure]), 610 subtest((1, 4), name='quadruple')]) 611 def test_baz(self, x, y): 612 ... 613 614 To actually instantiate the parametrized tests, one of instantiate_parametrized_tests() or 615 instantiate_device_type_tests() should be called. The former is intended for test classes 616 that contain device-agnostic tests, while the latter should be used for test classes that 617 contain device-specific tests. Both support arbitrary parametrizations using the decorator. 618 619 Args: 620 arg_str (str): String of arg names separate by commas (e.g. "x,y"). 621 arg_values (iterable): Iterable of arg values (e.g. range(10)) or 622 tuples of arg values (e.g. [(1, 2), (3, 4)]). 623 name_fn (Callable): Optional function that takes in parameters and returns subtest name. 624 """ 625 def __init__(self, arg_str, arg_values, name_fn=None): 626 self.arg_names: List[str] = [s.strip() for s in arg_str.split(',') if s != ''] 627 self.arg_values = arg_values 628 self.name_fn = name_fn 629 630 def _formatted_str_repr(self, idx, name, value): 631 """ Returns a string representation for the given arg that is suitable for use in test function names. """ 632 if isinstance(value, torch.dtype): 633 return dtype_name(value) 634 elif isinstance(value, torch.device): 635 return str(value) 636 # Can't use isinstance as it would cause a circular import 637 elif type(value).__name__ in {'OpInfo', 'ModuleInfo'}: 638 return value.formatted_name 639 elif isinstance(value, (int, float, str)): 640 return f"{name}_{str(value).replace('.', '_')}" 641 else: 642 return f"{name}{idx}" 643 644 def _default_subtest_name(self, idx, values): 645 return '_'.join([self._formatted_str_repr(idx, a, v) for a, v in zip(self.arg_names, values)]) 646 647 def _get_subtest_name(self, idx, values, explicit_name=None): 648 if explicit_name: 649 subtest_name = explicit_name 650 elif self.name_fn: 651 subtest_name = self.name_fn(*values) 652 else: 653 subtest_name = self._default_subtest_name(idx, values) 654 return subtest_name 655 656 def _parametrize_test(self, test, generic_cls, device_cls): 657 if len(self.arg_names) == 0: 658 # No additional parameters needed for the test. 659 test_name = '' 660 yield (test, test_name, {}, lambda _: []) 661 else: 662 # Each "values" item is expected to be either: 663 # * A tuple of values with one for each arg. For a single arg, a single item is expected. 664 # * A subtest instance with arg_values matching the previous. 665 values = check_exhausted_iterator = object() 666 for idx, values in enumerate(self.arg_values): 667 maybe_name = None 668 669 decorators = [] 670 if isinstance(values, subtest): 671 sub = values 672 values = sub.arg_values 673 maybe_name = sub.name 674 675 @wraps(test) 676 def test_wrapper(*args, **kwargs): 677 return test(*args, **kwargs) 678 679 decorators = sub.decorators 680 gen_test = test_wrapper 681 else: 682 gen_test = test 683 684 values = list(values) if len(self.arg_names) > 1 else [values] 685 if len(values) != len(self.arg_names): 686 raise RuntimeError(f'Expected # values == # arg names, but got: {len(values)} ' 687 f'values and {len(self.arg_names)} names for test "{test.__name__}"') 688 689 param_kwargs = dict(zip(self.arg_names, values)) 690 691 test_name = self._get_subtest_name(idx, values, explicit_name=maybe_name) 692 693 def decorator_fn(_, decorators=decorators): 694 return decorators 695 696 yield (gen_test, test_name, param_kwargs, decorator_fn) 697 698 if values is check_exhausted_iterator: 699 raise ValueError(f'{test}: An empty arg_values was passed to @parametrize. ' 700 'Note that this may result from reuse of a generator.') 701 702 703class decorateIf(_TestParametrizer): 704 """ 705 Decorator for applying parameter-specific conditional decoration. 706 Composes with other test parametrizers (e.g. @modules, @ops, @parametrize, etc.). 707 708 Examples:: 709 710 @decorateIf(unittest.skip, lambda params: params["x"] == 2) 711 @parametrize("x", range(5)) 712 def test_foo(self, x): 713 ... 714 715 @parametrize("x,y", [(1, 'foo'), (2, 'bar'), (3, 'baz')]) 716 @decorateIf( 717 unittest.expectedFailure, 718 lambda params: params["x"] == 3 and params["y"] == "baz" 719 ) 720 def test_bar(self, x, y): 721 ... 722 723 @decorateIf( 724 unittest.expectedFailure, 725 lambda params: params["op"].name == "add" and params["dtype"] == torch.float16 726 ) 727 @ops(op_db) 728 def test_op_foo(self, device, dtype, op): 729 ... 730 731 @decorateIf( 732 unittest.skip, 733 lambda params: params["module_info"].module_cls is torch.nn.Linear and \ 734 params["device"] == "cpu" 735 ) 736 @modules(module_db) 737 def test_module_foo(self, device, dtype, module_info): 738 ... 739 740 Args: 741 decorator: Test decorator to apply if the predicate is satisfied. 742 predicate_fn (Callable): Function taking in a dict of params and returning a boolean 743 indicating whether the decorator should be applied or not. 744 """ 745 def __init__(self, decorator, predicate_fn): 746 self.decorator = decorator 747 self.predicate_fn = predicate_fn 748 749 def _parametrize_test(self, test, generic_cls, device_cls): 750 751 # Leave test as-is and return the appropriate decorator_fn. 752 def decorator_fn(params, decorator=self.decorator, predicate_fn=self.predicate_fn): 753 if predicate_fn(params): 754 return [decorator] 755 else: 756 return [] 757 758 @wraps(test) 759 def test_wrapper(*args, **kwargs): 760 return test(*args, **kwargs) 761 762 test_name = '' 763 yield (test_wrapper, test_name, {}, decorator_fn) 764 765 766class ProfilingMode(Enum): 767 LEGACY = 1 768 SIMPLE = 2 769 PROFILING = 3 770 771def cppProfilingFlagsToProfilingMode(): 772 old_prof_exec_state = torch._C._jit_set_profiling_executor(True) 773 old_prof_mode_state = torch._C._get_graph_executor_optimize(True) 774 torch._C._jit_set_profiling_executor(old_prof_exec_state) 775 torch._C._get_graph_executor_optimize(old_prof_mode_state) 776 777 if old_prof_exec_state: 778 if old_prof_mode_state: 779 return ProfilingMode.PROFILING 780 else: 781 return ProfilingMode.SIMPLE 782 else: 783 return ProfilingMode.LEGACY 784 785@contextmanager 786def enable_profiling_mode_for_profiling_tests(): 787 if GRAPH_EXECUTOR == ProfilingMode.PROFILING: 788 old_prof_exec_state = torch._C._jit_set_profiling_executor(True) 789 old_prof_mode_state = torch._C._get_graph_executor_optimize(True) 790 try: 791 yield 792 finally: 793 if GRAPH_EXECUTOR == ProfilingMode.PROFILING: 794 torch._C._jit_set_profiling_executor(old_prof_exec_state) 795 torch._C._get_graph_executor_optimize(old_prof_mode_state) 796 797@contextmanager 798def enable_profiling_mode(): 799 old_prof_exec_state = torch._C._jit_set_profiling_executor(True) 800 old_prof_mode_state = torch._C._get_graph_executor_optimize(True) 801 try: 802 yield 803 finally: 804 torch._C._jit_set_profiling_executor(old_prof_exec_state) 805 torch._C._get_graph_executor_optimize(old_prof_mode_state) 806 807@contextmanager 808def num_profiled_runs(num_runs): 809 old_num_runs = torch._C._jit_set_num_profiled_runs(num_runs) 810 try: 811 yield 812 finally: 813 torch._C._jit_set_num_profiled_runs(old_num_runs) 814 815func_call = torch._C.ScriptFunction.__call__ 816meth_call = torch._C.ScriptMethod.__call__ 817 818def prof_callable(callable, *args, **kwargs): 819 if 'profile_and_replay' in kwargs: 820 del kwargs['profile_and_replay'] 821 if GRAPH_EXECUTOR == ProfilingMode.PROFILING: 822 with enable_profiling_mode_for_profiling_tests(): 823 callable(*args, **kwargs) 824 return callable(*args, **kwargs) 825 826 return callable(*args, **kwargs) 827 828def prof_func_call(*args, **kwargs): 829 return prof_callable(func_call, *args, **kwargs) 830 831def prof_meth_call(*args, **kwargs): 832 return prof_callable(meth_call, *args, **kwargs) 833 834torch._C.ScriptFunction.__call__ = prof_func_call # type: ignore[method-assign] 835torch._C.ScriptMethod.__call__ = prof_meth_call # type: ignore[method-assign] 836 837def _get_test_report_path(): 838 # allow users to override the test file location. We need this 839 # because the distributed tests run the same test file multiple 840 # times with different configurations. 841 override = os.environ.get('TEST_REPORT_SOURCE_OVERRIDE') 842 test_source = override if override is not None else 'python-unittest' 843 return os.path.join('test-reports', test_source) 844 845is_running_via_run_test = "run_test.py" in getattr(__main__, "__file__", "") 846parser = argparse.ArgumentParser(add_help=not is_running_via_run_test, allow_abbrev=False) 847parser.add_argument('--subprocess', action='store_true', 848 help='whether to run each test in a subprocess') 849parser.add_argument('--seed', type=int, default=1234) 850parser.add_argument('--accept', action='store_true') 851parser.add_argument('--jit-executor', '--jit_executor', type=str) 852parser.add_argument('--repeat', type=int, default=1) 853parser.add_argument('--test-bailouts', '--test_bailouts', action='store_true') 854parser.add_argument('--use-pytest', action='store_true') 855parser.add_argument('--save-xml', nargs='?', type=str, 856 const=_get_test_report_path(), 857 default=_get_test_report_path() if IS_CI else None) 858parser.add_argument('--discover-tests', action='store_true') 859parser.add_argument('--log-suffix', type=str, default="") 860parser.add_argument('--run-parallel', type=int, default=1) 861parser.add_argument('--import-slow-tests', type=str, nargs='?', const=DEFAULT_SLOW_TESTS_FILE) 862parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DEFAULT_DISABLED_TESTS_FILE) 863parser.add_argument('--rerun-disabled-tests', action='store_true') 864parser.add_argument('--pytest-single-test', type=str, nargs=1) 865if sys.version_info >= (3, 9): 866 parser.add_argument('--showlocals', action=argparse.BooleanOptionalAction, default=False) 867else: 868 parser.add_argument('--showlocals', action='store_true', default=False) 869 parser.add_argument('--no-showlocals', dest='showlocals', action='store_false') 870 871# Only run when -h or --help flag is active to display both unittest and parser help messages. 872def run_unittest_help(argv): 873 unittest.main(argv=argv) 874 875if '-h' in sys.argv or '--help' in sys.argv: 876 help_thread = threading.Thread(target=run_unittest_help, args=(sys.argv,)) 877 help_thread.start() 878 help_thread.join() 879 880args, remaining = parser.parse_known_args() 881if args.jit_executor == 'legacy': 882 GRAPH_EXECUTOR = ProfilingMode.LEGACY 883elif args.jit_executor == 'profiling': 884 GRAPH_EXECUTOR = ProfilingMode.PROFILING 885elif args.jit_executor == 'simple': 886 GRAPH_EXECUTOR = ProfilingMode.SIMPLE 887else: 888 # infer flags based on the default settings 889 GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode() 890 891RERUN_DISABLED_TESTS = args.rerun_disabled_tests 892 893SLOW_TESTS_FILE = args.import_slow_tests 894DISABLED_TESTS_FILE = args.import_disabled_tests 895LOG_SUFFIX = args.log_suffix 896RUN_PARALLEL = args.run_parallel 897TEST_BAILOUTS = args.test_bailouts 898USE_PYTEST = args.use_pytest 899PYTEST_SINGLE_TEST = args.pytest_single_test 900TEST_DISCOVER = args.discover_tests 901TEST_IN_SUBPROCESS = args.subprocess 902TEST_SAVE_XML = args.save_xml 903REPEAT_COUNT = args.repeat 904SEED = args.seed 905SHOWLOCALS = args.showlocals 906if not getattr(expecttest, "ACCEPT", False): 907 expecttest.ACCEPT = args.accept 908UNITTEST_ARGS = [sys.argv[0]] + remaining 909torch.manual_seed(SEED) 910 911# CI Prefix path used only on CI environment 912CI_TEST_PREFIX = str(Path(os.getcwd())) 913CI_PT_ROOT = str(Path(os.getcwd()).parent) 914CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch")) 915 916def wait_for_process(p, timeout=None): 917 try: 918 return p.wait(timeout=timeout) 919 except KeyboardInterrupt: 920 # Give `p` a chance to handle KeyboardInterrupt. Without this, 921 # `pytest` can't print errors it collected so far upon KeyboardInterrupt. 922 exit_status = p.wait(timeout=5) 923 if exit_status is not None: 924 return exit_status 925 else: 926 p.kill() 927 raise 928 except subprocess.TimeoutExpired: 929 # send SIGINT to give pytest a chance to make xml 930 p.send_signal(signal.SIGINT) 931 exit_status = None 932 try: 933 exit_status = p.wait(timeout=5) 934 # try to handle the case where p.wait(timeout=5) times out as well as 935 # otherwise the wait() call in the finally block can potentially hang 936 except subprocess.TimeoutExpired: 937 pass 938 if exit_status is not None: 939 return exit_status 940 else: 941 p.kill() 942 raise 943 except: # noqa: B001,E722, copied from python core library 944 p.kill() 945 raise 946 finally: 947 # Always call p.wait() to ensure exit 948 p.wait() 949 950def shell(command, cwd=None, env=None, stdout=None, stderr=None, timeout=None): 951 sys.stdout.flush() 952 sys.stderr.flush() 953 # The following cool snippet is copied from Py3 core library subprocess.call 954 # only the with 955 # 1. `except KeyboardInterrupt` block added for SIGINT handling. 956 # 2. In Py2, subprocess.Popen doesn't return a context manager, so we do 957 # `p.wait()` in a `final` block for the code to be portable. 958 # 959 # https://github.com/python/cpython/blob/71b6c1af727fbe13525fb734568057d78cea33f3/Lib/subprocess.py#L309-L323 960 assert not isinstance(command, str), "Command to shell should be a list or tuple of tokens" 961 p = subprocess.Popen(command, universal_newlines=True, cwd=cwd, env=env, stdout=stdout, stderr=stderr) 962 return wait_for_process(p, timeout=timeout) 963 964 965def retry_shell( 966 command, 967 cwd=None, 968 env=None, 969 stdout=None, 970 stderr=None, 971 timeout=None, 972 retries=1, 973 was_rerun=False, 974) -> Tuple[int, bool]: 975 # Returns exicode + whether it was rerun 976 assert ( 977 retries >= 0 978 ), f"Expecting non negative number for number of retries, got {retries}" 979 try: 980 exit_code = shell( 981 command, cwd=cwd, env=env, stdout=stdout, stderr=stderr, timeout=timeout 982 ) 983 if exit_code == 0 or retries == 0: 984 return exit_code, was_rerun 985 print( 986 f"Got exit code {exit_code}, retrying (retries left={retries})", 987 file=stdout, 988 flush=True, 989 ) 990 except subprocess.TimeoutExpired: 991 if retries == 0: 992 print( 993 f"Command took >{timeout // 60}min, returning 124", 994 file=stdout, 995 flush=True, 996 ) 997 return 124, was_rerun 998 print( 999 f"Command took >{timeout // 60}min, retrying (retries left={retries})", 1000 file=stdout, 1001 flush=True, 1002 ) 1003 return retry_shell( 1004 command, 1005 cwd=cwd, 1006 env=env, 1007 stdout=stdout, 1008 stderr=stderr, 1009 timeout=timeout, 1010 retries=retries - 1, 1011 was_rerun=True, 1012 ) 1013 1014 1015def discover_test_cases_recursively(suite_or_case): 1016 if isinstance(suite_or_case, unittest.TestCase): 1017 return [suite_or_case] 1018 rc = [] 1019 for element in suite_or_case: 1020 print(element) 1021 rc.extend(discover_test_cases_recursively(element)) 1022 return rc 1023 1024def get_test_names(test_cases): 1025 return ['.'.join(case.id().split('.')[-2:]) for case in test_cases] 1026 1027def _print_test_names(): 1028 suite = unittest.TestLoader().loadTestsFromModule(__main__) 1029 test_cases = discover_test_cases_recursively(suite) 1030 for name in get_test_names(test_cases): 1031 print(name) 1032 1033def chunk_list(lst, nchunks): 1034 return [lst[i::nchunks] for i in range(nchunks)] 1035 1036# sanitize filename e.g., distributed/pipeline/sync/skip/test_api.py -> distributed.pipeline.sync.skip.test_api 1037def sanitize_test_filename(filename): 1038 # inspect.getfile returns absolute path in some CI jobs, converting it to relative path if needed 1039 if filename.startswith(CI_TEST_PREFIX): 1040 filename = filename[len(CI_TEST_PREFIX) + 1:] 1041 strip_py = re.sub(r'.py$', '', filename) 1042 return re.sub('/', r'.', strip_py) 1043 1044def lint_test_case_extension(suite): 1045 succeed = True 1046 for test_case_or_suite in suite: 1047 test_case = test_case_or_suite 1048 if isinstance(test_case_or_suite, unittest.TestSuite): 1049 first_test = test_case_or_suite._tests[0] if len(test_case_or_suite._tests) > 0 else None 1050 if first_test is not None and isinstance(first_test, unittest.TestSuite): 1051 return succeed and lint_test_case_extension(test_case_or_suite) 1052 test_case = first_test 1053 1054 if test_case is not None: 1055 test_class = test_case.id().split('.', 1)[1].split('.')[0] 1056 if not isinstance(test_case, TestCase): 1057 err = "This test class should extend from torch.testing._internal.common_utils.TestCase but it doesn't." 1058 print(f"{test_class} - failed. {err}") 1059 succeed = False 1060 return succeed 1061 1062 1063def get_report_path(argv=UNITTEST_ARGS, pytest=False): 1064 test_filename = sanitize_test_filename(argv[0]) 1065 test_report_path = TEST_SAVE_XML + LOG_SUFFIX 1066 test_report_path = os.path.join(test_report_path, test_filename) 1067 if pytest: 1068 test_report_path = test_report_path.replace('python-unittest', 'python-pytest') 1069 os.makedirs(test_report_path, exist_ok=True) 1070 test_report_path = os.path.join(test_report_path, f"{test_filename}-{os.urandom(8).hex()}.xml") 1071 return test_report_path 1072 os.makedirs(test_report_path, exist_ok=True) 1073 return test_report_path 1074 1075 1076def sanitize_pytest_xml(xml_file: str): 1077 # pytext xml is different from unittext xml, this function makes pytest xml more similar to unittest xml 1078 # consider somehow modifying the XML logger in conftest to do this instead 1079 import xml.etree.ElementTree as ET 1080 tree = ET.parse(xml_file) 1081 for testcase in tree.iter('testcase'): 1082 full_classname = testcase.attrib.get("classname") 1083 if full_classname is None: 1084 continue 1085 # The test prefix is optional 1086 regex_result = re.search(r"^(test\.)?(?P<file>.*)\.(?P<classname>[^\.]*)$", full_classname) 1087 if regex_result is None: 1088 continue 1089 classname = regex_result.group("classname") 1090 file = regex_result.group("file").replace(".", "/") 1091 testcase.set("classname", classname) 1092 testcase.set("file", f"{file}.py") 1093 tree.write(xml_file) 1094 1095 1096def get_pytest_test_cases(argv: List[str]) -> List[str]: 1097 class TestCollectorPlugin: 1098 def __init__(self) -> None: 1099 self.tests = [] 1100 1101 def pytest_collection_finish(self, session): 1102 for item in session.items: 1103 self.tests.append(session.config.cwd_relative_nodeid(item.nodeid)) 1104 1105 test_collector_plugin = TestCollectorPlugin() 1106 import pytest 1107 pytest.main( 1108 [arg for arg in argv if arg != '-vv'] + ['--collect-only', '-qq', '--use-main-module'], 1109 plugins=[test_collector_plugin] 1110 ) 1111 return test_collector_plugin.tests 1112 1113 1114def run_tests(argv=UNITTEST_ARGS): 1115 # import test files. 1116 if SLOW_TESTS_FILE: 1117 if os.path.exists(SLOW_TESTS_FILE): 1118 with open(SLOW_TESTS_FILE) as fp: 1119 global slow_tests_dict 1120 slow_tests_dict = json.load(fp) 1121 # use env vars so pytest-xdist subprocesses can still access them 1122 os.environ['SLOW_TESTS_FILE'] = SLOW_TESTS_FILE 1123 else: 1124 warnings.warn(f'slow test file provided but not found: {SLOW_TESTS_FILE}') 1125 if DISABLED_TESTS_FILE: 1126 if os.path.exists(DISABLED_TESTS_FILE): 1127 with open(DISABLED_TESTS_FILE) as fp: 1128 global disabled_tests_dict 1129 disabled_tests_dict = json.load(fp) 1130 os.environ['DISABLED_TESTS_FILE'] = DISABLED_TESTS_FILE 1131 else: 1132 warnings.warn(f'disabled test file provided but not found: {DISABLED_TESTS_FILE}') 1133 # Determine the test launch mechanism 1134 if TEST_DISCOVER: 1135 _print_test_names() 1136 return 1137 1138 # Before running the tests, lint to check that every test class extends from TestCase 1139 suite = unittest.TestLoader().loadTestsFromModule(__main__) 1140 if not lint_test_case_extension(suite): 1141 sys.exit(1) 1142 1143 if SHOWLOCALS: 1144 argv = [ 1145 argv[0], 1146 *(["--showlocals", "--tb=long", "--color=yes"] if USE_PYTEST else ["--locals"]), 1147 *argv[1:], 1148 ] 1149 1150 if TEST_IN_SUBPROCESS: 1151 other_args = [] 1152 if DISABLED_TESTS_FILE: 1153 other_args.append("--import-disabled-tests") 1154 if SLOW_TESTS_FILE: 1155 other_args.append("--import-slow-tests") 1156 if USE_PYTEST: 1157 other_args.append("--use-pytest") 1158 if RERUN_DISABLED_TESTS: 1159 other_args.append("--rerun-disabled-tests") 1160 if TEST_SAVE_XML: 1161 other_args += ['--save-xml', args.save_xml] 1162 1163 test_cases = ( 1164 get_pytest_test_cases(argv) if USE_PYTEST else 1165 [case.id().split('.', 1)[1] for case in discover_test_cases_recursively(suite)] 1166 ) 1167 1168 failed_tests = [] 1169 1170 for test_case_full_name in test_cases: 1171 1172 cmd = ( 1173 [sys.executable] + [argv[0]] + other_args + argv[1:] + 1174 (["--pytest-single-test"] if USE_PYTEST else []) + 1175 [test_case_full_name] 1176 ) 1177 string_cmd = " ".join(cmd) 1178 1179 timeout = None if RERUN_DISABLED_TESTS else 15 * 60 1180 1181 exitcode, _ = retry_shell(cmd, timeout=timeout, retries=0 if RERUN_DISABLED_TESTS else 1) 1182 1183 if exitcode != 0: 1184 # This is sort of hacky, but add on relevant env variables for distributed tests. 1185 if 'TestDistBackendWithSpawn' in test_case_full_name: 1186 backend = os.environ.get("BACKEND", "") 1187 world_size = os.environ.get("WORLD_SIZE", "") 1188 env_prefix = f"BACKEND={backend} WORLD_SIZE={world_size}" 1189 string_cmd = env_prefix + " " + string_cmd 1190 # Log the command to reproduce the failure. 1191 print(f"Test exited with non-zero exitcode {exitcode}. Command to reproduce: {string_cmd}") 1192 failed_tests.append(test_case_full_name) 1193 1194 assert len(failed_tests) == 0, "{} unit test(s) failed:\n\t{}".format( 1195 len(failed_tests), '\n\t'.join(failed_tests)) 1196 1197 elif RUN_PARALLEL > 1: 1198 test_cases = discover_test_cases_recursively(suite) 1199 test_batches = chunk_list(get_test_names(test_cases), RUN_PARALLEL) 1200 processes = [] 1201 for i in range(RUN_PARALLEL): 1202 command = [sys.executable] + argv + [f'--log-suffix=-shard-{i + 1}'] + test_batches[i] 1203 processes.append(subprocess.Popen(command, universal_newlines=True)) 1204 failed = False 1205 for p in processes: 1206 failed |= wait_for_process(p) != 0 1207 assert not failed, "Some test shards have failed" 1208 elif USE_PYTEST: 1209 pytest_args = argv + ["--use-main-module"] 1210 if TEST_SAVE_XML: 1211 test_report_path = get_report_path(pytest=True) 1212 print(f'Test results will be stored in {test_report_path}') 1213 pytest_args.append(f'--junit-xml-reruns={test_report_path}') 1214 if PYTEST_SINGLE_TEST: 1215 pytest_args = PYTEST_SINGLE_TEST + pytest_args[1:] 1216 1217 import pytest 1218 os.environ["NO_COLOR"] = "1" 1219 exit_code = pytest.main(args=pytest_args) 1220 if TEST_SAVE_XML: 1221 sanitize_pytest_xml(test_report_path) 1222 1223 if not RERUN_DISABLED_TESTS: 1224 # exitcode of 5 means no tests were found, which happens since some test configs don't 1225 # run tests from certain files 1226 sys.exit(0 if exit_code == 5 else exit_code) 1227 else: 1228 # Only record the test report and always return a success code when running under rerun 1229 # disabled tests mode 1230 sys.exit(0) 1231 elif TEST_SAVE_XML is not None: 1232 # import here so that non-CI doesn't need xmlrunner installed 1233 import xmlrunner # type: ignore[import] 1234 from xmlrunner.result import _XMLTestResult # type: ignore[import] 1235 1236 class XMLTestResultVerbose(_XMLTestResult): 1237 """ 1238 Adding verbosity to test outputs: 1239 by default test summary prints 'skip', 1240 but we want to also print the skip reason. 1241 GH issue: https://github.com/pytorch/pytorch/issues/69014 1242 1243 This works with unittest_xml_reporting<=3.2.0,>=2.0.0 1244 (3.2.0 is latest at the moment) 1245 """ 1246 def __init__(self, *args, **kwargs): 1247 super().__init__(*args, **kwargs) 1248 1249 def addSkip(self, test, reason): 1250 super().addSkip(test, reason) 1251 for c in self.callback.__closure__: 1252 if isinstance(c.cell_contents, str) and c.cell_contents == 'skip': 1253 # this message is printed in test summary; 1254 # it stands for `verbose_str` captured in the closure 1255 c.cell_contents = f"skip: {reason}" 1256 1257 def printErrors(self) -> None: 1258 super().printErrors() 1259 self.printErrorList("XPASS", self.unexpectedSuccesses) 1260 test_report_path = get_report_path() 1261 verbose = '--verbose' in argv or '-v' in argv 1262 if verbose: 1263 print(f'Test results will be stored in {test_report_path}') 1264 unittest.main(argv=argv, testRunner=xmlrunner.XMLTestRunner( 1265 output=test_report_path, 1266 verbosity=2 if verbose else 1, 1267 resultclass=XMLTestResultVerbose)) 1268 elif REPEAT_COUNT > 1: 1269 for _ in range(REPEAT_COUNT): 1270 if not unittest.main(exit=False, argv=argv).result.wasSuccessful(): 1271 sys.exit(-1) 1272 else: 1273 unittest.main(argv=argv) 1274 1275IS_LINUX = sys.platform == "linux" 1276IS_WINDOWS = sys.platform == "win32" 1277IS_MACOS = sys.platform == "darwin" 1278IS_PPC = platform.machine() == "ppc64le" 1279IS_X86 = platform.machine() in ('x86_64', 'i386') 1280IS_ARM64 = platform.machine() in ('arm64', 'aarch64') 1281 1282def is_avx512_vnni_supported(): 1283 if sys.platform != 'linux': 1284 return False 1285 with open("/proc/cpuinfo", encoding="ascii") as f: 1286 lines = f.read() 1287 return "vnni" in lines 1288 1289IS_AVX512_VNNI_SUPPORTED = is_avx512_vnni_supported() 1290 1291if IS_WINDOWS: 1292 @contextmanager 1293 def TemporaryFileName(*args, **kwargs): 1294 # Ideally we would like to not have to manually delete the file, but NamedTemporaryFile 1295 # opens the file, and it cannot be opened multiple times in Windows. To support Windows, 1296 # close the file after creation and try to remove it manually 1297 if 'delete' in kwargs: 1298 if kwargs['delete'] is not False: 1299 raise UserWarning("only TemporaryFileName with delete=False is supported on Windows.") 1300 else: 1301 kwargs['delete'] = False 1302 f = tempfile.NamedTemporaryFile(*args, **kwargs) 1303 try: 1304 f.close() 1305 yield f.name 1306 finally: 1307 os.unlink(f.name) 1308else: 1309 @contextmanager # noqa: T484 1310 def TemporaryFileName(*args, **kwargs): 1311 with tempfile.NamedTemporaryFile(*args, **kwargs) as f: 1312 yield f.name 1313 1314if IS_WINDOWS: 1315 @contextmanager 1316 def TemporaryDirectoryName(suffix=None): 1317 # On Windows the directory created by TemporaryDirectory is likely to be removed prematurely, 1318 # so we first create the directory using mkdtemp and then remove it manually 1319 try: 1320 dir_name = tempfile.mkdtemp(suffix=suffix) 1321 yield dir_name 1322 finally: 1323 shutil.rmtree(dir_name) 1324else: 1325 @contextmanager # noqa: T484 1326 def TemporaryDirectoryName(suffix=None): 1327 with tempfile.TemporaryDirectory(suffix=suffix) as d: 1328 yield d 1329 1330 1331def is_privateuse1_backend_available(): 1332 privateuse1_backend_name = torch._C._get_privateuse1_backend_name() 1333 privateuse1_backend_module = getattr(torch, privateuse1_backend_name, None) 1334 return hasattr(privateuse1_backend_module, "is_available") and privateuse1_backend_module.is_available() 1335 1336 1337IS_FILESYSTEM_UTF8_ENCODING = sys.getfilesystemencoding() == 'utf-8' 1338 1339TEST_NUMPY = _check_module_exists('numpy') 1340TEST_FAIRSEQ = _check_module_exists('fairseq') 1341TEST_SCIPY = _check_module_exists('scipy') 1342TEST_MKL = torch.backends.mkl.is_available() 1343TEST_MPS = torch.backends.mps.is_available() 1344TEST_XPU = torch.xpu.is_available() 1345TEST_HPU = True if (hasattr(torch, "hpu") and torch.hpu.is_available()) else False 1346TEST_CUDA = torch.cuda.is_available() 1347custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name(), None) 1348TEST_PRIVATEUSE1 = is_privateuse1_backend_available() 1349TEST_PRIVATEUSE1_DEVICE_TYPE = torch._C._get_privateuse1_backend_name() 1350TEST_NUMBA = _check_module_exists('numba') 1351TEST_TRANSFORMERS = _check_module_exists('transformers') 1352TEST_DILL = _check_module_exists('dill') 1353 1354TEST_LIBROSA = _check_module_exists('librosa') and not IS_ARM64 1355 1356TEST_OPT_EINSUM = _check_module_exists('opt_einsum') 1357 1358TEST_Z3 = _check_module_exists('z3') 1359 1360def split_if_not_empty(x: str): 1361 return x.split(",") if len(x) != 0 else [] 1362 1363NOTEST_CPU = "cpu" in split_if_not_empty(os.getenv('PYTORCH_TESTING_DEVICE_EXCEPT_FOR', '')) 1364 1365skipIfNoDill = unittest.skipIf(not TEST_DILL, "no dill") 1366 1367 1368# Python 2.7 doesn't have spawn 1369NO_MULTIPROCESSING_SPAWN: bool = TestEnvironment.def_flag( 1370 "NO_MULTIPROCESSING_SPAWN", 1371 env_var="NO_MULTIPROCESSING_SPAWN", 1372) 1373TEST_WITH_ASAN: bool = TestEnvironment.def_flag( 1374 "TEST_WITH_ASAN", 1375 env_var="PYTORCH_TEST_WITH_ASAN", 1376) 1377TEST_WITH_DEV_DBG_ASAN: bool = TestEnvironment.def_flag( 1378 "TEST_WITH_DEV_DBG_ASAN", 1379 env_var="PYTORCH_TEST_WITH_DEV_DBG_ASAN", 1380) 1381TEST_WITH_TSAN: bool = TestEnvironment.def_flag( 1382 "TEST_WITH_TSAN", 1383 env_var="PYTORCH_TEST_WITH_TSAN", 1384) 1385TEST_WITH_UBSAN: bool = TestEnvironment.def_flag( 1386 "TEST_WITH_UBSAN", 1387 env_var="PYTORCH_TEST_WITH_UBSAN", 1388) 1389TEST_WITH_ROCM: bool = TestEnvironment.def_flag( 1390 "TEST_WITH_ROCM", 1391 env_var="PYTORCH_TEST_WITH_ROCM", 1392) 1393 1394# TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen 1395# See #64427 1396TEST_WITH_MIOPEN_SUGGEST_NHWC = os.getenv('PYTORCH_MIOPEN_SUGGEST_NHWC', '0') == '1' 1397# Enables tests that are slow to run (disabled by default) 1398TEST_WITH_SLOW: bool = TestEnvironment.def_flag( 1399 "TEST_WITH_SLOW", 1400 env_var="PYTORCH_TEST_WITH_SLOW", 1401) 1402 1403# Disables non-slow tests (these tests enabled by default) 1404# This is usually used in conjunction with TEST_WITH_SLOW to 1405# run *only* slow tests. (I could have done an enum, but 1406# it felt a little awkward. 1407TEST_SKIP_FAST: bool = TestEnvironment.def_flag( 1408 "TEST_SKIP_FAST", 1409 env_var="PYTORCH_TEST_SKIP_FAST", 1410) 1411 1412# Enables crossref tests, in addition to standard tests which 1413# are being run. crossref tests work by installing a torch 1414# function mode that runs extra compute alongside the regular 1415# computation that happens with the test. After both computations 1416# are done, we cross-reference them (thus the name) to check for 1417# correction, before throwing out the extra compute and proceeding 1418# as we had before. By default, we don't run these tests. 1419TEST_WITH_CROSSREF: bool = TestEnvironment.def_flag( 1420 "TEST_WITH_CROSSREF", 1421 env_var="PYTORCH_TEST_WITH_CROSSREF", 1422) 1423 1424TEST_SKIP_CUDAGRAPH: bool = TestEnvironment.def_flag( 1425 "TEST_SKIP_CUDAGRAPH", 1426 env_var="PYTORCH_TEST_SKIP_CUDAGRAPH", 1427) 1428TEST_CUDA_GRAPH = TEST_CUDA and (not TEST_SKIP_CUDAGRAPH) and ( 1429 (torch.version.cuda and int(torch.version.cuda.split(".")[0]) >= 11) or 1430 (torch.version.hip and float(".".join(torch.version.hip.split(".")[0:2])) >= 5.3) 1431) 1432 1433TEST_CUDA_CUDSS = TEST_CUDA and (torch.version.cuda and int(torch.version.cuda.split(".")[0]) >= 12) 1434 1435def allocator_option_enabled_fn(allocator_config, _, option): 1436 if allocator_config is None: 1437 return False 1438 allocator_config = allocator_config.split(',') if ',' in allocator_config else [allocator_config] 1439 mapping = dict([var.split(':') for var in allocator_config]) 1440 1441 if option in mapping and mapping[option] == 'True': 1442 return True 1443 else: 1444 return False 1445 1446EXPANDABLE_SEGMENTS: bool = TestEnvironment.def_flag( 1447 "EXPANDABLE_SEGMENTS", 1448 env_var="PYTORCH_CUDA_ALLOC_CONF", 1449 enabled_fn=functools.partial(allocator_option_enabled_fn, option='expandable_segments'), 1450) 1451 1452if TEST_CUDA and 'NUM_PARALLEL_PROCS' in os.environ: 1453 num_procs = int(os.getenv("NUM_PARALLEL_PROCS", "2")) 1454 gb_available = torch.cuda.mem_get_info()[1] / 2 ** 30 1455 # other libraries take up about a little under 1 GB of space per process 1456 torch.cuda.set_per_process_memory_fraction(round((gb_available - num_procs * .85) / gb_available / num_procs, 2)) 1457 1458requires_cuda = unittest.skipUnless(torch.cuda.is_available(), "Requires CUDA") 1459 1460def skipIfCrossRef(fn): 1461 @wraps(fn) 1462 def wrapper(*args, **kwargs): 1463 if TEST_WITH_CROSSREF: 1464 raise unittest.SkipTest("test doesn't currently with crossref") 1465 else: 1466 fn(*args, **kwargs) 1467 return wrapper 1468 1469class CrossRefMode(torch.overrides.TorchFunctionMode): 1470 def __torch_function__(self, func, types, args=(), kwargs=None): 1471 kwargs = kwargs or {} 1472 r = func(*args, **kwargs) 1473 return r 1474 1475# Run PyTorch tests with TorchDynamo 1476TEST_WITH_TORCHINDUCTOR: bool = TestEnvironment.def_flag( 1477 "TEST_WITH_TORCHINDUCTOR", 1478 env_var="PYTORCH_TEST_WITH_INDUCTOR", 1479) 1480# AOT_EAGER not tested in ci, useful for debugging 1481TEST_WITH_AOT_EAGER: bool = TestEnvironment.def_flag( 1482 "TEST_WITH_AOT_EAGER", 1483 env_var="PYTORCH_TEST_WITH_AOT_EAGER", 1484) 1485TEST_WITH_TORCHDYNAMO: bool = TestEnvironment.def_flag( 1486 "TEST_WITH_TORCHDYNAMO", 1487 env_var="PYTORCH_TEST_WITH_DYNAMO", 1488 implied_by_fn=lambda: TEST_WITH_TORCHINDUCTOR or TEST_WITH_AOT_EAGER, 1489) 1490 1491if TEST_WITH_TORCHDYNAMO: 1492 import torch._dynamo 1493 # Do not spend time on helper functions that are called with different inputs 1494 torch._dynamo.config.accumulated_cache_size_limit = 64 1495 # Do not log compilation metrics from unit tests 1496 torch._dynamo.config.log_compilation_metrics = False 1497 if TEST_WITH_TORCHINDUCTOR: 1498 import torch._inductor.config 1499 torch._inductor.config.fallback_random = True 1500 1501 1502def xpassIfTorchDynamo(func): 1503 return func if TEST_WITH_TORCHDYNAMO else unittest.expectedFailure(func) 1504 1505 1506def xfailIfTorchDynamo(func): 1507 return unittest.expectedFailure(func) if TEST_WITH_TORCHDYNAMO else func 1508 1509 1510def skipIfTorchDynamo(msg="test doesn't currently work with dynamo"): 1511 """ 1512 Usage: 1513 @skipIfTorchDynamo(msg) 1514 def test_blah(self): 1515 ... 1516 """ 1517 assert isinstance(msg, str), "Are you using skipIfTorchDynamo correctly?" 1518 1519 def decorator(fn): 1520 if not isinstance(fn, type): 1521 @wraps(fn) 1522 def wrapper(*args, **kwargs): 1523 if TEST_WITH_TORCHDYNAMO: 1524 raise unittest.SkipTest(msg) 1525 else: 1526 fn(*args, **kwargs) 1527 return wrapper 1528 1529 assert isinstance(fn, type) 1530 if TEST_WITH_TORCHDYNAMO: 1531 fn.__unittest_skip__ = True 1532 fn.__unittest_skip_why__ = msg 1533 1534 return fn 1535 1536 return decorator 1537 1538def skipIfTorchInductor(msg="test doesn't currently work with torchinductor", 1539 condition=TEST_WITH_TORCHINDUCTOR): 1540 def decorator(fn): 1541 if not isinstance(fn, type): 1542 @wraps(fn) 1543 def wrapper(*args, **kwargs): 1544 if condition: 1545 raise unittest.SkipTest(msg) 1546 else: 1547 fn(*args, **kwargs) 1548 return wrapper 1549 1550 assert isinstance(fn, type) 1551 if condition: 1552 fn.__unittest_skip__ = True 1553 fn.__unittest_skip_why__ = msg 1554 1555 return fn 1556 1557 return decorator 1558 1559def serialTest(condition=True): 1560 """ 1561 Decorator for running tests serially. Requires pytest 1562 """ 1563 def decorator(fn): 1564 if has_pytest and condition: 1565 return pytest.mark.serial(fn) 1566 return fn 1567 return decorator 1568 1569def unMarkDynamoStrictTest(cls=None): 1570 def decorator(cls): 1571 cls.dynamo_strict = False 1572 return cls 1573 1574 if cls is None: 1575 return decorator 1576 else: 1577 return decorator(cls) 1578 1579 1580def markDynamoStrictTest(cls_or_func=None, nopython=False): 1581 """ 1582 Marks the test as 'strict'. In strict mode, we reset before and after the 1583 test, and run without suppress errors. 1584 1585 Args: 1586 - nopython: if we should run torch._dynamo.optimize with nopython={True/False}. 1587 """ 1588 def decorator(cls_or_func): 1589 if inspect.isclass(cls_or_func): 1590 cls_or_func.dynamo_strict = True 1591 cls_or_func.dynamo_strict_nopython = nopython 1592 return cls_or_func 1593 1594 fn = cls_or_func 1595 1596 @wraps(fn) 1597 def wrapper(*args, **kwargs): 1598 torch._dynamo.reset() 1599 with unittest.mock.patch("torch._dynamo.config.suppress_errors", False): 1600 fn(*args, **kwargs) 1601 torch._dynamo.reset() 1602 return wrapper 1603 1604 if cls_or_func is None: 1605 return decorator 1606 else: 1607 return decorator(cls_or_func) 1608 1609 1610def skipRocmIfTorchInductor(msg="test doesn't currently work with torchinductor on the ROCm stack"): 1611 return skipIfTorchInductor(msg=msg, condition=TEST_WITH_ROCM and TEST_WITH_TORCHINDUCTOR) 1612 1613def skipIfLegacyJitExecutor(msg="test doesn't currently work with legacy JIT executor"): 1614 def decorator(fn): 1615 if not isinstance(fn, type): 1616 @wraps(fn) 1617 def wrapper(*args, **kwargs): 1618 if GRAPH_EXECUTOR == ProfilingMode.LEGACY: 1619 raise unittest.SkipTest(msg) 1620 else: 1621 fn(*args, **kwargs) 1622 return wrapper 1623 1624 assert isinstance(fn, type) 1625 if GRAPH_EXECUTOR == ProfilingMode.LEGACY: 1626 fn.__unittest_skip__ = True 1627 fn.__unittest_skip_why__ = msg 1628 1629 return fn 1630 1631 1632 return decorator 1633 1634 1635# Run PyTorch tests with translation validation on. 1636TEST_WITH_TV = os.getenv('PYTORCH_TEST_WITH_TV') == '1' 1637 1638if TEST_WITH_TV: 1639 torch.fx.experimental._config.translation_validation = True 1640 1641# Some tests take too long when dynamic_shapes is combined with 1642# translation_validation. Whenever that happens, we solve that by 1643# disabling translation_validation. 1644def disable_translation_validation_if_dynamic_shapes(fn): 1645 @functools.wraps(fn) 1646 def wrapper(*args, **kwargs): 1647 if torch._dynamo.config.dynamic_shapes: 1648 # Turning TV off due to high latency on dynamic shapes. 1649 torch.fx.experimental._config.translation_validation = False 1650 return fn(*args, **kwargs) 1651 return wrapper 1652 1653 1654# Determine whether to enable cuda memory leak check. 1655# CUDA mem leak check is expensive and thus we don't want to execute it on every 1656# test case / configuration. 1657# If this is True then CUDA memory leak checks are skipped. If this is false 1658# then CUDA memory leak checks are performed. 1659# See: https://github.com/pytorch/pytorch/pull/59402#issuecomment-858811135 1660TEST_CUDA_MEM_LEAK_CHECK: bool = TestEnvironment.def_flag( 1661 "TEST_CUDA_MEM_LEAK_CHECK", 1662 env_var="PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", 1663) 1664 1665 1666# Dict of NumPy dtype -> torch dtype (when the correspondence exists) 1667numpy_to_torch_dtype_dict = { 1668 np.bool_ : torch.bool, 1669 np.uint8 : torch.uint8, 1670 np.uint16 : torch.uint16, 1671 np.uint32 : torch.uint32, 1672 np.uint64 : torch.uint64, 1673 np.int8 : torch.int8, 1674 np.int16 : torch.int16, 1675 np.int32 : torch.int32, 1676 np.int64 : torch.int64, 1677 np.float16 : torch.float16, 1678 np.float32 : torch.float32, 1679 np.float64 : torch.float64, 1680 np.complex64 : torch.complex64, 1681 np.complex128 : torch.complex128 1682} 1683 1684 1685# numpy dtypes like np.float64 are not instances, but rather classes. This leads to rather absurd cases like 1686# np.float64 != np.dtype("float64") but np.float64 == np.dtype("float64").type. 1687# Especially when checking against a reference we can't be sure which variant we get, so we simply try both. 1688def numpy_to_torch_dtype(np_dtype): 1689 try: 1690 return numpy_to_torch_dtype_dict[np_dtype] 1691 except KeyError: 1692 return numpy_to_torch_dtype_dict[np_dtype.type] 1693 1694 1695def has_corresponding_torch_dtype(np_dtype): 1696 try: 1697 numpy_to_torch_dtype(np_dtype) 1698 return True 1699 except KeyError: 1700 return False 1701 1702 1703if IS_WINDOWS: 1704 # Size of `np.intc` is platform defined. 1705 # It is returned by functions like `bitwise_not`. 1706 # On Windows `int` is 32-bit 1707 # https://docs.microsoft.com/en-us/cpp/cpp/data-type-ranges?view=msvc-160 1708 numpy_to_torch_dtype_dict[np.intc] = torch.int 1709 1710# Dict of torch dtype -> NumPy dtype 1711torch_to_numpy_dtype_dict = {value : key for (key, value) in numpy_to_torch_dtype_dict.items()} 1712torch_to_numpy_dtype_dict.update({ 1713 torch.bfloat16: np.float32, 1714 torch.complex32: np.complex64 1715}) 1716 1717def skipIfNNModuleInlined( 1718 msg="test doesn't currently work with nn module inlining", 1719 condition=torch._dynamo.config.inline_inbuilt_nn_modules, 1720): 1721 def decorator(fn): 1722 if not isinstance(fn, type): 1723 1724 @wraps(fn) 1725 def wrapper(*args, **kwargs): 1726 if condition: 1727 raise unittest.SkipTest(msg) 1728 else: 1729 fn(*args, **kwargs) 1730 1731 return wrapper 1732 1733 assert isinstance(fn, type) 1734 if condition: 1735 fn.__unittest_skip__ = True 1736 fn.__unittest_skip_why__ = msg 1737 1738 return fn 1739 1740 return decorator 1741 1742def skipIfRocm(func=None, *, msg="test doesn't currently work on the ROCm stack"): 1743 def dec_fn(fn): 1744 reason = f"skipIfRocm: {msg}" 1745 1746 @wraps(fn) 1747 def wrapper(*args, **kwargs): 1748 if TEST_WITH_ROCM: 1749 raise unittest.SkipTest(reason) 1750 else: 1751 return fn(*args, **kwargs) 1752 return wrapper 1753 if func: 1754 return dec_fn(func) 1755 return dec_fn 1756 1757def runOnRocm(fn): 1758 @wraps(fn) 1759 def wrapper(*args, **kwargs): 1760 if TEST_WITH_ROCM: 1761 fn(*args, **kwargs) 1762 else: 1763 raise unittest.SkipTest("test currently only works on the ROCm stack") 1764 return wrapper 1765 1766def runOnRocmArch(arch: Tuple[str, ...]): 1767 def dec_fn(fn): 1768 @wraps(fn) 1769 def wrap_fn(self, *args, **kwargs): 1770 if TEST_WITH_ROCM: 1771 prop = torch.cuda.get_device_properties(0) 1772 if prop.gcnArchName.split(":")[0] not in arch: 1773 reason = f"skipIfRocm: test only runs on {arch}" 1774 raise unittest.SkipTest(reason) 1775 return fn(self, *args, **kwargs) 1776 return wrap_fn 1777 return dec_fn 1778 1779def skipIfXpu(func=None, *, msg="test doesn't currently work on the XPU stack"): 1780 def dec_fn(fn): 1781 reason = f"skipIfXpu: {msg}" 1782 1783 @wraps(fn) 1784 def wrapper(*args, **kwargs): 1785 if TEST_XPU: 1786 raise unittest.SkipTest(reason) 1787 else: 1788 return fn(*args, **kwargs) 1789 return wrapper 1790 if func: 1791 return dec_fn(func) 1792 return dec_fn 1793 1794def skipIfMps(fn): 1795 @wraps(fn) 1796 def wrapper(*args, **kwargs): 1797 if TEST_MPS: 1798 raise unittest.SkipTest("test doesn't currently work with MPS") 1799 else: 1800 fn(*args, **kwargs) 1801 return wrapper 1802 1803def skipIfHpu(fn): 1804 @wraps(fn) 1805 def wrapper(*args, **kwargs): 1806 if TEST_HPU: 1807 raise unittest.SkipTest("test doesn't currently work with HPU") 1808 else: 1809 fn(*args, **kwargs) 1810 return wrapper 1811 1812# Skips a test on CUDA if ROCm is available and its version is lower than requested. 1813def skipIfRocmVersionLessThan(version=None): 1814 def dec_fn(fn): 1815 @wraps(fn) 1816 def wrap_fn(self, *args, **kwargs): 1817 if TEST_WITH_ROCM: 1818 rocm_version = str(torch.version.hip) 1819 rocm_version = rocm_version.split("-")[0] # ignore git sha 1820 rocm_version_tuple = tuple(int(x) for x in rocm_version.split(".")) 1821 if rocm_version_tuple is None or version is None or rocm_version_tuple < tuple(version): 1822 reason = f"ROCm {rocm_version_tuple} is available but {version} required" 1823 raise unittest.SkipTest(reason) 1824 return fn(self, *args, **kwargs) 1825 return wrap_fn 1826 return dec_fn 1827 1828def skipIfNotMiopenSuggestNHWC(fn): 1829 @wraps(fn) 1830 def wrapper(*args, **kwargs): 1831 if not TEST_WITH_MIOPEN_SUGGEST_NHWC: 1832 raise unittest.SkipTest("test doesn't currently work without MIOpen NHWC activation") 1833 else: 1834 fn(*args, **kwargs) 1835 return wrapper 1836 1837def skipIfWindows(func=None, *, msg="test doesn't currently work on the Windows stack"): 1838 def dec_fn(fn): 1839 reason = f"skipIfWindows: {msg}" 1840 1841 @wraps(fn) 1842 def wrapper(*args, **kwargs): 1843 if IS_WINDOWS: # noqa: F821 1844 raise unittest.SkipTest(reason) 1845 else: 1846 return fn(*args, **kwargs) 1847 return wrapper 1848 if func: 1849 return dec_fn(func) 1850 return dec_fn 1851 1852# Reverts the linalg backend back to default to make sure potential failures in one 1853# test do not affect other tests 1854def setLinalgBackendsToDefaultFinally(fn): 1855 @wraps(fn) 1856 def _fn(*args, **kwargs): 1857 _preferred_backend = torch.backends.cuda.preferred_linalg_library() 1858 try: 1859 fn(*args, **kwargs) 1860 finally: 1861 torch.backends.cuda.preferred_linalg_library(_preferred_backend) 1862 return _fn 1863 1864 1865# Reverts the blas backend back to default to make sure potential failures in one 1866# test do not affect other tests 1867def setBlasBackendsToDefaultFinally(fn): 1868 @wraps(fn) 1869 def _fn(*args, **kwargs): 1870 _preferred_backend = torch.backends.cuda.preferred_blas_library() 1871 try: 1872 fn(*args, **kwargs) 1873 finally: 1874 torch.backends.cuda.preferred_blas_library(_preferred_backend) 1875 return _fn 1876 1877 1878# Context manager for setting deterministic flag and automatically 1879# resetting it to its original value 1880class DeterministicGuard: 1881 def __init__(self, deterministic, *, warn_only=False, fill_uninitialized_memory=True): 1882 self.deterministic = deterministic 1883 self.warn_only = warn_only 1884 self.fill_uninitialized_memory = fill_uninitialized_memory 1885 1886 def __enter__(self): 1887 self.deterministic_restore = torch.are_deterministic_algorithms_enabled() 1888 self.warn_only_restore = torch.is_deterministic_algorithms_warn_only_enabled() 1889 self.fill_uninitialized_memory_restore = torch.utils.deterministic.fill_uninitialized_memory 1890 torch.use_deterministic_algorithms( 1891 self.deterministic, 1892 warn_only=self.warn_only) 1893 torch.utils.deterministic.fill_uninitialized_memory = self.fill_uninitialized_memory 1894 1895 def __exit__(self, exception_type, exception_value, traceback): 1896 torch.use_deterministic_algorithms( 1897 self.deterministic_restore, 1898 warn_only=self.warn_only_restore) 1899 torch.utils.deterministic.fill_uninitialized_memory = self.fill_uninitialized_memory_restore 1900 1901class AlwaysWarnTypedStorageRemoval: 1902 def __init__(self, always_warn): 1903 assert isinstance(always_warn, bool) 1904 self.always_warn = always_warn 1905 1906 def __enter__(self): 1907 self.always_warn_restore = torch.storage._get_always_warn_typed_storage_removal() 1908 torch.storage._set_always_warn_typed_storage_removal(self.always_warn) 1909 1910 def __exit__(self, exception_type, exception_value, traceback): 1911 torch.storage._set_always_warn_typed_storage_removal(self.always_warn_restore) 1912 1913# Context manager for setting cuda sync debug mode and reset it 1914# to original value 1915# we are not exposing it to the core because sync debug mode is 1916# global and thus not thread safe 1917class CudaSyncGuard: 1918 def __init__(self, sync_debug_mode): 1919 self.mode = sync_debug_mode 1920 1921 def __enter__(self): 1922 self.debug_mode_restore = torch.cuda.get_sync_debug_mode() 1923 torch.cuda.set_sync_debug_mode(self.mode) 1924 1925 def __exit__(self, exception_type, exception_value, traceback): 1926 torch.cuda.set_sync_debug_mode(self.debug_mode_restore) 1927 1928# Context manager for setting torch.__future__.set_swap_module_params_on_conversion 1929# and automatically resetting it to its original value 1930class SwapTensorsGuard: 1931 def __init__(self, use_swap_tensors): 1932 self.use_swap_tensors = use_swap_tensors 1933 1934 def __enter__(self): 1935 self.swap_tensors_restore = torch.__future__.get_swap_module_params_on_conversion() 1936 if self.use_swap_tensors is not None: 1937 torch.__future__.set_swap_module_params_on_conversion(self.use_swap_tensors) 1938 1939 def __exit__(self, exception_type, exception_value, traceback): 1940 torch.__future__.set_swap_module_params_on_conversion(self.swap_tensors_restore) 1941 1942# This decorator can be used for API tests that call 1943# torch.use_deterministic_algorithms(). When the test is finished, it will 1944# restore the previous deterministic flag setting. 1945# 1946# If CUDA >= 10.2, this will set the environment variable 1947# CUBLAS_WORKSPACE_CONFIG=:4096:8 so that the error associated with that 1948# setting is not thrown during the test unless the test changes that variable 1949# on purpose. The previous CUBLAS_WORKSPACE_CONFIG setting will also be 1950# restored once the test is finished. 1951# 1952# Note that if a test requires CUDA to actually register the changed 1953# CUBLAS_WORKSPACE_CONFIG variable, a new subprocess must be created, because 1954# CUDA only checks the variable when the runtime initializes. Tests can be 1955# run inside a subprocess like so: 1956# 1957# import subprocess, sys, os 1958# script = ''' 1959# # Test code should go here 1960# ''' 1961# try: 1962# subprocess.check_output( 1963# [sys.executable, '-c', script], 1964# stderr=subprocess.STDOUT, 1965# cwd=os.path.dirname(os.path.realpath(__file__)), 1966# env=os.environ.copy()) 1967# except subprocess.CalledProcessError as e: 1968# error_message = e.output.decode('utf-8') 1969# # Handle exceptions raised by the subprocess here 1970# 1971def wrapDeterministicFlagAPITest(fn): 1972 @wraps(fn) 1973 def wrapper(*args, **kwargs): 1974 with DeterministicGuard( 1975 torch.are_deterministic_algorithms_enabled(), 1976 warn_only=torch.is_deterministic_algorithms_warn_only_enabled()): 1977 class CuBLASConfigGuard: 1978 cublas_var_name = 'CUBLAS_WORKSPACE_CONFIG' 1979 1980 def __enter__(self): 1981 self.is_cuda10_2_or_higher = ( 1982 (torch.version.cuda is not None) 1983 and ([int(x) for x in torch.version.cuda.split(".")] >= [10, 2])) 1984 if self.is_cuda10_2_or_higher: 1985 self.cublas_config_restore = os.environ.get(self.cublas_var_name) 1986 os.environ[self.cublas_var_name] = ':4096:8' 1987 1988 def __exit__(self, exception_type, exception_value, traceback): 1989 if self.is_cuda10_2_or_higher: 1990 cur_cublas_config = os.environ.get(self.cublas_var_name) 1991 if self.cublas_config_restore is None: 1992 if cur_cublas_config is not None: 1993 del os.environ[self.cublas_var_name] 1994 else: 1995 os.environ[self.cublas_var_name] = self.cublas_config_restore 1996 with CuBLASConfigGuard(): 1997 fn(*args, **kwargs) 1998 return wrapper 1999 2000# This decorator can be used for API tests that want to safely call 2001# torch.__future__.set_swap_module_params_on_conversion. `swap` can be set to 2002# True, False or None where None indicates that the context manager does not 2003# set the flag. When the test is finished, it will restore the previous swap 2004# flag setting. 2005def wrapSwapTensorsTest(swap=None): 2006 def dec_fn(fn): 2007 @wraps(fn) 2008 def wrapper(*args, **kwargs): 2009 with SwapTensorsGuard(swap): 2010 fn(*args, **kwargs) 2011 return wrapper 2012 return dec_fn 2013 2014# test parametrizer for swapping 2015class swap(_TestParametrizer): 2016 def __init__(self, swap_values): 2017 super().__init__() 2018 self.swap_values = swap_values 2019 2020 def _parametrize_test(self, test, generic_cls, device_cls): 2021 for swap in self.swap_values: 2022 yield wrapSwapTensorsTest(swap)(test), f'swap_{swap}', {}, lambda _: [] 2023 2024def skipIfCompiledWithoutNumpy(fn): 2025 # Even if the numpy module is present, if `USE_NUMPY=0` is used during the 2026 # build, numpy tests will fail 2027 numpy_support = TEST_NUMPY 2028 if numpy_support: 2029 try: 2030 # The numpy module is present, verify that PyTorch is compiled with 2031 # numpy support 2032 torch.from_numpy(np.array([2, 2])) 2033 except RuntimeError: 2034 numpy_support = False 2035 2036 @wraps(fn) 2037 def wrapper(*args, **kwargs): 2038 if not numpy_support: 2039 raise unittest.SkipTest("PyTorch was compiled without numpy support") 2040 else: 2041 fn(*args, **kwargs) 2042 return wrapper 2043 2044def _test_function(fn, device): 2045 def run_test_function(self): 2046 return fn(self, device) 2047 return run_test_function 2048 2049def skipIfNoXNNPACK(fn): 2050 @wraps(fn) 2051 def wrapper(*args, **kwargs): 2052 if not torch.backends.xnnpack.enabled: 2053 raise unittest.SkipTest('XNNPACK must be enabled for these tests. Please build with USE_XNNPACK=1.') 2054 else: 2055 fn(*args, **kwargs) 2056 return wrapper 2057 2058def skipIfNoLapack(fn): 2059 @wraps(fn) 2060 def wrapper(*args, **kwargs): 2061 if not torch._C.has_lapack: 2062 raise unittest.SkipTest('PyTorch compiled without Lapack') 2063 else: 2064 fn(*args, **kwargs) 2065 return wrapper 2066 2067def skipIfNotRegistered(op_name, message): 2068 """Wraps the decorator to hide the import of the `core`. 2069 2070 Args: 2071 op_name: Check if this op is registered in `core._REGISTERED_OPERATORS`. 2072 message: message to fail with. 2073 2074 Usage: 2075 @skipIfNotRegistered('MyOp', 'MyOp is not linked!') 2076 This will check if 'MyOp' is in the caffe2.python.core 2077 """ 2078 return unittest.skip("Pytorch is compiled without Caffe2") 2079 2080def skipIfNoSciPy(fn): 2081 @wraps(fn) 2082 def wrapper(*args, **kwargs): 2083 if not TEST_SCIPY: 2084 raise unittest.SkipTest("test require SciPy, but SciPy not found") 2085 else: 2086 fn(*args, **kwargs) 2087 return wrapper 2088 2089def skip_if_pytest(fn): 2090 @wraps(fn) 2091 def wrapped(*args, **kwargs): 2092 if "PYTEST_CURRENT_TEST" in os.environ: 2093 raise unittest.SkipTest("does not work under pytest") 2094 return fn(*args, **kwargs) 2095 2096 return wrapped 2097 2098 2099def slowTest(fn): 2100 @wraps(fn) 2101 def wrapper(*args, **kwargs): 2102 if not TEST_WITH_SLOW: 2103 raise unittest.SkipTest("test is slow; run with PYTORCH_TEST_WITH_SLOW to enable test") 2104 else: 2105 fn(*args, **kwargs) 2106 wrapper.__dict__['slow_test'] = True 2107 return wrapper 2108 2109 2110def slowTestIf(condition): 2111 return slowTest if condition else lambda fn: fn 2112 2113 2114def skipCUDAMemoryLeakCheckIf(condition): 2115 def dec(fn): 2116 if getattr(fn, '_do_cuda_memory_leak_check', True): # if current True 2117 fn._do_cuda_memory_leak_check = not condition 2118 return fn 2119 return dec 2120 2121def skipCUDANonDefaultStreamIf(condition): 2122 def dec(fn): 2123 if getattr(fn, '_do_cuda_non_default_stream', True): # if current True 2124 fn._do_cuda_non_default_stream = not condition 2125 return fn 2126 return dec 2127 2128def suppress_warnings(fn): 2129 @wraps(fn) 2130 def wrapper(*args, **kwargs): 2131 with warnings.catch_warnings(): 2132 warnings.simplefilter("ignore") 2133 fn(*args, **kwargs) 2134 return wrapper 2135 2136 2137def to_gpu(obj, type_map=None): 2138 if type_map is None: 2139 type_map = {} 2140 if isinstance(obj, torch.Tensor): 2141 assert obj.is_leaf 2142 t = type_map.get(obj.dtype, obj.dtype) 2143 with torch.no_grad(): 2144 res = obj.clone().to(dtype=t, device="cuda") 2145 res.requires_grad = obj.requires_grad 2146 return res 2147 elif torch.is_storage(obj): 2148 return obj.new().resize_(obj.size()).copy_(obj) 2149 elif isinstance(obj, list): 2150 return [to_gpu(o, type_map) for o in obj] 2151 elif isinstance(obj, tuple): 2152 return tuple(to_gpu(o, type_map) for o in obj) 2153 else: 2154 return deepcopy(obj) 2155 2156 2157def get_function_arglist(func): 2158 return inspect.getfullargspec(func).args 2159 2160 2161def set_rng_seed(seed): 2162 torch.manual_seed(seed) 2163 random.seed(seed) 2164 if TEST_NUMPY: 2165 np.random.seed(seed) 2166 2167 2168@contextlib.contextmanager 2169def set_default_dtype(dtype): 2170 saved_dtype = torch.get_default_dtype() 2171 torch.set_default_dtype(dtype) 2172 try: 2173 yield 2174 finally: 2175 torch.set_default_dtype(saved_dtype) 2176 2177@contextlib.contextmanager 2178def set_default_tensor_type(tensor_type): 2179 saved_tensor_type = torch.tensor([]).type() 2180 torch.set_default_tensor_type(tensor_type) 2181 try: 2182 yield 2183 finally: 2184 torch.set_default_tensor_type(saved_tensor_type) 2185 2186def iter_indices(tensor): 2187 if tensor.dim() == 0: 2188 return range(0) 2189 if tensor.dim() == 1: 2190 return range(tensor.size(0)) 2191 return product(*(range(s) for s in tensor.size())) 2192 2193 2194def is_iterable(obj): 2195 try: 2196 iter(obj) 2197 return True 2198 except TypeError: 2199 return False 2200 2201 2202def is_iterable_of_tensors(iterable, include_empty=False): 2203 """ Returns True if iterable is an iterable of tensors and False o.w. 2204 2205 If the iterable is empty, the return value is :attr:`include_empty` 2206 """ 2207 # Tensor itself is iterable so we check this first 2208 if isinstance(iterable, torch.Tensor): 2209 return False 2210 2211 try: 2212 if len(iterable) == 0: 2213 return include_empty 2214 2215 for t in iter(iterable): 2216 if not isinstance(t, torch.Tensor): 2217 return False 2218 2219 except TypeError as te: 2220 return False 2221 2222 return True 2223 2224 2225class CudaNonDefaultStream: 2226 def __enter__(self): 2227 # Before starting CUDA test save currently active streams on all 2228 # CUDA devices and set new non default streams to all CUDA devices 2229 # to ensure CUDA tests do not use default stream by mistake. 2230 beforeDevice = torch.cuda.current_device() 2231 self.beforeStreams = [] 2232 for d in range(torch.cuda.device_count()): 2233 self.beforeStreams.append(torch.cuda.current_stream(d)) 2234 deviceStream = torch.cuda.Stream(device=d) 2235 self.beforeStreams[-1].synchronize() 2236 torch._C._cuda_setStream(stream_id=deviceStream.stream_id, 2237 device_index=deviceStream.device_index, 2238 device_type=deviceStream.device_type) 2239 torch._C._cuda_setDevice(beforeDevice) 2240 2241 def __exit__(self, exec_type, exec_value, traceback): 2242 # After completing CUDA test load previously active streams on all 2243 # CUDA devices. 2244 beforeDevice = torch.cuda.current_device() 2245 for d in range(torch.cuda.device_count()): 2246 torch._C._cuda_setStream(stream_id=self.beforeStreams[d].stream_id, 2247 device_index=self.beforeStreams[d].device_index, 2248 device_type=self.beforeStreams[d].device_type) 2249 torch._C._cuda_setDevice(beforeDevice) 2250 2251class CudaMemoryLeakCheck: 2252 def __init__(self, testcase, name=None): 2253 self.name = testcase.id() if name is None else name 2254 self.testcase = testcase 2255 2256 # initialize context & RNG to prevent false positive detections 2257 # when the test is the first to initialize those 2258 from torch.testing._internal.common_cuda import initialize_cuda_context_rng 2259 initialize_cuda_context_rng() 2260 2261 # Stores CUDA memory data provided by PyTorch's caching allocator and 2262 # the CUDA driver. 2263 # 2264 # NOTE: The undocumented torch.cuda.mem_get_info() returns 2265 # (#free bytes, #total bytes available) on the GPU 2266 def __enter__(self): 2267 self.caching_allocator_befores = [] 2268 self.driver_befores = [] 2269 2270 # Performs a gc if required (required if any CUDA memory is held) 2271 num_devices = torch.cuda.device_count() 2272 for i in range(num_devices): 2273 caching_allocator_mem_allocated = torch.cuda.memory_allocated(i) 2274 # NOTE: gc is based exclusively on caching allocator memory 2275 # because the driver will always have some bytes in use (context size?) 2276 if caching_allocator_mem_allocated > 0: 2277 gc.collect() 2278 torch._C._cuda_clearCublasWorkspaces() 2279 torch.cuda.empty_cache() 2280 break 2281 2282 # Acquires caching allocator and driver statistics before the test is run 2283 for i in range(num_devices): 2284 self.caching_allocator_befores.append(torch.cuda.memory_allocated(i)) 2285 bytes_free, bytes_total = torch.cuda.mem_get_info(i) 2286 driver_mem_allocated = bytes_total - bytes_free 2287 self.driver_befores.append(driver_mem_allocated) 2288 2289 def __exit__(self, exec_type, exec_value, traceback): 2290 # Don't check for leaks if an exception was thrown 2291 if exec_type is not None: 2292 return 2293 2294 # Compares caching allocator before/after statistics 2295 # An increase in allocated memory is a discrepancy indicating a possible 2296 # memory leak 2297 discrepancy_detected = False 2298 num_devices = torch.cuda.device_count() 2299 for i in range(num_devices): 2300 # avoid counting cublasWorkspace allocations 2301 torch._C._cuda_clearCublasWorkspaces() 2302 caching_allocator_mem_allocated = torch.cuda.memory_allocated(i) 2303 2304 if caching_allocator_mem_allocated > self.caching_allocator_befores[i]: 2305 discrepancy_detected = True 2306 break 2307 2308 # Short-circuits if no discrepancy detected 2309 if not discrepancy_detected: 2310 return 2311 2312 # Validates the discrepancy persists after garbage collection and 2313 # is confirmed by the driver API 2314 2315 # NOTE: driver API iscrepancies alone are ignored because with the jiterator 2316 # some tests may permanently increase the CUDA context size and 2317 # that will appear as a driver memory leak but is the expected behavior. 2318 2319 # GCs and clears the cache 2320 gc.collect() 2321 torch.cuda.empty_cache() 2322 2323 for i in range(num_devices): 2324 2325 discrepancy_detected = True 2326 2327 # Query memory multiple items to ensure leak was not transient 2328 for n in range(3): 2329 caching_allocator_mem_allocated = torch.cuda.memory_allocated(i) 2330 bytes_free, bytes_total = torch.cuda.mem_get_info(i) 2331 driver_mem_allocated = bytes_total - bytes_free 2332 2333 caching_allocator_discrepancy = False 2334 driver_discrepancy = False 2335 2336 if caching_allocator_mem_allocated > self.caching_allocator_befores[i]: 2337 caching_allocator_discrepancy = True 2338 2339 if driver_mem_allocated > self.driver_befores[i]: 2340 driver_discrepancy = True 2341 2342 if not (caching_allocator_discrepancy or driver_discrepancy): 2343 # Leak was false positive, exit loop 2344 discrepancy_detected = False 2345 break 2346 2347 if not discrepancy_detected: 2348 continue 2349 2350 if caching_allocator_discrepancy and not driver_discrepancy: 2351 # Just raises a warning if the leak is not validated by the 2352 # driver API 2353 # NOTE: this may be a problem with how the caching allocator collects its 2354 # statistics or a leak too small to trigger the allocation of an 2355 # additional block of memory by the CUDA driver 2356 msg = ("CUDA caching allocator reports a memory leak not " 2357 f"verified by the driver API in {self.name}! " 2358 f"Caching allocator allocated memory was {self.caching_allocator_befores[i]} " 2359 f"and is now reported as {caching_allocator_mem_allocated} " 2360 f"on device {i}. " 2361 f"CUDA driver allocated memory was {self.driver_befores[i]} and is now {driver_mem_allocated}.") 2362 warnings.warn(msg) 2363 elif caching_allocator_discrepancy and driver_discrepancy: 2364 # A caching allocator discrepancy validated by the driver API is a 2365 # failure (except on ROCm, see below) 2366 msg = (f"CUDA driver API confirmed a leak in {self.name}! " 2367 f"Caching allocator allocated memory was {self.caching_allocator_befores[i]} " 2368 f"and is now reported as {caching_allocator_mem_allocated} " 2369 f"on device {i}. " 2370 f"CUDA driver allocated memory was {self.driver_befores[i]} and is now {driver_mem_allocated}.") 2371 2372 raise RuntimeError(msg) 2373 2374@contextmanager 2375def skip_exception_type(exc_type): 2376 try: 2377 yield 2378 except exc_type as e: 2379 raise unittest.SkipTest(f"not implemented: {e}") from e 2380 2381@contextmanager 2382def print_repro_on_failure(repro_parts): 2383 try: 2384 yield 2385 except unittest.SkipTest: 2386 raise 2387 except Exception as e: 2388 # Get the index of the sample input that failed the test if possible. 2389 sample_isolation_prefix = "" 2390 tracked_input = getattr(e, "_tracked_input", None) 2391 if tracked_input is not None: 2392 sample_isolation_prefix = f"PYTORCH_OPINFO_SAMPLE_INPUT_INDEX={tracked_input.index}" 2393 2394 repro_str = " ".join(filter(None, (sample_isolation_prefix, *repro_parts))) 2395 repro_msg = f""" 2396To execute this test, run the following from the base repo dir: 2397 {repro_str} 2398 2399This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0""" 2400 2401 # NB: Hacking the exception args is the cleanest way I've found to append 2402 # failure reproduction info without poisoning the stack trace. 2403 if len(e.args) >= 1: 2404 e.args = (f"{e.args[0]}\n{repro_msg}", *e.args[1:]) 2405 raise 2406 2407# "min_satisfying_examples" setting has been deprecated in hypothesis 2408# 3.56.0 and removed in hypothesis 4.x 2409try: 2410 import hypothesis 2411 2412 def settings(*args, **kwargs): 2413 if 'min_satisfying_examples' in kwargs and hypothesis.version.__version_info__ >= (3, 56, 0): 2414 kwargs.pop('min_satisfying_examples') 2415 return hypothesis.settings(*args, **kwargs) 2416 2417 2418 hypothesis.settings.register_profile( 2419 "pytorch_ci", 2420 settings( 2421 derandomize=True, 2422 suppress_health_check=[hypothesis.HealthCheck.too_slow], 2423 database=None, 2424 max_examples=50, 2425 verbosity=hypothesis.Verbosity.normal)) 2426 hypothesis.settings.register_profile( 2427 "dev", 2428 settings( 2429 suppress_health_check=[hypothesis.HealthCheck.too_slow], 2430 database=None, 2431 max_examples=10, 2432 verbosity=hypothesis.Verbosity.normal)) 2433 hypothesis.settings.register_profile( 2434 "debug", 2435 settings( 2436 suppress_health_check=[hypothesis.HealthCheck.too_slow], 2437 database=None, 2438 max_examples=1000, 2439 verbosity=hypothesis.Verbosity.verbose)) 2440 2441 hypothesis.settings.load_profile( 2442 "pytorch_ci" if IS_CI else os.getenv('PYTORCH_HYPOTHESIS_PROFILE', 'dev') 2443 ) 2444except ImportError: 2445 print('Fail to import hypothesis in common_utils, tests are not derandomized') 2446 2447# Used in check_if_enable to see if a test method should be disabled by an issue, 2448# sanitizes a test method name from appended suffixes by @dtypes parametrization. 2449# e.g., an issue with title "DISABLED test_bitwise_ops (__main__.TestBinaryUfuncs)" should 2450# disabled ALL parametrized test_bitwise_ops tests, such test_bitwise_ops_cuda_int32 2451def remove_device_and_dtype_suffixes(test_name: str) -> str: 2452 # import statement is localized to avoid circular dependency issues with common_device_type.py 2453 from torch.testing._internal.common_device_type import get_device_type_test_bases 2454 device_suffixes = [x.device_type for x in get_device_type_test_bases()] 2455 dtype_suffixes = [str(dt)[len("torch."):] for dt in get_all_dtypes()] 2456 2457 test_name_chunks = test_name.split("_") 2458 if len(test_name_chunks) > 0 and test_name_chunks[-1] in dtype_suffixes: 2459 if len(test_name_chunks) > 1 and test_name_chunks[-2] in device_suffixes: 2460 return "_".join(test_name_chunks[0:-2]) 2461 return "_".join(test_name_chunks[0:-1]) 2462 return test_name 2463 2464 2465def check_if_enable(test: unittest.TestCase): 2466 classname = str(test.__class__).split("'")[1].split(".")[-1] 2467 sanitized_testname = remove_device_and_dtype_suffixes(test._testMethodName) 2468 2469 def matches_test(target: str): 2470 target_test_parts = target.split() 2471 if len(target_test_parts) < 2: 2472 # poorly formed target test name 2473 return False 2474 target_testname = target_test_parts[0] 2475 target_classname = target_test_parts[1][1:-1].split(".")[-1] 2476 # if test method name or its sanitized version exactly matches the disabled 2477 # test method name AND allow non-parametrized suite names to disable 2478 # parametrized ones (TestSuite disables TestSuiteCPU) 2479 return classname.startswith(target_classname) and (target_testname in (test._testMethodName, sanitized_testname)) 2480 2481 if any(matches_test(x) for x in slow_tests_dict.keys()): 2482 getattr(test, test._testMethodName).__dict__['slow_test'] = True 2483 if not TEST_WITH_SLOW: 2484 raise unittest.SkipTest("test is slow; run with PYTORCH_TEST_WITH_SLOW to enable test") 2485 2486 if not IS_SANDCASTLE: 2487 should_skip = False 2488 skip_msg = "" 2489 2490 for disabled_test, (issue_url, platforms) in disabled_tests_dict.items(): 2491 if matches_test(disabled_test): 2492 platform_to_conditional: Dict = { 2493 "mac": IS_MACOS, 2494 "macos": IS_MACOS, 2495 "win": IS_WINDOWS, 2496 "windows": IS_WINDOWS, 2497 "linux": IS_LINUX, 2498 "rocm": TEST_WITH_ROCM, 2499 "xpu": TEST_XPU, 2500 "asan": TEST_WITH_ASAN, 2501 "dynamo": TEST_WITH_TORCHDYNAMO, 2502 "inductor": TEST_WITH_TORCHINDUCTOR, 2503 "slow": TEST_WITH_SLOW, 2504 } 2505 2506 invalid_platforms = list(filter(lambda p: p not in platform_to_conditional, platforms)) 2507 if len(invalid_platforms) > 0: 2508 invalid_plats_str = ", ".join(invalid_platforms) 2509 valid_plats = ", ".join(platform_to_conditional.keys()) 2510 2511 print(f"Test {disabled_test} is disabled for some unrecognized ", 2512 f"platforms: [{invalid_plats_str}]. Please edit issue {issue_url} to fix the platforms ", 2513 'assigned to this flaky test, changing "Platforms: ..." to a comma separated ', 2514 f"subset of the following (or leave it blank to match all platforms): {valid_plats}") 2515 2516 # Sanitize the platforms list so that we continue to disable the test for any valid platforms given 2517 platforms = list(filter(lambda p: p in platform_to_conditional, platforms)) 2518 2519 if platforms == [] or any(platform_to_conditional[platform] for platform in platforms): 2520 should_skip = True 2521 skip_msg = f"Test is disabled because an issue exists disabling it: {issue_url}" \ 2522 f" for {'all' if platforms == [] else ''}platform(s) {', '.join(platforms)}. " \ 2523 "If you're seeing this on your local machine and would like to enable this test, " \ 2524 "please make sure CI is not set and you are not using the flag --import-disabled-tests." 2525 break 2526 2527 if should_skip and not RERUN_DISABLED_TESTS: 2528 # Skip the disabled test when not running under --rerun-disabled-tests verification mode 2529 raise unittest.SkipTest(skip_msg) 2530 2531 if not should_skip and RERUN_DISABLED_TESTS: 2532 skip_msg = "Test is enabled but --rerun-disabled-tests verification mode is set, so only" \ 2533 " disabled tests are run" 2534 raise unittest.SkipTest(skip_msg) 2535 2536 if TEST_SKIP_FAST: 2537 if hasattr(test, test._testMethodName) and not getattr(test, test._testMethodName).__dict__.get('slow_test', False): 2538 raise unittest.SkipTest("test is fast; we disabled it with PYTORCH_TEST_SKIP_FAST") 2539 2540 2541# `TestCase.assertEqual` is very permissive and coerced the inputs into a format that could be compared. This is very 2542# convenient when writing tests, but not so much while reviewing them. By default, the comparison `Pair` framework of 2543# `torch.testing._comparison.are_equal`, used for example by the public testing function 2544# `torch.testing.assert_close`, is more strict. In order to use the same framework and thus reduce the divergence 2545# between internal and external comparison logic as much as possible, we define some "relaxed" pairs here. They only 2546# change the supported inputs, but the comparison logic is the same. 2547# TODO: Revisit the relaxed pairs and check how much work it is to fix the tests that would fail without the relaxation. 2548 2549class RelaxedBooleanPair(BooleanPair): 2550 """Pair for boolean-like inputs. 2551 2552 In contrast to the builtin :class:`BooleanPair`, this class also supports one input being a number or a single 2553 element tensor-like. 2554 """ 2555 _supported_number_types = NumberPair(0, 0)._supported_types 2556 2557 def _process_inputs(self, actual, expected, *, id): 2558 # We require only one of the inputs of the inputs to be a boolean and the other can also be a boolean, a 2559 # number, or a single element tensor or array, whereas in default BooleanPair both inputs have to be booleans. 2560 tensor_or_array_types: Tuple[Type, ...] = (torch.Tensor, np.ndarray) 2561 other_supported_types = (*self._supported_types, *self._supported_number_types, *tensor_or_array_types) 2562 if not ( 2563 (isinstance(actual, self._supported_types) and isinstance(expected, other_supported_types)) 2564 or (isinstance(expected, self._supported_types) and isinstance(actual, other_supported_types)) 2565 ): 2566 self._inputs_not_supported() 2567 2568 return [self._to_bool(input, id=id) for input in (actual, expected)] 2569 2570 def _to_bool(self, bool_like, *, id): 2571 if isinstance(bool_like, np.number): 2572 return bool(bool_like.item()) 2573 elif type(bool_like) in self._supported_number_types: 2574 return bool(bool_like) 2575 elif isinstance(bool_like, (torch.Tensor, np.ndarray)): 2576 numel = bool_like.numel() if isinstance(bool_like, torch.Tensor) else bool_like.size 2577 if numel > 1: 2578 self._fail( 2579 ValueError, 2580 f"Only single element tensor-likes can be compared against a boolean. " 2581 f"Got {numel} elements instead.", 2582 id=id 2583 ) 2584 2585 return bool(bool_like.item()) 2586 else: 2587 return super()._to_bool(bool_like, id=id) 2588 2589 2590class RelaxedNumberPair(NumberPair): 2591 """Pair for number-like inputs. 2592 2593 In contrast to the builtin :class:`NumberPair`, this class also supports one input being a single element 2594 tensor-like or a :class:`enum.Enum`. (D)Type checks are disabled, meaning comparing 1 to 1.0 succeeds even when 2595 ``check_dtype=True`` is passed. 2596 2597 In addition, this class uses looser default tolerances for :class:`float` and :class:`complex` inputs. Also 2598 supports overriding the absolute and relative tolerance through the ``@precisionOverride`` and 2599 ``@toleranceOverride`` decorators. 2600 """ 2601 _TYPE_TO_DTYPE = { 2602 int: torch.int64, 2603 float: torch.float32, 2604 complex: torch.complex64, 2605 } 2606 2607 def __init__( 2608 self, actual, expected, *, rtol_override=0.0, atol_override=0.0, check_dtype=None, **other_parameters 2609 ) -> None: 2610 super().__init__(actual, expected, check_dtype=False, **other_parameters) 2611 self.rtol = max(self.rtol, rtol_override) 2612 self.atol = max(self.atol, atol_override) 2613 2614 def _process_inputs(self, actual, expected, *, id): 2615 # We require only one of the inputs of the inputs to be a number and the other can also be a number or a single 2616 # element tensor or array, whereas in default NumberPair both inputs have to be numbers. 2617 tensor_or_array_types: Tuple[Type, ...] = (torch.Tensor, np.ndarray) 2618 other_supported_types = (*self._supported_types, *tensor_or_array_types) 2619 if not ( 2620 (isinstance(actual, self._supported_types) and isinstance(expected, other_supported_types)) 2621 or (isinstance(expected, self._supported_types) and isinstance(actual, other_supported_types)) 2622 ): 2623 self._inputs_not_supported() 2624 2625 return [self._to_number(input, id=id) for input in (actual, expected)] 2626 2627 def _to_number(self, number_like, *, id): 2628 if isinstance(number_like, (torch.Tensor, np.ndarray)): 2629 numel = number_like.numel() if isinstance(number_like, torch.Tensor) else number_like.size 2630 if numel > 1: 2631 self._fail( 2632 ValueError, 2633 f"Only single element tensor-likes can be compared against a number. " 2634 f"Got {numel} elements instead.", 2635 id=id 2636 ) 2637 number = number_like.item() 2638 if isinstance(number, bool): 2639 number = int(number) 2640 2641 return number 2642 elif isinstance(number_like, Enum): 2643 return int(number_like) # type: ignore[call-overload] 2644 else: 2645 return super()._to_number(number_like, id=id) 2646 2647 2648class TensorOrArrayPair(TensorLikePair): 2649 """Pair for tensor-like inputs. 2650 2651 On the one hand this class is stricter than the builtin :class:`TensorLikePair` since it only allows instances of 2652 :class:`torch.Tensor` and :class:`numpy.ndarray` rather than allowing any tensor-like than can be converted into a 2653 tensor. On the other hand this class is looser since it converts all inputs into tensors with no regard of their 2654 relationship, e.g. comparing a :class:`torch.Tensor` to :class:`numpy.ndarray` is fine. 2655 2656 In addition, this class supports overriding the absolute and relative tolerance through the ``@precisionOverride`` 2657 and ``@toleranceOverride`` decorators. 2658 """ 2659 def __init__(self, actual, expected, *, rtol_override=0.0, atol_override=0.0, **other_parameters): 2660 super().__init__(actual, expected, **other_parameters) 2661 self.rtol = max(self.rtol, rtol_override) 2662 self.atol = max(self.atol, atol_override) 2663 2664 def _process_inputs(self, actual, expected, *, id, allow_subclasses): 2665 self._check_inputs_isinstance(actual, expected, cls=(torch.Tensor, np.ndarray)) 2666 2667 actual, expected = (self._to_tensor(input) for input in (actual, expected)) 2668 for tensor in (actual, expected): 2669 self._check_supported(tensor, id=id) 2670 return actual, expected 2671 2672 2673class TypedStoragePair(TensorLikePair): 2674 """Pair for :class:`torch.storage.TypedStorage` inputs.""" 2675 def __init__(self, actual, expected, *, rtol_override=0.0, atol_override=0.0, **other_parameters): 2676 self._check_inputs_isinstance(actual, expected, cls=torch.storage.TypedStorage) 2677 super().__init__(actual, expected, **other_parameters) 2678 self.rtol = max(self.rtol, rtol_override) 2679 self.atol = max(self.atol, atol_override) 2680 2681 def _to_tensor(self, typed_storage): 2682 return torch.tensor( 2683 typed_storage._untyped_storage, 2684 dtype={ 2685 torch.quint8: torch.uint8, 2686 torch.quint4x2: torch.uint8, 2687 torch.quint2x4: torch.uint8, 2688 torch.qint32: torch.int32, 2689 torch.qint8: torch.int8 2690 }.get(typed_storage.dtype, typed_storage.dtype), 2691 device=typed_storage.device, 2692 ) 2693 2694 2695class UnittestPair(Pair): 2696 """Fallback ABC pair that handles non-numeric inputs. 2697 2698 To avoid recreating the mismatch messages of :meth:`unittest.TestCase.assertEqual`, this pair simply wraps it in 2699 order to use it with the :class:`Pair` "framework" from :func:`are_equal`. 2700 2701 Define the :attr:`UnittestPair.CLS` in a subclass to indicate which class(es) of the inputs the pair should support. 2702 """ 2703 CLS: Union[Type, Tuple[Type, ...]] 2704 TYPE_NAME: Optional[str] = None 2705 2706 def __init__(self, actual, expected, **other_parameters): 2707 self._check_inputs_isinstance(actual, expected, cls=self.CLS) 2708 super().__init__(actual, expected, **other_parameters) 2709 2710 def compare(self): 2711 test_case = unittest.TestCase() 2712 2713 try: 2714 return test_case.assertEqual(self.actual, self.expected) 2715 except test_case.failureException as error: 2716 msg = str(error) 2717 2718 type_name = self.TYPE_NAME or (self.CLS if isinstance(self.CLS, type) else self.CLS[0]).__name__ 2719 self._fail(AssertionError, f"{type_name.title()} comparison failed: {msg}") 2720 2721 2722class StringPair(UnittestPair): 2723 CLS = (str, bytes) 2724 TYPE_NAME = "string" 2725 2726 2727class SetPair(UnittestPair): 2728 CLS = set 2729 2730 2731class TypePair(UnittestPair): 2732 CLS = type 2733 2734 2735class ObjectPair(UnittestPair): 2736 CLS = object 2737 2738 2739# This implements a variant of assertRaises/assertRaisesRegex where we first test 2740# if the exception is NotImplementedError, and if so just skip the test instead 2741# of failing it. 2742# 2743# This is implemented by inheriting from the (private) implementation of 2744# assertRaises from unittest.case, and slightly tweaking it for this new 2745# behavior. The year is 2021: this private class hierarchy hasn't changed since 2746# 2010, seems low risk to inherit from. 2747class AssertRaisesContextIgnoreNotImplementedError(unittest.case._AssertRaisesContext): 2748 def __exit__(self, exc_type, exc_value, tb): 2749 if exc_type is not None and issubclass(exc_type, NotImplementedError): 2750 self.test_case.skipTest(f"not_implemented: {exc_value}") # type: ignore[attr-defined] 2751 return super().__exit__(exc_type, exc_value, tb) 2752 2753 2754@contextmanager 2755def set_warn_always_context(new_val: bool): 2756 old_val = torch.is_warn_always_enabled() 2757 torch.set_warn_always(new_val) 2758 try: 2759 yield 2760 finally: 2761 torch.set_warn_always(old_val) 2762 2763 2764class NoTest: 2765 # causes pytest to not recognize this class as a test 2766 __test__ = False 2767 2768 2769class TestCase(expecttest.TestCase): 2770 # NOTE: "precision" lets classes and generated tests set minimum 2771 # atol values when comparing tensors. Used by @precisionOverride and @toleranceOverride, for 2772 # example. 2773 # NOTE: "rel_tol" lets classes and generated tests set minimum 2774 # rtol values when comparing tensors. Used by @toleranceOverride, for example. 2775 _precision: float = 0 2776 _rel_tol: float = 0 2777 2778 # Toggles whether to assert that `torch.get_default_dtype()` returns 2779 # `torch.float` when `setUp` and `tearDown` are called. 2780 _default_dtype_check_enabled: bool = False 2781 2782 # Always use difflib to print diffs on multi line equality. 2783 # Undocumented feature in unittest 2784 _diffThreshold = sys.maxsize 2785 maxDiff = None 2786 2787 # checker to early terminate test suite if unrecoverable failure occurs. 2788 def _should_stop_test_suite(self): 2789 if torch.cuda.is_initialized(): 2790 # CUDA device side error will cause subsequence test cases to fail. 2791 # stop entire test suite if catches RuntimeError during torch.cuda.synchronize(). 2792 try: 2793 torch.cuda.synchronize() 2794 except RuntimeError as rte: 2795 print("TEST SUITE EARLY TERMINATION due to torch.cuda.synchronize() failure", file=sys.stderr) 2796 print(str(rte), file=sys.stderr) 2797 return True 2798 return False 2799 else: 2800 return False 2801 2802 @property 2803 def precision(self) -> float: 2804 return self._precision 2805 2806 @precision.setter 2807 def precision(self, prec: float) -> None: 2808 self._precision = prec 2809 2810 @property 2811 def rel_tol(self) -> float: 2812 return self._rel_tol 2813 2814 @rel_tol.setter 2815 def rel_tol(self, prec: float) -> None: 2816 self._rel_tol = prec 2817 2818 _do_cuda_memory_leak_check = False 2819 _do_cuda_non_default_stream = False 2820 2821 # When True, if a test case raises a NotImplementedError, instead of failing 2822 # the test, skip it instead. 2823 _ignore_not_implemented_error = False 2824 2825 def __init__(self, method_name='runTest', methodName='runTest'): 2826 # methodName is the correct naming in unittest and testslide uses keyword arguments. 2827 # So we need to use both to 1) not break BC and, 2) support testslide. 2828 if methodName != "runTest": 2829 method_name = methodName 2830 super().__init__(method_name) 2831 2832 test_method = getattr(self, method_name, None) 2833 if test_method is not None: 2834 # Wraps the tested method if we should do CUDA memory check. 2835 if TEST_CUDA_MEM_LEAK_CHECK: 2836 self._do_cuda_memory_leak_check &= getattr(test_method, '_do_cuda_memory_leak_check', True) 2837 # FIXME: figure out the flaky -1024 anti-leaks on windows. See #8044 2838 if self._do_cuda_memory_leak_check and not IS_WINDOWS: 2839 self.wrap_with_cuda_policy(method_name, self.assertLeaksNoCudaTensors) 2840 2841 # Wraps the tested method if we should enforce non default CUDA stream. 2842 self._do_cuda_non_default_stream &= getattr(test_method, '_do_cuda_non_default_stream', True) 2843 if self._do_cuda_non_default_stream and not IS_WINDOWS: 2844 self.wrap_with_cuda_policy(method_name, self.enforceNonDefaultStream) 2845 2846 if self._ignore_not_implemented_error: 2847 self.wrap_with_policy(method_name, lambda: skip_exception_type(NotImplementedError)) 2848 2849 if PRINT_REPRO_ON_FAILURE: 2850 try: 2851 def _get_rel_test_path(abs_test_path): 2852 # Attempt to get relative path based on the "test" dir. 2853 # In CI, the working dir is not guaranteed to be the base repo dir so 2854 # we can't just compute relative path from that. 2855 parts = Path(abs_test_path).parts 2856 for i, part in enumerate(parts): 2857 if part == "test": 2858 base_dir = os.path.join(*parts[:i]) if i > 0 else '' 2859 return os.path.relpath(abs_test_path, start=base_dir) 2860 2861 # Can't determine containing dir; just return the test filename. 2862 # The path isn't strictly correct but it's arguably better than nothing. 2863 return os.path.split(abs_test_path)[1] 2864 2865 # NB: In Python 3.8, the getfile() call will return a path relative 2866 # to the working directory, so convert that to absolute. 2867 abs_test_path = os.path.abspath(inspect.getfile(type(self))) 2868 test_filename = _get_rel_test_path(abs_test_path) 2869 class_name = type(self).__name__ 2870 test_run_cmd = f"python {test_filename} {class_name}.{method_name}" 2871 env_var_prefix = TestEnvironment.repro_env_var_prefix() 2872 repro_parts = [env_var_prefix, test_run_cmd] 2873 self.wrap_with_policy( 2874 method_name, 2875 lambda repro_parts=repro_parts: print_repro_on_failure(repro_parts)) 2876 except Exception as e: 2877 # Don't fail entirely if we can't get the test filename 2878 log.info("could not print repro string", extra=str(e)) 2879 2880 def assertLeaksNoCudaTensors(self, name=None): 2881 name = self.id() if name is None else name 2882 return CudaMemoryLeakCheck(self, name) 2883 2884 def enforceNonDefaultStream(self): 2885 return CudaNonDefaultStream() 2886 2887 def _remove_ansi_escape(self, input): 2888 # 7-bit C1 ANSI sequences 2889 ansi_escape = re.compile(r''' 2890 \x1B # ESC 2891 (?: # 7-bit C1 Fe (except CSI) 2892 [@-Z\\-_] 2893 | # or [ for CSI, followed by a control sequence 2894 \[ 2895 [0-?]* # Parameter bytes 2896 [ -/]* # Intermediate bytes 2897 [@-~] # Final byte 2898 ) 2899 ''', re.VERBOSE) 2900 return ansi_escape.sub('', input) 2901 2902 def remove_comment_lines(self, input_string): 2903 lines = input_string.split('\n') 2904 filtered_lines = [line for line in lines if not line.strip().startswith('#')] 2905 return '\n'.join(filtered_lines) 2906 2907 def remove_empty_lines(self, input_string): 2908 lines = input_string.split('\n') 2909 filtered_lines = [line for line in lines if not line.strip() == ''] 2910 return '\n'.join(filtered_lines) 2911 2912 # ignore comments will ignore lines that starts with # after being stripped 2913 def assertExpectedInline(self, actual, expect, skip=0, ignore_comments=False, ignore_empty_lines=False): 2914 actual = actual if isinstance(actual, str) else str(actual) 2915 actual = self._remove_ansi_escape(actual) 2916 expect = self._remove_ansi_escape(expect) 2917 if ignore_comments: 2918 actual = self.remove_comment_lines(actual) 2919 expect = self.remove_comment_lines(expect) 2920 2921 if ignore_empty_lines: 2922 actual = self.remove_empty_lines(actual) 2923 expect = self.remove_empty_lines(expect) 2924 2925 return super().assertExpectedInline(actual if isinstance(actual, str) else str(actual), expect, skip + 1) 2926 2927 # Munges exceptions that internally contain stack traces, using munge_exc 2928 def assertExpectedInlineMunged( 2929 self, exc_type, callable, expect, *, suppress_suffix=True 2930 ): 2931 try: 2932 callable() 2933 except exc_type as e: 2934 self.assertExpectedInline( 2935 munge_exc(e, suppress_suffix=suppress_suffix, skip=1), expect, skip=1 2936 ) 2937 return 2938 self.fail(msg="Did not raise when expected to") 2939 2940 def assertLogs(self, logger=None, level=None): 2941 if logger is None: 2942 logger = logging.getLogger("torch") 2943 return super().assertLogs(logger, level) 2944 2945 def assertNoLogs(self, logger=None, level=None): 2946 if logger is None: 2947 logger = logging.getLogger("torch") 2948 return super().assertNoLogs(logger, level) 2949 2950 def wrap_with_cuda_policy(self, method_name, policy): 2951 test_method = getattr(self, method_name) 2952 # the import below may initialize CUDA context, so we do it only if 2953 # self._do_cuda_memory_leak_check or self._do_cuda_non_default_stream 2954 # is True. 2955 # TODO: sure looks like we unconditionally initialize the context here 2956 # -- ezyang 2957 from torch.testing._internal.common_cuda import TEST_CUDA 2958 fullname = self.id().lower() # class_name.method_name 2959 if TEST_CUDA and ('gpu' in fullname or 'cuda' in fullname): 2960 setattr(self, method_name, self.wrap_method_with_policy(test_method, policy)) 2961 2962 def wrap_with_policy(self, method_name, policy): 2963 test_method = getattr(self, method_name) 2964 setattr(self, method_name, self.wrap_method_with_policy(test_method, policy)) 2965 2966 # A policy is a zero-argument function that returns a context manager. 2967 # We don't take the context manager directly as it may be necessary to 2968 # construct it once per test method 2969 def wrap_method_with_policy(self, method, policy): 2970 # Assumes that `method` is the tested function in `self`. 2971 # NOTE: Python Exceptions (e.g., unittest.Skip) keeps objects in scope 2972 # alive, so this cannot be done in setUp and tearDown because 2973 # tearDown is run unconditionally no matter whether the test 2974 # passes or not. For the same reason, we can't wrap the `method` 2975 # call in try-finally and always do the check. 2976 @wraps(method) 2977 def wrapper(self, *args, **kwargs): 2978 with policy(): 2979 method(*args, **kwargs) 2980 return types.MethodType(wrapper, self) 2981 2982 def wrap_with_cuda_memory_check(self, method): 2983 return self.wrap_method_with_policy(method, self.assertLeaksNoCudaTensors) 2984 2985 def _run_custom(self, result=None): 2986 using_unittest = isinstance(result, unittest.TestResult) 2987 2988 super_run = super().run 2989 test_cls = super_run.__self__ 2990 2991 # Are we compiling? 2992 compiled = TEST_WITH_TORCHDYNAMO or TEST_WITH_AOT_EAGER or TEST_WITH_TORCHINDUCTOR 2993 # Is the class strict and compiling? 2994 strict_default = False 2995 should_reset_dynamo = False 2996 if compiled: 2997 try: 2998 path = inspect.getfile(type(test_cls)) 2999 full_path = os.path.abspath(path) 3000 match = re.match(r".*/test/(.*).py", full_path) 3001 if match is not None: 3002 filename = match.group(1) 3003 if TEST_WITH_TORCHINDUCTOR: 3004 from .dynamo_test_failures import FIXME_inductor_non_strict 3005 strict_default = filename not in FIXME_inductor_non_strict 3006 3007 from .dynamo_test_failures import FIXME_inductor_dont_reset_dynamo 3008 should_reset_dynamo = filename not in FIXME_inductor_dont_reset_dynamo 3009 else: 3010 strict_default = True 3011 # inspect.getfile can fail with these 3012 except (OSError, TypeError): 3013 pass 3014 if "STRICT_DEFAULT" in os.environ: 3015 if os.environ["STRICT_DEFAULT"] == "1": 3016 strict_default = True 3017 3018 strict_mode = False 3019 if compiled: 3020 test_method = getattr(self, self._testMethodName) 3021 if hasattr(test_method, "dynamo_strict"): 3022 strict_mode = test_method.dynamo_strict 3023 elif hasattr(test_cls, "dynamo_strict"): 3024 strict_mode = test_cls.dynamo_strict 3025 else: 3026 strict_mode = strict_default 3027 nopython = getattr(test_cls, "dynamo_strict_nopython", False) and compiled 3028 3029 if strict_mode or should_reset_dynamo: 3030 torch._dynamo.reset() 3031 3032 # TODO: Remove this; this is grandfathered in because we suppressed errors 3033 # on test suite previously 3034 # When strict mode is False, suppress_errors is True 3035 if compiled: 3036 suppress_errors = not strict_mode 3037 else: 3038 suppress_errors = torch._dynamo.config.suppress_errors 3039 with unittest.mock.patch("torch._dynamo.config.suppress_errors", suppress_errors): 3040 if TEST_WITH_TORCHINDUCTOR: 3041 super_run = torch._dynamo.optimize("inductor")(super_run) 3042 elif TEST_WITH_AOT_EAGER: 3043 super_run = torch._dynamo.optimize("aot_eager_decomp_partition")(super_run) 3044 elif TEST_WITH_TORCHDYNAMO: 3045 # TorchDynamo optimize annotation 3046 # Assume eager-generated GraphModules will not error out. 3047 # If we do, this is probably a Dynamo bug! 3048 super_run = torch._dynamo.optimize("eager_noexcept", nopython=nopython)(super_run) 3049 key = f"{self.__class__.__name__}.{self._testMethodName}" 3050 from .dynamo_test_failures import dynamo_expected_failures, dynamo_skips 3051 3052 def expect_failure(f, test_name): 3053 @wraps(f) 3054 def wrapper(*args, **kwargs): 3055 try: 3056 f(*args, **kwargs) 3057 except BaseException as e: 3058 self.skipTest(e) 3059 raise RuntimeError(f"Unexpected success, please remove `test/dynamo_expected_failures/{test_name}`") 3060 return wrapper 3061 3062 if key in dynamo_expected_failures: 3063 method = getattr(self, self._testMethodName) 3064 setattr(self, self._testMethodName, expect_failure(method, key)) 3065 3066 def ignore_failure(f, test_name): 3067 @wraps(f) 3068 def wrapper(*args, **kwargs): 3069 try: 3070 f(*args, **kwargs) 3071 except BaseException as e: 3072 self.skipTest(e) 3073 method = getattr(self, self._testMethodName) 3074 if getattr(method, "__unittest_expecting_failure__", False): 3075 self.skipTest("unexpected success") 3076 else: 3077 self.skipTest(f"This test passed, maybe we can remove `test/dynamo_skips/{test_name}`") 3078 return wrapper 3079 3080 if key in dynamo_skips: 3081 method = getattr(self, self._testMethodName) 3082 setattr(self, self._testMethodName, ignore_failure(method, key)) 3083 3084 super_run(result=result) 3085 3086 if strict_mode or should_reset_dynamo: 3087 torch._dynamo.reset() 3088 3089 # Early terminate test if necessary. If using pytest, use the -x flag instead 3090 if using_unittest and self._should_stop_test_suite(): 3091 if result.wasSuccessful(): 3092 case = TestCase() 3093 if TEST_SAVE_XML is not None: 3094 # This is a big hacky, XMLRunner modifies expected type from TestCase to TestInfo 3095 # Create dummy TestInfo to record results correctly 3096 from xmlrunner.result import _TestInfo # type: ignore[import] 3097 case = _TestInfo(result, case) 3098 case.output = _TestInfo.ERROR 3099 case.elapsed_time = 0.0 3100 case.test_description = "TestSuiteEarlyFailure" 3101 # This shouldn't really happen, but if does add fake failure 3102 # For more details see https://github.com/pytorch/pytorch/issues/71973 3103 result.failures.append((case, "TestSuite execution was aborted early")) 3104 assert result.wasSuccessful() is False 3105 result.stop() 3106 3107 3108 def run(self, result=None): 3109 with contextlib.ExitStack() as stack: 3110 if TEST_WITH_CROSSREF: 3111 stack.enter_context(CrossRefMode()) 3112 self._run_custom( 3113 result=result, 3114 ) 3115 3116 def setUp(self): 3117 check_if_enable(self) 3118 set_rng_seed(SEED) 3119 3120 # Save global check sparse tensor invariants state that can be 3121 # restored from tearDown: 3122 self._check_invariants = torch.sparse.check_sparse_tensor_invariants.is_enabled() 3123 3124 # Enable invariant checks for all sparse tensors constructions 3125 # including the unsafe ones. If this is not desired for some 3126 # test case, use check_invariants=False optional argument to 3127 # sparse tensor constructors or 3128 # @torch.sparse.check_sparse_tensor_invariants(False) 3129 # decorator to disable the invariant checks. 3130 torch.sparse.check_sparse_tensor_invariants.enable() 3131 3132 if self._default_dtype_check_enabled: 3133 assert torch.get_default_dtype() == torch.float 3134 3135 # attempt to reset some global state at the end of the test 3136 self._prev_grad_state = torch.is_grad_enabled() 3137 3138 def tearDown(self): 3139 # There exists test cases that override TestCase.setUp 3140 # definition, so we cannot assume that _check_invariants 3141 # attribute is defined in general. 3142 if hasattr(self, '_check_invariants'): 3143 # Restore the global check sparse tensor invariants state 3144 if self._check_invariants: 3145 torch.sparse.check_sparse_tensor_invariants.enable() 3146 else: 3147 torch.sparse.check_sparse_tensor_invariants.disable() 3148 3149 if self._default_dtype_check_enabled: 3150 assert torch.get_default_dtype() == torch.float 3151 3152 # attribute may not be defined, per above 3153 if hasattr(self, '_prev_grad_state'): 3154 torch.set_grad_enabled(self._prev_grad_state) 3155 3156 @staticmethod 3157 def _make_crow_indices(n_rows, n_cols, nnz, 3158 *, device, dtype, random=True): 3159 """Return crow_indices of a CSR tensor with size (n_rows, n_cols) and 3160 the number of specified elements nnz. 3161 3162 If random is True, the column counts of rows are in random 3163 order. Otherwise, the column counts of rows are defined by the 3164 used sampling method. 3165 3166 Sampling method 3167 --------------- 3168 3169 The used sampling method was introduced in 3170 https://pearu.github.io/csr_sampling.html, and here we give 3171 only an overall description of the method. 3172 3173 Notice that crow_indices can be defined as cumsum(counts) 3174 where counts is a sequence of non-negative integers satisfying 3175 the following conditions: 3176 3177 len(counts) == n_rows + 1 3178 counts.max() <= n_cols 3179 3180 while counts[i + 1] is interpreted as the number of specified 3181 elements in the i-th row. 3182 3183 The used sampling method aims at increasing the diversity of 3184 CSR samples, that is, a CSR sample should contain (i) rows 3185 that are all filled, (ii) rows with no elements at all, and 3186 (iii) rows that are partially filled. At the same time and for 3187 the given total number of specified elements (nnz), there 3188 should be minimal preference to rows with a given number of 3189 elements. To achieve this, the sampling method is built-up on 3190 using a sawteeth model for counts. In the simplest case, we 3191 would have 3192 3193 counts = arange(n_rows + 1) % (n_cols + 1) 3194 3195 that has equal number of all possible column counts per row. 3196 This formula can be used only for specific input values of 3197 n_rows, n_cols, and nnz. To generalize this model to any 3198 combinations of inputs, the counts model above is extended 3199 with an incomplete sawtooth, and the right and lower 3200 rectangular parts that will guarantee that 3201 3202 counts.sum() == nnz 3203 3204 for any combination of n_rows, n_cols, and nnz. Basically, 3205 we'll find a maximal window in (n_rows + 1, n_cols + 1)-grid 3206 that is able to hold a sequence of sawteeth and so-called 3207 final correction, while the external part of the window is 3208 filled with counts to meet the nnz constraint exactly. 3209 """ 3210 assert 0 <= nnz <= n_rows * n_cols, (nnz, n_rows, n_cols) 3211 3212 def sawteeth(n, m): 3213 # return the total number of counts in the sequence of 3214 # sawteeth where n and m define a window in (n_rows+1, 3215 # n_cols+1) rectangle where the sequence of sawteeth 3216 # perfectly fit. 3217 M = (n_cols - m) * (n_cols - m + 1) // 2 3218 K = (n_rows - n) % (n_cols - m + 1) 3219 return M * ((n_rows - n) // (n_cols - m + 1)) + K * (K - 1) // 2 3220 3221 # Different from the original method description, here counts 3222 # has leading 0 required by crow_indices: 3223 counts = torch.zeros(n_rows + 1, dtype=dtype, device=torch.device('cpu')) 3224 3225 n = m = 0 3226 N = sawteeth(n, m) 3227 if N and nnz >= max(N, n_cols): 3228 # determine the width of the sawteeth window. We use bisection to solve 3229 # N(n, 0) == 0 or nnz - n * n_cols < max(N(n, 0), n_cols) 3230 # for n 3231 n_left = n 3232 n_right = n_rows - 1 3233 N_right = sawteeth(n_right, m) 3234 while n_right - n_left > 1: 3235 n_middle = (n_left + n_right) // 2 3236 N_middle = sawteeth(n_middle, m) 3237 if N_middle == 0 or nnz - n_middle * n_cols < max(N_middle, n_cols): 3238 n_right, N_right = n_middle, N_middle 3239 else: 3240 n_left = n_middle 3241 n, N = n_right, N_right 3242 # fill the right rectangle with counts: 3243 assert n 3244 counts[-n:].fill_(n_cols) 3245 3246 if N and nnz - n * n_cols >= max(N, n_rows - n): 3247 # determine the height of the sawteeth window. We use bisection to solve 3248 # N(n, m) == 0 or nnz - n * n_cols - m * (n_rows - n) < max(N(n, m), n_rows - n) 3249 # for m. 3250 m_left = m 3251 m_right = n_cols - 1 3252 N_right = sawteeth(n, m_right) 3253 while m_right - m_left > 1: 3254 m_middle = (m_left + m_right) // 2 3255 N_middle = sawteeth(n, m_middle) 3256 if N_middle == 0 or nnz - n * n_cols - m_middle * (n_rows - n) < max(N_middle, n_rows - n): 3257 m_right, N_right = m_middle, N_middle 3258 else: 3259 m_left = m_middle 3260 m, N = m_right, N_right 3261 # fill the bottom rectangle with counts: 3262 assert m 3263 counts[1:n_rows - n + 1].fill_(m) 3264 3265 if N: 3266 # fill the sawteeth window with counts 3267 q, r = divmod(nnz - n * n_cols - m * (n_rows - n), 3268 (n_cols - m) * (n_cols - m + 1) // 2) 3269 p = 1 + q * (n_cols - m + 1) 3270 k = math.isqrt(2 * r) 3271 if k * (k + 1) > 2 * r: 3272 k -= 1 3273 corr = r - k * (k + 1) // 2 3274 assert not ((p > 1) and (m > 0)) # full sawteeth are never on top of a bottom rectangle 3275 # sequence of full sawteeth: 3276 counts[1:p] = torch.arange(p - 1, dtype=dtype, device=counts.device) % (n_cols - m + 1) 3277 # incomplete sawtooth: 3278 counts[p:p + k + 1] += torch.arange(k + 1, dtype=dtype, device=counts.device) 3279 else: 3280 # given input does not support sawteeth 3281 p = 1 3282 corr = nnz - n * n_cols - m * (n_rows - n) 3283 3284 # correction that will guarantee counts.sum() == nnz: 3285 counts[p] += corr 3286 3287 if random: 3288 # randomize crow_indices by shuffling the sawteeth 3289 # sequence: 3290 perm = torch.randperm(n_rows, device=counts.device) 3291 counts[1:] = counts[1:][perm] 3292 3293 # compute crow_indices: 3294 crow_indices = counts 3295 crow_indices.cumsum_(dim=0) 3296 return crow_indices.to(device=device) 3297 3298 def genSparseCompressedTensor(self, size, nnz, *, layout, device, dtype, index_dtype, blocksize=(), dense_dims=0): 3299 from operator import mul 3300 from functools import reduce 3301 sparse_dim = 2 3302 assert all(size[d] > 0 for d in range(len(size))) or nnz == 0, 'invalid arguments' 3303 assert len(size) >= sparse_dim 3304 if blocksize: 3305 assert len(blocksize) == 2, (size, blocksize) 3306 assert size[-2 - dense_dims] % blocksize[0] == 0, (size, blocksize) 3307 assert size[-1 - dense_dims] % blocksize[1] == 0, (size, blocksize) 3308 blocksize0, blocksize1 = blocksize 3309 else: 3310 blocksize0 = blocksize1 = 1 3311 3312 size = tuple(size) 3313 dense_size = size[(len(size) - dense_dims):] 3314 3315 def random_sparse_compressed(n_compressed_dims, n_plain_dims, nnz): 3316 compressed_indices = self._make_crow_indices(n_compressed_dims, n_plain_dims, nnz, device=device, dtype=index_dtype) 3317 plain_indices = torch.zeros(nnz, dtype=index_dtype, device=device) 3318 for i in range(n_compressed_dims): 3319 count = compressed_indices[i + 1] - compressed_indices[i] 3320 plain_indices[compressed_indices[i]:compressed_indices[i + 1]], _ = torch.sort( 3321 torch.randperm(n_plain_dims, dtype=index_dtype, device=device)[:count]) 3322 low = -1 if dtype != torch.uint8 else 0 3323 high = 1 if dtype != torch.uint8 else 2 3324 values = make_tensor((nnz,) + blocksize + dense_size, device=device, dtype=dtype, low=low, high=high) 3325 return values, compressed_indices, plain_indices 3326 3327 batch_shape = size[:-2 - dense_dims] 3328 n_batch = reduce(mul, batch_shape, 1) 3329 3330 if layout in {torch.sparse_csr, torch.sparse_bsr}: 3331 n_compressed_dims, n_plain_dims = size[-2 - dense_dims] // blocksize0, size[-1 - dense_dims] // blocksize1 3332 else: 3333 n_compressed_dims, n_plain_dims = size[-1 - dense_dims] // blocksize1, size[-2 - dense_dims] // blocksize0 3334 blocknnz = nnz // (blocksize0 * blocksize1) 3335 sparse_tensors = [random_sparse_compressed(n_compressed_dims, n_plain_dims, blocknnz) for _ in range(n_batch)] 3336 sparse_tensors_it = map(list, zip(*sparse_tensors)) 3337 3338 values = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, blocknnz, *blocksize, *dense_size) 3339 compressed_indices = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, -1) 3340 plain_indices = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, -1) 3341 return torch.sparse_compressed_tensor(compressed_indices, plain_indices, 3342 values, size=size, dtype=dtype, layout=layout, device=device) 3343 3344 def genSparseCSRTensor(self, size, nnz, *, device, dtype, index_dtype, dense_dims=0): 3345 return self.genSparseCompressedTensor(size, nnz, layout=torch.sparse_csr, device=device, 3346 dtype=dtype, index_dtype=index_dtype, blocksize=(), dense_dims=dense_dims) 3347 3348 def genSparseCSCTensor(self, size, nnz, *, device, dtype, index_dtype, dense_dims=0): 3349 return self.genSparseCompressedTensor(size, nnz, layout=torch.sparse_csc, device=device, 3350 dtype=dtype, index_dtype=index_dtype, blocksize=(), dense_dims=0) 3351 3352 def genSparseBSRTensor(self, size, blocksize, nnz, *, device, dtype, index_dtype, dense_dims=0): 3353 assert len(blocksize) == 2 3354 return self.genSparseCompressedTensor(size, nnz, layout=torch.sparse_bsr, device=device, 3355 dtype=dtype, index_dtype=index_dtype, blocksize=blocksize, dense_dims=dense_dims) 3356 3357 def genSparseBSCTensor(self, size, blocksize, nnz, *, device, dtype, index_dtype, dense_dims=0): 3358 assert len(blocksize) == 2 3359 return self.genSparseCompressedTensor(size, nnz, layout=torch.sparse_bsc, device=device, 3360 dtype=dtype, index_dtype=index_dtype, blocksize=blocksize, dense_dims=dense_dims) 3361 3362 def genSparseTensor(self, size, sparse_dim, nnz, is_uncoalesced, device, dtype): 3363 # Assert not given impossible combination, where the sparse dims have 3364 # empty numel, but nnz > 0 makes the indices containing values. 3365 assert all(size[d] > 0 for d in range(sparse_dim)) or nnz == 0, 'invalid arguments' 3366 3367 v_size = [nnz] + list(size[sparse_dim:]) 3368 v = make_tensor(v_size, device=device, dtype=dtype, low=-1, high=1) 3369 i = torch.rand(sparse_dim, nnz, device=device) 3370 i.mul_(torch.tensor(size[:sparse_dim]).unsqueeze(1).to(i)) 3371 i = i.to(torch.long) 3372 if is_uncoalesced: 3373 i1 = i[:, :(nnz // 2), ...] 3374 i2 = i[:, :((nnz + 1) // 2), ...] 3375 i = torch.cat([i1, i2], 1) 3376 x = torch.sparse_coo_tensor(i, v, torch.Size(size), dtype=dtype, device=device) 3377 3378 if not is_uncoalesced: 3379 x = x.coalesce() 3380 else: 3381 # FIXME: `x` is a sparse view of `v`. Currently rebase_history for 3382 # sparse views is not implemented, so this workaround is 3383 # needed for inplace operations done on `x`, e.g., copy_(). 3384 # Remove after implementing something equivalent to CopySlice 3385 # for sparse views. 3386 # NOTE: We do clone() after detach() here because we need to be able to change size/storage of x afterwards 3387 x = x.detach().clone()._coalesced_(False) 3388 return x, x._indices().clone(), x._values().clone() 3389 3390 def generate_simple_inputs(self, layout, 3391 device=None, 3392 dtype=None, 3393 index_dtype=None, 3394 pin_memory=None, 3395 members_pin_memory=None, 3396 enable_batch=True, 3397 enable_hybrid=True, 3398 enable_zero_sized=True, 3399 enable_non_contiguous_indices=True, 3400 enable_non_contiguous_values=True, 3401 enable_batch_variable_nse=False, 3402 output_tensor=True, 3403 patterns=None): 3404 """Generator of simple inputs for tensor constructors of the given layout. 3405 3406 The generated tensor inputs have the following properties: 3407 3408 - tensor shapes are minimal but not trivial 3409 - tensor values are sorted sequences for COO and CSR formats, e.g. [1, 2, 3, 4] 3410 - the generated tensors represent the same mathematical tensor for all layouts 3411 - the generated tensors include regular, zero-sized, and optionally, batched or/and hybrid tensors. 3412 - the generated tensors include contiguous or non-contiguous tensors both in indices and values 3413 3414 If output_tensor is True, yield tensors with the given 3415 layout. Otherwise, yield inputs to the corresponding tensor 3416 constructors: 3417 3418 - sparse compressed input is defined as 3419 (compressed_indices, plain_indices, values), dict(size=expected_size_from_shape_inference, device=device, dtype=dtype, 3420 pin_memory=pin_memory) 3421 3422 - sparse COO input is defined as 3423 (indices, values), dict(size=expected_size_from_shape_inference, device=device, dtype=dtype, pin_memory=pin_memory) 3424 3425 - strided input is defined as 3426 (values,), dict(device=device, dtype=dtype) 3427 """ 3428 if index_dtype is None: 3429 index_dtype = torch.int64 3430 3431 is_compressed_sparse_layout = layout in {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc} 3432 3433 if output_tensor: 3434 for args, kwargs in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype, 3435 pin_memory=pin_memory, 3436 enable_batch=enable_batch, enable_hybrid=enable_hybrid, 3437 enable_zero_sized=enable_zero_sized, 3438 enable_non_contiguous_indices=enable_non_contiguous_indices, 3439 enable_non_contiguous_values=enable_non_contiguous_values, 3440 enable_batch_variable_nse=enable_batch_variable_nse, 3441 output_tensor=False): 3442 if members_pin_memory: 3443 args = tuple(a.pin_memory() for a in args) 3444 if layout is torch.strided: 3445 assert len(args) == 1 3446 size = kwargs.pop('size', None) # to ensure that a zero-sized tensor has the desired shape 3447 assert size is not None 3448 if pin_memory: 3449 yield args[0].reshape(size).pin_memory() 3450 else: 3451 yield args[0].reshape(size) 3452 elif layout is torch.sparse_coo: 3453 yield torch.sparse_coo_tensor(*args, **kwargs) 3454 elif is_compressed_sparse_layout: 3455 kwargs.update(layout=layout) 3456 yield torch.sparse_compressed_tensor(*args, **kwargs) 3457 else: 3458 assert 0 # unreachable 3459 return 3460 3461 def get_blockpattern(pattern, blocksize): 3462 basesize = pattern.shape 3463 assert basesize[0] % blocksize[0] == 0, (basesize, blocksize) 3464 assert basesize[1] % blocksize[1] == 0, (basesize, blocksize) 3465 blockpattern = pattern.reshape(-1, 3466 blocksize[0], 3467 basesize[1] // blocksize[1], 3468 blocksize[1]).transpose(-3, -2).any(-1).any(-1) 3469 block_ids = torch.arange(1, blockpattern.numel() + 1).reshape(blockpattern.shape) 3470 return (blockpattern != 0) * block_ids 3471 3472 def get_sparse_data(pattern): 3473 basesize = pattern.shape 3474 assert len(basesize) == 2, basesize # pattern is expected to be a matrix 3475 3476 # We cannot use `torch.sparse_xyz_tensor(pattern)` to 3477 # compute the sparse layout indices and values because 3478 # generate_simple_inputs is used to generate the inputs to 3479 # test `torch.sparse_xyz_tensor` factory functions, so 3480 # we'll compute the indices and values independently of 3481 # the factory functions. 3482 3483 indices = torch.where(pattern != 0) 3484 coo_indices = torch.stack(indices) 3485 crow_indices = torch.zeros(basesize[0] + 1, dtype=torch.int64) 3486 crow_indices[1:] = torch.cumsum(coo_indices[0].bincount(minlength=basesize[0]), 0) 3487 col_indices = coo_indices[1] 3488 strided_values = torch.zeros(basesize, dtype=torch.int64) 3489 3490 # the property of `values == range(1, 1+nnz)` is used in 3491 # get_sparse_data_with_block to relate BSR and BSC values, 3492 # so, don't change the following line: 3493 values = torch.arange(1, 1 + len(indices[0]), dtype=torch.int64) 3494 strided_values[indices] = values 3495 3496 indices_T = torch.where(pattern.transpose(0, 1) != 0) 3497 coo_indices_T = torch.stack(indices_T) 3498 ccol_indices = torch.zeros(basesize[1] + 1, dtype=torch.int64) 3499 ccol_indices[1:] = torch.cumsum(coo_indices_T[0].bincount(minlength=basesize[1]), 0) 3500 row_indices = coo_indices_T[1] 3501 csc_values = strided_values.transpose(0, 1)[indices_T] 3502 3503 return {torch.sparse_coo: (coo_indices, values), 3504 torch.sparse_csr: (crow_indices, col_indices, values), 3505 torch.sparse_csc: (ccol_indices, row_indices, csc_values), 3506 torch.strided: (strided_values,)} 3507 3508 def get_sparse_data_with_block(pattern, blocksize): 3509 nonblock_data = get_sparse_data(pattern) 3510 blockpattern = get_blockpattern(pattern, blocksize) 3511 block_data = get_sparse_data(blockpattern) 3512 3513 strided_values = nonblock_data[torch.strided][0] 3514 block_indices = block_data[torch.sparse_coo][0] 3515 bsr_values = torch.stack([strided_values[bi * blocksize[0]:(bi + 1) * blocksize[0], 3516 bj * blocksize[1]:(bj + 1) * blocksize[1]] 3517 for bi, bj in block_indices.transpose(0, 1)]) 3518 3519 # here we use the property `values == range(1, 1+nnz)` and 3520 # `values` relation to `csc_values` (see get_sparse_data) 3521 # to get BSC blocks via reordering the BSR blocks: 3522 bsc_values = bsr_values[block_data[torch.sparse_csc][2] - 1] 3523 3524 return {torch.sparse_bsr: (*block_data[torch.sparse_csr][:2], bsr_values), 3525 torch.sparse_bsc: (*block_data[torch.sparse_csc][:2], bsc_values), 3526 **nonblock_data} 3527 3528 def get_batch_sparse_data(pattern, blocksize): 3529 size = pattern.shape 3530 if len(size) <= 2: # non-batch 3531 return get_sparse_data_with_block(pattern, blocksize) 3532 3533 # batch data is created recursively: 3534 batch_data = {} 3535 for i, item in enumerate(pattern): 3536 for layout, d in get_batch_sparse_data(item, blocksize).items(): 3537 target = batch_data.get(layout) 3538 if layout is torch.sparse_coo: 3539 # a "batch COO" means a COO with the leading 3540 # sparse dimensions interpreted as batch 3541 # dimensions 3542 ext_coo_indices1 = torch.cat((torch.full((1, len(d[1])), i, dtype=torch.int64), d[0])) 3543 if target is None: 3544 target = batch_data[layout] = (ext_coo_indices1, d[1]) 3545 else: 3546 target[0].set_(torch.cat((target[0], ext_coo_indices1), 1)) 3547 target[1].set_(torch.cat((target[1], d[1]))) 3548 else: 3549 if target is None: 3550 target = batch_data[layout] = tuple(d[j].unsqueeze(0) for j in range(len(d))) 3551 else: 3552 for j in range(len(d)): 3553 target[j].set_(torch.cat((target[j], d[j].unsqueeze(0)))) 3554 return batch_data 3555 3556 def generate_values(base, densesize): 3557 """Generates a tensor of shape densesize with values equal to 3558 3559 base + i_1 * 10^0 + ... + i_d * 10^{d - 1} 3560 3561 at indices i_1, ..., i_d (with 0 <= i_j < densesize[j] for any 1 <= j <= 3562 len(densesize)) 3563 3564 This mapping produces unique values as long as 3565 densesize[i] < 10 for all i in range(len(densesize)). 3566 """ 3567 3568 if not densesize: 3569 return base 3570 if not isinstance(base, int) and base.ndim > 0: 3571 return torch.stack([generate_values(b, densesize) for b in base]) 3572 if base == 0: 3573 return torch.zeros(densesize, dtype=torch.int64) 3574 r = torch.arange(densesize[0], dtype=torch.int64) 3575 for i, d in enumerate(densesize[1:]): 3576 y = torch.arange(d, dtype=torch.int64) * (10 ** (i + 1)) 3577 r = r[..., None] + y[None, ...] 3578 r.add_(base) 3579 return r 3580 3581 if patterns is None: 3582 # A pattern is a 3-tuple with the following items: 3583 # 3584 # - a list of integers with the depth of two or more. The 3585 # integers define the sparsity patterns of the generated 3586 # inputs: zero values correspond to unspecified 3587 # elements/blocks, and non-zero values to the specified 3588 # elements. 3589 # 3590 # For debugging convenience, the elements with the same 3591 # value typically belong to the same block. However, it 3592 # is not a hard requirement: as long as the shape of a 3593 # pattern divides with block sizes, the pattern will be 3594 # a valid one. 3595 # 3596 # If the depth of the list is larger than two, inputs 3597 # with batch dimensions will be generated. 3598 # 3599 # - a list of 2-tuples of block sizes, used to generate 3600 # BSR/BSC tensors with various block size parameters 3601 # 3602 # - a list of tuples of dense dimensions, used to generate 3603 # hybrid tensors with various dense dimensions 3604 # 3605 patterns = [ 3606 # a simple 3 x 2 tensor: non-hybrid, hybrid with 1 and 2 dense dimensions 3607 ([[1, 2, 0], 3608 [1, 0, 3]], [(2, 1), (1, 3)], [(), (2,), (4, 5)]), 3609 # 2 x 3 batch of 3 x 2 tensors: non-hybrid and hybrid with 2 dense dimensions 3610 ([[[[1, 2, 0], 3611 [1, 0, 3]], 3612 [[1, 2, 3], 3613 [1, 0, 0]], 3614 [[1, 0, 0], 3615 [1, 2, 3]]], 3616 [[[0, 2, 0], 3617 [1, 2, 3]], 3618 [[1, 0, 3], 3619 [1, 2, 0]], 3620 [[1, 2, 3], 3621 [0, 2, 0]]]], [(2, 1), (2, 3)], [(), (2,)]), 3622 # tensor with non-trivial blocksize 3623 ([[0, 1, 0, 2, 0, 2], 3624 [0, 1, 0, 0, 2, 0], 3625 [3, 3, 3, 0, 0, 0], 3626 [0, 0, 0, 0, 0, 0], 3627 [0, 5, 0, 6, 6, 6], 3628 [5, 0, 5, 6, 6, 6], 3629 [0, 0, 0, 0, 8, 8], 3630 [7, 7, 7, 0, 8, 8]], [(2, 3)], [(), (4, 5)]), 3631 # batch tensor with variable NSE 3632 # Requires https://github.com/pytorch/pytorch/pull/84843 or similar. 3633 ([[[1, 2], 3634 [3, 4]], 3635 [[1, 0], 3636 [0, 0]]], [(1, 1)], ([()] if enable_batch_variable_nse else []))] 3637 3638 def non_contiguous_copy(t, dim=-1, offset=0): 3639 # return a copy of t that is non-contiguous along the 3640 # given dimension and with the given storage offset 3641 self.assertTrue(t.is_contiguous()) 3642 if dim < 0: 3643 dim = dim + t.ndim 3644 assert dim >= 0 and dim < t.ndim 3645 step = max(2, offset + 1) 3646 tmp = torch.zeros((*t.shape[:dim], t.shape[dim] * step, *t.shape[dim + 1:]), dtype=t.dtype, device=t.device) 3647 dim_slices = (*((slice(None),) * dim), slice(offset, None, step)) 3648 r = tmp[dim_slices].copy_(t) 3649 self.assertFalse(r.is_contiguous()) 3650 self.assertEqual(t, r) 3651 return r 3652 3653 # the main loop of the method: 3654 for pattern, blocksizes, densesizes in patterns: 3655 if not enable_hybrid: 3656 densesizes = [s for s in densesizes if not s] 3657 if not (densesizes and blocksizes): 3658 continue 3659 pattern = torch.tensor(pattern, dtype=torch.int64) 3660 if not enable_batch and pattern.ndim > 2: 3661 continue 3662 for blocksize in blocksizes: 3663 data = get_batch_sparse_data(pattern, blocksize)[layout] 3664 for densesize in densesizes: 3665 indices = [a.to(device=device, dtype=index_dtype) for a in data[:-1]] 3666 values = generate_values(data[-1], densesize).to(device=device, dtype=dtype) 3667 kwargs = dict(device=device, dtype=dtype, size=pattern.shape + densesize) 3668 if pin_memory is not None: 3669 kwargs.update(pin_memory=pin_memory) 3670 3671 yield (*indices, values), kwargs.copy() 3672 if enable_non_contiguous_indices and pattern.ndim > 2: 3673 # sparse compressed indices can be sliced only along batch dimensions 3674 for (dim, offset) in {(0, 1), (-2, 0)}: 3675 indices_copy = [non_contiguous_copy(a, dim=dim, offset=offset) for a in indices] 3676 yield (*indices_copy, values), kwargs.copy() 3677 3678 if enable_non_contiguous_values: 3679 values_copy = non_contiguous_copy(values, dim=-1, offset=1) 3680 yield (*indices_copy, values_copy), kwargs.copy() 3681 3682 if enable_non_contiguous_values: 3683 values_copy = non_contiguous_copy(values, dim=-1, offset=1) 3684 yield (*indices, values_copy), kwargs.copy() 3685 3686 # zero-sized tensor inputs, non-batch, non-hybrid/hybrid 3687 if enable_zero_sized: 3688 for basesize, blocksizes, densesizes in [ 3689 ((2, 0), [(1, 2)], [(), (2,), (2, 3)] if enable_hybrid else [()]), 3690 ((0, 2), [(1, 2), (2, 1), (3, 2)], [()]), 3691 ((0, 0), [(1, 2)], [()]), 3692 ]: 3693 for blocksize in blocksizes: 3694 for densesize in densesizes: 3695 if layout == torch.strided: 3696 indices = () 3697 values = torch.empty((basesize + densesize), device=device, dtype=dtype) 3698 elif layout == torch.sparse_coo: 3699 indices = (torch.empty(len(basesize), 0, device=device, dtype=index_dtype),) 3700 values = torch.empty((0, *densesize), device=device, dtype=dtype) 3701 elif layout == torch.sparse_csr: 3702 crow_indices = torch.tensor([0] * (basesize[0] + 1), device=device, dtype=index_dtype) 3703 col_indices = torch.empty(0, device=device, dtype=index_dtype) 3704 indices = (crow_indices, col_indices) 3705 values = torch.empty((0, *densesize), device=device, dtype=dtype) 3706 elif layout == torch.sparse_csc: 3707 ccol_indices = torch.tensor([0] * (basesize[1] + 1), device=device, dtype=index_dtype) 3708 row_indices = torch.empty(0, device=device, dtype=index_dtype) 3709 indices = (ccol_indices, row_indices) 3710 values = torch.empty((0, *densesize), device=device, dtype=dtype) 3711 elif layout == torch.sparse_bsr: 3712 crow_indices = torch.tensor([0] * (basesize[0] // blocksize[0] + 1), device=device, dtype=index_dtype) 3713 col_indices = torch.empty(0, device=device, dtype=index_dtype) 3714 indices = (crow_indices, col_indices) 3715 values = torch.empty((0, *blocksize, *densesize), device=device, dtype=dtype) 3716 elif layout == torch.sparse_bsc: 3717 ccol_indices = torch.tensor([0] * (basesize[1] // blocksize[1] + 1), device=device, dtype=index_dtype) 3718 row_indices = torch.empty(0, device=device, dtype=index_dtype) 3719 indices = (ccol_indices, row_indices) 3720 values = torch.empty((0, *blocksize, *densesize), device=device, dtype=dtype) 3721 else: 3722 assert 0 # unreachable 3723 kwargs = dict(device=device, dtype=dtype, size=basesize + densesize) 3724 if pin_memory is not None: 3725 kwargs.update(pin_memory=pin_memory) 3726 yield (*indices, values), kwargs 3727 3728 def safeToDense(self, t): 3729 # coalesce is only implemented for COO 3730 if t.layout == torch.sparse_coo: 3731 t = t.coalesce() 3732 return t.to_dense() 3733 3734 # Compares a torch function with a reference function for a given sample input (object of SampleInput) 3735 # Note: only values are compared, type comparison is not done here 3736 def compare_with_reference(self, torch_fn, ref_fn, sample_input, **kwargs): 3737 numpy_sample = sample_input.numpy() 3738 n_inp, n_args, n_kwargs = numpy_sample.input, numpy_sample.args, numpy_sample.kwargs 3739 t_inp, t_args, t_kwargs = sample_input.input, sample_input.args, sample_input.kwargs 3740 3741 actual = torch_fn(t_inp, *t_args, **t_kwargs) 3742 expected = ref_fn(n_inp, *n_args, **n_kwargs) 3743 3744 self.assertEqual(actual, expected, exact_device=False, **kwargs) 3745 3746 # Compares the given Torch and NumPy functions on the given tensor-like object. 3747 # NOTE: both torch_fn and np_fn should be functions that take a single 3748 # tensor (array). If the torch and/or NumPy function require additional 3749 # arguments then wrap the function in a lambda or pass a partial function. 3750 # TODO: add args/kwargs for passing to assertEqual (e.g. rtol, atol) 3751 def compare_with_numpy(self, torch_fn, np_fn, tensor_like, 3752 device=None, dtype=None, **kwargs): 3753 assert TEST_NUMPY 3754 3755 if isinstance(tensor_like, torch.Tensor): 3756 assert device is None 3757 assert dtype is None 3758 t_cpu = tensor_like.detach().cpu() 3759 if t_cpu.dtype is torch.bfloat16: 3760 t_cpu = t_cpu.float() 3761 a = t_cpu.numpy() 3762 t = tensor_like 3763 else: 3764 d = copy.copy(torch_to_numpy_dtype_dict) 3765 d[torch.bfloat16] = np.float32 3766 a = np.array(tensor_like, dtype=d[dtype]) 3767 t = torch.tensor(tensor_like, device=device, dtype=dtype) 3768 3769 np_result = np_fn(a) 3770 torch_result = torch_fn(t).cpu() 3771 3772 # Converts arrays to tensors 3773 if isinstance(np_result, np.ndarray): 3774 try: 3775 np_result = torch.from_numpy(np_result) 3776 except Exception: 3777 # NOTE: copying an array before conversion is necessary when, 3778 # for example, the array has negative strides. 3779 np_result = torch.from_numpy(np_result.copy()) 3780 if t.dtype is torch.bfloat16 and torch_result.dtype is torch.bfloat16 and np_result.dtype is torch.float: 3781 torch_result = torch_result.to(torch.float) 3782 3783 self.assertEqual(np_result, torch_result, **kwargs) 3784 3785 def assertEqualIgnoreType(self, *args, **kwargs) -> None: 3786 # If you are seeing this function used, that means test is written wrongly 3787 # and deserves detailed investigation 3788 return self.assertEqual(*args, exact_dtype=False, **kwargs) 3789 3790 def assertEqualBroadcasting(self, x, y, *args, **kwargs) -> None: 3791 r"""Tests if tensor x equals to y, if y to be broadcast to x.shape. 3792 """ 3793 if not isinstance(y, Iterable): 3794 # int, float, etc. or different shape tensors 3795 y = torch.ones_like(x) * y 3796 if not isinstance(y, torch.Tensor): 3797 # iterable, but not a tensor 3798 y = torch.ones_like(x) * torch.tensor(y) 3799 return self.assertEqual(x, y, *args, **kwargs) 3800 3801 def assertEqual( 3802 self, 3803 x, 3804 y, 3805 msg: Optional[Union[str, Callable[[str], str]]] = None, 3806 *, 3807 atol: Optional[float] = None, 3808 rtol: Optional[float] = None, 3809 equal_nan=True, 3810 exact_dtype=True, 3811 # TODO: default this to True 3812 exact_device=False, 3813 exact_layout=False, 3814 exact_stride=False, 3815 exact_is_coalesced=False 3816 ): 3817 # Hide this function from `pytest`'s traceback 3818 __tracebackhide__ = True 3819 3820 # numpy's dtypes are a superset of what PyTorch supports. In case we encounter an unsupported dtype, we fall 3821 # back to an elementwise comparison. Note that this has to happen here and not for example in 3822 # `TensorOrArrayPair`, since at that stage we can no longer split the array into its elements and perform 3823 # multiple comparisons. 3824 if any( 3825 isinstance(input, np.ndarray) and not has_corresponding_torch_dtype(input.dtype) for input in (x, y) 3826 ): 3827 def to_list(input): 3828 return input.tolist() if isinstance(input, (torch.Tensor, np.ndarray)) else list(input) 3829 3830 x = to_list(x) 3831 y = to_list(y) 3832 # When comparing a sequence of numbers to a tensor, we need to convert the sequence to a tensor here. 3833 # Otherwise, the pair origination of `are_equal` will fail, because the sequence is recognized as container 3834 # that should be checked elementwise while the tensor is not. 3835 elif isinstance(x, torch.Tensor) and isinstance(y, Sequence): 3836 y = torch.as_tensor(y, dtype=x.dtype, device=x.device) 3837 elif isinstance(x, Sequence) and isinstance(y, torch.Tensor): 3838 x = torch.as_tensor(x, dtype=y.dtype, device=y.device) 3839 3840 # unbind NSTs to compare them; don't do this for NJTs 3841 if isinstance(x, torch.Tensor) and x.is_nested and x.layout == torch.strided: 3842 x = x.unbind() 3843 if isinstance(y, torch.Tensor) and y.is_nested and y.layout == torch.strided: 3844 y = y.unbind() 3845 3846 error_metas = not_close_error_metas( 3847 x, 3848 y, 3849 pair_types=( 3850 NonePair, 3851 RelaxedBooleanPair, 3852 RelaxedNumberPair, 3853 TensorOrArrayPair, 3854 TypedStoragePair, 3855 StringPair, 3856 SetPair, 3857 TypePair, 3858 ObjectPair, 3859 ), 3860 sequence_types=( 3861 Sequence, 3862 Sequential, 3863 ModuleList, 3864 ParameterList, 3865 ScriptList, 3866 torch.utils.data.dataset.Subset, 3867 ), 3868 mapping_types=(Mapping, ModuleDict, ParameterDict, ScriptDict), 3869 rtol=rtol, 3870 rtol_override=self.rel_tol, 3871 atol=atol, 3872 atol_override=self.precision, 3873 equal_nan=equal_nan, 3874 check_device=exact_device, 3875 check_dtype=exact_dtype, 3876 check_layout=exact_layout, 3877 check_stride=exact_stride, 3878 check_is_coalesced=exact_is_coalesced, 3879 ) 3880 3881 if error_metas: 3882 # See [ErrorMeta Cycles] 3883 error_metas = [error_metas] 3884 # TODO: compose all metas into one AssertionError 3885 raise error_metas.pop()[0].to_error( 3886 # This emulates unittest.TestCase's behavior if a custom message passed and 3887 # TestCase.longMessage (https://docs.python.org/3/library/unittest.html#unittest.TestCase.longMessage) 3888 # is True (default) 3889 (lambda generated_msg: f"{generated_msg}\n{msg}") if isinstance(msg, str) and self.longMessage else msg 3890 ) 3891 3892 def assertNotEqual(self, x, y, msg: Optional[str] = None, *, # type: ignore[override] 3893 atol: Optional[float] = None, rtol: Optional[float] = None, **kwargs) -> None: 3894 with self.assertRaises(AssertionError, msg=msg): 3895 self.assertEqual(x, y, msg, atol=atol, rtol=rtol, **kwargs) 3896 3897 def assertEqualTypeString(self, x, y) -> None: 3898 # This API is used simulate deprecated x.type() == y.type() 3899 self.assertEqual(x.device, y.device) 3900 self.assertEqual(x.dtype, y.dtype) 3901 self.assertEqual(x.is_sparse, y.is_sparse) 3902 3903 def assertObjectIn(self, obj: Any, iterable: Iterable[Any]) -> None: 3904 for elem in iterable: 3905 if id(obj) == id(elem): 3906 return 3907 raise AssertionError("object not found in iterable") 3908 3909 # Reimplemented to provide special behavior when 3910 # _ignore_not_implemented_error is True 3911 def assertRaises(self, expected_exception, *args, **kwargs): 3912 if self._ignore_not_implemented_error: 3913 context: Optional[AssertRaisesContextIgnoreNotImplementedError] = \ 3914 AssertRaisesContextIgnoreNotImplementedError(expected_exception, self) # type: ignore[call-arg] 3915 try: 3916 return context.handle('assertRaises', args, kwargs) # type: ignore[union-attr] 3917 finally: 3918 # see https://bugs.python.org/issue23890 3919 context = None 3920 else: 3921 return super().assertRaises(expected_exception, *args, **kwargs) 3922 3923 # Reimplemented to provide special behavior when 3924 # _ignore_not_implemented_error is True 3925 def assertRaisesRegex(self, expected_exception, expected_regex, *args, **kwargs): 3926 # Verifies that an exception with the type expected_exception and message 3927 # matching the regular expression defined by expected_regex is thrown. 3928 # If the test is instantiated for a non-native device type (like XLA) 3929 # then the message is not validated. 3930 3931 # Checks whether the test is instantiated for a device type by testing 3932 # if the test class has defined the device_type attribute and, 3933 # if so, tests whether the instantiated device type is native or not 3934 if hasattr(self, 'device_type') and self.device_type not in NATIVE_DEVICES and self.device_type != "mps": # type: ignore[attr-defined] 3935 # empty string matches any string 3936 expected_regex = '' 3937 3938 if self._ignore_not_implemented_error: 3939 context = AssertRaisesContextIgnoreNotImplementedError( # type: ignore[call-arg] 3940 expected_exception, self, expected_regex) 3941 return context.handle('assertRaisesRegex', args, kwargs) # type: ignore[attr-defined] 3942 else: 3943 return super().assertRaisesRegex(expected_exception, expected_regex, *args, **kwargs) 3944 3945 # Verifies that no unraisable exceptions are raised by callable. Unlike regular 3946 # exceptions, these do not actually propagate to the caller and are 3947 # suppressed. We must test for them specially. 3948 def assertNoUnraisable(self, callable, *args, **kwargs): 3949 raised = None 3950 3951 def record_unraisable(unraisable): 3952 nonlocal raised 3953 raised = unraisable 3954 3955 # Disable GC when running the callable to prevent spurious flakiness 3956 # from unlucky GCs inside the callable 3957 prev = gc.isenabled() 3958 gc.disable() 3959 try: 3960 with unittest.mock.patch("sys.unraisablehook", record_unraisable): 3961 callable(*args, **kwargs) 3962 finally: 3963 if prev: 3964 gc.enable() 3965 3966 self.assertIsNone(raised) 3967 3968 # TODO: Support context manager interface 3969 # NB: The kwargs forwarding to callable robs the 'subname' parameter. 3970 # If you need it, manually apply your callable in a lambda instead. 3971 def assertExpectedRaises(self, exc_type, callable, *args, **kwargs): 3972 subname = None 3973 if 'subname' in kwargs: 3974 subname = kwargs['subname'] 3975 del kwargs['subname'] 3976 try: 3977 callable(*args, **kwargs) 3978 except exc_type as e: 3979 self.assertExpected(str(e), subname) 3980 return 3981 # Don't put this in the try block; the AssertionError will catch it 3982 self.fail(msg="Did not raise when expected to") 3983 3984 def assertNotWarn(self, callable, msg=''): 3985 r""" 3986 Test if :attr:`callable` does not raise a warning. 3987 """ 3988 with warnings.catch_warnings(record=True) as ws: 3989 warnings.simplefilter("always") # allow any warning to be raised 3990 with set_warn_always_context(True): 3991 callable() 3992 self.assertTrue(len(ws) == 0, msg) 3993 3994 @contextmanager 3995 def assertWarnsOnceRegex(self, category, regex=''): 3996 """Context manager for code that *must always* warn 3997 3998 This filters expected warnings from the test and fails if 3999 the expected warning is not caught. It uses set_warn_always() to force 4000 TORCH_WARN_ONCE to behave like TORCH_WARN 4001 """ 4002 pattern = re.compile(regex) 4003 with warnings.catch_warnings(record=True) as ws: 4004 warnings.simplefilter("always") # allow any warning to be raised 4005 with set_warn_always_context(True): 4006 yield 4007 if len(ws) == 0: 4008 self.fail('no warning caught') 4009 self.assertTrue(any(type(w.message) is category for w in ws)) 4010 self.assertTrue( 4011 any(re.match(pattern, str(w.message)) for w in ws), 4012 f'{pattern}, {[w.message for w in ws if type(w.message) is category]}') 4013 4014 def assertExpected(self, s, subname=None): 4015 r""" 4016 Test that a string matches the recorded contents of a file 4017 derived from the name of this test and subname. This file 4018 is placed in the 'expect' directory in the same directory 4019 as the test script. You can automatically update the recorded test 4020 output using --accept. 4021 4022 If you call this multiple times in a single function, you must 4023 give a unique subname each time. 4024 """ 4025 if not isinstance(s, str): 4026 raise TypeError("assertExpected is strings only") 4027 4028 def remove_prefix(text, prefix): 4029 if text.startswith(prefix): 4030 return text[len(prefix):] 4031 return text 4032 # NB: we take __file__ from the module that defined the test 4033 # class, so we place the expect directory where the test script 4034 # lives, NOT where test/common_utils.py lives. This doesn't matter in 4035 # PyTorch where all test scripts are in the same directory as 4036 # test/common_utils.py, but it matters in onnx-pytorch 4037 module_id = self.__class__.__module__ 4038 munged_id = remove_prefix(self.id(), module_id + ".") 4039 test_file = os.path.realpath(sys.modules[module_id].__file__) 4040 expected_file = os.path.join(os.path.dirname(test_file), 4041 "expect", 4042 munged_id) 4043 4044 subname_output = "" 4045 if subname: 4046 expected_file += "-" + subname 4047 subname_output = f" ({subname})" 4048 expected_file += ".expect" 4049 expected = None 4050 4051 def accept_output(update_type): 4052 print(f"Accepting {update_type} for {munged_id}{subname_output}:\n\n{s}") 4053 with open(expected_file, 'w') as f: 4054 # Adjust for producer_version, leave s unmodified 4055 s_tag = re.sub(r'(producer_version): "[0-9.]*"', 4056 r'\1: "CURRENT_VERSION"', s) 4057 f.write(s_tag) 4058 4059 try: 4060 with open(expected_file) as f: 4061 expected = f.read() 4062 except OSError as e: 4063 if e.errno != errno.ENOENT: 4064 raise 4065 elif expecttest.ACCEPT: 4066 return accept_output("output") 4067 else: 4068 raise RuntimeError( 4069 f"I got this output for {munged_id}{subname_output}:\n\n{s}\n\n" 4070 "No expect file exists; to accept the current output, run:\n" 4071 f"python {__main__.__file__} {munged_id} --accept") from None 4072 4073 # a hack for JIT tests 4074 if IS_WINDOWS: 4075 expected = re.sub(r'CppOp\[(.+?)\]', 'CppOp[]', expected) 4076 s = re.sub(r'CppOp\[(.+?)\]', 'CppOp[]', s) 4077 4078 # Adjust for producer_version 4079 expected = expected.replace( 4080 'producer_version: "CURRENT_VERSION"', 4081 f'producer_version: "{torch.onnx.producer_version}"' 4082 ) 4083 if expecttest.ACCEPT: 4084 if expected != s: 4085 return accept_output("updated output") 4086 else: 4087 if hasattr(self, "assertMultiLineEqual"): 4088 # Python 2.7 only 4089 # NB: Python considers lhs "old" and rhs "new". 4090 self.assertMultiLineEqual(expected, s) 4091 else: 4092 self.assertEqual(s, expected) 4093 4094 def assertExpectedStripMangled(self, s, subname=None): 4095 s = re.sub(r'__torch__[^ ]+', '', s) 4096 self.assertExpected(s, subname) 4097 4098 def assertGreaterAlmostEqual(self, first, second, places=None, msg=None, delta=None): 4099 """Assert that ``first`` is greater than or almost equal to ``second``. 4100 4101 The equality of ``first`` and ``second`` is determined in a similar way to 4102 the ``assertAlmostEqual`` function of the standard library. 4103 """ 4104 if delta is not None and places is not None: 4105 raise TypeError("specify delta or places not both") 4106 4107 if first >= second: 4108 return 4109 4110 diff = second - first 4111 if delta is not None: 4112 if diff <= delta: 4113 return 4114 4115 standardMsg = f"{first} not greater than or equal to {second} within {delta} delta" 4116 else: 4117 if places is None: 4118 places = 7 4119 4120 if round(diff, places) == 0: 4121 return 4122 4123 standardMsg = f"{first} not greater than or equal to {second} within {places} places" 4124 4125 msg = self._formatMessage(msg, standardMsg) 4126 raise self.failureException(msg) 4127 4128 def assertAtenOp(self, onnx_model, operator, overload_name=""): 4129 all_aten_nodes = [p for p in onnx_model.graph.node 4130 if p.op_type == "ATen" and p.domain == "org.pytorch.aten"] 4131 self.assertTrue(all_aten_nodes) 4132 4133 for op in all_aten_nodes: 4134 attrs = {attr.name: attr.s.decode() for attr in op.attribute} 4135 if attrs.get("operator") == operator: 4136 break 4137 4138 self.assertEqual(attrs["operator"], operator) 4139 self.assertEqual(attrs.get("overload_name", ""), overload_name) 4140 4141 def check_nondeterministic_alert(self, fn, caller_name, should_alert=True): 4142 '''Checks that an operation produces a nondeterministic alert when 4143 expected while `torch.use_deterministic_algorithms(True)` is set. 4144 4145 Args: 4146 fn (callable): Function to check for a nondeterministic alert 4147 4148 caller_name (str): Name of the operation that produces the 4149 nondeterministic alert. This name is expected to appear at the 4150 beginning of the error/warning message. 4151 4152 should_alert (bool, optional): If True, then the check will only pass 4153 if calling `fn` produces a nondeterministic error/warning with the 4154 expected message. If False, then the check will only pass if 4155 calling `fn` does not produce an error. Default: `True`. 4156 ''' 4157 4158 alert_message = '^' + caller_name + ' does not have a deterministic implementation, but you set' 4159 4160 # Check that errors are thrown correctly 4161 with DeterministicGuard(True): 4162 if should_alert: 4163 with self.assertRaisesRegex( 4164 RuntimeError, 4165 alert_message, 4166 msg='expected a non-deterministic error, but it was not raised'): 4167 fn() 4168 4169 else: 4170 # If a nondeterministic error is not expected, make sure 4171 # that it is not raised 4172 try: 4173 fn() 4174 except RuntimeError as e: 4175 if 'does not have a deterministic implementation' in str(e): 4176 self.fail( 4177 'did not expect non-deterministic error message, ' 4178 + 'but got one anyway: "' + str(e) + '"') 4179 # Reraise exceptions unrelated to nondeterminism 4180 raise 4181 4182 # Check that warnings are thrown correctly 4183 with DeterministicGuard(True, warn_only=True): 4184 if should_alert: 4185 with self.assertWarnsRegex( 4186 UserWarning, 4187 alert_message): 4188 fn() 4189 else: 4190 with warnings.catch_warnings(record=True) as w: 4191 warnings.simplefilter("always") 4192 fn() 4193 for warning in w: 4194 if isinstance(warning, UserWarning): 4195 self.assertTrue(re.search(alert_message, str(warning)) is None) 4196 4197 # run code in subprocess and capture exceptions. 4198 @staticmethod 4199 def run_process_no_exception(code, env=None): 4200 import subprocess 4201 4202 popen = subprocess.Popen( 4203 [sys.executable, '-c', code], 4204 stdout=subprocess.PIPE, 4205 stderr=subprocess.PIPE, 4206 env=env) 4207 (stdout, stderr) = popen.communicate() 4208 return (stdout, stderr) 4209 4210 # returns captured stderr 4211 @staticmethod 4212 def runWithPytorchAPIUsageStderr(code): 4213 env = os.environ.copy() 4214 env["PYTORCH_API_USAGE_STDERR"] = "1" 4215 # remove CI flag since this is a wrapped test process. 4216 # CI flag should be set in the parent process only. 4217 env.pop("CI", None) 4218 env.pop("TEST_SHOWLOCALS", None) 4219 (stdout, stderr) = TestCase.run_process_no_exception(code, env=env) 4220 return stderr.decode('ascii') 4221 4222 4223class TestCaseBase(TestCase): 4224 # Calls to super() in dynamically created classes are a bit odd. 4225 # See https://github.com/pytorch/pytorch/pull/118586 for more info 4226 # Subclassing this class and then calling super(TestCaseBase) will run 4227 # TestCase's setUp, tearDown etc functions 4228 pass 4229 4230 4231def download_file(url, binary=True): 4232 from urllib.parse import urlsplit 4233 from urllib import request, error 4234 4235 filename = os.path.basename(urlsplit(url)[2]) 4236 data_dir = get_writable_path(os.path.join(os.path.dirname(__file__), 'data')) 4237 path = os.path.join(data_dir, filename) 4238 4239 if os.path.exists(path): 4240 return path 4241 try: 4242 data = request.urlopen(url, timeout=15).read() 4243 with open(path, 'wb' if binary else 'w') as f: 4244 f.write(data) 4245 return path 4246 except error.URLError as e: 4247 msg = f"could not download test file '{url}'" 4248 warnings.warn(msg, RuntimeWarning) 4249 raise unittest.SkipTest(msg) from e 4250 4251def find_free_port(): 4252 """ 4253 Finds an available port and returns that port number. 4254 4255 NOTE: If this function is being used to allocate a port to Store (or 4256 indirectly via init_process_group or init_rpc), it should be used 4257 in conjuction with the `retry_on_connect_failures` decorator as there is a potential 4258 race condition where the allocated port may become unavailable before it can be used 4259 """ 4260 with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: 4261 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 4262 sock.bind(('localhost', 0)) 4263 _, port = sock.getsockname() 4264 return port 4265 4266# Errors that we can get in c10d initialization for which we should retry tests for. 4267ADDRESS_IN_USE = "Address already in use" 4268CONNECT_TIMEOUT = "connect() timed out." 4269 4270def retry_on_connect_failures(func=None, connect_errors=(ADDRESS_IN_USE)): 4271 """Reruns a test if the test returns a RuntimeError and the exception 4272 contains one of the strings in connect_errors.""" 4273 # This if block is executed when using this function as a decorator with arguments. 4274 if func is None: 4275 return partial(retry_on_connect_failures, connect_errors=connect_errors) 4276 4277 @wraps(func) 4278 def wrapper(*args, **kwargs): 4279 n_retries = 10 4280 tries_remaining = n_retries 4281 while True: 4282 try: 4283 return func(*args, **kwargs) 4284 except RuntimeError as error: 4285 if any(connect_error in str(error) for connect_error in connect_errors): 4286 tries_remaining -= 1 4287 if tries_remaining == 0: 4288 raise RuntimeError(f"Failing after {n_retries} retries with error: {str(error)}") from error 4289 time.sleep(random.random()) 4290 continue 4291 raise 4292 return wrapper 4293 4294 4295# Decorator to retry upon certain Exceptions. 4296def retry(ExceptionToCheck, tries=3, delay=3, skip_after_retries=False): 4297 def deco_retry(f): 4298 @wraps(f) 4299 def f_retry(*args, **kwargs): 4300 mtries, mdelay = tries, delay 4301 while mtries > 1: 4302 try: 4303 return f(*args, **kwargs) 4304 except ExceptionToCheck as e: 4305 msg = "%s, Retrying in %d seconds..." % (str(e), mdelay) 4306 print(msg) 4307 time.sleep(mdelay) 4308 mtries -= 1 4309 try: 4310 return f(*args, **kwargs) 4311 except ExceptionToCheck as e: 4312 raise unittest.SkipTest(f"Skipping after {tries} consecutive {str(e)}") from e if skip_after_retries else e 4313 return f_retry # true decorator 4314 return deco_retry 4315 4316 4317# FIXME: modernize these to be consistent with make_tensor 4318# and review including them in torch.testing 4319# Methods for matrix generation 4320 4321def random_square_matrix_of_rank(l, rank, dtype=torch.double, device='cpu'): 4322 assert rank <= l 4323 A = torch.randn(l, l, dtype=dtype, device=device) 4324 u, s, vh = torch.linalg.svd(A, full_matrices=False) 4325 for i in range(l): 4326 if i >= rank: 4327 s[i] = 0 4328 elif s[i] == 0: 4329 s[i] = 1 4330 return (u * s.to(dtype).unsqueeze(-2)) @ vh 4331 4332def random_well_conditioned_matrix(*shape, dtype, device, mean=1.0, sigma=0.001): 4333 """ 4334 Returns a random rectangular matrix (batch of matrices) 4335 with singular values sampled from a Gaussian with 4336 mean `mean` and standard deviation `sigma`. 4337 The smaller the `sigma`, the better conditioned 4338 the output matrix is. 4339 """ 4340 primitive_dtype = { 4341 torch.float: torch.float, 4342 torch.double: torch.double, 4343 torch.cfloat: torch.float, 4344 torch.cdouble: torch.double 4345 } 4346 x = torch.rand(shape, dtype=dtype, device=device) 4347 m = x.size(-2) 4348 n = x.size(-1) 4349 u, _, vh = torch.linalg.svd(x, full_matrices=False) 4350 s = (torch.randn(*(shape[:-2] + (min(m, n),)), dtype=primitive_dtype[dtype], device=device) * sigma + mean) \ 4351 .sort(-1, descending=True).values.to(dtype) 4352 return (u * s.unsqueeze(-2)) @ vh 4353 4354# Returns a noncontiguous (tensor with the same shape and values as t 4355# The noncontiguous tensor is constructed such that elements in the innermost 4356# dimension are separated by zeros or (whenever possible) nans 4357# TODO: consider more complicated noncontiguity schemes 4358def noncontiguous_like(t): 4359 # Short-circuits if t is already noncontiguous 4360 if not t.is_contiguous(): 4361 return t 4362 4363 # Choose a "weird" value that won't be accessed 4364 if t.dtype.is_floating_point or t.dtype.is_complex: 4365 value = math.nan 4366 elif t.dtype == torch.bool: 4367 value = True 4368 else: 4369 value = 12 4370 4371 result = t.new_empty(t.shape + (2,)) 4372 result[..., 0] = value 4373 result[..., 1] = t.detach() 4374 result = result[..., 1] 4375 result.requires_grad_(t.requires_grad) 4376 return result 4377 4378# TODO: remove this (prefer make_symmetric_matrices below) 4379def random_symmetric_matrix(l, *batches, **kwargs): 4380 dtype = kwargs.get('dtype', torch.double) 4381 device = kwargs.get('device', 'cpu') 4382 A = torch.randn(*(batches + (l, l)), dtype=dtype, device=device) 4383 A = (A + A.mT).div_(2) 4384 return A 4385 4386# Creates a symmetric matrix or batch of symmetric matrices 4387# Shape must be a square matrix or batch of square matrices 4388def make_symmetric_matrices(*shape, device, dtype): 4389 assert shape[-1] == shape[-2] 4390 t = make_tensor(shape, device=device, dtype=dtype) 4391 t = (t + t.mT).div_(2) 4392 return t 4393 4394def random_hermitian_matrix(l, *batches, **kwargs): 4395 dtype = kwargs.get('dtype', torch.double) 4396 device = kwargs.get('device', 'cpu') 4397 A = torch.randn(*(batches + (l, l)), dtype=dtype, device=device) 4398 A = (A + A.mH).div_(2) 4399 return A 4400 4401 4402def random_symmetric_psd_matrix(l, *batches, **kwargs): 4403 """ 4404 Returns a batch of random symmetric positive-semi-definite matrices. 4405 The shape of the result is batch_dims + (matrix_size, matrix_size) 4406 The following example creates a tensor of size 2 x 4 x 3 x 3 4407 >>> # xdoctest: +SKIP("undefined variables") 4408 >>> matrices = random_symmetric_psd_matrix(3, 2, 4, dtype=dtype, device=device) 4409 """ 4410 dtype = kwargs.get('dtype', torch.double) 4411 device = kwargs.get('device', 'cpu') 4412 A = torch.randn(*(batches + (l, l)), dtype=dtype, device=device) 4413 return A @ A.mT 4414 4415 4416def random_hermitian_psd_matrix(matrix_size, *batch_dims, dtype=torch.double, device='cpu'): 4417 """ 4418 Returns a batch of random Hermitian positive-semi-definite matrices. 4419 The shape of the result is batch_dims + (matrix_size, matrix_size) 4420 The following example creates a tensor of size 2 x 4 x 3 x 3 4421 >>> # xdoctest: +SKIP("undefined variables") 4422 >>> matrices = random_hermitian_psd_matrix(3, 2, 4, dtype=dtype, device=device) 4423 """ 4424 A = torch.randn(*(batch_dims + (matrix_size, matrix_size)), dtype=dtype, device=device) 4425 return A @ A.mH 4426 4427 4428# TODO: remove this (prefer make_symmetric_pd_matrices below) 4429def random_symmetric_pd_matrix(matrix_size, *batch_dims, **kwargs): 4430 dtype = kwargs.get('dtype', torch.double) 4431 device = kwargs.get('device', 'cpu') 4432 A = torch.randn(*(batch_dims + (matrix_size, matrix_size)), 4433 dtype=dtype, device=device) 4434 return torch.matmul(A, A.mT) \ 4435 + torch.eye(matrix_size, dtype=dtype, device=device) * 1e-5 4436 4437 4438# Creates a symmetric positive-definite matrix or batch of 4439# such matrices 4440def make_symmetric_pd_matrices(*shape, device, dtype): 4441 assert shape[-1] == shape[-2] 4442 t = make_tensor(shape, device=device, dtype=dtype) 4443 i = torch.eye(shape[-1], device=device, dtype=dtype) * 1e-5 4444 return t @ t.mT + i 4445 4446def random_hermitian_pd_matrix(matrix_size, *batch_dims, dtype, device): 4447 """ 4448 Returns a batch of random Hermitian positive-definite matrices. 4449 The shape of the result is batch_dims + (matrix_size, matrix_size) 4450 The following example creates a tensor of size 2 x 4 x 3 x 3 4451 >>> # xdoctest: +SKIP("undefined variables") 4452 >>> matrices = random_hermitian_pd_matrix(3, 2, 4, dtype=dtype, device=device) 4453 """ 4454 A = torch.randn(*(batch_dims + (matrix_size, matrix_size)), 4455 dtype=dtype, device=device) 4456 return A @ A.mH + torch.eye(matrix_size, dtype=dtype, device=device) 4457 4458# Creates a full rank matrix with distinct singular values or 4459# a batch of such matrices 4460def make_fullrank_matrices_with_distinct_singular_values(*shape, device, dtype, requires_grad=False): 4461 with torch.no_grad(): 4462 t = make_tensor(shape, device=device, dtype=dtype) 4463 u, _, vh = torch.linalg.svd(t, full_matrices=False) 4464 real_dtype = t.real.dtype if t.dtype.is_complex else t.dtype 4465 k = min(shape[-1], shape[-2]) 4466 # We choose the singular values to be "around one" 4467 # This is to make the matrix well conditioned 4468 # s = [2, 3, ..., k+1] 4469 s = torch.arange(2, k + 2, dtype=real_dtype, device=device) 4470 # s = [2, -3, 4, ..., (-1)^k k+1] 4471 s[1::2] *= -1. 4472 # 1 + 1/s so that the singular values are in the range [2/3, 3/2] 4473 # This gives a condition number of 9/4, which should be good enough 4474 s.reciprocal_().add_(1.) 4475 # Note that the singular values need not be ordered in an SVD so 4476 # we don't need need to sort S 4477 x = (u * s.to(u.dtype)) @ vh 4478 x.requires_grad_(requires_grad) 4479 return x 4480 4481def random_matrix(rows, columns, *batch_dims, **kwargs): 4482 """Return rectangular matrix or batches of rectangular matrices. 4483 4484 Parameters: 4485 dtype - the data type 4486 device - the device kind 4487 singular - when True, the output will be singular 4488 """ 4489 dtype = kwargs.get('dtype', torch.double) 4490 device = kwargs.get('device', 'cpu') 4491 silent = kwargs.get("silent", False) 4492 singular = kwargs.get("singular", False) 4493 if silent and not torch._C.has_lapack: 4494 return torch.ones(rows, columns, dtype=dtype, device=device) 4495 4496 A = torch.randn(batch_dims + (rows, columns), dtype=dtype, device=device) 4497 if A.numel() == 0: 4498 return A 4499 u, _, vh = torch.linalg.svd(A, full_matrices=False) 4500 k = min(rows, columns) 4501 s = torch.linspace(1 / (k + 1), 1, k, dtype=dtype, device=device) 4502 if singular: 4503 # make matrix singular 4504 s[k - 1] = 0 4505 if k > 2: 4506 # increase the order of singularity so that the pivoting 4507 # in LU factorization will be non-trivial 4508 s[0] = 0 4509 return (u * s.unsqueeze(-2)) @ vh 4510 4511 4512def random_lowrank_matrix(rank, rows, columns, *batch_dims, **kwargs): 4513 """Return rectangular matrix or batches of rectangular matrices with 4514 given rank. 4515 """ 4516 B = random_matrix(rows, rank, *batch_dims, **kwargs) 4517 C = random_matrix(rank, columns, *batch_dims, **kwargs) 4518 return B.matmul(C) 4519 4520 4521def _generate_indices_prefer_all_rows(rows: int, cols: int, num_indices: int) -> torch.Tensor: 4522 """Generate indices for a row x cols matrix, preferring at least one index per row if possible.""" 4523 indices = [] 4524 n_per_row = math.ceil(num_indices / rows) 4525 col_indices = list(range(cols)) 4526 4527 for r in range(rows): 4528 # Note that this can yield overlapping indices 4529 for c in random.choices(col_indices, k=n_per_row): 4530 indices.append((r, c)) 4531 4532 return torch.tensor(indices[:num_indices]) 4533 4534 4535def random_sparse_matrix(rows, columns, density=0.01, **kwargs): 4536 """Return rectangular random sparse matrix within given density. 4537 4538 The density of the result approaches to given density as the size 4539 of the matrix is increased and a relatively small value of density 4540 is specified but higher than min(rows, columns)/(rows * columns) 4541 for non-singular matrices. 4542 """ 4543 dtype = kwargs.get('dtype', torch.double) 4544 device = kwargs.get('device', 'cpu') 4545 4546 nonzero_elements = max(min(rows, columns), int(rows * columns * density)) 4547 indices = _generate_indices_prefer_all_rows(rows, columns, nonzero_elements) 4548 values = torch.randn(nonzero_elements, dtype=dtype, device=device) 4549 4550 # ensure that the diagonal dominates 4551 values *= torch.tensor([-float(i - j)**2 for i, j in indices], dtype=dtype, device=device).exp() 4552 A = torch.sparse_coo_tensor(indices.t(), values, (rows, columns), device=device) 4553 return A.coalesce() 4554 4555 4556def random_sparse_pd_matrix(matrix_size, density=0.01, **kwargs): 4557 """Return random sparse positive-definite matrix with given density. 4558 4559 The eigenvalues of the matrix are defined as:: 4560 arange(1, matrix_size+1)/matrix_size 4561 4562 Algorithm: 4563 A = diag(arange(1, matrix_size+1)/matrix_size) 4564 while <A density is smaller than required>: 4565 <choose random i, j in range(matrix_size), theta in [0, 2*pi]> 4566 R = <rotation matrix (i,j,theta)> 4567 A = R^T A R 4568 """ 4569 import math 4570 torch = kwargs.get('torch', globals()['torch']) 4571 dtype = kwargs.get('dtype', torch.double) 4572 device = kwargs.get('device', 'cpu') 4573 data = {(i, i): float(i + 1) / matrix_size 4574 for i in range(matrix_size)} 4575 4576 4577 def multiply(data, N, i, j, cs, sn, left=True): 4578 for k in range(N): 4579 if left: 4580 ik, jk = (k, i), (k, j) 4581 else: 4582 ik, jk = (i, k), (j, k) 4583 aik, ajk = data.get(ik, 0), data.get(jk, 0) 4584 aik, ajk = cs * aik + sn * ajk, -sn * aik + cs * ajk 4585 if aik: 4586 data[ik] = aik 4587 else: 4588 data.pop(ik, None) 4589 if ajk: 4590 data[jk] = ajk 4591 else: 4592 data.pop(jk, None) 4593 4594 target_nnz = density * matrix_size * matrix_size 4595 while len(data) < target_nnz: 4596 i = random.randint(0, matrix_size - 1) 4597 j = random.randint(0, matrix_size - 1) 4598 if i != j: 4599 theta = random.uniform(0, 2 * math.pi) 4600 cs = math.cos(theta) 4601 sn = math.sin(theta) 4602 multiply(data, matrix_size, i, j, cs, sn, left=True) 4603 multiply(data, matrix_size, i, j, cs, sn, left=False) 4604 icoords, jcoords, values = [], [], [] 4605 for (i, j), v in sorted(data.items()): 4606 icoords.append(i) 4607 jcoords.append(j) 4608 values.append(v) 4609 indices_tensor = torch.tensor([icoords, jcoords]) 4610 return torch.sparse_coo_tensor(indices_tensor, values, (matrix_size, matrix_size), dtype=dtype, device=device) 4611 4612# FIXME: remove this by updating test suites using it 4613def do_test_dtypes(self, dtypes, layout, device): 4614 for dtype in dtypes: 4615 if dtype != torch.float16: 4616 out = torch.zeros((2, 3), dtype=dtype, layout=layout, device=device) 4617 self.assertIs(dtype, out.dtype) 4618 self.assertIs(layout, out.layout) 4619 self.assertEqual(device, out.device) 4620 4621# FIXME: remove this by updating test suites using it 4622def do_test_empty_full(self, dtypes, layout, device): 4623 shape = torch.Size([2, 3]) 4624 4625 def check_value(tensor, dtype, layout, device, value, requires_grad): 4626 self.assertEqual(shape, tensor.shape) 4627 self.assertIs(dtype, tensor.dtype) 4628 self.assertIs(layout, tensor.layout) 4629 self.assertEqual(tensor.requires_grad, requires_grad) 4630 if tensor.is_cuda and device is not None: 4631 self.assertEqual(device, tensor.device) 4632 if value is not None: 4633 fill = tensor.new(shape).fill_(value) 4634 self.assertEqual(tensor, fill) 4635 4636 def get_int64_dtype(dtype): 4637 module = '.'.join(str(dtype).split('.')[1:-1]) 4638 if not module: 4639 return torch.int64 4640 return operator.attrgetter(module)(torch).int64 4641 4642 default_dtype = torch.get_default_dtype() 4643 check_value(torch.empty(shape), default_dtype, torch.strided, -1, None, False) 4644 check_value(torch.full(shape, -5.), default_dtype, torch.strided, -1, None, False) 4645 for dtype in dtypes: 4646 for rg in {dtype.is_floating_point, False}: 4647 int64_dtype = get_int64_dtype(dtype) 4648 v = torch.empty(shape, dtype=dtype, device=device, layout=layout, requires_grad=rg) 4649 check_value(v, dtype, layout, device, None, rg) 4650 out = v.new() 4651 check_value(torch.empty(shape, out=out, device=device, layout=layout, requires_grad=rg), 4652 dtype, layout, device, None, rg) 4653 check_value(v.new_empty(shape), dtype, layout, device, None, False) 4654 check_value(v.new_empty(shape, dtype=int64_dtype, device=device, requires_grad=False), 4655 int64_dtype, layout, device, None, False) 4656 check_value(torch.empty_like(v), dtype, layout, device, None, False) 4657 check_value(torch.empty_like(v, dtype=int64_dtype, layout=layout, device=device, requires_grad=False), 4658 int64_dtype, layout, device, None, False) 4659 4660 if dtype is not torch.float16 and layout != torch.sparse_coo: 4661 fv = 3 4662 v = torch.full(shape, fv, dtype=dtype, layout=layout, device=device, requires_grad=rg) 4663 check_value(v, dtype, layout, device, fv, rg) 4664 check_value(v.new_full(shape, fv + 1), dtype, layout, device, fv + 1, False) 4665 out = v.new() 4666 check_value(torch.full(shape, fv + 2, out=out, device=device, layout=layout, requires_grad=rg), 4667 dtype, layout, device, fv + 2, rg) 4668 check_value(v.new_full(shape, fv + 3, dtype=int64_dtype, device=device, requires_grad=False), 4669 int64_dtype, layout, device, fv + 3, False) 4670 check_value(torch.full_like(v, fv + 4), dtype, layout, device, fv + 4, False) 4671 check_value(torch.full_like(v, fv + 5, 4672 dtype=int64_dtype, layout=layout, device=device, requires_grad=False), 4673 int64_dtype, layout, device, fv + 5, False) 4674 4675# FIXME: improve load_tests() documentation here 4676running_script_path = None 4677def set_running_script_path(): 4678 global running_script_path 4679 try: 4680 running_file = os.path.abspath(os.path.realpath(sys.argv[0])) 4681 if running_file.endswith('.py'): # skip if the running file is not a script 4682 running_script_path = running_file 4683 except Exception: 4684 pass 4685 4686def check_test_defined_in_running_script(test_case): 4687 if running_script_path is None: 4688 return 4689 test_case_class_file = os.path.abspath(os.path.realpath(inspect.getfile(test_case.__class__))) 4690 assert test_case_class_file == running_script_path, f'Class of loaded TestCase "{test_case.id()}" ' \ 4691 f'is not defined in the running script "{running_script_path}", but in "{test_case_class_file}". Did you ' \ 4692 "accidentally import a unittest.TestCase from another file?" 4693 4694def load_tests(loader, tests, pattern): 4695 set_running_script_path() 4696 test_suite = unittest.TestSuite() 4697 for test_group in tests: 4698 if not DISABLE_RUNNING_SCRIPT_CHK: 4699 for test in test_group: 4700 check_test_defined_in_running_script(test) 4701 if test_group._tests: 4702 test_suite.addTest(test_group) 4703 return test_suite 4704 4705# FIXME: document this and move it to test_serialization 4706class BytesIOContext(io.BytesIO): 4707 def __enter__(self): 4708 return self 4709 4710 def __exit__(self, *args): 4711 pass 4712 4713# Tentative value for nondet_tol for gradcheck when backward implementation 4714# relies on nondeterministic operations, i.e., those listed here: 4715# https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html 4716# 4717# For more information see https://github.com/pytorch/pytorch/issues/56202 4718GRADCHECK_NONDET_TOL = 1e-12 4719 4720TEST_WITH_SLOW_GRADCHECK: bool = TestEnvironment.def_flag( 4721 "TEST_WITH_SLOW_GRADCHECK", 4722 env_var="PYTORCH_TEST_WITH_SLOW_GRADCHECK", 4723) 4724 4725skipIfSlowGradcheckEnv = unittest.skipIf( 4726 TEST_WITH_SLOW_GRADCHECK, 4727 "Tests that don't use gradcheck don't need to run on slow_gradcheck CI", 4728) 4729 4730 4731def gradcheck(fn, inputs, **kwargs): 4732 # Wrapper around gradcheck that enables certain keys by default. 4733 # Use this testing-internal gradcheck instead of autograd.gradcheck so that new features like vmap and 4734 # forward-mode AD are tested by default. We create this wrapper because we'd like to keep new checks 4735 # to be disabled to default for the public-facing api to avoid breaking user code. 4736 # 4737 # All PyTorch devs doing testing should use this wrapper instead of autograd.gradcheck. 4738 default_values = { 4739 "check_batched_grad": True, 4740 "fast_mode": True, 4741 } 4742 4743 if TEST_WITH_SLOW_GRADCHECK: 4744 default_values["fast_mode"] = False 4745 4746 for key, value in default_values.items(): 4747 # default value override values explicitly set to None 4748 k = kwargs.get(key, None) 4749 kwargs[key] = k if k is not None else value 4750 4751 return torch.autograd.gradcheck(fn, inputs, **kwargs) 4752 4753def gradgradcheck(fn, inputs, grad_outputs=None, **kwargs): 4754 # Wrapper around gradgradcheck that enables certain keys by default 4755 # See gradcheck above for an explanation of why we need something like this. 4756 # 4757 # All PyTorch devs doing testing should use this wrapper instead of autograd.gradgradcheck 4758 default_values = { 4759 "check_batched_grad": True, 4760 "fast_mode": True, 4761 } 4762 4763 if TEST_WITH_SLOW_GRADCHECK: 4764 default_values["fast_mode"] = False 4765 4766 for key, value in default_values.items(): 4767 # default value override values explicitly set to None 4768 k = kwargs.get(key, None) 4769 kwargs[key] = k if k is not None else value 4770 4771 return torch.autograd.gradgradcheck(fn, inputs, grad_outputs, **kwargs) 4772 4773 4774def _assertGradAndGradgradChecks(test_case, apply_fn, inputs, **kwargs): 4775 # call assert function rather than returning a bool since it's nicer 4776 # if we get whether this failed on the gradcheck or the gradgradcheck. 4777 test_case.assertTrue(gradcheck(apply_fn, inputs, **kwargs)) 4778 test_case.assertTrue(gradgradcheck(apply_fn, inputs, **kwargs)) 4779 4780 4781@contextmanager 4782def set_cwd(path: str) -> Iterator[None]: 4783 old_cwd = os.getcwd() 4784 try: 4785 os.chdir(path) 4786 yield 4787 finally: 4788 os.chdir(old_cwd) 4789 4790 4791# FIXME: delete this 4792# Using @toleranceOverride specific to your test is the recommended way 4793# of doing this. These are just some values that worked for test_nn. 4794dtype2prec_DONTUSE = {torch.float: 1e-5, 4795 torch.double: 1e-5, 4796 torch.half: 1e-2, 4797 torch.bfloat16: 1e-1} 4798 4799# FIXME: move to test_sparse or sparse utils 4800# This is a wrapper that wraps a test to run this test twice, one with 4801# coalesced=True, another with coalesced=False for coalesced/uncoalesced sparse tensors. 4802def coalescedonoff(f): 4803 @wraps(f) 4804 def wrapped(self, *args, **kwargs): 4805 f(self, *args, **kwargs, coalesced=True) 4806 f(self, *args, **kwargs, coalesced=False) 4807 return wrapped 4808 4809 4810def is_coalesced_indices(s): 4811 indices = s._indices() 4812 hash_coeffs = (1,) + s.shape[s.sparse_dim() - 1:0:-1] 4813 hash_indices = torch.tensor(hash_coeffs, device=s.device).cumprod(-1).flip(-1) 4814 if s.sparse_dim() > 1: 4815 hash_indices.unsqueeze_(-1) 4816 hash_indices = (indices * hash_indices).sum(0) 4817 else: 4818 hash_indices = indices * hash_indices 4819 4820 # check if indices are sorted 4821 res = torch.allclose(hash_indices, hash_indices.sort()[0]) 4822 4823 # check if there are no repeated indices 4824 res = res and torch.allclose(hash_indices, hash_indices.unique()) 4825 4826 return res 4827 4828 4829@contextlib.contextmanager 4830def disable_gc(): 4831 if gc.isenabled(): 4832 try: 4833 gc.disable() 4834 yield 4835 finally: 4836 gc.enable() 4837 else: 4838 yield 4839 4840 4841def find_library_location(lib_name: str) -> Path: 4842 # return the shared library file in the installed folder if exist, 4843 # else the file in the build folder 4844 torch_root = Path(torch.__file__).resolve().parent 4845 path = torch_root / 'lib' / lib_name 4846 if os.path.exists(path): 4847 return path 4848 torch_root = Path(__file__).resolve().parent.parent.parent 4849 return torch_root / 'build' / 'lib' / lib_name 4850 4851def skip_but_pass_in_sandcastle(reason): 4852 """ 4853 Similar to unittest.skip, however in the sandcastle environment it just 4854 "passes" the test instead to avoid creating tasks complaining about tests 4855 skipping continuously. 4856 """ 4857 def decorator(func): 4858 if not IS_SANDCASTLE: 4859 func.__unittest_skip__ = True 4860 func.__unittest_skip_why__ = reason 4861 return func 4862 4863 @wraps(func) 4864 def wrapper(*args, **kwargs): 4865 print(f'Skipping {func.__name__} on sandcastle for following reason: {reason}', file=sys.stderr) 4866 return 4867 return wrapper 4868 4869 return decorator 4870 4871def mock_wrapper(method): 4872 """ 4873 Returns a function that calls the real implementation of a method 4874 in addition to passing args to a mock object. 4875 """ 4876 mock = MagicMock() 4877 4878 @wraps(method) 4879 def wrapper(self, *args, **kwargs): 4880 mock(*args, **kwargs) 4881 return method(self, *args, **kwargs) 4882 wrapper.mock = mock # type: ignore[attr-defined] 4883 return wrapper 4884 4885def get_tensors_from(args, kwargs): 4886 """ Returns a set of all Tensor objects in the given args and kwargs. """ 4887 return set([arg for arg in args if isinstance(arg, Tensor)] + 4888 [v for v in kwargs.values() if isinstance(v, Tensor)]) 4889 4890 4891# Returns scalar tensor representation of a list of integer byte values 4892def bytes_to_scalar(byte_list: List[int], dtype: torch.dtype, device: torch.device): 4893 dtype_to_ctype: Dict[torch.dtype, Any] = { 4894 torch.int8: ctypes.c_int8, 4895 torch.uint8: ctypes.c_uint8, 4896 torch.uint16: ctypes.c_uint16, 4897 torch.uint32: ctypes.c_uint32, 4898 torch.uint64: ctypes.c_uint64, 4899 torch.int16: ctypes.c_int16, 4900 torch.int32: ctypes.c_int32, 4901 torch.int64: ctypes.c_int64, 4902 torch.bool: ctypes.c_bool, 4903 torch.float32: ctypes.c_float, 4904 torch.complex64: ctypes.c_float, 4905 torch.float64: ctypes.c_double, 4906 torch.complex128: ctypes.c_double, 4907 } 4908 ctype = dtype_to_ctype[dtype] 4909 num_bytes = ctypes.sizeof(ctype) 4910 4911 def check_bytes(byte_list): 4912 for byte in byte_list: 4913 assert 0 <= byte <= 255 4914 4915 if dtype.is_complex: 4916 assert len(byte_list) == (num_bytes * 2) 4917 check_bytes(byte_list) 4918 real = ctype.from_buffer((ctypes.c_byte * num_bytes)( 4919 *byte_list[:num_bytes])).value 4920 imag = ctype.from_buffer((ctypes.c_byte * num_bytes)( 4921 *byte_list[num_bytes:])).value 4922 res = real + 1j * imag 4923 else: 4924 assert len(byte_list) == num_bytes 4925 check_bytes(byte_list) 4926 res = ctype.from_buffer((ctypes.c_byte * num_bytes)( 4927 *byte_list)).value 4928 4929 return torch.tensor(res, device=device, dtype=dtype) 4930 4931 4932def copy_func(f): 4933 """Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)""" 4934 g = types.FunctionType(f.__code__, f.__globals__, name=f.__name__, 4935 argdefs=f.__defaults__, 4936 closure=f.__closure__) 4937 g = functools.update_wrapper(g, f) 4938 g.__kwdefaults__ = f.__kwdefaults__ 4939 return g 4940 4941 4942def xfail_inherited_tests(tests): 4943 """ 4944 Given a list of test names which are defined by a superclass of the 4945 class this decorates, mark them as expected failure. This is useful 4946 if you are doing poor man's parameterized tests by subclassing a generic 4947 test class. 4948 """ 4949 def deco(cls): 4950 for t in tests: 4951 # NB: expectedFailure operates by mutating the method in question, 4952 # which is why you have to copy the function first 4953 setattr(cls, t, unittest.expectedFailure(copy_func(getattr(cls, t)))) 4954 return cls 4955 return deco 4956 4957 4958def skip_but_pass_in_sandcastle_if(condition, reason): 4959 """ 4960 Similar to unittest.skipIf, however in the sandcastle environment it just 4961 "passes" the test instead to avoid creating tasks complaining about tests 4962 skipping continuously. 4963 """ 4964 def decorator(func): 4965 if condition: 4966 if IS_SANDCASTLE: 4967 @wraps(func) 4968 def wrapper(*args, **kwargs): 4969 print(f'Skipping {func.__name__} on sandcastle for following reason: {reason}', file=sys.stderr) 4970 return wrapper 4971 else: 4972 func.__unittest_skip__ = True 4973 func.__unittest_skip_why__ = reason 4974 4975 return func 4976 4977 return decorator 4978 4979def dtype_name(dtype): 4980 """ Returns the pretty name of the dtype (e.g. torch.int64 -> int64). """ 4981 return str(dtype).split('.')[1] 4982 4983 4984dtype_abbrs = { 4985 torch.bfloat16: 'bf16', 4986 torch.float64: 'f64', 4987 torch.float32: 'f32', 4988 torch.float16: 'f16', 4989 torch.complex32: 'c32', 4990 torch.complex64: 'c64', 4991 torch.complex128: 'c128', 4992 torch.int8: 'i8', 4993 torch.int16: 'i16', 4994 torch.int32: 'i32', 4995 torch.int64: 'i64', 4996 torch.bool: 'b8', 4997 torch.uint8: 'u8', 4998} 4999 5000 5001@functools.lru_cache 5002def get_cycles_per_ms() -> float: 5003 """Measure and return approximate number of cycles per millisecond for torch.cuda._sleep 5004 """ 5005 5006 def measure() -> float: 5007 start = torch.cuda.Event(enable_timing=True) 5008 end = torch.cuda.Event(enable_timing=True) 5009 start.record() 5010 torch.cuda._sleep(1000000) 5011 end.record() 5012 end.synchronize() 5013 cycles_per_ms = 1000000 / start.elapsed_time(end) 5014 return cycles_per_ms 5015 5016 # Get 10 values and remove the 2 max and 2 min and return the avg. 5017 # This is to avoid system disturbance that skew the results, e.g. 5018 # the very first cuda call likely does a bunch of init, which takes 5019 # much longer than subsequent calls. 5020 # 5021 # Tested on both Tesla V100, Quadro GP100, Titan RTX, RTX 3090 GPUs 5022 # and seems to return stable values. Therefore, we enable caching 5023 # using lru_cache decorator above. 5024 num = 10 5025 vals = [] 5026 for _ in range(num): 5027 vals.append(measure()) 5028 vals = sorted(vals) 5029 return mean(vals[2 : num - 2]) 5030 5031 5032# OpInfo utils 5033 5034T = TypeVar('T') 5035def first_sample(self: unittest.TestCase, samples: Iterable[T]) -> T: 5036 """ 5037 Returns the first sample from an iterable of samples, like those returned by OpInfo. 5038 The test will be skipped if no samples are available. 5039 """ 5040 try: 5041 return next(iter(samples)) 5042 except StopIteration as e: 5043 raise unittest.SkipTest('Skipped! Need at least 1 sample input') from e 5044 5045# this helper method is to recursively 5046# clone the tensor-type input of operators tested by OpInfo 5047def clone_input_helper(input): 5048 if isinstance(input, torch.Tensor): 5049 return torch.clone(input) 5050 5051 if isinstance(input, Sequence): 5052 return tuple(map(clone_input_helper, input)) 5053 5054 return input 5055 5056@contextmanager 5057def custom_op(opname, symbolic_fn, opset_version): 5058 """Context manager/decorator to test ONNX export with custom operator""" 5059 try: 5060 register_custom_op_symbolic(opname, symbolic_fn, opset_version) 5061 yield 5062 finally: 5063 unregister_custom_op_symbolic(opname, opset_version) 5064 5065 5066def outs_and_grads(fn, graph_inps, inps): 5067 outs = fn(*graph_inps) 5068 for out in pytree.tree_leaves(outs): 5069 if isinstance(out, torch.Tensor) and out.requires_grad: 5070 out.sum().backward(retain_graph=True) 5071 grads = [inp.grad for inp in pytree.tree_leaves(inps) if isinstance(inp, torch.Tensor)] 5072 for inp in pytree.tree_leaves(inps): 5073 if isinstance(inp, torch.Tensor): 5074 inp.grad = None 5075 return outs, grads 5076 5077def compare_equal_outs_and_grads(test, m1, m2, inps): 5078 r1, g1 = outs_and_grads(m1, inps, inps) 5079 r2, g2 = outs_and_grads(m2, inps, inps) 5080 test.assertEqual(r1, r2) 5081 test.assertEqual(g1, g2) 5082 5083class TestGradients(TestCase): 5084 exact_dtype = True 5085 5086 # Copies inputs to inplace operations to avoid inplace modifications 5087 # to leaves requiring gradient 5088 def _get_safe_inplace(self, inplace_variant): 5089 @wraps(inplace_variant) 5090 def _fn(t, *args, **kwargs): 5091 return inplace_variant(t.clone(), *args, **kwargs) 5092 5093 return _fn 5094 5095 def _check_helper(self, device, dtype, op, variant, check, *, check_forward_ad=False, check_backward_ad=True, 5096 check_batched_grad=None, check_batched_forward_grad=False): 5097 assert check in ('gradcheck', 'bwgrad_bwgrad', 'fwgrad_bwgrad') 5098 # NB: check_backward_ad does not affect gradgradcheck (always True) 5099 if variant is None: 5100 self.skipTest("Skipped! Variant not implemented.") 5101 if not op.supports_dtype(dtype, torch.device(device).type): 5102 self.skipTest(f"Skipped! {op.name} does not support dtype {str(dtype)}") 5103 5104 def is_inplace(variant): 5105 if hasattr(variant, "__wrapped__"): 5106 return variant.__wrapped__ is op.get_inplace() 5107 return variant is op.get_inplace() 5108 5109 include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex 5110 5111 samples = op.sample_inputs(device, dtype, requires_grad=True, include_conjugated_inputs=include_conjugated_inputs, 5112 small_inputs_only=TEST_WITH_SLOW_GRADCHECK) 5113 5114 for sample in samples: 5115 if sample.broadcasts_input and is_inplace(variant): 5116 continue 5117 5118 # Gradcheck expects tensors as its input, but autograd actually supports tensorlists 5119 # and tensors passed as kwargs. The following creates a function that accepts just 5120 # the tensors that require grad as varargs, and then recomposes them back into the 5121 # original input. 5122 5123 # Creates gradcheck inputs by identifying tensors requiring grad 5124 all_args = None 5125 if is_iterable_of_tensors(sample.input): 5126 all_args = chain(sample.input, sample.args, sample.kwargs.values()) 5127 else: 5128 all_args = tuple(chain((sample.input,), sample.args, sample.kwargs.values())) 5129 gradcheck_args = tuple(x for x in all_args if (isinstance(x, torch.Tensor) and x.requires_grad)) 5130 5131 # Verifies sample input tensors should have no grad 5132 # This may happen if the same tensor is used in two different SampleInputs 5133 for t in gradcheck_args: 5134 self.assertIsNone(t.grad, 5135 "A sampled input has a gradient before running autograd. " 5136 "This usually means that (at least) one input tensor is reused " 5137 "across different SampleInputs. " 5138 "Please create a new tensor for each SampleInput.") 5139 5140 def _input_recomposition_helper(inputs, inp, input_idx): 5141 if is_iterable_of_tensors(inp): 5142 tensor_list = [] 5143 for x in inp: 5144 if isinstance(x, torch.Tensor) and x.requires_grad: 5145 tensor_list.append(inputs[input_idx]) 5146 input_idx = input_idx + 1 5147 else: 5148 tensor_list.append(x) 5149 return tensor_list, input_idx 5150 elif isinstance(inp, torch.Tensor) and inp.requires_grad: 5151 return inputs[input_idx], input_idx + 1 5152 else: 5153 return inp, input_idx 5154 5155 def fn(*inputs): 5156 # Puts inputs back into sample properly 5157 positional_args = [] 5158 input_idx = 0 5159 inp, input_idx = _input_recomposition_helper(inputs, sample.input, input_idx) 5160 positional_args.append(inp) 5161 5162 for x in sample.args: 5163 inp, input_idx = _input_recomposition_helper(inputs, x, input_idx) 5164 positional_args.append(inp) 5165 5166 # Recreates kwargs 5167 kwargs = {} 5168 for k, v in sample.kwargs.items(): 5169 inp, input_idx = _input_recomposition_helper(inputs, v, input_idx) 5170 kwargs[k] = inp 5171 5172 output = op.gradcheck_wrapper(variant, *positional_args, **kwargs) 5173 if sample.output_process_fn_grad is not None: 5174 return sample.output_process_fn_grad(output) 5175 return output 5176 5177 if check == 'gradcheck': 5178 if check_batched_grad is None: 5179 check_batched_grad = op.check_batched_grad 5180 self.assertTrue(gradcheck(fn, gradcheck_args, 5181 check_batched_grad=check_batched_grad, 5182 check_grad_dtypes=True, 5183 nondet_tol=op.gradcheck_nondet_tol, 5184 fast_mode=op.gradcheck_fast_mode, 5185 check_forward_ad=check_forward_ad, 5186 check_backward_ad=check_backward_ad, 5187 check_undefined_grad=True, 5188 check_batched_forward_grad=check_batched_forward_grad)) 5189 elif check in ('bwgrad_bwgrad', 'fwgrad_bwgrad'): # gradgrad check 5190 self.assertFalse(check_forward_ad, msg="Cannot run forward AD check for gradgradcheck") 5191 for gen_non_contig_grad_outputs in (False, True): 5192 kwargs = { 5193 "gen_non_contig_grad_outputs": gen_non_contig_grad_outputs, 5194 "check_batched_grad": op.check_batched_gradgrad, 5195 "check_grad_dtypes": True, 5196 "nondet_tol": op.gradcheck_nondet_tol, 5197 "fast_mode": op.gradcheck_fast_mode 5198 } 5199 if check == "fwgrad_bwgrad": 5200 kwargs["check_fwd_over_rev"] = True 5201 kwargs["check_rev_over_rev"] = False 5202 kwargs["check_batched_grad"] = False 5203 kwargs["check_undefined_grad"] = False 5204 5205 self.assertTrue(gradgradcheck(fn, gradcheck_args, **kwargs)) 5206 else: 5207 self.assertTrue(False, msg="Unknown check requested!") 5208 5209 def _grad_test_helper(self, device, dtype, op, variant, *, check_forward_ad=False, check_backward_ad=True, 5210 check_batched_grad=None, check_batched_forward_grad=False): 5211 return self._check_helper(device, dtype, op, variant, 'gradcheck', check_forward_ad=check_forward_ad, 5212 check_backward_ad=check_backward_ad, check_batched_grad=check_batched_grad, 5213 check_batched_forward_grad=check_batched_forward_grad) 5214 5215 def _skip_helper(self, op, device, dtype): 5216 if dtype not in op.supported_backward_dtypes(torch.device(device).type): 5217 self.skipTest("Skipped! Op doesn't support autograd for this dtype.") 5218 if not op.supports_autograd and not op.supports_forward_ad: 5219 self.skipTest("Skipped! autograd not supported.") 5220 5221def make_lazy_class(cls): 5222 5223 def lazy_init(self, cb): 5224 self._cb = cb 5225 self._value = None 5226 5227 cls.__init__ = lazy_init 5228 5229 for basename in [ 5230 "add", "sub", "mul", "truediv", "floordiv", "mod", "divmod", "pow", 5231 "lshift", "rshift", "and", "or", "xor", "neg", "pos", "abs", "invert", 5232 "eq", "ne", "lt", "le", "gt", "ge", "bool", "int", "index", 5233 ]: 5234 name = f"__{basename}__" 5235 5236 def inner_wrapper(name): 5237 use_operator = basename not in ("bool", "int") 5238 5239 def wrapped(self, *args, **kwargs): 5240 if self._cb is not None: 5241 self._value = self._cb() 5242 self._cb = None 5243 if not use_operator: 5244 return getattr(self._value, name)(*args, **kwargs) 5245 else: 5246 return getattr(operator, name)(self._value, *args, **kwargs) 5247 return wrapped 5248 5249 setattr(cls, name, inner_wrapper(name)) 5250 5251 return cls 5252 5253 5254# Base TestCase for NT tests; used to define common helpers, etc. 5255class NestedTensorTestCase(TestCase): 5256 def assertEqualIgnoringNestedInts(self, a, b): 5257 # unbinding NJTs allows us to compare them as essentially equal without 5258 # caring about exact nested int comparison 5259 def _unbind_njts(x): 5260 if isinstance(x, torch.Tensor) and x.is_nested and x.layout == torch.jagged: 5261 return x.unbind() 5262 else: 5263 return x 5264 5265 self.assertEqual(pytree.tree_map(_unbind_njts, a), pytree.tree_map(_unbind_njts, b)) 5266 5267 @contextlib.contextmanager 5268 def branch_nested_state(self): 5269 """Context manager to branch and restore the nested tensor state.""" 5270 nested_tensor_module = torch.nested._internal.nested_tensor 5271 original_tensor_symint_registry = nested_tensor_module._tensor_symint_registry.copy() 5272 original_tensor_id_counter = nested_tensor_module._tensor_id_counter 5273 try: 5274 yield 5275 finally: 5276 nested_tensor_module._tensor_id_counter = original_tensor_id_counter 5277 nested_tensor_module._tensor_symint_registry = original_tensor_symint_registry 5278 5279 5280@make_lazy_class 5281class LazyVal: 5282 pass 5283 5284 5285def munge_exc(e, *, suppress_suffix=True, suppress_prefix=True, file=None, skip=0): 5286 if file is None: 5287 file = inspect.stack()[1 + skip].filename # skip one frame 5288 5289 file = _as_posix_path(file) 5290 s = _as_posix_path(str(e)) 5291 5292 # Remove everything that looks like stack frames in NOT this file 5293 def repl_frame(m): 5294 if m.group(1) != file: 5295 return "" 5296 # Don't accept top-level, even for this script, these will wobble 5297 # depending on how the testing script was invoked 5298 if m.group(2) == "<module>": 5299 return "" 5300 5301 return m.group(0) 5302 5303 s = re.sub(r' File "([^"]+)", line \d+, in (.+)\n( .+\n( +[~^]+ *\n)?)+', repl_frame, s) 5304 s = re.sub(r"line \d+", "line N", s) 5305 s = re.sub(r".py:\d+", ".py:N", s) 5306 s = re.sub(file, _as_posix_path(os.path.basename(file)), s) 5307 s = re.sub(_as_posix_path(os.path.join(os.path.dirname(torch.__file__), "")), "", s) 5308 if suppress_suffix: 5309 s = re.sub(r"\n*Set TORCH_LOGS.+", "", s, flags=re.DOTALL) 5310 s = re.sub(r"\n*You can suppress this exception.+", "", s, flags=re.DOTALL) 5311 if suppress_prefix: 5312 s = re.sub(r"Cannot export model.+\n\n", "", s) 5313 s = re.sub(r" +$", "", s, flags=re.MULTILINE) 5314 return s 5315