1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #ifdef __aarch64__
10 #include <arm_neon.h>
11 #include <sleef.h>
12 #endif
13
14 #include <cmath>
15
16 #include <executorch/runtime/kernel/kernel_includes.h>
17 #include <executorch/runtime/platform/assert.h>
18
19 namespace torch {
20 namespace executor {
21 namespace native {
22
23 using Tensor = exec_aten::Tensor;
24 using ScalarType = exec_aten::ScalarType;
25 using string_view = exec_aten::string_view;
26
27 namespace {
28
29 /**
30 * Element-wise gelu activation of `input` overwriting `out`.
31 *
32 * 'approximate' specifies the method used to approximation the Gelu function
33 * either 'none' to not approximate or 'tanh'
34 *
35 * Assumes that the tensors are contiguous, are the same shape, and have the
36 * same dtype. CTYPE should be the C type (like `float` or `double`) that
37 * matches the dtype of the tensors.
38 */
39 template <typename CTYPE>
gelu(executorch::runtime::KernelRuntimeContext & context,const Tensor & input,string_view approximate,Tensor & output)40 void gelu(
41 executorch::runtime::KernelRuntimeContext& context,
42 const Tensor& input,
43 string_view approximate,
44 Tensor& output) {
45 const CTYPE* in_data = input.const_data_ptr<CTYPE>();
46 CTYPE* out_data = output.mutable_data_ptr<CTYPE>();
47 size_t lim = input.numel();
48
49 // TODO: Add fast path for tanh using sleef's tanh
50 if (approximate == "tanh") {
51 // 0.5 * x * (1 + Tanh(sqrt(2 / pi) * (x + 0.044715 * x^3))
52 for (size_t i = 0; i < lim; ++i) {
53 const CTYPE x = in_data[i];
54 const CTYPE kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
55 const CTYPE kKappa = 0.044715;
56 auto x_cube = x * x * x;
57 auto inner = kBeta * (x + kKappa * x_cube);
58 out_data[i] = CTYPE(0.5) * x * (CTYPE(1) + std::tanh(inner));
59 }
60 } else if (approximate == "none") { // dont appx
61 // GELU(x) = x * Φ(x) where Φ(x) is the is the Cumulative Distribution
62 // Function for Gaussian Distribution.
63
64 #ifndef __aarch64__
65 for (size_t i = 0; i < lim; ++i) {
66 const CTYPE x = in_data[i];
67 out_data[i] = CTYPE(0.5) * x * (CTYPE(1) + std::erf(x * M_SQRT1_2));
68 }
69 #else
70 size_t i = 0;
71 if (std::is_same<CTYPE, float>::value) {
72 for (; i + 4 < lim; i += 4) {
73 const float32x4_t in =
74 vld1q_f32(static_cast<const float*>(&in_data[i]));
75 const float32x4_t m_sqrt1_2x4 = {
76 M_SQRT1_2, M_SQRT1_2, M_SQRT1_2, M_SQRT1_2};
77 const float32x4_t ones = vmovq_n_f32(1.0);
78 const float32x4_t halves = vmovq_n_f32(0.5);
79 float32x4_t out = Sleef_erff4_u10(vmulq_f32(in, m_sqrt1_2x4));
80 vst1q_f32(
81 static_cast<float*>(&out_data[i]),
82 vmulq_f32(vmulq_f32(vaddq_f32(out, ones), in), halves));
83 }
84 }
85 for (; i < lim; ++i) {
86 const CTYPE x = in_data[i];
87 out_data[i] = CTYPE(0.5) * x * (CTYPE(1) + std::erf(x * M_SQRT1_2));
88 }
89 #endif // __aarch64__
90
91 } else {
92 ET_KERNEL_CHECK_MSG(
93 context,
94 false,
95 InvalidArgument,
96 ,
97 "Invalid approximation format: %.*s for gelu",
98 static_cast<int>(approximate.length()),
99 approximate.data());
100 }
101 }
102
103 } // namespace
104
105 /**
106 * Element-wise Gelu of `input`, overwriting `out`.
107 *
108 * Asserts that all tensors have the same dtype and shape.
109 *
110 * gelu.out(Tensor self, str approximate, *, Tensor(a!) out) -> Tensor(a!)
111 */
opt_gelu_out(KernelRuntimeContext & context,const Tensor & input,string_view approximate,Tensor & out)112 Tensor& opt_gelu_out(
113 KernelRuntimeContext& context,
114 const Tensor& input,
115 string_view approximate,
116 Tensor& out) {
117 (void)context;
118 ET_KERNEL_CHECK(
119 context,
120 tensors_have_same_shape_and_dtype(input, out),
121 InvalidArgument,
122 out);
123
124 // helper for generating the cases for different data types
125 #define GELU(ctype, dtype) \
126 case ScalarType::dtype: \
127 gelu<ctype>(context, input, approximate, out); \
128 break;
129
130 switch (input.scalar_type()) {
131 // TODO support Double as well
132 GELU(float, Float)
133 default:
134 ET_KERNEL_CHECK_MSG(
135 context,
136 false,
137 InvalidArgument,
138 out,
139 "Unhandled dtype %" PRId8,
140 static_cast<int8_t>(input.scalar_type()));
141 }
142 #undef GELU
143
144 return out;
145 }
146
147 } // namespace native
148 } // namespace executor
149 } // namespace torch
150