xref: /aosp_15_r20/external/libopus/dnn/lpcnet_plc.c (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1 /* Copyright (c) 2021 Amazon */
2 /*
3    Redistribution and use in source and binary forms, with or without
4    modification, are permitted provided that the following conditions
5    are met:
6 
7    - Redistributions of source code must retain the above copyright
8    notice, this list of conditions and the following disclaimer.
9 
10    - Redistributions in binary form must reproduce the above copyright
11    notice, this list of conditions and the following disclaimer in the
12    documentation and/or other materials provided with the distribution.
13 
14    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
15    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
16    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
17    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
18    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
19    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
20    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
21    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
22    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
23    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 */
26 
27 #ifdef HAVE_CONFIG_H
28 #include "config.h"
29 #endif
30 
31 #include "lpcnet_private.h"
32 #include "lpcnet.h"
33 #include "plc_data.h"
34 #include "os_support.h"
35 #include "common.h"
36 #include "cpu_support.h"
37 
38 #ifndef M_PI
39 #define M_PI 3.141592653
40 #endif
41 
42 /* Comment this out to have LPCNet update its state on every good packet (slow). */
43 #define PLC_SKIP_UPDATES
44 
lpcnet_plc_reset(LPCNetPLCState * st)45 void lpcnet_plc_reset(LPCNetPLCState *st) {
46   OPUS_CLEAR((char*)&st->LPCNET_PLC_RESET_START,
47           sizeof(LPCNetPLCState)-
48           ((char*)&st->LPCNET_PLC_RESET_START - (char*)st));
49   lpcnet_encoder_init(&st->enc);
50   OPUS_CLEAR(st->pcm, PLC_BUF_SIZE);
51   st->blend = 0;
52   st->loss_count = 0;
53   st->analysis_gap = 1;
54   st->analysis_pos = PLC_BUF_SIZE;
55   st->predict_pos = PLC_BUF_SIZE;
56 }
57 
lpcnet_plc_init(LPCNetPLCState * st)58 int lpcnet_plc_init(LPCNetPLCState *st) {
59   int ret;
60   st->arch = opus_select_arch();
61   fargan_init(&st->fargan);
62   lpcnet_encoder_init(&st->enc);
63   st->loaded = 0;
64 #ifndef USE_WEIGHTS_FILE
65   ret = init_plcmodel(&st->model, plcmodel_arrays);
66   if (ret == 0) st->loaded = 1;
67 #else
68   ret = 0;
69 #endif
70   celt_assert(ret == 0);
71   lpcnet_plc_reset(st);
72   return ret;
73 }
74 
lpcnet_plc_load_model(LPCNetPLCState * st,const void * data,int len)75 int lpcnet_plc_load_model(LPCNetPLCState *st, const void *data, int len) {
76   WeightArray *list;
77   int ret;
78   parse_weights(&list, data, len);
79   ret = init_plcmodel(&st->model, list);
80   opus_free(list);
81   if (ret == 0) {
82     ret = lpcnet_encoder_load_model(&st->enc, data, len);
83   }
84   if (ret == 0) {
85     ret = fargan_load_model(&st->fargan, data, len);
86   }
87   if (ret == 0) st->loaded = 1;
88   return ret;
89 }
90 
lpcnet_plc_fec_add(LPCNetPLCState * st,const float * features)91 void lpcnet_plc_fec_add(LPCNetPLCState *st, const float *features) {
92   if (features == NULL) {
93     st->fec_skip++;
94     return;
95   }
96   if (st->fec_fill_pos == PLC_MAX_FEC) {
97     OPUS_MOVE(&st->fec[0][0], &st->fec[st->fec_read_pos][0], (st->fec_fill_pos-st->fec_read_pos)*NB_FEATURES);
98     st->fec_fill_pos = st->fec_fill_pos-st->fec_read_pos;
99     st->fec_read_pos -= st->fec_read_pos;
100   }
101   OPUS_COPY(&st->fec[st->fec_fill_pos][0], features, NB_FEATURES);
102   st->fec_fill_pos++;
103 }
104 
lpcnet_plc_fec_clear(LPCNetPLCState * st)105 void lpcnet_plc_fec_clear(LPCNetPLCState *st) {
106   st->fec_read_pos = st->fec_fill_pos = st->fec_skip = 0;
107 }
108 
109 
compute_plc_pred(LPCNetPLCState * st,float * out,const float * in)110 static void compute_plc_pred(LPCNetPLCState *st, float *out, const float *in) {
111   float tmp[PLC_DENSE_IN_OUT_SIZE];
112   PLCModel *model = &st->model;
113   PLCNetState *net = &st->plc_net;
114   celt_assert(st->loaded);
115   compute_generic_dense(&model->plc_dense_in, tmp, in, ACTIVATION_TANH, st->arch);
116   compute_generic_gru(&model->plc_gru1_input, &model->plc_gru1_recurrent, net->gru1_state, tmp, st->arch);
117   compute_generic_gru(&model->plc_gru2_input, &model->plc_gru2_recurrent, net->gru2_state, net->gru1_state, st->arch);
118   compute_generic_dense(&model->plc_dense_out, out, net->gru2_state, ACTIVATION_LINEAR, st->arch);
119 }
120 
get_fec_or_pred(LPCNetPLCState * st,float * out)121 static int get_fec_or_pred(LPCNetPLCState *st, float *out) {
122   if (st->fec_read_pos != st->fec_fill_pos && st->fec_skip==0) {
123     float plc_features[2*NB_BANDS+NB_FEATURES+1] = {0};
124     float discard[NB_FEATURES];
125     OPUS_COPY(out, &st->fec[st->fec_read_pos][0], NB_FEATURES);
126     st->fec_read_pos++;
127     /* Update PLC state using FEC, so without Burg features. */
128     OPUS_COPY(&plc_features[2*NB_BANDS], out, NB_FEATURES);
129     plc_features[2*NB_BANDS+NB_FEATURES] = -1;
130     compute_plc_pred(st, discard, plc_features);
131     return 1;
132   } else {
133     float zeros[2*NB_BANDS+NB_FEATURES+1] = {0};
134     compute_plc_pred(st, out, zeros);
135     if (st->fec_skip > 0) st->fec_skip--;
136     return 0;
137   }
138 }
139 
queue_features(LPCNetPLCState * st,const float * features)140 static void queue_features(LPCNetPLCState *st, const float *features) {
141   OPUS_MOVE(&st->cont_features[0], &st->cont_features[NB_FEATURES], (CONT_VECTORS-1)*NB_FEATURES);
142   OPUS_COPY(&st->cont_features[(CONT_VECTORS-1)*NB_FEATURES], features, NB_FEATURES);
143 }
144 
145 /* In this causal version of the code, the DNN model implemented by compute_plc_pred()
146    needs to generate two feature vectors to conceal the first lost packet.*/
147 
lpcnet_plc_update(LPCNetPLCState * st,opus_int16 * pcm)148 int lpcnet_plc_update(LPCNetPLCState *st, opus_int16 *pcm) {
149   int i;
150   if (st->analysis_pos - FRAME_SIZE >= 0) st->analysis_pos -= FRAME_SIZE;
151   else st->analysis_gap = 1;
152   if (st->predict_pos - FRAME_SIZE >= 0) st->predict_pos -= FRAME_SIZE;
153   OPUS_MOVE(st->pcm, &st->pcm[FRAME_SIZE], PLC_BUF_SIZE-FRAME_SIZE);
154   for (i=0;i<FRAME_SIZE;i++) st->pcm[PLC_BUF_SIZE-FRAME_SIZE+i] = (1.f/32768.f)*pcm[i];
155   st->loss_count = 0;
156   st->blend = 0;
157   return 0;
158 }
159 
160 static const float att_table[10] = {0, 0,  -.2, -.2,  -.4, -.4,  -.8, -.8, -1.6, -1.6};
lpcnet_plc_conceal(LPCNetPLCState * st,opus_int16 * pcm)161 int lpcnet_plc_conceal(LPCNetPLCState *st, opus_int16 *pcm) {
162   int i;
163   celt_assert(st->loaded);
164   if (st->blend == 0) {
165     int count = 0;
166     st->plc_net = st->plc_bak[0];
167     while (st->analysis_pos + FRAME_SIZE <= PLC_BUF_SIZE) {
168       float x[FRAME_SIZE];
169       float plc_features[2*NB_BANDS+NB_FEATURES+1];
170       celt_assert(st->analysis_pos >= 0);
171       for (i=0;i<FRAME_SIZE;i++) x[i] = 32768.f*st->pcm[st->analysis_pos+i];
172       burg_cepstral_analysis(plc_features, x);
173       lpcnet_compute_single_frame_features_float(&st->enc, x, st->features, st->arch);
174       if ((!st->analysis_gap || count>0) && st->analysis_pos >= st->predict_pos) {
175         queue_features(st, st->features);
176         OPUS_COPY(&plc_features[2*NB_BANDS], st->features, NB_FEATURES);
177         plc_features[2*NB_BANDS+NB_FEATURES] = 1;
178         st->plc_bak[0] = st->plc_bak[1];
179         st->plc_bak[1] = st->plc_net;
180         compute_plc_pred(st, st->features, plc_features);
181       }
182       st->analysis_pos += FRAME_SIZE;
183       count++;
184     }
185     st->plc_bak[0] = st->plc_bak[1];
186     st->plc_bak[1] = st->plc_net;
187     get_fec_or_pred(st, st->features);
188     queue_features(st, st->features);
189     st->plc_bak[0] = st->plc_bak[1];
190     st->plc_bak[1] = st->plc_net;
191     get_fec_or_pred(st, st->features);
192     queue_features(st, st->features);
193     fargan_cont(&st->fargan, &st->pcm[PLC_BUF_SIZE-FARGAN_CONT_SAMPLES], st->cont_features);
194     st->analysis_gap = 0;
195   }
196   st->plc_bak[0] = st->plc_bak[1];
197   st->plc_bak[1] = st->plc_net;
198   if (get_fec_or_pred(st, st->features)) st->loss_count = 0;
199   else st->loss_count++;
200   if (st->loss_count >= 10) st->features[0] = MAX16(-10, st->features[0]+att_table[9] - 2*(st->loss_count-9));
201   else st->features[0] = MAX16(-10, st->features[0]+att_table[st->loss_count]);
202   fargan_synthesize_int(&st->fargan, pcm, &st->features[0]);
203   queue_features(st, st->features);
204   if (st->analysis_pos - FRAME_SIZE >= 0) st->analysis_pos -= FRAME_SIZE;
205   else st->analysis_gap = 1;
206   st->predict_pos = PLC_BUF_SIZE;
207   OPUS_MOVE(st->pcm, &st->pcm[FRAME_SIZE], PLC_BUF_SIZE-FRAME_SIZE);
208   for (i=0;i<FRAME_SIZE;i++) st->pcm[PLC_BUF_SIZE-FRAME_SIZE+i] = (1.f/32768.f)*pcm[i];
209   st->blend = 1;
210   return 0;
211 }
212