xref: /aosp_15_r20/external/libopus/dnn/pitchdnn.c (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1 #ifdef HAVE_CONFIG_H
2 #include "config.h"
3 #endif
4 
5 #include <math.h>
6 #include "pitchdnn.h"
7 #include "os_support.h"
8 #include "nnet.h"
9 #include "lpcnet_private.h"
10 
11 
compute_pitchdnn(PitchDNNState * st,const float * if_features,const float * xcorr_features,int arch)12 float compute_pitchdnn(
13     PitchDNNState *st,
14     const float *if_features,
15     const float *xcorr_features,
16     int arch
17     )
18 {
19   float if1_out[DENSE_IF_UPSAMPLER_1_OUT_SIZE];
20   float downsampler_in[NB_XCORR_FEATURES + DENSE_IF_UPSAMPLER_2_OUT_SIZE];
21   float downsampler_out[DENSE_DOWNSAMPLER_OUT_SIZE];
22   float conv1_tmp1[(NB_XCORR_FEATURES + 2)*8] = {0};
23   float conv1_tmp2[(NB_XCORR_FEATURES + 2)*8] = {0};
24   float output[DENSE_FINAL_UPSAMPLER_OUT_SIZE];
25   int i;
26   int pos=0;
27   float maxval=-1;
28   float sum=0;
29   float count=0;
30   PitchDNN *model = &st->model;
31   /* IF */
32   compute_generic_dense(&model->dense_if_upsampler_1, if1_out, if_features, ACTIVATION_TANH, arch);
33   compute_generic_dense(&model->dense_if_upsampler_2, &downsampler_in[NB_XCORR_FEATURES], if1_out, ACTIVATION_TANH, arch);
34   /* xcorr*/
35   OPUS_COPY(&conv1_tmp1[1], xcorr_features, NB_XCORR_FEATURES);
36   compute_conv2d(&model->conv2d_1, &conv1_tmp2[1], st->xcorr_mem1, conv1_tmp1, NB_XCORR_FEATURES, NB_XCORR_FEATURES+2, ACTIVATION_TANH, arch);
37   compute_conv2d(&model->conv2d_2, downsampler_in, st->xcorr_mem2, conv1_tmp2, NB_XCORR_FEATURES, NB_XCORR_FEATURES, ACTIVATION_TANH, arch);
38 
39   compute_generic_dense(&model->dense_downsampler, downsampler_out, downsampler_in, ACTIVATION_TANH, arch);
40   compute_generic_gru(&model->gru_1_input, &model->gru_1_recurrent, st->gru_state, downsampler_out, arch);
41   compute_generic_dense(&model->dense_final_upsampler, output, st->gru_state, ACTIVATION_LINEAR, arch);
42   for (i=0;i<180;i++) {
43     if (output[i] > maxval) {
44       pos = i;
45       maxval = output[i];
46     }
47   }
48   for (i=IMAX(0, pos-2); i<=IMIN(179, pos+2); i++) {
49     float p = exp(output[i]);
50     sum += p*i;
51     count += p;
52   }
53   /*printf("%d %f\n", pos, sum/count);*/
54   return (1.f/60.f)*(sum/count) - 1.5;
55   /*return 256.f/pow(2.f, (1.f/60.f)*i);*/
56 }
57 
58 
pitchdnn_init(PitchDNNState * st)59 void pitchdnn_init(PitchDNNState *st)
60 {
61   int ret;
62   OPUS_CLEAR(st, 1);
63 #ifndef USE_WEIGHTS_FILE
64   ret = init_pitchdnn(&st->model, pitchdnn_arrays);
65 #else
66   ret = 0;
67 #endif
68   celt_assert(ret == 0);
69 }
70 
pitchdnn_load_model(PitchDNNState * st,const void * data,int len)71 int pitchdnn_load_model(PitchDNNState *st, const void *data, int len) {
72   WeightArray *list;
73   int ret;
74   parse_weights(&list, data, len);
75   ret = init_pitchdnn(&st->model, list);
76   opus_free(list);
77   if (ret == 0) return 0;
78   else return -1;
79 }
80