1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/Activation.h>
3
4 #include <ATen/core/Tensor.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/TensorIterator.h>
7 #include <ATen/TensorOperators.h>
8 #include <ATen/OpMathType.h>
9 #include <ATen/Parallel.h>
10 #include <ATen/ScalarOps.h>
11 #if defined(C10_MOBILE) && defined(USE_XNNPACK)
12 #include <ATen/native/xnnpack/Engine.h>
13 #endif
14 #include <ATen/core/DistributionsHelper.h>
15
16 #include <c10/util/irange.h>
17 #include <c10/core/ScalarType.h>
18 #if AT_MKLDNN_ENABLED()
19 #include <ATen/native/mkldnn/MKLDNNCommon.h>
20 #include <ATen/native/mkldnn/Utils.h>
21 #endif
22
23 #ifndef AT_PER_OPERATOR_HEADERS
24 #include <ATen/Functions.h>
25 #include <ATen/NativeFunctions.h>
26 #else
27 #include <ATen/ops/celu_native.h>
28 #include <ATen/ops/clamp.h>
29 #include <ATen/ops/clamp_min.h>
30 #include <ATen/ops/elu.h>
31 #include <ATen/ops/elu_backward_native.h>
32 #include <ATen/ops/elu_native.h>
33 #include <ATen/ops/empty.h>
34 #include <ATen/ops/empty_like.h>
35 #include <ATen/ops/gelu_backward_native.h>
36 #include <ATen/ops/gelu_native.h>
37 #include <ATen/ops/hardshrink_backward_native.h>
38 #include <ATen/ops/hardshrink_native.h>
39 #include <ATen/ops/hardsigmoid_backward_native.h>
40 #include <ATen/ops/hardsigmoid_native.h>
41 #include <ATen/ops/hardswish_backward_native.h>
42 #include <ATen/ops/hardswish_native.h>
43 #include <ATen/ops/hardtanh.h>
44 #include <ATen/ops/hardtanh_backward_native.h>
45 #include <ATen/ops/hardtanh_native.h>
46 #include <ATen/ops/infinitely_differentiable_gelu_backward_native.h>
47 #include <ATen/ops/leaky_relu.h>
48 #include <ATen/ops/leaky_relu_backward.h>
49 #include <ATen/ops/leaky_relu_backward_native.h>
50 #include <ATen/ops/leaky_relu_native.h>
51 #include <ATen/ops/log_sigmoid_backward_native.h>
52 #include <ATen/ops/log_sigmoid_forward.h>
53 #include <ATen/ops/log_sigmoid_forward_native.h>
54 #include <ATen/ops/log_sigmoid_native.h>
55 #include <ATen/ops/mish_backward_native.h>
56 #include <ATen/ops/mish_native.h>
57 #include <ATen/ops/prelu_native.h>
58 #include <ATen/ops/_prelu_kernel.h>
59 #include <ATen/ops/_prelu_kernel_native.h>
60 #include <ATen/ops/_prelu_kernel_backward_native.h>
61 #include <ATen/ops/relu6_native.h>
62 #include <ATen/ops/relu_native.h>
63 #include <ATen/ops/rrelu_native.h>
64 #include <ATen/ops/rrelu_with_noise.h>
65 #include <ATen/ops/rrelu_with_noise_backward_native.h>
66 #include <ATen/ops/rrelu_with_noise_native.h>
67 #include <ATen/ops/selu_native.h>
68 #include <ATen/ops/sigmoid.h>
69 #include <ATen/ops/silu_backward_native.h>
70 #include <ATen/ops/silu_native.h>
71 #include <ATen/ops/softplus.h>
72 #include <ATen/ops/softplus_backward_native.h>
73 #include <ATen/ops/softplus_native.h>
74 #include <ATen/ops/softshrink_backward_native.h>
75 #include <ATen/ops/softshrink_native.h>
76 #include <ATen/ops/tanh.h>
77 #include <ATen/ops/threshold_backward_native.h>
78 #include <ATen/ops/threshold_native.h>
79
80 #include <utility>
81 #endif
82
83 namespace at::meta {
84 // computes `result = self <= threshold ? value : other`
85 // other is `self` in threshold() and `grad` in threshold_backward()
TORCH_META_FUNC(threshold)86 TORCH_META_FUNC(threshold)(const Tensor& self, const Scalar& threshold, const Scalar& value) {
87 const Tensor& result = maybe_get_output();
88 build(TensorIteratorConfig()
89 .set_check_mem_overlap(false) // threshold is idempotent, so overlap is okay
90 .add_output(result)
91 .add_const_input(self)
92 .add_const_input(self) // other
93 .allow_cpu_scalars(true)
94 .promote_inputs_to_common_dtype(true)
95 .cast_common_dtype_to_outputs(true)
96 .enforce_safe_casting_to_output(true));
97 }
98 // computes `result = self <= threshold ? value : other`
99 // other is `self` in threshold() and `grad` in threshold_backward()
TORCH_META_FUNC(threshold_backward)100 TORCH_META_FUNC(threshold_backward)(const Tensor& grad, const Tensor& self, const Scalar& threshold) {
101 const Tensor& gradInput = maybe_get_output();
102 build(TensorIteratorConfig()
103 .set_check_mem_overlap(false) // threshold is idempotent, so overlap is okay
104 .add_output(gradInput)
105 .add_const_input(self)
106 .add_const_input(grad) // other
107 .allow_cpu_scalars(true)
108 .promote_inputs_to_common_dtype(true)
109 .cast_common_dtype_to_outputs(true)
110 .enforce_safe_casting_to_output(true));
111 }
112
TORCH_META_FUNC(elu)113 TORCH_META_FUNC(elu) (
114 const Tensor& self, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale
115 ) {
116 build_unary_op(maybe_get_output(), self);
117 }
118
TORCH_META_FUNC(elu_backward)119 TORCH_META_FUNC(elu_backward) (
120 const Tensor& grad_output,
121 const Scalar& alpha,
122 const Scalar& scale,
123 const Scalar& input_scale,
124 bool is_result,
125 const Tensor& self_or_result
126 ) {
127 TORCH_CHECK(
128 !is_result || alpha.to<double>() >= 0.0,
129 "In-place elu backward calculation is triggered with a negative slope which is not supported. "
130 "This is caused by calling in-place forward function with a negative slope, "
131 "please call out-of-place version instead.");
132
133 build_borrowing_binary_op(maybe_get_output(), grad_output, self_or_result);
134 }
135
TORCH_META_FUNC(silu)136 TORCH_META_FUNC(silu) (const Tensor& self) {
137 build_unary_op(maybe_get_output(), self);
138 }
139
TORCH_META_FUNC(silu_backward)140 TORCH_META_FUNC(silu_backward) (
141 const Tensor& grad_output, const Tensor& input
142 ) {
143 build_borrowing_binary_op(maybe_get_output(), grad_output, input);
144 }
145
TORCH_META_FUNC(mish)146 TORCH_META_FUNC(mish) (const Tensor& self) {
147 build_unary_op(maybe_get_output(), self);
148 }
149
TORCH_META_FUNC(softplus)150 TORCH_META_FUNC(softplus) (
151 const Tensor& self, const Scalar& beta, const Scalar& threshold
152 ) {
153 build_unary_op(maybe_get_output(), self);
154 }
155
TORCH_META_FUNC(softplus_backward)156 TORCH_META_FUNC(softplus_backward) (
157 const Tensor& grad_output,
158 const Tensor& self,
159 const Scalar& beta,
160 const Scalar& threshold
161 ) {
162 build_borrowing_binary_op(maybe_get_output(), grad_output, self);
163 }
164
TORCH_META_FUNC(leaky_relu)165 TORCH_META_FUNC(leaky_relu) (
166 const Tensor& self, const Scalar& negval
167 ) {
168 build_unary_op(maybe_get_output(), self);
169 }
170
171 // Note: leakyReLu backward calculation doesn't support in-place call with negative slope.
172 // The reason is that for in-place forward call, the forward result will be saved into autograd
173 // node instead of the input itself, when calculating backward gradient, there is no way to know
174 // whether the original input for current node is positive or not if the input slope is
175 // negative. eg. forward is 2, slope is -0.2, the original input for this node could be
176 // either 2, or -10, so no way to get a correct backward gradient in this case.
TORCH_META_FUNC(leaky_relu_backward)177 TORCH_META_FUNC(leaky_relu_backward) (
178 const Tensor& grad_output,
179 const Tensor& self_or_result,
180 const Scalar& negval,
181 bool is_result
182 ) {
183 TORCH_CHECK(
184 !is_result || negval.to<double>() >= 0.0,
185 "In-place leakyReLu backward calculation is triggered with a negative slope which is not supported. "
186 "This is caused by calling in-place forward function with a negative slope, "
187 "please call out-of-place version instead. File an issue at https://github.com/pytorch/pytorch if you do "
188 "require supporting in-place leakRelu backward calculation with negative slope");
189
190 build_borrowing_binary_op(maybe_get_output(), self_or_result, grad_output);
191 }
192
TORCH_META_FUNC(hardsigmoid)193 TORCH_META_FUNC(hardsigmoid) (const Tensor& self) {
194 build_unary_op(maybe_get_output(), self);
195 }
196
TORCH_META_FUNC(hardsigmoid_backward)197 TORCH_META_FUNC(hardsigmoid_backward) (const Tensor& grad_output, const Tensor& self) {
198 build_borrowing_binary_op(maybe_get_output(), grad_output, self);
199 }
200
TORCH_META_FUNC(hardshrink)201 TORCH_META_FUNC(hardshrink) (const Tensor & self, const Scalar& lambd) {
202 build_unary_op(maybe_get_output(), self);
203 }
204
TORCH_META_FUNC(hardshrink_backward)205 TORCH_META_FUNC(hardshrink_backward) (
206 const Tensor & grad, const Tensor & self, const Scalar& lambd
207 ) {
208 build_borrowing_binary_op(maybe_get_output(), grad, self);
209 }
210
softshrink_check(const Scalar & lambd)211 static inline void softshrink_check(const Scalar& lambd) {
212 double lamb = lambd.to<double>();
213 TORCH_CHECK(lamb >= 0, "lambda must be greater or equal to 0, but found to be ", lamb, ".");
214 }
215
TORCH_META_FUNC(softshrink)216 TORCH_META_FUNC(softshrink) (
217 const Tensor & self, const Scalar& lambd
218 ) {
219 softshrink_check(lambd);
220 build_unary_op(maybe_get_output(), self);
221 }
222
TORCH_META_FUNC(softshrink_backward)223 TORCH_META_FUNC(softshrink_backward) (
224 const Tensor & grad, const Tensor & self, const Scalar& lambd
225 ) {
226 build_borrowing_binary_op(maybe_get_output(), grad, self);
227 }
228
TORCH_META_FUNC(gelu)229 TORCH_META_FUNC(gelu) (const Tensor & self, c10::string_view approximate) {
230 build_unary_op(maybe_get_output(), self);
231 }
232
TORCH_META_FUNC(gelu_backward)233 TORCH_META_FUNC(gelu_backward) (
234 const Tensor& grad, const Tensor& self, c10::string_view approximate
235 ) {
236 build_borrowing_binary_op(maybe_get_output(), grad, self);
237 }
238
239 } // namespace at::meta
240
241 namespace at::native {
242
243 static const double SELU_ALPHA = 1.6732632423543772848170429916717;
244 static const double SELU_SCALE = 1.0507009873554804934193349852946;
245
246 DEFINE_DISPATCH(elu_stub);
247 DEFINE_DISPATCH(elu_backward_stub);
248 DEFINE_DISPATCH(softplus_stub);
249 DEFINE_DISPATCH(softplus_backward_stub);
250 DEFINE_DISPATCH(log_sigmoid_cpu_stub);
251 DEFINE_DISPATCH(log_sigmoid_backward_stub);
252 DEFINE_DISPATCH(threshold_stub);
253 DEFINE_DISPATCH(hardtanh_backward_stub);
254 DEFINE_DISPATCH(hardsigmoid_stub);
255 DEFINE_DISPATCH(hardsigmoid_backward_stub);
256 DEFINE_DISPATCH(hardswish_stub);
257 DEFINE_DISPATCH(hardswish_backward_stub);
258 DEFINE_DISPATCH(hardshrink_stub);
259 DEFINE_DISPATCH(softshrink_stub);
260 DEFINE_DISPATCH(shrink_backward_stub);
261 DEFINE_DISPATCH(leaky_relu_stub);
262 DEFINE_DISPATCH(leaky_relu_backward_stub);
263 DEFINE_DISPATCH(silu_stub);
264 DEFINE_DISPATCH(silu_backward_stub);
265 DEFINE_DISPATCH(mish_stub);
266 DEFINE_DISPATCH(mish_backward_stub);
267 DEFINE_DISPATCH(prelu_stub);
268 DEFINE_DISPATCH(prelu_backward_stub);
269
TORCH_IMPL_FUNC(elu_out)270 TORCH_IMPL_FUNC(elu_out) (
271 const Tensor& self, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale, const Tensor& result
272 ) {
273 elu_stub(device_type(), *this, alpha, scale, input_scale);
274 }
275
TORCH_IMPL_FUNC(elu_backward_out)276 TORCH_IMPL_FUNC(elu_backward_out) (
277 const Tensor& grad_output,
278 const Scalar& alpha,
279 const Scalar& scale,
280 const Scalar& input_scale,
281 bool is_result,
282 const Tensor& self_or_result,
283 const Tensor& grad_input
284 ) {
285 elu_backward_stub(device_type(), *this, alpha, scale, input_scale, is_result);
286 }
287
TORCH_IMPL_FUNC(silu_out)288 TORCH_IMPL_FUNC(silu_out) (
289 const Tensor& self, const Tensor& result
290 ) {
291 silu_stub(device_type(), *this);
292 }
293
TORCH_IMPL_FUNC(silu_backward_out)294 TORCH_IMPL_FUNC(silu_backward_out) (
295 const Tensor& grad_output, const Tensor& input, const Tensor& grad_input
296 ) {
297 silu_backward_stub(device_type(), *this);
298 }
299
TORCH_IMPL_FUNC(mish_out)300 TORCH_IMPL_FUNC(mish_out) (
301 const Tensor& self, const Tensor& result
302 ) {
303 mish_stub(device_type(), *this);
304 }
305
TORCH_IMPL_FUNC(softplus_out)306 TORCH_IMPL_FUNC(softplus_out) (
307 const Tensor& self, const Scalar& beta, const Scalar& threshold, const Tensor& result
308 ) {
309 softplus_stub(device_type(), *this, beta, threshold);
310 }
311
TORCH_IMPL_FUNC(softplus_backward_out)312 TORCH_IMPL_FUNC(softplus_backward_out) (
313 const Tensor& grad_output,
314 const Tensor& self,
315 const Scalar& beta,
316 const Scalar& threshold,
317 const Tensor& grad_input
318 ) {
319 softplus_backward_stub(device_type(), *this, beta, threshold);
320 }
321
TORCH_IMPL_FUNC(leaky_relu_out)322 TORCH_IMPL_FUNC(leaky_relu_out) (
323 const Tensor& self, const Scalar& negval, const Tensor& result
324 ) {
325 leaky_relu_stub(device_type(), *this, negval);
326 }
327
TORCH_IMPL_FUNC(leaky_relu_backward_out)328 TORCH_IMPL_FUNC(leaky_relu_backward_out) (
329 const Tensor& grad_output,
330 const Tensor& self_or_result,
331 const Scalar& negval,
332 bool is_result,
333 const Tensor& grad_input
334 ) {
335 leaky_relu_backward_stub(device_type(), *this, negval);
336 }
337
TORCH_IMPL_FUNC(hardsigmoid_out)338 TORCH_IMPL_FUNC(hardsigmoid_out) (
339 const Tensor& self, const Tensor& result
340 ) {
341 hardsigmoid_stub(device_type(), *this);
342 }
343
TORCH_IMPL_FUNC(hardsigmoid_backward_out)344 TORCH_IMPL_FUNC(hardsigmoid_backward_out) (
345 const Tensor& grad_output, const Tensor& self, const Tensor& grad_input
346 ) {
347 hardsigmoid_backward_stub(device_type(), *this);
348 }
349
TORCH_IMPL_FUNC(hardshrink_out)350 TORCH_IMPL_FUNC(hardshrink_out) (
351 const Tensor & self, const Scalar& lambd, const Tensor& result
352 ) {
353 hardshrink_stub(device_type(), *this, lambd);
354 }
355
TORCH_IMPL_FUNC(hardshrink_backward_out)356 TORCH_IMPL_FUNC(hardshrink_backward_out) (
357 const Tensor & grad, const Tensor & self, const Scalar& lambd, const Tensor& grad_input
358 ) {
359 shrink_backward_stub(device_type(), *this, lambd);
360 }
361
TORCH_IMPL_FUNC(softshrink_out)362 TORCH_IMPL_FUNC(softshrink_out) (
363 const Tensor & self, const Scalar& lambd, const Tensor& result
364 ) {
365 softshrink_stub(device_type(), *this, lambd);
366 }
367
TORCH_IMPL_FUNC(softshrink_backward_out)368 TORCH_IMPL_FUNC(softshrink_backward_out) (
369 const Tensor & grad, const Tensor & self, const Scalar& lambd, const Tensor& grad_input
370 ) {
371 shrink_backward_stub(device_type(), *this, lambd);
372 }
373
374 #if AT_MKLDNN_ENABLED()
use_mkldnn(const Tensor & input)375 static bool use_mkldnn(const Tensor& input) {
376 if (!at::globalContext().userEnabledMkldnn()) {
377 return false;
378 }
379 if (!input.is_contiguous() || input.numel() <= 1) {
380 return false;
381 }
382 return (input.is_mkldnn()) || // input is mkldnn Tensor
383 (input.device().is_cpu() &&
384 (((input.scalar_type() == kBFloat16) && mkldnn_bf16_device_check()) ||
385 (input.scalar_type() == kFloat))); // input is dense layout and bfloat16/float32
386 }
387 #endif
388
TORCH_IMPL_FUNC(gelu_out_cpu)389 TORCH_IMPL_FUNC(gelu_out_cpu) (
390 const Tensor& self, c10::string_view approximate, const Tensor& result
391 ) {
392 auto approximate_type = get_gelutype_enum(approximate);
393 #if AT_MKLDNN_ENABLED()
394 if (use_mkldnn(self) && (approximate_type == GeluType::None)) {
395 const ideep::tensor& x = itensor_from_tensor(self, /*from_const_data_ptr*/true);
396 ideep::tensor y = itensor_from_tensor(result);
397 ideep::eltwise_forward::compute(
398 x, y, ideep::algorithm::eltwise_gelu_erf, ideep::prop_kind::forward_training, /*alpha*/ 0.0);
399 #ifdef __aarch64__
400 } else if (use_mkldnn(self) && (approximate_type == GeluType::Tanh)) {
401 const ideep::tensor& x = itensor_from_tensor(self);
402 ideep::tensor y = itensor_from_tensor(result);
403 ideep::eltwise_forward::compute(
404 x, y, ideep::algorithm::eltwise_gelu_tanh, ideep::prop_kind::forward_training, /*alpha*/ 0.0);
405 #endif // ifdef __aarch64__
406 } else {
407 GeluKernel(kCPU, *this, approximate_type);
408 }
409 #else
410 GeluKernel(kCPU, *this, approximate_type);
411 #endif
412 }
413
TORCH_IMPL_FUNC(gelu_backward_out_cpu)414 TORCH_IMPL_FUNC(gelu_backward_out_cpu) (
415 const Tensor& grad, const Tensor& self, c10::string_view approximate, const Tensor& grad_input
416 ) {
417 auto approximate_type = get_gelutype_enum(approximate);
418 #if AT_MKLDNN_ENABLED()
419 if (use_mkldnn(self) && (approximate_type == GeluType::None)) {
420 const ideep::tensor& x = itensor_from_tensor(self, /*from_const_data_ptr*/true);
421 ideep::tensor grady = itensor_from_tensor(grad, /*from_const_data_ptr*/true);
422 ideep::tensor gradx = itensor_from_tensor(grad_input);
423 ideep::eltwise_backward::compute(x, grady, gradx,
424 ideep::algorithm::eltwise_gelu_erf, /*alpha*/ 0.0);
425 } else {
426 GeluBackwardKernel(kCPU, *this, approximate_type);
427 }
428 #else
429 GeluBackwardKernel(kCPU, *this, approximate_type);
430 #endif
431 }
432
hardtanh(const Tensor & self,const Scalar & min,const Scalar & max)433 Tensor hardtanh(const Tensor& self, const Scalar& min, const Scalar& max) {
434 Tensor result = at::empty_like(self);
435 return at::hardtanh_out(result, self, min, max);
436 }
437
hardtanh_out(const Tensor & self,const Scalar & min,const Scalar & max,Tensor & result)438 Tensor& hardtanh_out(const Tensor& self, const Scalar& min, const Scalar& max, Tensor& result) {
439 TORCH_CHECK(self.scalar_type() != at::kBool,
440 "Bool inputs not supported for hardtanh");
441 //preserve legacy behavior of boundaries not causing type promotion
442 Scalar min_, max_;
443 if (at::isIntegralType(self.scalar_type(), /*include_bool*/false)) {
444 int64_t minval = min.toLong();
445 int64_t maxval = max.toLong();
446 TORCH_CHECK(self.dtype() != at::kByte || (minval >= 0 &&
447 maxval >=0), "cannot do hardtanh on an unsigned type with negative limits");
448 min_ = minval;
449 max_ = maxval;
450 } else {
451 min_ = min;
452 max_ = max;
453 }
454 return at::clamp_out(result, self, min_, max_);
455 }
456
hardtanh_(Tensor & self,const Scalar & min,const Scalar & max)457 Tensor& hardtanh_(Tensor& self, const Scalar& min, const Scalar& max) {
458 return at::hardtanh_out(self, self, min, max);
459 }
460
hardtanh_backward_out(const Tensor & grad_output,const Tensor & self,const Scalar & min,const Scalar & max,Tensor & grad_input)461 Tensor& hardtanh_backward_out(const Tensor& grad_output, const Tensor& self, const Scalar& min, const Scalar& max, Tensor& grad_input) {
462 auto iter = TensorIterator::borrowing_binary_op(grad_input, grad_output, self);
463 hardtanh_backward_stub(iter.device_type(), iter, min, max);
464 return grad_input;
465 }
466
hardtanh_backward(const Tensor & grad_output,const Tensor & self,const Scalar & min,const Scalar & max)467 Tensor hardtanh_backward(const Tensor& grad_output, const Tensor& self, const Scalar& min, const Scalar& max) {
468 Tensor result;
469 auto iter = TensorIterator::borrowing_binary_op(result, grad_output, self);
470 hardtanh_backward_stub(iter.device_type(), iter, min, max);
471 return iter.output();
472 }
473
hardswish(const Tensor & self)474 Tensor hardswish(const Tensor& self) {
475 #if defined(C10_MOBILE) && defined(USE_XNNPACK)
476 if (xnnpack::use_hardswish(self)) {
477 return xnnpack::hardswish(self);
478 }
479 #endif
480 Tensor result;
481 auto iter = TensorIterator::unary_op(result, self);
482 hardswish_stub(iter.device_type(), iter);
483 return iter.output();
484 }
485
hardswish_out(const Tensor & self,Tensor & result)486 Tensor& hardswish_out(const Tensor& self, Tensor& result) {
487 auto iter = TensorIterator::unary_op(result, self);
488 hardswish_stub(iter.device_type(), iter);
489 return result;
490 }
491
hardswish_(Tensor & self)492 Tensor& hardswish_(Tensor& self) {
493 #if defined(C10_MOBILE) && defined(USE_XNNPACK)
494 if (xnnpack::use_hardswish(self)) {
495 xnnpack::hardswish_(self);
496 return self;
497 }
498 #endif
499 auto iter = TensorIterator::unary_op(self, self);
500 hardswish_stub(iter.device_type(), iter);
501 return self;
502 }
503
hardswish_backward(const Tensor & grad_output,const Tensor & self)504 Tensor hardswish_backward(const Tensor& grad_output, const Tensor& self) {
505 Tensor grad_input;
506 auto iter = TensorIterator::borrowing_binary_op(grad_input, grad_output, self);
507 hardswish_backward_stub(iter.device_type(), iter);
508 return iter.output();
509 }
510
relu(const Tensor & self)511 Tensor relu(const Tensor & self) {
512 TORCH_CHECK(self.scalar_type() != at::kBool, "Boolean inputs not supported for relu");
513 return at::clamp_min(self, 0);
514 }
515
relu_(Tensor & self)516 Tensor & relu_(Tensor & self) {
517 TORCH_CHECK(self.scalar_type() != at::kBool, "Boolean inputs not supported for relu");
518 return at::clamp_min_(self, 0);
519 }
520
selu(const Tensor & self)521 Tensor selu(const Tensor & self) {
522 return at::elu(self, SELU_ALPHA, SELU_SCALE);
523 }
524
relu6(const Tensor & self)525 Tensor relu6(const Tensor & self) {
526 return at::hardtanh(self, /*min_val=*/0, /*max_val=*/6);
527 }
528
selu_(Tensor & self)529 Tensor & selu_(Tensor & self) {
530 return at::elu_(self, SELU_ALPHA, SELU_SCALE);
531 }
532
relu6_(Tensor & self)533 Tensor & relu6_(Tensor & self) {
534 return at::hardtanh_(self, /*min_val=*/0, /*max_val=*/6);
535 }
536
celu(const Tensor & self,const Scalar & alpha)537 Tensor celu(const Tensor & self, const Scalar& alpha) {
538 TORCH_CHECK(alpha.to<double>() != 0,
539 "ZeroDivisionError: alpha cannot be 0 for CELU");
540 double inv_alpha = 1. / alpha.to<double>();
541 return at::elu(self, alpha, Scalar(1.0), Scalar(inv_alpha));
542 }
543
celu_(Tensor & self,const Scalar & alpha)544 Tensor & celu_(Tensor & self, const Scalar& alpha) {
545 TORCH_CHECK(alpha.to<double>() != 0,
546 "ZeroDivisionError: alpha cannot be 0 for CELU");
547 double inv_alpha = 1. / alpha.to<double>();
548 return at::elu_(self, alpha, Scalar(1.0), Scalar(inv_alpha));
549 }
550
math_silu_backward(const Tensor & grad_output,const Tensor & input)551 Tensor math_silu_backward(
552 const Tensor& grad_output,
553 const Tensor& input) {
554 auto input_sigmoid = at::sigmoid(input);
555 return grad_output * (input_sigmoid * (1 + input * (1 - input_sigmoid)));
556 }
557
mish_backward(const Tensor & grad_output,const Tensor & input)558 Tensor mish_backward(
559 const Tensor& grad_output,
560 const Tensor& input) {
561 Tensor grad_input = at::empty({0}, input.options());
562 auto iter = TensorIterator::binary_op(grad_input, grad_output, input);
563 mish_backward_stub(iter.device_type(), iter);
564 return grad_input;
565 }
566
math_mish_backward(const Tensor & grad_output,const Tensor & input)567 Tensor math_mish_backward(
568 const Tensor& grad_output,
569 const Tensor& input) {
570 auto input_tanh_softplus = at::tanh(at::softplus(input));
571 auto input_sigmoid = at::sigmoid(input);
572 return grad_output * (input_tanh_softplus + (input * input_sigmoid * (1 - input_tanh_softplus * input_tanh_softplus)));
573 }
574
575 template <typename scalar_t>
_rrelu_with_noise_train(Tensor & output,const Tensor & input,const Tensor & noise,const Scalar & lower_,const Scalar & upper_,std::optional<Generator> generator)576 inline void _rrelu_with_noise_train(
577 Tensor& output,
578 const Tensor& input,
579 const Tensor& noise,
580 const Scalar& lower_,
581 const Scalar& upper_,
582 std::optional<Generator> generator) {
583 using opmath_t = at::opmath_type<scalar_t>;
584 opmath_t lower = lower_.to<opmath_t>();
585 opmath_t upper = upper_.to<opmath_t>();
586 Tensor tmp_tensor = output.contiguous();
587 scalar_t* output_data = tmp_tensor.data_ptr<scalar_t>();
588 const scalar_t* input_data = input.const_data_ptr<scalar_t>();
589 scalar_t* noise_data = noise.data_ptr<scalar_t>();
590 auto gen = at::get_generator_or_default<CPUGeneratorImpl>(generator, detail::getDefaultCPUGenerator());
591 std::lock_guard<std::mutex> lock(gen->mutex_);
592 for (const auto i : c10::irange(input.numel())) {
593 if (input_data[i] <= 0) {
594 at::uniform_real_distribution<double> uniform(lower, upper);
595 const opmath_t r = (opmath_t)uniform(gen);
596 output_data[i] = input_data[i] * r;
597 noise_data[i] = r;
598 } else {
599 noise_data[i] = 1;
600 output_data[i] = input_data[i];
601 }
602 }
603 if (!output.is_contiguous()) {
604 output.copy_(tmp_tensor);
605 }
606 }
607
rrelu_with_noise_out_cpu(const Tensor & self,const Tensor & noise,const Scalar & lower,const Scalar & upper,bool training,std::optional<Generator> generator,Tensor & output)608 Tensor& rrelu_with_noise_out_cpu(const Tensor& self,
609 const Tensor& noise,
610 const Scalar& lower,
611 const Scalar& upper,
612 bool training,
613 std::optional<Generator> generator,
614 Tensor& output) {
615 TORCH_CHECK(self.sym_sizes() == noise.sym_sizes(), "noise tensor shape must match self tensor shape. Got self.shape = ", self.sym_sizes(), " noise.shape = ", noise.sym_sizes());
616 if (training) {
617 AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, self.scalar_type(), "rrelu_with_noise_out_cpu", [&] {
618 _rrelu_with_noise_train<scalar_t>(output, self.contiguous(), noise, lower, upper, generator);
619 });
620 return output;
621 } else {
622 auto lower_tensor = scalar_to_tensor(lower);
623 auto upper_tensor = scalar_to_tensor(upper);
624 auto negative = (lower_tensor + upper_tensor) / 2;
625 Scalar negative_slope = negative.item();
626 return at::leaky_relu_out(output, self, negative_slope);
627 }
628 }
629
rrelu_with_noise_cpu(const Tensor & self,const Tensor & noise,const Scalar & lower,const Scalar & upper,bool training,std::optional<Generator> generator)630 Tensor rrelu_with_noise_cpu(
631 const Tensor& self,
632 const Tensor& noise,
633 const Scalar& lower,
634 const Scalar& upper,
635 bool training,
636 std::optional<Generator> generator) {
637 auto output = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
638 return at::native::rrelu_with_noise_out_cpu(
639 self, noise, lower, upper, training, std::move(generator), output);
640 }
641
rrelu_with_noise_cpu_(Tensor & self,const Tensor & noise,const Scalar & lower,const Scalar & upper,bool training,std::optional<Generator> generator)642 Tensor& rrelu_with_noise_cpu_(
643 Tensor& self,
644 const Tensor& noise,
645 const Scalar& lower,
646 const Scalar& upper,
647 bool training,
648 std::optional<Generator> generator) {
649 return at::native::rrelu_with_noise_out_cpu(
650 self, noise, lower, upper, training, std::move(generator), self);
651 }
652
rrelu_with_noise_backward(const Tensor & grad_output,const Tensor & self_or_result,const Tensor & noise,const Scalar & lower,const Scalar & upper,bool training,bool is_result)653 Tensor rrelu_with_noise_backward(
654 const Tensor& grad_output,
655 const Tensor& self_or_result,
656 const Tensor& noise,
657 const Scalar& lower,
658 const Scalar& upper,
659 bool training,
660 bool is_result) {
661 if (training) {
662 return noise * grad_output;
663 } else {
664 auto l = lower.toDouble();
665 auto u = upper.toDouble();
666 auto mid = (l + u) / 2.;
667 return at::leaky_relu_backward(grad_output, self_or_result, mid, is_result);
668 }
669 }
670
rrelu(const Tensor & self,const Scalar & lower,const Scalar & upper,bool training,std::optional<Generator> generator)671 Tensor rrelu(const Tensor & self, const Scalar& lower, const Scalar& upper, bool training, std::optional<Generator> generator) {
672 TORCH_CHECK(lower.to<double>() <= upper.to<double>(), "Lower bound should be less than or equal to the upper bound")
673 return at::rrelu_with_noise(self, at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT), lower, upper, training, std::move(generator));
674 }
675
rrelu_(Tensor & self,const Scalar & lower,const Scalar & upper,bool training,std::optional<Generator> generator)676 Tensor & rrelu_(Tensor & self, const Scalar& lower, const Scalar& upper, bool training, std::optional<Generator> generator) {
677 TORCH_CHECK(lower.to<double>() <= upper.to<double>(), "Lower bound should be less than or equal to the upper bound")
678 return at::rrelu_with_noise_(self, at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT), lower, upper, training, std::move(generator));
679 }
680
TORCH_IMPL_FUNC(threshold_out)681 TORCH_IMPL_FUNC(threshold_out)(const Tensor& self, const Scalar& threshold, const Scalar& value, const Tensor& result) {
682 threshold_stub(device_type(), *this, threshold, value);
683 }
684
TORCH_IMPL_FUNC(threshold_backward_out)685 TORCH_IMPL_FUNC(threshold_backward_out)(const Tensor& grad, const Tensor& self, const Scalar& threshold, const Tensor& gradInput) {
686 threshold_stub(device_type(), *this, threshold, 0);
687 }
688
prelu(const Tensor & self,const Tensor & weight_)689 Tensor prelu(const Tensor& self, const Tensor& weight_) {
690 TORCH_INTERNAL_ASSERT(weight_.defined());
691 auto self_dim = self.dim();
692 TORCH_CHECK(self.scalar_type() == weight_.scalar_type(),
693 "prelu: Type promoting not supported. Got ",
694 self.scalar_type(), " and ", weight_.scalar_type());
695 if (weight_.sym_numel() != 1) {
696 TORCH_CHECK(self_dim > 0, "Not allow zero-dim input tensor.");
697
698 auto channel_size = self_dim > 1 ? self.sym_size(1) : 1; // channel_size default to 1
699 TORCH_CHECK(channel_size == weight_.sym_numel(),
700 "Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_.numel(),
701 " and channel size = ", channel_size, ".");
702 }
703
704 TORCH_CHECK(
705 weight_.dim() <= 1,
706 "prelu: Expected `weight` to be a scalar or 1D tensor, but got: ndim = ", weight_.dim());
707 // Adjust weight to broadcast over self and have weight.ndim == self.ndim
708 auto weight = weight_;
709 if (self_dim != weight.dim()) {
710 SymDimVector dim_w(self_dim, 1);
711 if (self_dim > 1) {
712 dim_w[1] = weight_.sym_numel();
713 }
714 // This will always be a view in CPU/CUDA, but some backends
715 // like MKLDNN do not support views
716 weight = weight.reshape_symint(dim_w);
717 }
718 return at::_prelu_kernel(self, weight);
719 }
720
721
_prelu_kernel(const Tensor & self,const Tensor & weight)722 Tensor _prelu_kernel(const Tensor& self, const Tensor& weight) {
723 // Weight broadcasts over self and they have the same dtype
724 auto result = at::empty_like(self);
725 auto iter = TensorIteratorConfig()
726 .add_output(result)
727 .add_const_input(self)
728 .add_const_input(weight)
729 .build();
730 prelu_stub(iter.device_type(), iter);
731 return result;
732 }
733
_prelu_kernel_backward(const Tensor & grad_out,const Tensor & self,const Tensor & weight)734 std::tuple<Tensor, Tensor> _prelu_kernel_backward(const Tensor& grad_out, const Tensor& self, const Tensor& weight) {
735 Tensor grad_self = at::empty({0}, self.options());
736 Tensor grad_weight = at::empty({0}, weight.options());
737 auto iter = TensorIteratorConfig()
738 .add_output(grad_self)
739 .add_output(grad_weight)
740 .add_const_input(self)
741 .add_const_input(weight)
742 .add_const_input(grad_out)
743 .build();
744 prelu_backward_stub(iter.device_type(), iter);
745 return {grad_self, grad_weight};
746 }
747
infinitely_differentiable_gelu_backward(const Tensor & grad,const Tensor & self)748 Tensor infinitely_differentiable_gelu_backward(
749 const Tensor& grad,
750 const Tensor& self) {
751 constexpr double kAlpha = M_2_SQRTPI * M_SQRT1_2 * 0.5;
752 Tensor cdf = (1.0 + (self * M_SQRT1_2).erf_()).mul_(0.5);
753 Tensor pdf = (-0.5 * self * self).exp_();
754 return cdf.addcmul_(self, pdf, kAlpha).mul_(grad);
755 }
756
log_sigmoid_forward_cpu(const Tensor & input)757 std::tuple<Tensor, Tensor> log_sigmoid_forward_cpu(const Tensor& input) {
758 auto result = at::empty_like(input, at::MemoryFormat::Contiguous);
759 auto buffer = at::empty_like(input, at::MemoryFormat::Contiguous);
760 log_sigmoid_cpu_stub(kCPU, result, buffer, input.contiguous());
761 return std::make_tuple(result, buffer);
762 }
763
log_sigmoid_forward_out_cpu(const Tensor & input,Tensor & result,Tensor & buffer)764 std::tuple<Tensor&, Tensor&> log_sigmoid_forward_out_cpu(const Tensor& input, Tensor& result, Tensor& buffer) {
765 result.resize_as_(input);
766 buffer.resize_as_(input, at::MemoryFormat::Contiguous);
767 TORCH_CHECK(buffer.is_contiguous(), "Contiguous buffer required for log_sigmoid with out parameter");
768 Tensor result_tmp = result.is_contiguous() ? result : at::empty_like(result, at::MemoryFormat::Contiguous);
769 log_sigmoid_cpu_stub(kCPU, result_tmp, buffer, input.contiguous());
770 if (!result.is_contiguous()) {
771 result.copy_(result_tmp);
772 }
773 return std::forward_as_tuple(result, buffer);
774 }
775
log_sigmoid_out(const Tensor & self,Tensor & output)776 Tensor & log_sigmoid_out(const Tensor & self, Tensor & output) {
777 Tensor buffer = at::empty({0}, self.options());
778 return std::get<0>(at::log_sigmoid_forward_out(output, buffer, self));
779 }
780
log_sigmoid(const Tensor & self)781 Tensor log_sigmoid(const Tensor & self) {
782 return std::get<0>(at::log_sigmoid_forward(self));
783 }
784
log_sigmoid_backward_cuda(const Tensor & grad_output,const Tensor & input,const Tensor & buffer)785 Tensor log_sigmoid_backward_cuda(const Tensor& grad_output, const Tensor& input, const Tensor& buffer) {
786 auto grad_input = at::empty_like(grad_output);
787 // NOTE: buffer is only used by CPU dispatch, we just ignore it here
788 auto iter = at::TensorIteratorConfig()
789 .add_output(grad_input)
790 .add_const_input(input)
791 .add_const_input(grad_output)
792 .build();
793 log_sigmoid_backward_stub(kCUDA, iter);
794 return iter.output();
795 }
796
log_sigmoid_backward_cpu(const Tensor & grad_output,const Tensor & input,const Tensor & buffer)797 Tensor log_sigmoid_backward_cpu(const Tensor& grad_output, const Tensor& input, const Tensor& buffer) {
798 auto grad_input = at::empty_like(grad_output);
799 auto iter = at::TensorIteratorConfig()
800 .add_output(grad_input)
801 .add_const_input(input)
802 .add_const_input(buffer)
803 .add_const_input(grad_output)
804 .build();
805 log_sigmoid_backward_stub(kCPU, iter);
806 return iter.output();
807 }
808
log_sigmoid_backward_cuda_out(const Tensor & grad_output,const Tensor & input,const Tensor & buffer,Tensor & grad_input)809 Tensor& log_sigmoid_backward_cuda_out(const Tensor& grad_output, const Tensor& input,
810 const Tensor& buffer, Tensor& grad_input) {
811 auto iter = TensorIteratorConfig()
812 .add_output(grad_input)
813 .add_const_input(input)
814 .add_const_input(grad_output)
815 .build();
816 log_sigmoid_backward_stub(kCUDA, iter);
817 return grad_input;
818 }
819
log_sigmoid_backward_cpu_out(const Tensor & grad_output,const Tensor & input,const Tensor & buffer,Tensor & grad_input)820 Tensor& log_sigmoid_backward_cpu_out(const Tensor& grad_output,
821 const Tensor& input,
822 const Tensor& buffer,
823 Tensor& grad_input) {
824 auto iter = TensorIteratorConfig()
825 .add_output(grad_input)
826 .add_const_input(input)
827 .add_const_input(buffer)
828 .add_const_input(grad_output)
829 .build();
830 log_sigmoid_backward_stub(kCPU, iter);
831 return grad_input;
832 }
833
834 DEFINE_DISPATCH(GeluKernel);
835 DEFINE_DISPATCH(GeluBackwardKernel);
836
837 } // namespace at::native
838