xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/opinfo/definitions/special.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import unittest
4from functools import partial
5from itertools import product
6from typing import List
7
8import numpy as np
9
10import torch
11from torch.testing import make_tensor
12from torch.testing._internal.common_device_type import (
13    precisionOverride,
14    tol,
15    toleranceOverride,
16)
17from torch.testing._internal.common_dtype import all_types_and, floating_types
18from torch.testing._internal.common_utils import (
19    TEST_SCIPY,
20    TEST_WITH_ROCM,
21    torch_to_numpy_dtype_dict,
22)
23from torch.testing._internal.opinfo.core import (
24    BinaryUfuncInfo,
25    DecorateInfo,
26    L,
27    NumericsFilter,
28    OpInfo,
29    S,
30    SampleInput,
31    UnaryUfuncInfo,
32)
33from torch.testing._internal.opinfo.refs import (
34    ElementwiseBinaryPythonRefInfo,
35    ElementwiseUnaryPythonRefInfo,
36)
37from torch.testing._internal.opinfo.utils import (
38    np_unary_ufunc_integer_promotion_wrapper,
39)
40
41
42if TEST_SCIPY:
43    import scipy.special
44
45
46# TODO: Consolidate `i0e` with sample_inputs_unary when `make_tensor`,
47#       supports `exclude` argument.
48#       For more context: https://github.com/pytorch/pytorch/pull/56352#discussion_r633277617
49def sample_inputs_i0_i1(op_info, device, dtype, requires_grad, **kwargs):
50    exclude_zero = requires_grad and op_info.op == torch.special.i0e
51    make_arg = partial(
52        make_tensor,
53        dtype=dtype,
54        device=device,
55        requires_grad=requires_grad,
56        exclude_zero=exclude_zero,
57    )
58    yield SampleInput(make_arg((S,)))
59    yield SampleInput(make_arg(()))
60
61    if requires_grad and not exclude_zero:
62        # Special Case for gradient
63        # Sample with `0` in the input
64        t = make_arg((S,))
65        t[0] = 0
66
67        yield SampleInput(t)
68
69
70def sample_inputs_polygamma(op_info, device, dtype, requires_grad, **kwargs):
71    make_arg = partial(
72        make_tensor,
73        device=device,
74        # TODO: eliminate low after gh-106692 is fixed:
75        low=(1 if dtype in {torch.int32, torch.int64} else None),
76        dtype=dtype,
77        requires_grad=requires_grad,
78    )
79    tensor_shapes = ((S, S), ())
80    ns = (1, 2, 3, 4, 5)
81
82    for shape, n in product(tensor_shapes, ns):
83        yield SampleInput(make_arg(shape), args=(n,))
84
85
86def reference_polygamma(x, n):
87    # WEIRD `scipy.special.polygamma` behavior
88    # >>> scipy.special.polygamma(0, np.array(501, dtype=np.float32)).dtype
89    # dtype('float64')
90    # >>> scipy.special.polygamma(0, np.array([501], dtype=np.float32)).dtype
91    # dtype('float32')
92    #
93    # Thus we cast output to the default torch dtype or preserve double
94    result_dtype = torch_to_numpy_dtype_dict[torch.get_default_dtype()]
95    if x.dtype == np.double:
96        result_dtype = np.double
97    return scipy.special.polygamma(n, x).astype(result_dtype)
98
99
100def sample_inputs_entr(op_info, device, dtype, requires_grad, **kwargs):
101    low, _ = op_info.domain
102
103    if requires_grad:
104        low = 0 + op_info._domain_eps
105
106    make_arg = partial(
107        make_tensor, dtype=dtype, device=device, low=low, requires_grad=requires_grad
108    )
109    yield SampleInput(make_arg((L,)))
110    yield SampleInput(make_arg(()))
111
112
113def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs):
114    for shape in ((L,), (1, 0, 3), ()):
115        yield SampleInput(
116            make_tensor(
117                shape,
118                device=device,
119                dtype=dtype,
120                low=-5,
121                requires_grad=requires_grad,
122            ),
123        )
124
125
126op_db: List[OpInfo] = [
127    UnaryUfuncInfo(
128        "special.i0e",
129        aten_name="special_i0e",
130        ref=scipy.special.i0e if TEST_SCIPY else None,
131        decorators=(precisionOverride({torch.bfloat16: 3e-1, torch.float16: 3e-1}),),
132        dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
133        backward_dtypes=floating_types(),
134        sample_inputs_func=sample_inputs_i0_i1,
135        supports_forward_ad=True,
136        supports_fwgrad_bwgrad=True,
137    ),
138    UnaryUfuncInfo(
139        "special.i1",
140        aten_name="special_i1",
141        ref=np_unary_ufunc_integer_promotion_wrapper(scipy.special.i1)
142        if TEST_SCIPY
143        else None,
144        dtypes=all_types_and(torch.bool),
145        dtypesIfCUDA=all_types_and(torch.bool),
146        sample_inputs_func=sample_inputs_i0_i1,
147        decorators=(
148            DecorateInfo(
149                toleranceOverride(
150                    {
151                        torch.float32: tol(atol=1e-4, rtol=0),
152                        torch.bool: tol(atol=1e-4, rtol=0),
153                    }
154                )
155            ),
156        ),
157        skips=(
158            DecorateInfo(
159                unittest.skip("Incorrect result!"),
160                "TestUnaryUfuncs",
161                "test_reference_numerics_large",
162                dtypes=(torch.int8,),
163            ),
164        ),
165        supports_fwgrad_bwgrad=True,
166        supports_forward_ad=True,
167    ),
168    UnaryUfuncInfo(
169        "special.i1e",
170        aten_name="special_i1e",
171        ref=scipy.special.i1e if TEST_SCIPY else None,
172        dtypes=all_types_and(torch.bool),
173        dtypesIfCUDA=all_types_and(torch.bool),
174        sample_inputs_func=sample_inputs_i0_i1,
175        supports_forward_ad=True,
176        supports_fwgrad_bwgrad=True,
177    ),
178    UnaryUfuncInfo(
179        "special.ndtr",
180        aten_name="special_ndtr",
181        decorators=(precisionOverride({torch.bfloat16: 5e-3, torch.float16: 5e-4}),),
182        ref=scipy.special.ndtr if TEST_SCIPY else None,
183        dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
184        supports_forward_ad=True,
185        supports_fwgrad_bwgrad=True,
186        skips=(
187            # Dispatch stub: unsupported device typemeta
188            DecorateInfo(
189                unittest.expectedFailure,
190                "TestFwdGradients",
191                "test_fn_fwgrad_bwgrad",
192                device_type="meta",
193            ),
194        ),
195    ),
196    # A separate OpInfo entry for special.polygamma is needed to reorder the arguments
197    # for the alias. See the discussion here: https://github.com/pytorch/pytorch/pull/59691#discussion_r650261939
198    UnaryUfuncInfo(
199        "special.polygamma",
200        op=lambda x, n, **kwargs: torch.special.polygamma(n, x, **kwargs),
201        variant_test_name="special_polygamma_n_0",
202        ref=reference_polygamma if TEST_SCIPY else None,
203        dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
204        dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
205        supports_forward_ad=True,
206        supports_fwgrad_bwgrad=True,
207        sample_inputs_func=sample_inputs_polygamma,
208        skips=(
209            # lambda impl
210            DecorateInfo(
211                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
212            ),
213            DecorateInfo(
214                unittest.expectedFailure,
215                "TestNormalizeOperators",
216                "test_normalize_operator_exhaustive",
217            ),
218        ),
219        sample_kwargs=lambda device, dtype, input: ({"n": 0}, {"n": 0}),
220        # polygamma functions have multiple singularities at x having non-positive integer value
221        reference_numerics_filter=NumericsFilter(
222            condition=lambda x: (x < 0.1) & ((x - x.round()).abs() < 1e-4), safe_val=1
223        ),
224    ),
225    BinaryUfuncInfo(
226        "special.xlog1py",
227        aten_name="special_xlog1py",
228        dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
229        promotes_int_to_float=True,
230        supports_forward_ad=True,
231        supports_fwgrad_bwgrad=True,
232        supports_one_python_scalar=True,
233        # We don't test -1 as the gradient will be NaN and it'll break
234        rhs_make_tensor_kwargs=dict(low=-0.99),
235    ),
236    BinaryUfuncInfo(
237        "special.zeta",
238        aten_name="special_zeta",
239        dtypes=all_types_and(torch.bool),
240        promotes_int_to_float=True,
241        supports_autograd=False,
242        supports_one_python_scalar=True,
243        skips=(
244            # Reference reference_inputs nans and infs on cuda and nan, inf, 0., -inf for cpu
245            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
246        ),
247    ),
248    # TODO: FIXME
249    # OpInfo entry to verify the gradient formula of `other`/`q`
250    # BinaryUfuncInfo('special.zeta',
251    #                 op=lambda q, x, **kwargs: torch.special.zeta(x, q, **kwargs),
252    #                 aten_name='special_zeta',
253    #                 variant_test_name='grad',
254    #                 dtypes=all_types_and(torch.bool),
255    #                 promotes_int_to_float=True,
256    #                 supports_autograd=True,
257    #                 supports_rhs_python_scalar=False,
258    #                 decorators=[
259    #                     # Derivative wrt first tensor not implemented
260    #                     DecorateInfo(unittest.expectedFailure, "TestCommon",
261    #                                  "test_floating_inputs_are_differentiable")
262    #                 ],
263    #                 skips=(
264    #                     # Lambda doesn't work in JIT test
265    #                     # AssertionError: JIT Test does not execute any logic
266    #                     DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"),
267    #                 )),
268    UnaryUfuncInfo(
269        "special.entr",
270        ref=scipy.special.entr if TEST_SCIPY else None,
271        aten_name="special_entr",
272        supports_forward_ad=True,
273        supports_fwgrad_bwgrad=True,
274        decorators=(precisionOverride({torch.float16: 1e-1, torch.bfloat16: 1e-1}),),
275        dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
276        skips=(
277            DecorateInfo(
278                unittest.skip("Skipped!"),
279                "TestUnaryUfuncs",
280                "test_reference_numerics_large",
281                dtypes=[torch.bfloat16, torch.float16],
282            ),
283        ),
284        supports_inplace_autograd=False,
285        sample_inputs_func=sample_inputs_entr,
286    ),
287    UnaryUfuncInfo(
288        "special.ndtri",
289        ref=scipy.special.ndtri if TEST_SCIPY else None,
290        domain=(0, 1),
291        aten_name="special_ndtri",
292        dtypes=all_types_and(torch.bool),
293        supports_forward_ad=True,
294        supports_fwgrad_bwgrad=True,
295    ),
296    UnaryUfuncInfo(
297        "special.log_ndtr",
298        aten_name="special_log_ndtr",
299        ref=scipy.special.log_ndtr if TEST_SCIPY else None,
300        dtypes=all_types_and(torch.bool),
301        supports_forward_ad=True,
302        supports_fwgrad_bwgrad=True,
303    ),
304    UnaryUfuncInfo(
305        "special.erfcx",
306        ref=scipy.special.erfcx if TEST_SCIPY else None,
307        aten_name="special_erfcx",
308        decorators=(
309            toleranceOverride(
310                {
311                    torch.float32: tol(atol=0, rtol=4e-6),
312                }
313            ),
314        ),
315        dtypes=all_types_and(torch.bool),
316        supports_forward_ad=True,
317        supports_fwgrad_bwgrad=True,
318        sample_inputs_func=sample_inputs_erfcx,
319    ),
320    UnaryUfuncInfo(
321        "special.airy_ai",
322        decorators=(
323            precisionOverride(
324                {
325                    torch.float32: 1e-03,
326                    torch.float64: 1e-05,
327                },
328            ),
329        ),
330        dtypes=all_types_and(torch.bool),
331        ref=lambda x: scipy.special.airy(x)[0] if TEST_SCIPY else None,
332        skips=(
333            DecorateInfo(
334                unittest.skip("Skipped!"),
335                "TestUnaryUfuncs",
336                "test_reference_numerics_large",
337            ),
338        ),
339        supports_autograd=False,
340    ),
341    UnaryUfuncInfo(
342        "special.bessel_j0",
343        decorators=(
344            precisionOverride(
345                {
346                    torch.float32: 1e-04,
347                    torch.float64: 1e-05,
348                },
349            ),
350        ),
351        dtypes=all_types_and(torch.bool),
352        ref=scipy.special.j0 if TEST_SCIPY else None,
353        supports_autograd=False,
354    ),
355    UnaryUfuncInfo(
356        "special.bessel_j1",
357        decorators=(
358            precisionOverride(
359                {
360                    torch.float32: 1e-04,
361                    torch.float64: 1e-05,
362                },
363            ),
364        ),
365        dtypes=all_types_and(torch.bool),
366        ref=scipy.special.j1 if TEST_SCIPY else None,
367        supports_autograd=False,
368    ),
369    UnaryUfuncInfo(
370        "special.bessel_y0",
371        decorators=(
372            precisionOverride(
373                {
374                    torch.float32: 1e-04,
375                    torch.float64: 1e-05,
376                },
377            ),
378        ),
379        dtypes=all_types_and(torch.bool),
380        ref=scipy.special.y0 if TEST_SCIPY else None,
381        supports_autograd=False,
382    ),
383    UnaryUfuncInfo(
384        "special.bessel_y1",
385        decorators=(
386            precisionOverride(
387                {
388                    torch.float32: 1e-04,
389                    torch.float64: 1e-05,
390                },
391            ),
392        ),
393        dtypes=all_types_and(torch.bool),
394        ref=scipy.special.y1 if TEST_SCIPY else None,
395        supports_autograd=False,
396    ),
397    BinaryUfuncInfo(
398        "special.chebyshev_polynomial_t",
399        dtypes=all_types_and(torch.bool),
400        promotes_int_to_float=True,
401        skips=(
402            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
403            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
404            DecorateInfo(
405                unittest.skip("testing takes an unreasonably long time, #79528"),
406                "TestCommon",
407                "test_compare_cpu",
408            ),
409        ),
410        supports_one_python_scalar=True,
411        supports_autograd=False,
412    ),
413    BinaryUfuncInfo(
414        "special.chebyshev_polynomial_u",
415        dtypes=all_types_and(torch.bool),
416        promotes_int_to_float=True,
417        skips=(
418            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
419            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
420            DecorateInfo(
421                unittest.skip("testing takes an unreasonably long time, #79528"),
422                "TestCommon",
423                "test_compare_cpu",
424            ),
425        ),
426        supports_one_python_scalar=True,
427        supports_autograd=False,
428    ),
429    BinaryUfuncInfo(
430        "special.chebyshev_polynomial_v",
431        dtypes=all_types_and(torch.bool),
432        promotes_int_to_float=True,
433        skips=(
434            DecorateInfo(
435                unittest.skip(
436                    "Skipping - testing takes an unreasonably long time, #79528"
437                )
438            ),
439            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
440            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
441        ),
442        supports_one_python_scalar=True,
443        supports_autograd=False,
444    ),
445    BinaryUfuncInfo(
446        "special.chebyshev_polynomial_w",
447        dtypes=all_types_and(torch.bool),
448        promotes_int_to_float=True,
449        skips=(
450            DecorateInfo(
451                unittest.skip(
452                    "Skipping - testing takes an unreasonably long time, #79528"
453                )
454            ),
455            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
456            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
457        ),
458        supports_one_python_scalar=True,
459        supports_autograd=False,
460    ),
461    BinaryUfuncInfo(
462        "special.hermite_polynomial_h",
463        dtypes=all_types_and(torch.bool),
464        promotes_int_to_float=True,
465        skips=(
466            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
467            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
468            # Greatest absolute difference: inf
469            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
470            DecorateInfo(unittest.skip("Hangs on ROCm 6.1"), active_if=TEST_WITH_ROCM),
471        ),
472        supports_one_python_scalar=True,
473        supports_autograd=False,
474    ),
475    BinaryUfuncInfo(
476        "special.hermite_polynomial_he",
477        dtypes=all_types_and(torch.bool),
478        promotes_int_to_float=True,
479        skips=(
480            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
481            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
482            DecorateInfo(
483                unittest.skip("testing takes an unreasonably long time, #79528"),
484                "TestCommon",
485                "test_compare_cpu",
486            ),
487        ),
488        supports_one_python_scalar=True,
489        supports_autograd=False,
490    ),
491    BinaryUfuncInfo(
492        "special.laguerre_polynomial_l",
493        dtypes=all_types_and(torch.bool),
494        promotes_int_to_float=True,
495        skips=(
496            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
497            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
498            DecorateInfo(
499                unittest.skip("testing takes an unreasonably long time, #79528"),
500                "TestCommon",
501                "test_compare_cpu",
502            ),
503        ),
504        supports_one_python_scalar=True,
505        supports_autograd=False,
506    ),
507    BinaryUfuncInfo(
508        "special.legendre_polynomial_p",
509        dtypes=all_types_and(torch.bool),
510        promotes_int_to_float=True,
511        skips=(
512            DecorateInfo(
513                unittest.skip(
514                    "Skipping - testing takes an unreasonably long time, #79528"
515                )
516            ),
517            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
518            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
519            DecorateInfo(
520                unittest.skip("testing takes an unreasonably long time, #79528"),
521                "TestCommon",
522                "test_compare_cpu",
523            ),
524        ),
525        supports_one_python_scalar=True,
526        supports_autograd=False,
527    ),
528    UnaryUfuncInfo(
529        "special.modified_bessel_i0",
530        decorators=(
531            precisionOverride(
532                {
533                    torch.float32: 1e-03,
534                    torch.float64: 1e-05,
535                },
536            ),
537        ),
538        dtypes=all_types_and(torch.bool),
539        ref=scipy.special.i0 if TEST_SCIPY else None,
540        supports_autograd=False,
541    ),
542    UnaryUfuncInfo(
543        "special.modified_bessel_i1",
544        decorators=(
545            precisionOverride(
546                {
547                    torch.float32: 1e-03,
548                    torch.float64: 1e-05,
549                },
550            ),
551        ),
552        dtypes=all_types_and(torch.bool),
553        ref=scipy.special.i1 if TEST_SCIPY else None,
554        supports_autograd=False,
555    ),
556    UnaryUfuncInfo(
557        "special.modified_bessel_k0",
558        decorators=(
559            precisionOverride(
560                {
561                    torch.float32: 1e-03,
562                    torch.float64: 1e-05,
563                },
564            ),
565        ),
566        dtypes=all_types_and(torch.bool),
567        ref=scipy.special.k0 if TEST_SCIPY else None,
568        supports_autograd=False,
569    ),
570    UnaryUfuncInfo(
571        "special.modified_bessel_k1",
572        decorators=(
573            precisionOverride(
574                {
575                    torch.float32: 1e-03,
576                    torch.float64: 1e-05,
577                },
578            ),
579        ),
580        dtypes=all_types_and(torch.bool),
581        ref=scipy.special.k1 if TEST_SCIPY else None,
582        supports_autograd=False,
583    ),
584    UnaryUfuncInfo(
585        "special.scaled_modified_bessel_k0",
586        decorators=(
587            toleranceOverride(
588                {
589                    torch.float32: tol(atol=1e-03, rtol=1e-03),
590                    torch.float64: tol(atol=1e-05, rtol=1e-03),
591                }
592            ),
593        ),
594        dtypes=all_types_and(torch.bool),
595        ref=scipy.special.k0e if TEST_SCIPY else None,
596        supports_autograd=False,
597    ),
598    UnaryUfuncInfo(
599        "special.scaled_modified_bessel_k1",
600        decorators=(
601            toleranceOverride(
602                {
603                    torch.float32: tol(atol=1e-03, rtol=1e-03),
604                    torch.float64: tol(atol=1e-05, rtol=1e-03),
605                }
606            ),
607        ),
608        dtypes=all_types_and(torch.bool),
609        ref=scipy.special.k1e if TEST_SCIPY else None,
610        supports_autograd=False,
611    ),
612    BinaryUfuncInfo(
613        "special.shifted_chebyshev_polynomial_t",
614        dtypes=all_types_and(torch.bool),
615        promotes_int_to_float=True,
616        skips=(
617            DecorateInfo(
618                unittest.skip(
619                    "Skipping - testing takes an unreasonably long time, #79528"
620                )
621            ),
622            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
623            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
624            DecorateInfo(
625                unittest.skip("testing takes an unreasonably long time, #79528"),
626                "TestCommon",
627                "test_compare_cpu",
628            ),
629        ),
630        supports_one_python_scalar=True,
631        supports_autograd=False,
632    ),
633    BinaryUfuncInfo(
634        "special.shifted_chebyshev_polynomial_u",
635        dtypes=all_types_and(torch.bool),
636        promotes_int_to_float=True,
637        skips=(
638            DecorateInfo(
639                unittest.skip(
640                    "Skipping - testing takes an unreasonably long time, #79528"
641                )
642            ),
643            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
644            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
645            DecorateInfo(
646                unittest.skip("testing takes an unreasonably long time, #79528"),
647                "TestCommon",
648                "test_compare_cpu",
649            ),
650        ),
651        supports_one_python_scalar=True,
652        supports_autograd=False,
653    ),
654    BinaryUfuncInfo(
655        "special.shifted_chebyshev_polynomial_v",
656        dtypes=all_types_and(torch.bool),
657        promotes_int_to_float=True,
658        skips=(
659            DecorateInfo(
660                unittest.skip(
661                    "Skipping - testing takes an unreasonably long time, #79528"
662                )
663            ),
664            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
665            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
666            DecorateInfo(
667                unittest.skip("testing takes an unreasonably long time, #79528"),
668                "TestCommon",
669                "test_compare_cpu",
670            ),
671        ),
672        supports_one_python_scalar=True,
673        supports_autograd=False,
674    ),
675    BinaryUfuncInfo(
676        "special.shifted_chebyshev_polynomial_w",
677        dtypes=all_types_and(torch.bool),
678        promotes_int_to_float=True,
679        skips=(
680            DecorateInfo(
681                unittest.skip(
682                    "Skipping - testing takes an unreasonably long time, #79528"
683                )
684            ),
685            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
686            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
687            DecorateInfo(
688                unittest.skip("testing takes an unreasonably long time, #79528"),
689                "TestCommon",
690                "test_compare_cpu",
691            ),
692        ),
693        supports_one_python_scalar=True,
694        supports_autograd=False,
695    ),
696    UnaryUfuncInfo(
697        "special.spherical_bessel_j0",
698        decorators=(
699            toleranceOverride(
700                {
701                    torch.float32: tol(atol=1e-03, rtol=1e-03),
702                    torch.float64: tol(atol=1e-05, rtol=1e-03),
703                }
704            ),
705        ),
706        dtypes=all_types_and(torch.bool),
707        ref=lambda x: scipy.special.spherical_jn(0, x) if TEST_SCIPY else None,
708        supports_autograd=False,
709    ),
710]
711
712python_ref_db: List[OpInfo] = [
713    #
714    # Elementwise Unary Special OpInfos
715    #
716    ElementwiseUnaryPythonRefInfo(
717        "_refs.special.bessel_j0",
718        torch_opinfo_name="special.bessel_j0",
719        op_db=op_db,
720        decorators=(
721            precisionOverride(
722                {
723                    torch.float32: 1e-04,
724                    torch.float64: 1e-05,
725                },
726            ),
727        ),
728    ),
729    ElementwiseUnaryPythonRefInfo(
730        "_refs.special.bessel_j1",
731        torch_opinfo_name="special.bessel_j1",
732        op_db=op_db,
733        decorators=(
734            precisionOverride(
735                {
736                    torch.float32: 1e-04,
737                    torch.float64: 1e-05,
738                },
739            ),
740        ),
741    ),
742    ElementwiseUnaryPythonRefInfo(
743        "_refs.special.entr",
744        torch_opinfo_name="special.entr",
745        op_db=op_db,
746        decorators=(precisionOverride({torch.float16: 1e-1, torch.bfloat16: 1e-1}),),
747        skips=(
748            DecorateInfo(
749                unittest.skip("Skipped!"),
750                "TestUnaryUfuncs",
751                "test_reference_numerics_large",
752                dtypes=[torch.bfloat16, torch.float16],
753            ),
754        ),
755    ),
756    ElementwiseUnaryPythonRefInfo(
757        "_refs.special.erfcx",
758        torch_opinfo_name="special.erfcx",
759        op_db=op_db,
760        decorators=(
761            toleranceOverride(
762                {
763                    torch.float32: tol(atol=0, rtol=4e-6),
764                }
765            ),
766        ),
767    ),
768    ElementwiseUnaryPythonRefInfo(
769        "_refs.special.i0e",
770        torch_opinfo_name="special.i0e",
771        op_db=op_db,
772        decorators=(precisionOverride({torch.bfloat16: 3e-1, torch.float16: 3e-1}),),
773    ),
774    ElementwiseUnaryPythonRefInfo(
775        "_refs.special.i1",
776        torch_opinfo_name="special.i1",
777        op_db=op_db,
778        decorators=(
779            DecorateInfo(
780                toleranceOverride(
781                    {
782                        torch.float32: tol(atol=1e-4, rtol=0),
783                        torch.bool: tol(atol=1e-4, rtol=0),
784                    }
785                )
786            ),
787        ),
788        skips=(
789            DecorateInfo(
790                unittest.skip("Incorrect result!"),
791                "TestUnaryUfuncs",
792                "test_reference_numerics_large",
793                dtypes=(torch.int8,),
794            ),
795        ),
796    ),
797    ElementwiseUnaryPythonRefInfo(
798        "_refs.special.i1e",
799        torch_opinfo_name="special.i1e",
800        op_db=op_db,
801    ),
802    ElementwiseUnaryPythonRefInfo(
803        "_refs.special.log_ndtr",
804        torch_opinfo_name="special.log_ndtr",
805        op_db=op_db,
806    ),
807    ElementwiseUnaryPythonRefInfo(
808        "_refs.special.ndtr",
809        torch_opinfo_name="special.ndtr",
810        op_db=op_db,
811    ),
812    ElementwiseUnaryPythonRefInfo(
813        "_refs.special.ndtri",
814        torch_opinfo_name="special.ndtri",
815        op_db=op_db,
816    ),
817    ElementwiseUnaryPythonRefInfo(
818        "_refs.special.spherical_bessel_j0",
819        torch_opinfo_name="special.spherical_bessel_j0",
820        op_db=op_db,
821        decorators=(
822            toleranceOverride(
823                {
824                    torch.float32: tol(atol=1e-03, rtol=1e-03),
825                    torch.float64: tol(atol=1e-05, rtol=1e-03),
826                }
827            ),
828        ),
829    ),
830    #
831    # Elementwise Binary Special OpInfos
832    #
833    ElementwiseBinaryPythonRefInfo(
834        "_refs.special.zeta",
835        torch_opinfo_name="special.zeta",
836        supports_one_python_scalar=True,
837        op_db=op_db,
838        skips=(
839            # Reference reference_inputs nans and infs on cuda and nan, inf, 0., -inf for cpu
840            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
841        ),
842    ),
843]
844