xref: /aosp_15_r20/external/pytorch/test/test_unary_ufuncs.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: tests"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport torch
4*da0073e9SAndroid Build Coastguard Workerimport numpy as np
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport math
7*da0073e9SAndroid Build Coastguard Workerfrom numbers import Number
8*da0073e9SAndroid Build Coastguard Workerimport random
9*da0073e9SAndroid Build Coastguard Workerimport unittest
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Workerfrom torch import inf, nan
12*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
13*da0073e9SAndroid Build Coastguard Worker    TestCase,
14*da0073e9SAndroid Build Coastguard Worker    run_tests,
15*da0073e9SAndroid Build Coastguard Worker    torch_to_numpy_dtype_dict,
16*da0073e9SAndroid Build Coastguard Worker    numpy_to_torch_dtype_dict,
17*da0073e9SAndroid Build Coastguard Worker    suppress_warnings,
18*da0073e9SAndroid Build Coastguard Worker    TEST_SCIPY,
19*da0073e9SAndroid Build Coastguard Worker    slowTest,
20*da0073e9SAndroid Build Coastguard Worker    skipIfNoSciPy,
21*da0073e9SAndroid Build Coastguard Worker    IS_WINDOWS,
22*da0073e9SAndroid Build Coastguard Worker    gradcheck,
23*da0073e9SAndroid Build Coastguard Worker    is_iterable_of_tensors,
24*da0073e9SAndroid Build Coastguard Worker    xfailIfTorchDynamo,
25*da0073e9SAndroid Build Coastguard Worker)
26*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import (
27*da0073e9SAndroid Build Coastguard Worker    unary_ufuncs,
28*da0073e9SAndroid Build Coastguard Worker    generate_elementwise_unary_tensors,
29*da0073e9SAndroid Build Coastguard Worker    generate_elementwise_unary_small_value_tensors,
30*da0073e9SAndroid Build Coastguard Worker    generate_elementwise_unary_large_value_tensors,
31*da0073e9SAndroid Build Coastguard Worker    generate_elementwise_unary_extremal_value_tensors,
32*da0073e9SAndroid Build Coastguard Worker)
33*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import (
34*da0073e9SAndroid Build Coastguard Worker    instantiate_device_type_tests,
35*da0073e9SAndroid Build Coastguard Worker    ops,
36*da0073e9SAndroid Build Coastguard Worker    dtypes,
37*da0073e9SAndroid Build Coastguard Worker    onlyCPU,
38*da0073e9SAndroid Build Coastguard Worker    onlyNativeDeviceTypes,
39*da0073e9SAndroid Build Coastguard Worker    onlyCUDA,
40*da0073e9SAndroid Build Coastguard Worker    dtypesIfCUDA,
41*da0073e9SAndroid Build Coastguard Worker    precisionOverride,
42*da0073e9SAndroid Build Coastguard Worker    dtypesIfCPU,
43*da0073e9SAndroid Build Coastguard Worker)
44*da0073e9SAndroid Build Coastguard Workerfrom torch.utils import _pytree as pytree
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor
47*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import (
48*da0073e9SAndroid Build Coastguard Worker    floating_types_and,
49*da0073e9SAndroid Build Coastguard Worker    all_types_and_complex_and,
50*da0073e9SAndroid Build Coastguard Worker    integral_types_and,
51*da0073e9SAndroid Build Coastguard Worker    get_all_math_dtypes,
52*da0073e9SAndroid Build Coastguard Worker    complex_types,
53*da0073e9SAndroid Build Coastguard Worker    floating_and_complex_types_and,
54*da0073e9SAndroid Build Coastguard Worker)
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Workerif TEST_SCIPY:
57*da0073e9SAndroid Build Coastguard Worker    import scipy
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker# Refer [scipy reference filter]
60*da0073e9SAndroid Build Coastguard Worker# Filter operators for which the reference function
61*da0073e9SAndroid Build Coastguard Worker# is available in the current environment (for reference_numerics tests).
62*da0073e9SAndroid Build Coastguard Workerreference_filtered_ops = list(filter(lambda op: op.ref is not None, unary_ufuncs))
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker# Tests for unary "universal functions (ufuncs)" that accept a single
65*da0073e9SAndroid Build Coastguard Worker# tensor and have common properties like:
66*da0073e9SAndroid Build Coastguard Worker#   - they are elementwise functions
67*da0073e9SAndroid Build Coastguard Worker#   - the input shape is the output shape
68*da0073e9SAndroid Build Coastguard Worker#   - they typically have method and inplace variants
69*da0073e9SAndroid Build Coastguard Worker#   - they typically support the out kwarg
70*da0073e9SAndroid Build Coastguard Worker#   - they typically have NumPy or SciPy references
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Worker# See NumPy's universal function documentation
73*da0073e9SAndroid Build Coastguard Worker# (https://numpy.org/doc/1.18/reference/ufuncs.html) for more details
74*da0073e9SAndroid Build Coastguard Worker# about the concept of ufuncs.
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker# TODO: port test_unary_out_op_mem_overlap
78*da0073e9SAndroid Build Coastguard Worker# TODO: add test for inplace variants erroring on broadcasted inputs
79*da0073e9SAndroid Build Coastguard Workerclass TestUnaryUfuncs(TestCase):
80*da0073e9SAndroid Build Coastguard Worker    exact_dtype = True
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker    @ops(
83*da0073e9SAndroid Build Coastguard Worker        [_fn for _fn in unary_ufuncs if _fn.domain != (None, None)],
84*da0073e9SAndroid Build Coastguard Worker        allowed_dtypes=floating_types_and(torch.bfloat16, torch.half),
85*da0073e9SAndroid Build Coastguard Worker    )
86*da0073e9SAndroid Build Coastguard Worker    def test_float_domains(self, device, dtype, op):
87*da0073e9SAndroid Build Coastguard Worker        eps = (1e-5, 1e-3, 1e-1, 1, 2, 10, 20, 50, 100)
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker        low, high = op.domain
90*da0073e9SAndroid Build Coastguard Worker        # NOTE: the following two loops are separated for readability
91*da0073e9SAndroid Build Coastguard Worker        if low is not None:
92*da0073e9SAndroid Build Coastguard Worker            low_tensor = torch.tensor(low, device=device, dtype=dtype)
93*da0073e9SAndroid Build Coastguard Worker            for epsilon in eps:
94*da0073e9SAndroid Build Coastguard Worker                lower_tensor = low_tensor - epsilon
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker                # Skips the test if the difference is not representable,
97*da0073e9SAndroid Build Coastguard Worker                #   which can occur if, for example, the difference is small
98*da0073e9SAndroid Build Coastguard Worker                #   and the dtype is imprecise (like bfloat16 is)
99*da0073e9SAndroid Build Coastguard Worker                if lower_tensor.item() == low_tensor.item():
100*da0073e9SAndroid Build Coastguard Worker                    continue
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker                result = op(lower_tensor)
103*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
104*da0073e9SAndroid Build Coastguard Worker                    result.item(),
105*da0073e9SAndroid Build Coastguard Worker                    float("nan"),
106*da0073e9SAndroid Build Coastguard Worker                    msg=(
107*da0073e9SAndroid Build Coastguard Worker                        f"input of {lower_tensor.item()} outside lower domain boundary"
108*da0073e9SAndroid Build Coastguard Worker                        f" {low} produced {result.item()}, not nan!"
109*da0073e9SAndroid Build Coastguard Worker                    ),
110*da0073e9SAndroid Build Coastguard Worker                )
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker        if high is not None:
113*da0073e9SAndroid Build Coastguard Worker            high_tensor = torch.tensor(high, device=device, dtype=dtype)
114*da0073e9SAndroid Build Coastguard Worker            for epsilon in eps:
115*da0073e9SAndroid Build Coastguard Worker                higher_tensor = high_tensor + epsilon
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker                # See above comment
118*da0073e9SAndroid Build Coastguard Worker                if higher_tensor.item() == high_tensor.item():
119*da0073e9SAndroid Build Coastguard Worker                    continue
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker                result = op(higher_tensor)
122*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
123*da0073e9SAndroid Build Coastguard Worker                    result.item(),
124*da0073e9SAndroid Build Coastguard Worker                    float("nan"),
125*da0073e9SAndroid Build Coastguard Worker                    msg=(
126*da0073e9SAndroid Build Coastguard Worker                        f"input of {higher_tensor.item()} outside upper domain boundary"
127*da0073e9SAndroid Build Coastguard Worker                        f" {high} produced {result.item()}, not nan!"
128*da0073e9SAndroid Build Coastguard Worker                    ),
129*da0073e9SAndroid Build Coastguard Worker                )
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker    # Helper for comparing torch tensors and numpy arrays
132*da0073e9SAndroid Build Coastguard Worker    # TODO: should this or assertEqual also validate that strides are equal?
133*da0073e9SAndroid Build Coastguard Worker    def assertEqualHelper(
134*da0073e9SAndroid Build Coastguard Worker        self, actual, expected, msg, *, dtype, exact_dtype=True, **kwargs
135*da0073e9SAndroid Build Coastguard Worker    ):
136*da0073e9SAndroid Build Coastguard Worker        assert isinstance(actual, torch.Tensor)
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker        # Some NumPy functions return scalars, not arrays
139*da0073e9SAndroid Build Coastguard Worker        if isinstance(expected, Number):
140*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual.item(), expected, msg, **kwargs)
141*da0073e9SAndroid Build Coastguard Worker        elif isinstance(expected, np.ndarray):
142*da0073e9SAndroid Build Coastguard Worker            # Handles exact dtype comparisons between arrays and tensors
143*da0073e9SAndroid Build Coastguard Worker            if exact_dtype:
144*da0073e9SAndroid Build Coastguard Worker                if (
145*da0073e9SAndroid Build Coastguard Worker                    actual.dtype is torch.bfloat16
146*da0073e9SAndroid Build Coastguard Worker                    or expected.dtype != torch_to_numpy_dtype_dict[actual.dtype]
147*da0073e9SAndroid Build Coastguard Worker                ):
148*da0073e9SAndroid Build Coastguard Worker                    # Allows array dtype to be float32 when comparing with bfloat16 tensors
149*da0073e9SAndroid Build Coastguard Worker                    #   since NumPy doesn't support the bfloat16 dtype
150*da0073e9SAndroid Build Coastguard Worker                    # Also ops like scipy.special.erf, scipy.special.erfc, etc, promote float16
151*da0073e9SAndroid Build Coastguard Worker                    # to float32
152*da0073e9SAndroid Build Coastguard Worker                    if expected.dtype == np.float32:
153*da0073e9SAndroid Build Coastguard Worker                        assert actual.dtype in (
154*da0073e9SAndroid Build Coastguard Worker                            torch.float16,
155*da0073e9SAndroid Build Coastguard Worker                            torch.bfloat16,
156*da0073e9SAndroid Build Coastguard Worker                            torch.float32,
157*da0073e9SAndroid Build Coastguard Worker                        )
158*da0073e9SAndroid Build Coastguard Worker                    elif expected.dtype == np.float64:
159*da0073e9SAndroid Build Coastguard Worker                        assert actual.dtype in (
160*da0073e9SAndroid Build Coastguard Worker                            torch.float16,
161*da0073e9SAndroid Build Coastguard Worker                            torch.bfloat16,
162*da0073e9SAndroid Build Coastguard Worker                            torch.float32,
163*da0073e9SAndroid Build Coastguard Worker                            torch.float64,
164*da0073e9SAndroid Build Coastguard Worker                        )
165*da0073e9SAndroid Build Coastguard Worker                    else:
166*da0073e9SAndroid Build Coastguard Worker                        self.fail(
167*da0073e9SAndroid Build Coastguard Worker                            f"Expected dtype {expected.dtype} but got {actual.dtype}!"
168*da0073e9SAndroid Build Coastguard Worker                        )
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
171*da0073e9SAndroid Build Coastguard Worker                actual,
172*da0073e9SAndroid Build Coastguard Worker                torch.from_numpy(expected).to(actual.dtype),
173*da0073e9SAndroid Build Coastguard Worker                msg,
174*da0073e9SAndroid Build Coastguard Worker                exact_device=False,
175*da0073e9SAndroid Build Coastguard Worker                **kwargs
176*da0073e9SAndroid Build Coastguard Worker            )
177*da0073e9SAndroid Build Coastguard Worker        else:
178*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expected, msg, exact_device=False, **kwargs)
179*da0073e9SAndroid Build Coastguard Worker
180*da0073e9SAndroid Build Coastguard Worker    # Tests that the function and its (array-accepting) reference produce the same
181*da0073e9SAndroid Build Coastguard Worker    #   values on given tensors
182*da0073e9SAndroid Build Coastguard Worker    def _test_reference_numerics(self, dtype, op, tensors, equal_nan=True):
183*da0073e9SAndroid Build Coastguard Worker        def _helper_reference_numerics(
184*da0073e9SAndroid Build Coastguard Worker            expected, actual, msg, exact_dtype, equal_nan=True
185*da0073e9SAndroid Build Coastguard Worker        ):
186*da0073e9SAndroid Build Coastguard Worker            if not torch.can_cast(
187*da0073e9SAndroid Build Coastguard Worker                numpy_to_torch_dtype_dict[expected.dtype.type], dtype
188*da0073e9SAndroid Build Coastguard Worker            ):
189*da0073e9SAndroid Build Coastguard Worker                exact_dtype = False
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard Worker            if dtype in [torch.uint8, torch.int8, torch.bool]:
192*da0073e9SAndroid Build Coastguard Worker                # NOTE: For these dtypes, PyTorch computes in the default scalar type (float)
193*da0073e9SAndroid Build Coastguard Worker                # while NumPy computes in float16
194*da0073e9SAndroid Build Coastguard Worker                self.assertEqualHelper(
195*da0073e9SAndroid Build Coastguard Worker                    actual,
196*da0073e9SAndroid Build Coastguard Worker                    expected,
197*da0073e9SAndroid Build Coastguard Worker                    msg,
198*da0073e9SAndroid Build Coastguard Worker                    dtype=dtype,
199*da0073e9SAndroid Build Coastguard Worker                    exact_dtype=exact_dtype,
200*da0073e9SAndroid Build Coastguard Worker                    rtol=1e-3,
201*da0073e9SAndroid Build Coastguard Worker                    atol=1e-2,
202*da0073e9SAndroid Build Coastguard Worker                )
203*da0073e9SAndroid Build Coastguard Worker            elif dtype is torch.bfloat16:
204*da0073e9SAndroid Build Coastguard Worker                # Ref: https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_utils.py#L1149
205*da0073e9SAndroid Build Coastguard Worker                self.assertEqualHelper(
206*da0073e9SAndroid Build Coastguard Worker                    actual,
207*da0073e9SAndroid Build Coastguard Worker                    expected,
208*da0073e9SAndroid Build Coastguard Worker                    msg,
209*da0073e9SAndroid Build Coastguard Worker                    dtype=dtype,
210*da0073e9SAndroid Build Coastguard Worker                    exact_dtype=exact_dtype,
211*da0073e9SAndroid Build Coastguard Worker                    rtol=16e-3,
212*da0073e9SAndroid Build Coastguard Worker                    atol=1e-5,
213*da0073e9SAndroid Build Coastguard Worker                )
214*da0073e9SAndroid Build Coastguard Worker            elif dtype is torch.half:
215*da0073e9SAndroid Build Coastguard Worker                self.assertEqualHelper(
216*da0073e9SAndroid Build Coastguard Worker                    actual,
217*da0073e9SAndroid Build Coastguard Worker                    expected,
218*da0073e9SAndroid Build Coastguard Worker                    msg,
219*da0073e9SAndroid Build Coastguard Worker                    dtype=dtype,
220*da0073e9SAndroid Build Coastguard Worker                    exact_dtype=exact_dtype,
221*da0073e9SAndroid Build Coastguard Worker                    rtol=1.2e-03,
222*da0073e9SAndroid Build Coastguard Worker                    atol=1e-03,
223*da0073e9SAndroid Build Coastguard Worker                )
224*da0073e9SAndroid Build Coastguard Worker            else:
225*da0073e9SAndroid Build Coastguard Worker                self.assertEqualHelper(
226*da0073e9SAndroid Build Coastguard Worker                    actual,
227*da0073e9SAndroid Build Coastguard Worker                    expected,
228*da0073e9SAndroid Build Coastguard Worker                    msg,
229*da0073e9SAndroid Build Coastguard Worker                    dtype=dtype,
230*da0073e9SAndroid Build Coastguard Worker                    equal_nan=equal_nan,
231*da0073e9SAndroid Build Coastguard Worker                    exact_dtype=exact_dtype,
232*da0073e9SAndroid Build Coastguard Worker                )
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker        for t in tensors:
235*da0073e9SAndroid Build Coastguard Worker            t = t.input
236*da0073e9SAndroid Build Coastguard Worker            torch_kwargs, numpy_kwargs = op.sample_kwargs(t.device, dtype, t)
237*da0073e9SAndroid Build Coastguard Worker            if dtype is torch.bfloat16:
238*da0073e9SAndroid Build Coastguard Worker                a = t.cpu().to(torch.float32).numpy()
239*da0073e9SAndroid Build Coastguard Worker            elif dtype is torch.complex32:
240*da0073e9SAndroid Build Coastguard Worker                a = t.cpu().to(torch.complex64).numpy()
241*da0073e9SAndroid Build Coastguard Worker            else:
242*da0073e9SAndroid Build Coastguard Worker                a = t.cpu().numpy()
243*da0073e9SAndroid Build Coastguard Worker
244*da0073e9SAndroid Build Coastguard Worker            actual = op(t, **torch_kwargs)
245*da0073e9SAndroid Build Coastguard Worker            expected = op.ref(a, **numpy_kwargs)
246*da0073e9SAndroid Build Coastguard Worker
247*da0073e9SAndroid Build Coastguard Worker            # Crafts a custom error message for smaller, printable tensors
248*da0073e9SAndroid Build Coastguard Worker            if t.numel() < 10:
249*da0073e9SAndroid Build Coastguard Worker                msg = (
250*da0073e9SAndroid Build Coastguard Worker                    "Failed to produce expected results! Input tensor was"
251*da0073e9SAndroid Build Coastguard Worker                    f" {t}, torch result is {actual}, and reference result is"
252*da0073e9SAndroid Build Coastguard Worker                    f" {expected}."
253*da0073e9SAndroid Build Coastguard Worker                )
254*da0073e9SAndroid Build Coastguard Worker            else:
255*da0073e9SAndroid Build Coastguard Worker                msg = None
256*da0073e9SAndroid Build Coastguard Worker
257*da0073e9SAndroid Build Coastguard Worker            exact_dtype = True
258*da0073e9SAndroid Build Coastguard Worker            if isinstance(actual, torch.Tensor):
259*da0073e9SAndroid Build Coastguard Worker                _helper_reference_numerics(
260*da0073e9SAndroid Build Coastguard Worker                    expected, actual, msg, exact_dtype, equal_nan
261*da0073e9SAndroid Build Coastguard Worker                )
262*da0073e9SAndroid Build Coastguard Worker            else:
263*da0073e9SAndroid Build Coastguard Worker                for x, y in zip(expected, actual):
264*da0073e9SAndroid Build Coastguard Worker                    # testing multi-outputs results
265*da0073e9SAndroid Build Coastguard Worker                    _helper_reference_numerics(x, y, msg, exact_dtype, equal_nan)
266*da0073e9SAndroid Build Coastguard Worker
267*da0073e9SAndroid Build Coastguard Worker    # Tests that the function and its (array-accepting) reference produce the same
268*da0073e9SAndroid Build Coastguard Worker    #   values on a range of tensors, including empty tensors, scalar tensors,
269*da0073e9SAndroid Build Coastguard Worker    #   1D tensors and a large 2D tensor with interesting and extremal values
270*da0073e9SAndroid Build Coastguard Worker    #   and noncontiguities.
271*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
272*da0073e9SAndroid Build Coastguard Worker    @ops(reference_filtered_ops)
273*da0073e9SAndroid Build Coastguard Worker    def test_reference_numerics_normal(self, device, dtype, op):
274*da0073e9SAndroid Build Coastguard Worker        tensors = generate_elementwise_unary_tensors(
275*da0073e9SAndroid Build Coastguard Worker            op, device=device, dtype=dtype, requires_grad=False
276*da0073e9SAndroid Build Coastguard Worker        )
277*da0073e9SAndroid Build Coastguard Worker        self._test_reference_numerics(dtype, op, tensors)
278*da0073e9SAndroid Build Coastguard Worker
279*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
280*da0073e9SAndroid Build Coastguard Worker    @ops(reference_filtered_ops)
281*da0073e9SAndroid Build Coastguard Worker    def test_reference_numerics_small(self, device, dtype, op):
282*da0073e9SAndroid Build Coastguard Worker        if dtype in (torch.bool,):
283*da0073e9SAndroid Build Coastguard Worker            raise self.skipTest("bool has no small values")
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Worker        tensors = generate_elementwise_unary_small_value_tensors(
286*da0073e9SAndroid Build Coastguard Worker            op, device=device, dtype=dtype, requires_grad=False
287*da0073e9SAndroid Build Coastguard Worker        )
288*da0073e9SAndroid Build Coastguard Worker        self._test_reference_numerics(dtype, op, tensors)
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
291*da0073e9SAndroid Build Coastguard Worker    @ops(reference_filtered_ops)
292*da0073e9SAndroid Build Coastguard Worker    def test_reference_numerics_large(self, device, dtype, op):
293*da0073e9SAndroid Build Coastguard Worker        if dtype in (torch.bool, torch.uint8, torch.int8):
294*da0073e9SAndroid Build Coastguard Worker            raise self.skipTest("bool, uint8, and int8 dtypes have no large values")
295*da0073e9SAndroid Build Coastguard Worker
296*da0073e9SAndroid Build Coastguard Worker        tensors = generate_elementwise_unary_large_value_tensors(
297*da0073e9SAndroid Build Coastguard Worker            op, device=device, dtype=dtype, requires_grad=False
298*da0073e9SAndroid Build Coastguard Worker        )
299*da0073e9SAndroid Build Coastguard Worker        self._test_reference_numerics(dtype, op, tensors)
300*da0073e9SAndroid Build Coastguard Worker
301*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
302*da0073e9SAndroid Build Coastguard Worker    @ops(
303*da0073e9SAndroid Build Coastguard Worker        reference_filtered_ops,
304*da0073e9SAndroid Build Coastguard Worker        allowed_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half),
305*da0073e9SAndroid Build Coastguard Worker    )
306*da0073e9SAndroid Build Coastguard Worker    def test_reference_numerics_extremal(self, device, dtype, op):
307*da0073e9SAndroid Build Coastguard Worker        tensors = generate_elementwise_unary_extremal_value_tensors(
308*da0073e9SAndroid Build Coastguard Worker            op, device=device, dtype=dtype, requires_grad=False
309*da0073e9SAndroid Build Coastguard Worker        )
310*da0073e9SAndroid Build Coastguard Worker        self._test_reference_numerics(dtype, op, tensors)
311*da0073e9SAndroid Build Coastguard Worker
312*da0073e9SAndroid Build Coastguard Worker    # Tests for testing (non)contiguity consistency
313*da0073e9SAndroid Build Coastguard Worker    @ops(unary_ufuncs)
314*da0073e9SAndroid Build Coastguard Worker    def test_contig_vs_every_other(self, device, dtype, op):
315*da0073e9SAndroid Build Coastguard Worker        contig = make_tensor(
316*da0073e9SAndroid Build Coastguard Worker            (1026,), device=device, dtype=dtype, low=op.domain[0], high=op.domain[1]
317*da0073e9SAndroid Build Coastguard Worker        )
318*da0073e9SAndroid Build Coastguard Worker        non_contig = contig[::2]
319*da0073e9SAndroid Build Coastguard Worker
320*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(contig.is_contiguous())
321*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(non_contig.is_contiguous())
322*da0073e9SAndroid Build Coastguard Worker
323*da0073e9SAndroid Build Coastguard Worker        torch_kwargs, _ = op.sample_kwargs(device, dtype, non_contig)
324*da0073e9SAndroid Build Coastguard Worker        expected = op(non_contig, **torch_kwargs)
325*da0073e9SAndroid Build Coastguard Worker        result = op(contig, **torch_kwargs)
326*da0073e9SAndroid Build Coastguard Worker        result = pytree.tree_map(lambda x: x[::2], result)
327*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, expected)
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Worker    @ops(unary_ufuncs)
330*da0073e9SAndroid Build Coastguard Worker    def test_contig_vs_transposed(self, device, dtype, op):
331*da0073e9SAndroid Build Coastguard Worker        contig = make_tensor(
332*da0073e9SAndroid Build Coastguard Worker            (789, 357), device=device, dtype=dtype, low=op.domain[0], high=op.domain[1]
333*da0073e9SAndroid Build Coastguard Worker        )
334*da0073e9SAndroid Build Coastguard Worker        non_contig = contig.T
335*da0073e9SAndroid Build Coastguard Worker
336*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(contig.is_contiguous())
337*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(non_contig.is_contiguous())
338*da0073e9SAndroid Build Coastguard Worker
339*da0073e9SAndroid Build Coastguard Worker        torch_kwargs, _ = op.sample_kwargs(device, dtype, contig)
340*da0073e9SAndroid Build Coastguard Worker        expected = op(non_contig, **torch_kwargs)
341*da0073e9SAndroid Build Coastguard Worker        result = op(contig, **torch_kwargs)
342*da0073e9SAndroid Build Coastguard Worker        result = pytree.tree_map(lambda x: x.T, result)
343*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, expected)
344*da0073e9SAndroid Build Coastguard Worker
345*da0073e9SAndroid Build Coastguard Worker    @ops(unary_ufuncs)
346*da0073e9SAndroid Build Coastguard Worker    def test_non_contig(self, device, dtype, op):
347*da0073e9SAndroid Build Coastguard Worker        shapes = [(5, 7), (1024,)]
348*da0073e9SAndroid Build Coastguard Worker        for shape in shapes:
349*da0073e9SAndroid Build Coastguard Worker            contig = make_tensor(
350*da0073e9SAndroid Build Coastguard Worker                shape, dtype=dtype, device=device, low=op.domain[0], high=op.domain[1]
351*da0073e9SAndroid Build Coastguard Worker            )
352*da0073e9SAndroid Build Coastguard Worker            non_contig = torch.empty(shape + (2,), device=device, dtype=dtype)[..., 0]
353*da0073e9SAndroid Build Coastguard Worker            non_contig.copy_(contig)
354*da0073e9SAndroid Build Coastguard Worker
355*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(contig.is_contiguous())
356*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(non_contig.is_contiguous())
357*da0073e9SAndroid Build Coastguard Worker
358*da0073e9SAndroid Build Coastguard Worker            torch_kwargs, _ = op.sample_kwargs(device, dtype, contig)
359*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(op(contig, **torch_kwargs), op(non_contig, **torch_kwargs))
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Worker    @ops(unary_ufuncs)
362*da0073e9SAndroid Build Coastguard Worker    def test_non_contig_index(self, device, dtype, op):
363*da0073e9SAndroid Build Coastguard Worker        contig = make_tensor(
364*da0073e9SAndroid Build Coastguard Worker            (2, 2, 1, 2),
365*da0073e9SAndroid Build Coastguard Worker            dtype=dtype,
366*da0073e9SAndroid Build Coastguard Worker            device=device,
367*da0073e9SAndroid Build Coastguard Worker            low=op.domain[0],
368*da0073e9SAndroid Build Coastguard Worker            high=op.domain[1],
369*da0073e9SAndroid Build Coastguard Worker        )
370*da0073e9SAndroid Build Coastguard Worker        non_contig = contig[:, 1, ...]
371*da0073e9SAndroid Build Coastguard Worker        contig = non_contig.contiguous()
372*da0073e9SAndroid Build Coastguard Worker
373*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(contig.is_contiguous())
374*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(non_contig.is_contiguous())
375*da0073e9SAndroid Build Coastguard Worker
376*da0073e9SAndroid Build Coastguard Worker        torch_kwargs, _ = op.sample_kwargs(device, dtype, contig)
377*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(op(contig, **torch_kwargs), op(non_contig, **torch_kwargs))
378*da0073e9SAndroid Build Coastguard Worker
379*da0073e9SAndroid Build Coastguard Worker    @ops(unary_ufuncs)
380*da0073e9SAndroid Build Coastguard Worker    def test_non_contig_expand(self, device, dtype, op):
381*da0073e9SAndroid Build Coastguard Worker        shapes = [(1, 3), (1, 7), (5, 7)]
382*da0073e9SAndroid Build Coastguard Worker        for shape in shapes:
383*da0073e9SAndroid Build Coastguard Worker            contig = make_tensor(
384*da0073e9SAndroid Build Coastguard Worker                shape, dtype=dtype, device=device, low=op.domain[0], high=op.domain[1]
385*da0073e9SAndroid Build Coastguard Worker            )
386*da0073e9SAndroid Build Coastguard Worker            non_contig = contig.clone().expand(3, -1, -1)
387*da0073e9SAndroid Build Coastguard Worker
388*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(contig.is_contiguous())
389*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(non_contig.is_contiguous())
390*da0073e9SAndroid Build Coastguard Worker
391*da0073e9SAndroid Build Coastguard Worker            torch_kwargs, _ = op.sample_kwargs(device, dtype, contig)
392*da0073e9SAndroid Build Coastguard Worker            contig = op(contig, **torch_kwargs)
393*da0073e9SAndroid Build Coastguard Worker            non_contig = op(non_contig, **torch_kwargs)
394*da0073e9SAndroid Build Coastguard Worker            for i in range(3):
395*da0073e9SAndroid Build Coastguard Worker                non_contig_i = pytree.tree_map(lambda x: x[i], non_contig)
396*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
397*da0073e9SAndroid Build Coastguard Worker                    contig, non_contig_i, msg="non-contiguous expand[" + str(i) + "]"
398*da0073e9SAndroid Build Coastguard Worker                )
399*da0073e9SAndroid Build Coastguard Worker
400*da0073e9SAndroid Build Coastguard Worker    @ops(unary_ufuncs)
401*da0073e9SAndroid Build Coastguard Worker    def test_contig_size1(self, device, dtype, op):
402*da0073e9SAndroid Build Coastguard Worker        contig = make_tensor(
403*da0073e9SAndroid Build Coastguard Worker            (5, 100), dtype=dtype, device=device, low=op.domain[0], high=op.domain[1]
404*da0073e9SAndroid Build Coastguard Worker        )
405*da0073e9SAndroid Build Coastguard Worker        contig = contig[:1, :50]
406*da0073e9SAndroid Build Coastguard Worker        contig2 = torch.empty(contig.size(), device=device, dtype=dtype)
407*da0073e9SAndroid Build Coastguard Worker        contig2.copy_(contig)
408*da0073e9SAndroid Build Coastguard Worker
409*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(contig.is_contiguous())
410*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(contig2.is_contiguous())
411*da0073e9SAndroid Build Coastguard Worker
412*da0073e9SAndroid Build Coastguard Worker        torch_kwargs, _ = op.sample_kwargs(device, dtype, contig)
413*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(op(contig, **torch_kwargs), op(contig2, **torch_kwargs))
414*da0073e9SAndroid Build Coastguard Worker
415*da0073e9SAndroid Build Coastguard Worker    @ops(unary_ufuncs)
416*da0073e9SAndroid Build Coastguard Worker    def test_contig_size1_large_dim(self, device, dtype, op):
417*da0073e9SAndroid Build Coastguard Worker        contig = make_tensor(
418*da0073e9SAndroid Build Coastguard Worker            (5, 2, 3, 1, 4, 5, 3, 2, 1, 2, 3, 4),
419*da0073e9SAndroid Build Coastguard Worker            dtype=dtype,
420*da0073e9SAndroid Build Coastguard Worker            device=device,
421*da0073e9SAndroid Build Coastguard Worker            low=op.domain[0],
422*da0073e9SAndroid Build Coastguard Worker            high=op.domain[1],
423*da0073e9SAndroid Build Coastguard Worker        )
424*da0073e9SAndroid Build Coastguard Worker        contig = contig[:1, :, :, :, :, :, :, :, :, :, :, :]
425*da0073e9SAndroid Build Coastguard Worker        contig2 = torch.empty(contig.size(), device=device, dtype=dtype)
426*da0073e9SAndroid Build Coastguard Worker        contig2.copy_(contig)
427*da0073e9SAndroid Build Coastguard Worker
428*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(contig.is_contiguous())
429*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(contig2.is_contiguous())
430*da0073e9SAndroid Build Coastguard Worker
431*da0073e9SAndroid Build Coastguard Worker        torch_kwargs, _ = op.sample_kwargs(device, dtype, contig)
432*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(op(contig, **torch_kwargs), op(contig2, **torch_kwargs))
433*da0073e9SAndroid Build Coastguard Worker
434*da0073e9SAndroid Build Coastguard Worker    # Tests that computation on a multiple batches is the same as
435*da0073e9SAndroid Build Coastguard Worker    # per-batch computation.
436*da0073e9SAndroid Build Coastguard Worker    @ops(unary_ufuncs)
437*da0073e9SAndroid Build Coastguard Worker    def test_batch_vs_slicing(self, device, dtype, op):
438*da0073e9SAndroid Build Coastguard Worker        input = make_tensor(
439*da0073e9SAndroid Build Coastguard Worker            (1024, 512), dtype=dtype, device=device, low=op.domain[0], high=op.domain[1]
440*da0073e9SAndroid Build Coastguard Worker        )
441*da0073e9SAndroid Build Coastguard Worker
442*da0073e9SAndroid Build Coastguard Worker        torch_kwargs, _ = op.sample_kwargs(device, dtype, input)
443*da0073e9SAndroid Build Coastguard Worker        actual = op(input, **torch_kwargs)
444*da0073e9SAndroid Build Coastguard Worker
445*da0073e9SAndroid Build Coastguard Worker        all_outs = [op(slice, **torch_kwargs) for slice in input]
446*da0073e9SAndroid Build Coastguard Worker        if is_iterable_of_tensors(actual):
447*da0073e9SAndroid Build Coastguard Worker            expected = [torch.stack([out[i] for out in all_outs]) for i in range(len(actual))]
448*da0073e9SAndroid Build Coastguard Worker        else:
449*da0073e9SAndroid Build Coastguard Worker            expected = torch.stack(all_outs)
450*da0073e9SAndroid Build Coastguard Worker
451*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual, expected)
452*da0073e9SAndroid Build Coastguard Worker
453*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.bool, torch.half))
454*da0073e9SAndroid Build Coastguard Worker    def test_nan_to_num(self, device, dtype):
455*da0073e9SAndroid Build Coastguard Worker        for contiguous in [False, True]:
456*da0073e9SAndroid Build Coastguard Worker            x = make_tensor((64, 64), low=0.0, high=100.0, dtype=dtype, device=device)
457*da0073e9SAndroid Build Coastguard Worker
458*da0073e9SAndroid Build Coastguard Worker            if dtype.is_floating_point:
459*da0073e9SAndroid Build Coastguard Worker                # Add extremal values.
460*da0073e9SAndroid Build Coastguard Worker                extremals = [float("nan"), float("inf"), -float("inf")]
461*da0073e9SAndroid Build Coastguard Worker                for idx, extremal in zip(torch.randint(0, 63, (3,)), extremals):
462*da0073e9SAndroid Build Coastguard Worker                    x[idx, :] = extremal
463*da0073e9SAndroid Build Coastguard Worker
464*da0073e9SAndroid Build Coastguard Worker            if not contiguous:
465*da0073e9SAndroid Build Coastguard Worker                x = x.T
466*da0073e9SAndroid Build Coastguard Worker
467*da0073e9SAndroid Build Coastguard Worker            # With args
468*da0073e9SAndroid Build Coastguard Worker            nan = random.random()
469*da0073e9SAndroid Build Coastguard Worker            posinf = random.random() * 5
470*da0073e9SAndroid Build Coastguard Worker            neginf = random.random() * 10
471*da0073e9SAndroid Build Coastguard Worker
472*da0073e9SAndroid Build Coastguard Worker            self.compare_with_numpy(
473*da0073e9SAndroid Build Coastguard Worker                lambda x: x.nan_to_num(nan=nan, posinf=posinf),
474*da0073e9SAndroid Build Coastguard Worker                lambda x: np.nan_to_num(x, nan=nan, posinf=posinf),
475*da0073e9SAndroid Build Coastguard Worker                x,
476*da0073e9SAndroid Build Coastguard Worker            )
477*da0073e9SAndroid Build Coastguard Worker            self.compare_with_numpy(
478*da0073e9SAndroid Build Coastguard Worker                lambda x: x.nan_to_num(posinf=posinf, neginf=neginf),
479*da0073e9SAndroid Build Coastguard Worker                lambda x: np.nan_to_num(x, posinf=posinf, neginf=neginf),
480*da0073e9SAndroid Build Coastguard Worker                x,
481*da0073e9SAndroid Build Coastguard Worker            )
482*da0073e9SAndroid Build Coastguard Worker
483*da0073e9SAndroid Build Coastguard Worker            # Out Variant
484*da0073e9SAndroid Build Coastguard Worker            out = torch.empty_like(x)
485*da0073e9SAndroid Build Coastguard Worker            result = torch.nan_to_num(x)
486*da0073e9SAndroid Build Coastguard Worker            torch.nan_to_num(x, out=out)
487*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, out)
488*da0073e9SAndroid Build Coastguard Worker
489*da0073e9SAndroid Build Coastguard Worker            result = torch.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf)
490*da0073e9SAndroid Build Coastguard Worker            torch.nan_to_num(x, out=out, nan=nan, posinf=posinf, neginf=neginf)
491*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, out)
492*da0073e9SAndroid Build Coastguard Worker
493*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
494*da0073e9SAndroid Build Coastguard Worker    def test_nan_to_num_bfloat16(self, device):
495*da0073e9SAndroid Build Coastguard Worker        def test_dtype(fn, input, dtype):
496*da0073e9SAndroid Build Coastguard Worker            input = input.detach().clone().to(dtype=dtype).requires_grad_(True)
497*da0073e9SAndroid Build Coastguard Worker            input2 = input.detach().clone().float().requires_grad_(True)
498*da0073e9SAndroid Build Coastguard Worker            out = fn(input)
499*da0073e9SAndroid Build Coastguard Worker            out.sum().backward()
500*da0073e9SAndroid Build Coastguard Worker            out2 = fn(input2)
501*da0073e9SAndroid Build Coastguard Worker            out2.sum().backward()
502*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out.dtype, dtype)
503*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input.grad.dtype, dtype)
504*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, out2, exact_dtype=False)
505*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input.grad, input2.grad, exact_dtype=False)
506*da0073e9SAndroid Build Coastguard Worker
507*da0073e9SAndroid Build Coastguard Worker        def func():
508*da0073e9SAndroid Build Coastguard Worker            return torch.nan_to_num
509*da0073e9SAndroid Build Coastguard Worker
510*da0073e9SAndroid Build Coastguard Worker        shapes = [[1, 3, 6, 6], [1, 3, 6, 128], [1, 3, 256, 256]]
511*da0073e9SAndroid Build Coastguard Worker        for shape in shapes:
512*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(shape, device=device)
513*da0073e9SAndroid Build Coastguard Worker            extremals = [float('nan'), float('inf'), -float('inf')]
514*da0073e9SAndroid Build Coastguard Worker            for id1, id2, extremal in zip(torch.randint(0, 2, (3,)), torch.randint(0, 5, (3,)), extremals):
515*da0073e9SAndroid Build Coastguard Worker                x[0, id1, id2, :] = extremal
516*da0073e9SAndroid Build Coastguard Worker            test_dtype(func(), x, torch.bfloat16)
517*da0073e9SAndroid Build Coastguard Worker
518*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.complex64, torch.complex128)
519*da0073e9SAndroid Build Coastguard Worker    def test_nan_to_num_complex(self, device, dtype):
520*da0073e9SAndroid Build Coastguard Worker        value_dtype = torch.tensor([], dtype=dtype).real.dtype
521*da0073e9SAndroid Build Coastguard Worker
522*da0073e9SAndroid Build Coastguard Worker        def gen_tensor(a):
523*da0073e9SAndroid Build Coastguard Worker            return torch.view_as_complex(torch.tensor(a, dtype=value_dtype, device=device))
524*da0073e9SAndroid Build Coastguard Worker
525*da0073e9SAndroid Build Coastguard Worker        for extremal, kwarg_name in zip(['nan', 'inf', '-inf'], ['nan', 'posinf', 'neginf']):
526*da0073e9SAndroid Build Coastguard Worker            a = gen_tensor([123, float(extremal)])
527*da0073e9SAndroid Build Coastguard Worker            res = torch.nan_to_num(a, **{kwarg_name: 12})
528*da0073e9SAndroid Build Coastguard Worker            res_check = gen_tensor([123, 12])
529*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, res_check)
530*da0073e9SAndroid Build Coastguard Worker
531*da0073e9SAndroid Build Coastguard Worker            a = gen_tensor([float(extremal), 456])
532*da0073e9SAndroid Build Coastguard Worker            res = torch.nan_to_num(a, **{kwarg_name: 21})
533*da0073e9SAndroid Build Coastguard Worker            res_check = gen_tensor([21, 456])
534*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, res_check)
535*da0073e9SAndroid Build Coastguard Worker
536*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.cdouble)
537*da0073e9SAndroid Build Coastguard Worker    def test_complex_edge_values(self, device, dtype):
538*da0073e9SAndroid Build Coastguard Worker        # sqrt Test Reference: https://github.com/pytorch/pytorch/pull/47424
539*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(0.0 - 1.0e20j, dtype=dtype, device=device)
540*da0073e9SAndroid Build Coastguard Worker        self.compare_with_numpy(torch.sqrt, np.sqrt, x)
541*da0073e9SAndroid Build Coastguard Worker        # acos test reference: https://github.com/pytorch/pytorch/issue/42952
542*da0073e9SAndroid Build Coastguard Worker        # Skip on Windows, as CUDA acos  returns conjugate value
543*da0073e9SAndroid Build Coastguard Worker        # see https://github.com/pytorch/pytorch/issues/52299
544*da0073e9SAndroid Build Coastguard Worker        if not (IS_WINDOWS and dtype == torch.cdouble and "cuda" in device):
545*da0073e9SAndroid Build Coastguard Worker            self.compare_with_numpy(torch.acos, np.arccos, x)
546*da0073e9SAndroid Build Coastguard Worker
547*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(
548*da0073e9SAndroid Build Coastguard Worker            (-1.0e60 if dtype == torch.cdouble else -1.0e20) - 4988429.2j,
549*da0073e9SAndroid Build Coastguard Worker            dtype=dtype,
550*da0073e9SAndroid Build Coastguard Worker            device=device,
551*da0073e9SAndroid Build Coastguard Worker        )
552*da0073e9SAndroid Build Coastguard Worker        self.compare_with_numpy(torch.sqrt, np.sqrt, x)
553*da0073e9SAndroid Build Coastguard Worker
554*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_SCIPY, "Requires SciPy")
555*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
556*da0073e9SAndroid Build Coastguard Worker    def test_digamma_special(self, device, dtype):
557*da0073e9SAndroid Build Coastguard Worker        # Based on SciPy test for the following special values.
558*da0073e9SAndroid Build Coastguard Worker        # Reference:
559*da0073e9SAndroid Build Coastguard Worker        # https://github.com/scipy/scipy/blob/3a8a3a1d4657254a6611e77e9c28feafa26e6645/scipy/special/tests/test_digamma.py#L22
560*da0073e9SAndroid Build Coastguard Worker        euler = 0.57721566490153286
561*da0073e9SAndroid Build Coastguard Worker        dataset = [
562*da0073e9SAndroid Build Coastguard Worker            (0.0, -0.0),
563*da0073e9SAndroid Build Coastguard Worker            (1, -euler),
564*da0073e9SAndroid Build Coastguard Worker            (0.5, -2 * math.log(2) - euler),
565*da0073e9SAndroid Build Coastguard Worker            (1 / 3, -math.pi / (2 * math.sqrt(3)) - 3 * math.log(3) / 2 - euler),
566*da0073e9SAndroid Build Coastguard Worker            (1 / 4, -math.pi / 2 - 3 * math.log(2) - euler),
567*da0073e9SAndroid Build Coastguard Worker            (
568*da0073e9SAndroid Build Coastguard Worker                1 / 6,
569*da0073e9SAndroid Build Coastguard Worker                -math.pi * math.sqrt(3) / 2
570*da0073e9SAndroid Build Coastguard Worker                - 2 * math.log(2)
571*da0073e9SAndroid Build Coastguard Worker                - 3 * math.log(3) / 2
572*da0073e9SAndroid Build Coastguard Worker                - euler,
573*da0073e9SAndroid Build Coastguard Worker            ),
574*da0073e9SAndroid Build Coastguard Worker            (
575*da0073e9SAndroid Build Coastguard Worker                1 / 8,
576*da0073e9SAndroid Build Coastguard Worker                -math.pi / 2
577*da0073e9SAndroid Build Coastguard Worker                - 4 * math.log(2)
578*da0073e9SAndroid Build Coastguard Worker                - (math.pi + math.log(2 + math.sqrt(2)) - math.log(2 - math.sqrt(2)))
579*da0073e9SAndroid Build Coastguard Worker                / math.sqrt(2)
580*da0073e9SAndroid Build Coastguard Worker                - euler,
581*da0073e9SAndroid Build Coastguard Worker            ),
582*da0073e9SAndroid Build Coastguard Worker        ]
583*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(dataset, device=device, dtype=dtype)
584*da0073e9SAndroid Build Coastguard Worker        self.compare_with_numpy(torch.digamma, scipy.special.digamma, x)
585*da0073e9SAndroid Build Coastguard Worker
586*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_SCIPY, "Requires SciPy")
587*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
588*da0073e9SAndroid Build Coastguard Worker    def test_digamma(self, device, dtype):
589*da0073e9SAndroid Build Coastguard Worker        # Tests pole behavior
590*da0073e9SAndroid Build Coastguard Worker        tensor = torch.tensor(
591*da0073e9SAndroid Build Coastguard Worker            [
592*da0073e9SAndroid Build Coastguard Worker                -0.999999994,
593*da0073e9SAndroid Build Coastguard Worker                -1.999999994,
594*da0073e9SAndroid Build Coastguard Worker                -2.0000000111,
595*da0073e9SAndroid Build Coastguard Worker                -100.99999994,
596*da0073e9SAndroid Build Coastguard Worker                0.000000111,
597*da0073e9SAndroid Build Coastguard Worker                -1931.99999994,
598*da0073e9SAndroid Build Coastguard Worker                -0.000000111,
599*da0073e9SAndroid Build Coastguard Worker                0,
600*da0073e9SAndroid Build Coastguard Worker                -0,
601*da0073e9SAndroid Build Coastguard Worker                -1,
602*da0073e9SAndroid Build Coastguard Worker                -2,
603*da0073e9SAndroid Build Coastguard Worker                -931,
604*da0073e9SAndroid Build Coastguard Worker            ],
605*da0073e9SAndroid Build Coastguard Worker            dtype=dtype,
606*da0073e9SAndroid Build Coastguard Worker            device=device,
607*da0073e9SAndroid Build Coastguard Worker        )
608*da0073e9SAndroid Build Coastguard Worker        self.compare_with_numpy(torch.digamma, scipy.special.digamma, tensor)
609*da0073e9SAndroid Build Coastguard Worker
610*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_types_and(torch.half))
611*da0073e9SAndroid Build Coastguard Worker    def test_frexp(self, device, dtype):
612*da0073e9SAndroid Build Coastguard Worker        input = make_tensor((50, 50), dtype=dtype, device=device)
613*da0073e9SAndroid Build Coastguard Worker        mantissa, exponent = torch.frexp(input)
614*da0073e9SAndroid Build Coastguard Worker        np_mantissa, np_exponent = np.frexp(input.cpu().numpy())
615*da0073e9SAndroid Build Coastguard Worker
616*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(mantissa, np_mantissa)
617*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(exponent, np_exponent)
618*da0073e9SAndroid Build Coastguard Worker
619*da0073e9SAndroid Build Coastguard Worker        # torch.frexp returns exponent in int32 to be compatible with np.frexp
620*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(exponent.dtype == torch.int32)
621*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch_to_numpy_dtype_dict[exponent.dtype] == np_exponent.dtype)
622*da0073e9SAndroid Build Coastguard Worker
623*da0073e9SAndroid Build Coastguard Worker    def test_frexp_assert_raises(self, device):
624*da0073e9SAndroid Build Coastguard Worker        invalid_input_dtypes = integral_types_and(torch.bool) + complex_types()
625*da0073e9SAndroid Build Coastguard Worker        for dtype in invalid_input_dtypes:
626*da0073e9SAndroid Build Coastguard Worker            input = make_tensor((50, 50), dtype=dtype, device=device)
627*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
628*da0073e9SAndroid Build Coastguard Worker                RuntimeError, r"torch\.frexp\(\) only supports floating-point dtypes"
629*da0073e9SAndroid Build Coastguard Worker            ):
630*da0073e9SAndroid Build Coastguard Worker                torch.frexp(input)
631*da0073e9SAndroid Build Coastguard Worker
632*da0073e9SAndroid Build Coastguard Worker        for dtype in floating_types_and(torch.half):
633*da0073e9SAndroid Build Coastguard Worker            input = make_tensor((50, 50), dtype=dtype, device=device)
634*da0073e9SAndroid Build Coastguard Worker
635*da0073e9SAndroid Build Coastguard Worker            dtypes = list(
636*da0073e9SAndroid Build Coastguard Worker                all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16)
637*da0073e9SAndroid Build Coastguard Worker            )
638*da0073e9SAndroid Build Coastguard Worker            dtypes.remove(dtype)
639*da0073e9SAndroid Build Coastguard Worker            for mantissa_dtype in dtypes:
640*da0073e9SAndroid Build Coastguard Worker                mantissa = torch.empty_like(input, dtype=mantissa_dtype)
641*da0073e9SAndroid Build Coastguard Worker                exponent = torch.empty_like(input, dtype=torch.int)
642*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(
643*da0073e9SAndroid Build Coastguard Worker                    RuntimeError,
644*da0073e9SAndroid Build Coastguard Worker                    r"torch\.frexp\(\) expects mantissa to have dtype .+ but got .+",
645*da0073e9SAndroid Build Coastguard Worker                ):
646*da0073e9SAndroid Build Coastguard Worker                    torch.frexp(input, out=(mantissa, exponent))
647*da0073e9SAndroid Build Coastguard Worker
648*da0073e9SAndroid Build Coastguard Worker            dtypes.append(dtype)
649*da0073e9SAndroid Build Coastguard Worker            dtypes.remove(torch.int)
650*da0073e9SAndroid Build Coastguard Worker            for exponent_dtype in dtypes:
651*da0073e9SAndroid Build Coastguard Worker                mantissa = torch.empty_like(input)
652*da0073e9SAndroid Build Coastguard Worker                exponent = torch.empty_like(input, dtype=exponent_dtype)
653*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(
654*da0073e9SAndroid Build Coastguard Worker                    RuntimeError,
655*da0073e9SAndroid Build Coastguard Worker                    r"torch\.frexp\(\) expects exponent to have int dtype but got .+",
656*da0073e9SAndroid Build Coastguard Worker                ):
657*da0073e9SAndroid Build Coastguard Worker                    torch.frexp(input, out=(mantissa, exponent))
658*da0073e9SAndroid Build Coastguard Worker
659*da0073e9SAndroid Build Coastguard Worker    def test_polygamma_neg(self, device):
660*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
661*da0073e9SAndroid Build Coastguard Worker            RuntimeError, r"polygamma\(n, x\) does not support negative n\."
662*da0073e9SAndroid Build Coastguard Worker        ):
663*da0073e9SAndroid Build Coastguard Worker            torch.polygamma(-1, torch.tensor([1.0, 2.0], device=device))
664*da0073e9SAndroid Build Coastguard Worker
665*da0073e9SAndroid Build Coastguard Worker    # TODO resolve with opinfos
666*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
667*da0073e9SAndroid Build Coastguard Worker    def test_op_invert(self, device):
668*da0073e9SAndroid Build Coastguard Worker        res = 0xFFFF - torch.arange(127, dtype=torch.int8)
669*da0073e9SAndroid Build Coastguard Worker        for dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
670*da0073e9SAndroid Build Coastguard Worker            a = torch.arange(127, dtype=dtype)
671*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res.to(dtype), ~a)
672*da0073e9SAndroid Build Coastguard Worker
673*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor([True, False]), ~torch.tensor([False, True]))
674*da0073e9SAndroid Build Coastguard Worker
675*da0073e9SAndroid Build Coastguard Worker        # test exceptions
676*da0073e9SAndroid Build Coastguard Worker        for dtype in (torch.half, torch.float, torch.double):
677*da0073e9SAndroid Build Coastguard Worker            a = torch.zeros(10, dtype=dtype)
678*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(TypeError):
679*da0073e9SAndroid Build Coastguard Worker                b = ~a
680*da0073e9SAndroid Build Coastguard Worker
681*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.complex64, torch.complex128)
682*da0073e9SAndroid Build Coastguard Worker    def test_abs_angle_complex_to_float(self, device, dtype):
683*da0073e9SAndroid Build Coastguard Worker        # Constructs random complex values
684*da0073e9SAndroid Build Coastguard Worker        from random import random
685*da0073e9SAndroid Build Coastguard Worker
686*da0073e9SAndroid Build Coastguard Worker        random_vals = []
687*da0073e9SAndroid Build Coastguard Worker        for multiplier in (-1, 1, -10, 10, -100, 100):
688*da0073e9SAndroid Build Coastguard Worker            for _ in range(10):
689*da0073e9SAndroid Build Coastguard Worker                random_vals.append(
690*da0073e9SAndroid Build Coastguard Worker                    complex(random() * multiplier, random() * multiplier)
691*da0073e9SAndroid Build Coastguard Worker                )
692*da0073e9SAndroid Build Coastguard Worker
693*da0073e9SAndroid Build Coastguard Worker        for vals in (random_vals, []):
694*da0073e9SAndroid Build Coastguard Worker            a = np.array(vals, dtype=torch_to_numpy_dtype_dict[dtype])
695*da0073e9SAndroid Build Coastguard Worker            t = torch.tensor(vals, device=device, dtype=dtype)
696*da0073e9SAndroid Build Coastguard Worker
697*da0073e9SAndroid Build Coastguard Worker            for fn_name in ("abs", "angle"):
698*da0073e9SAndroid Build Coastguard Worker                torch_fn = getattr(torch, fn_name)
699*da0073e9SAndroid Build Coastguard Worker                np_fn = getattr(np, fn_name)
700*da0073e9SAndroid Build Coastguard Worker
701*da0073e9SAndroid Build Coastguard Worker                # Tests function
702*da0073e9SAndroid Build Coastguard Worker                np_result = torch.from_numpy(np_fn(a))
703*da0073e9SAndroid Build Coastguard Worker                torch_result = torch_fn(t).cpu()
704*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(np_result, torch_result, exact_dtype=True)
705*da0073e9SAndroid Build Coastguard Worker
706*da0073e9SAndroid Build Coastguard Worker                # Tests float out
707*da0073e9SAndroid Build Coastguard Worker                float_dtype = (
708*da0073e9SAndroid Build Coastguard Worker                    torch.float32 if dtype is torch.complex64 else torch.float64
709*da0073e9SAndroid Build Coastguard Worker                )
710*da0073e9SAndroid Build Coastguard Worker                np_float_out = np_fn(a).astype(torch_to_numpy_dtype_dict[float_dtype])
711*da0073e9SAndroid Build Coastguard Worker                float_out = torch.empty_like(t, dtype=float_dtype)
712*da0073e9SAndroid Build Coastguard Worker                torch_fn(t, out=float_out)
713*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(torch.from_numpy(np_float_out), float_out.cpu())
714*da0073e9SAndroid Build Coastguard Worker
715*da0073e9SAndroid Build Coastguard Worker                # Tests float out (resized out)
716*da0073e9SAndroid Build Coastguard Worker                float_out = torch.empty(1, device=device, dtype=float_dtype)
717*da0073e9SAndroid Build Coastguard Worker                torch_fn(t, out=float_out)
718*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(torch.from_numpy(np_float_out), float_out.cpu())
719*da0073e9SAndroid Build Coastguard Worker
720*da0073e9SAndroid Build Coastguard Worker                # Tests complex out
721*da0073e9SAndroid Build Coastguard Worker                np_complex_out = np_fn(a).astype(torch_to_numpy_dtype_dict[dtype])
722*da0073e9SAndroid Build Coastguard Worker                complex_out = torch.empty_like(t)
723*da0073e9SAndroid Build Coastguard Worker                torch_fn(t, out=complex_out)
724*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(torch.from_numpy(np_complex_out), complex_out.cpu())
725*da0073e9SAndroid Build Coastguard Worker
726*da0073e9SAndroid Build Coastguard Worker                # Tests complex out (resized out)
727*da0073e9SAndroid Build Coastguard Worker                complex_out = torch.empty(0, device=device, dtype=dtype)
728*da0073e9SAndroid Build Coastguard Worker                torch_fn(t, out=complex_out)
729*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(torch.from_numpy(np_complex_out), complex_out.cpu())
730*da0073e9SAndroid Build Coastguard Worker
731*da0073e9SAndroid Build Coastguard Worker                # Tests long out behavior (expected failure)
732*da0073e9SAndroid Build Coastguard Worker                long_out = torch.empty(0, device=device, dtype=torch.long)
733*da0073e9SAndroid Build Coastguard Worker                with self.assertRaises(RuntimeError):
734*da0073e9SAndroid Build Coastguard Worker                    torch_fn(t, out=long_out)
735*da0073e9SAndroid Build Coastguard Worker
736*da0073e9SAndroid Build Coastguard Worker                # Tests inplace
737*da0073e9SAndroid Build Coastguard Worker                if fn_name == "abs":
738*da0073e9SAndroid Build Coastguard Worker                    torch_inplace_method = getattr(torch.Tensor, fn_name + "_")
739*da0073e9SAndroid Build Coastguard Worker                    np_fn(a, out=a)
740*da0073e9SAndroid Build Coastguard Worker                    if dtype.is_complex:
741*da0073e9SAndroid Build Coastguard Worker                        with self.assertRaisesRegex(
742*da0073e9SAndroid Build Coastguard Worker                            RuntimeError,
743*da0073e9SAndroid Build Coastguard Worker                            "In-place abs is not supported for complex tensors.",
744*da0073e9SAndroid Build Coastguard Worker                        ):
745*da0073e9SAndroid Build Coastguard Worker                            torch_inplace_method(t)
746*da0073e9SAndroid Build Coastguard Worker                        return
747*da0073e9SAndroid Build Coastguard Worker                    torch_inplace_method(t)
748*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(torch.from_numpy(a), t.cpu())
749*da0073e9SAndroid Build Coastguard Worker
750*da0073e9SAndroid Build Coastguard Worker                # Note: angle does not have an in-place variant
751*da0073e9SAndroid Build Coastguard Worker                if fn_name == "angle":
752*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaises(AttributeError):
753*da0073e9SAndroid Build Coastguard Worker                        torch_inplace_method = getattr(torch.Tensor, fn_name + "_")
754*da0073e9SAndroid Build Coastguard Worker
755*da0073e9SAndroid Build Coastguard Worker    def check_internal_mem_overlap(
756*da0073e9SAndroid Build Coastguard Worker        self, inplace_op, num_inputs, dtype, device, expected_failure=False
757*da0073e9SAndroid Build Coastguard Worker    ):
758*da0073e9SAndroid Build Coastguard Worker        if isinstance(inplace_op, str):
759*da0073e9SAndroid Build Coastguard Worker            inplace_op = getattr(torch.Tensor, inplace_op)
760*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(1, dtype=dtype, device=device).expand(3, 3)
761*da0073e9SAndroid Build Coastguard Worker        inputs = [input] + [torch.randn_like(input) for i in range(num_inputs - 1)]
762*da0073e9SAndroid Build Coastguard Worker        if not expected_failure:
763*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "single memory location"):
764*da0073e9SAndroid Build Coastguard Worker                inplace_op(*inputs)
765*da0073e9SAndroid Build Coastguard Worker        else:
766*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(AssertionError):
767*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, "single memory location"):
768*da0073e9SAndroid Build Coastguard Worker                    inplace_op(*inputs)
769*da0073e9SAndroid Build Coastguard Worker
770*da0073e9SAndroid Build Coastguard Worker    def unary_check_input_output_mem_overlap(
771*da0073e9SAndroid Build Coastguard Worker        self, data, sz, op, expected_failure=False
772*da0073e9SAndroid Build Coastguard Worker    ):
773*da0073e9SAndroid Build Coastguard Worker        def _test(op, output, input):
774*da0073e9SAndroid Build Coastguard Worker            output_exp = torch.empty_like(output)
775*da0073e9SAndroid Build Coastguard Worker            op(input, out=output_exp)
776*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(op(input, out=output), output_exp, msg=op.__name__)
777*da0073e9SAndroid Build Coastguard Worker
778*da0073e9SAndroid Build Coastguard Worker        # output is identical to input:
779*da0073e9SAndroid Build Coastguard Worker        _test(op, output=data[0:sz], input=data[0:sz])
780*da0073e9SAndroid Build Coastguard Worker        # output and input are independent:
781*da0073e9SAndroid Build Coastguard Worker        _test(op, output=data[0:sz], input=data[sz : 2 * sz])
782*da0073e9SAndroid Build Coastguard Worker        # output partially overlaps with input:
783*da0073e9SAndroid Build Coastguard Worker        if not expected_failure:
784*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "unsupported operation"):
785*da0073e9SAndroid Build Coastguard Worker                _test(op, data[0:sz], data[1 : sz + 1])
786*da0073e9SAndroid Build Coastguard Worker        else:
787*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(AssertionError):
788*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, "unsupported operation"):
789*da0073e9SAndroid Build Coastguard Worker                    _test(op, data[0:sz], data[1 : sz + 1])
790*da0073e9SAndroid Build Coastguard Worker
791*da0073e9SAndroid Build Coastguard Worker    # TODO: run on non-native device types
792*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/126474
793*da0073e9SAndroid Build Coastguard Worker    @xfailIfTorchDynamo
794*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
795*da0073e9SAndroid Build Coastguard Worker    def test_unary_out_op_mem_overlap(self, device, dtype):
796*da0073e9SAndroid Build Coastguard Worker        sz = 3
797*da0073e9SAndroid Build Coastguard Worker        doubles = torch.randn(2 * sz, dtype=dtype, device=device)
798*da0073e9SAndroid Build Coastguard Worker        positives = torch.randint(1, 100, (2 * sz,), device=device).double()
799*da0073e9SAndroid Build Coastguard Worker        ints = torch.randint(-100, 100, (2 * sz,), device=device)
800*da0073e9SAndroid Build Coastguard Worker        unary_mem_overlap_cases = [
801*da0073e9SAndroid Build Coastguard Worker            ("abs", doubles, True, True, "cpu"),
802*da0073e9SAndroid Build Coastguard Worker            ("abs", doubles, True, True, "cuda"),
803*da0073e9SAndroid Build Coastguard Worker            ("acos", doubles, True, True, "cpu"),
804*da0073e9SAndroid Build Coastguard Worker            ("acos", doubles, True, True, "cuda"),
805*da0073e9SAndroid Build Coastguard Worker            ("asin", doubles, True, True, "cpu"),
806*da0073e9SAndroid Build Coastguard Worker            ("asin", doubles, True, True, "cuda"),
807*da0073e9SAndroid Build Coastguard Worker            ("atan", doubles, True, True, "cpu"),
808*da0073e9SAndroid Build Coastguard Worker            ("atan", doubles, True, True, "cuda"),
809*da0073e9SAndroid Build Coastguard Worker            ("acosh", doubles, True, True, "cpu"),
810*da0073e9SAndroid Build Coastguard Worker            ("acosh", doubles, True, True, "cuda"),
811*da0073e9SAndroid Build Coastguard Worker            ("asinh", doubles, True, True, "cpu"),
812*da0073e9SAndroid Build Coastguard Worker            ("asinh", doubles, True, True, "cuda"),
813*da0073e9SAndroid Build Coastguard Worker            ("atanh", doubles, True, True, "cpu"),
814*da0073e9SAndroid Build Coastguard Worker            ("atanh", doubles, True, True, "cuda"),
815*da0073e9SAndroid Build Coastguard Worker            ("bitwise_not", ints, True, True, "cpu"),
816*da0073e9SAndroid Build Coastguard Worker            ("bitwise_not", ints, True, True, "cuda"),
817*da0073e9SAndroid Build Coastguard Worker            ("ceil", doubles, True, True, "cpu"),
818*da0073e9SAndroid Build Coastguard Worker            ("ceil", doubles, True, True, "cuda"),
819*da0073e9SAndroid Build Coastguard Worker            ("cos", doubles, True, True, "cpu"),
820*da0073e9SAndroid Build Coastguard Worker            ("cos", doubles, True, True, "cuda"),
821*da0073e9SAndroid Build Coastguard Worker            ("cosh", doubles, True, True, "cpu"),
822*da0073e9SAndroid Build Coastguard Worker            ("cosh", doubles, True, True, "cuda"),
823*da0073e9SAndroid Build Coastguard Worker            ("digamma", doubles, True, True, "cpu"),
824*da0073e9SAndroid Build Coastguard Worker            ("erf", doubles, True, True, "cpu"),
825*da0073e9SAndroid Build Coastguard Worker            ("erf", doubles, True, True, "cuda"),
826*da0073e9SAndroid Build Coastguard Worker            ("erfc", doubles, True, True, "cpu"),
827*da0073e9SAndroid Build Coastguard Worker            ("erfc", doubles, True, True, "cuda"),
828*da0073e9SAndroid Build Coastguard Worker            ("erfinv", doubles, True, True, "cpu"),
829*da0073e9SAndroid Build Coastguard Worker            ("erfinv", doubles, True, True, "cuda"),
830*da0073e9SAndroid Build Coastguard Worker            ("exp", doubles, True, True, "cpu"),
831*da0073e9SAndroid Build Coastguard Worker            ("exp", doubles, True, True, "cuda"),
832*da0073e9SAndroid Build Coastguard Worker            ("exp2", doubles, True, True, "cpu"),
833*da0073e9SAndroid Build Coastguard Worker            ("exp2", doubles, True, True, "cuda"),
834*da0073e9SAndroid Build Coastguard Worker            ("expm1", doubles, True, True, "cpu"),
835*da0073e9SAndroid Build Coastguard Worker            ("expm1", doubles, True, True, "cuda"),
836*da0073e9SAndroid Build Coastguard Worker            ("floor", doubles, True, True, "cpu"),
837*da0073e9SAndroid Build Coastguard Worker            ("floor", doubles, True, True, "cuda"),
838*da0073e9SAndroid Build Coastguard Worker            ("frac", doubles, True, True, "cpu"),
839*da0073e9SAndroid Build Coastguard Worker            ("frac", doubles, True, True, "cuda"),
840*da0073e9SAndroid Build Coastguard Worker            ("i0", doubles, True, True, "cpu"),
841*da0073e9SAndroid Build Coastguard Worker            ("i0", doubles, True, True, "cuda"),
842*da0073e9SAndroid Build Coastguard Worker            ("log", positives, True, True, "cpu"),
843*da0073e9SAndroid Build Coastguard Worker            ("log", positives, True, True, "cuda"),
844*da0073e9SAndroid Build Coastguard Worker            ("log10", positives, True, True, "cpu"),
845*da0073e9SAndroid Build Coastguard Worker            ("log10", positives, True, True, "cuda"),
846*da0073e9SAndroid Build Coastguard Worker            ("log1p", positives, True, True, "cpu"),
847*da0073e9SAndroid Build Coastguard Worker            ("log1p", positives, True, True, "cuda"),
848*da0073e9SAndroid Build Coastguard Worker            ("log2", positives, True, True, "cpu"),
849*da0073e9SAndroid Build Coastguard Worker            ("log2", positives, True, True, "cuda"),
850*da0073e9SAndroid Build Coastguard Worker            ("neg", doubles, True, True, "cpu"),
851*da0073e9SAndroid Build Coastguard Worker            ("neg", doubles, True, True, "cuda"),
852*da0073e9SAndroid Build Coastguard Worker            ("reciprocal", doubles, True, True, "cpu"),
853*da0073e9SAndroid Build Coastguard Worker            ("reciprocal", doubles, True, True, "cuda"),
854*da0073e9SAndroid Build Coastguard Worker            ("round", doubles, True, True, "cpu"),
855*da0073e9SAndroid Build Coastguard Worker            ("round", doubles, True, True, "cuda"),
856*da0073e9SAndroid Build Coastguard Worker            ("rsqrt", positives, True, True, "cpu"),
857*da0073e9SAndroid Build Coastguard Worker            ("rsqrt", positives, True, True, "cuda"),
858*da0073e9SAndroid Build Coastguard Worker            ("sin", doubles, True, True, "cpu"),
859*da0073e9SAndroid Build Coastguard Worker            ("sin", doubles, True, True, "cuda"),
860*da0073e9SAndroid Build Coastguard Worker            ("sinh", doubles, True, True, "cpu"),
861*da0073e9SAndroid Build Coastguard Worker            ("sinh", doubles, False, True, "cuda"),
862*da0073e9SAndroid Build Coastguard Worker            ("sigmoid", doubles, True, True, "cpu"),
863*da0073e9SAndroid Build Coastguard Worker            ("sigmoid", doubles, True, True, "cuda"),
864*da0073e9SAndroid Build Coastguard Worker            ("logit", doubles, True, True, "cpu"),
865*da0073e9SAndroid Build Coastguard Worker            ("logit", doubles, True, True, "cuda"),
866*da0073e9SAndroid Build Coastguard Worker            ("sqrt", doubles, True, True, "cpu"),
867*da0073e9SAndroid Build Coastguard Worker            ("sqrt", doubles, False, True, "cuda"),
868*da0073e9SAndroid Build Coastguard Worker            ("tan", doubles, True, True, "cpu"),
869*da0073e9SAndroid Build Coastguard Worker            ("tan", doubles, True, True, "cuda"),
870*da0073e9SAndroid Build Coastguard Worker            ("tanh", doubles, True, True, "cpu"),
871*da0073e9SAndroid Build Coastguard Worker            ("tanh", doubles, True, True, "cuda"),
872*da0073e9SAndroid Build Coastguard Worker            ("trunc", doubles, True, True, "cpu"),
873*da0073e9SAndroid Build Coastguard Worker            ("trunc", doubles, True, True, "cuda"),
874*da0073e9SAndroid Build Coastguard Worker        ]
875*da0073e9SAndroid Build Coastguard Worker
876*da0073e9SAndroid Build Coastguard Worker        for (
877*da0073e9SAndroid Build Coastguard Worker            fn,
878*da0073e9SAndroid Build Coastguard Worker            inputs,
879*da0073e9SAndroid Build Coastguard Worker            has_input_output_mem_overlap_check,
880*da0073e9SAndroid Build Coastguard Worker            has_internal_mem_overlap_check,
881*da0073e9SAndroid Build Coastguard Worker            dev,
882*da0073e9SAndroid Build Coastguard Worker        ) in unary_mem_overlap_cases:
883*da0073e9SAndroid Build Coastguard Worker            if dev != device:
884*da0073e9SAndroid Build Coastguard Worker                continue
885*da0073e9SAndroid Build Coastguard Worker            out_fn = getattr(torch, fn)
886*da0073e9SAndroid Build Coastguard Worker            in_fn = getattr(torch.Tensor, fn + "_")
887*da0073e9SAndroid Build Coastguard Worker
888*da0073e9SAndroid Build Coastguard Worker            self.unary_check_input_output_mem_overlap(
889*da0073e9SAndroid Build Coastguard Worker                inputs,
890*da0073e9SAndroid Build Coastguard Worker                sz,
891*da0073e9SAndroid Build Coastguard Worker                out_fn,
892*da0073e9SAndroid Build Coastguard Worker                expected_failure=not has_input_output_mem_overlap_check,
893*da0073e9SAndroid Build Coastguard Worker            )
894*da0073e9SAndroid Build Coastguard Worker
895*da0073e9SAndroid Build Coastguard Worker            self.check_internal_mem_overlap(
896*da0073e9SAndroid Build Coastguard Worker                in_fn,
897*da0073e9SAndroid Build Coastguard Worker                1,
898*da0073e9SAndroid Build Coastguard Worker                dtype,
899*da0073e9SAndroid Build Coastguard Worker                dev,
900*da0073e9SAndroid Build Coastguard Worker                expected_failure=not has_internal_mem_overlap_check,
901*da0073e9SAndroid Build Coastguard Worker            )
902*da0073e9SAndroid Build Coastguard Worker
903*da0073e9SAndroid Build Coastguard Worker    # TODO: opinfo hardshrink
904*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
905*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.bfloat16)
906*da0073e9SAndroid Build Coastguard Worker    def test_hardshrink(self, device, dtype):
907*da0073e9SAndroid Build Coastguard Worker        data = torch.tensor([1, 0.5, 0.3, 0.6], dtype=dtype, device=device).view(2, 2)
908*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
909*da0073e9SAndroid Build Coastguard Worker            torch.tensor([1, 0.5, 0, 0.6], dtype=dtype, device=device).view(2, 2),
910*da0073e9SAndroid Build Coastguard Worker            data.hardshrink(0.3),
911*da0073e9SAndroid Build Coastguard Worker        )
912*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
913*da0073e9SAndroid Build Coastguard Worker            torch.tensor([1, 0, 0, 0.6], dtype=dtype, device=device).view(2, 2),
914*da0073e9SAndroid Build Coastguard Worker            data.hardshrink(0.5),
915*da0073e9SAndroid Build Coastguard Worker        )
916*da0073e9SAndroid Build Coastguard Worker
917*da0073e9SAndroid Build Coastguard Worker        # test default lambd=0.5
918*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(data.hardshrink(), data.hardshrink(0.5))
919*da0073e9SAndroid Build Coastguard Worker
920*da0073e9SAndroid Build Coastguard Worker        # test non-contiguous case
921*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
922*da0073e9SAndroid Build Coastguard Worker            torch.tensor([1, 0, 0.5, 0.6], dtype=dtype, device=device).view(2, 2),
923*da0073e9SAndroid Build Coastguard Worker            data.t().hardshrink(0.3),
924*da0073e9SAndroid Build Coastguard Worker        )
925*da0073e9SAndroid Build Coastguard Worker
926*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
927*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.bfloat16)
928*da0073e9SAndroid Build Coastguard Worker    def test_hardshrink_edge_cases(self, device, dtype) -> None:
929*da0073e9SAndroid Build Coastguard Worker        def h(values, l_expected):
930*da0073e9SAndroid Build Coastguard Worker            for l, expected in l_expected.items():
931*da0073e9SAndroid Build Coastguard Worker                values_tensor = torch.tensor(
932*da0073e9SAndroid Build Coastguard Worker                    [float(v) for v in values], dtype=dtype, device=device
933*da0073e9SAndroid Build Coastguard Worker                )
934*da0073e9SAndroid Build Coastguard Worker                expected_tensor = torch.tensor(
935*da0073e9SAndroid Build Coastguard Worker                    [float(v) for v in expected], dtype=dtype, device=device
936*da0073e9SAndroid Build Coastguard Worker                )
937*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
938*da0073e9SAndroid Build Coastguard Worker                    expected_tensor == values_tensor.hardshrink(l),
939*da0073e9SAndroid Build Coastguard Worker                    torch.ones_like(values_tensor, dtype=torch.bool),
940*da0073e9SAndroid Build Coastguard Worker                )
941*da0073e9SAndroid Build Coastguard Worker
942*da0073e9SAndroid Build Coastguard Worker        def test_helper(min, max):
943*da0073e9SAndroid Build Coastguard Worker            h(
944*da0073e9SAndroid Build Coastguard Worker                [0.0, min, -min, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf],
945*da0073e9SAndroid Build Coastguard Worker                {
946*da0073e9SAndroid Build Coastguard Worker                    0.0: [0.0, min, -min, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf],
947*da0073e9SAndroid Build Coastguard Worker                    min: [0.0, 0.0, 0.0, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf],
948*da0073e9SAndroid Build Coastguard Worker                    0.1: [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, -1.0, max, -max, inf, -inf],
949*da0073e9SAndroid Build Coastguard Worker                    1.0: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, max, -max, inf, -inf],
950*da0073e9SAndroid Build Coastguard Worker                    max: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, inf, -inf],
951*da0073e9SAndroid Build Coastguard Worker                    inf: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
952*da0073e9SAndroid Build Coastguard Worker                },
953*da0073e9SAndroid Build Coastguard Worker            )
954*da0073e9SAndroid Build Coastguard Worker
955*da0073e9SAndroid Build Coastguard Worker        test_helper(torch.finfo(dtype).tiny, torch.finfo(dtype).max)
956*da0073e9SAndroid Build Coastguard Worker
957*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
958*da0073e9SAndroid Build Coastguard Worker    @slowTest
959*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
960*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(True, "Insufficient memory on linux.(2|4)xlarge")
961*da0073e9SAndroid Build Coastguard Worker    def test_exp_slow(self, device, dtype):
962*da0073e9SAndroid Build Coastguard Worker        # Test for https://github.com/pytorch/pytorch/issues/17271
963*da0073e9SAndroid Build Coastguard Worker        # This is pretty slow on my Macbook but it only takes a few
964*da0073e9SAndroid Build Coastguard Worker        # seconds on a beefy Xeon server
965*da0073e9SAndroid Build Coastguard Worker        a = torch.exp(torch.ones(2**31, dtype=dtype, device=device))
966*da0073e9SAndroid Build Coastguard Worker        b = torch.exp(torch.ones(1, dtype=dtype, device=device))
967*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a, b.expand(2**31))
968*da0073e9SAndroid Build Coastguard Worker
969*da0073e9SAndroid Build Coastguard Worker    @precisionOverride(
970*da0073e9SAndroid Build Coastguard Worker        {torch.bfloat16: 1e-2, torch.float: 0.0002, torch.double: 0.0002}
971*da0073e9SAndroid Build Coastguard Worker    )
972*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.bfloat16)
973*da0073e9SAndroid Build Coastguard Worker    def test_hardswish(self, device, dtype):
974*da0073e9SAndroid Build Coastguard Worker        inputValues = [-1000, -4, -3, -2, 0, 2, 3, 4, 1000]
975*da0073e9SAndroid Build Coastguard Worker        expectedOutput = np.multiply(
976*da0073e9SAndroid Build Coastguard Worker            inputValues, np.minimum(np.maximum((np.add(inputValues, 3)), 0), 6) / 6.0
977*da0073e9SAndroid Build Coastguard Worker        )
978*da0073e9SAndroid Build Coastguard Worker
979*da0073e9SAndroid Build Coastguard Worker        inputTensor = torch.tensor(inputValues, dtype=dtype, device=device)
980*da0073e9SAndroid Build Coastguard Worker        expectedOutputTensor = torch.tensor(expectedOutput, dtype=dtype, device=device)
981*da0073e9SAndroid Build Coastguard Worker
982*da0073e9SAndroid Build Coastguard Worker        # normal
983*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
984*da0073e9SAndroid Build Coastguard Worker            torch.nn.functional.hardswish(inputTensor), expectedOutputTensor
985*da0073e9SAndroid Build Coastguard Worker        )
986*da0073e9SAndroid Build Coastguard Worker
987*da0073e9SAndroid Build Coastguard Worker        # inplace
988*da0073e9SAndroid Build Coastguard Worker        inputTensorCpy = inputTensor.clone().detach()
989*da0073e9SAndroid Build Coastguard Worker        torch.nn.functional.hardswish(inputTensorCpy, inplace=True)
990*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(inputTensorCpy, expectedOutputTensor)
991*da0073e9SAndroid Build Coastguard Worker
992*da0073e9SAndroid Build Coastguard Worker    @precisionOverride(
993*da0073e9SAndroid Build Coastguard Worker        {torch.bfloat16: 1e-2, torch.float: 0.0002, torch.double: 0.0002}
994*da0073e9SAndroid Build Coastguard Worker    )
995*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.bfloat16)
996*da0073e9SAndroid Build Coastguard Worker    def test_hardsigmoid(self, device, dtype):
997*da0073e9SAndroid Build Coastguard Worker        inputValues = [-1000, -4, -3, -2, 0, 2, 3, 4, 1000]
998*da0073e9SAndroid Build Coastguard Worker        expectedOutput = np.minimum(np.maximum((np.add(inputValues, 3)), 0), 6) / 6.0
999*da0073e9SAndroid Build Coastguard Worker
1000*da0073e9SAndroid Build Coastguard Worker        inputTensor = torch.tensor(inputValues, dtype=dtype, device=device)
1001*da0073e9SAndroid Build Coastguard Worker
1002*da0073e9SAndroid Build Coastguard Worker        # normal
1003*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1004*da0073e9SAndroid Build Coastguard Worker            torch.nn.functional.hardsigmoid(inputTensor),
1005*da0073e9SAndroid Build Coastguard Worker            torch.tensor(expectedOutput, dtype=dtype, device=device),
1006*da0073e9SAndroid Build Coastguard Worker        )
1007*da0073e9SAndroid Build Coastguard Worker
1008*da0073e9SAndroid Build Coastguard Worker        # inplace
1009*da0073e9SAndroid Build Coastguard Worker        inputTensorCpy = inputTensor.clone().detach()
1010*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1011*da0073e9SAndroid Build Coastguard Worker            torch.nn.functional.hardsigmoid(inputTensorCpy, inplace=True),
1012*da0073e9SAndroid Build Coastguard Worker            torch.tensor(expectedOutput, dtype=dtype, device=device),
1013*da0073e9SAndroid Build Coastguard Worker        )
1014*da0073e9SAndroid Build Coastguard Worker
1015*da0073e9SAndroid Build Coastguard Worker    @precisionOverride(
1016*da0073e9SAndroid Build Coastguard Worker        {torch.bfloat16: 1e-2, torch.float: 0.0002, torch.double: 0.0002}
1017*da0073e9SAndroid Build Coastguard Worker    )
1018*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.bfloat16)
1019*da0073e9SAndroid Build Coastguard Worker    def test_hardsigmoid_backward(self, device, dtype):
1020*da0073e9SAndroid Build Coastguard Worker        inputValues = [-3.0, 3.0, -2.0, 2.0, -6.0, 6.0]
1021*da0073e9SAndroid Build Coastguard Worker        expectedValues = [0.0, 0.0, 1.0 / 6.0, 1.0 / 6.0, 0.0, 0.0]
1022*da0073e9SAndroid Build Coastguard Worker        inputTensor = torch.tensor(
1023*da0073e9SAndroid Build Coastguard Worker            inputValues, dtype=dtype, device=device
1024*da0073e9SAndroid Build Coastguard Worker        ).requires_grad_()
1025*da0073e9SAndroid Build Coastguard Worker        expetedTensor = torch.tensor(expectedValues, dtype=dtype, device=device)
1026*da0073e9SAndroid Build Coastguard Worker        out = torch.nn.functional.hardsigmoid(inputTensor)
1027*da0073e9SAndroid Build Coastguard Worker        out.backward(torch.ones_like(inputTensor))
1028*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(inputTensor.grad, expetedTensor)
1029*da0073e9SAndroid Build Coastguard Worker
1030*da0073e9SAndroid Build Coastguard Worker    @skipIfNoSciPy
1031*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
1032*da0073e9SAndroid Build Coastguard Worker    def test_silu(self, device, dtype):
1033*da0073e9SAndroid Build Coastguard Worker        input_np = np.random.randn(5, 8)
1034*da0073e9SAndroid Build Coastguard Worker        special_input = [[-1000, -1, -0.1, 0, 0.5, 1, 2, 1000]]
1035*da0073e9SAndroid Build Coastguard Worker        input_np = np.concatenate((input_np, special_input), axis=0).astype(
1036*da0073e9SAndroid Build Coastguard Worker            torch_to_numpy_dtype_dict[dtype]
1037*da0073e9SAndroid Build Coastguard Worker        )
1038*da0073e9SAndroid Build Coastguard Worker        expected_output_np = input_np * scipy.special.expit(input_np)
1039*da0073e9SAndroid Build Coastguard Worker
1040*da0073e9SAndroid Build Coastguard Worker        expected_output = torch.from_numpy(expected_output_np).to(device)
1041*da0073e9SAndroid Build Coastguard Worker        expected_output_noncontig = expected_output.transpose(0, 1)
1042*da0073e9SAndroid Build Coastguard Worker
1043*da0073e9SAndroid Build Coastguard Worker        atol = 1e-6
1044*da0073e9SAndroid Build Coastguard Worker        rtol = 1e-6
1045*da0073e9SAndroid Build Coastguard Worker
1046*da0073e9SAndroid Build Coastguard Worker        input = torch.from_numpy(input_np).clone().contiguous().to(device)
1047*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1048*da0073e9SAndroid Build Coastguard Worker            torch.nn.functional.silu(input), expected_output, atol=atol, rtol=rtol
1049*da0073e9SAndroid Build Coastguard Worker        )
1050*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1051*da0073e9SAndroid Build Coastguard Worker            torch.nn.functional.silu(input, inplace=True),
1052*da0073e9SAndroid Build Coastguard Worker            expected_output,
1053*da0073e9SAndroid Build Coastguard Worker            atol=atol,
1054*da0073e9SAndroid Build Coastguard Worker            rtol=rtol,
1055*da0073e9SAndroid Build Coastguard Worker        )
1056*da0073e9SAndroid Build Coastguard Worker
1057*da0073e9SAndroid Build Coastguard Worker        input = torch.from_numpy(input_np).clone().to(device)
1058*da0073e9SAndroid Build Coastguard Worker        input_noncontig = input.transpose(0, 1)
1059*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1060*da0073e9SAndroid Build Coastguard Worker            torch.nn.functional.silu(input_noncontig),
1061*da0073e9SAndroid Build Coastguard Worker            expected_output_noncontig,
1062*da0073e9SAndroid Build Coastguard Worker            atol=atol,
1063*da0073e9SAndroid Build Coastguard Worker            rtol=rtol,
1064*da0073e9SAndroid Build Coastguard Worker        )
1065*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1066*da0073e9SAndroid Build Coastguard Worker            torch.nn.functional.silu(input_noncontig, inplace=True),
1067*da0073e9SAndroid Build Coastguard Worker            expected_output_noncontig,
1068*da0073e9SAndroid Build Coastguard Worker            atol=atol,
1069*da0073e9SAndroid Build Coastguard Worker            rtol=rtol,
1070*da0073e9SAndroid Build Coastguard Worker        )
1071*da0073e9SAndroid Build Coastguard Worker
1072*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.complex64, torch.complex128)
1073*da0073e9SAndroid Build Coastguard Worker    def test_silu_complex(self, device, dtype):
1074*da0073e9SAndroid Build Coastguard Worker        atol = 1e-6
1075*da0073e9SAndroid Build Coastguard Worker        rtol = 1e-6
1076*da0073e9SAndroid Build Coastguard Worker        inouts = [
1077*da0073e9SAndroid Build Coastguard Worker            (0.2 + 0.3j, 0.08775215595960617065 + 0.18024823069572448730j),
1078*da0073e9SAndroid Build Coastguard Worker            (1e-19 + 1e-18j, 4.99999984132761269448e-20 + 5.00000022906852482872e-19j),
1079*da0073e9SAndroid Build Coastguard Worker            (-1.0 + 2.0j, -0.78546208143234252930 + -0.44626939296722412109j),
1080*da0073e9SAndroid Build Coastguard Worker            (0.0 + 0.5j, -0.06383547931909561157 + 0.25000000000000000000j),
1081*da0073e9SAndroid Build Coastguard Worker            (2.0j, -1.55740761756896972656 + 0.99999988079071044922j)
1082*da0073e9SAndroid Build Coastguard Worker        ]
1083*da0073e9SAndroid Build Coastguard Worker
1084*da0073e9SAndroid Build Coastguard Worker        for inp, out in inouts:
1085*da0073e9SAndroid Build Coastguard Worker            res = torch.nn.functional.silu(torch.tensor(inp, dtype=dtype, device=device))
1086*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.any(torch.isnan(res)))
1087*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res.real, out.real, atol=atol, rtol=rtol)
1088*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res.imag, out.imag, atol=atol, rtol=rtol)
1089*da0073e9SAndroid Build Coastguard Worker
1090*da0073e9SAndroid Build Coastguard Worker        for inp, out in inouts:
1091*da0073e9SAndroid Build Coastguard Worker            res = torch.nn.functional.silu(torch.tensor(inp, dtype=dtype, device=device), inplace=True)
1092*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.any(torch.isnan(res)))
1093*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res.real, out.real, atol=atol, rtol=rtol)
1094*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res.imag, out.imag, atol=atol, rtol=rtol)
1095*da0073e9SAndroid Build Coastguard Worker
1096*da0073e9SAndroid Build Coastguard Worker    # It is not obvious how to merge this into OpInfo becuase these inputs
1097*da0073e9SAndroid Build Coastguard Worker    # succeed for gradcheck but are expected to fail for gradgradcheck
1098*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
1099*da0073e9SAndroid Build Coastguard Worker    def test_sinc(self, device, dtype):
1100*da0073e9SAndroid Build Coastguard Worker        # The derivative of sinc(x) at x=0 has to be special cased.
1101*da0073e9SAndroid Build Coastguard Worker        # A naive computation will result in 0/0 -> NaN.
1102*da0073e9SAndroid Build Coastguard Worker        # We also need to be careful when we are very close to 0, as the
1103*da0073e9SAndroid Build Coastguard Worker        # derivative's denominator is squared, and there are some floats
1104*da0073e9SAndroid Build Coastguard Worker        # that are positive and whose squares are zero.
1105*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(
1106*da0073e9SAndroid Build Coastguard Worker            [0.0, torch.finfo(torch.double).tiny, 1.0],
1107*da0073e9SAndroid Build Coastguard Worker            dtype=dtype,
1108*da0073e9SAndroid Build Coastguard Worker            requires_grad=True,
1109*da0073e9SAndroid Build Coastguard Worker            device=device,
1110*da0073e9SAndroid Build Coastguard Worker        )
1111*da0073e9SAndroid Build Coastguard Worker        gradcheck(torch.sinc, a)
1112*da0073e9SAndroid Build Coastguard Worker
1113*da0073e9SAndroid Build Coastguard Worker    @skipIfNoSciPy
1114*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
1115*da0073e9SAndroid Build Coastguard Worker    def test_mish(self, device, dtype):
1116*da0073e9SAndroid Build Coastguard Worker        input_np = np.random.randn(5, 8)
1117*da0073e9SAndroid Build Coastguard Worker        special_input = [[-1000, -1, -0.1, 0, 0.5, 1, 2, 1000]]
1118*da0073e9SAndroid Build Coastguard Worker        input_np = np.concatenate((input_np, special_input), axis=0).astype(
1119*da0073e9SAndroid Build Coastguard Worker            torch_to_numpy_dtype_dict[dtype]
1120*da0073e9SAndroid Build Coastguard Worker        )
1121*da0073e9SAndroid Build Coastguard Worker        expected_output_np = input_np * np.tanh(np.log1p(np.exp(input_np)))
1122*da0073e9SAndroid Build Coastguard Worker
1123*da0073e9SAndroid Build Coastguard Worker        expected_output = torch.from_numpy(expected_output_np).to(device)
1124*da0073e9SAndroid Build Coastguard Worker        expected_output_noncontig = expected_output.transpose(0, 1)
1125*da0073e9SAndroid Build Coastguard Worker
1126*da0073e9SAndroid Build Coastguard Worker        atol = 1e-6
1127*da0073e9SAndroid Build Coastguard Worker        rtol = 1e-6
1128*da0073e9SAndroid Build Coastguard Worker
1129*da0073e9SAndroid Build Coastguard Worker        input = torch.from_numpy(input_np).clone().contiguous().to(device)
1130*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1131*da0073e9SAndroid Build Coastguard Worker            torch.nn.functional.mish(input), expected_output, atol=atol, rtol=rtol
1132*da0073e9SAndroid Build Coastguard Worker        )
1133*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1134*da0073e9SAndroid Build Coastguard Worker            torch.nn.functional.mish(input, inplace=True),
1135*da0073e9SAndroid Build Coastguard Worker            expected_output,
1136*da0073e9SAndroid Build Coastguard Worker            atol=atol,
1137*da0073e9SAndroid Build Coastguard Worker            rtol=rtol,
1138*da0073e9SAndroid Build Coastguard Worker        )
1139*da0073e9SAndroid Build Coastguard Worker
1140*da0073e9SAndroid Build Coastguard Worker        input = torch.from_numpy(input_np).clone().to(device)
1141*da0073e9SAndroid Build Coastguard Worker        input_noncontig = input.transpose(0, 1)
1142*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1143*da0073e9SAndroid Build Coastguard Worker            torch.nn.functional.mish(input_noncontig),
1144*da0073e9SAndroid Build Coastguard Worker            expected_output_noncontig,
1145*da0073e9SAndroid Build Coastguard Worker            atol=atol,
1146*da0073e9SAndroid Build Coastguard Worker            rtol=rtol,
1147*da0073e9SAndroid Build Coastguard Worker        )
1148*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1149*da0073e9SAndroid Build Coastguard Worker            torch.nn.functional.mish(input_noncontig, inplace=True),
1150*da0073e9SAndroid Build Coastguard Worker            expected_output_noncontig,
1151*da0073e9SAndroid Build Coastguard Worker            atol=atol,
1152*da0073e9SAndroid Build Coastguard Worker            rtol=rtol,
1153*da0073e9SAndroid Build Coastguard Worker        )
1154*da0073e9SAndroid Build Coastguard Worker
1155*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.complex64, torch.complex128)
1156*da0073e9SAndroid Build Coastguard Worker    def test_log1p_complex(self, device, dtype):
1157*da0073e9SAndroid Build Coastguard Worker        # The output values here were obtained using arbitrary precision math (mpmath)
1158*da0073e9SAndroid Build Coastguard Worker        # and double checked with WolframAlpha.
1159*da0073e9SAndroid Build Coastguard Worker        # Not using numpy's log1p here because by the time of writing this,
1160*da0073e9SAndroid Build Coastguard Worker        # np.log1p has precision problems for small complex input values, see here:
1161*da0073e9SAndroid Build Coastguard Worker        # https://github.com/numpy/numpy/issues/22609
1162*da0073e9SAndroid Build Coastguard Worker        inouts = [
1163*da0073e9SAndroid Build Coastguard Worker            (0.2 + 0.3j, 0.21263386770217202 + 0.24497866312686414j),
1164*da0073e9SAndroid Build Coastguard Worker            (1e-19 + 1e-18j, 1e-19 + 1e-18j),
1165*da0073e9SAndroid Build Coastguard Worker            (1e-18 + 0.1j, 0.00497517 + 0.0996687j),
1166*da0073e9SAndroid Build Coastguard Worker            (0.1 + 1e-18j, 0.0953102 + 9.090909090909090909e-19j),
1167*da0073e9SAndroid Build Coastguard Worker            (0.5 + 0j, 0.40546510810816 + 0j),
1168*da0073e9SAndroid Build Coastguard Worker            (0.0 + 0.5j, 0.111571776 + 0.463647609j),
1169*da0073e9SAndroid Build Coastguard Worker            (2.0 + 1.0j, 1.151292546497023 + 0.3217505543966422j),
1170*da0073e9SAndroid Build Coastguard Worker            (-1.0 + 2.0j, 0.6931471805599453 + 1.570796326794897j),
1171*da0073e9SAndroid Build Coastguard Worker            (2.0j, 0.80471895621705014 + 1.1071487177940904j),
1172*da0073e9SAndroid Build Coastguard Worker            (-2.0j, 0.80471895621705014 - 1.1071487177940904j),
1173*da0073e9SAndroid Build Coastguard Worker        ]
1174*da0073e9SAndroid Build Coastguard Worker        # test the extreme values
1175*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.complex128:
1176*da0073e9SAndroid Build Coastguard Worker            inouts += [
1177*da0073e9SAndroid Build Coastguard Worker                (-1 + 1e250j, 575.6462732485114 + 1.5707963267948966j),
1178*da0073e9SAndroid Build Coastguard Worker                (1e250 + 1j, 575.6462732485114 + 1e-250j),
1179*da0073e9SAndroid Build Coastguard Worker                (1e250 + 1e250j, 575.9928468387914 + 0.7853981633974483j),
1180*da0073e9SAndroid Build Coastguard Worker                (1e-250 + 1e250j, 575.6462732485114 + 1.5707963267948966j),
1181*da0073e9SAndroid Build Coastguard Worker                (1e-250 + 2e-250j, 1e-250 + 2e-250j),
1182*da0073e9SAndroid Build Coastguard Worker                (1e250 + 1e-250j, 575.6462732485114 + 0.0j),
1183*da0073e9SAndroid Build Coastguard Worker            ]
1184*da0073e9SAndroid Build Coastguard Worker        elif dtype == torch.complex64:
1185*da0073e9SAndroid Build Coastguard Worker            inouts += [
1186*da0073e9SAndroid Build Coastguard Worker                (-1 + 1e30j, 69.07755278982137 + 1.5707963267948966j),
1187*da0073e9SAndroid Build Coastguard Worker                (1e30 + 1j, 69.07755278982137 + 1e-30j),
1188*da0073e9SAndroid Build Coastguard Worker                (1e30 + 1e30j, 69.42412638010134 + 0.7853981633974483j),
1189*da0073e9SAndroid Build Coastguard Worker                (1e-30 + 1e30j, 69.07755278982137 + 1.5707963267948966j),
1190*da0073e9SAndroid Build Coastguard Worker                (1e-30 + 2e-30j, 1e-30 + 2e-30j),
1191*da0073e9SAndroid Build Coastguard Worker                (1e30 + 1e-30j, 69.07755278982137 + 0.0j),
1192*da0073e9SAndroid Build Coastguard Worker            ]
1193*da0073e9SAndroid Build Coastguard Worker
1194*da0073e9SAndroid Build Coastguard Worker        # test the log1p individually
1195*da0073e9SAndroid Build Coastguard Worker        for inp, out in inouts:
1196*da0073e9SAndroid Build Coastguard Worker            res = torch.log1p(torch.tensor(inp, dtype=dtype, device=device))
1197*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.any(torch.isnan(res)))
1198*da0073e9SAndroid Build Coastguard Worker            # setting up atol == 0.0 because some part has very small values
1199*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res.real, out.real, atol=0.0, rtol=1e-6)
1200*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res.imag, out.imag, atol=0.0, rtol=1e-6)
1201*da0073e9SAndroid Build Coastguard Worker
1202*da0073e9SAndroid Build Coastguard Worker        # test the log1p in tensor
1203*da0073e9SAndroid Build Coastguard Worker        inp_lst, out_lst = (list(elmt) for elmt in zip(*inouts))
1204*da0073e9SAndroid Build Coastguard Worker        inp_tens = torch.tensor(inp_lst, dtype=dtype, device=device)
1205*da0073e9SAndroid Build Coastguard Worker        out_tens = torch.tensor(out_lst, dtype=dtype, device=device)
1206*da0073e9SAndroid Build Coastguard Worker        res_tens = torch.log1p(inp_tens)
1207*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_tens.real, out_tens.real, atol=0.0, rtol=1e-6)
1208*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_tens.imag, out_tens.imag, atol=0.0, rtol=1e-6)
1209*da0073e9SAndroid Build Coastguard Worker
1210*da0073e9SAndroid Build Coastguard Worker    # do ops like threshold need a test_unary(_nonufunc) test suite?
1211*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1212*da0073e9SAndroid Build Coastguard Worker    @dtypes(*get_all_math_dtypes("cpu"))
1213*da0073e9SAndroid Build Coastguard Worker    def test_threshold(self, device, dtype):
1214*da0073e9SAndroid Build Coastguard Worker        if dtype != torch.uint8 and dtype != torch.float16 and not dtype.is_complex:
1215*da0073e9SAndroid Build Coastguard Worker            # 100 is wide enough to use AVX2 instructions for all types
1216*da0073e9SAndroid Build Coastguard Worker            x = (
1217*da0073e9SAndroid Build Coastguard Worker                torch.randn(100, dtype=torch.float, device=device)
1218*da0073e9SAndroid Build Coastguard Worker                .sign()
1219*da0073e9SAndroid Build Coastguard Worker                .to(dtype=dtype)
1220*da0073e9SAndroid Build Coastguard Worker            )
1221*da0073e9SAndroid Build Coastguard Worker            y = torch.threshold(x, 0, 0)
1222*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(y.le(0).any())
1223*da0073e9SAndroid Build Coastguard Worker
1224*da0073e9SAndroid Build Coastguard Worker    def _helper_test_igamma(self, loglo, loghi, device, dtype, torch_fcn, scipy_fcn):
1225*da0073e9SAndroid Build Coastguard Worker        exp1 = 2.71828182846
1226*da0073e9SAndroid Build Coastguard Worker        vec1 = torch.logspace(
1227*da0073e9SAndroid Build Coastguard Worker            loglo, loghi, steps=500, base=exp1, dtype=torch.float64, device=device
1228*da0073e9SAndroid Build Coastguard Worker        ).unsqueeze(-1)
1229*da0073e9SAndroid Build Coastguard Worker        vec1 = vec1.to(dtype)
1230*da0073e9SAndroid Build Coastguard Worker        inputs = [
1231*da0073e9SAndroid Build Coastguard Worker            (vec1, vec1.transpose(0, 1)),
1232*da0073e9SAndroid Build Coastguard Worker            (vec1, vec1),  # for large number, it should approach 0.5
1233*da0073e9SAndroid Build Coastguard Worker            (vec1, 0.5 * vec1),  # test for considerable ratio
1234*da0073e9SAndroid Build Coastguard Worker            (vec1, 2.0 * vec1),
1235*da0073e9SAndroid Build Coastguard Worker            (vec1[::2, :], vec1[::2, :]),  # contiguous/noncontiguous tests
1236*da0073e9SAndroid Build Coastguard Worker            (vec1[::2, :], vec1[: vec1.shape[0] // 2, :]),
1237*da0073e9SAndroid Build Coastguard Worker            (vec1[: vec1.shape[0] // 2, :], vec1[::2, :]),
1238*da0073e9SAndroid Build Coastguard Worker        ]
1239*da0073e9SAndroid Build Coastguard Worker        half_prec = dtype in [torch.bfloat16, torch.float16]
1240*da0073e9SAndroid Build Coastguard Worker        for input0, input1 in inputs:
1241*da0073e9SAndroid Build Coastguard Worker            actual = torch_fcn(input0, input1)
1242*da0073e9SAndroid Build Coastguard Worker            if half_prec:
1243*da0073e9SAndroid Build Coastguard Worker                input0 = input0.to(torch.float)
1244*da0073e9SAndroid Build Coastguard Worker                input1 = input1.to(torch.float)
1245*da0073e9SAndroid Build Coastguard Worker            expected = scipy_fcn(input0.cpu().numpy(), input1.cpu().numpy())
1246*da0073e9SAndroid Build Coastguard Worker            expected = torch.from_numpy(expected).to(dtype)
1247*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expected)
1248*da0073e9SAndroid Build Coastguard Worker
1249*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCPU(torch.float16, torch.bfloat16, torch.float32, torch.float64)
1250*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32, torch.float64)
1251*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1252*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1253*da0073e9SAndroid Build Coastguard Worker    def test_igamma_common(self, device, dtype):
1254*da0073e9SAndroid Build Coastguard Worker        # test igamma for reasonable range of values
1255*da0073e9SAndroid Build Coastguard Worker        loglo = -4  # approx 0.018
1256*da0073e9SAndroid Build Coastguard Worker        loghi = 4  # approx 54.6
1257*da0073e9SAndroid Build Coastguard Worker        self._helper_test_igamma(
1258*da0073e9SAndroid Build Coastguard Worker            loglo, loghi, device, dtype, torch.igamma, scipy.special.gammainc
1259*da0073e9SAndroid Build Coastguard Worker        )
1260*da0073e9SAndroid Build Coastguard Worker
1261*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCPU(torch.float16, torch.bfloat16, torch.float32, torch.float64)
1262*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32, torch.float64)
1263*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1264*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1265*da0073e9SAndroid Build Coastguard Worker    def test_igammac_common(self, device, dtype):
1266*da0073e9SAndroid Build Coastguard Worker        # test igammac for reasonable range of values
1267*da0073e9SAndroid Build Coastguard Worker        loglo = -4  # approx 0.018
1268*da0073e9SAndroid Build Coastguard Worker        loghi = 4  # approx 54.6
1269*da0073e9SAndroid Build Coastguard Worker        self._helper_test_igamma(
1270*da0073e9SAndroid Build Coastguard Worker            loglo, loghi, device, dtype, torch.igammac, scipy.special.gammaincc
1271*da0073e9SAndroid Build Coastguard Worker        )
1272*da0073e9SAndroid Build Coastguard Worker
1273*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCPU(torch.float16, torch.bfloat16, torch.float32, torch.float64)
1274*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32, torch.float64)
1275*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1276*da0073e9SAndroid Build Coastguard Worker    def test_igamma_edge_cases(self, device, dtype):
1277*da0073e9SAndroid Build Coastguard Worker        tkwargs = {"dtype": dtype, "device": device}
1278*da0073e9SAndroid Build Coastguard Worker        infs = torch.zeros((3,), **tkwargs) + float("inf")
1279*da0073e9SAndroid Build Coastguard Worker        zeros = torch.zeros((3,), **tkwargs)
1280*da0073e9SAndroid Build Coastguard Worker        ones = torch.ones((3,), **tkwargs)
1281*da0073e9SAndroid Build Coastguard Worker        zero_to_large = torch.tensor([0.0, 1.0, 1e3], **tkwargs)
1282*da0073e9SAndroid Build Coastguard Worker        small_to_inf = torch.tensor([1e-3, 1.0, float("inf")], **tkwargs)
1283*da0073e9SAndroid Build Coastguard Worker        nans = torch.zeros((3,), **tkwargs) + float("nan")
1284*da0073e9SAndroid Build Coastguard Worker        inpouts = [
1285*da0073e9SAndroid Build Coastguard Worker            # (a    ,    x),       out
1286*da0073e9SAndroid Build Coastguard Worker            ((zeros, small_to_inf), ones),
1287*da0073e9SAndroid Build Coastguard Worker            ((small_to_inf, zeros), zeros),
1288*da0073e9SAndroid Build Coastguard Worker            ((infs, zero_to_large), zeros),
1289*da0073e9SAndroid Build Coastguard Worker            ((zero_to_large, infs), ones),
1290*da0073e9SAndroid Build Coastguard Worker            ((zeros, zeros), nans),
1291*da0073e9SAndroid Build Coastguard Worker            ((infs, infs), nans),
1292*da0073e9SAndroid Build Coastguard Worker            ((-small_to_inf, small_to_inf), nans),
1293*da0073e9SAndroid Build Coastguard Worker        ]
1294*da0073e9SAndroid Build Coastguard Worker        for inputs, output in inpouts:
1295*da0073e9SAndroid Build Coastguard Worker            input0, input1 = inputs
1296*da0073e9SAndroid Build Coastguard Worker            calc = torch.igamma(input0, input1)
1297*da0073e9SAndroid Build Coastguard Worker            if torch.all(torch.isnan(output)):
1298*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch.all(torch.isnan(calc)))
1299*da0073e9SAndroid Build Coastguard Worker            else:
1300*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(calc, output)
1301*da0073e9SAndroid Build Coastguard Worker
1302*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCPU(torch.float16, torch.bfloat16, torch.float32, torch.float64)
1303*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32, torch.float64)
1304*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1305*da0073e9SAndroid Build Coastguard Worker    def test_igammac_edge_cases(self, device, dtype):
1306*da0073e9SAndroid Build Coastguard Worker        tkwargs = {"dtype": dtype, "device": device}
1307*da0073e9SAndroid Build Coastguard Worker        infs = torch.zeros((3,), **tkwargs) + float("inf")
1308*da0073e9SAndroid Build Coastguard Worker        zeros = torch.zeros((3,), **tkwargs)
1309*da0073e9SAndroid Build Coastguard Worker        ones = torch.ones((3,), **tkwargs)
1310*da0073e9SAndroid Build Coastguard Worker        zero_to_large = torch.tensor([0.0, 1.0, 1e3], **tkwargs)
1311*da0073e9SAndroid Build Coastguard Worker        small_to_inf = torch.tensor([1e-3, 1.0, float("inf")], **tkwargs)
1312*da0073e9SAndroid Build Coastguard Worker        nans = torch.zeros((3,), **tkwargs) + float("nan")
1313*da0073e9SAndroid Build Coastguard Worker        inpouts = [
1314*da0073e9SAndroid Build Coastguard Worker            # (a    ,    x),       out
1315*da0073e9SAndroid Build Coastguard Worker            ((zeros, small_to_inf), zeros),
1316*da0073e9SAndroid Build Coastguard Worker            ((small_to_inf, zeros), ones),
1317*da0073e9SAndroid Build Coastguard Worker            ((infs, zero_to_large), ones),
1318*da0073e9SAndroid Build Coastguard Worker            ((zero_to_large, infs), zeros),
1319*da0073e9SAndroid Build Coastguard Worker            ((zeros, zeros), nans),
1320*da0073e9SAndroid Build Coastguard Worker            ((infs, infs), nans),
1321*da0073e9SAndroid Build Coastguard Worker            ((-small_to_inf, small_to_inf), nans),
1322*da0073e9SAndroid Build Coastguard Worker        ]
1323*da0073e9SAndroid Build Coastguard Worker        for inputs, output in inpouts:
1324*da0073e9SAndroid Build Coastguard Worker            input0, input1 = inputs
1325*da0073e9SAndroid Build Coastguard Worker            calc = torch.igammac(input0, input1)
1326*da0073e9SAndroid Build Coastguard Worker            if torch.all(torch.isnan(output)):
1327*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch.all(torch.isnan(calc)))
1328*da0073e9SAndroid Build Coastguard Worker            else:
1329*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(calc, output)
1330*da0073e9SAndroid Build Coastguard Worker
1331*da0073e9SAndroid Build Coastguard Worker    def _i0_helper(self, t):
1332*da0073e9SAndroid Build Coastguard Worker        # Test by comparing to scipy
1333*da0073e9SAndroid Build Coastguard Worker        dtype = t.dtype
1334*da0073e9SAndroid Build Coastguard Worker        actual = torch.i0(t)
1335*da0073e9SAndroid Build Coastguard Worker        if dtype is torch.bfloat16:
1336*da0073e9SAndroid Build Coastguard Worker            t = t.to(torch.float32)
1337*da0073e9SAndroid Build Coastguard Worker        expected = scipy.special.i0(t.cpu().numpy())
1338*da0073e9SAndroid Build Coastguard Worker        # Casting down for dtype float16 is required since scipy upcasts to float32
1339*da0073e9SAndroid Build Coastguard Worker        if dtype is torch.bfloat16 or dtype is torch.float16:
1340*da0073e9SAndroid Build Coastguard Worker            expected = torch.from_numpy(expected).to(dtype)
1341*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual, expected)
1342*da0073e9SAndroid Build Coastguard Worker
1343*da0073e9SAndroid Build Coastguard Worker    def _i0_range_helper(self, range, device, dtype):
1344*da0073e9SAndroid Build Coastguard Worker        # i0 tests are broken up by the domain for which the function does not overflow for each dtype
1345*da0073e9SAndroid Build Coastguard Worker        # This is done to ensure that the function performs well across all possible input values, without worrying
1346*da0073e9SAndroid Build Coastguard Worker        # about inf or nan possibilities
1347*da0073e9SAndroid Build Coastguard Worker        for r in (range, -range):
1348*da0073e9SAndroid Build Coastguard Worker            t = torch.rand(1000, device=device).to(dtype) * r
1349*da0073e9SAndroid Build Coastguard Worker            self._i0_helper(t)
1350*da0073e9SAndroid Build Coastguard Worker
1351*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
1352*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.bfloat16, torch.float32, torch.float64)
1353*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1354*da0073e9SAndroid Build Coastguard Worker    def test_i0_range1(self, device, dtype):
1355*da0073e9SAndroid Build Coastguard Worker        # This tests the domain for i0 for which float16 does not overflow
1356*da0073e9SAndroid Build Coastguard Worker        # The domain is (-13.25, 13.25)
1357*da0073e9SAndroid Build Coastguard Worker        self._i0_range_helper(13.25, device, dtype)
1358*da0073e9SAndroid Build Coastguard Worker
1359*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
1360*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.bfloat16, torch.float32, torch.float64)
1361*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1362*da0073e9SAndroid Build Coastguard Worker    def test_i0_range2(self, device, dtype):
1363*da0073e9SAndroid Build Coastguard Worker        # This tests the domain for i0 for which float32 and bfloat16 does not overflow
1364*da0073e9SAndroid Build Coastguard Worker        # The domain is (-88.5, 88.5)
1365*da0073e9SAndroid Build Coastguard Worker        self._i0_range_helper(88.5, device, dtype)
1366*da0073e9SAndroid Build Coastguard Worker
1367*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float64)
1368*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1369*da0073e9SAndroid Build Coastguard Worker    def test_i0_range3(self, device, dtype):
1370*da0073e9SAndroid Build Coastguard Worker        # This tests the domain for i0 for which float64 does not overflow
1371*da0073e9SAndroid Build Coastguard Worker        # The domain is (-709.75, 709.75)
1372*da0073e9SAndroid Build Coastguard Worker        self._i0_range_helper(709.75, device, dtype)
1373*da0073e9SAndroid Build Coastguard Worker
1374*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
1375*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.bfloat16, torch.float32, torch.float64)
1376*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1377*da0073e9SAndroid Build Coastguard Worker    def test_i0_special(self, device, dtype):
1378*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([], device=device, dtype=dtype)
1379*da0073e9SAndroid Build Coastguard Worker        self._i0_helper(t)
1380*da0073e9SAndroid Build Coastguard Worker
1381*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([inf, -inf, nan], device=device, dtype=dtype)
1382*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.i0(t).isnan().all())
1383*da0073e9SAndroid Build Coastguard Worker
1384*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
1385*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.bfloat16, torch.float32, torch.float64)
1386*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1387*da0073e9SAndroid Build Coastguard Worker    def test_special_i0_i1_vs_scipy(self, device, dtype):
1388*da0073e9SAndroid Build Coastguard Worker        def check_equal(t, torch_fn, scipy_fn):
1389*da0073e9SAndroid Build Coastguard Worker            # Test by comparing to scipy
1390*da0073e9SAndroid Build Coastguard Worker            actual = torch_fn(t)
1391*da0073e9SAndroid Build Coastguard Worker            if dtype is torch.bfloat16:
1392*da0073e9SAndroid Build Coastguard Worker                t = t.to(torch.float32)
1393*da0073e9SAndroid Build Coastguard Worker            expected = scipy_fn(t.cpu().numpy())
1394*da0073e9SAndroid Build Coastguard Worker
1395*da0073e9SAndroid Build Coastguard Worker            # Casting down for dtype float16 is required since scipy upcasts to float32
1396*da0073e9SAndroid Build Coastguard Worker            if dtype is torch.bfloat16 or dtype is torch.float16:
1397*da0073e9SAndroid Build Coastguard Worker                expected = torch.from_numpy(expected).to(dtype)
1398*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expected)
1399*da0073e9SAndroid Build Coastguard Worker
1400*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([], device=device, dtype=dtype)
1401*da0073e9SAndroid Build Coastguard Worker        check_equal(t, torch.i0, scipy.special.i0)
1402*da0073e9SAndroid Build Coastguard Worker        check_equal(t, torch.special.i0e, scipy.special.i0e)
1403*da0073e9SAndroid Build Coastguard Worker        if dtype not in [torch.half, torch.bfloat16]:
1404*da0073e9SAndroid Build Coastguard Worker            check_equal(t, torch.special.i1, scipy.special.i1)
1405*da0073e9SAndroid Build Coastguard Worker            check_equal(t, torch.special.i1e, scipy.special.i1e)
1406*da0073e9SAndroid Build Coastguard Worker
1407*da0073e9SAndroid Build Coastguard Worker        range = (-1e7, 1e7)
1408*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.half:
1409*da0073e9SAndroid Build Coastguard Worker            range = (-65000, 65000)
1410*da0073e9SAndroid Build Coastguard Worker
1411*da0073e9SAndroid Build Coastguard Worker        t = torch.linspace(*range, int(1e4), device=device, dtype=dtype)
1412*da0073e9SAndroid Build Coastguard Worker        check_equal(t, torch.i0, scipy.special.i0)
1413*da0073e9SAndroid Build Coastguard Worker        check_equal(t, torch.special.i0e, scipy.special.i0e)
1414*da0073e9SAndroid Build Coastguard Worker        if dtype not in [torch.half, torch.bfloat16]:
1415*da0073e9SAndroid Build Coastguard Worker            check_equal(t, torch.special.i1, scipy.special.i1)
1416*da0073e9SAndroid Build Coastguard Worker            check_equal(t, torch.special.i1e, scipy.special.i1e)
1417*da0073e9SAndroid Build Coastguard Worker
1418*da0073e9SAndroid Build Coastguard Worker        # NaN, inf, -inf are tested in reference_numerics tests.
1419*da0073e9SAndroid Build Coastguard Worker        info = torch.finfo(dtype)
1420*da0073e9SAndroid Build Coastguard Worker        min, max, eps, tiny = info.min, info.max, info.eps, info.tiny
1421*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([min, max, eps, tiny], dtype=dtype, device=device)
1422*da0073e9SAndroid Build Coastguard Worker        check_equal(t, torch.i0, scipy.special.i0)
1423*da0073e9SAndroid Build Coastguard Worker        check_equal(t, torch.special.i0e, scipy.special.i0e)
1424*da0073e9SAndroid Build Coastguard Worker        if dtype not in [torch.half, torch.bfloat16]:
1425*da0073e9SAndroid Build Coastguard Worker            check_equal(t, torch.special.i1, scipy.special.i1)
1426*da0073e9SAndroid Build Coastguard Worker            check_equal(t, torch.special.i1e, scipy.special.i1e)
1427*da0073e9SAndroid Build Coastguard Worker
1428*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32, torch.float64)
1429*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1430*da0073e9SAndroid Build Coastguard Worker    def test_special_ndtr_vs_scipy(self, device, dtype):
1431*da0073e9SAndroid Build Coastguard Worker        def check_equal(t):
1432*da0073e9SAndroid Build Coastguard Worker            # Test by comparing to scipy
1433*da0073e9SAndroid Build Coastguard Worker            actual = torch.special.ndtr(t)
1434*da0073e9SAndroid Build Coastguard Worker            expected = scipy.special.ndtr(t.cpu().numpy())
1435*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expected)
1436*da0073e9SAndroid Build Coastguard Worker
1437*da0073e9SAndroid Build Coastguard Worker        range = (-10, 10)
1438*da0073e9SAndroid Build Coastguard Worker        t = torch.linspace(*range, 1, device=device, dtype=dtype)
1439*da0073e9SAndroid Build Coastguard Worker        check_equal(t)
1440*da0073e9SAndroid Build Coastguard Worker
1441*da0073e9SAndroid Build Coastguard Worker        # Skip testing NaN, inf, -inf since they are tested in reference_numerics tests.
1442*da0073e9SAndroid Build Coastguard Worker        info = torch.finfo(dtype)
1443*da0073e9SAndroid Build Coastguard Worker        min, max, eps, tiny = info.min, info.max, info.eps, info.tiny
1444*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([min, max, eps, tiny], dtype=dtype, device=device)
1445*da0073e9SAndroid Build Coastguard Worker        check_equal(t)
1446*da0073e9SAndroid Build Coastguard Worker
1447*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32, torch.float64)
1448*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1449*da0073e9SAndroid Build Coastguard Worker    def test_special_log_ndtr_vs_scipy(self, device, dtype):
1450*da0073e9SAndroid Build Coastguard Worker        def check_equal(t):
1451*da0073e9SAndroid Build Coastguard Worker            # Test by comparing with scipy
1452*da0073e9SAndroid Build Coastguard Worker            actual = torch.special.log_ndtr(t)
1453*da0073e9SAndroid Build Coastguard Worker            expected = scipy.special.log_ndtr(t.cpu().numpy())
1454*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expected)
1455*da0073e9SAndroid Build Coastguard Worker
1456*da0073e9SAndroid Build Coastguard Worker        # Skip testing NaN, inf, -inf since they are tested in reference_numerics tests.
1457*da0073e9SAndroid Build Coastguard Worker        info = torch.finfo(dtype)
1458*da0073e9SAndroid Build Coastguard Worker        min, max, eps, tiny = info.min, info.max, info.eps, info.tiny
1459*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([min, max, eps, tiny], dtype=dtype, device=device)
1460*da0073e9SAndroid Build Coastguard Worker        check_equal(t)
1461*da0073e9SAndroid Build Coastguard Worker
1462*da0073e9SAndroid Build Coastguard Worker    # TODO: allow large opinfo values to be opted-into via metadata
1463*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.long)
1464*da0073e9SAndroid Build Coastguard Worker    def test_abs_big_number(self, device, dtype):
1465*da0073e9SAndroid Build Coastguard Worker        bignumber = 2**31 + 1
1466*da0073e9SAndroid Build Coastguard Worker        res = torch.tensor([bignumber], device=device, dtype=dtype)
1467*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(res.abs()[0], 0)
1468*da0073e9SAndroid Build Coastguard Worker
1469*da0073e9SAndroid Build Coastguard Worker    # TODO: add signed zero testing to opinfos
1470*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
1471*da0073e9SAndroid Build Coastguard Worker    def test_abs_signed_zero(self, device, dtype):
1472*da0073e9SAndroid Build Coastguard Worker        # Both abs(0.0) and abs(-0.0) should result in 0.0
1473*da0073e9SAndroid Build Coastguard Worker        size = 128 + 1  # pick a large enough number with remainder so that
1474*da0073e9SAndroid Build Coastguard Worker        # both vectorized and nonvectorized op is tested
1475*da0073e9SAndroid Build Coastguard Worker        inp = torch.zeros(size, device=device, dtype=dtype)
1476*da0073e9SAndroid Build Coastguard Worker        inp[::2] = -0.0
1477*da0073e9SAndroid Build Coastguard Worker        inp = inp.abs()
1478*da0073e9SAndroid Build Coastguard Worker        for v in inp:
1479*da0073e9SAndroid Build Coastguard Worker            self.assertGreater(math.copysign(1.0, v), 0.0)
1480*da0073e9SAndroid Build Coastguard Worker
1481*da0073e9SAndroid Build Coastguard Worker    # TODO: update to compare against NumPy by rationalizing with OpInfo
1482*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1483*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
1484*da0073e9SAndroid Build Coastguard Worker    def test_abs_zero(self, device, dtype):
1485*da0073e9SAndroid Build Coastguard Worker        # Both abs(0.0) and abs(-0.0) should result in 0.0
1486*da0073e9SAndroid Build Coastguard Worker        abs_zeros = torch.tensor([0.0, -0.0], device=device, dtype=dtype).abs().tolist()
1487*da0073e9SAndroid Build Coastguard Worker        for num in abs_zeros:
1488*da0073e9SAndroid Build Coastguard Worker            self.assertGreater(math.copysign(1.0, num), 0.0)
1489*da0073e9SAndroid Build Coastguard Worker
1490*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
1491*da0073e9SAndroid Build Coastguard Worker    def test_isposinf_isneginf_non_boolean_output(self, device, dtype):
1492*da0073e9SAndroid Build Coastguard Worker        # test non-boolean tensors as the `out=` parameters
1493*da0073e9SAndroid Build Coastguard Worker        # boolean outputs are tested in the above testcases
1494*da0073e9SAndroid Build Coastguard Worker        vals = (float("inf"), -float("inf"), 1.2)
1495*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor(vals, device=device)
1496*da0073e9SAndroid Build Coastguard Worker        for torch_op in (torch.isposinf, torch.isneginf):
1497*da0073e9SAndroid Build Coastguard Worker            out = torch.empty_like(t, dtype=dtype)
1498*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
1499*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "does not support non-boolean outputs"
1500*da0073e9SAndroid Build Coastguard Worker            ):
1501*da0073e9SAndroid Build Coastguard Worker                torch_op(t, out=out)
1502*da0073e9SAndroid Build Coastguard Worker
1503*da0073e9SAndroid Build Coastguard Worker    def test_nonzero_empty(self, device):
1504*da0073e9SAndroid Build Coastguard Worker        def assert_tuple_empty(tup, dim):
1505*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(dim, len(tup))
1506*da0073e9SAndroid Build Coastguard Worker            for t in tup:
1507*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(torch.Size([0]), t.shape)
1508*da0073e9SAndroid Build Coastguard Worker
1509*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(0, 2, 0, 5, 0, device=device)
1510*da0073e9SAndroid Build Coastguard Worker        y = torch.nonzero(x)
1511*da0073e9SAndroid Build Coastguard Worker        z = torch.nonzero(x, as_tuple=True)
1512*da0073e9SAndroid Build Coastguard Worker
1513*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, y.numel())
1514*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.Size([0, 5]), y.shape)
1515*da0073e9SAndroid Build Coastguard Worker        assert_tuple_empty(z, 5)
1516*da0073e9SAndroid Build Coastguard Worker
1517*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(0.5, device=device)
1518*da0073e9SAndroid Build Coastguard Worker        y = torch.nonzero(x)
1519*da0073e9SAndroid Build Coastguard Worker        # nonzero with as_tuple returns a
1520*da0073e9SAndroid Build Coastguard Worker        # tuple of len 1 for a zero-dim tensor.
1521*da0073e9SAndroid Build Coastguard Worker        # This is done to match Numpy behavior.
1522*da0073e9SAndroid Build Coastguard Worker        z = torch.nonzero(x, as_tuple=True)
1523*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(1, len(z))
1524*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.zeros(1, dtype=torch.long), z[0])
1525*da0073e9SAndroid Build Coastguard Worker
1526*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros((), device=device)
1527*da0073e9SAndroid Build Coastguard Worker        y = torch.nonzero(x)
1528*da0073e9SAndroid Build Coastguard Worker        z = torch.nonzero(x, as_tuple=True)
1529*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.Size([0, 0]), y.shape)
1530*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(1, len(z))
1531*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.empty(0, dtype=torch.long), z[0])
1532*da0073e9SAndroid Build Coastguard Worker
1533*da0073e9SAndroid Build Coastguard Worker    # TODO: rationalize with exp OpInfo
1534*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types_and(torch.bfloat16))
1535*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(*floating_and_complex_types_and(torch.half, torch.bfloat16))
1536*da0073e9SAndroid Build Coastguard Worker    def test_exp(self, device, dtype):
1537*da0073e9SAndroid Build Coastguard Worker        for v in (2, -2) + ((1j, 1 + 1j) if dtype.is_complex else ()):
1538*da0073e9SAndroid Build Coastguard Worker            a = (
1539*da0073e9SAndroid Build Coastguard Worker                torch.tensor(v, dtype=dtype, device=device)
1540*da0073e9SAndroid Build Coastguard Worker                * torch.arange(18, device=device)
1541*da0073e9SAndroid Build Coastguard Worker                / 3
1542*da0073e9SAndroid Build Coastguard Worker                * math.pi
1543*da0073e9SAndroid Build Coastguard Worker            )
1544*da0073e9SAndroid Build Coastguard Worker            a = a.to(dtype)
1545*da0073e9SAndroid Build Coastguard Worker            # bfloat16 overflows
1546*da0073e9SAndroid Build Coastguard Worker            if dtype == torch.bfloat16:
1547*da0073e9SAndroid Build Coastguard Worker                return
1548*da0073e9SAndroid Build Coastguard Worker            self.compare_with_numpy(torch.exp, np.exp, a)
1549*da0073e9SAndroid Build Coastguard Worker
1550*da0073e9SAndroid Build Coastguard Worker            if dtype.is_complex:
1551*da0073e9SAndroid Build Coastguard Worker                inf_real_zero_imag_in = torch.tensor(
1552*da0073e9SAndroid Build Coastguard Worker                    complex(float("inf"), 0), device=device, dtype=dtype
1553*da0073e9SAndroid Build Coastguard Worker                )
1554*da0073e9SAndroid Build Coastguard Worker                inf_real_zero_imag_out = torch.exp(inf_real_zero_imag_in).item()
1555*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(math.isinf(inf_real_zero_imag_out.real))
1556*da0073e9SAndroid Build Coastguard Worker                if self.device_type == "cpu":
1557*da0073e9SAndroid Build Coastguard Worker                    pass
1558*da0073e9SAndroid Build Coastguard Worker                    # These are commented out because it cannot be consistently reproduced.
1559*da0073e9SAndroid Build Coastguard Worker                    # This is incorrect. It should be zero. Need fix!
1560*da0073e9SAndroid Build Coastguard Worker                    # https://github.com/pytorch/pytorch/issues/40590
1561*da0073e9SAndroid Build Coastguard Worker                    # self.assertNotEqual(inf_real_zero_imag_out.imag, 0)
1562*da0073e9SAndroid Build Coastguard Worker                    # This is incorrect. They should equal. Need fix!
1563*da0073e9SAndroid Build Coastguard Worker                    # https://github.com/pytorch/pytorch/issues/40590
1564*da0073e9SAndroid Build Coastguard Worker                    # with self.assertRaises(AssertionError):
1565*da0073e9SAndroid Build Coastguard Worker                    #     self.compare_with_numpy(torch.exp, np.exp, inf_real_zero_imag_in)
1566*da0073e9SAndroid Build Coastguard Worker                else:
1567*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(inf_real_zero_imag_out.imag, 0, atol=0, rtol=0)
1568*da0073e9SAndroid Build Coastguard Worker                    self.compare_with_numpy(torch.exp, np.exp, inf_real_zero_imag_in)
1569*da0073e9SAndroid Build Coastguard Worker
1570*da0073e9SAndroid Build Coastguard Worker                zero_real_inf_imag_in = torch.tensor(
1571*da0073e9SAndroid Build Coastguard Worker                    complex(0, float("inf")), device=device, dtype=dtype
1572*da0073e9SAndroid Build Coastguard Worker                )
1573*da0073e9SAndroid Build Coastguard Worker                zero_real_inf_imag_out = torch.exp(zero_real_inf_imag_in).item()
1574*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(math.isnan(zero_real_inf_imag_out.real))
1575*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(math.isnan(zero_real_inf_imag_out.imag))
1576*da0073e9SAndroid Build Coastguard Worker                # Ensure we are notified when NumPy changes its behavior
1577*da0073e9SAndroid Build Coastguard Worker                self.compare_with_numpy(torch.exp, np.exp, zero_real_inf_imag_in)
1578*da0073e9SAndroid Build Coastguard Worker
1579*da0073e9SAndroid Build Coastguard Worker                inf_real_imag_in = torch.tensor(
1580*da0073e9SAndroid Build Coastguard Worker                    complex(float("inf"), float("inf")), device=device, dtype=dtype
1581*da0073e9SAndroid Build Coastguard Worker                )
1582*da0073e9SAndroid Build Coastguard Worker                inf_real_imag_out = torch.exp(inf_real_imag_in).item()
1583*da0073e9SAndroid Build Coastguard Worker                if self.device_type == "cpu":
1584*da0073e9SAndroid Build Coastguard Worker                    pass
1585*da0073e9SAndroid Build Coastguard Worker                    # This is incorrect. Need fix! https://github.com/pytorch/pytorch/issues/40590
1586*da0073e9SAndroid Build Coastguard Worker                    # This is commented out because it cannot be consistently reproduced.
1587*da0073e9SAndroid Build Coastguard Worker                    # with self.assertRaises(AssertionError):
1588*da0073e9SAndroid Build Coastguard Worker                    #     self.compare_with_numpy(torch.exp, np.exp, inf_real_imag_in)
1589*da0073e9SAndroid Build Coastguard Worker                else:
1590*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(math.isinf(inf_real_imag_out.real))
1591*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(math.isnan(inf_real_imag_out.imag))
1592*da0073e9SAndroid Build Coastguard Worker                    self.compare_with_numpy(torch.exp, np.exp, inf_real_imag_in)
1593*da0073e9SAndroid Build Coastguard Worker
1594*da0073e9SAndroid Build Coastguard Worker                inf_real_nan_imag_in = torch.tensor(
1595*da0073e9SAndroid Build Coastguard Worker                    complex(float("inf"), float("nan")), device=device, dtype=dtype
1596*da0073e9SAndroid Build Coastguard Worker                )
1597*da0073e9SAndroid Build Coastguard Worker                inf_real_nan_imag_out = torch.exp(inf_real_nan_imag_in).item()
1598*da0073e9SAndroid Build Coastguard Worker                if self.device_type == "cpu":
1599*da0073e9SAndroid Build Coastguard Worker                    pass
1600*da0073e9SAndroid Build Coastguard Worker                    # This is incorrect. It should be inf. Need fix! https://github.com/pytorch/pytorch/issues/40590
1601*da0073e9SAndroid Build Coastguard Worker                    # This is commented out because it cannot be consistently reproduced.
1602*da0073e9SAndroid Build Coastguard Worker                    # with self.assertRaises(AssertionError):
1603*da0073e9SAndroid Build Coastguard Worker                    #     self.compare_with_numpy(torch.exp, np.exp, inf_real_nan_imag_in)
1604*da0073e9SAndroid Build Coastguard Worker                else:
1605*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(math.isinf(inf_real_nan_imag_out.real))
1606*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(math.isnan(inf_real_nan_imag_out.imag))
1607*da0073e9SAndroid Build Coastguard Worker                    self.compare_with_numpy(torch.exp, np.exp, inf_real_nan_imag_in)
1608*da0073e9SAndroid Build Coastguard Worker
1609*da0073e9SAndroid Build Coastguard Worker                nan_real_inf_imag_in = torch.tensor(
1610*da0073e9SAndroid Build Coastguard Worker                    complex(float("nan"), float("inf")), device=device, dtype=dtype
1611*da0073e9SAndroid Build Coastguard Worker                )
1612*da0073e9SAndroid Build Coastguard Worker                nan_real_inf_imag_out = torch.exp(nan_real_inf_imag_in).item()
1613*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(math.isnan(nan_real_inf_imag_out.real))
1614*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(math.isnan(nan_real_inf_imag_out.imag))
1615*da0073e9SAndroid Build Coastguard Worker                # Ensure we are notified when NumPy changes its behavior
1616*da0073e9SAndroid Build Coastguard Worker                self.compare_with_numpy(torch.exp, np.exp, nan_real_inf_imag_in)
1617*da0073e9SAndroid Build Coastguard Worker
1618*da0073e9SAndroid Build Coastguard Worker
1619*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestUnaryUfuncs, globals())
1620*da0073e9SAndroid Build Coastguard Worker
1621*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
1622*da0073e9SAndroid Build Coastguard Worker    run_tests()
1623