xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/UnaryOps.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/native/DispatchStub.h>
4 #include <ATen/Generator.h>
5 #include <c10/core/Scalar.h>
6 #include <stdexcept>
7 
8 namespace at {
9 class Tensor;
10 class TensorBase;
11 struct TensorIteratorBase;
12 }
13 
14 namespace at::native {
15 
16 using unary_fn = void(*)(TensorIteratorBase&);
17 using unary_fn_with_scalar = void(*)(TensorIteratorBase&, const Scalar& a);
18 
19 inline namespace CPU_CAPABILITY {
20 void conj_kernel(TensorIteratorBase &iter);
21 void neg_kernel(TensorIteratorBase &iter);
22 void reciprocal_kernel(TensorIteratorBase &iter);
23 void rsqrt_kernel(TensorIteratorBase& iter);
24 void sqrt_kernel(TensorIteratorBase& iter);
25 } // namespace CPU_CAPABILITY
26 
27 DECLARE_DISPATCH(unary_fn, abs_stub);
28 DECLARE_DISPATCH(unary_fn, angle_stub);
29 DECLARE_DISPATCH(unary_fn, conj_physical_stub);
30 DECLARE_DISPATCH(unary_fn, acos_stub);
31 DECLARE_DISPATCH(unary_fn, acosh_stub);
32 DECLARE_DISPATCH(unary_fn, asinh_stub);
33 DECLARE_DISPATCH(unary_fn, atanh_stub);
34 DECLARE_DISPATCH(unary_fn, asin_stub);
35 DECLARE_DISPATCH(unary_fn, atan_stub);
36 DECLARE_DISPATCH(unary_fn, bitwise_not_stub);
37 DECLARE_DISPATCH(unary_fn, logical_not_stub);
38 DECLARE_DISPATCH(unary_fn, ceil_stub);
39 DECLARE_DISPATCH(unary_fn, cos_stub);
40 DECLARE_DISPATCH(unary_fn, cosh_stub);
41 DECLARE_DISPATCH(unary_fn, digamma_stub);
42 DECLARE_DISPATCH(unary_fn, special_entr_stub);
43 DECLARE_DISPATCH(unary_fn, special_erfcx_stub);
44 DECLARE_DISPATCH(unary_fn, erf_stub);
45 DECLARE_DISPATCH(unary_fn, erfc_stub);
46 DECLARE_DISPATCH(unary_fn, erfinv_stub);
47 DECLARE_DISPATCH(unary_fn, exp_stub);
48 DECLARE_DISPATCH(unary_fn, exp2_stub);
49 DECLARE_DISPATCH(unary_fn, expm1_stub);
50 DECLARE_DISPATCH(unary_fn, floor_stub);
51 DECLARE_DISPATCH(unary_fn, frac_stub);
52 DECLARE_DISPATCH(unary_fn, frexp_stub);
53 DECLARE_DISPATCH(unary_fn, i0_stub);
54 DECLARE_DISPATCH(unary_fn, special_i0e_stub);
55 DECLARE_DISPATCH(unary_fn, special_i1_stub);
56 DECLARE_DISPATCH(unary_fn, special_i1e_stub);
57 DECLARE_DISPATCH(unary_fn, log_stub);
58 DECLARE_DISPATCH(unary_fn, log10_stub);
59 DECLARE_DISPATCH(unary_fn, log1p_stub);
60 DECLARE_DISPATCH(unary_fn, log2_stub);
61 DECLARE_DISPATCH(unary_fn, special_ndtri_stub);
62 DECLARE_DISPATCH(unary_fn, special_log_ndtr_stub);
63 DECLARE_DISPATCH(unary_fn, neg_stub);
64 
65 DECLARE_DISPATCH(unary_fn, reciprocal_stub);
66 DECLARE_DISPATCH(unary_fn, round_stub);
67 DECLARE_DISPATCH(unary_fn, rsqrt_stub);
68 DECLARE_DISPATCH(unary_fn, sigmoid_stub);
69 DECLARE_DISPATCH(unary_fn_with_scalar, logit_stub);
70 DECLARE_DISPATCH(unary_fn, sign_stub);
71 DECLARE_DISPATCH(unary_fn, signbit_stub);
72 DECLARE_DISPATCH(unary_fn, sgn_stub);
73 DECLARE_DISPATCH(unary_fn, sin_stub);
74 DECLARE_DISPATCH(unary_fn, sinc_stub);
75 DECLARE_DISPATCH(unary_fn, sinh_stub);
76 DECLARE_DISPATCH(unary_fn, sqrt_stub);
77 DECLARE_DISPATCH(unary_fn, tan_stub);
78 DECLARE_DISPATCH(unary_fn, tanh_stub);
79 DECLARE_DISPATCH(unary_fn, trigamma_stub);
80 DECLARE_DISPATCH(unary_fn, trunc_stub);
81 DECLARE_DISPATCH(unary_fn, lgamma_stub);
82 DECLARE_DISPATCH(unary_fn, special_airy_ai_stub);
83 DECLARE_DISPATCH(unary_fn, special_bessel_j0_stub);
84 DECLARE_DISPATCH(unary_fn, special_bessel_j1_stub);
85 DECLARE_DISPATCH(unary_fn, special_bessel_y0_stub);
86 DECLARE_DISPATCH(unary_fn, special_bessel_y1_stub);
87 DECLARE_DISPATCH(unary_fn, special_modified_bessel_i0_stub);
88 DECLARE_DISPATCH(unary_fn, special_modified_bessel_i1_stub);
89 DECLARE_DISPATCH(unary_fn, special_modified_bessel_k0_stub);
90 DECLARE_DISPATCH(unary_fn, special_modified_bessel_k1_stub);
91 DECLARE_DISPATCH(unary_fn, special_scaled_modified_bessel_k0_stub);
92 DECLARE_DISPATCH(unary_fn, special_scaled_modified_bessel_k1_stub);
93 DECLARE_DISPATCH(unary_fn, special_spherical_bessel_j0_stub);
94 
95 // NB: these are actually defined in Distribution
96 DECLARE_DISPATCH(void(*)(const TensorBase&, const TensorBase&, std::optional<Generator>), bernoulli_tensor_stub);
97 DECLARE_DISPATCH(void(*)(const TensorBase&, const double, std::optional<Generator>), bernoulli_scalar_stub);
98 DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, std::optional<Generator>), cauchy_stub);
99 DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, std::optional<Generator>), exponential_stub);
100 DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, std::optional<Generator>), geometric_stub);
101 DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, std::optional<Generator>), log_normal_stub);
102 DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, std::optional<Generator>), uniform_stub);
103 DECLARE_DISPATCH(void(*)(const TensorBase&, const double, const double, std::optional<Generator>), normal_stub);
104 DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const uint64_t, const int64_t, std::optional<Generator>), random_from_to_stub);
105 DECLARE_DISPATCH(void(*)(TensorIteratorBase&, std::optional<Generator>), random_full_64_bits_range_stub);
106 DECLARE_DISPATCH(void(*)(TensorIteratorBase&, std::optional<Generator>), random_stub);
107 
108 DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const int64_t, const double), kaiser_window_stub);
109 DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const int64_t), polygamma_stub);
110 DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const Scalar& a, const Scalar& b), clamp_stub);
111 DECLARE_DISPATCH(
112     void (*)(Tensor&, const Tensor&, int64_t, std::optional<Generator>),
113     multinomial_with_replacement_stub);
114 DECLARE_DISPATCH(
115     void (*)(
116         TensorIteratorBase&,
117         std::optional<double>,
118         std::optional<double>,
119         std::optional<double>),
120     nan_to_num_stub);
121 DECLARE_DISPATCH(void (*)(TensorIteratorBase&, int64_t), round_decimals_stub);
122 
123 // Missing unary functions
124 // digamma
125 // lgamma
126 // erfinv
127 // clone
128 // contiguous
129 // zero
130 } // namespace at::native
131