xref: /aosp_15_r20/external/pytorch/test/test_unary_ufuncs.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: tests"]
2
3import torch
4import numpy as np
5
6import math
7from numbers import Number
8import random
9import unittest
10
11from torch import inf, nan
12from torch.testing._internal.common_utils import (
13    TestCase,
14    run_tests,
15    torch_to_numpy_dtype_dict,
16    numpy_to_torch_dtype_dict,
17    suppress_warnings,
18    TEST_SCIPY,
19    slowTest,
20    skipIfNoSciPy,
21    IS_WINDOWS,
22    gradcheck,
23    is_iterable_of_tensors,
24    xfailIfTorchDynamo,
25)
26from torch.testing._internal.common_methods_invocations import (
27    unary_ufuncs,
28    generate_elementwise_unary_tensors,
29    generate_elementwise_unary_small_value_tensors,
30    generate_elementwise_unary_large_value_tensors,
31    generate_elementwise_unary_extremal_value_tensors,
32)
33from torch.testing._internal.common_device_type import (
34    instantiate_device_type_tests,
35    ops,
36    dtypes,
37    onlyCPU,
38    onlyNativeDeviceTypes,
39    onlyCUDA,
40    dtypesIfCUDA,
41    precisionOverride,
42    dtypesIfCPU,
43)
44from torch.utils import _pytree as pytree
45
46from torch.testing import make_tensor
47from torch.testing._internal.common_dtype import (
48    floating_types_and,
49    all_types_and_complex_and,
50    integral_types_and,
51    get_all_math_dtypes,
52    complex_types,
53    floating_and_complex_types_and,
54)
55
56if TEST_SCIPY:
57    import scipy
58
59# Refer [scipy reference filter]
60# Filter operators for which the reference function
61# is available in the current environment (for reference_numerics tests).
62reference_filtered_ops = list(filter(lambda op: op.ref is not None, unary_ufuncs))
63
64# Tests for unary "universal functions (ufuncs)" that accept a single
65# tensor and have common properties like:
66#   - they are elementwise functions
67#   - the input shape is the output shape
68#   - they typically have method and inplace variants
69#   - they typically support the out kwarg
70#   - they typically have NumPy or SciPy references
71
72# See NumPy's universal function documentation
73# (https://numpy.org/doc/1.18/reference/ufuncs.html) for more details
74# about the concept of ufuncs.
75
76
77# TODO: port test_unary_out_op_mem_overlap
78# TODO: add test for inplace variants erroring on broadcasted inputs
79class TestUnaryUfuncs(TestCase):
80    exact_dtype = True
81
82    @ops(
83        [_fn for _fn in unary_ufuncs if _fn.domain != (None, None)],
84        allowed_dtypes=floating_types_and(torch.bfloat16, torch.half),
85    )
86    def test_float_domains(self, device, dtype, op):
87        eps = (1e-5, 1e-3, 1e-1, 1, 2, 10, 20, 50, 100)
88
89        low, high = op.domain
90        # NOTE: the following two loops are separated for readability
91        if low is not None:
92            low_tensor = torch.tensor(low, device=device, dtype=dtype)
93            for epsilon in eps:
94                lower_tensor = low_tensor - epsilon
95
96                # Skips the test if the difference is not representable,
97                #   which can occur if, for example, the difference is small
98                #   and the dtype is imprecise (like bfloat16 is)
99                if lower_tensor.item() == low_tensor.item():
100                    continue
101
102                result = op(lower_tensor)
103                self.assertEqual(
104                    result.item(),
105                    float("nan"),
106                    msg=(
107                        f"input of {lower_tensor.item()} outside lower domain boundary"
108                        f" {low} produced {result.item()}, not nan!"
109                    ),
110                )
111
112        if high is not None:
113            high_tensor = torch.tensor(high, device=device, dtype=dtype)
114            for epsilon in eps:
115                higher_tensor = high_tensor + epsilon
116
117                # See above comment
118                if higher_tensor.item() == high_tensor.item():
119                    continue
120
121                result = op(higher_tensor)
122                self.assertEqual(
123                    result.item(),
124                    float("nan"),
125                    msg=(
126                        f"input of {higher_tensor.item()} outside upper domain boundary"
127                        f" {high} produced {result.item()}, not nan!"
128                    ),
129                )
130
131    # Helper for comparing torch tensors and numpy arrays
132    # TODO: should this or assertEqual also validate that strides are equal?
133    def assertEqualHelper(
134        self, actual, expected, msg, *, dtype, exact_dtype=True, **kwargs
135    ):
136        assert isinstance(actual, torch.Tensor)
137
138        # Some NumPy functions return scalars, not arrays
139        if isinstance(expected, Number):
140            self.assertEqual(actual.item(), expected, msg, **kwargs)
141        elif isinstance(expected, np.ndarray):
142            # Handles exact dtype comparisons between arrays and tensors
143            if exact_dtype:
144                if (
145                    actual.dtype is torch.bfloat16
146                    or expected.dtype != torch_to_numpy_dtype_dict[actual.dtype]
147                ):
148                    # Allows array dtype to be float32 when comparing with bfloat16 tensors
149                    #   since NumPy doesn't support the bfloat16 dtype
150                    # Also ops like scipy.special.erf, scipy.special.erfc, etc, promote float16
151                    # to float32
152                    if expected.dtype == np.float32:
153                        assert actual.dtype in (
154                            torch.float16,
155                            torch.bfloat16,
156                            torch.float32,
157                        )
158                    elif expected.dtype == np.float64:
159                        assert actual.dtype in (
160                            torch.float16,
161                            torch.bfloat16,
162                            torch.float32,
163                            torch.float64,
164                        )
165                    else:
166                        self.fail(
167                            f"Expected dtype {expected.dtype} but got {actual.dtype}!"
168                        )
169
170            self.assertEqual(
171                actual,
172                torch.from_numpy(expected).to(actual.dtype),
173                msg,
174                exact_device=False,
175                **kwargs
176            )
177        else:
178            self.assertEqual(actual, expected, msg, exact_device=False, **kwargs)
179
180    # Tests that the function and its (array-accepting) reference produce the same
181    #   values on given tensors
182    def _test_reference_numerics(self, dtype, op, tensors, equal_nan=True):
183        def _helper_reference_numerics(
184            expected, actual, msg, exact_dtype, equal_nan=True
185        ):
186            if not torch.can_cast(
187                numpy_to_torch_dtype_dict[expected.dtype.type], dtype
188            ):
189                exact_dtype = False
190
191            if dtype in [torch.uint8, torch.int8, torch.bool]:
192                # NOTE: For these dtypes, PyTorch computes in the default scalar type (float)
193                # while NumPy computes in float16
194                self.assertEqualHelper(
195                    actual,
196                    expected,
197                    msg,
198                    dtype=dtype,
199                    exact_dtype=exact_dtype,
200                    rtol=1e-3,
201                    atol=1e-2,
202                )
203            elif dtype is torch.bfloat16:
204                # Ref: https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_utils.py#L1149
205                self.assertEqualHelper(
206                    actual,
207                    expected,
208                    msg,
209                    dtype=dtype,
210                    exact_dtype=exact_dtype,
211                    rtol=16e-3,
212                    atol=1e-5,
213                )
214            elif dtype is torch.half:
215                self.assertEqualHelper(
216                    actual,
217                    expected,
218                    msg,
219                    dtype=dtype,
220                    exact_dtype=exact_dtype,
221                    rtol=1.2e-03,
222                    atol=1e-03,
223                )
224            else:
225                self.assertEqualHelper(
226                    actual,
227                    expected,
228                    msg,
229                    dtype=dtype,
230                    equal_nan=equal_nan,
231                    exact_dtype=exact_dtype,
232                )
233
234        for t in tensors:
235            t = t.input
236            torch_kwargs, numpy_kwargs = op.sample_kwargs(t.device, dtype, t)
237            if dtype is torch.bfloat16:
238                a = t.cpu().to(torch.float32).numpy()
239            elif dtype is torch.complex32:
240                a = t.cpu().to(torch.complex64).numpy()
241            else:
242                a = t.cpu().numpy()
243
244            actual = op(t, **torch_kwargs)
245            expected = op.ref(a, **numpy_kwargs)
246
247            # Crafts a custom error message for smaller, printable tensors
248            if t.numel() < 10:
249                msg = (
250                    "Failed to produce expected results! Input tensor was"
251                    f" {t}, torch result is {actual}, and reference result is"
252                    f" {expected}."
253                )
254            else:
255                msg = None
256
257            exact_dtype = True
258            if isinstance(actual, torch.Tensor):
259                _helper_reference_numerics(
260                    expected, actual, msg, exact_dtype, equal_nan
261                )
262            else:
263                for x, y in zip(expected, actual):
264                    # testing multi-outputs results
265                    _helper_reference_numerics(x, y, msg, exact_dtype, equal_nan)
266
267    # Tests that the function and its (array-accepting) reference produce the same
268    #   values on a range of tensors, including empty tensors, scalar tensors,
269    #   1D tensors and a large 2D tensor with interesting and extremal values
270    #   and noncontiguities.
271    @suppress_warnings
272    @ops(reference_filtered_ops)
273    def test_reference_numerics_normal(self, device, dtype, op):
274        tensors = generate_elementwise_unary_tensors(
275            op, device=device, dtype=dtype, requires_grad=False
276        )
277        self._test_reference_numerics(dtype, op, tensors)
278
279    @suppress_warnings
280    @ops(reference_filtered_ops)
281    def test_reference_numerics_small(self, device, dtype, op):
282        if dtype in (torch.bool,):
283            raise self.skipTest("bool has no small values")
284
285        tensors = generate_elementwise_unary_small_value_tensors(
286            op, device=device, dtype=dtype, requires_grad=False
287        )
288        self._test_reference_numerics(dtype, op, tensors)
289
290    @suppress_warnings
291    @ops(reference_filtered_ops)
292    def test_reference_numerics_large(self, device, dtype, op):
293        if dtype in (torch.bool, torch.uint8, torch.int8):
294            raise self.skipTest("bool, uint8, and int8 dtypes have no large values")
295
296        tensors = generate_elementwise_unary_large_value_tensors(
297            op, device=device, dtype=dtype, requires_grad=False
298        )
299        self._test_reference_numerics(dtype, op, tensors)
300
301    @suppress_warnings
302    @ops(
303        reference_filtered_ops,
304        allowed_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half),
305    )
306    def test_reference_numerics_extremal(self, device, dtype, op):
307        tensors = generate_elementwise_unary_extremal_value_tensors(
308            op, device=device, dtype=dtype, requires_grad=False
309        )
310        self._test_reference_numerics(dtype, op, tensors)
311
312    # Tests for testing (non)contiguity consistency
313    @ops(unary_ufuncs)
314    def test_contig_vs_every_other(self, device, dtype, op):
315        contig = make_tensor(
316            (1026,), device=device, dtype=dtype, low=op.domain[0], high=op.domain[1]
317        )
318        non_contig = contig[::2]
319
320        self.assertTrue(contig.is_contiguous())
321        self.assertFalse(non_contig.is_contiguous())
322
323        torch_kwargs, _ = op.sample_kwargs(device, dtype, non_contig)
324        expected = op(non_contig, **torch_kwargs)
325        result = op(contig, **torch_kwargs)
326        result = pytree.tree_map(lambda x: x[::2], result)
327        self.assertEqual(result, expected)
328
329    @ops(unary_ufuncs)
330    def test_contig_vs_transposed(self, device, dtype, op):
331        contig = make_tensor(
332            (789, 357), device=device, dtype=dtype, low=op.domain[0], high=op.domain[1]
333        )
334        non_contig = contig.T
335
336        self.assertTrue(contig.is_contiguous())
337        self.assertFalse(non_contig.is_contiguous())
338
339        torch_kwargs, _ = op.sample_kwargs(device, dtype, contig)
340        expected = op(non_contig, **torch_kwargs)
341        result = op(contig, **torch_kwargs)
342        result = pytree.tree_map(lambda x: x.T, result)
343        self.assertEqual(result, expected)
344
345    @ops(unary_ufuncs)
346    def test_non_contig(self, device, dtype, op):
347        shapes = [(5, 7), (1024,)]
348        for shape in shapes:
349            contig = make_tensor(
350                shape, dtype=dtype, device=device, low=op.domain[0], high=op.domain[1]
351            )
352            non_contig = torch.empty(shape + (2,), device=device, dtype=dtype)[..., 0]
353            non_contig.copy_(contig)
354
355            self.assertTrue(contig.is_contiguous())
356            self.assertFalse(non_contig.is_contiguous())
357
358            torch_kwargs, _ = op.sample_kwargs(device, dtype, contig)
359            self.assertEqual(op(contig, **torch_kwargs), op(non_contig, **torch_kwargs))
360
361    @ops(unary_ufuncs)
362    def test_non_contig_index(self, device, dtype, op):
363        contig = make_tensor(
364            (2, 2, 1, 2),
365            dtype=dtype,
366            device=device,
367            low=op.domain[0],
368            high=op.domain[1],
369        )
370        non_contig = contig[:, 1, ...]
371        contig = non_contig.contiguous()
372
373        self.assertTrue(contig.is_contiguous())
374        self.assertFalse(non_contig.is_contiguous())
375
376        torch_kwargs, _ = op.sample_kwargs(device, dtype, contig)
377        self.assertEqual(op(contig, **torch_kwargs), op(non_contig, **torch_kwargs))
378
379    @ops(unary_ufuncs)
380    def test_non_contig_expand(self, device, dtype, op):
381        shapes = [(1, 3), (1, 7), (5, 7)]
382        for shape in shapes:
383            contig = make_tensor(
384                shape, dtype=dtype, device=device, low=op.domain[0], high=op.domain[1]
385            )
386            non_contig = contig.clone().expand(3, -1, -1)
387
388            self.assertTrue(contig.is_contiguous())
389            self.assertFalse(non_contig.is_contiguous())
390
391            torch_kwargs, _ = op.sample_kwargs(device, dtype, contig)
392            contig = op(contig, **torch_kwargs)
393            non_contig = op(non_contig, **torch_kwargs)
394            for i in range(3):
395                non_contig_i = pytree.tree_map(lambda x: x[i], non_contig)
396                self.assertEqual(
397                    contig, non_contig_i, msg="non-contiguous expand[" + str(i) + "]"
398                )
399
400    @ops(unary_ufuncs)
401    def test_contig_size1(self, device, dtype, op):
402        contig = make_tensor(
403            (5, 100), dtype=dtype, device=device, low=op.domain[0], high=op.domain[1]
404        )
405        contig = contig[:1, :50]
406        contig2 = torch.empty(contig.size(), device=device, dtype=dtype)
407        contig2.copy_(contig)
408
409        self.assertTrue(contig.is_contiguous())
410        self.assertTrue(contig2.is_contiguous())
411
412        torch_kwargs, _ = op.sample_kwargs(device, dtype, contig)
413        self.assertEqual(op(contig, **torch_kwargs), op(contig2, **torch_kwargs))
414
415    @ops(unary_ufuncs)
416    def test_contig_size1_large_dim(self, device, dtype, op):
417        contig = make_tensor(
418            (5, 2, 3, 1, 4, 5, 3, 2, 1, 2, 3, 4),
419            dtype=dtype,
420            device=device,
421            low=op.domain[0],
422            high=op.domain[1],
423        )
424        contig = contig[:1, :, :, :, :, :, :, :, :, :, :, :]
425        contig2 = torch.empty(contig.size(), device=device, dtype=dtype)
426        contig2.copy_(contig)
427
428        self.assertTrue(contig.is_contiguous())
429        self.assertTrue(contig2.is_contiguous())
430
431        torch_kwargs, _ = op.sample_kwargs(device, dtype, contig)
432        self.assertEqual(op(contig, **torch_kwargs), op(contig2, **torch_kwargs))
433
434    # Tests that computation on a multiple batches is the same as
435    # per-batch computation.
436    @ops(unary_ufuncs)
437    def test_batch_vs_slicing(self, device, dtype, op):
438        input = make_tensor(
439            (1024, 512), dtype=dtype, device=device, low=op.domain[0], high=op.domain[1]
440        )
441
442        torch_kwargs, _ = op.sample_kwargs(device, dtype, input)
443        actual = op(input, **torch_kwargs)
444
445        all_outs = [op(slice, **torch_kwargs) for slice in input]
446        if is_iterable_of_tensors(actual):
447            expected = [torch.stack([out[i] for out in all_outs]) for i in range(len(actual))]
448        else:
449            expected = torch.stack(all_outs)
450
451        self.assertEqual(actual, expected)
452
453    @dtypes(*all_types_and_complex_and(torch.bool, torch.half))
454    def test_nan_to_num(self, device, dtype):
455        for contiguous in [False, True]:
456            x = make_tensor((64, 64), low=0.0, high=100.0, dtype=dtype, device=device)
457
458            if dtype.is_floating_point:
459                # Add extremal values.
460                extremals = [float("nan"), float("inf"), -float("inf")]
461                for idx, extremal in zip(torch.randint(0, 63, (3,)), extremals):
462                    x[idx, :] = extremal
463
464            if not contiguous:
465                x = x.T
466
467            # With args
468            nan = random.random()
469            posinf = random.random() * 5
470            neginf = random.random() * 10
471
472            self.compare_with_numpy(
473                lambda x: x.nan_to_num(nan=nan, posinf=posinf),
474                lambda x: np.nan_to_num(x, nan=nan, posinf=posinf),
475                x,
476            )
477            self.compare_with_numpy(
478                lambda x: x.nan_to_num(posinf=posinf, neginf=neginf),
479                lambda x: np.nan_to_num(x, posinf=posinf, neginf=neginf),
480                x,
481            )
482
483            # Out Variant
484            out = torch.empty_like(x)
485            result = torch.nan_to_num(x)
486            torch.nan_to_num(x, out=out)
487            self.assertEqual(result, out)
488
489            result = torch.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf)
490            torch.nan_to_num(x, out=out, nan=nan, posinf=posinf, neginf=neginf)
491            self.assertEqual(result, out)
492
493    @onlyCPU
494    def test_nan_to_num_bfloat16(self, device):
495        def test_dtype(fn, input, dtype):
496            input = input.detach().clone().to(dtype=dtype).requires_grad_(True)
497            input2 = input.detach().clone().float().requires_grad_(True)
498            out = fn(input)
499            out.sum().backward()
500            out2 = fn(input2)
501            out2.sum().backward()
502            self.assertEqual(out.dtype, dtype)
503            self.assertEqual(input.grad.dtype, dtype)
504            self.assertEqual(out, out2, exact_dtype=False)
505            self.assertEqual(input.grad, input2.grad, exact_dtype=False)
506
507        def func():
508            return torch.nan_to_num
509
510        shapes = [[1, 3, 6, 6], [1, 3, 6, 128], [1, 3, 256, 256]]
511        for shape in shapes:
512            x = torch.randn(shape, device=device)
513            extremals = [float('nan'), float('inf'), -float('inf')]
514            for id1, id2, extremal in zip(torch.randint(0, 2, (3,)), torch.randint(0, 5, (3,)), extremals):
515                x[0, id1, id2, :] = extremal
516            test_dtype(func(), x, torch.bfloat16)
517
518    @dtypes(torch.complex64, torch.complex128)
519    def test_nan_to_num_complex(self, device, dtype):
520        value_dtype = torch.tensor([], dtype=dtype).real.dtype
521
522        def gen_tensor(a):
523            return torch.view_as_complex(torch.tensor(a, dtype=value_dtype, device=device))
524
525        for extremal, kwarg_name in zip(['nan', 'inf', '-inf'], ['nan', 'posinf', 'neginf']):
526            a = gen_tensor([123, float(extremal)])
527            res = torch.nan_to_num(a, **{kwarg_name: 12})
528            res_check = gen_tensor([123, 12])
529            self.assertEqual(res, res_check)
530
531            a = gen_tensor([float(extremal), 456])
532            res = torch.nan_to_num(a, **{kwarg_name: 21})
533            res_check = gen_tensor([21, 456])
534            self.assertEqual(res, res_check)
535
536    @dtypes(torch.cdouble)
537    def test_complex_edge_values(self, device, dtype):
538        # sqrt Test Reference: https://github.com/pytorch/pytorch/pull/47424
539        x = torch.tensor(0.0 - 1.0e20j, dtype=dtype, device=device)
540        self.compare_with_numpy(torch.sqrt, np.sqrt, x)
541        # acos test reference: https://github.com/pytorch/pytorch/issue/42952
542        # Skip on Windows, as CUDA acos  returns conjugate value
543        # see https://github.com/pytorch/pytorch/issues/52299
544        if not (IS_WINDOWS and dtype == torch.cdouble and "cuda" in device):
545            self.compare_with_numpy(torch.acos, np.arccos, x)
546
547        x = torch.tensor(
548            (-1.0e60 if dtype == torch.cdouble else -1.0e20) - 4988429.2j,
549            dtype=dtype,
550            device=device,
551        )
552        self.compare_with_numpy(torch.sqrt, np.sqrt, x)
553
554    @unittest.skipIf(not TEST_SCIPY, "Requires SciPy")
555    @dtypes(torch.float, torch.double)
556    def test_digamma_special(self, device, dtype):
557        # Based on SciPy test for the following special values.
558        # Reference:
559        # https://github.com/scipy/scipy/blob/3a8a3a1d4657254a6611e77e9c28feafa26e6645/scipy/special/tests/test_digamma.py#L22
560        euler = 0.57721566490153286
561        dataset = [
562            (0.0, -0.0),
563            (1, -euler),
564            (0.5, -2 * math.log(2) - euler),
565            (1 / 3, -math.pi / (2 * math.sqrt(3)) - 3 * math.log(3) / 2 - euler),
566            (1 / 4, -math.pi / 2 - 3 * math.log(2) - euler),
567            (
568                1 / 6,
569                -math.pi * math.sqrt(3) / 2
570                - 2 * math.log(2)
571                - 3 * math.log(3) / 2
572                - euler,
573            ),
574            (
575                1 / 8,
576                -math.pi / 2
577                - 4 * math.log(2)
578                - (math.pi + math.log(2 + math.sqrt(2)) - math.log(2 - math.sqrt(2)))
579                / math.sqrt(2)
580                - euler,
581            ),
582        ]
583        x = torch.tensor(dataset, device=device, dtype=dtype)
584        self.compare_with_numpy(torch.digamma, scipy.special.digamma, x)
585
586    @unittest.skipIf(not TEST_SCIPY, "Requires SciPy")
587    @dtypes(torch.float, torch.double)
588    def test_digamma(self, device, dtype):
589        # Tests pole behavior
590        tensor = torch.tensor(
591            [
592                -0.999999994,
593                -1.999999994,
594                -2.0000000111,
595                -100.99999994,
596                0.000000111,
597                -1931.99999994,
598                -0.000000111,
599                0,
600                -0,
601                -1,
602                -2,
603                -931,
604            ],
605            dtype=dtype,
606            device=device,
607        )
608        self.compare_with_numpy(torch.digamma, scipy.special.digamma, tensor)
609
610    @dtypes(*floating_types_and(torch.half))
611    def test_frexp(self, device, dtype):
612        input = make_tensor((50, 50), dtype=dtype, device=device)
613        mantissa, exponent = torch.frexp(input)
614        np_mantissa, np_exponent = np.frexp(input.cpu().numpy())
615
616        self.assertEqual(mantissa, np_mantissa)
617        self.assertEqual(exponent, np_exponent)
618
619        # torch.frexp returns exponent in int32 to be compatible with np.frexp
620        self.assertTrue(exponent.dtype == torch.int32)
621        self.assertTrue(torch_to_numpy_dtype_dict[exponent.dtype] == np_exponent.dtype)
622
623    def test_frexp_assert_raises(self, device):
624        invalid_input_dtypes = integral_types_and(torch.bool) + complex_types()
625        for dtype in invalid_input_dtypes:
626            input = make_tensor((50, 50), dtype=dtype, device=device)
627            with self.assertRaisesRegex(
628                RuntimeError, r"torch\.frexp\(\) only supports floating-point dtypes"
629            ):
630                torch.frexp(input)
631
632        for dtype in floating_types_and(torch.half):
633            input = make_tensor((50, 50), dtype=dtype, device=device)
634
635            dtypes = list(
636                all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16)
637            )
638            dtypes.remove(dtype)
639            for mantissa_dtype in dtypes:
640                mantissa = torch.empty_like(input, dtype=mantissa_dtype)
641                exponent = torch.empty_like(input, dtype=torch.int)
642                with self.assertRaisesRegex(
643                    RuntimeError,
644                    r"torch\.frexp\(\) expects mantissa to have dtype .+ but got .+",
645                ):
646                    torch.frexp(input, out=(mantissa, exponent))
647
648            dtypes.append(dtype)
649            dtypes.remove(torch.int)
650            for exponent_dtype in dtypes:
651                mantissa = torch.empty_like(input)
652                exponent = torch.empty_like(input, dtype=exponent_dtype)
653                with self.assertRaisesRegex(
654                    RuntimeError,
655                    r"torch\.frexp\(\) expects exponent to have int dtype but got .+",
656                ):
657                    torch.frexp(input, out=(mantissa, exponent))
658
659    def test_polygamma_neg(self, device):
660        with self.assertRaisesRegex(
661            RuntimeError, r"polygamma\(n, x\) does not support negative n\."
662        ):
663            torch.polygamma(-1, torch.tensor([1.0, 2.0], device=device))
664
665    # TODO resolve with opinfos
666    @onlyCPU
667    def test_op_invert(self, device):
668        res = 0xFFFF - torch.arange(127, dtype=torch.int8)
669        for dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
670            a = torch.arange(127, dtype=dtype)
671            self.assertEqual(res.to(dtype), ~a)
672
673        self.assertEqual(torch.tensor([True, False]), ~torch.tensor([False, True]))
674
675        # test exceptions
676        for dtype in (torch.half, torch.float, torch.double):
677            a = torch.zeros(10, dtype=dtype)
678            with self.assertRaises(TypeError):
679                b = ~a
680
681    @dtypes(torch.complex64, torch.complex128)
682    def test_abs_angle_complex_to_float(self, device, dtype):
683        # Constructs random complex values
684        from random import random
685
686        random_vals = []
687        for multiplier in (-1, 1, -10, 10, -100, 100):
688            for _ in range(10):
689                random_vals.append(
690                    complex(random() * multiplier, random() * multiplier)
691                )
692
693        for vals in (random_vals, []):
694            a = np.array(vals, dtype=torch_to_numpy_dtype_dict[dtype])
695            t = torch.tensor(vals, device=device, dtype=dtype)
696
697            for fn_name in ("abs", "angle"):
698                torch_fn = getattr(torch, fn_name)
699                np_fn = getattr(np, fn_name)
700
701                # Tests function
702                np_result = torch.from_numpy(np_fn(a))
703                torch_result = torch_fn(t).cpu()
704                self.assertEqual(np_result, torch_result, exact_dtype=True)
705
706                # Tests float out
707                float_dtype = (
708                    torch.float32 if dtype is torch.complex64 else torch.float64
709                )
710                np_float_out = np_fn(a).astype(torch_to_numpy_dtype_dict[float_dtype])
711                float_out = torch.empty_like(t, dtype=float_dtype)
712                torch_fn(t, out=float_out)
713                self.assertEqual(torch.from_numpy(np_float_out), float_out.cpu())
714
715                # Tests float out (resized out)
716                float_out = torch.empty(1, device=device, dtype=float_dtype)
717                torch_fn(t, out=float_out)
718                self.assertEqual(torch.from_numpy(np_float_out), float_out.cpu())
719
720                # Tests complex out
721                np_complex_out = np_fn(a).astype(torch_to_numpy_dtype_dict[dtype])
722                complex_out = torch.empty_like(t)
723                torch_fn(t, out=complex_out)
724                self.assertEqual(torch.from_numpy(np_complex_out), complex_out.cpu())
725
726                # Tests complex out (resized out)
727                complex_out = torch.empty(0, device=device, dtype=dtype)
728                torch_fn(t, out=complex_out)
729                self.assertEqual(torch.from_numpy(np_complex_out), complex_out.cpu())
730
731                # Tests long out behavior (expected failure)
732                long_out = torch.empty(0, device=device, dtype=torch.long)
733                with self.assertRaises(RuntimeError):
734                    torch_fn(t, out=long_out)
735
736                # Tests inplace
737                if fn_name == "abs":
738                    torch_inplace_method = getattr(torch.Tensor, fn_name + "_")
739                    np_fn(a, out=a)
740                    if dtype.is_complex:
741                        with self.assertRaisesRegex(
742                            RuntimeError,
743                            "In-place abs is not supported for complex tensors.",
744                        ):
745                            torch_inplace_method(t)
746                        return
747                    torch_inplace_method(t)
748                    self.assertEqual(torch.from_numpy(a), t.cpu())
749
750                # Note: angle does not have an in-place variant
751                if fn_name == "angle":
752                    with self.assertRaises(AttributeError):
753                        torch_inplace_method = getattr(torch.Tensor, fn_name + "_")
754
755    def check_internal_mem_overlap(
756        self, inplace_op, num_inputs, dtype, device, expected_failure=False
757    ):
758        if isinstance(inplace_op, str):
759            inplace_op = getattr(torch.Tensor, inplace_op)
760        input = torch.randn(1, dtype=dtype, device=device).expand(3, 3)
761        inputs = [input] + [torch.randn_like(input) for i in range(num_inputs - 1)]
762        if not expected_failure:
763            with self.assertRaisesRegex(RuntimeError, "single memory location"):
764                inplace_op(*inputs)
765        else:
766            with self.assertRaises(AssertionError):
767                with self.assertRaisesRegex(RuntimeError, "single memory location"):
768                    inplace_op(*inputs)
769
770    def unary_check_input_output_mem_overlap(
771        self, data, sz, op, expected_failure=False
772    ):
773        def _test(op, output, input):
774            output_exp = torch.empty_like(output)
775            op(input, out=output_exp)
776            self.assertEqual(op(input, out=output), output_exp, msg=op.__name__)
777
778        # output is identical to input:
779        _test(op, output=data[0:sz], input=data[0:sz])
780        # output and input are independent:
781        _test(op, output=data[0:sz], input=data[sz : 2 * sz])
782        # output partially overlaps with input:
783        if not expected_failure:
784            with self.assertRaisesRegex(RuntimeError, "unsupported operation"):
785                _test(op, data[0:sz], data[1 : sz + 1])
786        else:
787            with self.assertRaises(AssertionError):
788                with self.assertRaisesRegex(RuntimeError, "unsupported operation"):
789                    _test(op, data[0:sz], data[1 : sz + 1])
790
791    # TODO: run on non-native device types
792    # https://github.com/pytorch/pytorch/issues/126474
793    @xfailIfTorchDynamo
794    @dtypes(torch.double)
795    def test_unary_out_op_mem_overlap(self, device, dtype):
796        sz = 3
797        doubles = torch.randn(2 * sz, dtype=dtype, device=device)
798        positives = torch.randint(1, 100, (2 * sz,), device=device).double()
799        ints = torch.randint(-100, 100, (2 * sz,), device=device)
800        unary_mem_overlap_cases = [
801            ("abs", doubles, True, True, "cpu"),
802            ("abs", doubles, True, True, "cuda"),
803            ("acos", doubles, True, True, "cpu"),
804            ("acos", doubles, True, True, "cuda"),
805            ("asin", doubles, True, True, "cpu"),
806            ("asin", doubles, True, True, "cuda"),
807            ("atan", doubles, True, True, "cpu"),
808            ("atan", doubles, True, True, "cuda"),
809            ("acosh", doubles, True, True, "cpu"),
810            ("acosh", doubles, True, True, "cuda"),
811            ("asinh", doubles, True, True, "cpu"),
812            ("asinh", doubles, True, True, "cuda"),
813            ("atanh", doubles, True, True, "cpu"),
814            ("atanh", doubles, True, True, "cuda"),
815            ("bitwise_not", ints, True, True, "cpu"),
816            ("bitwise_not", ints, True, True, "cuda"),
817            ("ceil", doubles, True, True, "cpu"),
818            ("ceil", doubles, True, True, "cuda"),
819            ("cos", doubles, True, True, "cpu"),
820            ("cos", doubles, True, True, "cuda"),
821            ("cosh", doubles, True, True, "cpu"),
822            ("cosh", doubles, True, True, "cuda"),
823            ("digamma", doubles, True, True, "cpu"),
824            ("erf", doubles, True, True, "cpu"),
825            ("erf", doubles, True, True, "cuda"),
826            ("erfc", doubles, True, True, "cpu"),
827            ("erfc", doubles, True, True, "cuda"),
828            ("erfinv", doubles, True, True, "cpu"),
829            ("erfinv", doubles, True, True, "cuda"),
830            ("exp", doubles, True, True, "cpu"),
831            ("exp", doubles, True, True, "cuda"),
832            ("exp2", doubles, True, True, "cpu"),
833            ("exp2", doubles, True, True, "cuda"),
834            ("expm1", doubles, True, True, "cpu"),
835            ("expm1", doubles, True, True, "cuda"),
836            ("floor", doubles, True, True, "cpu"),
837            ("floor", doubles, True, True, "cuda"),
838            ("frac", doubles, True, True, "cpu"),
839            ("frac", doubles, True, True, "cuda"),
840            ("i0", doubles, True, True, "cpu"),
841            ("i0", doubles, True, True, "cuda"),
842            ("log", positives, True, True, "cpu"),
843            ("log", positives, True, True, "cuda"),
844            ("log10", positives, True, True, "cpu"),
845            ("log10", positives, True, True, "cuda"),
846            ("log1p", positives, True, True, "cpu"),
847            ("log1p", positives, True, True, "cuda"),
848            ("log2", positives, True, True, "cpu"),
849            ("log2", positives, True, True, "cuda"),
850            ("neg", doubles, True, True, "cpu"),
851            ("neg", doubles, True, True, "cuda"),
852            ("reciprocal", doubles, True, True, "cpu"),
853            ("reciprocal", doubles, True, True, "cuda"),
854            ("round", doubles, True, True, "cpu"),
855            ("round", doubles, True, True, "cuda"),
856            ("rsqrt", positives, True, True, "cpu"),
857            ("rsqrt", positives, True, True, "cuda"),
858            ("sin", doubles, True, True, "cpu"),
859            ("sin", doubles, True, True, "cuda"),
860            ("sinh", doubles, True, True, "cpu"),
861            ("sinh", doubles, False, True, "cuda"),
862            ("sigmoid", doubles, True, True, "cpu"),
863            ("sigmoid", doubles, True, True, "cuda"),
864            ("logit", doubles, True, True, "cpu"),
865            ("logit", doubles, True, True, "cuda"),
866            ("sqrt", doubles, True, True, "cpu"),
867            ("sqrt", doubles, False, True, "cuda"),
868            ("tan", doubles, True, True, "cpu"),
869            ("tan", doubles, True, True, "cuda"),
870            ("tanh", doubles, True, True, "cpu"),
871            ("tanh", doubles, True, True, "cuda"),
872            ("trunc", doubles, True, True, "cpu"),
873            ("trunc", doubles, True, True, "cuda"),
874        ]
875
876        for (
877            fn,
878            inputs,
879            has_input_output_mem_overlap_check,
880            has_internal_mem_overlap_check,
881            dev,
882        ) in unary_mem_overlap_cases:
883            if dev != device:
884                continue
885            out_fn = getattr(torch, fn)
886            in_fn = getattr(torch.Tensor, fn + "_")
887
888            self.unary_check_input_output_mem_overlap(
889                inputs,
890                sz,
891                out_fn,
892                expected_failure=not has_input_output_mem_overlap_check,
893            )
894
895            self.check_internal_mem_overlap(
896                in_fn,
897                1,
898                dtype,
899                dev,
900                expected_failure=not has_internal_mem_overlap_check,
901            )
902
903    # TODO: opinfo hardshrink
904    @onlyCPU
905    @dtypes(torch.float, torch.double, torch.bfloat16)
906    def test_hardshrink(self, device, dtype):
907        data = torch.tensor([1, 0.5, 0.3, 0.6], dtype=dtype, device=device).view(2, 2)
908        self.assertEqual(
909            torch.tensor([1, 0.5, 0, 0.6], dtype=dtype, device=device).view(2, 2),
910            data.hardshrink(0.3),
911        )
912        self.assertEqual(
913            torch.tensor([1, 0, 0, 0.6], dtype=dtype, device=device).view(2, 2),
914            data.hardshrink(0.5),
915        )
916
917        # test default lambd=0.5
918        self.assertEqual(data.hardshrink(), data.hardshrink(0.5))
919
920        # test non-contiguous case
921        self.assertEqual(
922            torch.tensor([1, 0, 0.5, 0.6], dtype=dtype, device=device).view(2, 2),
923            data.t().hardshrink(0.3),
924        )
925
926    @onlyCPU
927    @dtypes(torch.float, torch.double, torch.bfloat16)
928    def test_hardshrink_edge_cases(self, device, dtype) -> None:
929        def h(values, l_expected):
930            for l, expected in l_expected.items():
931                values_tensor = torch.tensor(
932                    [float(v) for v in values], dtype=dtype, device=device
933                )
934                expected_tensor = torch.tensor(
935                    [float(v) for v in expected], dtype=dtype, device=device
936                )
937                self.assertEqual(
938                    expected_tensor == values_tensor.hardshrink(l),
939                    torch.ones_like(values_tensor, dtype=torch.bool),
940                )
941
942        def test_helper(min, max):
943            h(
944                [0.0, min, -min, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf],
945                {
946                    0.0: [0.0, min, -min, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf],
947                    min: [0.0, 0.0, 0.0, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf],
948                    0.1: [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, -1.0, max, -max, inf, -inf],
949                    1.0: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, max, -max, inf, -inf],
950                    max: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, inf, -inf],
951                    inf: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
952                },
953            )
954
955        test_helper(torch.finfo(dtype).tiny, torch.finfo(dtype).max)
956
957    @onlyCPU
958    @slowTest
959    @dtypes(torch.float)
960    @unittest.skipIf(True, "Insufficient memory on linux.(2|4)xlarge")
961    def test_exp_slow(self, device, dtype):
962        # Test for https://github.com/pytorch/pytorch/issues/17271
963        # This is pretty slow on my Macbook but it only takes a few
964        # seconds on a beefy Xeon server
965        a = torch.exp(torch.ones(2**31, dtype=dtype, device=device))
966        b = torch.exp(torch.ones(1, dtype=dtype, device=device))
967        self.assertEqual(a, b.expand(2**31))
968
969    @precisionOverride(
970        {torch.bfloat16: 1e-2, torch.float: 0.0002, torch.double: 0.0002}
971    )
972    @dtypes(torch.float, torch.double, torch.bfloat16)
973    def test_hardswish(self, device, dtype):
974        inputValues = [-1000, -4, -3, -2, 0, 2, 3, 4, 1000]
975        expectedOutput = np.multiply(
976            inputValues, np.minimum(np.maximum((np.add(inputValues, 3)), 0), 6) / 6.0
977        )
978
979        inputTensor = torch.tensor(inputValues, dtype=dtype, device=device)
980        expectedOutputTensor = torch.tensor(expectedOutput, dtype=dtype, device=device)
981
982        # normal
983        self.assertEqual(
984            torch.nn.functional.hardswish(inputTensor), expectedOutputTensor
985        )
986
987        # inplace
988        inputTensorCpy = inputTensor.clone().detach()
989        torch.nn.functional.hardswish(inputTensorCpy, inplace=True)
990        self.assertEqual(inputTensorCpy, expectedOutputTensor)
991
992    @precisionOverride(
993        {torch.bfloat16: 1e-2, torch.float: 0.0002, torch.double: 0.0002}
994    )
995    @dtypes(torch.float, torch.double, torch.bfloat16)
996    def test_hardsigmoid(self, device, dtype):
997        inputValues = [-1000, -4, -3, -2, 0, 2, 3, 4, 1000]
998        expectedOutput = np.minimum(np.maximum((np.add(inputValues, 3)), 0), 6) / 6.0
999
1000        inputTensor = torch.tensor(inputValues, dtype=dtype, device=device)
1001
1002        # normal
1003        self.assertEqual(
1004            torch.nn.functional.hardsigmoid(inputTensor),
1005            torch.tensor(expectedOutput, dtype=dtype, device=device),
1006        )
1007
1008        # inplace
1009        inputTensorCpy = inputTensor.clone().detach()
1010        self.assertEqual(
1011            torch.nn.functional.hardsigmoid(inputTensorCpy, inplace=True),
1012            torch.tensor(expectedOutput, dtype=dtype, device=device),
1013        )
1014
1015    @precisionOverride(
1016        {torch.bfloat16: 1e-2, torch.float: 0.0002, torch.double: 0.0002}
1017    )
1018    @dtypes(torch.float, torch.double, torch.bfloat16)
1019    def test_hardsigmoid_backward(self, device, dtype):
1020        inputValues = [-3.0, 3.0, -2.0, 2.0, -6.0, 6.0]
1021        expectedValues = [0.0, 0.0, 1.0 / 6.0, 1.0 / 6.0, 0.0, 0.0]
1022        inputTensor = torch.tensor(
1023            inputValues, dtype=dtype, device=device
1024        ).requires_grad_()
1025        expetedTensor = torch.tensor(expectedValues, dtype=dtype, device=device)
1026        out = torch.nn.functional.hardsigmoid(inputTensor)
1027        out.backward(torch.ones_like(inputTensor))
1028        self.assertEqual(inputTensor.grad, expetedTensor)
1029
1030    @skipIfNoSciPy
1031    @dtypes(torch.float, torch.double)
1032    def test_silu(self, device, dtype):
1033        input_np = np.random.randn(5, 8)
1034        special_input = [[-1000, -1, -0.1, 0, 0.5, 1, 2, 1000]]
1035        input_np = np.concatenate((input_np, special_input), axis=0).astype(
1036            torch_to_numpy_dtype_dict[dtype]
1037        )
1038        expected_output_np = input_np * scipy.special.expit(input_np)
1039
1040        expected_output = torch.from_numpy(expected_output_np).to(device)
1041        expected_output_noncontig = expected_output.transpose(0, 1)
1042
1043        atol = 1e-6
1044        rtol = 1e-6
1045
1046        input = torch.from_numpy(input_np).clone().contiguous().to(device)
1047        self.assertEqual(
1048            torch.nn.functional.silu(input), expected_output, atol=atol, rtol=rtol
1049        )
1050        self.assertEqual(
1051            torch.nn.functional.silu(input, inplace=True),
1052            expected_output,
1053            atol=atol,
1054            rtol=rtol,
1055        )
1056
1057        input = torch.from_numpy(input_np).clone().to(device)
1058        input_noncontig = input.transpose(0, 1)
1059        self.assertEqual(
1060            torch.nn.functional.silu(input_noncontig),
1061            expected_output_noncontig,
1062            atol=atol,
1063            rtol=rtol,
1064        )
1065        self.assertEqual(
1066            torch.nn.functional.silu(input_noncontig, inplace=True),
1067            expected_output_noncontig,
1068            atol=atol,
1069            rtol=rtol,
1070        )
1071
1072    @dtypes(torch.complex64, torch.complex128)
1073    def test_silu_complex(self, device, dtype):
1074        atol = 1e-6
1075        rtol = 1e-6
1076        inouts = [
1077            (0.2 + 0.3j, 0.08775215595960617065 + 0.18024823069572448730j),
1078            (1e-19 + 1e-18j, 4.99999984132761269448e-20 + 5.00000022906852482872e-19j),
1079            (-1.0 + 2.0j, -0.78546208143234252930 + -0.44626939296722412109j),
1080            (0.0 + 0.5j, -0.06383547931909561157 + 0.25000000000000000000j),
1081            (2.0j, -1.55740761756896972656 + 0.99999988079071044922j)
1082        ]
1083
1084        for inp, out in inouts:
1085            res = torch.nn.functional.silu(torch.tensor(inp, dtype=dtype, device=device))
1086            self.assertFalse(torch.any(torch.isnan(res)))
1087            self.assertEqual(res.real, out.real, atol=atol, rtol=rtol)
1088            self.assertEqual(res.imag, out.imag, atol=atol, rtol=rtol)
1089
1090        for inp, out in inouts:
1091            res = torch.nn.functional.silu(torch.tensor(inp, dtype=dtype, device=device), inplace=True)
1092            self.assertFalse(torch.any(torch.isnan(res)))
1093            self.assertEqual(res.real, out.real, atol=atol, rtol=rtol)
1094            self.assertEqual(res.imag, out.imag, atol=atol, rtol=rtol)
1095
1096    # It is not obvious how to merge this into OpInfo becuase these inputs
1097    # succeed for gradcheck but are expected to fail for gradgradcheck
1098    @dtypes(torch.double)
1099    def test_sinc(self, device, dtype):
1100        # The derivative of sinc(x) at x=0 has to be special cased.
1101        # A naive computation will result in 0/0 -> NaN.
1102        # We also need to be careful when we are very close to 0, as the
1103        # derivative's denominator is squared, and there are some floats
1104        # that are positive and whose squares are zero.
1105        a = torch.tensor(
1106            [0.0, torch.finfo(torch.double).tiny, 1.0],
1107            dtype=dtype,
1108            requires_grad=True,
1109            device=device,
1110        )
1111        gradcheck(torch.sinc, a)
1112
1113    @skipIfNoSciPy
1114    @dtypes(torch.float, torch.double)
1115    def test_mish(self, device, dtype):
1116        input_np = np.random.randn(5, 8)
1117        special_input = [[-1000, -1, -0.1, 0, 0.5, 1, 2, 1000]]
1118        input_np = np.concatenate((input_np, special_input), axis=0).astype(
1119            torch_to_numpy_dtype_dict[dtype]
1120        )
1121        expected_output_np = input_np * np.tanh(np.log1p(np.exp(input_np)))
1122
1123        expected_output = torch.from_numpy(expected_output_np).to(device)
1124        expected_output_noncontig = expected_output.transpose(0, 1)
1125
1126        atol = 1e-6
1127        rtol = 1e-6
1128
1129        input = torch.from_numpy(input_np).clone().contiguous().to(device)
1130        self.assertEqual(
1131            torch.nn.functional.mish(input), expected_output, atol=atol, rtol=rtol
1132        )
1133        self.assertEqual(
1134            torch.nn.functional.mish(input, inplace=True),
1135            expected_output,
1136            atol=atol,
1137            rtol=rtol,
1138        )
1139
1140        input = torch.from_numpy(input_np).clone().to(device)
1141        input_noncontig = input.transpose(0, 1)
1142        self.assertEqual(
1143            torch.nn.functional.mish(input_noncontig),
1144            expected_output_noncontig,
1145            atol=atol,
1146            rtol=rtol,
1147        )
1148        self.assertEqual(
1149            torch.nn.functional.mish(input_noncontig, inplace=True),
1150            expected_output_noncontig,
1151            atol=atol,
1152            rtol=rtol,
1153        )
1154
1155    @dtypes(torch.complex64, torch.complex128)
1156    def test_log1p_complex(self, device, dtype):
1157        # The output values here were obtained using arbitrary precision math (mpmath)
1158        # and double checked with WolframAlpha.
1159        # Not using numpy's log1p here because by the time of writing this,
1160        # np.log1p has precision problems for small complex input values, see here:
1161        # https://github.com/numpy/numpy/issues/22609
1162        inouts = [
1163            (0.2 + 0.3j, 0.21263386770217202 + 0.24497866312686414j),
1164            (1e-19 + 1e-18j, 1e-19 + 1e-18j),
1165            (1e-18 + 0.1j, 0.00497517 + 0.0996687j),
1166            (0.1 + 1e-18j, 0.0953102 + 9.090909090909090909e-19j),
1167            (0.5 + 0j, 0.40546510810816 + 0j),
1168            (0.0 + 0.5j, 0.111571776 + 0.463647609j),
1169            (2.0 + 1.0j, 1.151292546497023 + 0.3217505543966422j),
1170            (-1.0 + 2.0j, 0.6931471805599453 + 1.570796326794897j),
1171            (2.0j, 0.80471895621705014 + 1.1071487177940904j),
1172            (-2.0j, 0.80471895621705014 - 1.1071487177940904j),
1173        ]
1174        # test the extreme values
1175        if dtype == torch.complex128:
1176            inouts += [
1177                (-1 + 1e250j, 575.6462732485114 + 1.5707963267948966j),
1178                (1e250 + 1j, 575.6462732485114 + 1e-250j),
1179                (1e250 + 1e250j, 575.9928468387914 + 0.7853981633974483j),
1180                (1e-250 + 1e250j, 575.6462732485114 + 1.5707963267948966j),
1181                (1e-250 + 2e-250j, 1e-250 + 2e-250j),
1182                (1e250 + 1e-250j, 575.6462732485114 + 0.0j),
1183            ]
1184        elif dtype == torch.complex64:
1185            inouts += [
1186                (-1 + 1e30j, 69.07755278982137 + 1.5707963267948966j),
1187                (1e30 + 1j, 69.07755278982137 + 1e-30j),
1188                (1e30 + 1e30j, 69.42412638010134 + 0.7853981633974483j),
1189                (1e-30 + 1e30j, 69.07755278982137 + 1.5707963267948966j),
1190                (1e-30 + 2e-30j, 1e-30 + 2e-30j),
1191                (1e30 + 1e-30j, 69.07755278982137 + 0.0j),
1192            ]
1193
1194        # test the log1p individually
1195        for inp, out in inouts:
1196            res = torch.log1p(torch.tensor(inp, dtype=dtype, device=device))
1197            self.assertFalse(torch.any(torch.isnan(res)))
1198            # setting up atol == 0.0 because some part has very small values
1199            self.assertEqual(res.real, out.real, atol=0.0, rtol=1e-6)
1200            self.assertEqual(res.imag, out.imag, atol=0.0, rtol=1e-6)
1201
1202        # test the log1p in tensor
1203        inp_lst, out_lst = (list(elmt) for elmt in zip(*inouts))
1204        inp_tens = torch.tensor(inp_lst, dtype=dtype, device=device)
1205        out_tens = torch.tensor(out_lst, dtype=dtype, device=device)
1206        res_tens = torch.log1p(inp_tens)
1207        self.assertEqual(res_tens.real, out_tens.real, atol=0.0, rtol=1e-6)
1208        self.assertEqual(res_tens.imag, out_tens.imag, atol=0.0, rtol=1e-6)
1209
1210    # do ops like threshold need a test_unary(_nonufunc) test suite?
1211    @onlyCPU
1212    @dtypes(*get_all_math_dtypes("cpu"))
1213    def test_threshold(self, device, dtype):
1214        if dtype != torch.uint8 and dtype != torch.float16 and not dtype.is_complex:
1215            # 100 is wide enough to use AVX2 instructions for all types
1216            x = (
1217                torch.randn(100, dtype=torch.float, device=device)
1218                .sign()
1219                .to(dtype=dtype)
1220            )
1221            y = torch.threshold(x, 0, 0)
1222            self.assertTrue(y.le(0).any())
1223
1224    def _helper_test_igamma(self, loglo, loghi, device, dtype, torch_fcn, scipy_fcn):
1225        exp1 = 2.71828182846
1226        vec1 = torch.logspace(
1227            loglo, loghi, steps=500, base=exp1, dtype=torch.float64, device=device
1228        ).unsqueeze(-1)
1229        vec1 = vec1.to(dtype)
1230        inputs = [
1231            (vec1, vec1.transpose(0, 1)),
1232            (vec1, vec1),  # for large number, it should approach 0.5
1233            (vec1, 0.5 * vec1),  # test for considerable ratio
1234            (vec1, 2.0 * vec1),
1235            (vec1[::2, :], vec1[::2, :]),  # contiguous/noncontiguous tests
1236            (vec1[::2, :], vec1[: vec1.shape[0] // 2, :]),
1237            (vec1[: vec1.shape[0] // 2, :], vec1[::2, :]),
1238        ]
1239        half_prec = dtype in [torch.bfloat16, torch.float16]
1240        for input0, input1 in inputs:
1241            actual = torch_fcn(input0, input1)
1242            if half_prec:
1243                input0 = input0.to(torch.float)
1244                input1 = input1.to(torch.float)
1245            expected = scipy_fcn(input0.cpu().numpy(), input1.cpu().numpy())
1246            expected = torch.from_numpy(expected).to(dtype)
1247            self.assertEqual(actual, expected)
1248
1249    @dtypesIfCPU(torch.float16, torch.bfloat16, torch.float32, torch.float64)
1250    @dtypes(torch.float32, torch.float64)
1251    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1252    @onlyNativeDeviceTypes
1253    def test_igamma_common(self, device, dtype):
1254        # test igamma for reasonable range of values
1255        loglo = -4  # approx 0.018
1256        loghi = 4  # approx 54.6
1257        self._helper_test_igamma(
1258            loglo, loghi, device, dtype, torch.igamma, scipy.special.gammainc
1259        )
1260
1261    @dtypesIfCPU(torch.float16, torch.bfloat16, torch.float32, torch.float64)
1262    @dtypes(torch.float32, torch.float64)
1263    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1264    @onlyNativeDeviceTypes
1265    def test_igammac_common(self, device, dtype):
1266        # test igammac for reasonable range of values
1267        loglo = -4  # approx 0.018
1268        loghi = 4  # approx 54.6
1269        self._helper_test_igamma(
1270            loglo, loghi, device, dtype, torch.igammac, scipy.special.gammaincc
1271        )
1272
1273    @dtypesIfCPU(torch.float16, torch.bfloat16, torch.float32, torch.float64)
1274    @dtypes(torch.float32, torch.float64)
1275    @onlyNativeDeviceTypes
1276    def test_igamma_edge_cases(self, device, dtype):
1277        tkwargs = {"dtype": dtype, "device": device}
1278        infs = torch.zeros((3,), **tkwargs) + float("inf")
1279        zeros = torch.zeros((3,), **tkwargs)
1280        ones = torch.ones((3,), **tkwargs)
1281        zero_to_large = torch.tensor([0.0, 1.0, 1e3], **tkwargs)
1282        small_to_inf = torch.tensor([1e-3, 1.0, float("inf")], **tkwargs)
1283        nans = torch.zeros((3,), **tkwargs) + float("nan")
1284        inpouts = [
1285            # (a    ,    x),       out
1286            ((zeros, small_to_inf), ones),
1287            ((small_to_inf, zeros), zeros),
1288            ((infs, zero_to_large), zeros),
1289            ((zero_to_large, infs), ones),
1290            ((zeros, zeros), nans),
1291            ((infs, infs), nans),
1292            ((-small_to_inf, small_to_inf), nans),
1293        ]
1294        for inputs, output in inpouts:
1295            input0, input1 = inputs
1296            calc = torch.igamma(input0, input1)
1297            if torch.all(torch.isnan(output)):
1298                self.assertTrue(torch.all(torch.isnan(calc)))
1299            else:
1300                self.assertEqual(calc, output)
1301
1302    @dtypesIfCPU(torch.float16, torch.bfloat16, torch.float32, torch.float64)
1303    @dtypes(torch.float32, torch.float64)
1304    @onlyNativeDeviceTypes
1305    def test_igammac_edge_cases(self, device, dtype):
1306        tkwargs = {"dtype": dtype, "device": device}
1307        infs = torch.zeros((3,), **tkwargs) + float("inf")
1308        zeros = torch.zeros((3,), **tkwargs)
1309        ones = torch.ones((3,), **tkwargs)
1310        zero_to_large = torch.tensor([0.0, 1.0, 1e3], **tkwargs)
1311        small_to_inf = torch.tensor([1e-3, 1.0, float("inf")], **tkwargs)
1312        nans = torch.zeros((3,), **tkwargs) + float("nan")
1313        inpouts = [
1314            # (a    ,    x),       out
1315            ((zeros, small_to_inf), zeros),
1316            ((small_to_inf, zeros), ones),
1317            ((infs, zero_to_large), ones),
1318            ((zero_to_large, infs), zeros),
1319            ((zeros, zeros), nans),
1320            ((infs, infs), nans),
1321            ((-small_to_inf, small_to_inf), nans),
1322        ]
1323        for inputs, output in inpouts:
1324            input0, input1 = inputs
1325            calc = torch.igammac(input0, input1)
1326            if torch.all(torch.isnan(output)):
1327                self.assertTrue(torch.all(torch.isnan(calc)))
1328            else:
1329                self.assertEqual(calc, output)
1330
1331    def _i0_helper(self, t):
1332        # Test by comparing to scipy
1333        dtype = t.dtype
1334        actual = torch.i0(t)
1335        if dtype is torch.bfloat16:
1336            t = t.to(torch.float32)
1337        expected = scipy.special.i0(t.cpu().numpy())
1338        # Casting down for dtype float16 is required since scipy upcasts to float32
1339        if dtype is torch.bfloat16 or dtype is torch.float16:
1340            expected = torch.from_numpy(expected).to(dtype)
1341        self.assertEqual(actual, expected)
1342
1343    def _i0_range_helper(self, range, device, dtype):
1344        # i0 tests are broken up by the domain for which the function does not overflow for each dtype
1345        # This is done to ensure that the function performs well across all possible input values, without worrying
1346        # about inf or nan possibilities
1347        for r in (range, -range):
1348            t = torch.rand(1000, device=device).to(dtype) * r
1349            self._i0_helper(t)
1350
1351    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
1352    @dtypes(torch.bfloat16, torch.float32, torch.float64)
1353    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1354    def test_i0_range1(self, device, dtype):
1355        # This tests the domain for i0 for which float16 does not overflow
1356        # The domain is (-13.25, 13.25)
1357        self._i0_range_helper(13.25, device, dtype)
1358
1359    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
1360    @dtypes(torch.bfloat16, torch.float32, torch.float64)
1361    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1362    def test_i0_range2(self, device, dtype):
1363        # This tests the domain for i0 for which float32 and bfloat16 does not overflow
1364        # The domain is (-88.5, 88.5)
1365        self._i0_range_helper(88.5, device, dtype)
1366
1367    @dtypes(torch.float64)
1368    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1369    def test_i0_range3(self, device, dtype):
1370        # This tests the domain for i0 for which float64 does not overflow
1371        # The domain is (-709.75, 709.75)
1372        self._i0_range_helper(709.75, device, dtype)
1373
1374    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
1375    @dtypes(torch.bfloat16, torch.float32, torch.float64)
1376    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1377    def test_i0_special(self, device, dtype):
1378        t = torch.tensor([], device=device, dtype=dtype)
1379        self._i0_helper(t)
1380
1381        t = torch.tensor([inf, -inf, nan], device=device, dtype=dtype)
1382        self.assertTrue(torch.i0(t).isnan().all())
1383
1384    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
1385    @dtypes(torch.bfloat16, torch.float32, torch.float64)
1386    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1387    def test_special_i0_i1_vs_scipy(self, device, dtype):
1388        def check_equal(t, torch_fn, scipy_fn):
1389            # Test by comparing to scipy
1390            actual = torch_fn(t)
1391            if dtype is torch.bfloat16:
1392                t = t.to(torch.float32)
1393            expected = scipy_fn(t.cpu().numpy())
1394
1395            # Casting down for dtype float16 is required since scipy upcasts to float32
1396            if dtype is torch.bfloat16 or dtype is torch.float16:
1397                expected = torch.from_numpy(expected).to(dtype)
1398            self.assertEqual(actual, expected)
1399
1400        t = torch.tensor([], device=device, dtype=dtype)
1401        check_equal(t, torch.i0, scipy.special.i0)
1402        check_equal(t, torch.special.i0e, scipy.special.i0e)
1403        if dtype not in [torch.half, torch.bfloat16]:
1404            check_equal(t, torch.special.i1, scipy.special.i1)
1405            check_equal(t, torch.special.i1e, scipy.special.i1e)
1406
1407        range = (-1e7, 1e7)
1408        if dtype == torch.half:
1409            range = (-65000, 65000)
1410
1411        t = torch.linspace(*range, int(1e4), device=device, dtype=dtype)
1412        check_equal(t, torch.i0, scipy.special.i0)
1413        check_equal(t, torch.special.i0e, scipy.special.i0e)
1414        if dtype not in [torch.half, torch.bfloat16]:
1415            check_equal(t, torch.special.i1, scipy.special.i1)
1416            check_equal(t, torch.special.i1e, scipy.special.i1e)
1417
1418        # NaN, inf, -inf are tested in reference_numerics tests.
1419        info = torch.finfo(dtype)
1420        min, max, eps, tiny = info.min, info.max, info.eps, info.tiny
1421        t = torch.tensor([min, max, eps, tiny], dtype=dtype, device=device)
1422        check_equal(t, torch.i0, scipy.special.i0)
1423        check_equal(t, torch.special.i0e, scipy.special.i0e)
1424        if dtype not in [torch.half, torch.bfloat16]:
1425            check_equal(t, torch.special.i1, scipy.special.i1)
1426            check_equal(t, torch.special.i1e, scipy.special.i1e)
1427
1428    @dtypes(torch.float32, torch.float64)
1429    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1430    def test_special_ndtr_vs_scipy(self, device, dtype):
1431        def check_equal(t):
1432            # Test by comparing to scipy
1433            actual = torch.special.ndtr(t)
1434            expected = scipy.special.ndtr(t.cpu().numpy())
1435            self.assertEqual(actual, expected)
1436
1437        range = (-10, 10)
1438        t = torch.linspace(*range, 1, device=device, dtype=dtype)
1439        check_equal(t)
1440
1441        # Skip testing NaN, inf, -inf since they are tested in reference_numerics tests.
1442        info = torch.finfo(dtype)
1443        min, max, eps, tiny = info.min, info.max, info.eps, info.tiny
1444        t = torch.tensor([min, max, eps, tiny], dtype=dtype, device=device)
1445        check_equal(t)
1446
1447    @dtypes(torch.float32, torch.float64)
1448    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1449    def test_special_log_ndtr_vs_scipy(self, device, dtype):
1450        def check_equal(t):
1451            # Test by comparing with scipy
1452            actual = torch.special.log_ndtr(t)
1453            expected = scipy.special.log_ndtr(t.cpu().numpy())
1454            self.assertEqual(actual, expected)
1455
1456        # Skip testing NaN, inf, -inf since they are tested in reference_numerics tests.
1457        info = torch.finfo(dtype)
1458        min, max, eps, tiny = info.min, info.max, info.eps, info.tiny
1459        t = torch.tensor([min, max, eps, tiny], dtype=dtype, device=device)
1460        check_equal(t)
1461
1462    # TODO: allow large opinfo values to be opted-into via metadata
1463    @dtypes(torch.long)
1464    def test_abs_big_number(self, device, dtype):
1465        bignumber = 2**31 + 1
1466        res = torch.tensor([bignumber], device=device, dtype=dtype)
1467        self.assertGreater(res.abs()[0], 0)
1468
1469    # TODO: add signed zero testing to opinfos
1470    @dtypes(torch.float, torch.double)
1471    def test_abs_signed_zero(self, device, dtype):
1472        # Both abs(0.0) and abs(-0.0) should result in 0.0
1473        size = 128 + 1  # pick a large enough number with remainder so that
1474        # both vectorized and nonvectorized op is tested
1475        inp = torch.zeros(size, device=device, dtype=dtype)
1476        inp[::2] = -0.0
1477        inp = inp.abs()
1478        for v in inp:
1479            self.assertGreater(math.copysign(1.0, v), 0.0)
1480
1481    # TODO: update to compare against NumPy by rationalizing with OpInfo
1482    @onlyCUDA
1483    @dtypes(torch.float, torch.double)
1484    def test_abs_zero(self, device, dtype):
1485        # Both abs(0.0) and abs(-0.0) should result in 0.0
1486        abs_zeros = torch.tensor([0.0, -0.0], device=device, dtype=dtype).abs().tolist()
1487        for num in abs_zeros:
1488            self.assertGreater(math.copysign(1.0, num), 0.0)
1489
1490    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
1491    def test_isposinf_isneginf_non_boolean_output(self, device, dtype):
1492        # test non-boolean tensors as the `out=` parameters
1493        # boolean outputs are tested in the above testcases
1494        vals = (float("inf"), -float("inf"), 1.2)
1495        t = torch.tensor(vals, device=device)
1496        for torch_op in (torch.isposinf, torch.isneginf):
1497            out = torch.empty_like(t, dtype=dtype)
1498            with self.assertRaisesRegex(
1499                RuntimeError, "does not support non-boolean outputs"
1500            ):
1501                torch_op(t, out=out)
1502
1503    def test_nonzero_empty(self, device):
1504        def assert_tuple_empty(tup, dim):
1505            self.assertEqual(dim, len(tup))
1506            for t in tup:
1507                self.assertEqual(torch.Size([0]), t.shape)
1508
1509        x = torch.randn(0, 2, 0, 5, 0, device=device)
1510        y = torch.nonzero(x)
1511        z = torch.nonzero(x, as_tuple=True)
1512
1513        self.assertEqual(0, y.numel())
1514        self.assertEqual(torch.Size([0, 5]), y.shape)
1515        assert_tuple_empty(z, 5)
1516
1517        x = torch.tensor(0.5, device=device)
1518        y = torch.nonzero(x)
1519        # nonzero with as_tuple returns a
1520        # tuple of len 1 for a zero-dim tensor.
1521        # This is done to match Numpy behavior.
1522        z = torch.nonzero(x, as_tuple=True)
1523        self.assertEqual(1, len(z))
1524        self.assertEqual(torch.zeros(1, dtype=torch.long), z[0])
1525
1526        x = torch.zeros((), device=device)
1527        y = torch.nonzero(x)
1528        z = torch.nonzero(x, as_tuple=True)
1529        self.assertEqual(torch.Size([0, 0]), y.shape)
1530        self.assertEqual(1, len(z))
1531        self.assertEqual(torch.empty(0, dtype=torch.long), z[0])
1532
1533    # TODO: rationalize with exp OpInfo
1534    @dtypes(*floating_and_complex_types_and(torch.bfloat16))
1535    @dtypesIfCUDA(*floating_and_complex_types_and(torch.half, torch.bfloat16))
1536    def test_exp(self, device, dtype):
1537        for v in (2, -2) + ((1j, 1 + 1j) if dtype.is_complex else ()):
1538            a = (
1539                torch.tensor(v, dtype=dtype, device=device)
1540                * torch.arange(18, device=device)
1541                / 3
1542                * math.pi
1543            )
1544            a = a.to(dtype)
1545            # bfloat16 overflows
1546            if dtype == torch.bfloat16:
1547                return
1548            self.compare_with_numpy(torch.exp, np.exp, a)
1549
1550            if dtype.is_complex:
1551                inf_real_zero_imag_in = torch.tensor(
1552                    complex(float("inf"), 0), device=device, dtype=dtype
1553                )
1554                inf_real_zero_imag_out = torch.exp(inf_real_zero_imag_in).item()
1555                self.assertTrue(math.isinf(inf_real_zero_imag_out.real))
1556                if self.device_type == "cpu":
1557                    pass
1558                    # These are commented out because it cannot be consistently reproduced.
1559                    # This is incorrect. It should be zero. Need fix!
1560                    # https://github.com/pytorch/pytorch/issues/40590
1561                    # self.assertNotEqual(inf_real_zero_imag_out.imag, 0)
1562                    # This is incorrect. They should equal. Need fix!
1563                    # https://github.com/pytorch/pytorch/issues/40590
1564                    # with self.assertRaises(AssertionError):
1565                    #     self.compare_with_numpy(torch.exp, np.exp, inf_real_zero_imag_in)
1566                else:
1567                    self.assertEqual(inf_real_zero_imag_out.imag, 0, atol=0, rtol=0)
1568                    self.compare_with_numpy(torch.exp, np.exp, inf_real_zero_imag_in)
1569
1570                zero_real_inf_imag_in = torch.tensor(
1571                    complex(0, float("inf")), device=device, dtype=dtype
1572                )
1573                zero_real_inf_imag_out = torch.exp(zero_real_inf_imag_in).item()
1574                self.assertTrue(math.isnan(zero_real_inf_imag_out.real))
1575                self.assertTrue(math.isnan(zero_real_inf_imag_out.imag))
1576                # Ensure we are notified when NumPy changes its behavior
1577                self.compare_with_numpy(torch.exp, np.exp, zero_real_inf_imag_in)
1578
1579                inf_real_imag_in = torch.tensor(
1580                    complex(float("inf"), float("inf")), device=device, dtype=dtype
1581                )
1582                inf_real_imag_out = torch.exp(inf_real_imag_in).item()
1583                if self.device_type == "cpu":
1584                    pass
1585                    # This is incorrect. Need fix! https://github.com/pytorch/pytorch/issues/40590
1586                    # This is commented out because it cannot be consistently reproduced.
1587                    # with self.assertRaises(AssertionError):
1588                    #     self.compare_with_numpy(torch.exp, np.exp, inf_real_imag_in)
1589                else:
1590                    self.assertTrue(math.isinf(inf_real_imag_out.real))
1591                    self.assertTrue(math.isnan(inf_real_imag_out.imag))
1592                    self.compare_with_numpy(torch.exp, np.exp, inf_real_imag_in)
1593
1594                inf_real_nan_imag_in = torch.tensor(
1595                    complex(float("inf"), float("nan")), device=device, dtype=dtype
1596                )
1597                inf_real_nan_imag_out = torch.exp(inf_real_nan_imag_in).item()
1598                if self.device_type == "cpu":
1599                    pass
1600                    # This is incorrect. It should be inf. Need fix! https://github.com/pytorch/pytorch/issues/40590
1601                    # This is commented out because it cannot be consistently reproduced.
1602                    # with self.assertRaises(AssertionError):
1603                    #     self.compare_with_numpy(torch.exp, np.exp, inf_real_nan_imag_in)
1604                else:
1605                    self.assertTrue(math.isinf(inf_real_nan_imag_out.real))
1606                    self.assertTrue(math.isnan(inf_real_nan_imag_out.imag))
1607                    self.compare_with_numpy(torch.exp, np.exp, inf_real_nan_imag_in)
1608
1609                nan_real_inf_imag_in = torch.tensor(
1610                    complex(float("nan"), float("inf")), device=device, dtype=dtype
1611                )
1612                nan_real_inf_imag_out = torch.exp(nan_real_inf_imag_in).item()
1613                self.assertTrue(math.isnan(nan_real_inf_imag_out.real))
1614                self.assertTrue(math.isnan(nan_real_inf_imag_out.imag))
1615                # Ensure we are notified when NumPy changes its behavior
1616                self.compare_with_numpy(torch.exp, np.exp, nan_real_inf_imag_in)
1617
1618
1619instantiate_device_type_tests(TestUnaryUfuncs, globals())
1620
1621if __name__ == "__main__":
1622    run_tests()
1623