xref: /aosp_15_r20/external/executorch/kernels/optimized/utils/math_utils.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <cstdint>
12 
13 #include <executorch/kernels/optimized/utils/llvmMathExtras.h>
14 
15 namespace executorch {
16 namespace utils {
17 
18 template <typename scalar_t>
19 struct ComputeDTypeTraits {
20   using type = scalar_t;
21 };
22 // For 16 bit int types, ops should perform internal math in int32_t.
23 template <>
24 struct ComputeDTypeTraits<uint16_t> {
25   using type = uint32_t;
26 };
27 template <>
28 struct ComputeDTypeTraits<int16_t> {
29   using type = int32_t;
30 };
31 // For 8 bit int types, ops should perform internal math in int32_t.
32 template <>
33 struct ComputeDTypeTraits<uint8_t> {
34   using type = uint32_t;
35 };
36 template <>
37 struct ComputeDTypeTraits<int8_t> {
38   using type = int32_t;
39 };
40 
41 template <typename T>
42 using compute_dtype = typename ComputeDTypeTraits<T>::type;
43 
44 inline int64_t divup(int64_t x, int64_t y) {
45   return (x + y - 1) / y;
46 }
47 
48 template <typename T>
49 T CeilLog2(const T& x) {
50   if (x <= 2) {
51     return 1;
52   }
53   // Last set bit is floor(log2(x)), floor + 1 is ceil
54   // except when x is an exact powers of 2, so subtract 1 first
55   return static_cast<T>(executorch::llvm::findLastSet(
56       static_cast<uint64_t>(x) - 1)) + 1;
57 }
58 
59 } // namespace utils
60 } // namespace executorch
61