1 /*
2 * Copyright (C) 2010-2018 Arm Limited or its affiliates. All rights reserved.
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "ref_functions.h"
20
arm_fully_connected_mat_q7_vec_q15_opt_ref(const q15_t * pV,const q7_t * pM,const uint16_t dim_vec,const uint16_t num_of_rows,const uint16_t bias_shift,const uint16_t out_shift,const q7_t * bias,q15_t * pOut,q15_t * vec_buffer)21 void arm_fully_connected_mat_q7_vec_q15_opt_ref(const q15_t * pV, // pointer to vector
22 const q7_t * pM, // pointer to matrix
23 const uint16_t dim_vec, // length of the vector
24 const uint16_t num_of_rows, // numCol of A
25 const uint16_t bias_shift, // amount of left-shift for bias
26 const uint16_t out_shift, // amount of right-shift for output
27 const q7_t * bias, q15_t * pOut, // output operand
28 q15_t * vec_buffer)
29 {
30
31 uint16_t rowCnt = num_of_rows >> 2;
32 const q7_t *pB = pM;
33 const q15_t *pA;
34 q15_t *pO = pOut;
35 const q7_t *pBias = bias;
36
37 while (rowCnt)
38 {
39 pA = pV;
40 #ifndef ARM_NN_TRUNCATE
41 q31_t sum = (*pBias++ << bias_shift) + (0x1 << (out_shift - 1));
42 q31_t sum2 = (*pBias++ << bias_shift) + (0x1 << (out_shift - 1));
43 q31_t sum3 = (*pBias++ << bias_shift) + (0x1 << (out_shift - 1));
44 q31_t sum4 = (*pBias++ << bias_shift) + (0x1 << (out_shift - 1));
45 #else
46 q31_t sum = *pBias++ << bias_shift;
47 q31_t sum2 = *pBias++ << bias_shift;
48 q31_t sum3 = *pBias++ << bias_shift;
49 q31_t sum4 = *pBias++ << bias_shift;
50 #endif
51
52 uint16_t colCnt = dim_vec >> 1;
53
54 while (colCnt)
55 {
56 q15_t inA1 = *pA++;
57 q15_t inA2 = *pA++;
58
59 q7_t inB1 = *pB++;
60 q7_t inB3 = *pB++;
61 q7_t inB2 = *pB++;
62 q7_t inB4 = *pB++;
63
64 sum += inA1 * inB1 + inA2 * inB2;
65 sum2 += inA1 * inB3 + inA2 * inB4;
66
67 inB1 = *pB++;
68 inB3 = *pB++;
69 inB2 = *pB++;
70 inB4 = *pB++;
71
72 sum3 += inA1 * inB1 + inA2 * inB2;
73 sum4 += inA1 * inB3 + inA2 * inB4;
74
75 colCnt--;
76 }
77 colCnt = dim_vec & 0x1;
78 while (colCnt)
79 {
80 q15_t inA = *pA++;
81 q7_t inB = *pB++;
82 sum += inA * inB;
83 inB = *pB++;
84 sum2 += inA * inB;
85 inB = *pB++;
86 sum3 += inA * inB;
87 inB = *pB++;
88 sum4 += inA * inB;
89
90 colCnt--;
91 }
92 *pO++ = (q15_t) __SSAT((sum >> out_shift), 16);
93 *pO++ = (q15_t) __SSAT((sum2 >> out_shift), 16);
94 *pO++ = (q15_t) __SSAT((sum3 >> out_shift), 16);
95 *pO++ = (q15_t) __SSAT((sum4 >> out_shift), 16);
96
97 rowCnt--;
98 }
99
100 rowCnt = num_of_rows & 0x3;
101
102 while (rowCnt)
103 {
104 pA = pV;
105 #ifndef ARM_NN_TRUNCATE
106 int ip_out = (*pBias++ << bias_shift) + (0x1 << (out_shift - 1));
107 #else
108 int ip_out = *pBias++ << bias_shift;
109 #endif
110 for (int j = 0; j < dim_vec; j++)
111 {
112 q15_t inA = *pA++;
113 q7_t inB = *pB++;
114 ip_out += inA * inB;
115 }
116 *pO++ = (q15_t) __SSAT((ip_out >> out_shift), 16);
117
118 rowCnt--;
119 }
120 }
121