1*a58d3d2aSXin Li /* Copyright (c) 2018 Mozilla
2*a58d3d2aSXin Li 2012-2017 Jean-Marc Valin */
3*a58d3d2aSXin Li /*
4*a58d3d2aSXin Li Redistribution and use in source and binary forms, with or without
5*a58d3d2aSXin Li modification, are permitted provided that the following conditions
6*a58d3d2aSXin Li are met:
7*a58d3d2aSXin Li
8*a58d3d2aSXin Li - Redistributions of source code must retain the above copyright
9*a58d3d2aSXin Li notice, this list of conditions and the following disclaimer.
10*a58d3d2aSXin Li
11*a58d3d2aSXin Li - Redistributions in binary form must reproduce the above copyright
12*a58d3d2aSXin Li notice, this list of conditions and the following disclaimer in the
13*a58d3d2aSXin Li documentation and/or other materials provided with the distribution.
14*a58d3d2aSXin Li
15*a58d3d2aSXin Li THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16*a58d3d2aSXin Li ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17*a58d3d2aSXin Li LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18*a58d3d2aSXin Li A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR
19*a58d3d2aSXin Li CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20*a58d3d2aSXin Li EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21*a58d3d2aSXin Li PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22*a58d3d2aSXin Li PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23*a58d3d2aSXin Li LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24*a58d3d2aSXin Li NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25*a58d3d2aSXin Li SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26*a58d3d2aSXin Li */
27*a58d3d2aSXin Li /*
28*a58d3d2aSXin Li AVX implementation of vector operations, compile with -mavx
29*a58d3d2aSXin Li AVX2/FMA implementation of vector operations, compile with -mavx2 -mfma
30*a58d3d2aSXin Li */
31*a58d3d2aSXin Li
32*a58d3d2aSXin Li #ifndef VEC_AVX_H
33*a58d3d2aSXin Li #define VEC_AVX_H
34*a58d3d2aSXin Li
35*a58d3d2aSXin Li #include <immintrin.h>
36*a58d3d2aSXin Li #include <math.h>
37*a58d3d2aSXin Li #include "celt/x86/x86cpu.h"
38*a58d3d2aSXin Li
39*a58d3d2aSXin Li #define MAX_INPUTS (2048)
40*a58d3d2aSXin Li
41*a58d3d2aSXin Li #define USE_SU_BIAS
42*a58d3d2aSXin Li
43*a58d3d2aSXin Li #ifndef __SSE_4_1__
mm_floor_ps(__m128 x)44*a58d3d2aSXin Li static inline __m128 mm_floor_ps(__m128 x) {
45*a58d3d2aSXin Li __m128 half = _mm_set1_ps(0.5);
46*a58d3d2aSXin Li return _mm_cvtepi32_ps(_mm_cvtps_epi32(_mm_sub_ps(x, half)));
47*a58d3d2aSXin Li }
48*a58d3d2aSXin Li #undef _mm_floor_ps
49*a58d3d2aSXin Li #define _mm_floor_ps(x) mm_floor_ps(x)
50*a58d3d2aSXin Li #endif
51*a58d3d2aSXin Li
52*a58d3d2aSXin Li
53*a58d3d2aSXin Li /* If we don't have AVX available, emulate what we need with SSE up to 4.1. */
54*a58d3d2aSXin Li #ifndef __AVX__
55*a58d3d2aSXin Li
56*a58d3d2aSXin Li typedef struct {
57*a58d3d2aSXin Li __m128 lo;
58*a58d3d2aSXin Li __m128 hi;
59*a58d3d2aSXin Li } mm256_emu;
60*a58d3d2aSXin Li #define __m256 mm256_emu
61*a58d3d2aSXin Li
mm256_loadu_ps(const float * src)62*a58d3d2aSXin Li static inline mm256_emu mm256_loadu_ps(const float *src) {
63*a58d3d2aSXin Li mm256_emu ret;
64*a58d3d2aSXin Li ret.lo = _mm_loadu_ps(&src[0]);
65*a58d3d2aSXin Li ret.hi = _mm_loadu_ps(&src[4]);
66*a58d3d2aSXin Li return ret;
67*a58d3d2aSXin Li }
68*a58d3d2aSXin Li #define _mm256_loadu_ps(src) mm256_loadu_ps(src)
69*a58d3d2aSXin Li
70*a58d3d2aSXin Li
mm256_storeu_ps(float * dst,mm256_emu src)71*a58d3d2aSXin Li static inline void mm256_storeu_ps(float *dst, mm256_emu src) {
72*a58d3d2aSXin Li _mm_storeu_ps(dst, src.lo);
73*a58d3d2aSXin Li _mm_storeu_ps(&dst[4], src.hi);
74*a58d3d2aSXin Li }
75*a58d3d2aSXin Li #define _mm256_storeu_ps(dst, src) mm256_storeu_ps(dst, src)
76*a58d3d2aSXin Li
77*a58d3d2aSXin Li
mm256_setzero_ps(void)78*a58d3d2aSXin Li static inline mm256_emu mm256_setzero_ps(void) {
79*a58d3d2aSXin Li mm256_emu ret;
80*a58d3d2aSXin Li ret.lo = _mm_setzero_ps();
81*a58d3d2aSXin Li ret.hi = ret.lo;
82*a58d3d2aSXin Li return ret;
83*a58d3d2aSXin Li }
84*a58d3d2aSXin Li #define _mm256_setzero_ps mm256_setzero_ps
85*a58d3d2aSXin Li
mm256_broadcast_ss(const float * x)86*a58d3d2aSXin Li static inline mm256_emu mm256_broadcast_ss(const float *x) {
87*a58d3d2aSXin Li mm256_emu ret;
88*a58d3d2aSXin Li ret.lo = _mm_set1_ps(*x);
89*a58d3d2aSXin Li ret.hi = ret.lo;
90*a58d3d2aSXin Li return ret;
91*a58d3d2aSXin Li }
92*a58d3d2aSXin Li #define _mm256_broadcast_ss(x) mm256_broadcast_ss(x)
93*a58d3d2aSXin Li
mm256_set1_ps(float x)94*a58d3d2aSXin Li static inline mm256_emu mm256_set1_ps(float x) {
95*a58d3d2aSXin Li mm256_emu ret;
96*a58d3d2aSXin Li ret.lo = _mm_set1_ps(x);
97*a58d3d2aSXin Li ret.hi = ret.lo;
98*a58d3d2aSXin Li return ret;
99*a58d3d2aSXin Li }
100*a58d3d2aSXin Li #define _mm256_set1_ps(x) mm256_set1_ps(x)
101*a58d3d2aSXin Li
102*a58d3d2aSXin Li
103*a58d3d2aSXin Li
mm256_mul_ps(mm256_emu a,mm256_emu b)104*a58d3d2aSXin Li static inline mm256_emu mm256_mul_ps(mm256_emu a, mm256_emu b) {
105*a58d3d2aSXin Li mm256_emu ret;
106*a58d3d2aSXin Li ret.lo = _mm_mul_ps(a.lo, b.lo);
107*a58d3d2aSXin Li ret.hi = _mm_mul_ps(a.hi, b.hi);
108*a58d3d2aSXin Li return ret;
109*a58d3d2aSXin Li }
110*a58d3d2aSXin Li #define _mm256_mul_ps(a,b) mm256_mul_ps(a,b)
111*a58d3d2aSXin Li
mm256_add_ps(mm256_emu a,mm256_emu b)112*a58d3d2aSXin Li static inline mm256_emu mm256_add_ps(mm256_emu a, mm256_emu b) {
113*a58d3d2aSXin Li mm256_emu ret;
114*a58d3d2aSXin Li ret.lo = _mm_add_ps(a.lo, b.lo);
115*a58d3d2aSXin Li ret.hi = _mm_add_ps(a.hi, b.hi);
116*a58d3d2aSXin Li return ret;
117*a58d3d2aSXin Li }
118*a58d3d2aSXin Li #define _mm256_add_ps(a,b) mm256_add_ps(a,b)
119*a58d3d2aSXin Li
120*a58d3d2aSXin Li
mm256_max_ps(mm256_emu a,mm256_emu b)121*a58d3d2aSXin Li static inline mm256_emu mm256_max_ps(mm256_emu a, mm256_emu b) {
122*a58d3d2aSXin Li mm256_emu ret;
123*a58d3d2aSXin Li ret.lo = _mm_max_ps(a.lo, b.lo);
124*a58d3d2aSXin Li ret.hi = _mm_max_ps(a.hi, b.hi);
125*a58d3d2aSXin Li return ret;
126*a58d3d2aSXin Li }
127*a58d3d2aSXin Li #define _mm256_max_ps(a,b) mm256_max_ps(a,b)
128*a58d3d2aSXin Li
mm256_min_ps(mm256_emu a,mm256_emu b)129*a58d3d2aSXin Li static inline mm256_emu mm256_min_ps(mm256_emu a, mm256_emu b) {
130*a58d3d2aSXin Li mm256_emu ret;
131*a58d3d2aSXin Li ret.lo = _mm_min_ps(a.lo, b.lo);
132*a58d3d2aSXin Li ret.hi = _mm_min_ps(a.hi, b.hi);
133*a58d3d2aSXin Li return ret;
134*a58d3d2aSXin Li }
135*a58d3d2aSXin Li #define _mm256_min_ps(a,b) mm256_min_ps(a,b)
136*a58d3d2aSXin Li
mm256_rcp_ps(mm256_emu a)137*a58d3d2aSXin Li static inline mm256_emu mm256_rcp_ps(mm256_emu a) {
138*a58d3d2aSXin Li mm256_emu ret;
139*a58d3d2aSXin Li ret.lo = _mm_rcp_ps(a.lo);
140*a58d3d2aSXin Li ret.hi = _mm_rcp_ps(a.hi);
141*a58d3d2aSXin Li return ret;
142*a58d3d2aSXin Li }
143*a58d3d2aSXin Li #define _mm256_rcp_ps(a) mm256_rcp_ps(a)
144*a58d3d2aSXin Li
145*a58d3d2aSXin Li
mm256_extractf128_ps(mm256_emu x,int i)146*a58d3d2aSXin Li static inline __m128 mm256_extractf128_ps(mm256_emu x, int i) {
147*a58d3d2aSXin Li return (i==0) ? x.lo : x.hi;
148*a58d3d2aSXin Li }
149*a58d3d2aSXin Li #undef _mm256_extractf128_ps
150*a58d3d2aSXin Li #define _mm256_extractf128_ps(x,i) mm256_extractf128_ps(x,i)
151*a58d3d2aSXin Li
mm256_insertf128_ps(mm256_emu dst,__m128 src,int i)152*a58d3d2aSXin Li static inline mm256_emu mm256_insertf128_ps(mm256_emu dst, __m128 src, int i) {
153*a58d3d2aSXin Li if (i==0) dst.lo = src;
154*a58d3d2aSXin Li else dst.hi = src;
155*a58d3d2aSXin Li return dst;
156*a58d3d2aSXin Li }
157*a58d3d2aSXin Li #undef _mm256_insertf128_ps
158*a58d3d2aSXin Li #define _mm256_insertf128_ps(dst,src,i) mm256_insertf128_ps(dst,src,i)
159*a58d3d2aSXin Li
160*a58d3d2aSXin Li #endif /* __AVX__ */
161*a58d3d2aSXin Li
162*a58d3d2aSXin Li
163*a58d3d2aSXin Li
164*a58d3d2aSXin Li /* If we don't have AVX2 available, emulate what we need with SSE up to 4.1. */
165*a58d3d2aSXin Li #ifndef __AVX2__
166*a58d3d2aSXin Li
167*a58d3d2aSXin Li typedef struct {
168*a58d3d2aSXin Li __m128i lo;
169*a58d3d2aSXin Li __m128i hi;
170*a58d3d2aSXin Li } mm256i_emu;
171*a58d3d2aSXin Li typedef __m256i real_m256i;
172*a58d3d2aSXin Li #define __m256i mm256i_emu
173*a58d3d2aSXin Li
mm256_setzero_si256(void)174*a58d3d2aSXin Li static inline mm256i_emu mm256_setzero_si256(void) {
175*a58d3d2aSXin Li mm256i_emu ret;
176*a58d3d2aSXin Li ret.lo = _mm_setzero_si128();
177*a58d3d2aSXin Li ret.hi = ret.lo;
178*a58d3d2aSXin Li return ret;
179*a58d3d2aSXin Li }
180*a58d3d2aSXin Li #define _mm256_setzero_si256 mm256_setzero_si256
181*a58d3d2aSXin Li
182*a58d3d2aSXin Li
mm256_loadu_si256(const mm256i_emu * src)183*a58d3d2aSXin Li static inline mm256i_emu mm256_loadu_si256(const mm256i_emu *src) {
184*a58d3d2aSXin Li mm256i_emu ret;
185*a58d3d2aSXin Li ret.lo = _mm_loadu_si128((const __m128i*)src);
186*a58d3d2aSXin Li ret.hi = _mm_loadu_si128(&((const __m128i*)src)[1]);
187*a58d3d2aSXin Li return ret;
188*a58d3d2aSXin Li }
189*a58d3d2aSXin Li #define _mm256_loadu_si256(src) mm256_loadu_si256(src)
190*a58d3d2aSXin Li
191*a58d3d2aSXin Li
mm256_storeu_si256(mm256i_emu * dst,mm256i_emu src)192*a58d3d2aSXin Li static inline void mm256_storeu_si256(mm256i_emu *dst, mm256i_emu src) {
193*a58d3d2aSXin Li _mm_storeu_si128((__m128i*)dst, src.lo);
194*a58d3d2aSXin Li _mm_storeu_si128(&((__m128i*)dst)[1], src.hi);
195*a58d3d2aSXin Li }
196*a58d3d2aSXin Li #define _mm256_storeu_si256(dst, src) mm256_storeu_si256(dst, src)
197*a58d3d2aSXin Li
198*a58d3d2aSXin Li
mm256_broadcastd_epi32(__m128i x)199*a58d3d2aSXin Li static inline mm256i_emu mm256_broadcastd_epi32(__m128i x) {
200*a58d3d2aSXin Li mm256i_emu ret;
201*a58d3d2aSXin Li ret.hi = ret.lo = _mm_shuffle_epi32(x, 0);
202*a58d3d2aSXin Li return ret;
203*a58d3d2aSXin Li }
204*a58d3d2aSXin Li #define _mm256_broadcastd_epi32(x) mm256_broadcastd_epi32(x)
205*a58d3d2aSXin Li
206*a58d3d2aSXin Li
mm256_set1_epi32(int x)207*a58d3d2aSXin Li static inline mm256i_emu mm256_set1_epi32(int x) {
208*a58d3d2aSXin Li mm256i_emu ret;
209*a58d3d2aSXin Li ret.lo = _mm_set1_epi32(x);
210*a58d3d2aSXin Li ret.hi = ret.lo;
211*a58d3d2aSXin Li return ret;
212*a58d3d2aSXin Li }
213*a58d3d2aSXin Li #define _mm256_set1_epi32(x) mm256_set1_epi32(x)
214*a58d3d2aSXin Li
mm256_set1_epi16(int x)215*a58d3d2aSXin Li static inline mm256i_emu mm256_set1_epi16(int x) {
216*a58d3d2aSXin Li mm256i_emu ret;
217*a58d3d2aSXin Li ret.lo = _mm_set1_epi16(x);
218*a58d3d2aSXin Li ret.hi = ret.lo;
219*a58d3d2aSXin Li return ret;
220*a58d3d2aSXin Li }
221*a58d3d2aSXin Li #define _mm256_set1_epi16(x) mm256_set1_epi16(x)
222*a58d3d2aSXin Li
223*a58d3d2aSXin Li
mm256_add_epi32(mm256i_emu a,mm256i_emu b)224*a58d3d2aSXin Li static inline mm256i_emu mm256_add_epi32(mm256i_emu a, mm256i_emu b) {
225*a58d3d2aSXin Li mm256i_emu ret;
226*a58d3d2aSXin Li ret.lo = _mm_add_epi32(a.lo, b.lo);
227*a58d3d2aSXin Li ret.hi = _mm_add_epi32(a.hi, b.hi);
228*a58d3d2aSXin Li return ret;
229*a58d3d2aSXin Li }
230*a58d3d2aSXin Li #define _mm256_add_epi32(a,b) mm256_add_epi32(a,b)
231*a58d3d2aSXin Li
mm256_madd_epi16(mm256i_emu a,mm256i_emu b)232*a58d3d2aSXin Li static inline mm256i_emu mm256_madd_epi16(mm256i_emu a, mm256i_emu b) {
233*a58d3d2aSXin Li mm256i_emu ret;
234*a58d3d2aSXin Li ret.lo = _mm_madd_epi16(a.lo, b.lo);
235*a58d3d2aSXin Li ret.hi = _mm_madd_epi16(a.hi, b.hi);
236*a58d3d2aSXin Li return ret;
237*a58d3d2aSXin Li }
238*a58d3d2aSXin Li #define _mm256_madd_epi16(a,b) mm256_madd_epi16(a,b)
239*a58d3d2aSXin Li
mm256_maddubs_epi16(mm256i_emu a,mm256i_emu b)240*a58d3d2aSXin Li static inline mm256i_emu mm256_maddubs_epi16(mm256i_emu a, mm256i_emu b) {
241*a58d3d2aSXin Li mm256i_emu ret;
242*a58d3d2aSXin Li ret.lo = _mm_maddubs_epi16(a.lo, b.lo);
243*a58d3d2aSXin Li ret.hi = _mm_maddubs_epi16(a.hi, b.hi);
244*a58d3d2aSXin Li return ret;
245*a58d3d2aSXin Li }
246*a58d3d2aSXin Li #define _mm256_maddubs_epi16(a,b) mm256_maddubs_epi16(a,b)
247*a58d3d2aSXin Li
248*a58d3d2aSXin Li
249*a58d3d2aSXin Li
250*a58d3d2aSXin Li /* Emulating the conversion functions is tricky because they use __m256i but are defined in AVX.
251*a58d3d2aSXin Li So we need to make a special when only AVX is available. */
252*a58d3d2aSXin Li #ifdef __AVX__
253*a58d3d2aSXin Li
254*a58d3d2aSXin Li typedef union {
255*a58d3d2aSXin Li mm256i_emu fake;
256*a58d3d2aSXin Li real_m256i real;
257*a58d3d2aSXin Li } mm256_union;
258*a58d3d2aSXin Li
mm256_cvtepi32_ps(mm256i_emu a)259*a58d3d2aSXin Li static inline __m256 mm256_cvtepi32_ps(mm256i_emu a) {
260*a58d3d2aSXin Li mm256_union src;
261*a58d3d2aSXin Li src.fake = a;
262*a58d3d2aSXin Li return _mm256_cvtepi32_ps(src.real);
263*a58d3d2aSXin Li }
264*a58d3d2aSXin Li #define _mm256_cvtepi32_ps(a) mm256_cvtepi32_ps(a)
265*a58d3d2aSXin Li
mm256_cvtps_epi32(__m256 a)266*a58d3d2aSXin Li static inline mm256i_emu mm256_cvtps_epi32(__m256 a) {
267*a58d3d2aSXin Li mm256_union ret;
268*a58d3d2aSXin Li ret.real = _mm256_cvtps_epi32(a);
269*a58d3d2aSXin Li return ret.fake;
270*a58d3d2aSXin Li }
271*a58d3d2aSXin Li #define _mm256_cvtps_epi32(a) mm256_cvtps_epi32(a)
272*a58d3d2aSXin Li
273*a58d3d2aSXin Li
274*a58d3d2aSXin Li #else
275*a58d3d2aSXin Li
mm256_cvtepi32_ps(mm256i_emu a)276*a58d3d2aSXin Li static inline mm256_emu mm256_cvtepi32_ps(mm256i_emu a) {
277*a58d3d2aSXin Li mm256_emu ret;
278*a58d3d2aSXin Li ret.lo = _mm_cvtepi32_ps(a.lo);
279*a58d3d2aSXin Li ret.hi = _mm_cvtepi32_ps(a.hi);
280*a58d3d2aSXin Li return ret;
281*a58d3d2aSXin Li }
282*a58d3d2aSXin Li #define _mm256_cvtepi32_ps(a) mm256_cvtepi32_ps(a)
283*a58d3d2aSXin Li
mm256_cvtps_epi32(mm256_emu a)284*a58d3d2aSXin Li static inline mm256i_emu mm256_cvtps_epi32(mm256_emu a) {
285*a58d3d2aSXin Li mm256i_emu ret;
286*a58d3d2aSXin Li ret.lo = _mm_cvtps_epi32(a.lo);
287*a58d3d2aSXin Li ret.hi = _mm_cvtps_epi32(a.hi);
288*a58d3d2aSXin Li return ret;
289*a58d3d2aSXin Li }
290*a58d3d2aSXin Li #define _mm256_cvtps_epi32(a) mm256_cvtps_epi32(a)
291*a58d3d2aSXin Li
292*a58d3d2aSXin Li #endif /* __AVX__ */
293*a58d3d2aSXin Li
294*a58d3d2aSXin Li
295*a58d3d2aSXin Li #endif /* __AVX2__ */
296*a58d3d2aSXin Li
297*a58d3d2aSXin Li /* In case we don't have FMA, make it a mul and an add. */
298*a58d3d2aSXin Li #if !(defined(__FMA__) && defined(__AVX__))
299*a58d3d2aSXin Li #define _mm256_fmadd_ps(a,b,c) _mm256_add_ps(_mm256_mul_ps(a, b), c)
300*a58d3d2aSXin Li #define _mm_fmadd_ps(a,b,c) _mm_add_ps(_mm_mul_ps(a, b), c)
301*a58d3d2aSXin Li #endif
302*a58d3d2aSXin Li
303*a58d3d2aSXin Li #ifdef __AVX2__
exp8_approx(__m256 X)304*a58d3d2aSXin Li static inline __m256 exp8_approx(__m256 X)
305*a58d3d2aSXin Li {
306*a58d3d2aSXin Li const __m256 K0 = _mm256_set1_ps(0.99992522f);
307*a58d3d2aSXin Li const __m256 K1 = _mm256_set1_ps(0.69583354f);
308*a58d3d2aSXin Li const __m256 K2 = _mm256_set1_ps(0.22606716f);
309*a58d3d2aSXin Li const __m256 K3 = _mm256_set1_ps(0.078024523f);
310*a58d3d2aSXin Li const __m256 log2_E = _mm256_set1_ps(1.44269504f);
311*a58d3d2aSXin Li const __m256 max_in = _mm256_set1_ps(50.f);
312*a58d3d2aSXin Li const __m256 min_in = _mm256_set1_ps(-50.f);
313*a58d3d2aSXin Li __m256 XF, Y;
314*a58d3d2aSXin Li __m256i I;
315*a58d3d2aSXin Li X = _mm256_mul_ps(X, log2_E);
316*a58d3d2aSXin Li X = _mm256_max_ps(min_in, _mm256_min_ps(max_in, X));
317*a58d3d2aSXin Li XF = _mm256_floor_ps(X);
318*a58d3d2aSXin Li I = _mm256_cvtps_epi32(XF);
319*a58d3d2aSXin Li X = _mm256_sub_ps(X, XF);
320*a58d3d2aSXin Li Y = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(K3, X, K2), X, K1), X, K0);
321*a58d3d2aSXin Li I = _mm256_slli_epi32(I, 23);
322*a58d3d2aSXin Li Y = _mm256_castsi256_ps(_mm256_add_epi32(I, _mm256_castps_si256(Y)));
323*a58d3d2aSXin Li return Y;
324*a58d3d2aSXin Li }
325*a58d3d2aSXin Li
vector_ps_to_epi8(unsigned char * x,const float * _x,int len)326*a58d3d2aSXin Li static inline void vector_ps_to_epi8(unsigned char *x, const float *_x, int len) {
327*a58d3d2aSXin Li int i;
328*a58d3d2aSXin Li __m256 const127 = _mm256_set1_ps(127.f);
329*a58d3d2aSXin Li for (i=0;i<len;i+=8) {
330*a58d3d2aSXin Li __m256 xf;
331*a58d3d2aSXin Li __m256i xi;
332*a58d3d2aSXin Li xf = _mm256_loadu_ps(&_x[i]);
333*a58d3d2aSXin Li xf = _mm256_fmadd_ps(xf, const127, const127);
334*a58d3d2aSXin Li xi = _mm256_cvtps_epi32(xf);
335*a58d3d2aSXin Li xi = _mm256_packus_epi32(xi, _mm256_setzero_si256());
336*a58d3d2aSXin Li xi = _mm256_permute4x64_epi64(xi, 0xD8);
337*a58d3d2aSXin Li xi = _mm256_packus_epi16(xi, _mm256_setzero_si256());
338*a58d3d2aSXin Li xi = _mm256_permutevar8x32_epi32(xi, _mm256_setr_epi32(0,1, 0,0, 0,0, 0,0));
339*a58d3d2aSXin Li _mm256_storeu_si256 ((__m256i *)(void*)&x[i], xi);
340*a58d3d2aSXin Li }
341*a58d3d2aSXin Li }
342*a58d3d2aSXin Li
343*a58d3d2aSXin Li #else
exp4_approx(__m128 X)344*a58d3d2aSXin Li static inline __m128 exp4_approx(__m128 X)
345*a58d3d2aSXin Li {
346*a58d3d2aSXin Li const __m128 K0 = _mm_set1_ps(0.99992522f);
347*a58d3d2aSXin Li const __m128 K1 = _mm_set1_ps(0.69583354f);
348*a58d3d2aSXin Li const __m128 K2 = _mm_set1_ps(0.22606716f);
349*a58d3d2aSXin Li const __m128 K3 = _mm_set1_ps(0.078024523f);
350*a58d3d2aSXin Li const __m128 log2_E = _mm_set1_ps(1.44269504);
351*a58d3d2aSXin Li const __m128 max_in = _mm_set1_ps(50.f);
352*a58d3d2aSXin Li const __m128 min_in = _mm_set1_ps(-50.f);
353*a58d3d2aSXin Li const __m128i mask = _mm_set1_epi32(0x7fffffff);
354*a58d3d2aSXin Li __m128 XF, Y;
355*a58d3d2aSXin Li __m128i I;
356*a58d3d2aSXin Li X = _mm_mul_ps(X, log2_E);
357*a58d3d2aSXin Li X = _mm_max_ps(min_in, _mm_min_ps(max_in, X));
358*a58d3d2aSXin Li XF = _mm_floor_ps(X);
359*a58d3d2aSXin Li I = _mm_cvtps_epi32(XF);
360*a58d3d2aSXin Li X = _mm_sub_ps(X, XF);
361*a58d3d2aSXin Li Y = _mm_fmadd_ps(_mm_fmadd_ps(_mm_fmadd_ps(K3, X, K2), X, K1), X, K0);
362*a58d3d2aSXin Li I = _mm_slli_epi32(I, 23);
363*a58d3d2aSXin Li Y = _mm_castsi128_ps(_mm_and_si128(mask, _mm_add_epi32(I, _mm_castps_si128(Y))));
364*a58d3d2aSXin Li return Y;
365*a58d3d2aSXin Li }
exp8_approx(__m256 X)366*a58d3d2aSXin Li static inline __m256 exp8_approx(__m256 X)
367*a58d3d2aSXin Li {
368*a58d3d2aSXin Li __m256 Y;
369*a58d3d2aSXin Li __m128 Xhi, Xlo, Yhi, Ylo;
370*a58d3d2aSXin Li Xhi = _mm256_extractf128_ps(X, 1);
371*a58d3d2aSXin Li Xlo = _mm256_extractf128_ps(X, 0);
372*a58d3d2aSXin Li Yhi = exp4_approx(Xhi);
373*a58d3d2aSXin Li Ylo = exp4_approx(Xlo);
374*a58d3d2aSXin Li Y = _mm256_insertf128_ps(_mm256_setzero_ps(), Yhi, 1);
375*a58d3d2aSXin Li Y = _mm256_insertf128_ps(Y, Ylo, 0);
376*a58d3d2aSXin Li return Y;
377*a58d3d2aSXin Li }
378*a58d3d2aSXin Li
vector_ps_to_epi8(unsigned char * x,const float * _x,int len)379*a58d3d2aSXin Li static inline void vector_ps_to_epi8(unsigned char *x, const float *_x, int len) {
380*a58d3d2aSXin Li int i;
381*a58d3d2aSXin Li for (i=0;i<len;i++) x[i] = 127+(int)floor(.5+127*_x[i]);
382*a58d3d2aSXin Li }
383*a58d3d2aSXin Li
384*a58d3d2aSXin Li #endif
385*a58d3d2aSXin Li
386*a58d3d2aSXin Li
387*a58d3d2aSXin Li #ifdef __AVX__
388*a58d3d2aSXin Li
389*a58d3d2aSXin Li /* Approximating tanh() using a Padé-like rational function:
390*a58d3d2aSXin Li tanh(x) ~= x * (N0 + N1*x^2 + N2*x^4)/(D0 + D1*x^2 + D2*x^4)
391*a58d3d2aSXin Li subject to the +/- 1 bounds.
392*a58d3d2aSXin Li The coefficients were determined by gradient descent trying to minimize
393*a58d3d2aSXin Li the maximum deviation over the whole range (this is only possible because
394*a58d3d2aSXin Li of the bounds). The max error is around 3e-4 and is dominated by the
395*a58d3d2aSXin Li reciprocal approximation (the max error of the rational function is
396*a58d3d2aSXin Li around 6e-5).
397*a58d3d2aSXin Li */
tanh8_approx(__m256 X)398*a58d3d2aSXin Li static inline __m256 tanh8_approx(__m256 X)
399*a58d3d2aSXin Li {
400*a58d3d2aSXin Li const __m256 N0 = _mm256_set1_ps(952.52801514f);
401*a58d3d2aSXin Li const __m256 N1 = _mm256_set1_ps(96.39235687f);
402*a58d3d2aSXin Li const __m256 N2 = _mm256_set1_ps(0.60863042f);
403*a58d3d2aSXin Li const __m256 D0 = _mm256_set1_ps(952.72399902f);
404*a58d3d2aSXin Li const __m256 D1 = _mm256_set1_ps(413.36801147f);
405*a58d3d2aSXin Li const __m256 D2 = _mm256_set1_ps(11.88600922f);
406*a58d3d2aSXin Li const __m256 max_out = _mm256_set1_ps(1.f);
407*a58d3d2aSXin Li const __m256 min_out = _mm256_set1_ps(-1.f);
408*a58d3d2aSXin Li __m256 X2, num, den;
409*a58d3d2aSXin Li X2 = _mm256_mul_ps(X, X);
410*a58d3d2aSXin Li num = _mm256_fmadd_ps(_mm256_fmadd_ps(N2, X2, N1), X2, N0);
411*a58d3d2aSXin Li den = _mm256_fmadd_ps(_mm256_fmadd_ps(D2, X2, D1), X2, D0);
412*a58d3d2aSXin Li num = _mm256_mul_ps(num, X);
413*a58d3d2aSXin Li den = _mm256_rcp_ps(den);
414*a58d3d2aSXin Li num = _mm256_mul_ps(num, den);
415*a58d3d2aSXin Li return _mm256_max_ps(min_out, _mm256_min_ps(max_out, num));
416*a58d3d2aSXin Li }
417*a58d3d2aSXin Li
418*a58d3d2aSXin Li /* Sigmoid approximation using a Padé-like rational function:
419*a58d3d2aSXin Li 1/(1+exp(-x)) ~= 0.5 + x * (N0 + N1*x^2 + N2*x^4)/(D0 + D1*x^2 + D2*x^4)
420*a58d3d2aSXin Li subject to the [0, 1] bounds.
421*a58d3d2aSXin Li The coefficients are directly derived by dividing the tanh() coefficients
422*a58d3d2aSXin Li by powers of two to get the correct scaling. The max error is around 1.5e-4
423*a58d3d2aSXin Li and is dominated by the reciprocal approximation (the max error of the
424*a58d3d2aSXin Li rational function is around 3e-5).
425*a58d3d2aSXin Li */
sigmoid8_approx(__m256 X)426*a58d3d2aSXin Li static inline __m256 sigmoid8_approx(__m256 X)
427*a58d3d2aSXin Li {
428*a58d3d2aSXin Li const __m256 N0 = _mm256_set1_ps(238.13200378f);
429*a58d3d2aSXin Li const __m256 N1 = _mm256_set1_ps(6.02452230f);
430*a58d3d2aSXin Li const __m256 N2 = _mm256_set1_ps(0.00950985f);
431*a58d3d2aSXin Li const __m256 D0 = _mm256_set1_ps(952.72399902f);
432*a58d3d2aSXin Li const __m256 D1 = _mm256_set1_ps(103.34200287f);
433*a58d3d2aSXin Li const __m256 D2 = _mm256_set1_ps(0.74287558f);
434*a58d3d2aSXin Li const __m256 half = _mm256_set1_ps(0.5);
435*a58d3d2aSXin Li const __m256 max_out = _mm256_set1_ps(1.f);
436*a58d3d2aSXin Li const __m256 min_out = _mm256_set1_ps(0.f);
437*a58d3d2aSXin Li __m256 X2, num, den;
438*a58d3d2aSXin Li X2 = _mm256_mul_ps(X, X);
439*a58d3d2aSXin Li num = _mm256_fmadd_ps(_mm256_fmadd_ps(N2, X2, N1), X2, N0);
440*a58d3d2aSXin Li den = _mm256_fmadd_ps(_mm256_fmadd_ps(D2, X2, D1), X2, D0);
441*a58d3d2aSXin Li num = _mm256_mul_ps(num, X);
442*a58d3d2aSXin Li den = _mm256_rcp_ps(den);
443*a58d3d2aSXin Li num = _mm256_fmadd_ps(num, den, half);
444*a58d3d2aSXin Li return _mm256_max_ps(min_out, _mm256_min_ps(max_out, num));
445*a58d3d2aSXin Li }
446*a58d3d2aSXin Li
tanh_approx(float x)447*a58d3d2aSXin Li static inline float tanh_approx(float x)
448*a58d3d2aSXin Li {
449*a58d3d2aSXin Li float out[8];
450*a58d3d2aSXin Li __m256 X, Y;
451*a58d3d2aSXin Li X = _mm256_set1_ps(x);
452*a58d3d2aSXin Li Y = tanh8_approx(X);
453*a58d3d2aSXin Li _mm256_storeu_ps(out, Y);
454*a58d3d2aSXin Li return out[0];
455*a58d3d2aSXin Li }
456*a58d3d2aSXin Li
sigmoid_approx(float x)457*a58d3d2aSXin Li static inline float sigmoid_approx(float x)
458*a58d3d2aSXin Li {
459*a58d3d2aSXin Li float out[8];
460*a58d3d2aSXin Li __m256 X, Y;
461*a58d3d2aSXin Li X = _mm256_set1_ps(x);
462*a58d3d2aSXin Li Y = sigmoid8_approx(X);
463*a58d3d2aSXin Li _mm256_storeu_ps(out, Y);
464*a58d3d2aSXin Li return out[0];
465*a58d3d2aSXin Li }
466*a58d3d2aSXin Li
467*a58d3d2aSXin Li #else
468*a58d3d2aSXin Li
tanh4_approx(__m128 X)469*a58d3d2aSXin Li static inline __m128 tanh4_approx(__m128 X)
470*a58d3d2aSXin Li {
471*a58d3d2aSXin Li const __m128 N0 = _mm_set1_ps(952.52801514f);
472*a58d3d2aSXin Li const __m128 N1 = _mm_set1_ps(96.39235687f);
473*a58d3d2aSXin Li const __m128 N2 = _mm_set1_ps(0.60863042f);
474*a58d3d2aSXin Li const __m128 D0 = _mm_set1_ps(952.72399902f);
475*a58d3d2aSXin Li const __m128 D1 = _mm_set1_ps(413.36801147f);
476*a58d3d2aSXin Li const __m128 D2 = _mm_set1_ps(11.88600922f);
477*a58d3d2aSXin Li const __m128 max_out = _mm_set1_ps(1.f);
478*a58d3d2aSXin Li const __m128 min_out = _mm_set1_ps(-1.f);
479*a58d3d2aSXin Li __m128 X2, num, den;
480*a58d3d2aSXin Li X2 = _mm_mul_ps(X, X);
481*a58d3d2aSXin Li num = _mm_fmadd_ps(_mm_fmadd_ps(N2, X2, N1), X2, N0);
482*a58d3d2aSXin Li den = _mm_fmadd_ps(_mm_fmadd_ps(D2, X2, D1), X2, D0);
483*a58d3d2aSXin Li num = _mm_mul_ps(num, X);
484*a58d3d2aSXin Li den = _mm_rcp_ps(den);
485*a58d3d2aSXin Li num = _mm_mul_ps(num, den);
486*a58d3d2aSXin Li return _mm_max_ps(min_out, _mm_min_ps(max_out, num));
487*a58d3d2aSXin Li }
488*a58d3d2aSXin Li
sigmoid4_approx(__m128 X)489*a58d3d2aSXin Li static inline __m128 sigmoid4_approx(__m128 X)
490*a58d3d2aSXin Li {
491*a58d3d2aSXin Li const __m128 N0 = _mm_set1_ps(238.13200378f);
492*a58d3d2aSXin Li const __m128 N1 = _mm_set1_ps(6.02452230f);
493*a58d3d2aSXin Li const __m128 N2 = _mm_set1_ps(0.00950985f);
494*a58d3d2aSXin Li const __m128 D0 = _mm_set1_ps(952.72399902f);
495*a58d3d2aSXin Li const __m128 D1 = _mm_set1_ps(103.34200287f);
496*a58d3d2aSXin Li const __m128 D2 = _mm_set1_ps(0.74287558f);
497*a58d3d2aSXin Li const __m128 half = _mm_set1_ps(0.5);
498*a58d3d2aSXin Li const __m128 max_out = _mm_set1_ps(1.f);
499*a58d3d2aSXin Li const __m128 min_out = _mm_set1_ps(0.f);
500*a58d3d2aSXin Li __m128 X2, num, den;
501*a58d3d2aSXin Li X2 = _mm_mul_ps(X, X);
502*a58d3d2aSXin Li num = _mm_fmadd_ps(_mm_fmadd_ps(N2, X2, N1), X2, N0);
503*a58d3d2aSXin Li den = _mm_fmadd_ps(_mm_fmadd_ps(D2, X2, D1), X2, D0);
504*a58d3d2aSXin Li num = _mm_mul_ps(num, X);
505*a58d3d2aSXin Li den = _mm_rcp_ps(den);
506*a58d3d2aSXin Li num = _mm_fmadd_ps(num, den, half);
507*a58d3d2aSXin Li return _mm_max_ps(min_out, _mm_min_ps(max_out, num));
508*a58d3d2aSXin Li }
509*a58d3d2aSXin Li
tanh_approx(float x)510*a58d3d2aSXin Li static inline float tanh_approx(float x)
511*a58d3d2aSXin Li {
512*a58d3d2aSXin Li float out[4];
513*a58d3d2aSXin Li __m128 X, Y;
514*a58d3d2aSXin Li X = _mm_set1_ps(x);
515*a58d3d2aSXin Li Y = tanh4_approx(X);
516*a58d3d2aSXin Li _mm_storeu_ps(out, Y);
517*a58d3d2aSXin Li return out[0];
518*a58d3d2aSXin Li }
519*a58d3d2aSXin Li
sigmoid_approx(float x)520*a58d3d2aSXin Li static inline float sigmoid_approx(float x)
521*a58d3d2aSXin Li {
522*a58d3d2aSXin Li float out[4];
523*a58d3d2aSXin Li __m128 X, Y;
524*a58d3d2aSXin Li X = _mm_set1_ps(x);
525*a58d3d2aSXin Li Y = sigmoid4_approx(X);
526*a58d3d2aSXin Li _mm_storeu_ps(out, Y);
527*a58d3d2aSXin Li return out[0];
528*a58d3d2aSXin Li }
529*a58d3d2aSXin Li
530*a58d3d2aSXin Li #endif
531*a58d3d2aSXin Li
lpcnet_exp(float x)532*a58d3d2aSXin Li static inline float lpcnet_exp(float x)
533*a58d3d2aSXin Li {
534*a58d3d2aSXin Li float out[8];
535*a58d3d2aSXin Li __m256 X, Y;
536*a58d3d2aSXin Li X = _mm256_set1_ps(x);
537*a58d3d2aSXin Li Y = exp8_approx(X);
538*a58d3d2aSXin Li _mm256_storeu_ps(out, Y);
539*a58d3d2aSXin Li return out[0];
540*a58d3d2aSXin Li }
541*a58d3d2aSXin Li
softmax(float * y,const float * x,int N)542*a58d3d2aSXin Li static inline void softmax(float *y, const float *x, int N)
543*a58d3d2aSXin Li {
544*a58d3d2aSXin Li int i;
545*a58d3d2aSXin Li for (i=0;i<N-7;i+=8)
546*a58d3d2aSXin Li {
547*a58d3d2aSXin Li __m256 X, Y;
548*a58d3d2aSXin Li X = _mm256_loadu_ps(&x[i]);
549*a58d3d2aSXin Li Y = exp8_approx(X);
550*a58d3d2aSXin Li _mm256_storeu_ps(&y[i], Y);
551*a58d3d2aSXin Li }
552*a58d3d2aSXin Li for (;i<N;i++)
553*a58d3d2aSXin Li y[i] = lpcnet_exp(x[i]);
554*a58d3d2aSXin Li }
555*a58d3d2aSXin Li
556*a58d3d2aSXin Li #ifdef __AVX__
vec_tanh(float * y,const float * x,int N)557*a58d3d2aSXin Li static inline void vec_tanh(float *y, const float *x, int N)
558*a58d3d2aSXin Li {
559*a58d3d2aSXin Li int i;
560*a58d3d2aSXin Li for (i=0;i<N-7;i+=8)
561*a58d3d2aSXin Li {
562*a58d3d2aSXin Li __m256 X, Y;
563*a58d3d2aSXin Li X = _mm256_loadu_ps(&x[i]);
564*a58d3d2aSXin Li Y = tanh8_approx(X);
565*a58d3d2aSXin Li _mm256_storeu_ps(&y[i], Y);
566*a58d3d2aSXin Li }
567*a58d3d2aSXin Li for (;i<N;i++)
568*a58d3d2aSXin Li {
569*a58d3d2aSXin Li y[i] = tanh_approx(x[i]);
570*a58d3d2aSXin Li }
571*a58d3d2aSXin Li }
572*a58d3d2aSXin Li
vec_sigmoid(float * y,const float * x,int N)573*a58d3d2aSXin Li static inline void vec_sigmoid(float *y, const float *x, int N)
574*a58d3d2aSXin Li {
575*a58d3d2aSXin Li int i;
576*a58d3d2aSXin Li for (i=0;i<N-7;i+=8)
577*a58d3d2aSXin Li {
578*a58d3d2aSXin Li __m256 X, Y;
579*a58d3d2aSXin Li X = _mm256_loadu_ps(&x[i]);
580*a58d3d2aSXin Li Y = sigmoid8_approx(X);
581*a58d3d2aSXin Li _mm256_storeu_ps(&y[i], Y);
582*a58d3d2aSXin Li }
583*a58d3d2aSXin Li for (;i<N;i++)
584*a58d3d2aSXin Li {
585*a58d3d2aSXin Li y[i] = sigmoid_approx(x[i]);
586*a58d3d2aSXin Li }
587*a58d3d2aSXin Li }
588*a58d3d2aSXin Li #else
vec_tanh(float * y,const float * x,int N)589*a58d3d2aSXin Li static inline void vec_tanh(float *y, const float *x, int N)
590*a58d3d2aSXin Li {
591*a58d3d2aSXin Li int i;
592*a58d3d2aSXin Li for (i=0;i<N-3;i+=4)
593*a58d3d2aSXin Li {
594*a58d3d2aSXin Li __m128 X, Y;
595*a58d3d2aSXin Li X = _mm_loadu_ps(&x[i]);
596*a58d3d2aSXin Li Y = tanh4_approx(X);
597*a58d3d2aSXin Li _mm_storeu_ps(&y[i], Y);
598*a58d3d2aSXin Li }
599*a58d3d2aSXin Li for (;i<N;i++)
600*a58d3d2aSXin Li {
601*a58d3d2aSXin Li y[i] = tanh_approx(x[i]);
602*a58d3d2aSXin Li }
603*a58d3d2aSXin Li }
604*a58d3d2aSXin Li
vec_sigmoid(float * y,const float * x,int N)605*a58d3d2aSXin Li static inline void vec_sigmoid(float *y, const float *x, int N)
606*a58d3d2aSXin Li {
607*a58d3d2aSXin Li int i;
608*a58d3d2aSXin Li for (i=0;i<N-3;i+=4)
609*a58d3d2aSXin Li {
610*a58d3d2aSXin Li __m128 X, Y;
611*a58d3d2aSXin Li X = _mm_loadu_ps(&x[i]);
612*a58d3d2aSXin Li Y = sigmoid4_approx(X);
613*a58d3d2aSXin Li _mm_storeu_ps(&y[i], Y);
614*a58d3d2aSXin Li }
615*a58d3d2aSXin Li for (;i<N;i++)
616*a58d3d2aSXin Li {
617*a58d3d2aSXin Li y[i] = sigmoid_approx(x[i]);
618*a58d3d2aSXin Li }
619*a58d3d2aSXin Li }
620*a58d3d2aSXin Li
621*a58d3d2aSXin Li #endif
622*a58d3d2aSXin Li
623*a58d3d2aSXin Li #if defined(__AVXVNNI__) || defined(__AVX512VNNI__)
624*a58d3d2aSXin Li
625*a58d3d2aSXin Li #define opus_mm256_dpbusds_epi32(src, a, b) _mm256_dpbusds_epi32(src, a, b)
626*a58d3d2aSXin Li
627*a58d3d2aSXin Li #elif defined(__AVX2__)
628*a58d3d2aSXin Li
opus_mm256_dpbusds_epi32(__m256i src,__m256i a,__m256i b)629*a58d3d2aSXin Li static inline __m256i opus_mm256_dpbusds_epi32(__m256i src, __m256i a, __m256i b) {
630*a58d3d2aSXin Li __m256i ones, tmp;
631*a58d3d2aSXin Li ones = _mm256_set1_epi16(1);
632*a58d3d2aSXin Li tmp = _mm256_maddubs_epi16(a, b);
633*a58d3d2aSXin Li tmp = _mm256_madd_epi16(tmp, ones);
634*a58d3d2aSXin Li return _mm256_add_epi32(src, tmp);
635*a58d3d2aSXin Li }
636*a58d3d2aSXin Li
637*a58d3d2aSXin Li #elif defined(__SSSE3__)
638*a58d3d2aSXin Li
opus_mm256_dpbusds_epi32(mm256i_emu src,mm256i_emu a,mm256i_emu b)639*a58d3d2aSXin Li static inline mm256i_emu opus_mm256_dpbusds_epi32(mm256i_emu src, mm256i_emu a, mm256i_emu b) {
640*a58d3d2aSXin Li mm256i_emu ones, tmp;
641*a58d3d2aSXin Li ones = _mm256_set1_epi16(1);
642*a58d3d2aSXin Li tmp = _mm256_maddubs_epi16(a, b);
643*a58d3d2aSXin Li tmp = _mm256_madd_epi16(tmp, ones);
644*a58d3d2aSXin Li return _mm256_add_epi32(src, tmp);
645*a58d3d2aSXin Li }
646*a58d3d2aSXin Li
647*a58d3d2aSXin Li #elif defined(__SSE2__)
648*a58d3d2aSXin Li
mm_dpbusds_epi32(__m128i src,__m128i a,__m128i b)649*a58d3d2aSXin Li static inline __m128i mm_dpbusds_epi32(__m128i src, __m128i a, __m128i b) {
650*a58d3d2aSXin Li __m128i ah, al, bh, bl, tmp;
651*a58d3d2aSXin Li ah = _mm_srli_epi16(a, 8);
652*a58d3d2aSXin Li bh = _mm_srai_epi16(b, 8);
653*a58d3d2aSXin Li al = _mm_srli_epi16(_mm_slli_epi16(a, 8), 8);
654*a58d3d2aSXin Li bl = _mm_srai_epi16(_mm_slli_epi16(b, 8), 8);
655*a58d3d2aSXin Li tmp = _mm_add_epi32(_mm_madd_epi16(ah, bh), _mm_madd_epi16(al, bl));
656*a58d3d2aSXin Li return _mm_add_epi32(src, tmp);
657*a58d3d2aSXin Li }
658*a58d3d2aSXin Li
opus_mm256_dpbusds_epi32(mm256i_emu src,mm256i_emu a,mm256i_emu b)659*a58d3d2aSXin Li static inline mm256i_emu opus_mm256_dpbusds_epi32(mm256i_emu src, mm256i_emu a, mm256i_emu b) {
660*a58d3d2aSXin Li mm256i_emu res;
661*a58d3d2aSXin Li res.hi = mm_dpbusds_epi32(src.hi, a.hi, b.hi);
662*a58d3d2aSXin Li res.lo = mm_dpbusds_epi32(src.lo, a.lo, b.lo);
663*a58d3d2aSXin Li return res;
664*a58d3d2aSXin Li }
665*a58d3d2aSXin Li
666*a58d3d2aSXin Li
667*a58d3d2aSXin Li #else
668*a58d3d2aSXin Li
669*a58d3d2aSXin Li #error "No optimizations in vec_avx.h. This should never happen. "
670*a58d3d2aSXin Li #endif
671*a58d3d2aSXin Li
sgemv(float * out,const float * weights,int rows,int cols,int col_stride,const float * x)672*a58d3d2aSXin Li static inline void sgemv(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
673*a58d3d2aSXin Li {
674*a58d3d2aSXin Li int i, j;
675*a58d3d2aSXin Li i=0;
676*a58d3d2aSXin Li for (;i<rows-15;i+=16)
677*a58d3d2aSXin Li {
678*a58d3d2aSXin Li float *y;
679*a58d3d2aSXin Li __m256 vy0, vy8;
680*a58d3d2aSXin Li y = &out[i];
681*a58d3d2aSXin Li vy0 = _mm256_setzero_ps();
682*a58d3d2aSXin Li vy8 = _mm256_setzero_ps();
683*a58d3d2aSXin Li for (j=0;j<cols;j++)
684*a58d3d2aSXin Li {
685*a58d3d2aSXin Li __m256 vxj;
686*a58d3d2aSXin Li __m256 vw;
687*a58d3d2aSXin Li vxj = _mm256_broadcast_ss(&x[j]);
688*a58d3d2aSXin Li
689*a58d3d2aSXin Li vw = _mm256_loadu_ps(&weights[j*col_stride + i]);
690*a58d3d2aSXin Li vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
691*a58d3d2aSXin Li
692*a58d3d2aSXin Li vw = _mm256_loadu_ps(&weights[j*col_stride + i + 8]);
693*a58d3d2aSXin Li vy8 = _mm256_fmadd_ps(vw, vxj, vy8);
694*a58d3d2aSXin Li }
695*a58d3d2aSXin Li _mm256_storeu_ps (&y[0], vy0);
696*a58d3d2aSXin Li _mm256_storeu_ps (&y[8], vy8);
697*a58d3d2aSXin Li }
698*a58d3d2aSXin Li for (;i<rows-7;i+=8)
699*a58d3d2aSXin Li {
700*a58d3d2aSXin Li float *y;
701*a58d3d2aSXin Li __m256 vy0;
702*a58d3d2aSXin Li y = &out[i];
703*a58d3d2aSXin Li vy0 = _mm256_setzero_ps();
704*a58d3d2aSXin Li for (j=0;j<cols;j++)
705*a58d3d2aSXin Li {
706*a58d3d2aSXin Li __m256 vxj;
707*a58d3d2aSXin Li __m256 vw;
708*a58d3d2aSXin Li vxj = _mm256_broadcast_ss(&x[j]);
709*a58d3d2aSXin Li
710*a58d3d2aSXin Li vw = _mm256_loadu_ps(&weights[j*col_stride + i]);
711*a58d3d2aSXin Li vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
712*a58d3d2aSXin Li }
713*a58d3d2aSXin Li _mm256_storeu_ps (&y[0], vy0);
714*a58d3d2aSXin Li }
715*a58d3d2aSXin Li for (;i<rows-3;i+=4)
716*a58d3d2aSXin Li {
717*a58d3d2aSXin Li float *y;
718*a58d3d2aSXin Li __m128 vy0;
719*a58d3d2aSXin Li y = &out[i];
720*a58d3d2aSXin Li vy0 = _mm_setzero_ps();
721*a58d3d2aSXin Li for (j=0;j<cols;j++)
722*a58d3d2aSXin Li {
723*a58d3d2aSXin Li __m128 vxj;
724*a58d3d2aSXin Li __m128 vw;
725*a58d3d2aSXin Li vxj = _mm_set1_ps(x[j]);
726*a58d3d2aSXin Li
727*a58d3d2aSXin Li vw = _mm_loadu_ps(&weights[j*col_stride + i]);
728*a58d3d2aSXin Li vy0 = _mm_fmadd_ps(vw, vxj, vy0);
729*a58d3d2aSXin Li }
730*a58d3d2aSXin Li _mm_storeu_ps (&y[0], vy0);
731*a58d3d2aSXin Li }
732*a58d3d2aSXin Li for (;i<rows;i++)
733*a58d3d2aSXin Li {
734*a58d3d2aSXin Li out[i] = 0;
735*a58d3d2aSXin Li for (j=0;j<cols;j++) out[i] += weights[j*col_stride + i]*x[j];
736*a58d3d2aSXin Li }
737*a58d3d2aSXin Li }
738*a58d3d2aSXin Li
sparse_sgemv8x4(float * out,const float * weights,const int * idx,int rows,const float * x)739*a58d3d2aSXin Li static inline void sparse_sgemv8x4(float *out, const float *weights, const int *idx, int rows, const float *x)
740*a58d3d2aSXin Li {
741*a58d3d2aSXin Li int i, j;
742*a58d3d2aSXin Li for (i=0;i<rows;i+=8)
743*a58d3d2aSXin Li {
744*a58d3d2aSXin Li float *y;
745*a58d3d2aSXin Li int cols;
746*a58d3d2aSXin Li __m256 vy0;
747*a58d3d2aSXin Li y = &out[i];
748*a58d3d2aSXin Li vy0 = _mm256_setzero_ps();
749*a58d3d2aSXin Li cols = *idx++;
750*a58d3d2aSXin Li for (j=0;j<cols;j++)
751*a58d3d2aSXin Li {
752*a58d3d2aSXin Li int id;
753*a58d3d2aSXin Li __m256 vxj;
754*a58d3d2aSXin Li __m256 vw;
755*a58d3d2aSXin Li id = *idx++;
756*a58d3d2aSXin Li vxj = _mm256_broadcast_ss(&x[id]);
757*a58d3d2aSXin Li vw = _mm256_loadu_ps(&weights[0]);
758*a58d3d2aSXin Li vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
759*a58d3d2aSXin Li
760*a58d3d2aSXin Li vxj = _mm256_broadcast_ss(&x[id+1]);
761*a58d3d2aSXin Li vw = _mm256_loadu_ps(&weights[8]);
762*a58d3d2aSXin Li vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
763*a58d3d2aSXin Li
764*a58d3d2aSXin Li vxj = _mm256_broadcast_ss(&x[id+2]);
765*a58d3d2aSXin Li vw = _mm256_loadu_ps(&weights[16]);
766*a58d3d2aSXin Li vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
767*a58d3d2aSXin Li
768*a58d3d2aSXin Li vxj = _mm256_broadcast_ss(&x[id+3]);
769*a58d3d2aSXin Li vw = _mm256_loadu_ps(&weights[24]);
770*a58d3d2aSXin Li vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
771*a58d3d2aSXin Li
772*a58d3d2aSXin Li weights += 32;
773*a58d3d2aSXin Li }
774*a58d3d2aSXin Li _mm256_storeu_ps (&y[0], vy0);
775*a58d3d2aSXin Li }
776*a58d3d2aSXin Li }
777*a58d3d2aSXin Li
sparse_cgemv8x4(float * _out,const opus_int8 * w,const int * idx,const float * scale,int rows,int cols,const float * _x)778*a58d3d2aSXin Li static inline void sparse_cgemv8x4(float *_out, const opus_int8 *w, const int *idx, const float *scale, int rows, int cols, const float *_x)
779*a58d3d2aSXin Li {
780*a58d3d2aSXin Li int i, j;
781*a58d3d2aSXin Li unsigned char x[MAX_INPUTS];
782*a58d3d2aSXin Li /*for (i=0;i<cols;i++) x[i] = 127+floor(.5+127*_x[i]);*/
783*a58d3d2aSXin Li vector_ps_to_epi8(x, _x, cols);
784*a58d3d2aSXin Li for (i=0;i<rows;i+=8)
785*a58d3d2aSXin Li {
786*a58d3d2aSXin Li int colblocks;
787*a58d3d2aSXin Li __m256i vy0;
788*a58d3d2aSXin Li __m256 vout;
789*a58d3d2aSXin Li colblocks = *idx++;
790*a58d3d2aSXin Li vy0 = _mm256_setzero_si256();
791*a58d3d2aSXin Li j=0;
792*a58d3d2aSXin Li #if 1 /* Unrolling by 4 gives some gain, comment out if it does not. */
793*a58d3d2aSXin Li for (;j<colblocks-3;j+=4)
794*a58d3d2aSXin Li {
795*a58d3d2aSXin Li __m256i vxj;
796*a58d3d2aSXin Li __m256i vw;
797*a58d3d2aSXin Li vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[*idx++]));
798*a58d3d2aSXin Li vw = _mm256_loadu_si256((const __m256i *)(void*)w);
799*a58d3d2aSXin Li vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
800*a58d3d2aSXin Li w += 32;
801*a58d3d2aSXin Li vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[*idx++]));
802*a58d3d2aSXin Li vw = _mm256_loadu_si256((const __m256i *)(void*)w);
803*a58d3d2aSXin Li vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
804*a58d3d2aSXin Li w += 32;
805*a58d3d2aSXin Li vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[*idx++]));
806*a58d3d2aSXin Li vw = _mm256_loadu_si256((const __m256i *)(void*)w);
807*a58d3d2aSXin Li vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
808*a58d3d2aSXin Li w += 32;
809*a58d3d2aSXin Li vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[*idx++]));
810*a58d3d2aSXin Li vw = _mm256_loadu_si256((const __m256i *)(void*)w);
811*a58d3d2aSXin Li vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
812*a58d3d2aSXin Li w += 32;
813*a58d3d2aSXin Li }
814*a58d3d2aSXin Li #endif
815*a58d3d2aSXin Li for (;j<colblocks;j++)
816*a58d3d2aSXin Li {
817*a58d3d2aSXin Li __m256i vxj;
818*a58d3d2aSXin Li __m256i vw;
819*a58d3d2aSXin Li vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[*idx++]));
820*a58d3d2aSXin Li vw = _mm256_loadu_si256((const __m256i *)(void*)w);
821*a58d3d2aSXin Li vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
822*a58d3d2aSXin Li w += 32;
823*a58d3d2aSXin Li }
824*a58d3d2aSXin Li vout = _mm256_cvtepi32_ps(vy0);
825*a58d3d2aSXin Li vout = _mm256_mul_ps(vout, _mm256_loadu_ps(&scale[i]));
826*a58d3d2aSXin Li _mm256_storeu_ps(&_out[i], vout);
827*a58d3d2aSXin Li }
828*a58d3d2aSXin Li }
cgemv8x4(float * _out,const opus_int8 * w,const float * scale,int rows,int cols,const float * _x)829*a58d3d2aSXin Li static inline void cgemv8x4(float *_out, const opus_int8 *w, const float *scale, int rows, int cols, const float *_x)
830*a58d3d2aSXin Li {
831*a58d3d2aSXin Li int i, j;
832*a58d3d2aSXin Li unsigned char x[MAX_INPUTS];
833*a58d3d2aSXin Li /*for (i=0;i<cols;i++) x[i] = 127+floor(.5+127*_x[i]);*/
834*a58d3d2aSXin Li vector_ps_to_epi8(x, _x, cols);
835*a58d3d2aSXin Li for (i=0;i<rows;i+=8)
836*a58d3d2aSXin Li {
837*a58d3d2aSXin Li __m256i vy0;
838*a58d3d2aSXin Li __m256 vout;
839*a58d3d2aSXin Li vy0 = _mm256_setzero_si256();
840*a58d3d2aSXin Li j=0;
841*a58d3d2aSXin Li #if 1 /* Unrolling by 4 gives some gain, comment out if it does not. */
842*a58d3d2aSXin Li for (;j<cols-12;j+=16)
843*a58d3d2aSXin Li {
844*a58d3d2aSXin Li __m256i vxj;
845*a58d3d2aSXin Li __m256i vw;
846*a58d3d2aSXin Li vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[j]));
847*a58d3d2aSXin Li vw = _mm256_loadu_si256((const __m256i *)(void*)w);
848*a58d3d2aSXin Li vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
849*a58d3d2aSXin Li w += 32;
850*a58d3d2aSXin Li vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[j+4]));
851*a58d3d2aSXin Li vw = _mm256_loadu_si256((const __m256i *)(void*)w);
852*a58d3d2aSXin Li vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
853*a58d3d2aSXin Li w += 32;
854*a58d3d2aSXin Li vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[j+8]));
855*a58d3d2aSXin Li vw = _mm256_loadu_si256((const __m256i *)(void*)w);
856*a58d3d2aSXin Li vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
857*a58d3d2aSXin Li w += 32;
858*a58d3d2aSXin Li vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[j+12]));
859*a58d3d2aSXin Li vw = _mm256_loadu_si256((const __m256i *)(void*)w);
860*a58d3d2aSXin Li vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
861*a58d3d2aSXin Li w += 32;
862*a58d3d2aSXin Li }
863*a58d3d2aSXin Li #endif
864*a58d3d2aSXin Li for (;j<cols;j+=4)
865*a58d3d2aSXin Li {
866*a58d3d2aSXin Li __m256i vxj;
867*a58d3d2aSXin Li __m256i vw;
868*a58d3d2aSXin Li vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[j]));
869*a58d3d2aSXin Li vw = _mm256_loadu_si256((const __m256i *)(void*)w);
870*a58d3d2aSXin Li vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
871*a58d3d2aSXin Li w += 32;
872*a58d3d2aSXin Li }
873*a58d3d2aSXin Li vout = _mm256_cvtepi32_ps(vy0);
874*a58d3d2aSXin Li vout = _mm256_mul_ps(vout, _mm256_loadu_ps(&scale[i]));
875*a58d3d2aSXin Li _mm256_storeu_ps(&_out[i], vout);
876*a58d3d2aSXin Li }
877*a58d3d2aSXin Li }
878*a58d3d2aSXin Li
879*a58d3d2aSXin Li #define SCALE (128.f*127.f)
880*a58d3d2aSXin Li #define SCALE_1 (1.f/128.f/127.f)
881*a58d3d2aSXin Li #define USE_SU_BIAS
882*a58d3d2aSXin Li
883*a58d3d2aSXin Li
884*a58d3d2aSXin Li #endif /*VEC_AVX_H*/
885