1# mypy: ignore-errors 2 3import unittest 4from functools import partial 5from itertools import product 6from typing import List 7 8import numpy as np 9 10import torch 11from torch.testing import make_tensor 12from torch.testing._internal.common_device_type import ( 13 precisionOverride, 14 tol, 15 toleranceOverride, 16) 17from torch.testing._internal.common_dtype import all_types_and, floating_types 18from torch.testing._internal.common_utils import ( 19 TEST_SCIPY, 20 TEST_WITH_ROCM, 21 torch_to_numpy_dtype_dict, 22) 23from torch.testing._internal.opinfo.core import ( 24 BinaryUfuncInfo, 25 DecorateInfo, 26 L, 27 NumericsFilter, 28 OpInfo, 29 S, 30 SampleInput, 31 UnaryUfuncInfo, 32) 33from torch.testing._internal.opinfo.refs import ( 34 ElementwiseBinaryPythonRefInfo, 35 ElementwiseUnaryPythonRefInfo, 36) 37from torch.testing._internal.opinfo.utils import ( 38 np_unary_ufunc_integer_promotion_wrapper, 39) 40 41 42if TEST_SCIPY: 43 import scipy.special 44 45 46# TODO: Consolidate `i0e` with sample_inputs_unary when `make_tensor`, 47# supports `exclude` argument. 48# For more context: https://github.com/pytorch/pytorch/pull/56352#discussion_r633277617 49def sample_inputs_i0_i1(op_info, device, dtype, requires_grad, **kwargs): 50 exclude_zero = requires_grad and op_info.op == torch.special.i0e 51 make_arg = partial( 52 make_tensor, 53 dtype=dtype, 54 device=device, 55 requires_grad=requires_grad, 56 exclude_zero=exclude_zero, 57 ) 58 yield SampleInput(make_arg((S,))) 59 yield SampleInput(make_arg(())) 60 61 if requires_grad and not exclude_zero: 62 # Special Case for gradient 63 # Sample with `0` in the input 64 t = make_arg((S,)) 65 t[0] = 0 66 67 yield SampleInput(t) 68 69 70def sample_inputs_polygamma(op_info, device, dtype, requires_grad, **kwargs): 71 make_arg = partial( 72 make_tensor, 73 device=device, 74 # TODO: eliminate low after gh-106692 is fixed: 75 low=(1 if dtype in {torch.int32, torch.int64} else None), 76 dtype=dtype, 77 requires_grad=requires_grad, 78 ) 79 tensor_shapes = ((S, S), ()) 80 ns = (1, 2, 3, 4, 5) 81 82 for shape, n in product(tensor_shapes, ns): 83 yield SampleInput(make_arg(shape), args=(n,)) 84 85 86def reference_polygamma(x, n): 87 # WEIRD `scipy.special.polygamma` behavior 88 # >>> scipy.special.polygamma(0, np.array(501, dtype=np.float32)).dtype 89 # dtype('float64') 90 # >>> scipy.special.polygamma(0, np.array([501], dtype=np.float32)).dtype 91 # dtype('float32') 92 # 93 # Thus we cast output to the default torch dtype or preserve double 94 result_dtype = torch_to_numpy_dtype_dict[torch.get_default_dtype()] 95 if x.dtype == np.double: 96 result_dtype = np.double 97 return scipy.special.polygamma(n, x).astype(result_dtype) 98 99 100def sample_inputs_entr(op_info, device, dtype, requires_grad, **kwargs): 101 low, _ = op_info.domain 102 103 if requires_grad: 104 low = 0 + op_info._domain_eps 105 106 make_arg = partial( 107 make_tensor, dtype=dtype, device=device, low=low, requires_grad=requires_grad 108 ) 109 yield SampleInput(make_arg((L,))) 110 yield SampleInput(make_arg(())) 111 112 113def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs): 114 for shape in ((L,), (1, 0, 3), ()): 115 yield SampleInput( 116 make_tensor( 117 shape, 118 device=device, 119 dtype=dtype, 120 low=-5, 121 requires_grad=requires_grad, 122 ), 123 ) 124 125 126op_db: List[OpInfo] = [ 127 UnaryUfuncInfo( 128 "special.i0e", 129 aten_name="special_i0e", 130 ref=scipy.special.i0e if TEST_SCIPY else None, 131 decorators=(precisionOverride({torch.bfloat16: 3e-1, torch.float16: 3e-1}),), 132 dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), 133 backward_dtypes=floating_types(), 134 sample_inputs_func=sample_inputs_i0_i1, 135 supports_forward_ad=True, 136 supports_fwgrad_bwgrad=True, 137 ), 138 UnaryUfuncInfo( 139 "special.i1", 140 aten_name="special_i1", 141 ref=np_unary_ufunc_integer_promotion_wrapper(scipy.special.i1) 142 if TEST_SCIPY 143 else None, 144 dtypes=all_types_and(torch.bool), 145 dtypesIfCUDA=all_types_and(torch.bool), 146 sample_inputs_func=sample_inputs_i0_i1, 147 decorators=( 148 DecorateInfo( 149 toleranceOverride( 150 { 151 torch.float32: tol(atol=1e-4, rtol=0), 152 torch.bool: tol(atol=1e-4, rtol=0), 153 } 154 ) 155 ), 156 ), 157 skips=( 158 DecorateInfo( 159 unittest.skip("Incorrect result!"), 160 "TestUnaryUfuncs", 161 "test_reference_numerics_large", 162 dtypes=(torch.int8,), 163 ), 164 ), 165 supports_fwgrad_bwgrad=True, 166 supports_forward_ad=True, 167 ), 168 UnaryUfuncInfo( 169 "special.i1e", 170 aten_name="special_i1e", 171 ref=scipy.special.i1e if TEST_SCIPY else None, 172 dtypes=all_types_and(torch.bool), 173 dtypesIfCUDA=all_types_and(torch.bool), 174 sample_inputs_func=sample_inputs_i0_i1, 175 supports_forward_ad=True, 176 supports_fwgrad_bwgrad=True, 177 ), 178 UnaryUfuncInfo( 179 "special.ndtr", 180 aten_name="special_ndtr", 181 decorators=(precisionOverride({torch.bfloat16: 5e-3, torch.float16: 5e-4}),), 182 ref=scipy.special.ndtr if TEST_SCIPY else None, 183 dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), 184 supports_forward_ad=True, 185 supports_fwgrad_bwgrad=True, 186 skips=( 187 # Dispatch stub: unsupported device typemeta 188 DecorateInfo( 189 unittest.expectedFailure, 190 "TestFwdGradients", 191 "test_fn_fwgrad_bwgrad", 192 device_type="meta", 193 ), 194 ), 195 ), 196 # A separate OpInfo entry for special.polygamma is needed to reorder the arguments 197 # for the alias. See the discussion here: https://github.com/pytorch/pytorch/pull/59691#discussion_r650261939 198 UnaryUfuncInfo( 199 "special.polygamma", 200 op=lambda x, n, **kwargs: torch.special.polygamma(n, x, **kwargs), 201 variant_test_name="special_polygamma_n_0", 202 ref=reference_polygamma if TEST_SCIPY else None, 203 dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), 204 dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), 205 supports_forward_ad=True, 206 supports_fwgrad_bwgrad=True, 207 sample_inputs_func=sample_inputs_polygamma, 208 skips=( 209 # lambda impl 210 DecorateInfo( 211 unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" 212 ), 213 DecorateInfo( 214 unittest.expectedFailure, 215 "TestNormalizeOperators", 216 "test_normalize_operator_exhaustive", 217 ), 218 ), 219 sample_kwargs=lambda device, dtype, input: ({"n": 0}, {"n": 0}), 220 # polygamma functions have multiple singularities at x having non-positive integer value 221 reference_numerics_filter=NumericsFilter( 222 condition=lambda x: (x < 0.1) & ((x - x.round()).abs() < 1e-4), safe_val=1 223 ), 224 ), 225 BinaryUfuncInfo( 226 "special.xlog1py", 227 aten_name="special_xlog1py", 228 dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), 229 promotes_int_to_float=True, 230 supports_forward_ad=True, 231 supports_fwgrad_bwgrad=True, 232 supports_one_python_scalar=True, 233 # We don't test -1 as the gradient will be NaN and it'll break 234 rhs_make_tensor_kwargs=dict(low=-0.99), 235 ), 236 BinaryUfuncInfo( 237 "special.zeta", 238 aten_name="special_zeta", 239 dtypes=all_types_and(torch.bool), 240 promotes_int_to_float=True, 241 supports_autograd=False, 242 supports_one_python_scalar=True, 243 skips=( 244 # Reference reference_inputs nans and infs on cuda and nan, inf, 0., -inf for cpu 245 DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), 246 ), 247 ), 248 # TODO: FIXME 249 # OpInfo entry to verify the gradient formula of `other`/`q` 250 # BinaryUfuncInfo('special.zeta', 251 # op=lambda q, x, **kwargs: torch.special.zeta(x, q, **kwargs), 252 # aten_name='special_zeta', 253 # variant_test_name='grad', 254 # dtypes=all_types_and(torch.bool), 255 # promotes_int_to_float=True, 256 # supports_autograd=True, 257 # supports_rhs_python_scalar=False, 258 # decorators=[ 259 # # Derivative wrt first tensor not implemented 260 # DecorateInfo(unittest.expectedFailure, "TestCommon", 261 # "test_floating_inputs_are_differentiable") 262 # ], 263 # skips=( 264 # # Lambda doesn't work in JIT test 265 # # AssertionError: JIT Test does not execute any logic 266 # DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"), 267 # )), 268 UnaryUfuncInfo( 269 "special.entr", 270 ref=scipy.special.entr if TEST_SCIPY else None, 271 aten_name="special_entr", 272 supports_forward_ad=True, 273 supports_fwgrad_bwgrad=True, 274 decorators=(precisionOverride({torch.float16: 1e-1, torch.bfloat16: 1e-1}),), 275 dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), 276 skips=( 277 DecorateInfo( 278 unittest.skip("Skipped!"), 279 "TestUnaryUfuncs", 280 "test_reference_numerics_large", 281 dtypes=[torch.bfloat16, torch.float16], 282 ), 283 ), 284 supports_inplace_autograd=False, 285 sample_inputs_func=sample_inputs_entr, 286 ), 287 UnaryUfuncInfo( 288 "special.ndtri", 289 ref=scipy.special.ndtri if TEST_SCIPY else None, 290 domain=(0, 1), 291 aten_name="special_ndtri", 292 dtypes=all_types_and(torch.bool), 293 supports_forward_ad=True, 294 supports_fwgrad_bwgrad=True, 295 ), 296 UnaryUfuncInfo( 297 "special.log_ndtr", 298 aten_name="special_log_ndtr", 299 ref=scipy.special.log_ndtr if TEST_SCIPY else None, 300 dtypes=all_types_and(torch.bool), 301 supports_forward_ad=True, 302 supports_fwgrad_bwgrad=True, 303 ), 304 UnaryUfuncInfo( 305 "special.erfcx", 306 ref=scipy.special.erfcx if TEST_SCIPY else None, 307 aten_name="special_erfcx", 308 decorators=( 309 toleranceOverride( 310 { 311 torch.float32: tol(atol=0, rtol=4e-6), 312 } 313 ), 314 ), 315 dtypes=all_types_and(torch.bool), 316 supports_forward_ad=True, 317 supports_fwgrad_bwgrad=True, 318 sample_inputs_func=sample_inputs_erfcx, 319 ), 320 UnaryUfuncInfo( 321 "special.airy_ai", 322 decorators=( 323 precisionOverride( 324 { 325 torch.float32: 1e-03, 326 torch.float64: 1e-05, 327 }, 328 ), 329 ), 330 dtypes=all_types_and(torch.bool), 331 ref=lambda x: scipy.special.airy(x)[0] if TEST_SCIPY else None, 332 skips=( 333 DecorateInfo( 334 unittest.skip("Skipped!"), 335 "TestUnaryUfuncs", 336 "test_reference_numerics_large", 337 ), 338 ), 339 supports_autograd=False, 340 ), 341 UnaryUfuncInfo( 342 "special.bessel_j0", 343 decorators=( 344 precisionOverride( 345 { 346 torch.float32: 1e-04, 347 torch.float64: 1e-05, 348 }, 349 ), 350 ), 351 dtypes=all_types_and(torch.bool), 352 ref=scipy.special.j0 if TEST_SCIPY else None, 353 supports_autograd=False, 354 ), 355 UnaryUfuncInfo( 356 "special.bessel_j1", 357 decorators=( 358 precisionOverride( 359 { 360 torch.float32: 1e-04, 361 torch.float64: 1e-05, 362 }, 363 ), 364 ), 365 dtypes=all_types_and(torch.bool), 366 ref=scipy.special.j1 if TEST_SCIPY else None, 367 supports_autograd=False, 368 ), 369 UnaryUfuncInfo( 370 "special.bessel_y0", 371 decorators=( 372 precisionOverride( 373 { 374 torch.float32: 1e-04, 375 torch.float64: 1e-05, 376 }, 377 ), 378 ), 379 dtypes=all_types_and(torch.bool), 380 ref=scipy.special.y0 if TEST_SCIPY else None, 381 supports_autograd=False, 382 ), 383 UnaryUfuncInfo( 384 "special.bessel_y1", 385 decorators=( 386 precisionOverride( 387 { 388 torch.float32: 1e-04, 389 torch.float64: 1e-05, 390 }, 391 ), 392 ), 393 dtypes=all_types_and(torch.bool), 394 ref=scipy.special.y1 if TEST_SCIPY else None, 395 supports_autograd=False, 396 ), 397 BinaryUfuncInfo( 398 "special.chebyshev_polynomial_t", 399 dtypes=all_types_and(torch.bool), 400 promotes_int_to_float=True, 401 skips=( 402 DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), 403 DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), 404 DecorateInfo( 405 unittest.skip("testing takes an unreasonably long time, #79528"), 406 "TestCommon", 407 "test_compare_cpu", 408 ), 409 ), 410 supports_one_python_scalar=True, 411 supports_autograd=False, 412 ), 413 BinaryUfuncInfo( 414 "special.chebyshev_polynomial_u", 415 dtypes=all_types_and(torch.bool), 416 promotes_int_to_float=True, 417 skips=( 418 DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), 419 DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), 420 DecorateInfo( 421 unittest.skip("testing takes an unreasonably long time, #79528"), 422 "TestCommon", 423 "test_compare_cpu", 424 ), 425 ), 426 supports_one_python_scalar=True, 427 supports_autograd=False, 428 ), 429 BinaryUfuncInfo( 430 "special.chebyshev_polynomial_v", 431 dtypes=all_types_and(torch.bool), 432 promotes_int_to_float=True, 433 skips=( 434 DecorateInfo( 435 unittest.skip( 436 "Skipping - testing takes an unreasonably long time, #79528" 437 ) 438 ), 439 DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), 440 DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), 441 ), 442 supports_one_python_scalar=True, 443 supports_autograd=False, 444 ), 445 BinaryUfuncInfo( 446 "special.chebyshev_polynomial_w", 447 dtypes=all_types_and(torch.bool), 448 promotes_int_to_float=True, 449 skips=( 450 DecorateInfo( 451 unittest.skip( 452 "Skipping - testing takes an unreasonably long time, #79528" 453 ) 454 ), 455 DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), 456 DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), 457 ), 458 supports_one_python_scalar=True, 459 supports_autograd=False, 460 ), 461 BinaryUfuncInfo( 462 "special.hermite_polynomial_h", 463 dtypes=all_types_and(torch.bool), 464 promotes_int_to_float=True, 465 skips=( 466 DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), 467 DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), 468 # Greatest absolute difference: inf 469 DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), 470 DecorateInfo(unittest.skip("Hangs on ROCm 6.1"), active_if=TEST_WITH_ROCM), 471 ), 472 supports_one_python_scalar=True, 473 supports_autograd=False, 474 ), 475 BinaryUfuncInfo( 476 "special.hermite_polynomial_he", 477 dtypes=all_types_and(torch.bool), 478 promotes_int_to_float=True, 479 skips=( 480 DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), 481 DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), 482 DecorateInfo( 483 unittest.skip("testing takes an unreasonably long time, #79528"), 484 "TestCommon", 485 "test_compare_cpu", 486 ), 487 ), 488 supports_one_python_scalar=True, 489 supports_autograd=False, 490 ), 491 BinaryUfuncInfo( 492 "special.laguerre_polynomial_l", 493 dtypes=all_types_and(torch.bool), 494 promotes_int_to_float=True, 495 skips=( 496 DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), 497 DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), 498 DecorateInfo( 499 unittest.skip("testing takes an unreasonably long time, #79528"), 500 "TestCommon", 501 "test_compare_cpu", 502 ), 503 ), 504 supports_one_python_scalar=True, 505 supports_autograd=False, 506 ), 507 BinaryUfuncInfo( 508 "special.legendre_polynomial_p", 509 dtypes=all_types_and(torch.bool), 510 promotes_int_to_float=True, 511 skips=( 512 DecorateInfo( 513 unittest.skip( 514 "Skipping - testing takes an unreasonably long time, #79528" 515 ) 516 ), 517 DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), 518 DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), 519 DecorateInfo( 520 unittest.skip("testing takes an unreasonably long time, #79528"), 521 "TestCommon", 522 "test_compare_cpu", 523 ), 524 ), 525 supports_one_python_scalar=True, 526 supports_autograd=False, 527 ), 528 UnaryUfuncInfo( 529 "special.modified_bessel_i0", 530 decorators=( 531 precisionOverride( 532 { 533 torch.float32: 1e-03, 534 torch.float64: 1e-05, 535 }, 536 ), 537 ), 538 dtypes=all_types_and(torch.bool), 539 ref=scipy.special.i0 if TEST_SCIPY else None, 540 supports_autograd=False, 541 ), 542 UnaryUfuncInfo( 543 "special.modified_bessel_i1", 544 decorators=( 545 precisionOverride( 546 { 547 torch.float32: 1e-03, 548 torch.float64: 1e-05, 549 }, 550 ), 551 ), 552 dtypes=all_types_and(torch.bool), 553 ref=scipy.special.i1 if TEST_SCIPY else None, 554 supports_autograd=False, 555 ), 556 UnaryUfuncInfo( 557 "special.modified_bessel_k0", 558 decorators=( 559 precisionOverride( 560 { 561 torch.float32: 1e-03, 562 torch.float64: 1e-05, 563 }, 564 ), 565 ), 566 dtypes=all_types_and(torch.bool), 567 ref=scipy.special.k0 if TEST_SCIPY else None, 568 supports_autograd=False, 569 ), 570 UnaryUfuncInfo( 571 "special.modified_bessel_k1", 572 decorators=( 573 precisionOverride( 574 { 575 torch.float32: 1e-03, 576 torch.float64: 1e-05, 577 }, 578 ), 579 ), 580 dtypes=all_types_and(torch.bool), 581 ref=scipy.special.k1 if TEST_SCIPY else None, 582 supports_autograd=False, 583 ), 584 UnaryUfuncInfo( 585 "special.scaled_modified_bessel_k0", 586 decorators=( 587 toleranceOverride( 588 { 589 torch.float32: tol(atol=1e-03, rtol=1e-03), 590 torch.float64: tol(atol=1e-05, rtol=1e-03), 591 } 592 ), 593 ), 594 dtypes=all_types_and(torch.bool), 595 ref=scipy.special.k0e if TEST_SCIPY else None, 596 supports_autograd=False, 597 ), 598 UnaryUfuncInfo( 599 "special.scaled_modified_bessel_k1", 600 decorators=( 601 toleranceOverride( 602 { 603 torch.float32: tol(atol=1e-03, rtol=1e-03), 604 torch.float64: tol(atol=1e-05, rtol=1e-03), 605 } 606 ), 607 ), 608 dtypes=all_types_and(torch.bool), 609 ref=scipy.special.k1e if TEST_SCIPY else None, 610 supports_autograd=False, 611 ), 612 BinaryUfuncInfo( 613 "special.shifted_chebyshev_polynomial_t", 614 dtypes=all_types_and(torch.bool), 615 promotes_int_to_float=True, 616 skips=( 617 DecorateInfo( 618 unittest.skip( 619 "Skipping - testing takes an unreasonably long time, #79528" 620 ) 621 ), 622 DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), 623 DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), 624 DecorateInfo( 625 unittest.skip("testing takes an unreasonably long time, #79528"), 626 "TestCommon", 627 "test_compare_cpu", 628 ), 629 ), 630 supports_one_python_scalar=True, 631 supports_autograd=False, 632 ), 633 BinaryUfuncInfo( 634 "special.shifted_chebyshev_polynomial_u", 635 dtypes=all_types_and(torch.bool), 636 promotes_int_to_float=True, 637 skips=( 638 DecorateInfo( 639 unittest.skip( 640 "Skipping - testing takes an unreasonably long time, #79528" 641 ) 642 ), 643 DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), 644 DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), 645 DecorateInfo( 646 unittest.skip("testing takes an unreasonably long time, #79528"), 647 "TestCommon", 648 "test_compare_cpu", 649 ), 650 ), 651 supports_one_python_scalar=True, 652 supports_autograd=False, 653 ), 654 BinaryUfuncInfo( 655 "special.shifted_chebyshev_polynomial_v", 656 dtypes=all_types_and(torch.bool), 657 promotes_int_to_float=True, 658 skips=( 659 DecorateInfo( 660 unittest.skip( 661 "Skipping - testing takes an unreasonably long time, #79528" 662 ) 663 ), 664 DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), 665 DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), 666 DecorateInfo( 667 unittest.skip("testing takes an unreasonably long time, #79528"), 668 "TestCommon", 669 "test_compare_cpu", 670 ), 671 ), 672 supports_one_python_scalar=True, 673 supports_autograd=False, 674 ), 675 BinaryUfuncInfo( 676 "special.shifted_chebyshev_polynomial_w", 677 dtypes=all_types_and(torch.bool), 678 promotes_int_to_float=True, 679 skips=( 680 DecorateInfo( 681 unittest.skip( 682 "Skipping - testing takes an unreasonably long time, #79528" 683 ) 684 ), 685 DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), 686 DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), 687 DecorateInfo( 688 unittest.skip("testing takes an unreasonably long time, #79528"), 689 "TestCommon", 690 "test_compare_cpu", 691 ), 692 ), 693 supports_one_python_scalar=True, 694 supports_autograd=False, 695 ), 696 UnaryUfuncInfo( 697 "special.spherical_bessel_j0", 698 decorators=( 699 toleranceOverride( 700 { 701 torch.float32: tol(atol=1e-03, rtol=1e-03), 702 torch.float64: tol(atol=1e-05, rtol=1e-03), 703 } 704 ), 705 ), 706 dtypes=all_types_and(torch.bool), 707 ref=lambda x: scipy.special.spherical_jn(0, x) if TEST_SCIPY else None, 708 supports_autograd=False, 709 ), 710] 711 712python_ref_db: List[OpInfo] = [ 713 # 714 # Elementwise Unary Special OpInfos 715 # 716 ElementwiseUnaryPythonRefInfo( 717 "_refs.special.bessel_j0", 718 torch_opinfo_name="special.bessel_j0", 719 op_db=op_db, 720 decorators=( 721 precisionOverride( 722 { 723 torch.float32: 1e-04, 724 torch.float64: 1e-05, 725 }, 726 ), 727 ), 728 ), 729 ElementwiseUnaryPythonRefInfo( 730 "_refs.special.bessel_j1", 731 torch_opinfo_name="special.bessel_j1", 732 op_db=op_db, 733 decorators=( 734 precisionOverride( 735 { 736 torch.float32: 1e-04, 737 torch.float64: 1e-05, 738 }, 739 ), 740 ), 741 ), 742 ElementwiseUnaryPythonRefInfo( 743 "_refs.special.entr", 744 torch_opinfo_name="special.entr", 745 op_db=op_db, 746 decorators=(precisionOverride({torch.float16: 1e-1, torch.bfloat16: 1e-1}),), 747 skips=( 748 DecorateInfo( 749 unittest.skip("Skipped!"), 750 "TestUnaryUfuncs", 751 "test_reference_numerics_large", 752 dtypes=[torch.bfloat16, torch.float16], 753 ), 754 ), 755 ), 756 ElementwiseUnaryPythonRefInfo( 757 "_refs.special.erfcx", 758 torch_opinfo_name="special.erfcx", 759 op_db=op_db, 760 decorators=( 761 toleranceOverride( 762 { 763 torch.float32: tol(atol=0, rtol=4e-6), 764 } 765 ), 766 ), 767 ), 768 ElementwiseUnaryPythonRefInfo( 769 "_refs.special.i0e", 770 torch_opinfo_name="special.i0e", 771 op_db=op_db, 772 decorators=(precisionOverride({torch.bfloat16: 3e-1, torch.float16: 3e-1}),), 773 ), 774 ElementwiseUnaryPythonRefInfo( 775 "_refs.special.i1", 776 torch_opinfo_name="special.i1", 777 op_db=op_db, 778 decorators=( 779 DecorateInfo( 780 toleranceOverride( 781 { 782 torch.float32: tol(atol=1e-4, rtol=0), 783 torch.bool: tol(atol=1e-4, rtol=0), 784 } 785 ) 786 ), 787 ), 788 skips=( 789 DecorateInfo( 790 unittest.skip("Incorrect result!"), 791 "TestUnaryUfuncs", 792 "test_reference_numerics_large", 793 dtypes=(torch.int8,), 794 ), 795 ), 796 ), 797 ElementwiseUnaryPythonRefInfo( 798 "_refs.special.i1e", 799 torch_opinfo_name="special.i1e", 800 op_db=op_db, 801 ), 802 ElementwiseUnaryPythonRefInfo( 803 "_refs.special.log_ndtr", 804 torch_opinfo_name="special.log_ndtr", 805 op_db=op_db, 806 ), 807 ElementwiseUnaryPythonRefInfo( 808 "_refs.special.ndtr", 809 torch_opinfo_name="special.ndtr", 810 op_db=op_db, 811 ), 812 ElementwiseUnaryPythonRefInfo( 813 "_refs.special.ndtri", 814 torch_opinfo_name="special.ndtri", 815 op_db=op_db, 816 ), 817 ElementwiseUnaryPythonRefInfo( 818 "_refs.special.spherical_bessel_j0", 819 torch_opinfo_name="special.spherical_bessel_j0", 820 op_db=op_db, 821 decorators=( 822 toleranceOverride( 823 { 824 torch.float32: tol(atol=1e-03, rtol=1e-03), 825 torch.float64: tol(atol=1e-05, rtol=1e-03), 826 } 827 ), 828 ), 829 ), 830 # 831 # Elementwise Binary Special OpInfos 832 # 833 ElementwiseBinaryPythonRefInfo( 834 "_refs.special.zeta", 835 torch_opinfo_name="special.zeta", 836 supports_one_python_scalar=True, 837 op_db=op_db, 838 skips=( 839 # Reference reference_inputs nans and infs on cuda and nan, inf, 0., -inf for cpu 840 DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), 841 ), 842 ), 843] 844