xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/util/elementwise_util.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <executorch/kernels/portable/cpu/util/broadcast_util.h>
12 #include <executorch/kernels/portable/cpu/util/dtype_util.h>
13 #include <executorch/runtime/kernel/kernel_runtime_context.h>
14 
15 namespace torch {
16 namespace executor {
17 namespace native {
18 namespace utils {
19 
20 /*
21  * Convert Scalar to C++ type
22  */
23 
24 template <typename T>
scalar_to(const Scalar & s)25 T scalar_to(const Scalar& s) {
26   if (s.isBoolean()) {
27     return static_cast<T>(s.to<bool>());
28   } else if (s.isFloatingPoint()) {
29     return static_cast<T>(s.to<double>());
30   } else {
31     return static_cast<T>(s.to<int64_t>());
32   }
33 }
34 
35 template <>
36 inline double scalar_to<double>(const Scalar& s) {
37   return s.isFloatingPoint() ? s.to<double>()
38                              : static_cast<double>(s.to<int64_t>());
39 }
40 
41 template <>
42 inline int64_t scalar_to<int64_t>(const Scalar& s) {
43   return s.isFloatingPoint() ? static_cast<int64_t>(s.to<double>())
44                              : s.to<int64_t>();
45 }
46 
47 template <typename CTYPE_COMMON, const char* op_name, typename Op>
apply_unitensor_elementwise_fn(const Op & compute_fun,KernelRuntimeContext & ctx,const Tensor & a,SupportedTensorDtypes a_dtypes,const Tensor & out,SupportedTensorDtypes out_dtypes)48 inline void apply_unitensor_elementwise_fn(
49     const Op& compute_fun,
50     KernelRuntimeContext& ctx,
51     const Tensor& a,
52     SupportedTensorDtypes a_dtypes,
53     const Tensor& out,
54     SupportedTensorDtypes out_dtypes) {
55   constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
56 
57   ET_KERNEL_CHECK(
58       ctx,
59       (internal::check_tensor_dtype(a, a_dtypes, compute_type) &&
60        internal::check_tensor_dtype(out, out_dtypes, compute_type)),
61       InvalidArgument, );
62 
63   const auto load_a_to_common =
64       internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
65   const auto store_common_to_out =
66       internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
67           out, out_dtypes);
68   const char* const data_a = reinterpret_cast<const char*>(a.const_data_ptr());
69   const auto a_element_size = a.element_size();
70   const auto out_element_size = out.element_size();
71   char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());
72 
73   auto out_numel = out.numel();
74   for (size_t i = 0; i < out_numel; ++i) {
75     auto result = compute_fun(load_a_to_common(&data_a[i * a_element_size]));
76     store_common_to_out(result, &data_out[i * out_element_size]);
77   }
78 }
79 
80 /**
81  * Useful for bi-tensor elementwise operators. For each element of the inputs,
82  * perform a computation and write to the corresponding element of the output.
83  * Tensor broadcasting is applied wherever it is required.
84  */
85 template <typename CTYPE_COMMON, const char* op_name, typename Op>
apply_bitensor_elementwise_fn(const Op & compute_fun,KernelRuntimeContext & ctx,const Tensor & a,SupportedTensorDtypes a_dtypes,const Tensor & b,SupportedTensorDtypes b_dtypes,const Tensor & out,SupportedTensorDtypes out_dtypes)86 inline void apply_bitensor_elementwise_fn(
87     const Op& compute_fun,
88     KernelRuntimeContext& ctx,
89     const Tensor& a,
90     SupportedTensorDtypes a_dtypes,
91     const Tensor& b,
92     SupportedTensorDtypes b_dtypes,
93     const Tensor& out,
94     SupportedTensorDtypes out_dtypes) {
95   constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
96 
97   ET_KERNEL_CHECK(
98       ctx,
99       (internal::check_tensor_dtype(a, a_dtypes, compute_type) &&
100        internal::check_tensor_dtype(b, b_dtypes, compute_type) &&
101        internal::check_tensor_dtype(out, out_dtypes, compute_type)),
102       InvalidArgument, );
103 
104   const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
105   const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
106   const bool any_is_broadcasted = (a_is_broadcasted || b_is_broadcasted);
107 
108   const auto load_a_to_common =
109       internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
110   const auto load_b_to_common =
111       internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(b, b_dtypes);
112   const auto store_common_to_out =
113       internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
114           out, out_dtypes);
115   const char* const data_a = reinterpret_cast<const char*>(a.const_data_ptr());
116   const char* const data_b = reinterpret_cast<const char*>(b.const_data_ptr());
117   const auto a_element_size = a.element_size();
118   const auto b_element_size = b.element_size();
119   const auto out_element_size = out.element_size();
120   char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());
121 
122   auto out_numel = out.numel();
123   for (size_t i = 0; i < out_numel; ++i) {
124     size_t a_linear_index = i;
125     size_t b_linear_index = i;
126 
127     if (any_is_broadcasted) {
128       size_t out_indexes[kTensorDimensionLimit];
129       delinearize_index(i, out, out_indexes, kTensorDimensionLimit);
130 
131       if (a_is_broadcasted) {
132         a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a);
133       }
134       if (b_is_broadcasted) {
135         b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b);
136       }
137     }
138 
139     auto result = compute_fun(
140         load_a_to_common(&data_a[a_linear_index * a_element_size]),
141         load_b_to_common(&data_b[b_linear_index * b_element_size]));
142     store_common_to_out(result, &data_out[i * out_element_size]);
143   }
144 }
145 
146 /**
147  * Useful for tri-tensor elementwise operators. For each element of the
148  * inputs, perform a computation and write to the corresponding element of the
149  * output. Tensor broadcasting is applied wherever it is required.
150  *
151  * In order to mitigate build time cost (straightforwardly |CTYPE_A| *
152  * |CTYPE_B| * |CTYPE_C| * |CTYPE_OUT|), all arguments to compute_fun
153  * are passed as CTYPE_COMMON.
154  *
155  * Each tensor's supported dtypes set must be provided. The tensor
156  * will be checked to ensure that its dtype falls into that set.
157  *
158  * op_name is used to support dtype selective build, as with the
159  * ET_SWITCH family of macros. Note: because of C++17 quirks, you
160  * can't pass a string literal for op_name. Instead, you should do the
161  * following:
162  *
163  * static constexpr const char op_name[] = "my_op";
164  * apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>.
165  */
166 template <typename CTYPE_COMMON, const char* op_name, typename Op>
apply_tritensor_elementwise_fn(const Op & compute_fun,KernelRuntimeContext & ctx,const Tensor & a,SupportedTensorDtypes a_dtypes,const Tensor & b,SupportedTensorDtypes b_dtypes,const Tensor & c,SupportedTensorDtypes c_dtypes,const Tensor & out,SupportedTensorDtypes out_dtypes)167 inline void apply_tritensor_elementwise_fn(
168     const Op& compute_fun,
169     KernelRuntimeContext& ctx,
170     const Tensor& a,
171     SupportedTensorDtypes a_dtypes,
172     const Tensor& b,
173     SupportedTensorDtypes b_dtypes,
174     const Tensor& c,
175     SupportedTensorDtypes c_dtypes,
176     const Tensor& out,
177     SupportedTensorDtypes out_dtypes) {
178   constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
179 
180   ET_KERNEL_CHECK(
181       ctx,
182       (internal::check_tensor_dtype(a, a_dtypes, compute_type) &&
183        internal::check_tensor_dtype(b, b_dtypes, compute_type) &&
184        internal::check_tensor_dtype(c, c_dtypes, compute_type) &&
185        internal::check_tensor_dtype(out, out_dtypes, compute_type)),
186       InvalidArgument, );
187 
188   const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
189   const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
190   const bool c_is_broadcasted = !out.sizes().equals(c.sizes());
191   const bool any_is_broadcasted =
192       (a_is_broadcasted || b_is_broadcasted || c_is_broadcasted);
193 
194   const auto load_a_to_common =
195       internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
196   const auto load_b_to_common =
197       internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(b, b_dtypes);
198   const auto load_c_to_common =
199       internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(c, c_dtypes);
200   const auto store_common_to_out =
201       internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
202           out, out_dtypes);
203   const char* const data_a = reinterpret_cast<const char*>(a.const_data_ptr());
204   const char* const data_b = reinterpret_cast<const char*>(b.const_data_ptr());
205   const char* const data_c = reinterpret_cast<const char*>(c.const_data_ptr());
206   const auto a_element_size = a.element_size();
207   const auto b_element_size = b.element_size();
208   const auto c_element_size = c.element_size();
209   const auto out_element_size = out.element_size();
210   char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());
211 
212   auto out_numel = out.numel();
213   for (size_t i = 0; i < out_numel; ++i) {
214     size_t a_linear_index = i;
215     size_t b_linear_index = i;
216     size_t c_linear_index = i;
217 
218     if (any_is_broadcasted) {
219       size_t out_indexes[kTensorDimensionLimit];
220       delinearize_index(i, out, out_indexes, kTensorDimensionLimit);
221 
222       if (a_is_broadcasted) {
223         a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a);
224       }
225       if (b_is_broadcasted) {
226         b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b);
227       }
228       if (c_is_broadcasted) {
229         c_linear_index = linearize_access_indexes(out_indexes, out.dim(), c);
230       }
231     }
232 
233     auto result = compute_fun(
234         load_a_to_common(&data_a[a_linear_index * a_element_size]),
235         load_b_to_common(&data_b[b_linear_index * b_element_size]),
236         load_c_to_common(&data_c[c_linear_index * c_element_size]));
237     store_common_to_out(result, &data_out[i * out_element_size]);
238   }
239 }
240 
get_compute_type(ScalarType & common_type)241 inline ScalarType get_compute_type(ScalarType& common_type) {
242   ScalarType compute_type = common_type;
243   if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) {
244     compute_type = ScalarType::Float;
245   }
246   return compute_type;
247 }
248 
249 } // namespace utils
250 } // namespace native
251 } // namespace executor
252 } // namespace torch
253