xref: /aosp_15_r20/external/eigen/unsupported/Eigen/CXX11/src/Tensor/TensorUInt128.h (revision bf2c37156dfe67e5dfebd6d394bad8b2ab5804d4)
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2015 Benoit Steiner <[email protected]>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_UINT128_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_UINT128_H
12 
13 namespace Eigen {
14 namespace internal {
15 
16 
17 template <uint64_t n>
18 struct static_val {
19   static const uint64_t value = n;
uint64_tstatic_val20   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE operator uint64_t() const { return n; }
21 
static_valstatic_val22   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static_val() { }
23 
24   template <typename T>
static_valstatic_val25   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static_val(const T& v) {
26     EIGEN_UNUSED_VARIABLE(v);
27     eigen_assert(v == n);
28   }
29 };
30 
31 
32 template <typename HIGH = uint64_t, typename LOW = uint64_t>
33 struct TensorUInt128
34 {
35   HIGH high;
36   LOW low;
37 
38   template<typename OTHER_HIGH, typename OTHER_LOW>
39   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
TensorUInt128TensorUInt12840   TensorUInt128(const TensorUInt128<OTHER_HIGH, OTHER_LOW>& other) : high(other.high), low(other.low) {
41     EIGEN_STATIC_ASSERT(sizeof(OTHER_HIGH) <= sizeof(HIGH), YOU_MADE_A_PROGRAMMING_MISTAKE);
42     EIGEN_STATIC_ASSERT(sizeof(OTHER_LOW) <= sizeof(LOW), YOU_MADE_A_PROGRAMMING_MISTAKE);
43   }
44 
45   template<typename OTHER_HIGH, typename OTHER_LOW>
46   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
47   TensorUInt128& operator = (const TensorUInt128<OTHER_HIGH, OTHER_LOW>& other) {
48     EIGEN_STATIC_ASSERT(sizeof(OTHER_HIGH) <= sizeof(HIGH), YOU_MADE_A_PROGRAMMING_MISTAKE);
49     EIGEN_STATIC_ASSERT(sizeof(OTHER_LOW) <= sizeof(LOW), YOU_MADE_A_PROGRAMMING_MISTAKE);
50     high = other.high;
51     low = other.low;
52     return *this;
53   }
54 
55   template<typename T>
56   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
TensorUInt128TensorUInt12857   explicit TensorUInt128(const T& x) : high(0), low(x) {
58     eigen_assert((static_cast<typename conditional<sizeof(T) == 8, uint64_t, uint32_t>::type>(x) <= NumTraits<uint64_t>::highest()));
59     eigen_assert(x >= 0);
60   }
61 
62   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
TensorUInt128TensorUInt12863   TensorUInt128(HIGH y, LOW x) : high(y), low(x) { }
64 
LOWTensorUInt12865   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE operator LOW() const {
66     return low;
67   }
lowerTensorUInt12868   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LOW lower() const {
69     return low;
70   }
upperTensorUInt12871   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HIGH upper() const {
72     return high;
73   }
74 };
75 
76 
77 template <typename HL, typename LL, typename HR, typename LR>
78 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
79 bool operator == (const TensorUInt128<HL, LL>& lhs, const TensorUInt128<HR, LR>& rhs)
80 {
81   return (lhs.high == rhs.high) & (lhs.low == rhs.low);
82 }
83 
84 template <typename HL, typename LL, typename HR, typename LR>
85 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
86 bool operator != (const TensorUInt128<HL, LL>& lhs, const TensorUInt128<HR, LR>& rhs)
87 {
88   return (lhs.high != rhs.high) | (lhs.low != rhs.low);
89 }
90 
91 template <typename HL, typename LL, typename HR, typename LR>
92 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
93 bool operator >= (const TensorUInt128<HL, LL>& lhs, const TensorUInt128<HR, LR>& rhs)
94 {
95   if (lhs.high != rhs.high) {
96     return lhs.high > rhs.high;
97   }
98   return lhs.low >= rhs.low;
99 }
100 
101 template <typename HL, typename LL, typename HR, typename LR>
102 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
103 bool operator < (const TensorUInt128<HL, LL>& lhs, const TensorUInt128<HR, LR>& rhs)
104 {
105   if (lhs.high != rhs.high) {
106     return lhs.high < rhs.high;
107   }
108   return lhs.low < rhs.low;
109 }
110 
111 template <typename HL, typename LL, typename HR, typename LR>
112 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
113 TensorUInt128<uint64_t, uint64_t> operator + (const TensorUInt128<HL, LL>& lhs, const TensorUInt128<HR, LR>& rhs)
114 {
115   TensorUInt128<uint64_t, uint64_t> result(lhs.high + rhs.high, lhs.low + rhs.low);
116   if (result.low < rhs.low) {
117     result.high += 1;
118   }
119   return result;
120 }
121 
122 template <typename HL, typename LL, typename HR, typename LR>
123 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
124 TensorUInt128<uint64_t, uint64_t> operator - (const TensorUInt128<HL, LL>& lhs, const TensorUInt128<HR, LR>& rhs)
125 {
126   TensorUInt128<uint64_t, uint64_t> result(lhs.high - rhs.high, lhs.low - rhs.low);
127   if (result.low > lhs.low) {
128     result.high -= 1;
129   }
130   return result;
131 }
132 
133 
134 template <typename HL, typename LL, typename HR, typename LR>
135 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
136 TensorUInt128<uint64_t, uint64_t> operator * (const TensorUInt128<HL, LL>& lhs, const TensorUInt128<HR, LR>& rhs)
137 {
138   // Split each 128-bit integer into 4 32-bit integers, and then do the
139   // multiplications by hand as follow:
140   //   lhs      a  b  c  d
141   //   rhs      e  f  g  h
142   //           -----------
143   //           ah bh ch dh
144   //           bg cg dg
145   //           cf df
146   //           de
147   // The result is stored in 2 64bit integers, high and low.
148 
149   const uint64_t LOW = 0x00000000FFFFFFFFLL;
150   const uint64_t HIGH = 0xFFFFFFFF00000000LL;
151 
152   uint64_t d = lhs.low & LOW;
153   uint64_t c = (lhs.low & HIGH) >> 32LL;
154   uint64_t b = lhs.high & LOW;
155   uint64_t a = (lhs.high & HIGH) >> 32LL;
156 
157   uint64_t h = rhs.low & LOW;
158   uint64_t g = (rhs.low & HIGH) >> 32LL;
159   uint64_t f = rhs.high & LOW;
160   uint64_t e = (rhs.high & HIGH) >> 32LL;
161 
162   // Compute the low 32 bits of low
163   uint64_t acc = d * h;
164   uint64_t low = acc & LOW;
165   //  Compute the high 32 bits of low. Add a carry every time we wrap around
166   acc >>= 32LL;
167   uint64_t carry = 0;
168   uint64_t acc2 = acc + c * h;
169   if (acc2 < acc) {
170     carry++;
171   }
172   acc = acc2 + d * g;
173   if (acc < acc2) {
174     carry++;
175   }
176   low |= (acc << 32LL);
177 
178   // Carry forward the high bits of acc to initiate the computation of the
179   // low 32 bits of high
180   acc2 = (acc >> 32LL) | (carry << 32LL);
181   carry = 0;
182 
183   acc = acc2 + b * h;
184   if (acc < acc2) {
185     carry++;
186   }
187   acc2 = acc + c * g;
188   if (acc2 < acc) {
189     carry++;
190   }
191   acc = acc2 + d * f;
192   if (acc < acc2) {
193     carry++;
194   }
195   uint64_t high = acc & LOW;
196 
197   // Start to compute the high 32 bits of high.
198   acc2 = (acc >> 32LL) | (carry << 32LL);
199 
200   acc = acc2 + a * h;
201   acc2 = acc + b * g;
202   acc = acc2 + c * f;
203   acc2 = acc + d * e;
204   high |= (acc2 << 32LL);
205 
206   return TensorUInt128<uint64_t, uint64_t>(high, low);
207 }
208 
209 template <typename HL, typename LL, typename HR, typename LR>
210 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
211 TensorUInt128<uint64_t, uint64_t> operator / (const TensorUInt128<HL, LL>& lhs, const TensorUInt128<HR, LR>& rhs)
212 {
213   if (rhs == TensorUInt128<static_val<0>, static_val<1> >(1)) {
214     return TensorUInt128<uint64_t, uint64_t>(lhs.high, lhs.low);
215   } else if (lhs < rhs) {
216     return TensorUInt128<uint64_t, uint64_t>(0);
217   } else {
218     // calculate the biggest power of 2 times rhs that's less than or equal to lhs
219     TensorUInt128<uint64_t, uint64_t> power2(1);
220     TensorUInt128<uint64_t, uint64_t> d(rhs);
221     TensorUInt128<uint64_t, uint64_t> tmp(lhs - d);
222     while (lhs >= d) {
223       tmp = tmp - d;
224       d = d + d;
225       power2 = power2 + power2;
226     }
227 
228     tmp = TensorUInt128<uint64_t, uint64_t>(lhs.high, lhs.low);
229     TensorUInt128<uint64_t, uint64_t> result(0);
230     while (power2 != TensorUInt128<static_val<0>, static_val<0> >(0)) {
231       if (tmp >= d) {
232         tmp = tmp - d;
233         result = result + power2;
234       }
235       // Shift right
236       power2 = TensorUInt128<uint64_t, uint64_t>(power2.high >> 1, (power2.low >> 1) | (power2.high << 63));
237       d = TensorUInt128<uint64_t, uint64_t>(d.high >> 1, (d.low >> 1) | (d.high << 63));
238     }
239 
240     return result;
241   }
242 }
243 
244 
245 }  // namespace internal
246 }  // namespace Eigen
247 
248 
249 #endif  // EIGEN_CXX11_TENSOR_TENSOR_UINT128_H
250