xref: /aosp_15_r20/external/ruy/ruy/kernel_avx512.cc (revision bb86c7ed5fb1b98a7eac808e443a46cc8b90dfc0)
1 /* Copyright 2019 Google LLC. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <cstdint>
18 
19 #include "ruy/check_macros.h"
20 #include "ruy/kernel_x86.h"
21 #include "ruy/opt_set.h"
22 #include "ruy/platform.h"
23 #include "ruy/profiler/instrumentation.h"
24 
25 #if RUY_PLATFORM_AVX512 && RUY_OPT(ASM)
26 #include <immintrin.h>  // IWYU pragma: keep
27 #endif
28 
29 namespace ruy {
30 
31 #if !(RUY_PLATFORM_AVX512 && RUY_OPT(ASM))
32 
Kernel8bitAvx512(const KernelParams8bit<16,16> &)33 void Kernel8bitAvx512(const KernelParams8bit<16, 16>&) {
34   // CPU-ID-based checks should disable the path that would reach this point.
35   RUY_DCHECK(false);
36 }
37 
Kernel8bitAvx512SingleCol(const KernelParams8bit<16,16> &)38 void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>&) {
39   // CPU-ID-based checks should disable the path that would reach this point.
40   RUY_DCHECK(false);
41 }
42 
KernelFloatAvx512(const KernelParamsFloat<16,16> &)43 void KernelFloatAvx512(const KernelParamsFloat<16, 16>&) {
44   // CPU-ID-based checks should disable the path that would reach this point.
45   RUY_DCHECK(false);
46 }
47 
KernelFloatAvx512SingleCol(const KernelParamsFloat<16,16> &)48 void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>&) {
49   // CPU-ID-based checks should disable the path that would reach this point.
50   RUY_DCHECK(false);
51 }
52 
53 #else  // RUY_PLATFORM_AVX512 && RUY_OPT(ASM)
54 
55 void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) {
56   profiler::ScopeLabel label("Kernel kAvx512 8-bit");
57 
58   std::int32_t dst_stride = 0;
59   if ((params.dst_type_id == DstTypeId<std::int8_t>::kValue) ||
60       (params.dst_type_id == DstTypeId<std::uint8_t>::kValue)) {
61     dst_stride = params.dst_stride;
62   } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
63     dst_stride = params.dst_stride / sizeof(std::int16_t);
64   } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
65     dst_stride = params.dst_stride / sizeof(std::int32_t);
66   } else {
67     RUY_DCHECK(false);
68   }
69 
70   const void* rhs_col_ptr = params.rhs_base_ptr;
71   void* dst_col_ptr = params.dst_base_ptr;
72 
73   for (int col = params.start_col; col <= params.last_col; col += 16) {
74     const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
75     void* dst_ptr = dst_col_ptr;
76 
77     const std::int32_t lhs_zero_point = params.lhs_zero_point;
78     const bool has_rhs_sums_offsets =
79         (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point;
80     std::int32_t rhs_sums_offsets[16];
81     if (has_rhs_sums_offsets) {
82       const __m512i rhs_sums_offset_v =
83           _mm512_mullo_epi32(_mm512_set1_epi32(lhs_zero_point),
84                              _mm512_loadu_si512(&params.rhs_sums[col]));
85       _mm512_storeu_si512(reinterpret_cast<__m512i*>(rhs_sums_offsets),
86                           rhs_sums_offset_v);
87     }
88 
89     for (int row = params.start_row; row <= params.last_row; row += 16) {
90       int channel =
91           (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) ? col : row;
92       int multiplier_channel =
93           (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? channel : 0;
94 
95       const int residual_rows = std::min(params.dst_rows - row, 16);
96       const int residual_cols = std::min(params.dst_cols - col, 16);
97 
98       __m512i accum_data_v0;
99       __m512i accum_data_v1;
100       __m512i accum_data_v2;
101       __m512i accum_data_v3;
102       __m512i accum_data_v4;
103       __m512i accum_data_v5;
104       __m512i accum_data_v6;
105       __m512i accum_data_v7;
106       __m512i accum_data_v8;
107       __m512i accum_data_v9;
108       __m512i accum_data_va;
109       __m512i accum_data_vb;
110       __m512i accum_data_vc;
111       __m512i accum_data_vd;
112       __m512i accum_data_ve;
113       __m512i accum_data_vf;
114 
115       const __mmask16 row_mask =
116           (static_cast<std::uint32_t>(1) << residual_rows) - 1;
117 
118       // initial_accum_data will be the initialize of each of the
119       // accum_data_* accumulator registers. We compute into it terms that are
120       // identical across columns.
121       __m512i initial_accum_data = _mm512_set1_epi32(params.prod_zp_depth);
122 
123       // In the channels-are-rows case, we can load bias here.
124       if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) &&
125           !(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) {
126         initial_accum_data = _mm512_add_epi32(
127             initial_accum_data,
128             _mm512_loadu_si512(
129                 reinterpret_cast<const __m512i*>(params.bias + row)));
130       }
131 
132       const std::int32_t rhs_zero_point = params.rhs_zero_point;
133       if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
134         const __m512i lhs_sums_offset =
135             _mm512_mullo_epi32(_mm512_set1_epi32(rhs_zero_point),
136                                _mm512_loadu_si512(&params.lhs_sums[row]));
137         initial_accum_data =
138             _mm512_sub_epi32(initial_accum_data, lhs_sums_offset);
139       }
140 
141       // Adjustments differing across columns.
142       if (has_rhs_sums_offsets) {
143         accum_data_v0 = _mm512_sub_epi32(
144             initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[0]));
145         accum_data_v1 = _mm512_sub_epi32(
146             initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[1]));
147         accum_data_v2 = _mm512_sub_epi32(
148             initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[2]));
149         accum_data_v3 = _mm512_sub_epi32(
150             initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[3]));
151         accum_data_v4 = _mm512_sub_epi32(
152             initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[4]));
153         accum_data_v5 = _mm512_sub_epi32(
154             initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[5]));
155         accum_data_v6 = _mm512_sub_epi32(
156             initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[6]));
157         accum_data_v7 = _mm512_sub_epi32(
158             initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[7]));
159         accum_data_v8 = _mm512_sub_epi32(
160             initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[8]));
161         accum_data_v9 = _mm512_sub_epi32(
162             initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[9]));
163         accum_data_va = _mm512_sub_epi32(
164             initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[10]));
165         accum_data_vb = _mm512_sub_epi32(
166             initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[11]));
167         accum_data_vc = _mm512_sub_epi32(
168             initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[12]));
169         accum_data_vd = _mm512_sub_epi32(
170             initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[13]));
171         accum_data_ve = _mm512_sub_epi32(
172             initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[14]));
173         accum_data_vf = _mm512_sub_epi32(
174             initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[15]));
175       } else {
176         accum_data_v0 = initial_accum_data;
177         accum_data_v1 = initial_accum_data;
178         accum_data_v2 = initial_accum_data;
179         accum_data_v3 = initial_accum_data;
180         accum_data_v4 = initial_accum_data;
181         accum_data_v5 = initial_accum_data;
182         accum_data_v6 = initial_accum_data;
183         accum_data_v7 = initial_accum_data;
184         accum_data_v8 = initial_accum_data;
185         accum_data_v9 = initial_accum_data;
186         accum_data_va = initial_accum_data;
187         accum_data_vb = initial_accum_data;
188         accum_data_vc = initial_accum_data;
189         accum_data_vd = initial_accum_data;
190         accum_data_ve = initial_accum_data;
191         accum_data_vf = initial_accum_data;
192       }
193 
194       // Finally, in the channels-are-columns case, load bias data here.
195       if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) &&
196           (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) {
197         const __m512i bias_data = _mm512_loadu_si512(
198             reinterpret_cast<const __m512i*>(params.bias + col));
199         accum_data_v0 = _mm512_add_epi32(
200             accum_data_v0,
201             _mm512_permutexvar_epi32(_mm512_set1_epi32(0), bias_data));
202         accum_data_v1 = _mm512_add_epi32(
203             accum_data_v1,
204             _mm512_permutexvar_epi32(_mm512_set1_epi32(1), bias_data));
205         accum_data_v2 = _mm512_add_epi32(
206             accum_data_v2,
207             _mm512_permutexvar_epi32(_mm512_set1_epi32(2), bias_data));
208         accum_data_v3 = _mm512_add_epi32(
209             accum_data_v3,
210             _mm512_permutexvar_epi32(_mm512_set1_epi32(3), bias_data));
211         accum_data_v4 = _mm512_add_epi32(
212             accum_data_v4,
213             _mm512_permutexvar_epi32(_mm512_set1_epi32(4), bias_data));
214         accum_data_v5 = _mm512_add_epi32(
215             accum_data_v5,
216             _mm512_permutexvar_epi32(_mm512_set1_epi32(5), bias_data));
217         accum_data_v6 = _mm512_add_epi32(
218             accum_data_v6,
219             _mm512_permutexvar_epi32(_mm512_set1_epi32(6), bias_data));
220         accum_data_v7 = _mm512_add_epi32(
221             accum_data_v7,
222             _mm512_permutexvar_epi32(_mm512_set1_epi32(7), bias_data));
223         accum_data_v8 = _mm512_add_epi32(
224             accum_data_v8,
225             _mm512_permutexvar_epi32(_mm512_set1_epi32(8), bias_data));
226         accum_data_v9 = _mm512_add_epi32(
227             accum_data_v9,
228             _mm512_permutexvar_epi32(_mm512_set1_epi32(9), bias_data));
229         accum_data_va = _mm512_add_epi32(
230             accum_data_va,
231             _mm512_permutexvar_epi32(_mm512_set1_epi32(10), bias_data));
232         accum_data_vb = _mm512_add_epi32(
233             accum_data_vb,
234             _mm512_permutexvar_epi32(_mm512_set1_epi32(11), bias_data));
235         accum_data_vc = _mm512_add_epi32(
236             accum_data_vc,
237             _mm512_permutexvar_epi32(_mm512_set1_epi32(12), bias_data));
238         accum_data_vd = _mm512_add_epi32(
239             accum_data_vd,
240             _mm512_permutexvar_epi32(_mm512_set1_epi32(13), bias_data));
241         accum_data_ve = _mm512_add_epi32(
242             accum_data_ve,
243             _mm512_permutexvar_epi32(_mm512_set1_epi32(14), bias_data));
244         accum_data_vf = _mm512_add_epi32(
245             accum_data_vf,
246             _mm512_permutexvar_epi32(_mm512_set1_epi32(15), bias_data));
247       }
248 
249       const std::int8_t* lhs_ptr = lhs_col_ptr;
250       const void* rhs_ptr = rhs_col_ptr;
251       for (int d = 0; d < params.depth; d += 4) {
252         const __m512i lhs_data = _mm512_loadu_si512(lhs_ptr);
253         __m512i rhs_data_8bit = _mm512_loadu_si512(rhs_ptr);
254 
255         // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
256         std::int32_t rhs_data_buf[32];
257         const std::int32_t* rhs_data =
258             reinterpret_cast<const std::int32_t*>(rhs_ptr);
259         if (params.rhs_scalar_size == 1) {
260           rhs_data = rhs_data_buf;
261           const __m256i rhs_data_bottom_lane =
262               _mm512_castsi512_si256(rhs_data_8bit);
263           const __m256i rhs_data_top_lane =
264               _mm512_extracti32x8_epi32(rhs_data_8bit, 1);
265           const __m512i rhs_16_bit_dup_low =
266               _mm512_cvtepi8_epi16(rhs_data_bottom_lane);
267           const __m512i rhs_16_bit_dup_high =
268               _mm512_cvtepi8_epi16(rhs_data_top_lane);
269           // Now that we have cast the RHS data, we store it so that each value
270           // can be separately loaded in the accumulation loop.
271           _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data_buf),
272                               rhs_16_bit_dup_low);
273           _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data_buf + 16),
274                               rhs_16_bit_dup_high);
275         } else {
276           RUY_DCHECK(params.rhs_scalar_size == 2);
277         }
278 
279         // Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit.
280         const __m512i lhs_16_bit_low =
281             _mm512_cvtepi8_epi16(_mm512_cvtepi32_epi16(lhs_data));
282         // Take bytes 2, 3, 6, 7, 10, 11, ... and expand to 16-bit.
283         const __m512i lhs_16_bit_high = _mm512_cvtepi8_epi16(
284             _mm512_cvtepi32_epi16(_mm512_srli_epi32(lhs_data, 16)));
285 
286         auto process_column = [=](int col, __m512i& accum) {
287           const __m512i rhs_16_bit_dup_low =
288               _mm512_set1_epi32(rhs_data[2 * col]);
289           const __m512i rhs_16_bit_dup_high =
290               _mm512_set1_epi32(rhs_data[2 * col + 1]);
291 
292           accum = _mm512_add_epi32(
293               accum, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
294           accum = _mm512_add_epi32(
295               accum, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
296         };
297         process_column(0, accum_data_v0);
298         process_column(1, accum_data_v1);
299         process_column(2, accum_data_v2);
300         process_column(3, accum_data_v3);
301         process_column(4, accum_data_v4);
302         process_column(5, accum_data_v5);
303         process_column(6, accum_data_v6);
304         process_column(7, accum_data_v7);
305         process_column(8, accum_data_v8);
306         process_column(9, accum_data_v9);
307         process_column(10, accum_data_va);
308         process_column(11, accum_data_vb);
309         process_column(12, accum_data_vc);
310         process_column(13, accum_data_vd);
311         process_column(14, accum_data_ve);
312         process_column(15, accum_data_vf);
313 
314         lhs_ptr += 16 * 4;
315         rhs_ptr = static_cast<const void*>(static_cast<const char*>(rhs_ptr) +
316                                            16 * 4 * params.rhs_scalar_size);
317       }
318 
319       if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
320         // The non-per-channel case could equivalently be handled in the per-row
321         // or per-column code path. The per-row code path is slightly more
322         // efficient so we handle it there.
323         const bool per_column_multiplier =
324             (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) &&
325             (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL);
326 
327         __m512i m_vector;
328         __m512i e_vector;
329         // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
330         m_vector = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(
331             params.multiplier_fixedpoint + multiplier_channel));
332         e_vector = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(
333             params.multiplier_exponent + multiplier_channel));
334 
335         const __m512i m_64bit_low =
336             _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 0));
337         const __m512i m_64bit_high =
338             _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 1));
339 
340         const __m512i zero_vector = _mm512_setzero_epi32();
341         const __m512i left_shift = _mm512_max_epi32(e_vector, zero_vector);
342         const __m512i neg_e_vector = _mm512_sub_epi32(zero_vector, e_vector);
343         const __m512i right_shift = _mm512_max_epi32(neg_e_vector, zero_vector);
344         const __m512i final_right_shift = _mm512_set1_epi32(31);
345         const __m512i right_shift_low =
346             _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0));
347         const __m512i right_shift_high =
348             _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1));
349         const __m512i final_right_shift_low = _mm512_cvtepi32_epi64(
350             _mm512_extracti32x8_epi32(final_right_shift, 0));
351         const __m512i final_right_shift_high = _mm512_cvtepi32_epi64(
352             _mm512_extracti32x8_epi32(final_right_shift, 1));
353 
354         // A "half" added for rounding prior to truncation of 64-bit value.
355         const __m512i offset_vector =
356             _mm512_slli_epi64(_mm512_set1_epi64(1), 30);
357 
358         auto rounding_right_shift = [=](__m512i& results,
359                                         const __m512i& exponent) {
360           // Construct the "nudge" value for each lane if the exponent is
361           // greater than 0. Otherwise, the nudge is 0.
362           const __m512i zeros = _mm512_setzero_si512();
363           const auto mask_rightshift_gtz =
364               _mm512_cmpgt_epi64_mask(exponent, zeros);
365           const __m512i one_shift_exp_minus1 = _mm512_sllv_epi64(
366               _mm512_set1_epi64(1),
367               _mm512_sub_epi64(exponent, _mm512_set1_epi64(1)));
368           __m512i nudge = _mm512_mask_mov_epi64(zeros, mask_rightshift_gtz,
369                                                 one_shift_exp_minus1);
370           // Calculate the shifted sum (results + nudge) >> exp.
371           const __m512i r_plus_nudge = _mm512_add_epi64(results, nudge);
372           const __m512i shifted_sum = _mm512_srav_epi64(r_plus_nudge, exponent);
373 
374           // Identify overflow in each lane and create mask.
375           const __m512i one_shift_31minus_exp = _mm512_sllv_epi64(
376               _mm512_set1_epi64(1),
377               _mm512_sub_epi64(_mm512_set1_epi64(31), exponent));
378           const auto mask_num_plus_nudge_overflow = _mm512_cmpgt_epi64_mask(
379               results, _mm512_sub_epi64(_mm512_set1_epi64(0x7fffffff), nudge));
380           // Fill results with either (results + nudge) >> exponent or
381           // 1 << (31 - exp) in the case of overflow.
382           results = _mm512_mask_mov_epi64(
383               shifted_sum, mask_num_plus_nudge_overflow, one_shift_31minus_exp);
384         };
385 
386         if (per_column_multiplier) {
387           auto apply_multiplier = [=](__m512i& accum, int col) {
388             __m512i perm_64bit_vals = _mm512_set1_epi64(col % 8);
389             // Apply the fixed-point part of the multiplier.
390             __m512i left_shift_val =
391                 _mm512_permutexvar_epi32(_mm512_set1_epi32(col), left_shift);
392             __m512i m_64bit_val = _mm512_permutexvar_epi64(
393                 perm_64bit_vals, col < 8 ? m_64bit_low : m_64bit_high);
394             __m512i offset_vector_val =
395                 _mm512_permutexvar_epi64(perm_64bit_vals, offset_vector);
396             __m512i final_right_shift_val = _mm512_permutexvar_epi64(
397                 perm_64bit_vals,
398                 col < 8 ? final_right_shift_low : final_right_shift_high);
399             __m512i right_shift_val = _mm512_permutexvar_epi64(
400                 perm_64bit_vals, col < 8 ? right_shift_low : right_shift_high);
401 
402             accum = _mm512_sllv_epi32(accum, left_shift_val);
403             __m512i scaled_v_low = _mm512_mul_epi32(
404                 _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 0)),
405                 m_64bit_val);
406             __m512i scaled_v_high = _mm512_mul_epi32(
407                 _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 1)),
408                 m_64bit_val);
409 
410             scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_val);
411             scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_val);
412 
413             scaled_v_low =
414                 _mm512_srav_epi64(scaled_v_low, final_right_shift_val);
415             scaled_v_high =
416                 _mm512_srav_epi64(scaled_v_high, final_right_shift_val);
417 
418             rounding_right_shift(scaled_v_low, right_shift_val);
419             rounding_right_shift(scaled_v_high, right_shift_val);
420 
421             accum = _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
422             accum = _mm512_inserti32x8(accum,
423                                        _mm512_cvtepi64_epi32(scaled_v_high), 1);
424           };
425           apply_multiplier(accum_data_v0, 0);
426           apply_multiplier(accum_data_v1, 1);
427           apply_multiplier(accum_data_v2, 2);
428           apply_multiplier(accum_data_v3, 3);
429           apply_multiplier(accum_data_v4, 4);
430           apply_multiplier(accum_data_v5, 5);
431           apply_multiplier(accum_data_v6, 6);
432           apply_multiplier(accum_data_v7, 7);
433           apply_multiplier(accum_data_v8, 8);
434           apply_multiplier(accum_data_v9, 9);
435           apply_multiplier(accum_data_va, 10);
436           apply_multiplier(accum_data_vb, 11);
437           apply_multiplier(accum_data_vc, 12);
438           apply_multiplier(accum_data_vd, 13);
439           apply_multiplier(accum_data_ve, 14);
440           apply_multiplier(accum_data_vf, 15);
441         } else {  // not per-column, so per-row
442           auto apply_multiplier = [=](__m512i& accum) {
443             accum = _mm512_sllv_epi32(accum, left_shift);
444             // Apply the fixed-point part of the multiplier.
445             __m512i scaled_v_low = _mm512_mul_epi32(
446                 _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 0)),
447                 m_64bit_low);
448             __m512i scaled_v_high = _mm512_mul_epi32(
449                 _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 1)),
450                 m_64bit_high);
451 
452             scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector);
453             scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector);
454 
455             scaled_v_low =
456                 _mm512_srav_epi64(scaled_v_low, final_right_shift_low);
457             scaled_v_high =
458                 _mm512_srav_epi64(scaled_v_high, final_right_shift_high);
459 
460             rounding_right_shift(scaled_v_low, right_shift_low);
461             rounding_right_shift(scaled_v_high, right_shift_high);
462             accum = _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
463             accum = _mm512_inserti32x8(accum,
464                                        _mm512_cvtepi64_epi32(scaled_v_high), 1);
465           };
466           apply_multiplier(accum_data_v0);
467           apply_multiplier(accum_data_v1);
468           apply_multiplier(accum_data_v2);
469           apply_multiplier(accum_data_v3);
470           apply_multiplier(accum_data_v4);
471           apply_multiplier(accum_data_v5);
472           apply_multiplier(accum_data_v6);
473           apply_multiplier(accum_data_v7);
474           apply_multiplier(accum_data_v8);
475           apply_multiplier(accum_data_v9);
476           apply_multiplier(accum_data_va);
477           apply_multiplier(accum_data_vb);
478           apply_multiplier(accum_data_vc);
479           apply_multiplier(accum_data_vd);
480           apply_multiplier(accum_data_ve);
481           apply_multiplier(accum_data_vf);
482         }
483 
484         if (params.dst_zero_point != 0) {
485           __m512i dst_zero_point = _mm512_set1_epi32(params.dst_zero_point);
486           accum_data_v0 = _mm512_add_epi32(accum_data_v0, dst_zero_point);
487           accum_data_v1 = _mm512_add_epi32(accum_data_v1, dst_zero_point);
488           accum_data_v2 = _mm512_add_epi32(accum_data_v2, dst_zero_point);
489           accum_data_v3 = _mm512_add_epi32(accum_data_v3, dst_zero_point);
490           accum_data_v4 = _mm512_add_epi32(accum_data_v4, dst_zero_point);
491           accum_data_v5 = _mm512_add_epi32(accum_data_v5, dst_zero_point);
492           accum_data_v6 = _mm512_add_epi32(accum_data_v6, dst_zero_point);
493           accum_data_v7 = _mm512_add_epi32(accum_data_v7, dst_zero_point);
494           accum_data_v8 = _mm512_add_epi32(accum_data_v8, dst_zero_point);
495           accum_data_v9 = _mm512_add_epi32(accum_data_v9, dst_zero_point);
496           accum_data_va = _mm512_add_epi32(accum_data_va, dst_zero_point);
497           accum_data_vb = _mm512_add_epi32(accum_data_vb, dst_zero_point);
498           accum_data_vc = _mm512_add_epi32(accum_data_vc, dst_zero_point);
499           accum_data_vd = _mm512_add_epi32(accum_data_vd, dst_zero_point);
500           accum_data_ve = _mm512_add_epi32(accum_data_ve, dst_zero_point);
501           accum_data_vf = _mm512_add_epi32(accum_data_vf, dst_zero_point);
502         }
503       }
504 
505       const __m512i clamp_max_v = _mm512_set1_epi32(params.clamp_max);
506       const __m512i clamp_min_v = _mm512_set1_epi32(params.clamp_min);
507 
508       const bool store_full_block =
509           (residual_rows == 16) && (residual_cols == 16);
510 
511       __m512i accum_data_v[16];
512 
513       // In most cases we would make this conditional on (!store_full_block) and
514       // unwind the clamp-and-store loop, but the benefit appears small.
515       {
516         accum_data_v[0] = accum_data_v0;
517         accum_data_v[1] = accum_data_v1;
518         accum_data_v[2] = accum_data_v2;
519         accum_data_v[3] = accum_data_v3;
520         accum_data_v[4] = accum_data_v4;
521         accum_data_v[5] = accum_data_v5;
522         accum_data_v[6] = accum_data_v6;
523         accum_data_v[7] = accum_data_v7;
524         accum_data_v[8] = accum_data_v8;
525         accum_data_v[9] = accum_data_v9;
526         accum_data_v[10] = accum_data_va;
527         accum_data_v[11] = accum_data_vb;
528         accum_data_v[12] = accum_data_vc;
529         accum_data_v[13] = accum_data_vd;
530         accum_data_v[14] = accum_data_ve;
531         accum_data_v[15] = accum_data_vf;
532       }
533 
534       if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
535         std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
536         const int block_col_offset = dst_stride;
537         if (store_full_block) {
538           for (int j = 0; j < 16; ++j) {
539             __m512i result = accum_data_v[j];
540             result = _mm512_min_epi32(result, clamp_max_v);
541             result = _mm512_max_epi32(result, clamp_min_v);
542             _mm_storeu_si128(
543                 reinterpret_cast<__m128i*>(tmp_ptr + j * block_col_offset),
544                 _mm512_cvtepi32_epi8(result));
545           }
546         } else {
547           for (int j = 0; j < residual_cols; ++j) {
548             __m512i result = accum_data_v[j];
549             result = _mm512_min_epi32(result, clamp_max_v);
550             result = _mm512_max_epi32(result, clamp_min_v);
551             _mm_mask_storeu_epi8(tmp_ptr + j * block_col_offset, row_mask,
552                                  _mm512_cvtepi32_epi8(result));
553           }
554         }
555         dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) + 16);
556       } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
557         std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
558         const int block_col_offset = dst_stride;
559         if (store_full_block) {
560           for (int j = 0; j < residual_cols; ++j) {
561             __m512i result = accum_data_v[j];
562             result = _mm512_min_epi32(result, clamp_max_v);
563             result = _mm512_max_epi32(result, clamp_min_v);
564             _mm_storeu_si128(
565                 reinterpret_cast<__m128i*>(tmp_ptr + j * block_col_offset),
566                 _mm512_cvtepi32_epi8(result));
567           }
568         } else {
569           for (int j = 0; j < residual_cols; ++j) {
570             __m512i result = accum_data_v[j];
571             result = _mm512_min_epi32(result, clamp_max_v);
572             result = _mm512_max_epi32(result, clamp_min_v);
573             _mm_mask_storeu_epi8(tmp_ptr + j * block_col_offset, row_mask,
574                                  _mm512_cvtepi32_epi8(result));
575           }
576         }
577         dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) + 16);
578       } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
579         std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
580         const int block_col_offset = dst_stride;
581         if (store_full_block) {
582           for (int j = 0; j < 16; ++j) {
583             __m512i result = accum_data_v[j];
584             result = _mm512_min_epi32(result, clamp_max_v);
585             result = _mm512_max_epi32(result, clamp_min_v);
586             _mm256_storeu_si256(
587                 reinterpret_cast<__m256i*>(tmp_ptr + j * block_col_offset),
588                 _mm512_cvtepi32_epi16(result));
589           }
590         } else {
591           for (int j = 0; j < residual_cols; ++j) {
592             __m512i result = accum_data_v[j];
593             result = _mm512_min_epi32(result, clamp_max_v);
594             result = _mm512_max_epi32(result, clamp_min_v);
595             _mm256_mask_storeu_epi16(tmp_ptr + j * block_col_offset, row_mask,
596                                      _mm512_cvtepi32_epi16(result));
597           }
598         }
599         dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) + 16);
600       } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
601         if (store_full_block) {
602           std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr);
603           for (int j = 0; j < 16; ++j) {
604             _mm512_storeu_si512(tmp_ptr + j * dst_stride, accum_data_v[j]);
605           }
606         } else {
607           std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr);
608           for (int j = 0; j < residual_cols; ++j) {
609             _mm512_mask_storeu_epi32(tmp_ptr + j * dst_stride, row_mask,
610                                      accum_data_v[j]);
611           }
612         }
613         dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) + 16);
614       } else {
615         RUY_DCHECK(false);
616       }
617 
618       lhs_col_ptr += 16 * params.lhs_stride;
619     }  // End row-block loop.
620 
621     dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
622                                      16 * params.dst_stride);
623     rhs_col_ptr = static_cast<const void*>(
624         static_cast<const char*>(rhs_col_ptr) + 16 * params.rhs_stride);
625   }  // End col-block loop.
626 }  // NOLINT(readability/fn_size)
627 
628 void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) {
629   profiler::ScopeLabel label("Kernel kAvx512 8-bit GEMV");
630 
631   RUY_DCHECK_EQ(params.dst_cols, 1);
632   RUY_DCHECK_EQ(params.last_col, 0);
633   RUY_DCHECK_EQ(params.start_col, 0);
634 
635   int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 16 : 0;
636 
637   const void* rhs_col_ptr = params.rhs_base_ptr;
638   void* dst_col_ptr = params.dst_base_ptr;
639   const std::int32_t* bias_col_ptr = params.bias;
640   if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
641     bias_col_ptr += params.start_row;
642   }
643 
644   const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
645   void* dst_ptr = dst_col_ptr;
646   const std::int32_t* bias_ptr = bias_col_ptr;
647 
648   const std::int32_t lhs_zero_point = params.lhs_zero_point;
649   const bool has_rhs_sums_offsets =
650       (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point;
651   std::int32_t rhs_sums_offsets[16];
652   if (has_rhs_sums_offsets) {
653     const __m512i rhs_sums_offset_v =
654         _mm512_mullo_epi32(_mm512_set1_epi32(lhs_zero_point),
655                            _mm512_loadu_si512(&params.rhs_sums[0]));
656     _mm512_storeu_si512(reinterpret_cast<__m512i*>(rhs_sums_offsets),
657                         rhs_sums_offset_v);
658   }
659 
660   for (int row = params.start_row; row <= params.last_row; row += 16) {
661     const int residual_rows = std::min(params.dst_rows - row, 16);
662 
663     __m512i accum_data_v0;
664 
665     // Initialize with bias.
666     const __mmask16 row_mask =
667         (static_cast<std::uint32_t>(1) << residual_rows) - 1;
668     __m512i initial_accum_data =
669         _mm512_loadu_si512(reinterpret_cast<const __m512i*>(bias_ptr));
670     bias_ptr += bias_ptr_block_increment;
671 
672     const std::int32_t rhs_zero_point = params.rhs_zero_point;
673     if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
674       const __m512i lhs_sums_offset =
675           _mm512_mullo_epi32(_mm512_set1_epi32(rhs_zero_point),
676                              _mm512_loadu_si512(&params.lhs_sums[row]));
677       initial_accum_data =
678           _mm512_sub_epi32(initial_accum_data, lhs_sums_offset);
679     }
680 
681     const std::int32_t prod_zp_depth = params.prod_zp_depth;
682     if (prod_zp_depth != 0) {
683       initial_accum_data = _mm512_add_epi32(initial_accum_data,
684                                             _mm512_set1_epi32(prod_zp_depth));
685     }
686 
687     // Adjustments differing across columns.
688     if (has_rhs_sums_offsets) {
689       accum_data_v0 = _mm512_sub_epi32(initial_accum_data,
690                                        _mm512_set1_epi32(rhs_sums_offsets[0]));
691     } else {
692       accum_data_v0 = initial_accum_data;
693     }
694 
695     const std::int8_t* lhs_ptr = lhs_col_ptr;
696     const void* rhs_ptr = rhs_col_ptr;
697     for (int d = 0; d < params.depth; d += 4) {
698       const __m512i lhs_data = _mm512_loadu_si512(lhs_ptr);
699       const std::int32_t* rhs_data =
700           reinterpret_cast<const std::int32_t*>(rhs_ptr);
701 
702       // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
703       // For simplicity we load 4x the data that we need and process twice the
704       // data  that we need  and store only the data we need.
705       std::int32_t rhs_data_buf[2];
706       if (params.rhs_scalar_size == 1) {
707         rhs_data = rhs_data_buf;
708         const __m128i rhs_data_8bit =
709             _mm_loadu_si128(reinterpret_cast<const __m128i*>(rhs_ptr));
710         const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit);
711         // Now that we have cast the RHS data, we store it so that each value
712         // can be separately loaded in the accumulation loop.
713         _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data_buf),
714                         rhs_16_bit_dup);
715       } else {
716         RUY_DCHECK(params.rhs_scalar_size == 2);
717       }
718 
719       // Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit.
720       const __m512i lhs_16_bit_low =
721           _mm512_cvtepi8_epi16(_mm512_cvtepi32_epi16(lhs_data));
722       // Take bytes 2, 3, 6, 7, 10, 11, ... and expand to 16-bit.
723       const __m512i lhs_16_bit_high = _mm512_cvtepi8_epi16(
724           _mm512_cvtepi32_epi16(_mm512_srli_epi32(lhs_data, 16)));
725 
726       // Process column 0.
727       __m512i accum_v = accum_data_v0;
728       constexpr int index = 0;
729 
730       const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]);
731       const __m512i rhs_16_bit_dup_high =
732           _mm512_set1_epi32(rhs_data[index + 1]);
733 
734       accum_v = _mm512_add_epi32(
735           accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
736       accum_v = _mm512_add_epi32(
737           accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
738       accum_data_v0 = accum_v;
739 
740       lhs_ptr += 16 * 4;
741       rhs_ptr = static_cast<const void*>(static_cast<const char*>(rhs_ptr) +
742                                          16 * 4 * params.rhs_scalar_size);
743     }
744 
745     if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
746       __m512i m_vector;
747       __m512i e_vector;
748       // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
749       int channel = (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? row : 0;
750       m_vector = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(
751           params.multiplier_fixedpoint + channel));
752       e_vector = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(
753           params.multiplier_exponent + channel));
754 
755       const __m512i m_64bit_low =
756           _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 0));
757       const __m512i m_64bit_high =
758           _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 1));
759 
760       const __m512i zero_vector = _mm512_setzero_epi32();
761       const __m512i left_shift = _mm512_max_epi32(e_vector, zero_vector);
762       const __m512i neg_e_vector = _mm512_sub_epi32(zero_vector, e_vector);
763       const __m512i right_shift = _mm512_max_epi32(neg_e_vector, zero_vector);
764       const __m512i final_right_shift = _mm512_set1_epi32(31);
765       const __m512i right_shift_low =
766           _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0));
767       const __m512i right_shift_high =
768           _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1));
769       const __m512i final_right_shift_low = _mm512_cvtepi32_epi64(
770           _mm512_extracti32x8_epi32(final_right_shift, 0));
771       const __m512i final_right_shift_high = _mm512_cvtepi32_epi64(
772           _mm512_extracti32x8_epi32(final_right_shift, 1));
773 
774       // A "half" added for rounding prior to truncation of 64-bit value.
775       const __m512i offset_vector = _mm512_slli_epi64(_mm512_set1_epi64(1), 30);
776 
777       auto rounding_right_shift = [=](__m512i& results,
778                                       const __m512i& exponent) {
779         // Construct the "nudge" value for each lane if the exponent is
780         // greater than 0. Otherwise, the nudge is 0.
781         const __m512i zeros = _mm512_setzero_si512();
782         const auto mask_rightshift_gtz =
783             _mm512_cmpgt_epi64_mask(exponent, zeros);
784         const __m512i one_shift_exp_minus1 =
785             _mm512_sllv_epi64(_mm512_set1_epi64(1),
786                               _mm512_sub_epi64(exponent, _mm512_set1_epi64(1)));
787         __m512i nudge = _mm512_mask_mov_epi64(zeros, mask_rightshift_gtz,
788                                               one_shift_exp_minus1);
789         // Calculate the shifted sum (results + nudge) >> exp.
790         const __m512i r_plus_nudge = _mm512_add_epi64(results, nudge);
791         const __m512i shifted_sum = _mm512_srav_epi64(r_plus_nudge, exponent);
792 
793         // Identify overflow in each lane and create mask.
794         const __m512i one_shift_31minus_exp = _mm512_sllv_epi64(
795             _mm512_set1_epi64(1),
796             _mm512_sub_epi64(_mm512_set1_epi64(31), exponent));
797         const auto mask_num_plus_nudge_overflow = _mm512_cmpgt_epi64_mask(
798             results, _mm512_sub_epi64(_mm512_set1_epi64(0x7fffffff), nudge));
799         // Fill results with either (results + nudge) >> exponent or
800         // 1 << (31 - exp) in the case of overflow.
801         results = _mm512_mask_mov_epi64(
802             shifted_sum, mask_num_plus_nudge_overflow, one_shift_31minus_exp);
803       };
804 
805       // Shift and round column 0.
806       accum_data_v0 = _mm512_sllv_epi32(accum_data_v0, left_shift);
807       // Apply the fixed-point part of the multiplier.
808       __m512i scaled_v_low = _mm512_mul_epi32(
809           _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum_data_v0, 0)),
810           m_64bit_low);
811       __m512i scaled_v_high = _mm512_mul_epi32(
812           _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum_data_v0, 1)),
813           m_64bit_high);
814 
815       scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector);
816       scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector);
817 
818       scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low);
819       scaled_v_high = _mm512_srav_epi64(scaled_v_high, final_right_shift_high);
820 
821       rounding_right_shift(scaled_v_low, right_shift_low);
822       rounding_right_shift(scaled_v_high, right_shift_high);
823 
824       accum_data_v0 =
825           _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
826       accum_data_v0 = _mm512_inserti32x8(
827           accum_data_v0, _mm512_cvtepi64_epi32(scaled_v_high), 1);
828 
829       if (params.dst_zero_point != 0) {
830         __m512i dst_zero_point = _mm512_set1_epi32(params.dst_zero_point);
831         accum_data_v0 = _mm512_add_epi32(accum_data_v0, dst_zero_point);
832       }
833     }
834 
835     const __m512i clamp_max_v = _mm512_set1_epi32(params.clamp_max);
836     const __m512i clamp_min_v = _mm512_set1_epi32(params.clamp_min);
837 
838     if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
839       std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
840       __m512i result = accum_data_v0;
841       result = _mm512_min_epi32(result, clamp_max_v);
842       result = _mm512_max_epi32(result, clamp_min_v);
843       _mm_mask_storeu_epi8(tmp_ptr, row_mask, _mm512_cvtepi32_epi8(result));
844       dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) + 16);
845     } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
846       std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
847       __m512i result = accum_data_v0;
848       result = _mm512_min_epi32(result, clamp_max_v);
849       result = _mm512_max_epi32(result, clamp_min_v);
850       _mm_mask_storeu_epi8(tmp_ptr, row_mask, _mm512_cvtepi32_epi8(result));
851       dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) + 16);
852     } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
853       std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
854       __m512i result = accum_data_v0;
855       result = _mm512_min_epi32(result, clamp_max_v);
856       result = _mm512_max_epi32(result, clamp_min_v);
857       _mm256_mask_storeu_epi16(tmp_ptr, row_mask,
858                                _mm512_cvtepi32_epi16(result));
859       dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) + 16);
860     } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
861       std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr);
862       _mm512_mask_storeu_epi32(tmp_ptr, row_mask, accum_data_v0);
863       dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) + 16);
864     } else {
865       RUY_DCHECK(false);
866     }
867 
868     lhs_col_ptr += 16 * params.lhs_stride;
869   }  // End row-block loop.
870 }  // NOLINT(readability/fn_size)
871 
872 void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) {
873   profiler::ScopeLabel label("Kernel kAvx512 float");
874 
875   // As parameters are defined, we need to scale by sizeof(float).
876   const std::int64_t lhs_stride = params.lhs_stride >> 2;
877   const std::int64_t dst_stride = params.dst_stride >> 2;
878   const std::int64_t rhs_stride = params.rhs_stride >> 2;
879 
880   int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0;
881   const int end_row = std::min(params.dst_rows, params.last_row + 16);
882   const int end_col = std::min(params.dst_cols, params.last_col + 16);
883 
884   const float* adj_rhs_col_ptr =
885       params.rhs_base_ptr - params.start_col * rhs_stride;
886   float* adj_dst_col_ptr =
887       params.dst_base_ptr - params.start_col * dst_stride - params.start_row;
888   const float* adj_lhs_col_ptr =
889       params.lhs_base_ptr - params.start_row * lhs_stride;
890   const float* bias_ptr = params.bias;
891 
892   const __m512 clamp_max_v = _mm512_set1_ps(params.clamp_max);
893   const __m512 clamp_min_v = _mm512_set1_ps(params.clamp_min);
894   const bool channel_dimension_is_col =
895       params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL;
896 
897   int col = params.start_col;
898   for (; col <= end_col - 16; col += 16) {
899     const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
900     float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
901 
902     int row = params.start_row;
903     for (; row <= end_row - 16; row += 16) {
904       const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
905       float* dst_ptr = dst_col_ptr + row;
906 
907       // Process block in two halves, split by columns.
908 #pragma unroll(1)
909       for (int mmm = 0; mmm < 2; ++mmm) {
910         __m512 accum_data_v0;
911         __m512 accum_data_v1;
912         __m512 accum_data_v2;
913         __m512 accum_data_v3;
914         __m512 accum_data_v4;
915         __m512 accum_data_v5;
916         __m512 accum_data_v6;
917         __m512 accum_data_v7;
918 
919         // Initialize with bias.
920         if (channel_dimension_is_col) {
921           const float* bias_elem_ptr =
922               bias_ptr + (col + 8 * mmm) * bias_ptr_block_increment;
923           accum_data_v0 = _mm512_set1_ps(bias_elem_ptr[0]);
924           accum_data_v1 = _mm512_set1_ps(bias_elem_ptr[1]);
925           accum_data_v2 = _mm512_set1_ps(bias_elem_ptr[2]);
926           accum_data_v3 = _mm512_set1_ps(bias_elem_ptr[3]);
927           accum_data_v4 = _mm512_set1_ps(bias_elem_ptr[4]);
928           accum_data_v5 = _mm512_set1_ps(bias_elem_ptr[5]);
929           accum_data_v6 = _mm512_set1_ps(bias_elem_ptr[6]);
930           accum_data_v7 = _mm512_set1_ps(bias_elem_ptr[7]);
931         } else {
932           const __m512 initial_accum_data =
933               _mm512_loadu_ps(bias_ptr + row * bias_ptr_block_increment);
934 
935           accum_data_v0 = initial_accum_data;
936           accum_data_v1 = initial_accum_data;
937           accum_data_v2 = initial_accum_data;
938           accum_data_v3 = initial_accum_data;
939           accum_data_v4 = initial_accum_data;
940           accum_data_v5 = initial_accum_data;
941           accum_data_v6 = initial_accum_data;
942           accum_data_v7 = initial_accum_data;
943         }
944 
945         const float* lhs_ptr = lhs_col_ptr;
946         const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
947         for (int d = 0; d < (params.depth - 1); ++d) {
948           const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
949           const float* rhs_data = rhs_ptr;
950           lhs_ptr += 16;
951           rhs_ptr += 16;
952 
953           // GCC and clang can fuse set1+FMA into an FMA with EVEX broadcast:
954           // https://gcc.godbolt.org/z/xbfqWYfn1. Clang is more likely to do
955           // so if given an rvalue.
956           accum_data_v0 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[0]),
957                                           accum_data_v0);
958           accum_data_v1 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[1]),
959                                           accum_data_v1);
960           accum_data_v2 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[2]),
961                                           accum_data_v2);
962           accum_data_v3 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[3]),
963                                           accum_data_v3);
964           accum_data_v4 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[4]),
965                                           accum_data_v4);
966           accum_data_v5 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[5]),
967                                           accum_data_v5);
968           accum_data_v6 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[6]),
969                                           accum_data_v6);
970           accum_data_v7 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[7]),
971                                           accum_data_v7);
972         }
973         {  // nested extra blocks lead to measurable speed gains
974           const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
975           const float* rhs_data = rhs_ptr;
976           accum_data_v0 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[0]),
977                                           accum_data_v0);
978           accum_data_v1 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[1]),
979                                           accum_data_v1);
980           accum_data_v2 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[2]),
981                                           accum_data_v2);
982           accum_data_v3 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[3]),
983                                           accum_data_v3);
984           accum_data_v4 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[4]),
985                                           accum_data_v4);
986           accum_data_v5 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[5]),
987                                           accum_data_v5);
988           accum_data_v6 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[6]),
989                                           accum_data_v6);
990           accum_data_v7 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[7]),
991                                           accum_data_v7);
992           {
993             float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride;
994             accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v);
995             accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v);
996             _mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0);
997             accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v);
998             accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v);
999             _mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1);
1000             accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v);
1001             accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v);
1002             _mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2);
1003             accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v);
1004             accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v);
1005             _mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3);
1006             accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v);
1007             accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v);
1008             _mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4);
1009             accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v);
1010             accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v);
1011             _mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5);
1012             accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v);
1013             accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v);
1014             _mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6);
1015             accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v);
1016             accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v);
1017             _mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7);
1018           }
1019         }
1020       }
1021     }    // End row-block loop.
1022 
1023     // The unrolling within this conditional may be somewhat pointless. It
1024     // depends on the kinds of models.
1025     if (row < end_row) {
1026       const int residual_rows = end_row - row;
1027 
1028       const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
1029       float* dst_ptr = dst_col_ptr + row;
1030 
1031       const __mmask16 row_mask =
1032           (static_cast<std::uint32_t>(1) << residual_rows) - 1;
1033 
1034       // Process block in two halves, split by columns.
1035       for (int mmm = 0; mmm < 2; ++mmm) {
1036         __m512 accum_data_v0;
1037         __m512 accum_data_v1;
1038         __m512 accum_data_v2;
1039         __m512 accum_data_v3;
1040         __m512 accum_data_v4;
1041         __m512 accum_data_v5;
1042         __m512 accum_data_v6;
1043         __m512 accum_data_v7;
1044 
1045         // Initialize with bias.
1046         if (channel_dimension_is_col) {
1047           const float* bias_elem_ptr =
1048               bias_ptr + (col + 8 * mmm) * bias_ptr_block_increment;
1049           accum_data_v0 = _mm512_set1_ps(bias_elem_ptr[0]);
1050           accum_data_v1 = _mm512_set1_ps(bias_elem_ptr[1]);
1051           accum_data_v2 = _mm512_set1_ps(bias_elem_ptr[2]);
1052           accum_data_v3 = _mm512_set1_ps(bias_elem_ptr[3]);
1053           accum_data_v4 = _mm512_set1_ps(bias_elem_ptr[4]);
1054           accum_data_v5 = _mm512_set1_ps(bias_elem_ptr[5]);
1055           accum_data_v6 = _mm512_set1_ps(bias_elem_ptr[6]);
1056           accum_data_v7 = _mm512_set1_ps(bias_elem_ptr[7]);
1057         } else {
1058           const __m512 initial_accum_data =
1059               _mm512_loadu_ps(bias_ptr + row * bias_ptr_block_increment);
1060 
1061           accum_data_v0 = initial_accum_data;
1062           accum_data_v1 = initial_accum_data;
1063           accum_data_v2 = initial_accum_data;
1064           accum_data_v3 = initial_accum_data;
1065           accum_data_v4 = initial_accum_data;
1066           accum_data_v5 = initial_accum_data;
1067           accum_data_v6 = initial_accum_data;
1068           accum_data_v7 = initial_accum_data;
1069         }
1070 
1071         const float* lhs_ptr = lhs_col_ptr;
1072         const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
1073         for (int d = 0; d < (params.depth - 1); ++d) {
1074           const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
1075           const float* rhs_data = rhs_ptr;
1076           lhs_ptr += 16;
1077           rhs_ptr += 16;
1078           // GCC and clang can fuse set1+FMA into an FMA with EVEX broadcast:
1079           // https://gcc.godbolt.org/z/xbfqWYfn1. Clang is more likely to do
1080           // so if given an rvalue.
1081           accum_data_v0 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[0]),
1082                                           accum_data_v0);
1083           accum_data_v1 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[1]),
1084                                           accum_data_v1);
1085           accum_data_v2 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[2]),
1086                                           accum_data_v2);
1087           accum_data_v3 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[3]),
1088                                           accum_data_v3);
1089           accum_data_v4 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[4]),
1090                                           accum_data_v4);
1091           accum_data_v5 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[5]),
1092                                           accum_data_v5);
1093           accum_data_v6 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[6]),
1094                                           accum_data_v6);
1095           accum_data_v7 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[7]),
1096                                           accum_data_v7);
1097         }
1098         {
1099           const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
1100           const float* rhs_data = rhs_ptr;
1101           accum_data_v0 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[0]),
1102                                           accum_data_v0);
1103           accum_data_v1 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[1]),
1104                                           accum_data_v1);
1105           accum_data_v2 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[2]),
1106                                           accum_data_v2);
1107           accum_data_v3 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[3]),
1108                                           accum_data_v3);
1109           accum_data_v4 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[4]),
1110                                           accum_data_v4);
1111           accum_data_v5 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[5]),
1112                                           accum_data_v5);
1113           accum_data_v6 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[6]),
1114                                           accum_data_v6);
1115           accum_data_v7 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[7]),
1116                                           accum_data_v7);
1117           {
1118             float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride;
1119             accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v);
1120             accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v);
1121             _mm512_mask_storeu_ps(block_ptr + 0 * dst_stride, row_mask,
1122                                   accum_data_v0);
1123             accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v);
1124             accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v);
1125             _mm512_mask_storeu_ps(block_ptr + 1 * dst_stride, row_mask,
1126                                   accum_data_v1);
1127             accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v);
1128             accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v);
1129             _mm512_mask_storeu_ps(block_ptr + 2 * dst_stride, row_mask,
1130                                   accum_data_v2);
1131             accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v);
1132             accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v);
1133             _mm512_mask_storeu_ps(block_ptr + 3 * dst_stride, row_mask,
1134                                   accum_data_v3);
1135             accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v);
1136             accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v);
1137             _mm512_mask_storeu_ps(block_ptr + 4 * dst_stride, row_mask,
1138                                   accum_data_v4);
1139             accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v);
1140             accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v);
1141             _mm512_mask_storeu_ps(block_ptr + 5 * dst_stride, row_mask,
1142                                   accum_data_v5);
1143             accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v);
1144             accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v);
1145             _mm512_mask_storeu_ps(block_ptr + 6 * dst_stride, row_mask,
1146                                   accum_data_v6);
1147             accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v);
1148             accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v);
1149             _mm512_mask_storeu_ps(block_ptr + 7 * dst_stride, row_mask,
1150                                   accum_data_v7);
1151           }
1152         }
1153       }  // Inner half-block loop.
1154     }    // Residual rows, main col-block loop.
1155   }      // End col-block loop.
1156 
1157   if (col < end_col) {
1158     RUY_DCHECK_GE(end_col - col, 0);
1159     RUY_DCHECK_LT(end_col - col, 16);
1160 
1161     __m512 accum_data_v[8];
1162 
1163     const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
1164     float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
1165 
1166     for (int row = params.start_row; row < end_row; row += 16) {
1167       const int residual_rows = std::min(end_row - row, 16);
1168 
1169       const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
1170       float* dst_ptr = dst_col_ptr + row;
1171 
1172       const __mmask16 row_mask =
1173           (static_cast<std::uint32_t>(1) << residual_rows) - 1;
1174 
1175       // Process block in two halves, split by columns.
1176       for (int mmm = 0; mmm < 2; ++mmm) {
1177         // Initialize with bias.
1178         if (channel_dimension_is_col) {
1179           const float* bias_elem_ptr =
1180               bias_ptr + (col + 8 * mmm) * bias_ptr_block_increment;
1181           for (int j = 0; j < 8; ++j) {
1182             accum_data_v[j] = _mm512_set1_ps(bias_elem_ptr[j]);
1183           }
1184         } else {
1185           const __m512 initial_accum_data =
1186               _mm512_loadu_ps(bias_ptr + row * bias_ptr_block_increment);
1187           for (int j = 0; j < 8; ++j) {
1188             accum_data_v[j] = initial_accum_data;
1189           }
1190         }
1191 
1192         const float* lhs_ptr = lhs_col_ptr;
1193         const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
1194         for (int d = 0; d < params.depth; ++d) {
1195           const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
1196           const float* rhs_data = rhs_ptr;
1197 
1198           for (int j = 0; j < 8; ++j) {
1199             const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data[j]);
1200             accum_data_v[j] =
1201                 _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]);
1202           }
1203           lhs_ptr += 16;
1204           rhs_ptr += 16;
1205         }
1206 
1207         const int residual_cols = std::min(end_col - col - 8 * mmm, 8);
1208 
1209         if (residual_rows == 16) {
1210           if (residual_cols == 8) {
1211             for (int j = 0; j < 8; ++j) {
1212               float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride;
1213               accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v);
1214               accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v);
1215               _mm512_storeu_ps(block_ptr, accum_data_v[j]);
1216             }
1217           } else {
1218             for (int j = 0; j < residual_cols; ++j) {
1219               float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride;
1220               accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v);
1221               accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v);
1222               _mm512_storeu_ps(block_ptr, accum_data_v[j]);
1223             }
1224           }
1225         } else {
1226           for (int j = 0; j < residual_cols; ++j) {
1227             float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride;
1228             accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v);
1229             accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v);
1230             _mm512_mask_storeu_ps(block_ptr, row_mask, accum_data_v[j]);
1231           }
1232         }
1233       }  // Inner half-block loop.
1234     }    // End row-block loop.
1235   }      // Residual cols.
1236 }
1237 
1238 void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& params) {
1239   profiler::ScopeLabel label("Kernel kAvx512 float GEMV");
1240 
1241   RUY_DCHECK_EQ(params.dst_cols, 1);
1242   RUY_DCHECK_EQ(params.last_col, 0);
1243   RUY_DCHECK_EQ(params.start_col, 0);
1244 
1245   // As parameters are defined, we need to scale by sizeof(float).
1246   const std::int64_t lhs_stride = params.lhs_stride >> 2;
1247 
1248   int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0;
1249   const int end_row = std::min(params.dst_rows, params.last_row + 16);
1250 
1251   float* adj_dst_col_ptr = params.dst_base_ptr - params.start_row;
1252   const float* adj_lhs_col_ptr =
1253       params.lhs_base_ptr - params.start_row * lhs_stride;
1254   const float* bias_col_ptr = params.bias;
1255 
1256   const __m512 clamp_max_v = _mm512_set1_ps(params.clamp_max);
1257   const __m512 clamp_min_v = _mm512_set1_ps(params.clamp_min);
1258 
1259   __m512 accum_data_v;
1260 
1261   const float* rhs_col_ptr = params.rhs_base_ptr;
1262   float* dst_col_ptr = adj_dst_col_ptr;
1263 
1264   int row = params.start_row;
1265   for (; row <= end_row - 16; row += 16) {
1266     const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
1267     float* dst_ptr = dst_col_ptr + row;
1268     const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
1269 
1270     // Initialize with bias.
1271     accum_data_v = _mm512_loadu_ps(bias_ptr);
1272 
1273     const float* lhs_ptr = lhs_col_ptr;
1274     const float* rhs_ptr = rhs_col_ptr;
1275     for (int d = 0; d < params.depth; ++d) {
1276       const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
1277       const float rhs_data = *rhs_ptr;
1278 
1279       const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data);
1280       accum_data_v = _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v);
1281       lhs_ptr += 16;
1282       rhs_ptr += 16;
1283     }
1284 
1285     accum_data_v = _mm512_min_ps(accum_data_v, clamp_max_v);
1286     accum_data_v = _mm512_max_ps(accum_data_v, clamp_min_v);
1287     _mm512_storeu_ps(dst_ptr, accum_data_v);
1288   }  // End row-block loop.
1289 
1290   if (row < end_row) {
1291     const int residual_rows = end_row - row;
1292     RUY_CHECK_GE(residual_rows, 1);
1293     RUY_CHECK_LT(residual_rows, 16);
1294 
1295     const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
1296     float* dst_ptr = dst_col_ptr + row;
1297     const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
1298 
1299     // Initialize with bias.
1300     const __mmask16 row_mask =
1301         (static_cast<std::uint32_t>(1) << residual_rows) - 1;
1302     accum_data_v = _mm512_loadu_ps(bias_ptr);
1303 
1304     const float* lhs_ptr = lhs_col_ptr;
1305     const float* rhs_ptr = rhs_col_ptr;
1306     for (int d = 0; d < params.depth; ++d) {
1307       const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
1308       const float rhs_data = *rhs_ptr;
1309 
1310       const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data);
1311       accum_data_v = _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v);
1312       lhs_ptr += 16;
1313       rhs_ptr += 16;
1314     }
1315 
1316     accum_data_v = _mm512_min_ps(accum_data_v, clamp_max_v);
1317     accum_data_v = _mm512_max_ps(accum_data_v, clamp_min_v);
1318     _mm512_mask_storeu_ps(dst_ptr, row_mask, accum_data_v);
1319   }  // End handling of residual rows.
1320 }
1321 
1322 #endif  //  RUY_PLATFORM_AVX512 && RUY_OPT(ASM)
1323 
1324 }  // namespace ruy
1325