1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #pragma once
10
11 #include <algorithm>
12 #include <cmath>
13 #include <limits>
14
15 #include <executorch/kernels/portable/cpu/selective_build.h>
16 #include <executorch/runtime/core/exec_aten/exec_aten.h>
17 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
18 #include <executorch/runtime/core/portable_type/scalar.h>
19
20 #define ET_CHECK_SCALAR_SAME_TYPE(a__, b__) \
21 ({ \
22 ET_CHECK_MSG( \
23 (a__).isBoolean() == (b__).isBoolean(), \
24 "Scalars type do not match, isBoolean() %d vs %d", \
25 (a__).isBoolean(), \
26 (b__).isBoolean()); \
27 ET_CHECK_MSG( \
28 (a__).isIntegral(false) == (b__).isIntegral(false), \
29 "Scalars type do not match, isIntegral(false) %d vs %d", \
30 (a__).isIntegral(false), \
31 (b__).isIntegral(false)); \
32 ET_CHECK_MSG( \
33 (a__).isFloatingPoint() == (b__).isFloatingPoint(), \
34 "Scalars type do not match, isFloatingPoint() %d vs %d", \
35 (a__).isFloatingPoint(), \
36 (b__).isFloatingPoint()); \
37 })
38
39 /**
40 * Convenience macro to extract a Scalar into a value
41 */
42 #define ET_EXTRACT_SCALAR(scalar, out_val) \
43 ET_CHECK_MSG( \
44 utils::extract_scalar(scalar, &out_val), \
45 #scalar " could not be extracted: wrong type or out of range");
46
47 namespace torch {
48 namespace executor {
49 namespace native {
50 namespace utils {
51
52 /**
53 * Returns the dtype associated with a Scalar that reflects the category
54 * of value stored by the Scalar.
55 */
get_scalar_dtype(Scalar scalar)56 inline ScalarType get_scalar_dtype(Scalar scalar) {
57 if (scalar.isBoolean()) {
58 return ScalarType::Bool;
59 }
60 if (scalar.isIntegral(false)) {
61 return ScalarType::Long;
62 }
63 if (scalar.isFloatingPoint()) {
64 return ScalarType::Double;
65 }
66 ET_CHECK_MSG(false, "Scalar must be Boolean, Integral or Floating.");
67 }
68
scalars_have_same_dtype(Scalar a,Scalar b)69 inline bool scalars_have_same_dtype(Scalar a, Scalar b) {
70 ScalarType a_dtype = get_scalar_dtype(a);
71 ScalarType b_dtype = get_scalar_dtype(b);
72 if (a_dtype == b_dtype) {
73 return true;
74 }
75 ET_LOG(
76 Error,
77 "Expected scalars to have the same dtype, but found %s and %s",
78 toString(a_dtype),
79 toString(b_dtype));
80 return false;
81 }
82
83 template <typename T1, typename T2, bool half_to_float = false>
84 struct promote_type_with_scalar_type {
85 private:
86 static_assert(
87 std::is_same<T2, torch::executor::internal::B1>::value ||
88 std::is_same<T2, torch::executor::internal::I8>::value ||
89 std::is_same<T2, torch::executor::internal::F8>::value,
90 "scalar type can only be Bool, Long or Double");
91 static_assert(
92 !is_qint_type<T1>::value,
93 "promote_type_with_scalar_type not valid for quantized dtypes");
94 static_assert(
95 !is_bits_type<T1>::value,
96 "promote_type_with_scalar_type not valid for bits dtypes");
97 using promote_type_with_scalar_type_not_respecting_half_to_float =
98 typename std::conditional<
99 is_complex_type<T1>::value ||
100 std::is_same<T2, torch::executor::internal::B1>::value,
101 T1,
102 typename std::conditional<
103 std::is_same<T2, torch::executor::internal::I8>::value,
104 typename std::conditional<
105 std::is_same<T1, torch::executor::internal::B1>::value,
106 torch::executor::internal::I8,
107 T1>::type,
108 typename std::conditional<
109 is_floating_point<T1>::value,
110 T1,
111 torch::executor::internal::F4>::type>::type>::type;
112
113 public:
114 using type = typename std::conditional<
115 half_to_float &&
116 (std::is_same<
117 promote_type_with_scalar_type_not_respecting_half_to_float,
118 typename ScalarTypeToCppType<
119 exec_aten::ScalarType::Half>::type>::value ||
120 std::is_same<
121 promote_type_with_scalar_type_not_respecting_half_to_float,
122 typename ScalarTypeToCppType<
123 exec_aten::ScalarType::BFloat16>::type>::value),
124 typename ScalarTypeToCppType<exec_aten::ScalarType::Float>::type,
125 promote_type_with_scalar_type_not_respecting_half_to_float>::type;
126 };
127
128 /**
129 * Implement type promotion between a tensor's ScalarType with a Scalar.
130 * If the Scalar contains a value in the same category of the tensor's
131 * ScalarType, the tensor's ScalarType will be preserved. Otherwise, a type
132 * promotion will occur and the dtype associated with the Scalar will be
133 * returned.
134 *
135 * If t is a complex type, then it will be preserved.
136 */
137 inline ScalarType promote_type_with_scalar(
138 ScalarType t,
139 Scalar scalar,
140 bool half_to_float = false) {
141 if (half_to_float && t == ScalarType::Half) {
142 t = ScalarType::Float;
143 }
144
145 // QInt, and Bits types not supported
146 ET_CHECK(!isQIntType(t));
147 ET_CHECK(!isBitsType(t));
148
149 if (isComplexType(t)) {
150 return t;
151 }
152 if (scalar.isFloatingPoint()) {
153 if (isFloatingType(t)) {
154 return t;
155 } else {
156 // ATen will promote to Float instead of Double
157 return ScalarType::Float;
158 }
159 }
160 if (scalar.isIntegral(false)) {
161 if (isFloatingType(t) || isIntegralType(t, false)) {
162 return t;
163 } else {
164 return ScalarType::Long;
165 }
166 }
167 if (scalar.isBoolean()) {
168 return t;
169 }
170 ET_CHECK_MSG(false, "Scalar must be Boolean, Integral or Floating.");
171 }
172
173 /**
174 * Extracts an integer value from a Scalar.
175 *
176 * @param[in] scalar The source of the value to extract.
177 * @param[out] out_val The extracted value, on success.
178 * @returns `true` if a value was extracted, and sets `*out_val` to that value.
179 * `false` if a value could not be extracted: either it was not an integer
180 * Scalar, or the value of that Scalar could not be represented by INT_T.
181 */
182 template <
183 typename INT_T,
184 typename std::enable_if<
185 std::is_integral<INT_T>::value && !std::is_same<INT_T, bool>::value,
186 bool>::type = true>
extract_scalar(Scalar scalar,INT_T * out_val)187 bool extract_scalar(Scalar scalar, INT_T* out_val) {
188 if (!scalar.isIntegral(/*includeBool=*/false)) {
189 return false;
190 }
191 int64_t val = scalar.to<int64_t>();
192 if (val < std::numeric_limits<INT_T>::lowest() ||
193 val > std::numeric_limits<INT_T>::max()) {
194 // PyTorch's implementation of clamp() raises an exception if the min/max
195 // values cannot be represented as the dtype, so we should fail too.
196 return false;
197 }
198 *out_val = static_cast<INT_T>(val);
199 return true;
200 }
201
202 /**
203 * Extracts a floating point value from a Scalar.
204 *
205 * @param[in] scalar The source of the value to extract.
206 * @param[out] out_val The extracted value, on success.
207 * @returns `true` if a value was extracted, and sets `*out_val` to that value.
208 * `false` if a value could not be extracted: either it was not a floating
209 * point Scalar, or the value of that Scalar could not be represented by
210 * FLOAT_T.
211 */
212 template <
213 typename FLOAT_T,
214 typename std::enable_if<std::is_floating_point<FLOAT_T>::value, bool>::
215 type = true>
extract_scalar(Scalar scalar,FLOAT_T * out_val)216 bool extract_scalar(Scalar scalar, FLOAT_T* out_val) {
217 double val;
218 if (scalar.isFloatingPoint()) {
219 val = scalar.to<double>();
220 // If the double is outside the finite range supported by float, it cannot
221 // be represented when FLOAT_T == float. float can, however, represent
222 // infinite and NaN values.
223 if (std::isfinite(val) &&
224 (val < std::numeric_limits<FLOAT_T>::lowest() ||
225 val > std::numeric_limits<FLOAT_T>::max())) {
226 // PyTorch's implementation of clamp() raises an exception if the min/max
227 // values cannot be represented as the dtype, so we should fail too.
228 return false;
229 }
230 } else if (scalar.isIntegral(/*includeBool=*/false)) {
231 val = static_cast<double>(scalar.to<int64_t>());
232 } else {
233 // Not a numeric Scalar.
234 return false;
235 }
236 *out_val = static_cast<FLOAT_T>(val);
237 return true;
238 }
239
240 /**
241 * Extracts a boolean value from a Scalar.
242 *
243 * @param[in] scalar The source of the value to extract.
244 * @param[out] out_val The extracted value, on success.
245 * @returns `true` if a value was extracted, and sets `*out_val` to that value.
246 * `false` if a value could not be extracted, i.e. not a boolean
247 */
248 template <
249 typename BOOL_T,
250 typename std::enable_if<std::is_same<BOOL_T, bool>::value, bool>::type =
251 true>
extract_scalar(Scalar scalar,BOOL_T * out_val)252 bool extract_scalar(Scalar scalar, BOOL_T* out_val) {
253 if (scalar.isIntegral(false)) {
254 *out_val = static_cast<bool>(scalar.to<int64_t>());
255 return true;
256 }
257 if (scalar.isBoolean()) {
258 *out_val = scalar.to<bool>();
259 return true;
260 }
261 return false;
262 }
263
264 } // namespace utils
265 } // namespace native
266 } // namespace executor
267 } // namespace torch
268