xref: /aosp_15_r20/external/marisa-trie/lib/marisa/grimoire/vector/bit-vector.cc (revision ab8db090fce404b23716c4c9194221ee27efe31c)
1 #include "marisa/grimoire/vector/pop-count.h"
2 #include "marisa/grimoire/vector/bit-vector.h"
3 
4 namespace marisa {
5 namespace grimoire {
6 namespace vector {
7 namespace {
8 
9 #ifdef MARISA_USE_BMI2
select_bit(std::size_t i,std::size_t bit_id,UInt64 unit)10 std::size_t select_bit(std::size_t i, std::size_t bit_id, UInt64 unit) {
11  #ifdef _MSC_VER
12   unsigned long pos;
13   ::_BitScanForward64(&pos, _pdep_u64(1ULL << i, unit));
14   return bit_id + pos;
15  #else  // _MSC_VER
16   return bit_id + ::__builtin_ctzll(_pdep_u64(1ULL << i, unit));
17  #endif  // _MSC_VER
18 }
19 #else  // MARISA_USE_BMI2
20 const UInt8 SELECT_TABLE[8][256] = {
21   {
22     7, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
23     4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
24     5, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
25     4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
26     6, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
27     4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
28     5, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
29     4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
30     7, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
31     4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
32     5, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
33     4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
34     6, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
35     4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
36     5, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
37     4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0
38   },
39   {
40     7, 7, 7, 1, 7, 2, 2, 1, 7, 3, 3, 1, 3, 2, 2, 1,
41     7, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1,
42     7, 5, 5, 1, 5, 2, 2, 1, 5, 3, 3, 1, 3, 2, 2, 1,
43     5, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1,
44     7, 6, 6, 1, 6, 2, 2, 1, 6, 3, 3, 1, 3, 2, 2, 1,
45     6, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1,
46     6, 5, 5, 1, 5, 2, 2, 1, 5, 3, 3, 1, 3, 2, 2, 1,
47     5, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1,
48     7, 7, 7, 1, 7, 2, 2, 1, 7, 3, 3, 1, 3, 2, 2, 1,
49     7, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1,
50     7, 5, 5, 1, 5, 2, 2, 1, 5, 3, 3, 1, 3, 2, 2, 1,
51     5, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1,
52     7, 6, 6, 1, 6, 2, 2, 1, 6, 3, 3, 1, 3, 2, 2, 1,
53     6, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1,
54     6, 5, 5, 1, 5, 2, 2, 1, 5, 3, 3, 1, 3, 2, 2, 1,
55     5, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1
56   },
57   {
58     7, 7, 7, 7, 7, 7, 7, 2, 7, 7, 7, 3, 7, 3, 3, 2,
59     7, 7, 7, 4, 7, 4, 4, 2, 7, 4, 4, 3, 4, 3, 3, 2,
60     7, 7, 7, 5, 7, 5, 5, 2, 7, 5, 5, 3, 5, 3, 3, 2,
61     7, 5, 5, 4, 5, 4, 4, 2, 5, 4, 4, 3, 4, 3, 3, 2,
62     7, 7, 7, 6, 7, 6, 6, 2, 7, 6, 6, 3, 6, 3, 3, 2,
63     7, 6, 6, 4, 6, 4, 4, 2, 6, 4, 4, 3, 4, 3, 3, 2,
64     7, 6, 6, 5, 6, 5, 5, 2, 6, 5, 5, 3, 5, 3, 3, 2,
65     6, 5, 5, 4, 5, 4, 4, 2, 5, 4, 4, 3, 4, 3, 3, 2,
66     7, 7, 7, 7, 7, 7, 7, 2, 7, 7, 7, 3, 7, 3, 3, 2,
67     7, 7, 7, 4, 7, 4, 4, 2, 7, 4, 4, 3, 4, 3, 3, 2,
68     7, 7, 7, 5, 7, 5, 5, 2, 7, 5, 5, 3, 5, 3, 3, 2,
69     7, 5, 5, 4, 5, 4, 4, 2, 5, 4, 4, 3, 4, 3, 3, 2,
70     7, 7, 7, 6, 7, 6, 6, 2, 7, 6, 6, 3, 6, 3, 3, 2,
71     7, 6, 6, 4, 6, 4, 4, 2, 6, 4, 4, 3, 4, 3, 3, 2,
72     7, 6, 6, 5, 6, 5, 5, 2, 6, 5, 5, 3, 5, 3, 3, 2,
73     6, 5, 5, 4, 5, 4, 4, 2, 5, 4, 4, 3, 4, 3, 3, 2
74   },
75   {
76     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 3,
77     7, 7, 7, 7, 7, 7, 7, 4, 7, 7, 7, 4, 7, 4, 4, 3,
78     7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 7, 5, 7, 5, 5, 3,
79     7, 7, 7, 5, 7, 5, 5, 4, 7, 5, 5, 4, 5, 4, 4, 3,
80     7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 3,
81     7, 7, 7, 6, 7, 6, 6, 4, 7, 6, 6, 4, 6, 4, 4, 3,
82     7, 7, 7, 6, 7, 6, 6, 5, 7, 6, 6, 5, 6, 5, 5, 3,
83     7, 6, 6, 5, 6, 5, 5, 4, 6, 5, 5, 4, 5, 4, 4, 3,
84     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 3,
85     7, 7, 7, 7, 7, 7, 7, 4, 7, 7, 7, 4, 7, 4, 4, 3,
86     7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 7, 5, 7, 5, 5, 3,
87     7, 7, 7, 5, 7, 5, 5, 4, 7, 5, 5, 4, 5, 4, 4, 3,
88     7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 3,
89     7, 7, 7, 6, 7, 6, 6, 4, 7, 6, 6, 4, 6, 4, 4, 3,
90     7, 7, 7, 6, 7, 6, 6, 5, 7, 6, 6, 5, 6, 5, 5, 3,
91     7, 6, 6, 5, 6, 5, 5, 4, 6, 5, 5, 4, 5, 4, 4, 3
92   },
93   {
94     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
95     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 4,
96     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5,
97     7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 7, 5, 7, 5, 5, 4,
98     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6,
99     7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 4,
100     7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 5,
101     7, 7, 7, 6, 7, 6, 6, 5, 7, 6, 6, 5, 6, 5, 5, 4,
102     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
103     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 4,
104     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5,
105     7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 7, 5, 7, 5, 5, 4,
106     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6,
107     7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 4,
108     7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 5,
109     7, 7, 7, 6, 7, 6, 6, 5, 7, 6, 6, 5, 6, 5, 5, 4
110   },
111   {
112     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
113     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
114     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
115     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5,
116     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
117     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6,
118     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6,
119     7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 5,
120     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
121     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
122     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
123     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5,
124     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
125     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6,
126     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6,
127     7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 5
128   },
129   {
130     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
131     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
132     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
133     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
134     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
135     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
136     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
137     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6,
138     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
139     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
140     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
141     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
142     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
143     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
144     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
145     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6
146   },
147   {
148     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
149     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
150     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
151     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
152     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
153     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
154     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
155     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
156     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
157     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
158     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
159     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
160     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
161     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
162     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
163     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7
164   }
165 };
166 
167  #if MARISA_WORD_SIZE == 64
168 const UInt64 MASK_01 = 0x0101010101010101ULL;
169   #if !defined(MARISA_X64) || !defined(MARISA_USE_SSSE3)
170 const UInt64 MASK_0F = 0x0F0F0F0F0F0F0F0FULL;
171 const UInt64 MASK_33 = 0x3333333333333333ULL;
172 const UInt64 MASK_55 = 0x5555555555555555ULL;
173   #endif  // !defined(MARISA_X64) || !defined(MARISA_USE_SSSE3)
174   #if !defined(MARISA_X64) || !defined(MARISA_USE_POPCNT)
175 const UInt64 MASK_80 = 0x8080808080808080ULL;
176   #endif  // !defined(MARISA_X64) || !defined(MARISA_USE_POPCNT)
177 
178 std::size_t select_bit(std::size_t i, std::size_t bit_id, UInt64 unit) {
179   UInt64 counts;
180   {
181   #if defined(MARISA_X64) && defined(MARISA_USE_SSSE3)
182     __m128i lower_nibbles = _mm_cvtsi64_si128(
183         static_cast<long long>(unit & 0x0F0F0F0F0F0F0F0FULL));
184     __m128i upper_nibbles = _mm_cvtsi64_si128(
185         static_cast<long long>(unit & 0xF0F0F0F0F0F0F0F0ULL));
186     upper_nibbles = _mm_srli_epi32(upper_nibbles, 4);
187 
188     __m128i lower_counts =
189         _mm_set_epi8(4, 3, 3, 2, 3, 2, 2, 1, 3, 2, 2, 1, 2, 1, 1, 0);
190     lower_counts = _mm_shuffle_epi8(lower_counts, lower_nibbles);
191     __m128i upper_counts =
192         _mm_set_epi8(4, 3, 3, 2, 3, 2, 2, 1, 3, 2, 2, 1, 2, 1, 1, 0);
193     upper_counts = _mm_shuffle_epi8(upper_counts, upper_nibbles);
194 
195     counts = static_cast<UInt64>(_mm_cvtsi128_si64(
196        _mm_add_epi8(lower_counts, upper_counts)));
197   #else  // defined(MARISA_X64) && defined(MARISA_USE_SSSE3)
198     counts = unit - ((unit >> 1) & MASK_55);
199     counts = (counts & MASK_33) + ((counts >> 2) & MASK_33);
200     counts = (counts + (counts >> 4)) & MASK_0F;
201   #endif  // defined(MARISA_X64) && defined(MARISA_USE_SSSE3)
202     counts *= MASK_01;
203   }
204 
205   #if defined(MARISA_X64) && defined(MARISA_USE_POPCNT)
206   UInt8 skip;
207   {
208     __m128i x = _mm_cvtsi64_si128(static_cast<long long>((i + 1) * MASK_01));
209     __m128i y = _mm_cvtsi64_si128(static_cast<long long>(counts));
210     x = _mm_cmpgt_epi8(x, y);
211     skip = (UInt8)PopCount::count(static_cast<UInt64>(_mm_cvtsi128_si64(x)));
212   }
213   #else  // defined(MARISA_X64) && defined(MARISA_USE_POPCNT)
214   const UInt64 x = (counts | MASK_80) - ((i + 1) * MASK_01);
215    #ifdef _MSC_VER
216   unsigned long skip;
217   ::_BitScanForward64(&skip, (x & MASK_80) >> 7);
218    #else  // _MSC_VER
219   const int skip = ::__builtin_ctzll((x & MASK_80) >> 7);
220    #endif  // _MSC_VER
221   #endif  // defined(MARISA_X64) && defined(MARISA_USE_POPCNT)
222 
223   bit_id += static_cast<std::size_t>(skip);
224   unit >>= skip;
225   i -= ((counts << 8) >> skip) & 0xFF;
226 
227   return bit_id + SELECT_TABLE[i][unit & 0xFF];
228 }
229  #else  // MARISA_WORD_SIZE == 64
230   #ifdef MARISA_USE_SSE2
231 const UInt8 POPCNT_TABLE[256] = {
232    0,  8,  8, 16,  8, 16, 16, 24,  8, 16, 16, 24, 16, 24, 24, 32,
233    8, 16, 16, 24, 16, 24, 24, 32, 16, 24, 24, 32, 24, 32, 32, 40,
234    8, 16, 16, 24, 16, 24, 24, 32, 16, 24, 24, 32, 24, 32, 32, 40,
235   16, 24, 24, 32, 24, 32, 32, 40, 24, 32, 32, 40, 32, 40, 40, 48,
236    8, 16, 16, 24, 16, 24, 24, 32, 16, 24, 24, 32, 24, 32, 32, 40,
237   16, 24, 24, 32, 24, 32, 32, 40, 24, 32, 32, 40, 32, 40, 40, 48,
238   16, 24, 24, 32, 24, 32, 32, 40, 24, 32, 32, 40, 32, 40, 40, 48,
239   24, 32, 32, 40, 32, 40, 40, 48, 32, 40, 40, 48, 40, 48, 48, 56,
240    8, 16, 16, 24, 16, 24, 24, 32, 16, 24, 24, 32, 24, 32, 32, 40,
241   16, 24, 24, 32, 24, 32, 32, 40, 24, 32, 32, 40, 32, 40, 40, 48,
242   16, 24, 24, 32, 24, 32, 32, 40, 24, 32, 32, 40, 32, 40, 40, 48,
243   24, 32, 32, 40, 32, 40, 40, 48, 32, 40, 40, 48, 40, 48, 48, 56,
244   16, 24, 24, 32, 24, 32, 32, 40, 24, 32, 32, 40, 32, 40, 40, 48,
245   24, 32, 32, 40, 32, 40, 40, 48, 32, 40, 40, 48, 40, 48, 48, 56,
246   24, 32, 32, 40, 32, 40, 40, 48, 32, 40, 40, 48, 40, 48, 48, 56,
247   32, 40, 40, 48, 40, 48, 48, 56, 40, 48, 48, 56, 48, 56, 56, 64
248 };
249 
250 std::size_t select_bit(std::size_t i, std::size_t bit_id,
251     UInt32 unit_lo, UInt32 unit_hi) {
252   __m128i unit;
253   {
254     __m128i lower_dword = _mm_cvtsi32_si128(unit_lo);
255     __m128i upper_dword = _mm_cvtsi32_si128(unit_hi);
256     upper_dword = _mm_slli_si128(upper_dword, 4);
257     unit = _mm_or_si128(lower_dword, upper_dword);
258   }
259 
260   __m128i counts;
261   {
262    #ifdef MARISA_USE_SSSE3
263     __m128i lower_nibbles = _mm_set1_epi8(0x0F);
264     lower_nibbles = _mm_and_si128(lower_nibbles, unit);
265     __m128i upper_nibbles = _mm_set1_epi8((UInt8)0xF0);
266     upper_nibbles = _mm_and_si128(upper_nibbles, unit);
267     upper_nibbles = _mm_srli_epi32(upper_nibbles, 4);
268 
269     __m128i lower_counts =
270         _mm_set_epi8(4, 3, 3, 2, 3, 2, 2, 1, 3, 2, 2, 1, 2, 1, 1, 0);
271     lower_counts = _mm_shuffle_epi8(lower_counts, lower_nibbles);
272     __m128i upper_counts =
273         _mm_set_epi8(4, 3, 3, 2, 3, 2, 2, 1, 3, 2, 2, 1, 2, 1, 1, 0);
274     upper_counts = _mm_shuffle_epi8(upper_counts, upper_nibbles);
275 
276     counts = _mm_add_epi8(lower_counts, upper_counts);
277    #else  // MARISA_USE_SSSE3
278     __m128i x = _mm_srli_epi32(unit, 1);
279     x = _mm_and_si128(x, _mm_set1_epi8(0x55));
280     x = _mm_sub_epi8(unit, x);
281 
282     __m128i y = _mm_srli_epi32(x, 2);
283     y = _mm_and_si128(y, _mm_set1_epi8(0x33));
284     x = _mm_and_si128(x, _mm_set1_epi8(0x33));
285     x = _mm_add_epi8(x, y);
286 
287     y = _mm_srli_epi32(x, 4);
288     x = _mm_add_epi8(x, y);
289     counts = _mm_and_si128(x, _mm_set1_epi8(0x0F));
290    #endif  // MARISA_USE_SSSE3
291   }
292 
293   __m128i accumulated_counts;
294   {
295     __m128i x = counts;
296     x = _mm_slli_si128(x, 1);
297     __m128i y = counts;
298     y = _mm_add_epi32(y, x);
299 
300     x = y;
301     y = _mm_slli_si128(y, 2);
302     x = _mm_add_epi32(x, y);
303 
304     y = x;
305     x = _mm_slli_si128(x, 4);
306     y = _mm_add_epi32(y, x);
307 
308     accumulated_counts = _mm_set_epi32(0x7F7F7F7FU, 0x7F7F7F7FU, 0, 0);
309     accumulated_counts = _mm_or_si128(accumulated_counts, y);
310   }
311 
312   UInt8 skip;
313   {
314     __m128i x = _mm_set1_epi8((UInt8)(i + 1));
315     x = _mm_cmpgt_epi8(x, accumulated_counts);
316     skip = POPCNT_TABLE[_mm_movemask_epi8(x)];
317   }
318 
319   UInt8 byte;
320   {
321    #ifdef _MSC_VER
322     __declspec(align(16)) UInt8 unit_bytes[16];
323     __declspec(align(16)) UInt8 accumulated_counts_bytes[16];
324    #else  // _MSC_VER
325     UInt8 unit_bytes[16] __attribute__ ((aligned (16)));
326     UInt8 accumulated_counts_bytes[16] __attribute__ ((aligned (16)));
327    #endif  // _MSC_VER
328     accumulated_counts = _mm_slli_si128(accumulated_counts, 1);
329     _mm_store_si128(reinterpret_cast<__m128i *>(unit_bytes), unit);
330     _mm_store_si128(reinterpret_cast<__m128i *>(accumulated_counts_bytes),
331         accumulated_counts);
332 
333     bit_id += skip;
334     byte = unit_bytes[skip / 8];
335     i -= accumulated_counts_bytes[skip / 8];
336   }
337 
338   return bit_id + SELECT_TABLE[i][byte];
339 }
340   #endif  // MARISA_USE_SSE2
341  #endif  // MARISA_WORD_SIZE == 64
342 #endif  // MARISA_USE_BMI2
343 
344 }  // namespace
345 
346 #if MARISA_WORD_SIZE == 64
347 
rank1(std::size_t i) const348 std::size_t BitVector::rank1(std::size_t i) const {
349   MARISA_DEBUG_IF(ranks_.empty(), MARISA_STATE_ERROR);
350   MARISA_DEBUG_IF(i > size_, MARISA_BOUND_ERROR);
351 
352   const RankIndex &rank = ranks_[i / 512];
353   std::size_t offset = rank.abs();
354   switch ((i / 64) % 8) {
355     case 1: {
356       offset += rank.rel1();
357       break;
358     }
359     case 2: {
360       offset += rank.rel2();
361       break;
362     }
363     case 3: {
364       offset += rank.rel3();
365       break;
366     }
367     case 4: {
368       offset += rank.rel4();
369       break;
370     }
371     case 5: {
372       offset += rank.rel5();
373       break;
374     }
375     case 6: {
376       offset += rank.rel6();
377       break;
378     }
379     case 7: {
380       offset += rank.rel7();
381       break;
382     }
383   }
384   offset += PopCount::count(units_[i / 64] & ((1ULL << (i % 64)) - 1));
385   return offset;
386 }
387 
select0(std::size_t i) const388 std::size_t BitVector::select0(std::size_t i) const {
389   MARISA_DEBUG_IF(select0s_.empty(), MARISA_STATE_ERROR);
390   MARISA_DEBUG_IF(i >= num_0s(), MARISA_BOUND_ERROR);
391 
392   const std::size_t select_id = i / 512;
393   MARISA_DEBUG_IF((select_id + 1) >= select0s_.size(), MARISA_BOUND_ERROR);
394   if ((i % 512) == 0) {
395     return select0s_[select_id];
396   }
397   std::size_t begin = select0s_[select_id] / 512;
398   std::size_t end = (select0s_[select_id + 1] + 511) / 512;
399   if (begin + 10 >= end) {
400     while (i >= ((begin + 1) * 512) - ranks_[begin + 1].abs()) {
401       ++begin;
402     }
403   } else {
404     while (begin + 1 < end) {
405       const std::size_t middle = (begin + end) / 2;
406       if (i < (middle * 512) - ranks_[middle].abs()) {
407         end = middle;
408       } else {
409         begin = middle;
410       }
411     }
412   }
413   const std::size_t rank_id = begin;
414   i -= (rank_id * 512) - ranks_[rank_id].abs();
415 
416   const RankIndex &rank = ranks_[rank_id];
417   std::size_t unit_id = rank_id * 8;
418   if (i < (256U - rank.rel4())) {
419     if (i < (128U - rank.rel2())) {
420       if (i >= (64U - rank.rel1())) {
421         unit_id += 1;
422         i -= 64 - rank.rel1();
423       }
424     } else if (i < (192U - rank.rel3())) {
425       unit_id += 2;
426       i -= 128 - rank.rel2();
427     } else {
428       unit_id += 3;
429       i -= 192 - rank.rel3();
430     }
431   } else if (i < (384U - rank.rel6())) {
432     if (i < (320U - rank.rel5())) {
433       unit_id += 4;
434       i -= 256 - rank.rel4();
435     } else {
436       unit_id += 5;
437       i -= 320 - rank.rel5();
438     }
439   } else if (i < (448U - rank.rel7())) {
440     unit_id += 6;
441     i -= 384 - rank.rel6();
442   } else {
443     unit_id += 7;
444     i -= 448 - rank.rel7();
445   }
446 
447   return select_bit(i, unit_id * 64, ~units_[unit_id]);
448 }
449 
select1(std::size_t i) const450 std::size_t BitVector::select1(std::size_t i) const {
451   MARISA_DEBUG_IF(select1s_.empty(), MARISA_STATE_ERROR);
452   MARISA_DEBUG_IF(i >= num_1s(), MARISA_BOUND_ERROR);
453 
454   const std::size_t select_id = i / 512;
455   MARISA_DEBUG_IF((select_id + 1) >= select1s_.size(), MARISA_BOUND_ERROR);
456   if ((i % 512) == 0) {
457     return select1s_[select_id];
458   }
459   std::size_t begin = select1s_[select_id] / 512;
460   std::size_t end = (select1s_[select_id + 1] + 511) / 512;
461   if (begin + 10 >= end) {
462     while (i >= ranks_[begin + 1].abs()) {
463       ++begin;
464     }
465   } else {
466     while (begin + 1 < end) {
467       const std::size_t middle = (begin + end) / 2;
468       if (i < ranks_[middle].abs()) {
469         end = middle;
470       } else {
471         begin = middle;
472       }
473     }
474   }
475   const std::size_t rank_id = begin;
476   i -= ranks_[rank_id].abs();
477 
478   const RankIndex &rank = ranks_[rank_id];
479   std::size_t unit_id = rank_id * 8;
480   if (i < rank.rel4()) {
481     if (i < rank.rel2()) {
482       if (i >= rank.rel1()) {
483         unit_id += 1;
484         i -= rank.rel1();
485       }
486     } else if (i < rank.rel3()) {
487       unit_id += 2;
488       i -= rank.rel2();
489     } else {
490       unit_id += 3;
491       i -= rank.rel3();
492     }
493   } else if (i < rank.rel6()) {
494     if (i < rank.rel5()) {
495       unit_id += 4;
496       i -= rank.rel4();
497     } else {
498       unit_id += 5;
499       i -= rank.rel5();
500     }
501   } else if (i < rank.rel7()) {
502     unit_id += 6;
503     i -= rank.rel6();
504   } else {
505     unit_id += 7;
506     i -= rank.rel7();
507   }
508 
509   return select_bit(i, unit_id * 64, units_[unit_id]);
510 }
511 
512 #else  // MARISA_WORD_SIZE == 64
513 
rank1(std::size_t i) const514 std::size_t BitVector::rank1(std::size_t i) const {
515   MARISA_DEBUG_IF(ranks_.empty(), MARISA_STATE_ERROR);
516   MARISA_DEBUG_IF(i > size_, MARISA_BOUND_ERROR);
517 
518   const RankIndex &rank = ranks_[i / 512];
519   std::size_t offset = rank.abs();
520   switch ((i / 64) % 8) {
521     case 1: {
522       offset += rank.rel1();
523       break;
524     }
525     case 2: {
526       offset += rank.rel2();
527       break;
528     }
529     case 3: {
530       offset += rank.rel3();
531       break;
532     }
533     case 4: {
534       offset += rank.rel4();
535       break;
536     }
537     case 5: {
538       offset += rank.rel5();
539       break;
540     }
541     case 6: {
542       offset += rank.rel6();
543       break;
544     }
545     case 7: {
546       offset += rank.rel7();
547       break;
548     }
549   }
550   if (((i / 32) & 1) == 1) {
551     offset += PopCount::count(units_[(i / 32) - 1]);
552   }
553   offset += PopCount::count(units_[i / 32] & ((1U << (i % 32)) - 1));
554   return offset;
555 }
556 
select0(std::size_t i) const557 std::size_t BitVector::select0(std::size_t i) const {
558   MARISA_DEBUG_IF(select0s_.empty(), MARISA_STATE_ERROR);
559   MARISA_DEBUG_IF(i >= num_0s(), MARISA_BOUND_ERROR);
560 
561   const std::size_t select_id = i / 512;
562   MARISA_DEBUG_IF((select_id + 1) >= select0s_.size(), MARISA_BOUND_ERROR);
563   if ((i % 512) == 0) {
564     return select0s_[select_id];
565   }
566   std::size_t begin = select0s_[select_id] / 512;
567   std::size_t end = (select0s_[select_id + 1] + 511) / 512;
568   if (begin + 10 >= end) {
569     while (i >= ((begin + 1) * 512) - ranks_[begin + 1].abs()) {
570       ++begin;
571     }
572   } else {
573     while (begin + 1 < end) {
574       const std::size_t middle = (begin + end) / 2;
575       if (i < (middle * 512) - ranks_[middle].abs()) {
576         end = middle;
577       } else {
578         begin = middle;
579       }
580     }
581   }
582   const std::size_t rank_id = begin;
583   i -= (rank_id * 512) - ranks_[rank_id].abs();
584 
585   const RankIndex &rank = ranks_[rank_id];
586   std::size_t unit_id = rank_id * 16;
587   if (i < (256U - rank.rel4())) {
588     if (i < (128U - rank.rel2())) {
589       if (i >= (64U - rank.rel1())) {
590         unit_id += 2;
591         i -= 64 - rank.rel1();
592       }
593     } else if (i < (192U - rank.rel3())) {
594       unit_id += 4;
595       i -= 128 - rank.rel2();
596     } else {
597       unit_id += 6;
598       i -= 192 - rank.rel3();
599     }
600   } else if (i < (384U - rank.rel6())) {
601     if (i < (320U - rank.rel5())) {
602       unit_id += 8;
603       i -= 256 - rank.rel4();
604     } else {
605       unit_id += 10;
606       i -= 320 - rank.rel5();
607     }
608   } else if (i < (448U - rank.rel7())) {
609     unit_id += 12;
610     i -= 384 - rank.rel6();
611   } else {
612     unit_id += 14;
613     i -= 448 - rank.rel7();
614   }
615 
616 #ifdef MARISA_USE_SSE2
617   return select_bit(i, unit_id * 32, ~units_[unit_id], ~units_[unit_id + 1]);
618 #else  // MARISA_USE_SSE2
619   UInt32 unit = ~units_[unit_id];
620   PopCount count(unit);
621   if (i >= count.lo32()) {
622     ++unit_id;
623     i -= count.lo32();
624     unit = ~units_[unit_id];
625     count = PopCount(unit);
626   }
627 
628   std::size_t bit_id = unit_id * 32;
629   if (i < count.lo16()) {
630     if (i >= count.lo8()) {
631       bit_id += 8;
632       unit >>= 8;
633       i -= count.lo8();
634     }
635   } else if (i < count.lo24()) {
636     bit_id += 16;
637     unit >>= 16;
638     i -= count.lo16();
639   } else {
640     bit_id += 24;
641     unit >>= 24;
642     i -= count.lo24();
643   }
644   return bit_id + SELECT_TABLE[i][unit & 0xFF];
645 #endif  // MARISA_USE_SSE2
646 }
647 
select1(std::size_t i) const648 std::size_t BitVector::select1(std::size_t i) const {
649   MARISA_DEBUG_IF(select1s_.empty(), MARISA_STATE_ERROR);
650   MARISA_DEBUG_IF(i >= num_1s(), MARISA_BOUND_ERROR);
651 
652   const std::size_t select_id = i / 512;
653   MARISA_DEBUG_IF((select_id + 1) >= select1s_.size(), MARISA_BOUND_ERROR);
654   if ((i % 512) == 0) {
655     return select1s_[select_id];
656   }
657   std::size_t begin = select1s_[select_id] / 512;
658   std::size_t end = (select1s_[select_id + 1] + 511) / 512;
659   if (begin + 10 >= end) {
660     while (i >= ranks_[begin + 1].abs()) {
661       ++begin;
662     }
663   } else {
664     while (begin + 1 < end) {
665       const std::size_t middle = (begin + end) / 2;
666       if (i < ranks_[middle].abs()) {
667         end = middle;
668       } else {
669         begin = middle;
670       }
671     }
672   }
673   const std::size_t rank_id = begin;
674   i -= ranks_[rank_id].abs();
675 
676   const RankIndex &rank = ranks_[rank_id];
677   std::size_t unit_id = rank_id * 16;
678   if (i < rank.rel4()) {
679     if (i < rank.rel2()) {
680       if (i >= rank.rel1()) {
681         unit_id += 2;
682         i -= rank.rel1();
683       }
684     } else if (i < rank.rel3()) {
685       unit_id += 4;
686       i -= rank.rel2();
687     } else {
688       unit_id += 6;
689       i -= rank.rel3();
690     }
691   } else if (i < rank.rel6()) {
692     if (i < rank.rel5()) {
693       unit_id += 8;
694       i -= rank.rel4();
695     } else {
696       unit_id += 10;
697       i -= rank.rel5();
698     }
699   } else if (i < rank.rel7()) {
700     unit_id += 12;
701     i -= rank.rel6();
702   } else {
703     unit_id += 14;
704     i -= rank.rel7();
705   }
706 
707 #ifdef MARISA_USE_SSE2
708   return select_bit(i, unit_id * 32, units_[unit_id], units_[unit_id + 1]);
709 #else  // MARISA_USE_SSE2
710   UInt32 unit = units_[unit_id];
711   PopCount count(unit);
712   if (i >= count.lo32()) {
713     ++unit_id;
714     i -= count.lo32();
715     unit = units_[unit_id];
716     count = PopCount(unit);
717   }
718 
719   std::size_t bit_id = unit_id * 32;
720   if (i < count.lo16()) {
721     if (i >= count.lo8()) {
722       bit_id += 8;
723       unit >>= 8;
724       i -= count.lo8();
725     }
726   } else if (i < count.lo24()) {
727     bit_id += 16;
728     unit >>= 16;
729     i -= count.lo16();
730   } else {
731     bit_id += 24;
732     unit >>= 24;
733     i -= count.lo24();
734   }
735   return bit_id + SELECT_TABLE[i][unit & 0xFF];
736 #endif  // MARISA_USE_SSE2
737 }
738 
739 #endif  // MARISA_WORD_SIZE == 64
740 
build_index(const BitVector & bv,bool enables_select0,bool enables_select1)741 void BitVector::build_index(const BitVector &bv,
742     bool enables_select0, bool enables_select1) {
743   ranks_.resize((bv.size() / 512) + (((bv.size() % 512) != 0) ? 1 : 0) + 1);
744 
745   std::size_t num_0s = 0;
746   std::size_t num_1s = 0;
747 
748   for (std::size_t i = 0; i < bv.size(); ++i) {
749     if ((i % 64) == 0) {
750       const std::size_t rank_id = i / 512;
751       switch ((i / 64) % 8) {
752         case 0: {
753           ranks_[rank_id].set_abs(num_1s);
754           break;
755         }
756         case 1: {
757           ranks_[rank_id].set_rel1(num_1s - ranks_[rank_id].abs());
758           break;
759         }
760         case 2: {
761           ranks_[rank_id].set_rel2(num_1s - ranks_[rank_id].abs());
762           break;
763         }
764         case 3: {
765           ranks_[rank_id].set_rel3(num_1s - ranks_[rank_id].abs());
766           break;
767         }
768         case 4: {
769           ranks_[rank_id].set_rel4(num_1s - ranks_[rank_id].abs());
770           break;
771         }
772         case 5: {
773           ranks_[rank_id].set_rel5(num_1s - ranks_[rank_id].abs());
774           break;
775         }
776         case 6: {
777           ranks_[rank_id].set_rel6(num_1s - ranks_[rank_id].abs());
778           break;
779         }
780         case 7: {
781           ranks_[rank_id].set_rel7(num_1s - ranks_[rank_id].abs());
782           break;
783         }
784       }
785     }
786 
787     if (bv[i]) {
788       if (enables_select1 && ((num_1s % 512) == 0)) {
789         select1s_.push_back(static_cast<UInt32>(i));
790       }
791       ++num_1s;
792     } else {
793       if (enables_select0 && ((num_0s % 512) == 0)) {
794         select0s_.push_back(static_cast<UInt32>(i));
795       }
796       ++num_0s;
797     }
798   }
799 
800   if ((bv.size() % 512) != 0) {
801     const std::size_t rank_id = (bv.size() - 1) / 512;
802     switch (((bv.size() - 1) / 64) % 8) {
803       case 0: {
804         ranks_[rank_id].set_rel1(num_1s - ranks_[rank_id].abs());
805       }  // fall through
806       case 1: {
807         ranks_[rank_id].set_rel2(num_1s - ranks_[rank_id].abs());
808       }  // fall through
809       case 2: {
810         ranks_[rank_id].set_rel3(num_1s - ranks_[rank_id].abs());
811       }  // fall through
812       case 3: {
813         ranks_[rank_id].set_rel4(num_1s - ranks_[rank_id].abs());
814       }  // fall through
815       case 4: {
816         ranks_[rank_id].set_rel5(num_1s - ranks_[rank_id].abs());
817       }  // fall through
818       case 5: {
819         ranks_[rank_id].set_rel6(num_1s - ranks_[rank_id].abs());
820       }  // fall through
821       case 6: {
822         ranks_[rank_id].set_rel7(num_1s - ranks_[rank_id].abs());
823         break;
824       }
825     }
826   }
827 
828   size_ = bv.size();
829   num_1s_ = bv.num_1s();
830 
831   ranks_.back().set_abs(num_1s);
832   if (enables_select0) {
833     select0s_.push_back(static_cast<UInt32>(bv.size()));
834     select0s_.shrink();
835   }
836   if (enables_select1) {
837     select1s_.push_back(static_cast<UInt32>(bv.size()));
838     select1s_.shrink();
839   }
840 }
841 
842 }  // namespace vector
843 }  // namespace grimoire
844 }  // namespace marisa
845