xref: /aosp_15_r20/external/pytorch/test/test_binary_ufuncs.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: tests"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport itertools
4*da0073e9SAndroid Build Coastguard Workerimport math
5*da0073e9SAndroid Build Coastguard Workerimport operator
6*da0073e9SAndroid Build Coastguard Workerimport random
7*da0073e9SAndroid Build Coastguard Workerimport warnings
8*da0073e9SAndroid Build Coastguard Workerfrom functools import partial
9*da0073e9SAndroid Build Coastguard Workerfrom itertools import chain, product
10*da0073e9SAndroid Build Coastguard Workerfrom numbers import Number
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workerimport numpy as np
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Workerimport torch
15*da0073e9SAndroid Build Coastguard Workerimport torch.autograd.forward_ad as fwAD
16*da0073e9SAndroid Build Coastguard Workerfrom torch import inf, nan
17*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor
18*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import (
19*da0073e9SAndroid Build Coastguard Worker    deviceCountAtLeast,
20*da0073e9SAndroid Build Coastguard Worker    dtypes,
21*da0073e9SAndroid Build Coastguard Worker    dtypesIfCPU,
22*da0073e9SAndroid Build Coastguard Worker    dtypesIfCUDA,
23*da0073e9SAndroid Build Coastguard Worker    expectedFailureMeta,
24*da0073e9SAndroid Build Coastguard Worker    instantiate_device_type_tests,
25*da0073e9SAndroid Build Coastguard Worker    onlyCPU,
26*da0073e9SAndroid Build Coastguard Worker    onlyCUDA,
27*da0073e9SAndroid Build Coastguard Worker    onlyNativeDeviceTypes,
28*da0073e9SAndroid Build Coastguard Worker    OpDTypes,
29*da0073e9SAndroid Build Coastguard Worker    ops,
30*da0073e9SAndroid Build Coastguard Worker    precisionOverride,
31*da0073e9SAndroid Build Coastguard Worker    skipIf,
32*da0073e9SAndroid Build Coastguard Worker    skipMeta,
33*da0073e9SAndroid Build Coastguard Worker)
34*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import (
35*da0073e9SAndroid Build Coastguard Worker    all_types_and,
36*da0073e9SAndroid Build Coastguard Worker    all_types_and_complex_and,
37*da0073e9SAndroid Build Coastguard Worker    complex_types,
38*da0073e9SAndroid Build Coastguard Worker    floating_and_complex_types,
39*da0073e9SAndroid Build Coastguard Worker    floating_types_and,
40*da0073e9SAndroid Build Coastguard Worker    get_all_int_dtypes,
41*da0073e9SAndroid Build Coastguard Worker    get_all_math_dtypes,
42*da0073e9SAndroid Build Coastguard Worker    integral_types,
43*da0073e9SAndroid Build Coastguard Worker    integral_types_and,
44*da0073e9SAndroid Build Coastguard Worker)
45*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import (
46*da0073e9SAndroid Build Coastguard Worker    binary_ufuncs,
47*da0073e9SAndroid Build Coastguard Worker    binary_ufuncs_and_refs,
48*da0073e9SAndroid Build Coastguard Worker    generate_elementwise_binary_broadcasting_tensors,
49*da0073e9SAndroid Build Coastguard Worker    generate_elementwise_binary_extremal_value_tensors,
50*da0073e9SAndroid Build Coastguard Worker    generate_elementwise_binary_large_value_tensors,
51*da0073e9SAndroid Build Coastguard Worker    generate_elementwise_binary_small_value_tensors,
52*da0073e9SAndroid Build Coastguard Worker    generate_elementwise_binary_tensors,
53*da0073e9SAndroid Build Coastguard Worker    generate_elementwise_binary_with_scalar_and_type_promotion_samples,
54*da0073e9SAndroid Build Coastguard Worker    generate_elementwise_binary_with_scalar_samples,
55*da0073e9SAndroid Build Coastguard Worker)
56*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
57*da0073e9SAndroid Build Coastguard Worker    gradcheck,
58*da0073e9SAndroid Build Coastguard Worker    iter_indices,
59*da0073e9SAndroid Build Coastguard Worker    numpy_to_torch_dtype_dict,
60*da0073e9SAndroid Build Coastguard Worker    run_tests,
61*da0073e9SAndroid Build Coastguard Worker    set_default_dtype,
62*da0073e9SAndroid Build Coastguard Worker    skipIfTorchDynamo,
63*da0073e9SAndroid Build Coastguard Worker    slowTest,
64*da0073e9SAndroid Build Coastguard Worker    TEST_SCIPY,
65*da0073e9SAndroid Build Coastguard Worker    TestCase,
66*da0073e9SAndroid Build Coastguard Worker    torch_to_numpy_dtype_dict,
67*da0073e9SAndroid Build Coastguard Worker    xfailIfTorchDynamo,
68*da0073e9SAndroid Build Coastguard Worker)
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Workerif TEST_SCIPY:
72*da0073e9SAndroid Build Coastguard Worker    import scipy.integrate
73*da0073e9SAndroid Build Coastguard Worker    import scipy.special
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker# TODO: update to use opinfos consistently
77*da0073e9SAndroid Build Coastguard Workerclass TestBinaryUfuncs(TestCase):
78*da0073e9SAndroid Build Coastguard Worker    # Generic tests for elementwise binary (AKA binary universal (u) functions (funcs))
79*da0073e9SAndroid Build Coastguard Worker    # TODO: below contiguous tensor results are compared with a variety of noncontiguous results.
80*da0073e9SAndroid Build Coastguard Worker    #   It would be interesting to have the lhs and rhs have different discontiguities.
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker    # Helper for comparing torch tensors and NumPy arrays
83*da0073e9SAndroid Build Coastguard Worker    # TODO: should this or assertEqual also validate that strides are equal?
84*da0073e9SAndroid Build Coastguard Worker    def assertEqualHelper(
85*da0073e9SAndroid Build Coastguard Worker        self, actual, expected, msg, *, dtype, exact_dtype=True, **kwargs
86*da0073e9SAndroid Build Coastguard Worker    ):
87*da0073e9SAndroid Build Coastguard Worker        assert isinstance(actual, torch.Tensor)
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker        # Some NumPy functions return scalars, not arrays
90*da0073e9SAndroid Build Coastguard Worker        if isinstance(expected, Number):
91*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual.item(), expected, msg=msg, **kwargs)
92*da0073e9SAndroid Build Coastguard Worker        elif isinstance(expected, np.ndarray):
93*da0073e9SAndroid Build Coastguard Worker            # Handles exact dtype comparisons between arrays and tensors
94*da0073e9SAndroid Build Coastguard Worker            if exact_dtype:
95*da0073e9SAndroid Build Coastguard Worker                # Allows array dtype to be float32 when comparing with bfloat16 tensors
96*da0073e9SAndroid Build Coastguard Worker                #   since NumPy doesn't support the bfloat16 dtype
97*da0073e9SAndroid Build Coastguard Worker                # Also ops like scipy.special.erf, scipy.special.erfc, etc, promote float16
98*da0073e9SAndroid Build Coastguard Worker                # to float32
99*da0073e9SAndroid Build Coastguard Worker                if expected.dtype == np.float32:
100*da0073e9SAndroid Build Coastguard Worker                    assert actual.dtype in (
101*da0073e9SAndroid Build Coastguard Worker                        torch.float16,
102*da0073e9SAndroid Build Coastguard Worker                        torch.bfloat16,
103*da0073e9SAndroid Build Coastguard Worker                        torch.float32,
104*da0073e9SAndroid Build Coastguard Worker                    )
105*da0073e9SAndroid Build Coastguard Worker                else:
106*da0073e9SAndroid Build Coastguard Worker                    assert expected.dtype == torch_to_numpy_dtype_dict[actual.dtype]
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
109*da0073e9SAndroid Build Coastguard Worker                actual,
110*da0073e9SAndroid Build Coastguard Worker                torch.from_numpy(expected).to(actual.dtype),
111*da0073e9SAndroid Build Coastguard Worker                msg,
112*da0073e9SAndroid Build Coastguard Worker                exact_device=False,
113*da0073e9SAndroid Build Coastguard Worker                **kwargs,
114*da0073e9SAndroid Build Coastguard Worker            )
115*da0073e9SAndroid Build Coastguard Worker        else:
116*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expected, msg, exact_device=False, **kwargs)
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker    # Tests that the function and its (array-accepting) reference produce the same
119*da0073e9SAndroid Build Coastguard Worker    #   values on given tensors
120*da0073e9SAndroid Build Coastguard Worker    def _test_reference_numerics(self, dtype, op, gen, equal_nan=True):
121*da0073e9SAndroid Build Coastguard Worker        def _helper_reference_numerics(
122*da0073e9SAndroid Build Coastguard Worker            expected, actual, msg, exact_dtype, equal_nan=True
123*da0073e9SAndroid Build Coastguard Worker        ):
124*da0073e9SAndroid Build Coastguard Worker            if not torch.can_cast(
125*da0073e9SAndroid Build Coastguard Worker                numpy_to_torch_dtype_dict[expected.dtype.type], dtype
126*da0073e9SAndroid Build Coastguard Worker            ):
127*da0073e9SAndroid Build Coastguard Worker                exact_dtype = False
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker            if dtype is torch.bfloat16 and expected.dtype == np.float32:
130*da0073e9SAndroid Build Coastguard Worker                # Ref: https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_utils.py#L1149
131*da0073e9SAndroid Build Coastguard Worker                self.assertEqualHelper(
132*da0073e9SAndroid Build Coastguard Worker                    actual,
133*da0073e9SAndroid Build Coastguard Worker                    expected,
134*da0073e9SAndroid Build Coastguard Worker                    msg,
135*da0073e9SAndroid Build Coastguard Worker                    dtype=dtype,
136*da0073e9SAndroid Build Coastguard Worker                    exact_dtype=exact_dtype,
137*da0073e9SAndroid Build Coastguard Worker                    rtol=16e-3,
138*da0073e9SAndroid Build Coastguard Worker                    atol=1e-5,
139*da0073e9SAndroid Build Coastguard Worker                )
140*da0073e9SAndroid Build Coastguard Worker            else:
141*da0073e9SAndroid Build Coastguard Worker                self.assertEqualHelper(
142*da0073e9SAndroid Build Coastguard Worker                    actual,
143*da0073e9SAndroid Build Coastguard Worker                    expected,
144*da0073e9SAndroid Build Coastguard Worker                    msg,
145*da0073e9SAndroid Build Coastguard Worker                    dtype=dtype,
146*da0073e9SAndroid Build Coastguard Worker                    equal_nan=equal_nan,
147*da0073e9SAndroid Build Coastguard Worker                    exact_dtype=exact_dtype,
148*da0073e9SAndroid Build Coastguard Worker                )
149*da0073e9SAndroid Build Coastguard Worker
150*da0073e9SAndroid Build Coastguard Worker        for sample in gen:
151*da0073e9SAndroid Build Coastguard Worker            # Each sample input acquired from the generator is just one lhs tensor
152*da0073e9SAndroid Build Coastguard Worker            #   and one rhs tensor
153*da0073e9SAndroid Build Coastguard Worker            l = sample.input
154*da0073e9SAndroid Build Coastguard Worker            r = sample.args[0]
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Worker            numpy_sample = sample.numpy()
157*da0073e9SAndroid Build Coastguard Worker            l_numpy = numpy_sample.input
158*da0073e9SAndroid Build Coastguard Worker            r_numpy = numpy_sample.args[0]
159*da0073e9SAndroid Build Coastguard Worker            actual = op(l, r)
160*da0073e9SAndroid Build Coastguard Worker            expected = op.ref(l_numpy, r_numpy)
161*da0073e9SAndroid Build Coastguard Worker
162*da0073e9SAndroid Build Coastguard Worker            # Crafts a custom error message for smaller, printable tensors
163*da0073e9SAndroid Build Coastguard Worker            def _numel(x):
164*da0073e9SAndroid Build Coastguard Worker                if isinstance(x, torch.Tensor):
165*da0073e9SAndroid Build Coastguard Worker                    return x.numel()
166*da0073e9SAndroid Build Coastguard Worker                # Assumes x is a scalar
167*da0073e9SAndroid Build Coastguard Worker                return 1
168*da0073e9SAndroid Build Coastguard Worker
169*da0073e9SAndroid Build Coastguard Worker            if _numel(l) <= 100 and _numel(r) <= 100:
170*da0073e9SAndroid Build Coastguard Worker                msg = (
171*da0073e9SAndroid Build Coastguard Worker                    "Failed to produce expected results! Input lhs tensor was"
172*da0073e9SAndroid Build Coastguard Worker                    f" {l}, rhs tensor was {r}, torch result is {actual}, and reference result is"
173*da0073e9SAndroid Build Coastguard Worker                    f" {expected}."
174*da0073e9SAndroid Build Coastguard Worker                )
175*da0073e9SAndroid Build Coastguard Worker            else:
176*da0073e9SAndroid Build Coastguard Worker                msg = None
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Worker            exact_dtype = True
179*da0073e9SAndroid Build Coastguard Worker            if isinstance(actual, torch.Tensor):
180*da0073e9SAndroid Build Coastguard Worker                _helper_reference_numerics(
181*da0073e9SAndroid Build Coastguard Worker                    expected, actual, msg, exact_dtype, equal_nan
182*da0073e9SAndroid Build Coastguard Worker                )
183*da0073e9SAndroid Build Coastguard Worker            else:
184*da0073e9SAndroid Build Coastguard Worker                for x, y in zip(expected, actual):
185*da0073e9SAndroid Build Coastguard Worker                    # testing multi-outputs results
186*da0073e9SAndroid Build Coastguard Worker                    _helper_reference_numerics(x, y, msg, exact_dtype, equal_nan)
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Worker    # The following tests only apply to elementwise binary operators with references
189*da0073e9SAndroid Build Coastguard Worker    binary_ufuncs_with_references = list(
190*da0073e9SAndroid Build Coastguard Worker        filter(lambda op: op.ref is not None and op.ref is not None, binary_ufuncs)
191*da0073e9SAndroid Build Coastguard Worker    )
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker    @ops(binary_ufuncs_with_references)
194*da0073e9SAndroid Build Coastguard Worker    def test_reference_numerics(self, device, dtype, op):
195*da0073e9SAndroid Build Coastguard Worker        gen = generate_elementwise_binary_tensors(op, device=device, dtype=dtype)
196*da0073e9SAndroid Build Coastguard Worker        self._test_reference_numerics(dtype, op, gen, equal_nan=True)
197*da0073e9SAndroid Build Coastguard Worker
198*da0073e9SAndroid Build Coastguard Worker    @ops(binary_ufuncs_with_references)
199*da0073e9SAndroid Build Coastguard Worker    def test_reference_numerics_small_values(self, device, dtype, op):
200*da0073e9SAndroid Build Coastguard Worker        if dtype is torch.bool:
201*da0073e9SAndroid Build Coastguard Worker            self.skipTest("Doesn't support bool!")
202*da0073e9SAndroid Build Coastguard Worker
203*da0073e9SAndroid Build Coastguard Worker        gen = generate_elementwise_binary_small_value_tensors(
204*da0073e9SAndroid Build Coastguard Worker            op, device=device, dtype=dtype
205*da0073e9SAndroid Build Coastguard Worker        )
206*da0073e9SAndroid Build Coastguard Worker        self._test_reference_numerics(dtype, op, gen, equal_nan=True)
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Worker    @ops(
209*da0073e9SAndroid Build Coastguard Worker        binary_ufuncs_with_references,
210*da0073e9SAndroid Build Coastguard Worker        allowed_dtypes=(
211*da0073e9SAndroid Build Coastguard Worker            torch.int16,
212*da0073e9SAndroid Build Coastguard Worker            torch.int32,
213*da0073e9SAndroid Build Coastguard Worker            torch.int64,
214*da0073e9SAndroid Build Coastguard Worker            torch.float16,
215*da0073e9SAndroid Build Coastguard Worker            torch.bfloat16,
216*da0073e9SAndroid Build Coastguard Worker            torch.float32,
217*da0073e9SAndroid Build Coastguard Worker            torch.float64,
218*da0073e9SAndroid Build Coastguard Worker            torch.complex64,
219*da0073e9SAndroid Build Coastguard Worker            torch.complex128,
220*da0073e9SAndroid Build Coastguard Worker        ),
221*da0073e9SAndroid Build Coastguard Worker    )
222*da0073e9SAndroid Build Coastguard Worker    def test_reference_numerics_large_values(self, device, dtype, op):
223*da0073e9SAndroid Build Coastguard Worker        gen = generate_elementwise_binary_large_value_tensors(
224*da0073e9SAndroid Build Coastguard Worker            op, device=device, dtype=dtype
225*da0073e9SAndroid Build Coastguard Worker        )
226*da0073e9SAndroid Build Coastguard Worker        self._test_reference_numerics(dtype, op, gen, equal_nan=True)
227*da0073e9SAndroid Build Coastguard Worker
228*da0073e9SAndroid Build Coastguard Worker    @ops(
229*da0073e9SAndroid Build Coastguard Worker        binary_ufuncs_with_references,
230*da0073e9SAndroid Build Coastguard Worker        allowed_dtypes=(
231*da0073e9SAndroid Build Coastguard Worker            torch.float16,
232*da0073e9SAndroid Build Coastguard Worker            torch.bfloat16,
233*da0073e9SAndroid Build Coastguard Worker            torch.float32,
234*da0073e9SAndroid Build Coastguard Worker            torch.float64,
235*da0073e9SAndroid Build Coastguard Worker            torch.complex64,
236*da0073e9SAndroid Build Coastguard Worker            torch.complex128,
237*da0073e9SAndroid Build Coastguard Worker        ),
238*da0073e9SAndroid Build Coastguard Worker    )
239*da0073e9SAndroid Build Coastguard Worker    def test_reference_numerics_extremal_values(self, device, dtype, op):
240*da0073e9SAndroid Build Coastguard Worker        gen = generate_elementwise_binary_extremal_value_tensors(
241*da0073e9SAndroid Build Coastguard Worker            op, device=device, dtype=dtype
242*da0073e9SAndroid Build Coastguard Worker        )
243*da0073e9SAndroid Build Coastguard Worker        self._test_reference_numerics(dtype, op, gen, equal_nan=True)
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Worker    # tests broadcasting and noncontiguous broadcasting behavior
246*da0073e9SAndroid Build Coastguard Worker    @ops(
247*da0073e9SAndroid Build Coastguard Worker        binary_ufuncs_with_references,
248*da0073e9SAndroid Build Coastguard Worker        allowed_dtypes=(
249*da0073e9SAndroid Build Coastguard Worker            torch.long,
250*da0073e9SAndroid Build Coastguard Worker            torch.float32,
251*da0073e9SAndroid Build Coastguard Worker        ),
252*da0073e9SAndroid Build Coastguard Worker    )
253*da0073e9SAndroid Build Coastguard Worker    def test_broadcasting(self, device, dtype, op):
254*da0073e9SAndroid Build Coastguard Worker        gen = generate_elementwise_binary_broadcasting_tensors(
255*da0073e9SAndroid Build Coastguard Worker            op, device=device, dtype=dtype
256*da0073e9SAndroid Build Coastguard Worker        )
257*da0073e9SAndroid Build Coastguard Worker        self._test_reference_numerics(dtype, op, gen, equal_nan=True)
258*da0073e9SAndroid Build Coastguard Worker
259*da0073e9SAndroid Build Coastguard Worker    @ops(
260*da0073e9SAndroid Build Coastguard Worker        binary_ufuncs_with_references,
261*da0073e9SAndroid Build Coastguard Worker        allowed_dtypes=(torch.long, torch.float32, torch.complex64),
262*da0073e9SAndroid Build Coastguard Worker    )
263*da0073e9SAndroid Build Coastguard Worker    def test_scalar_support(self, device, dtype, op):
264*da0073e9SAndroid Build Coastguard Worker        gen = generate_elementwise_binary_with_scalar_samples(
265*da0073e9SAndroid Build Coastguard Worker            op, device=device, dtype=dtype
266*da0073e9SAndroid Build Coastguard Worker        )
267*da0073e9SAndroid Build Coastguard Worker        self._test_reference_numerics(dtype, op, gen, equal_nan=True)
268*da0073e9SAndroid Build Coastguard Worker        gen = generate_elementwise_binary_with_scalar_and_type_promotion_samples(
269*da0073e9SAndroid Build Coastguard Worker            op, device=device, dtype=dtype
270*da0073e9SAndroid Build Coastguard Worker        )
271*da0073e9SAndroid Build Coastguard Worker        self._test_reference_numerics(dtype, op, gen, equal_nan=True)
272*da0073e9SAndroid Build Coastguard Worker
273*da0073e9SAndroid Build Coastguard Worker    @ops(binary_ufuncs)
274*da0073e9SAndroid Build Coastguard Worker    def test_contig_vs_every_other(self, device, dtype, op):
275*da0073e9SAndroid Build Coastguard Worker        lhs = make_tensor(
276*da0073e9SAndroid Build Coastguard Worker            (1026,), device=device, dtype=dtype, **op.lhs_make_tensor_kwargs
277*da0073e9SAndroid Build Coastguard Worker        )
278*da0073e9SAndroid Build Coastguard Worker        rhs = make_tensor(
279*da0073e9SAndroid Build Coastguard Worker            (1026,), device=device, dtype=dtype, **op.rhs_make_tensor_kwargs
280*da0073e9SAndroid Build Coastguard Worker        )
281*da0073e9SAndroid Build Coastguard Worker
282*da0073e9SAndroid Build Coastguard Worker        lhs_non_contig = lhs[::2]
283*da0073e9SAndroid Build Coastguard Worker        rhs_non_contig = rhs[::2]
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(lhs.is_contiguous())
286*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(rhs.is_contiguous())
287*da0073e9SAndroid Build Coastguard Worker
288*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(lhs_non_contig.is_contiguous())
289*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(rhs_non_contig.is_contiguous())
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Worker        expected = op(lhs, rhs)[::2]
292*da0073e9SAndroid Build Coastguard Worker        actual = op(lhs_non_contig, rhs_non_contig)
293*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, actual)
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard Worker    @ops(binary_ufuncs)
296*da0073e9SAndroid Build Coastguard Worker    def test_contig_vs_transposed(self, device, dtype, op):
297*da0073e9SAndroid Build Coastguard Worker        lhs = make_tensor(
298*da0073e9SAndroid Build Coastguard Worker            (789, 357), device=device, dtype=dtype, **op.lhs_make_tensor_kwargs
299*da0073e9SAndroid Build Coastguard Worker        )
300*da0073e9SAndroid Build Coastguard Worker        rhs = make_tensor(
301*da0073e9SAndroid Build Coastguard Worker            (789, 357), device=device, dtype=dtype, **op.rhs_make_tensor_kwargs
302*da0073e9SAndroid Build Coastguard Worker        )
303*da0073e9SAndroid Build Coastguard Worker
304*da0073e9SAndroid Build Coastguard Worker        lhs_non_contig = lhs.T
305*da0073e9SAndroid Build Coastguard Worker        rhs_non_contig = rhs.T
306*da0073e9SAndroid Build Coastguard Worker
307*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(lhs.is_contiguous())
308*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(rhs.is_contiguous())
309*da0073e9SAndroid Build Coastguard Worker
310*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(lhs_non_contig.is_contiguous())
311*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(rhs_non_contig.is_contiguous())
312*da0073e9SAndroid Build Coastguard Worker
313*da0073e9SAndroid Build Coastguard Worker        expected = op(lhs, rhs).T
314*da0073e9SAndroid Build Coastguard Worker        actual = op(lhs_non_contig, rhs_non_contig)
315*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, actual)
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard Worker    @ops(binary_ufuncs)
318*da0073e9SAndroid Build Coastguard Worker    def test_non_contig(self, device, dtype, op):
319*da0073e9SAndroid Build Coastguard Worker        shapes = ((5, 7), (1024,))
320*da0073e9SAndroid Build Coastguard Worker        for shape in shapes:
321*da0073e9SAndroid Build Coastguard Worker            lhs = make_tensor(
322*da0073e9SAndroid Build Coastguard Worker                shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs
323*da0073e9SAndroid Build Coastguard Worker            )
324*da0073e9SAndroid Build Coastguard Worker            rhs = make_tensor(
325*da0073e9SAndroid Build Coastguard Worker                shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs
326*da0073e9SAndroid Build Coastguard Worker            )
327*da0073e9SAndroid Build Coastguard Worker
328*da0073e9SAndroid Build Coastguard Worker            lhs_non_contig = torch.empty(shape + (2,), device=device, dtype=dtype)[
329*da0073e9SAndroid Build Coastguard Worker                ..., 0
330*da0073e9SAndroid Build Coastguard Worker            ]
331*da0073e9SAndroid Build Coastguard Worker            lhs_non_contig.copy_(lhs)
332*da0073e9SAndroid Build Coastguard Worker
333*da0073e9SAndroid Build Coastguard Worker            rhs_non_contig = torch.empty(shape + (2,), device=device, dtype=dtype)[
334*da0073e9SAndroid Build Coastguard Worker                ..., 0
335*da0073e9SAndroid Build Coastguard Worker            ]
336*da0073e9SAndroid Build Coastguard Worker            rhs_non_contig.copy_(rhs)
337*da0073e9SAndroid Build Coastguard Worker
338*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(lhs.is_contiguous())
339*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(rhs.is_contiguous())
340*da0073e9SAndroid Build Coastguard Worker
341*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(lhs_non_contig.is_contiguous())
342*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(rhs_non_contig.is_contiguous())
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker            expected = op(lhs, rhs)
345*da0073e9SAndroid Build Coastguard Worker            actual = op(lhs_non_contig, rhs_non_contig)
346*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, actual)
347*da0073e9SAndroid Build Coastguard Worker
348*da0073e9SAndroid Build Coastguard Worker    @ops(binary_ufuncs)
349*da0073e9SAndroid Build Coastguard Worker    def test_non_contig_index(self, device, dtype, op):
350*da0073e9SAndroid Build Coastguard Worker        shape = (2, 2, 1, 2)
351*da0073e9SAndroid Build Coastguard Worker        lhs = make_tensor(
352*da0073e9SAndroid Build Coastguard Worker            shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs
353*da0073e9SAndroid Build Coastguard Worker        )
354*da0073e9SAndroid Build Coastguard Worker        rhs = make_tensor(
355*da0073e9SAndroid Build Coastguard Worker            shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs
356*da0073e9SAndroid Build Coastguard Worker        )
357*da0073e9SAndroid Build Coastguard Worker
358*da0073e9SAndroid Build Coastguard Worker        lhs_non_contig = lhs[:, 1, ...]
359*da0073e9SAndroid Build Coastguard Worker        lhs = lhs_non_contig.contiguous()
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Worker        rhs_non_contig = rhs[:, 1, ...]
362*da0073e9SAndroid Build Coastguard Worker        rhs = rhs_non_contig.contiguous()
363*da0073e9SAndroid Build Coastguard Worker
364*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(lhs.is_contiguous())
365*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(rhs.is_contiguous())
366*da0073e9SAndroid Build Coastguard Worker
367*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(lhs_non_contig.is_contiguous())
368*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(rhs_non_contig.is_contiguous())
369*da0073e9SAndroid Build Coastguard Worker
370*da0073e9SAndroid Build Coastguard Worker        expected = op(lhs, rhs)
371*da0073e9SAndroid Build Coastguard Worker        actual = op(lhs_non_contig, rhs_non_contig)
372*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, actual)
373*da0073e9SAndroid Build Coastguard Worker
374*da0073e9SAndroid Build Coastguard Worker    @ops(binary_ufuncs)
375*da0073e9SAndroid Build Coastguard Worker    def test_non_contig_expand(self, device, dtype, op):
376*da0073e9SAndroid Build Coastguard Worker        shapes = [(1, 3), (1, 7), (5, 7)]
377*da0073e9SAndroid Build Coastguard Worker        for shape in shapes:
378*da0073e9SAndroid Build Coastguard Worker            lhs = make_tensor(
379*da0073e9SAndroid Build Coastguard Worker                shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs
380*da0073e9SAndroid Build Coastguard Worker            )
381*da0073e9SAndroid Build Coastguard Worker            rhs = make_tensor(
382*da0073e9SAndroid Build Coastguard Worker                shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs
383*da0073e9SAndroid Build Coastguard Worker            )
384*da0073e9SAndroid Build Coastguard Worker
385*da0073e9SAndroid Build Coastguard Worker            lhs_non_contig = lhs.clone().expand(3, -1, -1)
386*da0073e9SAndroid Build Coastguard Worker            rhs_non_contig = rhs.clone().expand(3, -1, -1)
387*da0073e9SAndroid Build Coastguard Worker
388*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(lhs.is_contiguous())
389*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(rhs.is_contiguous())
390*da0073e9SAndroid Build Coastguard Worker
391*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(lhs_non_contig.is_contiguous())
392*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(rhs_non_contig.is_contiguous())
393*da0073e9SAndroid Build Coastguard Worker
394*da0073e9SAndroid Build Coastguard Worker            expected = op(lhs, rhs)
395*da0073e9SAndroid Build Coastguard Worker            actual = op(lhs_non_contig, rhs_non_contig)
396*da0073e9SAndroid Build Coastguard Worker            for i in range(3):
397*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expected, actual[i])
398*da0073e9SAndroid Build Coastguard Worker
399*da0073e9SAndroid Build Coastguard Worker    @ops(binary_ufuncs)
400*da0073e9SAndroid Build Coastguard Worker    def test_contig_size1(self, device, dtype, op):
401*da0073e9SAndroid Build Coastguard Worker        shape = (5, 100)
402*da0073e9SAndroid Build Coastguard Worker        lhs = make_tensor(
403*da0073e9SAndroid Build Coastguard Worker            shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs
404*da0073e9SAndroid Build Coastguard Worker        )
405*da0073e9SAndroid Build Coastguard Worker        rhs = make_tensor(
406*da0073e9SAndroid Build Coastguard Worker            shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs
407*da0073e9SAndroid Build Coastguard Worker        )
408*da0073e9SAndroid Build Coastguard Worker
409*da0073e9SAndroid Build Coastguard Worker        lhs = lhs[:1, :50]
410*da0073e9SAndroid Build Coastguard Worker        lhs_alt = torch.empty(lhs.size(), device=device, dtype=dtype)
411*da0073e9SAndroid Build Coastguard Worker        lhs_alt.copy_(lhs)
412*da0073e9SAndroid Build Coastguard Worker
413*da0073e9SAndroid Build Coastguard Worker        rhs = rhs[:1, :50]
414*da0073e9SAndroid Build Coastguard Worker        rhs_alt = torch.empty(rhs.size(), device=device, dtype=dtype)
415*da0073e9SAndroid Build Coastguard Worker        rhs_alt.copy_(rhs)
416*da0073e9SAndroid Build Coastguard Worker
417*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(lhs.is_contiguous())
418*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(rhs.is_contiguous())
419*da0073e9SAndroid Build Coastguard Worker
420*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(lhs_alt.is_contiguous())
421*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(rhs_alt.is_contiguous())
422*da0073e9SAndroid Build Coastguard Worker
423*da0073e9SAndroid Build Coastguard Worker        expected = op(lhs, rhs)
424*da0073e9SAndroid Build Coastguard Worker        actual = op(lhs_alt, rhs_alt)
425*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, actual)
426*da0073e9SAndroid Build Coastguard Worker
427*da0073e9SAndroid Build Coastguard Worker    @ops(binary_ufuncs)
428*da0073e9SAndroid Build Coastguard Worker    def test_contig_size1_large_dim(self, device, dtype, op):
429*da0073e9SAndroid Build Coastguard Worker        shape = (5, 2, 3, 1, 4, 5, 3, 2, 1, 2, 3, 4)
430*da0073e9SAndroid Build Coastguard Worker        lhs = make_tensor(
431*da0073e9SAndroid Build Coastguard Worker            shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs
432*da0073e9SAndroid Build Coastguard Worker        )
433*da0073e9SAndroid Build Coastguard Worker        rhs = make_tensor(
434*da0073e9SAndroid Build Coastguard Worker            shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs
435*da0073e9SAndroid Build Coastguard Worker        )
436*da0073e9SAndroid Build Coastguard Worker
437*da0073e9SAndroid Build Coastguard Worker        lhs = lhs[:1, :, :, :, :, :, :, :, :, :, :, :]
438*da0073e9SAndroid Build Coastguard Worker        lhs_alt = torch.empty(lhs.size(), device=device, dtype=dtype)
439*da0073e9SAndroid Build Coastguard Worker        lhs_alt.copy_(lhs)
440*da0073e9SAndroid Build Coastguard Worker
441*da0073e9SAndroid Build Coastguard Worker        rhs = rhs[:1, :, :, :, :, :, :, :, :, :, :, :]
442*da0073e9SAndroid Build Coastguard Worker        rhs_alt = torch.empty(rhs.size(), device=device, dtype=dtype)
443*da0073e9SAndroid Build Coastguard Worker        rhs_alt.copy_(rhs)
444*da0073e9SAndroid Build Coastguard Worker
445*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(lhs.is_contiguous())
446*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(rhs.is_contiguous())
447*da0073e9SAndroid Build Coastguard Worker
448*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(lhs_alt.is_contiguous())
449*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(rhs_alt.is_contiguous())
450*da0073e9SAndroid Build Coastguard Worker
451*da0073e9SAndroid Build Coastguard Worker        expected = op(lhs, rhs)
452*da0073e9SAndroid Build Coastguard Worker        actual = op(lhs_alt, rhs_alt)
453*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, actual)
454*da0073e9SAndroid Build Coastguard Worker
455*da0073e9SAndroid Build Coastguard Worker    @ops(binary_ufuncs)
456*da0073e9SAndroid Build Coastguard Worker    def test_batch_vs_slicing(self, device, dtype, op):
457*da0073e9SAndroid Build Coastguard Worker        shape = (32, 512)
458*da0073e9SAndroid Build Coastguard Worker        lhs = make_tensor(
459*da0073e9SAndroid Build Coastguard Worker            shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs
460*da0073e9SAndroid Build Coastguard Worker        )
461*da0073e9SAndroid Build Coastguard Worker        rhs = make_tensor(
462*da0073e9SAndroid Build Coastguard Worker            shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs
463*da0073e9SAndroid Build Coastguard Worker        )
464*da0073e9SAndroid Build Coastguard Worker
465*da0073e9SAndroid Build Coastguard Worker        expected = op(lhs, rhs)
466*da0073e9SAndroid Build Coastguard Worker
467*da0073e9SAndroid Build Coastguard Worker        actual = []
468*da0073e9SAndroid Build Coastguard Worker        for idx in range(32):
469*da0073e9SAndroid Build Coastguard Worker            actual.append(op(lhs[idx], rhs[idx]))
470*da0073e9SAndroid Build Coastguard Worker        actual = torch.stack(actual)
471*da0073e9SAndroid Build Coastguard Worker
472*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, actual)
473*da0073e9SAndroid Build Coastguard Worker
474*da0073e9SAndroid Build Coastguard Worker    # Tests that elementwise binary operators participate in type promotion properly
475*da0073e9SAndroid Build Coastguard Worker    # NOTE: because the cross-product of all possible type promotion tests is huge, this
476*da0073e9SAndroid Build Coastguard Worker    #   just spot checks some handwritten cases.
477*da0073e9SAndroid Build Coastguard Worker    # NOTE: It may be possible to refactor this test into something simpler
478*da0073e9SAndroid Build Coastguard Worker    @ops(binary_ufuncs_and_refs, dtypes=OpDTypes.none)
479*da0073e9SAndroid Build Coastguard Worker    def test_type_promotion(self, device, op):
480*da0073e9SAndroid Build Coastguard Worker        supported_dtypes = op.supported_dtypes(torch.device(device).type)
481*da0073e9SAndroid Build Coastguard Worker
482*da0073e9SAndroid Build Coastguard Worker        make_lhs = partial(
483*da0073e9SAndroid Build Coastguard Worker            make_tensor, (5,), device=device, **op.lhs_make_tensor_kwargs
484*da0073e9SAndroid Build Coastguard Worker        )
485*da0073e9SAndroid Build Coastguard Worker        make_rhs = partial(
486*da0073e9SAndroid Build Coastguard Worker            make_tensor, (5,), device=device, **op.rhs_make_tensor_kwargs
487*da0073e9SAndroid Build Coastguard Worker        )
488*da0073e9SAndroid Build Coastguard Worker
489*da0073e9SAndroid Build Coastguard Worker        make_rhs_scalar_tensor = partial(
490*da0073e9SAndroid Build Coastguard Worker            make_tensor, (), device="cpu", **op.rhs_make_tensor_kwargs
491*da0073e9SAndroid Build Coastguard Worker        )
492*da0073e9SAndroid Build Coastguard Worker
493*da0073e9SAndroid Build Coastguard Worker        def _supported(dtypes):
494*da0073e9SAndroid Build Coastguard Worker            return all(x in supported_dtypes for x in dtypes)
495*da0073e9SAndroid Build Coastguard Worker
496*da0073e9SAndroid Build Coastguard Worker        # int x int type promotion
497*da0073e9SAndroid Build Coastguard Worker        if _supported((torch.int16, torch.int32, torch.int64)):
498*da0073e9SAndroid Build Coastguard Worker            lhs_i16 = make_lhs(dtype=torch.int16)
499*da0073e9SAndroid Build Coastguard Worker            lhs_i32 = make_lhs(dtype=torch.int32)
500*da0073e9SAndroid Build Coastguard Worker            lhs_i64 = make_lhs(dtype=torch.int64)
501*da0073e9SAndroid Build Coastguard Worker
502*da0073e9SAndroid Build Coastguard Worker            rhs_i16 = make_rhs(dtype=torch.int16)
503*da0073e9SAndroid Build Coastguard Worker            rhs_i32 = make_rhs(dtype=torch.int32)
504*da0073e9SAndroid Build Coastguard Worker            rhs_i64 = make_rhs(dtype=torch.int64)
505*da0073e9SAndroid Build Coastguard Worker
506*da0073e9SAndroid Build Coastguard Worker            if op.promotes_int_to_float:
507*da0073e9SAndroid Build Coastguard Worker                default_dtype = torch.get_default_dtype()
508*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(op(lhs_i16, rhs_i32).dtype, default_dtype)
509*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
510*da0073e9SAndroid Build Coastguard Worker                    op(lhs_i16, rhs_i32),
511*da0073e9SAndroid Build Coastguard Worker                    op(lhs_i16.to(default_dtype), rhs_i32.to(default_dtype)),
512*da0073e9SAndroid Build Coastguard Worker                )
513*da0073e9SAndroid Build Coastguard Worker
514*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(op(lhs_i32, rhs_i64).dtype, default_dtype)
515*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
516*da0073e9SAndroid Build Coastguard Worker                    op(lhs_i32, rhs_i64),
517*da0073e9SAndroid Build Coastguard Worker                    op(lhs_i32.to(default_dtype), rhs_i64.to(default_dtype)),
518*da0073e9SAndroid Build Coastguard Worker                )
519*da0073e9SAndroid Build Coastguard Worker            elif op.always_returns_bool:
520*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(op(lhs_i16, rhs_i32).dtype, torch.bool)
521*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(op(lhs_i32, rhs_i64).dtype, torch.bool)
522*da0073e9SAndroid Build Coastguard Worker            else:  # standard type promotion
523*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(op(lhs_i16, rhs_i32).dtype, torch.int32)
524*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
525*da0073e9SAndroid Build Coastguard Worker                    op(lhs_i16, rhs_i32), op(lhs_i16.to(torch.int32), rhs_i32)
526*da0073e9SAndroid Build Coastguard Worker                )
527*da0073e9SAndroid Build Coastguard Worker
528*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(op(lhs_i32, rhs_i64).dtype, torch.int64)
529*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
530*da0073e9SAndroid Build Coastguard Worker                    op(lhs_i32, rhs_i64), op(lhs_i32.to(torch.int64), rhs_i64)
531*da0073e9SAndroid Build Coastguard Worker                )
532*da0073e9SAndroid Build Coastguard Worker
533*da0073e9SAndroid Build Coastguard Worker            if op.supports_out:
534*da0073e9SAndroid Build Coastguard Worker                if not op.promotes_int_to_float:
535*da0073e9SAndroid Build Coastguard Worker                    # Integers can be safely cast to other integer types
536*da0073e9SAndroid Build Coastguard Worker                    out = torch.empty_like(lhs_i64)
537*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(op(lhs_i16, rhs_i32, out=out).dtype, torch.int64)
538*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(op(lhs_i16, rhs_i32), out, exact_dtype=False)
539*da0073e9SAndroid Build Coastguard Worker
540*da0073e9SAndroid Build Coastguard Worker                    out = torch.empty_like(lhs_i16)
541*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(op(lhs_i32, rhs_i64, out=out).dtype, torch.int16)
542*da0073e9SAndroid Build Coastguard Worker                else:
543*da0073e9SAndroid Build Coastguard Worker                    # Float outs cannot be safely cast to integer types
544*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(RuntimeError, "can't be cast"):
545*da0073e9SAndroid Build Coastguard Worker                        op(lhs_i16, rhs_i32, out=torch.empty_like(lhs_i64))
546*da0073e9SAndroid Build Coastguard Worker
547*da0073e9SAndroid Build Coastguard Worker                if not op.always_returns_bool:
548*da0073e9SAndroid Build Coastguard Worker                    # Neither integer nor float outs can be cast to bool
549*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(RuntimeError, "can't be cast"):
550*da0073e9SAndroid Build Coastguard Worker                        op(
551*da0073e9SAndroid Build Coastguard Worker                            lhs_i16,
552*da0073e9SAndroid Build Coastguard Worker                            rhs_i32,
553*da0073e9SAndroid Build Coastguard Worker                            out=torch.empty_like(lhs_i64, dtype=torch.bool),
554*da0073e9SAndroid Build Coastguard Worker                        )
555*da0073e9SAndroid Build Coastguard Worker
556*da0073e9SAndroid Build Coastguard Worker                # All these output types can be cast to any float or complex type
557*da0073e9SAndroid Build Coastguard Worker                out = torch.empty_like(lhs_i64, dtype=torch.float16)
558*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(op(lhs_i16, rhs_i32, out=out).dtype, torch.float16)
559*da0073e9SAndroid Build Coastguard Worker
560*da0073e9SAndroid Build Coastguard Worker                out = torch.empty_like(lhs_i64, dtype=torch.bfloat16)
561*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(op(lhs_i16, rhs_i32, out=out).dtype, torch.bfloat16)
562*da0073e9SAndroid Build Coastguard Worker
563*da0073e9SAndroid Build Coastguard Worker                out = torch.empty_like(lhs_i64, dtype=torch.float32)
564*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(op(lhs_i16, rhs_i32, out=out).dtype, torch.float32)
565*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(op(lhs_i16, rhs_i32), out, exact_dtype=False)
566*da0073e9SAndroid Build Coastguard Worker
567*da0073e9SAndroid Build Coastguard Worker                out = torch.empty_like(lhs_i64, dtype=torch.complex64)
568*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(op(lhs_i16, rhs_i32, out=out).dtype, torch.complex64)
569*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(op(lhs_i16, rhs_i32), out, exact_dtype=False)
570*da0073e9SAndroid Build Coastguard Worker
571*da0073e9SAndroid Build Coastguard Worker        # float x float type promotion
572*da0073e9SAndroid Build Coastguard Worker        if _supported((torch.float32, torch.float64)):
573*da0073e9SAndroid Build Coastguard Worker            lhs_f32 = make_lhs(dtype=torch.float32)
574*da0073e9SAndroid Build Coastguard Worker            lhs_f64 = make_lhs(dtype=torch.float64)
575*da0073e9SAndroid Build Coastguard Worker
576*da0073e9SAndroid Build Coastguard Worker            rhs_f32 = make_rhs(dtype=torch.float32)
577*da0073e9SAndroid Build Coastguard Worker            rhs_f64 = make_rhs(dtype=torch.float64)
578*da0073e9SAndroid Build Coastguard Worker
579*da0073e9SAndroid Build Coastguard Worker            if op.always_returns_bool:
580*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(op(lhs_f32, rhs_f64).dtype, torch.bool)
581*da0073e9SAndroid Build Coastguard Worker            else:  # normal float type promotion
582*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(op(lhs_f32, rhs_f64).dtype, torch.float64)
583*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
584*da0073e9SAndroid Build Coastguard Worker                    op(lhs_f32, rhs_f64), op(lhs_f32.to(torch.float64), rhs_f64)
585*da0073e9SAndroid Build Coastguard Worker                )
586*da0073e9SAndroid Build Coastguard Worker
587*da0073e9SAndroid Build Coastguard Worker            if op.supports_out:
588*da0073e9SAndroid Build Coastguard Worker                # All these output types can be cast to any float or complex type
589*da0073e9SAndroid Build Coastguard Worker                out = torch.empty_like(lhs_f64, dtype=torch.float16)
590*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(op(lhs_f32, rhs_f64, out=out).dtype, torch.float16)
591*da0073e9SAndroid Build Coastguard Worker
592*da0073e9SAndroid Build Coastguard Worker                out = torch.empty_like(lhs_f64, dtype=torch.bfloat16)
593*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(op(lhs_f32, rhs_f64, out=out).dtype, torch.bfloat16)
594*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(op(lhs_f32, rhs_f64), out, exact_dtype=False)
595*da0073e9SAndroid Build Coastguard Worker
596*da0073e9SAndroid Build Coastguard Worker                out = torch.empty_like(lhs_f64, dtype=torch.float32)
597*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(op(lhs_f32, rhs_f64, out=out).dtype, torch.float32)
598*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(op(lhs_f32, rhs_f64), out, exact_dtype=False)
599*da0073e9SAndroid Build Coastguard Worker
600*da0073e9SAndroid Build Coastguard Worker                out = torch.empty_like(lhs_f64, dtype=torch.complex64)
601*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(op(lhs_f32, rhs_f64, out=out).dtype, torch.complex64)
602*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(op(lhs_f32, rhs_f64), out, exact_dtype=False)
603*da0073e9SAndroid Build Coastguard Worker
604*da0073e9SAndroid Build Coastguard Worker                if not op.always_returns_bool:
605*da0073e9SAndroid Build Coastguard Worker                    # float outs can't be cast to an integer dtype
606*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(RuntimeError, "can't be cast"):
607*da0073e9SAndroid Build Coastguard Worker                        op(
608*da0073e9SAndroid Build Coastguard Worker                            lhs_f32,
609*da0073e9SAndroid Build Coastguard Worker                            rhs_f64,
610*da0073e9SAndroid Build Coastguard Worker                            out=torch.empty_like(lhs_f64, dtype=torch.int64),
611*da0073e9SAndroid Build Coastguard Worker                        )
612*da0073e9SAndroid Build Coastguard Worker                else:
613*da0073e9SAndroid Build Coastguard Worker                    # bool outs can be cast to an integer dtype
614*da0073e9SAndroid Build Coastguard Worker                    out = torch.empty_like(lhs_f64, dtype=torch.int64)
615*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(op(lhs_f32, rhs_f64, out=out).dtype, torch.int64)
616*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(op(lhs_f32, rhs_f64), out, exact_dtype=False)
617*da0073e9SAndroid Build Coastguard Worker
618*da0073e9SAndroid Build Coastguard Worker        # complex x complex type promotion
619*da0073e9SAndroid Build Coastguard Worker        if _supported((torch.complex64, torch.complex128)):
620*da0073e9SAndroid Build Coastguard Worker            lhs_c64 = make_lhs(dtype=torch.complex64)
621*da0073e9SAndroid Build Coastguard Worker            lhs_c128 = make_lhs(dtype=torch.complex128)
622*da0073e9SAndroid Build Coastguard Worker
623*da0073e9SAndroid Build Coastguard Worker            rhs_c64 = make_rhs(dtype=torch.complex64)
624*da0073e9SAndroid Build Coastguard Worker            rhs_c128 = make_rhs(dtype=torch.complex128)
625*da0073e9SAndroid Build Coastguard Worker
626*da0073e9SAndroid Build Coastguard Worker            if op.always_returns_bool:
627*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(op(lhs_c64, lhs_c128).dtype, torch.bool)
628*da0073e9SAndroid Build Coastguard Worker            else:  # normal complex type promotion
629*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(op(lhs_c64, rhs_c128).dtype, torch.complex128)
630*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
631*da0073e9SAndroid Build Coastguard Worker                    op(lhs_c64, rhs_c128), op(lhs_c64.to(torch.complex128), rhs_c128)
632*da0073e9SAndroid Build Coastguard Worker                )
633*da0073e9SAndroid Build Coastguard Worker
634*da0073e9SAndroid Build Coastguard Worker            if op.supports_out:
635*da0073e9SAndroid Build Coastguard Worker                # All these output types can be cast to any or complex type
636*da0073e9SAndroid Build Coastguard Worker                out = torch.empty_like(lhs_c64, dtype=torch.complex64)
637*da0073e9SAndroid Build Coastguard Worker
638*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(op(lhs_c64, rhs_c128, out=out).dtype, torch.complex64)
639*da0073e9SAndroid Build Coastguard Worker                result = op(lhs_c64, rhs_c128)
640*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(result, out.to(result.dtype))
641*da0073e9SAndroid Build Coastguard Worker
642*da0073e9SAndroid Build Coastguard Worker                if not op.always_returns_bool:
643*da0073e9SAndroid Build Coastguard Worker                    # complex outs can't be cast to float types
644*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(RuntimeError, "can't be cast"):
645*da0073e9SAndroid Build Coastguard Worker                        op(
646*da0073e9SAndroid Build Coastguard Worker                            lhs_c64,
647*da0073e9SAndroid Build Coastguard Worker                            rhs_c128,
648*da0073e9SAndroid Build Coastguard Worker                            out=torch.empty_like(lhs_c64, dtype=torch.float64),
649*da0073e9SAndroid Build Coastguard Worker                        )
650*da0073e9SAndroid Build Coastguard Worker                    # complex outs can't be cast to an integer dtype
651*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(RuntimeError, "can't be cast"):
652*da0073e9SAndroid Build Coastguard Worker                        op(
653*da0073e9SAndroid Build Coastguard Worker                            lhs_c64,
654*da0073e9SAndroid Build Coastguard Worker                            rhs_c128,
655*da0073e9SAndroid Build Coastguard Worker                            out=torch.empty_like(lhs_c64, dtype=torch.int64),
656*da0073e9SAndroid Build Coastguard Worker                        )
657*da0073e9SAndroid Build Coastguard Worker                else:
658*da0073e9SAndroid Build Coastguard Worker                    # bool outs can be cast to a float type
659*da0073e9SAndroid Build Coastguard Worker                    out = torch.empty_like(lhs_c64, dtype=torch.float64)
660*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(
661*da0073e9SAndroid Build Coastguard Worker                        op(lhs_c64, rhs_c128, out=out).dtype, torch.float64
662*da0073e9SAndroid Build Coastguard Worker                    )
663*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(op(lhs_c64, rhs_c128), out, exact_dtype=False)
664*da0073e9SAndroid Build Coastguard Worker
665*da0073e9SAndroid Build Coastguard Worker                    # bool outs can be cast to an integer dtype
666*da0073e9SAndroid Build Coastguard Worker                    out = torch.empty_like(lhs_f64, dtype=torch.int64)
667*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(op(lhs_f32, rhs_f64, out=out).dtype, torch.int64)
668*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(op(lhs_f32, rhs_f64), out, exact_dtype=False)
669*da0073e9SAndroid Build Coastguard Worker
670*da0073e9SAndroid Build Coastguard Worker        # int x float type promotion
671*da0073e9SAndroid Build Coastguard Worker        # Note: float type is the result dtype
672*da0073e9SAndroid Build Coastguard Worker        if _supported((torch.long, torch.float32)):
673*da0073e9SAndroid Build Coastguard Worker            lhs_i64 = make_lhs(dtype=torch.int64)
674*da0073e9SAndroid Build Coastguard Worker            rhs_f32 = make_rhs(dtype=torch.float32)
675*da0073e9SAndroid Build Coastguard Worker
676*da0073e9SAndroid Build Coastguard Worker            result = op(lhs_i64, rhs_f32)
677*da0073e9SAndroid Build Coastguard Worker            expected_dtype = torch.float32 if not op.always_returns_bool else torch.bool
678*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.dtype, expected_dtype)
679*da0073e9SAndroid Build Coastguard Worker
680*da0073e9SAndroid Build Coastguard Worker        # float x complex type promotion
681*da0073e9SAndroid Build Coastguard Worker        # Note: complex type with highest "value type" is the result dtype
682*da0073e9SAndroid Build Coastguard Worker        if _supported((torch.float64, torch.complex64)):
683*da0073e9SAndroid Build Coastguard Worker            lhs_f64 = make_lhs(dtype=torch.float64)
684*da0073e9SAndroid Build Coastguard Worker            rhs_c64 = make_rhs(dtype=torch.complex64)
685*da0073e9SAndroid Build Coastguard Worker
686*da0073e9SAndroid Build Coastguard Worker            result = op(lhs_f64, rhs_c64)
687*da0073e9SAndroid Build Coastguard Worker            expected_dtype = (
688*da0073e9SAndroid Build Coastguard Worker                torch.complex128 if not op.always_returns_bool else torch.bool
689*da0073e9SAndroid Build Coastguard Worker            )
690*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.dtype, expected_dtype)
691*da0073e9SAndroid Build Coastguard Worker
692*da0073e9SAndroid Build Coastguard Worker        # int x float scalar type promotion
693*da0073e9SAndroid Build Coastguard Worker        # Note: default float dtype is the result dtype
694*da0073e9SAndroid Build Coastguard Worker        if _supported((torch.int64, torch.float32)) and op.supports_rhs_python_scalar:
695*da0073e9SAndroid Build Coastguard Worker            lhs_i64 = make_lhs(dtype=torch.int64)
696*da0073e9SAndroid Build Coastguard Worker            rhs_f_scalar = 1.0
697*da0073e9SAndroid Build Coastguard Worker
698*da0073e9SAndroid Build Coastguard Worker            result = op(lhs_i64, rhs_f_scalar)
699*da0073e9SAndroid Build Coastguard Worker            expected_dtype = (
700*da0073e9SAndroid Build Coastguard Worker                torch.get_default_dtype() if not op.always_returns_bool else torch.bool
701*da0073e9SAndroid Build Coastguard Worker            )
702*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.dtype, expected_dtype)
703*da0073e9SAndroid Build Coastguard Worker
704*da0073e9SAndroid Build Coastguard Worker            # repeats with a scalar float tensor, which should set the dtype
705*da0073e9SAndroid Build Coastguard Worker            rhs_f32_scalar_tensor = make_rhs_scalar_tensor(dtype=torch.float32)
706*da0073e9SAndroid Build Coastguard Worker            result = op(lhs_i64, rhs_f32_scalar_tensor)
707*da0073e9SAndroid Build Coastguard Worker            expected_dtype = torch.float32 if not op.always_returns_bool else torch.bool
708*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.dtype, expected_dtype)
709*da0073e9SAndroid Build Coastguard Worker
710*da0073e9SAndroid Build Coastguard Worker            # Additional test with double
711*da0073e9SAndroid Build Coastguard Worker            if _supported((torch.float64,)):
712*da0073e9SAndroid Build Coastguard Worker                rhs_f64_scalar_tensor = make_rhs_scalar_tensor(dtype=torch.float64)
713*da0073e9SAndroid Build Coastguard Worker                result = op(lhs_i64, rhs_f64_scalar_tensor)
714*da0073e9SAndroid Build Coastguard Worker                expected_dtype = (
715*da0073e9SAndroid Build Coastguard Worker                    torch.float64 if not op.always_returns_bool else torch.bool
716*da0073e9SAndroid Build Coastguard Worker                )
717*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(result.dtype, expected_dtype)
718*da0073e9SAndroid Build Coastguard Worker
719*da0073e9SAndroid Build Coastguard Worker        # float x complex scalar type promotion
720*da0073e9SAndroid Build Coastguard Worker        # Note: result dtype is complex with highest "value type" among all tensors
721*da0073e9SAndroid Build Coastguard Worker        if (
722*da0073e9SAndroid Build Coastguard Worker            _supported((torch.float32, torch.complex64))
723*da0073e9SAndroid Build Coastguard Worker            and op.supports_rhs_python_scalar
724*da0073e9SAndroid Build Coastguard Worker        ):
725*da0073e9SAndroid Build Coastguard Worker            lhs_f32 = make_lhs(dtype=torch.float32)
726*da0073e9SAndroid Build Coastguard Worker            rhs_c_scalar = complex(1, 1)
727*da0073e9SAndroid Build Coastguard Worker
728*da0073e9SAndroid Build Coastguard Worker            result = op(lhs_f32, rhs_c_scalar)
729*da0073e9SAndroid Build Coastguard Worker            expected_dtype = (
730*da0073e9SAndroid Build Coastguard Worker                torch.complex64 if not op.always_returns_bool else torch.bool
731*da0073e9SAndroid Build Coastguard Worker            )
732*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.dtype, expected_dtype)
733*da0073e9SAndroid Build Coastguard Worker
734*da0073e9SAndroid Build Coastguard Worker            # repeats with a scalar complex tensor
735*da0073e9SAndroid Build Coastguard Worker            rhs_c64_scalar_tensor = make_rhs_scalar_tensor(dtype=torch.complex64)
736*da0073e9SAndroid Build Coastguard Worker            result = op(lhs_f32, rhs_c64_scalar_tensor)
737*da0073e9SAndroid Build Coastguard Worker            expected_dtype = (
738*da0073e9SAndroid Build Coastguard Worker                torch.complex64 if not op.always_returns_bool else torch.bool
739*da0073e9SAndroid Build Coastguard Worker            )
740*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.dtype, expected_dtype)
741*da0073e9SAndroid Build Coastguard Worker
742*da0073e9SAndroid Build Coastguard Worker            # Additional test with complexdouble
743*da0073e9SAndroid Build Coastguard Worker            if _supported((torch.complex128,)):
744*da0073e9SAndroid Build Coastguard Worker                rhs_c128_scalar_tensor = make_rhs_scalar_tensor(dtype=torch.complex128)
745*da0073e9SAndroid Build Coastguard Worker                result = op(lhs_f32, rhs_c128_scalar_tensor)
746*da0073e9SAndroid Build Coastguard Worker                # Value type of 1D+ Tensor (lhs_f32) takes priority over scalar tensor (rhs_c128).
747*da0073e9SAndroid Build Coastguard Worker                expected_dtype = (
748*da0073e9SAndroid Build Coastguard Worker                    torch.complex64 if not op.always_returns_bool else torch.bool
749*da0073e9SAndroid Build Coastguard Worker                )
750*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(result.dtype, expected_dtype)
751*da0073e9SAndroid Build Coastguard Worker
752*da0073e9SAndroid Build Coastguard Worker        # float x float scalar tensor
753*da0073e9SAndroid Build Coastguard Worker        # Note: result dtype is the type of the float tensor
754*da0073e9SAndroid Build Coastguard Worker        if _supported((torch.float32, torch.float64)) and op.supports_rhs_python_scalar:
755*da0073e9SAndroid Build Coastguard Worker            lhs_f32 = make_lhs(dtype=torch.float32)
756*da0073e9SAndroid Build Coastguard Worker            rhs_f64_scalar_tensor = make_rhs_scalar_tensor(dtype=torch.float64)
757*da0073e9SAndroid Build Coastguard Worker
758*da0073e9SAndroid Build Coastguard Worker            result = op(lhs_f32, rhs_f64_scalar_tensor)
759*da0073e9SAndroid Build Coastguard Worker            expected_dtype = torch.float32 if not op.always_returns_bool else torch.bool
760*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.dtype, expected_dtype)
761*da0073e9SAndroid Build Coastguard Worker
762*da0073e9SAndroid Build Coastguard Worker        # complex x complex scalar tensor
763*da0073e9SAndroid Build Coastguard Worker        # Note: result dtype is the type of the complex tensor
764*da0073e9SAndroid Build Coastguard Worker        if (
765*da0073e9SAndroid Build Coastguard Worker            _supported((torch.complex64, torch.complex128))
766*da0073e9SAndroid Build Coastguard Worker            and op.supports_rhs_python_scalar
767*da0073e9SAndroid Build Coastguard Worker        ):
768*da0073e9SAndroid Build Coastguard Worker            lhs_c64 = make_lhs(dtype=torch.complex64)
769*da0073e9SAndroid Build Coastguard Worker            rhs_c128_scalar_tensor = make_rhs_scalar_tensor(dtype=torch.complex128)
770*da0073e9SAndroid Build Coastguard Worker
771*da0073e9SAndroid Build Coastguard Worker            result = op(lhs_c64, rhs_c128_scalar_tensor)
772*da0073e9SAndroid Build Coastguard Worker            expected_dtype = (
773*da0073e9SAndroid Build Coastguard Worker                torch.complex64 if not op.always_returns_bool else torch.bool
774*da0073e9SAndroid Build Coastguard Worker            )
775*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.dtype, expected_dtype)
776*da0073e9SAndroid Build Coastguard Worker
777*da0073e9SAndroid Build Coastguard Worker        # scalar  x scalar
778*da0073e9SAndroid Build Coastguard Worker        # Note: result dtype is default float type
779*da0073e9SAndroid Build Coastguard Worker        if op.supports_two_python_scalars and _supported((torch.long, torch.float32)):
780*da0073e9SAndroid Build Coastguard Worker            rhs_f_scalar = 2.0
781*da0073e9SAndroid Build Coastguard Worker            for lhs in (1, 1.0):
782*da0073e9SAndroid Build Coastguard Worker                result = op(lhs, rhs_f_scalar)
783*da0073e9SAndroid Build Coastguard Worker                expected_dtype = (
784*da0073e9SAndroid Build Coastguard Worker                    torch.get_default_dtype()
785*da0073e9SAndroid Build Coastguard Worker                    if not op.always_returns_bool
786*da0073e9SAndroid Build Coastguard Worker                    else torch.bool
787*da0073e9SAndroid Build Coastguard Worker                )
788*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(result.dtype, expected_dtype)
789*da0073e9SAndroid Build Coastguard Worker
790*da0073e9SAndroid Build Coastguard Worker    # TODO: move to error input test
791*da0073e9SAndroid Build Coastguard Worker    @ops(binary_ufuncs, allowed_dtypes=(torch.float32,))
792*da0073e9SAndroid Build Coastguard Worker    def test_not_broadcastable(self, device, dtype, op):
793*da0073e9SAndroid Build Coastguard Worker        for shape_lhs, shape_rhs in (
794*da0073e9SAndroid Build Coastguard Worker            ((2,), (3,)),
795*da0073e9SAndroid Build Coastguard Worker            ((3, 1), (2, 1)),
796*da0073e9SAndroid Build Coastguard Worker            ((1, 3, 2), (3,)),
797*da0073e9SAndroid Build Coastguard Worker            ((3, 1, 2), (2, 1, 2)),
798*da0073e9SAndroid Build Coastguard Worker        ):
799*da0073e9SAndroid Build Coastguard Worker            lhs = make_tensor(
800*da0073e9SAndroid Build Coastguard Worker                shape_lhs, device=device, dtype=dtype, **op.lhs_make_tensor_kwargs
801*da0073e9SAndroid Build Coastguard Worker            )
802*da0073e9SAndroid Build Coastguard Worker            rhs = make_tensor(
803*da0073e9SAndroid Build Coastguard Worker                shape_rhs, device=device, dtype=dtype, **op.rhs_make_tensor_kwargs
804*da0073e9SAndroid Build Coastguard Worker            )
805*da0073e9SAndroid Build Coastguard Worker
806*da0073e9SAndroid Build Coastguard Worker            try:
807*da0073e9SAndroid Build Coastguard Worker                broadcasted_shape = op(lhs, rhs).shape
808*da0073e9SAndroid Build Coastguard Worker            except RuntimeError:
809*da0073e9SAndroid Build Coastguard Worker                continue
810*da0073e9SAndroid Build Coastguard Worker
811*da0073e9SAndroid Build Coastguard Worker            msg = (
812*da0073e9SAndroid Build Coastguard Worker                f"On {device}, torch.{op.name} broadcasts inputs shapes {shape_lhs} and {shape_rhs} into "
813*da0073e9SAndroid Build Coastguard Worker                f"{broadcasted_shape}, although they are not broadcastable."
814*da0073e9SAndroid Build Coastguard Worker            )
815*da0073e9SAndroid Build Coastguard Worker            raise AssertionError(msg)
816*da0073e9SAndroid Build Coastguard Worker
817*da0073e9SAndroid Build Coastguard Worker    def test_add_broadcast_empty(self, device):
818*da0073e9SAndroid Build Coastguard Worker        # empty + empty
819*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(
820*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
821*da0073e9SAndroid Build Coastguard Worker            lambda: torch.randn(5, 0, device=device) + torch.randn(0, 5, device=device),
822*da0073e9SAndroid Build Coastguard Worker        )
823*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
824*da0073e9SAndroid Build Coastguard Worker            torch.randn(5, 0, device=device),
825*da0073e9SAndroid Build Coastguard Worker            torch.randn(0, device=device) + torch.randn(5, 0, device=device),
826*da0073e9SAndroid Build Coastguard Worker        )
827*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
828*da0073e9SAndroid Build Coastguard Worker            torch.randn(5, 0, 0, device=device),
829*da0073e9SAndroid Build Coastguard Worker            torch.randn(0, device=device) + torch.randn(5, 0, 1, device=device),
830*da0073e9SAndroid Build Coastguard Worker        )
831*da0073e9SAndroid Build Coastguard Worker
832*da0073e9SAndroid Build Coastguard Worker        # scalar + empty
833*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
834*da0073e9SAndroid Build Coastguard Worker            torch.randn(5, 0, 6, device=device),
835*da0073e9SAndroid Build Coastguard Worker            torch.randn((), device=device) + torch.randn(5, 0, 6, device=device),
836*da0073e9SAndroid Build Coastguard Worker        )
837*da0073e9SAndroid Build Coastguard Worker
838*da0073e9SAndroid Build Coastguard Worker        # non-empty, empty
839*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
840*da0073e9SAndroid Build Coastguard Worker            torch.randn(0, device=device),
841*da0073e9SAndroid Build Coastguard Worker            torch.randn(0, device=device) + torch.randn(1, device=device),
842*da0073e9SAndroid Build Coastguard Worker        )
843*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
844*da0073e9SAndroid Build Coastguard Worker            torch.randn(0, 7, 0, 6, 5, 0, 7, device=device),
845*da0073e9SAndroid Build Coastguard Worker            torch.randn(0, 7, 0, 6, 5, 0, 1, device=device)
846*da0073e9SAndroid Build Coastguard Worker            + torch.randn(1, 1, 5, 1, 7, device=device),
847*da0073e9SAndroid Build Coastguard Worker        )
848*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(
849*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
850*da0073e9SAndroid Build Coastguard Worker            lambda: torch.randn(7, 0, device=device) + torch.randn(2, 1, device=device),
851*da0073e9SAndroid Build Coastguard Worker        )
852*da0073e9SAndroid Build Coastguard Worker
853*da0073e9SAndroid Build Coastguard Worker    def test_addcmul_scalars_as_floats(self, device):
854*da0073e9SAndroid Build Coastguard Worker        # zero-dim variables that don't require grad should bind to scalar arguments
855*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(2.0)
856*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor(3.0, device=device)
857*da0073e9SAndroid Build Coastguard Worker        # 3 + (3 * 3) * 2
858*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.addcmul(y, y, value=x), 21)
859*da0073e9SAndroid Build Coastguard Worker
860*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(2.0, requires_grad=True)
861*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(Exception, lambda: y.addcmul(y, y, value=x))
862*da0073e9SAndroid Build Coastguard Worker
863*da0073e9SAndroid Build Coastguard Worker    # Tests that the binary operators and, or, and xor (as well as their reflected and inplace versions)
864*da0073e9SAndroid Build Coastguard Worker    # work properly (AKA &, ||, ^ and &=, |=, ^=)
865*da0073e9SAndroid Build Coastguard Worker    @dtypes(*integral_types_and(torch.bool))
866*da0073e9SAndroid Build Coastguard Worker    def test_bitwise_ops(self, device, dtype):
867*da0073e9SAndroid Build Coastguard Worker        # Tensor x Tensor and Tensor x Scalar ops
868*da0073e9SAndroid Build Coastguard Worker        ops = (
869*da0073e9SAndroid Build Coastguard Worker            operator.and_,
870*da0073e9SAndroid Build Coastguard Worker            operator.iand,
871*da0073e9SAndroid Build Coastguard Worker            operator.or_,
872*da0073e9SAndroid Build Coastguard Worker            operator.ior,
873*da0073e9SAndroid Build Coastguard Worker            operator.xor,
874*da0073e9SAndroid Build Coastguard Worker            operator.ixor,
875*da0073e9SAndroid Build Coastguard Worker        )
876*da0073e9SAndroid Build Coastguard Worker        inplace_ops = (operator.iand, operator.ior, operator.ixor)
877*da0073e9SAndroid Build Coastguard Worker        shapes = ((5,), (15, 15), (500, 500))
878*da0073e9SAndroid Build Coastguard Worker
879*da0073e9SAndroid Build Coastguard Worker        for op, shape in itertools.product(ops, shapes):
880*da0073e9SAndroid Build Coastguard Worker            # Tests tensor x tensor case
881*da0073e9SAndroid Build Coastguard Worker            a = make_tensor(shape, device=device, dtype=dtype)
882*da0073e9SAndroid Build Coastguard Worker            b = make_tensor(shape, device=device, dtype=dtype)
883*da0073e9SAndroid Build Coastguard Worker            a_np = a.cpu().clone().numpy()
884*da0073e9SAndroid Build Coastguard Worker            b_np = b.cpu().clone().numpy()
885*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(op(a, b), op(a_np, b_np))
886*da0073e9SAndroid Build Coastguard Worker
887*da0073e9SAndroid Build Coastguard Worker            # Tests tensor x scalar case
888*da0073e9SAndroid Build Coastguard Worker            a = make_tensor(shape, device=device, dtype=dtype)
889*da0073e9SAndroid Build Coastguard Worker            b_scalar = make_tensor((), device="cpu", dtype=dtype).item()
890*da0073e9SAndroid Build Coastguard Worker            a_np = a.cpu().clone().numpy()
891*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(op(a, b_scalar), op(a_np, b_scalar))
892*da0073e9SAndroid Build Coastguard Worker
893*da0073e9SAndroid Build Coastguard Worker            # Tests scalar x tensor case
894*da0073e9SAndroid Build Coastguard Worker            a_scalar = make_tensor((), device="cpu", dtype=dtype).item()
895*da0073e9SAndroid Build Coastguard Worker            b = make_tensor(shape, device=device, dtype=dtype)
896*da0073e9SAndroid Build Coastguard Worker            b_np = b.cpu().clone().numpy()
897*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(op(a_scalar, b), op(a_scalar, b_np))
898*da0073e9SAndroid Build Coastguard Worker
899*da0073e9SAndroid Build Coastguard Worker            # Tests scalar x tensor case (for ops which aren't inplace)
900*da0073e9SAndroid Build Coastguard Worker            if op in inplace_ops:
901*da0073e9SAndroid Build Coastguard Worker                # Tests tensor x tensor case
902*da0073e9SAndroid Build Coastguard Worker                a = make_tensor(shape, device=device, dtype=dtype)
903*da0073e9SAndroid Build Coastguard Worker                b = make_tensor(shape, device=device, dtype=dtype)
904*da0073e9SAndroid Build Coastguard Worker                a_np = a.cpu().clone().numpy()
905*da0073e9SAndroid Build Coastguard Worker                b_np = b.cpu().clone().numpy()
906*da0073e9SAndroid Build Coastguard Worker                op(a, b)
907*da0073e9SAndroid Build Coastguard Worker                op(a_np, b_np)
908*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(a, a_np)
909*da0073e9SAndroid Build Coastguard Worker
910*da0073e9SAndroid Build Coastguard Worker                # Tests tensor x scalar case
911*da0073e9SAndroid Build Coastguard Worker                a = make_tensor(shape, device=device, dtype=dtype)
912*da0073e9SAndroid Build Coastguard Worker                b_scalar = make_tensor((), device="cpu", dtype=dtype).item()
913*da0073e9SAndroid Build Coastguard Worker                a_np = a.cpu().clone().numpy()
914*da0073e9SAndroid Build Coastguard Worker                op(a, b_scalar)
915*da0073e9SAndroid Build Coastguard Worker                op(a_np, b_scalar)
916*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(a, a_np)
917*da0073e9SAndroid Build Coastguard Worker
918*da0073e9SAndroid Build Coastguard Worker    def test_inplace_division(self, device):
919*da0073e9SAndroid Build Coastguard Worker        t = torch.rand(5, 5, device=device)
920*da0073e9SAndroid Build Coastguard Worker        id_before = id(t)
921*da0073e9SAndroid Build Coastguard Worker        t /= 2
922*da0073e9SAndroid Build Coastguard Worker        id_after = id(t)
923*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(id_before, id_after)
924*da0073e9SAndroid Build Coastguard Worker
925*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and(torch.half, torch.bfloat16))
926*da0073e9SAndroid Build Coastguard Worker    def test_div_rounding_modes(self, device, dtype):
927*da0073e9SAndroid Build Coastguard Worker        if dtype.is_floating_point:
928*da0073e9SAndroid Build Coastguard Worker            low, high = -10.0, 10.0
929*da0073e9SAndroid Build Coastguard Worker        else:
930*da0073e9SAndroid Build Coastguard Worker            info = torch.iinfo(dtype)
931*da0073e9SAndroid Build Coastguard Worker            low, high = info.min, info.max
932*da0073e9SAndroid Build Coastguard Worker
933*da0073e9SAndroid Build Coastguard Worker        a = make_tensor((100,), dtype=dtype, device=device, low=low, high=high)
934*da0073e9SAndroid Build Coastguard Worker        b = make_tensor((100,), dtype=dtype, device=device, low=low, high=high)
935*da0073e9SAndroid Build Coastguard Worker
936*da0073e9SAndroid Build Coastguard Worker        # Avoid division by zero so we can test (a / b) * b == a
937*da0073e9SAndroid Build Coastguard Worker        if dtype.is_floating_point:
938*da0073e9SAndroid Build Coastguard Worker            eps = 0.1
939*da0073e9SAndroid Build Coastguard Worker            b[(-eps < b) & (b < eps)] = eps
940*da0073e9SAndroid Build Coastguard Worker        else:
941*da0073e9SAndroid Build Coastguard Worker            b[b == 0] = 1
942*da0073e9SAndroid Build Coastguard Worker
943*da0073e9SAndroid Build Coastguard Worker        if not dtype.is_floating_point:
944*da0073e9SAndroid Build Coastguard Worker            # floor(a / b) * b can be < a, so fixup slightly to avoid underflow
945*da0073e9SAndroid Build Coastguard Worker            a = torch.where(a < 0, a + b, a)
946*da0073e9SAndroid Build Coastguard Worker
947*da0073e9SAndroid Build Coastguard Worker        d_true = torch.divide(a, b, rounding_mode=None)
948*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(d_true.is_floating_point())
949*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(d_true * b, a.to(d_true.dtype))
950*da0073e9SAndroid Build Coastguard Worker
951*da0073e9SAndroid Build Coastguard Worker        d_floor = torch.divide(a, b, rounding_mode="floor")
952*da0073e9SAndroid Build Coastguard Worker        if dtype not in (torch.bfloat16, torch.half):
953*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(d_floor * b + torch.remainder(a, b), a)
954*da0073e9SAndroid Build Coastguard Worker        else:
955*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
956*da0073e9SAndroid Build Coastguard Worker                d_floor * b + torch.remainder(a.float(), b.float()),
957*da0073e9SAndroid Build Coastguard Worker                a,
958*da0073e9SAndroid Build Coastguard Worker                exact_dtype=False,
959*da0073e9SAndroid Build Coastguard Worker            )
960*da0073e9SAndroid Build Coastguard Worker
961*da0073e9SAndroid Build Coastguard Worker        d_trunc = torch.divide(a, b, rounding_mode="trunc")
962*da0073e9SAndroid Build Coastguard Worker        rounding_unsupported = (
963*da0073e9SAndroid Build Coastguard Worker            dtype == torch.half
964*da0073e9SAndroid Build Coastguard Worker            and device != "cuda"
965*da0073e9SAndroid Build Coastguard Worker            or dtype == torch.bfloat16
966*da0073e9SAndroid Build Coastguard Worker            and device != "cpu"
967*da0073e9SAndroid Build Coastguard Worker        )
968*da0073e9SAndroid Build Coastguard Worker        d_ref = d_true.float() if rounding_unsupported else d_true
969*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(d_trunc, d_ref.trunc().to(dtype))
970*da0073e9SAndroid Build Coastguard Worker
971*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_types_and(torch.bfloat16, torch.float16))
972*da0073e9SAndroid Build Coastguard Worker    def test_floor_div_extremal(self, device, dtype):
973*da0073e9SAndroid Build Coastguard Worker        for num, denom, shape in itertools.product(
974*da0073e9SAndroid Build Coastguard Worker            [torch.finfo(dtype).max * 0.7],
975*da0073e9SAndroid Build Coastguard Worker            [0.5, -0.5, 0.0],
976*da0073e9SAndroid Build Coastguard Worker            [(), (32,)],  # Scalar and vectorized
977*da0073e9SAndroid Build Coastguard Worker        ):
978*da0073e9SAndroid Build Coastguard Worker            a = torch.full(shape, num, dtype=dtype, device=device)
979*da0073e9SAndroid Build Coastguard Worker            b = torch.full(shape, denom, dtype=dtype, device=device)
980*da0073e9SAndroid Build Coastguard Worker
981*da0073e9SAndroid Build Coastguard Worker            ref = np.floor_divide(num, denom).item()
982*da0073e9SAndroid Build Coastguard Worker            if ref > torch.finfo(dtype).max:
983*da0073e9SAndroid Build Coastguard Worker                ref = np.inf
984*da0073e9SAndroid Build Coastguard Worker            elif ref < torch.finfo(dtype).min:
985*da0073e9SAndroid Build Coastguard Worker                ref = -np.inf
986*da0073e9SAndroid Build Coastguard Worker            expect = torch.full(shape, ref, dtype=dtype, device=device)
987*da0073e9SAndroid Build Coastguard Worker            actual = torch.div(a, b, rounding_mode="floor")
988*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expect, actual)
989*da0073e9SAndroid Build Coastguard Worker
990*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.bfloat16, torch.half, torch.float32, torch.float64)
991*da0073e9SAndroid Build Coastguard Worker    def test_div_rounding_nonfinite(self, device, dtype):
992*da0073e9SAndroid Build Coastguard Worker        # Compare division of special floating point values against NumPy
993*da0073e9SAndroid Build Coastguard Worker        num = torch.tensor(
994*da0073e9SAndroid Build Coastguard Worker            [1.0, -1.0, 0, 0.1, -0.1, np.pi, -np.pi, np.inf, -np.inf, np.nan],
995*da0073e9SAndroid Build Coastguard Worker            dtype=dtype,
996*da0073e9SAndroid Build Coastguard Worker            device=device,
997*da0073e9SAndroid Build Coastguard Worker        )
998*da0073e9SAndroid Build Coastguard Worker        # Divide by zero is tested separately
999*da0073e9SAndroid Build Coastguard Worker        denom = num[num != 0]
1000*da0073e9SAndroid Build Coastguard Worker
1001*da0073e9SAndroid Build Coastguard Worker        a, b = num[None, :].clone(), denom[:, None].clone()
1002*da0073e9SAndroid Build Coastguard Worker
1003*da0073e9SAndroid Build Coastguard Worker        # Compare bfloat16 against NumPy float
1004*da0073e9SAndroid Build Coastguard Worker        exact_dtype = dtype != torch.bfloat16
1005*da0073e9SAndroid Build Coastguard Worker        if exact_dtype:
1006*da0073e9SAndroid Build Coastguard Worker            an, bn = a.cpu().numpy(), b.cpu().numpy()
1007*da0073e9SAndroid Build Coastguard Worker        else:
1008*da0073e9SAndroid Build Coastguard Worker            an, bn = a.float().cpu().numpy(), b.float().cpu().numpy()
1009*da0073e9SAndroid Build Coastguard Worker
1010*da0073e9SAndroid Build Coastguard Worker        for mode, np_ref in ((None, np.true_divide), ("floor", np.floor_divide)):
1011*da0073e9SAndroid Build Coastguard Worker            expect = np_ref(an, bn)
1012*da0073e9SAndroid Build Coastguard Worker            kwargs = dict(rounding_mode=mode) if mode is not None else {}
1013*da0073e9SAndroid Build Coastguard Worker            with set_default_dtype(torch.double):
1014*da0073e9SAndroid Build Coastguard Worker                actual = torch.divide(a, b, **kwargs)
1015*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1016*da0073e9SAndroid Build Coastguard Worker                actual,
1017*da0073e9SAndroid Build Coastguard Worker                torch.from_numpy(expect),
1018*da0073e9SAndroid Build Coastguard Worker                exact_device=False,
1019*da0073e9SAndroid Build Coastguard Worker                exact_dtype=exact_dtype,
1020*da0073e9SAndroid Build Coastguard Worker            )
1021*da0073e9SAndroid Build Coastguard Worker
1022*da0073e9SAndroid Build Coastguard Worker        # Compare contiguous (likely vectorized) against non-contiguous (not vectorized)
1023*da0073e9SAndroid Build Coastguard Worker        a_noncontig = torch.empty([2 * i for i in a.shape], dtype=dtype, device=device)[
1024*da0073e9SAndroid Build Coastguard Worker            ::2, ::2
1025*da0073e9SAndroid Build Coastguard Worker        ]
1026*da0073e9SAndroid Build Coastguard Worker        a_noncontig[:] = a
1027*da0073e9SAndroid Build Coastguard Worker        b_noncontig = torch.empty([2 * i for i in b.shape], dtype=dtype, device=device)[
1028*da0073e9SAndroid Build Coastguard Worker            ::2, ::2
1029*da0073e9SAndroid Build Coastguard Worker        ]
1030*da0073e9SAndroid Build Coastguard Worker        b_noncontig[:] = b
1031*da0073e9SAndroid Build Coastguard Worker
1032*da0073e9SAndroid Build Coastguard Worker        for rounding_mode in (None, "trunc", "floor"):
1033*da0073e9SAndroid Build Coastguard Worker            expect = torch.divide(a_noncontig, b_noncontig, rounding_mode=rounding_mode)
1034*da0073e9SAndroid Build Coastguard Worker            actual = torch.divide(a, b, rounding_mode=rounding_mode)
1035*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expect)
1036*da0073e9SAndroid Build Coastguard Worker
1037*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.bfloat16, torch.half, torch.float32, torch.float64)
1038*da0073e9SAndroid Build Coastguard Worker    def test_divide_by_zero_rounding(self, device, dtype):
1039*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(
1040*da0073e9SAndroid Build Coastguard Worker            [1.0, -1.0, 0, 0.1, -0.1, np.pi, -np.pi, np.inf, -np.inf, np.nan],
1041*da0073e9SAndroid Build Coastguard Worker            dtype=dtype,
1042*da0073e9SAndroid Build Coastguard Worker        )
1043*da0073e9SAndroid Build Coastguard Worker        exact_dtype = dtype != torch.bfloat16
1044*da0073e9SAndroid Build Coastguard Worker        if exact_dtype:
1045*da0073e9SAndroid Build Coastguard Worker            an = a.cpu().numpy()
1046*da0073e9SAndroid Build Coastguard Worker        else:
1047*da0073e9SAndroid Build Coastguard Worker            an = a.float().cpu().numpy()
1048*da0073e9SAndroid Build Coastguard Worker
1049*da0073e9SAndroid Build Coastguard Worker        zero = torch.zeros_like(a)
1050*da0073e9SAndroid Build Coastguard Worker
1051*da0073e9SAndroid Build Coastguard Worker        # NOTE: NumPy's floor_divide rounding changed in 1.20.0 to be consistent with divide
1052*da0073e9SAndroid Build Coastguard Worker        expect = np.divide(an, 0)
1053*da0073e9SAndroid Build Coastguard Worker        for rounding_mode in (None, "floor"):
1054*da0073e9SAndroid Build Coastguard Worker            # CPU scalar
1055*da0073e9SAndroid Build Coastguard Worker            actual = torch.divide(a, 0, rounding_mode=rounding_mode)
1056*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expect, exact_dtype=exact_dtype)
1057*da0073e9SAndroid Build Coastguard Worker            # Device tensor
1058*da0073e9SAndroid Build Coastguard Worker            actual = torch.divide(a, zero, rounding_mode=rounding_mode)
1059*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expect, exact_dtype=exact_dtype)
1060*da0073e9SAndroid Build Coastguard Worker
1061*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and(torch.half))
1062*da0073e9SAndroid Build Coastguard Worker    def test_div_rounding_numpy(self, device, dtype):
1063*da0073e9SAndroid Build Coastguard Worker        info = torch.finfo(dtype) if dtype.is_floating_point else torch.iinfo(dtype)
1064*da0073e9SAndroid Build Coastguard Worker        low, high = info.min, info.max
1065*da0073e9SAndroid Build Coastguard Worker
1066*da0073e9SAndroid Build Coastguard Worker        # Compare division of random values against NumPy
1067*da0073e9SAndroid Build Coastguard Worker        a = make_tensor((4096,), dtype=dtype, device=device, low=low, high=high)
1068*da0073e9SAndroid Build Coastguard Worker        b = make_tensor((4096,), dtype=dtype, device=device, low=low, high=high)
1069*da0073e9SAndroid Build Coastguard Worker
1070*da0073e9SAndroid Build Coastguard Worker        # Avoid division by zero which raises for integers and, for floats,
1071*da0073e9SAndroid Build Coastguard Worker        # NumPy 1.20 changed floor_divide to follow IEEE rules for inf/nan
1072*da0073e9SAndroid Build Coastguard Worker        # after dividing by zero.
1073*da0073e9SAndroid Build Coastguard Worker        b[b == 0] = 1
1074*da0073e9SAndroid Build Coastguard Worker
1075*da0073e9SAndroid Build Coastguard Worker        # Compare bfloat16 against NumPy float
1076*da0073e9SAndroid Build Coastguard Worker        exact_dtype = dtype != torch.bfloat16
1077*da0073e9SAndroid Build Coastguard Worker
1078*da0073e9SAndroid Build Coastguard Worker        if exact_dtype:
1079*da0073e9SAndroid Build Coastguard Worker            an, bn = a.cpu().numpy(), b.cpu().numpy()
1080*da0073e9SAndroid Build Coastguard Worker        else:
1081*da0073e9SAndroid Build Coastguard Worker            an, bn = a.float().cpu().numpy(), b.float().cpu().numpy()
1082*da0073e9SAndroid Build Coastguard Worker
1083*da0073e9SAndroid Build Coastguard Worker        for mode, np_ref in (
1084*da0073e9SAndroid Build Coastguard Worker            (None, np.true_divide),
1085*da0073e9SAndroid Build Coastguard Worker            ("floor", np.floor_divide),
1086*da0073e9SAndroid Build Coastguard Worker            ("trunc", lambda a, b: np.trunc(np.true_divide(a, b)).astype(a.dtype)),
1087*da0073e9SAndroid Build Coastguard Worker        ):
1088*da0073e9SAndroid Build Coastguard Worker            expect = torch.from_numpy(np_ref(an, bn))
1089*da0073e9SAndroid Build Coastguard Worker
1090*da0073e9SAndroid Build Coastguard Worker            kwargs = dict(rounding_mode=mode) if mode is not None else {}
1091*da0073e9SAndroid Build Coastguard Worker            # Contiguous (likely vectorized)
1092*da0073e9SAndroid Build Coastguard Worker            with set_default_dtype(torch.double):
1093*da0073e9SAndroid Build Coastguard Worker                actual = torch.divide(a, b, **kwargs)
1094*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1095*da0073e9SAndroid Build Coastguard Worker                actual, expect, exact_device=False, exact_dtype=exact_dtype
1096*da0073e9SAndroid Build Coastguard Worker            )
1097*da0073e9SAndroid Build Coastguard Worker
1098*da0073e9SAndroid Build Coastguard Worker            # Non-contiguous (not vectorized)
1099*da0073e9SAndroid Build Coastguard Worker            expect = expect[::2]
1100*da0073e9SAndroid Build Coastguard Worker            with set_default_dtype(torch.double):
1101*da0073e9SAndroid Build Coastguard Worker                actual = torch.divide(a[::2], b[::2], **kwargs)
1102*da0073e9SAndroid Build Coastguard Worker
1103*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1104*da0073e9SAndroid Build Coastguard Worker                actual, expect, exact_device=False, exact_dtype=exact_dtype
1105*da0073e9SAndroid Build Coastguard Worker            )
1106*da0073e9SAndroid Build Coastguard Worker
1107*da0073e9SAndroid Build Coastguard Worker    @dtypes(*complex_types())
1108*da0073e9SAndroid Build Coastguard Worker    def test_complex_div_underflow_overflow(self, device, dtype):
1109*da0073e9SAndroid Build Coastguard Worker        # test to make sure the complex division does not produce underflow or overflow
1110*da0073e9SAndroid Build Coastguard Worker        # in the intermediate of its calculations
1111*da0073e9SAndroid Build Coastguard Worker        # NOTE: the calculation still produces an error if the number is greater than
1112*da0073e9SAndroid Build Coastguard Worker        # finfo.max / 2, but hopefully people realized that it's a dangerous region to work with
1113*da0073e9SAndroid Build Coastguard Worker        finfo = torch.finfo(dtype)
1114*da0073e9SAndroid Build Coastguard Worker        nom_lst = [
1115*da0073e9SAndroid Build Coastguard Worker            complex(finfo.min / 2, finfo.min / 2),
1116*da0073e9SAndroid Build Coastguard Worker            complex(finfo.max / 2, finfo.max / 2),
1117*da0073e9SAndroid Build Coastguard Worker            complex(finfo.tiny, finfo.tiny),
1118*da0073e9SAndroid Build Coastguard Worker            complex(finfo.tiny, 0.0),
1119*da0073e9SAndroid Build Coastguard Worker            complex(0.0, 0.0),
1120*da0073e9SAndroid Build Coastguard Worker        ]
1121*da0073e9SAndroid Build Coastguard Worker        denom_lst = [
1122*da0073e9SAndroid Build Coastguard Worker            complex(finfo.min / 2, finfo.min / 2),
1123*da0073e9SAndroid Build Coastguard Worker            complex(finfo.max / 2, finfo.max / 2),
1124*da0073e9SAndroid Build Coastguard Worker            complex(finfo.tiny, finfo.tiny),
1125*da0073e9SAndroid Build Coastguard Worker            complex(0.0, finfo.tiny),
1126*da0073e9SAndroid Build Coastguard Worker            complex(finfo.tiny, finfo.tiny),
1127*da0073e9SAndroid Build Coastguard Worker        ]
1128*da0073e9SAndroid Build Coastguard Worker        expected_lst = [
1129*da0073e9SAndroid Build Coastguard Worker            complex(1.0, 0.0),
1130*da0073e9SAndroid Build Coastguard Worker            complex(1.0, 0.0),
1131*da0073e9SAndroid Build Coastguard Worker            complex(1.0, 0.0),
1132*da0073e9SAndroid Build Coastguard Worker            complex(0.0, -1.0),
1133*da0073e9SAndroid Build Coastguard Worker            complex(0.0, 0.0),
1134*da0073e9SAndroid Build Coastguard Worker        ]
1135*da0073e9SAndroid Build Coastguard Worker        nom = torch.tensor(nom_lst, dtype=dtype, device=device)
1136*da0073e9SAndroid Build Coastguard Worker        denom = torch.tensor(denom_lst, dtype=dtype, device=device)
1137*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor(expected_lst, dtype=dtype, device=device)
1138*da0073e9SAndroid Build Coastguard Worker        res = nom / denom
1139*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, expected)
1140*da0073e9SAndroid Build Coastguard Worker
1141*da0073e9SAndroid Build Coastguard Worker    # Tests that trying to add, inplace, a CUDA tensor to a CPU tensor
1142*da0073e9SAndroid Build Coastguard Worker    #   throws the correct error message
1143*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1144*da0073e9SAndroid Build Coastguard Worker    def test_cross_device_inplace_error_msg(self, device):
1145*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(2.0)
1146*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor(2.0, device=device)
1147*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1148*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Expected all tensors to be on the same device"
1149*da0073e9SAndroid Build Coastguard Worker        ):
1150*da0073e9SAndroid Build Coastguard Worker            a += b
1151*da0073e9SAndroid Build Coastguard Worker
1152*da0073e9SAndroid Build Coastguard Worker    # TODO: refactor this test into a more generic one, it's parked here currently
1153*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1154*da0073e9SAndroid Build Coastguard Worker    def test_out_resize_warning(self, device):
1155*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor((1, 2, 3), device=device, dtype=torch.float32)
1156*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor((4, 5, 6), device=device, dtype=torch.float32)
1157*da0073e9SAndroid Build Coastguard Worker
1158*da0073e9SAndroid Build Coastguard Worker        unary_inputs = (a,)
1159*da0073e9SAndroid Build Coastguard Worker        binary_inputs = (a, b)
1160*da0073e9SAndroid Build Coastguard Worker        unary_ops = (torch.ceil, torch.exp)
1161*da0073e9SAndroid Build Coastguard Worker        binary_ops = (torch.add, torch.sub)
1162*da0073e9SAndroid Build Coastguard Worker        for op in unary_ops + binary_ops:
1163*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings(record=True) as w:
1164*da0073e9SAndroid Build Coastguard Worker                warnings.simplefilter("always")
1165*da0073e9SAndroid Build Coastguard Worker                inputs = unary_inputs if op in unary_ops else binary_inputs
1166*da0073e9SAndroid Build Coastguard Worker
1167*da0073e9SAndroid Build Coastguard Worker                # No warnings
1168*da0073e9SAndroid Build Coastguard Worker                op(*inputs, out=torch.empty(3, device=device))
1169*da0073e9SAndroid Build Coastguard Worker                op(*inputs, out=torch.empty(0, device=device))
1170*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(len(w), 0)
1171*da0073e9SAndroid Build Coastguard Worker
1172*da0073e9SAndroid Build Coastguard Worker                # Cases that throw warnings
1173*da0073e9SAndroid Build Coastguard Worker                op(*inputs, out=torch.empty(2, device=device))
1174*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(len(w), 1)
1175*da0073e9SAndroid Build Coastguard Worker        # test that multi-d out doesn't trigger segfault
1176*da0073e9SAndroid Build Coastguard Worker        arg1 = (torch.ones(2, 1, device=device), torch.ones(1, device=device))
1177*da0073e9SAndroid Build Coastguard Worker        arg2 = (torch.ones(2, device=device), torch.ones(1, 1, device=device))
1178*da0073e9SAndroid Build Coastguard Worker        outs = (
1179*da0073e9SAndroid Build Coastguard Worker            torch.ones(2, 1, 1, 1, device=device),
1180*da0073e9SAndroid Build Coastguard Worker            torch.ones(2, 2, 2, 2, device=device),
1181*da0073e9SAndroid Build Coastguard Worker        )
1182*da0073e9SAndroid Build Coastguard Worker
1183*da0073e9SAndroid Build Coastguard Worker        for a1, a2, o in zip(arg1, arg2, outs):
1184*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings(record=True) as w:
1185*da0073e9SAndroid Build Coastguard Worker                warnings.simplefilter("always")
1186*da0073e9SAndroid Build Coastguard Worker                torch.mul(a1, a2, out=o)
1187*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(len(w), 1)
1188*da0073e9SAndroid Build Coastguard Worker
1189*da0073e9SAndroid Build Coastguard Worker    # Verifies that the inplace dunders (like idiv) actually are in place
1190*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMeta  # UserWarning not triggered
1191*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1192*da0073e9SAndroid Build Coastguard Worker    def test_inplace_dunders(self, device):
1193*da0073e9SAndroid Build Coastguard Worker        t = torch.randn((1,), device=device)
1194*da0073e9SAndroid Build Coastguard Worker        expected = t.data_ptr()
1195*da0073e9SAndroid Build Coastguard Worker        t += 1
1196*da0073e9SAndroid Build Coastguard Worker        t -= 1
1197*da0073e9SAndroid Build Coastguard Worker        t *= 1
1198*da0073e9SAndroid Build Coastguard Worker        t /= 1
1199*da0073e9SAndroid Build Coastguard Worker        t **= 1
1200*da0073e9SAndroid Build Coastguard Worker        t //= 1
1201*da0073e9SAndroid Build Coastguard Worker        t %= 1
1202*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, t.data_ptr())
1203*da0073e9SAndroid Build Coastguard Worker
1204*da0073e9SAndroid Build Coastguard Worker    def check_internal_mem_overlap(
1205*da0073e9SAndroid Build Coastguard Worker        self, inplace_op, num_inputs, dtype, device, expected_failure=False
1206*da0073e9SAndroid Build Coastguard Worker    ):
1207*da0073e9SAndroid Build Coastguard Worker        if isinstance(inplace_op, str):
1208*da0073e9SAndroid Build Coastguard Worker            inplace_op = getattr(torch.Tensor, inplace_op)
1209*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(1, dtype=dtype, device=device).expand(3, 3)
1210*da0073e9SAndroid Build Coastguard Worker        inputs = [input] + [torch.randn_like(input) for i in range(num_inputs - 1)]
1211*da0073e9SAndroid Build Coastguard Worker        if not expected_failure:
1212*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "single memory location"):
1213*da0073e9SAndroid Build Coastguard Worker                inplace_op(*inputs)
1214*da0073e9SAndroid Build Coastguard Worker        else:
1215*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(AssertionError):
1216*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, "single memory location"):
1217*da0073e9SAndroid Build Coastguard Worker                    inplace_op(*inputs)
1218*da0073e9SAndroid Build Coastguard Worker
1219*da0073e9SAndroid Build Coastguard Worker    def unary_check_input_output_mem_overlap(
1220*da0073e9SAndroid Build Coastguard Worker        self, data, sz, op, expected_failure=False
1221*da0073e9SAndroid Build Coastguard Worker    ):
1222*da0073e9SAndroid Build Coastguard Worker        def _test(op, output, input):
1223*da0073e9SAndroid Build Coastguard Worker            output_exp = torch.empty_like(output)
1224*da0073e9SAndroid Build Coastguard Worker            op(input, out=output_exp)
1225*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(op(input, out=output), output_exp, msg=op.__name__)
1226*da0073e9SAndroid Build Coastguard Worker
1227*da0073e9SAndroid Build Coastguard Worker        # output is identical to input:
1228*da0073e9SAndroid Build Coastguard Worker        _test(op, output=data[0:sz], input=data[0:sz])
1229*da0073e9SAndroid Build Coastguard Worker        # output and input are independent:
1230*da0073e9SAndroid Build Coastguard Worker        _test(op, output=data[0:sz], input=data[sz : 2 * sz])
1231*da0073e9SAndroid Build Coastguard Worker        # output partially overlaps with input:
1232*da0073e9SAndroid Build Coastguard Worker        if not expected_failure:
1233*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "unsupported operation"):
1234*da0073e9SAndroid Build Coastguard Worker                _test(op, data[0:sz], data[1 : sz + 1])
1235*da0073e9SAndroid Build Coastguard Worker        else:
1236*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(AssertionError):
1237*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, "unsupported operation"):
1238*da0073e9SAndroid Build Coastguard Worker                    _test(op, data[0:sz], data[1 : sz + 1])
1239*da0073e9SAndroid Build Coastguard Worker
1240*da0073e9SAndroid Build Coastguard Worker    def binary_check_input_output_mem_overlap(self, op, device, expected_failure=False):
1241*da0073e9SAndroid Build Coastguard Worker        sz = 3
1242*da0073e9SAndroid Build Coastguard Worker        data = torch.randn(2 * sz, device=device)
1243*da0073e9SAndroid Build Coastguard Worker        other = torch.randn(sz, device=device)
1244*da0073e9SAndroid Build Coastguard Worker
1245*da0073e9SAndroid Build Coastguard Worker        self.unary_check_input_output_mem_overlap(
1246*da0073e9SAndroid Build Coastguard Worker            data,
1247*da0073e9SAndroid Build Coastguard Worker            sz,
1248*da0073e9SAndroid Build Coastguard Worker            lambda input, out: op(other, input, out=out),
1249*da0073e9SAndroid Build Coastguard Worker            expected_failure=expected_failure,
1250*da0073e9SAndroid Build Coastguard Worker        )
1251*da0073e9SAndroid Build Coastguard Worker
1252*da0073e9SAndroid Build Coastguard Worker        self.unary_check_input_output_mem_overlap(
1253*da0073e9SAndroid Build Coastguard Worker            data,
1254*da0073e9SAndroid Build Coastguard Worker            sz,
1255*da0073e9SAndroid Build Coastguard Worker            lambda input, out: op(input, other, out=out),
1256*da0073e9SAndroid Build Coastguard Worker            expected_failure=expected_failure,
1257*da0073e9SAndroid Build Coastguard Worker        )
1258*da0073e9SAndroid Build Coastguard Worker
1259*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/126474
1260*da0073e9SAndroid Build Coastguard Worker    @xfailIfTorchDynamo
1261*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
1262*da0073e9SAndroid Build Coastguard Worker    def test_binary_op_mem_overlap(self, device, dtype):
1263*da0073e9SAndroid Build Coastguard Worker        ops = [
1264*da0073e9SAndroid Build Coastguard Worker            ("add", True, True, "cpu"),
1265*da0073e9SAndroid Build Coastguard Worker            ("add", True, True, "cuda"),
1266*da0073e9SAndroid Build Coastguard Worker            ("mul", True, True, "cpu"),
1267*da0073e9SAndroid Build Coastguard Worker            ("mul", True, True, "cuda"),
1268*da0073e9SAndroid Build Coastguard Worker            ("sub", True, True, "cpu"),
1269*da0073e9SAndroid Build Coastguard Worker            ("sub", True, True, "cuda"),
1270*da0073e9SAndroid Build Coastguard Worker            ("div", True, True, "cpu"),
1271*da0073e9SAndroid Build Coastguard Worker            ("div", True, True, "cuda"),
1272*da0073e9SAndroid Build Coastguard Worker            ("pow", True, True, "cpu"),
1273*da0073e9SAndroid Build Coastguard Worker            ("pow", True, True, "cuda"),
1274*da0073e9SAndroid Build Coastguard Worker            ("fmod", True, True, "cpu"),
1275*da0073e9SAndroid Build Coastguard Worker            ("fmod", True, True, "cuda"),
1276*da0073e9SAndroid Build Coastguard Worker            ("atan2", True, True, "cpu"),
1277*da0073e9SAndroid Build Coastguard Worker            ("atan2", True, True, "cuda"),
1278*da0073e9SAndroid Build Coastguard Worker            ("hypot", True, True, "cpu"),
1279*da0073e9SAndroid Build Coastguard Worker            ("hypot", True, True, "cuda"),
1280*da0073e9SAndroid Build Coastguard Worker            ("igamma", True, True, "cpu"),
1281*da0073e9SAndroid Build Coastguard Worker            ("igamma", True, True, "cuda"),
1282*da0073e9SAndroid Build Coastguard Worker            ("igammac", True, True, "cpu"),
1283*da0073e9SAndroid Build Coastguard Worker            ("igammac", True, True, "cuda"),
1284*da0073e9SAndroid Build Coastguard Worker            ("nextafter", True, True, "cpu"),
1285*da0073e9SAndroid Build Coastguard Worker            ("nextafter", True, True, "cuda"),
1286*da0073e9SAndroid Build Coastguard Worker            ("le", True, True, "cpu"),
1287*da0073e9SAndroid Build Coastguard Worker            ("le", True, True, "cuda"),
1288*da0073e9SAndroid Build Coastguard Worker            ("lt", True, True, "cpu"),
1289*da0073e9SAndroid Build Coastguard Worker            ("lt", True, True, "cuda"),
1290*da0073e9SAndroid Build Coastguard Worker            ("ge", True, True, "cpu"),
1291*da0073e9SAndroid Build Coastguard Worker            ("ge", True, True, "cuda"),
1292*da0073e9SAndroid Build Coastguard Worker            ("gt", True, True, "cpu"),
1293*da0073e9SAndroid Build Coastguard Worker            ("gt", True, True, "cuda"),
1294*da0073e9SAndroid Build Coastguard Worker            ("eq", True, True, "cpu"),
1295*da0073e9SAndroid Build Coastguard Worker            ("eq", True, True, "cuda"),
1296*da0073e9SAndroid Build Coastguard Worker            ("ne", True, True, "cpu"),
1297*da0073e9SAndroid Build Coastguard Worker            ("ne", True, True, "cuda"),
1298*da0073e9SAndroid Build Coastguard Worker            ("logical_and", True, True, "cpu"),
1299*da0073e9SAndroid Build Coastguard Worker            ("logical_and", True, True, "cuda"),
1300*da0073e9SAndroid Build Coastguard Worker            ("logical_or", True, True, "cpu"),
1301*da0073e9SAndroid Build Coastguard Worker            ("logical_or", True, True, "cuda"),
1302*da0073e9SAndroid Build Coastguard Worker            ("logical_xor", True, True, "cpu"),
1303*da0073e9SAndroid Build Coastguard Worker            ("logical_xor", True, True, "cuda"),
1304*da0073e9SAndroid Build Coastguard Worker        ]
1305*da0073e9SAndroid Build Coastguard Worker
1306*da0073e9SAndroid Build Coastguard Worker        for (
1307*da0073e9SAndroid Build Coastguard Worker            fn,
1308*da0073e9SAndroid Build Coastguard Worker            has_input_output_mem_overlap_check,
1309*da0073e9SAndroid Build Coastguard Worker            has_internal_mem_overlap_check,
1310*da0073e9SAndroid Build Coastguard Worker            dev,
1311*da0073e9SAndroid Build Coastguard Worker        ) in ops:
1312*da0073e9SAndroid Build Coastguard Worker            if dev != device:
1313*da0073e9SAndroid Build Coastguard Worker                continue
1314*da0073e9SAndroid Build Coastguard Worker            out_op = getattr(torch, fn)
1315*da0073e9SAndroid Build Coastguard Worker            inplace_op = getattr(torch.Tensor, fn + "_")
1316*da0073e9SAndroid Build Coastguard Worker            self.check_internal_mem_overlap(
1317*da0073e9SAndroid Build Coastguard Worker                inplace_op,
1318*da0073e9SAndroid Build Coastguard Worker                2,
1319*da0073e9SAndroid Build Coastguard Worker                dtype,
1320*da0073e9SAndroid Build Coastguard Worker                device,
1321*da0073e9SAndroid Build Coastguard Worker                expected_failure=not has_internal_mem_overlap_check,
1322*da0073e9SAndroid Build Coastguard Worker            )
1323*da0073e9SAndroid Build Coastguard Worker
1324*da0073e9SAndroid Build Coastguard Worker            self.binary_check_input_output_mem_overlap(
1325*da0073e9SAndroid Build Coastguard Worker                out_op, device, expected_failure=not has_input_output_mem_overlap_check
1326*da0073e9SAndroid Build Coastguard Worker            )
1327*da0073e9SAndroid Build Coastguard Worker
1328*da0073e9SAndroid Build Coastguard Worker    def _do_pow_for_exponents(self, m1, exponents, pow_fn, atol):
1329*da0073e9SAndroid Build Coastguard Worker        for num in exponents:
1330*da0073e9SAndroid Build Coastguard Worker            if (
1331*da0073e9SAndroid Build Coastguard Worker                isinstance(num, int)
1332*da0073e9SAndroid Build Coastguard Worker                and num < 0
1333*da0073e9SAndroid Build Coastguard Worker                and not m1.is_floating_point()
1334*da0073e9SAndroid Build Coastguard Worker                and not m1.is_complex()
1335*da0073e9SAndroid Build Coastguard Worker            ):
1336*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(
1337*da0073e9SAndroid Build Coastguard Worker                    RuntimeError,
1338*da0073e9SAndroid Build Coastguard Worker                    r"Integers to negative integer powers are not allowed\.",
1339*da0073e9SAndroid Build Coastguard Worker                ):
1340*da0073e9SAndroid Build Coastguard Worker                    torch.pow(m1[4], num)
1341*da0073e9SAndroid Build Coastguard Worker            else:
1342*da0073e9SAndroid Build Coastguard Worker                # base - tensor, exponent - number
1343*da0073e9SAndroid Build Coastguard Worker                # contiguous
1344*da0073e9SAndroid Build Coastguard Worker                res1 = torch.pow(m1[4], num)
1345*da0073e9SAndroid Build Coastguard Worker                res2 = res1.clone().zero_()
1346*da0073e9SAndroid Build Coastguard Worker                # `math.pow` has issues with complex exponentiation so we need to resort to normal `pow`.
1347*da0073e9SAndroid Build Coastguard Worker                for i in range(res2.size(0)):
1348*da0073e9SAndroid Build Coastguard Worker                    res2[i] = pow_fn(m1[4][i], num)
1349*da0073e9SAndroid Build Coastguard Worker                rtol = 0 if atol is not None else None
1350*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res1, res2, atol=atol, rtol=rtol)
1351*da0073e9SAndroid Build Coastguard Worker
1352*da0073e9SAndroid Build Coastguard Worker                # non-contiguous
1353*da0073e9SAndroid Build Coastguard Worker                res1 = torch.pow(m1[:, 4], num)
1354*da0073e9SAndroid Build Coastguard Worker                res2 = res1.clone().zero_()
1355*da0073e9SAndroid Build Coastguard Worker                for i in range(res2.size(0)):
1356*da0073e9SAndroid Build Coastguard Worker                    res2[i] = pow_fn(m1[i, 4], num)
1357*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res1, res2, atol=atol, rtol=rtol)
1358*da0073e9SAndroid Build Coastguard Worker
1359*da0073e9SAndroid Build Coastguard Worker                # scalar ** tensor to enforce correct handling of dtypes for __rpow__().
1360*da0073e9SAndroid Build Coastguard Worker                expected_dtype = torch.result_type(num, m1)
1361*da0073e9SAndroid Build Coastguard Worker                res1 = num ** m1[4]
1362*da0073e9SAndroid Build Coastguard Worker                res2 = (
1363*da0073e9SAndroid Build Coastguard Worker                    torch.tensor(num, dtype=expected_dtype, device=m1.device) ** m1[4]
1364*da0073e9SAndroid Build Coastguard Worker                )
1365*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res1, res2)
1366*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res1.dtype, expected_dtype)
1367*da0073e9SAndroid Build Coastguard Worker
1368*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
1369*da0073e9SAndroid Build Coastguard Worker    def test_pow(self, device, dtype):
1370*da0073e9SAndroid Build Coastguard Worker        m1 = torch.empty(0, dtype=dtype, device=device)
1371*da0073e9SAndroid Build Coastguard Worker        if m1.is_floating_point() or m1.is_complex():
1372*da0073e9SAndroid Build Coastguard Worker            m1 = (
1373*da0073e9SAndroid Build Coastguard Worker                make_tensor((100, 100), low=0, high=1, dtype=dtype, device=device) + 0.5
1374*da0073e9SAndroid Build Coastguard Worker            )
1375*da0073e9SAndroid Build Coastguard Worker        else:
1376*da0073e9SAndroid Build Coastguard Worker            # math.pow will overflow and throw exceptions for large integers
1377*da0073e9SAndroid Build Coastguard Worker            range_high = 4 if dtype in (torch.int8, torch.uint8) else 10
1378*da0073e9SAndroid Build Coastguard Worker            m1 = make_tensor(
1379*da0073e9SAndroid Build Coastguard Worker                (100, 100), low=1, high=range_high, dtype=dtype, device=device
1380*da0073e9SAndroid Build Coastguard Worker            )
1381*da0073e9SAndroid Build Coastguard Worker
1382*da0073e9SAndroid Build Coastguard Worker        exponents = [-2.8, -2, -1, -0.5, 0, 0.5, 1, 2, 3, 4, 3.3, True, False]
1383*da0073e9SAndroid Build Coastguard Worker        complex_exponents = [
1384*da0073e9SAndroid Build Coastguard Worker            -2.5j,
1385*da0073e9SAndroid Build Coastguard Worker            -1.0j,
1386*da0073e9SAndroid Build Coastguard Worker            0j,
1387*da0073e9SAndroid Build Coastguard Worker            1.0j,
1388*da0073e9SAndroid Build Coastguard Worker            2.5j,
1389*da0073e9SAndroid Build Coastguard Worker            1.0 + 1.0j,
1390*da0073e9SAndroid Build Coastguard Worker            -1.0 - 1.5j,
1391*da0073e9SAndroid Build Coastguard Worker            3.3j,
1392*da0073e9SAndroid Build Coastguard Worker        ]
1393*da0073e9SAndroid Build Coastguard Worker        if m1.is_complex():
1394*da0073e9SAndroid Build Coastguard Worker            self._do_pow_for_exponents(m1, exponents + complex_exponents, pow, 10e-4)
1395*da0073e9SAndroid Build Coastguard Worker        else:
1396*da0073e9SAndroid Build Coastguard Worker            self._do_pow_for_exponents(m1, exponents, math.pow, None)
1397*da0073e9SAndroid Build Coastguard Worker            will_raise_error = (
1398*da0073e9SAndroid Build Coastguard Worker                dtype is torch.half and torch.device(device).type == "cpu"
1399*da0073e9SAndroid Build Coastguard Worker            )
1400*da0073e9SAndroid Build Coastguard Worker            if will_raise_error:
1401*da0073e9SAndroid Build Coastguard Worker                # On CPU,
1402*da0073e9SAndroid Build Coastguard Worker                # Half Tensor with complex exponents leads to computation dtype
1403*da0073e9SAndroid Build Coastguard Worker                # of ComplexHalf for which this ops is not supported yet
1404*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(
1405*da0073e9SAndroid Build Coastguard Worker                    RuntimeError, "not implemented for 'ComplexHalf'"
1406*da0073e9SAndroid Build Coastguard Worker                ):
1407*da0073e9SAndroid Build Coastguard Worker                    self._do_pow_for_exponents(m1, complex_exponents, pow, 10e-4)
1408*da0073e9SAndroid Build Coastguard Worker            else:
1409*da0073e9SAndroid Build Coastguard Worker                self._do_pow_for_exponents(m1, complex_exponents, pow, 10e-4)
1410*da0073e9SAndroid Build Coastguard Worker
1411*da0073e9SAndroid Build Coastguard Worker        # base - number, exponent - tensor
1412*da0073e9SAndroid Build Coastguard Worker        # contiguous
1413*da0073e9SAndroid Build Coastguard Worker        res1 = torch.pow(3, m1[4])
1414*da0073e9SAndroid Build Coastguard Worker        res2 = res1.clone().zero_()
1415*da0073e9SAndroid Build Coastguard Worker        for i in range(res2.size(0)):
1416*da0073e9SAndroid Build Coastguard Worker            res2[i] = pow(3, m1[4, i])
1417*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, res2)
1418*da0073e9SAndroid Build Coastguard Worker
1419*da0073e9SAndroid Build Coastguard Worker        # non-contiguous
1420*da0073e9SAndroid Build Coastguard Worker        res1 = torch.pow(3, m1[:, 4])
1421*da0073e9SAndroid Build Coastguard Worker        res2 = res1.clone().zero_()
1422*da0073e9SAndroid Build Coastguard Worker        for i in range(res2.size(0)):
1423*da0073e9SAndroid Build Coastguard Worker            res2[i] = pow(3, m1[i][4])
1424*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, res2)
1425*da0073e9SAndroid Build Coastguard Worker
1426*da0073e9SAndroid Build Coastguard Worker    # TODO: refactor all these tests using opinfos properly
1427*da0073e9SAndroid Build Coastguard Worker    def _test_pow(self, base, exponent, np_exponent=None):
1428*da0073e9SAndroid Build Coastguard Worker        if np_exponent is None:
1429*da0073e9SAndroid Build Coastguard Worker            np_exponent = exponent
1430*da0073e9SAndroid Build Coastguard Worker
1431*da0073e9SAndroid Build Coastguard Worker        def to_np(value):
1432*da0073e9SAndroid Build Coastguard Worker            if isinstance(value, torch.Tensor):
1433*da0073e9SAndroid Build Coastguard Worker                return value.cpu().numpy()
1434*da0073e9SAndroid Build Coastguard Worker            return value
1435*da0073e9SAndroid Build Coastguard Worker
1436*da0073e9SAndroid Build Coastguard Worker        try:
1437*da0073e9SAndroid Build Coastguard Worker            np_res = np.power(to_np(base), to_np(np_exponent))
1438*da0073e9SAndroid Build Coastguard Worker            expected = (
1439*da0073e9SAndroid Build Coastguard Worker                torch.from_numpy(np_res)
1440*da0073e9SAndroid Build Coastguard Worker                if isinstance(np_res, np.ndarray)
1441*da0073e9SAndroid Build Coastguard Worker                else torch.tensor(np_res, dtype=base.dtype)
1442*da0073e9SAndroid Build Coastguard Worker            )
1443*da0073e9SAndroid Build Coastguard Worker        except ValueError as e:
1444*da0073e9SAndroid Build Coastguard Worker            err_msg = "Integers to negative integer powers are not allowed."
1445*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(str(e), err_msg)
1446*da0073e9SAndroid Build Coastguard Worker            out = torch.empty_like(base)
1447*da0073e9SAndroid Build Coastguard Worker            test_cases = [
1448*da0073e9SAndroid Build Coastguard Worker                lambda: base.pow(exponent),
1449*da0073e9SAndroid Build Coastguard Worker                lambda: base.pow_(exponent),
1450*da0073e9SAndroid Build Coastguard Worker                lambda: torch.pow(base, exponent),
1451*da0073e9SAndroid Build Coastguard Worker                lambda: torch.pow(base, exponent, out=out),
1452*da0073e9SAndroid Build Coastguard Worker            ]
1453*da0073e9SAndroid Build Coastguard Worker            for test_case in test_cases:
1454*da0073e9SAndroid Build Coastguard Worker                self.assertRaisesRegex(RuntimeError, err_msg, test_case)
1455*da0073e9SAndroid Build Coastguard Worker        else:
1456*da0073e9SAndroid Build Coastguard Worker            if isinstance(base, torch.Tensor):
1457*da0073e9SAndroid Build Coastguard Worker                actual = base.pow(exponent)
1458*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual, expected.to(actual))
1459*da0073e9SAndroid Build Coastguard Worker                actual = base.clone()
1460*da0073e9SAndroid Build Coastguard Worker                # When base is a 0-dim cpu tensor and exp is a cuda tensor, we exp `pow` to work but `pow_` to fail, since
1461*da0073e9SAndroid Build Coastguard Worker                # `pow` will try to create the output tensor on a cuda device, but `pow_` needs to use the cpu tensor as the output
1462*da0073e9SAndroid Build Coastguard Worker                if (
1463*da0073e9SAndroid Build Coastguard Worker                    isinstance(exponent, torch.Tensor)
1464*da0073e9SAndroid Build Coastguard Worker                    and base.dim() == 0
1465*da0073e9SAndroid Build Coastguard Worker                    and base.device.type == "cpu"
1466*da0073e9SAndroid Build Coastguard Worker                    and exponent.device.type == "cuda"
1467*da0073e9SAndroid Build Coastguard Worker                ):
1468*da0073e9SAndroid Build Coastguard Worker                    regex = "Expected all tensors to be on the same device, but found at least two devices, cuda.* and cpu!"
1469*da0073e9SAndroid Build Coastguard Worker                    self.assertRaisesRegex(RuntimeError, regex, base.pow_, exponent)
1470*da0073e9SAndroid Build Coastguard Worker                elif torch.can_cast(torch.result_type(base, exponent), base.dtype):
1471*da0073e9SAndroid Build Coastguard Worker                    actual2 = actual.pow_(exponent)
1472*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(actual, expected)
1473*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(actual2, expected)
1474*da0073e9SAndroid Build Coastguard Worker                else:
1475*da0073e9SAndroid Build Coastguard Worker                    self.assertRaisesRegex(
1476*da0073e9SAndroid Build Coastguard Worker                        RuntimeError,
1477*da0073e9SAndroid Build Coastguard Worker                        "Found dtype \\w+ but expected \\w+",
1478*da0073e9SAndroid Build Coastguard Worker                        lambda: actual.pow_(exponent),
1479*da0073e9SAndroid Build Coastguard Worker                    )
1480*da0073e9SAndroid Build Coastguard Worker
1481*da0073e9SAndroid Build Coastguard Worker            actual = torch.pow(base, exponent)
1482*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expected.to(actual))
1483*da0073e9SAndroid Build Coastguard Worker
1484*da0073e9SAndroid Build Coastguard Worker            actual2 = torch.pow(base, exponent, out=actual)
1485*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expected.to(actual))
1486*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual2, expected.to(actual))
1487*da0073e9SAndroid Build Coastguard Worker
1488*da0073e9SAndroid Build Coastguard Worker    # We can potentially merge this into OpInfo, but one blocker is that the
1489*da0073e9SAndroid Build Coastguard Worker    # first input must be a scalar. It is not as simple as just wrapping this in
1490*da0073e9SAndroid Build Coastguard Worker    # a lambada that switches the inputs, because we also want to test samples inputs
1491*da0073e9SAndroid Build Coastguard Worker    # where the second input is a scalar. The wrapper would need some more logic.
1492*da0073e9SAndroid Build Coastguard Worker    def test_pow_scalar_base(self, device):
1493*da0073e9SAndroid Build Coastguard Worker        a = (
1494*da0073e9SAndroid Build Coastguard Worker            torch.arange(1, 13, dtype=torch.double, device=device)
1495*da0073e9SAndroid Build Coastguard Worker            .view(3, 4)
1496*da0073e9SAndroid Build Coastguard Worker            .requires_grad_()
1497*da0073e9SAndroid Build Coastguard Worker        )
1498*da0073e9SAndroid Build Coastguard Worker        gradcheck(lambda a: torch.pow(2, a), (a,))
1499*da0073e9SAndroid Build Coastguard Worker
1500*da0073e9SAndroid Build Coastguard Worker    # Tests pow() for integral, floating-type tensors, with integral, floating-type
1501*da0073e9SAndroid Build Coastguard Worker    # exponents (tensor or scalar), respectively. noncontiguous tensors are also tested.
1502*da0073e9SAndroid Build Coastguard Worker    def test_int_and_float_pow(self, device):
1503*da0073e9SAndroid Build Coastguard Worker        def _test_int_and_float_pow(dt, low, high, dev):
1504*da0073e9SAndroid Build Coastguard Worker            test_cases = (
1505*da0073e9SAndroid Build Coastguard Worker                ((4, 4), 0, (4, 1)),
1506*da0073e9SAndroid Build Coastguard Worker                ((3, 1), 4, (3, 1)),
1507*da0073e9SAndroid Build Coastguard Worker                ((2,), 4, (1,)),
1508*da0073e9SAndroid Build Coastguard Worker                ((1,), 2, ()),
1509*da0073e9SAndroid Build Coastguard Worker                ((513, 513), 4, (513,)),
1510*da0073e9SAndroid Build Coastguard Worker                ((5, 5, 5), 5, (5,)),
1511*da0073e9SAndroid Build Coastguard Worker                ((), 2, ()),
1512*da0073e9SAndroid Build Coastguard Worker            )
1513*da0073e9SAndroid Build Coastguard Worker            for base_shape, exp_scalar, exp_shape in test_cases:
1514*da0073e9SAndroid Build Coastguard Worker                base_tensor = make_tensor(
1515*da0073e9SAndroid Build Coastguard Worker                    base_shape, dtype=dt, device=dev, low=low, high=high
1516*da0073e9SAndroid Build Coastguard Worker                )
1517*da0073e9SAndroid Build Coastguard Worker                # int tensors don't take negative exponents
1518*da0073e9SAndroid Build Coastguard Worker                if dt in [
1519*da0073e9SAndroid Build Coastguard Worker                    torch.uint8,
1520*da0073e9SAndroid Build Coastguard Worker                    torch.int8,
1521*da0073e9SAndroid Build Coastguard Worker                    torch.int16,
1522*da0073e9SAndroid Build Coastguard Worker                    torch.int32,
1523*da0073e9SAndroid Build Coastguard Worker                    torch.int64,
1524*da0073e9SAndroid Build Coastguard Worker                ]:
1525*da0073e9SAndroid Build Coastguard Worker                    exp_tensor = make_tensor(
1526*da0073e9SAndroid Build Coastguard Worker                        exp_shape, dtype=dt, device=dev, low=0, high=high
1527*da0073e9SAndroid Build Coastguard Worker                    )
1528*da0073e9SAndroid Build Coastguard Worker                else:
1529*da0073e9SAndroid Build Coastguard Worker                    exp_tensor = make_tensor(
1530*da0073e9SAndroid Build Coastguard Worker                        exp_shape, dtype=dt, device=dev, low=low, high=high
1531*da0073e9SAndroid Build Coastguard Worker                    )
1532*da0073e9SAndroid Build Coastguard Worker                self._test_pow(base_tensor, exp_scalar)
1533*da0073e9SAndroid Build Coastguard Worker                self._test_pow(base_tensor, exp_tensor)
1534*da0073e9SAndroid Build Coastguard Worker                # test non-contiguous tensors as well
1535*da0073e9SAndroid Build Coastguard Worker                base_tensor = make_tensor(
1536*da0073e9SAndroid Build Coastguard Worker                    base_shape,
1537*da0073e9SAndroid Build Coastguard Worker                    dtype=dt,
1538*da0073e9SAndroid Build Coastguard Worker                    device=dev,
1539*da0073e9SAndroid Build Coastguard Worker                    low=low,
1540*da0073e9SAndroid Build Coastguard Worker                    high=high,
1541*da0073e9SAndroid Build Coastguard Worker                    noncontiguous=True,
1542*da0073e9SAndroid Build Coastguard Worker                )
1543*da0073e9SAndroid Build Coastguard Worker                if dt in [
1544*da0073e9SAndroid Build Coastguard Worker                    torch.uint8,
1545*da0073e9SAndroid Build Coastguard Worker                    torch.int8,
1546*da0073e9SAndroid Build Coastguard Worker                    torch.int16,
1547*da0073e9SAndroid Build Coastguard Worker                    torch.int32,
1548*da0073e9SAndroid Build Coastguard Worker                    torch.int64,
1549*da0073e9SAndroid Build Coastguard Worker                ]:
1550*da0073e9SAndroid Build Coastguard Worker                    exp_tensor = make_tensor(
1551*da0073e9SAndroid Build Coastguard Worker                        exp_shape,
1552*da0073e9SAndroid Build Coastguard Worker                        dtype=dt,
1553*da0073e9SAndroid Build Coastguard Worker                        device=dev,
1554*da0073e9SAndroid Build Coastguard Worker                        low=0,
1555*da0073e9SAndroid Build Coastguard Worker                        high=high,
1556*da0073e9SAndroid Build Coastguard Worker                        noncontiguous=True,
1557*da0073e9SAndroid Build Coastguard Worker                    )
1558*da0073e9SAndroid Build Coastguard Worker                else:
1559*da0073e9SAndroid Build Coastguard Worker                    exp_tensor = make_tensor(
1560*da0073e9SAndroid Build Coastguard Worker                        exp_shape,
1561*da0073e9SAndroid Build Coastguard Worker                        dtype=dt,
1562*da0073e9SAndroid Build Coastguard Worker                        device=dev,
1563*da0073e9SAndroid Build Coastguard Worker                        low=low,
1564*da0073e9SAndroid Build Coastguard Worker                        high=high,
1565*da0073e9SAndroid Build Coastguard Worker                        noncontiguous=True,
1566*da0073e9SAndroid Build Coastguard Worker                    )
1567*da0073e9SAndroid Build Coastguard Worker                self._test_pow(base_tensor, exp_scalar)
1568*da0073e9SAndroid Build Coastguard Worker                self._test_pow(base_tensor, exp_tensor)
1569*da0073e9SAndroid Build Coastguard Worker
1570*da0073e9SAndroid Build Coastguard Worker        _test_int_and_float_pow(torch.int8, -2, 2, device)
1571*da0073e9SAndroid Build Coastguard Worker        _test_int_and_float_pow(torch.uint8, 0, 3, device)
1572*da0073e9SAndroid Build Coastguard Worker        _test_int_and_float_pow(torch.int16, -5, 5, device)
1573*da0073e9SAndroid Build Coastguard Worker        _test_int_and_float_pow(torch.int64, -10, 10, device)
1574*da0073e9SAndroid Build Coastguard Worker        _test_int_and_float_pow(torch.int32, -10, 10, device)
1575*da0073e9SAndroid Build Coastguard Worker        _test_int_and_float_pow(torch.float16, 0.0, 5.0, device)
1576*da0073e9SAndroid Build Coastguard Worker        _test_int_and_float_pow(torch.float32, 0.0, 10.0, device)
1577*da0073e9SAndroid Build Coastguard Worker        _test_int_and_float_pow(torch.float64, 0.0, 10.0, device)
1578*da0073e9SAndroid Build Coastguard Worker        # pow's output would have some NaNs as well
1579*da0073e9SAndroid Build Coastguard Worker        _test_int_and_float_pow(torch.float32, -10.0, 10.0, device)
1580*da0073e9SAndroid Build Coastguard Worker        _test_int_and_float_pow(torch.float64, -10.0, 10.0, device)
1581*da0073e9SAndroid Build Coastguard Worker
1582*da0073e9SAndroid Build Coastguard Worker    # Tests that a Runtime error occurs when a base tensor cannot be resized
1583*da0073e9SAndroid Build Coastguard Worker    # by pow's inplace variant due to PyTorch's broadcasting semantics.
1584*da0073e9SAndroid Build Coastguard Worker    def test_pow_inplace_resizing_exception(self, device):
1585*da0073e9SAndroid Build Coastguard Worker        test_cases = (
1586*da0073e9SAndroid Build Coastguard Worker            ((), (3,)),
1587*da0073e9SAndroid Build Coastguard Worker            ((2,), (2, 1)),
1588*da0073e9SAndroid Build Coastguard Worker            ((2, 1), (2, 2)),
1589*da0073e9SAndroid Build Coastguard Worker            ((2, 2), (2, 1, 1)),
1590*da0073e9SAndroid Build Coastguard Worker        )
1591*da0073e9SAndroid Build Coastguard Worker        test_inputs = [
1592*da0073e9SAndroid Build Coastguard Worker            (
1593*da0073e9SAndroid Build Coastguard Worker                make_tensor(
1594*da0073e9SAndroid Build Coastguard Worker                    base_size, dtype=torch.float64, device=device, high=10.0, low=0.0
1595*da0073e9SAndroid Build Coastguard Worker                ),
1596*da0073e9SAndroid Build Coastguard Worker                make_tensor(
1597*da0073e9SAndroid Build Coastguard Worker                    exp_size, dtype=torch.float64, device=device, high=10.0, low=0.0
1598*da0073e9SAndroid Build Coastguard Worker                ),
1599*da0073e9SAndroid Build Coastguard Worker            )
1600*da0073e9SAndroid Build Coastguard Worker            for base_size, exp_size in test_cases
1601*da0073e9SAndroid Build Coastguard Worker        ]
1602*da0073e9SAndroid Build Coastguard Worker        for base, exponent in test_inputs:
1603*da0073e9SAndroid Build Coastguard Worker            regex = "doesn't match the broadcast shape"
1604*da0073e9SAndroid Build Coastguard Worker            self.assertRaisesRegex(RuntimeError, regex, base.pow_, exponent)
1605*da0073e9SAndroid Build Coastguard Worker
1606*da0073e9SAndroid Build Coastguard Worker    def test_int_tensor_pow_neg_ints(self, device):
1607*da0073e9SAndroid Build Coastguard Worker        ints = [
1608*da0073e9SAndroid Build Coastguard Worker            torch.iinfo(torch.int32).min,
1609*da0073e9SAndroid Build Coastguard Worker            -3,
1610*da0073e9SAndroid Build Coastguard Worker            -2,
1611*da0073e9SAndroid Build Coastguard Worker            -1,
1612*da0073e9SAndroid Build Coastguard Worker            0,
1613*da0073e9SAndroid Build Coastguard Worker            1,
1614*da0073e9SAndroid Build Coastguard Worker            2,
1615*da0073e9SAndroid Build Coastguard Worker            3,
1616*da0073e9SAndroid Build Coastguard Worker            torch.iinfo(torch.int32).max,
1617*da0073e9SAndroid Build Coastguard Worker        ]
1618*da0073e9SAndroid Build Coastguard Worker        neg_ints = [torch.iinfo(torch.int32).min, -3, -2, -1]
1619*da0073e9SAndroid Build Coastguard Worker        tensor = torch.tensor(ints, dtype=torch.int32, device=device)
1620*da0073e9SAndroid Build Coastguard Worker        for pow in neg_ints:
1621*da0073e9SAndroid Build Coastguard Worker            self._test_pow(tensor, pow)
1622*da0073e9SAndroid Build Coastguard Worker
1623*da0073e9SAndroid Build Coastguard Worker    def test_long_tensor_pow_floats(self, device):
1624*da0073e9SAndroid Build Coastguard Worker        ints = [0, 1, 23, 4567]
1625*da0073e9SAndroid Build Coastguard Worker        floats = [0.0, 1 / 3, 1 / 2, 1.0, 3 / 2, 2.0]
1626*da0073e9SAndroid Build Coastguard Worker        tensor = torch.tensor(ints, dtype=torch.int64, device=device)
1627*da0073e9SAndroid Build Coastguard Worker        for pow in floats:
1628*da0073e9SAndroid Build Coastguard Worker            self._test_pow(tensor, pow)
1629*da0073e9SAndroid Build Coastguard Worker
1630*da0073e9SAndroid Build Coastguard Worker    @dtypes(*[torch.float32, torch.float64])
1631*da0073e9SAndroid Build Coastguard Worker    def test_float_scalar_pow_float_tensor(self, device, dtype):
1632*da0073e9SAndroid Build Coastguard Worker        floats = [2.0, -3 / 2, -1.0, -1 / 2, -1 / 3, 0.0, 1 / 3, 1 / 2, 1.0, 3 / 2, 2.0]
1633*da0073e9SAndroid Build Coastguard Worker        exponent_shapes = (
1634*da0073e9SAndroid Build Coastguard Worker            (1,),
1635*da0073e9SAndroid Build Coastguard Worker            (2, 2),
1636*da0073e9SAndroid Build Coastguard Worker            (2, 1),
1637*da0073e9SAndroid Build Coastguard Worker            (2, 2, 2),
1638*da0073e9SAndroid Build Coastguard Worker        )
1639*da0073e9SAndroid Build Coastguard Worker        tensors = [
1640*da0073e9SAndroid Build Coastguard Worker            make_tensor(shape, dtype=dtype, device=device, low=0)
1641*da0073e9SAndroid Build Coastguard Worker            for shape in exponent_shapes
1642*da0073e9SAndroid Build Coastguard Worker        ]
1643*da0073e9SAndroid Build Coastguard Worker        floats_tensor = torch.tensor(floats, dtype=dtype, device=device)
1644*da0073e9SAndroid Build Coastguard Worker        for base in floats:
1645*da0073e9SAndroid Build Coastguard Worker            self._test_pow(base, floats_tensor)
1646*da0073e9SAndroid Build Coastguard Worker            for tensor in tensors:
1647*da0073e9SAndroid Build Coastguard Worker                self._test_pow(base, tensor)
1648*da0073e9SAndroid Build Coastguard Worker
1649*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1650*da0073e9SAndroid Build Coastguard Worker    def test_cuda_tensor_pow_scalar_tensor(self, device):
1651*da0073e9SAndroid Build Coastguard Worker        cuda_tensors = [
1652*da0073e9SAndroid Build Coastguard Worker            torch.randn((3, 3), device=device),
1653*da0073e9SAndroid Build Coastguard Worker            torch.tensor(3.0, device=device),
1654*da0073e9SAndroid Build Coastguard Worker        ]
1655*da0073e9SAndroid Build Coastguard Worker        scalar_tensors = [
1656*da0073e9SAndroid Build Coastguard Worker            torch.tensor(5.0, device="cpu"),
1657*da0073e9SAndroid Build Coastguard Worker            torch.tensor(-3),
1658*da0073e9SAndroid Build Coastguard Worker            torch.tensor(1),
1659*da0073e9SAndroid Build Coastguard Worker        ]
1660*da0073e9SAndroid Build Coastguard Worker        for base, exp in product(cuda_tensors, scalar_tensors):
1661*da0073e9SAndroid Build Coastguard Worker            self._test_pow(base, exp)
1662*da0073e9SAndroid Build Coastguard Worker
1663*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1664*da0073e9SAndroid Build Coastguard Worker    def test_cpu_tensor_pow_cuda_scalar_tensor(self, device):
1665*da0073e9SAndroid Build Coastguard Worker        cuda_tensors = [
1666*da0073e9SAndroid Build Coastguard Worker            torch.tensor(5.0, device="cuda"),
1667*da0073e9SAndroid Build Coastguard Worker            torch.tensor(-3, device="cuda"),
1668*da0073e9SAndroid Build Coastguard Worker        ]
1669*da0073e9SAndroid Build Coastguard Worker        for exp in cuda_tensors:
1670*da0073e9SAndroid Build Coastguard Worker            base = torch.randn((3, 3), device="cpu")
1671*da0073e9SAndroid Build Coastguard Worker            regex = "Expected all tensors to be on the same device, but found at least two devices, cuda.* and cpu!"
1672*da0073e9SAndroid Build Coastguard Worker            self.assertRaisesRegex(RuntimeError, regex, torch.pow, base, exp)
1673*da0073e9SAndroid Build Coastguard Worker        for exp in cuda_tensors:
1674*da0073e9SAndroid Build Coastguard Worker            # Binary ops with a cpu + cuda tensor are allowed if the cpu tensor has 0 dimension
1675*da0073e9SAndroid Build Coastguard Worker            base = torch.tensor(3.0, device="cpu")
1676*da0073e9SAndroid Build Coastguard Worker            self._test_pow(base, exp)
1677*da0073e9SAndroid Build Coastguard Worker
1678*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1679*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.complex64, torch.complex128)
1680*da0073e9SAndroid Build Coastguard Worker    def test_pow_cuda_complex_extremal_failing(self, device, dtype):
1681*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor(complex(-1.0, float("inf")), dtype=dtype, device=device)
1682*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(AssertionError):
1683*da0073e9SAndroid Build Coastguard Worker            cuda_out = t.pow(2)
1684*da0073e9SAndroid Build Coastguard Worker            cpu_out = t.cpu().pow(2)
1685*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cpu_out, cuda_out)
1686*da0073e9SAndroid Build Coastguard Worker
1687*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo()
1688*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1689*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half))
1690*da0073e9SAndroid Build Coastguard Worker    def test_complex_scalar_pow_tensor(self, device, dtype):
1691*da0073e9SAndroid Build Coastguard Worker        complexes = [0.5j, 1.0 + 1.0j, -1.5j, 2.2 - 1.6j, 1 + 0j]
1692*da0073e9SAndroid Build Coastguard Worker        first_exp = make_tensor((100,), dtype=dtype, device=device, low=-2, high=2)
1693*da0073e9SAndroid Build Coastguard Worker        second_exp = make_tensor(
1694*da0073e9SAndroid Build Coastguard Worker            (100,), dtype=dtype, device=device, low=-2, high=2, noncontiguous=True
1695*da0073e9SAndroid Build Coastguard Worker        )
1696*da0073e9SAndroid Build Coastguard Worker        first_exp[0] = first_exp[10] = first_exp[20] = 0
1697*da0073e9SAndroid Build Coastguard Worker        second_exp[0] = second_exp[10] = second_exp[20] = 0
1698*da0073e9SAndroid Build Coastguard Worker        for base in complexes:
1699*da0073e9SAndroid Build Coastguard Worker            # On CPU,
1700*da0073e9SAndroid Build Coastguard Worker            # Half Tensor with complex base leads to computation dtype
1701*da0073e9SAndroid Build Coastguard Worker            # of ComplexHalf for which this ops is not supported yet
1702*da0073e9SAndroid Build Coastguard Worker            # NOTE: pow has fast-path when base is 1 which supports
1703*da0073e9SAndroid Build Coastguard Worker            # ComplexHalf
1704*da0073e9SAndroid Build Coastguard Worker            will_raise_error = (
1705*da0073e9SAndroid Build Coastguard Worker                torch.device(device).type == "cpu"
1706*da0073e9SAndroid Build Coastguard Worker                and dtype is torch.half
1707*da0073e9SAndroid Build Coastguard Worker                and base != (1 + 0j)
1708*da0073e9SAndroid Build Coastguard Worker            )
1709*da0073e9SAndroid Build Coastguard Worker            if will_raise_error:
1710*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(
1711*da0073e9SAndroid Build Coastguard Worker                    RuntimeError, "not implemented for 'ComplexHalf'"
1712*da0073e9SAndroid Build Coastguard Worker                ):
1713*da0073e9SAndroid Build Coastguard Worker                    self._test_pow(base, first_exp)
1714*da0073e9SAndroid Build Coastguard Worker                    self._test_pow(base, second_exp)
1715*da0073e9SAndroid Build Coastguard Worker            else:
1716*da0073e9SAndroid Build Coastguard Worker                self._test_pow(base, first_exp)
1717*da0073e9SAndroid Build Coastguard Worker                self._test_pow(base, second_exp)
1718*da0073e9SAndroid Build Coastguard Worker
1719*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1720*da0073e9SAndroid Build Coastguard Worker    @skipMeta
1721*da0073e9SAndroid Build Coastguard Worker    def test_pow_scalar_type_promotion(self, device):
1722*da0073e9SAndroid Build Coastguard Worker        # Test against a scalar and non-scalar input
1723*da0073e9SAndroid Build Coastguard Worker        inputs = [17, [17]]
1724*da0073e9SAndroid Build Coastguard Worker        for input in inputs:
1725*da0073e9SAndroid Build Coastguard Worker            # We expect the computation to be performed in uint8 (overflowing to 0), and then cast to int64
1726*da0073e9SAndroid Build Coastguard Worker            input_tensor_uint8 = torch.tensor(input, dtype=torch.uint8, device=device)
1727*da0073e9SAndroid Build Coastguard Worker            out_uint8_computation = torch.pow(
1728*da0073e9SAndroid Build Coastguard Worker                2,
1729*da0073e9SAndroid Build Coastguard Worker                input_tensor_uint8,
1730*da0073e9SAndroid Build Coastguard Worker                out=torch.tensor(0, dtype=torch.int64, device=device),
1731*da0073e9SAndroid Build Coastguard Worker            )
1732*da0073e9SAndroid Build Coastguard Worker
1733*da0073e9SAndroid Build Coastguard Worker            # Computation should run in int64, and not overflow
1734*da0073e9SAndroid Build Coastguard Worker            input_tensor_int64 = torch.tensor(input, dtype=torch.int64, device=device)
1735*da0073e9SAndroid Build Coastguard Worker            out_int64_computation = torch.pow(
1736*da0073e9SAndroid Build Coastguard Worker                2,
1737*da0073e9SAndroid Build Coastguard Worker                input_tensor_int64,
1738*da0073e9SAndroid Build Coastguard Worker                out=torch.tensor(0, dtype=torch.int64, device=device),
1739*da0073e9SAndroid Build Coastguard Worker            )
1740*da0073e9SAndroid Build Coastguard Worker
1741*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(out_uint8_computation, out_int64_computation)
1742*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1743*da0073e9SAndroid Build Coastguard Worker                out_uint8_computation.to(dtype=torch.uint8),
1744*da0073e9SAndroid Build Coastguard Worker                out_int64_computation.to(dtype=torch.uint8),
1745*da0073e9SAndroid Build Coastguard Worker            )
1746*da0073e9SAndroid Build Coastguard Worker
1747*da0073e9SAndroid Build Coastguard Worker    def test_tensor_pow_tensor(self, device):
1748*da0073e9SAndroid Build Coastguard Worker        def rotate(l, n):
1749*da0073e9SAndroid Build Coastguard Worker            return l[-n:] + l[:-n]
1750*da0073e9SAndroid Build Coastguard Worker
1751*da0073e9SAndroid Build Coastguard Worker        def test_tensor_pow_tensor(values, torch_type, numpy_type):
1752*da0073e9SAndroid Build Coastguard Worker            vals_tensor = torch.tensor(values, dtype=torch_type, device=device)
1753*da0073e9SAndroid Build Coastguard Worker            for i in range(len(values)):
1754*da0073e9SAndroid Build Coastguard Worker                pows = rotate(values, i)
1755*da0073e9SAndroid Build Coastguard Worker                pows_tensor = torch.tensor(pows, dtype=torch_type, device=device)
1756*da0073e9SAndroid Build Coastguard Worker                self._test_pow(vals_tensor, pows_tensor)
1757*da0073e9SAndroid Build Coastguard Worker
1758*da0073e9SAndroid Build Coastguard Worker        ints = [0, 1, 2, 3]
1759*da0073e9SAndroid Build Coastguard Worker        test_tensor_pow_tensor(ints, torch.uint8, np.uint8)
1760*da0073e9SAndroid Build Coastguard Worker        test_tensor_pow_tensor(ints, torch.int8, np.int8)
1761*da0073e9SAndroid Build Coastguard Worker        test_tensor_pow_tensor(ints, torch.int16, np.int16)
1762*da0073e9SAndroid Build Coastguard Worker        test_tensor_pow_tensor(ints, torch.int32, np.int32)
1763*da0073e9SAndroid Build Coastguard Worker        test_tensor_pow_tensor(ints, torch.int64, np.int64)
1764*da0073e9SAndroid Build Coastguard Worker
1765*da0073e9SAndroid Build Coastguard Worker        floats = [-3.0, -2.0, -1.0, -1 / 2, -1 / 3, 0.0, 1 / 3, 1 / 2, 1.0, 2.0, 3.0]
1766*da0073e9SAndroid Build Coastguard Worker        test_tensor_pow_tensor(floats, torch.float16, np.float16)
1767*da0073e9SAndroid Build Coastguard Worker        test_tensor_pow_tensor(floats, torch.float32, np.float32)
1768*da0073e9SAndroid Build Coastguard Worker        test_tensor_pow_tensor(floats, torch.float64, np.float64)
1769*da0073e9SAndroid Build Coastguard Worker
1770*da0073e9SAndroid Build Coastguard Worker    def test_logical_xor_with_nontrivial_alignment(self, device):
1771*da0073e9SAndroid Build Coastguard Worker        # test tensor that is not aligned to multiple of 16 bytes
1772*da0073e9SAndroid Build Coastguard Worker        size = 128
1773*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(size, device=device) > 0
1774*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(size, device=device) > 0
1775*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(size, device=device) > 0
1776*da0073e9SAndroid Build Coastguard Worker        non_trivial_alignment = [1, 2, 4, 8, 15]
1777*da0073e9SAndroid Build Coastguard Worker        for i in non_trivial_alignment:
1778*da0073e9SAndroid Build Coastguard Worker            for j in non_trivial_alignment:
1779*da0073e9SAndroid Build Coastguard Worker                for k in non_trivial_alignment:
1780*da0073e9SAndroid Build Coastguard Worker                    a_ = a[i : 100 + i]
1781*da0073e9SAndroid Build Coastguard Worker                    b_ = b[j : 100 + j]
1782*da0073e9SAndroid Build Coastguard Worker                    c_ = c[k : 100 + k]
1783*da0073e9SAndroid Build Coastguard Worker                    torch.logical_xor(a_, b_, out=c_)
1784*da0073e9SAndroid Build Coastguard Worker                    for x, y, z in zip(a_.tolist(), b_.tolist(), c_.tolist()):
1785*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(x ^ y, z)
1786*da0073e9SAndroid Build Coastguard Worker
1787*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
1788*da0073e9SAndroid Build Coastguard Worker    def test_add_with_tail(self, device, dtype):
1789*da0073e9SAndroid Build Coastguard Worker        # test tensor where there is a tail which is not a multiple
1790*da0073e9SAndroid Build Coastguard Worker        # of GPU warp size
1791*da0073e9SAndroid Build Coastguard Worker        for tail_size in [1, 63, 67, 130]:
1792*da0073e9SAndroid Build Coastguard Worker            size = 4096 + tail_size
1793*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(size, device=device, dtype=dtype)
1794*da0073e9SAndroid Build Coastguard Worker            b = torch.randn(size, device=device, dtype=dtype)
1795*da0073e9SAndroid Build Coastguard Worker            c = a + b
1796*da0073e9SAndroid Build Coastguard Worker            for x, y, z in zip(a.tolist(), b.tolist(), c.tolist()):
1797*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x + y, z)
1798*da0073e9SAndroid Build Coastguard Worker
1799*da0073e9SAndroid Build Coastguard Worker    # Tests that CUDA tensors on different devices cannot be used in the same
1800*da0073e9SAndroid Build Coastguard Worker    # binary operation, and that CUDA "scalars" cannot be used in the same
1801*da0073e9SAndroid Build Coastguard Worker    # binary operation as non-scalar CPU tensors.
1802*da0073e9SAndroid Build Coastguard Worker    @deviceCountAtLeast(2)
1803*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1804*da0073e9SAndroid Build Coastguard Worker    def test_cross_device_binary_ops(self, devices):
1805*da0073e9SAndroid Build Coastguard Worker        vals = (1.0, (2.0,))
1806*da0073e9SAndroid Build Coastguard Worker        cpu_tensor = torch.randn(2, 2)
1807*da0073e9SAndroid Build Coastguard Worker
1808*da0073e9SAndroid Build Coastguard Worker        def do_test(op, a, b):
1809*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"):
1810*da0073e9SAndroid Build Coastguard Worker                op(a, b)
1811*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"):
1812*da0073e9SAndroid Build Coastguard Worker                op(b, a)
1813*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"):
1814*da0073e9SAndroid Build Coastguard Worker                op(a, cpu_tensor)
1815*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"):
1816*da0073e9SAndroid Build Coastguard Worker                op(cpu_tensor, a)
1817*da0073e9SAndroid Build Coastguard Worker
1818*da0073e9SAndroid Build Coastguard Worker        for op in (
1819*da0073e9SAndroid Build Coastguard Worker            operator.add,
1820*da0073e9SAndroid Build Coastguard Worker            torch.add,
1821*da0073e9SAndroid Build Coastguard Worker            operator.sub,
1822*da0073e9SAndroid Build Coastguard Worker            torch.sub,
1823*da0073e9SAndroid Build Coastguard Worker            operator.mul,
1824*da0073e9SAndroid Build Coastguard Worker            torch.mul,
1825*da0073e9SAndroid Build Coastguard Worker            operator.truediv,
1826*da0073e9SAndroid Build Coastguard Worker            torch.true_divide,
1827*da0073e9SAndroid Build Coastguard Worker            operator.floordiv,
1828*da0073e9SAndroid Build Coastguard Worker            torch.floor_divide,
1829*da0073e9SAndroid Build Coastguard Worker        ):
1830*da0073e9SAndroid Build Coastguard Worker            for a, b in product(vals, vals):
1831*da0073e9SAndroid Build Coastguard Worker                a = torch.tensor(a, device=devices[0])
1832*da0073e9SAndroid Build Coastguard Worker                b = torch.tensor(b, device=devices[1])
1833*da0073e9SAndroid Build Coastguard Worker
1834*da0073e9SAndroid Build Coastguard Worker            do_test(op, a, b)
1835*da0073e9SAndroid Build Coastguard Worker
1836*da0073e9SAndroid Build Coastguard Worker    # This test ensures that a scalar Tensor can be safely used
1837*da0073e9SAndroid Build Coastguard Worker    # in a binary operation in conjunction with a Tensor on all
1838*da0073e9SAndroid Build Coastguard Worker    # available CUDA devices
1839*da0073e9SAndroid Build Coastguard Worker    @deviceCountAtLeast(2)
1840*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1841*da0073e9SAndroid Build Coastguard Worker    def test_binary_op_scalar_device_unspecified(self, devices):
1842*da0073e9SAndroid Build Coastguard Worker        scalar_val = torch.tensor(1.0)
1843*da0073e9SAndroid Build Coastguard Worker        for default_device in devices:
1844*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.device(default_device):
1845*da0073e9SAndroid Build Coastguard Worker                for device in devices:
1846*da0073e9SAndroid Build Coastguard Worker                    device_obj = torch.device(device)
1847*da0073e9SAndroid Build Coastguard Worker                    x = torch.rand(3, device=device)
1848*da0073e9SAndroid Build Coastguard Worker                    y0 = x * scalar_val
1849*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(y0.device, device_obj)
1850*da0073e9SAndroid Build Coastguard Worker                    y1 = scalar_val * x
1851*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(y1.device, device_obj)
1852*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(y0, y1)
1853*da0073e9SAndroid Build Coastguard Worker
1854*da0073e9SAndroid Build Coastguard Worker    def test_div_and_floordiv_vs_python(self, device):
1855*da0073e9SAndroid Build Coastguard Worker        # Tests torch division ops which can handle both arguments being
1856*da0073e9SAndroid Build Coastguard Worker        #   scalars.
1857*da0073e9SAndroid Build Coastguard Worker        def _scalar_helper(python_op, torch_op):
1858*da0073e9SAndroid Build Coastguard Worker            for a, b in product(range(-10, 10), range(-10, 10)):
1859*da0073e9SAndroid Build Coastguard Worker                for op in (lambda x: x * 0.5, lambda x: math.floor(x)):
1860*da0073e9SAndroid Build Coastguard Worker                    a = op(a)
1861*da0073e9SAndroid Build Coastguard Worker                    b = op(b)
1862*da0073e9SAndroid Build Coastguard Worker
1863*da0073e9SAndroid Build Coastguard Worker                    # Skips zero divisors
1864*da0073e9SAndroid Build Coastguard Worker                    if b == 0:
1865*da0073e9SAndroid Build Coastguard Worker                        continue
1866*da0073e9SAndroid Build Coastguard Worker
1867*da0073e9SAndroid Build Coastguard Worker                    expected = python_op(a, b)
1868*da0073e9SAndroid Build Coastguard Worker
1869*da0073e9SAndroid Build Coastguard Worker                    for op in (operator.truediv, torch.true_divide):
1870*da0073e9SAndroid Build Coastguard Worker                        actual_scalar = torch_op(a, b)
1871*da0073e9SAndroid Build Coastguard Worker
1872*da0073e9SAndroid Build Coastguard Worker                        a_t = torch.tensor(a, device=device)
1873*da0073e9SAndroid Build Coastguard Worker                        b_t = torch.tensor(b, device=device)
1874*da0073e9SAndroid Build Coastguard Worker
1875*da0073e9SAndroid Build Coastguard Worker                        actual_tensor = torch_op(a_t, b_t)
1876*da0073e9SAndroid Build Coastguard Worker                        actual_first_tensor = torch_op(a_t, b)
1877*da0073e9SAndroid Build Coastguard Worker                        actual_second_tensor = torch_op(a, b_t)
1878*da0073e9SAndroid Build Coastguard Worker
1879*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(actual_scalar, expected)
1880*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(actual_tensor.item(), expected)
1881*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(actual_first_tensor, actual_tensor)
1882*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(actual_second_tensor, actual_tensor)
1883*da0073e9SAndroid Build Coastguard Worker
1884*da0073e9SAndroid Build Coastguard Worker        _scalar_helper(operator.truediv, operator.truediv)
1885*da0073e9SAndroid Build Coastguard Worker        _scalar_helper(operator.truediv, torch.true_divide)
1886*da0073e9SAndroid Build Coastguard Worker        _scalar_helper(lambda a, b: math.floor(a / b), operator.floordiv)
1887*da0073e9SAndroid Build Coastguard Worker        _scalar_helper(lambda a, b: math.floor(a / b), torch.floor_divide)
1888*da0073e9SAndroid Build Coastguard Worker
1889*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1890*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
1891*da0073e9SAndroid Build Coastguard Worker    def test_div_and_floordiv_script_vs_python(self, device):
1892*da0073e9SAndroid Build Coastguard Worker        # Creates jitted functions of two tensors
1893*da0073e9SAndroid Build Coastguard Worker        def _wrapped_div(a, b):
1894*da0073e9SAndroid Build Coastguard Worker            return a / b
1895*da0073e9SAndroid Build Coastguard Worker
1896*da0073e9SAndroid Build Coastguard Worker        def _wrapped_floordiv(a, b):
1897*da0073e9SAndroid Build Coastguard Worker            return a // b
1898*da0073e9SAndroid Build Coastguard Worker
1899*da0073e9SAndroid Build Coastguard Worker        scripted_div = torch.jit.script(_wrapped_div)
1900*da0073e9SAndroid Build Coastguard Worker        scripted_floordiv = torch.jit.script(_wrapped_floordiv)
1901*da0073e9SAndroid Build Coastguard Worker        for a, b in product(range(-10, 10), range(-10, 10)):
1902*da0073e9SAndroid Build Coastguard Worker            for op in (lambda x: x * 0.5, lambda x: math.floor(x)):
1903*da0073e9SAndroid Build Coastguard Worker                a = op(a)
1904*da0073e9SAndroid Build Coastguard Worker                b = op(b)
1905*da0073e9SAndroid Build Coastguard Worker
1906*da0073e9SAndroid Build Coastguard Worker                # Skips zero divisors
1907*da0073e9SAndroid Build Coastguard Worker                if b == 0:
1908*da0073e9SAndroid Build Coastguard Worker                    continue
1909*da0073e9SAndroid Build Coastguard Worker
1910*da0073e9SAndroid Build Coastguard Worker                expected_div = a / b
1911*da0073e9SAndroid Build Coastguard Worker                expected_floordiv = math.floor(a / b)
1912*da0073e9SAndroid Build Coastguard Worker                a_t = torch.tensor(a, device=device)
1913*da0073e9SAndroid Build Coastguard Worker                b_t = torch.tensor(b, device=device)
1914*da0073e9SAndroid Build Coastguard Worker
1915*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(scripted_div(a_t, b_t), expected_div)
1916*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(scripted_floordiv(a_t, b_t), expected_floordiv)
1917*da0073e9SAndroid Build Coastguard Worker
1918*da0073e9SAndroid Build Coastguard Worker        # Creates jitted functions of one tensor
1919*da0073e9SAndroid Build Coastguard Worker        def _wrapped_div_scalar(a):
1920*da0073e9SAndroid Build Coastguard Worker            return a / 5
1921*da0073e9SAndroid Build Coastguard Worker
1922*da0073e9SAndroid Build Coastguard Worker        # NOTE: the JIT implements division as torch.reciprocal(a) * 5
1923*da0073e9SAndroid Build Coastguard Worker        def _wrapped_rdiv_scalar(a):
1924*da0073e9SAndroid Build Coastguard Worker            return 5 / a
1925*da0073e9SAndroid Build Coastguard Worker
1926*da0073e9SAndroid Build Coastguard Worker        def _wrapped_floordiv_scalar(a):
1927*da0073e9SAndroid Build Coastguard Worker            return a // 5
1928*da0073e9SAndroid Build Coastguard Worker
1929*da0073e9SAndroid Build Coastguard Worker        # NOTE: this fails if the input is not an integer tensor
1930*da0073e9SAndroid Build Coastguard Worker        # See https://github.com/pytorch/pytorch/issues/45199
1931*da0073e9SAndroid Build Coastguard Worker        def _wrapped_rfloordiv_scalar(a):
1932*da0073e9SAndroid Build Coastguard Worker            return 5 // a
1933*da0073e9SAndroid Build Coastguard Worker
1934*da0073e9SAndroid Build Coastguard Worker        scripted_div_scalar = torch.jit.script(_wrapped_div_scalar)
1935*da0073e9SAndroid Build Coastguard Worker        scripted_rdiv_scalar = torch.jit.script(_wrapped_rdiv_scalar)
1936*da0073e9SAndroid Build Coastguard Worker        scripted_floordiv_scalar = torch.jit.script(_wrapped_floordiv_scalar)
1937*da0073e9SAndroid Build Coastguard Worker        scripted_rfloordiv_scalar = torch.jit.script(_wrapped_rfloordiv_scalar)
1938*da0073e9SAndroid Build Coastguard Worker
1939*da0073e9SAndroid Build Coastguard Worker        for a in range(-10, 10):
1940*da0073e9SAndroid Build Coastguard Worker            for op in (lambda x: x * 0.5, lambda x: math.floor(x)):
1941*da0073e9SAndroid Build Coastguard Worker                a = op(a)
1942*da0073e9SAndroid Build Coastguard Worker
1943*da0073e9SAndroid Build Coastguard Worker                a_t = torch.tensor(a, device=device)
1944*da0073e9SAndroid Build Coastguard Worker
1945*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(a / 5, scripted_div_scalar(a_t))
1946*da0073e9SAndroid Build Coastguard Worker
1947*da0073e9SAndroid Build Coastguard Worker                # Skips zero divisors
1948*da0073e9SAndroid Build Coastguard Worker                if a == 0:
1949*da0073e9SAndroid Build Coastguard Worker                    continue
1950*da0073e9SAndroid Build Coastguard Worker
1951*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(5 / a, scripted_rdiv_scalar(a_t))
1952*da0073e9SAndroid Build Coastguard Worker
1953*da0073e9SAndroid Build Coastguard Worker                # Handles Issue 45199 (see comment above)
1954*da0073e9SAndroid Build Coastguard Worker                if a_t.is_floating_point():
1955*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaises(RuntimeError):
1956*da0073e9SAndroid Build Coastguard Worker                        scripted_rfloordiv_scalar(a_t)
1957*da0073e9SAndroid Build Coastguard Worker                else:
1958*da0073e9SAndroid Build Coastguard Worker                    # This should emit a UserWarning, why doesn't it?
1959*da0073e9SAndroid Build Coastguard Worker                    # See issue gh-52387
1960*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(5 // a, scripted_rfloordiv_scalar(a_t))
1961*da0073e9SAndroid Build Coastguard Worker
1962*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1963*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
1964*da0073e9SAndroid Build Coastguard Worker    def test_idiv_and_ifloordiv_vs_python(self, device):
1965*da0073e9SAndroid Build Coastguard Worker        def _wrapped_idiv_tensor(a, b):
1966*da0073e9SAndroid Build Coastguard Worker            a /= b
1967*da0073e9SAndroid Build Coastguard Worker            return a
1968*da0073e9SAndroid Build Coastguard Worker
1969*da0073e9SAndroid Build Coastguard Worker        def _wrapped_idiv_scalar(a):
1970*da0073e9SAndroid Build Coastguard Worker            a /= 5
1971*da0073e9SAndroid Build Coastguard Worker            return a
1972*da0073e9SAndroid Build Coastguard Worker
1973*da0073e9SAndroid Build Coastguard Worker        def _wrapped_true_divide__tensor(a, b):
1974*da0073e9SAndroid Build Coastguard Worker            a.true_divide_(b)
1975*da0073e9SAndroid Build Coastguard Worker            return a
1976*da0073e9SAndroid Build Coastguard Worker
1977*da0073e9SAndroid Build Coastguard Worker        def _wrapped_true_divide__scalar(a):
1978*da0073e9SAndroid Build Coastguard Worker            a.true_divide_(5)
1979*da0073e9SAndroid Build Coastguard Worker            return a
1980*da0073e9SAndroid Build Coastguard Worker
1981*da0073e9SAndroid Build Coastguard Worker        def _wrapped_floor_divide__tensor(a, b):
1982*da0073e9SAndroid Build Coastguard Worker            a.floor_divide_(b)
1983*da0073e9SAndroid Build Coastguard Worker            return a
1984*da0073e9SAndroid Build Coastguard Worker
1985*da0073e9SAndroid Build Coastguard Worker        def _wrapped_floor_divide__scalar(a):
1986*da0073e9SAndroid Build Coastguard Worker            a.floor_divide_(5)
1987*da0073e9SAndroid Build Coastguard Worker            return a
1988*da0073e9SAndroid Build Coastguard Worker
1989*da0073e9SAndroid Build Coastguard Worker        # The following functions are unsupported by the JIT
1990*da0073e9SAndroid Build Coastguard Worker        def _wrapped_ifloordiv_tensor(a, b):
1991*da0073e9SAndroid Build Coastguard Worker            a //= b
1992*da0073e9SAndroid Build Coastguard Worker            return a
1993*da0073e9SAndroid Build Coastguard Worker
1994*da0073e9SAndroid Build Coastguard Worker        def _wrapped_ifloordiv_scalar(a):
1995*da0073e9SAndroid Build Coastguard Worker            a //= 5
1996*da0073e9SAndroid Build Coastguard Worker            return a
1997*da0073e9SAndroid Build Coastguard Worker
1998*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(torch.jit.frontend.NotSupportedError):
1999*da0073e9SAndroid Build Coastguard Worker            scripted_ifloordiv_tensor = torch.jit.script(_wrapped_ifloordiv_tensor)
2000*da0073e9SAndroid Build Coastguard Worker
2001*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(torch.jit.frontend.NotSupportedError):
2002*da0073e9SAndroid Build Coastguard Worker            scripted_ifloordiv_scalar = torch.jit.script(_wrapped_ifloordiv_scalar)
2003*da0073e9SAndroid Build Coastguard Worker
2004*da0073e9SAndroid Build Coastguard Worker        scripted_idiv_tensor = torch.jit.script(_wrapped_idiv_tensor)
2005*da0073e9SAndroid Build Coastguard Worker        scripted_idiv_scalar = torch.jit.script(_wrapped_idiv_scalar)
2006*da0073e9SAndroid Build Coastguard Worker        scripted_true_divide__tensor = torch.jit.script(_wrapped_true_divide__tensor)
2007*da0073e9SAndroid Build Coastguard Worker        scripted_true_divide__scalar = torch.jit.script(_wrapped_true_divide__scalar)
2008*da0073e9SAndroid Build Coastguard Worker        scripted_floor_divide__tensor = torch.jit.script(_wrapped_floor_divide__tensor)
2009*da0073e9SAndroid Build Coastguard Worker        scripted_floor_divide__scalar = torch.jit.script(_wrapped_floor_divide__scalar)
2010*da0073e9SAndroid Build Coastguard Worker
2011*da0073e9SAndroid Build Coastguard Worker        for a, b in product(range(-10, 10), range(-10, 10)):
2012*da0073e9SAndroid Build Coastguard Worker            for op in (lambda x: x * 0.5, lambda x: math.floor(x)):
2013*da0073e9SAndroid Build Coastguard Worker                a = op(a)
2014*da0073e9SAndroid Build Coastguard Worker                b = op(b)
2015*da0073e9SAndroid Build Coastguard Worker
2016*da0073e9SAndroid Build Coastguard Worker                # Skips zero divisors
2017*da0073e9SAndroid Build Coastguard Worker                if b == 0:
2018*da0073e9SAndroid Build Coastguard Worker                    continue
2019*da0073e9SAndroid Build Coastguard Worker
2020*da0073e9SAndroid Build Coastguard Worker                expected_idiv = a / b
2021*da0073e9SAndroid Build Coastguard Worker                expected_ifloordiv = a // b
2022*da0073e9SAndroid Build Coastguard Worker
2023*da0073e9SAndroid Build Coastguard Worker                a_t = torch.tensor(a, device=device)
2024*da0073e9SAndroid Build Coastguard Worker                b_t = torch.tensor(b, device=device)
2025*da0073e9SAndroid Build Coastguard Worker
2026*da0073e9SAndroid Build Coastguard Worker                if a_t.is_floating_point():
2027*da0073e9SAndroid Build Coastguard Worker                    tmp0 = a_t.clone()
2028*da0073e9SAndroid Build Coastguard Worker                    tmp0 /= b
2029*da0073e9SAndroid Build Coastguard Worker
2030*da0073e9SAndroid Build Coastguard Worker                    tmp1 = a_t.clone()
2031*da0073e9SAndroid Build Coastguard Worker                    tmp1 /= b_t
2032*da0073e9SAndroid Build Coastguard Worker
2033*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(tmp0.item(), expected_idiv)
2034*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(tmp1.item(), expected_idiv)
2035*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(
2036*da0073e9SAndroid Build Coastguard Worker                        scripted_true_divide__tensor(a_t.clone(), b_t).item(),
2037*da0073e9SAndroid Build Coastguard Worker                        expected_idiv,
2038*da0073e9SAndroid Build Coastguard Worker                    )
2039*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(
2040*da0073e9SAndroid Build Coastguard Worker                        scripted_true_divide__scalar(a_t.clone()).item(), a / 5
2041*da0073e9SAndroid Build Coastguard Worker                    )
2042*da0073e9SAndroid Build Coastguard Worker                else:
2043*da0073e9SAndroid Build Coastguard Worker                    tmp = a_t.clone()
2044*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaises(RuntimeError):
2045*da0073e9SAndroid Build Coastguard Worker                        tmp /= b
2046*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaises(RuntimeError):
2047*da0073e9SAndroid Build Coastguard Worker                        tmp /= b_t
2048*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaises(RuntimeError):
2049*da0073e9SAndroid Build Coastguard Worker                        scripted_true_divide__tensor(tmp, b_t)
2050*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaises(RuntimeError):
2051*da0073e9SAndroid Build Coastguard Worker                        scripted_true_divide__scalar(tmp)
2052*da0073e9SAndroid Build Coastguard Worker
2053*da0073e9SAndroid Build Coastguard Worker                if not a_t.is_floating_point() and b_t.is_floating_point():
2054*da0073e9SAndroid Build Coastguard Worker                    # Inplace modification fails because a float tensor is required
2055*da0073e9SAndroid Build Coastguard Worker                    #   if the divisor is a float tensor
2056*da0073e9SAndroid Build Coastguard Worker                    a_t.clone().floor_divide_(b_t)
2057*da0073e9SAndroid Build Coastguard Worker                    scripted_floor_divide__tensor(a_t.clone(), b_t)
2058*da0073e9SAndroid Build Coastguard Worker                    tmp = a_t.clone()
2059*da0073e9SAndroid Build Coastguard Worker                    tmp //= b_t
2060*da0073e9SAndroid Build Coastguard Worker                else:
2061*da0073e9SAndroid Build Coastguard Worker                    # Inplace modification is OK when both or neither tensor is
2062*da0073e9SAndroid Build Coastguard Worker                    #   a float tensor
2063*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(
2064*da0073e9SAndroid Build Coastguard Worker                        a_t.clone().floor_divide_(b_t).item(), expected_ifloordiv
2065*da0073e9SAndroid Build Coastguard Worker                    )
2066*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(
2067*da0073e9SAndroid Build Coastguard Worker                        scripted_floor_divide__tensor(a_t.clone(), b_t).item(),
2068*da0073e9SAndroid Build Coastguard Worker                        expected_ifloordiv,
2069*da0073e9SAndroid Build Coastguard Worker                    )
2070*da0073e9SAndroid Build Coastguard Worker                    tmp = a_t.clone()
2071*da0073e9SAndroid Build Coastguard Worker                    tmp //= b_t
2072*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(tmp.item(), expected_ifloordiv)
2073*da0073e9SAndroid Build Coastguard Worker
2074*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(scripted_floor_divide__scalar(a_t), math.floor(a / 5))
2075*da0073e9SAndroid Build Coastguard Worker
2076*da0073e9SAndroid Build Coastguard Worker    # Tests binary op equivalence with Python builtin ops
2077*da0073e9SAndroid Build Coastguard Worker    # Also tests that reverse operations are equivalent to forward ops
2078*da0073e9SAndroid Build Coastguard Worker    # NOTE: division ops are tested separately above
2079*da0073e9SAndroid Build Coastguard Worker    def test_binary_ops_with_scalars(self, device):
2080*da0073e9SAndroid Build Coastguard Worker        for python_op, torch_op in (
2081*da0073e9SAndroid Build Coastguard Worker            (operator.add, torch.add),
2082*da0073e9SAndroid Build Coastguard Worker            (operator.sub, torch.sub),
2083*da0073e9SAndroid Build Coastguard Worker            (operator.mul, torch.mul),
2084*da0073e9SAndroid Build Coastguard Worker            (operator.truediv, torch.div),
2085*da0073e9SAndroid Build Coastguard Worker        ):
2086*da0073e9SAndroid Build Coastguard Worker            for a, b in product(range(-10, 10), range(-10, 10)):
2087*da0073e9SAndroid Build Coastguard Worker                for op in (lambda x: x * 0.5, lambda x: math.floor(x)):
2088*da0073e9SAndroid Build Coastguard Worker                    a = op(a)
2089*da0073e9SAndroid Build Coastguard Worker                    b = op(b)
2090*da0073e9SAndroid Build Coastguard Worker
2091*da0073e9SAndroid Build Coastguard Worker                    # Skips zero divisors
2092*da0073e9SAndroid Build Coastguard Worker                    if b == 0 or a == 0:
2093*da0073e9SAndroid Build Coastguard Worker                        continue
2094*da0073e9SAndroid Build Coastguard Worker
2095*da0073e9SAndroid Build Coastguard Worker                    a_tensor = torch.tensor(a, device=device)
2096*da0073e9SAndroid Build Coastguard Worker                    b_tensor = torch.tensor(b, device=device)
2097*da0073e9SAndroid Build Coastguard Worker                    a_tensor_cpu = a_tensor.cpu()
2098*da0073e9SAndroid Build Coastguard Worker                    b_tensor_cpu = b_tensor.cpu()
2099*da0073e9SAndroid Build Coastguard Worker                    vals = (a, b, a_tensor, b_tensor, a_tensor_cpu, b_tensor_cpu)
2100*da0073e9SAndroid Build Coastguard Worker
2101*da0073e9SAndroid Build Coastguard Worker                    for args in product(vals, vals):
2102*da0073e9SAndroid Build Coastguard Worker                        first, second = args
2103*da0073e9SAndroid Build Coastguard Worker
2104*da0073e9SAndroid Build Coastguard Worker                        first_scalar = (
2105*da0073e9SAndroid Build Coastguard Worker                            first
2106*da0073e9SAndroid Build Coastguard Worker                            if not isinstance(first, torch.Tensor)
2107*da0073e9SAndroid Build Coastguard Worker                            else first.item()
2108*da0073e9SAndroid Build Coastguard Worker                        )
2109*da0073e9SAndroid Build Coastguard Worker                        second_scalar = (
2110*da0073e9SAndroid Build Coastguard Worker                            second
2111*da0073e9SAndroid Build Coastguard Worker                            if not isinstance(second, torch.Tensor)
2112*da0073e9SAndroid Build Coastguard Worker                            else second.item()
2113*da0073e9SAndroid Build Coastguard Worker                        )
2114*da0073e9SAndroid Build Coastguard Worker                        expected = python_op(first_scalar, second_scalar)
2115*da0073e9SAndroid Build Coastguard Worker
2116*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(expected, python_op(first, second))
2117*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(expected, torch_op(first, second))
2118*da0073e9SAndroid Build Coastguard Worker
2119*da0073e9SAndroid Build Coastguard Worker    @dtypes(
2120*da0073e9SAndroid Build Coastguard Worker        *product(
2121*da0073e9SAndroid Build Coastguard Worker            all_types_and(torch.half, torch.bfloat16, torch.bool),
2122*da0073e9SAndroid Build Coastguard Worker            all_types_and(torch.half, torch.bfloat16, torch.bool),
2123*da0073e9SAndroid Build Coastguard Worker        )
2124*da0073e9SAndroid Build Coastguard Worker    )
2125*da0073e9SAndroid Build Coastguard Worker    def test_maximum_minimum_type_promotion(self, device, dtypes):
2126*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor((0, 1), device=device, dtype=dtypes[0])
2127*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor((1, 0), device=device, dtype=dtypes[1])
2128*da0073e9SAndroid Build Coastguard Worker        for op in (
2129*da0073e9SAndroid Build Coastguard Worker            torch.maximum,
2130*da0073e9SAndroid Build Coastguard Worker            torch.max,
2131*da0073e9SAndroid Build Coastguard Worker            torch.fmax,
2132*da0073e9SAndroid Build Coastguard Worker            torch.minimum,
2133*da0073e9SAndroid Build Coastguard Worker            torch.min,
2134*da0073e9SAndroid Build Coastguard Worker            torch.fmin,
2135*da0073e9SAndroid Build Coastguard Worker        ):
2136*da0073e9SAndroid Build Coastguard Worker            result = op(a, b)
2137*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.dtype, torch.result_type(a, b))
2138*da0073e9SAndroid Build Coastguard Worker
2139*da0073e9SAndroid Build Coastguard Worker    @dtypes(*integral_types_and(torch.bool))
2140*da0073e9SAndroid Build Coastguard Worker    def test_maximum_minimum_int_and_bool(self, device, dtype):
2141*da0073e9SAndroid Build Coastguard Worker        ops = (
2142*da0073e9SAndroid Build Coastguard Worker            (torch.maximum, torch.max, np.maximum),
2143*da0073e9SAndroid Build Coastguard Worker            (torch.minimum, torch.min, np.minimum),
2144*da0073e9SAndroid Build Coastguard Worker            (torch.fmax, None, np.fmax),
2145*da0073e9SAndroid Build Coastguard Worker            (torch.fmin, None, np.fmin),
2146*da0073e9SAndroid Build Coastguard Worker        )
2147*da0073e9SAndroid Build Coastguard Worker        rng = np.random.default_rng()
2148*da0073e9SAndroid Build Coastguard Worker        a_np = np.array(
2149*da0073e9SAndroid Build Coastguard Worker            rng.integers(-100, 100, size=10), dtype=torch_to_numpy_dtype_dict[dtype]
2150*da0073e9SAndroid Build Coastguard Worker        )
2151*da0073e9SAndroid Build Coastguard Worker        b_np = np.array(
2152*da0073e9SAndroid Build Coastguard Worker            rng.integers(-100, 100, size=10), dtype=torch_to_numpy_dtype_dict[dtype]
2153*da0073e9SAndroid Build Coastguard Worker        )
2154*da0073e9SAndroid Build Coastguard Worker
2155*da0073e9SAndroid Build Coastguard Worker        for torch_op, alias, numpy_op in ops:
2156*da0073e9SAndroid Build Coastguard Worker            a_tensor = torch.from_numpy(a_np).to(device=device, dtype=dtype)
2157*da0073e9SAndroid Build Coastguard Worker            b_tensor = torch.from_numpy(b_np).to(device=device, dtype=dtype)
2158*da0073e9SAndroid Build Coastguard Worker            tensor_result = torch_op(a_tensor, b_tensor)
2159*da0073e9SAndroid Build Coastguard Worker
2160*da0073e9SAndroid Build Coastguard Worker            out = torch.empty_like(a_tensor)
2161*da0073e9SAndroid Build Coastguard Worker            torch_op(a_tensor, b_tensor, out=out)
2162*da0073e9SAndroid Build Coastguard Worker
2163*da0073e9SAndroid Build Coastguard Worker            numpy_result = numpy_op(a_np, b_np)
2164*da0073e9SAndroid Build Coastguard Worker
2165*da0073e9SAndroid Build Coastguard Worker            if alias is not None:
2166*da0073e9SAndroid Build Coastguard Worker                alias_result = alias(a_tensor, b_tensor)
2167*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(alias_result, tensor_result)
2168*da0073e9SAndroid Build Coastguard Worker
2169*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tensor_result, numpy_result)
2170*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, numpy_result)
2171*da0073e9SAndroid Build Coastguard Worker
2172*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.bfloat16: 1e-2})
2173*da0073e9SAndroid Build Coastguard Worker    @dtypes(*(floating_types_and(torch.half, torch.bfloat16)))
2174*da0073e9SAndroid Build Coastguard Worker    def test_maximum_minimum_float(self, device, dtype):
2175*da0073e9SAndroid Build Coastguard Worker        ops = (
2176*da0073e9SAndroid Build Coastguard Worker            (torch.maximum, torch.max, np.maximum),
2177*da0073e9SAndroid Build Coastguard Worker            (torch.minimum, torch.min, np.minimum),
2178*da0073e9SAndroid Build Coastguard Worker            (torch.fmax, None, np.fmax),
2179*da0073e9SAndroid Build Coastguard Worker            (torch.fmin, None, np.fmin),
2180*da0073e9SAndroid Build Coastguard Worker        )
2181*da0073e9SAndroid Build Coastguard Worker
2182*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.bfloat16:
2183*da0073e9SAndroid Build Coastguard Worker            a_np = np.random.randn(10).astype(np.float64)
2184*da0073e9SAndroid Build Coastguard Worker            b_np = np.random.randn(10).astype(np.float64)
2185*da0073e9SAndroid Build Coastguard Worker        else:
2186*da0073e9SAndroid Build Coastguard Worker            a_np = np.random.randn(10).astype(torch_to_numpy_dtype_dict[dtype])
2187*da0073e9SAndroid Build Coastguard Worker            b_np = np.random.randn(10).astype(torch_to_numpy_dtype_dict[dtype])
2188*da0073e9SAndroid Build Coastguard Worker
2189*da0073e9SAndroid Build Coastguard Worker        for torch_op, alias, numpy_op in ops:
2190*da0073e9SAndroid Build Coastguard Worker            numpy_result = numpy_op(a_np, b_np)
2191*da0073e9SAndroid Build Coastguard Worker
2192*da0073e9SAndroid Build Coastguard Worker            a_tensor = torch.from_numpy(a_np).to(device=device, dtype=dtype)
2193*da0073e9SAndroid Build Coastguard Worker            b_tensor = torch.from_numpy(b_np).to(device=device, dtype=dtype)
2194*da0073e9SAndroid Build Coastguard Worker            tensor_result = torch_op(a_tensor, b_tensor)
2195*da0073e9SAndroid Build Coastguard Worker            out = torch.empty_like(a_tensor)
2196*da0073e9SAndroid Build Coastguard Worker            torch_op(a_tensor, b_tensor, out=out)
2197*da0073e9SAndroid Build Coastguard Worker
2198*da0073e9SAndroid Build Coastguard Worker            if alias is not None:
2199*da0073e9SAndroid Build Coastguard Worker                alias_result = alias(a_tensor, b_tensor)
2200*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(alias_result, tensor_result, exact_dtype=False)
2201*da0073e9SAndroid Build Coastguard Worker
2202*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tensor_result, numpy_result, exact_dtype=False)
2203*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, numpy_result, exact_dtype=False)
2204*da0073e9SAndroid Build Coastguard Worker
2205*da0073e9SAndroid Build Coastguard Worker    @dtypes(*(floating_types_and(torch.half, torch.bfloat16)))
2206*da0073e9SAndroid Build Coastguard Worker    def test_maximum_minimum_float_nan_and_inf(self, device, dtype):
2207*da0073e9SAndroid Build Coastguard Worker        # np.maximum and np.minimum functions compare input arrays element-wisely.
2208*da0073e9SAndroid Build Coastguard Worker        # if one of the elements being compared is a NaN, then that element is returned.
2209*da0073e9SAndroid Build Coastguard Worker        ops = (
2210*da0073e9SAndroid Build Coastguard Worker            (torch.maximum, torch.max, np.maximum),
2211*da0073e9SAndroid Build Coastguard Worker            (torch.minimum, torch.min, np.minimum),
2212*da0073e9SAndroid Build Coastguard Worker            (torch.fmax, None, np.fmax),
2213*da0073e9SAndroid Build Coastguard Worker            (torch.fmin, None, np.fmin),
2214*da0073e9SAndroid Build Coastguard Worker        )
2215*da0073e9SAndroid Build Coastguard Worker        a_vals = (
2216*da0073e9SAndroid Build Coastguard Worker            float("inf"),
2217*da0073e9SAndroid Build Coastguard Worker            -float("inf"),
2218*da0073e9SAndroid Build Coastguard Worker            float("nan"),
2219*da0073e9SAndroid Build Coastguard Worker            float("inf"),
2220*da0073e9SAndroid Build Coastguard Worker            float("nan"),
2221*da0073e9SAndroid Build Coastguard Worker            float("nan"),
2222*da0073e9SAndroid Build Coastguard Worker            1,
2223*da0073e9SAndroid Build Coastguard Worker            float("nan"),
2224*da0073e9SAndroid Build Coastguard Worker        )
2225*da0073e9SAndroid Build Coastguard Worker        b_vals = (
2226*da0073e9SAndroid Build Coastguard Worker            -float("inf"),
2227*da0073e9SAndroid Build Coastguard Worker            float("inf"),
2228*da0073e9SAndroid Build Coastguard Worker            float("inf"),
2229*da0073e9SAndroid Build Coastguard Worker            float("nan"),
2230*da0073e9SAndroid Build Coastguard Worker            float("nan"),
2231*da0073e9SAndroid Build Coastguard Worker            0,
2232*da0073e9SAndroid Build Coastguard Worker            float("nan"),
2233*da0073e9SAndroid Build Coastguard Worker            -5,
2234*da0073e9SAndroid Build Coastguard Worker        )
2235*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.bfloat16:
2236*da0073e9SAndroid Build Coastguard Worker            a_np = np.array(a_vals, dtype=np.float64)
2237*da0073e9SAndroid Build Coastguard Worker            b_np = np.array(b_vals, dtype=np.float64)
2238*da0073e9SAndroid Build Coastguard Worker        else:
2239*da0073e9SAndroid Build Coastguard Worker            a_np = np.array(a_vals, dtype=torch_to_numpy_dtype_dict[dtype])
2240*da0073e9SAndroid Build Coastguard Worker            b_np = np.array(b_vals, dtype=torch_to_numpy_dtype_dict[dtype])
2241*da0073e9SAndroid Build Coastguard Worker
2242*da0073e9SAndroid Build Coastguard Worker        for torch_op, alias, numpy_op in ops:
2243*da0073e9SAndroid Build Coastguard Worker            numpy_result = numpy_op(a_np, b_np)
2244*da0073e9SAndroid Build Coastguard Worker
2245*da0073e9SAndroid Build Coastguard Worker            a_tensor = torch.from_numpy(a_np).to(device=device, dtype=dtype)
2246*da0073e9SAndroid Build Coastguard Worker            b_tensor = torch.from_numpy(b_np).to(device=device, dtype=dtype)
2247*da0073e9SAndroid Build Coastguard Worker            tensor_result = torch_op(a_tensor, b_tensor)
2248*da0073e9SAndroid Build Coastguard Worker
2249*da0073e9SAndroid Build Coastguard Worker            out = torch.empty_like(a_tensor)
2250*da0073e9SAndroid Build Coastguard Worker            torch_op(a_tensor, b_tensor, out=out)
2251*da0073e9SAndroid Build Coastguard Worker
2252*da0073e9SAndroid Build Coastguard Worker            if alias is not None:
2253*da0073e9SAndroid Build Coastguard Worker                alias_result = alias(a_tensor, b_tensor)
2254*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(alias_result, tensor_result)
2255*da0073e9SAndroid Build Coastguard Worker
2256*da0073e9SAndroid Build Coastguard Worker            if dtype == torch.bfloat16:
2257*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(tensor_result, numpy_result, exact_dtype=False)
2258*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out, numpy_result, exact_dtype=False)
2259*da0073e9SAndroid Build Coastguard Worker            else:
2260*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(tensor_result, numpy_result)
2261*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out, numpy_result)
2262*da0073e9SAndroid Build Coastguard Worker
2263*da0073e9SAndroid Build Coastguard Worker    @dtypes(
2264*da0073e9SAndroid Build Coastguard Worker        *product(
2265*da0073e9SAndroid Build Coastguard Worker            complex_types(),
2266*da0073e9SAndroid Build Coastguard Worker            all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
2267*da0073e9SAndroid Build Coastguard Worker        )
2268*da0073e9SAndroid Build Coastguard Worker    )
2269*da0073e9SAndroid Build Coastguard Worker    def test_maximum_minimum_complex(self, device, dtypes):
2270*da0073e9SAndroid Build Coastguard Worker        for torch_op in (
2271*da0073e9SAndroid Build Coastguard Worker            torch.maximum,
2272*da0073e9SAndroid Build Coastguard Worker            torch.minimum,
2273*da0073e9SAndroid Build Coastguard Worker            torch.max,
2274*da0073e9SAndroid Build Coastguard Worker            torch.min,
2275*da0073e9SAndroid Build Coastguard Worker            torch.fmax,
2276*da0073e9SAndroid Build Coastguard Worker            torch.fmin,
2277*da0073e9SAndroid Build Coastguard Worker        ):
2278*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, ".+not implemented for.+"):
2279*da0073e9SAndroid Build Coastguard Worker                torch_op(
2280*da0073e9SAndroid Build Coastguard Worker                    torch.ones(1, device=device, dtype=dtypes[0]),
2281*da0073e9SAndroid Build Coastguard Worker                    torch.ones(1, device=device, dtype=dtypes[1]),
2282*da0073e9SAndroid Build Coastguard Worker                )
2283*da0073e9SAndroid Build Coastguard Worker
2284*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, ".+not implemented for.+"):
2285*da0073e9SAndroid Build Coastguard Worker                torch_op(
2286*da0073e9SAndroid Build Coastguard Worker                    torch.ones(1, device=device, dtype=dtypes[1]),
2287*da0073e9SAndroid Build Coastguard Worker                    torch.ones(1, device=device, dtype=dtypes[0]),
2288*da0073e9SAndroid Build Coastguard Worker                )
2289*da0073e9SAndroid Build Coastguard Worker
2290*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
2291*da0073e9SAndroid Build Coastguard Worker    def test_maximum_minimum_cross_device(self, device):
2292*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor((1, 2, -1))
2293*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor((3, 0, 4), device=device)
2294*da0073e9SAndroid Build Coastguard Worker        ops = (torch.maximum, torch.minimum)
2295*da0073e9SAndroid Build Coastguard Worker
2296*da0073e9SAndroid Build Coastguard Worker        for torch_op in ops:
2297*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
2298*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "Expected all tensors to be on the same device"
2299*da0073e9SAndroid Build Coastguard Worker            ):
2300*da0073e9SAndroid Build Coastguard Worker                torch_op(a, b)
2301*da0073e9SAndroid Build Coastguard Worker
2302*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
2303*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "Expected all tensors to be on the same device"
2304*da0073e9SAndroid Build Coastguard Worker            ):
2305*da0073e9SAndroid Build Coastguard Worker                torch_op(b, a)
2306*da0073e9SAndroid Build Coastguard Worker
2307*da0073e9SAndroid Build Coastguard Worker        # test cuda tensor and cpu scalar
2308*da0073e9SAndroid Build Coastguard Worker        ops = ((torch.maximum, np.maximum), (torch.minimum, np.minimum))
2309*da0073e9SAndroid Build Coastguard Worker        a_np = np.array(1)
2310*da0073e9SAndroid Build Coastguard Worker        b_np = np.array([3, 0, 4])
2311*da0073e9SAndroid Build Coastguard Worker
2312*da0073e9SAndroid Build Coastguard Worker        for torch_op, numpy_op in ops:
2313*da0073e9SAndroid Build Coastguard Worker            a_tensor = torch.from_numpy(a_np)
2314*da0073e9SAndroid Build Coastguard Worker            b_tensor = torch.from_numpy(b_np).to(device=device)
2315*da0073e9SAndroid Build Coastguard Worker            tensor_result_1 = torch_op(a_tensor, b_tensor)
2316*da0073e9SAndroid Build Coastguard Worker            numpy_result_1 = numpy_op(a_np, b_np)
2317*da0073e9SAndroid Build Coastguard Worker            tensor_result_2 = torch_op(b_tensor, a_tensor)
2318*da0073e9SAndroid Build Coastguard Worker            numpy_result_2 = numpy_op(b_np, a_np)
2319*da0073e9SAndroid Build Coastguard Worker
2320*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tensor_result_1, numpy_result_1)
2321*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tensor_result_2, numpy_result_2)
2322*da0073e9SAndroid Build Coastguard Worker
2323*da0073e9SAndroid Build Coastguard Worker    @dtypes(
2324*da0073e9SAndroid Build Coastguard Worker        *product(
2325*da0073e9SAndroid Build Coastguard Worker            floating_types_and(torch.half, torch.bfloat16),
2326*da0073e9SAndroid Build Coastguard Worker            floating_types_and(torch.half, torch.bfloat16),
2327*da0073e9SAndroid Build Coastguard Worker        )
2328*da0073e9SAndroid Build Coastguard Worker    )
2329*da0073e9SAndroid Build Coastguard Worker    def test_maximum_and_minimum_subgradient(self, device, dtypes):
2330*da0073e9SAndroid Build Coastguard Worker        def run_test(f, a, b, expected_a_grad, expected_b_grad):
2331*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor(a, requires_grad=True, device=device, dtype=dtypes[0])
2332*da0073e9SAndroid Build Coastguard Worker            b = torch.tensor(b, requires_grad=True, device=device, dtype=dtypes[1])
2333*da0073e9SAndroid Build Coastguard Worker            z = f(a, b)
2334*da0073e9SAndroid Build Coastguard Worker            z.sum().backward()
2335*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a.grad, expected_a_grad)
2336*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(b.grad, expected_b_grad)
2337*da0073e9SAndroid Build Coastguard Worker
2338*da0073e9SAndroid Build Coastguard Worker        run_test(
2339*da0073e9SAndroid Build Coastguard Worker            torch.maximum,
2340*da0073e9SAndroid Build Coastguard Worker            [0.0, 1.0, 2.0],
2341*da0073e9SAndroid Build Coastguard Worker            [1.0, 1.0, 1.0],
2342*da0073e9SAndroid Build Coastguard Worker            [0.0, 0.5, 1.0],
2343*da0073e9SAndroid Build Coastguard Worker            [1.0, 0.5, 0.0],
2344*da0073e9SAndroid Build Coastguard Worker        )
2345*da0073e9SAndroid Build Coastguard Worker        run_test(
2346*da0073e9SAndroid Build Coastguard Worker            torch.minimum,
2347*da0073e9SAndroid Build Coastguard Worker            [0.0, 1.0, 2.0],
2348*da0073e9SAndroid Build Coastguard Worker            [1.0, 1.0, 1.0],
2349*da0073e9SAndroid Build Coastguard Worker            [1.0, 0.5, 0.0],
2350*da0073e9SAndroid Build Coastguard Worker            [0.0, 0.5, 1.0],
2351*da0073e9SAndroid Build Coastguard Worker        )
2352*da0073e9SAndroid Build Coastguard Worker
2353*da0073e9SAndroid Build Coastguard Worker    def test_maximum_minimum_forward_ad_float32(self, device):
2354*da0073e9SAndroid Build Coastguard Worker        # TODO: This should really be covered by OpInfo but it isn't. The problem
2355*da0073e9SAndroid Build Coastguard Worker        # is that our gradient tests test using float64 but it should also test
2356*da0073e9SAndroid Build Coastguard Worker        # float32
2357*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, device=device, dtype=torch.float32)
2358*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(3, device=device, dtype=torch.float32)
2359*da0073e9SAndroid Build Coastguard Worker        tx = torch.randn(3, device=device, dtype=torch.float32)
2360*da0073e9SAndroid Build Coastguard Worker        ty = torch.randn(3, device=device, dtype=torch.float32)
2361*da0073e9SAndroid Build Coastguard Worker
2362*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
2363*da0073e9SAndroid Build Coastguard Worker            x_dual = fwAD.make_dual(x, tx)
2364*da0073e9SAndroid Build Coastguard Worker            y_dual = fwAD.make_dual(y, ty)
2365*da0073e9SAndroid Build Coastguard Worker            result = torch.maximum(x_dual, y_dual)
2366*da0073e9SAndroid Build Coastguard Worker            _, result_tangent = fwAD.unpack_dual(result)
2367*da0073e9SAndroid Build Coastguard Worker
2368*da0073e9SAndroid Build Coastguard Worker        expected = torch.where(x > y, tx, ty)
2369*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result_tangent, expected)
2370*da0073e9SAndroid Build Coastguard Worker
2371*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
2372*da0073e9SAndroid Build Coastguard Worker            x_dual = fwAD.make_dual(x, tx)
2373*da0073e9SAndroid Build Coastguard Worker            y_dual = fwAD.make_dual(y, ty)
2374*da0073e9SAndroid Build Coastguard Worker            result = torch.minimum(x_dual, y_dual)
2375*da0073e9SAndroid Build Coastguard Worker            _, result_tangent = fwAD.unpack_dual(result)
2376*da0073e9SAndroid Build Coastguard Worker
2377*da0073e9SAndroid Build Coastguard Worker        expected = torch.where(x < y, tx, ty)
2378*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result_tangent, expected)
2379*da0073e9SAndroid Build Coastguard Worker
2380*da0073e9SAndroid Build Coastguard Worker    # TODO: tests like this should be generic
2381*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.half, torch.float, torch.double)
2382*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
2383*da0073e9SAndroid Build Coastguard Worker    def test_mul_intertype_scalar(self, device, dtype):
2384*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(1.5, dtype=dtype, device=device)
2385*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor(3, dtype=torch.int32, device=device)
2386*da0073e9SAndroid Build Coastguard Worker
2387*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x * y, 4.5)
2388*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y * x, 4.5)
2389*da0073e9SAndroid Build Coastguard Worker
2390*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
2391*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "can't be cast to the desired output type"
2392*da0073e9SAndroid Build Coastguard Worker        ):
2393*da0073e9SAndroid Build Coastguard Worker            y *= x
2394*da0073e9SAndroid Build Coastguard Worker        x *= y
2395*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, 4.5)
2396*da0073e9SAndroid Build Coastguard Worker
2397*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
2398*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
2399*da0073e9SAndroid Build Coastguard Worker    def test_sub(self, device, dtype):
2400*da0073e9SAndroid Build Coastguard Worker        if dtype in integral_types():
2401*da0073e9SAndroid Build Coastguard Worker            # Before Python 3.10, floats were implicitly converted to ints, but with
2402*da0073e9SAndroid Build Coastguard Worker            #   DeprecationWarning: an integer is required (got type float).
2403*da0073e9SAndroid Build Coastguard Worker            #   Implicit conversion to integers using __int__ is deprecated,
2404*da0073e9SAndroid Build Coastguard Worker            #   and may be removed in a future version of Python.
2405*da0073e9SAndroid Build Coastguard Worker            # Since Python 3.10, that attempt gives an error.
2406*da0073e9SAndroid Build Coastguard Worker            m1 = torch.tensor([2, 4], dtype=dtype, device=device)
2407*da0073e9SAndroid Build Coastguard Worker            m2 = torch.tensor([1, 2], dtype=dtype, device=device)
2408*da0073e9SAndroid Build Coastguard Worker            diff = torch.tensor([1, 2], dtype=dtype)
2409*da0073e9SAndroid Build Coastguard Worker        else:
2410*da0073e9SAndroid Build Coastguard Worker            m1 = torch.tensor([2.34, 4.44], dtype=dtype, device=device)
2411*da0073e9SAndroid Build Coastguard Worker            m2 = torch.tensor([1.23, 2.33], dtype=dtype, device=device)
2412*da0073e9SAndroid Build Coastguard Worker            diff = torch.tensor([1.11, 2.11], dtype=dtype)
2413*da0073e9SAndroid Build Coastguard Worker
2414*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.bool:
2415*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError, lambda: m1 - m2)
2416*da0073e9SAndroid Build Coastguard Worker        elif dtype == torch.bfloat16 or dtype == torch.half:
2417*da0073e9SAndroid Build Coastguard Worker            # bfloat16 has a lower precision so we have to have a separate check for it
2418*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m1 - m2, diff, atol=0.01, rtol=0)
2419*da0073e9SAndroid Build Coastguard Worker        else:
2420*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m1 - m2, diff)
2421*da0073e9SAndroid Build Coastguard Worker
2422*da0073e9SAndroid Build Coastguard Worker    # TODO: what is this test testing?
2423*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
2424*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
2425*da0073e9SAndroid Build Coastguard Worker    def test_csub(self, device, dtype):
2426*da0073e9SAndroid Build Coastguard Worker        # with a tensor
2427*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(100, 90, dtype=dtype, device=device)
2428*da0073e9SAndroid Build Coastguard Worker        b = a.clone().normal_()
2429*da0073e9SAndroid Build Coastguard Worker
2430*da0073e9SAndroid Build Coastguard Worker        res_add = torch.add(a, b, alpha=-1)
2431*da0073e9SAndroid Build Coastguard Worker        res_csub = a.clone()
2432*da0073e9SAndroid Build Coastguard Worker        res_csub.sub_(b)
2433*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_add, res_csub)
2434*da0073e9SAndroid Build Coastguard Worker
2435*da0073e9SAndroid Build Coastguard Worker        # with a scalar
2436*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(100, 100, dtype=dtype, device=device)
2437*da0073e9SAndroid Build Coastguard Worker
2438*da0073e9SAndroid Build Coastguard Worker        scalar = 123.5
2439*da0073e9SAndroid Build Coastguard Worker        res_add = torch.add(a, -scalar)
2440*da0073e9SAndroid Build Coastguard Worker        res_csub = a.clone()
2441*da0073e9SAndroid Build Coastguard Worker        res_csub.sub_(scalar)
2442*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_add, res_csub)
2443*da0073e9SAndroid Build Coastguard Worker
2444*da0073e9SAndroid Build Coastguard Worker    # TODO: reconcile with minimum/maximum tests
2445*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.half, torch.float, torch.double)
2446*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
2447*da0073e9SAndroid Build Coastguard Worker    def test_min_max_binary_op_nan(self, device, dtype):
2448*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(1000, dtype=dtype, device=device)
2449*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(1000, dtype=dtype, device=device)
2450*da0073e9SAndroid Build Coastguard Worker
2451*da0073e9SAndroid Build Coastguard Worker        # 0:250: a -- nan, b -- not nan
2452*da0073e9SAndroid Build Coastguard Worker        a[:250] = float("nan")
2453*da0073e9SAndroid Build Coastguard Worker        # 250:500: a -- not nan, b -- nan
2454*da0073e9SAndroid Build Coastguard Worker        b[250:500] = float("nan")
2455*da0073e9SAndroid Build Coastguard Worker        # 500:750: a and b both nan
2456*da0073e9SAndroid Build Coastguard Worker        a[500:750] = float("nan")
2457*da0073e9SAndroid Build Coastguard Worker        b[500:750] = float("nan")
2458*da0073e9SAndroid Build Coastguard Worker        # 750:1000: neither nan
2459*da0073e9SAndroid Build Coastguard Worker
2460*da0073e9SAndroid Build Coastguard Worker        ma = torch.max(a, b)
2461*da0073e9SAndroid Build Coastguard Worker        mi = torch.min(a, b)
2462*da0073e9SAndroid Build Coastguard Worker
2463*da0073e9SAndroid Build Coastguard Worker        for i in range(750):
2464*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
2465*da0073e9SAndroid Build Coastguard Worker                torch.isnan(ma[i]),
2466*da0073e9SAndroid Build Coastguard Worker                f"max(a, b): {ma[i]}, a: {a[i]}, b: {b[i]}",
2467*da0073e9SAndroid Build Coastguard Worker            )
2468*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
2469*da0073e9SAndroid Build Coastguard Worker                torch.isnan(mi[i]),
2470*da0073e9SAndroid Build Coastguard Worker                f"min(a, b): {mi[i]}, a: {a[i]}, b: {b[i]}",
2471*da0073e9SAndroid Build Coastguard Worker            )
2472*da0073e9SAndroid Build Coastguard Worker
2473*da0073e9SAndroid Build Coastguard Worker        for i in range(750, 1000):
2474*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(
2475*da0073e9SAndroid Build Coastguard Worker                torch.isnan(ma[i]),
2476*da0073e9SAndroid Build Coastguard Worker                f"max(a, b): {ma[i]}, a: {a[i]}, b: {b[i]}",
2477*da0073e9SAndroid Build Coastguard Worker            )
2478*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(
2479*da0073e9SAndroid Build Coastguard Worker                torch.isnan(mi[i]),
2480*da0073e9SAndroid Build Coastguard Worker                f"min(a, b): {mi[i]}, a: {a[i]}, b: {b[i]}",
2481*da0073e9SAndroid Build Coastguard Worker            )
2482*da0073e9SAndroid Build Coastguard Worker
2483*da0073e9SAndroid Build Coastguard Worker    @dtypes(
2484*da0073e9SAndroid Build Coastguard Worker        *product(
2485*da0073e9SAndroid Build Coastguard Worker            all_types_and(torch.half, torch.bfloat16, torch.bool),
2486*da0073e9SAndroid Build Coastguard Worker            all_types_and(torch.half, torch.bfloat16, torch.bool),
2487*da0073e9SAndroid Build Coastguard Worker        )
2488*da0073e9SAndroid Build Coastguard Worker    )
2489*da0073e9SAndroid Build Coastguard Worker    def test_copysign(self, device, dtypes):
2490*da0073e9SAndroid Build Coastguard Worker        def _test_copysign_numpy(a, b):
2491*da0073e9SAndroid Build Coastguard Worker            torch_result = torch.copysign(a, b)
2492*da0073e9SAndroid Build Coastguard Worker
2493*da0073e9SAndroid Build Coastguard Worker            if a.dtype == torch.bfloat16:
2494*da0073e9SAndroid Build Coastguard Worker                np_a = a.to(torch.float).cpu().numpy()
2495*da0073e9SAndroid Build Coastguard Worker            else:
2496*da0073e9SAndroid Build Coastguard Worker                np_a = a.cpu().numpy()
2497*da0073e9SAndroid Build Coastguard Worker
2498*da0073e9SAndroid Build Coastguard Worker            if b.dtype == torch.bfloat16:
2499*da0073e9SAndroid Build Coastguard Worker                np_b = b.to(torch.float).cpu().numpy()
2500*da0073e9SAndroid Build Coastguard Worker            else:
2501*da0073e9SAndroid Build Coastguard Worker                np_b = b.cpu().numpy()
2502*da0073e9SAndroid Build Coastguard Worker            expected = torch.from_numpy(np.copysign(np_a, np_b))
2503*da0073e9SAndroid Build Coastguard Worker            # To handle inconsistencies of type promotion between PyTorch and Numpy
2504*da0073e9SAndroid Build Coastguard Worker            # Applied for both arguments having integral precision and bfloat16
2505*da0073e9SAndroid Build Coastguard Worker            types = integral_types_and(torch.bool, torch.bfloat16)
2506*da0073e9SAndroid Build Coastguard Worker            if a.dtype in types or b.dtype in types:
2507*da0073e9SAndroid Build Coastguard Worker                promoted_type = torch.promote_types(torch_result.dtype, expected.dtype)
2508*da0073e9SAndroid Build Coastguard Worker                torch_result = torch_result.to(promoted_type)
2509*da0073e9SAndroid Build Coastguard Worker                expected = expected.to(promoted_type)
2510*da0073e9SAndroid Build Coastguard Worker
2511*da0073e9SAndroid Build Coastguard Worker            # Verify Value
2512*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch_result, expected)
2513*da0073e9SAndroid Build Coastguard Worker            # Verify Sign
2514*da0073e9SAndroid Build Coastguard Worker            # Use double copysign to verify the correctnes of 0.0 and -0.0, since
2515*da0073e9SAndroid Build Coastguard Worker            # it always True for self.assertEqual(0.0 == -0.0). So, we use 1 as the
2516*da0073e9SAndroid Build Coastguard Worker            # magnitude to verify the sign between torch and numpy results, elementwise.
2517*da0073e9SAndroid Build Coastguard Worker            # Special case: NaN conversions between FP32 and FP16 is not bitwise
2518*da0073e9SAndroid Build Coastguard Worker            # equivalent to pass this assertion.
2519*da0073e9SAndroid Build Coastguard Worker            if a.dtype != torch.float16 and b.dtype != torch.float16:
2520*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
2521*da0073e9SAndroid Build Coastguard Worker                    torch.copysign(torch.tensor(1.0), torch_result),
2522*da0073e9SAndroid Build Coastguard Worker                    torch.copysign(torch.tensor(1.0), expected),
2523*da0073e9SAndroid Build Coastguard Worker                )
2524*da0073e9SAndroid Build Coastguard Worker
2525*da0073e9SAndroid Build Coastguard Worker        # Compare Result with NumPy
2526*da0073e9SAndroid Build Coastguard Worker        # Type promotion
2527*da0073e9SAndroid Build Coastguard Worker        a = make_tensor((10, 10), device=device, dtype=dtypes[0], low=-9, high=9)
2528*da0073e9SAndroid Build Coastguard Worker        b = make_tensor((10, 10), device=device, dtype=dtypes[1], low=-9, high=9)
2529*da0073e9SAndroid Build Coastguard Worker        _test_copysign_numpy(a, b)
2530*da0073e9SAndroid Build Coastguard Worker
2531*da0073e9SAndroid Build Coastguard Worker        # Broadcast
2532*da0073e9SAndroid Build Coastguard Worker        a = make_tensor((10, 1, 10), device=device, dtype=dtypes[0], low=-9, high=9)
2533*da0073e9SAndroid Build Coastguard Worker        b = make_tensor((10, 10), device=device, dtype=dtypes[1], low=-9, high=9)
2534*da0073e9SAndroid Build Coastguard Worker        _test_copysign_numpy(a, b)
2535*da0073e9SAndroid Build Coastguard Worker
2536*da0073e9SAndroid Build Coastguard Worker        a = make_tensor((10, 10), device=device, dtype=dtypes[0], low=-9, high=9)
2537*da0073e9SAndroid Build Coastguard Worker        b = make_tensor((10, 1, 10), device=device, dtype=dtypes[1], low=-9, high=9)
2538*da0073e9SAndroid Build Coastguard Worker        _test_copysign_numpy(a, b)
2539*da0073e9SAndroid Build Coastguard Worker
2540*da0073e9SAndroid Build Coastguard Worker        # 0.0/-0.0/inf/-inf/nan
2541*da0073e9SAndroid Build Coastguard Worker        cases = [0.0, -0.0, float("inf"), float("-inf"), float("nan")]
2542*da0073e9SAndroid Build Coastguard Worker        # torch.bfloat16 can not hold '-nan'
2543*da0073e9SAndroid Build Coastguard Worker        # torch.half can not hold '-nan' on CUDA
2544*da0073e9SAndroid Build Coastguard Worker        types = [torch.float32, torch.float64]
2545*da0073e9SAndroid Build Coastguard Worker        if device == "cpu":
2546*da0073e9SAndroid Build Coastguard Worker            types.append(torch.float16)
2547*da0073e9SAndroid Build Coastguard Worker        if dtypes[0] in types:
2548*da0073e9SAndroid Build Coastguard Worker            b = make_tensor((10, 10), device=device, dtype=dtypes[1], low=-9, high=9)
2549*da0073e9SAndroid Build Coastguard Worker            for case in cases:
2550*da0073e9SAndroid Build Coastguard Worker                _test_copysign_numpy(
2551*da0073e9SAndroid Build Coastguard Worker                    torch.tensor([case], device=device, dtype=dtypes[0]), b
2552*da0073e9SAndroid Build Coastguard Worker                )
2553*da0073e9SAndroid Build Coastguard Worker
2554*da0073e9SAndroid Build Coastguard Worker        if dtypes[1] in floating_types_and(torch.half, torch.bfloat16):
2555*da0073e9SAndroid Build Coastguard Worker            a = make_tensor((10, 10), device=device, dtype=dtypes[0], low=-9, high=9)
2556*da0073e9SAndroid Build Coastguard Worker            for case in cases:
2557*da0073e9SAndroid Build Coastguard Worker                _test_copysign_numpy(
2558*da0073e9SAndroid Build Coastguard Worker                    a, torch.tensor([case], device=device, dtype=dtypes[1])
2559*da0073e9SAndroid Build Coastguard Worker                )
2560*da0073e9SAndroid Build Coastguard Worker
2561*da0073e9SAndroid Build Coastguard Worker    @dtypes(
2562*da0073e9SAndroid Build Coastguard Worker        *product(
2563*da0073e9SAndroid Build Coastguard Worker            floating_types_and(torch.half, torch.bfloat16),
2564*da0073e9SAndroid Build Coastguard Worker            floating_types_and(torch.half, torch.bfloat16),
2565*da0073e9SAndroid Build Coastguard Worker        )
2566*da0073e9SAndroid Build Coastguard Worker    )
2567*da0073e9SAndroid Build Coastguard Worker    def test_copysign_subgradient(self, device, dtypes):
2568*da0073e9SAndroid Build Coastguard Worker        # Input is 0.0
2569*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(
2570*da0073e9SAndroid Build Coastguard Worker            [0.0, 0.0, 0.0], dtype=dtypes[0], device=device, requires_grad=True
2571*da0073e9SAndroid Build Coastguard Worker        )
2572*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor(
2573*da0073e9SAndroid Build Coastguard Worker            [-1.0, 0.0, 1.0], dtype=dtypes[1], device=device, requires_grad=True
2574*da0073e9SAndroid Build Coastguard Worker        )
2575*da0073e9SAndroid Build Coastguard Worker        out = torch.copysign(x, y)
2576*da0073e9SAndroid Build Coastguard Worker        out.sum().backward()
2577*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad.tolist(), [0.0, 0.0, 0.0])
2578*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.grad.tolist(), [0.0] * 3)
2579*da0073e9SAndroid Build Coastguard Worker
2580*da0073e9SAndroid Build Coastguard Worker        # Input is -0.0
2581*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(
2582*da0073e9SAndroid Build Coastguard Worker            [-0.0, -0.0, -0.0], dtype=dtypes[0], device=device, requires_grad=True
2583*da0073e9SAndroid Build Coastguard Worker        )
2584*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor(
2585*da0073e9SAndroid Build Coastguard Worker            [-1.0, 0.0, 1.0], dtype=dtypes[1], device=device, requires_grad=True
2586*da0073e9SAndroid Build Coastguard Worker        )
2587*da0073e9SAndroid Build Coastguard Worker        out = torch.copysign(x, y)
2588*da0073e9SAndroid Build Coastguard Worker        out.sum().backward()
2589*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad.tolist(), [0.0, 0.0, 0.0])
2590*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.grad.tolist(), [0.0] * 3)
2591*da0073e9SAndroid Build Coastguard Worker
2592*da0073e9SAndroid Build Coastguard Worker        # Other is 0.0
2593*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(
2594*da0073e9SAndroid Build Coastguard Worker            [-1.0, 0.0, 1.0], dtype=dtypes[0], device=device, requires_grad=True
2595*da0073e9SAndroid Build Coastguard Worker        )
2596*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor(
2597*da0073e9SAndroid Build Coastguard Worker            [0.0, 0.0, 0.0], dtype=dtypes[1], device=device, requires_grad=True
2598*da0073e9SAndroid Build Coastguard Worker        )
2599*da0073e9SAndroid Build Coastguard Worker        out = torch.copysign(x, y)
2600*da0073e9SAndroid Build Coastguard Worker        out.sum().backward()
2601*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad.tolist(), [-1.0, 0.0, 1.0])
2602*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.grad.tolist(), [0.0] * 3)
2603*da0073e9SAndroid Build Coastguard Worker
2604*da0073e9SAndroid Build Coastguard Worker        # Other is -0.0
2605*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(
2606*da0073e9SAndroid Build Coastguard Worker            [-1.0, 0.0, 1.0], dtype=dtypes[0], device=device, requires_grad=True
2607*da0073e9SAndroid Build Coastguard Worker        )
2608*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor(
2609*da0073e9SAndroid Build Coastguard Worker            [-0.0, -0.0, -0.0], dtype=dtypes[1], device=device, requires_grad=True
2610*da0073e9SAndroid Build Coastguard Worker        )
2611*da0073e9SAndroid Build Coastguard Worker        out = torch.copysign(x, y)
2612*da0073e9SAndroid Build Coastguard Worker        out.sum().backward()
2613*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad.tolist(), [1.0, 0.0, -1.0])
2614*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.grad.tolist(), [0.0] * 3)
2615*da0073e9SAndroid Build Coastguard Worker
2616*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.bfloat16, torch.float)
2617*da0073e9SAndroid Build Coastguard Worker    def test_div(self, device, dtype):
2618*da0073e9SAndroid Build Coastguard Worker        for op, method, inplace in (
2619*da0073e9SAndroid Build Coastguard Worker            (torch.div, torch.Tensor.div, torch.Tensor.div_),
2620*da0073e9SAndroid Build Coastguard Worker            (torch.true_divide, torch.Tensor.true_divide, torch.Tensor.true_divide_),
2621*da0073e9SAndroid Build Coastguard Worker        ):
2622*da0073e9SAndroid Build Coastguard Worker            m1 = torch.randn(10, 10, dtype=torch.float, device=device).to(dtype=dtype)
2623*da0073e9SAndroid Build Coastguard Worker            res1 = m1.clone()
2624*da0073e9SAndroid Build Coastguard Worker            inplace(res1[:, 3], 2)
2625*da0073e9SAndroid Build Coastguard Worker            res2 = m1.clone()
2626*da0073e9SAndroid Build Coastguard Worker            for i in range(m1.size(0)):
2627*da0073e9SAndroid Build Coastguard Worker                res2[i, 3] = res2[i, 3] / 2
2628*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res1, res2)
2629*da0073e9SAndroid Build Coastguard Worker
2630*da0073e9SAndroid Build Coastguard Worker            if dtype == torch.bfloat16:
2631*da0073e9SAndroid Build Coastguard Worker                a1 = torch.tensor([4.2, 6.2], dtype=dtype, device=device)
2632*da0073e9SAndroid Build Coastguard Worker                a2 = torch.tensor([2.0, 2.0], dtype=dtype, device=device)
2633*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
2634*da0073e9SAndroid Build Coastguard Worker                    op(a1, a2),
2635*da0073e9SAndroid Build Coastguard Worker                    torch.tensor([2.1, 3.1], dtype=dtype, device=device),
2636*da0073e9SAndroid Build Coastguard Worker                    atol=0.01,
2637*da0073e9SAndroid Build Coastguard Worker                    rtol=0,
2638*da0073e9SAndroid Build Coastguard Worker                )
2639*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(method(a1, a2), op(a1, a2))
2640*da0073e9SAndroid Build Coastguard Worker
2641*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.bfloat16, torch.float)
2642*da0073e9SAndroid Build Coastguard Worker    def test_true_divide_out(self, device, dtype):
2643*da0073e9SAndroid Build Coastguard Worker        a1 = torch.tensor([4.2, 6.2], dtype=dtype, device=device)
2644*da0073e9SAndroid Build Coastguard Worker        a2 = torch.tensor([2.0, 2.0], dtype=dtype, device=device)
2645*da0073e9SAndroid Build Coastguard Worker        res = torch.empty_like(a1)
2646*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2647*da0073e9SAndroid Build Coastguard Worker            torch.true_divide(a1, a2, out=res),
2648*da0073e9SAndroid Build Coastguard Worker            torch.tensor([2.1, 3.1], dtype=dtype, device=device),
2649*da0073e9SAndroid Build Coastguard Worker            atol=0.01,
2650*da0073e9SAndroid Build Coastguard Worker            rtol=0,
2651*da0073e9SAndroid Build Coastguard Worker        )
2652*da0073e9SAndroid Build Coastguard Worker
2653*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.half)
2654*da0073e9SAndroid Build Coastguard Worker    def test_divmul_scalar(self, device, dtype):
2655*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(100.0, device=device, dtype=dtype)
2656*da0073e9SAndroid Build Coastguard Worker        x_ref = x.float()
2657*da0073e9SAndroid Build Coastguard Worker        scale = 1e5
2658*da0073e9SAndroid Build Coastguard Worker        res = x.div(scale)
2659*da0073e9SAndroid Build Coastguard Worker        expected = x_ref.div(scale)
2660*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, expected.to(dtype), atol=0.0, rtol=0.0)
2661*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(1e-5, device=device, dtype=dtype)
2662*da0073e9SAndroid Build Coastguard Worker        x_ref = x.float()
2663*da0073e9SAndroid Build Coastguard Worker        res = x.mul(scale)
2664*da0073e9SAndroid Build Coastguard Worker        expected = x_ref.mul(scale)
2665*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, expected.to(dtype), atol=0.0, rtol=0.0)
2666*da0073e9SAndroid Build Coastguard Worker        res = scale * x
2667*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, expected.to(dtype), atol=0.0, rtol=0.0)
2668*da0073e9SAndroid Build Coastguard Worker
2669*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(
2670*da0073e9SAndroid Build Coastguard Worker        *set(get_all_math_dtypes("cuda")) - {torch.complex64, torch.complex128}
2671*da0073e9SAndroid Build Coastguard Worker    )
2672*da0073e9SAndroid Build Coastguard Worker    @dtypes(*set(get_all_math_dtypes("cpu")) - {torch.complex64, torch.complex128})
2673*da0073e9SAndroid Build Coastguard Worker    def test_floor_divide_tensor(self, device, dtype):
2674*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, device=device).mul(30).to(dtype)
2675*da0073e9SAndroid Build Coastguard Worker        y = torch.arange(1, 11, dtype=dtype, device=device)
2676*da0073e9SAndroid Build Coastguard Worker
2677*da0073e9SAndroid Build Coastguard Worker        z = x // y
2678*da0073e9SAndroid Build Coastguard Worker        z_alt = torch.floor(x.double() / y.double()).to(dtype)
2679*da0073e9SAndroid Build Coastguard Worker
2680*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z.dtype, x.dtype)
2681*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z, z_alt)
2682*da0073e9SAndroid Build Coastguard Worker
2683*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(
2684*da0073e9SAndroid Build Coastguard Worker        *set(get_all_math_dtypes("cuda")) - {torch.complex64, torch.complex128}
2685*da0073e9SAndroid Build Coastguard Worker    )
2686*da0073e9SAndroid Build Coastguard Worker    @dtypes(*set(get_all_math_dtypes("cpu")) - {torch.complex64, torch.complex128})
2687*da0073e9SAndroid Build Coastguard Worker    def test_floor_divide_scalar(self, device, dtype):
2688*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(100, device=device).mul(10).to(dtype)
2689*da0073e9SAndroid Build Coastguard Worker
2690*da0073e9SAndroid Build Coastguard Worker        z = x // 3
2691*da0073e9SAndroid Build Coastguard Worker        z_alt = torch.tensor(
2692*da0073e9SAndroid Build Coastguard Worker            [math.floor(v.item() / 3.0) for v in x], dtype=x.dtype, device=device
2693*da0073e9SAndroid Build Coastguard Worker        )
2694*da0073e9SAndroid Build Coastguard Worker
2695*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z.dtype, x.dtype)
2696*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z, z_alt)
2697*da0073e9SAndroid Build Coastguard Worker
2698*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
2699*da0073e9SAndroid Build Coastguard Worker    @dtypes(*get_all_math_dtypes("cpu"))
2700*da0073e9SAndroid Build Coastguard Worker    def test_rdiv(self, device, dtype):
2701*da0073e9SAndroid Build Coastguard Worker        if dtype is torch.float16:
2702*da0073e9SAndroid Build Coastguard Worker            return
2703*da0073e9SAndroid Build Coastguard Worker        elif dtype.is_complex:
2704*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(100, dtype=dtype, device=device).add(1).mul(4)
2705*da0073e9SAndroid Build Coastguard Worker        else:
2706*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(100, device=device).add(1).mul(4).to(dtype)
2707*da0073e9SAndroid Build Coastguard Worker        y = 30 / x
2708*da0073e9SAndroid Build Coastguard Worker        z = torch.tensor([30 / v.item() for v in x], device=device)
2709*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, z, exact_dtype=False)
2710*da0073e9SAndroid Build Coastguard Worker
2711*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_types_and(torch.half))
2712*da0073e9SAndroid Build Coastguard Worker    def test_fmod_remainder_by_zero_float(self, device, dtype):
2713*da0073e9SAndroid Build Coastguard Worker        fn_list = (torch.fmod, torch.remainder)
2714*da0073e9SAndroid Build Coastguard Worker        for fn in fn_list:
2715*da0073e9SAndroid Build Coastguard Worker            # check floating-point tensor fmod/remainder to zero is nan on both CPU and GPU
2716*da0073e9SAndroid Build Coastguard Worker            x = make_tensor((10, 10), device=device, dtype=dtype, low=-9, high=9)
2717*da0073e9SAndroid Build Coastguard Worker            zero = torch.zeros_like(x)
2718*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.all(fn(x, 0.0).isnan()))
2719*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.all(fn(x, zero).isnan()))
2720*da0073e9SAndroid Build Coastguard Worker
2721*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes  # Check Issue https://github.com/pytorch/pytorch/issues/48130
2722*da0073e9SAndroid Build Coastguard Worker    @dtypes(*integral_types())
2723*da0073e9SAndroid Build Coastguard Worker    def test_fmod_remainder_by_zero_integral(self, device, dtype):
2724*da0073e9SAndroid Build Coastguard Worker        fn_list = (torch.fmod, torch.remainder)
2725*da0073e9SAndroid Build Coastguard Worker        for fn in fn_list:
2726*da0073e9SAndroid Build Coastguard Worker            # check integral tensor fmod/remainder to zero
2727*da0073e9SAndroid Build Coastguard Worker            x = make_tensor((10, 10), device=device, dtype=dtype, low=-9, high=9)
2728*da0073e9SAndroid Build Coastguard Worker            zero = torch.zeros_like(x)
2729*da0073e9SAndroid Build Coastguard Worker            # RuntimeError on CPU
2730*da0073e9SAndroid Build Coastguard Worker            if self.device_type == "cpu":
2731*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError"):
2732*da0073e9SAndroid Build Coastguard Worker                    fn(x, zero)
2733*da0073e9SAndroid Build Coastguard Worker            elif torch.version.hip is not None:
2734*da0073e9SAndroid Build Coastguard Worker                # ROCm behavior: x % 0 is a no-op; x is returned
2735*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(fn(x, zero), x)
2736*da0073e9SAndroid Build Coastguard Worker            else:
2737*da0073e9SAndroid Build Coastguard Worker                # CUDA behavior: Different value for different dtype
2738*da0073e9SAndroid Build Coastguard Worker                # Due to it's an undefined behavior, CUDA returns a pattern of all 1s
2739*da0073e9SAndroid Build Coastguard Worker                # for integral dividend (other than int64) divided by zero. For int64,
2740*da0073e9SAndroid Build Coastguard Worker                # CUDA returns all 1s for negative dividend, half 1s for positive dividend.
2741*da0073e9SAndroid Build Coastguard Worker                # uint8: 0xff -> 255
2742*da0073e9SAndroid Build Coastguard Worker                # int32: 0xffffffff -> -1
2743*da0073e9SAndroid Build Coastguard Worker                if dtype == torch.int64:
2744*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(fn(x, zero) == 4294967295, x >= 0)
2745*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(fn(x, zero) == -1, x < 0)
2746*da0073e9SAndroid Build Coastguard Worker                else:
2747*da0073e9SAndroid Build Coastguard Worker                    value = 255 if dtype == torch.uint8 else -1
2748*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(torch.all(fn(x, zero) == value))
2749*da0073e9SAndroid Build Coastguard Worker
2750*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and(torch.half))
2751*da0073e9SAndroid Build Coastguard Worker    def test_fmod_remainder(self, device, dtype):
2752*da0073e9SAndroid Build Coastguard Worker        # Use numpy as reference
2753*da0073e9SAndroid Build Coastguard Worker        def _helper(x, mod, fns_list):
2754*da0073e9SAndroid Build Coastguard Worker            for fn, inplace_fn, ref_fn in fns_list:
2755*da0073e9SAndroid Build Coastguard Worker                np_x = x.cpu().numpy() if torch.is_tensor(x) else x
2756*da0073e9SAndroid Build Coastguard Worker                np_mod = mod.cpu().numpy() if torch.is_tensor(mod) else mod
2757*da0073e9SAndroid Build Coastguard Worker                exp = ref_fn(np_x, np_mod)
2758*da0073e9SAndroid Build Coastguard Worker                exp = torch.from_numpy(exp)
2759*da0073e9SAndroid Build Coastguard Worker                res = fn(x, mod)
2760*da0073e9SAndroid Build Coastguard Worker
2761*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res, exp, exact_dtype=False)
2762*da0073e9SAndroid Build Coastguard Worker
2763*da0073e9SAndroid Build Coastguard Worker                if torch.is_tensor(x):
2764*da0073e9SAndroid Build Coastguard Worker                    # out
2765*da0073e9SAndroid Build Coastguard Worker                    out = torch.empty(0, device=device, dtype=res.dtype)
2766*da0073e9SAndroid Build Coastguard Worker                    fn(x, mod, out=out)
2767*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(out, exp, exact_dtype=False)
2768*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(out.size(), torch.Size([10, 10]))
2769*da0073e9SAndroid Build Coastguard Worker                    # in-place (Type cast runtime error)
2770*da0073e9SAndroid Build Coastguard Worker                    try:
2771*da0073e9SAndroid Build Coastguard Worker                        inplace_fn(x, mod)
2772*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(x, exp, exact_dtype=False)
2773*da0073e9SAndroid Build Coastguard Worker                    except RuntimeError as e:
2774*da0073e9SAndroid Build Coastguard Worker                        self.assertRegex(
2775*da0073e9SAndroid Build Coastguard Worker                            str(e),
2776*da0073e9SAndroid Build Coastguard Worker                            "result type (Half|Float|Double) "
2777*da0073e9SAndroid Build Coastguard Worker                            "can't be cast to the desired output "
2778*da0073e9SAndroid Build Coastguard Worker                            "type (Byte|Char|Short|Int|Long)",
2779*da0073e9SAndroid Build Coastguard Worker                        )
2780*da0073e9SAndroid Build Coastguard Worker
2781*da0073e9SAndroid Build Coastguard Worker        x = make_tensor((10, 10), device=device, dtype=dtype, low=-9, high=9)
2782*da0073e9SAndroid Build Coastguard Worker        # mod with same dtype as x
2783*da0073e9SAndroid Build Coastguard Worker        mod = make_tensor((10, 10), device=device, dtype=dtype, low=-9, high=9)
2784*da0073e9SAndroid Build Coastguard Worker        # Exclude 0
2785*da0073e9SAndroid Build Coastguard Worker        mod[mod == 0] = 1
2786*da0073e9SAndroid Build Coastguard Worker
2787*da0073e9SAndroid Build Coastguard Worker        # Mods: Integer, Float, Tensor, Non-contiguous Tensor
2788*da0073e9SAndroid Build Coastguard Worker        mods = [3, 2.3, mod, mod.t()]
2789*da0073e9SAndroid Build Coastguard Worker        # mod with floating-point dtype
2790*da0073e9SAndroid Build Coastguard Worker        if dtype in integral_types():
2791*da0073e9SAndroid Build Coastguard Worker            mod_float = make_tensor(
2792*da0073e9SAndroid Build Coastguard Worker                (10, 10), device=device, dtype=torch.float, low=-9, high=9
2793*da0073e9SAndroid Build Coastguard Worker            )
2794*da0073e9SAndroid Build Coastguard Worker            mod[mod == 0] = 1
2795*da0073e9SAndroid Build Coastguard Worker            mods.append(mod_float)
2796*da0073e9SAndroid Build Coastguard Worker
2797*da0073e9SAndroid Build Coastguard Worker        for dividend, mod in product([x, x.t()], mods):
2798*da0073e9SAndroid Build Coastguard Worker            _helper(
2799*da0073e9SAndroid Build Coastguard Worker                dividend,
2800*da0073e9SAndroid Build Coastguard Worker                mod,
2801*da0073e9SAndroid Build Coastguard Worker                (
2802*da0073e9SAndroid Build Coastguard Worker                    (torch.fmod, torch.Tensor.fmod_, np.fmod),
2803*da0073e9SAndroid Build Coastguard Worker                    (torch.remainder, torch.Tensor.remainder_, np.remainder),
2804*da0073e9SAndroid Build Coastguard Worker                ),
2805*da0073e9SAndroid Build Coastguard Worker            )
2806*da0073e9SAndroid Build Coastguard Worker
2807*da0073e9SAndroid Build Coastguard Worker        # Tests for torch.remainder(scalar, tensor)
2808*da0073e9SAndroid Build Coastguard Worker        for dividend, mod in product([5, 3.14], mods):
2809*da0073e9SAndroid Build Coastguard Worker            if torch.is_tensor(mod):
2810*da0073e9SAndroid Build Coastguard Worker                _helper(
2811*da0073e9SAndroid Build Coastguard Worker                    dividend,
2812*da0073e9SAndroid Build Coastguard Worker                    mod,
2813*da0073e9SAndroid Build Coastguard Worker                    ((torch.remainder, torch.Tensor.remainder_, np.remainder),),
2814*da0073e9SAndroid Build Coastguard Worker                )
2815*da0073e9SAndroid Build Coastguard Worker
2816*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
2817*da0073e9SAndroid Build Coastguard Worker    def test_remainder_fmod_large_dividend(self, device, dtype):
2818*da0073e9SAndroid Build Coastguard Worker        alarge = 1e9
2819*da0073e9SAndroid Build Coastguard Worker        pi = 3.14159265358979
2820*da0073e9SAndroid Build Coastguard Worker        for avalue in [alarge, -alarge]:
2821*da0073e9SAndroid Build Coastguard Worker            for bvalue in [pi, -pi]:
2822*da0073e9SAndroid Build Coastguard Worker                a = torch.tensor([avalue], dtype=dtype, device=device)
2823*da0073e9SAndroid Build Coastguard Worker                b = torch.tensor([bvalue], dtype=dtype, device=device)
2824*da0073e9SAndroid Build Coastguard Worker                c = torch.remainder(a, b)
2825*da0073e9SAndroid Build Coastguard Worker                d = torch.fmod(a, b)
2826*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(
2827*da0073e9SAndroid Build Coastguard Worker                    (b[0] > 0) == (c[0] > 0)
2828*da0073e9SAndroid Build Coastguard Worker                )  # remainder has same sign as divisor
2829*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(
2830*da0073e9SAndroid Build Coastguard Worker                    (a[0] > 0) == (d[0] > 0)
2831*da0073e9SAndroid Build Coastguard Worker                )  # fmod has same sign as dividend
2832*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(
2833*da0073e9SAndroid Build Coastguard Worker                    abs(c[0]) < abs(b[0])
2834*da0073e9SAndroid Build Coastguard Worker                )  # remainder is within range of divisor
2835*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(
2836*da0073e9SAndroid Build Coastguard Worker                    abs(d[0]) < abs(b[0])
2837*da0073e9SAndroid Build Coastguard Worker                )  # fmod is within range of divisor
2838*da0073e9SAndroid Build Coastguard Worker                if (a[0] > 0) == (b[0] > 0):
2839*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(c[0] == d[0])  # remainder is same as fmod
2840*da0073e9SAndroid Build Coastguard Worker                else:
2841*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(
2842*da0073e9SAndroid Build Coastguard Worker                        abs(c[0] - d[0]) == abs(b[0])
2843*da0073e9SAndroid Build Coastguard Worker                    )  # differ by one divisor
2844*da0073e9SAndroid Build Coastguard Worker
2845*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCPU(torch.bfloat16, torch.half, torch.float32, torch.float64)
2846*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32, torch.float64)
2847*da0073e9SAndroid Build Coastguard Worker    def test_hypot(self, device, dtype):
2848*da0073e9SAndroid Build Coastguard Worker        inputs = [
2849*da0073e9SAndroid Build Coastguard Worker            (
2850*da0073e9SAndroid Build Coastguard Worker                torch.randn(10, device=device).to(dtype),
2851*da0073e9SAndroid Build Coastguard Worker                torch.randn(10, device=device).to(dtype),
2852*da0073e9SAndroid Build Coastguard Worker            ),
2853*da0073e9SAndroid Build Coastguard Worker            (
2854*da0073e9SAndroid Build Coastguard Worker                torch.randn((3, 3, 3), device=device).to(dtype),
2855*da0073e9SAndroid Build Coastguard Worker                torch.randn((3, 3, 3), device=device).to(dtype),
2856*da0073e9SAndroid Build Coastguard Worker            ),
2857*da0073e9SAndroid Build Coastguard Worker            (
2858*da0073e9SAndroid Build Coastguard Worker                torch.randn((10, 1), device=device).to(dtype),
2859*da0073e9SAndroid Build Coastguard Worker                torch.randn((10, 1), device=device).to(dtype).transpose(0, 1),
2860*da0073e9SAndroid Build Coastguard Worker            ),
2861*da0073e9SAndroid Build Coastguard Worker            (
2862*da0073e9SAndroid Build Coastguard Worker                torch.randint(100, (10,), device=device, dtype=torch.long),
2863*da0073e9SAndroid Build Coastguard Worker                torch.randn(10, device=device).to(dtype),
2864*da0073e9SAndroid Build Coastguard Worker            ),
2865*da0073e9SAndroid Build Coastguard Worker        ]
2866*da0073e9SAndroid Build Coastguard Worker        for input in inputs:
2867*da0073e9SAndroid Build Coastguard Worker            actual = torch.hypot(input[0], input[1])
2868*da0073e9SAndroid Build Coastguard Worker            if dtype in [torch.bfloat16, torch.half]:
2869*da0073e9SAndroid Build Coastguard Worker                expected = torch.sqrt(input[0] * input[0] + input[1] * input[1])
2870*da0073e9SAndroid Build Coastguard Worker            else:
2871*da0073e9SAndroid Build Coastguard Worker                expected = np.hypot(input[0].cpu().numpy(), input[1].cpu().numpy())
2872*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expected, exact_dtype=False)
2873*da0073e9SAndroid Build Coastguard Worker
2874*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
2875*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
2876*da0073e9SAndroid Build Coastguard Worker    def test_gcd(self, device, dtype):
2877*da0073e9SAndroid Build Coastguard Worker        # Tests gcd(0, 0), gcd(0, a) cases
2878*da0073e9SAndroid Build Coastguard Worker        t1 = torch.tensor([0, 10, 0], dtype=dtype, device=device)
2879*da0073e9SAndroid Build Coastguard Worker        t2 = torch.tensor([0, 0, 10], dtype=dtype, device=device)
2880*da0073e9SAndroid Build Coastguard Worker        actual = torch.gcd(t1, t2)
2881*da0073e9SAndroid Build Coastguard Worker        expected = np.gcd([0, 10, 0], [0, 0, 10])
2882*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual, expected, exact_dtype=False)
2883*da0073e9SAndroid Build Coastguard Worker
2884*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.uint8:
2885*da0073e9SAndroid Build Coastguard Worker            # Test unsigned integers with potential sign issues (i.e., uint8 with value >= 128)
2886*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor([190, 210], device=device, dtype=dtype)
2887*da0073e9SAndroid Build Coastguard Worker            b = torch.tensor([190, 220], device=device, dtype=dtype)
2888*da0073e9SAndroid Build Coastguard Worker            actual = torch.gcd(a, b)
2889*da0073e9SAndroid Build Coastguard Worker            expected = torch.tensor([190, 10], device=device, dtype=dtype)
2890*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expected)
2891*da0073e9SAndroid Build Coastguard Worker        else:
2892*da0073e9SAndroid Build Coastguard Worker            # Compares with NumPy
2893*da0073e9SAndroid Build Coastguard Worker            a = torch.randint(-20, 20, (1024,), device=device, dtype=dtype)
2894*da0073e9SAndroid Build Coastguard Worker            b = torch.randint(-20, 20, (1024,), device=device, dtype=dtype)
2895*da0073e9SAndroid Build Coastguard Worker            actual = torch.gcd(a, b)
2896*da0073e9SAndroid Build Coastguard Worker            expected = np.gcd(a.cpu().numpy(), b.cpu().numpy())
2897*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expected)
2898*da0073e9SAndroid Build Coastguard Worker
2899*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
2900*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.int16, torch.int32, torch.int64)
2901*da0073e9SAndroid Build Coastguard Worker    def test_lcm(self, device, dtype):
2902*da0073e9SAndroid Build Coastguard Worker        # Tests lcm(0, 0), lcm(0, a) cases
2903*da0073e9SAndroid Build Coastguard Worker        t1 = torch.tensor([0, 10, 0], dtype=dtype, device=device)
2904*da0073e9SAndroid Build Coastguard Worker        t2 = torch.tensor([0, 0, 10], dtype=dtype, device=device)
2905*da0073e9SAndroid Build Coastguard Worker        actual = torch.lcm(t1, t2)
2906*da0073e9SAndroid Build Coastguard Worker        expected = np.lcm([0, 10, 0], [0, 0, 10])
2907*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual, expected, exact_dtype=False)
2908*da0073e9SAndroid Build Coastguard Worker
2909*da0073e9SAndroid Build Coastguard Worker        # Compares with NumPy
2910*da0073e9SAndroid Build Coastguard Worker        a = torch.randint(-20, 20, (1024,), device=device, dtype=dtype)
2911*da0073e9SAndroid Build Coastguard Worker        b = torch.randint(-20, 20, (1024,), device=device, dtype=dtype)
2912*da0073e9SAndroid Build Coastguard Worker        actual = torch.lcm(a, b)
2913*da0073e9SAndroid Build Coastguard Worker        expected = np.lcm(a.cpu().numpy(), b.cpu().numpy())
2914*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual, expected, exact_dtype=False)
2915*da0073e9SAndroid Build Coastguard Worker
2916*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
2917*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCPU(torch.float32, torch.float64, torch.float16)
2918*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32, torch.float64)
2919*da0073e9SAndroid Build Coastguard Worker    def test_nextafter(self, device, dtype):
2920*da0073e9SAndroid Build Coastguard Worker        # Test special cases
2921*da0073e9SAndroid Build Coastguard Worker        t1 = torch.tensor([0, 0, 10], device=device, dtype=dtype)
2922*da0073e9SAndroid Build Coastguard Worker        t2 = torch.tensor([inf, -inf, 10], device=device, dtype=dtype)
2923*da0073e9SAndroid Build Coastguard Worker        actual = torch.nextafter(t1, t2)
2924*da0073e9SAndroid Build Coastguard Worker        expected = np.nextafter(t1.cpu().numpy(), t2.cpu().numpy())
2925*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual, expected, atol=0, rtol=0)
2926*da0073e9SAndroid Build Coastguard Worker
2927*da0073e9SAndroid Build Coastguard Worker        actual = torch.nextafter(t2, t1)
2928*da0073e9SAndroid Build Coastguard Worker        expected = np.nextafter(t2.cpu().numpy(), t1.cpu().numpy())
2929*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual, expected, atol=0, rtol=0)
2930*da0073e9SAndroid Build Coastguard Worker
2931*da0073e9SAndroid Build Coastguard Worker        t1 = torch.tensor([0, nan], device=device, dtype=dtype)
2932*da0073e9SAndroid Build Coastguard Worker        t2 = torch.tensor([nan, 0], device=device, dtype=dtype)
2933*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.nextafter(t1, t2).isnan().all())
2934*da0073e9SAndroid Build Coastguard Worker
2935*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(100, device=device, dtype=dtype)
2936*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(100, device=device, dtype=dtype)
2937*da0073e9SAndroid Build Coastguard Worker        actual = torch.nextafter(a, b)
2938*da0073e9SAndroid Build Coastguard Worker        expected = np.nextafter(a.cpu().numpy(), b.cpu().numpy())
2939*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual, expected, atol=0, rtol=0)
2940*da0073e9SAndroid Build Coastguard Worker
2941*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
2942*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.bfloat16)
2943*da0073e9SAndroid Build Coastguard Worker    def test_nextafter_bfloat16(self, device, dtype):
2944*da0073e9SAndroid Build Coastguard Worker        nan = float("nan")
2945*da0073e9SAndroid Build Coastguard Worker        inf = float("inf")
2946*da0073e9SAndroid Build Coastguard Worker        cases = (
2947*da0073e9SAndroid Build Coastguard Worker            # (from, to, expected)
2948*da0073e9SAndroid Build Coastguard Worker            (0, 1, 9.183549615799121e-41),
2949*da0073e9SAndroid Build Coastguard Worker            (0, -1, -9.183549615799121e-41),
2950*da0073e9SAndroid Build Coastguard Worker            (1, -2, 0.99609375),
2951*da0073e9SAndroid Build Coastguard Worker            (1, 0, 0.99609375),
2952*da0073e9SAndroid Build Coastguard Worker            (1, 2, 1.0078125),
2953*da0073e9SAndroid Build Coastguard Worker            (-1, -2, -1.0078125),
2954*da0073e9SAndroid Build Coastguard Worker            (-1, 0, -0.99609375),
2955*da0073e9SAndroid Build Coastguard Worker            (2, -1, 1.9921875),
2956*da0073e9SAndroid Build Coastguard Worker            (2, 1, 1.9921875),
2957*da0073e9SAndroid Build Coastguard Worker            (20, 3000, 20.125),
2958*da0073e9SAndroid Build Coastguard Worker            (20, -3000, 19.875),
2959*da0073e9SAndroid Build Coastguard Worker            (3000, -20, 2992.0),
2960*da0073e9SAndroid Build Coastguard Worker            (-3000, 20, -2992.0),
2961*da0073e9SAndroid Build Coastguard Worker            (65536, 0, 65280.0),
2962*da0073e9SAndroid Build Coastguard Worker            (65536, inf, 66048.0),
2963*da0073e9SAndroid Build Coastguard Worker            (-65536, 0, -65280.0),
2964*da0073e9SAndroid Build Coastguard Worker            (-65536, -inf, -66048.0),
2965*da0073e9SAndroid Build Coastguard Worker            (nan, 0, nan),
2966*da0073e9SAndroid Build Coastguard Worker            (0, nan, nan),
2967*da0073e9SAndroid Build Coastguard Worker            (nan, nan, nan),
2968*da0073e9SAndroid Build Coastguard Worker            (nan, inf, nan),
2969*da0073e9SAndroid Build Coastguard Worker            (inf, nan, nan),
2970*da0073e9SAndroid Build Coastguard Worker            (inf, -inf, 3.3895313892515355e38),
2971*da0073e9SAndroid Build Coastguard Worker            (-inf, inf, -3.3895313892515355e38),
2972*da0073e9SAndroid Build Coastguard Worker            (inf, 0, 3.3895313892515355e38),
2973*da0073e9SAndroid Build Coastguard Worker            (0, inf, 9.183549615799121e-41),
2974*da0073e9SAndroid Build Coastguard Worker            (-inf, 0, -3.3895313892515355e38),
2975*da0073e9SAndroid Build Coastguard Worker            (0, -inf, -9.183549615799121e-41),
2976*da0073e9SAndroid Build Coastguard Worker        )
2977*da0073e9SAndroid Build Coastguard Worker
2978*da0073e9SAndroid Build Coastguard Worker        for from_v, to_v, expected in cases:
2979*da0073e9SAndroid Build Coastguard Worker            from_t = torch.tensor([from_v], device=device, dtype=dtype)
2980*da0073e9SAndroid Build Coastguard Worker            to_t = torch.tensor([to_v], device=device, dtype=dtype)
2981*da0073e9SAndroid Build Coastguard Worker            actual = torch.nextafter(from_t, to_t).item()
2982*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expected, atol=0, rtol=0)
2983*da0073e9SAndroid Build Coastguard Worker
2984*da0073e9SAndroid Build Coastguard Worker    def _test_cop(self, torchfn, mathfn, dtype, device):
2985*da0073e9SAndroid Build Coastguard Worker        def reference_implementation(res2):
2986*da0073e9SAndroid Build Coastguard Worker            for i, j in iter_indices(sm1):
2987*da0073e9SAndroid Build Coastguard Worker                idx1d = i * sm1.size(0) + j
2988*da0073e9SAndroid Build Coastguard Worker                res2[i, j] = mathfn(sm1[i, j], sm2[idx1d])
2989*da0073e9SAndroid Build Coastguard Worker            return res2
2990*da0073e9SAndroid Build Coastguard Worker
2991*da0073e9SAndroid Build Coastguard Worker        # contiguous
2992*da0073e9SAndroid Build Coastguard Worker        m1 = torch.randn(10, 10, 10, dtype=dtype, device=device)
2993*da0073e9SAndroid Build Coastguard Worker        m2 = torch.randn(10, 10 * 10, dtype=dtype, device=device)
2994*da0073e9SAndroid Build Coastguard Worker        sm1 = m1[4]
2995*da0073e9SAndroid Build Coastguard Worker        sm2 = m2[4]
2996*da0073e9SAndroid Build Coastguard Worker
2997*da0073e9SAndroid Build Coastguard Worker        res1 = torchfn(sm1, sm2.view(10, 10))
2998*da0073e9SAndroid Build Coastguard Worker        res2 = reference_implementation(res1.clone())
2999*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, res2)
3000*da0073e9SAndroid Build Coastguard Worker
3001*da0073e9SAndroid Build Coastguard Worker        # non-contiguous
3002*da0073e9SAndroid Build Coastguard Worker        m1 = torch.randn(10, 10, 10, dtype=dtype, device=device)
3003*da0073e9SAndroid Build Coastguard Worker        m2 = torch.randn(10 * 10, 10 * 10, dtype=dtype, device=device)
3004*da0073e9SAndroid Build Coastguard Worker        sm1 = m1[:, 4]
3005*da0073e9SAndroid Build Coastguard Worker        sm2 = m2[:, 4]
3006*da0073e9SAndroid Build Coastguard Worker        # view as sm1.size()
3007*da0073e9SAndroid Build Coastguard Worker        sm2.set_(
3008*da0073e9SAndroid Build Coastguard Worker            sm2.storage(),
3009*da0073e9SAndroid Build Coastguard Worker            sm2.storage_offset(),
3010*da0073e9SAndroid Build Coastguard Worker            sm1.size(),
3011*da0073e9SAndroid Build Coastguard Worker            (sm2.stride()[0] * 10, sm2.stride()[0]),
3012*da0073e9SAndroid Build Coastguard Worker        )
3013*da0073e9SAndroid Build Coastguard Worker        res1 = torchfn(sm1, sm2)
3014*da0073e9SAndroid Build Coastguard Worker        # reference_implementation assumes 1-d sm2
3015*da0073e9SAndroid Build Coastguard Worker        sm2.set_(
3016*da0073e9SAndroid Build Coastguard Worker            sm2.storage(), sm2.storage_offset(), m2[:, 4].size(), m2[:, 4].stride()
3017*da0073e9SAndroid Build Coastguard Worker        )
3018*da0073e9SAndroid Build Coastguard Worker        res2 = reference_implementation(res1.clone())
3019*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, res2)
3020*da0073e9SAndroid Build Coastguard Worker
3021*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
3022*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
3023*da0073e9SAndroid Build Coastguard Worker    def test_cdiv(self, device, dtype):
3024*da0073e9SAndroid Build Coastguard Worker        self._test_cop(torch.div, operator.truediv, dtype, device)
3025*da0073e9SAndroid Build Coastguard Worker
3026*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
3027*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
3028*da0073e9SAndroid Build Coastguard Worker    def test_cremainder(self, device, dtype):
3029*da0073e9SAndroid Build Coastguard Worker        self._test_cop(torch.remainder, operator.mod, dtype, device)
3030*da0073e9SAndroid Build Coastguard Worker
3031*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
3032*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
3033*da0073e9SAndroid Build Coastguard Worker    def test_cmul(self, device, dtype):
3034*da0073e9SAndroid Build Coastguard Worker        self._test_cop(torch.mul, operator.mul, dtype, device)
3035*da0073e9SAndroid Build Coastguard Worker
3036*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
3037*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
3038*da0073e9SAndroid Build Coastguard Worker    def test_cpow(self, device, dtype):
3039*da0073e9SAndroid Build Coastguard Worker        self._test_cop(
3040*da0073e9SAndroid Build Coastguard Worker            torch.pow, lambda x, y: nan if x < 0 else math.pow(x, y), dtype, device
3041*da0073e9SAndroid Build Coastguard Worker        )
3042*da0073e9SAndroid Build Coastguard Worker
3043*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
3044*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
3045*da0073e9SAndroid Build Coastguard Worker    def test_floor_divide_zero(self, device, dtype):
3046*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([0, 1], dtype=dtype, device=device)
3047*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor([0, 1], dtype=dtype, device=device)
3048*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError"):
3049*da0073e9SAndroid Build Coastguard Worker            with self.assertWarnsOnceRegex(UserWarning, "floor_divide"):
3050*da0073e9SAndroid Build Coastguard Worker                a // b
3051*da0073e9SAndroid Build Coastguard Worker
3052*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
3053*da0073e9SAndroid Build Coastguard Worker    def test_muldiv_scalar(self, device, dtype):
3054*da0073e9SAndroid Build Coastguard Worker        x = make_tensor((10, 3), dtype=dtype, device=device, low=None, high=None)
3055*da0073e9SAndroid Build Coastguard Worker        s = make_tensor((1,), dtype=dtype, device="cpu", low=None, high=None).item()
3056*da0073e9SAndroid Build Coastguard Worker        y = torch.full_like(x, s)
3057*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x * s, x * y)
3058*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s * x, y * x)
3059*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x / s, x / y)
3060*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s / x, y / x)
3061*da0073e9SAndroid Build Coastguard Worker
3062*da0073e9SAndroid Build Coastguard Worker    # TODO: update make_tensor to support extremal additions and remove this in favor of make_tensor
3063*da0073e9SAndroid Build Coastguard Worker    def _generate_input(self, shape, dtype, device, with_extremal):
3064*da0073e9SAndroid Build Coastguard Worker        if shape == ():
3065*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor((), dtype=dtype, device=device)
3066*da0073e9SAndroid Build Coastguard Worker        else:
3067*da0073e9SAndroid Build Coastguard Worker            if dtype.is_floating_point or dtype.is_complex:
3068*da0073e9SAndroid Build Coastguard Worker                # work around torch.randn not being implemented for bfloat16
3069*da0073e9SAndroid Build Coastguard Worker                if dtype == torch.bfloat16:
3070*da0073e9SAndroid Build Coastguard Worker                    x = torch.randn(*shape, device=device) * random.randint(30, 100)
3071*da0073e9SAndroid Build Coastguard Worker                    x = x.to(torch.bfloat16)
3072*da0073e9SAndroid Build Coastguard Worker                else:
3073*da0073e9SAndroid Build Coastguard Worker                    x = torch.randn(
3074*da0073e9SAndroid Build Coastguard Worker                        *shape, dtype=dtype, device=device
3075*da0073e9SAndroid Build Coastguard Worker                    ) * random.randint(30, 100)
3076*da0073e9SAndroid Build Coastguard Worker                x[torch.randn(*shape) > 0.5] = 0
3077*da0073e9SAndroid Build Coastguard Worker                if with_extremal and dtype.is_floating_point:
3078*da0073e9SAndroid Build Coastguard Worker                    # Use extremal values
3079*da0073e9SAndroid Build Coastguard Worker                    x[torch.randn(*shape) > 0.5] = float("nan")
3080*da0073e9SAndroid Build Coastguard Worker                    x[torch.randn(*shape) > 0.5] = float("inf")
3081*da0073e9SAndroid Build Coastguard Worker                    x[torch.randn(*shape) > 0.5] = float("-inf")
3082*da0073e9SAndroid Build Coastguard Worker                elif with_extremal and dtype.is_complex:
3083*da0073e9SAndroid Build Coastguard Worker                    x[torch.randn(*shape) > 0.5] = complex("nan")
3084*da0073e9SAndroid Build Coastguard Worker                    x[torch.randn(*shape) > 0.5] = complex("inf")
3085*da0073e9SAndroid Build Coastguard Worker                    x[torch.randn(*shape) > 0.5] = complex("-inf")
3086*da0073e9SAndroid Build Coastguard Worker            elif dtype == torch.bool:
3087*da0073e9SAndroid Build Coastguard Worker                x = torch.zeros(shape, dtype=dtype, device=device)
3088*da0073e9SAndroid Build Coastguard Worker                x[torch.randn(*shape) > 0.5] = True
3089*da0073e9SAndroid Build Coastguard Worker            else:
3090*da0073e9SAndroid Build Coastguard Worker                x = torch.randint(15, 100, shape, dtype=dtype, device=device)
3091*da0073e9SAndroid Build Coastguard Worker
3092*da0073e9SAndroid Build Coastguard Worker        return x
3093*da0073e9SAndroid Build Coastguard Worker
3094*da0073e9SAndroid Build Coastguard Worker    @dtypes(
3095*da0073e9SAndroid Build Coastguard Worker        *tuple(
3096*da0073e9SAndroid Build Coastguard Worker            itertools.combinations_with_replacement(
3097*da0073e9SAndroid Build Coastguard Worker                all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), 2
3098*da0073e9SAndroid Build Coastguard Worker            )
3099*da0073e9SAndroid Build Coastguard Worker        )
3100*da0073e9SAndroid Build Coastguard Worker    )
3101*da0073e9SAndroid Build Coastguard Worker    def test_comparison_ops_type_promotion_and_broadcasting(self, device, dtypes):
3102*da0073e9SAndroid Build Coastguard Worker        # issue #42660
3103*da0073e9SAndroid Build Coastguard Worker        # testing all combinations of broadcasting and type promotion
3104*da0073e9SAndroid Build Coastguard Worker        # with a range of dtypes and input shapes, and with extremal values
3105*da0073e9SAndroid Build Coastguard Worker        def compare_with_numpy_bin_op(torch_fn, np_fn, x, y, out=None):
3106*da0073e9SAndroid Build Coastguard Worker            # working around the fact that numpy doesn't support bfloat16
3107*da0073e9SAndroid Build Coastguard Worker            # by letting numpy treat them as float32's
3108*da0073e9SAndroid Build Coastguard Worker            x_np = x if x.dtype != torch.bfloat16 else x.to(torch.float32)
3109*da0073e9SAndroid Build Coastguard Worker            y_np = (
3110*da0073e9SAndroid Build Coastguard Worker                y.cpu().numpy()
3111*da0073e9SAndroid Build Coastguard Worker                if y.dtype != torch.bfloat16
3112*da0073e9SAndroid Build Coastguard Worker                else y.to(torch.float32).cpu().numpy()
3113*da0073e9SAndroid Build Coastguard Worker            )
3114*da0073e9SAndroid Build Coastguard Worker            self.compare_with_numpy(
3115*da0073e9SAndroid Build Coastguard Worker                lambda inp: torch_fn(inp, y, out=out) if out else torch_fn(inp, y),
3116*da0073e9SAndroid Build Coastguard Worker                lambda inp: np_fn(inp, y_np, out=out) if out else np_fn(inp, y_np),
3117*da0073e9SAndroid Build Coastguard Worker                x_np,
3118*da0073e9SAndroid Build Coastguard Worker            )
3119*da0073e9SAndroid Build Coastguard Worker
3120*da0073e9SAndroid Build Coastguard Worker        complex_op_denylist = [
3121*da0073e9SAndroid Build Coastguard Worker            torch.lt,
3122*da0073e9SAndroid Build Coastguard Worker            torch.le,
3123*da0073e9SAndroid Build Coastguard Worker            torch.gt,
3124*da0073e9SAndroid Build Coastguard Worker            torch.ge,
3125*da0073e9SAndroid Build Coastguard Worker        ]  # complex not supported
3126*da0073e9SAndroid Build Coastguard Worker        input_sizes = [(1,), (10,), (10, 1), (1, 10), (4, 10), (64, 10), (12, 3)]
3127*da0073e9SAndroid Build Coastguard Worker        op_pairs = [
3128*da0073e9SAndroid Build Coastguard Worker            (torch.lt, np.less),
3129*da0073e9SAndroid Build Coastguard Worker            (torch.le, np.less_equal),
3130*da0073e9SAndroid Build Coastguard Worker            (torch.gt, np.greater),
3131*da0073e9SAndroid Build Coastguard Worker            (torch.ge, np.greater_equal),
3132*da0073e9SAndroid Build Coastguard Worker            (torch.eq, np.equal),
3133*da0073e9SAndroid Build Coastguard Worker            (torch.ne, np.not_equal),
3134*da0073e9SAndroid Build Coastguard Worker            (torch.logical_and, np.logical_and),
3135*da0073e9SAndroid Build Coastguard Worker            (torch.logical_or, np.logical_or),
3136*da0073e9SAndroid Build Coastguard Worker            (torch.logical_xor, np.logical_xor),
3137*da0073e9SAndroid Build Coastguard Worker        ]
3138*da0073e9SAndroid Build Coastguard Worker
3139*da0073e9SAndroid Build Coastguard Worker        for size1 in input_sizes:
3140*da0073e9SAndroid Build Coastguard Worker            size2 = (2,) + size1  # perform broadcasting
3141*da0073e9SAndroid Build Coastguard Worker            for with_extremal in [False, True]:
3142*da0073e9SAndroid Build Coastguard Worker                a = self._generate_input(size1, dtypes[0], device, with_extremal)
3143*da0073e9SAndroid Build Coastguard Worker                b = self._generate_input(size2, dtypes[1], device, with_extremal)
3144*da0073e9SAndroid Build Coastguard Worker                for torch_op, numpy_op in op_pairs:
3145*da0073e9SAndroid Build Coastguard Worker                    if (
3146*da0073e9SAndroid Build Coastguard Worker                        dtypes[0].is_complex or dtypes[1].is_complex
3147*da0073e9SAndroid Build Coastguard Worker                    ) and torch_op in complex_op_denylist:
3148*da0073e9SAndroid Build Coastguard Worker                        continue
3149*da0073e9SAndroid Build Coastguard Worker                    # functional version of op
3150*da0073e9SAndroid Build Coastguard Worker                    compare_with_numpy_bin_op(torch_op, numpy_op, a, b)
3151*da0073e9SAndroid Build Coastguard Worker
3152*da0073e9SAndroid Build Coastguard Worker                    # functional comparison ops always return bool tensors
3153*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(torch_op(a, b).dtype, torch.bool)
3154*da0073e9SAndroid Build Coastguard Worker
3155*da0073e9SAndroid Build Coastguard Worker                    # out version of op
3156*da0073e9SAndroid Build Coastguard Worker                    out = torch.zeros(
3157*da0073e9SAndroid Build Coastguard Worker                        1, dtype=torch.complex128
3158*da0073e9SAndroid Build Coastguard Worker                    )  # all casts to complex128 are safe
3159*da0073e9SAndroid Build Coastguard Worker                    compare_with_numpy_bin_op(torch_op, numpy_op, a, b, out=out)
3160*da0073e9SAndroid Build Coastguard Worker
3161*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
3162*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.int8, torch.int16, torch.int32, torch.int64)
3163*da0073e9SAndroid Build Coastguard Worker    def test_signed_shift(self, device, dtype):
3164*da0073e9SAndroid Build Coastguard Worker        "Ensure that signed integer bit shifting works as expected."
3165*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([-10, 10], device=device, dtype=dtype)  # [11...1110110, 1010]
3166*da0073e9SAndroid Build Coastguard Worker        expected_l = torch.tensor(
3167*da0073e9SAndroid Build Coastguard Worker            [-40, 40], device=device, dtype=dtype
3168*da0073e9SAndroid Build Coastguard Worker        )  # [11...11011000, 101000]
3169*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a << 2, expected_l)
3170*da0073e9SAndroid Build Coastguard Worker        self.compare_with_numpy(lambda x: x << 2, lambda x: np.left_shift(x, 2), a)
3171*da0073e9SAndroid Build Coastguard Worker        expected_r = torch.tensor(
3172*da0073e9SAndroid Build Coastguard Worker            [-5, 5], device=device, dtype=dtype
3173*da0073e9SAndroid Build Coastguard Worker        )  # [1111...111011, 101]
3174*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a >> 1, expected_r)
3175*da0073e9SAndroid Build Coastguard Worker        self.compare_with_numpy(lambda x: x >> 1, lambda x: np.right_shift(x, 1), a)
3176*da0073e9SAndroid Build Coastguard Worker
3177*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
3178*da0073e9SAndroid Build Coastguard Worker    @dtypes(*get_all_int_dtypes())
3179*da0073e9SAndroid Build Coastguard Worker    def test_shift_limits(self, device, dtype):
3180*da0073e9SAndroid Build Coastguard Worker        "Ensure that integer bit shifting works as expected with out-of-limits shift values."
3181*da0073e9SAndroid Build Coastguard Worker        # Issue #70904
3182*da0073e9SAndroid Build Coastguard Worker        iinfo = torch.iinfo(dtype)
3183*da0073e9SAndroid Build Coastguard Worker        bits = iinfo.bits
3184*da0073e9SAndroid Build Coastguard Worker        low = iinfo.min
3185*da0073e9SAndroid Build Coastguard Worker        high = iinfo.max
3186*da0073e9SAndroid Build Coastguard Worker        exact_dtype = (
3187*da0073e9SAndroid Build Coastguard Worker            dtype != torch.uint8
3188*da0073e9SAndroid Build Coastguard Worker        )  # numpy changes dtype from uint8 to int16 for some out-of-limits shift values
3189*da0073e9SAndroid Build Coastguard Worker        for input in (
3190*da0073e9SAndroid Build Coastguard Worker            torch.tensor(
3191*da0073e9SAndroid Build Coastguard Worker                [-1, 0, 1], device=device, dtype=dtype
3192*da0073e9SAndroid Build Coastguard Worker            ),  # small for non-vectorized operation
3193*da0073e9SAndroid Build Coastguard Worker            torch.tensor(
3194*da0073e9SAndroid Build Coastguard Worker                [low, high], device=device, dtype=dtype
3195*da0073e9SAndroid Build Coastguard Worker            ),  # small for non-vectorized operation
3196*da0073e9SAndroid Build Coastguard Worker            make_tensor(
3197*da0073e9SAndroid Build Coastguard Worker                (64, 64, 64), low=low, high=high, device=device, dtype=dtype
3198*da0073e9SAndroid Build Coastguard Worker            ),  # large for vectorized operation
3199*da0073e9SAndroid Build Coastguard Worker        ):
3200*da0073e9SAndroid Build Coastguard Worker            shift_left_expected = torch.zeros_like(input)
3201*da0073e9SAndroid Build Coastguard Worker            shift_right_expected = torch.clamp(input, -1, 0)
3202*da0073e9SAndroid Build Coastguard Worker            for shift in chain(range(-100, -1), range(bits, 100)):
3203*da0073e9SAndroid Build Coastguard Worker                shift_left = input << shift
3204*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(shift_left, shift_left_expected, msg=f"<< {shift}")
3205*da0073e9SAndroid Build Coastguard Worker                self.compare_with_numpy(
3206*da0073e9SAndroid Build Coastguard Worker                    lambda x: x << shift,
3207*da0073e9SAndroid Build Coastguard Worker                    lambda x: np.left_shift(x, shift),
3208*da0073e9SAndroid Build Coastguard Worker                    input,
3209*da0073e9SAndroid Build Coastguard Worker                    exact_dtype=exact_dtype,
3210*da0073e9SAndroid Build Coastguard Worker                    msg=f"<< {shift}",
3211*da0073e9SAndroid Build Coastguard Worker                )
3212*da0073e9SAndroid Build Coastguard Worker                shift_right = input >> shift
3213*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(shift_right, shift_right_expected, msg=f">> {shift}")
3214*da0073e9SAndroid Build Coastguard Worker                self.compare_with_numpy(
3215*da0073e9SAndroid Build Coastguard Worker                    lambda x: x >> shift,
3216*da0073e9SAndroid Build Coastguard Worker                    lambda x: np.right_shift(x, shift),
3217*da0073e9SAndroid Build Coastguard Worker                    input,
3218*da0073e9SAndroid Build Coastguard Worker                    exact_dtype=exact_dtype,
3219*da0073e9SAndroid Build Coastguard Worker                    msg=f">> {shift}",
3220*da0073e9SAndroid Build Coastguard Worker                )
3221*da0073e9SAndroid Build Coastguard Worker
3222*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
3223*da0073e9SAndroid Build Coastguard Worker    @dtypes(
3224*da0073e9SAndroid Build Coastguard Worker        *list(
3225*da0073e9SAndroid Build Coastguard Worker            product(
3226*da0073e9SAndroid Build Coastguard Worker                all_types_and(torch.half, torch.bfloat16, torch.bool),
3227*da0073e9SAndroid Build Coastguard Worker                all_types_and(torch.half, torch.bfloat16, torch.bool),
3228*da0073e9SAndroid Build Coastguard Worker            )
3229*da0073e9SAndroid Build Coastguard Worker        )
3230*da0073e9SAndroid Build Coastguard Worker    )
3231*da0073e9SAndroid Build Coastguard Worker    def test_heaviside(self, device, dtypes):
3232*da0073e9SAndroid Build Coastguard Worker        input_dtype = dtypes[0]
3233*da0073e9SAndroid Build Coastguard Worker        values_dtype = dtypes[1]
3234*da0073e9SAndroid Build Coastguard Worker
3235*da0073e9SAndroid Build Coastguard Worker        rng = np.random.default_rng()
3236*da0073e9SAndroid Build Coastguard Worker        input = np.array(
3237*da0073e9SAndroid Build Coastguard Worker            rng.integers(-10, 10, size=10),
3238*da0073e9SAndroid Build Coastguard Worker            dtype=torch_to_numpy_dtype_dict[
3239*da0073e9SAndroid Build Coastguard Worker                input_dtype if (input_dtype != torch.bfloat16) else torch.float64
3240*da0073e9SAndroid Build Coastguard Worker            ],
3241*da0073e9SAndroid Build Coastguard Worker        )
3242*da0073e9SAndroid Build Coastguard Worker        input[0] = input[3] = input[7] = 0
3243*da0073e9SAndroid Build Coastguard Worker        values = np.array(
3244*da0073e9SAndroid Build Coastguard Worker            rng.integers(-10, 10, size=10),
3245*da0073e9SAndroid Build Coastguard Worker            dtype=torch_to_numpy_dtype_dict[
3246*da0073e9SAndroid Build Coastguard Worker                values_dtype if (values_dtype != torch.bfloat16) else torch.float64
3247*da0073e9SAndroid Build Coastguard Worker            ],
3248*da0073e9SAndroid Build Coastguard Worker        )
3249*da0073e9SAndroid Build Coastguard Worker        np_result = torch.from_numpy(np.heaviside(input, values)).to(
3250*da0073e9SAndroid Build Coastguard Worker            device=device, dtype=input_dtype
3251*da0073e9SAndroid Build Coastguard Worker        )
3252*da0073e9SAndroid Build Coastguard Worker
3253*da0073e9SAndroid Build Coastguard Worker        input = torch.from_numpy(input).to(device=device, dtype=input_dtype)
3254*da0073e9SAndroid Build Coastguard Worker        values = torch.from_numpy(values).to(device=device, dtype=values_dtype)
3255*da0073e9SAndroid Build Coastguard Worker        out = torch.empty_like(input)
3256*da0073e9SAndroid Build Coastguard Worker
3257*da0073e9SAndroid Build Coastguard Worker        if input_dtype == values_dtype:
3258*da0073e9SAndroid Build Coastguard Worker            torch_result = torch.heaviside(input, values)
3259*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(np_result, torch_result)
3260*da0073e9SAndroid Build Coastguard Worker
3261*da0073e9SAndroid Build Coastguard Worker            torch_result = input.heaviside(values)
3262*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(np_result, torch_result)
3263*da0073e9SAndroid Build Coastguard Worker
3264*da0073e9SAndroid Build Coastguard Worker            torch.heaviside(input, values, out=out)
3265*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(np_result, out)
3266*da0073e9SAndroid Build Coastguard Worker
3267*da0073e9SAndroid Build Coastguard Worker            input.heaviside_(values)
3268*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(np_result, input)
3269*da0073e9SAndroid Build Coastguard Worker        else:
3270*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
3271*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
3272*da0073e9SAndroid Build Coastguard Worker                "heaviside is not yet implemented for tensors with different dtypes.",
3273*da0073e9SAndroid Build Coastguard Worker            ):
3274*da0073e9SAndroid Build Coastguard Worker                torch.heaviside(input, values)
3275*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
3276*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
3277*da0073e9SAndroid Build Coastguard Worker                "heaviside is not yet implemented for tensors with different dtypes.",
3278*da0073e9SAndroid Build Coastguard Worker            ):
3279*da0073e9SAndroid Build Coastguard Worker                input.heaviside(values)
3280*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
3281*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
3282*da0073e9SAndroid Build Coastguard Worker                "heaviside is not yet implemented for tensors with different dtypes.",
3283*da0073e9SAndroid Build Coastguard Worker            ):
3284*da0073e9SAndroid Build Coastguard Worker                torch.heaviside(input, values, out=out)
3285*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
3286*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
3287*da0073e9SAndroid Build Coastguard Worker                "heaviside is not yet implemented for tensors with different dtypes.",
3288*da0073e9SAndroid Build Coastguard Worker            ):
3289*da0073e9SAndroid Build Coastguard Worker                input.heaviside_(values)
3290*da0073e9SAndroid Build Coastguard Worker
3291*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
3292*da0073e9SAndroid Build Coastguard Worker    def test_heaviside_cross_device(self, device):
3293*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([-9, 5, 0, 6, -2, 2], device=device)
3294*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor(0)
3295*da0073e9SAndroid Build Coastguard Worker        result = torch.heaviside(x, y)
3296*da0073e9SAndroid Build Coastguard Worker        expect = torch.tensor([0, 1, 0, 1, 0, 1], device=device)
3297*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, expect)
3298*da0073e9SAndroid Build Coastguard Worker
3299*da0073e9SAndroid Build Coastguard Worker        result = torch.heaviside(y, x)
3300*da0073e9SAndroid Build Coastguard Worker        expect = torch.tensor([-9, 5, 0, 6, -2, 2], device=device)
3301*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, expect)
3302*da0073e9SAndroid Build Coastguard Worker
3303*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([-9, 5, 0, 6, -2, 2])
3304*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor(0, device=device)
3305*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
3306*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Expected all tensors to be on the same device"
3307*da0073e9SAndroid Build Coastguard Worker        ):
3308*da0073e9SAndroid Build Coastguard Worker            torch.heaviside(x, y)
3309*da0073e9SAndroid Build Coastguard Worker
3310*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
3311*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Expected all tensors to be on the same device"
3312*da0073e9SAndroid Build Coastguard Worker        ):
3313*da0073e9SAndroid Build Coastguard Worker            torch.heaviside(y, x)
3314*da0073e9SAndroid Build Coastguard Worker
3315*da0073e9SAndroid Build Coastguard Worker    @dtypes(*list(product(complex_types(), complex_types())))
3316*da0073e9SAndroid Build Coastguard Worker    def test_heaviside_complex(self, device, dtypes):
3317*da0073e9SAndroid Build Coastguard Worker        input_dtype = dtypes[0]
3318*da0073e9SAndroid Build Coastguard Worker        values_dtype = dtypes[1]
3319*da0073e9SAndroid Build Coastguard Worker
3320*da0073e9SAndroid Build Coastguard Worker        data = (complex(0, -6), complex(-1, 3), complex(1, 1))
3321*da0073e9SAndroid Build Coastguard Worker        input = torch.tensor(data, device=device, dtype=input_dtype)
3322*da0073e9SAndroid Build Coastguard Worker        values = torch.tensor(data, device=device, dtype=values_dtype)
3323*da0073e9SAndroid Build Coastguard Worker        out = torch.empty_like(input)
3324*da0073e9SAndroid Build Coastguard Worker        real = input.real
3325*da0073e9SAndroid Build Coastguard Worker
3326*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
3327*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "heaviside is not yet implemented for complex tensors."
3328*da0073e9SAndroid Build Coastguard Worker        ):
3329*da0073e9SAndroid Build Coastguard Worker            torch.heaviside(input, real)
3330*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
3331*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "heaviside is not yet implemented for complex tensors."
3332*da0073e9SAndroid Build Coastguard Worker        ):
3333*da0073e9SAndroid Build Coastguard Worker            real.heaviside(values)
3334*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
3335*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "heaviside is not yet implemented for complex tensors."
3336*da0073e9SAndroid Build Coastguard Worker        ):
3337*da0073e9SAndroid Build Coastguard Worker            input.heaviside_(values)
3338*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
3339*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "heaviside is not yet implemented for complex tensors."
3340*da0073e9SAndroid Build Coastguard Worker        ):
3341*da0073e9SAndroid Build Coastguard Worker            torch.heaviside(real, real, out=out)
3342*da0073e9SAndroid Build Coastguard Worker
3343*da0073e9SAndroid Build Coastguard Worker    def _test_logical(self, device, dtypes, op, a_, b_, expected_res_):
3344*da0073e9SAndroid Build Coastguard Worker        expected_res = torch.tensor(expected_res_, dtype=dtypes[0], device=device)
3345*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(a_, dtype=dtypes[0], device=device)
3346*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor(b_, dtype=dtypes[1], device=device)
3347*da0073e9SAndroid Build Coastguard Worker
3348*da0073e9SAndroid Build Coastguard Worker        # new tensor
3349*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_res.bool(), getattr(a, op)(b))
3350*da0073e9SAndroid Build Coastguard Worker        # out
3351*da0073e9SAndroid Build Coastguard Worker        c = torch.empty(0, dtype=torch.bool, device=device)
3352*da0073e9SAndroid Build Coastguard Worker        getattr(torch, op)(a, b, out=c)
3353*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_res.bool(), c)
3354*da0073e9SAndroid Build Coastguard Worker
3355*da0073e9SAndroid Build Coastguard Worker        getattr(a, op + "_")(b)
3356*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_res, a)
3357*da0073e9SAndroid Build Coastguard Worker
3358*da0073e9SAndroid Build Coastguard Worker    @dtypes(
3359*da0073e9SAndroid Build Coastguard Worker        *product(
3360*da0073e9SAndroid Build Coastguard Worker            all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
3361*da0073e9SAndroid Build Coastguard Worker            all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
3362*da0073e9SAndroid Build Coastguard Worker        )
3363*da0073e9SAndroid Build Coastguard Worker    )
3364*da0073e9SAndroid Build Coastguard Worker    def test_logical_xor(self, device, dtypes):
3365*da0073e9SAndroid Build Coastguard Worker        self._test_logical(
3366*da0073e9SAndroid Build Coastguard Worker            device, dtypes, "logical_xor", [10, 0, 1, 0], [1, 0, 0, 10], [0, 0, 1, 1]
3367*da0073e9SAndroid Build Coastguard Worker        )
3368*da0073e9SAndroid Build Coastguard Worker
3369*da0073e9SAndroid Build Coastguard Worker    @dtypes(
3370*da0073e9SAndroid Build Coastguard Worker        *product(
3371*da0073e9SAndroid Build Coastguard Worker            all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
3372*da0073e9SAndroid Build Coastguard Worker            all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
3373*da0073e9SAndroid Build Coastguard Worker        )
3374*da0073e9SAndroid Build Coastguard Worker    )
3375*da0073e9SAndroid Build Coastguard Worker    def test_logical_and(self, device, dtypes):
3376*da0073e9SAndroid Build Coastguard Worker        self._test_logical(
3377*da0073e9SAndroid Build Coastguard Worker            device, dtypes, "logical_and", [10, 0, 1, 0], [1, 0, 0, 10], [1, 0, 0, 0]
3378*da0073e9SAndroid Build Coastguard Worker        )
3379*da0073e9SAndroid Build Coastguard Worker
3380*da0073e9SAndroid Build Coastguard Worker    @dtypes(
3381*da0073e9SAndroid Build Coastguard Worker        *product(
3382*da0073e9SAndroid Build Coastguard Worker            all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
3383*da0073e9SAndroid Build Coastguard Worker            all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
3384*da0073e9SAndroid Build Coastguard Worker        )
3385*da0073e9SAndroid Build Coastguard Worker    )
3386*da0073e9SAndroid Build Coastguard Worker    def test_logical_or(self, device, dtypes):
3387*da0073e9SAndroid Build Coastguard Worker        self._test_logical(
3388*da0073e9SAndroid Build Coastguard Worker            device, dtypes, "logical_or", [10, 0, 1, 0], [1, 0, 0, 10], [1, 0, 1, 1]
3389*da0073e9SAndroid Build Coastguard Worker        )
3390*da0073e9SAndroid Build Coastguard Worker
3391*da0073e9SAndroid Build Coastguard Worker    def test_remainder_overflow(self, device):
3392*da0073e9SAndroid Build Coastguard Worker        # Check Integer Overflows
3393*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(23500, dtype=torch.int64, device=device)
3394*da0073e9SAndroid Build Coastguard Worker        q = 392486996410368
3395*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x % q, x)
3396*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(-x % q, q - x)
3397*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x % -q, x - q)
3398*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(-x % -q, -x)
3399*da0073e9SAndroid Build Coastguard Worker
3400*da0073e9SAndroid Build Coastguard Worker    def test_rpow(self, device):
3401*da0073e9SAndroid Build Coastguard Worker        m = torch.randn(10, 10, device=device)
3402*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.pow(2, m), 2**m)
3403*da0073e9SAndroid Build Coastguard Worker
3404*da0073e9SAndroid Build Coastguard Worker        # test with scalar
3405*da0073e9SAndroid Build Coastguard Worker        m = torch.randn(1, device=device).squeeze()
3406*da0073e9SAndroid Build Coastguard Worker        assert m.dim() == 0, "m is intentionally a scalar"
3407*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.pow(2, m), 2**m)
3408*da0073e9SAndroid Build Coastguard Worker
3409*da0073e9SAndroid Build Coastguard Worker    def test_ldexp(self, device):
3410*da0073e9SAndroid Build Coastguard Worker        # random values
3411*da0073e9SAndroid Build Coastguard Worker        mantissas = torch.randn(64, device=device)
3412*da0073e9SAndroid Build Coastguard Worker        exponents = torch.randint(-31, 31, (64,), device=device, dtype=torch.int32)
3413*da0073e9SAndroid Build Coastguard Worker
3414*da0073e9SAndroid Build Coastguard Worker        # basic test
3415*da0073e9SAndroid Build Coastguard Worker        np_outcome = np.ldexp(mantissas.cpu().numpy(), exponents.cpu().numpy())
3416*da0073e9SAndroid Build Coastguard Worker        pt_outcome_1 = torch.ldexp(mantissas, exponents)
3417*da0073e9SAndroid Build Coastguard Worker        pt_outcome_2 = mantissas.ldexp(exponents)
3418*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(np_outcome, pt_outcome_1.cpu())
3419*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(np_outcome, pt_outcome_2.cpu())
3420*da0073e9SAndroid Build Coastguard Worker        mantissas.ldexp_(exponents)
3421*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(np_outcome, mantissas.cpu())
3422*da0073e9SAndroid Build Coastguard Worker
3423*da0073e9SAndroid Build Coastguard Worker        # test bounds
3424*da0073e9SAndroid Build Coastguard Worker        mantissas = torch.tensor(
3425*da0073e9SAndroid Build Coastguard Worker            [float("inf"), float("-inf"), float("inf"), float("nan")], device=device
3426*da0073e9SAndroid Build Coastguard Worker        )
3427*da0073e9SAndroid Build Coastguard Worker        exponents = torch.randint(0, 31, (4,), device=device, dtype=torch.int32)
3428*da0073e9SAndroid Build Coastguard Worker        np_outcome = np.ldexp(mantissas.cpu().numpy(), exponents.cpu().numpy())
3429*da0073e9SAndroid Build Coastguard Worker        pt_outcome = torch.ldexp(mantissas, exponents)
3430*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(np_outcome, pt_outcome.cpu())
3431*da0073e9SAndroid Build Coastguard Worker
3432*da0073e9SAndroid Build Coastguard Worker        # test half dtype behavior
3433*da0073e9SAndroid Build Coastguard Worker        mantissas = torch.randn(64, device=device, dtype=torch.half)
3434*da0073e9SAndroid Build Coastguard Worker        exponents = torch.randint(-5, 5, (64,), device=device)
3435*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.ldexp(mantissas, exponents).dtype, torch.half)
3436*da0073e9SAndroid Build Coastguard Worker
3437*da0073e9SAndroid Build Coastguard Worker        # test float64 computation
3438*da0073e9SAndroid Build Coastguard Worker        mantissas = torch.tensor([1], dtype=torch.float64, device=device)
3439*da0073e9SAndroid Build Coastguard Worker        exponents = torch.tensor([128], dtype=torch.int64, device=device)
3440*da0073e9SAndroid Build Coastguard Worker        expected = torch.pow(
3441*da0073e9SAndroid Build Coastguard Worker            torch.full((1,), 2, device=device, dtype=torch.float64), 128
3442*da0073e9SAndroid Build Coastguard Worker        )
3443*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.ldexp(mantissas, exponents), expected)
3444*da0073e9SAndroid Build Coastguard Worker
3445*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
3446*da0073e9SAndroid Build Coastguard Worker    def test_lerp(self, device, dtype):
3447*da0073e9SAndroid Build Coastguard Worker        start_end_weight_shapes = [(), (5,), (5, 5)]
3448*da0073e9SAndroid Build Coastguard Worker        for shapes in product(
3449*da0073e9SAndroid Build Coastguard Worker            start_end_weight_shapes, start_end_weight_shapes, start_end_weight_shapes
3450*da0073e9SAndroid Build Coastguard Worker        ):
3451*da0073e9SAndroid Build Coastguard Worker            start = torch.randn(shapes[0], device=device, dtype=dtype)
3452*da0073e9SAndroid Build Coastguard Worker            end = torch.randn(shapes[1], device=device, dtype=dtype)
3453*da0073e9SAndroid Build Coastguard Worker
3454*da0073e9SAndroid Build Coastguard Worker            # Tensor weights
3455*da0073e9SAndroid Build Coastguard Worker            weights = [
3456*da0073e9SAndroid Build Coastguard Worker                torch.randn(shapes[2], device=device, dtype=dtype),
3457*da0073e9SAndroid Build Coastguard Worker                random.random(),
3458*da0073e9SAndroid Build Coastguard Worker            ]
3459*da0073e9SAndroid Build Coastguard Worker            if dtype.is_complex:
3460*da0073e9SAndroid Build Coastguard Worker                weights += [complex(0, 1), complex(0.4, 1.2)]
3461*da0073e9SAndroid Build Coastguard Worker
3462*da0073e9SAndroid Build Coastguard Worker            for weight in weights:
3463*da0073e9SAndroid Build Coastguard Worker                actual = torch.lerp(start, end, weight)
3464*da0073e9SAndroid Build Coastguard Worker                actual_method = start.lerp(end, weight)
3465*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual, actual_method)
3466*da0073e9SAndroid Build Coastguard Worker                actual_out = torch.tensor(1.0, dtype=dtype, device=device)
3467*da0073e9SAndroid Build Coastguard Worker                torch.lerp(start, end, weight, out=actual_out)
3468*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual, actual_out)
3469*da0073e9SAndroid Build Coastguard Worker                expected = start + weight * (end - start)
3470*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expected, actual)
3471*da0073e9SAndroid Build Coastguard Worker
3472*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
3473*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.half, torch.bfloat16)
3474*da0073e9SAndroid Build Coastguard Worker    def test_lerp_lowp(self, device, dtype):
3475*da0073e9SAndroid Build Coastguard Worker        xvals = (0.0, -30000.0)
3476*da0073e9SAndroid Build Coastguard Worker        yvals = (0.1, -20000.0)
3477*da0073e9SAndroid Build Coastguard Worker        xs = [torch.full((4,), xval, device=device, dtype=dtype) for xval in xvals]
3478*da0073e9SAndroid Build Coastguard Worker        ys = [torch.full((4,), yval, device=device, dtype=dtype) for yval in yvals]
3479*da0073e9SAndroid Build Coastguard Worker        weights = [70000, torch.full((4,), 8, device=device, dtype=dtype)]
3480*da0073e9SAndroid Build Coastguard Worker        for x, y, w in zip(xs, ys, weights):
3481*da0073e9SAndroid Build Coastguard Worker            xref = x.float()
3482*da0073e9SAndroid Build Coastguard Worker            yref = y.float()
3483*da0073e9SAndroid Build Coastguard Worker            wref = w.float() if isinstance(w, torch.Tensor) else w
3484*da0073e9SAndroid Build Coastguard Worker            actual = torch.lerp(x, y, w)
3485*da0073e9SAndroid Build Coastguard Worker            expected = torch.lerp(xref, yref, wref).to(dtype)
3486*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expected, atol=0.0, rtol=0.0)
3487*da0073e9SAndroid Build Coastguard Worker
3488*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
3489*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.half, torch.bfloat16)
3490*da0073e9SAndroid Build Coastguard Worker    def test_lerp_lowp_cpu(self, device, dtype):
3491*da0073e9SAndroid Build Coastguard Worker        xvals = (0.0, -30000.0)
3492*da0073e9SAndroid Build Coastguard Worker        yvals = (0.1, -20000.0)
3493*da0073e9SAndroid Build Coastguard Worker        for shape in [(4,), (20,), (3, 10, 10)]:
3494*da0073e9SAndroid Build Coastguard Worker            xs = [torch.full(shape, xval, device=device, dtype=dtype) for xval in xvals]
3495*da0073e9SAndroid Build Coastguard Worker            ys = [torch.full(shape, yval, device=device, dtype=dtype) for yval in yvals]
3496*da0073e9SAndroid Build Coastguard Worker            weights = [70000, torch.full(shape, 8, device=device, dtype=dtype)]
3497*da0073e9SAndroid Build Coastguard Worker            for x, y, w in zip(xs, ys, weights):
3498*da0073e9SAndroid Build Coastguard Worker                xref = x.float()
3499*da0073e9SAndroid Build Coastguard Worker                yref = y.float()
3500*da0073e9SAndroid Build Coastguard Worker                wref = w.float() if isinstance(w, torch.Tensor) else w
3501*da0073e9SAndroid Build Coastguard Worker                actual = torch.lerp(x, y, w)
3502*da0073e9SAndroid Build Coastguard Worker                expected = torch.lerp(xref, yref, wref).to(dtype)
3503*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual, expected, atol=0.0, rtol=0.0)
3504*da0073e9SAndroid Build Coastguard Worker
3505*da0073e9SAndroid Build Coastguard Worker    def _test_logaddexp(self, device, dtype, base2):
3506*da0073e9SAndroid Build Coastguard Worker        if base2:
3507*da0073e9SAndroid Build Coastguard Worker            ref_func = np.logaddexp2
3508*da0073e9SAndroid Build Coastguard Worker            our_func = torch.logaddexp2
3509*da0073e9SAndroid Build Coastguard Worker        elif dtype in (torch.complex64, torch.complex128):
3510*da0073e9SAndroid Build Coastguard Worker            # numpy has not implemented logaddexp for complex
3511*da0073e9SAndroid Build Coastguard Worker            def _ref_func(x, y):
3512*da0073e9SAndroid Build Coastguard Worker                return scipy.special.logsumexp(np.stack((x, y), axis=0), axis=0)
3513*da0073e9SAndroid Build Coastguard Worker
3514*da0073e9SAndroid Build Coastguard Worker            ref_func = _ref_func
3515*da0073e9SAndroid Build Coastguard Worker            our_func = torch.logaddexp
3516*da0073e9SAndroid Build Coastguard Worker        else:
3517*da0073e9SAndroid Build Coastguard Worker            ref_func = np.logaddexp
3518*da0073e9SAndroid Build Coastguard Worker            our_func = torch.logaddexp
3519*da0073e9SAndroid Build Coastguard Worker
3520*da0073e9SAndroid Build Coastguard Worker        def _test_helper(a, b):
3521*da0073e9SAndroid Build Coastguard Worker            if dtype == torch.bfloat16:
3522*da0073e9SAndroid Build Coastguard Worker                ref = ref_func(a.cpu().float().numpy(), b.cpu().float().numpy())
3523*da0073e9SAndroid Build Coastguard Worker                v = our_func(a, b)
3524*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(ref, v.float(), atol=0.01, rtol=0.01)
3525*da0073e9SAndroid Build Coastguard Worker            else:
3526*da0073e9SAndroid Build Coastguard Worker                ref = ref_func(a.cpu().numpy(), b.cpu().numpy())
3527*da0073e9SAndroid Build Coastguard Worker                v = our_func(a, b)
3528*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(ref, v)
3529*da0073e9SAndroid Build Coastguard Worker
3530*da0073e9SAndroid Build Coastguard Worker        # simple test
3531*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(64, 2, dtype=dtype, device=device) - 0.5
3532*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(64, 2, dtype=dtype, device=device) - 0.5
3533*da0073e9SAndroid Build Coastguard Worker        _test_helper(a, b)
3534*da0073e9SAndroid Build Coastguard Worker        _test_helper(a[:3], b[:3])
3535*da0073e9SAndroid Build Coastguard Worker
3536*da0073e9SAndroid Build Coastguard Worker        # large value test for numerical stability
3537*da0073e9SAndroid Build Coastguard Worker        a *= 10000
3538*da0073e9SAndroid Build Coastguard Worker        b *= 10000
3539*da0073e9SAndroid Build Coastguard Worker        _test_helper(a, b)
3540*da0073e9SAndroid Build Coastguard Worker        _test_helper(a[:3], b[:3])
3541*da0073e9SAndroid Build Coastguard Worker
3542*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(
3543*da0073e9SAndroid Build Coastguard Worker            [float("inf"), float("-inf"), float("inf"), float("nan")],
3544*da0073e9SAndroid Build Coastguard Worker            dtype=dtype,
3545*da0073e9SAndroid Build Coastguard Worker            device=device,
3546*da0073e9SAndroid Build Coastguard Worker        )
3547*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor(
3548*da0073e9SAndroid Build Coastguard Worker            [float("inf"), float("-inf"), float("-inf"), float("nan")],
3549*da0073e9SAndroid Build Coastguard Worker            dtype=dtype,
3550*da0073e9SAndroid Build Coastguard Worker            device=device,
3551*da0073e9SAndroid Build Coastguard Worker        )
3552*da0073e9SAndroid Build Coastguard Worker        _test_helper(a, b)
3553*da0073e9SAndroid Build Coastguard Worker
3554*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo()  # complex infs/nans differ under Dynamo/Inductor
3555*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.float32, torch.float64, torch.bfloat16)
3556*da0073e9SAndroid Build Coastguard Worker    @dtypes(
3557*da0073e9SAndroid Build Coastguard Worker        torch.float32, torch.float64, torch.bfloat16, torch.complex64, torch.complex128
3558*da0073e9SAndroid Build Coastguard Worker    )
3559*da0073e9SAndroid Build Coastguard Worker    def test_logaddexp(self, device, dtype):
3560*da0073e9SAndroid Build Coastguard Worker        self._test_logaddexp(device, dtype, base2=False)
3561*da0073e9SAndroid Build Coastguard Worker
3562*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32, torch.float64, torch.bfloat16)
3563*da0073e9SAndroid Build Coastguard Worker    def test_logaddexp2(self, device, dtype):
3564*da0073e9SAndroid Build Coastguard Worker        self._test_logaddexp(device, dtype, base2=True)
3565*da0073e9SAndroid Build Coastguard Worker
3566*da0073e9SAndroid Build Coastguard Worker    def test_add(self, device):
3567*da0073e9SAndroid Build Coastguard Worker        dtypes = floating_and_complex_types()
3568*da0073e9SAndroid Build Coastguard Worker        for dtype in dtypes:
3569*da0073e9SAndroid Build Coastguard Worker            # [res] torch.add([res,] tensor1, tensor2)
3570*da0073e9SAndroid Build Coastguard Worker            m1 = torch.randn(100, 100, dtype=dtype, device=device)
3571*da0073e9SAndroid Build Coastguard Worker            v1 = torch.randn(100, dtype=dtype, device=device)
3572*da0073e9SAndroid Build Coastguard Worker
3573*da0073e9SAndroid Build Coastguard Worker            # contiguous
3574*da0073e9SAndroid Build Coastguard Worker            res1 = torch.add(m1[4], v1)
3575*da0073e9SAndroid Build Coastguard Worker            res2 = res1.clone().zero_()
3576*da0073e9SAndroid Build Coastguard Worker            for i in range(m1.size(1)):
3577*da0073e9SAndroid Build Coastguard Worker                res2[i] = m1[4, i] + v1[i]
3578*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res1, res2)
3579*da0073e9SAndroid Build Coastguard Worker
3580*da0073e9SAndroid Build Coastguard Worker            m1 = torch.randn(100, 100, device=device)
3581*da0073e9SAndroid Build Coastguard Worker            v1 = torch.randn(100, device=device)
3582*da0073e9SAndroid Build Coastguard Worker
3583*da0073e9SAndroid Build Coastguard Worker            # non-contiguous
3584*da0073e9SAndroid Build Coastguard Worker            res1 = torch.add(m1[:, 4], v1)
3585*da0073e9SAndroid Build Coastguard Worker            res2 = res1.clone().zero_()
3586*da0073e9SAndroid Build Coastguard Worker            for i in range(m1.size(0)):
3587*da0073e9SAndroid Build Coastguard Worker                res2[i] = m1[i, 4] + v1[i]
3588*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res1, res2)
3589*da0073e9SAndroid Build Coastguard Worker
3590*da0073e9SAndroid Build Coastguard Worker            # [res] torch.add([res,] tensor, value)
3591*da0073e9SAndroid Build Coastguard Worker            m1 = torch.randn(10, 10, device=device)
3592*da0073e9SAndroid Build Coastguard Worker
3593*da0073e9SAndroid Build Coastguard Worker            # contiguous
3594*da0073e9SAndroid Build Coastguard Worker            res1 = m1.clone()
3595*da0073e9SAndroid Build Coastguard Worker            res1[3].add_(2)
3596*da0073e9SAndroid Build Coastguard Worker            res2 = m1.clone()
3597*da0073e9SAndroid Build Coastguard Worker            for i in range(m1.size(1)):
3598*da0073e9SAndroid Build Coastguard Worker                res2[3, i] = res2[3, i] + 2
3599*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res1, res2)
3600*da0073e9SAndroid Build Coastguard Worker
3601*da0073e9SAndroid Build Coastguard Worker            # non-contiguous
3602*da0073e9SAndroid Build Coastguard Worker            m1 = torch.randn(10, 10, device=device)
3603*da0073e9SAndroid Build Coastguard Worker            res1 = m1.clone()
3604*da0073e9SAndroid Build Coastguard Worker            res1[:, 3].add_(2)
3605*da0073e9SAndroid Build Coastguard Worker            res2 = m1.clone()
3606*da0073e9SAndroid Build Coastguard Worker            for i in range(m1.size(0)):
3607*da0073e9SAndroid Build Coastguard Worker                res2[i, 3] = res2[i, 3] + 2
3608*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res1, res2)
3609*da0073e9SAndroid Build Coastguard Worker
3610*da0073e9SAndroid Build Coastguard Worker            # inter-type
3611*da0073e9SAndroid Build Coastguard Worker            m1 = torch.randn(10, 10, dtype=dtype, device=device)
3612*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m1 + 3, m1 + torch.tensor(3))
3613*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(3 + m1, torch.tensor(3) + m1)
3614*da0073e9SAndroid Build Coastguard Worker
3615*da0073e9SAndroid Build Coastguard Worker            # contiguous + non-contiguous
3616*da0073e9SAndroid Build Coastguard Worker            m1 = torch.randn(10, 10, dtype=dtype, device=device)
3617*da0073e9SAndroid Build Coastguard Worker            m2 = torch.randn(10, 10, dtype=dtype, device=device).t()
3618*da0073e9SAndroid Build Coastguard Worker            res = m1 + m2
3619*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(res.is_contiguous())
3620*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, m1 + m2.contiguous())
3621*da0073e9SAndroid Build Coastguard Worker
3622*da0073e9SAndroid Build Coastguard Worker            # 1d + empty
3623*da0073e9SAndroid Build Coastguard Worker            m1 = torch.tensor([1.0], dtype=dtype, device=device)
3624*da0073e9SAndroid Build Coastguard Worker            m2 = torch.tensor([], dtype=dtype, device=device)
3625*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m1 + m2, [])
3626*da0073e9SAndroid Build Coastguard Worker
3627*da0073e9SAndroid Build Coastguard Worker        # inter-type unint8
3628*da0073e9SAndroid Build Coastguard Worker        one = torch.tensor(1, dtype=torch.uint8, device=device)
3629*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.add(one, 1), 2)
3630*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.add(one, 1).dtype, torch.uint8)
3631*da0073e9SAndroid Build Coastguard Worker
3632*da0073e9SAndroid Build Coastguard Worker        # bool
3633*da0073e9SAndroid Build Coastguard Worker        m1 = torch.tensor(
3634*da0073e9SAndroid Build Coastguard Worker            [True, False, False, True, False, False], dtype=torch.bool, device=device
3635*da0073e9SAndroid Build Coastguard Worker        )
3636*da0073e9SAndroid Build Coastguard Worker        m2 = torch.tensor(
3637*da0073e9SAndroid Build Coastguard Worker            [True, True, False, False, False, True], dtype=torch.bool, device=device
3638*da0073e9SAndroid Build Coastguard Worker        )
3639*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor(
3640*da0073e9SAndroid Build Coastguard Worker            [True, True, False, True, False, True], dtype=torch.bool, device=device
3641*da0073e9SAndroid Build Coastguard Worker        )
3642*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m1 + m2, expected)
3643*da0073e9SAndroid Build Coastguard Worker
3644*da0073e9SAndroid Build Coastguard Worker        # fused multiply add
3645*da0073e9SAndroid Build Coastguard Worker        a = torch.zeros(2, 3, dtype=torch.bool, device=device)
3646*da0073e9SAndroid Build Coastguard Worker        res = torch.add(a, a, alpha=0)
3647*da0073e9SAndroid Build Coastguard Worker        expected = torch.zeros(2, 3, device=device).bool()
3648*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, expected)
3649*da0073e9SAndroid Build Coastguard Worker
3650*da0073e9SAndroid Build Coastguard Worker        # bfloat16
3651*da0073e9SAndroid Build Coastguard Worker        m1 = torch.tensor([1.0, 2.0], dtype=torch.bfloat16)
3652*da0073e9SAndroid Build Coastguard Worker        m2 = torch.tensor([3.0, 4.0], dtype=torch.bfloat16)
3653*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m1 + m2, torch.tensor([4.0, 6.0], dtype=torch.bfloat16))
3654*da0073e9SAndroid Build Coastguard Worker
3655*da0073e9SAndroid Build Coastguard Worker        # different alpha types
3656*da0073e9SAndroid Build Coastguard Worker        m1 = torch.tensor([2 + 3j, 4 + 5j], dtype=torch.complex64, device=device)
3657*da0073e9SAndroid Build Coastguard Worker        m2 = torch.tensor([4 + 5j, 2 + 3j], dtype=torch.complex64, device=device)
3658*da0073e9SAndroid Build Coastguard Worker        # add complex numbers with float alpha
3659*da0073e9SAndroid Build Coastguard Worker        res = torch.add(m1, m2, alpha=0.1)
3660*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor(
3661*da0073e9SAndroid Build Coastguard Worker            [2.4000 + 3.5000j, 4.2000 + 5.3000j], dtype=torch.complex64, device=device
3662*da0073e9SAndroid Build Coastguard Worker        )
3663*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, expected)
3664*da0073e9SAndroid Build Coastguard Worker
3665*da0073e9SAndroid Build Coastguard Worker        # add complex numbers with complex alpha
3666*da0073e9SAndroid Build Coastguard Worker        res = torch.add(m1, m2, alpha=complex(0.1, 0.2))
3667*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor(
3668*da0073e9SAndroid Build Coastguard Worker            [1.4000 + 4.3000j, 3.6000 + 5.7000j], dtype=torch.complex64, device=device
3669*da0073e9SAndroid Build Coastguard Worker        )
3670*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, expected)
3671*da0073e9SAndroid Build Coastguard Worker
3672*da0073e9SAndroid Build Coastguard Worker        # add complex numbers with integer alpha
3673*da0073e9SAndroid Build Coastguard Worker        res = torch.add(m1, m2, alpha=2)
3674*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor(
3675*da0073e9SAndroid Build Coastguard Worker            [10.0 + 13.0j, 8.0 + 11.0j], dtype=torch.complex64, device=device
3676*da0073e9SAndroid Build Coastguard Worker        )
3677*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, expected)
3678*da0073e9SAndroid Build Coastguard Worker
3679*da0073e9SAndroid Build Coastguard Worker        # mismatched alpha
3680*da0073e9SAndroid Build Coastguard Worker        m1 = torch.tensor([1], dtype=torch.int8, device=device)
3681*da0073e9SAndroid Build Coastguard Worker        m2 = torch.tensor([2], dtype=torch.int8, device=device)
3682*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
3683*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
3684*da0073e9SAndroid Build Coastguard Worker            r"Boolean alpha only supported for Boolean results\.",
3685*da0073e9SAndroid Build Coastguard Worker            lambda: torch.add(m1, m2, alpha=True),
3686*da0073e9SAndroid Build Coastguard Worker        )
3687*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
3688*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
3689*da0073e9SAndroid Build Coastguard Worker            r"For integral input tensors, argument alpha must not be a floating point number\.",
3690*da0073e9SAndroid Build Coastguard Worker            lambda: torch.add(m1, m2, alpha=1.0),
3691*da0073e9SAndroid Build Coastguard Worker        )
3692*da0073e9SAndroid Build Coastguard Worker
3693*da0073e9SAndroid Build Coastguard Worker        # mismatched alpha, float / double tensor and complex alpha
3694*da0073e9SAndroid Build Coastguard Worker        msg = r"For non-complex input tensors, argument alpha must not be a complex number\."
3695*da0073e9SAndroid Build Coastguard Worker        m1 = torch.tensor([3.0, 4.0], device=device)
3696*da0073e9SAndroid Build Coastguard Worker        m2 = torch.tensor([4.0, 3.0], device=device)
3697*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
3698*da0073e9SAndroid Build Coastguard Worker            RuntimeError, msg, lambda: torch.add(m1, m2, alpha=complex(0.1, 0.2))
3699*da0073e9SAndroid Build Coastguard Worker        )
3700*da0073e9SAndroid Build Coastguard Worker
3701*da0073e9SAndroid Build Coastguard Worker        m1 = torch.tensor([3.0, 4.0], dtype=torch.double, device=device)
3702*da0073e9SAndroid Build Coastguard Worker        m2 = torch.tensor([4.0, 3.0], dtype=torch.double, device=device)
3703*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
3704*da0073e9SAndroid Build Coastguard Worker            RuntimeError, msg, lambda: torch.add(m1, m2, alpha=complex(0.1, 0.2))
3705*da0073e9SAndroid Build Coastguard Worker        )
3706*da0073e9SAndroid Build Coastguard Worker
3707*da0073e9SAndroid Build Coastguard Worker        # complex
3708*da0073e9SAndroid Build Coastguard Worker        m1 = torch.tensor((4.0000 + 4.0000j), dtype=torch.complex64)
3709*da0073e9SAndroid Build Coastguard Worker        m2 = torch.tensor(4.0, dtype=torch.float64)
3710*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
3711*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
3712*da0073e9SAndroid Build Coastguard Worker            r"result type ComplexFloat can't be cast to the desired output type Double",
3713*da0073e9SAndroid Build Coastguard Worker            lambda: torch.add(m1, m1, out=m2),
3714*da0073e9SAndroid Build Coastguard Worker        )
3715*da0073e9SAndroid Build Coastguard Worker
3716*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
3717*da0073e9SAndroid Build Coastguard Worker    def test_addsub_half_tensor(self, device):
3718*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([60000.0], dtype=torch.half, device=device)
3719*da0073e9SAndroid Build Coastguard Worker        for op, y, alpha in (
3720*da0073e9SAndroid Build Coastguard Worker            (torch.add, torch.tensor([-60000.0], dtype=torch.half, device=device), 2),
3721*da0073e9SAndroid Build Coastguard Worker            (torch.sub, torch.tensor([60000.0], dtype=torch.half, device=device), 2),
3722*da0073e9SAndroid Build Coastguard Worker            (torch.add, -70000.0, 1),
3723*da0073e9SAndroid Build Coastguard Worker            (torch.sub, 70000.0, 1),
3724*da0073e9SAndroid Build Coastguard Worker        ):
3725*da0073e9SAndroid Build Coastguard Worker            actual = op(x, y, alpha=alpha)
3726*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(not (actual.isnan() or actual.isinf()))
3727*da0073e9SAndroid Build Coastguard Worker
3728*da0073e9SAndroid Build Coastguard Worker    def test_sub_typing(self, device):
3729*da0073e9SAndroid Build Coastguard Worker        m1 = torch.tensor(
3730*da0073e9SAndroid Build Coastguard Worker            [True, False, False, True, False, False], dtype=torch.bool, device=device
3731*da0073e9SAndroid Build Coastguard Worker        )
3732*da0073e9SAndroid Build Coastguard Worker        m2 = torch.tensor(
3733*da0073e9SAndroid Build Coastguard Worker            [True, True, False, False, False, True], dtype=torch.bool, device=device
3734*da0073e9SAndroid Build Coastguard Worker        )
3735*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
3736*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
3737*da0073e9SAndroid Build Coastguard Worker            r"Subtraction, the `\-` operator, with two bool tensors is not supported. "
3738*da0073e9SAndroid Build Coastguard Worker            r"Use the `\^` or `logical_xor\(\)` operator instead.",
3739*da0073e9SAndroid Build Coastguard Worker            lambda: m1 - m2,
3740*da0073e9SAndroid Build Coastguard Worker        )
3741*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
3742*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
3743*da0073e9SAndroid Build Coastguard Worker            r"Subtraction, the `\-` operator, with a bool tensor is not supported. "
3744*da0073e9SAndroid Build Coastguard Worker            r"If you are trying to invert a mask, use the `\~` or `logical_not\(\)` operator instead.",
3745*da0073e9SAndroid Build Coastguard Worker            lambda: 1 - m1,
3746*da0073e9SAndroid Build Coastguard Worker        )
3747*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
3748*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
3749*da0073e9SAndroid Build Coastguard Worker            r"Subtraction, the `\-` operator, with a bool tensor is not supported. "
3750*da0073e9SAndroid Build Coastguard Worker            r"If you are trying to invert a mask, use the `\~` or `logical_not\(\)` operator instead.",
3751*da0073e9SAndroid Build Coastguard Worker            lambda: m2 - 1,
3752*da0073e9SAndroid Build Coastguard Worker        )
3753*da0073e9SAndroid Build Coastguard Worker
3754*da0073e9SAndroid Build Coastguard Worker        # mismatched alpha
3755*da0073e9SAndroid Build Coastguard Worker        m1 = torch.tensor([1], dtype=torch.int8, device=device)
3756*da0073e9SAndroid Build Coastguard Worker        m2 = torch.tensor([2], dtype=torch.int8, device=device)
3757*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
3758*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
3759*da0073e9SAndroid Build Coastguard Worker            r"Boolean alpha only supported for Boolean results\.",
3760*da0073e9SAndroid Build Coastguard Worker            lambda: torch.sub(m1, m2, alpha=True),
3761*da0073e9SAndroid Build Coastguard Worker        )
3762*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
3763*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
3764*da0073e9SAndroid Build Coastguard Worker            r"For integral input tensors, argument alpha must not be a floating point number\.",
3765*da0073e9SAndroid Build Coastguard Worker            lambda: torch.sub(m1, m2, alpha=1.0),
3766*da0073e9SAndroid Build Coastguard Worker        )
3767*da0073e9SAndroid Build Coastguard Worker
3768*da0073e9SAndroid Build Coastguard Worker    def test_mul(self, device):
3769*da0073e9SAndroid Build Coastguard Worker        m1 = torch.randn(10, 10, device=device)
3770*da0073e9SAndroid Build Coastguard Worker        res1 = m1.clone()
3771*da0073e9SAndroid Build Coastguard Worker        res1[:, 3].mul_(2)
3772*da0073e9SAndroid Build Coastguard Worker        res2 = m1.clone()
3773*da0073e9SAndroid Build Coastguard Worker        for i in range(res1.size(0)):
3774*da0073e9SAndroid Build Coastguard Worker            res2[i, 3] = res2[i, 3] * 2
3775*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, res2)
3776*da0073e9SAndroid Build Coastguard Worker
3777*da0073e9SAndroid Build Coastguard Worker        a1 = torch.tensor([True, False, False, True], dtype=torch.bool, device=device)
3778*da0073e9SAndroid Build Coastguard Worker        a2 = torch.tensor([True, False, True, False], dtype=torch.bool, device=device)
3779*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3780*da0073e9SAndroid Build Coastguard Worker            a1 * a2,
3781*da0073e9SAndroid Build Coastguard Worker            torch.tensor([True, False, False, False], dtype=torch.bool, device=device),
3782*da0073e9SAndroid Build Coastguard Worker        )
3783*da0073e9SAndroid Build Coastguard Worker
3784*da0073e9SAndroid Build Coastguard Worker        if device == "cpu":
3785*da0073e9SAndroid Build Coastguard Worker            a1 = torch.tensor([0.1, 0.1], dtype=torch.bfloat16, device=device)
3786*da0073e9SAndroid Build Coastguard Worker            a2 = torch.tensor([1.1, 0.1], dtype=torch.bfloat16, device=device)
3787*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
3788*da0073e9SAndroid Build Coastguard Worker                a1 * a2,
3789*da0073e9SAndroid Build Coastguard Worker                torch.tensor([0.11, 0.01], dtype=torch.bfloat16, device=device),
3790*da0073e9SAndroid Build Coastguard Worker                atol=0.01,
3791*da0073e9SAndroid Build Coastguard Worker                rtol=0,
3792*da0073e9SAndroid Build Coastguard Worker            )
3793*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a1.mul(a2), a1 * a2)
3794*da0073e9SAndroid Build Coastguard Worker
3795*da0073e9SAndroid Build Coastguard Worker    def test_bool_tensor_comparison_ops(self, device):
3796*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(
3797*da0073e9SAndroid Build Coastguard Worker            [True, False, True, False, True, False], dtype=torch.bool, device=device
3798*da0073e9SAndroid Build Coastguard Worker        )
3799*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor(
3800*da0073e9SAndroid Build Coastguard Worker            [True, False, True, True, True, True], dtype=torch.bool, device=device
3801*da0073e9SAndroid Build Coastguard Worker        )
3802*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3803*da0073e9SAndroid Build Coastguard Worker            a == b, torch.tensor([1, 1, 1, 0, 1, 0], dtype=torch.bool, device=device)
3804*da0073e9SAndroid Build Coastguard Worker        )
3805*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3806*da0073e9SAndroid Build Coastguard Worker            a != b, torch.tensor([0, 0, 0, 1, 0, 1], dtype=torch.bool, device=device)
3807*da0073e9SAndroid Build Coastguard Worker        )
3808*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3809*da0073e9SAndroid Build Coastguard Worker            a < b, torch.tensor([0, 0, 0, 1, 0, 1], dtype=torch.bool, device=device)
3810*da0073e9SAndroid Build Coastguard Worker        )
3811*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3812*da0073e9SAndroid Build Coastguard Worker            a > b, torch.tensor([0, 0, 0, 0, 0, 0], dtype=torch.bool, device=device)
3813*da0073e9SAndroid Build Coastguard Worker        )
3814*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3815*da0073e9SAndroid Build Coastguard Worker            a >= b, torch.tensor([1, 1, 1, 0, 1, 0], dtype=torch.bool, device=device)
3816*da0073e9SAndroid Build Coastguard Worker        )
3817*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3818*da0073e9SAndroid Build Coastguard Worker            a <= b, torch.tensor([1, 1, 1, 1, 1, 1], dtype=torch.bool, device=device)
3819*da0073e9SAndroid Build Coastguard Worker        )
3820*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3821*da0073e9SAndroid Build Coastguard Worker            a > False, torch.tensor([1, 0, 1, 0, 1, 0], dtype=torch.bool, device=device)
3822*da0073e9SAndroid Build Coastguard Worker        )
3823*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3824*da0073e9SAndroid Build Coastguard Worker            a == torch.tensor(True, dtype=torch.bool, device=device),
3825*da0073e9SAndroid Build Coastguard Worker            torch.tensor([1, 0, 1, 0, 1, 0], dtype=torch.bool, device=device),
3826*da0073e9SAndroid Build Coastguard Worker        )
3827*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3828*da0073e9SAndroid Build Coastguard Worker            a == torch.tensor(0, dtype=torch.bool, device=device),
3829*da0073e9SAndroid Build Coastguard Worker            torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.bool, device=device),
3830*da0073e9SAndroid Build Coastguard Worker        )
3831*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(a.equal(b))
3832*da0073e9SAndroid Build Coastguard Worker
3833*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and(torch.half, torch.bfloat16, torch.bool))
3834*da0073e9SAndroid Build Coastguard Worker    def test_logical(self, device, dtype):
3835*da0073e9SAndroid Build Coastguard Worker        if dtype != torch.bool:
3836*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor([1, 2, 3, 4], device=device, dtype=dtype)
3837*da0073e9SAndroid Build Coastguard Worker            b = torch.tensor([2], device=device, dtype=dtype)
3838*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.lt(2), torch.tensor([True, False, False, False]))
3839*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.le(2), torch.tensor([True, True, False, False]))
3840*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.ge(2), torch.tensor([False, True, True, True]))
3841*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.gt(2), torch.tensor([False, False, True, True]))
3842*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.eq(2), torch.tensor([False, True, False, False]))
3843*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.ne(2), torch.tensor([True, False, True, True]))
3844*da0073e9SAndroid Build Coastguard Worker
3845*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.lt(b), torch.tensor([True, False, False, False]))
3846*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.le(b), torch.tensor([True, True, False, False]))
3847*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.ge(b), torch.tensor([False, True, True, True]))
3848*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.gt(b), torch.tensor([False, False, True, True]))
3849*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.eq(b), torch.tensor([False, True, False, False]))
3850*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.ne(b), torch.tensor([True, False, True, True]))
3851*da0073e9SAndroid Build Coastguard Worker        else:
3852*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor([True, False, True, False], device=device)
3853*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.lt(True), torch.tensor([False, True, False, True]))
3854*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.le(True), torch.tensor([True, True, True, True]))
3855*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.ge(True), torch.tensor([True, False, True, False]))
3856*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.gt(True), torch.tensor([False, False, False, False]))
3857*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.eq(True), torch.tensor([True, False, True, False]))
3858*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.ne(True), torch.tensor([False, True, False, True]))
3859*da0073e9SAndroid Build Coastguard Worker
3860*da0073e9SAndroid Build Coastguard Worker    def test_atan2(self, device):
3861*da0073e9SAndroid Build Coastguard Worker        def _test_atan2_with_size(size, device):
3862*da0073e9SAndroid Build Coastguard Worker            a = torch.rand(size=size, device=device, dtype=torch.double)
3863*da0073e9SAndroid Build Coastguard Worker            b = torch.rand(size=size, device=device, dtype=torch.double)
3864*da0073e9SAndroid Build Coastguard Worker            actual = a.atan2(b)
3865*da0073e9SAndroid Build Coastguard Worker            x = a.view(-1)
3866*da0073e9SAndroid Build Coastguard Worker            y = b.view(-1)
3867*da0073e9SAndroid Build Coastguard Worker            expected = torch.tensor(
3868*da0073e9SAndroid Build Coastguard Worker                [math.atan2(x[i].item(), y[i].item()) for i in range(x.numel())],
3869*da0073e9SAndroid Build Coastguard Worker                device=device,
3870*da0073e9SAndroid Build Coastguard Worker                dtype=torch.double,
3871*da0073e9SAndroid Build Coastguard Worker            )
3872*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, actual.view(-1), rtol=0, atol=0.02)
3873*da0073e9SAndroid Build Coastguard Worker
3874*da0073e9SAndroid Build Coastguard Worker            # bfloat16/float16
3875*da0073e9SAndroid Build Coastguard Worker            for lowp_dtype in [torch.bfloat16, torch.float16]:
3876*da0073e9SAndroid Build Coastguard Worker                if lowp_dtype == torch.bfloat16:
3877*da0073e9SAndroid Build Coastguard Worker                    rtol = 0
3878*da0073e9SAndroid Build Coastguard Worker                    atol = 0.02
3879*da0073e9SAndroid Build Coastguard Worker                else:
3880*da0073e9SAndroid Build Coastguard Worker                    rtol = 0
3881*da0073e9SAndroid Build Coastguard Worker                    atol = 0.001
3882*da0073e9SAndroid Build Coastguard Worker                a_16 = a.to(dtype=lowp_dtype)
3883*da0073e9SAndroid Build Coastguard Worker                b_16 = b.to(dtype=lowp_dtype)
3884*da0073e9SAndroid Build Coastguard Worker                actual_16 = a_16.atan2(b_16)
3885*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual_16, actual.to(dtype=lowp_dtype))
3886*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
3887*da0073e9SAndroid Build Coastguard Worker                    expected,
3888*da0073e9SAndroid Build Coastguard Worker                    actual_16.view(-1),
3889*da0073e9SAndroid Build Coastguard Worker                    exact_dtype=False,
3890*da0073e9SAndroid Build Coastguard Worker                    rtol=rtol,
3891*da0073e9SAndroid Build Coastguard Worker                    atol=atol,
3892*da0073e9SAndroid Build Coastguard Worker                )
3893*da0073e9SAndroid Build Coastguard Worker
3894*da0073e9SAndroid Build Coastguard Worker        _test_atan2_with_size((2, 2), device)
3895*da0073e9SAndroid Build Coastguard Worker        _test_atan2_with_size((3, 3), device)
3896*da0073e9SAndroid Build Coastguard Worker        _test_atan2_with_size((5, 5), device)
3897*da0073e9SAndroid Build Coastguard Worker
3898*da0073e9SAndroid Build Coastguard Worker    def test_atan2_edgecases(self, device):
3899*da0073e9SAndroid Build Coastguard Worker        def _test_atan2(x, y, expected, device, dtype):
3900*da0073e9SAndroid Build Coastguard Worker            expected_tensor = torch.tensor([expected], dtype=dtype, device=device)
3901*da0073e9SAndroid Build Coastguard Worker            x_tensor = torch.tensor([x], dtype=dtype, device=device)
3902*da0073e9SAndroid Build Coastguard Worker            y_tensor = torch.tensor([y], dtype=dtype, device=device)
3903*da0073e9SAndroid Build Coastguard Worker            actual = torch.atan2(y_tensor, x_tensor)
3904*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected_tensor, actual, rtol=0, atol=0.02)
3905*da0073e9SAndroid Build Coastguard Worker
3906*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.float, torch.double]:
3907*da0073e9SAndroid Build Coastguard Worker            _test_atan2(0, 0, 0, device, dtype)
3908*da0073e9SAndroid Build Coastguard Worker            _test_atan2(0, 1, math.pi / 2, device, dtype)
3909*da0073e9SAndroid Build Coastguard Worker            _test_atan2(0, -1, math.pi / -2, device, dtype)
3910*da0073e9SAndroid Build Coastguard Worker            _test_atan2(-1, 0, math.pi, device, dtype)
3911*da0073e9SAndroid Build Coastguard Worker            _test_atan2(1, 0, 0, device, dtype)
3912*da0073e9SAndroid Build Coastguard Worker            _test_atan2(-1, -1, math.pi * -3 / 4, device, dtype)
3913*da0073e9SAndroid Build Coastguard Worker            _test_atan2(1, 1, math.pi / 4, device, dtype)
3914*da0073e9SAndroid Build Coastguard Worker            _test_atan2(1, -1, math.pi / -4, device, dtype)
3915*da0073e9SAndroid Build Coastguard Worker            _test_atan2(-1, 1, math.pi * 3 / 4, device, dtype)
3916*da0073e9SAndroid Build Coastguard Worker
3917*da0073e9SAndroid Build Coastguard Worker    def test_trapezoid(self, device):
3918*da0073e9SAndroid Build Coastguard Worker        def test_dx(sizes, dim, dx, device):
3919*da0073e9SAndroid Build Coastguard Worker            t = torch.randn(sizes, device=device)
3920*da0073e9SAndroid Build Coastguard Worker            actual = torch.trapezoid(t, dx=dx, dim=dim)
3921*da0073e9SAndroid Build Coastguard Worker            expected = np.trapz(t.cpu().numpy(), dx=dx, axis=dim)  # noqa: NPY201
3922*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected.shape, actual.shape)
3923*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, actual, exact_dtype=False)
3924*da0073e9SAndroid Build Coastguard Worker
3925*da0073e9SAndroid Build Coastguard Worker        def test_x(sizes, dim, x, device):
3926*da0073e9SAndroid Build Coastguard Worker            t = torch.randn(sizes, device=device)
3927*da0073e9SAndroid Build Coastguard Worker            actual = torch.trapezoid(t, x=torch.tensor(x, device=device), dim=dim)
3928*da0073e9SAndroid Build Coastguard Worker            expected = np.trapz(t.cpu().numpy(), x=x, axis=dim)  # noqa: NPY201
3929*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected.shape, actual.shape)
3930*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, actual.cpu(), exact_dtype=False)
3931*da0073e9SAndroid Build Coastguard Worker
3932*da0073e9SAndroid Build Coastguard Worker        test_dx((2, 3, 4), 1, 1, device)
3933*da0073e9SAndroid Build Coastguard Worker        test_dx((10, 2), 0, 0.1, device)
3934*da0073e9SAndroid Build Coastguard Worker        test_dx((1, 10), 0, 2.3, device)
3935*da0073e9SAndroid Build Coastguard Worker        test_dx((0, 2), 0, 1.0, device)
3936*da0073e9SAndroid Build Coastguard Worker        test_dx((0, 2), 1, 1.0, device)
3937*da0073e9SAndroid Build Coastguard Worker        test_x((2, 3, 4), 1, [1.0, 2.0, 3.0], device)
3938*da0073e9SAndroid Build Coastguard Worker        test_x(
3939*da0073e9SAndroid Build Coastguard Worker            (10, 2), 0, [2.0, 3.0, 4.0, 7.0, 11.0, 14.0, 22.0, 26.0, 26.1, 30.3], device
3940*da0073e9SAndroid Build Coastguard Worker        )
3941*da0073e9SAndroid Build Coastguard Worker        test_x((1, 10), 0, [1.0], device)
3942*da0073e9SAndroid Build Coastguard Worker        test_x((0, 2), 0, [], device)
3943*da0073e9SAndroid Build Coastguard Worker        test_x((0, 2), 1, [1.0, 2.0], device)
3944*da0073e9SAndroid Build Coastguard Worker        test_x((2, 3, 4), -1, [1.0, 2.0, 3.0, 4.0], device)
3945*da0073e9SAndroid Build Coastguard Worker        test_x((2, 3, 4), 0, [1.0, 2.0], device)
3946*da0073e9SAndroid Build Coastguard Worker        test_x((2, 3, 4), 1, [1.0, 2.0, 3.0], device)
3947*da0073e9SAndroid Build Coastguard Worker        test_x((2, 3, 4), 2, [1.0, 2.0, 3.0, 4.0], device)
3948*da0073e9SAndroid Build Coastguard Worker        test_x((2, 2, 4), -1, [[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]], device)
3949*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(IndexError, "Dimension out of range"):
3950*da0073e9SAndroid Build Coastguard Worker            test_x((2, 3), 2, [], device)
3951*da0073e9SAndroid Build Coastguard Worker            test_dx((2, 3), 2, 1.0, device)
3952*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
3953*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "There must be one `x` value for each sample point"
3954*da0073e9SAndroid Build Coastguard Worker        ):
3955*da0073e9SAndroid Build Coastguard Worker            test_x((2, 3), 1, [1.0, 2.0], device)
3956*da0073e9SAndroid Build Coastguard Worker            test_x((2, 3), 1, [1.0, 2.0, 3.0, 4.0], device)
3957*da0073e9SAndroid Build Coastguard Worker
3958*da0073e9SAndroid Build Coastguard Worker    @skipIf(not TEST_SCIPY, "Scipy required for the test.")
3959*da0073e9SAndroid Build Coastguard Worker    def test_cumulative_trapezoid(self, device):
3960*da0073e9SAndroid Build Coastguard Worker        import scipy.integrate
3961*da0073e9SAndroid Build Coastguard Worker
3962*da0073e9SAndroid Build Coastguard Worker        if hasattr(scipy.integrate, "cumulative_trapezoid"):
3963*da0073e9SAndroid Build Coastguard Worker            _scipy_cumulative_trapezoid = scipy.integrate.cumulative_trapezoid
3964*da0073e9SAndroid Build Coastguard Worker        else:  # Older version of SciPy uses a different name
3965*da0073e9SAndroid Build Coastguard Worker            _scipy_cumulative_trapezoid = scipy.integrate.cumtrapz
3966*da0073e9SAndroid Build Coastguard Worker
3967*da0073e9SAndroid Build Coastguard Worker        def scipy_cumulative_trapezoid(y, x=None, dx=1.0, axis=-1, initial=None):
3968*da0073e9SAndroid Build Coastguard Worker            if y.shape[axis] == 0:
3969*da0073e9SAndroid Build Coastguard Worker                return np.empty_like(y)
3970*da0073e9SAndroid Build Coastguard Worker            else:
3971*da0073e9SAndroid Build Coastguard Worker                return _scipy_cumulative_trapezoid(y, x, dx, axis, initial)
3972*da0073e9SAndroid Build Coastguard Worker
3973*da0073e9SAndroid Build Coastguard Worker        def test_dx(sizes, dim, dx, device):
3974*da0073e9SAndroid Build Coastguard Worker            t = torch.randn(sizes, device=device)
3975*da0073e9SAndroid Build Coastguard Worker            y = t.cpu().numpy()
3976*da0073e9SAndroid Build Coastguard Worker            actual = torch.cumulative_trapezoid(t, dx=dx, dim=dim)
3977*da0073e9SAndroid Build Coastguard Worker            expected = scipy_cumulative_trapezoid(t.cpu().numpy(), dx=dx, axis=dim)
3978*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected.shape, actual.shape)
3979*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, actual, exact_dtype=False, atol=1e-4, rtol=1e-4)
3980*da0073e9SAndroid Build Coastguard Worker
3981*da0073e9SAndroid Build Coastguard Worker        def test_x(sizes, dim, x, device):
3982*da0073e9SAndroid Build Coastguard Worker            t = torch.randn(sizes, device=device)
3983*da0073e9SAndroid Build Coastguard Worker            actual = torch.cumulative_trapezoid(
3984*da0073e9SAndroid Build Coastguard Worker                t, x=torch.tensor(x, device=device), dim=dim
3985*da0073e9SAndroid Build Coastguard Worker            )
3986*da0073e9SAndroid Build Coastguard Worker            expected = scipy_cumulative_trapezoid(t.cpu().numpy(), x=x, axis=dim)
3987*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected.shape, actual.shape)
3988*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
3989*da0073e9SAndroid Build Coastguard Worker                expected, actual.cpu(), exact_dtype=False, atol=1e-4, rtol=1e-4
3990*da0073e9SAndroid Build Coastguard Worker            )
3991*da0073e9SAndroid Build Coastguard Worker
3992*da0073e9SAndroid Build Coastguard Worker        def test_empty_x(sizes, dim, x, device):
3993*da0073e9SAndroid Build Coastguard Worker            t = torch.randn(sizes, device=device)
3994*da0073e9SAndroid Build Coastguard Worker            actual = torch.cumulative_trapezoid(
3995*da0073e9SAndroid Build Coastguard Worker                t, x=torch.tensor(x, device=device), dim=dim
3996*da0073e9SAndroid Build Coastguard Worker            )
3997*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.empty(actual.shape), actual)
3998*da0073e9SAndroid Build Coastguard Worker
3999*da0073e9SAndroid Build Coastguard Worker        test_dx((2,), -1, 1, device)
4000*da0073e9SAndroid Build Coastguard Worker        test_dx((3, 3), -1, 1, device)
4001*da0073e9SAndroid Build Coastguard Worker        test_dx((4, 2), 0, 1, device)
4002*da0073e9SAndroid Build Coastguard Worker        test_dx((2, 3, 4), 1, 1, device)
4003*da0073e9SAndroid Build Coastguard Worker        test_dx((10, 2), 0, 0.1, device)
4004*da0073e9SAndroid Build Coastguard Worker        test_dx((1, 10), 0, 2.3, device)
4005*da0073e9SAndroid Build Coastguard Worker        test_dx((0, 2), 0, 1.0, device)
4006*da0073e9SAndroid Build Coastguard Worker        test_dx((0, 2), 1, 1.0, device)
4007*da0073e9SAndroid Build Coastguard Worker        test_dx((512, 512), 1, 1.0, device)
4008*da0073e9SAndroid Build Coastguard Worker        test_dx((100, 100, 100), 1, 1.0, device)
4009*da0073e9SAndroid Build Coastguard Worker
4010*da0073e9SAndroid Build Coastguard Worker        test_x((2,), -1, [100, 50], device)
4011*da0073e9SAndroid Build Coastguard Worker        test_x((4, 2), 0, [2, 3, 4, 5], device)
4012*da0073e9SAndroid Build Coastguard Worker        test_x((2, 3, 4), 1, [1.0, 2.0, 3.0], device)
4013*da0073e9SAndroid Build Coastguard Worker        test_x(
4014*da0073e9SAndroid Build Coastguard Worker            (10, 2), 0, [2.0, 3.0, 4.0, 7.0, 11.0, 14.0, 22.0, 26.0, 26.1, 30.3], device
4015*da0073e9SAndroid Build Coastguard Worker        )
4016*da0073e9SAndroid Build Coastguard Worker        test_x((1, 10), 0, [1.0], device)
4017*da0073e9SAndroid Build Coastguard Worker        test_x((0, 2), 1, [1, 2], device)
4018*da0073e9SAndroid Build Coastguard Worker        test_x((2, 3, 4), -1, [1.0, 2.0, 3.0, 4.0], device)
4019*da0073e9SAndroid Build Coastguard Worker        test_x((2, 3, 4), 0, [1.0, 2.0], device)
4020*da0073e9SAndroid Build Coastguard Worker        test_x((2, 3, 4), 1, [1.0, 2.0, 3.0], device)
4021*da0073e9SAndroid Build Coastguard Worker        test_x((2, 3, 4), 2, [1.0, 2.0, 3.0, 4.0], device)
4022*da0073e9SAndroid Build Coastguard Worker
4023*da0073e9SAndroid Build Coastguard Worker        test_empty_x(
4024*da0073e9SAndroid Build Coastguard Worker            (0, 2), 0, [], device
4025*da0073e9SAndroid Build Coastguard Worker        )  # SciPy failing when x == [], but our version returns empty
4026*da0073e9SAndroid Build Coastguard Worker
4027*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(IndexError, "Dimension out of range"):
4028*da0073e9SAndroid Build Coastguard Worker            test_x((2, 3), 2, [], device)
4029*da0073e9SAndroid Build Coastguard Worker            test_dx((2, 3), 2, 1.0, device)
4030*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
4031*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "There must be one `x` value for each sample point"
4032*da0073e9SAndroid Build Coastguard Worker        ):
4033*da0073e9SAndroid Build Coastguard Worker            test_x((2, 3), 1, [1.0, 2.0], device)
4034*da0073e9SAndroid Build Coastguard Worker            test_x((0, 2), 0, [1.0, 2.0], device)
4035*da0073e9SAndroid Build Coastguard Worker            test_x((2, 3), 1, [1.0, 2.0, 3.0, 4.0], device)
4036*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
4037*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Currently, we only support dx as a real number"
4038*da0073e9SAndroid Build Coastguard Worker        ):
4039*da0073e9SAndroid Build Coastguard Worker            test_dx((2, 2), -1, complex(1, 1), device)
4040*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
4041*da0073e9SAndroid Build Coastguard Worker            TypeError, "received an invalid combination of arguments"
4042*da0073e9SAndroid Build Coastguard Worker        ):
4043*da0073e9SAndroid Build Coastguard Worker            actual = torch.cumulative_trapezoid(
4044*da0073e9SAndroid Build Coastguard Worker                torch.randn((3, 3)), x=torch.randn((3, 3)), dx=3
4045*da0073e9SAndroid Build Coastguard Worker            )
4046*da0073e9SAndroid Build Coastguard Worker
4047*da0073e9SAndroid Build Coastguard Worker    @skipMeta
4048*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
4049*da0073e9SAndroid Build Coastguard Worker    def test_pow_scalar_overloads_mem_overlap(self, device, dtype):
4050*da0073e9SAndroid Build Coastguard Worker        sz = 3
4051*da0073e9SAndroid Build Coastguard Worker        doubles = torch.randn(2 * sz, dtype=dtype, device=device)
4052*da0073e9SAndroid Build Coastguard Worker        self.check_internal_mem_overlap(lambda t: t.pow_(42), 1, dtype, device)
4053*da0073e9SAndroid Build Coastguard Worker        self.unary_check_input_output_mem_overlap(
4054*da0073e9SAndroid Build Coastguard Worker            doubles, sz, lambda input, out: torch.pow(input, 42, out=out)
4055*da0073e9SAndroid Build Coastguard Worker        )
4056*da0073e9SAndroid Build Coastguard Worker        self.unary_check_input_output_mem_overlap(
4057*da0073e9SAndroid Build Coastguard Worker            doubles, sz, lambda input, out: torch.pow(42, input, out=out)
4058*da0073e9SAndroid Build Coastguard Worker        )
4059*da0073e9SAndroid Build Coastguard Worker
4060*da0073e9SAndroid Build Coastguard Worker    @dtypes(
4061*da0073e9SAndroid Build Coastguard Worker        *list(
4062*da0073e9SAndroid Build Coastguard Worker            product(
4063*da0073e9SAndroid Build Coastguard Worker                all_types_and_complex_and(torch.half, torch.bfloat16),
4064*da0073e9SAndroid Build Coastguard Worker                all_types_and_complex_and(torch.half, torch.bfloat16),
4065*da0073e9SAndroid Build Coastguard Worker            )
4066*da0073e9SAndroid Build Coastguard Worker        )
4067*da0073e9SAndroid Build Coastguard Worker    )
4068*da0073e9SAndroid Build Coastguard Worker    def test_float_power(self, device, dtypes):
4069*da0073e9SAndroid Build Coastguard Worker        def to_np(value):
4070*da0073e9SAndroid Build Coastguard Worker            if isinstance(value, torch.Tensor) and value.dtype == torch.bfloat16:
4071*da0073e9SAndroid Build Coastguard Worker                return value.to(torch.float).cpu().numpy()
4072*da0073e9SAndroid Build Coastguard Worker            return value.cpu().numpy() if isinstance(value, torch.Tensor) else value
4073*da0073e9SAndroid Build Coastguard Worker
4074*da0073e9SAndroid Build Coastguard Worker        base_dtype = dtypes[0]
4075*da0073e9SAndroid Build Coastguard Worker        exp_dtype = dtypes[1]
4076*da0073e9SAndroid Build Coastguard Worker        out_dtype = (
4077*da0073e9SAndroid Build Coastguard Worker            torch.complex128
4078*da0073e9SAndroid Build Coastguard Worker            if base_dtype.is_complex or exp_dtype.is_complex
4079*da0073e9SAndroid Build Coastguard Worker            else torch.float64
4080*da0073e9SAndroid Build Coastguard Worker        )
4081*da0073e9SAndroid Build Coastguard Worker
4082*da0073e9SAndroid Build Coastguard Worker        base = make_tensor((30,), dtype=base_dtype, device=device, low=1, high=100)
4083*da0073e9SAndroid Build Coastguard Worker        # Complex and real results do not agree between PyTorch and NumPy when computing negative and zero power of 0
4084*da0073e9SAndroid Build Coastguard Worker        # Related: https://github.com/pytorch/pytorch/issues/48000
4085*da0073e9SAndroid Build Coastguard Worker        # base[0] = base[3] = base[7] = 0
4086*da0073e9SAndroid Build Coastguard Worker        exp = make_tensor((30,), dtype=exp_dtype, device=device, low=-2, high=2)
4087*da0073e9SAndroid Build Coastguard Worker        exp[0] = exp[4] = exp[6] = 0
4088*da0073e9SAndroid Build Coastguard Worker
4089*da0073e9SAndroid Build Coastguard Worker        expected = torch.from_numpy(np.float_power(to_np(base), to_np(exp)))
4090*da0073e9SAndroid Build Coastguard Worker
4091*da0073e9SAndroid Build Coastguard Worker        exponents = [-2.8, -2, -1, -0.5, 0.5, 1, 2]
4092*da0073e9SAndroid Build Coastguard Worker        complex_exponents = exponents + [
4093*da0073e9SAndroid Build Coastguard Worker            -2.5j,
4094*da0073e9SAndroid Build Coastguard Worker            -1.0j,
4095*da0073e9SAndroid Build Coastguard Worker            1.0j,
4096*da0073e9SAndroid Build Coastguard Worker            2.5j,
4097*da0073e9SAndroid Build Coastguard Worker            1.0 + 1.0j,
4098*da0073e9SAndroid Build Coastguard Worker            -1.0 - 1.5j,
4099*da0073e9SAndroid Build Coastguard Worker            3.3j,
4100*da0073e9SAndroid Build Coastguard Worker        ]
4101*da0073e9SAndroid Build Coastguard Worker
4102*da0073e9SAndroid Build Coastguard Worker        for op in (
4103*da0073e9SAndroid Build Coastguard Worker            torch.float_power,
4104*da0073e9SAndroid Build Coastguard Worker            torch.Tensor.float_power,
4105*da0073e9SAndroid Build Coastguard Worker            torch.Tensor.float_power_,
4106*da0073e9SAndroid Build Coastguard Worker        ):
4107*da0073e9SAndroid Build Coastguard Worker            # Case of Tensor x Tensor
4108*da0073e9SAndroid Build Coastguard Worker            if op is torch.Tensor.float_power_ and base_dtype != out_dtype:
4109*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(
4110*da0073e9SAndroid Build Coastguard Worker                    RuntimeError, "operation's result requires dtype"
4111*da0073e9SAndroid Build Coastguard Worker                ):
4112*da0073e9SAndroid Build Coastguard Worker                    op(base.clone(), exp)
4113*da0073e9SAndroid Build Coastguard Worker            else:
4114*da0073e9SAndroid Build Coastguard Worker                result = op(base.clone(), exp)
4115*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expected, result)
4116*da0073e9SAndroid Build Coastguard Worker
4117*da0073e9SAndroid Build Coastguard Worker            if op is torch.float_power:
4118*da0073e9SAndroid Build Coastguard Worker                out = torch.empty_like(base).to(device=device, dtype=out_dtype)
4119*da0073e9SAndroid Build Coastguard Worker                op(base, exp, out=out)
4120*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expected, out)
4121*da0073e9SAndroid Build Coastguard Worker
4122*da0073e9SAndroid Build Coastguard Worker            # Case of Tensor x Scalar
4123*da0073e9SAndroid Build Coastguard Worker            for i in complex_exponents if exp_dtype.is_complex else exponents:
4124*da0073e9SAndroid Build Coastguard Worker                out_dtype_scalar_exp = (
4125*da0073e9SAndroid Build Coastguard Worker                    torch.complex128
4126*da0073e9SAndroid Build Coastguard Worker                    if base_dtype.is_complex or type(i) == complex
4127*da0073e9SAndroid Build Coastguard Worker                    else torch.float64
4128*da0073e9SAndroid Build Coastguard Worker                )
4129*da0073e9SAndroid Build Coastguard Worker                expected_scalar_exp = torch.from_numpy(np.float_power(to_np(base), i))
4130*da0073e9SAndroid Build Coastguard Worker
4131*da0073e9SAndroid Build Coastguard Worker                if (
4132*da0073e9SAndroid Build Coastguard Worker                    op is torch.Tensor.float_power_
4133*da0073e9SAndroid Build Coastguard Worker                    and base_dtype != out_dtype_scalar_exp
4134*da0073e9SAndroid Build Coastguard Worker                ):
4135*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(
4136*da0073e9SAndroid Build Coastguard Worker                        RuntimeError, "operation's result requires dtype"
4137*da0073e9SAndroid Build Coastguard Worker                    ):
4138*da0073e9SAndroid Build Coastguard Worker                        op(base.clone(), i)
4139*da0073e9SAndroid Build Coastguard Worker                else:
4140*da0073e9SAndroid Build Coastguard Worker                    result = op(base.clone(), i)
4141*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(expected_scalar_exp, result)
4142*da0073e9SAndroid Build Coastguard Worker
4143*da0073e9SAndroid Build Coastguard Worker                if op is torch.float_power:
4144*da0073e9SAndroid Build Coastguard Worker                    out = torch.empty_like(base).to(
4145*da0073e9SAndroid Build Coastguard Worker                        device=device, dtype=out_dtype_scalar_exp
4146*da0073e9SAndroid Build Coastguard Worker                    )
4147*da0073e9SAndroid Build Coastguard Worker                    op(base, i, out=out)
4148*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(expected_scalar_exp, out)
4149*da0073e9SAndroid Build Coastguard Worker
4150*da0073e9SAndroid Build Coastguard Worker        # Case of Scalar x Tensor
4151*da0073e9SAndroid Build Coastguard Worker        for i in complex_exponents if base_dtype.is_complex else exponents:
4152*da0073e9SAndroid Build Coastguard Worker            out_dtype_scalar_base = (
4153*da0073e9SAndroid Build Coastguard Worker                torch.complex128
4154*da0073e9SAndroid Build Coastguard Worker                if exp_dtype.is_complex or type(i) == complex
4155*da0073e9SAndroid Build Coastguard Worker                else torch.float64
4156*da0073e9SAndroid Build Coastguard Worker            )
4157*da0073e9SAndroid Build Coastguard Worker            expected_scalar_base = torch.from_numpy(np.float_power(i, to_np(exp)))
4158*da0073e9SAndroid Build Coastguard Worker
4159*da0073e9SAndroid Build Coastguard Worker            result = torch.float_power(i, exp)
4160*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected_scalar_base, result)
4161*da0073e9SAndroid Build Coastguard Worker
4162*da0073e9SAndroid Build Coastguard Worker            out = torch.empty_like(exp).to(device=device, dtype=out_dtype_scalar_base)
4163*da0073e9SAndroid Build Coastguard Worker            torch.float_power(i, exp, out=out)
4164*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected_scalar_base, out)
4165*da0073e9SAndroid Build Coastguard Worker
4166*da0073e9SAndroid Build Coastguard Worker    def test_float_power_exceptions(self, device):
4167*da0073e9SAndroid Build Coastguard Worker        def _promo_helper(x, y):
4168*da0073e9SAndroid Build Coastguard Worker            for i in (x, y):
4169*da0073e9SAndroid Build Coastguard Worker                if type(i) == complex:
4170*da0073e9SAndroid Build Coastguard Worker                    return torch.complex128
4171*da0073e9SAndroid Build Coastguard Worker                elif type(i) == torch.Tensor and i.is_complex():
4172*da0073e9SAndroid Build Coastguard Worker                    return torch.complex128
4173*da0073e9SAndroid Build Coastguard Worker            return torch.double
4174*da0073e9SAndroid Build Coastguard Worker
4175*da0073e9SAndroid Build Coastguard Worker        test_cases = (
4176*da0073e9SAndroid Build Coastguard Worker            (torch.tensor([-2, -1, 0, 1, 2], device=device), -0.25),
4177*da0073e9SAndroid Build Coastguard Worker            (
4178*da0073e9SAndroid Build Coastguard Worker                torch.tensor([-1.0j, 0j, 1.0j, 1.0 + 1.0j, -1.0 - 1.5j], device=device),
4179*da0073e9SAndroid Build Coastguard Worker                2.0,
4180*da0073e9SAndroid Build Coastguard Worker            ),
4181*da0073e9SAndroid Build Coastguard Worker        )
4182*da0073e9SAndroid Build Coastguard Worker        for base, exp in test_cases:
4183*da0073e9SAndroid Build Coastguard Worker            for out_dtype in (torch.long, torch.float, torch.double, torch.cdouble):
4184*da0073e9SAndroid Build Coastguard Worker                out = torch.empty(1, device=device, dtype=out_dtype)
4185*da0073e9SAndroid Build Coastguard Worker                required_dtype = _promo_helper(base, exp)
4186*da0073e9SAndroid Build Coastguard Worker
4187*da0073e9SAndroid Build Coastguard Worker                if out.dtype == required_dtype:
4188*da0073e9SAndroid Build Coastguard Worker                    torch.float_power(base, exp, out=out)
4189*da0073e9SAndroid Build Coastguard Worker                else:
4190*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(
4191*da0073e9SAndroid Build Coastguard Worker                        RuntimeError, "operation's result requires dtype"
4192*da0073e9SAndroid Build Coastguard Worker                    ):
4193*da0073e9SAndroid Build Coastguard Worker                        torch.float_power(base, exp, out=out)
4194*da0073e9SAndroid Build Coastguard Worker
4195*da0073e9SAndroid Build Coastguard Worker                if base.dtype == required_dtype:
4196*da0073e9SAndroid Build Coastguard Worker                    torch.Tensor.float_power_(base.clone(), exp)
4197*da0073e9SAndroid Build Coastguard Worker                else:
4198*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(
4199*da0073e9SAndroid Build Coastguard Worker                        RuntimeError, "operation's result requires dtype"
4200*da0073e9SAndroid Build Coastguard Worker                    ):
4201*da0073e9SAndroid Build Coastguard Worker                        torch.Tensor.float_power_(base.clone(), exp)
4202*da0073e9SAndroid Build Coastguard Worker
4203*da0073e9SAndroid Build Coastguard Worker    @skipIf(not TEST_SCIPY, "Scipy required for the test.")
4204*da0073e9SAndroid Build Coastguard Worker    @dtypes(
4205*da0073e9SAndroid Build Coastguard Worker        *product(
4206*da0073e9SAndroid Build Coastguard Worker            all_types_and(torch.half, torch.bool), all_types_and(torch.half, torch.bool)
4207*da0073e9SAndroid Build Coastguard Worker        )
4208*da0073e9SAndroid Build Coastguard Worker    )
4209*da0073e9SAndroid Build Coastguard Worker    def test_xlogy_xlog1py(self, device, dtypes):
4210*da0073e9SAndroid Build Coastguard Worker        x_dtype, y_dtype = dtypes
4211*da0073e9SAndroid Build Coastguard Worker
4212*da0073e9SAndroid Build Coastguard Worker        def out_variant_helper(torch_fn, x, y):
4213*da0073e9SAndroid Build Coastguard Worker            expected = torch_fn(x, y)
4214*da0073e9SAndroid Build Coastguard Worker            out = torch.empty_like(expected)
4215*da0073e9SAndroid Build Coastguard Worker            torch_fn(x, y, out=out)
4216*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, out)
4217*da0073e9SAndroid Build Coastguard Worker
4218*da0073e9SAndroid Build Coastguard Worker        def xlogy_inplace_variant_helper(x, y):
4219*da0073e9SAndroid Build Coastguard Worker            if x.dtype in integral_types_and(torch.bool):
4220*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(
4221*da0073e9SAndroid Build Coastguard Worker                    RuntimeError, "can't be cast to the desired output type"
4222*da0073e9SAndroid Build Coastguard Worker                ):
4223*da0073e9SAndroid Build Coastguard Worker                    x.clone().xlogy_(y)
4224*da0073e9SAndroid Build Coastguard Worker            else:
4225*da0073e9SAndroid Build Coastguard Worker                expected = torch.empty_like(x)
4226*da0073e9SAndroid Build Coastguard Worker                torch.xlogy(x, y, out=expected)
4227*da0073e9SAndroid Build Coastguard Worker                inplace_out = x.clone().xlogy_(y)
4228*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expected, inplace_out)
4229*da0073e9SAndroid Build Coastguard Worker
4230*da0073e9SAndroid Build Coastguard Worker        def test_helper(torch_fn, reference_fn, inputs, scalar=None):
4231*da0073e9SAndroid Build Coastguard Worker            x, y, z = inputs
4232*da0073e9SAndroid Build Coastguard Worker            torch_fn_partial = partial(torch_fn, x)
4233*da0073e9SAndroid Build Coastguard Worker            reference_fn_partial = partial(reference_fn, x.cpu().numpy())
4234*da0073e9SAndroid Build Coastguard Worker            self.compare_with_numpy(
4235*da0073e9SAndroid Build Coastguard Worker                torch_fn_partial, reference_fn_partial, x, exact_dtype=False
4236*da0073e9SAndroid Build Coastguard Worker            )
4237*da0073e9SAndroid Build Coastguard Worker            self.compare_with_numpy(
4238*da0073e9SAndroid Build Coastguard Worker                torch_fn_partial, reference_fn_partial, y, exact_dtype=False
4239*da0073e9SAndroid Build Coastguard Worker            )
4240*da0073e9SAndroid Build Coastguard Worker            self.compare_with_numpy(
4241*da0073e9SAndroid Build Coastguard Worker                torch_fn_partial, reference_fn_partial, z, exact_dtype=False
4242*da0073e9SAndroid Build Coastguard Worker            )
4243*da0073e9SAndroid Build Coastguard Worker
4244*da0073e9SAndroid Build Coastguard Worker            val = scalar if scalar is not None else x
4245*da0073e9SAndroid Build Coastguard Worker            out_variant_helper(torch_fn, val, x)
4246*da0073e9SAndroid Build Coastguard Worker            out_variant_helper(torch_fn, val, y)
4247*da0073e9SAndroid Build Coastguard Worker            out_variant_helper(torch_fn, val, z)
4248*da0073e9SAndroid Build Coastguard Worker
4249*da0073e9SAndroid Build Coastguard Worker        # Tensor-Tensor Test (tensor of same and different shape)
4250*da0073e9SAndroid Build Coastguard Worker        x = make_tensor((3, 2, 4, 5), dtype=x_dtype, device=device, low=0.5, high=1000)
4251*da0073e9SAndroid Build Coastguard Worker        y = make_tensor((3, 2, 4, 5), dtype=y_dtype, device=device, low=0.5, high=1000)
4252*da0073e9SAndroid Build Coastguard Worker        z = make_tensor((4, 5), dtype=y_dtype, device=device, low=0.5, high=1000)
4253*da0073e9SAndroid Build Coastguard Worker
4254*da0073e9SAndroid Build Coastguard Worker        x_1p = make_tensor(
4255*da0073e9SAndroid Build Coastguard Worker            (3, 2, 4, 5), dtype=x_dtype, device=device, low=-0.5, high=1000
4256*da0073e9SAndroid Build Coastguard Worker        )
4257*da0073e9SAndroid Build Coastguard Worker        y_1p = make_tensor(
4258*da0073e9SAndroid Build Coastguard Worker            (3, 2, 4, 5), dtype=y_dtype, device=device, low=-0.5, high=1000
4259*da0073e9SAndroid Build Coastguard Worker        )
4260*da0073e9SAndroid Build Coastguard Worker        z_1p = make_tensor((4, 5), dtype=y_dtype, device=device, low=-0.5, high=1000)
4261*da0073e9SAndroid Build Coastguard Worker
4262*da0073e9SAndroid Build Coastguard Worker        xlogy_fns = torch.xlogy, scipy.special.xlogy
4263*da0073e9SAndroid Build Coastguard Worker        xlog1py_fns = torch.special.xlog1py, scipy.special.xlog1py
4264*da0073e9SAndroid Build Coastguard Worker
4265*da0073e9SAndroid Build Coastguard Worker        test_helper(*xlogy_fns, (x, y, z))
4266*da0073e9SAndroid Build Coastguard Worker        xlogy_inplace_variant_helper(x, x)
4267*da0073e9SAndroid Build Coastguard Worker        xlogy_inplace_variant_helper(x, y)
4268*da0073e9SAndroid Build Coastguard Worker        xlogy_inplace_variant_helper(x, z)
4269*da0073e9SAndroid Build Coastguard Worker        test_helper(*xlog1py_fns, (x_1p, y_1p, z_1p))
4270*da0073e9SAndroid Build Coastguard Worker
4271*da0073e9SAndroid Build Coastguard Worker        # Scalar-Tensor Test
4272*da0073e9SAndroid Build Coastguard Worker        test_helper(*xlogy_fns, (x, y, z), 3.14)
4273*da0073e9SAndroid Build Coastguard Worker        test_helper(*xlog1py_fns, (x_1p, y_1p, z_1p), 3.14)
4274*da0073e9SAndroid Build Coastguard Worker
4275*da0073e9SAndroid Build Coastguard Worker        # Special Values Tensor-Tensor
4276*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor(
4277*da0073e9SAndroid Build Coastguard Worker            [-1.0, 0.0, 1.0, 2.0, float("inf"), -float("inf"), float("nan")],
4278*da0073e9SAndroid Build Coastguard Worker            device=device,
4279*da0073e9SAndroid Build Coastguard Worker        )
4280*da0073e9SAndroid Build Coastguard Worker        zeros = torch.zeros(7, dtype=y_dtype, device=device)
4281*da0073e9SAndroid Build Coastguard Worker
4282*da0073e9SAndroid Build Coastguard Worker        def test_zeros_special_helper(torch_fn, reference_fn, scalar=False):
4283*da0073e9SAndroid Build Coastguard Worker            zeros_t = 0 if scalar else zeros
4284*da0073e9SAndroid Build Coastguard Worker            zeros_np = 0 if scalar else zeros.cpu().numpy()
4285*da0073e9SAndroid Build Coastguard Worker            torch_fn_partial = partial(torch_fn, zeros_t)
4286*da0073e9SAndroid Build Coastguard Worker            reference_fn_partial = partial(reference_fn, zeros_np)
4287*da0073e9SAndroid Build Coastguard Worker            self.compare_with_numpy(
4288*da0073e9SAndroid Build Coastguard Worker                torch_fn_partial, reference_fn_partial, t, exact_dtype=False
4289*da0073e9SAndroid Build Coastguard Worker            )
4290*da0073e9SAndroid Build Coastguard Worker            out_variant_helper(torch_fn, zeros_t, t)
4291*da0073e9SAndroid Build Coastguard Worker
4292*da0073e9SAndroid Build Coastguard Worker        test_zeros_special_helper(*xlogy_fns)
4293*da0073e9SAndroid Build Coastguard Worker        xlogy_inplace_variant_helper(zeros, t)
4294*da0073e9SAndroid Build Coastguard Worker        test_zeros_special_helper(*xlog1py_fns)
4295*da0073e9SAndroid Build Coastguard Worker
4296*da0073e9SAndroid Build Coastguard Worker        # Special Values Scalar-Tensor
4297*da0073e9SAndroid Build Coastguard Worker        test_zeros_special_helper(*xlogy_fns, scalar=True)
4298*da0073e9SAndroid Build Coastguard Worker        test_zeros_special_helper(*xlog1py_fns, scalar=True)
4299*da0073e9SAndroid Build Coastguard Worker
4300*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float64)
4301*da0073e9SAndroid Build Coastguard Worker    def test_xlogy_xlog1py_gradients(self, device, dtype):
4302*da0073e9SAndroid Build Coastguard Worker        make_arg = partial(torch.tensor, dtype=dtype, device=device, requires_grad=True)
4303*da0073e9SAndroid Build Coastguard Worker
4304*da0073e9SAndroid Build Coastguard Worker        zeros = torch.zeros((2,), dtype=dtype, device=device)
4305*da0073e9SAndroid Build Coastguard Worker
4306*da0073e9SAndroid Build Coastguard Worker        x = make_arg([0.0, 0.0])
4307*da0073e9SAndroid Build Coastguard Worker        y = make_arg([-1.5, 0.0])
4308*da0073e9SAndroid Build Coastguard Worker        torch.special.xlogy(x, y).sum().backward()
4309*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, zeros)
4310*da0073e9SAndroid Build Coastguard Worker
4311*da0073e9SAndroid Build Coastguard Worker        x = make_arg([0.0, 0.0])
4312*da0073e9SAndroid Build Coastguard Worker        y = make_arg([-2.5, -1.0])
4313*da0073e9SAndroid Build Coastguard Worker        torch.special.xlog1py(x, y).sum().backward()
4314*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, zeros)
4315*da0073e9SAndroid Build Coastguard Worker
4316*da0073e9SAndroid Build Coastguard Worker    def test_xlogy_xlog1py_scalar_type_promotion(self, device):
4317*da0073e9SAndroid Build Coastguard Worker        # Test that python numbers don't participate in type promotion at the same
4318*da0073e9SAndroid Build Coastguard Worker        # priority level as 0-dim tensors
4319*da0073e9SAndroid Build Coastguard Worker        t = torch.randn((), dtype=torch.float32, device=device)
4320*da0073e9SAndroid Build Coastguard Worker
4321*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.dtype, torch.xlogy(t, 5).dtype)
4322*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.dtype, torch.xlogy(t, 5.0).dtype)
4323*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.dtype, torch.special.xlog1py(t, 5).dtype)
4324*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.dtype, torch.special.xlog1py(t, 5.0).dtype)
4325*da0073e9SAndroid Build Coastguard Worker
4326*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.dtype, torch.xlogy(5, t).dtype)
4327*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.dtype, torch.xlogy(5.0, t).dtype)
4328*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.dtype, torch.special.xlog1py(5, t).dtype)
4329*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.dtype, torch.special.xlog1py(5.0, t).dtype)
4330*da0073e9SAndroid Build Coastguard Worker
4331*da0073e9SAndroid Build Coastguard Worker    @skipIf(not TEST_SCIPY, "Scipy required for the test.")
4332*da0073e9SAndroid Build Coastguard Worker    def test_xlogy_xlog1py_bfloat16(self, device):
4333*da0073e9SAndroid Build Coastguard Worker        def _compare_helper(x, y, torch_fn, reference_fn):
4334*da0073e9SAndroid Build Coastguard Worker            x_np = x if isinstance(x, float) else x.cpu().to(torch.float).numpy()
4335*da0073e9SAndroid Build Coastguard Worker            y_np = y if isinstance(y, float) else y.cpu().to(torch.float).numpy()
4336*da0073e9SAndroid Build Coastguard Worker            expected = torch.from_numpy(reference_fn(x_np, y_np))
4337*da0073e9SAndroid Build Coastguard Worker            actual = torch_fn(x, y)
4338*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, actual, exact_dtype=False)
4339*da0073e9SAndroid Build Coastguard Worker
4340*da0073e9SAndroid Build Coastguard Worker        x_dtype, y_dtype = torch.bfloat16, torch.bfloat16
4341*da0073e9SAndroid Build Coastguard Worker
4342*da0073e9SAndroid Build Coastguard Worker        # Tensor-Tensor Test (tensor of same and different shape)
4343*da0073e9SAndroid Build Coastguard Worker        x = make_tensor((3, 2, 4, 5), dtype=x_dtype, device=device, low=0.5, high=1000)
4344*da0073e9SAndroid Build Coastguard Worker        y = make_tensor((3, 2, 4, 5), dtype=y_dtype, device=device, low=0.5, high=1000)
4345*da0073e9SAndroid Build Coastguard Worker        z = make_tensor((4, 5), dtype=y_dtype, device=device, low=0.5, high=1000)
4346*da0073e9SAndroid Build Coastguard Worker
4347*da0073e9SAndroid Build Coastguard Worker        x_1p = make_tensor(
4348*da0073e9SAndroid Build Coastguard Worker            (3, 2, 4, 5), dtype=x_dtype, device=device, low=-0.8, high=1000
4349*da0073e9SAndroid Build Coastguard Worker        )
4350*da0073e9SAndroid Build Coastguard Worker        y_1p = make_tensor(
4351*da0073e9SAndroid Build Coastguard Worker            (3, 2, 4, 5), dtype=y_dtype, device=device, low=-0.8, high=1000
4352*da0073e9SAndroid Build Coastguard Worker        )
4353*da0073e9SAndroid Build Coastguard Worker        z_1p = make_tensor((4, 5), dtype=y_dtype, device=device, low=-0.8, high=1000)
4354*da0073e9SAndroid Build Coastguard Worker
4355*da0073e9SAndroid Build Coastguard Worker        xlogy_fns = torch.xlogy, scipy.special.xlogy
4356*da0073e9SAndroid Build Coastguard Worker        xlog1py_fns = torch.special.xlog1py, scipy.special.xlog1py
4357*da0073e9SAndroid Build Coastguard Worker
4358*da0073e9SAndroid Build Coastguard Worker        _compare_helper(x, x, *xlogy_fns)
4359*da0073e9SAndroid Build Coastguard Worker        _compare_helper(x, y, *xlogy_fns)
4360*da0073e9SAndroid Build Coastguard Worker        _compare_helper(x, z, *xlogy_fns)
4361*da0073e9SAndroid Build Coastguard Worker        _compare_helper(x, 3.14, *xlogy_fns)
4362*da0073e9SAndroid Build Coastguard Worker        _compare_helper(y, 3.14, *xlogy_fns)
4363*da0073e9SAndroid Build Coastguard Worker        _compare_helper(z, 3.14, *xlogy_fns)
4364*da0073e9SAndroid Build Coastguard Worker
4365*da0073e9SAndroid Build Coastguard Worker        _compare_helper(x_1p, x_1p, *xlog1py_fns)
4366*da0073e9SAndroid Build Coastguard Worker        _compare_helper(x_1p, y_1p, *xlog1py_fns)
4367*da0073e9SAndroid Build Coastguard Worker        _compare_helper(x_1p, z_1p, *xlog1py_fns)
4368*da0073e9SAndroid Build Coastguard Worker        _compare_helper(x_1p, 3.14, *xlog1py_fns)
4369*da0073e9SAndroid Build Coastguard Worker        _compare_helper(y_1p, 3.14, *xlog1py_fns)
4370*da0073e9SAndroid Build Coastguard Worker        _compare_helper(z_1p, 3.14, *xlog1py_fns)
4371*da0073e9SAndroid Build Coastguard Worker
4372*da0073e9SAndroid Build Coastguard Worker        # Special Values Tensor-Tensor
4373*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor(
4374*da0073e9SAndroid Build Coastguard Worker            [-1.0, 0.0, 1.0, 2.0, float("inf"), -float("inf"), float("nan")],
4375*da0073e9SAndroid Build Coastguard Worker            device=device,
4376*da0073e9SAndroid Build Coastguard Worker        )
4377*da0073e9SAndroid Build Coastguard Worker        zeros = torch.tensor(7, dtype=y_dtype, device=device)
4378*da0073e9SAndroid Build Coastguard Worker
4379*da0073e9SAndroid Build Coastguard Worker        _compare_helper(t, zeros, *xlogy_fns)
4380*da0073e9SAndroid Build Coastguard Worker        _compare_helper(t, 0.0, *xlogy_fns)
4381*da0073e9SAndroid Build Coastguard Worker
4382*da0073e9SAndroid Build Coastguard Worker        _compare_helper(t, zeros, *xlog1py_fns)
4383*da0073e9SAndroid Build Coastguard Worker        _compare_helper(t, 0.0, *xlog1py_fns)
4384*da0073e9SAndroid Build Coastguard Worker
4385*da0073e9SAndroid Build Coastguard Worker    @dtypes(*product(all_types_and(torch.bool), all_types_and(torch.bool)))
4386*da0073e9SAndroid Build Coastguard Worker    @skipIf(not TEST_SCIPY, "Scipy required for the test.")
4387*da0073e9SAndroid Build Coastguard Worker    @slowTest
4388*da0073e9SAndroid Build Coastguard Worker    def test_zeta(self, device, dtypes):
4389*da0073e9SAndroid Build Coastguard Worker        x_dtype, q_dtype = dtypes
4390*da0073e9SAndroid Build Coastguard Worker
4391*da0073e9SAndroid Build Coastguard Worker        def test_helper(x, q):
4392*da0073e9SAndroid Build Coastguard Worker            x_np = x if isinstance(x, float) else x.cpu().numpy()
4393*da0073e9SAndroid Build Coastguard Worker            q_np = q if isinstance(q, float) else q.cpu().numpy()
4394*da0073e9SAndroid Build Coastguard Worker            expected = torch.from_numpy(scipy.special.zeta(x_np, q_np))
4395*da0073e9SAndroid Build Coastguard Worker            actual = torch.special.zeta(x, q)
4396*da0073e9SAndroid Build Coastguard Worker
4397*da0073e9SAndroid Build Coastguard Worker            rtol, atol = None, None
4398*da0073e9SAndroid Build Coastguard Worker            if self.device_type == "cpu":
4399*da0073e9SAndroid Build Coastguard Worker                rtol, atol = 1e-6, 1e-6
4400*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, actual, rtol=rtol, atol=atol, exact_dtype=False)
4401*da0073e9SAndroid Build Coastguard Worker
4402*da0073e9SAndroid Build Coastguard Worker        # x tensor - q tensor same size
4403*da0073e9SAndroid Build Coastguard Worker        x = make_tensor((2, 3, 4), dtype=x_dtype, device=device)
4404*da0073e9SAndroid Build Coastguard Worker        q = make_tensor((2, 3, 4), dtype=q_dtype, device=device)
4405*da0073e9SAndroid Build Coastguard Worker        test_helper(x, q)
4406*da0073e9SAndroid Build Coastguard Worker
4407*da0073e9SAndroid Build Coastguard Worker        # x tensor - q tensor broadcast lhs
4408*da0073e9SAndroid Build Coastguard Worker        x = make_tensor((2, 1, 4), dtype=x_dtype, device=device)
4409*da0073e9SAndroid Build Coastguard Worker        q = make_tensor((2, 3, 4), dtype=q_dtype, device=device)
4410*da0073e9SAndroid Build Coastguard Worker        test_helper(x, q)
4411*da0073e9SAndroid Build Coastguard Worker
4412*da0073e9SAndroid Build Coastguard Worker        # x tensor - q tensor broadcast rhs
4413*da0073e9SAndroid Build Coastguard Worker        x = make_tensor((2, 3, 4), dtype=x_dtype, device=device)
4414*da0073e9SAndroid Build Coastguard Worker        q = make_tensor((2, 1, 4), dtype=q_dtype, device=device)
4415*da0073e9SAndroid Build Coastguard Worker        test_helper(x, q)
4416*da0073e9SAndroid Build Coastguard Worker
4417*da0073e9SAndroid Build Coastguard Worker        # x tensor - q tensor broadcast all
4418*da0073e9SAndroid Build Coastguard Worker        x = make_tensor((2, 3, 1), dtype=x_dtype, device=device)
4419*da0073e9SAndroid Build Coastguard Worker        q = make_tensor((2, 1, 4), dtype=q_dtype, device=device)
4420*da0073e9SAndroid Build Coastguard Worker        test_helper(x, q)
4421*da0073e9SAndroid Build Coastguard Worker
4422*da0073e9SAndroid Build Coastguard Worker        # x scalar - q tensor
4423*da0073e9SAndroid Build Coastguard Worker        for x in np.linspace(-5, 5, num=10).tolist():
4424*da0073e9SAndroid Build Coastguard Worker            if not q_dtype.is_floating_point:
4425*da0073e9SAndroid Build Coastguard Worker                q_dtype = torch.get_default_dtype()
4426*da0073e9SAndroid Build Coastguard Worker            q = make_tensor((2, 3, 4), dtype=q_dtype, device=device)
4427*da0073e9SAndroid Build Coastguard Worker            test_helper(x, q)
4428*da0073e9SAndroid Build Coastguard Worker
4429*da0073e9SAndroid Build Coastguard Worker        # x tensor - q scalar
4430*da0073e9SAndroid Build Coastguard Worker        for q in np.linspace(-5, 5, num=10).tolist():
4431*da0073e9SAndroid Build Coastguard Worker            if not x_dtype.is_floating_point:
4432*da0073e9SAndroid Build Coastguard Worker                x_dtype = torch.get_default_dtype()
4433*da0073e9SAndroid Build Coastguard Worker            x = make_tensor((2, 3, 4), dtype=x_dtype, device=device)
4434*da0073e9SAndroid Build Coastguard Worker            test_helper(x, q)
4435*da0073e9SAndroid Build Coastguard Worker
4436*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
4437*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.chalf)
4438*da0073e9SAndroid Build Coastguard Worker    def test_mul_chalf_tensor_and_cpu_scalar(self, device, dtype):
4439*da0073e9SAndroid Build Coastguard Worker        # Tests that Tensor and CPU Scalar work for `mul` for chalf.
4440*da0073e9SAndroid Build Coastguard Worker        # Ideally, this should be covered by `test_complex_half_reference_testing`
4441*da0073e9SAndroid Build Coastguard Worker        # from test_ops.py by checking reference_samples from the OpInfo.
4442*da0073e9SAndroid Build Coastguard Worker        # But currently that doesn't work as sample generation requires support of
4443*da0073e9SAndroid Build Coastguard Worker        # `index_select` which is not implemented for `complex32` at the
4444*da0073e9SAndroid Build Coastguard Worker        # time of writing this test.
4445*da0073e9SAndroid Build Coastguard Worker        # TODO: Remove this test once above issue is fixed.
4446*da0073e9SAndroid Build Coastguard Worker        # Ref: https://github.com/pytorch/pytorch/pull/76364
4447*da0073e9SAndroid Build Coastguard Worker        x = make_tensor((2, 2), device=device, dtype=dtype)
4448*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x * 2.5, x * torch.tensor(2.5, device=device, dtype=dtype))
4449*da0073e9SAndroid Build Coastguard Worker
4450*da0073e9SAndroid Build Coastguard Worker
4451*da0073e9SAndroid Build Coastguard Workertensor_binary_ops = [
4452*da0073e9SAndroid Build Coastguard Worker    "__lt__",
4453*da0073e9SAndroid Build Coastguard Worker    "__le__",
4454*da0073e9SAndroid Build Coastguard Worker    "__gt__",
4455*da0073e9SAndroid Build Coastguard Worker    "__ge__",
4456*da0073e9SAndroid Build Coastguard Worker    "__eq__",
4457*da0073e9SAndroid Build Coastguard Worker    "__ne__",
4458*da0073e9SAndroid Build Coastguard Worker    "__add__",
4459*da0073e9SAndroid Build Coastguard Worker    "__radd__",
4460*da0073e9SAndroid Build Coastguard Worker    "__iadd__",
4461*da0073e9SAndroid Build Coastguard Worker    "__sub__",
4462*da0073e9SAndroid Build Coastguard Worker    "__rsub__",
4463*da0073e9SAndroid Build Coastguard Worker    "__isub__",
4464*da0073e9SAndroid Build Coastguard Worker    "__mul__",
4465*da0073e9SAndroid Build Coastguard Worker    "__rmul__",
4466*da0073e9SAndroid Build Coastguard Worker    "__imul__",
4467*da0073e9SAndroid Build Coastguard Worker    "__matmul__",
4468*da0073e9SAndroid Build Coastguard Worker    "__rmatmul__",
4469*da0073e9SAndroid Build Coastguard Worker    "__truediv__",
4470*da0073e9SAndroid Build Coastguard Worker    "__rtruediv__",
4471*da0073e9SAndroid Build Coastguard Worker    "__itruediv__",
4472*da0073e9SAndroid Build Coastguard Worker    "__floordiv__",
4473*da0073e9SAndroid Build Coastguard Worker    "__rfloordiv__",
4474*da0073e9SAndroid Build Coastguard Worker    "__ifloordiv__",
4475*da0073e9SAndroid Build Coastguard Worker    "__mod__",
4476*da0073e9SAndroid Build Coastguard Worker    "__rmod__",
4477*da0073e9SAndroid Build Coastguard Worker    "__imod__",
4478*da0073e9SAndroid Build Coastguard Worker    "__pow__",
4479*da0073e9SAndroid Build Coastguard Worker    "__rpow__",
4480*da0073e9SAndroid Build Coastguard Worker    "__ipow__",
4481*da0073e9SAndroid Build Coastguard Worker    "__lshift__",
4482*da0073e9SAndroid Build Coastguard Worker    "__rlshift__",
4483*da0073e9SAndroid Build Coastguard Worker    "__ilshift__",
4484*da0073e9SAndroid Build Coastguard Worker    "__rshift__",
4485*da0073e9SAndroid Build Coastguard Worker    "__rrshift__",
4486*da0073e9SAndroid Build Coastguard Worker    "__irshift__",
4487*da0073e9SAndroid Build Coastguard Worker    "__and__",
4488*da0073e9SAndroid Build Coastguard Worker    "__rand__",
4489*da0073e9SAndroid Build Coastguard Worker    "__iand__",
4490*da0073e9SAndroid Build Coastguard Worker    "__xor__",
4491*da0073e9SAndroid Build Coastguard Worker    "__rxor__",
4492*da0073e9SAndroid Build Coastguard Worker    "__ixor__",
4493*da0073e9SAndroid Build Coastguard Worker    "__or__",
4494*da0073e9SAndroid Build Coastguard Worker    "__ror__",
4495*da0073e9SAndroid Build Coastguard Worker    "__ior__",
4496*da0073e9SAndroid Build Coastguard Worker    # Unsupported operators
4497*da0073e9SAndroid Build Coastguard Worker    # '__imatmul__',
4498*da0073e9SAndroid Build Coastguard Worker    # '__divmod__', '__rdivmod__', '__idivmod__',
4499*da0073e9SAndroid Build Coastguard Worker]
4500*da0073e9SAndroid Build Coastguard Worker
4501*da0073e9SAndroid Build Coastguard Worker
4502*da0073e9SAndroid Build Coastguard Worker# Test that binary math operations return NotImplemented for unknown types.
4503*da0073e9SAndroid Build Coastguard Workerdef generate_not_implemented_tests(cls):
4504*da0073e9SAndroid Build Coastguard Worker    class UnknownType:
4505*da0073e9SAndroid Build Coastguard Worker        pass
4506*da0073e9SAndroid Build Coastguard Worker
4507*da0073e9SAndroid Build Coastguard Worker    # TODO: refactor to inline these
4508*da0073e9SAndroid Build Coastguard Worker    _types = [
4509*da0073e9SAndroid Build Coastguard Worker        torch.half,
4510*da0073e9SAndroid Build Coastguard Worker        torch.float,
4511*da0073e9SAndroid Build Coastguard Worker        torch.double,
4512*da0073e9SAndroid Build Coastguard Worker        torch.int8,
4513*da0073e9SAndroid Build Coastguard Worker        torch.short,
4514*da0073e9SAndroid Build Coastguard Worker        torch.int,
4515*da0073e9SAndroid Build Coastguard Worker        torch.long,
4516*da0073e9SAndroid Build Coastguard Worker        torch.uint8,
4517*da0073e9SAndroid Build Coastguard Worker    ]
4518*da0073e9SAndroid Build Coastguard Worker
4519*da0073e9SAndroid Build Coastguard Worker    def create_test_func(op):
4520*da0073e9SAndroid Build Coastguard Worker        @dtypes(*_types)
4521*da0073e9SAndroid Build Coastguard Worker        def test(self, device, dtype):
4522*da0073e9SAndroid Build Coastguard Worker            # Generate the inputs
4523*da0073e9SAndroid Build Coastguard Worker            tensor = torch.empty((), device=device, dtype=dtype)
4524*da0073e9SAndroid Build Coastguard Worker
4525*da0073e9SAndroid Build Coastguard Worker            # Runs the tensor op on the device
4526*da0073e9SAndroid Build Coastguard Worker            result = getattr(tensor, op)(UnknownType())
4527*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, NotImplemented)
4528*da0073e9SAndroid Build Coastguard Worker
4529*da0073e9SAndroid Build Coastguard Worker        return test
4530*da0073e9SAndroid Build Coastguard Worker
4531*da0073e9SAndroid Build Coastguard Worker    for op in tensor_binary_ops:
4532*da0073e9SAndroid Build Coastguard Worker        test_name = f"test_{op}_not_implemented"
4533*da0073e9SAndroid Build Coastguard Worker        assert not hasattr(cls, test_name), f"{test_name} already in {cls.__name__}"
4534*da0073e9SAndroid Build Coastguard Worker
4535*da0073e9SAndroid Build Coastguard Worker        setattr(cls, test_name, create_test_func(op))
4536*da0073e9SAndroid Build Coastguard Worker
4537*da0073e9SAndroid Build Coastguard Worker
4538*da0073e9SAndroid Build Coastguard Workergenerate_not_implemented_tests(TestBinaryUfuncs)
4539*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestBinaryUfuncs, globals())
4540*da0073e9SAndroid Build Coastguard Worker
4541*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
4542*da0073e9SAndroid Build Coastguard Worker    run_tests()
4543