1 /// A trait for describing vector operations used by vectorized searchers. 2 /// 3 /// The trait is highly constrained to low level vector operations needed. 4 /// In general, it was invented mostly to be generic over x86's __m128i and 5 /// __m256i types. At time of writing, it also supports wasm and aarch64 6 /// 128-bit vector types as well. 7 /// 8 /// # Safety 9 /// 10 /// All methods are not safe since they are intended to be implemented using 11 /// vendor intrinsics, which are also not safe. Callers must ensure that the 12 /// appropriate target features are enabled in the calling function, and that 13 /// the current CPU supports them. All implementations should avoid marking the 14 /// routines with #[target_feature] and instead mark them as #[inline(always)] 15 /// to ensure they get appropriately inlined. (inline(always) cannot be used 16 /// with target_feature.) 17 pub(crate) trait Vector: Copy + core::fmt::Debug { 18 /// The number of bits in the vector. 19 const BITS: usize; 20 /// The number of bytes in the vector. That is, this is the size of the 21 /// vector in memory. 22 const BYTES: usize; 23 /// The bits that must be zero in order for a `*const u8` pointer to be 24 /// correctly aligned to read vector values. 25 const ALIGN: usize; 26 27 /// The type of the value returned by `Vector::movemask`. 28 /// 29 /// This supports abstracting over the specific representation used in 30 /// order to accommodate different representations in different ISAs. 31 type Mask: MoveMask; 32 33 /// Create a vector with 8-bit lanes with the given byte repeated into each 34 /// lane. splat(byte: u8) -> Self35 unsafe fn splat(byte: u8) -> Self; 36 37 /// Read a vector-size number of bytes from the given pointer. The pointer 38 /// must be aligned to the size of the vector. 39 /// 40 /// # Safety 41 /// 42 /// Callers must guarantee that at least `BYTES` bytes are readable from 43 /// `data` and that `data` is aligned to a `BYTES` boundary. load_aligned(data: *const u8) -> Self44 unsafe fn load_aligned(data: *const u8) -> Self; 45 46 /// Read a vector-size number of bytes from the given pointer. The pointer 47 /// does not need to be aligned. 48 /// 49 /// # Safety 50 /// 51 /// Callers must guarantee that at least `BYTES` bytes are readable from 52 /// `data`. load_unaligned(data: *const u8) -> Self53 unsafe fn load_unaligned(data: *const u8) -> Self; 54 55 /// _mm_movemask_epi8 or _mm256_movemask_epi8 movemask(self) -> Self::Mask56 unsafe fn movemask(self) -> Self::Mask; 57 /// _mm_cmpeq_epi8 or _mm256_cmpeq_epi8 cmpeq(self, vector2: Self) -> Self58 unsafe fn cmpeq(self, vector2: Self) -> Self; 59 /// _mm_and_si128 or _mm256_and_si256 and(self, vector2: Self) -> Self60 unsafe fn and(self, vector2: Self) -> Self; 61 /// _mm_or or _mm256_or_si256 or(self, vector2: Self) -> Self62 unsafe fn or(self, vector2: Self) -> Self; 63 /// Returns true if and only if `Self::movemask` would return a mask that 64 /// contains at least one non-zero bit. movemask_will_have_non_zero(self) -> bool65 unsafe fn movemask_will_have_non_zero(self) -> bool { 66 self.movemask().has_non_zero() 67 } 68 } 69 70 /// A trait that abstracts over a vector-to-scalar operation called 71 /// "move mask." 72 /// 73 /// On x86-64, this is `_mm_movemask_epi8` for SSE2 and `_mm256_movemask_epi8` 74 /// for AVX2. It takes a vector of `u8` lanes and returns a scalar where the 75 /// `i`th bit is set if and only if the most significant bit in the `i`th lane 76 /// of the vector is set. The simd128 ISA for wasm32 also supports this 77 /// exact same operation natively. 78 /// 79 /// ... But aarch64 doesn't. So we have to fake it with more instructions and 80 /// a slightly different representation. We could do extra work to unify the 81 /// representations, but then would require additional costs in the hot path 82 /// for `memchr` and `packedpair`. So instead, we abstraction over the specific 83 /// representation with this trait an ddefine the operations we actually need. 84 pub(crate) trait MoveMask: Copy + core::fmt::Debug { 85 /// Return a mask that is all zeros except for the least significant `n` 86 /// lanes in a corresponding vector. all_zeros_except_least_significant(n: usize) -> Self87 fn all_zeros_except_least_significant(n: usize) -> Self; 88 89 /// Returns true if and only if this mask has a a non-zero bit anywhere. has_non_zero(self) -> bool90 fn has_non_zero(self) -> bool; 91 92 /// Returns the number of bits set to 1 in this mask. count_ones(self) -> usize93 fn count_ones(self) -> usize; 94 95 /// Does a bitwise `and` operation between `self` and `other`. and(self, other: Self) -> Self96 fn and(self, other: Self) -> Self; 97 98 /// Does a bitwise `or` operation between `self` and `other`. or(self, other: Self) -> Self99 fn or(self, other: Self) -> Self; 100 101 /// Returns a mask that is equivalent to `self` but with the least 102 /// significant 1-bit set to 0. clear_least_significant_bit(self) -> Self103 fn clear_least_significant_bit(self) -> Self; 104 105 /// Returns the offset of the first non-zero lane this mask represents. first_offset(self) -> usize106 fn first_offset(self) -> usize; 107 108 /// Returns the offset of the last non-zero lane this mask represents. last_offset(self) -> usize109 fn last_offset(self) -> usize; 110 } 111 112 /// This is a "sensible" movemask implementation where each bit represents 113 /// whether the most significant bit is set in each corresponding lane of a 114 /// vector. This is used on x86-64 and wasm, but such a mask is more expensive 115 /// to get on aarch64 so we use something a little different. 116 /// 117 /// We call this "sensible" because this is what we get using native sse/avx 118 /// movemask instructions. But neon has no such native equivalent. 119 #[derive(Clone, Copy, Debug)] 120 pub(crate) struct SensibleMoveMask(u32); 121 122 impl SensibleMoveMask { 123 /// Get the mask in a form suitable for computing offsets. 124 /// 125 /// Basically, this normalizes to little endian. On big endian, this swaps 126 /// the bytes. 127 #[inline(always)] get_for_offset(self) -> u32128 fn get_for_offset(self) -> u32 { 129 #[cfg(target_endian = "big")] 130 { 131 self.0.swap_bytes() 132 } 133 #[cfg(target_endian = "little")] 134 { 135 self.0 136 } 137 } 138 } 139 140 impl MoveMask for SensibleMoveMask { 141 #[inline(always)] all_zeros_except_least_significant(n: usize) -> SensibleMoveMask142 fn all_zeros_except_least_significant(n: usize) -> SensibleMoveMask { 143 debug_assert!(n < 32); 144 SensibleMoveMask(!((1 << n) - 1)) 145 } 146 147 #[inline(always)] has_non_zero(self) -> bool148 fn has_non_zero(self) -> bool { 149 self.0 != 0 150 } 151 152 #[inline(always)] count_ones(self) -> usize153 fn count_ones(self) -> usize { 154 self.0.count_ones() as usize 155 } 156 157 #[inline(always)] and(self, other: SensibleMoveMask) -> SensibleMoveMask158 fn and(self, other: SensibleMoveMask) -> SensibleMoveMask { 159 SensibleMoveMask(self.0 & other.0) 160 } 161 162 #[inline(always)] or(self, other: SensibleMoveMask) -> SensibleMoveMask163 fn or(self, other: SensibleMoveMask) -> SensibleMoveMask { 164 SensibleMoveMask(self.0 | other.0) 165 } 166 167 #[inline(always)] clear_least_significant_bit(self) -> SensibleMoveMask168 fn clear_least_significant_bit(self) -> SensibleMoveMask { 169 SensibleMoveMask(self.0 & (self.0 - 1)) 170 } 171 172 #[inline(always)] first_offset(self) -> usize173 fn first_offset(self) -> usize { 174 // We are dealing with little endian here (and if we aren't, we swap 175 // the bytes so we are in practice), where the most significant byte 176 // is at a higher address. That means the least significant bit that 177 // is set corresponds to the position of our first matching byte. 178 // That position corresponds to the number of zeros after the least 179 // significant bit. 180 self.get_for_offset().trailing_zeros() as usize 181 } 182 183 #[inline(always)] last_offset(self) -> usize184 fn last_offset(self) -> usize { 185 // We are dealing with little endian here (and if we aren't, we swap 186 // the bytes so we are in practice), where the most significant byte is 187 // at a higher address. That means the most significant bit that is set 188 // corresponds to the position of our last matching byte. The position 189 // from the end of the mask is therefore the number of leading zeros 190 // in a 32 bit integer, and the position from the start of the mask is 191 // therefore 32 - (leading zeros) - 1. 192 32 - self.get_for_offset().leading_zeros() as usize - 1 193 } 194 } 195 196 #[cfg(target_arch = "x86_64")] 197 mod x86sse2 { 198 use core::arch::x86_64::*; 199 200 use super::{SensibleMoveMask, Vector}; 201 202 impl Vector for __m128i { 203 const BITS: usize = 128; 204 const BYTES: usize = 16; 205 const ALIGN: usize = Self::BYTES - 1; 206 207 type Mask = SensibleMoveMask; 208 209 #[inline(always)] splat(byte: u8) -> __m128i210 unsafe fn splat(byte: u8) -> __m128i { 211 _mm_set1_epi8(byte as i8) 212 } 213 214 #[inline(always)] load_aligned(data: *const u8) -> __m128i215 unsafe fn load_aligned(data: *const u8) -> __m128i { 216 _mm_load_si128(data as *const __m128i) 217 } 218 219 #[inline(always)] load_unaligned(data: *const u8) -> __m128i220 unsafe fn load_unaligned(data: *const u8) -> __m128i { 221 _mm_loadu_si128(data as *const __m128i) 222 } 223 224 #[inline(always)] movemask(self) -> SensibleMoveMask225 unsafe fn movemask(self) -> SensibleMoveMask { 226 SensibleMoveMask(_mm_movemask_epi8(self) as u32) 227 } 228 229 #[inline(always)] cmpeq(self, vector2: Self) -> __m128i230 unsafe fn cmpeq(self, vector2: Self) -> __m128i { 231 _mm_cmpeq_epi8(self, vector2) 232 } 233 234 #[inline(always)] and(self, vector2: Self) -> __m128i235 unsafe fn and(self, vector2: Self) -> __m128i { 236 _mm_and_si128(self, vector2) 237 } 238 239 #[inline(always)] or(self, vector2: Self) -> __m128i240 unsafe fn or(self, vector2: Self) -> __m128i { 241 _mm_or_si128(self, vector2) 242 } 243 } 244 } 245 246 #[cfg(target_arch = "x86_64")] 247 mod x86avx2 { 248 use core::arch::x86_64::*; 249 250 use super::{SensibleMoveMask, Vector}; 251 252 impl Vector for __m256i { 253 const BITS: usize = 256; 254 const BYTES: usize = 32; 255 const ALIGN: usize = Self::BYTES - 1; 256 257 type Mask = SensibleMoveMask; 258 259 #[inline(always)] splat(byte: u8) -> __m256i260 unsafe fn splat(byte: u8) -> __m256i { 261 _mm256_set1_epi8(byte as i8) 262 } 263 264 #[inline(always)] load_aligned(data: *const u8) -> __m256i265 unsafe fn load_aligned(data: *const u8) -> __m256i { 266 _mm256_load_si256(data as *const __m256i) 267 } 268 269 #[inline(always)] load_unaligned(data: *const u8) -> __m256i270 unsafe fn load_unaligned(data: *const u8) -> __m256i { 271 _mm256_loadu_si256(data as *const __m256i) 272 } 273 274 #[inline(always)] movemask(self) -> SensibleMoveMask275 unsafe fn movemask(self) -> SensibleMoveMask { 276 SensibleMoveMask(_mm256_movemask_epi8(self) as u32) 277 } 278 279 #[inline(always)] cmpeq(self, vector2: Self) -> __m256i280 unsafe fn cmpeq(self, vector2: Self) -> __m256i { 281 _mm256_cmpeq_epi8(self, vector2) 282 } 283 284 #[inline(always)] and(self, vector2: Self) -> __m256i285 unsafe fn and(self, vector2: Self) -> __m256i { 286 _mm256_and_si256(self, vector2) 287 } 288 289 #[inline(always)] or(self, vector2: Self) -> __m256i290 unsafe fn or(self, vector2: Self) -> __m256i { 291 _mm256_or_si256(self, vector2) 292 } 293 } 294 } 295 296 #[cfg(target_arch = "aarch64")] 297 mod aarch64neon { 298 use core::arch::aarch64::*; 299 300 use super::{MoveMask, Vector}; 301 302 impl Vector for uint8x16_t { 303 const BITS: usize = 128; 304 const BYTES: usize = 16; 305 const ALIGN: usize = Self::BYTES - 1; 306 307 type Mask = NeonMoveMask; 308 309 #[inline(always)] splat(byte: u8) -> uint8x16_t310 unsafe fn splat(byte: u8) -> uint8x16_t { 311 vdupq_n_u8(byte) 312 } 313 314 #[inline(always)] load_aligned(data: *const u8) -> uint8x16_t315 unsafe fn load_aligned(data: *const u8) -> uint8x16_t { 316 // I've tried `data.cast::<uint8x16_t>().read()` instead, but 317 // couldn't observe any benchmark differences. 318 Self::load_unaligned(data) 319 } 320 321 #[inline(always)] load_unaligned(data: *const u8) -> uint8x16_t322 unsafe fn load_unaligned(data: *const u8) -> uint8x16_t { 323 vld1q_u8(data) 324 } 325 326 #[inline(always)] movemask(self) -> NeonMoveMask327 unsafe fn movemask(self) -> NeonMoveMask { 328 let asu16s = vreinterpretq_u16_u8(self); 329 let mask = vshrn_n_u16(asu16s, 4); 330 let asu64 = vreinterpret_u64_u8(mask); 331 let scalar64 = vget_lane_u64(asu64, 0); 332 NeonMoveMask(scalar64 & 0x8888888888888888) 333 } 334 335 #[inline(always)] cmpeq(self, vector2: Self) -> uint8x16_t336 unsafe fn cmpeq(self, vector2: Self) -> uint8x16_t { 337 vceqq_u8(self, vector2) 338 } 339 340 #[inline(always)] and(self, vector2: Self) -> uint8x16_t341 unsafe fn and(self, vector2: Self) -> uint8x16_t { 342 vandq_u8(self, vector2) 343 } 344 345 #[inline(always)] or(self, vector2: Self) -> uint8x16_t346 unsafe fn or(self, vector2: Self) -> uint8x16_t { 347 vorrq_u8(self, vector2) 348 } 349 350 /// This is the only interesting implementation of this routine. 351 /// Basically, instead of doing the "shift right narrow" dance, we use 352 /// adajacent folding max to determine whether there are any non-zero 353 /// bytes in our mask. If there are, *then* we'll do the "shift right 354 /// narrow" dance. In benchmarks, this does lead to slightly better 355 /// throughput, but the win doesn't appear huge. 356 #[inline(always)] movemask_will_have_non_zero(self) -> bool357 unsafe fn movemask_will_have_non_zero(self) -> bool { 358 let low = vreinterpretq_u64_u8(vpmaxq_u8(self, self)); 359 vgetq_lane_u64(low, 0) != 0 360 } 361 } 362 363 /// Neon doesn't have a `movemask` that works like the one in x86-64, so we 364 /// wind up using a different method[1]. The different method also produces 365 /// a mask, but 4 bits are set in the neon case instead of a single bit set 366 /// in the x86-64 case. We do an extra step to zero out 3 of the 4 bits, 367 /// but we still wind up with at least 3 zeroes between each set bit. This 368 /// generally means that we need to do some division by 4 before extracting 369 /// offsets. 370 /// 371 /// In fact, the existence of this type is the entire reason that we have 372 /// the `MoveMask` trait in the first place. This basically lets us keep 373 /// the different representations of masks without being forced to unify 374 /// them into a single representation, which could result in extra and 375 /// unnecessary work. 376 /// 377 /// [1]: https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon 378 #[derive(Clone, Copy, Debug)] 379 pub(crate) struct NeonMoveMask(u64); 380 381 impl NeonMoveMask { 382 /// Get the mask in a form suitable for computing offsets. 383 /// 384 /// Basically, this normalizes to little endian. On big endian, this 385 /// swaps the bytes. 386 #[inline(always)] get_for_offset(self) -> u64387 fn get_for_offset(self) -> u64 { 388 #[cfg(target_endian = "big")] 389 { 390 self.0.swap_bytes() 391 } 392 #[cfg(target_endian = "little")] 393 { 394 self.0 395 } 396 } 397 } 398 399 impl MoveMask for NeonMoveMask { 400 #[inline(always)] all_zeros_except_least_significant(n: usize) -> NeonMoveMask401 fn all_zeros_except_least_significant(n: usize) -> NeonMoveMask { 402 debug_assert!(n < 16); 403 NeonMoveMask(!(((1 << n) << 2) - 1)) 404 } 405 406 #[inline(always)] has_non_zero(self) -> bool407 fn has_non_zero(self) -> bool { 408 self.0 != 0 409 } 410 411 #[inline(always)] count_ones(self) -> usize412 fn count_ones(self) -> usize { 413 self.0.count_ones() as usize 414 } 415 416 #[inline(always)] and(self, other: NeonMoveMask) -> NeonMoveMask417 fn and(self, other: NeonMoveMask) -> NeonMoveMask { 418 NeonMoveMask(self.0 & other.0) 419 } 420 421 #[inline(always)] or(self, other: NeonMoveMask) -> NeonMoveMask422 fn or(self, other: NeonMoveMask) -> NeonMoveMask { 423 NeonMoveMask(self.0 | other.0) 424 } 425 426 #[inline(always)] clear_least_significant_bit(self) -> NeonMoveMask427 fn clear_least_significant_bit(self) -> NeonMoveMask { 428 NeonMoveMask(self.0 & (self.0 - 1)) 429 } 430 431 #[inline(always)] first_offset(self) -> usize432 fn first_offset(self) -> usize { 433 // We are dealing with little endian here (and if we aren't, 434 // we swap the bytes so we are in practice), where the most 435 // significant byte is at a higher address. That means the least 436 // significant bit that is set corresponds to the position of our 437 // first matching byte. That position corresponds to the number of 438 // zeros after the least significant bit. 439 // 440 // Note that unlike `SensibleMoveMask`, this mask has its bits 441 // spread out over 64 bits instead of 16 bits (for a 128 bit 442 // vector). Namely, where as x86-64 will turn 443 // 444 // 0x00 0xFF 0x00 0x00 0xFF 445 // 446 // into 10010, our neon approach will turn it into 447 // 448 // 10000000000010000000 449 // 450 // And this happens because neon doesn't have a native `movemask` 451 // instruction, so we kind of fake it[1]. Thus, we divide the 452 // number of trailing zeros by 4 to get the "real" offset. 453 // 454 // [1]: https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon 455 (self.get_for_offset().trailing_zeros() >> 2) as usize 456 } 457 458 #[inline(always)] last_offset(self) -> usize459 fn last_offset(self) -> usize { 460 // See comment in `first_offset` above. This is basically the same, 461 // but coming from the other direction. 462 16 - (self.get_for_offset().leading_zeros() >> 2) as usize - 1 463 } 464 } 465 } 466 467 #[cfg(target_arch = "wasm32")] 468 mod wasm_simd128 { 469 use core::arch::wasm32::*; 470 471 use super::{SensibleMoveMask, Vector}; 472 473 impl Vector for v128 { 474 const BITS: usize = 128; 475 const BYTES: usize = 16; 476 const ALIGN: usize = Self::BYTES - 1; 477 478 type Mask = SensibleMoveMask; 479 480 #[inline(always)] splat(byte: u8) -> v128481 unsafe fn splat(byte: u8) -> v128 { 482 u8x16_splat(byte) 483 } 484 485 #[inline(always)] load_aligned(data: *const u8) -> v128486 unsafe fn load_aligned(data: *const u8) -> v128 { 487 *data.cast() 488 } 489 490 #[inline(always)] load_unaligned(data: *const u8) -> v128491 unsafe fn load_unaligned(data: *const u8) -> v128 { 492 v128_load(data.cast()) 493 } 494 495 #[inline(always)] movemask(self) -> SensibleMoveMask496 unsafe fn movemask(self) -> SensibleMoveMask { 497 SensibleMoveMask(u8x16_bitmask(self).into()) 498 } 499 500 #[inline(always)] cmpeq(self, vector2: Self) -> v128501 unsafe fn cmpeq(self, vector2: Self) -> v128 { 502 u8x16_eq(self, vector2) 503 } 504 505 #[inline(always)] and(self, vector2: Self) -> v128506 unsafe fn and(self, vector2: Self) -> v128 { 507 v128_and(self, vector2) 508 } 509 510 #[inline(always)] or(self, vector2: Self) -> v128511 unsafe fn or(self, vector2: Self) -> v128 { 512 v128_or(self, vector2) 513 } 514 } 515 } 516