xref: /aosp_15_r20/external/libopus/dnn/training_tf2/rdovae_exchange.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1*a58d3d2aSXin Li"""
2*a58d3d2aSXin Li/* Copyright (c) 2022 Amazon
3*a58d3d2aSXin Li   Written by Jan Buethe */
4*a58d3d2aSXin Li/*
5*a58d3d2aSXin Li   Redistribution and use in source and binary forms, with or without
6*a58d3d2aSXin Li   modification, are permitted provided that the following conditions
7*a58d3d2aSXin Li   are met:
8*a58d3d2aSXin Li
9*a58d3d2aSXin Li   - Redistributions of source code must retain the above copyright
10*a58d3d2aSXin Li   notice, this list of conditions and the following disclaimer.
11*a58d3d2aSXin Li
12*a58d3d2aSXin Li   - Redistributions in binary form must reproduce the above copyright
13*a58d3d2aSXin Li   notice, this list of conditions and the following disclaimer in the
14*a58d3d2aSXin Li   documentation and/or other materials provided with the distribution.
15*a58d3d2aSXin Li
16*a58d3d2aSXin Li   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17*a58d3d2aSXin Li   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18*a58d3d2aSXin Li   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19*a58d3d2aSXin Li   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20*a58d3d2aSXin Li   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21*a58d3d2aSXin Li   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22*a58d3d2aSXin Li   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23*a58d3d2aSXin Li   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24*a58d3d2aSXin Li   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25*a58d3d2aSXin Li   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26*a58d3d2aSXin Li   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*a58d3d2aSXin Li*/
28*a58d3d2aSXin Li"""
29*a58d3d2aSXin Li
30*a58d3d2aSXin Li
31*a58d3d2aSXin Liimport argparse
32*a58d3d2aSXin Liimport os
33*a58d3d2aSXin Liimport sys
34*a58d3d2aSXin Li
35*a58d3d2aSXin Lios.environ['CUDA_VISIBLE_DEVICES'] = ""
36*a58d3d2aSXin Li
37*a58d3d2aSXin Liparser = argparse.ArgumentParser()
38*a58d3d2aSXin Li
39*a58d3d2aSXin Liparser.add_argument('weights', metavar="<weight file>", type=str, help='model weight file in hdf5 format')
40*a58d3d2aSXin Liparser.add_argument('output', metavar="<output folder>", type=str, help='output exchange folder')
41*a58d3d2aSXin Liparser.add_argument('--cond-size', type=int, help="conditioning size (default: 256)", default=256)
42*a58d3d2aSXin Liparser.add_argument('--latent-dim', type=int, help="dimension of latent space (default: 80)", default=80)
43*a58d3d2aSXin Liparser.add_argument('--quant-levels', type=int, help="number of quantization steps (default: 16)", default=16)
44*a58d3d2aSXin Li
45*a58d3d2aSXin Liargs = parser.parse_args()
46*a58d3d2aSXin Li
47*a58d3d2aSXin Li# now import the heavy stuff
48*a58d3d2aSXin Lifrom rdovae import new_rdovae_model
49*a58d3d2aSXin Lifrom wexchange.tf import dump_tf_weights, load_tf_weights
50*a58d3d2aSXin Li
51*a58d3d2aSXin Li
52*a58d3d2aSXin Liexchange_name = {
53*a58d3d2aSXin Li    'enc_dense1'    : 'encoder_stack_layer1_dense',
54*a58d3d2aSXin Li    'enc_dense3'    : 'encoder_stack_layer3_dense',
55*a58d3d2aSXin Li    'enc_dense5'    : 'encoder_stack_layer5_dense',
56*a58d3d2aSXin Li    'enc_dense7'    : 'encoder_stack_layer7_dense',
57*a58d3d2aSXin Li    'enc_dense8'    : 'encoder_stack_layer8_dense',
58*a58d3d2aSXin Li    'gdense1'       : 'encoder_state_layer1_dense',
59*a58d3d2aSXin Li    'gdense2'       : 'encoder_state_layer2_dense',
60*a58d3d2aSXin Li    'enc_dense2'    : 'encoder_stack_layer2_gru',
61*a58d3d2aSXin Li    'enc_dense4'    : 'encoder_stack_layer4_gru',
62*a58d3d2aSXin Li    'enc_dense6'    : 'encoder_stack_layer6_gru',
63*a58d3d2aSXin Li    'bits_dense'    : 'encoder_stack_layer9_conv',
64*a58d3d2aSXin Li    'qembedding'    : 'statistical_model_embedding',
65*a58d3d2aSXin Li    'state1'        : 'decoder_state1_dense',
66*a58d3d2aSXin Li    'state2'        : 'decoder_state2_dense',
67*a58d3d2aSXin Li    'state3'        : 'decoder_state3_dense',
68*a58d3d2aSXin Li    'dec_dense1'    : 'decoder_stack_layer1_dense',
69*a58d3d2aSXin Li    'dec_dense3'    : 'decoder_stack_layer3_dense',
70*a58d3d2aSXin Li    'dec_dense5'    : 'decoder_stack_layer5_dense',
71*a58d3d2aSXin Li    'dec_dense7'    : 'decoder_stack_layer7_dense',
72*a58d3d2aSXin Li    'dec_dense8'    : 'decoder_stack_layer8_dense',
73*a58d3d2aSXin Li    'dec_final'     : 'decoder_stack_layer9_dense',
74*a58d3d2aSXin Li    'dec_dense2'    : 'decoder_stack_layer2_gru',
75*a58d3d2aSXin Li    'dec_dense4'    : 'decoder_stack_layer4_gru',
76*a58d3d2aSXin Li    'dec_dense6'    : 'decoder_stack_layer6_gru'
77*a58d3d2aSXin Li}
78*a58d3d2aSXin Li
79*a58d3d2aSXin Li
80*a58d3d2aSXin Liif __name__ == "__main__":
81*a58d3d2aSXin Li
82*a58d3d2aSXin Li    model, encoder, decoder, qembedding = new_rdovae_model(20, args.latent_dim, cond_size=args.cond_size, nb_quant=args.quant_levels)
83*a58d3d2aSXin Li    model.load_weights(args.weights)
84*a58d3d2aSXin Li
85*a58d3d2aSXin Li    os.makedirs(args.output, exist_ok=True)
86*a58d3d2aSXin Li
87*a58d3d2aSXin Li    # encoder
88*a58d3d2aSXin Li    encoder_dense_names = [
89*a58d3d2aSXin Li        'enc_dense1',
90*a58d3d2aSXin Li        'enc_dense3',
91*a58d3d2aSXin Li        'enc_dense5',
92*a58d3d2aSXin Li        'enc_dense7',
93*a58d3d2aSXin Li        'enc_dense8',
94*a58d3d2aSXin Li        'gdense1',
95*a58d3d2aSXin Li        'gdense2'
96*a58d3d2aSXin Li    ]
97*a58d3d2aSXin Li
98*a58d3d2aSXin Li    encoder_gru_names = [
99*a58d3d2aSXin Li        'enc_dense2',
100*a58d3d2aSXin Li        'enc_dense4',
101*a58d3d2aSXin Li        'enc_dense6'
102*a58d3d2aSXin Li    ]
103*a58d3d2aSXin Li
104*a58d3d2aSXin Li    encoder_conv1d_names = [
105*a58d3d2aSXin Li        'bits_dense'
106*a58d3d2aSXin Li    ]
107*a58d3d2aSXin Li
108*a58d3d2aSXin Li
109*a58d3d2aSXin Li    for name in encoder_dense_names + encoder_gru_names + encoder_conv1d_names:
110*a58d3d2aSXin Li        print(f"writing layer {exchange_name[name]}...")
111*a58d3d2aSXin Li        dump_tf_weights(os.path.join(args.output, exchange_name[name]), encoder.get_layer(name))
112*a58d3d2aSXin Li
113*a58d3d2aSXin Li    # qembedding
114*a58d3d2aSXin Li    print(f"writing layer {exchange_name['qembedding']}...")
115*a58d3d2aSXin Li    dump_tf_weights(os.path.join(args.output, exchange_name['qembedding']), qembedding)
116*a58d3d2aSXin Li
117*a58d3d2aSXin Li    # decoder
118*a58d3d2aSXin Li    decoder_dense_names = [
119*a58d3d2aSXin Li        'state1',
120*a58d3d2aSXin Li        'state2',
121*a58d3d2aSXin Li        'state3',
122*a58d3d2aSXin Li        'dec_dense1',
123*a58d3d2aSXin Li        'dec_dense3',
124*a58d3d2aSXin Li        'dec_dense5',
125*a58d3d2aSXin Li        'dec_dense7',
126*a58d3d2aSXin Li        'dec_dense8',
127*a58d3d2aSXin Li        'dec_final'
128*a58d3d2aSXin Li    ]
129*a58d3d2aSXin Li
130*a58d3d2aSXin Li    decoder_gru_names = [
131*a58d3d2aSXin Li        'dec_dense2',
132*a58d3d2aSXin Li        'dec_dense4',
133*a58d3d2aSXin Li        'dec_dense6'
134*a58d3d2aSXin Li    ]
135*a58d3d2aSXin Li
136*a58d3d2aSXin Li    for name in decoder_dense_names + decoder_gru_names:
137*a58d3d2aSXin Li        print(f"writing layer {exchange_name[name]}...")
138*a58d3d2aSXin Li        dump_tf_weights(os.path.join(args.output, exchange_name[name]), decoder.get_layer(name))
139