xref: /aosp_15_r20/external/ruy/ruy/kernel_avx.cc (revision bb86c7ed5fb1b98a7eac808e443a46cc8b90dfc0)
1 /* Copyright 2020 Google LLC. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <cstdint>
18 #include <cstring>
19 
20 #include "ruy/check_macros.h"
21 #include "ruy/kernel_common.h"
22 #include "ruy/kernel_x86.h"
23 #include "ruy/opt_set.h"
24 #include "ruy/platform.h"
25 #include "ruy/profiler/instrumentation.h"
26 
27 #if RUY_PLATFORM_AVX && RUY_OPT(ASM)
28 #include <immintrin.h>  // IWYU pragma: keep
29 #endif
30 
31 namespace ruy {
32 
33 #if !(RUY_PLATFORM_AVX && RUY_OPT(ASM))
34 
Kernel8bitAvx(const KernelParams8bit<8,8> &)35 void Kernel8bitAvx(const KernelParams8bit<8, 8>&) {
36   // CPU-ID-based checks should disable the path that would reach this point.
37   RUY_DCHECK(false);
38 }
39 
Kernel8bitAvxSingleCol(const KernelParams8bit<8,8> &)40 void Kernel8bitAvxSingleCol(const KernelParams8bit<8, 8>&) {
41   // CPU-ID-based checks should disable the path that would reach this point.
42   RUY_DCHECK(false);
43 }
44 
KernelFloatAvx(const KernelParamsFloat<8,8> &)45 void KernelFloatAvx(const KernelParamsFloat<8, 8>&) {
46   // CPU-ID-based checks should disable the path that would reach this point.
47   RUY_DCHECK(false);
48 }
49 
KernelFloatAvxSingleCol(const KernelParamsFloat<8,8> &)50 void KernelFloatAvxSingleCol(const KernelParamsFloat<8, 8>&) {
51   // CPU-ID-based checks should disable the path that would reach this point.
52   RUY_DCHECK(false);
53 }
54 
55 #else  // RUY_PLATFORM_AVX && RUY_OPT(ASM)
56 
57 static constexpr int kAvx8bitBlockSize = 8;
58 static constexpr int kAvx8bitInnerSize = 4;
59 
60 namespace {
61 namespace intrin_utils {
62 
63 template <>
64 inline __m256i mm256_shuffle_epi8<Path::kAvx>(const __m256i& a,
65                                               const __m256i& b) {
66   __m128i a_lo = _mm256_extractf128_si256(a, 0);
67   __m128i a_hi = _mm256_extractf128_si256(a, 1);
68   __m128i b_lo = _mm256_extractf128_si256(b, 0);
69   __m128i b_hi = _mm256_extractf128_si256(b, 1);
70   __m128i dst_lo = _mm_shuffle_epi8(a_lo, b_lo);
71   __m128i dst_hi = _mm_shuffle_epi8(a_hi, b_hi);
72   return _mm256_set_m128i(dst_hi, dst_lo);
73 }
74 
75 template <>
76 inline __m128i mm256_extracti128_si256<Path::kAvx>(const __m256i& a,
77                                                    const int imm) {
78   switch (imm) {
79     case 0:
80       return _mm256_extractf128_si256(a, 0);
81     case 1:
82       return _mm256_extractf128_si256(a, 1);
83     default:
84       RUY_DCHECK_LT(imm, 2);
85       return _mm_setzero_si128();
86   }
87 }
88 
89 template <Path path>
90 inline __m256i mm256_cvtepi8_epi16(const __m128i& a) {
91   // Take the upper 64 bits of a and put in the first 64 bits of 'hi'
92   __m128i hi = _mm_unpackhi_epi64(a, _mm_setzero_si128());
93   return _mm256_set_m128i(_mm_cvtepi8_epi16(hi), _mm_cvtepi8_epi16(a));
94 }
95 
96 template <Path path>
97 inline __m256i mm256_cvtepi32_epi64(const __m128i& a) {
98   // sign extend the 32-bit values in the lower 64 bits of a.
99   __m128i lo = _mm_cvtepi32_epi64(a);
100   __m128i hi = _mm_cvtepi32_epi64(_mm_unpackhi_epi64(a, _mm_setzero_si128()));
101   return _mm256_set_m128i(hi, lo);
102 }
103 
104 inline __m128i mm_permute_helper(const __m256i& a, const __m256i& b,
105                                  const int imm) {
106   __m128i tmp = _mm_setzero_si128();
107   if (!(imm & 8)) {
108     switch (imm & 3) {
109       case 0:
110         return _mm256_extractf128_si256(a, 0);
111       case 1:
112         return _mm256_extractf128_si256(a, 1);
113       case 2:
114         return _mm256_extractf128_si256(b, 0);
115       case 3:
116         return _mm256_extractf128_si256(b, 1);
117     }
118   }
119   return tmp;
120 }
121 
122 template <Path path>
123 inline __m256i mm256_permute2x128_si256(const __m256i& a, const __m256i& b,
124                                         const int imm) {
125   const int lo_imm = imm & 15;
126   __m128i lo = mm_permute_helper(a, b, lo_imm);
127   const int hi_imm = (imm >> 4) & 15;
128   __m128i hi = mm_permute_helper(a, b, hi_imm);
129   return _mm256_set_m128i(hi, lo);
130 }
131 
132 template <Path path>
133 inline __m256i mm256_max_epi32(const __m256i& a, const __m256i& b) {
134   __m128i a_lo = _mm256_extractf128_si256(a, 0);
135   __m128i a_hi = _mm256_extractf128_si256(a, 1);
136   __m128i b_lo = _mm256_extractf128_si256(b, 0);
137   __m128i b_hi = _mm256_extractf128_si256(b, 1);
138   __m128i lo = _mm_max_epi32(a_lo, b_lo);
139   __m128i hi = _mm_max_epi32(a_hi, b_hi);
140   return _mm256_set_m128i(hi, lo);
141 }
142 
143 template <Path path>
144 inline __m256i mm256_min_epi32(const __m256i& a, const __m256i& b) {
145   __m128i a_lo = _mm256_extractf128_si256(a, 0);
146   __m128i a_hi = _mm256_extractf128_si256(a, 1);
147   __m128i b_lo = _mm256_extractf128_si256(b, 0);
148   __m128i b_hi = _mm256_extractf128_si256(b, 1);
149   __m128i lo = _mm_min_epi32(a_lo, b_lo);
150   __m128i hi = _mm_min_epi32(a_hi, b_hi);
151   return _mm256_set_m128i(hi, lo);
152 }
153 
154 template <Path path>
155 inline __m256i mm256_add_epi32(const __m256i& a, const __m256i& b) {
156   __m128i a_lo = _mm256_extractf128_si256(a, 0);
157   __m128i a_hi = _mm256_extractf128_si256(a, 1);
158   __m128i b_lo = _mm256_extractf128_si256(b, 0);
159   __m128i b_hi = _mm256_extractf128_si256(b, 1);
160   __m128i lo = _mm_add_epi32(a_lo, b_lo);
161   __m128i hi = _mm_add_epi32(a_hi, b_hi);
162   return _mm256_set_m128i(hi, lo);
163 }
164 
165 template <Path path>
166 inline __m256i mm256_add_epi64(const __m256i& a, const __m256i& b) {
167   __m128i a_lo = _mm256_extractf128_si256(a, 0);
168   __m128i a_hi = _mm256_extractf128_si256(a, 1);
169   __m128i b_lo = _mm256_extractf128_si256(b, 0);
170   __m128i b_hi = _mm256_extractf128_si256(b, 1);
171   __m128i lo = _mm_add_epi64(a_lo, b_lo);
172   __m128i hi = _mm_add_epi64(a_hi, b_hi);
173   return _mm256_set_m128i(hi, lo);
174 }
175 
176 template <Path path>
177 inline __m256i mm256_slli_epi64(const __m256i& a, int imm) {
178   __m128i a_lo = _mm256_extractf128_si256(a, 0);
179   __m128i a_hi = _mm256_extractf128_si256(a, 1);
180   __m128i lo = _mm_slli_epi64(a_lo, imm);
181   __m128i hi = _mm_slli_epi64(a_hi, imm);
182   return _mm256_set_m128i(hi, lo);
183 }
184 
185 template <Path path>
186 inline __m256i mm256_mullo_epi32(const __m256i& a, const __m256i& b) {
187   __m128i a_lo = _mm256_extractf128_si256(a, 0);
188   __m128i a_hi = _mm256_extractf128_si256(a, 1);
189   __m128i b_lo = _mm256_extractf128_si256(b, 0);
190   __m128i b_hi = _mm256_extractf128_si256(b, 1);
191   __m128i lo = _mm_mullo_epi32(a_lo, b_lo);
192   __m128i hi = _mm_mullo_epi32(a_hi, b_hi);
193   return _mm256_set_m128i(hi, lo);
194 }
195 
196 // Defined as a macro since `imm` must be an immediate.
197 #define BlendM128_epi32(a, b, imm) \
198   _mm_castps_si128(_mm_blend_ps(_mm_castsi128_ps(a), _mm_castsi128_ps(b), imm))
199 
200 // Defined as a macro since `imm` must be an immediate.
201 #define BlendM128_epi64(a, b, imm) \
202   _mm_castpd_si128(_mm_blend_pd(_mm_castsi128_pd(a), _mm_castsi128_pd(b), imm))
203 
204 // Defined as a macro since `imm` must be an immediate.
205 #define mm256_blend_epi32(ans, a, b, imm)              \
206   __m128i a_lo = _mm256_extractf128_si256(a, 0);       \
207   __m128i a_hi = _mm256_extractf128_si256(a, 1);       \
208   __m128i b_lo = _mm256_extractf128_si256(b, 0);       \
209   __m128i b_hi = _mm256_extractf128_si256(b, 1);       \
210   __m128i lo = BlendM128_epi32(a_lo, b_lo, imm & 0xe); \
211   __m128i hi = BlendM128_epi32(a_hi, b_hi, imm >> 4);  \
212   ans = _mm256_set_m128i(hi, lo);
213 
214 #define mm256_shuffle_epi32(ans, a, a_lo, a_hi, imm)   \
215   a_lo = _mm256_extractf128_si256(a, 0);               \
216   a_hi = _mm256_extractf128_si256(a, 1);               \
217   ans = _mm256_set_m128i(_mm_shuffle_epi32(a_hi, imm), \
218                          _mm_shuffle_epi32(a_lo, imm));
219 
220 template <Path path>
221 inline __m256i mm256_madd_epi16(const __m256i& a, const __m256i& b) {
222   __m128i a_lo = _mm256_extractf128_si256(a, 0);
223   __m128i a_hi = _mm256_extractf128_si256(a, 1);
224   __m128i b_lo = _mm256_extractf128_si256(b, 0);
225   __m128i b_hi = _mm256_extractf128_si256(b, 1);
226   __m128i lo = _mm_madd_epi16(a_lo, b_lo);
227   __m128i hi = _mm_madd_epi16(a_hi, b_hi);
228   return _mm256_set_m128i(hi, lo);
229 }
230 
231 inline __m128i mm_srlv_epi64(const __m128i& a, const __m128i& b) {
232   // shift both elements of a by lower 64bits of b.
233   __m128i res_lo = _mm_srl_epi64(a, b);
234   // shift both elements of a by upper 64bits of b.
235   __m128i hi_count = _mm_unpackhi_epi64(b, _mm_setzero_si128());
236   __m128i res_hi = _mm_srl_epi64(a, hi_count);
237   // Take the lower 64 bits of res_lo and upper 64 bits of res hi
238   // 1. Swap the upper and lower 64 bits of res_hi
239   __m128i tmp_hi =
240       _mm_castpd_si128(_mm_permute_pd(_mm_castsi128_pd(res_hi), 1));
241   // The lower 64 bits of res_lo and the lower 64 bits of tmp_hi.
242   return _mm_unpacklo_epi64(res_lo, tmp_hi);
243 }
244 
245 template <Path path>
246 inline __m256i mm256_srlv_epi64(const __m256i& a, const __m256i& b) {
247   __m128i a_lo = _mm256_extractf128_si256(a, 0);
248   __m128i a_hi = _mm256_extractf128_si256(a, 1);
249   __m128i b_lo = _mm256_extractf128_si256(b, 0);
250   __m128i b_hi = _mm256_extractf128_si256(b, 1);
251   __m128i lo = mm_srlv_epi64(a_lo, b_lo);
252   __m128i hi = mm_srlv_epi64(a_hi, b_hi);
253   return _mm256_set_m128i(hi, lo);
254 }
255 
256 template <Path path>
257 inline __m128i mm_sllv_epi64(const __m128i& a, const __m128i& b) {
258   // shift both elements of a by lower 64bits of b.
259   __m128i res_lo = _mm_sll_epi64(a, b);
260   // shift both elements of a by upper 64bits of b.
261   __m128i hi_count = _mm_unpackhi_epi64(b, _mm_setzero_si128());
262   __m128i res_hi = _mm_sll_epi64(a, hi_count);
263   // Take the lower 64 bits of res_lo and upper 64 bits of res hi
264   // 1. Swap the upper and lower 64 bits of res_hi
265   __m128i tmp_hi =
266       _mm_castpd_si128(_mm_permute_pd(_mm_castsi128_pd(res_hi), 1));
267   // The lower 64 bits of res_lo and the lower 64 bits of tmp_hi.
268   return _mm_unpacklo_epi64(res_lo, tmp_hi);
269 }
270 
271 template <Path path>
272 inline __m256i mm256_sllv_epi64(const __m256i& a, const __m256i& b) {
273   __m128i a_lo = _mm256_extractf128_si256(a, 0);
274   __m128i a_hi = _mm256_extractf128_si256(a, 1);
275   __m128i b_lo = _mm256_extractf128_si256(b, 0);
276   __m128i b_hi = _mm256_extractf128_si256(b, 1);
277   __m128i lo = mm_sllv_epi64<path>(a_lo, b_lo);
278   __m128i hi = mm_sllv_epi64<path>(a_hi, b_hi);
279   return _mm256_set_m128i(hi, lo);
280 }
281 
282 #define PermuteM128_epi32(a, imm) \
283   _mm_castps_si128(_mm_permute_ps(_mm_castsi128_ps(a), imm));
284 
285 inline __m128i mm_sllv_epi32(const __m128i& a, const __m128i& b) {
286   // shift all elements of a by first 32bits of b.
287   __m128i res0 = _mm_sll_epi32(a, BlendM128_epi32(_mm_setzero_si128(), b, 1));
288 
289   // put bits 32-63 of b in the first slot.
290   __m128i tmp1 = PermuteM128_epi32(b, 1);
291   // put bits 32-63 of a in the first slot.
292   __m128i a1 = PermuteM128_epi32(a, 1);
293   // shift all elements of a by second 32bits of b.
294   __m128i res1 =
295       _mm_sll_epi32(a1, BlendM128_epi32(_mm_setzero_si128(), tmp1, 1));
296 
297   // put bits 64-95 of b in the first slot.
298   __m128i tmp2 = PermuteM128_epi32(b, 2);
299   // shift all elements of a by third 32bits of b.
300   __m128i res2 =
301       _mm_sll_epi32(a, BlendM128_epi32(_mm_setzero_si128(), tmp2, 1));
302 
303   // put bits 96-127 of b in the first slot.
304   __m128i tmp3 = PermuteM128_epi32(b, 3);
305   // put bits 96-127 of a in the third slot.
306   __m128i a3 = PermuteM128_epi32(a, 48);
307   // shift all elements of a3 by fourth 32bits of b.
308   __m128i res3 =
309       _mm_sll_epi32(a3, BlendM128_epi32(_mm_setzero_si128(), tmp3, 1));
310 
311   // Take bits 0-31 of res0, bits 0-31 of res1,
312   // bits 64-95 of res2, and bits 64-95 of res3.
313   // res0 _ _ _ 0
314   // res1 _ _ _ 1
315   // res2 _ 2 _ _
316   // res3 _ 3 _ _
317   // f_01 _ _ 1 0
318   // f_23 _ _ 3 2
319   __m128i f_01 = _mm_unpacklo_epi32(res0, res1);
320   __m128i f_23 = _mm_unpackhi_epi32(res2, res3);
321   // The lower 64 bits of res_lo and the lower 64 bits of tmp_hi.
322   return _mm_unpacklo_epi64(f_01, f_23);
323 }
324 
325 template <Path path>
326 inline __m256i mm256_sllv_epi32(const __m256i& a, const __m256i& b) {
327   __m128i a_lo = _mm256_extractf128_si256(a, 0);
328   __m128i a_hi = _mm256_extractf128_si256(a, 1);
329   __m128i b_lo = _mm256_extractf128_si256(b, 0);
330   __m128i b_hi = _mm256_extractf128_si256(b, 1);
331   __m128i lo = mm_sllv_epi32(a_lo, b_lo);
332   __m128i hi = mm_sllv_epi32(a_hi, b_hi);
333   return _mm256_set_m128i(hi, lo);
334 }
335 
336 template <Path path>
337 inline __m256i mm256_sub_epi32(const __m256i& a, const __m256i& b) {
338   __m128i a_lo = _mm256_extractf128_si256(a, 0);
339   __m128i a_hi = _mm256_extractf128_si256(a, 1);
340   __m128i b_lo = _mm256_extractf128_si256(b, 0);
341   __m128i b_hi = _mm256_extractf128_si256(b, 1);
342   __m128i lo = _mm_sub_epi32(a_lo, b_lo);
343   __m128i hi = _mm_sub_epi32(a_hi, b_hi);
344   return _mm256_set_m128i(hi, lo);
345 }
346 
347 template <Path path>
348 inline __m256i mm256_mul_epi32(const __m256i& a, const __m256i& b) {
349   __m128i a_lo = _mm256_extractf128_si256(a, 0);
350   __m128i a_hi = _mm256_extractf128_si256(a, 1);
351   __m128i b_lo = _mm256_extractf128_si256(b, 0);
352   __m128i b_hi = _mm256_extractf128_si256(b, 1);
353   __m128i lo = _mm_mul_epi32(a_lo, b_lo);
354   __m128i hi = _mm_mul_epi32(a_hi, b_hi);
355   return _mm256_set_m128i(hi, lo);
356 }
357 
358 // Perform the equivalent of mm256_permutevar8x32 with
359 // a second argument of {7, 5, 3, 1, 6, 4, 2, 0}
360 template <Path path>
361 inline __m256i PermuteEpi32EvenOdds(const __m256i& a) {
362   // a_lo = 3 2 1 0
363   __m128i a_lo = _mm256_extractf128_si256(a, 0);
364   // a_hi = 7 6 5 4
365   __m128i a_hi = _mm256_extractf128_si256(a, 1);
366   // shuffle a_lo to get 3 1 2 0
367   __m128i tmp_lo = _mm_shuffle_epi32(a_lo, 0xd8);
368   // shuffle a_hi to get 7 5 6 4
369   __m128i tmp_hi = _mm_shuffle_epi32(a_hi, 0xd8);
370   // unpack lo 64 of res_lo and res hi to get 6 4 2 0
371   __m128i res_lo = _mm_unpacklo_epi64(tmp_lo, tmp_hi);
372   // unpack hi 64 of res_lo and res hi to get 7 5 1 3
373   __m128i res_hi = _mm_unpackhi_epi64(tmp_lo, tmp_hi);
374   return _mm256_set_m128i(res_hi, res_lo);
375 }
376 
377 template <Path path>
378 inline __m256i AddBiasEpi32(const __m256i& a, const int32_t* bias, int offset) {
379   const __m256i bias0 = _mm256_set1_epi32(*(bias + offset));
380   return mm256_add_epi32<path>(a, bias0);
381 }
382 
383 __m256i mm256_blendv_epi32(const __m256i& a, const __m256i& b,
384                            const __m256i& mask) {
385   __m256 result =
386       _mm256_blendv_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b),
387                        _mm256_castsi256_ps(mask));
388   return _mm256_castps_si256(result);
389 }
390 
391 template <Path path>
392 inline __m256i mm256_cmpgt_epi32(const __m256i& a, const __m256i& b) {
393   __m128i a_lo = _mm256_extractf128_si256(a, 0);
394   __m128i a_hi = _mm256_extractf128_si256(a, 1);
395   __m128i b_lo = _mm256_extractf128_si256(b, 0);
396   __m128i b_hi = _mm256_extractf128_si256(b, 1);
397   __m128i lo = _mm_cmpgt_epi32(a_lo, b_lo);
398   __m128i hi = _mm_cmpgt_epi32(a_hi, b_hi);
399   return _mm256_set_m128i(hi, lo);
400 }
401 
402 template <Path path>
403 inline __m256i mm256_srav_epi32(const __m256i& a, const __m256i& b) {
404   __m128i a_lo = _mm256_extractf128_si256(a, 0);
405   __m128i a_hi = _mm256_extractf128_si256(a, 1);
406 
407   __m128i r0 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 0));
408   __m128i r1 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 1));
409   __m128i r2 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 2));
410   __m128i r3 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 3));
411   __m128i r4 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 4));
412   __m128i r5 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 5));
413   __m128i r6 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 6));
414   __m128i r7 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 7));
415 
416   // get element 0 from r0, element 1 from r1
417   __m128i r01 = BlendM128_epi32(r0, r1, 2);
418   // get element 2 from r2, element 3 from r3
419   __m128i r23 = BlendM128_epi32(r2, r3, 8);
420   // get element 0 from r4, element 1 from r5
421   __m128i r45 = BlendM128_epi32(r4, r5, 2);
422   // get element 2 from r6, element 3 from r7
423   __m128i r67 = BlendM128_epi32(r6, r7, 8);
424   // get lower 64 bits of r01, upper 64 bits of r23
425   __m128i r0123 = BlendM128_epi64(r01, r23, 2);
426   // get lower 64 bits of r45, upper 64 bits of r67
427   __m128i r4567 = BlendM128_epi64(r45, r67, 2);
428   return _mm256_set_m128i(r4567, r0123);
429 }
430 
431 // AVX doesn't have fused multiply-add so we define an inline function to be
432 // used in the common code following.
433 template <>
434 inline __m256 MulAdd<Path::kAvx>(const __m256& a, const __m256& b,
435                                  const __m256& c) {
436   const __m256 prod = _mm256_mul_ps(a, b);
437   return _mm256_add_ps(prod, c);
438 }
439 
440 }  // namespace intrin_utils
441 }  // namespace
442 
443 template <Path path>
444 void Kernel8bitAvxImpl(const KernelParams8bit<8, 8>& params) {
445   profiler::ScopeLabel label("Kernel kAvx 8-bit");
446   const std::int8_t splitter_idx_data[32] = {
447       0, 1, 4, 5, 8,  9,  12, 13,  //
448       2, 3, 6, 7, 10, 11, 14, 15,  //
449       0, 1, 4, 5, 8,  9,  12, 13,  //
450       2, 3, 6, 7, 10, 11, 14, 15   //
451   };
452 
453   std::int32_t dst_stride = 0;
454   if ((params.dst_type_id == DstTypeId<std::int8_t>::kValue) ||
455       (params.dst_type_id == DstTypeId<std::uint8_t>::kValue)) {
456     dst_stride = params.dst_stride;
457   } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
458     dst_stride = params.dst_stride / sizeof(std::int16_t);
459   } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
460     dst_stride = params.dst_stride / sizeof(std::int32_t);
461   } else {
462     RUY_DCHECK(false);
463   }
464 
465   const std::int8_t* rhs_col_ptr =
466       static_cast<const int8_t*>(params.rhs_base_ptr);
467   void* dst_col_ptr = params.dst_base_ptr;
468 
469   for (int col = params.start_col; col <= params.last_col;
470        col += kAvx8bitBlockSize) {
471     const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
472     void* dst_ptr = dst_col_ptr;
473 
474     const std::int32_t lhs_zero_point = params.lhs_zero_point;
475     const bool has_rhs_sums_offsets =
476         (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point;
477     std::int32_t rhs_sums_offsets[8];
478     if (has_rhs_sums_offsets) {
479       const __m256i rhs_sums_offset_v = intrin_utils::mm256_mullo_epi32<path>(
480           _mm256_set1_epi32(lhs_zero_point),
481           _mm256_loadu_si256(
482               reinterpret_cast<__m256i const*>(&params.rhs_sums[col])));
483       _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets),
484                           rhs_sums_offset_v);
485     }
486 
487     for (int row = params.start_row; row <= params.last_row;
488          row += kAvx8bitBlockSize) {
489       int channel =
490           (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) ? col : row;
491       int multiplier_channel =
492           (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? channel : 0;
493       const int residual_rows =
494           std::min(params.dst_rows - row, kAvx8bitBlockSize);
495       const int residual_cols =
496           std::min(params.dst_cols - col, kAvx8bitBlockSize);
497 
498       const __m256i splitter_idx = _mm256_loadu_si256(
499           reinterpret_cast<__m256i const*>(splitter_idx_data));
500 
501       __m256i accum_data_v0;
502       __m256i accum_data_v1;
503       __m256i accum_data_v2;
504       __m256i accum_data_v3;
505       __m256i accum_data_v4;
506       __m256i accum_data_v5;
507       __m256i accum_data_v6;
508       __m256i accum_data_v7;
509 
510       // initial_accum_data will be the initialize of each of the
511       // accum_data_* accumulator registers. We compute into it terms that are
512       // identical across columns.
513       __m128i initial_accum_data_lo = _mm_set1_epi32(params.prod_zp_depth);
514       __m128i initial_accum_data_hi = _mm_set1_epi32(params.prod_zp_depth);
515 
516       // In the channels-are-rows case, we can load bias here.
517       if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) &&
518           !(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) {
519         initial_accum_data_lo = _mm_add_epi32(
520             initial_accum_data_lo,
521             _mm_loadu_si128(
522                 reinterpret_cast<const __m128i*>(params.bias + row)));
523         initial_accum_data_hi = _mm_add_epi32(
524             initial_accum_data_hi,
525             _mm_loadu_si128(
526                 reinterpret_cast<const __m128i*>(params.bias + row + 4)));
527       }
528 
529       // Adjustments common across columns.
530       const std::int32_t rhs_zero_point = params.rhs_zero_point;
531       if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
532         const __m128i rhs_zp = _mm_set1_epi32(rhs_zero_point);
533         const __m128i lhs_sums_offset_lo = _mm_mullo_epi32(
534             rhs_zp, _mm_loadu_si128(reinterpret_cast<__m128i const*>(
535                         &params.lhs_sums[row])));
536         const __m128i lhs_sums_offset_hi = _mm_mullo_epi32(
537             rhs_zp, _mm_loadu_si128(reinterpret_cast<__m128i const*>(
538                         &params.lhs_sums[row + 4])));
539 
540         initial_accum_data_lo =
541             _mm_sub_epi32(initial_accum_data_lo, lhs_sums_offset_lo);
542         initial_accum_data_hi =
543             _mm_sub_epi32(initial_accum_data_hi, lhs_sums_offset_hi);
544       }
545 
546       // Adjustments differing across columns.
547       if (has_rhs_sums_offsets) {
548         __m256i initial_accum_data =
549             _mm256_set_m128i(initial_accum_data_hi, initial_accum_data_lo);
550 
551         accum_data_v0 = intrin_utils::mm256_sub_epi32<path>(
552             initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[0]));
553         accum_data_v1 = intrin_utils::mm256_sub_epi32<path>(
554             initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[1]));
555         accum_data_v2 = intrin_utils::mm256_sub_epi32<path>(
556             initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[2]));
557         accum_data_v3 = intrin_utils::mm256_sub_epi32<path>(
558             initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[3]));
559         accum_data_v4 = intrin_utils::mm256_sub_epi32<path>(
560             initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[4]));
561         accum_data_v5 = intrin_utils::mm256_sub_epi32<path>(
562             initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[5]));
563         accum_data_v6 = intrin_utils::mm256_sub_epi32<path>(
564             initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[6]));
565         accum_data_v7 = intrin_utils::mm256_sub_epi32<path>(
566             initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[7]));
567       } else {
568         __m256i initial_accum_data =
569             _mm256_set_m128i(initial_accum_data_hi, initial_accum_data_lo);
570         accum_data_v0 = initial_accum_data;
571         accum_data_v1 = initial_accum_data;
572         accum_data_v2 = initial_accum_data;
573         accum_data_v3 = initial_accum_data;
574         accum_data_v4 = initial_accum_data;
575         accum_data_v5 = initial_accum_data;
576         accum_data_v6 = initial_accum_data;
577         accum_data_v7 = initial_accum_data;
578       }
579 
580       // Finally, in the channels-are-columns case, load bias data here.
581       if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) &&
582           (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) {
583         accum_data_v0 = intrin_utils::AddBiasEpi32<path>(accum_data_v0,
584                                                          params.bias + col, 0);
585         accum_data_v1 = intrin_utils::AddBiasEpi32<path>(accum_data_v1,
586                                                          params.bias + col, 1);
587         accum_data_v2 = intrin_utils::AddBiasEpi32<path>(accum_data_v2,
588                                                          params.bias + col, 2);
589         accum_data_v3 = intrin_utils::AddBiasEpi32<path>(accum_data_v3,
590                                                          params.bias + col, 3);
591         accum_data_v4 = intrin_utils::AddBiasEpi32<path>(accum_data_v4,
592                                                          params.bias + col, 4);
593         accum_data_v5 = intrin_utils::AddBiasEpi32<path>(accum_data_v5,
594                                                          params.bias + col, 5);
595         accum_data_v6 = intrin_utils::AddBiasEpi32<path>(accum_data_v6,
596                                                          params.bias + col, 6);
597         accum_data_v7 = intrin_utils::AddBiasEpi32<path>(accum_data_v7,
598                                                          params.bias + col, 7);
599       }
600 
601       const std::int8_t* lhs_ptr = lhs_col_ptr;
602       const std::int8_t* rhs_ptr = rhs_col_ptr;
603       for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) {
604         const __m256i lhs_data =
605             _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr));
606         const __m256i rhs_data_8bit =
607             _mm256_load_si256(reinterpret_cast<const __m256i*>(rhs_ptr));
608 
609         // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
610         std::int32_t rhs_data[16];
611         const __m128i rhs_data_bottom_lane =
612             _mm256_castsi256_si128(rhs_data_8bit);
613         const __m128i rhs_data_top_lane =
614             _mm256_extractf128_si256(rhs_data_8bit, 1);
615         const __m256i rhs_16_bit_dup_low =
616             intrin_utils::mm256_cvtepi8_epi16<path>(rhs_data_bottom_lane);
617         const __m256i rhs_16_bit_dup_high =
618             intrin_utils::mm256_cvtepi8_epi16<path>(rhs_data_top_lane);
619         // Now that we have cast the RHS data, we store it so that each value
620         // can be separately loaded in the accumulation loop.
621         _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data),
622                             rhs_16_bit_dup_low);
623         _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data + 8),
624                             rhs_16_bit_dup_high);
625 
626         // NOTE: There may be opportunities for permuting the data in the
627         // packing code instead of here.
628         const __m256i lhs_data_split =
629             intrin_utils::mm256_shuffle_epi8<path>(lhs_data, splitter_idx);
630         const __m256i lhs_data_split_expand_bottom =
631             intrin_utils::mm256_cvtepi8_epi16<path>(
632                 _mm256_extractf128_si256(lhs_data_split, 0));
633         const __m256i lhs_data_split_expand_top =
634             intrin_utils::mm256_cvtepi8_epi16<path>(
635                 _mm256_extractf128_si256(lhs_data_split, 1));
636 
637         // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit.
638         const __m256i lhs_16_bit_low =
639             intrin_utils::mm256_permute2x128_si256<path>(
640                 lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20);
641         // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit.
642         const __m256i lhs_16_bit_high =
643             intrin_utils::mm256_permute2x128_si256<path>(
644                 lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31);
645 
646         __m256i rhs0 = _mm256_lddqu_si256(reinterpret_cast<const __m256i*>(
647             rhs_data));  // Load [0 1 2 3 4 5 6 7]
648         __m256i rhs1 = _mm256_lddqu_si256(
649             reinterpret_cast<const __m256i*>(rhs_data + 8));  // Load [8 - 15]
650         __m256i rhs0_3 =
651             _mm256_permute2f128_si256(rhs0, rhs0, 0);  // [0 1 2 3 0 1 2 3]
652         __m256i rhs4_7 =
653             _mm256_permute2f128_si256(rhs0, rhs0, 0x11);  // [4 5 6 7 4 5 6 7]
654         __m256i rhs8_11 =
655             _mm256_permute2f128_si256(rhs1, rhs1, 0);  // [8 9 10 11 8 9 10 11]
656         __m256i rhs12_15 =
657             _mm256_permute2f128_si256(rhs1, rhs1, 17);  // [12 - 15, 12 - 15]
658 
659         auto process_column = [=](__m256i& rhs_dup_lo, __m256i& rhs_dup_hi,
660                                   __m256i& accum) {
661           // Perform mul-adds on low and high components of accum separately.
662           __m128i accum_lo = _mm256_extractf128_si256(accum, 0);
663           __m128i accum_hi = _mm256_extractf128_si256(accum, 1);
664 
665           __m128i lhs_lo_0 = _mm256_extractf128_si256(lhs_16_bit_low, 0);
666           __m128i lhs_lo_1 = _mm256_extractf128_si256(lhs_16_bit_low, 1);
667           __m128i rhs_dup_lo_0 = _mm256_extractf128_si256(rhs_dup_lo, 0);
668           __m128i rhs_dup_lo_1 = _mm256_extractf128_si256(rhs_dup_lo, 1);
669           __m128i lo_0 = _mm_madd_epi16(lhs_lo_0, rhs_dup_lo_0);
670           __m128i lo_1 = _mm_madd_epi16(lhs_lo_1, rhs_dup_lo_1);
671 
672           accum_lo = _mm_add_epi32(accum_lo, lo_0);
673           accum_hi = _mm_add_epi32(accum_hi, lo_1);
674 
675           __m128i lhs_hi_0 = _mm256_extractf128_si256(lhs_16_bit_high, 0);
676           __m128i lhs_hi_1 = _mm256_extractf128_si256(lhs_16_bit_high, 1);
677           __m128i rhs_dup_hi_0 = _mm256_extractf128_si256(rhs_dup_hi, 0);
678           __m128i rhs_dup_hi_1 = _mm256_extractf128_si256(rhs_dup_hi, 1);
679           __m128i hi_0 = _mm_madd_epi16(lhs_hi_0, rhs_dup_hi_0);
680           __m128i hi_1 = _mm_madd_epi16(lhs_hi_1, rhs_dup_hi_1);
681 
682           accum_lo = _mm_add_epi32(accum_lo, hi_0);
683           accum_hi = _mm_add_epi32(accum_hi, hi_1);
684           accum = _mm256_set_m128i(accum_hi, accum_lo);
685         };
686         __m256i tmp0, tmp1, tmp2, tmp3;
687         __m128i lo0, lo1, hi0, hi1;
688         mm256_shuffle_epi32(tmp0, rhs0_3, lo0, hi0, 0);
689         mm256_shuffle_epi32(tmp1, rhs0_3, lo1, hi1, 0x55);
690         process_column(tmp0, tmp1, accum_data_v0);
691         mm256_shuffle_epi32(tmp2, rhs0_3, lo0, hi0, 0xaa);
692         mm256_shuffle_epi32(tmp3, rhs0_3, lo1, hi1, 0xff);
693         process_column(tmp2, tmp3, accum_data_v1);
694 
695         mm256_shuffle_epi32(tmp0, rhs4_7, lo0, hi0, 0);
696         mm256_shuffle_epi32(tmp1, rhs4_7, lo1, hi1, 0x55);
697         process_column(tmp0, tmp1, accum_data_v2);
698         mm256_shuffle_epi32(tmp2, rhs4_7, lo0, hi0, 0xaa);
699         mm256_shuffle_epi32(tmp3, rhs4_7, lo1, hi1, 0xff);
700         process_column(tmp2, tmp3, accum_data_v3);
701 
702         mm256_shuffle_epi32(tmp0, rhs8_11, lo0, hi0, 0);
703         mm256_shuffle_epi32(tmp1, rhs8_11, lo1, hi1, 0x55);
704         process_column(tmp0, tmp1, accum_data_v4);
705         mm256_shuffle_epi32(tmp2, rhs8_11, lo0, hi0, 0xaa);
706         mm256_shuffle_epi32(tmp3, rhs8_11, lo1, hi1, 0xff);
707         process_column(tmp2, tmp3, accum_data_v5);
708 
709         mm256_shuffle_epi32(tmp0, rhs12_15, lo0, hi0, 0);
710         mm256_shuffle_epi32(tmp1, rhs12_15, lo1, hi1, 0x55);
711         process_column(tmp0, tmp1, accum_data_v6);
712         mm256_shuffle_epi32(tmp2, rhs12_15, lo0, hi0, 0xaa);
713         mm256_shuffle_epi32(tmp3, rhs12_15, lo1, hi1, 0xff);
714         process_column(tmp2, tmp3, accum_data_v7);
715 
716         lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
717         rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
718       }
719 
720       if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
721         __m256i m_vector;
722         __m256i e_vector;
723         // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
724         m_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
725             params.multiplier_fixedpoint + multiplier_channel));
726         e_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
727             params.multiplier_exponent + multiplier_channel));
728 
729         const __m256i m_64bit_low = intrin_utils::mm256_cvtepi32_epi64<path>(
730             _mm256_extractf128_si256(m_vector, 0));
731         const __m256i m_64bit_high = intrin_utils::mm256_cvtepi32_epi64<path>(
732             _mm256_extractf128_si256(m_vector, 1));
733 
734         const __m256i zero_vector = _mm256_setzero_si256();
735         const __m256i left_shift =
736             intrin_utils::mm256_max_epi32<path>(e_vector, zero_vector);
737         const __m256i neg_e_vector =
738             intrin_utils::mm256_sub_epi32<path>(zero_vector, e_vector);
739         const __m256i right_shift =
740             intrin_utils::mm256_max_epi32<path>(neg_e_vector, zero_vector);
741         const __m256i final_right_shift = _mm256_set1_epi32(31);
742         const __m256i final_right_shift_low =
743             intrin_utils::mm256_cvtepi32_epi64<path>(
744                 _mm256_extractf128_si256(final_right_shift, 0));
745         const __m256i final_right_shift_high =
746             intrin_utils::mm256_cvtepi32_epi64<path>(
747                 _mm256_extractf128_si256(final_right_shift, 1));
748         const __m256i convert_to_unsigned_64 =
749             _mm256_set1_epi64x(0x8000000000000000);
750 
751         __m256i post_scaling_offset = _mm256_setzero_si256();
752 
753         // A "half" added for rounding prior to truncation of 64-bit value.
754         const __m256i offset_vector = intrin_utils::mm256_add_epi64<path>(
755             intrin_utils::mm256_slli_epi64<path>(_mm256_set1_epi64x(1), 30),
756             convert_to_unsigned_64);
757 
758         if (params.dst_zero_point) {
759           post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point);
760         }
761 
762         // We cannot do
763         //
764         // scaled_v_low =
765         //     _mm256_srav_epi64(scaled_v_low, final_right_shift_low);
766         // scaled_v_high =
767         //     _mm256_srav_epi64(scaled_v_high, final_right_shift_high);
768         //
769         // since this instruction is not in AVX2. Instead we use
770         // _mm256_srlv_epi64, but this is an unsigned shift, so we applied
771         // offsets before (convert_to_unsigned_64) and after
772         // (convert_to_signed_halved).
773         //
774         // The overall process is, for 64-bit scaled accumulator:
775         // unsigned_accum = signed_accum + 1 << 63;
776         // unsigned_accum = (unsigned_accum >> right_shift) >> 31;
777         // signed_accum = unsigned_accum - ((1 << 32) >> right_shift) / 2 * 2;
778 
779         // There are various ways to repack the results, in the absence of
780         // _mm256_cvtepi64_epi32() or anything like it.
781         // A.
782         // accum_data_v[j] =
783         //     _mm256_set_epi32(_mm256_extract_epi32(scaled_v_high, 6),
784         //                      _mm256_extract_epi32(scaled_v_high, 4),
785         //                      _mm256_extract_epi32(scaled_v_high, 2),
786         //                      _mm256_extract_epi32(scaled_v_high, 0),
787         //                      _mm256_extract_epi32(scaled_v_low, 6),
788         //                      _mm256_extract_epi32(scaled_v_low, 4),
789         //                      _mm256_extract_epi32(scaled_v_low, 2),
790         //                      _mm256_extract_epi32(scaled_v_low, 0));
791         // B.
792         // scaled_v_low = _mm256_shuffle_epi32(scaled_v_low, 0xd8);
793         // scaled_v_high = _mm256_shuffle_epi32(scaled_v_high, 0xd8);
794         // accum_data_v[j] =
795         //     _mm256_set_epi64x(_mm256_extract_epi64(scaled_v_high, 2),
796         //                       _mm256_extract_epi64(scaled_v_high, 0),
797         //                       _mm256_extract_epi64(scaled_v_low, 2),
798         //                       _mm256_extract_epi64(scaled_v_low, 0));
799         // C.
800         // scaled_v_low =
801         //     _mm256_permutevar8x32_epi32(scaled_v_low, repack_perm);
802         // scaled_v_high =
803         //     _mm256_permutevar8x32_epi32(scaled_v_high, repack_perm);
804         // accum_data_v[j] =
805         //     _mm256_permute2x128_si256(scaled_v_low, scaled_v_high, 0x20);
806         //
807         // However, we choose the following because it uses two lighter
808         // instructions. The permutation does have a longer latency, but this
809         // loop can be unrolled.
810         // D.
811         // scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
812         // __m256i results =
813         //     _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
814         // results = _mm256_permutevar8x32_epi32(results, repack_perm);
815         // accum_data_v[j] = intrin_utils::mm256_add_epi32<path>(results,
816         // post_scaling_offset);
817 
818         // This multiplier code is complex and expensive enough on x86, that
819         // we prefer to implement the channels-are-columns case by transposing
820         // around it, rather than duplicate it (which would also require
821         // duplicating the above code computing the multiplier constants).
822         // This is one instance where channels-are-columns has lower performance
823         // than channels-are-rows.
824         const bool transpose_around_multiplier =
825             (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) &&
826             (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL);
827         if (transpose_around_multiplier) {
828           // Transpose the 8x8 accumulators block. Will be un-transposed below
829           // after the multplier implementation.
830           intrin_utils::mm256_transpose8x8_epi32<path>(
831               &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3,
832               &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7);
833         }
834 
835         auto rounding_right_shift = [=](__m256i& results,
836                                         const __m256i& exponent) {
837           // Construct the "nudge" value for each lane if the exponent is
838           // greater than 0. Otherwise, the nudge is 0.
839           const __m256i zeros = _mm256_setzero_si256();
840           const __m256i mask_rightshift_gtz =
841               intrin_utils::mm256_cmpgt_epi32<path>(exponent, zeros);
842           const __m256i one_shift_exp_minus1 =
843               intrin_utils::mm256_sllv_epi32<path>(
844                   _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>(
845                                             exponent, _mm256_set1_epi32(1)));
846           __m256i nudge = intrin_utils::mm256_blendv_epi32(
847               zeros, one_shift_exp_minus1, mask_rightshift_gtz);
848           // Calculate the shifted sum (results + nudge) >> exp.
849           const __m256i r_plus_nudge =
850               intrin_utils::mm256_add_epi32<path>(results, nudge);
851           const __m256i shifted_sum =
852               intrin_utils::mm256_srav_epi32<path>(r_plus_nudge, exponent);
853 
854           // Identify overflow in each lane and create mask.
855           const __m256i one_shift_31minus_exp =
856               intrin_utils::mm256_sllv_epi32<path>(
857                   _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>(
858                                             _mm256_set1_epi32(31), exponent));
859           const __m256i mask_num_plus_nudge_overflow =
860               intrin_utils::mm256_cmpgt_epi32<path>(
861                   results, intrin_utils::mm256_sub_epi32<path>(
862                                _mm256_set1_epi32(0x7fffffff), nudge));
863           // Fill results with either (results + nudge) >> exponent or
864           // 1 << (31 - exp) in the case of overflow.
865           results = intrin_utils::mm256_blendv_epi32(
866               shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow);
867         };
868 
869         auto apply_multiplier = [=](__m256i& accum) {
870           __m256i shifted_accum =
871               intrin_utils::mm256_sllv_epi32<path>(accum, left_shift);
872           // Apply the fixed-point part of the multiplier.
873           __m256i scaled_v_low = intrin_utils::mm256_mul_epi32<path>(
874               intrin_utils::mm256_cvtepi32_epi64<path>(
875                   _mm256_extractf128_si256(shifted_accum, 0)),
876               m_64bit_low);
877           __m256i scaled_v_high = intrin_utils::mm256_mul_epi32<path>(
878               intrin_utils::mm256_cvtepi32_epi64<path>(
879                   _mm256_extractf128_si256(shifted_accum, 1)),
880               m_64bit_high);
881           scaled_v_low = intrin_utils::mm256_add_epi64<path>(scaled_v_low,
882                                                              offset_vector);
883           scaled_v_high = intrin_utils::mm256_add_epi64<path>(
884               scaled_v_high, offset_vector);
885 
886           scaled_v_low = intrin_utils::mm256_srlv_epi64<path>(
887               scaled_v_low, final_right_shift_low);
888           scaled_v_high = intrin_utils::mm256_srlv_epi64<path>(
889               scaled_v_high, final_right_shift_high);
890           scaled_v_high =
891               intrin_utils::mm256_slli_epi64<path>(scaled_v_high, 32);
892           __m256i results;
893           mm256_blend_epi32(results, scaled_v_low, scaled_v_high, 0xaa);
894           // Permute results to this ordering of int32 elements
895           // lo->hi (0, 2, 4, 6, 1, 3, 5, 7);
896           results = intrin_utils::PermuteEpi32EvenOdds<path>(results);
897 
898           rounding_right_shift(results, right_shift);
899           accum =
900               intrin_utils::mm256_add_epi32<path>(results, post_scaling_offset);
901         };
902         apply_multiplier(accum_data_v0);
903         apply_multiplier(accum_data_v1);
904         apply_multiplier(accum_data_v2);
905         apply_multiplier(accum_data_v3);
906         apply_multiplier(accum_data_v4);
907         apply_multiplier(accum_data_v5);
908         apply_multiplier(accum_data_v6);
909         apply_multiplier(accum_data_v7);
910         // See above comment: here we transpose again to undo the transposition
911         // of the 8x8 block of accumulators used to implement the
912         // channels-are-columns case.
913         if (transpose_around_multiplier) {
914           intrin_utils::mm256_transpose8x8_epi32<path>(
915               &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3,
916               &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7);
917         }
918       }
919       const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max);
920       const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min);
921       const bool store_full_block = (residual_rows == kAvx8bitBlockSize) &&
922                                     (residual_cols == kAvx8bitBlockSize);
923 
924       __m256i accum_data_v[kAvx8bitBlockSize];
925       if (!store_full_block) {
926         accum_data_v[0] = accum_data_v0;
927         accum_data_v[1] = accum_data_v1;
928         accum_data_v[2] = accum_data_v2;
929         accum_data_v[3] = accum_data_v3;
930         accum_data_v[4] = accum_data_v4;
931         accum_data_v[5] = accum_data_v5;
932         accum_data_v[6] = accum_data_v6;
933         accum_data_v[7] = accum_data_v7;
934       }
935 
936       if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
937         std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
938         if (store_full_block) {
939           accum_data_v0 =
940               intrin_utils::mm256_min_epi32<path>(accum_data_v0, clamp_max_v);
941           accum_data_v0 =
942               intrin_utils::mm256_max_epi32<path>(accum_data_v0, clamp_min_v);
943           accum_data_v1 =
944               intrin_utils::mm256_min_epi32<path>(accum_data_v1, clamp_max_v);
945           accum_data_v1 =
946               intrin_utils::mm256_max_epi32<path>(accum_data_v1, clamp_min_v);
947           accum_data_v2 =
948               intrin_utils::mm256_min_epi32<path>(accum_data_v2, clamp_max_v);
949           accum_data_v2 =
950               intrin_utils::mm256_max_epi32<path>(accum_data_v2, clamp_min_v);
951           accum_data_v3 =
952               intrin_utils::mm256_min_epi32<path>(accum_data_v3, clamp_max_v);
953           accum_data_v3 =
954               intrin_utils::mm256_max_epi32<path>(accum_data_v3, clamp_min_v);
955           accum_data_v4 =
956               intrin_utils::mm256_min_epi32<path>(accum_data_v4, clamp_max_v);
957           accum_data_v4 =
958               intrin_utils::mm256_max_epi32<path>(accum_data_v4, clamp_min_v);
959           accum_data_v5 =
960               intrin_utils::mm256_min_epi32<path>(accum_data_v5, clamp_max_v);
961           accum_data_v5 =
962               intrin_utils::mm256_max_epi32<path>(accum_data_v5, clamp_min_v);
963           accum_data_v6 =
964               intrin_utils::mm256_min_epi32<path>(accum_data_v6, clamp_max_v);
965           accum_data_v6 =
966               intrin_utils::mm256_max_epi32<path>(accum_data_v6, clamp_min_v);
967           accum_data_v7 =
968               intrin_utils::mm256_min_epi32<path>(accum_data_v7, clamp_max_v);
969           accum_data_v7 =
970               intrin_utils::mm256_max_epi32<path>(accum_data_v7, clamp_min_v);
971           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
972               &tmp_ptr[0 * dst_stride], accum_data_v0);
973           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
974               &tmp_ptr[1 * dst_stride], accum_data_v1);
975           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
976               &tmp_ptr[2 * dst_stride], accum_data_v2);
977           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
978               &tmp_ptr[3 * dst_stride], accum_data_v3);
979           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
980               &tmp_ptr[4 * dst_stride], accum_data_v4);
981           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
982               &tmp_ptr[5 * dst_stride], accum_data_v5);
983           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
984               &tmp_ptr[6 * dst_stride], accum_data_v6);
985           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
986               &tmp_ptr[7 * dst_stride], accum_data_v7);
987         } else {
988           for (int j = 0; j < residual_cols; ++j) {
989             __m256i result = accum_data_v[j];
990             result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v);
991             result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v);
992             intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(
993                 tmp_ptr, residual_rows, result);
994             tmp_ptr += dst_stride;
995           }
996         }
997         dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) +
998                                      kAvx8bitBlockSize);
999       } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
1000         std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
1001         if (store_full_block) {
1002           accum_data_v0 =
1003               intrin_utils::mm256_min_epi32<path>(accum_data_v0, clamp_max_v);
1004           accum_data_v0 =
1005               intrin_utils::mm256_max_epi32<path>(accum_data_v0, clamp_min_v);
1006           accum_data_v1 =
1007               intrin_utils::mm256_min_epi32<path>(accum_data_v1, clamp_max_v);
1008           accum_data_v1 =
1009               intrin_utils::mm256_max_epi32<path>(accum_data_v1, clamp_min_v);
1010           accum_data_v2 =
1011               intrin_utils::mm256_min_epi32<path>(accum_data_v2, clamp_max_v);
1012           accum_data_v2 =
1013               intrin_utils::mm256_max_epi32<path>(accum_data_v2, clamp_min_v);
1014           accum_data_v3 =
1015               intrin_utils::mm256_min_epi32<path>(accum_data_v3, clamp_max_v);
1016           accum_data_v3 =
1017               intrin_utils::mm256_max_epi32<path>(accum_data_v3, clamp_min_v);
1018           accum_data_v4 =
1019               intrin_utils::mm256_min_epi32<path>(accum_data_v4, clamp_max_v);
1020           accum_data_v4 =
1021               intrin_utils::mm256_max_epi32<path>(accum_data_v4, clamp_min_v);
1022           accum_data_v5 =
1023               intrin_utils::mm256_min_epi32<path>(accum_data_v5, clamp_max_v);
1024           accum_data_v5 =
1025               intrin_utils::mm256_max_epi32<path>(accum_data_v5, clamp_min_v);
1026           accum_data_v6 =
1027               intrin_utils::mm256_min_epi32<path>(accum_data_v6, clamp_max_v);
1028           accum_data_v6 =
1029               intrin_utils::mm256_max_epi32<path>(accum_data_v6, clamp_min_v);
1030           accum_data_v7 =
1031               intrin_utils::mm256_min_epi32<path>(accum_data_v7, clamp_max_v);
1032           accum_data_v7 =
1033               intrin_utils::mm256_max_epi32<path>(accum_data_v7, clamp_min_v);
1034           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(&tmp_ptr[0],
1035                                                          accum_data_v0);
1036           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(&tmp_ptr[dst_stride],
1037                                                          accum_data_v1);
1038           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
1039               &tmp_ptr[2 * dst_stride], accum_data_v2);
1040           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
1041               &tmp_ptr[3 * dst_stride], accum_data_v3);
1042           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
1043               &tmp_ptr[4 * dst_stride], accum_data_v4);
1044           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
1045               &tmp_ptr[5 * dst_stride], accum_data_v5);
1046           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
1047               &tmp_ptr[6 * dst_stride], accum_data_v6);
1048           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
1049               &tmp_ptr[7 * dst_stride], accum_data_v7);
1050         } else {
1051           for (int j = 0; j < residual_cols; ++j) {
1052             __m256i result = accum_data_v[j];
1053             result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v);
1054             result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v);
1055             intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(
1056                 tmp_ptr, residual_rows, result);
1057             tmp_ptr += dst_stride;
1058           }
1059         }
1060         dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) +
1061                                      kAvx8bitBlockSize);
1062       } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
1063         std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
1064         if (store_full_block) {
1065           accum_data_v0 =
1066               intrin_utils::mm256_min_epi32<path>(accum_data_v0, clamp_max_v);
1067           accum_data_v0 =
1068               intrin_utils::mm256_max_epi32<path>(accum_data_v0, clamp_min_v);
1069           accum_data_v1 =
1070               intrin_utils::mm256_min_epi32<path>(accum_data_v1, clamp_max_v);
1071           accum_data_v1 =
1072               intrin_utils::mm256_max_epi32<path>(accum_data_v1, clamp_min_v);
1073           accum_data_v2 =
1074               intrin_utils::mm256_min_epi32<path>(accum_data_v2, clamp_max_v);
1075           accum_data_v2 =
1076               intrin_utils::mm256_max_epi32<path>(accum_data_v2, clamp_min_v);
1077           accum_data_v3 =
1078               intrin_utils::mm256_min_epi32<path>(accum_data_v3, clamp_max_v);
1079           accum_data_v3 =
1080               intrin_utils::mm256_max_epi32<path>(accum_data_v3, clamp_min_v);
1081           accum_data_v4 =
1082               intrin_utils::mm256_min_epi32<path>(accum_data_v4, clamp_max_v);
1083           accum_data_v4 =
1084               intrin_utils::mm256_max_epi32<path>(accum_data_v4, clamp_min_v);
1085           accum_data_v5 =
1086               intrin_utils::mm256_min_epi32<path>(accum_data_v5, clamp_max_v);
1087           accum_data_v5 =
1088               intrin_utils::mm256_max_epi32<path>(accum_data_v5, clamp_min_v);
1089           accum_data_v6 =
1090               intrin_utils::mm256_min_epi32<path>(accum_data_v6, clamp_max_v);
1091           accum_data_v6 =
1092               intrin_utils::mm256_max_epi32<path>(accum_data_v6, clamp_min_v);
1093           accum_data_v7 =
1094               intrin_utils::mm256_min_epi32<path>(accum_data_v7, clamp_max_v);
1095           accum_data_v7 =
1096               intrin_utils::mm256_max_epi32<path>(accum_data_v7, clamp_min_v);
1097           intrin_utils::mm256_storeu_cvtepi32_epi16<path>(&tmp_ptr[0],
1098                                                           accum_data_v0);
1099           intrin_utils::mm256_storeu_cvtepi32_epi16<path>(&tmp_ptr[dst_stride],
1100                                                           accum_data_v1);
1101           intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
1102               &tmp_ptr[2 * dst_stride], accum_data_v2);
1103           intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
1104               &tmp_ptr[3 * dst_stride], accum_data_v3);
1105           intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
1106               &tmp_ptr[4 * dst_stride], accum_data_v4);
1107           intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
1108               &tmp_ptr[5 * dst_stride], accum_data_v5);
1109           intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
1110               &tmp_ptr[6 * dst_stride], accum_data_v6);
1111           intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
1112               &tmp_ptr[7 * dst_stride], accum_data_v7);
1113         } else {
1114           for (int j = 0; j < residual_cols; ++j) {
1115             __m256i result = accum_data_v[j];
1116             result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v);
1117             result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v);
1118             intrin_utils::mm256_n_storeu_cvtepi32_epi16<path>(
1119                 tmp_ptr, residual_rows, result);
1120             tmp_ptr += dst_stride;
1121           }
1122         }
1123         dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) +
1124                                      kAvx8bitBlockSize);
1125       } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
1126         if (store_full_block) {
1127           std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr);
1128           intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[0], accum_data_v0);
1129           intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[dst_stride],
1130                                                  accum_data_v1);
1131           intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[2 * dst_stride],
1132                                                  accum_data_v2);
1133           intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[3 * dst_stride],
1134                                                  accum_data_v3);
1135           intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[4 * dst_stride],
1136                                                  accum_data_v4);
1137           intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[5 * dst_stride],
1138                                                  accum_data_v5);
1139           intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[6 * dst_stride],
1140                                                  accum_data_v6);
1141           intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[7 * dst_stride],
1142                                                  accum_data_v7);
1143         } else {
1144           std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr);
1145           for (int j = 0; j < residual_cols; ++j) {
1146             intrin_utils::mm256_n_storeu_epi32<path>(
1147                 dst_block_ptr, residual_rows, accum_data_v[j]);
1148             dst_block_ptr += dst_stride;
1149           }
1150         }
1151         dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) +
1152                                      kAvx8bitBlockSize);
1153       } else {
1154         RUY_DCHECK(false);
1155       }
1156 
1157       lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride;
1158     }  // End row-block loop.
1159 
1160     dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
1161                                      kAvx8bitBlockSize * params.dst_stride);
1162     rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride;
1163   }  // End col-block loop.
1164 }  // NOLINT(readability/fn_size)
1165 
1166 void Kernel8bitAvx(const KernelParams8bit<8, 8>& params) {
1167   Kernel8bitAvxImpl<Path::kAvx>(params);
1168 }
1169 
1170 template <Path path>
1171 void Kernel8bitAvxSingleColImpl(const KernelParams8bit<8, 8>& params) {
1172   profiler::ScopeLabel label("Kernel kAvx2 8-bit GEMV");
1173 
1174   RUY_DCHECK_EQ(params.dst_cols, 1);
1175   RUY_DCHECK_EQ(params.last_col, 0);
1176   RUY_DCHECK_EQ(params.start_col, 0);
1177 
1178   const std::int8_t splitter_idx_data[32] = {
1179       0, 1, 4, 5, 8,  9,  12, 13,  //
1180       2, 3, 6, 7, 10, 11, 14, 15,  //
1181       0, 1, 4, 5, 8,  9,  12, 13,  //
1182       2, 3, 6, 7, 10, 11, 14, 15   //
1183   };
1184 
1185   int bias_ptr_block_increment =
1186       params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0;
1187 
1188   const std::int8_t* rhs_col_ptr =
1189       static_cast<const int8_t*>(params.rhs_base_ptr);
1190   void* dst_col_ptr = params.dst_base_ptr;
1191   const std::int32_t* bias_col_ptr = params.bias;
1192   if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
1193     bias_col_ptr += params.start_row;
1194   }
1195 
1196   const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
1197   void* dst_ptr = dst_col_ptr;
1198   const std::int32_t* bias_ptr = bias_col_ptr;
1199 
1200   const std::int32_t lhs_zero_point = params.lhs_zero_point;
1201   const bool has_rhs_sums_offsets =
1202       (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point;
1203   std::int32_t rhs_sums_offsets[8];
1204   if (has_rhs_sums_offsets) {
1205     const __m256i rhs_sums_offset_v = intrin_utils::mm256_mullo_epi32<path>(
1206         _mm256_set1_epi32(lhs_zero_point),
1207         _mm256_loadu_si256(
1208             reinterpret_cast<__m256i const*>(&params.rhs_sums[0])));
1209     _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets),
1210                         rhs_sums_offset_v);
1211   }
1212 
1213   for (int row = params.start_row; row <= params.last_row;
1214        row += kAvx8bitBlockSize) {
1215     const int residual_rows =
1216         std::min(params.dst_rows - row, kAvx8bitBlockSize);
1217 
1218     const __m256i splitter_idx =
1219         _mm256_loadu_si256(reinterpret_cast<__m256i const*>(splitter_idx_data));
1220 
1221     __m256i accum_data_v0;
1222 
1223     // Initialize with bias.
1224     __m256i initial_accum_data =
1225         _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias_ptr));
1226     bias_ptr += bias_ptr_block_increment;
1227 
1228     // Adjustments common across columns.
1229     const std::int32_t rhs_zero_point = params.rhs_zero_point;
1230     if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
1231       const __m256i lhs_sums_offset = intrin_utils::mm256_mullo_epi32<path>(
1232           _mm256_set1_epi32(rhs_zero_point),
1233           _mm256_loadu_si256(
1234               reinterpret_cast<__m256i const*>(&params.lhs_sums[row])));
1235       initial_accum_data = intrin_utils::mm256_sub_epi32<path>(
1236           initial_accum_data, lhs_sums_offset);
1237     }
1238     const std::int32_t prod_zp_depth = params.prod_zp_depth;
1239     if (prod_zp_depth) {
1240       initial_accum_data = intrin_utils::mm256_add_epi32<path>(
1241           initial_accum_data, _mm256_set1_epi32(prod_zp_depth));
1242     }
1243 
1244     // Adjustments differing across columns.
1245     if (has_rhs_sums_offsets) {
1246       accum_data_v0 = intrin_utils::mm256_sub_epi32<path>(
1247           initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[0]));
1248     } else {
1249       accum_data_v0 = initial_accum_data;
1250     }
1251 
1252     const std::int8_t* lhs_ptr = lhs_col_ptr;
1253     const std::int8_t* rhs_ptr = rhs_col_ptr;
1254     for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) {
1255       const __m256i lhs_data =
1256           _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr));
1257       const __m128i rhs_data_8bit = intrin_utils::mm_loadu_si32<path>(rhs_ptr);
1258 
1259       // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
1260       // For simplicity we load 4x the data that we need and process twice the
1261       // data  that we need  and store only the data we need.
1262       std::int32_t rhs_data[2];
1263       const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit);
1264       // Now that we have cast the RHS data, we store it so that each value
1265       // can be separately loaded in the accumulation loop.
1266       _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup);
1267 
1268       // NOTE: There may be opportunities for permuting the data in the packing
1269       // code instead of here.
1270       const __m256i lhs_data_split =
1271           intrin_utils::mm256_shuffle_epi8<path>(lhs_data, splitter_idx);
1272       const __m256i lhs_data_split_expand_bottom =
1273           intrin_utils::mm256_cvtepi8_epi16<path>(
1274               _mm256_extractf128_si256(lhs_data_split, 0));
1275       const __m256i lhs_data_split_expand_top =
1276           intrin_utils::mm256_cvtepi8_epi16<path>(
1277               _mm256_extractf128_si256(lhs_data_split, 1));
1278 
1279       // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit.
1280       const __m256i lhs_16_bit_low =
1281           intrin_utils::mm256_permute2x128_si256<path>(
1282               lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20);
1283       // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit.
1284       const __m256i lhs_16_bit_high =
1285           intrin_utils::mm256_permute2x128_si256<path>(
1286               lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31);
1287       // Accumulate for column 0.
1288       const std::int32_t low_rhs_value = rhs_data[0];
1289       const std::int32_t high_rhs_value = rhs_data[1];
1290 
1291       const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value);
1292       const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value);
1293 
1294       accum_data_v0 = intrin_utils::mm256_add_epi32<path>(
1295           accum_data_v0, intrin_utils::mm256_madd_epi16<path>(
1296                              lhs_16_bit_low, rhs_16_bit_dup_low));
1297       accum_data_v0 = intrin_utils::mm256_add_epi32<path>(
1298           accum_data_v0, intrin_utils::mm256_madd_epi16<path>(
1299                              lhs_16_bit_high, rhs_16_bit_dup_high));
1300 
1301       lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
1302       rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
1303     }
1304 
1305     if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
1306       __m256i m_vector;
1307       __m256i e_vector;
1308       // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
1309       int channel = (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? row : 0;
1310       m_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
1311           params.multiplier_fixedpoint + channel));
1312       e_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
1313           params.multiplier_exponent + channel));
1314 
1315       const __m256i m_64bit_low = intrin_utils::mm256_cvtepi32_epi64<path>(
1316           _mm256_extractf128_si256(m_vector, 0));
1317       const __m256i m_64bit_high = intrin_utils::mm256_cvtepi32_epi64<path>(
1318           _mm256_extractf128_si256(m_vector, 1));
1319 
1320       const __m256i zero_vector = _mm256_setzero_si256();
1321       const __m256i left_shift =
1322           intrin_utils::mm256_max_epi32<path>(e_vector, zero_vector);
1323       const __m256i neg_e_vector =
1324           intrin_utils::mm256_sub_epi32<path>(zero_vector, e_vector);
1325       const __m256i right_shift =
1326           intrin_utils::mm256_max_epi32<path>(neg_e_vector, zero_vector);
1327       const __m256i final_right_shift = _mm256_set1_epi32(31);
1328       const __m256i final_right_shift_low =
1329           intrin_utils::mm256_cvtepi32_epi64<path>(
1330               _mm256_extractf128_si256(final_right_shift, 0));
1331       const __m256i final_right_shift_high =
1332           intrin_utils::mm256_cvtepi32_epi64<path>(
1333               _mm256_extractf128_si256(final_right_shift, 1));
1334       const __m256i convert_to_unsigned_64 =
1335           _mm256_set1_epi64x(0x8000000000000000);
1336 
1337       __m256i post_scaling_offset = _mm256_setzero_si256();
1338 
1339       // A "half" added for rounding prior to truncation of 64-bit value.
1340       const __m256i offset_vector = intrin_utils::mm256_add_epi64<path>(
1341           intrin_utils::mm256_slli_epi64<path>(_mm256_set1_epi64x(1), 30),
1342           convert_to_unsigned_64);
1343 
1344       if (params.dst_zero_point) {
1345         post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point);
1346       }
1347 
1348       // See GEMM version for details of this process.
1349       {
1350         __m256i shifted_accum =
1351             intrin_utils::mm256_sllv_epi32<path>(accum_data_v0, left_shift);
1352         // Apply the fixed-point part of the multiplier.
1353         __m256i scaled_v_low = intrin_utils::mm256_mul_epi32<path>(
1354             intrin_utils::mm256_cvtepi32_epi64<path>(
1355                 _mm256_extractf128_si256(shifted_accum, 0)),
1356             m_64bit_low);
1357         __m256i scaled_v_high = intrin_utils::mm256_mul_epi32<path>(
1358             intrin_utils::mm256_cvtepi32_epi64<path>(
1359                 _mm256_extractf128_si256(shifted_accum, 1)),
1360             m_64bit_high);
1361 
1362         scaled_v_low = intrin_utils::mm256_add_epi64<path>(scaled_v_low,
1363                                                            offset_vector);
1364         scaled_v_high = intrin_utils::mm256_add_epi64<path>(scaled_v_high,
1365                                                             offset_vector);
1366 
1367         scaled_v_low = intrin_utils::mm256_srlv_epi64<path>(
1368             scaled_v_low, final_right_shift_low);
1369         scaled_v_high = intrin_utils::mm256_srlv_epi64<path>(
1370             scaled_v_high, final_right_shift_high);
1371 
1372         scaled_v_high = intrin_utils::mm256_slli_epi64<path>(scaled_v_high, 32);
1373         __m256i results;
1374         mm256_blend_epi32(results, scaled_v_low, scaled_v_high, 0xaa);
1375         // Permute results to this ordering of int32 elements
1376         // lo->hi (0, 2, 4, 6, 1, 3, 5, 7);
1377         results = intrin_utils::PermuteEpi32EvenOdds<path>(results);
1378 
1379         // Now perform the Rounding Right Shift.
1380         // First, construct the "nudge" value for each lane if the exponent is
1381         // greater than 0. Otherwise, the nudge is 0.
1382         const __m256i zeros = _mm256_setzero_si256();
1383         const __m256i mask_rightshift_gtz =
1384             intrin_utils::mm256_cmpgt_epi32<path>(right_shift, zeros);
1385         const __m256i one_shift_exp_minus1 =
1386             intrin_utils::mm256_sllv_epi32<path>(
1387                 _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>(
1388                                           right_shift, _mm256_set1_epi32(1)));
1389         __m256i nudge = intrin_utils::mm256_blendv_epi32(
1390             zeros, one_shift_exp_minus1, mask_rightshift_gtz);
1391         // Calculate the shifted sum (results + nudge) >> exp.
1392         const __m256i r_plus_nudge =
1393             intrin_utils::mm256_add_epi32<path>(results, nudge);
1394         const __m256i shifted_sum =
1395             intrin_utils::mm256_srav_epi32<path>(r_plus_nudge, right_shift);
1396 
1397         // Identify overflow in each lane and create mask.
1398         const __m256i one_shift_31minus_exp =
1399             intrin_utils::mm256_sllv_epi32<path>(
1400                 _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>(
1401                                           _mm256_set1_epi32(31), right_shift));
1402         const __m256i mask_num_plus_nudge_overflow =
1403             intrin_utils::mm256_cmpgt_epi32<path>(
1404                 results, intrin_utils::mm256_sub_epi32<path>(
1405                              _mm256_set1_epi32(0x7fffffff), nudge));
1406         // Fill results with either (results + nudge) >> exponent or
1407         // 1 << (31 - exp) in the case of overflow.
1408         results = intrin_utils::mm256_blendv_epi32(
1409             shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow);
1410         accum_data_v0 =
1411             intrin_utils::mm256_add_epi32<path>(results, post_scaling_offset);
1412       }
1413     }
1414     const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max);
1415     const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min);
1416 
1417     if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
1418       std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
1419       __m256i result = accum_data_v0;
1420       result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v);
1421       result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v);
1422       intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(tmp_ptr, residual_rows,
1423                                                        result);
1424       dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) +
1425                                    kAvx8bitBlockSize);
1426     } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
1427       std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
1428       __m256i result = accum_data_v0;
1429       result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v);
1430       result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v);
1431       intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(tmp_ptr, residual_rows,
1432                                                        result);
1433       dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) +
1434                                    kAvx8bitBlockSize);
1435     } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
1436       std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
1437       __m256i result = accum_data_v0;
1438       result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v);
1439       result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v);
1440       intrin_utils::mm256_n_storeu_cvtepi32_epi16<path>(tmp_ptr, residual_rows,
1441                                                         result);
1442       dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) +
1443                                    kAvx8bitBlockSize);
1444     } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
1445       std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr);
1446       intrin_utils::mm256_n_storeu_epi32<path>(dst_block_ptr, residual_rows,
1447                                                accum_data_v0);
1448       dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) +
1449                                    kAvx8bitBlockSize);
1450     } else {
1451       RUY_DCHECK(false);
1452     }
1453 
1454     lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride;
1455   }  // End row-block loop.
1456 
1457   dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
1458                                    kAvx8bitBlockSize * params.dst_stride);
1459   rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride;
1460 }  // NOLINT(readability/fn_size)
1461 
1462 void Kernel8bitAvxSingleCol(const KernelParams8bit<8, 8>& params) {
1463   Kernel8bitAvxSingleColImpl<Path::kAvx>(params);
1464 }
1465 
1466 void KernelFloatAvx(const KernelParamsFloat<8, 8>& params) {
1467   profiler::ScopeLabel label("Kernel kAvx float");
1468   KernelFloatAvxCommon<Path::kAvx>(params);
1469 }
1470 
1471 void KernelFloatAvxSingleCol(const KernelParamsFloat<8, 8>& params) {
1472   profiler::ScopeLabel label("Kernel kAvx float GEMV");
1473   KernelFloatAvxCommonSingleCol<Path::kAvx>(params);
1474 }
1475 
1476 #endif  //  RUY_PLATFORM_AVX && RUY_OPT(ASM)
1477 
1478 }  // namespace ruy
1479