xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/Activation.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/Activation.h>
3 
4 #include <ATen/core/Tensor.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/TensorIterator.h>
7 #include <ATen/TensorOperators.h>
8 #include <ATen/OpMathType.h>
9 #include <ATen/Parallel.h>
10 #include <ATen/ScalarOps.h>
11 #if defined(C10_MOBILE) && defined(USE_XNNPACK)
12 #include <ATen/native/xnnpack/Engine.h>
13 #endif
14 #include <ATen/core/DistributionsHelper.h>
15 
16 #include <c10/util/irange.h>
17 #include <c10/core/ScalarType.h>
18 #if AT_MKLDNN_ENABLED()
19 #include <ATen/native/mkldnn/MKLDNNCommon.h>
20 #include <ATen/native/mkldnn/Utils.h>
21 #endif
22 
23 #ifndef AT_PER_OPERATOR_HEADERS
24 #include <ATen/Functions.h>
25 #include <ATen/NativeFunctions.h>
26 #else
27 #include <ATen/ops/celu_native.h>
28 #include <ATen/ops/clamp.h>
29 #include <ATen/ops/clamp_min.h>
30 #include <ATen/ops/elu.h>
31 #include <ATen/ops/elu_backward_native.h>
32 #include <ATen/ops/elu_native.h>
33 #include <ATen/ops/empty.h>
34 #include <ATen/ops/empty_like.h>
35 #include <ATen/ops/gelu_backward_native.h>
36 #include <ATen/ops/gelu_native.h>
37 #include <ATen/ops/hardshrink_backward_native.h>
38 #include <ATen/ops/hardshrink_native.h>
39 #include <ATen/ops/hardsigmoid_backward_native.h>
40 #include <ATen/ops/hardsigmoid_native.h>
41 #include <ATen/ops/hardswish_backward_native.h>
42 #include <ATen/ops/hardswish_native.h>
43 #include <ATen/ops/hardtanh.h>
44 #include <ATen/ops/hardtanh_backward_native.h>
45 #include <ATen/ops/hardtanh_native.h>
46 #include <ATen/ops/infinitely_differentiable_gelu_backward_native.h>
47 #include <ATen/ops/leaky_relu.h>
48 #include <ATen/ops/leaky_relu_backward.h>
49 #include <ATen/ops/leaky_relu_backward_native.h>
50 #include <ATen/ops/leaky_relu_native.h>
51 #include <ATen/ops/log_sigmoid_backward_native.h>
52 #include <ATen/ops/log_sigmoid_forward.h>
53 #include <ATen/ops/log_sigmoid_forward_native.h>
54 #include <ATen/ops/log_sigmoid_native.h>
55 #include <ATen/ops/mish_backward_native.h>
56 #include <ATen/ops/mish_native.h>
57 #include <ATen/ops/prelu_native.h>
58 #include <ATen/ops/_prelu_kernel.h>
59 #include <ATen/ops/_prelu_kernel_native.h>
60 #include <ATen/ops/_prelu_kernel_backward_native.h>
61 #include <ATen/ops/relu6_native.h>
62 #include <ATen/ops/relu_native.h>
63 #include <ATen/ops/rrelu_native.h>
64 #include <ATen/ops/rrelu_with_noise.h>
65 #include <ATen/ops/rrelu_with_noise_backward_native.h>
66 #include <ATen/ops/rrelu_with_noise_native.h>
67 #include <ATen/ops/selu_native.h>
68 #include <ATen/ops/sigmoid.h>
69 #include <ATen/ops/silu_backward_native.h>
70 #include <ATen/ops/silu_native.h>
71 #include <ATen/ops/softplus.h>
72 #include <ATen/ops/softplus_backward_native.h>
73 #include <ATen/ops/softplus_native.h>
74 #include <ATen/ops/softshrink_backward_native.h>
75 #include <ATen/ops/softshrink_native.h>
76 #include <ATen/ops/tanh.h>
77 #include <ATen/ops/threshold_backward_native.h>
78 #include <ATen/ops/threshold_native.h>
79 
80 #include <utility>
81 #endif
82 
83 namespace at::meta {
84 // computes `result = self <= threshold ? value : other`
85 // other is `self` in threshold() and `grad` in threshold_backward()
TORCH_META_FUNC(threshold)86 TORCH_META_FUNC(threshold)(const Tensor& self, const Scalar& threshold, const Scalar& value) {
87   const Tensor& result = maybe_get_output();
88   build(TensorIteratorConfig()
89     .set_check_mem_overlap(false)  // threshold is idempotent, so overlap is okay
90     .add_output(result)
91     .add_const_input(self)
92     .add_const_input(self) // other
93     .allow_cpu_scalars(true)
94     .promote_inputs_to_common_dtype(true)
95     .cast_common_dtype_to_outputs(true)
96     .enforce_safe_casting_to_output(true));
97 }
98 // computes `result = self <= threshold ? value : other`
99 // other is `self` in threshold() and `grad` in threshold_backward()
TORCH_META_FUNC(threshold_backward)100 TORCH_META_FUNC(threshold_backward)(const Tensor& grad, const Tensor& self, const Scalar& threshold) {
101   const Tensor& gradInput = maybe_get_output();
102   build(TensorIteratorConfig()
103     .set_check_mem_overlap(false)  // threshold is idempotent, so overlap is okay
104     .add_output(gradInput)
105     .add_const_input(self)
106     .add_const_input(grad)  // other
107     .allow_cpu_scalars(true)
108     .promote_inputs_to_common_dtype(true)
109     .cast_common_dtype_to_outputs(true)
110     .enforce_safe_casting_to_output(true));
111 }
112 
TORCH_META_FUNC(elu)113 TORCH_META_FUNC(elu) (
114   const Tensor& self, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale
115 ) {
116   build_unary_op(maybe_get_output(), self);
117 }
118 
TORCH_META_FUNC(elu_backward)119 TORCH_META_FUNC(elu_backward) (
120   const Tensor& grad_output,
121   const Scalar& alpha,
122   const Scalar& scale,
123   const Scalar& input_scale,
124   bool is_result,
125   const Tensor& self_or_result
126 ) {
127   TORCH_CHECK(
128     !is_result || alpha.to<double>() >= 0.0,
129     "In-place elu backward calculation is triggered with a negative slope which is not supported. "
130     "This is caused by calling in-place forward function with a negative slope, "
131     "please call out-of-place version instead.");
132 
133   build_borrowing_binary_op(maybe_get_output(), grad_output, self_or_result);
134 }
135 
TORCH_META_FUNC(silu)136 TORCH_META_FUNC(silu) (const Tensor& self) {
137   build_unary_op(maybe_get_output(), self);
138 }
139 
TORCH_META_FUNC(silu_backward)140 TORCH_META_FUNC(silu_backward) (
141   const Tensor& grad_output, const Tensor& input
142 ) {
143   build_borrowing_binary_op(maybe_get_output(), grad_output, input);
144 }
145 
TORCH_META_FUNC(mish)146 TORCH_META_FUNC(mish) (const Tensor& self) {
147   build_unary_op(maybe_get_output(), self);
148 }
149 
TORCH_META_FUNC(softplus)150 TORCH_META_FUNC(softplus) (
151   const Tensor& self, const Scalar& beta, const Scalar& threshold
152 ) {
153   build_unary_op(maybe_get_output(), self);
154 }
155 
TORCH_META_FUNC(softplus_backward)156 TORCH_META_FUNC(softplus_backward) (
157   const Tensor& grad_output,
158   const Tensor& self,
159   const Scalar& beta,
160   const Scalar& threshold
161 ) {
162   build_borrowing_binary_op(maybe_get_output(), grad_output, self);
163 }
164 
TORCH_META_FUNC(leaky_relu)165 TORCH_META_FUNC(leaky_relu) (
166   const Tensor& self, const Scalar& negval
167 ) {
168   build_unary_op(maybe_get_output(), self);
169 }
170 
171 // Note: leakyReLu backward calculation doesn't support in-place call with negative slope.
172 // The reason is that for in-place forward call, the forward result will be saved into autograd
173 // node instead of the input itself, when calculating backward gradient, there is no way to know
174 // whether the original input for current node is positive or not if the input slope is
175 // negative. eg. forward is 2, slope is -0.2, the original input for this node could be
176 // either 2, or -10, so no way to get a correct backward gradient in this case.
TORCH_META_FUNC(leaky_relu_backward)177 TORCH_META_FUNC(leaky_relu_backward) (
178   const Tensor& grad_output,
179   const Tensor& self_or_result,
180   const Scalar& negval,
181   bool is_result
182 ) {
183   TORCH_CHECK(
184     !is_result || negval.to<double>() >= 0.0,
185     "In-place leakyReLu backward calculation is triggered with a negative slope which is not supported. "
186     "This is caused by calling in-place forward function with a negative slope, "
187     "please call out-of-place version instead. File an issue at https://github.com/pytorch/pytorch if you do "
188     "require supporting in-place leakRelu backward calculation with negative slope");
189 
190   build_borrowing_binary_op(maybe_get_output(), self_or_result, grad_output);
191 }
192 
TORCH_META_FUNC(hardsigmoid)193 TORCH_META_FUNC(hardsigmoid) (const Tensor& self) {
194   build_unary_op(maybe_get_output(), self);
195 }
196 
TORCH_META_FUNC(hardsigmoid_backward)197 TORCH_META_FUNC(hardsigmoid_backward) (const Tensor& grad_output, const Tensor& self) {
198   build_borrowing_binary_op(maybe_get_output(), grad_output, self);
199 }
200 
TORCH_META_FUNC(hardshrink)201 TORCH_META_FUNC(hardshrink) (const Tensor & self, const Scalar& lambd) {
202   build_unary_op(maybe_get_output(), self);
203 }
204 
TORCH_META_FUNC(hardshrink_backward)205 TORCH_META_FUNC(hardshrink_backward) (
206   const Tensor & grad, const Tensor & self, const Scalar& lambd
207 ) {
208   build_borrowing_binary_op(maybe_get_output(), grad, self);
209 }
210 
softshrink_check(const Scalar & lambd)211 static inline void softshrink_check(const Scalar& lambd) {
212   double lamb = lambd.to<double>();
213   TORCH_CHECK(lamb >= 0, "lambda must be greater or equal to 0, but found to be ", lamb, ".");
214 }
215 
TORCH_META_FUNC(softshrink)216 TORCH_META_FUNC(softshrink) (
217   const Tensor & self, const Scalar& lambd
218 ) {
219   softshrink_check(lambd);
220   build_unary_op(maybe_get_output(), self);
221 }
222 
TORCH_META_FUNC(softshrink_backward)223 TORCH_META_FUNC(softshrink_backward) (
224   const Tensor & grad, const Tensor & self, const Scalar& lambd
225 ) {
226   build_borrowing_binary_op(maybe_get_output(), grad, self);
227 }
228 
TORCH_META_FUNC(gelu)229 TORCH_META_FUNC(gelu) (const Tensor & self, c10::string_view approximate) {
230   build_unary_op(maybe_get_output(), self);
231 }
232 
TORCH_META_FUNC(gelu_backward)233 TORCH_META_FUNC(gelu_backward) (
234   const Tensor& grad, const Tensor& self, c10::string_view approximate
235 ) {
236   build_borrowing_binary_op(maybe_get_output(), grad, self);
237 }
238 
239 } // namespace at::meta
240 
241 namespace at::native {
242 
243 static const double SELU_ALPHA = 1.6732632423543772848170429916717;
244 static const double SELU_SCALE = 1.0507009873554804934193349852946;
245 
246 DEFINE_DISPATCH(elu_stub);
247 DEFINE_DISPATCH(elu_backward_stub);
248 DEFINE_DISPATCH(softplus_stub);
249 DEFINE_DISPATCH(softplus_backward_stub);
250 DEFINE_DISPATCH(log_sigmoid_cpu_stub);
251 DEFINE_DISPATCH(log_sigmoid_backward_stub);
252 DEFINE_DISPATCH(threshold_stub);
253 DEFINE_DISPATCH(hardtanh_backward_stub);
254 DEFINE_DISPATCH(hardsigmoid_stub);
255 DEFINE_DISPATCH(hardsigmoid_backward_stub);
256 DEFINE_DISPATCH(hardswish_stub);
257 DEFINE_DISPATCH(hardswish_backward_stub);
258 DEFINE_DISPATCH(hardshrink_stub);
259 DEFINE_DISPATCH(softshrink_stub);
260 DEFINE_DISPATCH(shrink_backward_stub);
261 DEFINE_DISPATCH(leaky_relu_stub);
262 DEFINE_DISPATCH(leaky_relu_backward_stub);
263 DEFINE_DISPATCH(silu_stub);
264 DEFINE_DISPATCH(silu_backward_stub);
265 DEFINE_DISPATCH(mish_stub);
266 DEFINE_DISPATCH(mish_backward_stub);
267 DEFINE_DISPATCH(prelu_stub);
268 DEFINE_DISPATCH(prelu_backward_stub);
269 
TORCH_IMPL_FUNC(elu_out)270 TORCH_IMPL_FUNC(elu_out) (
271   const Tensor& self, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale, const Tensor& result
272 ) {
273   elu_stub(device_type(), *this, alpha, scale, input_scale);
274 }
275 
TORCH_IMPL_FUNC(elu_backward_out)276 TORCH_IMPL_FUNC(elu_backward_out) (
277   const Tensor& grad_output,
278   const Scalar& alpha,
279   const Scalar& scale,
280   const Scalar& input_scale,
281   bool is_result,
282   const Tensor& self_or_result,
283   const Tensor& grad_input
284 ) {
285   elu_backward_stub(device_type(), *this, alpha, scale, input_scale, is_result);
286 }
287 
TORCH_IMPL_FUNC(silu_out)288 TORCH_IMPL_FUNC(silu_out) (
289   const Tensor& self, const Tensor& result
290 ) {
291   silu_stub(device_type(), *this);
292 }
293 
TORCH_IMPL_FUNC(silu_backward_out)294 TORCH_IMPL_FUNC(silu_backward_out) (
295   const Tensor& grad_output, const Tensor& input, const Tensor& grad_input
296 ) {
297   silu_backward_stub(device_type(), *this);
298 }
299 
TORCH_IMPL_FUNC(mish_out)300 TORCH_IMPL_FUNC(mish_out) (
301   const Tensor& self, const Tensor& result
302 ) {
303   mish_stub(device_type(), *this);
304 }
305 
TORCH_IMPL_FUNC(softplus_out)306 TORCH_IMPL_FUNC(softplus_out) (
307   const Tensor& self, const Scalar& beta, const Scalar& threshold, const Tensor& result
308 ) {
309   softplus_stub(device_type(), *this, beta, threshold);
310 }
311 
TORCH_IMPL_FUNC(softplus_backward_out)312 TORCH_IMPL_FUNC(softplus_backward_out) (
313   const Tensor& grad_output,
314   const Tensor& self,
315   const Scalar& beta,
316   const Scalar& threshold,
317   const Tensor& grad_input
318 ) {
319   softplus_backward_stub(device_type(), *this, beta, threshold);
320 }
321 
TORCH_IMPL_FUNC(leaky_relu_out)322 TORCH_IMPL_FUNC(leaky_relu_out) (
323   const Tensor& self, const Scalar& negval, const Tensor& result
324 ) {
325   leaky_relu_stub(device_type(), *this, negval);
326 }
327 
TORCH_IMPL_FUNC(leaky_relu_backward_out)328 TORCH_IMPL_FUNC(leaky_relu_backward_out) (
329   const Tensor& grad_output,
330   const Tensor& self_or_result,
331   const Scalar& negval,
332   bool is_result,
333   const Tensor& grad_input
334 ) {
335   leaky_relu_backward_stub(device_type(), *this, negval);
336 }
337 
TORCH_IMPL_FUNC(hardsigmoid_out)338 TORCH_IMPL_FUNC(hardsigmoid_out) (
339   const Tensor& self, const Tensor& result
340 ) {
341   hardsigmoid_stub(device_type(), *this);
342 }
343 
TORCH_IMPL_FUNC(hardsigmoid_backward_out)344 TORCH_IMPL_FUNC(hardsigmoid_backward_out) (
345   const Tensor& grad_output, const Tensor& self, const Tensor& grad_input
346 ) {
347   hardsigmoid_backward_stub(device_type(), *this);
348 }
349 
TORCH_IMPL_FUNC(hardshrink_out)350 TORCH_IMPL_FUNC(hardshrink_out) (
351   const Tensor & self, const Scalar& lambd, const Tensor& result
352 ) {
353   hardshrink_stub(device_type(), *this, lambd);
354 }
355 
TORCH_IMPL_FUNC(hardshrink_backward_out)356 TORCH_IMPL_FUNC(hardshrink_backward_out) (
357   const Tensor & grad, const Tensor & self, const Scalar& lambd, const Tensor& grad_input
358 ) {
359   shrink_backward_stub(device_type(), *this, lambd);
360 }
361 
TORCH_IMPL_FUNC(softshrink_out)362 TORCH_IMPL_FUNC(softshrink_out) (
363   const Tensor & self, const Scalar& lambd, const Tensor& result
364 ) {
365   softshrink_stub(device_type(), *this, lambd);
366 }
367 
TORCH_IMPL_FUNC(softshrink_backward_out)368 TORCH_IMPL_FUNC(softshrink_backward_out) (
369   const Tensor & grad, const Tensor & self, const Scalar& lambd, const Tensor& grad_input
370 ) {
371   shrink_backward_stub(device_type(), *this, lambd);
372 }
373 
374 #if AT_MKLDNN_ENABLED()
use_mkldnn(const Tensor & input)375 static bool use_mkldnn(const Tensor& input) {
376   if (!at::globalContext().userEnabledMkldnn()) {
377     return false;
378   }
379   if (!input.is_contiguous() || input.numel() <= 1) {
380     return false;
381   }
382   return (input.is_mkldnn()) || // input is mkldnn Tensor
383     (input.device().is_cpu() &&
384     (((input.scalar_type() == kBFloat16) && mkldnn_bf16_device_check()) ||
385     (input.scalar_type() == kFloat))); // input is dense layout and bfloat16/float32
386 }
387 #endif
388 
TORCH_IMPL_FUNC(gelu_out_cpu)389 TORCH_IMPL_FUNC(gelu_out_cpu) (
390   const Tensor& self, c10::string_view approximate, const Tensor& result
391 ) {
392 auto approximate_type = get_gelutype_enum(approximate);
393 #if AT_MKLDNN_ENABLED()
394   if (use_mkldnn(self) && (approximate_type == GeluType::None)) {
395     const ideep::tensor& x = itensor_from_tensor(self, /*from_const_data_ptr*/true);
396     ideep::tensor y = itensor_from_tensor(result);
397     ideep::eltwise_forward::compute(
398       x, y, ideep::algorithm::eltwise_gelu_erf, ideep::prop_kind::forward_training, /*alpha*/ 0.0);
399 #ifdef __aarch64__
400   } else if (use_mkldnn(self) && (approximate_type == GeluType::Tanh)) {
401     const ideep::tensor& x = itensor_from_tensor(self);
402     ideep::tensor y = itensor_from_tensor(result);
403     ideep::eltwise_forward::compute(
404       x, y, ideep::algorithm::eltwise_gelu_tanh, ideep::prop_kind::forward_training, /*alpha*/ 0.0);
405 #endif  // ifdef __aarch64__
406   } else {
407     GeluKernel(kCPU, *this, approximate_type);
408   }
409 #else
410   GeluKernel(kCPU, *this, approximate_type);
411 #endif
412 }
413 
TORCH_IMPL_FUNC(gelu_backward_out_cpu)414 TORCH_IMPL_FUNC(gelu_backward_out_cpu) (
415   const Tensor& grad, const Tensor& self, c10::string_view approximate, const Tensor& grad_input
416 ) {
417 auto approximate_type = get_gelutype_enum(approximate);
418 #if AT_MKLDNN_ENABLED()
419   if (use_mkldnn(self) && (approximate_type == GeluType::None)) {
420     const ideep::tensor& x = itensor_from_tensor(self, /*from_const_data_ptr*/true);
421     ideep::tensor grady = itensor_from_tensor(grad, /*from_const_data_ptr*/true);
422     ideep::tensor gradx = itensor_from_tensor(grad_input);
423     ideep::eltwise_backward::compute(x, grady, gradx,
424       ideep::algorithm::eltwise_gelu_erf, /*alpha*/ 0.0);
425   } else {
426     GeluBackwardKernel(kCPU, *this, approximate_type);
427   }
428 #else
429   GeluBackwardKernel(kCPU, *this, approximate_type);
430 #endif
431 }
432 
hardtanh(const Tensor & self,const Scalar & min,const Scalar & max)433 Tensor hardtanh(const Tensor& self, const Scalar& min, const Scalar& max) {
434   Tensor result = at::empty_like(self);
435   return at::hardtanh_out(result, self, min, max);
436 }
437 
hardtanh_out(const Tensor & self,const Scalar & min,const Scalar & max,Tensor & result)438 Tensor& hardtanh_out(const Tensor& self, const Scalar& min, const Scalar& max, Tensor& result) {
439   TORCH_CHECK(self.scalar_type() != at::kBool,
440   "Bool inputs not supported for hardtanh");
441   //preserve legacy behavior of boundaries not causing type promotion
442   Scalar min_, max_;
443   if (at::isIntegralType(self.scalar_type(), /*include_bool*/false)) {
444     int64_t minval = min.toLong();
445     int64_t maxval = max.toLong();
446     TORCH_CHECK(self.dtype() != at::kByte || (minval >= 0 &&
447        maxval >=0), "cannot do hardtanh on an unsigned type with negative limits");
448     min_ = minval;
449     max_ = maxval;
450   } else {
451     min_ = min;
452     max_ = max;
453   }
454   return at::clamp_out(result, self, min_, max_);
455 }
456 
hardtanh_(Tensor & self,const Scalar & min,const Scalar & max)457 Tensor& hardtanh_(Tensor& self, const Scalar& min, const Scalar& max) {
458   return at::hardtanh_out(self, self, min, max);
459 }
460 
hardtanh_backward_out(const Tensor & grad_output,const Tensor & self,const Scalar & min,const Scalar & max,Tensor & grad_input)461 Tensor& hardtanh_backward_out(const Tensor& grad_output, const Tensor& self, const Scalar& min, const Scalar& max, Tensor& grad_input) {
462   auto iter = TensorIterator::borrowing_binary_op(grad_input, grad_output, self);
463   hardtanh_backward_stub(iter.device_type(), iter, min, max);
464   return grad_input;
465 }
466 
hardtanh_backward(const Tensor & grad_output,const Tensor & self,const Scalar & min,const Scalar & max)467 Tensor hardtanh_backward(const Tensor& grad_output, const Tensor& self, const Scalar& min, const Scalar& max) {
468   Tensor result;
469   auto iter = TensorIterator::borrowing_binary_op(result, grad_output, self);
470   hardtanh_backward_stub(iter.device_type(), iter, min, max);
471   return iter.output();
472 }
473 
hardswish(const Tensor & self)474 Tensor hardswish(const Tensor& self) {
475   #if defined(C10_MOBILE) && defined(USE_XNNPACK)
476   if (xnnpack::use_hardswish(self)) {
477     return xnnpack::hardswish(self);
478   }
479   #endif
480   Tensor result;
481   auto iter = TensorIterator::unary_op(result, self);
482   hardswish_stub(iter.device_type(), iter);
483   return iter.output();
484 }
485 
hardswish_out(const Tensor & self,Tensor & result)486 Tensor& hardswish_out(const Tensor& self, Tensor& result) {
487   auto iter = TensorIterator::unary_op(result, self);
488   hardswish_stub(iter.device_type(), iter);
489   return result;
490 }
491 
hardswish_(Tensor & self)492 Tensor& hardswish_(Tensor& self) {
493   #if defined(C10_MOBILE) && defined(USE_XNNPACK)
494   if (xnnpack::use_hardswish(self)) {
495     xnnpack::hardswish_(self);
496     return self;
497   }
498   #endif
499   auto iter = TensorIterator::unary_op(self, self);
500   hardswish_stub(iter.device_type(), iter);
501   return self;
502 }
503 
hardswish_backward(const Tensor & grad_output,const Tensor & self)504 Tensor hardswish_backward(const Tensor& grad_output, const Tensor& self) {
505   Tensor grad_input;
506   auto iter = TensorIterator::borrowing_binary_op(grad_input, grad_output, self);
507   hardswish_backward_stub(iter.device_type(), iter);
508   return iter.output();
509 }
510 
relu(const Tensor & self)511 Tensor relu(const Tensor & self) {
512   TORCH_CHECK(self.scalar_type() != at::kBool, "Boolean inputs not supported for relu");
513   return at::clamp_min(self, 0);
514 }
515 
relu_(Tensor & self)516 Tensor & relu_(Tensor & self) {
517   TORCH_CHECK(self.scalar_type() != at::kBool, "Boolean inputs not supported for relu");
518   return at::clamp_min_(self, 0);
519 }
520 
selu(const Tensor & self)521 Tensor selu(const Tensor & self) {
522   return at::elu(self, SELU_ALPHA, SELU_SCALE);
523 }
524 
relu6(const Tensor & self)525 Tensor relu6(const Tensor & self) {
526   return at::hardtanh(self, /*min_val=*/0, /*max_val=*/6);
527 }
528 
selu_(Tensor & self)529 Tensor & selu_(Tensor & self) {
530   return at::elu_(self, SELU_ALPHA, SELU_SCALE);
531 }
532 
relu6_(Tensor & self)533 Tensor & relu6_(Tensor & self) {
534   return at::hardtanh_(self, /*min_val=*/0, /*max_val=*/6);
535 }
536 
celu(const Tensor & self,const Scalar & alpha)537 Tensor celu(const Tensor & self, const Scalar& alpha) {
538   TORCH_CHECK(alpha.to<double>() != 0,
539       "ZeroDivisionError: alpha cannot be 0 for CELU");
540   double inv_alpha = 1. / alpha.to<double>();
541   return at::elu(self, alpha, Scalar(1.0), Scalar(inv_alpha));
542 }
543 
celu_(Tensor & self,const Scalar & alpha)544 Tensor & celu_(Tensor & self, const Scalar& alpha) {
545   TORCH_CHECK(alpha.to<double>() != 0,
546       "ZeroDivisionError: alpha cannot be 0 for CELU");
547   double inv_alpha = 1. / alpha.to<double>();
548   return at::elu_(self, alpha, Scalar(1.0), Scalar(inv_alpha));
549 }
550 
math_silu_backward(const Tensor & grad_output,const Tensor & input)551 Tensor math_silu_backward(
552     const Tensor& grad_output,
553     const Tensor& input) {
554   auto input_sigmoid = at::sigmoid(input);
555   return grad_output * (input_sigmoid * (1 + input * (1 - input_sigmoid)));
556 }
557 
mish_backward(const Tensor & grad_output,const Tensor & input)558 Tensor mish_backward(
559     const Tensor& grad_output,
560     const Tensor& input) {
561   Tensor grad_input = at::empty({0}, input.options());
562   auto iter = TensorIterator::binary_op(grad_input, grad_output, input);
563   mish_backward_stub(iter.device_type(), iter);
564   return grad_input;
565 }
566 
math_mish_backward(const Tensor & grad_output,const Tensor & input)567 Tensor math_mish_backward(
568     const Tensor& grad_output,
569     const Tensor& input) {
570   auto input_tanh_softplus = at::tanh(at::softplus(input));
571   auto input_sigmoid = at::sigmoid(input);
572   return grad_output * (input_tanh_softplus + (input * input_sigmoid * (1 - input_tanh_softplus * input_tanh_softplus)));
573 }
574 
575 template <typename scalar_t>
_rrelu_with_noise_train(Tensor & output,const Tensor & input,const Tensor & noise,const Scalar & lower_,const Scalar & upper_,std::optional<Generator> generator)576 inline void _rrelu_with_noise_train(
577     Tensor& output,
578     const Tensor& input,
579     const Tensor& noise,
580     const Scalar& lower_,
581     const Scalar& upper_,
582     std::optional<Generator> generator) {
583   using opmath_t = at::opmath_type<scalar_t>;
584   opmath_t lower = lower_.to<opmath_t>();
585   opmath_t upper = upper_.to<opmath_t>();
586   Tensor tmp_tensor = output.contiguous();
587   scalar_t* output_data = tmp_tensor.data_ptr<scalar_t>();
588   const scalar_t* input_data = input.const_data_ptr<scalar_t>();
589   scalar_t* noise_data = noise.data_ptr<scalar_t>();
590   auto gen  = at::get_generator_or_default<CPUGeneratorImpl>(generator, detail::getDefaultCPUGenerator());
591   std::lock_guard<std::mutex> lock(gen->mutex_);
592   for (const auto i : c10::irange(input.numel())) {
593     if (input_data[i] <= 0) {
594       at::uniform_real_distribution<double> uniform(lower, upper);
595       const opmath_t r = (opmath_t)uniform(gen);
596       output_data[i] = input_data[i] * r;
597       noise_data[i] = r;
598     } else {
599       noise_data[i] = 1;
600       output_data[i] = input_data[i];
601     }
602   }
603   if (!output.is_contiguous()) {
604     output.copy_(tmp_tensor);
605   }
606 }
607 
rrelu_with_noise_out_cpu(const Tensor & self,const Tensor & noise,const Scalar & lower,const Scalar & upper,bool training,std::optional<Generator> generator,Tensor & output)608 Tensor& rrelu_with_noise_out_cpu(const Tensor& self,
609     const Tensor& noise,
610     const Scalar& lower,
611     const Scalar& upper,
612     bool training,
613     std::optional<Generator> generator,
614     Tensor& output) {
615   TORCH_CHECK(self.sym_sizes() == noise.sym_sizes(), "noise tensor shape must match self tensor shape. Got self.shape = ", self.sym_sizes(), " noise.shape = ", noise.sym_sizes());
616   if (training) {
617     AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, self.scalar_type(), "rrelu_with_noise_out_cpu", [&] {
618       _rrelu_with_noise_train<scalar_t>(output, self.contiguous(), noise, lower, upper, generator);
619     });
620     return output;
621   } else {
622     auto lower_tensor = scalar_to_tensor(lower);
623     auto upper_tensor = scalar_to_tensor(upper);
624     auto negative = (lower_tensor + upper_tensor) / 2;
625     Scalar negative_slope = negative.item();
626     return at::leaky_relu_out(output, self, negative_slope);
627   }
628 }
629 
rrelu_with_noise_cpu(const Tensor & self,const Tensor & noise,const Scalar & lower,const Scalar & upper,bool training,std::optional<Generator> generator)630 Tensor rrelu_with_noise_cpu(
631     const Tensor& self,
632     const Tensor& noise,
633     const Scalar& lower,
634     const Scalar& upper,
635     bool training,
636     std::optional<Generator> generator) {
637   auto output = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
638   return at::native::rrelu_with_noise_out_cpu(
639       self, noise, lower, upper, training, std::move(generator), output);
640 }
641 
rrelu_with_noise_cpu_(Tensor & self,const Tensor & noise,const Scalar & lower,const Scalar & upper,bool training,std::optional<Generator> generator)642 Tensor& rrelu_with_noise_cpu_(
643     Tensor& self,
644     const Tensor& noise,
645     const Scalar& lower,
646     const Scalar& upper,
647     bool training,
648     std::optional<Generator> generator) {
649   return at::native::rrelu_with_noise_out_cpu(
650       self, noise, lower, upper, training, std::move(generator), self);
651 }
652 
rrelu_with_noise_backward(const Tensor & grad_output,const Tensor & self_or_result,const Tensor & noise,const Scalar & lower,const Scalar & upper,bool training,bool is_result)653 Tensor rrelu_with_noise_backward(
654     const Tensor& grad_output,
655     const Tensor& self_or_result,
656     const Tensor& noise,
657     const Scalar& lower,
658     const Scalar& upper,
659     bool training,
660     bool is_result) {
661   if (training) {
662     return noise * grad_output;
663   } else {
664     auto l = lower.toDouble();
665     auto u = upper.toDouble();
666     auto mid = (l + u) / 2.;
667     return at::leaky_relu_backward(grad_output, self_or_result, mid, is_result);
668   }
669 }
670 
rrelu(const Tensor & self,const Scalar & lower,const Scalar & upper,bool training,std::optional<Generator> generator)671 Tensor rrelu(const Tensor & self, const Scalar& lower, const Scalar& upper, bool training, std::optional<Generator> generator) {
672   TORCH_CHECK(lower.to<double>() <= upper.to<double>(), "Lower bound should be less than or equal to the upper bound")
673   return at::rrelu_with_noise(self, at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT), lower, upper, training, std::move(generator));
674 }
675 
rrelu_(Tensor & self,const Scalar & lower,const Scalar & upper,bool training,std::optional<Generator> generator)676 Tensor & rrelu_(Tensor & self, const Scalar& lower, const Scalar& upper, bool training, std::optional<Generator> generator) {
677   TORCH_CHECK(lower.to<double>() <= upper.to<double>(), "Lower bound should be less than or equal to the upper bound")
678   return at::rrelu_with_noise_(self, at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT), lower, upper, training, std::move(generator));
679 }
680 
TORCH_IMPL_FUNC(threshold_out)681 TORCH_IMPL_FUNC(threshold_out)(const Tensor& self, const Scalar& threshold, const Scalar& value, const Tensor& result) {
682   threshold_stub(device_type(), *this, threshold, value);
683 }
684 
TORCH_IMPL_FUNC(threshold_backward_out)685 TORCH_IMPL_FUNC(threshold_backward_out)(const Tensor& grad, const Tensor& self, const Scalar& threshold, const Tensor& gradInput) {
686   threshold_stub(device_type(), *this, threshold, 0);
687 }
688 
prelu(const Tensor & self,const Tensor & weight_)689 Tensor prelu(const Tensor& self, const Tensor& weight_) {
690   TORCH_INTERNAL_ASSERT(weight_.defined());
691   auto self_dim = self.dim();
692   TORCH_CHECK(self.scalar_type() == weight_.scalar_type(),
693               "prelu: Type promoting not supported. Got ",
694               self.scalar_type(), " and ", weight_.scalar_type());
695   if (weight_.sym_numel() != 1) {
696     TORCH_CHECK(self_dim > 0, "Not allow zero-dim input tensor.");
697 
698     auto channel_size = self_dim > 1 ? self.sym_size(1) : 1; // channel_size default to 1
699     TORCH_CHECK(channel_size == weight_.sym_numel(),
700       "Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_.numel(),
701       " and channel size = ", channel_size, ".");
702   }
703 
704   TORCH_CHECK(
705     weight_.dim() <= 1,
706     "prelu: Expected `weight` to be a scalar or 1D tensor, but got: ndim = ", weight_.dim());
707   // Adjust weight to broadcast over self and have weight.ndim == self.ndim
708   auto weight = weight_;
709   if (self_dim != weight.dim()) {
710     SymDimVector dim_w(self_dim, 1);
711     if (self_dim > 1) {
712       dim_w[1] = weight_.sym_numel();
713     }
714     // This will always be a view in CPU/CUDA, but some backends
715     // like MKLDNN do not support views
716     weight = weight.reshape_symint(dim_w);
717   }
718   return at::_prelu_kernel(self, weight);
719 }
720 
721 
_prelu_kernel(const Tensor & self,const Tensor & weight)722 Tensor _prelu_kernel(const Tensor& self, const Tensor& weight) {
723   // Weight broadcasts over self and they have the same dtype
724   auto result = at::empty_like(self);
725   auto iter = TensorIteratorConfig()
726     .add_output(result)
727     .add_const_input(self)
728     .add_const_input(weight)
729     .build();
730   prelu_stub(iter.device_type(), iter);
731   return result;
732 }
733 
_prelu_kernel_backward(const Tensor & grad_out,const Tensor & self,const Tensor & weight)734 std::tuple<Tensor, Tensor> _prelu_kernel_backward(const Tensor& grad_out, const Tensor& self, const Tensor& weight) {
735   Tensor grad_self = at::empty({0}, self.options());
736   Tensor grad_weight = at::empty({0}, weight.options());
737   auto iter = TensorIteratorConfig()
738     .add_output(grad_self)
739     .add_output(grad_weight)
740     .add_const_input(self)
741     .add_const_input(weight)
742     .add_const_input(grad_out)
743     .build();
744   prelu_backward_stub(iter.device_type(), iter);
745   return {grad_self, grad_weight};
746 }
747 
infinitely_differentiable_gelu_backward(const Tensor & grad,const Tensor & self)748 Tensor infinitely_differentiable_gelu_backward(
749     const Tensor& grad,
750     const Tensor& self) {
751   constexpr double kAlpha = M_2_SQRTPI * M_SQRT1_2 * 0.5;
752   Tensor cdf = (1.0 + (self * M_SQRT1_2).erf_()).mul_(0.5);
753   Tensor pdf = (-0.5 * self * self).exp_();
754   return cdf.addcmul_(self, pdf, kAlpha).mul_(grad);
755 }
756 
log_sigmoid_forward_cpu(const Tensor & input)757 std::tuple<Tensor, Tensor> log_sigmoid_forward_cpu(const Tensor& input) {
758   auto result = at::empty_like(input, at::MemoryFormat::Contiguous);
759   auto buffer = at::empty_like(input, at::MemoryFormat::Contiguous);
760   log_sigmoid_cpu_stub(kCPU, result, buffer, input.contiguous());
761   return std::make_tuple(result, buffer);
762 }
763 
log_sigmoid_forward_out_cpu(const Tensor & input,Tensor & result,Tensor & buffer)764 std::tuple<Tensor&, Tensor&> log_sigmoid_forward_out_cpu(const Tensor& input, Tensor& result, Tensor& buffer) {
765   result.resize_as_(input);
766   buffer.resize_as_(input, at::MemoryFormat::Contiguous);
767   TORCH_CHECK(buffer.is_contiguous(), "Contiguous buffer required for log_sigmoid with out parameter");
768   Tensor result_tmp = result.is_contiguous() ? result : at::empty_like(result, at::MemoryFormat::Contiguous);
769   log_sigmoid_cpu_stub(kCPU, result_tmp, buffer, input.contiguous());
770   if (!result.is_contiguous()) {
771     result.copy_(result_tmp);
772   }
773   return std::forward_as_tuple(result, buffer);
774 }
775 
log_sigmoid_out(const Tensor & self,Tensor & output)776 Tensor & log_sigmoid_out(const Tensor & self, Tensor & output) {
777   Tensor buffer = at::empty({0}, self.options());
778   return std::get<0>(at::log_sigmoid_forward_out(output, buffer, self));
779 }
780 
log_sigmoid(const Tensor & self)781 Tensor log_sigmoid(const Tensor & self) {
782   return std::get<0>(at::log_sigmoid_forward(self));
783 }
784 
log_sigmoid_backward_cuda(const Tensor & grad_output,const Tensor & input,const Tensor & buffer)785 Tensor log_sigmoid_backward_cuda(const Tensor& grad_output, const Tensor& input, const Tensor& buffer) {
786   auto grad_input = at::empty_like(grad_output);
787   // NOTE: buffer is only used by CPU dispatch, we just ignore it here
788   auto iter = at::TensorIteratorConfig()
789       .add_output(grad_input)
790       .add_const_input(input)
791       .add_const_input(grad_output)
792       .build();
793   log_sigmoid_backward_stub(kCUDA, iter);
794   return iter.output();
795 }
796 
log_sigmoid_backward_cpu(const Tensor & grad_output,const Tensor & input,const Tensor & buffer)797 Tensor log_sigmoid_backward_cpu(const Tensor& grad_output, const Tensor& input, const Tensor& buffer) {
798   auto grad_input = at::empty_like(grad_output);
799   auto iter = at::TensorIteratorConfig()
800       .add_output(grad_input)
801       .add_const_input(input)
802       .add_const_input(buffer)
803       .add_const_input(grad_output)
804       .build();
805   log_sigmoid_backward_stub(kCPU, iter);
806   return iter.output();
807 }
808 
log_sigmoid_backward_cuda_out(const Tensor & grad_output,const Tensor & input,const Tensor & buffer,Tensor & grad_input)809 Tensor& log_sigmoid_backward_cuda_out(const Tensor& grad_output, const Tensor& input,
810                                       const Tensor& buffer, Tensor& grad_input) {
811   auto iter = TensorIteratorConfig()
812       .add_output(grad_input)
813       .add_const_input(input)
814       .add_const_input(grad_output)
815       .build();
816   log_sigmoid_backward_stub(kCUDA, iter);
817   return grad_input;
818 }
819 
log_sigmoid_backward_cpu_out(const Tensor & grad_output,const Tensor & input,const Tensor & buffer,Tensor & grad_input)820 Tensor& log_sigmoid_backward_cpu_out(const Tensor& grad_output,
821     const Tensor& input,
822     const Tensor& buffer,
823     Tensor& grad_input) {
824   auto iter = TensorIteratorConfig()
825       .add_output(grad_input)
826       .add_const_input(input)
827       .add_const_input(buffer)
828       .add_const_input(grad_output)
829       .build();
830   log_sigmoid_backward_stub(kCPU, iter);
831   return grad_input;
832 }
833 
834 DEFINE_DISPATCH(GeluKernel);
835 DEFINE_DISPATCH(GeluBackwardKernel);
836 
837 }  // namespace at::native
838