1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: tests"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport torch 4*da0073e9SAndroid Build Coastguard Workerimport numpy as np 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport math 7*da0073e9SAndroid Build Coastguard Workerfrom numbers import Number 8*da0073e9SAndroid Build Coastguard Workerimport random 9*da0073e9SAndroid Build Coastguard Workerimport unittest 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Workerfrom torch import inf, nan 12*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 13*da0073e9SAndroid Build Coastguard Worker TestCase, 14*da0073e9SAndroid Build Coastguard Worker run_tests, 15*da0073e9SAndroid Build Coastguard Worker torch_to_numpy_dtype_dict, 16*da0073e9SAndroid Build Coastguard Worker numpy_to_torch_dtype_dict, 17*da0073e9SAndroid Build Coastguard Worker suppress_warnings, 18*da0073e9SAndroid Build Coastguard Worker TEST_SCIPY, 19*da0073e9SAndroid Build Coastguard Worker slowTest, 20*da0073e9SAndroid Build Coastguard Worker skipIfNoSciPy, 21*da0073e9SAndroid Build Coastguard Worker IS_WINDOWS, 22*da0073e9SAndroid Build Coastguard Worker gradcheck, 23*da0073e9SAndroid Build Coastguard Worker is_iterable_of_tensors, 24*da0073e9SAndroid Build Coastguard Worker xfailIfTorchDynamo, 25*da0073e9SAndroid Build Coastguard Worker) 26*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import ( 27*da0073e9SAndroid Build Coastguard Worker unary_ufuncs, 28*da0073e9SAndroid Build Coastguard Worker generate_elementwise_unary_tensors, 29*da0073e9SAndroid Build Coastguard Worker generate_elementwise_unary_small_value_tensors, 30*da0073e9SAndroid Build Coastguard Worker generate_elementwise_unary_large_value_tensors, 31*da0073e9SAndroid Build Coastguard Worker generate_elementwise_unary_extremal_value_tensors, 32*da0073e9SAndroid Build Coastguard Worker) 33*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ( 34*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests, 35*da0073e9SAndroid Build Coastguard Worker ops, 36*da0073e9SAndroid Build Coastguard Worker dtypes, 37*da0073e9SAndroid Build Coastguard Worker onlyCPU, 38*da0073e9SAndroid Build Coastguard Worker onlyNativeDeviceTypes, 39*da0073e9SAndroid Build Coastguard Worker onlyCUDA, 40*da0073e9SAndroid Build Coastguard Worker dtypesIfCUDA, 41*da0073e9SAndroid Build Coastguard Worker precisionOverride, 42*da0073e9SAndroid Build Coastguard Worker dtypesIfCPU, 43*da0073e9SAndroid Build Coastguard Worker) 44*da0073e9SAndroid Build Coastguard Workerfrom torch.utils import _pytree as pytree 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor 47*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import ( 48*da0073e9SAndroid Build Coastguard Worker floating_types_and, 49*da0073e9SAndroid Build Coastguard Worker all_types_and_complex_and, 50*da0073e9SAndroid Build Coastguard Worker integral_types_and, 51*da0073e9SAndroid Build Coastguard Worker get_all_math_dtypes, 52*da0073e9SAndroid Build Coastguard Worker complex_types, 53*da0073e9SAndroid Build Coastguard Worker floating_and_complex_types_and, 54*da0073e9SAndroid Build Coastguard Worker) 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Workerif TEST_SCIPY: 57*da0073e9SAndroid Build Coastguard Worker import scipy 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker# Refer [scipy reference filter] 60*da0073e9SAndroid Build Coastguard Worker# Filter operators for which the reference function 61*da0073e9SAndroid Build Coastguard Worker# is available in the current environment (for reference_numerics tests). 62*da0073e9SAndroid Build Coastguard Workerreference_filtered_ops = list(filter(lambda op: op.ref is not None, unary_ufuncs)) 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker# Tests for unary "universal functions (ufuncs)" that accept a single 65*da0073e9SAndroid Build Coastguard Worker# tensor and have common properties like: 66*da0073e9SAndroid Build Coastguard Worker# - they are elementwise functions 67*da0073e9SAndroid Build Coastguard Worker# - the input shape is the output shape 68*da0073e9SAndroid Build Coastguard Worker# - they typically have method and inplace variants 69*da0073e9SAndroid Build Coastguard Worker# - they typically support the out kwarg 70*da0073e9SAndroid Build Coastguard Worker# - they typically have NumPy or SciPy references 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Worker# See NumPy's universal function documentation 73*da0073e9SAndroid Build Coastguard Worker# (https://numpy.org/doc/1.18/reference/ufuncs.html) for more details 74*da0073e9SAndroid Build Coastguard Worker# about the concept of ufuncs. 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker# TODO: port test_unary_out_op_mem_overlap 78*da0073e9SAndroid Build Coastguard Worker# TODO: add test for inplace variants erroring on broadcasted inputs 79*da0073e9SAndroid Build Coastguard Workerclass TestUnaryUfuncs(TestCase): 80*da0073e9SAndroid Build Coastguard Worker exact_dtype = True 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker @ops( 83*da0073e9SAndroid Build Coastguard Worker [_fn for _fn in unary_ufuncs if _fn.domain != (None, None)], 84*da0073e9SAndroid Build Coastguard Worker allowed_dtypes=floating_types_and(torch.bfloat16, torch.half), 85*da0073e9SAndroid Build Coastguard Worker ) 86*da0073e9SAndroid Build Coastguard Worker def test_float_domains(self, device, dtype, op): 87*da0073e9SAndroid Build Coastguard Worker eps = (1e-5, 1e-3, 1e-1, 1, 2, 10, 20, 50, 100) 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker low, high = op.domain 90*da0073e9SAndroid Build Coastguard Worker # NOTE: the following two loops are separated for readability 91*da0073e9SAndroid Build Coastguard Worker if low is not None: 92*da0073e9SAndroid Build Coastguard Worker low_tensor = torch.tensor(low, device=device, dtype=dtype) 93*da0073e9SAndroid Build Coastguard Worker for epsilon in eps: 94*da0073e9SAndroid Build Coastguard Worker lower_tensor = low_tensor - epsilon 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker # Skips the test if the difference is not representable, 97*da0073e9SAndroid Build Coastguard Worker # which can occur if, for example, the difference is small 98*da0073e9SAndroid Build Coastguard Worker # and the dtype is imprecise (like bfloat16 is) 99*da0073e9SAndroid Build Coastguard Worker if lower_tensor.item() == low_tensor.item(): 100*da0073e9SAndroid Build Coastguard Worker continue 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker result = op(lower_tensor) 103*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 104*da0073e9SAndroid Build Coastguard Worker result.item(), 105*da0073e9SAndroid Build Coastguard Worker float("nan"), 106*da0073e9SAndroid Build Coastguard Worker msg=( 107*da0073e9SAndroid Build Coastguard Worker f"input of {lower_tensor.item()} outside lower domain boundary" 108*da0073e9SAndroid Build Coastguard Worker f" {low} produced {result.item()}, not nan!" 109*da0073e9SAndroid Build Coastguard Worker ), 110*da0073e9SAndroid Build Coastguard Worker ) 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker if high is not None: 113*da0073e9SAndroid Build Coastguard Worker high_tensor = torch.tensor(high, device=device, dtype=dtype) 114*da0073e9SAndroid Build Coastguard Worker for epsilon in eps: 115*da0073e9SAndroid Build Coastguard Worker higher_tensor = high_tensor + epsilon 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker # See above comment 118*da0073e9SAndroid Build Coastguard Worker if higher_tensor.item() == high_tensor.item(): 119*da0073e9SAndroid Build Coastguard Worker continue 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker result = op(higher_tensor) 122*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 123*da0073e9SAndroid Build Coastguard Worker result.item(), 124*da0073e9SAndroid Build Coastguard Worker float("nan"), 125*da0073e9SAndroid Build Coastguard Worker msg=( 126*da0073e9SAndroid Build Coastguard Worker f"input of {higher_tensor.item()} outside upper domain boundary" 127*da0073e9SAndroid Build Coastguard Worker f" {high} produced {result.item()}, not nan!" 128*da0073e9SAndroid Build Coastguard Worker ), 129*da0073e9SAndroid Build Coastguard Worker ) 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker # Helper for comparing torch tensors and numpy arrays 132*da0073e9SAndroid Build Coastguard Worker # TODO: should this or assertEqual also validate that strides are equal? 133*da0073e9SAndroid Build Coastguard Worker def assertEqualHelper( 134*da0073e9SAndroid Build Coastguard Worker self, actual, expected, msg, *, dtype, exact_dtype=True, **kwargs 135*da0073e9SAndroid Build Coastguard Worker ): 136*da0073e9SAndroid Build Coastguard Worker assert isinstance(actual, torch.Tensor) 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker # Some NumPy functions return scalars, not arrays 139*da0073e9SAndroid Build Coastguard Worker if isinstance(expected, Number): 140*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual.item(), expected, msg, **kwargs) 141*da0073e9SAndroid Build Coastguard Worker elif isinstance(expected, np.ndarray): 142*da0073e9SAndroid Build Coastguard Worker # Handles exact dtype comparisons between arrays and tensors 143*da0073e9SAndroid Build Coastguard Worker if exact_dtype: 144*da0073e9SAndroid Build Coastguard Worker if ( 145*da0073e9SAndroid Build Coastguard Worker actual.dtype is torch.bfloat16 146*da0073e9SAndroid Build Coastguard Worker or expected.dtype != torch_to_numpy_dtype_dict[actual.dtype] 147*da0073e9SAndroid Build Coastguard Worker ): 148*da0073e9SAndroid Build Coastguard Worker # Allows array dtype to be float32 when comparing with bfloat16 tensors 149*da0073e9SAndroid Build Coastguard Worker # since NumPy doesn't support the bfloat16 dtype 150*da0073e9SAndroid Build Coastguard Worker # Also ops like scipy.special.erf, scipy.special.erfc, etc, promote float16 151*da0073e9SAndroid Build Coastguard Worker # to float32 152*da0073e9SAndroid Build Coastguard Worker if expected.dtype == np.float32: 153*da0073e9SAndroid Build Coastguard Worker assert actual.dtype in ( 154*da0073e9SAndroid Build Coastguard Worker torch.float16, 155*da0073e9SAndroid Build Coastguard Worker torch.bfloat16, 156*da0073e9SAndroid Build Coastguard Worker torch.float32, 157*da0073e9SAndroid Build Coastguard Worker ) 158*da0073e9SAndroid Build Coastguard Worker elif expected.dtype == np.float64: 159*da0073e9SAndroid Build Coastguard Worker assert actual.dtype in ( 160*da0073e9SAndroid Build Coastguard Worker torch.float16, 161*da0073e9SAndroid Build Coastguard Worker torch.bfloat16, 162*da0073e9SAndroid Build Coastguard Worker torch.float32, 163*da0073e9SAndroid Build Coastguard Worker torch.float64, 164*da0073e9SAndroid Build Coastguard Worker ) 165*da0073e9SAndroid Build Coastguard Worker else: 166*da0073e9SAndroid Build Coastguard Worker self.fail( 167*da0073e9SAndroid Build Coastguard Worker f"Expected dtype {expected.dtype} but got {actual.dtype}!" 168*da0073e9SAndroid Build Coastguard Worker ) 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 171*da0073e9SAndroid Build Coastguard Worker actual, 172*da0073e9SAndroid Build Coastguard Worker torch.from_numpy(expected).to(actual.dtype), 173*da0073e9SAndroid Build Coastguard Worker msg, 174*da0073e9SAndroid Build Coastguard Worker exact_device=False, 175*da0073e9SAndroid Build Coastguard Worker **kwargs 176*da0073e9SAndroid Build Coastguard Worker ) 177*da0073e9SAndroid Build Coastguard Worker else: 178*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, msg, exact_device=False, **kwargs) 179*da0073e9SAndroid Build Coastguard Worker 180*da0073e9SAndroid Build Coastguard Worker # Tests that the function and its (array-accepting) reference produce the same 181*da0073e9SAndroid Build Coastguard Worker # values on given tensors 182*da0073e9SAndroid Build Coastguard Worker def _test_reference_numerics(self, dtype, op, tensors, equal_nan=True): 183*da0073e9SAndroid Build Coastguard Worker def _helper_reference_numerics( 184*da0073e9SAndroid Build Coastguard Worker expected, actual, msg, exact_dtype, equal_nan=True 185*da0073e9SAndroid Build Coastguard Worker ): 186*da0073e9SAndroid Build Coastguard Worker if not torch.can_cast( 187*da0073e9SAndroid Build Coastguard Worker numpy_to_torch_dtype_dict[expected.dtype.type], dtype 188*da0073e9SAndroid Build Coastguard Worker ): 189*da0073e9SAndroid Build Coastguard Worker exact_dtype = False 190*da0073e9SAndroid Build Coastguard Worker 191*da0073e9SAndroid Build Coastguard Worker if dtype in [torch.uint8, torch.int8, torch.bool]: 192*da0073e9SAndroid Build Coastguard Worker # NOTE: For these dtypes, PyTorch computes in the default scalar type (float) 193*da0073e9SAndroid Build Coastguard Worker # while NumPy computes in float16 194*da0073e9SAndroid Build Coastguard Worker self.assertEqualHelper( 195*da0073e9SAndroid Build Coastguard Worker actual, 196*da0073e9SAndroid Build Coastguard Worker expected, 197*da0073e9SAndroid Build Coastguard Worker msg, 198*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 199*da0073e9SAndroid Build Coastguard Worker exact_dtype=exact_dtype, 200*da0073e9SAndroid Build Coastguard Worker rtol=1e-3, 201*da0073e9SAndroid Build Coastguard Worker atol=1e-2, 202*da0073e9SAndroid Build Coastguard Worker ) 203*da0073e9SAndroid Build Coastguard Worker elif dtype is torch.bfloat16: 204*da0073e9SAndroid Build Coastguard Worker # Ref: https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_utils.py#L1149 205*da0073e9SAndroid Build Coastguard Worker self.assertEqualHelper( 206*da0073e9SAndroid Build Coastguard Worker actual, 207*da0073e9SAndroid Build Coastguard Worker expected, 208*da0073e9SAndroid Build Coastguard Worker msg, 209*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 210*da0073e9SAndroid Build Coastguard Worker exact_dtype=exact_dtype, 211*da0073e9SAndroid Build Coastguard Worker rtol=16e-3, 212*da0073e9SAndroid Build Coastguard Worker atol=1e-5, 213*da0073e9SAndroid Build Coastguard Worker ) 214*da0073e9SAndroid Build Coastguard Worker elif dtype is torch.half: 215*da0073e9SAndroid Build Coastguard Worker self.assertEqualHelper( 216*da0073e9SAndroid Build Coastguard Worker actual, 217*da0073e9SAndroid Build Coastguard Worker expected, 218*da0073e9SAndroid Build Coastguard Worker msg, 219*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 220*da0073e9SAndroid Build Coastguard Worker exact_dtype=exact_dtype, 221*da0073e9SAndroid Build Coastguard Worker rtol=1.2e-03, 222*da0073e9SAndroid Build Coastguard Worker atol=1e-03, 223*da0073e9SAndroid Build Coastguard Worker ) 224*da0073e9SAndroid Build Coastguard Worker else: 225*da0073e9SAndroid Build Coastguard Worker self.assertEqualHelper( 226*da0073e9SAndroid Build Coastguard Worker actual, 227*da0073e9SAndroid Build Coastguard Worker expected, 228*da0073e9SAndroid Build Coastguard Worker msg, 229*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 230*da0073e9SAndroid Build Coastguard Worker equal_nan=equal_nan, 231*da0073e9SAndroid Build Coastguard Worker exact_dtype=exact_dtype, 232*da0073e9SAndroid Build Coastguard Worker ) 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker for t in tensors: 235*da0073e9SAndroid Build Coastguard Worker t = t.input 236*da0073e9SAndroid Build Coastguard Worker torch_kwargs, numpy_kwargs = op.sample_kwargs(t.device, dtype, t) 237*da0073e9SAndroid Build Coastguard Worker if dtype is torch.bfloat16: 238*da0073e9SAndroid Build Coastguard Worker a = t.cpu().to(torch.float32).numpy() 239*da0073e9SAndroid Build Coastguard Worker elif dtype is torch.complex32: 240*da0073e9SAndroid Build Coastguard Worker a = t.cpu().to(torch.complex64).numpy() 241*da0073e9SAndroid Build Coastguard Worker else: 242*da0073e9SAndroid Build Coastguard Worker a = t.cpu().numpy() 243*da0073e9SAndroid Build Coastguard Worker 244*da0073e9SAndroid Build Coastguard Worker actual = op(t, **torch_kwargs) 245*da0073e9SAndroid Build Coastguard Worker expected = op.ref(a, **numpy_kwargs) 246*da0073e9SAndroid Build Coastguard Worker 247*da0073e9SAndroid Build Coastguard Worker # Crafts a custom error message for smaller, printable tensors 248*da0073e9SAndroid Build Coastguard Worker if t.numel() < 10: 249*da0073e9SAndroid Build Coastguard Worker msg = ( 250*da0073e9SAndroid Build Coastguard Worker "Failed to produce expected results! Input tensor was" 251*da0073e9SAndroid Build Coastguard Worker f" {t}, torch result is {actual}, and reference result is" 252*da0073e9SAndroid Build Coastguard Worker f" {expected}." 253*da0073e9SAndroid Build Coastguard Worker ) 254*da0073e9SAndroid Build Coastguard Worker else: 255*da0073e9SAndroid Build Coastguard Worker msg = None 256*da0073e9SAndroid Build Coastguard Worker 257*da0073e9SAndroid Build Coastguard Worker exact_dtype = True 258*da0073e9SAndroid Build Coastguard Worker if isinstance(actual, torch.Tensor): 259*da0073e9SAndroid Build Coastguard Worker _helper_reference_numerics( 260*da0073e9SAndroid Build Coastguard Worker expected, actual, msg, exact_dtype, equal_nan 261*da0073e9SAndroid Build Coastguard Worker ) 262*da0073e9SAndroid Build Coastguard Worker else: 263*da0073e9SAndroid Build Coastguard Worker for x, y in zip(expected, actual): 264*da0073e9SAndroid Build Coastguard Worker # testing multi-outputs results 265*da0073e9SAndroid Build Coastguard Worker _helper_reference_numerics(x, y, msg, exact_dtype, equal_nan) 266*da0073e9SAndroid Build Coastguard Worker 267*da0073e9SAndroid Build Coastguard Worker # Tests that the function and its (array-accepting) reference produce the same 268*da0073e9SAndroid Build Coastguard Worker # values on a range of tensors, including empty tensors, scalar tensors, 269*da0073e9SAndroid Build Coastguard Worker # 1D tensors and a large 2D tensor with interesting and extremal values 270*da0073e9SAndroid Build Coastguard Worker # and noncontiguities. 271*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 272*da0073e9SAndroid Build Coastguard Worker @ops(reference_filtered_ops) 273*da0073e9SAndroid Build Coastguard Worker def test_reference_numerics_normal(self, device, dtype, op): 274*da0073e9SAndroid Build Coastguard Worker tensors = generate_elementwise_unary_tensors( 275*da0073e9SAndroid Build Coastguard Worker op, device=device, dtype=dtype, requires_grad=False 276*da0073e9SAndroid Build Coastguard Worker ) 277*da0073e9SAndroid Build Coastguard Worker self._test_reference_numerics(dtype, op, tensors) 278*da0073e9SAndroid Build Coastguard Worker 279*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 280*da0073e9SAndroid Build Coastguard Worker @ops(reference_filtered_ops) 281*da0073e9SAndroid Build Coastguard Worker def test_reference_numerics_small(self, device, dtype, op): 282*da0073e9SAndroid Build Coastguard Worker if dtype in (torch.bool,): 283*da0073e9SAndroid Build Coastguard Worker raise self.skipTest("bool has no small values") 284*da0073e9SAndroid Build Coastguard Worker 285*da0073e9SAndroid Build Coastguard Worker tensors = generate_elementwise_unary_small_value_tensors( 286*da0073e9SAndroid Build Coastguard Worker op, device=device, dtype=dtype, requires_grad=False 287*da0073e9SAndroid Build Coastguard Worker ) 288*da0073e9SAndroid Build Coastguard Worker self._test_reference_numerics(dtype, op, tensors) 289*da0073e9SAndroid Build Coastguard Worker 290*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 291*da0073e9SAndroid Build Coastguard Worker @ops(reference_filtered_ops) 292*da0073e9SAndroid Build Coastguard Worker def test_reference_numerics_large(self, device, dtype, op): 293*da0073e9SAndroid Build Coastguard Worker if dtype in (torch.bool, torch.uint8, torch.int8): 294*da0073e9SAndroid Build Coastguard Worker raise self.skipTest("bool, uint8, and int8 dtypes have no large values") 295*da0073e9SAndroid Build Coastguard Worker 296*da0073e9SAndroid Build Coastguard Worker tensors = generate_elementwise_unary_large_value_tensors( 297*da0073e9SAndroid Build Coastguard Worker op, device=device, dtype=dtype, requires_grad=False 298*da0073e9SAndroid Build Coastguard Worker ) 299*da0073e9SAndroid Build Coastguard Worker self._test_reference_numerics(dtype, op, tensors) 300*da0073e9SAndroid Build Coastguard Worker 301*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 302*da0073e9SAndroid Build Coastguard Worker @ops( 303*da0073e9SAndroid Build Coastguard Worker reference_filtered_ops, 304*da0073e9SAndroid Build Coastguard Worker allowed_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half), 305*da0073e9SAndroid Build Coastguard Worker ) 306*da0073e9SAndroid Build Coastguard Worker def test_reference_numerics_extremal(self, device, dtype, op): 307*da0073e9SAndroid Build Coastguard Worker tensors = generate_elementwise_unary_extremal_value_tensors( 308*da0073e9SAndroid Build Coastguard Worker op, device=device, dtype=dtype, requires_grad=False 309*da0073e9SAndroid Build Coastguard Worker ) 310*da0073e9SAndroid Build Coastguard Worker self._test_reference_numerics(dtype, op, tensors) 311*da0073e9SAndroid Build Coastguard Worker 312*da0073e9SAndroid Build Coastguard Worker # Tests for testing (non)contiguity consistency 313*da0073e9SAndroid Build Coastguard Worker @ops(unary_ufuncs) 314*da0073e9SAndroid Build Coastguard Worker def test_contig_vs_every_other(self, device, dtype, op): 315*da0073e9SAndroid Build Coastguard Worker contig = make_tensor( 316*da0073e9SAndroid Build Coastguard Worker (1026,), device=device, dtype=dtype, low=op.domain[0], high=op.domain[1] 317*da0073e9SAndroid Build Coastguard Worker ) 318*da0073e9SAndroid Build Coastguard Worker non_contig = contig[::2] 319*da0073e9SAndroid Build Coastguard Worker 320*da0073e9SAndroid Build Coastguard Worker self.assertTrue(contig.is_contiguous()) 321*da0073e9SAndroid Build Coastguard Worker self.assertFalse(non_contig.is_contiguous()) 322*da0073e9SAndroid Build Coastguard Worker 323*da0073e9SAndroid Build Coastguard Worker torch_kwargs, _ = op.sample_kwargs(device, dtype, non_contig) 324*da0073e9SAndroid Build Coastguard Worker expected = op(non_contig, **torch_kwargs) 325*da0073e9SAndroid Build Coastguard Worker result = op(contig, **torch_kwargs) 326*da0073e9SAndroid Build Coastguard Worker result = pytree.tree_map(lambda x: x[::2], result) 327*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expected) 328*da0073e9SAndroid Build Coastguard Worker 329*da0073e9SAndroid Build Coastguard Worker @ops(unary_ufuncs) 330*da0073e9SAndroid Build Coastguard Worker def test_contig_vs_transposed(self, device, dtype, op): 331*da0073e9SAndroid Build Coastguard Worker contig = make_tensor( 332*da0073e9SAndroid Build Coastguard Worker (789, 357), device=device, dtype=dtype, low=op.domain[0], high=op.domain[1] 333*da0073e9SAndroid Build Coastguard Worker ) 334*da0073e9SAndroid Build Coastguard Worker non_contig = contig.T 335*da0073e9SAndroid Build Coastguard Worker 336*da0073e9SAndroid Build Coastguard Worker self.assertTrue(contig.is_contiguous()) 337*da0073e9SAndroid Build Coastguard Worker self.assertFalse(non_contig.is_contiguous()) 338*da0073e9SAndroid Build Coastguard Worker 339*da0073e9SAndroid Build Coastguard Worker torch_kwargs, _ = op.sample_kwargs(device, dtype, contig) 340*da0073e9SAndroid Build Coastguard Worker expected = op(non_contig, **torch_kwargs) 341*da0073e9SAndroid Build Coastguard Worker result = op(contig, **torch_kwargs) 342*da0073e9SAndroid Build Coastguard Worker result = pytree.tree_map(lambda x: x.T, result) 343*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expected) 344*da0073e9SAndroid Build Coastguard Worker 345*da0073e9SAndroid Build Coastguard Worker @ops(unary_ufuncs) 346*da0073e9SAndroid Build Coastguard Worker def test_non_contig(self, device, dtype, op): 347*da0073e9SAndroid Build Coastguard Worker shapes = [(5, 7), (1024,)] 348*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 349*da0073e9SAndroid Build Coastguard Worker contig = make_tensor( 350*da0073e9SAndroid Build Coastguard Worker shape, dtype=dtype, device=device, low=op.domain[0], high=op.domain[1] 351*da0073e9SAndroid Build Coastguard Worker ) 352*da0073e9SAndroid Build Coastguard Worker non_contig = torch.empty(shape + (2,), device=device, dtype=dtype)[..., 0] 353*da0073e9SAndroid Build Coastguard Worker non_contig.copy_(contig) 354*da0073e9SAndroid Build Coastguard Worker 355*da0073e9SAndroid Build Coastguard Worker self.assertTrue(contig.is_contiguous()) 356*da0073e9SAndroid Build Coastguard Worker self.assertFalse(non_contig.is_contiguous()) 357*da0073e9SAndroid Build Coastguard Worker 358*da0073e9SAndroid Build Coastguard Worker torch_kwargs, _ = op.sample_kwargs(device, dtype, contig) 359*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(contig, **torch_kwargs), op(non_contig, **torch_kwargs)) 360*da0073e9SAndroid Build Coastguard Worker 361*da0073e9SAndroid Build Coastguard Worker @ops(unary_ufuncs) 362*da0073e9SAndroid Build Coastguard Worker def test_non_contig_index(self, device, dtype, op): 363*da0073e9SAndroid Build Coastguard Worker contig = make_tensor( 364*da0073e9SAndroid Build Coastguard Worker (2, 2, 1, 2), 365*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 366*da0073e9SAndroid Build Coastguard Worker device=device, 367*da0073e9SAndroid Build Coastguard Worker low=op.domain[0], 368*da0073e9SAndroid Build Coastguard Worker high=op.domain[1], 369*da0073e9SAndroid Build Coastguard Worker ) 370*da0073e9SAndroid Build Coastguard Worker non_contig = contig[:, 1, ...] 371*da0073e9SAndroid Build Coastguard Worker contig = non_contig.contiguous() 372*da0073e9SAndroid Build Coastguard Worker 373*da0073e9SAndroid Build Coastguard Worker self.assertTrue(contig.is_contiguous()) 374*da0073e9SAndroid Build Coastguard Worker self.assertFalse(non_contig.is_contiguous()) 375*da0073e9SAndroid Build Coastguard Worker 376*da0073e9SAndroid Build Coastguard Worker torch_kwargs, _ = op.sample_kwargs(device, dtype, contig) 377*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(contig, **torch_kwargs), op(non_contig, **torch_kwargs)) 378*da0073e9SAndroid Build Coastguard Worker 379*da0073e9SAndroid Build Coastguard Worker @ops(unary_ufuncs) 380*da0073e9SAndroid Build Coastguard Worker def test_non_contig_expand(self, device, dtype, op): 381*da0073e9SAndroid Build Coastguard Worker shapes = [(1, 3), (1, 7), (5, 7)] 382*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 383*da0073e9SAndroid Build Coastguard Worker contig = make_tensor( 384*da0073e9SAndroid Build Coastguard Worker shape, dtype=dtype, device=device, low=op.domain[0], high=op.domain[1] 385*da0073e9SAndroid Build Coastguard Worker ) 386*da0073e9SAndroid Build Coastguard Worker non_contig = contig.clone().expand(3, -1, -1) 387*da0073e9SAndroid Build Coastguard Worker 388*da0073e9SAndroid Build Coastguard Worker self.assertTrue(contig.is_contiguous()) 389*da0073e9SAndroid Build Coastguard Worker self.assertFalse(non_contig.is_contiguous()) 390*da0073e9SAndroid Build Coastguard Worker 391*da0073e9SAndroid Build Coastguard Worker torch_kwargs, _ = op.sample_kwargs(device, dtype, contig) 392*da0073e9SAndroid Build Coastguard Worker contig = op(contig, **torch_kwargs) 393*da0073e9SAndroid Build Coastguard Worker non_contig = op(non_contig, **torch_kwargs) 394*da0073e9SAndroid Build Coastguard Worker for i in range(3): 395*da0073e9SAndroid Build Coastguard Worker non_contig_i = pytree.tree_map(lambda x: x[i], non_contig) 396*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 397*da0073e9SAndroid Build Coastguard Worker contig, non_contig_i, msg="non-contiguous expand[" + str(i) + "]" 398*da0073e9SAndroid Build Coastguard Worker ) 399*da0073e9SAndroid Build Coastguard Worker 400*da0073e9SAndroid Build Coastguard Worker @ops(unary_ufuncs) 401*da0073e9SAndroid Build Coastguard Worker def test_contig_size1(self, device, dtype, op): 402*da0073e9SAndroid Build Coastguard Worker contig = make_tensor( 403*da0073e9SAndroid Build Coastguard Worker (5, 100), dtype=dtype, device=device, low=op.domain[0], high=op.domain[1] 404*da0073e9SAndroid Build Coastguard Worker ) 405*da0073e9SAndroid Build Coastguard Worker contig = contig[:1, :50] 406*da0073e9SAndroid Build Coastguard Worker contig2 = torch.empty(contig.size(), device=device, dtype=dtype) 407*da0073e9SAndroid Build Coastguard Worker contig2.copy_(contig) 408*da0073e9SAndroid Build Coastguard Worker 409*da0073e9SAndroid Build Coastguard Worker self.assertTrue(contig.is_contiguous()) 410*da0073e9SAndroid Build Coastguard Worker self.assertTrue(contig2.is_contiguous()) 411*da0073e9SAndroid Build Coastguard Worker 412*da0073e9SAndroid Build Coastguard Worker torch_kwargs, _ = op.sample_kwargs(device, dtype, contig) 413*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(contig, **torch_kwargs), op(contig2, **torch_kwargs)) 414*da0073e9SAndroid Build Coastguard Worker 415*da0073e9SAndroid Build Coastguard Worker @ops(unary_ufuncs) 416*da0073e9SAndroid Build Coastguard Worker def test_contig_size1_large_dim(self, device, dtype, op): 417*da0073e9SAndroid Build Coastguard Worker contig = make_tensor( 418*da0073e9SAndroid Build Coastguard Worker (5, 2, 3, 1, 4, 5, 3, 2, 1, 2, 3, 4), 419*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 420*da0073e9SAndroid Build Coastguard Worker device=device, 421*da0073e9SAndroid Build Coastguard Worker low=op.domain[0], 422*da0073e9SAndroid Build Coastguard Worker high=op.domain[1], 423*da0073e9SAndroid Build Coastguard Worker ) 424*da0073e9SAndroid Build Coastguard Worker contig = contig[:1, :, :, :, :, :, :, :, :, :, :, :] 425*da0073e9SAndroid Build Coastguard Worker contig2 = torch.empty(contig.size(), device=device, dtype=dtype) 426*da0073e9SAndroid Build Coastguard Worker contig2.copy_(contig) 427*da0073e9SAndroid Build Coastguard Worker 428*da0073e9SAndroid Build Coastguard Worker self.assertTrue(contig.is_contiguous()) 429*da0073e9SAndroid Build Coastguard Worker self.assertTrue(contig2.is_contiguous()) 430*da0073e9SAndroid Build Coastguard Worker 431*da0073e9SAndroid Build Coastguard Worker torch_kwargs, _ = op.sample_kwargs(device, dtype, contig) 432*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(contig, **torch_kwargs), op(contig2, **torch_kwargs)) 433*da0073e9SAndroid Build Coastguard Worker 434*da0073e9SAndroid Build Coastguard Worker # Tests that computation on a multiple batches is the same as 435*da0073e9SAndroid Build Coastguard Worker # per-batch computation. 436*da0073e9SAndroid Build Coastguard Worker @ops(unary_ufuncs) 437*da0073e9SAndroid Build Coastguard Worker def test_batch_vs_slicing(self, device, dtype, op): 438*da0073e9SAndroid Build Coastguard Worker input = make_tensor( 439*da0073e9SAndroid Build Coastguard Worker (1024, 512), dtype=dtype, device=device, low=op.domain[0], high=op.domain[1] 440*da0073e9SAndroid Build Coastguard Worker ) 441*da0073e9SAndroid Build Coastguard Worker 442*da0073e9SAndroid Build Coastguard Worker torch_kwargs, _ = op.sample_kwargs(device, dtype, input) 443*da0073e9SAndroid Build Coastguard Worker actual = op(input, **torch_kwargs) 444*da0073e9SAndroid Build Coastguard Worker 445*da0073e9SAndroid Build Coastguard Worker all_outs = [op(slice, **torch_kwargs) for slice in input] 446*da0073e9SAndroid Build Coastguard Worker if is_iterable_of_tensors(actual): 447*da0073e9SAndroid Build Coastguard Worker expected = [torch.stack([out[i] for out in all_outs]) for i in range(len(actual))] 448*da0073e9SAndroid Build Coastguard Worker else: 449*da0073e9SAndroid Build Coastguard Worker expected = torch.stack(all_outs) 450*da0073e9SAndroid Build Coastguard Worker 451*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected) 452*da0073e9SAndroid Build Coastguard Worker 453*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.bool, torch.half)) 454*da0073e9SAndroid Build Coastguard Worker def test_nan_to_num(self, device, dtype): 455*da0073e9SAndroid Build Coastguard Worker for contiguous in [False, True]: 456*da0073e9SAndroid Build Coastguard Worker x = make_tensor((64, 64), low=0.0, high=100.0, dtype=dtype, device=device) 457*da0073e9SAndroid Build Coastguard Worker 458*da0073e9SAndroid Build Coastguard Worker if dtype.is_floating_point: 459*da0073e9SAndroid Build Coastguard Worker # Add extremal values. 460*da0073e9SAndroid Build Coastguard Worker extremals = [float("nan"), float("inf"), -float("inf")] 461*da0073e9SAndroid Build Coastguard Worker for idx, extremal in zip(torch.randint(0, 63, (3,)), extremals): 462*da0073e9SAndroid Build Coastguard Worker x[idx, :] = extremal 463*da0073e9SAndroid Build Coastguard Worker 464*da0073e9SAndroid Build Coastguard Worker if not contiguous: 465*da0073e9SAndroid Build Coastguard Worker x = x.T 466*da0073e9SAndroid Build Coastguard Worker 467*da0073e9SAndroid Build Coastguard Worker # With args 468*da0073e9SAndroid Build Coastguard Worker nan = random.random() 469*da0073e9SAndroid Build Coastguard Worker posinf = random.random() * 5 470*da0073e9SAndroid Build Coastguard Worker neginf = random.random() * 10 471*da0073e9SAndroid Build Coastguard Worker 472*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy( 473*da0073e9SAndroid Build Coastguard Worker lambda x: x.nan_to_num(nan=nan, posinf=posinf), 474*da0073e9SAndroid Build Coastguard Worker lambda x: np.nan_to_num(x, nan=nan, posinf=posinf), 475*da0073e9SAndroid Build Coastguard Worker x, 476*da0073e9SAndroid Build Coastguard Worker ) 477*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy( 478*da0073e9SAndroid Build Coastguard Worker lambda x: x.nan_to_num(posinf=posinf, neginf=neginf), 479*da0073e9SAndroid Build Coastguard Worker lambda x: np.nan_to_num(x, posinf=posinf, neginf=neginf), 480*da0073e9SAndroid Build Coastguard Worker x, 481*da0073e9SAndroid Build Coastguard Worker ) 482*da0073e9SAndroid Build Coastguard Worker 483*da0073e9SAndroid Build Coastguard Worker # Out Variant 484*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(x) 485*da0073e9SAndroid Build Coastguard Worker result = torch.nan_to_num(x) 486*da0073e9SAndroid Build Coastguard Worker torch.nan_to_num(x, out=out) 487*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, out) 488*da0073e9SAndroid Build Coastguard Worker 489*da0073e9SAndroid Build Coastguard Worker result = torch.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf) 490*da0073e9SAndroid Build Coastguard Worker torch.nan_to_num(x, out=out, nan=nan, posinf=posinf, neginf=neginf) 491*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, out) 492*da0073e9SAndroid Build Coastguard Worker 493*da0073e9SAndroid Build Coastguard Worker @onlyCPU 494*da0073e9SAndroid Build Coastguard Worker def test_nan_to_num_bfloat16(self, device): 495*da0073e9SAndroid Build Coastguard Worker def test_dtype(fn, input, dtype): 496*da0073e9SAndroid Build Coastguard Worker input = input.detach().clone().to(dtype=dtype).requires_grad_(True) 497*da0073e9SAndroid Build Coastguard Worker input2 = input.detach().clone().float().requires_grad_(True) 498*da0073e9SAndroid Build Coastguard Worker out = fn(input) 499*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 500*da0073e9SAndroid Build Coastguard Worker out2 = fn(input2) 501*da0073e9SAndroid Build Coastguard Worker out2.sum().backward() 502*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.dtype, dtype) 503*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad.dtype, dtype) 504*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, out2, exact_dtype=False) 505*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad, input2.grad, exact_dtype=False) 506*da0073e9SAndroid Build Coastguard Worker 507*da0073e9SAndroid Build Coastguard Worker def func(): 508*da0073e9SAndroid Build Coastguard Worker return torch.nan_to_num 509*da0073e9SAndroid Build Coastguard Worker 510*da0073e9SAndroid Build Coastguard Worker shapes = [[1, 3, 6, 6], [1, 3, 6, 128], [1, 3, 256, 256]] 511*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 512*da0073e9SAndroid Build Coastguard Worker x = torch.randn(shape, device=device) 513*da0073e9SAndroid Build Coastguard Worker extremals = [float('nan'), float('inf'), -float('inf')] 514*da0073e9SAndroid Build Coastguard Worker for id1, id2, extremal in zip(torch.randint(0, 2, (3,)), torch.randint(0, 5, (3,)), extremals): 515*da0073e9SAndroid Build Coastguard Worker x[0, id1, id2, :] = extremal 516*da0073e9SAndroid Build Coastguard Worker test_dtype(func(), x, torch.bfloat16) 517*da0073e9SAndroid Build Coastguard Worker 518*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.complex64, torch.complex128) 519*da0073e9SAndroid Build Coastguard Worker def test_nan_to_num_complex(self, device, dtype): 520*da0073e9SAndroid Build Coastguard Worker value_dtype = torch.tensor([], dtype=dtype).real.dtype 521*da0073e9SAndroid Build Coastguard Worker 522*da0073e9SAndroid Build Coastguard Worker def gen_tensor(a): 523*da0073e9SAndroid Build Coastguard Worker return torch.view_as_complex(torch.tensor(a, dtype=value_dtype, device=device)) 524*da0073e9SAndroid Build Coastguard Worker 525*da0073e9SAndroid Build Coastguard Worker for extremal, kwarg_name in zip(['nan', 'inf', '-inf'], ['nan', 'posinf', 'neginf']): 526*da0073e9SAndroid Build Coastguard Worker a = gen_tensor([123, float(extremal)]) 527*da0073e9SAndroid Build Coastguard Worker res = torch.nan_to_num(a, **{kwarg_name: 12}) 528*da0073e9SAndroid Build Coastguard Worker res_check = gen_tensor([123, 12]) 529*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, res_check) 530*da0073e9SAndroid Build Coastguard Worker 531*da0073e9SAndroid Build Coastguard Worker a = gen_tensor([float(extremal), 456]) 532*da0073e9SAndroid Build Coastguard Worker res = torch.nan_to_num(a, **{kwarg_name: 21}) 533*da0073e9SAndroid Build Coastguard Worker res_check = gen_tensor([21, 456]) 534*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, res_check) 535*da0073e9SAndroid Build Coastguard Worker 536*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.cdouble) 537*da0073e9SAndroid Build Coastguard Worker def test_complex_edge_values(self, device, dtype): 538*da0073e9SAndroid Build Coastguard Worker # sqrt Test Reference: https://github.com/pytorch/pytorch/pull/47424 539*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(0.0 - 1.0e20j, dtype=dtype, device=device) 540*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch.sqrt, np.sqrt, x) 541*da0073e9SAndroid Build Coastguard Worker # acos test reference: https://github.com/pytorch/pytorch/issue/42952 542*da0073e9SAndroid Build Coastguard Worker # Skip on Windows, as CUDA acos returns conjugate value 543*da0073e9SAndroid Build Coastguard Worker # see https://github.com/pytorch/pytorch/issues/52299 544*da0073e9SAndroid Build Coastguard Worker if not (IS_WINDOWS and dtype == torch.cdouble and "cuda" in device): 545*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch.acos, np.arccos, x) 546*da0073e9SAndroid Build Coastguard Worker 547*da0073e9SAndroid Build Coastguard Worker x = torch.tensor( 548*da0073e9SAndroid Build Coastguard Worker (-1.0e60 if dtype == torch.cdouble else -1.0e20) - 4988429.2j, 549*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 550*da0073e9SAndroid Build Coastguard Worker device=device, 551*da0073e9SAndroid Build Coastguard Worker ) 552*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch.sqrt, np.sqrt, x) 553*da0073e9SAndroid Build Coastguard Worker 554*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_SCIPY, "Requires SciPy") 555*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 556*da0073e9SAndroid Build Coastguard Worker def test_digamma_special(self, device, dtype): 557*da0073e9SAndroid Build Coastguard Worker # Based on SciPy test for the following special values. 558*da0073e9SAndroid Build Coastguard Worker # Reference: 559*da0073e9SAndroid Build Coastguard Worker # https://github.com/scipy/scipy/blob/3a8a3a1d4657254a6611e77e9c28feafa26e6645/scipy/special/tests/test_digamma.py#L22 560*da0073e9SAndroid Build Coastguard Worker euler = 0.57721566490153286 561*da0073e9SAndroid Build Coastguard Worker dataset = [ 562*da0073e9SAndroid Build Coastguard Worker (0.0, -0.0), 563*da0073e9SAndroid Build Coastguard Worker (1, -euler), 564*da0073e9SAndroid Build Coastguard Worker (0.5, -2 * math.log(2) - euler), 565*da0073e9SAndroid Build Coastguard Worker (1 / 3, -math.pi / (2 * math.sqrt(3)) - 3 * math.log(3) / 2 - euler), 566*da0073e9SAndroid Build Coastguard Worker (1 / 4, -math.pi / 2 - 3 * math.log(2) - euler), 567*da0073e9SAndroid Build Coastguard Worker ( 568*da0073e9SAndroid Build Coastguard Worker 1 / 6, 569*da0073e9SAndroid Build Coastguard Worker -math.pi * math.sqrt(3) / 2 570*da0073e9SAndroid Build Coastguard Worker - 2 * math.log(2) 571*da0073e9SAndroid Build Coastguard Worker - 3 * math.log(3) / 2 572*da0073e9SAndroid Build Coastguard Worker - euler, 573*da0073e9SAndroid Build Coastguard Worker ), 574*da0073e9SAndroid Build Coastguard Worker ( 575*da0073e9SAndroid Build Coastguard Worker 1 / 8, 576*da0073e9SAndroid Build Coastguard Worker -math.pi / 2 577*da0073e9SAndroid Build Coastguard Worker - 4 * math.log(2) 578*da0073e9SAndroid Build Coastguard Worker - (math.pi + math.log(2 + math.sqrt(2)) - math.log(2 - math.sqrt(2))) 579*da0073e9SAndroid Build Coastguard Worker / math.sqrt(2) 580*da0073e9SAndroid Build Coastguard Worker - euler, 581*da0073e9SAndroid Build Coastguard Worker ), 582*da0073e9SAndroid Build Coastguard Worker ] 583*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(dataset, device=device, dtype=dtype) 584*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch.digamma, scipy.special.digamma, x) 585*da0073e9SAndroid Build Coastguard Worker 586*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_SCIPY, "Requires SciPy") 587*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 588*da0073e9SAndroid Build Coastguard Worker def test_digamma(self, device, dtype): 589*da0073e9SAndroid Build Coastguard Worker # Tests pole behavior 590*da0073e9SAndroid Build Coastguard Worker tensor = torch.tensor( 591*da0073e9SAndroid Build Coastguard Worker [ 592*da0073e9SAndroid Build Coastguard Worker -0.999999994, 593*da0073e9SAndroid Build Coastguard Worker -1.999999994, 594*da0073e9SAndroid Build Coastguard Worker -2.0000000111, 595*da0073e9SAndroid Build Coastguard Worker -100.99999994, 596*da0073e9SAndroid Build Coastguard Worker 0.000000111, 597*da0073e9SAndroid Build Coastguard Worker -1931.99999994, 598*da0073e9SAndroid Build Coastguard Worker -0.000000111, 599*da0073e9SAndroid Build Coastguard Worker 0, 600*da0073e9SAndroid Build Coastguard Worker -0, 601*da0073e9SAndroid Build Coastguard Worker -1, 602*da0073e9SAndroid Build Coastguard Worker -2, 603*da0073e9SAndroid Build Coastguard Worker -931, 604*da0073e9SAndroid Build Coastguard Worker ], 605*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 606*da0073e9SAndroid Build Coastguard Worker device=device, 607*da0073e9SAndroid Build Coastguard Worker ) 608*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch.digamma, scipy.special.digamma, tensor) 609*da0073e9SAndroid Build Coastguard Worker 610*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_types_and(torch.half)) 611*da0073e9SAndroid Build Coastguard Worker def test_frexp(self, device, dtype): 612*da0073e9SAndroid Build Coastguard Worker input = make_tensor((50, 50), dtype=dtype, device=device) 613*da0073e9SAndroid Build Coastguard Worker mantissa, exponent = torch.frexp(input) 614*da0073e9SAndroid Build Coastguard Worker np_mantissa, np_exponent = np.frexp(input.cpu().numpy()) 615*da0073e9SAndroid Build Coastguard Worker 616*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mantissa, np_mantissa) 617*da0073e9SAndroid Build Coastguard Worker self.assertEqual(exponent, np_exponent) 618*da0073e9SAndroid Build Coastguard Worker 619*da0073e9SAndroid Build Coastguard Worker # torch.frexp returns exponent in int32 to be compatible with np.frexp 620*da0073e9SAndroid Build Coastguard Worker self.assertTrue(exponent.dtype == torch.int32) 621*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch_to_numpy_dtype_dict[exponent.dtype] == np_exponent.dtype) 622*da0073e9SAndroid Build Coastguard Worker 623*da0073e9SAndroid Build Coastguard Worker def test_frexp_assert_raises(self, device): 624*da0073e9SAndroid Build Coastguard Worker invalid_input_dtypes = integral_types_and(torch.bool) + complex_types() 625*da0073e9SAndroid Build Coastguard Worker for dtype in invalid_input_dtypes: 626*da0073e9SAndroid Build Coastguard Worker input = make_tensor((50, 50), dtype=dtype, device=device) 627*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 628*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"torch\.frexp\(\) only supports floating-point dtypes" 629*da0073e9SAndroid Build Coastguard Worker ): 630*da0073e9SAndroid Build Coastguard Worker torch.frexp(input) 631*da0073e9SAndroid Build Coastguard Worker 632*da0073e9SAndroid Build Coastguard Worker for dtype in floating_types_and(torch.half): 633*da0073e9SAndroid Build Coastguard Worker input = make_tensor((50, 50), dtype=dtype, device=device) 634*da0073e9SAndroid Build Coastguard Worker 635*da0073e9SAndroid Build Coastguard Worker dtypes = list( 636*da0073e9SAndroid Build Coastguard Worker all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16) 637*da0073e9SAndroid Build Coastguard Worker ) 638*da0073e9SAndroid Build Coastguard Worker dtypes.remove(dtype) 639*da0073e9SAndroid Build Coastguard Worker for mantissa_dtype in dtypes: 640*da0073e9SAndroid Build Coastguard Worker mantissa = torch.empty_like(input, dtype=mantissa_dtype) 641*da0073e9SAndroid Build Coastguard Worker exponent = torch.empty_like(input, dtype=torch.int) 642*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 643*da0073e9SAndroid Build Coastguard Worker RuntimeError, 644*da0073e9SAndroid Build Coastguard Worker r"torch\.frexp\(\) expects mantissa to have dtype .+ but got .+", 645*da0073e9SAndroid Build Coastguard Worker ): 646*da0073e9SAndroid Build Coastguard Worker torch.frexp(input, out=(mantissa, exponent)) 647*da0073e9SAndroid Build Coastguard Worker 648*da0073e9SAndroid Build Coastguard Worker dtypes.append(dtype) 649*da0073e9SAndroid Build Coastguard Worker dtypes.remove(torch.int) 650*da0073e9SAndroid Build Coastguard Worker for exponent_dtype in dtypes: 651*da0073e9SAndroid Build Coastguard Worker mantissa = torch.empty_like(input) 652*da0073e9SAndroid Build Coastguard Worker exponent = torch.empty_like(input, dtype=exponent_dtype) 653*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 654*da0073e9SAndroid Build Coastguard Worker RuntimeError, 655*da0073e9SAndroid Build Coastguard Worker r"torch\.frexp\(\) expects exponent to have int dtype but got .+", 656*da0073e9SAndroid Build Coastguard Worker ): 657*da0073e9SAndroid Build Coastguard Worker torch.frexp(input, out=(mantissa, exponent)) 658*da0073e9SAndroid Build Coastguard Worker 659*da0073e9SAndroid Build Coastguard Worker def test_polygamma_neg(self, device): 660*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 661*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"polygamma\(n, x\) does not support negative n\." 662*da0073e9SAndroid Build Coastguard Worker ): 663*da0073e9SAndroid Build Coastguard Worker torch.polygamma(-1, torch.tensor([1.0, 2.0], device=device)) 664*da0073e9SAndroid Build Coastguard Worker 665*da0073e9SAndroid Build Coastguard Worker # TODO resolve with opinfos 666*da0073e9SAndroid Build Coastguard Worker @onlyCPU 667*da0073e9SAndroid Build Coastguard Worker def test_op_invert(self, device): 668*da0073e9SAndroid Build Coastguard Worker res = 0xFFFF - torch.arange(127, dtype=torch.int8) 669*da0073e9SAndroid Build Coastguard Worker for dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): 670*da0073e9SAndroid Build Coastguard Worker a = torch.arange(127, dtype=dtype) 671*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.to(dtype), ~a) 672*da0073e9SAndroid Build Coastguard Worker 673*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor([True, False]), ~torch.tensor([False, True])) 674*da0073e9SAndroid Build Coastguard Worker 675*da0073e9SAndroid Build Coastguard Worker # test exceptions 676*da0073e9SAndroid Build Coastguard Worker for dtype in (torch.half, torch.float, torch.double): 677*da0073e9SAndroid Build Coastguard Worker a = torch.zeros(10, dtype=dtype) 678*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 679*da0073e9SAndroid Build Coastguard Worker b = ~a 680*da0073e9SAndroid Build Coastguard Worker 681*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.complex64, torch.complex128) 682*da0073e9SAndroid Build Coastguard Worker def test_abs_angle_complex_to_float(self, device, dtype): 683*da0073e9SAndroid Build Coastguard Worker # Constructs random complex values 684*da0073e9SAndroid Build Coastguard Worker from random import random 685*da0073e9SAndroid Build Coastguard Worker 686*da0073e9SAndroid Build Coastguard Worker random_vals = [] 687*da0073e9SAndroid Build Coastguard Worker for multiplier in (-1, 1, -10, 10, -100, 100): 688*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 689*da0073e9SAndroid Build Coastguard Worker random_vals.append( 690*da0073e9SAndroid Build Coastguard Worker complex(random() * multiplier, random() * multiplier) 691*da0073e9SAndroid Build Coastguard Worker ) 692*da0073e9SAndroid Build Coastguard Worker 693*da0073e9SAndroid Build Coastguard Worker for vals in (random_vals, []): 694*da0073e9SAndroid Build Coastguard Worker a = np.array(vals, dtype=torch_to_numpy_dtype_dict[dtype]) 695*da0073e9SAndroid Build Coastguard Worker t = torch.tensor(vals, device=device, dtype=dtype) 696*da0073e9SAndroid Build Coastguard Worker 697*da0073e9SAndroid Build Coastguard Worker for fn_name in ("abs", "angle"): 698*da0073e9SAndroid Build Coastguard Worker torch_fn = getattr(torch, fn_name) 699*da0073e9SAndroid Build Coastguard Worker np_fn = getattr(np, fn_name) 700*da0073e9SAndroid Build Coastguard Worker 701*da0073e9SAndroid Build Coastguard Worker # Tests function 702*da0073e9SAndroid Build Coastguard Worker np_result = torch.from_numpy(np_fn(a)) 703*da0073e9SAndroid Build Coastguard Worker torch_result = torch_fn(t).cpu() 704*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np_result, torch_result, exact_dtype=True) 705*da0073e9SAndroid Build Coastguard Worker 706*da0073e9SAndroid Build Coastguard Worker # Tests float out 707*da0073e9SAndroid Build Coastguard Worker float_dtype = ( 708*da0073e9SAndroid Build Coastguard Worker torch.float32 if dtype is torch.complex64 else torch.float64 709*da0073e9SAndroid Build Coastguard Worker ) 710*da0073e9SAndroid Build Coastguard Worker np_float_out = np_fn(a).astype(torch_to_numpy_dtype_dict[float_dtype]) 711*da0073e9SAndroid Build Coastguard Worker float_out = torch.empty_like(t, dtype=float_dtype) 712*da0073e9SAndroid Build Coastguard Worker torch_fn(t, out=float_out) 713*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.from_numpy(np_float_out), float_out.cpu()) 714*da0073e9SAndroid Build Coastguard Worker 715*da0073e9SAndroid Build Coastguard Worker # Tests float out (resized out) 716*da0073e9SAndroid Build Coastguard Worker float_out = torch.empty(1, device=device, dtype=float_dtype) 717*da0073e9SAndroid Build Coastguard Worker torch_fn(t, out=float_out) 718*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.from_numpy(np_float_out), float_out.cpu()) 719*da0073e9SAndroid Build Coastguard Worker 720*da0073e9SAndroid Build Coastguard Worker # Tests complex out 721*da0073e9SAndroid Build Coastguard Worker np_complex_out = np_fn(a).astype(torch_to_numpy_dtype_dict[dtype]) 722*da0073e9SAndroid Build Coastguard Worker complex_out = torch.empty_like(t) 723*da0073e9SAndroid Build Coastguard Worker torch_fn(t, out=complex_out) 724*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.from_numpy(np_complex_out), complex_out.cpu()) 725*da0073e9SAndroid Build Coastguard Worker 726*da0073e9SAndroid Build Coastguard Worker # Tests complex out (resized out) 727*da0073e9SAndroid Build Coastguard Worker complex_out = torch.empty(0, device=device, dtype=dtype) 728*da0073e9SAndroid Build Coastguard Worker torch_fn(t, out=complex_out) 729*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.from_numpy(np_complex_out), complex_out.cpu()) 730*da0073e9SAndroid Build Coastguard Worker 731*da0073e9SAndroid Build Coastguard Worker # Tests long out behavior (expected failure) 732*da0073e9SAndroid Build Coastguard Worker long_out = torch.empty(0, device=device, dtype=torch.long) 733*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 734*da0073e9SAndroid Build Coastguard Worker torch_fn(t, out=long_out) 735*da0073e9SAndroid Build Coastguard Worker 736*da0073e9SAndroid Build Coastguard Worker # Tests inplace 737*da0073e9SAndroid Build Coastguard Worker if fn_name == "abs": 738*da0073e9SAndroid Build Coastguard Worker torch_inplace_method = getattr(torch.Tensor, fn_name + "_") 739*da0073e9SAndroid Build Coastguard Worker np_fn(a, out=a) 740*da0073e9SAndroid Build Coastguard Worker if dtype.is_complex: 741*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 742*da0073e9SAndroid Build Coastguard Worker RuntimeError, 743*da0073e9SAndroid Build Coastguard Worker "In-place abs is not supported for complex tensors.", 744*da0073e9SAndroid Build Coastguard Worker ): 745*da0073e9SAndroid Build Coastguard Worker torch_inplace_method(t) 746*da0073e9SAndroid Build Coastguard Worker return 747*da0073e9SAndroid Build Coastguard Worker torch_inplace_method(t) 748*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.from_numpy(a), t.cpu()) 749*da0073e9SAndroid Build Coastguard Worker 750*da0073e9SAndroid Build Coastguard Worker # Note: angle does not have an in-place variant 751*da0073e9SAndroid Build Coastguard Worker if fn_name == "angle": 752*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AttributeError): 753*da0073e9SAndroid Build Coastguard Worker torch_inplace_method = getattr(torch.Tensor, fn_name + "_") 754*da0073e9SAndroid Build Coastguard Worker 755*da0073e9SAndroid Build Coastguard Worker def check_internal_mem_overlap( 756*da0073e9SAndroid Build Coastguard Worker self, inplace_op, num_inputs, dtype, device, expected_failure=False 757*da0073e9SAndroid Build Coastguard Worker ): 758*da0073e9SAndroid Build Coastguard Worker if isinstance(inplace_op, str): 759*da0073e9SAndroid Build Coastguard Worker inplace_op = getattr(torch.Tensor, inplace_op) 760*da0073e9SAndroid Build Coastguard Worker input = torch.randn(1, dtype=dtype, device=device).expand(3, 3) 761*da0073e9SAndroid Build Coastguard Worker inputs = [input] + [torch.randn_like(input) for i in range(num_inputs - 1)] 762*da0073e9SAndroid Build Coastguard Worker if not expected_failure: 763*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "single memory location"): 764*da0073e9SAndroid Build Coastguard Worker inplace_op(*inputs) 765*da0073e9SAndroid Build Coastguard Worker else: 766*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 767*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "single memory location"): 768*da0073e9SAndroid Build Coastguard Worker inplace_op(*inputs) 769*da0073e9SAndroid Build Coastguard Worker 770*da0073e9SAndroid Build Coastguard Worker def unary_check_input_output_mem_overlap( 771*da0073e9SAndroid Build Coastguard Worker self, data, sz, op, expected_failure=False 772*da0073e9SAndroid Build Coastguard Worker ): 773*da0073e9SAndroid Build Coastguard Worker def _test(op, output, input): 774*da0073e9SAndroid Build Coastguard Worker output_exp = torch.empty_like(output) 775*da0073e9SAndroid Build Coastguard Worker op(input, out=output_exp) 776*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(input, out=output), output_exp, msg=op.__name__) 777*da0073e9SAndroid Build Coastguard Worker 778*da0073e9SAndroid Build Coastguard Worker # output is identical to input: 779*da0073e9SAndroid Build Coastguard Worker _test(op, output=data[0:sz], input=data[0:sz]) 780*da0073e9SAndroid Build Coastguard Worker # output and input are independent: 781*da0073e9SAndroid Build Coastguard Worker _test(op, output=data[0:sz], input=data[sz : 2 * sz]) 782*da0073e9SAndroid Build Coastguard Worker # output partially overlaps with input: 783*da0073e9SAndroid Build Coastguard Worker if not expected_failure: 784*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "unsupported operation"): 785*da0073e9SAndroid Build Coastguard Worker _test(op, data[0:sz], data[1 : sz + 1]) 786*da0073e9SAndroid Build Coastguard Worker else: 787*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 788*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "unsupported operation"): 789*da0073e9SAndroid Build Coastguard Worker _test(op, data[0:sz], data[1 : sz + 1]) 790*da0073e9SAndroid Build Coastguard Worker 791*da0073e9SAndroid Build Coastguard Worker # TODO: run on non-native device types 792*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/126474 793*da0073e9SAndroid Build Coastguard Worker @xfailIfTorchDynamo 794*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 795*da0073e9SAndroid Build Coastguard Worker def test_unary_out_op_mem_overlap(self, device, dtype): 796*da0073e9SAndroid Build Coastguard Worker sz = 3 797*da0073e9SAndroid Build Coastguard Worker doubles = torch.randn(2 * sz, dtype=dtype, device=device) 798*da0073e9SAndroid Build Coastguard Worker positives = torch.randint(1, 100, (2 * sz,), device=device).double() 799*da0073e9SAndroid Build Coastguard Worker ints = torch.randint(-100, 100, (2 * sz,), device=device) 800*da0073e9SAndroid Build Coastguard Worker unary_mem_overlap_cases = [ 801*da0073e9SAndroid Build Coastguard Worker ("abs", doubles, True, True, "cpu"), 802*da0073e9SAndroid Build Coastguard Worker ("abs", doubles, True, True, "cuda"), 803*da0073e9SAndroid Build Coastguard Worker ("acos", doubles, True, True, "cpu"), 804*da0073e9SAndroid Build Coastguard Worker ("acos", doubles, True, True, "cuda"), 805*da0073e9SAndroid Build Coastguard Worker ("asin", doubles, True, True, "cpu"), 806*da0073e9SAndroid Build Coastguard Worker ("asin", doubles, True, True, "cuda"), 807*da0073e9SAndroid Build Coastguard Worker ("atan", doubles, True, True, "cpu"), 808*da0073e9SAndroid Build Coastguard Worker ("atan", doubles, True, True, "cuda"), 809*da0073e9SAndroid Build Coastguard Worker ("acosh", doubles, True, True, "cpu"), 810*da0073e9SAndroid Build Coastguard Worker ("acosh", doubles, True, True, "cuda"), 811*da0073e9SAndroid Build Coastguard Worker ("asinh", doubles, True, True, "cpu"), 812*da0073e9SAndroid Build Coastguard Worker ("asinh", doubles, True, True, "cuda"), 813*da0073e9SAndroid Build Coastguard Worker ("atanh", doubles, True, True, "cpu"), 814*da0073e9SAndroid Build Coastguard Worker ("atanh", doubles, True, True, "cuda"), 815*da0073e9SAndroid Build Coastguard Worker ("bitwise_not", ints, True, True, "cpu"), 816*da0073e9SAndroid Build Coastguard Worker ("bitwise_not", ints, True, True, "cuda"), 817*da0073e9SAndroid Build Coastguard Worker ("ceil", doubles, True, True, "cpu"), 818*da0073e9SAndroid Build Coastguard Worker ("ceil", doubles, True, True, "cuda"), 819*da0073e9SAndroid Build Coastguard Worker ("cos", doubles, True, True, "cpu"), 820*da0073e9SAndroid Build Coastguard Worker ("cos", doubles, True, True, "cuda"), 821*da0073e9SAndroid Build Coastguard Worker ("cosh", doubles, True, True, "cpu"), 822*da0073e9SAndroid Build Coastguard Worker ("cosh", doubles, True, True, "cuda"), 823*da0073e9SAndroid Build Coastguard Worker ("digamma", doubles, True, True, "cpu"), 824*da0073e9SAndroid Build Coastguard Worker ("erf", doubles, True, True, "cpu"), 825*da0073e9SAndroid Build Coastguard Worker ("erf", doubles, True, True, "cuda"), 826*da0073e9SAndroid Build Coastguard Worker ("erfc", doubles, True, True, "cpu"), 827*da0073e9SAndroid Build Coastguard Worker ("erfc", doubles, True, True, "cuda"), 828*da0073e9SAndroid Build Coastguard Worker ("erfinv", doubles, True, True, "cpu"), 829*da0073e9SAndroid Build Coastguard Worker ("erfinv", doubles, True, True, "cuda"), 830*da0073e9SAndroid Build Coastguard Worker ("exp", doubles, True, True, "cpu"), 831*da0073e9SAndroid Build Coastguard Worker ("exp", doubles, True, True, "cuda"), 832*da0073e9SAndroid Build Coastguard Worker ("exp2", doubles, True, True, "cpu"), 833*da0073e9SAndroid Build Coastguard Worker ("exp2", doubles, True, True, "cuda"), 834*da0073e9SAndroid Build Coastguard Worker ("expm1", doubles, True, True, "cpu"), 835*da0073e9SAndroid Build Coastguard Worker ("expm1", doubles, True, True, "cuda"), 836*da0073e9SAndroid Build Coastguard Worker ("floor", doubles, True, True, "cpu"), 837*da0073e9SAndroid Build Coastguard Worker ("floor", doubles, True, True, "cuda"), 838*da0073e9SAndroid Build Coastguard Worker ("frac", doubles, True, True, "cpu"), 839*da0073e9SAndroid Build Coastguard Worker ("frac", doubles, True, True, "cuda"), 840*da0073e9SAndroid Build Coastguard Worker ("i0", doubles, True, True, "cpu"), 841*da0073e9SAndroid Build Coastguard Worker ("i0", doubles, True, True, "cuda"), 842*da0073e9SAndroid Build Coastguard Worker ("log", positives, True, True, "cpu"), 843*da0073e9SAndroid Build Coastguard Worker ("log", positives, True, True, "cuda"), 844*da0073e9SAndroid Build Coastguard Worker ("log10", positives, True, True, "cpu"), 845*da0073e9SAndroid Build Coastguard Worker ("log10", positives, True, True, "cuda"), 846*da0073e9SAndroid Build Coastguard Worker ("log1p", positives, True, True, "cpu"), 847*da0073e9SAndroid Build Coastguard Worker ("log1p", positives, True, True, "cuda"), 848*da0073e9SAndroid Build Coastguard Worker ("log2", positives, True, True, "cpu"), 849*da0073e9SAndroid Build Coastguard Worker ("log2", positives, True, True, "cuda"), 850*da0073e9SAndroid Build Coastguard Worker ("neg", doubles, True, True, "cpu"), 851*da0073e9SAndroid Build Coastguard Worker ("neg", doubles, True, True, "cuda"), 852*da0073e9SAndroid Build Coastguard Worker ("reciprocal", doubles, True, True, "cpu"), 853*da0073e9SAndroid Build Coastguard Worker ("reciprocal", doubles, True, True, "cuda"), 854*da0073e9SAndroid Build Coastguard Worker ("round", doubles, True, True, "cpu"), 855*da0073e9SAndroid Build Coastguard Worker ("round", doubles, True, True, "cuda"), 856*da0073e9SAndroid Build Coastguard Worker ("rsqrt", positives, True, True, "cpu"), 857*da0073e9SAndroid Build Coastguard Worker ("rsqrt", positives, True, True, "cuda"), 858*da0073e9SAndroid Build Coastguard Worker ("sin", doubles, True, True, "cpu"), 859*da0073e9SAndroid Build Coastguard Worker ("sin", doubles, True, True, "cuda"), 860*da0073e9SAndroid Build Coastguard Worker ("sinh", doubles, True, True, "cpu"), 861*da0073e9SAndroid Build Coastguard Worker ("sinh", doubles, False, True, "cuda"), 862*da0073e9SAndroid Build Coastguard Worker ("sigmoid", doubles, True, True, "cpu"), 863*da0073e9SAndroid Build Coastguard Worker ("sigmoid", doubles, True, True, "cuda"), 864*da0073e9SAndroid Build Coastguard Worker ("logit", doubles, True, True, "cpu"), 865*da0073e9SAndroid Build Coastguard Worker ("logit", doubles, True, True, "cuda"), 866*da0073e9SAndroid Build Coastguard Worker ("sqrt", doubles, True, True, "cpu"), 867*da0073e9SAndroid Build Coastguard Worker ("sqrt", doubles, False, True, "cuda"), 868*da0073e9SAndroid Build Coastguard Worker ("tan", doubles, True, True, "cpu"), 869*da0073e9SAndroid Build Coastguard Worker ("tan", doubles, True, True, "cuda"), 870*da0073e9SAndroid Build Coastguard Worker ("tanh", doubles, True, True, "cpu"), 871*da0073e9SAndroid Build Coastguard Worker ("tanh", doubles, True, True, "cuda"), 872*da0073e9SAndroid Build Coastguard Worker ("trunc", doubles, True, True, "cpu"), 873*da0073e9SAndroid Build Coastguard Worker ("trunc", doubles, True, True, "cuda"), 874*da0073e9SAndroid Build Coastguard Worker ] 875*da0073e9SAndroid Build Coastguard Worker 876*da0073e9SAndroid Build Coastguard Worker for ( 877*da0073e9SAndroid Build Coastguard Worker fn, 878*da0073e9SAndroid Build Coastguard Worker inputs, 879*da0073e9SAndroid Build Coastguard Worker has_input_output_mem_overlap_check, 880*da0073e9SAndroid Build Coastguard Worker has_internal_mem_overlap_check, 881*da0073e9SAndroid Build Coastguard Worker dev, 882*da0073e9SAndroid Build Coastguard Worker ) in unary_mem_overlap_cases: 883*da0073e9SAndroid Build Coastguard Worker if dev != device: 884*da0073e9SAndroid Build Coastguard Worker continue 885*da0073e9SAndroid Build Coastguard Worker out_fn = getattr(torch, fn) 886*da0073e9SAndroid Build Coastguard Worker in_fn = getattr(torch.Tensor, fn + "_") 887*da0073e9SAndroid Build Coastguard Worker 888*da0073e9SAndroid Build Coastguard Worker self.unary_check_input_output_mem_overlap( 889*da0073e9SAndroid Build Coastguard Worker inputs, 890*da0073e9SAndroid Build Coastguard Worker sz, 891*da0073e9SAndroid Build Coastguard Worker out_fn, 892*da0073e9SAndroid Build Coastguard Worker expected_failure=not has_input_output_mem_overlap_check, 893*da0073e9SAndroid Build Coastguard Worker ) 894*da0073e9SAndroid Build Coastguard Worker 895*da0073e9SAndroid Build Coastguard Worker self.check_internal_mem_overlap( 896*da0073e9SAndroid Build Coastguard Worker in_fn, 897*da0073e9SAndroid Build Coastguard Worker 1, 898*da0073e9SAndroid Build Coastguard Worker dtype, 899*da0073e9SAndroid Build Coastguard Worker dev, 900*da0073e9SAndroid Build Coastguard Worker expected_failure=not has_internal_mem_overlap_check, 901*da0073e9SAndroid Build Coastguard Worker ) 902*da0073e9SAndroid Build Coastguard Worker 903*da0073e9SAndroid Build Coastguard Worker # TODO: opinfo hardshrink 904*da0073e9SAndroid Build Coastguard Worker @onlyCPU 905*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.bfloat16) 906*da0073e9SAndroid Build Coastguard Worker def test_hardshrink(self, device, dtype): 907*da0073e9SAndroid Build Coastguard Worker data = torch.tensor([1, 0.5, 0.3, 0.6], dtype=dtype, device=device).view(2, 2) 908*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 909*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 0.5, 0, 0.6], dtype=dtype, device=device).view(2, 2), 910*da0073e9SAndroid Build Coastguard Worker data.hardshrink(0.3), 911*da0073e9SAndroid Build Coastguard Worker ) 912*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 913*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 0, 0, 0.6], dtype=dtype, device=device).view(2, 2), 914*da0073e9SAndroid Build Coastguard Worker data.hardshrink(0.5), 915*da0073e9SAndroid Build Coastguard Worker ) 916*da0073e9SAndroid Build Coastguard Worker 917*da0073e9SAndroid Build Coastguard Worker # test default lambd=0.5 918*da0073e9SAndroid Build Coastguard Worker self.assertEqual(data.hardshrink(), data.hardshrink(0.5)) 919*da0073e9SAndroid Build Coastguard Worker 920*da0073e9SAndroid Build Coastguard Worker # test non-contiguous case 921*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 922*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 0, 0.5, 0.6], dtype=dtype, device=device).view(2, 2), 923*da0073e9SAndroid Build Coastguard Worker data.t().hardshrink(0.3), 924*da0073e9SAndroid Build Coastguard Worker ) 925*da0073e9SAndroid Build Coastguard Worker 926*da0073e9SAndroid Build Coastguard Worker @onlyCPU 927*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.bfloat16) 928*da0073e9SAndroid Build Coastguard Worker def test_hardshrink_edge_cases(self, device, dtype) -> None: 929*da0073e9SAndroid Build Coastguard Worker def h(values, l_expected): 930*da0073e9SAndroid Build Coastguard Worker for l, expected in l_expected.items(): 931*da0073e9SAndroid Build Coastguard Worker values_tensor = torch.tensor( 932*da0073e9SAndroid Build Coastguard Worker [float(v) for v in values], dtype=dtype, device=device 933*da0073e9SAndroid Build Coastguard Worker ) 934*da0073e9SAndroid Build Coastguard Worker expected_tensor = torch.tensor( 935*da0073e9SAndroid Build Coastguard Worker [float(v) for v in expected], dtype=dtype, device=device 936*da0073e9SAndroid Build Coastguard Worker ) 937*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 938*da0073e9SAndroid Build Coastguard Worker expected_tensor == values_tensor.hardshrink(l), 939*da0073e9SAndroid Build Coastguard Worker torch.ones_like(values_tensor, dtype=torch.bool), 940*da0073e9SAndroid Build Coastguard Worker ) 941*da0073e9SAndroid Build Coastguard Worker 942*da0073e9SAndroid Build Coastguard Worker def test_helper(min, max): 943*da0073e9SAndroid Build Coastguard Worker h( 944*da0073e9SAndroid Build Coastguard Worker [0.0, min, -min, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf], 945*da0073e9SAndroid Build Coastguard Worker { 946*da0073e9SAndroid Build Coastguard Worker 0.0: [0.0, min, -min, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf], 947*da0073e9SAndroid Build Coastguard Worker min: [0.0, 0.0, 0.0, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf], 948*da0073e9SAndroid Build Coastguard Worker 0.1: [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, -1.0, max, -max, inf, -inf], 949*da0073e9SAndroid Build Coastguard Worker 1.0: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, max, -max, inf, -inf], 950*da0073e9SAndroid Build Coastguard Worker max: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, inf, -inf], 951*da0073e9SAndroid Build Coastguard Worker inf: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 952*da0073e9SAndroid Build Coastguard Worker }, 953*da0073e9SAndroid Build Coastguard Worker ) 954*da0073e9SAndroid Build Coastguard Worker 955*da0073e9SAndroid Build Coastguard Worker test_helper(torch.finfo(dtype).tiny, torch.finfo(dtype).max) 956*da0073e9SAndroid Build Coastguard Worker 957*da0073e9SAndroid Build Coastguard Worker @onlyCPU 958*da0073e9SAndroid Build Coastguard Worker @slowTest 959*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 960*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(True, "Insufficient memory on linux.(2|4)xlarge") 961*da0073e9SAndroid Build Coastguard Worker def test_exp_slow(self, device, dtype): 962*da0073e9SAndroid Build Coastguard Worker # Test for https://github.com/pytorch/pytorch/issues/17271 963*da0073e9SAndroid Build Coastguard Worker # This is pretty slow on my Macbook but it only takes a few 964*da0073e9SAndroid Build Coastguard Worker # seconds on a beefy Xeon server 965*da0073e9SAndroid Build Coastguard Worker a = torch.exp(torch.ones(2**31, dtype=dtype, device=device)) 966*da0073e9SAndroid Build Coastguard Worker b = torch.exp(torch.ones(1, dtype=dtype, device=device)) 967*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, b.expand(2**31)) 968*da0073e9SAndroid Build Coastguard Worker 969*da0073e9SAndroid Build Coastguard Worker @precisionOverride( 970*da0073e9SAndroid Build Coastguard Worker {torch.bfloat16: 1e-2, torch.float: 0.0002, torch.double: 0.0002} 971*da0073e9SAndroid Build Coastguard Worker ) 972*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.bfloat16) 973*da0073e9SAndroid Build Coastguard Worker def test_hardswish(self, device, dtype): 974*da0073e9SAndroid Build Coastguard Worker inputValues = [-1000, -4, -3, -2, 0, 2, 3, 4, 1000] 975*da0073e9SAndroid Build Coastguard Worker expectedOutput = np.multiply( 976*da0073e9SAndroid Build Coastguard Worker inputValues, np.minimum(np.maximum((np.add(inputValues, 3)), 0), 6) / 6.0 977*da0073e9SAndroid Build Coastguard Worker ) 978*da0073e9SAndroid Build Coastguard Worker 979*da0073e9SAndroid Build Coastguard Worker inputTensor = torch.tensor(inputValues, dtype=dtype, device=device) 980*da0073e9SAndroid Build Coastguard Worker expectedOutputTensor = torch.tensor(expectedOutput, dtype=dtype, device=device) 981*da0073e9SAndroid Build Coastguard Worker 982*da0073e9SAndroid Build Coastguard Worker # normal 983*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 984*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.hardswish(inputTensor), expectedOutputTensor 985*da0073e9SAndroid Build Coastguard Worker ) 986*da0073e9SAndroid Build Coastguard Worker 987*da0073e9SAndroid Build Coastguard Worker # inplace 988*da0073e9SAndroid Build Coastguard Worker inputTensorCpy = inputTensor.clone().detach() 989*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.hardswish(inputTensorCpy, inplace=True) 990*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inputTensorCpy, expectedOutputTensor) 991*da0073e9SAndroid Build Coastguard Worker 992*da0073e9SAndroid Build Coastguard Worker @precisionOverride( 993*da0073e9SAndroid Build Coastguard Worker {torch.bfloat16: 1e-2, torch.float: 0.0002, torch.double: 0.0002} 994*da0073e9SAndroid Build Coastguard Worker ) 995*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.bfloat16) 996*da0073e9SAndroid Build Coastguard Worker def test_hardsigmoid(self, device, dtype): 997*da0073e9SAndroid Build Coastguard Worker inputValues = [-1000, -4, -3, -2, 0, 2, 3, 4, 1000] 998*da0073e9SAndroid Build Coastguard Worker expectedOutput = np.minimum(np.maximum((np.add(inputValues, 3)), 0), 6) / 6.0 999*da0073e9SAndroid Build Coastguard Worker 1000*da0073e9SAndroid Build Coastguard Worker inputTensor = torch.tensor(inputValues, dtype=dtype, device=device) 1001*da0073e9SAndroid Build Coastguard Worker 1002*da0073e9SAndroid Build Coastguard Worker # normal 1003*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1004*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.hardsigmoid(inputTensor), 1005*da0073e9SAndroid Build Coastguard Worker torch.tensor(expectedOutput, dtype=dtype, device=device), 1006*da0073e9SAndroid Build Coastguard Worker ) 1007*da0073e9SAndroid Build Coastguard Worker 1008*da0073e9SAndroid Build Coastguard Worker # inplace 1009*da0073e9SAndroid Build Coastguard Worker inputTensorCpy = inputTensor.clone().detach() 1010*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1011*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.hardsigmoid(inputTensorCpy, inplace=True), 1012*da0073e9SAndroid Build Coastguard Worker torch.tensor(expectedOutput, dtype=dtype, device=device), 1013*da0073e9SAndroid Build Coastguard Worker ) 1014*da0073e9SAndroid Build Coastguard Worker 1015*da0073e9SAndroid Build Coastguard Worker @precisionOverride( 1016*da0073e9SAndroid Build Coastguard Worker {torch.bfloat16: 1e-2, torch.float: 0.0002, torch.double: 0.0002} 1017*da0073e9SAndroid Build Coastguard Worker ) 1018*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.bfloat16) 1019*da0073e9SAndroid Build Coastguard Worker def test_hardsigmoid_backward(self, device, dtype): 1020*da0073e9SAndroid Build Coastguard Worker inputValues = [-3.0, 3.0, -2.0, 2.0, -6.0, 6.0] 1021*da0073e9SAndroid Build Coastguard Worker expectedValues = [0.0, 0.0, 1.0 / 6.0, 1.0 / 6.0, 0.0, 0.0] 1022*da0073e9SAndroid Build Coastguard Worker inputTensor = torch.tensor( 1023*da0073e9SAndroid Build Coastguard Worker inputValues, dtype=dtype, device=device 1024*da0073e9SAndroid Build Coastguard Worker ).requires_grad_() 1025*da0073e9SAndroid Build Coastguard Worker expetedTensor = torch.tensor(expectedValues, dtype=dtype, device=device) 1026*da0073e9SAndroid Build Coastguard Worker out = torch.nn.functional.hardsigmoid(inputTensor) 1027*da0073e9SAndroid Build Coastguard Worker out.backward(torch.ones_like(inputTensor)) 1028*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inputTensor.grad, expetedTensor) 1029*da0073e9SAndroid Build Coastguard Worker 1030*da0073e9SAndroid Build Coastguard Worker @skipIfNoSciPy 1031*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 1032*da0073e9SAndroid Build Coastguard Worker def test_silu(self, device, dtype): 1033*da0073e9SAndroid Build Coastguard Worker input_np = np.random.randn(5, 8) 1034*da0073e9SAndroid Build Coastguard Worker special_input = [[-1000, -1, -0.1, 0, 0.5, 1, 2, 1000]] 1035*da0073e9SAndroid Build Coastguard Worker input_np = np.concatenate((input_np, special_input), axis=0).astype( 1036*da0073e9SAndroid Build Coastguard Worker torch_to_numpy_dtype_dict[dtype] 1037*da0073e9SAndroid Build Coastguard Worker ) 1038*da0073e9SAndroid Build Coastguard Worker expected_output_np = input_np * scipy.special.expit(input_np) 1039*da0073e9SAndroid Build Coastguard Worker 1040*da0073e9SAndroid Build Coastguard Worker expected_output = torch.from_numpy(expected_output_np).to(device) 1041*da0073e9SAndroid Build Coastguard Worker expected_output_noncontig = expected_output.transpose(0, 1) 1042*da0073e9SAndroid Build Coastguard Worker 1043*da0073e9SAndroid Build Coastguard Worker atol = 1e-6 1044*da0073e9SAndroid Build Coastguard Worker rtol = 1e-6 1045*da0073e9SAndroid Build Coastguard Worker 1046*da0073e9SAndroid Build Coastguard Worker input = torch.from_numpy(input_np).clone().contiguous().to(device) 1047*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1048*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.silu(input), expected_output, atol=atol, rtol=rtol 1049*da0073e9SAndroid Build Coastguard Worker ) 1050*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1051*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.silu(input, inplace=True), 1052*da0073e9SAndroid Build Coastguard Worker expected_output, 1053*da0073e9SAndroid Build Coastguard Worker atol=atol, 1054*da0073e9SAndroid Build Coastguard Worker rtol=rtol, 1055*da0073e9SAndroid Build Coastguard Worker ) 1056*da0073e9SAndroid Build Coastguard Worker 1057*da0073e9SAndroid Build Coastguard Worker input = torch.from_numpy(input_np).clone().to(device) 1058*da0073e9SAndroid Build Coastguard Worker input_noncontig = input.transpose(0, 1) 1059*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1060*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.silu(input_noncontig), 1061*da0073e9SAndroid Build Coastguard Worker expected_output_noncontig, 1062*da0073e9SAndroid Build Coastguard Worker atol=atol, 1063*da0073e9SAndroid Build Coastguard Worker rtol=rtol, 1064*da0073e9SAndroid Build Coastguard Worker ) 1065*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1066*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.silu(input_noncontig, inplace=True), 1067*da0073e9SAndroid Build Coastguard Worker expected_output_noncontig, 1068*da0073e9SAndroid Build Coastguard Worker atol=atol, 1069*da0073e9SAndroid Build Coastguard Worker rtol=rtol, 1070*da0073e9SAndroid Build Coastguard Worker ) 1071*da0073e9SAndroid Build Coastguard Worker 1072*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.complex64, torch.complex128) 1073*da0073e9SAndroid Build Coastguard Worker def test_silu_complex(self, device, dtype): 1074*da0073e9SAndroid Build Coastguard Worker atol = 1e-6 1075*da0073e9SAndroid Build Coastguard Worker rtol = 1e-6 1076*da0073e9SAndroid Build Coastguard Worker inouts = [ 1077*da0073e9SAndroid Build Coastguard Worker (0.2 + 0.3j, 0.08775215595960617065 + 0.18024823069572448730j), 1078*da0073e9SAndroid Build Coastguard Worker (1e-19 + 1e-18j, 4.99999984132761269448e-20 + 5.00000022906852482872e-19j), 1079*da0073e9SAndroid Build Coastguard Worker (-1.0 + 2.0j, -0.78546208143234252930 + -0.44626939296722412109j), 1080*da0073e9SAndroid Build Coastguard Worker (0.0 + 0.5j, -0.06383547931909561157 + 0.25000000000000000000j), 1081*da0073e9SAndroid Build Coastguard Worker (2.0j, -1.55740761756896972656 + 0.99999988079071044922j) 1082*da0073e9SAndroid Build Coastguard Worker ] 1083*da0073e9SAndroid Build Coastguard Worker 1084*da0073e9SAndroid Build Coastguard Worker for inp, out in inouts: 1085*da0073e9SAndroid Build Coastguard Worker res = torch.nn.functional.silu(torch.tensor(inp, dtype=dtype, device=device)) 1086*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.any(torch.isnan(res))) 1087*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.real, out.real, atol=atol, rtol=rtol) 1088*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.imag, out.imag, atol=atol, rtol=rtol) 1089*da0073e9SAndroid Build Coastguard Worker 1090*da0073e9SAndroid Build Coastguard Worker for inp, out in inouts: 1091*da0073e9SAndroid Build Coastguard Worker res = torch.nn.functional.silu(torch.tensor(inp, dtype=dtype, device=device), inplace=True) 1092*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.any(torch.isnan(res))) 1093*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.real, out.real, atol=atol, rtol=rtol) 1094*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.imag, out.imag, atol=atol, rtol=rtol) 1095*da0073e9SAndroid Build Coastguard Worker 1096*da0073e9SAndroid Build Coastguard Worker # It is not obvious how to merge this into OpInfo becuase these inputs 1097*da0073e9SAndroid Build Coastguard Worker # succeed for gradcheck but are expected to fail for gradgradcheck 1098*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 1099*da0073e9SAndroid Build Coastguard Worker def test_sinc(self, device, dtype): 1100*da0073e9SAndroid Build Coastguard Worker # The derivative of sinc(x) at x=0 has to be special cased. 1101*da0073e9SAndroid Build Coastguard Worker # A naive computation will result in 0/0 -> NaN. 1102*da0073e9SAndroid Build Coastguard Worker # We also need to be careful when we are very close to 0, as the 1103*da0073e9SAndroid Build Coastguard Worker # derivative's denominator is squared, and there are some floats 1104*da0073e9SAndroid Build Coastguard Worker # that are positive and whose squares are zero. 1105*da0073e9SAndroid Build Coastguard Worker a = torch.tensor( 1106*da0073e9SAndroid Build Coastguard Worker [0.0, torch.finfo(torch.double).tiny, 1.0], 1107*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 1108*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 1109*da0073e9SAndroid Build Coastguard Worker device=device, 1110*da0073e9SAndroid Build Coastguard Worker ) 1111*da0073e9SAndroid Build Coastguard Worker gradcheck(torch.sinc, a) 1112*da0073e9SAndroid Build Coastguard Worker 1113*da0073e9SAndroid Build Coastguard Worker @skipIfNoSciPy 1114*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 1115*da0073e9SAndroid Build Coastguard Worker def test_mish(self, device, dtype): 1116*da0073e9SAndroid Build Coastguard Worker input_np = np.random.randn(5, 8) 1117*da0073e9SAndroid Build Coastguard Worker special_input = [[-1000, -1, -0.1, 0, 0.5, 1, 2, 1000]] 1118*da0073e9SAndroid Build Coastguard Worker input_np = np.concatenate((input_np, special_input), axis=0).astype( 1119*da0073e9SAndroid Build Coastguard Worker torch_to_numpy_dtype_dict[dtype] 1120*da0073e9SAndroid Build Coastguard Worker ) 1121*da0073e9SAndroid Build Coastguard Worker expected_output_np = input_np * np.tanh(np.log1p(np.exp(input_np))) 1122*da0073e9SAndroid Build Coastguard Worker 1123*da0073e9SAndroid Build Coastguard Worker expected_output = torch.from_numpy(expected_output_np).to(device) 1124*da0073e9SAndroid Build Coastguard Worker expected_output_noncontig = expected_output.transpose(0, 1) 1125*da0073e9SAndroid Build Coastguard Worker 1126*da0073e9SAndroid Build Coastguard Worker atol = 1e-6 1127*da0073e9SAndroid Build Coastguard Worker rtol = 1e-6 1128*da0073e9SAndroid Build Coastguard Worker 1129*da0073e9SAndroid Build Coastguard Worker input = torch.from_numpy(input_np).clone().contiguous().to(device) 1130*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1131*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.mish(input), expected_output, atol=atol, rtol=rtol 1132*da0073e9SAndroid Build Coastguard Worker ) 1133*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1134*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.mish(input, inplace=True), 1135*da0073e9SAndroid Build Coastguard Worker expected_output, 1136*da0073e9SAndroid Build Coastguard Worker atol=atol, 1137*da0073e9SAndroid Build Coastguard Worker rtol=rtol, 1138*da0073e9SAndroid Build Coastguard Worker ) 1139*da0073e9SAndroid Build Coastguard Worker 1140*da0073e9SAndroid Build Coastguard Worker input = torch.from_numpy(input_np).clone().to(device) 1141*da0073e9SAndroid Build Coastguard Worker input_noncontig = input.transpose(0, 1) 1142*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1143*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.mish(input_noncontig), 1144*da0073e9SAndroid Build Coastguard Worker expected_output_noncontig, 1145*da0073e9SAndroid Build Coastguard Worker atol=atol, 1146*da0073e9SAndroid Build Coastguard Worker rtol=rtol, 1147*da0073e9SAndroid Build Coastguard Worker ) 1148*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1149*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.mish(input_noncontig, inplace=True), 1150*da0073e9SAndroid Build Coastguard Worker expected_output_noncontig, 1151*da0073e9SAndroid Build Coastguard Worker atol=atol, 1152*da0073e9SAndroid Build Coastguard Worker rtol=rtol, 1153*da0073e9SAndroid Build Coastguard Worker ) 1154*da0073e9SAndroid Build Coastguard Worker 1155*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.complex64, torch.complex128) 1156*da0073e9SAndroid Build Coastguard Worker def test_log1p_complex(self, device, dtype): 1157*da0073e9SAndroid Build Coastguard Worker # The output values here were obtained using arbitrary precision math (mpmath) 1158*da0073e9SAndroid Build Coastguard Worker # and double checked with WolframAlpha. 1159*da0073e9SAndroid Build Coastguard Worker # Not using numpy's log1p here because by the time of writing this, 1160*da0073e9SAndroid Build Coastguard Worker # np.log1p has precision problems for small complex input values, see here: 1161*da0073e9SAndroid Build Coastguard Worker # https://github.com/numpy/numpy/issues/22609 1162*da0073e9SAndroid Build Coastguard Worker inouts = [ 1163*da0073e9SAndroid Build Coastguard Worker (0.2 + 0.3j, 0.21263386770217202 + 0.24497866312686414j), 1164*da0073e9SAndroid Build Coastguard Worker (1e-19 + 1e-18j, 1e-19 + 1e-18j), 1165*da0073e9SAndroid Build Coastguard Worker (1e-18 + 0.1j, 0.00497517 + 0.0996687j), 1166*da0073e9SAndroid Build Coastguard Worker (0.1 + 1e-18j, 0.0953102 + 9.090909090909090909e-19j), 1167*da0073e9SAndroid Build Coastguard Worker (0.5 + 0j, 0.40546510810816 + 0j), 1168*da0073e9SAndroid Build Coastguard Worker (0.0 + 0.5j, 0.111571776 + 0.463647609j), 1169*da0073e9SAndroid Build Coastguard Worker (2.0 + 1.0j, 1.151292546497023 + 0.3217505543966422j), 1170*da0073e9SAndroid Build Coastguard Worker (-1.0 + 2.0j, 0.6931471805599453 + 1.570796326794897j), 1171*da0073e9SAndroid Build Coastguard Worker (2.0j, 0.80471895621705014 + 1.1071487177940904j), 1172*da0073e9SAndroid Build Coastguard Worker (-2.0j, 0.80471895621705014 - 1.1071487177940904j), 1173*da0073e9SAndroid Build Coastguard Worker ] 1174*da0073e9SAndroid Build Coastguard Worker # test the extreme values 1175*da0073e9SAndroid Build Coastguard Worker if dtype == torch.complex128: 1176*da0073e9SAndroid Build Coastguard Worker inouts += [ 1177*da0073e9SAndroid Build Coastguard Worker (-1 + 1e250j, 575.6462732485114 + 1.5707963267948966j), 1178*da0073e9SAndroid Build Coastguard Worker (1e250 + 1j, 575.6462732485114 + 1e-250j), 1179*da0073e9SAndroid Build Coastguard Worker (1e250 + 1e250j, 575.9928468387914 + 0.7853981633974483j), 1180*da0073e9SAndroid Build Coastguard Worker (1e-250 + 1e250j, 575.6462732485114 + 1.5707963267948966j), 1181*da0073e9SAndroid Build Coastguard Worker (1e-250 + 2e-250j, 1e-250 + 2e-250j), 1182*da0073e9SAndroid Build Coastguard Worker (1e250 + 1e-250j, 575.6462732485114 + 0.0j), 1183*da0073e9SAndroid Build Coastguard Worker ] 1184*da0073e9SAndroid Build Coastguard Worker elif dtype == torch.complex64: 1185*da0073e9SAndroid Build Coastguard Worker inouts += [ 1186*da0073e9SAndroid Build Coastguard Worker (-1 + 1e30j, 69.07755278982137 + 1.5707963267948966j), 1187*da0073e9SAndroid Build Coastguard Worker (1e30 + 1j, 69.07755278982137 + 1e-30j), 1188*da0073e9SAndroid Build Coastguard Worker (1e30 + 1e30j, 69.42412638010134 + 0.7853981633974483j), 1189*da0073e9SAndroid Build Coastguard Worker (1e-30 + 1e30j, 69.07755278982137 + 1.5707963267948966j), 1190*da0073e9SAndroid Build Coastguard Worker (1e-30 + 2e-30j, 1e-30 + 2e-30j), 1191*da0073e9SAndroid Build Coastguard Worker (1e30 + 1e-30j, 69.07755278982137 + 0.0j), 1192*da0073e9SAndroid Build Coastguard Worker ] 1193*da0073e9SAndroid Build Coastguard Worker 1194*da0073e9SAndroid Build Coastguard Worker # test the log1p individually 1195*da0073e9SAndroid Build Coastguard Worker for inp, out in inouts: 1196*da0073e9SAndroid Build Coastguard Worker res = torch.log1p(torch.tensor(inp, dtype=dtype, device=device)) 1197*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.any(torch.isnan(res))) 1198*da0073e9SAndroid Build Coastguard Worker # setting up atol == 0.0 because some part has very small values 1199*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.real, out.real, atol=0.0, rtol=1e-6) 1200*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.imag, out.imag, atol=0.0, rtol=1e-6) 1201*da0073e9SAndroid Build Coastguard Worker 1202*da0073e9SAndroid Build Coastguard Worker # test the log1p in tensor 1203*da0073e9SAndroid Build Coastguard Worker inp_lst, out_lst = (list(elmt) for elmt in zip(*inouts)) 1204*da0073e9SAndroid Build Coastguard Worker inp_tens = torch.tensor(inp_lst, dtype=dtype, device=device) 1205*da0073e9SAndroid Build Coastguard Worker out_tens = torch.tensor(out_lst, dtype=dtype, device=device) 1206*da0073e9SAndroid Build Coastguard Worker res_tens = torch.log1p(inp_tens) 1207*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_tens.real, out_tens.real, atol=0.0, rtol=1e-6) 1208*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_tens.imag, out_tens.imag, atol=0.0, rtol=1e-6) 1209*da0073e9SAndroid Build Coastguard Worker 1210*da0073e9SAndroid Build Coastguard Worker # do ops like threshold need a test_unary(_nonufunc) test suite? 1211*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1212*da0073e9SAndroid Build Coastguard Worker @dtypes(*get_all_math_dtypes("cpu")) 1213*da0073e9SAndroid Build Coastguard Worker def test_threshold(self, device, dtype): 1214*da0073e9SAndroid Build Coastguard Worker if dtype != torch.uint8 and dtype != torch.float16 and not dtype.is_complex: 1215*da0073e9SAndroid Build Coastguard Worker # 100 is wide enough to use AVX2 instructions for all types 1216*da0073e9SAndroid Build Coastguard Worker x = ( 1217*da0073e9SAndroid Build Coastguard Worker torch.randn(100, dtype=torch.float, device=device) 1218*da0073e9SAndroid Build Coastguard Worker .sign() 1219*da0073e9SAndroid Build Coastguard Worker .to(dtype=dtype) 1220*da0073e9SAndroid Build Coastguard Worker ) 1221*da0073e9SAndroid Build Coastguard Worker y = torch.threshold(x, 0, 0) 1222*da0073e9SAndroid Build Coastguard Worker self.assertTrue(y.le(0).any()) 1223*da0073e9SAndroid Build Coastguard Worker 1224*da0073e9SAndroid Build Coastguard Worker def _helper_test_igamma(self, loglo, loghi, device, dtype, torch_fcn, scipy_fcn): 1225*da0073e9SAndroid Build Coastguard Worker exp1 = 2.71828182846 1226*da0073e9SAndroid Build Coastguard Worker vec1 = torch.logspace( 1227*da0073e9SAndroid Build Coastguard Worker loglo, loghi, steps=500, base=exp1, dtype=torch.float64, device=device 1228*da0073e9SAndroid Build Coastguard Worker ).unsqueeze(-1) 1229*da0073e9SAndroid Build Coastguard Worker vec1 = vec1.to(dtype) 1230*da0073e9SAndroid Build Coastguard Worker inputs = [ 1231*da0073e9SAndroid Build Coastguard Worker (vec1, vec1.transpose(0, 1)), 1232*da0073e9SAndroid Build Coastguard Worker (vec1, vec1), # for large number, it should approach 0.5 1233*da0073e9SAndroid Build Coastguard Worker (vec1, 0.5 * vec1), # test for considerable ratio 1234*da0073e9SAndroid Build Coastguard Worker (vec1, 2.0 * vec1), 1235*da0073e9SAndroid Build Coastguard Worker (vec1[::2, :], vec1[::2, :]), # contiguous/noncontiguous tests 1236*da0073e9SAndroid Build Coastguard Worker (vec1[::2, :], vec1[: vec1.shape[0] // 2, :]), 1237*da0073e9SAndroid Build Coastguard Worker (vec1[: vec1.shape[0] // 2, :], vec1[::2, :]), 1238*da0073e9SAndroid Build Coastguard Worker ] 1239*da0073e9SAndroid Build Coastguard Worker half_prec = dtype in [torch.bfloat16, torch.float16] 1240*da0073e9SAndroid Build Coastguard Worker for input0, input1 in inputs: 1241*da0073e9SAndroid Build Coastguard Worker actual = torch_fcn(input0, input1) 1242*da0073e9SAndroid Build Coastguard Worker if half_prec: 1243*da0073e9SAndroid Build Coastguard Worker input0 = input0.to(torch.float) 1244*da0073e9SAndroid Build Coastguard Worker input1 = input1.to(torch.float) 1245*da0073e9SAndroid Build Coastguard Worker expected = scipy_fcn(input0.cpu().numpy(), input1.cpu().numpy()) 1246*da0073e9SAndroid Build Coastguard Worker expected = torch.from_numpy(expected).to(dtype) 1247*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected) 1248*da0073e9SAndroid Build Coastguard Worker 1249*da0073e9SAndroid Build Coastguard Worker @dtypesIfCPU(torch.float16, torch.bfloat16, torch.float32, torch.float64) 1250*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.float64) 1251*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_SCIPY, "SciPy not found") 1252*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1253*da0073e9SAndroid Build Coastguard Worker def test_igamma_common(self, device, dtype): 1254*da0073e9SAndroid Build Coastguard Worker # test igamma for reasonable range of values 1255*da0073e9SAndroid Build Coastguard Worker loglo = -4 # approx 0.018 1256*da0073e9SAndroid Build Coastguard Worker loghi = 4 # approx 54.6 1257*da0073e9SAndroid Build Coastguard Worker self._helper_test_igamma( 1258*da0073e9SAndroid Build Coastguard Worker loglo, loghi, device, dtype, torch.igamma, scipy.special.gammainc 1259*da0073e9SAndroid Build Coastguard Worker ) 1260*da0073e9SAndroid Build Coastguard Worker 1261*da0073e9SAndroid Build Coastguard Worker @dtypesIfCPU(torch.float16, torch.bfloat16, torch.float32, torch.float64) 1262*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.float64) 1263*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_SCIPY, "SciPy not found") 1264*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1265*da0073e9SAndroid Build Coastguard Worker def test_igammac_common(self, device, dtype): 1266*da0073e9SAndroid Build Coastguard Worker # test igammac for reasonable range of values 1267*da0073e9SAndroid Build Coastguard Worker loglo = -4 # approx 0.018 1268*da0073e9SAndroid Build Coastguard Worker loghi = 4 # approx 54.6 1269*da0073e9SAndroid Build Coastguard Worker self._helper_test_igamma( 1270*da0073e9SAndroid Build Coastguard Worker loglo, loghi, device, dtype, torch.igammac, scipy.special.gammaincc 1271*da0073e9SAndroid Build Coastguard Worker ) 1272*da0073e9SAndroid Build Coastguard Worker 1273*da0073e9SAndroid Build Coastguard Worker @dtypesIfCPU(torch.float16, torch.bfloat16, torch.float32, torch.float64) 1274*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.float64) 1275*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1276*da0073e9SAndroid Build Coastguard Worker def test_igamma_edge_cases(self, device, dtype): 1277*da0073e9SAndroid Build Coastguard Worker tkwargs = {"dtype": dtype, "device": device} 1278*da0073e9SAndroid Build Coastguard Worker infs = torch.zeros((3,), **tkwargs) + float("inf") 1279*da0073e9SAndroid Build Coastguard Worker zeros = torch.zeros((3,), **tkwargs) 1280*da0073e9SAndroid Build Coastguard Worker ones = torch.ones((3,), **tkwargs) 1281*da0073e9SAndroid Build Coastguard Worker zero_to_large = torch.tensor([0.0, 1.0, 1e3], **tkwargs) 1282*da0073e9SAndroid Build Coastguard Worker small_to_inf = torch.tensor([1e-3, 1.0, float("inf")], **tkwargs) 1283*da0073e9SAndroid Build Coastguard Worker nans = torch.zeros((3,), **tkwargs) + float("nan") 1284*da0073e9SAndroid Build Coastguard Worker inpouts = [ 1285*da0073e9SAndroid Build Coastguard Worker # (a , x), out 1286*da0073e9SAndroid Build Coastguard Worker ((zeros, small_to_inf), ones), 1287*da0073e9SAndroid Build Coastguard Worker ((small_to_inf, zeros), zeros), 1288*da0073e9SAndroid Build Coastguard Worker ((infs, zero_to_large), zeros), 1289*da0073e9SAndroid Build Coastguard Worker ((zero_to_large, infs), ones), 1290*da0073e9SAndroid Build Coastguard Worker ((zeros, zeros), nans), 1291*da0073e9SAndroid Build Coastguard Worker ((infs, infs), nans), 1292*da0073e9SAndroid Build Coastguard Worker ((-small_to_inf, small_to_inf), nans), 1293*da0073e9SAndroid Build Coastguard Worker ] 1294*da0073e9SAndroid Build Coastguard Worker for inputs, output in inpouts: 1295*da0073e9SAndroid Build Coastguard Worker input0, input1 = inputs 1296*da0073e9SAndroid Build Coastguard Worker calc = torch.igamma(input0, input1) 1297*da0073e9SAndroid Build Coastguard Worker if torch.all(torch.isnan(output)): 1298*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.all(torch.isnan(calc))) 1299*da0073e9SAndroid Build Coastguard Worker else: 1300*da0073e9SAndroid Build Coastguard Worker self.assertEqual(calc, output) 1301*da0073e9SAndroid Build Coastguard Worker 1302*da0073e9SAndroid Build Coastguard Worker @dtypesIfCPU(torch.float16, torch.bfloat16, torch.float32, torch.float64) 1303*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.float64) 1304*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1305*da0073e9SAndroid Build Coastguard Worker def test_igammac_edge_cases(self, device, dtype): 1306*da0073e9SAndroid Build Coastguard Worker tkwargs = {"dtype": dtype, "device": device} 1307*da0073e9SAndroid Build Coastguard Worker infs = torch.zeros((3,), **tkwargs) + float("inf") 1308*da0073e9SAndroid Build Coastguard Worker zeros = torch.zeros((3,), **tkwargs) 1309*da0073e9SAndroid Build Coastguard Worker ones = torch.ones((3,), **tkwargs) 1310*da0073e9SAndroid Build Coastguard Worker zero_to_large = torch.tensor([0.0, 1.0, 1e3], **tkwargs) 1311*da0073e9SAndroid Build Coastguard Worker small_to_inf = torch.tensor([1e-3, 1.0, float("inf")], **tkwargs) 1312*da0073e9SAndroid Build Coastguard Worker nans = torch.zeros((3,), **tkwargs) + float("nan") 1313*da0073e9SAndroid Build Coastguard Worker inpouts = [ 1314*da0073e9SAndroid Build Coastguard Worker # (a , x), out 1315*da0073e9SAndroid Build Coastguard Worker ((zeros, small_to_inf), zeros), 1316*da0073e9SAndroid Build Coastguard Worker ((small_to_inf, zeros), ones), 1317*da0073e9SAndroid Build Coastguard Worker ((infs, zero_to_large), ones), 1318*da0073e9SAndroid Build Coastguard Worker ((zero_to_large, infs), zeros), 1319*da0073e9SAndroid Build Coastguard Worker ((zeros, zeros), nans), 1320*da0073e9SAndroid Build Coastguard Worker ((infs, infs), nans), 1321*da0073e9SAndroid Build Coastguard Worker ((-small_to_inf, small_to_inf), nans), 1322*da0073e9SAndroid Build Coastguard Worker ] 1323*da0073e9SAndroid Build Coastguard Worker for inputs, output in inpouts: 1324*da0073e9SAndroid Build Coastguard Worker input0, input1 = inputs 1325*da0073e9SAndroid Build Coastguard Worker calc = torch.igammac(input0, input1) 1326*da0073e9SAndroid Build Coastguard Worker if torch.all(torch.isnan(output)): 1327*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.all(torch.isnan(calc))) 1328*da0073e9SAndroid Build Coastguard Worker else: 1329*da0073e9SAndroid Build Coastguard Worker self.assertEqual(calc, output) 1330*da0073e9SAndroid Build Coastguard Worker 1331*da0073e9SAndroid Build Coastguard Worker def _i0_helper(self, t): 1332*da0073e9SAndroid Build Coastguard Worker # Test by comparing to scipy 1333*da0073e9SAndroid Build Coastguard Worker dtype = t.dtype 1334*da0073e9SAndroid Build Coastguard Worker actual = torch.i0(t) 1335*da0073e9SAndroid Build Coastguard Worker if dtype is torch.bfloat16: 1336*da0073e9SAndroid Build Coastguard Worker t = t.to(torch.float32) 1337*da0073e9SAndroid Build Coastguard Worker expected = scipy.special.i0(t.cpu().numpy()) 1338*da0073e9SAndroid Build Coastguard Worker # Casting down for dtype float16 is required since scipy upcasts to float32 1339*da0073e9SAndroid Build Coastguard Worker if dtype is torch.bfloat16 or dtype is torch.float16: 1340*da0073e9SAndroid Build Coastguard Worker expected = torch.from_numpy(expected).to(dtype) 1341*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected) 1342*da0073e9SAndroid Build Coastguard Worker 1343*da0073e9SAndroid Build Coastguard Worker def _i0_range_helper(self, range, device, dtype): 1344*da0073e9SAndroid Build Coastguard Worker # i0 tests are broken up by the domain for which the function does not overflow for each dtype 1345*da0073e9SAndroid Build Coastguard Worker # This is done to ensure that the function performs well across all possible input values, without worrying 1346*da0073e9SAndroid Build Coastguard Worker # about inf or nan possibilities 1347*da0073e9SAndroid Build Coastguard Worker for r in (range, -range): 1348*da0073e9SAndroid Build Coastguard Worker t = torch.rand(1000, device=device).to(dtype) * r 1349*da0073e9SAndroid Build Coastguard Worker self._i0_helper(t) 1350*da0073e9SAndroid Build Coastguard Worker 1351*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16)) 1352*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.bfloat16, torch.float32, torch.float64) 1353*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_SCIPY, "SciPy not found") 1354*da0073e9SAndroid Build Coastguard Worker def test_i0_range1(self, device, dtype): 1355*da0073e9SAndroid Build Coastguard Worker # This tests the domain for i0 for which float16 does not overflow 1356*da0073e9SAndroid Build Coastguard Worker # The domain is (-13.25, 13.25) 1357*da0073e9SAndroid Build Coastguard Worker self._i0_range_helper(13.25, device, dtype) 1358*da0073e9SAndroid Build Coastguard Worker 1359*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16)) 1360*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.bfloat16, torch.float32, torch.float64) 1361*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_SCIPY, "SciPy not found") 1362*da0073e9SAndroid Build Coastguard Worker def test_i0_range2(self, device, dtype): 1363*da0073e9SAndroid Build Coastguard Worker # This tests the domain for i0 for which float32 and bfloat16 does not overflow 1364*da0073e9SAndroid Build Coastguard Worker # The domain is (-88.5, 88.5) 1365*da0073e9SAndroid Build Coastguard Worker self._i0_range_helper(88.5, device, dtype) 1366*da0073e9SAndroid Build Coastguard Worker 1367*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float64) 1368*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_SCIPY, "SciPy not found") 1369*da0073e9SAndroid Build Coastguard Worker def test_i0_range3(self, device, dtype): 1370*da0073e9SAndroid Build Coastguard Worker # This tests the domain for i0 for which float64 does not overflow 1371*da0073e9SAndroid Build Coastguard Worker # The domain is (-709.75, 709.75) 1372*da0073e9SAndroid Build Coastguard Worker self._i0_range_helper(709.75, device, dtype) 1373*da0073e9SAndroid Build Coastguard Worker 1374*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16)) 1375*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.bfloat16, torch.float32, torch.float64) 1376*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_SCIPY, "SciPy not found") 1377*da0073e9SAndroid Build Coastguard Worker def test_i0_special(self, device, dtype): 1378*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([], device=device, dtype=dtype) 1379*da0073e9SAndroid Build Coastguard Worker self._i0_helper(t) 1380*da0073e9SAndroid Build Coastguard Worker 1381*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([inf, -inf, nan], device=device, dtype=dtype) 1382*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.i0(t).isnan().all()) 1383*da0073e9SAndroid Build Coastguard Worker 1384*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16)) 1385*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.bfloat16, torch.float32, torch.float64) 1386*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_SCIPY, "SciPy not found") 1387*da0073e9SAndroid Build Coastguard Worker def test_special_i0_i1_vs_scipy(self, device, dtype): 1388*da0073e9SAndroid Build Coastguard Worker def check_equal(t, torch_fn, scipy_fn): 1389*da0073e9SAndroid Build Coastguard Worker # Test by comparing to scipy 1390*da0073e9SAndroid Build Coastguard Worker actual = torch_fn(t) 1391*da0073e9SAndroid Build Coastguard Worker if dtype is torch.bfloat16: 1392*da0073e9SAndroid Build Coastguard Worker t = t.to(torch.float32) 1393*da0073e9SAndroid Build Coastguard Worker expected = scipy_fn(t.cpu().numpy()) 1394*da0073e9SAndroid Build Coastguard Worker 1395*da0073e9SAndroid Build Coastguard Worker # Casting down for dtype float16 is required since scipy upcasts to float32 1396*da0073e9SAndroid Build Coastguard Worker if dtype is torch.bfloat16 or dtype is torch.float16: 1397*da0073e9SAndroid Build Coastguard Worker expected = torch.from_numpy(expected).to(dtype) 1398*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected) 1399*da0073e9SAndroid Build Coastguard Worker 1400*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([], device=device, dtype=dtype) 1401*da0073e9SAndroid Build Coastguard Worker check_equal(t, torch.i0, scipy.special.i0) 1402*da0073e9SAndroid Build Coastguard Worker check_equal(t, torch.special.i0e, scipy.special.i0e) 1403*da0073e9SAndroid Build Coastguard Worker if dtype not in [torch.half, torch.bfloat16]: 1404*da0073e9SAndroid Build Coastguard Worker check_equal(t, torch.special.i1, scipy.special.i1) 1405*da0073e9SAndroid Build Coastguard Worker check_equal(t, torch.special.i1e, scipy.special.i1e) 1406*da0073e9SAndroid Build Coastguard Worker 1407*da0073e9SAndroid Build Coastguard Worker range = (-1e7, 1e7) 1408*da0073e9SAndroid Build Coastguard Worker if dtype == torch.half: 1409*da0073e9SAndroid Build Coastguard Worker range = (-65000, 65000) 1410*da0073e9SAndroid Build Coastguard Worker 1411*da0073e9SAndroid Build Coastguard Worker t = torch.linspace(*range, int(1e4), device=device, dtype=dtype) 1412*da0073e9SAndroid Build Coastguard Worker check_equal(t, torch.i0, scipy.special.i0) 1413*da0073e9SAndroid Build Coastguard Worker check_equal(t, torch.special.i0e, scipy.special.i0e) 1414*da0073e9SAndroid Build Coastguard Worker if dtype not in [torch.half, torch.bfloat16]: 1415*da0073e9SAndroid Build Coastguard Worker check_equal(t, torch.special.i1, scipy.special.i1) 1416*da0073e9SAndroid Build Coastguard Worker check_equal(t, torch.special.i1e, scipy.special.i1e) 1417*da0073e9SAndroid Build Coastguard Worker 1418*da0073e9SAndroid Build Coastguard Worker # NaN, inf, -inf are tested in reference_numerics tests. 1419*da0073e9SAndroid Build Coastguard Worker info = torch.finfo(dtype) 1420*da0073e9SAndroid Build Coastguard Worker min, max, eps, tiny = info.min, info.max, info.eps, info.tiny 1421*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([min, max, eps, tiny], dtype=dtype, device=device) 1422*da0073e9SAndroid Build Coastguard Worker check_equal(t, torch.i0, scipy.special.i0) 1423*da0073e9SAndroid Build Coastguard Worker check_equal(t, torch.special.i0e, scipy.special.i0e) 1424*da0073e9SAndroid Build Coastguard Worker if dtype not in [torch.half, torch.bfloat16]: 1425*da0073e9SAndroid Build Coastguard Worker check_equal(t, torch.special.i1, scipy.special.i1) 1426*da0073e9SAndroid Build Coastguard Worker check_equal(t, torch.special.i1e, scipy.special.i1e) 1427*da0073e9SAndroid Build Coastguard Worker 1428*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.float64) 1429*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_SCIPY, "SciPy not found") 1430*da0073e9SAndroid Build Coastguard Worker def test_special_ndtr_vs_scipy(self, device, dtype): 1431*da0073e9SAndroid Build Coastguard Worker def check_equal(t): 1432*da0073e9SAndroid Build Coastguard Worker # Test by comparing to scipy 1433*da0073e9SAndroid Build Coastguard Worker actual = torch.special.ndtr(t) 1434*da0073e9SAndroid Build Coastguard Worker expected = scipy.special.ndtr(t.cpu().numpy()) 1435*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected) 1436*da0073e9SAndroid Build Coastguard Worker 1437*da0073e9SAndroid Build Coastguard Worker range = (-10, 10) 1438*da0073e9SAndroid Build Coastguard Worker t = torch.linspace(*range, 1, device=device, dtype=dtype) 1439*da0073e9SAndroid Build Coastguard Worker check_equal(t) 1440*da0073e9SAndroid Build Coastguard Worker 1441*da0073e9SAndroid Build Coastguard Worker # Skip testing NaN, inf, -inf since they are tested in reference_numerics tests. 1442*da0073e9SAndroid Build Coastguard Worker info = torch.finfo(dtype) 1443*da0073e9SAndroid Build Coastguard Worker min, max, eps, tiny = info.min, info.max, info.eps, info.tiny 1444*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([min, max, eps, tiny], dtype=dtype, device=device) 1445*da0073e9SAndroid Build Coastguard Worker check_equal(t) 1446*da0073e9SAndroid Build Coastguard Worker 1447*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.float64) 1448*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_SCIPY, "SciPy not found") 1449*da0073e9SAndroid Build Coastguard Worker def test_special_log_ndtr_vs_scipy(self, device, dtype): 1450*da0073e9SAndroid Build Coastguard Worker def check_equal(t): 1451*da0073e9SAndroid Build Coastguard Worker # Test by comparing with scipy 1452*da0073e9SAndroid Build Coastguard Worker actual = torch.special.log_ndtr(t) 1453*da0073e9SAndroid Build Coastguard Worker expected = scipy.special.log_ndtr(t.cpu().numpy()) 1454*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected) 1455*da0073e9SAndroid Build Coastguard Worker 1456*da0073e9SAndroid Build Coastguard Worker # Skip testing NaN, inf, -inf since they are tested in reference_numerics tests. 1457*da0073e9SAndroid Build Coastguard Worker info = torch.finfo(dtype) 1458*da0073e9SAndroid Build Coastguard Worker min, max, eps, tiny = info.min, info.max, info.eps, info.tiny 1459*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([min, max, eps, tiny], dtype=dtype, device=device) 1460*da0073e9SAndroid Build Coastguard Worker check_equal(t) 1461*da0073e9SAndroid Build Coastguard Worker 1462*da0073e9SAndroid Build Coastguard Worker # TODO: allow large opinfo values to be opted-into via metadata 1463*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.long) 1464*da0073e9SAndroid Build Coastguard Worker def test_abs_big_number(self, device, dtype): 1465*da0073e9SAndroid Build Coastguard Worker bignumber = 2**31 + 1 1466*da0073e9SAndroid Build Coastguard Worker res = torch.tensor([bignumber], device=device, dtype=dtype) 1467*da0073e9SAndroid Build Coastguard Worker self.assertGreater(res.abs()[0], 0) 1468*da0073e9SAndroid Build Coastguard Worker 1469*da0073e9SAndroid Build Coastguard Worker # TODO: add signed zero testing to opinfos 1470*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 1471*da0073e9SAndroid Build Coastguard Worker def test_abs_signed_zero(self, device, dtype): 1472*da0073e9SAndroid Build Coastguard Worker # Both abs(0.0) and abs(-0.0) should result in 0.0 1473*da0073e9SAndroid Build Coastguard Worker size = 128 + 1 # pick a large enough number with remainder so that 1474*da0073e9SAndroid Build Coastguard Worker # both vectorized and nonvectorized op is tested 1475*da0073e9SAndroid Build Coastguard Worker inp = torch.zeros(size, device=device, dtype=dtype) 1476*da0073e9SAndroid Build Coastguard Worker inp[::2] = -0.0 1477*da0073e9SAndroid Build Coastguard Worker inp = inp.abs() 1478*da0073e9SAndroid Build Coastguard Worker for v in inp: 1479*da0073e9SAndroid Build Coastguard Worker self.assertGreater(math.copysign(1.0, v), 0.0) 1480*da0073e9SAndroid Build Coastguard Worker 1481*da0073e9SAndroid Build Coastguard Worker # TODO: update to compare against NumPy by rationalizing with OpInfo 1482*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1483*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 1484*da0073e9SAndroid Build Coastguard Worker def test_abs_zero(self, device, dtype): 1485*da0073e9SAndroid Build Coastguard Worker # Both abs(0.0) and abs(-0.0) should result in 0.0 1486*da0073e9SAndroid Build Coastguard Worker abs_zeros = torch.tensor([0.0, -0.0], device=device, dtype=dtype).abs().tolist() 1487*da0073e9SAndroid Build Coastguard Worker for num in abs_zeros: 1488*da0073e9SAndroid Build Coastguard Worker self.assertGreater(math.copysign(1.0, num), 0.0) 1489*da0073e9SAndroid Build Coastguard Worker 1490*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16)) 1491*da0073e9SAndroid Build Coastguard Worker def test_isposinf_isneginf_non_boolean_output(self, device, dtype): 1492*da0073e9SAndroid Build Coastguard Worker # test non-boolean tensors as the `out=` parameters 1493*da0073e9SAndroid Build Coastguard Worker # boolean outputs are tested in the above testcases 1494*da0073e9SAndroid Build Coastguard Worker vals = (float("inf"), -float("inf"), 1.2) 1495*da0073e9SAndroid Build Coastguard Worker t = torch.tensor(vals, device=device) 1496*da0073e9SAndroid Build Coastguard Worker for torch_op in (torch.isposinf, torch.isneginf): 1497*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(t, dtype=dtype) 1498*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1499*da0073e9SAndroid Build Coastguard Worker RuntimeError, "does not support non-boolean outputs" 1500*da0073e9SAndroid Build Coastguard Worker ): 1501*da0073e9SAndroid Build Coastguard Worker torch_op(t, out=out) 1502*da0073e9SAndroid Build Coastguard Worker 1503*da0073e9SAndroid Build Coastguard Worker def test_nonzero_empty(self, device): 1504*da0073e9SAndroid Build Coastguard Worker def assert_tuple_empty(tup, dim): 1505*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dim, len(tup)) 1506*da0073e9SAndroid Build Coastguard Worker for t in tup: 1507*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.Size([0]), t.shape) 1508*da0073e9SAndroid Build Coastguard Worker 1509*da0073e9SAndroid Build Coastguard Worker x = torch.randn(0, 2, 0, 5, 0, device=device) 1510*da0073e9SAndroid Build Coastguard Worker y = torch.nonzero(x) 1511*da0073e9SAndroid Build Coastguard Worker z = torch.nonzero(x, as_tuple=True) 1512*da0073e9SAndroid Build Coastguard Worker 1513*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, y.numel()) 1514*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.Size([0, 5]), y.shape) 1515*da0073e9SAndroid Build Coastguard Worker assert_tuple_empty(z, 5) 1516*da0073e9SAndroid Build Coastguard Worker 1517*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(0.5, device=device) 1518*da0073e9SAndroid Build Coastguard Worker y = torch.nonzero(x) 1519*da0073e9SAndroid Build Coastguard Worker # nonzero with as_tuple returns a 1520*da0073e9SAndroid Build Coastguard Worker # tuple of len 1 for a zero-dim tensor. 1521*da0073e9SAndroid Build Coastguard Worker # This is done to match Numpy behavior. 1522*da0073e9SAndroid Build Coastguard Worker z = torch.nonzero(x, as_tuple=True) 1523*da0073e9SAndroid Build Coastguard Worker self.assertEqual(1, len(z)) 1524*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.zeros(1, dtype=torch.long), z[0]) 1525*da0073e9SAndroid Build Coastguard Worker 1526*da0073e9SAndroid Build Coastguard Worker x = torch.zeros((), device=device) 1527*da0073e9SAndroid Build Coastguard Worker y = torch.nonzero(x) 1528*da0073e9SAndroid Build Coastguard Worker z = torch.nonzero(x, as_tuple=True) 1529*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.Size([0, 0]), y.shape) 1530*da0073e9SAndroid Build Coastguard Worker self.assertEqual(1, len(z)) 1531*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.empty(0, dtype=torch.long), z[0]) 1532*da0073e9SAndroid Build Coastguard Worker 1533*da0073e9SAndroid Build Coastguard Worker # TODO: rationalize with exp OpInfo 1534*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types_and(torch.bfloat16)) 1535*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(*floating_and_complex_types_and(torch.half, torch.bfloat16)) 1536*da0073e9SAndroid Build Coastguard Worker def test_exp(self, device, dtype): 1537*da0073e9SAndroid Build Coastguard Worker for v in (2, -2) + ((1j, 1 + 1j) if dtype.is_complex else ()): 1538*da0073e9SAndroid Build Coastguard Worker a = ( 1539*da0073e9SAndroid Build Coastguard Worker torch.tensor(v, dtype=dtype, device=device) 1540*da0073e9SAndroid Build Coastguard Worker * torch.arange(18, device=device) 1541*da0073e9SAndroid Build Coastguard Worker / 3 1542*da0073e9SAndroid Build Coastguard Worker * math.pi 1543*da0073e9SAndroid Build Coastguard Worker ) 1544*da0073e9SAndroid Build Coastguard Worker a = a.to(dtype) 1545*da0073e9SAndroid Build Coastguard Worker # bfloat16 overflows 1546*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bfloat16: 1547*da0073e9SAndroid Build Coastguard Worker return 1548*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch.exp, np.exp, a) 1549*da0073e9SAndroid Build Coastguard Worker 1550*da0073e9SAndroid Build Coastguard Worker if dtype.is_complex: 1551*da0073e9SAndroid Build Coastguard Worker inf_real_zero_imag_in = torch.tensor( 1552*da0073e9SAndroid Build Coastguard Worker complex(float("inf"), 0), device=device, dtype=dtype 1553*da0073e9SAndroid Build Coastguard Worker ) 1554*da0073e9SAndroid Build Coastguard Worker inf_real_zero_imag_out = torch.exp(inf_real_zero_imag_in).item() 1555*da0073e9SAndroid Build Coastguard Worker self.assertTrue(math.isinf(inf_real_zero_imag_out.real)) 1556*da0073e9SAndroid Build Coastguard Worker if self.device_type == "cpu": 1557*da0073e9SAndroid Build Coastguard Worker pass 1558*da0073e9SAndroid Build Coastguard Worker # These are commented out because it cannot be consistently reproduced. 1559*da0073e9SAndroid Build Coastguard Worker # This is incorrect. It should be zero. Need fix! 1560*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/40590 1561*da0073e9SAndroid Build Coastguard Worker # self.assertNotEqual(inf_real_zero_imag_out.imag, 0) 1562*da0073e9SAndroid Build Coastguard Worker # This is incorrect. They should equal. Need fix! 1563*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/40590 1564*da0073e9SAndroid Build Coastguard Worker # with self.assertRaises(AssertionError): 1565*da0073e9SAndroid Build Coastguard Worker # self.compare_with_numpy(torch.exp, np.exp, inf_real_zero_imag_in) 1566*da0073e9SAndroid Build Coastguard Worker else: 1567*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inf_real_zero_imag_out.imag, 0, atol=0, rtol=0) 1568*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch.exp, np.exp, inf_real_zero_imag_in) 1569*da0073e9SAndroid Build Coastguard Worker 1570*da0073e9SAndroid Build Coastguard Worker zero_real_inf_imag_in = torch.tensor( 1571*da0073e9SAndroid Build Coastguard Worker complex(0, float("inf")), device=device, dtype=dtype 1572*da0073e9SAndroid Build Coastguard Worker ) 1573*da0073e9SAndroid Build Coastguard Worker zero_real_inf_imag_out = torch.exp(zero_real_inf_imag_in).item() 1574*da0073e9SAndroid Build Coastguard Worker self.assertTrue(math.isnan(zero_real_inf_imag_out.real)) 1575*da0073e9SAndroid Build Coastguard Worker self.assertTrue(math.isnan(zero_real_inf_imag_out.imag)) 1576*da0073e9SAndroid Build Coastguard Worker # Ensure we are notified when NumPy changes its behavior 1577*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch.exp, np.exp, zero_real_inf_imag_in) 1578*da0073e9SAndroid Build Coastguard Worker 1579*da0073e9SAndroid Build Coastguard Worker inf_real_imag_in = torch.tensor( 1580*da0073e9SAndroid Build Coastguard Worker complex(float("inf"), float("inf")), device=device, dtype=dtype 1581*da0073e9SAndroid Build Coastguard Worker ) 1582*da0073e9SAndroid Build Coastguard Worker inf_real_imag_out = torch.exp(inf_real_imag_in).item() 1583*da0073e9SAndroid Build Coastguard Worker if self.device_type == "cpu": 1584*da0073e9SAndroid Build Coastguard Worker pass 1585*da0073e9SAndroid Build Coastguard Worker # This is incorrect. Need fix! https://github.com/pytorch/pytorch/issues/40590 1586*da0073e9SAndroid Build Coastguard Worker # This is commented out because it cannot be consistently reproduced. 1587*da0073e9SAndroid Build Coastguard Worker # with self.assertRaises(AssertionError): 1588*da0073e9SAndroid Build Coastguard Worker # self.compare_with_numpy(torch.exp, np.exp, inf_real_imag_in) 1589*da0073e9SAndroid Build Coastguard Worker else: 1590*da0073e9SAndroid Build Coastguard Worker self.assertTrue(math.isinf(inf_real_imag_out.real)) 1591*da0073e9SAndroid Build Coastguard Worker self.assertTrue(math.isnan(inf_real_imag_out.imag)) 1592*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch.exp, np.exp, inf_real_imag_in) 1593*da0073e9SAndroid Build Coastguard Worker 1594*da0073e9SAndroid Build Coastguard Worker inf_real_nan_imag_in = torch.tensor( 1595*da0073e9SAndroid Build Coastguard Worker complex(float("inf"), float("nan")), device=device, dtype=dtype 1596*da0073e9SAndroid Build Coastguard Worker ) 1597*da0073e9SAndroid Build Coastguard Worker inf_real_nan_imag_out = torch.exp(inf_real_nan_imag_in).item() 1598*da0073e9SAndroid Build Coastguard Worker if self.device_type == "cpu": 1599*da0073e9SAndroid Build Coastguard Worker pass 1600*da0073e9SAndroid Build Coastguard Worker # This is incorrect. It should be inf. Need fix! https://github.com/pytorch/pytorch/issues/40590 1601*da0073e9SAndroid Build Coastguard Worker # This is commented out because it cannot be consistently reproduced. 1602*da0073e9SAndroid Build Coastguard Worker # with self.assertRaises(AssertionError): 1603*da0073e9SAndroid Build Coastguard Worker # self.compare_with_numpy(torch.exp, np.exp, inf_real_nan_imag_in) 1604*da0073e9SAndroid Build Coastguard Worker else: 1605*da0073e9SAndroid Build Coastguard Worker self.assertTrue(math.isinf(inf_real_nan_imag_out.real)) 1606*da0073e9SAndroid Build Coastguard Worker self.assertTrue(math.isnan(inf_real_nan_imag_out.imag)) 1607*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch.exp, np.exp, inf_real_nan_imag_in) 1608*da0073e9SAndroid Build Coastguard Worker 1609*da0073e9SAndroid Build Coastguard Worker nan_real_inf_imag_in = torch.tensor( 1610*da0073e9SAndroid Build Coastguard Worker complex(float("nan"), float("inf")), device=device, dtype=dtype 1611*da0073e9SAndroid Build Coastguard Worker ) 1612*da0073e9SAndroid Build Coastguard Worker nan_real_inf_imag_out = torch.exp(nan_real_inf_imag_in).item() 1613*da0073e9SAndroid Build Coastguard Worker self.assertTrue(math.isnan(nan_real_inf_imag_out.real)) 1614*da0073e9SAndroid Build Coastguard Worker self.assertTrue(math.isnan(nan_real_inf_imag_out.imag)) 1615*da0073e9SAndroid Build Coastguard Worker # Ensure we are notified when NumPy changes its behavior 1616*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch.exp, np.exp, nan_real_inf_imag_in) 1617*da0073e9SAndroid Build Coastguard Worker 1618*da0073e9SAndroid Build Coastguard Worker 1619*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestUnaryUfuncs, globals()) 1620*da0073e9SAndroid Build Coastguard Worker 1621*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 1622*da0073e9SAndroid Build Coastguard Worker run_tests() 1623