xref: /aosp_15_r20/external/libopus/dnn/dred_rdovae_dec.c (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1*a58d3d2aSXin Li /* Copyright (c) 2022 Amazon
2*a58d3d2aSXin Li    Written by Jan Buethe */
3*a58d3d2aSXin Li /*
4*a58d3d2aSXin Li    Redistribution and use in source and binary forms, with or without
5*a58d3d2aSXin Li    modification, are permitted provided that the following conditions
6*a58d3d2aSXin Li    are met:
7*a58d3d2aSXin Li 
8*a58d3d2aSXin Li    - Redistributions of source code must retain the above copyright
9*a58d3d2aSXin Li    notice, this list of conditions and the following disclaimer.
10*a58d3d2aSXin Li 
11*a58d3d2aSXin Li    - Redistributions in binary form must reproduce the above copyright
12*a58d3d2aSXin Li    notice, this list of conditions and the following disclaimer in the
13*a58d3d2aSXin Li    documentation and/or other materials provided with the distribution.
14*a58d3d2aSXin Li 
15*a58d3d2aSXin Li    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16*a58d3d2aSXin Li    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17*a58d3d2aSXin Li    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18*a58d3d2aSXin Li    A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
19*a58d3d2aSXin Li    OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20*a58d3d2aSXin Li    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21*a58d3d2aSXin Li    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22*a58d3d2aSXin Li    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23*a58d3d2aSXin Li    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24*a58d3d2aSXin Li    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25*a58d3d2aSXin Li    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26*a58d3d2aSXin Li */
27*a58d3d2aSXin Li 
28*a58d3d2aSXin Li #ifdef HAVE_CONFIG_H
29*a58d3d2aSXin Li #include "config.h"
30*a58d3d2aSXin Li #endif
31*a58d3d2aSXin Li 
32*a58d3d2aSXin Li #include "dred_rdovae_dec.h"
33*a58d3d2aSXin Li #include "dred_rdovae_constants.h"
34*a58d3d2aSXin Li #include "os_support.h"
35*a58d3d2aSXin Li 
conv1_cond_init(float * mem,int len,int dilation,int * init)36*a58d3d2aSXin Li static void conv1_cond_init(float *mem, int len, int dilation, int *init)
37*a58d3d2aSXin Li {
38*a58d3d2aSXin Li     if (!*init) {
39*a58d3d2aSXin Li         int i;
40*a58d3d2aSXin Li         for (i=0;i<dilation;i++) OPUS_CLEAR(&mem[i*len], len);
41*a58d3d2aSXin Li     }
42*a58d3d2aSXin Li     *init = 1;
43*a58d3d2aSXin Li }
44*a58d3d2aSXin Li 
DRED_rdovae_decode_all(const RDOVAEDec * model,float * features,const float * state,const float * latents,int nb_latents,int arch)45*a58d3d2aSXin Li void DRED_rdovae_decode_all(const RDOVAEDec *model, float *features, const float *state, const float *latents, int nb_latents, int arch)
46*a58d3d2aSXin Li {
47*a58d3d2aSXin Li     int i;
48*a58d3d2aSXin Li     RDOVAEDecState dec;
49*a58d3d2aSXin Li     memset(&dec, 0, sizeof(dec));
50*a58d3d2aSXin Li     dred_rdovae_dec_init_states(&dec, model, state, arch);
51*a58d3d2aSXin Li     for (i = 0; i < 2*nb_latents; i += 2)
52*a58d3d2aSXin Li     {
53*a58d3d2aSXin Li         dred_rdovae_decode_qframe(
54*a58d3d2aSXin Li             &dec,
55*a58d3d2aSXin Li             model,
56*a58d3d2aSXin Li             &features[2*i*DRED_NUM_FEATURES],
57*a58d3d2aSXin Li             &latents[(i/2)*DRED_LATENT_DIM],
58*a58d3d2aSXin Li             arch);
59*a58d3d2aSXin Li     }
60*a58d3d2aSXin Li }
61*a58d3d2aSXin Li 
dred_rdovae_dec_init_states(RDOVAEDecState * h,const RDOVAEDec * model,const float * initial_state,int arch)62*a58d3d2aSXin Li void dred_rdovae_dec_init_states(
63*a58d3d2aSXin Li     RDOVAEDecState *h,            /* io: state buffer handle */
64*a58d3d2aSXin Li     const RDOVAEDec *model,
65*a58d3d2aSXin Li     const float *initial_state,  /* i: initial state */
66*a58d3d2aSXin Li     int arch
67*a58d3d2aSXin Li     )
68*a58d3d2aSXin Li {
69*a58d3d2aSXin Li     float hidden[DEC_HIDDEN_INIT_OUT_SIZE];
70*a58d3d2aSXin Li     float state_init[DEC_GRU1_STATE_SIZE+DEC_GRU2_STATE_SIZE+DEC_GRU3_STATE_SIZE+DEC_GRU4_STATE_SIZE+DEC_GRU5_STATE_SIZE];
71*a58d3d2aSXin Li     int counter=0;
72*a58d3d2aSXin Li     compute_generic_dense(&model->dec_hidden_init, hidden, initial_state, ACTIVATION_TANH, arch);
73*a58d3d2aSXin Li     compute_generic_dense(&model->dec_gru_init, state_init, hidden, ACTIVATION_TANH, arch);
74*a58d3d2aSXin Li     OPUS_COPY(h->gru1_state, state_init, DEC_GRU1_STATE_SIZE);
75*a58d3d2aSXin Li     counter += DEC_GRU1_STATE_SIZE;
76*a58d3d2aSXin Li     OPUS_COPY(h->gru2_state, &state_init[counter], DEC_GRU2_STATE_SIZE);
77*a58d3d2aSXin Li     counter += DEC_GRU2_STATE_SIZE;
78*a58d3d2aSXin Li     OPUS_COPY(h->gru3_state, &state_init[counter], DEC_GRU3_STATE_SIZE);
79*a58d3d2aSXin Li     counter += DEC_GRU3_STATE_SIZE;
80*a58d3d2aSXin Li     OPUS_COPY(h->gru4_state, &state_init[counter], DEC_GRU4_STATE_SIZE);
81*a58d3d2aSXin Li     counter += DEC_GRU4_STATE_SIZE;
82*a58d3d2aSXin Li     OPUS_COPY(h->gru5_state, &state_init[counter], DEC_GRU5_STATE_SIZE);
83*a58d3d2aSXin Li     h->initialized = 0;
84*a58d3d2aSXin Li }
85*a58d3d2aSXin Li 
86*a58d3d2aSXin Li 
dred_rdovae_decode_qframe(RDOVAEDecState * dec_state,const RDOVAEDec * model,float * qframe,const float * input,int arch)87*a58d3d2aSXin Li void dred_rdovae_decode_qframe(
88*a58d3d2aSXin Li     RDOVAEDecState *dec_state,       /* io: state buffer handle */
89*a58d3d2aSXin Li     const RDOVAEDec *model,
90*a58d3d2aSXin Li     float *qframe,              /* o: quadruple feature frame (four concatenated frames in reverse order) */
91*a58d3d2aSXin Li     const float *input,          /* i: latent vector */
92*a58d3d2aSXin Li     int arch
93*a58d3d2aSXin Li     )
94*a58d3d2aSXin Li {
95*a58d3d2aSXin Li     float buffer[DEC_DENSE1_OUT_SIZE + DEC_GRU1_OUT_SIZE + DEC_GRU2_OUT_SIZE + DEC_GRU3_OUT_SIZE + DEC_GRU4_OUT_SIZE + DEC_GRU5_OUT_SIZE
96*a58d3d2aSXin Li                  + DEC_CONV1_OUT_SIZE + DEC_CONV2_OUT_SIZE + DEC_CONV3_OUT_SIZE + DEC_CONV4_OUT_SIZE + DEC_CONV5_OUT_SIZE];
97*a58d3d2aSXin Li     int output_index = 0;
98*a58d3d2aSXin Li 
99*a58d3d2aSXin Li     /* run encoder stack and concatenate output in buffer*/
100*a58d3d2aSXin Li     compute_generic_dense(&model->dec_dense1, &buffer[output_index], input, ACTIVATION_TANH, arch);
101*a58d3d2aSXin Li     output_index += DEC_DENSE1_OUT_SIZE;
102*a58d3d2aSXin Li 
103*a58d3d2aSXin Li     compute_generic_gru(&model->dec_gru1_input, &model->dec_gru1_recurrent, dec_state->gru1_state, buffer, arch);
104*a58d3d2aSXin Li     compute_glu(&model->dec_glu1, &buffer[output_index], dec_state->gru1_state, arch);
105*a58d3d2aSXin Li     output_index += DEC_GRU1_OUT_SIZE;
106*a58d3d2aSXin Li     conv1_cond_init(dec_state->conv1_state, output_index, 1, &dec_state->initialized);
107*a58d3d2aSXin Li     compute_generic_conv1d(&model->dec_conv1, &buffer[output_index], dec_state->conv1_state, buffer, output_index, ACTIVATION_TANH, arch);
108*a58d3d2aSXin Li     output_index += DEC_CONV1_OUT_SIZE;
109*a58d3d2aSXin Li 
110*a58d3d2aSXin Li     compute_generic_gru(&model->dec_gru2_input, &model->dec_gru2_recurrent, dec_state->gru2_state, buffer, arch);
111*a58d3d2aSXin Li     compute_glu(&model->dec_glu2, &buffer[output_index], dec_state->gru2_state, arch);
112*a58d3d2aSXin Li     output_index += DEC_GRU2_OUT_SIZE;
113*a58d3d2aSXin Li     conv1_cond_init(dec_state->conv2_state, output_index, 1, &dec_state->initialized);
114*a58d3d2aSXin Li     compute_generic_conv1d(&model->dec_conv2, &buffer[output_index], dec_state->conv2_state, buffer, output_index, ACTIVATION_TANH, arch);
115*a58d3d2aSXin Li     output_index += DEC_CONV2_OUT_SIZE;
116*a58d3d2aSXin Li 
117*a58d3d2aSXin Li     compute_generic_gru(&model->dec_gru3_input, &model->dec_gru3_recurrent, dec_state->gru3_state, buffer, arch);
118*a58d3d2aSXin Li     compute_glu(&model->dec_glu3, &buffer[output_index], dec_state->gru3_state, arch);
119*a58d3d2aSXin Li     output_index += DEC_GRU3_OUT_SIZE;
120*a58d3d2aSXin Li     conv1_cond_init(dec_state->conv3_state, output_index, 1, &dec_state->initialized);
121*a58d3d2aSXin Li     compute_generic_conv1d(&model->dec_conv3, &buffer[output_index], dec_state->conv3_state, buffer, output_index, ACTIVATION_TANH, arch);
122*a58d3d2aSXin Li     output_index += DEC_CONV3_OUT_SIZE;
123*a58d3d2aSXin Li 
124*a58d3d2aSXin Li     compute_generic_gru(&model->dec_gru4_input, &model->dec_gru4_recurrent, dec_state->gru4_state, buffer, arch);
125*a58d3d2aSXin Li     compute_glu(&model->dec_glu4, &buffer[output_index], dec_state->gru4_state, arch);
126*a58d3d2aSXin Li     output_index += DEC_GRU4_OUT_SIZE;
127*a58d3d2aSXin Li     conv1_cond_init(dec_state->conv4_state, output_index, 1, &dec_state->initialized);
128*a58d3d2aSXin Li     compute_generic_conv1d(&model->dec_conv4, &buffer[output_index], dec_state->conv4_state, buffer, output_index, ACTIVATION_TANH, arch);
129*a58d3d2aSXin Li     output_index += DEC_CONV4_OUT_SIZE;
130*a58d3d2aSXin Li 
131*a58d3d2aSXin Li     compute_generic_gru(&model->dec_gru5_input, &model->dec_gru5_recurrent, dec_state->gru5_state, buffer, arch);
132*a58d3d2aSXin Li     compute_glu(&model->dec_glu5, &buffer[output_index], dec_state->gru5_state, arch);
133*a58d3d2aSXin Li     output_index += DEC_GRU5_OUT_SIZE;
134*a58d3d2aSXin Li     conv1_cond_init(dec_state->conv5_state, output_index, 1, &dec_state->initialized);
135*a58d3d2aSXin Li     compute_generic_conv1d(&model->dec_conv5, &buffer[output_index], dec_state->conv5_state, buffer, output_index, ACTIVATION_TANH, arch);
136*a58d3d2aSXin Li     output_index += DEC_CONV5_OUT_SIZE;
137*a58d3d2aSXin Li 
138*a58d3d2aSXin Li     compute_generic_dense(&model->dec_output, qframe, buffer, ACTIVATION_LINEAR, arch);
139*a58d3d2aSXin Li }
140