xref: /aosp_15_r20/external/pytorch/c10/util/safe_numerics.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker #include <c10/macros/Macros.h>
3*da0073e9SAndroid Build Coastguard Worker 
4*da0073e9SAndroid Build Coastguard Worker #include <cstdint>
5*da0073e9SAndroid Build Coastguard Worker 
6*da0073e9SAndroid Build Coastguard Worker // GCC has __builtin_mul_overflow from before it supported __has_builtin
7*da0073e9SAndroid Build Coastguard Worker #ifdef _MSC_VER
8*da0073e9SAndroid Build Coastguard Worker #define C10_HAS_BUILTIN_OVERFLOW() (0)
9*da0073e9SAndroid Build Coastguard Worker #include <c10/util/llvmMathExtras.h>
10*da0073e9SAndroid Build Coastguard Worker #include <intrin.h>
11*da0073e9SAndroid Build Coastguard Worker #else
12*da0073e9SAndroid Build Coastguard Worker #define C10_HAS_BUILTIN_OVERFLOW() (1)
13*da0073e9SAndroid Build Coastguard Worker #endif
14*da0073e9SAndroid Build Coastguard Worker 
15*da0073e9SAndroid Build Coastguard Worker namespace c10 {
16*da0073e9SAndroid Build Coastguard Worker 
add_overflows(uint64_t a,uint64_t b,uint64_t * out)17*da0073e9SAndroid Build Coastguard Worker C10_ALWAYS_INLINE bool add_overflows(uint64_t a, uint64_t b, uint64_t* out) {
18*da0073e9SAndroid Build Coastguard Worker #if C10_HAS_BUILTIN_OVERFLOW()
19*da0073e9SAndroid Build Coastguard Worker   return __builtin_add_overflow(a, b, out);
20*da0073e9SAndroid Build Coastguard Worker #else
21*da0073e9SAndroid Build Coastguard Worker   unsigned long long tmp;
22*da0073e9SAndroid Build Coastguard Worker #if defined(_M_IX86) || defined(_M_X64)
23*da0073e9SAndroid Build Coastguard Worker   auto carry = _addcarry_u64(0, a, b, &tmp);
24*da0073e9SAndroid Build Coastguard Worker #else
25*da0073e9SAndroid Build Coastguard Worker   tmp = a + b;
26*da0073e9SAndroid Build Coastguard Worker   unsigned long long vector = (a & b) ^ ((a ^ b) & ~tmp);
27*da0073e9SAndroid Build Coastguard Worker   auto carry = vector >> 63;
28*da0073e9SAndroid Build Coastguard Worker #endif
29*da0073e9SAndroid Build Coastguard Worker   *out = tmp;
30*da0073e9SAndroid Build Coastguard Worker   return carry;
31*da0073e9SAndroid Build Coastguard Worker #endif
32*da0073e9SAndroid Build Coastguard Worker }
33*da0073e9SAndroid Build Coastguard Worker 
mul_overflows(uint64_t a,uint64_t b,uint64_t * out)34*da0073e9SAndroid Build Coastguard Worker C10_ALWAYS_INLINE bool mul_overflows(uint64_t a, uint64_t b, uint64_t* out) {
35*da0073e9SAndroid Build Coastguard Worker #if C10_HAS_BUILTIN_OVERFLOW()
36*da0073e9SAndroid Build Coastguard Worker   return __builtin_mul_overflow(a, b, out);
37*da0073e9SAndroid Build Coastguard Worker #else
38*da0073e9SAndroid Build Coastguard Worker   *out = a * b;
39*da0073e9SAndroid Build Coastguard Worker   // This test isnt exact, but avoids doing integer division
40*da0073e9SAndroid Build Coastguard Worker   return (
41*da0073e9SAndroid Build Coastguard Worker       (c10::llvm::countLeadingZeros(a) + c10::llvm::countLeadingZeros(b)) < 64);
42*da0073e9SAndroid Build Coastguard Worker #endif
43*da0073e9SAndroid Build Coastguard Worker }
44*da0073e9SAndroid Build Coastguard Worker 
mul_overflows(int64_t a,int64_t b,int64_t * out)45*da0073e9SAndroid Build Coastguard Worker C10_ALWAYS_INLINE bool mul_overflows(int64_t a, int64_t b, int64_t* out) {
46*da0073e9SAndroid Build Coastguard Worker #if C10_HAS_BUILTIN_OVERFLOW()
47*da0073e9SAndroid Build Coastguard Worker   return __builtin_mul_overflow(a, b, out);
48*da0073e9SAndroid Build Coastguard Worker #else
49*da0073e9SAndroid Build Coastguard Worker   volatile int64_t tmp = a * b;
50*da0073e9SAndroid Build Coastguard Worker   *out = tmp;
51*da0073e9SAndroid Build Coastguard Worker   if (a == 0 || b == 0) {
52*da0073e9SAndroid Build Coastguard Worker     return false;
53*da0073e9SAndroid Build Coastguard Worker   }
54*da0073e9SAndroid Build Coastguard Worker   return !(a == tmp / b);
55*da0073e9SAndroid Build Coastguard Worker #endif
56*da0073e9SAndroid Build Coastguard Worker }
57*da0073e9SAndroid Build Coastguard Worker 
58*da0073e9SAndroid Build Coastguard Worker template <typename It>
safe_multiplies_u64(It first,It last,uint64_t * out)59*da0073e9SAndroid Build Coastguard Worker bool safe_multiplies_u64(It first, It last, uint64_t* out) {
60*da0073e9SAndroid Build Coastguard Worker #if C10_HAS_BUILTIN_OVERFLOW()
61*da0073e9SAndroid Build Coastguard Worker   uint64_t prod = 1;
62*da0073e9SAndroid Build Coastguard Worker   bool overflow = false;
63*da0073e9SAndroid Build Coastguard Worker   for (; first != last; ++first) {
64*da0073e9SAndroid Build Coastguard Worker     overflow |= c10::mul_overflows(prod, *first, &prod);
65*da0073e9SAndroid Build Coastguard Worker   }
66*da0073e9SAndroid Build Coastguard Worker   *out = prod;
67*da0073e9SAndroid Build Coastguard Worker   return overflow;
68*da0073e9SAndroid Build Coastguard Worker #else
69*da0073e9SAndroid Build Coastguard Worker   uint64_t prod = 1;
70*da0073e9SAndroid Build Coastguard Worker   uint64_t prod_log2 = 0;
71*da0073e9SAndroid Build Coastguard Worker   bool is_zero = false;
72*da0073e9SAndroid Build Coastguard Worker   for (; first != last; ++first) {
73*da0073e9SAndroid Build Coastguard Worker     auto x = static_cast<uint64_t>(*first);
74*da0073e9SAndroid Build Coastguard Worker     prod *= x;
75*da0073e9SAndroid Build Coastguard Worker     // log2(0) isn't valid, so need to track it specially
76*da0073e9SAndroid Build Coastguard Worker     is_zero |= (x == 0);
77*da0073e9SAndroid Build Coastguard Worker     prod_log2 += c10::llvm::Log2_64_Ceil(x);
78*da0073e9SAndroid Build Coastguard Worker   }
79*da0073e9SAndroid Build Coastguard Worker   *out = prod;
80*da0073e9SAndroid Build Coastguard Worker   // This test isnt exact, but avoids doing integer division
81*da0073e9SAndroid Build Coastguard Worker   return !is_zero && (prod_log2 >= 64);
82*da0073e9SAndroid Build Coastguard Worker #endif
83*da0073e9SAndroid Build Coastguard Worker }
84*da0073e9SAndroid Build Coastguard Worker 
85*da0073e9SAndroid Build Coastguard Worker template <typename Container>
safe_multiplies_u64(const Container & c,uint64_t * out)86*da0073e9SAndroid Build Coastguard Worker bool safe_multiplies_u64(const Container& c, uint64_t* out) {
87*da0073e9SAndroid Build Coastguard Worker   return safe_multiplies_u64(c.begin(), c.end(), out);
88*da0073e9SAndroid Build Coastguard Worker }
89*da0073e9SAndroid Build Coastguard Worker 
90*da0073e9SAndroid Build Coastguard Worker } // namespace c10
91