xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/util/dtype_util.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <executorch/runtime/kernel/kernel_includes.h>
12 
13 namespace torch {
14 namespace executor {
15 namespace native {
16 namespace utils {
17 namespace internal {
18 
19 template <typename To, typename From>
load_and_convert(const void * fromPtr)20 To load_and_convert(const void* fromPtr) {
21   return static_cast<To>(*reinterpret_cast<const From*>(fromPtr));
22 }
23 
24 template <typename To, typename From>
convert_and_store(From f,void * dst)25 void convert_and_store(From f, void* dst) {
26   *reinterpret_cast<To*>(dst) = static_cast<To>(f);
27 }
28 
29 template <typename CTYPE_COMMON>
30 using load_to_common_fn = CTYPE_COMMON (*)(const void*);
31 
32 template <typename CTYPE_COMMON, const char* op_name>
get_load_to_common_fn_realhbbf16(const Tensor & t)33 load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_realhbbf16(
34     const Tensor& t) {
35   CTYPE_COMMON (*result)(const void*) = nullptr;
36   ET_SWITCH_REALHBBF16_TYPES(
37       t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
38         result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
39       });
40   return result;
41 }
42 
43 template <typename CTYPE_COMMON, const char* op_name>
get_load_to_common_fn_realhbf16(const Tensor & t)44 load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_realhbf16(
45     const Tensor& t) {
46   CTYPE_COMMON (*result)(const void*) = nullptr;
47   ET_SWITCH_REALHBF16_TYPES(
48       t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
49         result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
50       });
51   return result;
52 }
53 
54 template <typename CTYPE_COMMON, const char* op_name>
get_load_to_common_fn_floathbf16(const Tensor & t)55 load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_floathbf16(
56     const Tensor& t) {
57   CTYPE_COMMON (*result)(const void*) = nullptr;
58   ET_SWITCH_FLOATHBF16_TYPES(
59       t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
60         result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
61       });
62   return result;
63 }
64 
65 template <typename CTYPE_COMMON, const char* op_name>
get_load_to_common_fn_intb(const Tensor & t)66 load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_intb(const Tensor& t) {
67   CTYPE_COMMON (*result)(const void*) = nullptr;
68   ET_SWITCH_INT_TYPES_AND(
69       Bool, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
70         result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
71       });
72   return result;
73 }
74 
75 template <typename CTYPE_COMMON, const char* op_name>
get_load_to_common_fn_bool_or_byte(const Tensor & t)76 load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_bool_or_byte(
77     const Tensor& t) {
78   CTYPE_COMMON (*result)(const void*) = nullptr;
79   ET_SWITCH_TWO_TYPES(
80       Bool, Byte, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
81         result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
82       });
83   return result;
84 }
85 
86 template <typename CTYPE_COMMON, const char* op_name>
get_load_to_common_fn_same_as_compute(const Tensor & t)87 load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_same_as_compute(
88     const Tensor& t) {
89   constexpr auto common_scalar_type = CppTypeToScalarType<CTYPE_COMMON>::value;
90   ET_CHECK_MSG(
91       t.scalar_type() == common_scalar_type,
92       "Unhandled dtype %s for %s",
93       ::executorch::runtime::toString(common_scalar_type),
94       op_name);
95   return internal::load_and_convert<CTYPE_COMMON, CTYPE_COMMON>;
96 }
97 
98 template <
99     typename CTYPE_COMMON,
100     const char* op_name,
101     std::enable_if_t<std::is_same_v<CTYPE_COMMON, float>, bool> = true>
get_load_to_common_fn_same_as_common(const Tensor & t)102 load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_same_as_common(
103     const Tensor& t) {
104   CTYPE_COMMON (*result)(const void*) = nullptr;
105   ET_SWITCH_THREE_TYPES(
106       Float, Half, BFloat16, t.scalar_type(), unused, op_name, T, [&]() {
107         result = internal::load_and_convert<CTYPE_COMMON, T>;
108       });
109   return result;
110 }
111 
112 template <
113     typename CTYPE_COMMON,
114     const char* op_name,
115     std::enable_if_t<!std::is_same_v<CTYPE_COMMON, float>, bool> = true>
get_load_to_common_fn_same_as_common(const Tensor & t)116 load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_same_as_common(
117     const Tensor& t) {
118   return get_load_to_common_fn_same_as_compute<CTYPE_COMMON, op_name>(t);
119 }
120 
121 template <typename CTYPE_COMMON>
122 using store_common_to_tensor_fn = void (*)(CTYPE_COMMON, void*);
123 
124 template <typename CTYPE_COMMON, const char* op_name>
125 store_common_to_tensor_fn<CTYPE_COMMON>
get_store_common_to_tensor_fn_realhbbf16(const Tensor & t)126 get_store_common_to_tensor_fn_realhbbf16(const Tensor& t) {
127   void (*result)(CTYPE_COMMON, void*) = nullptr;
128   ET_SWITCH_REALHBBF16_TYPES(
129       t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
130         result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
131       });
132   return result;
133 }
134 
135 template <typename CTYPE_COMMON, const char* op_name>
get_store_common_to_tensor_fn_realhbf16(const Tensor & t)136 store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn_realhbf16(
137     const Tensor& t) {
138   void (*result)(CTYPE_COMMON, void*) = nullptr;
139   ET_SWITCH_REALHBF16_TYPES(
140       t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
141         result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
142       });
143   return result;
144 }
145 
146 template <typename CTYPE_COMMON, const char* op_name>
147 store_common_to_tensor_fn<CTYPE_COMMON>
get_store_common_to_tensor_fn_floathbf16(const Tensor & t)148 get_store_common_to_tensor_fn_floathbf16(const Tensor& t) {
149   void (*result)(CTYPE_COMMON, void*) = nullptr;
150   ET_SWITCH_FLOATHBF16_TYPES(
151       t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
152         result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
153       });
154   return result;
155 }
156 
157 template <typename CTYPE_COMMON, const char* op_name>
get_store_common_to_tensor_fn_intb(const Tensor & t)158 store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn_intb(
159     const Tensor& t) {
160   void (*result)(CTYPE_COMMON, void*) = nullptr;
161   ET_SWITCH_INT_TYPES_AND(
162       Bool, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
163         result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
164       });
165   return result;
166 }
167 
168 template <typename CTYPE_COMMON, const char* op_name>
169 store_common_to_tensor_fn<CTYPE_COMMON>
get_store_common_to_tensor_fn_bool_or_byte(const Tensor & t)170 get_store_common_to_tensor_fn_bool_or_byte(const Tensor& t) {
171   void (*result)(CTYPE_COMMON, void*) = nullptr;
172   ET_SWITCH_TWO_TYPES(
173       Bool, Byte, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
174         result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
175       });
176   return result;
177 }
178 
179 template <typename CTYPE_COMMON, const char* op_name>
180 store_common_to_tensor_fn<CTYPE_COMMON>
get_store_common_to_tensor_fn_same_as_compute(const Tensor & t)181 get_store_common_to_tensor_fn_same_as_compute(const Tensor& t) {
182   constexpr auto common_scalar_type = CppTypeToScalarType<CTYPE_COMMON>::value;
183   ET_CHECK_MSG(
184       t.scalar_type() == common_scalar_type,
185       "Unhandled dtype %s for %s",
186       ::executorch::runtime::toString(common_scalar_type),
187       op_name);
188   return internal::convert_and_store<CTYPE_COMMON, CTYPE_COMMON>;
189 }
190 
191 template <
192     typename CTYPE_COMMON,
193     const char* op_name,
194     std::enable_if_t<std::is_same_v<CTYPE_COMMON, float>, bool> = true>
195 store_common_to_tensor_fn<CTYPE_COMMON>
get_store_common_to_tensor_fn_same_as_common(const Tensor & t)196 get_store_common_to_tensor_fn_same_as_common(const Tensor& t) {
197   void (*result)(CTYPE_COMMON, void*) = nullptr;
198   ET_SWITCH_THREE_TYPES(
199       Float, Half, BFloat16, t.scalar_type(), unused, op_name, CTYPE, [&]() {
200         result = internal::convert_and_store<CTYPE, CTYPE_COMMON>;
201       });
202   return result;
203 }
204 
205 template <
206     typename CTYPE_COMMON,
207     const char* op_name,
208     std::enable_if_t<!std::is_same_v<CTYPE_COMMON, float>, bool> = true>
209 store_common_to_tensor_fn<CTYPE_COMMON>
get_store_common_to_tensor_fn_same_as_common(const Tensor & t)210 get_store_common_to_tensor_fn_same_as_common(const Tensor& t) {
211   return get_store_common_to_tensor_fn_same_as_compute<CTYPE_COMMON, op_name>(
212       t);
213 }
214 
215 } // namespace internal
216 
217 enum class SupportedTensorDtypes {
218   REALHBBF16,
219   REALHBF16,
220   FLOATHBF16,
221   INTB,
222   BOOL_OR_BYTE,
223   SAME_AS_COMPUTE,
224   SAME_AS_COMMON,
225 };
226 
227 namespace internal {
228 
229 template <typename CTYPE_COMMON, const char* op_name>
get_load_to_common_fn(const Tensor & t,SupportedTensorDtypes dtypes)230 load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn(
231     const Tensor& t,
232     SupportedTensorDtypes dtypes) {
233   switch (dtypes) {
234     case SupportedTensorDtypes::REALHBBF16:
235       return get_load_to_common_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
236     case SupportedTensorDtypes::REALHBF16:
237       return get_load_to_common_fn_realhbf16<CTYPE_COMMON, op_name>(t);
238     case SupportedTensorDtypes::FLOATHBF16:
239       return get_load_to_common_fn_realhbf16<CTYPE_COMMON, op_name>(t);
240     case SupportedTensorDtypes::INTB:
241       return get_load_to_common_fn_intb<CTYPE_COMMON, op_name>(t);
242     case SupportedTensorDtypes::BOOL_OR_BYTE:
243       return get_load_to_common_fn_bool_or_byte<CTYPE_COMMON, op_name>(t);
244     case SupportedTensorDtypes::SAME_AS_COMPUTE:
245       return get_load_to_common_fn_same_as_compute<CTYPE_COMMON, op_name>(t);
246     case SupportedTensorDtypes::SAME_AS_COMMON:
247       return get_load_to_common_fn_same_as_common<CTYPE_COMMON, op_name>(t);
248   }
249   ET_CHECK(false);
250   return nullptr;
251 }
252 
253 template <typename CTYPE_COMMON, const char* op_name>
get_store_common_to_tensor_fn(const Tensor & t,SupportedTensorDtypes dtypes)254 store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn(
255     const Tensor& t,
256     SupportedTensorDtypes dtypes) {
257   switch (dtypes) {
258     case SupportedTensorDtypes::REALHBBF16:
259       return get_store_common_to_tensor_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
260     case SupportedTensorDtypes::REALHBF16:
261       return get_store_common_to_tensor_fn_realhbf16<CTYPE_COMMON, op_name>(t);
262     case SupportedTensorDtypes::FLOATHBF16:
263       return get_store_common_to_tensor_fn_floathbf16<CTYPE_COMMON, op_name>(t);
264     case SupportedTensorDtypes::INTB:
265       return get_store_common_to_tensor_fn_intb<CTYPE_COMMON, op_name>(t);
266     case SupportedTensorDtypes::BOOL_OR_BYTE:
267       return get_store_common_to_tensor_fn_bool_or_byte<CTYPE_COMMON, op_name>(
268           t);
269     case SupportedTensorDtypes::SAME_AS_COMPUTE:
270       return get_store_common_to_tensor_fn_same_as_compute<
271           CTYPE_COMMON,
272           op_name>(t);
273     case SupportedTensorDtypes::SAME_AS_COMMON: {
274       return get_store_common_to_tensor_fn_same_as_common<
275           CTYPE_COMMON,
276           op_name>(t);
277     }
278   }
279   ET_CHECK(false);
280   return nullptr;
281 }
282 
283 bool check_tensor_dtype(
284     const Tensor t,
285     SupportedTensorDtypes dtypes,
286     const ScalarType compute_type);
287 
288 } // namespace internal
289 } // namespace utils
290 } // namespace native
291 } // namespace executor
292 } // namespace torch
293