xref: /aosp_15_r20/external/libopus/dnn/nndsp.c (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1*a58d3d2aSXin Li /* Copyright (c) 2023 Amazon
2*a58d3d2aSXin Li    Written by Jan Buethe */
3*a58d3d2aSXin Li /*
4*a58d3d2aSXin Li    Redistribution and use in source and binary forms, with or without
5*a58d3d2aSXin Li    modification, are permitted provided that the following conditions
6*a58d3d2aSXin Li    are met:
7*a58d3d2aSXin Li 
8*a58d3d2aSXin Li    - Redistributions of source code must retain the above copyright
9*a58d3d2aSXin Li    notice, this list of conditions and the following disclaimer.
10*a58d3d2aSXin Li 
11*a58d3d2aSXin Li    - Redistributions in binary form must reproduce the above copyright
12*a58d3d2aSXin Li    notice, this list of conditions and the following disclaimer in the
13*a58d3d2aSXin Li    documentation and/or other materials provided with the distribution.
14*a58d3d2aSXin Li 
15*a58d3d2aSXin Li    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16*a58d3d2aSXin Li    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17*a58d3d2aSXin Li    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18*a58d3d2aSXin Li    A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
19*a58d3d2aSXin Li    OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20*a58d3d2aSXin Li    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21*a58d3d2aSXin Li    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22*a58d3d2aSXin Li    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23*a58d3d2aSXin Li    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24*a58d3d2aSXin Li    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25*a58d3d2aSXin Li    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26*a58d3d2aSXin Li */
27*a58d3d2aSXin Li 
28*a58d3d2aSXin Li #ifdef HAVE_CONFIG_H
29*a58d3d2aSXin Li #include "config.h"
30*a58d3d2aSXin Li #endif
31*a58d3d2aSXin Li 
32*a58d3d2aSXin Li 
33*a58d3d2aSXin Li #include "nndsp.h"
34*a58d3d2aSXin Li #include "arch.h"
35*a58d3d2aSXin Li #include "nnet.h"
36*a58d3d2aSXin Li #include "os_support.h"
37*a58d3d2aSXin Li #include "pitch.h"
38*a58d3d2aSXin Li 
39*a58d3d2aSXin Li #include <math.h>
40*a58d3d2aSXin Li 
41*a58d3d2aSXin Li #ifndef M_PI
42*a58d3d2aSXin Li #define M_PI 3.141592653589793f
43*a58d3d2aSXin Li #endif
44*a58d3d2aSXin Li 
45*a58d3d2aSXin Li #define KERNEL_INDEX(i_out_channels, i_in_channels, i_kernel) ((((i_out_channels) * in_channels) + (i_in_channels)) * kernel_size + (i_kernel))
46*a58d3d2aSXin Li 
init_adaconv_state(AdaConvState * hAdaConv)47*a58d3d2aSXin Li void init_adaconv_state(AdaConvState *hAdaConv)
48*a58d3d2aSXin Li {
49*a58d3d2aSXin Li     OPUS_CLEAR(hAdaConv, 1);
50*a58d3d2aSXin Li }
51*a58d3d2aSXin Li 
init_adacomb_state(AdaCombState * hAdaComb)52*a58d3d2aSXin Li void init_adacomb_state(AdaCombState *hAdaComb)
53*a58d3d2aSXin Li {
54*a58d3d2aSXin Li     OPUS_CLEAR(hAdaComb, 1);
55*a58d3d2aSXin Li }
56*a58d3d2aSXin Li 
init_adashape_state(AdaShapeState * hAdaShape)57*a58d3d2aSXin Li void init_adashape_state(AdaShapeState *hAdaShape)
58*a58d3d2aSXin Li {
59*a58d3d2aSXin Li     OPUS_CLEAR(hAdaShape, 1);
60*a58d3d2aSXin Li }
61*a58d3d2aSXin Li 
compute_overlap_window(float * window,int overlap_size)62*a58d3d2aSXin Li void compute_overlap_window(float *window, int overlap_size)
63*a58d3d2aSXin Li {
64*a58d3d2aSXin Li     int i_sample;
65*a58d3d2aSXin Li     for (i_sample=0; i_sample < overlap_size; i_sample++)
66*a58d3d2aSXin Li     {
67*a58d3d2aSXin Li         window[i_sample] = 0.5f + 0.5f * cos(M_PI * (i_sample + 0.5f) / overlap_size);
68*a58d3d2aSXin Li     }
69*a58d3d2aSXin Li }
70*a58d3d2aSXin Li 
71*a58d3d2aSXin Li #ifdef DEBUG_NNDSP
print_float_vector(const char * name,const float * vec,int length)72*a58d3d2aSXin Li void print_float_vector(const char* name, const float *vec, int length)
73*a58d3d2aSXin Li {
74*a58d3d2aSXin Li     for (int i = 0; i < length; i ++)
75*a58d3d2aSXin Li     {
76*a58d3d2aSXin Li         printf("%s[%d]: %f\n", name, i, vec[i]);
77*a58d3d2aSXin Li     }
78*a58d3d2aSXin Li }
79*a58d3d2aSXin Li #endif
80*a58d3d2aSXin Li 
scale_kernel(float * kernel,int in_channels,int out_channels,int kernel_size,float * gain)81*a58d3d2aSXin Li static void scale_kernel(
82*a58d3d2aSXin Li     float *kernel,
83*a58d3d2aSXin Li     int in_channels,
84*a58d3d2aSXin Li     int out_channels,
85*a58d3d2aSXin Li     int kernel_size,
86*a58d3d2aSXin Li     float *gain
87*a58d3d2aSXin Li )
88*a58d3d2aSXin Li /* normalizes (p-norm) kernel over input channel and kernel dimension */
89*a58d3d2aSXin Li {
90*a58d3d2aSXin Li     float norm;
91*a58d3d2aSXin Li     int i_in_channels, i_out_channels, i_kernel;
92*a58d3d2aSXin Li 
93*a58d3d2aSXin Li     for (i_out_channels = 0; i_out_channels < out_channels; i_out_channels++)
94*a58d3d2aSXin Li     {
95*a58d3d2aSXin Li         norm = 0;
96*a58d3d2aSXin Li         for (i_in_channels = 0; i_in_channels < in_channels; i_in_channels ++)
97*a58d3d2aSXin Li         {
98*a58d3d2aSXin Li             for (i_kernel = 0; i_kernel < kernel_size; i_kernel++)
99*a58d3d2aSXin Li             {
100*a58d3d2aSXin Li                 norm += kernel[KERNEL_INDEX(i_out_channels, i_in_channels, i_kernel)] * kernel[KERNEL_INDEX(i_out_channels, i_in_channels, i_kernel)];
101*a58d3d2aSXin Li             }
102*a58d3d2aSXin Li         }
103*a58d3d2aSXin Li #ifdef DEBUG_NNDSP
104*a58d3d2aSXin Li         printf("kernel norm: %f, %f\n", norm, sqrt(norm));
105*a58d3d2aSXin Li #endif
106*a58d3d2aSXin Li         norm = 1.f / (1e-6f + sqrt(norm));
107*a58d3d2aSXin Li         for (i_in_channels = 0; i_in_channels < in_channels; i_in_channels++)
108*a58d3d2aSXin Li         {
109*a58d3d2aSXin Li             for (i_kernel = 0; i_kernel < kernel_size; i_kernel++)
110*a58d3d2aSXin Li             {
111*a58d3d2aSXin Li 
112*a58d3d2aSXin Li                 kernel[KERNEL_INDEX(i_out_channels, i_in_channels, i_kernel)] *= norm * gain[i_out_channels];
113*a58d3d2aSXin Li             }
114*a58d3d2aSXin Li         }
115*a58d3d2aSXin Li     }
116*a58d3d2aSXin Li }
117*a58d3d2aSXin Li 
transform_gains(float * gains,int num_gains,float filter_gain_a,float filter_gain_b)118*a58d3d2aSXin Li static void transform_gains(
119*a58d3d2aSXin Li     float *gains,
120*a58d3d2aSXin Li     int num_gains,
121*a58d3d2aSXin Li     float filter_gain_a,
122*a58d3d2aSXin Li     float filter_gain_b
123*a58d3d2aSXin Li )
124*a58d3d2aSXin Li {
125*a58d3d2aSXin Li     int i;
126*a58d3d2aSXin Li     for (i = 0; i < num_gains; i++)
127*a58d3d2aSXin Li     {
128*a58d3d2aSXin Li         gains[i] = exp(filter_gain_a * gains[i] + filter_gain_b);
129*a58d3d2aSXin Li     }
130*a58d3d2aSXin Li }
131*a58d3d2aSXin Li 
adaconv_process_frame(AdaConvState * hAdaConv,float * x_out,const float * x_in,const float * features,const LinearLayer * kernel_layer,const LinearLayer * gain_layer,int feature_dim,int frame_size,int overlap_size,int in_channels,int out_channels,int kernel_size,int left_padding,float filter_gain_a,float filter_gain_b,float shape_gain,float * window,int arch)132*a58d3d2aSXin Li void adaconv_process_frame(
133*a58d3d2aSXin Li     AdaConvState* hAdaConv,
134*a58d3d2aSXin Li     float *x_out,
135*a58d3d2aSXin Li     const float *x_in,
136*a58d3d2aSXin Li     const float *features,
137*a58d3d2aSXin Li     const LinearLayer *kernel_layer,
138*a58d3d2aSXin Li     const LinearLayer *gain_layer,
139*a58d3d2aSXin Li     int feature_dim,
140*a58d3d2aSXin Li     int frame_size,
141*a58d3d2aSXin Li     int overlap_size,
142*a58d3d2aSXin Li     int in_channels,
143*a58d3d2aSXin Li     int out_channels,
144*a58d3d2aSXin Li     int kernel_size,
145*a58d3d2aSXin Li     int left_padding,
146*a58d3d2aSXin Li     float filter_gain_a,
147*a58d3d2aSXin Li     float filter_gain_b,
148*a58d3d2aSXin Li     float shape_gain,
149*a58d3d2aSXin Li     float *window,
150*a58d3d2aSXin Li     int arch
151*a58d3d2aSXin Li )
152*a58d3d2aSXin Li {
153*a58d3d2aSXin Li     float output_buffer[ADACONV_MAX_FRAME_SIZE * ADACONV_MAX_OUTPUT_CHANNELS];
154*a58d3d2aSXin Li     float kernel_buffer[ADACONV_MAX_KERNEL_SIZE * ADACONV_MAX_INPUT_CHANNELS * ADACONV_MAX_OUTPUT_CHANNELS];
155*a58d3d2aSXin Li     float input_buffer[ADACONV_MAX_INPUT_CHANNELS * (ADACONV_MAX_FRAME_SIZE + ADACONV_MAX_KERNEL_SIZE)];
156*a58d3d2aSXin Li     float kernel0[ADACONV_MAX_KERNEL_SIZE];
157*a58d3d2aSXin Li     float kernel1[ADACONV_MAX_KERNEL_SIZE];
158*a58d3d2aSXin Li     float channel_buffer0[ADACONV_MAX_OVERLAP_SIZE];
159*a58d3d2aSXin Li     float channel_buffer1[ADACONV_MAX_FRAME_SIZE];
160*a58d3d2aSXin Li     float gain_buffer[ADACONV_MAX_OUTPUT_CHANNELS];
161*a58d3d2aSXin Li     float *p_input;
162*a58d3d2aSXin Li     int i_in_channels, i_out_channels, i_sample;
163*a58d3d2aSXin Li 
164*a58d3d2aSXin Li     (void) feature_dim; /* ToDo: figure out whether we might need this information */
165*a58d3d2aSXin Li 
166*a58d3d2aSXin Li     celt_assert(shape_gain == 1);
167*a58d3d2aSXin Li     celt_assert(left_padding == kernel_size - 1); /* currently only supports causal version. Non-causal version not difficult to implement but will require third loop */
168*a58d3d2aSXin Li     celt_assert(kernel_size < frame_size);
169*a58d3d2aSXin Li 
170*a58d3d2aSXin Li     OPUS_CLEAR(output_buffer, ADACONV_MAX_FRAME_SIZE * ADACONV_MAX_OUTPUT_CHANNELS);
171*a58d3d2aSXin Li     OPUS_CLEAR(kernel_buffer, ADACONV_MAX_KERNEL_SIZE * ADACONV_MAX_INPUT_CHANNELS * ADACONV_MAX_OUTPUT_CHANNELS);
172*a58d3d2aSXin Li     OPUS_CLEAR(input_buffer, ADACONV_MAX_INPUT_CHANNELS * (ADACONV_MAX_FRAME_SIZE + ADACONV_MAX_KERNEL_SIZE));
173*a58d3d2aSXin Li 
174*a58d3d2aSXin Li #ifdef DEBUG_NNDSP
175*a58d3d2aSXin Li     print_float_vector("x_in", x_in, in_channels * frame_size);
176*a58d3d2aSXin Li #endif
177*a58d3d2aSXin Li 
178*a58d3d2aSXin Li     /* prepare input */
179*a58d3d2aSXin Li     for (i_in_channels=0; i_in_channels < in_channels; i_in_channels ++)
180*a58d3d2aSXin Li     {
181*a58d3d2aSXin Li         OPUS_COPY(input_buffer + i_in_channels * (kernel_size + frame_size), hAdaConv->history + i_in_channels * kernel_size, kernel_size);
182*a58d3d2aSXin Li         OPUS_COPY(input_buffer + kernel_size + i_in_channels * (kernel_size + frame_size), x_in + frame_size * i_in_channels, frame_size);
183*a58d3d2aSXin Li     }
184*a58d3d2aSXin Li     p_input = input_buffer + kernel_size;
185*a58d3d2aSXin Li 
186*a58d3d2aSXin Li 
187*a58d3d2aSXin Li     /* calculate new kernel and new gain */
188*a58d3d2aSXin Li     compute_generic_dense(kernel_layer, kernel_buffer, features, ACTIVATION_LINEAR, arch);
189*a58d3d2aSXin Li     compute_generic_dense(gain_layer, gain_buffer, features, ACTIVATION_TANH, arch);
190*a58d3d2aSXin Li #ifdef DEBUG_NNDSP
191*a58d3d2aSXin Li     print_float_vector("features", features, feature_dim);
192*a58d3d2aSXin Li     print_float_vector("adaconv_kernel_raw", kernel_buffer, in_channels * out_channels * kernel_size);
193*a58d3d2aSXin Li     print_float_vector("adaconv_gain_raw", gain_buffer, out_channels);
194*a58d3d2aSXin Li #endif
195*a58d3d2aSXin Li     transform_gains(gain_buffer, out_channels, filter_gain_a, filter_gain_b);
196*a58d3d2aSXin Li     scale_kernel(kernel_buffer, in_channels, out_channels, kernel_size, gain_buffer);
197*a58d3d2aSXin Li 
198*a58d3d2aSXin Li #ifdef DEBUG_NNDSP
199*a58d3d2aSXin Li     print_float_vector("adaconv_kernel", kernel_buffer, in_channels * out_channels * kernel_size);
200*a58d3d2aSXin Li     print_float_vector("adaconv_gain", gain_buffer, out_channels);
201*a58d3d2aSXin Li #endif
202*a58d3d2aSXin Li 
203*a58d3d2aSXin Li     /* calculate overlapping part using kernel from last frame */
204*a58d3d2aSXin Li 
205*a58d3d2aSXin Li     for (i_out_channels = 0; i_out_channels < out_channels; i_out_channels++)
206*a58d3d2aSXin Li     {
207*a58d3d2aSXin Li         for (i_in_channels = 0; i_in_channels < in_channels; i_in_channels++)
208*a58d3d2aSXin Li         {
209*a58d3d2aSXin Li             OPUS_CLEAR(kernel0, ADACONV_MAX_KERNEL_SIZE);
210*a58d3d2aSXin Li             OPUS_CLEAR(kernel1, ADACONV_MAX_KERNEL_SIZE);
211*a58d3d2aSXin Li 
212*a58d3d2aSXin Li             OPUS_COPY(kernel0, hAdaConv->last_kernel + KERNEL_INDEX(i_out_channels, i_in_channels, 0), kernel_size);
213*a58d3d2aSXin Li             OPUS_COPY(kernel1, kernel_buffer + KERNEL_INDEX(i_out_channels, i_in_channels, 0), kernel_size);
214*a58d3d2aSXin Li             celt_pitch_xcorr(kernel0, p_input + i_in_channels * (frame_size + kernel_size) - left_padding, channel_buffer0, ADACONV_MAX_KERNEL_SIZE, overlap_size, arch);
215*a58d3d2aSXin Li             celt_pitch_xcorr(kernel1, p_input + i_in_channels * (frame_size + kernel_size) - left_padding, channel_buffer1, ADACONV_MAX_KERNEL_SIZE, frame_size, arch);
216*a58d3d2aSXin Li             for (i_sample = 0; i_sample < overlap_size; i_sample++)
217*a58d3d2aSXin Li             {
218*a58d3d2aSXin Li                 output_buffer[i_sample + i_out_channels * frame_size] +=  window[i_sample] * channel_buffer0[i_sample];
219*a58d3d2aSXin Li                 output_buffer[i_sample + i_out_channels * frame_size] += (1.f - window[i_sample]) * channel_buffer1[i_sample];
220*a58d3d2aSXin Li             }
221*a58d3d2aSXin Li             for (i_sample = overlap_size; i_sample < frame_size; i_sample++)
222*a58d3d2aSXin Li             {
223*a58d3d2aSXin Li                 output_buffer[i_sample + i_out_channels * frame_size] += channel_buffer1[i_sample];
224*a58d3d2aSXin Li             }
225*a58d3d2aSXin Li         }
226*a58d3d2aSXin Li     }
227*a58d3d2aSXin Li 
228*a58d3d2aSXin Li     OPUS_COPY(x_out, output_buffer, out_channels * frame_size);
229*a58d3d2aSXin Li 
230*a58d3d2aSXin Li #ifdef DEBUG_NNDSP
231*a58d3d2aSXin Li     print_float_vector("x_out", x_out, out_channels * frame_size);
232*a58d3d2aSXin Li #endif
233*a58d3d2aSXin Li 
234*a58d3d2aSXin Li     /* buffer update */
235*a58d3d2aSXin Li     for (i_in_channels=0; i_in_channels < in_channels; i_in_channels ++)
236*a58d3d2aSXin Li     {
237*a58d3d2aSXin Li         OPUS_COPY(hAdaConv->history + i_in_channels * kernel_size, p_input + i_in_channels * (frame_size + kernel_size) + frame_size - kernel_size, kernel_size);
238*a58d3d2aSXin Li     }
239*a58d3d2aSXin Li     OPUS_COPY(hAdaConv->last_kernel, kernel_buffer, kernel_size * in_channels * out_channels);
240*a58d3d2aSXin Li }
241*a58d3d2aSXin Li 
adacomb_process_frame(AdaCombState * hAdaComb,float * x_out,const float * x_in,const float * features,const LinearLayer * kernel_layer,const LinearLayer * gain_layer,const LinearLayer * global_gain_layer,int pitch_lag,int feature_dim,int frame_size,int overlap_size,int kernel_size,int left_padding,float filter_gain_a,float filter_gain_b,float log_gain_limit,float * window,int arch)242*a58d3d2aSXin Li void adacomb_process_frame(
243*a58d3d2aSXin Li     AdaCombState* hAdaComb,
244*a58d3d2aSXin Li     float *x_out,
245*a58d3d2aSXin Li     const float *x_in,
246*a58d3d2aSXin Li     const float *features,
247*a58d3d2aSXin Li     const LinearLayer *kernel_layer,
248*a58d3d2aSXin Li     const LinearLayer *gain_layer,
249*a58d3d2aSXin Li     const LinearLayer *global_gain_layer,
250*a58d3d2aSXin Li     int pitch_lag,
251*a58d3d2aSXin Li     int feature_dim,
252*a58d3d2aSXin Li     int frame_size,
253*a58d3d2aSXin Li     int overlap_size,
254*a58d3d2aSXin Li     int kernel_size,
255*a58d3d2aSXin Li     int left_padding,
256*a58d3d2aSXin Li     float filter_gain_a,
257*a58d3d2aSXin Li     float filter_gain_b,
258*a58d3d2aSXin Li     float log_gain_limit,
259*a58d3d2aSXin Li     float *window,
260*a58d3d2aSXin Li     int arch
261*a58d3d2aSXin Li )
262*a58d3d2aSXin Li {
263*a58d3d2aSXin Li     float output_buffer[ADACOMB_MAX_FRAME_SIZE];
264*a58d3d2aSXin Li     float output_buffer_last[ADACOMB_MAX_FRAME_SIZE];
265*a58d3d2aSXin Li     float kernel_buffer[ADACOMB_MAX_KERNEL_SIZE];
266*a58d3d2aSXin Li     float input_buffer[ADACOMB_MAX_FRAME_SIZE + ADACOMB_MAX_LAG + ADACOMB_MAX_KERNEL_SIZE];
267*a58d3d2aSXin Li     float gain, global_gain;
268*a58d3d2aSXin Li     float *p_input;
269*a58d3d2aSXin Li     int i_sample;
270*a58d3d2aSXin Li     float kernel[16];
271*a58d3d2aSXin Li     float last_kernel[16];
272*a58d3d2aSXin Li 
273*a58d3d2aSXin Li     (void) feature_dim; /* ToDo: figure out whether we might need this information */
274*a58d3d2aSXin Li 
275*a58d3d2aSXin Li     OPUS_CLEAR(output_buffer, ADACOMB_MAX_FRAME_SIZE);
276*a58d3d2aSXin Li     OPUS_CLEAR(kernel_buffer, ADACOMB_MAX_KERNEL_SIZE);
277*a58d3d2aSXin Li     OPUS_CLEAR(input_buffer, ADACOMB_MAX_FRAME_SIZE + ADACOMB_MAX_LAG + ADACOMB_MAX_KERNEL_SIZE);
278*a58d3d2aSXin Li 
279*a58d3d2aSXin Li     OPUS_COPY(input_buffer, hAdaComb->history, kernel_size + ADACOMB_MAX_LAG);
280*a58d3d2aSXin Li     OPUS_COPY(input_buffer + kernel_size + ADACOMB_MAX_LAG, x_in, frame_size);
281*a58d3d2aSXin Li     p_input = input_buffer + kernel_size + ADACOMB_MAX_LAG;
282*a58d3d2aSXin Li 
283*a58d3d2aSXin Li     /* calculate new kernel and new gain */
284*a58d3d2aSXin Li     compute_generic_dense(kernel_layer, kernel_buffer, features, ACTIVATION_LINEAR, arch);
285*a58d3d2aSXin Li     compute_generic_dense(gain_layer, &gain, features, ACTIVATION_RELU, arch);
286*a58d3d2aSXin Li     compute_generic_dense(global_gain_layer, &global_gain, features, ACTIVATION_TANH, arch);
287*a58d3d2aSXin Li #ifdef DEBUG_NNDSP
288*a58d3d2aSXin Li     print_float_vector("features", features, feature_dim);
289*a58d3d2aSXin Li     print_float_vector("adacomb_kernel_raw", kernel_buffer, kernel_size);
290*a58d3d2aSXin Li     print_float_vector("adacomb_gain_raw", &gain, 1);
291*a58d3d2aSXin Li     print_float_vector("adacomb_global_gain_raw", &global_gain, 1);
292*a58d3d2aSXin Li #endif
293*a58d3d2aSXin Li     gain = exp(log_gain_limit - gain);
294*a58d3d2aSXin Li     global_gain = exp(filter_gain_a * global_gain + filter_gain_b);
295*a58d3d2aSXin Li     scale_kernel(kernel_buffer, 1, 1, kernel_size, &gain);
296*a58d3d2aSXin Li 
297*a58d3d2aSXin Li #ifdef DEBUG_NNDSP
298*a58d3d2aSXin Li     print_float_vector("adacomb_kernel", kernel_buffer, kernel_size);
299*a58d3d2aSXin Li     print_float_vector("adacomb_gain", &gain, 1);
300*a58d3d2aSXin Li #endif
301*a58d3d2aSXin Li 
302*a58d3d2aSXin Li     OPUS_CLEAR(kernel, ADACOMB_MAX_KERNEL_SIZE);
303*a58d3d2aSXin Li     OPUS_CLEAR(last_kernel, ADACOMB_MAX_KERNEL_SIZE);
304*a58d3d2aSXin Li     OPUS_COPY(kernel, kernel_buffer, kernel_size);
305*a58d3d2aSXin Li     OPUS_COPY(last_kernel, hAdaComb->last_kernel, kernel_size);
306*a58d3d2aSXin Li 
307*a58d3d2aSXin Li     celt_pitch_xcorr(last_kernel, &p_input[- left_padding - hAdaComb->last_pitch_lag], output_buffer_last, ADACOMB_MAX_KERNEL_SIZE, overlap_size, arch);
308*a58d3d2aSXin Li 
309*a58d3d2aSXin Li     celt_pitch_xcorr(kernel, &p_input[- left_padding - pitch_lag], output_buffer, ADACOMB_MAX_KERNEL_SIZE, frame_size, arch);
310*a58d3d2aSXin Li     for (i_sample = 0; i_sample < overlap_size; i_sample++)
311*a58d3d2aSXin Li     {
312*a58d3d2aSXin Li       output_buffer[i_sample] = hAdaComb->last_global_gain * window[i_sample] * output_buffer_last[i_sample] + global_gain * (1.f - window[i_sample]) * output_buffer[i_sample];
313*a58d3d2aSXin Li     }
314*a58d3d2aSXin Li 
315*a58d3d2aSXin Li     for (i_sample = 0; i_sample < overlap_size; i_sample++)
316*a58d3d2aSXin Li     {
317*a58d3d2aSXin Li       output_buffer[i_sample] += (window[i_sample] * hAdaComb->last_global_gain + (1.f - window[i_sample]) * global_gain) * p_input[i_sample];
318*a58d3d2aSXin Li     }
319*a58d3d2aSXin Li 
320*a58d3d2aSXin Li     for (i_sample = overlap_size; i_sample < frame_size; i_sample++)
321*a58d3d2aSXin Li     {
322*a58d3d2aSXin Li       output_buffer[i_sample] = global_gain * (output_buffer[i_sample] + p_input[i_sample]);
323*a58d3d2aSXin Li     }
324*a58d3d2aSXin Li     OPUS_COPY(x_out, output_buffer, frame_size);
325*a58d3d2aSXin Li 
326*a58d3d2aSXin Li #ifdef DEBUG_NNDSP
327*a58d3d2aSXin Li     print_float_vector("x_out", x_out, frame_size);
328*a58d3d2aSXin Li #endif
329*a58d3d2aSXin Li 
330*a58d3d2aSXin Li     /* buffer update */
331*a58d3d2aSXin Li     OPUS_COPY(hAdaComb->last_kernel, kernel_buffer, kernel_size);
332*a58d3d2aSXin Li     OPUS_COPY(hAdaComb->history, p_input + frame_size - kernel_size - ADACOMB_MAX_LAG, kernel_size + ADACOMB_MAX_LAG);
333*a58d3d2aSXin Li     hAdaComb->last_pitch_lag = pitch_lag;
334*a58d3d2aSXin Li     hAdaComb->last_global_gain = global_gain;
335*a58d3d2aSXin Li }
336*a58d3d2aSXin Li 
337*a58d3d2aSXin Li 
adashape_process_frame(AdaShapeState * hAdaShape,float * x_out,const float * x_in,const float * features,const LinearLayer * alpha1f,const LinearLayer * alpha1t,const LinearLayer * alpha2,int feature_dim,int frame_size,int avg_pool_k,int arch)338*a58d3d2aSXin Li void adashape_process_frame(
339*a58d3d2aSXin Li     AdaShapeState *hAdaShape,
340*a58d3d2aSXin Li     float *x_out,
341*a58d3d2aSXin Li     const float *x_in,
342*a58d3d2aSXin Li     const float *features,
343*a58d3d2aSXin Li     const LinearLayer *alpha1f,
344*a58d3d2aSXin Li     const LinearLayer *alpha1t,
345*a58d3d2aSXin Li     const LinearLayer *alpha2,
346*a58d3d2aSXin Li     int feature_dim,
347*a58d3d2aSXin Li     int frame_size,
348*a58d3d2aSXin Li     int avg_pool_k,
349*a58d3d2aSXin Li     int arch
350*a58d3d2aSXin Li )
351*a58d3d2aSXin Li {
352*a58d3d2aSXin Li     float in_buffer[ADASHAPE_MAX_INPUT_DIM + ADASHAPE_MAX_FRAME_SIZE];
353*a58d3d2aSXin Li     float out_buffer[ADASHAPE_MAX_FRAME_SIZE];
354*a58d3d2aSXin Li     float tmp_buffer[ADASHAPE_MAX_FRAME_SIZE];
355*a58d3d2aSXin Li     int i, k;
356*a58d3d2aSXin Li     int tenv_size;
357*a58d3d2aSXin Li     float mean;
358*a58d3d2aSXin Li     float *tenv;
359*a58d3d2aSXin Li 
360*a58d3d2aSXin Li     celt_assert(frame_size % avg_pool_k == 0);
361*a58d3d2aSXin Li     celt_assert(feature_dim + frame_size / avg_pool_k + 1 < ADASHAPE_MAX_INPUT_DIM);
362*a58d3d2aSXin Li 
363*a58d3d2aSXin Li     tenv_size = frame_size / avg_pool_k;
364*a58d3d2aSXin Li     tenv = in_buffer + feature_dim;
365*a58d3d2aSXin Li     OPUS_CLEAR(tenv, tenv_size + 1);
366*a58d3d2aSXin Li 
367*a58d3d2aSXin Li     OPUS_COPY(in_buffer, features, feature_dim);
368*a58d3d2aSXin Li 
369*a58d3d2aSXin Li     /* calculate temporal envelope */
370*a58d3d2aSXin Li     mean = 0;
371*a58d3d2aSXin Li     for (i = 0; i < tenv_size; i++)
372*a58d3d2aSXin Li     {
373*a58d3d2aSXin Li         for (k = 0; k < avg_pool_k; k++)
374*a58d3d2aSXin Li         {
375*a58d3d2aSXin Li             tenv[i] += fabs(x_in[i * avg_pool_k + k]);
376*a58d3d2aSXin Li         }
377*a58d3d2aSXin Li         tenv[i] = log(tenv[i] / avg_pool_k + 1.52587890625e-05f);
378*a58d3d2aSXin Li         mean += tenv[i];
379*a58d3d2aSXin Li     }
380*a58d3d2aSXin Li     mean /= tenv_size;
381*a58d3d2aSXin Li     for (i = 0; i < tenv_size; i++)
382*a58d3d2aSXin Li     {
383*a58d3d2aSXin Li         tenv[i] -= mean;
384*a58d3d2aSXin Li     }
385*a58d3d2aSXin Li     tenv[tenv_size] = mean;
386*a58d3d2aSXin Li #ifdef DEBUG_NNDSP
387*a58d3d2aSXin Li     print_float_vector("tenv", tenv, tenv_size + 1);
388*a58d3d2aSXin Li #endif
389*a58d3d2aSXin Li 
390*a58d3d2aSXin Li     /* calculate temporal weights */
391*a58d3d2aSXin Li #ifdef DEBUG_NNDSP
392*a58d3d2aSXin Li     print_float_vector("alpha1_in", in_buffer, feature_dim + tenv_size + 1);
393*a58d3d2aSXin Li #endif
394*a58d3d2aSXin Li     compute_generic_conv1d(alpha1f, out_buffer, hAdaShape->conv_alpha1f_state, in_buffer, feature_dim, ACTIVATION_LINEAR, arch);
395*a58d3d2aSXin Li     compute_generic_conv1d(alpha1t, tmp_buffer, hAdaShape->conv_alpha1t_state, tenv, tenv_size + 1, ACTIVATION_LINEAR, arch);
396*a58d3d2aSXin Li #ifdef DEBUG_NNDSP
397*a58d3d2aSXin Li     print_float_vector("alpha1_out", out_buffer, frame_size);
398*a58d3d2aSXin Li #endif
399*a58d3d2aSXin Li     /* compute leaky ReLU by hand. ToDo: try tanh activation */
400*a58d3d2aSXin Li     for (i = 0; i < frame_size; i ++)
401*a58d3d2aSXin Li     {
402*a58d3d2aSXin Li         float tmp = out_buffer[i] + tmp_buffer[i];
403*a58d3d2aSXin Li         in_buffer[i] = tmp >= 0 ? tmp : 0.2 * tmp;
404*a58d3d2aSXin Li     }
405*a58d3d2aSXin Li #ifdef DEBUG_NNDSP
406*a58d3d2aSXin Li     print_float_vector("post_alpha1", in_buffer, frame_size);
407*a58d3d2aSXin Li #endif
408*a58d3d2aSXin Li     compute_generic_conv1d(alpha2, out_buffer, hAdaShape->conv_alpha2_state, in_buffer, frame_size, ACTIVATION_LINEAR, arch);
409*a58d3d2aSXin Li 
410*a58d3d2aSXin Li     /* shape signal */
411*a58d3d2aSXin Li     for (i = 0; i < frame_size; i ++)
412*a58d3d2aSXin Li     {
413*a58d3d2aSXin Li         x_out[i] = exp(out_buffer[i]) * x_in[i];
414*a58d3d2aSXin Li     }
415*a58d3d2aSXin Li 
416*a58d3d2aSXin Li }
417