xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/UnaryOps.h>
3 
4 #include <cmath>
5 #include <limits>
6 #include <type_traits>
7 
8 #include <ATen/Config.h>
9 #include <ATen/Context.h>
10 #include <ATen/Dispatch.h>
11 #include <ATen/Parallel.h>
12 #include <ATen/cpu/vec/functional.h>
13 #include <ATen/cpu/vec/vec.h>
14 #include <ATen/cpu/vml.h>
15 #include <ATen/native/TensorIterator.h>
16 #include <ATen/native/cpu/CopyKernel.h>
17 #include <ATen/native/cpu/Loops.h>
18 #include <ATen/native/cpu/zmath.h>
19 #include <ATen/OpMathType.h>
20 
21 #include <c10/util/MathConstants.h>
22 #include <c10/core/Scalar.h>
23 #include <c10/util/TypeSafeSignMath.h>
24 #include <c10/util/irange.h>
25 
26 #if AT_MKL_ENABLED()
27 #include <mkl.h>
28 #endif
29 
30 namespace at::native {
31 
32 inline namespace CPU_CAPABILITY {
33 
34 using namespace vec;
35 
sigmoid_kernel(TensorIteratorBase & iter)36 static void sigmoid_kernel(TensorIteratorBase& iter) {
37   const auto dtype = iter.common_dtype();
38   if (at::isReducedFloatingType(dtype)) {
39     AT_DISPATCH_REDUCED_FLOATING_TYPES(dtype, "sigmoid_cpu_reduced_float", [&]() {
40       cpu_kernel_vec(
41           iter,
42           [=](scalar_t a) -> scalar_t {
43             float a0 = static_cast<float>(a);
44             return static_cast<float>(1) / (static_cast<float>(1) + std::exp((-a0)));
45           },
46           [=](Vectorized<scalar_t> a) {
47             auto [a0, a1] = convert_to_float<scalar_t>(a);
48             a0 = (Vectorized<float>(static_cast<float>(1)) + a0.neg().exp()).reciprocal();
49             a1 = (Vectorized<float>(static_cast<float>(1)) + a1.neg().exp()).reciprocal();
50             return convert_from_float<scalar_t>(a0, a1);
51           });
52     });
53   } else {
54     AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(dtype, "sigmoid_cpu", [&]() {
55       cpu_kernel_vec(
56           iter,
57           [=](scalar_t a) -> scalar_t {
58             return (static_cast<scalar_t>(1) / (static_cast<scalar_t>(1) + std::exp((-a))));
59           },
60           [=](Vectorized<scalar_t> a) {
61             a = Vectorized<scalar_t>(static_cast<scalar_t>(0)) - a;
62             a = a.exp();
63             a = Vectorized<scalar_t>(static_cast<scalar_t>(1)) + a;
64             a = a.reciprocal();
65             return a;
66           });
67     });
68   }
69 }
70 
71 #if AT_MKL_ENABLED()
72 
73 template <typename T>
VmlLog(int64_t N,const T * X,T * Y)74 void VmlLog(int64_t N, const T* X, T* Y) {
75   constexpr int64_t K = Vectorized<T>::size();
76   at::parallel_for(0, N, K, [=](int64_t begin, int64_t end) {
77     using VT = at::opmath_type<T>;
78     vec::map(
79         [](Vectorized<VT> x_vec) { return x_vec.log(); },
80         Y + begin,
81         X + begin,
82         end - begin);
83   });
84 }
85 
86 template <>
VmlLog(int64_t N,const float * X,float * Y)87 void VmlLog<float>(int64_t N, const float* X, float* Y) {
88   vsLn(N, X, Y);
89 }
90 
91 template <>
VmlLog(int64_t N,const double * X,double * Y)92 void VmlLog<double>(int64_t N, const double* X, double* Y) {
93   vdLn(N, X, Y);
94 }
95 
96 template <typename T>
LogitMKLKernel(T eps,TensorIteratorBase * it)97 void LogitMKLKernel(T eps, TensorIteratorBase* it) {
98   if (!it->can_use_32bit_indexing()) {
99     for (auto& sub_it : it->with_32bit_indexing()) {
100       LogitMKLKernel<T>(eps, &sub_it);
101     }
102     return;
103   }
104 
105   constexpr int64_t K = Vectorized<T>::size();
106   const int64_t N = it->numel();
107   const T* X_data = static_cast<T*>(it->data_ptr(1));
108   T* Y_data = static_cast<T*>(it->data_ptr(0));
109   if (eps < T(0)) {
110     at::parallel_for(0, N, K, [=](int64_t begin, int64_t end) {
111       for (const auto i : c10::irange(begin, end)) {
112         Y_data[i] = X_data[i] == T(1) ? std::numeric_limits<T>::infinity()
113                                       : X_data[i] / (T(1) - X_data[i]);
114       }
115       VmlLog<T>(end - begin, Y_data + begin, Y_data + begin);
116     });
117   } else {
118     const T lo = eps;
119     const T hi = T(1) - eps;
120     at::parallel_for(0, N, K, [=](int64_t begin, int64_t end) {
121       for (const auto i : c10::irange(begin, end)) {
122         const T x = X_data[i] < lo ? lo : (X_data[i] > hi ? hi : X_data[i]);
123         Y_data[i] =
124             x == T(1) ? std::numeric_limits<T>::infinity() : (x / (T(1) - x));
125       }
126       VmlLog<T>(end - begin, Y_data + begin, Y_data + begin);
127     });
128   }
129 }
130 
131 #else
132 
133 template <typename T>
LogitMKLKernel(T eps,TensorIteratorBase * it)134 void LogitMKLKernel(T eps, TensorIteratorBase* it) {
135   TORCH_CHECK(false, "ATen not compiled with MKL");
136 }
137 
138 #endif // AT_MKL_ENABLED
139 
logit_kernel(TensorIteratorBase & iter,const Scalar & eps_scalar)140 static void logit_kernel(TensorIteratorBase& iter, const Scalar& eps_scalar) {
141   AT_DISPATCH_FLOATING_TYPES_AND2(
142       kBFloat16, kHalf, iter.common_dtype(), "logit_cpu", [&]() {
143         const scalar_t eps = eps_scalar.to<scalar_t>();
144         if (at::hasMKL() && iter.is_contiguous()) {
145           LogitMKLKernel<scalar_t>(eps, &iter);
146           iter.cast_outputs();
147         } else if (eps < scalar_t(0)) {
148           const Vectorized<scalar_t> kOneVec(scalar_t(1));
149           cpu_kernel_vec(
150               iter,
151               [](scalar_t x) {
152                 return x == scalar_t(1)
153                     ? std::numeric_limits<scalar_t>::infinity()
154                     : std::log(x / (scalar_t(1) - x));
155               },
156               [kOneVec](Vectorized<scalar_t> x_vec) {
157                 return (x_vec / (kOneVec - x_vec)).log();
158               });
159         } else {
160           const scalar_t lo = eps;
161           const scalar_t hi = scalar_t(1) - eps;
162           const Vectorized<scalar_t> kOneVec(scalar_t(1));
163           const Vectorized<scalar_t> lo_vec(lo);
164           const Vectorized<scalar_t> hi_vec(hi);
165           cpu_kernel_vec(
166               iter,
167               [lo, hi](scalar_t x) {
168                 x = x < lo ? lo : (x > hi ? hi : x);
169                 return x == scalar_t(1)
170                     ? std::numeric_limits<scalar_t>::infinity()
171                     : std::log(x / (scalar_t(1) - x));
172               },
173               [kOneVec, lo_vec, hi_vec](Vectorized<scalar_t> x_vec) {
174                 x_vec = vec::clamp(x_vec, lo_vec, hi_vec);
175                 return (x_vec / (kOneVec - x_vec)).log();
176               });
177         }
178       });
179 }
180 
181 #if !defined(C10_MOBILE)
182 #define _AT_DISPATCH_ABS_TYPES(TYPE, NAME, ...)                                                 \
183         AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6(                                                 \
184             kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, \
185             TYPE, NAME, __VA_ARGS__)
186 #else
187 #define _AT_DISPATCH_ABS_TYPES(TYPE, NAME, ...)          \
188         AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(          \
189             kHalf, kBFloat16,                            \
190             TYPE, NAME, __VA_ARGS__)
191 #endif
192 
abs_kernel(TensorIteratorBase & iter)193 static void abs_kernel(TensorIteratorBase& iter) {
194   auto dtype = iter.dtype();
195   if (dtype == kComplexHalf) {
196     using scalar_t = c10::complex<Half>;
197     using opmath_t = at::opmath_type<scalar_t>;
198     cpu_kernel(iter, [=](scalar_t a) -> scalar_t { return abs_impl(opmath_t{a}); });
199   } else {
200     _AT_DISPATCH_ABS_TYPES(iter.dtype(), "abs_cpu", [&]() {
201       cpu_kernel_vec(
202           iter,
203           [=](scalar_t a) -> scalar_t { return abs_impl(a); },
204           [=](Vectorized<scalar_t> a) { return a.abs(); });
205     });
206   }
207 }
208 
angle_kernel(TensorIteratorBase & iter)209 static void angle_kernel(TensorIteratorBase& iter) {
210   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "angle_cpu", [&]() {
211     cpu_kernel_vec(
212         iter,
213         [=](scalar_t a) -> scalar_t { return angle_impl(a); },
214         [=](Vectorized<scalar_t> a) { return a.angle(); });
215   });
216 }
217 
218 // NB: Ignores the negative bit on tensors
conj_kernel(TensorIteratorBase & iter)219 void conj_kernel(TensorIteratorBase& iter) {
220   AT_DISPATCH_SWITCH(iter.common_dtype(), "conj_cpu",
221     AT_DISPATCH_CASE_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, [&] {
222       // conj is a no-op for non-complex types
223       direct_copy_kernel(iter);
224     })
225     AT_DISPATCH_CASE_COMPLEX_TYPES_AND(kComplexHalf, [&] {
226       cpu_kernel_vec(
227           iter,
228           [=](scalar_t a) -> scalar_t { return conj_impl(a); },
229           [=](Vectorized<scalar_t> a) { return a.conj(); });
230     })
231   );
232 }
233 
bitwise_not_kernel(TensorIteratorBase & iter)234 static void bitwise_not_kernel(TensorIteratorBase& iter) {
235   if (iter.dtype() == ScalarType::Bool) {
236     // Boolean type does not work with ~ (bitwise NOT) in C++. bitwise_not wraps this operation for both Boolean and
237     // integral types.
238     cpu_kernel(
239           iter,
240           [](bool a) {
241             return !a;
242           });
243   } else {
244     AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_not_cpu", [&]() {
245       cpu_kernel_vec(
246           iter,
247           [](scalar_t a) -> scalar_t {
248             return ~a;
249           },
250           [](Vectorized<scalar_t> a) -> Vectorized<scalar_t> {
251             return ~a;
252           });
253     });
254   }
255 }
256 
frac_kernel(TensorIteratorBase & iter)257 static void frac_kernel(TensorIteratorBase& iter) {
258   AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "frac_cpu", [&]() {
259     cpu_kernel_vec(
260         iter,
261         [=](scalar_t a) -> scalar_t { return a - std::trunc(a); },
262         [=](Vectorized<scalar_t> a) { return a.frac(); });
263   });
264 }
265 
logical_not_kernel(TensorIteratorBase & iter)266 static void logical_not_kernel(TensorIteratorBase& iter) {
267   // NOTE: this implementation differs from the CUDA implementation which only does single dispatch
268   // (to avoid expensive compilation) because CPU kernels don't handle dynamic_casting
269   // (see needs_dynamic_casting).
270   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kHalf, kBFloat16, iter.dtype(1), "logical_not_cpu", [&]() {
271     using self_t = scalar_t;
272     AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kHalf, kBFloat16, iter.dtype(0), "logical_not_cpu", [&]() {
273       cpu_kernel(iter, [](self_t a) -> scalar_t { return static_cast<scalar_t>(!a); });
274     });
275   });
276 }
277 
reciprocal_kernel(TensorIteratorBase & iter)278 void reciprocal_kernel(TensorIteratorBase& iter) {
279   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "reciprocal_cpu", [&]() {
280     cpu_kernel_vec(
281         iter,
282         [=](scalar_t a) __ubsan_ignore_float_divide_by_zero__ -> scalar_t { return static_cast<scalar_t>(1.0) / a; },
283         [=](Vectorized<scalar_t> a) { return a.reciprocal(); });
284   });
285 }
286 
287 // NB: Ignores the negative bit on tensors
neg_kernel(TensorIteratorBase & iter)288 void neg_kernel(TensorIteratorBase& iter) {
289   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kComplexHalf, kBFloat16, kHalf, iter.dtype(), "neg_cpu", [&]() {
290     cpu_kernel_vec(
291         iter,
292         [=](scalar_t a) -> scalar_t { return -a; },
293         [=](Vectorized<scalar_t> a) { return a.neg(); });
294   });
295 }
296 
sign_kernel(TensorIteratorBase & iter)297 static void sign_kernel(TensorIteratorBase& iter){
298   if(iter.dtype() == ScalarType::Bool){
299       cpu_kernel(iter, [=](bool x) -> bool { return x; });
300   } else {
301     AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, ScalarType::Half, iter.dtype(), "sign_cpu", [&]() {
302         auto zero_vec = Vectorized<scalar_t>(static_cast<scalar_t>(0));
303         auto one_vec = Vectorized<scalar_t>(static_cast<scalar_t>(1));
304 
305         cpu_kernel_vec(
306           iter,
307           [=](scalar_t a) -> scalar_t { return (0 < a) - c10::is_negative(a); },
308           [=](Vectorized<scalar_t> self_vec){
309 
310               // Comparison operators returns bitmask.
311               auto left = Vectorized<scalar_t>::blendv(zero_vec, one_vec, zero_vec < self_vec);
312               auto right = Vectorized<scalar_t>::blendv(zero_vec, one_vec, self_vec < zero_vec);
313 
314               return left - right;
315           });
316     });
317   }
318 }
319 
signbit_kernel(TensorIteratorBase & iter)320 static void signbit_kernel(TensorIteratorBase& iter){
321   // NOTE: signbit does not always support integral arguments.
322   AT_DISPATCH_SWITCH(iter.input_dtype(), "signbit_cpu",
323       AT_DISPATCH_CASE_INTEGRAL_TYPES([&] {
324         cpu_kernel(iter, [](scalar_t a) -> bool { return c10::is_negative(a); });
325       })
326       AT_DISPATCH_CASE_FLOATING_TYPES_AND2(kBFloat16, ScalarType::Half, [&] {
327         using opmath_t = at::opmath_type<scalar_t>;
328         cpu_kernel(iter, [](scalar_t a) -> bool { return std::signbit(opmath_t{a}); });
329       })
330     );
331 }
332 
sgn_kernel(TensorIteratorBase & iter)333 static void sgn_kernel(TensorIteratorBase& iter) {
334   auto dtype = iter.dtype();
335   if (dtype == kComplexHalf) {
336     using scalar_t = c10::complex<Half>;
337     using opmath_t = at::opmath_type<scalar_t>;
338     cpu_kernel(
339         iter, [=](scalar_t a) -> scalar_t { return sgn_impl(opmath_t{a}); });
340   } else {
341     AT_DISPATCH_COMPLEX_TYPES(dtype, "sgn_cpu", [&]() {
342       cpu_kernel_vec(
343         iter,
344         [=](scalar_t a) -> scalar_t { return sgn_impl(a); },
345         [=](Vectorized<scalar_t> a) { return a.sgn(); });
346     });
347   }
348 }
349 
sinc_kernel(TensorIteratorBase & iter)350 static void sinc_kernel(TensorIteratorBase& iter) {
351   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "sinc_cpu", [&]() {
352     cpu_kernel(
353         iter,
354         [=](scalar_t a) -> scalar_t {
355           if (a == scalar_t(0)) {
356             return scalar_t(1);
357           } else {
358             using opmath_t = at::opmath_type<scalar_t>;
359             opmath_t product = c10::pi<opmath_t> * opmath_t{a};
360             return static_cast<scalar_t>(std::sin(product) / product);
361           }
362         });
363   });
364 }
365 
sinh_kernel(TensorIteratorBase & iter)366 static void sinh_kernel(TensorIteratorBase& iter) {
367   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "sinh_cpu", [&]() {
368     cpu_kernel_vec(
369         iter,
370         [=](scalar_t a) -> scalar_t { return std::sinh(a); },
371         [=](Vectorized<scalar_t> self_vec){return self_vec.sinh();});
372   });
373 }
374 
cosh_kernel(TensorIteratorBase & iter)375 static void cosh_kernel(TensorIteratorBase& iter) {
376   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "cosh_cpu", [&]() {
377     cpu_kernel_vec(
378         iter,
379         [=](scalar_t a) -> scalar_t { return std::cosh(a); },
380         [=](Vectorized<scalar_t> self_vec){return self_vec.cosh();});
381   });
382 }
383 
acosh_kernel(TensorIteratorBase & iter)384 static void acosh_kernel(TensorIteratorBase& iter) {
385     AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "acosh_cpu", [&]() {
386       cpu_kernel(
387         iter,
388         [=](scalar_t a) -> scalar_t { return std::acosh(a); });
389     });
390 }
391 
asinh_kernel(TensorIteratorBase & iter)392 static void asinh_kernel(TensorIteratorBase& iter) {
393     AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "asinh_cpu", [&]() {
394       cpu_kernel(
395         iter,
396         [=](scalar_t a) -> scalar_t { return std::asinh(a); });
397     });
398 }
399 
atanh_kernel(TensorIteratorBase & iter)400 static void atanh_kernel(TensorIteratorBase& iter) {
401     AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "atanh_cpu", [&]() {
402       cpu_kernel_vec(
403         iter,
404         [=](scalar_t a) -> scalar_t { return std::atanh(a); },
405         [=](Vectorized<scalar_t> self_vec){return self_vec.atanh();});
406     });
407 }
408 
digamma_kernel(TensorIteratorBase & iter)409 static void digamma_kernel(TensorIteratorBase& iter) {
410   AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "digamma", [&]() {
411     cpu_kernel_vec(
412         iter,
413         [=](scalar_t a) -> scalar_t { return calc_digamma(a); },
414         [=](Vectorized<scalar_t> x) { return x.digamma(); });
415   });
416 }
417 
trigamma_kernel(TensorIteratorBase & iter)418 static void trigamma_kernel(TensorIteratorBase& iter) {
419   AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "trigamma", [&]() {
420     cpu_kernel(
421         iter,
422         [=](scalar_t a) -> scalar_t { return trigamma(a); });
423   });
424 }
425 
exp2_kernel(TensorIteratorBase & iter)426 static void exp2_kernel(TensorIteratorBase& iter) {
427   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
428       kBFloat16, kHalf, iter.dtype(), "exp2", [&] {
429     cpu_kernel_vec(
430         iter,
431         [](scalar_t a) -> scalar_t { return exp2_impl(a); },
432         [](Vectorized<scalar_t> a) { return a.exp2(); });
433   });
434 }
435 
polygamma_kernel(TensorIteratorBase & iter,int64_t n)436 static void polygamma_kernel(TensorIteratorBase& iter, int64_t n) {
437   if (n == 0) {
438     digamma_kernel(iter);
439   } else if (n == 1) {
440     trigamma_kernel(iter);
441   } else {
442     AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "polygamma", [&]() {
443       cpu_kernel(
444           iter, [=](scalar_t a) -> scalar_t { return calc_polygamma(a, n); });
445     });
446   }
447 }
448 
449 template <typename scalar_t>
_nan_to_num_replace(scalar_t a,scalar_t nan_replacement,scalar_t pos_inf_replacement,scalar_t neg_inf_replacement)450 inline scalar_t _nan_to_num_replace(
451     scalar_t a, scalar_t nan_replacement, scalar_t pos_inf_replacement, scalar_t neg_inf_replacement) {
452   if (at::_isnan(a)) {
453     return nan_replacement;
454   } else if (a == std::numeric_limits<scalar_t>::infinity()) {
455     return pos_inf_replacement;
456   } else if (a == -std::numeric_limits<scalar_t>::infinity()) {
457     return neg_inf_replacement;
458   } else {
459     return a;
460   }
461 }
462 
463 template <typename scalar_t>
_nan_to_num_replace(c10::complex<scalar_t> a,scalar_t nan,scalar_t posinf,scalar_t neginf)464 inline c10::complex<scalar_t> _nan_to_num_replace(
465     c10::complex<scalar_t> a, scalar_t nan, scalar_t posinf, scalar_t neginf) {
466   return c10::complex<scalar_t>(
467       _nan_to_num_replace(a.real(), nan, posinf, neginf),
468       _nan_to_num_replace(a.imag(), nan, posinf, neginf)
469   );
470 }
471 
472 template <typename scalar_t>
_nan_to_num_replace(Vectorized<scalar_t> a,scalar_t nan,scalar_t posinf,scalar_t neginf)473 inline Vectorized<scalar_t> _nan_to_num_replace(
474     Vectorized<scalar_t> a, scalar_t nan, scalar_t posinf, scalar_t neginf) {
475   using vec_t = Vectorized<scalar_t>;
476   vec_t inf(std::numeric_limits<scalar_t>::infinity());
477   vec_t result;
478   result = vec_t::blendv(a, vec_t(nan), a.isnan());
479   result = vec_t::blendv(result, vec_t(posinf), a == inf);
480   return vec_t::blendv(result, vec_t(neginf), a == inf.neg());
481 }
482 
483 template <typename scalar_t>
_nan_to_num_replace(Vectorized<c10::complex<scalar_t>> a,scalar_t nan,scalar_t posinf,scalar_t neginf)484 inline Vectorized<c10::complex<scalar_t>> _nan_to_num_replace(
485     Vectorized<c10::complex<scalar_t>> a, scalar_t nan, scalar_t posinf, scalar_t neginf) {
486 #if !defined(_MSC_VER) && (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512))
487   return {_nan_to_num_replace(Vectorized<scalar_t>(a), nan, posinf, neginf)};
488 #else
489   __at_align__ c10::complex<scalar_t> buffer[a.size()];
490   a.store(buffer);
491   auto asreal = Vectorized<scalar_t>::loadu(buffer);
492   _nan_to_num_replace(asreal, nan, posinf, neginf).store(buffer);
493   return Vectorized<c10::complex<scalar_t>>::loadu(buffer);
494 #endif
495 }
496 
nan_to_num_kernel(TensorIteratorBase & iter,std::optional<double> nan,std::optional<double> pos_inf,std::optional<double> neg_inf)497 static void nan_to_num_kernel(
498     TensorIteratorBase& iter,
499     std::optional<double> nan,
500     std::optional<double> pos_inf,
501     std::optional<double> neg_inf) {
502   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "nan_to_num", [&]() {
503     using value_t = c10::scalar_value_type<scalar_t>::type;
504     value_t nan_replacement = static_cast<value_t>(nan.value_or(0.));
505     value_t pos_inf_replacement = pos_inf.has_value()
506         ? static_cast<value_t>(pos_inf.value())
507         : std::numeric_limits<value_t>::max();
508     value_t neg_inf_replacement = neg_inf.has_value()
509         ? static_cast<value_t>(neg_inf.value())
510         : std::numeric_limits<value_t>::lowest();
511     using vec_t = Vectorized<scalar_t>;
512 
513     cpu_kernel_vec(iter, [=](scalar_t a) -> scalar_t {
514       return _nan_to_num_replace(a, nan_replacement, pos_inf_replacement, neg_inf_replacement);
515     }, [=](vec_t a) -> vec_t {
516       return _nan_to_num_replace(a, nan_replacement, pos_inf_replacement, neg_inf_replacement);
517     });
518   });
519 }
520 
kaiser_window_kernel(TensorIteratorBase & iter,int64_t window_length,double beta)521 static void kaiser_window_kernel(TensorIteratorBase& iter, int64_t window_length, double beta){
522   AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "kaiser_window_cpu", [&](){
523     using opmath_t = at::opmath_type<scalar_t>;
524     const opmath_t alpha = static_cast<opmath_t>((window_length - 1) / 2.0);
525     const opmath_t beta_ = static_cast<opmath_t>(beta);
526     cpu_kernel(iter, [=](scalar_t a) -> scalar_t {
527         return calc_i0(beta_ * std::sqrt(std::abs(1 - std::pow((static_cast<opmath_t>(a) - alpha) / alpha, static_cast<opmath_t>(2.0))))) / calc_i0(beta_);
528     });
529   });
530 }
531 
rsqrt_kernel(TensorIteratorBase & iter)532 void rsqrt_kernel(TensorIteratorBase& iter) {
533   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "rsqrt_cpu", [&] {
534     cpu_kernel_vec(
535         iter,
536         [=](scalar_t a) __ubsan_ignore_float_divide_by_zero__ -> scalar_t {
537           return (static_cast<scalar_t>(1)) / std::sqrt(a);
538         },
539         [=](Vectorized<scalar_t> a) { return a.rsqrt(); });
540   });
541 }
542 
entr_kernel(TensorIteratorBase & iter)543 static void entr_kernel(TensorIteratorBase& iter) {
544   AT_DISPATCH_FLOATING_TYPES_AND2(
545       kBFloat16, kHalf, iter.common_dtype(), "entr_cpu", [&] {
546         cpu_kernel(iter, [](scalar_t x) -> scalar_t {
547           if (at::_isnan(x)) {
548             return x;
549           } else if (x > 0) {
550             return -x * std::log(x);
551           } else if (x == 0) {
552             return static_cast<scalar_t>(0);
553           }
554           return static_cast<scalar_t>(-INFINITY);
555         });
556       });
557 }
558 
frexp_kernel(TensorIteratorBase & iter)559 static void frexp_kernel(TensorIteratorBase& iter) {
560   AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf,
561     // The iter.dtype() here is the dtype of mantissa output.
562     // It's a floating point type and must be the same as the input's dtype.
563     iter.dtype(),
564     "frexp_cpu", [&]() {
565       cpu_kernel_multiple_outputs(
566         iter,
567         [](scalar_t a) -> std::tuple<scalar_t, int32_t> {
568           int32_t exponent;
569           scalar_t mantissa = std::frexp(a, &exponent);
570           return std::tuple<scalar_t, int32_t>(mantissa, exponent);
571         }
572       );
573   });
574 }
575 
ndtri_kernel(TensorIteratorBase & iter)576 static void ndtri_kernel(TensorIteratorBase& iter) {
577   TORCH_INTERNAL_ASSERT(iter.ntensors() == 2);
578   AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "ndtri_cpu", [&]() {
579         cpu_kernel(iter, [](scalar_t x) { return calc_ndtri(x); });
580       });
581 }
582 
log_ndtr_kernel(TensorIteratorBase & iter)583 static void log_ndtr_kernel(TensorIteratorBase& iter) {
584   TORCH_INTERNAL_ASSERT(iter.ntensors() == 2);
585   AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "log_ndtr_cpu", [&]() {
586         cpu_kernel(iter, [](scalar_t x) { return calc_log_ndtr(x); });
587       });
588 }
589 
i0e_kernel(TensorIteratorBase & iter)590 static void i0e_kernel(TensorIteratorBase& iter) {
591   TORCH_INTERNAL_ASSERT(iter.ntensors() == 2);
592   AT_DISPATCH_FLOATING_TYPES_AND2(
593       kBFloat16, kHalf, iter.common_dtype(), "i0e_cpu", [&]() {
594         cpu_kernel_vec(
595             iter,
596             [](scalar_t x) { return calc_i0e(x); },
597             [](Vectorized<scalar_t> x) { return x.i0e(); });
598       });
599 }
600 
i1_kernel(TensorIteratorBase & iter)601 static void i1_kernel(TensorIteratorBase& iter) {
602   TORCH_INTERNAL_ASSERT(iter.ntensors() == 2);
603   AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "i1_cpu", [&]() {
604     cpu_kernel(iter, [](scalar_t x) { return calc_i1(x); });
605   });
606 }
607 
i1e_kernel(TensorIteratorBase & iter)608 static void i1e_kernel(TensorIteratorBase& iter) {
609   TORCH_INTERNAL_ASSERT(iter.ntensors() == 2);
610   AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "i1e_cpu", [&]() {
611     cpu_kernel(iter, [](scalar_t x) { return calc_i1e(x); });
612   });
613 }
614 
erfcx_kernel(TensorIteratorBase & iter)615 static void erfcx_kernel(TensorIteratorBase& iter){
616   AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "erfcx_cpu", [&]() {
617     cpu_kernel(
618       iter,
619       [](scalar_t a) -> scalar_t { return calc_erfcx(a); });
620   });
621 }
622 
round_decimals_kernel(TensorIteratorBase & iter,int64_t decimals)623 static void round_decimals_kernel(TensorIteratorBase& iter, int64_t decimals) {
624   AT_DISPATCH_FLOATING_TYPES_AND2(
625       kBFloat16, kHalf, iter.dtype(), "round_cpu", [&]() {
626         using opmath_t = at::opmath_type<scalar_t>;
627         bool neg_flag = false;
628         opmath_t ten_pow_decimals;
629         if (decimals < 0) {
630           decimals = -decimals;
631           neg_flag = true;
632         }
633         ten_pow_decimals = static_cast<opmath_t>(std::pow(10, decimals));
634         cpu_kernel(iter, [ten_pow_decimals, neg_flag](scalar_t a) -> scalar_t {
635           return neg_flag ? std::nearbyint(static_cast<opmath_t>(a) / ten_pow_decimals) * ten_pow_decimals
636                           : std::nearbyint(static_cast<opmath_t>(a) * ten_pow_decimals) / ten_pow_decimals;
637         });
638       });
639 }
640 
bessel_j0_kernel(TensorIteratorBase & iterator)641 static void bessel_j0_kernel(TensorIteratorBase& iterator) {
642     TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2);
643 
644     AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "bessel_j0_cpu", [&]() {
645         cpu_kernel(iterator, [](scalar_t x) {
646             return bessel_j0_forward(x);
647         });
648     });
649 } // bessel_j0_kernel(TensorIteratorBase& iterator)
650 
bessel_j1_kernel(TensorIteratorBase & iterator)651 static void bessel_j1_kernel(TensorIteratorBase& iterator) {
652     TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2);
653 
654     AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "bessel_j1_cpu", [&]() {
655         cpu_kernel(iterator, [](scalar_t x) {
656             return bessel_j1_forward(x);
657         });
658     });
659 } // bessel_j1_kernel(TensorIteratorBase& iterator)
660 
bessel_y0_kernel(TensorIteratorBase & iterator)661 static void bessel_y0_kernel(TensorIteratorBase& iterator) {
662     TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2);
663 
664     AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "bessel_y0_cpu", [&]() {
665         cpu_kernel(iterator, [](scalar_t x) {
666             return bessel_y0_forward(x);
667         });
668     });
669 } // bessel_y0_kernel(TensorIteratorBase& iterator)
670 
bessel_y1_kernel(TensorIteratorBase & iterator)671 static void bessel_y1_kernel(TensorIteratorBase& iterator) {
672     TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2);
673 
674     AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "bessel_y1_cpu", [&]() {
675         cpu_kernel(iterator, [](scalar_t x) {
676             return bessel_y1_forward(x);
677         });
678     });
679 } // bessel_y1_kernel(TensorIteratorBase& iterator)
680 
modified_bessel_i0_kernel(TensorIteratorBase & iterator)681 static void modified_bessel_i0_kernel(TensorIteratorBase& iterator) {
682     TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2);
683 
684     AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "modified_bessel_i0_cpu", [&]() {
685         cpu_kernel(iterator, [](scalar_t x) {
686             return modified_bessel_i0_forward(x);
687         });
688     });
689 } // modified_bessel_i0_kernel(TensorIteratorBase& iterator)
690 
modified_bessel_i1_kernel(TensorIteratorBase & iterator)691 static void modified_bessel_i1_kernel(TensorIteratorBase& iterator) {
692     TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2);
693 
694     AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "modified_bessel_i1_cpu", [&]() {
695         cpu_kernel(iterator, [](scalar_t x) {
696             return modified_bessel_i1_forward(x);
697         });
698     });
699 } // modified_bessel_i1_kernel(TensorIteratorBase& iterator)
700 
modified_bessel_k0_kernel(TensorIteratorBase & iterator)701 static void modified_bessel_k0_kernel(TensorIteratorBase& iterator) {
702     TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2);
703 
704     AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "modified_bessel_k0_cpu", [&]() {
705         cpu_kernel(iterator, [](scalar_t x) {
706             return modified_bessel_k0_forward(x);
707         });
708     });
709 } // modified_bessel_k0_kernel(TensorIteratorBase& iterator)
710 
modified_bessel_k1_kernel(TensorIteratorBase & iterator)711 static void modified_bessel_k1_kernel(TensorIteratorBase& iterator) {
712     TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2);
713 
714     AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "modified_bessel_k1_cpu", [&]() {
715         cpu_kernel(iterator, [](scalar_t x) {
716             return modified_bessel_k1_forward(x);
717         });
718     });
719 } // modified_bessel_k1_kernel(TensorIteratorBase& iterator)
720 
721 // TODO: Disable cont. branch to test more risky code
722 
723 #define IMPLEMENT_ITERATOR_LAMBDA(op)                                              \
724           [&](char** data_, const int64_t* strides, int64_t n) {                   \
725             scalar_t* out_data = reinterpret_cast<scalar_t*>(data_[0]);            \
726             scalar_t* in_data = reinterpret_cast<scalar_t*>(data_[1]);             \
727             int64_t out_stride = strides[0] / sizeof(scalar_t);                    \
728             int64_t in_stride = strides[1] / sizeof(scalar_t);                     \
729             if (out_stride == 1 && in_stride == 1) {                               \
730               vml::v##op(out_data, in_data, n);                                    \
731               return;                                                              \
732             }                                                                      \
733             static constexpr int64_t WIDTH = (8*1024) / sizeof(scalar_t);          \
734             for (int64_t i = 0; i < n; i += WIDTH) {                               \
735               scalar_t buffer[WIDTH];                                              \
736               const int64_t width = std::min(WIDTH, n - i);                        \
737               /* If either tensor is contiguous use it, otherwise copy into */     \
738               /* a contiguous buffer so compute can still be vectorized */         \
739               scalar_t * in_buffer = in_stride == 1 ? &in_data[i] : &buffer[0];    \
740               scalar_t * out_buffer = out_stride == 1 ? &out_data[i] : &buffer[0]; \
741               if (in_stride != 1)                                                  \
742                 for (const auto j : c10::irange(width))                            \
743                   in_buffer[j] = in_data[in_stride * (i + j)];                     \
744               vml::v##op(out_buffer, in_buffer, width);                            \
745               if (out_stride != 1)                                                 \
746                 for (const auto j : c10::irange(width))                            \
747                     out_data[out_stride * (i + j)] = out_buffer[j];                \
748             }                                                                      \
749           }
750 
751 #define IMPLEMENT_FLOAT_KERNEL(op)                                                  \
752   inline namespace CPU_CAPABILITY {                                                 \
753   static void op##_kernel(TensorIteratorBase& iter) {                               \
754     TORCH_INTERNAL_ASSERT(iter.ntensors() == 2);                                    \
755     AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), #op "_vml_cpu", [&]() { \
756       constexpr int64_t grain_size = 2048;                                          \
757       iter.for_each(IMPLEMENT_ITERATOR_LAMBDA(op), grain_size);                     \
758     });                                                                             \
759     iter.cast_outputs();                                                            \
760   }                                                                                 \
761   }
762 
763 #define IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(op)                                   \
764   IMPLEMENT_FLOAT_KERNEL(op)                                                        \
765   REGISTER_DISPATCH(op##_stub, &CPU_CAPABILITY::op##_kernel)
766 
767 #define IMPLEMENT_FLOAT_KERNEL_WITH_AVX512(op)                                      \
768   IMPLEMENT_FLOAT_KERNEL(op)                                                        \
769   ALSO_REGISTER_AVX512_DISPATCH(op##_stub, &CPU_CAPABILITY::op##_kernel)
770 
771 #define IMPLEMENT_COMPLEX_KERNEL(op)                                                             \
772   inline namespace CPU_CAPABILITY {                                                              \
773   void op##_kernel(TensorIteratorBase& iter) {                                                   \
774     TORCH_INTERNAL_ASSERT(iter.ntensors() == 2);                                                 \
775     AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), #op "_vml_cpu", [&]() { \
776         constexpr int64_t grain_size = 2048;                                                     \
777         iter.for_each(IMPLEMENT_ITERATOR_LAMBDA(op), grain_size);                                \
778     });                                                                                          \
779     iter.cast_outputs();                                                                         \
780   }                                                                                              \
781   }
782 
783 #define IMPLEMENT_COMPLEX_KERNEL_WITHOUT_AVX512(op)                            \
784   IMPLEMENT_COMPLEX_KERNEL(op)                                                 \
785   REGISTER_DISPATCH(op##_stub, &CPU_CAPABILITY::op##_kernel)
786 
787 #define IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(op)                               \
788   IMPLEMENT_COMPLEX_KERNEL(op)                                                 \
789   ALSO_REGISTER_AVX512_DISPATCH(op##_stub, &CPU_CAPABILITY::op##_kernel)
790 
791 #define STATIC_IMPLEMENT_COMPLEX_KERNEL(op)                                                      \
792   inline namespace CPU_CAPABILITY {                                                              \
793   static void op##_kernel(TensorIteratorBase& iter) {                                            \
794     TORCH_INTERNAL_ASSERT(iter.ntensors() == 2);                                                 \
795     AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), #op "_vml_cpu", [&]() { \
796         constexpr int64_t grain_size = 2048;                                                     \
797         iter.for_each(IMPLEMENT_ITERATOR_LAMBDA(op), grain_size);                                \
798     });                                                                                          \
799     iter.cast_outputs();                                                                         \
800   }                                                                                              \
801   }
802 
803 #define STATIC_IMPLEMENT_COMPLEX_KERNEL_WITHOUT_AVX512(op)                     \
804   STATIC_IMPLEMENT_COMPLEX_KERNEL(op)                                          \
805   REGISTER_DISPATCH(op##_stub, &CPU_CAPABILITY::op##_kernel)
806 
807 #define STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(op)                        \
808   STATIC_IMPLEMENT_COMPLEX_KERNEL(op)                                          \
809   ALSO_REGISTER_AVX512_DISPATCH(op##_stub, &CPU_CAPABILITY::op##_kernel)
810 
811 } // CPU_CAPABILITY namespace
812 
813 // The following kernels are slower with AVX512
814 REGISTER_DISPATCH(round_decimals_stub, &CPU_CAPABILITY::round_decimals_kernel);
815 REGISTER_DISPATCH(abs_stub, &CPU_CAPABILITY::abs_kernel);
816 REGISTER_DISPATCH(angle_stub, &CPU_CAPABILITY::angle_kernel);
817 REGISTER_DISPATCH(neg_stub, &CPU_CAPABILITY::neg_kernel);
818 REGISTER_DISPATCH(signbit_stub, &CPU_CAPABILITY::signbit_kernel);
819 REGISTER_DISPATCH(sinc_stub, &CPU_CAPABILITY::sinc_kernel);
820 REGISTER_DISPATCH(bitwise_not_stub, &CPU_CAPABILITY::bitwise_not_kernel);
821 REGISTER_DISPATCH(logical_not_stub, &CPU_CAPABILITY::logical_not_kernel);
822 REGISTER_DISPATCH(nan_to_num_stub, &CPU_CAPABILITY::nan_to_num_kernel);
823 REGISTER_DISPATCH(conj_physical_stub, &CPU_CAPABILITY::conj_kernel);
824 REGISTER_DISPATCH(rsqrt_stub, &CPU_CAPABILITY::rsqrt_kernel);
825 REGISTER_DISPATCH(frac_stub, &CPU_CAPABILITY::frac_kernel);
826 REGISTER_DISPATCH(special_entr_stub, &CPU_CAPABILITY::entr_kernel);
827 REGISTER_DISPATCH(special_i0e_stub, &CPU_CAPABILITY::i0e_kernel);
828 REGISTER_DISPATCH(special_ndtri_stub, &CPU_CAPABILITY::ndtri_kernel);
829 REGISTER_DISPATCH(special_modified_bessel_k0_stub, &CPU_CAPABILITY::modified_bessel_k0_kernel);
830 REGISTER_DISPATCH(special_modified_bessel_k1_stub, &CPU_CAPABILITY::modified_bessel_k1_kernel);
831 IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(ceil);
832 IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(floor);
833 IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(round);
834 IMPLEMENT_COMPLEX_KERNEL_WITHOUT_AVX512(sqrt);
835 IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(trunc);
836 IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(i0);
837 STATIC_IMPLEMENT_COMPLEX_KERNEL_WITHOUT_AVX512(sin);
838 STATIC_IMPLEMENT_COMPLEX_KERNEL_WITHOUT_AVX512(cos);
839 STATIC_IMPLEMENT_COMPLEX_KERNEL_WITHOUT_AVX512(tan);
840 
841 // The following kernels are compute-intensive & are compiled with both AVX512
842 // & AVX2
843 ALSO_REGISTER_AVX512_DISPATCH(sign_stub, &CPU_CAPABILITY::sign_kernel);
844 ALSO_REGISTER_AVX512_DISPATCH(sgn_stub, &CPU_CAPABILITY::sgn_kernel);
845 ALSO_REGISTER_AVX512_DISPATCH(reciprocal_stub, &CPU_CAPABILITY::reciprocal_kernel);
846 ALSO_REGISTER_AVX512_DISPATCH(exp2_stub, &CPU_CAPABILITY::exp2_kernel);
847 ALSO_REGISTER_AVX512_DISPATCH(sigmoid_stub, &CPU_CAPABILITY::sigmoid_kernel);
848 ALSO_REGISTER_AVX512_DISPATCH(logit_stub, &CPU_CAPABILITY::logit_kernel);
849 ALSO_REGISTER_AVX512_DISPATCH(sinh_stub, &CPU_CAPABILITY::sinh_kernel);
850 ALSO_REGISTER_AVX512_DISPATCH(cosh_stub, &CPU_CAPABILITY::cosh_kernel);
851 ALSO_REGISTER_AVX512_DISPATCH(atanh_stub, &CPU_CAPABILITY::atanh_kernel);
852 
853 // Might enable AVX512 dispatch after enabling explicit vectorization for them
854 REGISTER_DISPATCH(acosh_stub, &CPU_CAPABILITY::acosh_kernel);
855 REGISTER_DISPATCH(asinh_stub, &CPU_CAPABILITY::asinh_kernel);
856 REGISTER_DISPATCH(digamma_stub, &CPU_CAPABILITY::digamma_kernel);
857 REGISTER_DISPATCH(trigamma_stub, &CPU_CAPABILITY::trigamma_kernel);
858 REGISTER_DISPATCH(polygamma_stub, &CPU_CAPABILITY::polygamma_kernel);
859 REGISTER_DISPATCH(kaiser_window_stub, &CPU_CAPABILITY::kaiser_window_kernel);
860 REGISTER_DISPATCH(frexp_stub, &CPU_CAPABILITY::frexp_kernel);
861 REGISTER_DISPATCH(special_log_ndtr_stub, &CPU_CAPABILITY::log_ndtr_kernel);
862 REGISTER_DISPATCH(special_i1_stub, &CPU_CAPABILITY::i1_kernel);
863 REGISTER_DISPATCH(special_i1e_stub, &CPU_CAPABILITY::i1e_kernel);
864 REGISTER_DISPATCH(special_erfcx_stub, &CPU_CAPABILITY::erfcx_kernel);
865 REGISTER_DISPATCH(special_bessel_j0_stub, &CPU_CAPABILITY::bessel_j0_kernel);
866 REGISTER_DISPATCH(special_bessel_j1_stub, &CPU_CAPABILITY::bessel_j1_kernel);
867 REGISTER_DISPATCH(special_bessel_y0_stub, &CPU_CAPABILITY::bessel_y0_kernel);
868 REGISTER_DISPATCH(special_bessel_y1_stub, &CPU_CAPABILITY::bessel_y1_kernel);
869 REGISTER_DISPATCH(special_modified_bessel_i0_stub, &CPU_CAPABILITY::modified_bessel_i0_kernel);
870 REGISTER_DISPATCH(special_modified_bessel_i1_stub, &CPU_CAPABILITY::modified_bessel_i1_kernel);
871 
872 STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(acos);
873 STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(asin);
874 STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(atan);
875 IMPLEMENT_FLOAT_KERNEL_WITH_AVX512(erf);
876 IMPLEMENT_FLOAT_KERNEL_WITH_AVX512(erfc);
877 IMPLEMENT_FLOAT_KERNEL_WITH_AVX512(erfinv);
878 STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(exp);
879 STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(expm1);
880 STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(log);
881 STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(log10);
882 STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(log1p);
883 STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(log2);
884 STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(tanh);
885 IMPLEMENT_FLOAT_KERNEL_WITH_AVX512(lgamma);
886 
887 } // namespace at::native
888