xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/scalar_utils.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <algorithm>
12 #include <cmath>
13 #include <limits>
14 
15 #include <executorch/kernels/portable/cpu/selective_build.h>
16 #include <executorch/runtime/core/exec_aten/exec_aten.h>
17 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
18 #include <executorch/runtime/core/portable_type/scalar.h>
19 
20 #define ET_CHECK_SCALAR_SAME_TYPE(a__, b__)                      \
21   ({                                                             \
22     ET_CHECK_MSG(                                                \
23         (a__).isBoolean() == (b__).isBoolean(),                  \
24         "Scalars type do not match, isBoolean() %d vs %d",       \
25         (a__).isBoolean(),                                       \
26         (b__).isBoolean());                                      \
27     ET_CHECK_MSG(                                                \
28         (a__).isIntegral(false) == (b__).isIntegral(false),      \
29         "Scalars type do not match, isIntegral(false) %d vs %d", \
30         (a__).isIntegral(false),                                 \
31         (b__).isIntegral(false));                                \
32     ET_CHECK_MSG(                                                \
33         (a__).isFloatingPoint() == (b__).isFloatingPoint(),      \
34         "Scalars type do not match, isFloatingPoint() %d vs %d", \
35         (a__).isFloatingPoint(),                                 \
36         (b__).isFloatingPoint());                                \
37   })
38 
39 /**
40  * Convenience macro to extract a Scalar into a value
41  */
42 #define ET_EXTRACT_SCALAR(scalar, out_val)     \
43   ET_CHECK_MSG(                                \
44       utils::extract_scalar(scalar, &out_val), \
45       #scalar " could not be extracted: wrong type or out of range");
46 
47 namespace torch {
48 namespace executor {
49 namespace native {
50 namespace utils {
51 
52 /**
53  * Returns the dtype associated with a Scalar that reflects the category
54  * of value stored by the Scalar.
55  */
get_scalar_dtype(Scalar scalar)56 inline ScalarType get_scalar_dtype(Scalar scalar) {
57   if (scalar.isBoolean()) {
58     return ScalarType::Bool;
59   }
60   if (scalar.isIntegral(false)) {
61     return ScalarType::Long;
62   }
63   if (scalar.isFloatingPoint()) {
64     return ScalarType::Double;
65   }
66   ET_CHECK_MSG(false, "Scalar must be Boolean, Integral or Floating.");
67 }
68 
scalars_have_same_dtype(Scalar a,Scalar b)69 inline bool scalars_have_same_dtype(Scalar a, Scalar b) {
70   ScalarType a_dtype = get_scalar_dtype(a);
71   ScalarType b_dtype = get_scalar_dtype(b);
72   if (a_dtype == b_dtype) {
73     return true;
74   }
75   ET_LOG(
76       Error,
77       "Expected scalars to have the same dtype, but found %s and %s",
78       toString(a_dtype),
79       toString(b_dtype));
80   return false;
81 }
82 
83 template <typename T1, typename T2, bool half_to_float = false>
84 struct promote_type_with_scalar_type {
85  private:
86   static_assert(
87       std::is_same<T2, torch::executor::internal::B1>::value ||
88           std::is_same<T2, torch::executor::internal::I8>::value ||
89           std::is_same<T2, torch::executor::internal::F8>::value,
90       "scalar type can only be Bool, Long or Double");
91   static_assert(
92       !is_qint_type<T1>::value,
93       "promote_type_with_scalar_type not valid for quantized dtypes");
94   static_assert(
95       !is_bits_type<T1>::value,
96       "promote_type_with_scalar_type not valid for bits dtypes");
97   using promote_type_with_scalar_type_not_respecting_half_to_float =
98       typename std::conditional<
99           is_complex_type<T1>::value ||
100               std::is_same<T2, torch::executor::internal::B1>::value,
101           T1,
102           typename std::conditional<
103               std::is_same<T2, torch::executor::internal::I8>::value,
104               typename std::conditional<
105                   std::is_same<T1, torch::executor::internal::B1>::value,
106                   torch::executor::internal::I8,
107                   T1>::type,
108               typename std::conditional<
109                   is_floating_point<T1>::value,
110                   T1,
111                   torch::executor::internal::F4>::type>::type>::type;
112 
113  public:
114   using type = typename std::conditional<
115       half_to_float &&
116           (std::is_same<
117                promote_type_with_scalar_type_not_respecting_half_to_float,
118                typename ScalarTypeToCppType<
119                    exec_aten::ScalarType::Half>::type>::value ||
120            std::is_same<
121                promote_type_with_scalar_type_not_respecting_half_to_float,
122                typename ScalarTypeToCppType<
123                    exec_aten::ScalarType::BFloat16>::type>::value),
124       typename ScalarTypeToCppType<exec_aten::ScalarType::Float>::type,
125       promote_type_with_scalar_type_not_respecting_half_to_float>::type;
126 };
127 
128 /**
129  * Implement type promotion between a tensor's ScalarType with a Scalar.
130  * If the Scalar contains a value in the same category of the tensor's
131  * ScalarType, the tensor's ScalarType will be preserved. Otherwise, a type
132  * promotion will occur and the dtype associated with the Scalar will be
133  * returned.
134  *
135  * If t is a complex type, then it will be preserved.
136  */
137 inline ScalarType promote_type_with_scalar(
138     ScalarType t,
139     Scalar scalar,
140     bool half_to_float = false) {
141   if (half_to_float && t == ScalarType::Half) {
142     t = ScalarType::Float;
143   }
144 
145   // QInt, and Bits types not supported
146   ET_CHECK(!isQIntType(t));
147   ET_CHECK(!isBitsType(t));
148 
149   if (isComplexType(t)) {
150     return t;
151   }
152   if (scalar.isFloatingPoint()) {
153     if (isFloatingType(t)) {
154       return t;
155     } else {
156       // ATen will promote to Float instead of Double
157       return ScalarType::Float;
158     }
159   }
160   if (scalar.isIntegral(false)) {
161     if (isFloatingType(t) || isIntegralType(t, false)) {
162       return t;
163     } else {
164       return ScalarType::Long;
165     }
166   }
167   if (scalar.isBoolean()) {
168     return t;
169   }
170   ET_CHECK_MSG(false, "Scalar must be Boolean, Integral or Floating.");
171 }
172 
173 /**
174  * Extracts an integer value from a Scalar.
175  *
176  * @param[in] scalar The source of the value to extract.
177  * @param[out] out_val The extracted value, on success.
178  * @returns `true` if a value was extracted, and sets `*out_val` to that value.
179  *    `false` if a value could not be extracted: either it was not an integer
180  *    Scalar, or the value of that Scalar could not be represented by INT_T.
181  */
182 template <
183     typename INT_T,
184     typename std::enable_if<
185         std::is_integral<INT_T>::value && !std::is_same<INT_T, bool>::value,
186         bool>::type = true>
extract_scalar(Scalar scalar,INT_T * out_val)187 bool extract_scalar(Scalar scalar, INT_T* out_val) {
188   if (!scalar.isIntegral(/*includeBool=*/false)) {
189     return false;
190   }
191   int64_t val = scalar.to<int64_t>();
192   if (val < std::numeric_limits<INT_T>::lowest() ||
193       val > std::numeric_limits<INT_T>::max()) {
194     // PyTorch's implementation of clamp() raises an exception if the min/max
195     // values cannot be represented as the dtype, so we should fail too.
196     return false;
197   }
198   *out_val = static_cast<INT_T>(val);
199   return true;
200 }
201 
202 /**
203  * Extracts a floating point value from a Scalar.
204  *
205  * @param[in] scalar The source of the value to extract.
206  * @param[out] out_val The extracted value, on success.
207  * @returns `true` if a value was extracted, and sets `*out_val` to that value.
208  *    `false` if a value could not be extracted: either it was not a floating
209  *    point Scalar, or the value of that Scalar could not be represented by
210  *    FLOAT_T.
211  */
212 template <
213     typename FLOAT_T,
214     typename std::enable_if<std::is_floating_point<FLOAT_T>::value, bool>::
215         type = true>
extract_scalar(Scalar scalar,FLOAT_T * out_val)216 bool extract_scalar(Scalar scalar, FLOAT_T* out_val) {
217   double val;
218   if (scalar.isFloatingPoint()) {
219     val = scalar.to<double>();
220     // If the double is outside the finite range supported by float, it cannot
221     // be represented when FLOAT_T == float. float can, however, represent
222     // infinite and NaN values.
223     if (std::isfinite(val) &&
224         (val < std::numeric_limits<FLOAT_T>::lowest() ||
225          val > std::numeric_limits<FLOAT_T>::max())) {
226       // PyTorch's implementation of clamp() raises an exception if the min/max
227       // values cannot be represented as the dtype, so we should fail too.
228       return false;
229     }
230   } else if (scalar.isIntegral(/*includeBool=*/false)) {
231     val = static_cast<double>(scalar.to<int64_t>());
232   } else {
233     // Not a numeric Scalar.
234     return false;
235   }
236   *out_val = static_cast<FLOAT_T>(val);
237   return true;
238 }
239 
240 /**
241  * Extracts a boolean value from a Scalar.
242  *
243  * @param[in] scalar The source of the value to extract.
244  * @param[out] out_val The extracted value, on success.
245  * @returns `true` if a value was extracted, and sets `*out_val` to that value.
246  *    `false` if a value could not be extracted, i.e. not a boolean
247  */
248 template <
249     typename BOOL_T,
250     typename std::enable_if<std::is_same<BOOL_T, bool>::value, bool>::type =
251         true>
extract_scalar(Scalar scalar,BOOL_T * out_val)252 bool extract_scalar(Scalar scalar, BOOL_T* out_val) {
253   if (scalar.isIntegral(false)) {
254     *out_val = static_cast<bool>(scalar.to<int64_t>());
255     return true;
256   }
257   if (scalar.isBoolean()) {
258     *out_val = scalar.to<bool>();
259     return true;
260   }
261   return false;
262 }
263 
264 } // namespace utils
265 } // namespace native
266 } // namespace executor
267 } // namespace torch
268