xref: /aosp_15_r20/external/ruy/ruy/pack_avx2_fma.cc (revision bb86c7ed5fb1b98a7eac808e443a46cc8b90dfc0)
1 /* Copyright 2019 Google LLC. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstdint>
17 #include <cstring>
18 
19 #include "ruy/check_macros.h"
20 #include "ruy/opt_set.h"
21 #include "ruy/pack_x86.h"
22 #include "ruy/path.h"
23 #include "ruy/platform.h"
24 #include "ruy/profiler/instrumentation.h"
25 
26 #if RUY_PLATFORM_AVX2_FMA && RUY_OPT(INTRINSICS)
27 #include <immintrin.h>  // IWYU pragma: keep
28 #endif
29 
30 namespace ruy {
31 
32 #if !(RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM))
33 
Pack8bitColMajorForAvx2(const std::int8_t *,std::int8_t,const std::int8_t *,int,int,int,std::int8_t *,std::int32_t *)34 void Pack8bitColMajorForAvx2(const std::int8_t*, std::int8_t,
35                              const std::int8_t*, int, int, int, std::int8_t*,
36                              std::int32_t*) {
37   // CPU-ID-based checks should disable the path that would reach this point.
38   RUY_DCHECK(false);
39 }
40 
PackFloatColMajorForAvx2(const float *,const float *,int,int,int,float *)41 void PackFloatColMajorForAvx2(const float*, const float*, int, int, int,
42                               float*) {
43   // CPU-ID-based checks should disable the path that would reach this point.
44   RUY_DCHECK(false);
45 }
46 
Pack8bitRowMajorForAvx2(const std::uint8_t *,int,int,std::int8_t *,int,int,int,int,int,int,int,std::int32_t *)47 void Pack8bitRowMajorForAvx2(const std::uint8_t*, int, int, std::int8_t*, int,
48                              int, int, int, int, int, int, std::int32_t*) {
49   RUY_DCHECK(false);
50 }
51 
52 #else  // RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)
53 
54 // The first int8_t template parameter is arbitrary: this routine is common to
55 // all 8-bit source matrix types.
56 using PackImpl8bitAvx2 =
57     PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kColMajor, 4, 8>,
58              std::int8_t, std::int8_t, std::int32_t, Order::kColMajor>;
59 
60 using PackImplFloatAvx2 =
61     PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kRowMajor, 1, 8>, float,
62              float, float, Order::kColMajor>;
63 
64 namespace {
65 
66 inline void Pack8bitColMajorForAvx2Packer(
67     const std::int8_t* src_ptr, std::int8_t input_xor,
68     const std::int8_t* zerobuf, int src_stride, int remaining_src_cols,
69     int src_rows, std::int8_t* packed_ptr, std::int32_t* sums_ptr,
70     std::int8_t* trailing_buf) {
71   using Layout = PackImpl8bitAvx2::Layout;
72   RUY_DCHECK_EQ(Layout::kCols, 8);
73   RUY_DCHECK_EQ(Layout::kRows, 4);
74   // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
75   // We process 8 of these chunks at a time, padding short input chunks.
76   constexpr int kNumRowChunks = 8;
77   constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows;
78 
79   const std::int8_t* src_ptr0 = src_ptr;
80   const std::int8_t* src_ptr1 = src_ptr0 + src_stride;
81   const std::int8_t* src_ptr2 = src_ptr1 + src_stride;
82   const std::int8_t* src_ptr3 = src_ptr2 + src_stride;
83   const std::int8_t* src_ptr4 = src_ptr3 + src_stride;
84   const std::int8_t* src_ptr5 = src_ptr4 + src_stride;
85   const std::int8_t* src_ptr6 = src_ptr5 + src_stride;
86   const std::int8_t* src_ptr7 = src_ptr6 + src_stride;
87   std::int64_t src_inc0 = kNumChunkedSrcRows;
88   std::int64_t src_inc1 = kNumChunkedSrcRows;
89   std::int64_t src_inc2 = kNumChunkedSrcRows;
90   std::int64_t src_inc3 = kNumChunkedSrcRows;
91   std::int64_t src_inc4 = kNumChunkedSrcRows;
92   std::int64_t src_inc5 = kNumChunkedSrcRows;
93   std::int64_t src_inc6 = kNumChunkedSrcRows;
94   std::int64_t src_inc7 = kNumChunkedSrcRows;
95   // Handle cases where source does not have Layout::kCols (8) columns.
96   if (remaining_src_cols < 8) {
97     if (remaining_src_cols <= 0) {
98       src_ptr0 = zerobuf;
99       src_inc0 = 0;
100     }
101     if (remaining_src_cols <= 1) {
102       src_ptr1 = zerobuf;
103       src_inc1 = 0;
104     }
105     if (remaining_src_cols <= 2) {
106       src_ptr2 = zerobuf;
107       src_inc2 = 0;
108     }
109     if (remaining_src_cols <= 3) {
110       src_ptr3 = zerobuf;
111       src_inc3 = 0;
112     }
113     if (remaining_src_cols <= 4) {
114       src_ptr4 = zerobuf;
115       src_inc4 = 0;
116     }
117     if (remaining_src_cols <= 5) {
118       src_ptr5 = zerobuf;
119       src_inc5 = 0;
120     }
121     if (remaining_src_cols <= 6) {
122       src_ptr6 = zerobuf;
123       src_inc6 = 0;
124     }
125     src_ptr7 = zerobuf;
126     src_inc7 = 0;
127   }
128 
129   const std::int8_t zero_point = zerobuf[0];
130 
131   if (sums_ptr) {
132     // i: Layout::kCols.
133     for (int i = 0; i < 8; ++i) {
134       sums_ptr[i] = 0;
135     }
136   }
137   std::int32_t sums_adjustment = 0;
138   const __m256i ones_16bit = _mm256_set1_epi16(1);
139   __m256i sums_4x2_32bit_lo = _mm256_set1_epi32(0);
140   __m256i sums_4x2_32bit_hi = _mm256_set1_epi32(0);
141 
142   // The overall packing effectively pads the source rows to
143   // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we
144   // only pack for (src_rows + 31) & ~31. When there is an incomplete
145   // destination block, this is stored into trailing_buf instead of packed_ptr.
146   for (int k = 0; k < src_rows; k += kNumChunkedSrcRows) {
147     // Available source rows.
148     // If this is less than 0 (for m=1), we skip, having filled trailing
149     // buffer for m=0. Also, if source rows is zero on m=1, then we filled
150     // exactly to the end of the column in the packed buffer.
151     const int available_src_rows = src_rows - k;
152     // Effectively,
153     // available rows = std::max(0, std::min(8, src_rows - k));
154     // treat each case separately.
155     if (available_src_rows >= kNumChunkedSrcRows) {
156       if (sums_ptr) {
157         __m256i t0, t1, t2, t3, t4, t5, t6, t7;
158         __m256i r0, r1, r2, r3, r4, r5, r6, r7;
159         const __m256i input_xor_v = _mm256_set1_epi8(input_xor);
160 
161         t0 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr0));
162         t4 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr4));
163         t1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr1));
164         t5 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr5));
165         t2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr2));
166         t6 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr6));
167         t3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr3));
168         t7 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr7));
169 
170         r0 = _mm256_unpacklo_epi32(t0, t1);
171         r4 = _mm256_unpacklo_epi32(t4, t5);
172         r2 = _mm256_unpackhi_epi32(t0, t1);
173         r6 = _mm256_unpackhi_epi32(t4, t5);
174         r1 = _mm256_unpacklo_epi32(t2, t3);
175         r5 = _mm256_unpacklo_epi32(t6, t7);
176         r3 = _mm256_unpackhi_epi32(t2, t3);
177         r7 = _mm256_unpackhi_epi32(t6, t7);
178 
179         t0 = _mm256_unpacklo_epi64(r0, r1);
180         t4 = _mm256_unpacklo_epi64(r4, r5);
181         t2 = _mm256_unpackhi_epi64(r0, r1);
182         t6 = _mm256_unpackhi_epi64(r4, r5);
183         t1 = _mm256_unpacklo_epi64(r2, r3);
184         t5 = _mm256_unpacklo_epi64(r6, r7);
185         t3 = _mm256_unpackhi_epi64(r2, r3);
186         t7 = _mm256_unpackhi_epi64(r6, r7);
187 
188         // The preceding sets of rearrangement operations interleaved by 4 bytes
189         // and then by 8 bytes *within* lanes. The following set interleave by
190         // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0,
191         // t4) are interleaved to create (r0, r1). This complexity follows from
192         // the way that AVX is centered around MM 128-bit lanes.
193         r0 = _mm256_permute2x128_si256(t0, t4, 0x20);
194         r4 = _mm256_permute2x128_si256(t1, t5, 0x20);
195         r1 = _mm256_permute2x128_si256(t0, t4, 0x31);
196         r5 = _mm256_permute2x128_si256(t1, t5, 0x31);
197         r2 = _mm256_permute2x128_si256(t2, t6, 0x20);
198         r6 = _mm256_permute2x128_si256(t3, t7, 0x20);
199         r3 = _mm256_permute2x128_si256(t2, t6, 0x31);
200         r7 = _mm256_permute2x128_si256(t3, t7, 0x31);
201 
202         r0 = _mm256_xor_si256(r0, input_xor_v);
203         r1 = _mm256_xor_si256(r1, input_xor_v);
204         r2 = _mm256_xor_si256(r2, input_xor_v);
205         r3 = _mm256_xor_si256(r3, input_xor_v);
206         r4 = _mm256_xor_si256(r4, input_xor_v);
207         r5 = _mm256_xor_si256(r5, input_xor_v);
208         r6 = _mm256_xor_si256(r6, input_xor_v);
209         r7 = _mm256_xor_si256(r7, input_xor_v);
210 
211         __m256i sums_4x4_16bit_lo;
212         sums_4x4_16bit_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r0));
213         sums_4x4_16bit_lo =
214             _mm256_add_epi16(sums_4x4_16bit_lo,
215                              _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r1)));
216         sums_4x4_16bit_lo =
217             _mm256_add_epi16(sums_4x4_16bit_lo,
218                              _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r2)));
219         sums_4x4_16bit_lo =
220             _mm256_add_epi16(sums_4x4_16bit_lo,
221                              _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r3)));
222         sums_4x4_16bit_lo =
223             _mm256_add_epi16(sums_4x4_16bit_lo,
224                              _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r4)));
225         sums_4x4_16bit_lo =
226             _mm256_add_epi16(sums_4x4_16bit_lo,
227                              _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r5)));
228         sums_4x4_16bit_lo =
229             _mm256_add_epi16(sums_4x4_16bit_lo,
230                              _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r6)));
231         sums_4x4_16bit_lo =
232             _mm256_add_epi16(sums_4x4_16bit_lo,
233                              _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r7)));
234 
235         // The sums have been performed across columns, and now we have 4x16-bit
236         // sums packed together. We use madd for pairwise 32-bit sums.
237         const __m256i sums_4x2_32bit_lo_new =
238             _mm256_madd_epi16(sums_4x4_16bit_lo, ones_16bit);
239         sums_4x2_32bit_lo =
240             _mm256_add_epi32(sums_4x2_32bit_lo, sums_4x2_32bit_lo_new);
241 
242         __m256i sums_4x4_16bit_hi;
243         sums_4x4_16bit_hi =
244             _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r0, 1));
245         sums_4x4_16bit_hi = _mm256_add_epi16(
246             sums_4x4_16bit_hi,
247             _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r1, 1)));
248         sums_4x4_16bit_hi = _mm256_add_epi16(
249             sums_4x4_16bit_hi,
250             _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r2, 1)));
251         sums_4x4_16bit_hi = _mm256_add_epi16(
252             sums_4x4_16bit_hi,
253             _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r3, 1)));
254         sums_4x4_16bit_hi = _mm256_add_epi16(
255             sums_4x4_16bit_hi,
256             _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r4, 1)));
257         sums_4x4_16bit_hi = _mm256_add_epi16(
258             sums_4x4_16bit_hi,
259             _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r5, 1)));
260         sums_4x4_16bit_hi = _mm256_add_epi16(
261             sums_4x4_16bit_hi,
262             _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r6, 1)));
263         sums_4x4_16bit_hi = _mm256_add_epi16(
264             sums_4x4_16bit_hi,
265             _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r7, 1)));
266 
267         const __m256i sums_4x2_32bit_hi_new =
268             _mm256_madd_epi16(sums_4x4_16bit_hi, ones_16bit);
269         sums_4x2_32bit_hi =
270             _mm256_add_epi32(sums_4x2_32bit_hi, sums_4x2_32bit_hi_new);
271 
272         _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 0 * 8 * 4),
273                             r0);
274         _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 2 * 8 * 4),
275                             r4);
276         _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 4 * 8 * 4),
277                             r1);
278         _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 6 * 8 * 4),
279                             r5);
280         _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 1 * 8 * 4),
281                             r2);
282         _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 3 * 8 * 4),
283                             r6);
284         _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 5 * 8 * 4),
285                             r3);
286         _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 7 * 8 * 4),
287                             r7);
288       } else {
289         __m256i t0, t1, t2, t3, t4, t5, t6, t7;
290         __m256i r0, r1, r2, r3, r4, r5, r6, r7;
291         const __m256i input_xor_v = _mm256_set1_epi8(input_xor);
292 
293         t0 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr0));
294         t4 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr4));
295         t1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr1));
296         t5 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr5));
297         t2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr2));
298         t6 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr6));
299         t3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr3));
300         t7 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr7));
301 
302         r0 = _mm256_unpacklo_epi32(t0, t1);
303         r4 = _mm256_unpacklo_epi32(t4, t5);
304         r2 = _mm256_unpackhi_epi32(t0, t1);
305         r6 = _mm256_unpackhi_epi32(t4, t5);
306         r1 = _mm256_unpacklo_epi32(t2, t3);
307         r5 = _mm256_unpacklo_epi32(t6, t7);
308         r3 = _mm256_unpackhi_epi32(t2, t3);
309         r7 = _mm256_unpackhi_epi32(t6, t7);
310 
311         t0 = _mm256_unpacklo_epi64(r0, r1);
312         t4 = _mm256_unpacklo_epi64(r4, r5);
313         t2 = _mm256_unpackhi_epi64(r0, r1);
314         t6 = _mm256_unpackhi_epi64(r4, r5);
315         t1 = _mm256_unpacklo_epi64(r2, r3);
316         t5 = _mm256_unpacklo_epi64(r6, r7);
317         t3 = _mm256_unpackhi_epi64(r2, r3);
318         t7 = _mm256_unpackhi_epi64(r6, r7);
319 
320         // The preceding sets of rearrangement operations interleaved by 4 bytes
321         // and then by 8 bytes *within* lanes. The following set interleave by
322         // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0,
323         // t4) are interleaved to create (r0, r1). This complexity follows from
324         // the way that AVX is centered around MM 128-bit lanes.
325         r0 = _mm256_permute2x128_si256(t0, t4, 0x20);
326         r4 = _mm256_permute2x128_si256(t1, t5, 0x20);
327         r1 = _mm256_permute2x128_si256(t0, t4, 0x31);
328         r5 = _mm256_permute2x128_si256(t1, t5, 0x31);
329         r2 = _mm256_permute2x128_si256(t2, t6, 0x20);
330         r6 = _mm256_permute2x128_si256(t3, t7, 0x20);
331         r3 = _mm256_permute2x128_si256(t2, t6, 0x31);
332         r7 = _mm256_permute2x128_si256(t3, t7, 0x31);
333 
334         r0 = _mm256_xor_si256(r0, input_xor_v);
335         r1 = _mm256_xor_si256(r1, input_xor_v);
336         r2 = _mm256_xor_si256(r2, input_xor_v);
337         r3 = _mm256_xor_si256(r3, input_xor_v);
338         r4 = _mm256_xor_si256(r4, input_xor_v);
339         r5 = _mm256_xor_si256(r5, input_xor_v);
340         r6 = _mm256_xor_si256(r6, input_xor_v);
341         r7 = _mm256_xor_si256(r7, input_xor_v);
342 
343         _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 0 * 8 * 4),
344                             r0);
345         _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 2 * 8 * 4),
346                             r4);
347         _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 4 * 8 * 4),
348                             r1);
349         _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 6 * 8 * 4),
350                             r5);
351         _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 1 * 8 * 4),
352                             r2);
353         _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 3 * 8 * 4),
354                             r6);
355         _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 5 * 8 * 4),
356                             r3);
357         _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 7 * 8 * 4),
358                             r7);
359       }
360     } else if (available_src_rows > 0) {
361       RUY_DCHECK_LT(available_src_rows, kNumChunkedSrcRows);
362       // We do not care what goes into the trailing buffer, but we want
363       // in_data[...] ^ input_xor == 0 for irrelevant values in the summation.
364       //
365       // We compensate for padding-with-zero_point by initializing the
366       // summations with the compensating offset, effectively
367       // ((input_xor ^ input_xor) - (zero_point ^ input_xor)) *
368       //                         4 * (8 - ((available_src_rows + 3) >> 2)).
369       //
370       // Note that (zero_point ^ input_xor) is performed in 8-bits and then
371       // cast.
372       sums_adjustment +=
373           -(zero_point ^ input_xor) * 4 * (8 - ((available_src_rows + 3) >> 2));
374 
375       __m256i t0, t1, t2, t3, t4, t5, t6, t7;
376       __m256i r0, r1, r2, r3, r4, r5, r6, r7;
377       const __m256i input_xor_v = _mm256_set1_epi8(input_xor);
378 
379       t0 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr0);
380       t4 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr4);
381       t1 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr1);
382       t5 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr5);
383       t2 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr2);
384       t6 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr6);
385       t3 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr3);
386       t7 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr7);
387 
388       r0 = _mm256_unpacklo_epi32(t0, t1);
389       r4 = _mm256_unpacklo_epi32(t4, t5);
390       r2 = _mm256_unpackhi_epi32(t0, t1);
391       r6 = _mm256_unpackhi_epi32(t4, t5);
392       r1 = _mm256_unpacklo_epi32(t2, t3);
393       r5 = _mm256_unpacklo_epi32(t6, t7);
394       r3 = _mm256_unpackhi_epi32(t2, t3);
395       r7 = _mm256_unpackhi_epi32(t6, t7);
396 
397       t0 = _mm256_unpacklo_epi64(r0, r1);
398       t4 = _mm256_unpacklo_epi64(r4, r5);
399       t2 = _mm256_unpackhi_epi64(r0, r1);
400       t6 = _mm256_unpackhi_epi64(r4, r5);
401       t1 = _mm256_unpacklo_epi64(r2, r3);
402       t5 = _mm256_unpacklo_epi64(r6, r7);
403       t3 = _mm256_unpackhi_epi64(r2, r3);
404       t7 = _mm256_unpackhi_epi64(r6, r7);
405 
406       // The preceding sets of rearrangement operations interleaved by 4 bytes
407       // and then by 8 bytes *within* lanes. The following set interleave by
408       // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0,
409       // t4) are interleaved to create (r0, r1). This complexity follows from
410       // the way that AVX is centered around MM 128-bit lanes.
411       r0 = _mm256_permute2x128_si256(t0, t4, 0x20);
412       r4 = _mm256_permute2x128_si256(t1, t5, 0x20);
413       r1 = _mm256_permute2x128_si256(t0, t4, 0x31);
414       r5 = _mm256_permute2x128_si256(t1, t5, 0x31);
415       r2 = _mm256_permute2x128_si256(t2, t6, 0x20);
416       r6 = _mm256_permute2x128_si256(t3, t7, 0x20);
417       r3 = _mm256_permute2x128_si256(t2, t6, 0x31);
418       r7 = _mm256_permute2x128_si256(t3, t7, 0x31);
419 
420       r0 = _mm256_xor_si256(r0, input_xor_v);
421       r1 = _mm256_xor_si256(r1, input_xor_v);
422       r2 = _mm256_xor_si256(r2, input_xor_v);
423       r3 = _mm256_xor_si256(r3, input_xor_v);
424       r4 = _mm256_xor_si256(r4, input_xor_v);
425       r5 = _mm256_xor_si256(r5, input_xor_v);
426       r6 = _mm256_xor_si256(r6, input_xor_v);
427       r7 = _mm256_xor_si256(r7, input_xor_v);
428 
429       __m256i sums_4x4_16bit_lo;
430       sums_4x4_16bit_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r0));
431       sums_4x4_16bit_lo = _mm256_add_epi16(
432           sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r1)));
433       sums_4x4_16bit_lo = _mm256_add_epi16(
434           sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r2)));
435       sums_4x4_16bit_lo = _mm256_add_epi16(
436           sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r3)));
437       sums_4x4_16bit_lo = _mm256_add_epi16(
438           sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r4)));
439       sums_4x4_16bit_lo = _mm256_add_epi16(
440           sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r5)));
441       sums_4x4_16bit_lo = _mm256_add_epi16(
442           sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r6)));
443       sums_4x4_16bit_lo = _mm256_add_epi16(
444           sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r7)));
445 
446       // The sums have been performed across columns, and now we have 4x16-bit
447       // sums packed together. We use madd for pairwise 32-bit sums.
448       const __m256i sums_4x2_32bit_lo_new =
449           _mm256_madd_epi16(sums_4x4_16bit_lo, ones_16bit);
450       sums_4x2_32bit_lo =
451           _mm256_add_epi32(sums_4x2_32bit_lo, sums_4x2_32bit_lo_new);
452 
453       __m256i sums_4x4_16bit_hi;
454       sums_4x4_16bit_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r0, 1));
455       sums_4x4_16bit_hi = _mm256_add_epi16(
456           sums_4x4_16bit_hi,
457           _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r1, 1)));
458       sums_4x4_16bit_hi = _mm256_add_epi16(
459           sums_4x4_16bit_hi,
460           _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r2, 1)));
461       sums_4x4_16bit_hi = _mm256_add_epi16(
462           sums_4x4_16bit_hi,
463           _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r3, 1)));
464       sums_4x4_16bit_hi = _mm256_add_epi16(
465           sums_4x4_16bit_hi,
466           _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r4, 1)));
467       sums_4x4_16bit_hi = _mm256_add_epi16(
468           sums_4x4_16bit_hi,
469           _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r5, 1)));
470       sums_4x4_16bit_hi = _mm256_add_epi16(
471           sums_4x4_16bit_hi,
472           _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r6, 1)));
473       sums_4x4_16bit_hi = _mm256_add_epi16(
474           sums_4x4_16bit_hi,
475           _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r7, 1)));
476 
477       const __m256i sums_4x2_32bit_hi_new =
478           _mm256_madd_epi16(sums_4x4_16bit_hi, ones_16bit);
479       sums_4x2_32bit_hi =
480           _mm256_add_epi32(sums_4x2_32bit_hi, sums_4x2_32bit_hi_new);
481 
482       _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 0 * 8 * 4),
483                           r0);
484       _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 2 * 8 * 4),
485                           r4);
486       _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 4 * 8 * 4),
487                           r1);
488       _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 6 * 8 * 4),
489                           r5);
490       _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 1 * 8 * 4),
491                           r2);
492       _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 3 * 8 * 4),
493                           r6);
494       _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 5 * 8 * 4),
495                           r3);
496       _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 7 * 8 * 4),
497                           r7);
498     }
499 
500     packed_ptr += 8 * kNumChunkedSrcRows;
501     src_ptr0 += src_inc0;
502     src_ptr1 += src_inc1;
503     src_ptr2 += src_inc2;
504     src_ptr3 += src_inc3;
505     src_ptr4 += src_inc4;
506     src_ptr5 += src_inc5;
507     src_ptr6 += src_inc6;
508     src_ptr7 += src_inc7;
509   }
510 
511   if (sums_ptr) {
512     const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment);
513 
514     __m256i sums =
515         _mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums_ptr));
516     const __m256i idx = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);
517 
518     // We earlier used madd for pairwise 32-bit sums, and now we deinterlace the
519     // neighbours, finshing up by adding them to the stored accumulated sums.
520     const __m256i sums_2x4_32bit_lo =
521         _mm256_permutevar8x32_epi32(sums_4x2_32bit_lo, idx);
522     const __m256i sums_2x4_32bit_hi =
523         _mm256_permutevar8x32_epi32(sums_4x2_32bit_hi, idx);
524     const __m256i sums_2x4_32bit_a =
525         _mm256_permute2x128_si256(sums_2x4_32bit_lo, sums_2x4_32bit_hi, 0x20);
526     const __m256i sums_2x4_32bit_b =
527         _mm256_permute2x128_si256(sums_2x4_32bit_lo, sums_2x4_32bit_hi, 0x31);
528     sums = _mm256_add_epi32(sums, sums_adjustment_v);
529     sums = _mm256_add_epi32(sums, sums_2x4_32bit_a);
530     sums = _mm256_add_epi32(sums, sums_2x4_32bit_b);
531 
532     _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums_ptr), sums);
533   }
534 }
535 
536 // Use AVX2 specific intrinsic for greater than comparison.
537 template <>
538 inline __m256i CompareGreaterThan<Path::kAvx2Fma>(const __m256i& a,
539                                                   const __m256i& b) {
540   return _mm256_cmpgt_epi32(a, b);
541 }
542 
543 }  // namespace.
544 
545 void Pack8bitColMajorForAvx2(const std::int8_t* src_ptr, std::int8_t input_xor,
546                              const std::int8_t* zerobuf, int src_stride,
547                              int remaining_src_cols, int src_rows,
548                              std::int8_t* packed_ptr, std::int32_t* sums_ptr) {
549   profiler::ScopeLabel label("Pack kAvx2Fma 8bit");
550 
551   using Layout = PackImpl8bitAvx2::Layout;
552   RUY_DCHECK_EQ(Layout::kCols, 8);
553   RUY_DCHECK_EQ(Layout::kRows, 4);
554 
555   // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
556   // We process 8 of these chunks at a time, padding short input chunks.
557   static constexpr int kNumRowChunks = 8;  // Short input is padded.
558 
559   // Each packed block is 4*8, and there are normally 8. The trailing block is
560   // only slightly shorter.
561   constexpr int kTrailingBufSize =
562       kNumRowChunks * Layout::kCols * Layout::kRows;
563   std::int8_t trailing_buf[kTrailingBufSize];
564   memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t));
565 
566   Pack8bitColMajorForAvx2Packer(src_ptr, input_xor, zerobuf, src_stride,
567                                 remaining_src_cols, src_rows, packed_ptr,
568                                 sums_ptr, trailing_buf);
569 
570   constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1;
571   const bool trailing_data = (src_rows & kChunkedRowMask) > 0;
572   // If the number of source rows is not a multiple of kChunkedRowMask, there
573   // will be data in the trailing buffer,
574   if (trailing_data) {
575     const int non_trailing_rows = src_rows & ~kChunkedRowMask;
576     // Destination "rows" are padded to next highest multiple of Layout::kRows.
577     const int dst_rows = (src_rows + 3) & ~3;
578     const int trailing_rows = dst_rows - non_trailing_rows;
579     memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf,
580            Layout::kCols * trailing_rows * sizeof(std::int8_t));
581   }
582 }
583 
584 void PackFloatColMajorForAvx2(const float* src_ptr, const float* zerobuf,
585                               int src_stride, int remaining_src_cols,
586                               int src_rows, float* packed_ptr) {
587   profiler::ScopeLabel label("Pack kAvx2Fma float");
588   static constexpr int kPackCols = 8;  // Source cols packed together.
589   static constexpr int kPackRows = 8;  // Short input is padded.
590   float trailing_buf[(kPackRows - 1) * kPackCols];
591   if (remaining_src_cols < 8) {
592     memset(trailing_buf, 0, sizeof(trailing_buf));
593   }
594   PackFloatColMajorForAvxCommonPacker<PackImplFloatAvx2, Path::kAvx2Fma>(
595       src_ptr, zerobuf, src_stride, remaining_src_cols, src_rows, packed_ptr,
596       trailing_buf);
597 
598   const int trailing_rows = src_rows & (kPackRows - 1);
599   if (trailing_rows > 0) {
600     const int non_trailing_rows = src_rows & ~(kPackRows - 1);
601     memcpy(packed_ptr + kPackCols * non_trailing_rows, trailing_buf,
602            kPackCols * trailing_rows * sizeof(float));
603   }
604 }
605 
606 void Pack8bitRowMajorForAvx2(const std::uint8_t* src_ptr, int src_stride,
607                              int src_zero_point, std::int8_t* packed_ptr,
608                              int packed_stride, int start_col, int end_col,
609                              int src_cols, int block_row, int src_rows,
610                              int input_xor, std::int32_t* sums) {
611   int col = start_col;
612   int src_end_col = std::min(end_col, src_cols);
613 
614   for (; col <= src_end_col - 8; col += 8) {
615     std::int8_t* dst_ptr = packed_ptr;
616     __m128i val0, val1, val2, val3;
617     __m128i input_xor_dup = _mm_set1_epi8(input_xor);
618     // Load a 4x8 block.
619     if (block_row + 4 <= src_rows) {
620       val0 = _mm_loadu_si64(src_ptr + 0 * src_stride);
621       val1 = _mm_loadu_si64(src_ptr + 1 * src_stride);
622       val2 = _mm_loadu_si64(src_ptr + 2 * src_stride);
623       val3 = _mm_loadu_si64(src_ptr + 3 * src_stride);
624     } else {
625       val0 = _mm_set1_epi8(src_zero_point);
626       val1 = val0;
627       val2 = val0;
628       val3 = val0;
629       if (block_row + 0 < src_rows)
630         val0 = _mm_loadu_si64(src_ptr + 0 * src_stride);
631       if (block_row + 1 < src_rows)
632         val1 = _mm_loadu_si64(src_ptr + 1 * src_stride);
633       if (block_row + 2 < src_rows)
634         val2 = _mm_loadu_si64(src_ptr + 2 * src_stride);
635       if (block_row + 3 < src_rows)
636         val3 = _mm_loadu_si64(src_ptr + 3 * src_stride);
637     }
638     // Maybe xor the sign bit to convert from uint8 to int8.
639     val0 = _mm_xor_si128(val0, input_xor_dup);
640     val1 = _mm_xor_si128(val1, input_xor_dup);
641     val2 = _mm_xor_si128(val2, input_xor_dup);
642     val3 = _mm_xor_si128(val3, input_xor_dup);
643     // Update the sums.
644     __m128i val16_0 = _mm_cvtepi8_epi16(val0);
645     __m128i val16_1 = _mm_cvtepi8_epi16(val1);
646     __m128i val16_2 = _mm_cvtepi8_epi16(val2);
647     __m128i val16_3 = _mm_cvtepi8_epi16(val3);
648     __m128i new_sum16 = _mm_add_epi16(_mm_add_epi16(val16_0, val16_1),
649                                       _mm_add_epi16(val16_2, val16_3));
650     __m256i sum =
651         _mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums + col));
652     sum = _mm256_add_epi32(sum, _mm256_cvtepi16_epi32(new_sum16));
653     _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums + col), sum);
654     // Perform the transposition of 4x4 blocks
655     __m128i t2_val0 = _mm_unpacklo_epi8(val0, val1);
656     __m128i t2_val1 = _mm_unpacklo_epi8(val2, val3);
657     __m128i t4_val0 = _mm_unpacklo_epi16(t2_val0, t2_val1);
658     __m128i t4_val1 = _mm_unpackhi_epi16(t2_val0, t2_val1);
659     _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr), t4_val0);
660     _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 16), t4_val1);
661     src_ptr += 8;
662     packed_ptr += packed_stride * 8;
663   }
664   for (; col < src_end_col; col++) {
665     std::int32_t accum = 0;
666     for (int r = 0; r < 4; r++) {
667       std::int8_t packed_val;
668       if (block_row + r < src_rows) {
669         packed_val = input_xor ^ src_ptr[r * src_stride];
670       } else {
671         packed_val = input_xor ^ src_zero_point;
672       }
673       accum += packed_val;
674       *packed_ptr++ = packed_val;
675     }
676     if (sums) {
677       sums[col] += accum;
678     }
679     src_ptr++;
680   }
681   for (; col < end_col; col++) {
682     std::memset(packed_ptr, 0, 4);
683     packed_ptr += 4;
684   }
685 }
686 
687 #endif  // RUY_PLATFORM_AVX2_FMA && RUY_OPT(INTRINSICS)
688 
689 }  // namespace ruy
690