1*a58d3d2aSXin Li /* Copyright (c) 2018 David Rowe
2*a58d3d2aSXin Li 2018 Mozilla
3*a58d3d2aSXin Li 2008-2011 Octasic Inc.
4*a58d3d2aSXin Li 2012-2017 Jean-Marc Valin */
5*a58d3d2aSXin Li /*
6*a58d3d2aSXin Li Redistribution and use in source and binary forms, with or without
7*a58d3d2aSXin Li modification, are permitted provided that the following conditions
8*a58d3d2aSXin Li are met:
9*a58d3d2aSXin Li
10*a58d3d2aSXin Li - Redistributions of source code must retain the above copyright
11*a58d3d2aSXin Li notice, this list of conditions and the following disclaimer.
12*a58d3d2aSXin Li
13*a58d3d2aSXin Li - Redistributions in binary form must reproduce the above copyright
14*a58d3d2aSXin Li notice, this list of conditions and the following disclaimer in the
15*a58d3d2aSXin Li documentation and/or other materials provided with the distribution.
16*a58d3d2aSXin Li
17*a58d3d2aSXin Li THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
18*a58d3d2aSXin Li ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
19*a58d3d2aSXin Li LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
20*a58d3d2aSXin Li A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR
21*a58d3d2aSXin Li CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
22*a58d3d2aSXin Li EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
23*a58d3d2aSXin Li PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
24*a58d3d2aSXin Li PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
25*a58d3d2aSXin Li LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
26*a58d3d2aSXin Li NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
27*a58d3d2aSXin Li SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28*a58d3d2aSXin Li */
29*a58d3d2aSXin Li /* NEON support for ARM machines */
30*a58d3d2aSXin Li
31*a58d3d2aSXin Li #ifndef VEC_NEON_H
32*a58d3d2aSXin Li #define VEC_NEON_H
33*a58d3d2aSXin Li
34*a58d3d2aSXin Li #include <arm_neon.h>
35*a58d3d2aSXin Li #include "os_support.h"
36*a58d3d2aSXin Li
37*a58d3d2aSXin Li #if defined(__arm__) && !defined(__aarch64__)
38*a58d3d2aSXin Li /* Emulate vcvtnq_s32_f32() for ARMv7 Neon. */
vcvtnq_s32_f32(float32x4_t x)39*a58d3d2aSXin Li static OPUS_INLINE int32x4_t vcvtnq_s32_f32(float32x4_t x) {
40*a58d3d2aSXin Li return vrshrq_n_s32(vcvtq_n_s32_f32(x, 8), 8);
41*a58d3d2aSXin Li }
42*a58d3d2aSXin Li
vpaddq_s16(int16x8_t a,int16x8_t b)43*a58d3d2aSXin Li static OPUS_INLINE int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
44*a58d3d2aSXin Li return vcombine_s16(vpadd_s16(vget_low_s16(a), vget_high_s16(a)), vpadd_s16(vget_low_s16(b), vget_high_s16(b)));
45*a58d3d2aSXin Li }
46*a58d3d2aSXin Li
vmull_high_s8(int8x16_t a,int8x16_t b)47*a58d3d2aSXin Li static OPUS_INLINE int16x8_t vmull_high_s8(int8x16_t a, int8x16_t b) {
48*a58d3d2aSXin Li return vmull_s8(vget_high_s8(a), vget_high_s8(b));
49*a58d3d2aSXin Li }
50*a58d3d2aSXin Li #endif
51*a58d3d2aSXin Li
52*a58d3d2aSXin Li #ifdef __ARM_FEATURE_FMA
53*a58d3d2aSXin Li /* If we can, force the compiler to use an FMA instruction rather than break
54*a58d3d2aSXin Li vmlaq_f32() into fmul/fadd. */
55*a58d3d2aSXin Li #define vmlaq_f32(a,b,c) vfmaq_f32(a,b,c)
56*a58d3d2aSXin Li #endif
57*a58d3d2aSXin Li
58*a58d3d2aSXin Li #ifndef LPCNET_TEST
exp4_approx(float32x4_t x)59*a58d3d2aSXin Li static inline float32x4_t exp4_approx(float32x4_t x) {
60*a58d3d2aSXin Li int32x4_t i;
61*a58d3d2aSXin Li float32x4_t xf;
62*a58d3d2aSXin Li
63*a58d3d2aSXin Li x = vmaxq_f32(vminq_f32(x, vdupq_n_f32(88.f)), vdupq_n_f32(-88.f));
64*a58d3d2aSXin Li
65*a58d3d2aSXin Li /* express exp(x) as exp2(x/log(2)), add 127 for the exponent later */
66*a58d3d2aSXin Li x = vmlaq_f32(vdupq_n_f32(127.f), x, vdupq_n_f32(1.44269504f));
67*a58d3d2aSXin Li
68*a58d3d2aSXin Li /* split into integer and fractional parts */
69*a58d3d2aSXin Li i = vcvtq_s32_f32(x);
70*a58d3d2aSXin Li xf = vcvtq_f32_s32(i);
71*a58d3d2aSXin Li x = vsubq_f32(x, xf);
72*a58d3d2aSXin Li
73*a58d3d2aSXin Li float32x4_t K0 = vdupq_n_f32(0.99992522f);
74*a58d3d2aSXin Li float32x4_t K1 = vdupq_n_f32(0.69583354f);
75*a58d3d2aSXin Li float32x4_t K2 = vdupq_n_f32(0.22606716f);
76*a58d3d2aSXin Li float32x4_t K3 = vdupq_n_f32(0.078024523f);
77*a58d3d2aSXin Li float32x4_t Y = vmlaq_f32(K0, x, vmlaq_f32(K1, x, vmlaq_f32(K2, K3, x)));
78*a58d3d2aSXin Li
79*a58d3d2aSXin Li /* compute 2^i */
80*a58d3d2aSXin Li float32x4_t exponent = vreinterpretq_f32_s32(vshlq_n_s32(i, 23));
81*a58d3d2aSXin Li
82*a58d3d2aSXin Li Y = vmulq_f32(Y, exponent);
83*a58d3d2aSXin Li return Y;
84*a58d3d2aSXin Li }
85*a58d3d2aSXin Li
tanh4_approx(float32x4_t X)86*a58d3d2aSXin Li static inline float32x4_t tanh4_approx(float32x4_t X)
87*a58d3d2aSXin Li {
88*a58d3d2aSXin Li const float32x4_t N0 = vdupq_n_f32(952.52801514f);
89*a58d3d2aSXin Li const float32x4_t N1 = vdupq_n_f32(96.39235687f);
90*a58d3d2aSXin Li const float32x4_t N2 = vdupq_n_f32(0.60863042f);
91*a58d3d2aSXin Li const float32x4_t D0 = vdupq_n_f32(952.72399902f);
92*a58d3d2aSXin Li const float32x4_t D1 = vdupq_n_f32(413.36801147f);
93*a58d3d2aSXin Li const float32x4_t D2 = vdupq_n_f32(11.88600922f);
94*a58d3d2aSXin Li const float32x4_t max_out = vdupq_n_f32(1.f);
95*a58d3d2aSXin Li const float32x4_t min_out = vdupq_n_f32(-1.f);
96*a58d3d2aSXin Li float32x4_t X2, num, den;
97*a58d3d2aSXin Li X2 = vmulq_f32(X, X);
98*a58d3d2aSXin Li num = vmlaq_f32(N0, X2, vmlaq_f32(N1, N2, X2));
99*a58d3d2aSXin Li den = vmlaq_f32(D0, X2, vmlaq_f32(D1, D2, X2));
100*a58d3d2aSXin Li num = vmulq_f32(num, X);
101*a58d3d2aSXin Li den = vrecpeq_f32(den);
102*a58d3d2aSXin Li num = vmulq_f32(num, den);
103*a58d3d2aSXin Li return vmaxq_f32(min_out, vminq_f32(max_out, num));
104*a58d3d2aSXin Li }
105*a58d3d2aSXin Li
sigmoid4_approx(float32x4_t X)106*a58d3d2aSXin Li static inline float32x4_t sigmoid4_approx(float32x4_t X)
107*a58d3d2aSXin Li {
108*a58d3d2aSXin Li const float32x4_t N0 = vdupq_n_f32(238.13200378f);
109*a58d3d2aSXin Li const float32x4_t N1 = vdupq_n_f32(6.02452230f);
110*a58d3d2aSXin Li const float32x4_t N2 = vdupq_n_f32(0.00950985f);
111*a58d3d2aSXin Li const float32x4_t D0 = vdupq_n_f32(952.72399902f);
112*a58d3d2aSXin Li const float32x4_t D1 = vdupq_n_f32(103.34200287f);
113*a58d3d2aSXin Li const float32x4_t D2 = vdupq_n_f32(0.74287558f);
114*a58d3d2aSXin Li const float32x4_t half = vdupq_n_f32(0.5f);
115*a58d3d2aSXin Li const float32x4_t max_out = vdupq_n_f32(1.f);
116*a58d3d2aSXin Li const float32x4_t min_out = vdupq_n_f32(0.f);
117*a58d3d2aSXin Li float32x4_t X2, num, den;
118*a58d3d2aSXin Li X2 = vmulq_f32(X, X);
119*a58d3d2aSXin Li num = vmlaq_f32(N0, X2, vmlaq_f32(N1, N2, X2));
120*a58d3d2aSXin Li den = vmlaq_f32(D0, X2, vmlaq_f32(D1, D2, X2));
121*a58d3d2aSXin Li num = vmulq_f32(num, X);
122*a58d3d2aSXin Li den = vrecpeq_f32(den);
123*a58d3d2aSXin Li num = vmlaq_f32(half, num, den);
124*a58d3d2aSXin Li return vmaxq_f32(min_out, vminq_f32(max_out, num));
125*a58d3d2aSXin Li }
126*a58d3d2aSXin Li
lpcnet_exp(float x)127*a58d3d2aSXin Li static inline float lpcnet_exp(float x)
128*a58d3d2aSXin Li {
129*a58d3d2aSXin Li float out[4];
130*a58d3d2aSXin Li float32x4_t X, Y;
131*a58d3d2aSXin Li X = vdupq_n_f32(x);
132*a58d3d2aSXin Li Y = exp4_approx(X);
133*a58d3d2aSXin Li vst1q_f32(out, Y);
134*a58d3d2aSXin Li return out[0];
135*a58d3d2aSXin Li }
136*a58d3d2aSXin Li
tanh_approx(float x)137*a58d3d2aSXin Li static inline float tanh_approx(float x)
138*a58d3d2aSXin Li {
139*a58d3d2aSXin Li float out[4];
140*a58d3d2aSXin Li float32x4_t X, Y;
141*a58d3d2aSXin Li X = vdupq_n_f32(x);
142*a58d3d2aSXin Li Y = tanh4_approx(X);
143*a58d3d2aSXin Li vst1q_f32(out, Y);
144*a58d3d2aSXin Li return out[0];
145*a58d3d2aSXin Li }
146*a58d3d2aSXin Li
sigmoid_approx(float x)147*a58d3d2aSXin Li static inline float sigmoid_approx(float x)
148*a58d3d2aSXin Li {
149*a58d3d2aSXin Li float out[4];
150*a58d3d2aSXin Li float32x4_t X, Y;
151*a58d3d2aSXin Li X = vdupq_n_f32(x);
152*a58d3d2aSXin Li Y = sigmoid4_approx(X);
153*a58d3d2aSXin Li vst1q_f32(out, Y);
154*a58d3d2aSXin Li return out[0];
155*a58d3d2aSXin Li }
156*a58d3d2aSXin Li
softmax(float * y,const float * x,int N)157*a58d3d2aSXin Li static inline void softmax(float *y, const float *x, int N)
158*a58d3d2aSXin Li {
159*a58d3d2aSXin Li int i;
160*a58d3d2aSXin Li for (i=0;i<N-3;i+=4)
161*a58d3d2aSXin Li {
162*a58d3d2aSXin Li float32x4_t X, Y;
163*a58d3d2aSXin Li X = vld1q_f32(&x[i]);
164*a58d3d2aSXin Li Y = exp4_approx(X);
165*a58d3d2aSXin Li vst1q_f32(&y[i], Y);
166*a58d3d2aSXin Li }
167*a58d3d2aSXin Li for (;i<N;i++)
168*a58d3d2aSXin Li y[i] = lpcnet_exp(x[i]);
169*a58d3d2aSXin Li }
170*a58d3d2aSXin Li
vec_tanh(float * y,const float * x,int N)171*a58d3d2aSXin Li static inline void vec_tanh(float *y, const float *x, int N)
172*a58d3d2aSXin Li {
173*a58d3d2aSXin Li int i;
174*a58d3d2aSXin Li for (i=0;i<N-3;i+=4)
175*a58d3d2aSXin Li {
176*a58d3d2aSXin Li float32x4_t X, Y;
177*a58d3d2aSXin Li X = vld1q_f32(&x[i]);
178*a58d3d2aSXin Li Y = tanh4_approx(X);
179*a58d3d2aSXin Li vst1q_f32(&y[i], Y);
180*a58d3d2aSXin Li }
181*a58d3d2aSXin Li for (;i<N;i++)
182*a58d3d2aSXin Li {
183*a58d3d2aSXin Li float ex2;
184*a58d3d2aSXin Li ex2 = lpcnet_exp(2*x[i]);
185*a58d3d2aSXin Li y[i] = (ex2-1)/(ex2+1);
186*a58d3d2aSXin Li }
187*a58d3d2aSXin Li }
188*a58d3d2aSXin Li
vec_sigmoid(float * y,const float * x,int N)189*a58d3d2aSXin Li static inline void vec_sigmoid(float *y, const float *x, int N)
190*a58d3d2aSXin Li {
191*a58d3d2aSXin Li int i;
192*a58d3d2aSXin Li for (i=0;i<N-3;i+=4)
193*a58d3d2aSXin Li {
194*a58d3d2aSXin Li float32x4_t X, Y;
195*a58d3d2aSXin Li X = vld1q_f32(&x[i]);
196*a58d3d2aSXin Li Y = sigmoid4_approx(X);
197*a58d3d2aSXin Li vst1q_f32(&y[i], Y);
198*a58d3d2aSXin Li }
199*a58d3d2aSXin Li for (;i<N;i++)
200*a58d3d2aSXin Li {
201*a58d3d2aSXin Li float ex;
202*a58d3d2aSXin Li ex = lpcnet_exp(x[i]);
203*a58d3d2aSXin Li y[i] = (ex)/(ex+1);
204*a58d3d2aSXin Li }
205*a58d3d2aSXin Li }
206*a58d3d2aSXin Li #endif
207*a58d3d2aSXin Li
sgemv16x1(float * out,const float * weights,int rows,int cols,int col_stride,const float * x)208*a58d3d2aSXin Li static inline void sgemv16x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
209*a58d3d2aSXin Li {
210*a58d3d2aSXin Li int i, j;
211*a58d3d2aSXin Li for (i=0;i<rows;i+=16)
212*a58d3d2aSXin Li {
213*a58d3d2aSXin Li float * restrict y = &out[i];
214*a58d3d2aSXin Li
215*a58d3d2aSXin Li /* keep y[0..15] in registers for duration of inner loop */
216*a58d3d2aSXin Li
217*a58d3d2aSXin Li float32x4_t y0_3 = vdupq_n_f32(0);
218*a58d3d2aSXin Li float32x4_t y4_7 = vdupq_n_f32(0);
219*a58d3d2aSXin Li float32x4_t y8_11 = vdupq_n_f32(0);
220*a58d3d2aSXin Li float32x4_t y12_15 = vdupq_n_f32(0);
221*a58d3d2aSXin Li
222*a58d3d2aSXin Li for (j=0;j<cols;j++)
223*a58d3d2aSXin Li {
224*a58d3d2aSXin Li const float * restrict w;
225*a58d3d2aSXin Li float32x4_t wvec0_3, wvec4_7, wvec8_11, wvec12_15;
226*a58d3d2aSXin Li float32x4_t xj;
227*a58d3d2aSXin Li
228*a58d3d2aSXin Li w = &weights[j*col_stride + i];
229*a58d3d2aSXin Li wvec0_3 = vld1q_f32(&w[0]);
230*a58d3d2aSXin Li wvec4_7 = vld1q_f32(&w[4]);
231*a58d3d2aSXin Li wvec8_11 = vld1q_f32(&w[8]);
232*a58d3d2aSXin Li wvec12_15 = vld1q_f32(&w[12]);
233*a58d3d2aSXin Li
234*a58d3d2aSXin Li xj = vld1q_dup_f32(&x[j]);
235*a58d3d2aSXin Li
236*a58d3d2aSXin Li y0_3 = vmlaq_f32(y0_3, wvec0_3, xj);
237*a58d3d2aSXin Li y4_7 = vmlaq_f32(y4_7, wvec4_7, xj);
238*a58d3d2aSXin Li y8_11 = vmlaq_f32(y8_11, wvec8_11, xj);
239*a58d3d2aSXin Li y12_15 = vmlaq_f32(y12_15, wvec12_15, xj);
240*a58d3d2aSXin Li }
241*a58d3d2aSXin Li
242*a58d3d2aSXin Li /* save y[0..15] back to memory */
243*a58d3d2aSXin Li
244*a58d3d2aSXin Li vst1q_f32(&y[0], y0_3);
245*a58d3d2aSXin Li vst1q_f32(&y[4], y4_7);
246*a58d3d2aSXin Li vst1q_f32(&y[8], y8_11);
247*a58d3d2aSXin Li vst1q_f32(&y[12], y12_15);
248*a58d3d2aSXin Li
249*a58d3d2aSXin Li }
250*a58d3d2aSXin Li }
251*a58d3d2aSXin Li
sgemv8x1(float * out,const float * weights,int rows,int cols,int col_stride,const float * x)252*a58d3d2aSXin Li static inline void sgemv8x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
253*a58d3d2aSXin Li {
254*a58d3d2aSXin Li int i, j;
255*a58d3d2aSXin Li for (i=0;i<rows;i+=8)
256*a58d3d2aSXin Li {
257*a58d3d2aSXin Li float * restrict y = &out[i];
258*a58d3d2aSXin Li
259*a58d3d2aSXin Li /* keep y[0..15] in registers for duration of inner loop */
260*a58d3d2aSXin Li
261*a58d3d2aSXin Li float32x4_t y0_3 = vdupq_n_f32(0);
262*a58d3d2aSXin Li float32x4_t y4_7 = vdupq_n_f32(0);
263*a58d3d2aSXin Li
264*a58d3d2aSXin Li for (j=0;j<cols;j++)
265*a58d3d2aSXin Li {
266*a58d3d2aSXin Li const float * restrict w;
267*a58d3d2aSXin Li float32x4_t wvec0_3, wvec4_7;
268*a58d3d2aSXin Li float32x4_t xj;
269*a58d3d2aSXin Li
270*a58d3d2aSXin Li w = &weights[j*col_stride + i];
271*a58d3d2aSXin Li wvec0_3 = vld1q_f32(&w[0]);
272*a58d3d2aSXin Li wvec4_7 = vld1q_f32(&w[4]);
273*a58d3d2aSXin Li
274*a58d3d2aSXin Li xj = vld1q_dup_f32(&x[j]);
275*a58d3d2aSXin Li
276*a58d3d2aSXin Li y0_3 = vmlaq_f32(y0_3, wvec0_3, xj);
277*a58d3d2aSXin Li y4_7 = vmlaq_f32(y4_7, wvec4_7, xj);
278*a58d3d2aSXin Li }
279*a58d3d2aSXin Li
280*a58d3d2aSXin Li /* save y[0..15] back to memory */
281*a58d3d2aSXin Li
282*a58d3d2aSXin Li vst1q_f32(&y[0], y0_3);
283*a58d3d2aSXin Li vst1q_f32(&y[4], y4_7);
284*a58d3d2aSXin Li }
285*a58d3d2aSXin Li }
286*a58d3d2aSXin Li
sgemv(float * out,const float * weights,int rows,int cols,int col_stride,const float * x)287*a58d3d2aSXin Li static inline void sgemv(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
288*a58d3d2aSXin Li {
289*a58d3d2aSXin Li if ((rows&0xf) == 0) sgemv16x1(out, weights, rows, cols, col_stride, x);
290*a58d3d2aSXin Li else if ((rows&0x7) == 0) sgemv8x1(out, weights, rows, cols, col_stride, x);
291*a58d3d2aSXin Li else {
292*a58d3d2aSXin Li int i, j;
293*a58d3d2aSXin Li for (i=0;i<rows;i++)
294*a58d3d2aSXin Li {
295*a58d3d2aSXin Li out[i] = 0;
296*a58d3d2aSXin Li for (j=0;j<cols;j++) out[i] += weights[j*col_stride + i]*x[j];
297*a58d3d2aSXin Li }
298*a58d3d2aSXin Li }
299*a58d3d2aSXin Li }
300*a58d3d2aSXin Li
301*a58d3d2aSXin Li /* Temporarily use unoptimized version */
sparse_sgemv8x4(float * out,const float * w,const int * idx,int rows,const float * x)302*a58d3d2aSXin Li static inline void sparse_sgemv8x4(float *out, const float *w, const int *idx, int rows, const float *x)
303*a58d3d2aSXin Li {
304*a58d3d2aSXin Li int i, j;
305*a58d3d2aSXin Li OPUS_CLEAR(out, rows);
306*a58d3d2aSXin Li for (i=0;i<rows;i+=8)
307*a58d3d2aSXin Li {
308*a58d3d2aSXin Li int cols;
309*a58d3d2aSXin Li cols = *idx++;
310*a58d3d2aSXin Li for (j=0;j<cols;j++)
311*a58d3d2aSXin Li {
312*a58d3d2aSXin Li int pos;
313*a58d3d2aSXin Li float * restrict y;
314*a58d3d2aSXin Li float xj0, xj1, xj2, xj3;
315*a58d3d2aSXin Li pos = (*idx++);
316*a58d3d2aSXin Li xj0 = x[pos+0];
317*a58d3d2aSXin Li xj1 = x[pos+1];
318*a58d3d2aSXin Li xj2 = x[pos+2];
319*a58d3d2aSXin Li xj3 = x[pos+3];
320*a58d3d2aSXin Li y = &out[i];
321*a58d3d2aSXin Li y[0] += w[0]*xj0;
322*a58d3d2aSXin Li y[1] += w[1]*xj0;
323*a58d3d2aSXin Li y[2] += w[2]*xj0;
324*a58d3d2aSXin Li y[3] += w[3]*xj0;
325*a58d3d2aSXin Li y[4] += w[4]*xj0;
326*a58d3d2aSXin Li y[5] += w[5]*xj0;
327*a58d3d2aSXin Li y[6] += w[6]*xj0;
328*a58d3d2aSXin Li y[7] += w[7]*xj0;
329*a58d3d2aSXin Li
330*a58d3d2aSXin Li y[0] += w[8]*xj1;
331*a58d3d2aSXin Li y[1] += w[9]*xj1;
332*a58d3d2aSXin Li y[2] += w[10]*xj1;
333*a58d3d2aSXin Li y[3] += w[11]*xj1;
334*a58d3d2aSXin Li y[4] += w[12]*xj1;
335*a58d3d2aSXin Li y[5] += w[13]*xj1;
336*a58d3d2aSXin Li y[6] += w[14]*xj1;
337*a58d3d2aSXin Li y[7] += w[15]*xj1;
338*a58d3d2aSXin Li
339*a58d3d2aSXin Li y[0] += w[16]*xj2;
340*a58d3d2aSXin Li y[1] += w[17]*xj2;
341*a58d3d2aSXin Li y[2] += w[18]*xj2;
342*a58d3d2aSXin Li y[3] += w[19]*xj2;
343*a58d3d2aSXin Li y[4] += w[20]*xj2;
344*a58d3d2aSXin Li y[5] += w[21]*xj2;
345*a58d3d2aSXin Li y[6] += w[22]*xj2;
346*a58d3d2aSXin Li y[7] += w[23]*xj2;
347*a58d3d2aSXin Li
348*a58d3d2aSXin Li y[0] += w[24]*xj3;
349*a58d3d2aSXin Li y[1] += w[25]*xj3;
350*a58d3d2aSXin Li y[2] += w[26]*xj3;
351*a58d3d2aSXin Li y[3] += w[27]*xj3;
352*a58d3d2aSXin Li y[4] += w[28]*xj3;
353*a58d3d2aSXin Li y[5] += w[29]*xj3;
354*a58d3d2aSXin Li y[6] += w[30]*xj3;
355*a58d3d2aSXin Li y[7] += w[31]*xj3;
356*a58d3d2aSXin Li w += 32;
357*a58d3d2aSXin Li }
358*a58d3d2aSXin Li }
359*a58d3d2aSXin Li }
360*a58d3d2aSXin Li
361*a58d3d2aSXin Li
362*a58d3d2aSXin Li #define SCALE (128.f*127.f)
363*a58d3d2aSXin Li #define SCALE_1 (1.f/128.f/127.f)
364*a58d3d2aSXin Li
365*a58d3d2aSXin Li #define MAX_INPUTS 2048
366*a58d3d2aSXin Li #define MAX_OUTPUTS 8192
367*a58d3d2aSXin Li
368*a58d3d2aSXin Li #if __ARM_FEATURE_DOTPROD
vdotprod(int32x4_t acc,int8x16_t a,int8x16_t b)369*a58d3d2aSXin Li static inline int32x4_t vdotprod(int32x4_t acc, int8x16_t a, int8x16_t b) {
370*a58d3d2aSXin Li return vdotq_s32(acc, a, b);
371*a58d3d2aSXin Li }
372*a58d3d2aSXin Li #else
vdotprod(int32x4_t acc,int8x16_t a,int8x16_t b)373*a58d3d2aSXin Li static inline int32x4_t vdotprod(int32x4_t acc, int8x16_t a, int8x16_t b)
374*a58d3d2aSXin Li {
375*a58d3d2aSXin Li return vpadalq_s16(acc, vpaddq_s16(vmull_s8(vget_low_s8(a), vget_low_s8(b)), vmull_high_s8(a, b)));
376*a58d3d2aSXin Li }
377*a58d3d2aSXin Li #endif
378*a58d3d2aSXin Li
cgemv8x4(float * _out,const opus_int8 * w,const float * scale,int rows,int cols,const float * _x)379*a58d3d2aSXin Li static inline void cgemv8x4(float *_out, const opus_int8 *w, const float *scale, int rows, int cols, const float *_x)
380*a58d3d2aSXin Li {
381*a58d3d2aSXin Li int i, j;
382*a58d3d2aSXin Li opus_int32 x_int[MAX_INPUTS/4];
383*a58d3d2aSXin Li opus_int8 *x = (opus_int8*) x_int;
384*a58d3d2aSXin Li const float32x4_t const127 = vdupq_n_f32(127.);
385*a58d3d2aSXin Li for (i=0;i<cols;i+=8) {
386*a58d3d2aSXin Li int32x4_t xi0, xi4;
387*a58d3d2aSXin Li int16x8_t x_short;
388*a58d3d2aSXin Li xi0 = vcvtnq_s32_f32(vmulq_f32(const127, vld1q_f32(&_x[i])));
389*a58d3d2aSXin Li xi4 = vcvtnq_s32_f32(vmulq_f32(const127, vld1q_f32(&_x[i+4])));
390*a58d3d2aSXin Li x_short = vcombine_s16(vmovn_s32(xi0), vmovn_s32(xi4));
391*a58d3d2aSXin Li vst1_s8(&x[i], vmovn_s16(x_short));
392*a58d3d2aSXin Li }
393*a58d3d2aSXin Li for (i=0;i<rows;i+=8)
394*a58d3d2aSXin Li {
395*a58d3d2aSXin Li int32x4_t acc0, acc1;
396*a58d3d2aSXin Li int32x4_t acc2, acc3;
397*a58d3d2aSXin Li acc0 = vdupq_n_s32(0);
398*a58d3d2aSXin Li acc1 = vdupq_n_s32(0);
399*a58d3d2aSXin Li acc2 = vdupq_n_s32(0);
400*a58d3d2aSXin Li acc3 = vdupq_n_s32(0);
401*a58d3d2aSXin Li j=0;
402*a58d3d2aSXin Li for (;j<cols-4;j+=8)
403*a58d3d2aSXin Li {
404*a58d3d2aSXin Li int8x16_t vw0, vw1, vw2, vw3, vx0, vx1;
405*a58d3d2aSXin Li vx0 = (int8x16_t)vld1q_dup_s32((int*)(void*)&x[j]);
406*a58d3d2aSXin Li vw0 = vld1q_s8(w);
407*a58d3d2aSXin Li vw1 = vld1q_s8(&w[16]);
408*a58d3d2aSXin Li acc0 = vdotprod(acc0, vw0, vx0);
409*a58d3d2aSXin Li acc1 = vdotprod(acc1, vw1, vx0);
410*a58d3d2aSXin Li vx1 = (int8x16_t)vld1q_dup_s32((int*)(void*)&x[j+4]);
411*a58d3d2aSXin Li vw2 = vld1q_s8(&w[32]);
412*a58d3d2aSXin Li vw3 = vld1q_s8(&w[48]);
413*a58d3d2aSXin Li acc2 = vdotprod(acc2, vw2, vx1);
414*a58d3d2aSXin Li acc3 = vdotprod(acc3, vw3, vx1);
415*a58d3d2aSXin Li w += 64;
416*a58d3d2aSXin Li }
417*a58d3d2aSXin Li acc0 = vaddq_s32(acc0, acc2);
418*a58d3d2aSXin Li acc1 = vaddq_s32(acc1, acc3);
419*a58d3d2aSXin Li for (;j<cols;j+=4)
420*a58d3d2aSXin Li {
421*a58d3d2aSXin Li int8x16_t vw0, vw1, vx;
422*a58d3d2aSXin Li vx = (int8x16_t)vld1q_dup_s32((int*)(void*)&x[j]);
423*a58d3d2aSXin Li vw0 = vld1q_s8(w);
424*a58d3d2aSXin Li vw1 = vld1q_s8(&w[16]);
425*a58d3d2aSXin Li acc0 = vdotprod(acc0, vw0, vx);
426*a58d3d2aSXin Li acc1 = vdotprod(acc1, vw1, vx);
427*a58d3d2aSXin Li w += 32;
428*a58d3d2aSXin Li }
429*a58d3d2aSXin Li vst1q_f32(&_out[i], vmulq_f32(vld1q_f32(&scale[i]), vcvtq_f32_s32(acc0)));
430*a58d3d2aSXin Li vst1q_f32(&_out[i+4], vmulq_f32(vld1q_f32(&scale[i+4]), vcvtq_f32_s32(acc1)));
431*a58d3d2aSXin Li }
432*a58d3d2aSXin Li }
433*a58d3d2aSXin Li
sparse_cgemv8x4(float * _out,const opus_int8 * w,const int * idx,const float * scale,int rows,int cols,const float * _x)434*a58d3d2aSXin Li static inline void sparse_cgemv8x4(float *_out, const opus_int8 *w, const int *idx, const float *scale, int rows, int cols, const float *_x)
435*a58d3d2aSXin Li {
436*a58d3d2aSXin Li int i, j;
437*a58d3d2aSXin Li opus_int32 x_int[MAX_INPUTS/4];
438*a58d3d2aSXin Li opus_int8 *x = (opus_int8*) x_int;
439*a58d3d2aSXin Li const float32x4_t const127 = vdupq_n_f32(127.);
440*a58d3d2aSXin Li for (i=0;i<cols;i+=8) {
441*a58d3d2aSXin Li int32x4_t xi0, xi4;
442*a58d3d2aSXin Li int16x8_t x_short;
443*a58d3d2aSXin Li xi0 = vcvtnq_s32_f32(vmulq_f32(const127, vld1q_f32(&_x[i])));
444*a58d3d2aSXin Li xi4 = vcvtnq_s32_f32(vmulq_f32(const127, vld1q_f32(&_x[i+4])));
445*a58d3d2aSXin Li x_short = vcombine_s16(vmovn_s32(xi0), vmovn_s32(xi4));
446*a58d3d2aSXin Li vst1_s8(&x[i], vmovn_s16(x_short));
447*a58d3d2aSXin Li }
448*a58d3d2aSXin Li for (i=0;i<rows;i+=8)
449*a58d3d2aSXin Li {
450*a58d3d2aSXin Li int colblocks;
451*a58d3d2aSXin Li int32x4_t acc0, acc1;
452*a58d3d2aSXin Li acc0 = vdupq_n_s32(0);
453*a58d3d2aSXin Li acc1 = vdupq_n_s32(0);
454*a58d3d2aSXin Li colblocks = *idx++;
455*a58d3d2aSXin Li for (j=0;j<colblocks;j++)
456*a58d3d2aSXin Li {
457*a58d3d2aSXin Li int pos;
458*a58d3d2aSXin Li pos = (*idx++);
459*a58d3d2aSXin Li int8x16_t vw0, vw1, vx;
460*a58d3d2aSXin Li vx = (int8x16_t)vld1q_dup_s32((int*)(void*)&x[pos]);
461*a58d3d2aSXin Li vw0 = vld1q_s8(w);
462*a58d3d2aSXin Li vw1 = vld1q_s8(&w[16]);
463*a58d3d2aSXin Li acc0 = vdotprod(acc0, vw0, vx);
464*a58d3d2aSXin Li acc1 = vdotprod(acc1, vw1, vx);
465*a58d3d2aSXin Li w += 32;
466*a58d3d2aSXin Li }
467*a58d3d2aSXin Li vst1q_f32(&_out[i], vmulq_f32(vld1q_f32(&scale[i]), vcvtq_f32_s32(acc0)));
468*a58d3d2aSXin Li vst1q_f32(&_out[i+4], vmulq_f32(vld1q_f32(&scale[i+4]), vcvtq_f32_s32(acc1)));
469*a58d3d2aSXin Li }
470*a58d3d2aSXin Li }
471*a58d3d2aSXin Li
472*a58d3d2aSXin Li
473*a58d3d2aSXin Li #endif
474