xref: /aosp_15_r20/external/ruy/ruy/kernel_avx2_fma.cc (revision bb86c7ed5fb1b98a7eac808e443a46cc8b90dfc0)
1 /* Copyright 2019 Google LLC. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <cstdint>
18 #include <cstring>
19 
20 #include "ruy/check_macros.h"
21 #include "ruy/kernel_common.h"
22 #include "ruy/kernel_x86.h"
23 #include "ruy/opt_set.h"
24 #include "ruy/platform.h"
25 #include "ruy/profiler/instrumentation.h"
26 
27 #if RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)
28 #include <immintrin.h>  // IWYU pragma: keep
29 #endif
30 
31 namespace ruy {
32 
33 #if !(RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM))
34 
Kernel8bitAvx2(const KernelParams8bit<8,8> &)35 void Kernel8bitAvx2(const KernelParams8bit<8, 8>&) {
36   // CPU-ID-based checks should disable the path that would reach this point.
37   RUY_DCHECK(false);
38 }
39 
Kernel8bitAvx2SingleCol(const KernelParams8bit<8,8> &)40 void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>&) {
41   // CPU-ID-based checks should disable the path that would reach this point.
42   RUY_DCHECK(false);
43 }
44 
KernelFloatAvx2(const KernelParamsFloat<8,8> &)45 void KernelFloatAvx2(const KernelParamsFloat<8, 8>&) {
46   // CPU-ID-based checks should disable the path that would reach this point.
47   RUY_DCHECK(false);
48 }
49 
KernelFloatAvx2SingleCol(const KernelParamsFloat<8,8> &)50 void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>&) {
51   // CPU-ID-based checks should disable the path that would reach this point.
52   RUY_DCHECK(false);
53 }
54 
55 #else  // RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)
56 
57 static constexpr int kAvx8bitBlockSize = 8;
58 static constexpr int kAvx8bitInnerSize = 4;
59 
60 namespace {
61 namespace intrin_utils {
62 
63 template <>
64 inline __m256i mm256_shuffle_epi8<Path::kAvx2Fma>(const __m256i& a,
65                                                   const __m256i& b) {
66   return _mm256_shuffle_epi8(a, b);
67 }
68 
69 // Make an inline function for FMA so we can share the float kernels
70 // with non-FMA code.
71 template <>
72 inline __m256 MulAdd<Path::kAvx2Fma>(const __m256& a, const __m256& b,
73                                      const __m256& c) {
74   return _mm256_fmadd_ps(a, b, c);
75 }
76 
77 template <>
78 inline __m128i mm256_extracti128_si256<Path::kAvx2Fma>(const __m256i& a,
79                                                        const int imm) {
80   switch (imm) {
81     case 0:
82       return _mm256_extracti128_si256(a, 0);
83     case 1:
84       return _mm256_extracti128_si256(a, 1);
85     default:
86       RUY_DCHECK_LT(imm, 2);
87       return _mm_setzero_si128();
88   }
89 }
90 
91 __m256i mm256_blendv_epi32(const __m256i& a, const __m256i& b,
92                            const __m256i& mask) {
93   __m256 result =
94       _mm256_blendv_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b),
95                        _mm256_castsi256_ps(mask));
96   return _mm256_castps_si256(result);
97 }
98 
99 }  // namespace intrin_utils
100 }  // namespace
101 
102 template <Path path>
103 void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) {
104   profiler::ScopeLabel label("Kernel kAvx2Fma 8-bit");
105   const std::int8_t splitter_idx_data[32] = {
106       0, 1, 4, 5, 8,  9,  12, 13,  //
107       2, 3, 6, 7, 10, 11, 14, 15,  //
108       0, 1, 4, 5, 8,  9,  12, 13,  //
109       2, 3, 6, 7, 10, 11, 14, 15   //
110   };
111 
112   std::int32_t dst_stride = 0;
113   if ((params.dst_type_id == DstTypeId<std::int8_t>::kValue) ||
114       (params.dst_type_id == DstTypeId<std::uint8_t>::kValue)) {
115     dst_stride = params.dst_stride;
116   } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
117     dst_stride = params.dst_stride / sizeof(std::int16_t);
118   } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
119     dst_stride = params.dst_stride / sizeof(std::int32_t);
120   } else {
121     RUY_DCHECK(false);
122   }
123 
124   const void* rhs_col_ptr = params.rhs_base_ptr;
125   void* dst_col_ptr = params.dst_base_ptr;
126 
127   for (int col = params.start_col; col <= params.last_col;
128        col += kAvx8bitBlockSize) {
129     const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
130     void* dst_ptr = dst_col_ptr;
131 
132     const std::int32_t lhs_zero_point = params.lhs_zero_point;
133     const bool has_rhs_sums_offsets =
134         (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point;
135     std::int32_t rhs_sums_offsets[8];
136     if (has_rhs_sums_offsets) {
137       const __m256i rhs_sums_offset_v = _mm256_mullo_epi32(
138           _mm256_set1_epi32(lhs_zero_point),
139           _mm256_loadu_si256(
140               reinterpret_cast<__m256i const*>(&params.rhs_sums[col])));
141       _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets),
142                           rhs_sums_offset_v);
143     }
144 
145     for (int row = params.start_row; row <= params.last_row;
146          row += kAvx8bitBlockSize) {
147       int channel =
148           (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) ? col : row;
149       int multiplier_channel =
150           (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? channel : 0;
151       const int residual_rows =
152           std::min(params.dst_rows - row, kAvx8bitBlockSize);
153       const int residual_cols =
154           std::min(params.dst_cols - col, kAvx8bitBlockSize);
155 
156       const __m256i splitter_idx = _mm256_loadu_si256(
157           reinterpret_cast<__m256i const*>(splitter_idx_data));
158 
159       __m256i accum_data_v0;
160       __m256i accum_data_v1;
161       __m256i accum_data_v2;
162       __m256i accum_data_v3;
163       __m256i accum_data_v4;
164       __m256i accum_data_v5;
165       __m256i accum_data_v6;
166       __m256i accum_data_v7;
167 
168       // initial_accum_data will be the initialize of each of the
169       // accum_data_* accumulator registers. We compute into it terms that are
170       // identical across columns.
171       __m256i initial_accum_data = _mm256_set1_epi32(params.prod_zp_depth);
172 
173       // In the channels-are-rows case, we can load bias here.
174       if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) &&
175           !(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) {
176         initial_accum_data = _mm256_add_epi32(
177             initial_accum_data,
178             _mm256_loadu_si256(
179                 reinterpret_cast<const __m256i*>(params.bias + row)));
180       }
181 
182       // Adjustments common across columns.
183       const std::int32_t rhs_zero_point = params.rhs_zero_point;
184       if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
185         const __m256i lhs_sums_offset = _mm256_mullo_epi32(
186             _mm256_set1_epi32(rhs_zero_point),
187             _mm256_loadu_si256(
188                 reinterpret_cast<__m256i const*>(&params.lhs_sums[row])));
189         initial_accum_data =
190             _mm256_sub_epi32(initial_accum_data, lhs_sums_offset);
191       }
192 
193       // Adjustments differing across columns.
194       if (has_rhs_sums_offsets) {
195         accum_data_v0 = _mm256_sub_epi32(
196             initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[0]));
197         accum_data_v1 = _mm256_sub_epi32(
198             initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[1]));
199         accum_data_v2 = _mm256_sub_epi32(
200             initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[2]));
201         accum_data_v3 = _mm256_sub_epi32(
202             initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[3]));
203         accum_data_v4 = _mm256_sub_epi32(
204             initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[4]));
205         accum_data_v5 = _mm256_sub_epi32(
206             initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[5]));
207         accum_data_v6 = _mm256_sub_epi32(
208             initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[6]));
209         accum_data_v7 = _mm256_sub_epi32(
210             initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[7]));
211       } else {
212         accum_data_v0 = initial_accum_data;
213         accum_data_v1 = initial_accum_data;
214         accum_data_v2 = initial_accum_data;
215         accum_data_v3 = initial_accum_data;
216         accum_data_v4 = initial_accum_data;
217         accum_data_v5 = initial_accum_data;
218         accum_data_v6 = initial_accum_data;
219         accum_data_v7 = initial_accum_data;
220       }
221 
222       // Finally, in the channels-are-columns case, load bias data here.
223       if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) &&
224           (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) {
225         const __m256i bias_data = _mm256_loadu_si256(
226             reinterpret_cast<const __m256i*>(params.bias + col));
227         accum_data_v0 = _mm256_add_epi32(
228             accum_data_v0,
229             _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(0)));
230         accum_data_v1 = _mm256_add_epi32(
231             accum_data_v1,
232             _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(1)));
233         accum_data_v2 = _mm256_add_epi32(
234             accum_data_v2,
235             _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(2)));
236         accum_data_v3 = _mm256_add_epi32(
237             accum_data_v3,
238             _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(3)));
239         accum_data_v4 = _mm256_add_epi32(
240             accum_data_v4,
241             _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(4)));
242         accum_data_v5 = _mm256_add_epi32(
243             accum_data_v5,
244             _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(5)));
245         accum_data_v6 = _mm256_add_epi32(
246             accum_data_v6,
247             _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(6)));
248         accum_data_v7 = _mm256_add_epi32(
249             accum_data_v7,
250             _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(7)));
251       }
252 
253       const std::int8_t* lhs_ptr = lhs_col_ptr;
254       const void* rhs_ptr = rhs_col_ptr;
255       for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) {
256         const __m256i lhs_data =
257             _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr));
258         const __m256i rhs_data_8bit =
259             _mm256_load_si256(reinterpret_cast<const __m256i*>(rhs_ptr));
260 
261         // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
262         std::int32_t rhs_data_buf[16];
263         const std::int32_t* rhs_data =
264             reinterpret_cast<const std::int32_t*>(rhs_ptr);
265 
266         if (params.rhs_scalar_size == 1) {
267           rhs_data = rhs_data_buf;
268           const __m128i rhs_data_bottom_lane =
269               _mm256_castsi256_si128(rhs_data_8bit);
270           const __m128i rhs_data_top_lane =
271               _mm256_extracti128_si256(rhs_data_8bit, 1);
272           const __m256i rhs_16_bit_dup_low =
273               _mm256_cvtepi8_epi16(rhs_data_bottom_lane);
274           const __m256i rhs_16_bit_dup_high =
275               _mm256_cvtepi8_epi16(rhs_data_top_lane);
276           // Now that we have cast the RHS data, we store it so that each value
277           // can be separately loaded in the accumulation loop.
278           _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data_buf),
279                               rhs_16_bit_dup_low);
280           _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data_buf + 8),
281                               rhs_16_bit_dup_high);
282         } else {
283           RUY_DCHECK(params.rhs_scalar_size == 2);
284         }
285 
286         const __m256i lhs_data_split =
287             _mm256_shuffle_epi8(lhs_data, splitter_idx);
288         const __m256i lhs_data_split_expand_bottom =
289             _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 0));
290         const __m256i lhs_data_split_expand_top =
291             _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 1));
292 
293         // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit.
294         const __m256i lhs_16_bit_low = _mm256_permute2x128_si256(
295             lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20);
296         // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit.
297         const __m256i lhs_16_bit_high = _mm256_permute2x128_si256(
298             lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31);
299 
300         __m256i rhs0 = _mm256_lddqu_si256(reinterpret_cast<const __m256i*>(
301             rhs_data));  // Load [0 1 2 3 4 5 6 7]
302         __m256i rhs1 = _mm256_lddqu_si256(
303             reinterpret_cast<const __m256i*>(rhs_data + 8));  // Load [8 - 15]
304         __m256i rhs0_3 =
305             _mm256_permute2f128_si256(rhs0, rhs0, 0);  // [0 1 2 3 0 1 2 3]
306         __m256i rhs4_7 =
307             _mm256_permute2f128_si256(rhs0, rhs0, 0x11);  // [4 5 6 7 4 5 6 7]
308         __m256i rhs8_11 =
309             _mm256_permute2f128_si256(rhs1, rhs1, 0);  // [8 9 10 11 8 9 10 11]
310         __m256i rhs12_15 =
311             _mm256_permute2f128_si256(rhs1, rhs1, 17);  // [12 - 15, 12 - 15]
312 
313         auto process_column = [=](__m256i& rhs_dup_lo, __m256i& rhs_dup_hi,
314                                   __m256i& accum) {
315           accum = _mm256_add_epi32(
316               accum, _mm256_madd_epi16(lhs_16_bit_low, rhs_dup_lo));
317           accum = _mm256_add_epi32(
318               accum, _mm256_madd_epi16(lhs_16_bit_high, rhs_dup_hi));
319         };
320         __m256i tmp0, tmp1, tmp2, tmp3;
321         tmp0 = _mm256_shuffle_epi32(rhs0_3, 0);
322         tmp1 = _mm256_shuffle_epi32(rhs0_3, 0x55);
323         process_column(tmp0, tmp1, accum_data_v0);
324         tmp2 = _mm256_shuffle_epi32(rhs0_3, 0xaa);
325         tmp3 = _mm256_shuffle_epi32(rhs0_3, 0xff);
326         process_column(tmp2, tmp3, accum_data_v1);
327 
328         tmp0 = _mm256_shuffle_epi32(rhs4_7, 0);
329         tmp1 = _mm256_shuffle_epi32(rhs4_7, 0x55);
330         process_column(tmp0, tmp1, accum_data_v2);
331         tmp2 = _mm256_shuffle_epi32(rhs4_7, 0xaa);
332         tmp3 = _mm256_shuffle_epi32(rhs4_7, 0xff);
333         process_column(tmp2, tmp3, accum_data_v3);
334 
335         tmp0 = _mm256_shuffle_epi32(rhs8_11, 0);
336         tmp1 = _mm256_shuffle_epi32(rhs8_11, 0x55);
337         process_column(tmp0, tmp1, accum_data_v4);
338         tmp2 = _mm256_shuffle_epi32(rhs8_11, 0xaa);
339         tmp3 = _mm256_shuffle_epi32(rhs8_11, 0xff);
340         process_column(tmp2, tmp3, accum_data_v5);
341 
342         tmp0 = _mm256_shuffle_epi32(rhs12_15, 0);
343         tmp1 = _mm256_shuffle_epi32(rhs12_15, 0x55);
344         process_column(tmp0, tmp1, accum_data_v6);
345         tmp2 = _mm256_shuffle_epi32(rhs12_15, 0xaa);
346         tmp3 = _mm256_shuffle_epi32(rhs12_15, 0xff);
347         process_column(tmp2, tmp3, accum_data_v7);
348 
349         lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
350         rhs_ptr = static_cast<const void*>(
351             static_cast<const char*>(rhs_ptr) +
352             kAvx8bitBlockSize * kAvx8bitInnerSize * params.rhs_scalar_size);
353       }
354 
355       if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
356         __m256i m_vector;
357         __m256i e_vector;
358         // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
359         m_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
360             params.multiplier_fixedpoint + multiplier_channel));
361         e_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
362             params.multiplier_exponent + multiplier_channel));
363 
364         const __m256i m_64bit_low =
365             _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 0));
366         const __m256i m_64bit_high =
367             _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 1));
368 
369         const __m256i zero_vector = _mm256_setzero_si256();
370         const __m256i left_shift = _mm256_max_epi32(e_vector, zero_vector);
371         const __m256i neg_e_vector = _mm256_sub_epi32(zero_vector, e_vector);
372         const __m256i right_shift = _mm256_max_epi32(neg_e_vector, zero_vector);
373         const __m256i final_right_shift = _mm256_set1_epi32(31);
374         const __m256i final_right_shift_low = _mm256_cvtepi32_epi64(
375             _mm256_extracti128_si256(final_right_shift, 0));
376         const __m256i final_right_shift_high = _mm256_cvtepi32_epi64(
377             _mm256_extracti128_si256(final_right_shift, 1));
378         const __m256i convert_to_unsigned_64 =
379             _mm256_set1_epi64x(0x8000000000000000);
380 
381         __m256i post_scaling_offset = _mm256_setzero_si256();
382         // A "half" added for rounding prior to truncation of 64-bit value.
383         const __m256i offset_vector = _mm256_add_epi64(
384             _mm256_slli_epi64(_mm256_set1_epi64x(1), 30),
385             convert_to_unsigned_64);
386 
387         if (params.dst_zero_point) {
388           post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point);
389         }
390 
391         const __m256i repack_perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
392 
393         // We cannot do
394         //
395         // scaled_v_low =
396         //     _mm256_srav_epi64(scaled_v_low, final_right_shift_low);
397         // scaled_v_high =
398         //     _mm256_srav_epi64(scaled_v_high, final_right_shift_high);
399         //
400         // since this instruction is not in AVX2. Instead we use
401         // _mm256_srlv_epi64, but this is an unsigned shift, so we applied
402         // offsets before (convert_to_unsigned_64) and after
403         // (convert_to_signed_halved).
404         //
405         // The overall process is, for 64-bit scaled accumulator:
406         // unsigned_accum = signed_accum + 1 << 63;
407         // unsigned_accum = (unsigned_accum >> right_shift) >> 31;
408         // signed_accum = unsigned_accum - ((1 << 32) >> right_shift) / 2 * 2;
409 
410         // There are various ways to repack the results, in the absence of
411         // _mm256_cvtepi64_epi32() or anything like it.
412         // A.
413         // accum_data_v[j] =
414         //     _mm256_set_epi32(_mm256_extract_epi32(scaled_v_high, 6),
415         //                      _mm256_extract_epi32(scaled_v_high, 4),
416         //                      _mm256_extract_epi32(scaled_v_high, 2),
417         //                      _mm256_extract_epi32(scaled_v_high, 0),
418         //                      _mm256_extract_epi32(scaled_v_low, 6),
419         //                      _mm256_extract_epi32(scaled_v_low, 4),
420         //                      _mm256_extract_epi32(scaled_v_low, 2),
421         //                      _mm256_extract_epi32(scaled_v_low, 0));
422         // B.
423         // scaled_v_low = _mm256_shuffle_epi32(scaled_v_low, 0xd8);
424         // scaled_v_high = _mm256_shuffle_epi32(scaled_v_high, 0xd8);
425         // accum_data_v[j] =
426         //     _mm256_set_epi64x(_mm256_extract_epi64(scaled_v_high, 2),
427         //                       _mm256_extract_epi64(scaled_v_high, 0),
428         //                       _mm256_extract_epi64(scaled_v_low, 2),
429         //                       _mm256_extract_epi64(scaled_v_low, 0));
430         // C.
431         // scaled_v_low =
432         //     _mm256_permutevar8x32_epi32(scaled_v_low, repack_perm);
433         // scaled_v_high =
434         //     _mm256_permutevar8x32_epi32(scaled_v_high, repack_perm);
435         // accum_data_v[j] =
436         //     _mm256_permute2x128_si256(scaled_v_low, scaled_v_high, 0x20);
437         //
438         // However, we choose the following because it uses two lighter
439         // instructions. The permutation does have a longer latency, but this
440         // loop can be unrolled.
441         // D.
442         // scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
443         // __m256i results =
444         //     _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
445         // results = _mm256_permutevar8x32_epi32(results, repack_perm);
446         // accum_data_v[j] = _mm256_add_epi32(results, post_scaling_offset);
447 
448         // This multiplier code is complex and expensive enough on x86, that
449         // we prefer to implement the channels-are-columns case by transposing
450         // around it, rather than duplicate it (which would also require
451         // duplicating the above code computing the multiplier constants).
452         // This is one instance where channels-are-columns has lower performance
453         // than channels-are-rows.
454         const bool transpose_around_multiplier =
455             (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) &&
456             (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL);
457         if (transpose_around_multiplier) {
458           // Transpose the 8x8 accumulators block. Will be un-transposed below
459           // after the multplier implementation.
460           intrin_utils::mm256_transpose8x8_epi32<path>(
461               &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3,
462               &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7);
463         }
464 
465         auto rounding_right_shift = [=](__m256i& results,
466                                         const __m256i& exponent) {
467           // Construct the "nudge" value for each lane if the exponent is
468           // greater than 0. Otherwise, the nudge is 0.
469           const __m256i zeros = _mm256_setzero_si256();
470           const __m256i mask_rightshift_gtz =
471               _mm256_cmpgt_epi32(exponent, zeros);
472           const __m256i one_shift_exp_minus1 = _mm256_sllv_epi32(
473               _mm256_set1_epi32(1),
474               _mm256_sub_epi32(exponent, _mm256_set1_epi32(1)));
475           __m256i nudge = intrin_utils::mm256_blendv_epi32(
476               zeros, one_shift_exp_minus1, mask_rightshift_gtz);
477           // Calculate the shifted sum (results + nudge) >> exp.
478           const __m256i r_plus_nudge = _mm256_add_epi32(results, nudge);
479           const __m256i shifted_sum = _mm256_srav_epi32(r_plus_nudge, exponent);
480 
481           // Identify overflow in each lane and create mask.
482           const __m256i one_shift_31minus_exp = _mm256_sllv_epi32(
483               _mm256_set1_epi32(1),
484               _mm256_sub_epi32(_mm256_set1_epi32(31), exponent));
485           const __m256i mask_num_plus_nudge_overflow = _mm256_cmpgt_epi32(
486               results, _mm256_sub_epi32(_mm256_set1_epi32(0x7fffffff), nudge));
487           // Fill results with either (results + nudge) >> exponent or
488           // 1 << (31 - exp) in the case of overflow.
489           results = intrin_utils::mm256_blendv_epi32(
490               shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow);
491         };
492 
493         auto apply_multiplier = [=](__m256i& accum) {
494           __m256i shifted_accum = _mm256_sllv_epi32(accum, left_shift);
495           // Apply the fixed-point part of the multiplier.
496           __m256i scaled_v_low = _mm256_mul_epi32(
497               _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)),
498               m_64bit_low);
499           __m256i scaled_v_high = _mm256_mul_epi32(
500               _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)),
501               m_64bit_high);
502 
503           scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector);
504           scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector);
505 
506           scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low);
507           scaled_v_high =
508               _mm256_srlv_epi64(scaled_v_high, final_right_shift_high);
509 
510           scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
511           __m256i results =
512               _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
513           results = _mm256_permutevar8x32_epi32(results, repack_perm);
514           // Now do a Rounding Right Shift.
515           rounding_right_shift(results, right_shift);
516           accum = _mm256_add_epi32(results, post_scaling_offset);
517         };
518         apply_multiplier(accum_data_v0);
519         apply_multiplier(accum_data_v1);
520         apply_multiplier(accum_data_v2);
521         apply_multiplier(accum_data_v3);
522         apply_multiplier(accum_data_v4);
523         apply_multiplier(accum_data_v5);
524         apply_multiplier(accum_data_v6);
525         apply_multiplier(accum_data_v7);
526         // See above comment: here we transpose again to undo the transposition
527         // of the 8x8 block of accumulators used to implement the
528         // channels-are-columns case.
529         if (transpose_around_multiplier) {
530           intrin_utils::mm256_transpose8x8_epi32<path>(
531               &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3,
532               &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7);
533         }
534       }
535       const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max);
536       const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min);
537       const bool store_full_block = (residual_rows == kAvx8bitBlockSize) &&
538                                     (residual_cols == kAvx8bitBlockSize);
539 
540       __m256i accum_data_v[kAvx8bitBlockSize];
541       if (!store_full_block) {
542         accum_data_v[0] = accum_data_v0;
543         accum_data_v[1] = accum_data_v1;
544         accum_data_v[2] = accum_data_v2;
545         accum_data_v[3] = accum_data_v3;
546         accum_data_v[4] = accum_data_v4;
547         accum_data_v[5] = accum_data_v5;
548         accum_data_v[6] = accum_data_v6;
549         accum_data_v[7] = accum_data_v7;
550       }
551 
552       if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
553         std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
554         if (store_full_block) {
555           accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v);
556           accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v);
557           accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v);
558           accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v);
559           accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v);
560           accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v);
561           accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v);
562           accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v);
563           accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v);
564           accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v);
565           accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v);
566           accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v);
567           accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v);
568           accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v);
569           accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v);
570           accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v);
571           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
572               &tmp_ptr[0 * dst_stride], accum_data_v0);
573           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
574               &tmp_ptr[1 * dst_stride], accum_data_v1);
575           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
576               &tmp_ptr[2 * dst_stride], accum_data_v2);
577           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
578               &tmp_ptr[3 * dst_stride], accum_data_v3);
579           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
580               &tmp_ptr[4 * dst_stride], accum_data_v4);
581           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
582               &tmp_ptr[5 * dst_stride], accum_data_v5);
583           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
584               &tmp_ptr[6 * dst_stride], accum_data_v6);
585           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
586               &tmp_ptr[7 * dst_stride], accum_data_v7);
587         } else {
588           for (int j = 0; j < residual_cols; ++j) {
589             __m256i result = accum_data_v[j];
590             result = _mm256_min_epi32(result, clamp_max_v);
591             result = _mm256_max_epi32(result, clamp_min_v);
592             intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(
593                 tmp_ptr, residual_rows, result);
594             tmp_ptr += dst_stride;
595           }
596         }
597         dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) +
598                                      kAvx8bitBlockSize);
599       } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
600         std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
601         if (store_full_block) {
602           accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v);
603           accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v);
604           accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v);
605           accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v);
606           accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v);
607           accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v);
608           accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v);
609           accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v);
610           accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v);
611           accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v);
612           accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v);
613           accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v);
614           accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v);
615           accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v);
616           accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v);
617           accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v);
618           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(&tmp_ptr[0],
619                                                          accum_data_v0);
620           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(&tmp_ptr[dst_stride],
621                                                          accum_data_v1);
622           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
623               &tmp_ptr[2 * dst_stride], accum_data_v2);
624           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
625               &tmp_ptr[3 * dst_stride], accum_data_v3);
626           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
627               &tmp_ptr[4 * dst_stride], accum_data_v4);
628           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
629               &tmp_ptr[5 * dst_stride], accum_data_v5);
630           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
631               &tmp_ptr[6 * dst_stride], accum_data_v6);
632           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
633               &tmp_ptr[7 * dst_stride], accum_data_v7);
634         } else {
635           for (int j = 0; j < residual_cols; ++j) {
636             __m256i result = accum_data_v[j];
637             result = _mm256_min_epi32(result, clamp_max_v);
638             result = _mm256_max_epi32(result, clamp_min_v);
639             intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(
640                 tmp_ptr, residual_rows, result);
641             tmp_ptr += dst_stride;
642           }
643         }
644         dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) +
645                                      kAvx8bitBlockSize);
646       } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
647         std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
648         if (store_full_block) {
649           accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v);
650           accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v);
651           accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v);
652           accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v);
653           accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v);
654           accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v);
655           accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v);
656           accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v);
657           accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v);
658           accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v);
659           accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v);
660           accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v);
661           accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v);
662           accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v);
663           accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v);
664           accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v);
665           intrin_utils::mm256_storeu_cvtepi32_epi16<path>(&tmp_ptr[0],
666                                                           accum_data_v0);
667           intrin_utils::mm256_storeu_cvtepi32_epi16<path>(&tmp_ptr[dst_stride],
668                                                           accum_data_v1);
669           intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
670               &tmp_ptr[2 * dst_stride], accum_data_v2);
671           intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
672               &tmp_ptr[3 * dst_stride], accum_data_v3);
673           intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
674               &tmp_ptr[4 * dst_stride], accum_data_v4);
675           intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
676               &tmp_ptr[5 * dst_stride], accum_data_v5);
677           intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
678               &tmp_ptr[6 * dst_stride], accum_data_v6);
679           intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
680               &tmp_ptr[7 * dst_stride], accum_data_v7);
681         } else {
682           for (int j = 0; j < residual_cols; ++j) {
683             __m256i result = accum_data_v[j];
684             result = _mm256_min_epi32(result, clamp_max_v);
685             result = _mm256_max_epi32(result, clamp_min_v);
686             intrin_utils::mm256_n_storeu_cvtepi32_epi16<path>(
687                 tmp_ptr, residual_rows, result);
688             tmp_ptr += dst_stride;
689           }
690         }
691         dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) +
692                                      kAvx8bitBlockSize);
693       } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
694         if (store_full_block) {
695           std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr);
696           intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[0], accum_data_v0);
697           intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[dst_stride],
698                                                  accum_data_v1);
699           intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[2 * dst_stride],
700                                                  accum_data_v2);
701           intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[3 * dst_stride],
702                                                  accum_data_v3);
703           intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[4 * dst_stride],
704                                                  accum_data_v4);
705           intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[5 * dst_stride],
706                                                  accum_data_v5);
707           intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[6 * dst_stride],
708                                                  accum_data_v6);
709           intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[7 * dst_stride],
710                                                  accum_data_v7);
711         } else {
712           std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr);
713           for (int j = 0; j < residual_cols; ++j) {
714             intrin_utils::mm256_n_storeu_epi32<path>(
715                 dst_block_ptr, residual_rows, accum_data_v[j]);
716             dst_block_ptr += dst_stride;
717           }
718         }
719         dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) +
720                                      kAvx8bitBlockSize);
721       } else {
722         RUY_DCHECK(false);
723       }
724 
725       lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride;
726     }  // End row-block loop.
727 
728     dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
729                                      kAvx8bitBlockSize * params.dst_stride);
730     rhs_col_ptr =
731         static_cast<const void*>(static_cast<const char*>(rhs_col_ptr) +
732                                  kAvx8bitBlockSize * params.rhs_stride);
733   }  // End col-block loop.
734 }  // NOLINT(readability/fn_size)
735 
736 void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) {
737   Kernel8bitAvx2Impl<Path::kAvx2Fma>(params);
738 }
739 
740 template <Path path>
741 void Kernel8bitAvx2SingleColImpl(const KernelParams8bit<8, 8>& params) {
742   profiler::ScopeLabel label("Kernel kAvx2Fma 8-bit GEMV");
743 
744   RUY_DCHECK_EQ(params.dst_cols, 1);
745   RUY_DCHECK_EQ(params.last_col, 0);
746   RUY_DCHECK_EQ(params.start_col, 0);
747 
748   const std::int8_t splitter_idx_data[32] = {
749       0, 1, 4, 5, 8,  9,  12, 13,  //
750       2, 3, 6, 7, 10, 11, 14, 15,  //
751       0, 1, 4, 5, 8,  9,  12, 13,  //
752       2, 3, 6, 7, 10, 11, 14, 15   //
753   };
754 
755   int bias_ptr_block_increment =
756       params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0;
757 
758   const void* rhs_col_ptr = params.rhs_base_ptr;
759   void* dst_col_ptr = params.dst_base_ptr;
760   const std::int32_t* bias_col_ptr = params.bias;
761   if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
762     bias_col_ptr += params.start_row;
763   }
764 
765   const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
766   void* dst_ptr = dst_col_ptr;
767   const std::int32_t* bias_ptr = bias_col_ptr;
768 
769   const std::int32_t lhs_zero_point = params.lhs_zero_point;
770   const bool has_rhs_sums_offsets =
771       (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point;
772   std::int32_t rhs_sums_offsets[8];
773   if (has_rhs_sums_offsets) {
774     const __m256i rhs_sums_offset_v = _mm256_mullo_epi32(
775         _mm256_set1_epi32(lhs_zero_point),
776         _mm256_loadu_si256(
777             reinterpret_cast<__m256i const*>(&params.rhs_sums[0])));
778     _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets),
779                         rhs_sums_offset_v);
780   }
781 
782   for (int row = params.start_row; row <= params.last_row;
783        row += kAvx8bitBlockSize) {
784     const int residual_rows =
785         std::min(params.dst_rows - row, kAvx8bitBlockSize);
786 
787     const __m256i splitter_idx =
788         _mm256_loadu_si256(reinterpret_cast<__m256i const*>(splitter_idx_data));
789 
790     __m256i accum_data_v0;
791 
792     // Initialize with bias.
793     __m256i initial_accum_data =
794         _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias_ptr));
795     bias_ptr += bias_ptr_block_increment;
796 
797     // Adjustments common across columns.
798     const std::int32_t rhs_zero_point = params.rhs_zero_point;
799     if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
800       const __m256i lhs_sums_offset = _mm256_mullo_epi32(
801           _mm256_set1_epi32(rhs_zero_point),
802           _mm256_loadu_si256(
803               reinterpret_cast<__m256i const*>(&params.lhs_sums[row])));
804       initial_accum_data =
805           _mm256_sub_epi32(initial_accum_data, lhs_sums_offset);
806     }
807     const std::int32_t prod_zp_depth = params.prod_zp_depth;
808     if (prod_zp_depth) {
809       initial_accum_data = _mm256_add_epi32(initial_accum_data,
810                                             _mm256_set1_epi32(prod_zp_depth));
811     }
812 
813     // Adjustments differing across columns.
814     if (has_rhs_sums_offsets) {
815       accum_data_v0 = _mm256_sub_epi32(initial_accum_data,
816                                        _mm256_set1_epi32(rhs_sums_offsets[0]));
817     } else {
818       accum_data_v0 = initial_accum_data;
819     }
820 
821     const std::int8_t* lhs_ptr = lhs_col_ptr;
822     const void* rhs_ptr = rhs_col_ptr;
823     for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) {
824       const __m256i lhs_data =
825           _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr));
826       const std::int32_t* rhs_data =
827           reinterpret_cast<const std::int32_t*>(rhs_ptr);
828 
829       // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
830       // For simplicity we load 4x the data that we need and process twice the
831       // data  that we need  and store only the data we need.
832       std::int32_t rhs_data_buf[2];
833       if (params.rhs_scalar_size == 1) {
834         rhs_data = rhs_data_buf;
835         const __m128i rhs_data_8bit =
836             intrin_utils::mm_loadu_si32<path>(rhs_ptr);
837         const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit);
838         // Now that we have cast the RHS data, we store it so that each value
839         // can be separately loaded in the accumulation loop.
840         _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data_buf),
841                         rhs_16_bit_dup);
842       } else {
843         RUY_DCHECK(params.rhs_scalar_size == 2);
844       }
845 
846       // NOTE: There may be opportunities for permuting the data in the packing
847       // code instead of here.
848       const __m256i lhs_data_split =
849           _mm256_shuffle_epi8(lhs_data, splitter_idx);
850       const __m256i lhs_data_split_expand_bottom =
851           _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 0));
852       const __m256i lhs_data_split_expand_top =
853           _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 1));
854 
855       // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit.
856       const __m256i lhs_16_bit_low = _mm256_permute2x128_si256(
857           lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20);
858       // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit.
859       const __m256i lhs_16_bit_high = _mm256_permute2x128_si256(
860           lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31);
861       // Accumulate for column 0.
862       const std::int32_t low_rhs_value = rhs_data[0];
863       const std::int32_t high_rhs_value = rhs_data[1];
864 
865       const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value);
866       const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value);
867 
868       accum_data_v0 = _mm256_add_epi32(
869           accum_data_v0, _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
870       accum_data_v0 = _mm256_add_epi32(
871           accum_data_v0,
872           _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
873 
874       lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
875       rhs_ptr = static_cast<const void*>(static_cast<const char*>(rhs_ptr) +
876                                          kAvx8bitBlockSize * kAvx8bitInnerSize *
877                                              params.rhs_scalar_size);
878     }
879 
880     if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
881       __m256i m_vector;
882       __m256i e_vector;
883       // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
884       int channel = (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? row : 0;
885       m_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
886           params.multiplier_fixedpoint + channel));
887       e_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
888           params.multiplier_exponent + channel));
889 
890       const __m256i m_64bit_low =
891           _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 0));
892       const __m256i m_64bit_high =
893           _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 1));
894 
895       const __m256i zero_vector = _mm256_setzero_si256();
896       const __m256i left_shift = _mm256_max_epi32(e_vector, zero_vector);
897       const __m256i neg_e_vector = _mm256_sub_epi32(zero_vector, e_vector);
898       const __m256i right_shift = _mm256_max_epi32(neg_e_vector, zero_vector);
899       const __m256i final_right_shift = _mm256_set1_epi32(31);
900       const __m256i final_right_shift_low =
901           _mm256_cvtepi32_epi64(_mm256_extracti128_si256(final_right_shift, 0));
902       const __m256i final_right_shift_high =
903           _mm256_cvtepi32_epi64(_mm256_extracti128_si256(final_right_shift, 1));
904       const __m256i convert_to_unsigned_64 =
905           _mm256_set1_epi64x(0x8000000000000000);
906 
907       __m256i post_scaling_offset = _mm256_setzero_si256();
908       // A "half" added for rounding prior to truncation of 64-bit value.
909       const __m256i offset_vector = _mm256_add_epi64(
910           _mm256_slli_epi64(_mm256_set1_epi64x(1), 30),
911           convert_to_unsigned_64);
912 
913       if (params.dst_zero_point) {
914         post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point);
915       }
916 
917       const __m256i repack_perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
918 
919       // See GEMM version for details of this process.
920       {
921         __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v0, left_shift);
922         // Apply the fixed-point part of the multiplier.
923         __m256i scaled_v_low = _mm256_mul_epi32(
924             _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)),
925             m_64bit_low);
926         __m256i scaled_v_high = _mm256_mul_epi32(
927             _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)),
928             m_64bit_high);
929 
930         scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector);
931         scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector);
932 
933         scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low);
934         scaled_v_high =
935             _mm256_srlv_epi64(scaled_v_high, final_right_shift_high);
936 
937         scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
938         __m256i results = _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
939         results = _mm256_permutevar8x32_epi32(results, repack_perm);
940 
941         // Now do a Rounding Right Shift.
942         // First, construct the nudge value for each lane.
943         const __m256i zeros = _mm256_setzero_si256();
944         const __m256i mask_rightshift_gtz =
945             _mm256_cmpgt_epi32(right_shift, zeros);
946         const __m256i one_shift_exp_minus1 = _mm256_sllv_epi32(
947             _mm256_set1_epi32(1),
948             _mm256_sub_epi32(right_shift, _mm256_set1_epi32(1)));
949         __m256i nudge = intrin_utils::mm256_blendv_epi32(
950             zeros, one_shift_exp_minus1, mask_rightshift_gtz);
951         // Calculate the shifted sum (results + nudge) >> exp.
952         const __m256i r_plus_nudge = _mm256_add_epi32(results, nudge);
953         const __m256i shifted_sum =
954             _mm256_srav_epi32(r_plus_nudge, right_shift);
955 
956         // Identify overflow in each lane and create mask.
957         const __m256i one_shift_31minus_exp = _mm256_sllv_epi32(
958             _mm256_set1_epi32(1),
959             _mm256_sub_epi32(_mm256_set1_epi32(31), right_shift));
960         const __m256i mask_num_plus_nudge_overflow = _mm256_cmpgt_epi32(
961             results, _mm256_sub_epi32(_mm256_set1_epi32(0x7fffffff), nudge));
962         // Fill results with either (results + nudge) >> exponent or
963         // 1 << (31 - exp) in the case of overflow.
964         results = intrin_utils::mm256_blendv_epi32(
965             shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow);
966 
967         accum_data_v0 = _mm256_add_epi32(results, post_scaling_offset);
968       }
969     }
970     const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max);
971     const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min);
972 
973     if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
974       std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
975       __m256i result = accum_data_v0;
976       result = _mm256_min_epi32(result, clamp_max_v);
977       result = _mm256_max_epi32(result, clamp_min_v);
978       intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(tmp_ptr, residual_rows,
979                                                        result);
980       dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) +
981                                    kAvx8bitBlockSize);
982     } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
983       std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
984       __m256i result = accum_data_v0;
985       result = _mm256_min_epi32(result, clamp_max_v);
986       result = _mm256_max_epi32(result, clamp_min_v);
987       intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(tmp_ptr, residual_rows,
988                                                        result);
989       dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) +
990                                    kAvx8bitBlockSize);
991     } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
992       std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
993       __m256i result = accum_data_v0;
994       result = _mm256_min_epi32(result, clamp_max_v);
995       result = _mm256_max_epi32(result, clamp_min_v);
996       intrin_utils::mm256_n_storeu_cvtepi32_epi16<path>(tmp_ptr, residual_rows,
997                                                         result);
998       dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) +
999                                    kAvx8bitBlockSize);
1000     } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
1001       std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr);
1002       intrin_utils::mm256_n_storeu_epi32<path>(dst_block_ptr, residual_rows,
1003                                                accum_data_v0);
1004       dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) +
1005                                    kAvx8bitBlockSize);
1006     } else {
1007       RUY_DCHECK(false);
1008     }
1009 
1010     lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride;
1011   }  // End row-block loop.
1012 
1013   dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
1014                                    kAvx8bitBlockSize * params.dst_stride);
1015   rhs_col_ptr = static_cast<const void*>(static_cast<const char*>(rhs_col_ptr) +
1016                                          kAvx8bitBlockSize * params.rhs_stride);
1017 }  // NOLINT(readability/fn_size)
1018 
1019 void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) {
1020   Kernel8bitAvx2SingleColImpl<Path::kAvx2Fma>(params);
1021 }
1022 
1023 void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params) {
1024   profiler::ScopeLabel label("Kernel kAvx2Fma float");
1025   KernelFloatAvxCommon<Path::kAvx2Fma>(params);
1026 }
1027 
1028 void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params) {
1029   profiler::ScopeLabel label("Kernel kAvx2Fma float GEMV");
1030   KernelFloatAvxCommonSingleCol<Path::kAvx2Fma>(params);
1031 }
1032 
1033 #endif  //  RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)
1034 
1035 }  // namespace ruy
1036