xref: /aosp_15_r20/external/libopus/dnn/training_tf2/pade.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1*a58d3d2aSXin Li# Optimizing a rational function to optimize a tanh() approximation
2*a58d3d2aSXin Li
3*a58d3d2aSXin Liimport numpy as np
4*a58d3d2aSXin Liimport tensorflow as tf
5*a58d3d2aSXin Lifrom tensorflow.keras.models import Model
6*a58d3d2aSXin Lifrom tensorflow.keras.layers import Input, GRU, Dense, Embedding, Reshape, Concatenate, Lambda, Conv1D, Multiply, Add, Bidirectional, MaxPooling1D, Activation
7*a58d3d2aSXin Liimport tensorflow.keras.backend as K
8*a58d3d2aSXin Lifrom tensorflow.keras.optimizers import Adam, SGD
9*a58d3d2aSXin Li
10*a58d3d2aSXin Lidef my_loss1(y_true, y_pred):
11*a58d3d2aSXin Li    return 1*K.mean(K.square(y_true-y_pred)) + 1*K.max(K.square(y_true-y_pred), axis=1)
12*a58d3d2aSXin Li
13*a58d3d2aSXin Lidef my_loss2(y_true, y_pred):
14*a58d3d2aSXin Li    return .1*K.mean(K.square(y_true-y_pred)) + 1*K.max(K.square(y_true-y_pred), axis=1)
15*a58d3d2aSXin Li
16*a58d3d2aSXin Lidef my_loss3(y_true, y_pred):
17*a58d3d2aSXin Li    return .01*K.mean(K.square(y_true-y_pred)) + 1*K.max(K.square(y_true-y_pred), axis=1)
18*a58d3d2aSXin Li
19*a58d3d2aSXin Li# Using these initializers to seed the approximation
20*a58d3d2aSXin Li# with a reasonable starting point
21*a58d3d2aSXin Lidef num_init(shape, dtype=None):
22*a58d3d2aSXin Li    rr = tf.constant([[945], [105], [1]], dtype=dtype)
23*a58d3d2aSXin Li    #rr = tf.constant([[946.56757], [98.01368], [0.66841]], dtype=dtype)
24*a58d3d2aSXin Li    print(rr)
25*a58d3d2aSXin Li    return rr
26*a58d3d2aSXin Li
27*a58d3d2aSXin Lidef den_init(shape, dtype=None):
28*a58d3d2aSXin Li    rr = tf.constant([[945], [420], [15]], dtype=dtype)
29*a58d3d2aSXin Li    #rr = tf.constant([[946.604], [413.342], [12.465]], dtype=dtype)
30*a58d3d2aSXin Li    print(rr)
31*a58d3d2aSXin Li    return rr
32*a58d3d2aSXin Li
33*a58d3d2aSXin Li
34*a58d3d2aSXin Lix = np.arange(-10, 10, .01)
35*a58d3d2aSXin LiN = len(x)
36*a58d3d2aSXin Lix = np.reshape(x, (1, -1, 1))
37*a58d3d2aSXin Lix2 = x*x
38*a58d3d2aSXin Li
39*a58d3d2aSXin Lix2in = np.concatenate([x2*0 + 1, x2, x2*x2], axis=2)
40*a58d3d2aSXin Liyout = np.tanh(x)
41*a58d3d2aSXin Li
42*a58d3d2aSXin Li
43*a58d3d2aSXin Limodel_x = Input(shape=(None, 1,))
44*a58d3d2aSXin Limodel_x2 = Input(shape=(None, 3,))
45*a58d3d2aSXin Li
46*a58d3d2aSXin Linum = Dense(1, name='num', use_bias=False, kernel_initializer=num_init)
47*a58d3d2aSXin Liden = Dense(1, name='den', use_bias=False, kernel_initializer=den_init)
48*a58d3d2aSXin Li
49*a58d3d2aSXin Lidef ratio(x):
50*a58d3d2aSXin Li    return tf.minimum(1., tf.maximum(-1., x[0]*x[1]/x[2]))
51*a58d3d2aSXin Li
52*a58d3d2aSXin Liout_layer = Lambda(ratio)
53*a58d3d2aSXin Lioutput = out_layer([model_x, num(model_x2), den(model_x2)])
54*a58d3d2aSXin Li
55*a58d3d2aSXin Limodel = Model([model_x, model_x2], output)
56*a58d3d2aSXin Limodel.summary()
57*a58d3d2aSXin Li
58*a58d3d2aSXin Limodel.compile(Adam(0.05, beta_1=0.9, beta_2=0.9, decay=2e-5), loss='mean_squared_error')
59*a58d3d2aSXin Limodel.fit([x, x2in], yout, batch_size=1, epochs=500000, validation_split=0.0)
60*a58d3d2aSXin Li
61*a58d3d2aSXin Limodel.compile(Adam(0.001, beta_2=0.9, decay=1e-4), loss=my_loss1)
62*a58d3d2aSXin Limodel.fit([x, x2in], yout, batch_size=1, epochs=50000, validation_split=0.0)
63*a58d3d2aSXin Li
64*a58d3d2aSXin Limodel.compile(Adam(0.0001, beta_2=0.9, decay=1e-4), loss=my_loss2)
65*a58d3d2aSXin Limodel.fit([x, x2in], yout, batch_size=1, epochs=50000, validation_split=0.0)
66*a58d3d2aSXin Li
67*a58d3d2aSXin Limodel.compile(Adam(0.00001, beta_2=0.9, decay=1e-4), loss=my_loss3)
68*a58d3d2aSXin Limodel.fit([x, x2in], yout, batch_size=1, epochs=50000, validation_split=0.0)
69*a58d3d2aSXin Li
70*a58d3d2aSXin Limodel.save_weights('tanh.h5')
71