xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/op_gelu.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cmath>
10 
11 #include <executorch/kernels/portable/cpu/math_constants.h>
12 #include <executorch/kernels/portable/cpu/util/activation_ops_util.h>
13 #include <executorch/kernels/portable/cpu/util/functional_util.h>
14 #include <executorch/runtime/kernel/kernel_includes.h>
15 
16 namespace torch {
17 namespace executor {
18 namespace native {
19 
20 using Tensor = exec_aten::Tensor;
21 using ScalarType = exec_aten::ScalarType;
22 using string_view = exec_aten::string_view;
23 
gelu_out(KernelRuntimeContext & ctx,const Tensor & in,string_view approximate,Tensor & out)24 Tensor& gelu_out(
25     KernelRuntimeContext& ctx,
26     const Tensor& in,
27     string_view approximate,
28     Tensor& out) {
29   (void)ctx;
30 
31   ET_KERNEL_CHECK(
32       ctx, check_gelu_args(in, approximate, out), InvalidArgument, out);
33 
34   ET_KERNEL_CHECK(
35       ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
36 
37   ET_KERNEL_CHECK(
38       ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
39 
40   ET_SWITCH_FLOAT_TYPES(in.scalar_type(), ctx, "gelu.out", CTYPE, [&]() {
41     if (approximate == "tanh") {
42       apply_unary_map_fn(
43           [](const CTYPE x) {
44             if (x == -std::numeric_limits<CTYPE>::infinity()) {
45               return static_cast<CTYPE>(0.0);
46             } else if (x == std::numeric_limits<CTYPE>::infinity()) {
47               return std::numeric_limits<CTYPE>::infinity();
48             }
49             const CTYPE kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
50             const CTYPE kKappa = static_cast<float>(0.044715);
51 
52             const CTYPE x_cubed = x * x * x;
53             const CTYPE inner = kBeta * (x + kKappa * x_cubed);
54             const CTYPE ret = 0.5 * x * (1 + std::tanh(inner));
55 
56             return ret;
57           },
58           in.const_data_ptr<CTYPE>(),
59           out.mutable_data_ptr<CTYPE>(),
60           in.numel());
61     } else if (approximate == "none") {
62       apply_unary_map_fn(
63           [](const CTYPE x) {
64             if (x == -std::numeric_limits<CTYPE>::infinity()) {
65               return static_cast<CTYPE>(0.0);
66             } else if (x == std::numeric_limits<CTYPE>::infinity()) {
67               return std::numeric_limits<CTYPE>::infinity();
68             }
69             return static_cast<CTYPE>(0.5 * x * (1 + std::erf(x * M_SQRT1_2)));
70           },
71           in.const_data_ptr<CTYPE>(),
72           out.mutable_data_ptr<CTYPE>(),
73           in.numel());
74     } else {
75       ET_CHECK_MSG(
76           false,
77           "Invalid approximation format: %.*s for gelu",
78           static_cast<int>(approximate.length()),
79           approximate.data());
80     }
81   });
82 
83   return out;
84 }
85 
86 } // namespace native
87 } // namespace executor
88 } // namespace torch
89