xref: /aosp_15_r20/external/pytorch/c10/util/int128.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker // This file is based on the uint128 implementation of protobuf at
2*da0073e9SAndroid Build Coastguard Worker // https://github.com/protocolbuffers/protobuf/blob/1e88936fce10cf773cb72b44c6a7f48b38c7578b/src/google/protobuf/stubs/int128.cc
3*da0073e9SAndroid Build Coastguard Worker //
4*da0073e9SAndroid Build Coastguard Worker // Protocol Buffers - Google's data interchange format
5*da0073e9SAndroid Build Coastguard Worker // Copyright 2008 Google Inc.  All rights reserved.
6*da0073e9SAndroid Build Coastguard Worker // https://developers.google.com/protocol-buffers/
7*da0073e9SAndroid Build Coastguard Worker //
8*da0073e9SAndroid Build Coastguard Worker // Redistribution and use in source and binary forms, with or without
9*da0073e9SAndroid Build Coastguard Worker // modification, are permitted provided that the following conditions are
10*da0073e9SAndroid Build Coastguard Worker // met:
11*da0073e9SAndroid Build Coastguard Worker //
12*da0073e9SAndroid Build Coastguard Worker //     * Redistributions of source code must retain the above copyright
13*da0073e9SAndroid Build Coastguard Worker // notice, this list of conditions and the following disclaimer.
14*da0073e9SAndroid Build Coastguard Worker //     * Redistributions in binary form must reproduce the above
15*da0073e9SAndroid Build Coastguard Worker // copyright notice, this list of conditions and the following disclaimer
16*da0073e9SAndroid Build Coastguard Worker // in the documentation and/or other materials provided with the
17*da0073e9SAndroid Build Coastguard Worker // distribution.
18*da0073e9SAndroid Build Coastguard Worker //     * Neither the name of Google Inc. nor the names of its
19*da0073e9SAndroid Build Coastguard Worker // contributors may be used to endorse or promote products derived from
20*da0073e9SAndroid Build Coastguard Worker // this software without specific prior written permission.
21*da0073e9SAndroid Build Coastguard Worker //
22*da0073e9SAndroid Build Coastguard Worker // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23*da0073e9SAndroid Build Coastguard Worker // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24*da0073e9SAndroid Build Coastguard Worker // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
25*da0073e9SAndroid Build Coastguard Worker // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
26*da0073e9SAndroid Build Coastguard Worker // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
27*da0073e9SAndroid Build Coastguard Worker // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
28*da0073e9SAndroid Build Coastguard Worker // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
29*da0073e9SAndroid Build Coastguard Worker // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
30*da0073e9SAndroid Build Coastguard Worker // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31*da0073e9SAndroid Build Coastguard Worker // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32*da0073e9SAndroid Build Coastguard Worker // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33*da0073e9SAndroid Build Coastguard Worker 
34*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Logging.h>
35*da0073e9SAndroid Build Coastguard Worker #include <c10/util/int128.h>
36*da0073e9SAndroid Build Coastguard Worker #include <iomanip>
37*da0073e9SAndroid Build Coastguard Worker #include <ostream> // NOLINT(readability/streams)
38*da0073e9SAndroid Build Coastguard Worker 
39*da0073e9SAndroid Build Coastguard Worker namespace c10 {
40*da0073e9SAndroid Build Coastguard Worker 
41*da0073e9SAndroid Build Coastguard Worker const uint128_pod kuint128max = {
42*da0073e9SAndroid Build Coastguard Worker     uint64_t{0xFFFFFFFFFFFFFFFFu},
43*da0073e9SAndroid Build Coastguard Worker     uint64_t{0xFFFFFFFFFFFFFFFFu}};
44*da0073e9SAndroid Build Coastguard Worker 
45*da0073e9SAndroid Build Coastguard Worker // Returns the 0-based position of the last set bit (i.e., most significant bit)
46*da0073e9SAndroid Build Coastguard Worker // in the given uint64. The argument may not be 0.
47*da0073e9SAndroid Build Coastguard Worker //
48*da0073e9SAndroid Build Coastguard Worker // For example:
49*da0073e9SAndroid Build Coastguard Worker //   Given: 5 (decimal) == 101 (binary)
50*da0073e9SAndroid Build Coastguard Worker //   Returns: 2
51*da0073e9SAndroid Build Coastguard Worker #define STEP(T, n, pos, sh)                   \
52*da0073e9SAndroid Build Coastguard Worker   do {                                        \
53*da0073e9SAndroid Build Coastguard Worker     if ((n) >= (static_cast<T>(1) << (sh))) { \
54*da0073e9SAndroid Build Coastguard Worker       (n) = (n) >> (sh);                      \
55*da0073e9SAndroid Build Coastguard Worker       (pos) |= (sh);                          \
56*da0073e9SAndroid Build Coastguard Worker     }                                         \
57*da0073e9SAndroid Build Coastguard Worker   } while (0)
Fls64(uint64_t n)58*da0073e9SAndroid Build Coastguard Worker static inline int Fls64(uint64_t n) {
59*da0073e9SAndroid Build Coastguard Worker   //   GOOGLE_DCHECK_NE(0, n);
60*da0073e9SAndroid Build Coastguard Worker   uint64_t pos = 0;
61*da0073e9SAndroid Build Coastguard Worker   STEP(uint64_t, n, pos, 0x20);
62*da0073e9SAndroid Build Coastguard Worker   uint32_t n32 = n;
63*da0073e9SAndroid Build Coastguard Worker   STEP(uint32_t, n32, pos, 0x10);
64*da0073e9SAndroid Build Coastguard Worker   STEP(uint32_t, n32, pos, 0x08);
65*da0073e9SAndroid Build Coastguard Worker   STEP(uint32_t, n32, pos, 0x04);
66*da0073e9SAndroid Build Coastguard Worker   return static_cast<int>(
67*da0073e9SAndroid Build Coastguard Worker       pos + ((uint64_t{0x3333333322221100u} >> (n32 << 2)) & 0x3));
68*da0073e9SAndroid Build Coastguard Worker }
69*da0073e9SAndroid Build Coastguard Worker #undef STEP
70*da0073e9SAndroid Build Coastguard Worker 
71*da0073e9SAndroid Build Coastguard Worker // Like Fls64() above, but returns the 0-based position of the last set bit
72*da0073e9SAndroid Build Coastguard Worker // (i.e., most significant bit) in the given uint128. The argument may not be 0.
Fls128(uint128 n)73*da0073e9SAndroid Build Coastguard Worker static inline int Fls128(uint128 n) {
74*da0073e9SAndroid Build Coastguard Worker   if (uint64_t hi = Uint128High64(n)) {
75*da0073e9SAndroid Build Coastguard Worker     return Fls64(hi) + 64;
76*da0073e9SAndroid Build Coastguard Worker   }
77*da0073e9SAndroid Build Coastguard Worker   return Fls64(Uint128Low64(n));
78*da0073e9SAndroid Build Coastguard Worker }
79*da0073e9SAndroid Build Coastguard Worker 
DivModImpl(uint128 dividend,uint128 divisor,uint128 * quotient_ret,uint128 * remainder_ret)80*da0073e9SAndroid Build Coastguard Worker void uint128::DivModImpl(
81*da0073e9SAndroid Build Coastguard Worker     uint128 dividend,
82*da0073e9SAndroid Build Coastguard Worker     uint128 divisor,
83*da0073e9SAndroid Build Coastguard Worker     uint128* quotient_ret,
84*da0073e9SAndroid Build Coastguard Worker     uint128* remainder_ret) {
85*da0073e9SAndroid Build Coastguard Worker   if (divisor == 0) {
86*da0073e9SAndroid Build Coastguard Worker     LOG(FATAL) << "Division or mod by zero: dividend.hi=" << dividend.hi_
87*da0073e9SAndroid Build Coastguard Worker                << ", lo=" << dividend.lo_;
88*da0073e9SAndroid Build Coastguard Worker   } else if (dividend < divisor) {
89*da0073e9SAndroid Build Coastguard Worker     *quotient_ret = 0;
90*da0073e9SAndroid Build Coastguard Worker     *remainder_ret = dividend;
91*da0073e9SAndroid Build Coastguard Worker     return;
92*da0073e9SAndroid Build Coastguard Worker   } else {
93*da0073e9SAndroid Build Coastguard Worker     int dividend_bit_length = Fls128(dividend);
94*da0073e9SAndroid Build Coastguard Worker     int divisor_bit_length = Fls128(divisor);
95*da0073e9SAndroid Build Coastguard Worker     int difference = dividend_bit_length - divisor_bit_length;
96*da0073e9SAndroid Build Coastguard Worker     uint128 quotient = 0;
97*da0073e9SAndroid Build Coastguard Worker     while (difference >= 0) {
98*da0073e9SAndroid Build Coastguard Worker       quotient <<= 1;
99*da0073e9SAndroid Build Coastguard Worker       uint128 shifted_divisor = divisor << difference;
100*da0073e9SAndroid Build Coastguard Worker       if (shifted_divisor <= dividend) {
101*da0073e9SAndroid Build Coastguard Worker         dividend -= shifted_divisor;
102*da0073e9SAndroid Build Coastguard Worker         quotient += 1;
103*da0073e9SAndroid Build Coastguard Worker       }
104*da0073e9SAndroid Build Coastguard Worker       difference -= 1;
105*da0073e9SAndroid Build Coastguard Worker     }
106*da0073e9SAndroid Build Coastguard Worker     // record the final quotient and remainder
107*da0073e9SAndroid Build Coastguard Worker     *quotient_ret = quotient;
108*da0073e9SAndroid Build Coastguard Worker     *remainder_ret = dividend;
109*da0073e9SAndroid Build Coastguard Worker   }
110*da0073e9SAndroid Build Coastguard Worker }
111*da0073e9SAndroid Build Coastguard Worker 
operator /=(const uint128 & divisor)112*da0073e9SAndroid Build Coastguard Worker uint128& uint128::operator/=(const uint128& divisor) {
113*da0073e9SAndroid Build Coastguard Worker   uint128 quotient = 0;
114*da0073e9SAndroid Build Coastguard Worker   uint128 remainder = 0;
115*da0073e9SAndroid Build Coastguard Worker   DivModImpl(*this, divisor, &quotient, &remainder);
116*da0073e9SAndroid Build Coastguard Worker   *this = quotient;
117*da0073e9SAndroid Build Coastguard Worker   return *this;
118*da0073e9SAndroid Build Coastguard Worker }
operator %=(const uint128 & divisor)119*da0073e9SAndroid Build Coastguard Worker uint128& uint128::operator%=(const uint128& divisor) {
120*da0073e9SAndroid Build Coastguard Worker   uint128 quotient = 0;
121*da0073e9SAndroid Build Coastguard Worker   uint128 remainder = 0;
122*da0073e9SAndroid Build Coastguard Worker   DivModImpl(*this, divisor, &quotient, &remainder);
123*da0073e9SAndroid Build Coastguard Worker   *this = remainder;
124*da0073e9SAndroid Build Coastguard Worker   return *this;
125*da0073e9SAndroid Build Coastguard Worker }
126*da0073e9SAndroid Build Coastguard Worker 
operator <<(std::ostream & o,const uint128 & b)127*da0073e9SAndroid Build Coastguard Worker std::ostream& operator<<(std::ostream& o, const uint128& b) {
128*da0073e9SAndroid Build Coastguard Worker   std::ios_base::fmtflags flags = o.flags();
129*da0073e9SAndroid Build Coastguard Worker 
130*da0073e9SAndroid Build Coastguard Worker   // Select a divisor which is the largest power of the base < 2^64.
131*da0073e9SAndroid Build Coastguard Worker   uint128 div;
132*da0073e9SAndroid Build Coastguard Worker   int div_base_log = 0;
133*da0073e9SAndroid Build Coastguard Worker   switch (flags & std::ios::basefield) {
134*da0073e9SAndroid Build Coastguard Worker     case std::ios::hex:
135*da0073e9SAndroid Build Coastguard Worker       div = (uint64_t)0x1000000000000000u; // 16^15
136*da0073e9SAndroid Build Coastguard Worker       div_base_log = 15;
137*da0073e9SAndroid Build Coastguard Worker       break;
138*da0073e9SAndroid Build Coastguard Worker     case std::ios::oct:
139*da0073e9SAndroid Build Coastguard Worker       div = (uint64_t)01000000000000000000000u; // 8^21
140*da0073e9SAndroid Build Coastguard Worker       div_base_log = 21;
141*da0073e9SAndroid Build Coastguard Worker       break;
142*da0073e9SAndroid Build Coastguard Worker     default: // std::ios::dec
143*da0073e9SAndroid Build Coastguard Worker       div = (uint64_t)10000000000000000000u; // 10^19
144*da0073e9SAndroid Build Coastguard Worker       div_base_log = 19;
145*da0073e9SAndroid Build Coastguard Worker       break;
146*da0073e9SAndroid Build Coastguard Worker   }
147*da0073e9SAndroid Build Coastguard Worker 
148*da0073e9SAndroid Build Coastguard Worker   // Now piece together the uint128 representation from three chunks of
149*da0073e9SAndroid Build Coastguard Worker   // the original value, each less than "div" and therefore representable
150*da0073e9SAndroid Build Coastguard Worker   // as a uint64.
151*da0073e9SAndroid Build Coastguard Worker   std::ostringstream os;
152*da0073e9SAndroid Build Coastguard Worker   std::ios_base::fmtflags copy_mask =
153*da0073e9SAndroid Build Coastguard Worker       std::ios::basefield | std::ios::showbase | std::ios::uppercase;
154*da0073e9SAndroid Build Coastguard Worker   os.setf(flags & copy_mask, copy_mask);
155*da0073e9SAndroid Build Coastguard Worker   uint128 high = b;
156*da0073e9SAndroid Build Coastguard Worker   uint128 low;
157*da0073e9SAndroid Build Coastguard Worker   uint128::DivModImpl(high, div, &high, &low);
158*da0073e9SAndroid Build Coastguard Worker   uint128 mid;
159*da0073e9SAndroid Build Coastguard Worker   uint128::DivModImpl(high, div, &high, &mid);
160*da0073e9SAndroid Build Coastguard Worker   if (high.lo_ != 0) {
161*da0073e9SAndroid Build Coastguard Worker     os << high.lo_;
162*da0073e9SAndroid Build Coastguard Worker     os << std::noshowbase << std::setfill('0') << std::setw(div_base_log);
163*da0073e9SAndroid Build Coastguard Worker     os << mid.lo_;
164*da0073e9SAndroid Build Coastguard Worker     os << std::setw(div_base_log);
165*da0073e9SAndroid Build Coastguard Worker   } else if (mid.lo_ != 0) {
166*da0073e9SAndroid Build Coastguard Worker     os << mid.lo_;
167*da0073e9SAndroid Build Coastguard Worker     os << std::noshowbase << std::setfill('0') << std::setw(div_base_log);
168*da0073e9SAndroid Build Coastguard Worker   }
169*da0073e9SAndroid Build Coastguard Worker   os << low.lo_;
170*da0073e9SAndroid Build Coastguard Worker   std::string rep = os.str();
171*da0073e9SAndroid Build Coastguard Worker 
172*da0073e9SAndroid Build Coastguard Worker   // Add the requisite padding.
173*da0073e9SAndroid Build Coastguard Worker   std::streamsize width = o.width(0);
174*da0073e9SAndroid Build Coastguard Worker   if (width > static_cast<std::streamsize>(rep.size())) {
175*da0073e9SAndroid Build Coastguard Worker     if ((flags & std::ios::adjustfield) == std::ios::left) {
176*da0073e9SAndroid Build Coastguard Worker       rep.append(width - rep.size(), o.fill());
177*da0073e9SAndroid Build Coastguard Worker     } else {
178*da0073e9SAndroid Build Coastguard Worker       rep.insert(
179*da0073e9SAndroid Build Coastguard Worker           static_cast<std::string::size_type>(0), width - rep.size(), o.fill());
180*da0073e9SAndroid Build Coastguard Worker     }
181*da0073e9SAndroid Build Coastguard Worker   }
182*da0073e9SAndroid Build Coastguard Worker 
183*da0073e9SAndroid Build Coastguard Worker   // Stream the final representation in a single "<<" call.
184*da0073e9SAndroid Build Coastguard Worker   return o << rep;
185*da0073e9SAndroid Build Coastguard Worker }
186*da0073e9SAndroid Build Coastguard Worker 
187*da0073e9SAndroid Build Coastguard Worker } // namespace c10
188