xref: /aosp_15_r20/external/libopus/dnn/vec_avx.h (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1 /* Copyright (c) 2018 Mozilla
2                  2012-2017 Jean-Marc Valin */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7 
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10 
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14 
15    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
19    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27 /*
28   AVX implementation of vector operations, compile with -mavx
29   AVX2/FMA implementation of vector operations, compile with -mavx2 -mfma
30 */
31 
32 #ifndef VEC_AVX_H
33 #define VEC_AVX_H
34 
35 #include <immintrin.h>
36 #include <math.h>
37 #include "celt/x86/x86cpu.h"
38 
39 #define MAX_INPUTS (2048)
40 
41 #define USE_SU_BIAS
42 
43 #ifndef __SSE_4_1__
mm_floor_ps(__m128 x)44 static inline __m128 mm_floor_ps(__m128 x) {
45   __m128 half = _mm_set1_ps(0.5);
46   return _mm_cvtepi32_ps(_mm_cvtps_epi32(_mm_sub_ps(x, half)));
47 }
48 #undef _mm_floor_ps
49 #define _mm_floor_ps(x) mm_floor_ps(x)
50 #endif
51 
52 
53 /* If we don't have AVX available, emulate what we need with SSE up to 4.1. */
54 #ifndef __AVX__
55 
56 typedef struct {
57   __m128 lo;
58   __m128 hi;
59 } mm256_emu;
60 #define __m256 mm256_emu
61 
mm256_loadu_ps(const float * src)62 static inline mm256_emu mm256_loadu_ps(const float *src) {
63   mm256_emu ret;
64   ret.lo = _mm_loadu_ps(&src[0]);
65   ret.hi = _mm_loadu_ps(&src[4]);
66   return ret;
67 }
68 #define _mm256_loadu_ps(src) mm256_loadu_ps(src)
69 
70 
mm256_storeu_ps(float * dst,mm256_emu src)71 static inline void mm256_storeu_ps(float *dst, mm256_emu src) {
72   _mm_storeu_ps(dst, src.lo);
73   _mm_storeu_ps(&dst[4], src.hi);
74 }
75 #define _mm256_storeu_ps(dst, src) mm256_storeu_ps(dst, src)
76 
77 
mm256_setzero_ps(void)78 static inline mm256_emu mm256_setzero_ps(void) {
79   mm256_emu ret;
80   ret.lo = _mm_setzero_ps();
81   ret.hi = ret.lo;
82   return ret;
83 }
84 #define _mm256_setzero_ps mm256_setzero_ps
85 
mm256_broadcast_ss(const float * x)86 static inline mm256_emu mm256_broadcast_ss(const float *x) {
87   mm256_emu ret;
88   ret.lo = _mm_set1_ps(*x);
89   ret.hi = ret.lo;
90   return ret;
91 }
92 #define _mm256_broadcast_ss(x) mm256_broadcast_ss(x)
93 
mm256_set1_ps(float x)94 static inline mm256_emu mm256_set1_ps(float x) {
95   mm256_emu ret;
96   ret.lo = _mm_set1_ps(x);
97   ret.hi = ret.lo;
98   return ret;
99 }
100 #define _mm256_set1_ps(x) mm256_set1_ps(x)
101 
102 
103 
mm256_mul_ps(mm256_emu a,mm256_emu b)104 static inline mm256_emu mm256_mul_ps(mm256_emu a, mm256_emu b) {
105   mm256_emu ret;
106   ret.lo = _mm_mul_ps(a.lo, b.lo);
107   ret.hi = _mm_mul_ps(a.hi, b.hi);
108   return ret;
109 }
110 #define _mm256_mul_ps(a,b) mm256_mul_ps(a,b)
111 
mm256_add_ps(mm256_emu a,mm256_emu b)112 static inline mm256_emu mm256_add_ps(mm256_emu a, mm256_emu b) {
113   mm256_emu ret;
114   ret.lo = _mm_add_ps(a.lo, b.lo);
115   ret.hi = _mm_add_ps(a.hi, b.hi);
116   return ret;
117 }
118 #define _mm256_add_ps(a,b) mm256_add_ps(a,b)
119 
120 
mm256_max_ps(mm256_emu a,mm256_emu b)121 static inline mm256_emu mm256_max_ps(mm256_emu a, mm256_emu b) {
122   mm256_emu ret;
123   ret.lo = _mm_max_ps(a.lo, b.lo);
124   ret.hi = _mm_max_ps(a.hi, b.hi);
125   return ret;
126 }
127 #define _mm256_max_ps(a,b) mm256_max_ps(a,b)
128 
mm256_min_ps(mm256_emu a,mm256_emu b)129 static inline mm256_emu mm256_min_ps(mm256_emu a, mm256_emu b) {
130   mm256_emu ret;
131   ret.lo = _mm_min_ps(a.lo, b.lo);
132   ret.hi = _mm_min_ps(a.hi, b.hi);
133   return ret;
134 }
135 #define _mm256_min_ps(a,b) mm256_min_ps(a,b)
136 
mm256_rcp_ps(mm256_emu a)137 static inline mm256_emu mm256_rcp_ps(mm256_emu a) {
138   mm256_emu ret;
139   ret.lo = _mm_rcp_ps(a.lo);
140   ret.hi = _mm_rcp_ps(a.hi);
141   return ret;
142 }
143 #define _mm256_rcp_ps(a) mm256_rcp_ps(a)
144 
145 
mm256_extractf128_ps(mm256_emu x,int i)146 static inline __m128 mm256_extractf128_ps(mm256_emu x, int i) {
147     return (i==0) ? x.lo : x.hi;
148 }
149 #undef _mm256_extractf128_ps
150 #define _mm256_extractf128_ps(x,i) mm256_extractf128_ps(x,i)
151 
mm256_insertf128_ps(mm256_emu dst,__m128 src,int i)152 static inline mm256_emu mm256_insertf128_ps(mm256_emu dst, __m128 src, int i) {
153     if (i==0) dst.lo = src;
154     else dst.hi = src;
155     return dst;
156 }
157 #undef _mm256_insertf128_ps
158 #define _mm256_insertf128_ps(dst,src,i) mm256_insertf128_ps(dst,src,i)
159 
160 #endif /* __AVX__ */
161 
162 
163 
164 /* If we don't have AVX2 available, emulate what we need with SSE up to 4.1. */
165 #ifndef __AVX2__
166 
167 typedef struct {
168   __m128i lo;
169   __m128i hi;
170 } mm256i_emu;
171 typedef __m256i real_m256i;
172 #define __m256i mm256i_emu
173 
mm256_setzero_si256(void)174 static inline mm256i_emu mm256_setzero_si256(void) {
175   mm256i_emu ret;
176   ret.lo = _mm_setzero_si128();
177   ret.hi = ret.lo;
178   return ret;
179 }
180 #define _mm256_setzero_si256 mm256_setzero_si256
181 
182 
mm256_loadu_si256(const mm256i_emu * src)183 static inline mm256i_emu mm256_loadu_si256(const mm256i_emu *src) {
184   mm256i_emu ret;
185   ret.lo = _mm_loadu_si128((const __m128i*)src);
186   ret.hi = _mm_loadu_si128(&((const __m128i*)src)[1]);
187   return ret;
188 }
189 #define _mm256_loadu_si256(src) mm256_loadu_si256(src)
190 
191 
mm256_storeu_si256(mm256i_emu * dst,mm256i_emu src)192 static inline void mm256_storeu_si256(mm256i_emu *dst, mm256i_emu src) {
193   _mm_storeu_si128((__m128i*)dst, src.lo);
194   _mm_storeu_si128(&((__m128i*)dst)[1], src.hi);
195 }
196 #define _mm256_storeu_si256(dst, src) mm256_storeu_si256(dst, src)
197 
198 
mm256_broadcastd_epi32(__m128i x)199 static inline mm256i_emu mm256_broadcastd_epi32(__m128i x) {
200   mm256i_emu ret;
201   ret.hi = ret.lo = _mm_shuffle_epi32(x, 0);
202   return ret;
203 }
204 #define _mm256_broadcastd_epi32(x) mm256_broadcastd_epi32(x)
205 
206 
mm256_set1_epi32(int x)207 static inline mm256i_emu mm256_set1_epi32(int x) {
208   mm256i_emu ret;
209   ret.lo = _mm_set1_epi32(x);
210   ret.hi = ret.lo;
211   return ret;
212 }
213 #define _mm256_set1_epi32(x) mm256_set1_epi32(x)
214 
mm256_set1_epi16(int x)215 static inline mm256i_emu mm256_set1_epi16(int x) {
216   mm256i_emu ret;
217   ret.lo = _mm_set1_epi16(x);
218   ret.hi = ret.lo;
219   return ret;
220 }
221 #define _mm256_set1_epi16(x) mm256_set1_epi16(x)
222 
223 
mm256_add_epi32(mm256i_emu a,mm256i_emu b)224 static inline mm256i_emu mm256_add_epi32(mm256i_emu a, mm256i_emu b) {
225   mm256i_emu ret;
226   ret.lo = _mm_add_epi32(a.lo, b.lo);
227   ret.hi = _mm_add_epi32(a.hi, b.hi);
228   return ret;
229 }
230 #define _mm256_add_epi32(a,b) mm256_add_epi32(a,b)
231 
mm256_madd_epi16(mm256i_emu a,mm256i_emu b)232 static inline mm256i_emu mm256_madd_epi16(mm256i_emu a, mm256i_emu b) {
233   mm256i_emu ret;
234   ret.lo = _mm_madd_epi16(a.lo, b.lo);
235   ret.hi = _mm_madd_epi16(a.hi, b.hi);
236   return ret;
237 }
238 #define _mm256_madd_epi16(a,b) mm256_madd_epi16(a,b)
239 
mm256_maddubs_epi16(mm256i_emu a,mm256i_emu b)240 static inline mm256i_emu mm256_maddubs_epi16(mm256i_emu a, mm256i_emu b) {
241   mm256i_emu ret;
242   ret.lo = _mm_maddubs_epi16(a.lo, b.lo);
243   ret.hi = _mm_maddubs_epi16(a.hi, b.hi);
244   return ret;
245 }
246 #define _mm256_maddubs_epi16(a,b) mm256_maddubs_epi16(a,b)
247 
248 
249 
250 /* Emulating the conversion functions is tricky because they use __m256i but are defined in AVX.
251    So we need to make a special when only AVX is available. */
252 #ifdef __AVX__
253 
254 typedef union {
255   mm256i_emu fake;
256   real_m256i real;
257 } mm256_union;
258 
mm256_cvtepi32_ps(mm256i_emu a)259 static inline __m256 mm256_cvtepi32_ps(mm256i_emu a) {
260   mm256_union src;
261   src.fake = a;
262   return _mm256_cvtepi32_ps(src.real);
263 }
264 #define _mm256_cvtepi32_ps(a) mm256_cvtepi32_ps(a)
265 
mm256_cvtps_epi32(__m256 a)266 static inline mm256i_emu mm256_cvtps_epi32(__m256 a) {
267   mm256_union ret;
268   ret.real =   _mm256_cvtps_epi32(a);
269   return ret.fake;
270 }
271 #define _mm256_cvtps_epi32(a) mm256_cvtps_epi32(a)
272 
273 
274 #else
275 
mm256_cvtepi32_ps(mm256i_emu a)276 static inline mm256_emu mm256_cvtepi32_ps(mm256i_emu a) {
277   mm256_emu ret;
278   ret.lo = _mm_cvtepi32_ps(a.lo);
279   ret.hi = _mm_cvtepi32_ps(a.hi);
280   return ret;
281 }
282 #define _mm256_cvtepi32_ps(a) mm256_cvtepi32_ps(a)
283 
mm256_cvtps_epi32(mm256_emu a)284 static inline mm256i_emu mm256_cvtps_epi32(mm256_emu a) {
285   mm256i_emu ret;
286   ret.lo = _mm_cvtps_epi32(a.lo);
287   ret.hi = _mm_cvtps_epi32(a.hi);
288   return ret;
289 }
290 #define _mm256_cvtps_epi32(a) mm256_cvtps_epi32(a)
291 
292 #endif /* __AVX__ */
293 
294 
295 #endif /* __AVX2__ */
296 
297 /* In case we don't have FMA, make it a mul and an add. */
298 #if !(defined(__FMA__) && defined(__AVX__))
299 #define _mm256_fmadd_ps(a,b,c) _mm256_add_ps(_mm256_mul_ps(a, b), c)
300 #define _mm_fmadd_ps(a,b,c) _mm_add_ps(_mm_mul_ps(a, b), c)
301 #endif
302 
303 #ifdef __AVX2__
exp8_approx(__m256 X)304 static inline __m256 exp8_approx(__m256 X)
305 {
306    const __m256 K0 = _mm256_set1_ps(0.99992522f);
307    const __m256 K1 = _mm256_set1_ps(0.69583354f);
308    const __m256 K2 = _mm256_set1_ps(0.22606716f);
309    const __m256 K3 = _mm256_set1_ps(0.078024523f);
310    const __m256 log2_E = _mm256_set1_ps(1.44269504f);
311    const __m256 max_in = _mm256_set1_ps(50.f);
312    const __m256 min_in = _mm256_set1_ps(-50.f);
313    __m256 XF, Y;
314    __m256i I;
315    X = _mm256_mul_ps(X, log2_E);
316    X = _mm256_max_ps(min_in, _mm256_min_ps(max_in, X));
317    XF = _mm256_floor_ps(X);
318    I = _mm256_cvtps_epi32(XF);
319    X = _mm256_sub_ps(X, XF);
320    Y = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(K3, X, K2), X, K1), X, K0);
321    I = _mm256_slli_epi32(I, 23);
322    Y = _mm256_castsi256_ps(_mm256_add_epi32(I, _mm256_castps_si256(Y)));
323    return Y;
324 }
325 
vector_ps_to_epi8(unsigned char * x,const float * _x,int len)326 static inline void vector_ps_to_epi8(unsigned char *x, const float *_x, int len) {
327     int i;
328    __m256 const127 = _mm256_set1_ps(127.f);
329     for (i=0;i<len;i+=8) {
330        __m256 xf;
331        __m256i xi;
332        xf = _mm256_loadu_ps(&_x[i]);
333        xf = _mm256_fmadd_ps(xf, const127, const127);
334        xi = _mm256_cvtps_epi32(xf);
335        xi = _mm256_packus_epi32(xi,  _mm256_setzero_si256());
336        xi = _mm256_permute4x64_epi64(xi, 0xD8);
337        xi = _mm256_packus_epi16(xi, _mm256_setzero_si256());
338        xi = _mm256_permutevar8x32_epi32(xi, _mm256_setr_epi32(0,1, 0,0, 0,0, 0,0));
339        _mm256_storeu_si256 ((__m256i *)(void*)&x[i], xi);
340    }
341 }
342 
343 #else
exp4_approx(__m128 X)344 static inline __m128 exp4_approx(__m128 X)
345 {
346    const __m128 K0 = _mm_set1_ps(0.99992522f);
347    const __m128 K1 = _mm_set1_ps(0.69583354f);
348    const __m128 K2 = _mm_set1_ps(0.22606716f);
349    const __m128 K3 = _mm_set1_ps(0.078024523f);
350    const __m128 log2_E = _mm_set1_ps(1.44269504);
351    const __m128 max_in = _mm_set1_ps(50.f);
352    const __m128 min_in = _mm_set1_ps(-50.f);
353    const __m128i mask = _mm_set1_epi32(0x7fffffff);
354    __m128 XF, Y;
355    __m128i I;
356    X = _mm_mul_ps(X, log2_E);
357    X = _mm_max_ps(min_in, _mm_min_ps(max_in, X));
358    XF = _mm_floor_ps(X);
359    I = _mm_cvtps_epi32(XF);
360    X = _mm_sub_ps(X, XF);
361    Y = _mm_fmadd_ps(_mm_fmadd_ps(_mm_fmadd_ps(K3, X, K2), X, K1), X, K0);
362    I = _mm_slli_epi32(I, 23);
363    Y = _mm_castsi128_ps(_mm_and_si128(mask, _mm_add_epi32(I, _mm_castps_si128(Y))));
364    return Y;
365 }
exp8_approx(__m256 X)366 static inline __m256 exp8_approx(__m256 X)
367 {
368    __m256 Y;
369    __m128 Xhi, Xlo, Yhi, Ylo;
370    Xhi = _mm256_extractf128_ps(X, 1);
371    Xlo = _mm256_extractf128_ps(X, 0);
372    Yhi = exp4_approx(Xhi);
373    Ylo = exp4_approx(Xlo);
374    Y = _mm256_insertf128_ps(_mm256_setzero_ps(), Yhi, 1);
375    Y = _mm256_insertf128_ps(Y, Ylo, 0);
376    return Y;
377 }
378 
vector_ps_to_epi8(unsigned char * x,const float * _x,int len)379 static inline void vector_ps_to_epi8(unsigned char *x, const float *_x, int len) {
380     int i;
381     for (i=0;i<len;i++) x[i] = 127+(int)floor(.5+127*_x[i]);
382 }
383 
384 #endif
385 
386 
387 #ifdef __AVX__
388 
389 /* Approximating tanh() using a Padé-like rational function:
390    tanh(x) ~= x * (N0 + N1*x^2 + N2*x^4)/(D0 + D1*x^2 + D2*x^4)
391    subject to the +/- 1 bounds.
392    The coefficients were determined by gradient descent trying to minimize
393    the maximum deviation over the whole range (this is only possible because
394    of the bounds). The max error is around 3e-4 and is dominated by the
395    reciprocal approximation (the max error of the rational function is
396    around 6e-5).
397    */
tanh8_approx(__m256 X)398 static inline __m256 tanh8_approx(__m256 X)
399 {
400    const __m256 N0 = _mm256_set1_ps(952.52801514f);
401    const __m256 N1 = _mm256_set1_ps(96.39235687f);
402    const __m256 N2 = _mm256_set1_ps(0.60863042f);
403    const __m256 D0 = _mm256_set1_ps(952.72399902f);
404    const __m256 D1 = _mm256_set1_ps(413.36801147f);
405    const __m256 D2 = _mm256_set1_ps(11.88600922f);
406    const __m256 max_out = _mm256_set1_ps(1.f);
407    const __m256 min_out = _mm256_set1_ps(-1.f);
408    __m256 X2, num, den;
409    X2 = _mm256_mul_ps(X, X);
410    num = _mm256_fmadd_ps(_mm256_fmadd_ps(N2, X2, N1), X2, N0);
411    den = _mm256_fmadd_ps(_mm256_fmadd_ps(D2, X2, D1), X2, D0);
412    num = _mm256_mul_ps(num, X);
413    den = _mm256_rcp_ps(den);
414    num = _mm256_mul_ps(num, den);
415    return _mm256_max_ps(min_out, _mm256_min_ps(max_out, num));
416 }
417 
418 /* Sigmoid approximation using a Padé-like rational function:
419    1/(1+exp(-x)) ~= 0.5 + x * (N0 + N1*x^2 + N2*x^4)/(D0 + D1*x^2 + D2*x^4)
420    subject to the [0, 1] bounds.
421    The coefficients are directly derived by dividing the tanh() coefficients
422    by powers of two to get the correct scaling. The max error is around 1.5e-4
423    and is dominated by the reciprocal approximation (the max error of the
424    rational function is around 3e-5).
425    */
sigmoid8_approx(__m256 X)426 static inline __m256 sigmoid8_approx(__m256 X)
427 {
428    const __m256 N0 = _mm256_set1_ps(238.13200378f);
429    const __m256 N1 = _mm256_set1_ps(6.02452230f);
430    const __m256 N2 = _mm256_set1_ps(0.00950985f);
431    const __m256 D0 = _mm256_set1_ps(952.72399902f);
432    const __m256 D1 = _mm256_set1_ps(103.34200287f);
433    const __m256 D2 = _mm256_set1_ps(0.74287558f);
434    const __m256 half = _mm256_set1_ps(0.5);
435    const __m256 max_out = _mm256_set1_ps(1.f);
436    const __m256 min_out = _mm256_set1_ps(0.f);
437    __m256 X2, num, den;
438    X2 = _mm256_mul_ps(X, X);
439    num = _mm256_fmadd_ps(_mm256_fmadd_ps(N2, X2, N1), X2, N0);
440    den = _mm256_fmadd_ps(_mm256_fmadd_ps(D2, X2, D1), X2, D0);
441    num = _mm256_mul_ps(num, X);
442    den = _mm256_rcp_ps(den);
443    num = _mm256_fmadd_ps(num, den, half);
444    return _mm256_max_ps(min_out, _mm256_min_ps(max_out, num));
445 }
446 
tanh_approx(float x)447 static inline float tanh_approx(float x)
448 {
449    float out[8];
450    __m256 X, Y;
451    X = _mm256_set1_ps(x);
452    Y = tanh8_approx(X);
453    _mm256_storeu_ps(out, Y);
454    return out[0];
455 }
456 
sigmoid_approx(float x)457 static inline float sigmoid_approx(float x)
458 {
459    float out[8];
460    __m256 X, Y;
461    X = _mm256_set1_ps(x);
462    Y = sigmoid8_approx(X);
463    _mm256_storeu_ps(out, Y);
464    return out[0];
465 }
466 
467 #else
468 
tanh4_approx(__m128 X)469 static inline __m128 tanh4_approx(__m128 X)
470 {
471    const __m128 N0 = _mm_set1_ps(952.52801514f);
472    const __m128 N1 = _mm_set1_ps(96.39235687f);
473    const __m128 N2 = _mm_set1_ps(0.60863042f);
474    const __m128 D0 = _mm_set1_ps(952.72399902f);
475    const __m128 D1 = _mm_set1_ps(413.36801147f);
476    const __m128 D2 = _mm_set1_ps(11.88600922f);
477    const __m128 max_out = _mm_set1_ps(1.f);
478    const __m128 min_out = _mm_set1_ps(-1.f);
479    __m128 X2, num, den;
480    X2 = _mm_mul_ps(X, X);
481    num = _mm_fmadd_ps(_mm_fmadd_ps(N2, X2, N1), X2, N0);
482    den = _mm_fmadd_ps(_mm_fmadd_ps(D2, X2, D1), X2, D0);
483    num = _mm_mul_ps(num, X);
484    den = _mm_rcp_ps(den);
485    num = _mm_mul_ps(num, den);
486    return _mm_max_ps(min_out, _mm_min_ps(max_out, num));
487 }
488 
sigmoid4_approx(__m128 X)489 static inline __m128 sigmoid4_approx(__m128 X)
490 {
491    const __m128 N0 = _mm_set1_ps(238.13200378f);
492    const __m128 N1 = _mm_set1_ps(6.02452230f);
493    const __m128 N2 = _mm_set1_ps(0.00950985f);
494    const __m128 D0 = _mm_set1_ps(952.72399902f);
495    const __m128 D1 = _mm_set1_ps(103.34200287f);
496    const __m128 D2 = _mm_set1_ps(0.74287558f);
497    const __m128 half = _mm_set1_ps(0.5);
498    const __m128 max_out = _mm_set1_ps(1.f);
499    const __m128 min_out = _mm_set1_ps(0.f);
500    __m128 X2, num, den;
501    X2 = _mm_mul_ps(X, X);
502    num = _mm_fmadd_ps(_mm_fmadd_ps(N2, X2, N1), X2, N0);
503    den = _mm_fmadd_ps(_mm_fmadd_ps(D2, X2, D1), X2, D0);
504    num = _mm_mul_ps(num, X);
505    den = _mm_rcp_ps(den);
506    num = _mm_fmadd_ps(num, den, half);
507    return _mm_max_ps(min_out, _mm_min_ps(max_out, num));
508 }
509 
tanh_approx(float x)510 static inline float tanh_approx(float x)
511 {
512    float out[4];
513    __m128 X, Y;
514    X = _mm_set1_ps(x);
515    Y = tanh4_approx(X);
516    _mm_storeu_ps(out, Y);
517    return out[0];
518 }
519 
sigmoid_approx(float x)520 static inline float sigmoid_approx(float x)
521 {
522    float out[4];
523    __m128 X, Y;
524    X = _mm_set1_ps(x);
525    Y = sigmoid4_approx(X);
526    _mm_storeu_ps(out, Y);
527    return out[0];
528 }
529 
530 #endif
531 
lpcnet_exp(float x)532 static inline float lpcnet_exp(float x)
533 {
534    float out[8];
535    __m256 X, Y;
536    X = _mm256_set1_ps(x);
537    Y = exp8_approx(X);
538    _mm256_storeu_ps(out, Y);
539    return out[0];
540 }
541 
softmax(float * y,const float * x,int N)542 static inline void softmax(float *y, const float *x, int N)
543 {
544     int i;
545     for (i=0;i<N-7;i+=8)
546     {
547         __m256 X, Y;
548         X = _mm256_loadu_ps(&x[i]);
549         Y = exp8_approx(X);
550         _mm256_storeu_ps(&y[i], Y);
551     }
552     for (;i<N;i++)
553         y[i] = lpcnet_exp(x[i]);
554 }
555 
556 #ifdef __AVX__
vec_tanh(float * y,const float * x,int N)557 static inline void vec_tanh(float *y, const float *x, int N)
558 {
559     int i;
560     for (i=0;i<N-7;i+=8)
561     {
562         __m256 X, Y;
563         X = _mm256_loadu_ps(&x[i]);
564         Y = tanh8_approx(X);
565         _mm256_storeu_ps(&y[i], Y);
566     }
567     for (;i<N;i++)
568     {
569         y[i] = tanh_approx(x[i]);
570     }
571 }
572 
vec_sigmoid(float * y,const float * x,int N)573 static inline void vec_sigmoid(float *y, const float *x, int N)
574 {
575     int i;
576     for (i=0;i<N-7;i+=8)
577     {
578         __m256 X, Y;
579         X = _mm256_loadu_ps(&x[i]);
580         Y = sigmoid8_approx(X);
581         _mm256_storeu_ps(&y[i], Y);
582     }
583     for (;i<N;i++)
584     {
585         y[i] = sigmoid_approx(x[i]);
586     }
587 }
588 #else
vec_tanh(float * y,const float * x,int N)589 static inline void vec_tanh(float *y, const float *x, int N)
590 {
591     int i;
592     for (i=0;i<N-3;i+=4)
593     {
594         __m128 X, Y;
595         X = _mm_loadu_ps(&x[i]);
596         Y = tanh4_approx(X);
597         _mm_storeu_ps(&y[i], Y);
598     }
599     for (;i<N;i++)
600     {
601         y[i] = tanh_approx(x[i]);
602     }
603 }
604 
vec_sigmoid(float * y,const float * x,int N)605 static inline void vec_sigmoid(float *y, const float *x, int N)
606 {
607     int i;
608     for (i=0;i<N-3;i+=4)
609     {
610         __m128 X, Y;
611         X = _mm_loadu_ps(&x[i]);
612         Y = sigmoid4_approx(X);
613         _mm_storeu_ps(&y[i], Y);
614     }
615     for (;i<N;i++)
616     {
617         y[i] = sigmoid_approx(x[i]);
618     }
619 }
620 
621 #endif
622 
623 #if defined(__AVXVNNI__) || defined(__AVX512VNNI__)
624 
625 #define opus_mm256_dpbusds_epi32(src, a, b) _mm256_dpbusds_epi32(src, a, b)
626 
627 #elif defined(__AVX2__)
628 
opus_mm256_dpbusds_epi32(__m256i src,__m256i a,__m256i b)629 static inline __m256i opus_mm256_dpbusds_epi32(__m256i src, __m256i a, __m256i b) {
630   __m256i ones, tmp;
631   ones = _mm256_set1_epi16(1);
632   tmp = _mm256_maddubs_epi16(a, b);
633   tmp = _mm256_madd_epi16(tmp, ones);
634   return _mm256_add_epi32(src, tmp);
635 }
636 
637 #elif defined(__SSSE3__)
638 
opus_mm256_dpbusds_epi32(mm256i_emu src,mm256i_emu a,mm256i_emu b)639 static inline mm256i_emu opus_mm256_dpbusds_epi32(mm256i_emu src, mm256i_emu a, mm256i_emu b) {
640   mm256i_emu ones, tmp;
641   ones = _mm256_set1_epi16(1);
642   tmp = _mm256_maddubs_epi16(a, b);
643   tmp = _mm256_madd_epi16(tmp, ones);
644   return _mm256_add_epi32(src, tmp);
645 }
646 
647 #elif defined(__SSE2__)
648 
mm_dpbusds_epi32(__m128i src,__m128i a,__m128i b)649 static inline __m128i mm_dpbusds_epi32(__m128i src, __m128i a, __m128i b) {
650   __m128i ah, al, bh, bl, tmp;
651   ah = _mm_srli_epi16(a, 8);
652   bh = _mm_srai_epi16(b, 8);
653   al = _mm_srli_epi16(_mm_slli_epi16(a, 8), 8);
654   bl = _mm_srai_epi16(_mm_slli_epi16(b, 8), 8);
655   tmp = _mm_add_epi32(_mm_madd_epi16(ah, bh), _mm_madd_epi16(al, bl));
656   return _mm_add_epi32(src, tmp);
657 }
658 
opus_mm256_dpbusds_epi32(mm256i_emu src,mm256i_emu a,mm256i_emu b)659 static inline mm256i_emu opus_mm256_dpbusds_epi32(mm256i_emu src, mm256i_emu a, mm256i_emu b) {
660   mm256i_emu res;
661   res.hi = mm_dpbusds_epi32(src.hi, a.hi, b.hi);
662   res.lo = mm_dpbusds_epi32(src.lo, a.lo, b.lo);
663   return res;
664 }
665 
666 
667 #else
668 
669 #error "No optimizations in vec_avx.h. This should never happen. "
670 #endif
671 
sgemv(float * out,const float * weights,int rows,int cols,int col_stride,const float * x)672 static inline void sgemv(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
673 {
674   int i, j;
675   i=0;
676   for (;i<rows-15;i+=16)
677   {
678      float *y;
679      __m256 vy0, vy8;
680      y = &out[i];
681      vy0 = _mm256_setzero_ps();
682      vy8 = _mm256_setzero_ps();
683      for (j=0;j<cols;j++)
684      {
685         __m256 vxj;
686         __m256 vw;
687         vxj = _mm256_broadcast_ss(&x[j]);
688 
689         vw = _mm256_loadu_ps(&weights[j*col_stride + i]);
690         vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
691 
692         vw = _mm256_loadu_ps(&weights[j*col_stride + i + 8]);
693         vy8 = _mm256_fmadd_ps(vw, vxj, vy8);
694      }
695      _mm256_storeu_ps (&y[0], vy0);
696      _mm256_storeu_ps (&y[8], vy8);
697   }
698   for (;i<rows-7;i+=8)
699   {
700      float *y;
701      __m256 vy0;
702      y = &out[i];
703      vy0 = _mm256_setzero_ps();
704      for (j=0;j<cols;j++)
705      {
706         __m256 vxj;
707         __m256 vw;
708         vxj = _mm256_broadcast_ss(&x[j]);
709 
710         vw = _mm256_loadu_ps(&weights[j*col_stride + i]);
711         vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
712      }
713      _mm256_storeu_ps (&y[0], vy0);
714   }
715   for (;i<rows-3;i+=4)
716   {
717      float *y;
718      __m128 vy0;
719      y = &out[i];
720      vy0 = _mm_setzero_ps();
721      for (j=0;j<cols;j++)
722      {
723         __m128 vxj;
724         __m128 vw;
725         vxj = _mm_set1_ps(x[j]);
726 
727         vw = _mm_loadu_ps(&weights[j*col_stride + i]);
728         vy0 = _mm_fmadd_ps(vw, vxj, vy0);
729      }
730      _mm_storeu_ps (&y[0], vy0);
731   }
732   for (;i<rows;i++)
733   {
734     out[i] = 0;
735     for (j=0;j<cols;j++) out[i] += weights[j*col_stride + i]*x[j];
736   }
737 }
738 
sparse_sgemv8x4(float * out,const float * weights,const int * idx,int rows,const float * x)739 static inline void sparse_sgemv8x4(float *out, const float *weights, const int *idx, int rows, const float *x)
740 {
741    int i, j;
742    for (i=0;i<rows;i+=8)
743    {
744       float *y;
745       int cols;
746       __m256 vy0;
747       y = &out[i];
748       vy0 = _mm256_setzero_ps();
749       cols = *idx++;
750       for (j=0;j<cols;j++)
751       {
752          int id;
753          __m256 vxj;
754          __m256 vw;
755          id = *idx++;
756          vxj = _mm256_broadcast_ss(&x[id]);
757          vw = _mm256_loadu_ps(&weights[0]);
758          vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
759 
760          vxj = _mm256_broadcast_ss(&x[id+1]);
761          vw = _mm256_loadu_ps(&weights[8]);
762          vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
763 
764          vxj = _mm256_broadcast_ss(&x[id+2]);
765          vw = _mm256_loadu_ps(&weights[16]);
766          vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
767 
768          vxj = _mm256_broadcast_ss(&x[id+3]);
769          vw = _mm256_loadu_ps(&weights[24]);
770          vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
771 
772          weights += 32;
773       }
774       _mm256_storeu_ps (&y[0], vy0);
775    }
776 }
777 
sparse_cgemv8x4(float * _out,const opus_int8 * w,const int * idx,const float * scale,int rows,int cols,const float * _x)778 static inline void sparse_cgemv8x4(float *_out, const opus_int8 *w, const int *idx, const float *scale, int rows, int cols, const float *_x)
779 {
780    int i, j;
781    unsigned char x[MAX_INPUTS];
782    /*for (i=0;i<cols;i++) x[i] = 127+floor(.5+127*_x[i]);*/
783    vector_ps_to_epi8(x, _x, cols);
784    for (i=0;i<rows;i+=8)
785    {
786       int colblocks;
787       __m256i vy0;
788       __m256 vout;
789       colblocks = *idx++;
790       vy0 = _mm256_setzero_si256();
791       j=0;
792 #if 1 /* Unrolling by 4 gives some gain, comment out if it does not. */
793       for (;j<colblocks-3;j+=4)
794       {
795          __m256i vxj;
796          __m256i vw;
797          vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[*idx++]));
798          vw = _mm256_loadu_si256((const __m256i *)(void*)w);
799          vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
800          w += 32;
801          vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[*idx++]));
802          vw = _mm256_loadu_si256((const __m256i *)(void*)w);
803          vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
804          w += 32;
805          vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[*idx++]));
806          vw = _mm256_loadu_si256((const __m256i *)(void*)w);
807          vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
808          w += 32;
809          vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[*idx++]));
810          vw = _mm256_loadu_si256((const __m256i *)(void*)w);
811          vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
812          w += 32;
813       }
814 #endif
815       for (;j<colblocks;j++)
816       {
817          __m256i vxj;
818          __m256i vw;
819          vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[*idx++]));
820          vw = _mm256_loadu_si256((const __m256i *)(void*)w);
821          vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
822          w += 32;
823       }
824       vout = _mm256_cvtepi32_ps(vy0);
825       vout = _mm256_mul_ps(vout, _mm256_loadu_ps(&scale[i]));
826       _mm256_storeu_ps(&_out[i], vout);
827    }
828 }
cgemv8x4(float * _out,const opus_int8 * w,const float * scale,int rows,int cols,const float * _x)829 static inline void cgemv8x4(float *_out, const opus_int8 *w, const float *scale, int rows, int cols, const float *_x)
830 {
831    int i, j;
832    unsigned char x[MAX_INPUTS];
833    /*for (i=0;i<cols;i++) x[i] = 127+floor(.5+127*_x[i]);*/
834    vector_ps_to_epi8(x, _x, cols);
835    for (i=0;i<rows;i+=8)
836    {
837       __m256i vy0;
838       __m256 vout;
839       vy0 = _mm256_setzero_si256();
840       j=0;
841 #if 1 /* Unrolling by 4 gives some gain, comment out if it does not. */
842       for (;j<cols-12;j+=16)
843       {
844          __m256i vxj;
845          __m256i vw;
846          vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[j]));
847          vw = _mm256_loadu_si256((const __m256i *)(void*)w);
848          vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
849          w += 32;
850          vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[j+4]));
851          vw = _mm256_loadu_si256((const __m256i *)(void*)w);
852          vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
853          w += 32;
854          vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[j+8]));
855          vw = _mm256_loadu_si256((const __m256i *)(void*)w);
856          vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
857          w += 32;
858          vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[j+12]));
859          vw = _mm256_loadu_si256((const __m256i *)(void*)w);
860          vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
861          w += 32;
862       }
863 #endif
864       for (;j<cols;j+=4)
865       {
866          __m256i vxj;
867          __m256i vw;
868          vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[j]));
869          vw = _mm256_loadu_si256((const __m256i *)(void*)w);
870          vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
871          w += 32;
872       }
873       vout = _mm256_cvtepi32_ps(vy0);
874       vout = _mm256_mul_ps(vout, _mm256_loadu_ps(&scale[i]));
875       _mm256_storeu_ps(&_out[i], vout);
876    }
877 }
878 
879 #define SCALE (128.f*127.f)
880 #define SCALE_1 (1.f/128.f/127.f)
881 #define USE_SU_BIAS
882 
883 
884 #endif /*VEC_AVX_H*/
885