1 #pragma once 2 3 #include <ATen/native/DispatchStub.h> 4 #include <ATen/Generator.h> 5 #include <c10/core/Scalar.h> 6 #include <stdexcept> 7 8 namespace at { 9 class Tensor; 10 class TensorBase; 11 struct TensorIteratorBase; 12 } 13 14 namespace at::native { 15 16 using unary_fn = void(*)(TensorIteratorBase&); 17 using unary_fn_with_scalar = void(*)(TensorIteratorBase&, const Scalar& a); 18 19 inline namespace CPU_CAPABILITY { 20 void conj_kernel(TensorIteratorBase &iter); 21 void neg_kernel(TensorIteratorBase &iter); 22 void reciprocal_kernel(TensorIteratorBase &iter); 23 void rsqrt_kernel(TensorIteratorBase& iter); 24 void sqrt_kernel(TensorIteratorBase& iter); 25 } // namespace CPU_CAPABILITY 26 27 DECLARE_DISPATCH(unary_fn, abs_stub); 28 DECLARE_DISPATCH(unary_fn, angle_stub); 29 DECLARE_DISPATCH(unary_fn, conj_physical_stub); 30 DECLARE_DISPATCH(unary_fn, acos_stub); 31 DECLARE_DISPATCH(unary_fn, acosh_stub); 32 DECLARE_DISPATCH(unary_fn, asinh_stub); 33 DECLARE_DISPATCH(unary_fn, atanh_stub); 34 DECLARE_DISPATCH(unary_fn, asin_stub); 35 DECLARE_DISPATCH(unary_fn, atan_stub); 36 DECLARE_DISPATCH(unary_fn, bitwise_not_stub); 37 DECLARE_DISPATCH(unary_fn, logical_not_stub); 38 DECLARE_DISPATCH(unary_fn, ceil_stub); 39 DECLARE_DISPATCH(unary_fn, cos_stub); 40 DECLARE_DISPATCH(unary_fn, cosh_stub); 41 DECLARE_DISPATCH(unary_fn, digamma_stub); 42 DECLARE_DISPATCH(unary_fn, special_entr_stub); 43 DECLARE_DISPATCH(unary_fn, special_erfcx_stub); 44 DECLARE_DISPATCH(unary_fn, erf_stub); 45 DECLARE_DISPATCH(unary_fn, erfc_stub); 46 DECLARE_DISPATCH(unary_fn, erfinv_stub); 47 DECLARE_DISPATCH(unary_fn, exp_stub); 48 DECLARE_DISPATCH(unary_fn, exp2_stub); 49 DECLARE_DISPATCH(unary_fn, expm1_stub); 50 DECLARE_DISPATCH(unary_fn, floor_stub); 51 DECLARE_DISPATCH(unary_fn, frac_stub); 52 DECLARE_DISPATCH(unary_fn, frexp_stub); 53 DECLARE_DISPATCH(unary_fn, i0_stub); 54 DECLARE_DISPATCH(unary_fn, special_i0e_stub); 55 DECLARE_DISPATCH(unary_fn, special_i1_stub); 56 DECLARE_DISPATCH(unary_fn, special_i1e_stub); 57 DECLARE_DISPATCH(unary_fn, log_stub); 58 DECLARE_DISPATCH(unary_fn, log10_stub); 59 DECLARE_DISPATCH(unary_fn, log1p_stub); 60 DECLARE_DISPATCH(unary_fn, log2_stub); 61 DECLARE_DISPATCH(unary_fn, special_ndtri_stub); 62 DECLARE_DISPATCH(unary_fn, special_log_ndtr_stub); 63 DECLARE_DISPATCH(unary_fn, neg_stub); 64 65 DECLARE_DISPATCH(unary_fn, reciprocal_stub); 66 DECLARE_DISPATCH(unary_fn, round_stub); 67 DECLARE_DISPATCH(unary_fn, rsqrt_stub); 68 DECLARE_DISPATCH(unary_fn, sigmoid_stub); 69 DECLARE_DISPATCH(unary_fn_with_scalar, logit_stub); 70 DECLARE_DISPATCH(unary_fn, sign_stub); 71 DECLARE_DISPATCH(unary_fn, signbit_stub); 72 DECLARE_DISPATCH(unary_fn, sgn_stub); 73 DECLARE_DISPATCH(unary_fn, sin_stub); 74 DECLARE_DISPATCH(unary_fn, sinc_stub); 75 DECLARE_DISPATCH(unary_fn, sinh_stub); 76 DECLARE_DISPATCH(unary_fn, sqrt_stub); 77 DECLARE_DISPATCH(unary_fn, tan_stub); 78 DECLARE_DISPATCH(unary_fn, tanh_stub); 79 DECLARE_DISPATCH(unary_fn, trigamma_stub); 80 DECLARE_DISPATCH(unary_fn, trunc_stub); 81 DECLARE_DISPATCH(unary_fn, lgamma_stub); 82 DECLARE_DISPATCH(unary_fn, special_airy_ai_stub); 83 DECLARE_DISPATCH(unary_fn, special_bessel_j0_stub); 84 DECLARE_DISPATCH(unary_fn, special_bessel_j1_stub); 85 DECLARE_DISPATCH(unary_fn, special_bessel_y0_stub); 86 DECLARE_DISPATCH(unary_fn, special_bessel_y1_stub); 87 DECLARE_DISPATCH(unary_fn, special_modified_bessel_i0_stub); 88 DECLARE_DISPATCH(unary_fn, special_modified_bessel_i1_stub); 89 DECLARE_DISPATCH(unary_fn, special_modified_bessel_k0_stub); 90 DECLARE_DISPATCH(unary_fn, special_modified_bessel_k1_stub); 91 DECLARE_DISPATCH(unary_fn, special_scaled_modified_bessel_k0_stub); 92 DECLARE_DISPATCH(unary_fn, special_scaled_modified_bessel_k1_stub); 93 DECLARE_DISPATCH(unary_fn, special_spherical_bessel_j0_stub); 94 95 // NB: these are actually defined in Distribution 96 DECLARE_DISPATCH(void(*)(const TensorBase&, const TensorBase&, std::optional<Generator>), bernoulli_tensor_stub); 97 DECLARE_DISPATCH(void(*)(const TensorBase&, const double, std::optional<Generator>), bernoulli_scalar_stub); 98 DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, std::optional<Generator>), cauchy_stub); 99 DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, std::optional<Generator>), exponential_stub); 100 DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, std::optional<Generator>), geometric_stub); 101 DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, std::optional<Generator>), log_normal_stub); 102 DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, std::optional<Generator>), uniform_stub); 103 DECLARE_DISPATCH(void(*)(const TensorBase&, const double, const double, std::optional<Generator>), normal_stub); 104 DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const uint64_t, const int64_t, std::optional<Generator>), random_from_to_stub); 105 DECLARE_DISPATCH(void(*)(TensorIteratorBase&, std::optional<Generator>), random_full_64_bits_range_stub); 106 DECLARE_DISPATCH(void(*)(TensorIteratorBase&, std::optional<Generator>), random_stub); 107 108 DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const int64_t, const double), kaiser_window_stub); 109 DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const int64_t), polygamma_stub); 110 DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const Scalar& a, const Scalar& b), clamp_stub); 111 DECLARE_DISPATCH( 112 void (*)(Tensor&, const Tensor&, int64_t, std::optional<Generator>), 113 multinomial_with_replacement_stub); 114 DECLARE_DISPATCH( 115 void (*)( 116 TensorIteratorBase&, 117 std::optional<double>, 118 std::optional<double>, 119 std::optional<double>), 120 nan_to_num_stub); 121 DECLARE_DISPATCH(void (*)(TensorIteratorBase&, int64_t), round_decimals_stub); 122 123 // Missing unary functions 124 // digamma 125 // lgamma 126 // erfinv 127 // clone 128 // contiguous 129 // zero 130 } // namespace at::native 131