1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/UnaryOps.h>
3
4 #include <cmath>
5 #include <limits>
6 #include <type_traits>
7
8 #include <ATen/Config.h>
9 #include <ATen/Context.h>
10 #include <ATen/Dispatch.h>
11 #include <ATen/Parallel.h>
12 #include <ATen/cpu/vec/functional.h>
13 #include <ATen/cpu/vec/vec.h>
14 #include <ATen/cpu/vml.h>
15 #include <ATen/native/TensorIterator.h>
16 #include <ATen/native/cpu/CopyKernel.h>
17 #include <ATen/native/cpu/Loops.h>
18 #include <ATen/native/cpu/zmath.h>
19 #include <ATen/OpMathType.h>
20
21 #include <c10/util/MathConstants.h>
22 #include <c10/core/Scalar.h>
23 #include <c10/util/TypeSafeSignMath.h>
24 #include <c10/util/irange.h>
25
26 #if AT_MKL_ENABLED()
27 #include <mkl.h>
28 #endif
29
30 namespace at::native {
31
32 inline namespace CPU_CAPABILITY {
33
34 using namespace vec;
35
sigmoid_kernel(TensorIteratorBase & iter)36 static void sigmoid_kernel(TensorIteratorBase& iter) {
37 const auto dtype = iter.common_dtype();
38 if (at::isReducedFloatingType(dtype)) {
39 AT_DISPATCH_REDUCED_FLOATING_TYPES(dtype, "sigmoid_cpu_reduced_float", [&]() {
40 cpu_kernel_vec(
41 iter,
42 [=](scalar_t a) -> scalar_t {
43 float a0 = static_cast<float>(a);
44 return static_cast<float>(1) / (static_cast<float>(1) + std::exp((-a0)));
45 },
46 [=](Vectorized<scalar_t> a) {
47 auto [a0, a1] = convert_to_float<scalar_t>(a);
48 a0 = (Vectorized<float>(static_cast<float>(1)) + a0.neg().exp()).reciprocal();
49 a1 = (Vectorized<float>(static_cast<float>(1)) + a1.neg().exp()).reciprocal();
50 return convert_from_float<scalar_t>(a0, a1);
51 });
52 });
53 } else {
54 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(dtype, "sigmoid_cpu", [&]() {
55 cpu_kernel_vec(
56 iter,
57 [=](scalar_t a) -> scalar_t {
58 return (static_cast<scalar_t>(1) / (static_cast<scalar_t>(1) + std::exp((-a))));
59 },
60 [=](Vectorized<scalar_t> a) {
61 a = Vectorized<scalar_t>(static_cast<scalar_t>(0)) - a;
62 a = a.exp();
63 a = Vectorized<scalar_t>(static_cast<scalar_t>(1)) + a;
64 a = a.reciprocal();
65 return a;
66 });
67 });
68 }
69 }
70
71 #if AT_MKL_ENABLED()
72
73 template <typename T>
VmlLog(int64_t N,const T * X,T * Y)74 void VmlLog(int64_t N, const T* X, T* Y) {
75 constexpr int64_t K = Vectorized<T>::size();
76 at::parallel_for(0, N, K, [=](int64_t begin, int64_t end) {
77 using VT = at::opmath_type<T>;
78 vec::map(
79 [](Vectorized<VT> x_vec) { return x_vec.log(); },
80 Y + begin,
81 X + begin,
82 end - begin);
83 });
84 }
85
86 template <>
VmlLog(int64_t N,const float * X,float * Y)87 void VmlLog<float>(int64_t N, const float* X, float* Y) {
88 vsLn(N, X, Y);
89 }
90
91 template <>
VmlLog(int64_t N,const double * X,double * Y)92 void VmlLog<double>(int64_t N, const double* X, double* Y) {
93 vdLn(N, X, Y);
94 }
95
96 template <typename T>
LogitMKLKernel(T eps,TensorIteratorBase * it)97 void LogitMKLKernel(T eps, TensorIteratorBase* it) {
98 if (!it->can_use_32bit_indexing()) {
99 for (auto& sub_it : it->with_32bit_indexing()) {
100 LogitMKLKernel<T>(eps, &sub_it);
101 }
102 return;
103 }
104
105 constexpr int64_t K = Vectorized<T>::size();
106 const int64_t N = it->numel();
107 const T* X_data = static_cast<T*>(it->data_ptr(1));
108 T* Y_data = static_cast<T*>(it->data_ptr(0));
109 if (eps < T(0)) {
110 at::parallel_for(0, N, K, [=](int64_t begin, int64_t end) {
111 for (const auto i : c10::irange(begin, end)) {
112 Y_data[i] = X_data[i] == T(1) ? std::numeric_limits<T>::infinity()
113 : X_data[i] / (T(1) - X_data[i]);
114 }
115 VmlLog<T>(end - begin, Y_data + begin, Y_data + begin);
116 });
117 } else {
118 const T lo = eps;
119 const T hi = T(1) - eps;
120 at::parallel_for(0, N, K, [=](int64_t begin, int64_t end) {
121 for (const auto i : c10::irange(begin, end)) {
122 const T x = X_data[i] < lo ? lo : (X_data[i] > hi ? hi : X_data[i]);
123 Y_data[i] =
124 x == T(1) ? std::numeric_limits<T>::infinity() : (x / (T(1) - x));
125 }
126 VmlLog<T>(end - begin, Y_data + begin, Y_data + begin);
127 });
128 }
129 }
130
131 #else
132
133 template <typename T>
LogitMKLKernel(T eps,TensorIteratorBase * it)134 void LogitMKLKernel(T eps, TensorIteratorBase* it) {
135 TORCH_CHECK(false, "ATen not compiled with MKL");
136 }
137
138 #endif // AT_MKL_ENABLED
139
logit_kernel(TensorIteratorBase & iter,const Scalar & eps_scalar)140 static void logit_kernel(TensorIteratorBase& iter, const Scalar& eps_scalar) {
141 AT_DISPATCH_FLOATING_TYPES_AND2(
142 kBFloat16, kHalf, iter.common_dtype(), "logit_cpu", [&]() {
143 const scalar_t eps = eps_scalar.to<scalar_t>();
144 if (at::hasMKL() && iter.is_contiguous()) {
145 LogitMKLKernel<scalar_t>(eps, &iter);
146 iter.cast_outputs();
147 } else if (eps < scalar_t(0)) {
148 const Vectorized<scalar_t> kOneVec(scalar_t(1));
149 cpu_kernel_vec(
150 iter,
151 [](scalar_t x) {
152 return x == scalar_t(1)
153 ? std::numeric_limits<scalar_t>::infinity()
154 : std::log(x / (scalar_t(1) - x));
155 },
156 [kOneVec](Vectorized<scalar_t> x_vec) {
157 return (x_vec / (kOneVec - x_vec)).log();
158 });
159 } else {
160 const scalar_t lo = eps;
161 const scalar_t hi = scalar_t(1) - eps;
162 const Vectorized<scalar_t> kOneVec(scalar_t(1));
163 const Vectorized<scalar_t> lo_vec(lo);
164 const Vectorized<scalar_t> hi_vec(hi);
165 cpu_kernel_vec(
166 iter,
167 [lo, hi](scalar_t x) {
168 x = x < lo ? lo : (x > hi ? hi : x);
169 return x == scalar_t(1)
170 ? std::numeric_limits<scalar_t>::infinity()
171 : std::log(x / (scalar_t(1) - x));
172 },
173 [kOneVec, lo_vec, hi_vec](Vectorized<scalar_t> x_vec) {
174 x_vec = vec::clamp(x_vec, lo_vec, hi_vec);
175 return (x_vec / (kOneVec - x_vec)).log();
176 });
177 }
178 });
179 }
180
181 #if !defined(C10_MOBILE)
182 #define _AT_DISPATCH_ABS_TYPES(TYPE, NAME, ...) \
183 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \
184 kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, \
185 TYPE, NAME, __VA_ARGS__)
186 #else
187 #define _AT_DISPATCH_ABS_TYPES(TYPE, NAME, ...) \
188 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( \
189 kHalf, kBFloat16, \
190 TYPE, NAME, __VA_ARGS__)
191 #endif
192
abs_kernel(TensorIteratorBase & iter)193 static void abs_kernel(TensorIteratorBase& iter) {
194 auto dtype = iter.dtype();
195 if (dtype == kComplexHalf) {
196 using scalar_t = c10::complex<Half>;
197 using opmath_t = at::opmath_type<scalar_t>;
198 cpu_kernel(iter, [=](scalar_t a) -> scalar_t { return abs_impl(opmath_t{a}); });
199 } else {
200 _AT_DISPATCH_ABS_TYPES(iter.dtype(), "abs_cpu", [&]() {
201 cpu_kernel_vec(
202 iter,
203 [=](scalar_t a) -> scalar_t { return abs_impl(a); },
204 [=](Vectorized<scalar_t> a) { return a.abs(); });
205 });
206 }
207 }
208
angle_kernel(TensorIteratorBase & iter)209 static void angle_kernel(TensorIteratorBase& iter) {
210 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "angle_cpu", [&]() {
211 cpu_kernel_vec(
212 iter,
213 [=](scalar_t a) -> scalar_t { return angle_impl(a); },
214 [=](Vectorized<scalar_t> a) { return a.angle(); });
215 });
216 }
217
218 // NB: Ignores the negative bit on tensors
conj_kernel(TensorIteratorBase & iter)219 void conj_kernel(TensorIteratorBase& iter) {
220 AT_DISPATCH_SWITCH(iter.common_dtype(), "conj_cpu",
221 AT_DISPATCH_CASE_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, [&] {
222 // conj is a no-op for non-complex types
223 direct_copy_kernel(iter);
224 })
225 AT_DISPATCH_CASE_COMPLEX_TYPES_AND(kComplexHalf, [&] {
226 cpu_kernel_vec(
227 iter,
228 [=](scalar_t a) -> scalar_t { return conj_impl(a); },
229 [=](Vectorized<scalar_t> a) { return a.conj(); });
230 })
231 );
232 }
233
bitwise_not_kernel(TensorIteratorBase & iter)234 static void bitwise_not_kernel(TensorIteratorBase& iter) {
235 if (iter.dtype() == ScalarType::Bool) {
236 // Boolean type does not work with ~ (bitwise NOT) in C++. bitwise_not wraps this operation for both Boolean and
237 // integral types.
238 cpu_kernel(
239 iter,
240 [](bool a) {
241 return !a;
242 });
243 } else {
244 AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_not_cpu", [&]() {
245 cpu_kernel_vec(
246 iter,
247 [](scalar_t a) -> scalar_t {
248 return ~a;
249 },
250 [](Vectorized<scalar_t> a) -> Vectorized<scalar_t> {
251 return ~a;
252 });
253 });
254 }
255 }
256
frac_kernel(TensorIteratorBase & iter)257 static void frac_kernel(TensorIteratorBase& iter) {
258 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "frac_cpu", [&]() {
259 cpu_kernel_vec(
260 iter,
261 [=](scalar_t a) -> scalar_t { return a - std::trunc(a); },
262 [=](Vectorized<scalar_t> a) { return a.frac(); });
263 });
264 }
265
logical_not_kernel(TensorIteratorBase & iter)266 static void logical_not_kernel(TensorIteratorBase& iter) {
267 // NOTE: this implementation differs from the CUDA implementation which only does single dispatch
268 // (to avoid expensive compilation) because CPU kernels don't handle dynamic_casting
269 // (see needs_dynamic_casting).
270 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kHalf, kBFloat16, iter.dtype(1), "logical_not_cpu", [&]() {
271 using self_t = scalar_t;
272 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kHalf, kBFloat16, iter.dtype(0), "logical_not_cpu", [&]() {
273 cpu_kernel(iter, [](self_t a) -> scalar_t { return static_cast<scalar_t>(!a); });
274 });
275 });
276 }
277
reciprocal_kernel(TensorIteratorBase & iter)278 void reciprocal_kernel(TensorIteratorBase& iter) {
279 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "reciprocal_cpu", [&]() {
280 cpu_kernel_vec(
281 iter,
282 [=](scalar_t a) __ubsan_ignore_float_divide_by_zero__ -> scalar_t { return static_cast<scalar_t>(1.0) / a; },
283 [=](Vectorized<scalar_t> a) { return a.reciprocal(); });
284 });
285 }
286
287 // NB: Ignores the negative bit on tensors
neg_kernel(TensorIteratorBase & iter)288 void neg_kernel(TensorIteratorBase& iter) {
289 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kComplexHalf, kBFloat16, kHalf, iter.dtype(), "neg_cpu", [&]() {
290 cpu_kernel_vec(
291 iter,
292 [=](scalar_t a) -> scalar_t { return -a; },
293 [=](Vectorized<scalar_t> a) { return a.neg(); });
294 });
295 }
296
sign_kernel(TensorIteratorBase & iter)297 static void sign_kernel(TensorIteratorBase& iter){
298 if(iter.dtype() == ScalarType::Bool){
299 cpu_kernel(iter, [=](bool x) -> bool { return x; });
300 } else {
301 AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, ScalarType::Half, iter.dtype(), "sign_cpu", [&]() {
302 auto zero_vec = Vectorized<scalar_t>(static_cast<scalar_t>(0));
303 auto one_vec = Vectorized<scalar_t>(static_cast<scalar_t>(1));
304
305 cpu_kernel_vec(
306 iter,
307 [=](scalar_t a) -> scalar_t { return (0 < a) - c10::is_negative(a); },
308 [=](Vectorized<scalar_t> self_vec){
309
310 // Comparison operators returns bitmask.
311 auto left = Vectorized<scalar_t>::blendv(zero_vec, one_vec, zero_vec < self_vec);
312 auto right = Vectorized<scalar_t>::blendv(zero_vec, one_vec, self_vec < zero_vec);
313
314 return left - right;
315 });
316 });
317 }
318 }
319
signbit_kernel(TensorIteratorBase & iter)320 static void signbit_kernel(TensorIteratorBase& iter){
321 // NOTE: signbit does not always support integral arguments.
322 AT_DISPATCH_SWITCH(iter.input_dtype(), "signbit_cpu",
323 AT_DISPATCH_CASE_INTEGRAL_TYPES([&] {
324 cpu_kernel(iter, [](scalar_t a) -> bool { return c10::is_negative(a); });
325 })
326 AT_DISPATCH_CASE_FLOATING_TYPES_AND2(kBFloat16, ScalarType::Half, [&] {
327 using opmath_t = at::opmath_type<scalar_t>;
328 cpu_kernel(iter, [](scalar_t a) -> bool { return std::signbit(opmath_t{a}); });
329 })
330 );
331 }
332
sgn_kernel(TensorIteratorBase & iter)333 static void sgn_kernel(TensorIteratorBase& iter) {
334 auto dtype = iter.dtype();
335 if (dtype == kComplexHalf) {
336 using scalar_t = c10::complex<Half>;
337 using opmath_t = at::opmath_type<scalar_t>;
338 cpu_kernel(
339 iter, [=](scalar_t a) -> scalar_t { return sgn_impl(opmath_t{a}); });
340 } else {
341 AT_DISPATCH_COMPLEX_TYPES(dtype, "sgn_cpu", [&]() {
342 cpu_kernel_vec(
343 iter,
344 [=](scalar_t a) -> scalar_t { return sgn_impl(a); },
345 [=](Vectorized<scalar_t> a) { return a.sgn(); });
346 });
347 }
348 }
349
sinc_kernel(TensorIteratorBase & iter)350 static void sinc_kernel(TensorIteratorBase& iter) {
351 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "sinc_cpu", [&]() {
352 cpu_kernel(
353 iter,
354 [=](scalar_t a) -> scalar_t {
355 if (a == scalar_t(0)) {
356 return scalar_t(1);
357 } else {
358 using opmath_t = at::opmath_type<scalar_t>;
359 opmath_t product = c10::pi<opmath_t> * opmath_t{a};
360 return static_cast<scalar_t>(std::sin(product) / product);
361 }
362 });
363 });
364 }
365
sinh_kernel(TensorIteratorBase & iter)366 static void sinh_kernel(TensorIteratorBase& iter) {
367 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "sinh_cpu", [&]() {
368 cpu_kernel_vec(
369 iter,
370 [=](scalar_t a) -> scalar_t { return std::sinh(a); },
371 [=](Vectorized<scalar_t> self_vec){return self_vec.sinh();});
372 });
373 }
374
cosh_kernel(TensorIteratorBase & iter)375 static void cosh_kernel(TensorIteratorBase& iter) {
376 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "cosh_cpu", [&]() {
377 cpu_kernel_vec(
378 iter,
379 [=](scalar_t a) -> scalar_t { return std::cosh(a); },
380 [=](Vectorized<scalar_t> self_vec){return self_vec.cosh();});
381 });
382 }
383
acosh_kernel(TensorIteratorBase & iter)384 static void acosh_kernel(TensorIteratorBase& iter) {
385 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "acosh_cpu", [&]() {
386 cpu_kernel(
387 iter,
388 [=](scalar_t a) -> scalar_t { return std::acosh(a); });
389 });
390 }
391
asinh_kernel(TensorIteratorBase & iter)392 static void asinh_kernel(TensorIteratorBase& iter) {
393 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "asinh_cpu", [&]() {
394 cpu_kernel(
395 iter,
396 [=](scalar_t a) -> scalar_t { return std::asinh(a); });
397 });
398 }
399
atanh_kernel(TensorIteratorBase & iter)400 static void atanh_kernel(TensorIteratorBase& iter) {
401 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "atanh_cpu", [&]() {
402 cpu_kernel_vec(
403 iter,
404 [=](scalar_t a) -> scalar_t { return std::atanh(a); },
405 [=](Vectorized<scalar_t> self_vec){return self_vec.atanh();});
406 });
407 }
408
digamma_kernel(TensorIteratorBase & iter)409 static void digamma_kernel(TensorIteratorBase& iter) {
410 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "digamma", [&]() {
411 cpu_kernel_vec(
412 iter,
413 [=](scalar_t a) -> scalar_t { return calc_digamma(a); },
414 [=](Vectorized<scalar_t> x) { return x.digamma(); });
415 });
416 }
417
trigamma_kernel(TensorIteratorBase & iter)418 static void trigamma_kernel(TensorIteratorBase& iter) {
419 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "trigamma", [&]() {
420 cpu_kernel(
421 iter,
422 [=](scalar_t a) -> scalar_t { return trigamma(a); });
423 });
424 }
425
exp2_kernel(TensorIteratorBase & iter)426 static void exp2_kernel(TensorIteratorBase& iter) {
427 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
428 kBFloat16, kHalf, iter.dtype(), "exp2", [&] {
429 cpu_kernel_vec(
430 iter,
431 [](scalar_t a) -> scalar_t { return exp2_impl(a); },
432 [](Vectorized<scalar_t> a) { return a.exp2(); });
433 });
434 }
435
polygamma_kernel(TensorIteratorBase & iter,int64_t n)436 static void polygamma_kernel(TensorIteratorBase& iter, int64_t n) {
437 if (n == 0) {
438 digamma_kernel(iter);
439 } else if (n == 1) {
440 trigamma_kernel(iter);
441 } else {
442 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "polygamma", [&]() {
443 cpu_kernel(
444 iter, [=](scalar_t a) -> scalar_t { return calc_polygamma(a, n); });
445 });
446 }
447 }
448
449 template <typename scalar_t>
_nan_to_num_replace(scalar_t a,scalar_t nan_replacement,scalar_t pos_inf_replacement,scalar_t neg_inf_replacement)450 inline scalar_t _nan_to_num_replace(
451 scalar_t a, scalar_t nan_replacement, scalar_t pos_inf_replacement, scalar_t neg_inf_replacement) {
452 if (at::_isnan(a)) {
453 return nan_replacement;
454 } else if (a == std::numeric_limits<scalar_t>::infinity()) {
455 return pos_inf_replacement;
456 } else if (a == -std::numeric_limits<scalar_t>::infinity()) {
457 return neg_inf_replacement;
458 } else {
459 return a;
460 }
461 }
462
463 template <typename scalar_t>
_nan_to_num_replace(c10::complex<scalar_t> a,scalar_t nan,scalar_t posinf,scalar_t neginf)464 inline c10::complex<scalar_t> _nan_to_num_replace(
465 c10::complex<scalar_t> a, scalar_t nan, scalar_t posinf, scalar_t neginf) {
466 return c10::complex<scalar_t>(
467 _nan_to_num_replace(a.real(), nan, posinf, neginf),
468 _nan_to_num_replace(a.imag(), nan, posinf, neginf)
469 );
470 }
471
472 template <typename scalar_t>
_nan_to_num_replace(Vectorized<scalar_t> a,scalar_t nan,scalar_t posinf,scalar_t neginf)473 inline Vectorized<scalar_t> _nan_to_num_replace(
474 Vectorized<scalar_t> a, scalar_t nan, scalar_t posinf, scalar_t neginf) {
475 using vec_t = Vectorized<scalar_t>;
476 vec_t inf(std::numeric_limits<scalar_t>::infinity());
477 vec_t result;
478 result = vec_t::blendv(a, vec_t(nan), a.isnan());
479 result = vec_t::blendv(result, vec_t(posinf), a == inf);
480 return vec_t::blendv(result, vec_t(neginf), a == inf.neg());
481 }
482
483 template <typename scalar_t>
_nan_to_num_replace(Vectorized<c10::complex<scalar_t>> a,scalar_t nan,scalar_t posinf,scalar_t neginf)484 inline Vectorized<c10::complex<scalar_t>> _nan_to_num_replace(
485 Vectorized<c10::complex<scalar_t>> a, scalar_t nan, scalar_t posinf, scalar_t neginf) {
486 #if !defined(_MSC_VER) && (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512))
487 return {_nan_to_num_replace(Vectorized<scalar_t>(a), nan, posinf, neginf)};
488 #else
489 __at_align__ c10::complex<scalar_t> buffer[a.size()];
490 a.store(buffer);
491 auto asreal = Vectorized<scalar_t>::loadu(buffer);
492 _nan_to_num_replace(asreal, nan, posinf, neginf).store(buffer);
493 return Vectorized<c10::complex<scalar_t>>::loadu(buffer);
494 #endif
495 }
496
nan_to_num_kernel(TensorIteratorBase & iter,std::optional<double> nan,std::optional<double> pos_inf,std::optional<double> neg_inf)497 static void nan_to_num_kernel(
498 TensorIteratorBase& iter,
499 std::optional<double> nan,
500 std::optional<double> pos_inf,
501 std::optional<double> neg_inf) {
502 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "nan_to_num", [&]() {
503 using value_t = c10::scalar_value_type<scalar_t>::type;
504 value_t nan_replacement = static_cast<value_t>(nan.value_or(0.));
505 value_t pos_inf_replacement = pos_inf.has_value()
506 ? static_cast<value_t>(pos_inf.value())
507 : std::numeric_limits<value_t>::max();
508 value_t neg_inf_replacement = neg_inf.has_value()
509 ? static_cast<value_t>(neg_inf.value())
510 : std::numeric_limits<value_t>::lowest();
511 using vec_t = Vectorized<scalar_t>;
512
513 cpu_kernel_vec(iter, [=](scalar_t a) -> scalar_t {
514 return _nan_to_num_replace(a, nan_replacement, pos_inf_replacement, neg_inf_replacement);
515 }, [=](vec_t a) -> vec_t {
516 return _nan_to_num_replace(a, nan_replacement, pos_inf_replacement, neg_inf_replacement);
517 });
518 });
519 }
520
kaiser_window_kernel(TensorIteratorBase & iter,int64_t window_length,double beta)521 static void kaiser_window_kernel(TensorIteratorBase& iter, int64_t window_length, double beta){
522 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "kaiser_window_cpu", [&](){
523 using opmath_t = at::opmath_type<scalar_t>;
524 const opmath_t alpha = static_cast<opmath_t>((window_length - 1) / 2.0);
525 const opmath_t beta_ = static_cast<opmath_t>(beta);
526 cpu_kernel(iter, [=](scalar_t a) -> scalar_t {
527 return calc_i0(beta_ * std::sqrt(std::abs(1 - std::pow((static_cast<opmath_t>(a) - alpha) / alpha, static_cast<opmath_t>(2.0))))) / calc_i0(beta_);
528 });
529 });
530 }
531
rsqrt_kernel(TensorIteratorBase & iter)532 void rsqrt_kernel(TensorIteratorBase& iter) {
533 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "rsqrt_cpu", [&] {
534 cpu_kernel_vec(
535 iter,
536 [=](scalar_t a) __ubsan_ignore_float_divide_by_zero__ -> scalar_t {
537 return (static_cast<scalar_t>(1)) / std::sqrt(a);
538 },
539 [=](Vectorized<scalar_t> a) { return a.rsqrt(); });
540 });
541 }
542
entr_kernel(TensorIteratorBase & iter)543 static void entr_kernel(TensorIteratorBase& iter) {
544 AT_DISPATCH_FLOATING_TYPES_AND2(
545 kBFloat16, kHalf, iter.common_dtype(), "entr_cpu", [&] {
546 cpu_kernel(iter, [](scalar_t x) -> scalar_t {
547 if (at::_isnan(x)) {
548 return x;
549 } else if (x > 0) {
550 return -x * std::log(x);
551 } else if (x == 0) {
552 return static_cast<scalar_t>(0);
553 }
554 return static_cast<scalar_t>(-INFINITY);
555 });
556 });
557 }
558
frexp_kernel(TensorIteratorBase & iter)559 static void frexp_kernel(TensorIteratorBase& iter) {
560 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf,
561 // The iter.dtype() here is the dtype of mantissa output.
562 // It's a floating point type and must be the same as the input's dtype.
563 iter.dtype(),
564 "frexp_cpu", [&]() {
565 cpu_kernel_multiple_outputs(
566 iter,
567 [](scalar_t a) -> std::tuple<scalar_t, int32_t> {
568 int32_t exponent;
569 scalar_t mantissa = std::frexp(a, &exponent);
570 return std::tuple<scalar_t, int32_t>(mantissa, exponent);
571 }
572 );
573 });
574 }
575
ndtri_kernel(TensorIteratorBase & iter)576 static void ndtri_kernel(TensorIteratorBase& iter) {
577 TORCH_INTERNAL_ASSERT(iter.ntensors() == 2);
578 AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "ndtri_cpu", [&]() {
579 cpu_kernel(iter, [](scalar_t x) { return calc_ndtri(x); });
580 });
581 }
582
log_ndtr_kernel(TensorIteratorBase & iter)583 static void log_ndtr_kernel(TensorIteratorBase& iter) {
584 TORCH_INTERNAL_ASSERT(iter.ntensors() == 2);
585 AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "log_ndtr_cpu", [&]() {
586 cpu_kernel(iter, [](scalar_t x) { return calc_log_ndtr(x); });
587 });
588 }
589
i0e_kernel(TensorIteratorBase & iter)590 static void i0e_kernel(TensorIteratorBase& iter) {
591 TORCH_INTERNAL_ASSERT(iter.ntensors() == 2);
592 AT_DISPATCH_FLOATING_TYPES_AND2(
593 kBFloat16, kHalf, iter.common_dtype(), "i0e_cpu", [&]() {
594 cpu_kernel_vec(
595 iter,
596 [](scalar_t x) { return calc_i0e(x); },
597 [](Vectorized<scalar_t> x) { return x.i0e(); });
598 });
599 }
600
i1_kernel(TensorIteratorBase & iter)601 static void i1_kernel(TensorIteratorBase& iter) {
602 TORCH_INTERNAL_ASSERT(iter.ntensors() == 2);
603 AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "i1_cpu", [&]() {
604 cpu_kernel(iter, [](scalar_t x) { return calc_i1(x); });
605 });
606 }
607
i1e_kernel(TensorIteratorBase & iter)608 static void i1e_kernel(TensorIteratorBase& iter) {
609 TORCH_INTERNAL_ASSERT(iter.ntensors() == 2);
610 AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "i1e_cpu", [&]() {
611 cpu_kernel(iter, [](scalar_t x) { return calc_i1e(x); });
612 });
613 }
614
erfcx_kernel(TensorIteratorBase & iter)615 static void erfcx_kernel(TensorIteratorBase& iter){
616 AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "erfcx_cpu", [&]() {
617 cpu_kernel(
618 iter,
619 [](scalar_t a) -> scalar_t { return calc_erfcx(a); });
620 });
621 }
622
round_decimals_kernel(TensorIteratorBase & iter,int64_t decimals)623 static void round_decimals_kernel(TensorIteratorBase& iter, int64_t decimals) {
624 AT_DISPATCH_FLOATING_TYPES_AND2(
625 kBFloat16, kHalf, iter.dtype(), "round_cpu", [&]() {
626 using opmath_t = at::opmath_type<scalar_t>;
627 bool neg_flag = false;
628 opmath_t ten_pow_decimals;
629 if (decimals < 0) {
630 decimals = -decimals;
631 neg_flag = true;
632 }
633 ten_pow_decimals = static_cast<opmath_t>(std::pow(10, decimals));
634 cpu_kernel(iter, [ten_pow_decimals, neg_flag](scalar_t a) -> scalar_t {
635 return neg_flag ? std::nearbyint(static_cast<opmath_t>(a) / ten_pow_decimals) * ten_pow_decimals
636 : std::nearbyint(static_cast<opmath_t>(a) * ten_pow_decimals) / ten_pow_decimals;
637 });
638 });
639 }
640
bessel_j0_kernel(TensorIteratorBase & iterator)641 static void bessel_j0_kernel(TensorIteratorBase& iterator) {
642 TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2);
643
644 AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "bessel_j0_cpu", [&]() {
645 cpu_kernel(iterator, [](scalar_t x) {
646 return bessel_j0_forward(x);
647 });
648 });
649 } // bessel_j0_kernel(TensorIteratorBase& iterator)
650
bessel_j1_kernel(TensorIteratorBase & iterator)651 static void bessel_j1_kernel(TensorIteratorBase& iterator) {
652 TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2);
653
654 AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "bessel_j1_cpu", [&]() {
655 cpu_kernel(iterator, [](scalar_t x) {
656 return bessel_j1_forward(x);
657 });
658 });
659 } // bessel_j1_kernel(TensorIteratorBase& iterator)
660
bessel_y0_kernel(TensorIteratorBase & iterator)661 static void bessel_y0_kernel(TensorIteratorBase& iterator) {
662 TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2);
663
664 AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "bessel_y0_cpu", [&]() {
665 cpu_kernel(iterator, [](scalar_t x) {
666 return bessel_y0_forward(x);
667 });
668 });
669 } // bessel_y0_kernel(TensorIteratorBase& iterator)
670
bessel_y1_kernel(TensorIteratorBase & iterator)671 static void bessel_y1_kernel(TensorIteratorBase& iterator) {
672 TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2);
673
674 AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "bessel_y1_cpu", [&]() {
675 cpu_kernel(iterator, [](scalar_t x) {
676 return bessel_y1_forward(x);
677 });
678 });
679 } // bessel_y1_kernel(TensorIteratorBase& iterator)
680
modified_bessel_i0_kernel(TensorIteratorBase & iterator)681 static void modified_bessel_i0_kernel(TensorIteratorBase& iterator) {
682 TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2);
683
684 AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "modified_bessel_i0_cpu", [&]() {
685 cpu_kernel(iterator, [](scalar_t x) {
686 return modified_bessel_i0_forward(x);
687 });
688 });
689 } // modified_bessel_i0_kernel(TensorIteratorBase& iterator)
690
modified_bessel_i1_kernel(TensorIteratorBase & iterator)691 static void modified_bessel_i1_kernel(TensorIteratorBase& iterator) {
692 TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2);
693
694 AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "modified_bessel_i1_cpu", [&]() {
695 cpu_kernel(iterator, [](scalar_t x) {
696 return modified_bessel_i1_forward(x);
697 });
698 });
699 } // modified_bessel_i1_kernel(TensorIteratorBase& iterator)
700
modified_bessel_k0_kernel(TensorIteratorBase & iterator)701 static void modified_bessel_k0_kernel(TensorIteratorBase& iterator) {
702 TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2);
703
704 AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "modified_bessel_k0_cpu", [&]() {
705 cpu_kernel(iterator, [](scalar_t x) {
706 return modified_bessel_k0_forward(x);
707 });
708 });
709 } // modified_bessel_k0_kernel(TensorIteratorBase& iterator)
710
modified_bessel_k1_kernel(TensorIteratorBase & iterator)711 static void modified_bessel_k1_kernel(TensorIteratorBase& iterator) {
712 TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2);
713
714 AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "modified_bessel_k1_cpu", [&]() {
715 cpu_kernel(iterator, [](scalar_t x) {
716 return modified_bessel_k1_forward(x);
717 });
718 });
719 } // modified_bessel_k1_kernel(TensorIteratorBase& iterator)
720
721 // TODO: Disable cont. branch to test more risky code
722
723 #define IMPLEMENT_ITERATOR_LAMBDA(op) \
724 [&](char** data_, const int64_t* strides, int64_t n) { \
725 scalar_t* out_data = reinterpret_cast<scalar_t*>(data_[0]); \
726 scalar_t* in_data = reinterpret_cast<scalar_t*>(data_[1]); \
727 int64_t out_stride = strides[0] / sizeof(scalar_t); \
728 int64_t in_stride = strides[1] / sizeof(scalar_t); \
729 if (out_stride == 1 && in_stride == 1) { \
730 vml::v##op(out_data, in_data, n); \
731 return; \
732 } \
733 static constexpr int64_t WIDTH = (8*1024) / sizeof(scalar_t); \
734 for (int64_t i = 0; i < n; i += WIDTH) { \
735 scalar_t buffer[WIDTH]; \
736 const int64_t width = std::min(WIDTH, n - i); \
737 /* If either tensor is contiguous use it, otherwise copy into */ \
738 /* a contiguous buffer so compute can still be vectorized */ \
739 scalar_t * in_buffer = in_stride == 1 ? &in_data[i] : &buffer[0]; \
740 scalar_t * out_buffer = out_stride == 1 ? &out_data[i] : &buffer[0]; \
741 if (in_stride != 1) \
742 for (const auto j : c10::irange(width)) \
743 in_buffer[j] = in_data[in_stride * (i + j)]; \
744 vml::v##op(out_buffer, in_buffer, width); \
745 if (out_stride != 1) \
746 for (const auto j : c10::irange(width)) \
747 out_data[out_stride * (i + j)] = out_buffer[j]; \
748 } \
749 }
750
751 #define IMPLEMENT_FLOAT_KERNEL(op) \
752 inline namespace CPU_CAPABILITY { \
753 static void op##_kernel(TensorIteratorBase& iter) { \
754 TORCH_INTERNAL_ASSERT(iter.ntensors() == 2); \
755 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), #op "_vml_cpu", [&]() { \
756 constexpr int64_t grain_size = 2048; \
757 iter.for_each(IMPLEMENT_ITERATOR_LAMBDA(op), grain_size); \
758 }); \
759 iter.cast_outputs(); \
760 } \
761 }
762
763 #define IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(op) \
764 IMPLEMENT_FLOAT_KERNEL(op) \
765 REGISTER_DISPATCH(op##_stub, &CPU_CAPABILITY::op##_kernel)
766
767 #define IMPLEMENT_FLOAT_KERNEL_WITH_AVX512(op) \
768 IMPLEMENT_FLOAT_KERNEL(op) \
769 ALSO_REGISTER_AVX512_DISPATCH(op##_stub, &CPU_CAPABILITY::op##_kernel)
770
771 #define IMPLEMENT_COMPLEX_KERNEL(op) \
772 inline namespace CPU_CAPABILITY { \
773 void op##_kernel(TensorIteratorBase& iter) { \
774 TORCH_INTERNAL_ASSERT(iter.ntensors() == 2); \
775 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), #op "_vml_cpu", [&]() { \
776 constexpr int64_t grain_size = 2048; \
777 iter.for_each(IMPLEMENT_ITERATOR_LAMBDA(op), grain_size); \
778 }); \
779 iter.cast_outputs(); \
780 } \
781 }
782
783 #define IMPLEMENT_COMPLEX_KERNEL_WITHOUT_AVX512(op) \
784 IMPLEMENT_COMPLEX_KERNEL(op) \
785 REGISTER_DISPATCH(op##_stub, &CPU_CAPABILITY::op##_kernel)
786
787 #define IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(op) \
788 IMPLEMENT_COMPLEX_KERNEL(op) \
789 ALSO_REGISTER_AVX512_DISPATCH(op##_stub, &CPU_CAPABILITY::op##_kernel)
790
791 #define STATIC_IMPLEMENT_COMPLEX_KERNEL(op) \
792 inline namespace CPU_CAPABILITY { \
793 static void op##_kernel(TensorIteratorBase& iter) { \
794 TORCH_INTERNAL_ASSERT(iter.ntensors() == 2); \
795 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), #op "_vml_cpu", [&]() { \
796 constexpr int64_t grain_size = 2048; \
797 iter.for_each(IMPLEMENT_ITERATOR_LAMBDA(op), grain_size); \
798 }); \
799 iter.cast_outputs(); \
800 } \
801 }
802
803 #define STATIC_IMPLEMENT_COMPLEX_KERNEL_WITHOUT_AVX512(op) \
804 STATIC_IMPLEMENT_COMPLEX_KERNEL(op) \
805 REGISTER_DISPATCH(op##_stub, &CPU_CAPABILITY::op##_kernel)
806
807 #define STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(op) \
808 STATIC_IMPLEMENT_COMPLEX_KERNEL(op) \
809 ALSO_REGISTER_AVX512_DISPATCH(op##_stub, &CPU_CAPABILITY::op##_kernel)
810
811 } // CPU_CAPABILITY namespace
812
813 // The following kernels are slower with AVX512
814 REGISTER_DISPATCH(round_decimals_stub, &CPU_CAPABILITY::round_decimals_kernel);
815 REGISTER_DISPATCH(abs_stub, &CPU_CAPABILITY::abs_kernel);
816 REGISTER_DISPATCH(angle_stub, &CPU_CAPABILITY::angle_kernel);
817 REGISTER_DISPATCH(neg_stub, &CPU_CAPABILITY::neg_kernel);
818 REGISTER_DISPATCH(signbit_stub, &CPU_CAPABILITY::signbit_kernel);
819 REGISTER_DISPATCH(sinc_stub, &CPU_CAPABILITY::sinc_kernel);
820 REGISTER_DISPATCH(bitwise_not_stub, &CPU_CAPABILITY::bitwise_not_kernel);
821 REGISTER_DISPATCH(logical_not_stub, &CPU_CAPABILITY::logical_not_kernel);
822 REGISTER_DISPATCH(nan_to_num_stub, &CPU_CAPABILITY::nan_to_num_kernel);
823 REGISTER_DISPATCH(conj_physical_stub, &CPU_CAPABILITY::conj_kernel);
824 REGISTER_DISPATCH(rsqrt_stub, &CPU_CAPABILITY::rsqrt_kernel);
825 REGISTER_DISPATCH(frac_stub, &CPU_CAPABILITY::frac_kernel);
826 REGISTER_DISPATCH(special_entr_stub, &CPU_CAPABILITY::entr_kernel);
827 REGISTER_DISPATCH(special_i0e_stub, &CPU_CAPABILITY::i0e_kernel);
828 REGISTER_DISPATCH(special_ndtri_stub, &CPU_CAPABILITY::ndtri_kernel);
829 REGISTER_DISPATCH(special_modified_bessel_k0_stub, &CPU_CAPABILITY::modified_bessel_k0_kernel);
830 REGISTER_DISPATCH(special_modified_bessel_k1_stub, &CPU_CAPABILITY::modified_bessel_k1_kernel);
831 IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(ceil);
832 IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(floor);
833 IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(round);
834 IMPLEMENT_COMPLEX_KERNEL_WITHOUT_AVX512(sqrt);
835 IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(trunc);
836 IMPLEMENT_FLOAT_KERNEL_WITHOUT_AVX512(i0);
837 STATIC_IMPLEMENT_COMPLEX_KERNEL_WITHOUT_AVX512(sin);
838 STATIC_IMPLEMENT_COMPLEX_KERNEL_WITHOUT_AVX512(cos);
839 STATIC_IMPLEMENT_COMPLEX_KERNEL_WITHOUT_AVX512(tan);
840
841 // The following kernels are compute-intensive & are compiled with both AVX512
842 // & AVX2
843 ALSO_REGISTER_AVX512_DISPATCH(sign_stub, &CPU_CAPABILITY::sign_kernel);
844 ALSO_REGISTER_AVX512_DISPATCH(sgn_stub, &CPU_CAPABILITY::sgn_kernel);
845 ALSO_REGISTER_AVX512_DISPATCH(reciprocal_stub, &CPU_CAPABILITY::reciprocal_kernel);
846 ALSO_REGISTER_AVX512_DISPATCH(exp2_stub, &CPU_CAPABILITY::exp2_kernel);
847 ALSO_REGISTER_AVX512_DISPATCH(sigmoid_stub, &CPU_CAPABILITY::sigmoid_kernel);
848 ALSO_REGISTER_AVX512_DISPATCH(logit_stub, &CPU_CAPABILITY::logit_kernel);
849 ALSO_REGISTER_AVX512_DISPATCH(sinh_stub, &CPU_CAPABILITY::sinh_kernel);
850 ALSO_REGISTER_AVX512_DISPATCH(cosh_stub, &CPU_CAPABILITY::cosh_kernel);
851 ALSO_REGISTER_AVX512_DISPATCH(atanh_stub, &CPU_CAPABILITY::atanh_kernel);
852
853 // Might enable AVX512 dispatch after enabling explicit vectorization for them
854 REGISTER_DISPATCH(acosh_stub, &CPU_CAPABILITY::acosh_kernel);
855 REGISTER_DISPATCH(asinh_stub, &CPU_CAPABILITY::asinh_kernel);
856 REGISTER_DISPATCH(digamma_stub, &CPU_CAPABILITY::digamma_kernel);
857 REGISTER_DISPATCH(trigamma_stub, &CPU_CAPABILITY::trigamma_kernel);
858 REGISTER_DISPATCH(polygamma_stub, &CPU_CAPABILITY::polygamma_kernel);
859 REGISTER_DISPATCH(kaiser_window_stub, &CPU_CAPABILITY::kaiser_window_kernel);
860 REGISTER_DISPATCH(frexp_stub, &CPU_CAPABILITY::frexp_kernel);
861 REGISTER_DISPATCH(special_log_ndtr_stub, &CPU_CAPABILITY::log_ndtr_kernel);
862 REGISTER_DISPATCH(special_i1_stub, &CPU_CAPABILITY::i1_kernel);
863 REGISTER_DISPATCH(special_i1e_stub, &CPU_CAPABILITY::i1e_kernel);
864 REGISTER_DISPATCH(special_erfcx_stub, &CPU_CAPABILITY::erfcx_kernel);
865 REGISTER_DISPATCH(special_bessel_j0_stub, &CPU_CAPABILITY::bessel_j0_kernel);
866 REGISTER_DISPATCH(special_bessel_j1_stub, &CPU_CAPABILITY::bessel_j1_kernel);
867 REGISTER_DISPATCH(special_bessel_y0_stub, &CPU_CAPABILITY::bessel_y0_kernel);
868 REGISTER_DISPATCH(special_bessel_y1_stub, &CPU_CAPABILITY::bessel_y1_kernel);
869 REGISTER_DISPATCH(special_modified_bessel_i0_stub, &CPU_CAPABILITY::modified_bessel_i0_kernel);
870 REGISTER_DISPATCH(special_modified_bessel_i1_stub, &CPU_CAPABILITY::modified_bessel_i1_kernel);
871
872 STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(acos);
873 STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(asin);
874 STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(atan);
875 IMPLEMENT_FLOAT_KERNEL_WITH_AVX512(erf);
876 IMPLEMENT_FLOAT_KERNEL_WITH_AVX512(erfc);
877 IMPLEMENT_FLOAT_KERNEL_WITH_AVX512(erfinv);
878 STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(exp);
879 STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(expm1);
880 STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(log);
881 STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(log10);
882 STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(log1p);
883 STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(log2);
884 STATIC_IMPLEMENT_COMPLEX_KERNEL_WITH_AVX512(tanh);
885 IMPLEMENT_FLOAT_KERNEL_WITH_AVX512(lgamma);
886
887 } // namespace at::native
888