xref: /aosp_15_r20/external/libopus/dnn/vec_neon.h (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1*a58d3d2aSXin Li /* Copyright (c) 2018 David Rowe
2*a58d3d2aSXin Li                  2018 Mozilla
3*a58d3d2aSXin Li                  2008-2011 Octasic Inc.
4*a58d3d2aSXin Li                  2012-2017 Jean-Marc Valin */
5*a58d3d2aSXin Li /*
6*a58d3d2aSXin Li    Redistribution and use in source and binary forms, with or without
7*a58d3d2aSXin Li    modification, are permitted provided that the following conditions
8*a58d3d2aSXin Li    are met:
9*a58d3d2aSXin Li 
10*a58d3d2aSXin Li    - Redistributions of source code must retain the above copyright
11*a58d3d2aSXin Li    notice, this list of conditions and the following disclaimer.
12*a58d3d2aSXin Li 
13*a58d3d2aSXin Li    - Redistributions in binary form must reproduce the above copyright
14*a58d3d2aSXin Li    notice, this list of conditions and the following disclaimer in the
15*a58d3d2aSXin Li    documentation and/or other materials provided with the distribution.
16*a58d3d2aSXin Li 
17*a58d3d2aSXin Li    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
18*a58d3d2aSXin Li    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
19*a58d3d2aSXin Li    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
20*a58d3d2aSXin Li    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
21*a58d3d2aSXin Li    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
22*a58d3d2aSXin Li    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
23*a58d3d2aSXin Li    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
24*a58d3d2aSXin Li    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
25*a58d3d2aSXin Li    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
26*a58d3d2aSXin Li    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
27*a58d3d2aSXin Li    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28*a58d3d2aSXin Li */
29*a58d3d2aSXin Li /* NEON support for ARM machines */
30*a58d3d2aSXin Li 
31*a58d3d2aSXin Li #ifndef VEC_NEON_H
32*a58d3d2aSXin Li #define VEC_NEON_H
33*a58d3d2aSXin Li 
34*a58d3d2aSXin Li #include <arm_neon.h>
35*a58d3d2aSXin Li #include "os_support.h"
36*a58d3d2aSXin Li 
37*a58d3d2aSXin Li #if defined(__arm__) && !defined(__aarch64__)
38*a58d3d2aSXin Li /* Emulate vcvtnq_s32_f32() for ARMv7 Neon. */
vcvtnq_s32_f32(float32x4_t x)39*a58d3d2aSXin Li static OPUS_INLINE int32x4_t vcvtnq_s32_f32(float32x4_t x) {
40*a58d3d2aSXin Li   return vrshrq_n_s32(vcvtq_n_s32_f32(x, 8), 8);
41*a58d3d2aSXin Li }
42*a58d3d2aSXin Li 
vpaddq_s16(int16x8_t a,int16x8_t b)43*a58d3d2aSXin Li static OPUS_INLINE int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
44*a58d3d2aSXin Li   return vcombine_s16(vpadd_s16(vget_low_s16(a), vget_high_s16(a)), vpadd_s16(vget_low_s16(b), vget_high_s16(b)));
45*a58d3d2aSXin Li }
46*a58d3d2aSXin Li 
vmull_high_s8(int8x16_t a,int8x16_t b)47*a58d3d2aSXin Li static OPUS_INLINE int16x8_t vmull_high_s8(int8x16_t a, int8x16_t b) {
48*a58d3d2aSXin Li   return vmull_s8(vget_high_s8(a), vget_high_s8(b));
49*a58d3d2aSXin Li }
50*a58d3d2aSXin Li #endif
51*a58d3d2aSXin Li 
52*a58d3d2aSXin Li #ifdef __ARM_FEATURE_FMA
53*a58d3d2aSXin Li /* If we can, force the compiler to use an FMA instruction rather than break
54*a58d3d2aSXin Li    vmlaq_f32() into fmul/fadd. */
55*a58d3d2aSXin Li #define vmlaq_f32(a,b,c) vfmaq_f32(a,b,c)
56*a58d3d2aSXin Li #endif
57*a58d3d2aSXin Li 
58*a58d3d2aSXin Li #ifndef LPCNET_TEST
exp4_approx(float32x4_t x)59*a58d3d2aSXin Li static inline float32x4_t exp4_approx(float32x4_t x) {
60*a58d3d2aSXin Li   int32x4_t i;
61*a58d3d2aSXin Li   float32x4_t xf;
62*a58d3d2aSXin Li 
63*a58d3d2aSXin Li   x = vmaxq_f32(vminq_f32(x, vdupq_n_f32(88.f)), vdupq_n_f32(-88.f));
64*a58d3d2aSXin Li 
65*a58d3d2aSXin Li   /* express exp(x) as exp2(x/log(2)), add 127 for the exponent later */
66*a58d3d2aSXin Li   x = vmlaq_f32(vdupq_n_f32(127.f), x, vdupq_n_f32(1.44269504f));
67*a58d3d2aSXin Li 
68*a58d3d2aSXin Li   /* split into integer and fractional parts */
69*a58d3d2aSXin Li   i = vcvtq_s32_f32(x);
70*a58d3d2aSXin Li   xf = vcvtq_f32_s32(i);
71*a58d3d2aSXin Li   x = vsubq_f32(x, xf);
72*a58d3d2aSXin Li 
73*a58d3d2aSXin Li   float32x4_t K0 = vdupq_n_f32(0.99992522f);
74*a58d3d2aSXin Li   float32x4_t K1 = vdupq_n_f32(0.69583354f);
75*a58d3d2aSXin Li   float32x4_t K2 = vdupq_n_f32(0.22606716f);
76*a58d3d2aSXin Li   float32x4_t K3 = vdupq_n_f32(0.078024523f);
77*a58d3d2aSXin Li   float32x4_t Y = vmlaq_f32(K0, x, vmlaq_f32(K1, x, vmlaq_f32(K2, K3, x)));
78*a58d3d2aSXin Li 
79*a58d3d2aSXin Li   /* compute 2^i */
80*a58d3d2aSXin Li   float32x4_t exponent = vreinterpretq_f32_s32(vshlq_n_s32(i, 23));
81*a58d3d2aSXin Li 
82*a58d3d2aSXin Li   Y = vmulq_f32(Y, exponent);
83*a58d3d2aSXin Li   return Y;
84*a58d3d2aSXin Li }
85*a58d3d2aSXin Li 
tanh4_approx(float32x4_t X)86*a58d3d2aSXin Li static inline float32x4_t tanh4_approx(float32x4_t X)
87*a58d3d2aSXin Li {
88*a58d3d2aSXin Li   const float32x4_t N0 = vdupq_n_f32(952.52801514f);
89*a58d3d2aSXin Li   const float32x4_t N1 = vdupq_n_f32(96.39235687f);
90*a58d3d2aSXin Li   const float32x4_t N2 = vdupq_n_f32(0.60863042f);
91*a58d3d2aSXin Li   const float32x4_t D0 = vdupq_n_f32(952.72399902f);
92*a58d3d2aSXin Li   const float32x4_t D1 = vdupq_n_f32(413.36801147f);
93*a58d3d2aSXin Li   const float32x4_t D2 = vdupq_n_f32(11.88600922f);
94*a58d3d2aSXin Li   const float32x4_t max_out = vdupq_n_f32(1.f);
95*a58d3d2aSXin Li   const float32x4_t min_out = vdupq_n_f32(-1.f);
96*a58d3d2aSXin Li   float32x4_t X2, num, den;
97*a58d3d2aSXin Li   X2 = vmulq_f32(X, X);
98*a58d3d2aSXin Li   num = vmlaq_f32(N0, X2, vmlaq_f32(N1, N2, X2));
99*a58d3d2aSXin Li   den = vmlaq_f32(D0, X2, vmlaq_f32(D1, D2, X2));
100*a58d3d2aSXin Li   num = vmulq_f32(num, X);
101*a58d3d2aSXin Li   den = vrecpeq_f32(den);
102*a58d3d2aSXin Li   num = vmulq_f32(num, den);
103*a58d3d2aSXin Li   return vmaxq_f32(min_out, vminq_f32(max_out, num));
104*a58d3d2aSXin Li }
105*a58d3d2aSXin Li 
sigmoid4_approx(float32x4_t X)106*a58d3d2aSXin Li static inline float32x4_t sigmoid4_approx(float32x4_t X)
107*a58d3d2aSXin Li {
108*a58d3d2aSXin Li   const float32x4_t N0 = vdupq_n_f32(238.13200378f);
109*a58d3d2aSXin Li   const float32x4_t N1 = vdupq_n_f32(6.02452230f);
110*a58d3d2aSXin Li   const float32x4_t N2 = vdupq_n_f32(0.00950985f);
111*a58d3d2aSXin Li   const float32x4_t D0 = vdupq_n_f32(952.72399902f);
112*a58d3d2aSXin Li   const float32x4_t D1 = vdupq_n_f32(103.34200287f);
113*a58d3d2aSXin Li   const float32x4_t D2 = vdupq_n_f32(0.74287558f);
114*a58d3d2aSXin Li   const float32x4_t half = vdupq_n_f32(0.5f);
115*a58d3d2aSXin Li   const float32x4_t max_out = vdupq_n_f32(1.f);
116*a58d3d2aSXin Li   const float32x4_t min_out = vdupq_n_f32(0.f);
117*a58d3d2aSXin Li   float32x4_t X2, num, den;
118*a58d3d2aSXin Li   X2 = vmulq_f32(X, X);
119*a58d3d2aSXin Li   num = vmlaq_f32(N0, X2, vmlaq_f32(N1, N2, X2));
120*a58d3d2aSXin Li   den = vmlaq_f32(D0, X2, vmlaq_f32(D1, D2, X2));
121*a58d3d2aSXin Li   num = vmulq_f32(num, X);
122*a58d3d2aSXin Li   den = vrecpeq_f32(den);
123*a58d3d2aSXin Li   num = vmlaq_f32(half, num, den);
124*a58d3d2aSXin Li   return vmaxq_f32(min_out, vminq_f32(max_out, num));
125*a58d3d2aSXin Li }
126*a58d3d2aSXin Li 
lpcnet_exp(float x)127*a58d3d2aSXin Li static inline float lpcnet_exp(float x)
128*a58d3d2aSXin Li {
129*a58d3d2aSXin Li    float out[4];
130*a58d3d2aSXin Li    float32x4_t X, Y;
131*a58d3d2aSXin Li    X = vdupq_n_f32(x);
132*a58d3d2aSXin Li    Y = exp4_approx(X);
133*a58d3d2aSXin Li    vst1q_f32(out, Y);
134*a58d3d2aSXin Li    return out[0];
135*a58d3d2aSXin Li }
136*a58d3d2aSXin Li 
tanh_approx(float x)137*a58d3d2aSXin Li static inline float tanh_approx(float x)
138*a58d3d2aSXin Li {
139*a58d3d2aSXin Li    float out[4];
140*a58d3d2aSXin Li    float32x4_t X, Y;
141*a58d3d2aSXin Li    X = vdupq_n_f32(x);
142*a58d3d2aSXin Li    Y = tanh4_approx(X);
143*a58d3d2aSXin Li    vst1q_f32(out, Y);
144*a58d3d2aSXin Li    return out[0];
145*a58d3d2aSXin Li }
146*a58d3d2aSXin Li 
sigmoid_approx(float x)147*a58d3d2aSXin Li static inline float sigmoid_approx(float x)
148*a58d3d2aSXin Li {
149*a58d3d2aSXin Li    float out[4];
150*a58d3d2aSXin Li    float32x4_t X, Y;
151*a58d3d2aSXin Li    X = vdupq_n_f32(x);
152*a58d3d2aSXin Li    Y = sigmoid4_approx(X);
153*a58d3d2aSXin Li    vst1q_f32(out, Y);
154*a58d3d2aSXin Li    return out[0];
155*a58d3d2aSXin Li }
156*a58d3d2aSXin Li 
softmax(float * y,const float * x,int N)157*a58d3d2aSXin Li static inline void softmax(float *y, const float *x, int N)
158*a58d3d2aSXin Li {
159*a58d3d2aSXin Li     int i;
160*a58d3d2aSXin Li     for (i=0;i<N-3;i+=4)
161*a58d3d2aSXin Li     {
162*a58d3d2aSXin Li         float32x4_t X, Y;
163*a58d3d2aSXin Li         X = vld1q_f32(&x[i]);
164*a58d3d2aSXin Li         Y = exp4_approx(X);
165*a58d3d2aSXin Li         vst1q_f32(&y[i], Y);
166*a58d3d2aSXin Li     }
167*a58d3d2aSXin Li     for (;i<N;i++)
168*a58d3d2aSXin Li         y[i] = lpcnet_exp(x[i]);
169*a58d3d2aSXin Li }
170*a58d3d2aSXin Li 
vec_tanh(float * y,const float * x,int N)171*a58d3d2aSXin Li static inline void vec_tanh(float *y, const float *x, int N)
172*a58d3d2aSXin Li {
173*a58d3d2aSXin Li     int i;
174*a58d3d2aSXin Li     for (i=0;i<N-3;i+=4)
175*a58d3d2aSXin Li     {
176*a58d3d2aSXin Li         float32x4_t X, Y;
177*a58d3d2aSXin Li         X = vld1q_f32(&x[i]);
178*a58d3d2aSXin Li         Y = tanh4_approx(X);
179*a58d3d2aSXin Li         vst1q_f32(&y[i], Y);
180*a58d3d2aSXin Li     }
181*a58d3d2aSXin Li     for (;i<N;i++)
182*a58d3d2aSXin Li     {
183*a58d3d2aSXin Li         float ex2;
184*a58d3d2aSXin Li         ex2 = lpcnet_exp(2*x[i]);
185*a58d3d2aSXin Li         y[i] = (ex2-1)/(ex2+1);
186*a58d3d2aSXin Li     }
187*a58d3d2aSXin Li }
188*a58d3d2aSXin Li 
vec_sigmoid(float * y,const float * x,int N)189*a58d3d2aSXin Li static inline void vec_sigmoid(float *y, const float *x, int N)
190*a58d3d2aSXin Li {
191*a58d3d2aSXin Li     int i;
192*a58d3d2aSXin Li     for (i=0;i<N-3;i+=4)
193*a58d3d2aSXin Li     {
194*a58d3d2aSXin Li         float32x4_t X, Y;
195*a58d3d2aSXin Li         X = vld1q_f32(&x[i]);
196*a58d3d2aSXin Li         Y = sigmoid4_approx(X);
197*a58d3d2aSXin Li         vst1q_f32(&y[i], Y);
198*a58d3d2aSXin Li     }
199*a58d3d2aSXin Li     for (;i<N;i++)
200*a58d3d2aSXin Li     {
201*a58d3d2aSXin Li         float ex;
202*a58d3d2aSXin Li         ex = lpcnet_exp(x[i]);
203*a58d3d2aSXin Li         y[i] = (ex)/(ex+1);
204*a58d3d2aSXin Li     }
205*a58d3d2aSXin Li }
206*a58d3d2aSXin Li #endif
207*a58d3d2aSXin Li 
sgemv16x1(float * out,const float * weights,int rows,int cols,int col_stride,const float * x)208*a58d3d2aSXin Li static inline void sgemv16x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
209*a58d3d2aSXin Li {
210*a58d3d2aSXin Li     int i, j;
211*a58d3d2aSXin Li     for (i=0;i<rows;i+=16)
212*a58d3d2aSXin Li     {
213*a58d3d2aSXin Li 	float * restrict y = &out[i];
214*a58d3d2aSXin Li 
215*a58d3d2aSXin Li 	/* keep y[0..15] in registers for duration of inner loop */
216*a58d3d2aSXin Li 
217*a58d3d2aSXin Li 	float32x4_t y0_3 = vdupq_n_f32(0);
218*a58d3d2aSXin Li 	float32x4_t y4_7 = vdupq_n_f32(0);
219*a58d3d2aSXin Li 	float32x4_t y8_11 = vdupq_n_f32(0);
220*a58d3d2aSXin Li 	float32x4_t y12_15 = vdupq_n_f32(0);
221*a58d3d2aSXin Li 
222*a58d3d2aSXin Li 	for (j=0;j<cols;j++)
223*a58d3d2aSXin Li 	{
224*a58d3d2aSXin Li 	    const float * restrict w;
225*a58d3d2aSXin Li 	    float32x4_t wvec0_3, wvec4_7, wvec8_11, wvec12_15;
226*a58d3d2aSXin Li 	    float32x4_t xj;
227*a58d3d2aSXin Li 
228*a58d3d2aSXin Li 	    w = &weights[j*col_stride + i];
229*a58d3d2aSXin Li 	    wvec0_3 = vld1q_f32(&w[0]);
230*a58d3d2aSXin Li 	    wvec4_7 = vld1q_f32(&w[4]);
231*a58d3d2aSXin Li 	    wvec8_11 = vld1q_f32(&w[8]);
232*a58d3d2aSXin Li 	    wvec12_15 = vld1q_f32(&w[12]);
233*a58d3d2aSXin Li 
234*a58d3d2aSXin Li 	    xj = vld1q_dup_f32(&x[j]);
235*a58d3d2aSXin Li 
236*a58d3d2aSXin Li 	    y0_3 = vmlaq_f32(y0_3, wvec0_3, xj);
237*a58d3d2aSXin Li 	    y4_7 = vmlaq_f32(y4_7, wvec4_7, xj);
238*a58d3d2aSXin Li 	    y8_11 = vmlaq_f32(y8_11, wvec8_11, xj);
239*a58d3d2aSXin Li 	    y12_15 = vmlaq_f32(y12_15, wvec12_15, xj);
240*a58d3d2aSXin Li 	}
241*a58d3d2aSXin Li 
242*a58d3d2aSXin Li 	/* save y[0..15] back to memory */
243*a58d3d2aSXin Li 
244*a58d3d2aSXin Li 	vst1q_f32(&y[0], y0_3);
245*a58d3d2aSXin Li 	vst1q_f32(&y[4], y4_7);
246*a58d3d2aSXin Li 	vst1q_f32(&y[8], y8_11);
247*a58d3d2aSXin Li 	vst1q_f32(&y[12], y12_15);
248*a58d3d2aSXin Li 
249*a58d3d2aSXin Li     }
250*a58d3d2aSXin Li }
251*a58d3d2aSXin Li 
sgemv8x1(float * out,const float * weights,int rows,int cols,int col_stride,const float * x)252*a58d3d2aSXin Li static inline void sgemv8x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
253*a58d3d2aSXin Li {
254*a58d3d2aSXin Li     int i, j;
255*a58d3d2aSXin Li     for (i=0;i<rows;i+=8)
256*a58d3d2aSXin Li     {
257*a58d3d2aSXin Li     float * restrict y = &out[i];
258*a58d3d2aSXin Li 
259*a58d3d2aSXin Li     /* keep y[0..15] in registers for duration of inner loop */
260*a58d3d2aSXin Li 
261*a58d3d2aSXin Li     float32x4_t y0_3 = vdupq_n_f32(0);
262*a58d3d2aSXin Li     float32x4_t y4_7 = vdupq_n_f32(0);
263*a58d3d2aSXin Li 
264*a58d3d2aSXin Li     for (j=0;j<cols;j++)
265*a58d3d2aSXin Li     {
266*a58d3d2aSXin Li         const float * restrict w;
267*a58d3d2aSXin Li         float32x4_t wvec0_3, wvec4_7;
268*a58d3d2aSXin Li         float32x4_t xj;
269*a58d3d2aSXin Li 
270*a58d3d2aSXin Li         w = &weights[j*col_stride + i];
271*a58d3d2aSXin Li         wvec0_3 = vld1q_f32(&w[0]);
272*a58d3d2aSXin Li         wvec4_7 = vld1q_f32(&w[4]);
273*a58d3d2aSXin Li 
274*a58d3d2aSXin Li         xj = vld1q_dup_f32(&x[j]);
275*a58d3d2aSXin Li 
276*a58d3d2aSXin Li         y0_3 = vmlaq_f32(y0_3, wvec0_3, xj);
277*a58d3d2aSXin Li         y4_7 = vmlaq_f32(y4_7, wvec4_7, xj);
278*a58d3d2aSXin Li     }
279*a58d3d2aSXin Li 
280*a58d3d2aSXin Li     /* save y[0..15] back to memory */
281*a58d3d2aSXin Li 
282*a58d3d2aSXin Li     vst1q_f32(&y[0], y0_3);
283*a58d3d2aSXin Li     vst1q_f32(&y[4], y4_7);
284*a58d3d2aSXin Li     }
285*a58d3d2aSXin Li }
286*a58d3d2aSXin Li 
sgemv(float * out,const float * weights,int rows,int cols,int col_stride,const float * x)287*a58d3d2aSXin Li static inline void sgemv(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
288*a58d3d2aSXin Li {
289*a58d3d2aSXin Li    if ((rows&0xf) == 0) sgemv16x1(out, weights, rows, cols, col_stride, x);
290*a58d3d2aSXin Li    else if ((rows&0x7) == 0) sgemv8x1(out, weights, rows, cols, col_stride, x);
291*a58d3d2aSXin Li    else {
292*a58d3d2aSXin Li       int i, j;
293*a58d3d2aSXin Li       for (i=0;i<rows;i++)
294*a58d3d2aSXin Li       {
295*a58d3d2aSXin Li          out[i] = 0;
296*a58d3d2aSXin Li          for (j=0;j<cols;j++) out[i] += weights[j*col_stride + i]*x[j];
297*a58d3d2aSXin Li       }
298*a58d3d2aSXin Li    }
299*a58d3d2aSXin Li }
300*a58d3d2aSXin Li 
301*a58d3d2aSXin Li /* Temporarily use unoptimized version */
sparse_sgemv8x4(float * out,const float * w,const int * idx,int rows,const float * x)302*a58d3d2aSXin Li static inline void sparse_sgemv8x4(float *out, const float *w, const int *idx, int rows, const float *x)
303*a58d3d2aSXin Li {
304*a58d3d2aSXin Li    int i, j;
305*a58d3d2aSXin Li    OPUS_CLEAR(out, rows);
306*a58d3d2aSXin Li    for (i=0;i<rows;i+=8)
307*a58d3d2aSXin Li    {
308*a58d3d2aSXin Li       int cols;
309*a58d3d2aSXin Li       cols = *idx++;
310*a58d3d2aSXin Li       for (j=0;j<cols;j++)
311*a58d3d2aSXin Li       {
312*a58d3d2aSXin Li          int pos;
313*a58d3d2aSXin Li          float * restrict y;
314*a58d3d2aSXin Li          float xj0, xj1, xj2, xj3;
315*a58d3d2aSXin Li          pos = (*idx++);
316*a58d3d2aSXin Li          xj0 = x[pos+0];
317*a58d3d2aSXin Li          xj1 = x[pos+1];
318*a58d3d2aSXin Li          xj2 = x[pos+2];
319*a58d3d2aSXin Li          xj3 = x[pos+3];
320*a58d3d2aSXin Li          y = &out[i];
321*a58d3d2aSXin Li          y[0] += w[0]*xj0;
322*a58d3d2aSXin Li          y[1] += w[1]*xj0;
323*a58d3d2aSXin Li          y[2] += w[2]*xj0;
324*a58d3d2aSXin Li          y[3] += w[3]*xj0;
325*a58d3d2aSXin Li          y[4] += w[4]*xj0;
326*a58d3d2aSXin Li          y[5] += w[5]*xj0;
327*a58d3d2aSXin Li          y[6] += w[6]*xj0;
328*a58d3d2aSXin Li          y[7] += w[7]*xj0;
329*a58d3d2aSXin Li 
330*a58d3d2aSXin Li          y[0] += w[8]*xj1;
331*a58d3d2aSXin Li          y[1] += w[9]*xj1;
332*a58d3d2aSXin Li          y[2] += w[10]*xj1;
333*a58d3d2aSXin Li          y[3] += w[11]*xj1;
334*a58d3d2aSXin Li          y[4] += w[12]*xj1;
335*a58d3d2aSXin Li          y[5] += w[13]*xj1;
336*a58d3d2aSXin Li          y[6] += w[14]*xj1;
337*a58d3d2aSXin Li          y[7] += w[15]*xj1;
338*a58d3d2aSXin Li 
339*a58d3d2aSXin Li          y[0] += w[16]*xj2;
340*a58d3d2aSXin Li          y[1] += w[17]*xj2;
341*a58d3d2aSXin Li          y[2] += w[18]*xj2;
342*a58d3d2aSXin Li          y[3] += w[19]*xj2;
343*a58d3d2aSXin Li          y[4] += w[20]*xj2;
344*a58d3d2aSXin Li          y[5] += w[21]*xj2;
345*a58d3d2aSXin Li          y[6] += w[22]*xj2;
346*a58d3d2aSXin Li          y[7] += w[23]*xj2;
347*a58d3d2aSXin Li 
348*a58d3d2aSXin Li          y[0] += w[24]*xj3;
349*a58d3d2aSXin Li          y[1] += w[25]*xj3;
350*a58d3d2aSXin Li          y[2] += w[26]*xj3;
351*a58d3d2aSXin Li          y[3] += w[27]*xj3;
352*a58d3d2aSXin Li          y[4] += w[28]*xj3;
353*a58d3d2aSXin Li          y[5] += w[29]*xj3;
354*a58d3d2aSXin Li          y[6] += w[30]*xj3;
355*a58d3d2aSXin Li          y[7] += w[31]*xj3;
356*a58d3d2aSXin Li          w += 32;
357*a58d3d2aSXin Li       }
358*a58d3d2aSXin Li    }
359*a58d3d2aSXin Li }
360*a58d3d2aSXin Li 
361*a58d3d2aSXin Li 
362*a58d3d2aSXin Li #define SCALE (128.f*127.f)
363*a58d3d2aSXin Li #define SCALE_1 (1.f/128.f/127.f)
364*a58d3d2aSXin Li 
365*a58d3d2aSXin Li #define MAX_INPUTS 2048
366*a58d3d2aSXin Li #define MAX_OUTPUTS 8192
367*a58d3d2aSXin Li 
368*a58d3d2aSXin Li #if __ARM_FEATURE_DOTPROD
vdotprod(int32x4_t acc,int8x16_t a,int8x16_t b)369*a58d3d2aSXin Li static inline int32x4_t vdotprod(int32x4_t acc, int8x16_t a, int8x16_t b) {
370*a58d3d2aSXin Li   return vdotq_s32(acc, a, b);
371*a58d3d2aSXin Li }
372*a58d3d2aSXin Li #else
vdotprod(int32x4_t acc,int8x16_t a,int8x16_t b)373*a58d3d2aSXin Li static inline int32x4_t vdotprod(int32x4_t acc, int8x16_t a, int8x16_t b)
374*a58d3d2aSXin Li {
375*a58d3d2aSXin Li   return vpadalq_s16(acc, vpaddq_s16(vmull_s8(vget_low_s8(a), vget_low_s8(b)),  vmull_high_s8(a, b)));
376*a58d3d2aSXin Li }
377*a58d3d2aSXin Li #endif
378*a58d3d2aSXin Li 
cgemv8x4(float * _out,const opus_int8 * w,const float * scale,int rows,int cols,const float * _x)379*a58d3d2aSXin Li static inline void cgemv8x4(float *_out, const opus_int8 *w, const float *scale, int rows, int cols, const float *_x)
380*a58d3d2aSXin Li {
381*a58d3d2aSXin Li    int i, j;
382*a58d3d2aSXin Li    opus_int32 x_int[MAX_INPUTS/4];
383*a58d3d2aSXin Li    opus_int8 *x = (opus_int8*) x_int;
384*a58d3d2aSXin Li    const float32x4_t const127 = vdupq_n_f32(127.);
385*a58d3d2aSXin Li    for (i=0;i<cols;i+=8) {
386*a58d3d2aSXin Li       int32x4_t xi0, xi4;
387*a58d3d2aSXin Li       int16x8_t x_short;
388*a58d3d2aSXin Li       xi0 = vcvtnq_s32_f32(vmulq_f32(const127, vld1q_f32(&_x[i])));
389*a58d3d2aSXin Li       xi4 = vcvtnq_s32_f32(vmulq_f32(const127, vld1q_f32(&_x[i+4])));
390*a58d3d2aSXin Li       x_short = vcombine_s16(vmovn_s32(xi0), vmovn_s32(xi4));
391*a58d3d2aSXin Li       vst1_s8(&x[i], vmovn_s16(x_short));
392*a58d3d2aSXin Li    }
393*a58d3d2aSXin Li    for (i=0;i<rows;i+=8)
394*a58d3d2aSXin Li    {
395*a58d3d2aSXin Li       int32x4_t acc0, acc1;
396*a58d3d2aSXin Li       int32x4_t acc2, acc3;
397*a58d3d2aSXin Li       acc0 = vdupq_n_s32(0);
398*a58d3d2aSXin Li       acc1 = vdupq_n_s32(0);
399*a58d3d2aSXin Li       acc2 = vdupq_n_s32(0);
400*a58d3d2aSXin Li       acc3 = vdupq_n_s32(0);
401*a58d3d2aSXin Li       j=0;
402*a58d3d2aSXin Li       for (;j<cols-4;j+=8)
403*a58d3d2aSXin Li       {
404*a58d3d2aSXin Li          int8x16_t vw0, vw1, vw2, vw3, vx0, vx1;
405*a58d3d2aSXin Li          vx0 = (int8x16_t)vld1q_dup_s32((int*)(void*)&x[j]);
406*a58d3d2aSXin Li          vw0 = vld1q_s8(w);
407*a58d3d2aSXin Li          vw1 = vld1q_s8(&w[16]);
408*a58d3d2aSXin Li          acc0 = vdotprod(acc0, vw0, vx0);
409*a58d3d2aSXin Li          acc1 = vdotprod(acc1, vw1, vx0);
410*a58d3d2aSXin Li          vx1 = (int8x16_t)vld1q_dup_s32((int*)(void*)&x[j+4]);
411*a58d3d2aSXin Li          vw2 = vld1q_s8(&w[32]);
412*a58d3d2aSXin Li          vw3 = vld1q_s8(&w[48]);
413*a58d3d2aSXin Li          acc2 = vdotprod(acc2, vw2, vx1);
414*a58d3d2aSXin Li          acc3 = vdotprod(acc3, vw3, vx1);
415*a58d3d2aSXin Li          w += 64;
416*a58d3d2aSXin Li       }
417*a58d3d2aSXin Li       acc0 = vaddq_s32(acc0, acc2);
418*a58d3d2aSXin Li       acc1 = vaddq_s32(acc1, acc3);
419*a58d3d2aSXin Li       for (;j<cols;j+=4)
420*a58d3d2aSXin Li       {
421*a58d3d2aSXin Li          int8x16_t vw0, vw1, vx;
422*a58d3d2aSXin Li          vx = (int8x16_t)vld1q_dup_s32((int*)(void*)&x[j]);
423*a58d3d2aSXin Li          vw0 = vld1q_s8(w);
424*a58d3d2aSXin Li          vw1 = vld1q_s8(&w[16]);
425*a58d3d2aSXin Li          acc0 = vdotprod(acc0, vw0, vx);
426*a58d3d2aSXin Li          acc1 = vdotprod(acc1, vw1, vx);
427*a58d3d2aSXin Li          w += 32;
428*a58d3d2aSXin Li       }
429*a58d3d2aSXin Li       vst1q_f32(&_out[i], vmulq_f32(vld1q_f32(&scale[i]), vcvtq_f32_s32(acc0)));
430*a58d3d2aSXin Li       vst1q_f32(&_out[i+4], vmulq_f32(vld1q_f32(&scale[i+4]), vcvtq_f32_s32(acc1)));
431*a58d3d2aSXin Li    }
432*a58d3d2aSXin Li }
433*a58d3d2aSXin Li 
sparse_cgemv8x4(float * _out,const opus_int8 * w,const int * idx,const float * scale,int rows,int cols,const float * _x)434*a58d3d2aSXin Li static inline void sparse_cgemv8x4(float *_out, const opus_int8 *w, const int *idx, const float *scale, int rows, int cols, const float *_x)
435*a58d3d2aSXin Li {
436*a58d3d2aSXin Li    int i, j;
437*a58d3d2aSXin Li    opus_int32 x_int[MAX_INPUTS/4];
438*a58d3d2aSXin Li    opus_int8 *x = (opus_int8*) x_int;
439*a58d3d2aSXin Li    const float32x4_t const127 = vdupq_n_f32(127.);
440*a58d3d2aSXin Li    for (i=0;i<cols;i+=8) {
441*a58d3d2aSXin Li       int32x4_t xi0, xi4;
442*a58d3d2aSXin Li       int16x8_t x_short;
443*a58d3d2aSXin Li       xi0 = vcvtnq_s32_f32(vmulq_f32(const127, vld1q_f32(&_x[i])));
444*a58d3d2aSXin Li       xi4 = vcvtnq_s32_f32(vmulq_f32(const127, vld1q_f32(&_x[i+4])));
445*a58d3d2aSXin Li       x_short = vcombine_s16(vmovn_s32(xi0), vmovn_s32(xi4));
446*a58d3d2aSXin Li       vst1_s8(&x[i], vmovn_s16(x_short));
447*a58d3d2aSXin Li    }
448*a58d3d2aSXin Li    for (i=0;i<rows;i+=8)
449*a58d3d2aSXin Li    {
450*a58d3d2aSXin Li       int colblocks;
451*a58d3d2aSXin Li       int32x4_t acc0, acc1;
452*a58d3d2aSXin Li       acc0 = vdupq_n_s32(0);
453*a58d3d2aSXin Li       acc1 = vdupq_n_s32(0);
454*a58d3d2aSXin Li       colblocks = *idx++;
455*a58d3d2aSXin Li       for (j=0;j<colblocks;j++)
456*a58d3d2aSXin Li       {
457*a58d3d2aSXin Li          int pos;
458*a58d3d2aSXin Li          pos = (*idx++);
459*a58d3d2aSXin Li          int8x16_t vw0, vw1, vx;
460*a58d3d2aSXin Li          vx = (int8x16_t)vld1q_dup_s32((int*)(void*)&x[pos]);
461*a58d3d2aSXin Li          vw0 = vld1q_s8(w);
462*a58d3d2aSXin Li          vw1 = vld1q_s8(&w[16]);
463*a58d3d2aSXin Li          acc0 = vdotprod(acc0, vw0, vx);
464*a58d3d2aSXin Li          acc1 = vdotprod(acc1, vw1, vx);
465*a58d3d2aSXin Li          w += 32;
466*a58d3d2aSXin Li       }
467*a58d3d2aSXin Li       vst1q_f32(&_out[i], vmulq_f32(vld1q_f32(&scale[i]), vcvtq_f32_s32(acc0)));
468*a58d3d2aSXin Li       vst1q_f32(&_out[i+4], vmulq_f32(vld1q_f32(&scale[i+4]), vcvtq_f32_s32(acc1)));
469*a58d3d2aSXin Li    }
470*a58d3d2aSXin Li }
471*a58d3d2aSXin Li 
472*a58d3d2aSXin Li 
473*a58d3d2aSXin Li #endif
474