xref: /aosp_15_r20/external/libopus/dnn/training_tf2/dump_rdovae.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1*a58d3d2aSXin Li"""
2*a58d3d2aSXin Li/* Copyright (c) 2022 Amazon
3*a58d3d2aSXin Li   Written by Jan Buethe */
4*a58d3d2aSXin Li/*
5*a58d3d2aSXin Li   Redistribution and use in source and binary forms, with or without
6*a58d3d2aSXin Li   modification, are permitted provided that the following conditions
7*a58d3d2aSXin Li   are met:
8*a58d3d2aSXin Li
9*a58d3d2aSXin Li   - Redistributions of source code must retain the above copyright
10*a58d3d2aSXin Li   notice, this list of conditions and the following disclaimer.
11*a58d3d2aSXin Li
12*a58d3d2aSXin Li   - Redistributions in binary form must reproduce the above copyright
13*a58d3d2aSXin Li   notice, this list of conditions and the following disclaimer in the
14*a58d3d2aSXin Li   documentation and/or other materials provided with the distribution.
15*a58d3d2aSXin Li
16*a58d3d2aSXin Li   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17*a58d3d2aSXin Li   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18*a58d3d2aSXin Li   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19*a58d3d2aSXin Li   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20*a58d3d2aSXin Li   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21*a58d3d2aSXin Li   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22*a58d3d2aSXin Li   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23*a58d3d2aSXin Li   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24*a58d3d2aSXin Li   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25*a58d3d2aSXin Li   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26*a58d3d2aSXin Li   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*a58d3d2aSXin Li*/
28*a58d3d2aSXin Li"""
29*a58d3d2aSXin Li
30*a58d3d2aSXin Li
31*a58d3d2aSXin Liimport argparse
32*a58d3d2aSXin Lifrom ftplib import parse150
33*a58d3d2aSXin Liimport os
34*a58d3d2aSXin Li
35*a58d3d2aSXin Lios.environ['CUDA_VISIBLE_DEVICES'] = ""
36*a58d3d2aSXin Li
37*a58d3d2aSXin Liparser = argparse.ArgumentParser()
38*a58d3d2aSXin Li
39*a58d3d2aSXin Liparser.add_argument('weights', metavar="<weight file>", type=str, help='model weight file in hdf5 format')
40*a58d3d2aSXin Liparser.add_argument('--cond-size', type=int, help="conditioning size (default: 256)", default=256)
41*a58d3d2aSXin Liparser.add_argument('--latent-dim', type=int, help="dimension of latent space (default: 80)", default=80)
42*a58d3d2aSXin Liparser.add_argument('--quant-levels', type=int, help="number of quantization steps (default: 16)", default=16)
43*a58d3d2aSXin Li
44*a58d3d2aSXin Liargs = parser.parse_args()
45*a58d3d2aSXin Li
46*a58d3d2aSXin Li# now import the heavy stuff
47*a58d3d2aSXin Liimport tensorflow as tf
48*a58d3d2aSXin Liimport numpy as np
49*a58d3d2aSXin Lifrom keraslayerdump import dump_conv1d_layer, dump_dense_layer, dump_gru_layer, printVector
50*a58d3d2aSXin Lifrom rdovae import new_rdovae_model
51*a58d3d2aSXin Li
52*a58d3d2aSXin Lidef start_header(header_fid, header_name):
53*a58d3d2aSXin Li    header_guard = os.path.basename(header_name)[:-2].upper() + "_H"
54*a58d3d2aSXin Li    header_fid.write(
55*a58d3d2aSXin Lif"""
56*a58d3d2aSXin Li#ifndef {header_guard}
57*a58d3d2aSXin Li#define {header_guard}
58*a58d3d2aSXin Li
59*a58d3d2aSXin Li"""
60*a58d3d2aSXin Li    )
61*a58d3d2aSXin Li
62*a58d3d2aSXin Lidef finish_header(header_fid):
63*a58d3d2aSXin Li    header_fid.write(
64*a58d3d2aSXin Li"""
65*a58d3d2aSXin Li#endif
66*a58d3d2aSXin Li
67*a58d3d2aSXin Li"""
68*a58d3d2aSXin Li    )
69*a58d3d2aSXin Li
70*a58d3d2aSXin Lidef start_source(source_fid, header_name, weight_file):
71*a58d3d2aSXin Li    source_fid.write(
72*a58d3d2aSXin Lif"""
73*a58d3d2aSXin Li/* this source file was automatically generated from weight file {weight_file} */
74*a58d3d2aSXin Li
75*a58d3d2aSXin Li#ifdef HAVE_CONFIG_H
76*a58d3d2aSXin Li#include "config.h"
77*a58d3d2aSXin Li#endif
78*a58d3d2aSXin Li
79*a58d3d2aSXin Li#include "{header_name}"
80*a58d3d2aSXin Li
81*a58d3d2aSXin Li"""
82*a58d3d2aSXin Li    )
83*a58d3d2aSXin Li
84*a58d3d2aSXin Lidef finish_source(source_fid):
85*a58d3d2aSXin Li    pass
86*a58d3d2aSXin Li
87*a58d3d2aSXin Li
88*a58d3d2aSXin Lidef dump_statistical_model(qembedding, f, fh):
89*a58d3d2aSXin Li    w = qembedding.weights[0].numpy()
90*a58d3d2aSXin Li    levels, dim = w.shape
91*a58d3d2aSXin Li    N = dim // 6
92*a58d3d2aSXin Li
93*a58d3d2aSXin Li    print("dumping statistical model")
94*a58d3d2aSXin Li    quant_scales    = tf.math.softplus(w[:, : N]).numpy()
95*a58d3d2aSXin Li    dead_zone       = 0.05 * tf.math.softplus(w[:, N : 2 * N]).numpy()
96*a58d3d2aSXin Li    r               = tf.math.sigmoid(w[:, 5 * N : 6 * N]).numpy()
97*a58d3d2aSXin Li    p0              = tf.math.sigmoid(w[:, 4 * N : 5 * N]).numpy()
98*a58d3d2aSXin Li    p0              = 1 - r ** (0.5 + 0.5 * p0)
99*a58d3d2aSXin Li
100*a58d3d2aSXin Li    quant_scales_q8 = np.round(quant_scales * 2**8).astype(np.uint16)
101*a58d3d2aSXin Li    dead_zone_q10   = np.round(dead_zone * 2**10).astype(np.uint16)
102*a58d3d2aSXin Li    r_q15           = np.round(r * 2**15).astype(np.uint16)
103*a58d3d2aSXin Li    p0_q15          = np.round(p0 * 2**15).astype(np.uint16)
104*a58d3d2aSXin Li
105*a58d3d2aSXin Li    printVector(f, quant_scales_q8, 'dred_quant_scales_q8', dtype='opus_uint16', static=False)
106*a58d3d2aSXin Li    printVector(f, dead_zone_q10, 'dred_dead_zone_q10', dtype='opus_uint16', static=False)
107*a58d3d2aSXin Li    printVector(f, r_q15, 'dred_r_q15', dtype='opus_uint16', static=False)
108*a58d3d2aSXin Li    printVector(f, p0_q15, 'dred_p0_q15', dtype='opus_uint16', static=False)
109*a58d3d2aSXin Li
110*a58d3d2aSXin Li    fh.write(
111*a58d3d2aSXin Lif"""
112*a58d3d2aSXin Liextern const opus_uint16 dred_quant_scales_q8[{levels * N}];
113*a58d3d2aSXin Liextern const opus_uint16 dred_dead_zone_q10[{levels * N}];
114*a58d3d2aSXin Liextern const opus_uint16 dred_r_q15[{levels * N}];
115*a58d3d2aSXin Liextern const opus_uint16 dred_p0_q15[{levels * N}];
116*a58d3d2aSXin Li
117*a58d3d2aSXin Li"""
118*a58d3d2aSXin Li    )
119*a58d3d2aSXin Li
120*a58d3d2aSXin Liif __name__ == "__main__":
121*a58d3d2aSXin Li
122*a58d3d2aSXin Li    model, encoder, decoder, qembedding = new_rdovae_model(20, args.latent_dim, cond_size=args.cond_size, nb_quant=args.quant_levels)
123*a58d3d2aSXin Li    model.load_weights(args.weights)
124*a58d3d2aSXin Li
125*a58d3d2aSXin Li
126*a58d3d2aSXin Li
127*a58d3d2aSXin Li
128*a58d3d2aSXin Li    # encoder
129*a58d3d2aSXin Li    encoder_dense_names = [
130*a58d3d2aSXin Li        'enc_dense1',
131*a58d3d2aSXin Li        'enc_dense3',
132*a58d3d2aSXin Li        'enc_dense5',
133*a58d3d2aSXin Li        'enc_dense7',
134*a58d3d2aSXin Li        'enc_dense8',
135*a58d3d2aSXin Li        'gdense1',
136*a58d3d2aSXin Li        'gdense2'
137*a58d3d2aSXin Li    ]
138*a58d3d2aSXin Li
139*a58d3d2aSXin Li    encoder_gru_names = [
140*a58d3d2aSXin Li        'enc_dense2',
141*a58d3d2aSXin Li        'enc_dense4',
142*a58d3d2aSXin Li        'enc_dense6'
143*a58d3d2aSXin Li    ]
144*a58d3d2aSXin Li
145*a58d3d2aSXin Li    encoder_conv1d_names = [
146*a58d3d2aSXin Li        'bits_dense'
147*a58d3d2aSXin Li    ]
148*a58d3d2aSXin Li
149*a58d3d2aSXin Li    source_fid = open("dred_rdovae_enc_data.c", 'w')
150*a58d3d2aSXin Li    header_fid = open("dred_rdovae_enc_data.h", 'w')
151*a58d3d2aSXin Li
152*a58d3d2aSXin Li    start_header(header_fid, "dred_rdovae_enc_data.h")
153*a58d3d2aSXin Li    start_source(source_fid, "dred_rdovae_enc_data.h", os.path.basename(args.weights))
154*a58d3d2aSXin Li
155*a58d3d2aSXin Li    header_fid.write(
156*a58d3d2aSXin Lif"""
157*a58d3d2aSXin Li#include "dred_rdovae_constants.h"
158*a58d3d2aSXin Li
159*a58d3d2aSXin Li#include "nnet.h"
160*a58d3d2aSXin Li"""
161*a58d3d2aSXin Li    )
162*a58d3d2aSXin Li
163*a58d3d2aSXin Li    # dump GRUs
164*a58d3d2aSXin Li    max_rnn_neurons_enc = max(
165*a58d3d2aSXin Li        [
166*a58d3d2aSXin Li            dump_gru_layer(encoder.get_layer(name), source_fid, header_fid, dotp=True, sparse=True)
167*a58d3d2aSXin Li            for name in encoder_gru_names
168*a58d3d2aSXin Li        ]
169*a58d3d2aSXin Li    )
170*a58d3d2aSXin Li
171*a58d3d2aSXin Li    # dump conv layers
172*a58d3d2aSXin Li    max_conv_inputs = max(
173*a58d3d2aSXin Li        [
174*a58d3d2aSXin Li            dump_conv1d_layer(encoder.get_layer(name), source_fid, header_fid)
175*a58d3d2aSXin Li            for name in encoder_conv1d_names
176*a58d3d2aSXin Li        ]
177*a58d3d2aSXin Li    )
178*a58d3d2aSXin Li
179*a58d3d2aSXin Li    # dump Dense layers
180*a58d3d2aSXin Li    for name in encoder_dense_names:
181*a58d3d2aSXin Li        layer = encoder.get_layer(name)
182*a58d3d2aSXin Li        dump_dense_layer(layer, source_fid, header_fid)
183*a58d3d2aSXin Li
184*a58d3d2aSXin Li    # some global constants
185*a58d3d2aSXin Li    header_fid.write(
186*a58d3d2aSXin Lif"""
187*a58d3d2aSXin Li
188*a58d3d2aSXin Li#define DRED_ENC_MAX_RNN_NEURONS {max_rnn_neurons_enc}
189*a58d3d2aSXin Li
190*a58d3d2aSXin Li#define DRED_ENC_MAX_CONV_INPUTS {max_conv_inputs}
191*a58d3d2aSXin Li
192*a58d3d2aSXin Li"""
193*a58d3d2aSXin Li    )
194*a58d3d2aSXin Li
195*a58d3d2aSXin Li    finish_header(header_fid)
196*a58d3d2aSXin Li    finish_source(source_fid)
197*a58d3d2aSXin Li
198*a58d3d2aSXin Li    header_fid.close()
199*a58d3d2aSXin Li    source_fid.close()
200*a58d3d2aSXin Li
201*a58d3d2aSXin Li    # statistical model
202*a58d3d2aSXin Li    source_fid = open("dred_rdovae_stats_data.c", 'w')
203*a58d3d2aSXin Li    header_fid = open("dred_rdovae_stats_data.h", 'w')
204*a58d3d2aSXin Li
205*a58d3d2aSXin Li    start_header(header_fid, "dred_rdovae_stats_data.h")
206*a58d3d2aSXin Li    start_source(source_fid, "dred_rdovae_stats_data.h", os.path.basename(args.weights))
207*a58d3d2aSXin Li
208*a58d3d2aSXin Li    header_fid.write(
209*a58d3d2aSXin Li"""
210*a58d3d2aSXin Li
211*a58d3d2aSXin Li#include "opus_types.h"
212*a58d3d2aSXin Li
213*a58d3d2aSXin Li"""
214*a58d3d2aSXin Li    )
215*a58d3d2aSXin Li
216*a58d3d2aSXin Li    dump_statistical_model(qembedding, source_fid, header_fid)
217*a58d3d2aSXin Li
218*a58d3d2aSXin Li    finish_header(header_fid)
219*a58d3d2aSXin Li    finish_source(source_fid)
220*a58d3d2aSXin Li
221*a58d3d2aSXin Li    header_fid.close()
222*a58d3d2aSXin Li    source_fid.close()
223*a58d3d2aSXin Li
224*a58d3d2aSXin Li    # decoder
225*a58d3d2aSXin Li    decoder_dense_names = [
226*a58d3d2aSXin Li        'state1',
227*a58d3d2aSXin Li        'state2',
228*a58d3d2aSXin Li        'state3',
229*a58d3d2aSXin Li        'dec_dense1',
230*a58d3d2aSXin Li        'dec_dense3',
231*a58d3d2aSXin Li        'dec_dense5',
232*a58d3d2aSXin Li        'dec_dense7',
233*a58d3d2aSXin Li        'dec_dense8',
234*a58d3d2aSXin Li        'dec_final'
235*a58d3d2aSXin Li    ]
236*a58d3d2aSXin Li
237*a58d3d2aSXin Li    decoder_gru_names = [
238*a58d3d2aSXin Li        'dec_dense2',
239*a58d3d2aSXin Li        'dec_dense4',
240*a58d3d2aSXin Li        'dec_dense6'
241*a58d3d2aSXin Li    ]
242*a58d3d2aSXin Li
243*a58d3d2aSXin Li    source_fid = open("dred_rdovae_dec_data.c", 'w')
244*a58d3d2aSXin Li    header_fid = open("dred_rdovae_dec_data.h", 'w')
245*a58d3d2aSXin Li
246*a58d3d2aSXin Li    start_header(header_fid, "dred_rdovae_dec_data.h")
247*a58d3d2aSXin Li    start_source(source_fid, "dred_rdovae_dec_data.h", os.path.basename(args.weights))
248*a58d3d2aSXin Li
249*a58d3d2aSXin Li    header_fid.write(
250*a58d3d2aSXin Lif"""
251*a58d3d2aSXin Li#include "dred_rdovae_constants.h"
252*a58d3d2aSXin Li
253*a58d3d2aSXin Li#include "nnet.h"
254*a58d3d2aSXin Li"""
255*a58d3d2aSXin Li    )
256*a58d3d2aSXin Li
257*a58d3d2aSXin Li
258*a58d3d2aSXin Li    # dump GRUs
259*a58d3d2aSXin Li    max_rnn_neurons_dec = max(
260*a58d3d2aSXin Li        [
261*a58d3d2aSXin Li            dump_gru_layer(decoder.get_layer(name), source_fid, header_fid, dotp=True, sparse=True)
262*a58d3d2aSXin Li            for name in decoder_gru_names
263*a58d3d2aSXin Li        ]
264*a58d3d2aSXin Li    )
265*a58d3d2aSXin Li
266*a58d3d2aSXin Li    # dump Dense layers
267*a58d3d2aSXin Li    for name in decoder_dense_names:
268*a58d3d2aSXin Li        layer = decoder.get_layer(name)
269*a58d3d2aSXin Li        dump_dense_layer(layer, source_fid, header_fid)
270*a58d3d2aSXin Li
271*a58d3d2aSXin Li    # some global constants
272*a58d3d2aSXin Li    header_fid.write(
273*a58d3d2aSXin Lif"""
274*a58d3d2aSXin Li
275*a58d3d2aSXin Li#define DRED_DEC_MAX_RNN_NEURONS {max_rnn_neurons_dec}
276*a58d3d2aSXin Li
277*a58d3d2aSXin Li"""
278*a58d3d2aSXin Li    )
279*a58d3d2aSXin Li
280*a58d3d2aSXin Li    finish_header(header_fid)
281*a58d3d2aSXin Li    finish_source(source_fid)
282*a58d3d2aSXin Li
283*a58d3d2aSXin Li    header_fid.close()
284*a58d3d2aSXin Li    source_fid.close()
285*a58d3d2aSXin Li
286*a58d3d2aSXin Li    # common constants
287*a58d3d2aSXin Li    header_fid = open("dred_rdovae_constants.h", 'w')
288*a58d3d2aSXin Li    start_header(header_fid, "dred_rdovae_constants.h")
289*a58d3d2aSXin Li
290*a58d3d2aSXin Li    header_fid.write(
291*a58d3d2aSXin Lif"""
292*a58d3d2aSXin Li#define DRED_NUM_FEATURES 20
293*a58d3d2aSXin Li
294*a58d3d2aSXin Li#define DRED_LATENT_DIM {args.latent_dim}
295*a58d3d2aSXin Li
296*a58d3d2aSXin Li#define DRED_STATE_DIM {24}
297*a58d3d2aSXin Li
298*a58d3d2aSXin Li#define DRED_NUM_QUANTIZATION_LEVELS {qembedding.weights[0].shape[0]}
299*a58d3d2aSXin Li
300*a58d3d2aSXin Li#define DRED_MAX_RNN_NEURONS {max(max_rnn_neurons_enc, max_rnn_neurons_dec)}
301*a58d3d2aSXin Li
302*a58d3d2aSXin Li#define DRED_MAX_CONV_INPUTS {max_conv_inputs}
303*a58d3d2aSXin Li"""
304*a58d3d2aSXin Li    )
305*a58d3d2aSXin Li
306*a58d3d2aSXin Li    finish_header(header_fid)