1# Owner(s): ["module: tests"] 2 3import torch 4import numpy as np 5 6import math 7from numbers import Number 8import random 9import unittest 10 11from torch import inf, nan 12from torch.testing._internal.common_utils import ( 13 TestCase, 14 run_tests, 15 torch_to_numpy_dtype_dict, 16 numpy_to_torch_dtype_dict, 17 suppress_warnings, 18 TEST_SCIPY, 19 slowTest, 20 skipIfNoSciPy, 21 IS_WINDOWS, 22 gradcheck, 23 is_iterable_of_tensors, 24 xfailIfTorchDynamo, 25) 26from torch.testing._internal.common_methods_invocations import ( 27 unary_ufuncs, 28 generate_elementwise_unary_tensors, 29 generate_elementwise_unary_small_value_tensors, 30 generate_elementwise_unary_large_value_tensors, 31 generate_elementwise_unary_extremal_value_tensors, 32) 33from torch.testing._internal.common_device_type import ( 34 instantiate_device_type_tests, 35 ops, 36 dtypes, 37 onlyCPU, 38 onlyNativeDeviceTypes, 39 onlyCUDA, 40 dtypesIfCUDA, 41 precisionOverride, 42 dtypesIfCPU, 43) 44from torch.utils import _pytree as pytree 45 46from torch.testing import make_tensor 47from torch.testing._internal.common_dtype import ( 48 floating_types_and, 49 all_types_and_complex_and, 50 integral_types_and, 51 get_all_math_dtypes, 52 complex_types, 53 floating_and_complex_types_and, 54) 55 56if TEST_SCIPY: 57 import scipy 58 59# Refer [scipy reference filter] 60# Filter operators for which the reference function 61# is available in the current environment (for reference_numerics tests). 62reference_filtered_ops = list(filter(lambda op: op.ref is not None, unary_ufuncs)) 63 64# Tests for unary "universal functions (ufuncs)" that accept a single 65# tensor and have common properties like: 66# - they are elementwise functions 67# - the input shape is the output shape 68# - they typically have method and inplace variants 69# - they typically support the out kwarg 70# - they typically have NumPy or SciPy references 71 72# See NumPy's universal function documentation 73# (https://numpy.org/doc/1.18/reference/ufuncs.html) for more details 74# about the concept of ufuncs. 75 76 77# TODO: port test_unary_out_op_mem_overlap 78# TODO: add test for inplace variants erroring on broadcasted inputs 79class TestUnaryUfuncs(TestCase): 80 exact_dtype = True 81 82 @ops( 83 [_fn for _fn in unary_ufuncs if _fn.domain != (None, None)], 84 allowed_dtypes=floating_types_and(torch.bfloat16, torch.half), 85 ) 86 def test_float_domains(self, device, dtype, op): 87 eps = (1e-5, 1e-3, 1e-1, 1, 2, 10, 20, 50, 100) 88 89 low, high = op.domain 90 # NOTE: the following two loops are separated for readability 91 if low is not None: 92 low_tensor = torch.tensor(low, device=device, dtype=dtype) 93 for epsilon in eps: 94 lower_tensor = low_tensor - epsilon 95 96 # Skips the test if the difference is not representable, 97 # which can occur if, for example, the difference is small 98 # and the dtype is imprecise (like bfloat16 is) 99 if lower_tensor.item() == low_tensor.item(): 100 continue 101 102 result = op(lower_tensor) 103 self.assertEqual( 104 result.item(), 105 float("nan"), 106 msg=( 107 f"input of {lower_tensor.item()} outside lower domain boundary" 108 f" {low} produced {result.item()}, not nan!" 109 ), 110 ) 111 112 if high is not None: 113 high_tensor = torch.tensor(high, device=device, dtype=dtype) 114 for epsilon in eps: 115 higher_tensor = high_tensor + epsilon 116 117 # See above comment 118 if higher_tensor.item() == high_tensor.item(): 119 continue 120 121 result = op(higher_tensor) 122 self.assertEqual( 123 result.item(), 124 float("nan"), 125 msg=( 126 f"input of {higher_tensor.item()} outside upper domain boundary" 127 f" {high} produced {result.item()}, not nan!" 128 ), 129 ) 130 131 # Helper for comparing torch tensors and numpy arrays 132 # TODO: should this or assertEqual also validate that strides are equal? 133 def assertEqualHelper( 134 self, actual, expected, msg, *, dtype, exact_dtype=True, **kwargs 135 ): 136 assert isinstance(actual, torch.Tensor) 137 138 # Some NumPy functions return scalars, not arrays 139 if isinstance(expected, Number): 140 self.assertEqual(actual.item(), expected, msg, **kwargs) 141 elif isinstance(expected, np.ndarray): 142 # Handles exact dtype comparisons between arrays and tensors 143 if exact_dtype: 144 if ( 145 actual.dtype is torch.bfloat16 146 or expected.dtype != torch_to_numpy_dtype_dict[actual.dtype] 147 ): 148 # Allows array dtype to be float32 when comparing with bfloat16 tensors 149 # since NumPy doesn't support the bfloat16 dtype 150 # Also ops like scipy.special.erf, scipy.special.erfc, etc, promote float16 151 # to float32 152 if expected.dtype == np.float32: 153 assert actual.dtype in ( 154 torch.float16, 155 torch.bfloat16, 156 torch.float32, 157 ) 158 elif expected.dtype == np.float64: 159 assert actual.dtype in ( 160 torch.float16, 161 torch.bfloat16, 162 torch.float32, 163 torch.float64, 164 ) 165 else: 166 self.fail( 167 f"Expected dtype {expected.dtype} but got {actual.dtype}!" 168 ) 169 170 self.assertEqual( 171 actual, 172 torch.from_numpy(expected).to(actual.dtype), 173 msg, 174 exact_device=False, 175 **kwargs 176 ) 177 else: 178 self.assertEqual(actual, expected, msg, exact_device=False, **kwargs) 179 180 # Tests that the function and its (array-accepting) reference produce the same 181 # values on given tensors 182 def _test_reference_numerics(self, dtype, op, tensors, equal_nan=True): 183 def _helper_reference_numerics( 184 expected, actual, msg, exact_dtype, equal_nan=True 185 ): 186 if not torch.can_cast( 187 numpy_to_torch_dtype_dict[expected.dtype.type], dtype 188 ): 189 exact_dtype = False 190 191 if dtype in [torch.uint8, torch.int8, torch.bool]: 192 # NOTE: For these dtypes, PyTorch computes in the default scalar type (float) 193 # while NumPy computes in float16 194 self.assertEqualHelper( 195 actual, 196 expected, 197 msg, 198 dtype=dtype, 199 exact_dtype=exact_dtype, 200 rtol=1e-3, 201 atol=1e-2, 202 ) 203 elif dtype is torch.bfloat16: 204 # Ref: https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_utils.py#L1149 205 self.assertEqualHelper( 206 actual, 207 expected, 208 msg, 209 dtype=dtype, 210 exact_dtype=exact_dtype, 211 rtol=16e-3, 212 atol=1e-5, 213 ) 214 elif dtype is torch.half: 215 self.assertEqualHelper( 216 actual, 217 expected, 218 msg, 219 dtype=dtype, 220 exact_dtype=exact_dtype, 221 rtol=1.2e-03, 222 atol=1e-03, 223 ) 224 else: 225 self.assertEqualHelper( 226 actual, 227 expected, 228 msg, 229 dtype=dtype, 230 equal_nan=equal_nan, 231 exact_dtype=exact_dtype, 232 ) 233 234 for t in tensors: 235 t = t.input 236 torch_kwargs, numpy_kwargs = op.sample_kwargs(t.device, dtype, t) 237 if dtype is torch.bfloat16: 238 a = t.cpu().to(torch.float32).numpy() 239 elif dtype is torch.complex32: 240 a = t.cpu().to(torch.complex64).numpy() 241 else: 242 a = t.cpu().numpy() 243 244 actual = op(t, **torch_kwargs) 245 expected = op.ref(a, **numpy_kwargs) 246 247 # Crafts a custom error message for smaller, printable tensors 248 if t.numel() < 10: 249 msg = ( 250 "Failed to produce expected results! Input tensor was" 251 f" {t}, torch result is {actual}, and reference result is" 252 f" {expected}." 253 ) 254 else: 255 msg = None 256 257 exact_dtype = True 258 if isinstance(actual, torch.Tensor): 259 _helper_reference_numerics( 260 expected, actual, msg, exact_dtype, equal_nan 261 ) 262 else: 263 for x, y in zip(expected, actual): 264 # testing multi-outputs results 265 _helper_reference_numerics(x, y, msg, exact_dtype, equal_nan) 266 267 # Tests that the function and its (array-accepting) reference produce the same 268 # values on a range of tensors, including empty tensors, scalar tensors, 269 # 1D tensors and a large 2D tensor with interesting and extremal values 270 # and noncontiguities. 271 @suppress_warnings 272 @ops(reference_filtered_ops) 273 def test_reference_numerics_normal(self, device, dtype, op): 274 tensors = generate_elementwise_unary_tensors( 275 op, device=device, dtype=dtype, requires_grad=False 276 ) 277 self._test_reference_numerics(dtype, op, tensors) 278 279 @suppress_warnings 280 @ops(reference_filtered_ops) 281 def test_reference_numerics_small(self, device, dtype, op): 282 if dtype in (torch.bool,): 283 raise self.skipTest("bool has no small values") 284 285 tensors = generate_elementwise_unary_small_value_tensors( 286 op, device=device, dtype=dtype, requires_grad=False 287 ) 288 self._test_reference_numerics(dtype, op, tensors) 289 290 @suppress_warnings 291 @ops(reference_filtered_ops) 292 def test_reference_numerics_large(self, device, dtype, op): 293 if dtype in (torch.bool, torch.uint8, torch.int8): 294 raise self.skipTest("bool, uint8, and int8 dtypes have no large values") 295 296 tensors = generate_elementwise_unary_large_value_tensors( 297 op, device=device, dtype=dtype, requires_grad=False 298 ) 299 self._test_reference_numerics(dtype, op, tensors) 300 301 @suppress_warnings 302 @ops( 303 reference_filtered_ops, 304 allowed_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half), 305 ) 306 def test_reference_numerics_extremal(self, device, dtype, op): 307 tensors = generate_elementwise_unary_extremal_value_tensors( 308 op, device=device, dtype=dtype, requires_grad=False 309 ) 310 self._test_reference_numerics(dtype, op, tensors) 311 312 # Tests for testing (non)contiguity consistency 313 @ops(unary_ufuncs) 314 def test_contig_vs_every_other(self, device, dtype, op): 315 contig = make_tensor( 316 (1026,), device=device, dtype=dtype, low=op.domain[0], high=op.domain[1] 317 ) 318 non_contig = contig[::2] 319 320 self.assertTrue(contig.is_contiguous()) 321 self.assertFalse(non_contig.is_contiguous()) 322 323 torch_kwargs, _ = op.sample_kwargs(device, dtype, non_contig) 324 expected = op(non_contig, **torch_kwargs) 325 result = op(contig, **torch_kwargs) 326 result = pytree.tree_map(lambda x: x[::2], result) 327 self.assertEqual(result, expected) 328 329 @ops(unary_ufuncs) 330 def test_contig_vs_transposed(self, device, dtype, op): 331 contig = make_tensor( 332 (789, 357), device=device, dtype=dtype, low=op.domain[0], high=op.domain[1] 333 ) 334 non_contig = contig.T 335 336 self.assertTrue(contig.is_contiguous()) 337 self.assertFalse(non_contig.is_contiguous()) 338 339 torch_kwargs, _ = op.sample_kwargs(device, dtype, contig) 340 expected = op(non_contig, **torch_kwargs) 341 result = op(contig, **torch_kwargs) 342 result = pytree.tree_map(lambda x: x.T, result) 343 self.assertEqual(result, expected) 344 345 @ops(unary_ufuncs) 346 def test_non_contig(self, device, dtype, op): 347 shapes = [(5, 7), (1024,)] 348 for shape in shapes: 349 contig = make_tensor( 350 shape, dtype=dtype, device=device, low=op.domain[0], high=op.domain[1] 351 ) 352 non_contig = torch.empty(shape + (2,), device=device, dtype=dtype)[..., 0] 353 non_contig.copy_(contig) 354 355 self.assertTrue(contig.is_contiguous()) 356 self.assertFalse(non_contig.is_contiguous()) 357 358 torch_kwargs, _ = op.sample_kwargs(device, dtype, contig) 359 self.assertEqual(op(contig, **torch_kwargs), op(non_contig, **torch_kwargs)) 360 361 @ops(unary_ufuncs) 362 def test_non_contig_index(self, device, dtype, op): 363 contig = make_tensor( 364 (2, 2, 1, 2), 365 dtype=dtype, 366 device=device, 367 low=op.domain[0], 368 high=op.domain[1], 369 ) 370 non_contig = contig[:, 1, ...] 371 contig = non_contig.contiguous() 372 373 self.assertTrue(contig.is_contiguous()) 374 self.assertFalse(non_contig.is_contiguous()) 375 376 torch_kwargs, _ = op.sample_kwargs(device, dtype, contig) 377 self.assertEqual(op(contig, **torch_kwargs), op(non_contig, **torch_kwargs)) 378 379 @ops(unary_ufuncs) 380 def test_non_contig_expand(self, device, dtype, op): 381 shapes = [(1, 3), (1, 7), (5, 7)] 382 for shape in shapes: 383 contig = make_tensor( 384 shape, dtype=dtype, device=device, low=op.domain[0], high=op.domain[1] 385 ) 386 non_contig = contig.clone().expand(3, -1, -1) 387 388 self.assertTrue(contig.is_contiguous()) 389 self.assertFalse(non_contig.is_contiguous()) 390 391 torch_kwargs, _ = op.sample_kwargs(device, dtype, contig) 392 contig = op(contig, **torch_kwargs) 393 non_contig = op(non_contig, **torch_kwargs) 394 for i in range(3): 395 non_contig_i = pytree.tree_map(lambda x: x[i], non_contig) 396 self.assertEqual( 397 contig, non_contig_i, msg="non-contiguous expand[" + str(i) + "]" 398 ) 399 400 @ops(unary_ufuncs) 401 def test_contig_size1(self, device, dtype, op): 402 contig = make_tensor( 403 (5, 100), dtype=dtype, device=device, low=op.domain[0], high=op.domain[1] 404 ) 405 contig = contig[:1, :50] 406 contig2 = torch.empty(contig.size(), device=device, dtype=dtype) 407 contig2.copy_(contig) 408 409 self.assertTrue(contig.is_contiguous()) 410 self.assertTrue(contig2.is_contiguous()) 411 412 torch_kwargs, _ = op.sample_kwargs(device, dtype, contig) 413 self.assertEqual(op(contig, **torch_kwargs), op(contig2, **torch_kwargs)) 414 415 @ops(unary_ufuncs) 416 def test_contig_size1_large_dim(self, device, dtype, op): 417 contig = make_tensor( 418 (5, 2, 3, 1, 4, 5, 3, 2, 1, 2, 3, 4), 419 dtype=dtype, 420 device=device, 421 low=op.domain[0], 422 high=op.domain[1], 423 ) 424 contig = contig[:1, :, :, :, :, :, :, :, :, :, :, :] 425 contig2 = torch.empty(contig.size(), device=device, dtype=dtype) 426 contig2.copy_(contig) 427 428 self.assertTrue(contig.is_contiguous()) 429 self.assertTrue(contig2.is_contiguous()) 430 431 torch_kwargs, _ = op.sample_kwargs(device, dtype, contig) 432 self.assertEqual(op(contig, **torch_kwargs), op(contig2, **torch_kwargs)) 433 434 # Tests that computation on a multiple batches is the same as 435 # per-batch computation. 436 @ops(unary_ufuncs) 437 def test_batch_vs_slicing(self, device, dtype, op): 438 input = make_tensor( 439 (1024, 512), dtype=dtype, device=device, low=op.domain[0], high=op.domain[1] 440 ) 441 442 torch_kwargs, _ = op.sample_kwargs(device, dtype, input) 443 actual = op(input, **torch_kwargs) 444 445 all_outs = [op(slice, **torch_kwargs) for slice in input] 446 if is_iterable_of_tensors(actual): 447 expected = [torch.stack([out[i] for out in all_outs]) for i in range(len(actual))] 448 else: 449 expected = torch.stack(all_outs) 450 451 self.assertEqual(actual, expected) 452 453 @dtypes(*all_types_and_complex_and(torch.bool, torch.half)) 454 def test_nan_to_num(self, device, dtype): 455 for contiguous in [False, True]: 456 x = make_tensor((64, 64), low=0.0, high=100.0, dtype=dtype, device=device) 457 458 if dtype.is_floating_point: 459 # Add extremal values. 460 extremals = [float("nan"), float("inf"), -float("inf")] 461 for idx, extremal in zip(torch.randint(0, 63, (3,)), extremals): 462 x[idx, :] = extremal 463 464 if not contiguous: 465 x = x.T 466 467 # With args 468 nan = random.random() 469 posinf = random.random() * 5 470 neginf = random.random() * 10 471 472 self.compare_with_numpy( 473 lambda x: x.nan_to_num(nan=nan, posinf=posinf), 474 lambda x: np.nan_to_num(x, nan=nan, posinf=posinf), 475 x, 476 ) 477 self.compare_with_numpy( 478 lambda x: x.nan_to_num(posinf=posinf, neginf=neginf), 479 lambda x: np.nan_to_num(x, posinf=posinf, neginf=neginf), 480 x, 481 ) 482 483 # Out Variant 484 out = torch.empty_like(x) 485 result = torch.nan_to_num(x) 486 torch.nan_to_num(x, out=out) 487 self.assertEqual(result, out) 488 489 result = torch.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf) 490 torch.nan_to_num(x, out=out, nan=nan, posinf=posinf, neginf=neginf) 491 self.assertEqual(result, out) 492 493 @onlyCPU 494 def test_nan_to_num_bfloat16(self, device): 495 def test_dtype(fn, input, dtype): 496 input = input.detach().clone().to(dtype=dtype).requires_grad_(True) 497 input2 = input.detach().clone().float().requires_grad_(True) 498 out = fn(input) 499 out.sum().backward() 500 out2 = fn(input2) 501 out2.sum().backward() 502 self.assertEqual(out.dtype, dtype) 503 self.assertEqual(input.grad.dtype, dtype) 504 self.assertEqual(out, out2, exact_dtype=False) 505 self.assertEqual(input.grad, input2.grad, exact_dtype=False) 506 507 def func(): 508 return torch.nan_to_num 509 510 shapes = [[1, 3, 6, 6], [1, 3, 6, 128], [1, 3, 256, 256]] 511 for shape in shapes: 512 x = torch.randn(shape, device=device) 513 extremals = [float('nan'), float('inf'), -float('inf')] 514 for id1, id2, extremal in zip(torch.randint(0, 2, (3,)), torch.randint(0, 5, (3,)), extremals): 515 x[0, id1, id2, :] = extremal 516 test_dtype(func(), x, torch.bfloat16) 517 518 @dtypes(torch.complex64, torch.complex128) 519 def test_nan_to_num_complex(self, device, dtype): 520 value_dtype = torch.tensor([], dtype=dtype).real.dtype 521 522 def gen_tensor(a): 523 return torch.view_as_complex(torch.tensor(a, dtype=value_dtype, device=device)) 524 525 for extremal, kwarg_name in zip(['nan', 'inf', '-inf'], ['nan', 'posinf', 'neginf']): 526 a = gen_tensor([123, float(extremal)]) 527 res = torch.nan_to_num(a, **{kwarg_name: 12}) 528 res_check = gen_tensor([123, 12]) 529 self.assertEqual(res, res_check) 530 531 a = gen_tensor([float(extremal), 456]) 532 res = torch.nan_to_num(a, **{kwarg_name: 21}) 533 res_check = gen_tensor([21, 456]) 534 self.assertEqual(res, res_check) 535 536 @dtypes(torch.cdouble) 537 def test_complex_edge_values(self, device, dtype): 538 # sqrt Test Reference: https://github.com/pytorch/pytorch/pull/47424 539 x = torch.tensor(0.0 - 1.0e20j, dtype=dtype, device=device) 540 self.compare_with_numpy(torch.sqrt, np.sqrt, x) 541 # acos test reference: https://github.com/pytorch/pytorch/issue/42952 542 # Skip on Windows, as CUDA acos returns conjugate value 543 # see https://github.com/pytorch/pytorch/issues/52299 544 if not (IS_WINDOWS and dtype == torch.cdouble and "cuda" in device): 545 self.compare_with_numpy(torch.acos, np.arccos, x) 546 547 x = torch.tensor( 548 (-1.0e60 if dtype == torch.cdouble else -1.0e20) - 4988429.2j, 549 dtype=dtype, 550 device=device, 551 ) 552 self.compare_with_numpy(torch.sqrt, np.sqrt, x) 553 554 @unittest.skipIf(not TEST_SCIPY, "Requires SciPy") 555 @dtypes(torch.float, torch.double) 556 def test_digamma_special(self, device, dtype): 557 # Based on SciPy test for the following special values. 558 # Reference: 559 # https://github.com/scipy/scipy/blob/3a8a3a1d4657254a6611e77e9c28feafa26e6645/scipy/special/tests/test_digamma.py#L22 560 euler = 0.57721566490153286 561 dataset = [ 562 (0.0, -0.0), 563 (1, -euler), 564 (0.5, -2 * math.log(2) - euler), 565 (1 / 3, -math.pi / (2 * math.sqrt(3)) - 3 * math.log(3) / 2 - euler), 566 (1 / 4, -math.pi / 2 - 3 * math.log(2) - euler), 567 ( 568 1 / 6, 569 -math.pi * math.sqrt(3) / 2 570 - 2 * math.log(2) 571 - 3 * math.log(3) / 2 572 - euler, 573 ), 574 ( 575 1 / 8, 576 -math.pi / 2 577 - 4 * math.log(2) 578 - (math.pi + math.log(2 + math.sqrt(2)) - math.log(2 - math.sqrt(2))) 579 / math.sqrt(2) 580 - euler, 581 ), 582 ] 583 x = torch.tensor(dataset, device=device, dtype=dtype) 584 self.compare_with_numpy(torch.digamma, scipy.special.digamma, x) 585 586 @unittest.skipIf(not TEST_SCIPY, "Requires SciPy") 587 @dtypes(torch.float, torch.double) 588 def test_digamma(self, device, dtype): 589 # Tests pole behavior 590 tensor = torch.tensor( 591 [ 592 -0.999999994, 593 -1.999999994, 594 -2.0000000111, 595 -100.99999994, 596 0.000000111, 597 -1931.99999994, 598 -0.000000111, 599 0, 600 -0, 601 -1, 602 -2, 603 -931, 604 ], 605 dtype=dtype, 606 device=device, 607 ) 608 self.compare_with_numpy(torch.digamma, scipy.special.digamma, tensor) 609 610 @dtypes(*floating_types_and(torch.half)) 611 def test_frexp(self, device, dtype): 612 input = make_tensor((50, 50), dtype=dtype, device=device) 613 mantissa, exponent = torch.frexp(input) 614 np_mantissa, np_exponent = np.frexp(input.cpu().numpy()) 615 616 self.assertEqual(mantissa, np_mantissa) 617 self.assertEqual(exponent, np_exponent) 618 619 # torch.frexp returns exponent in int32 to be compatible with np.frexp 620 self.assertTrue(exponent.dtype == torch.int32) 621 self.assertTrue(torch_to_numpy_dtype_dict[exponent.dtype] == np_exponent.dtype) 622 623 def test_frexp_assert_raises(self, device): 624 invalid_input_dtypes = integral_types_and(torch.bool) + complex_types() 625 for dtype in invalid_input_dtypes: 626 input = make_tensor((50, 50), dtype=dtype, device=device) 627 with self.assertRaisesRegex( 628 RuntimeError, r"torch\.frexp\(\) only supports floating-point dtypes" 629 ): 630 torch.frexp(input) 631 632 for dtype in floating_types_and(torch.half): 633 input = make_tensor((50, 50), dtype=dtype, device=device) 634 635 dtypes = list( 636 all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16) 637 ) 638 dtypes.remove(dtype) 639 for mantissa_dtype in dtypes: 640 mantissa = torch.empty_like(input, dtype=mantissa_dtype) 641 exponent = torch.empty_like(input, dtype=torch.int) 642 with self.assertRaisesRegex( 643 RuntimeError, 644 r"torch\.frexp\(\) expects mantissa to have dtype .+ but got .+", 645 ): 646 torch.frexp(input, out=(mantissa, exponent)) 647 648 dtypes.append(dtype) 649 dtypes.remove(torch.int) 650 for exponent_dtype in dtypes: 651 mantissa = torch.empty_like(input) 652 exponent = torch.empty_like(input, dtype=exponent_dtype) 653 with self.assertRaisesRegex( 654 RuntimeError, 655 r"torch\.frexp\(\) expects exponent to have int dtype but got .+", 656 ): 657 torch.frexp(input, out=(mantissa, exponent)) 658 659 def test_polygamma_neg(self, device): 660 with self.assertRaisesRegex( 661 RuntimeError, r"polygamma\(n, x\) does not support negative n\." 662 ): 663 torch.polygamma(-1, torch.tensor([1.0, 2.0], device=device)) 664 665 # TODO resolve with opinfos 666 @onlyCPU 667 def test_op_invert(self, device): 668 res = 0xFFFF - torch.arange(127, dtype=torch.int8) 669 for dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): 670 a = torch.arange(127, dtype=dtype) 671 self.assertEqual(res.to(dtype), ~a) 672 673 self.assertEqual(torch.tensor([True, False]), ~torch.tensor([False, True])) 674 675 # test exceptions 676 for dtype in (torch.half, torch.float, torch.double): 677 a = torch.zeros(10, dtype=dtype) 678 with self.assertRaises(TypeError): 679 b = ~a 680 681 @dtypes(torch.complex64, torch.complex128) 682 def test_abs_angle_complex_to_float(self, device, dtype): 683 # Constructs random complex values 684 from random import random 685 686 random_vals = [] 687 for multiplier in (-1, 1, -10, 10, -100, 100): 688 for _ in range(10): 689 random_vals.append( 690 complex(random() * multiplier, random() * multiplier) 691 ) 692 693 for vals in (random_vals, []): 694 a = np.array(vals, dtype=torch_to_numpy_dtype_dict[dtype]) 695 t = torch.tensor(vals, device=device, dtype=dtype) 696 697 for fn_name in ("abs", "angle"): 698 torch_fn = getattr(torch, fn_name) 699 np_fn = getattr(np, fn_name) 700 701 # Tests function 702 np_result = torch.from_numpy(np_fn(a)) 703 torch_result = torch_fn(t).cpu() 704 self.assertEqual(np_result, torch_result, exact_dtype=True) 705 706 # Tests float out 707 float_dtype = ( 708 torch.float32 if dtype is torch.complex64 else torch.float64 709 ) 710 np_float_out = np_fn(a).astype(torch_to_numpy_dtype_dict[float_dtype]) 711 float_out = torch.empty_like(t, dtype=float_dtype) 712 torch_fn(t, out=float_out) 713 self.assertEqual(torch.from_numpy(np_float_out), float_out.cpu()) 714 715 # Tests float out (resized out) 716 float_out = torch.empty(1, device=device, dtype=float_dtype) 717 torch_fn(t, out=float_out) 718 self.assertEqual(torch.from_numpy(np_float_out), float_out.cpu()) 719 720 # Tests complex out 721 np_complex_out = np_fn(a).astype(torch_to_numpy_dtype_dict[dtype]) 722 complex_out = torch.empty_like(t) 723 torch_fn(t, out=complex_out) 724 self.assertEqual(torch.from_numpy(np_complex_out), complex_out.cpu()) 725 726 # Tests complex out (resized out) 727 complex_out = torch.empty(0, device=device, dtype=dtype) 728 torch_fn(t, out=complex_out) 729 self.assertEqual(torch.from_numpy(np_complex_out), complex_out.cpu()) 730 731 # Tests long out behavior (expected failure) 732 long_out = torch.empty(0, device=device, dtype=torch.long) 733 with self.assertRaises(RuntimeError): 734 torch_fn(t, out=long_out) 735 736 # Tests inplace 737 if fn_name == "abs": 738 torch_inplace_method = getattr(torch.Tensor, fn_name + "_") 739 np_fn(a, out=a) 740 if dtype.is_complex: 741 with self.assertRaisesRegex( 742 RuntimeError, 743 "In-place abs is not supported for complex tensors.", 744 ): 745 torch_inplace_method(t) 746 return 747 torch_inplace_method(t) 748 self.assertEqual(torch.from_numpy(a), t.cpu()) 749 750 # Note: angle does not have an in-place variant 751 if fn_name == "angle": 752 with self.assertRaises(AttributeError): 753 torch_inplace_method = getattr(torch.Tensor, fn_name + "_") 754 755 def check_internal_mem_overlap( 756 self, inplace_op, num_inputs, dtype, device, expected_failure=False 757 ): 758 if isinstance(inplace_op, str): 759 inplace_op = getattr(torch.Tensor, inplace_op) 760 input = torch.randn(1, dtype=dtype, device=device).expand(3, 3) 761 inputs = [input] + [torch.randn_like(input) for i in range(num_inputs - 1)] 762 if not expected_failure: 763 with self.assertRaisesRegex(RuntimeError, "single memory location"): 764 inplace_op(*inputs) 765 else: 766 with self.assertRaises(AssertionError): 767 with self.assertRaisesRegex(RuntimeError, "single memory location"): 768 inplace_op(*inputs) 769 770 def unary_check_input_output_mem_overlap( 771 self, data, sz, op, expected_failure=False 772 ): 773 def _test(op, output, input): 774 output_exp = torch.empty_like(output) 775 op(input, out=output_exp) 776 self.assertEqual(op(input, out=output), output_exp, msg=op.__name__) 777 778 # output is identical to input: 779 _test(op, output=data[0:sz], input=data[0:sz]) 780 # output and input are independent: 781 _test(op, output=data[0:sz], input=data[sz : 2 * sz]) 782 # output partially overlaps with input: 783 if not expected_failure: 784 with self.assertRaisesRegex(RuntimeError, "unsupported operation"): 785 _test(op, data[0:sz], data[1 : sz + 1]) 786 else: 787 with self.assertRaises(AssertionError): 788 with self.assertRaisesRegex(RuntimeError, "unsupported operation"): 789 _test(op, data[0:sz], data[1 : sz + 1]) 790 791 # TODO: run on non-native device types 792 # https://github.com/pytorch/pytorch/issues/126474 793 @xfailIfTorchDynamo 794 @dtypes(torch.double) 795 def test_unary_out_op_mem_overlap(self, device, dtype): 796 sz = 3 797 doubles = torch.randn(2 * sz, dtype=dtype, device=device) 798 positives = torch.randint(1, 100, (2 * sz,), device=device).double() 799 ints = torch.randint(-100, 100, (2 * sz,), device=device) 800 unary_mem_overlap_cases = [ 801 ("abs", doubles, True, True, "cpu"), 802 ("abs", doubles, True, True, "cuda"), 803 ("acos", doubles, True, True, "cpu"), 804 ("acos", doubles, True, True, "cuda"), 805 ("asin", doubles, True, True, "cpu"), 806 ("asin", doubles, True, True, "cuda"), 807 ("atan", doubles, True, True, "cpu"), 808 ("atan", doubles, True, True, "cuda"), 809 ("acosh", doubles, True, True, "cpu"), 810 ("acosh", doubles, True, True, "cuda"), 811 ("asinh", doubles, True, True, "cpu"), 812 ("asinh", doubles, True, True, "cuda"), 813 ("atanh", doubles, True, True, "cpu"), 814 ("atanh", doubles, True, True, "cuda"), 815 ("bitwise_not", ints, True, True, "cpu"), 816 ("bitwise_not", ints, True, True, "cuda"), 817 ("ceil", doubles, True, True, "cpu"), 818 ("ceil", doubles, True, True, "cuda"), 819 ("cos", doubles, True, True, "cpu"), 820 ("cos", doubles, True, True, "cuda"), 821 ("cosh", doubles, True, True, "cpu"), 822 ("cosh", doubles, True, True, "cuda"), 823 ("digamma", doubles, True, True, "cpu"), 824 ("erf", doubles, True, True, "cpu"), 825 ("erf", doubles, True, True, "cuda"), 826 ("erfc", doubles, True, True, "cpu"), 827 ("erfc", doubles, True, True, "cuda"), 828 ("erfinv", doubles, True, True, "cpu"), 829 ("erfinv", doubles, True, True, "cuda"), 830 ("exp", doubles, True, True, "cpu"), 831 ("exp", doubles, True, True, "cuda"), 832 ("exp2", doubles, True, True, "cpu"), 833 ("exp2", doubles, True, True, "cuda"), 834 ("expm1", doubles, True, True, "cpu"), 835 ("expm1", doubles, True, True, "cuda"), 836 ("floor", doubles, True, True, "cpu"), 837 ("floor", doubles, True, True, "cuda"), 838 ("frac", doubles, True, True, "cpu"), 839 ("frac", doubles, True, True, "cuda"), 840 ("i0", doubles, True, True, "cpu"), 841 ("i0", doubles, True, True, "cuda"), 842 ("log", positives, True, True, "cpu"), 843 ("log", positives, True, True, "cuda"), 844 ("log10", positives, True, True, "cpu"), 845 ("log10", positives, True, True, "cuda"), 846 ("log1p", positives, True, True, "cpu"), 847 ("log1p", positives, True, True, "cuda"), 848 ("log2", positives, True, True, "cpu"), 849 ("log2", positives, True, True, "cuda"), 850 ("neg", doubles, True, True, "cpu"), 851 ("neg", doubles, True, True, "cuda"), 852 ("reciprocal", doubles, True, True, "cpu"), 853 ("reciprocal", doubles, True, True, "cuda"), 854 ("round", doubles, True, True, "cpu"), 855 ("round", doubles, True, True, "cuda"), 856 ("rsqrt", positives, True, True, "cpu"), 857 ("rsqrt", positives, True, True, "cuda"), 858 ("sin", doubles, True, True, "cpu"), 859 ("sin", doubles, True, True, "cuda"), 860 ("sinh", doubles, True, True, "cpu"), 861 ("sinh", doubles, False, True, "cuda"), 862 ("sigmoid", doubles, True, True, "cpu"), 863 ("sigmoid", doubles, True, True, "cuda"), 864 ("logit", doubles, True, True, "cpu"), 865 ("logit", doubles, True, True, "cuda"), 866 ("sqrt", doubles, True, True, "cpu"), 867 ("sqrt", doubles, False, True, "cuda"), 868 ("tan", doubles, True, True, "cpu"), 869 ("tan", doubles, True, True, "cuda"), 870 ("tanh", doubles, True, True, "cpu"), 871 ("tanh", doubles, True, True, "cuda"), 872 ("trunc", doubles, True, True, "cpu"), 873 ("trunc", doubles, True, True, "cuda"), 874 ] 875 876 for ( 877 fn, 878 inputs, 879 has_input_output_mem_overlap_check, 880 has_internal_mem_overlap_check, 881 dev, 882 ) in unary_mem_overlap_cases: 883 if dev != device: 884 continue 885 out_fn = getattr(torch, fn) 886 in_fn = getattr(torch.Tensor, fn + "_") 887 888 self.unary_check_input_output_mem_overlap( 889 inputs, 890 sz, 891 out_fn, 892 expected_failure=not has_input_output_mem_overlap_check, 893 ) 894 895 self.check_internal_mem_overlap( 896 in_fn, 897 1, 898 dtype, 899 dev, 900 expected_failure=not has_internal_mem_overlap_check, 901 ) 902 903 # TODO: opinfo hardshrink 904 @onlyCPU 905 @dtypes(torch.float, torch.double, torch.bfloat16) 906 def test_hardshrink(self, device, dtype): 907 data = torch.tensor([1, 0.5, 0.3, 0.6], dtype=dtype, device=device).view(2, 2) 908 self.assertEqual( 909 torch.tensor([1, 0.5, 0, 0.6], dtype=dtype, device=device).view(2, 2), 910 data.hardshrink(0.3), 911 ) 912 self.assertEqual( 913 torch.tensor([1, 0, 0, 0.6], dtype=dtype, device=device).view(2, 2), 914 data.hardshrink(0.5), 915 ) 916 917 # test default lambd=0.5 918 self.assertEqual(data.hardshrink(), data.hardshrink(0.5)) 919 920 # test non-contiguous case 921 self.assertEqual( 922 torch.tensor([1, 0, 0.5, 0.6], dtype=dtype, device=device).view(2, 2), 923 data.t().hardshrink(0.3), 924 ) 925 926 @onlyCPU 927 @dtypes(torch.float, torch.double, torch.bfloat16) 928 def test_hardshrink_edge_cases(self, device, dtype) -> None: 929 def h(values, l_expected): 930 for l, expected in l_expected.items(): 931 values_tensor = torch.tensor( 932 [float(v) for v in values], dtype=dtype, device=device 933 ) 934 expected_tensor = torch.tensor( 935 [float(v) for v in expected], dtype=dtype, device=device 936 ) 937 self.assertEqual( 938 expected_tensor == values_tensor.hardshrink(l), 939 torch.ones_like(values_tensor, dtype=torch.bool), 940 ) 941 942 def test_helper(min, max): 943 h( 944 [0.0, min, -min, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf], 945 { 946 0.0: [0.0, min, -min, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf], 947 min: [0.0, 0.0, 0.0, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf], 948 0.1: [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, -1.0, max, -max, inf, -inf], 949 1.0: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, max, -max, inf, -inf], 950 max: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, inf, -inf], 951 inf: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 952 }, 953 ) 954 955 test_helper(torch.finfo(dtype).tiny, torch.finfo(dtype).max) 956 957 @onlyCPU 958 @slowTest 959 @dtypes(torch.float) 960 @unittest.skipIf(True, "Insufficient memory on linux.(2|4)xlarge") 961 def test_exp_slow(self, device, dtype): 962 # Test for https://github.com/pytorch/pytorch/issues/17271 963 # This is pretty slow on my Macbook but it only takes a few 964 # seconds on a beefy Xeon server 965 a = torch.exp(torch.ones(2**31, dtype=dtype, device=device)) 966 b = torch.exp(torch.ones(1, dtype=dtype, device=device)) 967 self.assertEqual(a, b.expand(2**31)) 968 969 @precisionOverride( 970 {torch.bfloat16: 1e-2, torch.float: 0.0002, torch.double: 0.0002} 971 ) 972 @dtypes(torch.float, torch.double, torch.bfloat16) 973 def test_hardswish(self, device, dtype): 974 inputValues = [-1000, -4, -3, -2, 0, 2, 3, 4, 1000] 975 expectedOutput = np.multiply( 976 inputValues, np.minimum(np.maximum((np.add(inputValues, 3)), 0), 6) / 6.0 977 ) 978 979 inputTensor = torch.tensor(inputValues, dtype=dtype, device=device) 980 expectedOutputTensor = torch.tensor(expectedOutput, dtype=dtype, device=device) 981 982 # normal 983 self.assertEqual( 984 torch.nn.functional.hardswish(inputTensor), expectedOutputTensor 985 ) 986 987 # inplace 988 inputTensorCpy = inputTensor.clone().detach() 989 torch.nn.functional.hardswish(inputTensorCpy, inplace=True) 990 self.assertEqual(inputTensorCpy, expectedOutputTensor) 991 992 @precisionOverride( 993 {torch.bfloat16: 1e-2, torch.float: 0.0002, torch.double: 0.0002} 994 ) 995 @dtypes(torch.float, torch.double, torch.bfloat16) 996 def test_hardsigmoid(self, device, dtype): 997 inputValues = [-1000, -4, -3, -2, 0, 2, 3, 4, 1000] 998 expectedOutput = np.minimum(np.maximum((np.add(inputValues, 3)), 0), 6) / 6.0 999 1000 inputTensor = torch.tensor(inputValues, dtype=dtype, device=device) 1001 1002 # normal 1003 self.assertEqual( 1004 torch.nn.functional.hardsigmoid(inputTensor), 1005 torch.tensor(expectedOutput, dtype=dtype, device=device), 1006 ) 1007 1008 # inplace 1009 inputTensorCpy = inputTensor.clone().detach() 1010 self.assertEqual( 1011 torch.nn.functional.hardsigmoid(inputTensorCpy, inplace=True), 1012 torch.tensor(expectedOutput, dtype=dtype, device=device), 1013 ) 1014 1015 @precisionOverride( 1016 {torch.bfloat16: 1e-2, torch.float: 0.0002, torch.double: 0.0002} 1017 ) 1018 @dtypes(torch.float, torch.double, torch.bfloat16) 1019 def test_hardsigmoid_backward(self, device, dtype): 1020 inputValues = [-3.0, 3.0, -2.0, 2.0, -6.0, 6.0] 1021 expectedValues = [0.0, 0.0, 1.0 / 6.0, 1.0 / 6.0, 0.0, 0.0] 1022 inputTensor = torch.tensor( 1023 inputValues, dtype=dtype, device=device 1024 ).requires_grad_() 1025 expetedTensor = torch.tensor(expectedValues, dtype=dtype, device=device) 1026 out = torch.nn.functional.hardsigmoid(inputTensor) 1027 out.backward(torch.ones_like(inputTensor)) 1028 self.assertEqual(inputTensor.grad, expetedTensor) 1029 1030 @skipIfNoSciPy 1031 @dtypes(torch.float, torch.double) 1032 def test_silu(self, device, dtype): 1033 input_np = np.random.randn(5, 8) 1034 special_input = [[-1000, -1, -0.1, 0, 0.5, 1, 2, 1000]] 1035 input_np = np.concatenate((input_np, special_input), axis=0).astype( 1036 torch_to_numpy_dtype_dict[dtype] 1037 ) 1038 expected_output_np = input_np * scipy.special.expit(input_np) 1039 1040 expected_output = torch.from_numpy(expected_output_np).to(device) 1041 expected_output_noncontig = expected_output.transpose(0, 1) 1042 1043 atol = 1e-6 1044 rtol = 1e-6 1045 1046 input = torch.from_numpy(input_np).clone().contiguous().to(device) 1047 self.assertEqual( 1048 torch.nn.functional.silu(input), expected_output, atol=atol, rtol=rtol 1049 ) 1050 self.assertEqual( 1051 torch.nn.functional.silu(input, inplace=True), 1052 expected_output, 1053 atol=atol, 1054 rtol=rtol, 1055 ) 1056 1057 input = torch.from_numpy(input_np).clone().to(device) 1058 input_noncontig = input.transpose(0, 1) 1059 self.assertEqual( 1060 torch.nn.functional.silu(input_noncontig), 1061 expected_output_noncontig, 1062 atol=atol, 1063 rtol=rtol, 1064 ) 1065 self.assertEqual( 1066 torch.nn.functional.silu(input_noncontig, inplace=True), 1067 expected_output_noncontig, 1068 atol=atol, 1069 rtol=rtol, 1070 ) 1071 1072 @dtypes(torch.complex64, torch.complex128) 1073 def test_silu_complex(self, device, dtype): 1074 atol = 1e-6 1075 rtol = 1e-6 1076 inouts = [ 1077 (0.2 + 0.3j, 0.08775215595960617065 + 0.18024823069572448730j), 1078 (1e-19 + 1e-18j, 4.99999984132761269448e-20 + 5.00000022906852482872e-19j), 1079 (-1.0 + 2.0j, -0.78546208143234252930 + -0.44626939296722412109j), 1080 (0.0 + 0.5j, -0.06383547931909561157 + 0.25000000000000000000j), 1081 (2.0j, -1.55740761756896972656 + 0.99999988079071044922j) 1082 ] 1083 1084 for inp, out in inouts: 1085 res = torch.nn.functional.silu(torch.tensor(inp, dtype=dtype, device=device)) 1086 self.assertFalse(torch.any(torch.isnan(res))) 1087 self.assertEqual(res.real, out.real, atol=atol, rtol=rtol) 1088 self.assertEqual(res.imag, out.imag, atol=atol, rtol=rtol) 1089 1090 for inp, out in inouts: 1091 res = torch.nn.functional.silu(torch.tensor(inp, dtype=dtype, device=device), inplace=True) 1092 self.assertFalse(torch.any(torch.isnan(res))) 1093 self.assertEqual(res.real, out.real, atol=atol, rtol=rtol) 1094 self.assertEqual(res.imag, out.imag, atol=atol, rtol=rtol) 1095 1096 # It is not obvious how to merge this into OpInfo becuase these inputs 1097 # succeed for gradcheck but are expected to fail for gradgradcheck 1098 @dtypes(torch.double) 1099 def test_sinc(self, device, dtype): 1100 # The derivative of sinc(x) at x=0 has to be special cased. 1101 # A naive computation will result in 0/0 -> NaN. 1102 # We also need to be careful when we are very close to 0, as the 1103 # derivative's denominator is squared, and there are some floats 1104 # that are positive and whose squares are zero. 1105 a = torch.tensor( 1106 [0.0, torch.finfo(torch.double).tiny, 1.0], 1107 dtype=dtype, 1108 requires_grad=True, 1109 device=device, 1110 ) 1111 gradcheck(torch.sinc, a) 1112 1113 @skipIfNoSciPy 1114 @dtypes(torch.float, torch.double) 1115 def test_mish(self, device, dtype): 1116 input_np = np.random.randn(5, 8) 1117 special_input = [[-1000, -1, -0.1, 0, 0.5, 1, 2, 1000]] 1118 input_np = np.concatenate((input_np, special_input), axis=0).astype( 1119 torch_to_numpy_dtype_dict[dtype] 1120 ) 1121 expected_output_np = input_np * np.tanh(np.log1p(np.exp(input_np))) 1122 1123 expected_output = torch.from_numpy(expected_output_np).to(device) 1124 expected_output_noncontig = expected_output.transpose(0, 1) 1125 1126 atol = 1e-6 1127 rtol = 1e-6 1128 1129 input = torch.from_numpy(input_np).clone().contiguous().to(device) 1130 self.assertEqual( 1131 torch.nn.functional.mish(input), expected_output, atol=atol, rtol=rtol 1132 ) 1133 self.assertEqual( 1134 torch.nn.functional.mish(input, inplace=True), 1135 expected_output, 1136 atol=atol, 1137 rtol=rtol, 1138 ) 1139 1140 input = torch.from_numpy(input_np).clone().to(device) 1141 input_noncontig = input.transpose(0, 1) 1142 self.assertEqual( 1143 torch.nn.functional.mish(input_noncontig), 1144 expected_output_noncontig, 1145 atol=atol, 1146 rtol=rtol, 1147 ) 1148 self.assertEqual( 1149 torch.nn.functional.mish(input_noncontig, inplace=True), 1150 expected_output_noncontig, 1151 atol=atol, 1152 rtol=rtol, 1153 ) 1154 1155 @dtypes(torch.complex64, torch.complex128) 1156 def test_log1p_complex(self, device, dtype): 1157 # The output values here were obtained using arbitrary precision math (mpmath) 1158 # and double checked with WolframAlpha. 1159 # Not using numpy's log1p here because by the time of writing this, 1160 # np.log1p has precision problems for small complex input values, see here: 1161 # https://github.com/numpy/numpy/issues/22609 1162 inouts = [ 1163 (0.2 + 0.3j, 0.21263386770217202 + 0.24497866312686414j), 1164 (1e-19 + 1e-18j, 1e-19 + 1e-18j), 1165 (1e-18 + 0.1j, 0.00497517 + 0.0996687j), 1166 (0.1 + 1e-18j, 0.0953102 + 9.090909090909090909e-19j), 1167 (0.5 + 0j, 0.40546510810816 + 0j), 1168 (0.0 + 0.5j, 0.111571776 + 0.463647609j), 1169 (2.0 + 1.0j, 1.151292546497023 + 0.3217505543966422j), 1170 (-1.0 + 2.0j, 0.6931471805599453 + 1.570796326794897j), 1171 (2.0j, 0.80471895621705014 + 1.1071487177940904j), 1172 (-2.0j, 0.80471895621705014 - 1.1071487177940904j), 1173 ] 1174 # test the extreme values 1175 if dtype == torch.complex128: 1176 inouts += [ 1177 (-1 + 1e250j, 575.6462732485114 + 1.5707963267948966j), 1178 (1e250 + 1j, 575.6462732485114 + 1e-250j), 1179 (1e250 + 1e250j, 575.9928468387914 + 0.7853981633974483j), 1180 (1e-250 + 1e250j, 575.6462732485114 + 1.5707963267948966j), 1181 (1e-250 + 2e-250j, 1e-250 + 2e-250j), 1182 (1e250 + 1e-250j, 575.6462732485114 + 0.0j), 1183 ] 1184 elif dtype == torch.complex64: 1185 inouts += [ 1186 (-1 + 1e30j, 69.07755278982137 + 1.5707963267948966j), 1187 (1e30 + 1j, 69.07755278982137 + 1e-30j), 1188 (1e30 + 1e30j, 69.42412638010134 + 0.7853981633974483j), 1189 (1e-30 + 1e30j, 69.07755278982137 + 1.5707963267948966j), 1190 (1e-30 + 2e-30j, 1e-30 + 2e-30j), 1191 (1e30 + 1e-30j, 69.07755278982137 + 0.0j), 1192 ] 1193 1194 # test the log1p individually 1195 for inp, out in inouts: 1196 res = torch.log1p(torch.tensor(inp, dtype=dtype, device=device)) 1197 self.assertFalse(torch.any(torch.isnan(res))) 1198 # setting up atol == 0.0 because some part has very small values 1199 self.assertEqual(res.real, out.real, atol=0.0, rtol=1e-6) 1200 self.assertEqual(res.imag, out.imag, atol=0.0, rtol=1e-6) 1201 1202 # test the log1p in tensor 1203 inp_lst, out_lst = (list(elmt) for elmt in zip(*inouts)) 1204 inp_tens = torch.tensor(inp_lst, dtype=dtype, device=device) 1205 out_tens = torch.tensor(out_lst, dtype=dtype, device=device) 1206 res_tens = torch.log1p(inp_tens) 1207 self.assertEqual(res_tens.real, out_tens.real, atol=0.0, rtol=1e-6) 1208 self.assertEqual(res_tens.imag, out_tens.imag, atol=0.0, rtol=1e-6) 1209 1210 # do ops like threshold need a test_unary(_nonufunc) test suite? 1211 @onlyCPU 1212 @dtypes(*get_all_math_dtypes("cpu")) 1213 def test_threshold(self, device, dtype): 1214 if dtype != torch.uint8 and dtype != torch.float16 and not dtype.is_complex: 1215 # 100 is wide enough to use AVX2 instructions for all types 1216 x = ( 1217 torch.randn(100, dtype=torch.float, device=device) 1218 .sign() 1219 .to(dtype=dtype) 1220 ) 1221 y = torch.threshold(x, 0, 0) 1222 self.assertTrue(y.le(0).any()) 1223 1224 def _helper_test_igamma(self, loglo, loghi, device, dtype, torch_fcn, scipy_fcn): 1225 exp1 = 2.71828182846 1226 vec1 = torch.logspace( 1227 loglo, loghi, steps=500, base=exp1, dtype=torch.float64, device=device 1228 ).unsqueeze(-1) 1229 vec1 = vec1.to(dtype) 1230 inputs = [ 1231 (vec1, vec1.transpose(0, 1)), 1232 (vec1, vec1), # for large number, it should approach 0.5 1233 (vec1, 0.5 * vec1), # test for considerable ratio 1234 (vec1, 2.0 * vec1), 1235 (vec1[::2, :], vec1[::2, :]), # contiguous/noncontiguous tests 1236 (vec1[::2, :], vec1[: vec1.shape[0] // 2, :]), 1237 (vec1[: vec1.shape[0] // 2, :], vec1[::2, :]), 1238 ] 1239 half_prec = dtype in [torch.bfloat16, torch.float16] 1240 for input0, input1 in inputs: 1241 actual = torch_fcn(input0, input1) 1242 if half_prec: 1243 input0 = input0.to(torch.float) 1244 input1 = input1.to(torch.float) 1245 expected = scipy_fcn(input0.cpu().numpy(), input1.cpu().numpy()) 1246 expected = torch.from_numpy(expected).to(dtype) 1247 self.assertEqual(actual, expected) 1248 1249 @dtypesIfCPU(torch.float16, torch.bfloat16, torch.float32, torch.float64) 1250 @dtypes(torch.float32, torch.float64) 1251 @unittest.skipIf(not TEST_SCIPY, "SciPy not found") 1252 @onlyNativeDeviceTypes 1253 def test_igamma_common(self, device, dtype): 1254 # test igamma for reasonable range of values 1255 loglo = -4 # approx 0.018 1256 loghi = 4 # approx 54.6 1257 self._helper_test_igamma( 1258 loglo, loghi, device, dtype, torch.igamma, scipy.special.gammainc 1259 ) 1260 1261 @dtypesIfCPU(torch.float16, torch.bfloat16, torch.float32, torch.float64) 1262 @dtypes(torch.float32, torch.float64) 1263 @unittest.skipIf(not TEST_SCIPY, "SciPy not found") 1264 @onlyNativeDeviceTypes 1265 def test_igammac_common(self, device, dtype): 1266 # test igammac for reasonable range of values 1267 loglo = -4 # approx 0.018 1268 loghi = 4 # approx 54.6 1269 self._helper_test_igamma( 1270 loglo, loghi, device, dtype, torch.igammac, scipy.special.gammaincc 1271 ) 1272 1273 @dtypesIfCPU(torch.float16, torch.bfloat16, torch.float32, torch.float64) 1274 @dtypes(torch.float32, torch.float64) 1275 @onlyNativeDeviceTypes 1276 def test_igamma_edge_cases(self, device, dtype): 1277 tkwargs = {"dtype": dtype, "device": device} 1278 infs = torch.zeros((3,), **tkwargs) + float("inf") 1279 zeros = torch.zeros((3,), **tkwargs) 1280 ones = torch.ones((3,), **tkwargs) 1281 zero_to_large = torch.tensor([0.0, 1.0, 1e3], **tkwargs) 1282 small_to_inf = torch.tensor([1e-3, 1.0, float("inf")], **tkwargs) 1283 nans = torch.zeros((3,), **tkwargs) + float("nan") 1284 inpouts = [ 1285 # (a , x), out 1286 ((zeros, small_to_inf), ones), 1287 ((small_to_inf, zeros), zeros), 1288 ((infs, zero_to_large), zeros), 1289 ((zero_to_large, infs), ones), 1290 ((zeros, zeros), nans), 1291 ((infs, infs), nans), 1292 ((-small_to_inf, small_to_inf), nans), 1293 ] 1294 for inputs, output in inpouts: 1295 input0, input1 = inputs 1296 calc = torch.igamma(input0, input1) 1297 if torch.all(torch.isnan(output)): 1298 self.assertTrue(torch.all(torch.isnan(calc))) 1299 else: 1300 self.assertEqual(calc, output) 1301 1302 @dtypesIfCPU(torch.float16, torch.bfloat16, torch.float32, torch.float64) 1303 @dtypes(torch.float32, torch.float64) 1304 @onlyNativeDeviceTypes 1305 def test_igammac_edge_cases(self, device, dtype): 1306 tkwargs = {"dtype": dtype, "device": device} 1307 infs = torch.zeros((3,), **tkwargs) + float("inf") 1308 zeros = torch.zeros((3,), **tkwargs) 1309 ones = torch.ones((3,), **tkwargs) 1310 zero_to_large = torch.tensor([0.0, 1.0, 1e3], **tkwargs) 1311 small_to_inf = torch.tensor([1e-3, 1.0, float("inf")], **tkwargs) 1312 nans = torch.zeros((3,), **tkwargs) + float("nan") 1313 inpouts = [ 1314 # (a , x), out 1315 ((zeros, small_to_inf), zeros), 1316 ((small_to_inf, zeros), ones), 1317 ((infs, zero_to_large), ones), 1318 ((zero_to_large, infs), zeros), 1319 ((zeros, zeros), nans), 1320 ((infs, infs), nans), 1321 ((-small_to_inf, small_to_inf), nans), 1322 ] 1323 for inputs, output in inpouts: 1324 input0, input1 = inputs 1325 calc = torch.igammac(input0, input1) 1326 if torch.all(torch.isnan(output)): 1327 self.assertTrue(torch.all(torch.isnan(calc))) 1328 else: 1329 self.assertEqual(calc, output) 1330 1331 def _i0_helper(self, t): 1332 # Test by comparing to scipy 1333 dtype = t.dtype 1334 actual = torch.i0(t) 1335 if dtype is torch.bfloat16: 1336 t = t.to(torch.float32) 1337 expected = scipy.special.i0(t.cpu().numpy()) 1338 # Casting down for dtype float16 is required since scipy upcasts to float32 1339 if dtype is torch.bfloat16 or dtype is torch.float16: 1340 expected = torch.from_numpy(expected).to(dtype) 1341 self.assertEqual(actual, expected) 1342 1343 def _i0_range_helper(self, range, device, dtype): 1344 # i0 tests are broken up by the domain for which the function does not overflow for each dtype 1345 # This is done to ensure that the function performs well across all possible input values, without worrying 1346 # about inf or nan possibilities 1347 for r in (range, -range): 1348 t = torch.rand(1000, device=device).to(dtype) * r 1349 self._i0_helper(t) 1350 1351 @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16)) 1352 @dtypes(torch.bfloat16, torch.float32, torch.float64) 1353 @unittest.skipIf(not TEST_SCIPY, "SciPy not found") 1354 def test_i0_range1(self, device, dtype): 1355 # This tests the domain for i0 for which float16 does not overflow 1356 # The domain is (-13.25, 13.25) 1357 self._i0_range_helper(13.25, device, dtype) 1358 1359 @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16)) 1360 @dtypes(torch.bfloat16, torch.float32, torch.float64) 1361 @unittest.skipIf(not TEST_SCIPY, "SciPy not found") 1362 def test_i0_range2(self, device, dtype): 1363 # This tests the domain for i0 for which float32 and bfloat16 does not overflow 1364 # The domain is (-88.5, 88.5) 1365 self._i0_range_helper(88.5, device, dtype) 1366 1367 @dtypes(torch.float64) 1368 @unittest.skipIf(not TEST_SCIPY, "SciPy not found") 1369 def test_i0_range3(self, device, dtype): 1370 # This tests the domain for i0 for which float64 does not overflow 1371 # The domain is (-709.75, 709.75) 1372 self._i0_range_helper(709.75, device, dtype) 1373 1374 @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16)) 1375 @dtypes(torch.bfloat16, torch.float32, torch.float64) 1376 @unittest.skipIf(not TEST_SCIPY, "SciPy not found") 1377 def test_i0_special(self, device, dtype): 1378 t = torch.tensor([], device=device, dtype=dtype) 1379 self._i0_helper(t) 1380 1381 t = torch.tensor([inf, -inf, nan], device=device, dtype=dtype) 1382 self.assertTrue(torch.i0(t).isnan().all()) 1383 1384 @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16)) 1385 @dtypes(torch.bfloat16, torch.float32, torch.float64) 1386 @unittest.skipIf(not TEST_SCIPY, "SciPy not found") 1387 def test_special_i0_i1_vs_scipy(self, device, dtype): 1388 def check_equal(t, torch_fn, scipy_fn): 1389 # Test by comparing to scipy 1390 actual = torch_fn(t) 1391 if dtype is torch.bfloat16: 1392 t = t.to(torch.float32) 1393 expected = scipy_fn(t.cpu().numpy()) 1394 1395 # Casting down for dtype float16 is required since scipy upcasts to float32 1396 if dtype is torch.bfloat16 or dtype is torch.float16: 1397 expected = torch.from_numpy(expected).to(dtype) 1398 self.assertEqual(actual, expected) 1399 1400 t = torch.tensor([], device=device, dtype=dtype) 1401 check_equal(t, torch.i0, scipy.special.i0) 1402 check_equal(t, torch.special.i0e, scipy.special.i0e) 1403 if dtype not in [torch.half, torch.bfloat16]: 1404 check_equal(t, torch.special.i1, scipy.special.i1) 1405 check_equal(t, torch.special.i1e, scipy.special.i1e) 1406 1407 range = (-1e7, 1e7) 1408 if dtype == torch.half: 1409 range = (-65000, 65000) 1410 1411 t = torch.linspace(*range, int(1e4), device=device, dtype=dtype) 1412 check_equal(t, torch.i0, scipy.special.i0) 1413 check_equal(t, torch.special.i0e, scipy.special.i0e) 1414 if dtype not in [torch.half, torch.bfloat16]: 1415 check_equal(t, torch.special.i1, scipy.special.i1) 1416 check_equal(t, torch.special.i1e, scipy.special.i1e) 1417 1418 # NaN, inf, -inf are tested in reference_numerics tests. 1419 info = torch.finfo(dtype) 1420 min, max, eps, tiny = info.min, info.max, info.eps, info.tiny 1421 t = torch.tensor([min, max, eps, tiny], dtype=dtype, device=device) 1422 check_equal(t, torch.i0, scipy.special.i0) 1423 check_equal(t, torch.special.i0e, scipy.special.i0e) 1424 if dtype not in [torch.half, torch.bfloat16]: 1425 check_equal(t, torch.special.i1, scipy.special.i1) 1426 check_equal(t, torch.special.i1e, scipy.special.i1e) 1427 1428 @dtypes(torch.float32, torch.float64) 1429 @unittest.skipIf(not TEST_SCIPY, "SciPy not found") 1430 def test_special_ndtr_vs_scipy(self, device, dtype): 1431 def check_equal(t): 1432 # Test by comparing to scipy 1433 actual = torch.special.ndtr(t) 1434 expected = scipy.special.ndtr(t.cpu().numpy()) 1435 self.assertEqual(actual, expected) 1436 1437 range = (-10, 10) 1438 t = torch.linspace(*range, 1, device=device, dtype=dtype) 1439 check_equal(t) 1440 1441 # Skip testing NaN, inf, -inf since they are tested in reference_numerics tests. 1442 info = torch.finfo(dtype) 1443 min, max, eps, tiny = info.min, info.max, info.eps, info.tiny 1444 t = torch.tensor([min, max, eps, tiny], dtype=dtype, device=device) 1445 check_equal(t) 1446 1447 @dtypes(torch.float32, torch.float64) 1448 @unittest.skipIf(not TEST_SCIPY, "SciPy not found") 1449 def test_special_log_ndtr_vs_scipy(self, device, dtype): 1450 def check_equal(t): 1451 # Test by comparing with scipy 1452 actual = torch.special.log_ndtr(t) 1453 expected = scipy.special.log_ndtr(t.cpu().numpy()) 1454 self.assertEqual(actual, expected) 1455 1456 # Skip testing NaN, inf, -inf since they are tested in reference_numerics tests. 1457 info = torch.finfo(dtype) 1458 min, max, eps, tiny = info.min, info.max, info.eps, info.tiny 1459 t = torch.tensor([min, max, eps, tiny], dtype=dtype, device=device) 1460 check_equal(t) 1461 1462 # TODO: allow large opinfo values to be opted-into via metadata 1463 @dtypes(torch.long) 1464 def test_abs_big_number(self, device, dtype): 1465 bignumber = 2**31 + 1 1466 res = torch.tensor([bignumber], device=device, dtype=dtype) 1467 self.assertGreater(res.abs()[0], 0) 1468 1469 # TODO: add signed zero testing to opinfos 1470 @dtypes(torch.float, torch.double) 1471 def test_abs_signed_zero(self, device, dtype): 1472 # Both abs(0.0) and abs(-0.0) should result in 0.0 1473 size = 128 + 1 # pick a large enough number with remainder so that 1474 # both vectorized and nonvectorized op is tested 1475 inp = torch.zeros(size, device=device, dtype=dtype) 1476 inp[::2] = -0.0 1477 inp = inp.abs() 1478 for v in inp: 1479 self.assertGreater(math.copysign(1.0, v), 0.0) 1480 1481 # TODO: update to compare against NumPy by rationalizing with OpInfo 1482 @onlyCUDA 1483 @dtypes(torch.float, torch.double) 1484 def test_abs_zero(self, device, dtype): 1485 # Both abs(0.0) and abs(-0.0) should result in 0.0 1486 abs_zeros = torch.tensor([0.0, -0.0], device=device, dtype=dtype).abs().tolist() 1487 for num in abs_zeros: 1488 self.assertGreater(math.copysign(1.0, num), 0.0) 1489 1490 @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16)) 1491 def test_isposinf_isneginf_non_boolean_output(self, device, dtype): 1492 # test non-boolean tensors as the `out=` parameters 1493 # boolean outputs are tested in the above testcases 1494 vals = (float("inf"), -float("inf"), 1.2) 1495 t = torch.tensor(vals, device=device) 1496 for torch_op in (torch.isposinf, torch.isneginf): 1497 out = torch.empty_like(t, dtype=dtype) 1498 with self.assertRaisesRegex( 1499 RuntimeError, "does not support non-boolean outputs" 1500 ): 1501 torch_op(t, out=out) 1502 1503 def test_nonzero_empty(self, device): 1504 def assert_tuple_empty(tup, dim): 1505 self.assertEqual(dim, len(tup)) 1506 for t in tup: 1507 self.assertEqual(torch.Size([0]), t.shape) 1508 1509 x = torch.randn(0, 2, 0, 5, 0, device=device) 1510 y = torch.nonzero(x) 1511 z = torch.nonzero(x, as_tuple=True) 1512 1513 self.assertEqual(0, y.numel()) 1514 self.assertEqual(torch.Size([0, 5]), y.shape) 1515 assert_tuple_empty(z, 5) 1516 1517 x = torch.tensor(0.5, device=device) 1518 y = torch.nonzero(x) 1519 # nonzero with as_tuple returns a 1520 # tuple of len 1 for a zero-dim tensor. 1521 # This is done to match Numpy behavior. 1522 z = torch.nonzero(x, as_tuple=True) 1523 self.assertEqual(1, len(z)) 1524 self.assertEqual(torch.zeros(1, dtype=torch.long), z[0]) 1525 1526 x = torch.zeros((), device=device) 1527 y = torch.nonzero(x) 1528 z = torch.nonzero(x, as_tuple=True) 1529 self.assertEqual(torch.Size([0, 0]), y.shape) 1530 self.assertEqual(1, len(z)) 1531 self.assertEqual(torch.empty(0, dtype=torch.long), z[0]) 1532 1533 # TODO: rationalize with exp OpInfo 1534 @dtypes(*floating_and_complex_types_and(torch.bfloat16)) 1535 @dtypesIfCUDA(*floating_and_complex_types_and(torch.half, torch.bfloat16)) 1536 def test_exp(self, device, dtype): 1537 for v in (2, -2) + ((1j, 1 + 1j) if dtype.is_complex else ()): 1538 a = ( 1539 torch.tensor(v, dtype=dtype, device=device) 1540 * torch.arange(18, device=device) 1541 / 3 1542 * math.pi 1543 ) 1544 a = a.to(dtype) 1545 # bfloat16 overflows 1546 if dtype == torch.bfloat16: 1547 return 1548 self.compare_with_numpy(torch.exp, np.exp, a) 1549 1550 if dtype.is_complex: 1551 inf_real_zero_imag_in = torch.tensor( 1552 complex(float("inf"), 0), device=device, dtype=dtype 1553 ) 1554 inf_real_zero_imag_out = torch.exp(inf_real_zero_imag_in).item() 1555 self.assertTrue(math.isinf(inf_real_zero_imag_out.real)) 1556 if self.device_type == "cpu": 1557 pass 1558 # These are commented out because it cannot be consistently reproduced. 1559 # This is incorrect. It should be zero. Need fix! 1560 # https://github.com/pytorch/pytorch/issues/40590 1561 # self.assertNotEqual(inf_real_zero_imag_out.imag, 0) 1562 # This is incorrect. They should equal. Need fix! 1563 # https://github.com/pytorch/pytorch/issues/40590 1564 # with self.assertRaises(AssertionError): 1565 # self.compare_with_numpy(torch.exp, np.exp, inf_real_zero_imag_in) 1566 else: 1567 self.assertEqual(inf_real_zero_imag_out.imag, 0, atol=0, rtol=0) 1568 self.compare_with_numpy(torch.exp, np.exp, inf_real_zero_imag_in) 1569 1570 zero_real_inf_imag_in = torch.tensor( 1571 complex(0, float("inf")), device=device, dtype=dtype 1572 ) 1573 zero_real_inf_imag_out = torch.exp(zero_real_inf_imag_in).item() 1574 self.assertTrue(math.isnan(zero_real_inf_imag_out.real)) 1575 self.assertTrue(math.isnan(zero_real_inf_imag_out.imag)) 1576 # Ensure we are notified when NumPy changes its behavior 1577 self.compare_with_numpy(torch.exp, np.exp, zero_real_inf_imag_in) 1578 1579 inf_real_imag_in = torch.tensor( 1580 complex(float("inf"), float("inf")), device=device, dtype=dtype 1581 ) 1582 inf_real_imag_out = torch.exp(inf_real_imag_in).item() 1583 if self.device_type == "cpu": 1584 pass 1585 # This is incorrect. Need fix! https://github.com/pytorch/pytorch/issues/40590 1586 # This is commented out because it cannot be consistently reproduced. 1587 # with self.assertRaises(AssertionError): 1588 # self.compare_with_numpy(torch.exp, np.exp, inf_real_imag_in) 1589 else: 1590 self.assertTrue(math.isinf(inf_real_imag_out.real)) 1591 self.assertTrue(math.isnan(inf_real_imag_out.imag)) 1592 self.compare_with_numpy(torch.exp, np.exp, inf_real_imag_in) 1593 1594 inf_real_nan_imag_in = torch.tensor( 1595 complex(float("inf"), float("nan")), device=device, dtype=dtype 1596 ) 1597 inf_real_nan_imag_out = torch.exp(inf_real_nan_imag_in).item() 1598 if self.device_type == "cpu": 1599 pass 1600 # This is incorrect. It should be inf. Need fix! https://github.com/pytorch/pytorch/issues/40590 1601 # This is commented out because it cannot be consistently reproduced. 1602 # with self.assertRaises(AssertionError): 1603 # self.compare_with_numpy(torch.exp, np.exp, inf_real_nan_imag_in) 1604 else: 1605 self.assertTrue(math.isinf(inf_real_nan_imag_out.real)) 1606 self.assertTrue(math.isnan(inf_real_nan_imag_out.imag)) 1607 self.compare_with_numpy(torch.exp, np.exp, inf_real_nan_imag_in) 1608 1609 nan_real_inf_imag_in = torch.tensor( 1610 complex(float("nan"), float("inf")), device=device, dtype=dtype 1611 ) 1612 nan_real_inf_imag_out = torch.exp(nan_real_inf_imag_in).item() 1613 self.assertTrue(math.isnan(nan_real_inf_imag_out.real)) 1614 self.assertTrue(math.isnan(nan_real_inf_imag_out.imag)) 1615 # Ensure we are notified when NumPy changes its behavior 1616 self.compare_with_numpy(torch.exp, np.exp, nan_real_inf_imag_in) 1617 1618 1619instantiate_device_type_tests(TestUnaryUfuncs, globals()) 1620 1621if __name__ == "__main__": 1622 run_tests() 1623