xref: /aosp_15_r20/external/libopus/dnn/training_tf2/train_plc.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1*a58d3d2aSXin Li#!/usr/bin/python3
2*a58d3d2aSXin Li'''Copyright (c) 2021-2022 Amazon
3*a58d3d2aSXin Li   Copyright (c) 2018-2019 Mozilla
4*a58d3d2aSXin Li
5*a58d3d2aSXin Li   Redistribution and use in source and binary forms, with or without
6*a58d3d2aSXin Li   modification, are permitted provided that the following conditions
7*a58d3d2aSXin Li   are met:
8*a58d3d2aSXin Li
9*a58d3d2aSXin Li   - Redistributions of source code must retain the above copyright
10*a58d3d2aSXin Li   notice, this list of conditions and the following disclaimer.
11*a58d3d2aSXin Li
12*a58d3d2aSXin Li   - Redistributions in binary form must reproduce the above copyright
13*a58d3d2aSXin Li   notice, this list of conditions and the following disclaimer in the
14*a58d3d2aSXin Li   documentation and/or other materials provided with the distribution.
15*a58d3d2aSXin Li
16*a58d3d2aSXin Li   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17*a58d3d2aSXin Li   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18*a58d3d2aSXin Li   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19*a58d3d2aSXin Li   A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
20*a58d3d2aSXin Li   CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21*a58d3d2aSXin Li   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22*a58d3d2aSXin Li   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23*a58d3d2aSXin Li   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24*a58d3d2aSXin Li   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25*a58d3d2aSXin Li   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26*a58d3d2aSXin Li   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*a58d3d2aSXin Li'''
28*a58d3d2aSXin Li
29*a58d3d2aSXin Li# Train an LPCNet model
30*a58d3d2aSXin Li
31*a58d3d2aSXin Liimport argparse
32*a58d3d2aSXin Lifrom plc_loader import PLCLoader
33*a58d3d2aSXin Li
34*a58d3d2aSXin Liparser = argparse.ArgumentParser(description='Train a PLC model')
35*a58d3d2aSXin Li
36*a58d3d2aSXin Liparser.add_argument('features', metavar='<features file>', help='binary features file (float32)')
37*a58d3d2aSXin Liparser.add_argument('lost_file', metavar='<packet loss file>', help='packet loss traces (int8)')
38*a58d3d2aSXin Liparser.add_argument('output', metavar='<output>', help='trained model file (.h5)')
39*a58d3d2aSXin Liparser.add_argument('--model', metavar='<model>', default='lpcnet_plc', help='PLC model python definition (without .py)')
40*a58d3d2aSXin Ligroup1 = parser.add_mutually_exclusive_group()
41*a58d3d2aSXin Ligroup1.add_argument('--quantize', metavar='<input weights>', help='quantize model')
42*a58d3d2aSXin Ligroup1.add_argument('--retrain', metavar='<input weights>', help='continue training model')
43*a58d3d2aSXin Liparser.add_argument('--gru-size', metavar='<units>', default=256, type=int, help='number of units in GRU (default 256)')
44*a58d3d2aSXin Liparser.add_argument('--cond-size', metavar='<units>', default=128, type=int, help='number of units in conditioning network (default 128)')
45*a58d3d2aSXin Liparser.add_argument('--epochs', metavar='<epochs>', default=120, type=int, help='number of epochs to train for (default 120)')
46*a58d3d2aSXin Liparser.add_argument('--batch-size', metavar='<batch size>', default=128, type=int, help='batch size to use (default 128)')
47*a58d3d2aSXin Liparser.add_argument('--seq-length', metavar='<sequence length>', default=1000, type=int, help='sequence length to use (default 1000)')
48*a58d3d2aSXin Liparser.add_argument('--lr', metavar='<learning rate>', type=float, help='learning rate')
49*a58d3d2aSXin Liparser.add_argument('--decay', metavar='<decay>', type=float, help='learning rate decay')
50*a58d3d2aSXin Liparser.add_argument('--band-loss', metavar='<weight>', default=1.0, type=float, help='weight of band loss (default 1.0)')
51*a58d3d2aSXin Liparser.add_argument('--loss-bias', metavar='<bias>', default=0.0, type=float, help='loss bias towards low energy (default 0.0)')
52*a58d3d2aSXin Liparser.add_argument('--logdir', metavar='<log dir>', help='directory for tensorboard log files')
53*a58d3d2aSXin Li
54*a58d3d2aSXin Li
55*a58d3d2aSXin Liargs = parser.parse_args()
56*a58d3d2aSXin Li
57*a58d3d2aSXin Liimport importlib
58*a58d3d2aSXin Lilpcnet = importlib.import_module(args.model)
59*a58d3d2aSXin Li
60*a58d3d2aSXin Liimport sys
61*a58d3d2aSXin Liimport numpy as np
62*a58d3d2aSXin Lifrom tensorflow.keras.optimizers import Adam
63*a58d3d2aSXin Lifrom tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger
64*a58d3d2aSXin Liimport tensorflow.keras.backend as K
65*a58d3d2aSXin Liimport h5py
66*a58d3d2aSXin Li
67*a58d3d2aSXin Liimport tensorflow as tf
68*a58d3d2aSXin Li#gpus = tf.config.experimental.list_physical_devices('GPU')
69*a58d3d2aSXin Li#if gpus:
70*a58d3d2aSXin Li#  try:
71*a58d3d2aSXin Li#    tf.config.experimental.set_virtual_device_configuration(gpus[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=5120)])
72*a58d3d2aSXin Li#  except RuntimeError as e:
73*a58d3d2aSXin Li#    print(e)
74*a58d3d2aSXin Li
75*a58d3d2aSXin Linb_epochs = args.epochs
76*a58d3d2aSXin Li
77*a58d3d2aSXin Li# Try reducing batch_size if you run out of memory on your GPU
78*a58d3d2aSXin Libatch_size = args.batch_size
79*a58d3d2aSXin Li
80*a58d3d2aSXin Liquantize = args.quantize is not None
81*a58d3d2aSXin Liretrain = args.retrain is not None
82*a58d3d2aSXin Li
83*a58d3d2aSXin Liif quantize:
84*a58d3d2aSXin Li    lr = 0.00003
85*a58d3d2aSXin Li    decay = 0
86*a58d3d2aSXin Li    input_model = args.quantize
87*a58d3d2aSXin Lielse:
88*a58d3d2aSXin Li    lr = 0.001
89*a58d3d2aSXin Li    decay = 2.5e-5
90*a58d3d2aSXin Li
91*a58d3d2aSXin Liif args.lr is not None:
92*a58d3d2aSXin Li    lr = args.lr
93*a58d3d2aSXin Li
94*a58d3d2aSXin Liif args.decay is not None:
95*a58d3d2aSXin Li    decay = args.decay
96*a58d3d2aSXin Li
97*a58d3d2aSXin Liif retrain:
98*a58d3d2aSXin Li    input_model = args.retrain
99*a58d3d2aSXin Li
100*a58d3d2aSXin Lidef plc_loss(alpha=1.0, bias=0.):
101*a58d3d2aSXin Li    def loss(y_true,y_pred):
102*a58d3d2aSXin Li        mask = y_true[:,:,-1:]
103*a58d3d2aSXin Li        y_true = y_true[:,:,:-1]
104*a58d3d2aSXin Li        e = (y_pred - y_true)*mask
105*a58d3d2aSXin Li        e_bands = tf.signal.idct(e[:,:,:-2], norm='ortho')
106*a58d3d2aSXin Li        bias_mask = K.minimum(1., K.maximum(0., 4*y_true[:,:,-1:]))
107*a58d3d2aSXin Li        l1_loss = K.mean(K.abs(e)) + 0.1*K.mean(K.maximum(0., -e[:,:,-1:])) + alpha*K.mean(K.abs(e_bands) + bias*bias_mask*K.maximum(0., e_bands)) + K.mean(K.minimum(K.abs(e[:,:,18:19]),1.)) + 8*K.mean(K.minimum(K.abs(e[:,:,18:19]),.4))
108*a58d3d2aSXin Li        return l1_loss
109*a58d3d2aSXin Li    return loss
110*a58d3d2aSXin Li
111*a58d3d2aSXin Lidef plc_l1_loss():
112*a58d3d2aSXin Li    def L1_loss(y_true,y_pred):
113*a58d3d2aSXin Li        mask = y_true[:,:,-1:]
114*a58d3d2aSXin Li        y_true = y_true[:,:,:-1]
115*a58d3d2aSXin Li        e = (y_pred - y_true)*mask
116*a58d3d2aSXin Li        l1_loss = K.mean(K.abs(e))
117*a58d3d2aSXin Li        return l1_loss
118*a58d3d2aSXin Li    return L1_loss
119*a58d3d2aSXin Li
120*a58d3d2aSXin Lidef plc_ceps_loss():
121*a58d3d2aSXin Li    def ceps_loss(y_true,y_pred):
122*a58d3d2aSXin Li        mask = y_true[:,:,-1:]
123*a58d3d2aSXin Li        y_true = y_true[:,:,:-1]
124*a58d3d2aSXin Li        e = (y_pred - y_true)*mask
125*a58d3d2aSXin Li        l1_loss = K.mean(K.abs(e[:,:,:-2]))
126*a58d3d2aSXin Li        return l1_loss
127*a58d3d2aSXin Li    return ceps_loss
128*a58d3d2aSXin Li
129*a58d3d2aSXin Lidef plc_band_loss():
130*a58d3d2aSXin Li    def L1_band_loss(y_true,y_pred):
131*a58d3d2aSXin Li        mask = y_true[:,:,-1:]
132*a58d3d2aSXin Li        y_true = y_true[:,:,:-1]
133*a58d3d2aSXin Li        e = (y_pred - y_true)*mask
134*a58d3d2aSXin Li        e_bands = tf.signal.idct(e[:,:,:-2], norm='ortho')
135*a58d3d2aSXin Li        l1_loss = K.mean(K.abs(e_bands))
136*a58d3d2aSXin Li        return l1_loss
137*a58d3d2aSXin Li    return L1_band_loss
138*a58d3d2aSXin Li
139*a58d3d2aSXin Lidef plc_pitch_loss():
140*a58d3d2aSXin Li    def pitch_loss(y_true,y_pred):
141*a58d3d2aSXin Li        mask = y_true[:,:,-1:]
142*a58d3d2aSXin Li        y_true = y_true[:,:,:-1]
143*a58d3d2aSXin Li        e = (y_pred - y_true)*mask
144*a58d3d2aSXin Li        l1_loss = K.mean(K.minimum(K.abs(e[:,:,18:19]),.4))
145*a58d3d2aSXin Li        return l1_loss
146*a58d3d2aSXin Li    return pitch_loss
147*a58d3d2aSXin Li
148*a58d3d2aSXin Liopt = Adam(lr, decay=decay, beta_2=0.99)
149*a58d3d2aSXin Listrategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
150*a58d3d2aSXin Li
151*a58d3d2aSXin Liwith strategy.scope():
152*a58d3d2aSXin Li    model = lpcnet.new_lpcnet_plc_model(rnn_units=args.gru_size, batch_size=batch_size, training=True, quantize=quantize, cond_size=args.cond_size)
153*a58d3d2aSXin Li    model.compile(optimizer=opt, loss=plc_loss(alpha=args.band_loss, bias=args.loss_bias), metrics=[plc_l1_loss(), plc_ceps_loss(), plc_band_loss(), plc_pitch_loss()])
154*a58d3d2aSXin Li    model.summary()
155*a58d3d2aSXin Li
156*a58d3d2aSXin Lilpc_order = 16
157*a58d3d2aSXin Li
158*a58d3d2aSXin Lifeature_file = args.features
159*a58d3d2aSXin Linb_features = model.nb_used_features + lpc_order + model.nb_burg_features
160*a58d3d2aSXin Linb_used_features = model.nb_used_features
161*a58d3d2aSXin Linb_burg_features = model.nb_burg_features
162*a58d3d2aSXin Lisequence_size = args.seq_length
163*a58d3d2aSXin Li
164*a58d3d2aSXin Li# u for unquantised, load 16 bit PCM samples and convert to mu-law
165*a58d3d2aSXin Li
166*a58d3d2aSXin Li
167*a58d3d2aSXin Lifeatures = np.memmap(feature_file, dtype='float32', mode='r')
168*a58d3d2aSXin Linb_sequences = len(features)//(nb_features*sequence_size)//batch_size*batch_size
169*a58d3d2aSXin Lifeatures = features[:nb_sequences*sequence_size*nb_features]
170*a58d3d2aSXin Li
171*a58d3d2aSXin Lifeatures = np.reshape(features, (nb_sequences, sequence_size, nb_features))
172*a58d3d2aSXin Li
173*a58d3d2aSXin Lifeatures = features[:, :, :nb_used_features+model.nb_burg_features]
174*a58d3d2aSXin Li
175*a58d3d2aSXin Lilost = np.memmap(args.lost_file, dtype='int8', mode='r')
176*a58d3d2aSXin Li
177*a58d3d2aSXin Li# dump models to disk as we go
178*a58d3d2aSXin Licheckpoint = ModelCheckpoint('{}_{}_{}.h5'.format(args.output, args.gru_size, '{epoch:02d}'))
179*a58d3d2aSXin Li
180*a58d3d2aSXin Liif args.retrain is not None:
181*a58d3d2aSXin Li    model.load_weights(args.retrain)
182*a58d3d2aSXin Li
183*a58d3d2aSXin Liif quantize or retrain:
184*a58d3d2aSXin Li    #Adapting from an existing model
185*a58d3d2aSXin Li    model.load_weights(input_model)
186*a58d3d2aSXin Li
187*a58d3d2aSXin Limodel.save_weights('{}_{}_initial.h5'.format(args.output, args.gru_size))
188*a58d3d2aSXin Li
189*a58d3d2aSXin Liloader = PLCLoader(features, lost, nb_burg_features, batch_size)
190*a58d3d2aSXin Li
191*a58d3d2aSXin Licallbacks = [checkpoint]
192*a58d3d2aSXin Liif args.logdir is not None:
193*a58d3d2aSXin Li    logdir = '{}/{}_{}_logs'.format(args.logdir, args.output, args.gru_size)
194*a58d3d2aSXin Li    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)
195*a58d3d2aSXin Li    callbacks.append(tensorboard_callback)
196*a58d3d2aSXin Li
197*a58d3d2aSXin Limodel.fit(loader, epochs=nb_epochs, validation_split=0.0, callbacks=callbacks)
198