xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/op_glu.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/portable/cpu/util/activation_ops_util.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11 #include <executorch/runtime/platform/assert.h>
12 #include <cinttypes>
13 #include <cmath>
14 #include <cstdint>
15 
16 namespace torch {
17 namespace executor {
18 namespace native {
19 
20 using Tensor = exec_aten::Tensor;
21 using ScalarType = exec_aten::ScalarType;
22 
23 namespace {
24 
exp_overload(double d)25 double exp_overload(double d) {
26   return exp(d);
27 }
28 
exp_overload(float f)29 float exp_overload(float f) {
30   return expf(f);
31 }
32 
33 /**
34  * In-place element-wise sigmoid function , i.e., f(x) = 1 / (1 + e^{-x})
35  */
36 // TODO: T146333648, refactor this as a common helper function
37 template <typename CTYPE_OUT>
sigmoid_tensor(Tensor & out)38 void sigmoid_tensor(Tensor& out) {
39   CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
40   for (size_t i = 0; i < out.numel(); i++) {
41     out_data[i] = 1.0 / (1.0 + exp_overload(-out_data[i]));
42   }
43 }
44 
45 /**
46  * Element-wise multiplication of the first half of `in` along the specified
47  * dimension and `out`, overwriting `out`.
48  */
49 template <typename CTYPE_IN, typename CTYPE_OUT>
mul_tensors(const Tensor & in,int64_t dim,Tensor & out)50 void mul_tensors(const Tensor& in, int64_t dim, Tensor& out) {
51   size_t num_values = static_cast<size_t>(in.size(dim)) / 2;
52   size_t dim_length_in = static_cast<size_t>(in.size(dim));
53   size_t dim_length_out = static_cast<size_t>(out.size(dim));
54   size_t leading_dims = getLeadingDims(in, dim);
55   size_t trailing_dims = getTrailingDims(in, dim);
56 
57   const CTYPE_IN* input_data_base = in.const_data_ptr<CTYPE_IN>();
58   CTYPE_OUT* output_data_base = out.mutable_data_ptr<CTYPE_OUT>();
59 
60   for (size_t i = 0; i < leading_dims; i++) {
61     const CTYPE_IN* input_data =
62         input_data_base + i * dim_length_in * trailing_dims;
63     CTYPE_OUT* output_data =
64         output_data_base + i * dim_length_out * trailing_dims;
65     for (size_t j = 0; j < num_values; j++) {
66       for (size_t k = 0; k < trailing_dims; ++k) {
67         output_data[k] = static_cast<CTYPE_OUT>(input_data[k]) * output_data[k];
68       }
69       input_data += trailing_dims;
70       output_data += trailing_dims;
71     }
72   }
73 }
74 
75 /**
76  * Slice the tensor in the given dim, from start to end, assume tensor in and
77  * out have same shape and dtype, the dim is a non-negative number and start,
78  * end are valid non-negative number
79  */
80 template <typename CTYPE_IN, typename CTYPE_OUT>
slice_tensor(const Tensor & in,int64_t dim,int64_t start,int64_t end,Tensor & out)81 void slice_tensor(
82     const Tensor& in,
83     int64_t dim,
84     int64_t start,
85     int64_t end,
86     Tensor& out) {
87   size_t num_values = static_cast<size_t>(end - start);
88   size_t dim_length_in = static_cast<size_t>(in.size(dim));
89   size_t dim_length_out = static_cast<size_t>(out.size(dim));
90   size_t non_negative_start = static_cast<size_t>(start);
91   size_t leading_dims = getLeadingDims(in, dim);
92   size_t trailing_dims = getTrailingDims(in, dim);
93 
94   const CTYPE_IN* input_data_base = in.const_data_ptr<CTYPE_IN>();
95   CTYPE_OUT* output_data_base = out.mutable_data_ptr<CTYPE_OUT>();
96 
97   for (size_t i = 0; i < leading_dims; i++) {
98     const CTYPE_IN* input_data = input_data_base +
99         (i * dim_length_in + non_negative_start) * trailing_dims;
100     CTYPE_OUT* output_data =
101         output_data_base + i * dim_length_out * trailing_dims;
102     for (size_t j = 0; j < num_values; j++) {
103       for (size_t k = 0; k < trailing_dims; ++k) {
104         output_data[k] = static_cast<CTYPE_OUT>(input_data[k]);
105       }
106       input_data += trailing_dims;
107       output_data += trailing_dims;
108     }
109   }
110 }
111 
112 /**
113  * Applies the gated linear unit function
114  *
115  * Based on the characteristic of glu function, the output should be in
116  * floating point type (Float and Double). The input and output tensors don't
117  * necessarily need to have the same type. Here are the assertions:
118  *  1. The input shall be in any float types (Float, Double)
119  *  2. The output shall be in float types (Float, Double)
120  */
121 template <typename CTYPE_IN, typename CTYPE_OUT>
glu_out_tensor(const Tensor & self,int64_t dim,Tensor & out)122 Tensor& glu_out_tensor(const Tensor& self, int64_t dim, Tensor& out) {
123   const auto self_size = self.size(dim);
124   slice_tensor<CTYPE_IN, CTYPE_OUT>(self, dim, self_size / 2, self_size, out);
125   sigmoid_tensor<CTYPE_OUT>(out);
126   mul_tensors<CTYPE_IN, CTYPE_OUT>(self, dim, out);
127   return out;
128 }
129 } // namespace
130 
131 /**
132  * Applies the gated linear unit function
133  *
134  * Based on the characteristic of glu function, the output should be in
135  * floating point type (Float and Double). The input and output tensors don't
136  * necessarily need to have the same type. Here are the assertions:
137  *  1. The input shall be in any float types (Float, Double)
138  *  2. The output shall be in float types (Float, Double)
139  */
glu_out(KernelRuntimeContext & ctx,const Tensor & self,int64_t dim,Tensor & out)140 Tensor& glu_out(
141     KernelRuntimeContext& ctx,
142     const Tensor& self,
143     int64_t dim,
144     Tensor& out) {
145   (void)ctx;
146 
147   ET_KERNEL_CHECK(
148       ctx, resize_glu_out(self, dim, out) == Error::Ok, InvalidArgument, out);
149 
150   ET_KERNEL_CHECK(
151       ctx, tensors_have_same_dim_order(self, out), InvalidArgument, out);
152 
153   ET_KERNEL_CHECK(ctx, check_glu_args(self, dim, out), InvalidArgument, out);
154 
155   const size_t non_negative_dim = dim < 0 ? dim + self.dim() : dim;
156   const auto in_dtype = self.scalar_type();
157 
158   ET_SWITCH_FLOAT_TYPES(in_dtype, ctx, "glu", CTYPE_IN, [&]() {
159     if (out.scalar_type() == ScalarType::Float) {
160       glu_out_tensor<CTYPE_IN, float>(self, non_negative_dim, out);
161     } else {
162       glu_out_tensor<CTYPE_IN, double>(self, non_negative_dim, out);
163     }
164   });
165 
166   return out;
167 }
168 
169 } // namespace native
170 } // namespace executor
171 } // namespace torch
172