xref: /aosp_15_r20/external/libopus/dnn/training_tf2/test_plc.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1*a58d3d2aSXin Li#!/usr/bin/python3
2*a58d3d2aSXin Li'''Copyright (c) 2021-2022 Amazon
3*a58d3d2aSXin Li   Copyright (c) 2018-2019 Mozilla
4*a58d3d2aSXin Li
5*a58d3d2aSXin Li   Redistribution and use in source and binary forms, with or without
6*a58d3d2aSXin Li   modification, are permitted provided that the following conditions
7*a58d3d2aSXin Li   are met:
8*a58d3d2aSXin Li
9*a58d3d2aSXin Li   - Redistributions of source code must retain the above copyright
10*a58d3d2aSXin Li   notice, this list of conditions and the following disclaimer.
11*a58d3d2aSXin Li
12*a58d3d2aSXin Li   - Redistributions in binary form must reproduce the above copyright
13*a58d3d2aSXin Li   notice, this list of conditions and the following disclaimer in the
14*a58d3d2aSXin Li   documentation and/or other materials provided with the distribution.
15*a58d3d2aSXin Li
16*a58d3d2aSXin Li   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17*a58d3d2aSXin Li   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18*a58d3d2aSXin Li   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19*a58d3d2aSXin Li   A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
20*a58d3d2aSXin Li   CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21*a58d3d2aSXin Li   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22*a58d3d2aSXin Li   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23*a58d3d2aSXin Li   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24*a58d3d2aSXin Li   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25*a58d3d2aSXin Li   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26*a58d3d2aSXin Li   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*a58d3d2aSXin Li'''
28*a58d3d2aSXin Li
29*a58d3d2aSXin Li# Train an LPCNet model
30*a58d3d2aSXin Li
31*a58d3d2aSXin Liimport argparse
32*a58d3d2aSXin Lifrom plc_loader import PLCLoader
33*a58d3d2aSXin Li
34*a58d3d2aSXin Liparser = argparse.ArgumentParser(description='Test a PLC model')
35*a58d3d2aSXin Li
36*a58d3d2aSXin Liparser.add_argument('weights', metavar='<weights file>', help='weights file (.h5)')
37*a58d3d2aSXin Liparser.add_argument('features', metavar='<features file>', help='binary features file (float32)')
38*a58d3d2aSXin Liparser.add_argument('output', metavar='<output>', help='reconstructed file (float32)')
39*a58d3d2aSXin Liparser.add_argument('--model', metavar='<model>', default='lpcnet_plc', help='PLC model python definition (without .py)')
40*a58d3d2aSXin Ligroup1 = parser.add_mutually_exclusive_group()
41*a58d3d2aSXin Li
42*a58d3d2aSXin Liparser.add_argument('--gru-size', metavar='<units>', default=256, type=int, help='number of units in GRU (default 256)')
43*a58d3d2aSXin Liparser.add_argument('--cond-size', metavar='<units>', default=128, type=int, help='number of units in conditioning network (default 128)')
44*a58d3d2aSXin Li
45*a58d3d2aSXin Li
46*a58d3d2aSXin Liargs = parser.parse_args()
47*a58d3d2aSXin Li
48*a58d3d2aSXin Liimport importlib
49*a58d3d2aSXin Lilpcnet = importlib.import_module(args.model)
50*a58d3d2aSXin Li
51*a58d3d2aSXin Liimport sys
52*a58d3d2aSXin Liimport numpy as np
53*a58d3d2aSXin Lifrom tensorflow.keras.optimizers import Adam
54*a58d3d2aSXin Lifrom tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger
55*a58d3d2aSXin Liimport tensorflow.keras.backend as K
56*a58d3d2aSXin Liimport h5py
57*a58d3d2aSXin Li
58*a58d3d2aSXin Liimport tensorflow as tf
59*a58d3d2aSXin Li#gpus = tf.config.experimental.list_physical_devices('GPU')
60*a58d3d2aSXin Li#if gpus:
61*a58d3d2aSXin Li#  try:
62*a58d3d2aSXin Li#    tf.config.experimental.set_virtual_device_configuration(gpus[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=5120)])
63*a58d3d2aSXin Li#  except RuntimeError as e:
64*a58d3d2aSXin Li#    print(e)
65*a58d3d2aSXin Li
66*a58d3d2aSXin Limodel = lpcnet.new_lpcnet_plc_model(rnn_units=args.gru_size, batch_size=1, training=False, quantize=False, cond_size=args.cond_size)
67*a58d3d2aSXin Limodel.compile()
68*a58d3d2aSXin Li
69*a58d3d2aSXin Lilpc_order = 16
70*a58d3d2aSXin Li
71*a58d3d2aSXin Lifeature_file = args.features
72*a58d3d2aSXin Linb_features = model.nb_used_features + lpc_order
73*a58d3d2aSXin Linb_used_features = model.nb_used_features
74*a58d3d2aSXin Li
75*a58d3d2aSXin Li# u for unquantised, load 16 bit PCM samples and convert to mu-law
76*a58d3d2aSXin Li
77*a58d3d2aSXin Lifeatures = np.loadtxt(feature_file)
78*a58d3d2aSXin Liprint(features.shape)
79*a58d3d2aSXin Lisequence_size = features.shape[0]
80*a58d3d2aSXin Lilost = np.reshape(features[:,-1:], (1, sequence_size, 1))
81*a58d3d2aSXin Lifeatures = features[:,:nb_used_features]
82*a58d3d2aSXin Lifeatures = np.reshape(features, (1, sequence_size, nb_used_features))
83*a58d3d2aSXin Li
84*a58d3d2aSXin Li
85*a58d3d2aSXin Limodel.load_weights(args.weights)
86*a58d3d2aSXin Li
87*a58d3d2aSXin Lifeatures = features*lost
88*a58d3d2aSXin Liout = model.predict([features, lost])
89*a58d3d2aSXin Li
90*a58d3d2aSXin Liout = features + (1-lost)*out
91*a58d3d2aSXin Li
92*a58d3d2aSXin Linp.savetxt(args.output, out[0,:,:])
93