1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/portable/cpu/util/activation_ops_util.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11 #include <executorch/runtime/platform/assert.h>
12 #include <cinttypes>
13 #include <cmath>
14 #include <cstdint>
15
16 namespace torch {
17 namespace executor {
18 namespace native {
19
20 using Tensor = exec_aten::Tensor;
21 using ScalarType = exec_aten::ScalarType;
22
23 namespace {
24
exp_overload(double d)25 double exp_overload(double d) {
26 return exp(d);
27 }
28
exp_overload(float f)29 float exp_overload(float f) {
30 return expf(f);
31 }
32
33 /**
34 * In-place element-wise sigmoid function , i.e., f(x) = 1 / (1 + e^{-x})
35 */
36 // TODO: T146333648, refactor this as a common helper function
37 template <typename CTYPE_OUT>
sigmoid_tensor(Tensor & out)38 void sigmoid_tensor(Tensor& out) {
39 CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
40 for (size_t i = 0; i < out.numel(); i++) {
41 out_data[i] = 1.0 / (1.0 + exp_overload(-out_data[i]));
42 }
43 }
44
45 /**
46 * Element-wise multiplication of the first half of `in` along the specified
47 * dimension and `out`, overwriting `out`.
48 */
49 template <typename CTYPE_IN, typename CTYPE_OUT>
mul_tensors(const Tensor & in,int64_t dim,Tensor & out)50 void mul_tensors(const Tensor& in, int64_t dim, Tensor& out) {
51 size_t num_values = static_cast<size_t>(in.size(dim)) / 2;
52 size_t dim_length_in = static_cast<size_t>(in.size(dim));
53 size_t dim_length_out = static_cast<size_t>(out.size(dim));
54 size_t leading_dims = getLeadingDims(in, dim);
55 size_t trailing_dims = getTrailingDims(in, dim);
56
57 const CTYPE_IN* input_data_base = in.const_data_ptr<CTYPE_IN>();
58 CTYPE_OUT* output_data_base = out.mutable_data_ptr<CTYPE_OUT>();
59
60 for (size_t i = 0; i < leading_dims; i++) {
61 const CTYPE_IN* input_data =
62 input_data_base + i * dim_length_in * trailing_dims;
63 CTYPE_OUT* output_data =
64 output_data_base + i * dim_length_out * trailing_dims;
65 for (size_t j = 0; j < num_values; j++) {
66 for (size_t k = 0; k < trailing_dims; ++k) {
67 output_data[k] = static_cast<CTYPE_OUT>(input_data[k]) * output_data[k];
68 }
69 input_data += trailing_dims;
70 output_data += trailing_dims;
71 }
72 }
73 }
74
75 /**
76 * Slice the tensor in the given dim, from start to end, assume tensor in and
77 * out have same shape and dtype, the dim is a non-negative number and start,
78 * end are valid non-negative number
79 */
80 template <typename CTYPE_IN, typename CTYPE_OUT>
slice_tensor(const Tensor & in,int64_t dim,int64_t start,int64_t end,Tensor & out)81 void slice_tensor(
82 const Tensor& in,
83 int64_t dim,
84 int64_t start,
85 int64_t end,
86 Tensor& out) {
87 size_t num_values = static_cast<size_t>(end - start);
88 size_t dim_length_in = static_cast<size_t>(in.size(dim));
89 size_t dim_length_out = static_cast<size_t>(out.size(dim));
90 size_t non_negative_start = static_cast<size_t>(start);
91 size_t leading_dims = getLeadingDims(in, dim);
92 size_t trailing_dims = getTrailingDims(in, dim);
93
94 const CTYPE_IN* input_data_base = in.const_data_ptr<CTYPE_IN>();
95 CTYPE_OUT* output_data_base = out.mutable_data_ptr<CTYPE_OUT>();
96
97 for (size_t i = 0; i < leading_dims; i++) {
98 const CTYPE_IN* input_data = input_data_base +
99 (i * dim_length_in + non_negative_start) * trailing_dims;
100 CTYPE_OUT* output_data =
101 output_data_base + i * dim_length_out * trailing_dims;
102 for (size_t j = 0; j < num_values; j++) {
103 for (size_t k = 0; k < trailing_dims; ++k) {
104 output_data[k] = static_cast<CTYPE_OUT>(input_data[k]);
105 }
106 input_data += trailing_dims;
107 output_data += trailing_dims;
108 }
109 }
110 }
111
112 /**
113 * Applies the gated linear unit function
114 *
115 * Based on the characteristic of glu function, the output should be in
116 * floating point type (Float and Double). The input and output tensors don't
117 * necessarily need to have the same type. Here are the assertions:
118 * 1. The input shall be in any float types (Float, Double)
119 * 2. The output shall be in float types (Float, Double)
120 */
121 template <typename CTYPE_IN, typename CTYPE_OUT>
glu_out_tensor(const Tensor & self,int64_t dim,Tensor & out)122 Tensor& glu_out_tensor(const Tensor& self, int64_t dim, Tensor& out) {
123 const auto self_size = self.size(dim);
124 slice_tensor<CTYPE_IN, CTYPE_OUT>(self, dim, self_size / 2, self_size, out);
125 sigmoid_tensor<CTYPE_OUT>(out);
126 mul_tensors<CTYPE_IN, CTYPE_OUT>(self, dim, out);
127 return out;
128 }
129 } // namespace
130
131 /**
132 * Applies the gated linear unit function
133 *
134 * Based on the characteristic of glu function, the output should be in
135 * floating point type (Float and Double). The input and output tensors don't
136 * necessarily need to have the same type. Here are the assertions:
137 * 1. The input shall be in any float types (Float, Double)
138 * 2. The output shall be in float types (Float, Double)
139 */
glu_out(KernelRuntimeContext & ctx,const Tensor & self,int64_t dim,Tensor & out)140 Tensor& glu_out(
141 KernelRuntimeContext& ctx,
142 const Tensor& self,
143 int64_t dim,
144 Tensor& out) {
145 (void)ctx;
146
147 ET_KERNEL_CHECK(
148 ctx, resize_glu_out(self, dim, out) == Error::Ok, InvalidArgument, out);
149
150 ET_KERNEL_CHECK(
151 ctx, tensors_have_same_dim_order(self, out), InvalidArgument, out);
152
153 ET_KERNEL_CHECK(ctx, check_glu_args(self, dim, out), InvalidArgument, out);
154
155 const size_t non_negative_dim = dim < 0 ? dim + self.dim() : dim;
156 const auto in_dtype = self.scalar_type();
157
158 ET_SWITCH_FLOAT_TYPES(in_dtype, ctx, "glu", CTYPE_IN, [&]() {
159 if (out.scalar_type() == ScalarType::Float) {
160 glu_out_tensor<CTYPE_IN, float>(self, non_negative_dim, out);
161 } else {
162 glu_out_tensor<CTYPE_IN, double>(self, non_negative_dim, out);
163 }
164 });
165
166 return out;
167 }
168
169 } // namespace native
170 } // namespace executor
171 } // namespace torch
172