1 /* Copyright (c) 2022 Amazon
2 Written by Jan Buethe */
3 /*
4 Redistribution and use in source and binary forms, with or without
5 modification, are permitted provided that the following conditions
6 are met:
7
8 - Redistributions of source code must retain the above copyright
9 notice, this list of conditions and the following disclaimer.
10
11 - Redistributions in binary form must reproduce the above copyright
12 notice, this list of conditions and the following disclaimer in the
13 documentation and/or other materials provided with the distribution.
14
15 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
19 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27
28 #ifdef HAVE_CONFIG_H
29 #include "config.h"
30 #endif
31
32 #include "dred_rdovae_dec.h"
33 #include "dred_rdovae_constants.h"
34 #include "os_support.h"
35
conv1_cond_init(float * mem,int len,int dilation,int * init)36 static void conv1_cond_init(float *mem, int len, int dilation, int *init)
37 {
38 if (!*init) {
39 int i;
40 for (i=0;i<dilation;i++) OPUS_CLEAR(&mem[i*len], len);
41 }
42 *init = 1;
43 }
44
DRED_rdovae_decode_all(const RDOVAEDec * model,float * features,const float * state,const float * latents,int nb_latents,int arch)45 void DRED_rdovae_decode_all(const RDOVAEDec *model, float *features, const float *state, const float *latents, int nb_latents, int arch)
46 {
47 int i;
48 RDOVAEDecState dec;
49 memset(&dec, 0, sizeof(dec));
50 dred_rdovae_dec_init_states(&dec, model, state, arch);
51 for (i = 0; i < 2*nb_latents; i += 2)
52 {
53 dred_rdovae_decode_qframe(
54 &dec,
55 model,
56 &features[2*i*DRED_NUM_FEATURES],
57 &latents[(i/2)*DRED_LATENT_DIM],
58 arch);
59 }
60 }
61
dred_rdovae_dec_init_states(RDOVAEDecState * h,const RDOVAEDec * model,const float * initial_state,int arch)62 void dred_rdovae_dec_init_states(
63 RDOVAEDecState *h, /* io: state buffer handle */
64 const RDOVAEDec *model,
65 const float *initial_state, /* i: initial state */
66 int arch
67 )
68 {
69 float hidden[DEC_HIDDEN_INIT_OUT_SIZE];
70 float state_init[DEC_GRU1_STATE_SIZE+DEC_GRU2_STATE_SIZE+DEC_GRU3_STATE_SIZE+DEC_GRU4_STATE_SIZE+DEC_GRU5_STATE_SIZE];
71 int counter=0;
72 compute_generic_dense(&model->dec_hidden_init, hidden, initial_state, ACTIVATION_TANH, arch);
73 compute_generic_dense(&model->dec_gru_init, state_init, hidden, ACTIVATION_TANH, arch);
74 OPUS_COPY(h->gru1_state, state_init, DEC_GRU1_STATE_SIZE);
75 counter += DEC_GRU1_STATE_SIZE;
76 OPUS_COPY(h->gru2_state, &state_init[counter], DEC_GRU2_STATE_SIZE);
77 counter += DEC_GRU2_STATE_SIZE;
78 OPUS_COPY(h->gru3_state, &state_init[counter], DEC_GRU3_STATE_SIZE);
79 counter += DEC_GRU3_STATE_SIZE;
80 OPUS_COPY(h->gru4_state, &state_init[counter], DEC_GRU4_STATE_SIZE);
81 counter += DEC_GRU4_STATE_SIZE;
82 OPUS_COPY(h->gru5_state, &state_init[counter], DEC_GRU5_STATE_SIZE);
83 h->initialized = 0;
84 }
85
86
dred_rdovae_decode_qframe(RDOVAEDecState * dec_state,const RDOVAEDec * model,float * qframe,const float * input,int arch)87 void dred_rdovae_decode_qframe(
88 RDOVAEDecState *dec_state, /* io: state buffer handle */
89 const RDOVAEDec *model,
90 float *qframe, /* o: quadruple feature frame (four concatenated frames in reverse order) */
91 const float *input, /* i: latent vector */
92 int arch
93 )
94 {
95 float buffer[DEC_DENSE1_OUT_SIZE + DEC_GRU1_OUT_SIZE + DEC_GRU2_OUT_SIZE + DEC_GRU3_OUT_SIZE + DEC_GRU4_OUT_SIZE + DEC_GRU5_OUT_SIZE
96 + DEC_CONV1_OUT_SIZE + DEC_CONV2_OUT_SIZE + DEC_CONV3_OUT_SIZE + DEC_CONV4_OUT_SIZE + DEC_CONV5_OUT_SIZE];
97 int output_index = 0;
98
99 /* run encoder stack and concatenate output in buffer*/
100 compute_generic_dense(&model->dec_dense1, &buffer[output_index], input, ACTIVATION_TANH, arch);
101 output_index += DEC_DENSE1_OUT_SIZE;
102
103 compute_generic_gru(&model->dec_gru1_input, &model->dec_gru1_recurrent, dec_state->gru1_state, buffer, arch);
104 compute_glu(&model->dec_glu1, &buffer[output_index], dec_state->gru1_state, arch);
105 output_index += DEC_GRU1_OUT_SIZE;
106 conv1_cond_init(dec_state->conv1_state, output_index, 1, &dec_state->initialized);
107 compute_generic_conv1d(&model->dec_conv1, &buffer[output_index], dec_state->conv1_state, buffer, output_index, ACTIVATION_TANH, arch);
108 output_index += DEC_CONV1_OUT_SIZE;
109
110 compute_generic_gru(&model->dec_gru2_input, &model->dec_gru2_recurrent, dec_state->gru2_state, buffer, arch);
111 compute_glu(&model->dec_glu2, &buffer[output_index], dec_state->gru2_state, arch);
112 output_index += DEC_GRU2_OUT_SIZE;
113 conv1_cond_init(dec_state->conv2_state, output_index, 1, &dec_state->initialized);
114 compute_generic_conv1d(&model->dec_conv2, &buffer[output_index], dec_state->conv2_state, buffer, output_index, ACTIVATION_TANH, arch);
115 output_index += DEC_CONV2_OUT_SIZE;
116
117 compute_generic_gru(&model->dec_gru3_input, &model->dec_gru3_recurrent, dec_state->gru3_state, buffer, arch);
118 compute_glu(&model->dec_glu3, &buffer[output_index], dec_state->gru3_state, arch);
119 output_index += DEC_GRU3_OUT_SIZE;
120 conv1_cond_init(dec_state->conv3_state, output_index, 1, &dec_state->initialized);
121 compute_generic_conv1d(&model->dec_conv3, &buffer[output_index], dec_state->conv3_state, buffer, output_index, ACTIVATION_TANH, arch);
122 output_index += DEC_CONV3_OUT_SIZE;
123
124 compute_generic_gru(&model->dec_gru4_input, &model->dec_gru4_recurrent, dec_state->gru4_state, buffer, arch);
125 compute_glu(&model->dec_glu4, &buffer[output_index], dec_state->gru4_state, arch);
126 output_index += DEC_GRU4_OUT_SIZE;
127 conv1_cond_init(dec_state->conv4_state, output_index, 1, &dec_state->initialized);
128 compute_generic_conv1d(&model->dec_conv4, &buffer[output_index], dec_state->conv4_state, buffer, output_index, ACTIVATION_TANH, arch);
129 output_index += DEC_CONV4_OUT_SIZE;
130
131 compute_generic_gru(&model->dec_gru5_input, &model->dec_gru5_recurrent, dec_state->gru5_state, buffer, arch);
132 compute_glu(&model->dec_glu5, &buffer[output_index], dec_state->gru5_state, arch);
133 output_index += DEC_GRU5_OUT_SIZE;
134 conv1_cond_init(dec_state->conv5_state, output_index, 1, &dec_state->initialized);
135 compute_generic_conv1d(&model->dec_conv5, &buffer[output_index], dec_state->conv5_state, buffer, output_index, ACTIVATION_TANH, arch);
136 output_index += DEC_CONV5_OUT_SIZE;
137
138 compute_generic_dense(&model->dec_output, qframe, buffer, ACTIVATION_LINEAR, arch);
139 }
140