1*a58d3d2aSXin Li /* Copyright (c) 2018-2019 Mozilla
2*a58d3d2aSXin Li 2023 Amazon */
3*a58d3d2aSXin Li /*
4*a58d3d2aSXin Li Redistribution and use in source and binary forms, with or without
5*a58d3d2aSXin Li modification, are permitted provided that the following conditions
6*a58d3d2aSXin Li are met:
7*a58d3d2aSXin Li
8*a58d3d2aSXin Li - Redistributions of source code must retain the above copyright
9*a58d3d2aSXin Li notice, this list of conditions and the following disclaimer.
10*a58d3d2aSXin Li
11*a58d3d2aSXin Li - Redistributions in binary form must reproduce the above copyright
12*a58d3d2aSXin Li notice, this list of conditions and the following disclaimer in the
13*a58d3d2aSXin Li documentation and/or other materials provided with the distribution.
14*a58d3d2aSXin Li
15*a58d3d2aSXin Li THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16*a58d3d2aSXin Li ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17*a58d3d2aSXin Li LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18*a58d3d2aSXin Li A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR
19*a58d3d2aSXin Li CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20*a58d3d2aSXin Li EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21*a58d3d2aSXin Li PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22*a58d3d2aSXin Li PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23*a58d3d2aSXin Li LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24*a58d3d2aSXin Li NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25*a58d3d2aSXin Li SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26*a58d3d2aSXin Li */
27*a58d3d2aSXin Li
28*a58d3d2aSXin Li #ifndef NNET_ARCH_H
29*a58d3d2aSXin Li #define NNET_ARCH_H
30*a58d3d2aSXin Li
31*a58d3d2aSXin Li #include "nnet.h"
32*a58d3d2aSXin Li #include "arch.h"
33*a58d3d2aSXin Li #include "os_support.h"
34*a58d3d2aSXin Li #include "vec.h"
35*a58d3d2aSXin Li
36*a58d3d2aSXin Li #define CAT_SUFFIX2(a,b) a ## b
37*a58d3d2aSXin Li #define CAT_SUFFIX(a,b) CAT_SUFFIX2(a, b)
38*a58d3d2aSXin Li
39*a58d3d2aSXin Li #define RTCD_SUF(name) CAT_SUFFIX(name, RTCD_ARCH)
40*a58d3d2aSXin Li
41*a58d3d2aSXin Li /* Force vectorization on for DNN code because some of the loops rely on
42*a58d3d2aSXin Li compiler vectorization rather than explicitly using intrinsics. */
43*a58d3d2aSXin Li #if OPUS_GNUC_PREREQ(5,1)
44*a58d3d2aSXin Li #define GCC_POP_OPTIONS
45*a58d3d2aSXin Li #pragma GCC push_options
46*a58d3d2aSXin Li #pragma GCC optimize("tree-vectorize")
47*a58d3d2aSXin Li #endif
48*a58d3d2aSXin Li
49*a58d3d2aSXin Li
50*a58d3d2aSXin Li #define MAX_ACTIVATIONS (4096)
51*a58d3d2aSXin Li
vec_swish(float * y,const float * x,int N)52*a58d3d2aSXin Li static OPUS_INLINE void vec_swish(float *y, const float *x, int N)
53*a58d3d2aSXin Li {
54*a58d3d2aSXin Li int i;
55*a58d3d2aSXin Li float tmp[MAX_ACTIVATIONS];
56*a58d3d2aSXin Li celt_assert(N <= MAX_ACTIVATIONS);
57*a58d3d2aSXin Li vec_sigmoid(tmp, x, N);
58*a58d3d2aSXin Li for (i=0;i<N;i++)
59*a58d3d2aSXin Li y[i] = x[i]*tmp[i];
60*a58d3d2aSXin Li }
61*a58d3d2aSXin Li
relu(float x)62*a58d3d2aSXin Li static OPUS_INLINE float relu(float x)
63*a58d3d2aSXin Li {
64*a58d3d2aSXin Li return x < 0 ? 0 : x;
65*a58d3d2aSXin Li }
66*a58d3d2aSXin Li
67*a58d3d2aSXin Li /*#define HIGH_ACCURACY */
68*a58d3d2aSXin Li
RTCD_SUF(compute_activation_)69*a58d3d2aSXin Li void RTCD_SUF(compute_activation_)(float *output, const float *input, int N, int activation)
70*a58d3d2aSXin Li {
71*a58d3d2aSXin Li int i;
72*a58d3d2aSXin Li if (activation == ACTIVATION_SIGMOID) {
73*a58d3d2aSXin Li #ifdef HIGH_ACCURACY
74*a58d3d2aSXin Li for (int n=0; n<N; n++)
75*a58d3d2aSXin Li {
76*a58d3d2aSXin Li output[n] = 1.f / (1 + exp(-input[n]));
77*a58d3d2aSXin Li }
78*a58d3d2aSXin Li #else
79*a58d3d2aSXin Li vec_sigmoid(output, input, N);
80*a58d3d2aSXin Li #endif
81*a58d3d2aSXin Li } else if (activation == ACTIVATION_TANH) {
82*a58d3d2aSXin Li #ifdef HIGH_ACCURACY
83*a58d3d2aSXin Li for (int n=0; n<N; n++)
84*a58d3d2aSXin Li {
85*a58d3d2aSXin Li output[n] = tanh(input[n]);
86*a58d3d2aSXin Li }
87*a58d3d2aSXin Li #else
88*a58d3d2aSXin Li vec_tanh(output, input, N);
89*a58d3d2aSXin Li #endif
90*a58d3d2aSXin Li } else if (activation == ACTIVATION_SWISH) {
91*a58d3d2aSXin Li vec_swish(output, input, N);
92*a58d3d2aSXin Li } else if (activation == ACTIVATION_RELU) {
93*a58d3d2aSXin Li for (i=0;i<N;i++)
94*a58d3d2aSXin Li output[i] = relu(input[i]);
95*a58d3d2aSXin Li } else if (activation == ACTIVATION_SOFTMAX) {
96*a58d3d2aSXin Li #ifdef SOFTMAX_HACK
97*a58d3d2aSXin Li OPUS_COPY(output, input, N);
98*a58d3d2aSXin Li /*for (i=0;i<N;i++)
99*a58d3d2aSXin Li output[i] = input[i];*/
100*a58d3d2aSXin Li #else
101*a58d3d2aSXin Li float sum = 0;
102*a58d3d2aSXin Li softmax(output, input, N);
103*a58d3d2aSXin Li for (i=0;i<N;i++) {
104*a58d3d2aSXin Li sum += output[i];
105*a58d3d2aSXin Li }
106*a58d3d2aSXin Li sum = 1.f/(sum+1e-30);
107*a58d3d2aSXin Li for (i=0;i<N;i++)
108*a58d3d2aSXin Li output[i] = sum*output[i];
109*a58d3d2aSXin Li #endif
110*a58d3d2aSXin Li } else {
111*a58d3d2aSXin Li celt_assert(activation == ACTIVATION_LINEAR);
112*a58d3d2aSXin Li if (input != output) {
113*a58d3d2aSXin Li for (i=0;i<N;i++)
114*a58d3d2aSXin Li output[i] = input[i];
115*a58d3d2aSXin Li }
116*a58d3d2aSXin Li }
117*a58d3d2aSXin Li }
118*a58d3d2aSXin Li
119*a58d3d2aSXin Li
RTCD_SUF(compute_linear_)120*a58d3d2aSXin Li void RTCD_SUF(compute_linear_) (const LinearLayer *linear, float *out, const float *in)
121*a58d3d2aSXin Li {
122*a58d3d2aSXin Li int i, M, N;
123*a58d3d2aSXin Li const float *bias;
124*a58d3d2aSXin Li celt_assert(in != out);
125*a58d3d2aSXin Li bias = linear->bias;
126*a58d3d2aSXin Li M = linear->nb_inputs;
127*a58d3d2aSXin Li N = linear->nb_outputs;
128*a58d3d2aSXin Li if (linear->float_weights != NULL) {
129*a58d3d2aSXin Li if (linear->weights_idx != NULL) sparse_sgemv8x4(out, linear->float_weights, linear->weights_idx, N, in);
130*a58d3d2aSXin Li else sgemv(out, linear->float_weights, N, M, N, in);
131*a58d3d2aSXin Li } else if (linear->weights != NULL) {
132*a58d3d2aSXin Li if (linear->weights_idx != NULL) sparse_cgemv8x4(out, linear->weights, linear->weights_idx, linear->scale, N, M, in);
133*a58d3d2aSXin Li else cgemv8x4(out, linear->weights, linear->scale, N, M, in);
134*a58d3d2aSXin Li /* Only use SU biases on for integer matrices on SU archs. */
135*a58d3d2aSXin Li #ifdef USE_SU_BIAS
136*a58d3d2aSXin Li bias = linear->subias;
137*a58d3d2aSXin Li #endif
138*a58d3d2aSXin Li }
139*a58d3d2aSXin Li else OPUS_CLEAR(out, N);
140*a58d3d2aSXin Li if (bias != NULL) {
141*a58d3d2aSXin Li for (i=0;i<N;i++) out[i] += bias[i];
142*a58d3d2aSXin Li }
143*a58d3d2aSXin Li if (linear->diag) {
144*a58d3d2aSXin Li /* Diag is only used for GRU recurrent weights. */
145*a58d3d2aSXin Li celt_assert(3*M == N);
146*a58d3d2aSXin Li for (i=0;i<M;i++) {
147*a58d3d2aSXin Li out[i] += linear->diag[i]*in[i];
148*a58d3d2aSXin Li out[i+M] += linear->diag[i+M]*in[i];
149*a58d3d2aSXin Li out[i+2*M] += linear->diag[i+2*M]*in[i];
150*a58d3d2aSXin Li }
151*a58d3d2aSXin Li }
152*a58d3d2aSXin Li }
153*a58d3d2aSXin Li
154*a58d3d2aSXin Li /* Computes non-padded convolution for input [ ksize1 x in_channels x (len2+ksize2) ],
155*a58d3d2aSXin Li kernel [ out_channels x in_channels x ksize1 x ksize2 ],
156*a58d3d2aSXin Li storing the output as [ out_channels x len2 ].
157*a58d3d2aSXin Li We assume that the output dimension along the ksize1 axis is 1,
158*a58d3d2aSXin Li i.e. processing one frame at a time. */
conv2d_float(float * out,const float * weights,int in_channels,int out_channels,int ktime,int kheight,const float * in,int height,int hstride)159*a58d3d2aSXin Li static void conv2d_float(float *out, const float *weights, int in_channels, int out_channels, int ktime, int kheight, const float *in, int height, int hstride)
160*a58d3d2aSXin Li {
161*a58d3d2aSXin Li int i;
162*a58d3d2aSXin Li int in_stride;
163*a58d3d2aSXin Li in_stride = height+kheight-1;
164*a58d3d2aSXin Li for (i=0;i<out_channels;i++) {
165*a58d3d2aSXin Li int m;
166*a58d3d2aSXin Li OPUS_CLEAR(&out[i*hstride], height);
167*a58d3d2aSXin Li for (m=0;m<in_channels;m++) {
168*a58d3d2aSXin Li int t;
169*a58d3d2aSXin Li for (t=0;t<ktime;t++) {
170*a58d3d2aSXin Li int h;
171*a58d3d2aSXin Li for (h=0;h<kheight;h++) {
172*a58d3d2aSXin Li int j;
173*a58d3d2aSXin Li for (j=0;j<height;j++) {
174*a58d3d2aSXin Li out[i*hstride + j] += weights[i*in_channels*ktime*kheight + m*ktime*kheight + t*kheight + h] *
175*a58d3d2aSXin Li in[t*in_channels*in_stride + m*in_stride + j + h];
176*a58d3d2aSXin Li }
177*a58d3d2aSXin Li }
178*a58d3d2aSXin Li }
179*a58d3d2aSXin Li }
180*a58d3d2aSXin Li }
181*a58d3d2aSXin Li }
182*a58d3d2aSXin Li
183*a58d3d2aSXin Li /* There's no intrinsics in this function (or the one above) because the gcc (and hopefully other compiler) auto-vectorizer is smart enough to
184*a58d3d2aSXin Li produce the right code by itself based on the compile flags. */
conv2d_3x3_float(float * out,const float * weights,int in_channels,int out_channels,const float * in,int height,int hstride)185*a58d3d2aSXin Li static void conv2d_3x3_float(float *out, const float *weights, int in_channels, int out_channels, const float *in, int height, int hstride)
186*a58d3d2aSXin Li {
187*a58d3d2aSXin Li int i;
188*a58d3d2aSXin Li int in_stride;
189*a58d3d2aSXin Li int kheight, ktime;
190*a58d3d2aSXin Li kheight = ktime = 3;
191*a58d3d2aSXin Li in_stride = height+kheight-1;
192*a58d3d2aSXin Li for (i=0;i<out_channels;i++) {
193*a58d3d2aSXin Li int m;
194*a58d3d2aSXin Li OPUS_CLEAR(&out[i*hstride], height);
195*a58d3d2aSXin Li for (m=0;m<in_channels;m++) {
196*a58d3d2aSXin Li int j;
197*a58d3d2aSXin Li for (j=0;j<height;j++) {
198*a58d3d2aSXin Li /* Unrolled version of previous function -- compiler will figure out the indexing simplifications. */
199*a58d3d2aSXin Li out[i*hstride + j] += weights[i*in_channels*ktime*kheight + m*ktime*kheight + 0*kheight + 0]*in[0*in_channels*in_stride + m*in_stride + j + 0]
200*a58d3d2aSXin Li + weights[i*in_channels*ktime*kheight + m*ktime*kheight + 0*kheight + 1]*in[0*in_channels*in_stride + m*in_stride + j + 1]
201*a58d3d2aSXin Li + weights[i*in_channels*ktime*kheight + m*ktime*kheight + 0*kheight + 2]*in[0*in_channels*in_stride + m*in_stride + j + 2]
202*a58d3d2aSXin Li + weights[i*in_channels*ktime*kheight + m*ktime*kheight + 1*kheight + 0]*in[1*in_channels*in_stride + m*in_stride + j + 0]
203*a58d3d2aSXin Li + weights[i*in_channels*ktime*kheight + m*ktime*kheight + 1*kheight + 1]*in[1*in_channels*in_stride + m*in_stride + j + 1]
204*a58d3d2aSXin Li + weights[i*in_channels*ktime*kheight + m*ktime*kheight + 1*kheight + 2]*in[1*in_channels*in_stride + m*in_stride + j + 2]
205*a58d3d2aSXin Li + weights[i*in_channels*ktime*kheight + m*ktime*kheight + 2*kheight + 0]*in[2*in_channels*in_stride + m*in_stride + j + 0]
206*a58d3d2aSXin Li + weights[i*in_channels*ktime*kheight + m*ktime*kheight + 2*kheight + 1]*in[2*in_channels*in_stride + m*in_stride + j + 1]
207*a58d3d2aSXin Li + weights[i*in_channels*ktime*kheight + m*ktime*kheight + 2*kheight + 2]*in[2*in_channels*in_stride + m*in_stride + j + 2];
208*a58d3d2aSXin Li }
209*a58d3d2aSXin Li }
210*a58d3d2aSXin Li }
211*a58d3d2aSXin Li }
212*a58d3d2aSXin Li
213*a58d3d2aSXin Li #define MAX_CONV2D_INPUTS 8192
214*a58d3d2aSXin Li
RTCD_SUF(compute_conv2d_)215*a58d3d2aSXin Li void RTCD_SUF(compute_conv2d_)(const Conv2dLayer *conv, float *out, float *mem, const float *in, int height, int hstride, int activation)
216*a58d3d2aSXin Li {
217*a58d3d2aSXin Li int i;
218*a58d3d2aSXin Li const float *bias;
219*a58d3d2aSXin Li float in_buf[MAX_CONV2D_INPUTS];
220*a58d3d2aSXin Li int time_stride;
221*a58d3d2aSXin Li celt_assert(in != out);
222*a58d3d2aSXin Li time_stride = conv->in_channels*(height+conv->kheight-1);
223*a58d3d2aSXin Li celt_assert(conv->ktime*time_stride <= MAX_CONV2D_INPUTS);
224*a58d3d2aSXin Li OPUS_COPY(in_buf, mem, (conv->ktime-1)*time_stride);
225*a58d3d2aSXin Li OPUS_COPY(&in_buf[(conv->ktime-1)*time_stride], in, time_stride);
226*a58d3d2aSXin Li OPUS_COPY(mem, &in_buf[time_stride], (conv->ktime-1)*time_stride);
227*a58d3d2aSXin Li bias = conv->bias;
228*a58d3d2aSXin Li if (conv->kheight == 3 && conv->ktime == 3)
229*a58d3d2aSXin Li conv2d_3x3_float(out, conv->float_weights, conv->in_channels, conv->out_channels, in_buf, height, hstride);
230*a58d3d2aSXin Li else
231*a58d3d2aSXin Li conv2d_float(out, conv->float_weights, conv->in_channels, conv->out_channels, conv->ktime, conv->kheight, in_buf, height, hstride);
232*a58d3d2aSXin Li if (bias != NULL) {
233*a58d3d2aSXin Li for (i=0;i<conv->out_channels;i++) {
234*a58d3d2aSXin Li int j;
235*a58d3d2aSXin Li for (j=0;j<height;j++) out[i*hstride+j] += bias[i];
236*a58d3d2aSXin Li }
237*a58d3d2aSXin Li }
238*a58d3d2aSXin Li for (i=0;i<conv->out_channels;i++) {
239*a58d3d2aSXin Li RTCD_SUF(compute_activation_)(&out[i*hstride], &out[i*hstride], height, activation);
240*a58d3d2aSXin Li }
241*a58d3d2aSXin Li }
242*a58d3d2aSXin Li
243*a58d3d2aSXin Li #ifdef GCC_POP_OPTIONS
244*a58d3d2aSXin Li #pragma GCC pop_options
245*a58d3d2aSXin Li #endif
246*a58d3d2aSXin Li
247*a58d3d2aSXin Li #endif
248