xref: /aosp_15_r20/external/tensorflow/tensorflow/core/lib/math/math_util.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_LIB_MATH_MATH_UTIL_H_
17 #define TENSORFLOW_CORE_LIB_MATH_MATH_UTIL_H_
18 
19 #include <type_traits>
20 
21 #include "tensorflow/core/platform/logging.h"
22 #include "tensorflow/core/platform/types.h"
23 
24 namespace tensorflow {
25 
26 class MathUtil {
27  public:
28   // ----------------------------------------------------------------------
29   // CeilOfRatio<IntegralType>
30   // FloorOfRatio<IntegralType>
31   //   Returns the ceil (resp. floor) of the ratio of two integers.
32   //
33   //  * IntegralType: any integral type, whether signed or not.
34   //  * numerator: any integer: positive, negative, or zero.
35   //  * denominator: a non-zero integer, positive or negative.
36   //
37   // This implementation is correct, meaning there is never any precision loss,
38   // and there is never an overflow. However, if the type is signed, having
39   // numerator == MathLimits<IntegralType>::kMin and denominator == -1 is not a
40   // valid input, because kMin has a greater absolute value than kMax.
41   //
42   // Input validity is DCHECKed. When not in debug mode, invalid inputs raise
43   // SIGFPE.
44   //
45   // This method has been designed and tested so that it should always be
46   // preferred to alternatives. Indeed, there exist popular recipes to compute
47   // the result, such as casting to double, but they are in general incorrect.
48   // In cases where an alternative technique is correct, performance measurement
49   // showed the provided implementation is faster.
50   template <typename IntegralType>
CeilOfRatio(IntegralType numerator,IntegralType denominator)51   static IntegralType CeilOfRatio(IntegralType numerator,
52                                   IntegralType denominator) {
53     return CeilOrFloorOfRatio<IntegralType, true>(numerator, denominator);
54   }
55   template <typename IntegralType>
FloorOfRatio(IntegralType numerator,IntegralType denominator)56   static IntegralType FloorOfRatio(IntegralType numerator,
57                                    IntegralType denominator) {
58     return CeilOrFloorOfRatio<IntegralType, false>(numerator, denominator);
59   }
60 
61   template <typename IntegralType, bool ceil>
62   static IntegralType CeilOrFloorOfRatio(IntegralType numerator,
63                                          IntegralType denominator);
64 
65   template <typename IntegralType>
66   static IntegralType GCD(IntegralType x, IntegralType y);
67 
68   // ----------------------------------------------------------------------
69   // IPow<T>
70   //   Computes the result of raising a number to a non-negative integral power.
71   //
72   //  * T: An integral type, floating-point type, or user-defined type for which
73   //    operator*= is defined.
74   //  * base: the base "v" of the operation
75   //  * exp: the exponent "i" of the operation; must be non-negative.
76   //
77   // Computes v^i, in a way that is faster than std::pow (which supports
78   // arbitrary real exponents).
79   //
80   // When T is a floating point type, this has the same semantics as std::pow,
81   // but it is much faster. When T is an integral type, computations are
82   // performed in the value domain of T, and overflow semantics are those of T.
83   //
84   // Input validity is DCHECKed.
85   template <typename T>
86   static T IPow(T base, int exp);
87 
88   // Retrieves the sign of `x`:
89   //  nan if x is nan.
90   //   -1 if x < 0,
91   //   +1 if x > 0,
92   //    0 if x = 0.
93   template <typename T, absl::enable_if_t<std::is_integral<T>::value, int> = 0>
Sign(const T x)94   static T Sign(const T x) {
95     return SignHelper<T>(x);
96   }
97   template <typename T, absl::enable_if_t<!std::is_integral<T>::value, int> = 0>
Sign(const T x)98   static T Sign(const T x) {
99     return std::isnan(x) ? x : SignHelper<T>(x);
100   }
101 
102  private:
103   // A helper function to reduce duplication between two MathUtil::Sign
104   // functions, which are required to be split to avoid ambiguity for integral
105   // types with std::isnan for some builds.
106   template <typename T>
SignHelper(const T x)107   static T SignHelper(const T x) {
108     return x == T(0) ? T(0) : (x > T(0) ? T(1) : T(-1));
109   }
110 };
111 
112 // ---- CeilOrFloorOfRatio ----
113 // This is a branching-free, cast-to-double-free implementation.
114 //
115 // Casting to double is in general incorrect because of loss of precision
116 // when casting an int64 into a double.
117 //
118 // There's a bunch of 'recipes' to compute a integer ceil (or floor) on the web,
119 // and most of them are incorrect.
120 template <typename IntegralType, bool ceil>
CeilOrFloorOfRatio(IntegralType numerator,IntegralType denominator)121 IntegralType MathUtil::CeilOrFloorOfRatio(IntegralType numerator,
122                                           IntegralType denominator) {
123   DCHECK_NE(0, denominator) << "Division by zero is not supported.";
124 
125   const IntegralType rounded_toward_zero = numerator / denominator;
126   const IntegralType intermediate_product = rounded_toward_zero * denominator;
127 
128   if (ceil) {  // Compile-time condition: not an actual branching
129     // When rounded_toward_zero is negative, then an adjustment is never needed:
130     // the real ratio is negative, and so rounded toward zero is the ceil.
131     // When rounded_toward_zero is non-negative, an adjustment is needed if the
132     // sign of the difference numerator - intermediate_product is the same as
133     // the sign of the denominator.
134     //
135     //
136     // Using a bool and then a static_cast to IntegralType is not strictly
137     // necessary, but it makes the code clear, and anyway the compiler should
138     // get rid of it.
139     const bool needs_adjustment =
140         (rounded_toward_zero >= 0) &&
141         ((denominator > 0 && numerator > intermediate_product) ||
142          (denominator < 0 && numerator < intermediate_product));
143     const IntegralType adjustment = static_cast<IntegralType>(needs_adjustment);
144     const IntegralType ceil_of_ratio = rounded_toward_zero + adjustment;
145     return ceil_of_ratio;
146   } else {
147     // Floor case: symmetrical to the previous one
148     const bool needs_adjustment =
149         (rounded_toward_zero <= 0) &&
150         ((denominator > 0 && numerator < intermediate_product) ||
151          (denominator < 0 && numerator > intermediate_product));
152     const IntegralType adjustment = static_cast<IntegralType>(needs_adjustment);
153     const IntegralType floor_of_ratio = rounded_toward_zero - adjustment;
154     return floor_of_ratio;
155   }
156 }
157 
158 template <typename IntegralType>
GCD(IntegralType x,IntegralType y)159 IntegralType MathUtil::GCD(IntegralType x, IntegralType y) {
160   static_assert(std::is_unsigned<IntegralType>::value,
161                 "signed GCD not supported!");
162   while (y != 0) {
163     IntegralType r = x % y;
164     x = y;
165     y = r;
166   }
167   return x;
168 }
169 
170 // ---- IPow ----
171 // Implemented with the squared exponentiation method (a.k.a. double-and-add).
172 //
173 // Note that "exp >>= 1" is faster than "exp /= 2" on at least one platform.
174 template <typename T>
IPow(T base,int exp)175 T MathUtil::IPow(T base, int exp) {
176   DCHECK_GE(exp, 0);
177   for (T result(1);; base *= base) {
178     if ((exp & 1) != 0) result *= base;
179     exp >>= 1;
180     if (exp == 0) return result;
181   }
182 }
183 
184 }  // namespace tensorflow
185 
186 #endif  // TENSORFLOW_CORE_LIB_MATH_MATH_UTIL_H_
187