xref: /aosp_15_r20/external/libopus/dnn/training_tf2/dump_rdovae.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2/* Copyright (c) 2022 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30
31import argparse
32from ftplib import parse150
33import os
34
35os.environ['CUDA_VISIBLE_DEVICES'] = ""
36
37parser = argparse.ArgumentParser()
38
39parser.add_argument('weights', metavar="<weight file>", type=str, help='model weight file in hdf5 format')
40parser.add_argument('--cond-size', type=int, help="conditioning size (default: 256)", default=256)
41parser.add_argument('--latent-dim', type=int, help="dimension of latent space (default: 80)", default=80)
42parser.add_argument('--quant-levels', type=int, help="number of quantization steps (default: 16)", default=16)
43
44args = parser.parse_args()
45
46# now import the heavy stuff
47import tensorflow as tf
48import numpy as np
49from keraslayerdump import dump_conv1d_layer, dump_dense_layer, dump_gru_layer, printVector
50from rdovae import new_rdovae_model
51
52def start_header(header_fid, header_name):
53    header_guard = os.path.basename(header_name)[:-2].upper() + "_H"
54    header_fid.write(
55f"""
56#ifndef {header_guard}
57#define {header_guard}
58
59"""
60    )
61
62def finish_header(header_fid):
63    header_fid.write(
64"""
65#endif
66
67"""
68    )
69
70def start_source(source_fid, header_name, weight_file):
71    source_fid.write(
72f"""
73/* this source file was automatically generated from weight file {weight_file} */
74
75#ifdef HAVE_CONFIG_H
76#include "config.h"
77#endif
78
79#include "{header_name}"
80
81"""
82    )
83
84def finish_source(source_fid):
85    pass
86
87
88def dump_statistical_model(qembedding, f, fh):
89    w = qembedding.weights[0].numpy()
90    levels, dim = w.shape
91    N = dim // 6
92
93    print("dumping statistical model")
94    quant_scales    = tf.math.softplus(w[:, : N]).numpy()
95    dead_zone       = 0.05 * tf.math.softplus(w[:, N : 2 * N]).numpy()
96    r               = tf.math.sigmoid(w[:, 5 * N : 6 * N]).numpy()
97    p0              = tf.math.sigmoid(w[:, 4 * N : 5 * N]).numpy()
98    p0              = 1 - r ** (0.5 + 0.5 * p0)
99
100    quant_scales_q8 = np.round(quant_scales * 2**8).astype(np.uint16)
101    dead_zone_q10   = np.round(dead_zone * 2**10).astype(np.uint16)
102    r_q15           = np.round(r * 2**15).astype(np.uint16)
103    p0_q15          = np.round(p0 * 2**15).astype(np.uint16)
104
105    printVector(f, quant_scales_q8, 'dred_quant_scales_q8', dtype='opus_uint16', static=False)
106    printVector(f, dead_zone_q10, 'dred_dead_zone_q10', dtype='opus_uint16', static=False)
107    printVector(f, r_q15, 'dred_r_q15', dtype='opus_uint16', static=False)
108    printVector(f, p0_q15, 'dred_p0_q15', dtype='opus_uint16', static=False)
109
110    fh.write(
111f"""
112extern const opus_uint16 dred_quant_scales_q8[{levels * N}];
113extern const opus_uint16 dred_dead_zone_q10[{levels * N}];
114extern const opus_uint16 dred_r_q15[{levels * N}];
115extern const opus_uint16 dred_p0_q15[{levels * N}];
116
117"""
118    )
119
120if __name__ == "__main__":
121
122    model, encoder, decoder, qembedding = new_rdovae_model(20, args.latent_dim, cond_size=args.cond_size, nb_quant=args.quant_levels)
123    model.load_weights(args.weights)
124
125
126
127
128    # encoder
129    encoder_dense_names = [
130        'enc_dense1',
131        'enc_dense3',
132        'enc_dense5',
133        'enc_dense7',
134        'enc_dense8',
135        'gdense1',
136        'gdense2'
137    ]
138
139    encoder_gru_names = [
140        'enc_dense2',
141        'enc_dense4',
142        'enc_dense6'
143    ]
144
145    encoder_conv1d_names = [
146        'bits_dense'
147    ]
148
149    source_fid = open("dred_rdovae_enc_data.c", 'w')
150    header_fid = open("dred_rdovae_enc_data.h", 'w')
151
152    start_header(header_fid, "dred_rdovae_enc_data.h")
153    start_source(source_fid, "dred_rdovae_enc_data.h", os.path.basename(args.weights))
154
155    header_fid.write(
156f"""
157#include "dred_rdovae_constants.h"
158
159#include "nnet.h"
160"""
161    )
162
163    # dump GRUs
164    max_rnn_neurons_enc = max(
165        [
166            dump_gru_layer(encoder.get_layer(name), source_fid, header_fid, dotp=True, sparse=True)
167            for name in encoder_gru_names
168        ]
169    )
170
171    # dump conv layers
172    max_conv_inputs = max(
173        [
174            dump_conv1d_layer(encoder.get_layer(name), source_fid, header_fid)
175            for name in encoder_conv1d_names
176        ]
177    )
178
179    # dump Dense layers
180    for name in encoder_dense_names:
181        layer = encoder.get_layer(name)
182        dump_dense_layer(layer, source_fid, header_fid)
183
184    # some global constants
185    header_fid.write(
186f"""
187
188#define DRED_ENC_MAX_RNN_NEURONS {max_rnn_neurons_enc}
189
190#define DRED_ENC_MAX_CONV_INPUTS {max_conv_inputs}
191
192"""
193    )
194
195    finish_header(header_fid)
196    finish_source(source_fid)
197
198    header_fid.close()
199    source_fid.close()
200
201    # statistical model
202    source_fid = open("dred_rdovae_stats_data.c", 'w')
203    header_fid = open("dred_rdovae_stats_data.h", 'w')
204
205    start_header(header_fid, "dred_rdovae_stats_data.h")
206    start_source(source_fid, "dred_rdovae_stats_data.h", os.path.basename(args.weights))
207
208    header_fid.write(
209"""
210
211#include "opus_types.h"
212
213"""
214    )
215
216    dump_statistical_model(qembedding, source_fid, header_fid)
217
218    finish_header(header_fid)
219    finish_source(source_fid)
220
221    header_fid.close()
222    source_fid.close()
223
224    # decoder
225    decoder_dense_names = [
226        'state1',
227        'state2',
228        'state3',
229        'dec_dense1',
230        'dec_dense3',
231        'dec_dense5',
232        'dec_dense7',
233        'dec_dense8',
234        'dec_final'
235    ]
236
237    decoder_gru_names = [
238        'dec_dense2',
239        'dec_dense4',
240        'dec_dense6'
241    ]
242
243    source_fid = open("dred_rdovae_dec_data.c", 'w')
244    header_fid = open("dred_rdovae_dec_data.h", 'w')
245
246    start_header(header_fid, "dred_rdovae_dec_data.h")
247    start_source(source_fid, "dred_rdovae_dec_data.h", os.path.basename(args.weights))
248
249    header_fid.write(
250f"""
251#include "dred_rdovae_constants.h"
252
253#include "nnet.h"
254"""
255    )
256
257
258    # dump GRUs
259    max_rnn_neurons_dec = max(
260        [
261            dump_gru_layer(decoder.get_layer(name), source_fid, header_fid, dotp=True, sparse=True)
262            for name in decoder_gru_names
263        ]
264    )
265
266    # dump Dense layers
267    for name in decoder_dense_names:
268        layer = decoder.get_layer(name)
269        dump_dense_layer(layer, source_fid, header_fid)
270
271    # some global constants
272    header_fid.write(
273f"""
274
275#define DRED_DEC_MAX_RNN_NEURONS {max_rnn_neurons_dec}
276
277"""
278    )
279
280    finish_header(header_fid)
281    finish_source(source_fid)
282
283    header_fid.close()
284    source_fid.close()
285
286    # common constants
287    header_fid = open("dred_rdovae_constants.h", 'w')
288    start_header(header_fid, "dred_rdovae_constants.h")
289
290    header_fid.write(
291f"""
292#define DRED_NUM_FEATURES 20
293
294#define DRED_LATENT_DIM {args.latent_dim}
295
296#define DRED_STATE_DIM {24}
297
298#define DRED_NUM_QUANTIZATION_LEVELS {qembedding.weights[0].shape[0]}
299
300#define DRED_MAX_RNN_NEURONS {max(max_rnn_neurons_enc, max_rnn_neurons_dec)}
301
302#define DRED_MAX_CONV_INPUTS {max_conv_inputs}
303"""
304    )
305
306    finish_header(header_fid)