1#!/usr/bin/python3 2'''Copyright (c) 2022 Amazon 3 4 Redistribution and use in source and binary forms, with or without 5 modification, are permitted provided that the following conditions 6 are met: 7 8 - Redistributions of source code must retain the above copyright 9 notice, this list of conditions and the following disclaimer. 10 11 - Redistributions in binary form must reproduce the above copyright 12 notice, this list of conditions and the following disclaimer in the 13 documentation and/or other materials provided with the distribution. 14 15 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 16 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 17 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 18 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR 19 CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 23 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 24 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26''' 27 28import math 29import tensorflow as tf 30from tensorflow.keras.models import Model 31from tensorflow.keras.layers import Input, GRU, Dense, Embedding, Reshape, Concatenate, Lambda, Conv1D, Multiply, Add, Bidirectional, MaxPooling1D, Activation, GaussianNoise, AveragePooling1D, RepeatVector 32from tensorflow.compat.v1.keras.layers import CuDNNGRU 33from tensorflow.keras import backend as K 34from tensorflow.keras.constraints import Constraint 35from tensorflow.keras.initializers import Initializer 36from tensorflow.keras.callbacks import Callback 37from tensorflow.keras.regularizers import l1 38import numpy as np 39import h5py 40from uniform_noise import UniformNoise 41 42class WeightClip(Constraint): 43 '''Clips the weights incident to each hidden unit to be inside a range 44 ''' 45 def __init__(self, c=2): 46 self.c = c 47 48 def __call__(self, p): 49 # Ensure that abs of adjacent weights don't sum to more than 127. Otherwise there's a risk of 50 # saturation when implementing dot products with SSSE3 or AVX2. 51 return self.c*p/tf.maximum(self.c, tf.repeat(tf.abs(p[:, 1::2])+tf.abs(p[:, 0::2]), 2, axis=1)) 52 #return K.clip(p, -self.c, self.c) 53 54 def get_config(self): 55 return {'name': self.__class__.__name__, 56 'c': self.c} 57 58constraint = WeightClip(0.496) 59 60def soft_quantize(x): 61 #x = 4*x 62 #x = x - (.25/np.math.pi)*tf.math.sin(2*np.math.pi*x) 63 #x = x - (.25/np.math.pi)*tf.math.sin(2*np.math.pi*x) 64 #x = x - (.25/np.math.pi)*tf.math.sin(2*np.math.pi*x) 65 return x 66 67def noise_quantize(x): 68 return soft_quantize(x + (K.random_uniform((128, 16, 80))-.5) ) 69 70def hard_quantize(x): 71 x = soft_quantize(x) 72 quantized = tf.round(x) 73 return x + tf.stop_gradient(quantized - x) 74 75def apply_dead_zone(x): 76 d = x[1]*.05 77 x = x[0] 78 y = x - d*tf.math.tanh(x/(.1+d)) 79 return y 80 81def rate_loss(y_true,y_pred): 82 log2_e = 1.4427 83 n = y_pred.shape[-1] 84 C = n - log2_e*np.math.log(np.math.gamma(n)) 85 k = K.sum(K.abs(y_pred), axis=-1) 86 p = 1.5 87 #rate = C + (n-1)*log2_e*tf.math.log((k**p + (n/5)**p)**(1/p)) 88 rate = C + (n-1)*log2_e*tf.math.log(k + .112*n**2/(n/1.8+k) ) 89 return K.mean(rate) 90 91eps=1e-6 92def safelog2(x): 93 log2_e = 1.4427 94 return log2_e*tf.math.log(eps+x) 95 96def feat_dist_loss(y_true,y_pred): 97 lambda_1 = 1./K.sqrt(y_pred[:,:,:,-1]) 98 y_pred = y_pred[:,:,:,:-1] 99 ceps = y_pred[:,:,:,:18] - y_true[:,:,:18] 100 pitch = 2*(y_pred[:,:,:,18:19] - y_true[:,:,18:19])/(y_true[:,:,18:19] + 2) 101 corr = y_pred[:,:,:,19:] - y_true[:,:,19:] 102 pitch_weight = K.square(K.maximum(0., y_true[:,:,19:]+.5)) 103 return K.mean(lambda_1*K.mean(K.square(ceps) + 10*(1/18.)*K.abs(pitch)*pitch_weight + (1/18.)*K.square(corr), axis=-1)) 104 105def sq1_rate_loss(y_true,y_pred): 106 lambda_val = K.sqrt(y_pred[:,:,-1]) 107 y_pred = y_pred[:,:,:-1] 108 log2_e = 1.4427 109 n = y_pred.shape[-1]//3 110 r = (y_pred[:,:,2*n:]) 111 p0 = (y_pred[:,:,n:2*n]) 112 p0 = 1-r**(.5+.5*p0) 113 y_pred = y_pred[:,:,:n] 114 y_pred = soft_quantize(y_pred) 115 116 y0 = K.maximum(0., 1. - K.abs(y_pred))**2 117 rate = -y0*safelog2(p0*r**K.abs(y_pred)) - (1-y0)*safelog2(.5*(1-p0)*(1-r)*r**(K.abs(y_pred)-1)) 118 rate = -safelog2(-.5*tf.math.log(r)*r**K.abs(y_pred)) 119 rate = -safelog2((1-r)/(1+r)*r**K.abs(y_pred)) 120 #rate = -safelog2(- tf.math.sinh(.5*tf.math.log(r))* r**K.abs(y_pred) - tf.math.cosh(K.maximum(0., .5 - K.abs(y_pred))*tf.math.log(r)) + 1) 121 rate = lambda_val*K.sum(rate, axis=-1) 122 return K.mean(rate) 123 124def sq2_rate_loss(y_true,y_pred): 125 lambda_val = K.sqrt(y_pred[:,:,-1]) 126 y_pred = y_pred[:,:,:-1] 127 log2_e = 1.4427 128 n = y_pred.shape[-1]//3 129 r = y_pred[:,:,2*n:] 130 p0 = y_pred[:,:,n:2*n] 131 p0 = 1-r**(.5+.5*p0) 132 #theta = K.minimum(1., .5 + 0*p0 - 0.04*tf.math.log(r)) 133 #p0 = 1-r**theta 134 y_pred = tf.round(y_pred[:,:,:n]) 135 y0 = K.maximum(0., 1. - K.abs(y_pred))**2 136 rate = -y0*safelog2(p0*r**K.abs(y_pred)) - (1-y0)*safelog2(.5*(1-p0)*(1-r)*r**(K.abs(y_pred)-1)) 137 rate = lambda_val*K.sum(rate, axis=-1) 138 return K.mean(rate) 139 140def sq_rate_metric(y_true,y_pred, reduce=True): 141 y_pred = y_pred[:,:,:-1] 142 log2_e = 1.4427 143 n = y_pred.shape[-1]//3 144 r = y_pred[:,:,2*n:] 145 p0 = y_pred[:,:,n:2*n] 146 p0 = 1-r**(.5+.5*p0) 147 #theta = K.minimum(1., .5 + 0*p0 - 0.04*tf.math.log(r)) 148 #p0 = 1-r**theta 149 y_pred = tf.round(y_pred[:,:,:n]) 150 y0 = K.maximum(0., 1. - K.abs(y_pred))**2 151 rate = -y0*safelog2(p0*r**K.abs(y_pred)) - (1-y0)*safelog2(.5*(1-p0)*(1-r)*r**(K.abs(y_pred)-1)) 152 rate = K.sum(rate, axis=-1) 153 if reduce: 154 rate = K.mean(rate) 155 return rate 156 157def pvq_quant_search(x, k): 158 x = x/tf.reduce_sum(tf.abs(x), axis=-1, keepdims=True) 159 kx = k*x 160 y = tf.round(kx) 161 newk = k 162 163 for j in range(10): 164 #print("y = ", y) 165 #print("iteration ", j) 166 abs_y = tf.abs(y) 167 abs_kx = tf.abs(kx) 168 kk=tf.reduce_sum(abs_y, axis=-1) 169 #print("sums = ", kk) 170 plus = 1.000001*tf.reduce_min((abs_y+.5)/(abs_kx+1e-15), axis=-1) 171 minus = .999999*tf.reduce_max((abs_y-.5)/(abs_kx+1e-15), axis=-1) 172 #print("plus = ", plus) 173 #print("minus = ", minus) 174 factor = tf.where(kk>k, minus, plus) 175 factor = tf.where(kk==k, tf.ones_like(factor), factor) 176 #print("scale = ", factor) 177 factor = tf.expand_dims(factor, axis=-1) 178 #newk = newk * (k/kk)**.2 179 newk = newk*factor 180 kx = newk*x 181 #print("newk = ", newk) 182 #print("unquantized = ", newk*x) 183 y = tf.round(kx) 184 185 #print(y) 186 #print(K.mean(K.sum(K.abs(y), axis=-1))) 187 return y 188 189def pvq_quantize(x, k): 190 x = x/(1e-15+tf.norm(x, axis=-1,keepdims=True)) 191 quantized = pvq_quant_search(x, k) 192 quantized = quantized/(1e-15+tf.norm(quantized, axis=-1,keepdims=True)) 193 return x + tf.stop_gradient(quantized - x) 194 195 196def var_repeat(x): 197 return tf.repeat(tf.expand_dims(x[0], 1), K.shape(x[1])[1], axis=1) 198 199nb_state_dim = 24 200 201def new_rdovae_encoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256, training=False): 202 feat = Input(shape=(None, nb_used_features), batch_size=batch_size) 203 204 gru = CuDNNGRU if training else GRU 205 enc_dense1 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense1') 206 enc_dense2 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense2') 207 enc_dense3 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense3') 208 enc_dense4 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense4') 209 enc_dense5 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense5') 210 enc_dense6 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense6') 211 enc_dense7 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='enc_dense7') 212 enc_dense8 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='enc_dense8') 213 214 #bits_dense = Dense(nb_bits, activation='linear', name='bits_dense') 215 bits_dense = Conv1D(nb_bits, 4, padding='causal', activation='linear', name='bits_dense') 216 217 zero_out = Lambda(lambda x: 0*x) 218 inputs = Reshape((-1, 2*nb_used_features))(feat) 219 d1 = enc_dense1(inputs) 220 d2 = enc_dense2(d1) 221 d3 = enc_dense3(d2) 222 d4 = enc_dense4(d3) 223 d5 = enc_dense5(d4) 224 d6 = enc_dense6(d5) 225 d7 = enc_dense7(d6) 226 d8 = enc_dense8(d7) 227 pre_out = Concatenate()([d1, d2, d3, d4, d5, d6, d7, d8]) 228 enc_out = bits_dense(pre_out) 229 global_dense1 = Dense(128, activation='tanh', name='gdense1') 230 global_dense2 = Dense(nb_state_dim, activation='tanh', name='gdense2') 231 global_bits = global_dense2(global_dense1(pre_out)) 232 233 encoder = Model([feat], [enc_out, global_bits], name='encoder') 234 return encoder 235 236def new_rdovae_decoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256, training=False): 237 bits_input = Input(shape=(None, nb_bits), batch_size=batch_size, name="dec_bits") 238 gru_state_input = Input(shape=(nb_state_dim,), batch_size=batch_size, name="dec_state") 239 240 241 gru = CuDNNGRU if training else GRU 242 dec_dense1 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense1') 243 dec_dense2 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense2') 244 dec_dense3 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense3') 245 dec_dense4 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense4') 246 dec_dense5 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense5') 247 dec_dense6 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense6') 248 dec_dense7 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='dec_dense7') 249 dec_dense8 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='dec_dense8') 250 251 dec_final = Dense(bunch*nb_used_features, activation='linear', name='dec_final') 252 253 time_reverse = Lambda(lambda x: K.reverse(x, 1)) 254 #time_reverse = Lambda(lambda x: x) 255 #gru_state_rep = RepeatVector(64//bunch)(gru_state_input) 256 257 #gru_state_rep = Lambda(var_repeat, output_shape=(None, nb_state_dim)) ([gru_state_input, bits_input]) 258 gru_state1 = Dense(cond_size, name="state1", activation='tanh')(gru_state_input) 259 gru_state2 = Dense(cond_size, name="state2", activation='tanh')(gru_state_input) 260 gru_state3 = Dense(cond_size, name="state3", activation='tanh')(gru_state_input) 261 262 dec1 = dec_dense1(time_reverse(bits_input)) 263 dec2 = dec_dense2(dec1, initial_state=gru_state1) 264 dec3 = dec_dense3(dec2) 265 dec4 = dec_dense4(dec3, initial_state=gru_state2) 266 dec5 = dec_dense5(dec4) 267 dec6 = dec_dense6(dec5, initial_state=gru_state3) 268 dec7 = dec_dense7(dec6) 269 dec8 = dec_dense8(dec7) 270 output = Reshape((-1, nb_used_features))(dec_final(Concatenate()([dec1, dec2, dec3, dec4, dec5, dec6, dec7, dec8]))) 271 decoder = Model([bits_input, gru_state_input], time_reverse(output), name='decoder') 272 decoder.nb_bits = nb_bits 273 decoder.bunch = bunch 274 return decoder 275 276def new_split_decoder(decoder): 277 nb_bits = decoder.nb_bits 278 bunch = decoder.bunch 279 bits_input = Input(shape=(None, nb_bits), name="split_bits") 280 gru_state_input = Input(shape=(None,nb_state_dim), name="split_state") 281 282 range_select = Lambda(lambda x: x[0][:,x[1]:x[2],:]) 283 elem_select = Lambda(lambda x: x[0][:,x[1],:]) 284 points = [0, 100, 200, 300, 400] 285 outputs = [] 286 for i in range(len(points)-1): 287 begin = points[i]//bunch 288 end = points[i+1]//bunch 289 state = elem_select([gru_state_input, end-1]) 290 bits = range_select([bits_input, begin, end]) 291 outputs.append(decoder([bits, state])) 292 output = Concatenate(axis=1)(outputs) 293 split = Model([bits_input, gru_state_input], output, name="split") 294 return split 295 296def tensor_concat(x): 297 #n = x[1]//2 298 #x = x[0] 299 n=2 300 y = [] 301 for i in range(n-1): 302 offset = 2 * (n-1-i) 303 tmp = K.concatenate([x[i][:, offset:, :], x[-1][:, -offset:, :]], axis=-2) 304 y.append(tf.expand_dims(tmp, axis=0)) 305 y.append(tf.expand_dims(x[-1], axis=0)) 306 return Concatenate(axis=0)(y) 307 308 309def new_rdovae_model(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256, training=False): 310 311 feat = Input(shape=(None, nb_used_features), batch_size=batch_size) 312 quant_id = Input(shape=(None,), batch_size=batch_size) 313 lambda_val = Input(shape=(None, 1), batch_size=batch_size) 314 lambda_bunched = AveragePooling1D(pool_size=bunch//2, strides=bunch//2, padding="valid")(lambda_val) 315 lambda_up = Lambda(lambda x: K.repeat_elements(x, 2, axis=-2))(lambda_val) 316 317 qembedding = Embedding(nb_quant, 6*nb_bits, name='quant_embed', embeddings_initializer='zeros') 318 quant_embed_dec = qembedding(quant_id) 319 quant_scale = Activation('softplus')(Lambda(lambda x: x[:,:,:nb_bits], name='quant_scale_embed')(quant_embed_dec)) 320 321 encoder = new_rdovae_encoder(nb_used_features, nb_bits, bunch, nb_quant, batch_size, cond_size, cond_size2, training=training) 322 ze, gru_state_dec = encoder([feat]) 323 ze = Multiply()([ze, quant_scale]) 324 325 decoder = new_rdovae_decoder(nb_used_features, nb_bits, bunch, nb_quant, batch_size, cond_size, cond_size2, training=training) 326 split_decoder = new_split_decoder(decoder) 327 328 dead_zone = Activation('softplus')(Lambda(lambda x: x[:,:,nb_bits:2*nb_bits], name='dead_zone_embed')(quant_embed_dec)) 329 soft_distr_embed = Activation('sigmoid')(Lambda(lambda x: x[:,:,2*nb_bits:4*nb_bits], name='soft_distr_embed')(quant_embed_dec)) 330 hard_distr_embed = Activation('sigmoid')(Lambda(lambda x: x[:,:,4*nb_bits:], name='hard_distr_embed')(quant_embed_dec)) 331 332 noisequant = UniformNoise() 333 hardquant = Lambda(hard_quantize) 334 dzone = Lambda(apply_dead_zone) 335 dze = dzone([ze,dead_zone]) 336 ndze = noisequant(dze) 337 dze_quant = hardquant(dze) 338 339 div = Lambda(lambda x: x[0]/x[1]) 340 dze_quant = div([dze_quant,quant_scale]) 341 ndze_unquant = div([ndze,quant_scale]) 342 343 mod_select = Lambda(lambda x: x[0][:,x[1]::bunch//2,:]) 344 gru_state_dec = Lambda(lambda x: pvq_quantize(x, 82))(gru_state_dec) 345 combined_output = [] 346 unquantized_output = [] 347 cat = Concatenate(name="out_cat") 348 for i in range(bunch//2): 349 dze_select = mod_select([dze_quant, i]) 350 ndze_select = mod_select([ndze_unquant, i]) 351 state_select = mod_select([gru_state_dec, i]) 352 353 tmp = split_decoder([dze_select, state_select]) 354 tmp = cat([tmp, lambda_up]) 355 combined_output.append(tmp) 356 357 tmp = split_decoder([ndze_select, state_select]) 358 tmp = cat([tmp, lambda_up]) 359 unquantized_output.append(tmp) 360 361 concat = Lambda(tensor_concat, name="output") 362 combined_output = concat(combined_output) 363 unquantized_output = concat(unquantized_output) 364 365 e2 = Concatenate(name="hard_bits")([dze, hard_distr_embed, lambda_val]) 366 e = Concatenate(name="soft_bits")([dze, soft_distr_embed, lambda_val]) 367 368 369 model = Model([feat, quant_id, lambda_val], [combined_output, unquantized_output, e, e2], name="end2end") 370 model.nb_used_features = nb_used_features 371 372 return model, encoder, decoder, qembedding 373