1 /*
2 * Copyright (C) 2010-2018 Arm Limited or its affiliates. All rights reserved.
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "ref_functions.h"
20
arm_fully_connected_q7_opt_ref(const q7_t * pV,const q7_t * pM,const uint16_t dim_vec,const uint16_t num_of_rows,const uint16_t bias_shift,const uint16_t out_shift,const q7_t * bias,q7_t * pOut,q15_t * vec_buffer)21 void arm_fully_connected_q7_opt_ref(const q7_t * pV, // pointer to vector
22 const q7_t * pM, // pointer to matrix
23 const uint16_t dim_vec, // length of the vector
24 const uint16_t num_of_rows, // numCol of A
25 const uint16_t bias_shift, // amount of left-shift for bias
26 const uint16_t out_shift, // amount of right-shift for output
27 const q7_t * bias, q7_t * pOut, // output operand
28 q15_t * vec_buffer)
29 {
30
31 uint16_t rowCnt = num_of_rows >> 2;
32 const q7_t *pB = pM;
33 const q7_t *pA;
34 q7_t *pO = pOut;
35 const q7_t *pBias = bias;
36
37 while (rowCnt)
38 {
39 pA = pV;
40 #ifndef ARM_NN_TRUNCATE
41 q31_t sum = (*pBias++ << bias_shift) + (0x1 << (out_shift - 1));
42 q31_t sum2 = (*pBias++ << bias_shift) + (0x1 << (out_shift - 1));
43 q31_t sum3 = (*pBias++ << bias_shift) + (0x1 << (out_shift - 1));
44 q31_t sum4 = (*pBias++ << bias_shift) + (0x1 << (out_shift - 1));
45 #else
46 q31_t sum = *pBias++ << bias_shift;
47 q31_t sum2 = *pBias++ << bias_shift;
48 q31_t sum3 = *pBias++ << bias_shift;
49 q31_t sum4 = *pBias++ << bias_shift;
50 #endif
51
52 uint16_t colCnt = dim_vec >> 2;
53
54 while (colCnt)
55 {
56 q7_t inA1 = *pA++;
57 q7_t inA3 = *pA++;
58 q7_t inA2 = *pA++;
59 q7_t inA4 = *pA++;
60
61 q7_t inB1 = *pB++;
62 q7_t inB3 = *pB++;
63 q7_t inB2 = *pB++;
64 q7_t inB4 = *pB++;
65
66 sum += inA1 * inB1 + inA2 * inB2;
67 sum2 += inA1 * inB3 + inA2 * inB4;
68
69 inB1 = *pB++;
70 inB3 = *pB++;
71 inB2 = *pB++;
72 inB4 = *pB++;
73
74 sum3 += inA1 * inB1 + inA2 * inB2;
75 sum4 += inA1 * inB3 + inA2 * inB4;
76
77 inB1 = *pB++;
78 inB3 = *pB++;
79 inB2 = *pB++;
80 inB4 = *pB++;
81
82 sum += inA3 * inB1 + inA4 * inB2;
83 sum2 += inA3 * inB3 + inA4 * inB4;
84
85 inB1 = *pB++;
86 inB3 = *pB++;
87 inB2 = *pB++;
88 inB4 = *pB++;
89
90 sum3 += inA3 * inB1 + inA4 * inB2;
91 sum4 += inA3 * inB3 + inA4 * inB4;
92
93 colCnt--;
94 }
95 colCnt = dim_vec & 0x3;
96 while (colCnt)
97 {
98 q7_t inA = *pA++;
99 q7_t inB = *pB++;
100 sum += inA * inB;
101 inB = *pB++;
102 sum2 += inA * inB;
103 inB = *pB++;
104 sum3 += inA * inB;
105 inB = *pB++;
106 sum4 += inA * inB;
107
108 colCnt--;
109 }
110 *pO++ = (q7_t) __SSAT((sum >> out_shift), 8);
111 *pO++ = (q7_t) __SSAT((sum2 >> out_shift), 8);
112 *pO++ = (q7_t) __SSAT((sum3 >> out_shift), 8);
113 *pO++ = (q7_t) __SSAT((sum4 >> out_shift), 8);
114
115 rowCnt--;
116 }
117
118 rowCnt = num_of_rows & 0x3;
119
120 while (rowCnt)
121 {
122 pA = pV;
123 #ifndef ARM_NN_TRUNCATE
124 int ip_out = (*pBias++ << bias_shift) + (0x1 << (out_shift - 1));
125 #else
126 int ip_out = *pBias++ << bias_shift;
127 #endif
128 for (int j = 0; j < dim_vec; j++)
129 {
130 q7_t inA = *pA++;
131 q7_t inB = *pB++;
132 ip_out += inA * inB;
133 }
134 *pO++ = (q7_t) __SSAT((ip_out >> out_shift), 8);
135
136 rowCnt--;
137 }
138 }
139