xref: /aosp_15_r20/external/executorch/kernels/optimized/cpu/op_gelu.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #ifdef __aarch64__
10 #include <arm_neon.h>
11 #include <sleef.h>
12 #endif
13 
14 #include <cmath>
15 
16 #include <executorch/runtime/kernel/kernel_includes.h>
17 #include <executorch/runtime/platform/assert.h>
18 
19 namespace torch {
20 namespace executor {
21 namespace native {
22 
23 using Tensor = exec_aten::Tensor;
24 using ScalarType = exec_aten::ScalarType;
25 using string_view = exec_aten::string_view;
26 
27 namespace {
28 
29 /**
30  * Element-wise gelu activation of `input` overwriting `out`.
31  *
32  * 'approximate' specifies the method used to approximation the Gelu function
33  *  either 'none' to not approximate or 'tanh'
34  *
35  * Assumes that the tensors are contiguous, are the same shape, and have the
36  * same dtype. CTYPE should be the C type (like `float` or `double`) that
37  * matches the dtype of the tensors.
38  */
39 template <typename CTYPE>
gelu(executorch::runtime::KernelRuntimeContext & context,const Tensor & input,string_view approximate,Tensor & output)40 void gelu(
41     executorch::runtime::KernelRuntimeContext& context,
42     const Tensor& input,
43     string_view approximate,
44     Tensor& output) {
45   const CTYPE* in_data = input.const_data_ptr<CTYPE>();
46   CTYPE* out_data = output.mutable_data_ptr<CTYPE>();
47   size_t lim = input.numel();
48 
49   // TODO: Add fast path for tanh using sleef's tanh
50   if (approximate == "tanh") {
51     // 0.5 * x * (1 + Tanh(sqrt(2 / pi) * (x + 0.044715 * x^3))
52     for (size_t i = 0; i < lim; ++i) {
53       const CTYPE x = in_data[i];
54       const CTYPE kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
55       const CTYPE kKappa = 0.044715;
56       auto x_cube = x * x * x;
57       auto inner = kBeta * (x + kKappa * x_cube);
58       out_data[i] = CTYPE(0.5) * x * (CTYPE(1) + std::tanh(inner));
59     }
60   } else if (approximate == "none") { // dont appx
61     // GELU(x) = x * Φ(x) where Φ(x) is the is the Cumulative Distribution
62     // Function for Gaussian Distribution.
63 
64 #ifndef __aarch64__
65     for (size_t i = 0; i < lim; ++i) {
66       const CTYPE x = in_data[i];
67       out_data[i] = CTYPE(0.5) * x * (CTYPE(1) + std::erf(x * M_SQRT1_2));
68     }
69 #else
70     size_t i = 0;
71     if (std::is_same<CTYPE, float>::value) {
72       for (; i + 4 < lim; i += 4) {
73         const float32x4_t in =
74             vld1q_f32(static_cast<const float*>(&in_data[i]));
75         const float32x4_t m_sqrt1_2x4 = {
76             M_SQRT1_2, M_SQRT1_2, M_SQRT1_2, M_SQRT1_2};
77         const float32x4_t ones = vmovq_n_f32(1.0);
78         const float32x4_t halves = vmovq_n_f32(0.5);
79         float32x4_t out = Sleef_erff4_u10(vmulq_f32(in, m_sqrt1_2x4));
80         vst1q_f32(
81             static_cast<float*>(&out_data[i]),
82             vmulq_f32(vmulq_f32(vaddq_f32(out, ones), in), halves));
83       }
84     }
85     for (; i < lim; ++i) {
86       const CTYPE x = in_data[i];
87       out_data[i] = CTYPE(0.5) * x * (CTYPE(1) + std::erf(x * M_SQRT1_2));
88     }
89 #endif // __aarch64__
90 
91   } else {
92     ET_KERNEL_CHECK_MSG(
93         context,
94         false,
95         InvalidArgument,
96         ,
97         "Invalid approximation format: %.*s for gelu",
98         static_cast<int>(approximate.length()),
99         approximate.data());
100   }
101 }
102 
103 } // namespace
104 
105 /**
106  * Element-wise Gelu of `input`, overwriting `out`.
107  *
108  * Asserts that all tensors have the same dtype and shape.
109  *
110  * gelu.out(Tensor self, str approximate, *, Tensor(a!) out) -> Tensor(a!)
111  */
opt_gelu_out(KernelRuntimeContext & context,const Tensor & input,string_view approximate,Tensor & out)112 Tensor& opt_gelu_out(
113     KernelRuntimeContext& context,
114     const Tensor& input,
115     string_view approximate,
116     Tensor& out) {
117   (void)context;
118   ET_KERNEL_CHECK(
119       context,
120       tensors_have_same_shape_and_dtype(input, out),
121       InvalidArgument,
122       out);
123 
124 // helper for generating the cases for different data types
125 #define GELU(ctype, dtype)                         \
126   case ScalarType::dtype:                          \
127     gelu<ctype>(context, input, approximate, out); \
128     break;
129 
130   switch (input.scalar_type()) {
131     // TODO support Double as well
132     GELU(float, Float)
133     default:
134       ET_KERNEL_CHECK_MSG(
135           context,
136           false,
137           InvalidArgument,
138           out,
139           "Unhandled dtype %" PRId8,
140           static_cast<int8_t>(input.scalar_type()));
141   }
142 #undef GELU
143 
144   return out;
145 }
146 
147 } // namespace native
148 } // namespace executor
149 } // namespace torch
150