xref: /aosp_15_r20/external/ruy/ruy/kernel_x86.h (revision bb86c7ed5fb1b98a7eac808e443a46cc8b90dfc0)
1 /* Copyright 2019 Google LLC. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef RUY_RUY_KERNEL_X86_H_
17 #define RUY_RUY_KERNEL_X86_H_
18 
19 #include <cstdint>
20 #include <cstring>
21 
22 #include "ruy/kernel_common.h"
23 #include "ruy/mat.h"
24 #include "ruy/mul_params.h"
25 #include "ruy/opt_set.h"
26 #include "ruy/path.h"
27 #include "ruy/platform.h"
28 #include "ruy/tune.h"
29 
30 namespace ruy {
31 
32 #if RUY_PLATFORM_X86
33 
34 RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kAvx)
35 RUY_INHERIT_KERNEL(Path::kAvx, Path::kAvx2Fma)
36 RUY_INHERIT_KERNEL(Path::kAvx2Fma, Path::kAvx512)
37 
38 void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params);
39 void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params);
40 
41 template <typename DstScalar>
42 struct Kernel<Path::kAvx512, std::int8_t, std::int8_t, std::int32_t, DstScalar> {
43   static constexpr Path kPath = Path::kAvx512;
44   Tuning tuning = Tuning::kAuto;
45   using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>;
46   using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>;
47   explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
48   void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs,
49            const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
50            int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
51     KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
52     MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
53                          end_col, dst, &params);
54     if (dst->layout.cols == 1 &&
55         mul_params.channel_dimension() == ChannelDimension::kRow) {
56       Kernel8bitAvx512SingleCol(params);
57     } else {
58       Kernel8bitAvx512(params);
59     }
60   }
61 };
62 
63 template <typename DstScalar>
64 struct Kernel<Path::kAvx512, std::int8_t, std::int16_t, std::int32_t,
65               DstScalar> {
66   static constexpr Path kPath = Path::kAvx512;
67   Tuning tuning = Tuning::kAuto;
68   using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>;
69   using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>;
70   explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
71   void Run(const PMat<std::int8_t>& lhs, const PMat<std::int16_t>& rhs,
72            const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
73            int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
74     KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
75     MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
76                          end_col, dst, &params);
77     if (dst->layout.cols == 1 &&
78         mul_params.channel_dimension() == ChannelDimension::kRow) {
79       Kernel8bitAvx512SingleCol(params);
80     } else {
81       Kernel8bitAvx512(params);
82     }
83   }
84 };
85 
86 void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params);
87 void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& param);
88 
89 template <>
90 struct Kernel<Path::kAvx512, float, float, float, float> {
91   static constexpr Path kPath = Path::kAvx512;
92   Tuning tuning = Tuning::kAuto;
93   using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 16>;
94   using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 16>;
95   explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
96   void Run(const PMat<float>& lhs, const PMat<float>& rhs,
97            const MulParams<float, float>& mul_params, int start_row,
98            int start_col, int end_row, int end_col, Mat<float>* dst) const {
99     KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
100     MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row,
101                           end_col, dst, &params);
102     if (dst->layout.cols == 1 &&
103         mul_params.channel_dimension() == ChannelDimension::kRow) {
104       KernelFloatAvx512SingleCol(params);
105     } else {
106       KernelFloatAvx512(params);
107     }
108   }
109 };
110 
111 void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params);
112 void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params);
113 
114 template <typename DstScalar>
115 struct Kernel<Path::kAvx2Fma, std::int8_t, std::int8_t, std::int32_t,
116               DstScalar> {
117   static constexpr Path kPath = Path::kAvx2Fma;
118   Tuning tuning = Tuning::kAuto;
119   using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
120   using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
121   explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
122   void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs,
123            const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
124            int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
125     KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
126     MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
127                          end_col, dst, &params);
128     if (dst->layout.cols == 1 &&
129         mul_params.channel_dimension() == ChannelDimension::kRow) {
130       Kernel8bitAvx2SingleCol(params);
131     } else {
132       Kernel8bitAvx2(params);
133     }
134   }
135 };
136 
137 template <typename DstScalar>
138 struct Kernel<Path::kAvx2Fma, std::int8_t, std::int16_t, std::int32_t,
139               DstScalar> {
140   static constexpr Path kPath = Path::kAvx2Fma;
141   Tuning tuning = Tuning::kAuto;
142   using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
143   using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
144   explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
145   void Run(const PMat<std::int8_t>& lhs, const PMat<std::int16_t>& rhs,
146            const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
147            int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
148     KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
149     MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
150                          end_col, dst, &params);
151     if (dst->layout.cols == 1 &&
152         mul_params.channel_dimension() == ChannelDimension::kRow) {
153       Kernel8bitAvx2SingleCol(params);
154     } else {
155       Kernel8bitAvx2(params);
156     }
157   }
158 };
159 
160 void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params);
161 void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params);
162 
163 template <>
164 struct Kernel<Path::kAvx2Fma, float, float, float, float> {
165   static constexpr Path kPath = Path::kAvx2Fma;
166   Tuning tuning = Tuning::kAuto;
167   using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
168   using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
169   explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
170   void Run(const PMat<float>& lhs, const PMat<float>& rhs,
171            const MulParams<float, float>& mul_params, int start_row,
172            int start_col, int end_row, int end_col, Mat<float>* dst) const {
173     KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
174     MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row,
175                           end_col, dst, &params);
176     if (dst->layout.cols == 1 &&
177         mul_params.channel_dimension() == ChannelDimension::kRow) {
178       KernelFloatAvx2SingleCol(params);
179     } else {
180       KernelFloatAvx2(params);
181     }
182   }
183 };
184 
185 void KernelFloatAvx(const KernelParamsFloat<8, 8>& params);
186 void KernelFloatAvxSingleCol(const KernelParamsFloat<8, 8>& params);
187 
188 template <>
189 struct Kernel<Path::kAvx, float, float, float, float> {
190   static constexpr Path kPath = Path::kAvx;
191   Tuning tuning = Tuning::kAuto;
192   using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
193   using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
194   explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
195   void Run(const PMat<float>& lhs, const PMat<float>& rhs,
196            const MulParams<float, float>& mul_params, int start_row,
197            int start_col, int end_row, int end_col, Mat<float>* dst) const {
198     KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
199     MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row,
200                           end_col, dst, &params);
201     if (dst->layout.cols == 1 &&
202         mul_params.channel_dimension() == ChannelDimension::kRow) {
203       KernelFloatAvxSingleCol(params);
204     } else {
205       KernelFloatAvx(params);
206     }
207   }
208 };
209 
210 void Kernel8bitAvx(const KernelParams8bit<8, 8>& params);
211 void Kernel8bitAvxSingleCol(const KernelParams8bit<8, 8>& params);
212 
213 template <typename DstScalar>
214 struct Kernel<Path::kAvx, std::int8_t, std::int8_t, std::int32_t, DstScalar> {
215   static constexpr Path kPath = Path::kAvx2Fma;
216   Tuning tuning = Tuning::kAuto;
217   using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
218   using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
219   explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
220   void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs,
221            const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
222            int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
223     KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
224     MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
225                          end_col, dst, &params);
226     if (dst->layout.cols == 1 &&
227         mul_params.channel_dimension() == ChannelDimension::kRow) {
228       Kernel8bitAvxSingleCol(params);
229     } else {
230       Kernel8bitAvx(params);
231     }
232   }
233 };
234 
235 #endif  // RUY_PLATFORM_X86
236 }  // namespace ruy
237 
238 #if ((RUY_PLATFORM_AVX || RUY_PLATFORM_AVX2_FMA) && RUY_OPT(ASM))
239 
240 #include <immintrin.h>  // IWYU pragma: keep
241 
242 namespace ruy {
243 namespace {
244 namespace intrin_utils {
245 
246 // Defined as a template so clang won't detect it as an uneeded
247 // definition.
248 template <Path path>
249 inline float mm256_get1_ps(const __m256 a, int i) {
250   __m256i ai = _mm256_castps_si256(a);
251   int float_val_as_int;
252   switch (i) {
253     case 0:
254       float_val_as_int = _mm256_extract_epi32(ai, 0);
255       break;
256     case 1:
257       float_val_as_int = _mm256_extract_epi32(ai, 1);
258       break;
259     case 2:
260       float_val_as_int = _mm256_extract_epi32(ai, 2);
261       break;
262     case 3:
263       float_val_as_int = _mm256_extract_epi32(ai, 3);
264       break;
265     case 4:
266       float_val_as_int = _mm256_extract_epi32(ai, 4);
267       break;
268     case 5:
269       float_val_as_int = _mm256_extract_epi32(ai, 5);
270       break;
271     case 6:
272       float_val_as_int = _mm256_extract_epi32(ai, 6);
273       break;
274     case 7:
275       float_val_as_int = _mm256_extract_epi32(ai, 7);
276       break;
277     default:
278       RUY_DCHECK_LT(i, 8);
279       return .0f;
280   }
281   float float_val;
282   std::memcpy(&float_val, &float_val_as_int, sizeof(float_val));
283   return float_val;
284 }
285 
286 // Defined as a template so clang won't detect it as an uneeded
287 // definition.
288 template <Path path>
289 inline void mm256_n_storeu_ps(float* dst, int residual_rows, const __m256 v) {
290   for (int i = 0; i < residual_rows; ++i) {
291     dst[i] = intrin_utils::mm256_get1_ps<path>(v, i);
292   }
293 }
294 
295 template <Path path>
296 inline __m256 MulAdd(const __m256&, const __m256&, const __m256&) {
297   // Specializations added for AVX and AVX2FMA paths in their respective kernel
298   // files.
299   RUY_DCHECK(false);
300   return _mm256_set1_ps(0);
301 }
302 
303 template <Path path>
304 inline __m256i mm256_shuffle_epi8(const __m256i&, const __m256i&) {
305   // Specializations added for AVX and AVX2FMA paths in their respective kernel
306   // files.
307   RUY_DCHECK(false);
308   return _mm256_set1_epi32(0);
309 }
310 
311 // Polyfill for _mm_storeu_si16(dst, v).
312 template <Path path>
313 inline void mm_storeu_si16(void* dst, __m128i v) {
314 #if (defined __clang__) || (defined _MSC_VER)
315   _mm_storeu_si16(dst, v);
316 #else
317   // GCC 9 lacks support for __mm_storeu_si16.
318   *static_cast<std::int16_t*>(dst) = _mm_extract_epi16(v, 0);
319 #endif
320 }
321 
322 // Polyfill for _mm_storeu_si32(dst, v).
323 template <Path path>
324 inline void mm_storeu_si32(void* dst, __m128i v) {
325 #if (defined __clang__) || (defined _MSC_VER)
326   _mm_storeu_si32(dst, v);
327 #else
328   // GCC 9 lacks support for __mm_storeu_si32.
329   *static_cast<std::int32_t*>(dst) = _mm_extract_epi32(v, 0);
330 #endif
331 }
332 
333 // Polyfill for _mm_loadu_si32(src).
334 template <Path path>
335 inline __m128i mm_loadu_si32(const void* src) {
336 #if (defined __clang__) || (defined _MSC_VER)
337   return _mm_loadu_si32(src);
338 #else
339   // GCC 9 lacks support for _mm_loadu_si32.
340   __m128i res;
341   asm("movss %[src], %[res]"
342       : [res] "=x"(res)
343       : [src] "m"(*static_cast<const int*>(src)));
344   return res;
345 #endif
346 }
347 
348 template <Path path>
349 inline __m128i mm256_extracti128_si256(const __m256i&, const int) {
350   RUY_DCHECK(false);
351   return _mm_setzero_si128();
352 }
353 
354 template <Path path>
355 inline void mm256_n_storeu_cvtepi32_epi8(std::uint8_t* dst, int residual_rows,
356                                          const __m256i v) {
357   // Select bytes 0, 4, 8, 12 within each lane, effectively truncating.
358   const __m256i repack_perm = _mm256_set1_epi32(0x0c080400);
359   __m256i shuffled_v;
360   if (residual_rows > 1) {
361     // This selects 0, 4, 8, 12, 0, 4, 8, 12, ..., but we only use the first 4
362     // in each 128-bit lane.
363     shuffled_v = intrin_utils::mm256_shuffle_epi8<path>(v, repack_perm);
364   }
365   switch (residual_rows) {
366     case 0:
367       break;
368     case 1:
369       dst[0] = _mm256_extract_epi8(v, 0);
370       break;
371     case 2:
372       mm_storeu_si16<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
373       break;
374     case 3: {
375       __m128i trailing_packed = mm256_extracti128_si256<path>(shuffled_v, 0);
376       mm_storeu_si16<path>(dst, trailing_packed);
377       dst[2] = _mm_extract_epi8(trailing_packed, 2);
378       break;
379     }
380     case 4:
381       mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
382       break;
383     case 5:
384       mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
385       dst[4] = _mm256_extract_epi8(shuffled_v, 16);
386       break;
387     case 6:
388       mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
389       mm_storeu_si16<path>(dst + 4,
390                            mm256_extracti128_si256<path>(shuffled_v, 1));
391       break;
392     case 7: {
393       mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
394       __m128i trailing_packed = mm256_extracti128_si256<path>(shuffled_v, 1);
395       mm_storeu_si16<path>(dst + 4, trailing_packed);
396       dst[6] = _mm_extract_epi8(trailing_packed, 2);
397       break;
398     }
399     case 8:
400       mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
401       mm_storeu_si32<path>(dst + 4,
402                            mm256_extracti128_si256<path>(shuffled_v, 1));
403       break;
404     default:
405       RUY_DCHECK_LE(residual_rows, 8);
406       break;
407   }
408 }
409 
410 template <Path path>
411 inline void mm256_storeu_cvtepi32_epi8(std::uint8_t* dst, const __m256i v) {
412   // Select bytes 0, 4, 8, 12 within each lane, effectively truncating.
413   const __m256i repack_perm = _mm256_set1_epi32(0x0c080400);
414   const __m256i shuffled_v = mm256_shuffle_epi8<path>(v, repack_perm);
415   mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
416   mm_storeu_si32<path>(dst + 4, mm256_extracti128_si256<path>(shuffled_v, 1));
417 }
418 
419 template <Path path>
420 inline void mm256_n_storeu_cvtepi32_epi8(std::int8_t* dst, int residual_rows,
421                                          const __m256i v) {
422   intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(
423       reinterpret_cast<std::uint8_t*>(dst), residual_rows, v);
424 }
425 
426 template <Path path>
427 inline void mm256_storeu_cvtepi32_epi8(std::int8_t* dst, const __m256i v) {
428   // Select bytes 0, 4, 8, 12 within each lane, effectively truncating.
429   const __m256i repack_perm = _mm256_set1_epi32(0x0c080400);
430   const __m256i shuffled_v = mm256_shuffle_epi8<path>(v, repack_perm);
431   mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
432   mm_storeu_si32<path>(dst + 4, mm256_extracti128_si256<path>(shuffled_v, 1));
433 }
434 
435 template <Path path>
436 inline void mm256_n_storeu_cvtepi32_epi16(std::int16_t* dst, int residual_rows,
437                                           const __m256i v) {
438   // Select bytes 0, 1, 4, 5, 8, 9, 12, 13 within each lane, effectively
439   // truncating each 16-bit integer.
440   const __m256i repack_perm = _mm256_set1_epi64x(0x0d0c090805040100);
441   __m256i shuffled_v;
442   __m128i shuffled_v_low;
443   if (residual_rows > 1) {
444     shuffled_v = mm256_shuffle_epi8<path>(v, repack_perm);
445     shuffled_v_low = mm256_extracti128_si256<path>(shuffled_v, 0);
446   } else {
447     shuffled_v_low = mm256_extracti128_si256<path>(v, 0);
448   }
449   switch (residual_rows) {
450     case 0:
451       break;
452     case 1:
453       mm_storeu_si16<path>(dst, shuffled_v_low);
454       break;
455     case 2:
456       mm_storeu_si32<path>(dst, shuffled_v_low);
457       break;
458     case 3: {
459       mm_storeu_si32<path>(dst, shuffled_v_low);
460       dst[2] = _mm_extract_epi16(shuffled_v_low, 2);
461       break;
462     }
463     case 4:
464       _mm_storeu_si64(dst, shuffled_v_low);
465       break;
466     case 5:
467       _mm_storeu_si64(dst, shuffled_v_low);
468       dst[4] = _mm256_extract_epi16(shuffled_v, 8);
469       break;
470     case 6:
471       _mm_storeu_si64(dst, shuffled_v_low);
472       mm_storeu_si32<path>(dst + 4,
473                            mm256_extracti128_si256<path>(shuffled_v, 1));
474       break;
475     case 7: {
476       _mm_storeu_si64(dst, shuffled_v_low);
477       __m128i trailing_packed = mm256_extracti128_si256<path>(shuffled_v, 1);
478       mm_storeu_si32<path>(dst + 4, trailing_packed);
479       dst[6] = _mm_extract_epi16(trailing_packed, 2);
480       break;
481     }
482     case 8:
483       _mm_storeu_si64(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
484       _mm_storeu_si64(dst + 4, mm256_extracti128_si256<path>(shuffled_v, 1));
485       break;
486     default:
487       RUY_DCHECK_LE(residual_rows, 8);
488       break;
489   }
490 }
491 
492 template <Path path>
493 inline void mm256_storeu_cvtepi32_epi16(std::int16_t* dst, const __m256i v) {
494   // Select bytes 0, 1, 4, 5, 8, 9, 12, 13 within each lane, effectively
495   // truncating each 16-bit integer.
496   const __m256i repack_perm = _mm256_set1_epi64x(0x0d0c090805040100);
497   const __m256i shuffled_v = mm256_shuffle_epi8<path>(v, repack_perm);
498   _mm_storeu_si64(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
499   _mm_storeu_si64(dst + 4, mm256_extracti128_si256<path>(shuffled_v, 1));
500 }
501 
502 template <Path path>
503 inline void mm256_n_storeu_epi32(std::int32_t* dst, int residual_rows,
504                                  const __m256i v) {
505   const __m128i v_low = mm256_extracti128_si256<path>(v, 0);
506   switch (residual_rows) {
507     case 0:
508       break;
509     case 1:
510       mm_storeu_si32<path>(dst, v_low);
511       break;
512     case 2:
513       _mm_storeu_si64(dst, v_low);
514       break;
515     case 3: {
516       __m128i trailing_packed = v_low;
517       _mm_storeu_si64(dst, trailing_packed);
518       dst[2] = _mm_extract_epi32(trailing_packed, 2);
519       break;
520     }
521     case 4:
522       _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low);
523       break;
524     case 5:
525       _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low);
526       dst[4] = _mm256_extract_epi32(v, 4);
527       break;
528     case 6:
529       _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low);
530       _mm_storeu_si64(dst + 4, mm256_extracti128_si256<path>(v, 1));
531       break;
532     case 7: {
533       _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low);
534       __m128i trailing_packed = mm256_extracti128_si256<path>(v, 1);
535       _mm_storeu_si64(dst + 4, trailing_packed);
536       dst[6] = _mm_extract_epi32(trailing_packed, 2);
537       break;
538     }
539     case 8:
540       _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v);
541       break;
542     default:
543       RUY_DCHECK_LE(residual_rows, 8);
544       break;
545   }
546 }
547 
548 template <Path path>
549 inline void mm256_storeu_epi32(std::int32_t* dst, const __m256i v) {
550   _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v);
551 }
552 
553 // Transpose a 8x8 matrix of floats.
554 template <Path path>
555 void mm256_transpose8x8_ps(__m256* v0, __m256* v1, __m256* v2, __m256* v3,
556                            __m256* v4, __m256* v5, __m256* v6, __m256* v7) {
557   __m256 t2x2_0 = _mm256_unpacklo_ps(*v0, *v1);
558   __m256 t2x2_1 = _mm256_unpackhi_ps(*v0, *v1);
559   __m256 t2x2_2 = _mm256_unpacklo_ps(*v2, *v3);
560   __m256 t2x2_3 = _mm256_unpackhi_ps(*v2, *v3);
561   __m256 t2x2_4 = _mm256_unpacklo_ps(*v4, *v5);
562   __m256 t2x2_5 = _mm256_unpackhi_ps(*v4, *v5);
563   __m256 t2x2_6 = _mm256_unpacklo_ps(*v6, *v7);
564   __m256 t2x2_7 = _mm256_unpackhi_ps(*v6, *v7);
565   __m256 t4x4_0 = _mm256_shuffle_ps(t2x2_0, t2x2_2, _MM_SHUFFLE(1, 0, 1, 0));
566   __m256 t4x4_1 = _mm256_shuffle_ps(t2x2_0, t2x2_2, _MM_SHUFFLE(3, 2, 3, 2));
567   __m256 t4x4_2 = _mm256_shuffle_ps(t2x2_1, t2x2_3, _MM_SHUFFLE(1, 0, 1, 0));
568   __m256 t4x4_3 = _mm256_shuffle_ps(t2x2_1, t2x2_3, _MM_SHUFFLE(3, 2, 3, 2));
569   __m256 t4x4_4 = _mm256_shuffle_ps(t2x2_4, t2x2_6, _MM_SHUFFLE(1, 0, 1, 0));
570   __m256 t4x4_5 = _mm256_shuffle_ps(t2x2_4, t2x2_6, _MM_SHUFFLE(3, 2, 3, 2));
571   __m256 t4x4_6 = _mm256_shuffle_ps(t2x2_5, t2x2_7, _MM_SHUFFLE(1, 0, 1, 0));
572   __m256 t4x4_7 = _mm256_shuffle_ps(t2x2_5, t2x2_7, _MM_SHUFFLE(3, 2, 3, 2));
573   *v0 = _mm256_permute2f128_ps(t4x4_0, t4x4_4, 0x20);
574   *v1 = _mm256_permute2f128_ps(t4x4_1, t4x4_5, 0x20);
575   *v2 = _mm256_permute2f128_ps(t4x4_2, t4x4_6, 0x20);
576   *v3 = _mm256_permute2f128_ps(t4x4_3, t4x4_7, 0x20);
577   *v4 = _mm256_permute2f128_ps(t4x4_0, t4x4_4, 0x31);
578   *v5 = _mm256_permute2f128_ps(t4x4_1, t4x4_5, 0x31);
579   *v6 = _mm256_permute2f128_ps(t4x4_2, t4x4_6, 0x31);
580   *v7 = _mm256_permute2f128_ps(t4x4_3, t4x4_7, 0x31);
581 }
582 
583 // Transpose a 8x8 matrix of int32's.
584 template <Path path>
585 void mm256_transpose8x8_epi32(__m256i* v0, __m256i* v1, __m256i* v2,
586                               __m256i* v3, __m256i* v4, __m256i* v5,
587                               __m256i* v6, __m256i* v7) {
588   mm256_transpose8x8_ps<path>(
589       reinterpret_cast<__m256*>(v0), reinterpret_cast<__m256*>(v1),
590       reinterpret_cast<__m256*>(v2), reinterpret_cast<__m256*>(v3),
591       reinterpret_cast<__m256*>(v4), reinterpret_cast<__m256*>(v5),
592       reinterpret_cast<__m256*>(v6), reinterpret_cast<__m256*>(v7));
593 }
594 
595 }  // namespace intrin_utils
596 }  // namespace
597 
598 template <Path path>
599 inline void KernelFloatAvxCommon(const KernelParamsFloat<8, 8>& params) {
600   // As parameters are defined, we need to scale by sizeof(float).
601   const std::int64_t lhs_stride = params.lhs_stride >> 2;
602   const std::int64_t dst_stride = params.dst_stride >> 2;
603   const std::int64_t rhs_stride = params.rhs_stride >> 2;
604   //
605   int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0;
606   // AVX2 float block size = 8.
607   const int end_row = std::min(params.dst_rows, params.last_row + 8);
608   const int end_col = std::min(params.dst_cols, params.last_col + 8);
609   //
610   const float* adj_rhs_col_ptr =
611       params.rhs_base_ptr - params.start_col * rhs_stride;
612   float* adj_dst_col_ptr =
613       params.dst_base_ptr - params.start_col * dst_stride - params.start_row;
614   const float* adj_lhs_col_ptr =
615       params.lhs_base_ptr - params.start_row * lhs_stride;
616   const float* bias_ptr = params.bias;
617 
618   const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max);
619   const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min);
620   const bool channel_dimension_is_col =
621       params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL;
622 
623   int col = params.start_col;
624   // Loop through cols by float block size, leaving incomplete remainder
625   for (; col <= end_col - 8; col += 8) {
626     __m256 accum_data_v[8];
627 
628     const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
629     float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
630 
631     for (int row = params.start_row; row < end_row; row += 8) {
632       const int residual_rows = std::min(end_row - row, 8);
633 
634       const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
635       float* dst_ptr = dst_col_ptr + row;
636 
637       // Initialize with bias.
638       if (channel_dimension_is_col) {
639         const float* bias_elem_ptr = bias_ptr + col * bias_ptr_block_increment;
640         for (int j = 0; j < 8; ++j) {
641           accum_data_v[j] = _mm256_broadcast_ss(bias_elem_ptr + j);
642         }
643       } else {
644         const float* bias_elem_ptr = bias_ptr + row * bias_ptr_block_increment;
645         const __m256 initial_accum_data = _mm256_loadu_ps(bias_elem_ptr);
646 
647         for (int j = 0; j < 8; ++j) {
648           accum_data_v[j] = initial_accum_data;
649         }
650       }
651 
652       const float* lhs_ptr = lhs_col_ptr;
653       const float* rhs_ptr = rhs_col_ptr;
654       for (int d = 0; d < params.depth; ++d) {
655         const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
656         // Load 8 RHS values, then use permute instructions to broadcast each
657         // value to a register. _mm256_permute2f128_ps is slow on AMD.
658         __m256 rhs0_3 =
659             _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr));
660         __m256 rhs4_7 =
661             _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr + 4));
662 
663         const __m256 dup_rhs_element_0 = _mm256_permute_ps(rhs0_3, 0);
664         accum_data_v[0] = intrin_utils::MulAdd<path>(
665             lhs_data, dup_rhs_element_0, accum_data_v[0]);
666 
667         const __m256 dup_rhs_element_1 = _mm256_permute_ps(rhs0_3, 85);
668         accum_data_v[1] = intrin_utils::MulAdd<path>(
669             lhs_data, dup_rhs_element_1, accum_data_v[1]);
670 
671         const __m256 dup_rhs_element_2 = _mm256_permute_ps(rhs0_3, 170);
672         accum_data_v[2] = intrin_utils::MulAdd<path>(
673             lhs_data, dup_rhs_element_2, accum_data_v[2]);
674 
675         const __m256 dup_rhs_element_3 = _mm256_permute_ps(rhs0_3, 255);
676         accum_data_v[3] = intrin_utils::MulAdd<path>(
677             lhs_data, dup_rhs_element_3, accum_data_v[3]);
678 
679         const __m256 dup_rhs_element_4 = _mm256_permute_ps(rhs4_7, 0);
680         accum_data_v[4] = intrin_utils::MulAdd<path>(
681             lhs_data, dup_rhs_element_4, accum_data_v[4]);
682 
683         const __m256 dup_rhs_element_5 = _mm256_permute_ps(rhs4_7, 85);
684         accum_data_v[5] = intrin_utils::MulAdd<path>(
685             lhs_data, dup_rhs_element_5, accum_data_v[5]);
686 
687         const __m256 dup_rhs_element_6 = _mm256_permute_ps(rhs4_7, 170);
688         accum_data_v[6] = intrin_utils::MulAdd<path>(
689             lhs_data, dup_rhs_element_6, accum_data_v[6]);
690 
691         const __m256 dup_rhs_element_7 = _mm256_permute_ps(rhs4_7, 255);
692         accum_data_v[7] = intrin_utils::MulAdd<path>(
693             lhs_data, dup_rhs_element_7, accum_data_v[7]);
694 
695         lhs_ptr += 8;
696         rhs_ptr += 8;
697       }
698 
699       if (residual_rows == 8) {
700         for (int j = 0; j < 8; ++j) {
701           float* block_ptr = dst_ptr + j * dst_stride;
702           accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v);
703           accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v);
704           _mm256_storeu_ps(block_ptr, accum_data_v[j]);
705         }
706       } else {
707         for (int j = 0; j < 8; ++j) {
708           float* block_ptr = dst_ptr + j * dst_stride;
709           accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v);
710           accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v);
711           intrin_utils::mm256_n_storeu_ps<path>(block_ptr, residual_rows,
712                                                 accum_data_v[j]);
713         }
714       }
715     }  // End row-block loop.
716   }    // End col-block loop.
717 
718   if (col < end_col) {
719     // Remaining cols in [0, float block size).
720     RUY_DCHECK_GE(end_col - col, 0);
721     RUY_DCHECK_LT(end_col - col, 8);
722 
723     __m256 accum_data_v[8];
724 
725     const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
726     float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
727     const int residual_cols = std::min(end_col - col, 8);
728 
729     for (int row = params.start_row; row < end_row; row += 8) {
730       const int residual_rows = std::min(end_row - row, 8);
731 
732       const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
733       float* dst_ptr = dst_col_ptr + row;
734 
735       // Initialize with bias.
736       if (channel_dimension_is_col) {
737         const float* bias_elem_ptr = bias_ptr + col * bias_ptr_block_increment;
738         for (int j = 0; j < 8; ++j) {
739           accum_data_v[j] = _mm256_broadcast_ss(bias_elem_ptr + j);
740         }
741       } else {
742         const float* bias_elem_ptr = bias_ptr + row * bias_ptr_block_increment;
743         const __m256 initial_accum_data = _mm256_loadu_ps(bias_elem_ptr);
744 
745         for (int j = 0; j < 8; ++j) {
746           accum_data_v[j] = initial_accum_data;
747         }
748       }
749 
750       const float* lhs_ptr = lhs_col_ptr;
751       const float* rhs_ptr = rhs_col_ptr;
752       for (int d = 0; d < params.depth; ++d) {
753         const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
754 
755         __m256 rhs0_3 =
756             _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr));
757         __m256 rhs4_7 =
758             _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr + 4));
759 
760         const __m256 dup_rhs_element_0 = _mm256_permute_ps(rhs0_3, 0);
761         accum_data_v[0] = intrin_utils::MulAdd<path>(
762             lhs_data, dup_rhs_element_0, accum_data_v[0]);
763 
764         const __m256 dup_rhs_element_1 = _mm256_permute_ps(rhs0_3, 85);
765         accum_data_v[1] = intrin_utils::MulAdd<path>(
766             lhs_data, dup_rhs_element_1, accum_data_v[1]);
767 
768         const __m256 dup_rhs_element_2 = _mm256_permute_ps(rhs0_3, 170);
769         accum_data_v[2] = intrin_utils::MulAdd<path>(
770             lhs_data, dup_rhs_element_2, accum_data_v[2]);
771 
772         const __m256 dup_rhs_element_3 = _mm256_permute_ps(rhs0_3, 255);
773         accum_data_v[3] = intrin_utils::MulAdd<path>(
774             lhs_data, dup_rhs_element_3, accum_data_v[3]);
775 
776         const __m256 dup_rhs_element_4 = _mm256_permute_ps(rhs4_7, 0);
777         accum_data_v[4] = intrin_utils::MulAdd<path>(
778             lhs_data, dup_rhs_element_4, accum_data_v[4]);
779 
780         const __m256 dup_rhs_element_5 = _mm256_permute_ps(rhs4_7, 85);
781         accum_data_v[5] = intrin_utils::MulAdd<path>(
782             lhs_data, dup_rhs_element_5, accum_data_v[5]);
783 
784         const __m256 dup_rhs_element_6 = _mm256_permute_ps(rhs4_7, 170);
785         accum_data_v[6] = intrin_utils::MulAdd<path>(
786             lhs_data, dup_rhs_element_6, accum_data_v[6]);
787 
788         const __m256 dup_rhs_element_7 = _mm256_permute_ps(rhs4_7, 255);
789         accum_data_v[7] = intrin_utils::MulAdd<path>(
790             lhs_data, dup_rhs_element_7, accum_data_v[7]);
791 
792         lhs_ptr += 8;
793         rhs_ptr += 8;
794       }
795 
796       for (int j = 0; j < residual_cols; ++j) {
797         float* block_ptr = dst_ptr + j * dst_stride;
798         accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v);
799         accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v);
800         intrin_utils::mm256_n_storeu_ps<path>(block_ptr, residual_rows,
801                                               accum_data_v[j]);
802       }
803     }  // End row-block loop.
804   }    // End col-block terminal conditional.
805 }
806 
807 template <Path path>
808 inline void KernelFloatAvxCommonSingleCol(
809     const KernelParamsFloat<8, 8>& params) {
810   RUY_DCHECK_EQ(params.dst_cols, 1);
811   RUY_DCHECK_EQ(params.last_col, 0);
812   RUY_DCHECK_EQ(params.start_col, 0);
813 
814   // As parameters are defined, we need to scale by sizeof(float).
815   const std::int64_t lhs_stride = params.lhs_stride >> 2;
816   //
817   int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0;
818   // AVX2 float block size = 8.
819   const int end_row = std::min(params.dst_rows, params.last_row + 8);
820 
821   float* adj_dst_col_ptr = params.dst_base_ptr - params.start_row;
822   const float* adj_lhs_col_ptr =
823       params.lhs_base_ptr - params.start_row * lhs_stride;
824   const float* bias_col_ptr = params.bias;
825 
826   const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max);
827   const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min);
828 
829   __m256 accum_data_v;
830 
831   const float* rhs_col_ptr = params.rhs_base_ptr;
832   float* dst_col_ptr = adj_dst_col_ptr;
833 
834   int row = params.start_row;
835   for (; row <= end_row - 8; row += 8) {
836     const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
837     float* dst_ptr = dst_col_ptr + row;
838     const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
839 
840     // Initialize with bias.
841     accum_data_v = _mm256_loadu_ps(bias_ptr);
842 
843     const float* lhs_ptr = lhs_col_ptr;
844     const float* rhs_ptr = rhs_col_ptr;
845     int d = 0;
846     for (; d <= params.depth - 4; d += 4) {
847       const __m256 lhs_data_0 = _mm256_loadu_ps(lhs_ptr);
848       const __m256 dup_rhs_element_0 = _mm256_set1_ps(rhs_ptr[0]);
849       accum_data_v = intrin_utils::MulAdd<path>(lhs_data_0, dup_rhs_element_0,
850                                                 accum_data_v);
851       const __m256 dup_rhs_element_1 = _mm256_set1_ps(rhs_ptr[8]);
852       const __m256 lhs_data_1 = _mm256_loadu_ps(lhs_ptr + 8);
853       accum_data_v = intrin_utils::MulAdd<path>(lhs_data_1, dup_rhs_element_1,
854                                                 accum_data_v);
855 
856       const __m256 lhs_data_2 = _mm256_loadu_ps(lhs_ptr + 16);
857       const __m256 dup_rhs_element_2 = _mm256_set1_ps(rhs_ptr[16]);
858       accum_data_v = intrin_utils::MulAdd<path>(lhs_data_2, dup_rhs_element_2,
859                                                 accum_data_v);
860       const __m256 dup_rhs_element_3 = _mm256_set1_ps(rhs_ptr[24]);
861       const __m256 lhs_data_3 = _mm256_loadu_ps(lhs_ptr + 24);
862       accum_data_v = intrin_utils::MulAdd<path>(lhs_data_3, dup_rhs_element_3,
863                                                 accum_data_v);
864       lhs_ptr += 32;  // Loaded 8 * 4 floats.
865       rhs_ptr += 32;
866     }
867     for (; d < params.depth; ++d) {
868       const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
869       const float* rhs_data = rhs_ptr;
870 
871       const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[0]);
872       accum_data_v =
873           intrin_utils::MulAdd<path>(lhs_data, dup_rhs_element_j, accum_data_v);
874       lhs_ptr += 8;
875       rhs_ptr += 8;
876     }
877 
878     accum_data_v = _mm256_min_ps(accum_data_v, clamp_max_v);
879     accum_data_v = _mm256_max_ps(accum_data_v, clamp_min_v);
880     _mm256_storeu_ps(dst_ptr, accum_data_v);
881   }  // End row-block loop.
882 
883   if (row < end_row) {
884     const int residual_rows = end_row - row;
885     RUY_CHECK_GE(residual_rows, 1);
886     RUY_CHECK_LT(residual_rows, 8);
887 
888     const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
889     float* dst_ptr = dst_col_ptr + row;
890     const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
891 
892     // Initialize with bias.
893     accum_data_v = _mm256_loadu_ps(bias_ptr);
894 
895     const float* lhs_ptr = lhs_col_ptr;
896     const float* rhs_ptr = rhs_col_ptr;
897     for (int d = 0; d < params.depth; ++d) {
898       const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
899       const float* rhs_data = rhs_ptr;
900 
901       const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[0]);
902       accum_data_v =
903           intrin_utils::MulAdd<path>(lhs_data, dup_rhs_element_j, accum_data_v);
904       lhs_ptr += 8;
905       rhs_ptr += 8;
906     }
907 
908     accum_data_v = _mm256_min_ps(accum_data_v, clamp_max_v);
909     accum_data_v = _mm256_max_ps(accum_data_v, clamp_min_v);
910     intrin_utils::mm256_n_storeu_ps<path>(dst_ptr, residual_rows, accum_data_v);
911   }  // End handling of residual rows.
912 }
913 }  // namespace ruy
914 #endif  //  (RUY_PLATFORM_AVX || RUY_PLATFORM_AVX2_FMA) && RUY_OPT(ASM)
915 
916 #endif  // RUY_RUY_KERNEL_X86_H_
917