1*a58d3d2aSXin Li""" 2*a58d3d2aSXin Li/* Copyright (c) 2022 Amazon 3*a58d3d2aSXin Li Written by Jan Buethe */ 4*a58d3d2aSXin Li/* 5*a58d3d2aSXin Li Redistribution and use in source and binary forms, with or without 6*a58d3d2aSXin Li modification, are permitted provided that the following conditions 7*a58d3d2aSXin Li are met: 8*a58d3d2aSXin Li 9*a58d3d2aSXin Li - Redistributions of source code must retain the above copyright 10*a58d3d2aSXin Li notice, this list of conditions and the following disclaimer. 11*a58d3d2aSXin Li 12*a58d3d2aSXin Li - Redistributions in binary form must reproduce the above copyright 13*a58d3d2aSXin Li notice, this list of conditions and the following disclaimer in the 14*a58d3d2aSXin Li documentation and/or other materials provided with the distribution. 15*a58d3d2aSXin Li 16*a58d3d2aSXin Li THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17*a58d3d2aSXin Li ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18*a58d3d2aSXin Li LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19*a58d3d2aSXin Li A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20*a58d3d2aSXin Li OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21*a58d3d2aSXin Li EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22*a58d3d2aSXin Li PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23*a58d3d2aSXin Li PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24*a58d3d2aSXin Li LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25*a58d3d2aSXin Li NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26*a58d3d2aSXin Li SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*a58d3d2aSXin Li*/ 28*a58d3d2aSXin Li""" 29*a58d3d2aSXin Li 30*a58d3d2aSXin Li 31*a58d3d2aSXin Liimport argparse 32*a58d3d2aSXin Liimport os 33*a58d3d2aSXin Liimport sys 34*a58d3d2aSXin Li 35*a58d3d2aSXin Lios.environ['CUDA_VISIBLE_DEVICES'] = "" 36*a58d3d2aSXin Li 37*a58d3d2aSXin Liparser = argparse.ArgumentParser() 38*a58d3d2aSXin Li 39*a58d3d2aSXin Liparser.add_argument('weights', metavar="<weight file>", type=str, help='model weight file in hdf5 format') 40*a58d3d2aSXin Liparser.add_argument('output', metavar="<output folder>", type=str, help='output exchange folder') 41*a58d3d2aSXin Liparser.add_argument('--cond-size', type=int, help="conditioning size (default: 256)", default=256) 42*a58d3d2aSXin Liparser.add_argument('--latent-dim', type=int, help="dimension of latent space (default: 80)", default=80) 43*a58d3d2aSXin Liparser.add_argument('--quant-levels', type=int, help="number of quantization steps (default: 16)", default=16) 44*a58d3d2aSXin Li 45*a58d3d2aSXin Liargs = parser.parse_args() 46*a58d3d2aSXin Li 47*a58d3d2aSXin Li# now import the heavy stuff 48*a58d3d2aSXin Lifrom rdovae import new_rdovae_model 49*a58d3d2aSXin Lifrom wexchange.tf import dump_tf_weights, load_tf_weights 50*a58d3d2aSXin Li 51*a58d3d2aSXin Li 52*a58d3d2aSXin Liexchange_name = { 53*a58d3d2aSXin Li 'enc_dense1' : 'encoder_stack_layer1_dense', 54*a58d3d2aSXin Li 'enc_dense3' : 'encoder_stack_layer3_dense', 55*a58d3d2aSXin Li 'enc_dense5' : 'encoder_stack_layer5_dense', 56*a58d3d2aSXin Li 'enc_dense7' : 'encoder_stack_layer7_dense', 57*a58d3d2aSXin Li 'enc_dense8' : 'encoder_stack_layer8_dense', 58*a58d3d2aSXin Li 'gdense1' : 'encoder_state_layer1_dense', 59*a58d3d2aSXin Li 'gdense2' : 'encoder_state_layer2_dense', 60*a58d3d2aSXin Li 'enc_dense2' : 'encoder_stack_layer2_gru', 61*a58d3d2aSXin Li 'enc_dense4' : 'encoder_stack_layer4_gru', 62*a58d3d2aSXin Li 'enc_dense6' : 'encoder_stack_layer6_gru', 63*a58d3d2aSXin Li 'bits_dense' : 'encoder_stack_layer9_conv', 64*a58d3d2aSXin Li 'qembedding' : 'statistical_model_embedding', 65*a58d3d2aSXin Li 'state1' : 'decoder_state1_dense', 66*a58d3d2aSXin Li 'state2' : 'decoder_state2_dense', 67*a58d3d2aSXin Li 'state3' : 'decoder_state3_dense', 68*a58d3d2aSXin Li 'dec_dense1' : 'decoder_stack_layer1_dense', 69*a58d3d2aSXin Li 'dec_dense3' : 'decoder_stack_layer3_dense', 70*a58d3d2aSXin Li 'dec_dense5' : 'decoder_stack_layer5_dense', 71*a58d3d2aSXin Li 'dec_dense7' : 'decoder_stack_layer7_dense', 72*a58d3d2aSXin Li 'dec_dense8' : 'decoder_stack_layer8_dense', 73*a58d3d2aSXin Li 'dec_final' : 'decoder_stack_layer9_dense', 74*a58d3d2aSXin Li 'dec_dense2' : 'decoder_stack_layer2_gru', 75*a58d3d2aSXin Li 'dec_dense4' : 'decoder_stack_layer4_gru', 76*a58d3d2aSXin Li 'dec_dense6' : 'decoder_stack_layer6_gru' 77*a58d3d2aSXin Li} 78*a58d3d2aSXin Li 79*a58d3d2aSXin Li 80*a58d3d2aSXin Liif __name__ == "__main__": 81*a58d3d2aSXin Li 82*a58d3d2aSXin Li model, encoder, decoder, qembedding = new_rdovae_model(20, args.latent_dim, cond_size=args.cond_size, nb_quant=args.quant_levels) 83*a58d3d2aSXin Li model.load_weights(args.weights) 84*a58d3d2aSXin Li 85*a58d3d2aSXin Li os.makedirs(args.output, exist_ok=True) 86*a58d3d2aSXin Li 87*a58d3d2aSXin Li # encoder 88*a58d3d2aSXin Li encoder_dense_names = [ 89*a58d3d2aSXin Li 'enc_dense1', 90*a58d3d2aSXin Li 'enc_dense3', 91*a58d3d2aSXin Li 'enc_dense5', 92*a58d3d2aSXin Li 'enc_dense7', 93*a58d3d2aSXin Li 'enc_dense8', 94*a58d3d2aSXin Li 'gdense1', 95*a58d3d2aSXin Li 'gdense2' 96*a58d3d2aSXin Li ] 97*a58d3d2aSXin Li 98*a58d3d2aSXin Li encoder_gru_names = [ 99*a58d3d2aSXin Li 'enc_dense2', 100*a58d3d2aSXin Li 'enc_dense4', 101*a58d3d2aSXin Li 'enc_dense6' 102*a58d3d2aSXin Li ] 103*a58d3d2aSXin Li 104*a58d3d2aSXin Li encoder_conv1d_names = [ 105*a58d3d2aSXin Li 'bits_dense' 106*a58d3d2aSXin Li ] 107*a58d3d2aSXin Li 108*a58d3d2aSXin Li 109*a58d3d2aSXin Li for name in encoder_dense_names + encoder_gru_names + encoder_conv1d_names: 110*a58d3d2aSXin Li print(f"writing layer {exchange_name[name]}...") 111*a58d3d2aSXin Li dump_tf_weights(os.path.join(args.output, exchange_name[name]), encoder.get_layer(name)) 112*a58d3d2aSXin Li 113*a58d3d2aSXin Li # qembedding 114*a58d3d2aSXin Li print(f"writing layer {exchange_name['qembedding']}...") 115*a58d3d2aSXin Li dump_tf_weights(os.path.join(args.output, exchange_name['qembedding']), qembedding) 116*a58d3d2aSXin Li 117*a58d3d2aSXin Li # decoder 118*a58d3d2aSXin Li decoder_dense_names = [ 119*a58d3d2aSXin Li 'state1', 120*a58d3d2aSXin Li 'state2', 121*a58d3d2aSXin Li 'state3', 122*a58d3d2aSXin Li 'dec_dense1', 123*a58d3d2aSXin Li 'dec_dense3', 124*a58d3d2aSXin Li 'dec_dense5', 125*a58d3d2aSXin Li 'dec_dense7', 126*a58d3d2aSXin Li 'dec_dense8', 127*a58d3d2aSXin Li 'dec_final' 128*a58d3d2aSXin Li ] 129*a58d3d2aSXin Li 130*a58d3d2aSXin Li decoder_gru_names = [ 131*a58d3d2aSXin Li 'dec_dense2', 132*a58d3d2aSXin Li 'dec_dense4', 133*a58d3d2aSXin Li 'dec_dense6' 134*a58d3d2aSXin Li ] 135*a58d3d2aSXin Li 136*a58d3d2aSXin Li for name in decoder_dense_names + decoder_gru_names: 137*a58d3d2aSXin Li print(f"writing layer {exchange_name[name]}...") 138*a58d3d2aSXin Li dump_tf_weights(os.path.join(args.output, exchange_name[name]), decoder.get_layer(name)) 139