xref: /aosp_15_r20/external/ruy/ruy/pack_avx512.cc (revision bb86c7ed5fb1b98a7eac808e443a46cc8b90dfc0)
1 /* Copyright 2019 Google LLC. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstdint>
17 #include <cstring>
18 
19 #include "ruy/check_macros.h"
20 #include "ruy/opt_set.h"
21 #include "ruy/pack_x86.h"
22 #include "ruy/path.h"
23 #include "ruy/platform.h"
24 #include "ruy/profiler/instrumentation.h"
25 
26 #if RUY_PLATFORM_AVX512 && RUY_OPT(INTRINSICS)
27 #include <immintrin.h>  // IWYU pragma: keep
28 #endif
29 
30 namespace ruy {
31 
32 #if !(RUY_PLATFORM_AVX512 && RUY_OPT(ASM))
33 
Pack8bitColMajorForAvx512(const std::int8_t *,std::int8_t,const std::int8_t *,int,int,int,std::int8_t *,std::int32_t *)34 void Pack8bitColMajorForAvx512(const std::int8_t*, std::int8_t,
35                                const std::int8_t*, int, int, int, std::int8_t*,
36                                std::int32_t*) {
37   // CPU-ID-based checks should disable the path that would reach this point.
38   RUY_DCHECK(false);
39 }
40 
Pack16bitColMajorForAvx512(const std::int16_t *,const std::int16_t *,int,int,int,std::int16_t *,std::int32_t *)41 void Pack16bitColMajorForAvx512(const std::int16_t*, const std::int16_t*, int,
42                                 int, int, std::int16_t*, std::int32_t*) {
43   // CPU-ID-based checks should disable the path that would reach this point.
44   RUY_DCHECK(false);
45 }
46 
PackFloatColMajorForAvx512(const float *,const float *,int,int,int,float *)47 void PackFloatColMajorForAvx512(const float*, const float*, int, int, int,
48                                 float*) {
49   // CPU-ID-based checks should disable the path that would reach this point.
50   RUY_DCHECK(false);
51 }
52 
Pack8bitRowMajorForAvx512(const std::uint8_t *,int,int,std::int8_t *,int,int,int,int,int,int,int,std::int32_t *)53 void Pack8bitRowMajorForAvx512(const std::uint8_t*, int, int, std::int8_t*, int,
54                                int, int, int, int, int, int, std::int32_t*) {
55   RUY_DCHECK(false);
56 }
57 
58 #else  // RUY_PLATFORM_AVX512 && RUY_OPT(ASM)
59 
60 // The first int8_t template parameter is arbitrary: this routine is common to
61 // all 8-bit source matrix types.
62 using PackImpl8bitAvx512 =
63     PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
64              std::int8_t, std::int8_t, std::int32_t, Order::kColMajor>;
65 using PackImpl16bitAvx512 =
66     PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
67              std::int16_t, std::int16_t, std::int32_t, Order::kColMajor>;
68 
69 namespace {
70 
71 template <typename PackImplAvx512, typename Scalar>
72 inline void ZeroHalfAvx512(int src_rows, Scalar packed_zero_point,
73                            Scalar* packed_ptr, int chunked_row_mask) {
74   using Layout = typename PackImplAvx512::Layout;
75   static constexpr int kHalfLayoutCols =
76       PackImplAvx512::kHalfLayoutCols;  // Half the number of cols in a
77                                         // block.
78   RUY_DCHECK_EQ(kHalfLayoutCols, 8);
79   RUY_DCHECK_EQ(Layout::kCols, 16);
80   RUY_DCHECK_EQ(Layout::kRows, 4);
81 
82   const int non_trailing_blocks = (src_rows & ~chunked_row_mask) >> 2;
83   // This routine fills half blocks, and typically fills the second halves.
84   // Thus packed_ptr is already offset by 8 * 4.
85   for (int k = 0; k < non_trailing_blocks; ++k) {
86     for (int j = 0; j < (kHalfLayoutCols * Layout::kRows); ++j) {
87       packed_ptr[Layout::kCols * Layout::kRows * k + j] = packed_zero_point;
88     }
89   }
90 }
91 
92 template <typename Scalar>
93 inline __m512i LoaduTwo(const Scalar* addr_lo, const Scalar* addr_hi) {
94   __m512i lower_filled = _mm512_castsi256_si512(
95       _mm256_loadu_si256(reinterpret_cast<const __m256i*>(addr_lo)));
96   return _mm512_inserti32x8(
97       lower_filled,
98       _mm256_loadu_si256(reinterpret_cast<const __m256i*>(addr_hi)), 1);
99 }
100 
101 inline __m512i MaskLoaduTwo(__mmask32 row_mask, const __m256i default_value_v,
102                             const std::int8_t* addr_lo,
103                             const std::int8_t* addr_hi) {
104   const __m512i lower_filled = _mm512_castsi256_si512(
105       _mm256_mask_loadu_epi8(default_value_v, row_mask, addr_lo));
106   return _mm512_inserti32x8(
107       lower_filled, _mm256_mask_loadu_epi8(default_value_v, row_mask, addr_hi),
108       1);
109 }
110 
111 inline __m512i MaskLoaduTwo(__mmask32 row_mask, const __m256i default_value_v,
112                             const std::int16_t* addr_lo,
113                             const std::int16_t* addr_hi) {
114   const __m512i lower_filled = _mm512_castsi256_si512(
115       _mm256_mask_loadu_epi16(default_value_v, row_mask, addr_lo));
116   return _mm512_inserti32x8(
117       lower_filled, _mm256_mask_loadu_epi16(default_value_v, row_mask, addr_hi),
118       1);
119 }
120 
121 inline void HalfPack8bitAvx512(const std::int8_t* src_ptr,
122                                std::int8_t input_xor,
123                                const std::int8_t* zerobuf, int src_stride,
124                                int remaining_src_cols, int src_rows,
125                                std::int8_t* packed_ptr, std::int32_t* sums_ptr,
126                                std::int8_t* trailing_buf) {
127   using Layout = PackImpl8bitAvx512::Layout;
128   RUY_DCHECK_EQ(Layout::kCols, 16);
129   RUY_DCHECK_EQ(Layout::kRows, 4);
130   // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
131   // We process 8 of these chunks at a time, padding short input chunks.
132   constexpr int kNumRowChunks = 8;
133   constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows;
134 
135   const std::int8_t* src_ptr0 = src_ptr;
136   const std::int8_t* src_ptr1 = src_ptr0 + src_stride;
137   const std::int8_t* src_ptr2 = src_ptr1 + src_stride;
138   const std::int8_t* src_ptr3 = src_ptr2 + src_stride;
139   const std::int8_t* src_ptr4 = src_ptr3 + src_stride;
140   const std::int8_t* src_ptr5 = src_ptr4 + src_stride;
141   const std::int8_t* src_ptr6 = src_ptr5 + src_stride;
142   const std::int8_t* src_ptr7 = src_ptr6 + src_stride;
143   std::int64_t src_inc0 = kNumChunkedSrcRows;
144   std::int64_t src_inc1 = kNumChunkedSrcRows;
145   std::int64_t src_inc2 = kNumChunkedSrcRows;
146   std::int64_t src_inc3 = kNumChunkedSrcRows;
147   std::int64_t src_inc4 = kNumChunkedSrcRows;
148   std::int64_t src_inc5 = kNumChunkedSrcRows;
149   std::int64_t src_inc6 = kNumChunkedSrcRows;
150   std::int64_t src_inc7 = kNumChunkedSrcRows;
151   // Handle cases where source does not have kHalfLayoutCols (8) columns.
152   if (remaining_src_cols < 8) {
153     if (remaining_src_cols <= 0) {
154       src_ptr0 = zerobuf;
155       src_inc0 = 0;
156     }
157     if (remaining_src_cols <= 1) {
158       src_ptr1 = zerobuf;
159       src_inc1 = 0;
160     }
161     if (remaining_src_cols <= 2) {
162       src_ptr2 = zerobuf;
163       src_inc2 = 0;
164     }
165     if (remaining_src_cols <= 3) {
166       src_ptr3 = zerobuf;
167       src_inc3 = 0;
168     }
169     if (remaining_src_cols <= 4) {
170       src_ptr4 = zerobuf;
171       src_inc4 = 0;
172     }
173     if (remaining_src_cols <= 5) {
174       src_ptr5 = zerobuf;
175       src_inc5 = 0;
176     }
177     if (remaining_src_cols <= 6) {
178       src_ptr6 = zerobuf;
179       src_inc6 = 0;
180     }
181     src_ptr7 = zerobuf;
182     src_inc7 = 0;
183   }
184 
185   const std::int8_t zero_point = zerobuf[0];
186 
187   if (sums_ptr) {
188     // i: kHalfLayoutCols.
189     for (int i = 0; i < 8; ++i) {
190       sums_ptr[i] = 0;
191     }
192   }
193   std::int32_t sums_adjustment = 0;
194   const __m512i ones_16bit = _mm512_set1_epi16(1);
195   __m512i sums_8x2_32bit = _mm512_set1_epi32(0);
196 
197   // The overall packing effectively pads the source rows to
198   // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we
199   // only pack for (src_rows + 31) & ~31. When there is an incomplete
200   // destination block, this is stored into trailing_buf instead of packed_ptr.
201   for (int k = 0; k < src_rows; k += 2 * kNumChunkedSrcRows) {
202     // m: {0, 1} for 2 chunks of rows.
203     for (int m = 0; m < 2; ++m) {
204       // Available source rows.
205       // If this is less than 0 (for m=1), we skip, having filled trailing
206       // buffer for m=0. Also, if source rows is zero on m=1, then we filled
207       // exactly to the end of the column in the packed buffer.
208       const int available_src_rows = src_rows - k - m * kNumChunkedSrcRows;
209       // Effectively,
210       // available rows = std::max(0, std::min(8, src_rows - k - 8 * 4 * m));
211       // treat each case separately.
212       if (available_src_rows >= kNumChunkedSrcRows) {
213         // i: chunks, s: Layout::Rows.
214         if (sums_ptr) {
215           __m512i t0, t1, t2, t3;
216           __m512i r0, r1, r2, r3;
217           const __m512i input_xor_v = _mm512_set1_epi8(input_xor);
218 
219           t0 = LoaduTwo(src_ptr0, src_ptr4);
220           t1 = LoaduTwo(src_ptr1, src_ptr5);
221           t2 = LoaduTwo(src_ptr2, src_ptr6);
222           t3 = LoaduTwo(src_ptr3, src_ptr7);
223 
224           r0 = _mm512_unpacklo_epi32(t0, t1);
225           r2 = _mm512_unpackhi_epi32(t0, t1);
226           r1 = _mm512_unpacklo_epi32(t2, t3);
227           r3 = _mm512_unpackhi_epi32(t2, t3);
228 
229           t0 = _mm512_unpacklo_epi64(r0, r1);
230           t2 = _mm512_unpackhi_epi64(r0, r1);
231           t1 = _mm512_unpacklo_epi64(r2, r3);
232           t3 = _mm512_unpackhi_epi64(r2, r3);
233 
234           r0 = _mm512_shuffle_i32x4(t0, t1, 0x88);
235           r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd);
236           r2 = _mm512_shuffle_i32x4(t2, t3, 0x88);
237           r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd);
238 
239           r0 = _mm512_xor_si512(r0, input_xor_v);
240           r1 = _mm512_xor_si512(r1, input_xor_v);
241           r2 = _mm512_xor_si512(r2, input_xor_v);
242           r3 = _mm512_xor_si512(r3, input_xor_v);
243 
244           const __m256i r0_0 = _mm512_castsi512_si256(r0);
245           const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1);
246           const __m256i r1_0 = _mm512_castsi512_si256(r1);
247           const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1);
248           const __m256i r2_0 = _mm512_castsi512_si256(r2);
249           const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1);
250           const __m256i r3_0 = _mm512_castsi512_si256(r3);
251           const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1);
252 
253           __m512i sums_8x4_16bit;
254           sums_8x4_16bit = _mm512_cvtepi8_epi16(r0_0);
255           sums_8x4_16bit =
256               _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r0_1));
257           sums_8x4_16bit =
258               _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_0));
259           sums_8x4_16bit =
260               _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_1));
261           sums_8x4_16bit =
262               _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_0));
263           sums_8x4_16bit =
264               _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_1));
265           sums_8x4_16bit =
266               _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_0));
267           sums_8x4_16bit =
268               _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_1));
269           // The sums have been performed across columns, and now we have
270           // 4x16-bit sums packed together. We use madd for pairwise 32-bit
271           // sums.
272           const __m512i sums_8x2_32bit_new =
273               _mm512_madd_epi16(sums_8x4_16bit, ones_16bit);
274           sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit, sums_8x2_32bit_new);
275 
276           _mm256_storeu_si256(
277               reinterpret_cast<__m256i*>(packed_ptr + 0 * 16 * 4), r0_0);
278           _mm256_storeu_si256(
279               reinterpret_cast<__m256i*>(packed_ptr + 2 * 16 * 4), r0_1);
280           _mm256_storeu_si256(
281               reinterpret_cast<__m256i*>(packed_ptr + 4 * 16 * 4), r1_0);
282           _mm256_storeu_si256(
283               reinterpret_cast<__m256i*>(packed_ptr + 6 * 16 * 4), r1_1);
284           _mm256_storeu_si256(
285               reinterpret_cast<__m256i*>(packed_ptr + 1 * 16 * 4), r2_0);
286           _mm256_storeu_si256(
287               reinterpret_cast<__m256i*>(packed_ptr + 3 * 16 * 4), r2_1);
288           _mm256_storeu_si256(
289               reinterpret_cast<__m256i*>(packed_ptr + 5 * 16 * 4), r3_0);
290           _mm256_storeu_si256(
291               reinterpret_cast<__m256i*>(packed_ptr + 7 * 16 * 4), r3_1);
292         } else {
293           __m512i t0, t1, t2, t3;
294           __m512i r0, r1, r2, r3;
295           const __m512i input_xor_v = _mm512_set1_epi8(input_xor);
296 
297           t0 = LoaduTwo(src_ptr0, src_ptr4);
298           t1 = LoaduTwo(src_ptr1, src_ptr5);
299           t2 = LoaduTwo(src_ptr2, src_ptr6);
300           t3 = LoaduTwo(src_ptr3, src_ptr7);
301 
302           r0 = _mm512_unpacklo_epi32(t0, t1);
303           r2 = _mm512_unpackhi_epi32(t0, t1);
304           r1 = _mm512_unpacklo_epi32(t2, t3);
305           r3 = _mm512_unpackhi_epi32(t2, t3);
306 
307           t0 = _mm512_unpacklo_epi64(r0, r1);
308           t2 = _mm512_unpackhi_epi64(r0, r1);
309           t1 = _mm512_unpacklo_epi64(r2, r3);
310           t3 = _mm512_unpackhi_epi64(r2, r3);
311 
312           r0 = _mm512_shuffle_i32x4(t0, t1, 0x88);
313           r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd);
314           r2 = _mm512_shuffle_i32x4(t2, t3, 0x88);
315           r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd);
316 
317           r0 = _mm512_xor_si512(r0, input_xor_v);
318           r1 = _mm512_xor_si512(r1, input_xor_v);
319           r2 = _mm512_xor_si512(r2, input_xor_v);
320           r3 = _mm512_xor_si512(r3, input_xor_v);
321 
322           const __m256i r0_0 = _mm512_castsi512_si256(r0);
323           const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1);
324           const __m256i r1_0 = _mm512_castsi512_si256(r1);
325           const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1);
326           const __m256i r2_0 = _mm512_castsi512_si256(r2);
327           const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1);
328           const __m256i r3_0 = _mm512_castsi512_si256(r3);
329           const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1);
330           _mm256_storeu_si256(
331               reinterpret_cast<__m256i*>(packed_ptr + 0 * 16 * 4), r0_0);
332           _mm256_storeu_si256(
333               reinterpret_cast<__m256i*>(packed_ptr + 2 * 16 * 4), r0_1);
334           _mm256_storeu_si256(
335               reinterpret_cast<__m256i*>(packed_ptr + 4 * 16 * 4), r1_0);
336           _mm256_storeu_si256(
337               reinterpret_cast<__m256i*>(packed_ptr + 6 * 16 * 4), r1_1);
338           _mm256_storeu_si256(
339               reinterpret_cast<__m256i*>(packed_ptr + 1 * 16 * 4), r2_0);
340           _mm256_storeu_si256(
341               reinterpret_cast<__m256i*>(packed_ptr + 3 * 16 * 4), r2_1);
342           _mm256_storeu_si256(
343               reinterpret_cast<__m256i*>(packed_ptr + 5 * 16 * 4), r3_0);
344           _mm256_storeu_si256(
345               reinterpret_cast<__m256i*>(packed_ptr + 7 * 16 * 4), r3_1);
346         }
347       } else if (available_src_rows > 0) {
348         RUY_DCHECK_LT(available_src_rows >> 2, kNumChunkedSrcRows);
349         const __mmask32 row_mask =
350             (static_cast<std::uint64_t>(1) << available_src_rows) - 1;
351 
352         // We do not care what goes into the trailing buffer, but we want
353         // in_data[...] ^ input_xor == 0 for irrelevant values in the summation.
354         //
355         // We compensate for padding-with-zero_point by initializing the
356         // summations with the compensating offset, effectively
357         // ((input_xor ^ input_xor) - (zero_point ^ input_xor)) *
358         //                         4 * (8 - ((available_src_rows + 3) >> 2)).
359         //
360         // Note that (zero_point ^ input_xor) is performed in 8-bits and then
361         // cast.
362         sums_adjustment += -(zero_point ^ input_xor) * 4 *
363                            (8 - ((available_src_rows + 3) >> 2));
364 
365         __m512i t0, t1, t2, t3;
366         __m512i r0, r1, r2, r3;
367         const __m512i input_xor_v = _mm512_set1_epi8(input_xor);
368         const __m256i zero_point_v = _mm256_set1_epi8(zero_point);
369 
370         t0 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr0, src_ptr4);
371         t1 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr1, src_ptr5);
372         t2 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr2, src_ptr6);
373         t3 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr3, src_ptr7);
374 
375         r0 = _mm512_unpacklo_epi32(t0, t1);
376         r2 = _mm512_unpackhi_epi32(t0, t1);
377         r1 = _mm512_unpacklo_epi32(t2, t3);
378         r3 = _mm512_unpackhi_epi32(t2, t3);
379 
380         t0 = _mm512_unpacklo_epi64(r0, r1);
381         t2 = _mm512_unpackhi_epi64(r0, r1);
382         t1 = _mm512_unpacklo_epi64(r2, r3);
383         t3 = _mm512_unpackhi_epi64(r2, r3);
384 
385         r0 = _mm512_shuffle_i32x4(t0, t1, 0x88);
386         r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd);
387         r2 = _mm512_shuffle_i32x4(t2, t3, 0x88);
388         r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd);
389 
390         r0 = _mm512_xor_si512(r0, input_xor_v);
391         r1 = _mm512_xor_si512(r1, input_xor_v);
392         r2 = _mm512_xor_si512(r2, input_xor_v);
393         r3 = _mm512_xor_si512(r3, input_xor_v);
394 
395         const __m256i r0_0 = _mm512_castsi512_si256(r0);
396         const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1);
397         const __m256i r1_0 = _mm512_castsi512_si256(r1);
398         const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1);
399         const __m256i r2_0 = _mm512_castsi512_si256(r2);
400         const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1);
401         const __m256i r3_0 = _mm512_castsi512_si256(r3);
402         const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1);
403 
404         __m512i sums_8x4_16bit;
405         sums_8x4_16bit = _mm512_cvtepi8_epi16(r0_0);
406         sums_8x4_16bit =
407             _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r0_1));
408         sums_8x4_16bit =
409             _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_0));
410         sums_8x4_16bit =
411             _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_1));
412         sums_8x4_16bit =
413             _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_0));
414         sums_8x4_16bit =
415             _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_1));
416         sums_8x4_16bit =
417             _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_0));
418         sums_8x4_16bit =
419             _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_1));
420         // The sums have been performed across columns, and now we have
421         // 4x16-bit sums packed together. We use madd for pairwise 32-bit
422         // sums.
423         const __m512i sums_8x2_32bit_new =
424             _mm512_madd_epi16(sums_8x4_16bit, ones_16bit);
425         sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit, sums_8x2_32bit_new);
426 
427         _mm256_storeu_si256(
428             reinterpret_cast<__m256i*>(trailing_buf + 0 * 16 * 4), r0_0);
429         _mm256_storeu_si256(
430             reinterpret_cast<__m256i*>(trailing_buf + 2 * 16 * 4), r0_1);
431         _mm256_storeu_si256(
432             reinterpret_cast<__m256i*>(trailing_buf + 4 * 16 * 4), r1_0);
433         _mm256_storeu_si256(
434             reinterpret_cast<__m256i*>(trailing_buf + 6 * 16 * 4), r1_1);
435         _mm256_storeu_si256(
436             reinterpret_cast<__m256i*>(trailing_buf + 1 * 16 * 4), r2_0);
437         _mm256_storeu_si256(
438             reinterpret_cast<__m256i*>(trailing_buf + 3 * 16 * 4), r2_1);
439         _mm256_storeu_si256(
440             reinterpret_cast<__m256i*>(trailing_buf + 5 * 16 * 4), r3_0);
441         _mm256_storeu_si256(
442             reinterpret_cast<__m256i*>(trailing_buf + 7 * 16 * 4), r3_1);
443       }
444 
445       packed_ptr += 16 * kNumChunkedSrcRows;
446       src_ptr0 += src_inc0;
447       src_ptr1 += src_inc1;
448       src_ptr2 += src_inc2;
449       src_ptr3 += src_inc3;
450       src_ptr4 += src_inc4;
451       src_ptr5 += src_inc5;
452       src_ptr6 += src_inc6;
453       src_ptr7 += src_inc7;
454     }
455   }
456 
457   if (sums_ptr) {
458     const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment);
459 
460     __m256i sums =
461         _mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums_ptr));
462     const __m512i idx =
463         _mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0);
464 
465     // We earlier used madd for pairwise 32-bit sums, and now we deinterlace the
466     // neighbours, finshing up by adding them to the stored accumulated sums.
467     const __m512i sums_2x8_32bit =
468         _mm512_permutexvar_epi32(idx, sums_8x2_32bit);
469     sums = _mm256_add_epi32(sums, sums_adjustment_v);
470     sums = _mm256_add_epi32(sums, _mm512_castsi512_si256(sums_2x8_32bit));
471     sums = _mm256_add_epi32(sums, _mm512_extracti32x8_epi32(sums_2x8_32bit, 1));
472 
473     _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums_ptr), sums);
474   }
475 }
476 
477 inline void HalfPack16bitAvx512(const std::int16_t* src_ptr,
478                                 const std::int16_t* zerobuf, int src_stride,
479                                 int remaining_src_cols, int src_rows,
480                                 std::int16_t* packed_ptr,
481                                 std::int32_t* sums_ptr,
482                                 std::int16_t* trailing_buf) {
483   using Layout = PackImpl16bitAvx512::Layout;
484   RUY_DCHECK_EQ(Layout::kCols, 16);
485   RUY_DCHECK_EQ(Layout::kRows, 4);
486   // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
487   // We process 4 of these chunks at a time, padding std::int16_t input chunks.
488   constexpr int kNumRowChunks = 4;
489   constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows;
490 
491   const std::int16_t* src_ptr0 = src_ptr;
492   const std::int16_t* src_ptr1 = src_ptr0 + src_stride;
493   const std::int16_t* src_ptr2 = src_ptr1 + src_stride;
494   const std::int16_t* src_ptr3 = src_ptr2 + src_stride;
495   const std::int16_t* src_ptr4 = src_ptr3 + src_stride;
496   const std::int16_t* src_ptr5 = src_ptr4 + src_stride;
497   const std::int16_t* src_ptr6 = src_ptr5 + src_stride;
498   const std::int16_t* src_ptr7 = src_ptr6 + src_stride;
499   std::int64_t src_inc0 = kNumChunkedSrcRows;
500   std::int64_t src_inc1 = kNumChunkedSrcRows;
501   std::int64_t src_inc2 = kNumChunkedSrcRows;
502   std::int64_t src_inc3 = kNumChunkedSrcRows;
503   std::int64_t src_inc4 = kNumChunkedSrcRows;
504   std::int64_t src_inc5 = kNumChunkedSrcRows;
505   std::int64_t src_inc6 = kNumChunkedSrcRows;
506   std::int64_t src_inc7 = kNumChunkedSrcRows;
507   // Handle cases where source does not have kHalfLayoutCols (8) columns.
508   if (remaining_src_cols < 8) {
509     if (remaining_src_cols <= 0) {
510       src_ptr0 = zerobuf;
511       src_inc0 = 0;
512     }
513     if (remaining_src_cols <= 1) {
514       src_ptr1 = zerobuf;
515       src_inc1 = 0;
516     }
517     if (remaining_src_cols <= 2) {
518       src_ptr2 = zerobuf;
519       src_inc2 = 0;
520     }
521     if (remaining_src_cols <= 3) {
522       src_ptr3 = zerobuf;
523       src_inc3 = 0;
524     }
525     if (remaining_src_cols <= 4) {
526       src_ptr4 = zerobuf;
527       src_inc4 = 0;
528     }
529     if (remaining_src_cols <= 5) {
530       src_ptr5 = zerobuf;
531       src_inc5 = 0;
532     }
533     if (remaining_src_cols <= 6) {
534       src_ptr6 = zerobuf;
535       src_inc6 = 0;
536     }
537     src_ptr7 = zerobuf;
538     src_inc7 = 0;
539   }
540 
541   const std::int16_t zero_point = zerobuf[0];
542 
543   if (sums_ptr) {
544     // i: kHalfLayoutCols.
545     for (int i = 0; i < 8; ++i) {
546       sums_ptr[i] = 0;
547     }
548   }
549   std::int32_t sums_adjustment = 0;
550   const __m512i ones_16bit = _mm512_set1_epi16(1);
551   __m512i sums_8x2_32bit = _mm512_set1_epi32(0);
552 
553   // The overall packing effectively pads the source rows to
554   // (src_rows + 31) & ~31. The iteration over k may skip when m=1, and then we
555   // only pack for (src_rows + 15) & ~15. When there is an incomplete
556   // destination block, this is stored into trailing_buf instead of packed_ptr.
557   for (int k = 0; k < src_rows; k += 2 * kNumChunkedSrcRows) {
558     // m: {0, 1} for 2 chunks of rows.
559     for (int m = 0; m < 2; ++m) {
560       const int available_src_rows = src_rows - k - m * kNumChunkedSrcRows;
561 
562       // Available source rows.
563       // If this is less than 0 (for m=1), we skip, having filled trailing
564       // buffer for m=0. Also, if source rows is zero on m=1, then we filled
565       // exactly to the end of the column in the packed buffer.
566       if (available_src_rows > 0) {
567         __m512i t0, t1, t2, t3;
568         __m512i r0, r1, r2, r3;
569         std::int16_t* dst_ptr = packed_ptr;
570 
571         if (available_src_rows >= kNumChunkedSrcRows) {
572           t0 = LoaduTwo(src_ptr0, src_ptr4);
573           t1 = LoaduTwo(src_ptr1, src_ptr5);
574           t2 = LoaduTwo(src_ptr2, src_ptr6);
575           t3 = LoaduTwo(src_ptr3, src_ptr7);
576         } else {
577           RUY_DCHECK_LT(available_src_rows >> 2, kNumChunkedSrcRows);
578           // We do not care what goes into the trailing buffer, but we want
579           // in_data[...] == zero_point for irrelevant values in the summation.
580           //
581           // We compensate for padding-with-zero_point by initializing the
582           // summations with the compensating offset.
583           sums_adjustment +=
584               -(zero_point)*4 * (4 - ((available_src_rows + 3) >> 2));
585 
586           const __m256i zero_point_v = _mm256_set1_epi16(zero_point);
587           const __mmask32 row_mask =
588               (static_cast<std::uint64_t>(1) << available_src_rows) - 1;
589 
590           t0 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr0, src_ptr4);
591           t1 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr1, src_ptr5);
592           t2 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr2, src_ptr6);
593           t3 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr3, src_ptr7);
594           dst_ptr = trailing_buf;
595         }
596 
597         r0 = _mm512_unpacklo_epi64(t0, t1);
598         r2 = _mm512_unpackhi_epi64(t0, t1);
599         r1 = _mm512_unpacklo_epi64(t2, t3);
600         r3 = _mm512_unpackhi_epi64(t2, t3);
601 
602         r1 = _mm512_permutex_epi64(r1, 0x4e);
603         r3 = _mm512_permutex_epi64(r3, 0x4e);
604 
605         t0 = _mm512_mask_blend_epi64(0xcc, r0, r1);
606         t1 = _mm512_mask_blend_epi64(0x33, r0, r1);
607         t2 = _mm512_mask_blend_epi64(0xcc, r2, r3);
608         t3 = _mm512_mask_blend_epi64(0x33, r2, r3);
609 
610         t1 = _mm512_permutex_epi64(t1, 0x4e);
611         t3 = _mm512_permutex_epi64(t3, 0x4e);
612 
613         _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst_ptr + 0 * 16 * 4),
614                             t0);
615         _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst_ptr + 2 * 16 * 4),
616                             t1);
617         _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst_ptr + 1 * 16 * 4),
618                             t2);
619         _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst_ptr + 3 * 16 * 4),
620                             t3);
621 
622         if (sums_ptr) {
623           sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit,
624                                             _mm512_madd_epi16(t0, ones_16bit));
625           sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit,
626                                             _mm512_madd_epi16(t1, ones_16bit));
627           sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit,
628                                             _mm512_madd_epi16(t2, ones_16bit));
629           sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit,
630                                             _mm512_madd_epi16(t3, ones_16bit));
631         }
632       }
633 
634       packed_ptr += 16 * kNumChunkedSrcRows;
635       src_ptr0 += src_inc0;
636       src_ptr1 += src_inc1;
637       src_ptr2 += src_inc2;
638       src_ptr3 += src_inc3;
639       src_ptr4 += src_inc4;
640       src_ptr5 += src_inc5;
641       src_ptr6 += src_inc6;
642       src_ptr7 += src_inc7;
643     }
644   }
645 
646   if (sums_ptr) {
647     const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment);
648 
649     __m256i sums =
650         _mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums_ptr));
651     const __m512i idx =
652         _mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0);
653 
654     const __m512i sums_2x8_32bit =
655         _mm512_permutexvar_epi32(idx, sums_8x2_32bit);
656     sums = _mm256_add_epi32(sums, sums_adjustment_v);
657     sums = _mm256_add_epi32(sums, _mm512_castsi512_si256(sums_2x8_32bit));
658     sums = _mm256_add_epi32(sums, _mm512_extracti32x8_epi32(sums_2x8_32bit, 1));
659 
660     _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums_ptr), sums);
661   }
662 }
663 
664 inline __m512 LoaduTwo(const float* addr_lo, const float* addr_hi) {
665   const __m512 lower_filled = _mm512_castps256_ps512(_mm256_loadu_ps(addr_lo));
666   return _mm512_insertf32x8(lower_filled, _mm256_loadu_ps(addr_hi), 1);
667 }
668 
669 inline __m512 MaskLoaduTwo(__mmask8 row_mask, const float* addr_lo,
670                            const float* addr_hi) {
671   const __m512 lower_filled =
672       _mm512_castps256_ps512(_mm256_maskz_loadu_ps(row_mask, addr_lo));
673   return _mm512_insertf32x8(lower_filled,
674                             _mm256_maskz_loadu_ps(row_mask, addr_hi), 1);
675 }
676 
677 inline __m512 Mm512UnpackloPsx2(const __m512 a, const __m512 b) {
678   return _mm512_castpd_ps(
679       _mm512_unpacklo_pd(_mm512_castps_pd(a), _mm512_castps_pd(b)));
680 }
681 
682 inline __m512 Mm512UnpackhiPsx2(const __m512 a, const __m512 b) {
683   return _mm512_castpd_ps(
684       _mm512_unpackhi_pd(_mm512_castps_pd(a), _mm512_castps_pd(b)));
685 }
686 
687 inline void HalfPackFloatAvx512(const float* src_ptr, const float* zerobuf,
688                                 int src_stride, int remaining_src_cols,
689                                 int src_rows, float* packed_ptr,
690                                 float* trailing_buf) {
691   const float* src_ptr0 = src_ptr;
692   const float* src_ptr1 = src_ptr0 + src_stride;
693   const float* src_ptr2 = src_ptr1 + src_stride;
694   const float* src_ptr3 = src_ptr2 + src_stride;
695   const float* src_ptr4 = src_ptr3 + src_stride;
696   const float* src_ptr5 = src_ptr4 + src_stride;
697   const float* src_ptr6 = src_ptr5 + src_stride;
698   const float* src_ptr7 = src_ptr6 + src_stride;
699   std::int64_t src_inc0 = 8;
700   std::int64_t src_inc1 = 8;
701   std::int64_t src_inc2 = 8;
702   std::int64_t src_inc3 = 8;
703   std::int64_t src_inc4 = 8;
704   std::int64_t src_inc5 = 8;
705   std::int64_t src_inc6 = 8;
706   std::int64_t src_inc7 = 8;
707   if (remaining_src_cols < 8) {
708     if (remaining_src_cols <= 0) {
709       src_ptr0 = zerobuf;
710       src_inc0 = 0;
711     }
712     if (remaining_src_cols <= 1) {
713       src_ptr1 = zerobuf;
714       src_inc1 = 0;
715     }
716     if (remaining_src_cols <= 2) {
717       src_ptr2 = zerobuf;
718       src_inc2 = 0;
719     }
720     if (remaining_src_cols <= 3) {
721       src_ptr3 = zerobuf;
722       src_inc3 = 0;
723     }
724     if (remaining_src_cols <= 4) {
725       src_ptr4 = zerobuf;
726       src_inc4 = 0;
727     }
728     if (remaining_src_cols <= 5) {
729       src_ptr5 = zerobuf;
730       src_inc5 = 0;
731     }
732     if (remaining_src_cols <= 6) {
733       src_ptr6 = zerobuf;
734       src_inc6 = 0;
735     }
736     src_ptr7 = zerobuf;
737     src_inc7 = 0;
738   }
739 
740   for (int k = 0; k < src_rows; k += 16) {
741     for (int m = 0; m < 2; ++m) {
742       const int available_src_rows = src_rows - k - 8 * m;
743       // Effectively,
744       // available_src_rows = std::max(0, std::min(8, src_rows - k - 8 * m));
745       // but treat each case separately.
746       if (available_src_rows > 7) {
747         __m512 t0, t1, t2, t3;
748         __m512 r0, r1, r2, r3;
749 
750         t0 = LoaduTwo(src_ptr0, src_ptr4);
751         t1 = LoaduTwo(src_ptr1, src_ptr5);
752         t2 = LoaduTwo(src_ptr2, src_ptr6);
753         t3 = LoaduTwo(src_ptr3, src_ptr7);
754 
755         r0 = _mm512_unpacklo_ps(t0, t1);
756         r2 = _mm512_unpackhi_ps(t0, t1);
757         r1 = _mm512_unpacklo_ps(t2, t3);
758         r3 = _mm512_unpackhi_ps(t2, t3);
759 
760         t0 = Mm512UnpackloPsx2(r0, r1);
761         t2 = Mm512UnpackhiPsx2(r0, r1);
762         t1 = Mm512UnpackloPsx2(r2, r3);
763         t3 = Mm512UnpackhiPsx2(r2, r3);
764 
765         r0 = _mm512_shuffle_f32x4(t0, t1, 0x88);
766         r1 = _mm512_shuffle_f32x4(t0, t1, 0xdd);
767         r2 = _mm512_shuffle_f32x4(t2, t3, 0x88);
768         r3 = _mm512_shuffle_f32x4(t2, t3, 0xdd);
769 
770         _mm256_storeu_ps(packed_ptr + 0 * 16, _mm512_castps512_ps256(r0));
771         _mm256_storeu_ps(packed_ptr + 2 * 16, _mm512_extractf32x8_ps(r0, 1));
772         _mm256_storeu_ps(packed_ptr + 4 * 16, _mm512_castps512_ps256(r1));
773         _mm256_storeu_ps(packed_ptr + 6 * 16, _mm512_extractf32x8_ps(r1, 1));
774         _mm256_storeu_ps(packed_ptr + 1 * 16, _mm512_castps512_ps256(r2));
775         _mm256_storeu_ps(packed_ptr + 3 * 16, _mm512_extractf32x8_ps(r2, 1));
776         _mm256_storeu_ps(packed_ptr + 5 * 16, _mm512_castps512_ps256(r3));
777         _mm256_storeu_ps(packed_ptr + 7 * 16, _mm512_extractf32x8_ps(r3, 1));
778       } else if (available_src_rows > 0) {
779         const __mmask8 row_mask =
780             (static_cast<std::uint32_t>(1) << available_src_rows) - 1;
781 
782         __m512 t0, t1, t2, t3;
783         __m512 r0, r1, r2, r3;
784 
785         t0 = MaskLoaduTwo(row_mask, src_ptr0, src_ptr4);
786         t1 = MaskLoaduTwo(row_mask, src_ptr1, src_ptr5);
787         t2 = MaskLoaduTwo(row_mask, src_ptr2, src_ptr6);
788         t3 = MaskLoaduTwo(row_mask, src_ptr3, src_ptr7);
789 
790         r0 = _mm512_unpacklo_ps(t0, t1);
791         r2 = _mm512_unpackhi_ps(t0, t1);
792         r1 = _mm512_unpacklo_ps(t2, t3);
793         r3 = _mm512_unpackhi_ps(t2, t3);
794 
795         t0 = Mm512UnpackloPsx2(r0, r1);
796         t2 = Mm512UnpackhiPsx2(r0, r1);
797         t1 = Mm512UnpackloPsx2(r2, r3);
798         t3 = Mm512UnpackhiPsx2(r2, r3);
799 
800         r0 = _mm512_shuffle_f32x4(t0, t1, 0x88);
801         r1 = _mm512_shuffle_f32x4(t0, t1, 0xdd);
802         r2 = _mm512_shuffle_f32x4(t2, t3, 0x88);
803         r3 = _mm512_shuffle_f32x4(t2, t3, 0xdd);
804 
805         _mm256_storeu_ps(trailing_buf + 0 * 16, _mm512_castps512_ps256(r0));
806         _mm256_storeu_ps(trailing_buf + 2 * 16, _mm512_extractf32x8_ps(r0, 1));
807         _mm256_storeu_ps(trailing_buf + 4 * 16, _mm512_castps512_ps256(r1));
808         _mm256_storeu_ps(trailing_buf + 6 * 16, _mm512_extractf32x8_ps(r1, 1));
809         _mm256_storeu_ps(trailing_buf + 1 * 16, _mm512_castps512_ps256(r2));
810         _mm256_storeu_ps(trailing_buf + 3 * 16, _mm512_extractf32x8_ps(r2, 1));
811         _mm256_storeu_ps(trailing_buf + 5 * 16, _mm512_castps512_ps256(r3));
812         // Do not store _mm512_extractf32x8_ps(r3, 1).
813       }
814 
815       packed_ptr += 16 * 8;
816       src_ptr0 += src_inc0;
817       src_ptr1 += src_inc1;
818       src_ptr2 += src_inc2;
819       src_ptr3 += src_inc3;
820       src_ptr4 += src_inc4;
821       src_ptr5 += src_inc5;
822       src_ptr6 += src_inc6;
823       src_ptr7 += src_inc7;
824     }
825   }
826 }
827 
828 inline void ZeroHalfFloatAvx512(int src_rows, float* packed_ptr) {
829   const int non_trailing_rows = src_rows & ~7;
830   for (int k = 0; k < non_trailing_rows; ++k) {
831     for (int j = 0; j < 8; ++j) {
832       packed_ptr[j] = 0.0f;
833     }
834     packed_ptr += 16;
835   }
836 }
837 
838 }  // namespace.
839 
840 void Pack8bitColMajorForAvx512(const std::int8_t* src_ptr,
841                                std::int8_t input_xor,
842                                const std::int8_t* zerobuf, int src_stride,
843                                int remaining_src_cols, int src_rows,
844                                std::int8_t* packed_ptr,
845                                std::int32_t* sums_ptr) {
846   profiler::ScopeLabel label("Pack kAvx512 8bit");
847 
848   using Layout = PackImpl8bitAvx512::Layout;
849   constexpr int kHalfBlockOffset = 32;
850   RUY_DCHECK_EQ(kHalfBlockOffset * 2, Layout::kRows * Layout::kCols);
851   static constexpr int kHalfLayoutCols =
852       PackImpl8bitAvx512::kHalfLayoutCols;  // Half the number of cols in a
853                                             // block.
854   RUY_DCHECK_EQ(kHalfLayoutCols, 8);
855   RUY_DCHECK_EQ(Layout::kCols, 16);
856   RUY_DCHECK_EQ(Layout::kRows, 4);
857 
858   // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
859   // We process 8 of these chunks at a time, padding short input chunks.
860   constexpr int kNumRowChunks = 8;
861 
862   // Each packed block is 4*16, and there are normally 8. The trailing block is
863   // only slightly shorter.
864   constexpr int kTrailingBufSize =
865       kNumRowChunks * Layout::kCols * Layout::kRows;
866   std::int8_t trailing_buf[kTrailingBufSize];
867   memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t));
868   constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1;
869 
870   std::int32_t* second_sums_ptr =
871       sums_ptr ? sums_ptr + kHalfLayoutCols : nullptr;
872   if (remaining_src_cols > kHalfLayoutCols) {
873     HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride,
874                        remaining_src_cols, src_rows, packed_ptr, sums_ptr,
875                        trailing_buf);
876     HalfPack8bitAvx512(src_ptr + src_stride * kHalfLayoutCols, input_xor,
877                        zerobuf, src_stride,
878                        remaining_src_cols - kHalfLayoutCols, src_rows,
879                        packed_ptr + kHalfBlockOffset, second_sums_ptr,
880                        trailing_buf + kHalfBlockOffset);
881   } else {
882     HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride,
883                        remaining_src_cols, src_rows, packed_ptr, sums_ptr,
884                        trailing_buf);
885     ZeroHalfAvx512<PackImpl8bitAvx512, std::int8_t>(
886         src_rows, zerobuf[0] ^ input_xor, packed_ptr + kHalfBlockOffset,
887         kChunkedRowMask);
888     // The kernel may not need the second half-blocks sums to be set.
889     if (second_sums_ptr) {
890       for (int i = 0; i < kHalfLayoutCols; ++i) {
891         second_sums_ptr[i] = (zerobuf[0] ^ input_xor) * ((src_rows + 3) & ~3);
892       }
893     }
894   }
895   const bool trailing_data = (src_rows & kChunkedRowMask) > 0;
896   // If the number of source rows is not a multiple of kChunkedRowMask, there
897   // will be data in the trailing buffer,
898   if (trailing_data) {
899     const int non_trailing_rows = src_rows & ~kChunkedRowMask;
900     // Destination "rows" are padded to next highest multiple of Layout::kRows.
901     const int dst_rows = (src_rows + 3) & ~3;
902     const int trailing_rows = dst_rows - non_trailing_rows;
903     memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf,
904            Layout::kCols * trailing_rows * sizeof(std::int8_t));
905   }
906 }
907 
908 void Pack16bitColMajorForAvx512(const std::int16_t* src_ptr,
909                                 const std::int16_t* zerobuf, int src_stride,
910                                 int remaining_src_cols, int src_rows,
911                                 std::int16_t* packed_ptr,
912                                 std::int32_t* sums_ptr) {
913   profiler::ScopeLabel label("Pack kAvx512 16bit");
914 
915   using Layout = PackImpl16bitAvx512::Layout;
916   constexpr int kHalfBlockOffset = 32;
917   RUY_DCHECK_EQ(kHalfBlockOffset * 2, Layout::kRows * Layout::kCols);
918   static constexpr int kHalfLayoutCols =
919       PackImpl16bitAvx512::kHalfLayoutCols;  // Half the number of cols in a
920                                              // block.
921   RUY_DCHECK_EQ(kHalfLayoutCols, 8);
922   RUY_DCHECK_EQ(Layout::kCols, 16);
923   RUY_DCHECK_EQ(Layout::kRows, 4);
924 
925   // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
926   // We process 8 of these chunks at a time, padding short input chunks.
927   constexpr int kNumRowChunks = 4;
928 
929   // Each packed block is 4*16, and there are normally 8. The trailing block is
930   // only slightly shorter.
931   constexpr int kTrailingBufSize =
932       kNumRowChunks * Layout::kCols * Layout::kRows;
933   std::int16_t trailing_buf[kTrailingBufSize] = {0};
934   constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1;
935 
936   std::int32_t* second_sums_ptr =
937       sums_ptr ? sums_ptr + kHalfLayoutCols : nullptr;
938   if (remaining_src_cols > kHalfLayoutCols) {
939     HalfPack16bitAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols,
940                         src_rows, packed_ptr, sums_ptr, trailing_buf);
941     HalfPack16bitAvx512(src_ptr + src_stride * kHalfLayoutCols, zerobuf,
942                         src_stride, remaining_src_cols - kHalfLayoutCols,
943                         src_rows, packed_ptr + kHalfBlockOffset,
944                         second_sums_ptr, trailing_buf + kHalfBlockOffset);
945   } else {
946     HalfPack16bitAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols,
947                         src_rows, packed_ptr, sums_ptr, trailing_buf);
948     ZeroHalfAvx512<PackImpl16bitAvx512, std::int16_t>(
949         src_rows, zerobuf[0], packed_ptr + kHalfBlockOffset, kChunkedRowMask);
950     // The kernel may not need the second half-blocks sums to be set.
951     if (second_sums_ptr) {
952       for (int i = 0; i < kHalfLayoutCols; ++i) {
953         second_sums_ptr[i] = (zerobuf[0]) * ((src_rows + 3) & ~3);
954       }
955     }
956   }
957   const bool trailing_data = (src_rows & kChunkedRowMask) > 0;
958   // If the number of source rows is not a multiple of kChunkedRowMask, there
959   // will be data in the trailing buffer,
960   if (trailing_data) {
961     const int non_trailing_rows = src_rows & ~kChunkedRowMask;
962     // Destination "rows" are padded to next highest multiple of Layout::kRows.
963     const int dst_rows = (src_rows + 3) & ~3;
964     const int trailing_rows = dst_rows - non_trailing_rows;
965     memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf,
966            Layout::kCols * trailing_rows * sizeof(std::int16_t));
967   }
968 }
969 
970 void PackFloatColMajorForAvx512(const float* src_ptr, const float* zerobuf,
971                                 int src_stride, int remaining_src_cols,
972                                 int src_rows, float* packed_ptr) {
973   profiler::ScopeLabel label("Pack kAvx512 float");
974   float trailing_buf[7 * 16];
975   if (remaining_src_cols > 8) {
976     HalfPackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols,
977                         src_rows, packed_ptr, trailing_buf);
978     HalfPackFloatAvx512(src_ptr + src_stride * 8, zerobuf, src_stride,
979                         remaining_src_cols - 8, src_rows, packed_ptr + 8,
980                         trailing_buf + 8);
981   } else {
982     memset(trailing_buf, 0, sizeof(trailing_buf));
983     HalfPackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols,
984                         src_rows, packed_ptr, trailing_buf);
985     ZeroHalfFloatAvx512(src_rows, packed_ptr + 8);
986   }
987   const int trailing_rows = src_rows & 7;
988   if (trailing_rows > 0) {
989     const int non_trailing_rows = src_rows & ~7;
990     memcpy(packed_ptr + 16 * non_trailing_rows, trailing_buf,
991            16 * trailing_rows * sizeof(float));
992   }
993 }
994 
995 void Pack8bitRowMajorForAvx512(const std::uint8_t* src_ptr, int src_stride,
996                                int src_zero_point, std::int8_t* packed_ptr,
997                                int packed_stride, int start_col, int end_col,
998                                int src_cols, int block_row, int src_rows,
999                                int input_xor, std::int32_t* sums) {
1000   int col = start_col;
1001   int src_end_col = std::min(end_col, src_cols);
1002 
1003   for (; col <= src_end_col - 16; col += 16) {
1004     std::int8_t* dst_ptr = packed_ptr;
1005     __m128i val0, val1, val2, val3;
1006     __m128i input_xor_dup = _mm_set1_epi8(input_xor);
1007     // Load a 4x16 block.
1008     if (block_row + 4 <= src_rows) {
1009       val0 = _mm_loadu_si128(
1010           reinterpret_cast<const __m128i*>(src_ptr + 0 * src_stride));
1011       val1 = _mm_loadu_si128(
1012           reinterpret_cast<const __m128i*>(src_ptr + 1 * src_stride));
1013       val2 = _mm_loadu_si128(
1014           reinterpret_cast<const __m128i*>(src_ptr + 2 * src_stride));
1015       val3 = _mm_loadu_si128(
1016           reinterpret_cast<const __m128i*>(src_ptr + 3 * src_stride));
1017     } else {
1018       val0 = _mm_set1_epi8(src_zero_point);
1019       val1 = val0;
1020       val2 = val0;
1021       val3 = val0;
1022       if (block_row + 0 < src_rows)
1023         val0 = _mm_loadu_si128(
1024             reinterpret_cast<const __m128i*>(src_ptr + 0 * src_stride));
1025       if (block_row + 1 < src_rows)
1026         val1 = _mm_loadu_si128(
1027             reinterpret_cast<const __m128i*>(src_ptr + 1 * src_stride));
1028       if (block_row + 2 < src_rows)
1029         val2 = _mm_loadu_si128(
1030             reinterpret_cast<const __m128i*>(src_ptr + 2 * src_stride));
1031       if (block_row + 3 < src_rows)
1032         val3 = _mm_loadu_si128(
1033             reinterpret_cast<const __m128i*>(src_ptr + 3 * src_stride));
1034     }
1035     // Maybe xor the sign bit to convert from uint8 to int8.
1036     val0 = _mm_xor_si128(val0, input_xor_dup);
1037     val1 = _mm_xor_si128(val1, input_xor_dup);
1038     val2 = _mm_xor_si128(val2, input_xor_dup);
1039     val3 = _mm_xor_si128(val3, input_xor_dup);
1040     // Update the sums.
1041     __m256i val16_0 = _mm256_cvtepi8_epi16(val0);
1042     __m256i val16_1 = _mm256_cvtepi8_epi16(val1);
1043     __m256i val16_2 = _mm256_cvtepi8_epi16(val2);
1044     __m256i val16_3 = _mm256_cvtepi8_epi16(val3);
1045     __m256i new_sum16 = _mm256_add_epi16(_mm256_add_epi16(val16_0, val16_1),
1046                                          _mm256_add_epi16(val16_2, val16_3));
1047     __m512i sum =
1048         _mm512_loadu_si512(reinterpret_cast<const __m512i*>(sums + col));
1049     sum = _mm512_add_epi32(sum, _mm512_cvtepi16_epi32(new_sum16));
1050     _mm512_storeu_si512(reinterpret_cast<__m512i*>(sums + col), sum);
1051     auto zip = [](__m128i x, __m128i y) {
1052       auto perm_64_0_64_0 = [](__m128i x) {
1053         return _mm256_permutexvar_epi64(_mm256_setr_epi64x(0, 2, 1, 3),
1054                                         _mm256_castsi128_si256(x));
1055       };
1056       return _mm256_unpacklo_epi8(perm_64_0_64_0(x), perm_64_0_64_0(y));
1057     };
1058     __m256i t2_val0 = zip(val0, val1);
1059     __m256i t2_val1 = zip(val2, val3);
1060     __m256i t4_val0 = _mm256_unpacklo_epi16(t2_val0, t2_val1);
1061     __m256i t4_val1 = _mm256_unpackhi_epi16(t2_val0, t2_val1);
1062     _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr),
1063                      _mm256_extractf128_si256(t4_val0, 0));
1064     _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 16),
1065                      _mm256_extractf128_si256(t4_val1, 0));
1066     _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 32),
1067                      _mm256_extractf128_si256(t4_val0, 1));
1068     _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 48),
1069                      _mm256_extractf128_si256(t4_val1, 1));
1070     src_ptr += 16;
1071     packed_ptr += packed_stride * 16;
1072   }
1073   for (; col < src_end_col; col++) {
1074     std::int32_t accum = 0;
1075     for (int r = 0; r < 4; r++) {
1076       std::int8_t packed_val;
1077       if (block_row + r < src_rows) {
1078         packed_val = input_xor ^ src_ptr[r * src_stride];
1079       } else {
1080         packed_val = input_xor ^ src_zero_point;
1081       }
1082       accum += packed_val;
1083       *packed_ptr++ = packed_val;
1084     }
1085     if (sums) {
1086       sums[col] += accum;
1087     }
1088     src_ptr++;
1089   }
1090   for (; col < end_col; col++) {
1091     std::memset(packed_ptr, 0, 4);
1092     packed_ptr += 4;
1093   }
1094 }
1095 
1096 #endif  // RUY_PLATFORM_AVX512 && RUY_OPT(INTRINSICS)
1097 
1098 }  // namespace ruy
1099