1 #include "marisa/grimoire/vector/pop-count.h"
2 #include "marisa/grimoire/vector/bit-vector.h"
3
4 namespace marisa {
5 namespace grimoire {
6 namespace vector {
7 namespace {
8
9 #ifdef MARISA_USE_BMI2
select_bit(std::size_t i,std::size_t bit_id,UInt64 unit)10 std::size_t select_bit(std::size_t i, std::size_t bit_id, UInt64 unit) {
11 #ifdef _MSC_VER
12 unsigned long pos;
13 ::_BitScanForward64(&pos, _pdep_u64(1ULL << i, unit));
14 return bit_id + pos;
15 #else // _MSC_VER
16 return bit_id + ::__builtin_ctzll(_pdep_u64(1ULL << i, unit));
17 #endif // _MSC_VER
18 }
19 #else // MARISA_USE_BMI2
20 const UInt8 SELECT_TABLE[8][256] = {
21 {
22 7, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
23 4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
24 5, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
25 4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
26 6, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
27 4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
28 5, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
29 4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
30 7, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
31 4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
32 5, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
33 4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
34 6, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
35 4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
36 5, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
37 4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0
38 },
39 {
40 7, 7, 7, 1, 7, 2, 2, 1, 7, 3, 3, 1, 3, 2, 2, 1,
41 7, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1,
42 7, 5, 5, 1, 5, 2, 2, 1, 5, 3, 3, 1, 3, 2, 2, 1,
43 5, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1,
44 7, 6, 6, 1, 6, 2, 2, 1, 6, 3, 3, 1, 3, 2, 2, 1,
45 6, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1,
46 6, 5, 5, 1, 5, 2, 2, 1, 5, 3, 3, 1, 3, 2, 2, 1,
47 5, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1,
48 7, 7, 7, 1, 7, 2, 2, 1, 7, 3, 3, 1, 3, 2, 2, 1,
49 7, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1,
50 7, 5, 5, 1, 5, 2, 2, 1, 5, 3, 3, 1, 3, 2, 2, 1,
51 5, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1,
52 7, 6, 6, 1, 6, 2, 2, 1, 6, 3, 3, 1, 3, 2, 2, 1,
53 6, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1,
54 6, 5, 5, 1, 5, 2, 2, 1, 5, 3, 3, 1, 3, 2, 2, 1,
55 5, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1
56 },
57 {
58 7, 7, 7, 7, 7, 7, 7, 2, 7, 7, 7, 3, 7, 3, 3, 2,
59 7, 7, 7, 4, 7, 4, 4, 2, 7, 4, 4, 3, 4, 3, 3, 2,
60 7, 7, 7, 5, 7, 5, 5, 2, 7, 5, 5, 3, 5, 3, 3, 2,
61 7, 5, 5, 4, 5, 4, 4, 2, 5, 4, 4, 3, 4, 3, 3, 2,
62 7, 7, 7, 6, 7, 6, 6, 2, 7, 6, 6, 3, 6, 3, 3, 2,
63 7, 6, 6, 4, 6, 4, 4, 2, 6, 4, 4, 3, 4, 3, 3, 2,
64 7, 6, 6, 5, 6, 5, 5, 2, 6, 5, 5, 3, 5, 3, 3, 2,
65 6, 5, 5, 4, 5, 4, 4, 2, 5, 4, 4, 3, 4, 3, 3, 2,
66 7, 7, 7, 7, 7, 7, 7, 2, 7, 7, 7, 3, 7, 3, 3, 2,
67 7, 7, 7, 4, 7, 4, 4, 2, 7, 4, 4, 3, 4, 3, 3, 2,
68 7, 7, 7, 5, 7, 5, 5, 2, 7, 5, 5, 3, 5, 3, 3, 2,
69 7, 5, 5, 4, 5, 4, 4, 2, 5, 4, 4, 3, 4, 3, 3, 2,
70 7, 7, 7, 6, 7, 6, 6, 2, 7, 6, 6, 3, 6, 3, 3, 2,
71 7, 6, 6, 4, 6, 4, 4, 2, 6, 4, 4, 3, 4, 3, 3, 2,
72 7, 6, 6, 5, 6, 5, 5, 2, 6, 5, 5, 3, 5, 3, 3, 2,
73 6, 5, 5, 4, 5, 4, 4, 2, 5, 4, 4, 3, 4, 3, 3, 2
74 },
75 {
76 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 3,
77 7, 7, 7, 7, 7, 7, 7, 4, 7, 7, 7, 4, 7, 4, 4, 3,
78 7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 7, 5, 7, 5, 5, 3,
79 7, 7, 7, 5, 7, 5, 5, 4, 7, 5, 5, 4, 5, 4, 4, 3,
80 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 3,
81 7, 7, 7, 6, 7, 6, 6, 4, 7, 6, 6, 4, 6, 4, 4, 3,
82 7, 7, 7, 6, 7, 6, 6, 5, 7, 6, 6, 5, 6, 5, 5, 3,
83 7, 6, 6, 5, 6, 5, 5, 4, 6, 5, 5, 4, 5, 4, 4, 3,
84 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 3,
85 7, 7, 7, 7, 7, 7, 7, 4, 7, 7, 7, 4, 7, 4, 4, 3,
86 7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 7, 5, 7, 5, 5, 3,
87 7, 7, 7, 5, 7, 5, 5, 4, 7, 5, 5, 4, 5, 4, 4, 3,
88 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 3,
89 7, 7, 7, 6, 7, 6, 6, 4, 7, 6, 6, 4, 6, 4, 4, 3,
90 7, 7, 7, 6, 7, 6, 6, 5, 7, 6, 6, 5, 6, 5, 5, 3,
91 7, 6, 6, 5, 6, 5, 5, 4, 6, 5, 5, 4, 5, 4, 4, 3
92 },
93 {
94 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
95 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 4,
96 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5,
97 7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 7, 5, 7, 5, 5, 4,
98 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6,
99 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 4,
100 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 5,
101 7, 7, 7, 6, 7, 6, 6, 5, 7, 6, 6, 5, 6, 5, 5, 4,
102 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
103 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 4,
104 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5,
105 7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 7, 5, 7, 5, 5, 4,
106 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6,
107 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 4,
108 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 5,
109 7, 7, 7, 6, 7, 6, 6, 5, 7, 6, 6, 5, 6, 5, 5, 4
110 },
111 {
112 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
113 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
114 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
115 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5,
116 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
117 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6,
118 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6,
119 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 5,
120 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
121 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
122 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
123 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5,
124 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
125 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6,
126 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6,
127 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 5
128 },
129 {
130 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
131 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
132 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
133 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
134 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
135 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
136 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
137 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6,
138 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
139 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
140 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
141 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
142 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
143 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
144 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
145 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6
146 },
147 {
148 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
149 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
150 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
151 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
152 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
153 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
154 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
155 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
156 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
157 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
158 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
159 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
160 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
161 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
162 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
163 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7
164 }
165 };
166
167 #if MARISA_WORD_SIZE == 64
168 const UInt64 MASK_01 = 0x0101010101010101ULL;
169 #if !defined(MARISA_X64) || !defined(MARISA_USE_SSSE3)
170 const UInt64 MASK_0F = 0x0F0F0F0F0F0F0F0FULL;
171 const UInt64 MASK_33 = 0x3333333333333333ULL;
172 const UInt64 MASK_55 = 0x5555555555555555ULL;
173 #endif // !defined(MARISA_X64) || !defined(MARISA_USE_SSSE3)
174 #if !defined(MARISA_X64) || !defined(MARISA_USE_POPCNT)
175 const UInt64 MASK_80 = 0x8080808080808080ULL;
176 #endif // !defined(MARISA_X64) || !defined(MARISA_USE_POPCNT)
177
178 std::size_t select_bit(std::size_t i, std::size_t bit_id, UInt64 unit) {
179 UInt64 counts;
180 {
181 #if defined(MARISA_X64) && defined(MARISA_USE_SSSE3)
182 __m128i lower_nibbles = _mm_cvtsi64_si128(
183 static_cast<long long>(unit & 0x0F0F0F0F0F0F0F0FULL));
184 __m128i upper_nibbles = _mm_cvtsi64_si128(
185 static_cast<long long>(unit & 0xF0F0F0F0F0F0F0F0ULL));
186 upper_nibbles = _mm_srli_epi32(upper_nibbles, 4);
187
188 __m128i lower_counts =
189 _mm_set_epi8(4, 3, 3, 2, 3, 2, 2, 1, 3, 2, 2, 1, 2, 1, 1, 0);
190 lower_counts = _mm_shuffle_epi8(lower_counts, lower_nibbles);
191 __m128i upper_counts =
192 _mm_set_epi8(4, 3, 3, 2, 3, 2, 2, 1, 3, 2, 2, 1, 2, 1, 1, 0);
193 upper_counts = _mm_shuffle_epi8(upper_counts, upper_nibbles);
194
195 counts = static_cast<UInt64>(_mm_cvtsi128_si64(
196 _mm_add_epi8(lower_counts, upper_counts)));
197 #else // defined(MARISA_X64) && defined(MARISA_USE_SSSE3)
198 counts = unit - ((unit >> 1) & MASK_55);
199 counts = (counts & MASK_33) + ((counts >> 2) & MASK_33);
200 counts = (counts + (counts >> 4)) & MASK_0F;
201 #endif // defined(MARISA_X64) && defined(MARISA_USE_SSSE3)
202 counts *= MASK_01;
203 }
204
205 #if defined(MARISA_X64) && defined(MARISA_USE_POPCNT)
206 UInt8 skip;
207 {
208 __m128i x = _mm_cvtsi64_si128(static_cast<long long>((i + 1) * MASK_01));
209 __m128i y = _mm_cvtsi64_si128(static_cast<long long>(counts));
210 x = _mm_cmpgt_epi8(x, y);
211 skip = (UInt8)PopCount::count(static_cast<UInt64>(_mm_cvtsi128_si64(x)));
212 }
213 #else // defined(MARISA_X64) && defined(MARISA_USE_POPCNT)
214 const UInt64 x = (counts | MASK_80) - ((i + 1) * MASK_01);
215 #ifdef _MSC_VER
216 unsigned long skip;
217 ::_BitScanForward64(&skip, (x & MASK_80) >> 7);
218 #else // _MSC_VER
219 const int skip = ::__builtin_ctzll((x & MASK_80) >> 7);
220 #endif // _MSC_VER
221 #endif // defined(MARISA_X64) && defined(MARISA_USE_POPCNT)
222
223 bit_id += static_cast<std::size_t>(skip);
224 unit >>= skip;
225 i -= ((counts << 8) >> skip) & 0xFF;
226
227 return bit_id + SELECT_TABLE[i][unit & 0xFF];
228 }
229 #else // MARISA_WORD_SIZE == 64
230 #ifdef MARISA_USE_SSE2
231 const UInt8 POPCNT_TABLE[256] = {
232 0, 8, 8, 16, 8, 16, 16, 24, 8, 16, 16, 24, 16, 24, 24, 32,
233 8, 16, 16, 24, 16, 24, 24, 32, 16, 24, 24, 32, 24, 32, 32, 40,
234 8, 16, 16, 24, 16, 24, 24, 32, 16, 24, 24, 32, 24, 32, 32, 40,
235 16, 24, 24, 32, 24, 32, 32, 40, 24, 32, 32, 40, 32, 40, 40, 48,
236 8, 16, 16, 24, 16, 24, 24, 32, 16, 24, 24, 32, 24, 32, 32, 40,
237 16, 24, 24, 32, 24, 32, 32, 40, 24, 32, 32, 40, 32, 40, 40, 48,
238 16, 24, 24, 32, 24, 32, 32, 40, 24, 32, 32, 40, 32, 40, 40, 48,
239 24, 32, 32, 40, 32, 40, 40, 48, 32, 40, 40, 48, 40, 48, 48, 56,
240 8, 16, 16, 24, 16, 24, 24, 32, 16, 24, 24, 32, 24, 32, 32, 40,
241 16, 24, 24, 32, 24, 32, 32, 40, 24, 32, 32, 40, 32, 40, 40, 48,
242 16, 24, 24, 32, 24, 32, 32, 40, 24, 32, 32, 40, 32, 40, 40, 48,
243 24, 32, 32, 40, 32, 40, 40, 48, 32, 40, 40, 48, 40, 48, 48, 56,
244 16, 24, 24, 32, 24, 32, 32, 40, 24, 32, 32, 40, 32, 40, 40, 48,
245 24, 32, 32, 40, 32, 40, 40, 48, 32, 40, 40, 48, 40, 48, 48, 56,
246 24, 32, 32, 40, 32, 40, 40, 48, 32, 40, 40, 48, 40, 48, 48, 56,
247 32, 40, 40, 48, 40, 48, 48, 56, 40, 48, 48, 56, 48, 56, 56, 64
248 };
249
250 std::size_t select_bit(std::size_t i, std::size_t bit_id,
251 UInt32 unit_lo, UInt32 unit_hi) {
252 __m128i unit;
253 {
254 __m128i lower_dword = _mm_cvtsi32_si128(unit_lo);
255 __m128i upper_dword = _mm_cvtsi32_si128(unit_hi);
256 upper_dword = _mm_slli_si128(upper_dword, 4);
257 unit = _mm_or_si128(lower_dword, upper_dword);
258 }
259
260 __m128i counts;
261 {
262 #ifdef MARISA_USE_SSSE3
263 __m128i lower_nibbles = _mm_set1_epi8(0x0F);
264 lower_nibbles = _mm_and_si128(lower_nibbles, unit);
265 __m128i upper_nibbles = _mm_set1_epi8((UInt8)0xF0);
266 upper_nibbles = _mm_and_si128(upper_nibbles, unit);
267 upper_nibbles = _mm_srli_epi32(upper_nibbles, 4);
268
269 __m128i lower_counts =
270 _mm_set_epi8(4, 3, 3, 2, 3, 2, 2, 1, 3, 2, 2, 1, 2, 1, 1, 0);
271 lower_counts = _mm_shuffle_epi8(lower_counts, lower_nibbles);
272 __m128i upper_counts =
273 _mm_set_epi8(4, 3, 3, 2, 3, 2, 2, 1, 3, 2, 2, 1, 2, 1, 1, 0);
274 upper_counts = _mm_shuffle_epi8(upper_counts, upper_nibbles);
275
276 counts = _mm_add_epi8(lower_counts, upper_counts);
277 #else // MARISA_USE_SSSE3
278 __m128i x = _mm_srli_epi32(unit, 1);
279 x = _mm_and_si128(x, _mm_set1_epi8(0x55));
280 x = _mm_sub_epi8(unit, x);
281
282 __m128i y = _mm_srli_epi32(x, 2);
283 y = _mm_and_si128(y, _mm_set1_epi8(0x33));
284 x = _mm_and_si128(x, _mm_set1_epi8(0x33));
285 x = _mm_add_epi8(x, y);
286
287 y = _mm_srli_epi32(x, 4);
288 x = _mm_add_epi8(x, y);
289 counts = _mm_and_si128(x, _mm_set1_epi8(0x0F));
290 #endif // MARISA_USE_SSSE3
291 }
292
293 __m128i accumulated_counts;
294 {
295 __m128i x = counts;
296 x = _mm_slli_si128(x, 1);
297 __m128i y = counts;
298 y = _mm_add_epi32(y, x);
299
300 x = y;
301 y = _mm_slli_si128(y, 2);
302 x = _mm_add_epi32(x, y);
303
304 y = x;
305 x = _mm_slli_si128(x, 4);
306 y = _mm_add_epi32(y, x);
307
308 accumulated_counts = _mm_set_epi32(0x7F7F7F7FU, 0x7F7F7F7FU, 0, 0);
309 accumulated_counts = _mm_or_si128(accumulated_counts, y);
310 }
311
312 UInt8 skip;
313 {
314 __m128i x = _mm_set1_epi8((UInt8)(i + 1));
315 x = _mm_cmpgt_epi8(x, accumulated_counts);
316 skip = POPCNT_TABLE[_mm_movemask_epi8(x)];
317 }
318
319 UInt8 byte;
320 {
321 #ifdef _MSC_VER
322 __declspec(align(16)) UInt8 unit_bytes[16];
323 __declspec(align(16)) UInt8 accumulated_counts_bytes[16];
324 #else // _MSC_VER
325 UInt8 unit_bytes[16] __attribute__ ((aligned (16)));
326 UInt8 accumulated_counts_bytes[16] __attribute__ ((aligned (16)));
327 #endif // _MSC_VER
328 accumulated_counts = _mm_slli_si128(accumulated_counts, 1);
329 _mm_store_si128(reinterpret_cast<__m128i *>(unit_bytes), unit);
330 _mm_store_si128(reinterpret_cast<__m128i *>(accumulated_counts_bytes),
331 accumulated_counts);
332
333 bit_id += skip;
334 byte = unit_bytes[skip / 8];
335 i -= accumulated_counts_bytes[skip / 8];
336 }
337
338 return bit_id + SELECT_TABLE[i][byte];
339 }
340 #endif // MARISA_USE_SSE2
341 #endif // MARISA_WORD_SIZE == 64
342 #endif // MARISA_USE_BMI2
343
344 } // namespace
345
346 #if MARISA_WORD_SIZE == 64
347
rank1(std::size_t i) const348 std::size_t BitVector::rank1(std::size_t i) const {
349 MARISA_DEBUG_IF(ranks_.empty(), MARISA_STATE_ERROR);
350 MARISA_DEBUG_IF(i > size_, MARISA_BOUND_ERROR);
351
352 const RankIndex &rank = ranks_[i / 512];
353 std::size_t offset = rank.abs();
354 switch ((i / 64) % 8) {
355 case 1: {
356 offset += rank.rel1();
357 break;
358 }
359 case 2: {
360 offset += rank.rel2();
361 break;
362 }
363 case 3: {
364 offset += rank.rel3();
365 break;
366 }
367 case 4: {
368 offset += rank.rel4();
369 break;
370 }
371 case 5: {
372 offset += rank.rel5();
373 break;
374 }
375 case 6: {
376 offset += rank.rel6();
377 break;
378 }
379 case 7: {
380 offset += rank.rel7();
381 break;
382 }
383 }
384 offset += PopCount::count(units_[i / 64] & ((1ULL << (i % 64)) - 1));
385 return offset;
386 }
387
select0(std::size_t i) const388 std::size_t BitVector::select0(std::size_t i) const {
389 MARISA_DEBUG_IF(select0s_.empty(), MARISA_STATE_ERROR);
390 MARISA_DEBUG_IF(i >= num_0s(), MARISA_BOUND_ERROR);
391
392 const std::size_t select_id = i / 512;
393 MARISA_DEBUG_IF((select_id + 1) >= select0s_.size(), MARISA_BOUND_ERROR);
394 if ((i % 512) == 0) {
395 return select0s_[select_id];
396 }
397 std::size_t begin = select0s_[select_id] / 512;
398 std::size_t end = (select0s_[select_id + 1] + 511) / 512;
399 if (begin + 10 >= end) {
400 while (i >= ((begin + 1) * 512) - ranks_[begin + 1].abs()) {
401 ++begin;
402 }
403 } else {
404 while (begin + 1 < end) {
405 const std::size_t middle = (begin + end) / 2;
406 if (i < (middle * 512) - ranks_[middle].abs()) {
407 end = middle;
408 } else {
409 begin = middle;
410 }
411 }
412 }
413 const std::size_t rank_id = begin;
414 i -= (rank_id * 512) - ranks_[rank_id].abs();
415
416 const RankIndex &rank = ranks_[rank_id];
417 std::size_t unit_id = rank_id * 8;
418 if (i < (256U - rank.rel4())) {
419 if (i < (128U - rank.rel2())) {
420 if (i >= (64U - rank.rel1())) {
421 unit_id += 1;
422 i -= 64 - rank.rel1();
423 }
424 } else if (i < (192U - rank.rel3())) {
425 unit_id += 2;
426 i -= 128 - rank.rel2();
427 } else {
428 unit_id += 3;
429 i -= 192 - rank.rel3();
430 }
431 } else if (i < (384U - rank.rel6())) {
432 if (i < (320U - rank.rel5())) {
433 unit_id += 4;
434 i -= 256 - rank.rel4();
435 } else {
436 unit_id += 5;
437 i -= 320 - rank.rel5();
438 }
439 } else if (i < (448U - rank.rel7())) {
440 unit_id += 6;
441 i -= 384 - rank.rel6();
442 } else {
443 unit_id += 7;
444 i -= 448 - rank.rel7();
445 }
446
447 return select_bit(i, unit_id * 64, ~units_[unit_id]);
448 }
449
select1(std::size_t i) const450 std::size_t BitVector::select1(std::size_t i) const {
451 MARISA_DEBUG_IF(select1s_.empty(), MARISA_STATE_ERROR);
452 MARISA_DEBUG_IF(i >= num_1s(), MARISA_BOUND_ERROR);
453
454 const std::size_t select_id = i / 512;
455 MARISA_DEBUG_IF((select_id + 1) >= select1s_.size(), MARISA_BOUND_ERROR);
456 if ((i % 512) == 0) {
457 return select1s_[select_id];
458 }
459 std::size_t begin = select1s_[select_id] / 512;
460 std::size_t end = (select1s_[select_id + 1] + 511) / 512;
461 if (begin + 10 >= end) {
462 while (i >= ranks_[begin + 1].abs()) {
463 ++begin;
464 }
465 } else {
466 while (begin + 1 < end) {
467 const std::size_t middle = (begin + end) / 2;
468 if (i < ranks_[middle].abs()) {
469 end = middle;
470 } else {
471 begin = middle;
472 }
473 }
474 }
475 const std::size_t rank_id = begin;
476 i -= ranks_[rank_id].abs();
477
478 const RankIndex &rank = ranks_[rank_id];
479 std::size_t unit_id = rank_id * 8;
480 if (i < rank.rel4()) {
481 if (i < rank.rel2()) {
482 if (i >= rank.rel1()) {
483 unit_id += 1;
484 i -= rank.rel1();
485 }
486 } else if (i < rank.rel3()) {
487 unit_id += 2;
488 i -= rank.rel2();
489 } else {
490 unit_id += 3;
491 i -= rank.rel3();
492 }
493 } else if (i < rank.rel6()) {
494 if (i < rank.rel5()) {
495 unit_id += 4;
496 i -= rank.rel4();
497 } else {
498 unit_id += 5;
499 i -= rank.rel5();
500 }
501 } else if (i < rank.rel7()) {
502 unit_id += 6;
503 i -= rank.rel6();
504 } else {
505 unit_id += 7;
506 i -= rank.rel7();
507 }
508
509 return select_bit(i, unit_id * 64, units_[unit_id]);
510 }
511
512 #else // MARISA_WORD_SIZE == 64
513
rank1(std::size_t i) const514 std::size_t BitVector::rank1(std::size_t i) const {
515 MARISA_DEBUG_IF(ranks_.empty(), MARISA_STATE_ERROR);
516 MARISA_DEBUG_IF(i > size_, MARISA_BOUND_ERROR);
517
518 const RankIndex &rank = ranks_[i / 512];
519 std::size_t offset = rank.abs();
520 switch ((i / 64) % 8) {
521 case 1: {
522 offset += rank.rel1();
523 break;
524 }
525 case 2: {
526 offset += rank.rel2();
527 break;
528 }
529 case 3: {
530 offset += rank.rel3();
531 break;
532 }
533 case 4: {
534 offset += rank.rel4();
535 break;
536 }
537 case 5: {
538 offset += rank.rel5();
539 break;
540 }
541 case 6: {
542 offset += rank.rel6();
543 break;
544 }
545 case 7: {
546 offset += rank.rel7();
547 break;
548 }
549 }
550 if (((i / 32) & 1) == 1) {
551 offset += PopCount::count(units_[(i / 32) - 1]);
552 }
553 offset += PopCount::count(units_[i / 32] & ((1U << (i % 32)) - 1));
554 return offset;
555 }
556
select0(std::size_t i) const557 std::size_t BitVector::select0(std::size_t i) const {
558 MARISA_DEBUG_IF(select0s_.empty(), MARISA_STATE_ERROR);
559 MARISA_DEBUG_IF(i >= num_0s(), MARISA_BOUND_ERROR);
560
561 const std::size_t select_id = i / 512;
562 MARISA_DEBUG_IF((select_id + 1) >= select0s_.size(), MARISA_BOUND_ERROR);
563 if ((i % 512) == 0) {
564 return select0s_[select_id];
565 }
566 std::size_t begin = select0s_[select_id] / 512;
567 std::size_t end = (select0s_[select_id + 1] + 511) / 512;
568 if (begin + 10 >= end) {
569 while (i >= ((begin + 1) * 512) - ranks_[begin + 1].abs()) {
570 ++begin;
571 }
572 } else {
573 while (begin + 1 < end) {
574 const std::size_t middle = (begin + end) / 2;
575 if (i < (middle * 512) - ranks_[middle].abs()) {
576 end = middle;
577 } else {
578 begin = middle;
579 }
580 }
581 }
582 const std::size_t rank_id = begin;
583 i -= (rank_id * 512) - ranks_[rank_id].abs();
584
585 const RankIndex &rank = ranks_[rank_id];
586 std::size_t unit_id = rank_id * 16;
587 if (i < (256U - rank.rel4())) {
588 if (i < (128U - rank.rel2())) {
589 if (i >= (64U - rank.rel1())) {
590 unit_id += 2;
591 i -= 64 - rank.rel1();
592 }
593 } else if (i < (192U - rank.rel3())) {
594 unit_id += 4;
595 i -= 128 - rank.rel2();
596 } else {
597 unit_id += 6;
598 i -= 192 - rank.rel3();
599 }
600 } else if (i < (384U - rank.rel6())) {
601 if (i < (320U - rank.rel5())) {
602 unit_id += 8;
603 i -= 256 - rank.rel4();
604 } else {
605 unit_id += 10;
606 i -= 320 - rank.rel5();
607 }
608 } else if (i < (448U - rank.rel7())) {
609 unit_id += 12;
610 i -= 384 - rank.rel6();
611 } else {
612 unit_id += 14;
613 i -= 448 - rank.rel7();
614 }
615
616 #ifdef MARISA_USE_SSE2
617 return select_bit(i, unit_id * 32, ~units_[unit_id], ~units_[unit_id + 1]);
618 #else // MARISA_USE_SSE2
619 UInt32 unit = ~units_[unit_id];
620 PopCount count(unit);
621 if (i >= count.lo32()) {
622 ++unit_id;
623 i -= count.lo32();
624 unit = ~units_[unit_id];
625 count = PopCount(unit);
626 }
627
628 std::size_t bit_id = unit_id * 32;
629 if (i < count.lo16()) {
630 if (i >= count.lo8()) {
631 bit_id += 8;
632 unit >>= 8;
633 i -= count.lo8();
634 }
635 } else if (i < count.lo24()) {
636 bit_id += 16;
637 unit >>= 16;
638 i -= count.lo16();
639 } else {
640 bit_id += 24;
641 unit >>= 24;
642 i -= count.lo24();
643 }
644 return bit_id + SELECT_TABLE[i][unit & 0xFF];
645 #endif // MARISA_USE_SSE2
646 }
647
select1(std::size_t i) const648 std::size_t BitVector::select1(std::size_t i) const {
649 MARISA_DEBUG_IF(select1s_.empty(), MARISA_STATE_ERROR);
650 MARISA_DEBUG_IF(i >= num_1s(), MARISA_BOUND_ERROR);
651
652 const std::size_t select_id = i / 512;
653 MARISA_DEBUG_IF((select_id + 1) >= select1s_.size(), MARISA_BOUND_ERROR);
654 if ((i % 512) == 0) {
655 return select1s_[select_id];
656 }
657 std::size_t begin = select1s_[select_id] / 512;
658 std::size_t end = (select1s_[select_id + 1] + 511) / 512;
659 if (begin + 10 >= end) {
660 while (i >= ranks_[begin + 1].abs()) {
661 ++begin;
662 }
663 } else {
664 while (begin + 1 < end) {
665 const std::size_t middle = (begin + end) / 2;
666 if (i < ranks_[middle].abs()) {
667 end = middle;
668 } else {
669 begin = middle;
670 }
671 }
672 }
673 const std::size_t rank_id = begin;
674 i -= ranks_[rank_id].abs();
675
676 const RankIndex &rank = ranks_[rank_id];
677 std::size_t unit_id = rank_id * 16;
678 if (i < rank.rel4()) {
679 if (i < rank.rel2()) {
680 if (i >= rank.rel1()) {
681 unit_id += 2;
682 i -= rank.rel1();
683 }
684 } else if (i < rank.rel3()) {
685 unit_id += 4;
686 i -= rank.rel2();
687 } else {
688 unit_id += 6;
689 i -= rank.rel3();
690 }
691 } else if (i < rank.rel6()) {
692 if (i < rank.rel5()) {
693 unit_id += 8;
694 i -= rank.rel4();
695 } else {
696 unit_id += 10;
697 i -= rank.rel5();
698 }
699 } else if (i < rank.rel7()) {
700 unit_id += 12;
701 i -= rank.rel6();
702 } else {
703 unit_id += 14;
704 i -= rank.rel7();
705 }
706
707 #ifdef MARISA_USE_SSE2
708 return select_bit(i, unit_id * 32, units_[unit_id], units_[unit_id + 1]);
709 #else // MARISA_USE_SSE2
710 UInt32 unit = units_[unit_id];
711 PopCount count(unit);
712 if (i >= count.lo32()) {
713 ++unit_id;
714 i -= count.lo32();
715 unit = units_[unit_id];
716 count = PopCount(unit);
717 }
718
719 std::size_t bit_id = unit_id * 32;
720 if (i < count.lo16()) {
721 if (i >= count.lo8()) {
722 bit_id += 8;
723 unit >>= 8;
724 i -= count.lo8();
725 }
726 } else if (i < count.lo24()) {
727 bit_id += 16;
728 unit >>= 16;
729 i -= count.lo16();
730 } else {
731 bit_id += 24;
732 unit >>= 24;
733 i -= count.lo24();
734 }
735 return bit_id + SELECT_TABLE[i][unit & 0xFF];
736 #endif // MARISA_USE_SSE2
737 }
738
739 #endif // MARISA_WORD_SIZE == 64
740
build_index(const BitVector & bv,bool enables_select0,bool enables_select1)741 void BitVector::build_index(const BitVector &bv,
742 bool enables_select0, bool enables_select1) {
743 ranks_.resize((bv.size() / 512) + (((bv.size() % 512) != 0) ? 1 : 0) + 1);
744
745 std::size_t num_0s = 0;
746 std::size_t num_1s = 0;
747
748 for (std::size_t i = 0; i < bv.size(); ++i) {
749 if ((i % 64) == 0) {
750 const std::size_t rank_id = i / 512;
751 switch ((i / 64) % 8) {
752 case 0: {
753 ranks_[rank_id].set_abs(num_1s);
754 break;
755 }
756 case 1: {
757 ranks_[rank_id].set_rel1(num_1s - ranks_[rank_id].abs());
758 break;
759 }
760 case 2: {
761 ranks_[rank_id].set_rel2(num_1s - ranks_[rank_id].abs());
762 break;
763 }
764 case 3: {
765 ranks_[rank_id].set_rel3(num_1s - ranks_[rank_id].abs());
766 break;
767 }
768 case 4: {
769 ranks_[rank_id].set_rel4(num_1s - ranks_[rank_id].abs());
770 break;
771 }
772 case 5: {
773 ranks_[rank_id].set_rel5(num_1s - ranks_[rank_id].abs());
774 break;
775 }
776 case 6: {
777 ranks_[rank_id].set_rel6(num_1s - ranks_[rank_id].abs());
778 break;
779 }
780 case 7: {
781 ranks_[rank_id].set_rel7(num_1s - ranks_[rank_id].abs());
782 break;
783 }
784 }
785 }
786
787 if (bv[i]) {
788 if (enables_select1 && ((num_1s % 512) == 0)) {
789 select1s_.push_back(static_cast<UInt32>(i));
790 }
791 ++num_1s;
792 } else {
793 if (enables_select0 && ((num_0s % 512) == 0)) {
794 select0s_.push_back(static_cast<UInt32>(i));
795 }
796 ++num_0s;
797 }
798 }
799
800 if ((bv.size() % 512) != 0) {
801 const std::size_t rank_id = (bv.size() - 1) / 512;
802 switch (((bv.size() - 1) / 64) % 8) {
803 case 0: {
804 ranks_[rank_id].set_rel1(num_1s - ranks_[rank_id].abs());
805 } // fall through
806 case 1: {
807 ranks_[rank_id].set_rel2(num_1s - ranks_[rank_id].abs());
808 } // fall through
809 case 2: {
810 ranks_[rank_id].set_rel3(num_1s - ranks_[rank_id].abs());
811 } // fall through
812 case 3: {
813 ranks_[rank_id].set_rel4(num_1s - ranks_[rank_id].abs());
814 } // fall through
815 case 4: {
816 ranks_[rank_id].set_rel5(num_1s - ranks_[rank_id].abs());
817 } // fall through
818 case 5: {
819 ranks_[rank_id].set_rel6(num_1s - ranks_[rank_id].abs());
820 } // fall through
821 case 6: {
822 ranks_[rank_id].set_rel7(num_1s - ranks_[rank_id].abs());
823 break;
824 }
825 }
826 }
827
828 size_ = bv.size();
829 num_1s_ = bv.num_1s();
830
831 ranks_.back().set_abs(num_1s);
832 if (enables_select0) {
833 select0s_.push_back(static_cast<UInt32>(bv.size()));
834 select0s_.shrink();
835 }
836 if (enables_select1) {
837 select1s_.push_back(static_cast<UInt32>(bv.size()));
838 select1s_.shrink();
839 }
840 }
841
842 } // namespace vector
843 } // namespace grimoire
844 } // namespace marisa
845