xref: /aosp_15_r20/external/ruy/ruy/kernel_arm32.cc (revision bb86c7ed5fb1b98a7eac808e443a46cc8b90dfc0)
1 /* Copyright 2019 Google LLC. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "ruy/kernel_arm.h"
17 #include "ruy/opt_set.h"
18 #include "ruy/platform.h"
19 #include "ruy/profiler/instrumentation.h"
20 
21 namespace ruy {
22 
23 #if RUY_PLATFORM_NEON_32 && RUY_OPT(ASM)
24 
25 #define RUY_ASM_LABEL_STORE_UINT8 91
26 #define RUY_ASM_LABEL_STORE_INT8 92
27 #define RUY_ASM_LABEL_STORE_INT16 93
28 #define RUY_ASM_LABEL_STORE_INT32 94
29 #define RUY_ASM_LABEL_AFTER_STORE 99
30 
31 #define RUY_OFFSET_LHS_BASE_PTR 0
32 #define RUY_OFFSET_RHS_BASE_PTR 4
33 #define RUY_OFFSET_DST_BASE_PTR 8
34 #define RUY_OFFSET_BIAS 12
35 #define RUY_OFFSET_START_ROW 16
36 #define RUY_OFFSET_START_COL 20
37 #define RUY_OFFSET_LAST_ROW 24
38 #define RUY_OFFSET_LAST_COL 28
39 #define RUY_OFFSET_DST_ROWS 32
40 #define RUY_OFFSET_DST_COLS 36
41 #define RUY_OFFSET_LHS_STRIDE 40
42 #define RUY_OFFSET_RHS_STRIDE 44
43 #define RUY_OFFSET_DST_STRIDE 48
44 #define RUY_OFFSET_DEPTH 52
45 #define RUY_OFFSET_CLAMP_MIN 56
46 #define RUY_OFFSET_CLAMP_MAX 60
47 #define RUY_OFFSET_FLAGS 64
48 
49 #define RUY_STACK_OFFSET_SIZE 96
50 #define RUY_STACK_OFFSET_DST_COL_PTR 0
51 #define RUY_STACK_OFFSET_DST_PTR 16
52 #define RUY_STACK_OFFSET_ROW 32
53 #define RUY_STACK_OFFSET_COL 48
54 #define RUY_STACK_OFFSET_LHS_COL_PTR 64
55 #define RUY_STACK_OFFSET_RHS_COL_PTR 80
56 
57 template <typename Params>
CheckOffsetsInKernelParamsFloat32(const Params &)58 void CheckOffsetsInKernelParamsFloat32(const Params&) {
59   static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, "");
60   static_assert(offsetof(Params, rhs_base_ptr) == RUY_OFFSET_RHS_BASE_PTR, "");
61   static_assert(offsetof(Params, dst_base_ptr) == RUY_OFFSET_DST_BASE_PTR, "");
62   static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, "");
63   static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, "");
64   static_assert(offsetof(Params, start_col) == RUY_OFFSET_START_COL, "");
65   static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, "");
66   static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, "");
67   static_assert(offsetof(Params, dst_rows) == RUY_OFFSET_DST_ROWS, "");
68   static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, "");
69   static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, "");
70   static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, "");
71   static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, "");
72   static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, "");
73   static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, "");
74   static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, "");
75 }
76 
77 // Float kernel for ARM32 out-of-order cores.
78 // Just like Float 64 version, except accumulate in to 8x4 block to only
79 // use 16 128-bit NEON registers. This is a "first pass" kernel and not
80 // tuned. It is meant to run on out-of-order CPUs like the Krait 400 or A9.
KernelFloat32Neon(const KernelParamsFloat<8,4> & params)81 void KernelFloat32Neon(const KernelParamsFloat<8, 4>& params) {
82   CheckOffsetsInKernelParamsFloat32(params);
83   profiler::ScopeLabel label("Kernel (kNeon)");
84 
85   const float* lhs_ptr = params.lhs_base_ptr;
86   const float* rhs_ptr = params.rhs_base_ptr;
87   // In ARM32 NEON, there are 16 128-bit "q" registers. These registers are
88   // each composed of two 64-bit "d" registers. The asm kernel below has the
89   // following NEON register allocation:
90   // Registers q3 -- q10 are accumulators. During accumulation,
91   // q0 -- q2 (d0 -- d5) are used to load data from LHS and RHS. q0 and q1
92   // are used to load a 8x1 block of LHS, and q2 is used to load a 1x4 block
93   // of RHS, like this:
94 
95   //  Register layout in "q" registers:
96   //                                    RHS 1x4 block
97   //                           /--------------------------|
98   //                           |q2.s[0] ...      q2.s[3]  |
99   //                           \--------------------------/
100   //        LHS 8x1 block
101   //  /---------------------\  /--------------------------|
102   //  |        q0.s[0]      |  | q3.s[0]   ...    q9.s[0] |
103   //  |         ...         |  |  ...               ...   |
104   //  |        q0.s[3]      |  | q3.s[3]          q9.s[3] |
105   //  |        q1.s[0]      |  | q4.s[0]         q10.s[0] |
106   //  |         ...         |  |  ...      ...      ...   |
107   //  |        q1.s[3]      |  | q4.s[3]   ..    q10.s[3] |
108   //  \---------------------/  \--------------------------/
109   //                             accumulators 8x4 block
110   // q11, q14, q15 currently unused. q12 and q13 are used to load
111   // parameters used for the post-accumulation part of the kernel.
112   // For completeness, here is the register layout in "d" registers:
113   //                                    RHS 1x4 block
114   //                           /--------------------------|
115   //                           |d4[0]     ...       d5[1] |
116   //                           \--------------------------/
117   //        LHS 8x1 block
118   //  /---------------------\  /--------------------------|
119   //  |        d0[0]        |  | d6[0]    ...      d18[0] |
120   //  |         ...         |  |  ...               ...   |
121   //  |        d1[1]        |  | d7[1]             d19[1] |
122   //  |        d2[0]        |  | d8[0]             d20[0] |
123   //  |         ...         |  |  ...      ...      ...   |
124   //  |        d3[1]        |  | d9[1]     ...     d21[1] |
125   //  \---------------------/  \--------------------------/
126   //                             accumulators 8x4 block
127   asm volatile(
128 #define RUY_MAKE_ZERO(reg) "vmov.f32 " #reg ", #0.0\n"
129 
130         // clang-format off
131 
132         // Load the first 32 bytes of LHS and RHS data.
133         // Load q0, q1
134         "vld1.32 {d0, d1}, [%[lhs_ptr]]!\n"
135         "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n"
136         RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
137         // Load q2
138         "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n"
139         RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
140 
141         "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
142 
143         "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
144         "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
145 
146         "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
147         "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
148 
149         "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
150         "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
151 
152         "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n"
153         "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
154 
155         "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
156         "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
157 
158         "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n"
159         "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
160         // Clear accumulators.
161         RUY_MAKE_ZERO(q3)
162         RUY_MAKE_ZERO(q4)
163         RUY_MAKE_ZERO(q5)
164         RUY_MAKE_ZERO(q6)
165         RUY_MAKE_ZERO(q7)
166         RUY_MAKE_ZERO(q8)
167         RUY_MAKE_ZERO(q9)
168         RUY_MAKE_ZERO(q10)
169 
170         // r1 is the number of levels of depth that we have already loaded
171         // LHS and RHS data for. Corresponding to the initial ld1 instructions
172         // above, this is currently 1.
173         "mov r1, #1\n"
174 
175         // Main loop of the whole GEMM, over rows and columns of the
176         // destination matrix.
177         "1:\n"
178 
179         // Accumulation loop
180         "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
181         "cmp r1, r2\n"
182         "beq 79f\n"
183 
184         "2:\n"
185 
186         "vmla.f32 q3, q0, d4[0]\n"
187         "vmla.f32 q5, q0, d4[1]\n"
188         "vmla.f32 q7, q0, d5[0]\n"
189         "vmla.f32 q9, q0, d5[1]\n"
190         "vld1.32 {d0, d1}, [%[lhs_ptr]]!\n" // Reload LHS
191 
192         "vmla.f32 q4, q1, d4[0]\n"
193         "vmla.f32 q6, q1, d4[1]\n"
194         "vmla.f32 q8, q1, d5[0]\n"
195         "vmla.f32 q10, q1, d5[1]\n"
196         "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS
197         RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
198         "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n" // Reload RHS
199         RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
200 
201         "add r1, r1, #1\n"
202         "cmp r1, r2\n"
203 
204         "blt 2b\n"
205 
206         "79:\n"
207 
208         // End of the inner loop on depth. Now perform the remaining
209         // multiply-adds of the last level of depth, for which the LHS
210         // and RHS data is already loaded.
211 
212         "vmla.f32 q3, q0, d4[0]\n"
213         "vmla.f32 q5, q0, d4[1]\n"
214         "vmla.f32 q7, q0, d5[0]\n"
215         "vmla.f32 q9, q0, d5[1]\n"
216 
217         "vmla.f32 q4, q1, d4[0]\n"
218         "vmla.f32 q6, q1, d4[1]\n"
219         "vmla.f32 q8, q1, d5[0]\n"
220         "vmla.f32 q10, q1, d5[1]\n"
221 
222         // End of accumulation. The registers q3 -- q10 contain the final
223         // float32 accumulator values of the current 8x8 destination block.
224         // We now have to compute the final values from these accumulators
225         // and advance to the next 8x8 block. We intertwine
226         // these two aspects whenever possible for optimal pipelining, both
227         // at the data flow level (prefetch data for next block as early as
228         // possible) and instruction pipelining level (some of the next-block
229         // work can dual-issue with some of the final work on the current
230         // block).
231 
232         // Logic to advance to the next block in preparation for the next
233         // iteration of the main loop. For now, we only want to compute
234         // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
235         // not yet ready to update the values of row and col, as we still need
236         // the current values for the rest of the work on the current block.
237 
238         "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
239         "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
240         "cmp r1, r3\n"  // Have we finished the last row?
241 
242         "bge 4f\n"      // If finished last row, go to 4
243         // Not finished last row: then advance to next row.
244         "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
245         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
246         "add r4, r4, r1, lsl #3\n"
247         "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
248         "b 5f\n"
249         "4:\n"  // Finished last row...
250         "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
251         // Go back to first row
252         "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
253         // Now we need to advance to the next column. If we already
254         // finished the last column, then in principle we are done, however
255         // we can't just return here, as we need to allow the end work of the
256         // current block to complete. The good news is that at this point it
257         // doesn't matter what data we load for the next column, since
258         // we will exit from the main loop below before actually storing
259         // anything computed from that data.
260         "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
261         "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
262         "cmp r8, r4\n"  // Have we finished the last column?
263         "bge 5f\n" // If yes, just carry on without updating the column pointer.
264         // Not finished last column: then advance to next column.
265         "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
266         "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
267         "add r10, r10, r1, lsl #2\n"
268         "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
269         "5:\n"
270 
271         // Set the LHS and RHS data pointers to the start of the columns just
272         // computed.
273         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
274         "mov %[lhs_ptr], r4\n"
275         "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
276         "mov %[rhs_ptr], r5\n"
277 
278         // Load some parameters needed for the end work on current block.
279         "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
280         "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
281 
282         // Let r8 be stack offset of the row or column variable, whichever
283         // is the channel index.
284         "tst r4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
285         "bne 1000f\n"
286         "mov r8, #" RUY_STR(RUY_STACK_OFFSET_ROW) "\n"
287         "b 1001f\n"
288         "1000:\n"
289         "mov r8, #" RUY_STR(RUY_STACK_OFFSET_COL) "\n"
290         "1001:\n"
291         // Let r8 be the channel index.
292         "ldr r8, [sp, r8]\n"
293         // Compute the bias pointer, by conditionally using the channel index
294         // (r8) as offset into bias buffer (r1).
295         "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
296         "beq 1002f\n"
297         "add r1, r1, r8, lsl #2\n"
298         "1002:\n"
299 
300         // Load 4 bias values. When the channel dimension is rows, we will load
301         // another 4 bias values just before performing the bias addition below,
302         // as this kernel has a 8x4 rectangular shape.
303         "vld1.32 {d24, d25}, [r1]!\n"
304 
305         // Now that we know what LHS and RHS data the next iteration of the
306         // main loop will need to load, we start loading the first 32 bytes of
307         // each of LHS and RHS, into q0 -- q2, as we don't need q0 -- q2 anymore
308         // in the rest of the work on the current block.
309         // Load q0, q1
310         "vld1.32 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n"
311         RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
312         // Load q2
313         "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n"
314         RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
315 
316         // Perform the bias-addition.
317         // Jump based on channel dimension.
318         "tst r4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
319         "bne 6f\n"
320         // Case where channels are rows.
321         // Load the remaining 4 bias values, since we're on the width-8 side
322         // of this 8x4 kernel.
323         "vld1.32 {d26, d27}, [r1]\n"
324         "vadd.f32 q3, q3, q12\n"
325         "vadd.f32 q5, q5, q12\n"
326         "vadd.f32 q7, q7, q12\n"
327         "vadd.f32 q9, q9, q12\n"
328         "vadd.f32 q4, q4, q13\n"
329         "vadd.f32 q6, q6, q13\n"
330         "vadd.f32 q8, q8, q13\n"
331         "vadd.f32 q10, q10, q13\n"
332         "b 7f\n"
333 
334         "6:\n"
335         // Case where channels are columns.
336         "vdup.32 q11, d24[0]\n"
337         "vdup.32 q13, d24[1]\n"
338         "vdup.32 q14, d25[0]\n"
339         "vdup.32 q15, d25[1]\n"
340         "vadd.f32 q3, q3, q11\n"
341         "vadd.f32 q4, q4, q11\n"
342         "vadd.f32 q5, q5, q13\n"
343         "vadd.f32 q6, q6, q13\n"
344         "vadd.f32 q7, q7, q14\n"
345         "vadd.f32 q8, q8, q14\n"
346         "vadd.f32 q9, q9, q15\n"
347         "vadd.f32 q10, q10, q15\n"
348         "7:\n"
349 
350         // Load the clamp_min, clamp_max bounds
351         "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
352         "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
353         "vdup.32 q12, r2\n"  // clamp_min
354         "vdup.32 q13, r3\n"  // clamp_max
355 
356         // Apply the clamp_min bound
357         "vmax.f32 q3, q3, q12\n"
358         "vmax.f32 q4, q4, q12\n"
359         "vmax.f32 q5, q5, q12\n"
360         "vmax.f32 q6, q6, q12\n"
361         "vmax.f32 q7, q7, q12\n"
362         "vmax.f32 q8, q8, q12\n"
363         "vmax.f32 q9, q9, q12\n"
364         "vmax.f32 q10, q10, q12\n"
365 
366         // Apply the clamp_max bound
367         "vmin.f32 q3, q3, q13\n"
368         "vmin.f32 q4, q4, q13\n"
369         "vmin.f32 q5, q5, q13\n"
370         "vmin.f32 q6, q6, q13\n"
371         "vmin.f32 q7, q7, q13\n"
372         "vmin.f32 q8, q8, q13\n"
373         "vmin.f32 q9, q9, q13\n"
374         "vmin.f32 q10, q10, q13\n"
375 
376         // Compute how much of the 8x4 block of destination values that
377         // we have computed, fit in the destination matrix. Typically, all of
378         // it fits, but when the destination matrix shape is not a multiple
379         // of 8x4, there are some 8x8 blocks along the boundaries that do
380         // not fit entirely.
381         "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
382         "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
383         "sub r1, r1, r8\n"
384 
385         "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
386         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
387         "sub r2, r2, r4\n"
388         "mov r3, #8\n"
389         "mov r5, #4\n"
390         "cmp r1, #8\n"
391         // Compute r1 = how many rows of the 8x4 block fit
392         "it gt\n"
393         "movgt r1, r3\n"
394         "cmp r2, #4\n"
395         // Compute r2 = how many cols of the 8x4 block fit
396         "it gt\n"
397         "movgt r2, r5\n"
398 
399         // Test if r1==8 && r2 == 4, i.e. if all of the 8x4 block fits.
400         "cmp r1, r3\n"
401         "it eq\n"
402         "cmpeq r2, r5\n"
403         // Yes, all of the 8x4 block fits, go to fast path.
404         "beq 30f\n"
405         // Not all of the 8x4 block fits.
406         // Set (r3 address, r4 stride) to write to dst_tmp_buf
407         "mov r3, %[dst_tmp_buf]\n"
408         "mov r4, #32\n"
409         "b 31f\n"
410         "30:\n"
411         // Yes, all of the 8x4 block fits.
412         // Set (r3 address, r4 stride) to write directly to destination matrix.
413         "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
414         "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
415         "mov r4, r5\n"
416         "31:\n"
417 
418         // Write our float values to the destination described by
419         // (r3 address, r4 stride)
420         "vst1.32 {d6, d7, d8, d9}, [r3]\n"
421         "add r3, r3, r4\n"
422         RUY_MAKE_ZERO(q3)
423         RUY_MAKE_ZERO(q4)
424         "vst1.32 {d10, d11, d12, d13}, [r3]\n"
425         "add r3, r3, r4\n"
426         RUY_MAKE_ZERO(q5)
427         RUY_MAKE_ZERO(q6)
428         "vst1.32 {d14, d15, d16, d17}, [r3]\n"
429         "add r3, r3, r4\n"
430         RUY_MAKE_ZERO(q7)
431         RUY_MAKE_ZERO(q8)
432         "vst1.32 {d18, d19, d20, d21}, [r3]\n"
433         "add r3, r3, r4\n"
434         RUY_MAKE_ZERO(q9)
435         RUY_MAKE_ZERO(q10)
436 
437         // If all of the 8x4 block fits, we just finished writing it to the
438         // destination, so we skip the next part.
439         "beq 41f\n"
440         // Not all of the 8x8 block fits in the destination matrix.  We just
441         // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
442         // it to copy into the destination matrix the part that fits.
443         "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
444         "mov r3, %[dst_tmp_buf]\n"
445         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
446         "mov r6, #0\n"
447         "50:\n"
448         "mov r5, #0\n"
449         "51:\n"
450         "ldr r10, [r3, r5, lsl #2]\n"
451         "str r10, [r4, r5, lsl #2]\n"
452         "add r5, r5, #1\n"
453         "cmp r5, r1\n"
454         "blt 51b\n"
455         "add r6, r6, #1\n"
456         "add r3, r3, #32\n"
457         "add r4, r4, r8\n"
458         // r2 = how many cols of the 8x4 block fit
459         "cmp r6, r2\n"
460         "blt 50b\n"
461         "41:\n"
462         // Load dst_ptr, increment, and write back.
463         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
464         "add r4, r4, #32\n"
465         "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
466         // At this point we have completely finished writing values to the
467         // destination matrix for the current block.
468 
469         // Reload some params --- we had used r3, r5, r10 for a few other things
470         // since the last time we had loaded them.
471         "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
472         "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
473         "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
474 
475         // Move to the next block of the destination matrix, for the next iter
476         // of the main loop.  Notice that lhs_col_ptr, rhs_col_ptr have already
477         // been updated earlier.
478         // Have we reached the end row?
479         "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
480         "cmp r8, r3\n"
481 
482         "beq 20f\n"  // yes, end row.
483         // Not end row. Move to the next row.
484         "add r8, r8, #8\n"
485         // Store new value of row
486         "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
487 
488         "b 21f\n"
489         "20:\n"
490         // Was already at end row.
491         // Move back to first row.
492         "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
493         // Move to the next column.
494         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
495         "add r4, r4, #4\n"
496         "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
497 
498         "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
499         "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
500         // Increment dst_col_ptr by 4 * dst_stride (i.e. 4 columns)
501         "add r1, r1, r8, lsl #2\n"
502         // Store dst_col_ptr
503         "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
504         // Store dst_ptr
505         "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
506         "21:\n"
507 
508         // Main loop exit condition: have we hit the end column?
509         "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
510         "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
511         "cmp r8, r4\n"
512 
513         // r1 is the number of levels of depth that we have already loaded
514         // LHS and RHS data for. Corresponding to the initial ld1 instructions
515         // above, this is currently 1.
516         "mov r1, #1\n"
517 
518         "ble 1b\n"
519 
520         // Restore stack pointer.
521         "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
522 
523         // clang-format on
524         : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr)
525         : [ params ] "r"(&params), [dst_tmp_buf] "r"(params.dst_tmp_buf)
526         // Clobber list must specify q registers (and not their constituent
527         // d registers). There is a (currently unexplained) slowdown if
528         // d registers are listed in the clobbers list.
529         : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc",
530           "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8",
531           "q9", "q10", "q12", "q13");
532 }
533 
534 #undef RUY_MAKE_ZERO
535 #undef RUY_STACK_OFFSET_SIZE
536 #undef RUY_STACK_OFFSET_DST_COL_PTR
537 #undef RUY_STACK_OFFSET_DST_PTR
538 #undef RUY_STACK_OFFSET_ROW
539 #undef RUY_STACK_OFFSET_COL
540 #undef RUY_STACK_OFFSET_LHS_COL_PTR
541 #undef RUY_STACK_OFFSET_RHS_COL_PTR
542 
543 #undef RUY_OFFSET_LHS_BASE_PTR
544 #undef RUY_OFFSET_RHS_BASE_PTR
545 #undef RUY_OFFSET_DST_BASE_PTR
546 #undef RUY_OFFSET_BIAS
547 #undef RUY_OFFSET_START_ROW
548 #undef RUY_OFFSET_START_COL
549 #undef RUY_OFFSET_LAST_ROW
550 #undef RUY_OFFSET_LAST_COL
551 #undef RUY_OFFSET_DST_ROWS
552 #undef RUY_OFFSET_DST_COLS
553 #undef RUY_OFFSET_LHS_STRIDE
554 #undef RUY_OFFSET_RHS_STRIDE
555 #undef RUY_OFFSET_DST_STRIDE
556 #undef RUY_OFFSET_DEPTH
557 #undef RUY_OFFSET_CLAMP_MIN
558 #undef RUY_OFFSET_CLAMP_MAX
559 #undef RUY_OFFSET_FLAGS
560 
561 #define RUY_OFFSET_BIAS 0
562 #define RUY_OFFSET_LHS_SUMS 4
563 #define RUY_OFFSET_RHS_SUMS 8
564 #define RUY_OFFSET_LHS_BASE_PTR 12
565 #define RUY_OFFSET_MULTIPLIER_FIXEDPOINT 16
566 #define RUY_OFFSET_MULTIPLIER_EXPONENT 20
567 #define RUY_OFFSET_RHS_BASE_PTR 24
568 #define RUY_OFFSET_DST_BASE_PTR 28
569 #define RUY_OFFSET_LHS_ZERO_POINT 32
570 #define RUY_OFFSET_RHS_ZERO_POINT 36
571 #define RUY_OFFSET_DST_ZERO_POINT 40
572 #define RUY_OFFSET_PROD_ZP_DEPTH 44
573 #define RUY_OFFSET_START_ROW 48
574 #define RUY_OFFSET_START_COL 52
575 #define RUY_OFFSET_LAST_ROW 56
576 #define RUY_OFFSET_LAST_COL 60
577 #define RUY_OFFSET_DST_ROWS 64
578 #define RUY_OFFSET_DST_COLS 68
579 #define RUY_OFFSET_LHS_STRIDE 72
580 #define RUY_OFFSET_RHS_STRIDE 76
581 #define RUY_OFFSET_DST_STRIDE 80
582 #define RUY_OFFSET_DEPTH 84
583 #define RUY_OFFSET_CLAMP_MIN 88
584 #define RUY_OFFSET_CLAMP_MAX 92
585 #define RUY_OFFSET_FLAGS 96
586 #define RUY_OFFSET_DST_TYPE_ID 97
587 
588 #define RUY_STACK_OFFSET_SIZE 96
589 #define RUY_STACK_OFFSET_DST_COL_PTR 0
590 #define RUY_STACK_OFFSET_DST_PTR 16
591 #define RUY_STACK_OFFSET_ROW 32
592 #define RUY_STACK_OFFSET_COL 48
593 #define RUY_STACK_OFFSET_LHS_COL_PTR 64
594 #define RUY_STACK_OFFSET_RHS_COL_PTR 80
595 
596 template <typename Params>
CheckOffsetsInKernelParams8bit(const Params &)597 void CheckOffsetsInKernelParams8bit(const Params&) {
598   static_assert(offsetof(Params, lhs_zero_point) == RUY_OFFSET_LHS_ZERO_POINT,
599                 "");
600   static_assert(offsetof(Params, rhs_zero_point) == RUY_OFFSET_RHS_ZERO_POINT,
601                 "");
602   static_assert(offsetof(Params, dst_zero_point) == RUY_OFFSET_DST_ZERO_POINT,
603                 "");
604   static_assert(offsetof(Params, prod_zp_depth) == RUY_OFFSET_PROD_ZP_DEPTH,
605                 "");
606   static_assert(offsetof(Params, multiplier_fixedpoint) ==
607                     RUY_OFFSET_MULTIPLIER_FIXEDPOINT,
608                 "");
609   static_assert(
610       offsetof(Params, multiplier_exponent) == RUY_OFFSET_MULTIPLIER_EXPONENT,
611       "");
612   static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, "");
613   static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, "");
614   static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, "");
615   static_assert(offsetof(Params, lhs_sums) == RUY_OFFSET_LHS_SUMS, "");
616   static_assert(offsetof(Params, rhs_sums) == RUY_OFFSET_RHS_SUMS, "");
617   static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, "");
618   static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, "");
619   static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, "");
620   static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, "");
621   static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, "");
622   static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, "");
623   static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, "");
624   static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, "");
625   static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, "");
626 }
627 
628 // Fast-int8 kernel, ported from ARM 64 version.
629 // Relevant target CPUs for this kernel include Krait 400 and A9,
630 // since these are 32-bit, out-of-order CPUs.
Kernel8bitNeon(const KernelParams8bit<4,2> & params)631 void Kernel8bitNeon(const KernelParams8bit<4, 2>& params) {
632   profiler::ScopeLabel label("Kernel (kNeon)");
633 
634   CheckOffsetsInKernelParams8bit(params);
635 
636   const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
637   const std::int8_t* rhs_col_ptr =
638       static_cast<const int8_t*>(params.rhs_base_ptr);
639   const std::int8_t* lhs_ptr = lhs_col_ptr;
640   const std::int8_t* rhs_ptr = rhs_col_ptr;
641 
642   // The asm kernel below has the following NEON register allocation:
643   //
644   // q6 - q13 are 128-bit (4x32b) accumulators.
645   // During accumulation, d0 -- d7 are used to load int8 data from LHS and
646   // d8 -- d11 from RHS:
647   //                                      int8 RHS 16x2 block
648   //                              /-----------------------------|
649   //                              |d8.b[0-7]   .....  d10.b[0-7]|
650   //                              |  ...                  ...   |
651   //                              |d9.b[0-7]   .....  d11.b[0-7]|
652   //                              \-----------------------------/
653   //    int8 LHS 4x16 block
654   //  /------------------------\  /-----------------------------|
655   //  |d0.b[0-7] ... d1.b[0-7] |  | q6         .....      q10   |
656   //  |d2.b[0-7] ... d3.b[0-7] |  | q7         .....      q11   |
657   //  (Reload d0, d1, d2, d3)
658   //  |d0.b[0-7] ... d1.b[0-7] |  | q8         .....      q12   |
659   //  |d2.b[0-7] ... d3.b[0-7] |  | q9         .....      q13   |
660   //  \------------------------/  \-----------------------------/
661   //                                128-bit accumulators 4x2 block
662   //
663   // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING
664   // optimization for this kernel.
665   asm volatile(
666 #define RUY_MAKE_ZERO(reg) "vmov.i32 " #reg ", #0x00000000\n"
667 
668         // clang-format off
669 
670         // Load the first 64 bytes of LHS and RHS data.
671         "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n"
672         // Clear accumulators.
673         RUY_MAKE_ZERO(q6)
674         RUY_MAKE_ZERO(q7)
675         "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n"
676         RUY_MAKE_ZERO(q8)
677         RUY_MAKE_ZERO(q9)
678         RUY_MAKE_ZERO(q10)
679         "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n"
680         RUY_MAKE_ZERO(q11)
681         "vld1.8 {d10, d11}, [%[rhs_ptr]]!\n"
682 
683         "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
684 
685         "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
686         RUY_MAKE_ZERO(q12)
687         "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
688 
689         "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
690         RUY_MAKE_ZERO(q13)
691         "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
692 
693         "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
694         RUY_MAKE_ZERO(q14)
695         "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
696 
697         "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n"
698         RUY_MAKE_ZERO(q15)
699         "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
700 
701         "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
702         "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
703 
704         "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n"
705         "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
706 
707 
708         // r1 is the number of levels of depth that we have already loaded
709         // LHS and RHS data for. Corresponding to the initial ld1 instructions
710         // above, this is currently 16.
711         "mov r1, #16\n"
712 
713         // Main loop of the whole GEMM, over rows and columns of the
714         // destination matrix.
715         "1:\n"
716 
717         // r1 is how many levels of depth we have already loaded
718         // data for, r10 is the total depth.
719         "ldr r10, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
720         "cmp r1, r10\n"
721         "beq 79f\n"
722 
723         "2:\n"
724 
725         // Mult, mult-acc in to q14, q15, q2, q3
726         "vmull.s8 q14, d0, d8\n"
727         "vmull.s8 q2, d0, d10\n"
728 
729         "vmull.s8 q15, d2, d8\n"
730         "vmull.s8 q3, d2, d10\n"
731 
732         "vmlal.s8 q14, d1, d9\n"
733         "vmlal.s8 q2, d1, d11\n"
734         "vmlal.s8 q15, d3, d9\n"
735         "vmlal.s8 q3, d3, d11\n"
736         "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS
737 
738         // Then pairwise accumulate in to q6, q7, q10, q11
739         "vpadal.s16 q6, q14\n"
740         "vpadal.s16 q7, q15\n"
741         "vpadal.s16 q10, q2\n"
742         "vpadal.s16 q11, q3\n"
743 
744         // Mult, mult-acc in to q14, q15, q2, q3
745         "vmull.s8 q14, d0, d8\n"
746         "vmull.s8 q2, d0, d10\n"
747 
748         "vmull.s8 q15, d2, d8\n"
749         "vmull.s8 q3, d2, d10\n"
750 
751         "vmlal.s8 q14, d1, d9\n"
752         "vmlal.s8 q2, d1, d11\n"
753         "vmlal.s8 q15, d3, d9\n"
754         "vmlal.s8 q3, d3, d11\n"
755         "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS
756 
757         // Then pairwise accumulate in to q8, q9, q12, q13
758         "vpadal.s16 q8, q14\n"
759         "vld1.8 {d8, d9, d10, d11}, [%[rhs_ptr]]!\n"
760         "vpadal.s16 q9, q15\n"
761         "vpadal.s16 q12, q2\n"
762         "vpadal.s16 q13, q3\n"
763 
764         // Prefetch the next 64 bytes of LHS and RHS data.
765         RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
766         RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
767 
768         // Each iteration of this loop advances by 16 levels of depth.
769         "add r1, r1, #16\n"
770 
771         // Loop termination condition
772         "cmp r1, r10\n"
773 
774         "blt 2b\n"
775 
776         "79:\n"
777 
778         // Mult, mult-acc in to q14, q15, q2, q3
779         "vmull.s8 q14, d0, d8\n"
780         "vmull.s8 q2, d0, d10\n"
781 
782         "vmull.s8 q15, d2, d8\n"
783         "vmull.s8 q3, d2, d10\n"
784 
785         "vmlal.s8 q14, d1, d9\n"
786         "vmlal.s8 q2, d1, d11\n"
787         "vmlal.s8 q15, d3, d9\n"
788         "vmlal.s8 q3, d3, d11\n"
789         "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS
790 
791         // Then pairwise accumulate in to q6, q7, q10, q11
792         "vpadal.s16 q6, q14\n"
793         "vpadal.s16 q7, q15\n"
794         "vpadal.s16 q10, q2\n"
795         "vpadal.s16 q11, q3\n"
796 
797         // Mult, mult-acc in to q14, q15, q2, q3
798         "vmull.s8 q14, d0, d8\n"
799         "vmull.s8 q2, d0, d10\n"
800 
801         "vmull.s8 q15, d2, d8\n"
802         "vmull.s8 q3, d2, d10\n"
803 
804         "vmlal.s8 q14, d1, d9\n"
805         "vmlal.s8 q2, d1, d11\n"
806         "vmlal.s8 q15, d3, d9\n"
807         "vmlal.s8 q3, d3, d11\n"
808 
809         // Then pairwise accumulate in to q8, q9, q12, q13
810         "vpadal.s16 q8, q14\n"
811         "vpadal.s16 q9, q15\n"
812         "vpadal.s16 q12, q2\n"
813         "vpadal.s16 q13, q3\n"
814 
815 
816         // All accumulation over depth done. q6 - q13 contain the 4x32b
817         // accumulators for the 4x2 final matrix.
818         // We now have to compute the final 8-bit values from these int32
819         // accumulators, and advance to the next 4x2 block. We intertwine
820         // these two aspects whenever possible for optimal pipelining, both
821         // at the data flow level (prefetch data for next block as early as
822         // possible) and instruction pipelining level (some of the next-block
823         // work can dual-issue with some of the final work on the current
824         // block).
825 
826         // q6-q13 now contain 4 x 32b
827         "vpadd.i32 d0, d12, d13\n"
828         "vpadd.i32 d1, d14, d15\n"
829         "vpadd.i32 d2, d16, d17\n"
830         "vpadd.i32 d3, d18, d19\n"
831         "vpadd.i32 d4, d20, d21\n"
832         "vpadd.i32 d5, d22, d23\n"
833         "vpadd.i32 d6, d24, d25\n"
834         "vpadd.i32 d7, d26, d27\n"
835 
836         // d0-d7 each contain 2 x 32b accumulators.
837         // Need to add pairwise to get 1 x 32b for each of the 4x2 entries
838         // of destination, (Four 'd' registers total)
839         "vpadd.i32 d28, d0, d1\n"
840         "vpadd.i32 d29, d2, d3\n"
841         "vpadd.i32 d30, d4, d5\n"
842         "vpadd.i32 d31, d6, d7\n"
843 
844         //Now d28 - d31 have the 1 x 32b accumulators for the 4x2 entries
845 
846         // Logic to advance to the next block in preparation for the next
847         // iteration of the main loop. For now, we only want to compute
848         // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
849         // not yet ready to update the values of row and col, as we still need
850         // the current values for the rest of the work on the current block.
851 
852         "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
853         "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
854         "cmp r1, r3\n"  // Have we finished the last row?
855 
856         "bge 4f\n"           // If finished last row, go to 4
857         // Not finished last row: then advance to next row.
858         "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
859         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
860         "add r4, r4, r1, lsl #2\n"
861         "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
862         "b 5f\n"
863         "4:\n"  // Finished last row...
864         "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
865         // Go back to first row
866         "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
867 
868         // Now we need to advance to the next column. If we already
869         // finished the last column, then in principle we are done, however
870         // we can't just return here, as we need to allow the end work of the
871         // current block to complete. The good news is that at this point it
872         // doesn't matter what data we load for the next column, since
873         // we will exit from the main loop below before actually storing
874         // anything computed from that data.
875 
876         "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
877         "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
878         "cmp r8, r4\n"  // Have we finished the last column?
879         "bge 5f\n" // If yes, just carry on without updating the column pointer.
880         // Not finished last column: then advance to next column.
881         "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
882         "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
883         "add r10, r10, r1, lsl #1\n"
884         "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
885         "5:\n"
886 
887         // Set the LHS and RHS data pointers to the start of the columns just
888         // computed.
889         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
890         "mov %[lhs_ptr], r4\n"
891         "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
892         "mov %[rhs_ptr], r5\n"
893 
894         // Now we load: bias data, LHS sums data, RHS sums data.
895 
896         // First, load the base pointers from the params.
897         "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
898         "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
899 
900         // Let r8 be stack offset of the row or column variable, whichever
901         // is the channel index.
902         "tst r4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
903         "bne 1000f\n"
904         "mov r8, #" RUY_STR(RUY_STACK_OFFSET_ROW) "\n"
905         "b 1001f\n"
906         "1000:\n"
907         "mov r8, #" RUY_STR(RUY_STACK_OFFSET_COL) "\n"
908         "1001:\n"
909 
910         // Let r8 be the channel index.
911         "ldr r8, [sp, r8]\n"
912         // Compute the bias pointer, by conditionally using the channel index
913         // (r8) as offset into bias buffer (r1).
914         "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
915         "beq 1002f\n"
916         "add r1, r1, r8, lsl #2\n"
917         "1002:\n"
918 
919         // Load 2 bias values. When the channel dimension is rows, we will load
920         // another 2 bias values just before performing the bias addition below,
921         // as this kernel has a 4x2 rectangular shape.
922         "vld1.32 {d24}, [r1]!\n"
923 
924         // Now that we know what LHS and RHS data the next iteration of the
925         // main loop will need to load, we start loading the first 32 bytes of
926         // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
927         // in the rest of the work on the current block.
928         "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n"
929         RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
930         "vld1.8 {d8, d9, d10, d11}, [%[rhs_ptr]]!\n"
931         RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
932 
933         // Add to the bias values the product
934         // (depth * lhs_zero_point * rhs_zero_point),
935         // See the term NZ1Z2 in equation (7) in
936         // https://arxiv.org/pdf/1712.05877.pdf
937         "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
938         "vdup.32 q9, r3\n"
939         "vadd.i32 d24, d24, d18\n"
940 
941         // Perform the bias-addition (per the above, we have just folded into
942         // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
943         // Jump based on channel dimension.
944         "tst r4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
945         "bne 6f\n"
946         // Case where channels are rows.
947         // Load the remaining 2 bias values, since we're on the width-4 side
948         // of this 4x2 kernel.
949         "vld1.32 {d25}, [r1]\n"
950         "vadd.i32 d25, d25, d19\n"
951         "vadd.i32 q14, q14, q12\n"
952         "vadd.i32 q15, q15, q12\n"
953         "b 7f\n"
954 
955         "6:\n"
956         // Case where channels are columns.
957         "vdup.32 q10, d24[0]\n"
958         "vdup.32 q11, d24[1]\n"
959         "vadd.i32 q14, q14, q10\n"
960         "vadd.i32 q15, q15, q11\n"
961         "7:\n"
962 
963         // LHS/RHS zero points
964         // Has RHS sums
965         "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
966         "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
967         "beq 401f\n"
968         "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
969         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
970         // Offset by current col * number of bytes per value
971         "add r3, r3, r4, lsl #2\n"
972         "vld1.32 { d12 }, [r3]\n"
973         "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
974         "vdup.32 q10, r5\n"  // create lhs_zero_point_vec
975         // Subtract rhs_sums * lhs_zero_point, per
976         // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
977         "vmls.i32 q14, q10, d12[0]\n"
978         "vmls.i32 q15, q10, d12[1]\n"
979         "401:\n"
980 
981         // Has LHS sums
982         "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
983         "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
984         "beq 402f\n"
985         "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
986         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
987         // Offset by current row * number of bytes per value
988         "add r2, r2, r4, lsl #2\n"
989         "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
990 
991         // Load 4 lhs_sums values.
992         "vld1.32 {d22, d23}, [r2]\n"
993         "vdup.32 d13, r5\n" // rhs_zero_point
994 
995         // Compute lhs_sums * rhs_zero_point.
996         "vmul.i32 q11, q11, d13[1]\n"
997         // Subtract lhs_sums * rhs_zero_point, per
998         // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
999         "vsub.s32 q14, q14, q11\n"
1000         "vsub.s32 q15, q15, q11\n"
1001 
1002         // If the destination is int32, it means the user asks for the raw
1003         // accumulators, no need for us to downquantize the value.
1004         "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n"
1005         "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
1006         "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
1007 
1008         "402:\n"
1009 
1010         // At this point we have computed the final int32 values. Now we
1011         // start down-quantizing them to obtain the final 8bit values from them.
1012 
1013         // As part of this down-quantization, our int32 values will be
1014         // multiplied by a multiplier that has a fixed-point component and an
1015         // exponent component.
1016 
1017         // Compute the data pointers for the multiplier data
1018         //   r1 = exponent part
1019         //   r2 = fixedpoint part
1020         "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
1021         "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
1022         // r6 has flags, r8 has channel index
1023         "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
1024         "beq 1003f\n"
1025         "add r1, r1, r8, lsl #2\n"
1026         "add r2, r2, r8, lsl #2\n"
1027         "1003:\n"
1028 
1029         // Load the first 2 values of multiplier exponent and fixedpoint data
1030         // Since this kernel is rectangular 4x2, we will only conditionally load
1031         // 2 more values below.
1032         "vld1.32 {d20}, [r1]!\n"  // 2 values of multiplier_exponent
1033         "vld1.32 {d12}, [r2]!\n"  // 2 values of multiplier_fixedpoint
1034 
1035         "tst r6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
1036         "vmvn.i32 q8, #0\n"
1037         "bne 8f\n"
1038         // Case where channels are rows.
1039         // Load the remaining 2 bias values, since we're on the width-4 side
1040         // of this 4x2 kernel.
1041         "vld1.32 {d21}, [r1]\n"  // 2 more values of multiplier_exponent
1042         "vld1.32 {d13}, [r2]\n"  // 2 more values of multiplier_fixedpoint
1043         "vmin.s32 q11, q10, q8\n"
1044         "vsub.s32 q10, q10, q11\n"
1045 
1046         // Apply the positive exponent part of the multiplier.
1047         "vshl.s32 q14, q14, q10\n"
1048         "vshl.s32 q15, q15, q10\n"
1049 
1050         // Apply the fixed-point part of the multiplier.
1051         "vqdmulh.s32 q14, q14, q6\n"
1052         "vqdmulh.s32 q15, q15, q6\n"
1053 
1054         // Apply the negative exponent part of the multiplier.
1055         "vrshl.s32 q14, q14, q11\n"
1056         "vrshl.s32 q15, q15, q11\n"
1057         "b 9f\n"
1058 
1059         "8:\n"
1060         // Case where channels are columns.
1061         "vmin.s32 d22, d20, d16\n"
1062         "vsub.s32 d20, d20, d22\n"
1063 
1064         // Apply the positive exponent part of the multiplier.
1065         "vdup.32  q12, d20[0]\n"
1066         "vdup.32  q13, d20[1]\n"
1067         "vshl.s32 q14, q14, q12\n"
1068         "vshl.s32 q15, q15, q13\n"
1069 
1070         // Apply the fixed-point part of the multiplier.
1071         "vqdmulh.s32 q14, q14, d12[0]\n"
1072         "vqdmulh.s32 q15, q15, d12[1]\n"
1073 
1074         // Apply the negative exponent part of the multiplier.
1075         "vdup.32  q12, d22[0]\n"
1076         "vdup.32  q13, d22[1]\n"
1077         "vrshl.s32 q14, q14, q12\n"
1078         "vrshl.s32 q15, q15, q13\n"
1079 
1080         "9:\n"
1081 
1082         "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n"
1083         "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
1084         "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
1085         "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
1086         "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
1087 
1088         // Store uint8 values:
1089         RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
1090 
1091         // Cast-and-saturate from int32 to int16
1092         // After this, all values for output are in q14.
1093         "vqmovn.s32 d28, q14\n"
1094         "vqmovn.s32 d29, q15\n"
1095 
1096         // At this point, d12 -- d26, d30, d31 aren't used anymore for the
1097         // current block, so we can start clearing these accumulators for the
1098         // next block (next iteration of the main loop).
1099         RUY_MAKE_ZERO(q6)
1100         RUY_MAKE_ZERO(q7)
1101         RUY_MAKE_ZERO(q8)
1102         RUY_MAKE_ZERO(q9)
1103         RUY_MAKE_ZERO(q10)
1104         RUY_MAKE_ZERO(q11)
1105         RUY_MAKE_ZERO(q12)
1106         RUY_MAKE_ZERO(q13)
1107         RUY_MAKE_ZERO(q15)
1108 
1109         // Load the destination zero point into each of the 8 16-bit slots
1110         // in a q register.
1111         "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
1112         "vdup.16 q13, r4\n" // dst_zero_point
1113 
1114         // Add the destination zero point
1115         "vqadd.s16 q14, q14, q13\n"
1116 
1117         // Cast-and-saturate from int16 to uint8
1118         // Now all 8 1-byte values are in d30.
1119         "vqmovun.s16 d30, q14\n"
1120 
1121         // Load the clamp_min, clamp_max bounds
1122         "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
1123         "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
1124         "vdup.8 d28, r2\n"  // clamp_min
1125         "vdup.8 d29, r3\n"  // clamp_max
1126 
1127         // Apply the clamp_min bound
1128         "vmax.u8 d30, d30, d28\n"
1129         // Apply the clamp_max bound
1130         "vmin.u8 d30, d30, d29\n"
1131 
1132         // Compute how much of the 4x2 block of destination 8bit values that
1133         // we have computed, fit in the destination matrix. Typically, all of
1134         // it fits, but when the destination matrix shape is not a multiple
1135         // of 4x2, there are some 4x2 blocks along the boundaries that do
1136         // not fit entirely.
1137 
1138         "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
1139         "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1140         "sub r1, r1, r8\n"
1141 
1142         "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
1143         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1144         "sub r2, r2, r4\n"
1145         "mov r3, #4\n"
1146         "mov r5, #2\n"
1147         "cmp r1, #4\n"
1148         // Compute r1 = how many rows of the 4x2 block fit
1149         "it gt\n"
1150         "movgt r1, r3\n"
1151 
1152         "cmp r2, #2\n"
1153         // Compute r2 = how many cols of the 4x2 block fit
1154         "it gt\n"
1155         "movgt r2, r5\n"
1156 
1157         // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits.
1158         "cmp r1, r3\n"
1159         "it eq\n"
1160         "cmpeq r2, r5\n"
1161         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1162         "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
1163         // Yes, all of the 4x2 block fits, go to fast path.
1164         "beq 30f\n"
1165         // Not all of the 4x2 block fits.
1166         // Store to dst_tmp_buf
1167         // Set r3 address to write to dst_tmp_buf.
1168         "mov r3, %[dst_tmp_buf]\n"
1169         "vst1.8 {d30}, [r3]\n"
1170 
1171         // Slow loop copying from dst_tmp_buf to dst.
1172         "mov r6, #0\n"
1173         "50:\n"
1174         "mov r8, #0\n"
1175         "51:\n"
1176         "ldrb r10, [r3, r8]\n"
1177         "strb r10, [r4, r8]\n"
1178         "add r8, r8, #1\n"
1179         "cmp r8, r1\n"
1180         "blt 51b\n"
1181         "add r6, r6, #1\n"
1182         "add r3, r3, #4\n"
1183         "add r4, r4, r5\n"
1184         "cmp r6, r2\n"
1185         "blt 50b\n"
1186         "b 31f\n"
1187         "30:\n"
1188         // Yes, all of the 4x2 block fits.
1189         // r3 address, r5 stride
1190         "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1191         "mov r4, r3\n"
1192         "mov r6, #1\n"
1193 
1194         "vst1.32 {d30[0]}, [r3]\n"
1195         "add r4, r4, r5\n"
1196         "mov r3, r4\n"
1197         "vst1.32 {d30[1]}, [r3]\n"
1198 
1199         "31:\n"
1200 
1201         // Load dst_ptr, increment, and write back.
1202         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1203         "add r4, r4, #4\n"
1204         "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1205 
1206         RUY_MAKE_ZERO(q13)
1207         RUY_MAKE_ZERO(q14)
1208         RUY_MAKE_ZERO(q15)
1209 
1210         "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
1211 
1212         // Store int8 values:
1213         RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
1214 
1215         // Cast-and-saturate from int32 to int16
1216         // After this, all values for output are in q14.
1217         "vqmovn.s32 d28, q14\n"
1218         "vqmovn.s32 d29, q15\n"
1219 
1220         // At this point, d12 -- d26, d30, d31 aren't used anymore for the
1221         // current block, so we can start clearing these accumulators for the
1222         // next block (next iteration of the main loop).
1223         RUY_MAKE_ZERO(q6)
1224         RUY_MAKE_ZERO(q7)
1225         RUY_MAKE_ZERO(q8)
1226         RUY_MAKE_ZERO(q9)
1227         RUY_MAKE_ZERO(q10)
1228         RUY_MAKE_ZERO(q11)
1229         RUY_MAKE_ZERO(q12)
1230         RUY_MAKE_ZERO(q13)
1231         RUY_MAKE_ZERO(q15)
1232 
1233         // Load the destination zero point into each of the 8 16-bit slots
1234         // in a q register.
1235         "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
1236         "vdup.16 q13, r4\n" // dst_zero_point
1237 
1238         // Add the destination zero point
1239         "vqadd.s16 q14, q14, q13\n"
1240 
1241         // Cast-and-saturate from int16 to int8
1242         // Now all 8 1-byte values are in d30.
1243         "vqmovn.s16 d30, q14\n"
1244 
1245         // Load the clamp_min, clamp_max bounds
1246         "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
1247         "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
1248         "vdup.8 d28, r2\n"  // clamp_min
1249         "vdup.8 d29, r3\n"  // clamp_max
1250 
1251         // Apply the clamp_min bound
1252         "vmax.s8 d30, d30, d28\n"
1253         // Apply the clamp_max bound
1254         "vmin.s8 d30, d30, d29\n"
1255 
1256         // Compute how much of the 4x2 block of destination 8bit values that
1257         // we have computed, fit in the destination matrix. Typically, all of
1258         // it fits, but when the destination matrix shape is not a multiple
1259         // of 4x2, there are some 4x2 blocks along the boundaries that do
1260         // not fit entirely.
1261 
1262         "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
1263         "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1264         "sub r1, r1, r8\n"
1265 
1266         "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
1267         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1268         "sub r2, r2, r4\n"
1269         "mov r3, #4\n"
1270         "mov r5, #2\n"
1271         "cmp r1, #4\n"
1272         // Compute r1 = how many rows of the 4x2 block fit
1273         "it gt\n"
1274         "movgt r1, r3\n"
1275 
1276         "cmp r2, #2\n"
1277         // Compute r2 = how many cols of the 4x2 block fit
1278         "it gt\n"
1279         "movgt r2, r5\n"
1280 
1281         // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits.
1282         "cmp r1, r3\n"
1283         "it eq\n"
1284         "cmpeq r2, r5\n"
1285         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1286         "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
1287         // Yes, all of the 4x2 block fits, go to fast path.
1288         "beq 30f\n"
1289         // Not all of the 4x2 block fits.
1290         // Store to dst_tmp_buf
1291         // Set r3 address to write to dst_tmp_buf.
1292         "mov r3, %[dst_tmp_buf]\n"
1293         "vst1.8 {d30}, [r3]\n"
1294 
1295         // Slow loop copying from dst_tmp_buf to dst.
1296         "mov r6, #0\n"
1297         "50:\n"
1298         "mov r8, #0\n"
1299         "51:\n"
1300         "ldrb r10, [r3, r8]\n"
1301         "strb r10, [r4, r8]\n"
1302         "add r8, r8, #1\n"
1303         "cmp r8, r1\n"
1304         "blt 51b\n"
1305         "add r6, r6, #1\n"
1306         "add r3, r3, #4\n"
1307         "add r4, r4, r5\n"
1308         "cmp r6, r2\n"
1309         "blt 50b\n"
1310         "b 31f\n"
1311         "30:\n"
1312         // Yes, all of the 4x2 block fits.
1313         // r3 address, r5 stride
1314         "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1315         "mov r4, r3\n"
1316         "mov r6, #1\n"
1317 
1318         "vst1.32 {d30[0]}, [r3]\n"
1319         "add r4, r4, r5\n"
1320         "mov r3, r4\n"
1321         "vst1.32 {d30[1]}, [r3]\n"
1322 
1323         "31:\n"
1324 
1325         // Load dst_ptr, increment, and write back.
1326         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1327         "add r4, r4, #4\n"
1328         "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1329 
1330         RUY_MAKE_ZERO(q13)
1331         RUY_MAKE_ZERO(q14)
1332         RUY_MAKE_ZERO(q15)
1333 
1334         "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
1335 
1336         RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
1337 
1338         // Load the destination zero point into each of the 4 32-bit slots
1339         // in a q register.
1340         "ldrsh r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
1341         "vdup.32 q13, r4\n" // dst_zero_point
1342         // Add the destination zero point
1343         "vadd.s32 q14, q14, q13\n"
1344         "vadd.s32 q15, q15, q13\n"
1345 
1346         // Cast-and-saturate from int32 to int16
1347         // After this, all values for output are in q14.
1348         "vqmovn.s32 d28, q14\n"
1349         "vqmovn.s32 d29, q15\n"
1350 
1351         // At this point, v18 -- v31 aren't used anymore for the current block,
1352         // so we can start clearing these accumulators for the next block
1353         // (next iteration of the main loop).
1354         RUY_MAKE_ZERO(q6)
1355         RUY_MAKE_ZERO(q7)
1356         RUY_MAKE_ZERO(q8)
1357         RUY_MAKE_ZERO(q9)
1358         RUY_MAKE_ZERO(q10)
1359         RUY_MAKE_ZERO(q11)
1360         RUY_MAKE_ZERO(q15)
1361 
1362          // Load the clamp_min, clamp_max bounds
1363         "ldrh r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
1364         "ldrh r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
1365         "vdup.16 q12, r2\n"  // clamp_min
1366         "vdup.16 q13, r3\n"  // clamp_max
1367 
1368         // Apply the clamp_min bound
1369         "vmax.s16 q14, q14, q12\n"
1370         // Apply the clamp_max bound
1371         "vmin.s16 q14, q14, q13\n"
1372 
1373         RUY_MAKE_ZERO(q12)
1374         RUY_MAKE_ZERO(q13)
1375 
1376         // Compute how much of the 4x2 block of destination 16-bit values that
1377         // we have computed, fit in the destination matrix. Typically, all of
1378         // it fits, but when the destination matrix shape is not a multiple
1379         // of 4x2, there are some 4x2 blocks along the boundaries that do
1380         // not fit entirely.
1381 
1382         "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
1383         "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1384         "sub r1, r1, r8\n"
1385 
1386         "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
1387         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1388         "sub r2, r2, r4\n"
1389         "mov r3, #4\n"
1390         "mov r5, #2\n"
1391         "cmp r1, #4\n"
1392         // Compute r1 = how many rows of the 4x2 block fit
1393         "it gt\n"
1394         "movgt r1, r3\n"
1395 
1396         "cmp r2, #2\n"
1397         // Compute r2 = how many cols of the 4x2 block fit
1398         "it gt\n"
1399         "movgt r2, r5\n"
1400 
1401         // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits.
1402         "cmp r1, r3\n"
1403         "it eq\n"
1404         "cmpeq r2, r5\n"
1405         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1406         "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
1407         // Yes, all of the 4x2 block fits, go to fast path.
1408         "beq 30f\n"
1409         // Not all of the 4x2 block fits.
1410         // Store to dst_tmp_buf
1411         // Set r3 address to write to dst_tmp_buf.
1412         "mov r3, %[dst_tmp_buf]\n"
1413         "vst1.16 {q14}, [r3]\n"
1414 
1415         // Slow loop copying from dst_tmp_buf to dst.
1416         "mov r6, #0\n"
1417         "50:\n"
1418         "mov r8, #0\n"
1419         "51:\n"
1420         // Shift of offset register for half-word loads not allowed in A32,
1421         // so we shift, load/store, then shift back r8.
1422         "lsl r8, r8, #1\n"
1423         "ldrh r10, [r3, r8]\n"
1424         "strh r10, [r4, r8]\n"
1425         "lsr r8, r8, #1\n"
1426         "add r8, r8, #1\n"
1427         "cmp r8, r1\n"
1428         "blt 51b\n"
1429         "add r6, r6, #1\n"
1430         "add r3, r3, #8\n"
1431         "add r4, r4, r5\n"
1432         "cmp r6, r2\n"
1433         "blt 50b\n"
1434         "b 31f\n"
1435         "30:\n"
1436         // Yes, all of the 4x2 block fits.
1437         // r3 address, r5 stride
1438         "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1439         "mov r4, r3\n"
1440         "mov r6, #2\n"
1441 
1442         "vst1.16 {d28[0]}, [r3], r6\n"
1443         "add r4, r4, r5\n"
1444         "vst1.16 {d28[1]}, [r3], r6\n"
1445         "vst1.16 {d28[2]}, [r3], r6\n"
1446         "vst1.16 {d28[3]}, [r3], r6\n"
1447         "mov r3, r4\n"
1448         "vst1.16 {d29[0]}, [r3], r6\n"
1449         "vst1.16 {d29[1]}, [r3], r6\n"
1450         "vst1.16 {d29[2]}, [r3], r6\n"
1451         "vst1.16 {d29[3]}, [r3], r6\n"
1452         "31:\n"
1453 
1454          // Load dst_ptr, increment, and write back.
1455         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1456         "add r4, r4, #8\n"
1457         "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1458 
1459         RUY_MAKE_ZERO(q14)
1460 
1461         "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
1462 
1463         RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
1464 
1465         // Since the store type is the same as the accum type, no need for
1466         // downcast. There's also no need for clamp by min/max.
1467 
1468         // At this point, v20 -- v31 aren't used anymore for the current block,
1469         // so we can start clearing these accumulators for the next block
1470         // (next iteration of the main loop).
1471         // Clear accumulators.
1472         RUY_MAKE_ZERO(q6)
1473         RUY_MAKE_ZERO(q7)
1474         RUY_MAKE_ZERO(q8)
1475         RUY_MAKE_ZERO(q9)
1476         RUY_MAKE_ZERO(q10)
1477         RUY_MAKE_ZERO(q11)
1478         RUY_MAKE_ZERO(q12)
1479         RUY_MAKE_ZERO(q13)
1480 
1481         // Compute how much of the 4x2 block of destination 32 bit values that
1482         // we have computed, fit in the destination matrix. Typically, all of
1483         // it fits, but when the destination matrix shape is not a multiple
1484         // of 4x2, there are some 4x4 blocks along the boundaries that do
1485         // not fit entirely.
1486 
1487         "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
1488         "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1489         "sub r1, r1, r8\n"
1490 
1491         "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
1492         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1493         "sub r2, r2, r4\n"
1494         "mov r3, #4\n"
1495         "mov r5, #2\n"
1496         "cmp r1, #4\n"
1497         // Compute r1 = how many rows of the 4x2 block fit
1498         "it gt\n"
1499         "movgt r1, r3\n"
1500 
1501         "cmp r2, #2\n"
1502         // Compute r2 = how many cols of the 4x2 block fit
1503         "it gt\n"
1504         "movgt r2, r5\n"
1505 
1506         // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits.
1507         "cmp r1, r3\n"
1508         "it eq\n"
1509         "cmpeq r2, r5\n"
1510         // Yes, all of the 4x2 block fits, go to fast path.
1511         "beq 30f\n"
1512         // Not all of the 4x2 block fits.
1513         // Set (r3 address, r4 stride) to write to dst_tmp_buf
1514         "mov r3, %[dst_tmp_buf]\n"
1515         "mov r4, #16\n"
1516         "b 31f\n"
1517 
1518         "30:\n"
1519         // Yes, all of the 4x2 block fits.
1520         "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
1521         // r3 address, r4 stride
1522         "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1523         "mov r4, r5\n"
1524 
1525         "31:\n"
1526 
1527         "vst1.32 {d28, d29}, [r3]\n"
1528         "add r3, r3, r4\n"
1529         "vst1.32 {d30, d31}, [r3]\n"
1530 
1531         // If all of the 4x2 block fits, we just finished writing it to the
1532         // destination, so we skip the next part.
1533         "beq 41f\n"
1534         // Not all of the 4x2 block fits in the destination matrix.  We just
1535         // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
1536         // it to copy into the destination matrix the part that fits.
1537         "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
1538         "mov r3, %[dst_tmp_buf]\n"
1539         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1540         "mov r6, #0\n"
1541         "50:\n"
1542         "mov r5, #0\n"
1543         "51:\n"
1544         "ldr r10, [r3, r5, lsl #2]\n"
1545         "str r10, [r4, r5, lsl #2]\n"
1546         "add r5, r5, #1\n"
1547         "cmp r5, r1\n"
1548         "blt 51b\n"
1549         "add r6, r6, #1\n"
1550         "add r3, r3, #16\n"
1551         "add r4, r4, r8\n"
1552         // r2 = how many cols of the 8x4 block fit
1553         "cmp r6, r2\n"
1554         "blt 50b\n"
1555 
1556         "41:\n"
1557         // Load dst_ptr, increment, and write back.
1558         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1559         "add r4, r4, #16\n"
1560         "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1561 
1562         RUY_MAKE_ZERO(q10)
1563         RUY_MAKE_ZERO(q11)
1564 
1565         "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
1566 
1567         RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
1568 
1569         // Reload some params --- we had used x5 -- x7 for a few other things
1570         // since the last time we had loaded them.
1571         "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
1572         "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
1573         "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
1574 
1575         // Move to the next block of the destination matrix, for the next iter
1576         // of the main loop.  Notice that lhs_col_ptr, rhs_col_ptr have already
1577         // been updated earlier.
1578         // Have we reached the end row?
1579         "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1580         "cmp r8, r3\n"
1581 
1582         "beq 20f\n"  // yes, end row.
1583         // Not end row. Move to the next row.
1584         "add r8, r8, #4\n"
1585         // Store new value of row
1586         "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1587 
1588         "b 21f\n"
1589         "20:\n"
1590         // Was already at end row.
1591         // Move back to first row.
1592         "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1593         // Move to the next column.
1594         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1595         "add r4, r4, #2\n"
1596         "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1597 
1598         "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
1599         "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
1600         // Increment dst_col_ptr by 2 * dst_stride (i.e. 2 columns)
1601         "add r1, r1, r8, lsl #1\n"
1602         // Store dst_col_ptr
1603         "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
1604         // Store dst_ptr
1605         "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1606         "21:\n"
1607 
1608         // Main loop exit condition: have we hit the end column?
1609         "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
1610         "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1611         "cmp r8, r4\n"
1612 
1613         // w1 is the number of levels of depth that we have already loaded
1614         // LHS and RHS data for. Corresponding to the initial ld1 instructions
1615         // above, this is currently 16.
1616         "mov r1, #16\n"
1617 
1618         "ble 1b\n"
1619 
1620         // Restore stack pointer.
1621         "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
1622 
1623         // clang-format on
1624 
1625         : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr)
1626         : [ params ] "r"(&params), [dst_tmp_buf] "r"(params.dst_tmp_buf)
1627         : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc",
1628            // Clobber list must specify q registers (and not their constituent
1629            // d registers). There is a (currently unexplained) slowdown if
1630            // d registers are listed in the clobbers list.
1631           "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8",
1632           "q9", "q10", "q12", "q13", "q14", "q15");
1633 }
1634 
1635 // Fast-int8 true "GEMV" kernel (RHS has 1 column). We assume the RHS
1636 // is still packed as if it has two columns
Kernel8bitNeon1Col(const KernelParams8bit<4,2> & params)1637 void Kernel8bitNeon1Col(const KernelParams8bit<4, 2>& params) {
1638   profiler::ScopeLabel label("Kernel (kNeon)");
1639 
1640   CheckOffsetsInKernelParams8bit(params);
1641 
1642   const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
1643   const std::int8_t* rhs_col_ptr =
1644       static_cast<const int8_t*>(params.rhs_base_ptr);
1645   const std::int8_t* lhs_ptr = lhs_col_ptr;
1646   const std::int8_t* rhs_ptr = rhs_col_ptr;
1647 
1648   RUY_DCHECK(!(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL));
1649 
1650   // The asm kernel below has the following NEON register allocation:
1651   //
1652   // q6 - q13 are 128-bit (4x32b) accumulators.
1653   // During accumulation, d0 -- d7 are used to load int8 data from LHS and
1654   // d8 -- d11 from RHS:
1655   //                                            int8 RHS 16x1 block
1656   //                                               /------------|
1657   //                                               | d8.b[0]    |
1658   //                                               | ...        |
1659   //                                               | d8.b[7]    |
1660   //                                               | d9.b[0]    |
1661   //                                               | ...        |
1662   //                                               | d9.b[7]    |
1663   //                                               \------------/
1664   //    int8 LHS 4x16 block
1665   //  /-----------------------------------------\  /------------|
1666   //  |d0.b[0] ... d0.b[7] d1.b[0] ... d1.b[7]  |  | q6         |
1667   //  |d2.b[0] ... d2.b[7] d3.b[0] ... d3.b[7]  |  | q7         |
1668   //  |d4.b[0] ... d4.b[7] d5.b[0] ... d5.b[7]  |  | q8         |
1669   //  |d6.b[0] ... d6.b[7] d7.b[0] ... d7.b[7]  |  | q9         |
1670   //  \-----------------------------------------/  \------------/
1671   //                              128-bit accumulators 4x1 block
1672   //
1673   // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING
1674   // optimization for this kernel.
1675   asm volatile(
1676 #define RUY_MAKE_ZERO(reg) "vmov.i32 " #reg ", #0x00000000\n"
1677 
1678         // clang-format off
1679 
1680         // Load the first 64 bytes of LHS and RHS data.
1681         "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n"
1682         "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n"
1683         "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n"
1684         "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n"
1685         "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n"
1686         // Skip the other column and advance the pointer.
1687         "add %[rhs_ptr], %[rhs_ptr], #16\n"
1688 
1689         "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
1690 
1691         "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
1692         "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
1693 
1694         "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
1695         "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1696 
1697         "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
1698         "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1699 
1700         "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n"
1701         "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1702 
1703         "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
1704         "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
1705 
1706         "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n"
1707         "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
1708 
1709         // Clear accumulators.
1710         RUY_MAKE_ZERO(q6)
1711         RUY_MAKE_ZERO(q7)
1712         RUY_MAKE_ZERO(q8)
1713         RUY_MAKE_ZERO(q9)
1714         RUY_MAKE_ZERO(q10)
1715         RUY_MAKE_ZERO(q11)
1716         RUY_MAKE_ZERO(q12)
1717         RUY_MAKE_ZERO(q13)
1718         RUY_MAKE_ZERO(q14)
1719         RUY_MAKE_ZERO(q15)
1720 
1721         // r1 is the number of levels of depth that we have already loaded
1722         // LHS and RHS data for. Corresponding to the initial ld1 instructions
1723         // above, this is currently 16.
1724         "mov r1, #16\n"
1725 
1726         // Main loop of the whole GEMM, over rows and columns of the
1727         // destination matrix.
1728         "1:\n"
1729 
1730         // r1 is how many levels of depth we have already loaded
1731         // data for, r10 is the total depth.
1732         "ldr r10, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
1733         "cmp r1, r10\n"
1734         "beq 79f\n"
1735 
1736         "2:\n"
1737 
1738         // Mult, mult-acc in to q14, q15
1739         "vmull.s8 q14, d0, d8\n"
1740         "vmull.s8 q15, d2, d8\n"
1741         "vmlal.s8 q14, d1, d9\n"
1742         "vmlal.s8 q15, d3, d9\n"
1743 
1744         // Then pairwise accumulate in to q6, q7
1745         "vpadal.s16 q6, q14\n"
1746         "vpadal.s16 q7, q15\n"
1747 
1748         // Mult, mult-acc in to q14, q15
1749         "vmull.s8 q14, d4, d8\n"
1750         "vmull.s8 q15, d6, d8\n"
1751         "vmlal.s8 q14, d5, d9\n"
1752         "vmlal.s8 q15, d7, d9\n"
1753 
1754         // Then pairwise accumulate in to q8, q9
1755         "vpadal.s16 q8, q14\n"
1756         "vpadal.s16 q9, q15\n"
1757 
1758 
1759         // Load the next 64 bytes of LHS and RHS data.
1760         "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n"
1761         "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n"
1762         "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n"
1763         "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n"
1764         RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
1765         "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n"
1766         // Skip the other column and advance the pointer.
1767         "add %[rhs_ptr], %[rhs_ptr], #16\n"
1768         RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
1769 
1770         // Each iteration of this loop advances by 16 levels of depth.
1771         "add r1, r1, #16\n"
1772 
1773         // Loop termination condition
1774         "cmp r1, r10\n"
1775 
1776         "blt 2b\n"
1777 
1778         "79:\n"
1779 
1780         // Mult, mult-acc in to q14, q15
1781         "vmull.s8 q14, d0, d8\n"
1782         "vmull.s8 q15, d2, d8\n"
1783         "vmlal.s8 q14, d1, d9\n"
1784         "vmlal.s8 q15, d3, d9\n"
1785 
1786         // Then pairwise accumulate in to q6, q7
1787         "vpadal.s16 q6, q14\n"
1788         "vpadal.s16 q7, q15\n"
1789 
1790         // Mult, mult-acc in to q14, q15
1791         "vmull.s8 q14, d4, d8\n"
1792         "vmull.s8 q15, d6, d8\n"
1793         "vmlal.s8 q14, d5, d9\n"
1794         "vmlal.s8 q15, d7, d9\n"
1795 
1796         // Then pairwise accumulate in to q8, q9
1797         "vpadal.s16 q8, q14\n"
1798         "vpadal.s16 q9, q15\n"
1799 
1800         // All accumulation over depth done. q6 - q9 contain the 4x32b
1801         // accumulators for the 4x1 final matrix.
1802         // We now have to compute the final 8-bit values from these int32
1803         // accumulators, and advance to the next 4x2 block. We intertwine
1804         // these two aspects whenever possible for optimal pipelining, both
1805         // at the data flow level (prefetch data for next block as early as
1806         // possible) and instruction pipelining level (some of the next-block
1807         // work can dual-issue with some of the final work on the current
1808         // block).
1809 
1810         // q6-q9 now contain 4 x 32b
1811         "vpadd.i32 d0, d12, d13\n"
1812         "vpadd.i32 d1, d14, d15\n"
1813         "vpadd.i32 d2, d16, d17\n"
1814         "vpadd.i32 d3, d18, d19\n"
1815 
1816         // d0-d4 each contain 2 x 32b accumulators.
1817         // Need to add pairwise to get 1 x 32b for each of the 4x1 entries
1818         // of destination, (Four 'd' registers total)
1819         "vpadd.i32 d28, d0, d1\n"
1820         "vpadd.i32 d29, d2, d3\n"
1821 
1822         // Now d28,d29 have the 1 x 32b accumulators for the 4x1 entries.
1823 
1824         // Logic to advance to the next block in preparation for the next
1825         // iteration of the main loop. For now, we only want to compute
1826         // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
1827         // not yet ready to update the values of row and col, as we still need
1828         // the current values for the rest of the work on the current block.
1829 
1830         "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
1831         "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1832         "cmp r1, r3\n"  // Have we finished the last row?
1833 
1834         "bge 4f\n"           // If finished last row, go to 4
1835         // Not finished last row: then advance to next row.
1836         "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
1837         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
1838         "add r4, r4, r1, lsl #2\n"
1839         "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
1840         "b 5f\n"
1841         "4:\n"  // Finished last row...
1842         "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
1843         // Go back to first row
1844         "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
1845 
1846         // Now we need to advance to the next column. If we already
1847         // finished the last column, then in principle we are done, however
1848         // we can't just return here, as we need to allow the end work of the
1849         // current block to complete. The good news is that at this point it
1850         // doesn't matter what data we load for the next column, since
1851         // we will exit from the main loop below before actually storing
1852         // anything computed from that data.
1853 
1854         "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
1855         "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1856         "cmp r8, r4\n"  // Have we finished the last column?
1857         "bge 5f\n" // If yes, just carry on without updating the column pointer.
1858         // Not finished last column: then advance to next column.
1859         "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
1860         "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
1861         "add r10, r10, r1, lsl #1\n"
1862         "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
1863         "5:\n"
1864 
1865         // Set the LHS and RHS data pointers to the start of the columns just
1866         // computed.
1867         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
1868         "mov %[lhs_ptr], r4\n"
1869         "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
1870         "mov %[rhs_ptr], r5\n"
1871 
1872         // Now we load: bias data, LHS sums data, RHS sums data.
1873 
1874         // First, load the base pointers from the params.
1875         "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
1876         "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
1877 
1878         // Offset these base pointers as needed given the current row, col.
1879         "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1880 
1881         "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
1882         "beq 1000f\n"
1883         "add r1, r1, r8, lsl #2\n"
1884         "1000:\n"
1885 
1886         // Load 4 bias values.
1887         "vld1.32 {d24, d25}, [r1]\n"
1888 
1889         // Now that we know what LHS and RHS data the next iteration of the
1890         // main loop will need to load, we start loading the first 32 bytes of
1891         // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
1892         // in the rest of the work on the current block.
1893         "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n"
1894         "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n"
1895         "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n"
1896         "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n"
1897         RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
1898         "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n"
1899         // Skip the other column and advance the pointer.
1900         "add %[rhs_ptr], %[rhs_ptr], #16\n"
1901         RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
1902 
1903         // Add to the bias values the product
1904         // (depth * lhs_zero_point * rhs_zero_point),
1905         // See the term NZ1Z2 in equation (7) in
1906         // https://arxiv.org/pdf/1712.05877.pdf
1907         "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
1908         "vdup.32 q9, r3\n"
1909         "vadd.i32 q12, q12, q9\n"
1910 
1911         // Perform the bias-addition (per the above, we have just folded into
1912         // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
1913         "vadd.i32 q14, q14, q12\n"
1914 
1915         // LHS/RHS zero points
1916         // Has RHS sums
1917         "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
1918         "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
1919         "beq 401f\n"
1920         "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
1921         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1922         // Offset by current col * number of bytes per value
1923         "add r3, r3, r4, lsl #2\n"
1924         "vld1.32 { d12 }, [r3]\n"
1925         "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
1926         "vdup.32 q10, r5\n"  // create lhs_zero_point_vec
1927         // Subtract rhs_sums * lhs_zero_point, per
1928         // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
1929         "vmls.i32 q14, q10, d12[0]\n"
1930         "401:\n"
1931 
1932         // Has LHS sums
1933         "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
1934         "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
1935         "beq 402f\n"
1936         "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
1937         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1938         // Offset by current row * number of bytes per value
1939         "add r2, r2, r4, lsl #2\n"
1940         "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
1941 
1942         // Load 4 lhs_sums values.
1943         "vld1.32 {d22, d23}, [r2]\n"
1944         "vdup.32 d13, r5\n" // rhs_zero_point
1945 
1946         // Compute lhs_sums * rhs_zero_point.
1947         "vmul.i32 q11, q11, d13[1]\n"
1948         // Subtract lhs_sums * rhs_zero_point, per
1949         // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
1950         "vsub.s32 q14, q14, q11\n"
1951 
1952         // If the destination is int32, it means the user asks for the raw
1953         // accumulators, no need for us to downquantize the value.
1954         "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n"
1955         "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
1956         "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
1957 
1958         "402:\n"
1959 
1960         // At this point we have computed the final int32 values. Now we
1961         // start down-quantizing them to obtain the final 8bit values from them.
1962 
1963         // As part of this down-quantization, our int32 values will be
1964         // multiplied by a multiplier that has a fixed-point component and an
1965         // exponent component.
1966 
1967         //Load the exponent part of the multiplier.
1968         "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
1969         "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
1970         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1971         "beq 1001f\n"
1972         "add r1, r1, r4, lsl #2\n"
1973         "1001:\n"
1974 
1975         "vld1.32 {q10}, [r1]\n"
1976 
1977         "vmvn.i32 q8, #0\n"
1978         "vmin.s32 q13, q10, q8\n"
1979         "vsub.s32 q12, q10, q13\n"
1980 
1981         // Apply the positive exponent part of the multiplier.
1982         "vshl.s32 q14, q14, q12\n"
1983 
1984         // Load fixed point part of the multiplier
1985         "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
1986         // r6 has flags, r4 has row
1987         "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
1988         "beq 1002f\n"
1989         "add r1, r1, r4, lsl #2\n"
1990         "1002:\n"
1991         "vld1.32 {q10}, [r1]\n" // multiplier_fixedpoint
1992 
1993         // Apply the fixed-point part of the multiplier.
1994         "vqdmulh.s32 q14, q14, q10\n"
1995 
1996         // Apply the negative exponent part of the multiplier.
1997         "vrshl.s32 q14, q14, q13\n"
1998 
1999         "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n"
2000         "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
2001         "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
2002         "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
2003         "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
2004 
2005         // Store uint8 values:
2006         RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
2007 
2008         // Cast-and-saturate from int32 to int16
2009         // After this, all values for output are in d28.
2010         "vqmovn.s32 d28, q14\n"
2011 
2012         // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the
2013         // current block, so we can start clearing these accumulators for the
2014         // next block (next iteration of the main loop).
2015         RUY_MAKE_ZERO(q6)
2016         RUY_MAKE_ZERO(q7)
2017         RUY_MAKE_ZERO(q8)
2018         RUY_MAKE_ZERO(q9)
2019         RUY_MAKE_ZERO(q10)
2020         RUY_MAKE_ZERO(q11)
2021         RUY_MAKE_ZERO(q12)
2022         RUY_MAKE_ZERO(q13)
2023         RUY_MAKE_ZERO(q15)
2024 
2025         // Load the destination zero point into each of the 8 16-bit slots
2026         // in a q register.
2027         "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
2028         "vdup.16 q13, r4\n" // dst_zero_point
2029 
2030         // Add the destination zero point
2031         "vqadd.s16 q14, q14, q13\n"
2032 
2033         // Cast-and-saturate from int16 to uint8
2034         "vqmovun.s16 d30, q14\n"
2035         // At this point, we only need 4 8-bit values in the lower half
2036         // of d30.
2037 
2038 
2039         // Load the clamp_min, clamp_max bounds
2040         "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
2041         "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
2042         "vdup.8 d28, r2\n"  // clamp_min
2043         "vdup.8 d29, r3\n"  // clamp_max
2044 
2045         // Apply the clamp_min bound
2046         "vmax.u8 d30, d30, d28\n"
2047         // Apply the clamp_max bound
2048         "vmin.u8 d30, d30, d29\n"
2049 
2050         // Compute how much of the 4x1 block of destination 8bit values that
2051         // we have computed, fit in the destination matrix. Typically, all of
2052         // it fits, but when the destination matrix shape is not a multiple
2053         // of 4x1, there are some 4x1 blocks along the boundaries that do
2054         // not fit entirely.
2055 
2056         "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
2057         "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
2058         "sub r1, r1, r8\n"
2059 
2060         "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
2061         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
2062         "sub r2, r2, r4\n"
2063         "mov r3, #4\n"
2064         "mov r5, #2\n"
2065         "cmp r1, #4\n"
2066         // Compute r1 = how many rows of the 4x1 block fit
2067         "it gt\n"
2068         "movgt r1, r3\n"
2069 
2070         // Test if r1==4, i.e. if all of the 4x1 block fits.
2071         "cmp r1, r3\n"
2072 
2073         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2074         "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
2075         // Yes, all of the 4x1 block fits, go to fast path.
2076         "beq 30f\n"
2077         // Not all of the 4x1 block fits.
2078         // Store to dst_tmp_buf
2079         // Set r3 address to write to dst_tmp_buf.
2080         "mov r3, %[dst_tmp_buf]\n"
2081         "vst1.8 {d30}, [r3]\n"
2082 
2083         // Slow loop copying from dst_tmp_buf to dst.
2084         "50:\n"
2085         "mov r8, #0\n"
2086         "51:\n"
2087         "ldrb r10, [r3, r8]\n"
2088         "strb r10, [r4, r8]\n"
2089         "add r8, r8, #1\n"
2090         "cmp r8, r1\n"
2091         "blt 51b\n"
2092         "b 31f\n"
2093         "30:\n"
2094         // Yes, all of the 4x1 block fits.
2095         // r3 address, r5 stride
2096         "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2097         "mov r4, r3\n"
2098         "mov r6, #1\n"
2099 
2100         "vst1.8 {d30[0]}, [r3], r6\n"
2101         "vst1.8 {d30[1]}, [r3], r6\n"
2102         "vst1.8 {d30[2]}, [r3], r6\n"
2103         "vst1.8 {d30[3]}, [r3], r6\n"
2104         "31:\n"
2105 
2106         // Load dst_ptr, increment, and write back.
2107         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2108         "add r4, r4, #4\n"
2109         "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2110 
2111         RUY_MAKE_ZERO(q13)
2112         RUY_MAKE_ZERO(q14)
2113         RUY_MAKE_ZERO(q15)
2114 
2115         "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
2116 
2117         // Store int8 values:
2118         RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
2119 
2120         // Cast-and-saturate from int32 to int16
2121         // After this, all values for output are in d28.
2122         "vqmovn.s32 d28, q14\n"
2123 
2124         // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the
2125         // current block, so we can start clearing these accumulators for the
2126         // next block (next iteration of the main loop).
2127         RUY_MAKE_ZERO(q6)
2128         RUY_MAKE_ZERO(q7)
2129         RUY_MAKE_ZERO(q8)
2130         RUY_MAKE_ZERO(q9)
2131         RUY_MAKE_ZERO(q10)
2132         RUY_MAKE_ZERO(q11)
2133         RUY_MAKE_ZERO(q12)
2134         RUY_MAKE_ZERO(q13)
2135         RUY_MAKE_ZERO(q15)
2136 
2137         // Load the destination zero point into each of the 8 16-bit slots
2138         // in a q register.
2139         "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
2140         "vdup.16 q13, r4\n" // dst_zero_point
2141 
2142         // Add the destination zero point
2143         "vqadd.s16 q14, q14, q13\n"
2144 
2145         // Cast-and-saturate from int16 to int8
2146         "vqmovn.s16 d30, q14\n"
2147         // At this point, we only need 4 8-bit values in the lower half
2148         // of d30.
2149 
2150         // Load the clamp_min, clamp_max bounds
2151         "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
2152         "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
2153         "vdup.8 d28, r2\n"  // clamp_min
2154         "vdup.8 d29, r3\n"  // clamp_max
2155 
2156         // Apply the clamp_min bound
2157         "vmax.s8 d30, d30, d28\n"
2158         // Apply the clamp_max bound
2159         "vmin.s8 d30, d30, d29\n"
2160 
2161         // Compute how much of the 4x1 block of destination 8bit values that
2162         // we have computed, fit in the destination matrix. Typically, all of
2163         // it fits, but when the destination matrix shape is not a multiple
2164         // of 4x2, there are some 4x2 blocks along the boundaries that do
2165         // not fit entirely.
2166 
2167         "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
2168         "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
2169         "sub r1, r1, r8\n"
2170 
2171         "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
2172         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
2173         "sub r2, r2, r4\n"
2174         "mov r3, #4\n"
2175         "mov r5, #2\n"
2176         "cmp r1, #4\n"
2177         // Compute r1 = how many rows of the 4x2 block fit
2178         "it gt\n"
2179         "movgt r1, r3\n"
2180 
2181         // Test if r1==4 i.e. if all of the 4x1 block fits.
2182         "cmp r1, r3\n"
2183 
2184         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2185         "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
2186         // Yes, all of the 4x2 block fits, go to fast path.
2187         "beq 30f\n"
2188         // Not all of the 4x2 block fits.
2189         // Store to dst_tmp_buf
2190         // Set r3 address to write to dst_tmp_buf.
2191         "mov r3, %[dst_tmp_buf]\n"
2192         "vst1.8 {d30}, [r3]\n"
2193 
2194         // Slow loop copying from dst_tmp_buf to dst.
2195         "50:\n"
2196         "mov r8, #0\n"
2197         "51:\n"
2198         "ldrb r10, [r3, r8]\n"
2199         "strb r10, [r4, r8]\n"
2200         "add r8, r8, #1\n"
2201         "cmp r8, r1\n"
2202         "blt 51b\n"
2203         "b 31f\n"
2204         "30:\n"
2205         // Yes, all of the 4x1 block fits.
2206         // r3 address, r5 stride
2207         "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2208         "mov r4, r3\n"
2209         "mov r6, #1\n"
2210 
2211         "vst1.8 {d30[0]}, [r3], r6\n"
2212         "vst1.8 {d30[1]}, [r3], r6\n"
2213         "vst1.8 {d30[2]}, [r3], r6\n"
2214         "vst1.8 {d30[3]}, [r3], r6\n"
2215         "31:\n"
2216 
2217         // Load dst_ptr, increment, and write back.
2218         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2219         "add r4, r4, #4\n"
2220         "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2221 
2222         RUY_MAKE_ZERO(q13)
2223         RUY_MAKE_ZERO(q14)
2224         RUY_MAKE_ZERO(q15)
2225 
2226         "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
2227 
2228         RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
2229 
2230         // Load the destination zero point into each of the 4 32-bit slots
2231         // in a q register.
2232         "ldrsh r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
2233         "vdup.32 q13, r4\n" // dst_zero_point
2234         // Add the destination zero point
2235         "vadd.s32 q14, q14, q13\n"
2236         //"vadd.s32 q15, q15, q13\n"
2237 
2238         // Cast-and-saturate from int32 to int16
2239         // After this, all values for output are in d28.
2240         "vqmovn.s32 d28, q14\n"
2241 
2242         // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the
2243         // so we can start clearing these accumulators for the next block
2244         // (next iteration of the main loop).
2245         RUY_MAKE_ZERO(q6)
2246         RUY_MAKE_ZERO(q7)
2247         RUY_MAKE_ZERO(q8)
2248         RUY_MAKE_ZERO(q9)
2249         RUY_MAKE_ZERO(q10)
2250         RUY_MAKE_ZERO(q11)
2251         RUY_MAKE_ZERO(q15)
2252 
2253          // Load the clamp_min, clamp_max bounds
2254         "ldrh r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
2255         "ldrh r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
2256         "vdup.16 d24, r2\n"  // clamp_min
2257         "vdup.16 d26, r3\n"  // clamp_max
2258 
2259         // Apply the clamp_min bound
2260         "vmax.s16 d28, d28, d24\n"
2261         // Apply the clamp_max bound
2262         "vmin.s16 d28, d28, d26\n"
2263 
2264         RUY_MAKE_ZERO(q12)
2265         RUY_MAKE_ZERO(q13)
2266 
2267         // Compute how much of the 4x1 block of destination 16-bit values that
2268         // we have computed, fit in the destination matrix. Typically, all of
2269         // it fits, but when the destination matrix shape is not a multiple
2270         // of 4x1, there are some 4x1 blocks along the boundaries that do
2271         // not fit entirely.
2272 
2273         "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
2274         "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
2275         "sub r1, r1, r8\n"
2276 
2277         "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
2278         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
2279         "sub r2, r2, r4\n"
2280         "mov r3, #4\n"
2281         "mov r5, #2\n"
2282         "cmp r1, #4\n"
2283         // Compute r1 = how many rows of the 4x1 block fit
2284         "it gt\n"
2285         "movgt r1, r3\n"
2286 
2287         // Test if r1==4, i.e. if all of the 4x1 block fits.
2288         "cmp r1, r3\n"
2289 
2290         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2291         "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
2292         // Yes, all of the 4x1 block fits, go to fast path.
2293         "beq 30f\n"
2294         // Not all of the 4x1 block fits.
2295         // Store to dst_tmp_buf
2296         // Set r3 address to write to dst_tmp_buf.
2297         "mov r3, %[dst_tmp_buf]\n"
2298         "vst1.16 {d28}, [r3]\n"
2299 
2300         // Slow loop copying from dst_tmp_buf to dst.
2301         "50:\n"
2302         "mov r8, #0\n"
2303         "51:\n"
2304         // Shift of offset register for half-word loads not allowed in A32,
2305         // so we shift, load/store, then shift back r8.
2306         "lsl r8, r8, #1\n"
2307         "ldrh r10, [r3, r8]\n"
2308         "strh r10, [r4, r8]\n"
2309         "lsr r8, r8, #1\n"
2310         "add r8, r8, #1\n"
2311         "cmp r8, r1\n"
2312         "blt 51b\n"
2313         "b 31f\n"
2314         "30:\n"
2315         // Yes, all of the 4x1 block fits.
2316         // r3 address, r5 stride
2317         "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2318         "mov r4, r3\n"
2319         "mov r6, #2\n"
2320 
2321         "vst1.16 {d28[0]}, [r3], r6\n"
2322         "vst1.16 {d28[1]}, [r3], r6\n"
2323         "vst1.16 {d28[2]}, [r3], r6\n"
2324         "vst1.16 {d28[3]}, [r3], r6\n"
2325         "31:\n"
2326 
2327          // Load dst_ptr, increment, and write back.
2328         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2329         "add r4, r4, #8\n"
2330         "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2331 
2332         RUY_MAKE_ZERO(q14)
2333 
2334         "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
2335 
2336         RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
2337 
2338         // Since the store type is the same as the accum type, no need for
2339         // downcast. There's also no need for clamp by min/max.
2340 
2341         // At this point, v20 -- v31 aren't used anymore for the current block,
2342         // so we can start clearing these accumulators for the next block
2343         // (next iteration of the main loop).
2344         // Clear accumulators.
2345         RUY_MAKE_ZERO(q6)
2346         RUY_MAKE_ZERO(q7)
2347         RUY_MAKE_ZERO(q8)
2348         RUY_MAKE_ZERO(q9)
2349         RUY_MAKE_ZERO(q10)
2350         RUY_MAKE_ZERO(q11)
2351         RUY_MAKE_ZERO(q12)
2352         RUY_MAKE_ZERO(q13)
2353 
2354         // Compute how much of the 4x1 block of destination 32 bit values that
2355         // we have computed, fit in the destination matrix. Typically, all of
2356         // it fits, but when the destination matrix shape is not a multiple
2357         // of 4x2, there are some 4x4 blocks along the boundaries that do
2358         // not fit entirely.
2359 
2360         "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
2361         "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
2362         "sub r1, r1, r8\n"
2363 
2364         "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
2365         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
2366         "sub r2, r2, r4\n"
2367         "mov r3, #4\n"
2368         "mov r5, #2\n"
2369         "cmp r1, #4\n"
2370         // Compute r1 = how many rows of the 4x2 block fit
2371         "it gt\n"
2372         "movgt r1, r3\n"
2373 
2374         // Test if r1==4, i.e. if all of the 4x1 block fits.
2375         "cmp r1, r3\n"
2376 
2377         // Yes, all of the 4x1 block fits, go to fast path.
2378         "beq 30f\n"
2379         // Not all of the 4x1 block fits.
2380         // Set (r3 address, r4 stride) to write to dst_tmp_buf
2381         "mov r3, %[dst_tmp_buf]\n"
2382         "mov r4, #16\n"
2383         "b 31f\n"
2384 
2385         "30:\n"
2386         // Yes, all of the 4x1 block fits.
2387         "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
2388         // r3 address, r4 stride
2389         "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2390         "mov r4, r5\n"
2391 
2392         "31:\n"
2393 
2394         "vst1.32 {d28, d29}, [r3]\n"
2395 
2396         // If all of the 4x1 block fits, we just finished writing it to the
2397         // destination, so we skip the next part.
2398         "beq 41f\n"
2399         // Not all of the 4x1 block fits in the destination matrix.  We just
2400         // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
2401         // it to copy into the destination matrix the part that fits.
2402         "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
2403         "mov r3, %[dst_tmp_buf]\n"
2404         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2405         "50:\n"
2406         "mov r5, #0\n"
2407         "51:\n"
2408         "ldr r10, [r3, r5, lsl #2]\n"
2409         "str r10, [r4, r5, lsl #2]\n"
2410         "add r5, r5, #1\n"
2411         "cmp r5, r1\n"
2412         "blt 51b\n"
2413 
2414         "41:\n"
2415         // Load dst_ptr, increment, and write back.
2416         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2417         "add r4, r4, #16\n"
2418         "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2419 
2420         RUY_MAKE_ZERO(q10)
2421         RUY_MAKE_ZERO(q11)
2422 
2423         "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
2424 
2425         RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
2426 
2427         // Reload some params --- we had used x5 -- x7 for a few other things
2428         // since the last time we had loaded them.
2429         "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
2430         "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
2431         "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
2432 
2433         // Move to the next block of the destination matrix, for the next iter
2434         // of the main loop.  Notice that lhs_col_ptr, rhs_col_ptr have already
2435         // been updated earlier.
2436         // Have we reached the end row?
2437         "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
2438         "cmp r8, r3\n"
2439 
2440         "beq 20f\n"  // yes, end row.
2441         // Not end row. Move to the next row.
2442         "add r8, r8, #4\n"
2443         // Store new value of row
2444         "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
2445 
2446         "b 21f\n"
2447         "20:\n"
2448         // Was already at end row.
2449         // Move back to first row.
2450         "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
2451         // Move to the next column.
2452         "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
2453         "add r4, r4, #2\n"
2454         "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
2455 
2456         "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
2457         "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
2458         // Increment dst_col_ptr by dst_stride (i.e. 1 column)
2459         "add r1, r1, r8\n"
2460         // Store dst_col_ptr
2461         "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
2462         // Store dst_ptr
2463         "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2464         "21:\n"
2465 
2466         // Main loop exit condition: have we hit the end column?
2467         "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
2468         "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
2469         "cmp r8, r4\n"
2470 
2471         // w1 is the number of levels of depth that we have already loaded
2472         // LHS and RHS data for. Corresponding to the initial ld1 instructions
2473         // above, this is currently 16.
2474         "mov r1, #16\n"
2475 
2476         "ble 1b\n"
2477 
2478         // Restore stack pointer.
2479         "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
2480 
2481         // clang-format on
2482 
2483         : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr)
2484         : [ params ] "r"(&params), [dst_tmp_buf] "r"(params.dst_tmp_buf)
2485         : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc",
2486            // Clobber list must specify q registers (and not their constituent
2487            // d registers). There is a (currently unexplained) slowdown if
2488            // d registers are listed in the clobbers list.
2489           "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8",
2490           "q9", "q10", "q12", "q13", "q14", "q15");
2491 }
2492 
2493 #undef RUY_OFFSET_BIAS
2494 #undef RUY_OFFSET_LHS_SUMS
2495 #undef RUY_OFFSET_RHS_SUMS
2496 #undef RUY_OFFSET_LHS_BASE_PTR
2497 #undef RUY_OFFSET_MULTIPLIER_FIXEDPOINT
2498 #undef RUY_OFFSET_MULTIPLIER_EXPONENT
2499 #undef RUY_OFFSET_RHS_BASE_PTR
2500 #undef RUY_OFFSET_DST_BASE_PTR
2501 #undef RUY_OFFSET_LHS_ZERO_POINT
2502 #undef RUY_OFFSET_RHS_ZERO_POINT
2503 #undef RUY_OFFSET_DST_ZERO_POINT
2504 #undef RUY_OFFSET_PROD_ZP_DEPTH
2505 #undef RUY_OFFSET_START_ROW
2506 #undef RUY_OFFSET_START_COL
2507 #undef RUY_OFFSET_LAST_ROW
2508 #undef RUY_OFFSET_LAST_COL
2509 #undef RUY_OFFSET_DST_ROWS
2510 #undef RUY_OFFSET_DST_COLS
2511 #undef RUY_OFFSET_LHS_STRIDE
2512 #undef RUY_OFFSET_RHS_STRIDE
2513 #undef RUY_OFFSET_DST_STRIDE
2514 #undef RUY_OFFSET_DEPTH
2515 #undef RUY_OFFSET_CLAMP_MIN
2516 #undef RUY_OFFSET_CLAMP_MAX
2517 #undef RUY_OFFSET_FLAGS
2518 #undef RUY_OFFSET_DST_TYPE_ID
2519 
2520 #undef RUY_STACK_OFFSET_SIZE
2521 #undef RUY_STACK_OFFSET_DST_COL_PTR
2522 #undef RUY_STACK_OFFSET_DST_PTR
2523 #undef RUY_STACK_OFFSET_ROW
2524 #undef RUY_STACK_OFFSET_COL
2525 #undef RUY_STACK_OFFSET_LHS_COL_PTR
2526 #undef RUY_STACK_OFFSET_RHS_COL_PTR
2527 
2528 #endif  // RUY_PLATFORM_NEON_32 && (RUY_OPT(ASM)
2529 }  // namespace ruy
2530