1 /*
2  * Copyright (C) 2010-2018 Arm Limited or its affiliates. All rights reserved.
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_fully_connected_q7_opt.c
22  * Description:  Q7 basic fully-connected layer function
23  *
24  * $Date:        17. January 2018
25  * $Revision:    V.1.0.0
26  *
27  * Target Processor:  Cortex-M cores
28  *
29  * -------------------------------------------------------------------- */
30 
31 #include "arm_math.h"
32 #include "arm_nnfunctions.h"
33 
34 /**
35  *  @ingroup groupNN
36  */
37 
38 /**
39  * @addtogroup FC
40  * @{
41  */
42 
43   /**
44    * @brief Q7 opt fully-connected layer function
45    * @param[in]       pV          pointer to input vector
46    * @param[in]       pM          pointer to matrix weights
47    * @param[in]       dim_vec     length of the vector
48    * @param[in]       num_of_rows number of rows in weight matrix
49    * @param[in]       bias_shift  amount of left-shift for bias
50    * @param[in]       out_shift   amount of right-shift for output
51    * @param[in]       bias        pointer to bias
52    * @param[in,out]   pOut        pointer to output vector
53    * @param[in,out]   vec_buffer  pointer to buffer space for input
54    * @return     The function returns <code>ARM_MATH_SUCCESS</code>
55    *
56    * @details
57    *
58    * <b>Buffer size:</b>
59    *
60    * vec_buffer size: dim_vec
61    *
62    * This opt function is designed to work with interleaved weight
63    * matrix. The vector input is assumed in q7_t format, we call
64    *  arm_q7_to_q15_no_shift_shuffle function to expand into
65    *  q15_t format with certain weight re-ordering, refer to the function
66    *  comments for more details.
67    *  Here we use only one pointer to read 4 rows in the weight
68    *  matrix. So if the original q7_t matrix looks like this:
69    *
70    *  | a11 | a12 | a13 | a14 | a15 | a16 | a17 |
71    *
72    *  | a21 | a22 | a23 | a24 | a25 | a26 | a27 |
73    *
74    *  | a31 | a32 | a33 | a34 | a35 | a36 | a37 |
75    *
76    *  | a41 | a42 | a43 | a44 | a45 | a46 | a47 |
77    *
78    *  | a51 | a52 | a53 | a54 | a55 | a56 | a57 |
79    *
80    *  | a61 | a62 | a63 | a64 | a65 | a66 | a67 |
81    *
82    *
83    *  We operates on multiple-of-4 rows, so the first four rows becomes
84    *
85    *  | a11 | a21 | a13 | a23 | a31 | a41 | a33 | a43 |
86    *
87    *  | a12 | a22 | a14 | a24 | a32 | a42 | a34 | a44 |
88    *
89    *  | a15 | a25 | a35 | a45 | a16 | a26 | a36 | a46 |
90    *
91    *  So within the kernel, we first read the re-ordered vector in as:
92    *
93    *  | b1  | b3  | and | b2  | b4  |
94    *
95    *  the four q31_t weights will look like
96    *
97    *  | a11 | a13 |, | a21 | a23 |, | a31 | a33 |, | a41 | a43 |
98    *
99    *  | a12 | a14 |, | a22 | a24 |, | a32 | a34 |, | a42 | a44 |
100    *
101    *  The column left over will be in-order.
102    *  which is:
103    *
104    *  | a17 | a27 | a37 | a47 |
105    *
106    *  For the left-over rows, we do 1x1 computation, so the data remains
107    *  as its original order.
108    *
109    *  So the stored weight matrix looks like this:
110    *
111    *  | a11 | a21 | a13 | a23 | a31 | a41 |
112    *
113    *  | a33 | a43 | a12 | a22 | a14 | a24 |
114    *
115    *  | a32 | a42 | a34 | a44 | a15 | a25 |
116    *
117    *  | a35 | a45 | a16 | a26 | a36 | a46 |
118    *
119    *  | a17 | a27 | a37 | a47 | a51 | a52 |
120    *
121    *  | a53 | a54 | a55 | a56 | a57 | a61 |
122    *
123    *  | a62 | a63 | a64 | a65 | a66 | a67 |
124    *
125    *
126    */
127 
128 arm_status
arm_fully_connected_q7_opt(const q7_t * pV,const q7_t * pM,const uint16_t dim_vec,const uint16_t num_of_rows,const uint16_t bias_shift,const uint16_t out_shift,const q7_t * bias,q7_t * pOut,q15_t * vec_buffer)129 arm_fully_connected_q7_opt(const q7_t * pV,
130                            const q7_t * pM,
131                            const uint16_t dim_vec,
132                            const uint16_t num_of_rows,
133                            const uint16_t bias_shift,
134                            const uint16_t out_shift,
135                            const q7_t * bias,
136                            q7_t * pOut,
137                            q15_t * vec_buffer)
138 {
139 
140 #if defined (ARM_MATH_DSP)
141     /* Run the following code for Cortex-M4 and Cortex-M7 */
142 
143     const q7_t *pB = pM;
144     q7_t     *pO = pOut;
145     const q7_t *pBias = bias;
146     q15_t    *pA;
147     uint16_t  rowCnt = num_of_rows >> 2;
148 
149     arm_q7_to_q15_reordered_no_shift(pV, vec_buffer, dim_vec);
150 
151     while (rowCnt)
152     {
153 
154         q31_t     sum =  ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
155         q31_t     sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
156         q31_t     sum3 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
157         q31_t     sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
158 
159         uint16_t  colCnt = dim_vec >> 2;
160 
161         pA = vec_buffer;
162 
163 #ifdef USE_INTRINSIC
164 
165 #ifndef ARM_MATH_BIG_ENDIAN
166         while (colCnt)
167         {
168             q31_t     inM11, inM12, inM13, inM14;
169             q31_t     inV;
170 
171             inV = *__SIMD32(pA)++;
172             inM11 = *__SIMD32(pB)++;
173             inM12 = __SXTB16(__ROR(inM11, 8));
174             inM11 = __SXTB16(inM11);
175             sum = __SMLAD(inM11, inV, sum);
176             sum2 = __SMLAD(inM12, inV, sum2);
177             inM13 = *__SIMD32(pB)++;
178             inM14 = __SXTB16(__ROR(inM13, 8));
179             inM13 = __SXTB16(inM13);
180             sum3 = __SMLAD(inM13, inV, sum3);
181             sum4 = __SMLAD(inM14, inV, sum4);
182 
183             inV = *__SIMD32(pA)++;
184             inM11 = *__SIMD32(pB)++;
185             inM12 = __SXTB16(__ROR(inM11, 8));
186             inM11 = __SXTB16(inM11);
187             sum = __SMLAD(inM11, inV, sum);
188             sum2 = __SMLAD(inM12, inV, sum2);
189             inM13 = *__SIMD32(pB)++;
190             inM14 = __SXTB16(__ROR(inM13, 8));
191             inM13 = __SXTB16(inM13);
192             sum3 = __SMLAD(inM13, inV, sum3);
193             sum4 = __SMLAD(inM14, inV, sum4);
194             colCnt--;
195         }
196 #else
197         while (colCnt)
198         {
199             q31_t     inM11, inM12, inM13, inM14;
200             q31_t     inV;
201 
202             inV = *__SIMD32(pA)++;
203             inM11 = *__SIMD32(pB)++;
204             inM12 = __SXTB16(__ROR(inM11, 8));
205             inM11 = __SXTB16(inM11);
206             sum = __SMLAD(inM12, inV, sum);
207             sum2 = __SMLAD(inM11, inV, sum2);
208             inM13 = *__SIMD32(pB)++;
209             inM14 = __SXTB16(__ROR(inM13, 8));
210             inM13 = __SXTB16(inM13);
211             sum3 = __SMLAD(inM14, inV, sum3);
212             sum4 = __SMLAD(inM13, inV, sum4);
213 
214             inV = *__SIMD32(pA)++;
215             inM11 = *__SIMD32(pB)++;
216             inM12 = __SXTB16(__ROR(inM11, 8));
217             inM11 = __SXTB16(inM11);
218             sum = __SMLAD(inM12, inV, sum);
219             sum2 = __SMLAD(inM11, inV, sum2);
220             inM13 = *__SIMD32(pB)++;
221             inM14 = __SXTB16(__ROR(inM13, 8));
222             inM13 = __SXTB16(inM13);
223             sum3 = __SMLAD(inM14, inV, sum3);
224             sum4 = __SMLAD(inM13, inV, sum4);
225             colCnt--;
226         }
227 #endif                          /* ARM_MATH_BIG_ENDIAN */
228 
229 #else
230 
231         /*
232          * register needed:
233          * loop counter: colCnt
234          * accumulators: sum, sum2, sum3, sum4
235          * pointers: pB, pA
236          * weight data: inM11, inM12, inM13, inM14
237          * activation data: inV
238          */
239 
240 #ifndef ARM_MATH_BIG_ENDIAN
241         asm volatile ("COL_LOOP_%=:\n"
242                       "ldr.w r4, [%[pA]], #8\n"
243                       "ldr.w r1, [%[pB]], #16\n"
244                       "mov.w r0, r1, ror #8\n"
245                       "sxtb16 r0, r0\n"
246                       "sxtb16 r1, r1\n"
247                       "smlad %[sum], r4, r1, %[sum]\n"
248                       "smlad %[sum2], r4, r0, %[sum2]\n"
249                       "ldr.w r3, [%[pB], #-12]\n"
250                       "mov.w r2, r3, ror #8\n"
251                       "sxtb16 r2, r2\n"
252                       "sxtb16 r3, r3\n"
253                       "smlad %[sum3], r4, r3, %[sum3]\n"
254                       "smlad %[sum4], r4, r2, %[sum4]\n"
255                       "ldr.w r4, [%[pA], #-4]\n"
256                       "ldr.w r1, [%[pB], #-8]\n"
257                       "mov.w r0, r1, ror #8\n"
258                       "sxtb16 r0, r0\n"
259                       "sxtb16 r1, r1\n"
260                       "smlad %[sum], r4, r1, %[sum]\n"
261                       "smlad %[sum2], r4, r0, %[sum2]\n"
262                       "ldr.w r3, [%[pB], #-4]\n"
263                       "mov.w r2, r3, ror #8\n"
264                       "sxtb16 r2, r2\n"
265                       "sxtb16 r3, r3\n"
266                       "smlad %[sum3], r4, r3, %[sum3]\n"
267                       "smlad %[sum4], r4, r2, %[sum4]\n"
268                       "subs %[colCnt], #1\n"
269                       "bne COL_LOOP_%=\n":[sum] "+r"(sum),
270                       [sum2] "+r"(sum2),[sum3] "+r"(sum3),
271                       [sum4] "+r"(sum4),[pB] "+r"(pB),[pA] "+r"(pA):[colCnt] "r"(colCnt):"r0", "r1", "r2", "r3", "r4");
272 #else
273         asm volatile ("COL_LOOP_%=:\n"
274                       "ldr.w r4, [%[pA]], #8\n"
275                       "ldr.w r1, [%[pB]], #16\n"
276                       "mov.w r0, r1, ror #8\n"
277                       "sxtb16 r0, r0\n"
278                       "sxtb16 r1, r1\n"
279                       "smlad %[sum], r4, r0, %[sum]\n"
280                       "smlad %[sum2], r4, r1, %[sum2]\n"
281                       "ldr.w r3, [%[pB], #-12]\n"
282                       "mov.w r2, r3, ror #8\n"
283                       "sxtb16 r2, r2\n"
284                       "sxtb16 r3, r3\n"
285                       "smlad %[sum3], r4, r2, %[sum3]\n"
286                       "smlad %[sum4], r4, r3, %[sum4]\n"
287                       "ldr.w r4, [%[pA], #-4]\n"
288                       "ldr.w r1, [%[pB], #-8]\n"
289                       "mov.w r0, r1, ror #8\n"
290                       "sxtb16 r0, r0\n"
291                       "sxtb16 r1, r1\n"
292                       "smlad %[sum], r4, r0, %[sum]\n"
293                       "smlad %[sum2], r4, r1, %[sum2]\n"
294                       "ldr.w r3, [%[pB], #-4]\n"
295                       "mov.w r2, r3, ror #8\n"
296                       "sxtb16 r2, r2\n"
297                       "sxtb16 r3, r3\n"
298                       "smlad %[sum3], r4, r2, %[sum3]\n"
299                       "smlad %[sum4], r4, r3, %[sum4]\n"
300                       "subs %[colCnt], #1\n"
301                       "bne COL_LOOP_%=\n":[sum] "+r"(sum),
302                       [sum2] "+r"(sum2),[sum3] "+r"(sum3),
303                       [sum4] "+r"(sum4),[pB] "+r"(pB),[pA] "+r"(pA):[colCnt] "r"(colCnt):"r0", "r1", "r2", "r3", "r4");
304 #endif                          /* ARM_MATH_BIG_ENDIAN */
305 
306 #endif                          /* USE_INTRINSIC */
307 
308         colCnt = dim_vec & 0x3;
309         while (colCnt)
310         {
311             q15_t     inV = *pA++;
312             q7_t      inM = *pB++;
313             q7_t      inM2 = *pB++;
314             q7_t      inM3 = *pB++;
315             q7_t      inM4 = *pB++;
316 
317             sum += inV * inM;
318             sum2 += inV * inM2;
319             sum3 += inV * inM3;
320             sum4 += inV * inM4;
321             colCnt--;
322         }                       /* while over colCnt */
323         *pO++ = (q7_t) (__SSAT((sum >> out_shift), 8));
324         *pO++ = (q7_t) (__SSAT((sum2 >> out_shift), 8));
325         *pO++ = (q7_t) (__SSAT((sum3 >> out_shift), 8));
326         *pO++ = (q7_t) (__SSAT((sum4 >> out_shift), 8));
327 
328         /* adjust the pointers and counters */
329         rowCnt--;
330     }
331 
332     /* left-over part of the rows */
333     rowCnt = num_of_rows & 0x3;
334 
335     while (rowCnt)
336     {
337         q31_t     sum = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
338         uint16_t  colCnt = dim_vec >> 2;
339 
340         pA = vec_buffer;
341 
342         while (colCnt)
343         {
344             q31_t     inV1, inV2, inM11, inM12;
345 
346             pB = (q7_t *) read_and_pad_reordered((void *)pB, &inM11, &inM12);
347 
348             inV1 = *__SIMD32(pA)++;
349             sum = __SMLAD(inV1, inM11, sum);
350 
351             inV2 = *__SIMD32(pA)++;
352             sum = __SMLAD(inV2, inM12, sum);
353 
354             colCnt--;
355         }
356 
357         /* left-over of the vector */
358         colCnt = dim_vec & 0x3;
359         while (colCnt)
360         {
361             q15_t     inV = *pA++;
362             q7_t      inM = *pB++;
363             sum += inV * inM;
364             colCnt--;
365         }
366 
367         *pO++ = (q7_t) (__SSAT((sum >> out_shift), 8));
368 
369         rowCnt--;
370     }
371 
372 #else
373     /* Run the following code as reference implementation for Cortex-M0 and Cortex-M3 */
374     uint16_t  rowCnt = num_of_rows >> 2;
375     const q7_t *pB = pM;
376     const q7_t *pA;
377     q7_t     *pO = pOut;
378     const q7_t *pBias = bias;
379 
380     while (rowCnt)
381     {
382         q31_t     sum =  ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
383         q31_t     sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
384         q31_t     sum3 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
385         q31_t     sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
386 
387         uint16_t  colCnt = dim_vec >> 2;
388 
389         pA = pV;
390 
391         while (colCnt)
392         {
393             q7_t      inA1 = *pA++;
394             q7_t      inA3 = *pA++;
395             q7_t      inA2 = *pA++;
396             q7_t      inA4 = *pA++;
397 
398             q7_t      inB1 = *pB++;
399             q7_t      inB3 = *pB++;
400             q7_t      inB2 = *pB++;
401             q7_t      inB4 = *pB++;
402 
403             sum += inA1 * inB1 + inA2 * inB2;
404             sum2 += inA1 * inB3 + inA2 * inB4;
405 
406             inB1 = *pB++;
407             inB3 = *pB++;
408             inB2 = *pB++;
409             inB4 = *pB++;
410 
411             sum3 += inA1 * inB1 + inA2 * inB2;
412             sum4 += inA1 * inB3 + inA2 * inB4;
413 
414             inB1 = *pB++;
415             inB3 = *pB++;
416             inB2 = *pB++;
417             inB4 = *pB++;
418 
419             sum += inA3 * inB1 + inA4 * inB2;
420             sum2 += inA3 * inB3 + inA4 * inB4;
421 
422             inB1 = *pB++;
423             inB3 = *pB++;
424             inB2 = *pB++;
425             inB4 = *pB++;
426 
427             sum3 += inA3 * inB1 + inA4 * inB2;
428             sum4 += inA3 * inB3 + inA4 * inB4;
429 
430             colCnt--;
431         }
432         colCnt = dim_vec & 0x3;
433         while (colCnt)
434         {
435             q7_t      inA = *pA++;
436             q7_t      inB = *pB++;
437             sum += inA * inB;
438             inB = *pB++;
439             sum2 += inA * inB;
440             inB = *pB++;
441             sum3 += inA * inB;
442             inB = *pB++;
443             sum4 += inA * inB;
444 
445             colCnt--;
446         }
447         *pO++ = (q7_t) __SSAT((sum >> out_shift), 8);
448         *pO++ = (q7_t) __SSAT((sum2 >> out_shift), 8);
449         *pO++ = (q7_t) __SSAT((sum3 >> out_shift), 8);
450         *pO++ = (q7_t) __SSAT((sum4 >> out_shift), 8);
451 
452         rowCnt--;
453     }
454 
455     rowCnt = num_of_rows & 0x3;
456 
457     while (rowCnt)
458     {
459         int       ip_out = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
460 
461         int       j;
462 
463         pA = pV;
464         for (j = 0; j < dim_vec; j++)
465         {
466             q7_t      inA = *pA++;
467             q7_t      inB = *pB++;
468             ip_out += inA * inB;
469         }
470         *pO++ = (q7_t) __SSAT((ip_out >> out_shift), 8);
471 
472         rowCnt--;
473     }
474 
475 #endif                          /* ARM_MATH_DSP */
476 
477     /* Return to ARM_MATH_SUCCESS */
478     return (ARM_MATH_SUCCESS);
479 
480 }
481 
482 /**
483  * @} end of FC group
484  */
485