1*a58d3d2aSXin Li#!/usr/bin/python3 2*a58d3d2aSXin Li'''Copyright (c) 2018 Mozilla 3*a58d3d2aSXin Li 4*a58d3d2aSXin Li Redistribution and use in source and binary forms, with or without 5*a58d3d2aSXin Li modification, are permitted provided that the following conditions 6*a58d3d2aSXin Li are met: 7*a58d3d2aSXin Li 8*a58d3d2aSXin Li - Redistributions of source code must retain the above copyright 9*a58d3d2aSXin Li notice, this list of conditions and the following disclaimer. 10*a58d3d2aSXin Li 11*a58d3d2aSXin Li - Redistributions in binary form must reproduce the above copyright 12*a58d3d2aSXin Li notice, this list of conditions and the following disclaimer in the 13*a58d3d2aSXin Li documentation and/or other materials provided with the distribution. 14*a58d3d2aSXin Li 15*a58d3d2aSXin Li THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 16*a58d3d2aSXin Li ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 17*a58d3d2aSXin Li LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 18*a58d3d2aSXin Li A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR 19*a58d3d2aSXin Li CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20*a58d3d2aSXin Li EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21*a58d3d2aSXin Li PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22*a58d3d2aSXin Li PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 23*a58d3d2aSXin Li LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 24*a58d3d2aSXin Li NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25*a58d3d2aSXin Li SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26*a58d3d2aSXin Li''' 27*a58d3d2aSXin Li 28*a58d3d2aSXin Liimport math 29*a58d3d2aSXin Liimport tensorflow as tf 30*a58d3d2aSXin Lifrom tensorflow.keras.models import Model 31*a58d3d2aSXin Lifrom tensorflow.keras.layers import Input, GRU, Dense, Embedding, Reshape, Concatenate, Lambda, Conv1D, Multiply, Add, Bidirectional, MaxPooling1D, Activation, GaussianNoise 32*a58d3d2aSXin Lifrom tensorflow.compat.v1.keras.layers import CuDNNGRU 33*a58d3d2aSXin Lifrom tensorflow.keras import backend as K 34*a58d3d2aSXin Lifrom tensorflow.keras.constraints import Constraint 35*a58d3d2aSXin Lifrom tensorflow.keras.initializers import Initializer 36*a58d3d2aSXin Lifrom tensorflow.keras.callbacks import Callback 37*a58d3d2aSXin Lifrom mdense import MDense 38*a58d3d2aSXin Liimport numpy as np 39*a58d3d2aSXin Liimport h5py 40*a58d3d2aSXin Liimport sys 41*a58d3d2aSXin Lifrom tf_funcs import * 42*a58d3d2aSXin Lifrom diffembed import diff_Embed 43*a58d3d2aSXin Lifrom parameters import set_parameter 44*a58d3d2aSXin Li 45*a58d3d2aSXin Liframe_size = 160 46*a58d3d2aSXin Lipcm_bits = 8 47*a58d3d2aSXin Liembed_size = 128 48*a58d3d2aSXin Lipcm_levels = 2**pcm_bits 49*a58d3d2aSXin Li 50*a58d3d2aSXin Lidef interleave(p, samples): 51*a58d3d2aSXin Li p2=tf.expand_dims(p, 3) 52*a58d3d2aSXin Li nb_repeats = pcm_levels//(2*p.shape[2]) 53*a58d3d2aSXin Li p3 = tf.reshape(tf.repeat(tf.concat([1-p2, p2], 3), nb_repeats), (-1, samples, pcm_levels)) 54*a58d3d2aSXin Li return p3 55*a58d3d2aSXin Li 56*a58d3d2aSXin Lidef tree_to_pdf(p, samples): 57*a58d3d2aSXin Li return interleave(p[:,:,1:2], samples) * interleave(p[:,:,2:4], samples) * interleave(p[:,:,4:8], samples) * interleave(p[:,:,8:16], samples) \ 58*a58d3d2aSXin Li * interleave(p[:,:,16:32], samples) * interleave(p[:,:,32:64], samples) * interleave(p[:,:,64:128], samples) * interleave(p[:,:,128:256], samples) 59*a58d3d2aSXin Li 60*a58d3d2aSXin Lidef tree_to_pdf_train(p): 61*a58d3d2aSXin Li #FIXME: try not to hardcode the 2400 samples (15 frames * 160 samples/frame) 62*a58d3d2aSXin Li return tree_to_pdf(p, 2400) 63*a58d3d2aSXin Li 64*a58d3d2aSXin Lidef tree_to_pdf_infer(p): 65*a58d3d2aSXin Li return tree_to_pdf(p, 1) 66*a58d3d2aSXin Li 67*a58d3d2aSXin Lidef quant_regularizer(x): 68*a58d3d2aSXin Li Q = 128 69*a58d3d2aSXin Li Q_1 = 1./Q 70*a58d3d2aSXin Li #return .01 * tf.reduce_mean(1 - tf.math.cos(2*3.1415926535897931*(Q*x-tf.round(Q*x)))) 71*a58d3d2aSXin Li return .01 * tf.reduce_mean(K.sqrt(K.sqrt(1.0001 - tf.math.cos(2*3.1415926535897931*(Q*x-tf.round(Q*x)))))) 72*a58d3d2aSXin Li 73*a58d3d2aSXin Liclass Sparsify(Callback): 74*a58d3d2aSXin Li def __init__(self, t_start, t_end, interval, density, quantize=False): 75*a58d3d2aSXin Li super(Sparsify, self).__init__() 76*a58d3d2aSXin Li self.batch = 0 77*a58d3d2aSXin Li self.t_start = t_start 78*a58d3d2aSXin Li self.t_end = t_end 79*a58d3d2aSXin Li self.interval = interval 80*a58d3d2aSXin Li self.final_density = density 81*a58d3d2aSXin Li self.quantize = quantize 82*a58d3d2aSXin Li 83*a58d3d2aSXin Li def on_batch_end(self, batch, logs=None): 84*a58d3d2aSXin Li #print("batch number", self.batch) 85*a58d3d2aSXin Li self.batch += 1 86*a58d3d2aSXin Li if self.quantize or (self.batch > self.t_start and (self.batch-self.t_start) % self.interval == 0) or self.batch >= self.t_end: 87*a58d3d2aSXin Li #print("constrain"); 88*a58d3d2aSXin Li layer = self.model.get_layer('gru_a') 89*a58d3d2aSXin Li w = layer.get_weights() 90*a58d3d2aSXin Li p = w[1] 91*a58d3d2aSXin Li nb = p.shape[1]//p.shape[0] 92*a58d3d2aSXin Li N = p.shape[0] 93*a58d3d2aSXin Li #print("nb = ", nb, ", N = ", N); 94*a58d3d2aSXin Li #print(p.shape) 95*a58d3d2aSXin Li #print ("density = ", density) 96*a58d3d2aSXin Li for k in range(nb): 97*a58d3d2aSXin Li density = self.final_density[k] 98*a58d3d2aSXin Li if self.batch < self.t_end and not self.quantize: 99*a58d3d2aSXin Li r = 1 - (self.batch-self.t_start)/(self.t_end - self.t_start) 100*a58d3d2aSXin Li density = 1 - (1-self.final_density[k])*(1 - r*r*r) 101*a58d3d2aSXin Li A = p[:, k*N:(k+1)*N] 102*a58d3d2aSXin Li A = A - np.diag(np.diag(A)) 103*a58d3d2aSXin Li #This is needed because of the CuDNNGRU strange weight ordering 104*a58d3d2aSXin Li A = np.transpose(A, (1, 0)) 105*a58d3d2aSXin Li L=np.reshape(A, (N//4, 4, N//8, 8)) 106*a58d3d2aSXin Li S=np.sum(L*L, axis=-1) 107*a58d3d2aSXin Li S=np.sum(S, axis=1) 108*a58d3d2aSXin Li SS=np.sort(np.reshape(S, (-1,))) 109*a58d3d2aSXin Li thresh = SS[round(N*N//32*(1-density))] 110*a58d3d2aSXin Li mask = (S>=thresh).astype('float32') 111*a58d3d2aSXin Li mask = np.repeat(mask, 4, axis=0) 112*a58d3d2aSXin Li mask = np.repeat(mask, 8, axis=1) 113*a58d3d2aSXin Li mask = np.minimum(1, mask + np.diag(np.ones((N,)))) 114*a58d3d2aSXin Li #This is needed because of the CuDNNGRU strange weight ordering 115*a58d3d2aSXin Li mask = np.transpose(mask, (1, 0)) 116*a58d3d2aSXin Li p[:, k*N:(k+1)*N] = p[:, k*N:(k+1)*N]*mask 117*a58d3d2aSXin Li #print(thresh, np.mean(mask)) 118*a58d3d2aSXin Li if self.quantize and ((self.batch > self.t_start and (self.batch-self.t_start) % self.interval == 0) or self.batch >= self.t_end): 119*a58d3d2aSXin Li if self.batch < self.t_end: 120*a58d3d2aSXin Li threshold = .5*(self.batch - self.t_start)/(self.t_end - self.t_start) 121*a58d3d2aSXin Li else: 122*a58d3d2aSXin Li threshold = .5 123*a58d3d2aSXin Li quant = np.round(p*128.) 124*a58d3d2aSXin Li res = p*128.-quant 125*a58d3d2aSXin Li mask = (np.abs(res) <= threshold).astype('float32') 126*a58d3d2aSXin Li p = mask/128.*quant + (1-mask)*p 127*a58d3d2aSXin Li 128*a58d3d2aSXin Li w[1] = p 129*a58d3d2aSXin Li layer.set_weights(w) 130*a58d3d2aSXin Li 131*a58d3d2aSXin Liclass SparsifyGRUB(Callback): 132*a58d3d2aSXin Li def __init__(self, t_start, t_end, interval, grua_units, density, quantize=False): 133*a58d3d2aSXin Li super(SparsifyGRUB, self).__init__() 134*a58d3d2aSXin Li self.batch = 0 135*a58d3d2aSXin Li self.t_start = t_start 136*a58d3d2aSXin Li self.t_end = t_end 137*a58d3d2aSXin Li self.interval = interval 138*a58d3d2aSXin Li self.final_density = density 139*a58d3d2aSXin Li self.grua_units = grua_units 140*a58d3d2aSXin Li self.quantize = quantize 141*a58d3d2aSXin Li 142*a58d3d2aSXin Li def on_batch_end(self, batch, logs=None): 143*a58d3d2aSXin Li #print("batch number", self.batch) 144*a58d3d2aSXin Li self.batch += 1 145*a58d3d2aSXin Li if self.quantize or (self.batch > self.t_start and (self.batch-self.t_start) % self.interval == 0) or self.batch >= self.t_end: 146*a58d3d2aSXin Li #print("constrain"); 147*a58d3d2aSXin Li layer = self.model.get_layer('gru_b') 148*a58d3d2aSXin Li w = layer.get_weights() 149*a58d3d2aSXin Li p = w[0] 150*a58d3d2aSXin Li N = p.shape[0] 151*a58d3d2aSXin Li M = p.shape[1]//3 152*a58d3d2aSXin Li for k in range(3): 153*a58d3d2aSXin Li density = self.final_density[k] 154*a58d3d2aSXin Li if self.batch < self.t_end and not self.quantize: 155*a58d3d2aSXin Li r = 1 - (self.batch-self.t_start)/(self.t_end - self.t_start) 156*a58d3d2aSXin Li density = 1 - (1-self.final_density[k])*(1 - r*r*r) 157*a58d3d2aSXin Li A = p[:, k*M:(k+1)*M] 158*a58d3d2aSXin Li #This is needed because of the CuDNNGRU strange weight ordering 159*a58d3d2aSXin Li A = np.reshape(A, (M, N)) 160*a58d3d2aSXin Li A = np.transpose(A, (1, 0)) 161*a58d3d2aSXin Li N2 = self.grua_units 162*a58d3d2aSXin Li A2 = A[:N2, :] 163*a58d3d2aSXin Li L=np.reshape(A2, (N2//4, 4, M//8, 8)) 164*a58d3d2aSXin Li S=np.sum(L*L, axis=-1) 165*a58d3d2aSXin Li S=np.sum(S, axis=1) 166*a58d3d2aSXin Li SS=np.sort(np.reshape(S, (-1,))) 167*a58d3d2aSXin Li thresh = SS[round(M*N2//32*(1-density))] 168*a58d3d2aSXin Li mask = (S>=thresh).astype('float32') 169*a58d3d2aSXin Li mask = np.repeat(mask, 4, axis=0) 170*a58d3d2aSXin Li mask = np.repeat(mask, 8, axis=1) 171*a58d3d2aSXin Li A = np.concatenate([A2*mask, A[N2:,:]], axis=0) 172*a58d3d2aSXin Li #This is needed because of the CuDNNGRU strange weight ordering 173*a58d3d2aSXin Li A = np.transpose(A, (1, 0)) 174*a58d3d2aSXin Li A = np.reshape(A, (N, M)) 175*a58d3d2aSXin Li p[:, k*M:(k+1)*M] = A 176*a58d3d2aSXin Li #print(thresh, np.mean(mask)) 177*a58d3d2aSXin Li if self.quantize and ((self.batch > self.t_start and (self.batch-self.t_start) % self.interval == 0) or self.batch >= self.t_end): 178*a58d3d2aSXin Li if self.batch < self.t_end: 179*a58d3d2aSXin Li threshold = .5*(self.batch - self.t_start)/(self.t_end - self.t_start) 180*a58d3d2aSXin Li else: 181*a58d3d2aSXin Li threshold = .5 182*a58d3d2aSXin Li quant = np.round(p*128.) 183*a58d3d2aSXin Li res = p*128.-quant 184*a58d3d2aSXin Li mask = (np.abs(res) <= threshold).astype('float32') 185*a58d3d2aSXin Li p = mask/128.*quant + (1-mask)*p 186*a58d3d2aSXin Li 187*a58d3d2aSXin Li w[0] = p 188*a58d3d2aSXin Li layer.set_weights(w) 189*a58d3d2aSXin Li 190*a58d3d2aSXin Li 191*a58d3d2aSXin Liclass PCMInit(Initializer): 192*a58d3d2aSXin Li def __init__(self, gain=.1, seed=None): 193*a58d3d2aSXin Li self.gain = gain 194*a58d3d2aSXin Li self.seed = seed 195*a58d3d2aSXin Li 196*a58d3d2aSXin Li def __call__(self, shape, dtype=None): 197*a58d3d2aSXin Li num_rows = 1 198*a58d3d2aSXin Li for dim in shape[:-1]: 199*a58d3d2aSXin Li num_rows *= dim 200*a58d3d2aSXin Li num_cols = shape[-1] 201*a58d3d2aSXin Li flat_shape = (num_rows, num_cols) 202*a58d3d2aSXin Li if self.seed is not None: 203*a58d3d2aSXin Li np.random.seed(self.seed) 204*a58d3d2aSXin Li a = np.random.uniform(-1.7321, 1.7321, flat_shape) 205*a58d3d2aSXin Li #a[:,0] = math.sqrt(12)*np.arange(-.5*num_rows+.5,.5*num_rows-.4)/num_rows 206*a58d3d2aSXin Li #a[:,1] = .5*a[:,0]*a[:,0]*a[:,0] 207*a58d3d2aSXin Li a = a + np.reshape(math.sqrt(12)*np.arange(-.5*num_rows+.5,.5*num_rows-.4)/num_rows, (num_rows, 1)) 208*a58d3d2aSXin Li return self.gain * a.astype("float32") 209*a58d3d2aSXin Li 210*a58d3d2aSXin Li def get_config(self): 211*a58d3d2aSXin Li return { 212*a58d3d2aSXin Li 'gain': self.gain, 213*a58d3d2aSXin Li 'seed': self.seed 214*a58d3d2aSXin Li } 215*a58d3d2aSXin Li 216*a58d3d2aSXin Liclass WeightClip(Constraint): 217*a58d3d2aSXin Li '''Clips the weights incident to each hidden unit to be inside a range 218*a58d3d2aSXin Li ''' 219*a58d3d2aSXin Li def __init__(self, c=2): 220*a58d3d2aSXin Li self.c = c 221*a58d3d2aSXin Li 222*a58d3d2aSXin Li def __call__(self, p): 223*a58d3d2aSXin Li # Ensure that abs of adjacent weights don't sum to more than 127. Otherwise there's a risk of 224*a58d3d2aSXin Li # saturation when implementing dot products with SSSE3 or AVX2. 225*a58d3d2aSXin Li return self.c*p/tf.maximum(self.c, tf.repeat(tf.abs(p[:, 1::2])+tf.abs(p[:, 0::2]), 2, axis=1)) 226*a58d3d2aSXin Li #return K.clip(p, -self.c, self.c) 227*a58d3d2aSXin Li 228*a58d3d2aSXin Li def get_config(self): 229*a58d3d2aSXin Li return {'name': self.__class__.__name__, 230*a58d3d2aSXin Li 'c': self.c} 231*a58d3d2aSXin Li 232*a58d3d2aSXin Liconstraint = WeightClip(0.992) 233*a58d3d2aSXin Li 234*a58d3d2aSXin Lidef new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features=20, batch_size=128, training=False, adaptation=False, quantize=False, flag_e2e = False, cond_size=128, lpc_order=16, lpc_gamma=1., lookahead=2): 235*a58d3d2aSXin Li pcm = Input(shape=(None, 1), batch_size=batch_size) 236*a58d3d2aSXin Li dpcm = Input(shape=(None, 3), batch_size=batch_size) 237*a58d3d2aSXin Li feat = Input(shape=(None, nb_used_features), batch_size=batch_size) 238*a58d3d2aSXin Li pitch = Input(shape=(None, 1), batch_size=batch_size) 239*a58d3d2aSXin Li dec_feat = Input(shape=(None, cond_size)) 240*a58d3d2aSXin Li dec_state1 = Input(shape=(rnn_units1,)) 241*a58d3d2aSXin Li dec_state2 = Input(shape=(rnn_units2,)) 242*a58d3d2aSXin Li 243*a58d3d2aSXin Li padding = 'valid' if training else 'same' 244*a58d3d2aSXin Li fconv1 = Conv1D(cond_size, 3, padding=padding, activation='tanh', name='feature_conv1') 245*a58d3d2aSXin Li fconv2 = Conv1D(cond_size, 3, padding=padding, activation='tanh', name='feature_conv2') 246*a58d3d2aSXin Li pembed = Embedding(256, 64, name='embed_pitch') 247*a58d3d2aSXin Li cat_feat = Concatenate()([feat, Reshape((-1, 64))(pembed(pitch))]) 248*a58d3d2aSXin Li 249*a58d3d2aSXin Li cfeat = fconv2(fconv1(cat_feat)) 250*a58d3d2aSXin Li 251*a58d3d2aSXin Li fdense1 = Dense(cond_size, activation='tanh', name='feature_dense1') 252*a58d3d2aSXin Li fdense2 = Dense(cond_size, activation='tanh', name='feature_dense2') 253*a58d3d2aSXin Li 254*a58d3d2aSXin Li if flag_e2e and quantize: 255*a58d3d2aSXin Li fconv1.trainable = False 256*a58d3d2aSXin Li fconv2.trainable = False 257*a58d3d2aSXin Li fdense1.trainable = False 258*a58d3d2aSXin Li fdense2.trainable = False 259*a58d3d2aSXin Li 260*a58d3d2aSXin Li cfeat = fdense2(fdense1(cfeat)) 261*a58d3d2aSXin Li 262*a58d3d2aSXin Li error_calc = Lambda(lambda x: tf_l2u(x[0] - tf.roll(x[1],1,axis = 1))) 263*a58d3d2aSXin Li if flag_e2e: 264*a58d3d2aSXin Li lpcoeffs = diff_rc2lpc(name = "rc2lpc")(cfeat) 265*a58d3d2aSXin Li else: 266*a58d3d2aSXin Li lpcoeffs = Input(shape=(None, lpc_order), batch_size=batch_size) 267*a58d3d2aSXin Li 268*a58d3d2aSXin Li real_preds = diff_pred(name = "real_lpc2preds")([pcm,lpcoeffs]) 269*a58d3d2aSXin Li weighting = lpc_gamma ** np.arange(1, 17).astype('float32') 270*a58d3d2aSXin Li weighted_lpcoeffs = Lambda(lambda x: x[0]*x[1])([lpcoeffs, weighting]) 271*a58d3d2aSXin Li tensor_preds = diff_pred(name = "lpc2preds")([pcm,weighted_lpcoeffs]) 272*a58d3d2aSXin Li past_errors = error_calc([pcm,tensor_preds]) 273*a58d3d2aSXin Li 274*a58d3d2aSXin Li embed = diff_Embed(name='embed_sig',initializer = PCMInit()) 275*a58d3d2aSXin Li cpcm = Concatenate()([tf_l2u(pcm),tf_l2u(tensor_preds),past_errors]) 276*a58d3d2aSXin Li cpcm = GaussianNoise(.3)(cpcm) 277*a58d3d2aSXin Li cpcm = Reshape((-1, embed_size*3))(embed(cpcm)) 278*a58d3d2aSXin Li cpcm_decoder = Reshape((-1, embed_size*3))(embed(dpcm)) 279*a58d3d2aSXin Li 280*a58d3d2aSXin Li 281*a58d3d2aSXin Li rep = Lambda(lambda x: K.repeat_elements(x, frame_size, 1)) 282*a58d3d2aSXin Li 283*a58d3d2aSXin Li quant = quant_regularizer if quantize else None 284*a58d3d2aSXin Li 285*a58d3d2aSXin Li if training: 286*a58d3d2aSXin Li rnn = CuDNNGRU(rnn_units1, return_sequences=True, return_state=True, name='gru_a', stateful=True, 287*a58d3d2aSXin Li recurrent_constraint = constraint, recurrent_regularizer=quant) 288*a58d3d2aSXin Li rnn2 = CuDNNGRU(rnn_units2, return_sequences=True, return_state=True, name='gru_b', stateful=True, 289*a58d3d2aSXin Li kernel_constraint=constraint, recurrent_constraint = constraint, kernel_regularizer=quant, recurrent_regularizer=quant) 290*a58d3d2aSXin Li else: 291*a58d3d2aSXin Li rnn = GRU(rnn_units1, return_sequences=True, return_state=True, recurrent_activation="sigmoid", reset_after='true', name='gru_a', stateful=True, 292*a58d3d2aSXin Li recurrent_constraint = constraint, recurrent_regularizer=quant) 293*a58d3d2aSXin Li rnn2 = GRU(rnn_units2, return_sequences=True, return_state=True, recurrent_activation="sigmoid", reset_after='true', name='gru_b', stateful=True, 294*a58d3d2aSXin Li kernel_constraint=constraint, recurrent_constraint = constraint, kernel_regularizer=quant, recurrent_regularizer=quant) 295*a58d3d2aSXin Li 296*a58d3d2aSXin Li rnn_in = Concatenate()([cpcm, rep(cfeat)]) 297*a58d3d2aSXin Li md = MDense(pcm_levels, activation='sigmoid', name='dual_fc') 298*a58d3d2aSXin Li gru_out1, _ = rnn(rnn_in) 299*a58d3d2aSXin Li gru_out1 = GaussianNoise(.005)(gru_out1) 300*a58d3d2aSXin Li gru_out2, _ = rnn2(Concatenate()([gru_out1, rep(cfeat)])) 301*a58d3d2aSXin Li ulaw_prob = Lambda(tree_to_pdf_train)(md(gru_out2)) 302*a58d3d2aSXin Li 303*a58d3d2aSXin Li if adaptation: 304*a58d3d2aSXin Li rnn.trainable=False 305*a58d3d2aSXin Li rnn2.trainable=False 306*a58d3d2aSXin Li md.trainable=False 307*a58d3d2aSXin Li embed.Trainable=False 308*a58d3d2aSXin Li 309*a58d3d2aSXin Li m_out = Concatenate(name='pdf')([tensor_preds,real_preds,ulaw_prob]) 310*a58d3d2aSXin Li if not flag_e2e: 311*a58d3d2aSXin Li model = Model([pcm, feat, pitch, lpcoeffs], m_out) 312*a58d3d2aSXin Li else: 313*a58d3d2aSXin Li model = Model([pcm, feat, pitch], [m_out, cfeat]) 314*a58d3d2aSXin Li model.rnn_units1 = rnn_units1 315*a58d3d2aSXin Li model.rnn_units2 = rnn_units2 316*a58d3d2aSXin Li model.nb_used_features = nb_used_features 317*a58d3d2aSXin Li model.frame_size = frame_size 318*a58d3d2aSXin Li 319*a58d3d2aSXin Li if not flag_e2e: 320*a58d3d2aSXin Li encoder = Model([feat, pitch], cfeat) 321*a58d3d2aSXin Li dec_rnn_in = Concatenate()([cpcm_decoder, dec_feat]) 322*a58d3d2aSXin Li else: 323*a58d3d2aSXin Li encoder = Model([feat, pitch], [cfeat,lpcoeffs]) 324*a58d3d2aSXin Li dec_rnn_in = Concatenate()([cpcm_decoder, dec_feat]) 325*a58d3d2aSXin Li dec_gru_out1, state1 = rnn(dec_rnn_in, initial_state=dec_state1) 326*a58d3d2aSXin Li dec_gru_out2, state2 = rnn2(Concatenate()([dec_gru_out1, dec_feat]), initial_state=dec_state2) 327*a58d3d2aSXin Li dec_ulaw_prob = Lambda(tree_to_pdf_infer)(md(dec_gru_out2)) 328*a58d3d2aSXin Li 329*a58d3d2aSXin Li if flag_e2e: 330*a58d3d2aSXin Li decoder = Model([dpcm, dec_feat, dec_state1, dec_state2], [dec_ulaw_prob, state1, state2]) 331*a58d3d2aSXin Li else: 332*a58d3d2aSXin Li decoder = Model([dpcm, dec_feat, dec_state1, dec_state2], [dec_ulaw_prob, state1, state2]) 333*a58d3d2aSXin Li 334*a58d3d2aSXin Li # add parameters to model 335*a58d3d2aSXin Li set_parameter(model, 'lpc_gamma', lpc_gamma, dtype='float64') 336*a58d3d2aSXin Li set_parameter(model, 'flag_e2e', flag_e2e, dtype='bool') 337*a58d3d2aSXin Li set_parameter(model, 'lookahead', lookahead, dtype='int32') 338*a58d3d2aSXin Li 339*a58d3d2aSXin Li return model, encoder, decoder 340