xref: /aosp_15_r20/external/llvm-libc/src/__support/high_precision_decimal.h (revision 71db0c75aadcf003ffe3238005f61d7618a3fead)
1 //===-- High Precision Decimal ----------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See httpss//llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 // -----------------------------------------------------------------------------
10 //                               **** WARNING ****
11 // This file is shared with libc++. You should also be careful when adding
12 // dependencies to this file, since it needs to build for all libc++ targets.
13 // -----------------------------------------------------------------------------
14 
15 #ifndef LLVM_LIBC_SRC___SUPPORT_HIGH_PRECISION_DECIMAL_H
16 #define LLVM_LIBC_SRC___SUPPORT_HIGH_PRECISION_DECIMAL_H
17 
18 #include "src/__support/CPP/limits.h"
19 #include "src/__support/ctype_utils.h"
20 #include "src/__support/macros/config.h"
21 #include "src/__support/str_to_integer.h"
22 #include <stdint.h>
23 
24 namespace LIBC_NAMESPACE_DECL {
25 namespace internal {
26 
27 struct LShiftTableEntry {
28   uint32_t new_digits;
29   char const *power_of_five;
30 };
31 
32 // -----------------------------------------------------------------------------
33 //                               **** WARNING ****
34 // This interface is shared with libc++, if you change this interface you need
35 // to update it in both libc and libc++.
36 // -----------------------------------------------------------------------------
37 // This is used in both this file and in the main str_to_float.h.
38 // TODO: Figure out where to put this.
39 enum class RoundDirection { Up, Down, Nearest };
40 
41 // This is based on the HPD data structure described as part of the Simple
42 // Decimal Conversion algorithm by Nigel Tao, described at this link:
43 // https://nigeltao.github.io/blog/2020/parse-number-f64-simple.html
44 class HighPrecisionDecimal {
45 
46   // This precomputed table speeds up left shifts by having the number of new
47   // digits that will be added by multiplying 5^i by 2^i. If the number is less
48   // than 5^i then it will add one fewer digit. There are only 60 entries since
49   // that's the max shift amount.
50   // This table was generated by the script at
51   // libc/utils/mathtools/GenerateHPDConstants.py
52   static constexpr LShiftTableEntry LEFT_SHIFT_DIGIT_TABLE[] = {
53       {0, ""},
54       {1, "5"},
55       {1, "25"},
56       {1, "125"},
57       {2, "625"},
58       {2, "3125"},
59       {2, "15625"},
60       {3, "78125"},
61       {3, "390625"},
62       {3, "1953125"},
63       {4, "9765625"},
64       {4, "48828125"},
65       {4, "244140625"},
66       {4, "1220703125"},
67       {5, "6103515625"},
68       {5, "30517578125"},
69       {5, "152587890625"},
70       {6, "762939453125"},
71       {6, "3814697265625"},
72       {6, "19073486328125"},
73       {7, "95367431640625"},
74       {7, "476837158203125"},
75       {7, "2384185791015625"},
76       {7, "11920928955078125"},
77       {8, "59604644775390625"},
78       {8, "298023223876953125"},
79       {8, "1490116119384765625"},
80       {9, "7450580596923828125"},
81       {9, "37252902984619140625"},
82       {9, "186264514923095703125"},
83       {10, "931322574615478515625"},
84       {10, "4656612873077392578125"},
85       {10, "23283064365386962890625"},
86       {10, "116415321826934814453125"},
87       {11, "582076609134674072265625"},
88       {11, "2910383045673370361328125"},
89       {11, "14551915228366851806640625"},
90       {12, "72759576141834259033203125"},
91       {12, "363797880709171295166015625"},
92       {12, "1818989403545856475830078125"},
93       {13, "9094947017729282379150390625"},
94       {13, "45474735088646411895751953125"},
95       {13, "227373675443232059478759765625"},
96       {13, "1136868377216160297393798828125"},
97       {14, "5684341886080801486968994140625"},
98       {14, "28421709430404007434844970703125"},
99       {14, "142108547152020037174224853515625"},
100       {15, "710542735760100185871124267578125"},
101       {15, "3552713678800500929355621337890625"},
102       {15, "17763568394002504646778106689453125"},
103       {16, "88817841970012523233890533447265625"},
104       {16, "444089209850062616169452667236328125"},
105       {16, "2220446049250313080847263336181640625"},
106       {16, "11102230246251565404236316680908203125"},
107       {17, "55511151231257827021181583404541015625"},
108       {17, "277555756156289135105907917022705078125"},
109       {17, "1387778780781445675529539585113525390625"},
110       {18, "6938893903907228377647697925567626953125"},
111       {18, "34694469519536141888238489627838134765625"},
112       {18, "173472347597680709441192448139190673828125"},
113       {19, "867361737988403547205962240695953369140625"},
114   };
115 
116   // The maximum amount we can shift is the number of bits used in the
117   // accumulator, minus the number of bits needed to represent the base (in this
118   // case 4).
119   static constexpr uint32_t MAX_SHIFT_AMOUNT = sizeof(uint64_t) - 4;
120 
121   // 800 is an arbitrary number of digits, but should be
122   // large enough for any practical number.
123   static constexpr uint32_t MAX_NUM_DIGITS = 800;
124 
125   uint32_t num_digits = 0;
126   int32_t decimal_point = 0;
127   bool truncated = false;
128   uint8_t digits[MAX_NUM_DIGITS];
129 
130 private:
should_round_up(int32_t round_to_digit,RoundDirection round)131   LIBC_INLINE bool should_round_up(int32_t round_to_digit,
132                                    RoundDirection round) {
133     if (round_to_digit < 0 ||
134         static_cast<uint32_t>(round_to_digit) >= this->num_digits) {
135       return false;
136     }
137 
138     // The above condition handles all cases where all of the trailing digits
139     // are zero. In that case, if the rounding mode is up, then this number
140     // should be rounded up. Similarly, if the rounding mode is down, then it
141     // should always round down.
142     if (round == RoundDirection::Up) {
143       return true;
144     } else if (round == RoundDirection::Down) {
145       return false;
146     }
147     // Else round to nearest.
148 
149     // If we're right in the middle and there are no extra digits
150     if (this->digits[round_to_digit] == 5 &&
151         static_cast<uint32_t>(round_to_digit + 1) == this->num_digits) {
152 
153       // Round up if we've truncated (since that means the result is slightly
154       // higher than what's represented.)
155       if (this->truncated) {
156         return true;
157       }
158 
159       // If this exactly halfway, round to even.
160       if (round_to_digit == 0)
161         // When the input is ".5".
162         return false;
163       return this->digits[round_to_digit - 1] % 2 != 0;
164     }
165     // If there are digits after round_to_digit, they must be non-zero since we
166     // trim trailing zeroes after all operations that change digits.
167     return this->digits[round_to_digit] >= 5;
168   }
169 
170   // Takes an amount to left shift and returns the number of new digits needed
171   // to store the result based on LEFT_SHIFT_DIGIT_TABLE.
get_num_new_digits(uint32_t lshift_amount)172   LIBC_INLINE uint32_t get_num_new_digits(uint32_t lshift_amount) {
173     const char *power_of_five =
174         LEFT_SHIFT_DIGIT_TABLE[lshift_amount].power_of_five;
175     uint32_t new_digits = LEFT_SHIFT_DIGIT_TABLE[lshift_amount].new_digits;
176     uint32_t digit_index = 0;
177     while (power_of_five[digit_index] != 0) {
178       if (digit_index >= this->num_digits) {
179         return new_digits - 1;
180       }
181       if (this->digits[digit_index] != power_of_five[digit_index] - '0') {
182         return new_digits -
183                ((this->digits[digit_index] < power_of_five[digit_index] - '0')
184                     ? 1
185                     : 0);
186       }
187       ++digit_index;
188     }
189     return new_digits;
190   }
191 
192   // Trim all trailing 0s
trim_trailing_zeroes()193   LIBC_INLINE void trim_trailing_zeroes() {
194     while (this->num_digits > 0 && this->digits[this->num_digits - 1] == 0) {
195       --this->num_digits;
196     }
197     if (this->num_digits == 0) {
198       this->decimal_point = 0;
199     }
200   }
201 
202   // Perform a digitwise binary non-rounding right shift on this value by
203   // shift_amount. The shift_amount can't be more than MAX_SHIFT_AMOUNT to
204   // prevent overflow.
right_shift(uint32_t shift_amount)205   LIBC_INLINE void right_shift(uint32_t shift_amount) {
206     uint32_t read_index = 0;
207     uint32_t write_index = 0;
208 
209     uint64_t accumulator = 0;
210 
211     const uint64_t shift_mask = (uint64_t(1) << shift_amount) - 1;
212 
213     // Warm Up phase: we don't have enough digits to start writing, so just
214     // read them into the accumulator.
215     while (accumulator >> shift_amount == 0) {
216       uint64_t read_digit = 0;
217       // If there are still digits to read, read the next one, else the digit is
218       // assumed to be 0.
219       if (read_index < this->num_digits) {
220         read_digit = this->digits[read_index];
221       }
222       accumulator = accumulator * 10 + read_digit;
223       ++read_index;
224     }
225 
226     // Shift the decimal point by the number of digits it took to fill the
227     // accumulator.
228     this->decimal_point -= read_index - 1;
229 
230     // Middle phase: we have enough digits to write, as well as more digits to
231     // read. Keep reading until we run out of digits.
232     while (read_index < this->num_digits) {
233       uint64_t read_digit = this->digits[read_index];
234       uint64_t write_digit = accumulator >> shift_amount;
235       accumulator &= shift_mask;
236       this->digits[write_index] = static_cast<uint8_t>(write_digit);
237       accumulator = accumulator * 10 + read_digit;
238       ++read_index;
239       ++write_index;
240     }
241 
242     // Cool Down phase: All of the readable digits have been read, so just write
243     // the remainder, while treating any more digits as 0.
244     while (accumulator > 0) {
245       uint64_t write_digit = accumulator >> shift_amount;
246       accumulator &= shift_mask;
247       if (write_index < MAX_NUM_DIGITS) {
248         this->digits[write_index] = static_cast<uint8_t>(write_digit);
249         ++write_index;
250       } else if (write_digit > 0) {
251         this->truncated = true;
252       }
253       accumulator = accumulator * 10;
254     }
255     this->num_digits = write_index;
256     this->trim_trailing_zeroes();
257   }
258 
259   // Perform a digitwise binary non-rounding left shift on this value by
260   // shift_amount. The shift_amount can't be more than MAX_SHIFT_AMOUNT to
261   // prevent overflow.
left_shift(uint32_t shift_amount)262   LIBC_INLINE void left_shift(uint32_t shift_amount) {
263     uint32_t new_digits = this->get_num_new_digits(shift_amount);
264 
265     int32_t read_index = this->num_digits - 1;
266     uint32_t write_index = this->num_digits + new_digits;
267 
268     uint64_t accumulator = 0;
269 
270     // No Warm Up phase. Since we're putting digits in at the top and taking
271     // digits from the bottom we don't have to wait for the accumulator to fill.
272 
273     // Middle phase: while we have more digits to read, keep reading as well as
274     // writing.
275     while (read_index >= 0) {
276       accumulator += static_cast<uint64_t>(this->digits[read_index])
277                      << shift_amount;
278       uint64_t next_accumulator = accumulator / 10;
279       uint64_t write_digit = accumulator - (10 * next_accumulator);
280       --write_index;
281       if (write_index < MAX_NUM_DIGITS) {
282         this->digits[write_index] = static_cast<uint8_t>(write_digit);
283       } else if (write_digit != 0) {
284         this->truncated = true;
285       }
286       accumulator = next_accumulator;
287       --read_index;
288     }
289 
290     // Cool Down phase: there are no more digits to read, so just write the
291     // remaining digits in the accumulator.
292     while (accumulator > 0) {
293       uint64_t next_accumulator = accumulator / 10;
294       uint64_t write_digit = accumulator - (10 * next_accumulator);
295       --write_index;
296       if (write_index < MAX_NUM_DIGITS) {
297         this->digits[write_index] = static_cast<uint8_t>(write_digit);
298       } else if (write_digit != 0) {
299         this->truncated = true;
300       }
301       accumulator = next_accumulator;
302     }
303 
304     this->num_digits += new_digits;
305     if (this->num_digits > MAX_NUM_DIGITS) {
306       this->num_digits = MAX_NUM_DIGITS;
307     }
308     this->decimal_point += new_digits;
309     this->trim_trailing_zeroes();
310   }
311 
312 public:
313   // num_string is assumed to be a string of numeric characters. It doesn't
314   // handle leading spaces.
315   LIBC_INLINE
316   HighPrecisionDecimal(
317       const char *__restrict num_string,
318       const size_t num_len = cpp::numeric_limits<size_t>::max()) {
319     bool saw_dot = false;
320     size_t num_cur = 0;
321     // This counts the digits in the number, even if there isn't space to store
322     // them all.
323     uint32_t total_digits = 0;
324     while (num_cur < num_len &&
325            (isdigit(num_string[num_cur]) || num_string[num_cur] == '.')) {
326       if (num_string[num_cur] == '.') {
327         if (saw_dot) {
328           break;
329         }
330         this->decimal_point = total_digits;
331         saw_dot = true;
332       } else {
333         if (num_string[num_cur] == '0' && this->num_digits == 0) {
334           --this->decimal_point;
335           ++num_cur;
336           continue;
337         }
338         ++total_digits;
339         if (this->num_digits < MAX_NUM_DIGITS) {
340           this->digits[this->num_digits] =
341               static_cast<uint8_t>(num_string[num_cur] - '0');
342           ++this->num_digits;
343         } else if (num_string[num_cur] != '0') {
344           this->truncated = true;
345         }
346       }
347       ++num_cur;
348     }
349 
350     if (!saw_dot)
351       this->decimal_point = total_digits;
352 
353     if (num_cur < num_len &&
354         (num_string[num_cur] == 'e' || num_string[num_cur] == 'E')) {
355       ++num_cur;
356       if (isdigit(num_string[num_cur]) || num_string[num_cur] == '+' ||
357           num_string[num_cur] == '-') {
358         auto result =
359             strtointeger<int32_t>(num_string + num_cur, 10, num_len - num_cur);
360         if (result.has_error()) {
361           // TODO: handle error
362         }
363         int32_t add_to_exponent = result.value;
364 
365         // Here we do this operation as int64 to avoid overflow.
366         int64_t temp_exponent = static_cast<int64_t>(this->decimal_point) +
367                                 static_cast<int64_t>(add_to_exponent);
368 
369         // Theoretically these numbers should be MAX_BIASED_EXPONENT for long
370         // double, but that should be ~16,000 which is much less than 1 << 30.
371         if (temp_exponent > (1 << 30)) {
372           temp_exponent = (1 << 30);
373         } else if (temp_exponent < -(1 << 30)) {
374           temp_exponent = -(1 << 30);
375         }
376         this->decimal_point = static_cast<int32_t>(temp_exponent);
377       }
378     }
379 
380     this->trim_trailing_zeroes();
381   }
382 
383   // Binary shift left (shift_amount > 0) or right (shift_amount < 0)
shift(int shift_amount)384   LIBC_INLINE void shift(int shift_amount) {
385     if (shift_amount == 0) {
386       return;
387     }
388     // Left
389     else if (shift_amount > 0) {
390       while (static_cast<uint32_t>(shift_amount) > MAX_SHIFT_AMOUNT) {
391         this->left_shift(MAX_SHIFT_AMOUNT);
392         shift_amount -= MAX_SHIFT_AMOUNT;
393       }
394       this->left_shift(shift_amount);
395     }
396     // Right
397     else {
398       while (static_cast<uint32_t>(shift_amount) < -MAX_SHIFT_AMOUNT) {
399         this->right_shift(MAX_SHIFT_AMOUNT);
400         shift_amount += MAX_SHIFT_AMOUNT;
401       }
402       this->right_shift(-shift_amount);
403     }
404   }
405 
406   // Round the number represented to the closest value of unsigned int type T.
407   // This is done ignoring overflow.
408   template <class T>
409   LIBC_INLINE T
410   round_to_integer_type(RoundDirection round = RoundDirection::Nearest) {
411     T result = 0;
412     uint32_t cur_digit = 0;
413 
414     while (static_cast<int32_t>(cur_digit) < this->decimal_point &&
415            cur_digit < this->num_digits) {
416       result = result * 10 + (this->digits[cur_digit]);
417       ++cur_digit;
418     }
419 
420     // If there are implicit 0s at the end of the number, include those.
421     while (static_cast<int32_t>(cur_digit) < this->decimal_point) {
422       result *= 10;
423       ++cur_digit;
424     }
425     return result + static_cast<unsigned int>(
426                         this->should_round_up(this->decimal_point, round));
427   }
428 
429   // Extra functions for testing.
430 
get_digits()431   LIBC_INLINE uint8_t *get_digits() { return this->digits; }
get_num_digits()432   LIBC_INLINE uint32_t get_num_digits() { return this->num_digits; }
get_decimal_point()433   LIBC_INLINE int32_t get_decimal_point() { return this->decimal_point; }
set_truncated(bool trunc)434   LIBC_INLINE void set_truncated(bool trunc) { this->truncated = trunc; }
435 };
436 
437 } // namespace internal
438 } // namespace LIBC_NAMESPACE_DECL
439 
440 #endif // LLVM_LIBC_SRC___SUPPORT_HIGH_PRECISION_DECIMAL_H
441