xref: /aosp_15_r20/external/libopus/dnn/training_tf2/rdovae.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1*a58d3d2aSXin Li#!/usr/bin/python3
2*a58d3d2aSXin Li'''Copyright (c) 2022 Amazon
3*a58d3d2aSXin Li
4*a58d3d2aSXin Li   Redistribution and use in source and binary forms, with or without
5*a58d3d2aSXin Li   modification, are permitted provided that the following conditions
6*a58d3d2aSXin Li   are met:
7*a58d3d2aSXin Li
8*a58d3d2aSXin Li   - Redistributions of source code must retain the above copyright
9*a58d3d2aSXin Li   notice, this list of conditions and the following disclaimer.
10*a58d3d2aSXin Li
11*a58d3d2aSXin Li   - Redistributions in binary form must reproduce the above copyright
12*a58d3d2aSXin Li   notice, this list of conditions and the following disclaimer in the
13*a58d3d2aSXin Li   documentation and/or other materials provided with the distribution.
14*a58d3d2aSXin Li
15*a58d3d2aSXin Li   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16*a58d3d2aSXin Li   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17*a58d3d2aSXin Li   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18*a58d3d2aSXin Li   A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
19*a58d3d2aSXin Li   CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20*a58d3d2aSXin Li   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21*a58d3d2aSXin Li   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22*a58d3d2aSXin Li   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23*a58d3d2aSXin Li   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24*a58d3d2aSXin Li   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25*a58d3d2aSXin Li   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26*a58d3d2aSXin Li'''
27*a58d3d2aSXin Li
28*a58d3d2aSXin Liimport math
29*a58d3d2aSXin Liimport tensorflow as tf
30*a58d3d2aSXin Lifrom tensorflow.keras.models import Model
31*a58d3d2aSXin Lifrom tensorflow.keras.layers import Input, GRU, Dense, Embedding, Reshape, Concatenate, Lambda, Conv1D, Multiply, Add, Bidirectional, MaxPooling1D, Activation, GaussianNoise, AveragePooling1D, RepeatVector
32*a58d3d2aSXin Lifrom tensorflow.compat.v1.keras.layers import CuDNNGRU
33*a58d3d2aSXin Lifrom tensorflow.keras import backend as K
34*a58d3d2aSXin Lifrom tensorflow.keras.constraints import Constraint
35*a58d3d2aSXin Lifrom tensorflow.keras.initializers import Initializer
36*a58d3d2aSXin Lifrom tensorflow.keras.callbacks import Callback
37*a58d3d2aSXin Lifrom tensorflow.keras.regularizers import l1
38*a58d3d2aSXin Liimport numpy as np
39*a58d3d2aSXin Liimport h5py
40*a58d3d2aSXin Lifrom uniform_noise import UniformNoise
41*a58d3d2aSXin Li
42*a58d3d2aSXin Liclass WeightClip(Constraint):
43*a58d3d2aSXin Li    '''Clips the weights incident to each hidden unit to be inside a range
44*a58d3d2aSXin Li    '''
45*a58d3d2aSXin Li    def __init__(self, c=2):
46*a58d3d2aSXin Li        self.c = c
47*a58d3d2aSXin Li
48*a58d3d2aSXin Li    def __call__(self, p):
49*a58d3d2aSXin Li        # Ensure that abs of adjacent weights don't sum to more than 127. Otherwise there's a risk of
50*a58d3d2aSXin Li        # saturation when implementing dot products with SSSE3 or AVX2.
51*a58d3d2aSXin Li        return self.c*p/tf.maximum(self.c, tf.repeat(tf.abs(p[:, 1::2])+tf.abs(p[:, 0::2]), 2, axis=1))
52*a58d3d2aSXin Li        #return K.clip(p, -self.c, self.c)
53*a58d3d2aSXin Li
54*a58d3d2aSXin Li    def get_config(self):
55*a58d3d2aSXin Li        return {'name': self.__class__.__name__,
56*a58d3d2aSXin Li            'c': self.c}
57*a58d3d2aSXin Li
58*a58d3d2aSXin Liconstraint = WeightClip(0.496)
59*a58d3d2aSXin Li
60*a58d3d2aSXin Lidef soft_quantize(x):
61*a58d3d2aSXin Li    #x = 4*x
62*a58d3d2aSXin Li    #x = x - (.25/np.math.pi)*tf.math.sin(2*np.math.pi*x)
63*a58d3d2aSXin Li    #x = x - (.25/np.math.pi)*tf.math.sin(2*np.math.pi*x)
64*a58d3d2aSXin Li    #x = x - (.25/np.math.pi)*tf.math.sin(2*np.math.pi*x)
65*a58d3d2aSXin Li    return x
66*a58d3d2aSXin Li
67*a58d3d2aSXin Lidef noise_quantize(x):
68*a58d3d2aSXin Li    return soft_quantize(x + (K.random_uniform((128, 16, 80))-.5) )
69*a58d3d2aSXin Li
70*a58d3d2aSXin Lidef hard_quantize(x):
71*a58d3d2aSXin Li    x = soft_quantize(x)
72*a58d3d2aSXin Li    quantized = tf.round(x)
73*a58d3d2aSXin Li    return x + tf.stop_gradient(quantized - x)
74*a58d3d2aSXin Li
75*a58d3d2aSXin Lidef apply_dead_zone(x):
76*a58d3d2aSXin Li    d = x[1]*.05
77*a58d3d2aSXin Li    x = x[0]
78*a58d3d2aSXin Li    y = x - d*tf.math.tanh(x/(.1+d))
79*a58d3d2aSXin Li    return y
80*a58d3d2aSXin Li
81*a58d3d2aSXin Lidef rate_loss(y_true,y_pred):
82*a58d3d2aSXin Li    log2_e = 1.4427
83*a58d3d2aSXin Li    n = y_pred.shape[-1]
84*a58d3d2aSXin Li    C = n - log2_e*np.math.log(np.math.gamma(n))
85*a58d3d2aSXin Li    k = K.sum(K.abs(y_pred), axis=-1)
86*a58d3d2aSXin Li    p = 1.5
87*a58d3d2aSXin Li    #rate = C + (n-1)*log2_e*tf.math.log((k**p + (n/5)**p)**(1/p))
88*a58d3d2aSXin Li    rate = C + (n-1)*log2_e*tf.math.log(k + .112*n**2/(n/1.8+k) )
89*a58d3d2aSXin Li    return K.mean(rate)
90*a58d3d2aSXin Li
91*a58d3d2aSXin Lieps=1e-6
92*a58d3d2aSXin Lidef safelog2(x):
93*a58d3d2aSXin Li    log2_e = 1.4427
94*a58d3d2aSXin Li    return log2_e*tf.math.log(eps+x)
95*a58d3d2aSXin Li
96*a58d3d2aSXin Lidef feat_dist_loss(y_true,y_pred):
97*a58d3d2aSXin Li    lambda_1 = 1./K.sqrt(y_pred[:,:,:,-1])
98*a58d3d2aSXin Li    y_pred = y_pred[:,:,:,:-1]
99*a58d3d2aSXin Li    ceps = y_pred[:,:,:,:18] - y_true[:,:,:18]
100*a58d3d2aSXin Li    pitch = 2*(y_pred[:,:,:,18:19] - y_true[:,:,18:19])/(y_true[:,:,18:19] + 2)
101*a58d3d2aSXin Li    corr = y_pred[:,:,:,19:] - y_true[:,:,19:]
102*a58d3d2aSXin Li    pitch_weight = K.square(K.maximum(0., y_true[:,:,19:]+.5))
103*a58d3d2aSXin Li    return K.mean(lambda_1*K.mean(K.square(ceps) + 10*(1/18.)*K.abs(pitch)*pitch_weight + (1/18.)*K.square(corr), axis=-1))
104*a58d3d2aSXin Li
105*a58d3d2aSXin Lidef sq1_rate_loss(y_true,y_pred):
106*a58d3d2aSXin Li    lambda_val = K.sqrt(y_pred[:,:,-1])
107*a58d3d2aSXin Li    y_pred = y_pred[:,:,:-1]
108*a58d3d2aSXin Li    log2_e = 1.4427
109*a58d3d2aSXin Li    n = y_pred.shape[-1]//3
110*a58d3d2aSXin Li    r = (y_pred[:,:,2*n:])
111*a58d3d2aSXin Li    p0 = (y_pred[:,:,n:2*n])
112*a58d3d2aSXin Li    p0 = 1-r**(.5+.5*p0)
113*a58d3d2aSXin Li    y_pred = y_pred[:,:,:n]
114*a58d3d2aSXin Li    y_pred = soft_quantize(y_pred)
115*a58d3d2aSXin Li
116*a58d3d2aSXin Li    y0 = K.maximum(0., 1. - K.abs(y_pred))**2
117*a58d3d2aSXin Li    rate = -y0*safelog2(p0*r**K.abs(y_pred)) - (1-y0)*safelog2(.5*(1-p0)*(1-r)*r**(K.abs(y_pred)-1))
118*a58d3d2aSXin Li    rate = -safelog2(-.5*tf.math.log(r)*r**K.abs(y_pred))
119*a58d3d2aSXin Li    rate = -safelog2((1-r)/(1+r)*r**K.abs(y_pred))
120*a58d3d2aSXin Li    #rate = -safelog2(- tf.math.sinh(.5*tf.math.log(r))* r**K.abs(y_pred) - tf.math.cosh(K.maximum(0., .5 - K.abs(y_pred))*tf.math.log(r)) + 1)
121*a58d3d2aSXin Li    rate = lambda_val*K.sum(rate, axis=-1)
122*a58d3d2aSXin Li    return K.mean(rate)
123*a58d3d2aSXin Li
124*a58d3d2aSXin Lidef sq2_rate_loss(y_true,y_pred):
125*a58d3d2aSXin Li    lambda_val = K.sqrt(y_pred[:,:,-1])
126*a58d3d2aSXin Li    y_pred = y_pred[:,:,:-1]
127*a58d3d2aSXin Li    log2_e = 1.4427
128*a58d3d2aSXin Li    n = y_pred.shape[-1]//3
129*a58d3d2aSXin Li    r = y_pred[:,:,2*n:]
130*a58d3d2aSXin Li    p0 = y_pred[:,:,n:2*n]
131*a58d3d2aSXin Li    p0 = 1-r**(.5+.5*p0)
132*a58d3d2aSXin Li    #theta = K.minimum(1., .5 + 0*p0 - 0.04*tf.math.log(r))
133*a58d3d2aSXin Li    #p0 = 1-r**theta
134*a58d3d2aSXin Li    y_pred = tf.round(y_pred[:,:,:n])
135*a58d3d2aSXin Li    y0 = K.maximum(0., 1. - K.abs(y_pred))**2
136*a58d3d2aSXin Li    rate = -y0*safelog2(p0*r**K.abs(y_pred)) - (1-y0)*safelog2(.5*(1-p0)*(1-r)*r**(K.abs(y_pred)-1))
137*a58d3d2aSXin Li    rate = lambda_val*K.sum(rate, axis=-1)
138*a58d3d2aSXin Li    return K.mean(rate)
139*a58d3d2aSXin Li
140*a58d3d2aSXin Lidef sq_rate_metric(y_true,y_pred, reduce=True):
141*a58d3d2aSXin Li    y_pred = y_pred[:,:,:-1]
142*a58d3d2aSXin Li    log2_e = 1.4427
143*a58d3d2aSXin Li    n = y_pred.shape[-1]//3
144*a58d3d2aSXin Li    r = y_pred[:,:,2*n:]
145*a58d3d2aSXin Li    p0 = y_pred[:,:,n:2*n]
146*a58d3d2aSXin Li    p0 = 1-r**(.5+.5*p0)
147*a58d3d2aSXin Li    #theta = K.minimum(1., .5 + 0*p0 - 0.04*tf.math.log(r))
148*a58d3d2aSXin Li    #p0 = 1-r**theta
149*a58d3d2aSXin Li    y_pred = tf.round(y_pred[:,:,:n])
150*a58d3d2aSXin Li    y0 = K.maximum(0., 1. - K.abs(y_pred))**2
151*a58d3d2aSXin Li    rate = -y0*safelog2(p0*r**K.abs(y_pred)) - (1-y0)*safelog2(.5*(1-p0)*(1-r)*r**(K.abs(y_pred)-1))
152*a58d3d2aSXin Li    rate = K.sum(rate, axis=-1)
153*a58d3d2aSXin Li    if reduce:
154*a58d3d2aSXin Li        rate = K.mean(rate)
155*a58d3d2aSXin Li    return rate
156*a58d3d2aSXin Li
157*a58d3d2aSXin Lidef pvq_quant_search(x, k):
158*a58d3d2aSXin Li    x = x/tf.reduce_sum(tf.abs(x), axis=-1, keepdims=True)
159*a58d3d2aSXin Li    kx = k*x
160*a58d3d2aSXin Li    y = tf.round(kx)
161*a58d3d2aSXin Li    newk = k
162*a58d3d2aSXin Li
163*a58d3d2aSXin Li    for j in range(10):
164*a58d3d2aSXin Li        #print("y = ", y)
165*a58d3d2aSXin Li        #print("iteration ", j)
166*a58d3d2aSXin Li        abs_y = tf.abs(y)
167*a58d3d2aSXin Li        abs_kx = tf.abs(kx)
168*a58d3d2aSXin Li        kk=tf.reduce_sum(abs_y, axis=-1)
169*a58d3d2aSXin Li        #print("sums = ", kk)
170*a58d3d2aSXin Li        plus = 1.000001*tf.reduce_min((abs_y+.5)/(abs_kx+1e-15), axis=-1)
171*a58d3d2aSXin Li        minus = .999999*tf.reduce_max((abs_y-.5)/(abs_kx+1e-15), axis=-1)
172*a58d3d2aSXin Li        #print("plus = ", plus)
173*a58d3d2aSXin Li        #print("minus = ", minus)
174*a58d3d2aSXin Li        factor = tf.where(kk>k, minus, plus)
175*a58d3d2aSXin Li        factor = tf.where(kk==k, tf.ones_like(factor), factor)
176*a58d3d2aSXin Li        #print("scale = ", factor)
177*a58d3d2aSXin Li        factor = tf.expand_dims(factor, axis=-1)
178*a58d3d2aSXin Li        #newk = newk * (k/kk)**.2
179*a58d3d2aSXin Li        newk = newk*factor
180*a58d3d2aSXin Li        kx = newk*x
181*a58d3d2aSXin Li        #print("newk = ", newk)
182*a58d3d2aSXin Li        #print("unquantized = ", newk*x)
183*a58d3d2aSXin Li        y = tf.round(kx)
184*a58d3d2aSXin Li
185*a58d3d2aSXin Li    #print(y)
186*a58d3d2aSXin Li    #print(K.mean(K.sum(K.abs(y), axis=-1)))
187*a58d3d2aSXin Li    return y
188*a58d3d2aSXin Li
189*a58d3d2aSXin Lidef pvq_quantize(x, k):
190*a58d3d2aSXin Li    x = x/(1e-15+tf.norm(x, axis=-1,keepdims=True))
191*a58d3d2aSXin Li    quantized = pvq_quant_search(x, k)
192*a58d3d2aSXin Li    quantized = quantized/(1e-15+tf.norm(quantized, axis=-1,keepdims=True))
193*a58d3d2aSXin Li    return x + tf.stop_gradient(quantized - x)
194*a58d3d2aSXin Li
195*a58d3d2aSXin Li
196*a58d3d2aSXin Lidef var_repeat(x):
197*a58d3d2aSXin Li    return tf.repeat(tf.expand_dims(x[0], 1), K.shape(x[1])[1], axis=1)
198*a58d3d2aSXin Li
199*a58d3d2aSXin Linb_state_dim = 24
200*a58d3d2aSXin Li
201*a58d3d2aSXin Lidef new_rdovae_encoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256, training=False):
202*a58d3d2aSXin Li    feat = Input(shape=(None, nb_used_features), batch_size=batch_size)
203*a58d3d2aSXin Li
204*a58d3d2aSXin Li    gru = CuDNNGRU if training else GRU
205*a58d3d2aSXin Li    enc_dense1 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense1')
206*a58d3d2aSXin Li    enc_dense2 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense2')
207*a58d3d2aSXin Li    enc_dense3 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense3')
208*a58d3d2aSXin Li    enc_dense4 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense4')
209*a58d3d2aSXin Li    enc_dense5 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense5')
210*a58d3d2aSXin Li    enc_dense6 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense6')
211*a58d3d2aSXin Li    enc_dense7 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='enc_dense7')
212*a58d3d2aSXin Li    enc_dense8 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='enc_dense8')
213*a58d3d2aSXin Li
214*a58d3d2aSXin Li    #bits_dense = Dense(nb_bits, activation='linear', name='bits_dense')
215*a58d3d2aSXin Li    bits_dense = Conv1D(nb_bits, 4, padding='causal', activation='linear', name='bits_dense')
216*a58d3d2aSXin Li
217*a58d3d2aSXin Li    zero_out = Lambda(lambda x: 0*x)
218*a58d3d2aSXin Li    inputs = Reshape((-1, 2*nb_used_features))(feat)
219*a58d3d2aSXin Li    d1 = enc_dense1(inputs)
220*a58d3d2aSXin Li    d2 = enc_dense2(d1)
221*a58d3d2aSXin Li    d3 = enc_dense3(d2)
222*a58d3d2aSXin Li    d4 = enc_dense4(d3)
223*a58d3d2aSXin Li    d5 = enc_dense5(d4)
224*a58d3d2aSXin Li    d6 = enc_dense6(d5)
225*a58d3d2aSXin Li    d7 = enc_dense7(d6)
226*a58d3d2aSXin Li    d8 = enc_dense8(d7)
227*a58d3d2aSXin Li    pre_out = Concatenate()([d1, d2, d3, d4, d5, d6, d7, d8])
228*a58d3d2aSXin Li    enc_out = bits_dense(pre_out)
229*a58d3d2aSXin Li    global_dense1 = Dense(128, activation='tanh', name='gdense1')
230*a58d3d2aSXin Li    global_dense2 = Dense(nb_state_dim, activation='tanh', name='gdense2')
231*a58d3d2aSXin Li    global_bits = global_dense2(global_dense1(pre_out))
232*a58d3d2aSXin Li
233*a58d3d2aSXin Li    encoder = Model([feat], [enc_out, global_bits], name='encoder')
234*a58d3d2aSXin Li    return encoder
235*a58d3d2aSXin Li
236*a58d3d2aSXin Lidef new_rdovae_decoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256, training=False):
237*a58d3d2aSXin Li    bits_input = Input(shape=(None, nb_bits), batch_size=batch_size, name="dec_bits")
238*a58d3d2aSXin Li    gru_state_input = Input(shape=(nb_state_dim,), batch_size=batch_size, name="dec_state")
239*a58d3d2aSXin Li
240*a58d3d2aSXin Li
241*a58d3d2aSXin Li    gru = CuDNNGRU if training else GRU
242*a58d3d2aSXin Li    dec_dense1 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense1')
243*a58d3d2aSXin Li    dec_dense2 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense2')
244*a58d3d2aSXin Li    dec_dense3 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense3')
245*a58d3d2aSXin Li    dec_dense4 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense4')
246*a58d3d2aSXin Li    dec_dense5 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense5')
247*a58d3d2aSXin Li    dec_dense6 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense6')
248*a58d3d2aSXin Li    dec_dense7 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='dec_dense7')
249*a58d3d2aSXin Li    dec_dense8 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='dec_dense8')
250*a58d3d2aSXin Li
251*a58d3d2aSXin Li    dec_final = Dense(bunch*nb_used_features, activation='linear', name='dec_final')
252*a58d3d2aSXin Li
253*a58d3d2aSXin Li    time_reverse = Lambda(lambda x: K.reverse(x, 1))
254*a58d3d2aSXin Li    #time_reverse = Lambda(lambda x: x)
255*a58d3d2aSXin Li    #gru_state_rep = RepeatVector(64//bunch)(gru_state_input)
256*a58d3d2aSXin Li
257*a58d3d2aSXin Li    #gru_state_rep = Lambda(var_repeat, output_shape=(None, nb_state_dim)) ([gru_state_input, bits_input])
258*a58d3d2aSXin Li    gru_state1 = Dense(cond_size, name="state1", activation='tanh')(gru_state_input)
259*a58d3d2aSXin Li    gru_state2 = Dense(cond_size, name="state2", activation='tanh')(gru_state_input)
260*a58d3d2aSXin Li    gru_state3 = Dense(cond_size, name="state3", activation='tanh')(gru_state_input)
261*a58d3d2aSXin Li
262*a58d3d2aSXin Li    dec1 = dec_dense1(time_reverse(bits_input))
263*a58d3d2aSXin Li    dec2 = dec_dense2(dec1, initial_state=gru_state1)
264*a58d3d2aSXin Li    dec3 = dec_dense3(dec2)
265*a58d3d2aSXin Li    dec4 = dec_dense4(dec3, initial_state=gru_state2)
266*a58d3d2aSXin Li    dec5 = dec_dense5(dec4)
267*a58d3d2aSXin Li    dec6 = dec_dense6(dec5, initial_state=gru_state3)
268*a58d3d2aSXin Li    dec7 = dec_dense7(dec6)
269*a58d3d2aSXin Li    dec8 = dec_dense8(dec7)
270*a58d3d2aSXin Li    output = Reshape((-1, nb_used_features))(dec_final(Concatenate()([dec1, dec2, dec3, dec4, dec5, dec6, dec7, dec8])))
271*a58d3d2aSXin Li    decoder = Model([bits_input, gru_state_input], time_reverse(output), name='decoder')
272*a58d3d2aSXin Li    decoder.nb_bits = nb_bits
273*a58d3d2aSXin Li    decoder.bunch = bunch
274*a58d3d2aSXin Li    return decoder
275*a58d3d2aSXin Li
276*a58d3d2aSXin Lidef new_split_decoder(decoder):
277*a58d3d2aSXin Li    nb_bits = decoder.nb_bits
278*a58d3d2aSXin Li    bunch = decoder.bunch
279*a58d3d2aSXin Li    bits_input = Input(shape=(None, nb_bits), name="split_bits")
280*a58d3d2aSXin Li    gru_state_input = Input(shape=(None,nb_state_dim), name="split_state")
281*a58d3d2aSXin Li
282*a58d3d2aSXin Li    range_select = Lambda(lambda x: x[0][:,x[1]:x[2],:])
283*a58d3d2aSXin Li    elem_select = Lambda(lambda x: x[0][:,x[1],:])
284*a58d3d2aSXin Li    points = [0, 100, 200, 300, 400]
285*a58d3d2aSXin Li    outputs = []
286*a58d3d2aSXin Li    for i in range(len(points)-1):
287*a58d3d2aSXin Li        begin = points[i]//bunch
288*a58d3d2aSXin Li        end = points[i+1]//bunch
289*a58d3d2aSXin Li        state = elem_select([gru_state_input, end-1])
290*a58d3d2aSXin Li        bits = range_select([bits_input, begin, end])
291*a58d3d2aSXin Li        outputs.append(decoder([bits, state]))
292*a58d3d2aSXin Li    output = Concatenate(axis=1)(outputs)
293*a58d3d2aSXin Li    split = Model([bits_input, gru_state_input], output, name="split")
294*a58d3d2aSXin Li    return split
295*a58d3d2aSXin Li
296*a58d3d2aSXin Lidef tensor_concat(x):
297*a58d3d2aSXin Li    #n = x[1]//2
298*a58d3d2aSXin Li    #x = x[0]
299*a58d3d2aSXin Li    n=2
300*a58d3d2aSXin Li    y = []
301*a58d3d2aSXin Li    for i in range(n-1):
302*a58d3d2aSXin Li        offset = 2 * (n-1-i)
303*a58d3d2aSXin Li        tmp = K.concatenate([x[i][:, offset:, :], x[-1][:, -offset:, :]], axis=-2)
304*a58d3d2aSXin Li        y.append(tf.expand_dims(tmp, axis=0))
305*a58d3d2aSXin Li    y.append(tf.expand_dims(x[-1], axis=0))
306*a58d3d2aSXin Li    return Concatenate(axis=0)(y)
307*a58d3d2aSXin Li
308*a58d3d2aSXin Li
309*a58d3d2aSXin Lidef new_rdovae_model(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256, training=False):
310*a58d3d2aSXin Li
311*a58d3d2aSXin Li    feat = Input(shape=(None, nb_used_features), batch_size=batch_size)
312*a58d3d2aSXin Li    quant_id = Input(shape=(None,), batch_size=batch_size)
313*a58d3d2aSXin Li    lambda_val = Input(shape=(None, 1), batch_size=batch_size)
314*a58d3d2aSXin Li    lambda_bunched = AveragePooling1D(pool_size=bunch//2, strides=bunch//2, padding="valid")(lambda_val)
315*a58d3d2aSXin Li    lambda_up = Lambda(lambda x: K.repeat_elements(x, 2, axis=-2))(lambda_val)
316*a58d3d2aSXin Li
317*a58d3d2aSXin Li    qembedding = Embedding(nb_quant, 6*nb_bits, name='quant_embed', embeddings_initializer='zeros')
318*a58d3d2aSXin Li    quant_embed_dec = qembedding(quant_id)
319*a58d3d2aSXin Li    quant_scale = Activation('softplus')(Lambda(lambda x: x[:,:,:nb_bits], name='quant_scale_embed')(quant_embed_dec))
320*a58d3d2aSXin Li
321*a58d3d2aSXin Li    encoder = new_rdovae_encoder(nb_used_features, nb_bits, bunch, nb_quant, batch_size, cond_size, cond_size2, training=training)
322*a58d3d2aSXin Li    ze, gru_state_dec = encoder([feat])
323*a58d3d2aSXin Li    ze = Multiply()([ze, quant_scale])
324*a58d3d2aSXin Li
325*a58d3d2aSXin Li    decoder = new_rdovae_decoder(nb_used_features, nb_bits, bunch, nb_quant, batch_size, cond_size, cond_size2, training=training)
326*a58d3d2aSXin Li    split_decoder = new_split_decoder(decoder)
327*a58d3d2aSXin Li
328*a58d3d2aSXin Li    dead_zone = Activation('softplus')(Lambda(lambda x: x[:,:,nb_bits:2*nb_bits], name='dead_zone_embed')(quant_embed_dec))
329*a58d3d2aSXin Li    soft_distr_embed = Activation('sigmoid')(Lambda(lambda x: x[:,:,2*nb_bits:4*nb_bits], name='soft_distr_embed')(quant_embed_dec))
330*a58d3d2aSXin Li    hard_distr_embed = Activation('sigmoid')(Lambda(lambda x: x[:,:,4*nb_bits:], name='hard_distr_embed')(quant_embed_dec))
331*a58d3d2aSXin Li
332*a58d3d2aSXin Li    noisequant = UniformNoise()
333*a58d3d2aSXin Li    hardquant = Lambda(hard_quantize)
334*a58d3d2aSXin Li    dzone = Lambda(apply_dead_zone)
335*a58d3d2aSXin Li    dze = dzone([ze,dead_zone])
336*a58d3d2aSXin Li    ndze = noisequant(dze)
337*a58d3d2aSXin Li    dze_quant = hardquant(dze)
338*a58d3d2aSXin Li
339*a58d3d2aSXin Li    div = Lambda(lambda x: x[0]/x[1])
340*a58d3d2aSXin Li    dze_quant = div([dze_quant,quant_scale])
341*a58d3d2aSXin Li    ndze_unquant = div([ndze,quant_scale])
342*a58d3d2aSXin Li
343*a58d3d2aSXin Li    mod_select = Lambda(lambda x: x[0][:,x[1]::bunch//2,:])
344*a58d3d2aSXin Li    gru_state_dec = Lambda(lambda x: pvq_quantize(x, 82))(gru_state_dec)
345*a58d3d2aSXin Li    combined_output = []
346*a58d3d2aSXin Li    unquantized_output = []
347*a58d3d2aSXin Li    cat = Concatenate(name="out_cat")
348*a58d3d2aSXin Li    for i in range(bunch//2):
349*a58d3d2aSXin Li        dze_select = mod_select([dze_quant, i])
350*a58d3d2aSXin Li        ndze_select = mod_select([ndze_unquant, i])
351*a58d3d2aSXin Li        state_select = mod_select([gru_state_dec, i])
352*a58d3d2aSXin Li
353*a58d3d2aSXin Li        tmp = split_decoder([dze_select, state_select])
354*a58d3d2aSXin Li        tmp = cat([tmp, lambda_up])
355*a58d3d2aSXin Li        combined_output.append(tmp)
356*a58d3d2aSXin Li
357*a58d3d2aSXin Li        tmp = split_decoder([ndze_select, state_select])
358*a58d3d2aSXin Li        tmp = cat([tmp, lambda_up])
359*a58d3d2aSXin Li        unquantized_output.append(tmp)
360*a58d3d2aSXin Li
361*a58d3d2aSXin Li    concat = Lambda(tensor_concat, name="output")
362*a58d3d2aSXin Li    combined_output = concat(combined_output)
363*a58d3d2aSXin Li    unquantized_output = concat(unquantized_output)
364*a58d3d2aSXin Li
365*a58d3d2aSXin Li    e2 = Concatenate(name="hard_bits")([dze, hard_distr_embed, lambda_val])
366*a58d3d2aSXin Li    e = Concatenate(name="soft_bits")([dze, soft_distr_embed, lambda_val])
367*a58d3d2aSXin Li
368*a58d3d2aSXin Li
369*a58d3d2aSXin Li    model = Model([feat, quant_id, lambda_val], [combined_output, unquantized_output, e, e2], name="end2end")
370*a58d3d2aSXin Li    model.nb_used_features = nb_used_features
371*a58d3d2aSXin Li
372*a58d3d2aSXin Li    return model, encoder, decoder, qembedding
373