xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/Activation.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #ifndef _USE_MATH_DEFINES
3 #define _USE_MATH_DEFINES
4 #endif
5 
6 #include <ATen/native/Activation.h>
7 
8 
9 #include <cmath>
10 #include <functional>
11 
12 #include <ATen/Dispatch.h>
13 #include <ATen/OpMathType.h>
14 #include <ATen/core/TensorBase.h>
15 #include <ATen/cpu/vec/functional.h>
16 #include <ATen/cpu/vec/vec.h>
17 #include <ATen/native/TensorIterator.h>
18 #include <ATen/native/cpu/Loops.h>
19 #include <ATen/Parallel.h>
20 
21 #include <c10/core/Scalar.h>
22 
23 namespace at::native {
24 
25 namespace {
26 
log_sigmoid_cpu_kernel(TensorBase & output,TensorBase & buffer,const TensorBase & input)27 static void log_sigmoid_cpu_kernel(TensorBase &output, TensorBase &buffer, const TensorBase &input) {
28   if (at::isReducedFloatingType(input.scalar_type())) {
29     AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "log_sigmoid_cpu", [&]() {
30     using Vec = Vectorized<scalar_t>;
31     scalar_t* output_data = output.data_ptr<scalar_t>();
32     scalar_t* buffer_data = buffer.data_ptr<scalar_t>();
33     const scalar_t* input_data = input.const_data_ptr<scalar_t>();
34     parallel_for(0, input.numel(), 1, [&] (int64_t begin, int64_t end) {
35       int64_t size = end - begin;
36       int64_t d = 0;
37       for (; d < size - (size % Vec::size()); d += Vec::size()) {
38         Vec data_vec = Vec::loadu(input_data + begin+ d);
39         auto [data_vec0, data_vec1] = convert_to_float<scalar_t>(data_vec);
40         Vectorized<float> min_vec = minimum(data_vec0, Vectorized<float>(float(0)));
41         Vectorized<float> buffer_vec0 = data_vec0.abs().neg().exp();
42         Vectorized<float> output_vec0 = min_vec - buffer_vec0.log1p();
43         min_vec = minimum(data_vec1, Vectorized<float>(float(0)));
44         Vectorized<float> buffer_vec1 = data_vec1.abs().neg().exp();
45         Vectorized<float> output_vec1 = min_vec - buffer_vec1.log1p();
46         convert_from_float<scalar_t>(buffer_vec0, buffer_vec1).store(buffer_data + begin + d);
47         convert_from_float<scalar_t>(output_vec0, output_vec1).store(output_data + begin + d);
48       }
49       if (size - d > 0) {
50         Vec data_vec = Vec::loadu(input_data + begin + d, size - d);
51         auto [data_vec0, data_vec1] = convert_to_float<scalar_t>(data_vec);
52         Vectorized<float> min_vec = minimum(data_vec0, Vectorized<float>(float(0)));
53         Vectorized<float> buffer_vec0 = data_vec0.abs().neg().exp();
54         Vectorized<float> output_vec0 = min_vec - buffer_vec0.log1p();
55         min_vec = minimum(data_vec1, Vectorized<float>(float(0)));
56         Vectorized<float> buffer_vec1 = data_vec1.abs().neg().exp();
57         Vectorized<float> output_vec1 = min_vec - buffer_vec1.log1p();
58         convert_from_float<scalar_t>(buffer_vec0, buffer_vec1).store(buffer_data + begin + d, size - d);
59         convert_from_float<scalar_t>(output_vec0, output_vec1).store(output_data + begin + d, size - d);
60       }
61     });
62     });
63   } else {
64     AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "log_sigmoid_cpu", [&] {
65       using Vec = Vectorized<scalar_t>;
66       scalar_t* output_data = output.data_ptr<scalar_t>();
67       scalar_t* buffer_data = buffer.data_ptr<scalar_t>();
68       const scalar_t* input_data = input.const_data_ptr<scalar_t>();
69       parallel_for(0, input.numel(), 1, [&] (int64_t begin, int64_t end) {
70         int64_t size = end - begin;
71         int64_t d = 0;
72         for (; d < size - (size % Vec::size()); d += Vec::size()) {
73           Vec data_vec = Vec::loadu(input_data + begin+ d);
74           Vec min_vec = vec::minimum(data_vec, Vec(scalar_t(0)));
75           Vec buffer_vec = data_vec.abs().neg().exp();
76           Vec output_vec = min_vec - buffer_vec.log1p();
77           buffer_vec.store(buffer_data + begin + d);
78           output_vec.store(output_data + begin + d);
79         }
80         if (size - d > 0) {
81           Vec data_vec = Vec::loadu(input_data + begin + d, size - d);
82           Vec min_vec = vec::minimum(data_vec, Vec(scalar_t(0)));
83           Vec buffer_vec = data_vec.abs().neg().exp();
84           Vec output_vec = min_vec - buffer_vec.log1p();
85           buffer_vec.store(buffer_data + begin + d, size - d);
86           output_vec.store(output_data + begin + d, size - d);
87         }
88       });
89     });
90   }
91 }
92 
log_sigmoid_backward_cpu_kernel(TensorIterator & iter)93 static void log_sigmoid_backward_cpu_kernel(TensorIterator& iter) {
94   if (at::isReducedFloatingType(iter.dtype())) {
95     AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "log_sigmoid_backward_cpu", [&]() {
96       using Vec = Vectorized<scalar_t>;
97       auto zero_val = float(0);
98       auto zero_vec = Vectorized<float>(zero_val);
99       auto one_val = float(1);
100       auto one_vec = Vectorized<float>(one_val);
101       cpu_kernel_vec(iter,
102         [=](scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
103           auto in_negative = float(a) < float(0);
104           auto max_deriv = in_negative ? float(1) : float(0);
105           auto sign = in_negative ? float(1) : -float(1);
106           return (max_deriv - sign * (float(b) / (float(1) + b))) * float(c);
107         },
108         [=](Vec a, Vec b, Vec c) -> Vec {
109           auto [a0, a1] = convert_to_float<scalar_t>(a);
110           auto [b0, b1] = convert_to_float<scalar_t>(b);
111           auto [c0, c1] = convert_to_float<scalar_t>(c);
112           auto mask = a0 < zero_vec;
113           auto max_deriv_vec = Vectorized<float>::blendv(zero_vec, one_vec, mask);
114           auto sign_vec = Vectorized<float>::blendv(one_vec.neg(), one_vec, mask);
115           a0 = (max_deriv_vec - sign_vec * (b0 / (one_vec + b0))) * c0;
116           mask = a1 < zero_vec;
117           max_deriv_vec = Vectorized<float>::blendv(zero_vec, one_vec, mask);
118           sign_vec = Vectorized<float>::blendv(one_vec.neg(), one_vec, mask);
119           a1 = (max_deriv_vec - sign_vec * (b1 / (one_vec + b1))) * c1;
120           return convert_from_float<scalar_t>(a0, a1);
121         });
122     });
123   } else {
124     AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "log_sigmoid_backward_cpu", [&]() {
125     using Vec = Vectorized<scalar_t>;
126     auto zero_val = scalar_t(0);
127     auto zero_vec = Vec(zero_val);
128     auto one_val = scalar_t(1);
129     auto one_vec = Vec(one_val);
130     cpu_kernel_vec(iter,
131       [=](scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
132         auto in_negative = a < scalar_t(0);
133         auto max_deriv = in_negative ? scalar_t(1) : scalar_t(0);
134         auto sign = in_negative ? scalar_t(1) : -scalar_t(1);
135         return (max_deriv - sign * (b / (scalar_t(1) + b))) * c;
136       },
137       [=](Vec a, Vec b, Vec c) -> Vec {
138         auto mask = a < zero_vec;
139         auto max_deriv_vec = Vec::blendv(zero_vec, one_vec, mask);
140         auto sign_vec = Vec::blendv(one_vec.neg(), one_vec, mask);
141         return (max_deriv_vec - sign_vec * (b / (one_vec + b))) * c;
142       });
143   });
144   }
145 }
146 
threshold_kernel(TensorIteratorBase & iter,const Scalar & threshold_scalar,const Scalar & value_scalar)147 static void threshold_kernel(
148     TensorIteratorBase& iter,
149     const Scalar& threshold_scalar,
150     const Scalar& value_scalar) {
151   if (at::isReducedFloatingType(iter.dtype())) {
152     AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "threshold_cpu", [&]() {
153       using Vec = Vectorized<float>;
154       float threshold = threshold_scalar.to<float>();
155       Vec threshold_v = Vec(threshold);
156       scalar_t value = value_scalar.to<scalar_t>();
157       Vec value_v = Vec(float(value));
158       cpu_kernel_vec(
159           iter,
160           [&](scalar_t x, scalar_t other) -> scalar_t {
161             return float(x) <= threshold ? value : other;
162           },
163           [&](Vectorized<scalar_t> x, Vectorized<scalar_t> other) -> Vectorized<scalar_t> {
164             auto [x0, x1] = convert_to_float<scalar_t>(x);
165             auto [other0, other1] = convert_to_float<scalar_t>(other);
166             return convert_from_float<scalar_t>(Vec::blendv(other0, value_v, x0 <= threshold_v),
167                                                 Vec::blendv(other1, value_v, x1 <= threshold_v));
168           });
169     });
170   } else {
171     AT_DISPATCH_ALL_TYPES(iter.dtype(), "threshold_cpu", [&] {
172       using Vec = Vectorized<scalar_t>;
173       scalar_t threshold = threshold_scalar.to<scalar_t>();
174       Vec threshold_v = Vec(threshold);
175       scalar_t value = value_scalar.to<scalar_t>();
176       Vec value_v = Vec(value);
177       cpu_kernel_vec(
178           iter,
179           [&](scalar_t x, scalar_t other) -> scalar_t {
180             return x <= threshold ? value : other;
181           },
182           [&](Vec x, Vec other) -> Vec {
183             return Vec::blendv(other, value_v, x <= threshold_v);
184           });
185     });
186   }
187 }
188 
elu_kernel(TensorIteratorBase & it,const Scalar & alpha,const Scalar & scale,const Scalar & input_scale)189 void elu_kernel(TensorIteratorBase& it, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale) {
190   if (at::isReducedFloatingType(it.common_dtype())) {
191     AT_DISPATCH_REDUCED_FLOATING_TYPES(it.common_dtype(), "elu_cpu", [&]() {
192       auto negcoef = alpha.to<float>() * scale.to<float>();
193       auto poscoef = scale.to<float>();
194       auto negiptcoef = input_scale.to<float>();
195       const Vectorized<float> negcoef_vec(negcoef);
196       const Vectorized<float> negiptcoef_vec(negiptcoef);
197       const Vectorized<float> poscoef_vec(poscoef);
198       const Vectorized<float> one_vec(static_cast<float>(1));
199       const Vectorized<float> zero_vec(static_cast<float>(0));
200       cpu_kernel_vec(
201         it,
202         [negcoef, negiptcoef, poscoef](scalar_t a) -> scalar_t {
203           return float(a) <= float(0) ? (std::exp(float(a) * negiptcoef) - float(1)) * negcoef : float(a) * poscoef;
204         },
205         [&negcoef_vec, &negiptcoef_vec, &poscoef_vec, &one_vec, &zero_vec](Vectorized<scalar_t> a) -> Vectorized<scalar_t> {
206           auto [a0, a1] = convert_to_float<scalar_t>(a);
207           auto cmp0 = (a0 > zero_vec);
208           auto cmp1 = (a1 > zero_vec);
209           auto get_res_masked = [&](Vectorized<float>& cmp, Vectorized<float>& a) {
210             return !cmp.zero_mask() ? a * poscoef_vec :
211               Vectorized<float>::blendv(((a * negiptcoef_vec).exp() - one_vec) * negcoef_vec, a * poscoef_vec, cmp);
212           };
213           auto res0 = get_res_masked(cmp0, a0);
214           auto res1 = get_res_masked(cmp1, a1);
215           return convert_from_float<scalar_t>(res0, res1);
216         });
217     });
218   } else {
219     AT_DISPATCH_FLOATING_TYPES(it.common_dtype(), "elu_cpu", [&]() {
220       using Vec = Vectorized<scalar_t>;
221       auto negcoef = alpha.to<scalar_t>() * scale.to<scalar_t>();
222       auto poscoef = scale.to<scalar_t>();
223       auto negiptcoef = input_scale.to<scalar_t>();
224       const Vec negcoef_vec(negcoef);
225       const Vec negiptcoef_vec(negiptcoef);
226       const Vec poscoef_vec(poscoef);
227       const Vec one_vec(static_cast<scalar_t>(1));
228       const Vec zero_vec(static_cast<scalar_t>(0));
229       cpu_kernel_vec(
230           it,
231           [negcoef, negiptcoef, poscoef](scalar_t a) -> scalar_t {
232             return a <= scalar_t(0) ? (std::exp(a * negiptcoef) - scalar_t(1)) * negcoef : a * poscoef;
233           },
234           [&negcoef_vec, &negiptcoef_vec, &poscoef_vec, &one_vec, &zero_vec](Vec a) -> Vec {
235             auto cmp = (a > zero_vec);
236             if (!cmp.zero_mask()) {  // only a * poscoef (which is very quick) needs to be computed
237               return a * poscoef_vec;
238             } else {
239               return Vec::blendv(((a * negiptcoef_vec).exp() - one_vec) * negcoef_vec, a * poscoef_vec, cmp);
240             }
241           });
242     });
243   }
244 }
245 
elu_backward_kernel(TensorIteratorBase & it,const Scalar & alpha,const Scalar & scale,const Scalar & input_scale,bool is_result)246 void elu_backward_kernel(TensorIteratorBase& it, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale, bool is_result) {
247   if (at::isReducedFloatingType(it.common_dtype())) {
248     AT_DISPATCH_REDUCED_FLOATING_TYPES(it.common_dtype(), "elu_backward_cpu", [&]() {
249     auto negcoef = alpha.to<float>() * scale.to<float>();
250     auto poscoef = scale.to<float>();
251     auto negiptcoef = input_scale.to<float>();
252     const Vectorized<float> negcoef_vec(negcoef);
253     const Vectorized<float> negiptcoef_vec(negiptcoef);
254     const Vectorized<float> poscoef_vec(poscoef);
255     const Vectorized<float> zero_vec(static_cast<float>(0));
256     cpu_kernel_vec(
257         it,
258         [negcoef, negiptcoef, poscoef, is_result](scalar_t a, scalar_t b) -> scalar_t {
259           if (is_result) {
260             return float(b) <= float(0) ? float(a) * negiptcoef * (float(b) + negcoef) : float(a) * poscoef;
261           } else {
262             return float(b) <= float(0) ? float(a) * negiptcoef * negcoef * std::exp(float(b) * negiptcoef): float(a) * poscoef;
263           }
264         },
265         [&negcoef_vec, &negiptcoef_vec, &poscoef_vec, &zero_vec, is_result](Vectorized<scalar_t> a, Vectorized<scalar_t> b) -> Vectorized<scalar_t> {
266           auto [a0, a1] = convert_to_float<scalar_t>(a);
267           auto [b0, b1] = convert_to_float<scalar_t>(b);
268           auto cmp0 = (b0 > zero_vec);
269           auto cmp1 = (b1 > zero_vec);
270           auto get_res_masked = [&](Vectorized<float>& cmp, Vectorized<float>& a, Vectorized<float>& b) {
271             if (is_result) {
272               return !cmp.zero_mask() ? a * poscoef_vec :
273                 Vectorized<float>::blendv(a * negiptcoef_vec * (b + negcoef_vec), a * poscoef_vec, cmp);
274             } else {
275               return Vectorized<float>::blendv(a * negiptcoef_vec * negcoef_vec * (b * negiptcoef_vec).exp(), a * poscoef_vec, cmp);
276             }
277           };
278           auto res0 = get_res_masked(cmp0, a0, b0);
279           auto res1 = get_res_masked(cmp1, a1, b1);
280           return convert_from_float<scalar_t>(res0, res1);
281         });
282     });
283   } else {
284     AT_DISPATCH_FLOATING_TYPES(it.dtype(), "elu_backward_cpu", [&]() {
285       using Vec = Vectorized<scalar_t>;
286       auto negcoef = alpha.to<scalar_t>() * scale.to<scalar_t>();
287       auto poscoef = scale.to<scalar_t>();
288       auto negiptcoef = input_scale.to<scalar_t>();
289       const Vec negcoef_vec(negcoef);
290       const Vec negiptcoef_vec(negiptcoef);
291       const Vec poscoef_vec(poscoef);
292       const Vec zero_vec(static_cast<scalar_t>(0));
293       cpu_kernel_vec(
294           it,
295           [negcoef, negiptcoef, poscoef, is_result](scalar_t a, scalar_t b) -> scalar_t {
296             if (is_result) {
297               return b <= scalar_t(0) ? a * negiptcoef * (b + negcoef) : a * poscoef;
298             } else {
299               return b <= scalar_t(0) ? a * negiptcoef * negcoef * std::exp(b * negiptcoef): a * poscoef;
300             }
301           },
302           [&negcoef_vec, &negiptcoef_vec, &poscoef_vec, &zero_vec, is_result](Vec a, Vec b) -> Vec {
303             auto cmp = (b > zero_vec);
304             if (is_result) {
305               if (!cmp.zero_mask()) {  // only a * poscoef (which is very quick) needs to be computed
306                 return a * poscoef_vec;
307               } else {
308                 return Vec::blendv(a * negiptcoef_vec * (b + negcoef_vec), a * poscoef_vec, cmp);
309               }
310             } else {
311               return Vec::blendv(a * negiptcoef_vec * negcoef_vec * (b * negiptcoef_vec).exp(), a * poscoef_vec, cmp);
312             }
313           }
314       );
315     });
316   }
317 }
318 
319 // TODO(yangxm): Add another fast kernel using formula
320 // y = 0.5x * (1 + tanh(sqrt(2/Pi) * (x + 0.044715x^3)))
321 // and the fast tanh impl from Eigen.
GeluKernelImpl(TensorIteratorBase & it,GeluType approximate)322 void GeluKernelImpl(TensorIteratorBase& it, GeluType approximate) {
323   auto grain_size = at::internal::GRAIN_SIZE;
324   // Numbers based on benchmarking.
325   // Benchmark: benchmarks/operator_benchmarks/pt/gelu_test.py
326 #ifdef C10_MOBILE
327   // Benchmarked on S8 US phone.
328   // Internal benchmarking that converts operator benchmark into
329   // a torchscript module and run that on mobile.
330   // Same benchmark as server side.
331   constexpr int64_t GELU_MIN_ELEMENTS_FOR_MULTI_THREADING{6144};
332 #else
333   // Benchmarked on i9 8 core 16 thread machine.
334   // 1 thread: cd benchmark/operator_benchmarks;
335   //           python -m pt.gelu_test --tag_filter long --omp_num_threads 1
336   // 2 threads: cd benchmark/operator_benchmarks;
337   //           python -m pt.gelu_test --tag_filter long --omp_num_threads 1
338   constexpr int64_t GELU_MIN_ELEMENTS_FOR_MULTI_THREADING{16384};
339 #endif
340   if (it.numel() > GELU_MIN_ELEMENTS_FOR_MULTI_THREADING) {
341     grain_size = it.numel() / at::get_num_threads();
342   }
343   if (approximate == GeluType::Tanh) {
344     if (at::isReducedFloatingType(it.common_dtype())) {
345       AT_DISPATCH_REDUCED_FLOATING_TYPES(it.common_dtype(), "GeluKernelImpl", [&]() {
346         auto kBetaVec = Vectorized<float>((float)(M_SQRT2 * M_2_SQRTPI * 0.5));
347         auto kKappaVec = Vectorized<float>((float)(0.044715));
348         auto kOneVec = Vectorized<float>((float)(1));
349         auto kPointFiveVec = Vectorized<float>((float)(0.5));
350         cpu_kernel_vec(
351             it,
352             [](scalar_t x) -> scalar_t {
353               const float kBeta = float(M_SQRT2 * M_2_SQRTPI * 0.5);
354               const float kKappa = float(0.044715);
355               float x_cube = float(x) * float(x) * float(x);
356               float inner = kBeta * (float(x) + kKappa * x_cube);
357               return float(0.5) * float(x) * (float(1) + std::tanh(inner));
358             },
359             [&](Vectorized<scalar_t> x) -> Vectorized<scalar_t> {
360               auto [x0, x1] = convert_to_float<scalar_t>(x);
361               auto x0_cube = x0 * x0 * x0;
362               auto x1_cube = x1 * x1 * x1;
363               auto inner_vec0 = kBetaVec * (x0 + kKappaVec * x0_cube);
364               auto inner_vec1 = kBetaVec * (x1 + kKappaVec * x1_cube);
365               auto res0 = kPointFiveVec * x0 * (kOneVec + inner_vec0.tanh());
366               auto res1 = kPointFiveVec * x1 * (kOneVec + inner_vec1.tanh());
367               return convert_from_float<scalar_t>(res0, res1);
368             },
369             grain_size);
370       });
371     } else {
372       AT_DISPATCH_FLOATING_TYPES(
373           it.dtype(), "GeluKernelImpl", [&]() {
374         using Vec = vec::Vectorized<scalar_t>;
375         const Vec kBetaVec(scalar_t(M_SQRT2 * M_2_SQRTPI * 0.5));
376         const Vec kKappaVec(scalar_t(0.044715));
377         const Vec kOneVec(scalar_t(1));
378         const Vec kPointFiveVec(scalar_t(0.5));
379         cpu_kernel_vec(
380             it,
381             [](scalar_t x) {
382               const scalar_t kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
383               const scalar_t kKappa = 0.044715;
384               auto x_cube = x * x * x;
385               auto inner = kBeta * (x + kKappa * x_cube);
386               return scalar_t(0.5) * x * (scalar_t(1) + std::tanh(inner));
387             },
388             [&](Vec x_vec) {
389               auto x_cube = x_vec * x_vec * x_vec;
390               auto inner_vec = kBetaVec * (x_vec + kKappaVec * x_cube);
391               return kPointFiveVec * x_vec * (kOneVec + inner_vec.tanh());
392             },
393             grain_size);
394       });
395     }
396   } else {
397     if (at::isReducedFloatingType(it.common_dtype())) {
398       AT_DISPATCH_REDUCED_FLOATING_TYPES(it.dtype(), "GeluKernelImpl", [&]() {
399         auto kAlphaVec = Vectorized<float>((float)(M_SQRT1_2));
400         auto kOneVec = Vectorized<float>((float)(1));
401         auto kPointFiveVec = Vectorized<float>((float)(0.5));
402         cpu_kernel_vec(
403             it,
404             [](scalar_t x) -> scalar_t {
405               const float kAlpha = float(M_SQRT1_2);
406               return float(x) * float(0.5) * (float(1) + std::erf(float(x) * kAlpha));
407             },
408             [&](Vectorized<scalar_t> x) -> Vectorized<scalar_t> {
409               auto [x0, x1] = convert_to_float<scalar_t>(x);
410               auto res0 = x0 * kPointFiveVec * (kOneVec + (x0 * kAlphaVec).erf());
411               auto res1 = x1 * kPointFiveVec * (kOneVec + (x1 * kAlphaVec).erf());
412               return convert_from_float<scalar_t>(res0, res1);
413             },
414             grain_size);
415       });
416     } else {
417       AT_DISPATCH_FLOATING_TYPES(
418           it.dtype(), "GeluKernelImpl", [&]() {
419         using Vec = vec::Vectorized<scalar_t>;
420         const Vec kAlphaVec(scalar_t(M_SQRT1_2));
421         const Vec kOneVec(scalar_t(1));
422         const Vec kPointFiveVec(scalar_t(0.5));
423         cpu_kernel_vec(
424             it,
425             [](scalar_t x) {
426               const scalar_t kAlpha = scalar_t(M_SQRT1_2);
427               return x * scalar_t(0.5) * (scalar_t(1) + std::erf(x * kAlpha));
428             },
429             [&](Vec x_vec) {
430               return x_vec * kPointFiveVec *
431                   (kOneVec + (x_vec * kAlphaVec).erf());
432             },
433             grain_size);
434       });
435     }
436   }
437 }
438 
GeluBackwardKernelImpl(TensorIteratorBase & it,GeluType approximate)439 void GeluBackwardKernelImpl(TensorIteratorBase& it, GeluType approximate) {
440   if (approximate == GeluType::Tanh) {
441     if (at::isReducedFloatingType(it.common_dtype())) {
442       AT_DISPATCH_REDUCED_FLOATING_TYPES(it.common_dtype(), "GeluBackwardKernelImpl", [&]() {
443       auto kBetaVec = Vectorized<float>((float)(M_SQRT2 * M_2_SQRTPI * 0.5));
444       auto kKappaVec = Vectorized<float>((float)(0.044715));
445       auto kOneVec = Vectorized<float>((float)(1));
446       auto kThreeVec = Vectorized<float>((float)(3));
447       auto kPointFiveVec = Vectorized<float>((float)(0.5));
448       cpu_kernel_vec(
449           it,
450           [](scalar_t dy, scalar_t x) -> scalar_t {
451             const float kBeta = float(M_SQRT2 * M_2_SQRTPI * 0.5);
452             const float kKappa = float(0.044715);
453             float x_sq = float(x) * float(x);
454             float x_cube = x_sq * float(x);
455             float inner = kBeta * (float(x) + kKappa * x_cube);
456             float tanh_inner = float(std::tanh(inner));
457 
458             float left = float(0.5) * float(x);
459             float right = float(1) + tanh_inner;
460 
461             float left_derivative = float(0.5) * right;
462 
463             float tanh_derivative = float(1) - tanh_inner * tanh_inner;
464             float inner_derivative =
465               kBeta * (float(1) + float(3) * kKappa * x_sq);
466             float right_derivative = left * tanh_derivative * inner_derivative;
467 
468             return float(dy) * (left_derivative + right_derivative);
469           },
470           [&](Vectorized<scalar_t> dy_vec, Vectorized<scalar_t> x_vec) -> Vectorized<scalar_t> {
471             auto [x0_vec, x1_vec] = convert_to_float<scalar_t>(x_vec);
472             auto [dy0_vec, dy1_vec] = convert_to_float<scalar_t>(dy_vec);
473             auto x0_sq = x0_vec * x0_vec;
474             auto x1_sq = x1_vec * x1_vec;
475             auto x0_cube = x0_vec * x0_vec * x0_vec;
476             auto x1_cube = x1_vec * x1_vec * x1_vec;
477             auto inner_vec0 = kBetaVec * (x0_vec + kKappaVec * x0_cube);
478             auto inner_vec1 = kBetaVec * (x1_vec + kKappaVec * x1_cube);
479             auto tanh_inner_vec0 = inner_vec0.tanh();
480             auto tanh_inner_vec1 = inner_vec1.tanh();
481 
482             auto left_vec0 = kPointFiveVec * x0_vec;
483             auto left_vec1 = kPointFiveVec * x1_vec;
484             auto right_vec0 = kOneVec + tanh_inner_vec0;
485             auto right_vec1 = kOneVec + tanh_inner_vec1;
486 
487             auto left_derivative_vec0 = kPointFiveVec * right_vec0;
488             auto left_derivative_vec1 = kPointFiveVec * right_vec1;
489 
490             auto tanh_derivative_vec0 = kOneVec - tanh_inner_vec0 * tanh_inner_vec0;
491             auto tanh_derivative_vec1 = kOneVec - tanh_inner_vec1 * tanh_inner_vec1;
492             auto inner_derivative_vec0 = kBetaVec * (kOneVec + kThreeVec * kKappaVec * x0_sq);
493             auto inner_derivative_vec1 = kBetaVec * (kOneVec + kThreeVec * kKappaVec * x1_sq);
494             auto right_derivative_vec0 = left_vec0 * tanh_derivative_vec0 * inner_derivative_vec0;
495             auto right_derivative_vec1 = left_vec1 * tanh_derivative_vec1 * inner_derivative_vec1;
496 
497             auto res0 = dy0_vec * (left_derivative_vec0 + right_derivative_vec0);
498             auto res1 = dy1_vec * (left_derivative_vec1 + right_derivative_vec1);
499             return convert_from_float<scalar_t>(res0, res1);
500           });
501       });
502     } else {
503       AT_DISPATCH_FLOATING_TYPES(
504           it.dtype(), "GeluBackwardKernelImpl", [&]() {
505         using Vec = vec::Vectorized<scalar_t>;
506         const Vec kBetaVec(scalar_t(M_SQRT2 * M_2_SQRTPI * 0.5));
507         const Vec kKappaVec(scalar_t(0.044715));
508         const Vec kOneVec(scalar_t(1));
509         const Vec kThreeVec(scalar_t(3));
510         const Vec kPointFiveVec(scalar_t(0.5));
511         cpu_kernel_vec(
512             it,
513             [](scalar_t dy, scalar_t x) {
514               const scalar_t kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
515               const scalar_t kKappa = 0.044715;
516               auto x_sq = x * x;
517               auto x_cube = x_sq * x;
518               auto inner = kBeta * (x + kKappa * x_cube);
519               auto tanh_inner = std::tanh(inner);
520 
521               auto left = scalar_t(0.5) * x;
522               auto right = scalar_t(1) + tanh_inner;
523 
524               auto left_derivative = scalar_t(0.5) * right;
525 
526               auto tanh_derivative = scalar_t(1) - tanh_inner * tanh_inner;
527               auto inner_derivative =
528                 kBeta * (scalar_t(1) + scalar_t(3) * kKappa * x_sq);
529               auto right_derivative = left * tanh_derivative * inner_derivative;
530 
531               return dy * (left_derivative + right_derivative);
532             },
533             [&](Vec dy_vec, Vec x_vec) {
534               auto x_sq = x_vec * x_vec;
535               auto x_cube = x_vec * x_vec * x_vec;
536               auto inner_vec =
537                   kBetaVec * (x_vec + kKappaVec * x_cube);
538               auto tanh_inner_vec = inner_vec.tanh();
539 
540               auto left_vec = kPointFiveVec * x_vec;
541               auto right_vec = kOneVec + tanh_inner_vec;
542 
543               auto left_derivative_vec = kPointFiveVec * right_vec;
544 
545               auto tanh_derivative_vec =
546                   kOneVec - tanh_inner_vec * tanh_inner_vec;
547               auto inner_derivative_vec =
548                   kBetaVec * (kOneVec + kThreeVec * kKappaVec * x_sq);
549               auto right_derivative_vec =
550                   left_vec * tanh_derivative_vec * inner_derivative_vec;
551 
552               return dy_vec * (left_derivative_vec + right_derivative_vec);
553             });
554       });
555     }
556   } else {
557     if (at::isReducedFloatingType(it.common_dtype())) {
558       AT_DISPATCH_REDUCED_FLOATING_TYPES(it.common_dtype(), "GeluBackwardKernelImpl", [&]() {
559       auto kAlphaVec = Vectorized<float>((float)(M_SQRT1_2));
560       auto kBetaVec = Vectorized<float>((float)(M_2_SQRTPI * M_SQRT1_2 * 0.5));
561       auto kOneVec = Vectorized<float>((float)(1));
562       auto kPointFiveVec = Vectorized<float>((float)(0.5));
563       auto kMinusPointFiveVec = Vectorized<float>((float)(-0.5));
564       cpu_kernel_vec(
565           it,
566           [](scalar_t dy, scalar_t x) -> scalar_t {
567               const float kAlpha = float(M_SQRT1_2);
568               const float kBeta = float(M_2_SQRTPI) * float(M_SQRT1_2) * float(0.5);
569               const float cdf =
570                   float(0.5) * (float(1) + std::erf(float(x) * kAlpha));
571               const float pdf = kBeta * std::exp(float(x) * float(x) * float(-0.5));
572               return float(dy) * (cdf + float(x) * pdf);
573           },
574           [&](Vectorized<scalar_t> dy, Vectorized<scalar_t> x) -> Vectorized<scalar_t> {
575               auto [x0, x1] = convert_to_float<scalar_t>(x);
576               auto [dy0, dy1] = convert_to_float<scalar_t>(dy);
577               auto cdf_vec0 = kPointFiveVec * (kOneVec + (x0 * kAlphaVec).erf());
578               auto cdf_vec1 = kPointFiveVec * (kOneVec + (x1 * kAlphaVec).erf());
579               auto pdf_vec0 = kBetaVec * (x0 * x0 * kMinusPointFiveVec).exp();
580               auto pdf_vec1 = kBetaVec * (x1 * x1 * kMinusPointFiveVec).exp();
581               auto res0 = dy0 * (cdf_vec0 + x0 * pdf_vec0);
582               auto res1 = dy1 * (cdf_vec1 + x1 * pdf_vec1);
583               return convert_from_float<scalar_t>(res0, res1);
584           });
585       });
586     } else {
587       AT_DISPATCH_FLOATING_TYPES(
588           it.dtype(), "GeluBackwardKernelImpl", [&]() {
589         using Vec = vec::Vectorized<scalar_t>;
590         const Vec kAlphaVec(scalar_t(M_SQRT1_2));
591         const Vec kBetaVec(scalar_t(M_2_SQRTPI * M_SQRT1_2 * 0.5));
592         const Vec kOneVec(scalar_t(1));
593         const Vec kPointFiveVec(scalar_t(0.5));
594         const Vec kMinusPointFiveVec(scalar_t(-0.5));
595         cpu_kernel_vec(
596             it,
597             [](scalar_t dy, scalar_t x) {
598               const scalar_t kAlpha = scalar_t(M_SQRT1_2);
599               const scalar_t kBeta = M_2_SQRTPI * M_SQRT1_2 * scalar_t(0.5);
600               const scalar_t cdf =
601                   scalar_t(0.5) * (scalar_t(1) + std::erf(x * kAlpha));
602               const scalar_t pdf = kBeta * std::exp(x * x * scalar_t(-0.5));
603               return dy * (cdf + x * pdf);
604             },
605             [&](Vec dy_vec, Vec x_vec) {
606               const Vec cdf_vec =
607                   kPointFiveVec * (kOneVec + (x_vec * kAlphaVec).erf());
608               const Vec pdf_vec =
609                   kBetaVec * (x_vec * x_vec * kMinusPointFiveVec).exp();
610               return dy_vec * (cdf_vec + x_vec * pdf_vec);
611             });
612       });
613     }
614   }
615 }
616 
hardsigmoid_kernel(TensorIteratorBase & iter)617 void hardsigmoid_kernel(TensorIteratorBase& iter) {
618   if (at::isReducedFloatingType(iter.dtype())) {
619     AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "hardsigmoid_cpu", [&]() {
620     const float zero(0.0f);
621     const float three(3.0f);
622     const float six(6.0f);
623     using Vec = vec::Vectorized<float>;
624     const Vec kZeroVec(zero);
625     const Vec kThreeVec(three);
626     const Vec kSixVec(six);
627     cpu_kernel_vec(
628         iter,
629         [&](scalar_t self_val) -> scalar_t {
630           return std::min(std::max(float(self_val) + three, zero), six) / six;
631         },
632         [&](vec::Vectorized<scalar_t> self_val) -> vec::Vectorized<scalar_t> {
633           auto [self_val0, self_val1] = convert_to_float<scalar_t>(self_val);
634           self_val0 = minimum(
635             maximum(self_val0 + kThreeVec, kZeroVec),
636             kSixVec
637           ) / kSixVec;
638           self_val1 = minimum(
639             maximum(self_val1 + kThreeVec, kZeroVec),
640             kSixVec
641           ) / kSixVec;
642           return convert_from_float<scalar_t>(self_val0, self_val1);
643         });
644     });
645   } else {
646     AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "hardsigmoid_cpu", [&] {
647     const scalar_t zero(0.0f);
648     const scalar_t three(3.0f);
649     const scalar_t six(6.0f);
650     using Vec = vec::Vectorized<scalar_t>;
651     const Vec kZeroVec(zero);
652     const Vec kThreeVec(three);
653     const Vec kSixVec(six);
654     cpu_kernel_vec(
655         iter,
656         [&](scalar_t self_val) {
657           return std::min(std::max(self_val + three, zero), six) / six;
658         },
659         [&](Vec self_val) {
660           return vec::minimum(
661             vec::maximum(self_val + kThreeVec, kZeroVec),
662             kSixVec
663           ) / kSixVec;
664         });
665   });
666   }
667 }
668 
hardsigmoid_backward_kernel(TensorIteratorBase & iter)669 void hardsigmoid_backward_kernel(TensorIteratorBase& iter) {
670   if (at::isReducedFloatingType(iter.dtype())) {
671     AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.common_dtype(), "hardsigmoid_backward", [&]() {
672     const float zero(0.0f);
673     const float three(3.0f);
674     const float neg_three(-3.0f);
675     const float one_sixth(1.0f / 6.0f);
676     using Vec = Vectorized<float>;
677     Vec kZeroVec(0.0f);
678     Vec kOneSixthVec(1.0f / 6.0f);
679     cpu_kernel_vec(
680         iter,
681         [=](scalar_t grad_val, scalar_t self_val) -> scalar_t {
682           return (float(self_val) > neg_three && float(self_val) < three)
683             ? float(grad_val) * one_sixth
684             : zero;
685         },
686         [=](Vectorized<scalar_t> grad_val, Vectorized<scalar_t> self_val) -> Vectorized<scalar_t> {
687           auto [self_val0, self_val1] = convert_to_float<scalar_t>(self_val);
688           auto [grad_val0, grad_val1] = convert_to_float<scalar_t>(grad_val);
689           Vec gradNonZeroMask = (self_val0 > neg_three) & (self_val0 < three);
690           self_val0 = Vec::blendv(kZeroVec, grad_val0 * kOneSixthVec, gradNonZeroMask);
691           gradNonZeroMask = (self_val1 > neg_three) & (self_val1 < three);
692           self_val1 = Vec::blendv(kZeroVec, grad_val1 * kOneSixthVec, gradNonZeroMask);
693           return convert_from_float<scalar_t>(self_val0, self_val1);
694         });
695     });
696   } else {
697     AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "hardsigmoid_backward", [&] {
698     const scalar_t zero(0.0f);
699     const scalar_t three(3.0f);
700     const scalar_t neg_three(-3.0f);
701     const scalar_t one_sixth(1.0f / 6.0f);
702     using Vec = Vectorized<scalar_t>;
703     Vec kZeroVec(0.0f);
704     Vec kOneSixthVec(1.0f / 6.0f);
705     cpu_kernel_vec(
706         iter,
707         [=](scalar_t grad_val, scalar_t self_val) {
708           return (self_val > neg_three && self_val < three)
709             ? grad_val * one_sixth
710             : zero;
711         },
712         [=](Vec grad_val, Vec self_val) {
713           Vec gradNonZeroMask = (self_val > neg_three) & (self_val < three);
714           return Vec::blendv(kZeroVec, grad_val * kOneSixthVec, gradNonZeroMask);
715         });
716   });
717   }
718 }
719 
hardshrink_kernel(TensorIteratorBase & iter,const Scalar & lambd)720 void hardshrink_kernel(TensorIteratorBase& iter, const Scalar& lambd) {
721     AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "hardshrink_cpu", [&] {
722     auto lambd_val = lambd.to<scalar_t>();
723     using Vec = Vectorized<scalar_t>;
724     cpu_kernel_vec(
725         iter,
726         [=](scalar_t self_val) {
727           return (self_val >= -lambd_val && self_val <= lambd_val) ? scalar_t(0)
728                                                                    : self_val;
729         },
730         [=](Vec self_val) {
731           return Vec::blendv(self_val, Vec(0), (self_val >= -lambd_val) & (self_val <= lambd_val));
732         });
733   });
734 }
735 
softshrink_kernel(TensorIteratorBase & iter,const Scalar & lambd)736 void softshrink_kernel(TensorIteratorBase& iter, const Scalar& lambd) {
737   if (at::isReducedFloatingType(iter.dtype())) {
738     AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.common_dtype(), "softshrink_cpu", [&]() {
739     auto lambd_val = lambd.to<float>();
740     auto lambdVec = Vectorized<float>(lambd_val);
741     cpu_kernel_vec(
742       iter,
743       [=](scalar_t a) -> scalar_t {
744         return float(a) > lambd_val ? a - lambd_val : (float(a) < -lambd_val ? a + lambd_val : float(0));
745       },
746       [=](Vectorized<scalar_t> self_val) -> Vectorized<scalar_t> {
747           auto [self_val0, self_val1] = convert_to_float<scalar_t>(self_val);
748           auto self_val_t0 = convert_from_float<scalar_t>((self_val0 > lambdVec) & (self_val0 - lambdVec), (self_val1 > lambdVec) & (self_val1 - lambdVec));
749           auto self_val_t1 = convert_from_float<scalar_t>((self_val0 < -lambd_val) & (self_val0 + lambdVec), (self_val1 < -lambd_val) & (self_val1 + lambdVec));
750           return (self_val_t0 | self_val_t1);
751       });
752     });
753   } else {
754     AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "softshrink_cpu", [&]() {
755     auto lambd_val = lambd.to<scalar_t>();
756     auto lambdVec = Vectorized<scalar_t>(lambd_val);
757     cpu_kernel_vec(
758       iter,
759       [=](scalar_t a) -> scalar_t {
760         return a > lambd_val ? a - lambd_val : (a < -lambd_val ? a + lambd_val : scalar_t(0));
761       },
762       [=](Vectorized<scalar_t> self_val) -> Vectorized<scalar_t> {
763           Vectorized<scalar_t> self_val_t0, self_val_t1;
764           self_val_t0 = (self_val > lambdVec) & (self_val - lambdVec);
765           self_val_t1 = (self_val < -lambd_val) & (self_val + lambdVec);
766           return (self_val_t0 | self_val_t1);
767       });
768   });
769   }
770 }
771 
shrink_backward_kernel(TensorIteratorBase & iter,const Scalar & lambd)772 void shrink_backward_kernel(TensorIteratorBase& iter, const Scalar& lambd) {
773     AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "shrink_backward_cpu", [&] {
774     auto lambd_val = lambd.to<scalar_t>();
775     cpu_kernel_vec(
776         iter,
777         [=](scalar_t grad_val, scalar_t self_val) {
778           return (self_val >= -lambd_val && self_val <= lambd_val) ? scalar_t(0)
779                                                                    : grad_val;
780         },
781         [=](Vectorized<scalar_t> grad_val, Vectorized<scalar_t> self_val) {
782           return ((self_val < -lambd_val) | (self_val > lambd_val)) & grad_val;
783         });
784   });
785 }
786 
hardtanh_backward_kernel(TensorIterator & iter,const Scalar & min,const Scalar & max)787 void hardtanh_backward_kernel(TensorIterator& iter, const Scalar& min, const Scalar& max) {
788   if (at::isReducedFloatingType(iter.dtype())) {
789     AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "hardshrink_backward_cpu", [&]() {
790       auto min_val = min.to<float>();
791       auto max_val = max.to<float>();
792       cpu_kernel_vec(
793           iter,
794           [=](scalar_t grad_val, scalar_t self_val) -> scalar_t {
795             return (float(self_val) <= min_val || float(self_val) >= max_val) ? scalar_t(0) : grad_val;
796           },
797           [=](Vectorized<scalar_t> grad_val, Vectorized<scalar_t> self_val) -> Vectorized<scalar_t> {
798             auto [grad_val0, grad_val1] = convert_to_float<scalar_t>(grad_val);
799             auto [self_val0, self_val1] = convert_to_float<scalar_t>(self_val);
800             return convert_from_float<scalar_t>(
801               ((self_val0 > min_val) & (self_val0 < max_val)) & grad_val0,
802               ((self_val1 > min_val) & (self_val1 < max_val)) & grad_val1
803             );
804           });
805     });
806   } else {
807     AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "hardshrink_backward_cpu", [&] {
808     auto min_val = min.to<scalar_t>();
809     auto max_val = max.to<scalar_t>();
810     cpu_kernel_vec(
811         iter,
812         [=](scalar_t grad_val, scalar_t self_val) {
813           return (self_val <= min_val || self_val >= max_val) ? scalar_t(0) : grad_val;
814         },
815         [=](Vectorized<scalar_t> grad_val, Vectorized<scalar_t> self_val) {
816           return ((self_val > min_val) & (self_val < max_val)) & grad_val;
817         });
818   });
819   }
820 }
821 
hardswish_kernel(TensorIterator & iter)822 void hardswish_kernel(TensorIterator& iter) {
823   if (at::isReducedFloatingType(iter.dtype())) {
824     AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "hardswish_cpu", [&]() {
825     const float zero(0.0f);
826     const float three(3.0f);
827     const float six(6.0f);
828     using Vec = vec::Vectorized<float>;
829     const Vec kZeroVec(zero);
830     const Vec kThreeVec(three);
831     const Vec kSixVec(six);
832     cpu_kernel_vec(
833       iter,
834       [&](scalar_t x) -> scalar_t {
835         return float(x) * std::min(std::max(float(x) + three, zero), six) / six;
836       },
837       [&](vec::Vectorized<scalar_t> x_vec) {
838         auto [x_vec0, x_vec1] = convert_to_float<scalar_t>(x_vec);
839         x_vec0 = x_vec0 * minimum(
840           maximum(x_vec0 + kThreeVec, kZeroVec),
841           kSixVec
842         ) / kSixVec;
843         x_vec1 = x_vec1 * minimum(
844           maximum(x_vec1 + kThreeVec, kZeroVec),
845           kSixVec
846         ) / kSixVec;
847         return convert_from_float<scalar_t>(x_vec0, x_vec1);
848       });
849     });
850   } else {
851     AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "hardswish_cpu", [&]() {
852     const scalar_t zero(0.0f);
853     const scalar_t three(3.0f);
854     const scalar_t six(6.0f);
855     using Vec = vec::Vectorized<scalar_t>;
856     const Vec kZeroVec(zero);
857     const Vec kThreeVec(three);
858     const Vec kSixVec(six);
859     cpu_kernel_vec(
860       iter,
861       [&](scalar_t x) {
862         return x * std::min(std::max(x + three, zero), six) / six;
863       },
864       [&](Vec x_vec) {
865         return x_vec * vec::minimum(
866           vec::maximum(x_vec + kThreeVec, kZeroVec),
867           kSixVec
868         ) / kSixVec;
869       }
870     );
871   });
872   }
873 }
874 
hardswish_backward_kernel(TensorIterator & iter)875 void hardswish_backward_kernel(TensorIterator& iter) {
876   if (at::isReducedFloatingType(iter.dtype())) {
877     AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "hardswish_backward_cpu", [&]() {
878     const float zero(0.0f);
879     const float three(3.0f);
880     const float neg_three(-3.0f);
881     const float one_half(0.5f);
882     using Vec = vec::Vectorized<float>;
883     const Vec kZeroVec(zero);
884     const Vec kThreeVec(three);
885     const Vec kNegThreeVec(neg_three);
886     const Vec kOneHalfVec(one_half);
887     cpu_kernel_vec(
888       iter,
889       [&](scalar_t grad_val, scalar_t self_val) -> scalar_t {
890         if (float(self_val) < neg_three) {
891           return zero;
892         } else if (float(self_val) <= three) {
893           return float(grad_val) * ((float(self_val) / three) + one_half);
894         } else {
895           return grad_val;
896         }
897       },
898       [&](vec::Vectorized<scalar_t> grad_val, vec::Vectorized<scalar_t> self_val) {
899         auto [self_val0, self_val1] = convert_to_float<scalar_t>(self_val);
900         auto [grad_val0, grad_val1] = convert_to_float<scalar_t>(grad_val);
901         self_val0 = Vec::blendv(
902           Vec::blendv(
903             grad_val0 * ((self_val0 / kThreeVec) + kOneHalfVec),
904             grad_val0,
905             self_val0 >= kThreeVec
906           ),
907           kZeroVec,
908           self_val0 < kNegThreeVec
909         );
910         self_val1 = Vec::blendv(
911           Vec::blendv(
912             grad_val1 * ((self_val1 / kThreeVec) + kOneHalfVec),
913             grad_val1,
914             self_val1 >= kThreeVec
915           ),
916           kZeroVec,
917           self_val1 < kNegThreeVec
918         );
919         return convert_from_float<scalar_t>(self_val0, self_val1);
920       });
921     });
922   } else {
923     AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "hardswish_backward_cpu", [&]() {
924     const scalar_t zero(0.0f);
925     const scalar_t three(3.0f);
926     const scalar_t neg_three(-3.0f);
927     const scalar_t one_half(0.5f);
928     using Vec = vec::Vectorized<scalar_t>;
929     const Vec kZeroVec(zero);
930     const Vec kThreeVec(three);
931     const Vec kNegThreeVec(neg_three);
932     const Vec kOneHalfVec(one_half);
933     cpu_kernel_vec(
934       iter,
935       [&](scalar_t grad_val, scalar_t self_val) {
936         if (self_val < neg_three) {
937           return zero;
938         } else if (self_val <= three) {
939           return grad_val * ((self_val / three) + one_half);
940         } else {
941           return grad_val;
942         }
943       },
944       [&](Vec grad_val, Vec self_val) {
945         return Vec::blendv(
946           Vec::blendv(
947             grad_val * ((self_val / kThreeVec) + kOneHalfVec),
948             grad_val,
949             self_val >= kThreeVec
950           ),
951           kZeroVec,
952           self_val < kNegThreeVec
953         );
954       }
955     );
956   });
957   }
958 }
959 
leaky_relu_kernel(TensorIteratorBase & iter,const Scalar & negval_)960 static void leaky_relu_kernel(TensorIteratorBase& iter, const Scalar& negval_) {
961   if (at::isReducedFloatingType(iter.dtype())) {
962     AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "leaky_relu_cpu", [&]() {
963     auto zero_vec = Vectorized<float>((float)(0));
964     auto one_vec = Vectorized<float>((float)(1));
965     float negval = negval_.to<float>();
966     Vectorized<float> negval_v = Vectorized<float>(negval);
967     cpu_kernel_vec(
968         iter,
969         [&](scalar_t a) -> scalar_t {
970           return float(a) > float(0) ? float(a) : float(a) * negval;
971         },
972         [&](Vectorized<scalar_t> a) -> Vectorized<scalar_t> {
973           auto [a0, a1] = convert_to_float<scalar_t>(a);
974           auto res0 = a0 * (Vectorized<float>::blendv(negval_v, one_vec, a0 > zero_vec));
975           auto res1 = a1 * (Vectorized<float>::blendv(negval_v, one_vec, a1 > zero_vec));
976           return convert_from_float<scalar_t>(res0, res1);
977         });
978     });
979   } else {
980     AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "leaky_relu_cpu", [&] {
981       using Vec = Vectorized<scalar_t>;
982       auto zero_vec = Vec((scalar_t)(0));
983       auto one_vec = Vec((scalar_t)(1));
984       scalar_t negval = negval_.to<scalar_t>();
985       Vec negval_v = Vec(negval);
986       cpu_kernel_vec(
987           iter,
988           [&](scalar_t a) -> scalar_t {
989             return a > scalar_t(0) ? a : a * negval;
990           },
991           [&](Vec a) -> Vec {
992             auto r = Vec::blendv(negval_v, one_vec, a > zero_vec);
993             return a * r;
994           });
995     });
996   }
997 }
998 
leaky_relu_backward_kernel(TensorIteratorBase & iter,const Scalar & negval_)999 static void leaky_relu_backward_kernel(TensorIteratorBase& iter, const Scalar& negval_) {
1000   if (at::isReducedFloatingType(iter.dtype())) {
1001     AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "leaky_relu_backward_cpu", [&]() {
1002     auto zero_vec = Vectorized<float>((float)(0));
1003     auto one_vec = Vectorized<float>((float)(1));
1004     float negval = negval_.to<float>();
1005     Vectorized<float> negval_v = Vectorized<float>(negval);
1006     cpu_kernel_vec(
1007       iter,
1008       [&](scalar_t a, scalar_t b) -> scalar_t {
1009         return float(a) > float(0) ? float(b) : float(b) * negval;
1010       },
1011       [&](Vectorized<scalar_t> a, Vectorized<scalar_t> b) -> Vectorized<scalar_t> {
1012         auto [a0, a1] = convert_to_float<scalar_t>(a);
1013         auto [b0, b1] = convert_to_float<scalar_t>(b);
1014         auto res0 = b0 * (Vectorized<float>::blendv(negval_v, one_vec, a0 > zero_vec));
1015         auto res1 = b1 * (Vectorized<float>::blendv(negval_v, one_vec, a1 > zero_vec));
1016         return convert_from_float<scalar_t>(res0, res1);
1017       });
1018     });
1019   } else {
1020     AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "leaky_relu_backward_cpu", [&] {
1021       using Vec = Vectorized<scalar_t>;
1022       auto zero_vec = Vec((scalar_t)(0));
1023       auto one_vec = Vec((scalar_t)(1));
1024       scalar_t negval = negval_.to<scalar_t>();
1025       Vec negval_v = Vec(negval);
1026       cpu_kernel_vec(
1027           iter,
1028           [&](scalar_t a, scalar_t b) -> scalar_t {
1029             return a > scalar_t(0) ? b : b * negval;
1030           },
1031           [&](Vec a, Vec b) -> Vec {
1032             auto r = Vec::blendv(negval_v, one_vec, a > zero_vec);
1033             return b * r;
1034           });
1035     });
1036   }
1037 }
1038 
softplus_kernel(TensorIteratorBase & iter,const Scalar & beta_,const Scalar & threshold_)1039 void softplus_kernel(TensorIteratorBase& iter, const Scalar& beta_, const Scalar& threshold_) {
1040     if (at::isReducedFloatingType(iter.dtype())) {
1041     AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "softplus_cpu", [&]() {
1042       using Vec = Vectorized<float>;
1043       auto beta = beta_.to<float>();
1044       auto threshold = threshold_.to<float>();
1045       const Vec beta_vec(beta);
1046       const Vec threshold_vec(threshold);
1047       cpu_kernel_vec(
1048           iter,
1049           [beta, threshold](scalar_t a) -> scalar_t {
1050             return (float(a) * beta) > threshold ? a
1051               : static_cast<scalar_t>((std::log1p(std::exp(float(a) * beta))) / beta);
1052           },
1053           [beta_vec, threshold_vec](Vectorized<scalar_t> a) -> Vectorized<scalar_t> {
1054             auto [a0, a1] = convert_to_float<scalar_t>(a);
1055             a0 = Vec::blendv((a0 * beta_vec).exp().log1p() / beta_vec, a0, (a0 * beta_vec) > threshold_vec);
1056             a1 = Vec::blendv((a1 * beta_vec).exp().log1p() / beta_vec, a1, (a1 * beta_vec) > threshold_vec);
1057             return convert_from_float<scalar_t>(a0, a1);
1058           }
1059       );
1060     });
1061   } else {
1062     AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "softplus_cpu", [&]() {
1063     using Vec = Vectorized<scalar_t>;
1064     auto beta = beta_.to<scalar_t>();
1065     auto threshold = threshold_.to<scalar_t>();
1066     const Vec beta_vec(beta);
1067     const Vec threshold_vec(threshold);
1068     cpu_kernel_vec(
1069         iter,
1070         [beta, threshold](scalar_t a) -> scalar_t {
1071           return (a * beta) > threshold ? a
1072             : static_cast<scalar_t>(std::log1p(std::exp(a * beta))) / beta;
1073         },
1074         [beta_vec, threshold_vec](Vec a) -> Vec {
1075           return Vec::blendv((a * beta_vec).exp().log1p() / beta_vec, a, (a * beta_vec) > threshold_vec);
1076         }
1077     );
1078   });
1079   }
1080 }
1081 
softplus_backward_kernel(TensorIteratorBase & iter,const Scalar & beta_,const Scalar & threshold_)1082 void softplus_backward_kernel(TensorIteratorBase& iter, const Scalar& beta_, const Scalar& threshold_) {
1083   if (at::isReducedFloatingType(iter.dtype())) {
1084     AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "softplus_backward_cpu", [&]() {
1085     using Vec = Vectorized<float>;
1086     auto beta = beta_.to<float>();
1087     auto threshold = threshold_.to<float>();
1088     const Vec beta_vec(beta);
1089     const Vec threshold_vec(threshold);
1090     const Vec one_vec(static_cast<float>(1.0));
1091     cpu_kernel_vec(
1092         iter,
1093         [beta, threshold](scalar_t a, scalar_t b) -> scalar_t {
1094           float z = std::exp(float(b) * beta);
1095           return (float(b) * beta) > threshold ? a : static_cast<scalar_t>(float(a) * z / (z + float(1.)));
1096         },
1097         [beta_vec, one_vec, threshold_vec](Vectorized<scalar_t> a, Vectorized<scalar_t> b) -> Vectorized<scalar_t> {
1098           auto [a0, a1] = convert_to_float<scalar_t>(a);
1099           auto [b0, b1] = convert_to_float<scalar_t>(b);
1100           Vec z = (b0 * beta_vec).exp();
1101           a0 = Vec::blendv(a0 * z / (z + one_vec), a0, (b0 * beta_vec) > threshold_vec);
1102           z = (b1 * beta_vec).exp();
1103           a1 = Vec::blendv(a1 * z / (z + one_vec), a1, (b1 * beta_vec) > threshold_vec);
1104           return convert_from_float<scalar_t>(a0, a1);
1105         });
1106     });
1107   } else {
1108     AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "softplus_backward_cpu", [&]() {
1109     using Vec = Vectorized<scalar_t>;
1110     auto beta = beta_.to<scalar_t>();
1111     auto threshold = threshold_.to<scalar_t>();
1112     const Vec beta_vec(beta);
1113     const Vec threshold_vec(threshold);
1114     const Vec one_vec(static_cast<scalar_t>(1.0));
1115     cpu_kernel_vec(
1116         iter,
1117         [beta, threshold](scalar_t a, scalar_t b) -> scalar_t {
1118           scalar_t z = std::exp(b * beta);
1119           return (b * beta) > threshold ? a : a * z / (z + scalar_t(1.));
1120         },
1121         [beta_vec, one_vec, threshold_vec](Vec a, Vec b) -> Vec {
1122           const Vec z = (b * beta_vec).exp();
1123           return Vec::blendv(a * z / (z + one_vec), a, (b * beta_vec) > threshold_vec);
1124         }
1125     );
1126   });
1127   }
1128 }
1129 
glu_kernel(TensorIteratorBase & iter)1130 void glu_kernel(TensorIteratorBase& iter) {
1131   if (at::isReducedFloatingType(iter.dtype())) {
1132     AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "glu_cpu", [&]() {
1133     const float float_one_val(1);
1134     const Vectorized<float> float_one_vec(float_one_val);
1135     cpu_kernel_vec(
1136       iter,
1137       [float_one_val](scalar_t a, scalar_t b) -> scalar_t {
1138         return float(a) * (float_one_val / (float_one_val + std::exp(- float(b))));
1139       },
1140       [float_one_vec](Vectorized<scalar_t> a, Vectorized<scalar_t> b) -> Vectorized<scalar_t> {
1141         auto [a0, a1] = convert_to_float<scalar_t>(a);
1142         auto [b0, b1] = convert_to_float<scalar_t>(b);
1143         return convert_from_float<scalar_t>(a0 * (float_one_vec / (float_one_vec + b0.neg().exp())),
1144                                             a1 * (float_one_vec / (float_one_vec + b1.neg().exp())));
1145       });
1146     });
1147   } else {
1148     AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "glu_cpu", [&] {
1149     using Vec = Vectorized<scalar_t>;
1150     const scalar_t one_val(1);
1151     const Vec one_vec(one_val);
1152     cpu_kernel_vec(
1153       iter,
1154       [one_val](scalar_t a, scalar_t b) -> scalar_t {
1155         return a * (one_val / (one_val + std::exp(-b)));
1156       },
1157       [one_vec](Vec a, Vec b) -> Vec {
1158         return a * (one_vec / (one_vec + b.neg().exp()));
1159       }
1160     );
1161   });
1162   }
1163 }
1164 
glu_jvp_kernel(TensorIteratorBase & iter)1165 void glu_jvp_kernel(TensorIteratorBase& iter) {
1166   AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "glu_jvp_cpu", [&] {
1167     using Vec = Vectorized<scalar_t>;
1168     const scalar_t one(1);
1169     const Vec ones(one);
1170     cpu_kernel_vec(
1171       iter,
1172       [one](scalar_t res, scalar_t b, scalar_t da, scalar_t db) -> scalar_t {
1173         const auto sig_b = one / (one + std::exp(-b));
1174         return da * sig_b + res * (db - sig_b * db);
1175       },
1176       [ones](Vec res, Vec b, Vec da, Vec db) -> Vec {
1177         const auto sig_b = ones / (ones + b.neg().exp());
1178         return da * sig_b + res * (db - sig_b * db);
1179       }
1180     );
1181   });
1182 }
1183 
glu_backward_kernel(TensorIterator & iter)1184 void glu_backward_kernel(TensorIterator& iter) {
1185   if (at::isReducedFloatingType(iter.dtype())) {
1186     AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "glu_backward_cpu", [&]() {
1187     const float float_one_val(1);
1188     const Vectorized<float> float_one_vec(float_one_val);
1189     cpu_kernel_vec(
1190       iter,
1191       [float_one_val](scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
1192         return  (float_one_val - float(a)) * float(a) * float(b) * float(c);
1193       },
1194       [float_one_vec](Vectorized<scalar_t> a, Vectorized<scalar_t> b, Vectorized<scalar_t> c) -> Vectorized<scalar_t> {
1195         auto [a0, a1] = convert_to_float<scalar_t>(a);
1196         auto [b0, b1] = convert_to_float<scalar_t>(b);
1197         auto [c0, c1] = convert_to_float<scalar_t>(c);
1198         a0 = (float_one_vec - a0) * a0 * b0 * c0;
1199         a1 = (float_one_vec - a1) * a1 * b1 * c1;
1200         return convert_from_float<scalar_t>(a0, a1);
1201       });
1202     });
1203   } else {
1204     AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "glu_backward_cpu", [&] {
1205       using Vec = Vectorized<scalar_t>;
1206       const scalar_t one_val(1);
1207       const Vec one_vec(one_val);
1208       cpu_kernel_vec(
1209         iter,
1210         [one_val](scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
1211           return (one_val - a) * a * b * c;
1212         },
1213         [one_vec](Vec a, Vec b, Vec c) -> Vec {
1214           return (one_vec - a) * a * b * c;
1215         }
1216       );
1217     });
1218   }
1219 }
1220 
silu_kernel(TensorIteratorBase & iter)1221 void silu_kernel(TensorIteratorBase& iter) {
1222   if (at::isReducedFloatingType(iter.dtype())) {
1223     AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "silu_cpu", [&]() {
1224       const Vectorized<float> kOneVec(1.0f);
1225       cpu_kernel_vec(
1226           iter,
1227           [](scalar_t x) -> scalar_t {
1228             return float(x) / (1.0f + std::exp(-float(x)));
1229           },
1230           [kOneVec](Vectorized<scalar_t> x_vec) -> Vectorized<scalar_t> {
1231             auto [x_vec0, x_vec1] = convert_to_float<scalar_t>(x_vec);
1232             return convert_from_float<scalar_t>(
1233               x_vec0 / (kOneVec + x_vec0.neg().exp()),
1234               x_vec1 / (kOneVec + x_vec1.neg().exp()));
1235           });
1236     });
1237   } else {
1238     AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
1239       iter.dtype(), "silu_cpu", [&]() {
1240         const Vectorized<scalar_t> kOneVec(scalar_t(1));
1241         cpu_kernel_vec(
1242             iter,
1243             [](scalar_t x) {
1244               return x / (scalar_t(1) + std::exp(-x));
1245             },
1246             [kOneVec](Vectorized<scalar_t> x_vec) {
1247               return x_vec / (kOneVec + x_vec.neg().exp());
1248             });
1249       });
1250     }
1251 }
1252 
silu_backward_kernel(TensorIteratorBase & iter)1253 void silu_backward_kernel(TensorIteratorBase& iter) {
1254   if (at::isReducedFloatingType(iter.dtype())) {
1255     AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "silu_backward_cpu", [&]() {
1256     const Vectorized<float> kOneVec(1.0f);
1257     cpu_kernel_vec(
1258         iter,
1259         [](scalar_t dy, scalar_t x) -> scalar_t {
1260           const float sigmoid =
1261               1.0f / (1.0f + std::exp(-float(x)));
1262           return dy * sigmoid * (1.0f + x * (1.0f - sigmoid));
1263         },
1264         [kOneVec](Vectorized<scalar_t> dy_vec, Vectorized<scalar_t> x_vec) -> Vectorized<scalar_t> {
1265           auto [x_vec0, x_vec1] = convert_to_float<scalar_t>(x_vec);
1266           auto [dy_vec0, dy_vec1] = convert_to_float<scalar_t>(dy_vec);
1267           const Vectorized<float> sigmoid0 =
1268               kOneVec / (kOneVec + x_vec0.neg().exp());
1269           const Vectorized<float> sigmoid1 =
1270               kOneVec / (kOneVec + x_vec1.neg().exp());
1271           return convert_from_float<scalar_t>(
1272             dy_vec0 * sigmoid0 * (kOneVec + x_vec0 * (kOneVec - sigmoid0)),
1273             dy_vec1 * sigmoid1 * (kOneVec + x_vec1 * (kOneVec - sigmoid1)));
1274         });
1275     });
1276   } else {
1277     AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
1278       iter.dtype(), "silu_backward_cpu", [&]() {
1279         const Vectorized<scalar_t> kOneVec(scalar_t(1));
1280         cpu_kernel_vec(
1281             iter,
1282             [](scalar_t dy, scalar_t x) {
1283               const scalar_t sigmoid =
1284                   scalar_t(1) / (scalar_t(1) + std::exp(-x));
1285               return dy * sigmoid * (scalar_t(1) + x * (scalar_t(1) - sigmoid));
1286             },
1287             [kOneVec](Vectorized<scalar_t> dy_vec, Vectorized<scalar_t> x_vec) {
1288               const Vectorized<scalar_t> sigmoid =
1289                   kOneVec / (kOneVec + x_vec.neg().exp());
1290               return dy_vec * sigmoid * (kOneVec + x_vec * (kOneVec - sigmoid));
1291             });
1292       });
1293   }
1294 }
1295 
mish_kernel(TensorIteratorBase & iter)1296 void mish_kernel(TensorIteratorBase& iter) {
1297   if (at::isReducedFloatingType(iter.dtype())) {
1298     AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "mish_cpu", [&]() {
1299     cpu_kernel_vec(
1300         iter,
1301         [](scalar_t x) -> scalar_t{
1302           return static_cast<scalar_t>(float(x) * std::tanh(std::log1p(std::exp(float(x)))));
1303         },
1304         [](Vectorized<scalar_t> x_vec) -> Vectorized<scalar_t> {
1305           auto [x_vec0, x_vec1] = convert_to_float<scalar_t>(x_vec);
1306           return convert_from_float<scalar_t>(
1307             x_vec0 * x_vec0.exp().log1p().tanh(),
1308             x_vec1 * x_vec1.exp().log1p().tanh()
1309           );
1310         });
1311     });
1312   } else {
1313     AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "mish_cpu", [&]() {
1314         using Vec = Vectorized<scalar_t>;
1315         cpu_kernel_vec(
1316             iter,
1317             [](scalar_t x) -> scalar_t{
1318               return static_cast<scalar_t>(x * std::tanh(std::log1p(std::exp(x))));
1319             },
1320             [](Vec x_vec) -> Vec {
1321               return x_vec * x_vec.exp().log1p().tanh();
1322             });
1323       });
1324   }
1325 }
1326 
mish_backward_kernel(TensorIterator & iter)1327 void mish_backward_kernel(TensorIterator& iter) {
1328   if (at::isReducedFloatingType(iter.dtype())) {
1329     AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "mish_backward_cpu", [&]() {
1330     using Vec = Vectorized<float>;
1331     const Vec kOneVec(1.0f);
1332     cpu_kernel_vec(
1333         iter,
1334         [](scalar_t dy, scalar_t x) -> scalar_t {
1335           const float sigmoid =
1336               1.0f / (1.0f + std::exp(-float(x)));
1337           const float tanh_softplus = std::tanh(std::log1p(std::exp(float(x))));
1338           return dy * (tanh_softplus + x * sigmoid * (1.0f - tanh_softplus * tanh_softplus));
1339         },
1340         [kOneVec](Vectorized<scalar_t> dy_vec, Vectorized<scalar_t> x_vec) -> Vectorized<scalar_t> {
1341           auto [x_vec0, x_vec1] = convert_to_float<scalar_t>(x_vec);
1342           auto [dy_vec0, dy_vec1] = convert_to_float<scalar_t>(dy_vec);
1343           const Vec sigmoid0 = kOneVec / (kOneVec + x_vec0.neg().exp());
1344           const Vec sigmoid1 = kOneVec / (kOneVec + x_vec1.neg().exp());
1345           const Vec tanh_softplus0 = x_vec0.exp().log1p().tanh();
1346           const Vec tanh_softplus1 = x_vec1.exp().log1p().tanh();
1347           return convert_from_float<scalar_t>(
1348             dy_vec0 * (tanh_softplus0 + x_vec0 * sigmoid0 * (kOneVec - tanh_softplus0 * tanh_softplus0)),
1349             dy_vec1 * (tanh_softplus1 + x_vec1 * sigmoid1 * (kOneVec - tanh_softplus1 * tanh_softplus1))
1350           );
1351         });
1352     });
1353   } else {
1354     AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "mish_backward_cpu", [&]() {
1355         using Vec = Vectorized<scalar_t>;
1356         const Vec kOneVec(scalar_t(1));
1357         cpu_kernel_vec(
1358             iter,
1359             [](scalar_t dy, scalar_t x) -> scalar_t {
1360               const scalar_t sigmoid =
1361                   scalar_t(1) / (scalar_t(1) + std::exp(-x));
1362               const scalar_t tanh_softplus = std::tanh(std::log1p(std::exp(x)));
1363               return dy * (tanh_softplus + x * sigmoid * (scalar_t(1) - tanh_softplus * tanh_softplus));
1364             },
1365             [kOneVec](Vec dy_vec, Vec x_vec) -> Vec {
1366               const Vec sigmoid = kOneVec / (kOneVec + x_vec.neg().exp());
1367               const Vec tanh_softplus = x_vec.exp().log1p().tanh();
1368               return dy_vec * (tanh_softplus + x_vec * sigmoid * (kOneVec - tanh_softplus * tanh_softplus));
1369             });
1370       });
1371   }
1372 }
1373 
prelu_kernel(TensorIterator & iter)1374 void prelu_kernel(TensorIterator& iter) {
1375   AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "prelu_cpu", [&]() {
1376     using Vec = Vectorized<scalar_t>;
1377     cpu_kernel_vec(
1378       iter,
1379       [](scalar_t input, scalar_t weight) {
1380         return (input > scalar_t(0)) ? input : weight * input;
1381       },
1382       [](Vec input, Vec weight) {
1383         return Vec::blendv(weight * input, input, input > Vec(0));
1384       });
1385   });
1386 }
1387 
prelu_backward_kernel(TensorIterator & iter)1388 void prelu_backward_kernel(TensorIterator& iter) {
1389   AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "prelu_backward_cpu", [&]() {
1390     cpu_kernel_multiple_outputs(iter,
1391       [](scalar_t input, scalar_t weight, scalar_t grad) -> std::tuple<scalar_t, scalar_t> {
1392         auto mask = input > scalar_t{0};
1393         auto grad_input = mask ? grad : weight * grad;
1394         auto grad_weight = mask ? scalar_t{0} : input * grad;
1395         return {grad_input, grad_weight};
1396       });
1397   });
1398 }
1399 
1400 } // namespace
1401 
1402 
1403 REGISTER_DISPATCH(hardsigmoid_stub, &hardsigmoid_kernel);
1404 REGISTER_DISPATCH(hardsigmoid_backward_stub, &hardsigmoid_backward_kernel);
1405 REGISTER_DISPATCH(threshold_stub, &threshold_kernel);
1406 REGISTER_DISPATCH(leaky_relu_stub, &leaky_relu_kernel);
1407 REGISTER_DISPATCH(leaky_relu_backward_stub, &leaky_relu_backward_kernel);
1408 REGISTER_DISPATCH(prelu_stub, &prelu_kernel);
1409 REGISTER_DISPATCH(prelu_backward_stub, &prelu_backward_kernel);
1410 REGISTER_DISPATCH(hardtanh_backward_stub, &hardtanh_backward_kernel);
1411 REGISTER_DISPATCH(hardshrink_stub, &hardshrink_kernel);
1412 REGISTER_DISPATCH(softshrink_stub, &softshrink_kernel);
1413 REGISTER_DISPATCH(shrink_backward_stub, &shrink_backward_kernel);
1414 
1415 ALSO_REGISTER_AVX512_DISPATCH(log_sigmoid_cpu_stub, &log_sigmoid_cpu_kernel);
1416 ALSO_REGISTER_AVX512_DISPATCH(log_sigmoid_backward_stub, &log_sigmoid_backward_cpu_kernel);
1417 ALSO_REGISTER_AVX512_DISPATCH(glu_stub, &glu_kernel);
1418 ALSO_REGISTER_AVX512_DISPATCH(glu_backward_stub, &glu_backward_kernel);
1419 ALSO_REGISTER_AVX512_DISPATCH(glu_jvp_stub, &glu_jvp_kernel);
1420 ALSO_REGISTER_AVX512_DISPATCH(elu_stub, &elu_kernel);
1421 ALSO_REGISTER_AVX512_DISPATCH(elu_backward_stub, &elu_backward_kernel);
1422 ALSO_REGISTER_AVX512_DISPATCH(GeluKernel, &GeluKernelImpl);
1423 ALSO_REGISTER_AVX512_DISPATCH(GeluBackwardKernel, &GeluBackwardKernelImpl);
1424 ALSO_REGISTER_AVX512_DISPATCH(hardswish_stub, &hardswish_kernel);
1425 ALSO_REGISTER_AVX512_DISPATCH(hardswish_backward_stub, &hardswish_backward_kernel);
1426 ALSO_REGISTER_AVX512_DISPATCH(softplus_stub, &softplus_kernel);
1427 ALSO_REGISTER_AVX512_DISPATCH(softplus_backward_stub, &softplus_backward_kernel);
1428 ALSO_REGISTER_AVX512_DISPATCH(silu_stub, &silu_kernel);
1429 ALSO_REGISTER_AVX512_DISPATCH(silu_backward_stub, &silu_backward_kernel);
1430 ALSO_REGISTER_AVX512_DISPATCH(mish_stub, &mish_kernel);
1431 ALSO_REGISTER_AVX512_DISPATCH(mish_backward_stub, &mish_backward_kernel);
1432 
1433 } // namespace at::native
1434