1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #pragma once
10
11 #include <executorch/kernels/portable/cpu/util/broadcast_util.h>
12 #include <executorch/kernels/portable/cpu/util/dtype_util.h>
13 #include <executorch/runtime/kernel/kernel_runtime_context.h>
14
15 namespace torch {
16 namespace executor {
17 namespace native {
18 namespace utils {
19
20 /*
21 * Convert Scalar to C++ type
22 */
23
24 template <typename T>
scalar_to(const Scalar & s)25 T scalar_to(const Scalar& s) {
26 if (s.isBoolean()) {
27 return static_cast<T>(s.to<bool>());
28 } else if (s.isFloatingPoint()) {
29 return static_cast<T>(s.to<double>());
30 } else {
31 return static_cast<T>(s.to<int64_t>());
32 }
33 }
34
35 template <>
36 inline double scalar_to<double>(const Scalar& s) {
37 return s.isFloatingPoint() ? s.to<double>()
38 : static_cast<double>(s.to<int64_t>());
39 }
40
41 template <>
42 inline int64_t scalar_to<int64_t>(const Scalar& s) {
43 return s.isFloatingPoint() ? static_cast<int64_t>(s.to<double>())
44 : s.to<int64_t>();
45 }
46
47 template <typename CTYPE_COMMON, const char* op_name, typename Op>
apply_unitensor_elementwise_fn(const Op & compute_fun,KernelRuntimeContext & ctx,const Tensor & a,SupportedTensorDtypes a_dtypes,const Tensor & out,SupportedTensorDtypes out_dtypes)48 inline void apply_unitensor_elementwise_fn(
49 const Op& compute_fun,
50 KernelRuntimeContext& ctx,
51 const Tensor& a,
52 SupportedTensorDtypes a_dtypes,
53 const Tensor& out,
54 SupportedTensorDtypes out_dtypes) {
55 constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
56
57 ET_KERNEL_CHECK(
58 ctx,
59 (internal::check_tensor_dtype(a, a_dtypes, compute_type) &&
60 internal::check_tensor_dtype(out, out_dtypes, compute_type)),
61 InvalidArgument, );
62
63 const auto load_a_to_common =
64 internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
65 const auto store_common_to_out =
66 internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
67 out, out_dtypes);
68 const char* const data_a = reinterpret_cast<const char*>(a.const_data_ptr());
69 const auto a_element_size = a.element_size();
70 const auto out_element_size = out.element_size();
71 char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());
72
73 auto out_numel = out.numel();
74 for (size_t i = 0; i < out_numel; ++i) {
75 auto result = compute_fun(load_a_to_common(&data_a[i * a_element_size]));
76 store_common_to_out(result, &data_out[i * out_element_size]);
77 }
78 }
79
80 /**
81 * Useful for bi-tensor elementwise operators. For each element of the inputs,
82 * perform a computation and write to the corresponding element of the output.
83 * Tensor broadcasting is applied wherever it is required.
84 */
85 template <typename CTYPE_COMMON, const char* op_name, typename Op>
apply_bitensor_elementwise_fn(const Op & compute_fun,KernelRuntimeContext & ctx,const Tensor & a,SupportedTensorDtypes a_dtypes,const Tensor & b,SupportedTensorDtypes b_dtypes,const Tensor & out,SupportedTensorDtypes out_dtypes)86 inline void apply_bitensor_elementwise_fn(
87 const Op& compute_fun,
88 KernelRuntimeContext& ctx,
89 const Tensor& a,
90 SupportedTensorDtypes a_dtypes,
91 const Tensor& b,
92 SupportedTensorDtypes b_dtypes,
93 const Tensor& out,
94 SupportedTensorDtypes out_dtypes) {
95 constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
96
97 ET_KERNEL_CHECK(
98 ctx,
99 (internal::check_tensor_dtype(a, a_dtypes, compute_type) &&
100 internal::check_tensor_dtype(b, b_dtypes, compute_type) &&
101 internal::check_tensor_dtype(out, out_dtypes, compute_type)),
102 InvalidArgument, );
103
104 const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
105 const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
106 const bool any_is_broadcasted = (a_is_broadcasted || b_is_broadcasted);
107
108 const auto load_a_to_common =
109 internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
110 const auto load_b_to_common =
111 internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(b, b_dtypes);
112 const auto store_common_to_out =
113 internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
114 out, out_dtypes);
115 const char* const data_a = reinterpret_cast<const char*>(a.const_data_ptr());
116 const char* const data_b = reinterpret_cast<const char*>(b.const_data_ptr());
117 const auto a_element_size = a.element_size();
118 const auto b_element_size = b.element_size();
119 const auto out_element_size = out.element_size();
120 char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());
121
122 auto out_numel = out.numel();
123 for (size_t i = 0; i < out_numel; ++i) {
124 size_t a_linear_index = i;
125 size_t b_linear_index = i;
126
127 if (any_is_broadcasted) {
128 size_t out_indexes[kTensorDimensionLimit];
129 delinearize_index(i, out, out_indexes, kTensorDimensionLimit);
130
131 if (a_is_broadcasted) {
132 a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a);
133 }
134 if (b_is_broadcasted) {
135 b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b);
136 }
137 }
138
139 auto result = compute_fun(
140 load_a_to_common(&data_a[a_linear_index * a_element_size]),
141 load_b_to_common(&data_b[b_linear_index * b_element_size]));
142 store_common_to_out(result, &data_out[i * out_element_size]);
143 }
144 }
145
146 /**
147 * Useful for tri-tensor elementwise operators. For each element of the
148 * inputs, perform a computation and write to the corresponding element of the
149 * output. Tensor broadcasting is applied wherever it is required.
150 *
151 * In order to mitigate build time cost (straightforwardly |CTYPE_A| *
152 * |CTYPE_B| * |CTYPE_C| * |CTYPE_OUT|), all arguments to compute_fun
153 * are passed as CTYPE_COMMON.
154 *
155 * Each tensor's supported dtypes set must be provided. The tensor
156 * will be checked to ensure that its dtype falls into that set.
157 *
158 * op_name is used to support dtype selective build, as with the
159 * ET_SWITCH family of macros. Note: because of C++17 quirks, you
160 * can't pass a string literal for op_name. Instead, you should do the
161 * following:
162 *
163 * static constexpr const char op_name[] = "my_op";
164 * apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>.
165 */
166 template <typename CTYPE_COMMON, const char* op_name, typename Op>
apply_tritensor_elementwise_fn(const Op & compute_fun,KernelRuntimeContext & ctx,const Tensor & a,SupportedTensorDtypes a_dtypes,const Tensor & b,SupportedTensorDtypes b_dtypes,const Tensor & c,SupportedTensorDtypes c_dtypes,const Tensor & out,SupportedTensorDtypes out_dtypes)167 inline void apply_tritensor_elementwise_fn(
168 const Op& compute_fun,
169 KernelRuntimeContext& ctx,
170 const Tensor& a,
171 SupportedTensorDtypes a_dtypes,
172 const Tensor& b,
173 SupportedTensorDtypes b_dtypes,
174 const Tensor& c,
175 SupportedTensorDtypes c_dtypes,
176 const Tensor& out,
177 SupportedTensorDtypes out_dtypes) {
178 constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
179
180 ET_KERNEL_CHECK(
181 ctx,
182 (internal::check_tensor_dtype(a, a_dtypes, compute_type) &&
183 internal::check_tensor_dtype(b, b_dtypes, compute_type) &&
184 internal::check_tensor_dtype(c, c_dtypes, compute_type) &&
185 internal::check_tensor_dtype(out, out_dtypes, compute_type)),
186 InvalidArgument, );
187
188 const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
189 const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
190 const bool c_is_broadcasted = !out.sizes().equals(c.sizes());
191 const bool any_is_broadcasted =
192 (a_is_broadcasted || b_is_broadcasted || c_is_broadcasted);
193
194 const auto load_a_to_common =
195 internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
196 const auto load_b_to_common =
197 internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(b, b_dtypes);
198 const auto load_c_to_common =
199 internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(c, c_dtypes);
200 const auto store_common_to_out =
201 internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
202 out, out_dtypes);
203 const char* const data_a = reinterpret_cast<const char*>(a.const_data_ptr());
204 const char* const data_b = reinterpret_cast<const char*>(b.const_data_ptr());
205 const char* const data_c = reinterpret_cast<const char*>(c.const_data_ptr());
206 const auto a_element_size = a.element_size();
207 const auto b_element_size = b.element_size();
208 const auto c_element_size = c.element_size();
209 const auto out_element_size = out.element_size();
210 char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());
211
212 auto out_numel = out.numel();
213 for (size_t i = 0; i < out_numel; ++i) {
214 size_t a_linear_index = i;
215 size_t b_linear_index = i;
216 size_t c_linear_index = i;
217
218 if (any_is_broadcasted) {
219 size_t out_indexes[kTensorDimensionLimit];
220 delinearize_index(i, out, out_indexes, kTensorDimensionLimit);
221
222 if (a_is_broadcasted) {
223 a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a);
224 }
225 if (b_is_broadcasted) {
226 b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b);
227 }
228 if (c_is_broadcasted) {
229 c_linear_index = linearize_access_indexes(out_indexes, out.dim(), c);
230 }
231 }
232
233 auto result = compute_fun(
234 load_a_to_common(&data_a[a_linear_index * a_element_size]),
235 load_b_to_common(&data_b[b_linear_index * b_element_size]),
236 load_c_to_common(&data_c[c_linear_index * c_element_size]));
237 store_common_to_out(result, &data_out[i * out_element_size]);
238 }
239 }
240
get_compute_type(ScalarType & common_type)241 inline ScalarType get_compute_type(ScalarType& common_type) {
242 ScalarType compute_type = common_type;
243 if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) {
244 compute_type = ScalarType::Float;
245 }
246 return compute_type;
247 }
248
249 } // namespace utils
250 } // namespace native
251 } // namespace executor
252 } // namespace torch
253