1 /*
2 * Copyright (C) 2010-2018 Arm Limited or its affiliates. All rights reserved.
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_fully_connected_q7_opt.c
22 * Description: Q7 basic fully-connected layer function
23 *
24 * $Date: 17. January 2018
25 * $Revision: V.1.0.0
26 *
27 * Target Processor: Cortex-M cores
28 *
29 * -------------------------------------------------------------------- */
30
31 #include "arm_math.h"
32 #include "arm_nnfunctions.h"
33
34 /**
35 * @ingroup groupNN
36 */
37
38 /**
39 * @addtogroup FC
40 * @{
41 */
42
43 /**
44 * @brief Q7 opt fully-connected layer function
45 * @param[in] pV pointer to input vector
46 * @param[in] pM pointer to matrix weights
47 * @param[in] dim_vec length of the vector
48 * @param[in] num_of_rows number of rows in weight matrix
49 * @param[in] bias_shift amount of left-shift for bias
50 * @param[in] out_shift amount of right-shift for output
51 * @param[in] bias pointer to bias
52 * @param[in,out] pOut pointer to output vector
53 * @param[in,out] vec_buffer pointer to buffer space for input
54 * @return The function returns <code>ARM_MATH_SUCCESS</code>
55 *
56 * @details
57 *
58 * <b>Buffer size:</b>
59 *
60 * vec_buffer size: dim_vec
61 *
62 * This opt function is designed to work with interleaved weight
63 * matrix. The vector input is assumed in q7_t format, we call
64 * arm_q7_to_q15_no_shift_shuffle function to expand into
65 * q15_t format with certain weight re-ordering, refer to the function
66 * comments for more details.
67 * Here we use only one pointer to read 4 rows in the weight
68 * matrix. So if the original q7_t matrix looks like this:
69 *
70 * | a11 | a12 | a13 | a14 | a15 | a16 | a17 |
71 *
72 * | a21 | a22 | a23 | a24 | a25 | a26 | a27 |
73 *
74 * | a31 | a32 | a33 | a34 | a35 | a36 | a37 |
75 *
76 * | a41 | a42 | a43 | a44 | a45 | a46 | a47 |
77 *
78 * | a51 | a52 | a53 | a54 | a55 | a56 | a57 |
79 *
80 * | a61 | a62 | a63 | a64 | a65 | a66 | a67 |
81 *
82 *
83 * We operates on multiple-of-4 rows, so the first four rows becomes
84 *
85 * | a11 | a21 | a13 | a23 | a31 | a41 | a33 | a43 |
86 *
87 * | a12 | a22 | a14 | a24 | a32 | a42 | a34 | a44 |
88 *
89 * | a15 | a25 | a35 | a45 | a16 | a26 | a36 | a46 |
90 *
91 * So within the kernel, we first read the re-ordered vector in as:
92 *
93 * | b1 | b3 | and | b2 | b4 |
94 *
95 * the four q31_t weights will look like
96 *
97 * | a11 | a13 |, | a21 | a23 |, | a31 | a33 |, | a41 | a43 |
98 *
99 * | a12 | a14 |, | a22 | a24 |, | a32 | a34 |, | a42 | a44 |
100 *
101 * The column left over will be in-order.
102 * which is:
103 *
104 * | a17 | a27 | a37 | a47 |
105 *
106 * For the left-over rows, we do 1x1 computation, so the data remains
107 * as its original order.
108 *
109 * So the stored weight matrix looks like this:
110 *
111 * | a11 | a21 | a13 | a23 | a31 | a41 |
112 *
113 * | a33 | a43 | a12 | a22 | a14 | a24 |
114 *
115 * | a32 | a42 | a34 | a44 | a15 | a25 |
116 *
117 * | a35 | a45 | a16 | a26 | a36 | a46 |
118 *
119 * | a17 | a27 | a37 | a47 | a51 | a52 |
120 *
121 * | a53 | a54 | a55 | a56 | a57 | a61 |
122 *
123 * | a62 | a63 | a64 | a65 | a66 | a67 |
124 *
125 *
126 */
127
128 arm_status
arm_fully_connected_q7_opt(const q7_t * pV,const q7_t * pM,const uint16_t dim_vec,const uint16_t num_of_rows,const uint16_t bias_shift,const uint16_t out_shift,const q7_t * bias,q7_t * pOut,q15_t * vec_buffer)129 arm_fully_connected_q7_opt(const q7_t * pV,
130 const q7_t * pM,
131 const uint16_t dim_vec,
132 const uint16_t num_of_rows,
133 const uint16_t bias_shift,
134 const uint16_t out_shift,
135 const q7_t * bias,
136 q7_t * pOut,
137 q15_t * vec_buffer)
138 {
139
140 #if defined (ARM_MATH_DSP)
141 /* Run the following code for Cortex-M4 and Cortex-M7 */
142
143 const q7_t *pB = pM;
144 q7_t *pO = pOut;
145 const q7_t *pBias = bias;
146 q15_t *pA;
147 uint16_t rowCnt = num_of_rows >> 2;
148
149 arm_q7_to_q15_reordered_no_shift(pV, vec_buffer, dim_vec);
150
151 while (rowCnt)
152 {
153
154 q31_t sum = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
155 q31_t sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
156 q31_t sum3 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
157 q31_t sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
158
159 uint16_t colCnt = dim_vec >> 2;
160
161 pA = vec_buffer;
162
163 #ifdef USE_INTRINSIC
164
165 #ifndef ARM_MATH_BIG_ENDIAN
166 while (colCnt)
167 {
168 q31_t inM11, inM12, inM13, inM14;
169 q31_t inV;
170
171 inV = *__SIMD32(pA)++;
172 inM11 = *__SIMD32(pB)++;
173 inM12 = __SXTB16(__ROR(inM11, 8));
174 inM11 = __SXTB16(inM11);
175 sum = __SMLAD(inM11, inV, sum);
176 sum2 = __SMLAD(inM12, inV, sum2);
177 inM13 = *__SIMD32(pB)++;
178 inM14 = __SXTB16(__ROR(inM13, 8));
179 inM13 = __SXTB16(inM13);
180 sum3 = __SMLAD(inM13, inV, sum3);
181 sum4 = __SMLAD(inM14, inV, sum4);
182
183 inV = *__SIMD32(pA)++;
184 inM11 = *__SIMD32(pB)++;
185 inM12 = __SXTB16(__ROR(inM11, 8));
186 inM11 = __SXTB16(inM11);
187 sum = __SMLAD(inM11, inV, sum);
188 sum2 = __SMLAD(inM12, inV, sum2);
189 inM13 = *__SIMD32(pB)++;
190 inM14 = __SXTB16(__ROR(inM13, 8));
191 inM13 = __SXTB16(inM13);
192 sum3 = __SMLAD(inM13, inV, sum3);
193 sum4 = __SMLAD(inM14, inV, sum4);
194 colCnt--;
195 }
196 #else
197 while (colCnt)
198 {
199 q31_t inM11, inM12, inM13, inM14;
200 q31_t inV;
201
202 inV = *__SIMD32(pA)++;
203 inM11 = *__SIMD32(pB)++;
204 inM12 = __SXTB16(__ROR(inM11, 8));
205 inM11 = __SXTB16(inM11);
206 sum = __SMLAD(inM12, inV, sum);
207 sum2 = __SMLAD(inM11, inV, sum2);
208 inM13 = *__SIMD32(pB)++;
209 inM14 = __SXTB16(__ROR(inM13, 8));
210 inM13 = __SXTB16(inM13);
211 sum3 = __SMLAD(inM14, inV, sum3);
212 sum4 = __SMLAD(inM13, inV, sum4);
213
214 inV = *__SIMD32(pA)++;
215 inM11 = *__SIMD32(pB)++;
216 inM12 = __SXTB16(__ROR(inM11, 8));
217 inM11 = __SXTB16(inM11);
218 sum = __SMLAD(inM12, inV, sum);
219 sum2 = __SMLAD(inM11, inV, sum2);
220 inM13 = *__SIMD32(pB)++;
221 inM14 = __SXTB16(__ROR(inM13, 8));
222 inM13 = __SXTB16(inM13);
223 sum3 = __SMLAD(inM14, inV, sum3);
224 sum4 = __SMLAD(inM13, inV, sum4);
225 colCnt--;
226 }
227 #endif /* ARM_MATH_BIG_ENDIAN */
228
229 #else
230
231 /*
232 * register needed:
233 * loop counter: colCnt
234 * accumulators: sum, sum2, sum3, sum4
235 * pointers: pB, pA
236 * weight data: inM11, inM12, inM13, inM14
237 * activation data: inV
238 */
239
240 #ifndef ARM_MATH_BIG_ENDIAN
241 asm volatile ("COL_LOOP_%=:\n"
242 "ldr.w r4, [%[pA]], #8\n"
243 "ldr.w r1, [%[pB]], #16\n"
244 "mov.w r0, r1, ror #8\n"
245 "sxtb16 r0, r0\n"
246 "sxtb16 r1, r1\n"
247 "smlad %[sum], r4, r1, %[sum]\n"
248 "smlad %[sum2], r4, r0, %[sum2]\n"
249 "ldr.w r3, [%[pB], #-12]\n"
250 "mov.w r2, r3, ror #8\n"
251 "sxtb16 r2, r2\n"
252 "sxtb16 r3, r3\n"
253 "smlad %[sum3], r4, r3, %[sum3]\n"
254 "smlad %[sum4], r4, r2, %[sum4]\n"
255 "ldr.w r4, [%[pA], #-4]\n"
256 "ldr.w r1, [%[pB], #-8]\n"
257 "mov.w r0, r1, ror #8\n"
258 "sxtb16 r0, r0\n"
259 "sxtb16 r1, r1\n"
260 "smlad %[sum], r4, r1, %[sum]\n"
261 "smlad %[sum2], r4, r0, %[sum2]\n"
262 "ldr.w r3, [%[pB], #-4]\n"
263 "mov.w r2, r3, ror #8\n"
264 "sxtb16 r2, r2\n"
265 "sxtb16 r3, r3\n"
266 "smlad %[sum3], r4, r3, %[sum3]\n"
267 "smlad %[sum4], r4, r2, %[sum4]\n"
268 "subs %[colCnt], #1\n"
269 "bne COL_LOOP_%=\n":[sum] "+r"(sum),
270 [sum2] "+r"(sum2),[sum3] "+r"(sum3),
271 [sum4] "+r"(sum4),[pB] "+r"(pB),[pA] "+r"(pA):[colCnt] "r"(colCnt):"r0", "r1", "r2", "r3", "r4");
272 #else
273 asm volatile ("COL_LOOP_%=:\n"
274 "ldr.w r4, [%[pA]], #8\n"
275 "ldr.w r1, [%[pB]], #16\n"
276 "mov.w r0, r1, ror #8\n"
277 "sxtb16 r0, r0\n"
278 "sxtb16 r1, r1\n"
279 "smlad %[sum], r4, r0, %[sum]\n"
280 "smlad %[sum2], r4, r1, %[sum2]\n"
281 "ldr.w r3, [%[pB], #-12]\n"
282 "mov.w r2, r3, ror #8\n"
283 "sxtb16 r2, r2\n"
284 "sxtb16 r3, r3\n"
285 "smlad %[sum3], r4, r2, %[sum3]\n"
286 "smlad %[sum4], r4, r3, %[sum4]\n"
287 "ldr.w r4, [%[pA], #-4]\n"
288 "ldr.w r1, [%[pB], #-8]\n"
289 "mov.w r0, r1, ror #8\n"
290 "sxtb16 r0, r0\n"
291 "sxtb16 r1, r1\n"
292 "smlad %[sum], r4, r0, %[sum]\n"
293 "smlad %[sum2], r4, r1, %[sum2]\n"
294 "ldr.w r3, [%[pB], #-4]\n"
295 "mov.w r2, r3, ror #8\n"
296 "sxtb16 r2, r2\n"
297 "sxtb16 r3, r3\n"
298 "smlad %[sum3], r4, r2, %[sum3]\n"
299 "smlad %[sum4], r4, r3, %[sum4]\n"
300 "subs %[colCnt], #1\n"
301 "bne COL_LOOP_%=\n":[sum] "+r"(sum),
302 [sum2] "+r"(sum2),[sum3] "+r"(sum3),
303 [sum4] "+r"(sum4),[pB] "+r"(pB),[pA] "+r"(pA):[colCnt] "r"(colCnt):"r0", "r1", "r2", "r3", "r4");
304 #endif /* ARM_MATH_BIG_ENDIAN */
305
306 #endif /* USE_INTRINSIC */
307
308 colCnt = dim_vec & 0x3;
309 while (colCnt)
310 {
311 q15_t inV = *pA++;
312 q7_t inM = *pB++;
313 q7_t inM2 = *pB++;
314 q7_t inM3 = *pB++;
315 q7_t inM4 = *pB++;
316
317 sum += inV * inM;
318 sum2 += inV * inM2;
319 sum3 += inV * inM3;
320 sum4 += inV * inM4;
321 colCnt--;
322 } /* while over colCnt */
323 *pO++ = (q7_t) (__SSAT((sum >> out_shift), 8));
324 *pO++ = (q7_t) (__SSAT((sum2 >> out_shift), 8));
325 *pO++ = (q7_t) (__SSAT((sum3 >> out_shift), 8));
326 *pO++ = (q7_t) (__SSAT((sum4 >> out_shift), 8));
327
328 /* adjust the pointers and counters */
329 rowCnt--;
330 }
331
332 /* left-over part of the rows */
333 rowCnt = num_of_rows & 0x3;
334
335 while (rowCnt)
336 {
337 q31_t sum = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
338 uint16_t colCnt = dim_vec >> 2;
339
340 pA = vec_buffer;
341
342 while (colCnt)
343 {
344 q31_t inV1, inV2, inM11, inM12;
345
346 pB = (q7_t *) read_and_pad_reordered((void *)pB, &inM11, &inM12);
347
348 inV1 = *__SIMD32(pA)++;
349 sum = __SMLAD(inV1, inM11, sum);
350
351 inV2 = *__SIMD32(pA)++;
352 sum = __SMLAD(inV2, inM12, sum);
353
354 colCnt--;
355 }
356
357 /* left-over of the vector */
358 colCnt = dim_vec & 0x3;
359 while (colCnt)
360 {
361 q15_t inV = *pA++;
362 q7_t inM = *pB++;
363 sum += inV * inM;
364 colCnt--;
365 }
366
367 *pO++ = (q7_t) (__SSAT((sum >> out_shift), 8));
368
369 rowCnt--;
370 }
371
372 #else
373 /* Run the following code as reference implementation for Cortex-M0 and Cortex-M3 */
374 uint16_t rowCnt = num_of_rows >> 2;
375 const q7_t *pB = pM;
376 const q7_t *pA;
377 q7_t *pO = pOut;
378 const q7_t *pBias = bias;
379
380 while (rowCnt)
381 {
382 q31_t sum = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
383 q31_t sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
384 q31_t sum3 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
385 q31_t sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
386
387 uint16_t colCnt = dim_vec >> 2;
388
389 pA = pV;
390
391 while (colCnt)
392 {
393 q7_t inA1 = *pA++;
394 q7_t inA3 = *pA++;
395 q7_t inA2 = *pA++;
396 q7_t inA4 = *pA++;
397
398 q7_t inB1 = *pB++;
399 q7_t inB3 = *pB++;
400 q7_t inB2 = *pB++;
401 q7_t inB4 = *pB++;
402
403 sum += inA1 * inB1 + inA2 * inB2;
404 sum2 += inA1 * inB3 + inA2 * inB4;
405
406 inB1 = *pB++;
407 inB3 = *pB++;
408 inB2 = *pB++;
409 inB4 = *pB++;
410
411 sum3 += inA1 * inB1 + inA2 * inB2;
412 sum4 += inA1 * inB3 + inA2 * inB4;
413
414 inB1 = *pB++;
415 inB3 = *pB++;
416 inB2 = *pB++;
417 inB4 = *pB++;
418
419 sum += inA3 * inB1 + inA4 * inB2;
420 sum2 += inA3 * inB3 + inA4 * inB4;
421
422 inB1 = *pB++;
423 inB3 = *pB++;
424 inB2 = *pB++;
425 inB4 = *pB++;
426
427 sum3 += inA3 * inB1 + inA4 * inB2;
428 sum4 += inA3 * inB3 + inA4 * inB4;
429
430 colCnt--;
431 }
432 colCnt = dim_vec & 0x3;
433 while (colCnt)
434 {
435 q7_t inA = *pA++;
436 q7_t inB = *pB++;
437 sum += inA * inB;
438 inB = *pB++;
439 sum2 += inA * inB;
440 inB = *pB++;
441 sum3 += inA * inB;
442 inB = *pB++;
443 sum4 += inA * inB;
444
445 colCnt--;
446 }
447 *pO++ = (q7_t) __SSAT((sum >> out_shift), 8);
448 *pO++ = (q7_t) __SSAT((sum2 >> out_shift), 8);
449 *pO++ = (q7_t) __SSAT((sum3 >> out_shift), 8);
450 *pO++ = (q7_t) __SSAT((sum4 >> out_shift), 8);
451
452 rowCnt--;
453 }
454
455 rowCnt = num_of_rows & 0x3;
456
457 while (rowCnt)
458 {
459 int ip_out = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
460
461 int j;
462
463 pA = pV;
464 for (j = 0; j < dim_vec; j++)
465 {
466 q7_t inA = *pA++;
467 q7_t inB = *pB++;
468 ip_out += inA * inB;
469 }
470 *pO++ = (q7_t) __SSAT((ip_out >> out_shift), 8);
471
472 rowCnt--;
473 }
474
475 #endif /* ARM_MATH_DSP */
476
477 /* Return to ARM_MATH_SUCCESS */
478 return (ARM_MATH_SUCCESS);
479
480 }
481
482 /**
483 * @} end of FC group
484 */
485