1 #define TORCH_ASSERT_NO_OPERATORS
2 #define _USE_MATH_DEFINES
3
4 #include <ATen/native/Activation.h>
5
6 #include <cmath>
7
8 #include <thrust/tuple.h>
9
10 #include <ATen/AccumulateType.h>
11 #include <ATen/Dispatch.h>
12 #include <ATen/core/TensorBase.h>
13 #include <c10/core/Scalar.h>
14 #include <c10/cuda/CUDAMathCompat.h>
15 #include <ATen/cuda/ApplyGridUtils.cuh>
16 #include <ATen/cuda/detail/OffsetCalculator.cuh>
17 #include <ATen/native/cuda/Loops.cuh>
18
19 namespace at::native {
20
21 // -----------------------------------
22 // log_sigmoid forward
23 // -----------------------------------
24
launch_log_sigmoid_forward_kernel(TensorIteratorBase & iter)25 void launch_log_sigmoid_forward_kernel(TensorIteratorBase& iter) {
26 AT_DISPATCH_FLOATING_TYPES_AND2(
27 kHalf, kBFloat16, iter.common_dtype(), "log_sigmoid_forward_cuda", [&] {
28 using opmath_t = at::opmath_type<scalar_t>;
29
30 gpu_kernel(iter, [] GPU_LAMBDA(scalar_t in_) -> scalar_t {
31 const opmath_t in = in_;
32 const auto min = std::min(opmath_t(0), in);
33 const auto z = std::exp(-std::abs(in));
34 return min - std::log1p(z);
35 });
36 });
37 }
38
39 namespace {
40 // -----------------------------------
41 // log_sigmoid backward
42 // -----------------------------------
log_sigmoid_backward_kernel(TensorIterator & iter)43 void log_sigmoid_backward_kernel(TensorIterator& iter) {
44 AT_DISPATCH_FLOATING_TYPES_AND2(
45 kHalf, kBFloat16, iter.common_dtype(), "log_sigmoid_backward_cuda", [&] {
46 using opmath_t = at::opmath_type<scalar_t>;
47 gpu_kernel(
48 iter, [] GPU_LAMBDA(scalar_t in_, scalar_t grad_out_) -> scalar_t {
49 const opmath_t in = in_;
50 const opmath_t grad_out = grad_out_;
51
52 auto in_negative = in < opmath_t(0);
53 auto max_deriv = in_negative ? opmath_t(1) : opmath_t(0);
54 auto sign = in_negative ? opmath_t(1) : -opmath_t(1);
55 const auto z = std::exp(-std::abs(in));
56 return grad_out * (max_deriv - sign * (z / (opmath_t(1) + z)));
57 });
58 });
59 }
60 } // namespace
61
62 REGISTER_DISPATCH(log_sigmoid_backward_stub, &log_sigmoid_backward_kernel);
63
64 } // namespace at::native
65