xref: /aosp_15_r20/external/gemmlowp/internal/kernel_msa.h (revision 5f39d1b313f0528e11bae88b3029b54b9e1033e7)
1 // Copyright 2018 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // kernel_msa.h: a collection of MSA optimized kernels.
16 // Check in kernel_default.h which one(s) are actually used by default.
17 // Others are mere experiments; they are still covered by tests
18 // in case they might be useful some day.
19 
20 #ifndef GEMMLOWP_INTERNAL_KERNEL_MSA_H_
21 #define GEMMLOWP_INTERNAL_KERNEL_MSA_H_
22 
23 #include "kernel.h"
24 
25 #include <msa.h>
26 #include <cassert>
27 
28 namespace gemmlowp {
29 
30 #ifdef GEMMLOWP_MSA
31 
32 // Some convenience macros to hide differences between MIPS32 and MIPS64.
33 #ifdef GEMMLOWP_MIPS_64
34 #define GEMMLOWP_MIPS_XADDU "daddu"
35 #define GEMMLOWP_MIPS_XADDIU "daddiu"
36 #define GEMMLOWP_MIPS_XSLL "dsll"
37 #else
38 #define GEMMLOWP_MIPS_XADDU "addu"
39 #define GEMMLOWP_MIPS_XADDIU "addiu"
40 #define GEMMLOWP_MIPS_XSLL "sll"
41 #endif
42 
43 // Our main GEMM kernel.
44 struct MSA_Kernel12x8Depth2 : KernelBase {
45   typedef KernelFormat<KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 3>,
46                        KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 2> >
47       Format;
48 
NameMSA_Kernel12x8Depth249   const char* Name() const override { return "MSA, 12x8, depth 2"; }
50 
51   // TODO(benoitjacob): reorder function arguments so dst comes last
RunMSA_Kernel12x8Depth252   void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride,
53            std::size_t dst_col_stride, const std::uint8_t* lhs_ptr,
54            const std::uint8_t* rhs_ptr, std::size_t start_depth,
55            std::size_t run_depth) const override {
56     ScopedProfilingLabel label("optimized kernel (MSA 12x8)");
57 // See comments above for why we need local numerical labels in our asm.
58 #define GEMMLOWP_LABEL_CLEAR_ACCUMULATORS "1"
59 #define GEMMLOWP_LABEL_BEFORE_LOOP "2"
60 #define GEMMLOWP_LABEL_LOOP "3"
61 #define GEMMLOWP_LABEL_AFTER_LOOP "4"
62 
63     assert(dst_row_stride == 1);
64     asm volatile(
65         // Multiply dst_col_stride by 4 == sizeof(int32) to use
66         // it as a byte offset below.
67         GEMMLOWP_MIPS_XSLL
68         " %[dst_col_stride], %[dst_col_stride], 2\n"
69 
70         // Check if start_depth==0 to decide whether we will clear
71         // accumulators or load existing accumulators.
72         "beqz   %[start_depth], " GEMMLOWP_LABEL_CLEAR_ACCUMULATORS "f\n"
73 
74         // Load accumulators (start_depth != 0).
75         GEMMLOWP_MIPS_XADDU " $a0, %[dst_ptr], %[dst_col_stride]\n"
76         "ld.w   $w0,  (0*16)(%[dst_ptr])\n"
77         "ld.w   $w4,  (1*16)(%[dst_ptr])\n"
78         "ld.w   $w8,  (2*16)(%[dst_ptr])\n" GEMMLOWP_MIPS_XADDU " $a1, $a0, %[dst_col_stride]\n"
79         "ld.w   $w1,  (0*16)($a0)\n"
80         "ld.w   $w5,  (1*16)($a0)\n"
81         "ld.w   $w9,  (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU " $a0, $a1, %[dst_col_stride]\n"
82         "ld.w   $w2,  (0*16)($a1)\n"
83         "ld.w   $w6,  (1*16)($a1)\n"
84         "ld.w   $w10, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU " $a1, $a0, %[dst_col_stride]\n"
85         "ld.w   $w3,  (0*16)($a0)\n"
86         "ld.w   $w7,  (1*16)($a0)\n"
87         "ld.w   $w11, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU " $a0, $a1, %[dst_col_stride]\n"
88         "ld.w   $w12, (0*16)($a1)\n"
89         "ld.w   $w16, (1*16)($a1)\n"
90         "ld.w   $w20, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU " $a1, $a0, %[dst_col_stride]\n"
91         "ld.w   $w13, (0*16)($a0)\n"
92         "ld.w   $w17, (1*16)($a0)\n"
93         "ld.w   $w21, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU " $a0, $a1, %[dst_col_stride]\n"
94         "ld.w   $w14, (0*16)($a1)\n"
95         "ld.w   $w18, (1*16)($a1)\n"
96         "ld.w   $w22, (2*16)($a1)\n"
97         "ld.w   $w15, (0*16)($a0)\n"
98         "ld.w   $w19, (1*16)($a0)\n"
99         "ld.w   $w23, (2*16)($a0)\n"
100         "b " GEMMLOWP_LABEL_BEFORE_LOOP "f\n"
101 
102         GEMMLOWP_LABEL_CLEAR_ACCUMULATORS ":\n"
103         // Clear accumulators (start_depth == 0).
104         "ldi.w  $w0,  0\n"
105         "ldi.w  $w4,  0\n"
106         "ldi.w  $w8,  0\n"
107         "ldi.w  $w1,  0\n"
108         "ldi.w  $w5,  0\n"
109         "ldi.w  $w9,  0\n"
110         "ldi.w  $w2,  0\n"
111         "ldi.w  $w6,  0\n"
112         "ldi.w  $w10, 0\n"
113         "ldi.w  $w3,  0\n"
114         "ldi.w  $w7,  0\n"
115         "ldi.w  $w11, 0\n"
116         "ldi.w  $w12, 0\n"
117         "ldi.w  $w16, 0\n"
118         "ldi.w  $w20, 0\n"
119         "ldi.w  $w13, 0\n"
120         "ldi.w  $w17, 0\n"
121         "ldi.w  $w21, 0\n"
122         "ldi.w  $w14, 0\n"
123         "ldi.w  $w18, 0\n"
124         "ldi.w  $w22, 0\n"
125         "ldi.w  $w15, 0\n"
126         "ldi.w  $w19, 0\n"
127         "ldi.w  $w23, 0\n"
128 
129         GEMMLOWP_LABEL_BEFORE_LOOP ":\n"
130 
131         GEMMLOWP_LABEL_LOOP ":\n"
132         // Overview of register layout:
133         //
134         // A half of the 2 2x4 cells of Rhs is stored in 16bit in w28-w31
135         // (each register contains 4 replicas of a pair of elements).
136         // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in w24-w26.
137         // A 12x8 block of accumulators is stored in 32bit in w0-w23.
138         //
139         //                    +------+------+------+------+
140         //               Rhs  |w28   |w29   |w30   |w31   |
141         //                    +------+------+------+------+
142         //
143         //                    |      |      |      |      |
144         //
145         //       Lhs          |      |      |      |      |
146         //
147         //      +---+ - - - - +------+------+------+------+
148         //      |w24|         |w0/12 |w1/13 |w2/14 |w3/15 |
149         //      |w24|         |w0/12 |w1/13 |w2/14 |w3/15 |
150         //      |w24|         |w0/12 |w1/13 |w2/14 |w3/15 |
151         //      |w24|         |w0/12 |w1/13 |w2/14 |w3/15 |
152         //      +---+ - - - - +------+------+------+------+
153         //      |w25|         |w4/16 |w5/17 |w6/18 |w7/19 |
154         //      |w25|         |w4/16 |w5/17 |w6/18 |w7/19 |
155         //      |w25|         |w4/16 |w5/17 |w6/18 |w7/19 |
156         //      |w25|         |w4/16 |w5/17 |w6/18 |w7/19 |
157         //      +---+ - - - - +------+------+------+------+
158         //      |w26|         |w8/20 |w9/21 |w10/22|w11/23|
159         //      |w26|         |w8/20 |w9/21 |w10/22|w11/23|
160         //      |w26|         |w8/20 |w9/21 |w10/22|w11/23|
161         //      |w26|         |w8/20 |w9/21 |w10/22|w11/23|
162         //      +---+ - - - - +------+------+------+------+
163         //
164         //                             Accumulators
165 
166         // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads.
167         "ld.b   $w24, 0(%[lhs_ptr])\n"
168         "ld.b   $w25, 8(%[lhs_ptr])\n"
169 
170         // Load 2 x 8 bytes of rhs[].
171         "ld.b   $w27, 0(%[rhs_ptr])\n"
172 
173         // Zero-extend 8-bit elements of lhs[] to 16 bits.
174         "ldi.b  $w31, 0\n"
175         "ilvr.b $w24, $w31, $w24\n"
176         "ilvl.b $w26, $w31, $w25\n"
177         "ilvr.b $w25, $w31, $w25\n"
178 
179         // First half of depths 0 and 1.
180         // Zero-extend 8-bit elements of rhs[] to 16 bits.
181         "ilvr.b    $w31, $w31, $w27\n"
182         // Make 4 replicas of every pair of rhs[] elements.
183         "splati.w  $w28, $w31[0]\n"
184         "splati.w  $w29, $w31[1]\n"
185         "splati.w  $w30, $w31[2]\n"
186         "splati.w  $w31, $w31[3]\n"
187         // Dot-product-(and)-add doubles multiplicand width.
188         "dpadd_u.w  $w0, $w24, $w28\n"
189         "dpadd_u.w  $w4, $w25, $w28\n"
190         "dpadd_u.w  $w8, $w26, $w28\n"
191         "dpadd_u.w  $w1, $w24, $w29\n"
192         "dpadd_u.w  $w5, $w25, $w29\n"
193         "dpadd_u.w  $w9, $w26, $w29\n"
194         "dpadd_u.w  $w2, $w24, $w30\n"
195         "dpadd_u.w  $w6, $w25, $w30\n"
196         "dpadd_u.w $w10, $w26, $w30\n"
197         "dpadd_u.w  $w3, $w24, $w31\n"
198         "dpadd_u.w  $w7, $w25, $w31\n"
199         "dpadd_u.w $w11, $w26, $w31\n"
200 
201         // Second half of depths 0 and 1.
202         // Zero-extend 8-bit elements of rhs[] to 16 bits.
203         "ldi.b     $w31, 0\n"
204         "ilvl.b    $w31, $w31, $w27\n"
205         // Make 4 replicas of every pair of rhs[] elements.
206         "splati.w  $w28, $w31[0]\n"
207         "splati.w  $w29, $w31[1]\n"
208         "splati.w  $w30, $w31[2]\n"
209         "splati.w  $w31, $w31[3]\n"
210         // Dot-product-(and)-add doubles multiplicand width.
211         "dpadd_u.w $w12, $w24, $w28\n"
212         "dpadd_u.w $w16, $w25, $w28\n"
213         "dpadd_u.w $w20, $w26, $w28\n"
214         "dpadd_u.w $w13, $w24, $w29\n"
215         "dpadd_u.w $w17, $w25, $w29\n"
216         "dpadd_u.w $w21, $w26, $w29\n"
217         "dpadd_u.w $w14, $w24, $w30\n"
218         "dpadd_u.w $w18, $w25, $w30\n"
219         "dpadd_u.w $w22, $w26, $w30\n"
220         "dpadd_u.w $w15, $w24, $w31\n"
221         "dpadd_u.w $w19, $w25, $w31\n"
222         "dpadd_u.w $w23, $w26, $w31\n"
223 
224         GEMMLOWP_MIPS_XADDIU " %[run_depth], -2\n" GEMMLOWP_MIPS_XADDIU
225         " %[lhs_ptr], 24\n" GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 16\n"
226         "bnez   %[run_depth]," GEMMLOWP_LABEL_LOOP "b\n"
227 
228         GEMMLOWP_LABEL_AFTER_LOOP ":\n"
229 
230         // Store accumulators.
231         GEMMLOWP_MIPS_XADDU " $a0, %[dst_ptr], %[dst_col_stride]\n"
232         "st.w   $w0,  (0*16)(%[dst_ptr])\n"
233         "st.w   $w4,  (1*16)(%[dst_ptr])\n"
234         "st.w   $w8,  (2*16)(%[dst_ptr])\n" GEMMLOWP_MIPS_XADDU " $a1, $a0, %[dst_col_stride]\n"
235         "st.w   $w1,  (0*16)($a0)\n"
236         "st.w   $w5,  (1*16)($a0)\n"
237         "st.w   $w9,  (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU " $a0, $a1, %[dst_col_stride]\n"
238         "st.w   $w2,  (0*16)($a1)\n"
239         "st.w   $w6,  (1*16)($a1)\n"
240         "st.w   $w10, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU " $a1, $a0, %[dst_col_stride]\n"
241         "st.w   $w3,  (0*16)($a0)\n"
242         "st.w   $w7,  (1*16)($a0)\n"
243         "st.w   $w11, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU " $a0, $a1, %[dst_col_stride]\n"
244         "st.w   $w12, (0*16)($a1)\n"
245         "st.w   $w16, (1*16)($a1)\n"
246         "st.w   $w20, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU " $a1, $a0, %[dst_col_stride]\n"
247         "st.w   $w13, (0*16)($a0)\n"
248         "st.w   $w17, (1*16)($a0)\n"
249         "st.w   $w21, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU " $a0, $a1, %[dst_col_stride]\n"
250         "st.w   $w14, (0*16)($a1)\n"
251         "st.w   $w18, (1*16)($a1)\n"
252         "st.w   $w22, (2*16)($a1)\n"
253         "st.w   $w15, (0*16)($a0)\n"
254         "st.w   $w19, (1*16)($a0)\n"
255         "st.w   $w23, (2*16)($a0)\n"
256         :  // outputs
257         [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), [run_depth] "+r"(run_depth),
258         [dst_col_stride] "+r"(dst_col_stride)
259         :  // inputs
260         [dst_ptr] "r"(dst_ptr),
261         [start_depth] "r"(start_depth)
262         :  // clobbers
263         "memory", "a0", "a1", "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", "$f8", "$f9",
264         "$f10", "$f11", "$f12", "$f13", "$f14", "$f15", "$f16", "$f17", "$f18", "$f19", "$f20",
265         "$f21", "$f22", "$f23", "$f24", "$f25", "$f26", "$f27", "$f28", "$f29", "$f30", "$f31");
266 
267 #undef GEMMLOWP_LABEL_CLEAR_ACCUMULATORS
268 #undef GEMMLOWP_LABEL_BEFORE_LOOP
269 #undef GEMMLOWP_LABEL_LOOP
270 #undef GEMMLOWP_LABEL_AFTER_LOOP
271   }
272 };
273 
274 // Fast kernel operating on int8 operands.
275 // It is assumed that one of the two int8 operands only takes values
276 // in [-127, 127], while the other may freely range in [-128, 127].
277 // The issue with both operands taking the value -128 is that:
278 // -128*-128 + -128*-128 == -32768 overflows int16.
279 // Every other expression a*b + c*d, for any int8 a,b,c,d, fits in int16
280 // range. That is the basic idea of this kernel.
281 struct MSA_GEMM_Int8Operands_LhsNonzero : KernelBase {
282   typedef KernelFormat<
283       KernelSideFormatInt8<CellFormat<4, 16, CellOrder::WidthMajor>, 1>,
284       KernelSideFormatInt8<CellFormat<4, 16, CellOrder::WidthMajor>, 1> >
285       Format;
286 
NameMSA_GEMM_Int8Operands_LhsNonzero287   const char* Name() const override {
288     return "MSA, 4x4, depth 16, accumulating two within signed int16";
289   }
290 
291   // TODO(benoitjacob): reorder function arguments so dst comes last
RunMSA_GEMM_Int8Operands_LhsNonzero292   void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride,
293            std::size_t dst_col_stride, const std::uint8_t* lhs_ptr,
294            const std::uint8_t* rhs_ptr, std::size_t start_depth,
295            std::size_t run_depth) const override {
296     (void)dst_row_stride;
297 #define GEMMLOWP_LABEL_AFTER_LOOP_LAST16 "1"
298 #define GEMMLOWP_LABEL_LOOP "2"
299 #define GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES "3"
300 #define GEMMLOWP_LABEL_STORE "4"
301     asm volatile(
302         GEMMLOWP_MIPS_XADDIU " %[run_depth], -16\n"
303         // Load lhs[] and rhs[], zero out internal accumulators.
304         "ld.b       $w16, 0(%[lhs_ptr])\n"
305         "ldi.b      $w0, 0\n"
306         "ld.b       $w20, 0(%[rhs_ptr])\n"
307         "ldi.b      $w1, 0\n"
308         "ld.b       $w17, 16(%[lhs_ptr])\n"
309         "ldi.b      $w2, 0\n"
310         "ld.b       $w21, 16(%[rhs_ptr])\n"
311         "ldi.b      $w3, 0\n"
312         "ld.b       $w18, 32(%[lhs_ptr])\n"
313         "ldi.b      $w4, 0\n"
314         "ld.b       $w19, 48(%[lhs_ptr])\n"
315         "ldi.b      $w5, 0\n"
316         "ld.b       $w22, 32(%[rhs_ptr])\n"
317         "ldi.b      $w6, 0\n"
318         "ld.b       $w23, 48(%[rhs_ptr])\n"
319         "ldi.b      $w7, 0\n"
320         "ldi.b      $w8, 0\n"
321         "ldi.b      $w9, 0\n"
322         "ldi.b      $w10, 0\n"
323         "ldi.b      $w11, 0\n"
324         "ldi.b      $w12, 0\n"
325         "ldi.b      $w13, 0\n"
326         "ldi.b      $w14, 0\n"
327         "ldi.b      $w15, 0\n"
328         "ldi.h      $w31, 1\n"
329         // If the loop depth is only 16, then we can skip the general loop
330         // and go straight to the final part of the code.
331         "beqz %[run_depth], " GEMMLOWP_LABEL_AFTER_LOOP_LAST16 "f\n"
332 
333         GEMMLOWP_LABEL_LOOP ":\n"
334         // Overview of register layout:
335         //
336         // A 4x16 block of Rhs is stored in 8 bit in w16-w19.
337         // A 4x16 block of Lhs is stored in 8 bit in w20-w23.
338         //
339         // A 4x4 block of accumulators is stored in w0-w15 (as 4x32 bit
340         // components which need to be horizontally added at the end).
341         //
342         // Dot products of Lhs and Rhs are 16-bit values, which can't
343         // immediately be accumulated in 32-bit accumulators by that
344         // same instruction that calculates them.
345         // For example, "dotp_s.h $w25, $w16, $w20" produces 8 16-bit
346         // sums in w25 (note, the 16 sums have already been reduced to 8
347         // by the horizontal addition of the dotp instruction).
348         // They are then sign-extended to 32 bits, horizontally added
349         // (again) to form 4 32-bit sums and then they are finally added
350         // to the 32-bit accumulators, all by "dpadd_s.w $w0, $w25, $w31".
351         //
352         //                    +-----+-----+-----+-----+
353         //               Rhs  | w20 | w21 | w22 | w23 |
354         //                    +-----+-----+-----+-----+
355         //
356         //                    |     |     |     |     |
357         //
358         //       Lhs          |     |     |     |     |
359         //
360         //      +---+ - - - - +-----+-----+-----+-----+
361         //      |w16|         | w0  | w4  | w8  | w12 |
362         //      |w17|         | w1  | w5  | w9  | w13 |
363         //      |w18|         | w2  | w6  | w10 | w14 |
364         //      |w19|         | w3  | w7  | w11 | w15 |
365         //      +---+ - - - - +-----+-----+-----+-----+
366         //
367         //                           Accumulators
368 
369         // Calculate the results for 16 depths and load
370         // lhs[] and rhs[] for the next iteration.
371         GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 64\n"
372         GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 64\n"
373         GEMMLOWP_MIPS_XADDIU " %[run_depth], -16\n"
374 
375         // Dot product: multiply-add pairs of adjacent int8 elements.
376         // Each dot product takes 16*2 int8 values in and produces 8 int16 sums.
377         "dotp_s.h   $w25, $w16, $w20\n"
378         "dotp_s.h   $w26, $w17, $w20\n"
379         "dotp_s.h   $w27, $w16, $w21\n"
380         "dotp_s.h   $w28, $w17, $w21\n"
381         "dotp_s.h   $w29, $w18, $w20\n"
382         // Horizontal add of pairs of adjacent int16 sums into internal int32
383         // accumulators.
384         "dpadd_s.w  $w0, $w25, $w31\n"
385         "dpadd_s.w  $w1, $w26, $w31\n"
386         "dpadd_s.w  $w4, $w27, $w31\n"
387         "dpadd_s.w  $w5, $w28, $w31\n"
388         "dpadd_s.w  $w2, $w29, $w31\n"
389 
390         // Dot product: multiply-add pairs of adjacent int8 elements.
391         // Each dot product takes 16*2 int8 values in and produces 8 int16 sums.
392         "dotp_s.h   $w24, $w16, $w22\n"
393         "dotp_s.h   $w25, $w19, $w20\n"
394         "dotp_s.h   $w26, $w16, $w23\n"
395         "dotp_s.h   $w27, $w17, $w22\n"
396         "ld.b       $w20, 0(%[rhs_ptr])\n"
397         "dotp_s.h   $w28, $w17, $w23\n"
398         "ld.b       $w16, 0(%[lhs_ptr])\n"
399         "dotp_s.h   $w29, $w18, $w21\n"
400         "ld.b       $w17, 16(%[lhs_ptr])\n"
401         // Horizontal add of pairs of adjacent int16 sums into internal int32
402         // accumulators.
403         "dpadd_s.w  $w8, $w24, $w31\n"
404         "dpadd_s.w  $w3, $w25, $w31\n"
405         "dpadd_s.w  $w12, $w26, $w31\n"
406         "dpadd_s.w  $w9, $w27, $w31\n"
407         "dpadd_s.w  $w13, $w28, $w31\n"
408         "dpadd_s.w  $w6, $w29, $w31\n"
409 
410         // Dot product: multiply-add pairs of adjacent int8 elements.
411         // Each dot product takes 16*2 int8 values in and produces 8 int16 sums.
412         "dotp_s.h   $w25, $w19, $w21\n"
413         "dotp_s.h   $w26, $w18, $w22\n"
414         "dotp_s.h   $w27, $w18, $w23\n"
415         "ld.b       $w21, 16(%[rhs_ptr])\n"
416         "dotp_s.h   $w28, $w19, $w22\n"
417         "ld.b       $w18, 32(%[lhs_ptr])\n"
418         "dotp_s.h   $w29, $w19, $w23\n"
419         "ld.b       $w22, 32(%[rhs_ptr])\n"
420         // Horizontal add of pairs of adjacent int16 sums into internal int32
421         // accumulators.
422         "dpadd_s.w  $w7, $w25, $w31\n"
423         "ld.b       $w19, 48(%[lhs_ptr])\n"
424         "dpadd_s.w  $w10, $w26, $w31\n"
425         "ld.b       $w23, 48(%[rhs_ptr])\n"
426         "dpadd_s.w  $w14, $w27, $w31\n"
427         "dpadd_s.w  $w11, $w28, $w31\n"
428         "dpadd_s.w  $w15, $w29, $w31\n"
429 
430         "bnez %[run_depth], " GEMMLOWP_LABEL_LOOP "b\n"
431 
432         GEMMLOWP_LABEL_AFTER_LOOP_LAST16 ":\n"
433         // Calculate the results for the last 16 depths.
434 
435         // Dot product: multiply-add pairs of adjacent int8 elements.
436         // Each dot product takes 16*2 int8 values in and produces 8 int16 sums.
437         "dotp_s.h   $w25, $w16, $w20\n"
438         "dotp_s.h   $w26, $w17, $w20\n"
439         "dotp_s.h   $w27, $w16, $w21\n"
440         "dotp_s.h   $w28, $w17, $w21\n"
441         "dotp_s.h   $w29, $w18, $w20\n"
442         // Horizontal add of pairs of adjacent int16 sums into internal int32
443         // accumulators.
444         "dpadd_s.w  $w0, $w25, $w31\n"
445         "dpadd_s.w  $w1, $w26, $w31\n"
446         "dpadd_s.w  $w4, $w27, $w31\n"
447         "dpadd_s.w  $w5, $w28, $w31\n"
448         "dpadd_s.w  $w2, $w29, $w31\n"
449 
450         // Dot product: multiply-add pairs of adjacent int8 elements.
451         // Each dot product takes 16*2 int8 values in and produces 8 int16 sums.
452         "dotp_s.h   $w24, $w16, $w22\n"
453         "dotp_s.h   $w25, $w19, $w20\n"
454         "dotp_s.h   $w26, $w16, $w23\n"
455         "dotp_s.h   $w27, $w17, $w22\n"
456         "dotp_s.h   $w28, $w17, $w23\n"
457         "dotp_s.h   $w29, $w18, $w21\n"
458         // Horizontal add of pairs of adjacent int16 sums into internal int32
459         // accumulators.
460         "dpadd_s.w  $w8, $w24, $w31\n"
461         "dpadd_s.w  $w3, $w25, $w31\n"
462         "dpadd_s.w  $w12, $w26, $w31\n"
463         "dpadd_s.w  $w9, $w27, $w31\n"
464         "dpadd_s.w  $w13, $w28, $w31\n"
465         "dpadd_s.w  $w6, $w29, $w31\n"
466 
467         // Dot product: multiply-add pairs of adjacent int8 elements.
468         // Each dot product takes 16*2 int8 values in and produces 8 int16 sums.
469         "dotp_s.h   $w25, $w19, $w21\n"
470         "dotp_s.h   $w26, $w18, $w22\n"
471         "dotp_s.h   $w27, $w18, $w23\n"
472         "dotp_s.h   $w28, $w19, $w22\n"
473         "dotp_s.h   $w29, $w19, $w23\n"
474         // Horizontal add of pairs of adjacent int16 sums into internal int32
475         // accumulators.
476         "dpadd_s.w  $w7, $w25, $w31\n"
477         "dpadd_s.w  $w10, $w26, $w31\n"
478         "dpadd_s.w  $w14, $w27, $w31\n"
479         "dpadd_s.w  $w11, $w28, $w31\n"
480         "dpadd_s.w  $w15, $w29, $w31\n"
481 
482         // Horizontal-add internal accumulators.
483         "hadd_s.d   $w0, $w0, $w0\n"
484         "hadd_s.d   $w1, $w1, $w1\n"
485         "hadd_s.d   $w2, $w2, $w2\n"
486         "hadd_s.d   $w3, $w3, $w3\n"
487         "hadd_s.d   $w4, $w4, $w4\n"
488         "hadd_s.d   $w5, $w5, $w5\n"
489         "hadd_s.d   $w6, $w6, $w6\n"
490         "hadd_s.d   $w7, $w7, $w7\n"
491         "hadd_s.d   $w8, $w8, $w8\n"
492         "hadd_s.d   $w9, $w9, $w9\n"
493         "hadd_s.d   $w10, $w10, $w10\n"
494         "hadd_s.d   $w11, $w11, $w11\n"
495         "hadd_s.d   $w12, $w12, $w12\n"
496         "hadd_s.d   $w13, $w13, $w13\n"
497         "hadd_s.d   $w14, $w14, $w14\n"
498         "hadd_s.d   $w15, $w15, $w15\n"
499         "pckev.w    $w0, $w1, $w0\n"
500         "pckev.w    $w2, $w3, $w2\n"
501         "pckev.w    $w4, $w5, $w4\n"
502         "pckev.w    $w6, $w7, $w6\n"
503         "pckev.w    $w8, $w9, $w8\n"
504         "pckev.w    $w10, $w11, $w10\n"
505         "pckev.w    $w12, $w13, $w12\n"
506         "pckev.w    $w14, $w15, $w14\n"
507         "hadd_s.d   $w0, $w0, $w0\n"
508         "hadd_s.d   $w2, $w2, $w2\n"
509         "hadd_s.d   $w4, $w4, $w4\n"
510         "hadd_s.d   $w6, $w6, $w6\n"
511         "hadd_s.d   $w8, $w8, $w8\n"
512         "hadd_s.d   $w10, $w10, $w10\n"
513         "hadd_s.d   $w12, $w12, $w12\n"
514         "hadd_s.d   $w14, $w14, $w14\n"
515         // 4 more pckev instructions follow in both paths below.
516 
517         // Check if start_depth==0 to decide whether we will load
518         // existing accumulators from memory.
519         "bnez %[start_depth], " GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES "f\n"
520 
521         "pckev.w    $w0, $w2, $w0\n"
522         "pckev.w    $w1, $w6, $w4\n"
523         "pckev.w    $w2, $w10, $w8\n"
524         "pckev.w    $w3, $w14, $w12\n"
525 
526         "b " GEMMLOWP_LABEL_STORE "f\n"
527 
528         GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES ":\n"
529         // Load accumulators from memory.
530         "ld.w       $w16, 0(%[dst_ptr0])\n"
531         "pckev.w    $w0, $w2, $w0\n"
532         "ld.w       $w17, 0(%[dst_ptr1])\n"
533         "pckev.w    $w1, $w6, $w4\n"
534         "ld.w       $w18, 0(%[dst_ptr2])\n"
535         "pckev.w    $w2, $w10, $w8\n"
536         "ld.w       $w19, 0(%[dst_ptr3])\n"
537         "pckev.w    $w3, $w14, $w12\n"
538 
539         // Add them to internal accumulators.
540         "addv.w     $w0, $w0, $w16\n"
541         "addv.w     $w1, $w1, $w17\n"
542         "addv.w     $w2, $w2, $w18\n"
543         "addv.w     $w3, $w3, $w19\n"
544 
545         GEMMLOWP_LABEL_STORE ":\n"
546         // Store accumulators.
547         "st.w       $w0, 0(%[dst_ptr0])\n"
548         "st.w       $w1, 0(%[dst_ptr1])\n"
549         "st.w       $w2, 0(%[dst_ptr2])\n"
550         "st.w       $w3, 0(%[dst_ptr3])\n"
551         :  // outputs
552         [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
553         [run_depth] "+r"(run_depth)
554         :  // inputs
555         [dst_ptr0] "r"(dst_ptr), [dst_ptr1] "r"(dst_ptr + dst_col_stride),
556         [dst_ptr2] "r"(dst_ptr + dst_col_stride * 2),
557         [dst_ptr3] "r"(dst_ptr + dst_col_stride * 3),
558         [start_depth] "r"(start_depth)
559         :  // clobbers
560         "memory", "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", "$f8",
561         "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15", "$f16", "$f17",
562         "$f18", "$f19", "$f20", "$f21", "$f22", "$f23", "$f24", "$f25", "$f26",
563         "$f27", "$f28", "$f29", "$f30", "$f31");
564 #undef GEMMLOWP_LABEL_LOOP
565 #undef GEMMLOWP_LABEL_AFTER_LOOP_LAST16
566 #undef GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES
567 #undef GEMMLOWP_LABEL_STORE
568   }
569 };
570 
571 #undef GEMMLOWP_MIPS_XADDU
572 #undef GEMMLOWP_MIPS_XADDIU
573 #undef GEMMLOWP_MIPS_XSLL
574 
575 #endif  // GEMMLOWP_MSA
576 
577 }  // namespace gemmlowp
578 
579 #endif  // GEMMLOWP_INTERNAL_KERNEL_MSA_H_
580