xref: /aosp_15_r20/external/libopus/scripts/rnn_train.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1*a58d3d2aSXin Li#!/usr/bin/python
2*a58d3d2aSXin Li
3*a58d3d2aSXin Lifrom __future__ import print_function
4*a58d3d2aSXin Li
5*a58d3d2aSXin Lifrom keras.models import Sequential
6*a58d3d2aSXin Lifrom keras.models import Model
7*a58d3d2aSXin Lifrom keras.layers import Input
8*a58d3d2aSXin Lifrom keras.layers import Dense
9*a58d3d2aSXin Lifrom keras.layers import LSTM
10*a58d3d2aSXin Lifrom keras.layers import GRU
11*a58d3d2aSXin Lifrom keras.layers import SimpleRNN
12*a58d3d2aSXin Lifrom keras.layers import Dropout
13*a58d3d2aSXin Lifrom keras import losses
14*a58d3d2aSXin Liimport h5py
15*a58d3d2aSXin Li
16*a58d3d2aSXin Lifrom keras import backend as K
17*a58d3d2aSXin Liimport numpy as np
18*a58d3d2aSXin Li
19*a58d3d2aSXin Lidef binary_crossentrop2(y_true, y_pred):
20*a58d3d2aSXin Li    return K.mean(2*K.abs(y_true-0.5) * K.binary_crossentropy(y_pred, y_true), axis=-1)
21*a58d3d2aSXin Li
22*a58d3d2aSXin Liprint('Build model...')
23*a58d3d2aSXin Li#model = Sequential()
24*a58d3d2aSXin Li#model.add(Dense(16, activation='tanh', input_shape=(None, 25)))
25*a58d3d2aSXin Li#model.add(GRU(12, dropout=0.0, recurrent_dropout=0.0, activation='tanh', recurrent_activation='sigmoid', return_sequences=True))
26*a58d3d2aSXin Li#model.add(Dense(2, activation='sigmoid'))
27*a58d3d2aSXin Li
28*a58d3d2aSXin Limain_input = Input(shape=(None, 25), name='main_input')
29*a58d3d2aSXin Lix = Dense(16, activation='tanh')(main_input)
30*a58d3d2aSXin Lix = GRU(12, dropout=0.1, recurrent_dropout=0.1, activation='tanh', recurrent_activation='sigmoid', return_sequences=True)(x)
31*a58d3d2aSXin Lix = Dense(2, activation='sigmoid')(x)
32*a58d3d2aSXin Limodel = Model(inputs=main_input, outputs=x)
33*a58d3d2aSXin Li
34*a58d3d2aSXin Libatch_size = 64
35*a58d3d2aSXin Li
36*a58d3d2aSXin Liprint('Loading data...')
37*a58d3d2aSXin Liwith h5py.File('features.h5', 'r') as hf:
38*a58d3d2aSXin Li    all_data = hf['features'][:]
39*a58d3d2aSXin Liprint('done.')
40*a58d3d2aSXin Li
41*a58d3d2aSXin Liwindow_size = 1500
42*a58d3d2aSXin Li
43*a58d3d2aSXin Linb_sequences = len(all_data)/window_size
44*a58d3d2aSXin Liprint(nb_sequences, ' sequences')
45*a58d3d2aSXin Lix_train = all_data[:nb_sequences*window_size, :-2]
46*a58d3d2aSXin Lix_train = np.reshape(x_train, (nb_sequences, window_size, 25))
47*a58d3d2aSXin Li
48*a58d3d2aSXin Liy_train = np.copy(all_data[:nb_sequences*window_size, -2:])
49*a58d3d2aSXin Liy_train = np.reshape(y_train, (nb_sequences, window_size, 2))
50*a58d3d2aSXin Li
51*a58d3d2aSXin Liall_data = 0;
52*a58d3d2aSXin Lix_train = x_train.astype('float32')
53*a58d3d2aSXin Liy_train = y_train.astype('float32')
54*a58d3d2aSXin Li
55*a58d3d2aSXin Liprint(len(x_train), 'train sequences. x shape =', x_train.shape, 'y shape = ', y_train.shape)
56*a58d3d2aSXin Li
57*a58d3d2aSXin Li# try using different optimizers and different optimizer configs
58*a58d3d2aSXin Limodel.compile(loss=binary_crossentrop2,
59*a58d3d2aSXin Li              optimizer='adam',
60*a58d3d2aSXin Li              metrics=['binary_accuracy'])
61*a58d3d2aSXin Li
62*a58d3d2aSXin Liprint('Train...')
63*a58d3d2aSXin Limodel.fit(x_train, y_train,
64*a58d3d2aSXin Li          batch_size=batch_size,
65*a58d3d2aSXin Li          epochs=200,
66*a58d3d2aSXin Li          validation_data=(x_train, y_train))
67*a58d3d2aSXin Limodel.save("newweights.hdf5")
68