1 #pragma once
2 #include <c10/macros/Macros.h>
3
4 #include <cstdint>
5
6 // GCC has __builtin_mul_overflow from before it supported __has_builtin
7 #ifdef _MSC_VER
8 #define C10_HAS_BUILTIN_OVERFLOW() (0)
9 #include <c10/util/llvmMathExtras.h>
10 #include <intrin.h>
11 #else
12 #define C10_HAS_BUILTIN_OVERFLOW() (1)
13 #endif
14
15 namespace c10 {
16
add_overflows(uint64_t a,uint64_t b,uint64_t * out)17 C10_ALWAYS_INLINE bool add_overflows(uint64_t a, uint64_t b, uint64_t* out) {
18 #if C10_HAS_BUILTIN_OVERFLOW()
19 return __builtin_add_overflow(a, b, out);
20 #else
21 unsigned long long tmp;
22 #if defined(_M_IX86) || defined(_M_X64)
23 auto carry = _addcarry_u64(0, a, b, &tmp);
24 #else
25 tmp = a + b;
26 unsigned long long vector = (a & b) ^ ((a ^ b) & ~tmp);
27 auto carry = vector >> 63;
28 #endif
29 *out = tmp;
30 return carry;
31 #endif
32 }
33
mul_overflows(uint64_t a,uint64_t b,uint64_t * out)34 C10_ALWAYS_INLINE bool mul_overflows(uint64_t a, uint64_t b, uint64_t* out) {
35 #if C10_HAS_BUILTIN_OVERFLOW()
36 return __builtin_mul_overflow(a, b, out);
37 #else
38 *out = a * b;
39 // This test isnt exact, but avoids doing integer division
40 return (
41 (c10::llvm::countLeadingZeros(a) + c10::llvm::countLeadingZeros(b)) < 64);
42 #endif
43 }
44
mul_overflows(int64_t a,int64_t b,int64_t * out)45 C10_ALWAYS_INLINE bool mul_overflows(int64_t a, int64_t b, int64_t* out) {
46 #if C10_HAS_BUILTIN_OVERFLOW()
47 return __builtin_mul_overflow(a, b, out);
48 #else
49 volatile int64_t tmp = a * b;
50 *out = tmp;
51 if (a == 0 || b == 0) {
52 return false;
53 }
54 return !(a == tmp / b);
55 #endif
56 }
57
58 template <typename It>
safe_multiplies_u64(It first,It last,uint64_t * out)59 bool safe_multiplies_u64(It first, It last, uint64_t* out) {
60 #if C10_HAS_BUILTIN_OVERFLOW()
61 uint64_t prod = 1;
62 bool overflow = false;
63 for (; first != last; ++first) {
64 overflow |= c10::mul_overflows(prod, *first, &prod);
65 }
66 *out = prod;
67 return overflow;
68 #else
69 uint64_t prod = 1;
70 uint64_t prod_log2 = 0;
71 bool is_zero = false;
72 for (; first != last; ++first) {
73 auto x = static_cast<uint64_t>(*first);
74 prod *= x;
75 // log2(0) isn't valid, so need to track it specially
76 is_zero |= (x == 0);
77 prod_log2 += c10::llvm::Log2_64_Ceil(x);
78 }
79 *out = prod;
80 // This test isnt exact, but avoids doing integer division
81 return !is_zero && (prod_log2 >= 64);
82 #endif
83 }
84
85 template <typename Container>
safe_multiplies_u64(const Container & c,uint64_t * out)86 bool safe_multiplies_u64(const Container& c, uint64_t* out) {
87 return safe_multiplies_u64(c.begin(), c.end(), out);
88 }
89
90 } // namespace c10
91