xref: /aosp_15_r20/external/rnnoise/src/rnn_train.py (revision 1295d6828459cc82c3c29cc5d7d297215250a74b)
1*1295d682SXin Li#!/usr/bin/python
2*1295d682SXin Li
3*1295d682SXin Lifrom __future__ import print_function
4*1295d682SXin Li
5*1295d682SXin Lifrom keras.models import Sequential
6*1295d682SXin Lifrom keras.models import Model
7*1295d682SXin Lifrom keras.layers import Input
8*1295d682SXin Lifrom keras.layers import Dense
9*1295d682SXin Lifrom keras.layers import LSTM
10*1295d682SXin Lifrom keras.layers import GRU
11*1295d682SXin Lifrom keras.layers import SimpleRNN
12*1295d682SXin Lifrom keras.layers import Dropout
13*1295d682SXin Lifrom keras import losses
14*1295d682SXin Liimport h5py
15*1295d682SXin Li
16*1295d682SXin Lifrom keras import backend as K
17*1295d682SXin Liimport numpy as np
18*1295d682SXin Li
19*1295d682SXin Liprint('Build model...')
20*1295d682SXin Limain_input = Input(shape=(None, 22), name='main_input')
21*1295d682SXin Li#x = Dense(44, activation='relu')(main_input)
22*1295d682SXin Li#x = GRU(44, dropout=0.0, recurrent_dropout=0.0, activation='tanh', recurrent_activation='sigmoid', return_sequences=True)(x)
23*1295d682SXin Lix=main_input
24*1295d682SXin Lix = GRU(128, activation='tanh', recurrent_activation='sigmoid', return_sequences=True)(x)
25*1295d682SXin Li#x = GRU(128, return_sequences=True)(x)
26*1295d682SXin Li#x = GRU(22, activation='relu', return_sequences=True)(x)
27*1295d682SXin Lix = Dense(22, activation='sigmoid')(x)
28*1295d682SXin Li#x = Dense(22, activation='softplus')(x)
29*1295d682SXin Limodel = Model(inputs=main_input, outputs=x)
30*1295d682SXin Li
31*1295d682SXin Libatch_size = 32
32*1295d682SXin Li
33*1295d682SXin Liprint('Loading data...')
34*1295d682SXin Liwith h5py.File('denoise_data.h5', 'r') as hf:
35*1295d682SXin Li    all_data = hf['denoise_data'][:]
36*1295d682SXin Liprint('done.')
37*1295d682SXin Li
38*1295d682SXin Liwindow_size = 500
39*1295d682SXin Li
40*1295d682SXin Linb_sequences = len(all_data)//window_size
41*1295d682SXin Liprint(nb_sequences, ' sequences')
42*1295d682SXin Lix_train = all_data[:nb_sequences*window_size, :-22]
43*1295d682SXin Lix_train = np.reshape(x_train, (nb_sequences, window_size, 22))
44*1295d682SXin Li
45*1295d682SXin Liy_train = np.copy(all_data[:nb_sequences*window_size, -22:])
46*1295d682SXin Liy_train = np.reshape(y_train, (nb_sequences, window_size, 22))
47*1295d682SXin Li
48*1295d682SXin Li#y_train = -20*np.log10(np.add(y_train, .03));
49*1295d682SXin Li
50*1295d682SXin Liall_data = 0;
51*1295d682SXin Lix_train = x_train.astype('float32')
52*1295d682SXin Liy_train = y_train.astype('float32')
53*1295d682SXin Li
54*1295d682SXin Liprint(len(x_train), 'train sequences. x shape =', x_train.shape, 'y shape = ', y_train.shape)
55*1295d682SXin Li
56*1295d682SXin Li# try using different optimizers and different optimizer configs
57*1295d682SXin Limodel.compile(loss='mean_squared_error',
58*1295d682SXin Li              optimizer='adam',
59*1295d682SXin Li              metrics=['binary_accuracy'])
60*1295d682SXin Li
61*1295d682SXin Liprint('Train...')
62*1295d682SXin Limodel.fit(x_train, y_train,
63*1295d682SXin Li          batch_size=batch_size,
64*1295d682SXin Li          epochs=200,
65*1295d682SXin Li          validation_data=(x_train, y_train))
66*1295d682SXin Limodel.save("newweights.hdf5")
67