xref: /aosp_15_r20/external/pytorch/test/test_binary_ufuncs.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: tests"]
2
3import itertools
4import math
5import operator
6import random
7import warnings
8from functools import partial
9from itertools import chain, product
10from numbers import Number
11
12import numpy as np
13
14import torch
15import torch.autograd.forward_ad as fwAD
16from torch import inf, nan
17from torch.testing import make_tensor
18from torch.testing._internal.common_device_type import (
19    deviceCountAtLeast,
20    dtypes,
21    dtypesIfCPU,
22    dtypesIfCUDA,
23    expectedFailureMeta,
24    instantiate_device_type_tests,
25    onlyCPU,
26    onlyCUDA,
27    onlyNativeDeviceTypes,
28    OpDTypes,
29    ops,
30    precisionOverride,
31    skipIf,
32    skipMeta,
33)
34from torch.testing._internal.common_dtype import (
35    all_types_and,
36    all_types_and_complex_and,
37    complex_types,
38    floating_and_complex_types,
39    floating_types_and,
40    get_all_int_dtypes,
41    get_all_math_dtypes,
42    integral_types,
43    integral_types_and,
44)
45from torch.testing._internal.common_methods_invocations import (
46    binary_ufuncs,
47    binary_ufuncs_and_refs,
48    generate_elementwise_binary_broadcasting_tensors,
49    generate_elementwise_binary_extremal_value_tensors,
50    generate_elementwise_binary_large_value_tensors,
51    generate_elementwise_binary_small_value_tensors,
52    generate_elementwise_binary_tensors,
53    generate_elementwise_binary_with_scalar_and_type_promotion_samples,
54    generate_elementwise_binary_with_scalar_samples,
55)
56from torch.testing._internal.common_utils import (
57    gradcheck,
58    iter_indices,
59    numpy_to_torch_dtype_dict,
60    run_tests,
61    set_default_dtype,
62    skipIfTorchDynamo,
63    slowTest,
64    TEST_SCIPY,
65    TestCase,
66    torch_to_numpy_dtype_dict,
67    xfailIfTorchDynamo,
68)
69
70
71if TEST_SCIPY:
72    import scipy.integrate
73    import scipy.special
74
75
76# TODO: update to use opinfos consistently
77class TestBinaryUfuncs(TestCase):
78    # Generic tests for elementwise binary (AKA binary universal (u) functions (funcs))
79    # TODO: below contiguous tensor results are compared with a variety of noncontiguous results.
80    #   It would be interesting to have the lhs and rhs have different discontiguities.
81
82    # Helper for comparing torch tensors and NumPy arrays
83    # TODO: should this or assertEqual also validate that strides are equal?
84    def assertEqualHelper(
85        self, actual, expected, msg, *, dtype, exact_dtype=True, **kwargs
86    ):
87        assert isinstance(actual, torch.Tensor)
88
89        # Some NumPy functions return scalars, not arrays
90        if isinstance(expected, Number):
91            self.assertEqual(actual.item(), expected, msg=msg, **kwargs)
92        elif isinstance(expected, np.ndarray):
93            # Handles exact dtype comparisons between arrays and tensors
94            if exact_dtype:
95                # Allows array dtype to be float32 when comparing with bfloat16 tensors
96                #   since NumPy doesn't support the bfloat16 dtype
97                # Also ops like scipy.special.erf, scipy.special.erfc, etc, promote float16
98                # to float32
99                if expected.dtype == np.float32:
100                    assert actual.dtype in (
101                        torch.float16,
102                        torch.bfloat16,
103                        torch.float32,
104                    )
105                else:
106                    assert expected.dtype == torch_to_numpy_dtype_dict[actual.dtype]
107
108            self.assertEqual(
109                actual,
110                torch.from_numpy(expected).to(actual.dtype),
111                msg,
112                exact_device=False,
113                **kwargs,
114            )
115        else:
116            self.assertEqual(actual, expected, msg, exact_device=False, **kwargs)
117
118    # Tests that the function and its (array-accepting) reference produce the same
119    #   values on given tensors
120    def _test_reference_numerics(self, dtype, op, gen, equal_nan=True):
121        def _helper_reference_numerics(
122            expected, actual, msg, exact_dtype, equal_nan=True
123        ):
124            if not torch.can_cast(
125                numpy_to_torch_dtype_dict[expected.dtype.type], dtype
126            ):
127                exact_dtype = False
128
129            if dtype is torch.bfloat16 and expected.dtype == np.float32:
130                # Ref: https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_utils.py#L1149
131                self.assertEqualHelper(
132                    actual,
133                    expected,
134                    msg,
135                    dtype=dtype,
136                    exact_dtype=exact_dtype,
137                    rtol=16e-3,
138                    atol=1e-5,
139                )
140            else:
141                self.assertEqualHelper(
142                    actual,
143                    expected,
144                    msg,
145                    dtype=dtype,
146                    equal_nan=equal_nan,
147                    exact_dtype=exact_dtype,
148                )
149
150        for sample in gen:
151            # Each sample input acquired from the generator is just one lhs tensor
152            #   and one rhs tensor
153            l = sample.input
154            r = sample.args[0]
155
156            numpy_sample = sample.numpy()
157            l_numpy = numpy_sample.input
158            r_numpy = numpy_sample.args[0]
159            actual = op(l, r)
160            expected = op.ref(l_numpy, r_numpy)
161
162            # Crafts a custom error message for smaller, printable tensors
163            def _numel(x):
164                if isinstance(x, torch.Tensor):
165                    return x.numel()
166                # Assumes x is a scalar
167                return 1
168
169            if _numel(l) <= 100 and _numel(r) <= 100:
170                msg = (
171                    "Failed to produce expected results! Input lhs tensor was"
172                    f" {l}, rhs tensor was {r}, torch result is {actual}, and reference result is"
173                    f" {expected}."
174                )
175            else:
176                msg = None
177
178            exact_dtype = True
179            if isinstance(actual, torch.Tensor):
180                _helper_reference_numerics(
181                    expected, actual, msg, exact_dtype, equal_nan
182                )
183            else:
184                for x, y in zip(expected, actual):
185                    # testing multi-outputs results
186                    _helper_reference_numerics(x, y, msg, exact_dtype, equal_nan)
187
188    # The following tests only apply to elementwise binary operators with references
189    binary_ufuncs_with_references = list(
190        filter(lambda op: op.ref is not None and op.ref is not None, binary_ufuncs)
191    )
192
193    @ops(binary_ufuncs_with_references)
194    def test_reference_numerics(self, device, dtype, op):
195        gen = generate_elementwise_binary_tensors(op, device=device, dtype=dtype)
196        self._test_reference_numerics(dtype, op, gen, equal_nan=True)
197
198    @ops(binary_ufuncs_with_references)
199    def test_reference_numerics_small_values(self, device, dtype, op):
200        if dtype is torch.bool:
201            self.skipTest("Doesn't support bool!")
202
203        gen = generate_elementwise_binary_small_value_tensors(
204            op, device=device, dtype=dtype
205        )
206        self._test_reference_numerics(dtype, op, gen, equal_nan=True)
207
208    @ops(
209        binary_ufuncs_with_references,
210        allowed_dtypes=(
211            torch.int16,
212            torch.int32,
213            torch.int64,
214            torch.float16,
215            torch.bfloat16,
216            torch.float32,
217            torch.float64,
218            torch.complex64,
219            torch.complex128,
220        ),
221    )
222    def test_reference_numerics_large_values(self, device, dtype, op):
223        gen = generate_elementwise_binary_large_value_tensors(
224            op, device=device, dtype=dtype
225        )
226        self._test_reference_numerics(dtype, op, gen, equal_nan=True)
227
228    @ops(
229        binary_ufuncs_with_references,
230        allowed_dtypes=(
231            torch.float16,
232            torch.bfloat16,
233            torch.float32,
234            torch.float64,
235            torch.complex64,
236            torch.complex128,
237        ),
238    )
239    def test_reference_numerics_extremal_values(self, device, dtype, op):
240        gen = generate_elementwise_binary_extremal_value_tensors(
241            op, device=device, dtype=dtype
242        )
243        self._test_reference_numerics(dtype, op, gen, equal_nan=True)
244
245    # tests broadcasting and noncontiguous broadcasting behavior
246    @ops(
247        binary_ufuncs_with_references,
248        allowed_dtypes=(
249            torch.long,
250            torch.float32,
251        ),
252    )
253    def test_broadcasting(self, device, dtype, op):
254        gen = generate_elementwise_binary_broadcasting_tensors(
255            op, device=device, dtype=dtype
256        )
257        self._test_reference_numerics(dtype, op, gen, equal_nan=True)
258
259    @ops(
260        binary_ufuncs_with_references,
261        allowed_dtypes=(torch.long, torch.float32, torch.complex64),
262    )
263    def test_scalar_support(self, device, dtype, op):
264        gen = generate_elementwise_binary_with_scalar_samples(
265            op, device=device, dtype=dtype
266        )
267        self._test_reference_numerics(dtype, op, gen, equal_nan=True)
268        gen = generate_elementwise_binary_with_scalar_and_type_promotion_samples(
269            op, device=device, dtype=dtype
270        )
271        self._test_reference_numerics(dtype, op, gen, equal_nan=True)
272
273    @ops(binary_ufuncs)
274    def test_contig_vs_every_other(self, device, dtype, op):
275        lhs = make_tensor(
276            (1026,), device=device, dtype=dtype, **op.lhs_make_tensor_kwargs
277        )
278        rhs = make_tensor(
279            (1026,), device=device, dtype=dtype, **op.rhs_make_tensor_kwargs
280        )
281
282        lhs_non_contig = lhs[::2]
283        rhs_non_contig = rhs[::2]
284
285        self.assertTrue(lhs.is_contiguous())
286        self.assertTrue(rhs.is_contiguous())
287
288        self.assertFalse(lhs_non_contig.is_contiguous())
289        self.assertFalse(rhs_non_contig.is_contiguous())
290
291        expected = op(lhs, rhs)[::2]
292        actual = op(lhs_non_contig, rhs_non_contig)
293        self.assertEqual(expected, actual)
294
295    @ops(binary_ufuncs)
296    def test_contig_vs_transposed(self, device, dtype, op):
297        lhs = make_tensor(
298            (789, 357), device=device, dtype=dtype, **op.lhs_make_tensor_kwargs
299        )
300        rhs = make_tensor(
301            (789, 357), device=device, dtype=dtype, **op.rhs_make_tensor_kwargs
302        )
303
304        lhs_non_contig = lhs.T
305        rhs_non_contig = rhs.T
306
307        self.assertTrue(lhs.is_contiguous())
308        self.assertTrue(rhs.is_contiguous())
309
310        self.assertFalse(lhs_non_contig.is_contiguous())
311        self.assertFalse(rhs_non_contig.is_contiguous())
312
313        expected = op(lhs, rhs).T
314        actual = op(lhs_non_contig, rhs_non_contig)
315        self.assertEqual(expected, actual)
316
317    @ops(binary_ufuncs)
318    def test_non_contig(self, device, dtype, op):
319        shapes = ((5, 7), (1024,))
320        for shape in shapes:
321            lhs = make_tensor(
322                shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs
323            )
324            rhs = make_tensor(
325                shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs
326            )
327
328            lhs_non_contig = torch.empty(shape + (2,), device=device, dtype=dtype)[
329                ..., 0
330            ]
331            lhs_non_contig.copy_(lhs)
332
333            rhs_non_contig = torch.empty(shape + (2,), device=device, dtype=dtype)[
334                ..., 0
335            ]
336            rhs_non_contig.copy_(rhs)
337
338            self.assertTrue(lhs.is_contiguous())
339            self.assertTrue(rhs.is_contiguous())
340
341            self.assertFalse(lhs_non_contig.is_contiguous())
342            self.assertFalse(rhs_non_contig.is_contiguous())
343
344            expected = op(lhs, rhs)
345            actual = op(lhs_non_contig, rhs_non_contig)
346            self.assertEqual(expected, actual)
347
348    @ops(binary_ufuncs)
349    def test_non_contig_index(self, device, dtype, op):
350        shape = (2, 2, 1, 2)
351        lhs = make_tensor(
352            shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs
353        )
354        rhs = make_tensor(
355            shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs
356        )
357
358        lhs_non_contig = lhs[:, 1, ...]
359        lhs = lhs_non_contig.contiguous()
360
361        rhs_non_contig = rhs[:, 1, ...]
362        rhs = rhs_non_contig.contiguous()
363
364        self.assertTrue(lhs.is_contiguous())
365        self.assertTrue(rhs.is_contiguous())
366
367        self.assertFalse(lhs_non_contig.is_contiguous())
368        self.assertFalse(rhs_non_contig.is_contiguous())
369
370        expected = op(lhs, rhs)
371        actual = op(lhs_non_contig, rhs_non_contig)
372        self.assertEqual(expected, actual)
373
374    @ops(binary_ufuncs)
375    def test_non_contig_expand(self, device, dtype, op):
376        shapes = [(1, 3), (1, 7), (5, 7)]
377        for shape in shapes:
378            lhs = make_tensor(
379                shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs
380            )
381            rhs = make_tensor(
382                shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs
383            )
384
385            lhs_non_contig = lhs.clone().expand(3, -1, -1)
386            rhs_non_contig = rhs.clone().expand(3, -1, -1)
387
388            self.assertTrue(lhs.is_contiguous())
389            self.assertTrue(rhs.is_contiguous())
390
391            self.assertFalse(lhs_non_contig.is_contiguous())
392            self.assertFalse(rhs_non_contig.is_contiguous())
393
394            expected = op(lhs, rhs)
395            actual = op(lhs_non_contig, rhs_non_contig)
396            for i in range(3):
397                self.assertEqual(expected, actual[i])
398
399    @ops(binary_ufuncs)
400    def test_contig_size1(self, device, dtype, op):
401        shape = (5, 100)
402        lhs = make_tensor(
403            shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs
404        )
405        rhs = make_tensor(
406            shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs
407        )
408
409        lhs = lhs[:1, :50]
410        lhs_alt = torch.empty(lhs.size(), device=device, dtype=dtype)
411        lhs_alt.copy_(lhs)
412
413        rhs = rhs[:1, :50]
414        rhs_alt = torch.empty(rhs.size(), device=device, dtype=dtype)
415        rhs_alt.copy_(rhs)
416
417        self.assertTrue(lhs.is_contiguous())
418        self.assertTrue(rhs.is_contiguous())
419
420        self.assertTrue(lhs_alt.is_contiguous())
421        self.assertTrue(rhs_alt.is_contiguous())
422
423        expected = op(lhs, rhs)
424        actual = op(lhs_alt, rhs_alt)
425        self.assertEqual(expected, actual)
426
427    @ops(binary_ufuncs)
428    def test_contig_size1_large_dim(self, device, dtype, op):
429        shape = (5, 2, 3, 1, 4, 5, 3, 2, 1, 2, 3, 4)
430        lhs = make_tensor(
431            shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs
432        )
433        rhs = make_tensor(
434            shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs
435        )
436
437        lhs = lhs[:1, :, :, :, :, :, :, :, :, :, :, :]
438        lhs_alt = torch.empty(lhs.size(), device=device, dtype=dtype)
439        lhs_alt.copy_(lhs)
440
441        rhs = rhs[:1, :, :, :, :, :, :, :, :, :, :, :]
442        rhs_alt = torch.empty(rhs.size(), device=device, dtype=dtype)
443        rhs_alt.copy_(rhs)
444
445        self.assertTrue(lhs.is_contiguous())
446        self.assertTrue(rhs.is_contiguous())
447
448        self.assertTrue(lhs_alt.is_contiguous())
449        self.assertTrue(rhs_alt.is_contiguous())
450
451        expected = op(lhs, rhs)
452        actual = op(lhs_alt, rhs_alt)
453        self.assertEqual(expected, actual)
454
455    @ops(binary_ufuncs)
456    def test_batch_vs_slicing(self, device, dtype, op):
457        shape = (32, 512)
458        lhs = make_tensor(
459            shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs
460        )
461        rhs = make_tensor(
462            shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs
463        )
464
465        expected = op(lhs, rhs)
466
467        actual = []
468        for idx in range(32):
469            actual.append(op(lhs[idx], rhs[idx]))
470        actual = torch.stack(actual)
471
472        self.assertEqual(expected, actual)
473
474    # Tests that elementwise binary operators participate in type promotion properly
475    # NOTE: because the cross-product of all possible type promotion tests is huge, this
476    #   just spot checks some handwritten cases.
477    # NOTE: It may be possible to refactor this test into something simpler
478    @ops(binary_ufuncs_and_refs, dtypes=OpDTypes.none)
479    def test_type_promotion(self, device, op):
480        supported_dtypes = op.supported_dtypes(torch.device(device).type)
481
482        make_lhs = partial(
483            make_tensor, (5,), device=device, **op.lhs_make_tensor_kwargs
484        )
485        make_rhs = partial(
486            make_tensor, (5,), device=device, **op.rhs_make_tensor_kwargs
487        )
488
489        make_rhs_scalar_tensor = partial(
490            make_tensor, (), device="cpu", **op.rhs_make_tensor_kwargs
491        )
492
493        def _supported(dtypes):
494            return all(x in supported_dtypes for x in dtypes)
495
496        # int x int type promotion
497        if _supported((torch.int16, torch.int32, torch.int64)):
498            lhs_i16 = make_lhs(dtype=torch.int16)
499            lhs_i32 = make_lhs(dtype=torch.int32)
500            lhs_i64 = make_lhs(dtype=torch.int64)
501
502            rhs_i16 = make_rhs(dtype=torch.int16)
503            rhs_i32 = make_rhs(dtype=torch.int32)
504            rhs_i64 = make_rhs(dtype=torch.int64)
505
506            if op.promotes_int_to_float:
507                default_dtype = torch.get_default_dtype()
508                self.assertEqual(op(lhs_i16, rhs_i32).dtype, default_dtype)
509                self.assertEqual(
510                    op(lhs_i16, rhs_i32),
511                    op(lhs_i16.to(default_dtype), rhs_i32.to(default_dtype)),
512                )
513
514                self.assertEqual(op(lhs_i32, rhs_i64).dtype, default_dtype)
515                self.assertEqual(
516                    op(lhs_i32, rhs_i64),
517                    op(lhs_i32.to(default_dtype), rhs_i64.to(default_dtype)),
518                )
519            elif op.always_returns_bool:
520                self.assertEqual(op(lhs_i16, rhs_i32).dtype, torch.bool)
521                self.assertEqual(op(lhs_i32, rhs_i64).dtype, torch.bool)
522            else:  # standard type promotion
523                self.assertEqual(op(lhs_i16, rhs_i32).dtype, torch.int32)
524                self.assertEqual(
525                    op(lhs_i16, rhs_i32), op(lhs_i16.to(torch.int32), rhs_i32)
526                )
527
528                self.assertEqual(op(lhs_i32, rhs_i64).dtype, torch.int64)
529                self.assertEqual(
530                    op(lhs_i32, rhs_i64), op(lhs_i32.to(torch.int64), rhs_i64)
531                )
532
533            if op.supports_out:
534                if not op.promotes_int_to_float:
535                    # Integers can be safely cast to other integer types
536                    out = torch.empty_like(lhs_i64)
537                    self.assertEqual(op(lhs_i16, rhs_i32, out=out).dtype, torch.int64)
538                    self.assertEqual(op(lhs_i16, rhs_i32), out, exact_dtype=False)
539
540                    out = torch.empty_like(lhs_i16)
541                    self.assertEqual(op(lhs_i32, rhs_i64, out=out).dtype, torch.int16)
542                else:
543                    # Float outs cannot be safely cast to integer types
544                    with self.assertRaisesRegex(RuntimeError, "can't be cast"):
545                        op(lhs_i16, rhs_i32, out=torch.empty_like(lhs_i64))
546
547                if not op.always_returns_bool:
548                    # Neither integer nor float outs can be cast to bool
549                    with self.assertRaisesRegex(RuntimeError, "can't be cast"):
550                        op(
551                            lhs_i16,
552                            rhs_i32,
553                            out=torch.empty_like(lhs_i64, dtype=torch.bool),
554                        )
555
556                # All these output types can be cast to any float or complex type
557                out = torch.empty_like(lhs_i64, dtype=torch.float16)
558                self.assertEqual(op(lhs_i16, rhs_i32, out=out).dtype, torch.float16)
559
560                out = torch.empty_like(lhs_i64, dtype=torch.bfloat16)
561                self.assertEqual(op(lhs_i16, rhs_i32, out=out).dtype, torch.bfloat16)
562
563                out = torch.empty_like(lhs_i64, dtype=torch.float32)
564                self.assertEqual(op(lhs_i16, rhs_i32, out=out).dtype, torch.float32)
565                self.assertEqual(op(lhs_i16, rhs_i32), out, exact_dtype=False)
566
567                out = torch.empty_like(lhs_i64, dtype=torch.complex64)
568                self.assertEqual(op(lhs_i16, rhs_i32, out=out).dtype, torch.complex64)
569                self.assertEqual(op(lhs_i16, rhs_i32), out, exact_dtype=False)
570
571        # float x float type promotion
572        if _supported((torch.float32, torch.float64)):
573            lhs_f32 = make_lhs(dtype=torch.float32)
574            lhs_f64 = make_lhs(dtype=torch.float64)
575
576            rhs_f32 = make_rhs(dtype=torch.float32)
577            rhs_f64 = make_rhs(dtype=torch.float64)
578
579            if op.always_returns_bool:
580                self.assertEqual(op(lhs_f32, rhs_f64).dtype, torch.bool)
581            else:  # normal float type promotion
582                self.assertEqual(op(lhs_f32, rhs_f64).dtype, torch.float64)
583                self.assertEqual(
584                    op(lhs_f32, rhs_f64), op(lhs_f32.to(torch.float64), rhs_f64)
585                )
586
587            if op.supports_out:
588                # All these output types can be cast to any float or complex type
589                out = torch.empty_like(lhs_f64, dtype=torch.float16)
590                self.assertEqual(op(lhs_f32, rhs_f64, out=out).dtype, torch.float16)
591
592                out = torch.empty_like(lhs_f64, dtype=torch.bfloat16)
593                self.assertEqual(op(lhs_f32, rhs_f64, out=out).dtype, torch.bfloat16)
594                self.assertEqual(op(lhs_f32, rhs_f64), out, exact_dtype=False)
595
596                out = torch.empty_like(lhs_f64, dtype=torch.float32)
597                self.assertEqual(op(lhs_f32, rhs_f64, out=out).dtype, torch.float32)
598                self.assertEqual(op(lhs_f32, rhs_f64), out, exact_dtype=False)
599
600                out = torch.empty_like(lhs_f64, dtype=torch.complex64)
601                self.assertEqual(op(lhs_f32, rhs_f64, out=out).dtype, torch.complex64)
602                self.assertEqual(op(lhs_f32, rhs_f64), out, exact_dtype=False)
603
604                if not op.always_returns_bool:
605                    # float outs can't be cast to an integer dtype
606                    with self.assertRaisesRegex(RuntimeError, "can't be cast"):
607                        op(
608                            lhs_f32,
609                            rhs_f64,
610                            out=torch.empty_like(lhs_f64, dtype=torch.int64),
611                        )
612                else:
613                    # bool outs can be cast to an integer dtype
614                    out = torch.empty_like(lhs_f64, dtype=torch.int64)
615                    self.assertEqual(op(lhs_f32, rhs_f64, out=out).dtype, torch.int64)
616                    self.assertEqual(op(lhs_f32, rhs_f64), out, exact_dtype=False)
617
618        # complex x complex type promotion
619        if _supported((torch.complex64, torch.complex128)):
620            lhs_c64 = make_lhs(dtype=torch.complex64)
621            lhs_c128 = make_lhs(dtype=torch.complex128)
622
623            rhs_c64 = make_rhs(dtype=torch.complex64)
624            rhs_c128 = make_rhs(dtype=torch.complex128)
625
626            if op.always_returns_bool:
627                self.assertEqual(op(lhs_c64, lhs_c128).dtype, torch.bool)
628            else:  # normal complex type promotion
629                self.assertEqual(op(lhs_c64, rhs_c128).dtype, torch.complex128)
630                self.assertEqual(
631                    op(lhs_c64, rhs_c128), op(lhs_c64.to(torch.complex128), rhs_c128)
632                )
633
634            if op.supports_out:
635                # All these output types can be cast to any or complex type
636                out = torch.empty_like(lhs_c64, dtype=torch.complex64)
637
638                self.assertEqual(op(lhs_c64, rhs_c128, out=out).dtype, torch.complex64)
639                result = op(lhs_c64, rhs_c128)
640                self.assertEqual(result, out.to(result.dtype))
641
642                if not op.always_returns_bool:
643                    # complex outs can't be cast to float types
644                    with self.assertRaisesRegex(RuntimeError, "can't be cast"):
645                        op(
646                            lhs_c64,
647                            rhs_c128,
648                            out=torch.empty_like(lhs_c64, dtype=torch.float64),
649                        )
650                    # complex outs can't be cast to an integer dtype
651                    with self.assertRaisesRegex(RuntimeError, "can't be cast"):
652                        op(
653                            lhs_c64,
654                            rhs_c128,
655                            out=torch.empty_like(lhs_c64, dtype=torch.int64),
656                        )
657                else:
658                    # bool outs can be cast to a float type
659                    out = torch.empty_like(lhs_c64, dtype=torch.float64)
660                    self.assertEqual(
661                        op(lhs_c64, rhs_c128, out=out).dtype, torch.float64
662                    )
663                    self.assertEqual(op(lhs_c64, rhs_c128), out, exact_dtype=False)
664
665                    # bool outs can be cast to an integer dtype
666                    out = torch.empty_like(lhs_f64, dtype=torch.int64)
667                    self.assertEqual(op(lhs_f32, rhs_f64, out=out).dtype, torch.int64)
668                    self.assertEqual(op(lhs_f32, rhs_f64), out, exact_dtype=False)
669
670        # int x float type promotion
671        # Note: float type is the result dtype
672        if _supported((torch.long, torch.float32)):
673            lhs_i64 = make_lhs(dtype=torch.int64)
674            rhs_f32 = make_rhs(dtype=torch.float32)
675
676            result = op(lhs_i64, rhs_f32)
677            expected_dtype = torch.float32 if not op.always_returns_bool else torch.bool
678            self.assertEqual(result.dtype, expected_dtype)
679
680        # float x complex type promotion
681        # Note: complex type with highest "value type" is the result dtype
682        if _supported((torch.float64, torch.complex64)):
683            lhs_f64 = make_lhs(dtype=torch.float64)
684            rhs_c64 = make_rhs(dtype=torch.complex64)
685
686            result = op(lhs_f64, rhs_c64)
687            expected_dtype = (
688                torch.complex128 if not op.always_returns_bool else torch.bool
689            )
690            self.assertEqual(result.dtype, expected_dtype)
691
692        # int x float scalar type promotion
693        # Note: default float dtype is the result dtype
694        if _supported((torch.int64, torch.float32)) and op.supports_rhs_python_scalar:
695            lhs_i64 = make_lhs(dtype=torch.int64)
696            rhs_f_scalar = 1.0
697
698            result = op(lhs_i64, rhs_f_scalar)
699            expected_dtype = (
700                torch.get_default_dtype() if not op.always_returns_bool else torch.bool
701            )
702            self.assertEqual(result.dtype, expected_dtype)
703
704            # repeats with a scalar float tensor, which should set the dtype
705            rhs_f32_scalar_tensor = make_rhs_scalar_tensor(dtype=torch.float32)
706            result = op(lhs_i64, rhs_f32_scalar_tensor)
707            expected_dtype = torch.float32 if not op.always_returns_bool else torch.bool
708            self.assertEqual(result.dtype, expected_dtype)
709
710            # Additional test with double
711            if _supported((torch.float64,)):
712                rhs_f64_scalar_tensor = make_rhs_scalar_tensor(dtype=torch.float64)
713                result = op(lhs_i64, rhs_f64_scalar_tensor)
714                expected_dtype = (
715                    torch.float64 if not op.always_returns_bool else torch.bool
716                )
717                self.assertEqual(result.dtype, expected_dtype)
718
719        # float x complex scalar type promotion
720        # Note: result dtype is complex with highest "value type" among all tensors
721        if (
722            _supported((torch.float32, torch.complex64))
723            and op.supports_rhs_python_scalar
724        ):
725            lhs_f32 = make_lhs(dtype=torch.float32)
726            rhs_c_scalar = complex(1, 1)
727
728            result = op(lhs_f32, rhs_c_scalar)
729            expected_dtype = (
730                torch.complex64 if not op.always_returns_bool else torch.bool
731            )
732            self.assertEqual(result.dtype, expected_dtype)
733
734            # repeats with a scalar complex tensor
735            rhs_c64_scalar_tensor = make_rhs_scalar_tensor(dtype=torch.complex64)
736            result = op(lhs_f32, rhs_c64_scalar_tensor)
737            expected_dtype = (
738                torch.complex64 if not op.always_returns_bool else torch.bool
739            )
740            self.assertEqual(result.dtype, expected_dtype)
741
742            # Additional test with complexdouble
743            if _supported((torch.complex128,)):
744                rhs_c128_scalar_tensor = make_rhs_scalar_tensor(dtype=torch.complex128)
745                result = op(lhs_f32, rhs_c128_scalar_tensor)
746                # Value type of 1D+ Tensor (lhs_f32) takes priority over scalar tensor (rhs_c128).
747                expected_dtype = (
748                    torch.complex64 if not op.always_returns_bool else torch.bool
749                )
750                self.assertEqual(result.dtype, expected_dtype)
751
752        # float x float scalar tensor
753        # Note: result dtype is the type of the float tensor
754        if _supported((torch.float32, torch.float64)) and op.supports_rhs_python_scalar:
755            lhs_f32 = make_lhs(dtype=torch.float32)
756            rhs_f64_scalar_tensor = make_rhs_scalar_tensor(dtype=torch.float64)
757
758            result = op(lhs_f32, rhs_f64_scalar_tensor)
759            expected_dtype = torch.float32 if not op.always_returns_bool else torch.bool
760            self.assertEqual(result.dtype, expected_dtype)
761
762        # complex x complex scalar tensor
763        # Note: result dtype is the type of the complex tensor
764        if (
765            _supported((torch.complex64, torch.complex128))
766            and op.supports_rhs_python_scalar
767        ):
768            lhs_c64 = make_lhs(dtype=torch.complex64)
769            rhs_c128_scalar_tensor = make_rhs_scalar_tensor(dtype=torch.complex128)
770
771            result = op(lhs_c64, rhs_c128_scalar_tensor)
772            expected_dtype = (
773                torch.complex64 if not op.always_returns_bool else torch.bool
774            )
775            self.assertEqual(result.dtype, expected_dtype)
776
777        # scalar  x scalar
778        # Note: result dtype is default float type
779        if op.supports_two_python_scalars and _supported((torch.long, torch.float32)):
780            rhs_f_scalar = 2.0
781            for lhs in (1, 1.0):
782                result = op(lhs, rhs_f_scalar)
783                expected_dtype = (
784                    torch.get_default_dtype()
785                    if not op.always_returns_bool
786                    else torch.bool
787                )
788                self.assertEqual(result.dtype, expected_dtype)
789
790    # TODO: move to error input test
791    @ops(binary_ufuncs, allowed_dtypes=(torch.float32,))
792    def test_not_broadcastable(self, device, dtype, op):
793        for shape_lhs, shape_rhs in (
794            ((2,), (3,)),
795            ((3, 1), (2, 1)),
796            ((1, 3, 2), (3,)),
797            ((3, 1, 2), (2, 1, 2)),
798        ):
799            lhs = make_tensor(
800                shape_lhs, device=device, dtype=dtype, **op.lhs_make_tensor_kwargs
801            )
802            rhs = make_tensor(
803                shape_rhs, device=device, dtype=dtype, **op.rhs_make_tensor_kwargs
804            )
805
806            try:
807                broadcasted_shape = op(lhs, rhs).shape
808            except RuntimeError:
809                continue
810
811            msg = (
812                f"On {device}, torch.{op.name} broadcasts inputs shapes {shape_lhs} and {shape_rhs} into "
813                f"{broadcasted_shape}, although they are not broadcastable."
814            )
815            raise AssertionError(msg)
816
817    def test_add_broadcast_empty(self, device):
818        # empty + empty
819        self.assertRaises(
820            RuntimeError,
821            lambda: torch.randn(5, 0, device=device) + torch.randn(0, 5, device=device),
822        )
823        self.assertEqual(
824            torch.randn(5, 0, device=device),
825            torch.randn(0, device=device) + torch.randn(5, 0, device=device),
826        )
827        self.assertEqual(
828            torch.randn(5, 0, 0, device=device),
829            torch.randn(0, device=device) + torch.randn(5, 0, 1, device=device),
830        )
831
832        # scalar + empty
833        self.assertEqual(
834            torch.randn(5, 0, 6, device=device),
835            torch.randn((), device=device) + torch.randn(5, 0, 6, device=device),
836        )
837
838        # non-empty, empty
839        self.assertEqual(
840            torch.randn(0, device=device),
841            torch.randn(0, device=device) + torch.randn(1, device=device),
842        )
843        self.assertEqual(
844            torch.randn(0, 7, 0, 6, 5, 0, 7, device=device),
845            torch.randn(0, 7, 0, 6, 5, 0, 1, device=device)
846            + torch.randn(1, 1, 5, 1, 7, device=device),
847        )
848        self.assertRaises(
849            RuntimeError,
850            lambda: torch.randn(7, 0, device=device) + torch.randn(2, 1, device=device),
851        )
852
853    def test_addcmul_scalars_as_floats(self, device):
854        # zero-dim variables that don't require grad should bind to scalar arguments
855        x = torch.tensor(2.0)
856        y = torch.tensor(3.0, device=device)
857        # 3 + (3 * 3) * 2
858        self.assertEqual(y.addcmul(y, y, value=x), 21)
859
860        x = torch.tensor(2.0, requires_grad=True)
861        self.assertRaises(Exception, lambda: y.addcmul(y, y, value=x))
862
863    # Tests that the binary operators and, or, and xor (as well as their reflected and inplace versions)
864    # work properly (AKA &, ||, ^ and &=, |=, ^=)
865    @dtypes(*integral_types_and(torch.bool))
866    def test_bitwise_ops(self, device, dtype):
867        # Tensor x Tensor and Tensor x Scalar ops
868        ops = (
869            operator.and_,
870            operator.iand,
871            operator.or_,
872            operator.ior,
873            operator.xor,
874            operator.ixor,
875        )
876        inplace_ops = (operator.iand, operator.ior, operator.ixor)
877        shapes = ((5,), (15, 15), (500, 500))
878
879        for op, shape in itertools.product(ops, shapes):
880            # Tests tensor x tensor case
881            a = make_tensor(shape, device=device, dtype=dtype)
882            b = make_tensor(shape, device=device, dtype=dtype)
883            a_np = a.cpu().clone().numpy()
884            b_np = b.cpu().clone().numpy()
885            self.assertEqual(op(a, b), op(a_np, b_np))
886
887            # Tests tensor x scalar case
888            a = make_tensor(shape, device=device, dtype=dtype)
889            b_scalar = make_tensor((), device="cpu", dtype=dtype).item()
890            a_np = a.cpu().clone().numpy()
891            self.assertEqual(op(a, b_scalar), op(a_np, b_scalar))
892
893            # Tests scalar x tensor case
894            a_scalar = make_tensor((), device="cpu", dtype=dtype).item()
895            b = make_tensor(shape, device=device, dtype=dtype)
896            b_np = b.cpu().clone().numpy()
897            self.assertEqual(op(a_scalar, b), op(a_scalar, b_np))
898
899            # Tests scalar x tensor case (for ops which aren't inplace)
900            if op in inplace_ops:
901                # Tests tensor x tensor case
902                a = make_tensor(shape, device=device, dtype=dtype)
903                b = make_tensor(shape, device=device, dtype=dtype)
904                a_np = a.cpu().clone().numpy()
905                b_np = b.cpu().clone().numpy()
906                op(a, b)
907                op(a_np, b_np)
908                self.assertEqual(a, a_np)
909
910                # Tests tensor x scalar case
911                a = make_tensor(shape, device=device, dtype=dtype)
912                b_scalar = make_tensor((), device="cpu", dtype=dtype).item()
913                a_np = a.cpu().clone().numpy()
914                op(a, b_scalar)
915                op(a_np, b_scalar)
916                self.assertEqual(a, a_np)
917
918    def test_inplace_division(self, device):
919        t = torch.rand(5, 5, device=device)
920        id_before = id(t)
921        t /= 2
922        id_after = id(t)
923        self.assertEqual(id_before, id_after)
924
925    @dtypes(*all_types_and(torch.half, torch.bfloat16))
926    def test_div_rounding_modes(self, device, dtype):
927        if dtype.is_floating_point:
928            low, high = -10.0, 10.0
929        else:
930            info = torch.iinfo(dtype)
931            low, high = info.min, info.max
932
933        a = make_tensor((100,), dtype=dtype, device=device, low=low, high=high)
934        b = make_tensor((100,), dtype=dtype, device=device, low=low, high=high)
935
936        # Avoid division by zero so we can test (a / b) * b == a
937        if dtype.is_floating_point:
938            eps = 0.1
939            b[(-eps < b) & (b < eps)] = eps
940        else:
941            b[b == 0] = 1
942
943        if not dtype.is_floating_point:
944            # floor(a / b) * b can be < a, so fixup slightly to avoid underflow
945            a = torch.where(a < 0, a + b, a)
946
947        d_true = torch.divide(a, b, rounding_mode=None)
948        self.assertTrue(d_true.is_floating_point())
949        self.assertEqual(d_true * b, a.to(d_true.dtype))
950
951        d_floor = torch.divide(a, b, rounding_mode="floor")
952        if dtype not in (torch.bfloat16, torch.half):
953            self.assertEqual(d_floor * b + torch.remainder(a, b), a)
954        else:
955            self.assertEqual(
956                d_floor * b + torch.remainder(a.float(), b.float()),
957                a,
958                exact_dtype=False,
959            )
960
961        d_trunc = torch.divide(a, b, rounding_mode="trunc")
962        rounding_unsupported = (
963            dtype == torch.half
964            and device != "cuda"
965            or dtype == torch.bfloat16
966            and device != "cpu"
967        )
968        d_ref = d_true.float() if rounding_unsupported else d_true
969        self.assertEqual(d_trunc, d_ref.trunc().to(dtype))
970
971    @dtypes(*floating_types_and(torch.bfloat16, torch.float16))
972    def test_floor_div_extremal(self, device, dtype):
973        for num, denom, shape in itertools.product(
974            [torch.finfo(dtype).max * 0.7],
975            [0.5, -0.5, 0.0],
976            [(), (32,)],  # Scalar and vectorized
977        ):
978            a = torch.full(shape, num, dtype=dtype, device=device)
979            b = torch.full(shape, denom, dtype=dtype, device=device)
980
981            ref = np.floor_divide(num, denom).item()
982            if ref > torch.finfo(dtype).max:
983                ref = np.inf
984            elif ref < torch.finfo(dtype).min:
985                ref = -np.inf
986            expect = torch.full(shape, ref, dtype=dtype, device=device)
987            actual = torch.div(a, b, rounding_mode="floor")
988            self.assertEqual(expect, actual)
989
990    @dtypes(torch.bfloat16, torch.half, torch.float32, torch.float64)
991    def test_div_rounding_nonfinite(self, device, dtype):
992        # Compare division of special floating point values against NumPy
993        num = torch.tensor(
994            [1.0, -1.0, 0, 0.1, -0.1, np.pi, -np.pi, np.inf, -np.inf, np.nan],
995            dtype=dtype,
996            device=device,
997        )
998        # Divide by zero is tested separately
999        denom = num[num != 0]
1000
1001        a, b = num[None, :].clone(), denom[:, None].clone()
1002
1003        # Compare bfloat16 against NumPy float
1004        exact_dtype = dtype != torch.bfloat16
1005        if exact_dtype:
1006            an, bn = a.cpu().numpy(), b.cpu().numpy()
1007        else:
1008            an, bn = a.float().cpu().numpy(), b.float().cpu().numpy()
1009
1010        for mode, np_ref in ((None, np.true_divide), ("floor", np.floor_divide)):
1011            expect = np_ref(an, bn)
1012            kwargs = dict(rounding_mode=mode) if mode is not None else {}
1013            with set_default_dtype(torch.double):
1014                actual = torch.divide(a, b, **kwargs)
1015            self.assertEqual(
1016                actual,
1017                torch.from_numpy(expect),
1018                exact_device=False,
1019                exact_dtype=exact_dtype,
1020            )
1021
1022        # Compare contiguous (likely vectorized) against non-contiguous (not vectorized)
1023        a_noncontig = torch.empty([2 * i for i in a.shape], dtype=dtype, device=device)[
1024            ::2, ::2
1025        ]
1026        a_noncontig[:] = a
1027        b_noncontig = torch.empty([2 * i for i in b.shape], dtype=dtype, device=device)[
1028            ::2, ::2
1029        ]
1030        b_noncontig[:] = b
1031
1032        for rounding_mode in (None, "trunc", "floor"):
1033            expect = torch.divide(a_noncontig, b_noncontig, rounding_mode=rounding_mode)
1034            actual = torch.divide(a, b, rounding_mode=rounding_mode)
1035            self.assertEqual(actual, expect)
1036
1037    @dtypes(torch.bfloat16, torch.half, torch.float32, torch.float64)
1038    def test_divide_by_zero_rounding(self, device, dtype):
1039        a = torch.tensor(
1040            [1.0, -1.0, 0, 0.1, -0.1, np.pi, -np.pi, np.inf, -np.inf, np.nan],
1041            dtype=dtype,
1042        )
1043        exact_dtype = dtype != torch.bfloat16
1044        if exact_dtype:
1045            an = a.cpu().numpy()
1046        else:
1047            an = a.float().cpu().numpy()
1048
1049        zero = torch.zeros_like(a)
1050
1051        # NOTE: NumPy's floor_divide rounding changed in 1.20.0 to be consistent with divide
1052        expect = np.divide(an, 0)
1053        for rounding_mode in (None, "floor"):
1054            # CPU scalar
1055            actual = torch.divide(a, 0, rounding_mode=rounding_mode)
1056            self.assertEqual(actual, expect, exact_dtype=exact_dtype)
1057            # Device tensor
1058            actual = torch.divide(a, zero, rounding_mode=rounding_mode)
1059            self.assertEqual(actual, expect, exact_dtype=exact_dtype)
1060
1061    @dtypes(*all_types_and(torch.half))
1062    def test_div_rounding_numpy(self, device, dtype):
1063        info = torch.finfo(dtype) if dtype.is_floating_point else torch.iinfo(dtype)
1064        low, high = info.min, info.max
1065
1066        # Compare division of random values against NumPy
1067        a = make_tensor((4096,), dtype=dtype, device=device, low=low, high=high)
1068        b = make_tensor((4096,), dtype=dtype, device=device, low=low, high=high)
1069
1070        # Avoid division by zero which raises for integers and, for floats,
1071        # NumPy 1.20 changed floor_divide to follow IEEE rules for inf/nan
1072        # after dividing by zero.
1073        b[b == 0] = 1
1074
1075        # Compare bfloat16 against NumPy float
1076        exact_dtype = dtype != torch.bfloat16
1077
1078        if exact_dtype:
1079            an, bn = a.cpu().numpy(), b.cpu().numpy()
1080        else:
1081            an, bn = a.float().cpu().numpy(), b.float().cpu().numpy()
1082
1083        for mode, np_ref in (
1084            (None, np.true_divide),
1085            ("floor", np.floor_divide),
1086            ("trunc", lambda a, b: np.trunc(np.true_divide(a, b)).astype(a.dtype)),
1087        ):
1088            expect = torch.from_numpy(np_ref(an, bn))
1089
1090            kwargs = dict(rounding_mode=mode) if mode is not None else {}
1091            # Contiguous (likely vectorized)
1092            with set_default_dtype(torch.double):
1093                actual = torch.divide(a, b, **kwargs)
1094            self.assertEqual(
1095                actual, expect, exact_device=False, exact_dtype=exact_dtype
1096            )
1097
1098            # Non-contiguous (not vectorized)
1099            expect = expect[::2]
1100            with set_default_dtype(torch.double):
1101                actual = torch.divide(a[::2], b[::2], **kwargs)
1102
1103            self.assertEqual(
1104                actual, expect, exact_device=False, exact_dtype=exact_dtype
1105            )
1106
1107    @dtypes(*complex_types())
1108    def test_complex_div_underflow_overflow(self, device, dtype):
1109        # test to make sure the complex division does not produce underflow or overflow
1110        # in the intermediate of its calculations
1111        # NOTE: the calculation still produces an error if the number is greater than
1112        # finfo.max / 2, but hopefully people realized that it's a dangerous region to work with
1113        finfo = torch.finfo(dtype)
1114        nom_lst = [
1115            complex(finfo.min / 2, finfo.min / 2),
1116            complex(finfo.max / 2, finfo.max / 2),
1117            complex(finfo.tiny, finfo.tiny),
1118            complex(finfo.tiny, 0.0),
1119            complex(0.0, 0.0),
1120        ]
1121        denom_lst = [
1122            complex(finfo.min / 2, finfo.min / 2),
1123            complex(finfo.max / 2, finfo.max / 2),
1124            complex(finfo.tiny, finfo.tiny),
1125            complex(0.0, finfo.tiny),
1126            complex(finfo.tiny, finfo.tiny),
1127        ]
1128        expected_lst = [
1129            complex(1.0, 0.0),
1130            complex(1.0, 0.0),
1131            complex(1.0, 0.0),
1132            complex(0.0, -1.0),
1133            complex(0.0, 0.0),
1134        ]
1135        nom = torch.tensor(nom_lst, dtype=dtype, device=device)
1136        denom = torch.tensor(denom_lst, dtype=dtype, device=device)
1137        expected = torch.tensor(expected_lst, dtype=dtype, device=device)
1138        res = nom / denom
1139        self.assertEqual(res, expected)
1140
1141    # Tests that trying to add, inplace, a CUDA tensor to a CPU tensor
1142    #   throws the correct error message
1143    @onlyCUDA
1144    def test_cross_device_inplace_error_msg(self, device):
1145        a = torch.tensor(2.0)
1146        b = torch.tensor(2.0, device=device)
1147        with self.assertRaisesRegex(
1148            RuntimeError, "Expected all tensors to be on the same device"
1149        ):
1150            a += b
1151
1152    # TODO: refactor this test into a more generic one, it's parked here currently
1153    @onlyNativeDeviceTypes
1154    def test_out_resize_warning(self, device):
1155        a = torch.tensor((1, 2, 3), device=device, dtype=torch.float32)
1156        b = torch.tensor((4, 5, 6), device=device, dtype=torch.float32)
1157
1158        unary_inputs = (a,)
1159        binary_inputs = (a, b)
1160        unary_ops = (torch.ceil, torch.exp)
1161        binary_ops = (torch.add, torch.sub)
1162        for op in unary_ops + binary_ops:
1163            with warnings.catch_warnings(record=True) as w:
1164                warnings.simplefilter("always")
1165                inputs = unary_inputs if op in unary_ops else binary_inputs
1166
1167                # No warnings
1168                op(*inputs, out=torch.empty(3, device=device))
1169                op(*inputs, out=torch.empty(0, device=device))
1170                self.assertEqual(len(w), 0)
1171
1172                # Cases that throw warnings
1173                op(*inputs, out=torch.empty(2, device=device))
1174                self.assertEqual(len(w), 1)
1175        # test that multi-d out doesn't trigger segfault
1176        arg1 = (torch.ones(2, 1, device=device), torch.ones(1, device=device))
1177        arg2 = (torch.ones(2, device=device), torch.ones(1, 1, device=device))
1178        outs = (
1179            torch.ones(2, 1, 1, 1, device=device),
1180            torch.ones(2, 2, 2, 2, device=device),
1181        )
1182
1183        for a1, a2, o in zip(arg1, arg2, outs):
1184            with warnings.catch_warnings(record=True) as w:
1185                warnings.simplefilter("always")
1186                torch.mul(a1, a2, out=o)
1187                self.assertEqual(len(w), 1)
1188
1189    # Verifies that the inplace dunders (like idiv) actually are in place
1190    @expectedFailureMeta  # UserWarning not triggered
1191    @onlyNativeDeviceTypes
1192    def test_inplace_dunders(self, device):
1193        t = torch.randn((1,), device=device)
1194        expected = t.data_ptr()
1195        t += 1
1196        t -= 1
1197        t *= 1
1198        t /= 1
1199        t **= 1
1200        t //= 1
1201        t %= 1
1202        self.assertEqual(expected, t.data_ptr())
1203
1204    def check_internal_mem_overlap(
1205        self, inplace_op, num_inputs, dtype, device, expected_failure=False
1206    ):
1207        if isinstance(inplace_op, str):
1208            inplace_op = getattr(torch.Tensor, inplace_op)
1209        input = torch.randn(1, dtype=dtype, device=device).expand(3, 3)
1210        inputs = [input] + [torch.randn_like(input) for i in range(num_inputs - 1)]
1211        if not expected_failure:
1212            with self.assertRaisesRegex(RuntimeError, "single memory location"):
1213                inplace_op(*inputs)
1214        else:
1215            with self.assertRaises(AssertionError):
1216                with self.assertRaisesRegex(RuntimeError, "single memory location"):
1217                    inplace_op(*inputs)
1218
1219    def unary_check_input_output_mem_overlap(
1220        self, data, sz, op, expected_failure=False
1221    ):
1222        def _test(op, output, input):
1223            output_exp = torch.empty_like(output)
1224            op(input, out=output_exp)
1225            self.assertEqual(op(input, out=output), output_exp, msg=op.__name__)
1226
1227        # output is identical to input:
1228        _test(op, output=data[0:sz], input=data[0:sz])
1229        # output and input are independent:
1230        _test(op, output=data[0:sz], input=data[sz : 2 * sz])
1231        # output partially overlaps with input:
1232        if not expected_failure:
1233            with self.assertRaisesRegex(RuntimeError, "unsupported operation"):
1234                _test(op, data[0:sz], data[1 : sz + 1])
1235        else:
1236            with self.assertRaises(AssertionError):
1237                with self.assertRaisesRegex(RuntimeError, "unsupported operation"):
1238                    _test(op, data[0:sz], data[1 : sz + 1])
1239
1240    def binary_check_input_output_mem_overlap(self, op, device, expected_failure=False):
1241        sz = 3
1242        data = torch.randn(2 * sz, device=device)
1243        other = torch.randn(sz, device=device)
1244
1245        self.unary_check_input_output_mem_overlap(
1246            data,
1247            sz,
1248            lambda input, out: op(other, input, out=out),
1249            expected_failure=expected_failure,
1250        )
1251
1252        self.unary_check_input_output_mem_overlap(
1253            data,
1254            sz,
1255            lambda input, out: op(input, other, out=out),
1256            expected_failure=expected_failure,
1257        )
1258
1259    # https://github.com/pytorch/pytorch/issues/126474
1260    @xfailIfTorchDynamo
1261    @dtypes(torch.double)
1262    def test_binary_op_mem_overlap(self, device, dtype):
1263        ops = [
1264            ("add", True, True, "cpu"),
1265            ("add", True, True, "cuda"),
1266            ("mul", True, True, "cpu"),
1267            ("mul", True, True, "cuda"),
1268            ("sub", True, True, "cpu"),
1269            ("sub", True, True, "cuda"),
1270            ("div", True, True, "cpu"),
1271            ("div", True, True, "cuda"),
1272            ("pow", True, True, "cpu"),
1273            ("pow", True, True, "cuda"),
1274            ("fmod", True, True, "cpu"),
1275            ("fmod", True, True, "cuda"),
1276            ("atan2", True, True, "cpu"),
1277            ("atan2", True, True, "cuda"),
1278            ("hypot", True, True, "cpu"),
1279            ("hypot", True, True, "cuda"),
1280            ("igamma", True, True, "cpu"),
1281            ("igamma", True, True, "cuda"),
1282            ("igammac", True, True, "cpu"),
1283            ("igammac", True, True, "cuda"),
1284            ("nextafter", True, True, "cpu"),
1285            ("nextafter", True, True, "cuda"),
1286            ("le", True, True, "cpu"),
1287            ("le", True, True, "cuda"),
1288            ("lt", True, True, "cpu"),
1289            ("lt", True, True, "cuda"),
1290            ("ge", True, True, "cpu"),
1291            ("ge", True, True, "cuda"),
1292            ("gt", True, True, "cpu"),
1293            ("gt", True, True, "cuda"),
1294            ("eq", True, True, "cpu"),
1295            ("eq", True, True, "cuda"),
1296            ("ne", True, True, "cpu"),
1297            ("ne", True, True, "cuda"),
1298            ("logical_and", True, True, "cpu"),
1299            ("logical_and", True, True, "cuda"),
1300            ("logical_or", True, True, "cpu"),
1301            ("logical_or", True, True, "cuda"),
1302            ("logical_xor", True, True, "cpu"),
1303            ("logical_xor", True, True, "cuda"),
1304        ]
1305
1306        for (
1307            fn,
1308            has_input_output_mem_overlap_check,
1309            has_internal_mem_overlap_check,
1310            dev,
1311        ) in ops:
1312            if dev != device:
1313                continue
1314            out_op = getattr(torch, fn)
1315            inplace_op = getattr(torch.Tensor, fn + "_")
1316            self.check_internal_mem_overlap(
1317                inplace_op,
1318                2,
1319                dtype,
1320                device,
1321                expected_failure=not has_internal_mem_overlap_check,
1322            )
1323
1324            self.binary_check_input_output_mem_overlap(
1325                out_op, device, expected_failure=not has_input_output_mem_overlap_check
1326            )
1327
1328    def _do_pow_for_exponents(self, m1, exponents, pow_fn, atol):
1329        for num in exponents:
1330            if (
1331                isinstance(num, int)
1332                and num < 0
1333                and not m1.is_floating_point()
1334                and not m1.is_complex()
1335            ):
1336                with self.assertRaisesRegex(
1337                    RuntimeError,
1338                    r"Integers to negative integer powers are not allowed\.",
1339                ):
1340                    torch.pow(m1[4], num)
1341            else:
1342                # base - tensor, exponent - number
1343                # contiguous
1344                res1 = torch.pow(m1[4], num)
1345                res2 = res1.clone().zero_()
1346                # `math.pow` has issues with complex exponentiation so we need to resort to normal `pow`.
1347                for i in range(res2.size(0)):
1348                    res2[i] = pow_fn(m1[4][i], num)
1349                rtol = 0 if atol is not None else None
1350                self.assertEqual(res1, res2, atol=atol, rtol=rtol)
1351
1352                # non-contiguous
1353                res1 = torch.pow(m1[:, 4], num)
1354                res2 = res1.clone().zero_()
1355                for i in range(res2.size(0)):
1356                    res2[i] = pow_fn(m1[i, 4], num)
1357                self.assertEqual(res1, res2, atol=atol, rtol=rtol)
1358
1359                # scalar ** tensor to enforce correct handling of dtypes for __rpow__().
1360                expected_dtype = torch.result_type(num, m1)
1361                res1 = num ** m1[4]
1362                res2 = (
1363                    torch.tensor(num, dtype=expected_dtype, device=m1.device) ** m1[4]
1364                )
1365                self.assertEqual(res1, res2)
1366                self.assertEqual(res1.dtype, expected_dtype)
1367
1368    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
1369    def test_pow(self, device, dtype):
1370        m1 = torch.empty(0, dtype=dtype, device=device)
1371        if m1.is_floating_point() or m1.is_complex():
1372            m1 = (
1373                make_tensor((100, 100), low=0, high=1, dtype=dtype, device=device) + 0.5
1374            )
1375        else:
1376            # math.pow will overflow and throw exceptions for large integers
1377            range_high = 4 if dtype in (torch.int8, torch.uint8) else 10
1378            m1 = make_tensor(
1379                (100, 100), low=1, high=range_high, dtype=dtype, device=device
1380            )
1381
1382        exponents = [-2.8, -2, -1, -0.5, 0, 0.5, 1, 2, 3, 4, 3.3, True, False]
1383        complex_exponents = [
1384            -2.5j,
1385            -1.0j,
1386            0j,
1387            1.0j,
1388            2.5j,
1389            1.0 + 1.0j,
1390            -1.0 - 1.5j,
1391            3.3j,
1392        ]
1393        if m1.is_complex():
1394            self._do_pow_for_exponents(m1, exponents + complex_exponents, pow, 10e-4)
1395        else:
1396            self._do_pow_for_exponents(m1, exponents, math.pow, None)
1397            will_raise_error = (
1398                dtype is torch.half and torch.device(device).type == "cpu"
1399            )
1400            if will_raise_error:
1401                # On CPU,
1402                # Half Tensor with complex exponents leads to computation dtype
1403                # of ComplexHalf for which this ops is not supported yet
1404                with self.assertRaisesRegex(
1405                    RuntimeError, "not implemented for 'ComplexHalf'"
1406                ):
1407                    self._do_pow_for_exponents(m1, complex_exponents, pow, 10e-4)
1408            else:
1409                self._do_pow_for_exponents(m1, complex_exponents, pow, 10e-4)
1410
1411        # base - number, exponent - tensor
1412        # contiguous
1413        res1 = torch.pow(3, m1[4])
1414        res2 = res1.clone().zero_()
1415        for i in range(res2.size(0)):
1416            res2[i] = pow(3, m1[4, i])
1417        self.assertEqual(res1, res2)
1418
1419        # non-contiguous
1420        res1 = torch.pow(3, m1[:, 4])
1421        res2 = res1.clone().zero_()
1422        for i in range(res2.size(0)):
1423            res2[i] = pow(3, m1[i][4])
1424        self.assertEqual(res1, res2)
1425
1426    # TODO: refactor all these tests using opinfos properly
1427    def _test_pow(self, base, exponent, np_exponent=None):
1428        if np_exponent is None:
1429            np_exponent = exponent
1430
1431        def to_np(value):
1432            if isinstance(value, torch.Tensor):
1433                return value.cpu().numpy()
1434            return value
1435
1436        try:
1437            np_res = np.power(to_np(base), to_np(np_exponent))
1438            expected = (
1439                torch.from_numpy(np_res)
1440                if isinstance(np_res, np.ndarray)
1441                else torch.tensor(np_res, dtype=base.dtype)
1442            )
1443        except ValueError as e:
1444            err_msg = "Integers to negative integer powers are not allowed."
1445            self.assertEqual(str(e), err_msg)
1446            out = torch.empty_like(base)
1447            test_cases = [
1448                lambda: base.pow(exponent),
1449                lambda: base.pow_(exponent),
1450                lambda: torch.pow(base, exponent),
1451                lambda: torch.pow(base, exponent, out=out),
1452            ]
1453            for test_case in test_cases:
1454                self.assertRaisesRegex(RuntimeError, err_msg, test_case)
1455        else:
1456            if isinstance(base, torch.Tensor):
1457                actual = base.pow(exponent)
1458                self.assertEqual(actual, expected.to(actual))
1459                actual = base.clone()
1460                # When base is a 0-dim cpu tensor and exp is a cuda tensor, we exp `pow` to work but `pow_` to fail, since
1461                # `pow` will try to create the output tensor on a cuda device, but `pow_` needs to use the cpu tensor as the output
1462                if (
1463                    isinstance(exponent, torch.Tensor)
1464                    and base.dim() == 0
1465                    and base.device.type == "cpu"
1466                    and exponent.device.type == "cuda"
1467                ):
1468                    regex = "Expected all tensors to be on the same device, but found at least two devices, cuda.* and cpu!"
1469                    self.assertRaisesRegex(RuntimeError, regex, base.pow_, exponent)
1470                elif torch.can_cast(torch.result_type(base, exponent), base.dtype):
1471                    actual2 = actual.pow_(exponent)
1472                    self.assertEqual(actual, expected)
1473                    self.assertEqual(actual2, expected)
1474                else:
1475                    self.assertRaisesRegex(
1476                        RuntimeError,
1477                        "Found dtype \\w+ but expected \\w+",
1478                        lambda: actual.pow_(exponent),
1479                    )
1480
1481            actual = torch.pow(base, exponent)
1482            self.assertEqual(actual, expected.to(actual))
1483
1484            actual2 = torch.pow(base, exponent, out=actual)
1485            self.assertEqual(actual, expected.to(actual))
1486            self.assertEqual(actual2, expected.to(actual))
1487
1488    # We can potentially merge this into OpInfo, but one blocker is that the
1489    # first input must be a scalar. It is not as simple as just wrapping this in
1490    # a lambada that switches the inputs, because we also want to test samples inputs
1491    # where the second input is a scalar. The wrapper would need some more logic.
1492    def test_pow_scalar_base(self, device):
1493        a = (
1494            torch.arange(1, 13, dtype=torch.double, device=device)
1495            .view(3, 4)
1496            .requires_grad_()
1497        )
1498        gradcheck(lambda a: torch.pow(2, a), (a,))
1499
1500    # Tests pow() for integral, floating-type tensors, with integral, floating-type
1501    # exponents (tensor or scalar), respectively. noncontiguous tensors are also tested.
1502    def test_int_and_float_pow(self, device):
1503        def _test_int_and_float_pow(dt, low, high, dev):
1504            test_cases = (
1505                ((4, 4), 0, (4, 1)),
1506                ((3, 1), 4, (3, 1)),
1507                ((2,), 4, (1,)),
1508                ((1,), 2, ()),
1509                ((513, 513), 4, (513,)),
1510                ((5, 5, 5), 5, (5,)),
1511                ((), 2, ()),
1512            )
1513            for base_shape, exp_scalar, exp_shape in test_cases:
1514                base_tensor = make_tensor(
1515                    base_shape, dtype=dt, device=dev, low=low, high=high
1516                )
1517                # int tensors don't take negative exponents
1518                if dt in [
1519                    torch.uint8,
1520                    torch.int8,
1521                    torch.int16,
1522                    torch.int32,
1523                    torch.int64,
1524                ]:
1525                    exp_tensor = make_tensor(
1526                        exp_shape, dtype=dt, device=dev, low=0, high=high
1527                    )
1528                else:
1529                    exp_tensor = make_tensor(
1530                        exp_shape, dtype=dt, device=dev, low=low, high=high
1531                    )
1532                self._test_pow(base_tensor, exp_scalar)
1533                self._test_pow(base_tensor, exp_tensor)
1534                # test non-contiguous tensors as well
1535                base_tensor = make_tensor(
1536                    base_shape,
1537                    dtype=dt,
1538                    device=dev,
1539                    low=low,
1540                    high=high,
1541                    noncontiguous=True,
1542                )
1543                if dt in [
1544                    torch.uint8,
1545                    torch.int8,
1546                    torch.int16,
1547                    torch.int32,
1548                    torch.int64,
1549                ]:
1550                    exp_tensor = make_tensor(
1551                        exp_shape,
1552                        dtype=dt,
1553                        device=dev,
1554                        low=0,
1555                        high=high,
1556                        noncontiguous=True,
1557                    )
1558                else:
1559                    exp_tensor = make_tensor(
1560                        exp_shape,
1561                        dtype=dt,
1562                        device=dev,
1563                        low=low,
1564                        high=high,
1565                        noncontiguous=True,
1566                    )
1567                self._test_pow(base_tensor, exp_scalar)
1568                self._test_pow(base_tensor, exp_tensor)
1569
1570        _test_int_and_float_pow(torch.int8, -2, 2, device)
1571        _test_int_and_float_pow(torch.uint8, 0, 3, device)
1572        _test_int_and_float_pow(torch.int16, -5, 5, device)
1573        _test_int_and_float_pow(torch.int64, -10, 10, device)
1574        _test_int_and_float_pow(torch.int32, -10, 10, device)
1575        _test_int_and_float_pow(torch.float16, 0.0, 5.0, device)
1576        _test_int_and_float_pow(torch.float32, 0.0, 10.0, device)
1577        _test_int_and_float_pow(torch.float64, 0.0, 10.0, device)
1578        # pow's output would have some NaNs as well
1579        _test_int_and_float_pow(torch.float32, -10.0, 10.0, device)
1580        _test_int_and_float_pow(torch.float64, -10.0, 10.0, device)
1581
1582    # Tests that a Runtime error occurs when a base tensor cannot be resized
1583    # by pow's inplace variant due to PyTorch's broadcasting semantics.
1584    def test_pow_inplace_resizing_exception(self, device):
1585        test_cases = (
1586            ((), (3,)),
1587            ((2,), (2, 1)),
1588            ((2, 1), (2, 2)),
1589            ((2, 2), (2, 1, 1)),
1590        )
1591        test_inputs = [
1592            (
1593                make_tensor(
1594                    base_size, dtype=torch.float64, device=device, high=10.0, low=0.0
1595                ),
1596                make_tensor(
1597                    exp_size, dtype=torch.float64, device=device, high=10.0, low=0.0
1598                ),
1599            )
1600            for base_size, exp_size in test_cases
1601        ]
1602        for base, exponent in test_inputs:
1603            regex = "doesn't match the broadcast shape"
1604            self.assertRaisesRegex(RuntimeError, regex, base.pow_, exponent)
1605
1606    def test_int_tensor_pow_neg_ints(self, device):
1607        ints = [
1608            torch.iinfo(torch.int32).min,
1609            -3,
1610            -2,
1611            -1,
1612            0,
1613            1,
1614            2,
1615            3,
1616            torch.iinfo(torch.int32).max,
1617        ]
1618        neg_ints = [torch.iinfo(torch.int32).min, -3, -2, -1]
1619        tensor = torch.tensor(ints, dtype=torch.int32, device=device)
1620        for pow in neg_ints:
1621            self._test_pow(tensor, pow)
1622
1623    def test_long_tensor_pow_floats(self, device):
1624        ints = [0, 1, 23, 4567]
1625        floats = [0.0, 1 / 3, 1 / 2, 1.0, 3 / 2, 2.0]
1626        tensor = torch.tensor(ints, dtype=torch.int64, device=device)
1627        for pow in floats:
1628            self._test_pow(tensor, pow)
1629
1630    @dtypes(*[torch.float32, torch.float64])
1631    def test_float_scalar_pow_float_tensor(self, device, dtype):
1632        floats = [2.0, -3 / 2, -1.0, -1 / 2, -1 / 3, 0.0, 1 / 3, 1 / 2, 1.0, 3 / 2, 2.0]
1633        exponent_shapes = (
1634            (1,),
1635            (2, 2),
1636            (2, 1),
1637            (2, 2, 2),
1638        )
1639        tensors = [
1640            make_tensor(shape, dtype=dtype, device=device, low=0)
1641            for shape in exponent_shapes
1642        ]
1643        floats_tensor = torch.tensor(floats, dtype=dtype, device=device)
1644        for base in floats:
1645            self._test_pow(base, floats_tensor)
1646            for tensor in tensors:
1647                self._test_pow(base, tensor)
1648
1649    @onlyCUDA
1650    def test_cuda_tensor_pow_scalar_tensor(self, device):
1651        cuda_tensors = [
1652            torch.randn((3, 3), device=device),
1653            torch.tensor(3.0, device=device),
1654        ]
1655        scalar_tensors = [
1656            torch.tensor(5.0, device="cpu"),
1657            torch.tensor(-3),
1658            torch.tensor(1),
1659        ]
1660        for base, exp in product(cuda_tensors, scalar_tensors):
1661            self._test_pow(base, exp)
1662
1663    @onlyCUDA
1664    def test_cpu_tensor_pow_cuda_scalar_tensor(self, device):
1665        cuda_tensors = [
1666            torch.tensor(5.0, device="cuda"),
1667            torch.tensor(-3, device="cuda"),
1668        ]
1669        for exp in cuda_tensors:
1670            base = torch.randn((3, 3), device="cpu")
1671            regex = "Expected all tensors to be on the same device, but found at least two devices, cuda.* and cpu!"
1672            self.assertRaisesRegex(RuntimeError, regex, torch.pow, base, exp)
1673        for exp in cuda_tensors:
1674            # Binary ops with a cpu + cuda tensor are allowed if the cpu tensor has 0 dimension
1675            base = torch.tensor(3.0, device="cpu")
1676            self._test_pow(base, exp)
1677
1678    @onlyCUDA
1679    @dtypes(torch.complex64, torch.complex128)
1680    def test_pow_cuda_complex_extremal_failing(self, device, dtype):
1681        t = torch.tensor(complex(-1.0, float("inf")), dtype=dtype, device=device)
1682        with self.assertRaises(AssertionError):
1683            cuda_out = t.pow(2)
1684            cpu_out = t.cpu().pow(2)
1685            self.assertEqual(cpu_out, cuda_out)
1686
1687    @skipIfTorchDynamo()
1688    @onlyNativeDeviceTypes
1689    @dtypes(*all_types_and_complex_and(torch.half))
1690    def test_complex_scalar_pow_tensor(self, device, dtype):
1691        complexes = [0.5j, 1.0 + 1.0j, -1.5j, 2.2 - 1.6j, 1 + 0j]
1692        first_exp = make_tensor((100,), dtype=dtype, device=device, low=-2, high=2)
1693        second_exp = make_tensor(
1694            (100,), dtype=dtype, device=device, low=-2, high=2, noncontiguous=True
1695        )
1696        first_exp[0] = first_exp[10] = first_exp[20] = 0
1697        second_exp[0] = second_exp[10] = second_exp[20] = 0
1698        for base in complexes:
1699            # On CPU,
1700            # Half Tensor with complex base leads to computation dtype
1701            # of ComplexHalf for which this ops is not supported yet
1702            # NOTE: pow has fast-path when base is 1 which supports
1703            # ComplexHalf
1704            will_raise_error = (
1705                torch.device(device).type == "cpu"
1706                and dtype is torch.half
1707                and base != (1 + 0j)
1708            )
1709            if will_raise_error:
1710                with self.assertRaisesRegex(
1711                    RuntimeError, "not implemented for 'ComplexHalf'"
1712                ):
1713                    self._test_pow(base, first_exp)
1714                    self._test_pow(base, second_exp)
1715            else:
1716                self._test_pow(base, first_exp)
1717                self._test_pow(base, second_exp)
1718
1719    @onlyNativeDeviceTypes
1720    @skipMeta
1721    def test_pow_scalar_type_promotion(self, device):
1722        # Test against a scalar and non-scalar input
1723        inputs = [17, [17]]
1724        for input in inputs:
1725            # We expect the computation to be performed in uint8 (overflowing to 0), and then cast to int64
1726            input_tensor_uint8 = torch.tensor(input, dtype=torch.uint8, device=device)
1727            out_uint8_computation = torch.pow(
1728                2,
1729                input_tensor_uint8,
1730                out=torch.tensor(0, dtype=torch.int64, device=device),
1731            )
1732
1733            # Computation should run in int64, and not overflow
1734            input_tensor_int64 = torch.tensor(input, dtype=torch.int64, device=device)
1735            out_int64_computation = torch.pow(
1736                2,
1737                input_tensor_int64,
1738                out=torch.tensor(0, dtype=torch.int64, device=device),
1739            )
1740
1741            self.assertNotEqual(out_uint8_computation, out_int64_computation)
1742            self.assertEqual(
1743                out_uint8_computation.to(dtype=torch.uint8),
1744                out_int64_computation.to(dtype=torch.uint8),
1745            )
1746
1747    def test_tensor_pow_tensor(self, device):
1748        def rotate(l, n):
1749            return l[-n:] + l[:-n]
1750
1751        def test_tensor_pow_tensor(values, torch_type, numpy_type):
1752            vals_tensor = torch.tensor(values, dtype=torch_type, device=device)
1753            for i in range(len(values)):
1754                pows = rotate(values, i)
1755                pows_tensor = torch.tensor(pows, dtype=torch_type, device=device)
1756                self._test_pow(vals_tensor, pows_tensor)
1757
1758        ints = [0, 1, 2, 3]
1759        test_tensor_pow_tensor(ints, torch.uint8, np.uint8)
1760        test_tensor_pow_tensor(ints, torch.int8, np.int8)
1761        test_tensor_pow_tensor(ints, torch.int16, np.int16)
1762        test_tensor_pow_tensor(ints, torch.int32, np.int32)
1763        test_tensor_pow_tensor(ints, torch.int64, np.int64)
1764
1765        floats = [-3.0, -2.0, -1.0, -1 / 2, -1 / 3, 0.0, 1 / 3, 1 / 2, 1.0, 2.0, 3.0]
1766        test_tensor_pow_tensor(floats, torch.float16, np.float16)
1767        test_tensor_pow_tensor(floats, torch.float32, np.float32)
1768        test_tensor_pow_tensor(floats, torch.float64, np.float64)
1769
1770    def test_logical_xor_with_nontrivial_alignment(self, device):
1771        # test tensor that is not aligned to multiple of 16 bytes
1772        size = 128
1773        a = torch.randn(size, device=device) > 0
1774        b = torch.randn(size, device=device) > 0
1775        c = torch.randn(size, device=device) > 0
1776        non_trivial_alignment = [1, 2, 4, 8, 15]
1777        for i in non_trivial_alignment:
1778            for j in non_trivial_alignment:
1779                for k in non_trivial_alignment:
1780                    a_ = a[i : 100 + i]
1781                    b_ = b[j : 100 + j]
1782                    c_ = c[k : 100 + k]
1783                    torch.logical_xor(a_, b_, out=c_)
1784                    for x, y, z in zip(a_.tolist(), b_.tolist(), c_.tolist()):
1785                        self.assertEqual(x ^ y, z)
1786
1787    @dtypes(torch.float)
1788    def test_add_with_tail(self, device, dtype):
1789        # test tensor where there is a tail which is not a multiple
1790        # of GPU warp size
1791        for tail_size in [1, 63, 67, 130]:
1792            size = 4096 + tail_size
1793            a = torch.randn(size, device=device, dtype=dtype)
1794            b = torch.randn(size, device=device, dtype=dtype)
1795            c = a + b
1796            for x, y, z in zip(a.tolist(), b.tolist(), c.tolist()):
1797                self.assertEqual(x + y, z)
1798
1799    # Tests that CUDA tensors on different devices cannot be used in the same
1800    # binary operation, and that CUDA "scalars" cannot be used in the same
1801    # binary operation as non-scalar CPU tensors.
1802    @deviceCountAtLeast(2)
1803    @onlyCUDA
1804    def test_cross_device_binary_ops(self, devices):
1805        vals = (1.0, (2.0,))
1806        cpu_tensor = torch.randn(2, 2)
1807
1808        def do_test(op, a, b):
1809            with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"):
1810                op(a, b)
1811            with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"):
1812                op(b, a)
1813            with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"):
1814                op(a, cpu_tensor)
1815            with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"):
1816                op(cpu_tensor, a)
1817
1818        for op in (
1819            operator.add,
1820            torch.add,
1821            operator.sub,
1822            torch.sub,
1823            operator.mul,
1824            torch.mul,
1825            operator.truediv,
1826            torch.true_divide,
1827            operator.floordiv,
1828            torch.floor_divide,
1829        ):
1830            for a, b in product(vals, vals):
1831                a = torch.tensor(a, device=devices[0])
1832                b = torch.tensor(b, device=devices[1])
1833
1834            do_test(op, a, b)
1835
1836    # This test ensures that a scalar Tensor can be safely used
1837    # in a binary operation in conjunction with a Tensor on all
1838    # available CUDA devices
1839    @deviceCountAtLeast(2)
1840    @onlyCUDA
1841    def test_binary_op_scalar_device_unspecified(self, devices):
1842        scalar_val = torch.tensor(1.0)
1843        for default_device in devices:
1844            with torch.cuda.device(default_device):
1845                for device in devices:
1846                    device_obj = torch.device(device)
1847                    x = torch.rand(3, device=device)
1848                    y0 = x * scalar_val
1849                    self.assertEqual(y0.device, device_obj)
1850                    y1 = scalar_val * x
1851                    self.assertEqual(y1.device, device_obj)
1852                    self.assertEqual(y0, y1)
1853
1854    def test_div_and_floordiv_vs_python(self, device):
1855        # Tests torch division ops which can handle both arguments being
1856        #   scalars.
1857        def _scalar_helper(python_op, torch_op):
1858            for a, b in product(range(-10, 10), range(-10, 10)):
1859                for op in (lambda x: x * 0.5, lambda x: math.floor(x)):
1860                    a = op(a)
1861                    b = op(b)
1862
1863                    # Skips zero divisors
1864                    if b == 0:
1865                        continue
1866
1867                    expected = python_op(a, b)
1868
1869                    for op in (operator.truediv, torch.true_divide):
1870                        actual_scalar = torch_op(a, b)
1871
1872                        a_t = torch.tensor(a, device=device)
1873                        b_t = torch.tensor(b, device=device)
1874
1875                        actual_tensor = torch_op(a_t, b_t)
1876                        actual_first_tensor = torch_op(a_t, b)
1877                        actual_second_tensor = torch_op(a, b_t)
1878
1879                        self.assertEqual(actual_scalar, expected)
1880                        self.assertEqual(actual_tensor.item(), expected)
1881                        self.assertEqual(actual_first_tensor, actual_tensor)
1882                        self.assertEqual(actual_second_tensor, actual_tensor)
1883
1884        _scalar_helper(operator.truediv, operator.truediv)
1885        _scalar_helper(operator.truediv, torch.true_divide)
1886        _scalar_helper(lambda a, b: math.floor(a / b), operator.floordiv)
1887        _scalar_helper(lambda a, b: math.floor(a / b), torch.floor_divide)
1888
1889    @onlyNativeDeviceTypes
1890    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
1891    def test_div_and_floordiv_script_vs_python(self, device):
1892        # Creates jitted functions of two tensors
1893        def _wrapped_div(a, b):
1894            return a / b
1895
1896        def _wrapped_floordiv(a, b):
1897            return a // b
1898
1899        scripted_div = torch.jit.script(_wrapped_div)
1900        scripted_floordiv = torch.jit.script(_wrapped_floordiv)
1901        for a, b in product(range(-10, 10), range(-10, 10)):
1902            for op in (lambda x: x * 0.5, lambda x: math.floor(x)):
1903                a = op(a)
1904                b = op(b)
1905
1906                # Skips zero divisors
1907                if b == 0:
1908                    continue
1909
1910                expected_div = a / b
1911                expected_floordiv = math.floor(a / b)
1912                a_t = torch.tensor(a, device=device)
1913                b_t = torch.tensor(b, device=device)
1914
1915                self.assertEqual(scripted_div(a_t, b_t), expected_div)
1916                self.assertEqual(scripted_floordiv(a_t, b_t), expected_floordiv)
1917
1918        # Creates jitted functions of one tensor
1919        def _wrapped_div_scalar(a):
1920            return a / 5
1921
1922        # NOTE: the JIT implements division as torch.reciprocal(a) * 5
1923        def _wrapped_rdiv_scalar(a):
1924            return 5 / a
1925
1926        def _wrapped_floordiv_scalar(a):
1927            return a // 5
1928
1929        # NOTE: this fails if the input is not an integer tensor
1930        # See https://github.com/pytorch/pytorch/issues/45199
1931        def _wrapped_rfloordiv_scalar(a):
1932            return 5 // a
1933
1934        scripted_div_scalar = torch.jit.script(_wrapped_div_scalar)
1935        scripted_rdiv_scalar = torch.jit.script(_wrapped_rdiv_scalar)
1936        scripted_floordiv_scalar = torch.jit.script(_wrapped_floordiv_scalar)
1937        scripted_rfloordiv_scalar = torch.jit.script(_wrapped_rfloordiv_scalar)
1938
1939        for a in range(-10, 10):
1940            for op in (lambda x: x * 0.5, lambda x: math.floor(x)):
1941                a = op(a)
1942
1943                a_t = torch.tensor(a, device=device)
1944
1945                self.assertEqual(a / 5, scripted_div_scalar(a_t))
1946
1947                # Skips zero divisors
1948                if a == 0:
1949                    continue
1950
1951                self.assertEqual(5 / a, scripted_rdiv_scalar(a_t))
1952
1953                # Handles Issue 45199 (see comment above)
1954                if a_t.is_floating_point():
1955                    with self.assertRaises(RuntimeError):
1956                        scripted_rfloordiv_scalar(a_t)
1957                else:
1958                    # This should emit a UserWarning, why doesn't it?
1959                    # See issue gh-52387
1960                    self.assertEqual(5 // a, scripted_rfloordiv_scalar(a_t))
1961
1962    @onlyNativeDeviceTypes
1963    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
1964    def test_idiv_and_ifloordiv_vs_python(self, device):
1965        def _wrapped_idiv_tensor(a, b):
1966            a /= b
1967            return a
1968
1969        def _wrapped_idiv_scalar(a):
1970            a /= 5
1971            return a
1972
1973        def _wrapped_true_divide__tensor(a, b):
1974            a.true_divide_(b)
1975            return a
1976
1977        def _wrapped_true_divide__scalar(a):
1978            a.true_divide_(5)
1979            return a
1980
1981        def _wrapped_floor_divide__tensor(a, b):
1982            a.floor_divide_(b)
1983            return a
1984
1985        def _wrapped_floor_divide__scalar(a):
1986            a.floor_divide_(5)
1987            return a
1988
1989        # The following functions are unsupported by the JIT
1990        def _wrapped_ifloordiv_tensor(a, b):
1991            a //= b
1992            return a
1993
1994        def _wrapped_ifloordiv_scalar(a):
1995            a //= 5
1996            return a
1997
1998        with self.assertRaises(torch.jit.frontend.NotSupportedError):
1999            scripted_ifloordiv_tensor = torch.jit.script(_wrapped_ifloordiv_tensor)
2000
2001        with self.assertRaises(torch.jit.frontend.NotSupportedError):
2002            scripted_ifloordiv_scalar = torch.jit.script(_wrapped_ifloordiv_scalar)
2003
2004        scripted_idiv_tensor = torch.jit.script(_wrapped_idiv_tensor)
2005        scripted_idiv_scalar = torch.jit.script(_wrapped_idiv_scalar)
2006        scripted_true_divide__tensor = torch.jit.script(_wrapped_true_divide__tensor)
2007        scripted_true_divide__scalar = torch.jit.script(_wrapped_true_divide__scalar)
2008        scripted_floor_divide__tensor = torch.jit.script(_wrapped_floor_divide__tensor)
2009        scripted_floor_divide__scalar = torch.jit.script(_wrapped_floor_divide__scalar)
2010
2011        for a, b in product(range(-10, 10), range(-10, 10)):
2012            for op in (lambda x: x * 0.5, lambda x: math.floor(x)):
2013                a = op(a)
2014                b = op(b)
2015
2016                # Skips zero divisors
2017                if b == 0:
2018                    continue
2019
2020                expected_idiv = a / b
2021                expected_ifloordiv = a // b
2022
2023                a_t = torch.tensor(a, device=device)
2024                b_t = torch.tensor(b, device=device)
2025
2026                if a_t.is_floating_point():
2027                    tmp0 = a_t.clone()
2028                    tmp0 /= b
2029
2030                    tmp1 = a_t.clone()
2031                    tmp1 /= b_t
2032
2033                    self.assertEqual(tmp0.item(), expected_idiv)
2034                    self.assertEqual(tmp1.item(), expected_idiv)
2035                    self.assertEqual(
2036                        scripted_true_divide__tensor(a_t.clone(), b_t).item(),
2037                        expected_idiv,
2038                    )
2039                    self.assertEqual(
2040                        scripted_true_divide__scalar(a_t.clone()).item(), a / 5
2041                    )
2042                else:
2043                    tmp = a_t.clone()
2044                    with self.assertRaises(RuntimeError):
2045                        tmp /= b
2046                    with self.assertRaises(RuntimeError):
2047                        tmp /= b_t
2048                    with self.assertRaises(RuntimeError):
2049                        scripted_true_divide__tensor(tmp, b_t)
2050                    with self.assertRaises(RuntimeError):
2051                        scripted_true_divide__scalar(tmp)
2052
2053                if not a_t.is_floating_point() and b_t.is_floating_point():
2054                    # Inplace modification fails because a float tensor is required
2055                    #   if the divisor is a float tensor
2056                    a_t.clone().floor_divide_(b_t)
2057                    scripted_floor_divide__tensor(a_t.clone(), b_t)
2058                    tmp = a_t.clone()
2059                    tmp //= b_t
2060                else:
2061                    # Inplace modification is OK when both or neither tensor is
2062                    #   a float tensor
2063                    self.assertEqual(
2064                        a_t.clone().floor_divide_(b_t).item(), expected_ifloordiv
2065                    )
2066                    self.assertEqual(
2067                        scripted_floor_divide__tensor(a_t.clone(), b_t).item(),
2068                        expected_ifloordiv,
2069                    )
2070                    tmp = a_t.clone()
2071                    tmp //= b_t
2072                    self.assertEqual(tmp.item(), expected_ifloordiv)
2073
2074                self.assertEqual(scripted_floor_divide__scalar(a_t), math.floor(a / 5))
2075
2076    # Tests binary op equivalence with Python builtin ops
2077    # Also tests that reverse operations are equivalent to forward ops
2078    # NOTE: division ops are tested separately above
2079    def test_binary_ops_with_scalars(self, device):
2080        for python_op, torch_op in (
2081            (operator.add, torch.add),
2082            (operator.sub, torch.sub),
2083            (operator.mul, torch.mul),
2084            (operator.truediv, torch.div),
2085        ):
2086            for a, b in product(range(-10, 10), range(-10, 10)):
2087                for op in (lambda x: x * 0.5, lambda x: math.floor(x)):
2088                    a = op(a)
2089                    b = op(b)
2090
2091                    # Skips zero divisors
2092                    if b == 0 or a == 0:
2093                        continue
2094
2095                    a_tensor = torch.tensor(a, device=device)
2096                    b_tensor = torch.tensor(b, device=device)
2097                    a_tensor_cpu = a_tensor.cpu()
2098                    b_tensor_cpu = b_tensor.cpu()
2099                    vals = (a, b, a_tensor, b_tensor, a_tensor_cpu, b_tensor_cpu)
2100
2101                    for args in product(vals, vals):
2102                        first, second = args
2103
2104                        first_scalar = (
2105                            first
2106                            if not isinstance(first, torch.Tensor)
2107                            else first.item()
2108                        )
2109                        second_scalar = (
2110                            second
2111                            if not isinstance(second, torch.Tensor)
2112                            else second.item()
2113                        )
2114                        expected = python_op(first_scalar, second_scalar)
2115
2116                        self.assertEqual(expected, python_op(first, second))
2117                        self.assertEqual(expected, torch_op(first, second))
2118
2119    @dtypes(
2120        *product(
2121            all_types_and(torch.half, torch.bfloat16, torch.bool),
2122            all_types_and(torch.half, torch.bfloat16, torch.bool),
2123        )
2124    )
2125    def test_maximum_minimum_type_promotion(self, device, dtypes):
2126        a = torch.tensor((0, 1), device=device, dtype=dtypes[0])
2127        b = torch.tensor((1, 0), device=device, dtype=dtypes[1])
2128        for op in (
2129            torch.maximum,
2130            torch.max,
2131            torch.fmax,
2132            torch.minimum,
2133            torch.min,
2134            torch.fmin,
2135        ):
2136            result = op(a, b)
2137            self.assertEqual(result.dtype, torch.result_type(a, b))
2138
2139    @dtypes(*integral_types_and(torch.bool))
2140    def test_maximum_minimum_int_and_bool(self, device, dtype):
2141        ops = (
2142            (torch.maximum, torch.max, np.maximum),
2143            (torch.minimum, torch.min, np.minimum),
2144            (torch.fmax, None, np.fmax),
2145            (torch.fmin, None, np.fmin),
2146        )
2147        rng = np.random.default_rng()
2148        a_np = np.array(
2149            rng.integers(-100, 100, size=10), dtype=torch_to_numpy_dtype_dict[dtype]
2150        )
2151        b_np = np.array(
2152            rng.integers(-100, 100, size=10), dtype=torch_to_numpy_dtype_dict[dtype]
2153        )
2154
2155        for torch_op, alias, numpy_op in ops:
2156            a_tensor = torch.from_numpy(a_np).to(device=device, dtype=dtype)
2157            b_tensor = torch.from_numpy(b_np).to(device=device, dtype=dtype)
2158            tensor_result = torch_op(a_tensor, b_tensor)
2159
2160            out = torch.empty_like(a_tensor)
2161            torch_op(a_tensor, b_tensor, out=out)
2162
2163            numpy_result = numpy_op(a_np, b_np)
2164
2165            if alias is not None:
2166                alias_result = alias(a_tensor, b_tensor)
2167                self.assertEqual(alias_result, tensor_result)
2168
2169            self.assertEqual(tensor_result, numpy_result)
2170            self.assertEqual(out, numpy_result)
2171
2172    @precisionOverride({torch.bfloat16: 1e-2})
2173    @dtypes(*(floating_types_and(torch.half, torch.bfloat16)))
2174    def test_maximum_minimum_float(self, device, dtype):
2175        ops = (
2176            (torch.maximum, torch.max, np.maximum),
2177            (torch.minimum, torch.min, np.minimum),
2178            (torch.fmax, None, np.fmax),
2179            (torch.fmin, None, np.fmin),
2180        )
2181
2182        if dtype == torch.bfloat16:
2183            a_np = np.random.randn(10).astype(np.float64)
2184            b_np = np.random.randn(10).astype(np.float64)
2185        else:
2186            a_np = np.random.randn(10).astype(torch_to_numpy_dtype_dict[dtype])
2187            b_np = np.random.randn(10).astype(torch_to_numpy_dtype_dict[dtype])
2188
2189        for torch_op, alias, numpy_op in ops:
2190            numpy_result = numpy_op(a_np, b_np)
2191
2192            a_tensor = torch.from_numpy(a_np).to(device=device, dtype=dtype)
2193            b_tensor = torch.from_numpy(b_np).to(device=device, dtype=dtype)
2194            tensor_result = torch_op(a_tensor, b_tensor)
2195            out = torch.empty_like(a_tensor)
2196            torch_op(a_tensor, b_tensor, out=out)
2197
2198            if alias is not None:
2199                alias_result = alias(a_tensor, b_tensor)
2200                self.assertEqual(alias_result, tensor_result, exact_dtype=False)
2201
2202            self.assertEqual(tensor_result, numpy_result, exact_dtype=False)
2203            self.assertEqual(out, numpy_result, exact_dtype=False)
2204
2205    @dtypes(*(floating_types_and(torch.half, torch.bfloat16)))
2206    def test_maximum_minimum_float_nan_and_inf(self, device, dtype):
2207        # np.maximum and np.minimum functions compare input arrays element-wisely.
2208        # if one of the elements being compared is a NaN, then that element is returned.
2209        ops = (
2210            (torch.maximum, torch.max, np.maximum),
2211            (torch.minimum, torch.min, np.minimum),
2212            (torch.fmax, None, np.fmax),
2213            (torch.fmin, None, np.fmin),
2214        )
2215        a_vals = (
2216            float("inf"),
2217            -float("inf"),
2218            float("nan"),
2219            float("inf"),
2220            float("nan"),
2221            float("nan"),
2222            1,
2223            float("nan"),
2224        )
2225        b_vals = (
2226            -float("inf"),
2227            float("inf"),
2228            float("inf"),
2229            float("nan"),
2230            float("nan"),
2231            0,
2232            float("nan"),
2233            -5,
2234        )
2235        if dtype == torch.bfloat16:
2236            a_np = np.array(a_vals, dtype=np.float64)
2237            b_np = np.array(b_vals, dtype=np.float64)
2238        else:
2239            a_np = np.array(a_vals, dtype=torch_to_numpy_dtype_dict[dtype])
2240            b_np = np.array(b_vals, dtype=torch_to_numpy_dtype_dict[dtype])
2241
2242        for torch_op, alias, numpy_op in ops:
2243            numpy_result = numpy_op(a_np, b_np)
2244
2245            a_tensor = torch.from_numpy(a_np).to(device=device, dtype=dtype)
2246            b_tensor = torch.from_numpy(b_np).to(device=device, dtype=dtype)
2247            tensor_result = torch_op(a_tensor, b_tensor)
2248
2249            out = torch.empty_like(a_tensor)
2250            torch_op(a_tensor, b_tensor, out=out)
2251
2252            if alias is not None:
2253                alias_result = alias(a_tensor, b_tensor)
2254                self.assertEqual(alias_result, tensor_result)
2255
2256            if dtype == torch.bfloat16:
2257                self.assertEqual(tensor_result, numpy_result, exact_dtype=False)
2258                self.assertEqual(out, numpy_result, exact_dtype=False)
2259            else:
2260                self.assertEqual(tensor_result, numpy_result)
2261                self.assertEqual(out, numpy_result)
2262
2263    @dtypes(
2264        *product(
2265            complex_types(),
2266            all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
2267        )
2268    )
2269    def test_maximum_minimum_complex(self, device, dtypes):
2270        for torch_op in (
2271            torch.maximum,
2272            torch.minimum,
2273            torch.max,
2274            torch.min,
2275            torch.fmax,
2276            torch.fmin,
2277        ):
2278            with self.assertRaisesRegex(RuntimeError, ".+not implemented for.+"):
2279                torch_op(
2280                    torch.ones(1, device=device, dtype=dtypes[0]),
2281                    torch.ones(1, device=device, dtype=dtypes[1]),
2282                )
2283
2284            with self.assertRaisesRegex(RuntimeError, ".+not implemented for.+"):
2285                torch_op(
2286                    torch.ones(1, device=device, dtype=dtypes[1]),
2287                    torch.ones(1, device=device, dtype=dtypes[0]),
2288                )
2289
2290    @onlyCUDA
2291    def test_maximum_minimum_cross_device(self, device):
2292        a = torch.tensor((1, 2, -1))
2293        b = torch.tensor((3, 0, 4), device=device)
2294        ops = (torch.maximum, torch.minimum)
2295
2296        for torch_op in ops:
2297            with self.assertRaisesRegex(
2298                RuntimeError, "Expected all tensors to be on the same device"
2299            ):
2300                torch_op(a, b)
2301
2302            with self.assertRaisesRegex(
2303                RuntimeError, "Expected all tensors to be on the same device"
2304            ):
2305                torch_op(b, a)
2306
2307        # test cuda tensor and cpu scalar
2308        ops = ((torch.maximum, np.maximum), (torch.minimum, np.minimum))
2309        a_np = np.array(1)
2310        b_np = np.array([3, 0, 4])
2311
2312        for torch_op, numpy_op in ops:
2313            a_tensor = torch.from_numpy(a_np)
2314            b_tensor = torch.from_numpy(b_np).to(device=device)
2315            tensor_result_1 = torch_op(a_tensor, b_tensor)
2316            numpy_result_1 = numpy_op(a_np, b_np)
2317            tensor_result_2 = torch_op(b_tensor, a_tensor)
2318            numpy_result_2 = numpy_op(b_np, a_np)
2319
2320            self.assertEqual(tensor_result_1, numpy_result_1)
2321            self.assertEqual(tensor_result_2, numpy_result_2)
2322
2323    @dtypes(
2324        *product(
2325            floating_types_and(torch.half, torch.bfloat16),
2326            floating_types_and(torch.half, torch.bfloat16),
2327        )
2328    )
2329    def test_maximum_and_minimum_subgradient(self, device, dtypes):
2330        def run_test(f, a, b, expected_a_grad, expected_b_grad):
2331            a = torch.tensor(a, requires_grad=True, device=device, dtype=dtypes[0])
2332            b = torch.tensor(b, requires_grad=True, device=device, dtype=dtypes[1])
2333            z = f(a, b)
2334            z.sum().backward()
2335            self.assertEqual(a.grad, expected_a_grad)
2336            self.assertEqual(b.grad, expected_b_grad)
2337
2338        run_test(
2339            torch.maximum,
2340            [0.0, 1.0, 2.0],
2341            [1.0, 1.0, 1.0],
2342            [0.0, 0.5, 1.0],
2343            [1.0, 0.5, 0.0],
2344        )
2345        run_test(
2346            torch.minimum,
2347            [0.0, 1.0, 2.0],
2348            [1.0, 1.0, 1.0],
2349            [1.0, 0.5, 0.0],
2350            [0.0, 0.5, 1.0],
2351        )
2352
2353    def test_maximum_minimum_forward_ad_float32(self, device):
2354        # TODO: This should really be covered by OpInfo but it isn't. The problem
2355        # is that our gradient tests test using float64 but it should also test
2356        # float32
2357        x = torch.randn(3, device=device, dtype=torch.float32)
2358        y = torch.randn(3, device=device, dtype=torch.float32)
2359        tx = torch.randn(3, device=device, dtype=torch.float32)
2360        ty = torch.randn(3, device=device, dtype=torch.float32)
2361
2362        with fwAD.dual_level():
2363            x_dual = fwAD.make_dual(x, tx)
2364            y_dual = fwAD.make_dual(y, ty)
2365            result = torch.maximum(x_dual, y_dual)
2366            _, result_tangent = fwAD.unpack_dual(result)
2367
2368        expected = torch.where(x > y, tx, ty)
2369        self.assertEqual(result_tangent, expected)
2370
2371        with fwAD.dual_level():
2372            x_dual = fwAD.make_dual(x, tx)
2373            y_dual = fwAD.make_dual(y, ty)
2374            result = torch.minimum(x_dual, y_dual)
2375            _, result_tangent = fwAD.unpack_dual(result)
2376
2377        expected = torch.where(x < y, tx, ty)
2378        self.assertEqual(result_tangent, expected)
2379
2380    # TODO: tests like this should be generic
2381    @dtypesIfCUDA(torch.half, torch.float, torch.double)
2382    @dtypes(torch.float, torch.double)
2383    def test_mul_intertype_scalar(self, device, dtype):
2384        x = torch.tensor(1.5, dtype=dtype, device=device)
2385        y = torch.tensor(3, dtype=torch.int32, device=device)
2386
2387        self.assertEqual(x * y, 4.5)
2388        self.assertEqual(y * x, 4.5)
2389
2390        with self.assertRaisesRegex(
2391            RuntimeError, "can't be cast to the desired output type"
2392        ):
2393            y *= x
2394        x *= y
2395        self.assertEqual(x, 4.5)
2396
2397    @onlyCPU
2398    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
2399    def test_sub(self, device, dtype):
2400        if dtype in integral_types():
2401            # Before Python 3.10, floats were implicitly converted to ints, but with
2402            #   DeprecationWarning: an integer is required (got type float).
2403            #   Implicit conversion to integers using __int__ is deprecated,
2404            #   and may be removed in a future version of Python.
2405            # Since Python 3.10, that attempt gives an error.
2406            m1 = torch.tensor([2, 4], dtype=dtype, device=device)
2407            m2 = torch.tensor([1, 2], dtype=dtype, device=device)
2408            diff = torch.tensor([1, 2], dtype=dtype)
2409        else:
2410            m1 = torch.tensor([2.34, 4.44], dtype=dtype, device=device)
2411            m2 = torch.tensor([1.23, 2.33], dtype=dtype, device=device)
2412            diff = torch.tensor([1.11, 2.11], dtype=dtype)
2413
2414        if dtype == torch.bool:
2415            self.assertRaises(RuntimeError, lambda: m1 - m2)
2416        elif dtype == torch.bfloat16 or dtype == torch.half:
2417            # bfloat16 has a lower precision so we have to have a separate check for it
2418            self.assertEqual(m1 - m2, diff, atol=0.01, rtol=0)
2419        else:
2420            self.assertEqual(m1 - m2, diff)
2421
2422    # TODO: what is this test testing?
2423    @onlyCPU
2424    @dtypes(torch.float)
2425    def test_csub(self, device, dtype):
2426        # with a tensor
2427        a = torch.randn(100, 90, dtype=dtype, device=device)
2428        b = a.clone().normal_()
2429
2430        res_add = torch.add(a, b, alpha=-1)
2431        res_csub = a.clone()
2432        res_csub.sub_(b)
2433        self.assertEqual(res_add, res_csub)
2434
2435        # with a scalar
2436        a = torch.randn(100, 100, dtype=dtype, device=device)
2437
2438        scalar = 123.5
2439        res_add = torch.add(a, -scalar)
2440        res_csub = a.clone()
2441        res_csub.sub_(scalar)
2442        self.assertEqual(res_add, res_csub)
2443
2444    # TODO: reconcile with minimum/maximum tests
2445    @dtypesIfCUDA(torch.half, torch.float, torch.double)
2446    @dtypes(torch.float, torch.double)
2447    def test_min_max_binary_op_nan(self, device, dtype):
2448        a = torch.rand(1000, dtype=dtype, device=device)
2449        b = torch.rand(1000, dtype=dtype, device=device)
2450
2451        # 0:250: a -- nan, b -- not nan
2452        a[:250] = float("nan")
2453        # 250:500: a -- not nan, b -- nan
2454        b[250:500] = float("nan")
2455        # 500:750: a and b both nan
2456        a[500:750] = float("nan")
2457        b[500:750] = float("nan")
2458        # 750:1000: neither nan
2459
2460        ma = torch.max(a, b)
2461        mi = torch.min(a, b)
2462
2463        for i in range(750):
2464            self.assertTrue(
2465                torch.isnan(ma[i]),
2466                f"max(a, b): {ma[i]}, a: {a[i]}, b: {b[i]}",
2467            )
2468            self.assertTrue(
2469                torch.isnan(mi[i]),
2470                f"min(a, b): {mi[i]}, a: {a[i]}, b: {b[i]}",
2471            )
2472
2473        for i in range(750, 1000):
2474            self.assertFalse(
2475                torch.isnan(ma[i]),
2476                f"max(a, b): {ma[i]}, a: {a[i]}, b: {b[i]}",
2477            )
2478            self.assertFalse(
2479                torch.isnan(mi[i]),
2480                f"min(a, b): {mi[i]}, a: {a[i]}, b: {b[i]}",
2481            )
2482
2483    @dtypes(
2484        *product(
2485            all_types_and(torch.half, torch.bfloat16, torch.bool),
2486            all_types_and(torch.half, torch.bfloat16, torch.bool),
2487        )
2488    )
2489    def test_copysign(self, device, dtypes):
2490        def _test_copysign_numpy(a, b):
2491            torch_result = torch.copysign(a, b)
2492
2493            if a.dtype == torch.bfloat16:
2494                np_a = a.to(torch.float).cpu().numpy()
2495            else:
2496                np_a = a.cpu().numpy()
2497
2498            if b.dtype == torch.bfloat16:
2499                np_b = b.to(torch.float).cpu().numpy()
2500            else:
2501                np_b = b.cpu().numpy()
2502            expected = torch.from_numpy(np.copysign(np_a, np_b))
2503            # To handle inconsistencies of type promotion between PyTorch and Numpy
2504            # Applied for both arguments having integral precision and bfloat16
2505            types = integral_types_and(torch.bool, torch.bfloat16)
2506            if a.dtype in types or b.dtype in types:
2507                promoted_type = torch.promote_types(torch_result.dtype, expected.dtype)
2508                torch_result = torch_result.to(promoted_type)
2509                expected = expected.to(promoted_type)
2510
2511            # Verify Value
2512            self.assertEqual(torch_result, expected)
2513            # Verify Sign
2514            # Use double copysign to verify the correctnes of 0.0 and -0.0, since
2515            # it always True for self.assertEqual(0.0 == -0.0). So, we use 1 as the
2516            # magnitude to verify the sign between torch and numpy results, elementwise.
2517            # Special case: NaN conversions between FP32 and FP16 is not bitwise
2518            # equivalent to pass this assertion.
2519            if a.dtype != torch.float16 and b.dtype != torch.float16:
2520                self.assertEqual(
2521                    torch.copysign(torch.tensor(1.0), torch_result),
2522                    torch.copysign(torch.tensor(1.0), expected),
2523                )
2524
2525        # Compare Result with NumPy
2526        # Type promotion
2527        a = make_tensor((10, 10), device=device, dtype=dtypes[0], low=-9, high=9)
2528        b = make_tensor((10, 10), device=device, dtype=dtypes[1], low=-9, high=9)
2529        _test_copysign_numpy(a, b)
2530
2531        # Broadcast
2532        a = make_tensor((10, 1, 10), device=device, dtype=dtypes[0], low=-9, high=9)
2533        b = make_tensor((10, 10), device=device, dtype=dtypes[1], low=-9, high=9)
2534        _test_copysign_numpy(a, b)
2535
2536        a = make_tensor((10, 10), device=device, dtype=dtypes[0], low=-9, high=9)
2537        b = make_tensor((10, 1, 10), device=device, dtype=dtypes[1], low=-9, high=9)
2538        _test_copysign_numpy(a, b)
2539
2540        # 0.0/-0.0/inf/-inf/nan
2541        cases = [0.0, -0.0, float("inf"), float("-inf"), float("nan")]
2542        # torch.bfloat16 can not hold '-nan'
2543        # torch.half can not hold '-nan' on CUDA
2544        types = [torch.float32, torch.float64]
2545        if device == "cpu":
2546            types.append(torch.float16)
2547        if dtypes[0] in types:
2548            b = make_tensor((10, 10), device=device, dtype=dtypes[1], low=-9, high=9)
2549            for case in cases:
2550                _test_copysign_numpy(
2551                    torch.tensor([case], device=device, dtype=dtypes[0]), b
2552                )
2553
2554        if dtypes[1] in floating_types_and(torch.half, torch.bfloat16):
2555            a = make_tensor((10, 10), device=device, dtype=dtypes[0], low=-9, high=9)
2556            for case in cases:
2557                _test_copysign_numpy(
2558                    a, torch.tensor([case], device=device, dtype=dtypes[1])
2559                )
2560
2561    @dtypes(
2562        *product(
2563            floating_types_and(torch.half, torch.bfloat16),
2564            floating_types_and(torch.half, torch.bfloat16),
2565        )
2566    )
2567    def test_copysign_subgradient(self, device, dtypes):
2568        # Input is 0.0
2569        x = torch.tensor(
2570            [0.0, 0.0, 0.0], dtype=dtypes[0], device=device, requires_grad=True
2571        )
2572        y = torch.tensor(
2573            [-1.0, 0.0, 1.0], dtype=dtypes[1], device=device, requires_grad=True
2574        )
2575        out = torch.copysign(x, y)
2576        out.sum().backward()
2577        self.assertEqual(x.grad.tolist(), [0.0, 0.0, 0.0])
2578        self.assertEqual(y.grad.tolist(), [0.0] * 3)
2579
2580        # Input is -0.0
2581        x = torch.tensor(
2582            [-0.0, -0.0, -0.0], dtype=dtypes[0], device=device, requires_grad=True
2583        )
2584        y = torch.tensor(
2585            [-1.0, 0.0, 1.0], dtype=dtypes[1], device=device, requires_grad=True
2586        )
2587        out = torch.copysign(x, y)
2588        out.sum().backward()
2589        self.assertEqual(x.grad.tolist(), [0.0, 0.0, 0.0])
2590        self.assertEqual(y.grad.tolist(), [0.0] * 3)
2591
2592        # Other is 0.0
2593        x = torch.tensor(
2594            [-1.0, 0.0, 1.0], dtype=dtypes[0], device=device, requires_grad=True
2595        )
2596        y = torch.tensor(
2597            [0.0, 0.0, 0.0], dtype=dtypes[1], device=device, requires_grad=True
2598        )
2599        out = torch.copysign(x, y)
2600        out.sum().backward()
2601        self.assertEqual(x.grad.tolist(), [-1.0, 0.0, 1.0])
2602        self.assertEqual(y.grad.tolist(), [0.0] * 3)
2603
2604        # Other is -0.0
2605        x = torch.tensor(
2606            [-1.0, 0.0, 1.0], dtype=dtypes[0], device=device, requires_grad=True
2607        )
2608        y = torch.tensor(
2609            [-0.0, -0.0, -0.0], dtype=dtypes[1], device=device, requires_grad=True
2610        )
2611        out = torch.copysign(x, y)
2612        out.sum().backward()
2613        self.assertEqual(x.grad.tolist(), [1.0, 0.0, -1.0])
2614        self.assertEqual(y.grad.tolist(), [0.0] * 3)
2615
2616    @dtypes(torch.bfloat16, torch.float)
2617    def test_div(self, device, dtype):
2618        for op, method, inplace in (
2619            (torch.div, torch.Tensor.div, torch.Tensor.div_),
2620            (torch.true_divide, torch.Tensor.true_divide, torch.Tensor.true_divide_),
2621        ):
2622            m1 = torch.randn(10, 10, dtype=torch.float, device=device).to(dtype=dtype)
2623            res1 = m1.clone()
2624            inplace(res1[:, 3], 2)
2625            res2 = m1.clone()
2626            for i in range(m1.size(0)):
2627                res2[i, 3] = res2[i, 3] / 2
2628            self.assertEqual(res1, res2)
2629
2630            if dtype == torch.bfloat16:
2631                a1 = torch.tensor([4.2, 6.2], dtype=dtype, device=device)
2632                a2 = torch.tensor([2.0, 2.0], dtype=dtype, device=device)
2633                self.assertEqual(
2634                    op(a1, a2),
2635                    torch.tensor([2.1, 3.1], dtype=dtype, device=device),
2636                    atol=0.01,
2637                    rtol=0,
2638                )
2639                self.assertEqual(method(a1, a2), op(a1, a2))
2640
2641    @dtypes(torch.bfloat16, torch.float)
2642    def test_true_divide_out(self, device, dtype):
2643        a1 = torch.tensor([4.2, 6.2], dtype=dtype, device=device)
2644        a2 = torch.tensor([2.0, 2.0], dtype=dtype, device=device)
2645        res = torch.empty_like(a1)
2646        self.assertEqual(
2647            torch.true_divide(a1, a2, out=res),
2648            torch.tensor([2.1, 3.1], dtype=dtype, device=device),
2649            atol=0.01,
2650            rtol=0,
2651        )
2652
2653    @dtypes(torch.half)
2654    def test_divmul_scalar(self, device, dtype):
2655        x = torch.tensor(100.0, device=device, dtype=dtype)
2656        x_ref = x.float()
2657        scale = 1e5
2658        res = x.div(scale)
2659        expected = x_ref.div(scale)
2660        self.assertEqual(res, expected.to(dtype), atol=0.0, rtol=0.0)
2661        x = torch.tensor(1e-5, device=device, dtype=dtype)
2662        x_ref = x.float()
2663        res = x.mul(scale)
2664        expected = x_ref.mul(scale)
2665        self.assertEqual(res, expected.to(dtype), atol=0.0, rtol=0.0)
2666        res = scale * x
2667        self.assertEqual(res, expected.to(dtype), atol=0.0, rtol=0.0)
2668
2669    @dtypesIfCUDA(
2670        *set(get_all_math_dtypes("cuda")) - {torch.complex64, torch.complex128}
2671    )
2672    @dtypes(*set(get_all_math_dtypes("cpu")) - {torch.complex64, torch.complex128})
2673    def test_floor_divide_tensor(self, device, dtype):
2674        x = torch.randn(10, device=device).mul(30).to(dtype)
2675        y = torch.arange(1, 11, dtype=dtype, device=device)
2676
2677        z = x // y
2678        z_alt = torch.floor(x.double() / y.double()).to(dtype)
2679
2680        self.assertEqual(z.dtype, x.dtype)
2681        self.assertEqual(z, z_alt)
2682
2683    @dtypesIfCUDA(
2684        *set(get_all_math_dtypes("cuda")) - {torch.complex64, torch.complex128}
2685    )
2686    @dtypes(*set(get_all_math_dtypes("cpu")) - {torch.complex64, torch.complex128})
2687    def test_floor_divide_scalar(self, device, dtype):
2688        x = torch.randn(100, device=device).mul(10).to(dtype)
2689
2690        z = x // 3
2691        z_alt = torch.tensor(
2692            [math.floor(v.item() / 3.0) for v in x], dtype=x.dtype, device=device
2693        )
2694
2695        self.assertEqual(z.dtype, x.dtype)
2696        self.assertEqual(z, z_alt)
2697
2698    @onlyCPU
2699    @dtypes(*get_all_math_dtypes("cpu"))
2700    def test_rdiv(self, device, dtype):
2701        if dtype is torch.float16:
2702            return
2703        elif dtype.is_complex:
2704            x = torch.rand(100, dtype=dtype, device=device).add(1).mul(4)
2705        else:
2706            x = torch.rand(100, device=device).add(1).mul(4).to(dtype)
2707        y = 30 / x
2708        z = torch.tensor([30 / v.item() for v in x], device=device)
2709        self.assertEqual(y, z, exact_dtype=False)
2710
2711    @dtypes(*floating_types_and(torch.half))
2712    def test_fmod_remainder_by_zero_float(self, device, dtype):
2713        fn_list = (torch.fmod, torch.remainder)
2714        for fn in fn_list:
2715            # check floating-point tensor fmod/remainder to zero is nan on both CPU and GPU
2716            x = make_tensor((10, 10), device=device, dtype=dtype, low=-9, high=9)
2717            zero = torch.zeros_like(x)
2718            self.assertTrue(torch.all(fn(x, 0.0).isnan()))
2719            self.assertTrue(torch.all(fn(x, zero).isnan()))
2720
2721    @onlyNativeDeviceTypes  # Check Issue https://github.com/pytorch/pytorch/issues/48130
2722    @dtypes(*integral_types())
2723    def test_fmod_remainder_by_zero_integral(self, device, dtype):
2724        fn_list = (torch.fmod, torch.remainder)
2725        for fn in fn_list:
2726            # check integral tensor fmod/remainder to zero
2727            x = make_tensor((10, 10), device=device, dtype=dtype, low=-9, high=9)
2728            zero = torch.zeros_like(x)
2729            # RuntimeError on CPU
2730            if self.device_type == "cpu":
2731                with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError"):
2732                    fn(x, zero)
2733            elif torch.version.hip is not None:
2734                # ROCm behavior: x % 0 is a no-op; x is returned
2735                self.assertEqual(fn(x, zero), x)
2736            else:
2737                # CUDA behavior: Different value for different dtype
2738                # Due to it's an undefined behavior, CUDA returns a pattern of all 1s
2739                # for integral dividend (other than int64) divided by zero. For int64,
2740                # CUDA returns all 1s for negative dividend, half 1s for positive dividend.
2741                # uint8: 0xff -> 255
2742                # int32: 0xffffffff -> -1
2743                if dtype == torch.int64:
2744                    self.assertEqual(fn(x, zero) == 4294967295, x >= 0)
2745                    self.assertEqual(fn(x, zero) == -1, x < 0)
2746                else:
2747                    value = 255 if dtype == torch.uint8 else -1
2748                    self.assertTrue(torch.all(fn(x, zero) == value))
2749
2750    @dtypes(*all_types_and(torch.half))
2751    def test_fmod_remainder(self, device, dtype):
2752        # Use numpy as reference
2753        def _helper(x, mod, fns_list):
2754            for fn, inplace_fn, ref_fn in fns_list:
2755                np_x = x.cpu().numpy() if torch.is_tensor(x) else x
2756                np_mod = mod.cpu().numpy() if torch.is_tensor(mod) else mod
2757                exp = ref_fn(np_x, np_mod)
2758                exp = torch.from_numpy(exp)
2759                res = fn(x, mod)
2760
2761                self.assertEqual(res, exp, exact_dtype=False)
2762
2763                if torch.is_tensor(x):
2764                    # out
2765                    out = torch.empty(0, device=device, dtype=res.dtype)
2766                    fn(x, mod, out=out)
2767                    self.assertEqual(out, exp, exact_dtype=False)
2768                    self.assertEqual(out.size(), torch.Size([10, 10]))
2769                    # in-place (Type cast runtime error)
2770                    try:
2771                        inplace_fn(x, mod)
2772                        self.assertEqual(x, exp, exact_dtype=False)
2773                    except RuntimeError as e:
2774                        self.assertRegex(
2775                            str(e),
2776                            "result type (Half|Float|Double) "
2777                            "can't be cast to the desired output "
2778                            "type (Byte|Char|Short|Int|Long)",
2779                        )
2780
2781        x = make_tensor((10, 10), device=device, dtype=dtype, low=-9, high=9)
2782        # mod with same dtype as x
2783        mod = make_tensor((10, 10), device=device, dtype=dtype, low=-9, high=9)
2784        # Exclude 0
2785        mod[mod == 0] = 1
2786
2787        # Mods: Integer, Float, Tensor, Non-contiguous Tensor
2788        mods = [3, 2.3, mod, mod.t()]
2789        # mod with floating-point dtype
2790        if dtype in integral_types():
2791            mod_float = make_tensor(
2792                (10, 10), device=device, dtype=torch.float, low=-9, high=9
2793            )
2794            mod[mod == 0] = 1
2795            mods.append(mod_float)
2796
2797        for dividend, mod in product([x, x.t()], mods):
2798            _helper(
2799                dividend,
2800                mod,
2801                (
2802                    (torch.fmod, torch.Tensor.fmod_, np.fmod),
2803                    (torch.remainder, torch.Tensor.remainder_, np.remainder),
2804                ),
2805            )
2806
2807        # Tests for torch.remainder(scalar, tensor)
2808        for dividend, mod in product([5, 3.14], mods):
2809            if torch.is_tensor(mod):
2810                _helper(
2811                    dividend,
2812                    mod,
2813                    ((torch.remainder, torch.Tensor.remainder_, np.remainder),),
2814                )
2815
2816    @dtypes(torch.float, torch.double)
2817    def test_remainder_fmod_large_dividend(self, device, dtype):
2818        alarge = 1e9
2819        pi = 3.14159265358979
2820        for avalue in [alarge, -alarge]:
2821            for bvalue in [pi, -pi]:
2822                a = torch.tensor([avalue], dtype=dtype, device=device)
2823                b = torch.tensor([bvalue], dtype=dtype, device=device)
2824                c = torch.remainder(a, b)
2825                d = torch.fmod(a, b)
2826                self.assertTrue(
2827                    (b[0] > 0) == (c[0] > 0)
2828                )  # remainder has same sign as divisor
2829                self.assertTrue(
2830                    (a[0] > 0) == (d[0] > 0)
2831                )  # fmod has same sign as dividend
2832                self.assertTrue(
2833                    abs(c[0]) < abs(b[0])
2834                )  # remainder is within range of divisor
2835                self.assertTrue(
2836                    abs(d[0]) < abs(b[0])
2837                )  # fmod is within range of divisor
2838                if (a[0] > 0) == (b[0] > 0):
2839                    self.assertTrue(c[0] == d[0])  # remainder is same as fmod
2840                else:
2841                    self.assertTrue(
2842                        abs(c[0] - d[0]) == abs(b[0])
2843                    )  # differ by one divisor
2844
2845    @dtypesIfCPU(torch.bfloat16, torch.half, torch.float32, torch.float64)
2846    @dtypes(torch.float32, torch.float64)
2847    def test_hypot(self, device, dtype):
2848        inputs = [
2849            (
2850                torch.randn(10, device=device).to(dtype),
2851                torch.randn(10, device=device).to(dtype),
2852            ),
2853            (
2854                torch.randn((3, 3, 3), device=device).to(dtype),
2855                torch.randn((3, 3, 3), device=device).to(dtype),
2856            ),
2857            (
2858                torch.randn((10, 1), device=device).to(dtype),
2859                torch.randn((10, 1), device=device).to(dtype).transpose(0, 1),
2860            ),
2861            (
2862                torch.randint(100, (10,), device=device, dtype=torch.long),
2863                torch.randn(10, device=device).to(dtype),
2864            ),
2865        ]
2866        for input in inputs:
2867            actual = torch.hypot(input[0], input[1])
2868            if dtype in [torch.bfloat16, torch.half]:
2869                expected = torch.sqrt(input[0] * input[0] + input[1] * input[1])
2870            else:
2871                expected = np.hypot(input[0].cpu().numpy(), input[1].cpu().numpy())
2872            self.assertEqual(actual, expected, exact_dtype=False)
2873
2874    @onlyNativeDeviceTypes
2875    @dtypes(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
2876    def test_gcd(self, device, dtype):
2877        # Tests gcd(0, 0), gcd(0, a) cases
2878        t1 = torch.tensor([0, 10, 0], dtype=dtype, device=device)
2879        t2 = torch.tensor([0, 0, 10], dtype=dtype, device=device)
2880        actual = torch.gcd(t1, t2)
2881        expected = np.gcd([0, 10, 0], [0, 0, 10])
2882        self.assertEqual(actual, expected, exact_dtype=False)
2883
2884        if dtype == torch.uint8:
2885            # Test unsigned integers with potential sign issues (i.e., uint8 with value >= 128)
2886            a = torch.tensor([190, 210], device=device, dtype=dtype)
2887            b = torch.tensor([190, 220], device=device, dtype=dtype)
2888            actual = torch.gcd(a, b)
2889            expected = torch.tensor([190, 10], device=device, dtype=dtype)
2890            self.assertEqual(actual, expected)
2891        else:
2892            # Compares with NumPy
2893            a = torch.randint(-20, 20, (1024,), device=device, dtype=dtype)
2894            b = torch.randint(-20, 20, (1024,), device=device, dtype=dtype)
2895            actual = torch.gcd(a, b)
2896            expected = np.gcd(a.cpu().numpy(), b.cpu().numpy())
2897            self.assertEqual(actual, expected)
2898
2899    @onlyNativeDeviceTypes
2900    @dtypes(torch.int16, torch.int32, torch.int64)
2901    def test_lcm(self, device, dtype):
2902        # Tests lcm(0, 0), lcm(0, a) cases
2903        t1 = torch.tensor([0, 10, 0], dtype=dtype, device=device)
2904        t2 = torch.tensor([0, 0, 10], dtype=dtype, device=device)
2905        actual = torch.lcm(t1, t2)
2906        expected = np.lcm([0, 10, 0], [0, 0, 10])
2907        self.assertEqual(actual, expected, exact_dtype=False)
2908
2909        # Compares with NumPy
2910        a = torch.randint(-20, 20, (1024,), device=device, dtype=dtype)
2911        b = torch.randint(-20, 20, (1024,), device=device, dtype=dtype)
2912        actual = torch.lcm(a, b)
2913        expected = np.lcm(a.cpu().numpy(), b.cpu().numpy())
2914        self.assertEqual(actual, expected, exact_dtype=False)
2915
2916    @onlyNativeDeviceTypes
2917    @dtypesIfCPU(torch.float32, torch.float64, torch.float16)
2918    @dtypes(torch.float32, torch.float64)
2919    def test_nextafter(self, device, dtype):
2920        # Test special cases
2921        t1 = torch.tensor([0, 0, 10], device=device, dtype=dtype)
2922        t2 = torch.tensor([inf, -inf, 10], device=device, dtype=dtype)
2923        actual = torch.nextafter(t1, t2)
2924        expected = np.nextafter(t1.cpu().numpy(), t2.cpu().numpy())
2925        self.assertEqual(actual, expected, atol=0, rtol=0)
2926
2927        actual = torch.nextafter(t2, t1)
2928        expected = np.nextafter(t2.cpu().numpy(), t1.cpu().numpy())
2929        self.assertEqual(actual, expected, atol=0, rtol=0)
2930
2931        t1 = torch.tensor([0, nan], device=device, dtype=dtype)
2932        t2 = torch.tensor([nan, 0], device=device, dtype=dtype)
2933        self.assertTrue(torch.nextafter(t1, t2).isnan().all())
2934
2935        a = torch.randn(100, device=device, dtype=dtype)
2936        b = torch.randn(100, device=device, dtype=dtype)
2937        actual = torch.nextafter(a, b)
2938        expected = np.nextafter(a.cpu().numpy(), b.cpu().numpy())
2939        self.assertEqual(actual, expected, atol=0, rtol=0)
2940
2941    @onlyNativeDeviceTypes
2942    @dtypes(torch.bfloat16)
2943    def test_nextafter_bfloat16(self, device, dtype):
2944        nan = float("nan")
2945        inf = float("inf")
2946        cases = (
2947            # (from, to, expected)
2948            (0, 1, 9.183549615799121e-41),
2949            (0, -1, -9.183549615799121e-41),
2950            (1, -2, 0.99609375),
2951            (1, 0, 0.99609375),
2952            (1, 2, 1.0078125),
2953            (-1, -2, -1.0078125),
2954            (-1, 0, -0.99609375),
2955            (2, -1, 1.9921875),
2956            (2, 1, 1.9921875),
2957            (20, 3000, 20.125),
2958            (20, -3000, 19.875),
2959            (3000, -20, 2992.0),
2960            (-3000, 20, -2992.0),
2961            (65536, 0, 65280.0),
2962            (65536, inf, 66048.0),
2963            (-65536, 0, -65280.0),
2964            (-65536, -inf, -66048.0),
2965            (nan, 0, nan),
2966            (0, nan, nan),
2967            (nan, nan, nan),
2968            (nan, inf, nan),
2969            (inf, nan, nan),
2970            (inf, -inf, 3.3895313892515355e38),
2971            (-inf, inf, -3.3895313892515355e38),
2972            (inf, 0, 3.3895313892515355e38),
2973            (0, inf, 9.183549615799121e-41),
2974            (-inf, 0, -3.3895313892515355e38),
2975            (0, -inf, -9.183549615799121e-41),
2976        )
2977
2978        for from_v, to_v, expected in cases:
2979            from_t = torch.tensor([from_v], device=device, dtype=dtype)
2980            to_t = torch.tensor([to_v], device=device, dtype=dtype)
2981            actual = torch.nextafter(from_t, to_t).item()
2982            self.assertEqual(actual, expected, atol=0, rtol=0)
2983
2984    def _test_cop(self, torchfn, mathfn, dtype, device):
2985        def reference_implementation(res2):
2986            for i, j in iter_indices(sm1):
2987                idx1d = i * sm1.size(0) + j
2988                res2[i, j] = mathfn(sm1[i, j], sm2[idx1d])
2989            return res2
2990
2991        # contiguous
2992        m1 = torch.randn(10, 10, 10, dtype=dtype, device=device)
2993        m2 = torch.randn(10, 10 * 10, dtype=dtype, device=device)
2994        sm1 = m1[4]
2995        sm2 = m2[4]
2996
2997        res1 = torchfn(sm1, sm2.view(10, 10))
2998        res2 = reference_implementation(res1.clone())
2999        self.assertEqual(res1, res2)
3000
3001        # non-contiguous
3002        m1 = torch.randn(10, 10, 10, dtype=dtype, device=device)
3003        m2 = torch.randn(10 * 10, 10 * 10, dtype=dtype, device=device)
3004        sm1 = m1[:, 4]
3005        sm2 = m2[:, 4]
3006        # view as sm1.size()
3007        sm2.set_(
3008            sm2.storage(),
3009            sm2.storage_offset(),
3010            sm1.size(),
3011            (sm2.stride()[0] * 10, sm2.stride()[0]),
3012        )
3013        res1 = torchfn(sm1, sm2)
3014        # reference_implementation assumes 1-d sm2
3015        sm2.set_(
3016            sm2.storage(), sm2.storage_offset(), m2[:, 4].size(), m2[:, 4].stride()
3017        )
3018        res2 = reference_implementation(res1.clone())
3019        self.assertEqual(res1, res2)
3020
3021    @onlyCPU
3022    @dtypes(torch.float)
3023    def test_cdiv(self, device, dtype):
3024        self._test_cop(torch.div, operator.truediv, dtype, device)
3025
3026    @onlyCPU
3027    @dtypes(torch.float)
3028    def test_cremainder(self, device, dtype):
3029        self._test_cop(torch.remainder, operator.mod, dtype, device)
3030
3031    @onlyCPU
3032    @dtypes(torch.float)
3033    def test_cmul(self, device, dtype):
3034        self._test_cop(torch.mul, operator.mul, dtype, device)
3035
3036    @onlyCPU
3037    @dtypes(torch.float)
3038    def test_cpow(self, device, dtype):
3039        self._test_cop(
3040            torch.pow, lambda x, y: nan if x < 0 else math.pow(x, y), dtype, device
3041        )
3042
3043    @onlyCPU
3044    @dtypes(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
3045    def test_floor_divide_zero(self, device, dtype):
3046        a = torch.tensor([0, 1], dtype=dtype, device=device)
3047        b = torch.tensor([0, 1], dtype=dtype, device=device)
3048        with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError"):
3049            with self.assertWarnsOnceRegex(UserWarning, "floor_divide"):
3050                a // b
3051
3052    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
3053    def test_muldiv_scalar(self, device, dtype):
3054        x = make_tensor((10, 3), dtype=dtype, device=device, low=None, high=None)
3055        s = make_tensor((1,), dtype=dtype, device="cpu", low=None, high=None).item()
3056        y = torch.full_like(x, s)
3057        self.assertEqual(x * s, x * y)
3058        self.assertEqual(s * x, y * x)
3059        self.assertEqual(x / s, x / y)
3060        self.assertEqual(s / x, y / x)
3061
3062    # TODO: update make_tensor to support extremal additions and remove this in favor of make_tensor
3063    def _generate_input(self, shape, dtype, device, with_extremal):
3064        if shape == ():
3065            x = torch.tensor((), dtype=dtype, device=device)
3066        else:
3067            if dtype.is_floating_point or dtype.is_complex:
3068                # work around torch.randn not being implemented for bfloat16
3069                if dtype == torch.bfloat16:
3070                    x = torch.randn(*shape, device=device) * random.randint(30, 100)
3071                    x = x.to(torch.bfloat16)
3072                else:
3073                    x = torch.randn(
3074                        *shape, dtype=dtype, device=device
3075                    ) * random.randint(30, 100)
3076                x[torch.randn(*shape) > 0.5] = 0
3077                if with_extremal and dtype.is_floating_point:
3078                    # Use extremal values
3079                    x[torch.randn(*shape) > 0.5] = float("nan")
3080                    x[torch.randn(*shape) > 0.5] = float("inf")
3081                    x[torch.randn(*shape) > 0.5] = float("-inf")
3082                elif with_extremal and dtype.is_complex:
3083                    x[torch.randn(*shape) > 0.5] = complex("nan")
3084                    x[torch.randn(*shape) > 0.5] = complex("inf")
3085                    x[torch.randn(*shape) > 0.5] = complex("-inf")
3086            elif dtype == torch.bool:
3087                x = torch.zeros(shape, dtype=dtype, device=device)
3088                x[torch.randn(*shape) > 0.5] = True
3089            else:
3090                x = torch.randint(15, 100, shape, dtype=dtype, device=device)
3091
3092        return x
3093
3094    @dtypes(
3095        *tuple(
3096            itertools.combinations_with_replacement(
3097                all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), 2
3098            )
3099        )
3100    )
3101    def test_comparison_ops_type_promotion_and_broadcasting(self, device, dtypes):
3102        # issue #42660
3103        # testing all combinations of broadcasting and type promotion
3104        # with a range of dtypes and input shapes, and with extremal values
3105        def compare_with_numpy_bin_op(torch_fn, np_fn, x, y, out=None):
3106            # working around the fact that numpy doesn't support bfloat16
3107            # by letting numpy treat them as float32's
3108            x_np = x if x.dtype != torch.bfloat16 else x.to(torch.float32)
3109            y_np = (
3110                y.cpu().numpy()
3111                if y.dtype != torch.bfloat16
3112                else y.to(torch.float32).cpu().numpy()
3113            )
3114            self.compare_with_numpy(
3115                lambda inp: torch_fn(inp, y, out=out) if out else torch_fn(inp, y),
3116                lambda inp: np_fn(inp, y_np, out=out) if out else np_fn(inp, y_np),
3117                x_np,
3118            )
3119
3120        complex_op_denylist = [
3121            torch.lt,
3122            torch.le,
3123            torch.gt,
3124            torch.ge,
3125        ]  # complex not supported
3126        input_sizes = [(1,), (10,), (10, 1), (1, 10), (4, 10), (64, 10), (12, 3)]
3127        op_pairs = [
3128            (torch.lt, np.less),
3129            (torch.le, np.less_equal),
3130            (torch.gt, np.greater),
3131            (torch.ge, np.greater_equal),
3132            (torch.eq, np.equal),
3133            (torch.ne, np.not_equal),
3134            (torch.logical_and, np.logical_and),
3135            (torch.logical_or, np.logical_or),
3136            (torch.logical_xor, np.logical_xor),
3137        ]
3138
3139        for size1 in input_sizes:
3140            size2 = (2,) + size1  # perform broadcasting
3141            for with_extremal in [False, True]:
3142                a = self._generate_input(size1, dtypes[0], device, with_extremal)
3143                b = self._generate_input(size2, dtypes[1], device, with_extremal)
3144                for torch_op, numpy_op in op_pairs:
3145                    if (
3146                        dtypes[0].is_complex or dtypes[1].is_complex
3147                    ) and torch_op in complex_op_denylist:
3148                        continue
3149                    # functional version of op
3150                    compare_with_numpy_bin_op(torch_op, numpy_op, a, b)
3151
3152                    # functional comparison ops always return bool tensors
3153                    self.assertEqual(torch_op(a, b).dtype, torch.bool)
3154
3155                    # out version of op
3156                    out = torch.zeros(
3157                        1, dtype=torch.complex128
3158                    )  # all casts to complex128 are safe
3159                    compare_with_numpy_bin_op(torch_op, numpy_op, a, b, out=out)
3160
3161    @onlyNativeDeviceTypes
3162    @dtypes(torch.int8, torch.int16, torch.int32, torch.int64)
3163    def test_signed_shift(self, device, dtype):
3164        "Ensure that signed integer bit shifting works as expected."
3165        a = torch.tensor([-10, 10], device=device, dtype=dtype)  # [11...1110110, 1010]
3166        expected_l = torch.tensor(
3167            [-40, 40], device=device, dtype=dtype
3168        )  # [11...11011000, 101000]
3169        self.assertEqual(a << 2, expected_l)
3170        self.compare_with_numpy(lambda x: x << 2, lambda x: np.left_shift(x, 2), a)
3171        expected_r = torch.tensor(
3172            [-5, 5], device=device, dtype=dtype
3173        )  # [1111...111011, 101]
3174        self.assertEqual(a >> 1, expected_r)
3175        self.compare_with_numpy(lambda x: x >> 1, lambda x: np.right_shift(x, 1), a)
3176
3177    @onlyNativeDeviceTypes
3178    @dtypes(*get_all_int_dtypes())
3179    def test_shift_limits(self, device, dtype):
3180        "Ensure that integer bit shifting works as expected with out-of-limits shift values."
3181        # Issue #70904
3182        iinfo = torch.iinfo(dtype)
3183        bits = iinfo.bits
3184        low = iinfo.min
3185        high = iinfo.max
3186        exact_dtype = (
3187            dtype != torch.uint8
3188        )  # numpy changes dtype from uint8 to int16 for some out-of-limits shift values
3189        for input in (
3190            torch.tensor(
3191                [-1, 0, 1], device=device, dtype=dtype
3192            ),  # small for non-vectorized operation
3193            torch.tensor(
3194                [low, high], device=device, dtype=dtype
3195            ),  # small for non-vectorized operation
3196            make_tensor(
3197                (64, 64, 64), low=low, high=high, device=device, dtype=dtype
3198            ),  # large for vectorized operation
3199        ):
3200            shift_left_expected = torch.zeros_like(input)
3201            shift_right_expected = torch.clamp(input, -1, 0)
3202            for shift in chain(range(-100, -1), range(bits, 100)):
3203                shift_left = input << shift
3204                self.assertEqual(shift_left, shift_left_expected, msg=f"<< {shift}")
3205                self.compare_with_numpy(
3206                    lambda x: x << shift,
3207                    lambda x: np.left_shift(x, shift),
3208                    input,
3209                    exact_dtype=exact_dtype,
3210                    msg=f"<< {shift}",
3211                )
3212                shift_right = input >> shift
3213                self.assertEqual(shift_right, shift_right_expected, msg=f">> {shift}")
3214                self.compare_with_numpy(
3215                    lambda x: x >> shift,
3216                    lambda x: np.right_shift(x, shift),
3217                    input,
3218                    exact_dtype=exact_dtype,
3219                    msg=f">> {shift}",
3220                )
3221
3222    @onlyNativeDeviceTypes
3223    @dtypes(
3224        *list(
3225            product(
3226                all_types_and(torch.half, torch.bfloat16, torch.bool),
3227                all_types_and(torch.half, torch.bfloat16, torch.bool),
3228            )
3229        )
3230    )
3231    def test_heaviside(self, device, dtypes):
3232        input_dtype = dtypes[0]
3233        values_dtype = dtypes[1]
3234
3235        rng = np.random.default_rng()
3236        input = np.array(
3237            rng.integers(-10, 10, size=10),
3238            dtype=torch_to_numpy_dtype_dict[
3239                input_dtype if (input_dtype != torch.bfloat16) else torch.float64
3240            ],
3241        )
3242        input[0] = input[3] = input[7] = 0
3243        values = np.array(
3244            rng.integers(-10, 10, size=10),
3245            dtype=torch_to_numpy_dtype_dict[
3246                values_dtype if (values_dtype != torch.bfloat16) else torch.float64
3247            ],
3248        )
3249        np_result = torch.from_numpy(np.heaviside(input, values)).to(
3250            device=device, dtype=input_dtype
3251        )
3252
3253        input = torch.from_numpy(input).to(device=device, dtype=input_dtype)
3254        values = torch.from_numpy(values).to(device=device, dtype=values_dtype)
3255        out = torch.empty_like(input)
3256
3257        if input_dtype == values_dtype:
3258            torch_result = torch.heaviside(input, values)
3259            self.assertEqual(np_result, torch_result)
3260
3261            torch_result = input.heaviside(values)
3262            self.assertEqual(np_result, torch_result)
3263
3264            torch.heaviside(input, values, out=out)
3265            self.assertEqual(np_result, out)
3266
3267            input.heaviside_(values)
3268            self.assertEqual(np_result, input)
3269        else:
3270            with self.assertRaisesRegex(
3271                RuntimeError,
3272                "heaviside is not yet implemented for tensors with different dtypes.",
3273            ):
3274                torch.heaviside(input, values)
3275            with self.assertRaisesRegex(
3276                RuntimeError,
3277                "heaviside is not yet implemented for tensors with different dtypes.",
3278            ):
3279                input.heaviside(values)
3280            with self.assertRaisesRegex(
3281                RuntimeError,
3282                "heaviside is not yet implemented for tensors with different dtypes.",
3283            ):
3284                torch.heaviside(input, values, out=out)
3285            with self.assertRaisesRegex(
3286                RuntimeError,
3287                "heaviside is not yet implemented for tensors with different dtypes.",
3288            ):
3289                input.heaviside_(values)
3290
3291    @onlyCUDA
3292    def test_heaviside_cross_device(self, device):
3293        x = torch.tensor([-9, 5, 0, 6, -2, 2], device=device)
3294        y = torch.tensor(0)
3295        result = torch.heaviside(x, y)
3296        expect = torch.tensor([0, 1, 0, 1, 0, 1], device=device)
3297        self.assertEqual(result, expect)
3298
3299        result = torch.heaviside(y, x)
3300        expect = torch.tensor([-9, 5, 0, 6, -2, 2], device=device)
3301        self.assertEqual(result, expect)
3302
3303        x = torch.tensor([-9, 5, 0, 6, -2, 2])
3304        y = torch.tensor(0, device=device)
3305        with self.assertRaisesRegex(
3306            RuntimeError, "Expected all tensors to be on the same device"
3307        ):
3308            torch.heaviside(x, y)
3309
3310        with self.assertRaisesRegex(
3311            RuntimeError, "Expected all tensors to be on the same device"
3312        ):
3313            torch.heaviside(y, x)
3314
3315    @dtypes(*list(product(complex_types(), complex_types())))
3316    def test_heaviside_complex(self, device, dtypes):
3317        input_dtype = dtypes[0]
3318        values_dtype = dtypes[1]
3319
3320        data = (complex(0, -6), complex(-1, 3), complex(1, 1))
3321        input = torch.tensor(data, device=device, dtype=input_dtype)
3322        values = torch.tensor(data, device=device, dtype=values_dtype)
3323        out = torch.empty_like(input)
3324        real = input.real
3325
3326        with self.assertRaisesRegex(
3327            RuntimeError, "heaviside is not yet implemented for complex tensors."
3328        ):
3329            torch.heaviside(input, real)
3330        with self.assertRaisesRegex(
3331            RuntimeError, "heaviside is not yet implemented for complex tensors."
3332        ):
3333            real.heaviside(values)
3334        with self.assertRaisesRegex(
3335            RuntimeError, "heaviside is not yet implemented for complex tensors."
3336        ):
3337            input.heaviside_(values)
3338        with self.assertRaisesRegex(
3339            RuntimeError, "heaviside is not yet implemented for complex tensors."
3340        ):
3341            torch.heaviside(real, real, out=out)
3342
3343    def _test_logical(self, device, dtypes, op, a_, b_, expected_res_):
3344        expected_res = torch.tensor(expected_res_, dtype=dtypes[0], device=device)
3345        a = torch.tensor(a_, dtype=dtypes[0], device=device)
3346        b = torch.tensor(b_, dtype=dtypes[1], device=device)
3347
3348        # new tensor
3349        self.assertEqual(expected_res.bool(), getattr(a, op)(b))
3350        # out
3351        c = torch.empty(0, dtype=torch.bool, device=device)
3352        getattr(torch, op)(a, b, out=c)
3353        self.assertEqual(expected_res.bool(), c)
3354
3355        getattr(a, op + "_")(b)
3356        self.assertEqual(expected_res, a)
3357
3358    @dtypes(
3359        *product(
3360            all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
3361            all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
3362        )
3363    )
3364    def test_logical_xor(self, device, dtypes):
3365        self._test_logical(
3366            device, dtypes, "logical_xor", [10, 0, 1, 0], [1, 0, 0, 10], [0, 0, 1, 1]
3367        )
3368
3369    @dtypes(
3370        *product(
3371            all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
3372            all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
3373        )
3374    )
3375    def test_logical_and(self, device, dtypes):
3376        self._test_logical(
3377            device, dtypes, "logical_and", [10, 0, 1, 0], [1, 0, 0, 10], [1, 0, 0, 0]
3378        )
3379
3380    @dtypes(
3381        *product(
3382            all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
3383            all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
3384        )
3385    )
3386    def test_logical_or(self, device, dtypes):
3387        self._test_logical(
3388            device, dtypes, "logical_or", [10, 0, 1, 0], [1, 0, 0, 10], [1, 0, 1, 1]
3389        )
3390
3391    def test_remainder_overflow(self, device):
3392        # Check Integer Overflows
3393        x = torch.tensor(23500, dtype=torch.int64, device=device)
3394        q = 392486996410368
3395        self.assertEqual(x % q, x)
3396        self.assertEqual(-x % q, q - x)
3397        self.assertEqual(x % -q, x - q)
3398        self.assertEqual(-x % -q, -x)
3399
3400    def test_rpow(self, device):
3401        m = torch.randn(10, 10, device=device)
3402        self.assertEqual(torch.pow(2, m), 2**m)
3403
3404        # test with scalar
3405        m = torch.randn(1, device=device).squeeze()
3406        assert m.dim() == 0, "m is intentionally a scalar"
3407        self.assertEqual(torch.pow(2, m), 2**m)
3408
3409    def test_ldexp(self, device):
3410        # random values
3411        mantissas = torch.randn(64, device=device)
3412        exponents = torch.randint(-31, 31, (64,), device=device, dtype=torch.int32)
3413
3414        # basic test
3415        np_outcome = np.ldexp(mantissas.cpu().numpy(), exponents.cpu().numpy())
3416        pt_outcome_1 = torch.ldexp(mantissas, exponents)
3417        pt_outcome_2 = mantissas.ldexp(exponents)
3418        self.assertEqual(np_outcome, pt_outcome_1.cpu())
3419        self.assertEqual(np_outcome, pt_outcome_2.cpu())
3420        mantissas.ldexp_(exponents)
3421        self.assertEqual(np_outcome, mantissas.cpu())
3422
3423        # test bounds
3424        mantissas = torch.tensor(
3425            [float("inf"), float("-inf"), float("inf"), float("nan")], device=device
3426        )
3427        exponents = torch.randint(0, 31, (4,), device=device, dtype=torch.int32)
3428        np_outcome = np.ldexp(mantissas.cpu().numpy(), exponents.cpu().numpy())
3429        pt_outcome = torch.ldexp(mantissas, exponents)
3430        self.assertEqual(np_outcome, pt_outcome.cpu())
3431
3432        # test half dtype behavior
3433        mantissas = torch.randn(64, device=device, dtype=torch.half)
3434        exponents = torch.randint(-5, 5, (64,), device=device)
3435        self.assertEqual(torch.ldexp(mantissas, exponents).dtype, torch.half)
3436
3437        # test float64 computation
3438        mantissas = torch.tensor([1], dtype=torch.float64, device=device)
3439        exponents = torch.tensor([128], dtype=torch.int64, device=device)
3440        expected = torch.pow(
3441            torch.full((1,), 2, device=device, dtype=torch.float64), 128
3442        )
3443        self.assertEqual(torch.ldexp(mantissas, exponents), expected)
3444
3445    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
3446    def test_lerp(self, device, dtype):
3447        start_end_weight_shapes = [(), (5,), (5, 5)]
3448        for shapes in product(
3449            start_end_weight_shapes, start_end_weight_shapes, start_end_weight_shapes
3450        ):
3451            start = torch.randn(shapes[0], device=device, dtype=dtype)
3452            end = torch.randn(shapes[1], device=device, dtype=dtype)
3453
3454            # Tensor weights
3455            weights = [
3456                torch.randn(shapes[2], device=device, dtype=dtype),
3457                random.random(),
3458            ]
3459            if dtype.is_complex:
3460                weights += [complex(0, 1), complex(0.4, 1.2)]
3461
3462            for weight in weights:
3463                actual = torch.lerp(start, end, weight)
3464                actual_method = start.lerp(end, weight)
3465                self.assertEqual(actual, actual_method)
3466                actual_out = torch.tensor(1.0, dtype=dtype, device=device)
3467                torch.lerp(start, end, weight, out=actual_out)
3468                self.assertEqual(actual, actual_out)
3469                expected = start + weight * (end - start)
3470                self.assertEqual(expected, actual)
3471
3472    @onlyCUDA
3473    @dtypes(torch.half, torch.bfloat16)
3474    def test_lerp_lowp(self, device, dtype):
3475        xvals = (0.0, -30000.0)
3476        yvals = (0.1, -20000.0)
3477        xs = [torch.full((4,), xval, device=device, dtype=dtype) for xval in xvals]
3478        ys = [torch.full((4,), yval, device=device, dtype=dtype) for yval in yvals]
3479        weights = [70000, torch.full((4,), 8, device=device, dtype=dtype)]
3480        for x, y, w in zip(xs, ys, weights):
3481            xref = x.float()
3482            yref = y.float()
3483            wref = w.float() if isinstance(w, torch.Tensor) else w
3484            actual = torch.lerp(x, y, w)
3485            expected = torch.lerp(xref, yref, wref).to(dtype)
3486            self.assertEqual(actual, expected, atol=0.0, rtol=0.0)
3487
3488    @onlyCPU
3489    @dtypes(torch.half, torch.bfloat16)
3490    def test_lerp_lowp_cpu(self, device, dtype):
3491        xvals = (0.0, -30000.0)
3492        yvals = (0.1, -20000.0)
3493        for shape in [(4,), (20,), (3, 10, 10)]:
3494            xs = [torch.full(shape, xval, device=device, dtype=dtype) for xval in xvals]
3495            ys = [torch.full(shape, yval, device=device, dtype=dtype) for yval in yvals]
3496            weights = [70000, torch.full(shape, 8, device=device, dtype=dtype)]
3497            for x, y, w in zip(xs, ys, weights):
3498                xref = x.float()
3499                yref = y.float()
3500                wref = w.float() if isinstance(w, torch.Tensor) else w
3501                actual = torch.lerp(x, y, w)
3502                expected = torch.lerp(xref, yref, wref).to(dtype)
3503                self.assertEqual(actual, expected, atol=0.0, rtol=0.0)
3504
3505    def _test_logaddexp(self, device, dtype, base2):
3506        if base2:
3507            ref_func = np.logaddexp2
3508            our_func = torch.logaddexp2
3509        elif dtype in (torch.complex64, torch.complex128):
3510            # numpy has not implemented logaddexp for complex
3511            def _ref_func(x, y):
3512                return scipy.special.logsumexp(np.stack((x, y), axis=0), axis=0)
3513
3514            ref_func = _ref_func
3515            our_func = torch.logaddexp
3516        else:
3517            ref_func = np.logaddexp
3518            our_func = torch.logaddexp
3519
3520        def _test_helper(a, b):
3521            if dtype == torch.bfloat16:
3522                ref = ref_func(a.cpu().float().numpy(), b.cpu().float().numpy())
3523                v = our_func(a, b)
3524                self.assertEqual(ref, v.float(), atol=0.01, rtol=0.01)
3525            else:
3526                ref = ref_func(a.cpu().numpy(), b.cpu().numpy())
3527                v = our_func(a, b)
3528                self.assertEqual(ref, v)
3529
3530        # simple test
3531        a = torch.randn(64, 2, dtype=dtype, device=device) - 0.5
3532        b = torch.randn(64, 2, dtype=dtype, device=device) - 0.5
3533        _test_helper(a, b)
3534        _test_helper(a[:3], b[:3])
3535
3536        # large value test for numerical stability
3537        a *= 10000
3538        b *= 10000
3539        _test_helper(a, b)
3540        _test_helper(a[:3], b[:3])
3541
3542        a = torch.tensor(
3543            [float("inf"), float("-inf"), float("inf"), float("nan")],
3544            dtype=dtype,
3545            device=device,
3546        )
3547        b = torch.tensor(
3548            [float("inf"), float("-inf"), float("-inf"), float("nan")],
3549            dtype=dtype,
3550            device=device,
3551        )
3552        _test_helper(a, b)
3553
3554    @skipIfTorchDynamo()  # complex infs/nans differ under Dynamo/Inductor
3555    @dtypesIfCUDA(torch.float32, torch.float64, torch.bfloat16)
3556    @dtypes(
3557        torch.float32, torch.float64, torch.bfloat16, torch.complex64, torch.complex128
3558    )
3559    def test_logaddexp(self, device, dtype):
3560        self._test_logaddexp(device, dtype, base2=False)
3561
3562    @dtypes(torch.float32, torch.float64, torch.bfloat16)
3563    def test_logaddexp2(self, device, dtype):
3564        self._test_logaddexp(device, dtype, base2=True)
3565
3566    def test_add(self, device):
3567        dtypes = floating_and_complex_types()
3568        for dtype in dtypes:
3569            # [res] torch.add([res,] tensor1, tensor2)
3570            m1 = torch.randn(100, 100, dtype=dtype, device=device)
3571            v1 = torch.randn(100, dtype=dtype, device=device)
3572
3573            # contiguous
3574            res1 = torch.add(m1[4], v1)
3575            res2 = res1.clone().zero_()
3576            for i in range(m1.size(1)):
3577                res2[i] = m1[4, i] + v1[i]
3578            self.assertEqual(res1, res2)
3579
3580            m1 = torch.randn(100, 100, device=device)
3581            v1 = torch.randn(100, device=device)
3582
3583            # non-contiguous
3584            res1 = torch.add(m1[:, 4], v1)
3585            res2 = res1.clone().zero_()
3586            for i in range(m1.size(0)):
3587                res2[i] = m1[i, 4] + v1[i]
3588            self.assertEqual(res1, res2)
3589
3590            # [res] torch.add([res,] tensor, value)
3591            m1 = torch.randn(10, 10, device=device)
3592
3593            # contiguous
3594            res1 = m1.clone()
3595            res1[3].add_(2)
3596            res2 = m1.clone()
3597            for i in range(m1.size(1)):
3598                res2[3, i] = res2[3, i] + 2
3599            self.assertEqual(res1, res2)
3600
3601            # non-contiguous
3602            m1 = torch.randn(10, 10, device=device)
3603            res1 = m1.clone()
3604            res1[:, 3].add_(2)
3605            res2 = m1.clone()
3606            for i in range(m1.size(0)):
3607                res2[i, 3] = res2[i, 3] + 2
3608            self.assertEqual(res1, res2)
3609
3610            # inter-type
3611            m1 = torch.randn(10, 10, dtype=dtype, device=device)
3612            self.assertEqual(m1 + 3, m1 + torch.tensor(3))
3613            self.assertEqual(3 + m1, torch.tensor(3) + m1)
3614
3615            # contiguous + non-contiguous
3616            m1 = torch.randn(10, 10, dtype=dtype, device=device)
3617            m2 = torch.randn(10, 10, dtype=dtype, device=device).t()
3618            res = m1 + m2
3619            self.assertTrue(res.is_contiguous())
3620            self.assertEqual(res, m1 + m2.contiguous())
3621
3622            # 1d + empty
3623            m1 = torch.tensor([1.0], dtype=dtype, device=device)
3624            m2 = torch.tensor([], dtype=dtype, device=device)
3625            self.assertEqual(m1 + m2, [])
3626
3627        # inter-type unint8
3628        one = torch.tensor(1, dtype=torch.uint8, device=device)
3629        self.assertEqual(torch.add(one, 1), 2)
3630        self.assertEqual(torch.add(one, 1).dtype, torch.uint8)
3631
3632        # bool
3633        m1 = torch.tensor(
3634            [True, False, False, True, False, False], dtype=torch.bool, device=device
3635        )
3636        m2 = torch.tensor(
3637            [True, True, False, False, False, True], dtype=torch.bool, device=device
3638        )
3639        expected = torch.tensor(
3640            [True, True, False, True, False, True], dtype=torch.bool, device=device
3641        )
3642        self.assertEqual(m1 + m2, expected)
3643
3644        # fused multiply add
3645        a = torch.zeros(2, 3, dtype=torch.bool, device=device)
3646        res = torch.add(a, a, alpha=0)
3647        expected = torch.zeros(2, 3, device=device).bool()
3648        self.assertEqual(res, expected)
3649
3650        # bfloat16
3651        m1 = torch.tensor([1.0, 2.0], dtype=torch.bfloat16)
3652        m2 = torch.tensor([3.0, 4.0], dtype=torch.bfloat16)
3653        self.assertEqual(m1 + m2, torch.tensor([4.0, 6.0], dtype=torch.bfloat16))
3654
3655        # different alpha types
3656        m1 = torch.tensor([2 + 3j, 4 + 5j], dtype=torch.complex64, device=device)
3657        m2 = torch.tensor([4 + 5j, 2 + 3j], dtype=torch.complex64, device=device)
3658        # add complex numbers with float alpha
3659        res = torch.add(m1, m2, alpha=0.1)
3660        expected = torch.tensor(
3661            [2.4000 + 3.5000j, 4.2000 + 5.3000j], dtype=torch.complex64, device=device
3662        )
3663        self.assertEqual(res, expected)
3664
3665        # add complex numbers with complex alpha
3666        res = torch.add(m1, m2, alpha=complex(0.1, 0.2))
3667        expected = torch.tensor(
3668            [1.4000 + 4.3000j, 3.6000 + 5.7000j], dtype=torch.complex64, device=device
3669        )
3670        self.assertEqual(res, expected)
3671
3672        # add complex numbers with integer alpha
3673        res = torch.add(m1, m2, alpha=2)
3674        expected = torch.tensor(
3675            [10.0 + 13.0j, 8.0 + 11.0j], dtype=torch.complex64, device=device
3676        )
3677        self.assertEqual(res, expected)
3678
3679        # mismatched alpha
3680        m1 = torch.tensor([1], dtype=torch.int8, device=device)
3681        m2 = torch.tensor([2], dtype=torch.int8, device=device)
3682        self.assertRaisesRegex(
3683            RuntimeError,
3684            r"Boolean alpha only supported for Boolean results\.",
3685            lambda: torch.add(m1, m2, alpha=True),
3686        )
3687        self.assertRaisesRegex(
3688            RuntimeError,
3689            r"For integral input tensors, argument alpha must not be a floating point number\.",
3690            lambda: torch.add(m1, m2, alpha=1.0),
3691        )
3692
3693        # mismatched alpha, float / double tensor and complex alpha
3694        msg = r"For non-complex input tensors, argument alpha must not be a complex number\."
3695        m1 = torch.tensor([3.0, 4.0], device=device)
3696        m2 = torch.tensor([4.0, 3.0], device=device)
3697        self.assertRaisesRegex(
3698            RuntimeError, msg, lambda: torch.add(m1, m2, alpha=complex(0.1, 0.2))
3699        )
3700
3701        m1 = torch.tensor([3.0, 4.0], dtype=torch.double, device=device)
3702        m2 = torch.tensor([4.0, 3.0], dtype=torch.double, device=device)
3703        self.assertRaisesRegex(
3704            RuntimeError, msg, lambda: torch.add(m1, m2, alpha=complex(0.1, 0.2))
3705        )
3706
3707        # complex
3708        m1 = torch.tensor((4.0000 + 4.0000j), dtype=torch.complex64)
3709        m2 = torch.tensor(4.0, dtype=torch.float64)
3710        self.assertRaisesRegex(
3711            RuntimeError,
3712            r"result type ComplexFloat can't be cast to the desired output type Double",
3713            lambda: torch.add(m1, m1, out=m2),
3714        )
3715
3716    @onlyCUDA
3717    def test_addsub_half_tensor(self, device):
3718        x = torch.tensor([60000.0], dtype=torch.half, device=device)
3719        for op, y, alpha in (
3720            (torch.add, torch.tensor([-60000.0], dtype=torch.half, device=device), 2),
3721            (torch.sub, torch.tensor([60000.0], dtype=torch.half, device=device), 2),
3722            (torch.add, -70000.0, 1),
3723            (torch.sub, 70000.0, 1),
3724        ):
3725            actual = op(x, y, alpha=alpha)
3726            self.assertTrue(not (actual.isnan() or actual.isinf()))
3727
3728    def test_sub_typing(self, device):
3729        m1 = torch.tensor(
3730            [True, False, False, True, False, False], dtype=torch.bool, device=device
3731        )
3732        m2 = torch.tensor(
3733            [True, True, False, False, False, True], dtype=torch.bool, device=device
3734        )
3735        self.assertRaisesRegex(
3736            RuntimeError,
3737            r"Subtraction, the `\-` operator, with two bool tensors is not supported. "
3738            r"Use the `\^` or `logical_xor\(\)` operator instead.",
3739            lambda: m1 - m2,
3740        )
3741        self.assertRaisesRegex(
3742            RuntimeError,
3743            r"Subtraction, the `\-` operator, with a bool tensor is not supported. "
3744            r"If you are trying to invert a mask, use the `\~` or `logical_not\(\)` operator instead.",
3745            lambda: 1 - m1,
3746        )
3747        self.assertRaisesRegex(
3748            RuntimeError,
3749            r"Subtraction, the `\-` operator, with a bool tensor is not supported. "
3750            r"If you are trying to invert a mask, use the `\~` or `logical_not\(\)` operator instead.",
3751            lambda: m2 - 1,
3752        )
3753
3754        # mismatched alpha
3755        m1 = torch.tensor([1], dtype=torch.int8, device=device)
3756        m2 = torch.tensor([2], dtype=torch.int8, device=device)
3757        self.assertRaisesRegex(
3758            RuntimeError,
3759            r"Boolean alpha only supported for Boolean results\.",
3760            lambda: torch.sub(m1, m2, alpha=True),
3761        )
3762        self.assertRaisesRegex(
3763            RuntimeError,
3764            r"For integral input tensors, argument alpha must not be a floating point number\.",
3765            lambda: torch.sub(m1, m2, alpha=1.0),
3766        )
3767
3768    def test_mul(self, device):
3769        m1 = torch.randn(10, 10, device=device)
3770        res1 = m1.clone()
3771        res1[:, 3].mul_(2)
3772        res2 = m1.clone()
3773        for i in range(res1.size(0)):
3774            res2[i, 3] = res2[i, 3] * 2
3775        self.assertEqual(res1, res2)
3776
3777        a1 = torch.tensor([True, False, False, True], dtype=torch.bool, device=device)
3778        a2 = torch.tensor([True, False, True, False], dtype=torch.bool, device=device)
3779        self.assertEqual(
3780            a1 * a2,
3781            torch.tensor([True, False, False, False], dtype=torch.bool, device=device),
3782        )
3783
3784        if device == "cpu":
3785            a1 = torch.tensor([0.1, 0.1], dtype=torch.bfloat16, device=device)
3786            a2 = torch.tensor([1.1, 0.1], dtype=torch.bfloat16, device=device)
3787            self.assertEqual(
3788                a1 * a2,
3789                torch.tensor([0.11, 0.01], dtype=torch.bfloat16, device=device),
3790                atol=0.01,
3791                rtol=0,
3792            )
3793            self.assertEqual(a1.mul(a2), a1 * a2)
3794
3795    def test_bool_tensor_comparison_ops(self, device):
3796        a = torch.tensor(
3797            [True, False, True, False, True, False], dtype=torch.bool, device=device
3798        )
3799        b = torch.tensor(
3800            [True, False, True, True, True, True], dtype=torch.bool, device=device
3801        )
3802        self.assertEqual(
3803            a == b, torch.tensor([1, 1, 1, 0, 1, 0], dtype=torch.bool, device=device)
3804        )
3805        self.assertEqual(
3806            a != b, torch.tensor([0, 0, 0, 1, 0, 1], dtype=torch.bool, device=device)
3807        )
3808        self.assertEqual(
3809            a < b, torch.tensor([0, 0, 0, 1, 0, 1], dtype=torch.bool, device=device)
3810        )
3811        self.assertEqual(
3812            a > b, torch.tensor([0, 0, 0, 0, 0, 0], dtype=torch.bool, device=device)
3813        )
3814        self.assertEqual(
3815            a >= b, torch.tensor([1, 1, 1, 0, 1, 0], dtype=torch.bool, device=device)
3816        )
3817        self.assertEqual(
3818            a <= b, torch.tensor([1, 1, 1, 1, 1, 1], dtype=torch.bool, device=device)
3819        )
3820        self.assertEqual(
3821            a > False, torch.tensor([1, 0, 1, 0, 1, 0], dtype=torch.bool, device=device)
3822        )
3823        self.assertEqual(
3824            a == torch.tensor(True, dtype=torch.bool, device=device),
3825            torch.tensor([1, 0, 1, 0, 1, 0], dtype=torch.bool, device=device),
3826        )
3827        self.assertEqual(
3828            a == torch.tensor(0, dtype=torch.bool, device=device),
3829            torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.bool, device=device),
3830        )
3831        self.assertFalse(a.equal(b))
3832
3833    @dtypes(*all_types_and(torch.half, torch.bfloat16, torch.bool))
3834    def test_logical(self, device, dtype):
3835        if dtype != torch.bool:
3836            x = torch.tensor([1, 2, 3, 4], device=device, dtype=dtype)
3837            b = torch.tensor([2], device=device, dtype=dtype)
3838            self.assertEqual(x.lt(2), torch.tensor([True, False, False, False]))
3839            self.assertEqual(x.le(2), torch.tensor([True, True, False, False]))
3840            self.assertEqual(x.ge(2), torch.tensor([False, True, True, True]))
3841            self.assertEqual(x.gt(2), torch.tensor([False, False, True, True]))
3842            self.assertEqual(x.eq(2), torch.tensor([False, True, False, False]))
3843            self.assertEqual(x.ne(2), torch.tensor([True, False, True, True]))
3844
3845            self.assertEqual(x.lt(b), torch.tensor([True, False, False, False]))
3846            self.assertEqual(x.le(b), torch.tensor([True, True, False, False]))
3847            self.assertEqual(x.ge(b), torch.tensor([False, True, True, True]))
3848            self.assertEqual(x.gt(b), torch.tensor([False, False, True, True]))
3849            self.assertEqual(x.eq(b), torch.tensor([False, True, False, False]))
3850            self.assertEqual(x.ne(b), torch.tensor([True, False, True, True]))
3851        else:
3852            x = torch.tensor([True, False, True, False], device=device)
3853            self.assertEqual(x.lt(True), torch.tensor([False, True, False, True]))
3854            self.assertEqual(x.le(True), torch.tensor([True, True, True, True]))
3855            self.assertEqual(x.ge(True), torch.tensor([True, False, True, False]))
3856            self.assertEqual(x.gt(True), torch.tensor([False, False, False, False]))
3857            self.assertEqual(x.eq(True), torch.tensor([True, False, True, False]))
3858            self.assertEqual(x.ne(True), torch.tensor([False, True, False, True]))
3859
3860    def test_atan2(self, device):
3861        def _test_atan2_with_size(size, device):
3862            a = torch.rand(size=size, device=device, dtype=torch.double)
3863            b = torch.rand(size=size, device=device, dtype=torch.double)
3864            actual = a.atan2(b)
3865            x = a.view(-1)
3866            y = b.view(-1)
3867            expected = torch.tensor(
3868                [math.atan2(x[i].item(), y[i].item()) for i in range(x.numel())],
3869                device=device,
3870                dtype=torch.double,
3871            )
3872            self.assertEqual(expected, actual.view(-1), rtol=0, atol=0.02)
3873
3874            # bfloat16/float16
3875            for lowp_dtype in [torch.bfloat16, torch.float16]:
3876                if lowp_dtype == torch.bfloat16:
3877                    rtol = 0
3878                    atol = 0.02
3879                else:
3880                    rtol = 0
3881                    atol = 0.001
3882                a_16 = a.to(dtype=lowp_dtype)
3883                b_16 = b.to(dtype=lowp_dtype)
3884                actual_16 = a_16.atan2(b_16)
3885                self.assertEqual(actual_16, actual.to(dtype=lowp_dtype))
3886                self.assertEqual(
3887                    expected,
3888                    actual_16.view(-1),
3889                    exact_dtype=False,
3890                    rtol=rtol,
3891                    atol=atol,
3892                )
3893
3894        _test_atan2_with_size((2, 2), device)
3895        _test_atan2_with_size((3, 3), device)
3896        _test_atan2_with_size((5, 5), device)
3897
3898    def test_atan2_edgecases(self, device):
3899        def _test_atan2(x, y, expected, device, dtype):
3900            expected_tensor = torch.tensor([expected], dtype=dtype, device=device)
3901            x_tensor = torch.tensor([x], dtype=dtype, device=device)
3902            y_tensor = torch.tensor([y], dtype=dtype, device=device)
3903            actual = torch.atan2(y_tensor, x_tensor)
3904            self.assertEqual(expected_tensor, actual, rtol=0, atol=0.02)
3905
3906        for dtype in [torch.float, torch.double]:
3907            _test_atan2(0, 0, 0, device, dtype)
3908            _test_atan2(0, 1, math.pi / 2, device, dtype)
3909            _test_atan2(0, -1, math.pi / -2, device, dtype)
3910            _test_atan2(-1, 0, math.pi, device, dtype)
3911            _test_atan2(1, 0, 0, device, dtype)
3912            _test_atan2(-1, -1, math.pi * -3 / 4, device, dtype)
3913            _test_atan2(1, 1, math.pi / 4, device, dtype)
3914            _test_atan2(1, -1, math.pi / -4, device, dtype)
3915            _test_atan2(-1, 1, math.pi * 3 / 4, device, dtype)
3916
3917    def test_trapezoid(self, device):
3918        def test_dx(sizes, dim, dx, device):
3919            t = torch.randn(sizes, device=device)
3920            actual = torch.trapezoid(t, dx=dx, dim=dim)
3921            expected = np.trapz(t.cpu().numpy(), dx=dx, axis=dim)  # noqa: NPY201
3922            self.assertEqual(expected.shape, actual.shape)
3923            self.assertEqual(expected, actual, exact_dtype=False)
3924
3925        def test_x(sizes, dim, x, device):
3926            t = torch.randn(sizes, device=device)
3927            actual = torch.trapezoid(t, x=torch.tensor(x, device=device), dim=dim)
3928            expected = np.trapz(t.cpu().numpy(), x=x, axis=dim)  # noqa: NPY201
3929            self.assertEqual(expected.shape, actual.shape)
3930            self.assertEqual(expected, actual.cpu(), exact_dtype=False)
3931
3932        test_dx((2, 3, 4), 1, 1, device)
3933        test_dx((10, 2), 0, 0.1, device)
3934        test_dx((1, 10), 0, 2.3, device)
3935        test_dx((0, 2), 0, 1.0, device)
3936        test_dx((0, 2), 1, 1.0, device)
3937        test_x((2, 3, 4), 1, [1.0, 2.0, 3.0], device)
3938        test_x(
3939            (10, 2), 0, [2.0, 3.0, 4.0, 7.0, 11.0, 14.0, 22.0, 26.0, 26.1, 30.3], device
3940        )
3941        test_x((1, 10), 0, [1.0], device)
3942        test_x((0, 2), 0, [], device)
3943        test_x((0, 2), 1, [1.0, 2.0], device)
3944        test_x((2, 3, 4), -1, [1.0, 2.0, 3.0, 4.0], device)
3945        test_x((2, 3, 4), 0, [1.0, 2.0], device)
3946        test_x((2, 3, 4), 1, [1.0, 2.0, 3.0], device)
3947        test_x((2, 3, 4), 2, [1.0, 2.0, 3.0, 4.0], device)
3948        test_x((2, 2, 4), -1, [[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]], device)
3949        with self.assertRaisesRegex(IndexError, "Dimension out of range"):
3950            test_x((2, 3), 2, [], device)
3951            test_dx((2, 3), 2, 1.0, device)
3952        with self.assertRaisesRegex(
3953            RuntimeError, "There must be one `x` value for each sample point"
3954        ):
3955            test_x((2, 3), 1, [1.0, 2.0], device)
3956            test_x((2, 3), 1, [1.0, 2.0, 3.0, 4.0], device)
3957
3958    @skipIf(not TEST_SCIPY, "Scipy required for the test.")
3959    def test_cumulative_trapezoid(self, device):
3960        import scipy.integrate
3961
3962        if hasattr(scipy.integrate, "cumulative_trapezoid"):
3963            _scipy_cumulative_trapezoid = scipy.integrate.cumulative_trapezoid
3964        else:  # Older version of SciPy uses a different name
3965            _scipy_cumulative_trapezoid = scipy.integrate.cumtrapz
3966
3967        def scipy_cumulative_trapezoid(y, x=None, dx=1.0, axis=-1, initial=None):
3968            if y.shape[axis] == 0:
3969                return np.empty_like(y)
3970            else:
3971                return _scipy_cumulative_trapezoid(y, x, dx, axis, initial)
3972
3973        def test_dx(sizes, dim, dx, device):
3974            t = torch.randn(sizes, device=device)
3975            y = t.cpu().numpy()
3976            actual = torch.cumulative_trapezoid(t, dx=dx, dim=dim)
3977            expected = scipy_cumulative_trapezoid(t.cpu().numpy(), dx=dx, axis=dim)
3978            self.assertEqual(expected.shape, actual.shape)
3979            self.assertEqual(expected, actual, exact_dtype=False, atol=1e-4, rtol=1e-4)
3980
3981        def test_x(sizes, dim, x, device):
3982            t = torch.randn(sizes, device=device)
3983            actual = torch.cumulative_trapezoid(
3984                t, x=torch.tensor(x, device=device), dim=dim
3985            )
3986            expected = scipy_cumulative_trapezoid(t.cpu().numpy(), x=x, axis=dim)
3987            self.assertEqual(expected.shape, actual.shape)
3988            self.assertEqual(
3989                expected, actual.cpu(), exact_dtype=False, atol=1e-4, rtol=1e-4
3990            )
3991
3992        def test_empty_x(sizes, dim, x, device):
3993            t = torch.randn(sizes, device=device)
3994            actual = torch.cumulative_trapezoid(
3995                t, x=torch.tensor(x, device=device), dim=dim
3996            )
3997            self.assertEqual(torch.empty(actual.shape), actual)
3998
3999        test_dx((2,), -1, 1, device)
4000        test_dx((3, 3), -1, 1, device)
4001        test_dx((4, 2), 0, 1, device)
4002        test_dx((2, 3, 4), 1, 1, device)
4003        test_dx((10, 2), 0, 0.1, device)
4004        test_dx((1, 10), 0, 2.3, device)
4005        test_dx((0, 2), 0, 1.0, device)
4006        test_dx((0, 2), 1, 1.0, device)
4007        test_dx((512, 512), 1, 1.0, device)
4008        test_dx((100, 100, 100), 1, 1.0, device)
4009
4010        test_x((2,), -1, [100, 50], device)
4011        test_x((4, 2), 0, [2, 3, 4, 5], device)
4012        test_x((2, 3, 4), 1, [1.0, 2.0, 3.0], device)
4013        test_x(
4014            (10, 2), 0, [2.0, 3.0, 4.0, 7.0, 11.0, 14.0, 22.0, 26.0, 26.1, 30.3], device
4015        )
4016        test_x((1, 10), 0, [1.0], device)
4017        test_x((0, 2), 1, [1, 2], device)
4018        test_x((2, 3, 4), -1, [1.0, 2.0, 3.0, 4.0], device)
4019        test_x((2, 3, 4), 0, [1.0, 2.0], device)
4020        test_x((2, 3, 4), 1, [1.0, 2.0, 3.0], device)
4021        test_x((2, 3, 4), 2, [1.0, 2.0, 3.0, 4.0], device)
4022
4023        test_empty_x(
4024            (0, 2), 0, [], device
4025        )  # SciPy failing when x == [], but our version returns empty
4026
4027        with self.assertRaisesRegex(IndexError, "Dimension out of range"):
4028            test_x((2, 3), 2, [], device)
4029            test_dx((2, 3), 2, 1.0, device)
4030        with self.assertRaisesRegex(
4031            RuntimeError, "There must be one `x` value for each sample point"
4032        ):
4033            test_x((2, 3), 1, [1.0, 2.0], device)
4034            test_x((0, 2), 0, [1.0, 2.0], device)
4035            test_x((2, 3), 1, [1.0, 2.0, 3.0, 4.0], device)
4036        with self.assertRaisesRegex(
4037            RuntimeError, "Currently, we only support dx as a real number"
4038        ):
4039            test_dx((2, 2), -1, complex(1, 1), device)
4040        with self.assertRaisesRegex(
4041            TypeError, "received an invalid combination of arguments"
4042        ):
4043            actual = torch.cumulative_trapezoid(
4044                torch.randn((3, 3)), x=torch.randn((3, 3)), dx=3
4045            )
4046
4047    @skipMeta
4048    @dtypes(torch.double)
4049    def test_pow_scalar_overloads_mem_overlap(self, device, dtype):
4050        sz = 3
4051        doubles = torch.randn(2 * sz, dtype=dtype, device=device)
4052        self.check_internal_mem_overlap(lambda t: t.pow_(42), 1, dtype, device)
4053        self.unary_check_input_output_mem_overlap(
4054            doubles, sz, lambda input, out: torch.pow(input, 42, out=out)
4055        )
4056        self.unary_check_input_output_mem_overlap(
4057            doubles, sz, lambda input, out: torch.pow(42, input, out=out)
4058        )
4059
4060    @dtypes(
4061        *list(
4062            product(
4063                all_types_and_complex_and(torch.half, torch.bfloat16),
4064                all_types_and_complex_and(torch.half, torch.bfloat16),
4065            )
4066        )
4067    )
4068    def test_float_power(self, device, dtypes):
4069        def to_np(value):
4070            if isinstance(value, torch.Tensor) and value.dtype == torch.bfloat16:
4071                return value.to(torch.float).cpu().numpy()
4072            return value.cpu().numpy() if isinstance(value, torch.Tensor) else value
4073
4074        base_dtype = dtypes[0]
4075        exp_dtype = dtypes[1]
4076        out_dtype = (
4077            torch.complex128
4078            if base_dtype.is_complex or exp_dtype.is_complex
4079            else torch.float64
4080        )
4081
4082        base = make_tensor((30,), dtype=base_dtype, device=device, low=1, high=100)
4083        # Complex and real results do not agree between PyTorch and NumPy when computing negative and zero power of 0
4084        # Related: https://github.com/pytorch/pytorch/issues/48000
4085        # base[0] = base[3] = base[7] = 0
4086        exp = make_tensor((30,), dtype=exp_dtype, device=device, low=-2, high=2)
4087        exp[0] = exp[4] = exp[6] = 0
4088
4089        expected = torch.from_numpy(np.float_power(to_np(base), to_np(exp)))
4090
4091        exponents = [-2.8, -2, -1, -0.5, 0.5, 1, 2]
4092        complex_exponents = exponents + [
4093            -2.5j,
4094            -1.0j,
4095            1.0j,
4096            2.5j,
4097            1.0 + 1.0j,
4098            -1.0 - 1.5j,
4099            3.3j,
4100        ]
4101
4102        for op in (
4103            torch.float_power,
4104            torch.Tensor.float_power,
4105            torch.Tensor.float_power_,
4106        ):
4107            # Case of Tensor x Tensor
4108            if op is torch.Tensor.float_power_ and base_dtype != out_dtype:
4109                with self.assertRaisesRegex(
4110                    RuntimeError, "operation's result requires dtype"
4111                ):
4112                    op(base.clone(), exp)
4113            else:
4114                result = op(base.clone(), exp)
4115                self.assertEqual(expected, result)
4116
4117            if op is torch.float_power:
4118                out = torch.empty_like(base).to(device=device, dtype=out_dtype)
4119                op(base, exp, out=out)
4120                self.assertEqual(expected, out)
4121
4122            # Case of Tensor x Scalar
4123            for i in complex_exponents if exp_dtype.is_complex else exponents:
4124                out_dtype_scalar_exp = (
4125                    torch.complex128
4126                    if base_dtype.is_complex or type(i) == complex
4127                    else torch.float64
4128                )
4129                expected_scalar_exp = torch.from_numpy(np.float_power(to_np(base), i))
4130
4131                if (
4132                    op is torch.Tensor.float_power_
4133                    and base_dtype != out_dtype_scalar_exp
4134                ):
4135                    with self.assertRaisesRegex(
4136                        RuntimeError, "operation's result requires dtype"
4137                    ):
4138                        op(base.clone(), i)
4139                else:
4140                    result = op(base.clone(), i)
4141                    self.assertEqual(expected_scalar_exp, result)
4142
4143                if op is torch.float_power:
4144                    out = torch.empty_like(base).to(
4145                        device=device, dtype=out_dtype_scalar_exp
4146                    )
4147                    op(base, i, out=out)
4148                    self.assertEqual(expected_scalar_exp, out)
4149
4150        # Case of Scalar x Tensor
4151        for i in complex_exponents if base_dtype.is_complex else exponents:
4152            out_dtype_scalar_base = (
4153                torch.complex128
4154                if exp_dtype.is_complex or type(i) == complex
4155                else torch.float64
4156            )
4157            expected_scalar_base = torch.from_numpy(np.float_power(i, to_np(exp)))
4158
4159            result = torch.float_power(i, exp)
4160            self.assertEqual(expected_scalar_base, result)
4161
4162            out = torch.empty_like(exp).to(device=device, dtype=out_dtype_scalar_base)
4163            torch.float_power(i, exp, out=out)
4164            self.assertEqual(expected_scalar_base, out)
4165
4166    def test_float_power_exceptions(self, device):
4167        def _promo_helper(x, y):
4168            for i in (x, y):
4169                if type(i) == complex:
4170                    return torch.complex128
4171                elif type(i) == torch.Tensor and i.is_complex():
4172                    return torch.complex128
4173            return torch.double
4174
4175        test_cases = (
4176            (torch.tensor([-2, -1, 0, 1, 2], device=device), -0.25),
4177            (
4178                torch.tensor([-1.0j, 0j, 1.0j, 1.0 + 1.0j, -1.0 - 1.5j], device=device),
4179                2.0,
4180            ),
4181        )
4182        for base, exp in test_cases:
4183            for out_dtype in (torch.long, torch.float, torch.double, torch.cdouble):
4184                out = torch.empty(1, device=device, dtype=out_dtype)
4185                required_dtype = _promo_helper(base, exp)
4186
4187                if out.dtype == required_dtype:
4188                    torch.float_power(base, exp, out=out)
4189                else:
4190                    with self.assertRaisesRegex(
4191                        RuntimeError, "operation's result requires dtype"
4192                    ):
4193                        torch.float_power(base, exp, out=out)
4194
4195                if base.dtype == required_dtype:
4196                    torch.Tensor.float_power_(base.clone(), exp)
4197                else:
4198                    with self.assertRaisesRegex(
4199                        RuntimeError, "operation's result requires dtype"
4200                    ):
4201                        torch.Tensor.float_power_(base.clone(), exp)
4202
4203    @skipIf(not TEST_SCIPY, "Scipy required for the test.")
4204    @dtypes(
4205        *product(
4206            all_types_and(torch.half, torch.bool), all_types_and(torch.half, torch.bool)
4207        )
4208    )
4209    def test_xlogy_xlog1py(self, device, dtypes):
4210        x_dtype, y_dtype = dtypes
4211
4212        def out_variant_helper(torch_fn, x, y):
4213            expected = torch_fn(x, y)
4214            out = torch.empty_like(expected)
4215            torch_fn(x, y, out=out)
4216            self.assertEqual(expected, out)
4217
4218        def xlogy_inplace_variant_helper(x, y):
4219            if x.dtype in integral_types_and(torch.bool):
4220                with self.assertRaisesRegex(
4221                    RuntimeError, "can't be cast to the desired output type"
4222                ):
4223                    x.clone().xlogy_(y)
4224            else:
4225                expected = torch.empty_like(x)
4226                torch.xlogy(x, y, out=expected)
4227                inplace_out = x.clone().xlogy_(y)
4228                self.assertEqual(expected, inplace_out)
4229
4230        def test_helper(torch_fn, reference_fn, inputs, scalar=None):
4231            x, y, z = inputs
4232            torch_fn_partial = partial(torch_fn, x)
4233            reference_fn_partial = partial(reference_fn, x.cpu().numpy())
4234            self.compare_with_numpy(
4235                torch_fn_partial, reference_fn_partial, x, exact_dtype=False
4236            )
4237            self.compare_with_numpy(
4238                torch_fn_partial, reference_fn_partial, y, exact_dtype=False
4239            )
4240            self.compare_with_numpy(
4241                torch_fn_partial, reference_fn_partial, z, exact_dtype=False
4242            )
4243
4244            val = scalar if scalar is not None else x
4245            out_variant_helper(torch_fn, val, x)
4246            out_variant_helper(torch_fn, val, y)
4247            out_variant_helper(torch_fn, val, z)
4248
4249        # Tensor-Tensor Test (tensor of same and different shape)
4250        x = make_tensor((3, 2, 4, 5), dtype=x_dtype, device=device, low=0.5, high=1000)
4251        y = make_tensor((3, 2, 4, 5), dtype=y_dtype, device=device, low=0.5, high=1000)
4252        z = make_tensor((4, 5), dtype=y_dtype, device=device, low=0.5, high=1000)
4253
4254        x_1p = make_tensor(
4255            (3, 2, 4, 5), dtype=x_dtype, device=device, low=-0.5, high=1000
4256        )
4257        y_1p = make_tensor(
4258            (3, 2, 4, 5), dtype=y_dtype, device=device, low=-0.5, high=1000
4259        )
4260        z_1p = make_tensor((4, 5), dtype=y_dtype, device=device, low=-0.5, high=1000)
4261
4262        xlogy_fns = torch.xlogy, scipy.special.xlogy
4263        xlog1py_fns = torch.special.xlog1py, scipy.special.xlog1py
4264
4265        test_helper(*xlogy_fns, (x, y, z))
4266        xlogy_inplace_variant_helper(x, x)
4267        xlogy_inplace_variant_helper(x, y)
4268        xlogy_inplace_variant_helper(x, z)
4269        test_helper(*xlog1py_fns, (x_1p, y_1p, z_1p))
4270
4271        # Scalar-Tensor Test
4272        test_helper(*xlogy_fns, (x, y, z), 3.14)
4273        test_helper(*xlog1py_fns, (x_1p, y_1p, z_1p), 3.14)
4274
4275        # Special Values Tensor-Tensor
4276        t = torch.tensor(
4277            [-1.0, 0.0, 1.0, 2.0, float("inf"), -float("inf"), float("nan")],
4278            device=device,
4279        )
4280        zeros = torch.zeros(7, dtype=y_dtype, device=device)
4281
4282        def test_zeros_special_helper(torch_fn, reference_fn, scalar=False):
4283            zeros_t = 0 if scalar else zeros
4284            zeros_np = 0 if scalar else zeros.cpu().numpy()
4285            torch_fn_partial = partial(torch_fn, zeros_t)
4286            reference_fn_partial = partial(reference_fn, zeros_np)
4287            self.compare_with_numpy(
4288                torch_fn_partial, reference_fn_partial, t, exact_dtype=False
4289            )
4290            out_variant_helper(torch_fn, zeros_t, t)
4291
4292        test_zeros_special_helper(*xlogy_fns)
4293        xlogy_inplace_variant_helper(zeros, t)
4294        test_zeros_special_helper(*xlog1py_fns)
4295
4296        # Special Values Scalar-Tensor
4297        test_zeros_special_helper(*xlogy_fns, scalar=True)
4298        test_zeros_special_helper(*xlog1py_fns, scalar=True)
4299
4300    @dtypes(torch.float64)
4301    def test_xlogy_xlog1py_gradients(self, device, dtype):
4302        make_arg = partial(torch.tensor, dtype=dtype, device=device, requires_grad=True)
4303
4304        zeros = torch.zeros((2,), dtype=dtype, device=device)
4305
4306        x = make_arg([0.0, 0.0])
4307        y = make_arg([-1.5, 0.0])
4308        torch.special.xlogy(x, y).sum().backward()
4309        self.assertEqual(x.grad, zeros)
4310
4311        x = make_arg([0.0, 0.0])
4312        y = make_arg([-2.5, -1.0])
4313        torch.special.xlog1py(x, y).sum().backward()
4314        self.assertEqual(x.grad, zeros)
4315
4316    def test_xlogy_xlog1py_scalar_type_promotion(self, device):
4317        # Test that python numbers don't participate in type promotion at the same
4318        # priority level as 0-dim tensors
4319        t = torch.randn((), dtype=torch.float32, device=device)
4320
4321        self.assertEqual(t.dtype, torch.xlogy(t, 5).dtype)
4322        self.assertEqual(t.dtype, torch.xlogy(t, 5.0).dtype)
4323        self.assertEqual(t.dtype, torch.special.xlog1py(t, 5).dtype)
4324        self.assertEqual(t.dtype, torch.special.xlog1py(t, 5.0).dtype)
4325
4326        self.assertEqual(t.dtype, torch.xlogy(5, t).dtype)
4327        self.assertEqual(t.dtype, torch.xlogy(5.0, t).dtype)
4328        self.assertEqual(t.dtype, torch.special.xlog1py(5, t).dtype)
4329        self.assertEqual(t.dtype, torch.special.xlog1py(5.0, t).dtype)
4330
4331    @skipIf(not TEST_SCIPY, "Scipy required for the test.")
4332    def test_xlogy_xlog1py_bfloat16(self, device):
4333        def _compare_helper(x, y, torch_fn, reference_fn):
4334            x_np = x if isinstance(x, float) else x.cpu().to(torch.float).numpy()
4335            y_np = y if isinstance(y, float) else y.cpu().to(torch.float).numpy()
4336            expected = torch.from_numpy(reference_fn(x_np, y_np))
4337            actual = torch_fn(x, y)
4338            self.assertEqual(expected, actual, exact_dtype=False)
4339
4340        x_dtype, y_dtype = torch.bfloat16, torch.bfloat16
4341
4342        # Tensor-Tensor Test (tensor of same and different shape)
4343        x = make_tensor((3, 2, 4, 5), dtype=x_dtype, device=device, low=0.5, high=1000)
4344        y = make_tensor((3, 2, 4, 5), dtype=y_dtype, device=device, low=0.5, high=1000)
4345        z = make_tensor((4, 5), dtype=y_dtype, device=device, low=0.5, high=1000)
4346
4347        x_1p = make_tensor(
4348            (3, 2, 4, 5), dtype=x_dtype, device=device, low=-0.8, high=1000
4349        )
4350        y_1p = make_tensor(
4351            (3, 2, 4, 5), dtype=y_dtype, device=device, low=-0.8, high=1000
4352        )
4353        z_1p = make_tensor((4, 5), dtype=y_dtype, device=device, low=-0.8, high=1000)
4354
4355        xlogy_fns = torch.xlogy, scipy.special.xlogy
4356        xlog1py_fns = torch.special.xlog1py, scipy.special.xlog1py
4357
4358        _compare_helper(x, x, *xlogy_fns)
4359        _compare_helper(x, y, *xlogy_fns)
4360        _compare_helper(x, z, *xlogy_fns)
4361        _compare_helper(x, 3.14, *xlogy_fns)
4362        _compare_helper(y, 3.14, *xlogy_fns)
4363        _compare_helper(z, 3.14, *xlogy_fns)
4364
4365        _compare_helper(x_1p, x_1p, *xlog1py_fns)
4366        _compare_helper(x_1p, y_1p, *xlog1py_fns)
4367        _compare_helper(x_1p, z_1p, *xlog1py_fns)
4368        _compare_helper(x_1p, 3.14, *xlog1py_fns)
4369        _compare_helper(y_1p, 3.14, *xlog1py_fns)
4370        _compare_helper(z_1p, 3.14, *xlog1py_fns)
4371
4372        # Special Values Tensor-Tensor
4373        t = torch.tensor(
4374            [-1.0, 0.0, 1.0, 2.0, float("inf"), -float("inf"), float("nan")],
4375            device=device,
4376        )
4377        zeros = torch.tensor(7, dtype=y_dtype, device=device)
4378
4379        _compare_helper(t, zeros, *xlogy_fns)
4380        _compare_helper(t, 0.0, *xlogy_fns)
4381
4382        _compare_helper(t, zeros, *xlog1py_fns)
4383        _compare_helper(t, 0.0, *xlog1py_fns)
4384
4385    @dtypes(*product(all_types_and(torch.bool), all_types_and(torch.bool)))
4386    @skipIf(not TEST_SCIPY, "Scipy required for the test.")
4387    @slowTest
4388    def test_zeta(self, device, dtypes):
4389        x_dtype, q_dtype = dtypes
4390
4391        def test_helper(x, q):
4392            x_np = x if isinstance(x, float) else x.cpu().numpy()
4393            q_np = q if isinstance(q, float) else q.cpu().numpy()
4394            expected = torch.from_numpy(scipy.special.zeta(x_np, q_np))
4395            actual = torch.special.zeta(x, q)
4396
4397            rtol, atol = None, None
4398            if self.device_type == "cpu":
4399                rtol, atol = 1e-6, 1e-6
4400            self.assertEqual(expected, actual, rtol=rtol, atol=atol, exact_dtype=False)
4401
4402        # x tensor - q tensor same size
4403        x = make_tensor((2, 3, 4), dtype=x_dtype, device=device)
4404        q = make_tensor((2, 3, 4), dtype=q_dtype, device=device)
4405        test_helper(x, q)
4406
4407        # x tensor - q tensor broadcast lhs
4408        x = make_tensor((2, 1, 4), dtype=x_dtype, device=device)
4409        q = make_tensor((2, 3, 4), dtype=q_dtype, device=device)
4410        test_helper(x, q)
4411
4412        # x tensor - q tensor broadcast rhs
4413        x = make_tensor((2, 3, 4), dtype=x_dtype, device=device)
4414        q = make_tensor((2, 1, 4), dtype=q_dtype, device=device)
4415        test_helper(x, q)
4416
4417        # x tensor - q tensor broadcast all
4418        x = make_tensor((2, 3, 1), dtype=x_dtype, device=device)
4419        q = make_tensor((2, 1, 4), dtype=q_dtype, device=device)
4420        test_helper(x, q)
4421
4422        # x scalar - q tensor
4423        for x in np.linspace(-5, 5, num=10).tolist():
4424            if not q_dtype.is_floating_point:
4425                q_dtype = torch.get_default_dtype()
4426            q = make_tensor((2, 3, 4), dtype=q_dtype, device=device)
4427            test_helper(x, q)
4428
4429        # x tensor - q scalar
4430        for q in np.linspace(-5, 5, num=10).tolist():
4431            if not x_dtype.is_floating_point:
4432                x_dtype = torch.get_default_dtype()
4433            x = make_tensor((2, 3, 4), dtype=x_dtype, device=device)
4434            test_helper(x, q)
4435
4436    @onlyCUDA
4437    @dtypes(torch.chalf)
4438    def test_mul_chalf_tensor_and_cpu_scalar(self, device, dtype):
4439        # Tests that Tensor and CPU Scalar work for `mul` for chalf.
4440        # Ideally, this should be covered by `test_complex_half_reference_testing`
4441        # from test_ops.py by checking reference_samples from the OpInfo.
4442        # But currently that doesn't work as sample generation requires support of
4443        # `index_select` which is not implemented for `complex32` at the
4444        # time of writing this test.
4445        # TODO: Remove this test once above issue is fixed.
4446        # Ref: https://github.com/pytorch/pytorch/pull/76364
4447        x = make_tensor((2, 2), device=device, dtype=dtype)
4448        self.assertEqual(x * 2.5, x * torch.tensor(2.5, device=device, dtype=dtype))
4449
4450
4451tensor_binary_ops = [
4452    "__lt__",
4453    "__le__",
4454    "__gt__",
4455    "__ge__",
4456    "__eq__",
4457    "__ne__",
4458    "__add__",
4459    "__radd__",
4460    "__iadd__",
4461    "__sub__",
4462    "__rsub__",
4463    "__isub__",
4464    "__mul__",
4465    "__rmul__",
4466    "__imul__",
4467    "__matmul__",
4468    "__rmatmul__",
4469    "__truediv__",
4470    "__rtruediv__",
4471    "__itruediv__",
4472    "__floordiv__",
4473    "__rfloordiv__",
4474    "__ifloordiv__",
4475    "__mod__",
4476    "__rmod__",
4477    "__imod__",
4478    "__pow__",
4479    "__rpow__",
4480    "__ipow__",
4481    "__lshift__",
4482    "__rlshift__",
4483    "__ilshift__",
4484    "__rshift__",
4485    "__rrshift__",
4486    "__irshift__",
4487    "__and__",
4488    "__rand__",
4489    "__iand__",
4490    "__xor__",
4491    "__rxor__",
4492    "__ixor__",
4493    "__or__",
4494    "__ror__",
4495    "__ior__",
4496    # Unsupported operators
4497    # '__imatmul__',
4498    # '__divmod__', '__rdivmod__', '__idivmod__',
4499]
4500
4501
4502# Test that binary math operations return NotImplemented for unknown types.
4503def generate_not_implemented_tests(cls):
4504    class UnknownType:
4505        pass
4506
4507    # TODO: refactor to inline these
4508    _types = [
4509        torch.half,
4510        torch.float,
4511        torch.double,
4512        torch.int8,
4513        torch.short,
4514        torch.int,
4515        torch.long,
4516        torch.uint8,
4517    ]
4518
4519    def create_test_func(op):
4520        @dtypes(*_types)
4521        def test(self, device, dtype):
4522            # Generate the inputs
4523            tensor = torch.empty((), device=device, dtype=dtype)
4524
4525            # Runs the tensor op on the device
4526            result = getattr(tensor, op)(UnknownType())
4527            self.assertEqual(result, NotImplemented)
4528
4529        return test
4530
4531    for op in tensor_binary_ops:
4532        test_name = f"test_{op}_not_implemented"
4533        assert not hasattr(cls, test_name), f"{test_name} already in {cls.__name__}"
4534
4535        setattr(cls, test_name, create_test_func(op))
4536
4537
4538generate_not_implemented_tests(TestBinaryUfuncs)
4539instantiate_device_type_tests(TestBinaryUfuncs, globals())
4540
4541if __name__ == "__main__":
4542    run_tests()
4543