xref: /aosp_15_r20/external/ruy/ruy/pack_arm.h (revision bb86c7ed5fb1b98a7eac808e443a46cc8b90dfc0)
1 /* Copyright 2019 Google LLC. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef RUY_RUY_PACK_ARM_H_
17 #define RUY_RUY_PACK_ARM_H_
18 
19 #include <algorithm>
20 #include <cstdint>
21 #include <type_traits>
22 
23 #include "ruy/asm_helpers.h"
24 #include "ruy/check_macros.h"
25 #include "ruy/mat.h"
26 #include "ruy/opt_set.h"
27 #include "ruy/pack_common.h"
28 #include "ruy/path.h"
29 #include "ruy/platform.h"
30 #include "ruy/profiler/instrumentation.h"
31 #include "ruy/tune.h"
32 
33 namespace ruy {
34 
35 #if RUY_PLATFORM_NEON
36 RUY_INHERIT_PACK(Path::kStandardCpp, Path::kNeon)
37 RUY_INHERIT_PACK(Path::kNeon, Path::kNeonDotprod)
38 
39 RUY_USE_MEMCPY_ROWMAJOR_FLOAT_PACK(Path::kNeon, 8)
40 #if RUY_PLATFORM_NEON_32
41 RUY_USE_MEMCPY_ROWMAJOR_FLOAT_PACK(Path::kNeon, 4)
42 #endif
43 
44 template <>
45 struct PackedTypeImpl<Path::kNeon, std::uint8_t> {
46   using Type = std::int8_t;
47 };
48 template <>
49 struct PackedTypeImpl<Path::kNeonDotprod, std::uint8_t> {
50   using Type = std::int8_t;
51 };
52 #endif
53 
54 #if RUY_PLATFORM_NEON
55 void Pack8bitRowMajorForNeon(const std::uint8_t* src_ptr, int src_stride,
56                              int src_rows, int src_cols, int block_row,
57                              int start_col, int end_col,
58                              std::int8_t* packed_ptr, int packed_stride,
59                              int packed_zero_point, std::int32_t* sums_ptr,
60                              int input_xor, int kernel_cols);
61 #endif
62 
63 #if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
64 
65 void Pack8bitColMajorForNeon(const void* src_ptr0, const void* src_ptr1,
66                              const void* src_ptr2, const void* src_ptr3,
67                              int src_inc0, int src_inc1, int src_inc2,
68                              int src_inc3, int src_rows, int src_zero_point,
69                              std::int8_t* packed_ptr, std::int32_t* sums_ptr,
70                              int input_xor);
71 void Pack8bitColMajorForNeonA55ish(const void* src_ptr0, const void* src_ptr1,
72                                    const void* src_ptr2, const void* src_ptr3,
73                                    int src_inc0, int src_inc1, int src_inc2,
74                                    int src_inc3, int src_rows,
75                                    int src_zero_point, std::int8_t* packed_ptr,
76                                    std::int32_t* sums_ptr, int input_xor);
77 void Pack8bitColMajorForNeonDotprod(const void* src_ptr0, const void* src_ptr1,
78                                     const void* src_ptr2, const void* src_ptr3,
79                                     int src_inc0, int src_inc1, int src_inc2,
80                                     int src_inc3, int src_rows,
81                                     int src_zero_point, std::int8_t* packed_ptr,
82                                     std::int32_t* sums_ptr, int input_xor);
83 void Pack8bitColMajorForNeonDotprodA55ish(
84     const void* src_ptr0, const void* src_ptr1, const void* src_ptr2,
85     const void* src_ptr3, int src_inc0, int src_inc1, int src_inc2,
86     int src_inc3, int src_rows, int src_zero_point, std::int8_t* packed_ptr,
87     std::int32_t* sums_ptr, int input_xor);
88 void Pack8bitRowMajorForNeonDotprod(const void* src_ptr0, const void* src_ptr1,
89                                     const void* src_ptr2, const void* src_ptr3,
90                                     int src_inc0, int src_inc1, int src_inc2,
91                                     int src_inc3, int src_cols,
92                                     int src_zero_point, std::int8_t* packed_ptr,
93                                     int packed_stride, std::int32_t* sums_ptr,
94                                     int input_xor);
95 #elif RUY_PLATFORM_NEON_32 && RUY_OPT(ASM)
96 
97 struct PackParams8bit {
98   const void* src_ptr0;
99   const void* src_ptr1;
100   const void* src_ptr2;
101   const void* src_ptr3;
102   const std::int32_t* sums_ptr;
103   const std::int8_t* packed_ptr;
104   int src_inc0;
105   int src_inc1;
106   int src_inc2;
107   int src_inc3;
108   int src_rows;
109   int src_zero_point;
110   int input_xor;
111 };
112 
113 inline void MakePackParams8bit(const void* src_ptr0, const void* src_ptr1,
114                                const void* src_ptr2, const void* src_ptr3,
115                                const std::int32_t* sums_ptr,
116                                const std::int8_t* packed_ptr, int src_inc0,
117                                int src_inc1, int src_inc2, int src_inc3,
118                                int src_rows, int src_zero_point, int input_xor,
119                                PackParams8bit* params) {
120   params->src_ptr0 = src_ptr0;
121   params->src_ptr1 = src_ptr1;
122   params->src_ptr2 = src_ptr2;
123   params->src_ptr3 = src_ptr3;
124   params->sums_ptr = sums_ptr;
125   params->packed_ptr = packed_ptr;
126   params->src_inc0 = src_inc0;
127   params->src_inc1 = src_inc1;
128   params->src_inc2 = src_inc2;
129   params->src_inc3 = src_inc3;
130   params->src_rows = src_rows;
131   params->src_zero_point = src_zero_point;
132   params->input_xor = input_xor;
133 }
134 
135 void Pack8bitColMajorForNeon4Cols(const PackParams8bit& params);
136 void Pack8bitColMajorForNeon2Cols(const PackParams8bit& params);
137 
138 #endif  // (RUY_PLATFORM_NEON_32 && RUY_OPT(ASM)
139 
140 #if (RUY_PLATFORM_NEON_32 || RUY_PLATFORM_NEON_64) && RUY_OPT(ASM)
141 
142 template <typename Scalar>
143 struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kColMajor, 16, 4>, Scalar,
144                 std::int8_t, std::int32_t, Order::kColMajor> {
145   static_assert(std::is_same<Scalar, std::int8_t>::value ||
146                     std::is_same<Scalar, std::uint8_t>::value,
147                 "");
148   static constexpr int kInputXor =
149       std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
150 
151   static void Run(Tuning tuning, const Mat<Scalar>& src_matrix,
152                   PMat<std::int8_t>* packed_matrix, int start_col,
153                   int end_col) {
154     RUY_DCHECK(IsColMajor(src_matrix.layout));
155     RUY_DCHECK(IsColMajor(packed_matrix->layout));
156     RUY_DCHECK_EQ(start_col % 4, 0);
157     std::int32_t* sums = packed_matrix->sums;
158     Scalar zerobuf[16];
159     memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf));
160     for (int block_col = start_col; block_col < end_col; block_col += 4) {
161       int src_stride = src_matrix.layout.stride;
162       const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col;
163       const Scalar* src_ptr1 = src_ptr0 + src_stride;
164       const Scalar* src_ptr2 = src_ptr1 + src_stride;
165       const Scalar* src_ptr3 = src_ptr2 + src_stride;
166       int src_inc0 = 16;
167       int src_inc1 = 16;
168       int src_inc2 = 16;
169       int src_inc3 = 16;
170       if (block_col >= src_matrix.layout.cols - 3) {
171         if (block_col >= src_matrix.layout.cols - 0) {
172           src_ptr0 = zerobuf;
173           src_inc0 = 0;
174         }
175         if (block_col >= src_matrix.layout.cols - 1) {
176           src_ptr1 = zerobuf;
177           src_inc1 = 0;
178         }
179         if (block_col >= src_matrix.layout.cols - 2) {
180           src_ptr2 = zerobuf;
181           src_inc2 = 0;
182         }
183         if (block_col >= src_matrix.layout.cols - 3) {
184           src_ptr3 = zerobuf;
185           src_inc3 = 0;
186         }
187       }
188       std::int8_t* packed_ptr =
189           packed_matrix->data + packed_matrix->layout.stride * block_col;
190       std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
191 #if RUY_PLATFORM_NEON_64
192       if (__builtin_expect(tuning == Tuning::kA55ish, true)) {
193         Pack8bitColMajorForNeonA55ish(
194             src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1,
195             src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point,
196             packed_ptr, sums_ptr, kInputXor);
197       } else {
198         Pack8bitColMajorForNeon(src_ptr0, src_ptr1, src_ptr2, src_ptr3,
199                                 src_inc0, src_inc1, src_inc2, src_inc3,
200                                 src_matrix.layout.rows, src_matrix.zero_point,
201                                 packed_ptr, sums_ptr, kInputXor);
202       }
203 #else
204       (void)tuning;
205       // We have a more limited set of general purpose registers in ARMv7, so
206       // we use the "params" struct technique from the kernel code to save
207       // registers.
208       PackParams8bit params;
209       MakePackParams8bit(src_ptr0, src_ptr1, src_ptr2, src_ptr3, sums_ptr,
210                          packed_ptr, src_inc0, src_inc1, src_inc2, src_inc3,
211                          src_matrix.layout.rows, src_matrix.zero_point,
212                          kInputXor, &params);
213       Pack8bitColMajorForNeon4Cols(params);
214 #endif  // RUY_PLATFORM_NEON_64
215     }
216   }
217 };
218 
219 #endif  // (RUY_PLATFORM_NEON_32 || RUY_PLATFORM_NEON_64) &&
220         // RUY_OPT(ASM)
221 
222 #if RUY_PLATFORM_NEON_32 && RUY_OPT(ASM)
223 // The 32-bit float kernel is 4 rows X 2 columns, so we need an additional
224 // partial specialization for the RHS, which has a FixedKernelLayout with 2
225 // columns.
226 template <typename Scalar>
227 struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kColMajor, 16, 2>, Scalar,
228                 std::int8_t, std::int32_t, Order::kColMajor> {
229   static_assert(std::is_same<Scalar, std::int8_t>::value ||
230                     std::is_same<Scalar, std::uint8_t>::value,
231                 "");
232   static constexpr int kInputXor =
233       std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
234   static void Run(Tuning, const Mat<Scalar>& src_matrix,
235                   PMat<std::int8_t>* packed_matrix, int start_col,
236                   int end_col) {
237     RUY_DCHECK(IsColMajor(src_matrix.layout));
238     RUY_DCHECK(IsColMajor(packed_matrix->layout));
239     RUY_DCHECK_EQ(start_col % 2, 0);
240     std::int32_t* sums = packed_matrix->sums;
241     Scalar zerobuf[16];
242     memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf));
243     for (int block_col = start_col; block_col < end_col; block_col += 2) {
244       int src_stride = src_matrix.layout.stride;
245       const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col;
246       const Scalar* src_ptr1 = src_ptr0 + src_stride;
247       int src_inc0 = 16;
248       int src_inc1 = 16;
249       if (block_col >= src_matrix.layout.cols - 2) {
250         if (block_col >= src_matrix.layout.cols - 0) {
251           src_ptr0 = zerobuf;
252           src_inc0 = 0;
253         }
254         if (block_col >= src_matrix.layout.cols - 1) {
255           src_ptr1 = zerobuf;
256           src_inc1 = 0;
257         }
258       }
259       std::int8_t* packed_ptr =
260           packed_matrix->data + packed_matrix->layout.stride * block_col;
261       std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
262       PackParams8bit params;
263       MakePackParams8bit(src_ptr0, src_ptr1, nullptr, nullptr, sums_ptr,
264                          packed_ptr, src_inc0, src_inc1, -1, -1,
265                          src_matrix.layout.rows, src_matrix.zero_point,
266                          kInputXor, &params);
267       Pack8bitColMajorForNeon2Cols(params);
268     }
269   }
270 };
271 #endif  // (RUY_PLATFORM_NEON_32) && RUY_OPT(ASM)
272 
273 #if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
274 template <typename Scalar>
275 struct PackImpl<Path::kNeonDotprod, FixedKernelLayout<Order::kColMajor, 4, 8>,
276                 Scalar, std::int8_t, std::int32_t, Order::kColMajor> {
277   static_assert(std::is_same<Scalar, std::int8_t>::value ||
278                     std::is_same<Scalar, std::uint8_t>::value,
279                 "");
280   static constexpr int kInputXor =
281       std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
282 
283   static void Run(Tuning tuning, const Mat<Scalar>& src_matrix,
284                   PMat<std::int8_t>* packed_matrix, int start_col,
285                   int end_col) {
286     RUY_DCHECK(IsColMajor(src_matrix.layout));
287     RUY_DCHECK(IsColMajor(packed_matrix->layout));
288     RUY_DCHECK_EQ(start_col % 8, 0);
289     std::int32_t* sums = packed_matrix->sums;
290     Scalar zerobuf[16];
291     memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf));
292     for (int block_col = start_col; block_col < end_col; block_col += 4) {
293       int src_stride = src_matrix.layout.stride;
294       const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col;
295       const Scalar* src_ptr1 = src_ptr0 + src_stride;
296       const Scalar* src_ptr2 = src_ptr1 + src_stride;
297       const Scalar* src_ptr3 = src_ptr2 + src_stride;
298       std::int64_t src_inc0 = 16;
299       std::int64_t src_inc1 = 16;
300       std::int64_t src_inc2 = 16;
301       std::int64_t src_inc3 = 16;
302       if (block_col >= src_matrix.layout.cols - 3) {
303         if (block_col >= src_matrix.layout.cols - 0) {
304           src_ptr0 = zerobuf;
305           src_inc0 = 0;
306         }
307         if (block_col >= src_matrix.layout.cols - 1) {
308           src_ptr1 = zerobuf;
309           src_inc1 = 0;
310         }
311         if (block_col >= src_matrix.layout.cols - 2) {
312           src_ptr2 = zerobuf;
313           src_inc2 = 0;
314         }
315         if (block_col >= src_matrix.layout.cols - 3) {
316           src_ptr3 = zerobuf;
317           src_inc3 = 0;
318         }
319       }
320       std::int8_t* packed_ptr =
321           packed_matrix->data +
322           packed_matrix->layout.stride * (block_col & ~7) +
323           ((block_col & 4) * 4);
324       std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
325       if (__builtin_expect(tuning == Tuning::kA55ish, true)) {
326         Pack8bitColMajorForNeonDotprodA55ish(
327             src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1,
328             src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point,
329             packed_ptr, sums_ptr, kInputXor);
330       } else {
331         Pack8bitColMajorForNeonDotprod(
332             src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1,
333             src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point,
334             packed_ptr, sums_ptr, kInputXor);
335       }
336     }
337   }
338 };
339 #endif  // (RUY_PLATFORM_NEON_64&& RUY_OPT(ASM)
340 
341 #if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
342 void PackFloatColMajorForNeon(const float* src_ptr0, const float* src_ptr1,
343                               const float* src_ptr2, const float* src_ptr3,
344                               int src_inc0, int src_inc1, int src_inc2,
345                               int src_inc3, int src_rows, float* packed_ptr);
346 void PackFloatColMajorForNeonA55ish(const float* src_ptr0,
347                                     const float* src_ptr1,
348                                     const float* src_ptr2,
349                                     const float* src_ptr3, int src_inc0,
350                                     int src_inc1, int src_inc2, int src_inc3,
351                                     int src_rows, float* packed_ptr);
352 
353 #elif RUY_PLATFORM_NEON_32 && RUY_OPT(ASM)
354 void PackFloatColMajorForNeon(const float* src_ptr0, const float* src_ptr1,
355                               const float* src_ptr2, const float* src_ptr3,
356                               int src_inc, int src_rows, float* packed_ptr,
357                               int stride);
358 #endif  // (RUY_PLATFORM_NEON_64&& RUY_OPT(ASM)
359 
360 #if (RUY_PLATFORM_NEON_32 || RUY_PLATFORM_NEON_64) && RUY_OPT(ASM)
361 
362 template <>
363 struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kRowMajor, 1, 8>, float,
364                 float, float, Order::kColMajor> {
365   static void Run(Tuning tuning, const Mat<float>& src_matrix,
366                   PMat<float>* packed_matrix, int start_col, int end_col) {
367     RUY_DCHECK(IsColMajor(src_matrix.layout));
368     RUY_DCHECK(IsColMajor(packed_matrix->layout));
369     RUY_DCHECK_EQ(start_col % 8, 0);
370     const float zerobuf[4] = {0};
371     for (int block_col = start_col; block_col < end_col; block_col += 4) {
372       int src_stride = src_matrix.layout.stride;
373       const float* src_ptr0 = src_matrix.data.get() + src_stride * block_col;
374       const float* src_ptr1 = src_ptr0 + src_stride;
375       const float* src_ptr2 = src_ptr1 + src_stride;
376       const float* src_ptr3 = src_ptr2 + src_stride;
377       std::int64_t src_inc0 = 16;
378       std::int64_t src_inc1 = 16;
379       std::int64_t src_inc2 = 16;
380       std::int64_t src_inc3 = 16;
381       if (block_col >= src_matrix.layout.cols - 3) {
382         if (block_col >= src_matrix.layout.cols - 0) {
383           src_ptr0 = zerobuf;
384           src_inc0 = 0;
385         }
386         if (block_col >= src_matrix.layout.cols - 1) {
387           src_ptr1 = zerobuf;
388           src_inc1 = 0;
389         }
390         if (block_col >= src_matrix.layout.cols - 2) {
391           src_ptr2 = zerobuf;
392           src_inc2 = 0;
393         }
394         if (block_col >= src_matrix.layout.cols - 3) {
395           src_ptr3 = zerobuf;
396           src_inc3 = 0;
397         }
398       }
399       float* packed_ptr = packed_matrix->data +
400                           packed_matrix->layout.stride * (block_col & ~7) +
401                           ((block_col & 4));
402 #if RUY_PLATFORM_NEON_64
403       if (__builtin_expect(tuning == Tuning::kA55ish, true)) {
404         PackFloatColMajorForNeonA55ish(src_ptr0, src_ptr1, src_ptr2, src_ptr3,
405                                        src_inc0, src_inc1, src_inc2, src_inc3,
406                                        src_matrix.layout.rows, packed_ptr);
407       } else {
408         PackFloatColMajorForNeon(src_ptr0, src_ptr1, src_ptr2, src_ptr3,
409                                  src_inc0, src_inc1, src_inc2, src_inc3,
410                                  src_matrix.layout.rows, packed_ptr);
411       }
412 #else
413       (void)tuning;
414       // Encode each of src_inc0, ..., src_inc3 in lowest 4 bits of src_inc
415       // to save on registers (we have fewer general purpose registers in
416       // 32-bit ARM than in 64-bit ARM). For the 64-bit case, we pass four
417       // values that are each either 16 or 0 and use them directly. For the
418       // 32-bit case, bits 0, 1, 2, and 3 are used to determine if we should
419       // use the value 16 (bit is set) or 0 (bit is not set) for the
420       // respective increment value.
421       std::int64_t src_inc = 0;
422       src_inc += src_inc0 == 16 ? 1 : 0;
423       src_inc += src_inc1 == 16 ? 2 : 0;
424       src_inc += src_inc2 == 16 ? 4 : 0;
425       src_inc += src_inc3 == 16 ? 8 : 0;
426       const int kOutputStride = 32;
427       PackFloatColMajorForNeon(src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc,
428                                src_matrix.layout.rows, packed_ptr,
429                                kOutputStride);
430 #endif  // RUY_PLATFORM_NEON_64
431     }
432   }
433 };
434 
435 #if RUY_PLATFORM_NEON_32
436 // The 32-bit float kernel is 8 rows X 4 columns, so we need an additional
437 // specialization for a FixedKernelLayout with 4 columns.
438 template <>
439 struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kRowMajor, 1, 4>, float,
440                 float, float, Order::kColMajor> {
441   static void Run(Tuning, const Mat<float>& src_matrix,
442                   PMat<float>* packed_matrix, int start_col, int end_col) {
443     RUY_DCHECK(IsColMajor(src_matrix.layout));
444     RUY_DCHECK(IsColMajor(packed_matrix->layout));
445     RUY_DCHECK_EQ(start_col % 4, 0);
446     const float zerobuf[4] = {0};
447     for (int block_col = start_col; block_col < end_col; block_col += 4) {
448       int src_stride = src_matrix.layout.stride;
449       const float* src_ptr0 = src_matrix.data.get() + src_stride * block_col;
450       const float* src_ptr1 = src_ptr0 + src_stride;
451       const float* src_ptr2 = src_ptr1 + src_stride;
452       const float* src_ptr3 = src_ptr2 + src_stride;
453       std::int64_t src_inc0 = 16;
454       std::int64_t src_inc1 = 16;
455       std::int64_t src_inc2 = 16;
456       std::int64_t src_inc3 = 16;
457       if (block_col >= src_matrix.layout.cols - 3) {
458         if (block_col >= src_matrix.layout.cols - 0) {
459           src_ptr0 = zerobuf;
460           src_inc0 = 0;
461         }
462         if (block_col >= src_matrix.layout.cols - 1) {
463           src_ptr1 = zerobuf;
464           src_inc1 = 0;
465         }
466         if (block_col >= src_matrix.layout.cols - 2) {
467           src_ptr2 = zerobuf;
468           src_inc2 = 0;
469         }
470         if (block_col >= src_matrix.layout.cols - 3) {
471           src_ptr3 = zerobuf;
472           src_inc3 = 0;
473         }
474       }
475       float* packed_ptr =
476           packed_matrix->data + packed_matrix->layout.stride * (block_col);
477       // Encode each of src_inc0, ..., src_inc1 in lowest 4 bits of scrc_inc
478       // to save registers.
479       std::int64_t src_inc = 0;
480       src_inc += src_inc0 == 16 ? 1 : 0;
481       src_inc += src_inc1 == 16 ? 2 : 0;
482       src_inc += src_inc2 == 16 ? 4 : 0;
483       src_inc += src_inc3 == 16 ? 8 : 0;
484       const int kOutputStride = 16;
485       PackFloatColMajorForNeon(src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc,
486                                src_matrix.layout.rows, packed_ptr,
487                                kOutputStride);
488     }
489   }
490 };
491 #endif  // (RUY_PLATFORM_NEON_32)
492 #endif  // (RUY_PLATFORM_NEON_64 || RUY_PLATFORM_NEON_32) && \
493         // RUY_OPT(ASM)
494 
495 #if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
496 
497 template <typename Scalar>
498 struct PackImpl<Path::kNeonDotprod, FixedKernelLayout<Order::kColMajor, 4, 8>,
499                 Scalar, std::int8_t, std::int32_t, Order::kRowMajor> {
500   static_assert(std::is_same<Scalar, std::int8_t>::value ||
501                     std::is_same<Scalar, std::uint8_t>::value,
502                 "");
503   static constexpr int kInputXor =
504       std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
505 
506   static void Run(Tuning, const Mat<Scalar>& src_matrix,
507                   PMat<std::int8_t>* packed_matrix, int start_col,
508                   int end_col) {
509     RUY_DCHECK(IsRowMajor(src_matrix.layout));
510     RUY_DCHECK(IsColMajor(packed_matrix->layout));
511     RUY_DCHECK_EQ(start_col % 8, 0);
512     std::int32_t* sums = packed_matrix->sums;
513     std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col));
514     Scalar zerobuf[8];
515     memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf));
516     int src_stride = src_matrix.layout.stride;
517     // As the source matrix is row-major and the destination packed matrix is
518     // column-major, there is no traversal order that will be optimal for both
519     // so we choose to favor the source matrix with a row-major traversal order.
520     // Loop over groups of 4 rows.
521     for (int block_row = 0; block_row < packed_matrix->layout.rows;
522          block_row += 4) {
523       // src_ptr[0-3] shall point to the positions in the 4 rows of the source
524       // matrix that we are loading from, and will be incremented by
525       // src_inc[0-3] after each 4x8 block is loaded.
526       // First we compute these src_ptr and src_inc values for the case where
527       // there are 4 rows left to be loaded from in the source matrix ...
528       const Scalar* src_ptr0 =
529           src_matrix.data.get() + src_stride * block_row + start_col;
530       const Scalar* src_ptr1 = src_ptr0 + src_stride;
531       const Scalar* src_ptr2 = src_ptr1 + src_stride;
532       const Scalar* src_ptr3 = src_ptr2 + src_stride;
533       std::int64_t src_inc0 = 8;
534       std::int64_t src_inc1 = 8;
535       std::int64_t src_inc2 = 8;
536       std::int64_t src_inc3 = 8;
537       // ... and now we adjust these values in case there are fewer than 4 rows
538       // left to load from in the source matrix. In that case, we set the
539       // corresponding src_ptr pointer to load from `zerobuf` and set src_inc
540       // to 0 to avoid overrunning that small buffer.
541       if (block_row >= src_matrix.layout.rows - 3) {
542         if (block_row >= src_matrix.layout.rows - 0) {
543           src_ptr0 = zerobuf;
544           src_inc0 = 0;
545         }
546         if (block_row >= src_matrix.layout.rows - 1) {
547           src_ptr1 = zerobuf;
548           src_inc1 = 0;
549         }
550         if (block_row >= src_matrix.layout.rows - 2) {
551           src_ptr2 = zerobuf;
552           src_inc2 = 0;
553         }
554         if (block_row >= src_matrix.layout.rows - 3) {
555           src_ptr3 = zerobuf;
556           src_inc3 = 0;
557         }
558       }
559       // Let src_cols be the number of source matrix columns to handle.
560       int src_cols = std::min(end_col, src_matrix.layout.cols) - start_col;
561       std::int8_t* packed_ptr = packed_matrix->data +
562                                 packed_matrix->layout.stride * start_col +
563                                 8 * block_row;
564       std::int32_t* sums_ptr = sums + start_col;
565       Pack8bitRowMajorForNeonDotprod(
566           src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1, src_inc2,
567           src_inc3, src_cols, src_matrix.zero_point, packed_ptr,
568           packed_matrix->layout.stride, sums_ptr, kInputXor);
569     }
570   }
571 };
572 
573 #endif  // RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
574 
575 #if RUY_PLATFORM_NEON
576 
577 template <typename Scalar, int KernelCols>
578 struct PackImpl<Path::kNeon,
579                 FixedKernelLayout<Order::kColMajor, 16, KernelCols>, Scalar,
580                 std::int8_t, std::int32_t, Order::kRowMajor> {
581   static void Run(Tuning, const Mat<Scalar>& src_matrix,
582                   PMat<std::int8_t>* packed_matrix, int start_col,
583                   int end_col) {
584     profiler::ScopeLabel label("Pack (KNeon, from row-major source)");
585     static constexpr int kInputXor =
586         std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
587     RUY_DCHECK_EQ(src_matrix.layout.order, Order::kRowMajor);
588     RUY_DCHECK_EQ((end_col - start_col) % KernelCols, 0);
589     std::int32_t* sums = packed_matrix->sums;
590     std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col));
591     int block_row = 0;
592     for (; block_row < packed_matrix->layout.rows; block_row += 16) {
593       int src_stride = src_matrix.layout.stride;
594       int packed_stride = packed_matrix->layout.stride;
595       const Scalar* src_ptr =
596           src_matrix.data.get() + block_row * src_stride + start_col;
597       std::int8_t* packed_ptr = packed_matrix->data +
598                                 start_col * packed_stride +
599                                 block_row * KernelCols;
600 
601       Pack8bitRowMajorForNeon(
602           reinterpret_cast<const std::uint8_t*>(src_ptr), src_stride,
603           src_matrix.layout.rows, src_matrix.layout.cols, block_row, start_col,
604           end_col, packed_ptr, packed_stride, packed_matrix->zero_point, sums,
605           kInputXor, KernelCols);
606     }
607   }
608 };
609 #endif
610 
611 }  // namespace ruy
612 
613 #endif  // RUY_RUY_PACK_ARM_H_
614