1*a58d3d2aSXin Li#!/usr/bin/python3 2*a58d3d2aSXin Li'''Copyright (c) 2021-2022 Amazon 3*a58d3d2aSXin Li Copyright (c) 2018-2019 Mozilla 4*a58d3d2aSXin Li 5*a58d3d2aSXin Li Redistribution and use in source and binary forms, with or without 6*a58d3d2aSXin Li modification, are permitted provided that the following conditions 7*a58d3d2aSXin Li are met: 8*a58d3d2aSXin Li 9*a58d3d2aSXin Li - Redistributions of source code must retain the above copyright 10*a58d3d2aSXin Li notice, this list of conditions and the following disclaimer. 11*a58d3d2aSXin Li 12*a58d3d2aSXin Li - Redistributions in binary form must reproduce the above copyright 13*a58d3d2aSXin Li notice, this list of conditions and the following disclaimer in the 14*a58d3d2aSXin Li documentation and/or other materials provided with the distribution. 15*a58d3d2aSXin Li 16*a58d3d2aSXin Li THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17*a58d3d2aSXin Li ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18*a58d3d2aSXin Li LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19*a58d3d2aSXin Li A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR 20*a58d3d2aSXin Li CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21*a58d3d2aSXin Li EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22*a58d3d2aSXin Li PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23*a58d3d2aSXin Li PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24*a58d3d2aSXin Li LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25*a58d3d2aSXin Li NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26*a58d3d2aSXin Li SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*a58d3d2aSXin Li''' 28*a58d3d2aSXin Li 29*a58d3d2aSXin Li# Train an LPCNet model 30*a58d3d2aSXin Li 31*a58d3d2aSXin Liimport argparse 32*a58d3d2aSXin Lifrom plc_loader import PLCLoader 33*a58d3d2aSXin Li 34*a58d3d2aSXin Liparser = argparse.ArgumentParser(description='Train a PLC model') 35*a58d3d2aSXin Li 36*a58d3d2aSXin Liparser.add_argument('features', metavar='<features file>', help='binary features file (float32)') 37*a58d3d2aSXin Liparser.add_argument('lost_file', metavar='<packet loss file>', help='packet loss traces (int8)') 38*a58d3d2aSXin Liparser.add_argument('output', metavar='<output>', help='trained model file (.h5)') 39*a58d3d2aSXin Liparser.add_argument('--model', metavar='<model>', default='lpcnet_plc', help='PLC model python definition (without .py)') 40*a58d3d2aSXin Ligroup1 = parser.add_mutually_exclusive_group() 41*a58d3d2aSXin Ligroup1.add_argument('--quantize', metavar='<input weights>', help='quantize model') 42*a58d3d2aSXin Ligroup1.add_argument('--retrain', metavar='<input weights>', help='continue training model') 43*a58d3d2aSXin Liparser.add_argument('--gru-size', metavar='<units>', default=256, type=int, help='number of units in GRU (default 256)') 44*a58d3d2aSXin Liparser.add_argument('--cond-size', metavar='<units>', default=128, type=int, help='number of units in conditioning network (default 128)') 45*a58d3d2aSXin Liparser.add_argument('--epochs', metavar='<epochs>', default=120, type=int, help='number of epochs to train for (default 120)') 46*a58d3d2aSXin Liparser.add_argument('--batch-size', metavar='<batch size>', default=128, type=int, help='batch size to use (default 128)') 47*a58d3d2aSXin Liparser.add_argument('--seq-length', metavar='<sequence length>', default=1000, type=int, help='sequence length to use (default 1000)') 48*a58d3d2aSXin Liparser.add_argument('--lr', metavar='<learning rate>', type=float, help='learning rate') 49*a58d3d2aSXin Liparser.add_argument('--decay', metavar='<decay>', type=float, help='learning rate decay') 50*a58d3d2aSXin Liparser.add_argument('--band-loss', metavar='<weight>', default=1.0, type=float, help='weight of band loss (default 1.0)') 51*a58d3d2aSXin Liparser.add_argument('--loss-bias', metavar='<bias>', default=0.0, type=float, help='loss bias towards low energy (default 0.0)') 52*a58d3d2aSXin Liparser.add_argument('--logdir', metavar='<log dir>', help='directory for tensorboard log files') 53*a58d3d2aSXin Li 54*a58d3d2aSXin Li 55*a58d3d2aSXin Liargs = parser.parse_args() 56*a58d3d2aSXin Li 57*a58d3d2aSXin Liimport importlib 58*a58d3d2aSXin Lilpcnet = importlib.import_module(args.model) 59*a58d3d2aSXin Li 60*a58d3d2aSXin Liimport sys 61*a58d3d2aSXin Liimport numpy as np 62*a58d3d2aSXin Lifrom tensorflow.keras.optimizers import Adam 63*a58d3d2aSXin Lifrom tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger 64*a58d3d2aSXin Liimport tensorflow.keras.backend as K 65*a58d3d2aSXin Liimport h5py 66*a58d3d2aSXin Li 67*a58d3d2aSXin Liimport tensorflow as tf 68*a58d3d2aSXin Li#gpus = tf.config.experimental.list_physical_devices('GPU') 69*a58d3d2aSXin Li#if gpus: 70*a58d3d2aSXin Li# try: 71*a58d3d2aSXin Li# tf.config.experimental.set_virtual_device_configuration(gpus[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=5120)]) 72*a58d3d2aSXin Li# except RuntimeError as e: 73*a58d3d2aSXin Li# print(e) 74*a58d3d2aSXin Li 75*a58d3d2aSXin Linb_epochs = args.epochs 76*a58d3d2aSXin Li 77*a58d3d2aSXin Li# Try reducing batch_size if you run out of memory on your GPU 78*a58d3d2aSXin Libatch_size = args.batch_size 79*a58d3d2aSXin Li 80*a58d3d2aSXin Liquantize = args.quantize is not None 81*a58d3d2aSXin Liretrain = args.retrain is not None 82*a58d3d2aSXin Li 83*a58d3d2aSXin Liif quantize: 84*a58d3d2aSXin Li lr = 0.00003 85*a58d3d2aSXin Li decay = 0 86*a58d3d2aSXin Li input_model = args.quantize 87*a58d3d2aSXin Lielse: 88*a58d3d2aSXin Li lr = 0.001 89*a58d3d2aSXin Li decay = 2.5e-5 90*a58d3d2aSXin Li 91*a58d3d2aSXin Liif args.lr is not None: 92*a58d3d2aSXin Li lr = args.lr 93*a58d3d2aSXin Li 94*a58d3d2aSXin Liif args.decay is not None: 95*a58d3d2aSXin Li decay = args.decay 96*a58d3d2aSXin Li 97*a58d3d2aSXin Liif retrain: 98*a58d3d2aSXin Li input_model = args.retrain 99*a58d3d2aSXin Li 100*a58d3d2aSXin Lidef plc_loss(alpha=1.0, bias=0.): 101*a58d3d2aSXin Li def loss(y_true,y_pred): 102*a58d3d2aSXin Li mask = y_true[:,:,-1:] 103*a58d3d2aSXin Li y_true = y_true[:,:,:-1] 104*a58d3d2aSXin Li e = (y_pred - y_true)*mask 105*a58d3d2aSXin Li e_bands = tf.signal.idct(e[:,:,:-2], norm='ortho') 106*a58d3d2aSXin Li bias_mask = K.minimum(1., K.maximum(0., 4*y_true[:,:,-1:])) 107*a58d3d2aSXin Li l1_loss = K.mean(K.abs(e)) + 0.1*K.mean(K.maximum(0., -e[:,:,-1:])) + alpha*K.mean(K.abs(e_bands) + bias*bias_mask*K.maximum(0., e_bands)) + K.mean(K.minimum(K.abs(e[:,:,18:19]),1.)) + 8*K.mean(K.minimum(K.abs(e[:,:,18:19]),.4)) 108*a58d3d2aSXin Li return l1_loss 109*a58d3d2aSXin Li return loss 110*a58d3d2aSXin Li 111*a58d3d2aSXin Lidef plc_l1_loss(): 112*a58d3d2aSXin Li def L1_loss(y_true,y_pred): 113*a58d3d2aSXin Li mask = y_true[:,:,-1:] 114*a58d3d2aSXin Li y_true = y_true[:,:,:-1] 115*a58d3d2aSXin Li e = (y_pred - y_true)*mask 116*a58d3d2aSXin Li l1_loss = K.mean(K.abs(e)) 117*a58d3d2aSXin Li return l1_loss 118*a58d3d2aSXin Li return L1_loss 119*a58d3d2aSXin Li 120*a58d3d2aSXin Lidef plc_ceps_loss(): 121*a58d3d2aSXin Li def ceps_loss(y_true,y_pred): 122*a58d3d2aSXin Li mask = y_true[:,:,-1:] 123*a58d3d2aSXin Li y_true = y_true[:,:,:-1] 124*a58d3d2aSXin Li e = (y_pred - y_true)*mask 125*a58d3d2aSXin Li l1_loss = K.mean(K.abs(e[:,:,:-2])) 126*a58d3d2aSXin Li return l1_loss 127*a58d3d2aSXin Li return ceps_loss 128*a58d3d2aSXin Li 129*a58d3d2aSXin Lidef plc_band_loss(): 130*a58d3d2aSXin Li def L1_band_loss(y_true,y_pred): 131*a58d3d2aSXin Li mask = y_true[:,:,-1:] 132*a58d3d2aSXin Li y_true = y_true[:,:,:-1] 133*a58d3d2aSXin Li e = (y_pred - y_true)*mask 134*a58d3d2aSXin Li e_bands = tf.signal.idct(e[:,:,:-2], norm='ortho') 135*a58d3d2aSXin Li l1_loss = K.mean(K.abs(e_bands)) 136*a58d3d2aSXin Li return l1_loss 137*a58d3d2aSXin Li return L1_band_loss 138*a58d3d2aSXin Li 139*a58d3d2aSXin Lidef plc_pitch_loss(): 140*a58d3d2aSXin Li def pitch_loss(y_true,y_pred): 141*a58d3d2aSXin Li mask = y_true[:,:,-1:] 142*a58d3d2aSXin Li y_true = y_true[:,:,:-1] 143*a58d3d2aSXin Li e = (y_pred - y_true)*mask 144*a58d3d2aSXin Li l1_loss = K.mean(K.minimum(K.abs(e[:,:,18:19]),.4)) 145*a58d3d2aSXin Li return l1_loss 146*a58d3d2aSXin Li return pitch_loss 147*a58d3d2aSXin Li 148*a58d3d2aSXin Liopt = Adam(lr, decay=decay, beta_2=0.99) 149*a58d3d2aSXin Listrategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() 150*a58d3d2aSXin Li 151*a58d3d2aSXin Liwith strategy.scope(): 152*a58d3d2aSXin Li model = lpcnet.new_lpcnet_plc_model(rnn_units=args.gru_size, batch_size=batch_size, training=True, quantize=quantize, cond_size=args.cond_size) 153*a58d3d2aSXin Li model.compile(optimizer=opt, loss=plc_loss(alpha=args.band_loss, bias=args.loss_bias), metrics=[plc_l1_loss(), plc_ceps_loss(), plc_band_loss(), plc_pitch_loss()]) 154*a58d3d2aSXin Li model.summary() 155*a58d3d2aSXin Li 156*a58d3d2aSXin Lilpc_order = 16 157*a58d3d2aSXin Li 158*a58d3d2aSXin Lifeature_file = args.features 159*a58d3d2aSXin Linb_features = model.nb_used_features + lpc_order + model.nb_burg_features 160*a58d3d2aSXin Linb_used_features = model.nb_used_features 161*a58d3d2aSXin Linb_burg_features = model.nb_burg_features 162*a58d3d2aSXin Lisequence_size = args.seq_length 163*a58d3d2aSXin Li 164*a58d3d2aSXin Li# u for unquantised, load 16 bit PCM samples and convert to mu-law 165*a58d3d2aSXin Li 166*a58d3d2aSXin Li 167*a58d3d2aSXin Lifeatures = np.memmap(feature_file, dtype='float32', mode='r') 168*a58d3d2aSXin Linb_sequences = len(features)//(nb_features*sequence_size)//batch_size*batch_size 169*a58d3d2aSXin Lifeatures = features[:nb_sequences*sequence_size*nb_features] 170*a58d3d2aSXin Li 171*a58d3d2aSXin Lifeatures = np.reshape(features, (nb_sequences, sequence_size, nb_features)) 172*a58d3d2aSXin Li 173*a58d3d2aSXin Lifeatures = features[:, :, :nb_used_features+model.nb_burg_features] 174*a58d3d2aSXin Li 175*a58d3d2aSXin Lilost = np.memmap(args.lost_file, dtype='int8', mode='r') 176*a58d3d2aSXin Li 177*a58d3d2aSXin Li# dump models to disk as we go 178*a58d3d2aSXin Licheckpoint = ModelCheckpoint('{}_{}_{}.h5'.format(args.output, args.gru_size, '{epoch:02d}')) 179*a58d3d2aSXin Li 180*a58d3d2aSXin Liif args.retrain is not None: 181*a58d3d2aSXin Li model.load_weights(args.retrain) 182*a58d3d2aSXin Li 183*a58d3d2aSXin Liif quantize or retrain: 184*a58d3d2aSXin Li #Adapting from an existing model 185*a58d3d2aSXin Li model.load_weights(input_model) 186*a58d3d2aSXin Li 187*a58d3d2aSXin Limodel.save_weights('{}_{}_initial.h5'.format(args.output, args.gru_size)) 188*a58d3d2aSXin Li 189*a58d3d2aSXin Liloader = PLCLoader(features, lost, nb_burg_features, batch_size) 190*a58d3d2aSXin Li 191*a58d3d2aSXin Licallbacks = [checkpoint] 192*a58d3d2aSXin Liif args.logdir is not None: 193*a58d3d2aSXin Li logdir = '{}/{}_{}_logs'.format(args.logdir, args.output, args.gru_size) 194*a58d3d2aSXin Li tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir) 195*a58d3d2aSXin Li callbacks.append(tensorboard_callback) 196*a58d3d2aSXin Li 197*a58d3d2aSXin Limodel.fit(loader, epochs=nb_epochs, validation_split=0.0, callbacks=callbacks) 198