1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <cmath>
10
11 #include <executorch/kernels/portable/cpu/math_constants.h>
12 #include <executorch/kernels/portable/cpu/util/activation_ops_util.h>
13 #include <executorch/kernels/portable/cpu/util/functional_util.h>
14 #include <executorch/runtime/kernel/kernel_includes.h>
15
16 namespace torch {
17 namespace executor {
18 namespace native {
19
20 using Tensor = exec_aten::Tensor;
21 using ScalarType = exec_aten::ScalarType;
22 using string_view = exec_aten::string_view;
23
gelu_out(KernelRuntimeContext & ctx,const Tensor & in,string_view approximate,Tensor & out)24 Tensor& gelu_out(
25 KernelRuntimeContext& ctx,
26 const Tensor& in,
27 string_view approximate,
28 Tensor& out) {
29 (void)ctx;
30
31 ET_KERNEL_CHECK(
32 ctx, check_gelu_args(in, approximate, out), InvalidArgument, out);
33
34 ET_KERNEL_CHECK(
35 ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
36
37 ET_KERNEL_CHECK(
38 ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
39
40 ET_SWITCH_FLOAT_TYPES(in.scalar_type(), ctx, "gelu.out", CTYPE, [&]() {
41 if (approximate == "tanh") {
42 apply_unary_map_fn(
43 [](const CTYPE x) {
44 if (x == -std::numeric_limits<CTYPE>::infinity()) {
45 return static_cast<CTYPE>(0.0);
46 } else if (x == std::numeric_limits<CTYPE>::infinity()) {
47 return std::numeric_limits<CTYPE>::infinity();
48 }
49 const CTYPE kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
50 const CTYPE kKappa = static_cast<float>(0.044715);
51
52 const CTYPE x_cubed = x * x * x;
53 const CTYPE inner = kBeta * (x + kKappa * x_cubed);
54 const CTYPE ret = 0.5 * x * (1 + std::tanh(inner));
55
56 return ret;
57 },
58 in.const_data_ptr<CTYPE>(),
59 out.mutable_data_ptr<CTYPE>(),
60 in.numel());
61 } else if (approximate == "none") {
62 apply_unary_map_fn(
63 [](const CTYPE x) {
64 if (x == -std::numeric_limits<CTYPE>::infinity()) {
65 return static_cast<CTYPE>(0.0);
66 } else if (x == std::numeric_limits<CTYPE>::infinity()) {
67 return std::numeric_limits<CTYPE>::infinity();
68 }
69 return static_cast<CTYPE>(0.5 * x * (1 + std::erf(x * M_SQRT1_2)));
70 },
71 in.const_data_ptr<CTYPE>(),
72 out.mutable_data_ptr<CTYPE>(),
73 in.numel());
74 } else {
75 ET_CHECK_MSG(
76 false,
77 "Invalid approximation format: %.*s for gelu",
78 static_cast<int>(approximate.length()),
79 approximate.data());
80 }
81 });
82
83 return out;
84 }
85
86 } // namespace native
87 } // namespace executor
88 } // namespace torch
89