1 #define TORCH_ASSERT_NO_OPERATORS
2 #ifndef _USE_MATH_DEFINES
3 #define _USE_MATH_DEFINES
4 #endif
5
6 #include <ATen/native/Activation.h>
7
8
9 #include <cmath>
10 #include <functional>
11
12 #include <ATen/Dispatch.h>
13 #include <ATen/OpMathType.h>
14 #include <ATen/core/TensorBase.h>
15 #include <ATen/cpu/vec/functional.h>
16 #include <ATen/cpu/vec/vec.h>
17 #include <ATen/native/TensorIterator.h>
18 #include <ATen/native/cpu/Loops.h>
19 #include <ATen/Parallel.h>
20
21 #include <c10/core/Scalar.h>
22
23 namespace at::native {
24
25 namespace {
26
log_sigmoid_cpu_kernel(TensorBase & output,TensorBase & buffer,const TensorBase & input)27 static void log_sigmoid_cpu_kernel(TensorBase &output, TensorBase &buffer, const TensorBase &input) {
28 if (at::isReducedFloatingType(input.scalar_type())) {
29 AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "log_sigmoid_cpu", [&]() {
30 using Vec = Vectorized<scalar_t>;
31 scalar_t* output_data = output.data_ptr<scalar_t>();
32 scalar_t* buffer_data = buffer.data_ptr<scalar_t>();
33 const scalar_t* input_data = input.const_data_ptr<scalar_t>();
34 parallel_for(0, input.numel(), 1, [&] (int64_t begin, int64_t end) {
35 int64_t size = end - begin;
36 int64_t d = 0;
37 for (; d < size - (size % Vec::size()); d += Vec::size()) {
38 Vec data_vec = Vec::loadu(input_data + begin+ d);
39 auto [data_vec0, data_vec1] = convert_to_float<scalar_t>(data_vec);
40 Vectorized<float> min_vec = minimum(data_vec0, Vectorized<float>(float(0)));
41 Vectorized<float> buffer_vec0 = data_vec0.abs().neg().exp();
42 Vectorized<float> output_vec0 = min_vec - buffer_vec0.log1p();
43 min_vec = minimum(data_vec1, Vectorized<float>(float(0)));
44 Vectorized<float> buffer_vec1 = data_vec1.abs().neg().exp();
45 Vectorized<float> output_vec1 = min_vec - buffer_vec1.log1p();
46 convert_from_float<scalar_t>(buffer_vec0, buffer_vec1).store(buffer_data + begin + d);
47 convert_from_float<scalar_t>(output_vec0, output_vec1).store(output_data + begin + d);
48 }
49 if (size - d > 0) {
50 Vec data_vec = Vec::loadu(input_data + begin + d, size - d);
51 auto [data_vec0, data_vec1] = convert_to_float<scalar_t>(data_vec);
52 Vectorized<float> min_vec = minimum(data_vec0, Vectorized<float>(float(0)));
53 Vectorized<float> buffer_vec0 = data_vec0.abs().neg().exp();
54 Vectorized<float> output_vec0 = min_vec - buffer_vec0.log1p();
55 min_vec = minimum(data_vec1, Vectorized<float>(float(0)));
56 Vectorized<float> buffer_vec1 = data_vec1.abs().neg().exp();
57 Vectorized<float> output_vec1 = min_vec - buffer_vec1.log1p();
58 convert_from_float<scalar_t>(buffer_vec0, buffer_vec1).store(buffer_data + begin + d, size - d);
59 convert_from_float<scalar_t>(output_vec0, output_vec1).store(output_data + begin + d, size - d);
60 }
61 });
62 });
63 } else {
64 AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "log_sigmoid_cpu", [&] {
65 using Vec = Vectorized<scalar_t>;
66 scalar_t* output_data = output.data_ptr<scalar_t>();
67 scalar_t* buffer_data = buffer.data_ptr<scalar_t>();
68 const scalar_t* input_data = input.const_data_ptr<scalar_t>();
69 parallel_for(0, input.numel(), 1, [&] (int64_t begin, int64_t end) {
70 int64_t size = end - begin;
71 int64_t d = 0;
72 for (; d < size - (size % Vec::size()); d += Vec::size()) {
73 Vec data_vec = Vec::loadu(input_data + begin+ d);
74 Vec min_vec = vec::minimum(data_vec, Vec(scalar_t(0)));
75 Vec buffer_vec = data_vec.abs().neg().exp();
76 Vec output_vec = min_vec - buffer_vec.log1p();
77 buffer_vec.store(buffer_data + begin + d);
78 output_vec.store(output_data + begin + d);
79 }
80 if (size - d > 0) {
81 Vec data_vec = Vec::loadu(input_data + begin + d, size - d);
82 Vec min_vec = vec::minimum(data_vec, Vec(scalar_t(0)));
83 Vec buffer_vec = data_vec.abs().neg().exp();
84 Vec output_vec = min_vec - buffer_vec.log1p();
85 buffer_vec.store(buffer_data + begin + d, size - d);
86 output_vec.store(output_data + begin + d, size - d);
87 }
88 });
89 });
90 }
91 }
92
log_sigmoid_backward_cpu_kernel(TensorIterator & iter)93 static void log_sigmoid_backward_cpu_kernel(TensorIterator& iter) {
94 if (at::isReducedFloatingType(iter.dtype())) {
95 AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "log_sigmoid_backward_cpu", [&]() {
96 using Vec = Vectorized<scalar_t>;
97 auto zero_val = float(0);
98 auto zero_vec = Vectorized<float>(zero_val);
99 auto one_val = float(1);
100 auto one_vec = Vectorized<float>(one_val);
101 cpu_kernel_vec(iter,
102 [=](scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
103 auto in_negative = float(a) < float(0);
104 auto max_deriv = in_negative ? float(1) : float(0);
105 auto sign = in_negative ? float(1) : -float(1);
106 return (max_deriv - sign * (float(b) / (float(1) + b))) * float(c);
107 },
108 [=](Vec a, Vec b, Vec c) -> Vec {
109 auto [a0, a1] = convert_to_float<scalar_t>(a);
110 auto [b0, b1] = convert_to_float<scalar_t>(b);
111 auto [c0, c1] = convert_to_float<scalar_t>(c);
112 auto mask = a0 < zero_vec;
113 auto max_deriv_vec = Vectorized<float>::blendv(zero_vec, one_vec, mask);
114 auto sign_vec = Vectorized<float>::blendv(one_vec.neg(), one_vec, mask);
115 a0 = (max_deriv_vec - sign_vec * (b0 / (one_vec + b0))) * c0;
116 mask = a1 < zero_vec;
117 max_deriv_vec = Vectorized<float>::blendv(zero_vec, one_vec, mask);
118 sign_vec = Vectorized<float>::blendv(one_vec.neg(), one_vec, mask);
119 a1 = (max_deriv_vec - sign_vec * (b1 / (one_vec + b1))) * c1;
120 return convert_from_float<scalar_t>(a0, a1);
121 });
122 });
123 } else {
124 AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "log_sigmoid_backward_cpu", [&]() {
125 using Vec = Vectorized<scalar_t>;
126 auto zero_val = scalar_t(0);
127 auto zero_vec = Vec(zero_val);
128 auto one_val = scalar_t(1);
129 auto one_vec = Vec(one_val);
130 cpu_kernel_vec(iter,
131 [=](scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
132 auto in_negative = a < scalar_t(0);
133 auto max_deriv = in_negative ? scalar_t(1) : scalar_t(0);
134 auto sign = in_negative ? scalar_t(1) : -scalar_t(1);
135 return (max_deriv - sign * (b / (scalar_t(1) + b))) * c;
136 },
137 [=](Vec a, Vec b, Vec c) -> Vec {
138 auto mask = a < zero_vec;
139 auto max_deriv_vec = Vec::blendv(zero_vec, one_vec, mask);
140 auto sign_vec = Vec::blendv(one_vec.neg(), one_vec, mask);
141 return (max_deriv_vec - sign_vec * (b / (one_vec + b))) * c;
142 });
143 });
144 }
145 }
146
threshold_kernel(TensorIteratorBase & iter,const Scalar & threshold_scalar,const Scalar & value_scalar)147 static void threshold_kernel(
148 TensorIteratorBase& iter,
149 const Scalar& threshold_scalar,
150 const Scalar& value_scalar) {
151 if (at::isReducedFloatingType(iter.dtype())) {
152 AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "threshold_cpu", [&]() {
153 using Vec = Vectorized<float>;
154 float threshold = threshold_scalar.to<float>();
155 Vec threshold_v = Vec(threshold);
156 scalar_t value = value_scalar.to<scalar_t>();
157 Vec value_v = Vec(float(value));
158 cpu_kernel_vec(
159 iter,
160 [&](scalar_t x, scalar_t other) -> scalar_t {
161 return float(x) <= threshold ? value : other;
162 },
163 [&](Vectorized<scalar_t> x, Vectorized<scalar_t> other) -> Vectorized<scalar_t> {
164 auto [x0, x1] = convert_to_float<scalar_t>(x);
165 auto [other0, other1] = convert_to_float<scalar_t>(other);
166 return convert_from_float<scalar_t>(Vec::blendv(other0, value_v, x0 <= threshold_v),
167 Vec::blendv(other1, value_v, x1 <= threshold_v));
168 });
169 });
170 } else {
171 AT_DISPATCH_ALL_TYPES(iter.dtype(), "threshold_cpu", [&] {
172 using Vec = Vectorized<scalar_t>;
173 scalar_t threshold = threshold_scalar.to<scalar_t>();
174 Vec threshold_v = Vec(threshold);
175 scalar_t value = value_scalar.to<scalar_t>();
176 Vec value_v = Vec(value);
177 cpu_kernel_vec(
178 iter,
179 [&](scalar_t x, scalar_t other) -> scalar_t {
180 return x <= threshold ? value : other;
181 },
182 [&](Vec x, Vec other) -> Vec {
183 return Vec::blendv(other, value_v, x <= threshold_v);
184 });
185 });
186 }
187 }
188
elu_kernel(TensorIteratorBase & it,const Scalar & alpha,const Scalar & scale,const Scalar & input_scale)189 void elu_kernel(TensorIteratorBase& it, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale) {
190 if (at::isReducedFloatingType(it.common_dtype())) {
191 AT_DISPATCH_REDUCED_FLOATING_TYPES(it.common_dtype(), "elu_cpu", [&]() {
192 auto negcoef = alpha.to<float>() * scale.to<float>();
193 auto poscoef = scale.to<float>();
194 auto negiptcoef = input_scale.to<float>();
195 const Vectorized<float> negcoef_vec(negcoef);
196 const Vectorized<float> negiptcoef_vec(negiptcoef);
197 const Vectorized<float> poscoef_vec(poscoef);
198 const Vectorized<float> one_vec(static_cast<float>(1));
199 const Vectorized<float> zero_vec(static_cast<float>(0));
200 cpu_kernel_vec(
201 it,
202 [negcoef, negiptcoef, poscoef](scalar_t a) -> scalar_t {
203 return float(a) <= float(0) ? (std::exp(float(a) * negiptcoef) - float(1)) * negcoef : float(a) * poscoef;
204 },
205 [&negcoef_vec, &negiptcoef_vec, &poscoef_vec, &one_vec, &zero_vec](Vectorized<scalar_t> a) -> Vectorized<scalar_t> {
206 auto [a0, a1] = convert_to_float<scalar_t>(a);
207 auto cmp0 = (a0 > zero_vec);
208 auto cmp1 = (a1 > zero_vec);
209 auto get_res_masked = [&](Vectorized<float>& cmp, Vectorized<float>& a) {
210 return !cmp.zero_mask() ? a * poscoef_vec :
211 Vectorized<float>::blendv(((a * negiptcoef_vec).exp() - one_vec) * negcoef_vec, a * poscoef_vec, cmp);
212 };
213 auto res0 = get_res_masked(cmp0, a0);
214 auto res1 = get_res_masked(cmp1, a1);
215 return convert_from_float<scalar_t>(res0, res1);
216 });
217 });
218 } else {
219 AT_DISPATCH_FLOATING_TYPES(it.common_dtype(), "elu_cpu", [&]() {
220 using Vec = Vectorized<scalar_t>;
221 auto negcoef = alpha.to<scalar_t>() * scale.to<scalar_t>();
222 auto poscoef = scale.to<scalar_t>();
223 auto negiptcoef = input_scale.to<scalar_t>();
224 const Vec negcoef_vec(negcoef);
225 const Vec negiptcoef_vec(negiptcoef);
226 const Vec poscoef_vec(poscoef);
227 const Vec one_vec(static_cast<scalar_t>(1));
228 const Vec zero_vec(static_cast<scalar_t>(0));
229 cpu_kernel_vec(
230 it,
231 [negcoef, negiptcoef, poscoef](scalar_t a) -> scalar_t {
232 return a <= scalar_t(0) ? (std::exp(a * negiptcoef) - scalar_t(1)) * negcoef : a * poscoef;
233 },
234 [&negcoef_vec, &negiptcoef_vec, &poscoef_vec, &one_vec, &zero_vec](Vec a) -> Vec {
235 auto cmp = (a > zero_vec);
236 if (!cmp.zero_mask()) { // only a * poscoef (which is very quick) needs to be computed
237 return a * poscoef_vec;
238 } else {
239 return Vec::blendv(((a * negiptcoef_vec).exp() - one_vec) * negcoef_vec, a * poscoef_vec, cmp);
240 }
241 });
242 });
243 }
244 }
245
elu_backward_kernel(TensorIteratorBase & it,const Scalar & alpha,const Scalar & scale,const Scalar & input_scale,bool is_result)246 void elu_backward_kernel(TensorIteratorBase& it, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale, bool is_result) {
247 if (at::isReducedFloatingType(it.common_dtype())) {
248 AT_DISPATCH_REDUCED_FLOATING_TYPES(it.common_dtype(), "elu_backward_cpu", [&]() {
249 auto negcoef = alpha.to<float>() * scale.to<float>();
250 auto poscoef = scale.to<float>();
251 auto negiptcoef = input_scale.to<float>();
252 const Vectorized<float> negcoef_vec(negcoef);
253 const Vectorized<float> negiptcoef_vec(negiptcoef);
254 const Vectorized<float> poscoef_vec(poscoef);
255 const Vectorized<float> zero_vec(static_cast<float>(0));
256 cpu_kernel_vec(
257 it,
258 [negcoef, negiptcoef, poscoef, is_result](scalar_t a, scalar_t b) -> scalar_t {
259 if (is_result) {
260 return float(b) <= float(0) ? float(a) * negiptcoef * (float(b) + negcoef) : float(a) * poscoef;
261 } else {
262 return float(b) <= float(0) ? float(a) * negiptcoef * negcoef * std::exp(float(b) * negiptcoef): float(a) * poscoef;
263 }
264 },
265 [&negcoef_vec, &negiptcoef_vec, &poscoef_vec, &zero_vec, is_result](Vectorized<scalar_t> a, Vectorized<scalar_t> b) -> Vectorized<scalar_t> {
266 auto [a0, a1] = convert_to_float<scalar_t>(a);
267 auto [b0, b1] = convert_to_float<scalar_t>(b);
268 auto cmp0 = (b0 > zero_vec);
269 auto cmp1 = (b1 > zero_vec);
270 auto get_res_masked = [&](Vectorized<float>& cmp, Vectorized<float>& a, Vectorized<float>& b) {
271 if (is_result) {
272 return !cmp.zero_mask() ? a * poscoef_vec :
273 Vectorized<float>::blendv(a * negiptcoef_vec * (b + negcoef_vec), a * poscoef_vec, cmp);
274 } else {
275 return Vectorized<float>::blendv(a * negiptcoef_vec * negcoef_vec * (b * negiptcoef_vec).exp(), a * poscoef_vec, cmp);
276 }
277 };
278 auto res0 = get_res_masked(cmp0, a0, b0);
279 auto res1 = get_res_masked(cmp1, a1, b1);
280 return convert_from_float<scalar_t>(res0, res1);
281 });
282 });
283 } else {
284 AT_DISPATCH_FLOATING_TYPES(it.dtype(), "elu_backward_cpu", [&]() {
285 using Vec = Vectorized<scalar_t>;
286 auto negcoef = alpha.to<scalar_t>() * scale.to<scalar_t>();
287 auto poscoef = scale.to<scalar_t>();
288 auto negiptcoef = input_scale.to<scalar_t>();
289 const Vec negcoef_vec(negcoef);
290 const Vec negiptcoef_vec(negiptcoef);
291 const Vec poscoef_vec(poscoef);
292 const Vec zero_vec(static_cast<scalar_t>(0));
293 cpu_kernel_vec(
294 it,
295 [negcoef, negiptcoef, poscoef, is_result](scalar_t a, scalar_t b) -> scalar_t {
296 if (is_result) {
297 return b <= scalar_t(0) ? a * negiptcoef * (b + negcoef) : a * poscoef;
298 } else {
299 return b <= scalar_t(0) ? a * negiptcoef * negcoef * std::exp(b * negiptcoef): a * poscoef;
300 }
301 },
302 [&negcoef_vec, &negiptcoef_vec, &poscoef_vec, &zero_vec, is_result](Vec a, Vec b) -> Vec {
303 auto cmp = (b > zero_vec);
304 if (is_result) {
305 if (!cmp.zero_mask()) { // only a * poscoef (which is very quick) needs to be computed
306 return a * poscoef_vec;
307 } else {
308 return Vec::blendv(a * negiptcoef_vec * (b + negcoef_vec), a * poscoef_vec, cmp);
309 }
310 } else {
311 return Vec::blendv(a * negiptcoef_vec * negcoef_vec * (b * negiptcoef_vec).exp(), a * poscoef_vec, cmp);
312 }
313 }
314 );
315 });
316 }
317 }
318
319 // TODO(yangxm): Add another fast kernel using formula
320 // y = 0.5x * (1 + tanh(sqrt(2/Pi) * (x + 0.044715x^3)))
321 // and the fast tanh impl from Eigen.
GeluKernelImpl(TensorIteratorBase & it,GeluType approximate)322 void GeluKernelImpl(TensorIteratorBase& it, GeluType approximate) {
323 auto grain_size = at::internal::GRAIN_SIZE;
324 // Numbers based on benchmarking.
325 // Benchmark: benchmarks/operator_benchmarks/pt/gelu_test.py
326 #ifdef C10_MOBILE
327 // Benchmarked on S8 US phone.
328 // Internal benchmarking that converts operator benchmark into
329 // a torchscript module and run that on mobile.
330 // Same benchmark as server side.
331 constexpr int64_t GELU_MIN_ELEMENTS_FOR_MULTI_THREADING{6144};
332 #else
333 // Benchmarked on i9 8 core 16 thread machine.
334 // 1 thread: cd benchmark/operator_benchmarks;
335 // python -m pt.gelu_test --tag_filter long --omp_num_threads 1
336 // 2 threads: cd benchmark/operator_benchmarks;
337 // python -m pt.gelu_test --tag_filter long --omp_num_threads 1
338 constexpr int64_t GELU_MIN_ELEMENTS_FOR_MULTI_THREADING{16384};
339 #endif
340 if (it.numel() > GELU_MIN_ELEMENTS_FOR_MULTI_THREADING) {
341 grain_size = it.numel() / at::get_num_threads();
342 }
343 if (approximate == GeluType::Tanh) {
344 if (at::isReducedFloatingType(it.common_dtype())) {
345 AT_DISPATCH_REDUCED_FLOATING_TYPES(it.common_dtype(), "GeluKernelImpl", [&]() {
346 auto kBetaVec = Vectorized<float>((float)(M_SQRT2 * M_2_SQRTPI * 0.5));
347 auto kKappaVec = Vectorized<float>((float)(0.044715));
348 auto kOneVec = Vectorized<float>((float)(1));
349 auto kPointFiveVec = Vectorized<float>((float)(0.5));
350 cpu_kernel_vec(
351 it,
352 [](scalar_t x) -> scalar_t {
353 const float kBeta = float(M_SQRT2 * M_2_SQRTPI * 0.5);
354 const float kKappa = float(0.044715);
355 float x_cube = float(x) * float(x) * float(x);
356 float inner = kBeta * (float(x) + kKappa * x_cube);
357 return float(0.5) * float(x) * (float(1) + std::tanh(inner));
358 },
359 [&](Vectorized<scalar_t> x) -> Vectorized<scalar_t> {
360 auto [x0, x1] = convert_to_float<scalar_t>(x);
361 auto x0_cube = x0 * x0 * x0;
362 auto x1_cube = x1 * x1 * x1;
363 auto inner_vec0 = kBetaVec * (x0 + kKappaVec * x0_cube);
364 auto inner_vec1 = kBetaVec * (x1 + kKappaVec * x1_cube);
365 auto res0 = kPointFiveVec * x0 * (kOneVec + inner_vec0.tanh());
366 auto res1 = kPointFiveVec * x1 * (kOneVec + inner_vec1.tanh());
367 return convert_from_float<scalar_t>(res0, res1);
368 },
369 grain_size);
370 });
371 } else {
372 AT_DISPATCH_FLOATING_TYPES(
373 it.dtype(), "GeluKernelImpl", [&]() {
374 using Vec = vec::Vectorized<scalar_t>;
375 const Vec kBetaVec(scalar_t(M_SQRT2 * M_2_SQRTPI * 0.5));
376 const Vec kKappaVec(scalar_t(0.044715));
377 const Vec kOneVec(scalar_t(1));
378 const Vec kPointFiveVec(scalar_t(0.5));
379 cpu_kernel_vec(
380 it,
381 [](scalar_t x) {
382 const scalar_t kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
383 const scalar_t kKappa = 0.044715;
384 auto x_cube = x * x * x;
385 auto inner = kBeta * (x + kKappa * x_cube);
386 return scalar_t(0.5) * x * (scalar_t(1) + std::tanh(inner));
387 },
388 [&](Vec x_vec) {
389 auto x_cube = x_vec * x_vec * x_vec;
390 auto inner_vec = kBetaVec * (x_vec + kKappaVec * x_cube);
391 return kPointFiveVec * x_vec * (kOneVec + inner_vec.tanh());
392 },
393 grain_size);
394 });
395 }
396 } else {
397 if (at::isReducedFloatingType(it.common_dtype())) {
398 AT_DISPATCH_REDUCED_FLOATING_TYPES(it.dtype(), "GeluKernelImpl", [&]() {
399 auto kAlphaVec = Vectorized<float>((float)(M_SQRT1_2));
400 auto kOneVec = Vectorized<float>((float)(1));
401 auto kPointFiveVec = Vectorized<float>((float)(0.5));
402 cpu_kernel_vec(
403 it,
404 [](scalar_t x) -> scalar_t {
405 const float kAlpha = float(M_SQRT1_2);
406 return float(x) * float(0.5) * (float(1) + std::erf(float(x) * kAlpha));
407 },
408 [&](Vectorized<scalar_t> x) -> Vectorized<scalar_t> {
409 auto [x0, x1] = convert_to_float<scalar_t>(x);
410 auto res0 = x0 * kPointFiveVec * (kOneVec + (x0 * kAlphaVec).erf());
411 auto res1 = x1 * kPointFiveVec * (kOneVec + (x1 * kAlphaVec).erf());
412 return convert_from_float<scalar_t>(res0, res1);
413 },
414 grain_size);
415 });
416 } else {
417 AT_DISPATCH_FLOATING_TYPES(
418 it.dtype(), "GeluKernelImpl", [&]() {
419 using Vec = vec::Vectorized<scalar_t>;
420 const Vec kAlphaVec(scalar_t(M_SQRT1_2));
421 const Vec kOneVec(scalar_t(1));
422 const Vec kPointFiveVec(scalar_t(0.5));
423 cpu_kernel_vec(
424 it,
425 [](scalar_t x) {
426 const scalar_t kAlpha = scalar_t(M_SQRT1_2);
427 return x * scalar_t(0.5) * (scalar_t(1) + std::erf(x * kAlpha));
428 },
429 [&](Vec x_vec) {
430 return x_vec * kPointFiveVec *
431 (kOneVec + (x_vec * kAlphaVec).erf());
432 },
433 grain_size);
434 });
435 }
436 }
437 }
438
GeluBackwardKernelImpl(TensorIteratorBase & it,GeluType approximate)439 void GeluBackwardKernelImpl(TensorIteratorBase& it, GeluType approximate) {
440 if (approximate == GeluType::Tanh) {
441 if (at::isReducedFloatingType(it.common_dtype())) {
442 AT_DISPATCH_REDUCED_FLOATING_TYPES(it.common_dtype(), "GeluBackwardKernelImpl", [&]() {
443 auto kBetaVec = Vectorized<float>((float)(M_SQRT2 * M_2_SQRTPI * 0.5));
444 auto kKappaVec = Vectorized<float>((float)(0.044715));
445 auto kOneVec = Vectorized<float>((float)(1));
446 auto kThreeVec = Vectorized<float>((float)(3));
447 auto kPointFiveVec = Vectorized<float>((float)(0.5));
448 cpu_kernel_vec(
449 it,
450 [](scalar_t dy, scalar_t x) -> scalar_t {
451 const float kBeta = float(M_SQRT2 * M_2_SQRTPI * 0.5);
452 const float kKappa = float(0.044715);
453 float x_sq = float(x) * float(x);
454 float x_cube = x_sq * float(x);
455 float inner = kBeta * (float(x) + kKappa * x_cube);
456 float tanh_inner = float(std::tanh(inner));
457
458 float left = float(0.5) * float(x);
459 float right = float(1) + tanh_inner;
460
461 float left_derivative = float(0.5) * right;
462
463 float tanh_derivative = float(1) - tanh_inner * tanh_inner;
464 float inner_derivative =
465 kBeta * (float(1) + float(3) * kKappa * x_sq);
466 float right_derivative = left * tanh_derivative * inner_derivative;
467
468 return float(dy) * (left_derivative + right_derivative);
469 },
470 [&](Vectorized<scalar_t> dy_vec, Vectorized<scalar_t> x_vec) -> Vectorized<scalar_t> {
471 auto [x0_vec, x1_vec] = convert_to_float<scalar_t>(x_vec);
472 auto [dy0_vec, dy1_vec] = convert_to_float<scalar_t>(dy_vec);
473 auto x0_sq = x0_vec * x0_vec;
474 auto x1_sq = x1_vec * x1_vec;
475 auto x0_cube = x0_vec * x0_vec * x0_vec;
476 auto x1_cube = x1_vec * x1_vec * x1_vec;
477 auto inner_vec0 = kBetaVec * (x0_vec + kKappaVec * x0_cube);
478 auto inner_vec1 = kBetaVec * (x1_vec + kKappaVec * x1_cube);
479 auto tanh_inner_vec0 = inner_vec0.tanh();
480 auto tanh_inner_vec1 = inner_vec1.tanh();
481
482 auto left_vec0 = kPointFiveVec * x0_vec;
483 auto left_vec1 = kPointFiveVec * x1_vec;
484 auto right_vec0 = kOneVec + tanh_inner_vec0;
485 auto right_vec1 = kOneVec + tanh_inner_vec1;
486
487 auto left_derivative_vec0 = kPointFiveVec * right_vec0;
488 auto left_derivative_vec1 = kPointFiveVec * right_vec1;
489
490 auto tanh_derivative_vec0 = kOneVec - tanh_inner_vec0 * tanh_inner_vec0;
491 auto tanh_derivative_vec1 = kOneVec - tanh_inner_vec1 * tanh_inner_vec1;
492 auto inner_derivative_vec0 = kBetaVec * (kOneVec + kThreeVec * kKappaVec * x0_sq);
493 auto inner_derivative_vec1 = kBetaVec * (kOneVec + kThreeVec * kKappaVec * x1_sq);
494 auto right_derivative_vec0 = left_vec0 * tanh_derivative_vec0 * inner_derivative_vec0;
495 auto right_derivative_vec1 = left_vec1 * tanh_derivative_vec1 * inner_derivative_vec1;
496
497 auto res0 = dy0_vec * (left_derivative_vec0 + right_derivative_vec0);
498 auto res1 = dy1_vec * (left_derivative_vec1 + right_derivative_vec1);
499 return convert_from_float<scalar_t>(res0, res1);
500 });
501 });
502 } else {
503 AT_DISPATCH_FLOATING_TYPES(
504 it.dtype(), "GeluBackwardKernelImpl", [&]() {
505 using Vec = vec::Vectorized<scalar_t>;
506 const Vec kBetaVec(scalar_t(M_SQRT2 * M_2_SQRTPI * 0.5));
507 const Vec kKappaVec(scalar_t(0.044715));
508 const Vec kOneVec(scalar_t(1));
509 const Vec kThreeVec(scalar_t(3));
510 const Vec kPointFiveVec(scalar_t(0.5));
511 cpu_kernel_vec(
512 it,
513 [](scalar_t dy, scalar_t x) {
514 const scalar_t kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
515 const scalar_t kKappa = 0.044715;
516 auto x_sq = x * x;
517 auto x_cube = x_sq * x;
518 auto inner = kBeta * (x + kKappa * x_cube);
519 auto tanh_inner = std::tanh(inner);
520
521 auto left = scalar_t(0.5) * x;
522 auto right = scalar_t(1) + tanh_inner;
523
524 auto left_derivative = scalar_t(0.5) * right;
525
526 auto tanh_derivative = scalar_t(1) - tanh_inner * tanh_inner;
527 auto inner_derivative =
528 kBeta * (scalar_t(1) + scalar_t(3) * kKappa * x_sq);
529 auto right_derivative = left * tanh_derivative * inner_derivative;
530
531 return dy * (left_derivative + right_derivative);
532 },
533 [&](Vec dy_vec, Vec x_vec) {
534 auto x_sq = x_vec * x_vec;
535 auto x_cube = x_vec * x_vec * x_vec;
536 auto inner_vec =
537 kBetaVec * (x_vec + kKappaVec * x_cube);
538 auto tanh_inner_vec = inner_vec.tanh();
539
540 auto left_vec = kPointFiveVec * x_vec;
541 auto right_vec = kOneVec + tanh_inner_vec;
542
543 auto left_derivative_vec = kPointFiveVec * right_vec;
544
545 auto tanh_derivative_vec =
546 kOneVec - tanh_inner_vec * tanh_inner_vec;
547 auto inner_derivative_vec =
548 kBetaVec * (kOneVec + kThreeVec * kKappaVec * x_sq);
549 auto right_derivative_vec =
550 left_vec * tanh_derivative_vec * inner_derivative_vec;
551
552 return dy_vec * (left_derivative_vec + right_derivative_vec);
553 });
554 });
555 }
556 } else {
557 if (at::isReducedFloatingType(it.common_dtype())) {
558 AT_DISPATCH_REDUCED_FLOATING_TYPES(it.common_dtype(), "GeluBackwardKernelImpl", [&]() {
559 auto kAlphaVec = Vectorized<float>((float)(M_SQRT1_2));
560 auto kBetaVec = Vectorized<float>((float)(M_2_SQRTPI * M_SQRT1_2 * 0.5));
561 auto kOneVec = Vectorized<float>((float)(1));
562 auto kPointFiveVec = Vectorized<float>((float)(0.5));
563 auto kMinusPointFiveVec = Vectorized<float>((float)(-0.5));
564 cpu_kernel_vec(
565 it,
566 [](scalar_t dy, scalar_t x) -> scalar_t {
567 const float kAlpha = float(M_SQRT1_2);
568 const float kBeta = float(M_2_SQRTPI) * float(M_SQRT1_2) * float(0.5);
569 const float cdf =
570 float(0.5) * (float(1) + std::erf(float(x) * kAlpha));
571 const float pdf = kBeta * std::exp(float(x) * float(x) * float(-0.5));
572 return float(dy) * (cdf + float(x) * pdf);
573 },
574 [&](Vectorized<scalar_t> dy, Vectorized<scalar_t> x) -> Vectorized<scalar_t> {
575 auto [x0, x1] = convert_to_float<scalar_t>(x);
576 auto [dy0, dy1] = convert_to_float<scalar_t>(dy);
577 auto cdf_vec0 = kPointFiveVec * (kOneVec + (x0 * kAlphaVec).erf());
578 auto cdf_vec1 = kPointFiveVec * (kOneVec + (x1 * kAlphaVec).erf());
579 auto pdf_vec0 = kBetaVec * (x0 * x0 * kMinusPointFiveVec).exp();
580 auto pdf_vec1 = kBetaVec * (x1 * x1 * kMinusPointFiveVec).exp();
581 auto res0 = dy0 * (cdf_vec0 + x0 * pdf_vec0);
582 auto res1 = dy1 * (cdf_vec1 + x1 * pdf_vec1);
583 return convert_from_float<scalar_t>(res0, res1);
584 });
585 });
586 } else {
587 AT_DISPATCH_FLOATING_TYPES(
588 it.dtype(), "GeluBackwardKernelImpl", [&]() {
589 using Vec = vec::Vectorized<scalar_t>;
590 const Vec kAlphaVec(scalar_t(M_SQRT1_2));
591 const Vec kBetaVec(scalar_t(M_2_SQRTPI * M_SQRT1_2 * 0.5));
592 const Vec kOneVec(scalar_t(1));
593 const Vec kPointFiveVec(scalar_t(0.5));
594 const Vec kMinusPointFiveVec(scalar_t(-0.5));
595 cpu_kernel_vec(
596 it,
597 [](scalar_t dy, scalar_t x) {
598 const scalar_t kAlpha = scalar_t(M_SQRT1_2);
599 const scalar_t kBeta = M_2_SQRTPI * M_SQRT1_2 * scalar_t(0.5);
600 const scalar_t cdf =
601 scalar_t(0.5) * (scalar_t(1) + std::erf(x * kAlpha));
602 const scalar_t pdf = kBeta * std::exp(x * x * scalar_t(-0.5));
603 return dy * (cdf + x * pdf);
604 },
605 [&](Vec dy_vec, Vec x_vec) {
606 const Vec cdf_vec =
607 kPointFiveVec * (kOneVec + (x_vec * kAlphaVec).erf());
608 const Vec pdf_vec =
609 kBetaVec * (x_vec * x_vec * kMinusPointFiveVec).exp();
610 return dy_vec * (cdf_vec + x_vec * pdf_vec);
611 });
612 });
613 }
614 }
615 }
616
hardsigmoid_kernel(TensorIteratorBase & iter)617 void hardsigmoid_kernel(TensorIteratorBase& iter) {
618 if (at::isReducedFloatingType(iter.dtype())) {
619 AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "hardsigmoid_cpu", [&]() {
620 const float zero(0.0f);
621 const float three(3.0f);
622 const float six(6.0f);
623 using Vec = vec::Vectorized<float>;
624 const Vec kZeroVec(zero);
625 const Vec kThreeVec(three);
626 const Vec kSixVec(six);
627 cpu_kernel_vec(
628 iter,
629 [&](scalar_t self_val) -> scalar_t {
630 return std::min(std::max(float(self_val) + three, zero), six) / six;
631 },
632 [&](vec::Vectorized<scalar_t> self_val) -> vec::Vectorized<scalar_t> {
633 auto [self_val0, self_val1] = convert_to_float<scalar_t>(self_val);
634 self_val0 = minimum(
635 maximum(self_val0 + kThreeVec, kZeroVec),
636 kSixVec
637 ) / kSixVec;
638 self_val1 = minimum(
639 maximum(self_val1 + kThreeVec, kZeroVec),
640 kSixVec
641 ) / kSixVec;
642 return convert_from_float<scalar_t>(self_val0, self_val1);
643 });
644 });
645 } else {
646 AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "hardsigmoid_cpu", [&] {
647 const scalar_t zero(0.0f);
648 const scalar_t three(3.0f);
649 const scalar_t six(6.0f);
650 using Vec = vec::Vectorized<scalar_t>;
651 const Vec kZeroVec(zero);
652 const Vec kThreeVec(three);
653 const Vec kSixVec(six);
654 cpu_kernel_vec(
655 iter,
656 [&](scalar_t self_val) {
657 return std::min(std::max(self_val + three, zero), six) / six;
658 },
659 [&](Vec self_val) {
660 return vec::minimum(
661 vec::maximum(self_val + kThreeVec, kZeroVec),
662 kSixVec
663 ) / kSixVec;
664 });
665 });
666 }
667 }
668
hardsigmoid_backward_kernel(TensorIteratorBase & iter)669 void hardsigmoid_backward_kernel(TensorIteratorBase& iter) {
670 if (at::isReducedFloatingType(iter.dtype())) {
671 AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.common_dtype(), "hardsigmoid_backward", [&]() {
672 const float zero(0.0f);
673 const float three(3.0f);
674 const float neg_three(-3.0f);
675 const float one_sixth(1.0f / 6.0f);
676 using Vec = Vectorized<float>;
677 Vec kZeroVec(0.0f);
678 Vec kOneSixthVec(1.0f / 6.0f);
679 cpu_kernel_vec(
680 iter,
681 [=](scalar_t grad_val, scalar_t self_val) -> scalar_t {
682 return (float(self_val) > neg_three && float(self_val) < three)
683 ? float(grad_val) * one_sixth
684 : zero;
685 },
686 [=](Vectorized<scalar_t> grad_val, Vectorized<scalar_t> self_val) -> Vectorized<scalar_t> {
687 auto [self_val0, self_val1] = convert_to_float<scalar_t>(self_val);
688 auto [grad_val0, grad_val1] = convert_to_float<scalar_t>(grad_val);
689 Vec gradNonZeroMask = (self_val0 > neg_three) & (self_val0 < three);
690 self_val0 = Vec::blendv(kZeroVec, grad_val0 * kOneSixthVec, gradNonZeroMask);
691 gradNonZeroMask = (self_val1 > neg_three) & (self_val1 < three);
692 self_val1 = Vec::blendv(kZeroVec, grad_val1 * kOneSixthVec, gradNonZeroMask);
693 return convert_from_float<scalar_t>(self_val0, self_val1);
694 });
695 });
696 } else {
697 AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "hardsigmoid_backward", [&] {
698 const scalar_t zero(0.0f);
699 const scalar_t three(3.0f);
700 const scalar_t neg_three(-3.0f);
701 const scalar_t one_sixth(1.0f / 6.0f);
702 using Vec = Vectorized<scalar_t>;
703 Vec kZeroVec(0.0f);
704 Vec kOneSixthVec(1.0f / 6.0f);
705 cpu_kernel_vec(
706 iter,
707 [=](scalar_t grad_val, scalar_t self_val) {
708 return (self_val > neg_three && self_val < three)
709 ? grad_val * one_sixth
710 : zero;
711 },
712 [=](Vec grad_val, Vec self_val) {
713 Vec gradNonZeroMask = (self_val > neg_three) & (self_val < three);
714 return Vec::blendv(kZeroVec, grad_val * kOneSixthVec, gradNonZeroMask);
715 });
716 });
717 }
718 }
719
hardshrink_kernel(TensorIteratorBase & iter,const Scalar & lambd)720 void hardshrink_kernel(TensorIteratorBase& iter, const Scalar& lambd) {
721 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "hardshrink_cpu", [&] {
722 auto lambd_val = lambd.to<scalar_t>();
723 using Vec = Vectorized<scalar_t>;
724 cpu_kernel_vec(
725 iter,
726 [=](scalar_t self_val) {
727 return (self_val >= -lambd_val && self_val <= lambd_val) ? scalar_t(0)
728 : self_val;
729 },
730 [=](Vec self_val) {
731 return Vec::blendv(self_val, Vec(0), (self_val >= -lambd_val) & (self_val <= lambd_val));
732 });
733 });
734 }
735
softshrink_kernel(TensorIteratorBase & iter,const Scalar & lambd)736 void softshrink_kernel(TensorIteratorBase& iter, const Scalar& lambd) {
737 if (at::isReducedFloatingType(iter.dtype())) {
738 AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.common_dtype(), "softshrink_cpu", [&]() {
739 auto lambd_val = lambd.to<float>();
740 auto lambdVec = Vectorized<float>(lambd_val);
741 cpu_kernel_vec(
742 iter,
743 [=](scalar_t a) -> scalar_t {
744 return float(a) > lambd_val ? a - lambd_val : (float(a) < -lambd_val ? a + lambd_val : float(0));
745 },
746 [=](Vectorized<scalar_t> self_val) -> Vectorized<scalar_t> {
747 auto [self_val0, self_val1] = convert_to_float<scalar_t>(self_val);
748 auto self_val_t0 = convert_from_float<scalar_t>((self_val0 > lambdVec) & (self_val0 - lambdVec), (self_val1 > lambdVec) & (self_val1 - lambdVec));
749 auto self_val_t1 = convert_from_float<scalar_t>((self_val0 < -lambd_val) & (self_val0 + lambdVec), (self_val1 < -lambd_val) & (self_val1 + lambdVec));
750 return (self_val_t0 | self_val_t1);
751 });
752 });
753 } else {
754 AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "softshrink_cpu", [&]() {
755 auto lambd_val = lambd.to<scalar_t>();
756 auto lambdVec = Vectorized<scalar_t>(lambd_val);
757 cpu_kernel_vec(
758 iter,
759 [=](scalar_t a) -> scalar_t {
760 return a > lambd_val ? a - lambd_val : (a < -lambd_val ? a + lambd_val : scalar_t(0));
761 },
762 [=](Vectorized<scalar_t> self_val) -> Vectorized<scalar_t> {
763 Vectorized<scalar_t> self_val_t0, self_val_t1;
764 self_val_t0 = (self_val > lambdVec) & (self_val - lambdVec);
765 self_val_t1 = (self_val < -lambd_val) & (self_val + lambdVec);
766 return (self_val_t0 | self_val_t1);
767 });
768 });
769 }
770 }
771
shrink_backward_kernel(TensorIteratorBase & iter,const Scalar & lambd)772 void shrink_backward_kernel(TensorIteratorBase& iter, const Scalar& lambd) {
773 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "shrink_backward_cpu", [&] {
774 auto lambd_val = lambd.to<scalar_t>();
775 cpu_kernel_vec(
776 iter,
777 [=](scalar_t grad_val, scalar_t self_val) {
778 return (self_val >= -lambd_val && self_val <= lambd_val) ? scalar_t(0)
779 : grad_val;
780 },
781 [=](Vectorized<scalar_t> grad_val, Vectorized<scalar_t> self_val) {
782 return ((self_val < -lambd_val) | (self_val > lambd_val)) & grad_val;
783 });
784 });
785 }
786
hardtanh_backward_kernel(TensorIterator & iter,const Scalar & min,const Scalar & max)787 void hardtanh_backward_kernel(TensorIterator& iter, const Scalar& min, const Scalar& max) {
788 if (at::isReducedFloatingType(iter.dtype())) {
789 AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "hardshrink_backward_cpu", [&]() {
790 auto min_val = min.to<float>();
791 auto max_val = max.to<float>();
792 cpu_kernel_vec(
793 iter,
794 [=](scalar_t grad_val, scalar_t self_val) -> scalar_t {
795 return (float(self_val) <= min_val || float(self_val) >= max_val) ? scalar_t(0) : grad_val;
796 },
797 [=](Vectorized<scalar_t> grad_val, Vectorized<scalar_t> self_val) -> Vectorized<scalar_t> {
798 auto [grad_val0, grad_val1] = convert_to_float<scalar_t>(grad_val);
799 auto [self_val0, self_val1] = convert_to_float<scalar_t>(self_val);
800 return convert_from_float<scalar_t>(
801 ((self_val0 > min_val) & (self_val0 < max_val)) & grad_val0,
802 ((self_val1 > min_val) & (self_val1 < max_val)) & grad_val1
803 );
804 });
805 });
806 } else {
807 AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "hardshrink_backward_cpu", [&] {
808 auto min_val = min.to<scalar_t>();
809 auto max_val = max.to<scalar_t>();
810 cpu_kernel_vec(
811 iter,
812 [=](scalar_t grad_val, scalar_t self_val) {
813 return (self_val <= min_val || self_val >= max_val) ? scalar_t(0) : grad_val;
814 },
815 [=](Vectorized<scalar_t> grad_val, Vectorized<scalar_t> self_val) {
816 return ((self_val > min_val) & (self_val < max_val)) & grad_val;
817 });
818 });
819 }
820 }
821
hardswish_kernel(TensorIterator & iter)822 void hardswish_kernel(TensorIterator& iter) {
823 if (at::isReducedFloatingType(iter.dtype())) {
824 AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "hardswish_cpu", [&]() {
825 const float zero(0.0f);
826 const float three(3.0f);
827 const float six(6.0f);
828 using Vec = vec::Vectorized<float>;
829 const Vec kZeroVec(zero);
830 const Vec kThreeVec(three);
831 const Vec kSixVec(six);
832 cpu_kernel_vec(
833 iter,
834 [&](scalar_t x) -> scalar_t {
835 return float(x) * std::min(std::max(float(x) + three, zero), six) / six;
836 },
837 [&](vec::Vectorized<scalar_t> x_vec) {
838 auto [x_vec0, x_vec1] = convert_to_float<scalar_t>(x_vec);
839 x_vec0 = x_vec0 * minimum(
840 maximum(x_vec0 + kThreeVec, kZeroVec),
841 kSixVec
842 ) / kSixVec;
843 x_vec1 = x_vec1 * minimum(
844 maximum(x_vec1 + kThreeVec, kZeroVec),
845 kSixVec
846 ) / kSixVec;
847 return convert_from_float<scalar_t>(x_vec0, x_vec1);
848 });
849 });
850 } else {
851 AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "hardswish_cpu", [&]() {
852 const scalar_t zero(0.0f);
853 const scalar_t three(3.0f);
854 const scalar_t six(6.0f);
855 using Vec = vec::Vectorized<scalar_t>;
856 const Vec kZeroVec(zero);
857 const Vec kThreeVec(three);
858 const Vec kSixVec(six);
859 cpu_kernel_vec(
860 iter,
861 [&](scalar_t x) {
862 return x * std::min(std::max(x + three, zero), six) / six;
863 },
864 [&](Vec x_vec) {
865 return x_vec * vec::minimum(
866 vec::maximum(x_vec + kThreeVec, kZeroVec),
867 kSixVec
868 ) / kSixVec;
869 }
870 );
871 });
872 }
873 }
874
hardswish_backward_kernel(TensorIterator & iter)875 void hardswish_backward_kernel(TensorIterator& iter) {
876 if (at::isReducedFloatingType(iter.dtype())) {
877 AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "hardswish_backward_cpu", [&]() {
878 const float zero(0.0f);
879 const float three(3.0f);
880 const float neg_three(-3.0f);
881 const float one_half(0.5f);
882 using Vec = vec::Vectorized<float>;
883 const Vec kZeroVec(zero);
884 const Vec kThreeVec(three);
885 const Vec kNegThreeVec(neg_three);
886 const Vec kOneHalfVec(one_half);
887 cpu_kernel_vec(
888 iter,
889 [&](scalar_t grad_val, scalar_t self_val) -> scalar_t {
890 if (float(self_val) < neg_three) {
891 return zero;
892 } else if (float(self_val) <= three) {
893 return float(grad_val) * ((float(self_val) / three) + one_half);
894 } else {
895 return grad_val;
896 }
897 },
898 [&](vec::Vectorized<scalar_t> grad_val, vec::Vectorized<scalar_t> self_val) {
899 auto [self_val0, self_val1] = convert_to_float<scalar_t>(self_val);
900 auto [grad_val0, grad_val1] = convert_to_float<scalar_t>(grad_val);
901 self_val0 = Vec::blendv(
902 Vec::blendv(
903 grad_val0 * ((self_val0 / kThreeVec) + kOneHalfVec),
904 grad_val0,
905 self_val0 >= kThreeVec
906 ),
907 kZeroVec,
908 self_val0 < kNegThreeVec
909 );
910 self_val1 = Vec::blendv(
911 Vec::blendv(
912 grad_val1 * ((self_val1 / kThreeVec) + kOneHalfVec),
913 grad_val1,
914 self_val1 >= kThreeVec
915 ),
916 kZeroVec,
917 self_val1 < kNegThreeVec
918 );
919 return convert_from_float<scalar_t>(self_val0, self_val1);
920 });
921 });
922 } else {
923 AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "hardswish_backward_cpu", [&]() {
924 const scalar_t zero(0.0f);
925 const scalar_t three(3.0f);
926 const scalar_t neg_three(-3.0f);
927 const scalar_t one_half(0.5f);
928 using Vec = vec::Vectorized<scalar_t>;
929 const Vec kZeroVec(zero);
930 const Vec kThreeVec(three);
931 const Vec kNegThreeVec(neg_three);
932 const Vec kOneHalfVec(one_half);
933 cpu_kernel_vec(
934 iter,
935 [&](scalar_t grad_val, scalar_t self_val) {
936 if (self_val < neg_three) {
937 return zero;
938 } else if (self_val <= three) {
939 return grad_val * ((self_val / three) + one_half);
940 } else {
941 return grad_val;
942 }
943 },
944 [&](Vec grad_val, Vec self_val) {
945 return Vec::blendv(
946 Vec::blendv(
947 grad_val * ((self_val / kThreeVec) + kOneHalfVec),
948 grad_val,
949 self_val >= kThreeVec
950 ),
951 kZeroVec,
952 self_val < kNegThreeVec
953 );
954 }
955 );
956 });
957 }
958 }
959
leaky_relu_kernel(TensorIteratorBase & iter,const Scalar & negval_)960 static void leaky_relu_kernel(TensorIteratorBase& iter, const Scalar& negval_) {
961 if (at::isReducedFloatingType(iter.dtype())) {
962 AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "leaky_relu_cpu", [&]() {
963 auto zero_vec = Vectorized<float>((float)(0));
964 auto one_vec = Vectorized<float>((float)(1));
965 float negval = negval_.to<float>();
966 Vectorized<float> negval_v = Vectorized<float>(negval);
967 cpu_kernel_vec(
968 iter,
969 [&](scalar_t a) -> scalar_t {
970 return float(a) > float(0) ? float(a) : float(a) * negval;
971 },
972 [&](Vectorized<scalar_t> a) -> Vectorized<scalar_t> {
973 auto [a0, a1] = convert_to_float<scalar_t>(a);
974 auto res0 = a0 * (Vectorized<float>::blendv(negval_v, one_vec, a0 > zero_vec));
975 auto res1 = a1 * (Vectorized<float>::blendv(negval_v, one_vec, a1 > zero_vec));
976 return convert_from_float<scalar_t>(res0, res1);
977 });
978 });
979 } else {
980 AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "leaky_relu_cpu", [&] {
981 using Vec = Vectorized<scalar_t>;
982 auto zero_vec = Vec((scalar_t)(0));
983 auto one_vec = Vec((scalar_t)(1));
984 scalar_t negval = negval_.to<scalar_t>();
985 Vec negval_v = Vec(negval);
986 cpu_kernel_vec(
987 iter,
988 [&](scalar_t a) -> scalar_t {
989 return a > scalar_t(0) ? a : a * negval;
990 },
991 [&](Vec a) -> Vec {
992 auto r = Vec::blendv(negval_v, one_vec, a > zero_vec);
993 return a * r;
994 });
995 });
996 }
997 }
998
leaky_relu_backward_kernel(TensorIteratorBase & iter,const Scalar & negval_)999 static void leaky_relu_backward_kernel(TensorIteratorBase& iter, const Scalar& negval_) {
1000 if (at::isReducedFloatingType(iter.dtype())) {
1001 AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "leaky_relu_backward_cpu", [&]() {
1002 auto zero_vec = Vectorized<float>((float)(0));
1003 auto one_vec = Vectorized<float>((float)(1));
1004 float negval = negval_.to<float>();
1005 Vectorized<float> negval_v = Vectorized<float>(negval);
1006 cpu_kernel_vec(
1007 iter,
1008 [&](scalar_t a, scalar_t b) -> scalar_t {
1009 return float(a) > float(0) ? float(b) : float(b) * negval;
1010 },
1011 [&](Vectorized<scalar_t> a, Vectorized<scalar_t> b) -> Vectorized<scalar_t> {
1012 auto [a0, a1] = convert_to_float<scalar_t>(a);
1013 auto [b0, b1] = convert_to_float<scalar_t>(b);
1014 auto res0 = b0 * (Vectorized<float>::blendv(negval_v, one_vec, a0 > zero_vec));
1015 auto res1 = b1 * (Vectorized<float>::blendv(negval_v, one_vec, a1 > zero_vec));
1016 return convert_from_float<scalar_t>(res0, res1);
1017 });
1018 });
1019 } else {
1020 AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "leaky_relu_backward_cpu", [&] {
1021 using Vec = Vectorized<scalar_t>;
1022 auto zero_vec = Vec((scalar_t)(0));
1023 auto one_vec = Vec((scalar_t)(1));
1024 scalar_t negval = negval_.to<scalar_t>();
1025 Vec negval_v = Vec(negval);
1026 cpu_kernel_vec(
1027 iter,
1028 [&](scalar_t a, scalar_t b) -> scalar_t {
1029 return a > scalar_t(0) ? b : b * negval;
1030 },
1031 [&](Vec a, Vec b) -> Vec {
1032 auto r = Vec::blendv(negval_v, one_vec, a > zero_vec);
1033 return b * r;
1034 });
1035 });
1036 }
1037 }
1038
softplus_kernel(TensorIteratorBase & iter,const Scalar & beta_,const Scalar & threshold_)1039 void softplus_kernel(TensorIteratorBase& iter, const Scalar& beta_, const Scalar& threshold_) {
1040 if (at::isReducedFloatingType(iter.dtype())) {
1041 AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "softplus_cpu", [&]() {
1042 using Vec = Vectorized<float>;
1043 auto beta = beta_.to<float>();
1044 auto threshold = threshold_.to<float>();
1045 const Vec beta_vec(beta);
1046 const Vec threshold_vec(threshold);
1047 cpu_kernel_vec(
1048 iter,
1049 [beta, threshold](scalar_t a) -> scalar_t {
1050 return (float(a) * beta) > threshold ? a
1051 : static_cast<scalar_t>((std::log1p(std::exp(float(a) * beta))) / beta);
1052 },
1053 [beta_vec, threshold_vec](Vectorized<scalar_t> a) -> Vectorized<scalar_t> {
1054 auto [a0, a1] = convert_to_float<scalar_t>(a);
1055 a0 = Vec::blendv((a0 * beta_vec).exp().log1p() / beta_vec, a0, (a0 * beta_vec) > threshold_vec);
1056 a1 = Vec::blendv((a1 * beta_vec).exp().log1p() / beta_vec, a1, (a1 * beta_vec) > threshold_vec);
1057 return convert_from_float<scalar_t>(a0, a1);
1058 }
1059 );
1060 });
1061 } else {
1062 AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "softplus_cpu", [&]() {
1063 using Vec = Vectorized<scalar_t>;
1064 auto beta = beta_.to<scalar_t>();
1065 auto threshold = threshold_.to<scalar_t>();
1066 const Vec beta_vec(beta);
1067 const Vec threshold_vec(threshold);
1068 cpu_kernel_vec(
1069 iter,
1070 [beta, threshold](scalar_t a) -> scalar_t {
1071 return (a * beta) > threshold ? a
1072 : static_cast<scalar_t>(std::log1p(std::exp(a * beta))) / beta;
1073 },
1074 [beta_vec, threshold_vec](Vec a) -> Vec {
1075 return Vec::blendv((a * beta_vec).exp().log1p() / beta_vec, a, (a * beta_vec) > threshold_vec);
1076 }
1077 );
1078 });
1079 }
1080 }
1081
softplus_backward_kernel(TensorIteratorBase & iter,const Scalar & beta_,const Scalar & threshold_)1082 void softplus_backward_kernel(TensorIteratorBase& iter, const Scalar& beta_, const Scalar& threshold_) {
1083 if (at::isReducedFloatingType(iter.dtype())) {
1084 AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "softplus_backward_cpu", [&]() {
1085 using Vec = Vectorized<float>;
1086 auto beta = beta_.to<float>();
1087 auto threshold = threshold_.to<float>();
1088 const Vec beta_vec(beta);
1089 const Vec threshold_vec(threshold);
1090 const Vec one_vec(static_cast<float>(1.0));
1091 cpu_kernel_vec(
1092 iter,
1093 [beta, threshold](scalar_t a, scalar_t b) -> scalar_t {
1094 float z = std::exp(float(b) * beta);
1095 return (float(b) * beta) > threshold ? a : static_cast<scalar_t>(float(a) * z / (z + float(1.)));
1096 },
1097 [beta_vec, one_vec, threshold_vec](Vectorized<scalar_t> a, Vectorized<scalar_t> b) -> Vectorized<scalar_t> {
1098 auto [a0, a1] = convert_to_float<scalar_t>(a);
1099 auto [b0, b1] = convert_to_float<scalar_t>(b);
1100 Vec z = (b0 * beta_vec).exp();
1101 a0 = Vec::blendv(a0 * z / (z + one_vec), a0, (b0 * beta_vec) > threshold_vec);
1102 z = (b1 * beta_vec).exp();
1103 a1 = Vec::blendv(a1 * z / (z + one_vec), a1, (b1 * beta_vec) > threshold_vec);
1104 return convert_from_float<scalar_t>(a0, a1);
1105 });
1106 });
1107 } else {
1108 AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "softplus_backward_cpu", [&]() {
1109 using Vec = Vectorized<scalar_t>;
1110 auto beta = beta_.to<scalar_t>();
1111 auto threshold = threshold_.to<scalar_t>();
1112 const Vec beta_vec(beta);
1113 const Vec threshold_vec(threshold);
1114 const Vec one_vec(static_cast<scalar_t>(1.0));
1115 cpu_kernel_vec(
1116 iter,
1117 [beta, threshold](scalar_t a, scalar_t b) -> scalar_t {
1118 scalar_t z = std::exp(b * beta);
1119 return (b * beta) > threshold ? a : a * z / (z + scalar_t(1.));
1120 },
1121 [beta_vec, one_vec, threshold_vec](Vec a, Vec b) -> Vec {
1122 const Vec z = (b * beta_vec).exp();
1123 return Vec::blendv(a * z / (z + one_vec), a, (b * beta_vec) > threshold_vec);
1124 }
1125 );
1126 });
1127 }
1128 }
1129
glu_kernel(TensorIteratorBase & iter)1130 void glu_kernel(TensorIteratorBase& iter) {
1131 if (at::isReducedFloatingType(iter.dtype())) {
1132 AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "glu_cpu", [&]() {
1133 const float float_one_val(1);
1134 const Vectorized<float> float_one_vec(float_one_val);
1135 cpu_kernel_vec(
1136 iter,
1137 [float_one_val](scalar_t a, scalar_t b) -> scalar_t {
1138 return float(a) * (float_one_val / (float_one_val + std::exp(- float(b))));
1139 },
1140 [float_one_vec](Vectorized<scalar_t> a, Vectorized<scalar_t> b) -> Vectorized<scalar_t> {
1141 auto [a0, a1] = convert_to_float<scalar_t>(a);
1142 auto [b0, b1] = convert_to_float<scalar_t>(b);
1143 return convert_from_float<scalar_t>(a0 * (float_one_vec / (float_one_vec + b0.neg().exp())),
1144 a1 * (float_one_vec / (float_one_vec + b1.neg().exp())));
1145 });
1146 });
1147 } else {
1148 AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "glu_cpu", [&] {
1149 using Vec = Vectorized<scalar_t>;
1150 const scalar_t one_val(1);
1151 const Vec one_vec(one_val);
1152 cpu_kernel_vec(
1153 iter,
1154 [one_val](scalar_t a, scalar_t b) -> scalar_t {
1155 return a * (one_val / (one_val + std::exp(-b)));
1156 },
1157 [one_vec](Vec a, Vec b) -> Vec {
1158 return a * (one_vec / (one_vec + b.neg().exp()));
1159 }
1160 );
1161 });
1162 }
1163 }
1164
glu_jvp_kernel(TensorIteratorBase & iter)1165 void glu_jvp_kernel(TensorIteratorBase& iter) {
1166 AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "glu_jvp_cpu", [&] {
1167 using Vec = Vectorized<scalar_t>;
1168 const scalar_t one(1);
1169 const Vec ones(one);
1170 cpu_kernel_vec(
1171 iter,
1172 [one](scalar_t res, scalar_t b, scalar_t da, scalar_t db) -> scalar_t {
1173 const auto sig_b = one / (one + std::exp(-b));
1174 return da * sig_b + res * (db - sig_b * db);
1175 },
1176 [ones](Vec res, Vec b, Vec da, Vec db) -> Vec {
1177 const auto sig_b = ones / (ones + b.neg().exp());
1178 return da * sig_b + res * (db - sig_b * db);
1179 }
1180 );
1181 });
1182 }
1183
glu_backward_kernel(TensorIterator & iter)1184 void glu_backward_kernel(TensorIterator& iter) {
1185 if (at::isReducedFloatingType(iter.dtype())) {
1186 AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "glu_backward_cpu", [&]() {
1187 const float float_one_val(1);
1188 const Vectorized<float> float_one_vec(float_one_val);
1189 cpu_kernel_vec(
1190 iter,
1191 [float_one_val](scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
1192 return (float_one_val - float(a)) * float(a) * float(b) * float(c);
1193 },
1194 [float_one_vec](Vectorized<scalar_t> a, Vectorized<scalar_t> b, Vectorized<scalar_t> c) -> Vectorized<scalar_t> {
1195 auto [a0, a1] = convert_to_float<scalar_t>(a);
1196 auto [b0, b1] = convert_to_float<scalar_t>(b);
1197 auto [c0, c1] = convert_to_float<scalar_t>(c);
1198 a0 = (float_one_vec - a0) * a0 * b0 * c0;
1199 a1 = (float_one_vec - a1) * a1 * b1 * c1;
1200 return convert_from_float<scalar_t>(a0, a1);
1201 });
1202 });
1203 } else {
1204 AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "glu_backward_cpu", [&] {
1205 using Vec = Vectorized<scalar_t>;
1206 const scalar_t one_val(1);
1207 const Vec one_vec(one_val);
1208 cpu_kernel_vec(
1209 iter,
1210 [one_val](scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
1211 return (one_val - a) * a * b * c;
1212 },
1213 [one_vec](Vec a, Vec b, Vec c) -> Vec {
1214 return (one_vec - a) * a * b * c;
1215 }
1216 );
1217 });
1218 }
1219 }
1220
silu_kernel(TensorIteratorBase & iter)1221 void silu_kernel(TensorIteratorBase& iter) {
1222 if (at::isReducedFloatingType(iter.dtype())) {
1223 AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "silu_cpu", [&]() {
1224 const Vectorized<float> kOneVec(1.0f);
1225 cpu_kernel_vec(
1226 iter,
1227 [](scalar_t x) -> scalar_t {
1228 return float(x) / (1.0f + std::exp(-float(x)));
1229 },
1230 [kOneVec](Vectorized<scalar_t> x_vec) -> Vectorized<scalar_t> {
1231 auto [x_vec0, x_vec1] = convert_to_float<scalar_t>(x_vec);
1232 return convert_from_float<scalar_t>(
1233 x_vec0 / (kOneVec + x_vec0.neg().exp()),
1234 x_vec1 / (kOneVec + x_vec1.neg().exp()));
1235 });
1236 });
1237 } else {
1238 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
1239 iter.dtype(), "silu_cpu", [&]() {
1240 const Vectorized<scalar_t> kOneVec(scalar_t(1));
1241 cpu_kernel_vec(
1242 iter,
1243 [](scalar_t x) {
1244 return x / (scalar_t(1) + std::exp(-x));
1245 },
1246 [kOneVec](Vectorized<scalar_t> x_vec) {
1247 return x_vec / (kOneVec + x_vec.neg().exp());
1248 });
1249 });
1250 }
1251 }
1252
silu_backward_kernel(TensorIteratorBase & iter)1253 void silu_backward_kernel(TensorIteratorBase& iter) {
1254 if (at::isReducedFloatingType(iter.dtype())) {
1255 AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "silu_backward_cpu", [&]() {
1256 const Vectorized<float> kOneVec(1.0f);
1257 cpu_kernel_vec(
1258 iter,
1259 [](scalar_t dy, scalar_t x) -> scalar_t {
1260 const float sigmoid =
1261 1.0f / (1.0f + std::exp(-float(x)));
1262 return dy * sigmoid * (1.0f + x * (1.0f - sigmoid));
1263 },
1264 [kOneVec](Vectorized<scalar_t> dy_vec, Vectorized<scalar_t> x_vec) -> Vectorized<scalar_t> {
1265 auto [x_vec0, x_vec1] = convert_to_float<scalar_t>(x_vec);
1266 auto [dy_vec0, dy_vec1] = convert_to_float<scalar_t>(dy_vec);
1267 const Vectorized<float> sigmoid0 =
1268 kOneVec / (kOneVec + x_vec0.neg().exp());
1269 const Vectorized<float> sigmoid1 =
1270 kOneVec / (kOneVec + x_vec1.neg().exp());
1271 return convert_from_float<scalar_t>(
1272 dy_vec0 * sigmoid0 * (kOneVec + x_vec0 * (kOneVec - sigmoid0)),
1273 dy_vec1 * sigmoid1 * (kOneVec + x_vec1 * (kOneVec - sigmoid1)));
1274 });
1275 });
1276 } else {
1277 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
1278 iter.dtype(), "silu_backward_cpu", [&]() {
1279 const Vectorized<scalar_t> kOneVec(scalar_t(1));
1280 cpu_kernel_vec(
1281 iter,
1282 [](scalar_t dy, scalar_t x) {
1283 const scalar_t sigmoid =
1284 scalar_t(1) / (scalar_t(1) + std::exp(-x));
1285 return dy * sigmoid * (scalar_t(1) + x * (scalar_t(1) - sigmoid));
1286 },
1287 [kOneVec](Vectorized<scalar_t> dy_vec, Vectorized<scalar_t> x_vec) {
1288 const Vectorized<scalar_t> sigmoid =
1289 kOneVec / (kOneVec + x_vec.neg().exp());
1290 return dy_vec * sigmoid * (kOneVec + x_vec * (kOneVec - sigmoid));
1291 });
1292 });
1293 }
1294 }
1295
mish_kernel(TensorIteratorBase & iter)1296 void mish_kernel(TensorIteratorBase& iter) {
1297 if (at::isReducedFloatingType(iter.dtype())) {
1298 AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "mish_cpu", [&]() {
1299 cpu_kernel_vec(
1300 iter,
1301 [](scalar_t x) -> scalar_t{
1302 return static_cast<scalar_t>(float(x) * std::tanh(std::log1p(std::exp(float(x)))));
1303 },
1304 [](Vectorized<scalar_t> x_vec) -> Vectorized<scalar_t> {
1305 auto [x_vec0, x_vec1] = convert_to_float<scalar_t>(x_vec);
1306 return convert_from_float<scalar_t>(
1307 x_vec0 * x_vec0.exp().log1p().tanh(),
1308 x_vec1 * x_vec1.exp().log1p().tanh()
1309 );
1310 });
1311 });
1312 } else {
1313 AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "mish_cpu", [&]() {
1314 using Vec = Vectorized<scalar_t>;
1315 cpu_kernel_vec(
1316 iter,
1317 [](scalar_t x) -> scalar_t{
1318 return static_cast<scalar_t>(x * std::tanh(std::log1p(std::exp(x))));
1319 },
1320 [](Vec x_vec) -> Vec {
1321 return x_vec * x_vec.exp().log1p().tanh();
1322 });
1323 });
1324 }
1325 }
1326
mish_backward_kernel(TensorIterator & iter)1327 void mish_backward_kernel(TensorIterator& iter) {
1328 if (at::isReducedFloatingType(iter.dtype())) {
1329 AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "mish_backward_cpu", [&]() {
1330 using Vec = Vectorized<float>;
1331 const Vec kOneVec(1.0f);
1332 cpu_kernel_vec(
1333 iter,
1334 [](scalar_t dy, scalar_t x) -> scalar_t {
1335 const float sigmoid =
1336 1.0f / (1.0f + std::exp(-float(x)));
1337 const float tanh_softplus = std::tanh(std::log1p(std::exp(float(x))));
1338 return dy * (tanh_softplus + x * sigmoid * (1.0f - tanh_softplus * tanh_softplus));
1339 },
1340 [kOneVec](Vectorized<scalar_t> dy_vec, Vectorized<scalar_t> x_vec) -> Vectorized<scalar_t> {
1341 auto [x_vec0, x_vec1] = convert_to_float<scalar_t>(x_vec);
1342 auto [dy_vec0, dy_vec1] = convert_to_float<scalar_t>(dy_vec);
1343 const Vec sigmoid0 = kOneVec / (kOneVec + x_vec0.neg().exp());
1344 const Vec sigmoid1 = kOneVec / (kOneVec + x_vec1.neg().exp());
1345 const Vec tanh_softplus0 = x_vec0.exp().log1p().tanh();
1346 const Vec tanh_softplus1 = x_vec1.exp().log1p().tanh();
1347 return convert_from_float<scalar_t>(
1348 dy_vec0 * (tanh_softplus0 + x_vec0 * sigmoid0 * (kOneVec - tanh_softplus0 * tanh_softplus0)),
1349 dy_vec1 * (tanh_softplus1 + x_vec1 * sigmoid1 * (kOneVec - tanh_softplus1 * tanh_softplus1))
1350 );
1351 });
1352 });
1353 } else {
1354 AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "mish_backward_cpu", [&]() {
1355 using Vec = Vectorized<scalar_t>;
1356 const Vec kOneVec(scalar_t(1));
1357 cpu_kernel_vec(
1358 iter,
1359 [](scalar_t dy, scalar_t x) -> scalar_t {
1360 const scalar_t sigmoid =
1361 scalar_t(1) / (scalar_t(1) + std::exp(-x));
1362 const scalar_t tanh_softplus = std::tanh(std::log1p(std::exp(x)));
1363 return dy * (tanh_softplus + x * sigmoid * (scalar_t(1) - tanh_softplus * tanh_softplus));
1364 },
1365 [kOneVec](Vec dy_vec, Vec x_vec) -> Vec {
1366 const Vec sigmoid = kOneVec / (kOneVec + x_vec.neg().exp());
1367 const Vec tanh_softplus = x_vec.exp().log1p().tanh();
1368 return dy_vec * (tanh_softplus + x_vec * sigmoid * (kOneVec - tanh_softplus * tanh_softplus));
1369 });
1370 });
1371 }
1372 }
1373
prelu_kernel(TensorIterator & iter)1374 void prelu_kernel(TensorIterator& iter) {
1375 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "prelu_cpu", [&]() {
1376 using Vec = Vectorized<scalar_t>;
1377 cpu_kernel_vec(
1378 iter,
1379 [](scalar_t input, scalar_t weight) {
1380 return (input > scalar_t(0)) ? input : weight * input;
1381 },
1382 [](Vec input, Vec weight) {
1383 return Vec::blendv(weight * input, input, input > Vec(0));
1384 });
1385 });
1386 }
1387
prelu_backward_kernel(TensorIterator & iter)1388 void prelu_backward_kernel(TensorIterator& iter) {
1389 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "prelu_backward_cpu", [&]() {
1390 cpu_kernel_multiple_outputs(iter,
1391 [](scalar_t input, scalar_t weight, scalar_t grad) -> std::tuple<scalar_t, scalar_t> {
1392 auto mask = input > scalar_t{0};
1393 auto grad_input = mask ? grad : weight * grad;
1394 auto grad_weight = mask ? scalar_t{0} : input * grad;
1395 return {grad_input, grad_weight};
1396 });
1397 });
1398 }
1399
1400 } // namespace
1401
1402
1403 REGISTER_DISPATCH(hardsigmoid_stub, &hardsigmoid_kernel);
1404 REGISTER_DISPATCH(hardsigmoid_backward_stub, &hardsigmoid_backward_kernel);
1405 REGISTER_DISPATCH(threshold_stub, &threshold_kernel);
1406 REGISTER_DISPATCH(leaky_relu_stub, &leaky_relu_kernel);
1407 REGISTER_DISPATCH(leaky_relu_backward_stub, &leaky_relu_backward_kernel);
1408 REGISTER_DISPATCH(prelu_stub, &prelu_kernel);
1409 REGISTER_DISPATCH(prelu_backward_stub, &prelu_backward_kernel);
1410 REGISTER_DISPATCH(hardtanh_backward_stub, &hardtanh_backward_kernel);
1411 REGISTER_DISPATCH(hardshrink_stub, &hardshrink_kernel);
1412 REGISTER_DISPATCH(softshrink_stub, &softshrink_kernel);
1413 REGISTER_DISPATCH(shrink_backward_stub, &shrink_backward_kernel);
1414
1415 ALSO_REGISTER_AVX512_DISPATCH(log_sigmoid_cpu_stub, &log_sigmoid_cpu_kernel);
1416 ALSO_REGISTER_AVX512_DISPATCH(log_sigmoid_backward_stub, &log_sigmoid_backward_cpu_kernel);
1417 ALSO_REGISTER_AVX512_DISPATCH(glu_stub, &glu_kernel);
1418 ALSO_REGISTER_AVX512_DISPATCH(glu_backward_stub, &glu_backward_kernel);
1419 ALSO_REGISTER_AVX512_DISPATCH(glu_jvp_stub, &glu_jvp_kernel);
1420 ALSO_REGISTER_AVX512_DISPATCH(elu_stub, &elu_kernel);
1421 ALSO_REGISTER_AVX512_DISPATCH(elu_backward_stub, &elu_backward_kernel);
1422 ALSO_REGISTER_AVX512_DISPATCH(GeluKernel, &GeluKernelImpl);
1423 ALSO_REGISTER_AVX512_DISPATCH(GeluBackwardKernel, &GeluBackwardKernelImpl);
1424 ALSO_REGISTER_AVX512_DISPATCH(hardswish_stub, &hardswish_kernel);
1425 ALSO_REGISTER_AVX512_DISPATCH(hardswish_backward_stub, &hardswish_backward_kernel);
1426 ALSO_REGISTER_AVX512_DISPATCH(softplus_stub, &softplus_kernel);
1427 ALSO_REGISTER_AVX512_DISPATCH(softplus_backward_stub, &softplus_backward_kernel);
1428 ALSO_REGISTER_AVX512_DISPATCH(silu_stub, &silu_kernel);
1429 ALSO_REGISTER_AVX512_DISPATCH(silu_backward_stub, &silu_backward_kernel);
1430 ALSO_REGISTER_AVX512_DISPATCH(mish_stub, &mish_kernel);
1431 ALSO_REGISTER_AVX512_DISPATCH(mish_backward_stub, &mish_backward_kernel);
1432
1433 } // namespace at::native
1434