xref: /aosp_15_r20/external/libopus/dnn/training_tf2/train_lpcnet.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1#!/usr/bin/python3
2'''Copyright (c) 2018 Mozilla
3
4   Redistribution and use in source and binary forms, with or without
5   modification, are permitted provided that the following conditions
6   are met:
7
8   - Redistributions of source code must retain the above copyright
9   notice, this list of conditions and the following disclaimer.
10
11   - Redistributions in binary form must reproduce the above copyright
12   notice, this list of conditions and the following disclaimer in the
13   documentation and/or other materials provided with the distribution.
14
15   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18   A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
19   CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26'''
27
28# Train an LPCNet model
29
30import argparse
31import os
32
33from dataloader import LPCNetLoader
34
35parser = argparse.ArgumentParser(description='Train an LPCNet model')
36
37parser.add_argument('features', metavar='<features file>', help='binary features file (float32)')
38parser.add_argument('data', metavar='<audio data file>', help='binary audio data file (uint8)')
39parser.add_argument('output', metavar='<output>', help='trained model file (.h5)')
40parser.add_argument('--model', metavar='<model>', default='lpcnet', help='LPCNet model python definition (without .py)')
41group1 = parser.add_mutually_exclusive_group()
42group1.add_argument('--quantize', metavar='<input weights>', help='quantize model')
43group1.add_argument('--retrain', metavar='<input weights>', help='continue training model')
44parser.add_argument('--density', metavar='<global density>', type=float, help='average density of the recurrent weights (default 0.1)')
45parser.add_argument('--density-split', nargs=3, metavar=('<update>', '<reset>', '<state>'), type=float, help='density of each recurrent gate (default 0.05, 0.05, 0.2)')
46parser.add_argument('--grub-density', metavar='<global GRU B density>', type=float, help='average density of the recurrent weights (default 1.0)')
47parser.add_argument('--grub-density-split', nargs=3, metavar=('<update>', '<reset>', '<state>'), type=float, help='density of each GRU B input gate (default 1.0, 1.0, 1.0)')
48parser.add_argument('--grua-size', metavar='<units>', default=384, type=int, help='number of units in GRU A (default 384)')
49parser.add_argument('--grub-size', metavar='<units>', default=16, type=int, help='number of units in GRU B (default 16)')
50parser.add_argument('--cond-size', metavar='<units>', default=128, type=int, help='number of units in conditioning network, aka frame rate network (default 128)')
51parser.add_argument('--epochs', metavar='<epochs>', default=120, type=int, help='number of epochs to train for (default 120)')
52parser.add_argument('--batch-size', metavar='<batch size>', default=128, type=int, help='batch size to use (default 128)')
53parser.add_argument('--end2end', dest='flag_e2e', action='store_true', help='Enable end-to-end training (with differentiable LPC computation')
54parser.add_argument('--lr', metavar='<learning rate>', type=float, help='learning rate')
55parser.add_argument('--decay', metavar='<decay>', type=float, help='learning rate decay')
56parser.add_argument('--gamma', metavar='<gamma>', type=float, help='adjust u-law compensation (default 2.0, should not be less than 1.0)')
57parser.add_argument('--lookahead', metavar='<nb frames>', default=2, type=int, help='Number of look-ahead frames (default 2)')
58parser.add_argument('--logdir', metavar='<log dir>', help='directory for tensorboard log files')
59parser.add_argument('--lpc-gamma', type=float, default=1, help='gamma for LPC weighting')
60parser.add_argument('--cuda-devices', metavar='<cuda devices>', type=str, default=None, help='string with comma separated cuda device ids')
61
62args = parser.parse_args()
63
64# set visible cuda devices
65if args.cuda_devices != None:
66    os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_devices
67
68density = (0.05, 0.05, 0.2)
69if args.density_split is not None:
70    density = args.density_split
71elif args.density is not None:
72    density = [0.5*args.density, 0.5*args.density, 2.0*args.density];
73
74grub_density = (1., 1., 1.)
75if args.grub_density_split is not None:
76    grub_density = args.grub_density_split
77elif args.grub_density is not None:
78    grub_density = [0.5*args.grub_density, 0.5*args.grub_density, 2.0*args.grub_density];
79
80gamma = 2.0 if args.gamma is None else args.gamma
81
82import importlib
83lpcnet = importlib.import_module(args.model)
84
85import sys
86import numpy as np
87from tensorflow.keras.optimizers import Adam
88from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger
89from ulaw import ulaw2lin, lin2ulaw
90import tensorflow.keras.backend as K
91import h5py
92
93import tensorflow as tf
94from tf_funcs import *
95from lossfuncs import *
96#gpus = tf.config.experimental.list_physical_devices('GPU')
97#if gpus:
98#  try:
99#    tf.config.experimental.set_virtual_device_configuration(gpus[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=5120)])
100#  except RuntimeError as e:
101#    print(e)
102
103nb_epochs = args.epochs
104
105# Try reducing batch_size if you run out of memory on your GPU
106batch_size = args.batch_size
107
108quantize = args.quantize is not None
109retrain = args.retrain is not None
110
111lpc_order = 16
112
113if quantize:
114    lr = 0.00003
115    decay = 0
116    input_model = args.quantize
117else:
118    lr = 0.001
119    decay = 5e-5
120
121if args.lr is not None:
122    lr = args.lr
123
124if args.decay is not None:
125    decay = args.decay
126
127if retrain:
128    input_model = args.retrain
129
130flag_e2e = args.flag_e2e
131
132opt = Adam(lr, decay=decay, beta_1=0.5, beta_2=0.8)
133strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
134
135with strategy.scope():
136    model, _, _ = lpcnet.new_lpcnet_model(rnn_units1=args.grua_size,
137                                          rnn_units2=args.grub_size,
138                                          batch_size=batch_size, training=True,
139                                          quantize=quantize,
140                                          flag_e2e=flag_e2e,
141                                          cond_size=args.cond_size,
142                                          lpc_gamma=args.lpc_gamma,
143                                          lookahead=args.lookahead
144                                          )
145    if not flag_e2e:
146        model.compile(optimizer=opt, loss=metric_cel, metrics=metric_cel)
147    else:
148        model.compile(optimizer=opt, loss = [interp_mulaw(gamma=gamma), loss_matchlar()], loss_weights = [1.0, 2.0], metrics={'pdf':[metric_cel,metric_icel,metric_exc_sd,metric_oginterploss]})
149    model.summary()
150
151feature_file = args.features
152pcm_file = args.data     # 16 bit unsigned short PCM samples
153frame_size = model.frame_size
154nb_features = model.nb_used_features + lpc_order
155nb_used_features = model.nb_used_features
156feature_chunk_size = 15
157pcm_chunk_size = frame_size*feature_chunk_size
158
159# u for unquantised, load 16 bit PCM samples and convert to mu-law
160
161data = np.memmap(pcm_file, dtype='int16', mode='r')
162nb_frames = (len(data)//(2*pcm_chunk_size)-1)//batch_size*batch_size
163
164features = np.memmap(feature_file, dtype='float32', mode='r')
165
166# limit to discrete number of frames
167data = data[(4-args.lookahead)*2*frame_size:]
168data = data[:nb_frames*2*pcm_chunk_size]
169
170
171data = np.reshape(data, (nb_frames, pcm_chunk_size, 2))
172
173#print("ulaw std = ", np.std(out_exc))
174
175sizeof = features.strides[-1]
176features = np.lib.stride_tricks.as_strided(features, shape=(nb_frames, feature_chunk_size+4, nb_features),
177                                           strides=(feature_chunk_size*nb_features*sizeof, nb_features*sizeof, sizeof))
178#features = features[:, :, :nb_used_features]
179
180
181periods = (.1 + 50*features[:,:,nb_used_features-2:nb_used_features-1]+100).astype('int16')
182#periods = np.minimum(periods, 255)
183
184# dump models to disk as we go
185checkpoint = ModelCheckpoint('{}_{}_{}.h5'.format(args.output, args.grua_size, '{epoch:02d}'))
186
187if args.retrain is not None:
188    model.load_weights(args.retrain)
189
190if quantize or retrain:
191    #Adapting from an existing model
192    model.load_weights(input_model)
193    if quantize:
194        sparsify = lpcnet.Sparsify(10000, 30000, 100, density, quantize=True)
195        grub_sparsify = lpcnet.SparsifyGRUB(10000, 30000, 100, args.grua_size, grub_density, quantize=True)
196    else:
197        sparsify = lpcnet.Sparsify(0, 0, 1, density)
198        grub_sparsify = lpcnet.SparsifyGRUB(0, 0, 1, args.grua_size, grub_density)
199else:
200    #Training from scratch
201    sparsify = lpcnet.Sparsify(2000, 20000, 400, density)
202    grub_sparsify = lpcnet.SparsifyGRUB(2000, 40000, 400, args.grua_size, grub_density)
203
204model.save_weights('{}_{}_initial.h5'.format(args.output, args.grua_size))
205
206loader = LPCNetLoader(data, features, periods, batch_size, e2e=flag_e2e, lookahead=args.lookahead)
207
208callbacks = [checkpoint, sparsify, grub_sparsify]
209if args.logdir is not None:
210    logdir = '{}/{}_{}_logs'.format(args.logdir, args.output, args.grua_size)
211    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)
212    callbacks.append(tensorboard_callback)
213
214model.fit(loader, epochs=nb_epochs, validation_split=0.0, callbacks=callbacks)
215