xref: /aosp_15_r20/external/libopus/training/rnn_train.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1*a58d3d2aSXin Li#!/usr/bin/python3
2*a58d3d2aSXin Li
3*a58d3d2aSXin Lifrom __future__ import print_function
4*a58d3d2aSXin Li
5*a58d3d2aSXin Lifrom keras.models import Sequential
6*a58d3d2aSXin Lifrom keras.models import Model
7*a58d3d2aSXin Lifrom keras.layers import Input
8*a58d3d2aSXin Lifrom keras.layers import Dense
9*a58d3d2aSXin Lifrom keras.layers import LSTM
10*a58d3d2aSXin Lifrom keras.layers import GRU
11*a58d3d2aSXin Lifrom keras.layers import CuDNNGRU
12*a58d3d2aSXin Lifrom keras.layers import SimpleRNN
13*a58d3d2aSXin Lifrom keras.layers import Dropout
14*a58d3d2aSXin Lifrom keras import losses
15*a58d3d2aSXin Liimport h5py
16*a58d3d2aSXin Lifrom keras.optimizers import Adam
17*a58d3d2aSXin Li
18*a58d3d2aSXin Lifrom keras.constraints import Constraint
19*a58d3d2aSXin Lifrom keras import backend as K
20*a58d3d2aSXin Liimport numpy as np
21*a58d3d2aSXin Li
22*a58d3d2aSXin Liimport tensorflow as tf
23*a58d3d2aSXin Lifrom keras.backend.tensorflow_backend import set_session
24*a58d3d2aSXin Liconfig = tf.ConfigProto()
25*a58d3d2aSXin Liconfig.gpu_options.per_process_gpu_memory_fraction = 0.44
26*a58d3d2aSXin Liset_session(tf.Session(config=config))
27*a58d3d2aSXin Li
28*a58d3d2aSXin Lidef binary_crossentrop2(y_true, y_pred):
29*a58d3d2aSXin Li    return K.mean(2*K.abs(y_true-0.5) * K.binary_crossentropy(y_true, y_pred), axis=-1)
30*a58d3d2aSXin Li
31*a58d3d2aSXin Lidef binary_accuracy2(y_true, y_pred):
32*a58d3d2aSXin Li    return K.mean(K.cast(K.equal(y_true, K.round(y_pred)), 'float32') + K.cast(K.equal(y_true, 0.5), 'float32'), axis=-1)
33*a58d3d2aSXin Li
34*a58d3d2aSXin Lidef quant_model(model):
35*a58d3d2aSXin Li    weights = model.get_weights()
36*a58d3d2aSXin Li    for k in range(len(weights)):
37*a58d3d2aSXin Li        weights[k] = np.maximum(-128, np.minimum(127, np.round(128*weights[k])*0.0078125))
38*a58d3d2aSXin Li    model.set_weights(weights)
39*a58d3d2aSXin Li
40*a58d3d2aSXin Liclass WeightClip(Constraint):
41*a58d3d2aSXin Li    '''Clips the weights incident to each hidden unit to be inside a range
42*a58d3d2aSXin Li    '''
43*a58d3d2aSXin Li    def __init__(self, c=2):
44*a58d3d2aSXin Li        self.c = c
45*a58d3d2aSXin Li
46*a58d3d2aSXin Li    def __call__(self, p):
47*a58d3d2aSXin Li        return K.clip(p, -self.c, self.c)
48*a58d3d2aSXin Li
49*a58d3d2aSXin Li    def get_config(self):
50*a58d3d2aSXin Li        return {'name': self.__class__.__name__,
51*a58d3d2aSXin Li            'c': self.c}
52*a58d3d2aSXin Li
53*a58d3d2aSXin Lireg = 0.000001
54*a58d3d2aSXin Liconstraint = WeightClip(.998)
55*a58d3d2aSXin Li
56*a58d3d2aSXin Liprint('Build model...')
57*a58d3d2aSXin Li
58*a58d3d2aSXin Limain_input = Input(shape=(None, 25), name='main_input')
59*a58d3d2aSXin Lix = Dense(32, activation='tanh', kernel_constraint=constraint, bias_constraint=constraint)(main_input)
60*a58d3d2aSXin Li#x = CuDNNGRU(24, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, bias_constraint=constraint)(x)
61*a58d3d2aSXin Lix = GRU(24, recurrent_activation='sigmoid', activation='tanh', return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, bias_constraint=constraint)(x)
62*a58d3d2aSXin Lix = Dense(2, activation='sigmoid', kernel_constraint=constraint, bias_constraint=constraint)(x)
63*a58d3d2aSXin Limodel = Model(inputs=main_input, outputs=x)
64*a58d3d2aSXin Li
65*a58d3d2aSXin Libatch_size = 2048
66*a58d3d2aSXin Li
67*a58d3d2aSXin Liprint('Loading data...')
68*a58d3d2aSXin Liwith h5py.File('features10b.h5', 'r') as hf:
69*a58d3d2aSXin Li    all_data = hf['data'][:]
70*a58d3d2aSXin Liprint('done.')
71*a58d3d2aSXin Li
72*a58d3d2aSXin Liwindow_size = 1500
73*a58d3d2aSXin Li
74*a58d3d2aSXin Linb_sequences = len(all_data)//window_size
75*a58d3d2aSXin Liprint(nb_sequences, ' sequences')
76*a58d3d2aSXin Lix_train = all_data[:nb_sequences*window_size, :-2]
77*a58d3d2aSXin Lix_train = np.reshape(x_train, (nb_sequences, window_size, 25))
78*a58d3d2aSXin Li
79*a58d3d2aSXin Liy_train = np.copy(all_data[:nb_sequences*window_size, -2:])
80*a58d3d2aSXin Liy_train = np.reshape(y_train, (nb_sequences, window_size, 2))
81*a58d3d2aSXin Li
82*a58d3d2aSXin Liprint("Marking ignores")
83*a58d3d2aSXin Lifor s in y_train:
84*a58d3d2aSXin Li    for e in s:
85*a58d3d2aSXin Li        if (e[1] >= 1):
86*a58d3d2aSXin Li            break
87*a58d3d2aSXin Li        e[0] = 0.5
88*a58d3d2aSXin Li
89*a58d3d2aSXin Liall_data = 0;
90*a58d3d2aSXin Lix_train = x_train.astype('float32')
91*a58d3d2aSXin Liy_train = y_train.astype('float32')
92*a58d3d2aSXin Li
93*a58d3d2aSXin Liprint(len(x_train), 'train sequences. x shape =', x_train.shape, 'y shape = ', y_train.shape)
94*a58d3d2aSXin Li
95*a58d3d2aSXin Limodel.load_weights('newweights10a1b_ep206.hdf5')
96*a58d3d2aSXin Li
97*a58d3d2aSXin Li#weights = model.get_weights()
98*a58d3d2aSXin Li#for k in range(len(weights)):
99*a58d3d2aSXin Li#    weights[k] = np.round(128*weights[k])*0.0078125
100*a58d3d2aSXin Li#model.set_weights(weights)
101*a58d3d2aSXin Li
102*a58d3d2aSXin Li# try using different optimizers and different optimizer configs
103*a58d3d2aSXin Limodel.compile(loss=binary_crossentrop2,
104*a58d3d2aSXin Li              optimizer=Adam(0.0001),
105*a58d3d2aSXin Li              metrics=[binary_accuracy2])
106*a58d3d2aSXin Li
107*a58d3d2aSXin Liprint('Train...')
108*a58d3d2aSXin Liquant_model(model)
109*a58d3d2aSXin Limodel.fit(x_train, y_train,
110*a58d3d2aSXin Li          batch_size=batch_size,
111*a58d3d2aSXin Li          epochs=10, validation_data=(x_train, y_train))
112*a58d3d2aSXin Limodel.save("newweights10a1c_ep10.hdf5")
113*a58d3d2aSXin Li
114*a58d3d2aSXin Liquant_model(model)
115*a58d3d2aSXin Limodel.fit(x_train, y_train,
116*a58d3d2aSXin Li          batch_size=batch_size,
117*a58d3d2aSXin Li          epochs=50, initial_epoch=10)
118*a58d3d2aSXin Limodel.save("newweights10a1c_ep50.hdf5")
119*a58d3d2aSXin Li
120*a58d3d2aSXin Limodel.compile(loss=binary_crossentrop2,
121*a58d3d2aSXin Li              optimizer=Adam(0.0001),
122*a58d3d2aSXin Li              metrics=[binary_accuracy2])
123*a58d3d2aSXin Li
124*a58d3d2aSXin Liquant_model(model)
125*a58d3d2aSXin Limodel.fit(x_train, y_train,
126*a58d3d2aSXin Li          batch_size=batch_size,
127*a58d3d2aSXin Li          epochs=100, initial_epoch=50)
128*a58d3d2aSXin Limodel.save("newweights10a1c_ep100.hdf5")
129*a58d3d2aSXin Li
130*a58d3d2aSXin Liquant_model(model)
131*a58d3d2aSXin Limodel.fit(x_train, y_train,
132*a58d3d2aSXin Li          batch_size=batch_size,
133*a58d3d2aSXin Li          epochs=150, initial_epoch=100)
134*a58d3d2aSXin Limodel.save("newweights10a1c_ep150.hdf5")
135*a58d3d2aSXin Li
136*a58d3d2aSXin Liquant_model(model)
137*a58d3d2aSXin Limodel.fit(x_train, y_train,
138*a58d3d2aSXin Li          batch_size=batch_size,
139*a58d3d2aSXin Li          epochs=200, initial_epoch=150)
140*a58d3d2aSXin Limodel.save("newweights10a1c_ep200.hdf5")
141*a58d3d2aSXin Li
142*a58d3d2aSXin Liquant_model(model)
143*a58d3d2aSXin Limodel.fit(x_train, y_train,
144*a58d3d2aSXin Li          batch_size=batch_size,
145*a58d3d2aSXin Li          epochs=201, initial_epoch=200)
146*a58d3d2aSXin Limodel.save("newweights10a1c_ep201.hdf5")
147*a58d3d2aSXin Li
148*a58d3d2aSXin Liquant_model(model)
149*a58d3d2aSXin Limodel.fit(x_train, y_train,
150*a58d3d2aSXin Li          batch_size=batch_size,
151*a58d3d2aSXin Li          epochs=202, initial_epoch=201, validation_data=(x_train, y_train))
152*a58d3d2aSXin Limodel.save("newweights10a1c_ep202.hdf5")
153*a58d3d2aSXin Li
154*a58d3d2aSXin Liquant_model(model)
155*a58d3d2aSXin Limodel.fit(x_train, y_train,
156*a58d3d2aSXin Li          batch_size=batch_size,
157*a58d3d2aSXin Li          epochs=203, initial_epoch=202, validation_data=(x_train, y_train))
158*a58d3d2aSXin Limodel.save("newweights10a1c_ep203.hdf5")
159*a58d3d2aSXin Li
160*a58d3d2aSXin Liquant_model(model)
161*a58d3d2aSXin Limodel.fit(x_train, y_train,
162*a58d3d2aSXin Li          batch_size=batch_size,
163*a58d3d2aSXin Li          epochs=204, initial_epoch=203, validation_data=(x_train, y_train))
164*a58d3d2aSXin Limodel.save("newweights10a1c_ep204.hdf5")
165*a58d3d2aSXin Li
166*a58d3d2aSXin Liquant_model(model)
167*a58d3d2aSXin Limodel.fit(x_train, y_train,
168*a58d3d2aSXin Li          batch_size=batch_size,
169*a58d3d2aSXin Li          epochs=205, initial_epoch=204, validation_data=(x_train, y_train))
170*a58d3d2aSXin Limodel.save("newweights10a1c_ep205.hdf5")
171*a58d3d2aSXin Li
172*a58d3d2aSXin Liquant_model(model)
173*a58d3d2aSXin Limodel.fit(x_train, y_train,
174*a58d3d2aSXin Li          batch_size=batch_size,
175*a58d3d2aSXin Li          epochs=206, initial_epoch=205, validation_data=(x_train, y_train))
176*a58d3d2aSXin Limodel.save("newweights10a1c_ep206.hdf5")
177*a58d3d2aSXin Li
178