1 #ifndef AVX512_FUNCS_H
2 #define AVX512_FUNCS_H
3 
4 #include <immintrin.h>
5 #include <stdint.h>
6 /* Written because *_add_epi32(a) sets off ubsan */
_mm512_reduce_add_epu32(__m512i x)7 static inline uint32_t _mm512_reduce_add_epu32(__m512i x) {
8     __m256i a = _mm512_extracti64x4_epi64(x, 1);
9     __m256i b = _mm512_extracti64x4_epi64(x, 0);
10 
11     __m256i a_plus_b = _mm256_add_epi32(a, b);
12     __m128i c = _mm256_extracti128_si256(a_plus_b, 1);
13     __m128i d = _mm256_extracti128_si256(a_plus_b, 0);
14     __m128i c_plus_d = _mm_add_epi32(c, d);
15 
16     __m128i sum1 = _mm_unpackhi_epi64(c_plus_d, c_plus_d);
17     __m128i sum2 = _mm_add_epi32(sum1, c_plus_d);
18     __m128i sum3 = _mm_shuffle_epi32(sum2, 0x01);
19     __m128i sum4 = _mm_add_epi32(sum2, sum3);
20 
21     return _mm_cvtsi128_si32(sum4);
22 }
23 
partial_hsum(__m512i x)24 static inline uint32_t partial_hsum(__m512i x) {
25     /* We need a permutation vector to extract every other integer. The
26      * rest are going to be zeros. Marking this const so the compiler stands
27      * a better chance of keeping this resident in a register through entire
28      * loop execution. We certainly have enough zmm registers (32) */
29     const __m512i perm_vec = _mm512_setr_epi32(0, 2, 4, 6, 8, 10, 12, 14,
30                                                1, 1, 1, 1, 1,  1,  1,  1);
31 
32     __m512i non_zero = _mm512_permutexvar_epi32(perm_vec, x);
33 
34     /* From here, it's a simple 256 bit wide reduction sum */
35     __m256i non_zero_avx = _mm512_castsi512_si256(non_zero);
36 
37     /* See Agner Fog's vectorclass for a decent reference. Essentially, phadd is
38      * pretty slow, much slower than the longer instruction sequence below */
39     __m128i sum1  = _mm_add_epi32(_mm256_extracti128_si256(non_zero_avx, 1),
40                                   _mm256_castsi256_si128(non_zero_avx));
41     __m128i sum2  = _mm_add_epi32(sum1,_mm_unpackhi_epi64(sum1, sum1));
42     __m128i sum3  = _mm_add_epi32(sum2,_mm_shuffle_epi32(sum2, 1));
43     return (uint32_t)_mm_cvtsi128_si32(sum3);
44 }
45 
46 #endif
47