xref: /aosp_15_r20/external/pytorch/test/test_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: unknown"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport contextlib
4*da0073e9SAndroid Build Coastguard Workerimport copy
5*da0073e9SAndroid Build Coastguard Workerimport inspect
6*da0073e9SAndroid Build Coastguard Workerimport itertools
7*da0073e9SAndroid Build Coastguard Workerimport os
8*da0073e9SAndroid Build Coastguard Workerimport re
9*da0073e9SAndroid Build Coastguard Workerimport unittest
10*da0073e9SAndroid Build Coastguard Workerimport warnings
11*da0073e9SAndroid Build Coastguard Workerfrom collections import defaultdict
12*da0073e9SAndroid Build Coastguard Workerfrom collections.abc import Sequence
13*da0073e9SAndroid Build Coastguard Workerfrom functools import partial
14*da0073e9SAndroid Build Coastguard Workerfrom importlib import import_module
15*da0073e9SAndroid Build Coastguard Workerfrom typing import Dict, List
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Workerimport torch
18*da0073e9SAndroid Build Coastguard Workerimport torch._prims as prims
19*da0073e9SAndroid Build Coastguard Workerimport torch.utils._pytree as pytree
20*da0073e9SAndroid Build Coastguard Workerfrom torch._prims.context import TorchRefsMode
21*da0073e9SAndroid Build Coastguard Workerfrom torch._prims_common.wrappers import _maybe_remove_out_wrapper
22*da0073e9SAndroid Build Coastguard Workerfrom torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
23*da0073e9SAndroid Build Coastguard Workerfrom torch._subclasses.fake_utils import outputs_alias_inputs
24*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor
25*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal import composite_compliance, opinfo
26*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import (
27*da0073e9SAndroid Build Coastguard Worker    deviceCountAtLeast,
28*da0073e9SAndroid Build Coastguard Worker    instantiate_device_type_tests,
29*da0073e9SAndroid Build Coastguard Worker    onlyCPU,
30*da0073e9SAndroid Build Coastguard Worker    onlyCUDA,
31*da0073e9SAndroid Build Coastguard Worker    onlyNativeDeviceTypesAnd,
32*da0073e9SAndroid Build Coastguard Worker    OpDTypes,
33*da0073e9SAndroid Build Coastguard Worker    ops,
34*da0073e9SAndroid Build Coastguard Worker    skipMeta,
35*da0073e9SAndroid Build Coastguard Worker)
36*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import (
37*da0073e9SAndroid Build Coastguard Worker    all_types_and_complex_and,
38*da0073e9SAndroid Build Coastguard Worker    floating_and_complex_types_and,
39*da0073e9SAndroid Build Coastguard Worker    integral_types_and,
40*da0073e9SAndroid Build Coastguard Worker)
41*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import (
42*da0073e9SAndroid Build Coastguard Worker    BinaryUfuncInfo,
43*da0073e9SAndroid Build Coastguard Worker    op_db,
44*da0073e9SAndroid Build Coastguard Worker    ops_and_refs,
45*da0073e9SAndroid Build Coastguard Worker    python_ref_db,
46*da0073e9SAndroid Build Coastguard Worker    ReductionOpInfo,
47*da0073e9SAndroid Build Coastguard Worker    ReductionPythonRefInfo,
48*da0073e9SAndroid Build Coastguard Worker    skip,
49*da0073e9SAndroid Build Coastguard Worker    skipOps,
50*da0073e9SAndroid Build Coastguard Worker    SpectralFuncInfo,
51*da0073e9SAndroid Build Coastguard Worker    UnaryUfuncInfo,
52*da0073e9SAndroid Build Coastguard Worker    xfail,
53*da0073e9SAndroid Build Coastguard Worker)
54*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
55*da0073e9SAndroid Build Coastguard Worker    clone_input_helper,
56*da0073e9SAndroid Build Coastguard Worker    first_sample,
57*da0073e9SAndroid Build Coastguard Worker    IS_CI,
58*da0073e9SAndroid Build Coastguard Worker    IS_FBCODE,
59*da0073e9SAndroid Build Coastguard Worker    is_iterable_of_tensors,
60*da0073e9SAndroid Build Coastguard Worker    IS_SANDCASTLE,
61*da0073e9SAndroid Build Coastguard Worker    IS_WINDOWS,
62*da0073e9SAndroid Build Coastguard Worker    noncontiguous_like,
63*da0073e9SAndroid Build Coastguard Worker    parametrize,
64*da0073e9SAndroid Build Coastguard Worker    run_tests,
65*da0073e9SAndroid Build Coastguard Worker    set_default_dtype,
66*da0073e9SAndroid Build Coastguard Worker    skipIfTorchInductor,
67*da0073e9SAndroid Build Coastguard Worker    slowTest,
68*da0073e9SAndroid Build Coastguard Worker    suppress_warnings,
69*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_ASAN,
70*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_ROCM,
71*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_TORCHDYNAMO,
72*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_TORCHINDUCTOR,
73*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_UBSAN,
74*da0073e9SAndroid Build Coastguard Worker    TestCase,
75*da0073e9SAndroid Build Coastguard Worker    unMarkDynamoStrictTest,
76*da0073e9SAndroid Build Coastguard Worker)
77*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._python_dispatch import TorchDispatchMode
78*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._pytree import tree_map
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Workerassert torch.get_default_dtype() == torch.float32
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Worker# variant testing is only done with torch.float and torch.cfloat to avoid
84*da0073e9SAndroid Build Coastguard Worker#   excessive test times and maximize signal to noise ratio
85*da0073e9SAndroid Build Coastguard Worker_variant_ops = partial(
86*da0073e9SAndroid Build Coastguard Worker    ops, dtypes=OpDTypes.supported, allowed_dtypes=(torch.float, torch.cfloat)
87*da0073e9SAndroid Build Coastguard Worker)
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker# Get names of all the operators which have ref in their entry in OpInfo (testing infra)
90*da0073e9SAndroid Build Coastguard Worker#   except for elementwise unary operators (separately implemented in test/test_unary_ufuncs.py),
91*da0073e9SAndroid Build Coastguard Worker#   elementwise binary operators (separately implemented in test_binary_ufuncs.py),
92*da0073e9SAndroid Build Coastguard Worker#   reduction operations (separately impelemented in test_reductions.py),
93*da0073e9SAndroid Build Coastguard Worker#   and Spectral Functions (separately implemented for only 1D as of now, in test/test_spectral_ops.py)
94*da0073e9SAndroid Build Coastguard Worker_ref_test_ops = tuple(
95*da0073e9SAndroid Build Coastguard Worker    filter(
96*da0073e9SAndroid Build Coastguard Worker        lambda op: not isinstance(
97*da0073e9SAndroid Build Coastguard Worker            op, (UnaryUfuncInfo, ReductionOpInfo, SpectralFuncInfo, BinaryUfuncInfo)
98*da0073e9SAndroid Build Coastguard Worker        )
99*da0073e9SAndroid Build Coastguard Worker        and op.ref is not None,
100*da0073e9SAndroid Build Coastguard Worker        op_db,
101*da0073e9SAndroid Build Coastguard Worker    )
102*da0073e9SAndroid Build Coastguard Worker)
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Workerdef reduction_dtype_filter(op):
106*da0073e9SAndroid Build Coastguard Worker    if (
107*da0073e9SAndroid Build Coastguard Worker        not isinstance(op, ReductionPythonRefInfo)
108*da0073e9SAndroid Build Coastguard Worker        or not op.supports_out
109*da0073e9SAndroid Build Coastguard Worker        or torch.int16 not in op.dtypes
110*da0073e9SAndroid Build Coastguard Worker    ):
111*da0073e9SAndroid Build Coastguard Worker        return False
112*da0073e9SAndroid Build Coastguard Worker    return "dtype" in inspect.getfullargspec(op.op).kwonlyargs
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Worker# Create a list of operators that are a subset of _ref_test_ops but don't have a
116*da0073e9SAndroid Build Coastguard Worker# numpy ref to compare them too, If both CPU and CUDA are compared to numpy
117*da0073e9SAndroid Build Coastguard Worker# then they do not need to be compared to each other
118*da0073e9SAndroid Build Coastguard Worker_ops_and_refs_with_no_numpy_ref = [op for op in ops_and_refs if op.ref is None]
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Workeraten = torch.ops.aten
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker# Tests that apply to all operators and aren't related to any particular
124*da0073e9SAndroid Build Coastguard Worker#   system
125*da0073e9SAndroid Build Coastguard Worker@unMarkDynamoStrictTest
126*da0073e9SAndroid Build Coastguard Workerclass TestCommon(TestCase):
127*da0073e9SAndroid Build Coastguard Worker    exact_dtype = True
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker    # Verifies, on teardown, that no OpInfo is still using dynamic dtypes in CI
130*da0073e9SAndroid Build Coastguard Worker    @classmethod
131*da0073e9SAndroid Build Coastguard Worker    def tearDownClass(cls):
132*da0073e9SAndroid Build Coastguard Worker        super().tearDownClass()
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Worker        if IS_CI:
135*da0073e9SAndroid Build Coastguard Worker            err_msg = (
136*da0073e9SAndroid Build Coastguard Worker                "The operator(s) below is(are) using dynamic_dtypes in the OpInfo entries."
137*da0073e9SAndroid Build Coastguard Worker                "This is OK for testing, but be sure to set the dtypes manually before landing your PR!"
138*da0073e9SAndroid Build Coastguard Worker            )
139*da0073e9SAndroid Build Coastguard Worker            # Assure no opinfo entry has dynamic_dtypes
140*da0073e9SAndroid Build Coastguard Worker            filtered_ops = list(filter(opinfo.utils.is_dynamic_dtype_set, op_db))
141*da0073e9SAndroid Build Coastguard Worker            for op in filtered_ops:
142*da0073e9SAndroid Build Coastguard Worker                fmt_str = opinfo.utils.str_format_dynamic_dtype(op)
143*da0073e9SAndroid Build Coastguard Worker                err_msg += "\n" + fmt_str
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Worker            assert len(filtered_ops) == 0, err_msg
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker    # Validates that each OpInfo works correctly on different CUDA devices
148*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
149*da0073e9SAndroid Build Coastguard Worker    @deviceCountAtLeast(2)
150*da0073e9SAndroid Build Coastguard Worker    @ops(op_db, allowed_dtypes=(torch.float32, torch.long))
151*da0073e9SAndroid Build Coastguard Worker    def test_multiple_devices(self, devices, dtype, op):
152*da0073e9SAndroid Build Coastguard Worker        for cuda_device_str in devices:
153*da0073e9SAndroid Build Coastguard Worker            cuda_device = torch.device(cuda_device_str)
154*da0073e9SAndroid Build Coastguard Worker            # NOTE: only tests on first sample
155*da0073e9SAndroid Build Coastguard Worker            samples = op.sample_inputs(cuda_device, dtype)
156*da0073e9SAndroid Build Coastguard Worker            sample = first_sample(self, samples)
157*da0073e9SAndroid Build Coastguard Worker            result = op(sample.input, *sample.args, **sample.kwargs)
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker            if isinstance(result, torch.Tensor):
160*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(result.device == cuda_device)
161*da0073e9SAndroid Build Coastguard Worker            elif is_iterable_of_tensors(result):
162*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(all(t.device == cuda_device for t in result))
163*da0073e9SAndroid Build Coastguard Worker            else:
164*da0073e9SAndroid Build Coastguard Worker                self.skipTest(
165*da0073e9SAndroid Build Coastguard Worker                    "Skipped! Only supports single tensor or iterable of tensor outputs."
166*da0073e9SAndroid Build Coastguard Worker                )
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker    def test_pointwise_tag_coverage(self):
169*da0073e9SAndroid Build Coastguard Worker        pytorch_dir = os.path.abspath(__file__ + "/../../")
170*da0073e9SAndroid Build Coastguard Worker        files = [
171*da0073e9SAndroid Build Coastguard Worker            "aten/src/ATen/native/UnaryOps.cpp",
172*da0073e9SAndroid Build Coastguard Worker            "aten/src/ATen/native/BinaryOps.cpp",
173*da0073e9SAndroid Build Coastguard Worker            "aten/src/ATen/native/PointwiseOps.cpp",
174*da0073e9SAndroid Build Coastguard Worker            "aten/src/ATen/native/TensorCompare.cpp",
175*da0073e9SAndroid Build Coastguard Worker        ]
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker        allowed_functions = (
178*da0073e9SAndroid Build Coastguard Worker            # reduction version of these operators
179*da0073e9SAndroid Build Coastguard Worker            "aten.max.default",
180*da0073e9SAndroid Build Coastguard Worker            "aten.max.dim",
181*da0073e9SAndroid Build Coastguard Worker            "aten.max.dim_max",
182*da0073e9SAndroid Build Coastguard Worker            "aten.max.names_dim",
183*da0073e9SAndroid Build Coastguard Worker            "aten.max.names_dim_max",
184*da0073e9SAndroid Build Coastguard Worker            "aten.max.unary_out",
185*da0073e9SAndroid Build Coastguard Worker            "aten.min.default",
186*da0073e9SAndroid Build Coastguard Worker            "aten.min.dim",
187*da0073e9SAndroid Build Coastguard Worker            "aten.min.dim_min",
188*da0073e9SAndroid Build Coastguard Worker            "aten.min.names_dim",
189*da0073e9SAndroid Build Coastguard Worker            "aten.min.names_dim_min",
190*da0073e9SAndroid Build Coastguard Worker            "aten.min.unary_out",
191*da0073e9SAndroid Build Coastguard Worker            # not pointwise
192*da0073e9SAndroid Build Coastguard Worker            "aten.isin.Tensor_Tensor",
193*da0073e9SAndroid Build Coastguard Worker            "aten.isin.Tensor_Tensor_out",
194*da0073e9SAndroid Build Coastguard Worker            "aten.isin.Tensor_Scalar",
195*da0073e9SAndroid Build Coastguard Worker            "aten.isin.Tensor_Scalar_out",
196*da0073e9SAndroid Build Coastguard Worker            "aten.isin.Scalar_Tensor",
197*da0073e9SAndroid Build Coastguard Worker            "aten.isin.Scalar_Tensor_out",
198*da0073e9SAndroid Build Coastguard Worker            "aten.mode.default",
199*da0073e9SAndroid Build Coastguard Worker            "aten.mode.dimname",
200*da0073e9SAndroid Build Coastguard Worker            "aten.mode.dimname_out",
201*da0073e9SAndroid Build Coastguard Worker            "aten.mode.values",
202*da0073e9SAndroid Build Coastguard Worker        )
203*da0073e9SAndroid Build Coastguard Worker
204*da0073e9SAndroid Build Coastguard Worker        regex = re.compile(r"DEFINE_DISPATCH\(.*_stub")
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Worker        def get_opoverloadpacket_from_dispatch(kernel):
207*da0073e9SAndroid Build Coastguard Worker            if hasattr(torch.ops.aten, kernel):
208*da0073e9SAndroid Build Coastguard Worker                return kernel
209*da0073e9SAndroid Build Coastguard Worker            if hasattr(torch.ops.aten, f"__{kernel}__"):
210*da0073e9SAndroid Build Coastguard Worker                return f"__{kernel}__"
211*da0073e9SAndroid Build Coastguard Worker            if hasattr(torch.ops.aten, f"special_{kernel}"):
212*da0073e9SAndroid Build Coastguard Worker                return f"special_{kernel}"
213*da0073e9SAndroid Build Coastguard Worker            if "_" in kernel:
214*da0073e9SAndroid Build Coastguard Worker                kernel_split = kernel.split("_")
215*da0073e9SAndroid Build Coastguard Worker                new_kernel = "_".join(kernel_split[:-1])
216*da0073e9SAndroid Build Coastguard Worker                if hasattr(torch.ops.aten, new_kernel):
217*da0073e9SAndroid Build Coastguard Worker                    return new_kernel
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker            # could not find op from kernel dispatch string
220*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(False)
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Worker        for file_name in files:
223*da0073e9SAndroid Build Coastguard Worker            with open(os.path.join(pytorch_dir, file_name)) as f:
224*da0073e9SAndroid Build Coastguard Worker                lines = f.read()
225*da0073e9SAndroid Build Coastguard Worker                matches = regex.findall(lines)
226*da0073e9SAndroid Build Coastguard Worker                for match in matches:
227*da0073e9SAndroid Build Coastguard Worker                    kernel = match[len("DEFINE_DISPATCH(") : -len("_stub")]
228*da0073e9SAndroid Build Coastguard Worker
229*da0073e9SAndroid Build Coastguard Worker                    # no op definition for it, but defined with DEFINE_DISPATCH ?
230*da0073e9SAndroid Build Coastguard Worker                    if kernel == "trigamma":
231*da0073e9SAndroid Build Coastguard Worker                        continue
232*da0073e9SAndroid Build Coastguard Worker
233*da0073e9SAndroid Build Coastguard Worker                    kernel = get_opoverloadpacket_from_dispatch(kernel)
234*da0073e9SAndroid Build Coastguard Worker                    overloadpacket = getattr(torch.ops.aten, kernel)
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker                    for overload_name in overloadpacket.overloads():
237*da0073e9SAndroid Build Coastguard Worker                        overload = getattr(overloadpacket, overload_name)
238*da0073e9SAndroid Build Coastguard Worker
239*da0073e9SAndroid Build Coastguard Worker                        if not torch._C._dispatch_has_kernel(overload.name()):
240*da0073e9SAndroid Build Coastguard Worker                            continue
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Worker                        # TODO: tags are not propagated to generated overload,
243*da0073e9SAndroid Build Coastguard Worker                        # and there's no way of specifying them
244*da0073e9SAndroid Build Coastguard Worker                        if torch.Tag.generated in overload.tags:
245*da0073e9SAndroid Build Coastguard Worker                            continue
246*da0073e9SAndroid Build Coastguard Worker
247*da0073e9SAndroid Build Coastguard Worker                        if str(overload) in allowed_functions:
248*da0073e9SAndroid Build Coastguard Worker                            continue
249*da0073e9SAndroid Build Coastguard Worker
250*da0073e9SAndroid Build Coastguard Worker                        self.assertTrue(torch.Tag.pointwise in overload.tags)
251*da0073e9SAndroid Build Coastguard Worker
252*da0073e9SAndroid Build Coastguard Worker    # Tests that the function and its (ndarray-accepting) reference produce the same
253*da0073e9SAndroid Build Coastguard Worker    #   values on the tensors from sample_inputs func for the corresponding op.
254*da0073e9SAndroid Build Coastguard Worker    # This test runs in double and complex double precision because
255*da0073e9SAndroid Build Coastguard Worker    # NumPy does computation internally using double precision for many functions
256*da0073e9SAndroid Build Coastguard Worker    # resulting in possible equality check failures.
257*da0073e9SAndroid Build Coastguard Worker    # skip windows case on CPU due to https://github.com/pytorch/pytorch/issues/129947
258*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypesAnd(["hpu"])
259*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
260*da0073e9SAndroid Build Coastguard Worker    @ops(_ref_test_ops, allowed_dtypes=(torch.float64, torch.long, torch.complex128))
261*da0073e9SAndroid Build Coastguard Worker    def test_numpy_ref(self, device, dtype, op):
262*da0073e9SAndroid Build Coastguard Worker        if (
263*da0073e9SAndroid Build Coastguard Worker            TEST_WITH_TORCHINDUCTOR
264*da0073e9SAndroid Build Coastguard Worker            and op.formatted_name
265*da0073e9SAndroid Build Coastguard Worker            in ("signal_windows_exponential", "signal_windows_bartlett")
266*da0073e9SAndroid Build Coastguard Worker            and dtype == torch.float64
267*da0073e9SAndroid Build Coastguard Worker            and "cuda" in device
268*da0073e9SAndroid Build Coastguard Worker            or "cpu" in device
269*da0073e9SAndroid Build Coastguard Worker        ):  # noqa: E121
270*da0073e9SAndroid Build Coastguard Worker            raise unittest.SkipTest("XXX: raises tensor-likes are not close.")
271*da0073e9SAndroid Build Coastguard Worker
272*da0073e9SAndroid Build Coastguard Worker        # Sets the default dtype to NumPy's default dtype of double
273*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.double):
274*da0073e9SAndroid Build Coastguard Worker            for sample_input in op.reference_inputs(device, dtype):
275*da0073e9SAndroid Build Coastguard Worker                self.compare_with_reference(
276*da0073e9SAndroid Build Coastguard Worker                    op, op.ref, sample_input, exact_dtype=(dtype is not torch.long)
277*da0073e9SAndroid Build Coastguard Worker                )
278*da0073e9SAndroid Build Coastguard Worker
279*da0073e9SAndroid Build Coastguard Worker    # Tests that the cpu and gpu results are consistent
280*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
281*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
282*da0073e9SAndroid Build Coastguard Worker    @slowTest
283*da0073e9SAndroid Build Coastguard Worker    @ops(_ops_and_refs_with_no_numpy_ref, dtypes=OpDTypes.any_common_cpu_cuda_one)
284*da0073e9SAndroid Build Coastguard Worker    def test_compare_cpu(self, device, dtype, op):
285*da0073e9SAndroid Build Coastguard Worker        def to_cpu(arg):
286*da0073e9SAndroid Build Coastguard Worker            if isinstance(arg, torch.Tensor):
287*da0073e9SAndroid Build Coastguard Worker                return arg.to(device="cpu")
288*da0073e9SAndroid Build Coastguard Worker            return arg
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard Worker        samples = op.reference_inputs(device, dtype)
291*da0073e9SAndroid Build Coastguard Worker
292*da0073e9SAndroid Build Coastguard Worker        for sample in samples:
293*da0073e9SAndroid Build Coastguard Worker            cpu_sample = sample.transform(to_cpu)
294*da0073e9SAndroid Build Coastguard Worker            cuda_results = op(sample.input, *sample.args, **sample.kwargs)
295*da0073e9SAndroid Build Coastguard Worker            cpu_results = op(cpu_sample.input, *cpu_sample.args, **cpu_sample.kwargs)
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker            # output_process_fn_grad has a very unfortunate name
298*da0073e9SAndroid Build Coastguard Worker            # We use this function in linalg extensively to postprocess the inputs of functions
299*da0073e9SAndroid Build Coastguard Worker            # that are not completely well-defined. Think svd and muliplying the singular vectors by -1.
300*da0073e9SAndroid Build Coastguard Worker            # CPU and CUDA implementations of the SVD can return valid SVDs that are different.
301*da0073e9SAndroid Build Coastguard Worker            # We use this function to compare them.
302*da0073e9SAndroid Build Coastguard Worker            cuda_results = sample.output_process_fn_grad(cuda_results)
303*da0073e9SAndroid Build Coastguard Worker            cpu_results = cpu_sample.output_process_fn_grad(cpu_results)
304*da0073e9SAndroid Build Coastguard Worker
305*da0073e9SAndroid Build Coastguard Worker            # Lower tolerance because we are running this as a `@slowTest`
306*da0073e9SAndroid Build Coastguard Worker            # Don't want the periodic tests to fail frequently
307*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cuda_results, cpu_results, atol=1e-3, rtol=1e-3)
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Worker    # Tests that experimental Python References can propagate shape, dtype,
310*da0073e9SAndroid Build Coastguard Worker    # and device metadata properly.
311*da0073e9SAndroid Build Coastguard Worker    # See https://github.com/pytorch/pytorch/issues/78050 for a discussion of stride propagation.
312*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypesAnd(["hpu"])
313*da0073e9SAndroid Build Coastguard Worker    @ops(python_ref_db)
314*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("Takes too long for inductor")
315*da0073e9SAndroid Build Coastguard Worker    def test_python_ref_meta(self, device, dtype, op):
316*da0073e9SAndroid Build Coastguard Worker        CHECK_CONJ_SKIPS = {
317*da0073e9SAndroid Build Coastguard Worker            torch._refs.linalg.svd,
318*da0073e9SAndroid Build Coastguard Worker        }
319*da0073e9SAndroid Build Coastguard Worker
320*da0073e9SAndroid Build Coastguard Worker        with FakeTensorMode() as mode:
321*da0073e9SAndroid Build Coastguard Worker            pass
322*da0073e9SAndroid Build Coastguard Worker
323*da0073e9SAndroid Build Coastguard Worker        def _to_tensormeta(x):
324*da0073e9SAndroid Build Coastguard Worker            if isinstance(x, torch.Tensor):
325*da0073e9SAndroid Build Coastguard Worker                out = FakeTensor.from_tensor(x, mode)
326*da0073e9SAndroid Build Coastguard Worker                return out
327*da0073e9SAndroid Build Coastguard Worker            return x
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Worker        # TODO: iterate over requires_grad true/false
330*da0073e9SAndroid Build Coastguard Worker        for sample in op.reference_inputs(device, dtype, requires_grad=False):
331*da0073e9SAndroid Build Coastguard Worker            result = op(sample.input, *sample.args, **sample.kwargs)
332*da0073e9SAndroid Build Coastguard Worker
333*da0073e9SAndroid Build Coastguard Worker            meta_sample = sample.transform(_to_tensormeta)
334*da0073e9SAndroid Build Coastguard Worker            try:
335*da0073e9SAndroid Build Coastguard Worker                with mode:
336*da0073e9SAndroid Build Coastguard Worker                    meta_result = op(
337*da0073e9SAndroid Build Coastguard Worker                        meta_sample.input, *meta_sample.args, **meta_sample.kwargs
338*da0073e9SAndroid Build Coastguard Worker                    )
339*da0073e9SAndroid Build Coastguard Worker            except torch._subclasses.fake_tensor.UnsupportedFakeTensorException:
340*da0073e9SAndroid Build Coastguard Worker                continue
341*da0073e9SAndroid Build Coastguard Worker            except torch._subclasses.fake_tensor.DataDependentOutputException:
342*da0073e9SAndroid Build Coastguard Worker                continue
343*da0073e9SAndroid Build Coastguard Worker            except torch._subclasses.fake_tensor.UnsupportedOperatorException:
344*da0073e9SAndroid Build Coastguard Worker                continue
345*da0073e9SAndroid Build Coastguard Worker
346*da0073e9SAndroid Build Coastguard Worker            if isinstance(result, torch.Tensor):
347*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(isinstance(meta_result, FakeTensor))
348*da0073e9SAndroid Build Coastguard Worker                prims.utils.compare_tensor_meta(
349*da0073e9SAndroid Build Coastguard Worker                    result, meta_result, check_conj=op.op not in CHECK_CONJ_SKIPS
350*da0073e9SAndroid Build Coastguard Worker                )
351*da0073e9SAndroid Build Coastguard Worker            elif isinstance(result, Sequence):
352*da0073e9SAndroid Build Coastguard Worker                for a, b in zip(result, meta_result):
353*da0073e9SAndroid Build Coastguard Worker                    if isinstance(a, torch.Tensor) or isinstance(b, torch.Tensor):
354*da0073e9SAndroid Build Coastguard Worker                        self.assertTrue(isinstance(b, FakeTensor))
355*da0073e9SAndroid Build Coastguard Worker                        prims.utils.compare_tensor_meta(
356*da0073e9SAndroid Build Coastguard Worker                            a, b, check_conj=op.op not in CHECK_CONJ_SKIPS
357*da0073e9SAndroid Build Coastguard Worker                        )
358*da0073e9SAndroid Build Coastguard Worker
359*da0073e9SAndroid Build Coastguard Worker    def _ref_test_helper(
360*da0073e9SAndroid Build Coastguard Worker        self,
361*da0073e9SAndroid Build Coastguard Worker        ctx,
362*da0073e9SAndroid Build Coastguard Worker        device,
363*da0073e9SAndroid Build Coastguard Worker        dtype,
364*da0073e9SAndroid Build Coastguard Worker        op,
365*da0073e9SAndroid Build Coastguard Worker        skip_zero_numel=False,
366*da0073e9SAndroid Build Coastguard Worker        skip_zero_dim=False,
367*da0073e9SAndroid Build Coastguard Worker        skip_bfloat=False,
368*da0073e9SAndroid Build Coastguard Worker        skip_view_consistency=False,
369*da0073e9SAndroid Build Coastguard Worker    ):
370*da0073e9SAndroid Build Coastguard Worker        # NOTE: this test works by comparing the reference
371*da0073e9SAndroid Build Coastguard Worker        ex = None
372*da0073e9SAndroid Build Coastguard Worker        for sample in op.reference_inputs(device, dtype, requires_grad=False):
373*da0073e9SAndroid Build Coastguard Worker            if (
374*da0073e9SAndroid Build Coastguard Worker                isinstance(sample.input, torch.Tensor)
375*da0073e9SAndroid Build Coastguard Worker                and sample.input.numel() == 0
376*da0073e9SAndroid Build Coastguard Worker                and skip_zero_numel
377*da0073e9SAndroid Build Coastguard Worker            ):
378*da0073e9SAndroid Build Coastguard Worker                continue
379*da0073e9SAndroid Build Coastguard Worker            if (
380*da0073e9SAndroid Build Coastguard Worker                isinstance(sample.input, torch.Tensor)
381*da0073e9SAndroid Build Coastguard Worker                and sample.input.ndim == 0
382*da0073e9SAndroid Build Coastguard Worker                and skip_zero_dim
383*da0073e9SAndroid Build Coastguard Worker            ):
384*da0073e9SAndroid Build Coastguard Worker                continue
385*da0073e9SAndroid Build Coastguard Worker
386*da0073e9SAndroid Build Coastguard Worker            if skip_bfloat and (
387*da0073e9SAndroid Build Coastguard Worker                (
388*da0073e9SAndroid Build Coastguard Worker                    isinstance(sample.input, torch.Tensor)
389*da0073e9SAndroid Build Coastguard Worker                    and sample.input.dtype == torch.bfloat16
390*da0073e9SAndroid Build Coastguard Worker                )
391*da0073e9SAndroid Build Coastguard Worker                or any(
392*da0073e9SAndroid Build Coastguard Worker                    isinstance(arg, torch.Tensor) and arg.dtype == torch.bfloat16
393*da0073e9SAndroid Build Coastguard Worker                    for arg in sample.args
394*da0073e9SAndroid Build Coastguard Worker                )
395*da0073e9SAndroid Build Coastguard Worker            ):
396*da0073e9SAndroid Build Coastguard Worker                continue
397*da0073e9SAndroid Build Coastguard Worker            with ctx():
398*da0073e9SAndroid Build Coastguard Worker                ref_result = op(sample.input, *sample.args, **sample.kwargs)
399*da0073e9SAndroid Build Coastguard Worker            torch_result = op.torch_opinfo(sample.input, *sample.args, **sample.kwargs)
400*da0073e9SAndroid Build Coastguard Worker
401*da0073e9SAndroid Build Coastguard Worker            for a, b in zip(
402*da0073e9SAndroid Build Coastguard Worker                pytree.tree_leaves(ref_result), pytree.tree_leaves(torch_result)
403*da0073e9SAndroid Build Coastguard Worker            ):
404*da0073e9SAndroid Build Coastguard Worker                if isinstance(a, torch.Tensor) or isinstance(b, torch.Tensor):
405*da0073e9SAndroid Build Coastguard Worker                    prims.utils.compare_tensor_meta(a, b)
406*da0073e9SAndroid Build Coastguard Worker                    if (
407*da0073e9SAndroid Build Coastguard Worker                        getattr(op, "validate_view_consistency", True)
408*da0073e9SAndroid Build Coastguard Worker                        and not skip_view_consistency
409*da0073e9SAndroid Build Coastguard Worker                    ):
410*da0073e9SAndroid Build Coastguard Worker                        msg = (
411*da0073e9SAndroid Build Coastguard Worker                            f"The torch implementation {'returns' if b._is_view() else 'does not return'} "
412*da0073e9SAndroid Build Coastguard Worker                            f"a view, while the reference {'does' if a._is_view() else 'does not'}"
413*da0073e9SAndroid Build Coastguard Worker                        )
414*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(a._is_view(), b._is_view(), msg)
415*da0073e9SAndroid Build Coastguard Worker
416*da0073e9SAndroid Build Coastguard Worker            # Computes the dtype the more precise computatino would occur in
417*da0073e9SAndroid Build Coastguard Worker            precise_dtype = torch.bool
418*da0073e9SAndroid Build Coastguard Worker            if prims.utils.is_integer_dtype(dtype):
419*da0073e9SAndroid Build Coastguard Worker                # Note: bool and integer dtypes do not have more
420*da0073e9SAndroid Build Coastguard Worker                # precise dtypes -- they simply must be close
421*da0073e9SAndroid Build Coastguard Worker                precise_dtype = dtype
422*da0073e9SAndroid Build Coastguard Worker            if prims.utils.is_float_dtype(dtype):
423*da0073e9SAndroid Build Coastguard Worker                precise_dtype = torch.double
424*da0073e9SAndroid Build Coastguard Worker            if prims.utils.is_complex_dtype(dtype):
425*da0073e9SAndroid Build Coastguard Worker                precise_dtype = torch.cdouble
426*da0073e9SAndroid Build Coastguard Worker
427*da0073e9SAndroid Build Coastguard Worker            # Checks if the results are close
428*da0073e9SAndroid Build Coastguard Worker            try:
429*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
430*da0073e9SAndroid Build Coastguard Worker                    ref_result,
431*da0073e9SAndroid Build Coastguard Worker                    torch_result,
432*da0073e9SAndroid Build Coastguard Worker                    exact_stride=False,
433*da0073e9SAndroid Build Coastguard Worker                    exact_device=True,
434*da0073e9SAndroid Build Coastguard Worker                    exact_layout=True,
435*da0073e9SAndroid Build Coastguard Worker                    exact_is_coalesced=True,
436*da0073e9SAndroid Build Coastguard Worker                )
437*da0073e9SAndroid Build Coastguard Worker            except AssertionError as e:
438*da0073e9SAndroid Build Coastguard Worker                # Raises the error if the precise dtype comparison wouldn't be
439*da0073e9SAndroid Build Coastguard Worker                # different
440*da0073e9SAndroid Build Coastguard Worker                if dtype is precise_dtype:
441*da0073e9SAndroid Build Coastguard Worker                    raise e
442*da0073e9SAndroid Build Coastguard Worker
443*da0073e9SAndroid Build Coastguard Worker                ex = e
444*da0073e9SAndroid Build Coastguard Worker
445*da0073e9SAndroid Build Coastguard Worker            # Goes to next sample if these results are close
446*da0073e9SAndroid Build Coastguard Worker            if not ex:
447*da0073e9SAndroid Build Coastguard Worker                continue
448*da0073e9SAndroid Build Coastguard Worker
449*da0073e9SAndroid Build Coastguard Worker            # If the results are not close, checks that the
450*da0073e9SAndroid Build Coastguard Worker            # reference is more accurate than the torch op
451*da0073e9SAndroid Build Coastguard Worker            def _make_precise(x):
452*da0073e9SAndroid Build Coastguard Worker                if isinstance(x, torch.dtype):
453*da0073e9SAndroid Build Coastguard Worker                    return precise_dtype
454*da0073e9SAndroid Build Coastguard Worker                if isinstance(x, torch.Tensor) and x.dtype is dtype:
455*da0073e9SAndroid Build Coastguard Worker                    return x.to(precise_dtype)
456*da0073e9SAndroid Build Coastguard Worker                return x
457*da0073e9SAndroid Build Coastguard Worker
458*da0073e9SAndroid Build Coastguard Worker            precise_sample = sample.transform(_make_precise)
459*da0073e9SAndroid Build Coastguard Worker            precise_result = op.torch_opinfo(
460*da0073e9SAndroid Build Coastguard Worker                precise_sample.input, *precise_sample.args, **precise_sample.kwargs
461*da0073e9SAndroid Build Coastguard Worker            )
462*da0073e9SAndroid Build Coastguard Worker
463*da0073e9SAndroid Build Coastguard Worker            def _distance(a, b):
464*da0073e9SAndroid Build Coastguard Worker                # Special-cases boolean comparisons
465*da0073e9SAndroid Build Coastguard Worker                if prims.utils.is_boolean_dtype(a.dtype):
466*da0073e9SAndroid Build Coastguard Worker                    assert b.dtype is torch.bool
467*da0073e9SAndroid Build Coastguard Worker                    return (a ^ b).sum()
468*da0073e9SAndroid Build Coastguard Worker
469*da0073e9SAndroid Build Coastguard Worker                same = a == b
470*da0073e9SAndroid Build Coastguard Worker                if prims.utils.is_float_dtype(a.dtype) or prims.utils.is_complex_dtype(
471*da0073e9SAndroid Build Coastguard Worker                    a.dtype
472*da0073e9SAndroid Build Coastguard Worker                ):
473*da0073e9SAndroid Build Coastguard Worker                    same = torch.logical_or(
474*da0073e9SAndroid Build Coastguard Worker                        same, torch.logical_and(torch.isnan(a), torch.isnan(b))
475*da0073e9SAndroid Build Coastguard Worker                    )
476*da0073e9SAndroid Build Coastguard Worker
477*da0073e9SAndroid Build Coastguard Worker                actual_error = torch.where(same, 0, torch.abs(a - b)).sum()
478*da0073e9SAndroid Build Coastguard Worker                return actual_error
479*da0073e9SAndroid Build Coastguard Worker
480*da0073e9SAndroid Build Coastguard Worker            ref_distance = 0
481*da0073e9SAndroid Build Coastguard Worker            for a, b in zip(
482*da0073e9SAndroid Build Coastguard Worker                pytree.tree_leaves(ref_result), pytree.tree_leaves(precise_result)
483*da0073e9SAndroid Build Coastguard Worker            ):
484*da0073e9SAndroid Build Coastguard Worker                ref_distance = ref_distance + _distance(a, b)
485*da0073e9SAndroid Build Coastguard Worker
486*da0073e9SAndroid Build Coastguard Worker            torch_distance = 0
487*da0073e9SAndroid Build Coastguard Worker            for a, b in zip(
488*da0073e9SAndroid Build Coastguard Worker                pytree.tree_leaves(torch_result), pytree.tree_leaves(precise_result)
489*da0073e9SAndroid Build Coastguard Worker            ):
490*da0073e9SAndroid Build Coastguard Worker                torch_distance = torch_distance + _distance(a, b)
491*da0073e9SAndroid Build Coastguard Worker
492*da0073e9SAndroid Build Coastguard Worker            # TODO: consider adding some tolerance to this comparison
493*da0073e9SAndroid Build Coastguard Worker            msg = (
494*da0073e9SAndroid Build Coastguard Worker                f"Reference result was farther ({ref_distance}) from the precise "
495*da0073e9SAndroid Build Coastguard Worker                f"computation than the torch result was ({torch_distance})!"
496*da0073e9SAndroid Build Coastguard Worker            )
497*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(ref_distance <= torch_distance, msg=msg)
498*da0073e9SAndroid Build Coastguard Worker
499*da0073e9SAndroid Build Coastguard Worker        # Reports numerical accuracy discrepancies
500*da0073e9SAndroid Build Coastguard Worker        if ex is not None:
501*da0073e9SAndroid Build Coastguard Worker            msg = "Test passed because the reference was more accurate than the torch operator."
502*da0073e9SAndroid Build Coastguard Worker            warnings.warn(msg)
503*da0073e9SAndroid Build Coastguard Worker
504*da0073e9SAndroid Build Coastguard Worker    # Tests that experimental Python References perform the same computation
505*da0073e9SAndroid Build Coastguard Worker    # as the operators they reference, when operator calls in the torch
506*da0073e9SAndroid Build Coastguard Worker    # namesapce are remapped to the refs namespace (torch.foo becomes refs.foo).
507*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypesAnd(["hpu"])
508*da0073e9SAndroid Build Coastguard Worker    @ops(python_ref_db)
509*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("Takes too long for inductor")
510*da0073e9SAndroid Build Coastguard Worker    def test_python_ref(self, device, dtype, op):
511*da0073e9SAndroid Build Coastguard Worker        # In this test, primTorch refs call into the refs namespace
512*da0073e9SAndroid Build Coastguard Worker        # For example, a ref with torch.foo in it will calls refs.foo instead
513*da0073e9SAndroid Build Coastguard Worker        # Direct calls to refs and prims are not affected
514*da0073e9SAndroid Build Coastguard Worker        if (
515*da0073e9SAndroid Build Coastguard Worker            TEST_WITH_ROCM
516*da0073e9SAndroid Build Coastguard Worker            and (op.name == "_refs.fft.ihfftn" or op.name == "_refs.fft.ihfft2")
517*da0073e9SAndroid Build Coastguard Worker            and dtype == torch.float16
518*da0073e9SAndroid Build Coastguard Worker        ):
519*da0073e9SAndroid Build Coastguard Worker            self.skipTest("Skipped on ROCm")
520*da0073e9SAndroid Build Coastguard Worker        self._ref_test_helper(lambda: TorchRefsMode(strict=True), device, dtype, op)
521*da0073e9SAndroid Build Coastguard Worker
522*da0073e9SAndroid Build Coastguard Worker    # Tests that experimental Python References perform the same computation
523*da0073e9SAndroid Build Coastguard Worker    # as the operators they reference, when operator calls in the torch
524*da0073e9SAndroid Build Coastguard Worker    # namespace are preserved (torch.foo remains torch.foo).
525*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypesAnd(["hpu"])
526*da0073e9SAndroid Build Coastguard Worker    @ops(python_ref_db)
527*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("Takes too long for inductor")
528*da0073e9SAndroid Build Coastguard Worker    def test_python_ref_torch_fallback(self, device, dtype, op):
529*da0073e9SAndroid Build Coastguard Worker        # In this test, refs call into the torch namespace (after the initial invocation)
530*da0073e9SAndroid Build Coastguard Worker        # For example, a ref with torch.foo in it will call torch.foo instead of refs.foo
531*da0073e9SAndroid Build Coastguard Worker        # Direct calls to refs and prims are not translated
532*da0073e9SAndroid Build Coastguard Worker        if TEST_WITH_ROCM and op.name == "_refs.fft.ihfftn" and dtype == torch.float16:
533*da0073e9SAndroid Build Coastguard Worker            self.skipTest("Skipped on ROCm")
534*da0073e9SAndroid Build Coastguard Worker        self._ref_test_helper(contextlib.nullcontext, device, dtype, op)
535*da0073e9SAndroid Build Coastguard Worker
536*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
537*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
538*da0073e9SAndroid Build Coastguard Worker    @ops(python_ref_db)
539*da0073e9SAndroid Build Coastguard Worker    @parametrize("executor", ["aten"])
540*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("Takes too long for inductor")
541*da0073e9SAndroid Build Coastguard Worker    def test_python_ref_executor(self, device, dtype, op, executor):
542*da0073e9SAndroid Build Coastguard Worker        if (
543*da0073e9SAndroid Build Coastguard Worker            TEST_WITH_ROCM
544*da0073e9SAndroid Build Coastguard Worker            and (op.name == "_refs.fft.ihfftn" or op.name == "_refs.fft.ihfft2")
545*da0073e9SAndroid Build Coastguard Worker            and dtype == torch.float16
546*da0073e9SAndroid Build Coastguard Worker        ):
547*da0073e9SAndroid Build Coastguard Worker            self.skipTest("Skipped on ROCm")
548*da0073e9SAndroid Build Coastguard Worker        # skip zero-dim tensors for some composites of reduction operations and view
549*da0073e9SAndroid Build Coastguard Worker        skip_zero_dim_ops = [
550*da0073e9SAndroid Build Coastguard Worker            "_refs.logsumexp",
551*da0073e9SAndroid Build Coastguard Worker            "_refs.log_softmax",
552*da0073e9SAndroid Build Coastguard Worker            "_refs.native_group_norm",
553*da0073e9SAndroid Build Coastguard Worker            "_refs.softmax",
554*da0073e9SAndroid Build Coastguard Worker            "_refs.sum_to_size",
555*da0073e9SAndroid Build Coastguard Worker            "ops.nvprims.view",
556*da0073e9SAndroid Build Coastguard Worker        ]
557*da0073e9SAndroid Build Coastguard Worker
558*da0073e9SAndroid Build Coastguard Worker        from copy import copy
559*da0073e9SAndroid Build Coastguard Worker
560*da0073e9SAndroid Build Coastguard Worker        from torch._prims.executor import make_traced
561*da0073e9SAndroid Build Coastguard Worker
562*da0073e9SAndroid Build Coastguard Worker        op = copy(op)
563*da0073e9SAndroid Build Coastguard Worker        op.op = partial(make_traced(op.op), executor=executor)
564*da0073e9SAndroid Build Coastguard Worker        self._ref_test_helper(contextlib.nullcontext, device, dtype, op)
565*da0073e9SAndroid Build Coastguard Worker
566*da0073e9SAndroid Build Coastguard Worker    @skipMeta
567*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypesAnd(["hpu"])
568*da0073e9SAndroid Build Coastguard Worker    @ops([op for op in op_db if op.error_inputs_func is not None], dtypes=OpDTypes.none)
569*da0073e9SAndroid Build Coastguard Worker    def test_errors(self, device, op):
570*da0073e9SAndroid Build Coastguard Worker        error_inputs = op.error_inputs(device)
571*da0073e9SAndroid Build Coastguard Worker        for ei in error_inputs:
572*da0073e9SAndroid Build Coastguard Worker            si = ei.sample_input
573*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(ei.error_type, ei.error_regex):
574*da0073e9SAndroid Build Coastguard Worker                out = op(si.input, *si.args, **si.kwargs)
575*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(isinstance(out, type(NotImplemented)))
576*da0073e9SAndroid Build Coastguard Worker
577*da0073e9SAndroid Build Coastguard Worker    @skipMeta
578*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypesAnd(["hpu"])
579*da0073e9SAndroid Build Coastguard Worker    @ops(
580*da0073e9SAndroid Build Coastguard Worker        [op for op in op_db if op.error_inputs_sparse_func is not None],
581*da0073e9SAndroid Build Coastguard Worker        dtypes=OpDTypes.none,
582*da0073e9SAndroid Build Coastguard Worker    )
583*da0073e9SAndroid Build Coastguard Worker    @parametrize(
584*da0073e9SAndroid Build Coastguard Worker        "layout",
585*da0073e9SAndroid Build Coastguard Worker        (
586*da0073e9SAndroid Build Coastguard Worker            torch.sparse_csr,
587*da0073e9SAndroid Build Coastguard Worker            torch.sparse_csc,
588*da0073e9SAndroid Build Coastguard Worker            torch.sparse_bsr,
589*da0073e9SAndroid Build Coastguard Worker            torch.sparse_bsc,
590*da0073e9SAndroid Build Coastguard Worker            torch.sparse_coo,
591*da0073e9SAndroid Build Coastguard Worker        ),
592*da0073e9SAndroid Build Coastguard Worker    )
593*da0073e9SAndroid Build Coastguard Worker    def test_errors_sparse(self, device, op, layout):
594*da0073e9SAndroid Build Coastguard Worker        for ei in op.error_inputs_sparse(device, layout):
595*da0073e9SAndroid Build Coastguard Worker            si = ei.sample_input
596*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(ei.error_type, ei.error_regex):
597*da0073e9SAndroid Build Coastguard Worker                out = op(si.input, *si.args, **si.kwargs)
598*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(isinstance(out, type(NotImplemented)))
599*da0073e9SAndroid Build Coastguard Worker
600*da0073e9SAndroid Build Coastguard Worker    @skipMeta
601*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypesAnd(["hpu"])
602*da0073e9SAndroid Build Coastguard Worker    @ops(
603*da0073e9SAndroid Build Coastguard Worker        [op for op in python_ref_db if op.error_inputs_func is not None],
604*da0073e9SAndroid Build Coastguard Worker        dtypes=OpDTypes.none,
605*da0073e9SAndroid Build Coastguard Worker    )
606*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchInductor("Takes too long for inductor")
607*da0073e9SAndroid Build Coastguard Worker    def test_python_ref_errors(self, device, op):
608*da0073e9SAndroid Build Coastguard Worker        mode = FakeTensorMode()
609*da0073e9SAndroid Build Coastguard Worker        with mode:
610*da0073e9SAndroid Build Coastguard Worker            pass
611*da0073e9SAndroid Build Coastguard Worker
612*da0073e9SAndroid Build Coastguard Worker        def _to_tensormeta(x):
613*da0073e9SAndroid Build Coastguard Worker            if isinstance(x, torch.Tensor):
614*da0073e9SAndroid Build Coastguard Worker                return FakeTensor.from_tensor(x, mode)
615*da0073e9SAndroid Build Coastguard Worker            return x
616*da0073e9SAndroid Build Coastguard Worker
617*da0073e9SAndroid Build Coastguard Worker        error_inputs = op.error_inputs(device)
618*da0073e9SAndroid Build Coastguard Worker        for ei in error_inputs:
619*da0073e9SAndroid Build Coastguard Worker            si = ei.sample_input
620*da0073e9SAndroid Build Coastguard Worker            meta_sample = si.transform(_to_tensormeta)
621*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(ei.error_type, ei.error_regex):
622*da0073e9SAndroid Build Coastguard Worker                op(meta_sample.input, *meta_sample.args, **meta_sample.kwargs)
623*da0073e9SAndroid Build Coastguard Worker
624*da0073e9SAndroid Build Coastguard Worker    # Tests that the function produces the same result when called with
625*da0073e9SAndroid Build Coastguard Worker    #   noncontiguous tensors.
626*da0073e9SAndroid Build Coastguard Worker    # TODO: get working with Windows by addressing failing operators
627*da0073e9SAndroid Build Coastguard Worker    # TODO: get working with ASAN by addressing failing operators
628*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "Skipped under Windows")
629*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypesAnd(["hpu"])
630*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
631*da0073e9SAndroid Build Coastguard Worker    @ops(op_db, allowed_dtypes=(torch.float32, torch.long, torch.complex64))
632*da0073e9SAndroid Build Coastguard Worker    def test_noncontiguous_samples(self, device, dtype, op):
633*da0073e9SAndroid Build Coastguard Worker        test_grad = dtype in op.supported_backward_dtypes(torch.device(device).type)
634*da0073e9SAndroid Build Coastguard Worker        sample_inputs = op.sample_inputs(device, dtype, requires_grad=test_grad)
635*da0073e9SAndroid Build Coastguard Worker        for sample_input in sample_inputs:
636*da0073e9SAndroid Build Coastguard Worker            t_inp, t_args, t_kwargs = (
637*da0073e9SAndroid Build Coastguard Worker                sample_input.input,
638*da0073e9SAndroid Build Coastguard Worker                sample_input.args,
639*da0073e9SAndroid Build Coastguard Worker                sample_input.kwargs,
640*da0073e9SAndroid Build Coastguard Worker            )
641*da0073e9SAndroid Build Coastguard Worker            noncontig_sample = sample_input.noncontiguous()
642*da0073e9SAndroid Build Coastguard Worker            n_inp, n_args, n_kwargs = (
643*da0073e9SAndroid Build Coastguard Worker                noncontig_sample.input,
644*da0073e9SAndroid Build Coastguard Worker                noncontig_sample.args,
645*da0073e9SAndroid Build Coastguard Worker                noncontig_sample.kwargs,
646*da0073e9SAndroid Build Coastguard Worker            )
647*da0073e9SAndroid Build Coastguard Worker
648*da0073e9SAndroid Build Coastguard Worker            # validates forward
649*da0073e9SAndroid Build Coastguard Worker            expected = op(t_inp, *t_args, **t_kwargs)
650*da0073e9SAndroid Build Coastguard Worker            actual = op(n_inp, *n_args, **n_kwargs)
651*da0073e9SAndroid Build Coastguard Worker
652*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expected)
653*da0073e9SAndroid Build Coastguard Worker
654*da0073e9SAndroid Build Coastguard Worker            # Validate backward
655*da0073e9SAndroid Build Coastguard Worker            # Short-circuits if the op doesn't support grad in this device x dtype
656*da0073e9SAndroid Build Coastguard Worker            if not test_grad:
657*da0073e9SAndroid Build Coastguard Worker                continue
658*da0073e9SAndroid Build Coastguard Worker
659*da0073e9SAndroid Build Coastguard Worker            expected = sample_input.output_process_fn_grad(expected)
660*da0073e9SAndroid Build Coastguard Worker            actual = sample_input.output_process_fn_grad(actual)
661*da0073e9SAndroid Build Coastguard Worker
662*da0073e9SAndroid Build Coastguard Worker            if isinstance(expected, torch.Tensor):
663*da0073e9SAndroid Build Coastguard Worker                grad_for_expected = torch.randn_like(expected)
664*da0073e9SAndroid Build Coastguard Worker                grad_for_actual = noncontiguous_like(grad_for_expected)
665*da0073e9SAndroid Build Coastguard Worker            elif isinstance(expected, Sequence):
666*da0073e9SAndroid Build Coastguard Worker                # Filter output elements that do not require grad
667*da0073e9SAndroid Build Coastguard Worker                expected = [
668*da0073e9SAndroid Build Coastguard Worker                    t
669*da0073e9SAndroid Build Coastguard Worker                    for t in expected
670*da0073e9SAndroid Build Coastguard Worker                    if isinstance(t, torch.Tensor) and t.requires_grad
671*da0073e9SAndroid Build Coastguard Worker                ]
672*da0073e9SAndroid Build Coastguard Worker                actual = [
673*da0073e9SAndroid Build Coastguard Worker                    n for n in actual if isinstance(n, torch.Tensor) and n.requires_grad
674*da0073e9SAndroid Build Coastguard Worker                ]
675*da0073e9SAndroid Build Coastguard Worker                grad_for_expected = [torch.randn_like(t) for t in expected]
676*da0073e9SAndroid Build Coastguard Worker                grad_for_actual = [noncontiguous_like(n) for n in grad_for_expected]
677*da0073e9SAndroid Build Coastguard Worker            else:
678*da0073e9SAndroid Build Coastguard Worker                # Nothing to do if it returns a scalar or things like that
679*da0073e9SAndroid Build Coastguard Worker                continue
680*da0073e9SAndroid Build Coastguard Worker
681*da0073e9SAndroid Build Coastguard Worker            # Concatenate inputs into a tuple
682*da0073e9SAndroid Build Coastguard Worker            t_inputs = (
683*da0073e9SAndroid Build Coastguard Worker                (t_inp,) + t_args
684*da0073e9SAndroid Build Coastguard Worker                if isinstance(t_inp, torch.Tensor)
685*da0073e9SAndroid Build Coastguard Worker                else tuple(t_inp) + t_args
686*da0073e9SAndroid Build Coastguard Worker            )
687*da0073e9SAndroid Build Coastguard Worker            n_inputs = (
688*da0073e9SAndroid Build Coastguard Worker                (n_inp,) + n_args
689*da0073e9SAndroid Build Coastguard Worker                if isinstance(n_inp, torch.Tensor)
690*da0073e9SAndroid Build Coastguard Worker                else tuple(n_inp) + n_args
691*da0073e9SAndroid Build Coastguard Worker            )
692*da0073e9SAndroid Build Coastguard Worker
693*da0073e9SAndroid Build Coastguard Worker            # Filter the elemnts that are tensors that require grad
694*da0073e9SAndroid Build Coastguard Worker            t_input_tensors = [
695*da0073e9SAndroid Build Coastguard Worker                t for t in t_inputs if isinstance(t, torch.Tensor) and t.requires_grad
696*da0073e9SAndroid Build Coastguard Worker            ]
697*da0073e9SAndroid Build Coastguard Worker            n_input_tensors = [
698*da0073e9SAndroid Build Coastguard Worker                n for n in n_inputs if isinstance(n, torch.Tensor) and n.requires_grad
699*da0073e9SAndroid Build Coastguard Worker            ]
700*da0073e9SAndroid Build Coastguard Worker
701*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(t_input_tensors), len(n_input_tensors))
702*da0073e9SAndroid Build Coastguard Worker
703*da0073e9SAndroid Build Coastguard Worker            # Some functions may not use all the inputs to generate gradients. One of the
704*da0073e9SAndroid Build Coastguard Worker            # few examples of this "odd" behaviour is F.hinge_embedding_loss
705*da0073e9SAndroid Build Coastguard Worker            t_grads = torch.autograd.grad(
706*da0073e9SAndroid Build Coastguard Worker                expected, t_input_tensors, grad_for_expected, allow_unused=True
707*da0073e9SAndroid Build Coastguard Worker            )
708*da0073e9SAndroid Build Coastguard Worker            n_grads = torch.autograd.grad(
709*da0073e9SAndroid Build Coastguard Worker                actual, n_input_tensors, grad_for_actual, allow_unused=True
710*da0073e9SAndroid Build Coastguard Worker            )
711*da0073e9SAndroid Build Coastguard Worker
712*da0073e9SAndroid Build Coastguard Worker            msg = "Got different gradients for contiguous / non-contiguous inputs wrt input {}."
713*da0073e9SAndroid Build Coastguard Worker            for i, (t, n) in enumerate(zip(t_grads, n_grads)):
714*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(t, n, msg=msg.format(i))
715*da0073e9SAndroid Build Coastguard Worker
716*da0073e9SAndroid Build Coastguard Worker    # Separates one case from the following test_out because many ops don't properly implement the
717*da0073e9SAndroid Build Coastguard Worker    #   incorrectly sized out parameter warning properly yet
718*da0073e9SAndroid Build Coastguard Worker    # Cases test here:
719*da0073e9SAndroid Build Coastguard Worker    #   - out= with the correct dtype and device, but the wrong shape
720*da0073e9SAndroid Build Coastguard Worker    @ops(ops_and_refs, dtypes=OpDTypes.none)
721*da0073e9SAndroid Build Coastguard Worker    def test_out_warning(self, device, op):
722*da0073e9SAndroid Build Coastguard Worker        if TEST_WITH_TORCHDYNAMO and op.name == "_refs.clamp":
723*da0073e9SAndroid Build Coastguard Worker            self.skipTest("flaky")
724*da0073e9SAndroid Build Coastguard Worker        # Prefers running in float32 but has a fallback for the first listed supported dtype
725*da0073e9SAndroid Build Coastguard Worker        supported_dtypes = op.supported_dtypes(self.device_type)
726*da0073e9SAndroid Build Coastguard Worker        if len(supported_dtypes) == 0:
727*da0073e9SAndroid Build Coastguard Worker            self.skipTest("Skipped! Op has not supported dtypes on this device.")
728*da0073e9SAndroid Build Coastguard Worker        dtype = (
729*da0073e9SAndroid Build Coastguard Worker            torch.float32
730*da0073e9SAndroid Build Coastguard Worker            if torch.float32 in supported_dtypes
731*da0073e9SAndroid Build Coastguard Worker            else next(iter(supported_dtypes))
732*da0073e9SAndroid Build Coastguard Worker        )
733*da0073e9SAndroid Build Coastguard Worker
734*da0073e9SAndroid Build Coastguard Worker        # Ops from python_ref_db point to python decomps that are potentially
735*da0073e9SAndroid Build Coastguard Worker        # wrapped with `torch._prims_common.wrappers.out_wrapper`. Unwrap these
736*da0073e9SAndroid Build Coastguard Worker        # ops before testing to avoid clashing with OpInfo.supports_out
737*da0073e9SAndroid Build Coastguard Worker        if not op.supports_out:
738*da0073e9SAndroid Build Coastguard Worker            op = copy.copy(op)
739*da0073e9SAndroid Build Coastguard Worker            op.op = _maybe_remove_out_wrapper(op.op)
740*da0073e9SAndroid Build Coastguard Worker
741*da0073e9SAndroid Build Coastguard Worker        samples = op.sample_inputs(device, dtype)
742*da0073e9SAndroid Build Coastguard Worker        for sample in samples:
743*da0073e9SAndroid Build Coastguard Worker            # calls it normally to get the expected result
744*da0073e9SAndroid Build Coastguard Worker            expected = op(sample.input, *sample.args, **sample.kwargs)
745*da0073e9SAndroid Build Coastguard Worker            op_out = partial(op, sample.input, *sample.args, **sample.kwargs)
746*da0073e9SAndroid Build Coastguard Worker
747*da0073e9SAndroid Build Coastguard Worker            # Short-circuits if output is not a single tensor or an
748*da0073e9SAndroid Build Coastguard Worker            #   iterable of tensors
749*da0073e9SAndroid Build Coastguard Worker            if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors(
750*da0073e9SAndroid Build Coastguard Worker                expected, include_empty=True
751*da0073e9SAndroid Build Coastguard Worker            ):
752*da0073e9SAndroid Build Coastguard Worker                self.skipTest(
753*da0073e9SAndroid Build Coastguard Worker                    "Skipped! Only supports single tensor or iterable of tensor outputs."
754*da0073e9SAndroid Build Coastguard Worker                )
755*da0073e9SAndroid Build Coastguard Worker
756*da0073e9SAndroid Build Coastguard Worker            # Validates the op doesn't support out if it claims not to
757*da0073e9SAndroid Build Coastguard Worker            if not op.supports_out:
758*da0073e9SAndroid Build Coastguard Worker                with self.assertRaises(Exception):
759*da0073e9SAndroid Build Coastguard Worker                    assert op_out(out=expected) != NotImplemented
760*da0073e9SAndroid Build Coastguard Worker                return
761*da0073e9SAndroid Build Coastguard Worker
762*da0073e9SAndroid Build Coastguard Worker            # A wrapper around map that works with single tensors and always
763*da0073e9SAndroid Build Coastguard Worker            #   instantiates the map. Used below to apply transforms to
764*da0073e9SAndroid Build Coastguard Worker            #   single tensor and iterable tensor outputs.
765*da0073e9SAndroid Build Coastguard Worker            def _apply_out_transform(fn, out):
766*da0073e9SAndroid Build Coastguard Worker                if isinstance(out, torch.Tensor):
767*da0073e9SAndroid Build Coastguard Worker                    return fn(out)
768*da0073e9SAndroid Build Coastguard Worker
769*da0073e9SAndroid Build Coastguard Worker                # assumes (see above) that out is an iterable of tensors
770*da0073e9SAndroid Build Coastguard Worker                return tuple(map(fn, out))
771*da0073e9SAndroid Build Coastguard Worker
772*da0073e9SAndroid Build Coastguard Worker            # Extracts strides from a tensor or iterable of tensors into a tuple
773*da0073e9SAndroid Build Coastguard Worker            def _extract_strides(out):
774*da0073e9SAndroid Build Coastguard Worker                if isinstance(out, torch.Tensor):
775*da0073e9SAndroid Build Coastguard Worker                    return (out.stride(),)
776*da0073e9SAndroid Build Coastguard Worker
777*da0073e9SAndroid Build Coastguard Worker                # assumes (see above) that out is an iterable of tensors
778*da0073e9SAndroid Build Coastguard Worker                return tuple(t.stride() for t in out)
779*da0073e9SAndroid Build Coastguard Worker
780*da0073e9SAndroid Build Coastguard Worker            # Extracts data pointers from a tensor or iterable of tensors into a tuple
781*da0073e9SAndroid Build Coastguard Worker            # NOTE: only extracts on the CPU and CUDA device types since some
782*da0073e9SAndroid Build Coastguard Worker            #   device types don't have storage
783*da0073e9SAndroid Build Coastguard Worker            def _extract_data_ptrs(out):
784*da0073e9SAndroid Build Coastguard Worker                if self.device_type != "cpu" and self.device_type != "cuda":
785*da0073e9SAndroid Build Coastguard Worker                    return ()
786*da0073e9SAndroid Build Coastguard Worker
787*da0073e9SAndroid Build Coastguard Worker                if isinstance(out, torch.Tensor):
788*da0073e9SAndroid Build Coastguard Worker                    return (out.data_ptr(),)
789*da0073e9SAndroid Build Coastguard Worker
790*da0073e9SAndroid Build Coastguard Worker                # assumes (see above) that out is an iterable of tensors
791*da0073e9SAndroid Build Coastguard Worker                return tuple(t.data_ptr() for t in out)
792*da0073e9SAndroid Build Coastguard Worker
793*da0073e9SAndroid Build Coastguard Worker            @suppress_warnings
794*da0073e9SAndroid Build Coastguard Worker            def _compare_out(transform, *, compare_strides_and_data_ptrs=True):
795*da0073e9SAndroid Build Coastguard Worker                out = _apply_out_transform(transform, expected)
796*da0073e9SAndroid Build Coastguard Worker                original_strides = _extract_strides(out)
797*da0073e9SAndroid Build Coastguard Worker                original_ptrs = _extract_data_ptrs(out)
798*da0073e9SAndroid Build Coastguard Worker
799*da0073e9SAndroid Build Coastguard Worker                op_out(out=out)
800*da0073e9SAndroid Build Coastguard Worker                final_strides = _extract_strides(out)
801*da0073e9SAndroid Build Coastguard Worker                final_ptrs = _extract_data_ptrs(out)
802*da0073e9SAndroid Build Coastguard Worker
803*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expected, out)
804*da0073e9SAndroid Build Coastguard Worker
805*da0073e9SAndroid Build Coastguard Worker                if compare_strides_and_data_ptrs:
806*da0073e9SAndroid Build Coastguard Worker                    stride_msg = (
807*da0073e9SAndroid Build Coastguard Worker                        f"Strides are not the same! Original strides were {original_strides} "
808*da0073e9SAndroid Build Coastguard Worker                        f"and strides are now {final_strides}"
809*da0073e9SAndroid Build Coastguard Worker                    )
810*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(original_strides, final_strides, msg=stride_msg)
811*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(original_ptrs, final_ptrs)
812*da0073e9SAndroid Build Coastguard Worker
813*da0073e9SAndroid Build Coastguard Worker            # Case Zero: out= with the correct dtype and device, but the wrong shape
814*da0073e9SAndroid Build Coastguard Worker            #   Expected behavior: if nonempty, resize with a warning.
815*da0073e9SAndroid Build Coastguard Worker            def _case_zero_transform(t):
816*da0073e9SAndroid Build Coastguard Worker                wrong_shape = list(t.shape)
817*da0073e9SAndroid Build Coastguard Worker
818*da0073e9SAndroid Build Coastguard Worker                if len(wrong_shape) == 0:
819*da0073e9SAndroid Build Coastguard Worker                    # Handles scalar tensor case (empty list)
820*da0073e9SAndroid Build Coastguard Worker                    wrong_shape = [2]
821*da0073e9SAndroid Build Coastguard Worker                else:
822*da0073e9SAndroid Build Coastguard Worker                    wrong_shape[-1] = wrong_shape[-1] + 1
823*da0073e9SAndroid Build Coastguard Worker                return make_tensor(wrong_shape, dtype=t.dtype, device=t.device)
824*da0073e9SAndroid Build Coastguard Worker
825*da0073e9SAndroid Build Coastguard Worker            # Verifies the out values are correct
826*da0073e9SAndroid Build Coastguard Worker            _compare_out(_case_zero_transform, compare_strides_and_data_ptrs=False)
827*da0073e9SAndroid Build Coastguard Worker
828*da0073e9SAndroid Build Coastguard Worker            # Additionally validates that the appropriate warning is thrown if a nonempty
829*da0073e9SAndroid Build Coastguard Worker            #   tensor is resized.
830*da0073e9SAndroid Build Coastguard Worker            def _any_nonempty(out):
831*da0073e9SAndroid Build Coastguard Worker                if isinstance(out, torch.Tensor):
832*da0073e9SAndroid Build Coastguard Worker                    return out.numel() > 0
833*da0073e9SAndroid Build Coastguard Worker
834*da0073e9SAndroid Build Coastguard Worker                return any(x.numel() > 0 for x in out)
835*da0073e9SAndroid Build Coastguard Worker
836*da0073e9SAndroid Build Coastguard Worker            out = _apply_out_transform(_case_zero_transform, expected)
837*da0073e9SAndroid Build Coastguard Worker            msg_fail = "Resized a non-empty tensor but did not warn about it."
838*da0073e9SAndroid Build Coastguard Worker            if _any_nonempty(out):
839*da0073e9SAndroid Build Coastguard Worker                with self.assertWarnsRegex(
840*da0073e9SAndroid Build Coastguard Worker                    UserWarning, "An output with one or more elements", msg=msg_fail
841*da0073e9SAndroid Build Coastguard Worker                ):
842*da0073e9SAndroid Build Coastguard Worker                    op_out(out=out)
843*da0073e9SAndroid Build Coastguard Worker
844*da0073e9SAndroid Build Coastguard Worker    # Validates ops implement the correct out= behavior
845*da0073e9SAndroid Build Coastguard Worker    # See https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-does-out-work-in-pytorch
846*da0073e9SAndroid Build Coastguard Worker    #   for a description of the correct behavior
847*da0073e9SAndroid Build Coastguard Worker    # Validates the following cases:
848*da0073e9SAndroid Build Coastguard Worker    #   - Case 0: out has the correct shape, dtype, and device but is full of extremal values
849*da0073e9SAndroid Build Coastguard Worker    #   - Case 1: out has the correct shape, dtype, and device but is noncontiguous
850*da0073e9SAndroid Build Coastguard Worker    #   - Case 2: out has the correct dtype and device, but is zero elements
851*da0073e9SAndroid Build Coastguard Worker    #   - Case 3: out has the correct shape and dtype, but is on a different device type
852*da0073e9SAndroid Build Coastguard Worker    #   - Case 4: out has the correct shape and device, but a dtype that cannot
853*da0073e9SAndroid Build Coastguard Worker    #       "safely" cast to
854*da0073e9SAndroid Build Coastguard Worker    #
855*da0073e9SAndroid Build Coastguard Worker    # Case 3 and 4 are slightly different when the op is a factory function:
856*da0073e9SAndroid Build Coastguard Worker    #   - if device, dtype are NOT passed, any combination of dtype/device should be OK for out
857*da0073e9SAndroid Build Coastguard Worker    #   - if device, dtype are passed, device and dtype should match
858*da0073e9SAndroid Build Coastguard Worker    @ops(ops_and_refs, dtypes=OpDTypes.any_one)
859*da0073e9SAndroid Build Coastguard Worker    def test_out(self, device, dtype, op):
860*da0073e9SAndroid Build Coastguard Worker        # Prefers running in float32 but has a fallback for the first listed supported dtype
861*da0073e9SAndroid Build Coastguard Worker        samples = op.sample_inputs(device, dtype)
862*da0073e9SAndroid Build Coastguard Worker
863*da0073e9SAndroid Build Coastguard Worker        # Ops from python_ref_db point to python decomps that are potentially
864*da0073e9SAndroid Build Coastguard Worker        # wrapped with `torch._prims_common.wrappers.out_wrapper`. Unwrap these
865*da0073e9SAndroid Build Coastguard Worker        # ops before testing to avoid clashing with OpInfo.supports_out
866*da0073e9SAndroid Build Coastguard Worker        if not op.supports_out:
867*da0073e9SAndroid Build Coastguard Worker            op = copy.copy(op)
868*da0073e9SAndroid Build Coastguard Worker            op.op = _maybe_remove_out_wrapper(op.op)
869*da0073e9SAndroid Build Coastguard Worker
870*da0073e9SAndroid Build Coastguard Worker        for sample in samples:
871*da0073e9SAndroid Build Coastguard Worker            # calls it normally to get the expected result
872*da0073e9SAndroid Build Coastguard Worker            expected = op(sample.input, *sample.args, **sample.kwargs)
873*da0073e9SAndroid Build Coastguard Worker            op_out = partial(op, sample.input, *sample.args, **sample.kwargs)
874*da0073e9SAndroid Build Coastguard Worker
875*da0073e9SAndroid Build Coastguard Worker            # Short-circuits if output is not a single tensor or an
876*da0073e9SAndroid Build Coastguard Worker            #   iterable of tensors
877*da0073e9SAndroid Build Coastguard Worker            if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors(
878*da0073e9SAndroid Build Coastguard Worker                expected, include_empty=True
879*da0073e9SAndroid Build Coastguard Worker            ):
880*da0073e9SAndroid Build Coastguard Worker                self.skipTest(
881*da0073e9SAndroid Build Coastguard Worker                    "Skipped! Only supports single tensor or iterable of tensor outputs."
882*da0073e9SAndroid Build Coastguard Worker                )
883*da0073e9SAndroid Build Coastguard Worker
884*da0073e9SAndroid Build Coastguard Worker            # Validates the op doesn't support out if it claims not to
885*da0073e9SAndroid Build Coastguard Worker            if not op.supports_out:
886*da0073e9SAndroid Build Coastguard Worker                with self.assertRaises(Exception):
887*da0073e9SAndroid Build Coastguard Worker                    assert op_out(out=expected) != NotImplemented
888*da0073e9SAndroid Build Coastguard Worker                return
889*da0073e9SAndroid Build Coastguard Worker
890*da0073e9SAndroid Build Coastguard Worker            # A wrapper around map that works with single tensors and always
891*da0073e9SAndroid Build Coastguard Worker            #   instantiates the map. Used below to apply transforms to
892*da0073e9SAndroid Build Coastguard Worker            #   single tensor and iterable tensor outputs.
893*da0073e9SAndroid Build Coastguard Worker            def _apply_out_transform(fn, out):
894*da0073e9SAndroid Build Coastguard Worker                if isinstance(out, torch.Tensor):
895*da0073e9SAndroid Build Coastguard Worker                    return fn(out)
896*da0073e9SAndroid Build Coastguard Worker
897*da0073e9SAndroid Build Coastguard Worker                # assumes (see above) that out is an iterable of tensors
898*da0073e9SAndroid Build Coastguard Worker                return tuple(map(fn, out))
899*da0073e9SAndroid Build Coastguard Worker
900*da0073e9SAndroid Build Coastguard Worker            # Extracts strides from a tensor or iterable of tensors into a tuple
901*da0073e9SAndroid Build Coastguard Worker            def _extract_strides(out):
902*da0073e9SAndroid Build Coastguard Worker                if isinstance(out, torch.Tensor):
903*da0073e9SAndroid Build Coastguard Worker                    return (out.stride(),)
904*da0073e9SAndroid Build Coastguard Worker
905*da0073e9SAndroid Build Coastguard Worker                # assumes (see above) that out is an iterable of tensors
906*da0073e9SAndroid Build Coastguard Worker                return tuple(t.stride() for t in out)
907*da0073e9SAndroid Build Coastguard Worker
908*da0073e9SAndroid Build Coastguard Worker            # Extracts data pointers from a tensor or iterable of tensors into a tuple
909*da0073e9SAndroid Build Coastguard Worker            # NOTE: only extracts on the CPU and CUDA device types since some
910*da0073e9SAndroid Build Coastguard Worker            #   device types don't have storage
911*da0073e9SAndroid Build Coastguard Worker            def _extract_data_ptrs(out):
912*da0073e9SAndroid Build Coastguard Worker                if self.device_type != "cpu" and self.device_type != "cuda":
913*da0073e9SAndroid Build Coastguard Worker                    return ()
914*da0073e9SAndroid Build Coastguard Worker
915*da0073e9SAndroid Build Coastguard Worker                if isinstance(out, torch.Tensor):
916*da0073e9SAndroid Build Coastguard Worker                    return (out.data_ptr(),)
917*da0073e9SAndroid Build Coastguard Worker
918*da0073e9SAndroid Build Coastguard Worker                # assumes (see above) that out is an iterable of tensors
919*da0073e9SAndroid Build Coastguard Worker                return tuple(t.data_ptr() for t in out)
920*da0073e9SAndroid Build Coastguard Worker
921*da0073e9SAndroid Build Coastguard Worker            def _compare_out(transform, *, compare_strides_and_data_ptrs=True):
922*da0073e9SAndroid Build Coastguard Worker                out = _apply_out_transform(transform, expected)
923*da0073e9SAndroid Build Coastguard Worker                original_strides = _extract_strides(out)
924*da0073e9SAndroid Build Coastguard Worker                original_ptrs = _extract_data_ptrs(out)
925*da0073e9SAndroid Build Coastguard Worker
926*da0073e9SAndroid Build Coastguard Worker                op_out(out=out)
927*da0073e9SAndroid Build Coastguard Worker                final_strides = _extract_strides(out)
928*da0073e9SAndroid Build Coastguard Worker                final_ptrs = _extract_data_ptrs(out)
929*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expected, out)
930*da0073e9SAndroid Build Coastguard Worker
931*da0073e9SAndroid Build Coastguard Worker                if compare_strides_and_data_ptrs:
932*da0073e9SAndroid Build Coastguard Worker                    stride_msg = (
933*da0073e9SAndroid Build Coastguard Worker                        "Strides are not the same! "
934*da0073e9SAndroid Build Coastguard Worker                        f"Original strides were {original_strides} and strides are now {final_strides}"
935*da0073e9SAndroid Build Coastguard Worker                    )
936*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(original_strides, final_strides, msg=stride_msg)
937*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(original_ptrs, final_ptrs)
938*da0073e9SAndroid Build Coastguard Worker
939*da0073e9SAndroid Build Coastguard Worker            # Case 0: out= with the correct shape, dtype, and device
940*da0073e9SAndroid Build Coastguard Worker            #   but NaN values for floating point and complex tensors, and
941*da0073e9SAndroid Build Coastguard Worker            #   maximum values for integer tensors.
942*da0073e9SAndroid Build Coastguard Worker            #   Expected behavior: out= values have no effect on the computation.
943*da0073e9SAndroid Build Coastguard Worker            def _case_zero_transform(t):
944*da0073e9SAndroid Build Coastguard Worker                try:
945*da0073e9SAndroid Build Coastguard Worker                    info = torch.iinfo(t.dtype)
946*da0073e9SAndroid Build Coastguard Worker                    return torch.full_like(t, info.max)
947*da0073e9SAndroid Build Coastguard Worker                except TypeError as te:
948*da0073e9SAndroid Build Coastguard Worker                    # for non-integer types fills with NaN
949*da0073e9SAndroid Build Coastguard Worker                    return torch.full_like(t, float("nan"))
950*da0073e9SAndroid Build Coastguard Worker
951*da0073e9SAndroid Build Coastguard Worker            _compare_out(_case_zero_transform)
952*da0073e9SAndroid Build Coastguard Worker
953*da0073e9SAndroid Build Coastguard Worker            # Case 1: out= with the correct shape, dtype, and device,
954*da0073e9SAndroid Build Coastguard Worker            #   but noncontiguous.
955*da0073e9SAndroid Build Coastguard Worker            #   Expected behavior: strides are respected and `out` storage is not changed.
956*da0073e9SAndroid Build Coastguard Worker            def _case_one_transform(t):
957*da0073e9SAndroid Build Coastguard Worker                return make_tensor(
958*da0073e9SAndroid Build Coastguard Worker                    t.shape, dtype=t.dtype, device=t.device, noncontiguous=True
959*da0073e9SAndroid Build Coastguard Worker                )
960*da0073e9SAndroid Build Coastguard Worker
961*da0073e9SAndroid Build Coastguard Worker            _compare_out(_case_one_transform)
962*da0073e9SAndroid Build Coastguard Worker
963*da0073e9SAndroid Build Coastguard Worker            # Case 2: out= with the correct dtype and device, but has no elements.
964*da0073e9SAndroid Build Coastguard Worker            #   Expected behavior: resize without warning.
965*da0073e9SAndroid Build Coastguard Worker            def _case_two_transform(t):
966*da0073e9SAndroid Build Coastguard Worker                return make_tensor((0,), dtype=t.dtype, device=t.device)
967*da0073e9SAndroid Build Coastguard Worker
968*da0073e9SAndroid Build Coastguard Worker            _compare_out(_case_two_transform, compare_strides_and_data_ptrs=False)
969*da0073e9SAndroid Build Coastguard Worker
970*da0073e9SAndroid Build Coastguard Worker            # Also validates that no warning is thrown when this out is resized
971*da0073e9SAndroid Build Coastguard Worker            out = _apply_out_transform(_case_two_transform, expected)
972*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings(record=True) as caught:
973*da0073e9SAndroid Build Coastguard Worker                warnings.simplefilter("always")
974*da0073e9SAndroid Build Coastguard Worker                op_out(out=out)
975*da0073e9SAndroid Build Coastguard Worker
976*da0073e9SAndroid Build Coastguard Worker            # Verifies no warning is a resize warning
977*da0073e9SAndroid Build Coastguard Worker            for w in caught:
978*da0073e9SAndroid Build Coastguard Worker                if "An output with one or more elements" in str(w.message):
979*da0073e9SAndroid Build Coastguard Worker                    self.fail(
980*da0073e9SAndroid Build Coastguard Worker                        "Resizing an out= argument with no elements threw a resize warning!"
981*da0073e9SAndroid Build Coastguard Worker                    )
982*da0073e9SAndroid Build Coastguard Worker
983*da0073e9SAndroid Build Coastguard Worker            # Case 3: out= with correct shape and dtype, but wrong device.
984*da0073e9SAndroid Build Coastguard Worker            wrong_device = None
985*da0073e9SAndroid Build Coastguard Worker            if torch.device(device).type != "cpu":
986*da0073e9SAndroid Build Coastguard Worker                wrong_device = "cpu"
987*da0073e9SAndroid Build Coastguard Worker            elif torch.cuda.is_available():
988*da0073e9SAndroid Build Coastguard Worker                wrong_device = "cuda"
989*da0073e9SAndroid Build Coastguard Worker
990*da0073e9SAndroid Build Coastguard Worker            factory_fn_msg = (
991*da0073e9SAndroid Build Coastguard Worker                "\n\nNOTE: If your op is a factory function (i.e., it accepts TensorOptions) you should mark its "
992*da0073e9SAndroid Build Coastguard Worker                "OpInfo with `is_factory_function=True`."
993*da0073e9SAndroid Build Coastguard Worker            )
994*da0073e9SAndroid Build Coastguard Worker            if wrong_device is not None:
995*da0073e9SAndroid Build Coastguard Worker
996*da0073e9SAndroid Build Coastguard Worker                def _case_three_transform(t):
997*da0073e9SAndroid Build Coastguard Worker                    return make_tensor(t.shape, dtype=t.dtype, device=wrong_device)
998*da0073e9SAndroid Build Coastguard Worker
999*da0073e9SAndroid Build Coastguard Worker                out = _apply_out_transform(_case_three_transform, expected)
1000*da0073e9SAndroid Build Coastguard Worker
1001*da0073e9SAndroid Build Coastguard Worker                if op.is_factory_function and sample.kwargs.get("device", None) is None:
1002*da0073e9SAndroid Build Coastguard Worker                    op_out(out=out)
1003*da0073e9SAndroid Build Coastguard Worker                else:
1004*da0073e9SAndroid Build Coastguard Worker                    msg_fail = (
1005*da0073e9SAndroid Build Coastguard Worker                        f"Expected RuntimeError when calling with input.device={device} and out.device={wrong_device}."
1006*da0073e9SAndroid Build Coastguard Worker                    ) + factory_fn_msg
1007*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaises(RuntimeError, msg=msg_fail):
1008*da0073e9SAndroid Build Coastguard Worker                        op_out(out=out)
1009*da0073e9SAndroid Build Coastguard Worker
1010*da0073e9SAndroid Build Coastguard Worker            # Case 4: out= with correct shape and device, but a dtype
1011*da0073e9SAndroid Build Coastguard Worker            #   that output cannot be "safely" cast to (long).
1012*da0073e9SAndroid Build Coastguard Worker            #   Expected behavior: error.
1013*da0073e9SAndroid Build Coastguard Worker            # NOTE: this case is filtered by dtype since some ops produce
1014*da0073e9SAndroid Build Coastguard Worker            #   bool tensors, for example, which can be safely cast to any
1015*da0073e9SAndroid Build Coastguard Worker            #   dtype. It is applied when single tensors are floating point or complex
1016*da0073e9SAndroid Build Coastguard Worker            #   dtypes, or if an op returns multiple tensors when at least one such
1017*da0073e9SAndroid Build Coastguard Worker            #   tensor is a floating point or complex dtype.
1018*da0073e9SAndroid Build Coastguard Worker            _dtypes = floating_and_complex_types_and(torch.float16, torch.bfloat16)
1019*da0073e9SAndroid Build Coastguard Worker            if (
1020*da0073e9SAndroid Build Coastguard Worker                isinstance(expected, torch.Tensor)
1021*da0073e9SAndroid Build Coastguard Worker                and expected.dtype in _dtypes
1022*da0073e9SAndroid Build Coastguard Worker                or (
1023*da0073e9SAndroid Build Coastguard Worker                    not isinstance(expected, torch.Tensor)
1024*da0073e9SAndroid Build Coastguard Worker                    and any(t.dtype in _dtypes for t in expected)
1025*da0073e9SAndroid Build Coastguard Worker                )
1026*da0073e9SAndroid Build Coastguard Worker            ):
1027*da0073e9SAndroid Build Coastguard Worker
1028*da0073e9SAndroid Build Coastguard Worker                def _case_four_transform(t):
1029*da0073e9SAndroid Build Coastguard Worker                    return make_tensor(t.shape, dtype=torch.long, device=t.device)
1030*da0073e9SAndroid Build Coastguard Worker
1031*da0073e9SAndroid Build Coastguard Worker                out = _apply_out_transform(_case_four_transform, expected)
1032*da0073e9SAndroid Build Coastguard Worker                msg_fail = "Expected RuntimeError when doing an unsafe cast!"
1033*da0073e9SAndroid Build Coastguard Worker                msg_fail = (
1034*da0073e9SAndroid Build Coastguard Worker                    msg_fail
1035*da0073e9SAndroid Build Coastguard Worker                    if not isinstance(expected, torch.Tensor)
1036*da0073e9SAndroid Build Coastguard Worker                    else (
1037*da0073e9SAndroid Build Coastguard Worker                        "Expected RuntimeError when doing an unsafe cast from a result of dtype "
1038*da0073e9SAndroid Build Coastguard Worker                        f"{expected.dtype} into an out= with dtype torch.long"
1039*da0073e9SAndroid Build Coastguard Worker                    )
1040*da0073e9SAndroid Build Coastguard Worker                ) + factory_fn_msg
1041*da0073e9SAndroid Build Coastguard Worker
1042*da0073e9SAndroid Build Coastguard Worker                if op.is_factory_function and sample.kwargs.get("dtype", None) is None:
1043*da0073e9SAndroid Build Coastguard Worker                    op_out(out=out)
1044*da0073e9SAndroid Build Coastguard Worker                else:
1045*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaises(RuntimeError, msg=msg_fail):
1046*da0073e9SAndroid Build Coastguard Worker                        op_out(out=out)
1047*da0073e9SAndroid Build Coastguard Worker
1048*da0073e9SAndroid Build Coastguard Worker    @ops(
1049*da0073e9SAndroid Build Coastguard Worker        [
1050*da0073e9SAndroid Build Coastguard Worker            op
1051*da0073e9SAndroid Build Coastguard Worker            for op in op_db
1052*da0073e9SAndroid Build Coastguard Worker            if op.supports_out and (op.supports_autograd or op.is_factory_function)
1053*da0073e9SAndroid Build Coastguard Worker        ],
1054*da0073e9SAndroid Build Coastguard Worker        dtypes=OpDTypes.supported,
1055*da0073e9SAndroid Build Coastguard Worker        allowed_dtypes=[torch.float, torch.cfloat],
1056*da0073e9SAndroid Build Coastguard Worker    )
1057*da0073e9SAndroid Build Coastguard Worker    def test_out_requires_grad_error(self, device, dtype, op):
1058*da0073e9SAndroid Build Coastguard Worker        sample = first_sample(self, op.sample_inputs(device, dtype))
1059*da0073e9SAndroid Build Coastguard Worker
1060*da0073e9SAndroid Build Coastguard Worker        # Call op to get prototype for out arguments
1061*da0073e9SAndroid Build Coastguard Worker        expect = op(sample.input, *sample.args, **sample.kwargs)
1062*da0073e9SAndroid Build Coastguard Worker        any_requires_grad = False
1063*da0073e9SAndroid Build Coastguard Worker
1064*da0073e9SAndroid Build Coastguard Worker        def set_requires_grad(x):
1065*da0073e9SAndroid Build Coastguard Worker            nonlocal any_requires_grad
1066*da0073e9SAndroid Build Coastguard Worker            if isinstance(x, torch.Tensor) and (
1067*da0073e9SAndroid Build Coastguard Worker                x.is_floating_point() or x.is_complex()
1068*da0073e9SAndroid Build Coastguard Worker            ):
1069*da0073e9SAndroid Build Coastguard Worker                any_requires_grad = True
1070*da0073e9SAndroid Build Coastguard Worker                x.requires_grad_(True)
1071*da0073e9SAndroid Build Coastguard Worker            return x
1072*da0073e9SAndroid Build Coastguard Worker
1073*da0073e9SAndroid Build Coastguard Worker        out = pytree.tree_map_(set_requires_grad, expect)
1074*da0073e9SAndroid Build Coastguard Worker        if not any_requires_grad:
1075*da0073e9SAndroid Build Coastguard Worker            # Skip ops without any floating point outputs, e.g. isnan
1076*da0073e9SAndroid Build Coastguard Worker            return
1077*da0073e9SAndroid Build Coastguard Worker
1078*da0073e9SAndroid Build Coastguard Worker        msg = (
1079*da0073e9SAndroid Build Coastguard Worker            "functions with out=... arguments don't support automatic "
1080*da0073e9SAndroid Build Coastguard Worker            "differentiation, but one of the arguments requires grad."
1081*da0073e9SAndroid Build Coastguard Worker        )
1082*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError, msg=msg):
1083*da0073e9SAndroid Build Coastguard Worker            op(sample.input, *sample.args, **sample.kwargs, out=out)
1084*da0073e9SAndroid Build Coastguard Worker
1085*da0073e9SAndroid Build Coastguard Worker    @ops(filter(reduction_dtype_filter, ops_and_refs), dtypes=(torch.int16,))
1086*da0073e9SAndroid Build Coastguard Worker    def test_out_integral_dtype(self, device, dtype, op):
1087*da0073e9SAndroid Build Coastguard Worker        def helper(with_out, expectFail, op_to_test, inputs, *args, **kwargs):
1088*da0073e9SAndroid Build Coastguard Worker            out = None
1089*da0073e9SAndroid Build Coastguard Worker            try:
1090*da0073e9SAndroid Build Coastguard Worker                if with_out:
1091*da0073e9SAndroid Build Coastguard Worker                    out = torch.empty(0, dtype=torch.int32, device=device)
1092*da0073e9SAndroid Build Coastguard Worker                    op_to_test(inputs, *args, out=out, **kwargs)
1093*da0073e9SAndroid Build Coastguard Worker                else:
1094*da0073e9SAndroid Build Coastguard Worker                    out = op_to_test(inputs, *args, **kwargs)
1095*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(expectFail)
1096*da0073e9SAndroid Build Coastguard Worker            except RuntimeError as err:
1097*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
1098*da0073e9SAndroid Build Coastguard Worker                    str(err), "dtype argument and out dtype must match in reduction"
1099*da0073e9SAndroid Build Coastguard Worker                )
1100*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(expectFail)
1101*da0073e9SAndroid Build Coastguard Worker            return out
1102*da0073e9SAndroid Build Coastguard Worker
1103*da0073e9SAndroid Build Coastguard Worker        samples = op.sample_inputs(device, dtype)
1104*da0073e9SAndroid Build Coastguard Worker        for sample in samples:
1105*da0073e9SAndroid Build Coastguard Worker            if "dtype" not in sample.kwargs:
1106*da0073e9SAndroid Build Coastguard Worker                helper(False, False, op, sample.input, *sample.args, **sample.kwargs)
1107*da0073e9SAndroid Build Coastguard Worker                helper(True, False, op, sample.input, *sample.args, **sample.kwargs)
1108*da0073e9SAndroid Build Coastguard Worker                sample.kwargs["dtype"] = torch.int16
1109*da0073e9SAndroid Build Coastguard Worker                helper(False, False, op, sample.input, *sample.args, **sample.kwargs)
1110*da0073e9SAndroid Build Coastguard Worker                helper(True, True, op, sample.input, *sample.args, **sample.kwargs)
1111*da0073e9SAndroid Build Coastguard Worker                sample.kwargs["dtype"] = torch.int32
1112*da0073e9SAndroid Build Coastguard Worker                helper(False, False, op, sample.input, *sample.args, **sample.kwargs)
1113*da0073e9SAndroid Build Coastguard Worker                helper(True, False, op, sample.input, *sample.args, **sample.kwargs)
1114*da0073e9SAndroid Build Coastguard Worker            else:
1115*da0073e9SAndroid Build Coastguard Worker                helper(False, False, op, sample.input, *sample.args, **sample.kwargs)
1116*da0073e9SAndroid Build Coastguard Worker                helper(
1117*da0073e9SAndroid Build Coastguard Worker                    True,
1118*da0073e9SAndroid Build Coastguard Worker                    sample.kwargs["dtype"] != torch.int32,
1119*da0073e9SAndroid Build Coastguard Worker                    op,
1120*da0073e9SAndroid Build Coastguard Worker                    sample.input,
1121*da0073e9SAndroid Build Coastguard Worker                    *sample.args,
1122*da0073e9SAndroid Build Coastguard Worker                    **sample.kwargs,
1123*da0073e9SAndroid Build Coastguard Worker                )
1124*da0073e9SAndroid Build Coastguard Worker
1125*da0073e9SAndroid Build Coastguard Worker    # Tests that the forward and backward passes of operations produce the
1126*da0073e9SAndroid Build Coastguard Worker    #   same values for the cross-product of op variants (method, inplace)
1127*da0073e9SAndroid Build Coastguard Worker    #   against eager's gold standard op function variant
1128*da0073e9SAndroid Build Coastguard Worker    @_variant_ops(op_db)
1129*da0073e9SAndroid Build Coastguard Worker    def test_variant_consistency_eager(self, device, dtype, op):
1130*da0073e9SAndroid Build Coastguard Worker        # Acquires variants (method variant, inplace variant, operator variant, inplace_operator variant, aliases)
1131*da0073e9SAndroid Build Coastguard Worker
1132*da0073e9SAndroid Build Coastguard Worker        method = op.method_variant
1133*da0073e9SAndroid Build Coastguard Worker        inplace = op.inplace_variant
1134*da0073e9SAndroid Build Coastguard Worker        operator = op.operator_variant
1135*da0073e9SAndroid Build Coastguard Worker        inplace_operator = op.inplace_operator_variant
1136*da0073e9SAndroid Build Coastguard Worker
1137*da0073e9SAndroid Build Coastguard Worker        # list of all inplace ops: inplace variant + alias inplace variants if exist
1138*da0073e9SAndroid Build Coastguard Worker        inplace_ops = [inplace, inplace_operator]
1139*da0073e9SAndroid Build Coastguard Worker        variants = [method, inplace, operator, inplace_operator]
1140*da0073e9SAndroid Build Coastguard Worker        operators = [operator, inplace_operator]
1141*da0073e9SAndroid Build Coastguard Worker
1142*da0073e9SAndroid Build Coastguard Worker        for a_op in op.aliases:
1143*da0073e9SAndroid Build Coastguard Worker            variants.append(a_op.op)
1144*da0073e9SAndroid Build Coastguard Worker            variants.append(a_op.method_variant)
1145*da0073e9SAndroid Build Coastguard Worker            variants.append(a_op.inplace_variant)
1146*da0073e9SAndroid Build Coastguard Worker            inplace_ops.append(a_op.inplace_variant)
1147*da0073e9SAndroid Build Coastguard Worker
1148*da0073e9SAndroid Build Coastguard Worker        inplace_variants = tuple(filter(None, inplace_ops))
1149*da0073e9SAndroid Build Coastguard Worker        variants = tuple(filter(None, variants))
1150*da0073e9SAndroid Build Coastguard Worker        operators = tuple(filter(None, operators))
1151*da0073e9SAndroid Build Coastguard Worker
1152*da0073e9SAndroid Build Coastguard Worker        _requires_grad = dtype in op.supported_backward_dtypes(
1153*da0073e9SAndroid Build Coastguard Worker            torch.device(device).type
1154*da0073e9SAndroid Build Coastguard Worker        )
1155*da0073e9SAndroid Build Coastguard Worker
1156*da0073e9SAndroid Build Coastguard Worker        include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
1157*da0073e9SAndroid Build Coastguard Worker        samples = op.sample_inputs(
1158*da0073e9SAndroid Build Coastguard Worker            device,
1159*da0073e9SAndroid Build Coastguard Worker            dtype,
1160*da0073e9SAndroid Build Coastguard Worker            requires_grad=_requires_grad,
1161*da0073e9SAndroid Build Coastguard Worker            include_conjugated_inputs=include_conjugated_inputs,
1162*da0073e9SAndroid Build Coastguard Worker        )
1163*da0073e9SAndroid Build Coastguard Worker        samples = list(samples)
1164*da0073e9SAndroid Build Coastguard Worker
1165*da0073e9SAndroid Build Coastguard Worker        def _test_consistency_helper(samples, variants):
1166*da0073e9SAndroid Build Coastguard Worker            for sample in samples:
1167*da0073e9SAndroid Build Coastguard Worker                # TODO: Check grad for all Tensors requiring grad if sample.input is TensorList
1168*da0073e9SAndroid Build Coastguard Worker                tensor = (
1169*da0073e9SAndroid Build Coastguard Worker                    sample.input
1170*da0073e9SAndroid Build Coastguard Worker                    if isinstance(sample.input, torch.Tensor)
1171*da0073e9SAndroid Build Coastguard Worker                    else sample.input[0]
1172*da0073e9SAndroid Build Coastguard Worker                )
1173*da0073e9SAndroid Build Coastguard Worker
1174*da0073e9SAndroid Build Coastguard Worker                # Computes function forward and backward values
1175*da0073e9SAndroid Build Coastguard Worker                tensor.grad = None
1176*da0073e9SAndroid Build Coastguard Worker                expected_forward = op(sample.input, *sample.args, **sample.kwargs)
1177*da0073e9SAndroid Build Coastguard Worker                expected_grad = None
1178*da0073e9SAndroid Build Coastguard Worker
1179*da0073e9SAndroid Build Coastguard Worker                output_process_fn_grad = (
1180*da0073e9SAndroid Build Coastguard Worker                    sample.output_process_fn_grad
1181*da0073e9SAndroid Build Coastguard Worker                    if sample.output_process_fn_grad
1182*da0073e9SAndroid Build Coastguard Worker                    else lambda x: x
1183*da0073e9SAndroid Build Coastguard Worker                )
1184*da0073e9SAndroid Build Coastguard Worker
1185*da0073e9SAndroid Build Coastguard Worker                # Skips inplace variants if the output dtype is not the same as
1186*da0073e9SAndroid Build Coastguard Worker                #   the input dtype
1187*da0073e9SAndroid Build Coastguard Worker                skip_inplace = False
1188*da0073e9SAndroid Build Coastguard Worker                if (
1189*da0073e9SAndroid Build Coastguard Worker                    isinstance(expected_forward, torch.Tensor)
1190*da0073e9SAndroid Build Coastguard Worker                    and expected_forward.dtype is not tensor.dtype
1191*da0073e9SAndroid Build Coastguard Worker                ):
1192*da0073e9SAndroid Build Coastguard Worker                    skip_inplace = True
1193*da0073e9SAndroid Build Coastguard Worker
1194*da0073e9SAndroid Build Coastguard Worker                # TODO: backward consistency only supported for single tensor outputs
1195*da0073e9SAndroid Build Coastguard Worker                # TODO: backward consistency only checked on sample.input, not all
1196*da0073e9SAndroid Build Coastguard Worker                #   tensor inputs
1197*da0073e9SAndroid Build Coastguard Worker                # TODO: update to handle checking grads of all tensor inputs as
1198*da0073e9SAndroid Build Coastguard Worker                #   derived from each tensor output
1199*da0073e9SAndroid Build Coastguard Worker                if isinstance(
1200*da0073e9SAndroid Build Coastguard Worker                    expected_forward, torch.Tensor
1201*da0073e9SAndroid Build Coastguard Worker                ) and dtype in op.supported_backward_dtypes(torch.device(device).type):
1202*da0073e9SAndroid Build Coastguard Worker                    out = output_process_fn_grad(expected_forward).sum()
1203*da0073e9SAndroid Build Coastguard Worker                    if out.dtype.is_complex:
1204*da0073e9SAndroid Build Coastguard Worker                        out = out.abs()
1205*da0073e9SAndroid Build Coastguard Worker                    out.backward()
1206*da0073e9SAndroid Build Coastguard Worker                    expected_grad = tensor.grad
1207*da0073e9SAndroid Build Coastguard Worker
1208*da0073e9SAndroid Build Coastguard Worker                # Test eager consistency
1209*da0073e9SAndroid Build Coastguard Worker                for variant in variants:
1210*da0073e9SAndroid Build Coastguard Worker                    # Skips inplace ops
1211*da0073e9SAndroid Build Coastguard Worker                    if variant in inplace_ops and skip_inplace:
1212*da0073e9SAndroid Build Coastguard Worker                        continue
1213*da0073e9SAndroid Build Coastguard Worker
1214*da0073e9SAndroid Build Coastguard Worker                    # Compares variant's forward
1215*da0073e9SAndroid Build Coastguard Worker                    # Note: copies the to-be-modified input when testing the inplace variant
1216*da0073e9SAndroid Build Coastguard Worker                    tensor.grad = None
1217*da0073e9SAndroid Build Coastguard Worker                    cloned = (
1218*da0073e9SAndroid Build Coastguard Worker                        clone_input_helper(sample.input)
1219*da0073e9SAndroid Build Coastguard Worker                        if variant in inplace_ops
1220*da0073e9SAndroid Build Coastguard Worker                        else sample.input
1221*da0073e9SAndroid Build Coastguard Worker                    )
1222*da0073e9SAndroid Build Coastguard Worker
1223*da0073e9SAndroid Build Coastguard Worker                    if variant in inplace_ops and sample.broadcasts_input:
1224*da0073e9SAndroid Build Coastguard Worker                        with self.assertRaises(
1225*da0073e9SAndroid Build Coastguard Worker                            RuntimeError,
1226*da0073e9SAndroid Build Coastguard Worker                            msg=(
1227*da0073e9SAndroid Build Coastguard Worker                                "inplace variant either incorrectly allowed "
1228*da0073e9SAndroid Build Coastguard Worker                                f"resizing or you have marked the sample {sample.summary()}"
1229*da0073e9SAndroid Build Coastguard Worker                                " incorrectly with `broadcasts_self=True"
1230*da0073e9SAndroid Build Coastguard Worker                            ),
1231*da0073e9SAndroid Build Coastguard Worker                        ):
1232*da0073e9SAndroid Build Coastguard Worker                            variant_forward = variant(
1233*da0073e9SAndroid Build Coastguard Worker                                cloned, *sample.args, **sample.kwargs
1234*da0073e9SAndroid Build Coastguard Worker                            )
1235*da0073e9SAndroid Build Coastguard Worker                        continue
1236*da0073e9SAndroid Build Coastguard Worker
1237*da0073e9SAndroid Build Coastguard Worker                    if variant in operators and sample.kwargs:
1238*da0073e9SAndroid Build Coastguard Worker                        # skip samples with kwargs for operator variants
1239*da0073e9SAndroid Build Coastguard Worker                        continue
1240*da0073e9SAndroid Build Coastguard Worker
1241*da0073e9SAndroid Build Coastguard Worker                    variant_forward = variant(cloned, *sample.args, **sample.kwargs)
1242*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(expected_forward, variant_forward)
1243*da0073e9SAndroid Build Coastguard Worker
1244*da0073e9SAndroid Build Coastguard Worker                    # Compares variant's backward
1245*da0073e9SAndroid Build Coastguard Worker                    if expected_grad is not None and (
1246*da0073e9SAndroid Build Coastguard Worker                        variant not in inplace_ops or op.supports_inplace_autograd
1247*da0073e9SAndroid Build Coastguard Worker                    ):
1248*da0073e9SAndroid Build Coastguard Worker                        out = output_process_fn_grad(variant_forward).sum()
1249*da0073e9SAndroid Build Coastguard Worker                        if out.dtype.is_complex:
1250*da0073e9SAndroid Build Coastguard Worker                            out = out.abs()
1251*da0073e9SAndroid Build Coastguard Worker                        out.backward()
1252*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(expected_grad, tensor.grad)
1253*da0073e9SAndroid Build Coastguard Worker
1254*da0073e9SAndroid Build Coastguard Worker        _test_consistency_helper(samples, variants)
1255*da0073e9SAndroid Build Coastguard Worker
1256*da0073e9SAndroid Build Coastguard Worker        def _test_inplace_preserve_storage(samples, variants):
1257*da0073e9SAndroid Build Coastguard Worker            for sample in samples:
1258*da0073e9SAndroid Build Coastguard Worker                # Skips inplace variants if the output dtype is not the same as
1259*da0073e9SAndroid Build Coastguard Worker                #   the input dtype
1260*da0073e9SAndroid Build Coastguard Worker                expected_forward = op(sample.input, *sample.args, **sample.kwargs)
1261*da0073e9SAndroid Build Coastguard Worker                tensor = (
1262*da0073e9SAndroid Build Coastguard Worker                    sample.input
1263*da0073e9SAndroid Build Coastguard Worker                    if isinstance(sample.input, torch.Tensor)
1264*da0073e9SAndroid Build Coastguard Worker                    else sample.input[0]
1265*da0073e9SAndroid Build Coastguard Worker                )
1266*da0073e9SAndroid Build Coastguard Worker                skip_inplace = False
1267*da0073e9SAndroid Build Coastguard Worker                if (
1268*da0073e9SAndroid Build Coastguard Worker                    isinstance(expected_forward, torch.Tensor)
1269*da0073e9SAndroid Build Coastguard Worker                    and expected_forward.dtype is not tensor.dtype
1270*da0073e9SAndroid Build Coastguard Worker                ):
1271*da0073e9SAndroid Build Coastguard Worker                    skip_inplace = True
1272*da0073e9SAndroid Build Coastguard Worker                if skip_inplace:
1273*da0073e9SAndroid Build Coastguard Worker                    return
1274*da0073e9SAndroid Build Coastguard Worker                for variant in variants:
1275*da0073e9SAndroid Build Coastguard Worker                    cloned = (
1276*da0073e9SAndroid Build Coastguard Worker                        clone_input_helper(sample.input)
1277*da0073e9SAndroid Build Coastguard Worker                        if variant in inplace_ops
1278*da0073e9SAndroid Build Coastguard Worker                        else sample.input
1279*da0073e9SAndroid Build Coastguard Worker                    )
1280*da0073e9SAndroid Build Coastguard Worker                    inp_tensor = (
1281*da0073e9SAndroid Build Coastguard Worker                        cloned if isinstance(cloned, torch.Tensor) else cloned[0]
1282*da0073e9SAndroid Build Coastguard Worker                    )
1283*da0073e9SAndroid Build Coastguard Worker                    data_ptr = inp_tensor.data_ptr()
1284*da0073e9SAndroid Build Coastguard Worker                    if variant in operators and sample.kwargs:
1285*da0073e9SAndroid Build Coastguard Worker                        # skip samples with kwargs for operator variants
1286*da0073e9SAndroid Build Coastguard Worker                        continue
1287*da0073e9SAndroid Build Coastguard Worker
1288*da0073e9SAndroid Build Coastguard Worker                    variant_forward = variant(cloned, *sample.args, **sample.kwargs)
1289*da0073e9SAndroid Build Coastguard Worker                    # TODO Support non-tensor outputs if they exist for inplace ops
1290*da0073e9SAndroid Build Coastguard Worker                    if isinstance(variant_forward, torch.Tensor):
1291*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(
1292*da0073e9SAndroid Build Coastguard Worker                            data_ptr, variant_forward.data_ptr(), atol=0, rtol=0
1293*da0073e9SAndroid Build Coastguard Worker                        )
1294*da0073e9SAndroid Build Coastguard Worker                    else:
1295*da0073e9SAndroid Build Coastguard Worker                        self.assertTrue(
1296*da0073e9SAndroid Build Coastguard Worker                            False,
1297*da0073e9SAndroid Build Coastguard Worker                            "Non-tensor outputs for inplace ops are not supported",
1298*da0073e9SAndroid Build Coastguard Worker                        )
1299*da0073e9SAndroid Build Coastguard Worker
1300*da0073e9SAndroid Build Coastguard Worker        if len(inplace_ops) > 0:
1301*da0073e9SAndroid Build Coastguard Worker            inplace_samples = list(
1302*da0073e9SAndroid Build Coastguard Worker                filter(lambda sample: not sample.broadcasts_input, samples)
1303*da0073e9SAndroid Build Coastguard Worker            )
1304*da0073e9SAndroid Build Coastguard Worker            _test_inplace_preserve_storage(inplace_samples, inplace_variants)
1305*da0073e9SAndroid Build Coastguard Worker
1306*da0073e9SAndroid Build Coastguard Worker    # Reference testing for operations in complex32 against complex64.
1307*da0073e9SAndroid Build Coastguard Worker    # NOTE: We test against complex64 as NumPy doesn't have a complex32 equivalent dtype.
1308*da0073e9SAndroid Build Coastguard Worker    @ops(op_db, allowed_dtypes=(torch.complex32,))
1309*da0073e9SAndroid Build Coastguard Worker    def test_complex_half_reference_testing(self, device, dtype, op):
1310*da0073e9SAndroid Build Coastguard Worker        if not op.supports_dtype(torch.complex32, device):
1311*da0073e9SAndroid Build Coastguard Worker            unittest.skip("Does not support complex32")
1312*da0073e9SAndroid Build Coastguard Worker
1313*da0073e9SAndroid Build Coastguard Worker        for sample in op.sample_inputs(device, dtype):
1314*da0073e9SAndroid Build Coastguard Worker            actual = op(sample.input, *sample.args, **sample.kwargs)
1315*da0073e9SAndroid Build Coastguard Worker            # sample.transform applies the lambda to torch.Tensor and torch.dtype.
1316*da0073e9SAndroid Build Coastguard Worker            # However, we only want to apply it to Tensors with dtype `torch.complex32`..
1317*da0073e9SAndroid Build Coastguard Worker            transformed_sample = sample.transform(
1318*da0073e9SAndroid Build Coastguard Worker                lambda x: x.to(torch.complex64)
1319*da0073e9SAndroid Build Coastguard Worker                if isinstance(x, torch.Tensor) and x.dtype is torch.complex32
1320*da0073e9SAndroid Build Coastguard Worker                else x
1321*da0073e9SAndroid Build Coastguard Worker            )
1322*da0073e9SAndroid Build Coastguard Worker            expected = op(
1323*da0073e9SAndroid Build Coastguard Worker                transformed_sample.input,
1324*da0073e9SAndroid Build Coastguard Worker                *transformed_sample.args,
1325*da0073e9SAndroid Build Coastguard Worker                **transformed_sample.kwargs,
1326*da0073e9SAndroid Build Coastguard Worker            )
1327*da0073e9SAndroid Build Coastguard Worker            # Since range of chalf is much less compared to cfloat,
1328*da0073e9SAndroid Build Coastguard Worker            # we get `inf`s easily (eg. with `pow`, `exp`),
1329*da0073e9SAndroid Build Coastguard Worker            # so we cast `cfloat` back to `chalf`.
1330*da0073e9SAndroid Build Coastguard Worker            expected = tree_map(
1331*da0073e9SAndroid Build Coastguard Worker                lambda x: x.to(torch.complex32)
1332*da0073e9SAndroid Build Coastguard Worker                if isinstance(x, torch.Tensor) and x.dtype is torch.complex64
1333*da0073e9SAndroid Build Coastguard Worker                else x,
1334*da0073e9SAndroid Build Coastguard Worker                expected,
1335*da0073e9SAndroid Build Coastguard Worker            )
1336*da0073e9SAndroid Build Coastguard Worker
1337*da0073e9SAndroid Build Coastguard Worker            # `exact_dtype` is False because for ops like real, imag
1338*da0073e9SAndroid Build Coastguard Worker            # we get different dtypes for `actual` and `expected`
1339*da0073e9SAndroid Build Coastguard Worker            # `chalf` input -> `half` output
1340*da0073e9SAndroid Build Coastguard Worker            # `cfloat` input -> `float` output
1341*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expected, exact_dtype=False)
1342*da0073e9SAndroid Build Coastguard Worker
1343*da0073e9SAndroid Build Coastguard Worker    @ops(op_db, allowed_dtypes=(torch.bool,))
1344*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_UBSAN, "Test uses undefined behavior")
1345*da0073e9SAndroid Build Coastguard Worker    def test_non_standard_bool_values(self, device, dtype, op):
1346*da0073e9SAndroid Build Coastguard Worker        # Test boolean values other than 0x00 and 0x01 (gh-54789)
1347*da0073e9SAndroid Build Coastguard Worker        def convert_boolean_tensors(x):
1348*da0073e9SAndroid Build Coastguard Worker            if not isinstance(x, torch.Tensor) or x.dtype != torch.bool:
1349*da0073e9SAndroid Build Coastguard Worker                return x
1350*da0073e9SAndroid Build Coastguard Worker
1351*da0073e9SAndroid Build Coastguard Worker            # Map False -> 0 and True -> Random value in [2, 255]
1352*da0073e9SAndroid Build Coastguard Worker            true_vals = torch.randint(
1353*da0073e9SAndroid Build Coastguard Worker                2, 255, x.shape, dtype=torch.uint8, device=x.device
1354*da0073e9SAndroid Build Coastguard Worker            )
1355*da0073e9SAndroid Build Coastguard Worker            false_vals = torch.zeros((), dtype=torch.uint8, device=x.device)
1356*da0073e9SAndroid Build Coastguard Worker            x_int = torch.where(x, true_vals, false_vals)
1357*da0073e9SAndroid Build Coastguard Worker
1358*da0073e9SAndroid Build Coastguard Worker            ret = x_int.view(torch.bool)
1359*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ret, x)
1360*da0073e9SAndroid Build Coastguard Worker            return ret
1361*da0073e9SAndroid Build Coastguard Worker
1362*da0073e9SAndroid Build Coastguard Worker        for sample in op.sample_inputs(device, dtype):
1363*da0073e9SAndroid Build Coastguard Worker            expect = op(sample.input, *sample.args, **sample.kwargs)
1364*da0073e9SAndroid Build Coastguard Worker
1365*da0073e9SAndroid Build Coastguard Worker            transformed = sample.transform(convert_boolean_tensors)
1366*da0073e9SAndroid Build Coastguard Worker            actual = op(transformed.input, *transformed.args, **transformed.kwargs)
1367*da0073e9SAndroid Build Coastguard Worker
1368*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expect, actual)
1369*da0073e9SAndroid Build Coastguard Worker
1370*da0073e9SAndroid Build Coastguard Worker    # Validates that each OpInfo specifies its forward and backward dtypes
1371*da0073e9SAndroid Build Coastguard Worker    #   correctly for CPU and CUDA devices
1372*da0073e9SAndroid Build Coastguard Worker    @skipMeta
1373*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypesAnd(["hpu"])
1374*da0073e9SAndroid Build Coastguard Worker    @ops(ops_and_refs, dtypes=OpDTypes.none)
1375*da0073e9SAndroid Build Coastguard Worker    def test_dtypes(self, device, op):
1376*da0073e9SAndroid Build Coastguard Worker        # Check complex32 support only if the op claims.
1377*da0073e9SAndroid Build Coastguard Worker        # TODO: Once the complex32 support is better, we should add check for complex32 unconditionally.
1378*da0073e9SAndroid Build Coastguard Worker        device_type = torch.device(device).type
1379*da0073e9SAndroid Build Coastguard Worker        include_complex32 = (
1380*da0073e9SAndroid Build Coastguard Worker            (torch.complex32,)
1381*da0073e9SAndroid Build Coastguard Worker            if op.supports_dtype(torch.complex32, device_type)
1382*da0073e9SAndroid Build Coastguard Worker            else ()
1383*da0073e9SAndroid Build Coastguard Worker        )
1384*da0073e9SAndroid Build Coastguard Worker
1385*da0073e9SAndroid Build Coastguard Worker        # dtypes to try to backward in
1386*da0073e9SAndroid Build Coastguard Worker        allowed_backward_dtypes = floating_and_complex_types_and(
1387*da0073e9SAndroid Build Coastguard Worker            *((torch.half, torch.bfloat16) + include_complex32)
1388*da0073e9SAndroid Build Coastguard Worker        )
1389*da0073e9SAndroid Build Coastguard Worker
1390*da0073e9SAndroid Build Coastguard Worker        # lists for (un)supported dtypes
1391*da0073e9SAndroid Build Coastguard Worker        supported_dtypes = set()
1392*da0073e9SAndroid Build Coastguard Worker        unsupported_dtypes = set()
1393*da0073e9SAndroid Build Coastguard Worker        supported_backward_dtypes = set()
1394*da0073e9SAndroid Build Coastguard Worker        unsupported_backward_dtypes = set()
1395*da0073e9SAndroid Build Coastguard Worker        dtype_error: Dict[torch.dtype, Exception] = {}
1396*da0073e9SAndroid Build Coastguard Worker
1397*da0073e9SAndroid Build Coastguard Worker        def unsupported(dtype, e):
1398*da0073e9SAndroid Build Coastguard Worker            dtype_error[dtype] = e
1399*da0073e9SAndroid Build Coastguard Worker            unsupported_dtypes.add(dtype)
1400*da0073e9SAndroid Build Coastguard Worker            if dtype in allowed_backward_dtypes:
1401*da0073e9SAndroid Build Coastguard Worker                unsupported_backward_dtypes.add(dtype)
1402*da0073e9SAndroid Build Coastguard Worker
1403*da0073e9SAndroid Build Coastguard Worker        for dtype in all_types_and_complex_and(
1404*da0073e9SAndroid Build Coastguard Worker            *((torch.half, torch.bfloat16, torch.bool) + include_complex32)
1405*da0073e9SAndroid Build Coastguard Worker        ):
1406*da0073e9SAndroid Build Coastguard Worker            # tries to acquire samples - failure indicates lack of support
1407*da0073e9SAndroid Build Coastguard Worker            requires_grad = dtype in allowed_backward_dtypes
1408*da0073e9SAndroid Build Coastguard Worker            try:
1409*da0073e9SAndroid Build Coastguard Worker                samples = tuple(
1410*da0073e9SAndroid Build Coastguard Worker                    op.sample_inputs(device, dtype, requires_grad=requires_grad)
1411*da0073e9SAndroid Build Coastguard Worker                )
1412*da0073e9SAndroid Build Coastguard Worker            except Exception as e:
1413*da0073e9SAndroid Build Coastguard Worker                unsupported(dtype, e)
1414*da0073e9SAndroid Build Coastguard Worker                continue
1415*da0073e9SAndroid Build Coastguard Worker
1416*da0073e9SAndroid Build Coastguard Worker            for sample in samples:
1417*da0073e9SAndroid Build Coastguard Worker                # tries to call operator with the sample - failure indicates
1418*da0073e9SAndroid Build Coastguard Worker                #   lack of support
1419*da0073e9SAndroid Build Coastguard Worker                try:
1420*da0073e9SAndroid Build Coastguard Worker                    result = op(sample.input, *sample.args, **sample.kwargs)
1421*da0073e9SAndroid Build Coastguard Worker                    supported_dtypes.add(dtype)
1422*da0073e9SAndroid Build Coastguard Worker                except Exception as e:
1423*da0073e9SAndroid Build Coastguard Worker                    # NOTE: some ops will fail in forward if their inputs
1424*da0073e9SAndroid Build Coastguard Worker                    #   require grad but they don't support computing the gradient
1425*da0073e9SAndroid Build Coastguard Worker                    #   in that type! This is a bug in the op!
1426*da0073e9SAndroid Build Coastguard Worker                    unsupported(dtype, e)
1427*da0073e9SAndroid Build Coastguard Worker                    continue
1428*da0073e9SAndroid Build Coastguard Worker
1429*da0073e9SAndroid Build Coastguard Worker                # Checks for backward support in the same dtype, if the input has
1430*da0073e9SAndroid Build Coastguard Worker                # one or more tensors requiring grad
1431*da0073e9SAndroid Build Coastguard Worker                def _tensor_requires_grad(x):
1432*da0073e9SAndroid Build Coastguard Worker                    if isinstance(x, dict):
1433*da0073e9SAndroid Build Coastguard Worker                        for v in x.values():
1434*da0073e9SAndroid Build Coastguard Worker                            if _tensor_requires_grad(v):
1435*da0073e9SAndroid Build Coastguard Worker                                return True
1436*da0073e9SAndroid Build Coastguard Worker                    if isinstance(x, (list, tuple)):
1437*da0073e9SAndroid Build Coastguard Worker                        for a in x:
1438*da0073e9SAndroid Build Coastguard Worker                            if _tensor_requires_grad(a):
1439*da0073e9SAndroid Build Coastguard Worker                                return True
1440*da0073e9SAndroid Build Coastguard Worker                    if isinstance(x, torch.Tensor) and x.requires_grad:
1441*da0073e9SAndroid Build Coastguard Worker                        return True
1442*da0073e9SAndroid Build Coastguard Worker
1443*da0073e9SAndroid Build Coastguard Worker                    return False
1444*da0073e9SAndroid Build Coastguard Worker
1445*da0073e9SAndroid Build Coastguard Worker                requires_grad = (
1446*da0073e9SAndroid Build Coastguard Worker                    _tensor_requires_grad(sample.input)
1447*da0073e9SAndroid Build Coastguard Worker                    or _tensor_requires_grad(sample.args)
1448*da0073e9SAndroid Build Coastguard Worker                    or _tensor_requires_grad(sample.kwargs)
1449*da0073e9SAndroid Build Coastguard Worker                )
1450*da0073e9SAndroid Build Coastguard Worker                if not requires_grad:
1451*da0073e9SAndroid Build Coastguard Worker                    continue
1452*da0073e9SAndroid Build Coastguard Worker
1453*da0073e9SAndroid Build Coastguard Worker                try:
1454*da0073e9SAndroid Build Coastguard Worker                    result = sample.output_process_fn_grad(result)
1455*da0073e9SAndroid Build Coastguard Worker                    if isinstance(result, torch.Tensor):
1456*da0073e9SAndroid Build Coastguard Worker                        backward_tensor = result
1457*da0073e9SAndroid Build Coastguard Worker                    elif isinstance(result, Sequence) and isinstance(
1458*da0073e9SAndroid Build Coastguard Worker                        result[0], torch.Tensor
1459*da0073e9SAndroid Build Coastguard Worker                    ):
1460*da0073e9SAndroid Build Coastguard Worker                        backward_tensor = result[0]
1461*da0073e9SAndroid Build Coastguard Worker                    else:
1462*da0073e9SAndroid Build Coastguard Worker                        continue
1463*da0073e9SAndroid Build Coastguard Worker
1464*da0073e9SAndroid Build Coastguard Worker                    # Note: this grad may not have the same dtype as dtype
1465*da0073e9SAndroid Build Coastguard Worker                    # For functions like complex (float -> complex) or abs
1466*da0073e9SAndroid Build Coastguard Worker                    #   (complex -> float) the grad tensor will have a
1467*da0073e9SAndroid Build Coastguard Worker                    #   different dtype than the input.
1468*da0073e9SAndroid Build Coastguard Worker                    #   For simplicity, this is still modeled as these ops
1469*da0073e9SAndroid Build Coastguard Worker                    #   supporting grad in the input dtype.
1470*da0073e9SAndroid Build Coastguard Worker                    grad = torch.randn_like(backward_tensor)
1471*da0073e9SAndroid Build Coastguard Worker                    backward_tensor.backward(grad)
1472*da0073e9SAndroid Build Coastguard Worker                    supported_backward_dtypes.add(dtype)
1473*da0073e9SAndroid Build Coastguard Worker                except Exception as e:
1474*da0073e9SAndroid Build Coastguard Worker                    dtype_error[dtype] = e
1475*da0073e9SAndroid Build Coastguard Worker                    unsupported_backward_dtypes.add(dtype)
1476*da0073e9SAndroid Build Coastguard Worker
1477*da0073e9SAndroid Build Coastguard Worker        # Checks that dtypes are listed correctly and generates an informative
1478*da0073e9SAndroid Build Coastguard Worker        #   error message
1479*da0073e9SAndroid Build Coastguard Worker
1480*da0073e9SAndroid Build Coastguard Worker        supported_forward = supported_dtypes - unsupported_dtypes
1481*da0073e9SAndroid Build Coastguard Worker        partially_supported_forward = supported_dtypes & unsupported_dtypes
1482*da0073e9SAndroid Build Coastguard Worker        unsupported_forward = unsupported_dtypes - supported_dtypes
1483*da0073e9SAndroid Build Coastguard Worker        supported_backward = supported_backward_dtypes - unsupported_backward_dtypes
1484*da0073e9SAndroid Build Coastguard Worker        partially_supported_backward = (
1485*da0073e9SAndroid Build Coastguard Worker            supported_backward_dtypes & unsupported_backward_dtypes
1486*da0073e9SAndroid Build Coastguard Worker        )
1487*da0073e9SAndroid Build Coastguard Worker        unsupported_backward = unsupported_backward_dtypes - supported_backward_dtypes
1488*da0073e9SAndroid Build Coastguard Worker
1489*da0073e9SAndroid Build Coastguard Worker        device_type = torch.device(device).type
1490*da0073e9SAndroid Build Coastguard Worker
1491*da0073e9SAndroid Build Coastguard Worker        claimed_forward = set(op.supported_dtypes(device_type))
1492*da0073e9SAndroid Build Coastguard Worker        supported_but_unclaimed_forward = supported_forward - claimed_forward
1493*da0073e9SAndroid Build Coastguard Worker        claimed_but_unsupported_forward = claimed_forward & unsupported_forward
1494*da0073e9SAndroid Build Coastguard Worker
1495*da0073e9SAndroid Build Coastguard Worker        claimed_backward = set(op.supported_backward_dtypes(device_type))
1496*da0073e9SAndroid Build Coastguard Worker        supported_but_unclaimed_backward = supported_backward - claimed_backward
1497*da0073e9SAndroid Build Coastguard Worker        claimed_but_unsupported_backward = claimed_backward & unsupported_backward
1498*da0073e9SAndroid Build Coastguard Worker
1499*da0073e9SAndroid Build Coastguard Worker        # Partially supporting a dtype is not an error, but we print a warning
1500*da0073e9SAndroid Build Coastguard Worker        if (len(partially_supported_forward) + len(partially_supported_backward)) > 0:
1501*da0073e9SAndroid Build Coastguard Worker            msg = f"Some dtypes for {op.name} on device type {device_type} are only partially supported!\n"
1502*da0073e9SAndroid Build Coastguard Worker            if len(partially_supported_forward) > 0:
1503*da0073e9SAndroid Build Coastguard Worker                msg = (
1504*da0073e9SAndroid Build Coastguard Worker                    msg
1505*da0073e9SAndroid Build Coastguard Worker                    + f"The following dtypes only worked on some samples during forward: {partially_supported_forward}.\n"
1506*da0073e9SAndroid Build Coastguard Worker                )
1507*da0073e9SAndroid Build Coastguard Worker            if len(partially_supported_backward) > 0:
1508*da0073e9SAndroid Build Coastguard Worker                msg = (
1509*da0073e9SAndroid Build Coastguard Worker                    msg
1510*da0073e9SAndroid Build Coastguard Worker                    + f"The following dtypes only worked on some samples during backward: {partially_supported_backward}.\n"
1511*da0073e9SAndroid Build Coastguard Worker                )
1512*da0073e9SAndroid Build Coastguard Worker            print(msg)
1513*da0073e9SAndroid Build Coastguard Worker
1514*da0073e9SAndroid Build Coastguard Worker        if (
1515*da0073e9SAndroid Build Coastguard Worker            len(supported_but_unclaimed_forward)
1516*da0073e9SAndroid Build Coastguard Worker            + len(claimed_but_unsupported_forward)
1517*da0073e9SAndroid Build Coastguard Worker            + len(supported_but_unclaimed_backward)
1518*da0073e9SAndroid Build Coastguard Worker            + len(claimed_but_unsupported_backward)
1519*da0073e9SAndroid Build Coastguard Worker        ) == 0:
1520*da0073e9SAndroid Build Coastguard Worker            return
1521*da0073e9SAndroid Build Coastguard Worker
1522*da0073e9SAndroid Build Coastguard Worker        # Reference operators often support additional dtypes, and that's OK
1523*da0073e9SAndroid Build Coastguard Worker        if op in python_ref_db:
1524*da0073e9SAndroid Build Coastguard Worker            if (
1525*da0073e9SAndroid Build Coastguard Worker                len(claimed_but_unsupported_forward)
1526*da0073e9SAndroid Build Coastguard Worker                + len(claimed_but_unsupported_backward)
1527*da0073e9SAndroid Build Coastguard Worker            ) == 0:
1528*da0073e9SAndroid Build Coastguard Worker                return
1529*da0073e9SAndroid Build Coastguard Worker
1530*da0073e9SAndroid Build Coastguard Worker        # Generates error msg
1531*da0073e9SAndroid Build Coastguard Worker        msg = f"The supported dtypes for {op.name} on device type {device_type} are incorrect!\n"
1532*da0073e9SAndroid Build Coastguard Worker        if len(supported_but_unclaimed_forward) > 0:
1533*da0073e9SAndroid Build Coastguard Worker            msg = (
1534*da0073e9SAndroid Build Coastguard Worker                msg
1535*da0073e9SAndroid Build Coastguard Worker                + "The following dtypes worked in forward but are not listed by the OpInfo: "
1536*da0073e9SAndroid Build Coastguard Worker                + f"{supported_but_unclaimed_forward}.\n"
1537*da0073e9SAndroid Build Coastguard Worker            )
1538*da0073e9SAndroid Build Coastguard Worker        if len(supported_but_unclaimed_backward) > 0:
1539*da0073e9SAndroid Build Coastguard Worker            msg = (
1540*da0073e9SAndroid Build Coastguard Worker                msg
1541*da0073e9SAndroid Build Coastguard Worker                + "The following dtypes worked in backward but are not listed by the OpInfo: "
1542*da0073e9SAndroid Build Coastguard Worker                + f"{supported_but_unclaimed_backward}.\n"
1543*da0073e9SAndroid Build Coastguard Worker            )
1544*da0073e9SAndroid Build Coastguard Worker        if len(claimed_but_unsupported_forward) > 0:
1545*da0073e9SAndroid Build Coastguard Worker            msg = (
1546*da0073e9SAndroid Build Coastguard Worker                msg
1547*da0073e9SAndroid Build Coastguard Worker                + "The following dtypes did not work in forward but are listed by the OpInfo: "
1548*da0073e9SAndroid Build Coastguard Worker                + f"{claimed_but_unsupported_forward}.\n"
1549*da0073e9SAndroid Build Coastguard Worker            )
1550*da0073e9SAndroid Build Coastguard Worker        if len(claimed_but_unsupported_backward) > 0:
1551*da0073e9SAndroid Build Coastguard Worker            msg = (
1552*da0073e9SAndroid Build Coastguard Worker                msg
1553*da0073e9SAndroid Build Coastguard Worker                + "The following dtypes did not work in backward "
1554*da0073e9SAndroid Build Coastguard Worker                + f"but are listed by the OpInfo: {claimed_but_unsupported_backward}.\n"
1555*da0073e9SAndroid Build Coastguard Worker            )
1556*da0073e9SAndroid Build Coastguard Worker
1557*da0073e9SAndroid Build Coastguard Worker        all_claimed_but_unsupported = set.union(
1558*da0073e9SAndroid Build Coastguard Worker            claimed_but_unsupported_backward, claimed_but_unsupported_forward
1559*da0073e9SAndroid Build Coastguard Worker        )
1560*da0073e9SAndroid Build Coastguard Worker        if all_claimed_but_unsupported:
1561*da0073e9SAndroid Build Coastguard Worker            msg += "Unexpected failures raised the following errors:\n"
1562*da0073e9SAndroid Build Coastguard Worker            for dtype in all_claimed_but_unsupported:
1563*da0073e9SAndroid Build Coastguard Worker                msg += f"{dtype} - {dtype_error[dtype]}\n"
1564*da0073e9SAndroid Build Coastguard Worker
1565*da0073e9SAndroid Build Coastguard Worker        self.fail(msg)
1566*da0073e9SAndroid Build Coastguard Worker
1567*da0073e9SAndroid Build Coastguard Worker    # Validates that each OpInfo that sets promotes_int_to_float=True does as it says
1568*da0073e9SAndroid Build Coastguard Worker    @skipMeta
1569*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypesAnd(["hpu"])
1570*da0073e9SAndroid Build Coastguard Worker    @ops(
1571*da0073e9SAndroid Build Coastguard Worker        (op for op in op_db if op.promotes_int_to_float),
1572*da0073e9SAndroid Build Coastguard Worker        allowed_dtypes=integral_types_and(torch.bool),
1573*da0073e9SAndroid Build Coastguard Worker    )
1574*da0073e9SAndroid Build Coastguard Worker    def test_promotes_int_to_float(self, device, dtype, op):
1575*da0073e9SAndroid Build Coastguard Worker        for sample in op.sample_inputs(device, dtype):
1576*da0073e9SAndroid Build Coastguard Worker            output = op(sample.input, *sample.args, **sample.kwargs)
1577*da0073e9SAndroid Build Coastguard Worker            if not output.dtype.is_floating_point:
1578*da0073e9SAndroid Build Coastguard Worker                self.fail(
1579*da0073e9SAndroid Build Coastguard Worker                    f"The OpInfo sets `promotes_int_to_float=True`, but {dtype} was promoted to {output.dtype}."
1580*da0073e9SAndroid Build Coastguard Worker                )
1581*da0073e9SAndroid Build Coastguard Worker
1582*da0073e9SAndroid Build Coastguard Worker
1583*da0073e9SAndroid Build Coastguard Worker@unMarkDynamoStrictTest
1584*da0073e9SAndroid Build Coastguard Workerclass TestCompositeCompliance(TestCase):
1585*da0073e9SAndroid Build Coastguard Worker    # Checks if the operator (if it is composite) is written to support most
1586*da0073e9SAndroid Build Coastguard Worker    # backends and Tensor subclasses. See "CompositeImplicitAutograd Compliance"
1587*da0073e9SAndroid Build Coastguard Worker    # in aten/src/ATen/native/README.md for more details
1588*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
1589*da0073e9SAndroid Build Coastguard Worker        IS_FBCODE or IS_SANDCASTLE, "__torch_dispatch__ does not work in fbcode"
1590*da0073e9SAndroid Build Coastguard Worker    )
1591*da0073e9SAndroid Build Coastguard Worker    @ops(op_db, allowed_dtypes=(torch.float,))
1592*da0073e9SAndroid Build Coastguard Worker    def test_operator(self, device, dtype, op):
1593*da0073e9SAndroid Build Coastguard Worker        samples = op.sample_inputs(device, dtype, requires_grad=False)
1594*da0073e9SAndroid Build Coastguard Worker
1595*da0073e9SAndroid Build Coastguard Worker        for sample in samples:
1596*da0073e9SAndroid Build Coastguard Worker            args = [sample.input] + list(sample.args)
1597*da0073e9SAndroid Build Coastguard Worker            kwargs = sample.kwargs
1598*da0073e9SAndroid Build Coastguard Worker            composite_compliance.check_with_mode(op, args, kwargs, self.assertEqual)
1599*da0073e9SAndroid Build Coastguard Worker            composite_compliance.check_all_permutations(
1600*da0073e9SAndroid Build Coastguard Worker                op, args, kwargs, self.assertEqual
1601*da0073e9SAndroid Build Coastguard Worker            )
1602*da0073e9SAndroid Build Coastguard Worker
1603*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
1604*da0073e9SAndroid Build Coastguard Worker        IS_FBCODE or IS_SANDCASTLE, "__torch_dispatch__ does not work in fbcode"
1605*da0073e9SAndroid Build Coastguard Worker    )
1606*da0073e9SAndroid Build Coastguard Worker    @ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,))
1607*da0073e9SAndroid Build Coastguard Worker    def test_backward(self, device, dtype, op):
1608*da0073e9SAndroid Build Coastguard Worker        samples = op.sample_inputs(device, dtype, requires_grad=True)
1609*da0073e9SAndroid Build Coastguard Worker
1610*da0073e9SAndroid Build Coastguard Worker        for sample in samples:
1611*da0073e9SAndroid Build Coastguard Worker            args = [sample.input] + list(sample.args)
1612*da0073e9SAndroid Build Coastguard Worker            kwargs = sample.kwargs
1613*da0073e9SAndroid Build Coastguard Worker            # We pass assertEqual so that decorators like `toleranceOverride`
1614*da0073e9SAndroid Build Coastguard Worker            # actually work (otherwise they silently do nothing!)
1615*da0073e9SAndroid Build Coastguard Worker            composite_compliance.check_backward_formula(
1616*da0073e9SAndroid Build Coastguard Worker                op.get_op(),
1617*da0073e9SAndroid Build Coastguard Worker                args,
1618*da0073e9SAndroid Build Coastguard Worker                kwargs,
1619*da0073e9SAndroid Build Coastguard Worker                sample.output_process_fn_grad,
1620*da0073e9SAndroid Build Coastguard Worker                op.gradcheck_wrapper,
1621*da0073e9SAndroid Build Coastguard Worker                self.assertEqual,
1622*da0073e9SAndroid Build Coastguard Worker            )
1623*da0073e9SAndroid Build Coastguard Worker
1624*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
1625*da0073e9SAndroid Build Coastguard Worker        IS_FBCODE or IS_SANDCASTLE, "__torch_dispatch__ does not work in fbcode"
1626*da0073e9SAndroid Build Coastguard Worker    )
1627*da0073e9SAndroid Build Coastguard Worker    @ops(op_db, allowed_dtypes=(torch.float,))
1628*da0073e9SAndroid Build Coastguard Worker    def test_forward_ad(self, device, dtype, op):
1629*da0073e9SAndroid Build Coastguard Worker        if torch.float not in op.supported_backward_dtypes(device):
1630*da0073e9SAndroid Build Coastguard Worker            raise unittest.SkipTest("Does not support autograd")
1631*da0073e9SAndroid Build Coastguard Worker
1632*da0073e9SAndroid Build Coastguard Worker        if not op.supports_forward_ad:
1633*da0073e9SAndroid Build Coastguard Worker            raise unittest.SkipTest("Does not support forward_ad")
1634*da0073e9SAndroid Build Coastguard Worker
1635*da0073e9SAndroid Build Coastguard Worker        samples = op.sample_inputs(device, dtype, requires_grad=True)
1636*da0073e9SAndroid Build Coastguard Worker
1637*da0073e9SAndroid Build Coastguard Worker        for sample in samples:
1638*da0073e9SAndroid Build Coastguard Worker            args = [sample.input] + list(sample.args)
1639*da0073e9SAndroid Build Coastguard Worker            kwargs = sample.kwargs
1640*da0073e9SAndroid Build Coastguard Worker            # We pass assertEqual so that decorators like `toleranceOverride`
1641*da0073e9SAndroid Build Coastguard Worker            # actually work (otherwise they silently do nothing!)
1642*da0073e9SAndroid Build Coastguard Worker            composite_compliance.check_forward_ad_formula(
1643*da0073e9SAndroid Build Coastguard Worker                op.get_op(), args, kwargs, op.gradcheck_wrapper, self.assertEqual
1644*da0073e9SAndroid Build Coastguard Worker            )
1645*da0073e9SAndroid Build Coastguard Worker
1646*da0073e9SAndroid Build Coastguard Worker    @ops(op_db, allowed_dtypes=(torch.float,))
1647*da0073e9SAndroid Build Coastguard Worker    def test_cow_input(self, device, dtype, op):
1648*da0073e9SAndroid Build Coastguard Worker        samples = op.sample_inputs(device, dtype, requires_grad=op.supports_autograd)
1649*da0073e9SAndroid Build Coastguard Worker
1650*da0073e9SAndroid Build Coastguard Worker        def is_strided_tensor(arg):
1651*da0073e9SAndroid Build Coastguard Worker            return torch.is_tensor(arg) and arg.layout == torch.strided
1652*da0073e9SAndroid Build Coastguard Worker
1653*da0073e9SAndroid Build Coastguard Worker        def check_ignore_materialize(idx_or_kw, allow_list):
1654*da0073e9SAndroid Build Coastguard Worker            return (allow_list is not None) and (idx_or_kw in allow_list)
1655*da0073e9SAndroid Build Coastguard Worker
1656*da0073e9SAndroid Build Coastguard Worker        def check_cow_input(
1657*da0073e9SAndroid Build Coastguard Worker            arg,
1658*da0073e9SAndroid Build Coastguard Worker            arg_copy,
1659*da0073e9SAndroid Build Coastguard Worker            idx_or_kw,
1660*da0073e9SAndroid Build Coastguard Worker            backward_or_forward="forward",
1661*da0073e9SAndroid Build Coastguard Worker            supports_cow_input_no_materialize=op.supports_cow_input_no_materialize_forward,
1662*da0073e9SAndroid Build Coastguard Worker            allow_list=op.allow_cow_input_materialize_forward,
1663*da0073e9SAndroid Build Coastguard Worker        ):
1664*da0073e9SAndroid Build Coastguard Worker            arg_name = (
1665*da0073e9SAndroid Build Coastguard Worker                f"Argument {idx_or_kw}"
1666*da0073e9SAndroid Build Coastguard Worker                if isinstance(idx_or_kw, int)
1667*da0073e9SAndroid Build Coastguard Worker                else f"Keyword argument '{idx_or_kw}'"
1668*da0073e9SAndroid Build Coastguard Worker            ) + f" during {backward_or_forward} call"
1669*da0073e9SAndroid Build Coastguard Worker
1670*da0073e9SAndroid Build Coastguard Worker            if is_strided_tensor(arg):
1671*da0073e9SAndroid Build Coastguard Worker                is_cow = torch._C._is_cow_tensor(arg)
1672*da0073e9SAndroid Build Coastguard Worker
1673*da0073e9SAndroid Build Coastguard Worker                if supports_cow_input_no_materialize and not check_ignore_materialize(
1674*da0073e9SAndroid Build Coastguard Worker                    idx_or_kw, allow_list
1675*da0073e9SAndroid Build Coastguard Worker                ):
1676*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(
1677*da0073e9SAndroid Build Coastguard Worker                        is_cow,
1678*da0073e9SAndroid Build Coastguard Worker                        msg=(
1679*da0073e9SAndroid Build Coastguard Worker                            f"{arg_name} unexpectedly materializes. "
1680*da0073e9SAndroid Build Coastguard Worker                            f"Either set `supports_cow_input_no_materialize_{backward_or_forward}=False` "
1681*da0073e9SAndroid Build Coastguard Worker                            "in this operation's OpInfo, add the arg to the OpInfo's "
1682*da0073e9SAndroid Build Coastguard Worker                            f"`allow_cow_input_materialize_{backward_or_forward}` list, or change the "
1683*da0073e9SAndroid Build Coastguard Worker                            "implementation to avoid materialization."
1684*da0073e9SAndroid Build Coastguard Worker                        ),
1685*da0073e9SAndroid Build Coastguard Worker                    )
1686*da0073e9SAndroid Build Coastguard Worker
1687*da0073e9SAndroid Build Coastguard Worker                if is_cow:
1688*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(
1689*da0073e9SAndroid Build Coastguard Worker                        torch.allclose(arg, arg_copy, rtol=0, atol=0, equal_nan=True),
1690*da0073e9SAndroid Build Coastguard Worker                        msg=(
1691*da0073e9SAndroid Build Coastguard Worker                            f"{arg_name} avoided materialization, "
1692*da0073e9SAndroid Build Coastguard Worker                            "but the operation mutated its data."
1693*da0073e9SAndroid Build Coastguard Worker                        ),
1694*da0073e9SAndroid Build Coastguard Worker                    )
1695*da0073e9SAndroid Build Coastguard Worker
1696*da0073e9SAndroid Build Coastguard Worker        for sample in samples:
1697*da0073e9SAndroid Build Coastguard Worker            args_raw = [sample.input] + list(sample.args)
1698*da0073e9SAndroid Build Coastguard Worker            kwargs_raw = sample.kwargs
1699*da0073e9SAndroid Build Coastguard Worker            args_copy = []
1700*da0073e9SAndroid Build Coastguard Worker            args = []
1701*da0073e9SAndroid Build Coastguard Worker            kwargs_copy = {}
1702*da0073e9SAndroid Build Coastguard Worker            kwargs = {}
1703*da0073e9SAndroid Build Coastguard Worker
1704*da0073e9SAndroid Build Coastguard Worker            # Convert strided tensor inputs to COW tensors and make copies of
1705*da0073e9SAndroid Build Coastguard Worker            # all inputs
1706*da0073e9SAndroid Build Coastguard Worker            for idx, arg in enumerate(args_raw):
1707*da0073e9SAndroid Build Coastguard Worker                if is_strided_tensor(arg):
1708*da0073e9SAndroid Build Coastguard Worker                    args_copy.append(arg.clone().detach())
1709*da0073e9SAndroid Build Coastguard Worker                    args.append(torch._lazy_clone(arg))
1710*da0073e9SAndroid Build Coastguard Worker                else:
1711*da0073e9SAndroid Build Coastguard Worker                    if torch.is_tensor(arg):
1712*da0073e9SAndroid Build Coastguard Worker                        args_copy.append(arg.clone().detach())
1713*da0073e9SAndroid Build Coastguard Worker                    else:
1714*da0073e9SAndroid Build Coastguard Worker                        args_copy.append(copy.deepcopy(arg))
1715*da0073e9SAndroid Build Coastguard Worker                    args.append(arg)
1716*da0073e9SAndroid Build Coastguard Worker
1717*da0073e9SAndroid Build Coastguard Worker            for kw, arg in kwargs_raw.items():
1718*da0073e9SAndroid Build Coastguard Worker                if is_strided_tensor(arg):
1719*da0073e9SAndroid Build Coastguard Worker                    kwargs_copy[kw] = arg.clone().detach()
1720*da0073e9SAndroid Build Coastguard Worker                    kwargs[kw] = torch._lazy_clone(arg)
1721*da0073e9SAndroid Build Coastguard Worker                else:
1722*da0073e9SAndroid Build Coastguard Worker                    if torch.is_tensor(arg):
1723*da0073e9SAndroid Build Coastguard Worker                        kwargs_copy[kw] = arg.clone().detach()
1724*da0073e9SAndroid Build Coastguard Worker                    else:
1725*da0073e9SAndroid Build Coastguard Worker                        kwargs_copy[kw] = copy.deepcopy(arg)
1726*da0073e9SAndroid Build Coastguard Worker                    kwargs[kw] = arg
1727*da0073e9SAndroid Build Coastguard Worker
1728*da0073e9SAndroid Build Coastguard Worker            leaf_tensors = composite_compliance.gather_leaf_tensors(args, kwargs)
1729*da0073e9SAndroid Build Coastguard Worker
1730*da0073e9SAndroid Build Coastguard Worker            # Call forward op
1731*da0073e9SAndroid Build Coastguard Worker            results_raw = op.get_op()(*args, **kwargs)
1732*da0073e9SAndroid Build Coastguard Worker
1733*da0073e9SAndroid Build Coastguard Worker            # Check that COW inputs remain COW after the forward op is executed
1734*da0073e9SAndroid Build Coastguard Worker            for idx, arg in enumerate(args):
1735*da0073e9SAndroid Build Coastguard Worker                check_cow_input(arg, args_copy[idx], idx)
1736*da0073e9SAndroid Build Coastguard Worker
1737*da0073e9SAndroid Build Coastguard Worker            for kw, arg in kwargs.items():
1738*da0073e9SAndroid Build Coastguard Worker                check_cow_input(arg, kwargs_copy[kw], kw)
1739*da0073e9SAndroid Build Coastguard Worker
1740*da0073e9SAndroid Build Coastguard Worker            # Call backward op if it is supported. This part of the test is
1741*da0073e9SAndroid Build Coastguard Worker            # based on `composite_compliance.check_backward_formula`
1742*da0073e9SAndroid Build Coastguard Worker            if (
1743*da0073e9SAndroid Build Coastguard Worker                op.supports_autograd
1744*da0073e9SAndroid Build Coastguard Worker                and len(leaf_tensors) > 0
1745*da0073e9SAndroid Build Coastguard Worker                and not op.skip_cow_input_backward
1746*da0073e9SAndroid Build Coastguard Worker            ):
1747*da0073e9SAndroid Build Coastguard Worker                if sample.output_process_fn_grad is not None:
1748*da0073e9SAndroid Build Coastguard Worker                    results_raw = sample.output_process_fn_grad(results_raw)
1749*da0073e9SAndroid Build Coastguard Worker
1750*da0073e9SAndroid Build Coastguard Worker                leaf_results = pytree.tree_leaves(results_raw)
1751*da0073e9SAndroid Build Coastguard Worker                results = [
1752*da0073e9SAndroid Build Coastguard Worker                    r
1753*da0073e9SAndroid Build Coastguard Worker                    for r in leaf_results
1754*da0073e9SAndroid Build Coastguard Worker                    if isinstance(r, torch.Tensor) and r.requires_grad
1755*da0073e9SAndroid Build Coastguard Worker                ]
1756*da0073e9SAndroid Build Coastguard Worker
1757*da0073e9SAndroid Build Coastguard Worker                all_results_strided = all(
1758*da0073e9SAndroid Build Coastguard Worker                    is_strided_tensor(result) for result in results
1759*da0073e9SAndroid Build Coastguard Worker                )
1760*da0073e9SAndroid Build Coastguard Worker
1761*da0073e9SAndroid Build Coastguard Worker                # Only test backward if the results are strided tensors
1762*da0073e9SAndroid Build Coastguard Worker                if all_results_strided:
1763*da0073e9SAndroid Build Coastguard Worker                    output_grads_raw = [
1764*da0073e9SAndroid Build Coastguard Worker                        torch.ones(r.shape, device=r.device, dtype=r.dtype)
1765*da0073e9SAndroid Build Coastguard Worker                        for r in results
1766*da0073e9SAndroid Build Coastguard Worker                    ]
1767*da0073e9SAndroid Build Coastguard Worker                    output_grads_copy = []
1768*da0073e9SAndroid Build Coastguard Worker                    output_grads = []
1769*da0073e9SAndroid Build Coastguard Worker
1770*da0073e9SAndroid Build Coastguard Worker                    # Convert output grads to COW tensors and make copies
1771*da0073e9SAndroid Build Coastguard Worker                    for output_grad in output_grads_raw:
1772*da0073e9SAndroid Build Coastguard Worker                        output_grads_copy.append(output_grad.clone().detach())
1773*da0073e9SAndroid Build Coastguard Worker                        output_grads.append(torch._lazy_clone(output_grad))
1774*da0073e9SAndroid Build Coastguard Worker
1775*da0073e9SAndroid Build Coastguard Worker                    input_grads = torch.autograd.grad(
1776*da0073e9SAndroid Build Coastguard Worker                        results,
1777*da0073e9SAndroid Build Coastguard Worker                        leaf_tensors,
1778*da0073e9SAndroid Build Coastguard Worker                        output_grads,
1779*da0073e9SAndroid Build Coastguard Worker                        allow_unused=True,
1780*da0073e9SAndroid Build Coastguard Worker                        retain_graph=True,
1781*da0073e9SAndroid Build Coastguard Worker                    )
1782*da0073e9SAndroid Build Coastguard Worker
1783*da0073e9SAndroid Build Coastguard Worker                    # Check that COW inputs remain COW after the backward op is executed
1784*da0073e9SAndroid Build Coastguard Worker                    for idx, arg in enumerate(args):
1785*da0073e9SAndroid Build Coastguard Worker                        check_cow_input(
1786*da0073e9SAndroid Build Coastguard Worker                            arg,
1787*da0073e9SAndroid Build Coastguard Worker                            args_copy[idx],
1788*da0073e9SAndroid Build Coastguard Worker                            idx,
1789*da0073e9SAndroid Build Coastguard Worker                            backward_or_forward="backward",
1790*da0073e9SAndroid Build Coastguard Worker                            supports_cow_input_no_materialize=op.supports_cow_input_no_materialize_backward,
1791*da0073e9SAndroid Build Coastguard Worker                            allow_list=op.allow_cow_input_materialize_backward,
1792*da0073e9SAndroid Build Coastguard Worker                        )
1793*da0073e9SAndroid Build Coastguard Worker
1794*da0073e9SAndroid Build Coastguard Worker                    # Check that COW inputs remain COW after the backward op is executed
1795*da0073e9SAndroid Build Coastguard Worker                    for idx, output_grad in enumerate(output_grads):
1796*da0073e9SAndroid Build Coastguard Worker                        check_cow_input(
1797*da0073e9SAndroid Build Coastguard Worker                            output_grad,
1798*da0073e9SAndroid Build Coastguard Worker                            output_grads_copy[idx],
1799*da0073e9SAndroid Build Coastguard Worker                            f"output grad {idx}",
1800*da0073e9SAndroid Build Coastguard Worker                            backward_or_forward="backward",
1801*da0073e9SAndroid Build Coastguard Worker                            supports_cow_input_no_materialize=op.supports_cow_input_no_materialize_backward,
1802*da0073e9SAndroid Build Coastguard Worker                            allow_list=op.allow_cow_input_materialize_backward,
1803*da0073e9SAndroid Build Coastguard Worker                        )
1804*da0073e9SAndroid Build Coastguard Worker
1805*da0073e9SAndroid Build Coastguard Worker    @ops(op_db, allowed_dtypes=(torch.float,))
1806*da0073e9SAndroid Build Coastguard Worker    def test_view_replay(self, device, dtype, op):
1807*da0073e9SAndroid Build Coastguard Worker        def _assert_match_metadata(a, b):
1808*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a.size(), b.size())
1809*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a.stride(), b.stride())
1810*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a.storage_offset(), b.storage_offset())
1811*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a.device, b.device)
1812*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a.dtype, b.dtype)
1813*da0073e9SAndroid Build Coastguard Worker
1814*da0073e9SAndroid Build Coastguard Worker        # ensure view replay is enabled
1815*da0073e9SAndroid Build Coastguard Worker        with torch.autograd._force_original_view_tracking(True):
1816*da0073e9SAndroid Build Coastguard Worker            for sample in op.sample_inputs(device, dtype, requires_grad=False):
1817*da0073e9SAndroid Build Coastguard Worker                inp = sample.input
1818*da0073e9SAndroid Build Coastguard Worker                outs = op(inp, *sample.args, **sample.kwargs)
1819*da0073e9SAndroid Build Coastguard Worker                if not isinstance(outs, (tuple, List)):
1820*da0073e9SAndroid Build Coastguard Worker                    outs = [outs]
1821*da0073e9SAndroid Build Coastguard Worker
1822*da0073e9SAndroid Build Coastguard Worker                # for all outputs that are views of the input, we should be able to replay the
1823*da0073e9SAndroid Build Coastguard Worker                # forward and reverse views via a functioning view_func() / rev_view_func().
1824*da0073e9SAndroid Build Coastguard Worker                for out in outs:
1825*da0073e9SAndroid Build Coastguard Worker                    if not (
1826*da0073e9SAndroid Build Coastguard Worker                        isinstance(out, torch.Tensor)
1827*da0073e9SAndroid Build Coastguard Worker                        and out._is_view()
1828*da0073e9SAndroid Build Coastguard Worker                        and out._base is inp
1829*da0073e9SAndroid Build Coastguard Worker                    ):
1830*da0073e9SAndroid Build Coastguard Worker                        continue
1831*da0073e9SAndroid Build Coastguard Worker
1832*da0073e9SAndroid Build Coastguard Worker                    # forward view_func
1833*da0073e9SAndroid Build Coastguard Worker                    new_inp = inp.clone()
1834*da0073e9SAndroid Build Coastguard Worker                    _assert_match_metadata(new_inp, inp)
1835*da0073e9SAndroid Build Coastguard Worker                    new_out = out._view_func_unsafe(new_inp)
1836*da0073e9SAndroid Build Coastguard Worker                    _assert_match_metadata(new_out, out)
1837*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(new_out, out)
1838*da0073e9SAndroid Build Coastguard Worker
1839*da0073e9SAndroid Build Coastguard Worker                    # reverse view_func
1840*da0073e9SAndroid Build Coastguard Worker                    new_out = out.detach()
1841*da0073e9SAndroid Build Coastguard Worker                    new_inp = out._rev_view_func_unsafe(new_out)
1842*da0073e9SAndroid Build Coastguard Worker                    _assert_match_metadata(new_inp, inp)
1843*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(new_inp._is_view())
1844*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(new_inp._base is new_out)
1845*da0073e9SAndroid Build Coastguard Worker
1846*da0073e9SAndroid Build Coastguard Worker
1847*da0073e9SAndroid Build Coastguard Worker@unMarkDynamoStrictTest
1848*da0073e9SAndroid Build Coastguard Workerclass TestMathBits(TestCase):
1849*da0073e9SAndroid Build Coastguard Worker    # Tests that
1850*da0073e9SAndroid Build Coastguard Worker    # 1. The operator's output for physically conjugated/negated tensors and conjugate/negative view tensors
1851*da0073e9SAndroid Build Coastguard Worker    # produces the same value
1852*da0073e9SAndroid Build Coastguard Worker    # 2. The gradients are same in both cases mentioned in (1)
1853*da0073e9SAndroid Build Coastguard Worker    # 3. If the operator's inplace variant is supported, tests that the inplace operation
1854*da0073e9SAndroid Build Coastguard Worker    #    produces the correct value when called on a conjugate/negative view tensor and that the output
1855*da0073e9SAndroid Build Coastguard Worker    #    has its conj/neg bit set to true
1856*da0073e9SAndroid Build Coastguard Worker    # This test only runs for C -> R and C -> C functions
1857*da0073e9SAndroid Build Coastguard Worker    # TODO: add tests for `R->C` functions
1858*da0073e9SAndroid Build Coastguard Worker    # Note: This test runs for functions that take both tensors and tensorlists as input.
1859*da0073e9SAndroid Build Coastguard Worker    def _test_math_view(
1860*da0073e9SAndroid Build Coastguard Worker        self,
1861*da0073e9SAndroid Build Coastguard Worker        device,
1862*da0073e9SAndroid Build Coastguard Worker        dtype,
1863*da0073e9SAndroid Build Coastguard Worker        op,
1864*da0073e9SAndroid Build Coastguard Worker        samples,
1865*da0073e9SAndroid Build Coastguard Worker        math_op_physical,
1866*da0073e9SAndroid Build Coastguard Worker        math_op_view,
1867*da0073e9SAndroid Build Coastguard Worker        is_bit_set,
1868*da0073e9SAndroid Build Coastguard Worker        out_type,
1869*da0073e9SAndroid Build Coastguard Worker    ):
1870*da0073e9SAndroid Build Coastguard Worker        inplace_variant = op.inplace_variant
1871*da0073e9SAndroid Build Coastguard Worker
1872*da0073e9SAndroid Build Coastguard Worker        # helper function to clone and conjugate/negate the input if its a tensor
1873*da0073e9SAndroid Build Coastguard Worker        # else clone the sequence and conjugate/negate the first element in the sequence
1874*da0073e9SAndroid Build Coastguard Worker        # If a requires_grad argument is provided the tensor being conjugated/negated will
1875*da0073e9SAndroid Build Coastguard Worker        # have its requires_grad set to that value.
1876*da0073e9SAndroid Build Coastguard Worker        def clone_and_perform_view(input, **kwargs):
1877*da0073e9SAndroid Build Coastguard Worker            if isinstance(input, torch.Tensor):
1878*da0073e9SAndroid Build Coastguard Worker                requires_grad = kwargs.get("requires_grad", input.requires_grad)
1879*da0073e9SAndroid Build Coastguard Worker                with torch.no_grad():
1880*da0073e9SAndroid Build Coastguard Worker                    # Ensure view represents the original sample input
1881*da0073e9SAndroid Build Coastguard Worker                    input = math_op_physical(input)
1882*da0073e9SAndroid Build Coastguard Worker                # Note: .conj() is not called under no_grad mode since it's not allowed to modify a
1883*da0073e9SAndroid Build Coastguard Worker                # view created in no_grad mode. Here it's ok to do so, so as a workaround we call conj
1884*da0073e9SAndroid Build Coastguard Worker                # before resetting the requires_grad field for input
1885*da0073e9SAndroid Build Coastguard Worker                input = math_op_view(input)
1886*da0073e9SAndroid Build Coastguard Worker                assert input.is_leaf
1887*da0073e9SAndroid Build Coastguard Worker                return input.requires_grad_(requires_grad)
1888*da0073e9SAndroid Build Coastguard Worker
1889*da0073e9SAndroid Build Coastguard Worker            if isinstance(input, Sequence):
1890*da0073e9SAndroid Build Coastguard Worker                out = list(map(clone_input_helper, input))
1891*da0073e9SAndroid Build Coastguard Worker                out[0] = clone_and_perform_view(out[0])
1892*da0073e9SAndroid Build Coastguard Worker                return tuple(out)
1893*da0073e9SAndroid Build Coastguard Worker
1894*da0073e9SAndroid Build Coastguard Worker        for sample in samples:
1895*da0073e9SAndroid Build Coastguard Worker            tensor = (
1896*da0073e9SAndroid Build Coastguard Worker                sample.input
1897*da0073e9SAndroid Build Coastguard Worker                if isinstance(sample.input, torch.Tensor)
1898*da0073e9SAndroid Build Coastguard Worker                else sample.input[0]
1899*da0073e9SAndroid Build Coastguard Worker            )
1900*da0073e9SAndroid Build Coastguard Worker            cloned1 = clone_and_perform_view(sample.input)
1901*da0073e9SAndroid Build Coastguard Worker
1902*da0073e9SAndroid Build Coastguard Worker            # Computes function forward value with a physically conjugated/negated tensor and
1903*da0073e9SAndroid Build Coastguard Worker            # a conj/neg view tensor and verifies that the output in both case are equal.
1904*da0073e9SAndroid Build Coastguard Worker            expected_forward = op(sample.input, *sample.args, **sample.kwargs)
1905*da0073e9SAndroid Build Coastguard Worker            forward_with_mathview = op(cloned1, *sample.args, **sample.kwargs)
1906*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected_forward, forward_with_mathview)
1907*da0073e9SAndroid Build Coastguard Worker
1908*da0073e9SAndroid Build Coastguard Worker            # If the op has an inplace variant, and the input doesn't require broadcasting
1909*da0073e9SAndroid Build Coastguard Worker            # and has the same dtype as output, verify that the inplace operation on a conjugated/negated
1910*da0073e9SAndroid Build Coastguard Worker            # input produces correct output, and the output tensor has the conj/neg bit set to True
1911*da0073e9SAndroid Build Coastguard Worker            if inplace_variant is not None and not sample.broadcasts_input:
1912*da0073e9SAndroid Build Coastguard Worker                cloned2 = clone_and_perform_view(tensor, requires_grad=False)
1913*da0073e9SAndroid Build Coastguard Worker                if (
1914*da0073e9SAndroid Build Coastguard Worker                    isinstance(expected_forward, torch.Tensor)
1915*da0073e9SAndroid Build Coastguard Worker                    and expected_forward.dtype is tensor.dtype
1916*da0073e9SAndroid Build Coastguard Worker                ):
1917*da0073e9SAndroid Build Coastguard Worker                    inplace_forward = inplace_variant(
1918*da0073e9SAndroid Build Coastguard Worker                        cloned2, *sample.args, **sample.kwargs
1919*da0073e9SAndroid Build Coastguard Worker                    )
1920*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(is_bit_set(inplace_forward))
1921*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(inplace_forward, expected_forward)
1922*da0073e9SAndroid Build Coastguard Worker
1923*da0073e9SAndroid Build Coastguard Worker            # TODO: backward consistency only supported for single tensor outputs
1924*da0073e9SAndroid Build Coastguard Worker            # TODO: backward consistency only checked on sample.input, not all
1925*da0073e9SAndroid Build Coastguard Worker            #   tensor inputs
1926*da0073e9SAndroid Build Coastguard Worker            # TODO: update to handle checking grads of all tensor inputs as
1927*da0073e9SAndroid Build Coastguard Worker            #   derived from each tensor output
1928*da0073e9SAndroid Build Coastguard Worker            if (
1929*da0073e9SAndroid Build Coastguard Worker                isinstance(expected_forward, torch.Tensor)
1930*da0073e9SAndroid Build Coastguard Worker                and expected_forward.requires_grad
1931*da0073e9SAndroid Build Coastguard Worker            ):
1932*da0073e9SAndroid Build Coastguard Worker                output_process_fn_grad = sample.output_process_fn_grad or (lambda x: x)
1933*da0073e9SAndroid Build Coastguard Worker                expected_forward = output_process_fn_grad(expected_forward)
1934*da0073e9SAndroid Build Coastguard Worker                forward_with_mathview = output_process_fn_grad(forward_with_mathview)
1935*da0073e9SAndroid Build Coastguard Worker
1936*da0073e9SAndroid Build Coastguard Worker                tensor = (
1937*da0073e9SAndroid Build Coastguard Worker                    sample.input
1938*da0073e9SAndroid Build Coastguard Worker                    if isinstance(sample.input, torch.Tensor)
1939*da0073e9SAndroid Build Coastguard Worker                    else sample.input[0]
1940*da0073e9SAndroid Build Coastguard Worker                )
1941*da0073e9SAndroid Build Coastguard Worker                expected_forward.sum().abs().backward(retain_graph=True)
1942*da0073e9SAndroid Build Coastguard Worker                forward_with_mathview.sum().abs().backward(retain_graph=True)
1943*da0073e9SAndroid Build Coastguard Worker                if tensor.grad is not None:
1944*da0073e9SAndroid Build Coastguard Worker                    cloned1_tensor = (
1945*da0073e9SAndroid Build Coastguard Worker                        cloned1 if isinstance(cloned1, torch.Tensor) else cloned1[0]
1946*da0073e9SAndroid Build Coastguard Worker                    )
1947*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(tensor.grad, cloned1_tensor.grad)
1948*da0073e9SAndroid Build Coastguard Worker
1949*da0073e9SAndroid Build Coastguard Worker                    tensor.grad, cloned1_tensor.grad = None, None
1950*da0073e9SAndroid Build Coastguard Worker
1951*da0073e9SAndroid Build Coastguard Worker                    # a repeat of the above test if output is not complex valued
1952*da0073e9SAndroid Build Coastguard Worker                    if out_type(expected_forward):
1953*da0073e9SAndroid Build Coastguard Worker                        grad = torch.randn_like(expected_forward)
1954*da0073e9SAndroid Build Coastguard Worker                        expected_forward.backward(grad)
1955*da0073e9SAndroid Build Coastguard Worker                        forward_with_mathview.backward(
1956*da0073e9SAndroid Build Coastguard Worker                            math_op_view(math_op_physical(grad))
1957*da0073e9SAndroid Build Coastguard Worker                        )
1958*da0073e9SAndroid Build Coastguard Worker
1959*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(tensor.grad, cloned1_tensor.grad)
1960*da0073e9SAndroid Build Coastguard Worker
1961*da0073e9SAndroid Build Coastguard Worker    @ops(ops_and_refs, allowed_dtypes=(torch.cfloat,))
1962*da0073e9SAndroid Build Coastguard Worker    def test_conj_view(self, device, dtype, op):
1963*da0073e9SAndroid Build Coastguard Worker        if not op.test_conjugated_samples:
1964*da0073e9SAndroid Build Coastguard Worker            self.skipTest("Operation doesn't support conjugated inputs.")
1965*da0073e9SAndroid Build Coastguard Worker        math_op_physical = torch.conj_physical
1966*da0073e9SAndroid Build Coastguard Worker        math_op_view = torch.conj
1967*da0073e9SAndroid Build Coastguard Worker        _requires_grad = torch.cfloat in op.supported_backward_dtypes(
1968*da0073e9SAndroid Build Coastguard Worker            torch.device(device).type
1969*da0073e9SAndroid Build Coastguard Worker        )
1970*da0073e9SAndroid Build Coastguard Worker        is_bit_set = torch.is_conj
1971*da0073e9SAndroid Build Coastguard Worker        samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)
1972*da0073e9SAndroid Build Coastguard Worker        self._test_math_view(
1973*da0073e9SAndroid Build Coastguard Worker            device,
1974*da0073e9SAndroid Build Coastguard Worker            dtype,
1975*da0073e9SAndroid Build Coastguard Worker            op,
1976*da0073e9SAndroid Build Coastguard Worker            samples,
1977*da0073e9SAndroid Build Coastguard Worker            math_op_physical,
1978*da0073e9SAndroid Build Coastguard Worker            math_op_view,
1979*da0073e9SAndroid Build Coastguard Worker            is_bit_set,
1980*da0073e9SAndroid Build Coastguard Worker            torch.is_complex,
1981*da0073e9SAndroid Build Coastguard Worker        )
1982*da0073e9SAndroid Build Coastguard Worker
1983*da0073e9SAndroid Build Coastguard Worker    @ops(ops_and_refs, allowed_dtypes=(torch.double,))
1984*da0073e9SAndroid Build Coastguard Worker    def test_neg_view(self, device, dtype, op):
1985*da0073e9SAndroid Build Coastguard Worker        if not op.test_neg_view:
1986*da0073e9SAndroid Build Coastguard Worker            self.skipTest("Operation not tested with tensors with negative bit.")
1987*da0073e9SAndroid Build Coastguard Worker        math_op_physical = torch.neg
1988*da0073e9SAndroid Build Coastguard Worker        math_op_view = torch._neg_view
1989*da0073e9SAndroid Build Coastguard Worker        is_bit_set = torch.is_neg
1990*da0073e9SAndroid Build Coastguard Worker        samples = op.sample_inputs(device, dtype, requires_grad=op.supports_autograd)
1991*da0073e9SAndroid Build Coastguard Worker        self._test_math_view(
1992*da0073e9SAndroid Build Coastguard Worker            device,
1993*da0073e9SAndroid Build Coastguard Worker            dtype,
1994*da0073e9SAndroid Build Coastguard Worker            op,
1995*da0073e9SAndroid Build Coastguard Worker            samples,
1996*da0073e9SAndroid Build Coastguard Worker            math_op_physical,
1997*da0073e9SAndroid Build Coastguard Worker            math_op_view,
1998*da0073e9SAndroid Build Coastguard Worker            is_bit_set,
1999*da0073e9SAndroid Build Coastguard Worker            lambda x: True,
2000*da0073e9SAndroid Build Coastguard Worker        )
2001*da0073e9SAndroid Build Coastguard Worker
2002*da0073e9SAndroid Build Coastguard Worker    @ops(ops_and_refs, allowed_dtypes=(torch.cdouble,))
2003*da0073e9SAndroid Build Coastguard Worker    def test_neg_conj_view(self, device, dtype, op):
2004*da0073e9SAndroid Build Coastguard Worker        if not op.test_neg_view:
2005*da0073e9SAndroid Build Coastguard Worker            self.skipTest("Operation not tested with tensors with negative bit.")
2006*da0073e9SAndroid Build Coastguard Worker        if not op.test_conjugated_samples:
2007*da0073e9SAndroid Build Coastguard Worker            self.skipTest("Operation doesn't support conjugated inputs.")
2008*da0073e9SAndroid Build Coastguard Worker
2009*da0073e9SAndroid Build Coastguard Worker        def math_op_physical(x):
2010*da0073e9SAndroid Build Coastguard Worker            return -x.conj_physical()
2011*da0073e9SAndroid Build Coastguard Worker
2012*da0073e9SAndroid Build Coastguard Worker        def math_op_view(x):
2013*da0073e9SAndroid Build Coastguard Worker            return torch._neg_view(x).conj()
2014*da0073e9SAndroid Build Coastguard Worker
2015*da0073e9SAndroid Build Coastguard Worker        def is_bit_set(x):
2016*da0073e9SAndroid Build Coastguard Worker            return torch.is_neg(x) and torch.is_conj(x)
2017*da0073e9SAndroid Build Coastguard Worker
2018*da0073e9SAndroid Build Coastguard Worker        _requires_grad = dtype in op.supported_backward_dtypes(
2019*da0073e9SAndroid Build Coastguard Worker            torch.device(device).type
2020*da0073e9SAndroid Build Coastguard Worker        )
2021*da0073e9SAndroid Build Coastguard Worker        samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)
2022*da0073e9SAndroid Build Coastguard Worker        # Only test one sample
2023*da0073e9SAndroid Build Coastguard Worker        samples = itertools.islice(samples, 1)
2024*da0073e9SAndroid Build Coastguard Worker        self._test_math_view(
2025*da0073e9SAndroid Build Coastguard Worker            device,
2026*da0073e9SAndroid Build Coastguard Worker            dtype,
2027*da0073e9SAndroid Build Coastguard Worker            op,
2028*da0073e9SAndroid Build Coastguard Worker            samples,
2029*da0073e9SAndroid Build Coastguard Worker            math_op_physical,
2030*da0073e9SAndroid Build Coastguard Worker            math_op_view,
2031*da0073e9SAndroid Build Coastguard Worker            is_bit_set,
2032*da0073e9SAndroid Build Coastguard Worker            torch.is_complex,
2033*da0073e9SAndroid Build Coastguard Worker        )
2034*da0073e9SAndroid Build Coastguard Worker
2035*da0073e9SAndroid Build Coastguard Worker
2036*da0073e9SAndroid Build Coastguard Worker# input strides and size may have been altered due to the result of an inplace op
2037*da0073e9SAndroid Build Coastguard Workerdef check_inplace_view(func, input, rs, input_size, input_strides):
2038*da0073e9SAndroid Build Coastguard Worker    if func is None:
2039*da0073e9SAndroid Build Coastguard Worker        return
2040*da0073e9SAndroid Build Coastguard Worker    # TODO: extend this test to test ops with multiple outputs and ops like native_batch_norm(_legit).out
2041*da0073e9SAndroid Build Coastguard Worker    # which mutate not necessarily the first input.
2042*da0073e9SAndroid Build Coastguard Worker    if isinstance(rs, torch.Tensor) and rs is input:
2043*da0073e9SAndroid Build Coastguard Worker        unequal_size = rs.size() != input_size
2044*da0073e9SAndroid Build Coastguard Worker        unequal_strides = rs.stride() != input_strides
2045*da0073e9SAndroid Build Coastguard Worker        # resize_ should probably have inplace_view tag. Not adding the tag since it
2046*da0073e9SAndroid Build Coastguard Worker        # breaks some codegen logic
2047*da0073e9SAndroid Build Coastguard Worker        if unequal_size or unequal_strides:
2048*da0073e9SAndroid Build Coastguard Worker            if isinstance(func, torch._ops.OpOverloadPacket):
2049*da0073e9SAndroid Build Coastguard Worker                func = func.default
2050*da0073e9SAndroid Build Coastguard Worker            # Reference: https://github.com/pytorch/pytorch/issues/78759
2051*da0073e9SAndroid Build Coastguard Worker            if func is not torch.ops.aten.resize_.default:
2052*da0073e9SAndroid Build Coastguard Worker                # TODO: use self.assertIn when we have separate tests for each tag
2053*da0073e9SAndroid Build Coastguard Worker                assert torch.Tag.inplace_view in func.tags
2054*da0073e9SAndroid Build Coastguard Worker
2055*da0073e9SAndroid Build Coastguard Worker
2056*da0073e9SAndroid Build Coastguard Worker# A mode that when enabled runs correctness checks to ensure
2057*da0073e9SAndroid Build Coastguard Worker# that operators have expected tags based on their input and
2058*da0073e9SAndroid Build Coastguard Worker# output tensor properties
2059*da0073e9SAndroid Build Coastguard Workerclass TestTagsMode(TorchDispatchMode):
2060*da0073e9SAndroid Build Coastguard Worker    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
2061*da0073e9SAndroid Build Coastguard Worker        if isinstance(args[0], torch.Tensor):
2062*da0073e9SAndroid Build Coastguard Worker            old_size = args[0].size()
2063*da0073e9SAndroid Build Coastguard Worker            old_stride = args[0].stride()
2064*da0073e9SAndroid Build Coastguard Worker            rs = func(*args, **kwargs)
2065*da0073e9SAndroid Build Coastguard Worker            check_inplace_view(func, args[0], rs, old_size, old_stride)
2066*da0073e9SAndroid Build Coastguard Worker        else:
2067*da0073e9SAndroid Build Coastguard Worker            rs = func(*args, **kwargs)
2068*da0073e9SAndroid Build Coastguard Worker        return rs
2069*da0073e9SAndroid Build Coastguard Worker
2070*da0073e9SAndroid Build Coastguard Worker
2071*da0073e9SAndroid Build Coastguard Worker# Test to verify the correctness for tags in `tags.yaml`, also available for access through `torch.Tags`
2072*da0073e9SAndroid Build Coastguard Worker@unMarkDynamoStrictTest
2073*da0073e9SAndroid Build Coastguard Workerclass TestTags(TestCase):
2074*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
2075*da0073e9SAndroid Build Coastguard Worker    @ops(ops_and_refs, dtypes=OpDTypes.any_one)
2076*da0073e9SAndroid Build Coastguard Worker    def test_tags(self, device, dtype, op):
2077*da0073e9SAndroid Build Coastguard Worker        samples = op.sample_inputs(device, dtype, requires_grad=False)
2078*da0073e9SAndroid Build Coastguard Worker        for sample in samples:
2079*da0073e9SAndroid Build Coastguard Worker            # TODO: Test tags for ops that return a list of tensors
2080*da0073e9SAndroid Build Coastguard Worker            input = sample.input
2081*da0073e9SAndroid Build Coastguard Worker            if isinstance(input, torch.Tensor):
2082*da0073e9SAndroid Build Coastguard Worker                old_size = input.size()
2083*da0073e9SAndroid Build Coastguard Worker                old_stride = input.stride()
2084*da0073e9SAndroid Build Coastguard Worker                with TestTagsMode():
2085*da0073e9SAndroid Build Coastguard Worker                    rs = op(input, *sample.args, **sample.kwargs)
2086*da0073e9SAndroid Build Coastguard Worker                # TODO: add test for aliases: https://github.com/pytorch/pytorch/issues/78761
2087*da0073e9SAndroid Build Coastguard Worker                aten_name = op.aten_name if op.aten_name is not None else op.name
2088*da0073e9SAndroid Build Coastguard Worker                opoverloadpacket = getattr(torch.ops.aten, aten_name, None)
2089*da0073e9SAndroid Build Coastguard Worker                check_inplace_view(opoverloadpacket, input, rs, old_size, old_stride)
2090*da0073e9SAndroid Build Coastguard Worker
2091*da0073e9SAndroid Build Coastguard Worker
2092*da0073e9SAndroid Build Coastguard Workerclass TestSelfKwarg(TestCase):
2093*da0073e9SAndroid Build Coastguard Worker    def test_self_kwargs(self):
2094*da0073e9SAndroid Build Coastguard Worker        """Verify that we can call the aten ops with all kwargs even if the
2095*da0073e9SAndroid Build Coastguard Worker        argument's name is "self"
2096*da0073e9SAndroid Build Coastguard Worker        """
2097*da0073e9SAndroid Build Coastguard Worker        torch.ops.aten.reshape.default(self=torch.rand(1, 2), shape=[2])
2098*da0073e9SAndroid Build Coastguard Worker        torch.ops.aten.min.default(self=torch.rand(100))
2099*da0073e9SAndroid Build Coastguard Worker
2100*da0073e9SAndroid Build Coastguard Worker
2101*da0073e9SAndroid Build Coastguard Worker@unMarkDynamoStrictTest
2102*da0073e9SAndroid Build Coastguard Workerclass TestRefsOpsInfo(TestCase):
2103*da0073e9SAndroid Build Coastguard Worker    import_paths = [
2104*da0073e9SAndroid Build Coastguard Worker        "_refs",
2105*da0073e9SAndroid Build Coastguard Worker        "_refs.special",
2106*da0073e9SAndroid Build Coastguard Worker        "_refs.nn.functional",
2107*da0073e9SAndroid Build Coastguard Worker        "_refs.fft",
2108*da0073e9SAndroid Build Coastguard Worker        "_refs._conversions",
2109*da0073e9SAndroid Build Coastguard Worker    ]
2110*da0073e9SAndroid Build Coastguard Worker    module_alls = [
2111*da0073e9SAndroid Build Coastguard Worker        (path, import_module(f"torch.{path}").__all__) for path in import_paths
2112*da0073e9SAndroid Build Coastguard Worker    ]
2113*da0073e9SAndroid Build Coastguard Worker    ref_ops_names = tuple(
2114*da0073e9SAndroid Build Coastguard Worker        itertools.chain.from_iterable(
2115*da0073e9SAndroid Build Coastguard Worker            [f"{path}.{op}" for op in module_all] for path, module_all in module_alls
2116*da0073e9SAndroid Build Coastguard Worker        )
2117*da0073e9SAndroid Build Coastguard Worker    )
2118*da0073e9SAndroid Build Coastguard Worker    ref_db_names = {ref_op.name for ref_op in python_ref_db}
2119*da0073e9SAndroid Build Coastguard Worker
2120*da0073e9SAndroid Build Coastguard Worker    # TODO: References that do not have an entry in python_ref_db
2121*da0073e9SAndroid Build Coastguard Worker    skip_ref_ops = {
2122*da0073e9SAndroid Build Coastguard Worker        "_refs.alias",
2123*da0073e9SAndroid Build Coastguard Worker        "_refs.bitwise_right_shift",
2124*da0073e9SAndroid Build Coastguard Worker        "_refs.copy_to",
2125*da0073e9SAndroid Build Coastguard Worker        "_refs.empty_permuted",
2126*da0073e9SAndroid Build Coastguard Worker        "_refs.empty_strided",
2127*da0073e9SAndroid Build Coastguard Worker        "_refs.equal",
2128*da0073e9SAndroid Build Coastguard Worker        "_refs.full",
2129*da0073e9SAndroid Build Coastguard Worker        "_refs.full_like",
2130*da0073e9SAndroid Build Coastguard Worker        "_refs.is_complex",
2131*da0073e9SAndroid Build Coastguard Worker        "_refs.to",
2132*da0073e9SAndroid Build Coastguard Worker        "_refs.mvlgamma",
2133*da0073e9SAndroid Build Coastguard Worker        "_refs.ones",
2134*da0073e9SAndroid Build Coastguard Worker        "_refs.ones_like",
2135*da0073e9SAndroid Build Coastguard Worker        "_refs.special.expit",
2136*da0073e9SAndroid Build Coastguard Worker        "_refs.std_var",
2137*da0073e9SAndroid Build Coastguard Worker        "_refs.swap_axes",
2138*da0073e9SAndroid Build Coastguard Worker        "_refs.uniform",
2139*da0073e9SAndroid Build Coastguard Worker        "_refs.scalar_tensor",
2140*da0073e9SAndroid Build Coastguard Worker        "_refs.trunc_divide",
2141*da0073e9SAndroid Build Coastguard Worker        "_refs.zero",
2142*da0073e9SAndroid Build Coastguard Worker        "_refs.zeros",
2143*da0073e9SAndroid Build Coastguard Worker        "_refs.zeros_like",
2144*da0073e9SAndroid Build Coastguard Worker        "_refs.rfloordiv",
2145*da0073e9SAndroid Build Coastguard Worker        "_refs.rtruediv",
2146*da0073e9SAndroid Build Coastguard Worker        "_refs.rpow",
2147*da0073e9SAndroid Build Coastguard Worker        # These should be tested with their out-of-place counterparts
2148*da0073e9SAndroid Build Coastguard Worker        "_refs.index_add_",
2149*da0073e9SAndroid Build Coastguard Worker        "_refs.index_copy_",
2150*da0073e9SAndroid Build Coastguard Worker        "_refs.index_fill_",
2151*da0073e9SAndroid Build Coastguard Worker        "_refs.native_group_norm",
2152*da0073e9SAndroid Build Coastguard Worker    }
2153*da0073e9SAndroid Build Coastguard Worker
2154*da0073e9SAndroid Build Coastguard Worker    not_in_decomp_table = {
2155*da0073e9SAndroid Build Coastguard Worker        # duplicated in _decomp and _refs
2156*da0073e9SAndroid Build Coastguard Worker        "_refs.nn.functional.group_norm",
2157*da0073e9SAndroid Build Coastguard Worker        "_refs.nn.functional.mse_loss",
2158*da0073e9SAndroid Build Coastguard Worker        "_refs.floor_divide",
2159*da0073e9SAndroid Build Coastguard Worker        # duplicated as refs do not have decent support for advanced indexing
2160*da0073e9SAndroid Build Coastguard Worker        "_refs.index_copy",
2161*da0073e9SAndroid Build Coastguard Worker        "_refs.index_copy_",
2162*da0073e9SAndroid Build Coastguard Worker        "_refs.index_add",
2163*da0073e9SAndroid Build Coastguard Worker        "_refs.index_add_",
2164*da0073e9SAndroid Build Coastguard Worker        # these are not aten ops?
2165*da0073e9SAndroid Build Coastguard Worker        "_refs._conversions.bfloat16",
2166*da0073e9SAndroid Build Coastguard Worker        "_refs._conversions.bool",
2167*da0073e9SAndroid Build Coastguard Worker        "_refs._conversions.byte",
2168*da0073e9SAndroid Build Coastguard Worker        "_refs._conversions.char",
2169*da0073e9SAndroid Build Coastguard Worker        "_refs._conversions.double",
2170*da0073e9SAndroid Build Coastguard Worker        "_refs._conversions.float",
2171*da0073e9SAndroid Build Coastguard Worker        "_refs._conversions.half",
2172*da0073e9SAndroid Build Coastguard Worker        "_refs._conversions.int",
2173*da0073e9SAndroid Build Coastguard Worker        "_refs._conversions.long",
2174*da0073e9SAndroid Build Coastguard Worker        "_refs._conversions.short",
2175*da0073e9SAndroid Build Coastguard Worker        "_refs._conversions.chalf",
2176*da0073e9SAndroid Build Coastguard Worker        "_refs._conversions.cfloat",
2177*da0073e9SAndroid Build Coastguard Worker        "_refs._conversions.cdouble",
2178*da0073e9SAndroid Build Coastguard Worker        "_refs.broadcast_shapes",
2179*da0073e9SAndroid Build Coastguard Worker        "_refs.broadcast_tensors",
2180*da0073e9SAndroid Build Coastguard Worker        "_refs.mvlgamma",
2181*da0073e9SAndroid Build Coastguard Worker        "_refs.nn.functional.layer_norm",
2182*da0073e9SAndroid Build Coastguard Worker        "_refs.nn.functional.tanhshrink",
2183*da0073e9SAndroid Build Coastguard Worker        "_refs.nn.functional.triplet_margin_loss",
2184*da0073e9SAndroid Build Coastguard Worker        "_refs.rfloordiv",
2185*da0073e9SAndroid Build Coastguard Worker        "_refs.rtruediv",
2186*da0073e9SAndroid Build Coastguard Worker        "_refs.rpow",
2187*da0073e9SAndroid Build Coastguard Worker        # CompositeImplicitAutograd
2188*da0073e9SAndroid Build Coastguard Worker        "_refs.allclose",
2189*da0073e9SAndroid Build Coastguard Worker        "_refs.atleast_1d",
2190*da0073e9SAndroid Build Coastguard Worker        "_refs.atleast_2d",
2191*da0073e9SAndroid Build Coastguard Worker        "_refs.atleast_3d",
2192*da0073e9SAndroid Build Coastguard Worker        "_refs.broadcast_to",
2193*da0073e9SAndroid Build Coastguard Worker        "_refs.chunk",
2194*da0073e9SAndroid Build Coastguard Worker        "_refs.column_stack",
2195*da0073e9SAndroid Build Coastguard Worker        "_refs.contiguous",
2196*da0073e9SAndroid Build Coastguard Worker        "_refs.dsplit",
2197*da0073e9SAndroid Build Coastguard Worker        "_refs.dstack",
2198*da0073e9SAndroid Build Coastguard Worker        "_refs.fill",
2199*da0073e9SAndroid Build Coastguard Worker        "_refs.fill_",
2200*da0073e9SAndroid Build Coastguard Worker        "_refs.flatten",
2201*da0073e9SAndroid Build Coastguard Worker        "_refs.fliplr",
2202*da0073e9SAndroid Build Coastguard Worker        "_refs.flipud",
2203*da0073e9SAndroid Build Coastguard Worker        "_refs.float_power",
2204*da0073e9SAndroid Build Coastguard Worker        "_refs.hsplit",
2205*da0073e9SAndroid Build Coastguard Worker        "_refs.hstack",
2206*da0073e9SAndroid Build Coastguard Worker        "_refs.isclose",
2207*da0073e9SAndroid Build Coastguard Worker        "_refs.isfinite",
2208*da0073e9SAndroid Build Coastguard Worker        "_refs.isreal",
2209*da0073e9SAndroid Build Coastguard Worker        "_refs.istft",
2210*da0073e9SAndroid Build Coastguard Worker        "_refs.log_softmax",
2211*da0073e9SAndroid Build Coastguard Worker        "_refs.movedim",
2212*da0073e9SAndroid Build Coastguard Worker        "_refs.narrow",
2213*da0073e9SAndroid Build Coastguard Worker        "_refs.nn.functional.dropout",
2214*da0073e9SAndroid Build Coastguard Worker        "_refs.nn.functional.l1_loss",
2215*da0073e9SAndroid Build Coastguard Worker        "_refs.nn.functional.smooth_l1_loss",
2216*da0073e9SAndroid Build Coastguard Worker        "_refs.nn.functional.log_softmax",
2217*da0073e9SAndroid Build Coastguard Worker        "_refs.nn.functional.poisson_nll_loss",
2218*da0073e9SAndroid Build Coastguard Worker        "_refs.nn.functional.softmax",
2219*da0073e9SAndroid Build Coastguard Worker        "_refs.nn.functional.softmin",
2220*da0073e9SAndroid Build Coastguard Worker        "_refs.positive",
2221*da0073e9SAndroid Build Coastguard Worker        "_refs.ravel",
2222*da0073e9SAndroid Build Coastguard Worker        "_refs.reshape",
2223*da0073e9SAndroid Build Coastguard Worker        "_refs.softmax",
2224*da0073e9SAndroid Build Coastguard Worker        "_refs.special.expit",
2225*da0073e9SAndroid Build Coastguard Worker        "_refs.special.log_softmax",
2226*da0073e9SAndroid Build Coastguard Worker        "_refs.special.softmax",
2227*da0073e9SAndroid Build Coastguard Worker        "_refs.square",
2228*da0073e9SAndroid Build Coastguard Worker        "_refs.stft",
2229*da0073e9SAndroid Build Coastguard Worker        "_refs.T",
2230*da0073e9SAndroid Build Coastguard Worker        "_refs.take_along_dim",
2231*da0073e9SAndroid Build Coastguard Worker        "_refs.tensor_split",
2232*da0073e9SAndroid Build Coastguard Worker        "_refs.to",
2233*da0073e9SAndroid Build Coastguard Worker        "_refs.true_divide",
2234*da0073e9SAndroid Build Coastguard Worker        "_refs.trunc_divide",
2235*da0073e9SAndroid Build Coastguard Worker        "_refs.vsplit",
2236*da0073e9SAndroid Build Coastguard Worker        "_refs.vstack",
2237*da0073e9SAndroid Build Coastguard Worker        "_refs.linalg.matrix_norm",
2238*da0073e9SAndroid Build Coastguard Worker        "_refs.linalg.norm",
2239*da0073e9SAndroid Build Coastguard Worker        "_refs.linalg.svd",
2240*da0073e9SAndroid Build Coastguard Worker        "_refs.linalg.svdvals",
2241*da0073e9SAndroid Build Coastguard Worker        "_refs.unflatten",
2242*da0073e9SAndroid Build Coastguard Worker        "_refs.sum_to_size",
2243*da0073e9SAndroid Build Coastguard Worker        # ref implementation missing kwargs
2244*da0073e9SAndroid Build Coastguard Worker        "_refs.full_like",  # missing "layout"
2245*da0073e9SAndroid Build Coastguard Worker        "_refs.scalar_tensor",  # missing "layout"
2246*da0073e9SAndroid Build Coastguard Worker        # other
2247*da0073e9SAndroid Build Coastguard Worker        "_refs.block_diag",  # only refs._block_diag_iterable is in decomposition table
2248*da0073e9SAndroid Build Coastguard Worker        "_refs.empty",  # intentional; direct empty is faster and has less guards
2249*da0073e9SAndroid Build Coastguard Worker        "_refs.empty_permuted",  # intentional; direct empty is faster and has less guards
2250*da0073e9SAndroid Build Coastguard Worker        "_refs.expand_as",
2251*da0073e9SAndroid Build Coastguard Worker        "_refs.as_strided",  # _prims._as_strided_meta: "reduce() of empty sequence with no initial value"
2252*da0073e9SAndroid Build Coastguard Worker        "_refs.copy_to",  # torch._C._jit_get_operation: No such operator aten::copy_to
2253*da0073e9SAndroid Build Coastguard Worker        "_refs.equal",  # 'bool' object has no attribute 'dtype'
2254*da0073e9SAndroid Build Coastguard Worker        "_refs.conj",  # Calls _prims.conj
2255*da0073e9SAndroid Build Coastguard Worker        "_refs.real",
2256*da0073e9SAndroid Build Coastguard Worker        "_refs.imag",
2257*da0073e9SAndroid Build Coastguard Worker        "_refs.reshape_as",
2258*da0073e9SAndroid Build Coastguard Worker        "_refs.view_as",
2259*da0073e9SAndroid Build Coastguard Worker        "_refs.view_as_complex",  # TorchInductor does not support complex at the moment.
2260*da0073e9SAndroid Build Coastguard Worker        # the decompositions for these ops are slightly different
2261*da0073e9SAndroid Build Coastguard Worker        # because of out handling
2262*da0073e9SAndroid Build Coastguard Worker        "_refs.var_mean",
2263*da0073e9SAndroid Build Coastguard Worker        "_refs.std_mean",
2264*da0073e9SAndroid Build Coastguard Worker        "_refs.native_layer_norm",
2265*da0073e9SAndroid Build Coastguard Worker    }
2266*da0073e9SAndroid Build Coastguard Worker
2267*da0073e9SAndroid Build Coastguard Worker    @parametrize("op", ref_ops_names)
2268*da0073e9SAndroid Build Coastguard Worker    def test_refs_are_in_python_ref_db(self, op):
2269*da0073e9SAndroid Build Coastguard Worker        inplace = op[-1] == "_"
2270*da0073e9SAndroid Build Coastguard Worker        if op in self.skip_ref_ops:
2271*da0073e9SAndroid Build Coastguard Worker            raise unittest.SkipTest(f"{op} does not have an entry in python_ref_db")
2272*da0073e9SAndroid Build Coastguard Worker        elif inplace:
2273*da0073e9SAndroid Build Coastguard Worker            self.assertNotIn(
2274*da0073e9SAndroid Build Coastguard Worker                op,
2275*da0073e9SAndroid Build Coastguard Worker                self.ref_db_names,
2276*da0073e9SAndroid Build Coastguard Worker                msg=f"{op} is an in-place operation and should not have an OpInfo",
2277*da0073e9SAndroid Build Coastguard Worker            )
2278*da0073e9SAndroid Build Coastguard Worker        else:
2279*da0073e9SAndroid Build Coastguard Worker            # Intentionally don't use assertIn to avoid printing the
2280*da0073e9SAndroid Build Coastguard Worker            # (very large) container
2281*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(op in self.ref_db_names, msg=f"{op} not in ref_db_names")
2282*da0073e9SAndroid Build Coastguard Worker
2283*da0073e9SAndroid Build Coastguard Worker    @parametrize("op", ref_ops_names)
2284*da0073e9SAndroid Build Coastguard Worker    def test_refs_are_in_decomp_table(self, op):
2285*da0073e9SAndroid Build Coastguard Worker        path = op.split(".")
2286*da0073e9SAndroid Build Coastguard Worker        module_path = ".".join(path[:-1])
2287*da0073e9SAndroid Build Coastguard Worker        op_name = path[-1]
2288*da0073e9SAndroid Build Coastguard Worker        op_impl = getattr(import_module(f"torch.{module_path}"), op_name)
2289*da0073e9SAndroid Build Coastguard Worker
2290*da0073e9SAndroid Build Coastguard Worker        if op in self.not_in_decomp_table:
2291*da0073e9SAndroid Build Coastguard Worker            self.assertNotIn(
2292*da0073e9SAndroid Build Coastguard Worker                op_impl,
2293*da0073e9SAndroid Build Coastguard Worker                torch._decomp.decomposition_table.values(),
2294*da0073e9SAndroid Build Coastguard Worker                f"Unexpectedly found {op} in torch._decomp.decomposition_table.values()",
2295*da0073e9SAndroid Build Coastguard Worker            )
2296*da0073e9SAndroid Build Coastguard Worker        else:
2297*da0073e9SAndroid Build Coastguard Worker            self.assertIn(
2298*da0073e9SAndroid Build Coastguard Worker                op_impl,
2299*da0073e9SAndroid Build Coastguard Worker                torch._decomp.decomposition_table.values(),
2300*da0073e9SAndroid Build Coastguard Worker                f"Did not find {op} in torch._decomp.decomposition_table.values()",
2301*da0073e9SAndroid Build Coastguard Worker            )
2302*da0073e9SAndroid Build Coastguard Worker
2303*da0073e9SAndroid Build Coastguard Worker
2304*da0073e9SAndroid Build Coastguard Workerfake_skips = (
2305*da0073e9SAndroid Build Coastguard Worker    "aminmax",  # failing input
2306*da0073e9SAndroid Build Coastguard Worker    "cov",  # aweights cannot be negtaive
2307*da0073e9SAndroid Build Coastguard Worker    "istft",  # window overlap add min: 0
2308*da0073e9SAndroid Build Coastguard Worker    "linalg.eigvals",  # The tensor has a non-zero number of elements, but its data is not allocated yet
2309*da0073e9SAndroid Build Coastguard Worker    "linalg.eigvalsh",  # aten::linalg_eigvalsh.out' with arguments from the 'Meta' backend
2310*da0073e9SAndroid Build Coastguard Worker    "linalg.matrix_power",  # Could not run 'aten::eye.m_out' with arguments from the 'Meta' backend
2311*da0073e9SAndroid Build Coastguard Worker    # "linalg.pinv",  # Could not run 'aten::pinv.out' with arguments from the 'Meta' backen
2312*da0073e9SAndroid Build Coastguard Worker    "linalg.matrix_rank.hermitian",  # Could not run 'aten::linalg_eigvalsh.out' with arguments from the 'Meta' backend
2313*da0073e9SAndroid Build Coastguard Worker    "linalg.pinv.hermitian",  # tensor.mH is only supported on matrices or batches of matrices. Got 1-D tensor
2314*da0073e9SAndroid Build Coastguard Worker    "linalg.solve",  # Could not run 'aten::linalg_solve' with arguments from the 'Meta' backend
2315*da0073e9SAndroid Build Coastguard Worker    "linalg.tensorsolve",  # Could not run 'aten::linalg_solve' with arguments from the 'Meta'
2316*da0073e9SAndroid Build Coastguard Worker    "lu_solve",  # MALLOC ERROR: debug
2317*da0073e9SAndroid Build Coastguard Worker    "multinomial",  # Could not run 'aten::multinomial' with arguments from the 'Meta' backend
2318*da0073e9SAndroid Build Coastguard Worker    "mvlgamma.mvlgamma_p_1",  # Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend
2319*da0073e9SAndroid Build Coastguard Worker    "mvlgamma.mvlgamma_p_3",  # Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend
2320*da0073e9SAndroid Build Coastguard Worker    "mvlgamma.mvlgamma_p_5",  # Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend
2321*da0073e9SAndroid Build Coastguard Worker    "nanmean",  # logical_not() got an unexpected keyword argument 'out'
2322*da0073e9SAndroid Build Coastguard Worker    "quantile",  # quantile() q values must be in the range [0, 1]
2323*da0073e9SAndroid Build Coastguard Worker    "nanquantile",  # quantile() q values must be in the range [0, 1]
2324*da0073e9SAndroid Build Coastguard Worker    "nn.functional.ctc_loss",  # The tensor has a non-zero number of elements, but its data is not allocated yet
2325*da0073e9SAndroid Build Coastguard Worker    "nn.functional.embedding_bag",  # sometimes errors
2326*da0073e9SAndroid Build Coastguard Worker    "nn.functional.nll_loss",  # sometimes errors
2327*da0073e9SAndroid Build Coastguard Worker    "nn.functional.max_pool1d",  # The tensor has a non-zero number of elements
2328*da0073e9SAndroid Build Coastguard Worker    "to_sparse",  # Could not run 'aten::_to_sparse' with arguments from the 'Meta' backend
2329*da0073e9SAndroid Build Coastguard Worker    "tensor_split",  # The tensor has a non-zero number of elements, but its data is not allocated yet
2330*da0073e9SAndroid Build Coastguard Worker    "repeat_interleave",  # cannot repeat_interleave a meta tensor without output_size
2331*da0073e9SAndroid Build Coastguard Worker    "sparse.sampled.addmm",  # sparsity not supported
2332*da0073e9SAndroid Build Coastguard Worker    # Can not infer total number of classes from meta. no way at present to throw DynamicOutputShapeException
2333*da0073e9SAndroid Build Coastguard Worker    "nn.functional.one_hot",
2334*da0073e9SAndroid Build Coastguard Worker    "narrow",  # Fails only for one overload with DataDependentOutputException (hence skip).
2335*da0073e9SAndroid Build Coastguard Worker)
2336*da0073e9SAndroid Build Coastguard Worker
2337*da0073e9SAndroid Build Coastguard Workerfake_autocast_device_skips = defaultdict(dict)
2338*da0073e9SAndroid Build Coastguard Worker
2339*da0073e9SAndroid Build Coastguard Worker# TODO: investigate/fix
2340*da0073e9SAndroid Build Coastguard Workerfake_autocast_device_skips["cpu"] = {"linalg.pinv"}
2341*da0073e9SAndroid Build Coastguard Workerfake_autocast_device_skips["cuda"] = {"linalg.pinv", "pinverse"}
2342*da0073e9SAndroid Build Coastguard Worker
2343*da0073e9SAndroid Build Coastguard Worker
2344*da0073e9SAndroid Build Coastguard Workerdynamic_output_op_tests = (
2345*da0073e9SAndroid Build Coastguard Worker    "argwhere",
2346*da0073e9SAndroid Build Coastguard Worker    "bincount",
2347*da0073e9SAndroid Build Coastguard Worker    "combinations",
2348*da0073e9SAndroid Build Coastguard Worker    "linalg.lstsq",
2349*da0073e9SAndroid Build Coastguard Worker    "masked_select",
2350*da0073e9SAndroid Build Coastguard Worker    "nonzero",
2351*da0073e9SAndroid Build Coastguard Worker    "unique_consecutive",
2352*da0073e9SAndroid Build Coastguard Worker    "unique",
2353*da0073e9SAndroid Build Coastguard Worker    "linalg.lstsq.grad_oriented",
2354*da0073e9SAndroid Build Coastguard Worker)
2355*da0073e9SAndroid Build Coastguard Worker
2356*da0073e9SAndroid Build Coastguard Worker# Ops that have dynamic output shapes that we can handle when
2357*da0073e9SAndroid Build Coastguard Worker# allow_dynamic_shape_ops is True in fake tensor shape environment.
2358*da0073e9SAndroid Build Coastguard Workersupported_dynamic_output_op_tests = (
2359*da0073e9SAndroid Build Coastguard Worker    "nonzero",
2360*da0073e9SAndroid Build Coastguard Worker    "unique",
2361*da0073e9SAndroid Build Coastguard Worker    "repeat_interleave",
2362*da0073e9SAndroid Build Coastguard Worker    "masked_select",
2363*da0073e9SAndroid Build Coastguard Worker)
2364*da0073e9SAndroid Build Coastguard Worker
2365*da0073e9SAndroid Build Coastguard Worker# some inputs invoke dynamic output shape operators, some do not
2366*da0073e9SAndroid Build Coastguard Workersometimes_dynamic_output_op_test = ("__getitem__", "index_select")
2367*da0073e9SAndroid Build Coastguard Worker
2368*da0073e9SAndroid Build Coastguard Workerdata_dependent_op_tests = (
2369*da0073e9SAndroid Build Coastguard Worker    "equal",
2370*da0073e9SAndroid Build Coastguard Worker    "corrcoef",
2371*da0073e9SAndroid Build Coastguard Worker    "nn.functional.gaussian_nll_loss",
2372*da0073e9SAndroid Build Coastguard Worker    "allclose",
2373*da0073e9SAndroid Build Coastguard Worker)
2374*da0073e9SAndroid Build Coastguard Worker
2375*da0073e9SAndroid Build Coastguard Workeraliasing_failures = ("histogramdd",)
2376*da0073e9SAndroid Build Coastguard Worker
2377*da0073e9SAndroid Build Coastguard Workerfake_backward_skips = {
2378*da0073e9SAndroid Build Coastguard Worker    "linalg.cond",
2379*da0073e9SAndroid Build Coastguard Worker    "linalg.matrix_norm",
2380*da0073e9SAndroid Build Coastguard Worker    "linalg.norm",
2381*da0073e9SAndroid Build Coastguard Worker    "linalg.svd",
2382*da0073e9SAndroid Build Coastguard Worker    "linalg.svdvals",
2383*da0073e9SAndroid Build Coastguard Worker    "pca_lowrank",
2384*da0073e9SAndroid Build Coastguard Worker    "roll",
2385*da0073e9SAndroid Build Coastguard Worker    "svd_lowrank",
2386*da0073e9SAndroid Build Coastguard Worker    "sgn",
2387*da0073e9SAndroid Build Coastguard Worker}
2388*da0073e9SAndroid Build Coastguard Worker
2389*da0073e9SAndroid Build Coastguard Workerfake_backward_xfails = {skip(s) for s in fake_backward_skips} | {
2390*da0073e9SAndroid Build Coastguard Worker    xfail("fft.ihfftn"),  # Mismatch in aten._conj_physical.default
2391*da0073e9SAndroid Build Coastguard Worker    xfail("fft.ihfft2"),  # Mismatch in aten._conj_physical.default
2392*da0073e9SAndroid Build Coastguard Worker    skip("nn.functional.ctc_loss"),
2393*da0073e9SAndroid Build Coastguard Worker}
2394*da0073e9SAndroid Build Coastguard Worker
2395*da0073e9SAndroid Build Coastguard Workerfake_autocast_backward_xfails = {
2396*da0073e9SAndroid Build Coastguard Worker    skip("nn.functional.binary_cross_entropy"),
2397*da0073e9SAndroid Build Coastguard Worker    skip("sparse.sampled_addmm"),
2398*da0073e9SAndroid Build Coastguard Worker    skip("linalg.pinv"),
2399*da0073e9SAndroid Build Coastguard Worker    skip("linalg.pinv", "hermitian"),
2400*da0073e9SAndroid Build Coastguard Worker    skip("linalg.pinv", "singular"),
2401*da0073e9SAndroid Build Coastguard Worker    skip("pinverse"),
2402*da0073e9SAndroid Build Coastguard Worker}
2403*da0073e9SAndroid Build Coastguard Worker
2404*da0073e9SAndroid Build Coastguard Worker
2405*da0073e9SAndroid Build Coastguard Worker@unMarkDynamoStrictTest
2406*da0073e9SAndroid Build Coastguard Workerclass TestFakeTensor(TestCase):
2407*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
2408*da0073e9SAndroid Build Coastguard Worker        # Turn on FakeTensor caching and cross-checking for these tests:
2409*da0073e9SAndroid Build Coastguard Worker        cache_enabled = unittest.mock.patch(
2410*da0073e9SAndroid Build Coastguard Worker            "torch._dynamo.config.fake_tensor_cache_enabled", True
2411*da0073e9SAndroid Build Coastguard Worker        )
2412*da0073e9SAndroid Build Coastguard Worker        cache_enabled.start()
2413*da0073e9SAndroid Build Coastguard Worker        self.addCleanup(cache_enabled.stop)
2414*da0073e9SAndroid Build Coastguard Worker
2415*da0073e9SAndroid Build Coastguard Worker        cache_crosscheck = unittest.mock.patch(
2416*da0073e9SAndroid Build Coastguard Worker            "torch._dynamo.config.fake_tensor_cache_crosscheck_enabled", True
2417*da0073e9SAndroid Build Coastguard Worker        )
2418*da0073e9SAndroid Build Coastguard Worker        cache_crosscheck.start()
2419*da0073e9SAndroid Build Coastguard Worker        self.addCleanup(cache_crosscheck.stop)
2420*da0073e9SAndroid Build Coastguard Worker
2421*da0073e9SAndroid Build Coastguard Worker    def _test_fake_helper(self, device, dtype, op, context):
2422*da0073e9SAndroid Build Coastguard Worker        name = op.name
2423*da0073e9SAndroid Build Coastguard Worker        if op.variant_test_name:
2424*da0073e9SAndroid Build Coastguard Worker            name += "." + op.variant_test_name
2425*da0073e9SAndroid Build Coastguard Worker        if name in fake_skips or "sparse" in name or "jiterator" in name:
2426*da0073e9SAndroid Build Coastguard Worker            self.skipTest("Skip failing test")
2427*da0073e9SAndroid Build Coastguard Worker
2428*da0073e9SAndroid Build Coastguard Worker        samples = op.sample_inputs(device, dtype, requires_grad=False)
2429*da0073e9SAndroid Build Coastguard Worker        for sample in samples:
2430*da0073e9SAndroid Build Coastguard Worker            mode = FakeTensorMode()
2431*da0073e9SAndroid Build Coastguard Worker
2432*da0073e9SAndroid Build Coastguard Worker            from torch.fx.experimental.symbolic_shapes import ShapeEnv
2433*da0073e9SAndroid Build Coastguard Worker
2434*da0073e9SAndroid Build Coastguard Worker            allow_dynamic_output_shape_shape_env = ShapeEnv(
2435*da0073e9SAndroid Build Coastguard Worker                allow_dynamic_output_shape_ops=True
2436*da0073e9SAndroid Build Coastguard Worker            )
2437*da0073e9SAndroid Build Coastguard Worker
2438*da0073e9SAndroid Build Coastguard Worker            allow_dynamic_output_shape_mode = FakeTensorMode(
2439*da0073e9SAndroid Build Coastguard Worker                shape_env=allow_dynamic_output_shape_shape_env
2440*da0073e9SAndroid Build Coastguard Worker            )
2441*da0073e9SAndroid Build Coastguard Worker
2442*da0073e9SAndroid Build Coastguard Worker            try:
2443*da0073e9SAndroid Build Coastguard Worker                with context():
2444*da0073e9SAndroid Build Coastguard Worker                    res = op(sample.input, *sample.args, **sample.kwargs)
2445*da0073e9SAndroid Build Coastguard Worker            except Exception:
2446*da0073e9SAndroid Build Coastguard Worker                continue
2447*da0073e9SAndroid Build Coastguard Worker
2448*da0073e9SAndroid Build Coastguard Worker            def run_with_fake_mode_and_verify(fake_mode, match_results=True):
2449*da0073e9SAndroid Build Coastguard Worker                def map_to_fake(e):
2450*da0073e9SAndroid Build Coastguard Worker                    if isinstance(e, torch.Tensor):
2451*da0073e9SAndroid Build Coastguard Worker                        return fake_mode.from_tensor(e)
2452*da0073e9SAndroid Build Coastguard Worker                    else:
2453*da0073e9SAndroid Build Coastguard Worker                        return e
2454*da0073e9SAndroid Build Coastguard Worker
2455*da0073e9SAndroid Build Coastguard Worker                input = tree_map(map_to_fake, sample.input)
2456*da0073e9SAndroid Build Coastguard Worker                args = tree_map(map_to_fake, sample.args)
2457*da0073e9SAndroid Build Coastguard Worker                kwargs = tree_map(map_to_fake, sample.kwargs)
2458*da0073e9SAndroid Build Coastguard Worker
2459*da0073e9SAndroid Build Coastguard Worker                try:
2460*da0073e9SAndroid Build Coastguard Worker                    with context():
2461*da0073e9SAndroid Build Coastguard Worker                        with fake_mode:
2462*da0073e9SAndroid Build Coastguard Worker                            res_fake = op(input, *args, **kwargs)
2463*da0073e9SAndroid Build Coastguard Worker
2464*da0073e9SAndroid Build Coastguard Worker                    if not match_results:
2465*da0073e9SAndroid Build Coastguard Worker                        return
2466*da0073e9SAndroid Build Coastguard Worker
2467*da0073e9SAndroid Build Coastguard Worker                    for fake_out, real_out in zip(
2468*da0073e9SAndroid Build Coastguard Worker                        pytree.tree_leaves(res_fake), pytree.tree_leaves(res)
2469*da0073e9SAndroid Build Coastguard Worker                    ):
2470*da0073e9SAndroid Build Coastguard Worker                        if not isinstance(fake_out, torch.Tensor):
2471*da0073e9SAndroid Build Coastguard Worker                            self.assertTrue(not isinstance(real_out, torch.Tensor))
2472*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(fake_out, real_out)
2473*da0073e9SAndroid Build Coastguard Worker                            continue
2474*da0073e9SAndroid Build Coastguard Worker
2475*da0073e9SAndroid Build Coastguard Worker                        self.assertTrue(isinstance(fake_out, FakeTensor))
2476*da0073e9SAndroid Build Coastguard Worker                        # if you see a shape exception here, you may need to add
2477*da0073e9SAndroid Build Coastguard Worker                        # a `dynamic_output_shape` tag to an operator
2478*da0073e9SAndroid Build Coastguard Worker
2479*da0073e9SAndroid Build Coastguard Worker                        if op.op not in [
2480*da0073e9SAndroid Build Coastguard Worker                            torch.ops.aten._efficient_attention_forward,
2481*da0073e9SAndroid Build Coastguard Worker                            torch.ops.aten._flash_attention_forward,
2482*da0073e9SAndroid Build Coastguard Worker                        ]:
2483*da0073e9SAndroid Build Coastguard Worker                            # prims/decomps must correctly model strides,
2484*da0073e9SAndroid Build Coastguard Worker                            # see https://github.com/pytorch/pytorch/issues/78050#issuecomment-1253950325
2485*da0073e9SAndroid Build Coastguard Worker
2486*da0073e9SAndroid Build Coastguard Worker                            # note: the excluded ops have intentionally incorrect device;
2487*da0073e9SAndroid Build Coastguard Worker                            # see "Note [Seed and Offset]" (_meta_registrations.py)
2488*da0073e9SAndroid Build Coastguard Worker                            prims.utils.compare_tensor_meta(fake_out, real_out, True)
2489*da0073e9SAndroid Build Coastguard Worker
2490*da0073e9SAndroid Build Coastguard Worker                        if name not in aliasing_failures:
2491*da0073e9SAndroid Build Coastguard Worker                            fake_aliasing = outputs_alias_inputs(
2492*da0073e9SAndroid Build Coastguard Worker                                (input, args, kwargs), res_fake
2493*da0073e9SAndroid Build Coastguard Worker                            )
2494*da0073e9SAndroid Build Coastguard Worker                            real_aliasing = outputs_alias_inputs(
2495*da0073e9SAndroid Build Coastguard Worker                                (sample.input, sample, args, sample.kwargs), res
2496*da0073e9SAndroid Build Coastguard Worker                            )
2497*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(fake_aliasing, real_aliasing)
2498*da0073e9SAndroid Build Coastguard Worker
2499*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(
2500*da0073e9SAndroid Build Coastguard Worker                        name not in dynamic_output_op_tests
2501*da0073e9SAndroid Build Coastguard Worker                        and name not in data_dependent_op_tests
2502*da0073e9SAndroid Build Coastguard Worker                    )
2503*da0073e9SAndroid Build Coastguard Worker
2504*da0073e9SAndroid Build Coastguard Worker                except torch._subclasses.fake_tensor.UnsupportedFakeTensorException:
2505*da0073e9SAndroid Build Coastguard Worker                    pass
2506*da0073e9SAndroid Build Coastguard Worker                except torch._subclasses.fake_tensor.UnsupportedOperatorException:
2507*da0073e9SAndroid Build Coastguard Worker                    pass
2508*da0073e9SAndroid Build Coastguard Worker                except torch._subclasses.fake_tensor.DynamicOutputShapeException:
2509*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(
2510*da0073e9SAndroid Build Coastguard Worker                        name in dynamic_output_op_tests
2511*da0073e9SAndroid Build Coastguard Worker                        or name in sometimes_dynamic_output_op_test
2512*da0073e9SAndroid Build Coastguard Worker                    )
2513*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(
2514*da0073e9SAndroid Build Coastguard Worker                        fake_mode.shape_env is None
2515*da0073e9SAndroid Build Coastguard Worker                        or not fake_mode.shape_env.allow_dynamic_output_shape_ops
2516*da0073e9SAndroid Build Coastguard Worker                        or name not in supported_dynamic_output_op_tests
2517*da0073e9SAndroid Build Coastguard Worker                    )
2518*da0073e9SAndroid Build Coastguard Worker                except torch._subclasses.fake_tensor.DataDependentOutputException:
2519*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(name in data_dependent_op_tests)
2520*da0073e9SAndroid Build Coastguard Worker
2521*da0073e9SAndroid Build Coastguard Worker            run_with_fake_mode_and_verify(mode)
2522*da0073e9SAndroid Build Coastguard Worker            if name in supported_dynamic_output_op_tests:
2523*da0073e9SAndroid Build Coastguard Worker                run_with_fake_mode_and_verify(
2524*da0073e9SAndroid Build Coastguard Worker                    allow_dynamic_output_shape_mode, match_results=False
2525*da0073e9SAndroid Build Coastguard Worker                )
2526*da0073e9SAndroid Build Coastguard Worker
2527*da0073e9SAndroid Build Coastguard Worker    @ops(op_db, dtypes=OpDTypes.any_one)
2528*da0073e9SAndroid Build Coastguard Worker    def test_pointwise_ops(self, device, dtype, op):
2529*da0073e9SAndroid Build Coastguard Worker        name = op.name
2530*da0073e9SAndroid Build Coastguard Worker        if op.variant_test_name:
2531*da0073e9SAndroid Build Coastguard Worker            name += "." + op.variant_test_name
2532*da0073e9SAndroid Build Coastguard Worker        if name in fake_skips or "sparse" in name or "jiterator" in name:
2533*da0073e9SAndroid Build Coastguard Worker            self.skipTest("Skip failing test")
2534*da0073e9SAndroid Build Coastguard Worker
2535*da0073e9SAndroid Build Coastguard Worker        test_self = self
2536*da0073e9SAndroid Build Coastguard Worker
2537*da0073e9SAndroid Build Coastguard Worker        class TestPointwiseMode(TorchDispatchMode):
2538*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
2539*da0073e9SAndroid Build Coastguard Worker                kwargs = kwargs or {}
2540*da0073e9SAndroid Build Coastguard Worker
2541*da0073e9SAndroid Build Coastguard Worker                out = func(*args, **kwargs)
2542*da0073e9SAndroid Build Coastguard Worker
2543*da0073e9SAndroid Build Coastguard Worker                if torch.Tag.pointwise in func.tags:
2544*da0073e9SAndroid Build Coastguard Worker                    shapes = []
2545*da0073e9SAndroid Build Coastguard Worker                    for inp in pytree.arg_tree_leaves(*args, **kwargs):
2546*da0073e9SAndroid Build Coastguard Worker                        if isinstance(inp, torch.Tensor):
2547*da0073e9SAndroid Build Coastguard Worker                            shapes.append(inp.shape)
2548*da0073e9SAndroid Build Coastguard Worker
2549*da0073e9SAndroid Build Coastguard Worker                    out_shape = torch._refs._broadcast_shapes(*shapes)
2550*da0073e9SAndroid Build Coastguard Worker
2551*da0073e9SAndroid Build Coastguard Worker                    for out_elem in pytree.tree_leaves(out):
2552*da0073e9SAndroid Build Coastguard Worker                        if isinstance(out_elem, torch.Tensor):
2553*da0073e9SAndroid Build Coastguard Worker                            test_self.assertEqual(out_elem.shape, out_shape)
2554*da0073e9SAndroid Build Coastguard Worker
2555*da0073e9SAndroid Build Coastguard Worker                return out
2556*da0073e9SAndroid Build Coastguard Worker
2557*da0073e9SAndroid Build Coastguard Worker        samples = op.sample_inputs(device, dtype, requires_grad=False)
2558*da0073e9SAndroid Build Coastguard Worker        for sample in samples:
2559*da0073e9SAndroid Build Coastguard Worker            mode = FakeTensorMode()
2560*da0073e9SAndroid Build Coastguard Worker
2561*da0073e9SAndroid Build Coastguard Worker            def map_to_fake(e):
2562*da0073e9SAndroid Build Coastguard Worker                if isinstance(e, torch.Tensor):
2563*da0073e9SAndroid Build Coastguard Worker                    return mode.from_tensor(e)
2564*da0073e9SAndroid Build Coastguard Worker                else:
2565*da0073e9SAndroid Build Coastguard Worker                    return e
2566*da0073e9SAndroid Build Coastguard Worker
2567*da0073e9SAndroid Build Coastguard Worker            input = tree_map(map_to_fake, sample.input)
2568*da0073e9SAndroid Build Coastguard Worker            args = tree_map(map_to_fake, sample.args)
2569*da0073e9SAndroid Build Coastguard Worker            kwargs = tree_map(map_to_fake, sample.kwargs)
2570*da0073e9SAndroid Build Coastguard Worker
2571*da0073e9SAndroid Build Coastguard Worker            try:
2572*da0073e9SAndroid Build Coastguard Worker                op(input, *args, **kwargs)
2573*da0073e9SAndroid Build Coastguard Worker            except Exception as e:
2574*da0073e9SAndroid Build Coastguard Worker                continue
2575*da0073e9SAndroid Build Coastguard Worker
2576*da0073e9SAndroid Build Coastguard Worker            with TestPointwiseMode():
2577*da0073e9SAndroid Build Coastguard Worker                with mode:
2578*da0073e9SAndroid Build Coastguard Worker                    op(input, *args, **kwargs)
2579*da0073e9SAndroid Build Coastguard Worker
2580*da0073e9SAndroid Build Coastguard Worker    @ops(op_db, dtypes=OpDTypes.any_one)
2581*da0073e9SAndroid Build Coastguard Worker    def test_fake(self, device, dtype, op):
2582*da0073e9SAndroid Build Coastguard Worker        self._test_fake_helper(device, dtype, op, contextlib.nullcontext)
2583*da0073e9SAndroid Build Coastguard Worker
2584*da0073e9SAndroid Build Coastguard Worker    @ops(op_db, dtypes=OpDTypes.any_one)
2585*da0073e9SAndroid Build Coastguard Worker    def test_fake_autocast(self, device, dtype, op):
2586*da0073e9SAndroid Build Coastguard Worker        device_type = torch.device(device).type
2587*da0073e9SAndroid Build Coastguard Worker        if op.name in fake_autocast_device_skips[device_type]:
2588*da0073e9SAndroid Build Coastguard Worker            self.skipTest("Skip failing test")
2589*da0073e9SAndroid Build Coastguard Worker
2590*da0073e9SAndroid Build Coastguard Worker        def context_fn():
2591*da0073e9SAndroid Build Coastguard Worker            return torch.amp.autocast(device_type)
2592*da0073e9SAndroid Build Coastguard Worker
2593*da0073e9SAndroid Build Coastguard Worker        self._test_fake_helper(device, dtype, op, context_fn)
2594*da0073e9SAndroid Build Coastguard Worker
2595*da0073e9SAndroid Build Coastguard Worker    def _test_fake_crossref_helper(self, device, dtype, op, context):
2596*da0073e9SAndroid Build Coastguard Worker        samples = op.sample_inputs(device, dtype, requires_grad=True)
2597*da0073e9SAndroid Build Coastguard Worker
2598*da0073e9SAndroid Build Coastguard Worker        for iter, sample in enumerate(samples):
2599*da0073e9SAndroid Build Coastguard Worker            args = [sample.input] + list(sample.args)
2600*da0073e9SAndroid Build Coastguard Worker            kwargs = sample.kwargs
2601*da0073e9SAndroid Build Coastguard Worker
2602*da0073e9SAndroid Build Coastguard Worker            # skip these to speed up tests
2603*da0073e9SAndroid Build Coastguard Worker            common_skip_ops = (
2604*da0073e9SAndroid Build Coastguard Worker                aten.detach.default,
2605*da0073e9SAndroid Build Coastguard Worker                aten.empty_strided.default,
2606*da0073e9SAndroid Build Coastguard Worker                aten.copy_.default,
2607*da0073e9SAndroid Build Coastguard Worker                aten.is_same_size.default,
2608*da0073e9SAndroid Build Coastguard Worker            )
2609*da0073e9SAndroid Build Coastguard Worker
2610*da0073e9SAndroid Build Coastguard Worker            # TODO: enable check_aliasing, batch norm fails
2611*da0073e9SAndroid Build Coastguard Worker            try:
2612*da0073e9SAndroid Build Coastguard Worker                with torch._subclasses.CrossRefFakeMode(
2613*da0073e9SAndroid Build Coastguard Worker                    ignore_op_fn=lambda fn: fn in common_skip_ops, check_aliasing=True
2614*da0073e9SAndroid Build Coastguard Worker                ):
2615*da0073e9SAndroid Build Coastguard Worker                    with warnings.catch_warnings(), context(), torch.autograd.set_multithreading_enabled(
2616*da0073e9SAndroid Build Coastguard Worker                        False
2617*da0073e9SAndroid Build Coastguard Worker                    ):
2618*da0073e9SAndroid Build Coastguard Worker                        composite_compliance.compute_expected_grads(
2619*da0073e9SAndroid Build Coastguard Worker                            op.get_op(),
2620*da0073e9SAndroid Build Coastguard Worker                            args,
2621*da0073e9SAndroid Build Coastguard Worker                            kwargs,
2622*da0073e9SAndroid Build Coastguard Worker                            sample.output_process_fn_grad,
2623*da0073e9SAndroid Build Coastguard Worker                            op.gradcheck_wrapper,
2624*da0073e9SAndroid Build Coastguard Worker                        )
2625*da0073e9SAndroid Build Coastguard Worker            except torch._subclasses.fake_tensor.UnsupportedOperatorException:
2626*da0073e9SAndroid Build Coastguard Worker                pass
2627*da0073e9SAndroid Build Coastguard Worker
2628*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
2629*da0073e9SAndroid Build Coastguard Worker    @ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,))
2630*da0073e9SAndroid Build Coastguard Worker    @skipOps(
2631*da0073e9SAndroid Build Coastguard Worker        "TestFakeTensor", "test_fake_crossref_backward_no_amp", fake_backward_xfails
2632*da0073e9SAndroid Build Coastguard Worker    )
2633*da0073e9SAndroid Build Coastguard Worker    def test_fake_crossref_backward_no_amp(self, device, dtype, op):
2634*da0073e9SAndroid Build Coastguard Worker        self._test_fake_crossref_helper(device, dtype, op, contextlib.nullcontext)
2635*da0073e9SAndroid Build Coastguard Worker
2636*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
2637*da0073e9SAndroid Build Coastguard Worker    @ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,))
2638*da0073e9SAndroid Build Coastguard Worker    @skipOps(
2639*da0073e9SAndroid Build Coastguard Worker        "TestFakeTensor",
2640*da0073e9SAndroid Build Coastguard Worker        "test_fake_crossref_backward_amp",
2641*da0073e9SAndroid Build Coastguard Worker        fake_backward_xfails | fake_autocast_backward_xfails,
2642*da0073e9SAndroid Build Coastguard Worker    )
2643*da0073e9SAndroid Build Coastguard Worker    def test_fake_crossref_backward_amp(self, device, dtype, op):
2644*da0073e9SAndroid Build Coastguard Worker        self._test_fake_crossref_helper(device, dtype, op, torch.cuda.amp.autocast)
2645*da0073e9SAndroid Build Coastguard Worker
2646*da0073e9SAndroid Build Coastguard Worker    @ops([op for op in ops_and_refs if op.is_factory_function])
2647*da0073e9SAndroid Build Coastguard Worker    def test_strided_layout(self, device, dtype, op):
2648*da0073e9SAndroid Build Coastguard Worker        samples = op.sample_inputs(device, dtype)
2649*da0073e9SAndroid Build Coastguard Worker        for sample in samples:
2650*da0073e9SAndroid Build Coastguard Worker            kwargs = sample.kwargs.copy()
2651*da0073e9SAndroid Build Coastguard Worker            kwargs["layout"] = torch.strided
2652*da0073e9SAndroid Build Coastguard Worker            strided_result = op(sample.input, *sample.args, **kwargs)
2653*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(strided_result.layout, torch.strided)
2654*da0073e9SAndroid Build Coastguard Worker
2655*da0073e9SAndroid Build Coastguard Worker
2656*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestCommon, globals())
2657*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestCompositeCompliance, globals())
2658*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestMathBits, globals())
2659*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestRefsOpsInfo, globals(), only_for="cpu")
2660*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestFakeTensor, globals())
2661*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestTags, globals())
2662*da0073e9SAndroid Build Coastguard Worker
2663*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
2664*da0073e9SAndroid Build Coastguard Worker    TestCase._default_dtype_check_enabled = True
2665*da0073e9SAndroid Build Coastguard Worker    run_tests()
2666