xref: /aosp_15_r20/external/ruy/ruy/pack_arm.cc (revision bb86c7ed5fb1b98a7eac808e443a46cc8b90dfc0)
1 /* Copyright 2019 Google LLC. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "ruy/pack_arm.h"
17 
18 #include <cstdint>
19 
20 #include "ruy/asm_helpers.h"
21 #include "ruy/opt_set.h"
22 #include "ruy/pack_common.h"
23 #include "ruy/platform.h"
24 #include "ruy/profiler/instrumentation.h"
25 
26 #if RUY_PLATFORM_NEON
27 #include <arm_neon.h>
28 #endif
29 
30 namespace ruy {
31 
32 #if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
33 
Pack8bitColMajorForNeon(const void * src_ptr0,const void * src_ptr1,const void * src_ptr2,const void * src_ptr3,int src_inc0,int src_inc1,int src_inc2,int src_inc3,int src_rows,int src_zero_point,std::int8_t * packed_ptr,std::int32_t * sums_ptr,int input_xor)34 void Pack8bitColMajorForNeon(const void* src_ptr0, const void* src_ptr1,
35                              const void* src_ptr2, const void* src_ptr3,
36                              int src_inc0, int src_inc1, int src_inc2,
37                              int src_inc3, int src_rows, int src_zero_point,
38                              std::int8_t* packed_ptr, std::int32_t* sums_ptr,
39                              int input_xor) {
40   profiler::ScopeLabel label("Pack (kNeon)");
41   asm volatile(
42       // clang-format off
43           // v26 will be the vector to XOR input values with to perform
44           // any input data type conversion (e.g. uint8 to int8).
45           "dup v26.16b, %w[input_xor]\n"
46           // w1 will be the number of rows already loaded.
47           "mov w1, #0\n"
48           // v28--v32 will be used to accumulate the sums
49           "movi v28.4s, #0\n"
50           "movi v29.4s, #0\n"
51           "movi v30.4s, #0\n"
52           "movi v31.4s, #0\n"
53           // Let w2 be `rows` rounded down to multiple of 16.
54           "ands w2, %w[rows], #-16\n"
55           // If there are no full blocks of 16 rows to process, jump to the
56           // code handling the last < 16 rows.
57           "beq 3f\n"
58           // Load the first block of 16 rows.
59           "add w1, w1, #16\n"
60           "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n"
61           "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n"
62           // Check if these were the only full block of 16 rows to load.
63           "cmp w1, w2\n"
64           "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n"
65           "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n"
66           // In that case, jump to the code handling the last loaded block of
67           // 16 rows.
68           "beq 2f\n"
69           // Main loop processing blocks of 16 rows.
70           "1:\n"
71           // Load the next 16 rows, interleaved with the XOR input type
72           // conversion (e.g. uint8->int8) on the already loaded inputs.
73           "add w1, w1, #16\n"
74           "eor v4.16b, v0.16b, v26.16b\n"
75           "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n"
76           "eor v5.16b, v1.16b, v26.16b\n"
77           "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n"
78           "eor v6.16b, v2.16b, v26.16b\n"
79           "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n"
80           "eor v7.16b, v3.16b, v26.16b\n"
81           "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n"
82           // Compute the sums, interleaved with storing to the packed matrix.
83           "saddlp v16.8h, v4.16b\n"
84           "str q4, [%[packed_ptr], #0]\n"
85           "saddlp v17.8h, v5.16b\n"
86           "str q5, [%[packed_ptr], #16]\n"
87           "saddlp v18.8h, v6.16b\n"
88           "str q6, [%[packed_ptr], #32]\n"
89           "saddlp v19.8h, v7.16b\n"
90           "str q7, [%[packed_ptr], #48]\n"
91           "sadalp v28.4s, v16.8h\n"
92           // Was this the last block of 16 rows to load?
93           "cmp w1, w2\n"
94           "sadalp v29.4s, v17.8h\n"
95           "add %[packed_ptr], %[packed_ptr], #64\n"
96           "sadalp v30.4s, v18.8h\n"
97           "sadalp v31.4s, v19.8h\n"
98           // End of main loop on blocks of 16 rows.
99           "bne 1b\n"
100 
101           // Code handling the last already-loaded block of 16 rows.
102           "2:\n"
103 
104           // Process the last loaded full 16x4 block.
105           "eor v4.16b, v0.16b, v26.16b\n"
106           "eor v5.16b, v1.16b, v26.16b\n"
107           "eor v6.16b, v2.16b, v26.16b\n"
108           "eor v7.16b, v3.16b, v26.16b\n"
109 
110           "saddlp v16.8h, v4.16b\n"
111           "str q4, [%[packed_ptr], #0]\n"
112           "saddlp v17.8h, v5.16b\n"
113           "str q5, [%[packed_ptr], #16]\n"
114           "saddlp v18.8h, v6.16b\n"
115           "str q6, [%[packed_ptr], #32]\n"
116           "saddlp v19.8h, v7.16b\n"
117           "str q7, [%[packed_ptr], #48]\n"
118           "sadalp v28.4s, v16.8h\n"
119           "sadalp v29.4s, v17.8h\n"
120           "sadalp v30.4s, v18.8h\n"
121           "sadalp v31.4s, v19.8h\n"
122 
123           "add %[packed_ptr], %[packed_ptr], #64\n"
124 
125           // End of code handling full blocks of 16 rows.
126           // Now we handle any remaining rows.
127           "3:\n"
128           // Let w2 be the number of rows left to handle.
129           "ands w2, %w[rows], #15\n"
130           // If w2==0, there are no remaining rows, jump to the end.
131           "beq 4f\n"
132           // Zero out a 16x4 block in registers, which we'll partially overwrite
133           // with any remaining rows.
134           "dup v0.16b, %w[src_zero_point]\n"
135           "dup v1.16b, %w[src_zero_point]\n"
136           "dup v2.16b, %w[src_zero_point]\n"
137           "dup v3.16b, %w[src_zero_point]\n"
138 #define RUY_LOAD_ONE_ROW(R)                   \
139   "cmp w2, #" #R "\n"                         \
140   "beq 5f\n"                                  \
141   "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \
142   "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \
143   "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \
144   "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n"
145 
146           RUY_LOAD_ONE_ROW(0)
147           RUY_LOAD_ONE_ROW(1)
148           RUY_LOAD_ONE_ROW(2)
149           RUY_LOAD_ONE_ROW(3)
150           RUY_LOAD_ONE_ROW(4)
151           RUY_LOAD_ONE_ROW(5)
152           RUY_LOAD_ONE_ROW(6)
153           RUY_LOAD_ONE_ROW(7)
154           RUY_LOAD_ONE_ROW(8)
155           RUY_LOAD_ONE_ROW(9)
156           RUY_LOAD_ONE_ROW(10)
157           RUY_LOAD_ONE_ROW(11)
158           RUY_LOAD_ONE_ROW(12)
159           RUY_LOAD_ONE_ROW(13)
160           RUY_LOAD_ONE_ROW(14)
161           // Here we know that w2==15, so RUY_LOAD_ONE_ROW(15) would be a no-op.
162 #undef RUY_LOAD_ONE_ROW
163           "5:\n"
164 
165           // Process the last zero-padded 16x4 block.
166           "eor v4.16b, v0.16b, v26.16b\n"
167           "eor v5.16b, v1.16b, v26.16b\n"
168           "eor v6.16b, v2.16b, v26.16b\n"
169           "eor v7.16b, v3.16b, v26.16b\n"
170 
171           "saddlp v16.8h, v4.16b\n"
172           "saddlp v17.8h, v5.16b\n"
173           "saddlp v18.8h, v6.16b\n"
174           "saddlp v19.8h, v7.16b\n"
175           "sadalp v28.4s, v16.8h\n"
176           "sadalp v29.4s, v17.8h\n"
177           "sadalp v30.4s, v18.8h\n"
178           "sadalp v31.4s, v19.8h\n"
179 
180           "str q4, [%[packed_ptr], #0]\n"
181           "str q5, [%[packed_ptr], #16]\n"
182           "str q6, [%[packed_ptr], #32]\n"
183           "str q7, [%[packed_ptr], #48]\n"
184           "add %[packed_ptr], %[packed_ptr], #64\n"
185 
186           "4:\n"
187 
188           // Horizontal reduction of the registers used to accumulate sums.
189           "addp v28.4s, v28.4s, v29.4s\n"
190           "addp v30.4s, v30.4s, v31.4s\n"
191           "addp v28.4s, v28.4s, v30.4s\n"
192 
193           // Store the sums.
194           "cmp %[sums_ptr], #0\n"
195           "beq 6f\n"
196           "st1 {v28.4s}, [%[sums_ptr]], #16\n"
197           "6:\n"
198       // clang-format on
199 
200       : [src_ptr0] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1),
201         [src_ptr2] "+r"(src_ptr2), [src_ptr3] "+r"(src_ptr3),
202         [packed_ptr] "+r"(packed_ptr), [sums_ptr] "+r"(sums_ptr)
203       : [src_inc0] "r"(static_cast<std::int64_t>(src_inc0)),
204         [src_inc1] "r"(static_cast<std::int64_t>(src_inc1)),
205         [src_inc2] "r"(static_cast<std::int64_t>(src_inc2)),
206         [src_inc3] "r"(static_cast<std::int64_t>(src_inc3)),
207         [rows] "r"(src_rows), [src_zero_point] "r"(src_zero_point),
208         [input_xor] "r"(input_xor)
209       : "cc", "memory", "x1", "x2", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
210         "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
211         "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
212         "v27", "v28", "v29", "v30", "v31");
213 }
214 #endif
215 
216 #if RUY_PLATFORM_NEON_32 && RUY_OPT(ASM)
217 
218 #define RUY_OFFSET_SRC_PTR0 0
219 #define RUY_OFFSET_SRC_PTR1 4
220 #define RUY_OFFSET_SRC_PTR2 8
221 #define RUY_OFFSET_SRC_PTR3 12
222 #define RUY_OFFSET_SUMS_PTR 16
223 #define RUY_OFFSET_PACKED_PTR 20
224 #define RUY_OFFSET_SRC_INC0 24
225 #define RUY_OFFSET_SRC_INC1 28
226 #define RUY_OFFSET_SRC_INC2 32
227 #define RUY_OFFSET_SRC_INC3 36
228 #define RUY_OFFSET_SRC_ROWS 40
229 #define RUY_OFFSET_SRC_ZERO_POINT 44
230 #define RUY_OFFSET_INPUT_XOR 48
231 
232 template <typename Params>
CheckOffsetsInPackParams8bit(const Params &)233 void CheckOffsetsInPackParams8bit(const Params&) {
234   static_assert(offsetof(Params, src_ptr0) == RUY_OFFSET_SRC_PTR0, "");
235   static_assert(offsetof(Params, src_ptr1) == RUY_OFFSET_SRC_PTR1, "");
236   static_assert(offsetof(Params, src_ptr2) == RUY_OFFSET_SRC_PTR2, "");
237   static_assert(offsetof(Params, src_ptr3) == RUY_OFFSET_SRC_PTR3, "");
238   static_assert(offsetof(Params, sums_ptr) == RUY_OFFSET_SUMS_PTR, "");
239   static_assert(offsetof(Params, packed_ptr) == RUY_OFFSET_PACKED_PTR, "");
240   static_assert(offsetof(Params, src_inc0) == RUY_OFFSET_SRC_INC0, "");
241   static_assert(offsetof(Params, src_inc1) == RUY_OFFSET_SRC_INC1, "");
242   static_assert(offsetof(Params, src_inc2) == RUY_OFFSET_SRC_INC2, "");
243   static_assert(offsetof(Params, src_inc3) == RUY_OFFSET_SRC_INC3, "");
244   static_assert(offsetof(Params, src_rows) == RUY_OFFSET_SRC_ROWS, "");
245   static_assert(offsetof(Params, src_zero_point) == RUY_OFFSET_SRC_ZERO_POINT,
246                 "");
247   static_assert(offsetof(Params, input_xor) == RUY_OFFSET_INPUT_XOR, "");
248 }
249 
250 // No attempt made at making this code efficient on A55-ish cores yet.
Pack8bitColMajorForNeon4Cols(const PackParams8bit & params)251 void Pack8bitColMajorForNeon4Cols(const PackParams8bit& params) {
252   CheckOffsetsInPackParams8bit(params);
253   profiler::ScopeLabel label("Pack (kNeon)");
254   const void* src_ptr0 = params.src_ptr0;
255   const void* src_ptr1 = params.src_ptr1;
256   const void* src_ptr2 = params.src_ptr2;
257   const void* src_ptr3 = params.src_ptr3;
258   const int src_inc0 = params.src_inc0;
259   const int src_inc1 = params.src_inc1;
260   const int src_inc2 = params.src_inc2;
261   const int src_inc3 = params.src_inc3;
262   const std::int8_t* packed_ptr = params.packed_ptr;
263 
264   asm volatile(
265       // clang-format off
266 
267           "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_INPUT_XOR) "]\n"
268           "vdup.8 q11, r2\n"
269           "mov r1, #0\n"
270           // Zero-out the accumulators
271           "vmov.i32 q12, #0\n"
272           "vmov.i32 q13, #0\n"
273           "vmov.i32 q14, #0\n"
274           "vmov.i32 q15, #0\n"
275 
276           // Round down src_rows to nearest multiple of 16.
277           "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n"
278           "and r2, r3, #-16\n"
279           "cmp r1, r2\n"
280           "beq 3f\n"
281 
282           "1:\n"
283           "add r1, r1, #16\n"
284           /* Load q0 */
285           "vld1.8 {d0, d1}, [%[src_ptr0]]\n"
286           "add %[src_ptr0], %[src_ptr0], %[src_inc0]\n"
287           RUY_PREFETCH_LOAD("pld [%[src_ptr0]]\n")
288 
289           /* Load q1 */
290           "vld1.8 {d2, d3}, [%[src_ptr1]]\n"
291           "add %[src_ptr1], %[src_ptr1], %[src_inc1]\n"
292           RUY_PREFETCH_LOAD("pld [%[src_ptr1]]\n")
293 
294           "veor.8 q4, q0, q11\n"
295           "veor.8 q5, q1, q11\n"
296 
297           // Pairwise add in to 16b accumulators.
298           "vpaddl.s8 q8, q4\n"
299           "vpaddl.s8 q9, q5\n"
300 
301           "vst1.32 {q4}, [%[packed_ptr]]!\n"
302           "vst1.32 {q5}, [%[packed_ptr]]!\n"
303 
304           // Pairwise add accumulate into 32b accumulators.
305           // q12 and q13 contain 4x32b accumulators
306           "vpadal.s16 q12, q8\n"
307           "vpadal.s16 q13, q9\n"
308 
309           // Now do the same for src_ptr2 and src_ptr3.
310           "vld1.8 {d0, d1}, [%[src_ptr2]]\n"
311           "add %[src_ptr2], %[src_ptr2], %[src_inc2]\n"
312           RUY_PREFETCH_LOAD("pld [%[src_ptr2]]\n")
313 
314           "vld1.8 {d2, d3}, [%[src_ptr3]]\n"
315           "add %[src_ptr3], %[src_ptr3], %[src_inc3]\n"
316           RUY_PREFETCH_LOAD("pld [%[src_ptr3]]\n")
317 
318           "veor.8 q4, q0, q11\n"
319           "veor.8 q5, q1, q11\n"
320 
321           "vpaddl.s8 q8, q4\n"
322           "vpaddl.s8 q9, q5\n"
323 
324           "vst1.32 {q4}, [%[packed_ptr]]!\n"
325           "vst1.32 {q5}, [%[packed_ptr]]!\n"
326 
327           // Pairwise add accumulate into 32b accumulators.
328           // q14 and q15 contain 4x32b accumulators
329           "vpadal.s16 q14, q8\n"
330           "vpadal.s16 q15, q9\n"
331 
332           "cmp r1, r2\n"
333           "bne 1b\n"
334 
335           "3:\n"
336 
337           // Now pack the last (num_rows % 16) rows.
338           "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n"
339           "ands r2, r3, #15\n"
340           "beq 4f\n"
341           "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ZERO_POINT) "]\n"
342           "vdup.8 q0, r3\n"
343           "vdup.8 q1, r3\n"
344 
345 // First, read/accumulate/write for src_ptr0 and src_ptr1.
346 #define RUY_LOAD_ONE_ROW1(I, R)            \
347   "cmp r2, #" #I "\n"                      \
348   "beq 5f\n"                               \
349   "vld1.8 { d0[" #R "]}, [%[src_ptr0]]!\n" \
350   "vld1.8 { d2[" #R "]}, [%[src_ptr1]]!\n" \
351 
352           RUY_LOAD_ONE_ROW1(0, 0)
353           RUY_LOAD_ONE_ROW1(1, 1)
354           RUY_LOAD_ONE_ROW1(2, 2)
355           RUY_LOAD_ONE_ROW1(3, 3)
356           RUY_LOAD_ONE_ROW1(4, 4)
357           RUY_LOAD_ONE_ROW1(5, 5)
358           RUY_LOAD_ONE_ROW1(6, 6)
359           RUY_LOAD_ONE_ROW1(7, 7)
360 #undef RUY_LOAD_ONE_ROW1
361 
362 #define RUY_LOAD_ONE_ROW2(I, R)            \
363   "cmp r2, #" #I "\n"                      \
364   "beq 5f\n"                               \
365   "vld1.8 { d1[" #R "]}, [%[src_ptr0]]!\n" \
366   "vld1.8 { d3[" #R "]}, [%[src_ptr1]]!\n" \
367 
368           RUY_LOAD_ONE_ROW2(8, 0)
369           RUY_LOAD_ONE_ROW2(9, 1)
370           RUY_LOAD_ONE_ROW2(10, 2)
371           RUY_LOAD_ONE_ROW2(11, 3)
372           RUY_LOAD_ONE_ROW2(12, 4)
373           RUY_LOAD_ONE_ROW2(13, 5)
374           RUY_LOAD_ONE_ROW2(14, 6)
375           RUY_LOAD_ONE_ROW2(15, 7)
376 #undef RUY_LOAD_ONE_ROW2
377 
378           "5:\n"
379 
380           "veor.16 q4, q0, q11\n"
381           "veor.16 q5, q1, q11\n"
382 
383           "vpaddl.s8 q8, q4\n"
384           "vpaddl.s8 q9, q5\n"
385 
386           // Pairwise add accumulate to 4x32b accumulators.
387           "vpadal.s16 q12, q8\n"
388           "vpadal.s16 q13, q9\n"
389 
390           "vst1.32 {q4}, [%[packed_ptr]]!\n"
391           "vst1.32 {q5}, [%[packed_ptr]]!\n"
392 
393           // Reset to src_zero for src_ptr2 and src_ptr3.
394           "vdup.8 q0, r3\n"
395           "vdup.8 q1, r3\n"
396 
397 // Next, read/accumulate/write for src_ptr2 and src_ptr3.
398 #define RUY_LOAD_ONE_ROW1(I, R)            \
399   "cmp r2, #" #I "\n"                      \
400   "beq 5f\n"                               \
401   "vld1.8 { d0[" #R "]}, [%[src_ptr2]]!\n" \
402   "vld1.8 { d2[" #R "]}, [%[src_ptr3]]!\n" \
403 
404           RUY_LOAD_ONE_ROW1(0, 0)
405           RUY_LOAD_ONE_ROW1(1, 1)
406           RUY_LOAD_ONE_ROW1(2, 2)
407           RUY_LOAD_ONE_ROW1(3, 3)
408           RUY_LOAD_ONE_ROW1(4, 4)
409           RUY_LOAD_ONE_ROW1(5, 5)
410           RUY_LOAD_ONE_ROW1(6, 6)
411           RUY_LOAD_ONE_ROW1(7, 7)
412 #undef RUY_LOAD_ONE_ROW1
413 
414 #define RUY_LOAD_ONE_ROW2(I, R)            \
415   "cmp r2, #" #I "\n"                      \
416   "beq 5f\n"                               \
417   "vld1.8 { d1[" #R "]}, [%[src_ptr2]]!\n" \
418   "vld1.8 { d3[" #R "]}, [%[src_ptr3]]!\n" \
419 
420           RUY_LOAD_ONE_ROW2(8, 0)
421           RUY_LOAD_ONE_ROW2(9, 1)
422           RUY_LOAD_ONE_ROW2(10, 2)
423           RUY_LOAD_ONE_ROW2(11, 3)
424           RUY_LOAD_ONE_ROW2(12, 4)
425           RUY_LOAD_ONE_ROW2(13, 5)
426           RUY_LOAD_ONE_ROW2(14, 6)
427           RUY_LOAD_ONE_ROW2(15, 7)
428 #undef RUY_LOAD_ONE_ROW2
429 
430           "5:\n"
431 
432           "veor.16 q4, q0, q11\n"
433           "veor.16 q5, q1, q11\n"
434 
435           "vpaddl.s8 q8, q4\n"
436           "vpaddl.s8 q9, q5\n"
437 
438           // Pairwise add accumulate to 4x32b accumulators.
439           "vpadal.s16 q14, q8\n"
440           "vpadal.s16 q15, q9\n"
441 
442           "vst1.32 {q4}, [%[packed_ptr]]!\n"
443           "vst1.32 {q5}, [%[packed_ptr]]!\n"
444 
445           "4:\n"
446           // Pairwise add 32-bit accumulators
447           "vpadd.i32 d24, d24, d25\n"
448           "vpadd.i32 d26, d26, d27\n"
449           "vpadd.i32 d28, d28, d29\n"
450           "vpadd.i32 d30, d30, d31\n"
451           // Final 32-bit values per row
452           "vpadd.i32 d25, d24, d26\n"
453           "vpadd.i32 d27, d28, d30\n"
454 
455           "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SUMS_PTR) "]\n"
456           "cmp r3, #0\n"
457           "beq 6f\n"
458           "vst1.32 {d25}, [r3]!\n"
459           "vst1.32 {d27}, [r3]!\n"
460           "6:\n"
461       // clang-format on
462 
463       : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1),
464         [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3)
465       : [ src_inc0 ] "r"(src_inc0), [ src_inc1 ] "r"(src_inc1),
466         [ src_inc2 ] "r"(src_inc2), [ src_inc3 ] "r"(src_inc3),
467         [ packed_ptr ] "r"(packed_ptr), [ params ] "r"(&params)
468       : "cc", "memory", "r1", "r2", "r3", "q0", "q1", "q2", "q3",
469         "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13");
470 }
471 
472 // Packing code for out-of-order ARMv7 CPUs like the Krait 400 or A9.
473 // No attempt made at making this code efficient on in-order cores yet.
474 // This version differs from the above in that we only handle two columns
475 // at a time.
Pack8bitColMajorForNeon2Cols(const PackParams8bit & params)476 void Pack8bitColMajorForNeon2Cols(const PackParams8bit& params) {
477   CheckOffsetsInPackParams8bit(params);
478   profiler::ScopeLabel label("Pack (kNeon)");
479   const void* src_ptr0 = params.src_ptr0;
480   const void* src_ptr1 = params.src_ptr1;
481   const int src_inc0 = params.src_inc0;
482   const int src_inc1 = params.src_inc1;
483   const std::int8_t* packed_ptr = params.packed_ptr;
484 
485   asm volatile(
486       // clang-format off
487 
488           "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_INPUT_XOR) "]\n"
489           "vdup.8 q11, r2\n"
490           "mov r1, #0\n"
491           // Zero-out the accumulators
492           "vmov.i32 q12, #0\n"
493           "vmov.i32 q13, #0\n"
494 
495           // Round down src_rows to nearest multiple of 16.
496           "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n"
497           "and r2, r3, #-16\n"
498           "cmp r1, r2\n"
499           "beq 3f\n"
500 
501           "1:\n"
502           "add r1, r1, #16\n"
503           /* Load q0 */
504           "vld1.8 {d0, d1}, [%[src_ptr0]]\n"
505           "add %[src_ptr0], %[src_ptr0], %[src_inc0]\n"
506 
507           /* Load q1 */
508           "vld1.8 {d2, d3}, [%[src_ptr1]]\n"
509           "add %[src_ptr1], %[src_ptr1], %[src_inc1]\n"
510 
511           "veor.8 q4, q0, q11\n"
512           "veor.8 q5, q1, q11\n"
513 
514           // Pairwise add in to 16b accumulators.
515           "vpaddl.s8 q8, q4\n"
516           "vpaddl.s8 q9, q5\n"
517 
518           "vst1.32 {q4}, [%[packed_ptr]]!\n"
519           "vst1.32 {q5}, [%[packed_ptr]]!\n"
520 
521           // Pairwise add accumulate into 32b accumulators.
522           // q12 and q13 contain 4x32b accumulators
523           "vpadal.s16 q12, q8\n"
524           "vpadal.s16 q13, q9\n"
525 
526           "cmp r1, r2\n"
527 
528           "bne 1b\n"
529 
530           "3:\n"
531 
532           // Now pack the last (num_rows % 16) rows.
533           "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n"
534           "ands r2, r3, #15\n"
535           "beq 4f\n"
536           "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ZERO_POINT) "]\n"
537           "vdup.8 q0, r3\n"
538           "vdup.8 q1, r3\n"
539 
540 // Read/accumulate/write for src_ptr0 and src_ptr1.
541 #define RUY_LOAD_ONE_ROW1(I, R)            \
542   "cmp r2, #" #I "\n"                      \
543   "beq 5f\n"                               \
544   "vld1.8 { d0[" #R "]}, [%[src_ptr0]]!\n" \
545   "vld1.8 { d2[" #R "]}, [%[src_ptr1]]!\n" \
546 
547           RUY_LOAD_ONE_ROW1(0, 0)
548           RUY_LOAD_ONE_ROW1(1, 1)
549           RUY_LOAD_ONE_ROW1(2, 2)
550           RUY_LOAD_ONE_ROW1(3, 3)
551           RUY_LOAD_ONE_ROW1(4, 4)
552           RUY_LOAD_ONE_ROW1(5, 5)
553           RUY_LOAD_ONE_ROW1(6, 6)
554           RUY_LOAD_ONE_ROW1(7, 7)
555 #undef RUY_LOAD_ONE_ROW1
556 
557 #define RUY_LOAD_ONE_ROW2(I, R)            \
558   "cmp r2, #" #I "\n"                      \
559   "beq 5f\n"                               \
560   "vld1.8 { d1[" #R "]}, [%[src_ptr0]]!\n" \
561   "vld1.8 { d3[" #R "]}, [%[src_ptr1]]!\n" \
562 
563           RUY_LOAD_ONE_ROW2(8, 0)
564           RUY_LOAD_ONE_ROW2(9, 1)
565           RUY_LOAD_ONE_ROW2(10, 2)
566           RUY_LOAD_ONE_ROW2(11, 3)
567           RUY_LOAD_ONE_ROW2(12, 4)
568           RUY_LOAD_ONE_ROW2(13, 5)
569           RUY_LOAD_ONE_ROW2(14, 6)
570           RUY_LOAD_ONE_ROW2(15, 7)
571 #undef RUY_LOAD_ONE_ROW2
572 
573           "5:\n"
574 
575           "veor.16 q4, q0, q11\n"
576           "veor.16 q5, q1, q11\n"
577 
578           "vpaddl.s8 q8, q4\n"
579           "vpaddl.s8 q9, q5\n"
580 
581 
582           // Pairwise add accumulate to 4x32b accumulators.
583           "vpadal.s16 q12, q8\n"
584           "vpadal.s16 q13, q9\n"
585 
586           "vst1.32 {q4}, [%[packed_ptr]]!\n"
587           "vst1.32 {q5}, [%[packed_ptr]]!\n"
588 
589           "4:\n"
590 
591           // Pairwise add 32-bit accumulators
592           "vpadd.i32 d24, d24, d25\n"
593           "vpadd.i32 d26, d26, d27\n"
594           // Final 32-bit values per row
595           "vpadd.i32 d25, d24, d26\n"
596 
597           "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SUMS_PTR) "]\n"
598           "cmp r3, #0\n"
599           "beq 6f\n"
600           "vst1.32 {d25}, [r3]!\n"
601           "6:\n"
602       // clang-format on
603 
604       : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1)
605       : [ src_inc0 ] "r"(src_inc0), [ src_inc1 ] "r"(src_inc1),
606         [ packed_ptr ] "r"(packed_ptr), [ params ] "r"(&params)
607       : "cc", "memory", "r1", "r2", "r3", "q0", "q1", "q2", "q3",
608         "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13");
609 }
610 
611 #undef RUY_OFFSET_SRC_PTR0
612 #undef RUY_OFFSET_SRC_PTR1
613 #undef RUY_OFFSET_SRC_PTR2
614 #undef RUY_OFFSET_SRC_PTR32
615 #undef RUY_OFFSET_SUMS_PTR
616 #undef RUY_OFFSET_PACKED_PTR0
617 #undef RUY_OFFSET_SRC_INC0
618 #undef RUY_OFFSET_SRC_INC1
619 #undef RUY_OFFSET_SRC_INC2
620 #undef RUY_OFFSET_SRC_INC3
621 #undef RUY_OFFSET_SRC_ROWS
622 #undef RUY_OFFSET_SRC_ZERO_POINT
623 #undef RUY_OFFSET_INPUT_XOR
624 
625 #endif  //  RUY_PLATFORM_NEON_32 && RUY_OPT(ASM)
626 
627 #if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
628 
Pack8bitColMajorForNeonA55ish(const void * src_ptr0,const void * src_ptr1,const void * src_ptr2,const void * src_ptr3,int src_inc0,int src_inc1,int src_inc2,int src_inc3,int src_rows,int src_zero_point,std::int8_t * packed_ptr,std::int32_t * sums_ptr,int input_xor)629 void Pack8bitColMajorForNeonA55ish(const void* src_ptr0, const void* src_ptr1,
630                                    const void* src_ptr2, const void* src_ptr3,
631                                    int src_inc0, int src_inc1, int src_inc2,
632                                    int src_inc3, int src_rows,
633                                    int src_zero_point, std::int8_t* packed_ptr,
634                                    std::int32_t* sums_ptr, int input_xor) {
635   profiler::ScopeLabel label("Pack (kNeon, optimized for in-order cores)");
636   asm volatile(
637       // clang-format off
638           // v26 will be the vector to XOR input values with to perform
639           // any input data type conversion (e.g. uint8 to int8).
640           "dup v26.16b, %w[input_xor]\n"
641           // w1 will be the number of rows already loaded.
642           "mov w1, #0\n"
643           // v28--v32 will be used to accumulate the sums
644           "movi v28.4s, #0\n"
645           "movi v29.4s, #0\n"
646           "movi v30.4s, #0\n"
647           "movi v31.4s, #0\n"
648           // Let w2 be `rows` rounded down to multiple of 16.
649           "ands w2, %w[rows], #-16\n"
650           // If there are no full blocks of 16 rows to process, jump to the
651           // code handling the last < 16 rows.
652           "beq 3f\n"
653           // Load the first block of 16 rows.
654           "add w1, w1, #16\n"
655           "ldr x10, [%[src_ptr0], #8]\n"
656           "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n"
657           "ldr x11, [%[src_ptr1], #8]\n"
658           "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n"
659           "ldr x12, [%[src_ptr2], #8]\n"
660           "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n"
661           "ldr x13, [%[src_ptr3], #8]\n"
662           "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n"
663           // Check if these were the only full block of 16 rows to load.
664           "cmp w1, w2\n"
665           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #64]\n")
666           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #64]\n")
667           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #64]\n")
668           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #64]\n")
669           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #128]\n")
670           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #128]\n")
671           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #128]\n")
672           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #128]\n")
673           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #192]\n")
674           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #192]\n")
675           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #192]\n")
676           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #192]\n")
677           // In that case, jump to the code handling the last loaded block of
678           // 16 rows.
679           "beq 2f\n"
680           // Main loop processing blocks of 16 rows.
681           "1:\n"
682           // Load the next 16 rows, interleaved with the XOR input type
683           // conversion (e.g. uint8->int8) on the already loaded inputs.
684           "add w1, w1, #16\n"
685           "ins v0.d[1], x10\n"
686           "ldr x10, [%[src_ptr0], #8]\n"
687           "ins v1.d[1], x11\n"
688           "ldr x11, [%[src_ptr1], #8]\n"
689           "ins v2.d[1], x12\n"
690           "ldr x12, [%[src_ptr2], #8]\n"
691           "ins v3.d[1], x13\n"
692           "ldr x13, [%[src_ptr3], #8]\n"
693           "eor v4.16b, v0.16b, v26.16b\n"
694           "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n"
695           "eor v5.16b, v1.16b, v26.16b\n"
696           "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n"
697           "eor v6.16b, v2.16b, v26.16b\n"
698           "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n"
699           "eor v7.16b, v3.16b, v26.16b\n"
700           "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n"
701           // Compute the sums, interleaved with storing to the packed matrix.
702           "saddlp v16.8h, v4.16b\n"
703           "str q4, [%[packed_ptr], #0]\n"
704           "saddlp v17.8h, v5.16b\n"
705           "str q5, [%[packed_ptr], #16]\n"
706           "saddlp v18.8h, v6.16b\n"
707           "str q6, [%[packed_ptr], #32]\n"
708           "saddlp v19.8h, v7.16b\n"
709           "str q7, [%[packed_ptr], #48]\n"
710           "sadalp v28.4s, v16.8h\n"
711           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #240]\n")
712           // Was this the last block of 16 rows to load?
713           "cmp w1, w2\n"
714           "sadalp v29.4s, v17.8h\n"
715           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #240]\n")
716           "add %[packed_ptr], %[packed_ptr], #64\n"
717           "sadalp v30.4s, v18.8h\n"
718           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #240]\n")
719           "sadalp v31.4s, v19.8h\n"
720           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #240]\n")
721           // End of main loop on blocks of 16 rows.
722           "bne 1b\n"
723 
724           // Code handling the last already-loaded block of 16 rows.
725           "2:\n"
726           // Process the last loaded full 16x4 block.
727           "ins v0.d[1], x10\n"
728           "ins v1.d[1], x11\n"
729           "ins v2.d[1], x12\n"
730           "ins v3.d[1], x13\n"
731           "eor v4.16b, v0.16b, v26.16b\n"
732           "eor v5.16b, v1.16b, v26.16b\n"
733           "eor v6.16b, v2.16b, v26.16b\n"
734           "eor v7.16b, v3.16b, v26.16b\n"
735 
736           "saddlp v16.8h, v4.16b\n"
737           "str q4, [%[packed_ptr], #0]\n"
738           "saddlp v17.8h, v5.16b\n"
739           "str q5, [%[packed_ptr], #16]\n"
740           "saddlp v18.8h, v6.16b\n"
741           "str q6, [%[packed_ptr], #32]\n"
742           "saddlp v19.8h, v7.16b\n"
743           "str q7, [%[packed_ptr], #48]\n"
744           "sadalp v28.4s, v16.8h\n"
745           "sadalp v29.4s, v17.8h\n"
746           "sadalp v30.4s, v18.8h\n"
747           "sadalp v31.4s, v19.8h\n"
748 
749           "add %[packed_ptr], %[packed_ptr], #64\n"
750 
751           // End of code handling full blocks of 16 rows.
752           // Now we handle any remaining rows.
753           "3:\n"
754           // Let w2 be the number of rows left to handle.
755           "ands w2, %w[rows], #15\n"
756           // If w2==0, there are no remaining rows, jump to the end.
757           "beq 4f\n"
758           // Zero out a 16x4 block in registers, which we'll partially overwrite
759           // with any remaining rows.
760           "dup v0.16b, %w[src_zero_point]\n"
761           "dup v1.16b, %w[src_zero_point]\n"
762           "dup v2.16b, %w[src_zero_point]\n"
763           "dup v3.16b, %w[src_zero_point]\n"
764 #define RUY_LOAD_ONE_ROW(R)                   \
765   "cmp w2, #" #R "\n"                         \
766   "beq 5f\n"                                  \
767   "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \
768   "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \
769   "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \
770   "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n"
771 
772           RUY_LOAD_ONE_ROW(0)
773           RUY_LOAD_ONE_ROW(1)
774           RUY_LOAD_ONE_ROW(2)
775           RUY_LOAD_ONE_ROW(3)
776           RUY_LOAD_ONE_ROW(4)
777           RUY_LOAD_ONE_ROW(5)
778           RUY_LOAD_ONE_ROW(6)
779           RUY_LOAD_ONE_ROW(7)
780           RUY_LOAD_ONE_ROW(8)
781           RUY_LOAD_ONE_ROW(9)
782           RUY_LOAD_ONE_ROW(10)
783           RUY_LOAD_ONE_ROW(11)
784           RUY_LOAD_ONE_ROW(12)
785           RUY_LOAD_ONE_ROW(13)
786           RUY_LOAD_ONE_ROW(14)
787           // Here we know that w2==15, so RUY_LOAD_ONE_ROW(15) would be a no-op.
788 #undef RUY_LOAD_ONE_ROW
789           "5:\n"
790 
791           // Process the last zero-padded 16x4 block.
792           "eor v4.16b, v0.16b, v26.16b\n"
793           "eor v5.16b, v1.16b, v26.16b\n"
794           "eor v6.16b, v2.16b, v26.16b\n"
795           "eor v7.16b, v3.16b, v26.16b\n"
796 
797           "saddlp v16.8h, v4.16b\n"
798           "saddlp v17.8h, v5.16b\n"
799           "saddlp v18.8h, v6.16b\n"
800           "saddlp v19.8h, v7.16b\n"
801           "sadalp v28.4s, v16.8h\n"
802           "sadalp v29.4s, v17.8h\n"
803           "sadalp v30.4s, v18.8h\n"
804           "sadalp v31.4s, v19.8h\n"
805 
806           "str q4, [%[packed_ptr], #0]\n"
807           "str q5, [%[packed_ptr], #16]\n"
808           "str q6, [%[packed_ptr], #32]\n"
809           "str q7, [%[packed_ptr], #48]\n"
810           "add %[packed_ptr], %[packed_ptr], #64\n"
811 
812           "4:\n"
813 
814           // Horizontal reduction of the registers used to accumulate sums.
815           "addp v28.4s, v28.4s, v29.4s\n"
816           "addp v30.4s, v30.4s, v31.4s\n"
817           "addp v28.4s, v28.4s, v30.4s\n"
818 
819           // Store the sums.
820           "cmp %[sums_ptr], #0\n"
821           "beq 6f\n"
822           "st1 {v28.4s}, [%[sums_ptr]], #16\n"
823           "6:\n"
824           // clang-format on
825 
826           : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1),
827             [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3),
828             [ packed_ptr ] "+r"(packed_ptr), [ sums_ptr ] "+r"(sums_ptr)
829           : [ src_inc0 ] "r"(static_cast<std::int64_t>(src_inc0)), [ src_inc1 ] "r"(static_cast<std::int64_t>(src_inc1)),
830             [ src_inc2 ] "r"(static_cast<std::int64_t>(src_inc2)), [ src_inc3 ] "r"(static_cast<std::int64_t>(src_inc3)),
831             [ rows ] "r"(src_rows),
832             [ src_zero_point ] "r"(src_zero_point),
833             [input_xor] "r"(input_xor)
834           : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5",
835             "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15",
836             "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24",
837             "v25", "v26", "v27", "v28", "v29", "v30", "v31");
838 }
839 
Pack8bitColMajorForNeonDotprodA55ish(const void * src_ptr0,const void * src_ptr1,const void * src_ptr2,const void * src_ptr3,int src_inc0,int src_inc1,int src_inc2,int src_inc3,int src_rows,int src_zero_point,std::int8_t * packed_ptr,std::int32_t * sums_ptr,int input_xor)840 void Pack8bitColMajorForNeonDotprodA55ish(
841     const void* src_ptr0, const void* src_ptr1, const void* src_ptr2,
842     const void* src_ptr3, int src_inc0, int src_inc1, int src_inc2,
843     int src_inc3, int src_rows, int src_zero_point, std::int8_t* packed_ptr,
844     std::int32_t* sums_ptr, int input_xor) {
845   profiler::ScopeLabel label(
846       "Pack (kNeonDotprod, optimized for in-order cores)");
847   asm volatile(
848           // clang-format off
849           // v26 will be the vector to XOR input values with to perform
850           // any input data type conversion (e.g. uint8 to int8).
851           "dup v26.16b, %w[input_xor]\n"
852           // v27 will be filled with 1's. It will be used as an operand
853           // to SDOT to compute the sums.
854           "mov w1, #1\n"
855           "dup v27.16b, w1\n"
856           // w1 will be the number of rows already loaded.
857           "mov w1, #0\n"
858           // v28--v32 will be used to accumulate the sums
859           "movi v28.4s, #0\n"
860           "movi v29.4s, #0\n"
861           "movi v30.4s, #0\n"
862           "movi v31.4s, #0\n"
863 
864           // Let w2 be `rows` rounded down to multiple of 16.
865           "ands w2, %w[rows], #-16\n"
866           // If there are no full blocks of 16 rows to process, jump to the
867           // code handling the last < 16 rows.
868           "beq 3f\n"
869           // Load the first block of 16 rows.
870           "add w1, w1, #16\n"
871           "ldr x10, [%[src_ptr0], #8]\n"
872           "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n"
873           "ldr x11, [%[src_ptr1], #8]\n"
874           "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n"
875           "ldr x12, [%[src_ptr2], #8]\n"
876           "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n"
877           "ldr x13, [%[src_ptr3], #8]\n"
878           "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n"
879           // Check if these were the only full block of 16 rows to load.
880           "cmp w1, w2\n"
881           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #64]\n")
882           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #64]\n")
883           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #64]\n")
884           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #64]\n")
885           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #128]\n")
886           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #128]\n")
887           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #128]\n")
888           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #128]\n")
889           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #192]\n")
890           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #192]\n")
891           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #192]\n")
892           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #192]\n")
893           // In that case, jump to the code handling the last loaded block of
894           // 16 rows.
895           "beq 2f\n"
896 
897           // Main loop processing blocks of 16 rows.
898           "1:\n"
899           "add w1, w1, #16\n"
900           // Prepare the already-loaded 16 rows by inserting the parts
901           // loaded into general purpose registers x10--x13 into the
902           // NEON registers v0--v3 where the other parts had already been
903           // loaded.
904           "ins v0.d[1], x10\n"
905           "ldr x10, [%[src_ptr0], #8]\n"
906           "ins v1.d[1], x11\n"
907           "ldr x11, [%[src_ptr1], #8]\n"
908           "ins v2.d[1], x12\n"
909           "ldr x12, [%[src_ptr2], #8]\n"
910           "ins v3.d[1], x13\n"
911           "ldr x13, [%[src_ptr3], #8]\n"
912 
913           // Load the next 16 rows and, interleaved with that,
914           // perform the input type conversion (e.g. uint8->int8) on the
915           // current 16 rows.
916           "eor v4.16b, v0.16b, v26.16b\n"
917           "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n"
918           "eor v5.16b, v1.16b, v26.16b\n"
919           "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n"
920           "eor v6.16b, v2.16b, v26.16b\n"
921           "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n"
922           "eor v7.16b, v3.16b, v26.16b\n"
923           "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n"
924 
925           // Transposition of 4x4 blocks, part 1
926           "trn1 v16.4s, v4.4s, v5.4s\n"
927           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #240]\n")
928           "trn2 v17.4s, v4.4s, v5.4s\n"
929           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #240]\n")
930           "trn1 v18.4s, v6.4s, v7.4s\n"
931           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #240]\n")
932           "trn2 v19.4s, v6.4s, v7.4s\n"
933           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #240]\n")
934 
935           // Transposition of 4x4 blocks, part 2
936           "trn1 v20.2d, v16.2d, v18.2d\n"
937           "trn2 v22.2d, v16.2d, v18.2d\n"
938           "trn1 v21.2d, v17.2d, v19.2d\n"
939           "trn2 v23.2d, v17.2d, v19.2d\n"
940           "cmp w1, w2\n"
941 
942           // Store the block to the packed matrix and, interleaved with
943           // that, compute sums using sdot instructions.
944           ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
945           "str q20, [%[packed_ptr], #0]\n"
946           ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
947           "str q21, [%[packed_ptr], #32]\n"
948           ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
949           "str q22, [%[packed_ptr], #64]\n"
950           ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
951           "str q23, [%[packed_ptr], #96]\n"
952           "add %[packed_ptr], %[packed_ptr], #128\n"
953           // End of main loop on blocks of 16 rows.
954           "bne 1b\n"
955 
956           // Code handling the last already-loaded block of 16 rows.
957           "2:\n"
958           // Process the last loaded full 16x4 block.
959           "ins v0.d[1], x10\n"
960           "ins v1.d[1], x11\n"
961           "ins v2.d[1], x12\n"
962           "ins v3.d[1], x13\n"
963           "eor v0.16b, v0.16b, v26.16b\n"
964           "eor v1.16b, v1.16b, v26.16b\n"
965           "eor v2.16b, v2.16b, v26.16b\n"
966           "eor v3.16b, v3.16b, v26.16b\n"
967 
968           "trn1 v16.4s, v0.4s, v1.4s\n"
969           "trn2 v17.4s, v0.4s, v1.4s\n"
970           "trn1 v18.4s, v2.4s, v3.4s\n"
971           "trn2 v19.4s, v2.4s, v3.4s\n"
972 
973           "trn1 v20.2d, v16.2d, v18.2d\n"
974           "trn2 v22.2d, v16.2d, v18.2d\n"
975           "trn1 v21.2d, v17.2d, v19.2d\n"
976           "trn2 v23.2d, v17.2d, v19.2d\n"
977 
978           ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
979           "str q20, [%[packed_ptr], #0]\n"
980           ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
981           "str q21, [%[packed_ptr], #32]\n"
982           ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
983           "str q22, [%[packed_ptr], #64]\n"
984           ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
985           "str q23, [%[packed_ptr], #96]\n"
986           "add %[packed_ptr], %[packed_ptr], #128\n"
987 
988           // End of code handling full blocks of 16 rows.
989           // Now we handle any remaining rows.
990           "3:\n"
991           // Let w2 be the number of rows left to handle.
992           "ands w2, %w[rows], #15\n"
993           // If w2==0, there are no remaining rows, jump to the end.
994           "beq 4f\n"
995           // Zero out a 16x4 block in registers, which we'll partially overwrite
996           // with any remaining rows.
997           "dup v0.16b, %w[src_zero_point]\n"
998           "dup v1.16b, %w[src_zero_point]\n"
999           "dup v2.16b, %w[src_zero_point]\n"
1000           "dup v3.16b, %w[src_zero_point]\n"
1001 #define RUY_LOAD_ONE_ROW(R)                   \
1002   "cmp w2, #" #R "\n"                         \
1003   "beq 5f\n"                                  \
1004   "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \
1005   "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \
1006   "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \
1007   "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n"
1008 
1009           RUY_LOAD_ONE_ROW(0)
1010           RUY_LOAD_ONE_ROW(1)
1011           RUY_LOAD_ONE_ROW(2)
1012           RUY_LOAD_ONE_ROW(3)
1013           RUY_LOAD_ONE_ROW(4)
1014           RUY_LOAD_ONE_ROW(5)
1015           RUY_LOAD_ONE_ROW(6)
1016           RUY_LOAD_ONE_ROW(7)
1017           RUY_LOAD_ONE_ROW(8)
1018           RUY_LOAD_ONE_ROW(9)
1019           RUY_LOAD_ONE_ROW(10)
1020           RUY_LOAD_ONE_ROW(11)
1021           RUY_LOAD_ONE_ROW(12)
1022           RUY_LOAD_ONE_ROW(13)
1023           RUY_LOAD_ONE_ROW(14)
1024           // Here we know that w2==15, so RUY_LOAD_ONE_ROW(15) would be a no-op.
1025 #undef RUY_LOAD_ONE_ROW
1026 
1027           "5:\n"
1028           // Process the last zero-padded 16x4 block.
1029           "eor v0.16b, v0.16b, v26.16b\n"
1030           "eor v1.16b, v1.16b, v26.16b\n"
1031           "eor v2.16b, v2.16b, v26.16b\n"
1032           "eor v3.16b, v3.16b, v26.16b\n"
1033 
1034           "trn1 v16.4s, v0.4s, v1.4s\n"
1035           "trn2 v17.4s, v0.4s, v1.4s\n"
1036           "trn1 v18.4s, v2.4s, v3.4s\n"
1037           "trn2 v19.4s, v2.4s, v3.4s\n"
1038 
1039           "trn1 v20.2d, v16.2d, v18.2d\n"
1040           "trn2 v22.2d, v16.2d, v18.2d\n"
1041           "trn1 v21.2d, v17.2d, v19.2d\n"
1042           "trn2 v23.2d, v17.2d, v19.2d\n"
1043 
1044           ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
1045           "str q20, [%[packed_ptr], #0]\n"
1046           "cmp w2, #4\n"
1047           "ble 4f\n"
1048           ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
1049           "str q21, [%[packed_ptr], #32]\n"
1050           "cmp w2, #8\n"
1051           "ble 4f\n"
1052           ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
1053           "str q22, [%[packed_ptr], #64]\n"
1054           "cmp w2, #12\n"
1055           "ble 4f\n"
1056           ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
1057           "str q23, [%[packed_ptr], #96]\n"
1058           "add %[packed_ptr], %[packed_ptr], #128\n"
1059 
1060           "4:\n"
1061 
1062           // Reduction of the registers used to accumulate sums.
1063           "add v28.4s, v28.4s, v29.4s\n"
1064           "add v30.4s, v30.4s, v31.4s\n"
1065           "add v28.4s, v28.4s, v30.4s\n"
1066 
1067           // Store the sums.
1068           "cmp %[sums_ptr], #0\n"
1069           "beq 6f\n"
1070           "st1 {v28.4s}, [%[sums_ptr]], #16\n"
1071           "6:\n"
1072           // clang-format on
1073 
1074           : [ src_ptr0 ] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1), [src_ptr2] "+r"(src_ptr2),
1075             [src_ptr3] "+r"(src_ptr3), [packed_ptr] "+r"(packed_ptr), [sums_ptr] "+r"(sums_ptr)
1076           : [ src_inc0 ] "r"(static_cast<std::int64_t>(src_inc0)), [ src_inc1 ] "r"(static_cast<std::int64_t>(src_inc1)),
1077             [ src_inc2 ] "r"(static_cast<std::int64_t>(src_inc2)), [ src_inc3 ] "r"(static_cast<std::int64_t>(src_inc3)),
1078                 [rows] "r"(src_rows),
1079             [src_zero_point] "r"(static_cast<int>(src_zero_point)),
1080             [input_xor] "r"(input_xor)
1081           : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
1082             "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
1083             "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
1084 }
1085 
Pack8bitColMajorForNeonDotprod(const void * src_ptr0,const void * src_ptr1,const void * src_ptr2,const void * src_ptr3,int src_inc0,int src_inc1,int src_inc2,int src_inc3,int src_rows,int src_zero_point,std::int8_t * packed_ptr,std::int32_t * sums_ptr,int input_xor)1086 void Pack8bitColMajorForNeonDotprod(const void* src_ptr0, const void* src_ptr1,
1087                                     const void* src_ptr2, const void* src_ptr3,
1088                                     int src_inc0, int src_inc1, int src_inc2,
1089                                     int src_inc3, int src_rows,
1090                                     int src_zero_point, std::int8_t* packed_ptr,
1091                                     std::int32_t* sums_ptr, int input_xor) {
1092   profiler::ScopeLabel label("Pack (kNeonDotprod)");
1093   asm volatile(
1094       // clang-format off
1095           // v26 will be the vector to XOR input values with to perform
1096           // any input data type conversion (e.g. uint8 to int8).
1097           "dup v26.16b, %w[input_xor]\n"
1098           // v27 will be filled with 1's. It will be used as an operand
1099           // to SDOT to compute the sums.
1100           "mov w1, #1\n"
1101           "dup v27.16b, w1\n"
1102           // w1 will be the number of rows already loaded.
1103           "mov w1, #0\n"
1104           // v28--v32 will be used to accumulate the sums
1105           "movi v28.4s, #0\n"
1106           "movi v29.4s, #0\n"
1107           "movi v30.4s, #0\n"
1108           "movi v31.4s, #0\n"
1109 
1110           // 4x partially unrolled code processing blocks of 64 rows.
1111           // Read the original loop below first, it has more comments.
1112 #if RUY_OPT(MAX_STREAMING)
1113           // Let w2 be `rows` rounded down to multiple of 64.
1114           // Each iteration of this 4x partially unrolled loop handles
1115           // 64 rows.
1116           "ands w2, %w[rows], #-64\n"
1117           // If there are no full blocks of 64 rows to process, jump to
1118           // the main loop below handling 16 rows per iteration.
1119           "beq 9f\n"
1120           // Load the first block of 64 rows.
1121           "add w1, w1, #64\n"
1122           "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n"
1123           "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n"
1124           "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n"
1125           "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n"
1126           "ld1 {v4.16b}, [%[src_ptr0]], %[src_inc0]\n"
1127           "ld1 {v5.16b}, [%[src_ptr1]], %[src_inc1]\n"
1128           "ld1 {v6.16b}, [%[src_ptr2]], %[src_inc2]\n"
1129           "ld1 {v7.16b}, [%[src_ptr3]], %[src_inc3]\n"
1130           "ld1 {v8.16b}, [%[src_ptr0]], %[src_inc0]\n"
1131           "ld1 {v9.16b}, [%[src_ptr1]], %[src_inc1]\n"
1132           "ld1 {v10.16b}, [%[src_ptr2]], %[src_inc2]\n"
1133           "ld1 {v11.16b}, [%[src_ptr3]], %[src_inc3]\n"
1134           "ld1 {v12.16b}, [%[src_ptr0]], %[src_inc0]\n"
1135           "ld1 {v13.16b}, [%[src_ptr1]], %[src_inc1]\n"
1136           // Was that the last full block of 64 rows to load?
1137           "cmp w1, w2\n"
1138           "ld1 {v14.16b}, [%[src_ptr2]], %[src_inc2]\n"
1139           "ld1 {v15.16b}, [%[src_ptr3]], %[src_inc3]\n"
1140           // Then jump to the end of the 64-rows-at-a-time code.
1141           "beq 8f\n"
1142 
1143           // Start of the main 4x partially unrolled loop.
1144           "7:\n"
1145           // Process rows 0 -- 15 out of 64.
1146           "eor v0.16b, v0.16b, v26.16b\n"
1147           "eor v1.16b, v1.16b, v26.16b\n"
1148           "eor v2.16b, v2.16b, v26.16b\n"
1149           "eor v3.16b, v3.16b, v26.16b\n"
1150 
1151           "trn1 v16.4s, v0.4s, v1.4s\n"
1152           "trn2 v17.4s, v0.4s, v1.4s\n"
1153           "trn1 v18.4s, v2.4s, v3.4s\n"
1154           "trn2 v19.4s, v2.4s, v3.4s\n"
1155 
1156           "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n"
1157           "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n"
1158           "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n"
1159           "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n"
1160           "add w1, w1, #16\n"
1161 
1162           "trn1 v20.2d, v16.2d, v18.2d\n"
1163           "trn2 v22.2d, v16.2d, v18.2d\n"
1164           "trn1 v21.2d, v17.2d, v19.2d\n"
1165           "trn2 v23.2d, v17.2d, v19.2d\n"
1166 
1167           ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
1168           ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
1169           ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
1170           ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
1171 
1172           "str q20, [%[packed_ptr], #0]\n"
1173           "str q21, [%[packed_ptr], #32]\n"
1174           "str q22, [%[packed_ptr], #64]\n"
1175           "str q23, [%[packed_ptr], #96]\n"
1176           "add %[packed_ptr], %[packed_ptr], #128\n"
1177 
1178           // Process rows 16 -- 31 out of 64.
1179           "eor v4.16b, v4.16b, v26.16b\n"
1180           "eor v5.16b, v5.16b, v26.16b\n"
1181           "eor v6.16b, v6.16b, v26.16b\n"
1182           "eor v7.16b, v7.16b, v26.16b\n"
1183 
1184           "trn1 v16.4s, v4.4s, v5.4s\n"
1185           "trn2 v17.4s, v4.4s, v5.4s\n"
1186           "trn1 v18.4s, v6.4s, v7.4s\n"
1187           "trn2 v19.4s, v6.4s, v7.4s\n"
1188 
1189           "ld1 {v4.16b}, [%[src_ptr0]], %[src_inc0]\n"
1190           "ld1 {v5.16b}, [%[src_ptr1]], %[src_inc1]\n"
1191           "ld1 {v6.16b}, [%[src_ptr2]], %[src_inc2]\n"
1192           "ld1 {v7.16b}, [%[src_ptr3]], %[src_inc3]\n"
1193           "add w1, w1, #16\n"
1194 
1195           "trn1 v20.2d, v16.2d, v18.2d\n"
1196           "trn2 v22.2d, v16.2d, v18.2d\n"
1197           "trn1 v21.2d, v17.2d, v19.2d\n"
1198           "trn2 v23.2d, v17.2d, v19.2d\n"
1199 
1200           ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
1201           ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
1202           ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
1203           ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
1204 
1205           "str q20, [%[packed_ptr], #0]\n"
1206           "str q21, [%[packed_ptr], #32]\n"
1207           "str q22, [%[packed_ptr], #64]\n"
1208           "str q23, [%[packed_ptr], #96]\n"
1209           "add %[packed_ptr], %[packed_ptr], #128\n"
1210 
1211           // Process rows 32 -- 47 out of 64.
1212           "eor v8.16b, v8.16b, v26.16b\n"
1213           "eor v9.16b, v9.16b, v26.16b\n"
1214           "eor v10.16b, v10.16b, v26.16b\n"
1215           "eor v11.16b, v11.16b, v26.16b\n"
1216 
1217           "trn1 v16.4s, v8.4s, v9.4s\n"
1218           "trn2 v17.4s, v8.4s, v9.4s\n"
1219           "trn1 v18.4s, v10.4s, v11.4s\n"
1220           "trn2 v19.4s, v10.4s, v11.4s\n"
1221 
1222           "ld1 {v8.16b}, [%[src_ptr0]], %[src_inc0]\n"
1223           "ld1 {v9.16b}, [%[src_ptr1]], %[src_inc1]\n"
1224           "ld1 {v10.16b}, [%[src_ptr2]], %[src_inc2]\n"
1225           "ld1 {v11.16b}, [%[src_ptr3]], %[src_inc3]\n"
1226           "add w1, w1, #16\n"
1227 
1228           "trn1 v20.2d, v16.2d, v18.2d\n"
1229           "trn2 v22.2d, v16.2d, v18.2d\n"
1230           "trn1 v21.2d, v17.2d, v19.2d\n"
1231           "trn2 v23.2d, v17.2d, v19.2d\n"
1232 
1233           ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
1234           ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
1235           ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
1236           ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
1237 
1238           "str q20, [%[packed_ptr], #0]\n"
1239           "str q21, [%[packed_ptr], #32]\n"
1240           "str q22, [%[packed_ptr], #64]\n"
1241           "str q23, [%[packed_ptr], #96]\n"
1242           "add %[packed_ptr], %[packed_ptr], #128\n"
1243 
1244           // Process rows 48 -- 63 out of 64.
1245           "eor v12.16b, v12.16b, v26.16b\n"
1246           "eor v13.16b, v13.16b, v26.16b\n"
1247           "eor v14.16b, v14.16b, v26.16b\n"
1248           "eor v15.16b, v15.16b, v26.16b\n"
1249 
1250           "trn1 v16.4s, v12.4s, v13.4s\n"
1251           "trn2 v17.4s, v12.4s, v13.4s\n"
1252           "trn1 v18.4s, v14.4s, v15.4s\n"
1253           "trn2 v19.4s, v14.4s, v15.4s\n"
1254 
1255           "ld1 {v12.16b}, [%[src_ptr0]], %[src_inc0]\n"
1256           "ld1 {v13.16b}, [%[src_ptr1]], %[src_inc1]\n"
1257           "ld1 {v14.16b}, [%[src_ptr2]], %[src_inc2]\n"
1258           "ld1 {v15.16b}, [%[src_ptr3]], %[src_inc3]\n"
1259           "add w1, w1, #16\n"
1260 
1261           "trn1 v20.2d, v16.2d, v18.2d\n"
1262           "trn2 v22.2d, v16.2d, v18.2d\n"
1263           "trn1 v21.2d, v17.2d, v19.2d\n"
1264           "trn2 v23.2d, v17.2d, v19.2d\n"
1265 
1266           ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
1267           ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
1268           ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
1269           ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
1270 
1271           "cmp w1, w2\n"
1272           "str q20, [%[packed_ptr], #0]\n"
1273           "str q21, [%[packed_ptr], #32]\n"
1274           "str q22, [%[packed_ptr], #64]\n"
1275           "str q23, [%[packed_ptr], #96]\n"
1276           "add %[packed_ptr], %[packed_ptr], #128\n"
1277 
1278           // End of main 4x partially unrolled loop.
1279           "bne 7b\n"
1280 
1281           // Last part of the 4x partially unrolled code:
1282           // handle the last already-loaded 64 rows.
1283           "8:\n"
1284 
1285           // Process rows 0 -- 15 out of 64.
1286           "eor v0.16b, v0.16b, v26.16b\n"
1287           "eor v1.16b, v1.16b, v26.16b\n"
1288           "eor v2.16b, v2.16b, v26.16b\n"
1289           "eor v3.16b, v3.16b, v26.16b\n"
1290 
1291           "trn1 v16.4s, v0.4s, v1.4s\n"
1292           "trn2 v17.4s, v0.4s, v1.4s\n"
1293           "trn1 v18.4s, v2.4s, v3.4s\n"
1294           "trn2 v19.4s, v2.4s, v3.4s\n"
1295 
1296           "trn1 v20.2d, v16.2d, v18.2d\n"
1297           "trn2 v22.2d, v16.2d, v18.2d\n"
1298           "trn1 v21.2d, v17.2d, v19.2d\n"
1299           "trn2 v23.2d, v17.2d, v19.2d\n"
1300 
1301           ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
1302           ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
1303           ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
1304           ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
1305 
1306           "str q20, [%[packed_ptr], #0]\n"
1307           "str q21, [%[packed_ptr], #32]\n"
1308           "str q22, [%[packed_ptr], #64]\n"
1309           "str q23, [%[packed_ptr], #96]\n"
1310           "add %[packed_ptr], %[packed_ptr], #128\n"
1311 
1312           // Process rows 16 -- 31 out of 64.
1313           "eor v4.16b, v4.16b, v26.16b\n"
1314           "eor v5.16b, v5.16b, v26.16b\n"
1315           "eor v6.16b, v6.16b, v26.16b\n"
1316           "eor v7.16b, v7.16b, v26.16b\n"
1317 
1318           "trn1 v16.4s, v4.4s, v5.4s\n"
1319           "trn2 v17.4s, v4.4s, v5.4s\n"
1320           "trn1 v18.4s, v6.4s, v7.4s\n"
1321           "trn2 v19.4s, v6.4s, v7.4s\n"
1322 
1323           "trn1 v20.2d, v16.2d, v18.2d\n"
1324           "trn2 v22.2d, v16.2d, v18.2d\n"
1325           "trn1 v21.2d, v17.2d, v19.2d\n"
1326           "trn2 v23.2d, v17.2d, v19.2d\n"
1327 
1328           ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
1329           ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
1330           ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
1331           ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
1332 
1333           "str q20, [%[packed_ptr], #0]\n"
1334           "str q21, [%[packed_ptr], #32]\n"
1335           "str q22, [%[packed_ptr], #64]\n"
1336           "str q23, [%[packed_ptr], #96]\n"
1337           "add %[packed_ptr], %[packed_ptr], #128\n"
1338 
1339           // Process rows 32 -- 47 out of 64.
1340           "eor v8.16b, v8.16b, v26.16b\n"
1341           "eor v9.16b, v9.16b, v26.16b\n"
1342           "eor v10.16b, v10.16b, v26.16b\n"
1343           "eor v11.16b, v11.16b, v26.16b\n"
1344 
1345           "trn1 v16.4s, v8.4s, v9.4s\n"
1346           "trn2 v17.4s, v8.4s, v9.4s\n"
1347           "trn1 v18.4s, v10.4s, v11.4s\n"
1348           "trn2 v19.4s, v10.4s, v11.4s\n"
1349 
1350           "trn1 v20.2d, v16.2d, v18.2d\n"
1351           "trn2 v22.2d, v16.2d, v18.2d\n"
1352           "trn1 v21.2d, v17.2d, v19.2d\n"
1353           "trn2 v23.2d, v17.2d, v19.2d\n"
1354 
1355           ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
1356           ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
1357           ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
1358           ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
1359 
1360           "str q20, [%[packed_ptr], #0]\n"
1361           "str q21, [%[packed_ptr], #32]\n"
1362           "str q22, [%[packed_ptr], #64]\n"
1363           "str q23, [%[packed_ptr], #96]\n"
1364           "add %[packed_ptr], %[packed_ptr], #128\n"
1365 
1366           // Process rows 48 -- 63 out of 64.
1367           "eor v12.16b, v12.16b, v26.16b\n"
1368           "eor v13.16b, v13.16b, v26.16b\n"
1369           "eor v14.16b, v14.16b, v26.16b\n"
1370           "eor v15.16b, v15.16b, v26.16b\n"
1371 
1372           "trn1 v16.4s, v12.4s, v13.4s\n"
1373           "trn2 v17.4s, v12.4s, v13.4s\n"
1374           "trn1 v18.4s, v14.4s, v15.4s\n"
1375           "trn2 v19.4s, v14.4s, v15.4s\n"
1376 
1377           "trn1 v20.2d, v16.2d, v18.2d\n"
1378           "trn2 v22.2d, v16.2d, v18.2d\n"
1379           "trn1 v21.2d, v17.2d, v19.2d\n"
1380           "trn2 v23.2d, v17.2d, v19.2d\n"
1381 
1382           ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
1383           ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
1384           ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
1385           ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
1386 
1387           "str q20, [%[packed_ptr], #0]\n"
1388           "str q21, [%[packed_ptr], #32]\n"
1389           "str q22, [%[packed_ptr], #64]\n"
1390           "str q23, [%[packed_ptr], #96]\n"
1391           "add %[packed_ptr], %[packed_ptr], #128\n"
1392 
1393           "9:\n"
1394 #endif  // #if RUY_OPT(MAX_STREAMING)
1395           // End of 4x partially unrolled code processing blocks of 64 rows.
1396 
1397           // Main part of the code, processing blocks of 16 rows.
1398 
1399           // Let w2 be `rows` rounded down to multiple of 16.
1400           "and w2, %w[rows], #-16\n"
1401           // If there are no full blocks of 16 rows to process, jump to the
1402           // code handling the last < 16 rows.
1403           "cmp w1, w2\n"
1404           "beq 3f\n"
1405 
1406           // Load the first block of 16 rows.
1407           "add w1, w1, #16\n"
1408           "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n"
1409           "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n"
1410           // Check if these were the only full block of 16 rows to load.
1411           "cmp w1, w2\n"
1412           "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n"
1413           "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n"
1414           // In that case, jump to the code handling the last loaded block of
1415           // 16 rows.
1416           "beq 2f\n"
1417           // Main loop processing blocks of 16 rows.
1418           "1:\n"
1419           // Input type conversion (e.g. uint8->int8).
1420           "eor v0.16b, v0.16b, v26.16b\n"
1421           "eor v1.16b, v1.16b, v26.16b\n"
1422           "eor v2.16b, v2.16b, v26.16b\n"
1423           "eor v3.16b, v3.16b, v26.16b\n"
1424           // Transposition of 4x4 blocks, part 1
1425           "trn1 v16.4s, v0.4s, v1.4s\n"
1426           "trn2 v17.4s, v0.4s, v1.4s\n"
1427           "trn1 v18.4s, v2.4s, v3.4s\n"
1428           "trn2 v19.4s, v2.4s, v3.4s\n"
1429           // Load the next 16 rows
1430           "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n"
1431           "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n"
1432           "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n"
1433           "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n"
1434           "add w1, w1, #16\n"
1435           // Transposition of 4x4 blocks, part 2
1436           "trn1 v20.2d, v16.2d, v18.2d\n"
1437           "trn2 v22.2d, v16.2d, v18.2d\n"
1438           "trn1 v21.2d, v17.2d, v19.2d\n"
1439           "trn2 v23.2d, v17.2d, v19.2d\n"
1440           // Compute sums using sdot instructions.
1441           ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
1442           ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
1443           ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
1444           ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
1445           // Store the block to the packed matrix.
1446           "str q20, [%[packed_ptr], #0]\n"
1447           "str q21, [%[packed_ptr], #32]\n"
1448           "cmp w1, w2\n"
1449           "str q22, [%[packed_ptr], #64]\n"
1450           "str q23, [%[packed_ptr], #96]\n"
1451           "add %[packed_ptr], %[packed_ptr], #128\n"
1452           // End of main loop on blocks of 16 rows.
1453           "bne 1b\n"
1454 
1455           // Code handling the last already-loaded block of 16 rows.
1456           "2:\n"
1457 
1458           // Process the last loaded full 16x4 block.
1459           "eor v0.16b, v0.16b, v26.16b\n"
1460           "eor v1.16b, v1.16b, v26.16b\n"
1461           "eor v2.16b, v2.16b, v26.16b\n"
1462           "eor v3.16b, v3.16b, v26.16b\n"
1463 
1464           "trn1 v16.4s, v0.4s, v1.4s\n"
1465           "trn2 v17.4s, v0.4s, v1.4s\n"
1466           "trn1 v18.4s, v2.4s, v3.4s\n"
1467           "trn2 v19.4s, v2.4s, v3.4s\n"
1468 
1469           "trn1 v20.2d, v16.2d, v18.2d\n"
1470           "trn2 v22.2d, v16.2d, v18.2d\n"
1471           "trn1 v21.2d, v17.2d, v19.2d\n"
1472           "trn2 v23.2d, v17.2d, v19.2d\n"
1473 
1474           ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
1475           ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
1476           ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
1477           ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
1478 
1479           "str q20, [%[packed_ptr], #0]\n"
1480           "str q21, [%[packed_ptr], #32]\n"
1481           "str q22, [%[packed_ptr], #64]\n"
1482           "str q23, [%[packed_ptr], #96]\n"
1483           "add %[packed_ptr], %[packed_ptr], #128\n"
1484 
1485           // End of code handling full blocks of 16 rows.
1486           // Now we handle any remaining rows.
1487           "3:\n"
1488           // Let w2 be the number of rows left to handle.
1489           "ands w2, %w[rows], #15\n"
1490           // If w2==0, there are no remaining rows, jump to the end.
1491           "beq 4f\n"
1492           // Zero out a 16x4 block in registers, which we'll partially overwrite
1493           // with any remaining rows.
1494           "dup v0.16b, %w[src_zero_point]\n"
1495           "dup v1.16b, %w[src_zero_point]\n"
1496           "dup v2.16b, %w[src_zero_point]\n"
1497           "dup v3.16b, %w[src_zero_point]\n"
1498 #define RUY_LOAD_ONE_ROW(R)                   \
1499   "cmp w2, #" #R "\n"                         \
1500   "beq 5f\n"                                  \
1501   "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \
1502   "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \
1503   "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \
1504   "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n"
1505 
1506           RUY_LOAD_ONE_ROW(0)
1507           RUY_LOAD_ONE_ROW(1)
1508           RUY_LOAD_ONE_ROW(2)
1509           RUY_LOAD_ONE_ROW(3)
1510           RUY_LOAD_ONE_ROW(4)
1511           RUY_LOAD_ONE_ROW(5)
1512           RUY_LOAD_ONE_ROW(6)
1513           RUY_LOAD_ONE_ROW(7)
1514           RUY_LOAD_ONE_ROW(8)
1515           RUY_LOAD_ONE_ROW(9)
1516           RUY_LOAD_ONE_ROW(10)
1517           RUY_LOAD_ONE_ROW(11)
1518           RUY_LOAD_ONE_ROW(12)
1519           RUY_LOAD_ONE_ROW(13)
1520           RUY_LOAD_ONE_ROW(14)
1521           // Here we know that w2==15, so RUY_LOAD_ONE_ROW(15) would be a no-op.
1522 #undef RUY_LOAD_ONE_ROW
1523 
1524           "5:\n"
1525           // Process the last zero-padded 16x4 block.
1526           "eor v0.16b, v0.16b, v26.16b\n"
1527           "eor v1.16b, v1.16b, v26.16b\n"
1528           "eor v2.16b, v2.16b, v26.16b\n"
1529           "eor v3.16b, v3.16b, v26.16b\n"
1530 
1531           "trn1 v16.4s, v0.4s, v1.4s\n"
1532           "trn2 v17.4s, v0.4s, v1.4s\n"
1533           "trn1 v18.4s, v2.4s, v3.4s\n"
1534           "trn2 v19.4s, v2.4s, v3.4s\n"
1535 
1536           "trn1 v20.2d, v16.2d, v18.2d\n"
1537           "trn2 v22.2d, v16.2d, v18.2d\n"
1538           "trn1 v21.2d, v17.2d, v19.2d\n"
1539           "trn2 v23.2d, v17.2d, v19.2d\n"
1540 
1541           ".word 0x4e9b969c  // sdot v28.4s, v20.16b, v27.16b\n"
1542           "str q20, [%[packed_ptr], #0]\n"
1543           "cmp w2, #4\n"
1544           "ble 4f\n"
1545           ".word 0x4e9b96be  // sdot v30.4s, v21.16b, v27.16b\n"
1546           "str q21, [%[packed_ptr], #32]\n"
1547           "cmp w2, #8\n"
1548           "ble 4f\n"
1549           ".word 0x4e9b96dd  // sdot v29.4s, v22.16b, v27.16b\n"
1550           "str q22, [%[packed_ptr], #64]\n"
1551           "cmp w2, #12\n"
1552           "ble 4f\n"
1553           ".word 0x4e9b96ff  // sdot v31.4s, v23.16b, v27.16b\n"
1554           "str q23, [%[packed_ptr], #96]\n"
1555           "add %[packed_ptr], %[packed_ptr], #128\n"
1556 
1557           "4:\n"
1558 
1559           // Reduction of the registers used to accumulate sums.
1560           "add v28.4s, v28.4s, v29.4s\n"
1561           "add v30.4s, v30.4s, v31.4s\n"
1562           "add v28.4s, v28.4s, v30.4s\n"
1563 
1564           // Store the sums.
1565           "cmp %[sums_ptr], #0\n"
1566           "beq 6f\n"
1567           "st1 {v28.4s}, [%[sums_ptr]], #16\n"
1568           "6:\n"
1569       // clang-format on
1570 
1571       : [src_ptr0] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1),
1572         [src_ptr2] "+r"(src_ptr2), [src_ptr3] "+r"(src_ptr3),
1573         [packed_ptr] "+r"(packed_ptr), [sums_ptr] "+r"(sums_ptr)
1574       : [src_inc0] "r"(static_cast<std::int64_t>(src_inc0)),
1575         [src_inc1] "r"(static_cast<std::int64_t>(src_inc1)),
1576         [src_inc2] "r"(static_cast<std::int64_t>(src_inc2)),
1577         [src_inc3] "r"(static_cast<std::int64_t>(src_inc3)),
1578         [rows] "r"(src_rows),
1579         [src_zero_point] "r"(static_cast<int>(src_zero_point)),
1580         [input_xor] "r"(input_xor)
1581       : "cc", "memory", "x1", "x2", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
1582         "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
1583         "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
1584         "v27", "v28", "v29", "v30", "v31");
1585 }
1586 
Pack8bitRowMajorForNeonDotprod(const void * src_ptr0,const void * src_ptr1,const void * src_ptr2,const void * src_ptr3,int src_inc0,int src_inc1,int src_inc2,int src_inc3,int src_cols,int src_zero_point,std::int8_t * packed_ptr,int packed_stride,std::int32_t * sums_ptr,int input_xor)1587 void Pack8bitRowMajorForNeonDotprod(const void* src_ptr0, const void* src_ptr1,
1588                                     const void* src_ptr2, const void* src_ptr3,
1589                                     int src_inc0, int src_inc1, int src_inc2,
1590                                     int src_inc3, int src_cols,
1591                                     int src_zero_point, std::int8_t* packed_ptr,
1592                                     int packed_stride, std::int32_t* sums_ptr,
1593                                     int input_xor) {
1594   profiler::ScopeLabel label("Pack (kNeonDotprod, from row-major)");
1595   asm volatile(
1596       // clang-format off
1597           // Prefetch data. This was tuned on Cortex-A55-rev1 cores.
1598           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0]]\n")
1599           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1]]\n")
1600           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2]]\n")
1601           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3]]\n")
1602           // Let w0 = (number of columns to compute) - 8.
1603           "subs w0, %w[src_cols], 8\n"
1604           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], 64]\n")
1605           // Let v26 duplicate the input_xor value in all lanes.
1606           "dup v26.16b, %w[input_xor]\n"
1607           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], 64]\n")
1608           // Let v27 be 1 in all lanes. Used with sdot to compute sums.
1609           "movi v27.16b, 1\n"
1610           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], 64]\n")
1611           // If there isn't a full block of 8 columns to load from, jump to the
1612           // code after the loop handling leftovers.
1613           "blt 2f\n"
1614           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], 64]\n")
1615           // Main loop, each iteration handles a full block of 8 cols.
1616           "1:\n"
1617           // Load the 4x8 block from the source matrix, or zero if we're
1618           // past the bottom of the source matrix.
1619           "ld1 {v0.8b}, [%[src_ptr0]]\n"
1620           "ld1 {v1.8b}, [%[src_ptr1]]\n"
1621           "ld1 {v2.8b}, [%[src_ptr2]]\n"
1622           "ld1 {v3.8b}, [%[src_ptr3]]\n"
1623           // Load values from the sums buffer, and start the reordering
1624           // of the loaded 4x8 block by interleaving 8bit values.
1625           "zip1 v0.16b, v0.16b, v1.16b\n"
1626           "ldr q8, [%[sums_ptr], 0]\n"
1627           "zip1 v1.16b, v2.16b, v3.16b\n"
1628           "ldr q9, [%[sums_ptr], 16]\n"
1629           // Finish the reordering of the 4x8 block, putting it into
1630           // column-major order.
1631           "zip1 v2.8h, v0.8h, v1.8h\n"
1632           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], 128]\n")
1633           "zip2 v3.8h, v0.8h, v1.8h\n"
1634           // Apply input_xor, i.e. convert source values from uint8 to int8
1635           // if needed.
1636           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], 128]\n")
1637           "eor v2.16b, v2.16b, v26.16b\n"
1638           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], 128]\n")
1639           "eor v3.16b, v3.16b, v26.16b\n"
1640           // Update the sums.
1641           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], 128]\n")
1642           ".word 0x4e9b9448  // sdot v8.4s, v2.16b, v27.16b\n"
1643           ".word 0x4e9b9469  // sdot v9.4s, v3.16b, v27.16b\n"
1644           // Store the column-major 4x8 block to the packed matrix, and
1645           // increment some source pointers.
1646           "str q2, [%[packed_ptr], 0]\n"
1647           "add %[src_ptr0], %[src_ptr0], %w[src_inc0], sxtw\n"
1648           "str q3, [%[packed_ptr], 16]\n"
1649           "add %[src_ptr1], %[src_ptr1], %w[src_inc1], sxtw\n"
1650           // Store the updated sums, and increment the remaining pointers
1651           // and the block_col loop index.
1652           "st1 {v8.4s}, [%[sums_ptr]], 16\n"
1653           "add %[packed_ptr], %[packed_ptr], %[packed_stride], lsl 3\n"
1654           "st1 {v9.4s}, [%[sums_ptr]], 16\n"
1655           // Advance by 8 columns and set the condition code.
1656           "subs w0, w0, 8\n"
1657           "add %[src_ptr2], %[src_ptr2], %w[src_inc2], sxtw\n"
1658           "add %[src_ptr3], %[src_ptr3], %w[src_inc3], sxtw\n"
1659           // End of the main loop.
1660           "bge 1b\n"
1661 
1662           "2:\n"
1663           // We add back 8 to w0 so that w0 is the number of columns remaining
1664           // to handle.
1665           "adds w0, w0, 8\n"
1666           // Nothing left? Then jump to the end.
1667           "beq 3f\n"
1668           // Here w0 is between 1 and 7. We zero-initialize v0--v3 ...
1669           "dup v0.8b, %w[src_zero_point]\n"
1670           "dup v1.8b, %w[src_zero_point]\n"
1671           "dup v2.8b, %w[src_zero_point]\n"
1672           "dup v3.8b, %w[src_zero_point]\n"
1673           // ... and now we fill lanes one by one with leftover columns.
1674 #define RUY_LOAD_ONE_COL(C)\
1675   "cmp w0, " #C "\n" \
1676   "beq 4f\n"                                  \
1677   "ld1 { v0.b }[" #C "], [%[src_ptr0]], #1\n" \
1678   "ld1 { v1.b }[" #C "], [%[src_ptr1]], #1\n" \
1679   "ld1 { v2.b }[" #C "], [%[src_ptr2]], #1\n" \
1680   "ld1 { v3.b }[" #C "], [%[src_ptr3]], #1\n"
1681 
1682           RUY_LOAD_ONE_COL(0)
1683           RUY_LOAD_ONE_COL(1)
1684           RUY_LOAD_ONE_COL(2)
1685           RUY_LOAD_ONE_COL(3)
1686           RUY_LOAD_ONE_COL(4)
1687           RUY_LOAD_ONE_COL(5)
1688           RUY_LOAD_ONE_COL(6)
1689           // Here we know that w0==7, so RUY_LOAD_ONE_COL(7) would be a no-op.
1690 #undef RUY_LOAD_ONE_COL
1691 
1692           "4:\n"
1693           // The leftovers source data is loaded, now we can perform the
1694           // computation as usual.
1695           // Load values from the sums buffer, and start the reordering
1696           // of the loaded 4x8 block by interleaving 8bit values.
1697           "zip1 v0.16b, v0.16b, v1.16b\n"
1698           "ldr q8, [%[sums_ptr], 0]\n"
1699           "zip1 v1.16b, v2.16b, v3.16b\n"
1700           "ldr q9, [%[sums_ptr], 16]\n"
1701           // Finish the reordering of the 4x8 block, putting it into
1702           // column-major order.
1703           "zip1 v2.8h, v0.8h, v1.8h\n"
1704           "zip2 v3.8h, v0.8h, v1.8h\n"
1705           // Apply input_xor, i.e. convert source values from uint8 to int8
1706           // if needed.
1707           "eor v2.16b, v2.16b, v26.16b\n"
1708           "eor v3.16b, v3.16b, v26.16b\n"
1709           // Update the sums.
1710           ".word 0x4e9b9448  // sdot v8.4s, v2.16b, v27.16b\n"
1711           ".word 0x4e9b9469  // sdot v9.4s, v3.16b, v27.16b\n"
1712           // Store the column-major 4x8 block to the packed matrix, and
1713           // increment some source pointers.
1714           "str q2, [%[packed_ptr], 0]\n"
1715           "str q3, [%[packed_ptr], 16]\n"
1716           // Store the updated sums, and increment the remaining pointers
1717           // and the block_col loop index.
1718           "st1 {v8.4s}, [%[sums_ptr]], 16\n"
1719           "st1 {v9.4s}, [%[sums_ptr]], 16\n"
1720 
1721           // End label.
1722           "3:\n"
1723       // clang-format on
1724       : [packed_ptr] "+r"(packed_ptr), [sums_ptr] "+r"(sums_ptr),
1725         [src_ptr0] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1),
1726         [src_ptr2] "+r"(src_ptr2), [src_ptr3] "+r"(src_ptr3)
1727       : [src_inc0] "r"(src_inc0), [src_inc1] "r"(src_inc1),
1728         [src_inc2] "r"(src_inc2), [src_inc3] "r"(src_inc3),
1729         [input_xor] "r"(input_xor), [src_zero_point] "r"(src_zero_point),
1730         [packed_stride] "r"(static_cast<std::int64_t>(packed_stride)),
1731         [src_cols] "r"(src_cols)
1732       : "cc", "memory", "x0", "x1", "x2", "v0", "v1", "v2", "v3", "v4", "v5",
1733         "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
1734         "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
1735         "v27", "v28", "v29", "v30", "v31");
1736 }
1737 
PackFloatColMajorForNeon(const float * src_ptr0,const float * src_ptr1,const float * src_ptr2,const float * src_ptr3,int src_inc0,int src_inc1,int src_inc2,int src_inc3,int src_rows,float * packed_ptr)1738 void PackFloatColMajorForNeon(const float* src_ptr0, const float* src_ptr1,
1739                               const float* src_ptr2, const float* src_ptr3,
1740                               int src_inc0, int src_inc1, int src_inc2,
1741                               int src_inc3, int src_rows, float* packed_ptr) {
1742   profiler::ScopeLabel label("Pack (kNeon)");
1743   asm volatile(
1744       // clang-format off
1745           // w1 will be the number of rows already loaded.
1746           "mov w1, #0\n"
1747           // Let w2 be `rows` rounded down to multiple of 4.
1748           "ands w2, %w[rows], #-4\n"
1749           // If there are no full blocks of 4 rows to process, jump to the
1750           // code handling the last < 4 rows.
1751           "beq 3f\n"
1752           // Load the first block of 16 rows.
1753           "add w1, w1, #4\n"
1754           "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n"
1755           "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n"
1756           // Check if these were the only full block of 4 rows to load.
1757           "cmp w1, w2\n"
1758           "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n"
1759           "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n"
1760           // In that case, jump to the code handling the last loaded block of
1761           // 4 rows.
1762           "beq 2f\n"
1763           // Main loop processing blocks of 4 rows.
1764           "1:\n"
1765           // Advance by 4 rows.
1766           "add w1, w1, #4\n"
1767           // Transposition of the already-loaded 4x4 block, part 1.
1768           "trn1 v16.4s, v0.4s, v1.4s\n"
1769           "trn2 v17.4s, v0.4s, v1.4s\n"
1770           "trn1 v18.4s, v2.4s, v3.4s\n"
1771           "trn2 v19.4s, v2.4s, v3.4s\n"
1772           // Load the next 4x4 block.
1773           "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n"
1774           "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n"
1775           "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n"
1776           "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n"
1777           // Transposition of the already-loaded 4x4 block, part 2.
1778           "trn1 v20.2d, v16.2d, v18.2d\n"
1779           "trn2 v22.2d, v16.2d, v18.2d\n"
1780           "trn1 v21.2d, v17.2d, v19.2d\n"
1781           "trn2 v23.2d, v17.2d, v19.2d\n"
1782           // Was this the last full 4x4 block to load?
1783           "cmp w1, w2\n"
1784           // Store the transposed 4x4 block.
1785           "str q20, [%[packed_ptr], #0]\n"
1786           "str q21, [%[packed_ptr], #32]\n"
1787           "str q22, [%[packed_ptr], #64]\n"
1788           "str q23, [%[packed_ptr], #96]\n"
1789           "add %[packed_ptr], %[packed_ptr], #128\n"
1790           // End of main loop on 4x4 blocks.
1791           "bne 1b\n"
1792 
1793           // Code handling the last already-loaded 4x4 block.
1794           "2:\n"
1795 
1796           "trn1 v16.4s, v0.4s, v1.4s\n"
1797           "trn2 v17.4s, v0.4s, v1.4s\n"
1798           "trn1 v18.4s, v2.4s, v3.4s\n"
1799           "trn2 v19.4s, v2.4s, v3.4s\n"
1800 
1801           "trn1 v20.2d, v16.2d, v18.2d\n"
1802           "trn2 v22.2d, v16.2d, v18.2d\n"
1803           "trn1 v21.2d, v17.2d, v19.2d\n"
1804           "trn2 v23.2d, v17.2d, v19.2d\n"
1805 
1806           "str q20, [%[packed_ptr], #0]\n"
1807           "str q21, [%[packed_ptr], #32]\n"
1808           "str q22, [%[packed_ptr], #64]\n"
1809           "str q23, [%[packed_ptr], #96]\n"
1810           "add %[packed_ptr], %[packed_ptr], #128\n"
1811 
1812           // End of code handling full 4x4 blocks.
1813           // Now we handle any remaining rows.
1814           "3:\n"
1815           // Let w2 be the number of rows left to handle.
1816           "ands w2, %w[rows], #3\n"
1817           // If w2==0, there are no remaining rows, jump to the end.
1818           "beq 4f\n"
1819           // Zero out a 4x4 block in registers, which we'll partially overwrite
1820           // with any remaining rows.
1821           "movi v0.16b, #0\n"
1822           "movi v1.16b, #0\n"
1823           "movi v2.16b, #0\n"
1824           "movi v3.16b, #0\n"
1825 #define RUY_LOAD_ONE_ROW(R)                   \
1826   "cmp w2, #" #R "\n"                         \
1827   "beq 5f\n"                                  \
1828   "ld1 { v0.s }[" #R "], [%[src_ptr0]], #4\n" \
1829   "ld1 { v1.s }[" #R "], [%[src_ptr1]], #4\n" \
1830   "ld1 { v2.s }[" #R "], [%[src_ptr2]], #4\n" \
1831   "ld1 { v3.s }[" #R "], [%[src_ptr3]], #4\n"
1832 
1833           RUY_LOAD_ONE_ROW(0)
1834           RUY_LOAD_ONE_ROW(1)
1835           RUY_LOAD_ONE_ROW(2)
1836           // Here we know that w2==3, so RUY_LOAD_ONE_ROW(3) would be a no-op.
1837 #undef RUY_LOAD_ONE_ROW
1838           "5:\n"
1839 
1840           // Transpose that last zero-padded 4x4 block.
1841           "trn1 v16.4s, v0.4s, v1.4s\n"
1842           "trn2 v17.4s, v0.4s, v1.4s\n"
1843           "trn1 v18.4s, v2.4s, v3.4s\n"
1844           "trn2 v19.4s, v2.4s, v3.4s\n"
1845 
1846           "trn1 v20.2d, v16.2d, v18.2d\n"
1847           "trn2 v22.2d, v16.2d, v18.2d\n"
1848           "trn1 v21.2d, v17.2d, v19.2d\n"
1849           "trn2 v23.2d, v17.2d, v19.2d\n"
1850 
1851           // Store that last zero-padded block to the packed matrix.
1852           "mov x1, #32\n"
1853 #define RUY_STORE_ONE_ROW(ROW, REGISTER)                  \
1854           "cmp w2, #" #ROW "\n"                           \
1855           "beq 4f\n"                                      \
1856           "st1 {" #REGISTER ".4s}, [%[packed_ptr]], x1\n"
1857 
1858           RUY_STORE_ONE_ROW(0, v20)
1859           RUY_STORE_ONE_ROW(1, v21)
1860           RUY_STORE_ONE_ROW(2, v22)
1861           RUY_STORE_ONE_ROW(3, v23)
1862 
1863 #undef RUY_STORE_ONE_ROW
1864 
1865           "4:\n"
1866 
1867       // clang-format on
1868 
1869       : [src_ptr0] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1),
1870         [src_ptr2] "+r"(src_ptr2), [src_ptr3] "+r"(src_ptr3),
1871         [packed_ptr] "+r"(packed_ptr)
1872       : [src_inc0] "r"(static_cast<std::int64_t>(src_inc0)),
1873         [src_inc1] "r"(static_cast<std::int64_t>(src_inc1)),
1874         [src_inc2] "r"(static_cast<std::int64_t>(src_inc2)),
1875         [src_inc3] "r"(static_cast<std::int64_t>(src_inc3)),
1876         [rows] "r"(src_rows)
1877       : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1",
1878         "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
1879         "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22",
1880         "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
1881 }
1882 #endif
1883 
1884 #if RUY_PLATFORM_NEON_32 && RUY_OPT(ASM)
PackFloatColMajorForNeon(const float * src_ptr0,const float * src_ptr1,const float * src_ptr2,const float * src_ptr3,int src_inc,int src_rows,float * packed_ptr,int output_stride)1885 void PackFloatColMajorForNeon(const float* src_ptr0, const float* src_ptr1,
1886                               const float* src_ptr2, const float* src_ptr3,
1887                               int src_inc, int src_rows, float* packed_ptr,
1888                               int output_stride) {
1889   profiler::ScopeLabel label("Pack (kNeon)");
1890   asm volatile(
1891       // clang-format off
1892           "mov r1, #0\n"
1893           "and r2, %[rows], #-4\n"
1894           "cmp r1, r2\n"
1895           "beq 3f\n"
1896 #define RUY_LOAD_FOUR_BY_FOUR()               \
1897   /* Load q0 */                               \
1898   "vld1.32 {d0, d1}, [%[src_ptr0]]\n"         \
1899   /* if src_inc0 != 0, add 16 to src_ptr0 */  \
1900   "and r3, %[src_inc], #1\n"                  \
1901   "add %[src_ptr0], %[src_ptr0], r3, lsl #4\n"\
1902   /* Load q1 */                               \
1903   "vld1.32 {d2, d3}, [%[src_ptr1]]\n"         \
1904   /* if src_inc1 != 0, add 16 to src_ptr0 */  \
1905   "and r3, %[src_inc], #2\n"                  \
1906   "add %[src_ptr1], %[src_ptr1], r3, lsl #3\n"\
1907   /* Load q2 */                               \
1908   "vld1.32 {d4, d5}, [%[src_ptr2]]\n"         \
1909   /* if src_inc2 != 0, add 16 to src_ptr0 */  \
1910   "and r3, %[src_inc], #4\n"                  \
1911   "add %[src_ptr2], %[src_ptr2], r3, lsl #2\n"\
1912   /* Load q3 */                               \
1913   "vld1.32 {d6, d7}, [%[src_ptr3]]\n"         \
1914   /* if src_inc3 != 0, add 16 to src_ptr0 */  \
1915   "and r3, %[src_inc], #8\n"                  \
1916   "add %[src_ptr3], %[src_ptr3], r3, lsl #1\n"\
1917 
1918           RUY_LOAD_FOUR_BY_FOUR()
1919           "add r1, r1, #4\n"
1920           "cmp r1, r2\n"
1921 
1922           "beq 2f\n"
1923 
1924           "1:\n"
1925           "add r1, r1, #4\n"
1926 
1927           // Transpose 4x4 matrix.
1928           "vzip.32 q0, q1\n"
1929           "vzip.32 q2, q3\n"
1930 
1931           "vtrn.32 q0, q2\n"
1932           "vtrn.32 q1, q3\n"
1933 
1934           "vzip.32 q0, q2\n"
1935           "vzip.32 q1, q3\n"
1936 
1937           "vmov q8, q0\n"
1938           "vmov q9, q1\n"
1939           "vmov q10, q2\n"
1940           "vmov q11, q3\n"
1941 
1942           RUY_LOAD_FOUR_BY_FOUR()
1943 #undef RUY_LOAD_FOUR_BY_FOUR
1944 
1945 #define RUY_STORE_FOUR_BY_FOUR()                  \
1946   /* Store q8, q10, q9, q11 */                    \
1947   /* q8 = d16, d17 */                             \
1948   "vst1.32 {d16, d17}, [%[packed_ptr]]\n"         \
1949   /* q10 = d20, d21 */                            \
1950   "add %[packed_ptr], %[packed_ptr], %[stride]\n" \
1951   "vst1.32 {d20, d21}, [%[packed_ptr]]\n"         \
1952   /* q9 = d18, d19 */                             \
1953   "add %[packed_ptr], %[packed_ptr], %[stride]\n" \
1954   "vst1.32 {d18, d19}, [%[packed_ptr]]\n"         \
1955   /* q11 = d22, d23 */                            \
1956   "add %[packed_ptr], %[packed_ptr], %[stride]\n" \
1957   "vst1.32 {d22, d23}, [%[packed_ptr]]\n"         \
1958   "add %[packed_ptr], %[packed_ptr], %[stride]\n" \
1959 
1960           RUY_STORE_FOUR_BY_FOUR()
1961           "cmp r1, r2\n"
1962 
1963           "bne 1b\n"
1964 
1965           "2:\n"
1966 
1967           // Transpose 4x4 matrix.
1968           "vzip.32 q0, q1\n"
1969           "vzip.32 q2, q3\n"
1970 
1971           "vtrn.32 q0, q2\n"
1972           "vtrn.32 q1, q3\n"
1973 
1974           "vzip.32 q0, q2\n"
1975           "vzip.32 q1, q3\n"
1976 
1977           "vmov q8, q0\n"
1978           "vmov q9, q1\n"
1979           "vmov q10, q2\n"
1980           "vmov q11, q3\n"
1981 
1982           RUY_STORE_FOUR_BY_FOUR()
1983 #undef RUY_STORE_FOUR_BY_FOUR
1984           "3:\n"
1985 
1986           "ands r2, %[rows], #3\n"
1987           "beq 4f\n"
1988           "mov r0, #0\n"
1989           // Zero out q0 - q3
1990           "vdup.32 q0, r0\n"
1991           "vdup.32 q1, r0\n"
1992           "vdup.32 q2, r0\n"
1993           "vdup.32 q3, r0\n"
1994 #define RUY_LOAD_ONE_ROW_FIRST_HALF(R, I)    \
1995   "cmp r2, #" #R "\n"                        \
1996   "beq 5f\n"                                 \
1997   "vld1.32 { d0[" #I "] }, [%[src_ptr0]]!\n" \
1998   "vld1.32 { d2[" #I "] }, [%[src_ptr1]]!\n" \
1999   "vld1.32 { d4[" #I "] }, [%[src_ptr2]]!\n" \
2000   "vld1.32 { d6[" #I "] }, [%[src_ptr3]]!\n"
2001 
2002 #define RUY_LOAD_ONE_ROW_SECOND_HALF(R, I)      \
2003   "cmp r2, #" #R "\n"                        \
2004   "beq 5f\n"                                 \
2005   "vld1.32 { d1[" #I "] }, [%[src_ptr0]]!\n" \
2006   "vld1.32 { d3[" #I "] }, [%[src_ptr1]]!\n" \
2007   "vld1.32 { d5[" #I "] }, [%[src_ptr2]]!\n" \
2008   "vld1.32 { d7[" #I "] }, [%[src_ptr3]]!\n"
2009 
2010           RUY_LOAD_ONE_ROW_FIRST_HALF(0, 0)
2011           RUY_LOAD_ONE_ROW_FIRST_HALF(1, 1)
2012           RUY_LOAD_ONE_ROW_SECOND_HALF(2, 0)
2013           RUY_LOAD_ONE_ROW_SECOND_HALF(3, 1)
2014 #undef RUY_LOAD_ONE_ROW_SECOND_HALF
2015 #undef RUY_LOAD_ONE_ROW_FIRST_HALF
2016           "5:\n"
2017 
2018           // Transpose 4x4 matrix.
2019           "vzip.32 q0, q1\n"
2020           "vzip.32 q2, q3\n"
2021 
2022           "vtrn.32 q0, q2\n"
2023           "vtrn.32 q1, q3\n"
2024 
2025           "vzip.32 q0, q2\n"
2026           "vzip.32 q1, q3\n"
2027 
2028           "vmov q8, q0\n"
2029           "vmov q9, q1\n"
2030           "vmov q10, q2\n"
2031           "vmov q11, q3\n"
2032 
2033           "mov r1, #32\n"
2034 
2035 #define RUY_STORE_ONE_ROW(ROW, REGISTER)      \
2036           "cmp r2, #" #ROW "\n"                           \
2037           "beq 4f\n"                                      \
2038           "vst1.32 {" #REGISTER "}, [%[packed_ptr]]\n"    \
2039           "add %[packed_ptr], %[packed_ptr], %[stride]\n"
2040 
2041           // Store q8
2042           RUY_STORE_ONE_ROW(0, q8)
2043           // Store q10
2044           RUY_STORE_ONE_ROW(1, q10)
2045           // Store q9
2046           RUY_STORE_ONE_ROW(2, q9)
2047           // Store q11
2048           RUY_STORE_ONE_ROW(3, q11)
2049 
2050 #undef RUY_STORE_ONE_ROW
2051 
2052           "4:\n"
2053 
2054       // clang-format on
2055       : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1),
2056         [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3),
2057         [ packed_ptr ] "+r"(packed_ptr)
2058       : [ src_inc ] "r"(static_cast<std::int64_t>(src_inc)),
2059         [ rows ] "r"(src_rows), [ stride ] "r"(output_stride)
2060       : "cc", "memory", "r0", "r1", "r2", "r3", "q0", "q1", "q2", "q3",
2061         "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11");
2062 }
2063 
2064 #endif  // (RUY_PLATFORM_NEON_32
2065 
2066 #if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
PackFloatColMajorForNeonA55ish(const float * src_ptr0,const float * src_ptr1,const float * src_ptr2,const float * src_ptr3,int src_inc0,int src_inc1,int src_inc2,int src_inc3,int src_rows,float * packed_ptr)2067 void PackFloatColMajorForNeonA55ish(const float* src_ptr0,
2068                                     const float* src_ptr1,
2069                                     const float* src_ptr2,
2070                                     const float* src_ptr3, int src_inc0,
2071                                     int src_inc1, int src_inc2, int src_inc3,
2072                                     int src_rows, float* packed_ptr) {
2073   profiler::ScopeLabel label("Pack (kNeon, optimized for in-order cores)");
2074 
2075   asm volatile(
2076           // clang-format off
2077           "mov w1, #0\n"
2078 
2079           "and w2, %w[rows], #-4\n"
2080           "cmp w1, w2\n"
2081           "beq 3f\n"
2082           "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n"
2083           "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n"
2084           "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n"
2085           "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n"
2086           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #64]\n")
2087           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #64]\n")
2088           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #64]\n")
2089           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #64]\n")
2090           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #128]\n")
2091           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #128]\n")
2092           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #128]\n")
2093           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #128]\n")
2094           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #192]\n")
2095           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #192]\n")
2096           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #192]\n")
2097           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #192]\n")
2098           "add w1, w1, #4\n"
2099           "cmp w1, w2\n"
2100 
2101           "beq 2f\n"
2102 
2103           "1:\n"
2104           "add w1, w1, #4\n"
2105 
2106           "ldr x10, [%[src_ptr0], #8]\n"
2107           "trn1 v16.4s, v0.4s, v1.4s\n"
2108           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #240]\n")
2109           "ldr x11, [%[src_ptr1], #8]\n"
2110           "trn2 v17.4s, v0.4s, v1.4s\n"
2111           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #240]\n")
2112           "ldr x12, [%[src_ptr2], #8]\n"
2113           "trn1 v18.4s, v2.4s, v3.4s\n"
2114           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #240]\n")
2115           "ldr x13, [%[src_ptr3], #8]\n"
2116           "trn2 v19.4s, v2.4s, v3.4s\n"
2117           RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #240]\n")
2118 
2119           "ld1 {v0.2s}, [%[src_ptr0]], %[src_inc0]\n"
2120           "trn1 v20.2d, v16.2d, v18.2d\n"
2121           "ld1 {v1.2s}, [%[src_ptr1]], %[src_inc1]\n"
2122           "trn2 v22.2d, v16.2d, v18.2d\n"
2123           "ld1 {v2.2s}, [%[src_ptr2]], %[src_inc2]\n"
2124           "trn1 v21.2d, v17.2d, v19.2d\n"
2125           "ld1 {v3.2s}, [%[src_ptr3]], %[src_inc3]\n"
2126           "trn2 v23.2d, v17.2d, v19.2d\n"
2127           "cmp w1, w2\n"
2128 
2129           "ins v0.d[1], x10\n"
2130           "str q20, [%[packed_ptr], #0]\n"
2131           "ins v1.d[1], x11\n"
2132           "str q21, [%[packed_ptr], #32]\n"
2133           "ins v2.d[1], x12\n"
2134           "str q22, [%[packed_ptr], #64]\n"
2135           "ins v3.d[1], x13\n"
2136           "str q23, [%[packed_ptr], #96]\n"
2137 
2138           "add %[packed_ptr], %[packed_ptr], #128\n"
2139 
2140           "bne 1b\n"
2141 
2142           "2:\n"
2143 
2144           "trn1 v16.4s, v0.4s, v1.4s\n"
2145           "trn2 v17.4s, v0.4s, v1.4s\n"
2146           "trn1 v18.4s, v2.4s, v3.4s\n"
2147           "trn2 v19.4s, v2.4s, v3.4s\n"
2148 
2149           "trn1 v20.2d, v16.2d, v18.2d\n"
2150           "trn2 v22.2d, v16.2d, v18.2d\n"
2151           "trn1 v21.2d, v17.2d, v19.2d\n"
2152           "trn2 v23.2d, v17.2d, v19.2d\n"
2153 
2154           "str q20, [%[packed_ptr], #0]\n"
2155           "str q21, [%[packed_ptr], #32]\n"
2156           "str q22, [%[packed_ptr], #64]\n"
2157           "str q23, [%[packed_ptr], #96]\n"
2158           "add %[packed_ptr], %[packed_ptr], #128\n"
2159 
2160           "3:\n"
2161 
2162           "ands w2, %w[rows], #3\n"
2163           "beq 4f\n"
2164           "movi v0.16b, #0\n"
2165           "movi v1.16b, #0\n"
2166           "movi v2.16b, #0\n"
2167           "movi v3.16b, #0\n"
2168 #define RUY_LOAD_ONE_ROW(R)                   \
2169   "cmp w2, #" #R "\n"                         \
2170   "beq 5f\n"                                  \
2171   "ld1 { v0.s }[" #R "], [%[src_ptr0]], #4\n" \
2172   "ld1 { v1.s }[" #R "], [%[src_ptr1]], #4\n" \
2173   "ld1 { v2.s }[" #R "], [%[src_ptr2]], #4\n" \
2174   "ld1 { v3.s }[" #R "], [%[src_ptr3]], #4\n"
2175 
2176           RUY_LOAD_ONE_ROW(0)
2177           RUY_LOAD_ONE_ROW(1)
2178           RUY_LOAD_ONE_ROW(2)
2179           RUY_LOAD_ONE_ROW(3)
2180 #undef RUY_LOAD_ONE_ROW
2181           "5:\n"
2182 
2183           "trn1 v16.4s, v0.4s, v1.4s\n"
2184           "trn2 v17.4s, v0.4s, v1.4s\n"
2185           "trn1 v18.4s, v2.4s, v3.4s\n"
2186           "trn2 v19.4s, v2.4s, v3.4s\n"
2187 
2188           "trn1 v20.2d, v16.2d, v18.2d\n"
2189           "trn2 v22.2d, v16.2d, v18.2d\n"
2190           "trn1 v21.2d, v17.2d, v19.2d\n"
2191           "trn2 v23.2d, v17.2d, v19.2d\n"
2192 
2193           "mov x1, #32\n"
2194 
2195 #define RUY_STORE_ONE_ROW(ROW, REGISTER)                  \
2196           "cmp w2, #" #ROW "\n"                           \
2197           "beq 4f\n"                                      \
2198           "st1 {" #REGISTER ".4s}, [%[packed_ptr]], x1\n"
2199 
2200           RUY_STORE_ONE_ROW(0, v20)
2201           RUY_STORE_ONE_ROW(1, v21)
2202           RUY_STORE_ONE_ROW(2, v22)
2203           RUY_STORE_ONE_ROW(3, v23)
2204 
2205 #undef RUY_STORE_ONE_ROW
2206 
2207           "4:\n"
2208 
2209           // clang-format on
2210 
2211           : [ src_ptr0 ] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1), [src_ptr2] "+r"(src_ptr2),
2212             [src_ptr3] "+r"(src_ptr3), [packed_ptr] "+r"(packed_ptr)
2213           : [ src_inc0 ] "r"(static_cast<std::int64_t>(src_inc0)), [src_inc1] "r"(static_cast<std::int64_t>(src_inc1)), [src_inc2] "r"(static_cast<std::int64_t>(src_inc2)),
2214             [src_inc3] "r"(static_cast<std::int64_t>(src_inc3)), [rows] "r"(src_rows)
2215           : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
2216             "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
2217             "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
2218 }
2219 #endif  // RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
2220 
2221 #if RUY_PLATFORM_NEON
2222 
2223 namespace {
2224 // transpose_*bit_vals are wrappers around ARM TRN1 instructions, allowing
2225 // to use these instructions like we would in assembly --- this is one instance
2226 // where assembly is more idiomatic than intrinsics.
2227 //
2228 // The way that TRN1 is exposed by vtrn_* intrinsics makes its usage very
2229 // cumbersome. The issue is that transposing grouped of values has been exposed
2230 // only as transposing values of a wider type, so this requires many
2231 // vreinterpret's, and to make it worse, vtrn_* return NEON array types like
2232 // int8x8x2_t for which vreinterpret's are not defined!
transpose_8bit_vals(int8x8_t & a,int8x8_t & b)2233 void transpose_8bit_vals(int8x8_t& a, int8x8_t& b) {
2234   int8x8x2_t t = vtrn_s8(a, b);
2235   a = t.val[0];
2236   b = t.val[1];
2237 }
2238 
transpose_16bit_vals(int8x8_t & a,int8x8_t & b)2239 void transpose_16bit_vals(int8x8_t& a, int8x8_t& b) {
2240   int16x4x2_t t = vtrn_s16(vreinterpret_s16_s8(a), vreinterpret_s16_s8(b));
2241   a = vreinterpret_s8_s16(t.val[0]);
2242   b = vreinterpret_s8_s16(t.val[1]);
2243 }
2244 
transpose_32bit_vals(int8x8_t & a,int8x8_t & b)2245 void transpose_32bit_vals(int8x8_t& a, int8x8_t& b) {
2246   int32x2x2_t t = vtrn_s32(vreinterpret_s32_s8(a), vreinterpret_s32_s8(b));
2247   a = vreinterpret_s8_s32(t.val[0]);
2248   b = vreinterpret_s8_s32(t.val[1]);
2249 }
2250 }  // namespace
2251 
Pack8bitRowMajorForNeon(const std::uint8_t * src_ptr,int src_stride,int src_rows,int src_cols,int block_row,int start_col,int end_col,std::int8_t * packed_ptr,int packed_stride,int packed_zero_point,std::int32_t * sums,int input_xor,int kernel_cols)2252 void Pack8bitRowMajorForNeon(const std::uint8_t* src_ptr, int src_stride,
2253                              int src_rows, int src_cols, int block_row,
2254                              int start_col, int end_col,
2255                              std::int8_t* packed_ptr, int packed_stride,
2256                              int packed_zero_point, std::int32_t* sums,
2257                              int input_xor, int kernel_cols) {
2258   profiler::ScopeLabel label("Pack (kNeon, from row-major)");
2259 
2260   int src_end_col = std::min(end_col, src_cols);
2261   int col = start_col;
2262   for (; col <= src_end_col - 8; col += 8) {
2263     // Each iteration of this loop handles 8 columns, and the kernel format
2264     // has 16 rows, so each iteration handles a 16x8 block.
2265     //
2266     // Since the source is row-major, handling 8 columns at a time means
2267     // loading only 8 bytes i.e. 64bit from each row. This may seem surprising
2268     // on 128bit SIMD like NEON. While we could handle 16 columns at a time,
2269     // we prefer to stick with 8 for the following reasons:
2270     // 1. The arithmetic (computing sums and transposing data) done on these
2271     //    values is such that even though we initially start from 64bit vectors,
2272     //    most of our NEON instructions are full 128bit instructions. For the
2273     //    sums computation, that is because summing 8bit values requires
2274     //    expansion to 16bit anyway. For the matrix transposition code, that is
2275     //    because the ARM ZIP instructions take 64bit of data from two input
2276     //    registers and zip it into a 128bit output. If we had 128bit of data
2277     //    in each input registers, we would need 2x more ARM NEON instructions
2278     //    to zip it.
2279     // 2. The main optimization target for this (ARM, 8bit, non-dotprod)
2280     //    code path is in-order ARM cores such as the Cortex-A53, which prefer
2281     //    64bit loads anyway.
2282     // 3. Handling only 8 columns at a time limits the size of the final
2283     //    leftover columns handled with slow scalar code.
2284     //
2285     // This code is not very optimized anyway, as evidenced from the facts that
2286     // (1) it's written in intrinsics, (2) it's not using separate versions
2287     // tuned for different types of CPU cores. At the level of optimization that
2288     // it's working at, this seems like a fair compromise. If one wanted to
2289     // maximize performance at the cost of more code complexity/size, one could
2290     // have code handling 16 columns at a time (maybe limited to
2291     // Tuning::kGeneric), then 8, then 4 to minimize the amount of slow
2292     // leftovers.
2293     //
2294     // Load 8 sums in sums0, sums1.
2295     int32x4_t sums0 = vld1q_s32(sums + col);
2296     int32x4_t sums1 = vld1q_s32(sums + col + 4);
2297     // Load the 8x16 block from the source matrix.
2298     // Each val* here is the data from one row.
2299     int8x8_t val0, val1, val2, val3, val4, val5, val6, val7, val8, val9, val10,
2300         val11, val12, val13, val14, val15;
2301     // Even though this function takes a uint8_t* src_ptr, that's only a
2302     // type-erased pointer (using uint8_t* so that pointer arithmetic is
2303     // allowed). The actual type may be either uint8_t or int8_t. The only
2304     // difference it makes is that if it's uint8_t then we need to flip the
2305     // sign bit. This is specified by the input_xor value (which is 0x80 if the
2306     // input data is uint8_t, and 0x0 otherwise).
2307     auto load_and_convert = [=](const std::uint8_t* from) {
2308       return vreinterpret_s8_u8(veor_u8(vdup_n_u8(input_xor), vld1_u8(from)));
2309     };
2310     if (block_row <= src_rows - 16) {
2311       // Load data in the regular case: there are still 16 rows to be read from
2312       // the source matrix.
2313       val0 = load_and_convert(src_ptr + 0 * src_stride);
2314       val1 = load_and_convert(src_ptr + 1 * src_stride);
2315       val2 = load_and_convert(src_ptr + 2 * src_stride);
2316       val3 = load_and_convert(src_ptr + 3 * src_stride);
2317       val4 = load_and_convert(src_ptr + 4 * src_stride);
2318       val5 = load_and_convert(src_ptr + 5 * src_stride);
2319       val6 = load_and_convert(src_ptr + 6 * src_stride);
2320       val7 = load_and_convert(src_ptr + 7 * src_stride);
2321       val8 = load_and_convert(src_ptr + 8 * src_stride);
2322       val9 = load_and_convert(src_ptr + 9 * src_stride);
2323       val10 = load_and_convert(src_ptr + 10 * src_stride);
2324       val11 = load_and_convert(src_ptr + 11 * src_stride);
2325       val12 = load_and_convert(src_ptr + 12 * src_stride);
2326       val13 = load_and_convert(src_ptr + 13 * src_stride);
2327       val14 = load_and_convert(src_ptr + 14 * src_stride);
2328       val15 = load_and_convert(src_ptr + 15 * src_stride);
2329     } else {
2330       // Boundary case: there are fewer than 16 rows to be read from the source
2331       // matrix. We pad by the zero_point.
2332       val0 = vdup_n_s8(packed_zero_point);
2333       val1 = val0;
2334       val2 = val0;
2335       val3 = val0;
2336       val4 = val0;
2337       val5 = val0;
2338       val6 = val0;
2339       val7 = val0;
2340       val8 = val0;
2341       val9 = val0;
2342       val10 = val0;
2343       val11 = val0;
2344       val12 = val0;
2345       val13 = val0;
2346       val14 = val0;
2347       val15 = val0;
2348       if (block_row + 0 < src_rows)
2349         val0 = load_and_convert(src_ptr + 0 * src_stride);
2350       if (block_row + 1 < src_rows)
2351         val1 = load_and_convert(src_ptr + 1 * src_stride);
2352       if (block_row + 2 < src_rows)
2353         val2 = load_and_convert(src_ptr + 2 * src_stride);
2354       if (block_row + 3 < src_rows)
2355         val3 = load_and_convert(src_ptr + 3 * src_stride);
2356       if (block_row + 4 < src_rows)
2357         val4 = load_and_convert(src_ptr + 4 * src_stride);
2358       if (block_row + 5 < src_rows)
2359         val5 = load_and_convert(src_ptr + 5 * src_stride);
2360       if (block_row + 6 < src_rows)
2361         val6 = load_and_convert(src_ptr + 6 * src_stride);
2362       if (block_row + 7 < src_rows)
2363         val7 = load_and_convert(src_ptr + 7 * src_stride);
2364       if (block_row + 8 < src_rows)
2365         val8 = load_and_convert(src_ptr + 8 * src_stride);
2366       if (block_row + 9 < src_rows)
2367         val9 = load_and_convert(src_ptr + 9 * src_stride);
2368       if (block_row + 10 < src_rows)
2369         val10 = load_and_convert(src_ptr + 10 * src_stride);
2370       if (block_row + 11 < src_rows)
2371         val11 = load_and_convert(src_ptr + 11 * src_stride);
2372       if (block_row + 12 < src_rows)
2373         val12 = load_and_convert(src_ptr + 12 * src_stride);
2374       if (block_row + 13 < src_rows)
2375         val13 = load_and_convert(src_ptr + 13 * src_stride);
2376       if (block_row + 14 < src_rows)
2377         val14 = load_and_convert(src_ptr + 14 * src_stride);
2378       if (block_row + 15 < src_rows)
2379         val15 = load_and_convert(src_ptr + 15 * src_stride);
2380     }
2381     src_ptr += 8;
2382     // Compute sums.
2383     int16x8_t sums16_0 = vaddl_s8(val0, val1);
2384     int16x8_t sums16_1 = vaddl_s8(val2, val3);
2385     sums16_0 = vaddq_s16(sums16_0, vaddl_s8(val4, val5));
2386     sums16_1 = vaddq_s16(sums16_1, vaddl_s8(val6, val7));
2387     sums16_0 = vaddq_s16(sums16_0, vaddl_s8(val8, val9));
2388     sums16_1 = vaddq_s16(sums16_1, vaddl_s8(val10, val11));
2389     sums16_0 = vaddq_s16(sums16_0, vaddl_s8(val12, val13));
2390     sums16_1 = vaddq_s16(sums16_1, vaddl_s8(val14, val15));
2391     int16x8_t sums16 = vaddq_s16(sums16_0, sums16_1);
2392     sums0 = vaddw_s16(sums0, vget_low_s16(sums16));
2393     sums1 = vaddw_s16(sums1, vget_high_s16(sums16));
2394     // Store sums.
2395     vst1q_s32(sums + col, sums0);
2396     vst1q_s32(sums + col + 4, sums1);
2397 
2398     // Transpose the data, i.e. change the storage order of the
2399     // 16x8 block, to convert from the row-major source to the
2400     // column-major packed format.
2401     //
2402     // Before, for i in [0, 15], val<i> is the i-th row.
2403     // After, for i in [0, 7], { val<i> val<i+8> } is the i-th column.
2404     transpose_8bit_vals(val0, val1);
2405     transpose_8bit_vals(val2, val3);
2406     transpose_8bit_vals(val4, val5);
2407     transpose_8bit_vals(val6, val7);
2408     transpose_8bit_vals(val8, val9);
2409     transpose_8bit_vals(val10, val11);
2410     transpose_8bit_vals(val12, val13);
2411     transpose_8bit_vals(val14, val15);
2412     transpose_16bit_vals(val0, val2);
2413     transpose_16bit_vals(val1, val3);
2414     transpose_16bit_vals(val4, val6);
2415     transpose_16bit_vals(val5, val7);
2416     transpose_16bit_vals(val8, val10);
2417     transpose_16bit_vals(val9, val11);
2418     transpose_16bit_vals(val12, val14);
2419     transpose_16bit_vals(val13, val15);
2420     transpose_32bit_vals(val0, val4);
2421     transpose_32bit_vals(val1, val5);
2422     transpose_32bit_vals(val2, val6);
2423     transpose_32bit_vals(val3, val7);
2424     transpose_32bit_vals(val8, val12);
2425     transpose_32bit_vals(val9, val13);
2426     transpose_32bit_vals(val10, val14);
2427     transpose_32bit_vals(val11, val15);
2428     // Store to the packed_matrix.
2429     std::int8_t* dst_ptr = packed_ptr;
2430     vst1q_s8(dst_ptr, vcombine_s8(val0, val8));
2431     vst1q_s8(dst_ptr + 16, vcombine_s8(val1, val9));
2432     dst_ptr += (kernel_cols == 2) ? 2 * packed_stride : 32;
2433     vst1q_s8(dst_ptr, vcombine_s8(val2, val10));
2434     vst1q_s8(dst_ptr + 16, vcombine_s8(val3, val11));
2435     packed_ptr += 4 * packed_stride;
2436     dst_ptr = packed_ptr;
2437     vst1q_s8(dst_ptr, vcombine_s8(val4, val12));
2438     vst1q_s8(dst_ptr + 16, vcombine_s8(val5, val13));
2439     dst_ptr += (kernel_cols == 2) ? 2 * packed_stride : 32;
2440     vst1q_s8(dst_ptr, vcombine_s8(val6, val14));
2441     vst1q_s8(dst_ptr + 16, vcombine_s8(val7, val15));
2442     packed_ptr += 4 * packed_stride;
2443   }
2444   // Handle remaining columns, not fitting in a full block of 8 columns, but
2445   // still true columns frome the source matrix (as opposed to the final columns
2446   // below).
2447   for (; col < src_end_col; col++) {
2448     std::int32_t accum = 0;
2449     std::int8_t* dst_ptr = packed_ptr + (col & (kernel_cols - 1)) * 16;
2450     for (int r = 0; r < 16; r++) {
2451       std::int8_t packed_val = (block_row + r < src_rows)
2452                                    ? (src_ptr[r * src_stride] ^ input_xor)
2453                                    : packed_zero_point;
2454       accum += packed_val;
2455       dst_ptr[r] = packed_val;
2456     }
2457     if (sums) {
2458       sums[col] += accum;
2459     }
2460     src_ptr++;
2461     if (((col + 1) & (kernel_cols - 1)) == 0) {
2462       packed_ptr += kernel_cols * packed_stride;
2463     }
2464   }
2465   // Handle the final columns of the packed matrix, beyond the last column of
2466   // the source matrix. The values here don't matter, we just want to avoid
2467   // leaving uninitialized data. Since the sums are already initialized above,
2468   // we don't need to do anything about them here.
2469   for (; col < end_col; col++) {
2470     std::int8_t* dst_ptr = packed_ptr + (col & (kernel_cols - 1)) * 16;
2471     std::memset(dst_ptr, 0, 16);
2472     if (((col + 1) & (kernel_cols - 1)) == 0) {
2473       packed_ptr += kernel_cols * packed_stride;
2474     }
2475   }
2476 }
2477 
2478 #endif
2479 
2480 }  // namespace ruy
2481