xref: /aosp_15_r20/external/ruy/ruy/pack_x86.h (revision bb86c7ed5fb1b98a7eac808e443a46cc8b90dfc0)
1 /* Copyright 2019 Google LLC. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef RUY_RUY_PACK_X86_H_
17 #define RUY_RUY_PACK_X86_H_
18 
19 #include <algorithm>
20 #include <cstdint>
21 #include <cstring>
22 #include <type_traits>
23 
24 #include "ruy/check_macros.h"
25 #include "ruy/mat.h"
26 #include "ruy/opt_set.h"
27 #include "ruy/pack_common.h"
28 #include "ruy/path.h"
29 #include "ruy/platform.h"
30 #include "ruy/profiler/instrumentation.h"
31 #include "ruy/tune.h"
32 
33 namespace ruy {
34 
35 #if RUY_PLATFORM_X86
36 
37 RUY_INHERIT_PACK(Path::kStandardCpp, Path::kAvx)
38 RUY_INHERIT_PACK(Path::kAvx, Path::kAvx2Fma)
39 RUY_INHERIT_PACK(Path::kAvx2Fma, Path::kAvx512)
40 
41 RUY_USE_MEMCPY_ROWMAJOR_FLOAT_PACK(Path::kAvx2Fma, 8)
42 RUY_USE_MEMCPY_ROWMAJOR_FLOAT_PACK(Path::kAvx512, 16)
43 
44 template <>
45 struct PackedTypeImpl<Path::kAvx, std::uint8_t> {
46   using Type = std::int8_t;
47 };
48 
49 template <>
50 struct PackedTypeImpl<Path::kAvx2Fma, std::uint8_t> {
51   using Type = std::int8_t;
52 };
53 template <>
54 struct PackedTypeImpl<Path::kAvx512, std::uint8_t> {
55   using Type = std::int8_t;
56 };
57 
58 // Note that source and zero buffers can be uint8 type, but in the packing
59 // function are reinterpreted as int8, and are XOR-ed with input_xor.
60 void Pack8bitColMajorForAvx2(const std::int8_t* src_ptr, std::int8_t input_xor,
61                              const std::int8_t* zerobuf, int src_stride,
62                              int remaining_src_cols, int src_rows,
63                              std::int8_t* packed_ptr, std::int32_t* sums_ptr);
64 
65 template <typename Scalar>
66 struct PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kColMajor, 4, 8>,
67                 Scalar, std::int8_t, std::int32_t, Order::kColMajor> {
68   static_assert(std::is_same<Scalar, std::int8_t>::value ||
69                     std::is_same<Scalar, std::uint8_t>::value,
70                 "");
71   using Layout = FixedKernelLayout<Order::kColMajor, 4, 8>;
72   static constexpr std::int8_t kInputXor =
73       std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
74 
75   static void Run(Tuning, const Mat<Scalar>& src_matrix,
76                   PMat<std::int8_t>* packed_matrix, int start_col,
77                   int end_col) {
78     profiler::ScopeLabel label("Pack (AVX2 8-bit)");
79 
80     RUY_DCHECK(IsColMajor(src_matrix.layout));
81     RUY_DCHECK(IsColMajor(packed_matrix->layout));
82     RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
83     RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
84     std::int32_t* sums = packed_matrix->sums;
85     Scalar zerobuf[Layout::kCols * Layout::kRows];
86     memset(zerobuf, packed_matrix->zero_point ^ kInputXor,
87            Layout::kCols * Layout::kRows * sizeof(Scalar));
88     for (int block_col = start_col; block_col < end_col;
89          block_col += Layout::kCols) {
90       std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
91       int src_stride = src_matrix.layout.stride;
92       const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col;
93       int remaining_src_cols = src_matrix.layout.cols - block_col;
94 
95       static constexpr int block_col_mask = ~(Layout::kCols - 1);  // High bits.
96       std::int8_t* packed_ptr =
97           packed_matrix->data +
98           packed_matrix->layout.stride * (block_col & block_col_mask);
99       Pack8bitColMajorForAvx2(
100           reinterpret_cast<const std::int8_t*>(src_ptr), kInputXor,
101           reinterpret_cast<const std::int8_t*>(zerobuf), src_stride,
102           remaining_src_cols, src_matrix.layout.rows, packed_ptr, sums_ptr);
103     }
104   }
105 };
106 
107 void Pack8bitColMajorForAvx(const std::int8_t* src_ptr, std::int8_t input_xor,
108                             const std::int8_t* zerobuf, int src_stride,
109                             int remaining_src_cols, int src_rows,
110                             std::int8_t* packed_ptr, std::int32_t* sums_ptr);
111 
112 template <typename Scalar>
113 struct PackImpl<Path::kAvx, FixedKernelLayout<Order::kColMajor, 4, 8>, Scalar,
114                 std::int8_t, std::int32_t, Order::kColMajor> {
115   static_assert(std::is_same<Scalar, std::int8_t>::value ||
116                     std::is_same<Scalar, std::uint8_t>::value,
117                 "");
118   using Layout = FixedKernelLayout<Order::kColMajor, 4, 8>;
119   static constexpr std::int8_t kInputXor =
120       std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
121 
122   static void Run(Tuning, const Mat<Scalar>& src_matrix,
123                   PMat<std::int8_t>* packed_matrix, int start_col,
124                   int end_col) {
125     profiler::ScopeLabel label("Pack (AVX 8-bit)");
126 
127     RUY_DCHECK(IsColMajor(src_matrix.layout));
128     RUY_DCHECK(IsColMajor(packed_matrix->layout));
129     RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
130     RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
131     std::int32_t* sums = packed_matrix->sums;
132     Scalar zerobuf[Layout::kCols * Layout::kRows];
133     memset(zerobuf, packed_matrix->zero_point ^ kInputXor,
134            Layout::kCols * Layout::kRows * sizeof(Scalar));
135     for (int block_col = start_col; block_col < end_col;
136          block_col += Layout::kCols) {
137       std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
138       int src_stride = src_matrix.layout.stride;
139       const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col;
140       int remaining_src_cols = src_matrix.layout.cols - block_col;
141 
142       static constexpr int block_col_mask = ~(Layout::kCols - 1);  // High bits.
143       std::int8_t* packed_ptr =
144           packed_matrix->data +
145           packed_matrix->layout.stride * (block_col & block_col_mask);
146       Pack8bitColMajorForAvx(
147           reinterpret_cast<const std::int8_t*>(src_ptr), kInputXor,
148           reinterpret_cast<const std::int8_t*>(zerobuf), src_stride,
149           remaining_src_cols, src_matrix.layout.rows, packed_ptr, sums_ptr);
150     }
151   }
152 };
153 
154 void PackFloatColMajorForAvx(const float* src_ptr, const float* zerobuf,
155                              int src_stride, int remaining_src_cols,
156                              int src_rows, float* packed_ptr);
157 
158 template <>
159 struct PackImpl<Path::kAvx, FixedKernelLayout<Order::kRowMajor, 1, 8>, float,
160                 float, float, Order::kColMajor> {
161   using Layout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
162   static void Run(Tuning, const Mat<float>& src_matrix,
163                   PMat<float>* packed_matrix, int start_col, int end_col) {
164     profiler::ScopeLabel label("Pack (AVX float)");
165 
166     RUY_DCHECK(IsColMajor(src_matrix.layout));
167     RUY_DCHECK(IsColMajor(packed_matrix->layout));
168     RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
169     RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
170     const float zerobuf[Layout::kCols] = {
171         0.0f};  // Remainder default inits to 0.0f.
172     for (int block_col = start_col; block_col < end_col;
173          block_col += Layout::kCols) {
174       int src_stride = src_matrix.layout.stride;
175       const float* src_ptr = src_matrix.data.get() + src_stride * block_col;
176       int remaining_src_cols = src_matrix.layout.cols - block_col;
177 
178       static constexpr int block_col_mask = ~(Layout::kCols - 1);  // High bits.
179       float* packed_ptr =
180           packed_matrix->data +
181           packed_matrix->layout.stride * (block_col & block_col_mask);
182       PackFloatColMajorForAvx(src_ptr, zerobuf, src_stride, remaining_src_cols,
183                               src_matrix.layout.rows, packed_ptr);
184     }
185   }
186 };
187 
188 void PackFloatColMajorForAvx2(const float* src_ptr, const float* zerobuf,
189                               int src_stride, int remaining_src_cols,
190                               int src_rows, float* packed_ptr);
191 
192 template <>
193 struct PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kRowMajor, 1, 8>,
194                 float, float, float, Order::kColMajor> {
195   using Layout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
196   static void Run(Tuning, const Mat<float>& src_matrix,
197                   PMat<float>* packed_matrix, int start_col, int end_col) {
198     profiler::ScopeLabel label("Pack (AVX2 float)");
199 
200     RUY_DCHECK(IsColMajor(src_matrix.layout));
201     RUY_DCHECK(IsColMajor(packed_matrix->layout));
202     RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
203     RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
204     const float zerobuf[Layout::kCols] = {
205         0.0f};  // Remainder default inits to 0.0f.
206     for (int block_col = start_col; block_col < end_col;
207          block_col += Layout::kCols) {
208       int src_stride = src_matrix.layout.stride;
209       const float* src_ptr = src_matrix.data.get() + src_stride * block_col;
210       int remaining_src_cols = src_matrix.layout.cols - block_col;
211 
212       static constexpr int block_col_mask = ~(Layout::kCols - 1);  // High bits.
213       float* packed_ptr =
214           packed_matrix->data +
215           packed_matrix->layout.stride * (block_col & block_col_mask);
216       PackFloatColMajorForAvx2(src_ptr, zerobuf, src_stride, remaining_src_cols,
217                                src_matrix.layout.rows, packed_ptr);
218     }
219   }
220 };
221 
222 // Note that source and zero buffers can be uint8 type, but in the packing
223 // function are reinterpreted as int8, and are XOR-ed with input_xor.
224 void Pack8bitColMajorForAvx512(const std::int8_t* src_ptr,
225                                std::int8_t input_xor,
226                                const std::int8_t* zerobuf, int src_stride,
227                                int remaining_src_cols, int src_rows,
228                                std::int8_t* packed_ptr, std::int32_t* sums_ptr);
229 
230 template <typename Scalar>
231 struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
232                 Scalar, std::int8_t, std::int32_t, Order::kColMajor> {
233   static_assert(std::is_same<Scalar, std::int8_t>::value ||
234                     std::is_same<Scalar, std::uint8_t>::value,
235                 "");
236   using Layout = FixedKernelLayout<Order::kColMajor, 4, 16>;
237   static constexpr int kHalfLayoutCols =
238       8;  // Half the number of cols in a block.
239   static constexpr std::int8_t kInputXor =
240       std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
241 
242   static void Run(Tuning, const Mat<Scalar>& src_matrix,
243                   PMat<std::int8_t>* packed_matrix, int start_col,
244                   int end_col) {
245     profiler::ScopeLabel label("Pack (AVX-512 8-bit)");
246 
247     RUY_DCHECK(IsColMajor(src_matrix.layout));
248     RUY_DCHECK(IsColMajor(packed_matrix->layout));
249     RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
250     RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
251     RUY_DCHECK_EQ(kHalfLayoutCols * 2, Layout::kCols);
252     std::int32_t* sums = packed_matrix->sums;
253     Scalar zerobuf[kHalfLayoutCols * Layout::kRows];
254     memset(zerobuf, packed_matrix->zero_point ^ kInputXor,
255            kHalfLayoutCols * Layout::kRows * sizeof(Scalar));
256     for (int block_col = start_col; block_col < end_col;
257          block_col += Layout::kCols) {
258       std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
259       int src_stride = src_matrix.layout.stride;
260       const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col;
261       int remaining_src_cols = src_matrix.layout.cols - block_col;
262 
263       static constexpr int block_col_mask = ~(Layout::kCols - 1);  // High bits.
264       std::int8_t* packed_ptr =
265           packed_matrix->data +
266           packed_matrix->layout.stride * (block_col & block_col_mask);
267       Pack8bitColMajorForAvx512(
268           reinterpret_cast<const std::int8_t*>(src_ptr), kInputXor,
269           reinterpret_cast<const std::int8_t*>(zerobuf), src_stride,
270           remaining_src_cols, src_matrix.layout.rows, packed_ptr, sums_ptr);
271     }
272   }
273 };
274 
275 void Pack16bitColMajorForAvx512(const std::int16_t* src_ptr,
276                                 const std::int16_t* zerobuf, int src_stride,
277                                 int remaining_src_cols, int src_rows,
278                                 std::int16_t* packed_ptr,
279                                 std::int32_t* sums_ptr);
280 
281 template <>
282 struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
283                 std::int16_t, std::int16_t, std::int32_t, Order::kColMajor> {
284   using Layout = FixedKernelLayout<Order::kColMajor, 4, 16>;
285   static constexpr int kHalfLayoutCols =
286       8;  // Half the number of cols in a block.
287 
288   static void Run(Tuning, const Mat<std::int16_t>& src_matrix,
289                   PMat<std::int16_t>* packed_matrix, int start_col,
290                   int end_col) {
291     profiler::ScopeLabel label("Pack (AVX-512 16-bit)");
292 
293     RUY_DCHECK(IsColMajor(src_matrix.layout));
294     RUY_DCHECK(IsColMajor(packed_matrix->layout));
295     RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
296     RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
297     RUY_DCHECK_EQ(kHalfLayoutCols * 2, Layout::kCols);
298     std::int32_t* sums = packed_matrix->sums;
299     std::int16_t zerobuf[kHalfLayoutCols * Layout::kRows];
300     std::fill(zerobuf, zerobuf + kHalfLayoutCols * Layout::kRows,
301               static_cast<int16_t>(packed_matrix->zero_point));
302     for (int block_col = start_col; block_col < end_col;
303          block_col += Layout::kCols) {
304       std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
305       int src_stride = src_matrix.layout.stride;
306       const std::int16_t* src_ptr =
307           src_matrix.data.get() + src_stride * block_col;
308       int remaining_src_cols = src_matrix.layout.cols - block_col;
309 
310       static constexpr int block_col_mask = ~(Layout::kCols - 1);
311       std::int16_t* packed_ptr =
312           packed_matrix->data +
313           packed_matrix->layout.stride * (block_col & block_col_mask);
314       Pack16bitColMajorForAvx512(src_ptr, zerobuf, src_stride,
315                                  remaining_src_cols, src_matrix.layout.rows,
316                                  packed_ptr, sums_ptr);
317     }
318   }
319 };
320 
321 void PackFloatColMajorForAvx512(const float* src_ptr, const float* zerobuf,
322                                 int src_stride, int remaining_src_cols,
323                                 int src_rows, float* packed_ptr);
324 
325 template <>
326 struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kRowMajor, 1, 16>,
327                 float, float, float, Order::kColMajor> {
328   static void Run(Tuning, const Mat<float>& src_matrix,
329                   PMat<float>* packed_matrix, int start_col, int end_col) {
330     profiler::ScopeLabel label("Pack (AVX-512 float)");
331     using Layout = FixedKernelLayout<Order::kRowMajor, 1, 16>;
332     RUY_DCHECK(IsColMajor(src_matrix.layout));
333     RUY_DCHECK(IsColMajor(packed_matrix->layout));
334     RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
335     RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
336     const float zerobuf[Layout::kCols] = {
337         0.0f};  // Remainder default inits to 0.0f.
338     for (int block_col = start_col; block_col < end_col;
339          block_col += Layout::kCols) {
340       int src_stride = src_matrix.layout.stride;
341       const float* src_ptr = src_matrix.data.get() + src_stride * block_col;
342       int remaining_src_cols = src_matrix.layout.cols - block_col;
343 
344       static constexpr int block_col_mask = ~(Layout::kCols - 1);  // High bits.
345       float* packed_ptr =
346           packed_matrix->data +
347           packed_matrix->layout.stride * (block_col & block_col_mask);
348       PackFloatColMajorForAvx512(src_ptr, zerobuf, src_stride,
349                                  remaining_src_cols, src_matrix.layout.rows,
350                                  packed_ptr);
351     }
352   }
353 };
354 
355 void Pack8bitRowMajorForAvx2(const std::uint8_t* src_ptr, int src_stride,
356                              int src_zero_point, std::int8_t* packed_ptr,
357                              int packed_stride, int start_col, int end_col,
358                              int src_cols, int block_row, int src_rows,
359                              int input_xor, std::int32_t* sums);
360 
361 template <typename Scalar>
362 struct PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kColMajor, 4, 8>,
363                 Scalar, std::int8_t, std::int32_t, Order::kRowMajor> {
364   static void Run(Tuning, const Mat<Scalar>& src_matrix,
365                   PMat<std::int8_t>* packed_matrix, int start_col,
366                   int end_col) {
367     profiler::ScopeLabel label("Pack (kAvx2Fma 8bit row-major)");
368     RUY_DCHECK_EQ(src_matrix.layout.order, Order::kRowMajor);
369     static constexpr int kInputXor =
370         std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
371     std::int32_t* sums = packed_matrix->sums;
372     std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col));
373     int block_row = 0;
374     for (; block_row < packed_matrix->layout.rows; block_row += 4) {
375       int src_stride = src_matrix.layout.stride;
376       int packed_stride = packed_matrix->layout.stride;
377       const Scalar* src_ptr =
378           src_matrix.data.get() + block_row * src_stride + start_col;
379       std::int8_t* packed_ptr =
380           packed_matrix->data + start_col * packed_stride + block_row * 8;
381       Pack8bitRowMajorForAvx2(reinterpret_cast<const std::uint8_t*>(src_ptr),
382                               src_stride, src_matrix.zero_point, packed_ptr,
383                               packed_stride, start_col, end_col,
384                               src_matrix.layout.cols, block_row,
385                               src_matrix.layout.rows, kInputXor, sums);
386     }
387   }
388 };
389 
390 void Pack8bitRowMajorForAvx(const std::uint8_t* src_ptr, int src_stride,
391                             int src_zero_point, std::int8_t* packed_ptr,
392                             int packed_stride, int start_col, int end_col,
393                             int src_cols, int block_row, int src_rows,
394                             int input_xor, std::int32_t* sums);
395 
396 template <typename Scalar>
397 struct PackImpl<Path::kAvx, FixedKernelLayout<Order::kColMajor, 4, 8>, Scalar,
398                 std::int8_t, std::int32_t, Order::kRowMajor> {
399   static void Run(Tuning, const Mat<Scalar>& src_matrix,
400                   PMat<std::int8_t>* packed_matrix, int start_col,
401                   int end_col) {
402     profiler::ScopeLabel label("Pack (AVX 8bit row-major)");
403     RUY_DCHECK_EQ(src_matrix.layout.order, Order::kRowMajor);
404     static constexpr int kInputXor =
405         std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
406     std::int32_t* sums = packed_matrix->sums;
407     std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col));
408     int block_row = 0;
409     for (; block_row < packed_matrix->layout.rows; block_row += 4) {
410       int src_stride = src_matrix.layout.stride;
411       int packed_stride = packed_matrix->layout.stride;
412       const Scalar* src_ptr =
413           src_matrix.data.get() + block_row * src_stride + start_col;
414       std::int8_t* packed_ptr =
415           packed_matrix->data + start_col * packed_stride + block_row * 8;
416       Pack8bitRowMajorForAvx(reinterpret_cast<const std::uint8_t*>(src_ptr),
417                              src_stride, src_matrix.zero_point, packed_ptr,
418                              packed_stride, start_col, end_col,
419                              src_matrix.layout.cols, block_row,
420                              src_matrix.layout.rows, kInputXor, sums);
421     }
422   }
423 };
424 
425 void Pack8bitRowMajorForAvx512(const std::uint8_t* src_ptr, int src_stride,
426                                int src_zero_point, std::int8_t* packed_ptr,
427                                int packed_stride, int start_col, int end_col,
428                                int src_cols, int block_row, int src_rows,
429                                int input_xor, std::int32_t* sums);
430 
431 template <typename Scalar>
432 struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
433                 Scalar, std::int8_t, std::int32_t, Order::kRowMajor> {
434   static void Run(Tuning, const Mat<Scalar>& src_matrix,
435                   PMat<std::int8_t>* packed_matrix, int start_col,
436                   int end_col) {
437     profiler::ScopeLabel label("Pack (kAvx512 8bit row-major)");
438     RUY_DCHECK_EQ(src_matrix.layout.order, Order::kRowMajor);
439     static constexpr int kInputXor =
440         std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
441     std::int32_t* sums = packed_matrix->sums;
442     std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col));
443     int block_row = 0;
444     for (; block_row < packed_matrix->layout.rows; block_row += 4) {
445       int src_stride = src_matrix.layout.stride;
446       int packed_stride = packed_matrix->layout.stride;
447       const Scalar* src_ptr =
448           src_matrix.data.get() + block_row * src_stride + start_col;
449       std::int8_t* packed_ptr =
450           packed_matrix->data + start_col * packed_stride + block_row * 16;
451       Pack8bitRowMajorForAvx512(reinterpret_cast<const std::uint8_t*>(src_ptr),
452                                 src_stride, src_matrix.zero_point, packed_ptr,
453                                 packed_stride, start_col, end_col,
454                                 src_matrix.layout.cols, block_row,
455                                 src_matrix.layout.rows, kInputXor, sums);
456     }
457   }
458 };
459 #endif  // RUY_PLATFORM_X86
460 
461 }  // namespace ruy
462 
463 #if ((RUY_PLATFORM_AVX || RUY_PLATFORM_AVX2_FMA) && RUY_OPT(ASM))
464 
465 #include <immintrin.h>  // IWYU pragma: keep
466 
467 namespace ruy {
468 namespace {
469 
470 template <Path path>
471 inline __m256 Mm256UnpackloPsx2(const __m256 a, const __m256 b) {
472   return _mm256_castpd_ps(
473       _mm256_unpacklo_pd(_mm256_castps_pd(a), _mm256_castps_pd(b)));
474 }
475 
476 template <Path path>
477 inline __m256 Mm256UnpackhiPsx2(const __m256 a, const __m256 b) {
478   return _mm256_castpd_ps(
479       _mm256_unpackhi_pd(_mm256_castps_pd(a), _mm256_castps_pd(b)));
480 }
481 
482 template <Path path>
483 inline __m256i CompareGreaterThan(const __m256i&, const __m256i&) {
484   RUY_DCHECK(false);
485   return _mm256_set1_epi32(0);
486 }
487 
488 // Shared between AVX and AVX2+FMA.
489 template <Path path>
490 inline __m256i MaskLoadu(int available_src_rows, std::int8_t zero_point,
491                          const std::int8_t* addr) {
492   RUY_DCHECK_LT(available_src_rows, 32);
493   __m256i padded_data;
494 
495   if (available_src_rows >= 16) {
496     __m128i load_hi = _mm_set1_epi8(zero_point);
497     __m128i load_lo = _mm_loadu_si128(reinterpret_cast<const __m128i*>(addr));
498     memcpy(&load_hi, addr + 16, available_src_rows - 16);
499     padded_data = _mm256_set_m128i(load_hi, load_lo);
500   } else {
501     __m128i load_hi = _mm_set1_epi8(zero_point);
502     __m128i load_lo = load_hi;
503     memcpy(&load_lo, addr, available_src_rows);
504     padded_data = _mm256_set_m128i(load_hi, load_lo);
505   }
506   return padded_data;
507 }
508 
509 }  // namespace.
510 
511 template <typename PackImpl, Path path>
512 inline void PackFloatColMajorForAvxCommonPacker(const float* src_ptr,
513                                                 const float* zerobuf,
514                                                 int src_stride,
515                                                 int remaining_src_cols,
516                                                 int src_rows, float* packed_ptr,
517                                                 float* trailing_buf) {
518   RUY_DCHECK_EQ(PackImpl::Layout::kCols, 8);
519   RUY_DCHECK_EQ(PackImpl::Layout::kRows, 1);
520 
521   // This packing amounts to transposition of 8x8 blocks.
522   static constexpr int kPackCols = 8;  // Source cols packed together.
523   static constexpr int kPackRows = 8;  // Short input is padded.
524 
525   const float* src_ptr0 = src_ptr;
526   const float* src_ptr1 = src_ptr0 + src_stride;
527   const float* src_ptr2 = src_ptr1 + src_stride;
528   const float* src_ptr3 = src_ptr2 + src_stride;
529   const float* src_ptr4 = src_ptr3 + src_stride;
530   const float* src_ptr5 = src_ptr4 + src_stride;
531   const float* src_ptr6 = src_ptr5 + src_stride;
532   const float* src_ptr7 = src_ptr6 + src_stride;
533   std::int64_t src_inc0 = 8;
534   std::int64_t src_inc1 = 8;
535   std::int64_t src_inc2 = 8;
536   std::int64_t src_inc3 = 8;
537   std::int64_t src_inc4 = 8;
538   std::int64_t src_inc5 = 8;
539   std::int64_t src_inc6 = 8;
540   std::int64_t src_inc7 = 8;
541   // Handle cases where source does not have kPackDim (8) columns.
542   if (remaining_src_cols < kPackCols) {
543     if (remaining_src_cols <= 0) {
544       src_ptr0 = zerobuf;
545       src_inc0 = 0;
546     }
547     if (remaining_src_cols <= 1) {
548       src_ptr1 = zerobuf;
549       src_inc1 = 0;
550     }
551     if (remaining_src_cols <= 2) {
552       src_ptr2 = zerobuf;
553       src_inc2 = 0;
554     }
555     if (remaining_src_cols <= 3) {
556       src_ptr3 = zerobuf;
557       src_inc3 = 0;
558     }
559     if (remaining_src_cols <= 4) {
560       src_ptr4 = zerobuf;
561       src_inc4 = 0;
562     }
563     if (remaining_src_cols <= 5) {
564       src_ptr5 = zerobuf;
565       src_inc5 = 0;
566     }
567     if (remaining_src_cols <= 6) {
568       src_ptr6 = zerobuf;
569       src_inc6 = 0;
570     }
571     src_ptr7 = zerobuf;
572     src_inc7 = 0;
573   }
574 
575   for (int k = 0; k < src_rows; k += kPackRows) {
576     const int available_src_rows = src_rows - k;
577     // Effectively,
578     // available_src_rows = std::max(0, std::min(kPackDim, src_rows - k));
579     // but treat each case separately.
580     if (available_src_rows >= kPackRows) {
581       __m256 t0, t1, t2, t3, t4, t5, t6, t7;
582       __m256 r0, r1, r2, r3, r4, r5, r6, r7;
583 
584       t0 = _mm256_loadu_ps(src_ptr0);
585       t4 = _mm256_loadu_ps(src_ptr4);
586       t1 = _mm256_loadu_ps(src_ptr1);
587       t5 = _mm256_loadu_ps(src_ptr5);
588       t2 = _mm256_loadu_ps(src_ptr2);
589       t6 = _mm256_loadu_ps(src_ptr6);
590       t3 = _mm256_loadu_ps(src_ptr3);
591       t7 = _mm256_loadu_ps(src_ptr7);
592 
593       r0 = _mm256_unpacklo_ps(t0, t1);
594       r4 = _mm256_unpacklo_ps(t4, t5);
595       r2 = _mm256_unpackhi_ps(t0, t1);
596       r6 = _mm256_unpackhi_ps(t4, t5);
597       r1 = _mm256_unpacklo_ps(t2, t3);
598       r5 = _mm256_unpacklo_ps(t6, t7);
599       r3 = _mm256_unpackhi_ps(t2, t3);
600       r7 = _mm256_unpackhi_ps(t6, t7);
601 
602       t0 = Mm256UnpackloPsx2<path>(r0, r1);
603       t4 = Mm256UnpackloPsx2<path>(r4, r5);
604       t2 = Mm256UnpackhiPsx2<path>(r0, r1);
605       t6 = Mm256UnpackhiPsx2<path>(r4, r5);
606       t1 = Mm256UnpackloPsx2<path>(r2, r3);
607       t5 = Mm256UnpackloPsx2<path>(r6, r7);
608       t3 = Mm256UnpackhiPsx2<path>(r2, r3);
609       t7 = Mm256UnpackhiPsx2<path>(r6, r7);
610 
611       // The preceding sets of rearrangement operations interleaved by 4 bytes
612       // and then by 8 bytes *within* lanes. The following set interleave by 16
613       // bytes (128-bit), operating *between* AVX lanes. For instance (t0, t4)
614       // are interleaved to create (r0, r1). This complexity follows from the
615       // way that AVX is centered around MM 128-bit lanes.
616       r0 = _mm256_permute2f128_ps(t0, t4, 0x20);
617       r4 = _mm256_permute2f128_ps(t1, t5, 0x20);
618       r1 = _mm256_permute2f128_ps(t0, t4, 0x31);
619       r5 = _mm256_permute2f128_ps(t1, t5, 0x31);
620       r2 = _mm256_permute2f128_ps(t2, t6, 0x20);
621       r6 = _mm256_permute2f128_ps(t3, t7, 0x20);
622       r3 = _mm256_permute2f128_ps(t2, t6, 0x31);
623       r7 = _mm256_permute2f128_ps(t3, t7, 0x31);
624 
625       _mm256_storeu_ps(packed_ptr + 0 * 8, r0);
626       _mm256_storeu_ps(packed_ptr + 2 * 8, r4);
627       _mm256_storeu_ps(packed_ptr + 4 * 8, r1);
628       _mm256_storeu_ps(packed_ptr + 6 * 8, r5);
629       _mm256_storeu_ps(packed_ptr + 1 * 8, r2);
630       _mm256_storeu_ps(packed_ptr + 3 * 8, r6);
631       _mm256_storeu_ps(packed_ptr + 5 * 8, r3);
632       _mm256_storeu_ps(packed_ptr + 7 * 8, r7);
633     } else if (available_src_rows > 0) {
634       const __m256i series = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
635       const __m256i row_mask_v = CompareGreaterThan<path>(
636           _mm256_set1_epi32(available_src_rows), series);
637 
638       __m256 t0, t1, t2, t3, t4, t5, t6, t7;
639       __m256 r0, r1, r2, r3, r4, r5, r6, r7;
640 
641       t0 = _mm256_maskload_ps(src_ptr0, row_mask_v);
642       t4 = _mm256_maskload_ps(src_ptr4, row_mask_v);
643       t1 = _mm256_maskload_ps(src_ptr1, row_mask_v);
644       t5 = _mm256_maskload_ps(src_ptr5, row_mask_v);
645       t2 = _mm256_maskload_ps(src_ptr2, row_mask_v);
646       t6 = _mm256_maskload_ps(src_ptr6, row_mask_v);
647       t3 = _mm256_maskload_ps(src_ptr3, row_mask_v);
648       t7 = _mm256_maskload_ps(src_ptr7, row_mask_v);
649 
650       r0 = _mm256_unpacklo_ps(t0, t1);
651       r4 = _mm256_unpacklo_ps(t4, t5);
652       r2 = _mm256_unpackhi_ps(t0, t1);
653       r6 = _mm256_unpackhi_ps(t4, t5);
654       r1 = _mm256_unpacklo_ps(t2, t3);
655       r5 = _mm256_unpacklo_ps(t6, t7);
656       r3 = _mm256_unpackhi_ps(t2, t3);
657       r7 = _mm256_unpackhi_ps(t6, t7);
658 
659       t0 = Mm256UnpackloPsx2<path>(r0, r1);
660       t4 = Mm256UnpackloPsx2<path>(r4, r5);
661       t2 = Mm256UnpackhiPsx2<path>(r0, r1);
662       t6 = Mm256UnpackhiPsx2<path>(r4, r5);
663       t1 = Mm256UnpackloPsx2<path>(r2, r3);
664       t5 = Mm256UnpackloPsx2<path>(r6, r7);
665       t3 = Mm256UnpackhiPsx2<path>(r2, r3);
666       t7 = Mm256UnpackhiPsx2<path>(r6, r7);
667 
668       // The preceding sets of rearrangement operations interleaved by 4 bytes
669       // and then by 8 bytes *within* lanes. The following set interleave by 16
670       // bytes (128-bit), operating *between* AVX lanes. For instance (t0, t4)
671       // are interleaved to create (r0, r1). This complexity follows from the
672       // way that AVX is centered around MM 128-bit lanes.
673       r0 = _mm256_permute2f128_ps(t0, t4, 0x20);
674       r4 = _mm256_permute2f128_ps(t1, t5, 0x20);
675       r1 = _mm256_permute2f128_ps(t0, t4, 0x31);
676       r5 = _mm256_permute2f128_ps(t1, t5, 0x31);
677       r2 = _mm256_permute2f128_ps(t2, t6, 0x20);
678       r6 = _mm256_permute2f128_ps(t3, t7, 0x20);
679       r3 = _mm256_permute2f128_ps(t2, t6, 0x31);
680       // r7 no longer needed.
681 
682       _mm256_storeu_ps(trailing_buf + 0 * 8, r0);
683       _mm256_storeu_ps(trailing_buf + 2 * 8, r4);
684       _mm256_storeu_ps(trailing_buf + 4 * 8, r1);
685       _mm256_storeu_ps(trailing_buf + 6 * 8, r5);
686       _mm256_storeu_ps(trailing_buf + 1 * 8, r2);
687       _mm256_storeu_ps(trailing_buf + 3 * 8, r6);
688       _mm256_storeu_ps(trailing_buf + 5 * 8, r3);
689       // No store to (trailing_buf + 7 * 8), space not allocated.
690     }
691 
692     packed_ptr += kPackRows * kPackCols;
693     src_ptr0 += src_inc0;
694     src_ptr1 += src_inc1;
695     src_ptr2 += src_inc2;
696     src_ptr3 += src_inc3;
697     src_ptr4 += src_inc4;
698     src_ptr5 += src_inc5;
699     src_ptr6 += src_inc6;
700     src_ptr7 += src_inc7;
701   }
702 }
703 }  // namespace ruy
704 #endif  //  (RUY_PLATFORM_AVX || RUY_PLATFORM_AVX2_FMA) && RUY_OPT(ASM)
705 
706 #endif  // RUY_RUY_PACK_X86_H_
707