1*a58d3d2aSXin Li /* Copyright (c) 2008-2011 Octasic Inc.
2*a58d3d2aSXin Li 2012-2017 Jean-Marc Valin */
3*a58d3d2aSXin Li /*
4*a58d3d2aSXin Li Redistribution and use in source and binary forms, with or without
5*a58d3d2aSXin Li modification, are permitted provided that the following conditions
6*a58d3d2aSXin Li are met:
7*a58d3d2aSXin Li
8*a58d3d2aSXin Li - Redistributions of source code must retain the above copyright
9*a58d3d2aSXin Li notice, this list of conditions and the following disclaimer.
10*a58d3d2aSXin Li
11*a58d3d2aSXin Li - Redistributions in binary form must reproduce the above copyright
12*a58d3d2aSXin Li notice, this list of conditions and the following disclaimer in the
13*a58d3d2aSXin Li documentation and/or other materials provided with the distribution.
14*a58d3d2aSXin Li
15*a58d3d2aSXin Li THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16*a58d3d2aSXin Li ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17*a58d3d2aSXin Li LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18*a58d3d2aSXin Li A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR
19*a58d3d2aSXin Li CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20*a58d3d2aSXin Li EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21*a58d3d2aSXin Li PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22*a58d3d2aSXin Li PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23*a58d3d2aSXin Li LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24*a58d3d2aSXin Li NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25*a58d3d2aSXin Li SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26*a58d3d2aSXin Li */
27*a58d3d2aSXin Li
28*a58d3d2aSXin Li #ifdef HAVE_CONFIG_H
29*a58d3d2aSXin Li #include "config.h"
30*a58d3d2aSXin Li #endif
31*a58d3d2aSXin Li
32*a58d3d2aSXin Li #include <math.h>
33*a58d3d2aSXin Li #include "opus_types.h"
34*a58d3d2aSXin Li #include "opus_defines.h"
35*a58d3d2aSXin Li #include "arch.h"
36*a58d3d2aSXin Li #include "mlp.h"
37*a58d3d2aSXin Li
38*a58d3d2aSXin Li #define fmadd(a, b, c) ((a)*(b)+(c))
tansig_approx(float x)39*a58d3d2aSXin Li static OPUS_INLINE float tansig_approx(float x)
40*a58d3d2aSXin Li {
41*a58d3d2aSXin Li const float N0 = 952.52801514f;
42*a58d3d2aSXin Li const float N1 = 96.39235687f;
43*a58d3d2aSXin Li const float N2 = 0.60863042f;
44*a58d3d2aSXin Li const float D0 = 952.72399902f;
45*a58d3d2aSXin Li const float D1 = 413.36801147f;
46*a58d3d2aSXin Li const float D2 = 11.88600922f;
47*a58d3d2aSXin Li float X2, num, den;
48*a58d3d2aSXin Li X2 = x*x;
49*a58d3d2aSXin Li num = fmadd(fmadd(N2, X2, N1), X2, N0);
50*a58d3d2aSXin Li den = fmadd(fmadd(D2, X2, D1), X2, D0);
51*a58d3d2aSXin Li num = num*x/den;
52*a58d3d2aSXin Li return MAX32(-1.f, MIN32(1.f, num));
53*a58d3d2aSXin Li }
54*a58d3d2aSXin Li
sigmoid_approx(float x)55*a58d3d2aSXin Li static OPUS_INLINE float sigmoid_approx(float x)
56*a58d3d2aSXin Li {
57*a58d3d2aSXin Li return .5f + .5f*tansig_approx(.5f*x);
58*a58d3d2aSXin Li }
59*a58d3d2aSXin Li
gemm_accum(float * out,const opus_int8 * weights,int rows,int cols,int col_stride,const float * x)60*a58d3d2aSXin Li static void gemm_accum(float *out, const opus_int8 *weights, int rows, int cols, int col_stride, const float *x)
61*a58d3d2aSXin Li {
62*a58d3d2aSXin Li int i, j;
63*a58d3d2aSXin Li for (i=0;i<rows;i++)
64*a58d3d2aSXin Li {
65*a58d3d2aSXin Li for (j=0;j<cols;j++)
66*a58d3d2aSXin Li out[i] += weights[j*col_stride + i]*x[j];
67*a58d3d2aSXin Li }
68*a58d3d2aSXin Li }
69*a58d3d2aSXin Li
analysis_compute_dense(const AnalysisDenseLayer * layer,float * output,const float * input)70*a58d3d2aSXin Li void analysis_compute_dense(const AnalysisDenseLayer *layer, float *output, const float *input)
71*a58d3d2aSXin Li {
72*a58d3d2aSXin Li int i;
73*a58d3d2aSXin Li int N, M;
74*a58d3d2aSXin Li int stride;
75*a58d3d2aSXin Li M = layer->nb_inputs;
76*a58d3d2aSXin Li N = layer->nb_neurons;
77*a58d3d2aSXin Li stride = N;
78*a58d3d2aSXin Li for (i=0;i<N;i++)
79*a58d3d2aSXin Li output[i] = layer->bias[i];
80*a58d3d2aSXin Li gemm_accum(output, layer->input_weights, N, M, stride, input);
81*a58d3d2aSXin Li for (i=0;i<N;i++)
82*a58d3d2aSXin Li output[i] *= WEIGHTS_SCALE;
83*a58d3d2aSXin Li if (layer->sigmoid) {
84*a58d3d2aSXin Li for (i=0;i<N;i++)
85*a58d3d2aSXin Li output[i] = sigmoid_approx(output[i]);
86*a58d3d2aSXin Li } else {
87*a58d3d2aSXin Li for (i=0;i<N;i++)
88*a58d3d2aSXin Li output[i] = tansig_approx(output[i]);
89*a58d3d2aSXin Li }
90*a58d3d2aSXin Li }
91*a58d3d2aSXin Li
analysis_compute_gru(const AnalysisGRULayer * gru,float * state,const float * input)92*a58d3d2aSXin Li void analysis_compute_gru(const AnalysisGRULayer *gru, float *state, const float *input)
93*a58d3d2aSXin Li {
94*a58d3d2aSXin Li int i;
95*a58d3d2aSXin Li int N, M;
96*a58d3d2aSXin Li int stride;
97*a58d3d2aSXin Li float tmp[MAX_NEURONS];
98*a58d3d2aSXin Li float z[MAX_NEURONS];
99*a58d3d2aSXin Li float r[MAX_NEURONS];
100*a58d3d2aSXin Li float h[MAX_NEURONS];
101*a58d3d2aSXin Li M = gru->nb_inputs;
102*a58d3d2aSXin Li N = gru->nb_neurons;
103*a58d3d2aSXin Li stride = 3*N;
104*a58d3d2aSXin Li /* Compute update gate. */
105*a58d3d2aSXin Li for (i=0;i<N;i++)
106*a58d3d2aSXin Li z[i] = gru->bias[i];
107*a58d3d2aSXin Li gemm_accum(z, gru->input_weights, N, M, stride, input);
108*a58d3d2aSXin Li gemm_accum(z, gru->recurrent_weights, N, N, stride, state);
109*a58d3d2aSXin Li for (i=0;i<N;i++)
110*a58d3d2aSXin Li z[i] = sigmoid_approx(WEIGHTS_SCALE*z[i]);
111*a58d3d2aSXin Li
112*a58d3d2aSXin Li /* Compute reset gate. */
113*a58d3d2aSXin Li for (i=0;i<N;i++)
114*a58d3d2aSXin Li r[i] = gru->bias[N + i];
115*a58d3d2aSXin Li gemm_accum(r, &gru->input_weights[N], N, M, stride, input);
116*a58d3d2aSXin Li gemm_accum(r, &gru->recurrent_weights[N], N, N, stride, state);
117*a58d3d2aSXin Li for (i=0;i<N;i++)
118*a58d3d2aSXin Li r[i] = sigmoid_approx(WEIGHTS_SCALE*r[i]);
119*a58d3d2aSXin Li
120*a58d3d2aSXin Li /* Compute output. */
121*a58d3d2aSXin Li for (i=0;i<N;i++)
122*a58d3d2aSXin Li h[i] = gru->bias[2*N + i];
123*a58d3d2aSXin Li for (i=0;i<N;i++)
124*a58d3d2aSXin Li tmp[i] = state[i] * r[i];
125*a58d3d2aSXin Li gemm_accum(h, &gru->input_weights[2*N], N, M, stride, input);
126*a58d3d2aSXin Li gemm_accum(h, &gru->recurrent_weights[2*N], N, N, stride, tmp);
127*a58d3d2aSXin Li for (i=0;i<N;i++)
128*a58d3d2aSXin Li h[i] = z[i]*state[i] + (1-z[i])*tansig_approx(WEIGHTS_SCALE*h[i]);
129*a58d3d2aSXin Li for (i=0;i<N;i++)
130*a58d3d2aSXin Li state[i] = h[i];
131*a58d3d2aSXin Li }
132*a58d3d2aSXin Li
133