1*a58d3d2aSXin Li#!/usr/bin/python3 2*a58d3d2aSXin Li'''Copyright (c) 2022 Amazon 3*a58d3d2aSXin Li 4*a58d3d2aSXin Li Redistribution and use in source and binary forms, with or without 5*a58d3d2aSXin Li modification, are permitted provided that the following conditions 6*a58d3d2aSXin Li are met: 7*a58d3d2aSXin Li 8*a58d3d2aSXin Li - Redistributions of source code must retain the above copyright 9*a58d3d2aSXin Li notice, this list of conditions and the following disclaimer. 10*a58d3d2aSXin Li 11*a58d3d2aSXin Li - Redistributions in binary form must reproduce the above copyright 12*a58d3d2aSXin Li notice, this list of conditions and the following disclaimer in the 13*a58d3d2aSXin Li documentation and/or other materials provided with the distribution. 14*a58d3d2aSXin Li 15*a58d3d2aSXin Li THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 16*a58d3d2aSXin Li ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 17*a58d3d2aSXin Li LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 18*a58d3d2aSXin Li A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR 19*a58d3d2aSXin Li CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20*a58d3d2aSXin Li EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21*a58d3d2aSXin Li PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22*a58d3d2aSXin Li PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 23*a58d3d2aSXin Li LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 24*a58d3d2aSXin Li NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25*a58d3d2aSXin Li SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26*a58d3d2aSXin Li''' 27*a58d3d2aSXin Li 28*a58d3d2aSXin Liimport math 29*a58d3d2aSXin Liimport tensorflow as tf 30*a58d3d2aSXin Lifrom tensorflow.keras.models import Model 31*a58d3d2aSXin Lifrom tensorflow.keras.layers import Input, GRU, Dense, Embedding, Reshape, Concatenate, Lambda, Conv1D, Multiply, Add, Bidirectional, MaxPooling1D, Activation, GaussianNoise, AveragePooling1D, RepeatVector 32*a58d3d2aSXin Lifrom tensorflow.compat.v1.keras.layers import CuDNNGRU 33*a58d3d2aSXin Lifrom tensorflow.keras import backend as K 34*a58d3d2aSXin Lifrom tensorflow.keras.constraints import Constraint 35*a58d3d2aSXin Lifrom tensorflow.keras.initializers import Initializer 36*a58d3d2aSXin Lifrom tensorflow.keras.callbacks import Callback 37*a58d3d2aSXin Lifrom tensorflow.keras.regularizers import l1 38*a58d3d2aSXin Liimport numpy as np 39*a58d3d2aSXin Liimport h5py 40*a58d3d2aSXin Lifrom uniform_noise import UniformNoise 41*a58d3d2aSXin Li 42*a58d3d2aSXin Liclass WeightClip(Constraint): 43*a58d3d2aSXin Li '''Clips the weights incident to each hidden unit to be inside a range 44*a58d3d2aSXin Li ''' 45*a58d3d2aSXin Li def __init__(self, c=2): 46*a58d3d2aSXin Li self.c = c 47*a58d3d2aSXin Li 48*a58d3d2aSXin Li def __call__(self, p): 49*a58d3d2aSXin Li # Ensure that abs of adjacent weights don't sum to more than 127. Otherwise there's a risk of 50*a58d3d2aSXin Li # saturation when implementing dot products with SSSE3 or AVX2. 51*a58d3d2aSXin Li return self.c*p/tf.maximum(self.c, tf.repeat(tf.abs(p[:, 1::2])+tf.abs(p[:, 0::2]), 2, axis=1)) 52*a58d3d2aSXin Li #return K.clip(p, -self.c, self.c) 53*a58d3d2aSXin Li 54*a58d3d2aSXin Li def get_config(self): 55*a58d3d2aSXin Li return {'name': self.__class__.__name__, 56*a58d3d2aSXin Li 'c': self.c} 57*a58d3d2aSXin Li 58*a58d3d2aSXin Liconstraint = WeightClip(0.496) 59*a58d3d2aSXin Li 60*a58d3d2aSXin Lidef soft_quantize(x): 61*a58d3d2aSXin Li #x = 4*x 62*a58d3d2aSXin Li #x = x - (.25/np.math.pi)*tf.math.sin(2*np.math.pi*x) 63*a58d3d2aSXin Li #x = x - (.25/np.math.pi)*tf.math.sin(2*np.math.pi*x) 64*a58d3d2aSXin Li #x = x - (.25/np.math.pi)*tf.math.sin(2*np.math.pi*x) 65*a58d3d2aSXin Li return x 66*a58d3d2aSXin Li 67*a58d3d2aSXin Lidef noise_quantize(x): 68*a58d3d2aSXin Li return soft_quantize(x + (K.random_uniform((128, 16, 80))-.5) ) 69*a58d3d2aSXin Li 70*a58d3d2aSXin Lidef hard_quantize(x): 71*a58d3d2aSXin Li x = soft_quantize(x) 72*a58d3d2aSXin Li quantized = tf.round(x) 73*a58d3d2aSXin Li return x + tf.stop_gradient(quantized - x) 74*a58d3d2aSXin Li 75*a58d3d2aSXin Lidef apply_dead_zone(x): 76*a58d3d2aSXin Li d = x[1]*.05 77*a58d3d2aSXin Li x = x[0] 78*a58d3d2aSXin Li y = x - d*tf.math.tanh(x/(.1+d)) 79*a58d3d2aSXin Li return y 80*a58d3d2aSXin Li 81*a58d3d2aSXin Lidef rate_loss(y_true,y_pred): 82*a58d3d2aSXin Li log2_e = 1.4427 83*a58d3d2aSXin Li n = y_pred.shape[-1] 84*a58d3d2aSXin Li C = n - log2_e*np.math.log(np.math.gamma(n)) 85*a58d3d2aSXin Li k = K.sum(K.abs(y_pred), axis=-1) 86*a58d3d2aSXin Li p = 1.5 87*a58d3d2aSXin Li #rate = C + (n-1)*log2_e*tf.math.log((k**p + (n/5)**p)**(1/p)) 88*a58d3d2aSXin Li rate = C + (n-1)*log2_e*tf.math.log(k + .112*n**2/(n/1.8+k) ) 89*a58d3d2aSXin Li return K.mean(rate) 90*a58d3d2aSXin Li 91*a58d3d2aSXin Lieps=1e-6 92*a58d3d2aSXin Lidef safelog2(x): 93*a58d3d2aSXin Li log2_e = 1.4427 94*a58d3d2aSXin Li return log2_e*tf.math.log(eps+x) 95*a58d3d2aSXin Li 96*a58d3d2aSXin Lidef feat_dist_loss(y_true,y_pred): 97*a58d3d2aSXin Li lambda_1 = 1./K.sqrt(y_pred[:,:,:,-1]) 98*a58d3d2aSXin Li y_pred = y_pred[:,:,:,:-1] 99*a58d3d2aSXin Li ceps = y_pred[:,:,:,:18] - y_true[:,:,:18] 100*a58d3d2aSXin Li pitch = 2*(y_pred[:,:,:,18:19] - y_true[:,:,18:19])/(y_true[:,:,18:19] + 2) 101*a58d3d2aSXin Li corr = y_pred[:,:,:,19:] - y_true[:,:,19:] 102*a58d3d2aSXin Li pitch_weight = K.square(K.maximum(0., y_true[:,:,19:]+.5)) 103*a58d3d2aSXin Li return K.mean(lambda_1*K.mean(K.square(ceps) + 10*(1/18.)*K.abs(pitch)*pitch_weight + (1/18.)*K.square(corr), axis=-1)) 104*a58d3d2aSXin Li 105*a58d3d2aSXin Lidef sq1_rate_loss(y_true,y_pred): 106*a58d3d2aSXin Li lambda_val = K.sqrt(y_pred[:,:,-1]) 107*a58d3d2aSXin Li y_pred = y_pred[:,:,:-1] 108*a58d3d2aSXin Li log2_e = 1.4427 109*a58d3d2aSXin Li n = y_pred.shape[-1]//3 110*a58d3d2aSXin Li r = (y_pred[:,:,2*n:]) 111*a58d3d2aSXin Li p0 = (y_pred[:,:,n:2*n]) 112*a58d3d2aSXin Li p0 = 1-r**(.5+.5*p0) 113*a58d3d2aSXin Li y_pred = y_pred[:,:,:n] 114*a58d3d2aSXin Li y_pred = soft_quantize(y_pred) 115*a58d3d2aSXin Li 116*a58d3d2aSXin Li y0 = K.maximum(0., 1. - K.abs(y_pred))**2 117*a58d3d2aSXin Li rate = -y0*safelog2(p0*r**K.abs(y_pred)) - (1-y0)*safelog2(.5*(1-p0)*(1-r)*r**(K.abs(y_pred)-1)) 118*a58d3d2aSXin Li rate = -safelog2(-.5*tf.math.log(r)*r**K.abs(y_pred)) 119*a58d3d2aSXin Li rate = -safelog2((1-r)/(1+r)*r**K.abs(y_pred)) 120*a58d3d2aSXin Li #rate = -safelog2(- tf.math.sinh(.5*tf.math.log(r))* r**K.abs(y_pred) - tf.math.cosh(K.maximum(0., .5 - K.abs(y_pred))*tf.math.log(r)) + 1) 121*a58d3d2aSXin Li rate = lambda_val*K.sum(rate, axis=-1) 122*a58d3d2aSXin Li return K.mean(rate) 123*a58d3d2aSXin Li 124*a58d3d2aSXin Lidef sq2_rate_loss(y_true,y_pred): 125*a58d3d2aSXin Li lambda_val = K.sqrt(y_pred[:,:,-1]) 126*a58d3d2aSXin Li y_pred = y_pred[:,:,:-1] 127*a58d3d2aSXin Li log2_e = 1.4427 128*a58d3d2aSXin Li n = y_pred.shape[-1]//3 129*a58d3d2aSXin Li r = y_pred[:,:,2*n:] 130*a58d3d2aSXin Li p0 = y_pred[:,:,n:2*n] 131*a58d3d2aSXin Li p0 = 1-r**(.5+.5*p0) 132*a58d3d2aSXin Li #theta = K.minimum(1., .5 + 0*p0 - 0.04*tf.math.log(r)) 133*a58d3d2aSXin Li #p0 = 1-r**theta 134*a58d3d2aSXin Li y_pred = tf.round(y_pred[:,:,:n]) 135*a58d3d2aSXin Li y0 = K.maximum(0., 1. - K.abs(y_pred))**2 136*a58d3d2aSXin Li rate = -y0*safelog2(p0*r**K.abs(y_pred)) - (1-y0)*safelog2(.5*(1-p0)*(1-r)*r**(K.abs(y_pred)-1)) 137*a58d3d2aSXin Li rate = lambda_val*K.sum(rate, axis=-1) 138*a58d3d2aSXin Li return K.mean(rate) 139*a58d3d2aSXin Li 140*a58d3d2aSXin Lidef sq_rate_metric(y_true,y_pred, reduce=True): 141*a58d3d2aSXin Li y_pred = y_pred[:,:,:-1] 142*a58d3d2aSXin Li log2_e = 1.4427 143*a58d3d2aSXin Li n = y_pred.shape[-1]//3 144*a58d3d2aSXin Li r = y_pred[:,:,2*n:] 145*a58d3d2aSXin Li p0 = y_pred[:,:,n:2*n] 146*a58d3d2aSXin Li p0 = 1-r**(.5+.5*p0) 147*a58d3d2aSXin Li #theta = K.minimum(1., .5 + 0*p0 - 0.04*tf.math.log(r)) 148*a58d3d2aSXin Li #p0 = 1-r**theta 149*a58d3d2aSXin Li y_pred = tf.round(y_pred[:,:,:n]) 150*a58d3d2aSXin Li y0 = K.maximum(0., 1. - K.abs(y_pred))**2 151*a58d3d2aSXin Li rate = -y0*safelog2(p0*r**K.abs(y_pred)) - (1-y0)*safelog2(.5*(1-p0)*(1-r)*r**(K.abs(y_pred)-1)) 152*a58d3d2aSXin Li rate = K.sum(rate, axis=-1) 153*a58d3d2aSXin Li if reduce: 154*a58d3d2aSXin Li rate = K.mean(rate) 155*a58d3d2aSXin Li return rate 156*a58d3d2aSXin Li 157*a58d3d2aSXin Lidef pvq_quant_search(x, k): 158*a58d3d2aSXin Li x = x/tf.reduce_sum(tf.abs(x), axis=-1, keepdims=True) 159*a58d3d2aSXin Li kx = k*x 160*a58d3d2aSXin Li y = tf.round(kx) 161*a58d3d2aSXin Li newk = k 162*a58d3d2aSXin Li 163*a58d3d2aSXin Li for j in range(10): 164*a58d3d2aSXin Li #print("y = ", y) 165*a58d3d2aSXin Li #print("iteration ", j) 166*a58d3d2aSXin Li abs_y = tf.abs(y) 167*a58d3d2aSXin Li abs_kx = tf.abs(kx) 168*a58d3d2aSXin Li kk=tf.reduce_sum(abs_y, axis=-1) 169*a58d3d2aSXin Li #print("sums = ", kk) 170*a58d3d2aSXin Li plus = 1.000001*tf.reduce_min((abs_y+.5)/(abs_kx+1e-15), axis=-1) 171*a58d3d2aSXin Li minus = .999999*tf.reduce_max((abs_y-.5)/(abs_kx+1e-15), axis=-1) 172*a58d3d2aSXin Li #print("plus = ", plus) 173*a58d3d2aSXin Li #print("minus = ", minus) 174*a58d3d2aSXin Li factor = tf.where(kk>k, minus, plus) 175*a58d3d2aSXin Li factor = tf.where(kk==k, tf.ones_like(factor), factor) 176*a58d3d2aSXin Li #print("scale = ", factor) 177*a58d3d2aSXin Li factor = tf.expand_dims(factor, axis=-1) 178*a58d3d2aSXin Li #newk = newk * (k/kk)**.2 179*a58d3d2aSXin Li newk = newk*factor 180*a58d3d2aSXin Li kx = newk*x 181*a58d3d2aSXin Li #print("newk = ", newk) 182*a58d3d2aSXin Li #print("unquantized = ", newk*x) 183*a58d3d2aSXin Li y = tf.round(kx) 184*a58d3d2aSXin Li 185*a58d3d2aSXin Li #print(y) 186*a58d3d2aSXin Li #print(K.mean(K.sum(K.abs(y), axis=-1))) 187*a58d3d2aSXin Li return y 188*a58d3d2aSXin Li 189*a58d3d2aSXin Lidef pvq_quantize(x, k): 190*a58d3d2aSXin Li x = x/(1e-15+tf.norm(x, axis=-1,keepdims=True)) 191*a58d3d2aSXin Li quantized = pvq_quant_search(x, k) 192*a58d3d2aSXin Li quantized = quantized/(1e-15+tf.norm(quantized, axis=-1,keepdims=True)) 193*a58d3d2aSXin Li return x + tf.stop_gradient(quantized - x) 194*a58d3d2aSXin Li 195*a58d3d2aSXin Li 196*a58d3d2aSXin Lidef var_repeat(x): 197*a58d3d2aSXin Li return tf.repeat(tf.expand_dims(x[0], 1), K.shape(x[1])[1], axis=1) 198*a58d3d2aSXin Li 199*a58d3d2aSXin Linb_state_dim = 24 200*a58d3d2aSXin Li 201*a58d3d2aSXin Lidef new_rdovae_encoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256, training=False): 202*a58d3d2aSXin Li feat = Input(shape=(None, nb_used_features), batch_size=batch_size) 203*a58d3d2aSXin Li 204*a58d3d2aSXin Li gru = CuDNNGRU if training else GRU 205*a58d3d2aSXin Li enc_dense1 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense1') 206*a58d3d2aSXin Li enc_dense2 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense2') 207*a58d3d2aSXin Li enc_dense3 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense3') 208*a58d3d2aSXin Li enc_dense4 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense4') 209*a58d3d2aSXin Li enc_dense5 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense5') 210*a58d3d2aSXin Li enc_dense6 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense6') 211*a58d3d2aSXin Li enc_dense7 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='enc_dense7') 212*a58d3d2aSXin Li enc_dense8 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='enc_dense8') 213*a58d3d2aSXin Li 214*a58d3d2aSXin Li #bits_dense = Dense(nb_bits, activation='linear', name='bits_dense') 215*a58d3d2aSXin Li bits_dense = Conv1D(nb_bits, 4, padding='causal', activation='linear', name='bits_dense') 216*a58d3d2aSXin Li 217*a58d3d2aSXin Li zero_out = Lambda(lambda x: 0*x) 218*a58d3d2aSXin Li inputs = Reshape((-1, 2*nb_used_features))(feat) 219*a58d3d2aSXin Li d1 = enc_dense1(inputs) 220*a58d3d2aSXin Li d2 = enc_dense2(d1) 221*a58d3d2aSXin Li d3 = enc_dense3(d2) 222*a58d3d2aSXin Li d4 = enc_dense4(d3) 223*a58d3d2aSXin Li d5 = enc_dense5(d4) 224*a58d3d2aSXin Li d6 = enc_dense6(d5) 225*a58d3d2aSXin Li d7 = enc_dense7(d6) 226*a58d3d2aSXin Li d8 = enc_dense8(d7) 227*a58d3d2aSXin Li pre_out = Concatenate()([d1, d2, d3, d4, d5, d6, d7, d8]) 228*a58d3d2aSXin Li enc_out = bits_dense(pre_out) 229*a58d3d2aSXin Li global_dense1 = Dense(128, activation='tanh', name='gdense1') 230*a58d3d2aSXin Li global_dense2 = Dense(nb_state_dim, activation='tanh', name='gdense2') 231*a58d3d2aSXin Li global_bits = global_dense2(global_dense1(pre_out)) 232*a58d3d2aSXin Li 233*a58d3d2aSXin Li encoder = Model([feat], [enc_out, global_bits], name='encoder') 234*a58d3d2aSXin Li return encoder 235*a58d3d2aSXin Li 236*a58d3d2aSXin Lidef new_rdovae_decoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256, training=False): 237*a58d3d2aSXin Li bits_input = Input(shape=(None, nb_bits), batch_size=batch_size, name="dec_bits") 238*a58d3d2aSXin Li gru_state_input = Input(shape=(nb_state_dim,), batch_size=batch_size, name="dec_state") 239*a58d3d2aSXin Li 240*a58d3d2aSXin Li 241*a58d3d2aSXin Li gru = CuDNNGRU if training else GRU 242*a58d3d2aSXin Li dec_dense1 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense1') 243*a58d3d2aSXin Li dec_dense2 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense2') 244*a58d3d2aSXin Li dec_dense3 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense3') 245*a58d3d2aSXin Li dec_dense4 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense4') 246*a58d3d2aSXin Li dec_dense5 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense5') 247*a58d3d2aSXin Li dec_dense6 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense6') 248*a58d3d2aSXin Li dec_dense7 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='dec_dense7') 249*a58d3d2aSXin Li dec_dense8 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='dec_dense8') 250*a58d3d2aSXin Li 251*a58d3d2aSXin Li dec_final = Dense(bunch*nb_used_features, activation='linear', name='dec_final') 252*a58d3d2aSXin Li 253*a58d3d2aSXin Li time_reverse = Lambda(lambda x: K.reverse(x, 1)) 254*a58d3d2aSXin Li #time_reverse = Lambda(lambda x: x) 255*a58d3d2aSXin Li #gru_state_rep = RepeatVector(64//bunch)(gru_state_input) 256*a58d3d2aSXin Li 257*a58d3d2aSXin Li #gru_state_rep = Lambda(var_repeat, output_shape=(None, nb_state_dim)) ([gru_state_input, bits_input]) 258*a58d3d2aSXin Li gru_state1 = Dense(cond_size, name="state1", activation='tanh')(gru_state_input) 259*a58d3d2aSXin Li gru_state2 = Dense(cond_size, name="state2", activation='tanh')(gru_state_input) 260*a58d3d2aSXin Li gru_state3 = Dense(cond_size, name="state3", activation='tanh')(gru_state_input) 261*a58d3d2aSXin Li 262*a58d3d2aSXin Li dec1 = dec_dense1(time_reverse(bits_input)) 263*a58d3d2aSXin Li dec2 = dec_dense2(dec1, initial_state=gru_state1) 264*a58d3d2aSXin Li dec3 = dec_dense3(dec2) 265*a58d3d2aSXin Li dec4 = dec_dense4(dec3, initial_state=gru_state2) 266*a58d3d2aSXin Li dec5 = dec_dense5(dec4) 267*a58d3d2aSXin Li dec6 = dec_dense6(dec5, initial_state=gru_state3) 268*a58d3d2aSXin Li dec7 = dec_dense7(dec6) 269*a58d3d2aSXin Li dec8 = dec_dense8(dec7) 270*a58d3d2aSXin Li output = Reshape((-1, nb_used_features))(dec_final(Concatenate()([dec1, dec2, dec3, dec4, dec5, dec6, dec7, dec8]))) 271*a58d3d2aSXin Li decoder = Model([bits_input, gru_state_input], time_reverse(output), name='decoder') 272*a58d3d2aSXin Li decoder.nb_bits = nb_bits 273*a58d3d2aSXin Li decoder.bunch = bunch 274*a58d3d2aSXin Li return decoder 275*a58d3d2aSXin Li 276*a58d3d2aSXin Lidef new_split_decoder(decoder): 277*a58d3d2aSXin Li nb_bits = decoder.nb_bits 278*a58d3d2aSXin Li bunch = decoder.bunch 279*a58d3d2aSXin Li bits_input = Input(shape=(None, nb_bits), name="split_bits") 280*a58d3d2aSXin Li gru_state_input = Input(shape=(None,nb_state_dim), name="split_state") 281*a58d3d2aSXin Li 282*a58d3d2aSXin Li range_select = Lambda(lambda x: x[0][:,x[1]:x[2],:]) 283*a58d3d2aSXin Li elem_select = Lambda(lambda x: x[0][:,x[1],:]) 284*a58d3d2aSXin Li points = [0, 100, 200, 300, 400] 285*a58d3d2aSXin Li outputs = [] 286*a58d3d2aSXin Li for i in range(len(points)-1): 287*a58d3d2aSXin Li begin = points[i]//bunch 288*a58d3d2aSXin Li end = points[i+1]//bunch 289*a58d3d2aSXin Li state = elem_select([gru_state_input, end-1]) 290*a58d3d2aSXin Li bits = range_select([bits_input, begin, end]) 291*a58d3d2aSXin Li outputs.append(decoder([bits, state])) 292*a58d3d2aSXin Li output = Concatenate(axis=1)(outputs) 293*a58d3d2aSXin Li split = Model([bits_input, gru_state_input], output, name="split") 294*a58d3d2aSXin Li return split 295*a58d3d2aSXin Li 296*a58d3d2aSXin Lidef tensor_concat(x): 297*a58d3d2aSXin Li #n = x[1]//2 298*a58d3d2aSXin Li #x = x[0] 299*a58d3d2aSXin Li n=2 300*a58d3d2aSXin Li y = [] 301*a58d3d2aSXin Li for i in range(n-1): 302*a58d3d2aSXin Li offset = 2 * (n-1-i) 303*a58d3d2aSXin Li tmp = K.concatenate([x[i][:, offset:, :], x[-1][:, -offset:, :]], axis=-2) 304*a58d3d2aSXin Li y.append(tf.expand_dims(tmp, axis=0)) 305*a58d3d2aSXin Li y.append(tf.expand_dims(x[-1], axis=0)) 306*a58d3d2aSXin Li return Concatenate(axis=0)(y) 307*a58d3d2aSXin Li 308*a58d3d2aSXin Li 309*a58d3d2aSXin Lidef new_rdovae_model(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256, training=False): 310*a58d3d2aSXin Li 311*a58d3d2aSXin Li feat = Input(shape=(None, nb_used_features), batch_size=batch_size) 312*a58d3d2aSXin Li quant_id = Input(shape=(None,), batch_size=batch_size) 313*a58d3d2aSXin Li lambda_val = Input(shape=(None, 1), batch_size=batch_size) 314*a58d3d2aSXin Li lambda_bunched = AveragePooling1D(pool_size=bunch//2, strides=bunch//2, padding="valid")(lambda_val) 315*a58d3d2aSXin Li lambda_up = Lambda(lambda x: K.repeat_elements(x, 2, axis=-2))(lambda_val) 316*a58d3d2aSXin Li 317*a58d3d2aSXin Li qembedding = Embedding(nb_quant, 6*nb_bits, name='quant_embed', embeddings_initializer='zeros') 318*a58d3d2aSXin Li quant_embed_dec = qembedding(quant_id) 319*a58d3d2aSXin Li quant_scale = Activation('softplus')(Lambda(lambda x: x[:,:,:nb_bits], name='quant_scale_embed')(quant_embed_dec)) 320*a58d3d2aSXin Li 321*a58d3d2aSXin Li encoder = new_rdovae_encoder(nb_used_features, nb_bits, bunch, nb_quant, batch_size, cond_size, cond_size2, training=training) 322*a58d3d2aSXin Li ze, gru_state_dec = encoder([feat]) 323*a58d3d2aSXin Li ze = Multiply()([ze, quant_scale]) 324*a58d3d2aSXin Li 325*a58d3d2aSXin Li decoder = new_rdovae_decoder(nb_used_features, nb_bits, bunch, nb_quant, batch_size, cond_size, cond_size2, training=training) 326*a58d3d2aSXin Li split_decoder = new_split_decoder(decoder) 327*a58d3d2aSXin Li 328*a58d3d2aSXin Li dead_zone = Activation('softplus')(Lambda(lambda x: x[:,:,nb_bits:2*nb_bits], name='dead_zone_embed')(quant_embed_dec)) 329*a58d3d2aSXin Li soft_distr_embed = Activation('sigmoid')(Lambda(lambda x: x[:,:,2*nb_bits:4*nb_bits], name='soft_distr_embed')(quant_embed_dec)) 330*a58d3d2aSXin Li hard_distr_embed = Activation('sigmoid')(Lambda(lambda x: x[:,:,4*nb_bits:], name='hard_distr_embed')(quant_embed_dec)) 331*a58d3d2aSXin Li 332*a58d3d2aSXin Li noisequant = UniformNoise() 333*a58d3d2aSXin Li hardquant = Lambda(hard_quantize) 334*a58d3d2aSXin Li dzone = Lambda(apply_dead_zone) 335*a58d3d2aSXin Li dze = dzone([ze,dead_zone]) 336*a58d3d2aSXin Li ndze = noisequant(dze) 337*a58d3d2aSXin Li dze_quant = hardquant(dze) 338*a58d3d2aSXin Li 339*a58d3d2aSXin Li div = Lambda(lambda x: x[0]/x[1]) 340*a58d3d2aSXin Li dze_quant = div([dze_quant,quant_scale]) 341*a58d3d2aSXin Li ndze_unquant = div([ndze,quant_scale]) 342*a58d3d2aSXin Li 343*a58d3d2aSXin Li mod_select = Lambda(lambda x: x[0][:,x[1]::bunch//2,:]) 344*a58d3d2aSXin Li gru_state_dec = Lambda(lambda x: pvq_quantize(x, 82))(gru_state_dec) 345*a58d3d2aSXin Li combined_output = [] 346*a58d3d2aSXin Li unquantized_output = [] 347*a58d3d2aSXin Li cat = Concatenate(name="out_cat") 348*a58d3d2aSXin Li for i in range(bunch//2): 349*a58d3d2aSXin Li dze_select = mod_select([dze_quant, i]) 350*a58d3d2aSXin Li ndze_select = mod_select([ndze_unquant, i]) 351*a58d3d2aSXin Li state_select = mod_select([gru_state_dec, i]) 352*a58d3d2aSXin Li 353*a58d3d2aSXin Li tmp = split_decoder([dze_select, state_select]) 354*a58d3d2aSXin Li tmp = cat([tmp, lambda_up]) 355*a58d3d2aSXin Li combined_output.append(tmp) 356*a58d3d2aSXin Li 357*a58d3d2aSXin Li tmp = split_decoder([ndze_select, state_select]) 358*a58d3d2aSXin Li tmp = cat([tmp, lambda_up]) 359*a58d3d2aSXin Li unquantized_output.append(tmp) 360*a58d3d2aSXin Li 361*a58d3d2aSXin Li concat = Lambda(tensor_concat, name="output") 362*a58d3d2aSXin Li combined_output = concat(combined_output) 363*a58d3d2aSXin Li unquantized_output = concat(unquantized_output) 364*a58d3d2aSXin Li 365*a58d3d2aSXin Li e2 = Concatenate(name="hard_bits")([dze, hard_distr_embed, lambda_val]) 366*a58d3d2aSXin Li e = Concatenate(name="soft_bits")([dze, soft_distr_embed, lambda_val]) 367*a58d3d2aSXin Li 368*a58d3d2aSXin Li 369*a58d3d2aSXin Li model = Model([feat, quant_id, lambda_val], [combined_output, unquantized_output, e, e2], name="end2end") 370*a58d3d2aSXin Li model.nb_used_features = nb_used_features 371*a58d3d2aSXin Li 372*a58d3d2aSXin Li return model, encoder, decoder, qembedding 373