xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/ActivationLogSigmoidKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #define _USE_MATH_DEFINES
3 
4 #include <ATen/native/Activation.h>
5 
6 #include <cmath>
7 
8 #include <thrust/tuple.h>
9 
10 #include <ATen/AccumulateType.h>
11 #include <ATen/Dispatch.h>
12 #include <ATen/core/TensorBase.h>
13 #include <c10/core/Scalar.h>
14 #include <c10/cuda/CUDAMathCompat.h>
15 #include <ATen/cuda/ApplyGridUtils.cuh>
16 #include <ATen/cuda/detail/OffsetCalculator.cuh>
17 #include <ATen/native/cuda/Loops.cuh>
18 
19 namespace at::native {
20 
21 // -----------------------------------
22 // log_sigmoid forward
23 // -----------------------------------
24 
launch_log_sigmoid_forward_kernel(TensorIteratorBase & iter)25 void launch_log_sigmoid_forward_kernel(TensorIteratorBase& iter) {
26   AT_DISPATCH_FLOATING_TYPES_AND2(
27       kHalf, kBFloat16, iter.common_dtype(), "log_sigmoid_forward_cuda", [&] {
28         using opmath_t = at::opmath_type<scalar_t>;
29 
30         gpu_kernel(iter, [] GPU_LAMBDA(scalar_t in_) -> scalar_t {
31           const opmath_t in = in_;
32           const auto min = std::min(opmath_t(0), in);
33           const auto z = std::exp(-std::abs(in));
34           return min - std::log1p(z);
35         });
36       });
37 }
38 
39 namespace {
40 // -----------------------------------
41 // log_sigmoid backward
42 // -----------------------------------
log_sigmoid_backward_kernel(TensorIterator & iter)43 void log_sigmoid_backward_kernel(TensorIterator& iter) {
44   AT_DISPATCH_FLOATING_TYPES_AND2(
45       kHalf, kBFloat16, iter.common_dtype(), "log_sigmoid_backward_cuda", [&] {
46         using opmath_t = at::opmath_type<scalar_t>;
47         gpu_kernel(
48             iter, [] GPU_LAMBDA(scalar_t in_, scalar_t grad_out_) -> scalar_t {
49               const opmath_t in = in_;
50               const opmath_t grad_out = grad_out_;
51 
52               auto in_negative = in < opmath_t(0);
53               auto max_deriv = in_negative ? opmath_t(1) : opmath_t(0);
54               auto sign = in_negative ? opmath_t(1) : -opmath_t(1);
55               const auto z = std::exp(-std::abs(in));
56               return grad_out * (max_deriv - sign * (z / (opmath_t(1) + z)));
57             });
58       });
59 }
60 } // namespace
61 
62 REGISTER_DISPATCH(log_sigmoid_backward_stub, &log_sigmoid_backward_kernel);
63 
64 } // namespace at::native
65