xref: /aosp_15_r20/external/libopus/dnn/training_tf2/pade.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1# Optimizing a rational function to optimize a tanh() approximation
2
3import numpy as np
4import tensorflow as tf
5from tensorflow.keras.models import Model
6from tensorflow.keras.layers import Input, GRU, Dense, Embedding, Reshape, Concatenate, Lambda, Conv1D, Multiply, Add, Bidirectional, MaxPooling1D, Activation
7import tensorflow.keras.backend as K
8from tensorflow.keras.optimizers import Adam, SGD
9
10def my_loss1(y_true, y_pred):
11    return 1*K.mean(K.square(y_true-y_pred)) + 1*K.max(K.square(y_true-y_pred), axis=1)
12
13def my_loss2(y_true, y_pred):
14    return .1*K.mean(K.square(y_true-y_pred)) + 1*K.max(K.square(y_true-y_pred), axis=1)
15
16def my_loss3(y_true, y_pred):
17    return .01*K.mean(K.square(y_true-y_pred)) + 1*K.max(K.square(y_true-y_pred), axis=1)
18
19# Using these initializers to seed the approximation
20# with a reasonable starting point
21def num_init(shape, dtype=None):
22    rr = tf.constant([[945], [105], [1]], dtype=dtype)
23    #rr = tf.constant([[946.56757], [98.01368], [0.66841]], dtype=dtype)
24    print(rr)
25    return rr
26
27def den_init(shape, dtype=None):
28    rr = tf.constant([[945], [420], [15]], dtype=dtype)
29    #rr = tf.constant([[946.604], [413.342], [12.465]], dtype=dtype)
30    print(rr)
31    return rr
32
33
34x = np.arange(-10, 10, .01)
35N = len(x)
36x = np.reshape(x, (1, -1, 1))
37x2 = x*x
38
39x2in = np.concatenate([x2*0 + 1, x2, x2*x2], axis=2)
40yout = np.tanh(x)
41
42
43model_x = Input(shape=(None, 1,))
44model_x2 = Input(shape=(None, 3,))
45
46num = Dense(1, name='num', use_bias=False, kernel_initializer=num_init)
47den = Dense(1, name='den', use_bias=False, kernel_initializer=den_init)
48
49def ratio(x):
50    return tf.minimum(1., tf.maximum(-1., x[0]*x[1]/x[2]))
51
52out_layer = Lambda(ratio)
53output = out_layer([model_x, num(model_x2), den(model_x2)])
54
55model = Model([model_x, model_x2], output)
56model.summary()
57
58model.compile(Adam(0.05, beta_1=0.9, beta_2=0.9, decay=2e-5), loss='mean_squared_error')
59model.fit([x, x2in], yout, batch_size=1, epochs=500000, validation_split=0.0)
60
61model.compile(Adam(0.001, beta_2=0.9, decay=1e-4), loss=my_loss1)
62model.fit([x, x2in], yout, batch_size=1, epochs=50000, validation_split=0.0)
63
64model.compile(Adam(0.0001, beta_2=0.9, decay=1e-4), loss=my_loss2)
65model.fit([x, x2in], yout, batch_size=1, epochs=50000, validation_split=0.0)
66
67model.compile(Adam(0.00001, beta_2=0.9, decay=1e-4), loss=my_loss3)
68model.fit([x, x2in], yout, batch_size=1, epochs=50000, validation_split=0.0)
69
70model.save_weights('tanh.h5')
71