xref: /aosp_15_r20/external/pytorch/c10/util/Bitset.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <cstddef>
4 #if defined(_MSC_VER)
5 #include <intrin.h>
6 #endif
7 
8 namespace c10::utils {
9 
10 /**
11  * This is a simple bitset class with sizeof(long long int) bits.
12  * You can set bits, unset bits, query bits by index,
13  * and query for the first set bit.
14  * Before using this class, please also take a look at std::bitset,
15  * which has more functionality and is more generic. It is probably
16  * a better fit for your use case. The sole reason for c10::utils::bitset
17  * to exist is that std::bitset misses a find_first_set() method.
18  */
19 struct bitset final {
20  private:
21 #if defined(_MSC_VER)
22   // MSVCs _BitScanForward64 expects int64_t
23   using bitset_type = int64_t;
24 #else
25   // POSIX ffsll expects long long int
26   using bitset_type = long long int;
27 #endif
28  public:
NUM_BITSfinal29   static constexpr size_t NUM_BITS() {
30     return 8 * sizeof(bitset_type);
31   }
32 
33   constexpr bitset() noexcept = default;
34   constexpr bitset(const bitset&) noexcept = default;
35   constexpr bitset(bitset&&) noexcept = default;
36   // there is an issure for gcc 5.3.0 when define default function as constexpr
37   // see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=68754.
38   bitset& operator=(const bitset&) noexcept = default;
39   bitset& operator=(bitset&&) noexcept = default;
40 
setfinal41   constexpr void set(size_t index) noexcept {
42     bitset_ |= (static_cast<long long int>(1) << index);
43   }
44 
unsetfinal45   constexpr void unset(size_t index) noexcept {
46     bitset_ &= ~(static_cast<long long int>(1) << index);
47   }
48 
getfinal49   constexpr bool get(size_t index) const noexcept {
50     return bitset_ & (static_cast<long long int>(1) << index);
51   }
52 
is_entirely_unsetfinal53   constexpr bool is_entirely_unset() const noexcept {
54     return 0 == bitset_;
55   }
56 
57   // Call the given functor with the index of each bit that is set
58   template <class Func>
for_each_set_bitfinal59   void for_each_set_bit(Func&& func) const {
60     bitset cur = *this;
61     size_t index = cur.find_first_set();
62     while (0 != index) {
63       // -1 because find_first_set() is not one-indexed.
64       index -= 1;
65       func(index);
66       cur.unset(index);
67       index = cur.find_first_set();
68     }
69   }
70 
71  private:
72   // Return the index of the first set bit. The returned index is one-indexed
73   // (i.e. if the very first bit is set, this function returns '1'), and a
74   // return of '0' means that there was no bit set.
find_first_setfinal75   size_t find_first_set() const {
76 #if defined(_MSC_VER) && (defined(_M_X64) || defined(_M_ARM64))
77     unsigned long result;
78     bool has_bits_set = (0 != _BitScanForward64(&result, bitset_));
79     if (!has_bits_set) {
80       return 0;
81     }
82     return result + 1;
83 #elif defined(_MSC_VER) && defined(_M_IX86)
84     unsigned long result;
85     if (static_cast<uint32_t>(bitset_) != 0) {
86       bool has_bits_set =
87           (0 != _BitScanForward(&result, static_cast<uint32_t>(bitset_)));
88       if (!has_bits_set) {
89         return 0;
90       }
91       return result + 1;
92     } else {
93       bool has_bits_set =
94           (0 != _BitScanForward(&result, static_cast<uint32_t>(bitset_ >> 32)));
95       if (!has_bits_set) {
96         return 32;
97       }
98       return result + 33;
99     }
100 #else
101     return __builtin_ffsll(bitset_);
102 #endif
103   }
104 
105   friend bool operator==(bitset lhs, bitset rhs) noexcept {
106     return lhs.bitset_ == rhs.bitset_;
107   }
108 
109   bitset_type bitset_{0};
110 };
111 
112 inline bool operator!=(bitset lhs, bitset rhs) noexcept {
113   return !(lhs == rhs);
114 }
115 
116 } // namespace c10::utils
117