1 #define TORCH_ASSERT_NO_OPERATORS
2 #define _USE_MATH_DEFINES
3
4 #include <ATen/native/Activation.h>
5
6 #include <cmath>
7
8 #include <thrust/tuple.h>
9
10 #include <ATen/AccumulateType.h>
11 #include <ATen/Dispatch.h>
12 #include <ATen/core/TensorBase.h>
13 #include <c10/core/Scalar.h>
14 #include <c10/cuda/CUDAMathCompat.h>
15 #include <ATen/cuda/ApplyGridUtils.cuh>
16 #include <ATen/cuda/detail/OffsetCalculator.cuh>
17 #include <ATen/native/cuda/Loops.cuh>
18
19 namespace at::native {
20
GeluCUDAKernelImpl(TensorIteratorBase & it,GeluType approximate)21 void GeluCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate) {
22 if (approximate == GeluType::Tanh) {
23 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluCUDAKernelImpl", [&]() {
24 gpu_kernel(it, [] GPU_LAMBDA(scalar_t x) -> scalar_t {
25 using opmath_t = at::opmath_type<scalar_t>;
26 constexpr opmath_t kBeta = M_SQRT2 * M_2_SQRTPI * opmath_t(0.5);
27 constexpr opmath_t kKappa = 0.044715;
28 auto x_cube = static_cast<opmath_t>(x) * static_cast<opmath_t>(x) * static_cast<opmath_t>(x);
29 auto inner = kBeta * (static_cast<opmath_t>(x) + kKappa * x_cube);
30 return opmath_t(0.5) * static_cast<opmath_t>(x) * (opmath_t(1) + c10::cuda::compat::tanh(inner));
31 });
32 });
33 } else {
34 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluCUDAKernelImpl", [&]() {
35 gpu_kernel(it, [] GPU_LAMBDA(scalar_t x) -> scalar_t {
36 using opmath_t = at::opmath_type<scalar_t>;
37 constexpr opmath_t kAlpha = M_SQRT1_2;
38 return static_cast<opmath_t>(x) * opmath_t(0.5) * (opmath_t(1) + ::erf(static_cast<opmath_t>(x) * kAlpha));
39 });
40 });
41 }
42 }
43
GeluBackwardCUDAKernelImpl(TensorIteratorBase & it,GeluType approximate)44 void GeluBackwardCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate) {
45 if (approximate == GeluType::Tanh) {
46 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
47 it.dtype(), "GeluBackwardCUDAKernelImpl", [&]() {
48 gpu_kernel(it, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t {
49 using opmath_t = at::opmath_type<scalar_t>;
50 constexpr opmath_t kBeta = M_SQRT2 * M_2_SQRTPI * opmath_t(0.5);
51 constexpr opmath_t kKappa = 0.044715;
52 auto x_sq = static_cast<opmath_t>(x) * static_cast<opmath_t>(x);
53 auto x_cube = x_sq * static_cast<opmath_t>(x);
54 auto inner = kBeta * (static_cast<opmath_t>(x) + kKappa * x_cube);
55 auto tanh_inner = c10::cuda::compat::tanh(inner);
56
57 auto left = opmath_t(0.5) * static_cast<opmath_t>(x);
58 auto right = opmath_t(1) + tanh_inner;
59
60 auto left_derivative = opmath_t(0.5) * right;
61
62 auto tanh_derivative = opmath_t(1) - tanh_inner * tanh_inner;
63 auto inner_derivative = kBeta * (opmath_t(1) + opmath_t(3) * kKappa * x_sq);
64 auto right_derivative = left * tanh_derivative * inner_derivative;
65
66 return static_cast<opmath_t>(dy) * (left_derivative + right_derivative);
67 });
68 });
69 } else {
70 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
71 it.dtype(), "GeluBackwardCUDAKernelImpl", [&]() {
72 gpu_kernel(it, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t {
73 using opmath_t = at::opmath_type<scalar_t>;
74 constexpr opmath_t kBeta = M_2_SQRTPI * M_SQRT1_2 * opmath_t(0.5);
75 constexpr opmath_t kAlpha = M_SQRT1_2;
76 const opmath_t cdf =
77 opmath_t(0.5) * (opmath_t(1) + ::erf(static_cast<opmath_t>(x) * kAlpha));
78 const opmath_t pdf =
79 c10::cuda::compat::exp(
80 opmath_t(-0.5) * static_cast<opmath_t>(x) * static_cast<opmath_t>(x)) *
81 kBeta;
82 return static_cast<opmath_t>(dy) * (cdf + static_cast<opmath_t>(x) * pdf);
83 });
84 });
85 }
86 }
87
88 } // namespace at::native
89