xref: /aosp_15_r20/external/pytorch/c10/util/safe_numerics.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <c10/macros/Macros.h>
3 
4 #include <cstdint>
5 
6 // GCC has __builtin_mul_overflow from before it supported __has_builtin
7 #ifdef _MSC_VER
8 #define C10_HAS_BUILTIN_OVERFLOW() (0)
9 #include <c10/util/llvmMathExtras.h>
10 #include <intrin.h>
11 #else
12 #define C10_HAS_BUILTIN_OVERFLOW() (1)
13 #endif
14 
15 namespace c10 {
16 
add_overflows(uint64_t a,uint64_t b,uint64_t * out)17 C10_ALWAYS_INLINE bool add_overflows(uint64_t a, uint64_t b, uint64_t* out) {
18 #if C10_HAS_BUILTIN_OVERFLOW()
19   return __builtin_add_overflow(a, b, out);
20 #else
21   unsigned long long tmp;
22 #if defined(_M_IX86) || defined(_M_X64)
23   auto carry = _addcarry_u64(0, a, b, &tmp);
24 #else
25   tmp = a + b;
26   unsigned long long vector = (a & b) ^ ((a ^ b) & ~tmp);
27   auto carry = vector >> 63;
28 #endif
29   *out = tmp;
30   return carry;
31 #endif
32 }
33 
mul_overflows(uint64_t a,uint64_t b,uint64_t * out)34 C10_ALWAYS_INLINE bool mul_overflows(uint64_t a, uint64_t b, uint64_t* out) {
35 #if C10_HAS_BUILTIN_OVERFLOW()
36   return __builtin_mul_overflow(a, b, out);
37 #else
38   *out = a * b;
39   // This test isnt exact, but avoids doing integer division
40   return (
41       (c10::llvm::countLeadingZeros(a) + c10::llvm::countLeadingZeros(b)) < 64);
42 #endif
43 }
44 
mul_overflows(int64_t a,int64_t b,int64_t * out)45 C10_ALWAYS_INLINE bool mul_overflows(int64_t a, int64_t b, int64_t* out) {
46 #if C10_HAS_BUILTIN_OVERFLOW()
47   return __builtin_mul_overflow(a, b, out);
48 #else
49   volatile int64_t tmp = a * b;
50   *out = tmp;
51   if (a == 0 || b == 0) {
52     return false;
53   }
54   return !(a == tmp / b);
55 #endif
56 }
57 
58 template <typename It>
safe_multiplies_u64(It first,It last,uint64_t * out)59 bool safe_multiplies_u64(It first, It last, uint64_t* out) {
60 #if C10_HAS_BUILTIN_OVERFLOW()
61   uint64_t prod = 1;
62   bool overflow = false;
63   for (; first != last; ++first) {
64     overflow |= c10::mul_overflows(prod, *first, &prod);
65   }
66   *out = prod;
67   return overflow;
68 #else
69   uint64_t prod = 1;
70   uint64_t prod_log2 = 0;
71   bool is_zero = false;
72   for (; first != last; ++first) {
73     auto x = static_cast<uint64_t>(*first);
74     prod *= x;
75     // log2(0) isn't valid, so need to track it specially
76     is_zero |= (x == 0);
77     prod_log2 += c10::llvm::Log2_64_Ceil(x);
78   }
79   *out = prod;
80   // This test isnt exact, but avoids doing integer division
81   return !is_zero && (prod_log2 >= 64);
82 #endif
83 }
84 
85 template <typename Container>
safe_multiplies_u64(const Container & c,uint64_t * out)86 bool safe_multiplies_u64(const Container& c, uint64_t* out) {
87   return safe_multiplies_u64(c.begin(), c.end(), out);
88 }
89 
90 } // namespace c10
91