xref: /aosp_15_r20/external/libopus/dnn/nnet_arch.h (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1*a58d3d2aSXin Li /* Copyright (c) 2018-2019 Mozilla
2*a58d3d2aSXin Li                  2023 Amazon */
3*a58d3d2aSXin Li /*
4*a58d3d2aSXin Li    Redistribution and use in source and binary forms, with or without
5*a58d3d2aSXin Li    modification, are permitted provided that the following conditions
6*a58d3d2aSXin Li    are met:
7*a58d3d2aSXin Li 
8*a58d3d2aSXin Li    - Redistributions of source code must retain the above copyright
9*a58d3d2aSXin Li    notice, this list of conditions and the following disclaimer.
10*a58d3d2aSXin Li 
11*a58d3d2aSXin Li    - Redistributions in binary form must reproduce the above copyright
12*a58d3d2aSXin Li    notice, this list of conditions and the following disclaimer in the
13*a58d3d2aSXin Li    documentation and/or other materials provided with the distribution.
14*a58d3d2aSXin Li 
15*a58d3d2aSXin Li    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16*a58d3d2aSXin Li    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17*a58d3d2aSXin Li    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18*a58d3d2aSXin Li    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
19*a58d3d2aSXin Li    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20*a58d3d2aSXin Li    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21*a58d3d2aSXin Li    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22*a58d3d2aSXin Li    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23*a58d3d2aSXin Li    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24*a58d3d2aSXin Li    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25*a58d3d2aSXin Li    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26*a58d3d2aSXin Li */
27*a58d3d2aSXin Li 
28*a58d3d2aSXin Li #ifndef NNET_ARCH_H
29*a58d3d2aSXin Li #define NNET_ARCH_H
30*a58d3d2aSXin Li 
31*a58d3d2aSXin Li #include "nnet.h"
32*a58d3d2aSXin Li #include "arch.h"
33*a58d3d2aSXin Li #include "os_support.h"
34*a58d3d2aSXin Li #include "vec.h"
35*a58d3d2aSXin Li 
36*a58d3d2aSXin Li #define CAT_SUFFIX2(a,b) a ## b
37*a58d3d2aSXin Li #define CAT_SUFFIX(a,b) CAT_SUFFIX2(a, b)
38*a58d3d2aSXin Li 
39*a58d3d2aSXin Li #define RTCD_SUF(name) CAT_SUFFIX(name, RTCD_ARCH)
40*a58d3d2aSXin Li 
41*a58d3d2aSXin Li /* Force vectorization on for DNN code because some of the loops rely on
42*a58d3d2aSXin Li    compiler vectorization rather than explicitly using intrinsics. */
43*a58d3d2aSXin Li #if OPUS_GNUC_PREREQ(5,1)
44*a58d3d2aSXin Li #define GCC_POP_OPTIONS
45*a58d3d2aSXin Li #pragma GCC push_options
46*a58d3d2aSXin Li #pragma GCC optimize("tree-vectorize")
47*a58d3d2aSXin Li #endif
48*a58d3d2aSXin Li 
49*a58d3d2aSXin Li 
50*a58d3d2aSXin Li #define MAX_ACTIVATIONS (4096)
51*a58d3d2aSXin Li 
vec_swish(float * y,const float * x,int N)52*a58d3d2aSXin Li static OPUS_INLINE void vec_swish(float *y, const float *x, int N)
53*a58d3d2aSXin Li {
54*a58d3d2aSXin Li    int i;
55*a58d3d2aSXin Li    float tmp[MAX_ACTIVATIONS];
56*a58d3d2aSXin Li    celt_assert(N <= MAX_ACTIVATIONS);
57*a58d3d2aSXin Li    vec_sigmoid(tmp, x, N);
58*a58d3d2aSXin Li    for (i=0;i<N;i++)
59*a58d3d2aSXin Li       y[i] = x[i]*tmp[i];
60*a58d3d2aSXin Li }
61*a58d3d2aSXin Li 
relu(float x)62*a58d3d2aSXin Li static OPUS_INLINE float relu(float x)
63*a58d3d2aSXin Li {
64*a58d3d2aSXin Li    return x < 0 ? 0 : x;
65*a58d3d2aSXin Li }
66*a58d3d2aSXin Li 
67*a58d3d2aSXin Li /*#define HIGH_ACCURACY */
68*a58d3d2aSXin Li 
RTCD_SUF(compute_activation_)69*a58d3d2aSXin Li void RTCD_SUF(compute_activation_)(float *output, const float *input, int N, int activation)
70*a58d3d2aSXin Li {
71*a58d3d2aSXin Li    int i;
72*a58d3d2aSXin Li    if (activation == ACTIVATION_SIGMOID) {
73*a58d3d2aSXin Li #ifdef HIGH_ACCURACY
74*a58d3d2aSXin Li       for (int n=0; n<N; n++)
75*a58d3d2aSXin Li       {
76*a58d3d2aSXin Li          output[n] = 1.f  / (1 + exp(-input[n]));
77*a58d3d2aSXin Li       }
78*a58d3d2aSXin Li #else
79*a58d3d2aSXin Li       vec_sigmoid(output, input, N);
80*a58d3d2aSXin Li #endif
81*a58d3d2aSXin Li    } else if (activation == ACTIVATION_TANH) {
82*a58d3d2aSXin Li #ifdef HIGH_ACCURACY
83*a58d3d2aSXin Li       for (int n=0; n<N; n++)
84*a58d3d2aSXin Li       {
85*a58d3d2aSXin Li          output[n] = tanh(input[n]);
86*a58d3d2aSXin Li       }
87*a58d3d2aSXin Li #else
88*a58d3d2aSXin Li       vec_tanh(output, input, N);
89*a58d3d2aSXin Li #endif
90*a58d3d2aSXin Li    } else if (activation == ACTIVATION_SWISH) {
91*a58d3d2aSXin Li       vec_swish(output, input, N);
92*a58d3d2aSXin Li    } else if (activation == ACTIVATION_RELU) {
93*a58d3d2aSXin Li       for (i=0;i<N;i++)
94*a58d3d2aSXin Li          output[i] = relu(input[i]);
95*a58d3d2aSXin Li    } else if (activation == ACTIVATION_SOFTMAX) {
96*a58d3d2aSXin Li #ifdef SOFTMAX_HACK
97*a58d3d2aSXin Li       OPUS_COPY(output, input, N);
98*a58d3d2aSXin Li       /*for (i=0;i<N;i++)
99*a58d3d2aSXin Li          output[i] = input[i];*/
100*a58d3d2aSXin Li #else
101*a58d3d2aSXin Li       float sum = 0;
102*a58d3d2aSXin Li       softmax(output, input, N);
103*a58d3d2aSXin Li       for (i=0;i<N;i++) {
104*a58d3d2aSXin Li          sum += output[i];
105*a58d3d2aSXin Li       }
106*a58d3d2aSXin Li       sum = 1.f/(sum+1e-30);
107*a58d3d2aSXin Li       for (i=0;i<N;i++)
108*a58d3d2aSXin Li          output[i] = sum*output[i];
109*a58d3d2aSXin Li #endif
110*a58d3d2aSXin Li    } else {
111*a58d3d2aSXin Li       celt_assert(activation == ACTIVATION_LINEAR);
112*a58d3d2aSXin Li       if (input != output) {
113*a58d3d2aSXin Li          for (i=0;i<N;i++)
114*a58d3d2aSXin Li             output[i] = input[i];
115*a58d3d2aSXin Li       }
116*a58d3d2aSXin Li    }
117*a58d3d2aSXin Li }
118*a58d3d2aSXin Li 
119*a58d3d2aSXin Li 
RTCD_SUF(compute_linear_)120*a58d3d2aSXin Li void RTCD_SUF(compute_linear_) (const LinearLayer *linear, float *out, const float *in)
121*a58d3d2aSXin Li {
122*a58d3d2aSXin Li    int i, M, N;
123*a58d3d2aSXin Li    const float *bias;
124*a58d3d2aSXin Li    celt_assert(in != out);
125*a58d3d2aSXin Li    bias = linear->bias;
126*a58d3d2aSXin Li    M = linear->nb_inputs;
127*a58d3d2aSXin Li    N = linear->nb_outputs;
128*a58d3d2aSXin Li    if (linear->float_weights != NULL) {
129*a58d3d2aSXin Li      if (linear->weights_idx != NULL) sparse_sgemv8x4(out, linear->float_weights, linear->weights_idx, N, in);
130*a58d3d2aSXin Li      else sgemv(out, linear->float_weights, N, M, N, in);
131*a58d3d2aSXin Li    } else if (linear->weights != NULL) {
132*a58d3d2aSXin Li      if (linear->weights_idx != NULL) sparse_cgemv8x4(out, linear->weights, linear->weights_idx, linear->scale, N, M, in);
133*a58d3d2aSXin Li      else cgemv8x4(out, linear->weights, linear->scale, N, M, in);
134*a58d3d2aSXin Li      /* Only use SU biases on for integer matrices on SU archs. */
135*a58d3d2aSXin Li #ifdef USE_SU_BIAS
136*a58d3d2aSXin Li      bias = linear->subias;
137*a58d3d2aSXin Li #endif
138*a58d3d2aSXin Li    }
139*a58d3d2aSXin Li    else OPUS_CLEAR(out, N);
140*a58d3d2aSXin Li    if (bias != NULL) {
141*a58d3d2aSXin Li       for (i=0;i<N;i++) out[i] += bias[i];
142*a58d3d2aSXin Li    }
143*a58d3d2aSXin Li    if (linear->diag) {
144*a58d3d2aSXin Li       /* Diag is only used for GRU recurrent weights. */
145*a58d3d2aSXin Li       celt_assert(3*M == N);
146*a58d3d2aSXin Li       for (i=0;i<M;i++) {
147*a58d3d2aSXin Li          out[i] += linear->diag[i]*in[i];
148*a58d3d2aSXin Li          out[i+M] += linear->diag[i+M]*in[i];
149*a58d3d2aSXin Li          out[i+2*M] += linear->diag[i+2*M]*in[i];
150*a58d3d2aSXin Li       }
151*a58d3d2aSXin Li    }
152*a58d3d2aSXin Li }
153*a58d3d2aSXin Li 
154*a58d3d2aSXin Li /* Computes non-padded convolution for input [ ksize1 x in_channels x (len2+ksize2) ],
155*a58d3d2aSXin Li    kernel [ out_channels x in_channels x ksize1 x ksize2 ],
156*a58d3d2aSXin Li    storing the output as [ out_channels x len2 ].
157*a58d3d2aSXin Li    We assume that the output dimension along the ksize1 axis is 1,
158*a58d3d2aSXin Li    i.e. processing one frame at a time. */
conv2d_float(float * out,const float * weights,int in_channels,int out_channels,int ktime,int kheight,const float * in,int height,int hstride)159*a58d3d2aSXin Li static void conv2d_float(float *out, const float *weights, int in_channels, int out_channels, int ktime, int kheight, const float *in, int height, int hstride)
160*a58d3d2aSXin Li {
161*a58d3d2aSXin Li    int i;
162*a58d3d2aSXin Li    int in_stride;
163*a58d3d2aSXin Li    in_stride = height+kheight-1;
164*a58d3d2aSXin Li    for (i=0;i<out_channels;i++) {
165*a58d3d2aSXin Li       int m;
166*a58d3d2aSXin Li       OPUS_CLEAR(&out[i*hstride], height);
167*a58d3d2aSXin Li       for (m=0;m<in_channels;m++) {
168*a58d3d2aSXin Li          int t;
169*a58d3d2aSXin Li          for (t=0;t<ktime;t++) {
170*a58d3d2aSXin Li             int h;
171*a58d3d2aSXin Li             for (h=0;h<kheight;h++) {
172*a58d3d2aSXin Li                int j;
173*a58d3d2aSXin Li                for (j=0;j<height;j++) {
174*a58d3d2aSXin Li                   out[i*hstride + j] += weights[i*in_channels*ktime*kheight + m*ktime*kheight + t*kheight + h] *
175*a58d3d2aSXin Li                                      in[t*in_channels*in_stride + m*in_stride + j + h];
176*a58d3d2aSXin Li                }
177*a58d3d2aSXin Li             }
178*a58d3d2aSXin Li          }
179*a58d3d2aSXin Li       }
180*a58d3d2aSXin Li    }
181*a58d3d2aSXin Li }
182*a58d3d2aSXin Li 
183*a58d3d2aSXin Li /* There's no intrinsics in this function (or the one above) because the gcc (and hopefully other compiler) auto-vectorizer is smart enough to
184*a58d3d2aSXin Li    produce the right code by itself based on the compile flags. */
conv2d_3x3_float(float * out,const float * weights,int in_channels,int out_channels,const float * in,int height,int hstride)185*a58d3d2aSXin Li static void conv2d_3x3_float(float *out, const float *weights, int in_channels, int out_channels, const float *in, int height, int hstride)
186*a58d3d2aSXin Li {
187*a58d3d2aSXin Li    int i;
188*a58d3d2aSXin Li    int in_stride;
189*a58d3d2aSXin Li    int kheight, ktime;
190*a58d3d2aSXin Li    kheight = ktime = 3;
191*a58d3d2aSXin Li    in_stride = height+kheight-1;
192*a58d3d2aSXin Li    for (i=0;i<out_channels;i++) {
193*a58d3d2aSXin Li       int m;
194*a58d3d2aSXin Li       OPUS_CLEAR(&out[i*hstride], height);
195*a58d3d2aSXin Li       for (m=0;m<in_channels;m++) {
196*a58d3d2aSXin Li          int j;
197*a58d3d2aSXin Li          for (j=0;j<height;j++) {
198*a58d3d2aSXin Li             /* Unrolled version of previous function -- compiler will figure out the indexing simplifications. */
199*a58d3d2aSXin Li             out[i*hstride + j] += weights[i*in_channels*ktime*kheight + m*ktime*kheight + 0*kheight + 0]*in[0*in_channels*in_stride + m*in_stride + j + 0]
200*a58d3d2aSXin Li                                 + weights[i*in_channels*ktime*kheight + m*ktime*kheight + 0*kheight + 1]*in[0*in_channels*in_stride + m*in_stride + j + 1]
201*a58d3d2aSXin Li                                 + weights[i*in_channels*ktime*kheight + m*ktime*kheight + 0*kheight + 2]*in[0*in_channels*in_stride + m*in_stride + j + 2]
202*a58d3d2aSXin Li                                 + weights[i*in_channels*ktime*kheight + m*ktime*kheight + 1*kheight + 0]*in[1*in_channels*in_stride + m*in_stride + j + 0]
203*a58d3d2aSXin Li                                 + weights[i*in_channels*ktime*kheight + m*ktime*kheight + 1*kheight + 1]*in[1*in_channels*in_stride + m*in_stride + j + 1]
204*a58d3d2aSXin Li                                 + weights[i*in_channels*ktime*kheight + m*ktime*kheight + 1*kheight + 2]*in[1*in_channels*in_stride + m*in_stride + j + 2]
205*a58d3d2aSXin Li                                 + weights[i*in_channels*ktime*kheight + m*ktime*kheight + 2*kheight + 0]*in[2*in_channels*in_stride + m*in_stride + j + 0]
206*a58d3d2aSXin Li                                 + weights[i*in_channels*ktime*kheight + m*ktime*kheight + 2*kheight + 1]*in[2*in_channels*in_stride + m*in_stride + j + 1]
207*a58d3d2aSXin Li                                 + weights[i*in_channels*ktime*kheight + m*ktime*kheight + 2*kheight + 2]*in[2*in_channels*in_stride + m*in_stride + j + 2];
208*a58d3d2aSXin Li                }
209*a58d3d2aSXin Li       }
210*a58d3d2aSXin Li    }
211*a58d3d2aSXin Li }
212*a58d3d2aSXin Li 
213*a58d3d2aSXin Li #define MAX_CONV2D_INPUTS 8192
214*a58d3d2aSXin Li 
RTCD_SUF(compute_conv2d_)215*a58d3d2aSXin Li void RTCD_SUF(compute_conv2d_)(const Conv2dLayer *conv, float *out, float *mem, const float *in, int height, int hstride, int activation)
216*a58d3d2aSXin Li {
217*a58d3d2aSXin Li    int i;
218*a58d3d2aSXin Li    const float *bias;
219*a58d3d2aSXin Li    float in_buf[MAX_CONV2D_INPUTS];
220*a58d3d2aSXin Li    int time_stride;
221*a58d3d2aSXin Li    celt_assert(in != out);
222*a58d3d2aSXin Li    time_stride = conv->in_channels*(height+conv->kheight-1);
223*a58d3d2aSXin Li    celt_assert(conv->ktime*time_stride <= MAX_CONV2D_INPUTS);
224*a58d3d2aSXin Li    OPUS_COPY(in_buf, mem, (conv->ktime-1)*time_stride);
225*a58d3d2aSXin Li    OPUS_COPY(&in_buf[(conv->ktime-1)*time_stride], in, time_stride);
226*a58d3d2aSXin Li    OPUS_COPY(mem, &in_buf[time_stride], (conv->ktime-1)*time_stride);
227*a58d3d2aSXin Li    bias = conv->bias;
228*a58d3d2aSXin Li    if (conv->kheight == 3 && conv->ktime == 3)
229*a58d3d2aSXin Li      conv2d_3x3_float(out, conv->float_weights, conv->in_channels, conv->out_channels, in_buf, height, hstride);
230*a58d3d2aSXin Li    else
231*a58d3d2aSXin Li      conv2d_float(out, conv->float_weights, conv->in_channels, conv->out_channels, conv->ktime, conv->kheight, in_buf, height, hstride);
232*a58d3d2aSXin Li    if (bias != NULL) {
233*a58d3d2aSXin Li      for (i=0;i<conv->out_channels;i++) {
234*a58d3d2aSXin Li        int j;
235*a58d3d2aSXin Li        for (j=0;j<height;j++) out[i*hstride+j] += bias[i];
236*a58d3d2aSXin Li      }
237*a58d3d2aSXin Li    }
238*a58d3d2aSXin Li    for (i=0;i<conv->out_channels;i++) {
239*a58d3d2aSXin Li      RTCD_SUF(compute_activation_)(&out[i*hstride], &out[i*hstride], height, activation);
240*a58d3d2aSXin Li    }
241*a58d3d2aSXin Li }
242*a58d3d2aSXin Li 
243*a58d3d2aSXin Li #ifdef GCC_POP_OPTIONS
244*a58d3d2aSXin Li #pragma GCC pop_options
245*a58d3d2aSXin Li #endif
246*a58d3d2aSXin Li 
247*a58d3d2aSXin Li #endif
248