xref: /aosp_15_r20/external/pytorch/test/test_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: unknown"]
2
3import contextlib
4import copy
5import inspect
6import itertools
7import os
8import re
9import unittest
10import warnings
11from collections import defaultdict
12from collections.abc import Sequence
13from functools import partial
14from importlib import import_module
15from typing import Dict, List
16
17import torch
18import torch._prims as prims
19import torch.utils._pytree as pytree
20from torch._prims.context import TorchRefsMode
21from torch._prims_common.wrappers import _maybe_remove_out_wrapper
22from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
23from torch._subclasses.fake_utils import outputs_alias_inputs
24from torch.testing import make_tensor
25from torch.testing._internal import composite_compliance, opinfo
26from torch.testing._internal.common_device_type import (
27    deviceCountAtLeast,
28    instantiate_device_type_tests,
29    onlyCPU,
30    onlyCUDA,
31    onlyNativeDeviceTypesAnd,
32    OpDTypes,
33    ops,
34    skipMeta,
35)
36from torch.testing._internal.common_dtype import (
37    all_types_and_complex_and,
38    floating_and_complex_types_and,
39    integral_types_and,
40)
41from torch.testing._internal.common_methods_invocations import (
42    BinaryUfuncInfo,
43    op_db,
44    ops_and_refs,
45    python_ref_db,
46    ReductionOpInfo,
47    ReductionPythonRefInfo,
48    skip,
49    skipOps,
50    SpectralFuncInfo,
51    UnaryUfuncInfo,
52    xfail,
53)
54from torch.testing._internal.common_utils import (
55    clone_input_helper,
56    first_sample,
57    IS_CI,
58    IS_FBCODE,
59    is_iterable_of_tensors,
60    IS_SANDCASTLE,
61    IS_WINDOWS,
62    noncontiguous_like,
63    parametrize,
64    run_tests,
65    set_default_dtype,
66    skipIfTorchInductor,
67    slowTest,
68    suppress_warnings,
69    TEST_WITH_ASAN,
70    TEST_WITH_ROCM,
71    TEST_WITH_TORCHDYNAMO,
72    TEST_WITH_TORCHINDUCTOR,
73    TEST_WITH_UBSAN,
74    TestCase,
75    unMarkDynamoStrictTest,
76)
77from torch.utils._python_dispatch import TorchDispatchMode
78from torch.utils._pytree import tree_map
79
80
81assert torch.get_default_dtype() == torch.float32
82
83# variant testing is only done with torch.float and torch.cfloat to avoid
84#   excessive test times and maximize signal to noise ratio
85_variant_ops = partial(
86    ops, dtypes=OpDTypes.supported, allowed_dtypes=(torch.float, torch.cfloat)
87)
88
89# Get names of all the operators which have ref in their entry in OpInfo (testing infra)
90#   except for elementwise unary operators (separately implemented in test/test_unary_ufuncs.py),
91#   elementwise binary operators (separately implemented in test_binary_ufuncs.py),
92#   reduction operations (separately impelemented in test_reductions.py),
93#   and Spectral Functions (separately implemented for only 1D as of now, in test/test_spectral_ops.py)
94_ref_test_ops = tuple(
95    filter(
96        lambda op: not isinstance(
97            op, (UnaryUfuncInfo, ReductionOpInfo, SpectralFuncInfo, BinaryUfuncInfo)
98        )
99        and op.ref is not None,
100        op_db,
101    )
102)
103
104
105def reduction_dtype_filter(op):
106    if (
107        not isinstance(op, ReductionPythonRefInfo)
108        or not op.supports_out
109        or torch.int16 not in op.dtypes
110    ):
111        return False
112    return "dtype" in inspect.getfullargspec(op.op).kwonlyargs
113
114
115# Create a list of operators that are a subset of _ref_test_ops but don't have a
116# numpy ref to compare them too, If both CPU and CUDA are compared to numpy
117# then they do not need to be compared to each other
118_ops_and_refs_with_no_numpy_ref = [op for op in ops_and_refs if op.ref is None]
119
120aten = torch.ops.aten
121
122
123# Tests that apply to all operators and aren't related to any particular
124#   system
125@unMarkDynamoStrictTest
126class TestCommon(TestCase):
127    exact_dtype = True
128
129    # Verifies, on teardown, that no OpInfo is still using dynamic dtypes in CI
130    @classmethod
131    def tearDownClass(cls):
132        super().tearDownClass()
133
134        if IS_CI:
135            err_msg = (
136                "The operator(s) below is(are) using dynamic_dtypes in the OpInfo entries."
137                "This is OK for testing, but be sure to set the dtypes manually before landing your PR!"
138            )
139            # Assure no opinfo entry has dynamic_dtypes
140            filtered_ops = list(filter(opinfo.utils.is_dynamic_dtype_set, op_db))
141            for op in filtered_ops:
142                fmt_str = opinfo.utils.str_format_dynamic_dtype(op)
143                err_msg += "\n" + fmt_str
144
145            assert len(filtered_ops) == 0, err_msg
146
147    # Validates that each OpInfo works correctly on different CUDA devices
148    @onlyCUDA
149    @deviceCountAtLeast(2)
150    @ops(op_db, allowed_dtypes=(torch.float32, torch.long))
151    def test_multiple_devices(self, devices, dtype, op):
152        for cuda_device_str in devices:
153            cuda_device = torch.device(cuda_device_str)
154            # NOTE: only tests on first sample
155            samples = op.sample_inputs(cuda_device, dtype)
156            sample = first_sample(self, samples)
157            result = op(sample.input, *sample.args, **sample.kwargs)
158
159            if isinstance(result, torch.Tensor):
160                self.assertTrue(result.device == cuda_device)
161            elif is_iterable_of_tensors(result):
162                self.assertTrue(all(t.device == cuda_device for t in result))
163            else:
164                self.skipTest(
165                    "Skipped! Only supports single tensor or iterable of tensor outputs."
166                )
167
168    def test_pointwise_tag_coverage(self):
169        pytorch_dir = os.path.abspath(__file__ + "/../../")
170        files = [
171            "aten/src/ATen/native/UnaryOps.cpp",
172            "aten/src/ATen/native/BinaryOps.cpp",
173            "aten/src/ATen/native/PointwiseOps.cpp",
174            "aten/src/ATen/native/TensorCompare.cpp",
175        ]
176
177        allowed_functions = (
178            # reduction version of these operators
179            "aten.max.default",
180            "aten.max.dim",
181            "aten.max.dim_max",
182            "aten.max.names_dim",
183            "aten.max.names_dim_max",
184            "aten.max.unary_out",
185            "aten.min.default",
186            "aten.min.dim",
187            "aten.min.dim_min",
188            "aten.min.names_dim",
189            "aten.min.names_dim_min",
190            "aten.min.unary_out",
191            # not pointwise
192            "aten.isin.Tensor_Tensor",
193            "aten.isin.Tensor_Tensor_out",
194            "aten.isin.Tensor_Scalar",
195            "aten.isin.Tensor_Scalar_out",
196            "aten.isin.Scalar_Tensor",
197            "aten.isin.Scalar_Tensor_out",
198            "aten.mode.default",
199            "aten.mode.dimname",
200            "aten.mode.dimname_out",
201            "aten.mode.values",
202        )
203
204        regex = re.compile(r"DEFINE_DISPATCH\(.*_stub")
205
206        def get_opoverloadpacket_from_dispatch(kernel):
207            if hasattr(torch.ops.aten, kernel):
208                return kernel
209            if hasattr(torch.ops.aten, f"__{kernel}__"):
210                return f"__{kernel}__"
211            if hasattr(torch.ops.aten, f"special_{kernel}"):
212                return f"special_{kernel}"
213            if "_" in kernel:
214                kernel_split = kernel.split("_")
215                new_kernel = "_".join(kernel_split[:-1])
216                if hasattr(torch.ops.aten, new_kernel):
217                    return new_kernel
218
219            # could not find op from kernel dispatch string
220            self.assertTrue(False)
221
222        for file_name in files:
223            with open(os.path.join(pytorch_dir, file_name)) as f:
224                lines = f.read()
225                matches = regex.findall(lines)
226                for match in matches:
227                    kernel = match[len("DEFINE_DISPATCH(") : -len("_stub")]
228
229                    # no op definition for it, but defined with DEFINE_DISPATCH ?
230                    if kernel == "trigamma":
231                        continue
232
233                    kernel = get_opoverloadpacket_from_dispatch(kernel)
234                    overloadpacket = getattr(torch.ops.aten, kernel)
235
236                    for overload_name in overloadpacket.overloads():
237                        overload = getattr(overloadpacket, overload_name)
238
239                        if not torch._C._dispatch_has_kernel(overload.name()):
240                            continue
241
242                        # TODO: tags are not propagated to generated overload,
243                        # and there's no way of specifying them
244                        if torch.Tag.generated in overload.tags:
245                            continue
246
247                        if str(overload) in allowed_functions:
248                            continue
249
250                        self.assertTrue(torch.Tag.pointwise in overload.tags)
251
252    # Tests that the function and its (ndarray-accepting) reference produce the same
253    #   values on the tensors from sample_inputs func for the corresponding op.
254    # This test runs in double and complex double precision because
255    # NumPy does computation internally using double precision for many functions
256    # resulting in possible equality check failures.
257    # skip windows case on CPU due to https://github.com/pytorch/pytorch/issues/129947
258    @onlyNativeDeviceTypesAnd(["hpu"])
259    @suppress_warnings
260    @ops(_ref_test_ops, allowed_dtypes=(torch.float64, torch.long, torch.complex128))
261    def test_numpy_ref(self, device, dtype, op):
262        if (
263            TEST_WITH_TORCHINDUCTOR
264            and op.formatted_name
265            in ("signal_windows_exponential", "signal_windows_bartlett")
266            and dtype == torch.float64
267            and "cuda" in device
268            or "cpu" in device
269        ):  # noqa: E121
270            raise unittest.SkipTest("XXX: raises tensor-likes are not close.")
271
272        # Sets the default dtype to NumPy's default dtype of double
273        with set_default_dtype(torch.double):
274            for sample_input in op.reference_inputs(device, dtype):
275                self.compare_with_reference(
276                    op, op.ref, sample_input, exact_dtype=(dtype is not torch.long)
277                )
278
279    # Tests that the cpu and gpu results are consistent
280    @onlyCUDA
281    @suppress_warnings
282    @slowTest
283    @ops(_ops_and_refs_with_no_numpy_ref, dtypes=OpDTypes.any_common_cpu_cuda_one)
284    def test_compare_cpu(self, device, dtype, op):
285        def to_cpu(arg):
286            if isinstance(arg, torch.Tensor):
287                return arg.to(device="cpu")
288            return arg
289
290        samples = op.reference_inputs(device, dtype)
291
292        for sample in samples:
293            cpu_sample = sample.transform(to_cpu)
294            cuda_results = op(sample.input, *sample.args, **sample.kwargs)
295            cpu_results = op(cpu_sample.input, *cpu_sample.args, **cpu_sample.kwargs)
296
297            # output_process_fn_grad has a very unfortunate name
298            # We use this function in linalg extensively to postprocess the inputs of functions
299            # that are not completely well-defined. Think svd and muliplying the singular vectors by -1.
300            # CPU and CUDA implementations of the SVD can return valid SVDs that are different.
301            # We use this function to compare them.
302            cuda_results = sample.output_process_fn_grad(cuda_results)
303            cpu_results = cpu_sample.output_process_fn_grad(cpu_results)
304
305            # Lower tolerance because we are running this as a `@slowTest`
306            # Don't want the periodic tests to fail frequently
307            self.assertEqual(cuda_results, cpu_results, atol=1e-3, rtol=1e-3)
308
309    # Tests that experimental Python References can propagate shape, dtype,
310    # and device metadata properly.
311    # See https://github.com/pytorch/pytorch/issues/78050 for a discussion of stride propagation.
312    @onlyNativeDeviceTypesAnd(["hpu"])
313    @ops(python_ref_db)
314    @skipIfTorchInductor("Takes too long for inductor")
315    def test_python_ref_meta(self, device, dtype, op):
316        CHECK_CONJ_SKIPS = {
317            torch._refs.linalg.svd,
318        }
319
320        with FakeTensorMode() as mode:
321            pass
322
323        def _to_tensormeta(x):
324            if isinstance(x, torch.Tensor):
325                out = FakeTensor.from_tensor(x, mode)
326                return out
327            return x
328
329        # TODO: iterate over requires_grad true/false
330        for sample in op.reference_inputs(device, dtype, requires_grad=False):
331            result = op(sample.input, *sample.args, **sample.kwargs)
332
333            meta_sample = sample.transform(_to_tensormeta)
334            try:
335                with mode:
336                    meta_result = op(
337                        meta_sample.input, *meta_sample.args, **meta_sample.kwargs
338                    )
339            except torch._subclasses.fake_tensor.UnsupportedFakeTensorException:
340                continue
341            except torch._subclasses.fake_tensor.DataDependentOutputException:
342                continue
343            except torch._subclasses.fake_tensor.UnsupportedOperatorException:
344                continue
345
346            if isinstance(result, torch.Tensor):
347                self.assertTrue(isinstance(meta_result, FakeTensor))
348                prims.utils.compare_tensor_meta(
349                    result, meta_result, check_conj=op.op not in CHECK_CONJ_SKIPS
350                )
351            elif isinstance(result, Sequence):
352                for a, b in zip(result, meta_result):
353                    if isinstance(a, torch.Tensor) or isinstance(b, torch.Tensor):
354                        self.assertTrue(isinstance(b, FakeTensor))
355                        prims.utils.compare_tensor_meta(
356                            a, b, check_conj=op.op not in CHECK_CONJ_SKIPS
357                        )
358
359    def _ref_test_helper(
360        self,
361        ctx,
362        device,
363        dtype,
364        op,
365        skip_zero_numel=False,
366        skip_zero_dim=False,
367        skip_bfloat=False,
368        skip_view_consistency=False,
369    ):
370        # NOTE: this test works by comparing the reference
371        ex = None
372        for sample in op.reference_inputs(device, dtype, requires_grad=False):
373            if (
374                isinstance(sample.input, torch.Tensor)
375                and sample.input.numel() == 0
376                and skip_zero_numel
377            ):
378                continue
379            if (
380                isinstance(sample.input, torch.Tensor)
381                and sample.input.ndim == 0
382                and skip_zero_dim
383            ):
384                continue
385
386            if skip_bfloat and (
387                (
388                    isinstance(sample.input, torch.Tensor)
389                    and sample.input.dtype == torch.bfloat16
390                )
391                or any(
392                    isinstance(arg, torch.Tensor) and arg.dtype == torch.bfloat16
393                    for arg in sample.args
394                )
395            ):
396                continue
397            with ctx():
398                ref_result = op(sample.input, *sample.args, **sample.kwargs)
399            torch_result = op.torch_opinfo(sample.input, *sample.args, **sample.kwargs)
400
401            for a, b in zip(
402                pytree.tree_leaves(ref_result), pytree.tree_leaves(torch_result)
403            ):
404                if isinstance(a, torch.Tensor) or isinstance(b, torch.Tensor):
405                    prims.utils.compare_tensor_meta(a, b)
406                    if (
407                        getattr(op, "validate_view_consistency", True)
408                        and not skip_view_consistency
409                    ):
410                        msg = (
411                            f"The torch implementation {'returns' if b._is_view() else 'does not return'} "
412                            f"a view, while the reference {'does' if a._is_view() else 'does not'}"
413                        )
414                        self.assertEqual(a._is_view(), b._is_view(), msg)
415
416            # Computes the dtype the more precise computatino would occur in
417            precise_dtype = torch.bool
418            if prims.utils.is_integer_dtype(dtype):
419                # Note: bool and integer dtypes do not have more
420                # precise dtypes -- they simply must be close
421                precise_dtype = dtype
422            if prims.utils.is_float_dtype(dtype):
423                precise_dtype = torch.double
424            if prims.utils.is_complex_dtype(dtype):
425                precise_dtype = torch.cdouble
426
427            # Checks if the results are close
428            try:
429                self.assertEqual(
430                    ref_result,
431                    torch_result,
432                    exact_stride=False,
433                    exact_device=True,
434                    exact_layout=True,
435                    exact_is_coalesced=True,
436                )
437            except AssertionError as e:
438                # Raises the error if the precise dtype comparison wouldn't be
439                # different
440                if dtype is precise_dtype:
441                    raise e
442
443                ex = e
444
445            # Goes to next sample if these results are close
446            if not ex:
447                continue
448
449            # If the results are not close, checks that the
450            # reference is more accurate than the torch op
451            def _make_precise(x):
452                if isinstance(x, torch.dtype):
453                    return precise_dtype
454                if isinstance(x, torch.Tensor) and x.dtype is dtype:
455                    return x.to(precise_dtype)
456                return x
457
458            precise_sample = sample.transform(_make_precise)
459            precise_result = op.torch_opinfo(
460                precise_sample.input, *precise_sample.args, **precise_sample.kwargs
461            )
462
463            def _distance(a, b):
464                # Special-cases boolean comparisons
465                if prims.utils.is_boolean_dtype(a.dtype):
466                    assert b.dtype is torch.bool
467                    return (a ^ b).sum()
468
469                same = a == b
470                if prims.utils.is_float_dtype(a.dtype) or prims.utils.is_complex_dtype(
471                    a.dtype
472                ):
473                    same = torch.logical_or(
474                        same, torch.logical_and(torch.isnan(a), torch.isnan(b))
475                    )
476
477                actual_error = torch.where(same, 0, torch.abs(a - b)).sum()
478                return actual_error
479
480            ref_distance = 0
481            for a, b in zip(
482                pytree.tree_leaves(ref_result), pytree.tree_leaves(precise_result)
483            ):
484                ref_distance = ref_distance + _distance(a, b)
485
486            torch_distance = 0
487            for a, b in zip(
488                pytree.tree_leaves(torch_result), pytree.tree_leaves(precise_result)
489            ):
490                torch_distance = torch_distance + _distance(a, b)
491
492            # TODO: consider adding some tolerance to this comparison
493            msg = (
494                f"Reference result was farther ({ref_distance}) from the precise "
495                f"computation than the torch result was ({torch_distance})!"
496            )
497            self.assertTrue(ref_distance <= torch_distance, msg=msg)
498
499        # Reports numerical accuracy discrepancies
500        if ex is not None:
501            msg = "Test passed because the reference was more accurate than the torch operator."
502            warnings.warn(msg)
503
504    # Tests that experimental Python References perform the same computation
505    # as the operators they reference, when operator calls in the torch
506    # namesapce are remapped to the refs namespace (torch.foo becomes refs.foo).
507    @onlyNativeDeviceTypesAnd(["hpu"])
508    @ops(python_ref_db)
509    @skipIfTorchInductor("Takes too long for inductor")
510    def test_python_ref(self, device, dtype, op):
511        # In this test, primTorch refs call into the refs namespace
512        # For example, a ref with torch.foo in it will calls refs.foo instead
513        # Direct calls to refs and prims are not affected
514        if (
515            TEST_WITH_ROCM
516            and (op.name == "_refs.fft.ihfftn" or op.name == "_refs.fft.ihfft2")
517            and dtype == torch.float16
518        ):
519            self.skipTest("Skipped on ROCm")
520        self._ref_test_helper(lambda: TorchRefsMode(strict=True), device, dtype, op)
521
522    # Tests that experimental Python References perform the same computation
523    # as the operators they reference, when operator calls in the torch
524    # namespace are preserved (torch.foo remains torch.foo).
525    @onlyNativeDeviceTypesAnd(["hpu"])
526    @ops(python_ref_db)
527    @skipIfTorchInductor("Takes too long for inductor")
528    def test_python_ref_torch_fallback(self, device, dtype, op):
529        # In this test, refs call into the torch namespace (after the initial invocation)
530        # For example, a ref with torch.foo in it will call torch.foo instead of refs.foo
531        # Direct calls to refs and prims are not translated
532        if TEST_WITH_ROCM and op.name == "_refs.fft.ihfftn" and dtype == torch.float16:
533            self.skipTest("Skipped on ROCm")
534        self._ref_test_helper(contextlib.nullcontext, device, dtype, op)
535
536    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
537    @onlyCUDA
538    @ops(python_ref_db)
539    @parametrize("executor", ["aten"])
540    @skipIfTorchInductor("Takes too long for inductor")
541    def test_python_ref_executor(self, device, dtype, op, executor):
542        if (
543            TEST_WITH_ROCM
544            and (op.name == "_refs.fft.ihfftn" or op.name == "_refs.fft.ihfft2")
545            and dtype == torch.float16
546        ):
547            self.skipTest("Skipped on ROCm")
548        # skip zero-dim tensors for some composites of reduction operations and view
549        skip_zero_dim_ops = [
550            "_refs.logsumexp",
551            "_refs.log_softmax",
552            "_refs.native_group_norm",
553            "_refs.softmax",
554            "_refs.sum_to_size",
555            "ops.nvprims.view",
556        ]
557
558        from copy import copy
559
560        from torch._prims.executor import make_traced
561
562        op = copy(op)
563        op.op = partial(make_traced(op.op), executor=executor)
564        self._ref_test_helper(contextlib.nullcontext, device, dtype, op)
565
566    @skipMeta
567    @onlyNativeDeviceTypesAnd(["hpu"])
568    @ops([op for op in op_db if op.error_inputs_func is not None], dtypes=OpDTypes.none)
569    def test_errors(self, device, op):
570        error_inputs = op.error_inputs(device)
571        for ei in error_inputs:
572            si = ei.sample_input
573            with self.assertRaisesRegex(ei.error_type, ei.error_regex):
574                out = op(si.input, *si.args, **si.kwargs)
575                self.assertFalse(isinstance(out, type(NotImplemented)))
576
577    @skipMeta
578    @onlyNativeDeviceTypesAnd(["hpu"])
579    @ops(
580        [op for op in op_db if op.error_inputs_sparse_func is not None],
581        dtypes=OpDTypes.none,
582    )
583    @parametrize(
584        "layout",
585        (
586            torch.sparse_csr,
587            torch.sparse_csc,
588            torch.sparse_bsr,
589            torch.sparse_bsc,
590            torch.sparse_coo,
591        ),
592    )
593    def test_errors_sparse(self, device, op, layout):
594        for ei in op.error_inputs_sparse(device, layout):
595            si = ei.sample_input
596            with self.assertRaisesRegex(ei.error_type, ei.error_regex):
597                out = op(si.input, *si.args, **si.kwargs)
598                self.assertFalse(isinstance(out, type(NotImplemented)))
599
600    @skipMeta
601    @onlyNativeDeviceTypesAnd(["hpu"])
602    @ops(
603        [op for op in python_ref_db if op.error_inputs_func is not None],
604        dtypes=OpDTypes.none,
605    )
606    @skipIfTorchInductor("Takes too long for inductor")
607    def test_python_ref_errors(self, device, op):
608        mode = FakeTensorMode()
609        with mode:
610            pass
611
612        def _to_tensormeta(x):
613            if isinstance(x, torch.Tensor):
614                return FakeTensor.from_tensor(x, mode)
615            return x
616
617        error_inputs = op.error_inputs(device)
618        for ei in error_inputs:
619            si = ei.sample_input
620            meta_sample = si.transform(_to_tensormeta)
621            with self.assertRaisesRegex(ei.error_type, ei.error_regex):
622                op(meta_sample.input, *meta_sample.args, **meta_sample.kwargs)
623
624    # Tests that the function produces the same result when called with
625    #   noncontiguous tensors.
626    # TODO: get working with Windows by addressing failing operators
627    # TODO: get working with ASAN by addressing failing operators
628    @unittest.skipIf(IS_WINDOWS, "Skipped under Windows")
629    @onlyNativeDeviceTypesAnd(["hpu"])
630    @suppress_warnings
631    @ops(op_db, allowed_dtypes=(torch.float32, torch.long, torch.complex64))
632    def test_noncontiguous_samples(self, device, dtype, op):
633        test_grad = dtype in op.supported_backward_dtypes(torch.device(device).type)
634        sample_inputs = op.sample_inputs(device, dtype, requires_grad=test_grad)
635        for sample_input in sample_inputs:
636            t_inp, t_args, t_kwargs = (
637                sample_input.input,
638                sample_input.args,
639                sample_input.kwargs,
640            )
641            noncontig_sample = sample_input.noncontiguous()
642            n_inp, n_args, n_kwargs = (
643                noncontig_sample.input,
644                noncontig_sample.args,
645                noncontig_sample.kwargs,
646            )
647
648            # validates forward
649            expected = op(t_inp, *t_args, **t_kwargs)
650            actual = op(n_inp, *n_args, **n_kwargs)
651
652            self.assertEqual(actual, expected)
653
654            # Validate backward
655            # Short-circuits if the op doesn't support grad in this device x dtype
656            if not test_grad:
657                continue
658
659            expected = sample_input.output_process_fn_grad(expected)
660            actual = sample_input.output_process_fn_grad(actual)
661
662            if isinstance(expected, torch.Tensor):
663                grad_for_expected = torch.randn_like(expected)
664                grad_for_actual = noncontiguous_like(grad_for_expected)
665            elif isinstance(expected, Sequence):
666                # Filter output elements that do not require grad
667                expected = [
668                    t
669                    for t in expected
670                    if isinstance(t, torch.Tensor) and t.requires_grad
671                ]
672                actual = [
673                    n for n in actual if isinstance(n, torch.Tensor) and n.requires_grad
674                ]
675                grad_for_expected = [torch.randn_like(t) for t in expected]
676                grad_for_actual = [noncontiguous_like(n) for n in grad_for_expected]
677            else:
678                # Nothing to do if it returns a scalar or things like that
679                continue
680
681            # Concatenate inputs into a tuple
682            t_inputs = (
683                (t_inp,) + t_args
684                if isinstance(t_inp, torch.Tensor)
685                else tuple(t_inp) + t_args
686            )
687            n_inputs = (
688                (n_inp,) + n_args
689                if isinstance(n_inp, torch.Tensor)
690                else tuple(n_inp) + n_args
691            )
692
693            # Filter the elemnts that are tensors that require grad
694            t_input_tensors = [
695                t for t in t_inputs if isinstance(t, torch.Tensor) and t.requires_grad
696            ]
697            n_input_tensors = [
698                n for n in n_inputs if isinstance(n, torch.Tensor) and n.requires_grad
699            ]
700
701            self.assertEqual(len(t_input_tensors), len(n_input_tensors))
702
703            # Some functions may not use all the inputs to generate gradients. One of the
704            # few examples of this "odd" behaviour is F.hinge_embedding_loss
705            t_grads = torch.autograd.grad(
706                expected, t_input_tensors, grad_for_expected, allow_unused=True
707            )
708            n_grads = torch.autograd.grad(
709                actual, n_input_tensors, grad_for_actual, allow_unused=True
710            )
711
712            msg = "Got different gradients for contiguous / non-contiguous inputs wrt input {}."
713            for i, (t, n) in enumerate(zip(t_grads, n_grads)):
714                self.assertEqual(t, n, msg=msg.format(i))
715
716    # Separates one case from the following test_out because many ops don't properly implement the
717    #   incorrectly sized out parameter warning properly yet
718    # Cases test here:
719    #   - out= with the correct dtype and device, but the wrong shape
720    @ops(ops_and_refs, dtypes=OpDTypes.none)
721    def test_out_warning(self, device, op):
722        if TEST_WITH_TORCHDYNAMO and op.name == "_refs.clamp":
723            self.skipTest("flaky")
724        # Prefers running in float32 but has a fallback for the first listed supported dtype
725        supported_dtypes = op.supported_dtypes(self.device_type)
726        if len(supported_dtypes) == 0:
727            self.skipTest("Skipped! Op has not supported dtypes on this device.")
728        dtype = (
729            torch.float32
730            if torch.float32 in supported_dtypes
731            else next(iter(supported_dtypes))
732        )
733
734        # Ops from python_ref_db point to python decomps that are potentially
735        # wrapped with `torch._prims_common.wrappers.out_wrapper`. Unwrap these
736        # ops before testing to avoid clashing with OpInfo.supports_out
737        if not op.supports_out:
738            op = copy.copy(op)
739            op.op = _maybe_remove_out_wrapper(op.op)
740
741        samples = op.sample_inputs(device, dtype)
742        for sample in samples:
743            # calls it normally to get the expected result
744            expected = op(sample.input, *sample.args, **sample.kwargs)
745            op_out = partial(op, sample.input, *sample.args, **sample.kwargs)
746
747            # Short-circuits if output is not a single tensor or an
748            #   iterable of tensors
749            if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors(
750                expected, include_empty=True
751            ):
752                self.skipTest(
753                    "Skipped! Only supports single tensor or iterable of tensor outputs."
754                )
755
756            # Validates the op doesn't support out if it claims not to
757            if not op.supports_out:
758                with self.assertRaises(Exception):
759                    assert op_out(out=expected) != NotImplemented
760                return
761
762            # A wrapper around map that works with single tensors and always
763            #   instantiates the map. Used below to apply transforms to
764            #   single tensor and iterable tensor outputs.
765            def _apply_out_transform(fn, out):
766                if isinstance(out, torch.Tensor):
767                    return fn(out)
768
769                # assumes (see above) that out is an iterable of tensors
770                return tuple(map(fn, out))
771
772            # Extracts strides from a tensor or iterable of tensors into a tuple
773            def _extract_strides(out):
774                if isinstance(out, torch.Tensor):
775                    return (out.stride(),)
776
777                # assumes (see above) that out is an iterable of tensors
778                return tuple(t.stride() for t in out)
779
780            # Extracts data pointers from a tensor or iterable of tensors into a tuple
781            # NOTE: only extracts on the CPU and CUDA device types since some
782            #   device types don't have storage
783            def _extract_data_ptrs(out):
784                if self.device_type != "cpu" and self.device_type != "cuda":
785                    return ()
786
787                if isinstance(out, torch.Tensor):
788                    return (out.data_ptr(),)
789
790                # assumes (see above) that out is an iterable of tensors
791                return tuple(t.data_ptr() for t in out)
792
793            @suppress_warnings
794            def _compare_out(transform, *, compare_strides_and_data_ptrs=True):
795                out = _apply_out_transform(transform, expected)
796                original_strides = _extract_strides(out)
797                original_ptrs = _extract_data_ptrs(out)
798
799                op_out(out=out)
800                final_strides = _extract_strides(out)
801                final_ptrs = _extract_data_ptrs(out)
802
803                self.assertEqual(expected, out)
804
805                if compare_strides_and_data_ptrs:
806                    stride_msg = (
807                        f"Strides are not the same! Original strides were {original_strides} "
808                        f"and strides are now {final_strides}"
809                    )
810                    self.assertEqual(original_strides, final_strides, msg=stride_msg)
811                    self.assertEqual(original_ptrs, final_ptrs)
812
813            # Case Zero: out= with the correct dtype and device, but the wrong shape
814            #   Expected behavior: if nonempty, resize with a warning.
815            def _case_zero_transform(t):
816                wrong_shape = list(t.shape)
817
818                if len(wrong_shape) == 0:
819                    # Handles scalar tensor case (empty list)
820                    wrong_shape = [2]
821                else:
822                    wrong_shape[-1] = wrong_shape[-1] + 1
823                return make_tensor(wrong_shape, dtype=t.dtype, device=t.device)
824
825            # Verifies the out values are correct
826            _compare_out(_case_zero_transform, compare_strides_and_data_ptrs=False)
827
828            # Additionally validates that the appropriate warning is thrown if a nonempty
829            #   tensor is resized.
830            def _any_nonempty(out):
831                if isinstance(out, torch.Tensor):
832                    return out.numel() > 0
833
834                return any(x.numel() > 0 for x in out)
835
836            out = _apply_out_transform(_case_zero_transform, expected)
837            msg_fail = "Resized a non-empty tensor but did not warn about it."
838            if _any_nonempty(out):
839                with self.assertWarnsRegex(
840                    UserWarning, "An output with one or more elements", msg=msg_fail
841                ):
842                    op_out(out=out)
843
844    # Validates ops implement the correct out= behavior
845    # See https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-does-out-work-in-pytorch
846    #   for a description of the correct behavior
847    # Validates the following cases:
848    #   - Case 0: out has the correct shape, dtype, and device but is full of extremal values
849    #   - Case 1: out has the correct shape, dtype, and device but is noncontiguous
850    #   - Case 2: out has the correct dtype and device, but is zero elements
851    #   - Case 3: out has the correct shape and dtype, but is on a different device type
852    #   - Case 4: out has the correct shape and device, but a dtype that cannot
853    #       "safely" cast to
854    #
855    # Case 3 and 4 are slightly different when the op is a factory function:
856    #   - if device, dtype are NOT passed, any combination of dtype/device should be OK for out
857    #   - if device, dtype are passed, device and dtype should match
858    @ops(ops_and_refs, dtypes=OpDTypes.any_one)
859    def test_out(self, device, dtype, op):
860        # Prefers running in float32 but has a fallback for the first listed supported dtype
861        samples = op.sample_inputs(device, dtype)
862
863        # Ops from python_ref_db point to python decomps that are potentially
864        # wrapped with `torch._prims_common.wrappers.out_wrapper`. Unwrap these
865        # ops before testing to avoid clashing with OpInfo.supports_out
866        if not op.supports_out:
867            op = copy.copy(op)
868            op.op = _maybe_remove_out_wrapper(op.op)
869
870        for sample in samples:
871            # calls it normally to get the expected result
872            expected = op(sample.input, *sample.args, **sample.kwargs)
873            op_out = partial(op, sample.input, *sample.args, **sample.kwargs)
874
875            # Short-circuits if output is not a single tensor or an
876            #   iterable of tensors
877            if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors(
878                expected, include_empty=True
879            ):
880                self.skipTest(
881                    "Skipped! Only supports single tensor or iterable of tensor outputs."
882                )
883
884            # Validates the op doesn't support out if it claims not to
885            if not op.supports_out:
886                with self.assertRaises(Exception):
887                    assert op_out(out=expected) != NotImplemented
888                return
889
890            # A wrapper around map that works with single tensors and always
891            #   instantiates the map. Used below to apply transforms to
892            #   single tensor and iterable tensor outputs.
893            def _apply_out_transform(fn, out):
894                if isinstance(out, torch.Tensor):
895                    return fn(out)
896
897                # assumes (see above) that out is an iterable of tensors
898                return tuple(map(fn, out))
899
900            # Extracts strides from a tensor or iterable of tensors into a tuple
901            def _extract_strides(out):
902                if isinstance(out, torch.Tensor):
903                    return (out.stride(),)
904
905                # assumes (see above) that out is an iterable of tensors
906                return tuple(t.stride() for t in out)
907
908            # Extracts data pointers from a tensor or iterable of tensors into a tuple
909            # NOTE: only extracts on the CPU and CUDA device types since some
910            #   device types don't have storage
911            def _extract_data_ptrs(out):
912                if self.device_type != "cpu" and self.device_type != "cuda":
913                    return ()
914
915                if isinstance(out, torch.Tensor):
916                    return (out.data_ptr(),)
917
918                # assumes (see above) that out is an iterable of tensors
919                return tuple(t.data_ptr() for t in out)
920
921            def _compare_out(transform, *, compare_strides_and_data_ptrs=True):
922                out = _apply_out_transform(transform, expected)
923                original_strides = _extract_strides(out)
924                original_ptrs = _extract_data_ptrs(out)
925
926                op_out(out=out)
927                final_strides = _extract_strides(out)
928                final_ptrs = _extract_data_ptrs(out)
929                self.assertEqual(expected, out)
930
931                if compare_strides_and_data_ptrs:
932                    stride_msg = (
933                        "Strides are not the same! "
934                        f"Original strides were {original_strides} and strides are now {final_strides}"
935                    )
936                    self.assertEqual(original_strides, final_strides, msg=stride_msg)
937                    self.assertEqual(original_ptrs, final_ptrs)
938
939            # Case 0: out= with the correct shape, dtype, and device
940            #   but NaN values for floating point and complex tensors, and
941            #   maximum values for integer tensors.
942            #   Expected behavior: out= values have no effect on the computation.
943            def _case_zero_transform(t):
944                try:
945                    info = torch.iinfo(t.dtype)
946                    return torch.full_like(t, info.max)
947                except TypeError as te:
948                    # for non-integer types fills with NaN
949                    return torch.full_like(t, float("nan"))
950
951            _compare_out(_case_zero_transform)
952
953            # Case 1: out= with the correct shape, dtype, and device,
954            #   but noncontiguous.
955            #   Expected behavior: strides are respected and `out` storage is not changed.
956            def _case_one_transform(t):
957                return make_tensor(
958                    t.shape, dtype=t.dtype, device=t.device, noncontiguous=True
959                )
960
961            _compare_out(_case_one_transform)
962
963            # Case 2: out= with the correct dtype and device, but has no elements.
964            #   Expected behavior: resize without warning.
965            def _case_two_transform(t):
966                return make_tensor((0,), dtype=t.dtype, device=t.device)
967
968            _compare_out(_case_two_transform, compare_strides_and_data_ptrs=False)
969
970            # Also validates that no warning is thrown when this out is resized
971            out = _apply_out_transform(_case_two_transform, expected)
972            with warnings.catch_warnings(record=True) as caught:
973                warnings.simplefilter("always")
974                op_out(out=out)
975
976            # Verifies no warning is a resize warning
977            for w in caught:
978                if "An output with one or more elements" in str(w.message):
979                    self.fail(
980                        "Resizing an out= argument with no elements threw a resize warning!"
981                    )
982
983            # Case 3: out= with correct shape and dtype, but wrong device.
984            wrong_device = None
985            if torch.device(device).type != "cpu":
986                wrong_device = "cpu"
987            elif torch.cuda.is_available():
988                wrong_device = "cuda"
989
990            factory_fn_msg = (
991                "\n\nNOTE: If your op is a factory function (i.e., it accepts TensorOptions) you should mark its "
992                "OpInfo with `is_factory_function=True`."
993            )
994            if wrong_device is not None:
995
996                def _case_three_transform(t):
997                    return make_tensor(t.shape, dtype=t.dtype, device=wrong_device)
998
999                out = _apply_out_transform(_case_three_transform, expected)
1000
1001                if op.is_factory_function and sample.kwargs.get("device", None) is None:
1002                    op_out(out=out)
1003                else:
1004                    msg_fail = (
1005                        f"Expected RuntimeError when calling with input.device={device} and out.device={wrong_device}."
1006                    ) + factory_fn_msg
1007                    with self.assertRaises(RuntimeError, msg=msg_fail):
1008                        op_out(out=out)
1009
1010            # Case 4: out= with correct shape and device, but a dtype
1011            #   that output cannot be "safely" cast to (long).
1012            #   Expected behavior: error.
1013            # NOTE: this case is filtered by dtype since some ops produce
1014            #   bool tensors, for example, which can be safely cast to any
1015            #   dtype. It is applied when single tensors are floating point or complex
1016            #   dtypes, or if an op returns multiple tensors when at least one such
1017            #   tensor is a floating point or complex dtype.
1018            _dtypes = floating_and_complex_types_and(torch.float16, torch.bfloat16)
1019            if (
1020                isinstance(expected, torch.Tensor)
1021                and expected.dtype in _dtypes
1022                or (
1023                    not isinstance(expected, torch.Tensor)
1024                    and any(t.dtype in _dtypes for t in expected)
1025                )
1026            ):
1027
1028                def _case_four_transform(t):
1029                    return make_tensor(t.shape, dtype=torch.long, device=t.device)
1030
1031                out = _apply_out_transform(_case_four_transform, expected)
1032                msg_fail = "Expected RuntimeError when doing an unsafe cast!"
1033                msg_fail = (
1034                    msg_fail
1035                    if not isinstance(expected, torch.Tensor)
1036                    else (
1037                        "Expected RuntimeError when doing an unsafe cast from a result of dtype "
1038                        f"{expected.dtype} into an out= with dtype torch.long"
1039                    )
1040                ) + factory_fn_msg
1041
1042                if op.is_factory_function and sample.kwargs.get("dtype", None) is None:
1043                    op_out(out=out)
1044                else:
1045                    with self.assertRaises(RuntimeError, msg=msg_fail):
1046                        op_out(out=out)
1047
1048    @ops(
1049        [
1050            op
1051            for op in op_db
1052            if op.supports_out and (op.supports_autograd or op.is_factory_function)
1053        ],
1054        dtypes=OpDTypes.supported,
1055        allowed_dtypes=[torch.float, torch.cfloat],
1056    )
1057    def test_out_requires_grad_error(self, device, dtype, op):
1058        sample = first_sample(self, op.sample_inputs(device, dtype))
1059
1060        # Call op to get prototype for out arguments
1061        expect = op(sample.input, *sample.args, **sample.kwargs)
1062        any_requires_grad = False
1063
1064        def set_requires_grad(x):
1065            nonlocal any_requires_grad
1066            if isinstance(x, torch.Tensor) and (
1067                x.is_floating_point() or x.is_complex()
1068            ):
1069                any_requires_grad = True
1070                x.requires_grad_(True)
1071            return x
1072
1073        out = pytree.tree_map_(set_requires_grad, expect)
1074        if not any_requires_grad:
1075            # Skip ops without any floating point outputs, e.g. isnan
1076            return
1077
1078        msg = (
1079            "functions with out=... arguments don't support automatic "
1080            "differentiation, but one of the arguments requires grad."
1081        )
1082        with self.assertRaises(RuntimeError, msg=msg):
1083            op(sample.input, *sample.args, **sample.kwargs, out=out)
1084
1085    @ops(filter(reduction_dtype_filter, ops_and_refs), dtypes=(torch.int16,))
1086    def test_out_integral_dtype(self, device, dtype, op):
1087        def helper(with_out, expectFail, op_to_test, inputs, *args, **kwargs):
1088            out = None
1089            try:
1090                if with_out:
1091                    out = torch.empty(0, dtype=torch.int32, device=device)
1092                    op_to_test(inputs, *args, out=out, **kwargs)
1093                else:
1094                    out = op_to_test(inputs, *args, **kwargs)
1095                self.assertFalse(expectFail)
1096            except RuntimeError as err:
1097                self.assertEqual(
1098                    str(err), "dtype argument and out dtype must match in reduction"
1099                )
1100                self.assertTrue(expectFail)
1101            return out
1102
1103        samples = op.sample_inputs(device, dtype)
1104        for sample in samples:
1105            if "dtype" not in sample.kwargs:
1106                helper(False, False, op, sample.input, *sample.args, **sample.kwargs)
1107                helper(True, False, op, sample.input, *sample.args, **sample.kwargs)
1108                sample.kwargs["dtype"] = torch.int16
1109                helper(False, False, op, sample.input, *sample.args, **sample.kwargs)
1110                helper(True, True, op, sample.input, *sample.args, **sample.kwargs)
1111                sample.kwargs["dtype"] = torch.int32
1112                helper(False, False, op, sample.input, *sample.args, **sample.kwargs)
1113                helper(True, False, op, sample.input, *sample.args, **sample.kwargs)
1114            else:
1115                helper(False, False, op, sample.input, *sample.args, **sample.kwargs)
1116                helper(
1117                    True,
1118                    sample.kwargs["dtype"] != torch.int32,
1119                    op,
1120                    sample.input,
1121                    *sample.args,
1122                    **sample.kwargs,
1123                )
1124
1125    # Tests that the forward and backward passes of operations produce the
1126    #   same values for the cross-product of op variants (method, inplace)
1127    #   against eager's gold standard op function variant
1128    @_variant_ops(op_db)
1129    def test_variant_consistency_eager(self, device, dtype, op):
1130        # Acquires variants (method variant, inplace variant, operator variant, inplace_operator variant, aliases)
1131
1132        method = op.method_variant
1133        inplace = op.inplace_variant
1134        operator = op.operator_variant
1135        inplace_operator = op.inplace_operator_variant
1136
1137        # list of all inplace ops: inplace variant + alias inplace variants if exist
1138        inplace_ops = [inplace, inplace_operator]
1139        variants = [method, inplace, operator, inplace_operator]
1140        operators = [operator, inplace_operator]
1141
1142        for a_op in op.aliases:
1143            variants.append(a_op.op)
1144            variants.append(a_op.method_variant)
1145            variants.append(a_op.inplace_variant)
1146            inplace_ops.append(a_op.inplace_variant)
1147
1148        inplace_variants = tuple(filter(None, inplace_ops))
1149        variants = tuple(filter(None, variants))
1150        operators = tuple(filter(None, operators))
1151
1152        _requires_grad = dtype in op.supported_backward_dtypes(
1153            torch.device(device).type
1154        )
1155
1156        include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
1157        samples = op.sample_inputs(
1158            device,
1159            dtype,
1160            requires_grad=_requires_grad,
1161            include_conjugated_inputs=include_conjugated_inputs,
1162        )
1163        samples = list(samples)
1164
1165        def _test_consistency_helper(samples, variants):
1166            for sample in samples:
1167                # TODO: Check grad for all Tensors requiring grad if sample.input is TensorList
1168                tensor = (
1169                    sample.input
1170                    if isinstance(sample.input, torch.Tensor)
1171                    else sample.input[0]
1172                )
1173
1174                # Computes function forward and backward values
1175                tensor.grad = None
1176                expected_forward = op(sample.input, *sample.args, **sample.kwargs)
1177                expected_grad = None
1178
1179                output_process_fn_grad = (
1180                    sample.output_process_fn_grad
1181                    if sample.output_process_fn_grad
1182                    else lambda x: x
1183                )
1184
1185                # Skips inplace variants if the output dtype is not the same as
1186                #   the input dtype
1187                skip_inplace = False
1188                if (
1189                    isinstance(expected_forward, torch.Tensor)
1190                    and expected_forward.dtype is not tensor.dtype
1191                ):
1192                    skip_inplace = True
1193
1194                # TODO: backward consistency only supported for single tensor outputs
1195                # TODO: backward consistency only checked on sample.input, not all
1196                #   tensor inputs
1197                # TODO: update to handle checking grads of all tensor inputs as
1198                #   derived from each tensor output
1199                if isinstance(
1200                    expected_forward, torch.Tensor
1201                ) and dtype in op.supported_backward_dtypes(torch.device(device).type):
1202                    out = output_process_fn_grad(expected_forward).sum()
1203                    if out.dtype.is_complex:
1204                        out = out.abs()
1205                    out.backward()
1206                    expected_grad = tensor.grad
1207
1208                # Test eager consistency
1209                for variant in variants:
1210                    # Skips inplace ops
1211                    if variant in inplace_ops and skip_inplace:
1212                        continue
1213
1214                    # Compares variant's forward
1215                    # Note: copies the to-be-modified input when testing the inplace variant
1216                    tensor.grad = None
1217                    cloned = (
1218                        clone_input_helper(sample.input)
1219                        if variant in inplace_ops
1220                        else sample.input
1221                    )
1222
1223                    if variant in inplace_ops and sample.broadcasts_input:
1224                        with self.assertRaises(
1225                            RuntimeError,
1226                            msg=(
1227                                "inplace variant either incorrectly allowed "
1228                                f"resizing or you have marked the sample {sample.summary()}"
1229                                " incorrectly with `broadcasts_self=True"
1230                            ),
1231                        ):
1232                            variant_forward = variant(
1233                                cloned, *sample.args, **sample.kwargs
1234                            )
1235                        continue
1236
1237                    if variant in operators and sample.kwargs:
1238                        # skip samples with kwargs for operator variants
1239                        continue
1240
1241                    variant_forward = variant(cloned, *sample.args, **sample.kwargs)
1242                    self.assertEqual(expected_forward, variant_forward)
1243
1244                    # Compares variant's backward
1245                    if expected_grad is not None and (
1246                        variant not in inplace_ops or op.supports_inplace_autograd
1247                    ):
1248                        out = output_process_fn_grad(variant_forward).sum()
1249                        if out.dtype.is_complex:
1250                            out = out.abs()
1251                        out.backward()
1252                        self.assertEqual(expected_grad, tensor.grad)
1253
1254        _test_consistency_helper(samples, variants)
1255
1256        def _test_inplace_preserve_storage(samples, variants):
1257            for sample in samples:
1258                # Skips inplace variants if the output dtype is not the same as
1259                #   the input dtype
1260                expected_forward = op(sample.input, *sample.args, **sample.kwargs)
1261                tensor = (
1262                    sample.input
1263                    if isinstance(sample.input, torch.Tensor)
1264                    else sample.input[0]
1265                )
1266                skip_inplace = False
1267                if (
1268                    isinstance(expected_forward, torch.Tensor)
1269                    and expected_forward.dtype is not tensor.dtype
1270                ):
1271                    skip_inplace = True
1272                if skip_inplace:
1273                    return
1274                for variant in variants:
1275                    cloned = (
1276                        clone_input_helper(sample.input)
1277                        if variant in inplace_ops
1278                        else sample.input
1279                    )
1280                    inp_tensor = (
1281                        cloned if isinstance(cloned, torch.Tensor) else cloned[0]
1282                    )
1283                    data_ptr = inp_tensor.data_ptr()
1284                    if variant in operators and sample.kwargs:
1285                        # skip samples with kwargs for operator variants
1286                        continue
1287
1288                    variant_forward = variant(cloned, *sample.args, **sample.kwargs)
1289                    # TODO Support non-tensor outputs if they exist for inplace ops
1290                    if isinstance(variant_forward, torch.Tensor):
1291                        self.assertEqual(
1292                            data_ptr, variant_forward.data_ptr(), atol=0, rtol=0
1293                        )
1294                    else:
1295                        self.assertTrue(
1296                            False,
1297                            "Non-tensor outputs for inplace ops are not supported",
1298                        )
1299
1300        if len(inplace_ops) > 0:
1301            inplace_samples = list(
1302                filter(lambda sample: not sample.broadcasts_input, samples)
1303            )
1304            _test_inplace_preserve_storage(inplace_samples, inplace_variants)
1305
1306    # Reference testing for operations in complex32 against complex64.
1307    # NOTE: We test against complex64 as NumPy doesn't have a complex32 equivalent dtype.
1308    @ops(op_db, allowed_dtypes=(torch.complex32,))
1309    def test_complex_half_reference_testing(self, device, dtype, op):
1310        if not op.supports_dtype(torch.complex32, device):
1311            unittest.skip("Does not support complex32")
1312
1313        for sample in op.sample_inputs(device, dtype):
1314            actual = op(sample.input, *sample.args, **sample.kwargs)
1315            # sample.transform applies the lambda to torch.Tensor and torch.dtype.
1316            # However, we only want to apply it to Tensors with dtype `torch.complex32`..
1317            transformed_sample = sample.transform(
1318                lambda x: x.to(torch.complex64)
1319                if isinstance(x, torch.Tensor) and x.dtype is torch.complex32
1320                else x
1321            )
1322            expected = op(
1323                transformed_sample.input,
1324                *transformed_sample.args,
1325                **transformed_sample.kwargs,
1326            )
1327            # Since range of chalf is much less compared to cfloat,
1328            # we get `inf`s easily (eg. with `pow`, `exp`),
1329            # so we cast `cfloat` back to `chalf`.
1330            expected = tree_map(
1331                lambda x: x.to(torch.complex32)
1332                if isinstance(x, torch.Tensor) and x.dtype is torch.complex64
1333                else x,
1334                expected,
1335            )
1336
1337            # `exact_dtype` is False because for ops like real, imag
1338            # we get different dtypes for `actual` and `expected`
1339            # `chalf` input -> `half` output
1340            # `cfloat` input -> `float` output
1341            self.assertEqual(actual, expected, exact_dtype=False)
1342
1343    @ops(op_db, allowed_dtypes=(torch.bool,))
1344    @unittest.skipIf(TEST_WITH_UBSAN, "Test uses undefined behavior")
1345    def test_non_standard_bool_values(self, device, dtype, op):
1346        # Test boolean values other than 0x00 and 0x01 (gh-54789)
1347        def convert_boolean_tensors(x):
1348            if not isinstance(x, torch.Tensor) or x.dtype != torch.bool:
1349                return x
1350
1351            # Map False -> 0 and True -> Random value in [2, 255]
1352            true_vals = torch.randint(
1353                2, 255, x.shape, dtype=torch.uint8, device=x.device
1354            )
1355            false_vals = torch.zeros((), dtype=torch.uint8, device=x.device)
1356            x_int = torch.where(x, true_vals, false_vals)
1357
1358            ret = x_int.view(torch.bool)
1359            self.assertEqual(ret, x)
1360            return ret
1361
1362        for sample in op.sample_inputs(device, dtype):
1363            expect = op(sample.input, *sample.args, **sample.kwargs)
1364
1365            transformed = sample.transform(convert_boolean_tensors)
1366            actual = op(transformed.input, *transformed.args, **transformed.kwargs)
1367
1368            self.assertEqual(expect, actual)
1369
1370    # Validates that each OpInfo specifies its forward and backward dtypes
1371    #   correctly for CPU and CUDA devices
1372    @skipMeta
1373    @onlyNativeDeviceTypesAnd(["hpu"])
1374    @ops(ops_and_refs, dtypes=OpDTypes.none)
1375    def test_dtypes(self, device, op):
1376        # Check complex32 support only if the op claims.
1377        # TODO: Once the complex32 support is better, we should add check for complex32 unconditionally.
1378        device_type = torch.device(device).type
1379        include_complex32 = (
1380            (torch.complex32,)
1381            if op.supports_dtype(torch.complex32, device_type)
1382            else ()
1383        )
1384
1385        # dtypes to try to backward in
1386        allowed_backward_dtypes = floating_and_complex_types_and(
1387            *((torch.half, torch.bfloat16) + include_complex32)
1388        )
1389
1390        # lists for (un)supported dtypes
1391        supported_dtypes = set()
1392        unsupported_dtypes = set()
1393        supported_backward_dtypes = set()
1394        unsupported_backward_dtypes = set()
1395        dtype_error: Dict[torch.dtype, Exception] = {}
1396
1397        def unsupported(dtype, e):
1398            dtype_error[dtype] = e
1399            unsupported_dtypes.add(dtype)
1400            if dtype in allowed_backward_dtypes:
1401                unsupported_backward_dtypes.add(dtype)
1402
1403        for dtype in all_types_and_complex_and(
1404            *((torch.half, torch.bfloat16, torch.bool) + include_complex32)
1405        ):
1406            # tries to acquire samples - failure indicates lack of support
1407            requires_grad = dtype in allowed_backward_dtypes
1408            try:
1409                samples = tuple(
1410                    op.sample_inputs(device, dtype, requires_grad=requires_grad)
1411                )
1412            except Exception as e:
1413                unsupported(dtype, e)
1414                continue
1415
1416            for sample in samples:
1417                # tries to call operator with the sample - failure indicates
1418                #   lack of support
1419                try:
1420                    result = op(sample.input, *sample.args, **sample.kwargs)
1421                    supported_dtypes.add(dtype)
1422                except Exception as e:
1423                    # NOTE: some ops will fail in forward if their inputs
1424                    #   require grad but they don't support computing the gradient
1425                    #   in that type! This is a bug in the op!
1426                    unsupported(dtype, e)
1427                    continue
1428
1429                # Checks for backward support in the same dtype, if the input has
1430                # one or more tensors requiring grad
1431                def _tensor_requires_grad(x):
1432                    if isinstance(x, dict):
1433                        for v in x.values():
1434                            if _tensor_requires_grad(v):
1435                                return True
1436                    if isinstance(x, (list, tuple)):
1437                        for a in x:
1438                            if _tensor_requires_grad(a):
1439                                return True
1440                    if isinstance(x, torch.Tensor) and x.requires_grad:
1441                        return True
1442
1443                    return False
1444
1445                requires_grad = (
1446                    _tensor_requires_grad(sample.input)
1447                    or _tensor_requires_grad(sample.args)
1448                    or _tensor_requires_grad(sample.kwargs)
1449                )
1450                if not requires_grad:
1451                    continue
1452
1453                try:
1454                    result = sample.output_process_fn_grad(result)
1455                    if isinstance(result, torch.Tensor):
1456                        backward_tensor = result
1457                    elif isinstance(result, Sequence) and isinstance(
1458                        result[0], torch.Tensor
1459                    ):
1460                        backward_tensor = result[0]
1461                    else:
1462                        continue
1463
1464                    # Note: this grad may not have the same dtype as dtype
1465                    # For functions like complex (float -> complex) or abs
1466                    #   (complex -> float) the grad tensor will have a
1467                    #   different dtype than the input.
1468                    #   For simplicity, this is still modeled as these ops
1469                    #   supporting grad in the input dtype.
1470                    grad = torch.randn_like(backward_tensor)
1471                    backward_tensor.backward(grad)
1472                    supported_backward_dtypes.add(dtype)
1473                except Exception as e:
1474                    dtype_error[dtype] = e
1475                    unsupported_backward_dtypes.add(dtype)
1476
1477        # Checks that dtypes are listed correctly and generates an informative
1478        #   error message
1479
1480        supported_forward = supported_dtypes - unsupported_dtypes
1481        partially_supported_forward = supported_dtypes & unsupported_dtypes
1482        unsupported_forward = unsupported_dtypes - supported_dtypes
1483        supported_backward = supported_backward_dtypes - unsupported_backward_dtypes
1484        partially_supported_backward = (
1485            supported_backward_dtypes & unsupported_backward_dtypes
1486        )
1487        unsupported_backward = unsupported_backward_dtypes - supported_backward_dtypes
1488
1489        device_type = torch.device(device).type
1490
1491        claimed_forward = set(op.supported_dtypes(device_type))
1492        supported_but_unclaimed_forward = supported_forward - claimed_forward
1493        claimed_but_unsupported_forward = claimed_forward & unsupported_forward
1494
1495        claimed_backward = set(op.supported_backward_dtypes(device_type))
1496        supported_but_unclaimed_backward = supported_backward - claimed_backward
1497        claimed_but_unsupported_backward = claimed_backward & unsupported_backward
1498
1499        # Partially supporting a dtype is not an error, but we print a warning
1500        if (len(partially_supported_forward) + len(partially_supported_backward)) > 0:
1501            msg = f"Some dtypes for {op.name} on device type {device_type} are only partially supported!\n"
1502            if len(partially_supported_forward) > 0:
1503                msg = (
1504                    msg
1505                    + f"The following dtypes only worked on some samples during forward: {partially_supported_forward}.\n"
1506                )
1507            if len(partially_supported_backward) > 0:
1508                msg = (
1509                    msg
1510                    + f"The following dtypes only worked on some samples during backward: {partially_supported_backward}.\n"
1511                )
1512            print(msg)
1513
1514        if (
1515            len(supported_but_unclaimed_forward)
1516            + len(claimed_but_unsupported_forward)
1517            + len(supported_but_unclaimed_backward)
1518            + len(claimed_but_unsupported_backward)
1519        ) == 0:
1520            return
1521
1522        # Reference operators often support additional dtypes, and that's OK
1523        if op in python_ref_db:
1524            if (
1525                len(claimed_but_unsupported_forward)
1526                + len(claimed_but_unsupported_backward)
1527            ) == 0:
1528                return
1529
1530        # Generates error msg
1531        msg = f"The supported dtypes for {op.name} on device type {device_type} are incorrect!\n"
1532        if len(supported_but_unclaimed_forward) > 0:
1533            msg = (
1534                msg
1535                + "The following dtypes worked in forward but are not listed by the OpInfo: "
1536                + f"{supported_but_unclaimed_forward}.\n"
1537            )
1538        if len(supported_but_unclaimed_backward) > 0:
1539            msg = (
1540                msg
1541                + "The following dtypes worked in backward but are not listed by the OpInfo: "
1542                + f"{supported_but_unclaimed_backward}.\n"
1543            )
1544        if len(claimed_but_unsupported_forward) > 0:
1545            msg = (
1546                msg
1547                + "The following dtypes did not work in forward but are listed by the OpInfo: "
1548                + f"{claimed_but_unsupported_forward}.\n"
1549            )
1550        if len(claimed_but_unsupported_backward) > 0:
1551            msg = (
1552                msg
1553                + "The following dtypes did not work in backward "
1554                + f"but are listed by the OpInfo: {claimed_but_unsupported_backward}.\n"
1555            )
1556
1557        all_claimed_but_unsupported = set.union(
1558            claimed_but_unsupported_backward, claimed_but_unsupported_forward
1559        )
1560        if all_claimed_but_unsupported:
1561            msg += "Unexpected failures raised the following errors:\n"
1562            for dtype in all_claimed_but_unsupported:
1563                msg += f"{dtype} - {dtype_error[dtype]}\n"
1564
1565        self.fail(msg)
1566
1567    # Validates that each OpInfo that sets promotes_int_to_float=True does as it says
1568    @skipMeta
1569    @onlyNativeDeviceTypesAnd(["hpu"])
1570    @ops(
1571        (op for op in op_db if op.promotes_int_to_float),
1572        allowed_dtypes=integral_types_and(torch.bool),
1573    )
1574    def test_promotes_int_to_float(self, device, dtype, op):
1575        for sample in op.sample_inputs(device, dtype):
1576            output = op(sample.input, *sample.args, **sample.kwargs)
1577            if not output.dtype.is_floating_point:
1578                self.fail(
1579                    f"The OpInfo sets `promotes_int_to_float=True`, but {dtype} was promoted to {output.dtype}."
1580                )
1581
1582
1583@unMarkDynamoStrictTest
1584class TestCompositeCompliance(TestCase):
1585    # Checks if the operator (if it is composite) is written to support most
1586    # backends and Tensor subclasses. See "CompositeImplicitAutograd Compliance"
1587    # in aten/src/ATen/native/README.md for more details
1588    @unittest.skipIf(
1589        IS_FBCODE or IS_SANDCASTLE, "__torch_dispatch__ does not work in fbcode"
1590    )
1591    @ops(op_db, allowed_dtypes=(torch.float,))
1592    def test_operator(self, device, dtype, op):
1593        samples = op.sample_inputs(device, dtype, requires_grad=False)
1594
1595        for sample in samples:
1596            args = [sample.input] + list(sample.args)
1597            kwargs = sample.kwargs
1598            composite_compliance.check_with_mode(op, args, kwargs, self.assertEqual)
1599            composite_compliance.check_all_permutations(
1600                op, args, kwargs, self.assertEqual
1601            )
1602
1603    @unittest.skipIf(
1604        IS_FBCODE or IS_SANDCASTLE, "__torch_dispatch__ does not work in fbcode"
1605    )
1606    @ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,))
1607    def test_backward(self, device, dtype, op):
1608        samples = op.sample_inputs(device, dtype, requires_grad=True)
1609
1610        for sample in samples:
1611            args = [sample.input] + list(sample.args)
1612            kwargs = sample.kwargs
1613            # We pass assertEqual so that decorators like `toleranceOverride`
1614            # actually work (otherwise they silently do nothing!)
1615            composite_compliance.check_backward_formula(
1616                op.get_op(),
1617                args,
1618                kwargs,
1619                sample.output_process_fn_grad,
1620                op.gradcheck_wrapper,
1621                self.assertEqual,
1622            )
1623
1624    @unittest.skipIf(
1625        IS_FBCODE or IS_SANDCASTLE, "__torch_dispatch__ does not work in fbcode"
1626    )
1627    @ops(op_db, allowed_dtypes=(torch.float,))
1628    def test_forward_ad(self, device, dtype, op):
1629        if torch.float not in op.supported_backward_dtypes(device):
1630            raise unittest.SkipTest("Does not support autograd")
1631
1632        if not op.supports_forward_ad:
1633            raise unittest.SkipTest("Does not support forward_ad")
1634
1635        samples = op.sample_inputs(device, dtype, requires_grad=True)
1636
1637        for sample in samples:
1638            args = [sample.input] + list(sample.args)
1639            kwargs = sample.kwargs
1640            # We pass assertEqual so that decorators like `toleranceOverride`
1641            # actually work (otherwise they silently do nothing!)
1642            composite_compliance.check_forward_ad_formula(
1643                op.get_op(), args, kwargs, op.gradcheck_wrapper, self.assertEqual
1644            )
1645
1646    @ops(op_db, allowed_dtypes=(torch.float,))
1647    def test_cow_input(self, device, dtype, op):
1648        samples = op.sample_inputs(device, dtype, requires_grad=op.supports_autograd)
1649
1650        def is_strided_tensor(arg):
1651            return torch.is_tensor(arg) and arg.layout == torch.strided
1652
1653        def check_ignore_materialize(idx_or_kw, allow_list):
1654            return (allow_list is not None) and (idx_or_kw in allow_list)
1655
1656        def check_cow_input(
1657            arg,
1658            arg_copy,
1659            idx_or_kw,
1660            backward_or_forward="forward",
1661            supports_cow_input_no_materialize=op.supports_cow_input_no_materialize_forward,
1662            allow_list=op.allow_cow_input_materialize_forward,
1663        ):
1664            arg_name = (
1665                f"Argument {idx_or_kw}"
1666                if isinstance(idx_or_kw, int)
1667                else f"Keyword argument '{idx_or_kw}'"
1668            ) + f" during {backward_or_forward} call"
1669
1670            if is_strided_tensor(arg):
1671                is_cow = torch._C._is_cow_tensor(arg)
1672
1673                if supports_cow_input_no_materialize and not check_ignore_materialize(
1674                    idx_or_kw, allow_list
1675                ):
1676                    self.assertTrue(
1677                        is_cow,
1678                        msg=(
1679                            f"{arg_name} unexpectedly materializes. "
1680                            f"Either set `supports_cow_input_no_materialize_{backward_or_forward}=False` "
1681                            "in this operation's OpInfo, add the arg to the OpInfo's "
1682                            f"`allow_cow_input_materialize_{backward_or_forward}` list, or change the "
1683                            "implementation to avoid materialization."
1684                        ),
1685                    )
1686
1687                if is_cow:
1688                    self.assertTrue(
1689                        torch.allclose(arg, arg_copy, rtol=0, atol=0, equal_nan=True),
1690                        msg=(
1691                            f"{arg_name} avoided materialization, "
1692                            "but the operation mutated its data."
1693                        ),
1694                    )
1695
1696        for sample in samples:
1697            args_raw = [sample.input] + list(sample.args)
1698            kwargs_raw = sample.kwargs
1699            args_copy = []
1700            args = []
1701            kwargs_copy = {}
1702            kwargs = {}
1703
1704            # Convert strided tensor inputs to COW tensors and make copies of
1705            # all inputs
1706            for idx, arg in enumerate(args_raw):
1707                if is_strided_tensor(arg):
1708                    args_copy.append(arg.clone().detach())
1709                    args.append(torch._lazy_clone(arg))
1710                else:
1711                    if torch.is_tensor(arg):
1712                        args_copy.append(arg.clone().detach())
1713                    else:
1714                        args_copy.append(copy.deepcopy(arg))
1715                    args.append(arg)
1716
1717            for kw, arg in kwargs_raw.items():
1718                if is_strided_tensor(arg):
1719                    kwargs_copy[kw] = arg.clone().detach()
1720                    kwargs[kw] = torch._lazy_clone(arg)
1721                else:
1722                    if torch.is_tensor(arg):
1723                        kwargs_copy[kw] = arg.clone().detach()
1724                    else:
1725                        kwargs_copy[kw] = copy.deepcopy(arg)
1726                    kwargs[kw] = arg
1727
1728            leaf_tensors = composite_compliance.gather_leaf_tensors(args, kwargs)
1729
1730            # Call forward op
1731            results_raw = op.get_op()(*args, **kwargs)
1732
1733            # Check that COW inputs remain COW after the forward op is executed
1734            for idx, arg in enumerate(args):
1735                check_cow_input(arg, args_copy[idx], idx)
1736
1737            for kw, arg in kwargs.items():
1738                check_cow_input(arg, kwargs_copy[kw], kw)
1739
1740            # Call backward op if it is supported. This part of the test is
1741            # based on `composite_compliance.check_backward_formula`
1742            if (
1743                op.supports_autograd
1744                and len(leaf_tensors) > 0
1745                and not op.skip_cow_input_backward
1746            ):
1747                if sample.output_process_fn_grad is not None:
1748                    results_raw = sample.output_process_fn_grad(results_raw)
1749
1750                leaf_results = pytree.tree_leaves(results_raw)
1751                results = [
1752                    r
1753                    for r in leaf_results
1754                    if isinstance(r, torch.Tensor) and r.requires_grad
1755                ]
1756
1757                all_results_strided = all(
1758                    is_strided_tensor(result) for result in results
1759                )
1760
1761                # Only test backward if the results are strided tensors
1762                if all_results_strided:
1763                    output_grads_raw = [
1764                        torch.ones(r.shape, device=r.device, dtype=r.dtype)
1765                        for r in results
1766                    ]
1767                    output_grads_copy = []
1768                    output_grads = []
1769
1770                    # Convert output grads to COW tensors and make copies
1771                    for output_grad in output_grads_raw:
1772                        output_grads_copy.append(output_grad.clone().detach())
1773                        output_grads.append(torch._lazy_clone(output_grad))
1774
1775                    input_grads = torch.autograd.grad(
1776                        results,
1777                        leaf_tensors,
1778                        output_grads,
1779                        allow_unused=True,
1780                        retain_graph=True,
1781                    )
1782
1783                    # Check that COW inputs remain COW after the backward op is executed
1784                    for idx, arg in enumerate(args):
1785                        check_cow_input(
1786                            arg,
1787                            args_copy[idx],
1788                            idx,
1789                            backward_or_forward="backward",
1790                            supports_cow_input_no_materialize=op.supports_cow_input_no_materialize_backward,
1791                            allow_list=op.allow_cow_input_materialize_backward,
1792                        )
1793
1794                    # Check that COW inputs remain COW after the backward op is executed
1795                    for idx, output_grad in enumerate(output_grads):
1796                        check_cow_input(
1797                            output_grad,
1798                            output_grads_copy[idx],
1799                            f"output grad {idx}",
1800                            backward_or_forward="backward",
1801                            supports_cow_input_no_materialize=op.supports_cow_input_no_materialize_backward,
1802                            allow_list=op.allow_cow_input_materialize_backward,
1803                        )
1804
1805    @ops(op_db, allowed_dtypes=(torch.float,))
1806    def test_view_replay(self, device, dtype, op):
1807        def _assert_match_metadata(a, b):
1808            self.assertEqual(a.size(), b.size())
1809            self.assertEqual(a.stride(), b.stride())
1810            self.assertEqual(a.storage_offset(), b.storage_offset())
1811            self.assertEqual(a.device, b.device)
1812            self.assertEqual(a.dtype, b.dtype)
1813
1814        # ensure view replay is enabled
1815        with torch.autograd._force_original_view_tracking(True):
1816            for sample in op.sample_inputs(device, dtype, requires_grad=False):
1817                inp = sample.input
1818                outs = op(inp, *sample.args, **sample.kwargs)
1819                if not isinstance(outs, (tuple, List)):
1820                    outs = [outs]
1821
1822                # for all outputs that are views of the input, we should be able to replay the
1823                # forward and reverse views via a functioning view_func() / rev_view_func().
1824                for out in outs:
1825                    if not (
1826                        isinstance(out, torch.Tensor)
1827                        and out._is_view()
1828                        and out._base is inp
1829                    ):
1830                        continue
1831
1832                    # forward view_func
1833                    new_inp = inp.clone()
1834                    _assert_match_metadata(new_inp, inp)
1835                    new_out = out._view_func_unsafe(new_inp)
1836                    _assert_match_metadata(new_out, out)
1837                    self.assertEqual(new_out, out)
1838
1839                    # reverse view_func
1840                    new_out = out.detach()
1841                    new_inp = out._rev_view_func_unsafe(new_out)
1842                    _assert_match_metadata(new_inp, inp)
1843                    self.assertTrue(new_inp._is_view())
1844                    self.assertTrue(new_inp._base is new_out)
1845
1846
1847@unMarkDynamoStrictTest
1848class TestMathBits(TestCase):
1849    # Tests that
1850    # 1. The operator's output for physically conjugated/negated tensors and conjugate/negative view tensors
1851    # produces the same value
1852    # 2. The gradients are same in both cases mentioned in (1)
1853    # 3. If the operator's inplace variant is supported, tests that the inplace operation
1854    #    produces the correct value when called on a conjugate/negative view tensor and that the output
1855    #    has its conj/neg bit set to true
1856    # This test only runs for C -> R and C -> C functions
1857    # TODO: add tests for `R->C` functions
1858    # Note: This test runs for functions that take both tensors and tensorlists as input.
1859    def _test_math_view(
1860        self,
1861        device,
1862        dtype,
1863        op,
1864        samples,
1865        math_op_physical,
1866        math_op_view,
1867        is_bit_set,
1868        out_type,
1869    ):
1870        inplace_variant = op.inplace_variant
1871
1872        # helper function to clone and conjugate/negate the input if its a tensor
1873        # else clone the sequence and conjugate/negate the first element in the sequence
1874        # If a requires_grad argument is provided the tensor being conjugated/negated will
1875        # have its requires_grad set to that value.
1876        def clone_and_perform_view(input, **kwargs):
1877            if isinstance(input, torch.Tensor):
1878                requires_grad = kwargs.get("requires_grad", input.requires_grad)
1879                with torch.no_grad():
1880                    # Ensure view represents the original sample input
1881                    input = math_op_physical(input)
1882                # Note: .conj() is not called under no_grad mode since it's not allowed to modify a
1883                # view created in no_grad mode. Here it's ok to do so, so as a workaround we call conj
1884                # before resetting the requires_grad field for input
1885                input = math_op_view(input)
1886                assert input.is_leaf
1887                return input.requires_grad_(requires_grad)
1888
1889            if isinstance(input, Sequence):
1890                out = list(map(clone_input_helper, input))
1891                out[0] = clone_and_perform_view(out[0])
1892                return tuple(out)
1893
1894        for sample in samples:
1895            tensor = (
1896                sample.input
1897                if isinstance(sample.input, torch.Tensor)
1898                else sample.input[0]
1899            )
1900            cloned1 = clone_and_perform_view(sample.input)
1901
1902            # Computes function forward value with a physically conjugated/negated tensor and
1903            # a conj/neg view tensor and verifies that the output in both case are equal.
1904            expected_forward = op(sample.input, *sample.args, **sample.kwargs)
1905            forward_with_mathview = op(cloned1, *sample.args, **sample.kwargs)
1906            self.assertEqual(expected_forward, forward_with_mathview)
1907
1908            # If the op has an inplace variant, and the input doesn't require broadcasting
1909            # and has the same dtype as output, verify that the inplace operation on a conjugated/negated
1910            # input produces correct output, and the output tensor has the conj/neg bit set to True
1911            if inplace_variant is not None and not sample.broadcasts_input:
1912                cloned2 = clone_and_perform_view(tensor, requires_grad=False)
1913                if (
1914                    isinstance(expected_forward, torch.Tensor)
1915                    and expected_forward.dtype is tensor.dtype
1916                ):
1917                    inplace_forward = inplace_variant(
1918                        cloned2, *sample.args, **sample.kwargs
1919                    )
1920                    self.assertTrue(is_bit_set(inplace_forward))
1921                    self.assertEqual(inplace_forward, expected_forward)
1922
1923            # TODO: backward consistency only supported for single tensor outputs
1924            # TODO: backward consistency only checked on sample.input, not all
1925            #   tensor inputs
1926            # TODO: update to handle checking grads of all tensor inputs as
1927            #   derived from each tensor output
1928            if (
1929                isinstance(expected_forward, torch.Tensor)
1930                and expected_forward.requires_grad
1931            ):
1932                output_process_fn_grad = sample.output_process_fn_grad or (lambda x: x)
1933                expected_forward = output_process_fn_grad(expected_forward)
1934                forward_with_mathview = output_process_fn_grad(forward_with_mathview)
1935
1936                tensor = (
1937                    sample.input
1938                    if isinstance(sample.input, torch.Tensor)
1939                    else sample.input[0]
1940                )
1941                expected_forward.sum().abs().backward(retain_graph=True)
1942                forward_with_mathview.sum().abs().backward(retain_graph=True)
1943                if tensor.grad is not None:
1944                    cloned1_tensor = (
1945                        cloned1 if isinstance(cloned1, torch.Tensor) else cloned1[0]
1946                    )
1947                    self.assertEqual(tensor.grad, cloned1_tensor.grad)
1948
1949                    tensor.grad, cloned1_tensor.grad = None, None
1950
1951                    # a repeat of the above test if output is not complex valued
1952                    if out_type(expected_forward):
1953                        grad = torch.randn_like(expected_forward)
1954                        expected_forward.backward(grad)
1955                        forward_with_mathview.backward(
1956                            math_op_view(math_op_physical(grad))
1957                        )
1958
1959                        self.assertEqual(tensor.grad, cloned1_tensor.grad)
1960
1961    @ops(ops_and_refs, allowed_dtypes=(torch.cfloat,))
1962    def test_conj_view(self, device, dtype, op):
1963        if not op.test_conjugated_samples:
1964            self.skipTest("Operation doesn't support conjugated inputs.")
1965        math_op_physical = torch.conj_physical
1966        math_op_view = torch.conj
1967        _requires_grad = torch.cfloat in op.supported_backward_dtypes(
1968            torch.device(device).type
1969        )
1970        is_bit_set = torch.is_conj
1971        samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)
1972        self._test_math_view(
1973            device,
1974            dtype,
1975            op,
1976            samples,
1977            math_op_physical,
1978            math_op_view,
1979            is_bit_set,
1980            torch.is_complex,
1981        )
1982
1983    @ops(ops_and_refs, allowed_dtypes=(torch.double,))
1984    def test_neg_view(self, device, dtype, op):
1985        if not op.test_neg_view:
1986            self.skipTest("Operation not tested with tensors with negative bit.")
1987        math_op_physical = torch.neg
1988        math_op_view = torch._neg_view
1989        is_bit_set = torch.is_neg
1990        samples = op.sample_inputs(device, dtype, requires_grad=op.supports_autograd)
1991        self._test_math_view(
1992            device,
1993            dtype,
1994            op,
1995            samples,
1996            math_op_physical,
1997            math_op_view,
1998            is_bit_set,
1999            lambda x: True,
2000        )
2001
2002    @ops(ops_and_refs, allowed_dtypes=(torch.cdouble,))
2003    def test_neg_conj_view(self, device, dtype, op):
2004        if not op.test_neg_view:
2005            self.skipTest("Operation not tested with tensors with negative bit.")
2006        if not op.test_conjugated_samples:
2007            self.skipTest("Operation doesn't support conjugated inputs.")
2008
2009        def math_op_physical(x):
2010            return -x.conj_physical()
2011
2012        def math_op_view(x):
2013            return torch._neg_view(x).conj()
2014
2015        def is_bit_set(x):
2016            return torch.is_neg(x) and torch.is_conj(x)
2017
2018        _requires_grad = dtype in op.supported_backward_dtypes(
2019            torch.device(device).type
2020        )
2021        samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)
2022        # Only test one sample
2023        samples = itertools.islice(samples, 1)
2024        self._test_math_view(
2025            device,
2026            dtype,
2027            op,
2028            samples,
2029            math_op_physical,
2030            math_op_view,
2031            is_bit_set,
2032            torch.is_complex,
2033        )
2034
2035
2036# input strides and size may have been altered due to the result of an inplace op
2037def check_inplace_view(func, input, rs, input_size, input_strides):
2038    if func is None:
2039        return
2040    # TODO: extend this test to test ops with multiple outputs and ops like native_batch_norm(_legit).out
2041    # which mutate not necessarily the first input.
2042    if isinstance(rs, torch.Tensor) and rs is input:
2043        unequal_size = rs.size() != input_size
2044        unequal_strides = rs.stride() != input_strides
2045        # resize_ should probably have inplace_view tag. Not adding the tag since it
2046        # breaks some codegen logic
2047        if unequal_size or unequal_strides:
2048            if isinstance(func, torch._ops.OpOverloadPacket):
2049                func = func.default
2050            # Reference: https://github.com/pytorch/pytorch/issues/78759
2051            if func is not torch.ops.aten.resize_.default:
2052                # TODO: use self.assertIn when we have separate tests for each tag
2053                assert torch.Tag.inplace_view in func.tags
2054
2055
2056# A mode that when enabled runs correctness checks to ensure
2057# that operators have expected tags based on their input and
2058# output tensor properties
2059class TestTagsMode(TorchDispatchMode):
2060    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
2061        if isinstance(args[0], torch.Tensor):
2062            old_size = args[0].size()
2063            old_stride = args[0].stride()
2064            rs = func(*args, **kwargs)
2065            check_inplace_view(func, args[0], rs, old_size, old_stride)
2066        else:
2067            rs = func(*args, **kwargs)
2068        return rs
2069
2070
2071# Test to verify the correctness for tags in `tags.yaml`, also available for access through `torch.Tags`
2072@unMarkDynamoStrictTest
2073class TestTags(TestCase):
2074    @onlyCPU
2075    @ops(ops_and_refs, dtypes=OpDTypes.any_one)
2076    def test_tags(self, device, dtype, op):
2077        samples = op.sample_inputs(device, dtype, requires_grad=False)
2078        for sample in samples:
2079            # TODO: Test tags for ops that return a list of tensors
2080            input = sample.input
2081            if isinstance(input, torch.Tensor):
2082                old_size = input.size()
2083                old_stride = input.stride()
2084                with TestTagsMode():
2085                    rs = op(input, *sample.args, **sample.kwargs)
2086                # TODO: add test for aliases: https://github.com/pytorch/pytorch/issues/78761
2087                aten_name = op.aten_name if op.aten_name is not None else op.name
2088                opoverloadpacket = getattr(torch.ops.aten, aten_name, None)
2089                check_inplace_view(opoverloadpacket, input, rs, old_size, old_stride)
2090
2091
2092class TestSelfKwarg(TestCase):
2093    def test_self_kwargs(self):
2094        """Verify that we can call the aten ops with all kwargs even if the
2095        argument's name is "self"
2096        """
2097        torch.ops.aten.reshape.default(self=torch.rand(1, 2), shape=[2])
2098        torch.ops.aten.min.default(self=torch.rand(100))
2099
2100
2101@unMarkDynamoStrictTest
2102class TestRefsOpsInfo(TestCase):
2103    import_paths = [
2104        "_refs",
2105        "_refs.special",
2106        "_refs.nn.functional",
2107        "_refs.fft",
2108        "_refs._conversions",
2109    ]
2110    module_alls = [
2111        (path, import_module(f"torch.{path}").__all__) for path in import_paths
2112    ]
2113    ref_ops_names = tuple(
2114        itertools.chain.from_iterable(
2115            [f"{path}.{op}" for op in module_all] for path, module_all in module_alls
2116        )
2117    )
2118    ref_db_names = {ref_op.name for ref_op in python_ref_db}
2119
2120    # TODO: References that do not have an entry in python_ref_db
2121    skip_ref_ops = {
2122        "_refs.alias",
2123        "_refs.bitwise_right_shift",
2124        "_refs.copy_to",
2125        "_refs.empty_permuted",
2126        "_refs.empty_strided",
2127        "_refs.equal",
2128        "_refs.full",
2129        "_refs.full_like",
2130        "_refs.is_complex",
2131        "_refs.to",
2132        "_refs.mvlgamma",
2133        "_refs.ones",
2134        "_refs.ones_like",
2135        "_refs.special.expit",
2136        "_refs.std_var",
2137        "_refs.swap_axes",
2138        "_refs.uniform",
2139        "_refs.scalar_tensor",
2140        "_refs.trunc_divide",
2141        "_refs.zero",
2142        "_refs.zeros",
2143        "_refs.zeros_like",
2144        "_refs.rfloordiv",
2145        "_refs.rtruediv",
2146        "_refs.rpow",
2147        # These should be tested with their out-of-place counterparts
2148        "_refs.index_add_",
2149        "_refs.index_copy_",
2150        "_refs.index_fill_",
2151        "_refs.native_group_norm",
2152    }
2153
2154    not_in_decomp_table = {
2155        # duplicated in _decomp and _refs
2156        "_refs.nn.functional.group_norm",
2157        "_refs.nn.functional.mse_loss",
2158        "_refs.floor_divide",
2159        # duplicated as refs do not have decent support for advanced indexing
2160        "_refs.index_copy",
2161        "_refs.index_copy_",
2162        "_refs.index_add",
2163        "_refs.index_add_",
2164        # these are not aten ops?
2165        "_refs._conversions.bfloat16",
2166        "_refs._conversions.bool",
2167        "_refs._conversions.byte",
2168        "_refs._conversions.char",
2169        "_refs._conversions.double",
2170        "_refs._conversions.float",
2171        "_refs._conversions.half",
2172        "_refs._conversions.int",
2173        "_refs._conversions.long",
2174        "_refs._conversions.short",
2175        "_refs._conversions.chalf",
2176        "_refs._conversions.cfloat",
2177        "_refs._conversions.cdouble",
2178        "_refs.broadcast_shapes",
2179        "_refs.broadcast_tensors",
2180        "_refs.mvlgamma",
2181        "_refs.nn.functional.layer_norm",
2182        "_refs.nn.functional.tanhshrink",
2183        "_refs.nn.functional.triplet_margin_loss",
2184        "_refs.rfloordiv",
2185        "_refs.rtruediv",
2186        "_refs.rpow",
2187        # CompositeImplicitAutograd
2188        "_refs.allclose",
2189        "_refs.atleast_1d",
2190        "_refs.atleast_2d",
2191        "_refs.atleast_3d",
2192        "_refs.broadcast_to",
2193        "_refs.chunk",
2194        "_refs.column_stack",
2195        "_refs.contiguous",
2196        "_refs.dsplit",
2197        "_refs.dstack",
2198        "_refs.fill",
2199        "_refs.fill_",
2200        "_refs.flatten",
2201        "_refs.fliplr",
2202        "_refs.flipud",
2203        "_refs.float_power",
2204        "_refs.hsplit",
2205        "_refs.hstack",
2206        "_refs.isclose",
2207        "_refs.isfinite",
2208        "_refs.isreal",
2209        "_refs.istft",
2210        "_refs.log_softmax",
2211        "_refs.movedim",
2212        "_refs.narrow",
2213        "_refs.nn.functional.dropout",
2214        "_refs.nn.functional.l1_loss",
2215        "_refs.nn.functional.smooth_l1_loss",
2216        "_refs.nn.functional.log_softmax",
2217        "_refs.nn.functional.poisson_nll_loss",
2218        "_refs.nn.functional.softmax",
2219        "_refs.nn.functional.softmin",
2220        "_refs.positive",
2221        "_refs.ravel",
2222        "_refs.reshape",
2223        "_refs.softmax",
2224        "_refs.special.expit",
2225        "_refs.special.log_softmax",
2226        "_refs.special.softmax",
2227        "_refs.square",
2228        "_refs.stft",
2229        "_refs.T",
2230        "_refs.take_along_dim",
2231        "_refs.tensor_split",
2232        "_refs.to",
2233        "_refs.true_divide",
2234        "_refs.trunc_divide",
2235        "_refs.vsplit",
2236        "_refs.vstack",
2237        "_refs.linalg.matrix_norm",
2238        "_refs.linalg.norm",
2239        "_refs.linalg.svd",
2240        "_refs.linalg.svdvals",
2241        "_refs.unflatten",
2242        "_refs.sum_to_size",
2243        # ref implementation missing kwargs
2244        "_refs.full_like",  # missing "layout"
2245        "_refs.scalar_tensor",  # missing "layout"
2246        # other
2247        "_refs.block_diag",  # only refs._block_diag_iterable is in decomposition table
2248        "_refs.empty",  # intentional; direct empty is faster and has less guards
2249        "_refs.empty_permuted",  # intentional; direct empty is faster and has less guards
2250        "_refs.expand_as",
2251        "_refs.as_strided",  # _prims._as_strided_meta: "reduce() of empty sequence with no initial value"
2252        "_refs.copy_to",  # torch._C._jit_get_operation: No such operator aten::copy_to
2253        "_refs.equal",  # 'bool' object has no attribute 'dtype'
2254        "_refs.conj",  # Calls _prims.conj
2255        "_refs.real",
2256        "_refs.imag",
2257        "_refs.reshape_as",
2258        "_refs.view_as",
2259        "_refs.view_as_complex",  # TorchInductor does not support complex at the moment.
2260        # the decompositions for these ops are slightly different
2261        # because of out handling
2262        "_refs.var_mean",
2263        "_refs.std_mean",
2264        "_refs.native_layer_norm",
2265    }
2266
2267    @parametrize("op", ref_ops_names)
2268    def test_refs_are_in_python_ref_db(self, op):
2269        inplace = op[-1] == "_"
2270        if op in self.skip_ref_ops:
2271            raise unittest.SkipTest(f"{op} does not have an entry in python_ref_db")
2272        elif inplace:
2273            self.assertNotIn(
2274                op,
2275                self.ref_db_names,
2276                msg=f"{op} is an in-place operation and should not have an OpInfo",
2277            )
2278        else:
2279            # Intentionally don't use assertIn to avoid printing the
2280            # (very large) container
2281            self.assertTrue(op in self.ref_db_names, msg=f"{op} not in ref_db_names")
2282
2283    @parametrize("op", ref_ops_names)
2284    def test_refs_are_in_decomp_table(self, op):
2285        path = op.split(".")
2286        module_path = ".".join(path[:-1])
2287        op_name = path[-1]
2288        op_impl = getattr(import_module(f"torch.{module_path}"), op_name)
2289
2290        if op in self.not_in_decomp_table:
2291            self.assertNotIn(
2292                op_impl,
2293                torch._decomp.decomposition_table.values(),
2294                f"Unexpectedly found {op} in torch._decomp.decomposition_table.values()",
2295            )
2296        else:
2297            self.assertIn(
2298                op_impl,
2299                torch._decomp.decomposition_table.values(),
2300                f"Did not find {op} in torch._decomp.decomposition_table.values()",
2301            )
2302
2303
2304fake_skips = (
2305    "aminmax",  # failing input
2306    "cov",  # aweights cannot be negtaive
2307    "istft",  # window overlap add min: 0
2308    "linalg.eigvals",  # The tensor has a non-zero number of elements, but its data is not allocated yet
2309    "linalg.eigvalsh",  # aten::linalg_eigvalsh.out' with arguments from the 'Meta' backend
2310    "linalg.matrix_power",  # Could not run 'aten::eye.m_out' with arguments from the 'Meta' backend
2311    # "linalg.pinv",  # Could not run 'aten::pinv.out' with arguments from the 'Meta' backen
2312    "linalg.matrix_rank.hermitian",  # Could not run 'aten::linalg_eigvalsh.out' with arguments from the 'Meta' backend
2313    "linalg.pinv.hermitian",  # tensor.mH is only supported on matrices or batches of matrices. Got 1-D tensor
2314    "linalg.solve",  # Could not run 'aten::linalg_solve' with arguments from the 'Meta' backend
2315    "linalg.tensorsolve",  # Could not run 'aten::linalg_solve' with arguments from the 'Meta'
2316    "lu_solve",  # MALLOC ERROR: debug
2317    "multinomial",  # Could not run 'aten::multinomial' with arguments from the 'Meta' backend
2318    "mvlgamma.mvlgamma_p_1",  # Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend
2319    "mvlgamma.mvlgamma_p_3",  # Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend
2320    "mvlgamma.mvlgamma_p_5",  # Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend
2321    "nanmean",  # logical_not() got an unexpected keyword argument 'out'
2322    "quantile",  # quantile() q values must be in the range [0, 1]
2323    "nanquantile",  # quantile() q values must be in the range [0, 1]
2324    "nn.functional.ctc_loss",  # The tensor has a non-zero number of elements, but its data is not allocated yet
2325    "nn.functional.embedding_bag",  # sometimes errors
2326    "nn.functional.nll_loss",  # sometimes errors
2327    "nn.functional.max_pool1d",  # The tensor has a non-zero number of elements
2328    "to_sparse",  # Could not run 'aten::_to_sparse' with arguments from the 'Meta' backend
2329    "tensor_split",  # The tensor has a non-zero number of elements, but its data is not allocated yet
2330    "repeat_interleave",  # cannot repeat_interleave a meta tensor without output_size
2331    "sparse.sampled.addmm",  # sparsity not supported
2332    # Can not infer total number of classes from meta. no way at present to throw DynamicOutputShapeException
2333    "nn.functional.one_hot",
2334    "narrow",  # Fails only for one overload with DataDependentOutputException (hence skip).
2335)
2336
2337fake_autocast_device_skips = defaultdict(dict)
2338
2339# TODO: investigate/fix
2340fake_autocast_device_skips["cpu"] = {"linalg.pinv"}
2341fake_autocast_device_skips["cuda"] = {"linalg.pinv", "pinverse"}
2342
2343
2344dynamic_output_op_tests = (
2345    "argwhere",
2346    "bincount",
2347    "combinations",
2348    "linalg.lstsq",
2349    "masked_select",
2350    "nonzero",
2351    "unique_consecutive",
2352    "unique",
2353    "linalg.lstsq.grad_oriented",
2354)
2355
2356# Ops that have dynamic output shapes that we can handle when
2357# allow_dynamic_shape_ops is True in fake tensor shape environment.
2358supported_dynamic_output_op_tests = (
2359    "nonzero",
2360    "unique",
2361    "repeat_interleave",
2362    "masked_select",
2363)
2364
2365# some inputs invoke dynamic output shape operators, some do not
2366sometimes_dynamic_output_op_test = ("__getitem__", "index_select")
2367
2368data_dependent_op_tests = (
2369    "equal",
2370    "corrcoef",
2371    "nn.functional.gaussian_nll_loss",
2372    "allclose",
2373)
2374
2375aliasing_failures = ("histogramdd",)
2376
2377fake_backward_skips = {
2378    "linalg.cond",
2379    "linalg.matrix_norm",
2380    "linalg.norm",
2381    "linalg.svd",
2382    "linalg.svdvals",
2383    "pca_lowrank",
2384    "roll",
2385    "svd_lowrank",
2386    "sgn",
2387}
2388
2389fake_backward_xfails = {skip(s) for s in fake_backward_skips} | {
2390    xfail("fft.ihfftn"),  # Mismatch in aten._conj_physical.default
2391    xfail("fft.ihfft2"),  # Mismatch in aten._conj_physical.default
2392    skip("nn.functional.ctc_loss"),
2393}
2394
2395fake_autocast_backward_xfails = {
2396    skip("nn.functional.binary_cross_entropy"),
2397    skip("sparse.sampled_addmm"),
2398    skip("linalg.pinv"),
2399    skip("linalg.pinv", "hermitian"),
2400    skip("linalg.pinv", "singular"),
2401    skip("pinverse"),
2402}
2403
2404
2405@unMarkDynamoStrictTest
2406class TestFakeTensor(TestCase):
2407    def setUp(self):
2408        # Turn on FakeTensor caching and cross-checking for these tests:
2409        cache_enabled = unittest.mock.patch(
2410            "torch._dynamo.config.fake_tensor_cache_enabled", True
2411        )
2412        cache_enabled.start()
2413        self.addCleanup(cache_enabled.stop)
2414
2415        cache_crosscheck = unittest.mock.patch(
2416            "torch._dynamo.config.fake_tensor_cache_crosscheck_enabled", True
2417        )
2418        cache_crosscheck.start()
2419        self.addCleanup(cache_crosscheck.stop)
2420
2421    def _test_fake_helper(self, device, dtype, op, context):
2422        name = op.name
2423        if op.variant_test_name:
2424            name += "." + op.variant_test_name
2425        if name in fake_skips or "sparse" in name or "jiterator" in name:
2426            self.skipTest("Skip failing test")
2427
2428        samples = op.sample_inputs(device, dtype, requires_grad=False)
2429        for sample in samples:
2430            mode = FakeTensorMode()
2431
2432            from torch.fx.experimental.symbolic_shapes import ShapeEnv
2433
2434            allow_dynamic_output_shape_shape_env = ShapeEnv(
2435                allow_dynamic_output_shape_ops=True
2436            )
2437
2438            allow_dynamic_output_shape_mode = FakeTensorMode(
2439                shape_env=allow_dynamic_output_shape_shape_env
2440            )
2441
2442            try:
2443                with context():
2444                    res = op(sample.input, *sample.args, **sample.kwargs)
2445            except Exception:
2446                continue
2447
2448            def run_with_fake_mode_and_verify(fake_mode, match_results=True):
2449                def map_to_fake(e):
2450                    if isinstance(e, torch.Tensor):
2451                        return fake_mode.from_tensor(e)
2452                    else:
2453                        return e
2454
2455                input = tree_map(map_to_fake, sample.input)
2456                args = tree_map(map_to_fake, sample.args)
2457                kwargs = tree_map(map_to_fake, sample.kwargs)
2458
2459                try:
2460                    with context():
2461                        with fake_mode:
2462                            res_fake = op(input, *args, **kwargs)
2463
2464                    if not match_results:
2465                        return
2466
2467                    for fake_out, real_out in zip(
2468                        pytree.tree_leaves(res_fake), pytree.tree_leaves(res)
2469                    ):
2470                        if not isinstance(fake_out, torch.Tensor):
2471                            self.assertTrue(not isinstance(real_out, torch.Tensor))
2472                            self.assertEqual(fake_out, real_out)
2473                            continue
2474
2475                        self.assertTrue(isinstance(fake_out, FakeTensor))
2476                        # if you see a shape exception here, you may need to add
2477                        # a `dynamic_output_shape` tag to an operator
2478
2479                        if op.op not in [
2480                            torch.ops.aten._efficient_attention_forward,
2481                            torch.ops.aten._flash_attention_forward,
2482                        ]:
2483                            # prims/decomps must correctly model strides,
2484                            # see https://github.com/pytorch/pytorch/issues/78050#issuecomment-1253950325
2485
2486                            # note: the excluded ops have intentionally incorrect device;
2487                            # see "Note [Seed and Offset]" (_meta_registrations.py)
2488                            prims.utils.compare_tensor_meta(fake_out, real_out, True)
2489
2490                        if name not in aliasing_failures:
2491                            fake_aliasing = outputs_alias_inputs(
2492                                (input, args, kwargs), res_fake
2493                            )
2494                            real_aliasing = outputs_alias_inputs(
2495                                (sample.input, sample, args, sample.kwargs), res
2496                            )
2497                            self.assertEqual(fake_aliasing, real_aliasing)
2498
2499                    self.assertTrue(
2500                        name not in dynamic_output_op_tests
2501                        and name not in data_dependent_op_tests
2502                    )
2503
2504                except torch._subclasses.fake_tensor.UnsupportedFakeTensorException:
2505                    pass
2506                except torch._subclasses.fake_tensor.UnsupportedOperatorException:
2507                    pass
2508                except torch._subclasses.fake_tensor.DynamicOutputShapeException:
2509                    self.assertTrue(
2510                        name in dynamic_output_op_tests
2511                        or name in sometimes_dynamic_output_op_test
2512                    )
2513                    self.assertTrue(
2514                        fake_mode.shape_env is None
2515                        or not fake_mode.shape_env.allow_dynamic_output_shape_ops
2516                        or name not in supported_dynamic_output_op_tests
2517                    )
2518                except torch._subclasses.fake_tensor.DataDependentOutputException:
2519                    self.assertTrue(name in data_dependent_op_tests)
2520
2521            run_with_fake_mode_and_verify(mode)
2522            if name in supported_dynamic_output_op_tests:
2523                run_with_fake_mode_and_verify(
2524                    allow_dynamic_output_shape_mode, match_results=False
2525                )
2526
2527    @ops(op_db, dtypes=OpDTypes.any_one)
2528    def test_pointwise_ops(self, device, dtype, op):
2529        name = op.name
2530        if op.variant_test_name:
2531            name += "." + op.variant_test_name
2532        if name in fake_skips or "sparse" in name or "jiterator" in name:
2533            self.skipTest("Skip failing test")
2534
2535        test_self = self
2536
2537        class TestPointwiseMode(TorchDispatchMode):
2538            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
2539                kwargs = kwargs or {}
2540
2541                out = func(*args, **kwargs)
2542
2543                if torch.Tag.pointwise in func.tags:
2544                    shapes = []
2545                    for inp in pytree.arg_tree_leaves(*args, **kwargs):
2546                        if isinstance(inp, torch.Tensor):
2547                            shapes.append(inp.shape)
2548
2549                    out_shape = torch._refs._broadcast_shapes(*shapes)
2550
2551                    for out_elem in pytree.tree_leaves(out):
2552                        if isinstance(out_elem, torch.Tensor):
2553                            test_self.assertEqual(out_elem.shape, out_shape)
2554
2555                return out
2556
2557        samples = op.sample_inputs(device, dtype, requires_grad=False)
2558        for sample in samples:
2559            mode = FakeTensorMode()
2560
2561            def map_to_fake(e):
2562                if isinstance(e, torch.Tensor):
2563                    return mode.from_tensor(e)
2564                else:
2565                    return e
2566
2567            input = tree_map(map_to_fake, sample.input)
2568            args = tree_map(map_to_fake, sample.args)
2569            kwargs = tree_map(map_to_fake, sample.kwargs)
2570
2571            try:
2572                op(input, *args, **kwargs)
2573            except Exception as e:
2574                continue
2575
2576            with TestPointwiseMode():
2577                with mode:
2578                    op(input, *args, **kwargs)
2579
2580    @ops(op_db, dtypes=OpDTypes.any_one)
2581    def test_fake(self, device, dtype, op):
2582        self._test_fake_helper(device, dtype, op, contextlib.nullcontext)
2583
2584    @ops(op_db, dtypes=OpDTypes.any_one)
2585    def test_fake_autocast(self, device, dtype, op):
2586        device_type = torch.device(device).type
2587        if op.name in fake_autocast_device_skips[device_type]:
2588            self.skipTest("Skip failing test")
2589
2590        def context_fn():
2591            return torch.amp.autocast(device_type)
2592
2593        self._test_fake_helper(device, dtype, op, context_fn)
2594
2595    def _test_fake_crossref_helper(self, device, dtype, op, context):
2596        samples = op.sample_inputs(device, dtype, requires_grad=True)
2597
2598        for iter, sample in enumerate(samples):
2599            args = [sample.input] + list(sample.args)
2600            kwargs = sample.kwargs
2601
2602            # skip these to speed up tests
2603            common_skip_ops = (
2604                aten.detach.default,
2605                aten.empty_strided.default,
2606                aten.copy_.default,
2607                aten.is_same_size.default,
2608            )
2609
2610            # TODO: enable check_aliasing, batch norm fails
2611            try:
2612                with torch._subclasses.CrossRefFakeMode(
2613                    ignore_op_fn=lambda fn: fn in common_skip_ops, check_aliasing=True
2614                ):
2615                    with warnings.catch_warnings(), context(), torch.autograd.set_multithreading_enabled(
2616                        False
2617                    ):
2618                        composite_compliance.compute_expected_grads(
2619                            op.get_op(),
2620                            args,
2621                            kwargs,
2622                            sample.output_process_fn_grad,
2623                            op.gradcheck_wrapper,
2624                        )
2625            except torch._subclasses.fake_tensor.UnsupportedOperatorException:
2626                pass
2627
2628    @onlyCUDA
2629    @ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,))
2630    @skipOps(
2631        "TestFakeTensor", "test_fake_crossref_backward_no_amp", fake_backward_xfails
2632    )
2633    def test_fake_crossref_backward_no_amp(self, device, dtype, op):
2634        self._test_fake_crossref_helper(device, dtype, op, contextlib.nullcontext)
2635
2636    @onlyCUDA
2637    @ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,))
2638    @skipOps(
2639        "TestFakeTensor",
2640        "test_fake_crossref_backward_amp",
2641        fake_backward_xfails | fake_autocast_backward_xfails,
2642    )
2643    def test_fake_crossref_backward_amp(self, device, dtype, op):
2644        self._test_fake_crossref_helper(device, dtype, op, torch.cuda.amp.autocast)
2645
2646    @ops([op for op in ops_and_refs if op.is_factory_function])
2647    def test_strided_layout(self, device, dtype, op):
2648        samples = op.sample_inputs(device, dtype)
2649        for sample in samples:
2650            kwargs = sample.kwargs.copy()
2651            kwargs["layout"] = torch.strided
2652            strided_result = op(sample.input, *sample.args, **kwargs)
2653            self.assertEqual(strided_result.layout, torch.strided)
2654
2655
2656instantiate_device_type_tests(TestCommon, globals())
2657instantiate_device_type_tests(TestCompositeCompliance, globals())
2658instantiate_device_type_tests(TestMathBits, globals())
2659instantiate_device_type_tests(TestRefsOpsInfo, globals(), only_for="cpu")
2660instantiate_device_type_tests(TestFakeTensor, globals())
2661instantiate_device_type_tests(TestTags, globals())
2662
2663if __name__ == "__main__":
2664    TestCase._default_dtype_check_enabled = True
2665    run_tests()
2666