1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #pragma once
10
11 #include <executorch/runtime/kernel/kernel_includes.h>
12
13 namespace torch {
14 namespace executor {
15 namespace native {
16 namespace utils {
17 namespace internal {
18
19 template <typename To, typename From>
load_and_convert(const void * fromPtr)20 To load_and_convert(const void* fromPtr) {
21 return static_cast<To>(*reinterpret_cast<const From*>(fromPtr));
22 }
23
24 template <typename To, typename From>
convert_and_store(From f,void * dst)25 void convert_and_store(From f, void* dst) {
26 *reinterpret_cast<To*>(dst) = static_cast<To>(f);
27 }
28
29 template <typename CTYPE_COMMON>
30 using load_to_common_fn = CTYPE_COMMON (*)(const void*);
31
32 template <typename CTYPE_COMMON, const char* op_name>
get_load_to_common_fn_realhbbf16(const Tensor & t)33 load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_realhbbf16(
34 const Tensor& t) {
35 CTYPE_COMMON (*result)(const void*) = nullptr;
36 ET_SWITCH_REALHBBF16_TYPES(
37 t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
38 result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
39 });
40 return result;
41 }
42
43 template <typename CTYPE_COMMON, const char* op_name>
get_load_to_common_fn_realhbf16(const Tensor & t)44 load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_realhbf16(
45 const Tensor& t) {
46 CTYPE_COMMON (*result)(const void*) = nullptr;
47 ET_SWITCH_REALHBF16_TYPES(
48 t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
49 result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
50 });
51 return result;
52 }
53
54 template <typename CTYPE_COMMON, const char* op_name>
get_load_to_common_fn_floathbf16(const Tensor & t)55 load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_floathbf16(
56 const Tensor& t) {
57 CTYPE_COMMON (*result)(const void*) = nullptr;
58 ET_SWITCH_FLOATHBF16_TYPES(
59 t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
60 result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
61 });
62 return result;
63 }
64
65 template <typename CTYPE_COMMON, const char* op_name>
get_load_to_common_fn_intb(const Tensor & t)66 load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_intb(const Tensor& t) {
67 CTYPE_COMMON (*result)(const void*) = nullptr;
68 ET_SWITCH_INT_TYPES_AND(
69 Bool, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
70 result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
71 });
72 return result;
73 }
74
75 template <typename CTYPE_COMMON, const char* op_name>
get_load_to_common_fn_bool_or_byte(const Tensor & t)76 load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_bool_or_byte(
77 const Tensor& t) {
78 CTYPE_COMMON (*result)(const void*) = nullptr;
79 ET_SWITCH_TWO_TYPES(
80 Bool, Byte, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
81 result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
82 });
83 return result;
84 }
85
86 template <typename CTYPE_COMMON, const char* op_name>
get_load_to_common_fn_same_as_compute(const Tensor & t)87 load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_same_as_compute(
88 const Tensor& t) {
89 constexpr auto common_scalar_type = CppTypeToScalarType<CTYPE_COMMON>::value;
90 ET_CHECK_MSG(
91 t.scalar_type() == common_scalar_type,
92 "Unhandled dtype %s for %s",
93 ::executorch::runtime::toString(common_scalar_type),
94 op_name);
95 return internal::load_and_convert<CTYPE_COMMON, CTYPE_COMMON>;
96 }
97
98 template <
99 typename CTYPE_COMMON,
100 const char* op_name,
101 std::enable_if_t<std::is_same_v<CTYPE_COMMON, float>, bool> = true>
get_load_to_common_fn_same_as_common(const Tensor & t)102 load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_same_as_common(
103 const Tensor& t) {
104 CTYPE_COMMON (*result)(const void*) = nullptr;
105 ET_SWITCH_THREE_TYPES(
106 Float, Half, BFloat16, t.scalar_type(), unused, op_name, T, [&]() {
107 result = internal::load_and_convert<CTYPE_COMMON, T>;
108 });
109 return result;
110 }
111
112 template <
113 typename CTYPE_COMMON,
114 const char* op_name,
115 std::enable_if_t<!std::is_same_v<CTYPE_COMMON, float>, bool> = true>
get_load_to_common_fn_same_as_common(const Tensor & t)116 load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_same_as_common(
117 const Tensor& t) {
118 return get_load_to_common_fn_same_as_compute<CTYPE_COMMON, op_name>(t);
119 }
120
121 template <typename CTYPE_COMMON>
122 using store_common_to_tensor_fn = void (*)(CTYPE_COMMON, void*);
123
124 template <typename CTYPE_COMMON, const char* op_name>
125 store_common_to_tensor_fn<CTYPE_COMMON>
get_store_common_to_tensor_fn_realhbbf16(const Tensor & t)126 get_store_common_to_tensor_fn_realhbbf16(const Tensor& t) {
127 void (*result)(CTYPE_COMMON, void*) = nullptr;
128 ET_SWITCH_REALHBBF16_TYPES(
129 t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
130 result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
131 });
132 return result;
133 }
134
135 template <typename CTYPE_COMMON, const char* op_name>
get_store_common_to_tensor_fn_realhbf16(const Tensor & t)136 store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn_realhbf16(
137 const Tensor& t) {
138 void (*result)(CTYPE_COMMON, void*) = nullptr;
139 ET_SWITCH_REALHBF16_TYPES(
140 t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
141 result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
142 });
143 return result;
144 }
145
146 template <typename CTYPE_COMMON, const char* op_name>
147 store_common_to_tensor_fn<CTYPE_COMMON>
get_store_common_to_tensor_fn_floathbf16(const Tensor & t)148 get_store_common_to_tensor_fn_floathbf16(const Tensor& t) {
149 void (*result)(CTYPE_COMMON, void*) = nullptr;
150 ET_SWITCH_FLOATHBF16_TYPES(
151 t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
152 result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
153 });
154 return result;
155 }
156
157 template <typename CTYPE_COMMON, const char* op_name>
get_store_common_to_tensor_fn_intb(const Tensor & t)158 store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn_intb(
159 const Tensor& t) {
160 void (*result)(CTYPE_COMMON, void*) = nullptr;
161 ET_SWITCH_INT_TYPES_AND(
162 Bool, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
163 result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
164 });
165 return result;
166 }
167
168 template <typename CTYPE_COMMON, const char* op_name>
169 store_common_to_tensor_fn<CTYPE_COMMON>
get_store_common_to_tensor_fn_bool_or_byte(const Tensor & t)170 get_store_common_to_tensor_fn_bool_or_byte(const Tensor& t) {
171 void (*result)(CTYPE_COMMON, void*) = nullptr;
172 ET_SWITCH_TWO_TYPES(
173 Bool, Byte, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
174 result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
175 });
176 return result;
177 }
178
179 template <typename CTYPE_COMMON, const char* op_name>
180 store_common_to_tensor_fn<CTYPE_COMMON>
get_store_common_to_tensor_fn_same_as_compute(const Tensor & t)181 get_store_common_to_tensor_fn_same_as_compute(const Tensor& t) {
182 constexpr auto common_scalar_type = CppTypeToScalarType<CTYPE_COMMON>::value;
183 ET_CHECK_MSG(
184 t.scalar_type() == common_scalar_type,
185 "Unhandled dtype %s for %s",
186 ::executorch::runtime::toString(common_scalar_type),
187 op_name);
188 return internal::convert_and_store<CTYPE_COMMON, CTYPE_COMMON>;
189 }
190
191 template <
192 typename CTYPE_COMMON,
193 const char* op_name,
194 std::enable_if_t<std::is_same_v<CTYPE_COMMON, float>, bool> = true>
195 store_common_to_tensor_fn<CTYPE_COMMON>
get_store_common_to_tensor_fn_same_as_common(const Tensor & t)196 get_store_common_to_tensor_fn_same_as_common(const Tensor& t) {
197 void (*result)(CTYPE_COMMON, void*) = nullptr;
198 ET_SWITCH_THREE_TYPES(
199 Float, Half, BFloat16, t.scalar_type(), unused, op_name, CTYPE, [&]() {
200 result = internal::convert_and_store<CTYPE, CTYPE_COMMON>;
201 });
202 return result;
203 }
204
205 template <
206 typename CTYPE_COMMON,
207 const char* op_name,
208 std::enable_if_t<!std::is_same_v<CTYPE_COMMON, float>, bool> = true>
209 store_common_to_tensor_fn<CTYPE_COMMON>
get_store_common_to_tensor_fn_same_as_common(const Tensor & t)210 get_store_common_to_tensor_fn_same_as_common(const Tensor& t) {
211 return get_store_common_to_tensor_fn_same_as_compute<CTYPE_COMMON, op_name>(
212 t);
213 }
214
215 } // namespace internal
216
217 enum class SupportedTensorDtypes {
218 REALHBBF16,
219 REALHBF16,
220 FLOATHBF16,
221 INTB,
222 BOOL_OR_BYTE,
223 SAME_AS_COMPUTE,
224 SAME_AS_COMMON,
225 };
226
227 namespace internal {
228
229 template <typename CTYPE_COMMON, const char* op_name>
get_load_to_common_fn(const Tensor & t,SupportedTensorDtypes dtypes)230 load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn(
231 const Tensor& t,
232 SupportedTensorDtypes dtypes) {
233 switch (dtypes) {
234 case SupportedTensorDtypes::REALHBBF16:
235 return get_load_to_common_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
236 case SupportedTensorDtypes::REALHBF16:
237 return get_load_to_common_fn_realhbf16<CTYPE_COMMON, op_name>(t);
238 case SupportedTensorDtypes::FLOATHBF16:
239 return get_load_to_common_fn_realhbf16<CTYPE_COMMON, op_name>(t);
240 case SupportedTensorDtypes::INTB:
241 return get_load_to_common_fn_intb<CTYPE_COMMON, op_name>(t);
242 case SupportedTensorDtypes::BOOL_OR_BYTE:
243 return get_load_to_common_fn_bool_or_byte<CTYPE_COMMON, op_name>(t);
244 case SupportedTensorDtypes::SAME_AS_COMPUTE:
245 return get_load_to_common_fn_same_as_compute<CTYPE_COMMON, op_name>(t);
246 case SupportedTensorDtypes::SAME_AS_COMMON:
247 return get_load_to_common_fn_same_as_common<CTYPE_COMMON, op_name>(t);
248 }
249 ET_CHECK(false);
250 return nullptr;
251 }
252
253 template <typename CTYPE_COMMON, const char* op_name>
get_store_common_to_tensor_fn(const Tensor & t,SupportedTensorDtypes dtypes)254 store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn(
255 const Tensor& t,
256 SupportedTensorDtypes dtypes) {
257 switch (dtypes) {
258 case SupportedTensorDtypes::REALHBBF16:
259 return get_store_common_to_tensor_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
260 case SupportedTensorDtypes::REALHBF16:
261 return get_store_common_to_tensor_fn_realhbf16<CTYPE_COMMON, op_name>(t);
262 case SupportedTensorDtypes::FLOATHBF16:
263 return get_store_common_to_tensor_fn_floathbf16<CTYPE_COMMON, op_name>(t);
264 case SupportedTensorDtypes::INTB:
265 return get_store_common_to_tensor_fn_intb<CTYPE_COMMON, op_name>(t);
266 case SupportedTensorDtypes::BOOL_OR_BYTE:
267 return get_store_common_to_tensor_fn_bool_or_byte<CTYPE_COMMON, op_name>(
268 t);
269 case SupportedTensorDtypes::SAME_AS_COMPUTE:
270 return get_store_common_to_tensor_fn_same_as_compute<
271 CTYPE_COMMON,
272 op_name>(t);
273 case SupportedTensorDtypes::SAME_AS_COMMON: {
274 return get_store_common_to_tensor_fn_same_as_common<
275 CTYPE_COMMON,
276 op_name>(t);
277 }
278 }
279 ET_CHECK(false);
280 return nullptr;
281 }
282
283 bool check_tensor_dtype(
284 const Tensor t,
285 SupportedTensorDtypes dtypes,
286 const ScalarType compute_type);
287
288 } // namespace internal
289 } // namespace utils
290 } // namespace native
291 } // namespace executor
292 } // namespace torch
293