xref: /aosp_15_r20/external/ruy/ruy/pack_avx.cc (revision bb86c7ed5fb1b98a7eac808e443a46cc8b90dfc0)
1 /* Copyright 2020 Google LLC. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstdint>
17 #include <cstring>
18 
19 #include "ruy/check_macros.h"
20 #include "ruy/opt_set.h"
21 #include "ruy/pack_x86.h"
22 #include "ruy/path.h"
23 #include "ruy/platform.h"
24 #include "ruy/profiler/instrumentation.h"
25 
26 #if RUY_PLATFORM_AVX && RUY_OPT(INTRINSICS)
27 #include <immintrin.h>  // IWYU pragma: keep
28 #endif
29 
30 namespace ruy {
31 
32 #if !(RUY_PLATFORM_AVX && RUY_OPT(ASM))
33 
Pack8bitColMajorForAvx(const std::int8_t *,std::int8_t,const std::int8_t *,int,int,int,std::int8_t *,std::int32_t *)34 void Pack8bitColMajorForAvx(const std::int8_t*, std::int8_t, const std::int8_t*,
35                             int, int, int, std::int8_t*, std::int32_t*) {
36   // CPU-ID-based checks should disable the path that would reach this point.
37   RUY_DCHECK(false);
38 }
39 
PackFloatColMajorForAvx(const float *,const float *,int,int,int,float *)40 void PackFloatColMajorForAvx(const float*, const float*, int, int, int,
41                              float*) {
42   // CPU-ID-based checks should disable the path that would reach this point.
43   RUY_DCHECK(false);
44 }
45 
Pack8bitRowMajorForAvx(const std::uint8_t *,int,int,std::int8_t *,int,int,int,int,int,int,int,std::int32_t *)46 void Pack8bitRowMajorForAvx(const std::uint8_t*, int, int, std::int8_t*, int,
47                             int, int, int, int, int, int, std::int32_t*) {
48   RUY_DCHECK(false);
49 }
50 
51 #else  // RUY_PLATFORM_AVX && RUY_OPT(ASM)
52 
53 // The first int8_t template parameter is arbitrary: this routine is common to
54 // all 8-bit source matrix types.
55 using PackImpl8bitAvx =
56     PackImpl<Path::kAvx, FixedKernelLayout<Order::kColMajor, 4, 8>, std::int8_t,
57              std::int8_t, std::int32_t, Order::kColMajor>;
58 
59 using PackImplFloatAvx =
60     PackImpl<Path::kAvx, FixedKernelLayout<Order::kRowMajor, 1, 8>, float,
61              float, float, Order::kColMajor>;
62 
63 namespace {
64 
65 // Perform the equivalent of mm256_permutevar8x32 with
66 // a second argument of {7, 5, 3, 1, 6, 4, 2, 0}
67 inline __m256i PermuteEpi32EvenOdds(const __m256i& a) {
68   // a_lo = 3 2 1 0
69   __m128i a_lo = _mm256_extractf128_si256(a, 0);
70   // a_hi = 7 6 5 4
71   __m128i a_hi = _mm256_extractf128_si256(a, 1);
72   // shuffle a_lo to get 3 1 2 0
73   __m128i tmp_lo = _mm_shuffle_epi32(a_lo, 0xd8);
74   // shuffle a_hi to get 7 5 6 4
75   __m128i tmp_hi = _mm_shuffle_epi32(a_hi, 0xd8);
76   // unpack lo 64 of res_lo and res hi to get 6 4 2 0
77   __m128i res_lo = _mm_unpacklo_epi64(tmp_lo, tmp_hi);
78   // unpack hi 64 of res_lo and res hi to get 7 5 3 1
79   __m128i res_hi = _mm_unpackhi_epi64(tmp_lo, tmp_hi);
80   return _mm256_set_m128i(res_hi, res_lo);
81 }
82 
83 inline __m128i mm256_extracti128_si256(const __m256i& a, const int imm) {
84   switch (imm) {
85     case 0:
86       return _mm256_extractf128_si256(a, 0);
87     case 1:
88       return _mm256_extractf128_si256(a, 1);
89     default:
90       RUY_DCHECK_LT(imm, 2);
91       return _mm_setzero_si128();
92   }
93 }
94 
95 inline __m256i mm256_cvtepi8_epi16(const __m128i& a) {
96   // Take the upper 64 bits of a and put in the first 64 bits of 'hi'
97   __m128i hi = _mm_unpackhi_epi64(a, _mm_setzero_si128());
98   return _mm256_set_m128i(_mm_cvtepi8_epi16(hi), _mm_cvtepi8_epi16(a));
99 }
100 
101 inline __m256i mm256_cvtepi16_epi32(const __m128i& a) {
102   // Take the upper 64 bits of a and put in the first 64 bits of 'hi'
103   __m128i hi = _mm_unpackhi_epi64(a, _mm_setzero_si128());
104   return _mm256_set_m128i(_mm_cvtepi16_epi32(hi), _mm_cvtepi16_epi32(a));
105 }
106 
107 inline __m256i mm256_xor_si256(const __m256i& a, const __m256i& b) {
108   __m128i a_lo = _mm256_extractf128_si256(a, 0);
109   __m128i a_hi = _mm256_extractf128_si256(a, 1);
110   __m128i b_lo = _mm256_extractf128_si256(b, 0);
111   __m128i b_hi = _mm256_extractf128_si256(b, 1);
112   __m128i lo = _mm_xor_si128(a_lo, b_lo);
113   __m128i hi = _mm_xor_si128(a_hi, b_hi);
114   return _mm256_set_m128i(hi, lo);
115 }
116 
117 inline __m256i mm256_unpacklo_epi32(const __m256i& a, const __m256i& b) {
118   __m128i a_lo = _mm256_extractf128_si256(a, 0);
119   __m128i a_hi = _mm256_extractf128_si256(a, 1);
120   __m128i b_lo = _mm256_extractf128_si256(b, 0);
121   __m128i b_hi = _mm256_extractf128_si256(b, 1);
122   __m128i lo = _mm_unpacklo_epi32(a_lo, b_lo);
123   __m128i hi = _mm_unpacklo_epi32(a_hi, b_hi);
124   return _mm256_set_m128i(hi, lo);
125 }
126 
127 inline __m256i mm256_unpacklo_epi64(const __m256i& a, const __m256i& b) {
128   __m128i a_lo = _mm256_extractf128_si256(a, 0);
129   __m128i a_hi = _mm256_extractf128_si256(a, 1);
130   __m128i b_lo = _mm256_extractf128_si256(b, 0);
131   __m128i b_hi = _mm256_extractf128_si256(b, 1);
132   __m128i lo = _mm_unpacklo_epi64(a_lo, b_lo);
133   __m128i hi = _mm_unpacklo_epi64(a_hi, b_hi);
134   return _mm256_set_m128i(hi, lo);
135 }
136 
137 inline __m256i mm256_unpackhi_epi32(const __m256i& a, const __m256i& b) {
138   __m128i a_lo = _mm256_extractf128_si256(a, 0);
139   __m128i a_hi = _mm256_extractf128_si256(a, 1);
140   __m128i b_lo = _mm256_extractf128_si256(b, 0);
141   __m128i b_hi = _mm256_extractf128_si256(b, 1);
142   __m128i lo = _mm_unpackhi_epi32(a_lo, b_lo);
143   __m128i hi = _mm_unpackhi_epi32(a_hi, b_hi);
144   return _mm256_set_m128i(hi, lo);
145 }
146 
147 inline __m256i mm256_unpackhi_epi64(const __m256i& a, const __m256i& b) {
148   __m128i a_lo = _mm256_extractf128_si256(a, 0);
149   __m128i a_hi = _mm256_extractf128_si256(a, 1);
150   __m128i b_lo = _mm256_extractf128_si256(b, 0);
151   __m128i b_hi = _mm256_extractf128_si256(b, 1);
152   __m128i lo = _mm_unpackhi_epi64(a_lo, b_lo);
153   __m128i hi = _mm_unpackhi_epi64(a_hi, b_hi);
154   return _mm256_set_m128i(hi, lo);
155 }
156 
157 inline __m256i mm256_add_epi32(const __m256i& a, const __m256i& b) {
158   __m128i a_lo = _mm256_extractf128_si256(a, 0);
159   __m128i a_hi = _mm256_extractf128_si256(a, 1);
160   __m128i b_lo = _mm256_extractf128_si256(b, 0);
161   __m128i b_hi = _mm256_extractf128_si256(b, 1);
162   __m128i lo = _mm_add_epi32(a_lo, b_lo);
163   __m128i hi = _mm_add_epi32(a_hi, b_hi);
164   return _mm256_set_m128i(hi, lo);
165 }
166 
167 inline __m256i mm256_add_epi16(const __m256i& a, const __m256i& b) {
168   __m128i a_lo = _mm256_extractf128_si256(a, 0);
169   __m128i a_hi = _mm256_extractf128_si256(a, 1);
170   __m128i b_lo = _mm256_extractf128_si256(b, 0);
171   __m128i b_hi = _mm256_extractf128_si256(b, 1);
172   __m128i lo = _mm_add_epi16(a_lo, b_lo);
173   __m128i hi = _mm_add_epi16(a_hi, b_hi);
174   return _mm256_set_m128i(hi, lo);
175 }
176 
177 inline __m256i mm256_madd_epi16(const __m256i& a, const __m256i& b) {
178   __m128i a_lo = _mm256_extractf128_si256(a, 0);
179   __m128i a_hi = _mm256_extractf128_si256(a, 1);
180   __m128i b_lo = _mm256_extractf128_si256(b, 0);
181   __m128i b_hi = _mm256_extractf128_si256(b, 1);
182   __m128i lo = _mm_madd_epi16(a_lo, b_lo);
183   __m128i hi = _mm_madd_epi16(a_hi, b_hi);
184   return _mm256_set_m128i(hi, lo);
185 }
186 
187 inline __m128i mm_permute_helper(const __m256i& a, const __m256i& b,
188                                  const int imm) {
189   __m128i tmp = _mm_setzero_si128();
190   if (!(imm & 8)) {
191     switch (imm & 3) {
192       case 0:
193         return _mm256_extractf128_si256(a, 0);
194       case 1:
195         return _mm256_extractf128_si256(a, 1);
196       case 2:
197         return _mm256_extractf128_si256(b, 0);
198       case 3:
199         return _mm256_extractf128_si256(b, 1);
200     }
201   }
202   return tmp;
203 }
204 
205 inline __m256i mm256_permute2x128_si256(const __m256i& a, const __m256i& b,
206                                         const int imm) {
207   const int lo_imm = imm & 15;
208   __m128i lo = mm_permute_helper(a, b, lo_imm);
209   const int hi_imm = (imm >> 4) & 15;
210   __m128i hi = mm_permute_helper(a, b, hi_imm);
211   return _mm256_set_m128i(hi, lo);
212 }
213 
214 inline void Pack8bitColMajorForAvxPacker(const std::int8_t* src_ptr,
215                                          std::int8_t input_xor,
216                                          const std::int8_t* zerobuf,
217                                          int src_stride, int remaining_src_cols,
218                                          int src_rows, std::int8_t* packed_ptr,
219                                          std::int32_t* sums_ptr,
220                                          std::int8_t* trailing_buf) {
221   using Layout = PackImpl8bitAvx::Layout;
222   RUY_DCHECK_EQ(Layout::kCols, 8);
223   RUY_DCHECK_EQ(Layout::kRows, 4);
224   // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
225   // We process 8 of these chunks at a time, padding short input chunks.
226   constexpr int kNumRowChunks = 8;
227   constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows;
228 
229   const std::int8_t* src_ptr0 = src_ptr;
230   const std::int8_t* src_ptr1 = src_ptr0 + src_stride;
231   const std::int8_t* src_ptr2 = src_ptr1 + src_stride;
232   const std::int8_t* src_ptr3 = src_ptr2 + src_stride;
233   const std::int8_t* src_ptr4 = src_ptr3 + src_stride;
234   const std::int8_t* src_ptr5 = src_ptr4 + src_stride;
235   const std::int8_t* src_ptr6 = src_ptr5 + src_stride;
236   const std::int8_t* src_ptr7 = src_ptr6 + src_stride;
237   std::int64_t src_inc0 = kNumChunkedSrcRows;
238   std::int64_t src_inc1 = kNumChunkedSrcRows;
239   std::int64_t src_inc2 = kNumChunkedSrcRows;
240   std::int64_t src_inc3 = kNumChunkedSrcRows;
241   std::int64_t src_inc4 = kNumChunkedSrcRows;
242   std::int64_t src_inc5 = kNumChunkedSrcRows;
243   std::int64_t src_inc6 = kNumChunkedSrcRows;
244   std::int64_t src_inc7 = kNumChunkedSrcRows;
245   // Handle cases where source does not have Layout::kCols (8) columns.
246   if (remaining_src_cols < 8) {
247     if (remaining_src_cols <= 0) {
248       src_ptr0 = zerobuf;
249       src_inc0 = 0;
250     }
251     if (remaining_src_cols <= 1) {
252       src_ptr1 = zerobuf;
253       src_inc1 = 0;
254     }
255     if (remaining_src_cols <= 2) {
256       src_ptr2 = zerobuf;
257       src_inc2 = 0;
258     }
259     if (remaining_src_cols <= 3) {
260       src_ptr3 = zerobuf;
261       src_inc3 = 0;
262     }
263     if (remaining_src_cols <= 4) {
264       src_ptr4 = zerobuf;
265       src_inc4 = 0;
266     }
267     if (remaining_src_cols <= 5) {
268       src_ptr5 = zerobuf;
269       src_inc5 = 0;
270     }
271     if (remaining_src_cols <= 6) {
272       src_ptr6 = zerobuf;
273       src_inc6 = 0;
274     }
275     src_ptr7 = zerobuf;
276     src_inc7 = 0;
277   }
278 
279   const std::int8_t zero_point = zerobuf[0];
280 
281   if (sums_ptr) {
282     // i: Layout::kCols.
283     for (int i = 0; i < 8; ++i) {
284       sums_ptr[i] = 0;
285     }
286   }
287   std::int32_t sums_adjustment = 0;
288   const __m256i ones_16bit = _mm256_set1_epi16(1);
289   __m256i sums_4x2_32bit_lo = _mm256_set1_epi32(0);
290   __m256i sums_4x2_32bit_hi = _mm256_set1_epi32(0);
291 
292   // The overall packing effectively pads the source rows to
293   // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we
294   // only pack for (src_rows + 31) & ~31. When there is an incomplete
295   // destination block, this is stored into trailing_buf instead of packed_ptr.
296   for (int k = 0; k < src_rows; k += kNumChunkedSrcRows) {
297     // Available source rows.
298     // If this is less than 0 (for m=1), we skip, having filled trailing
299     // buffer for m=0. Also, if source rows is zero on m=1, then we filled
300     // exactly to the end of the column in the packed buffer.
301     const int available_src_rows = src_rows - k;
302     // Effectively,
303     // available rows = std::max(0, std::min(8, src_rows - k));
304     // treat each case separately.
305     if (available_src_rows >= kNumChunkedSrcRows) {
306       if (sums_ptr) {
307         __m256i t0, t1, t2, t3, t4, t5, t6, t7;
308         __m256i r0, r1, r2, r3, r4, r5, r6, r7;
309         const __m256i input_xor_v = _mm256_set1_epi8(input_xor);
310 
311         t0 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr0));
312         t4 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr4));
313         t1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr1));
314         t5 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr5));
315         t2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr2));
316         t6 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr6));
317         t3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr3));
318         t7 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr7));
319 
320         r0 = mm256_unpacklo_epi32(t0, t1);
321         r4 = mm256_unpacklo_epi32(t4, t5);
322         r2 = mm256_unpackhi_epi32(t0, t1);
323         r6 = mm256_unpackhi_epi32(t4, t5);
324         r1 = mm256_unpacklo_epi32(t2, t3);
325         r5 = mm256_unpacklo_epi32(t6, t7);
326         r3 = mm256_unpackhi_epi32(t2, t3);
327         r7 = mm256_unpackhi_epi32(t6, t7);
328 
329         t0 = mm256_unpacklo_epi64(r0, r1);
330         t4 = mm256_unpacklo_epi64(r4, r5);
331         t2 = mm256_unpackhi_epi64(r0, r1);
332         t6 = mm256_unpackhi_epi64(r4, r5);
333         t1 = mm256_unpacklo_epi64(r2, r3);
334         t5 = mm256_unpacklo_epi64(r6, r7);
335         t3 = mm256_unpackhi_epi64(r2, r3);
336         t7 = mm256_unpackhi_epi64(r6, r7);
337 
338         // The preceding sets of rearrangement operations interleaved by 4 bytes
339         // and then by 8 bytes *within* lanes. The following set interleave by
340         // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0,
341         // t4) are interleaved to create (r0, r1). This complexity follows from
342         // the way that AVX is centered around MM 128-bit lanes.
343         r0 = mm256_permute2x128_si256(t0, t4, 0x20);
344         r4 = mm256_permute2x128_si256(t1, t5, 0x20);
345         r1 = mm256_permute2x128_si256(t0, t4, 0x31);
346         r5 = mm256_permute2x128_si256(t1, t5, 0x31);
347         r2 = mm256_permute2x128_si256(t2, t6, 0x20);
348         r6 = mm256_permute2x128_si256(t3, t7, 0x20);
349         r3 = mm256_permute2x128_si256(t2, t6, 0x31);
350         r7 = mm256_permute2x128_si256(t3, t7, 0x31);
351 
352         r0 = mm256_xor_si256(r0, input_xor_v);
353         r1 = mm256_xor_si256(r1, input_xor_v);
354         r2 = mm256_xor_si256(r2, input_xor_v);
355         r3 = mm256_xor_si256(r3, input_xor_v);
356         r4 = mm256_xor_si256(r4, input_xor_v);
357         r5 = mm256_xor_si256(r5, input_xor_v);
358         r6 = mm256_xor_si256(r6, input_xor_v);
359         r7 = mm256_xor_si256(r7, input_xor_v);
360 
361         __m256i sums_4x4_16bit_lo;
362         sums_4x4_16bit_lo = mm256_cvtepi8_epi16(_mm256_castsi256_si128(r0));
363         sums_4x4_16bit_lo = mm256_add_epi16(
364             sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r1)));
365         sums_4x4_16bit_lo = mm256_add_epi16(
366             sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r2)));
367         sums_4x4_16bit_lo = mm256_add_epi16(
368             sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r3)));
369         sums_4x4_16bit_lo = mm256_add_epi16(
370             sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r4)));
371         sums_4x4_16bit_lo = mm256_add_epi16(
372             sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r5)));
373         sums_4x4_16bit_lo = mm256_add_epi16(
374             sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r6)));
375         sums_4x4_16bit_lo = mm256_add_epi16(
376             sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r7)));
377 
378         // The sums have been performed across columns, and now we have 4x16-bit
379         // sums packed together. We use madd for pairwise 32-bit sums.
380         const __m256i sums_4x2_32bit_lo_new =
381             mm256_madd_epi16(sums_4x4_16bit_lo, ones_16bit);
382         sums_4x2_32bit_lo =
383             mm256_add_epi32(sums_4x2_32bit_lo, sums_4x2_32bit_lo_new);
384 
385         __m256i sums_4x4_16bit_hi;
386         sums_4x4_16bit_hi = mm256_cvtepi8_epi16(mm256_extracti128_si256(r0, 1));
387         sums_4x4_16bit_hi = mm256_add_epi16(
388             sums_4x4_16bit_hi,
389             mm256_cvtepi8_epi16(mm256_extracti128_si256(r1, 1)));
390         sums_4x4_16bit_hi = mm256_add_epi16(
391             sums_4x4_16bit_hi,
392             mm256_cvtepi8_epi16(mm256_extracti128_si256(r2, 1)));
393         sums_4x4_16bit_hi = mm256_add_epi16(
394             sums_4x4_16bit_hi,
395             mm256_cvtepi8_epi16(mm256_extracti128_si256(r3, 1)));
396         sums_4x4_16bit_hi = mm256_add_epi16(
397             sums_4x4_16bit_hi,
398             mm256_cvtepi8_epi16(mm256_extracti128_si256(r4, 1)));
399         sums_4x4_16bit_hi = mm256_add_epi16(
400             sums_4x4_16bit_hi,
401             mm256_cvtepi8_epi16(mm256_extracti128_si256(r5, 1)));
402         sums_4x4_16bit_hi = mm256_add_epi16(
403             sums_4x4_16bit_hi,
404             mm256_cvtepi8_epi16(mm256_extracti128_si256(r6, 1)));
405         sums_4x4_16bit_hi = mm256_add_epi16(
406             sums_4x4_16bit_hi,
407             mm256_cvtepi8_epi16(mm256_extracti128_si256(r7, 1)));
408 
409         const __m256i sums_4x2_32bit_hi_new =
410             mm256_madd_epi16(sums_4x4_16bit_hi, ones_16bit);
411         sums_4x2_32bit_hi =
412             mm256_add_epi32(sums_4x2_32bit_hi, sums_4x2_32bit_hi_new);
413 
414         _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 0 * 8 * 4),
415                             r0);
416         _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 2 * 8 * 4),
417                             r4);
418         _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 4 * 8 * 4),
419                             r1);
420         _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 6 * 8 * 4),
421                             r5);
422         _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 1 * 8 * 4),
423                             r2);
424         _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 3 * 8 * 4),
425                             r6);
426         _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 5 * 8 * 4),
427                             r3);
428         _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 7 * 8 * 4),
429                             r7);
430       } else {
431         __m256i t0, t1, t2, t3, t4, t5, t6, t7;
432         __m256i r0, r1, r2, r3, r4, r5, r6, r7;
433         const __m256i input_xor_v = _mm256_set1_epi8(input_xor);
434 
435         t0 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr0));
436         t4 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr4));
437         t1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr1));
438         t5 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr5));
439         t2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr2));
440         t6 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr6));
441         t3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr3));
442         t7 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr7));
443 
444         r0 = mm256_unpacklo_epi32(t0, t1);
445         r4 = mm256_unpacklo_epi32(t4, t5);
446         r2 = mm256_unpackhi_epi32(t0, t1);
447         r6 = mm256_unpackhi_epi32(t4, t5);
448         r1 = mm256_unpacklo_epi32(t2, t3);
449         r5 = mm256_unpacklo_epi32(t6, t7);
450         r3 = mm256_unpackhi_epi32(t2, t3);
451         r7 = mm256_unpackhi_epi32(t6, t7);
452 
453         t0 = mm256_unpacklo_epi64(r0, r1);
454         t4 = mm256_unpacklo_epi64(r4, r5);
455         t2 = mm256_unpackhi_epi64(r0, r1);
456         t6 = mm256_unpackhi_epi64(r4, r5);
457         t1 = mm256_unpacklo_epi64(r2, r3);
458         t5 = mm256_unpacklo_epi64(r6, r7);
459         t3 = mm256_unpackhi_epi64(r2, r3);
460         t7 = mm256_unpackhi_epi64(r6, r7);
461 
462         // The preceding sets of rearrangement operations interleaved by 4 bytes
463         // and then by 8 bytes *within* lanes. The following set interleave by
464         // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0,
465         // t4) are interleaved to create (r0, r1). This complexity follows from
466         // the way that AVX is centered around MM 128-bit lanes.
467         r0 = mm256_permute2x128_si256(t0, t4, 0x20);
468         r4 = mm256_permute2x128_si256(t1, t5, 0x20);
469         r1 = mm256_permute2x128_si256(t0, t4, 0x31);
470         r5 = mm256_permute2x128_si256(t1, t5, 0x31);
471         r2 = mm256_permute2x128_si256(t2, t6, 0x20);
472         r6 = mm256_permute2x128_si256(t3, t7, 0x20);
473         r3 = mm256_permute2x128_si256(t2, t6, 0x31);
474         r7 = mm256_permute2x128_si256(t3, t7, 0x31);
475 
476         r0 = mm256_xor_si256(r0, input_xor_v);
477         r1 = mm256_xor_si256(r1, input_xor_v);
478         r2 = mm256_xor_si256(r2, input_xor_v);
479         r3 = mm256_xor_si256(r3, input_xor_v);
480         r4 = mm256_xor_si256(r4, input_xor_v);
481         r5 = mm256_xor_si256(r5, input_xor_v);
482         r6 = mm256_xor_si256(r6, input_xor_v);
483         r7 = mm256_xor_si256(r7, input_xor_v);
484 
485         _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 0 * 8 * 4),
486                             r0);
487         _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 2 * 8 * 4),
488                             r4);
489         _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 4 * 8 * 4),
490                             r1);
491         _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 6 * 8 * 4),
492                             r5);
493         _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 1 * 8 * 4),
494                             r2);
495         _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 3 * 8 * 4),
496                             r6);
497         _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 5 * 8 * 4),
498                             r3);
499         _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 7 * 8 * 4),
500                             r7);
501       }
502     } else if (available_src_rows > 0) {
503       RUY_DCHECK_LT(available_src_rows, kNumChunkedSrcRows);
504       // We do not care what goes into the trailing buffer, but we want
505       // in_data[...] ^ input_xor == 0 for irrelevant values in the summation.
506       //
507       // We compensate for padding-with-zero_point by initializing the
508       // summations with the compensating offset, effectively
509       // ((input_xor ^ input_xor) - (zero_point ^ input_xor)) *
510       //                         4 * (8 - ((available_src_rows + 3) >> 2)).
511       //
512       // Note that (zero_point ^ input_xor) is performed in 8-bits and then
513       // cast.
514       sums_adjustment +=
515           -(zero_point ^ input_xor) * 4 * (8 - ((available_src_rows + 3) >> 2));
516 
517       __m256i t0, t1, t2, t3, t4, t5, t6, t7;
518       __m256i r0, r1, r2, r3, r4, r5, r6, r7;
519       const __m256i input_xor_v = _mm256_set1_epi8(input_xor);
520 
521       t0 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr0);
522       t4 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr4);
523       t1 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr1);
524       t5 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr5);
525       t2 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr2);
526       t6 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr6);
527       t3 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr3);
528       t7 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr7);
529 
530       r0 = mm256_unpacklo_epi32(t0, t1);
531       r4 = mm256_unpacklo_epi32(t4, t5);
532       r2 = mm256_unpackhi_epi32(t0, t1);
533       r6 = mm256_unpackhi_epi32(t4, t5);
534       r1 = mm256_unpacklo_epi32(t2, t3);
535       r5 = mm256_unpacklo_epi32(t6, t7);
536       r3 = mm256_unpackhi_epi32(t2, t3);
537       r7 = mm256_unpackhi_epi32(t6, t7);
538 
539       t0 = mm256_unpacklo_epi64(r0, r1);
540       t4 = mm256_unpacklo_epi64(r4, r5);
541       t2 = mm256_unpackhi_epi64(r0, r1);
542       t6 = mm256_unpackhi_epi64(r4, r5);
543       t1 = mm256_unpacklo_epi64(r2, r3);
544       t5 = mm256_unpacklo_epi64(r6, r7);
545       t3 = mm256_unpackhi_epi64(r2, r3);
546       t7 = mm256_unpackhi_epi64(r6, r7);
547 
548       // The preceding sets of rearrangement operations interleaved by 4 bytes
549       // and then by 8 bytes *within* lanes. The following set interleave by
550       // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0,
551       // t4) are interleaved to create (r0, r1). This complexity follows from
552       // the way that AVX is centered around MM 128-bit lanes.
553       r0 = mm256_permute2x128_si256(t0, t4, 0x20);
554       r4 = mm256_permute2x128_si256(t1, t5, 0x20);
555       r1 = mm256_permute2x128_si256(t0, t4, 0x31);
556       r5 = mm256_permute2x128_si256(t1, t5, 0x31);
557       r2 = mm256_permute2x128_si256(t2, t6, 0x20);
558       r6 = mm256_permute2x128_si256(t3, t7, 0x20);
559       r3 = mm256_permute2x128_si256(t2, t6, 0x31);
560       r7 = mm256_permute2x128_si256(t3, t7, 0x31);
561 
562       r0 = mm256_xor_si256(r0, input_xor_v);
563       r1 = mm256_xor_si256(r1, input_xor_v);
564       r2 = mm256_xor_si256(r2, input_xor_v);
565       r3 = mm256_xor_si256(r3, input_xor_v);
566       r4 = mm256_xor_si256(r4, input_xor_v);
567       r5 = mm256_xor_si256(r5, input_xor_v);
568       r6 = mm256_xor_si256(r6, input_xor_v);
569       r7 = mm256_xor_si256(r7, input_xor_v);
570 
571       __m256i sums_4x4_16bit_lo;
572       sums_4x4_16bit_lo = mm256_cvtepi8_epi16(_mm256_castsi256_si128(r0));
573       sums_4x4_16bit_lo = mm256_add_epi16(
574           sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r1)));
575       sums_4x4_16bit_lo = mm256_add_epi16(
576           sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r2)));
577       sums_4x4_16bit_lo = mm256_add_epi16(
578           sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r3)));
579       sums_4x4_16bit_lo = mm256_add_epi16(
580           sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r4)));
581       sums_4x4_16bit_lo = mm256_add_epi16(
582           sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r5)));
583       sums_4x4_16bit_lo = mm256_add_epi16(
584           sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r6)));
585       sums_4x4_16bit_lo = mm256_add_epi16(
586           sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r7)));
587 
588       // The sums have been performed across columns, and now we have 4x16-bit
589       // sums packed together. We use madd for pairwise 32-bit sums.
590       const __m256i sums_4x2_32bit_lo_new =
591           mm256_madd_epi16(sums_4x4_16bit_lo, ones_16bit);
592       sums_4x2_32bit_lo =
593           mm256_add_epi32(sums_4x2_32bit_lo, sums_4x2_32bit_lo_new);
594 
595       __m256i sums_4x4_16bit_hi;
596       sums_4x4_16bit_hi = mm256_cvtepi8_epi16(mm256_extracti128_si256(r0, 1));
597       sums_4x4_16bit_hi =
598           mm256_add_epi16(sums_4x4_16bit_hi,
599                           mm256_cvtepi8_epi16(mm256_extracti128_si256(r1, 1)));
600       sums_4x4_16bit_hi =
601           mm256_add_epi16(sums_4x4_16bit_hi,
602                           mm256_cvtepi8_epi16(mm256_extracti128_si256(r2, 1)));
603       sums_4x4_16bit_hi =
604           mm256_add_epi16(sums_4x4_16bit_hi,
605                           mm256_cvtepi8_epi16(mm256_extracti128_si256(r3, 1)));
606       sums_4x4_16bit_hi =
607           mm256_add_epi16(sums_4x4_16bit_hi,
608                           mm256_cvtepi8_epi16(mm256_extracti128_si256(r4, 1)));
609       sums_4x4_16bit_hi =
610           mm256_add_epi16(sums_4x4_16bit_hi,
611                           mm256_cvtepi8_epi16(mm256_extracti128_si256(r5, 1)));
612       sums_4x4_16bit_hi =
613           mm256_add_epi16(sums_4x4_16bit_hi,
614                           mm256_cvtepi8_epi16(mm256_extracti128_si256(r6, 1)));
615       sums_4x4_16bit_hi =
616           mm256_add_epi16(sums_4x4_16bit_hi,
617                           mm256_cvtepi8_epi16(mm256_extracti128_si256(r7, 1)));
618 
619       const __m256i sums_4x2_32bit_hi_new =
620           mm256_madd_epi16(sums_4x4_16bit_hi, ones_16bit);
621       sums_4x2_32bit_hi =
622           mm256_add_epi32(sums_4x2_32bit_hi, sums_4x2_32bit_hi_new);
623 
624       _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 0 * 8 * 4),
625                           r0);
626       _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 2 * 8 * 4),
627                           r4);
628       _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 4 * 8 * 4),
629                           r1);
630       _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 6 * 8 * 4),
631                           r5);
632       _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 1 * 8 * 4),
633                           r2);
634       _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 3 * 8 * 4),
635                           r6);
636       _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 5 * 8 * 4),
637                           r3);
638       _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 7 * 8 * 4),
639                           r7);
640     }
641 
642     packed_ptr += 8 * kNumChunkedSrcRows;
643     src_ptr0 += src_inc0;
644     src_ptr1 += src_inc1;
645     src_ptr2 += src_inc2;
646     src_ptr3 += src_inc3;
647     src_ptr4 += src_inc4;
648     src_ptr5 += src_inc5;
649     src_ptr6 += src_inc6;
650     src_ptr7 += src_inc7;
651   }
652 
653   if (sums_ptr) {
654     const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment);
655 
656     __m256i sums =
657         _mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums_ptr));
658 
659     // We earlier used madd for pairwise 32-bit sums, and now we deinterlace the
660     // neighbours, finshing up by adding them to the stored accumulated sums.
661     const __m256i sums_2x4_32bit_lo = PermuteEpi32EvenOdds(sums_4x2_32bit_lo);
662     const __m256i sums_2x4_32bit_hi = PermuteEpi32EvenOdds(sums_4x2_32bit_hi);
663     const __m256i sums_2x4_32bit_a =
664         mm256_permute2x128_si256(sums_2x4_32bit_lo, sums_2x4_32bit_hi, 0x20);
665     const __m256i sums_2x4_32bit_b =
666         mm256_permute2x128_si256(sums_2x4_32bit_lo, sums_2x4_32bit_hi, 0x31);
667     sums = mm256_add_epi32(sums, sums_adjustment_v);
668     sums = mm256_add_epi32(sums, sums_2x4_32bit_a);
669     sums = mm256_add_epi32(sums, sums_2x4_32bit_b);
670 
671     _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums_ptr), sums);
672   }
673 }
674 
675 // Use a generic AVX intrinsic for greater-than comparison.
676 template <>
677 inline __m256i CompareGreaterThan<Path::kAvx>(const __m256i& a,
678                                               const __m256i& b) {
679   constexpr int kGreaterThanSignalling = 14;
680   const __m256 v = _mm256_cmp_ps(_mm256_cvtepi32_ps(a), _mm256_cvtepi32_ps(b),
681                                  kGreaterThanSignalling);
682   return _mm256_cvtps_epi32(v);
683 }
684 
685 }  // namespace.
686 
687 void Pack8bitColMajorForAvx(const std::int8_t* src_ptr, std::int8_t input_xor,
688                             const std::int8_t* zerobuf, int src_stride,
689                             int remaining_src_cols, int src_rows,
690                             std::int8_t* packed_ptr, std::int32_t* sums_ptr) {
691   profiler::ScopeLabel label("Pack kAvx 8bit");
692 
693   using Layout = PackImpl8bitAvx::Layout;
694   RUY_DCHECK_EQ(Layout::kCols, 8);
695   RUY_DCHECK_EQ(Layout::kRows, 4);
696 
697   // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
698   // We process 8 of these chunks at a time, padding short input chunks.
699   static constexpr int kNumRowChunks = 8;  // Short input is padded.
700 
701   // Each packed block is 4*8, and there are normally 8. The trailing block is
702   // only slightly shorter.
703   constexpr int kTrailingBufSize =
704       kNumRowChunks * Layout::kCols * Layout::kRows;
705   std::int8_t trailing_buf[kTrailingBufSize];
706   memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t));
707 
708   Pack8bitColMajorForAvxPacker(src_ptr, input_xor, zerobuf, src_stride,
709                                remaining_src_cols, src_rows, packed_ptr,
710                                sums_ptr, trailing_buf);
711 
712   constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1;
713   const bool trailing_data = (src_rows & kChunkedRowMask) > 0;
714   // If the number of source rows is not a multiple of kChunkedRowMask, there
715   // will be data in the trailing buffer,
716   if (trailing_data) {
717     const int non_trailing_rows = src_rows & ~kChunkedRowMask;
718     // Destination "rows" are padded to next highest multiple of Layout::kRows.
719     const int dst_rows = (src_rows + 3) & ~3;
720     const int trailing_rows = dst_rows - non_trailing_rows;
721     memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf,
722            Layout::kCols * trailing_rows * sizeof(std::int8_t));
723   }
724 }
725 
726 void PackFloatColMajorForAvx(const float* src_ptr, const float* zerobuf,
727                              int src_stride, int remaining_src_cols,
728                              int src_rows, float* packed_ptr) {
729   profiler::ScopeLabel label("Pack kAvx float");
730   static constexpr int kPackCols = 8;  // Source cols packed together.
731   static constexpr int kPackRows = 8;  // Short input is padded.
732   float trailing_buf[(kPackRows - 1) * kPackCols];
733   if (remaining_src_cols < 8) {
734     memset(trailing_buf, 0, sizeof(trailing_buf));
735   }
736   PackFloatColMajorForAvxCommonPacker<PackImplFloatAvx, Path::kAvx>(
737       src_ptr, zerobuf, src_stride, remaining_src_cols, src_rows, packed_ptr,
738       trailing_buf);
739 
740   const int trailing_rows = src_rows & (kPackRows - 1);
741   if (trailing_rows > 0) {
742     const int non_trailing_rows = src_rows & ~(kPackRows - 1);
743     memcpy(packed_ptr + kPackCols * non_trailing_rows, trailing_buf,
744            kPackCols * trailing_rows * sizeof(float));
745   }
746 }
747 
748 void Pack8bitRowMajorForAvx(const std::uint8_t* src_ptr, int src_stride,
749                             int src_zero_point, std::int8_t* packed_ptr,
750                             int packed_stride, int start_col, int end_col,
751                             int src_cols, int block_row, int src_rows,
752                             int input_xor, std::int32_t* sums) {
753   int col = start_col;
754   int src_end_col = std::min(end_col, src_cols);
755 
756   for (; col <= src_end_col - 8; col += 8) {
757     std::int8_t* dst_ptr = packed_ptr;
758     __m128i val0, val1, val2, val3;
759     __m128i input_xor_dup = _mm_set1_epi8(input_xor);
760     // Load a 4x8 block.
761     if (block_row + 4 <= src_rows) {
762       val0 = _mm_loadu_si64(src_ptr + 0 * src_stride);
763       val1 = _mm_loadu_si64(src_ptr + 1 * src_stride);
764       val2 = _mm_loadu_si64(src_ptr + 2 * src_stride);
765       val3 = _mm_loadu_si64(src_ptr + 3 * src_stride);
766     } else {
767       val0 = _mm_set1_epi8(src_zero_point);
768       val1 = val0;
769       val2 = val0;
770       val3 = val0;
771       if (block_row + 0 < src_rows)
772         val0 = _mm_loadu_si64(src_ptr + 0 * src_stride);
773       if (block_row + 1 < src_rows)
774         val1 = _mm_loadu_si64(src_ptr + 1 * src_stride);
775       if (block_row + 2 < src_rows)
776         val2 = _mm_loadu_si64(src_ptr + 2 * src_stride);
777       if (block_row + 3 < src_rows)
778         val3 = _mm_loadu_si64(src_ptr + 3 * src_stride);
779     }
780     // Maybe xor the sign bit to convert from uint8 to int8.
781     val0 = _mm_xor_si128(val0, input_xor_dup);
782     val1 = _mm_xor_si128(val1, input_xor_dup);
783     val2 = _mm_xor_si128(val2, input_xor_dup);
784     val3 = _mm_xor_si128(val3, input_xor_dup);
785     // Update the sums.
786     __m128i val16_0 = _mm_cvtepi8_epi16(val0);
787     __m128i val16_1 = _mm_cvtepi8_epi16(val1);
788     __m128i val16_2 = _mm_cvtepi8_epi16(val2);
789     __m128i val16_3 = _mm_cvtepi8_epi16(val3);
790     __m128i new_sum16 = _mm_add_epi16(_mm_add_epi16(val16_0, val16_1),
791                                       _mm_add_epi16(val16_2, val16_3));
792     __m256i sum =
793         _mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums + col));
794     sum = mm256_add_epi32(sum, mm256_cvtepi16_epi32(new_sum16));
795     _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums + col), sum);
796     // Perform the transposition of 4x4 blocks
797     __m128i t2_val0 = _mm_unpacklo_epi8(val0, val1);
798     __m128i t2_val1 = _mm_unpacklo_epi8(val2, val3);
799     __m128i t4_val0 = _mm_unpacklo_epi16(t2_val0, t2_val1);
800     __m128i t4_val1 = _mm_unpackhi_epi16(t2_val0, t2_val1);
801     _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr), t4_val0);
802     _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 16), t4_val1);
803     src_ptr += 8;
804     packed_ptr += packed_stride * 8;
805   }
806   for (; col < src_end_col; col++) {
807     std::int32_t accum = 0;
808     for (int r = 0; r < 4; r++) {
809       std::int8_t packed_val;
810       if (block_row + r < src_rows) {
811         packed_val = input_xor ^ src_ptr[r * src_stride];
812       } else {
813         packed_val = input_xor ^ src_zero_point;
814       }
815       accum += packed_val;
816       *packed_ptr++ = packed_val;
817     }
818     if (sums) {
819       sums[col] += accum;
820     }
821     src_ptr++;
822   }
823   for (; col < end_col; col++) {
824     std::memset(packed_ptr, 0, 4);
825     packed_ptr += 4;
826   }
827 }
828 
829 #endif  // RUY_PLATFORM_AVX && RUY_OPT(INTRINSICS)
830 
831 }  // namespace ruy
832