1#!/usr/bin/python3 2'''Copyright (c) 2018 Mozilla 3 4 Redistribution and use in source and binary forms, with or without 5 modification, are permitted provided that the following conditions 6 are met: 7 8 - Redistributions of source code must retain the above copyright 9 notice, this list of conditions and the following disclaimer. 10 11 - Redistributions in binary form must reproduce the above copyright 12 notice, this list of conditions and the following disclaimer in the 13 documentation and/or other materials provided with the distribution. 14 15 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 16 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 17 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 18 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR 19 CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 23 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 24 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26''' 27 28# Train an LPCNet model 29 30import argparse 31import os 32 33from dataloader import LPCNetLoader 34 35parser = argparse.ArgumentParser(description='Train an LPCNet model') 36 37parser.add_argument('features', metavar='<features file>', help='binary features file (float32)') 38parser.add_argument('data', metavar='<audio data file>', help='binary audio data file (uint8)') 39parser.add_argument('output', metavar='<output>', help='trained model file (.h5)') 40parser.add_argument('--model', metavar='<model>', default='lpcnet', help='LPCNet model python definition (without .py)') 41group1 = parser.add_mutually_exclusive_group() 42group1.add_argument('--quantize', metavar='<input weights>', help='quantize model') 43group1.add_argument('--retrain', metavar='<input weights>', help='continue training model') 44parser.add_argument('--density', metavar='<global density>', type=float, help='average density of the recurrent weights (default 0.1)') 45parser.add_argument('--density-split', nargs=3, metavar=('<update>', '<reset>', '<state>'), type=float, help='density of each recurrent gate (default 0.05, 0.05, 0.2)') 46parser.add_argument('--grub-density', metavar='<global GRU B density>', type=float, help='average density of the recurrent weights (default 1.0)') 47parser.add_argument('--grub-density-split', nargs=3, metavar=('<update>', '<reset>', '<state>'), type=float, help='density of each GRU B input gate (default 1.0, 1.0, 1.0)') 48parser.add_argument('--grua-size', metavar='<units>', default=384, type=int, help='number of units in GRU A (default 384)') 49parser.add_argument('--grub-size', metavar='<units>', default=16, type=int, help='number of units in GRU B (default 16)') 50parser.add_argument('--cond-size', metavar='<units>', default=128, type=int, help='number of units in conditioning network, aka frame rate network (default 128)') 51parser.add_argument('--epochs', metavar='<epochs>', default=120, type=int, help='number of epochs to train for (default 120)') 52parser.add_argument('--batch-size', metavar='<batch size>', default=128, type=int, help='batch size to use (default 128)') 53parser.add_argument('--end2end', dest='flag_e2e', action='store_true', help='Enable end-to-end training (with differentiable LPC computation') 54parser.add_argument('--lr', metavar='<learning rate>', type=float, help='learning rate') 55parser.add_argument('--decay', metavar='<decay>', type=float, help='learning rate decay') 56parser.add_argument('--gamma', metavar='<gamma>', type=float, help='adjust u-law compensation (default 2.0, should not be less than 1.0)') 57parser.add_argument('--lookahead', metavar='<nb frames>', default=2, type=int, help='Number of look-ahead frames (default 2)') 58parser.add_argument('--logdir', metavar='<log dir>', help='directory for tensorboard log files') 59parser.add_argument('--lpc-gamma', type=float, default=1, help='gamma for LPC weighting') 60parser.add_argument('--cuda-devices', metavar='<cuda devices>', type=str, default=None, help='string with comma separated cuda device ids') 61 62args = parser.parse_args() 63 64# set visible cuda devices 65if args.cuda_devices != None: 66 os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_devices 67 68density = (0.05, 0.05, 0.2) 69if args.density_split is not None: 70 density = args.density_split 71elif args.density is not None: 72 density = [0.5*args.density, 0.5*args.density, 2.0*args.density]; 73 74grub_density = (1., 1., 1.) 75if args.grub_density_split is not None: 76 grub_density = args.grub_density_split 77elif args.grub_density is not None: 78 grub_density = [0.5*args.grub_density, 0.5*args.grub_density, 2.0*args.grub_density]; 79 80gamma = 2.0 if args.gamma is None else args.gamma 81 82import importlib 83lpcnet = importlib.import_module(args.model) 84 85import sys 86import numpy as np 87from tensorflow.keras.optimizers import Adam 88from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger 89from ulaw import ulaw2lin, lin2ulaw 90import tensorflow.keras.backend as K 91import h5py 92 93import tensorflow as tf 94from tf_funcs import * 95from lossfuncs import * 96#gpus = tf.config.experimental.list_physical_devices('GPU') 97#if gpus: 98# try: 99# tf.config.experimental.set_virtual_device_configuration(gpus[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=5120)]) 100# except RuntimeError as e: 101# print(e) 102 103nb_epochs = args.epochs 104 105# Try reducing batch_size if you run out of memory on your GPU 106batch_size = args.batch_size 107 108quantize = args.quantize is not None 109retrain = args.retrain is not None 110 111lpc_order = 16 112 113if quantize: 114 lr = 0.00003 115 decay = 0 116 input_model = args.quantize 117else: 118 lr = 0.001 119 decay = 5e-5 120 121if args.lr is not None: 122 lr = args.lr 123 124if args.decay is not None: 125 decay = args.decay 126 127if retrain: 128 input_model = args.retrain 129 130flag_e2e = args.flag_e2e 131 132opt = Adam(lr, decay=decay, beta_1=0.5, beta_2=0.8) 133strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() 134 135with strategy.scope(): 136 model, _, _ = lpcnet.new_lpcnet_model(rnn_units1=args.grua_size, 137 rnn_units2=args.grub_size, 138 batch_size=batch_size, training=True, 139 quantize=quantize, 140 flag_e2e=flag_e2e, 141 cond_size=args.cond_size, 142 lpc_gamma=args.lpc_gamma, 143 lookahead=args.lookahead 144 ) 145 if not flag_e2e: 146 model.compile(optimizer=opt, loss=metric_cel, metrics=metric_cel) 147 else: 148 model.compile(optimizer=opt, loss = [interp_mulaw(gamma=gamma), loss_matchlar()], loss_weights = [1.0, 2.0], metrics={'pdf':[metric_cel,metric_icel,metric_exc_sd,metric_oginterploss]}) 149 model.summary() 150 151feature_file = args.features 152pcm_file = args.data # 16 bit unsigned short PCM samples 153frame_size = model.frame_size 154nb_features = model.nb_used_features + lpc_order 155nb_used_features = model.nb_used_features 156feature_chunk_size = 15 157pcm_chunk_size = frame_size*feature_chunk_size 158 159# u for unquantised, load 16 bit PCM samples and convert to mu-law 160 161data = np.memmap(pcm_file, dtype='int16', mode='r') 162nb_frames = (len(data)//(2*pcm_chunk_size)-1)//batch_size*batch_size 163 164features = np.memmap(feature_file, dtype='float32', mode='r') 165 166# limit to discrete number of frames 167data = data[(4-args.lookahead)*2*frame_size:] 168data = data[:nb_frames*2*pcm_chunk_size] 169 170 171data = np.reshape(data, (nb_frames, pcm_chunk_size, 2)) 172 173#print("ulaw std = ", np.std(out_exc)) 174 175sizeof = features.strides[-1] 176features = np.lib.stride_tricks.as_strided(features, shape=(nb_frames, feature_chunk_size+4, nb_features), 177 strides=(feature_chunk_size*nb_features*sizeof, nb_features*sizeof, sizeof)) 178#features = features[:, :, :nb_used_features] 179 180 181periods = (.1 + 50*features[:,:,nb_used_features-2:nb_used_features-1]+100).astype('int16') 182#periods = np.minimum(periods, 255) 183 184# dump models to disk as we go 185checkpoint = ModelCheckpoint('{}_{}_{}.h5'.format(args.output, args.grua_size, '{epoch:02d}')) 186 187if args.retrain is not None: 188 model.load_weights(args.retrain) 189 190if quantize or retrain: 191 #Adapting from an existing model 192 model.load_weights(input_model) 193 if quantize: 194 sparsify = lpcnet.Sparsify(10000, 30000, 100, density, quantize=True) 195 grub_sparsify = lpcnet.SparsifyGRUB(10000, 30000, 100, args.grua_size, grub_density, quantize=True) 196 else: 197 sparsify = lpcnet.Sparsify(0, 0, 1, density) 198 grub_sparsify = lpcnet.SparsifyGRUB(0, 0, 1, args.grua_size, grub_density) 199else: 200 #Training from scratch 201 sparsify = lpcnet.Sparsify(2000, 20000, 400, density) 202 grub_sparsify = lpcnet.SparsifyGRUB(2000, 40000, 400, args.grua_size, grub_density) 203 204model.save_weights('{}_{}_initial.h5'.format(args.output, args.grua_size)) 205 206loader = LPCNetLoader(data, features, periods, batch_size, e2e=flag_e2e, lookahead=args.lookahead) 207 208callbacks = [checkpoint, sparsify, grub_sparsify] 209if args.logdir is not None: 210 logdir = '{}/{}_{}_logs'.format(args.logdir, args.output, args.grua_size) 211 tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir) 212 callbacks.append(tensorboard_callback) 213 214model.fit(loader, epochs=nb_epochs, validation_split=0.0, callbacks=callbacks) 215