xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/ActivationGeluKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #define _USE_MATH_DEFINES
3 
4 #include <ATen/native/Activation.h>
5 
6 #include <cmath>
7 
8 #include <thrust/tuple.h>
9 
10 #include <ATen/AccumulateType.h>
11 #include <ATen/Dispatch.h>
12 #include <ATen/core/TensorBase.h>
13 #include <c10/core/Scalar.h>
14 #include <c10/cuda/CUDAMathCompat.h>
15 #include <ATen/cuda/ApplyGridUtils.cuh>
16 #include <ATen/cuda/detail/OffsetCalculator.cuh>
17 #include <ATen/native/cuda/Loops.cuh>
18 
19 namespace at::native {
20 
GeluCUDAKernelImpl(TensorIteratorBase & it,GeluType approximate)21 void GeluCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate) {
22   if (approximate == GeluType::Tanh) {
23     AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluCUDAKernelImpl", [&]() {
24       gpu_kernel(it, [] GPU_LAMBDA(scalar_t x) -> scalar_t {
25         using opmath_t = at::opmath_type<scalar_t>;
26         constexpr opmath_t kBeta = M_SQRT2 * M_2_SQRTPI * opmath_t(0.5);
27         constexpr opmath_t kKappa = 0.044715;
28         auto x_cube = static_cast<opmath_t>(x) * static_cast<opmath_t>(x) * static_cast<opmath_t>(x);
29         auto inner = kBeta * (static_cast<opmath_t>(x) + kKappa * x_cube);
30         return opmath_t(0.5) * static_cast<opmath_t>(x) * (opmath_t(1) + c10::cuda::compat::tanh(inner));
31       });
32     });
33   } else {
34     AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluCUDAKernelImpl", [&]() {
35       gpu_kernel(it, [] GPU_LAMBDA(scalar_t x) -> scalar_t {
36         using opmath_t = at::opmath_type<scalar_t>;
37         constexpr opmath_t kAlpha = M_SQRT1_2;
38         return static_cast<opmath_t>(x) * opmath_t(0.5) * (opmath_t(1) + ::erf(static_cast<opmath_t>(x) * kAlpha));
39       });
40     });
41   }
42 }
43 
GeluBackwardCUDAKernelImpl(TensorIteratorBase & it,GeluType approximate)44 void GeluBackwardCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate) {
45   if (approximate == GeluType::Tanh) {
46     AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
47         it.dtype(), "GeluBackwardCUDAKernelImpl", [&]() {
48           gpu_kernel(it, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t {
49             using opmath_t = at::opmath_type<scalar_t>;
50             constexpr opmath_t kBeta = M_SQRT2 * M_2_SQRTPI * opmath_t(0.5);
51             constexpr opmath_t kKappa = 0.044715;
52             auto x_sq = static_cast<opmath_t>(x) * static_cast<opmath_t>(x);
53             auto x_cube = x_sq * static_cast<opmath_t>(x);
54             auto inner = kBeta * (static_cast<opmath_t>(x) + kKappa * x_cube);
55             auto tanh_inner = c10::cuda::compat::tanh(inner);
56 
57             auto left = opmath_t(0.5) * static_cast<opmath_t>(x);
58             auto right = opmath_t(1) + tanh_inner;
59 
60             auto left_derivative = opmath_t(0.5) * right;
61 
62             auto tanh_derivative = opmath_t(1) - tanh_inner * tanh_inner;
63             auto inner_derivative = kBeta * (opmath_t(1) + opmath_t(3) * kKappa * x_sq);
64             auto right_derivative = left * tanh_derivative * inner_derivative;
65 
66             return static_cast<opmath_t>(dy) * (left_derivative + right_derivative);
67         });
68       });
69   } else {
70     AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
71         it.dtype(), "GeluBackwardCUDAKernelImpl", [&]() {
72           gpu_kernel(it, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t {
73             using opmath_t = at::opmath_type<scalar_t>;
74             constexpr opmath_t kBeta = M_2_SQRTPI * M_SQRT1_2 * opmath_t(0.5);
75             constexpr opmath_t kAlpha = M_SQRT1_2;
76             const opmath_t cdf =
77                 opmath_t(0.5) * (opmath_t(1) + ::erf(static_cast<opmath_t>(x) * kAlpha));
78             const opmath_t pdf =
79                 c10::cuda::compat::exp(
80                     opmath_t(-0.5) * static_cast<opmath_t>(x) * static_cast<opmath_t>(x)) *
81                 kBeta;
82             return static_cast<opmath_t>(dy) * (cdf + static_cast<opmath_t>(x) * pdf);
83           });
84         });
85   }
86 }
87 
88 } // namespace at::native
89