xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/common_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3r"""Importing this file must **not** initialize CUDA context. test_distributed
4relies on this assumption to properly run. This means that when this is imported
5no CUDA calls shall be made, including torch.cuda.device_count(), etc.
6
7torch.testing._internal.common_cuda.py can freely initialize CUDA context when imported.
8"""
9
10import argparse
11import contextlib
12import copy
13import ctypes
14import errno
15import functools
16import gc
17import hashlib
18import inspect
19import io
20import json
21import logging
22import math
23import operator
24import os
25import platform
26import random
27import re
28import shutil
29import signal
30import socket
31import subprocess
32import sys
33import tempfile
34import threading
35import time
36import types
37import unittest
38import warnings
39from collections.abc import Mapping, Sequence
40from contextlib import closing, contextmanager
41from copy import deepcopy
42from dataclasses import dataclass
43from enum import Enum
44from functools import partial, wraps
45from itertools import product, chain
46from pathlib import Path
47from statistics import mean
48from typing import (
49    Any,
50    Callable,
51    Dict,
52    Iterable,
53    Iterator,
54    List,
55    Optional,
56    Tuple,
57    Type,
58    TypeVar,
59    Union,
60)
61from unittest.mock import MagicMock
62
63import expecttest
64import numpy as np
65
66import __main__  # type: ignore[import]
67import torch
68import torch.backends.cudnn
69import torch.backends.mkl
70import torch.backends.mps
71import torch.backends.xnnpack
72import torch.cuda
73from torch import Tensor
74from torch._C import ScriptDict, ScriptList  # type: ignore[attr-defined]
75from torch._dynamo.trace_rules import _as_posix_path
76from torch._utils_internal import get_writable_path
77from torch.nn import (
78    ModuleDict,
79    ModuleList,
80    ParameterDict,
81    ParameterList,
82    Sequential,
83)
84from torch.onnx import (
85    register_custom_op_symbolic,
86    unregister_custom_op_symbolic,
87)
88from torch.testing import make_tensor
89from torch.testing._comparison import (
90    BooleanPair,
91    NonePair,
92    NumberPair,
93    Pair,
94    TensorLikePair,
95)
96from torch.testing._comparison import not_close_error_metas
97from torch.testing._internal.common_dtype import get_all_dtypes
98from torch.utils._import_utils import _check_module_exists
99import torch.utils._pytree as pytree
100try:
101    import pytest
102    has_pytest = True
103except ImportError:
104    has_pytest = False
105
106
107MI300_ARCH = ("gfx940", "gfx941", "gfx942")
108
109
110def freeze_rng_state(*args, **kwargs):
111    return torch.testing._utils.freeze_rng_state(*args, **kwargs)
112
113
114# Class to keep track of test flags configurable by environment variables.
115# Flags set here are intended to be read-only and should not be modified after
116# definition.
117# TODO: Expand this class to handle abritrary settings in addition to boolean flags?
118class TestEnvironment:
119    # Set of env vars to set for the repro command that is output on test failure.
120    # Specifically, this includes env vars that are set to non-default values and
121    # are not implied. Maps from env var name -> value (int)
122    repro_env_vars: dict = {}
123
124    # Defines a flag usable throughout the test suite, determining its value by querying
125    # the specified environment variable.
126    #
127    # Args:
128    #     name (str): The name of the flag. A global variable with this name will be set
129    #         for convenient access throughout the test suite.
130    #     env_var (str): The name of the primary environment variable from which to
131    #         determine the value of this flag. If this is None or the environment variable
132    #         is unset, the default value will be used unless otherwise implied (see
133    #         implied_by_fn). Default: None
134    #     default (bool): The default value to use for the flag if unset by the environment
135    #         variable and unimplied. Default: False
136    #     include_in_repro (bool): Indicates whether this flag should be included in the
137    #         repro command that is output on test failure (i.e. whether it is possibly
138    #         relevant to reproducing the test failure). Default: True
139    #     enabled_fn (Callable): Callable returning whether the flag should be enabled
140    #         given the environment variable value and the default value. Default: Lambda
141    #         requiring "0" to disable if on by default OR "1" to enable if off by default.
142    #     implied_by_fn (Callable): Thunk returning a bool to imply this flag as enabled
143    #         by something outside of its primary environment variable setting. For example,
144    #         this can be useful if the value of another environment variable implies the flag
145    #         as enabled. Default: Lambda returning False to indicate no implications.
146    @staticmethod
147    def def_flag(
148        name,
149        env_var=None,
150        default=False,
151        include_in_repro=True,
152        enabled_fn=lambda env_var_val, default: (
153            (env_var_val != "0") if default else (env_var_val == "1")),
154        implied_by_fn=lambda: False,
155    ):
156        enabled = default
157        if env_var is not None:
158            env_var_val = os.getenv(env_var)
159            enabled = enabled_fn(env_var_val, default)
160        implied = implied_by_fn()
161        enabled = enabled or implied
162        if include_in_repro and (env_var is not None) and (enabled != default) and not implied:
163            TestEnvironment.repro_env_vars[env_var] = env_var_val
164
165        # export flag globally for convenience
166        assert name not in globals(), f"duplicate definition of flag '{name}'"
167        globals()[name] = enabled
168        return enabled
169
170    # Defines a setting usable throughout the test suite, determining its value by querying
171    # the specified environment variable. This differs from a flag in that it's not restricted
172    # to a boolean value.
173    #
174    # Args:
175    #     name (str): The name of the setting. A global variable with this name will be set
176    #         for convenient access throughout the test suite.
177    #     env_var (str): The name of the primary environment variable from which to
178    #         determine the value of this setting. If this is None or the environment variable
179    #         is unset, the default value will be used. Default: None
180    #     default (Any): The default value to use for the setting if unset by the environment
181    #         variable. Default: None
182    #     include_in_repro (bool): Indicates whether this setting should be included in the
183    #         repro command that is output on test failure (i.e. whether it is possibly
184    #         relevant to reproducing the test failure). Default: True
185    #     parse_fn (Callable): Callable parsing the env var string. Default value just uses
186    #         the string itself.
187    @staticmethod
188    def def_setting(
189        name,
190        env_var=None,
191        default=None,
192        include_in_repro=True,
193        parse_fn=lambda maybe_val_str: maybe_val_str,
194    ):
195        value = default if env_var is None else os.getenv(env_var)
196        value = parse_fn(value)
197        if include_in_repro and (value != default):
198            TestEnvironment.repro_env_vars[env_var] = value
199
200        # export setting globally for convenience
201        assert name not in globals(), f"duplicate definition of setting '{name}'"
202        globals()[name] = value
203        return value
204
205    # Returns a string prefix usable to set environment variables for any test
206    # settings that should be explicitly set to match this instantiation of the
207    # test suite.
208    # Example: "PYTORCH_TEST_WITH_ASAN=1 PYTORCH_TEST_WITH_ROCM=1"
209    @staticmethod
210    def repro_env_var_prefix() -> str:
211        return " ".join([f"{env_var}={value}"
212                         for env_var, value in TestEnvironment.repro_env_vars.items()])
213
214
215log = logging.getLogger(__name__)
216torch.backends.disable_global_flags()
217
218FILE_SCHEMA = "file://"
219if sys.platform == 'win32':
220    FILE_SCHEMA = "file:///"
221
222# NB: This flag differs semantically from others in that setting the env var to any
223# non-empty value will cause it to be true:
224#   CI=1, CI="true", CI=0, etc. all set the flag to be true.
225#   CI= and an unset CI set the flag to be false.
226# GitHub sets the value to CI="true" to enable it.
227IS_CI: bool = TestEnvironment.def_flag(
228    "IS_CI",
229    env_var="CI",
230    include_in_repro=False,
231    enabled_fn=lambda env_var_value, _: bool(env_var_value),
232)
233IS_SANDCASTLE: bool = TestEnvironment.def_flag(
234    "IS_SANDCASTLE",
235    env_var="SANDCASTLE",
236    implied_by_fn=lambda: os.getenv("TW_JOB_USER") == "sandcastle",
237    include_in_repro=False,
238)
239
240_is_fbcode_default = (
241    hasattr(torch._utils_internal, "IS_FBSOURCE") and
242    torch._utils_internal.IS_FBSOURCE
243)
244
245IS_FBCODE: bool = TestEnvironment.def_flag(
246    "IS_FBCODE",
247    env_var="PYTORCH_TEST_FBCODE",
248    default=_is_fbcode_default,
249    include_in_repro=False,
250)
251IS_REMOTE_GPU: bool = TestEnvironment.def_flag(
252    "IS_REMOTE_GPU",
253    env_var="PYTORCH_TEST_REMOTE_GPU",
254    include_in_repro=False,
255)
256
257DISABLE_RUNNING_SCRIPT_CHK: bool = TestEnvironment.def_flag(
258    "DISABLE_RUNNING_SCRIPT_CHK",
259    env_var="PYTORCH_DISABLE_RUNNING_SCRIPT_CHK",
260    include_in_repro=False,
261)
262# NB: enabled by default unless in an fbcode context.
263PRINT_REPRO_ON_FAILURE: bool = TestEnvironment.def_flag(
264    "PRINT_REPRO_ON_FAILURE",
265    env_var="PYTORCH_PRINT_REPRO_ON_FAILURE",
266    default=(not IS_FBCODE),
267    include_in_repro=False,
268)
269
270# possibly restrict OpInfo tests to a single sample input
271OPINFO_SAMPLE_INPUT_INDEX: Optional[int] = TestEnvironment.def_setting(
272    "OPINFO_SAMPLE_INPUT_INDEX",
273    env_var="PYTORCH_OPINFO_SAMPLE_INPUT_INDEX",
274    default=None,
275    # Don't include the env var value in the repro command because the info will
276    # be queried from the tracked sample input instead
277    include_in_repro=False,
278    parse_fn=lambda val: None if val is None else int(val),
279)
280
281DEFAULT_DISABLED_TESTS_FILE = '.pytorch-disabled-tests.json'
282DEFAULT_SLOW_TESTS_FILE = 'slow_tests.json'
283
284disabled_tests_dict = {}
285slow_tests_dict = {}
286
287def maybe_load_json(filename):
288    if os.path.isfile(filename):
289        with open(filename) as fp:
290            return json.load(fp)
291    log.warning("Attempted to load json file '%s' but it does not exist.", filename)
292    return {}
293
294# set them here in case the tests are running in a subprocess that doesn't call run_tests
295if os.getenv("SLOW_TESTS_FILE", ""):
296    slow_tests_dict = maybe_load_json(os.getenv("SLOW_TESTS_FILE", ""))
297if os.getenv("DISABLED_TESTS_FILE", ""):
298    disabled_tests_dict = maybe_load_json(os.getenv("DISABLED_TESTS_FILE", ""))
299
300NATIVE_DEVICES = ('cpu', 'cuda', 'xpu', 'meta', torch._C._get_privateuse1_backend_name())
301
302check_names = ['orin', 'concord', 'galen', 'xavier', 'nano', 'jetson', 'tegra']
303IS_JETSON = any(name in platform.platform() for name in check_names)
304
305def gcIfJetson(fn):
306    # Irregular Jetson host/device memory setup requires cleanup to avoid tests being killed
307    @functools.wraps(fn)
308    def wrapper(*args, **kwargs):
309        if IS_JETSON:
310            gc.collect()
311            torch.cuda.empty_cache()
312        fn(*args, **kwargs)
313    return wrapper
314
315# Tries to extract the current test function by crawling the stack.
316# If unsuccessful, return None.
317def extract_test_fn() -> Optional[Callable]:
318    try:
319        stack = inspect.stack()
320        for frame_info in stack:
321            frame = frame_info.frame
322            if "self" not in frame.f_locals:
323                continue
324            self_val = frame.f_locals["self"]
325            if isinstance(self_val, unittest.TestCase):
326                test_id = self_val.id()
327                test_name = test_id.split('.')[2]
328                test_fn = getattr(self_val, test_name).__func__
329                return test_fn
330    except Exception:
331        pass
332    return None
333
334# Contains tracked input data useful for debugging purposes
335@dataclass
336class TrackedInput:
337    index: int
338    val: Any
339    type_desc: str
340
341# Attempt to pull out tracked input information from the test function.
342# A TrackedInputIter is used to insert this information.
343def get_tracked_input() -> Optional[TrackedInput]:
344    test_fn = extract_test_fn()
345    if test_fn is None:
346        return None
347    if not hasattr(test_fn, "tracked_input"):
348        return None
349    return test_fn.tracked_input
350
351def clear_tracked_input():
352    test_fn = extract_test_fn()
353    if test_fn is None:
354        return
355    if not hasattr(test_fn, "tracked_input"):
356        return None
357    test_fn.tracked_input = None
358
359# Wraps an iterator and tracks the most recent value the iterator produces
360# for debugging purposes. Tracked values are stored on the test function.
361class TrackedInputIter:
362    def __init__(self, child_iter, input_type_desc,
363                 callback=lambda x: x, set_seed=True, restrict_to_index=None):
364        self.child_iter = enumerate(child_iter)
365        # Input type describes the things we're tracking (e.g. "sample input", "error input").
366        self.input_type_desc = input_type_desc
367        # Callback is run on each iterated thing to get the thing to track.
368        self.callback = callback
369        self.test_fn = extract_test_fn()
370        # Indicates whether the random seed should be set before each call to the iterator
371        self.set_seed = set_seed
372        # Indicates that iteration should be restricted to only the provided index.
373        # If None, no restriction is done
374        self.restrict_to_index = restrict_to_index
375
376    def __iter__(self):
377        return self
378
379    def __next__(self):
380        while True:
381            if self.set_seed:
382                # use a test-name-specific hash for the seed if possible
383                seed = (
384                    int.from_bytes(hashlib.sha256(
385                        self.test_fn.__qualname__.encode("utf-8")).digest()[:4], 'little')
386                    if self.test_fn is not None else SEED
387                )
388                set_rng_seed(seed)
389
390            # allow StopIteration to bubble up
391            input_idx, input_val = next(self.child_iter)
392            if (self.restrict_to_index is None) or (input_idx == self.restrict_to_index):
393                break
394
395        self._set_tracked_input(
396            TrackedInput(
397                index=input_idx, val=self.callback(input_val), type_desc=self.input_type_desc
398            )
399        )
400        return input_val
401
402    def _set_tracked_input(self, tracked_input: TrackedInput):
403        if self.test_fn is None:
404            return
405        if not hasattr(self.test_fn, "tracked_input"):
406            return
407        self.test_fn.tracked_input = tracked_input
408
409class _TestParametrizer:
410    """
411    Decorator class for parametrizing a test function, yielding a set of new tests spawned
412    from the original generic test, each specialized for a specific set of test inputs. For
413    example, parametrizing a test across the set of ops will result in a test function per op.
414
415    The decision of how to parametrize / what to parametrize over is intended to be implemented
416    by each derived class.
417
418    In the details, the decorator adds a 'parametrize_fn' property to the test function. This function
419    is intended to be called later by one of:
420      * Device-specific test instantiation via instantiate_device_type_tests(). Note that for this
421        case there is no need to explicitly parametrize over device type, as that is handled separately.
422      * Device-agnostic parametrized test instantiation via instantiate_parametrized_tests().
423
424    If the decorator is applied to a test function that already has a 'parametrize_fn' property, a new
425    composite 'parametrize_fn' will be created that generates tests with the product of the parameters
426    generated by the old and new parametrize_fns. This allows for convenient composability of decorators.
427    """
428    def _parametrize_test(self, test, generic_cls, device_cls):
429        """
430        Parametrizes the given test function across whatever dimension is specified by the derived class.
431        Tests can be parametrized over any arbitrary dimension or combination of dimensions, such as all
432        ops, all modules, or all ops + their associated dtypes.
433
434        Args:
435            test (fn): Test function to parametrize over
436            generic_cls (class): Generic test class object containing tests (e.g. TestFoo)
437            device_cls (class): Device-specialized test class object (e.g. TestFooCPU); set to None
438                if the tests are not part of a device-specific set
439
440        Returns:
441            Generator object returning 4-tuples of:
442                test (fn): Parametrized test function; must support a device arg and args for any params
443                test_name (str): Parametrized suffix for the test (e.g. opname_int64); will be appended to
444                    the base name of the test
445                param_kwargs (dict): Param kwargs to pass to the test (e.g. {'op': 'add', 'dtype': torch.int64})
446                decorator_fn (callable): Callable[[Dict], List] for list of decorators to apply given param_kwargs
447        """
448        raise NotImplementedError
449
450    def __call__(self, fn):
451        if hasattr(fn, 'parametrize_fn'):
452            # Do composition with the product of args.
453            old_parametrize_fn = fn.parametrize_fn
454            new_parametrize_fn = self._parametrize_test
455            fn.parametrize_fn = compose_parametrize_fns(old_parametrize_fn, new_parametrize_fn)
456        else:
457            fn.parametrize_fn = self._parametrize_test
458        return fn
459
460
461def compose_parametrize_fns(old_parametrize_fn, new_parametrize_fn):
462    """
463    Returns a parametrize_fn that parametrizes over the product of the parameters handled
464    by the given parametrize_fns. Each given parametrize_fn should each have the signature
465    f(test, generic_cls, device_cls).
466
467    The test names will be a combination of the names produced by the parametrize_fns in
468    "<new_name>_<old_name>" order. This order is done to match intuition for constructed names
469    when composing multiple decorators; the names will be built in top to bottom order when stacking
470    parametrization decorators.
471
472    Args:
473        old_parametrize_fn (callable) - First parametrize_fn to compose.
474        new_parametrize_fn (callable) - Second parametrize_fn to compose.
475    """
476
477    def composite_fn(test, generic_cls, device_cls,
478                     old_parametrize_fn=old_parametrize_fn,
479                     new_parametrize_fn=new_parametrize_fn):
480        old_tests = list(old_parametrize_fn(test, generic_cls, device_cls))
481        for (old_test, old_test_name, old_param_kwargs, old_dec_fn) in old_tests:
482            for (new_test, new_test_name, new_param_kwargs, new_dec_fn) in \
483                    new_parametrize_fn(old_test, generic_cls, device_cls):
484                redundant_params = set(old_param_kwargs.keys()).intersection(new_param_kwargs.keys())
485                if redundant_params:
486                    raise RuntimeError('Parametrization over the same parameter by multiple parametrization '
487                                       f'decorators is not supported. For test "{test.__name__}", the following parameters '
488                                       f'are handled multiple times: {redundant_params}')
489                full_param_kwargs = {**old_param_kwargs, **new_param_kwargs}
490                merged_test_name = '{}{}{}'.format(new_test_name,
491                                                   '_' if old_test_name != '' and new_test_name != '' else '',
492                                                   old_test_name)
493
494                def merged_decorator_fn(param_kwargs, old_dec_fn=old_dec_fn, new_dec_fn=new_dec_fn):
495                    return list(old_dec_fn(param_kwargs)) + list(new_dec_fn(param_kwargs))
496
497                yield (new_test, merged_test_name, full_param_kwargs, merged_decorator_fn)
498
499    return composite_fn
500
501
502def instantiate_parametrized_tests(generic_cls):
503    """
504    Instantiates tests that have been decorated with a parametrize_fn. This is generally performed by a
505    decorator subclass of _TestParametrizer. The generic test will be replaced on the test class by
506    parametrized tests with specialized names. This should be used instead of
507    instantiate_device_type_tests() if the test class contains device-agnostic tests.
508
509    You can also use it as a class decorator. E.g.
510
511    ```
512    @instantiate_parametrized_tests
513    class TestFoo(TestCase):
514        ...
515    ```
516
517    Args:
518        generic_cls (class): Generic test class object containing tests (e.g. TestFoo)
519    """
520    for attr_name in tuple(dir(generic_cls)):
521        class_attr = getattr(generic_cls, attr_name)
522        if not hasattr(class_attr, 'parametrize_fn'):
523            continue
524
525        # Remove the generic test from the test class.
526        delattr(generic_cls, attr_name)
527
528        # Add parametrized tests to the test class.
529        def instantiate_test_helper(cls, name, test, param_kwargs):
530            @wraps(test)
531            def instantiated_test(self, param_kwargs=param_kwargs):
532                test(self, **param_kwargs)
533
534            assert not hasattr(generic_cls, name), f"Redefinition of test {name}"
535            setattr(generic_cls, name, instantiated_test)
536
537        for (test, test_suffix, param_kwargs, decorator_fn) in class_attr.parametrize_fn(
538                class_attr, generic_cls=generic_cls, device_cls=None):
539            full_name = f'{test.__name__}_{test_suffix}'
540
541            # Apply decorators based on full param kwargs.
542            for decorator in decorator_fn(param_kwargs):
543                test = decorator(test)
544
545            instantiate_test_helper(cls=generic_cls, name=full_name, test=test, param_kwargs=param_kwargs)
546    return generic_cls
547
548
549class subtest:
550    """
551    Explicit subtest case for use with test parametrization.
552    Allows for explicit naming of individual subtest cases as well as applying
553    decorators to the parametrized test.
554
555    Args:
556        arg_values (iterable): Iterable of arg values (e.g. range(10)) or
557            tuples of arg values (e.g. [(1, 2), (3, 4)]).
558        name (str): Optional name to use for the test.
559        decorators (iterable): Iterable of decorators to apply to the generated test.
560    """
561    __slots__ = ['arg_values', 'name', 'decorators']
562
563    def __init__(self, arg_values, name=None, decorators=None):
564        self.arg_values = arg_values
565        self.name = name
566        self.decorators = decorators if decorators else []
567
568
569class parametrize(_TestParametrizer):
570    """
571    Decorator for applying generic test parametrizations.
572
573    The interface for this decorator is modeled after `@pytest.mark.parametrize`.
574    Basic usage between this decorator and pytest's is identical. The first argument
575    should be a string containing comma-separated names of parameters for the test, and
576    the second argument should be an iterable returning values or tuples of values for
577    the case of multiple parameters.
578
579    Beyond this basic usage, the decorator provides some additional functionality that
580    pytest does not.
581
582    1. Parametrized tests end up as generated test functions on unittest test classes.
583    Since this differs from how pytest works, this decorator takes on the additional
584    responsibility of naming these test functions. The default test names consists of
585    the test's base name followed by each parameter name + value (e.g. "test_bar_x_1_y_foo"),
586    but custom names can be defined using `name_fn` or the `subtest` structure (see below).
587
588    2. The decorator specially handles parameter values of type `subtest`, which allows for
589    more fine-grained control over both test naming and test execution. In particular, it can
590    be used to tag subtests with explicit test names or apply arbitrary decorators (see examples
591    below).
592
593    Examples::
594
595        @parametrize("x", range(5))
596        def test_foo(self, x):
597            ...
598
599        @parametrize("x,y", [(1, 'foo'), (2, 'bar'), (3, 'baz')])
600        def test_bar(self, x, y):
601            ...
602
603        @parametrize("x,y", [(1, 'foo'), (2, 'bar'), (3, 'baz')],
604                     name_fn=lambda x, y: '{}_{}'.format(x, y))
605        def test_bar_custom_names(self, x, y):
606            ...
607
608        @parametrize("x, y", [subtest((1, 2), name='double'),
609                              subtest((1, 3), name='triple', decorators=[unittest.expectedFailure]),
610                              subtest((1, 4), name='quadruple')])
611        def test_baz(self, x, y):
612            ...
613
614    To actually instantiate the parametrized tests, one of instantiate_parametrized_tests() or
615    instantiate_device_type_tests() should be called. The former is intended for test classes
616    that contain device-agnostic tests, while the latter should be used for test classes that
617    contain device-specific tests. Both support arbitrary parametrizations using the decorator.
618
619    Args:
620        arg_str (str): String of arg names separate by commas (e.g. "x,y").
621        arg_values (iterable): Iterable of arg values (e.g. range(10)) or
622            tuples of arg values (e.g. [(1, 2), (3, 4)]).
623        name_fn (Callable): Optional function that takes in parameters and returns subtest name.
624    """
625    def __init__(self, arg_str, arg_values, name_fn=None):
626        self.arg_names: List[str] = [s.strip() for s in arg_str.split(',') if s != '']
627        self.arg_values = arg_values
628        self.name_fn = name_fn
629
630    def _formatted_str_repr(self, idx, name, value):
631        """ Returns a string representation for the given arg that is suitable for use in test function names. """
632        if isinstance(value, torch.dtype):
633            return dtype_name(value)
634        elif isinstance(value, torch.device):
635            return str(value)
636        # Can't use isinstance as it would cause a circular import
637        elif type(value).__name__ in {'OpInfo', 'ModuleInfo'}:
638            return value.formatted_name
639        elif isinstance(value, (int, float, str)):
640            return f"{name}_{str(value).replace('.', '_')}"
641        else:
642            return f"{name}{idx}"
643
644    def _default_subtest_name(self, idx, values):
645        return '_'.join([self._formatted_str_repr(idx, a, v) for a, v in zip(self.arg_names, values)])
646
647    def _get_subtest_name(self, idx, values, explicit_name=None):
648        if explicit_name:
649            subtest_name = explicit_name
650        elif self.name_fn:
651            subtest_name = self.name_fn(*values)
652        else:
653            subtest_name = self._default_subtest_name(idx, values)
654        return subtest_name
655
656    def _parametrize_test(self, test, generic_cls, device_cls):
657        if len(self.arg_names) == 0:
658            # No additional parameters needed for the test.
659            test_name = ''
660            yield (test, test_name, {}, lambda _: [])
661        else:
662            # Each "values" item is expected to be either:
663            # * A tuple of values with one for each arg. For a single arg, a single item is expected.
664            # * A subtest instance with arg_values matching the previous.
665            values = check_exhausted_iterator = object()
666            for idx, values in enumerate(self.arg_values):
667                maybe_name = None
668
669                decorators = []
670                if isinstance(values, subtest):
671                    sub = values
672                    values = sub.arg_values
673                    maybe_name = sub.name
674
675                    @wraps(test)
676                    def test_wrapper(*args, **kwargs):
677                        return test(*args, **kwargs)
678
679                    decorators = sub.decorators
680                    gen_test = test_wrapper
681                else:
682                    gen_test = test
683
684                values = list(values) if len(self.arg_names) > 1 else [values]
685                if len(values) != len(self.arg_names):
686                    raise RuntimeError(f'Expected # values == # arg names, but got: {len(values)} '
687                                       f'values and {len(self.arg_names)} names for test "{test.__name__}"')
688
689                param_kwargs = dict(zip(self.arg_names, values))
690
691                test_name = self._get_subtest_name(idx, values, explicit_name=maybe_name)
692
693                def decorator_fn(_, decorators=decorators):
694                    return decorators
695
696                yield (gen_test, test_name, param_kwargs, decorator_fn)
697
698            if values is check_exhausted_iterator:
699                raise ValueError(f'{test}: An empty arg_values was passed to @parametrize. '
700                                 'Note that this may result from reuse of a generator.')
701
702
703class decorateIf(_TestParametrizer):
704    """
705    Decorator for applying parameter-specific conditional decoration.
706    Composes with other test parametrizers (e.g. @modules, @ops, @parametrize, etc.).
707
708    Examples::
709
710        @decorateIf(unittest.skip, lambda params: params["x"] == 2)
711        @parametrize("x", range(5))
712        def test_foo(self, x):
713            ...
714
715        @parametrize("x,y", [(1, 'foo'), (2, 'bar'), (3, 'baz')])
716        @decorateIf(
717            unittest.expectedFailure,
718            lambda params: params["x"] == 3 and params["y"] == "baz"
719        )
720        def test_bar(self, x, y):
721            ...
722
723        @decorateIf(
724            unittest.expectedFailure,
725            lambda params: params["op"].name == "add" and params["dtype"] == torch.float16
726        )
727        @ops(op_db)
728        def test_op_foo(self, device, dtype, op):
729            ...
730
731        @decorateIf(
732            unittest.skip,
733            lambda params: params["module_info"].module_cls is torch.nn.Linear and \
734                params["device"] == "cpu"
735        )
736        @modules(module_db)
737        def test_module_foo(self, device, dtype, module_info):
738            ...
739
740    Args:
741        decorator: Test decorator to apply if the predicate is satisfied.
742        predicate_fn (Callable): Function taking in a dict of params and returning a boolean
743            indicating whether the decorator should be applied or not.
744    """
745    def __init__(self, decorator, predicate_fn):
746        self.decorator = decorator
747        self.predicate_fn = predicate_fn
748
749    def _parametrize_test(self, test, generic_cls, device_cls):
750
751        # Leave test as-is and return the appropriate decorator_fn.
752        def decorator_fn(params, decorator=self.decorator, predicate_fn=self.predicate_fn):
753            if predicate_fn(params):
754                return [decorator]
755            else:
756                return []
757
758        @wraps(test)
759        def test_wrapper(*args, **kwargs):
760            return test(*args, **kwargs)
761
762        test_name = ''
763        yield (test_wrapper, test_name, {}, decorator_fn)
764
765
766class ProfilingMode(Enum):
767    LEGACY = 1
768    SIMPLE = 2
769    PROFILING = 3
770
771def cppProfilingFlagsToProfilingMode():
772    old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
773    old_prof_mode_state = torch._C._get_graph_executor_optimize(True)
774    torch._C._jit_set_profiling_executor(old_prof_exec_state)
775    torch._C._get_graph_executor_optimize(old_prof_mode_state)
776
777    if old_prof_exec_state:
778        if old_prof_mode_state:
779            return ProfilingMode.PROFILING
780        else:
781            return ProfilingMode.SIMPLE
782    else:
783        return ProfilingMode.LEGACY
784
785@contextmanager
786def enable_profiling_mode_for_profiling_tests():
787    if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
788        old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
789        old_prof_mode_state = torch._C._get_graph_executor_optimize(True)
790    try:
791        yield
792    finally:
793        if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
794            torch._C._jit_set_profiling_executor(old_prof_exec_state)
795            torch._C._get_graph_executor_optimize(old_prof_mode_state)
796
797@contextmanager
798def enable_profiling_mode():
799    old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
800    old_prof_mode_state = torch._C._get_graph_executor_optimize(True)
801    try:
802        yield
803    finally:
804        torch._C._jit_set_profiling_executor(old_prof_exec_state)
805        torch._C._get_graph_executor_optimize(old_prof_mode_state)
806
807@contextmanager
808def num_profiled_runs(num_runs):
809    old_num_runs = torch._C._jit_set_num_profiled_runs(num_runs)
810    try:
811        yield
812    finally:
813        torch._C._jit_set_num_profiled_runs(old_num_runs)
814
815func_call = torch._C.ScriptFunction.__call__
816meth_call = torch._C.ScriptMethod.__call__
817
818def prof_callable(callable, *args, **kwargs):
819    if 'profile_and_replay' in kwargs:
820        del kwargs['profile_and_replay']
821        if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
822            with enable_profiling_mode_for_profiling_tests():
823                callable(*args, **kwargs)
824                return callable(*args, **kwargs)
825
826    return callable(*args, **kwargs)
827
828def prof_func_call(*args, **kwargs):
829    return prof_callable(func_call, *args, **kwargs)
830
831def prof_meth_call(*args, **kwargs):
832    return prof_callable(meth_call, *args, **kwargs)
833
834torch._C.ScriptFunction.__call__ = prof_func_call  # type: ignore[method-assign]
835torch._C.ScriptMethod.__call__ = prof_meth_call  # type: ignore[method-assign]
836
837def _get_test_report_path():
838    # allow users to override the test file location. We need this
839    # because the distributed tests run the same test file multiple
840    # times with different configurations.
841    override = os.environ.get('TEST_REPORT_SOURCE_OVERRIDE')
842    test_source = override if override is not None else 'python-unittest'
843    return os.path.join('test-reports', test_source)
844
845is_running_via_run_test = "run_test.py" in getattr(__main__, "__file__", "")
846parser = argparse.ArgumentParser(add_help=not is_running_via_run_test, allow_abbrev=False)
847parser.add_argument('--subprocess', action='store_true',
848                    help='whether to run each test in a subprocess')
849parser.add_argument('--seed', type=int, default=1234)
850parser.add_argument('--accept', action='store_true')
851parser.add_argument('--jit-executor', '--jit_executor', type=str)
852parser.add_argument('--repeat', type=int, default=1)
853parser.add_argument('--test-bailouts', '--test_bailouts', action='store_true')
854parser.add_argument('--use-pytest', action='store_true')
855parser.add_argument('--save-xml', nargs='?', type=str,
856                    const=_get_test_report_path(),
857                    default=_get_test_report_path() if IS_CI else None)
858parser.add_argument('--discover-tests', action='store_true')
859parser.add_argument('--log-suffix', type=str, default="")
860parser.add_argument('--run-parallel', type=int, default=1)
861parser.add_argument('--import-slow-tests', type=str, nargs='?', const=DEFAULT_SLOW_TESTS_FILE)
862parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DEFAULT_DISABLED_TESTS_FILE)
863parser.add_argument('--rerun-disabled-tests', action='store_true')
864parser.add_argument('--pytest-single-test', type=str, nargs=1)
865if sys.version_info >= (3, 9):
866    parser.add_argument('--showlocals', action=argparse.BooleanOptionalAction, default=False)
867else:
868    parser.add_argument('--showlocals', action='store_true', default=False)
869    parser.add_argument('--no-showlocals', dest='showlocals', action='store_false')
870
871# Only run when -h or --help flag is active to display both unittest and parser help messages.
872def run_unittest_help(argv):
873    unittest.main(argv=argv)
874
875if '-h' in sys.argv or '--help' in sys.argv:
876    help_thread = threading.Thread(target=run_unittest_help, args=(sys.argv,))
877    help_thread.start()
878    help_thread.join()
879
880args, remaining = parser.parse_known_args()
881if args.jit_executor == 'legacy':
882    GRAPH_EXECUTOR = ProfilingMode.LEGACY
883elif args.jit_executor == 'profiling':
884    GRAPH_EXECUTOR = ProfilingMode.PROFILING
885elif args.jit_executor == 'simple':
886    GRAPH_EXECUTOR = ProfilingMode.SIMPLE
887else:
888    # infer flags based on the default settings
889    GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode()
890
891RERUN_DISABLED_TESTS = args.rerun_disabled_tests
892
893SLOW_TESTS_FILE = args.import_slow_tests
894DISABLED_TESTS_FILE = args.import_disabled_tests
895LOG_SUFFIX = args.log_suffix
896RUN_PARALLEL = args.run_parallel
897TEST_BAILOUTS = args.test_bailouts
898USE_PYTEST = args.use_pytest
899PYTEST_SINGLE_TEST = args.pytest_single_test
900TEST_DISCOVER = args.discover_tests
901TEST_IN_SUBPROCESS = args.subprocess
902TEST_SAVE_XML = args.save_xml
903REPEAT_COUNT = args.repeat
904SEED = args.seed
905SHOWLOCALS = args.showlocals
906if not getattr(expecttest, "ACCEPT", False):
907    expecttest.ACCEPT = args.accept
908UNITTEST_ARGS = [sys.argv[0]] + remaining
909torch.manual_seed(SEED)
910
911# CI Prefix path used only on CI environment
912CI_TEST_PREFIX = str(Path(os.getcwd()))
913CI_PT_ROOT = str(Path(os.getcwd()).parent)
914CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch"))
915
916def wait_for_process(p, timeout=None):
917    try:
918        return p.wait(timeout=timeout)
919    except KeyboardInterrupt:
920        # Give `p` a chance to handle KeyboardInterrupt. Without this,
921        # `pytest` can't print errors it collected so far upon KeyboardInterrupt.
922        exit_status = p.wait(timeout=5)
923        if exit_status is not None:
924            return exit_status
925        else:
926            p.kill()
927            raise
928    except subprocess.TimeoutExpired:
929        # send SIGINT to give pytest a chance to make xml
930        p.send_signal(signal.SIGINT)
931        exit_status = None
932        try:
933            exit_status = p.wait(timeout=5)
934        # try to handle the case where p.wait(timeout=5) times out as well as
935        # otherwise the wait() call in the finally block can potentially hang
936        except subprocess.TimeoutExpired:
937            pass
938        if exit_status is not None:
939            return exit_status
940        else:
941            p.kill()
942        raise
943    except:  # noqa: B001,E722, copied from python core library
944        p.kill()
945        raise
946    finally:
947        # Always call p.wait() to ensure exit
948        p.wait()
949
950def shell(command, cwd=None, env=None, stdout=None, stderr=None, timeout=None):
951    sys.stdout.flush()
952    sys.stderr.flush()
953    # The following cool snippet is copied from Py3 core library subprocess.call
954    # only the with
955    #   1. `except KeyboardInterrupt` block added for SIGINT handling.
956    #   2. In Py2, subprocess.Popen doesn't return a context manager, so we do
957    #      `p.wait()` in a `final` block for the code to be portable.
958    #
959    # https://github.com/python/cpython/blob/71b6c1af727fbe13525fb734568057d78cea33f3/Lib/subprocess.py#L309-L323
960    assert not isinstance(command, str), "Command to shell should be a list or tuple of tokens"
961    p = subprocess.Popen(command, universal_newlines=True, cwd=cwd, env=env, stdout=stdout, stderr=stderr)
962    return wait_for_process(p, timeout=timeout)
963
964
965def retry_shell(
966    command,
967    cwd=None,
968    env=None,
969    stdout=None,
970    stderr=None,
971    timeout=None,
972    retries=1,
973    was_rerun=False,
974) -> Tuple[int, bool]:
975    # Returns exicode + whether it was rerun
976    assert (
977        retries >= 0
978    ), f"Expecting non negative number for number of retries, got {retries}"
979    try:
980        exit_code = shell(
981            command, cwd=cwd, env=env, stdout=stdout, stderr=stderr, timeout=timeout
982        )
983        if exit_code == 0 or retries == 0:
984            return exit_code, was_rerun
985        print(
986            f"Got exit code {exit_code}, retrying (retries left={retries})",
987            file=stdout,
988            flush=True,
989        )
990    except subprocess.TimeoutExpired:
991        if retries == 0:
992            print(
993                f"Command took >{timeout // 60}min, returning 124",
994                file=stdout,
995                flush=True,
996            )
997            return 124, was_rerun
998        print(
999            f"Command took >{timeout // 60}min, retrying (retries left={retries})",
1000            file=stdout,
1001            flush=True,
1002        )
1003    return retry_shell(
1004        command,
1005        cwd=cwd,
1006        env=env,
1007        stdout=stdout,
1008        stderr=stderr,
1009        timeout=timeout,
1010        retries=retries - 1,
1011        was_rerun=True,
1012    )
1013
1014
1015def discover_test_cases_recursively(suite_or_case):
1016    if isinstance(suite_or_case, unittest.TestCase):
1017        return [suite_or_case]
1018    rc = []
1019    for element in suite_or_case:
1020        print(element)
1021        rc.extend(discover_test_cases_recursively(element))
1022    return rc
1023
1024def get_test_names(test_cases):
1025    return ['.'.join(case.id().split('.')[-2:]) for case in test_cases]
1026
1027def _print_test_names():
1028    suite = unittest.TestLoader().loadTestsFromModule(__main__)
1029    test_cases = discover_test_cases_recursively(suite)
1030    for name in get_test_names(test_cases):
1031        print(name)
1032
1033def chunk_list(lst, nchunks):
1034    return [lst[i::nchunks] for i in range(nchunks)]
1035
1036# sanitize filename e.g., distributed/pipeline/sync/skip/test_api.py -> distributed.pipeline.sync.skip.test_api
1037def sanitize_test_filename(filename):
1038    # inspect.getfile returns absolute path in some CI jobs, converting it to relative path if needed
1039    if filename.startswith(CI_TEST_PREFIX):
1040        filename = filename[len(CI_TEST_PREFIX) + 1:]
1041    strip_py = re.sub(r'.py$', '', filename)
1042    return re.sub('/', r'.', strip_py)
1043
1044def lint_test_case_extension(suite):
1045    succeed = True
1046    for test_case_or_suite in suite:
1047        test_case = test_case_or_suite
1048        if isinstance(test_case_or_suite, unittest.TestSuite):
1049            first_test = test_case_or_suite._tests[0] if len(test_case_or_suite._tests) > 0 else None
1050            if first_test is not None and isinstance(first_test, unittest.TestSuite):
1051                return succeed and lint_test_case_extension(test_case_or_suite)
1052            test_case = first_test
1053
1054        if test_case is not None:
1055            test_class = test_case.id().split('.', 1)[1].split('.')[0]
1056            if not isinstance(test_case, TestCase):
1057                err = "This test class should extend from torch.testing._internal.common_utils.TestCase but it doesn't."
1058                print(f"{test_class} - failed. {err}")
1059                succeed = False
1060    return succeed
1061
1062
1063def get_report_path(argv=UNITTEST_ARGS, pytest=False):
1064    test_filename = sanitize_test_filename(argv[0])
1065    test_report_path = TEST_SAVE_XML + LOG_SUFFIX
1066    test_report_path = os.path.join(test_report_path, test_filename)
1067    if pytest:
1068        test_report_path = test_report_path.replace('python-unittest', 'python-pytest')
1069        os.makedirs(test_report_path, exist_ok=True)
1070        test_report_path = os.path.join(test_report_path, f"{test_filename}-{os.urandom(8).hex()}.xml")
1071        return test_report_path
1072    os.makedirs(test_report_path, exist_ok=True)
1073    return test_report_path
1074
1075
1076def sanitize_pytest_xml(xml_file: str):
1077    # pytext xml is different from unittext xml, this function makes pytest xml more similar to unittest xml
1078    # consider somehow modifying the XML logger in conftest to do this instead
1079    import xml.etree.ElementTree as ET
1080    tree = ET.parse(xml_file)
1081    for testcase in tree.iter('testcase'):
1082        full_classname = testcase.attrib.get("classname")
1083        if full_classname is None:
1084            continue
1085        # The test prefix is optional
1086        regex_result = re.search(r"^(test\.)?(?P<file>.*)\.(?P<classname>[^\.]*)$", full_classname)
1087        if regex_result is None:
1088            continue
1089        classname = regex_result.group("classname")
1090        file = regex_result.group("file").replace(".", "/")
1091        testcase.set("classname", classname)
1092        testcase.set("file", f"{file}.py")
1093    tree.write(xml_file)
1094
1095
1096def get_pytest_test_cases(argv: List[str]) -> List[str]:
1097    class TestCollectorPlugin:
1098        def __init__(self) -> None:
1099            self.tests = []
1100
1101        def pytest_collection_finish(self, session):
1102            for item in session.items:
1103                self.tests.append(session.config.cwd_relative_nodeid(item.nodeid))
1104
1105    test_collector_plugin = TestCollectorPlugin()
1106    import pytest
1107    pytest.main(
1108        [arg for arg in argv if arg != '-vv'] + ['--collect-only', '-qq', '--use-main-module'],
1109        plugins=[test_collector_plugin]
1110    )
1111    return test_collector_plugin.tests
1112
1113
1114def run_tests(argv=UNITTEST_ARGS):
1115    # import test files.
1116    if SLOW_TESTS_FILE:
1117        if os.path.exists(SLOW_TESTS_FILE):
1118            with open(SLOW_TESTS_FILE) as fp:
1119                global slow_tests_dict
1120                slow_tests_dict = json.load(fp)
1121                # use env vars so pytest-xdist subprocesses can still access them
1122                os.environ['SLOW_TESTS_FILE'] = SLOW_TESTS_FILE
1123        else:
1124            warnings.warn(f'slow test file provided but not found: {SLOW_TESTS_FILE}')
1125    if DISABLED_TESTS_FILE:
1126        if os.path.exists(DISABLED_TESTS_FILE):
1127            with open(DISABLED_TESTS_FILE) as fp:
1128                global disabled_tests_dict
1129                disabled_tests_dict = json.load(fp)
1130                os.environ['DISABLED_TESTS_FILE'] = DISABLED_TESTS_FILE
1131        else:
1132            warnings.warn(f'disabled test file provided but not found: {DISABLED_TESTS_FILE}')
1133    # Determine the test launch mechanism
1134    if TEST_DISCOVER:
1135        _print_test_names()
1136        return
1137
1138    # Before running the tests, lint to check that every test class extends from TestCase
1139    suite = unittest.TestLoader().loadTestsFromModule(__main__)
1140    if not lint_test_case_extension(suite):
1141        sys.exit(1)
1142
1143    if SHOWLOCALS:
1144        argv = [
1145            argv[0],
1146            *(["--showlocals", "--tb=long", "--color=yes"] if USE_PYTEST else ["--locals"]),
1147            *argv[1:],
1148        ]
1149
1150    if TEST_IN_SUBPROCESS:
1151        other_args = []
1152        if DISABLED_TESTS_FILE:
1153            other_args.append("--import-disabled-tests")
1154        if SLOW_TESTS_FILE:
1155            other_args.append("--import-slow-tests")
1156        if USE_PYTEST:
1157            other_args.append("--use-pytest")
1158        if RERUN_DISABLED_TESTS:
1159            other_args.append("--rerun-disabled-tests")
1160        if TEST_SAVE_XML:
1161            other_args += ['--save-xml', args.save_xml]
1162
1163        test_cases = (
1164            get_pytest_test_cases(argv) if USE_PYTEST else
1165            [case.id().split('.', 1)[1] for case in discover_test_cases_recursively(suite)]
1166        )
1167
1168        failed_tests = []
1169
1170        for test_case_full_name in test_cases:
1171
1172            cmd = (
1173                [sys.executable] + [argv[0]] + other_args + argv[1:] +
1174                (["--pytest-single-test"] if USE_PYTEST else []) +
1175                [test_case_full_name]
1176            )
1177            string_cmd = " ".join(cmd)
1178
1179            timeout = None if RERUN_DISABLED_TESTS else 15 * 60
1180
1181            exitcode, _ = retry_shell(cmd, timeout=timeout, retries=0 if RERUN_DISABLED_TESTS else 1)
1182
1183            if exitcode != 0:
1184                # This is sort of hacky, but add on relevant env variables for distributed tests.
1185                if 'TestDistBackendWithSpawn' in test_case_full_name:
1186                    backend = os.environ.get("BACKEND", "")
1187                    world_size = os.environ.get("WORLD_SIZE", "")
1188                    env_prefix = f"BACKEND={backend} WORLD_SIZE={world_size}"
1189                    string_cmd = env_prefix + " " + string_cmd
1190                # Log the command to reproduce the failure.
1191                print(f"Test exited with non-zero exitcode {exitcode}. Command to reproduce: {string_cmd}")
1192                failed_tests.append(test_case_full_name)
1193
1194            assert len(failed_tests) == 0, "{} unit test(s) failed:\n\t{}".format(
1195                len(failed_tests), '\n\t'.join(failed_tests))
1196
1197    elif RUN_PARALLEL > 1:
1198        test_cases = discover_test_cases_recursively(suite)
1199        test_batches = chunk_list(get_test_names(test_cases), RUN_PARALLEL)
1200        processes = []
1201        for i in range(RUN_PARALLEL):
1202            command = [sys.executable] + argv + [f'--log-suffix=-shard-{i + 1}'] + test_batches[i]
1203            processes.append(subprocess.Popen(command, universal_newlines=True))
1204        failed = False
1205        for p in processes:
1206            failed |= wait_for_process(p) != 0
1207        assert not failed, "Some test shards have failed"
1208    elif USE_PYTEST:
1209        pytest_args = argv + ["--use-main-module"]
1210        if TEST_SAVE_XML:
1211            test_report_path = get_report_path(pytest=True)
1212            print(f'Test results will be stored in {test_report_path}')
1213            pytest_args.append(f'--junit-xml-reruns={test_report_path}')
1214        if PYTEST_SINGLE_TEST:
1215            pytest_args = PYTEST_SINGLE_TEST + pytest_args[1:]
1216
1217        import pytest
1218        os.environ["NO_COLOR"] = "1"
1219        exit_code = pytest.main(args=pytest_args)
1220        if TEST_SAVE_XML:
1221            sanitize_pytest_xml(test_report_path)
1222
1223        if not RERUN_DISABLED_TESTS:
1224            # exitcode of 5 means no tests were found, which happens since some test configs don't
1225            # run tests from certain files
1226            sys.exit(0 if exit_code == 5 else exit_code)
1227        else:
1228            # Only record the test report and always return a success code when running under rerun
1229            # disabled tests mode
1230            sys.exit(0)
1231    elif TEST_SAVE_XML is not None:
1232        # import here so that non-CI doesn't need xmlrunner installed
1233        import xmlrunner  # type: ignore[import]
1234        from xmlrunner.result import _XMLTestResult  # type: ignore[import]
1235
1236        class XMLTestResultVerbose(_XMLTestResult):
1237            """
1238            Adding verbosity to test outputs:
1239            by default test summary prints 'skip',
1240            but we want to also print the skip reason.
1241            GH issue: https://github.com/pytorch/pytorch/issues/69014
1242
1243            This works with unittest_xml_reporting<=3.2.0,>=2.0.0
1244            (3.2.0 is latest at the moment)
1245            """
1246            def __init__(self, *args, **kwargs):
1247                super().__init__(*args, **kwargs)
1248
1249            def addSkip(self, test, reason):
1250                super().addSkip(test, reason)
1251                for c in self.callback.__closure__:
1252                    if isinstance(c.cell_contents, str) and c.cell_contents == 'skip':
1253                        # this message is printed in test summary;
1254                        # it stands for `verbose_str` captured in the closure
1255                        c.cell_contents = f"skip: {reason}"
1256
1257            def printErrors(self) -> None:
1258                super().printErrors()
1259                self.printErrorList("XPASS", self.unexpectedSuccesses)
1260        test_report_path = get_report_path()
1261        verbose = '--verbose' in argv or '-v' in argv
1262        if verbose:
1263            print(f'Test results will be stored in {test_report_path}')
1264        unittest.main(argv=argv, testRunner=xmlrunner.XMLTestRunner(
1265            output=test_report_path,
1266            verbosity=2 if verbose else 1,
1267            resultclass=XMLTestResultVerbose))
1268    elif REPEAT_COUNT > 1:
1269        for _ in range(REPEAT_COUNT):
1270            if not unittest.main(exit=False, argv=argv).result.wasSuccessful():
1271                sys.exit(-1)
1272    else:
1273        unittest.main(argv=argv)
1274
1275IS_LINUX = sys.platform == "linux"
1276IS_WINDOWS = sys.platform == "win32"
1277IS_MACOS = sys.platform == "darwin"
1278IS_PPC = platform.machine() == "ppc64le"
1279IS_X86 = platform.machine() in ('x86_64', 'i386')
1280IS_ARM64 = platform.machine() in ('arm64', 'aarch64')
1281
1282def is_avx512_vnni_supported():
1283    if sys.platform != 'linux':
1284        return False
1285    with open("/proc/cpuinfo", encoding="ascii") as f:
1286        lines = f.read()
1287    return "vnni" in lines
1288
1289IS_AVX512_VNNI_SUPPORTED = is_avx512_vnni_supported()
1290
1291if IS_WINDOWS:
1292    @contextmanager
1293    def TemporaryFileName(*args, **kwargs):
1294        # Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
1295        # opens the file, and it cannot be opened multiple times in Windows. To support Windows,
1296        # close the file after creation and try to remove it manually
1297        if 'delete' in kwargs:
1298            if kwargs['delete'] is not False:
1299                raise UserWarning("only TemporaryFileName with delete=False is supported on Windows.")
1300        else:
1301            kwargs['delete'] = False
1302        f = tempfile.NamedTemporaryFile(*args, **kwargs)
1303        try:
1304            f.close()
1305            yield f.name
1306        finally:
1307            os.unlink(f.name)
1308else:
1309    @contextmanager  # noqa: T484
1310    def TemporaryFileName(*args, **kwargs):
1311        with tempfile.NamedTemporaryFile(*args, **kwargs) as f:
1312            yield f.name
1313
1314if IS_WINDOWS:
1315    @contextmanager
1316    def TemporaryDirectoryName(suffix=None):
1317        # On Windows the directory created by TemporaryDirectory is likely to be removed prematurely,
1318        # so we first create the directory using mkdtemp and then remove it manually
1319        try:
1320            dir_name = tempfile.mkdtemp(suffix=suffix)
1321            yield dir_name
1322        finally:
1323            shutil.rmtree(dir_name)
1324else:
1325    @contextmanager  # noqa: T484
1326    def TemporaryDirectoryName(suffix=None):
1327        with tempfile.TemporaryDirectory(suffix=suffix) as d:
1328            yield d
1329
1330
1331def is_privateuse1_backend_available():
1332    privateuse1_backend_name = torch._C._get_privateuse1_backend_name()
1333    privateuse1_backend_module = getattr(torch, privateuse1_backend_name, None)
1334    return hasattr(privateuse1_backend_module, "is_available") and privateuse1_backend_module.is_available()
1335
1336
1337IS_FILESYSTEM_UTF8_ENCODING = sys.getfilesystemencoding() == 'utf-8'
1338
1339TEST_NUMPY = _check_module_exists('numpy')
1340TEST_FAIRSEQ = _check_module_exists('fairseq')
1341TEST_SCIPY = _check_module_exists('scipy')
1342TEST_MKL = torch.backends.mkl.is_available()
1343TEST_MPS = torch.backends.mps.is_available()
1344TEST_XPU = torch.xpu.is_available()
1345TEST_HPU = True if (hasattr(torch, "hpu") and torch.hpu.is_available()) else False
1346TEST_CUDA = torch.cuda.is_available()
1347custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name(), None)
1348TEST_PRIVATEUSE1 = is_privateuse1_backend_available()
1349TEST_PRIVATEUSE1_DEVICE_TYPE = torch._C._get_privateuse1_backend_name()
1350TEST_NUMBA = _check_module_exists('numba')
1351TEST_TRANSFORMERS = _check_module_exists('transformers')
1352TEST_DILL = _check_module_exists('dill')
1353
1354TEST_LIBROSA = _check_module_exists('librosa') and not IS_ARM64
1355
1356TEST_OPT_EINSUM = _check_module_exists('opt_einsum')
1357
1358TEST_Z3 = _check_module_exists('z3')
1359
1360def split_if_not_empty(x: str):
1361    return x.split(",") if len(x) != 0 else []
1362
1363NOTEST_CPU = "cpu" in split_if_not_empty(os.getenv('PYTORCH_TESTING_DEVICE_EXCEPT_FOR', ''))
1364
1365skipIfNoDill = unittest.skipIf(not TEST_DILL, "no dill")
1366
1367
1368# Python 2.7 doesn't have spawn
1369NO_MULTIPROCESSING_SPAWN: bool = TestEnvironment.def_flag(
1370    "NO_MULTIPROCESSING_SPAWN",
1371    env_var="NO_MULTIPROCESSING_SPAWN",
1372)
1373TEST_WITH_ASAN: bool = TestEnvironment.def_flag(
1374    "TEST_WITH_ASAN",
1375    env_var="PYTORCH_TEST_WITH_ASAN",
1376)
1377TEST_WITH_DEV_DBG_ASAN: bool = TestEnvironment.def_flag(
1378    "TEST_WITH_DEV_DBG_ASAN",
1379    env_var="PYTORCH_TEST_WITH_DEV_DBG_ASAN",
1380)
1381TEST_WITH_TSAN: bool = TestEnvironment.def_flag(
1382    "TEST_WITH_TSAN",
1383    env_var="PYTORCH_TEST_WITH_TSAN",
1384)
1385TEST_WITH_UBSAN: bool = TestEnvironment.def_flag(
1386    "TEST_WITH_UBSAN",
1387    env_var="PYTORCH_TEST_WITH_UBSAN",
1388)
1389TEST_WITH_ROCM: bool = TestEnvironment.def_flag(
1390    "TEST_WITH_ROCM",
1391    env_var="PYTORCH_TEST_WITH_ROCM",
1392)
1393
1394# TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
1395# See #64427
1396TEST_WITH_MIOPEN_SUGGEST_NHWC = os.getenv('PYTORCH_MIOPEN_SUGGEST_NHWC', '0') == '1'
1397# Enables tests that are slow to run (disabled by default)
1398TEST_WITH_SLOW: bool = TestEnvironment.def_flag(
1399    "TEST_WITH_SLOW",
1400    env_var="PYTORCH_TEST_WITH_SLOW",
1401)
1402
1403# Disables non-slow tests (these tests enabled by default)
1404# This is usually used in conjunction with TEST_WITH_SLOW to
1405# run *only* slow tests.  (I could have done an enum, but
1406# it felt a little awkward.
1407TEST_SKIP_FAST: bool = TestEnvironment.def_flag(
1408    "TEST_SKIP_FAST",
1409    env_var="PYTORCH_TEST_SKIP_FAST",
1410)
1411
1412# Enables crossref tests, in addition to standard tests which
1413# are being run.  crossref tests work by installing a torch
1414# function mode that runs extra compute alongside the regular
1415# computation that happens with the test.  After both computations
1416# are done, we cross-reference them (thus the name) to check for
1417# correction, before throwing out the extra compute and proceeding
1418# as we had before.  By default, we don't run these tests.
1419TEST_WITH_CROSSREF: bool = TestEnvironment.def_flag(
1420    "TEST_WITH_CROSSREF",
1421    env_var="PYTORCH_TEST_WITH_CROSSREF",
1422)
1423
1424TEST_SKIP_CUDAGRAPH: bool = TestEnvironment.def_flag(
1425    "TEST_SKIP_CUDAGRAPH",
1426    env_var="PYTORCH_TEST_SKIP_CUDAGRAPH",
1427)
1428TEST_CUDA_GRAPH = TEST_CUDA and (not TEST_SKIP_CUDAGRAPH) and (
1429    (torch.version.cuda and int(torch.version.cuda.split(".")[0]) >= 11) or
1430    (torch.version.hip and float(".".join(torch.version.hip.split(".")[0:2])) >= 5.3)
1431)
1432
1433TEST_CUDA_CUDSS = TEST_CUDA and (torch.version.cuda and int(torch.version.cuda.split(".")[0]) >= 12)
1434
1435def allocator_option_enabled_fn(allocator_config, _, option):
1436    if allocator_config is None:
1437        return False
1438    allocator_config = allocator_config.split(',') if ',' in allocator_config else [allocator_config]
1439    mapping = dict([var.split(':') for var in allocator_config])
1440
1441    if option in mapping and mapping[option] == 'True':
1442        return True
1443    else:
1444        return False
1445
1446EXPANDABLE_SEGMENTS: bool = TestEnvironment.def_flag(
1447    "EXPANDABLE_SEGMENTS",
1448    env_var="PYTORCH_CUDA_ALLOC_CONF",
1449    enabled_fn=functools.partial(allocator_option_enabled_fn, option='expandable_segments'),
1450)
1451
1452if TEST_CUDA and 'NUM_PARALLEL_PROCS' in os.environ:
1453    num_procs = int(os.getenv("NUM_PARALLEL_PROCS", "2"))
1454    gb_available = torch.cuda.mem_get_info()[1] / 2 ** 30
1455    # other libraries take up about a little under 1 GB of space per process
1456    torch.cuda.set_per_process_memory_fraction(round((gb_available - num_procs * .85) / gb_available / num_procs, 2))
1457
1458requires_cuda = unittest.skipUnless(torch.cuda.is_available(), "Requires CUDA")
1459
1460def skipIfCrossRef(fn):
1461    @wraps(fn)
1462    def wrapper(*args, **kwargs):
1463        if TEST_WITH_CROSSREF:
1464            raise unittest.SkipTest("test doesn't currently with crossref")
1465        else:
1466            fn(*args, **kwargs)
1467    return wrapper
1468
1469class CrossRefMode(torch.overrides.TorchFunctionMode):
1470    def __torch_function__(self, func, types, args=(), kwargs=None):
1471        kwargs = kwargs or {}
1472        r = func(*args, **kwargs)
1473        return r
1474
1475# Run PyTorch tests with TorchDynamo
1476TEST_WITH_TORCHINDUCTOR: bool = TestEnvironment.def_flag(
1477    "TEST_WITH_TORCHINDUCTOR",
1478    env_var="PYTORCH_TEST_WITH_INDUCTOR",
1479)
1480# AOT_EAGER not tested in ci, useful for debugging
1481TEST_WITH_AOT_EAGER: bool = TestEnvironment.def_flag(
1482    "TEST_WITH_AOT_EAGER",
1483    env_var="PYTORCH_TEST_WITH_AOT_EAGER",
1484)
1485TEST_WITH_TORCHDYNAMO: bool = TestEnvironment.def_flag(
1486    "TEST_WITH_TORCHDYNAMO",
1487    env_var="PYTORCH_TEST_WITH_DYNAMO",
1488    implied_by_fn=lambda: TEST_WITH_TORCHINDUCTOR or TEST_WITH_AOT_EAGER,
1489)
1490
1491if TEST_WITH_TORCHDYNAMO:
1492    import torch._dynamo
1493    # Do not spend time on helper functions that are called with different inputs
1494    torch._dynamo.config.accumulated_cache_size_limit = 64
1495    # Do not log compilation metrics from unit tests
1496    torch._dynamo.config.log_compilation_metrics = False
1497    if TEST_WITH_TORCHINDUCTOR:
1498        import torch._inductor.config
1499        torch._inductor.config.fallback_random = True
1500
1501
1502def xpassIfTorchDynamo(func):
1503    return func if TEST_WITH_TORCHDYNAMO else unittest.expectedFailure(func)
1504
1505
1506def xfailIfTorchDynamo(func):
1507    return unittest.expectedFailure(func) if TEST_WITH_TORCHDYNAMO else func
1508
1509
1510def skipIfTorchDynamo(msg="test doesn't currently work with dynamo"):
1511    """
1512    Usage:
1513    @skipIfTorchDynamo(msg)
1514    def test_blah(self):
1515        ...
1516    """
1517    assert isinstance(msg, str), "Are you using skipIfTorchDynamo correctly?"
1518
1519    def decorator(fn):
1520        if not isinstance(fn, type):
1521            @wraps(fn)
1522            def wrapper(*args, **kwargs):
1523                if TEST_WITH_TORCHDYNAMO:
1524                    raise unittest.SkipTest(msg)
1525                else:
1526                    fn(*args, **kwargs)
1527            return wrapper
1528
1529        assert isinstance(fn, type)
1530        if TEST_WITH_TORCHDYNAMO:
1531            fn.__unittest_skip__ = True
1532            fn.__unittest_skip_why__ = msg
1533
1534        return fn
1535
1536    return decorator
1537
1538def skipIfTorchInductor(msg="test doesn't currently work with torchinductor",
1539                        condition=TEST_WITH_TORCHINDUCTOR):
1540    def decorator(fn):
1541        if not isinstance(fn, type):
1542            @wraps(fn)
1543            def wrapper(*args, **kwargs):
1544                if condition:
1545                    raise unittest.SkipTest(msg)
1546                else:
1547                    fn(*args, **kwargs)
1548            return wrapper
1549
1550        assert isinstance(fn, type)
1551        if condition:
1552            fn.__unittest_skip__ = True
1553            fn.__unittest_skip_why__ = msg
1554
1555        return fn
1556
1557    return decorator
1558
1559def serialTest(condition=True):
1560    """
1561    Decorator for running tests serially.  Requires pytest
1562    """
1563    def decorator(fn):
1564        if has_pytest and condition:
1565            return pytest.mark.serial(fn)
1566        return fn
1567    return decorator
1568
1569def unMarkDynamoStrictTest(cls=None):
1570    def decorator(cls):
1571        cls.dynamo_strict = False
1572        return cls
1573
1574    if cls is None:
1575        return decorator
1576    else:
1577        return decorator(cls)
1578
1579
1580def markDynamoStrictTest(cls_or_func=None, nopython=False):
1581    """
1582    Marks the test as 'strict'. In strict mode, we reset before and after the
1583    test, and run without suppress errors.
1584
1585    Args:
1586    - nopython: if we should run torch._dynamo.optimize with nopython={True/False}.
1587    """
1588    def decorator(cls_or_func):
1589        if inspect.isclass(cls_or_func):
1590            cls_or_func.dynamo_strict = True
1591            cls_or_func.dynamo_strict_nopython = nopython
1592            return cls_or_func
1593
1594        fn = cls_or_func
1595
1596        @wraps(fn)
1597        def wrapper(*args, **kwargs):
1598            torch._dynamo.reset()
1599            with unittest.mock.patch("torch._dynamo.config.suppress_errors", False):
1600                fn(*args, **kwargs)
1601            torch._dynamo.reset()
1602        return wrapper
1603
1604    if cls_or_func is None:
1605        return decorator
1606    else:
1607        return decorator(cls_or_func)
1608
1609
1610def skipRocmIfTorchInductor(msg="test doesn't currently work with torchinductor on the ROCm stack"):
1611    return skipIfTorchInductor(msg=msg, condition=TEST_WITH_ROCM and TEST_WITH_TORCHINDUCTOR)
1612
1613def skipIfLegacyJitExecutor(msg="test doesn't currently work with legacy JIT executor"):
1614    def decorator(fn):
1615        if not isinstance(fn, type):
1616            @wraps(fn)
1617            def wrapper(*args, **kwargs):
1618                if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
1619                    raise unittest.SkipTest(msg)
1620                else:
1621                    fn(*args, **kwargs)
1622            return wrapper
1623
1624        assert isinstance(fn, type)
1625        if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
1626            fn.__unittest_skip__ = True
1627            fn.__unittest_skip_why__ = msg
1628
1629        return fn
1630
1631
1632    return decorator
1633
1634
1635# Run PyTorch tests with translation validation on.
1636TEST_WITH_TV = os.getenv('PYTORCH_TEST_WITH_TV') == '1'
1637
1638if TEST_WITH_TV:
1639    torch.fx.experimental._config.translation_validation = True
1640
1641# Some tests take too long when dynamic_shapes is combined with
1642# translation_validation. Whenever that happens, we solve that by
1643# disabling translation_validation.
1644def disable_translation_validation_if_dynamic_shapes(fn):
1645    @functools.wraps(fn)
1646    def wrapper(*args, **kwargs):
1647        if torch._dynamo.config.dynamic_shapes:
1648            # Turning TV off due to high latency on dynamic shapes.
1649            torch.fx.experimental._config.translation_validation = False
1650        return fn(*args, **kwargs)
1651    return wrapper
1652
1653
1654# Determine whether to enable cuda memory leak check.
1655# CUDA mem leak check is expensive and thus we don't want to execute it on every
1656# test case / configuration.
1657# If this is True then CUDA memory leak checks are skipped. If this is false
1658#   then CUDA memory leak checks are performed.
1659# See: https://github.com/pytorch/pytorch/pull/59402#issuecomment-858811135
1660TEST_CUDA_MEM_LEAK_CHECK: bool = TestEnvironment.def_flag(
1661    "TEST_CUDA_MEM_LEAK_CHECK",
1662    env_var="PYTORCH_TEST_CUDA_MEM_LEAK_CHECK",
1663)
1664
1665
1666# Dict of NumPy dtype -> torch dtype (when the correspondence exists)
1667numpy_to_torch_dtype_dict = {
1668    np.bool_      : torch.bool,
1669    np.uint8      : torch.uint8,
1670    np.uint16     : torch.uint16,
1671    np.uint32     : torch.uint32,
1672    np.uint64     : torch.uint64,
1673    np.int8       : torch.int8,
1674    np.int16      : torch.int16,
1675    np.int32      : torch.int32,
1676    np.int64      : torch.int64,
1677    np.float16    : torch.float16,
1678    np.float32    : torch.float32,
1679    np.float64    : torch.float64,
1680    np.complex64  : torch.complex64,
1681    np.complex128 : torch.complex128
1682}
1683
1684
1685# numpy dtypes like np.float64 are not instances, but rather classes. This leads to rather absurd cases like
1686# np.float64 != np.dtype("float64") but np.float64 == np.dtype("float64").type.
1687# Especially when checking against a reference we can't be sure which variant we get, so we simply try both.
1688def numpy_to_torch_dtype(np_dtype):
1689    try:
1690        return numpy_to_torch_dtype_dict[np_dtype]
1691    except KeyError:
1692        return numpy_to_torch_dtype_dict[np_dtype.type]
1693
1694
1695def has_corresponding_torch_dtype(np_dtype):
1696    try:
1697        numpy_to_torch_dtype(np_dtype)
1698        return True
1699    except KeyError:
1700        return False
1701
1702
1703if IS_WINDOWS:
1704    # Size of `np.intc` is platform defined.
1705    # It is returned by functions like `bitwise_not`.
1706    # On Windows `int` is 32-bit
1707    # https://docs.microsoft.com/en-us/cpp/cpp/data-type-ranges?view=msvc-160
1708    numpy_to_torch_dtype_dict[np.intc] = torch.int
1709
1710# Dict of torch dtype -> NumPy dtype
1711torch_to_numpy_dtype_dict = {value : key for (key, value) in numpy_to_torch_dtype_dict.items()}
1712torch_to_numpy_dtype_dict.update({
1713    torch.bfloat16: np.float32,
1714    torch.complex32: np.complex64
1715})
1716
1717def skipIfNNModuleInlined(
1718    msg="test doesn't currently work with nn module inlining",
1719    condition=torch._dynamo.config.inline_inbuilt_nn_modules,
1720):
1721    def decorator(fn):
1722        if not isinstance(fn, type):
1723
1724            @wraps(fn)
1725            def wrapper(*args, **kwargs):
1726                if condition:
1727                    raise unittest.SkipTest(msg)
1728                else:
1729                    fn(*args, **kwargs)
1730
1731            return wrapper
1732
1733        assert isinstance(fn, type)
1734        if condition:
1735            fn.__unittest_skip__ = True
1736            fn.__unittest_skip_why__ = msg
1737
1738        return fn
1739
1740    return decorator
1741
1742def skipIfRocm(func=None, *, msg="test doesn't currently work on the ROCm stack"):
1743    def dec_fn(fn):
1744        reason = f"skipIfRocm: {msg}"
1745
1746        @wraps(fn)
1747        def wrapper(*args, **kwargs):
1748            if TEST_WITH_ROCM:
1749                raise unittest.SkipTest(reason)
1750            else:
1751                return fn(*args, **kwargs)
1752        return wrapper
1753    if func:
1754        return dec_fn(func)
1755    return dec_fn
1756
1757def runOnRocm(fn):
1758    @wraps(fn)
1759    def wrapper(*args, **kwargs):
1760        if TEST_WITH_ROCM:
1761            fn(*args, **kwargs)
1762        else:
1763            raise unittest.SkipTest("test currently only works on the ROCm stack")
1764    return wrapper
1765
1766def runOnRocmArch(arch: Tuple[str, ...]):
1767    def dec_fn(fn):
1768        @wraps(fn)
1769        def wrap_fn(self, *args, **kwargs):
1770            if TEST_WITH_ROCM:
1771                prop = torch.cuda.get_device_properties(0)
1772                if prop.gcnArchName.split(":")[0] not in arch:
1773                    reason = f"skipIfRocm: test only runs on {arch}"
1774                    raise unittest.SkipTest(reason)
1775            return fn(self, *args, **kwargs)
1776        return wrap_fn
1777    return dec_fn
1778
1779def skipIfXpu(func=None, *, msg="test doesn't currently work on the XPU stack"):
1780    def dec_fn(fn):
1781        reason = f"skipIfXpu: {msg}"
1782
1783        @wraps(fn)
1784        def wrapper(*args, **kwargs):
1785            if TEST_XPU:
1786                raise unittest.SkipTest(reason)
1787            else:
1788                return fn(*args, **kwargs)
1789        return wrapper
1790    if func:
1791        return dec_fn(func)
1792    return dec_fn
1793
1794def skipIfMps(fn):
1795    @wraps(fn)
1796    def wrapper(*args, **kwargs):
1797        if TEST_MPS:
1798            raise unittest.SkipTest("test doesn't currently work with MPS")
1799        else:
1800            fn(*args, **kwargs)
1801    return wrapper
1802
1803def skipIfHpu(fn):
1804    @wraps(fn)
1805    def wrapper(*args, **kwargs):
1806        if TEST_HPU:
1807            raise unittest.SkipTest("test doesn't currently work with HPU")
1808        else:
1809            fn(*args, **kwargs)
1810    return wrapper
1811
1812# Skips a test on CUDA if ROCm is available and its version is lower than requested.
1813def skipIfRocmVersionLessThan(version=None):
1814    def dec_fn(fn):
1815        @wraps(fn)
1816        def wrap_fn(self, *args, **kwargs):
1817            if TEST_WITH_ROCM:
1818                rocm_version = str(torch.version.hip)
1819                rocm_version = rocm_version.split("-")[0]    # ignore git sha
1820                rocm_version_tuple = tuple(int(x) for x in rocm_version.split("."))
1821                if rocm_version_tuple is None or version is None or rocm_version_tuple < tuple(version):
1822                    reason = f"ROCm {rocm_version_tuple} is available but {version} required"
1823                    raise unittest.SkipTest(reason)
1824            return fn(self, *args, **kwargs)
1825        return wrap_fn
1826    return dec_fn
1827
1828def skipIfNotMiopenSuggestNHWC(fn):
1829    @wraps(fn)
1830    def wrapper(*args, **kwargs):
1831        if not TEST_WITH_MIOPEN_SUGGEST_NHWC:
1832            raise unittest.SkipTest("test doesn't currently work without MIOpen NHWC activation")
1833        else:
1834            fn(*args, **kwargs)
1835    return wrapper
1836
1837def skipIfWindows(func=None, *, msg="test doesn't currently work on the Windows stack"):
1838    def dec_fn(fn):
1839        reason = f"skipIfWindows: {msg}"
1840
1841        @wraps(fn)
1842        def wrapper(*args, **kwargs):
1843            if IS_WINDOWS:  # noqa: F821
1844                raise unittest.SkipTest(reason)
1845            else:
1846                return fn(*args, **kwargs)
1847        return wrapper
1848    if func:
1849        return dec_fn(func)
1850    return dec_fn
1851
1852# Reverts the linalg backend back to default to make sure potential failures in one
1853# test do not affect other tests
1854def setLinalgBackendsToDefaultFinally(fn):
1855    @wraps(fn)
1856    def _fn(*args, **kwargs):
1857        _preferred_backend = torch.backends.cuda.preferred_linalg_library()
1858        try:
1859            fn(*args, **kwargs)
1860        finally:
1861            torch.backends.cuda.preferred_linalg_library(_preferred_backend)
1862    return _fn
1863
1864
1865# Reverts the blas backend back to default to make sure potential failures in one
1866# test do not affect other tests
1867def setBlasBackendsToDefaultFinally(fn):
1868    @wraps(fn)
1869    def _fn(*args, **kwargs):
1870        _preferred_backend = torch.backends.cuda.preferred_blas_library()
1871        try:
1872            fn(*args, **kwargs)
1873        finally:
1874            torch.backends.cuda.preferred_blas_library(_preferred_backend)
1875    return _fn
1876
1877
1878# Context manager for setting deterministic flag and automatically
1879# resetting it to its original value
1880class DeterministicGuard:
1881    def __init__(self, deterministic, *, warn_only=False, fill_uninitialized_memory=True):
1882        self.deterministic = deterministic
1883        self.warn_only = warn_only
1884        self.fill_uninitialized_memory = fill_uninitialized_memory
1885
1886    def __enter__(self):
1887        self.deterministic_restore = torch.are_deterministic_algorithms_enabled()
1888        self.warn_only_restore = torch.is_deterministic_algorithms_warn_only_enabled()
1889        self.fill_uninitialized_memory_restore = torch.utils.deterministic.fill_uninitialized_memory
1890        torch.use_deterministic_algorithms(
1891            self.deterministic,
1892            warn_only=self.warn_only)
1893        torch.utils.deterministic.fill_uninitialized_memory = self.fill_uninitialized_memory
1894
1895    def __exit__(self, exception_type, exception_value, traceback):
1896        torch.use_deterministic_algorithms(
1897            self.deterministic_restore,
1898            warn_only=self.warn_only_restore)
1899        torch.utils.deterministic.fill_uninitialized_memory = self.fill_uninitialized_memory_restore
1900
1901class AlwaysWarnTypedStorageRemoval:
1902    def __init__(self, always_warn):
1903        assert isinstance(always_warn, bool)
1904        self.always_warn = always_warn
1905
1906    def __enter__(self):
1907        self.always_warn_restore = torch.storage._get_always_warn_typed_storage_removal()
1908        torch.storage._set_always_warn_typed_storage_removal(self.always_warn)
1909
1910    def __exit__(self, exception_type, exception_value, traceback):
1911        torch.storage._set_always_warn_typed_storage_removal(self.always_warn_restore)
1912
1913# Context manager for setting cuda sync debug mode and reset it
1914# to original value
1915# we are not exposing it to the core because sync debug mode is
1916# global and thus not thread safe
1917class CudaSyncGuard:
1918    def __init__(self, sync_debug_mode):
1919        self.mode = sync_debug_mode
1920
1921    def __enter__(self):
1922        self.debug_mode_restore = torch.cuda.get_sync_debug_mode()
1923        torch.cuda.set_sync_debug_mode(self.mode)
1924
1925    def __exit__(self, exception_type, exception_value, traceback):
1926        torch.cuda.set_sync_debug_mode(self.debug_mode_restore)
1927
1928# Context manager for setting torch.__future__.set_swap_module_params_on_conversion
1929# and automatically resetting it to its original value
1930class SwapTensorsGuard:
1931    def __init__(self, use_swap_tensors):
1932        self.use_swap_tensors = use_swap_tensors
1933
1934    def __enter__(self):
1935        self.swap_tensors_restore = torch.__future__.get_swap_module_params_on_conversion()
1936        if self.use_swap_tensors is not None:
1937            torch.__future__.set_swap_module_params_on_conversion(self.use_swap_tensors)
1938
1939    def __exit__(self, exception_type, exception_value, traceback):
1940        torch.__future__.set_swap_module_params_on_conversion(self.swap_tensors_restore)
1941
1942# This decorator can be used for API tests that call
1943# torch.use_deterministic_algorithms().  When the test is finished, it will
1944# restore the previous deterministic flag setting.
1945#
1946# If CUDA >= 10.2, this will set the environment variable
1947# CUBLAS_WORKSPACE_CONFIG=:4096:8 so that the error associated with that
1948# setting is not thrown during the test unless the test changes that variable
1949# on purpose. The previous CUBLAS_WORKSPACE_CONFIG setting will also be
1950# restored once the test is finished.
1951#
1952# Note that if a test requires CUDA to actually register the changed
1953# CUBLAS_WORKSPACE_CONFIG variable, a new subprocess must be created, because
1954# CUDA only checks the variable when the runtime initializes. Tests can be
1955# run inside a subprocess like so:
1956#
1957#   import subprocess, sys, os
1958#   script = '''
1959#   # Test code should go here
1960#   '''
1961#   try:
1962#       subprocess.check_output(
1963#           [sys.executable, '-c', script],
1964#           stderr=subprocess.STDOUT,
1965#           cwd=os.path.dirname(os.path.realpath(__file__)),
1966#           env=os.environ.copy())
1967#   except subprocess.CalledProcessError as e:
1968#       error_message = e.output.decode('utf-8')
1969#       # Handle exceptions raised by the subprocess here
1970#
1971def wrapDeterministicFlagAPITest(fn):
1972    @wraps(fn)
1973    def wrapper(*args, **kwargs):
1974        with DeterministicGuard(
1975                torch.are_deterministic_algorithms_enabled(),
1976                warn_only=torch.is_deterministic_algorithms_warn_only_enabled()):
1977            class CuBLASConfigGuard:
1978                cublas_var_name = 'CUBLAS_WORKSPACE_CONFIG'
1979
1980                def __enter__(self):
1981                    self.is_cuda10_2_or_higher = (
1982                        (torch.version.cuda is not None)
1983                        and ([int(x) for x in torch.version.cuda.split(".")] >= [10, 2]))
1984                    if self.is_cuda10_2_or_higher:
1985                        self.cublas_config_restore = os.environ.get(self.cublas_var_name)
1986                        os.environ[self.cublas_var_name] = ':4096:8'
1987
1988                def __exit__(self, exception_type, exception_value, traceback):
1989                    if self.is_cuda10_2_or_higher:
1990                        cur_cublas_config = os.environ.get(self.cublas_var_name)
1991                        if self.cublas_config_restore is None:
1992                            if cur_cublas_config is not None:
1993                                del os.environ[self.cublas_var_name]
1994                        else:
1995                            os.environ[self.cublas_var_name] = self.cublas_config_restore
1996            with CuBLASConfigGuard():
1997                fn(*args, **kwargs)
1998    return wrapper
1999
2000# This decorator can be used for API tests that want to safely call
2001# torch.__future__.set_swap_module_params_on_conversion.  `swap` can be set to
2002# True, False or None where None indicates that the context manager does not
2003# set the flag. When the test is finished, it will restore the previous swap
2004# flag setting.
2005def wrapSwapTensorsTest(swap=None):
2006    def dec_fn(fn):
2007        @wraps(fn)
2008        def wrapper(*args, **kwargs):
2009            with SwapTensorsGuard(swap):
2010                fn(*args, **kwargs)
2011        return wrapper
2012    return dec_fn
2013
2014# test parametrizer for swapping
2015class swap(_TestParametrizer):
2016    def __init__(self, swap_values):
2017        super().__init__()
2018        self.swap_values = swap_values
2019
2020    def _parametrize_test(self, test, generic_cls, device_cls):
2021        for swap in self.swap_values:
2022            yield wrapSwapTensorsTest(swap)(test), f'swap_{swap}', {}, lambda _: []
2023
2024def skipIfCompiledWithoutNumpy(fn):
2025    # Even if the numpy module is present, if `USE_NUMPY=0` is used during the
2026    # build, numpy tests will fail
2027    numpy_support = TEST_NUMPY
2028    if numpy_support:
2029        try:
2030            # The numpy module is present, verify that PyTorch is compiled with
2031            # numpy support
2032            torch.from_numpy(np.array([2, 2]))
2033        except RuntimeError:
2034            numpy_support = False
2035
2036    @wraps(fn)
2037    def wrapper(*args, **kwargs):
2038        if not numpy_support:
2039            raise unittest.SkipTest("PyTorch was compiled without numpy support")
2040        else:
2041            fn(*args, **kwargs)
2042    return wrapper
2043
2044def _test_function(fn, device):
2045    def run_test_function(self):
2046        return fn(self, device)
2047    return run_test_function
2048
2049def skipIfNoXNNPACK(fn):
2050    @wraps(fn)
2051    def wrapper(*args, **kwargs):
2052        if not torch.backends.xnnpack.enabled:
2053            raise unittest.SkipTest('XNNPACK must be enabled for these tests. Please build with USE_XNNPACK=1.')
2054        else:
2055            fn(*args, **kwargs)
2056    return wrapper
2057
2058def skipIfNoLapack(fn):
2059    @wraps(fn)
2060    def wrapper(*args, **kwargs):
2061        if not torch._C.has_lapack:
2062            raise unittest.SkipTest('PyTorch compiled without Lapack')
2063        else:
2064            fn(*args, **kwargs)
2065    return wrapper
2066
2067def skipIfNotRegistered(op_name, message):
2068    """Wraps the decorator to hide the import of the `core`.
2069
2070    Args:
2071        op_name: Check if this op is registered in `core._REGISTERED_OPERATORS`.
2072        message: message to fail with.
2073
2074    Usage:
2075        @skipIfNotRegistered('MyOp', 'MyOp is not linked!')
2076            This will check if 'MyOp' is in the caffe2.python.core
2077    """
2078    return unittest.skip("Pytorch is compiled without Caffe2")
2079
2080def skipIfNoSciPy(fn):
2081    @wraps(fn)
2082    def wrapper(*args, **kwargs):
2083        if not TEST_SCIPY:
2084            raise unittest.SkipTest("test require SciPy, but SciPy not found")
2085        else:
2086            fn(*args, **kwargs)
2087    return wrapper
2088
2089def skip_if_pytest(fn):
2090    @wraps(fn)
2091    def wrapped(*args, **kwargs):
2092        if "PYTEST_CURRENT_TEST" in os.environ:
2093            raise unittest.SkipTest("does not work under pytest")
2094        return fn(*args, **kwargs)
2095
2096    return wrapped
2097
2098
2099def slowTest(fn):
2100    @wraps(fn)
2101    def wrapper(*args, **kwargs):
2102        if not TEST_WITH_SLOW:
2103            raise unittest.SkipTest("test is slow; run with PYTORCH_TEST_WITH_SLOW to enable test")
2104        else:
2105            fn(*args, **kwargs)
2106    wrapper.__dict__['slow_test'] = True
2107    return wrapper
2108
2109
2110def slowTestIf(condition):
2111    return slowTest if condition else lambda fn: fn
2112
2113
2114def skipCUDAMemoryLeakCheckIf(condition):
2115    def dec(fn):
2116        if getattr(fn, '_do_cuda_memory_leak_check', True):  # if current True
2117            fn._do_cuda_memory_leak_check = not condition
2118        return fn
2119    return dec
2120
2121def skipCUDANonDefaultStreamIf(condition):
2122    def dec(fn):
2123        if getattr(fn, '_do_cuda_non_default_stream', True):  # if current True
2124            fn._do_cuda_non_default_stream = not condition
2125        return fn
2126    return dec
2127
2128def suppress_warnings(fn):
2129    @wraps(fn)
2130    def wrapper(*args, **kwargs):
2131        with warnings.catch_warnings():
2132            warnings.simplefilter("ignore")
2133            fn(*args, **kwargs)
2134    return wrapper
2135
2136
2137def to_gpu(obj, type_map=None):
2138    if type_map is None:
2139        type_map = {}
2140    if isinstance(obj, torch.Tensor):
2141        assert obj.is_leaf
2142        t = type_map.get(obj.dtype, obj.dtype)
2143        with torch.no_grad():
2144            res = obj.clone().to(dtype=t, device="cuda")
2145            res.requires_grad = obj.requires_grad
2146        return res
2147    elif torch.is_storage(obj):
2148        return obj.new().resize_(obj.size()).copy_(obj)
2149    elif isinstance(obj, list):
2150        return [to_gpu(o, type_map) for o in obj]
2151    elif isinstance(obj, tuple):
2152        return tuple(to_gpu(o, type_map) for o in obj)
2153    else:
2154        return deepcopy(obj)
2155
2156
2157def get_function_arglist(func):
2158    return inspect.getfullargspec(func).args
2159
2160
2161def set_rng_seed(seed):
2162    torch.manual_seed(seed)
2163    random.seed(seed)
2164    if TEST_NUMPY:
2165        np.random.seed(seed)
2166
2167
2168@contextlib.contextmanager
2169def set_default_dtype(dtype):
2170    saved_dtype = torch.get_default_dtype()
2171    torch.set_default_dtype(dtype)
2172    try:
2173        yield
2174    finally:
2175        torch.set_default_dtype(saved_dtype)
2176
2177@contextlib.contextmanager
2178def set_default_tensor_type(tensor_type):
2179    saved_tensor_type = torch.tensor([]).type()
2180    torch.set_default_tensor_type(tensor_type)
2181    try:
2182        yield
2183    finally:
2184        torch.set_default_tensor_type(saved_tensor_type)
2185
2186def iter_indices(tensor):
2187    if tensor.dim() == 0:
2188        return range(0)
2189    if tensor.dim() == 1:
2190        return range(tensor.size(0))
2191    return product(*(range(s) for s in tensor.size()))
2192
2193
2194def is_iterable(obj):
2195    try:
2196        iter(obj)
2197        return True
2198    except TypeError:
2199        return False
2200
2201
2202def is_iterable_of_tensors(iterable, include_empty=False):
2203    """ Returns True if iterable is an iterable of tensors and False o.w.
2204
2205        If the iterable is empty, the return value is :attr:`include_empty`
2206    """
2207    # Tensor itself is iterable so we check this first
2208    if isinstance(iterable, torch.Tensor):
2209        return False
2210
2211    try:
2212        if len(iterable) == 0:
2213            return include_empty
2214
2215        for t in iter(iterable):
2216            if not isinstance(t, torch.Tensor):
2217                return False
2218
2219    except TypeError as te:
2220        return False
2221
2222    return True
2223
2224
2225class CudaNonDefaultStream:
2226    def __enter__(self):
2227        # Before starting CUDA test save currently active streams on all
2228        # CUDA devices and set new non default streams to all CUDA devices
2229        # to ensure CUDA tests do not use default stream by mistake.
2230        beforeDevice = torch.cuda.current_device()
2231        self.beforeStreams = []
2232        for d in range(torch.cuda.device_count()):
2233            self.beforeStreams.append(torch.cuda.current_stream(d))
2234            deviceStream = torch.cuda.Stream(device=d)
2235            self.beforeStreams[-1].synchronize()
2236            torch._C._cuda_setStream(stream_id=deviceStream.stream_id,
2237                                     device_index=deviceStream.device_index,
2238                                     device_type=deviceStream.device_type)
2239        torch._C._cuda_setDevice(beforeDevice)
2240
2241    def __exit__(self, exec_type, exec_value, traceback):
2242        # After completing CUDA test load previously active streams on all
2243        # CUDA devices.
2244        beforeDevice = torch.cuda.current_device()
2245        for d in range(torch.cuda.device_count()):
2246            torch._C._cuda_setStream(stream_id=self.beforeStreams[d].stream_id,
2247                                     device_index=self.beforeStreams[d].device_index,
2248                                     device_type=self.beforeStreams[d].device_type)
2249        torch._C._cuda_setDevice(beforeDevice)
2250
2251class CudaMemoryLeakCheck:
2252    def __init__(self, testcase, name=None):
2253        self.name = testcase.id() if name is None else name
2254        self.testcase = testcase
2255
2256        # initialize context & RNG to prevent false positive detections
2257        # when the test is the first to initialize those
2258        from torch.testing._internal.common_cuda import initialize_cuda_context_rng
2259        initialize_cuda_context_rng()
2260
2261    # Stores CUDA memory data provided by PyTorch's caching allocator and
2262    #   the CUDA driver.
2263    #
2264    # NOTE: The undocumented torch.cuda.mem_get_info() returns
2265    #   (#free bytes, #total bytes available) on the GPU
2266    def __enter__(self):
2267        self.caching_allocator_befores = []
2268        self.driver_befores = []
2269
2270        # Performs a gc if required (required if any CUDA memory is held)
2271        num_devices = torch.cuda.device_count()
2272        for i in range(num_devices):
2273            caching_allocator_mem_allocated = torch.cuda.memory_allocated(i)
2274            # NOTE: gc is based exclusively on caching allocator memory
2275            #   because the driver will always have some bytes in use (context size?)
2276            if caching_allocator_mem_allocated > 0:
2277                gc.collect()
2278                torch._C._cuda_clearCublasWorkspaces()
2279                torch.cuda.empty_cache()
2280                break
2281
2282        # Acquires caching allocator and driver statistics before the test is run
2283        for i in range(num_devices):
2284            self.caching_allocator_befores.append(torch.cuda.memory_allocated(i))
2285            bytes_free, bytes_total = torch.cuda.mem_get_info(i)
2286            driver_mem_allocated = bytes_total - bytes_free
2287            self.driver_befores.append(driver_mem_allocated)
2288
2289    def __exit__(self, exec_type, exec_value, traceback):
2290        # Don't check for leaks if an exception was thrown
2291        if exec_type is not None:
2292            return
2293
2294        # Compares caching allocator before/after statistics
2295        # An increase in allocated memory is a discrepancy indicating a possible
2296        #   memory leak
2297        discrepancy_detected = False
2298        num_devices = torch.cuda.device_count()
2299        for i in range(num_devices):
2300            # avoid counting cublasWorkspace allocations
2301            torch._C._cuda_clearCublasWorkspaces()
2302            caching_allocator_mem_allocated = torch.cuda.memory_allocated(i)
2303
2304            if caching_allocator_mem_allocated > self.caching_allocator_befores[i]:
2305                discrepancy_detected = True
2306                break
2307
2308        # Short-circuits if no discrepancy detected
2309        if not discrepancy_detected:
2310            return
2311
2312        # Validates the discrepancy persists after garbage collection and
2313        #   is confirmed by the driver API
2314
2315        # NOTE: driver API iscrepancies alone are ignored because with the jiterator
2316        #   some tests may permanently increase the CUDA context size and
2317        #   that will appear as a driver memory leak but is the expected behavior.
2318
2319        # GCs and clears the cache
2320        gc.collect()
2321        torch.cuda.empty_cache()
2322
2323        for i in range(num_devices):
2324
2325            discrepancy_detected = True
2326
2327            # Query memory multiple items to ensure leak was not transient
2328            for n in range(3):
2329                caching_allocator_mem_allocated = torch.cuda.memory_allocated(i)
2330                bytes_free, bytes_total = torch.cuda.mem_get_info(i)
2331                driver_mem_allocated = bytes_total - bytes_free
2332
2333                caching_allocator_discrepancy = False
2334                driver_discrepancy = False
2335
2336                if caching_allocator_mem_allocated > self.caching_allocator_befores[i]:
2337                    caching_allocator_discrepancy = True
2338
2339                if driver_mem_allocated > self.driver_befores[i]:
2340                    driver_discrepancy = True
2341
2342                if not (caching_allocator_discrepancy or driver_discrepancy):
2343                    # Leak was false positive, exit loop
2344                    discrepancy_detected = False
2345                    break
2346
2347            if not discrepancy_detected:
2348                continue
2349
2350            if caching_allocator_discrepancy and not driver_discrepancy:
2351                # Just raises a warning if the leak is not validated by the
2352                #   driver API
2353                # NOTE: this may be a problem with how the caching allocator collects its
2354                #   statistics or a leak too small to trigger the allocation of an
2355                #   additional block of memory by the CUDA driver
2356                msg = ("CUDA caching allocator reports a memory leak not "
2357                       f"verified by the driver API in {self.name}! "
2358                       f"Caching allocator allocated memory was {self.caching_allocator_befores[i]} "
2359                       f"and is now reported as {caching_allocator_mem_allocated} "
2360                       f"on device {i}. "
2361                       f"CUDA driver allocated memory was {self.driver_befores[i]} and is now {driver_mem_allocated}.")
2362                warnings.warn(msg)
2363            elif caching_allocator_discrepancy and driver_discrepancy:
2364                # A caching allocator discrepancy validated by the driver API is a
2365                #   failure (except on ROCm, see below)
2366                msg = (f"CUDA driver API confirmed a leak in {self.name}! "
2367                       f"Caching allocator allocated memory was {self.caching_allocator_befores[i]} "
2368                       f"and is now reported as {caching_allocator_mem_allocated} "
2369                       f"on device {i}. "
2370                       f"CUDA driver allocated memory was {self.driver_befores[i]} and is now {driver_mem_allocated}.")
2371
2372                raise RuntimeError(msg)
2373
2374@contextmanager
2375def skip_exception_type(exc_type):
2376    try:
2377        yield
2378    except exc_type as e:
2379        raise unittest.SkipTest(f"not implemented: {e}") from e
2380
2381@contextmanager
2382def print_repro_on_failure(repro_parts):
2383    try:
2384        yield
2385    except unittest.SkipTest:
2386        raise
2387    except Exception as e:
2388        # Get the index of the sample input that failed the test if possible.
2389        sample_isolation_prefix = ""
2390        tracked_input = getattr(e, "_tracked_input", None)
2391        if tracked_input is not None:
2392            sample_isolation_prefix = f"PYTORCH_OPINFO_SAMPLE_INPUT_INDEX={tracked_input.index}"
2393
2394        repro_str = " ".join(filter(None, (sample_isolation_prefix, *repro_parts)))
2395        repro_msg = f"""
2396To execute this test, run the following from the base repo dir:
2397    {repro_str}
2398
2399This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0"""
2400
2401        # NB: Hacking the exception args is the cleanest way I've found to append
2402        # failure reproduction info without poisoning the stack trace.
2403        if len(e.args) >= 1:
2404            e.args = (f"{e.args[0]}\n{repro_msg}", *e.args[1:])
2405        raise
2406
2407#  "min_satisfying_examples" setting has been deprecated in hypothesis
2408#  3.56.0 and removed in hypothesis 4.x
2409try:
2410    import hypothesis
2411
2412    def settings(*args, **kwargs):
2413        if 'min_satisfying_examples' in kwargs and hypothesis.version.__version_info__ >= (3, 56, 0):
2414            kwargs.pop('min_satisfying_examples')
2415        return hypothesis.settings(*args, **kwargs)
2416
2417
2418    hypothesis.settings.register_profile(
2419        "pytorch_ci",
2420        settings(
2421            derandomize=True,
2422            suppress_health_check=[hypothesis.HealthCheck.too_slow],
2423            database=None,
2424            max_examples=50,
2425            verbosity=hypothesis.Verbosity.normal))
2426    hypothesis.settings.register_profile(
2427        "dev",
2428        settings(
2429            suppress_health_check=[hypothesis.HealthCheck.too_slow],
2430            database=None,
2431            max_examples=10,
2432            verbosity=hypothesis.Verbosity.normal))
2433    hypothesis.settings.register_profile(
2434        "debug",
2435        settings(
2436            suppress_health_check=[hypothesis.HealthCheck.too_slow],
2437            database=None,
2438            max_examples=1000,
2439            verbosity=hypothesis.Verbosity.verbose))
2440
2441    hypothesis.settings.load_profile(
2442        "pytorch_ci" if IS_CI else os.getenv('PYTORCH_HYPOTHESIS_PROFILE', 'dev')
2443    )
2444except ImportError:
2445    print('Fail to import hypothesis in common_utils, tests are not derandomized')
2446
2447# Used in check_if_enable to see if a test method should be disabled by an issue,
2448# sanitizes a test method name from appended suffixes by @dtypes parametrization.
2449# e.g., an issue with title "DISABLED test_bitwise_ops (__main__.TestBinaryUfuncs)" should
2450# disabled ALL parametrized test_bitwise_ops tests, such test_bitwise_ops_cuda_int32
2451def remove_device_and_dtype_suffixes(test_name: str) -> str:
2452    # import statement is localized to avoid circular dependency issues with common_device_type.py
2453    from torch.testing._internal.common_device_type import get_device_type_test_bases
2454    device_suffixes = [x.device_type for x in get_device_type_test_bases()]
2455    dtype_suffixes = [str(dt)[len("torch."):] for dt in get_all_dtypes()]
2456
2457    test_name_chunks = test_name.split("_")
2458    if len(test_name_chunks) > 0 and test_name_chunks[-1] in dtype_suffixes:
2459        if len(test_name_chunks) > 1 and test_name_chunks[-2] in device_suffixes:
2460            return "_".join(test_name_chunks[0:-2])
2461        return "_".join(test_name_chunks[0:-1])
2462    return test_name
2463
2464
2465def check_if_enable(test: unittest.TestCase):
2466    classname = str(test.__class__).split("'")[1].split(".")[-1]
2467    sanitized_testname = remove_device_and_dtype_suffixes(test._testMethodName)
2468
2469    def matches_test(target: str):
2470        target_test_parts = target.split()
2471        if len(target_test_parts) < 2:
2472            # poorly formed target test name
2473            return False
2474        target_testname = target_test_parts[0]
2475        target_classname = target_test_parts[1][1:-1].split(".")[-1]
2476        # if test method name or its sanitized version exactly matches the disabled
2477        # test method name AND allow non-parametrized suite names to disable
2478        # parametrized ones (TestSuite disables TestSuiteCPU)
2479        return classname.startswith(target_classname) and (target_testname in (test._testMethodName, sanitized_testname))
2480
2481    if any(matches_test(x) for x in slow_tests_dict.keys()):
2482        getattr(test, test._testMethodName).__dict__['slow_test'] = True
2483        if not TEST_WITH_SLOW:
2484            raise unittest.SkipTest("test is slow; run with PYTORCH_TEST_WITH_SLOW to enable test")
2485
2486    if not IS_SANDCASTLE:
2487        should_skip = False
2488        skip_msg = ""
2489
2490        for disabled_test, (issue_url, platforms) in disabled_tests_dict.items():
2491            if matches_test(disabled_test):
2492                platform_to_conditional: Dict = {
2493                    "mac": IS_MACOS,
2494                    "macos": IS_MACOS,
2495                    "win": IS_WINDOWS,
2496                    "windows": IS_WINDOWS,
2497                    "linux": IS_LINUX,
2498                    "rocm": TEST_WITH_ROCM,
2499                    "xpu": TEST_XPU,
2500                    "asan": TEST_WITH_ASAN,
2501                    "dynamo": TEST_WITH_TORCHDYNAMO,
2502                    "inductor": TEST_WITH_TORCHINDUCTOR,
2503                    "slow": TEST_WITH_SLOW,
2504                }
2505
2506                invalid_platforms = list(filter(lambda p: p not in platform_to_conditional, platforms))
2507                if len(invalid_platforms) > 0:
2508                    invalid_plats_str = ", ".join(invalid_platforms)
2509                    valid_plats = ", ".join(platform_to_conditional.keys())
2510
2511                    print(f"Test {disabled_test} is disabled for some unrecognized ",
2512                          f"platforms: [{invalid_plats_str}]. Please edit issue {issue_url} to fix the platforms ",
2513                          'assigned to this flaky test, changing "Platforms: ..." to a comma separated ',
2514                          f"subset of the following (or leave it blank to match all platforms): {valid_plats}")
2515
2516                    # Sanitize the platforms list so that we continue to disable the test for any valid platforms given
2517                    platforms = list(filter(lambda p: p in platform_to_conditional, platforms))
2518
2519                if platforms == [] or any(platform_to_conditional[platform] for platform in platforms):
2520                    should_skip = True
2521                    skip_msg = f"Test is disabled because an issue exists disabling it: {issue_url}" \
2522                        f" for {'all' if platforms == [] else ''}platform(s) {', '.join(platforms)}. " \
2523                        "If you're seeing this on your local machine and would like to enable this test, " \
2524                        "please make sure CI is not set and you are not using the flag --import-disabled-tests."
2525                    break
2526
2527        if should_skip and not RERUN_DISABLED_TESTS:
2528            # Skip the disabled test when not running under --rerun-disabled-tests verification mode
2529            raise unittest.SkipTest(skip_msg)
2530
2531        if not should_skip and RERUN_DISABLED_TESTS:
2532            skip_msg = "Test is enabled but --rerun-disabled-tests verification mode is set, so only" \
2533                " disabled tests are run"
2534            raise unittest.SkipTest(skip_msg)
2535
2536    if TEST_SKIP_FAST:
2537        if hasattr(test, test._testMethodName) and not getattr(test, test._testMethodName).__dict__.get('slow_test', False):
2538            raise unittest.SkipTest("test is fast; we disabled it with PYTORCH_TEST_SKIP_FAST")
2539
2540
2541# `TestCase.assertEqual` is very permissive and coerced the inputs into a format that could be compared. This is very
2542# convenient when writing tests, but not so much while reviewing them. By default, the comparison `Pair` framework of
2543# `torch.testing._comparison.are_equal`, used for example by the public testing function
2544# `torch.testing.assert_close`, is more strict. In order to use the same framework and thus reduce the divergence
2545# between internal and external comparison logic as much as possible, we define some "relaxed" pairs here. They only
2546# change the supported inputs, but the comparison logic is the same.
2547# TODO: Revisit the relaxed pairs and check how much work it is to fix the tests that would fail without the relaxation.
2548
2549class RelaxedBooleanPair(BooleanPair):
2550    """Pair for boolean-like inputs.
2551
2552    In contrast to the builtin :class:`BooleanPair`, this class also supports one input being a number or a single
2553    element tensor-like.
2554    """
2555    _supported_number_types = NumberPair(0, 0)._supported_types
2556
2557    def _process_inputs(self, actual, expected, *, id):
2558        # We require only one of the inputs of the inputs to be a boolean and the other can also be a boolean, a
2559        # number, or a single element tensor or array, whereas in default BooleanPair both inputs have to be booleans.
2560        tensor_or_array_types: Tuple[Type, ...] = (torch.Tensor, np.ndarray)
2561        other_supported_types = (*self._supported_types, *self._supported_number_types, *tensor_or_array_types)
2562        if not (
2563            (isinstance(actual, self._supported_types) and isinstance(expected, other_supported_types))
2564            or (isinstance(expected, self._supported_types) and isinstance(actual, other_supported_types))
2565        ):
2566            self._inputs_not_supported()
2567
2568        return [self._to_bool(input, id=id) for input in (actual, expected)]
2569
2570    def _to_bool(self, bool_like, *, id):
2571        if isinstance(bool_like, np.number):
2572            return bool(bool_like.item())
2573        elif type(bool_like) in self._supported_number_types:
2574            return bool(bool_like)
2575        elif isinstance(bool_like, (torch.Tensor, np.ndarray)):
2576            numel = bool_like.numel() if isinstance(bool_like, torch.Tensor) else bool_like.size
2577            if numel > 1:
2578                self._fail(
2579                    ValueError,
2580                    f"Only single element tensor-likes can be compared against a boolean. "
2581                    f"Got {numel} elements instead.",
2582                    id=id
2583                )
2584
2585            return bool(bool_like.item())
2586        else:
2587            return super()._to_bool(bool_like, id=id)
2588
2589
2590class RelaxedNumberPair(NumberPair):
2591    """Pair for number-like inputs.
2592
2593    In contrast to the builtin :class:`NumberPair`, this class also supports one input being a single element
2594    tensor-like or a :class:`enum.Enum`. (D)Type checks are disabled, meaning comparing 1 to 1.0 succeeds even when
2595    ``check_dtype=True`` is passed.
2596
2597    In addition, this class uses looser default tolerances for :class:`float` and :class:`complex` inputs. Also
2598    supports overriding the absolute and relative tolerance through the ``@precisionOverride`` and
2599    ``@toleranceOverride`` decorators.
2600    """
2601    _TYPE_TO_DTYPE = {
2602        int: torch.int64,
2603        float: torch.float32,
2604        complex: torch.complex64,
2605    }
2606
2607    def __init__(
2608            self, actual, expected, *, rtol_override=0.0, atol_override=0.0, check_dtype=None, **other_parameters
2609    ) -> None:
2610        super().__init__(actual, expected, check_dtype=False, **other_parameters)
2611        self.rtol = max(self.rtol, rtol_override)
2612        self.atol = max(self.atol, atol_override)
2613
2614    def _process_inputs(self, actual, expected, *, id):
2615        # We require only one of the inputs of the inputs to be a number and the other can also be a number or a single
2616        # element tensor or array, whereas in default NumberPair both inputs have to be numbers.
2617        tensor_or_array_types: Tuple[Type, ...] = (torch.Tensor, np.ndarray)
2618        other_supported_types = (*self._supported_types, *tensor_or_array_types)
2619        if not (
2620                (isinstance(actual, self._supported_types) and isinstance(expected, other_supported_types))
2621                or (isinstance(expected, self._supported_types) and isinstance(actual, other_supported_types))
2622        ):
2623            self._inputs_not_supported()
2624
2625        return [self._to_number(input, id=id) for input in (actual, expected)]
2626
2627    def _to_number(self, number_like, *, id):
2628        if isinstance(number_like, (torch.Tensor, np.ndarray)):
2629            numel = number_like.numel() if isinstance(number_like, torch.Tensor) else number_like.size
2630            if numel > 1:
2631                self._fail(
2632                    ValueError,
2633                    f"Only single element tensor-likes can be compared against a number. "
2634                    f"Got {numel} elements instead.",
2635                    id=id
2636                )
2637            number = number_like.item()
2638            if isinstance(number, bool):
2639                number = int(number)
2640
2641            return number
2642        elif isinstance(number_like, Enum):
2643            return int(number_like)  # type: ignore[call-overload]
2644        else:
2645            return super()._to_number(number_like, id=id)
2646
2647
2648class TensorOrArrayPair(TensorLikePair):
2649    """Pair for tensor-like inputs.
2650
2651    On the one hand this class is stricter than the builtin :class:`TensorLikePair` since it only allows instances of
2652    :class:`torch.Tensor` and :class:`numpy.ndarray` rather than allowing any tensor-like than can be converted into a
2653    tensor. On the other hand this class is looser since it converts all inputs into tensors with no regard of their
2654    relationship, e.g. comparing a :class:`torch.Tensor` to :class:`numpy.ndarray` is fine.
2655
2656    In addition, this class supports overriding the absolute and relative tolerance through the ``@precisionOverride``
2657    and ``@toleranceOverride`` decorators.
2658    """
2659    def __init__(self, actual, expected, *, rtol_override=0.0, atol_override=0.0, **other_parameters):
2660        super().__init__(actual, expected, **other_parameters)
2661        self.rtol = max(self.rtol, rtol_override)
2662        self.atol = max(self.atol, atol_override)
2663
2664    def _process_inputs(self, actual, expected, *, id, allow_subclasses):
2665        self._check_inputs_isinstance(actual, expected, cls=(torch.Tensor, np.ndarray))
2666
2667        actual, expected = (self._to_tensor(input) for input in (actual, expected))
2668        for tensor in (actual, expected):
2669            self._check_supported(tensor, id=id)
2670        return actual, expected
2671
2672
2673class TypedStoragePair(TensorLikePair):
2674    """Pair for :class:`torch.storage.TypedStorage` inputs."""
2675    def __init__(self, actual, expected, *, rtol_override=0.0, atol_override=0.0, **other_parameters):
2676        self._check_inputs_isinstance(actual, expected, cls=torch.storage.TypedStorage)
2677        super().__init__(actual, expected, **other_parameters)
2678        self.rtol = max(self.rtol, rtol_override)
2679        self.atol = max(self.atol, atol_override)
2680
2681    def _to_tensor(self, typed_storage):
2682        return torch.tensor(
2683            typed_storage._untyped_storage,
2684            dtype={
2685                torch.quint8: torch.uint8,
2686                torch.quint4x2: torch.uint8,
2687                torch.quint2x4: torch.uint8,
2688                torch.qint32: torch.int32,
2689                torch.qint8: torch.int8
2690            }.get(typed_storage.dtype, typed_storage.dtype),
2691            device=typed_storage.device,
2692        )
2693
2694
2695class UnittestPair(Pair):
2696    """Fallback ABC pair that handles non-numeric inputs.
2697
2698    To avoid recreating the mismatch messages of :meth:`unittest.TestCase.assertEqual`, this pair simply wraps it in
2699    order to use it with the :class:`Pair` "framework" from :func:`are_equal`.
2700
2701    Define the :attr:`UnittestPair.CLS` in a subclass to indicate which class(es) of the inputs the pair should support.
2702    """
2703    CLS: Union[Type, Tuple[Type, ...]]
2704    TYPE_NAME: Optional[str] = None
2705
2706    def __init__(self, actual, expected, **other_parameters):
2707        self._check_inputs_isinstance(actual, expected, cls=self.CLS)
2708        super().__init__(actual, expected, **other_parameters)
2709
2710    def compare(self):
2711        test_case = unittest.TestCase()
2712
2713        try:
2714            return test_case.assertEqual(self.actual, self.expected)
2715        except test_case.failureException as error:
2716            msg = str(error)
2717
2718        type_name = self.TYPE_NAME or (self.CLS if isinstance(self.CLS, type) else self.CLS[0]).__name__
2719        self._fail(AssertionError, f"{type_name.title()} comparison failed: {msg}")
2720
2721
2722class StringPair(UnittestPair):
2723    CLS = (str, bytes)
2724    TYPE_NAME = "string"
2725
2726
2727class SetPair(UnittestPair):
2728    CLS = set
2729
2730
2731class TypePair(UnittestPair):
2732    CLS = type
2733
2734
2735class ObjectPair(UnittestPair):
2736    CLS = object
2737
2738
2739# This implements a variant of assertRaises/assertRaisesRegex where we first test
2740# if the exception is NotImplementedError, and if so just skip the test instead
2741# of failing it.
2742#
2743# This is implemented by inheriting from the (private) implementation of
2744# assertRaises from unittest.case, and slightly tweaking it for this new
2745# behavior.  The year is 2021: this private class hierarchy hasn't changed since
2746# 2010, seems low risk to inherit from.
2747class AssertRaisesContextIgnoreNotImplementedError(unittest.case._AssertRaisesContext):
2748    def __exit__(self, exc_type, exc_value, tb):
2749        if exc_type is not None and issubclass(exc_type, NotImplementedError):
2750            self.test_case.skipTest(f"not_implemented: {exc_value}")  # type: ignore[attr-defined]
2751        return super().__exit__(exc_type, exc_value, tb)
2752
2753
2754@contextmanager
2755def set_warn_always_context(new_val: bool):
2756    old_val = torch.is_warn_always_enabled()
2757    torch.set_warn_always(new_val)
2758    try:
2759        yield
2760    finally:
2761        torch.set_warn_always(old_val)
2762
2763
2764class NoTest:
2765    # causes pytest to not recognize this class as a test
2766    __test__ = False
2767
2768
2769class TestCase(expecttest.TestCase):
2770    # NOTE: "precision" lets classes and generated tests set minimum
2771    # atol values when comparing tensors. Used by @precisionOverride and @toleranceOverride, for
2772    # example.
2773    # NOTE: "rel_tol" lets classes and generated tests set minimum
2774    # rtol values when comparing tensors. Used by @toleranceOverride, for example.
2775    _precision: float = 0
2776    _rel_tol: float = 0
2777
2778    # Toggles whether to assert that `torch.get_default_dtype()` returns
2779    # `torch.float` when `setUp` and `tearDown` are called.
2780    _default_dtype_check_enabled: bool = False
2781
2782    # Always use difflib to print diffs on multi line equality.
2783    # Undocumented feature in unittest
2784    _diffThreshold = sys.maxsize
2785    maxDiff = None
2786
2787    # checker to early terminate test suite if unrecoverable failure occurs.
2788    def _should_stop_test_suite(self):
2789        if torch.cuda.is_initialized():
2790            # CUDA device side error will cause subsequence test cases to fail.
2791            # stop entire test suite if catches RuntimeError during torch.cuda.synchronize().
2792            try:
2793                torch.cuda.synchronize()
2794            except RuntimeError as rte:
2795                print("TEST SUITE EARLY TERMINATION due to torch.cuda.synchronize() failure", file=sys.stderr)
2796                print(str(rte), file=sys.stderr)
2797                return True
2798            return False
2799        else:
2800            return False
2801
2802    @property
2803    def precision(self) -> float:
2804        return self._precision
2805
2806    @precision.setter
2807    def precision(self, prec: float) -> None:
2808        self._precision = prec
2809
2810    @property
2811    def rel_tol(self) -> float:
2812        return self._rel_tol
2813
2814    @rel_tol.setter
2815    def rel_tol(self, prec: float) -> None:
2816        self._rel_tol = prec
2817
2818    _do_cuda_memory_leak_check = False
2819    _do_cuda_non_default_stream = False
2820
2821    # When True, if a test case raises a NotImplementedError, instead of failing
2822    # the test, skip it instead.
2823    _ignore_not_implemented_error = False
2824
2825    def __init__(self, method_name='runTest', methodName='runTest'):
2826        # methodName is the correct naming in unittest and testslide uses keyword arguments.
2827        # So we need to use both to 1) not break BC and, 2) support testslide.
2828        if methodName != "runTest":
2829            method_name = methodName
2830        super().__init__(method_name)
2831
2832        test_method = getattr(self, method_name, None)
2833        if test_method is not None:
2834            # Wraps the tested method if we should do CUDA memory check.
2835            if TEST_CUDA_MEM_LEAK_CHECK:
2836                self._do_cuda_memory_leak_check &= getattr(test_method, '_do_cuda_memory_leak_check', True)
2837                # FIXME: figure out the flaky -1024 anti-leaks on windows. See #8044
2838                if self._do_cuda_memory_leak_check and not IS_WINDOWS:
2839                    self.wrap_with_cuda_policy(method_name, self.assertLeaksNoCudaTensors)
2840
2841            # Wraps the tested method if we should enforce non default CUDA stream.
2842            self._do_cuda_non_default_stream &= getattr(test_method, '_do_cuda_non_default_stream', True)
2843            if self._do_cuda_non_default_stream and not IS_WINDOWS:
2844                self.wrap_with_cuda_policy(method_name, self.enforceNonDefaultStream)
2845
2846            if self._ignore_not_implemented_error:
2847                self.wrap_with_policy(method_name, lambda: skip_exception_type(NotImplementedError))
2848
2849            if PRINT_REPRO_ON_FAILURE:
2850                try:
2851                    def _get_rel_test_path(abs_test_path):
2852                        # Attempt to get relative path based on the "test" dir.
2853                        # In CI, the working dir is not guaranteed to be the base repo dir so
2854                        # we can't just compute relative path from that.
2855                        parts = Path(abs_test_path).parts
2856                        for i, part in enumerate(parts):
2857                            if part == "test":
2858                                base_dir = os.path.join(*parts[:i]) if i > 0 else ''
2859                                return os.path.relpath(abs_test_path, start=base_dir)
2860
2861                        # Can't determine containing dir; just return the test filename.
2862                        # The path isn't strictly correct but it's arguably better than nothing.
2863                        return os.path.split(abs_test_path)[1]
2864
2865                    # NB: In Python 3.8, the getfile() call will return a path relative
2866                    # to the working directory, so convert that to absolute.
2867                    abs_test_path = os.path.abspath(inspect.getfile(type(self)))
2868                    test_filename = _get_rel_test_path(abs_test_path)
2869                    class_name = type(self).__name__
2870                    test_run_cmd = f"python {test_filename} {class_name}.{method_name}"
2871                    env_var_prefix = TestEnvironment.repro_env_var_prefix()
2872                    repro_parts = [env_var_prefix, test_run_cmd]
2873                    self.wrap_with_policy(
2874                        method_name,
2875                        lambda repro_parts=repro_parts: print_repro_on_failure(repro_parts))
2876                except Exception as e:
2877                    # Don't fail entirely if we can't get the test filename
2878                    log.info("could not print repro string", extra=str(e))
2879
2880    def assertLeaksNoCudaTensors(self, name=None):
2881        name = self.id() if name is None else name
2882        return CudaMemoryLeakCheck(self, name)
2883
2884    def enforceNonDefaultStream(self):
2885        return CudaNonDefaultStream()
2886
2887    def _remove_ansi_escape(self, input):
2888        # 7-bit C1 ANSI sequences
2889        ansi_escape = re.compile(r'''
2890            \x1B  # ESC
2891            (?:   # 7-bit C1 Fe (except CSI)
2892                [@-Z\\-_]
2893            |     # or [ for CSI, followed by a control sequence
2894                \[
2895                [0-?]*  # Parameter bytes
2896                [ -/]*  # Intermediate bytes
2897                [@-~]   # Final byte
2898            )
2899        ''', re.VERBOSE)
2900        return ansi_escape.sub('', input)
2901
2902    def remove_comment_lines(self, input_string):
2903        lines = input_string.split('\n')
2904        filtered_lines = [line for line in lines if not line.strip().startswith('#')]
2905        return '\n'.join(filtered_lines)
2906
2907    def remove_empty_lines(self, input_string):
2908        lines = input_string.split('\n')
2909        filtered_lines = [line for line in lines if not line.strip() == '']
2910        return '\n'.join(filtered_lines)
2911
2912    # ignore comments will ignore lines that starts with # after being stripped
2913    def assertExpectedInline(self, actual, expect, skip=0, ignore_comments=False, ignore_empty_lines=False):
2914        actual = actual if isinstance(actual, str) else str(actual)
2915        actual = self._remove_ansi_escape(actual)
2916        expect = self._remove_ansi_escape(expect)
2917        if ignore_comments:
2918            actual = self.remove_comment_lines(actual)
2919            expect = self.remove_comment_lines(expect)
2920
2921        if ignore_empty_lines:
2922            actual = self.remove_empty_lines(actual)
2923            expect = self.remove_empty_lines(expect)
2924
2925        return super().assertExpectedInline(actual if isinstance(actual, str) else str(actual), expect, skip + 1)
2926
2927    # Munges exceptions that internally contain stack traces, using munge_exc
2928    def assertExpectedInlineMunged(
2929        self, exc_type, callable, expect, *, suppress_suffix=True
2930    ):
2931        try:
2932            callable()
2933        except exc_type as e:
2934            self.assertExpectedInline(
2935                munge_exc(e, suppress_suffix=suppress_suffix, skip=1), expect, skip=1
2936            )
2937            return
2938        self.fail(msg="Did not raise when expected to")
2939
2940    def assertLogs(self, logger=None, level=None):
2941        if logger is None:
2942            logger = logging.getLogger("torch")
2943        return super().assertLogs(logger, level)
2944
2945    def assertNoLogs(self, logger=None, level=None):
2946        if logger is None:
2947            logger = logging.getLogger("torch")
2948        return super().assertNoLogs(logger, level)
2949
2950    def wrap_with_cuda_policy(self, method_name, policy):
2951        test_method = getattr(self, method_name)
2952        # the import below may initialize CUDA context, so we do it only if
2953        # self._do_cuda_memory_leak_check or self._do_cuda_non_default_stream
2954        # is True.
2955        # TODO: sure looks like we unconditionally initialize the context here
2956        # -- ezyang
2957        from torch.testing._internal.common_cuda import TEST_CUDA
2958        fullname = self.id().lower()  # class_name.method_name
2959        if TEST_CUDA and ('gpu' in fullname or 'cuda' in fullname):
2960            setattr(self, method_name, self.wrap_method_with_policy(test_method, policy))
2961
2962    def wrap_with_policy(self, method_name, policy):
2963        test_method = getattr(self, method_name)
2964        setattr(self, method_name, self.wrap_method_with_policy(test_method, policy))
2965
2966    # A policy is a zero-argument function that returns a context manager.
2967    # We don't take the context manager directly as it may be necessary to
2968    # construct it once per test method
2969    def wrap_method_with_policy(self, method, policy):
2970        # Assumes that `method` is the tested function in `self`.
2971        # NOTE: Python Exceptions (e.g., unittest.Skip) keeps objects in scope
2972        #       alive, so this cannot be done in setUp and tearDown because
2973        #       tearDown is run unconditionally no matter whether the test
2974        #       passes or not. For the same reason, we can't wrap the `method`
2975        #       call in try-finally and always do the check.
2976        @wraps(method)
2977        def wrapper(self, *args, **kwargs):
2978            with policy():
2979                method(*args, **kwargs)
2980        return types.MethodType(wrapper, self)
2981
2982    def wrap_with_cuda_memory_check(self, method):
2983        return self.wrap_method_with_policy(method, self.assertLeaksNoCudaTensors)
2984
2985    def _run_custom(self, result=None):
2986        using_unittest = isinstance(result, unittest.TestResult)
2987
2988        super_run = super().run
2989        test_cls = super_run.__self__
2990
2991        # Are we compiling?
2992        compiled = TEST_WITH_TORCHDYNAMO or TEST_WITH_AOT_EAGER or TEST_WITH_TORCHINDUCTOR
2993        # Is the class strict and compiling?
2994        strict_default = False
2995        should_reset_dynamo = False
2996        if compiled:
2997            try:
2998                path = inspect.getfile(type(test_cls))
2999                full_path = os.path.abspath(path)
3000                match = re.match(r".*/test/(.*).py", full_path)
3001                if match is not None:
3002                    filename = match.group(1)
3003                    if TEST_WITH_TORCHINDUCTOR:
3004                        from .dynamo_test_failures import FIXME_inductor_non_strict
3005                        strict_default = filename not in FIXME_inductor_non_strict
3006
3007                        from .dynamo_test_failures import FIXME_inductor_dont_reset_dynamo
3008                        should_reset_dynamo = filename not in FIXME_inductor_dont_reset_dynamo
3009                    else:
3010                        strict_default = True
3011            # inspect.getfile can fail with these
3012            except (OSError, TypeError):
3013                pass
3014            if "STRICT_DEFAULT" in os.environ:
3015                if os.environ["STRICT_DEFAULT"] == "1":
3016                    strict_default = True
3017
3018        strict_mode = False
3019        if compiled:
3020            test_method = getattr(self, self._testMethodName)
3021            if hasattr(test_method, "dynamo_strict"):
3022                strict_mode = test_method.dynamo_strict
3023            elif hasattr(test_cls, "dynamo_strict"):
3024                strict_mode = test_cls.dynamo_strict
3025            else:
3026                strict_mode = strict_default
3027        nopython = getattr(test_cls, "dynamo_strict_nopython", False) and compiled
3028
3029        if strict_mode or should_reset_dynamo:
3030            torch._dynamo.reset()
3031
3032        # TODO: Remove this; this is grandfathered in because we suppressed errors
3033        # on test suite previously
3034        # When strict mode is False, suppress_errors is True
3035        if compiled:
3036            suppress_errors = not strict_mode
3037        else:
3038            suppress_errors = torch._dynamo.config.suppress_errors
3039        with unittest.mock.patch("torch._dynamo.config.suppress_errors", suppress_errors):
3040            if TEST_WITH_TORCHINDUCTOR:
3041                super_run = torch._dynamo.optimize("inductor")(super_run)
3042            elif TEST_WITH_AOT_EAGER:
3043                super_run = torch._dynamo.optimize("aot_eager_decomp_partition")(super_run)
3044            elif TEST_WITH_TORCHDYNAMO:
3045                # TorchDynamo optimize annotation
3046                # Assume eager-generated GraphModules will not error out.
3047                # If we do, this is probably a Dynamo bug!
3048                super_run = torch._dynamo.optimize("eager_noexcept", nopython=nopython)(super_run)
3049                key = f"{self.__class__.__name__}.{self._testMethodName}"
3050                from .dynamo_test_failures import dynamo_expected_failures, dynamo_skips
3051
3052                def expect_failure(f, test_name):
3053                    @wraps(f)
3054                    def wrapper(*args, **kwargs):
3055                        try:
3056                            f(*args, **kwargs)
3057                        except BaseException as e:
3058                            self.skipTest(e)
3059                        raise RuntimeError(f"Unexpected success, please remove `test/dynamo_expected_failures/{test_name}`")
3060                    return wrapper
3061
3062                if key in dynamo_expected_failures:
3063                    method = getattr(self, self._testMethodName)
3064                    setattr(self, self._testMethodName, expect_failure(method, key))
3065
3066                def ignore_failure(f, test_name):
3067                    @wraps(f)
3068                    def wrapper(*args, **kwargs):
3069                        try:
3070                            f(*args, **kwargs)
3071                        except BaseException as e:
3072                            self.skipTest(e)
3073                        method = getattr(self, self._testMethodName)
3074                        if getattr(method, "__unittest_expecting_failure__", False):
3075                            self.skipTest("unexpected success")
3076                        else:
3077                            self.skipTest(f"This test passed, maybe we can remove `test/dynamo_skips/{test_name}`")
3078                    return wrapper
3079
3080                if key in dynamo_skips:
3081                    method = getattr(self, self._testMethodName)
3082                    setattr(self, self._testMethodName, ignore_failure(method, key))
3083
3084            super_run(result=result)
3085
3086        if strict_mode or should_reset_dynamo:
3087            torch._dynamo.reset()
3088
3089        # Early terminate test if necessary.  If using pytest, use the -x flag instead
3090        if using_unittest and self._should_stop_test_suite():
3091            if result.wasSuccessful():
3092                case = TestCase()
3093                if TEST_SAVE_XML is not None:
3094                    # This is a big hacky, XMLRunner modifies expected type from TestCase to TestInfo
3095                    # Create dummy TestInfo to record results correctly
3096                    from xmlrunner.result import _TestInfo  # type: ignore[import]
3097                    case = _TestInfo(result, case)
3098                    case.output = _TestInfo.ERROR
3099                    case.elapsed_time = 0.0
3100                    case.test_description = "TestSuiteEarlyFailure"
3101                # This shouldn't really happen, but if does add fake failure
3102                # For more details see https://github.com/pytorch/pytorch/issues/71973
3103                result.failures.append((case, "TestSuite execution was aborted early"))
3104                assert result.wasSuccessful() is False
3105            result.stop()
3106
3107
3108    def run(self, result=None):
3109        with contextlib.ExitStack() as stack:
3110            if TEST_WITH_CROSSREF:
3111                stack.enter_context(CrossRefMode())
3112            self._run_custom(
3113                result=result,
3114            )
3115
3116    def setUp(self):
3117        check_if_enable(self)
3118        set_rng_seed(SEED)
3119
3120        # Save global check sparse tensor invariants state that can be
3121        # restored from tearDown:
3122        self._check_invariants = torch.sparse.check_sparse_tensor_invariants.is_enabled()
3123
3124        # Enable invariant checks for all sparse tensors constructions
3125        # including the unsafe ones. If this is not desired for some
3126        # test case, use check_invariants=False optional argument to
3127        # sparse tensor constructors or
3128        # @torch.sparse.check_sparse_tensor_invariants(False)
3129        # decorator to disable the invariant checks.
3130        torch.sparse.check_sparse_tensor_invariants.enable()
3131
3132        if self._default_dtype_check_enabled:
3133            assert torch.get_default_dtype() == torch.float
3134
3135        # attempt to reset some global state at the end of the test
3136        self._prev_grad_state = torch.is_grad_enabled()
3137
3138    def tearDown(self):
3139        # There exists test cases that override TestCase.setUp
3140        # definition, so we cannot assume that _check_invariants
3141        # attribute is defined in general.
3142        if hasattr(self, '_check_invariants'):
3143            # Restore the global check sparse tensor invariants state
3144            if self._check_invariants:
3145                torch.sparse.check_sparse_tensor_invariants.enable()
3146            else:
3147                torch.sparse.check_sparse_tensor_invariants.disable()
3148
3149        if self._default_dtype_check_enabled:
3150            assert torch.get_default_dtype() == torch.float
3151
3152        # attribute may not be defined, per above
3153        if hasattr(self, '_prev_grad_state'):
3154            torch.set_grad_enabled(self._prev_grad_state)
3155
3156    @staticmethod
3157    def _make_crow_indices(n_rows, n_cols, nnz,
3158                           *, device, dtype, random=True):
3159        """Return crow_indices of a CSR tensor with size (n_rows, n_cols) and
3160        the number of specified elements nnz.
3161
3162        If random is True, the column counts of rows are in random
3163        order. Otherwise, the column counts of rows are defined by the
3164        used sampling method.
3165
3166        Sampling method
3167        ---------------
3168
3169        The used sampling method was introduced in
3170        https://pearu.github.io/csr_sampling.html, and here we give
3171        only an overall description of the method.
3172
3173        Notice that crow_indices can be defined as cumsum(counts)
3174        where counts is a sequence of non-negative integers satisfying
3175        the following conditions:
3176
3177          len(counts) == n_rows + 1
3178          counts.max() <= n_cols
3179
3180        while counts[i + 1] is interpreted as the number of specified
3181        elements in the i-th row.
3182
3183        The used sampling method aims at increasing the diversity of
3184        CSR samples, that is, a CSR sample should contain (i) rows
3185        that are all filled, (ii) rows with no elements at all, and
3186        (iii) rows that are partially filled. At the same time and for
3187        the given total number of specified elements (nnz), there
3188        should be minimal preference to rows with a given number of
3189        elements.  To achieve this, the sampling method is built-up on
3190        using a sawteeth model for counts. In the simplest case, we
3191        would have
3192
3193          counts = arange(n_rows + 1) % (n_cols + 1)
3194
3195        that has equal number of all possible column counts per row.
3196        This formula can be used only for specific input values of
3197        n_rows, n_cols, and nnz. To generalize this model to any
3198        combinations of inputs, the counts model above is extended
3199        with an incomplete sawtooth, and the right and lower
3200        rectangular parts that will guarantee that
3201
3202          counts.sum() == nnz
3203
3204        for any combination of n_rows, n_cols, and nnz. Basically,
3205        we'll find a maximal window in (n_rows + 1, n_cols + 1)-grid
3206        that is able to hold a sequence of sawteeth and so-called
3207        final correction, while the external part of the window is
3208        filled with counts to meet the nnz constraint exactly.
3209        """
3210        assert 0 <= nnz <= n_rows * n_cols, (nnz, n_rows, n_cols)
3211
3212        def sawteeth(n, m):
3213            # return the total number of counts in the sequence of
3214            # sawteeth where n and m define a window in (n_rows+1,
3215            # n_cols+1) rectangle where the sequence of sawteeth
3216            # perfectly fit.
3217            M = (n_cols - m) * (n_cols - m + 1) // 2
3218            K = (n_rows - n) % (n_cols - m + 1)
3219            return M * ((n_rows - n) // (n_cols - m + 1)) + K * (K - 1) // 2
3220
3221        # Different from the original method description, here counts
3222        # has leading 0 required by crow_indices:
3223        counts = torch.zeros(n_rows + 1, dtype=dtype, device=torch.device('cpu'))
3224
3225        n = m = 0
3226        N = sawteeth(n, m)
3227        if N and nnz >= max(N, n_cols):
3228            # determine the width of the sawteeth window. We use bisection to solve
3229            #   N(n, 0) == 0 or nnz - n * n_cols < max(N(n, 0), n_cols)
3230            # for n
3231            n_left = n
3232            n_right = n_rows - 1
3233            N_right = sawteeth(n_right, m)
3234            while n_right - n_left > 1:
3235                n_middle = (n_left + n_right) // 2
3236                N_middle = sawteeth(n_middle, m)
3237                if N_middle == 0 or nnz - n_middle * n_cols < max(N_middle, n_cols):
3238                    n_right, N_right = n_middle, N_middle
3239                else:
3240                    n_left = n_middle
3241            n, N = n_right, N_right
3242            # fill the right rectangle with counts:
3243            assert n
3244            counts[-n:].fill_(n_cols)
3245
3246        if N and nnz - n * n_cols >= max(N, n_rows - n):
3247            # determine the height of the sawteeth window. We use bisection to solve
3248            #   N(n, m) == 0 or nnz - n * n_cols - m * (n_rows - n) < max(N(n, m), n_rows - n)
3249            # for m.
3250            m_left = m
3251            m_right = n_cols - 1
3252            N_right = sawteeth(n, m_right)
3253            while m_right - m_left > 1:
3254                m_middle = (m_left + m_right) // 2
3255                N_middle = sawteeth(n, m_middle)
3256                if N_middle == 0 or nnz - n * n_cols - m_middle * (n_rows - n) < max(N_middle, n_rows - n):
3257                    m_right, N_right = m_middle, N_middle
3258                else:
3259                    m_left = m_middle
3260            m, N = m_right, N_right
3261            # fill the bottom rectangle with counts:
3262            assert m
3263            counts[1:n_rows - n + 1].fill_(m)
3264
3265        if N:
3266            # fill the sawteeth window with counts
3267            q, r = divmod(nnz - n * n_cols - m * (n_rows - n),
3268                          (n_cols - m) * (n_cols - m + 1) // 2)
3269            p = 1 + q * (n_cols - m + 1)
3270            k = math.isqrt(2 * r)
3271            if k * (k + 1) > 2 * r:
3272                k -= 1
3273            corr = r - k * (k + 1) // 2
3274            assert not ((p > 1) and (m > 0))  # full sawteeth are never on top of a bottom rectangle
3275            # sequence of full sawteeth:
3276            counts[1:p] = torch.arange(p - 1, dtype=dtype, device=counts.device) % (n_cols - m + 1)
3277            # incomplete sawtooth:
3278            counts[p:p + k + 1] += torch.arange(k + 1, dtype=dtype, device=counts.device)
3279        else:
3280            # given input does not support sawteeth
3281            p = 1
3282            corr = nnz - n * n_cols - m * (n_rows - n)
3283
3284        # correction that will guarantee counts.sum() == nnz:
3285        counts[p] += corr
3286
3287        if random:
3288            # randomize crow_indices by shuffling the sawteeth
3289            # sequence:
3290            perm = torch.randperm(n_rows, device=counts.device)
3291            counts[1:] = counts[1:][perm]
3292
3293        # compute crow_indices:
3294        crow_indices = counts
3295        crow_indices.cumsum_(dim=0)
3296        return crow_indices.to(device=device)
3297
3298    def genSparseCompressedTensor(self, size, nnz, *, layout, device, dtype, index_dtype, blocksize=(), dense_dims=0):
3299        from operator import mul
3300        from functools import reduce
3301        sparse_dim = 2
3302        assert all(size[d] > 0 for d in range(len(size))) or nnz == 0, 'invalid arguments'
3303        assert len(size) >= sparse_dim
3304        if blocksize:
3305            assert len(blocksize) == 2, (size, blocksize)
3306            assert size[-2 - dense_dims] % blocksize[0] == 0, (size, blocksize)
3307            assert size[-1 - dense_dims] % blocksize[1] == 0, (size, blocksize)
3308            blocksize0, blocksize1 = blocksize
3309        else:
3310            blocksize0 = blocksize1 = 1
3311
3312        size = tuple(size)
3313        dense_size = size[(len(size) - dense_dims):]
3314
3315        def random_sparse_compressed(n_compressed_dims, n_plain_dims, nnz):
3316            compressed_indices = self._make_crow_indices(n_compressed_dims, n_plain_dims, nnz, device=device, dtype=index_dtype)
3317            plain_indices = torch.zeros(nnz, dtype=index_dtype, device=device)
3318            for i in range(n_compressed_dims):
3319                count = compressed_indices[i + 1] - compressed_indices[i]
3320                plain_indices[compressed_indices[i]:compressed_indices[i + 1]], _ = torch.sort(
3321                    torch.randperm(n_plain_dims, dtype=index_dtype, device=device)[:count])
3322            low = -1 if dtype != torch.uint8 else 0
3323            high = 1 if dtype != torch.uint8 else 2
3324            values = make_tensor((nnz,) + blocksize + dense_size, device=device, dtype=dtype, low=low, high=high)
3325            return values, compressed_indices, plain_indices
3326
3327        batch_shape = size[:-2 - dense_dims]
3328        n_batch = reduce(mul, batch_shape, 1)
3329
3330        if layout in {torch.sparse_csr, torch.sparse_bsr}:
3331            n_compressed_dims, n_plain_dims = size[-2 - dense_dims] // blocksize0, size[-1 - dense_dims] // blocksize1
3332        else:
3333            n_compressed_dims, n_plain_dims = size[-1 - dense_dims] // blocksize1, size[-2 - dense_dims] // blocksize0
3334        blocknnz = nnz // (blocksize0 * blocksize1)
3335        sparse_tensors = [random_sparse_compressed(n_compressed_dims, n_plain_dims, blocknnz) for _ in range(n_batch)]
3336        sparse_tensors_it = map(list, zip(*sparse_tensors))
3337
3338        values = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, blocknnz, *blocksize, *dense_size)
3339        compressed_indices = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, -1)
3340        plain_indices = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, -1)
3341        return torch.sparse_compressed_tensor(compressed_indices, plain_indices,
3342                                              values, size=size, dtype=dtype, layout=layout, device=device)
3343
3344    def genSparseCSRTensor(self, size, nnz, *, device, dtype, index_dtype, dense_dims=0):
3345        return self.genSparseCompressedTensor(size, nnz, layout=torch.sparse_csr, device=device,
3346                                              dtype=dtype, index_dtype=index_dtype, blocksize=(), dense_dims=dense_dims)
3347
3348    def genSparseCSCTensor(self, size, nnz, *, device, dtype, index_dtype, dense_dims=0):
3349        return self.genSparseCompressedTensor(size, nnz, layout=torch.sparse_csc, device=device,
3350                                              dtype=dtype, index_dtype=index_dtype, blocksize=(), dense_dims=0)
3351
3352    def genSparseBSRTensor(self, size, blocksize, nnz, *, device, dtype, index_dtype, dense_dims=0):
3353        assert len(blocksize) == 2
3354        return self.genSparseCompressedTensor(size, nnz, layout=torch.sparse_bsr, device=device,
3355                                              dtype=dtype, index_dtype=index_dtype, blocksize=blocksize, dense_dims=dense_dims)
3356
3357    def genSparseBSCTensor(self, size, blocksize, nnz, *, device, dtype, index_dtype, dense_dims=0):
3358        assert len(blocksize) == 2
3359        return self.genSparseCompressedTensor(size, nnz, layout=torch.sparse_bsc, device=device,
3360                                              dtype=dtype, index_dtype=index_dtype, blocksize=blocksize, dense_dims=dense_dims)
3361
3362    def genSparseTensor(self, size, sparse_dim, nnz, is_uncoalesced, device, dtype):
3363        # Assert not given impossible combination, where the sparse dims have
3364        # empty numel, but nnz > 0 makes the indices containing values.
3365        assert all(size[d] > 0 for d in range(sparse_dim)) or nnz == 0, 'invalid arguments'
3366
3367        v_size = [nnz] + list(size[sparse_dim:])
3368        v = make_tensor(v_size, device=device, dtype=dtype, low=-1, high=1)
3369        i = torch.rand(sparse_dim, nnz, device=device)
3370        i.mul_(torch.tensor(size[:sparse_dim]).unsqueeze(1).to(i))
3371        i = i.to(torch.long)
3372        if is_uncoalesced:
3373            i1 = i[:, :(nnz // 2), ...]
3374            i2 = i[:, :((nnz + 1) // 2), ...]
3375            i = torch.cat([i1, i2], 1)
3376        x = torch.sparse_coo_tensor(i, v, torch.Size(size), dtype=dtype, device=device)
3377
3378        if not is_uncoalesced:
3379            x = x.coalesce()
3380        else:
3381            # FIXME: `x` is a sparse view of `v`. Currently rebase_history for
3382            #        sparse views is not implemented, so this workaround is
3383            #        needed for inplace operations done on `x`, e.g., copy_().
3384            #        Remove after implementing something equivalent to CopySlice
3385            #        for sparse views.
3386            # NOTE: We do clone() after detach() here because we need to be able to change size/storage of x afterwards
3387            x = x.detach().clone()._coalesced_(False)
3388        return x, x._indices().clone(), x._values().clone()
3389
3390    def generate_simple_inputs(self, layout,
3391                               device=None,
3392                               dtype=None,
3393                               index_dtype=None,
3394                               pin_memory=None,
3395                               members_pin_memory=None,
3396                               enable_batch=True,
3397                               enable_hybrid=True,
3398                               enable_zero_sized=True,
3399                               enable_non_contiguous_indices=True,
3400                               enable_non_contiguous_values=True,
3401                               enable_batch_variable_nse=False,
3402                               output_tensor=True,
3403                               patterns=None):
3404        """Generator of simple inputs for tensor constructors of the given layout.
3405
3406        The generated tensor inputs have the following properties:
3407
3408        - tensor shapes are minimal but not trivial
3409        - tensor values are sorted sequences for COO and CSR formats, e.g. [1, 2, 3, 4]
3410        - the generated tensors represent the same mathematical tensor for all layouts
3411        - the generated tensors include regular, zero-sized, and optionally, batched or/and hybrid tensors.
3412        - the generated tensors include contiguous or non-contiguous tensors both in indices and values
3413
3414        If output_tensor is True, yield tensors with the given
3415        layout. Otherwise, yield inputs to the corresponding tensor
3416        constructors:
3417
3418          - sparse compressed input is defined as
3419            (compressed_indices, plain_indices, values), dict(size=expected_size_from_shape_inference, device=device, dtype=dtype,
3420                                                              pin_memory=pin_memory)
3421
3422          - sparse COO input is defined as
3423            (indices, values), dict(size=expected_size_from_shape_inference, device=device, dtype=dtype, pin_memory=pin_memory)
3424
3425          - strided input is defined as
3426            (values,), dict(device=device, dtype=dtype)
3427        """
3428        if index_dtype is None:
3429            index_dtype = torch.int64
3430
3431        is_compressed_sparse_layout = layout in {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}
3432
3433        if output_tensor:
3434            for args, kwargs in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype,
3435                                                            pin_memory=pin_memory,
3436                                                            enable_batch=enable_batch, enable_hybrid=enable_hybrid,
3437                                                            enable_zero_sized=enable_zero_sized,
3438                                                            enable_non_contiguous_indices=enable_non_contiguous_indices,
3439                                                            enable_non_contiguous_values=enable_non_contiguous_values,
3440                                                            enable_batch_variable_nse=enable_batch_variable_nse,
3441                                                            output_tensor=False):
3442                if members_pin_memory:
3443                    args = tuple(a.pin_memory() for a in args)
3444                if layout is torch.strided:
3445                    assert len(args) == 1
3446                    size = kwargs.pop('size', None)  # to ensure that a zero-sized tensor has the desired shape
3447                    assert size is not None
3448                    if pin_memory:
3449                        yield args[0].reshape(size).pin_memory()
3450                    else:
3451                        yield args[0].reshape(size)
3452                elif layout is torch.sparse_coo:
3453                    yield torch.sparse_coo_tensor(*args, **kwargs)
3454                elif is_compressed_sparse_layout:
3455                    kwargs.update(layout=layout)
3456                    yield torch.sparse_compressed_tensor(*args, **kwargs)
3457                else:
3458                    assert 0  # unreachable
3459            return
3460
3461        def get_blockpattern(pattern, blocksize):
3462            basesize = pattern.shape
3463            assert basesize[0] % blocksize[0] == 0, (basesize, blocksize)
3464            assert basesize[1] % blocksize[1] == 0, (basesize, blocksize)
3465            blockpattern = pattern.reshape(-1,
3466                                           blocksize[0],
3467                                           basesize[1] // blocksize[1],
3468                                           blocksize[1]).transpose(-3, -2).any(-1).any(-1)
3469            block_ids = torch.arange(1, blockpattern.numel() + 1).reshape(blockpattern.shape)
3470            return (blockpattern != 0) * block_ids
3471
3472        def get_sparse_data(pattern):
3473            basesize = pattern.shape
3474            assert len(basesize) == 2, basesize  # pattern is expected to be a matrix
3475
3476            # We cannot use `torch.sparse_xyz_tensor(pattern)` to
3477            # compute the sparse layout indices and values because
3478            # generate_simple_inputs is used to generate the inputs to
3479            # test `torch.sparse_xyz_tensor` factory functions, so
3480            # we'll compute the indices and values independently of
3481            # the factory functions.
3482
3483            indices = torch.where(pattern != 0)
3484            coo_indices = torch.stack(indices)
3485            crow_indices = torch.zeros(basesize[0] + 1, dtype=torch.int64)
3486            crow_indices[1:] = torch.cumsum(coo_indices[0].bincount(minlength=basesize[0]), 0)
3487            col_indices = coo_indices[1]
3488            strided_values = torch.zeros(basesize, dtype=torch.int64)
3489
3490            # the property of `values == range(1, 1+nnz)` is used in
3491            # get_sparse_data_with_block to relate BSR and BSC values,
3492            # so, don't change the following line:
3493            values = torch.arange(1, 1 + len(indices[0]), dtype=torch.int64)
3494            strided_values[indices] = values
3495
3496            indices_T = torch.where(pattern.transpose(0, 1) != 0)
3497            coo_indices_T = torch.stack(indices_T)
3498            ccol_indices = torch.zeros(basesize[1] + 1, dtype=torch.int64)
3499            ccol_indices[1:] = torch.cumsum(coo_indices_T[0].bincount(minlength=basesize[1]), 0)
3500            row_indices = coo_indices_T[1]
3501            csc_values = strided_values.transpose(0, 1)[indices_T]
3502
3503            return {torch.sparse_coo: (coo_indices, values),
3504                    torch.sparse_csr: (crow_indices, col_indices, values),
3505                    torch.sparse_csc: (ccol_indices, row_indices, csc_values),
3506                    torch.strided: (strided_values,)}
3507
3508        def get_sparse_data_with_block(pattern, blocksize):
3509            nonblock_data = get_sparse_data(pattern)
3510            blockpattern = get_blockpattern(pattern, blocksize)
3511            block_data = get_sparse_data(blockpattern)
3512
3513            strided_values = nonblock_data[torch.strided][0]
3514            block_indices = block_data[torch.sparse_coo][0]
3515            bsr_values = torch.stack([strided_values[bi * blocksize[0]:(bi + 1) * blocksize[0],
3516                                                     bj * blocksize[1]:(bj + 1) * blocksize[1]]
3517                                      for bi, bj in block_indices.transpose(0, 1)])
3518
3519            # here we use the property `values == range(1, 1+nnz)` and
3520            # `values` relation to `csc_values` (see get_sparse_data)
3521            # to get BSC blocks via reordering the BSR blocks:
3522            bsc_values = bsr_values[block_data[torch.sparse_csc][2] - 1]
3523
3524            return {torch.sparse_bsr: (*block_data[torch.sparse_csr][:2], bsr_values),
3525                    torch.sparse_bsc: (*block_data[torch.sparse_csc][:2], bsc_values),
3526                    **nonblock_data}
3527
3528        def get_batch_sparse_data(pattern, blocksize):
3529            size = pattern.shape
3530            if len(size) <= 2:  # non-batch
3531                return get_sparse_data_with_block(pattern, blocksize)
3532
3533            # batch data is created recursively:
3534            batch_data = {}
3535            for i, item in enumerate(pattern):
3536                for layout, d in get_batch_sparse_data(item, blocksize).items():
3537                    target = batch_data.get(layout)
3538                    if layout is torch.sparse_coo:
3539                        # a "batch COO" means a COO with the leading
3540                        # sparse dimensions interpreted as batch
3541                        # dimensions
3542                        ext_coo_indices1 = torch.cat((torch.full((1, len(d[1])), i, dtype=torch.int64), d[0]))
3543                        if target is None:
3544                            target = batch_data[layout] = (ext_coo_indices1, d[1])
3545                        else:
3546                            target[0].set_(torch.cat((target[0], ext_coo_indices1), 1))
3547                            target[1].set_(torch.cat((target[1], d[1])))
3548                    else:
3549                        if target is None:
3550                            target = batch_data[layout] = tuple(d[j].unsqueeze(0) for j in range(len(d)))
3551                        else:
3552                            for j in range(len(d)):
3553                                target[j].set_(torch.cat((target[j], d[j].unsqueeze(0))))
3554            return batch_data
3555
3556        def generate_values(base, densesize):
3557            """Generates a tensor of shape densesize with values equal to
3558
3559              base + i_1 * 10^0 + ... + i_d * 10^{d - 1}
3560
3561            at indices i_1, ..., i_d (with 0 <= i_j < densesize[j] for any 1 <= j <=
3562            len(densesize))
3563
3564            This mapping produces unique values as long as
3565            densesize[i] < 10 for all i in range(len(densesize)).
3566            """
3567
3568            if not densesize:
3569                return base
3570            if not isinstance(base, int) and base.ndim > 0:
3571                return torch.stack([generate_values(b, densesize) for b in base])
3572            if base == 0:
3573                return torch.zeros(densesize, dtype=torch.int64)
3574            r = torch.arange(densesize[0], dtype=torch.int64)
3575            for i, d in enumerate(densesize[1:]):
3576                y = torch.arange(d, dtype=torch.int64) * (10 ** (i + 1))
3577                r = r[..., None] + y[None, ...]
3578            r.add_(base)
3579            return r
3580
3581        if patterns is None:
3582            # A pattern is a 3-tuple with the following items:
3583            #
3584            # - a list of integers with the depth of two or more. The
3585            #   integers define the sparsity patterns of the generated
3586            #   inputs: zero values correspond to unspecified
3587            #   elements/blocks, and non-zero values to the specified
3588            #   elements.
3589            #
3590            #   For debugging convenience, the elements with the same
3591            #   value typically belong to the same block. However, it
3592            #   is not a hard requirement: as long as the shape of a
3593            #   pattern divides with block sizes, the pattern will be
3594            #   a valid one.
3595            #
3596            #   If the depth of the list is larger than two, inputs
3597            #   with batch dimensions will be generated.
3598            #
3599            # - a list of 2-tuples of block sizes, used to generate
3600            #   BSR/BSC tensors with various block size parameters
3601            #
3602            # - a list of tuples of dense dimensions, used to generate
3603            #   hybrid tensors with various dense dimensions
3604            #
3605            patterns = [
3606                # a simple 3 x 2 tensor: non-hybrid, hybrid with 1 and 2 dense dimensions
3607                ([[1, 2, 0],
3608                  [1, 0, 3]], [(2, 1), (1, 3)], [(), (2,), (4, 5)]),
3609                # 2 x 3 batch of 3 x 2 tensors: non-hybrid and hybrid with 2 dense dimensions
3610                ([[[[1, 2, 0],
3611                    [1, 0, 3]],
3612                   [[1, 2, 3],
3613                    [1, 0, 0]],
3614                   [[1, 0, 0],
3615                    [1, 2, 3]]],
3616                  [[[0, 2, 0],
3617                    [1, 2, 3]],
3618                   [[1, 0, 3],
3619                    [1, 2, 0]],
3620                   [[1, 2, 3],
3621                    [0, 2, 0]]]], [(2, 1), (2, 3)], [(), (2,)]),
3622                # tensor with non-trivial blocksize
3623                ([[0, 1, 0, 2, 0, 2],
3624                  [0, 1, 0, 0, 2, 0],
3625                  [3, 3, 3, 0, 0, 0],
3626                  [0, 0, 0, 0, 0, 0],
3627                  [0, 5, 0, 6, 6, 6],
3628                  [5, 0, 5, 6, 6, 6],
3629                  [0, 0, 0, 0, 8, 8],
3630                  [7, 7, 7, 0, 8, 8]], [(2, 3)], [(), (4, 5)]),
3631                # batch tensor with variable NSE
3632                # Requires https://github.com/pytorch/pytorch/pull/84843 or similar.
3633                ([[[1, 2],
3634                   [3, 4]],
3635                  [[1, 0],
3636                   [0, 0]]], [(1, 1)], ([()] if enable_batch_variable_nse else []))]
3637
3638        def non_contiguous_copy(t, dim=-1, offset=0):
3639            # return a copy of t that is non-contiguous along the
3640            # given dimension and with the given storage offset
3641            self.assertTrue(t.is_contiguous())
3642            if dim < 0:
3643                dim = dim + t.ndim
3644            assert dim >= 0 and dim < t.ndim
3645            step = max(2, offset + 1)
3646            tmp = torch.zeros((*t.shape[:dim], t.shape[dim] * step, *t.shape[dim + 1:]), dtype=t.dtype, device=t.device)
3647            dim_slices = (*((slice(None),) * dim), slice(offset, None, step))
3648            r = tmp[dim_slices].copy_(t)
3649            self.assertFalse(r.is_contiguous())
3650            self.assertEqual(t, r)
3651            return r
3652
3653        # the main loop of the method:
3654        for pattern, blocksizes, densesizes in patterns:
3655            if not enable_hybrid:
3656                densesizes = [s for s in densesizes if not s]
3657            if not (densesizes and blocksizes):
3658                continue
3659            pattern = torch.tensor(pattern, dtype=torch.int64)
3660            if not enable_batch and pattern.ndim > 2:
3661                continue
3662            for blocksize in blocksizes:
3663                data = get_batch_sparse_data(pattern, blocksize)[layout]
3664                for densesize in densesizes:
3665                    indices = [a.to(device=device, dtype=index_dtype) for a in data[:-1]]
3666                    values = generate_values(data[-1], densesize).to(device=device, dtype=dtype)
3667                    kwargs = dict(device=device, dtype=dtype, size=pattern.shape + densesize)
3668                    if pin_memory is not None:
3669                        kwargs.update(pin_memory=pin_memory)
3670
3671                    yield (*indices, values), kwargs.copy()
3672                    if enable_non_contiguous_indices and pattern.ndim > 2:
3673                        # sparse compressed indices can be sliced only along batch dimensions
3674                        for (dim, offset) in {(0, 1), (-2, 0)}:
3675                            indices_copy = [non_contiguous_copy(a, dim=dim, offset=offset) for a in indices]
3676                            yield (*indices_copy, values), kwargs.copy()
3677
3678                            if enable_non_contiguous_values:
3679                                values_copy = non_contiguous_copy(values, dim=-1, offset=1)
3680                                yield (*indices_copy, values_copy), kwargs.copy()
3681
3682                    if enable_non_contiguous_values:
3683                        values_copy = non_contiguous_copy(values, dim=-1, offset=1)
3684                        yield (*indices, values_copy), kwargs.copy()
3685
3686        # zero-sized tensor inputs, non-batch, non-hybrid/hybrid
3687        if enable_zero_sized:
3688            for basesize, blocksizes, densesizes in [
3689                    ((2, 0), [(1, 2)], [(), (2,), (2, 3)] if enable_hybrid else [()]),
3690                    ((0, 2), [(1, 2), (2, 1), (3, 2)], [()]),
3691                    ((0, 0), [(1, 2)], [()]),
3692            ]:
3693                for blocksize in blocksizes:
3694                    for densesize in densesizes:
3695                        if layout == torch.strided:
3696                            indices = ()
3697                            values = torch.empty((basesize + densesize), device=device, dtype=dtype)
3698                        elif layout == torch.sparse_coo:
3699                            indices = (torch.empty(len(basesize), 0, device=device, dtype=index_dtype),)
3700                            values = torch.empty((0, *densesize), device=device, dtype=dtype)
3701                        elif layout == torch.sparse_csr:
3702                            crow_indices = torch.tensor([0] * (basesize[0] + 1), device=device, dtype=index_dtype)
3703                            col_indices = torch.empty(0, device=device, dtype=index_dtype)
3704                            indices = (crow_indices, col_indices)
3705                            values = torch.empty((0, *densesize), device=device, dtype=dtype)
3706                        elif layout == torch.sparse_csc:
3707                            ccol_indices = torch.tensor([0] * (basesize[1] + 1), device=device, dtype=index_dtype)
3708                            row_indices = torch.empty(0, device=device, dtype=index_dtype)
3709                            indices = (ccol_indices, row_indices)
3710                            values = torch.empty((0, *densesize), device=device, dtype=dtype)
3711                        elif layout == torch.sparse_bsr:
3712                            crow_indices = torch.tensor([0] * (basesize[0] // blocksize[0] + 1), device=device, dtype=index_dtype)
3713                            col_indices = torch.empty(0, device=device, dtype=index_dtype)
3714                            indices = (crow_indices, col_indices)
3715                            values = torch.empty((0, *blocksize, *densesize), device=device, dtype=dtype)
3716                        elif layout == torch.sparse_bsc:
3717                            ccol_indices = torch.tensor([0] * (basesize[1] // blocksize[1] + 1), device=device, dtype=index_dtype)
3718                            row_indices = torch.empty(0, device=device, dtype=index_dtype)
3719                            indices = (ccol_indices, row_indices)
3720                            values = torch.empty((0, *blocksize, *densesize), device=device, dtype=dtype)
3721                        else:
3722                            assert 0  # unreachable
3723                        kwargs = dict(device=device, dtype=dtype, size=basesize + densesize)
3724                        if pin_memory is not None:
3725                            kwargs.update(pin_memory=pin_memory)
3726                        yield (*indices, values), kwargs
3727
3728    def safeToDense(self, t):
3729        # coalesce is only implemented for COO
3730        if t.layout == torch.sparse_coo:
3731            t = t.coalesce()
3732        return t.to_dense()
3733
3734    # Compares a torch function with a reference function for a given sample input (object of SampleInput)
3735    # Note: only values are compared, type comparison is not done here
3736    def compare_with_reference(self, torch_fn, ref_fn, sample_input, **kwargs):
3737        numpy_sample = sample_input.numpy()
3738        n_inp, n_args, n_kwargs = numpy_sample.input, numpy_sample.args, numpy_sample.kwargs
3739        t_inp, t_args, t_kwargs = sample_input.input, sample_input.args, sample_input.kwargs
3740
3741        actual = torch_fn(t_inp, *t_args, **t_kwargs)
3742        expected = ref_fn(n_inp, *n_args, **n_kwargs)
3743
3744        self.assertEqual(actual, expected, exact_device=False, **kwargs)
3745
3746    # Compares the given Torch and NumPy functions on the given tensor-like object.
3747    # NOTE: both torch_fn and np_fn should be functions that take a single
3748    #   tensor (array). If the torch and/or NumPy function require additional
3749    #   arguments then wrap the function in a lambda or pass a partial function.
3750    # TODO: add args/kwargs for passing to assertEqual (e.g. rtol, atol)
3751    def compare_with_numpy(self, torch_fn, np_fn, tensor_like,
3752                           device=None, dtype=None, **kwargs):
3753        assert TEST_NUMPY
3754
3755        if isinstance(tensor_like, torch.Tensor):
3756            assert device is None
3757            assert dtype is None
3758            t_cpu = tensor_like.detach().cpu()
3759            if t_cpu.dtype is torch.bfloat16:
3760                t_cpu = t_cpu.float()
3761            a = t_cpu.numpy()
3762            t = tensor_like
3763        else:
3764            d = copy.copy(torch_to_numpy_dtype_dict)
3765            d[torch.bfloat16] = np.float32
3766            a = np.array(tensor_like, dtype=d[dtype])
3767            t = torch.tensor(tensor_like, device=device, dtype=dtype)
3768
3769        np_result = np_fn(a)
3770        torch_result = torch_fn(t).cpu()
3771
3772        # Converts arrays to tensors
3773        if isinstance(np_result, np.ndarray):
3774            try:
3775                np_result = torch.from_numpy(np_result)
3776            except Exception:
3777                # NOTE: copying an array before conversion is necessary when,
3778                #   for example, the array has negative strides.
3779                np_result = torch.from_numpy(np_result.copy())
3780            if t.dtype is torch.bfloat16 and torch_result.dtype is torch.bfloat16 and np_result.dtype is torch.float:
3781                torch_result = torch_result.to(torch.float)
3782
3783        self.assertEqual(np_result, torch_result, **kwargs)
3784
3785    def assertEqualIgnoreType(self, *args, **kwargs) -> None:
3786        # If you are seeing this function used, that means test is written wrongly
3787        # and deserves detailed investigation
3788        return self.assertEqual(*args, exact_dtype=False, **kwargs)
3789
3790    def assertEqualBroadcasting(self, x, y, *args, **kwargs) -> None:
3791        r"""Tests if tensor x equals to y, if y to be broadcast to x.shape.
3792        """
3793        if not isinstance(y, Iterable):
3794            # int, float, etc. or different shape tensors
3795            y = torch.ones_like(x) * y
3796        if not isinstance(y, torch.Tensor):
3797            # iterable, but not a tensor
3798            y = torch.ones_like(x) * torch.tensor(y)
3799        return self.assertEqual(x, y, *args, **kwargs)
3800
3801    def assertEqual(
3802            self,
3803            x,
3804            y,
3805            msg: Optional[Union[str, Callable[[str], str]]] = None,
3806            *,
3807            atol: Optional[float] = None,
3808            rtol: Optional[float] = None,
3809            equal_nan=True,
3810            exact_dtype=True,
3811            # TODO: default this to True
3812            exact_device=False,
3813            exact_layout=False,
3814            exact_stride=False,
3815            exact_is_coalesced=False
3816    ):
3817        # Hide this function from `pytest`'s traceback
3818        __tracebackhide__ = True
3819
3820        # numpy's dtypes are a superset of what PyTorch supports. In case we encounter an unsupported dtype, we fall
3821        # back to an elementwise comparison. Note that this has to happen here and not for example in
3822        # `TensorOrArrayPair`, since at that stage we can no longer split the array into its elements and perform
3823        # multiple comparisons.
3824        if any(
3825            isinstance(input, np.ndarray) and not has_corresponding_torch_dtype(input.dtype) for input in (x, y)
3826        ):
3827            def to_list(input):
3828                return input.tolist() if isinstance(input, (torch.Tensor, np.ndarray)) else list(input)
3829
3830            x = to_list(x)
3831            y = to_list(y)
3832        # When comparing a sequence of numbers to a tensor, we need to convert the sequence to a tensor here.
3833        # Otherwise, the pair origination of `are_equal` will fail, because the sequence is recognized as container
3834        # that should be checked elementwise while the tensor is not.
3835        elif isinstance(x, torch.Tensor) and isinstance(y, Sequence):
3836            y = torch.as_tensor(y, dtype=x.dtype, device=x.device)
3837        elif isinstance(x, Sequence) and isinstance(y, torch.Tensor):
3838            x = torch.as_tensor(x, dtype=y.dtype, device=y.device)
3839
3840        # unbind NSTs to compare them; don't do this for NJTs
3841        if isinstance(x, torch.Tensor) and x.is_nested and x.layout == torch.strided:
3842            x = x.unbind()
3843        if isinstance(y, torch.Tensor) and y.is_nested and y.layout == torch.strided:
3844            y = y.unbind()
3845
3846        error_metas = not_close_error_metas(
3847            x,
3848            y,
3849            pair_types=(
3850                NonePair,
3851                RelaxedBooleanPair,
3852                RelaxedNumberPair,
3853                TensorOrArrayPair,
3854                TypedStoragePair,
3855                StringPair,
3856                SetPair,
3857                TypePair,
3858                ObjectPair,
3859            ),
3860            sequence_types=(
3861                Sequence,
3862                Sequential,
3863                ModuleList,
3864                ParameterList,
3865                ScriptList,
3866                torch.utils.data.dataset.Subset,
3867            ),
3868            mapping_types=(Mapping, ModuleDict, ParameterDict, ScriptDict),
3869            rtol=rtol,
3870            rtol_override=self.rel_tol,
3871            atol=atol,
3872            atol_override=self.precision,
3873            equal_nan=equal_nan,
3874            check_device=exact_device,
3875            check_dtype=exact_dtype,
3876            check_layout=exact_layout,
3877            check_stride=exact_stride,
3878            check_is_coalesced=exact_is_coalesced,
3879        )
3880
3881        if error_metas:
3882            # See [ErrorMeta Cycles]
3883            error_metas = [error_metas]
3884            # TODO: compose all metas into one AssertionError
3885            raise error_metas.pop()[0].to_error(
3886                # This emulates unittest.TestCase's behavior if a custom message passed and
3887                # TestCase.longMessage (https://docs.python.org/3/library/unittest.html#unittest.TestCase.longMessage)
3888                # is True (default)
3889                (lambda generated_msg: f"{generated_msg}\n{msg}") if isinstance(msg, str) and self.longMessage else msg
3890            )
3891
3892    def assertNotEqual(self, x, y, msg: Optional[str] = None, *,                                       # type: ignore[override]
3893                       atol: Optional[float] = None, rtol: Optional[float] = None, **kwargs) -> None:
3894        with self.assertRaises(AssertionError, msg=msg):
3895            self.assertEqual(x, y, msg, atol=atol, rtol=rtol, **kwargs)
3896
3897    def assertEqualTypeString(self, x, y) -> None:
3898        # This API is used simulate deprecated x.type() == y.type()
3899        self.assertEqual(x.device, y.device)
3900        self.assertEqual(x.dtype, y.dtype)
3901        self.assertEqual(x.is_sparse, y.is_sparse)
3902
3903    def assertObjectIn(self, obj: Any, iterable: Iterable[Any]) -> None:
3904        for elem in iterable:
3905            if id(obj) == id(elem):
3906                return
3907        raise AssertionError("object not found in iterable")
3908
3909    # Reimplemented to provide special behavior when
3910    # _ignore_not_implemented_error is True
3911    def assertRaises(self, expected_exception, *args, **kwargs):
3912        if self._ignore_not_implemented_error:
3913            context: Optional[AssertRaisesContextIgnoreNotImplementedError] = \
3914                AssertRaisesContextIgnoreNotImplementedError(expected_exception, self)  # type: ignore[call-arg]
3915            try:
3916                return context.handle('assertRaises', args, kwargs)  # type: ignore[union-attr]
3917            finally:
3918                # see https://bugs.python.org/issue23890
3919                context = None
3920        else:
3921            return super().assertRaises(expected_exception, *args, **kwargs)
3922
3923    # Reimplemented to provide special behavior when
3924    # _ignore_not_implemented_error is True
3925    def assertRaisesRegex(self, expected_exception, expected_regex, *args, **kwargs):
3926        # Verifies that an exception with the type expected_exception and message
3927        # matching the regular expression defined by expected_regex is thrown.
3928        # If the test is instantiated for a non-native device type (like XLA)
3929        # then the message is not validated.
3930
3931        # Checks whether the test is instantiated for a device type by testing
3932        # if the test class has defined the device_type attribute and,
3933        # if so, tests whether the instantiated device type is native or not
3934        if hasattr(self, 'device_type') and self.device_type not in NATIVE_DEVICES and self.device_type != "mps":  # type: ignore[attr-defined]
3935            # empty string matches any string
3936            expected_regex = ''
3937
3938        if self._ignore_not_implemented_error:
3939            context = AssertRaisesContextIgnoreNotImplementedError(  # type: ignore[call-arg]
3940                expected_exception, self, expected_regex)
3941            return context.handle('assertRaisesRegex', args, kwargs)  # type: ignore[attr-defined]
3942        else:
3943            return super().assertRaisesRegex(expected_exception, expected_regex, *args, **kwargs)
3944
3945    # Verifies that no unraisable exceptions are raised by callable.  Unlike regular
3946    # exceptions, these do not actually propagate to the caller and are
3947    # suppressed.  We must test for them specially.
3948    def assertNoUnraisable(self, callable, *args, **kwargs):
3949        raised = None
3950
3951        def record_unraisable(unraisable):
3952            nonlocal raised
3953            raised = unraisable
3954
3955        # Disable GC when running the callable to prevent spurious flakiness
3956        # from unlucky GCs inside the callable
3957        prev = gc.isenabled()
3958        gc.disable()
3959        try:
3960            with unittest.mock.patch("sys.unraisablehook", record_unraisable):
3961                callable(*args, **kwargs)
3962        finally:
3963            if prev:
3964                gc.enable()
3965
3966        self.assertIsNone(raised)
3967
3968    # TODO: Support context manager interface
3969    # NB: The kwargs forwarding to callable robs the 'subname' parameter.
3970    # If you need it, manually apply your callable in a lambda instead.
3971    def assertExpectedRaises(self, exc_type, callable, *args, **kwargs):
3972        subname = None
3973        if 'subname' in kwargs:
3974            subname = kwargs['subname']
3975            del kwargs['subname']
3976        try:
3977            callable(*args, **kwargs)
3978        except exc_type as e:
3979            self.assertExpected(str(e), subname)
3980            return
3981        # Don't put this in the try block; the AssertionError will catch it
3982        self.fail(msg="Did not raise when expected to")
3983
3984    def assertNotWarn(self, callable, msg=''):
3985        r"""
3986        Test if :attr:`callable` does not raise a warning.
3987        """
3988        with warnings.catch_warnings(record=True) as ws:
3989            warnings.simplefilter("always")  # allow any warning to be raised
3990            with set_warn_always_context(True):
3991                callable()
3992            self.assertTrue(len(ws) == 0, msg)
3993
3994    @contextmanager
3995    def assertWarnsOnceRegex(self, category, regex=''):
3996        """Context manager for code that *must always* warn
3997
3998        This filters expected warnings from the test and fails if
3999        the expected warning is not caught. It uses set_warn_always() to force
4000        TORCH_WARN_ONCE to behave like TORCH_WARN
4001        """
4002        pattern = re.compile(regex)
4003        with warnings.catch_warnings(record=True) as ws:
4004            warnings.simplefilter("always")  # allow any warning to be raised
4005            with set_warn_always_context(True):
4006                yield
4007            if len(ws) == 0:
4008                self.fail('no warning caught')
4009            self.assertTrue(any(type(w.message) is category for w in ws))
4010            self.assertTrue(
4011                any(re.match(pattern, str(w.message)) for w in ws),
4012                f'{pattern}, {[w.message for w in ws if type(w.message) is category]}')
4013
4014    def assertExpected(self, s, subname=None):
4015        r"""
4016        Test that a string matches the recorded contents of a file
4017        derived from the name of this test and subname.  This file
4018        is placed in the 'expect' directory in the same directory
4019        as the test script. You can automatically update the recorded test
4020        output using --accept.
4021
4022        If you call this multiple times in a single function, you must
4023        give a unique subname each time.
4024        """
4025        if not isinstance(s, str):
4026            raise TypeError("assertExpected is strings only")
4027
4028        def remove_prefix(text, prefix):
4029            if text.startswith(prefix):
4030                return text[len(prefix):]
4031            return text
4032        # NB: we take __file__ from the module that defined the test
4033        # class, so we place the expect directory where the test script
4034        # lives, NOT where test/common_utils.py lives.  This doesn't matter in
4035        # PyTorch where all test scripts are in the same directory as
4036        # test/common_utils.py, but it matters in onnx-pytorch
4037        module_id = self.__class__.__module__
4038        munged_id = remove_prefix(self.id(), module_id + ".")
4039        test_file = os.path.realpath(sys.modules[module_id].__file__)
4040        expected_file = os.path.join(os.path.dirname(test_file),
4041                                     "expect",
4042                                     munged_id)
4043
4044        subname_output = ""
4045        if subname:
4046            expected_file += "-" + subname
4047            subname_output = f" ({subname})"
4048        expected_file += ".expect"
4049        expected = None
4050
4051        def accept_output(update_type):
4052            print(f"Accepting {update_type} for {munged_id}{subname_output}:\n\n{s}")
4053            with open(expected_file, 'w') as f:
4054                # Adjust for producer_version, leave s unmodified
4055                s_tag = re.sub(r'(producer_version): "[0-9.]*"',
4056                               r'\1: "CURRENT_VERSION"', s)
4057                f.write(s_tag)
4058
4059        try:
4060            with open(expected_file) as f:
4061                expected = f.read()
4062        except OSError as e:
4063            if e.errno != errno.ENOENT:
4064                raise
4065            elif expecttest.ACCEPT:
4066                return accept_output("output")
4067            else:
4068                raise RuntimeError(
4069                      f"I got this output for {munged_id}{subname_output}:\n\n{s}\n\n"
4070                      "No expect file exists; to accept the current output, run:\n"
4071                      f"python {__main__.__file__} {munged_id} --accept") from None
4072
4073        # a hack for JIT tests
4074        if IS_WINDOWS:
4075            expected = re.sub(r'CppOp\[(.+?)\]', 'CppOp[]', expected)
4076            s = re.sub(r'CppOp\[(.+?)\]', 'CppOp[]', s)
4077
4078        # Adjust for producer_version
4079        expected = expected.replace(
4080            'producer_version: "CURRENT_VERSION"',
4081            f'producer_version: "{torch.onnx.producer_version}"'
4082        )
4083        if expecttest.ACCEPT:
4084            if expected != s:
4085                return accept_output("updated output")
4086        else:
4087            if hasattr(self, "assertMultiLineEqual"):
4088                # Python 2.7 only
4089                # NB: Python considers lhs "old" and rhs "new".
4090                self.assertMultiLineEqual(expected, s)
4091            else:
4092                self.assertEqual(s, expected)
4093
4094    def assertExpectedStripMangled(self, s, subname=None):
4095        s = re.sub(r'__torch__[^ ]+', '', s)
4096        self.assertExpected(s, subname)
4097
4098    def assertGreaterAlmostEqual(self, first, second, places=None, msg=None, delta=None):
4099        """Assert that ``first`` is greater than or almost equal to ``second``.
4100
4101        The equality of ``first`` and ``second`` is determined in a similar way to
4102        the ``assertAlmostEqual`` function of the standard library.
4103        """
4104        if delta is not None and places is not None:
4105            raise TypeError("specify delta or places not both")
4106
4107        if first >= second:
4108            return
4109
4110        diff = second - first
4111        if delta is not None:
4112            if diff <= delta:
4113                return
4114
4115            standardMsg = f"{first} not greater than or equal to {second} within {delta} delta"
4116        else:
4117            if places is None:
4118                places = 7
4119
4120            if round(diff, places) == 0:
4121                return
4122
4123            standardMsg = f"{first} not greater than or equal to {second} within {places} places"
4124
4125        msg = self._formatMessage(msg, standardMsg)
4126        raise self.failureException(msg)
4127
4128    def assertAtenOp(self, onnx_model, operator, overload_name=""):
4129        all_aten_nodes = [p for p in onnx_model.graph.node
4130                          if p.op_type == "ATen" and p.domain == "org.pytorch.aten"]
4131        self.assertTrue(all_aten_nodes)
4132
4133        for op in all_aten_nodes:
4134            attrs = {attr.name: attr.s.decode() for attr in op.attribute}
4135            if attrs.get("operator") == operator:
4136                break
4137
4138        self.assertEqual(attrs["operator"], operator)
4139        self.assertEqual(attrs.get("overload_name", ""), overload_name)
4140
4141    def check_nondeterministic_alert(self, fn, caller_name, should_alert=True):
4142        '''Checks that an operation produces a nondeterministic alert when
4143        expected while `torch.use_deterministic_algorithms(True)` is set.
4144
4145        Args:
4146          fn (callable): Function to check for a nondeterministic alert
4147
4148          caller_name (str): Name of the operation that produces the
4149              nondeterministic alert. This name is expected to appear at the
4150              beginning of the error/warning message.
4151
4152          should_alert (bool, optional): If True, then the check will only pass
4153              if calling `fn` produces a nondeterministic error/warning with the
4154              expected message. If False, then the check will only pass if
4155              calling `fn` does not produce an error. Default: `True`.
4156        '''
4157
4158        alert_message = '^' + caller_name + ' does not have a deterministic implementation, but you set'
4159
4160        # Check that errors are thrown correctly
4161        with DeterministicGuard(True):
4162            if should_alert:
4163                with self.assertRaisesRegex(
4164                        RuntimeError,
4165                        alert_message,
4166                        msg='expected a non-deterministic error, but it was not raised'):
4167                    fn()
4168
4169            else:
4170                # If a nondeterministic error is not expected, make sure
4171                # that it is not raised
4172                try:
4173                    fn()
4174                except RuntimeError as e:
4175                    if 'does not have a deterministic implementation' in str(e):
4176                        self.fail(
4177                            'did not expect non-deterministic error message, '
4178                            + 'but got one anyway: "' + str(e) + '"')
4179                    # Reraise exceptions unrelated to nondeterminism
4180                    raise
4181
4182        # Check that warnings are thrown correctly
4183        with DeterministicGuard(True, warn_only=True):
4184            if should_alert:
4185                with self.assertWarnsRegex(
4186                        UserWarning,
4187                        alert_message):
4188                    fn()
4189            else:
4190                with warnings.catch_warnings(record=True) as w:
4191                    warnings.simplefilter("always")
4192                    fn()
4193                    for warning in w:
4194                        if isinstance(warning, UserWarning):
4195                            self.assertTrue(re.search(alert_message, str(warning)) is None)
4196
4197    # run code in subprocess and capture exceptions.
4198    @staticmethod
4199    def run_process_no_exception(code, env=None):
4200        import subprocess
4201
4202        popen = subprocess.Popen(
4203            [sys.executable, '-c', code],
4204            stdout=subprocess.PIPE,
4205            stderr=subprocess.PIPE,
4206            env=env)
4207        (stdout, stderr) = popen.communicate()
4208        return (stdout, stderr)
4209
4210    # returns captured stderr
4211    @staticmethod
4212    def runWithPytorchAPIUsageStderr(code):
4213        env = os.environ.copy()
4214        env["PYTORCH_API_USAGE_STDERR"] = "1"
4215        # remove CI flag since this is a wrapped test process.
4216        # CI flag should be set in the parent process only.
4217        env.pop("CI", None)
4218        env.pop("TEST_SHOWLOCALS", None)
4219        (stdout, stderr) = TestCase.run_process_no_exception(code, env=env)
4220        return stderr.decode('ascii')
4221
4222
4223class TestCaseBase(TestCase):
4224    # Calls to super() in dynamically created classes are a bit odd.
4225    # See https://github.com/pytorch/pytorch/pull/118586 for more info
4226    # Subclassing this class and then calling super(TestCaseBase) will run
4227    # TestCase's setUp, tearDown etc functions
4228    pass
4229
4230
4231def download_file(url, binary=True):
4232    from urllib.parse import urlsplit
4233    from urllib import request, error
4234
4235    filename = os.path.basename(urlsplit(url)[2])
4236    data_dir = get_writable_path(os.path.join(os.path.dirname(__file__), 'data'))
4237    path = os.path.join(data_dir, filename)
4238
4239    if os.path.exists(path):
4240        return path
4241    try:
4242        data = request.urlopen(url, timeout=15).read()
4243        with open(path, 'wb' if binary else 'w') as f:
4244            f.write(data)
4245        return path
4246    except error.URLError as e:
4247        msg = f"could not download test file '{url}'"
4248        warnings.warn(msg, RuntimeWarning)
4249        raise unittest.SkipTest(msg) from e
4250
4251def find_free_port():
4252    """
4253    Finds an available port and returns that port number.
4254
4255    NOTE: If this function is being used to allocate a port to Store (or
4256    indirectly via init_process_group or init_rpc), it should be used
4257    in conjuction with the `retry_on_connect_failures` decorator as there is a potential
4258    race condition where the allocated port may become unavailable before it can be used
4259    """
4260    with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
4261        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
4262        sock.bind(('localhost', 0))
4263        _, port = sock.getsockname()
4264        return port
4265
4266# Errors that we can get in c10d initialization for which we should retry tests for.
4267ADDRESS_IN_USE = "Address already in use"
4268CONNECT_TIMEOUT = "connect() timed out."
4269
4270def retry_on_connect_failures(func=None, connect_errors=(ADDRESS_IN_USE)):
4271    """Reruns a test if the test returns a RuntimeError and the exception
4272    contains one of the strings in connect_errors."""
4273    # This if block is executed when using this function as a decorator with arguments.
4274    if func is None:
4275        return partial(retry_on_connect_failures, connect_errors=connect_errors)
4276
4277    @wraps(func)
4278    def wrapper(*args, **kwargs):
4279        n_retries = 10
4280        tries_remaining = n_retries
4281        while True:
4282            try:
4283                return func(*args, **kwargs)
4284            except RuntimeError as error:
4285                if any(connect_error in str(error) for connect_error in connect_errors):
4286                    tries_remaining -= 1
4287                    if tries_remaining == 0:
4288                        raise RuntimeError(f"Failing after {n_retries} retries with error: {str(error)}") from error
4289                    time.sleep(random.random())
4290                    continue
4291                raise
4292    return wrapper
4293
4294
4295# Decorator to retry upon certain Exceptions.
4296def retry(ExceptionToCheck, tries=3, delay=3, skip_after_retries=False):
4297    def deco_retry(f):
4298        @wraps(f)
4299        def f_retry(*args, **kwargs):
4300            mtries, mdelay = tries, delay
4301            while mtries > 1:
4302                try:
4303                    return f(*args, **kwargs)
4304                except ExceptionToCheck as e:
4305                    msg = "%s, Retrying in %d seconds..." % (str(e), mdelay)
4306                    print(msg)
4307                    time.sleep(mdelay)
4308                    mtries -= 1
4309            try:
4310                return f(*args, **kwargs)
4311            except ExceptionToCheck as e:
4312                raise unittest.SkipTest(f"Skipping after {tries} consecutive {str(e)}") from e if skip_after_retries else e
4313        return f_retry  # true decorator
4314    return deco_retry
4315
4316
4317# FIXME: modernize these to be consistent with make_tensor
4318#   and review including them in torch.testing
4319# Methods for matrix generation
4320
4321def random_square_matrix_of_rank(l, rank, dtype=torch.double, device='cpu'):
4322    assert rank <= l
4323    A = torch.randn(l, l, dtype=dtype, device=device)
4324    u, s, vh = torch.linalg.svd(A, full_matrices=False)
4325    for i in range(l):
4326        if i >= rank:
4327            s[i] = 0
4328        elif s[i] == 0:
4329            s[i] = 1
4330    return (u * s.to(dtype).unsqueeze(-2)) @ vh
4331
4332def random_well_conditioned_matrix(*shape, dtype, device, mean=1.0, sigma=0.001):
4333    """
4334    Returns a random rectangular matrix (batch of matrices)
4335    with singular values sampled from a Gaussian with
4336    mean `mean` and standard deviation `sigma`.
4337    The smaller the `sigma`, the better conditioned
4338    the output matrix is.
4339    """
4340    primitive_dtype = {
4341        torch.float: torch.float,
4342        torch.double: torch.double,
4343        torch.cfloat: torch.float,
4344        torch.cdouble: torch.double
4345    }
4346    x = torch.rand(shape, dtype=dtype, device=device)
4347    m = x.size(-2)
4348    n = x.size(-1)
4349    u, _, vh = torch.linalg.svd(x, full_matrices=False)
4350    s = (torch.randn(*(shape[:-2] + (min(m, n),)), dtype=primitive_dtype[dtype], device=device) * sigma + mean) \
4351        .sort(-1, descending=True).values.to(dtype)
4352    return (u * s.unsqueeze(-2)) @ vh
4353
4354# Returns a noncontiguous (tensor with the same shape and values as t
4355# The noncontiguous tensor is constructed such that elements in the innermost
4356#   dimension are separated by zeros or (whenever possible) nans
4357# TODO: consider more complicated noncontiguity schemes
4358def noncontiguous_like(t):
4359    # Short-circuits if t is already noncontiguous
4360    if not t.is_contiguous():
4361        return t
4362
4363    # Choose a "weird" value that won't be accessed
4364    if t.dtype.is_floating_point or t.dtype.is_complex:
4365        value = math.nan
4366    elif t.dtype == torch.bool:
4367        value = True
4368    else:
4369        value = 12
4370
4371    result = t.new_empty(t.shape + (2,))
4372    result[..., 0] = value
4373    result[..., 1] = t.detach()
4374    result = result[..., 1]
4375    result.requires_grad_(t.requires_grad)
4376    return result
4377
4378# TODO: remove this (prefer make_symmetric_matrices below)
4379def random_symmetric_matrix(l, *batches, **kwargs):
4380    dtype = kwargs.get('dtype', torch.double)
4381    device = kwargs.get('device', 'cpu')
4382    A = torch.randn(*(batches + (l, l)), dtype=dtype, device=device)
4383    A = (A + A.mT).div_(2)
4384    return A
4385
4386# Creates a symmetric matrix or batch of symmetric matrices
4387# Shape must be a square matrix or batch of square matrices
4388def make_symmetric_matrices(*shape, device, dtype):
4389    assert shape[-1] == shape[-2]
4390    t = make_tensor(shape, device=device, dtype=dtype)
4391    t = (t + t.mT).div_(2)
4392    return t
4393
4394def random_hermitian_matrix(l, *batches, **kwargs):
4395    dtype = kwargs.get('dtype', torch.double)
4396    device = kwargs.get('device', 'cpu')
4397    A = torch.randn(*(batches + (l, l)), dtype=dtype, device=device)
4398    A = (A + A.mH).div_(2)
4399    return A
4400
4401
4402def random_symmetric_psd_matrix(l, *batches, **kwargs):
4403    """
4404    Returns a batch of random symmetric positive-semi-definite matrices.
4405    The shape of the result is batch_dims + (matrix_size, matrix_size)
4406    The following example creates a tensor of size 2 x 4 x 3 x 3
4407    >>> # xdoctest: +SKIP("undefined variables")
4408    >>> matrices = random_symmetric_psd_matrix(3, 2, 4, dtype=dtype, device=device)
4409    """
4410    dtype = kwargs.get('dtype', torch.double)
4411    device = kwargs.get('device', 'cpu')
4412    A = torch.randn(*(batches + (l, l)), dtype=dtype, device=device)
4413    return A @ A.mT
4414
4415
4416def random_hermitian_psd_matrix(matrix_size, *batch_dims, dtype=torch.double, device='cpu'):
4417    """
4418    Returns a batch of random Hermitian positive-semi-definite matrices.
4419    The shape of the result is batch_dims + (matrix_size, matrix_size)
4420    The following example creates a tensor of size 2 x 4 x 3 x 3
4421    >>> # xdoctest: +SKIP("undefined variables")
4422    >>> matrices = random_hermitian_psd_matrix(3, 2, 4, dtype=dtype, device=device)
4423    """
4424    A = torch.randn(*(batch_dims + (matrix_size, matrix_size)), dtype=dtype, device=device)
4425    return A @ A.mH
4426
4427
4428# TODO: remove this (prefer make_symmetric_pd_matrices below)
4429def random_symmetric_pd_matrix(matrix_size, *batch_dims, **kwargs):
4430    dtype = kwargs.get('dtype', torch.double)
4431    device = kwargs.get('device', 'cpu')
4432    A = torch.randn(*(batch_dims + (matrix_size, matrix_size)),
4433                    dtype=dtype, device=device)
4434    return torch.matmul(A, A.mT) \
4435        + torch.eye(matrix_size, dtype=dtype, device=device) * 1e-5
4436
4437
4438# Creates a symmetric positive-definite matrix or batch of
4439#   such matrices
4440def make_symmetric_pd_matrices(*shape, device, dtype):
4441    assert shape[-1] == shape[-2]
4442    t = make_tensor(shape, device=device, dtype=dtype)
4443    i = torch.eye(shape[-1], device=device, dtype=dtype) * 1e-5
4444    return t @ t.mT + i
4445
4446def random_hermitian_pd_matrix(matrix_size, *batch_dims, dtype, device):
4447    """
4448    Returns a batch of random Hermitian positive-definite matrices.
4449    The shape of the result is batch_dims + (matrix_size, matrix_size)
4450    The following example creates a tensor of size 2 x 4 x 3 x 3
4451    >>> # xdoctest: +SKIP("undefined variables")
4452    >>> matrices = random_hermitian_pd_matrix(3, 2, 4, dtype=dtype, device=device)
4453    """
4454    A = torch.randn(*(batch_dims + (matrix_size, matrix_size)),
4455                    dtype=dtype, device=device)
4456    return A @ A.mH + torch.eye(matrix_size, dtype=dtype, device=device)
4457
4458# Creates a full rank matrix with distinct singular values or
4459#   a batch of such matrices
4460def make_fullrank_matrices_with_distinct_singular_values(*shape, device, dtype, requires_grad=False):
4461    with torch.no_grad():
4462        t = make_tensor(shape, device=device, dtype=dtype)
4463        u, _, vh = torch.linalg.svd(t, full_matrices=False)
4464        real_dtype = t.real.dtype if t.dtype.is_complex else t.dtype
4465        k = min(shape[-1], shape[-2])
4466        # We choose the singular values to be "around one"
4467        # This is to make the matrix well conditioned
4468        # s = [2, 3, ..., k+1]
4469        s = torch.arange(2, k + 2, dtype=real_dtype, device=device)
4470        # s = [2, -3, 4, ..., (-1)^k k+1]
4471        s[1::2] *= -1.
4472        # 1 + 1/s so that the singular values are in the range [2/3, 3/2]
4473        # This gives a condition number of 9/4, which should be good enough
4474        s.reciprocal_().add_(1.)
4475        # Note that the singular values need not be ordered in an SVD so
4476        # we don't need need to sort S
4477        x = (u * s.to(u.dtype)) @ vh
4478    x.requires_grad_(requires_grad)
4479    return x
4480
4481def random_matrix(rows, columns, *batch_dims, **kwargs):
4482    """Return rectangular matrix or batches of rectangular matrices.
4483
4484    Parameters:
4485      dtype - the data type
4486      device - the device kind
4487      singular - when True, the output will be singular
4488    """
4489    dtype = kwargs.get('dtype', torch.double)
4490    device = kwargs.get('device', 'cpu')
4491    silent = kwargs.get("silent", False)
4492    singular = kwargs.get("singular", False)
4493    if silent and not torch._C.has_lapack:
4494        return torch.ones(rows, columns, dtype=dtype, device=device)
4495
4496    A = torch.randn(batch_dims + (rows, columns), dtype=dtype, device=device)
4497    if A.numel() == 0:
4498        return A
4499    u, _, vh = torch.linalg.svd(A, full_matrices=False)
4500    k = min(rows, columns)
4501    s = torch.linspace(1 / (k + 1), 1, k, dtype=dtype, device=device)
4502    if singular:
4503        # make matrix singular
4504        s[k - 1] = 0
4505        if k > 2:
4506            # increase the order of singularity so that the pivoting
4507            # in LU factorization will be non-trivial
4508            s[0] = 0
4509    return (u * s.unsqueeze(-2)) @ vh
4510
4511
4512def random_lowrank_matrix(rank, rows, columns, *batch_dims, **kwargs):
4513    """Return rectangular matrix or batches of rectangular matrices with
4514    given rank.
4515    """
4516    B = random_matrix(rows, rank, *batch_dims, **kwargs)
4517    C = random_matrix(rank, columns, *batch_dims, **kwargs)
4518    return B.matmul(C)
4519
4520
4521def _generate_indices_prefer_all_rows(rows: int, cols: int, num_indices: int) -> torch.Tensor:
4522    """Generate indices for a row x cols matrix, preferring at least one index per row if possible."""
4523    indices = []
4524    n_per_row = math.ceil(num_indices / rows)
4525    col_indices = list(range(cols))
4526
4527    for r in range(rows):
4528        # Note that this can yield overlapping indices
4529        for c in random.choices(col_indices, k=n_per_row):
4530            indices.append((r, c))
4531
4532    return torch.tensor(indices[:num_indices])
4533
4534
4535def random_sparse_matrix(rows, columns, density=0.01, **kwargs):
4536    """Return rectangular random sparse matrix within given density.
4537
4538    The density of the result approaches to given density as the size
4539    of the matrix is increased and a relatively small value of density
4540    is specified but higher than min(rows, columns)/(rows * columns)
4541    for non-singular matrices.
4542    """
4543    dtype = kwargs.get('dtype', torch.double)
4544    device = kwargs.get('device', 'cpu')
4545
4546    nonzero_elements = max(min(rows, columns), int(rows * columns * density))
4547    indices = _generate_indices_prefer_all_rows(rows, columns, nonzero_elements)
4548    values = torch.randn(nonzero_elements, dtype=dtype, device=device)
4549
4550    # ensure that the diagonal dominates
4551    values *= torch.tensor([-float(i - j)**2 for i, j in indices], dtype=dtype, device=device).exp()
4552    A = torch.sparse_coo_tensor(indices.t(), values, (rows, columns), device=device)
4553    return A.coalesce()
4554
4555
4556def random_sparse_pd_matrix(matrix_size, density=0.01, **kwargs):
4557    """Return random sparse positive-definite matrix with given density.
4558
4559    The eigenvalues of the matrix are defined as::
4560      arange(1, matrix_size+1)/matrix_size
4561
4562    Algorithm:
4563      A = diag(arange(1, matrix_size+1)/matrix_size)
4564      while <A density is smaller than required>:
4565          <choose random i, j in range(matrix_size), theta in [0, 2*pi]>
4566          R = <rotation matrix (i,j,theta)>
4567          A = R^T A R
4568    """
4569    import math
4570    torch = kwargs.get('torch', globals()['torch'])
4571    dtype = kwargs.get('dtype', torch.double)
4572    device = kwargs.get('device', 'cpu')
4573    data = {(i, i): float(i + 1) / matrix_size
4574            for i in range(matrix_size)}
4575
4576
4577    def multiply(data, N, i, j, cs, sn, left=True):
4578        for k in range(N):
4579            if left:
4580                ik, jk = (k, i), (k, j)
4581            else:
4582                ik, jk = (i, k), (j, k)
4583            aik, ajk = data.get(ik, 0), data.get(jk, 0)
4584            aik, ajk = cs * aik + sn * ajk, -sn * aik + cs * ajk
4585            if aik:
4586                data[ik] = aik
4587            else:
4588                data.pop(ik, None)
4589            if ajk:
4590                data[jk] = ajk
4591            else:
4592                data.pop(jk, None)
4593
4594    target_nnz = density * matrix_size * matrix_size
4595    while len(data) < target_nnz:
4596        i = random.randint(0, matrix_size - 1)
4597        j = random.randint(0, matrix_size - 1)
4598        if i != j:
4599            theta = random.uniform(0, 2 * math.pi)
4600            cs = math.cos(theta)
4601            sn = math.sin(theta)
4602            multiply(data, matrix_size, i, j, cs, sn, left=True)
4603            multiply(data, matrix_size, i, j, cs, sn, left=False)
4604    icoords, jcoords, values = [], [], []
4605    for (i, j), v in sorted(data.items()):
4606        icoords.append(i)
4607        jcoords.append(j)
4608        values.append(v)
4609    indices_tensor = torch.tensor([icoords, jcoords])
4610    return torch.sparse_coo_tensor(indices_tensor, values, (matrix_size, matrix_size), dtype=dtype, device=device)
4611
4612# FIXME: remove this by updating test suites using it
4613def do_test_dtypes(self, dtypes, layout, device):
4614    for dtype in dtypes:
4615        if dtype != torch.float16:
4616            out = torch.zeros((2, 3), dtype=dtype, layout=layout, device=device)
4617            self.assertIs(dtype, out.dtype)
4618            self.assertIs(layout, out.layout)
4619            self.assertEqual(device, out.device)
4620
4621# FIXME: remove this by updating test suites using it
4622def do_test_empty_full(self, dtypes, layout, device):
4623    shape = torch.Size([2, 3])
4624
4625    def check_value(tensor, dtype, layout, device, value, requires_grad):
4626        self.assertEqual(shape, tensor.shape)
4627        self.assertIs(dtype, tensor.dtype)
4628        self.assertIs(layout, tensor.layout)
4629        self.assertEqual(tensor.requires_grad, requires_grad)
4630        if tensor.is_cuda and device is not None:
4631            self.assertEqual(device, tensor.device)
4632        if value is not None:
4633            fill = tensor.new(shape).fill_(value)
4634            self.assertEqual(tensor, fill)
4635
4636    def get_int64_dtype(dtype):
4637        module = '.'.join(str(dtype).split('.')[1:-1])
4638        if not module:
4639            return torch.int64
4640        return operator.attrgetter(module)(torch).int64
4641
4642    default_dtype = torch.get_default_dtype()
4643    check_value(torch.empty(shape), default_dtype, torch.strided, -1, None, False)
4644    check_value(torch.full(shape, -5.), default_dtype, torch.strided, -1, None, False)
4645    for dtype in dtypes:
4646        for rg in {dtype.is_floating_point, False}:
4647            int64_dtype = get_int64_dtype(dtype)
4648            v = torch.empty(shape, dtype=dtype, device=device, layout=layout, requires_grad=rg)
4649            check_value(v, dtype, layout, device, None, rg)
4650            out = v.new()
4651            check_value(torch.empty(shape, out=out, device=device, layout=layout, requires_grad=rg),
4652                        dtype, layout, device, None, rg)
4653            check_value(v.new_empty(shape), dtype, layout, device, None, False)
4654            check_value(v.new_empty(shape, dtype=int64_dtype, device=device, requires_grad=False),
4655                        int64_dtype, layout, device, None, False)
4656            check_value(torch.empty_like(v), dtype, layout, device, None, False)
4657            check_value(torch.empty_like(v, dtype=int64_dtype, layout=layout, device=device, requires_grad=False),
4658                        int64_dtype, layout, device, None, False)
4659
4660            if dtype is not torch.float16 and layout != torch.sparse_coo:
4661                fv = 3
4662                v = torch.full(shape, fv, dtype=dtype, layout=layout, device=device, requires_grad=rg)
4663                check_value(v, dtype, layout, device, fv, rg)
4664                check_value(v.new_full(shape, fv + 1), dtype, layout, device, fv + 1, False)
4665                out = v.new()
4666                check_value(torch.full(shape, fv + 2, out=out, device=device, layout=layout, requires_grad=rg),
4667                            dtype, layout, device, fv + 2, rg)
4668                check_value(v.new_full(shape, fv + 3, dtype=int64_dtype, device=device, requires_grad=False),
4669                            int64_dtype, layout, device, fv + 3, False)
4670                check_value(torch.full_like(v, fv + 4), dtype, layout, device, fv + 4, False)
4671                check_value(torch.full_like(v, fv + 5,
4672                                            dtype=int64_dtype, layout=layout, device=device, requires_grad=False),
4673                            int64_dtype, layout, device, fv + 5, False)
4674
4675# FIXME: improve load_tests() documentation here
4676running_script_path = None
4677def set_running_script_path():
4678    global running_script_path
4679    try:
4680        running_file = os.path.abspath(os.path.realpath(sys.argv[0]))
4681        if running_file.endswith('.py'):  # skip if the running file is not a script
4682            running_script_path = running_file
4683    except Exception:
4684        pass
4685
4686def check_test_defined_in_running_script(test_case):
4687    if running_script_path is None:
4688        return
4689    test_case_class_file = os.path.abspath(os.path.realpath(inspect.getfile(test_case.__class__)))
4690    assert test_case_class_file == running_script_path, f'Class of loaded TestCase "{test_case.id()}" ' \
4691        f'is not defined in the running script "{running_script_path}", but in "{test_case_class_file}". Did you ' \
4692        "accidentally import a unittest.TestCase from another file?"
4693
4694def load_tests(loader, tests, pattern):
4695    set_running_script_path()
4696    test_suite = unittest.TestSuite()
4697    for test_group in tests:
4698        if not DISABLE_RUNNING_SCRIPT_CHK:
4699            for test in test_group:
4700                check_test_defined_in_running_script(test)
4701        if test_group._tests:
4702            test_suite.addTest(test_group)
4703    return test_suite
4704
4705# FIXME: document this and move it to test_serialization
4706class BytesIOContext(io.BytesIO):
4707    def __enter__(self):
4708        return self
4709
4710    def __exit__(self, *args):
4711        pass
4712
4713# Tentative value for nondet_tol for gradcheck when backward implementation
4714# relies on nondeterministic operations, i.e., those listed here:
4715# https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html
4716#
4717# For more information see https://github.com/pytorch/pytorch/issues/56202
4718GRADCHECK_NONDET_TOL = 1e-12
4719
4720TEST_WITH_SLOW_GRADCHECK: bool = TestEnvironment.def_flag(
4721    "TEST_WITH_SLOW_GRADCHECK",
4722    env_var="PYTORCH_TEST_WITH_SLOW_GRADCHECK",
4723)
4724
4725skipIfSlowGradcheckEnv = unittest.skipIf(
4726    TEST_WITH_SLOW_GRADCHECK,
4727    "Tests that don't use gradcheck don't need to run on slow_gradcheck CI",
4728)
4729
4730
4731def gradcheck(fn, inputs, **kwargs):
4732    # Wrapper around gradcheck that enables certain keys by default.
4733    # Use this testing-internal gradcheck instead of autograd.gradcheck so that new features like vmap and
4734    # forward-mode AD are tested by default. We create this wrapper because we'd like to keep new checks
4735    # to be disabled to default for the public-facing api to avoid breaking user code.
4736    #
4737    # All PyTorch devs doing testing should use this wrapper instead of autograd.gradcheck.
4738    default_values = {
4739        "check_batched_grad": True,
4740        "fast_mode": True,
4741    }
4742
4743    if TEST_WITH_SLOW_GRADCHECK:
4744        default_values["fast_mode"] = False
4745
4746    for key, value in default_values.items():
4747        # default value override values explicitly set to None
4748        k = kwargs.get(key, None)
4749        kwargs[key] = k if k is not None else value
4750
4751    return torch.autograd.gradcheck(fn, inputs, **kwargs)
4752
4753def gradgradcheck(fn, inputs, grad_outputs=None, **kwargs):
4754    # Wrapper around gradgradcheck that enables certain keys by default
4755    # See gradcheck above for an explanation of why we need something like this.
4756    #
4757    # All PyTorch devs doing testing should use this wrapper instead of autograd.gradgradcheck
4758    default_values = {
4759        "check_batched_grad": True,
4760        "fast_mode": True,
4761    }
4762
4763    if TEST_WITH_SLOW_GRADCHECK:
4764        default_values["fast_mode"] = False
4765
4766    for key, value in default_values.items():
4767        # default value override values explicitly set to None
4768        k = kwargs.get(key, None)
4769        kwargs[key] = k if k is not None else value
4770
4771    return torch.autograd.gradgradcheck(fn, inputs, grad_outputs, **kwargs)
4772
4773
4774def _assertGradAndGradgradChecks(test_case, apply_fn, inputs, **kwargs):
4775    # call assert function rather than returning a bool since it's nicer
4776    # if we get whether this failed on the gradcheck or the gradgradcheck.
4777    test_case.assertTrue(gradcheck(apply_fn, inputs, **kwargs))
4778    test_case.assertTrue(gradgradcheck(apply_fn, inputs, **kwargs))
4779
4780
4781@contextmanager
4782def set_cwd(path: str) -> Iterator[None]:
4783    old_cwd = os.getcwd()
4784    try:
4785        os.chdir(path)
4786        yield
4787    finally:
4788        os.chdir(old_cwd)
4789
4790
4791# FIXME: delete this
4792# Using @toleranceOverride specific to your test is the recommended way
4793# of doing this. These are just some values that worked for test_nn.
4794dtype2prec_DONTUSE = {torch.float: 1e-5,
4795                      torch.double: 1e-5,
4796                      torch.half: 1e-2,
4797                      torch.bfloat16: 1e-1}
4798
4799# FIXME: move to test_sparse or sparse utils
4800# This is a wrapper that wraps a test to run this test twice, one with
4801# coalesced=True, another with coalesced=False for coalesced/uncoalesced sparse tensors.
4802def coalescedonoff(f):
4803    @wraps(f)
4804    def wrapped(self, *args, **kwargs):
4805        f(self, *args, **kwargs, coalesced=True)
4806        f(self, *args, **kwargs, coalesced=False)
4807    return wrapped
4808
4809
4810def is_coalesced_indices(s):
4811    indices = s._indices()
4812    hash_coeffs = (1,) + s.shape[s.sparse_dim() - 1:0:-1]
4813    hash_indices = torch.tensor(hash_coeffs, device=s.device).cumprod(-1).flip(-1)
4814    if s.sparse_dim() > 1:
4815        hash_indices.unsqueeze_(-1)
4816        hash_indices = (indices * hash_indices).sum(0)
4817    else:
4818        hash_indices = indices * hash_indices
4819
4820    # check if indices are sorted
4821    res = torch.allclose(hash_indices, hash_indices.sort()[0])
4822
4823    # check if there are no repeated indices
4824    res = res and torch.allclose(hash_indices, hash_indices.unique())
4825
4826    return res
4827
4828
4829@contextlib.contextmanager
4830def disable_gc():
4831    if gc.isenabled():
4832        try:
4833            gc.disable()
4834            yield
4835        finally:
4836            gc.enable()
4837    else:
4838        yield
4839
4840
4841def find_library_location(lib_name: str) -> Path:
4842    # return the shared library file in the installed folder if exist,
4843    # else the file in the build folder
4844    torch_root = Path(torch.__file__).resolve().parent
4845    path = torch_root / 'lib' / lib_name
4846    if os.path.exists(path):
4847        return path
4848    torch_root = Path(__file__).resolve().parent.parent.parent
4849    return torch_root / 'build' / 'lib' / lib_name
4850
4851def skip_but_pass_in_sandcastle(reason):
4852    """
4853    Similar to unittest.skip, however in the sandcastle environment it just
4854    "passes" the test instead to avoid creating tasks complaining about tests
4855    skipping continuously.
4856    """
4857    def decorator(func):
4858        if not IS_SANDCASTLE:
4859            func.__unittest_skip__ = True
4860            func.__unittest_skip_why__ = reason
4861            return func
4862
4863        @wraps(func)
4864        def wrapper(*args, **kwargs):
4865            print(f'Skipping {func.__name__} on sandcastle for following reason: {reason}', file=sys.stderr)
4866            return
4867        return wrapper
4868
4869    return decorator
4870
4871def mock_wrapper(method):
4872    """
4873    Returns a function that calls the real implementation of a method
4874    in addition to passing args to a mock object.
4875    """
4876    mock = MagicMock()
4877
4878    @wraps(method)
4879    def wrapper(self, *args, **kwargs):
4880        mock(*args, **kwargs)
4881        return method(self, *args, **kwargs)
4882    wrapper.mock = mock  # type: ignore[attr-defined]
4883    return wrapper
4884
4885def get_tensors_from(args, kwargs):
4886    """ Returns a set of all Tensor objects in the given args and kwargs. """
4887    return set([arg for arg in args if isinstance(arg, Tensor)] +
4888               [v for v in kwargs.values() if isinstance(v, Tensor)])
4889
4890
4891# Returns scalar tensor representation of a list of integer byte values
4892def bytes_to_scalar(byte_list: List[int], dtype: torch.dtype, device: torch.device):
4893    dtype_to_ctype: Dict[torch.dtype, Any] = {
4894        torch.int8: ctypes.c_int8,
4895        torch.uint8: ctypes.c_uint8,
4896        torch.uint16: ctypes.c_uint16,
4897        torch.uint32: ctypes.c_uint32,
4898        torch.uint64: ctypes.c_uint64,
4899        torch.int16: ctypes.c_int16,
4900        torch.int32: ctypes.c_int32,
4901        torch.int64: ctypes.c_int64,
4902        torch.bool: ctypes.c_bool,
4903        torch.float32: ctypes.c_float,
4904        torch.complex64: ctypes.c_float,
4905        torch.float64: ctypes.c_double,
4906        torch.complex128: ctypes.c_double,
4907    }
4908    ctype = dtype_to_ctype[dtype]
4909    num_bytes = ctypes.sizeof(ctype)
4910
4911    def check_bytes(byte_list):
4912        for byte in byte_list:
4913            assert 0 <= byte <= 255
4914
4915    if dtype.is_complex:
4916        assert len(byte_list) == (num_bytes * 2)
4917        check_bytes(byte_list)
4918        real = ctype.from_buffer((ctypes.c_byte * num_bytes)(
4919            *byte_list[:num_bytes])).value
4920        imag = ctype.from_buffer((ctypes.c_byte * num_bytes)(
4921            *byte_list[num_bytes:])).value
4922        res = real + 1j * imag
4923    else:
4924        assert len(byte_list) == num_bytes
4925        check_bytes(byte_list)
4926        res = ctype.from_buffer((ctypes.c_byte * num_bytes)(
4927            *byte_list)).value
4928
4929    return torch.tensor(res, device=device, dtype=dtype)
4930
4931
4932def copy_func(f):
4933    """Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)"""
4934    g = types.FunctionType(f.__code__, f.__globals__, name=f.__name__,
4935                           argdefs=f.__defaults__,
4936                           closure=f.__closure__)
4937    g = functools.update_wrapper(g, f)
4938    g.__kwdefaults__ = f.__kwdefaults__
4939    return g
4940
4941
4942def xfail_inherited_tests(tests):
4943    """
4944    Given a list of test names which are defined by a superclass of the
4945    class this decorates, mark them as expected failure.  This is useful
4946    if you are doing poor man's parameterized tests by subclassing a generic
4947    test class.
4948    """
4949    def deco(cls):
4950        for t in tests:
4951            # NB: expectedFailure operates by mutating the method in question,
4952            # which is why you have to copy the function first
4953            setattr(cls, t, unittest.expectedFailure(copy_func(getattr(cls, t))))
4954        return cls
4955    return deco
4956
4957
4958def skip_but_pass_in_sandcastle_if(condition, reason):
4959    """
4960    Similar to unittest.skipIf, however in the sandcastle environment it just
4961    "passes" the test instead to avoid creating tasks complaining about tests
4962    skipping continuously.
4963    """
4964    def decorator(func):
4965        if condition:
4966            if IS_SANDCASTLE:
4967                @wraps(func)
4968                def wrapper(*args, **kwargs):
4969                    print(f'Skipping {func.__name__} on sandcastle for following reason: {reason}', file=sys.stderr)
4970                return wrapper
4971            else:
4972                func.__unittest_skip__ = True
4973                func.__unittest_skip_why__ = reason
4974
4975        return func
4976
4977    return decorator
4978
4979def dtype_name(dtype):
4980    """ Returns the pretty name of the dtype (e.g. torch.int64 -> int64). """
4981    return str(dtype).split('.')[1]
4982
4983
4984dtype_abbrs = {
4985    torch.bfloat16: 'bf16',
4986    torch.float64: 'f64',
4987    torch.float32: 'f32',
4988    torch.float16: 'f16',
4989    torch.complex32: 'c32',
4990    torch.complex64: 'c64',
4991    torch.complex128: 'c128',
4992    torch.int8: 'i8',
4993    torch.int16: 'i16',
4994    torch.int32: 'i32',
4995    torch.int64: 'i64',
4996    torch.bool: 'b8',
4997    torch.uint8: 'u8',
4998}
4999
5000
5001@functools.lru_cache
5002def get_cycles_per_ms() -> float:
5003    """Measure and return approximate number of cycles per millisecond for torch.cuda._sleep
5004    """
5005
5006    def measure() -> float:
5007        start = torch.cuda.Event(enable_timing=True)
5008        end = torch.cuda.Event(enable_timing=True)
5009        start.record()
5010        torch.cuda._sleep(1000000)
5011        end.record()
5012        end.synchronize()
5013        cycles_per_ms = 1000000 / start.elapsed_time(end)
5014        return cycles_per_ms
5015
5016    # Get 10 values and remove the 2 max and 2 min and return the avg.
5017    # This is to avoid system disturbance that skew the results, e.g.
5018    # the very first cuda call likely does a bunch of init, which takes
5019    # much longer than subsequent calls.
5020    #
5021    # Tested on both Tesla V100, Quadro GP100, Titan RTX, RTX 3090 GPUs
5022    # and seems to return stable values. Therefore, we enable caching
5023    # using lru_cache decorator above.
5024    num = 10
5025    vals = []
5026    for _ in range(num):
5027        vals.append(measure())
5028    vals = sorted(vals)
5029    return mean(vals[2 : num - 2])
5030
5031
5032# OpInfo utils
5033
5034T = TypeVar('T')
5035def first_sample(self: unittest.TestCase, samples: Iterable[T]) -> T:
5036    """
5037    Returns the first sample from an iterable of samples, like those returned by OpInfo.
5038    The test will be skipped if no samples are available.
5039    """
5040    try:
5041        return next(iter(samples))
5042    except StopIteration as e:
5043        raise unittest.SkipTest('Skipped! Need at least 1 sample input') from e
5044
5045# this helper method is to recursively
5046# clone the tensor-type input of operators tested by OpInfo
5047def clone_input_helper(input):
5048    if isinstance(input, torch.Tensor):
5049        return torch.clone(input)
5050
5051    if isinstance(input, Sequence):
5052        return tuple(map(clone_input_helper, input))
5053
5054    return input
5055
5056@contextmanager
5057def custom_op(opname, symbolic_fn, opset_version):
5058    """Context manager/decorator to test ONNX export with custom operator"""
5059    try:
5060        register_custom_op_symbolic(opname, symbolic_fn, opset_version)
5061        yield
5062    finally:
5063        unregister_custom_op_symbolic(opname, opset_version)
5064
5065
5066def outs_and_grads(fn, graph_inps, inps):
5067    outs = fn(*graph_inps)
5068    for out in pytree.tree_leaves(outs):
5069        if isinstance(out, torch.Tensor) and out.requires_grad:
5070            out.sum().backward(retain_graph=True)
5071    grads = [inp.grad for inp in pytree.tree_leaves(inps) if isinstance(inp, torch.Tensor)]
5072    for inp in pytree.tree_leaves(inps):
5073        if isinstance(inp, torch.Tensor):
5074            inp.grad = None
5075    return outs, grads
5076
5077def compare_equal_outs_and_grads(test, m1, m2, inps):
5078    r1, g1 = outs_and_grads(m1, inps, inps)
5079    r2, g2 = outs_and_grads(m2, inps, inps)
5080    test.assertEqual(r1, r2)
5081    test.assertEqual(g1, g2)
5082
5083class TestGradients(TestCase):
5084    exact_dtype = True
5085
5086    # Copies inputs to inplace operations to avoid inplace modifications
5087    #   to leaves requiring gradient
5088    def _get_safe_inplace(self, inplace_variant):
5089        @wraps(inplace_variant)
5090        def _fn(t, *args, **kwargs):
5091            return inplace_variant(t.clone(), *args, **kwargs)
5092
5093        return _fn
5094
5095    def _check_helper(self, device, dtype, op, variant, check, *, check_forward_ad=False, check_backward_ad=True,
5096                      check_batched_grad=None, check_batched_forward_grad=False):
5097        assert check in ('gradcheck', 'bwgrad_bwgrad', 'fwgrad_bwgrad')
5098        # NB: check_backward_ad does not affect gradgradcheck (always True)
5099        if variant is None:
5100            self.skipTest("Skipped! Variant not implemented.")
5101        if not op.supports_dtype(dtype, torch.device(device).type):
5102            self.skipTest(f"Skipped! {op.name} does not support dtype {str(dtype)}")
5103
5104        def is_inplace(variant):
5105            if hasattr(variant, "__wrapped__"):
5106                return variant.__wrapped__ is op.get_inplace()
5107            return variant is op.get_inplace()
5108
5109        include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
5110
5111        samples = op.sample_inputs(device, dtype, requires_grad=True, include_conjugated_inputs=include_conjugated_inputs,
5112                                   small_inputs_only=TEST_WITH_SLOW_GRADCHECK)
5113
5114        for sample in samples:
5115            if sample.broadcasts_input and is_inplace(variant):
5116                continue
5117
5118            # Gradcheck expects tensors as its input, but autograd actually supports tensorlists
5119            #   and tensors passed as kwargs. The following creates a function that accepts just
5120            #   the tensors that require grad as varargs, and then recomposes them back into the
5121            #   original input.
5122
5123            # Creates gradcheck inputs by identifying tensors requiring grad
5124            all_args = None
5125            if is_iterable_of_tensors(sample.input):
5126                all_args = chain(sample.input, sample.args, sample.kwargs.values())
5127            else:
5128                all_args = tuple(chain((sample.input,), sample.args, sample.kwargs.values()))
5129            gradcheck_args = tuple(x for x in all_args if (isinstance(x, torch.Tensor) and x.requires_grad))
5130
5131            # Verifies sample input tensors should have no grad
5132            # This may happen if the same tensor is used in two different SampleInputs
5133            for t in gradcheck_args:
5134                self.assertIsNone(t.grad,
5135                                  "A sampled input has a gradient before running autograd. "
5136                                  "This usually means that (at least) one input tensor is reused "
5137                                  "across different SampleInputs. "
5138                                  "Please create a new tensor for each SampleInput.")
5139
5140            def _input_recomposition_helper(inputs, inp, input_idx):
5141                if is_iterable_of_tensors(inp):
5142                    tensor_list = []
5143                    for x in inp:
5144                        if isinstance(x, torch.Tensor) and x.requires_grad:
5145                            tensor_list.append(inputs[input_idx])
5146                            input_idx = input_idx + 1
5147                        else:
5148                            tensor_list.append(x)
5149                    return tensor_list, input_idx
5150                elif isinstance(inp, torch.Tensor) and inp.requires_grad:
5151                    return inputs[input_idx], input_idx + 1
5152                else:
5153                    return inp, input_idx
5154
5155            def fn(*inputs):
5156                # Puts inputs back into sample properly
5157                positional_args = []
5158                input_idx = 0
5159                inp, input_idx = _input_recomposition_helper(inputs, sample.input, input_idx)
5160                positional_args.append(inp)
5161
5162                for x in sample.args:
5163                    inp, input_idx = _input_recomposition_helper(inputs, x, input_idx)
5164                    positional_args.append(inp)
5165
5166                # Recreates kwargs
5167                kwargs = {}
5168                for k, v in sample.kwargs.items():
5169                    inp, input_idx = _input_recomposition_helper(inputs, v, input_idx)
5170                    kwargs[k] = inp
5171
5172                output = op.gradcheck_wrapper(variant, *positional_args, **kwargs)
5173                if sample.output_process_fn_grad is not None:
5174                    return sample.output_process_fn_grad(output)
5175                return output
5176
5177            if check == 'gradcheck':
5178                if check_batched_grad is None:
5179                    check_batched_grad = op.check_batched_grad
5180                self.assertTrue(gradcheck(fn, gradcheck_args,
5181                                          check_batched_grad=check_batched_grad,
5182                                          check_grad_dtypes=True,
5183                                          nondet_tol=op.gradcheck_nondet_tol,
5184                                          fast_mode=op.gradcheck_fast_mode,
5185                                          check_forward_ad=check_forward_ad,
5186                                          check_backward_ad=check_backward_ad,
5187                                          check_undefined_grad=True,
5188                                          check_batched_forward_grad=check_batched_forward_grad))
5189            elif check in ('bwgrad_bwgrad', 'fwgrad_bwgrad'):  # gradgrad check
5190                self.assertFalse(check_forward_ad, msg="Cannot run forward AD check for gradgradcheck")
5191                for gen_non_contig_grad_outputs in (False, True):
5192                    kwargs = {
5193                        "gen_non_contig_grad_outputs": gen_non_contig_grad_outputs,
5194                        "check_batched_grad": op.check_batched_gradgrad,
5195                        "check_grad_dtypes": True,
5196                        "nondet_tol": op.gradcheck_nondet_tol,
5197                        "fast_mode": op.gradcheck_fast_mode
5198                    }
5199                    if check == "fwgrad_bwgrad":
5200                        kwargs["check_fwd_over_rev"] = True
5201                        kwargs["check_rev_over_rev"] = False
5202                        kwargs["check_batched_grad"] = False
5203                        kwargs["check_undefined_grad"] = False
5204
5205                    self.assertTrue(gradgradcheck(fn, gradcheck_args, **kwargs))
5206            else:
5207                self.assertTrue(False, msg="Unknown check requested!")
5208
5209    def _grad_test_helper(self, device, dtype, op, variant, *, check_forward_ad=False, check_backward_ad=True,
5210                          check_batched_grad=None, check_batched_forward_grad=False):
5211        return self._check_helper(device, dtype, op, variant, 'gradcheck', check_forward_ad=check_forward_ad,
5212                                  check_backward_ad=check_backward_ad, check_batched_grad=check_batched_grad,
5213                                  check_batched_forward_grad=check_batched_forward_grad)
5214
5215    def _skip_helper(self, op, device, dtype):
5216        if dtype not in op.supported_backward_dtypes(torch.device(device).type):
5217            self.skipTest("Skipped! Op doesn't support autograd for this dtype.")
5218        if not op.supports_autograd and not op.supports_forward_ad:
5219            self.skipTest("Skipped! autograd not supported.")
5220
5221def make_lazy_class(cls):
5222
5223    def lazy_init(self, cb):
5224        self._cb = cb
5225        self._value = None
5226
5227    cls.__init__ = lazy_init
5228
5229    for basename in [
5230        "add", "sub", "mul", "truediv", "floordiv", "mod", "divmod", "pow",
5231        "lshift", "rshift", "and", "or", "xor", "neg", "pos", "abs", "invert",
5232        "eq", "ne", "lt", "le", "gt", "ge", "bool", "int", "index",
5233    ]:
5234        name = f"__{basename}__"
5235
5236        def inner_wrapper(name):
5237            use_operator = basename not in ("bool", "int")
5238
5239            def wrapped(self, *args, **kwargs):
5240                if self._cb is not None:
5241                    self._value = self._cb()
5242                    self._cb = None
5243                if not use_operator:
5244                    return getattr(self._value, name)(*args, **kwargs)
5245                else:
5246                    return getattr(operator, name)(self._value, *args, **kwargs)
5247            return wrapped
5248
5249        setattr(cls, name, inner_wrapper(name))
5250
5251    return cls
5252
5253
5254# Base TestCase for NT tests; used to define common helpers, etc.
5255class NestedTensorTestCase(TestCase):
5256    def assertEqualIgnoringNestedInts(self, a, b):
5257        # unbinding NJTs allows us to compare them as essentially equal without
5258        # caring about exact nested int comparison
5259        def _unbind_njts(x):
5260            if isinstance(x, torch.Tensor) and x.is_nested and x.layout == torch.jagged:
5261                return x.unbind()
5262            else:
5263                return x
5264
5265        self.assertEqual(pytree.tree_map(_unbind_njts, a), pytree.tree_map(_unbind_njts, b))
5266
5267    @contextlib.contextmanager
5268    def branch_nested_state(self):
5269        """Context manager to branch and restore the nested tensor state."""
5270        nested_tensor_module = torch.nested._internal.nested_tensor
5271        original_tensor_symint_registry = nested_tensor_module._tensor_symint_registry.copy()
5272        original_tensor_id_counter = nested_tensor_module._tensor_id_counter
5273        try:
5274            yield
5275        finally:
5276            nested_tensor_module._tensor_id_counter = original_tensor_id_counter
5277            nested_tensor_module._tensor_symint_registry = original_tensor_symint_registry
5278
5279
5280@make_lazy_class
5281class LazyVal:
5282    pass
5283
5284
5285def munge_exc(e, *, suppress_suffix=True, suppress_prefix=True, file=None, skip=0):
5286    if file is None:
5287        file = inspect.stack()[1 + skip].filename  # skip one frame
5288
5289    file = _as_posix_path(file)
5290    s = _as_posix_path(str(e))
5291
5292    # Remove everything that looks like stack frames in NOT this file
5293    def repl_frame(m):
5294        if m.group(1) != file:
5295            return ""
5296        # Don't accept top-level, even for this script, these will wobble
5297        # depending on how the testing script was invoked
5298        if m.group(2) == "<module>":
5299            return ""
5300
5301        return m.group(0)
5302
5303    s = re.sub(r'  File "([^"]+)", line \d+, in (.+)\n(    .+\n( +[~^]+ *\n)?)+', repl_frame, s)
5304    s = re.sub(r"line \d+", "line N", s)
5305    s = re.sub(r".py:\d+", ".py:N", s)
5306    s = re.sub(file, _as_posix_path(os.path.basename(file)), s)
5307    s = re.sub(_as_posix_path(os.path.join(os.path.dirname(torch.__file__), "")), "", s)
5308    if suppress_suffix:
5309        s = re.sub(r"\n*Set TORCH_LOGS.+", "", s, flags=re.DOTALL)
5310        s = re.sub(r"\n*You can suppress this exception.+", "", s, flags=re.DOTALL)
5311    if suppress_prefix:
5312        s = re.sub(r"Cannot export model.+\n\n", "", s)
5313    s = re.sub(r" +$", "", s, flags=re.MULTILINE)
5314    return s
5315