xref: /aosp_15_r20/external/libopus/dnn/training_tf2/lpcnet.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1*a58d3d2aSXin Li#!/usr/bin/python3
2*a58d3d2aSXin Li'''Copyright (c) 2018 Mozilla
3*a58d3d2aSXin Li
4*a58d3d2aSXin Li   Redistribution and use in source and binary forms, with or without
5*a58d3d2aSXin Li   modification, are permitted provided that the following conditions
6*a58d3d2aSXin Li   are met:
7*a58d3d2aSXin Li
8*a58d3d2aSXin Li   - Redistributions of source code must retain the above copyright
9*a58d3d2aSXin Li   notice, this list of conditions and the following disclaimer.
10*a58d3d2aSXin Li
11*a58d3d2aSXin Li   - Redistributions in binary form must reproduce the above copyright
12*a58d3d2aSXin Li   notice, this list of conditions and the following disclaimer in the
13*a58d3d2aSXin Li   documentation and/or other materials provided with the distribution.
14*a58d3d2aSXin Li
15*a58d3d2aSXin Li   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16*a58d3d2aSXin Li   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17*a58d3d2aSXin Li   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18*a58d3d2aSXin Li   A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
19*a58d3d2aSXin Li   CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20*a58d3d2aSXin Li   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21*a58d3d2aSXin Li   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22*a58d3d2aSXin Li   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23*a58d3d2aSXin Li   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24*a58d3d2aSXin Li   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25*a58d3d2aSXin Li   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26*a58d3d2aSXin Li'''
27*a58d3d2aSXin Li
28*a58d3d2aSXin Liimport math
29*a58d3d2aSXin Liimport tensorflow as tf
30*a58d3d2aSXin Lifrom tensorflow.keras.models import Model
31*a58d3d2aSXin Lifrom tensorflow.keras.layers import Input, GRU, Dense, Embedding, Reshape, Concatenate, Lambda, Conv1D, Multiply, Add, Bidirectional, MaxPooling1D, Activation, GaussianNoise
32*a58d3d2aSXin Lifrom tensorflow.compat.v1.keras.layers import CuDNNGRU
33*a58d3d2aSXin Lifrom tensorflow.keras import backend as K
34*a58d3d2aSXin Lifrom tensorflow.keras.constraints import Constraint
35*a58d3d2aSXin Lifrom tensorflow.keras.initializers import Initializer
36*a58d3d2aSXin Lifrom tensorflow.keras.callbacks import Callback
37*a58d3d2aSXin Lifrom mdense import MDense
38*a58d3d2aSXin Liimport numpy as np
39*a58d3d2aSXin Liimport h5py
40*a58d3d2aSXin Liimport sys
41*a58d3d2aSXin Lifrom tf_funcs import *
42*a58d3d2aSXin Lifrom diffembed import diff_Embed
43*a58d3d2aSXin Lifrom parameters import set_parameter
44*a58d3d2aSXin Li
45*a58d3d2aSXin Liframe_size = 160
46*a58d3d2aSXin Lipcm_bits = 8
47*a58d3d2aSXin Liembed_size = 128
48*a58d3d2aSXin Lipcm_levels = 2**pcm_bits
49*a58d3d2aSXin Li
50*a58d3d2aSXin Lidef interleave(p, samples):
51*a58d3d2aSXin Li    p2=tf.expand_dims(p, 3)
52*a58d3d2aSXin Li    nb_repeats = pcm_levels//(2*p.shape[2])
53*a58d3d2aSXin Li    p3 = tf.reshape(tf.repeat(tf.concat([1-p2, p2], 3), nb_repeats), (-1, samples, pcm_levels))
54*a58d3d2aSXin Li    return p3
55*a58d3d2aSXin Li
56*a58d3d2aSXin Lidef tree_to_pdf(p, samples):
57*a58d3d2aSXin Li    return interleave(p[:,:,1:2], samples) * interleave(p[:,:,2:4], samples) * interleave(p[:,:,4:8], samples) * interleave(p[:,:,8:16], samples) \
58*a58d3d2aSXin Li         * interleave(p[:,:,16:32], samples) * interleave(p[:,:,32:64], samples) * interleave(p[:,:,64:128], samples) * interleave(p[:,:,128:256], samples)
59*a58d3d2aSXin Li
60*a58d3d2aSXin Lidef tree_to_pdf_train(p):
61*a58d3d2aSXin Li    #FIXME: try not to hardcode the 2400 samples (15 frames * 160 samples/frame)
62*a58d3d2aSXin Li    return tree_to_pdf(p, 2400)
63*a58d3d2aSXin Li
64*a58d3d2aSXin Lidef tree_to_pdf_infer(p):
65*a58d3d2aSXin Li    return tree_to_pdf(p, 1)
66*a58d3d2aSXin Li
67*a58d3d2aSXin Lidef quant_regularizer(x):
68*a58d3d2aSXin Li    Q = 128
69*a58d3d2aSXin Li    Q_1 = 1./Q
70*a58d3d2aSXin Li    #return .01 * tf.reduce_mean(1 - tf.math.cos(2*3.1415926535897931*(Q*x-tf.round(Q*x))))
71*a58d3d2aSXin Li    return .01 * tf.reduce_mean(K.sqrt(K.sqrt(1.0001 - tf.math.cos(2*3.1415926535897931*(Q*x-tf.round(Q*x))))))
72*a58d3d2aSXin Li
73*a58d3d2aSXin Liclass Sparsify(Callback):
74*a58d3d2aSXin Li    def __init__(self, t_start, t_end, interval, density, quantize=False):
75*a58d3d2aSXin Li        super(Sparsify, self).__init__()
76*a58d3d2aSXin Li        self.batch = 0
77*a58d3d2aSXin Li        self.t_start = t_start
78*a58d3d2aSXin Li        self.t_end = t_end
79*a58d3d2aSXin Li        self.interval = interval
80*a58d3d2aSXin Li        self.final_density = density
81*a58d3d2aSXin Li        self.quantize = quantize
82*a58d3d2aSXin Li
83*a58d3d2aSXin Li    def on_batch_end(self, batch, logs=None):
84*a58d3d2aSXin Li        #print("batch number", self.batch)
85*a58d3d2aSXin Li        self.batch += 1
86*a58d3d2aSXin Li        if self.quantize or (self.batch > self.t_start and (self.batch-self.t_start) % self.interval == 0) or self.batch >= self.t_end:
87*a58d3d2aSXin Li            #print("constrain");
88*a58d3d2aSXin Li            layer = self.model.get_layer('gru_a')
89*a58d3d2aSXin Li            w = layer.get_weights()
90*a58d3d2aSXin Li            p = w[1]
91*a58d3d2aSXin Li            nb = p.shape[1]//p.shape[0]
92*a58d3d2aSXin Li            N = p.shape[0]
93*a58d3d2aSXin Li            #print("nb = ", nb, ", N = ", N);
94*a58d3d2aSXin Li            #print(p.shape)
95*a58d3d2aSXin Li            #print ("density = ", density)
96*a58d3d2aSXin Li            for k in range(nb):
97*a58d3d2aSXin Li                density = self.final_density[k]
98*a58d3d2aSXin Li                if self.batch < self.t_end and not self.quantize:
99*a58d3d2aSXin Li                    r = 1 - (self.batch-self.t_start)/(self.t_end - self.t_start)
100*a58d3d2aSXin Li                    density = 1 - (1-self.final_density[k])*(1 - r*r*r)
101*a58d3d2aSXin Li                A = p[:, k*N:(k+1)*N]
102*a58d3d2aSXin Li                A = A - np.diag(np.diag(A))
103*a58d3d2aSXin Li                #This is needed because of the CuDNNGRU strange weight ordering
104*a58d3d2aSXin Li                A = np.transpose(A, (1, 0))
105*a58d3d2aSXin Li                L=np.reshape(A, (N//4, 4, N//8, 8))
106*a58d3d2aSXin Li                S=np.sum(L*L, axis=-1)
107*a58d3d2aSXin Li                S=np.sum(S, axis=1)
108*a58d3d2aSXin Li                SS=np.sort(np.reshape(S, (-1,)))
109*a58d3d2aSXin Li                thresh = SS[round(N*N//32*(1-density))]
110*a58d3d2aSXin Li                mask = (S>=thresh).astype('float32')
111*a58d3d2aSXin Li                mask = np.repeat(mask, 4, axis=0)
112*a58d3d2aSXin Li                mask = np.repeat(mask, 8, axis=1)
113*a58d3d2aSXin Li                mask = np.minimum(1, mask + np.diag(np.ones((N,))))
114*a58d3d2aSXin Li                #This is needed because of the CuDNNGRU strange weight ordering
115*a58d3d2aSXin Li                mask = np.transpose(mask, (1, 0))
116*a58d3d2aSXin Li                p[:, k*N:(k+1)*N] = p[:, k*N:(k+1)*N]*mask
117*a58d3d2aSXin Li                #print(thresh, np.mean(mask))
118*a58d3d2aSXin Li            if self.quantize and ((self.batch > self.t_start and (self.batch-self.t_start) % self.interval == 0) or self.batch >= self.t_end):
119*a58d3d2aSXin Li                if self.batch < self.t_end:
120*a58d3d2aSXin Li                    threshold = .5*(self.batch - self.t_start)/(self.t_end - self.t_start)
121*a58d3d2aSXin Li                else:
122*a58d3d2aSXin Li                    threshold = .5
123*a58d3d2aSXin Li                quant = np.round(p*128.)
124*a58d3d2aSXin Li                res = p*128.-quant
125*a58d3d2aSXin Li                mask = (np.abs(res) <= threshold).astype('float32')
126*a58d3d2aSXin Li                p = mask/128.*quant + (1-mask)*p
127*a58d3d2aSXin Li
128*a58d3d2aSXin Li            w[1] = p
129*a58d3d2aSXin Li            layer.set_weights(w)
130*a58d3d2aSXin Li
131*a58d3d2aSXin Liclass SparsifyGRUB(Callback):
132*a58d3d2aSXin Li    def __init__(self, t_start, t_end, interval, grua_units, density, quantize=False):
133*a58d3d2aSXin Li        super(SparsifyGRUB, self).__init__()
134*a58d3d2aSXin Li        self.batch = 0
135*a58d3d2aSXin Li        self.t_start = t_start
136*a58d3d2aSXin Li        self.t_end = t_end
137*a58d3d2aSXin Li        self.interval = interval
138*a58d3d2aSXin Li        self.final_density = density
139*a58d3d2aSXin Li        self.grua_units = grua_units
140*a58d3d2aSXin Li        self.quantize = quantize
141*a58d3d2aSXin Li
142*a58d3d2aSXin Li    def on_batch_end(self, batch, logs=None):
143*a58d3d2aSXin Li        #print("batch number", self.batch)
144*a58d3d2aSXin Li        self.batch += 1
145*a58d3d2aSXin Li        if self.quantize or (self.batch > self.t_start and (self.batch-self.t_start) % self.interval == 0) or self.batch >= self.t_end:
146*a58d3d2aSXin Li            #print("constrain");
147*a58d3d2aSXin Li            layer = self.model.get_layer('gru_b')
148*a58d3d2aSXin Li            w = layer.get_weights()
149*a58d3d2aSXin Li            p = w[0]
150*a58d3d2aSXin Li            N = p.shape[0]
151*a58d3d2aSXin Li            M = p.shape[1]//3
152*a58d3d2aSXin Li            for k in range(3):
153*a58d3d2aSXin Li                density = self.final_density[k]
154*a58d3d2aSXin Li                if self.batch < self.t_end and not self.quantize:
155*a58d3d2aSXin Li                    r = 1 - (self.batch-self.t_start)/(self.t_end - self.t_start)
156*a58d3d2aSXin Li                    density = 1 - (1-self.final_density[k])*(1 - r*r*r)
157*a58d3d2aSXin Li                A = p[:, k*M:(k+1)*M]
158*a58d3d2aSXin Li                #This is needed because of the CuDNNGRU strange weight ordering
159*a58d3d2aSXin Li                A = np.reshape(A, (M, N))
160*a58d3d2aSXin Li                A = np.transpose(A, (1, 0))
161*a58d3d2aSXin Li                N2 = self.grua_units
162*a58d3d2aSXin Li                A2 = A[:N2, :]
163*a58d3d2aSXin Li                L=np.reshape(A2, (N2//4, 4, M//8, 8))
164*a58d3d2aSXin Li                S=np.sum(L*L, axis=-1)
165*a58d3d2aSXin Li                S=np.sum(S, axis=1)
166*a58d3d2aSXin Li                SS=np.sort(np.reshape(S, (-1,)))
167*a58d3d2aSXin Li                thresh = SS[round(M*N2//32*(1-density))]
168*a58d3d2aSXin Li                mask = (S>=thresh).astype('float32')
169*a58d3d2aSXin Li                mask = np.repeat(mask, 4, axis=0)
170*a58d3d2aSXin Li                mask = np.repeat(mask, 8, axis=1)
171*a58d3d2aSXin Li                A = np.concatenate([A2*mask, A[N2:,:]], axis=0)
172*a58d3d2aSXin Li                #This is needed because of the CuDNNGRU strange weight ordering
173*a58d3d2aSXin Li                A = np.transpose(A, (1, 0))
174*a58d3d2aSXin Li                A = np.reshape(A, (N, M))
175*a58d3d2aSXin Li                p[:, k*M:(k+1)*M] = A
176*a58d3d2aSXin Li                #print(thresh, np.mean(mask))
177*a58d3d2aSXin Li            if self.quantize and ((self.batch > self.t_start and (self.batch-self.t_start) % self.interval == 0) or self.batch >= self.t_end):
178*a58d3d2aSXin Li                if self.batch < self.t_end:
179*a58d3d2aSXin Li                    threshold = .5*(self.batch - self.t_start)/(self.t_end - self.t_start)
180*a58d3d2aSXin Li                else:
181*a58d3d2aSXin Li                    threshold = .5
182*a58d3d2aSXin Li                quant = np.round(p*128.)
183*a58d3d2aSXin Li                res = p*128.-quant
184*a58d3d2aSXin Li                mask = (np.abs(res) <= threshold).astype('float32')
185*a58d3d2aSXin Li                p = mask/128.*quant + (1-mask)*p
186*a58d3d2aSXin Li
187*a58d3d2aSXin Li            w[0] = p
188*a58d3d2aSXin Li            layer.set_weights(w)
189*a58d3d2aSXin Li
190*a58d3d2aSXin Li
191*a58d3d2aSXin Liclass PCMInit(Initializer):
192*a58d3d2aSXin Li    def __init__(self, gain=.1, seed=None):
193*a58d3d2aSXin Li        self.gain = gain
194*a58d3d2aSXin Li        self.seed = seed
195*a58d3d2aSXin Li
196*a58d3d2aSXin Li    def __call__(self, shape, dtype=None):
197*a58d3d2aSXin Li        num_rows = 1
198*a58d3d2aSXin Li        for dim in shape[:-1]:
199*a58d3d2aSXin Li            num_rows *= dim
200*a58d3d2aSXin Li        num_cols = shape[-1]
201*a58d3d2aSXin Li        flat_shape = (num_rows, num_cols)
202*a58d3d2aSXin Li        if self.seed is not None:
203*a58d3d2aSXin Li            np.random.seed(self.seed)
204*a58d3d2aSXin Li        a = np.random.uniform(-1.7321, 1.7321, flat_shape)
205*a58d3d2aSXin Li        #a[:,0] = math.sqrt(12)*np.arange(-.5*num_rows+.5,.5*num_rows-.4)/num_rows
206*a58d3d2aSXin Li        #a[:,1] = .5*a[:,0]*a[:,0]*a[:,0]
207*a58d3d2aSXin Li        a = a + np.reshape(math.sqrt(12)*np.arange(-.5*num_rows+.5,.5*num_rows-.4)/num_rows, (num_rows, 1))
208*a58d3d2aSXin Li        return self.gain * a.astype("float32")
209*a58d3d2aSXin Li
210*a58d3d2aSXin Li    def get_config(self):
211*a58d3d2aSXin Li        return {
212*a58d3d2aSXin Li            'gain': self.gain,
213*a58d3d2aSXin Li            'seed': self.seed
214*a58d3d2aSXin Li        }
215*a58d3d2aSXin Li
216*a58d3d2aSXin Liclass WeightClip(Constraint):
217*a58d3d2aSXin Li    '''Clips the weights incident to each hidden unit to be inside a range
218*a58d3d2aSXin Li    '''
219*a58d3d2aSXin Li    def __init__(self, c=2):
220*a58d3d2aSXin Li        self.c = c
221*a58d3d2aSXin Li
222*a58d3d2aSXin Li    def __call__(self, p):
223*a58d3d2aSXin Li        # Ensure that abs of adjacent weights don't sum to more than 127. Otherwise there's a risk of
224*a58d3d2aSXin Li        # saturation when implementing dot products with SSSE3 or AVX2.
225*a58d3d2aSXin Li        return self.c*p/tf.maximum(self.c, tf.repeat(tf.abs(p[:, 1::2])+tf.abs(p[:, 0::2]), 2, axis=1))
226*a58d3d2aSXin Li        #return K.clip(p, -self.c, self.c)
227*a58d3d2aSXin Li
228*a58d3d2aSXin Li    def get_config(self):
229*a58d3d2aSXin Li        return {'name': self.__class__.__name__,
230*a58d3d2aSXin Li            'c': self.c}
231*a58d3d2aSXin Li
232*a58d3d2aSXin Liconstraint = WeightClip(0.992)
233*a58d3d2aSXin Li
234*a58d3d2aSXin Lidef new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features=20, batch_size=128, training=False, adaptation=False, quantize=False, flag_e2e = False, cond_size=128, lpc_order=16, lpc_gamma=1., lookahead=2):
235*a58d3d2aSXin Li    pcm = Input(shape=(None, 1), batch_size=batch_size)
236*a58d3d2aSXin Li    dpcm = Input(shape=(None, 3), batch_size=batch_size)
237*a58d3d2aSXin Li    feat = Input(shape=(None, nb_used_features), batch_size=batch_size)
238*a58d3d2aSXin Li    pitch = Input(shape=(None, 1), batch_size=batch_size)
239*a58d3d2aSXin Li    dec_feat = Input(shape=(None, cond_size))
240*a58d3d2aSXin Li    dec_state1 = Input(shape=(rnn_units1,))
241*a58d3d2aSXin Li    dec_state2 = Input(shape=(rnn_units2,))
242*a58d3d2aSXin Li
243*a58d3d2aSXin Li    padding = 'valid' if training else 'same'
244*a58d3d2aSXin Li    fconv1 = Conv1D(cond_size, 3, padding=padding, activation='tanh', name='feature_conv1')
245*a58d3d2aSXin Li    fconv2 = Conv1D(cond_size, 3, padding=padding, activation='tanh', name='feature_conv2')
246*a58d3d2aSXin Li    pembed = Embedding(256, 64, name='embed_pitch')
247*a58d3d2aSXin Li    cat_feat = Concatenate()([feat, Reshape((-1, 64))(pembed(pitch))])
248*a58d3d2aSXin Li
249*a58d3d2aSXin Li    cfeat = fconv2(fconv1(cat_feat))
250*a58d3d2aSXin Li
251*a58d3d2aSXin Li    fdense1 = Dense(cond_size, activation='tanh', name='feature_dense1')
252*a58d3d2aSXin Li    fdense2 = Dense(cond_size, activation='tanh', name='feature_dense2')
253*a58d3d2aSXin Li
254*a58d3d2aSXin Li    if flag_e2e and quantize:
255*a58d3d2aSXin Li        fconv1.trainable = False
256*a58d3d2aSXin Li        fconv2.trainable = False
257*a58d3d2aSXin Li        fdense1.trainable = False
258*a58d3d2aSXin Li        fdense2.trainable = False
259*a58d3d2aSXin Li
260*a58d3d2aSXin Li    cfeat = fdense2(fdense1(cfeat))
261*a58d3d2aSXin Li
262*a58d3d2aSXin Li    error_calc = Lambda(lambda x: tf_l2u(x[0] - tf.roll(x[1],1,axis = 1)))
263*a58d3d2aSXin Li    if flag_e2e:
264*a58d3d2aSXin Li        lpcoeffs = diff_rc2lpc(name = "rc2lpc")(cfeat)
265*a58d3d2aSXin Li    else:
266*a58d3d2aSXin Li        lpcoeffs = Input(shape=(None, lpc_order), batch_size=batch_size)
267*a58d3d2aSXin Li
268*a58d3d2aSXin Li    real_preds = diff_pred(name = "real_lpc2preds")([pcm,lpcoeffs])
269*a58d3d2aSXin Li    weighting = lpc_gamma ** np.arange(1, 17).astype('float32')
270*a58d3d2aSXin Li    weighted_lpcoeffs = Lambda(lambda x: x[0]*x[1])([lpcoeffs, weighting])
271*a58d3d2aSXin Li    tensor_preds = diff_pred(name = "lpc2preds")([pcm,weighted_lpcoeffs])
272*a58d3d2aSXin Li    past_errors = error_calc([pcm,tensor_preds])
273*a58d3d2aSXin Li
274*a58d3d2aSXin Li    embed = diff_Embed(name='embed_sig',initializer = PCMInit())
275*a58d3d2aSXin Li    cpcm = Concatenate()([tf_l2u(pcm),tf_l2u(tensor_preds),past_errors])
276*a58d3d2aSXin Li    cpcm = GaussianNoise(.3)(cpcm)
277*a58d3d2aSXin Li    cpcm = Reshape((-1, embed_size*3))(embed(cpcm))
278*a58d3d2aSXin Li    cpcm_decoder = Reshape((-1, embed_size*3))(embed(dpcm))
279*a58d3d2aSXin Li
280*a58d3d2aSXin Li
281*a58d3d2aSXin Li    rep = Lambda(lambda x: K.repeat_elements(x, frame_size, 1))
282*a58d3d2aSXin Li
283*a58d3d2aSXin Li    quant = quant_regularizer if quantize else None
284*a58d3d2aSXin Li
285*a58d3d2aSXin Li    if training:
286*a58d3d2aSXin Li        rnn = CuDNNGRU(rnn_units1, return_sequences=True, return_state=True, name='gru_a', stateful=True,
287*a58d3d2aSXin Li              recurrent_constraint = constraint, recurrent_regularizer=quant)
288*a58d3d2aSXin Li        rnn2 = CuDNNGRU(rnn_units2, return_sequences=True, return_state=True, name='gru_b', stateful=True,
289*a58d3d2aSXin Li               kernel_constraint=constraint, recurrent_constraint = constraint, kernel_regularizer=quant, recurrent_regularizer=quant)
290*a58d3d2aSXin Li    else:
291*a58d3d2aSXin Li        rnn = GRU(rnn_units1, return_sequences=True, return_state=True, recurrent_activation="sigmoid", reset_after='true', name='gru_a', stateful=True,
292*a58d3d2aSXin Li              recurrent_constraint = constraint, recurrent_regularizer=quant)
293*a58d3d2aSXin Li        rnn2 = GRU(rnn_units2, return_sequences=True, return_state=True, recurrent_activation="sigmoid", reset_after='true', name='gru_b', stateful=True,
294*a58d3d2aSXin Li               kernel_constraint=constraint, recurrent_constraint = constraint, kernel_regularizer=quant, recurrent_regularizer=quant)
295*a58d3d2aSXin Li
296*a58d3d2aSXin Li    rnn_in = Concatenate()([cpcm, rep(cfeat)])
297*a58d3d2aSXin Li    md = MDense(pcm_levels, activation='sigmoid', name='dual_fc')
298*a58d3d2aSXin Li    gru_out1, _ = rnn(rnn_in)
299*a58d3d2aSXin Li    gru_out1 = GaussianNoise(.005)(gru_out1)
300*a58d3d2aSXin Li    gru_out2, _ = rnn2(Concatenate()([gru_out1, rep(cfeat)]))
301*a58d3d2aSXin Li    ulaw_prob = Lambda(tree_to_pdf_train)(md(gru_out2))
302*a58d3d2aSXin Li
303*a58d3d2aSXin Li    if adaptation:
304*a58d3d2aSXin Li        rnn.trainable=False
305*a58d3d2aSXin Li        rnn2.trainable=False
306*a58d3d2aSXin Li        md.trainable=False
307*a58d3d2aSXin Li        embed.Trainable=False
308*a58d3d2aSXin Li
309*a58d3d2aSXin Li    m_out = Concatenate(name='pdf')([tensor_preds,real_preds,ulaw_prob])
310*a58d3d2aSXin Li    if not flag_e2e:
311*a58d3d2aSXin Li        model = Model([pcm, feat, pitch, lpcoeffs], m_out)
312*a58d3d2aSXin Li    else:
313*a58d3d2aSXin Li        model = Model([pcm, feat, pitch], [m_out, cfeat])
314*a58d3d2aSXin Li    model.rnn_units1 = rnn_units1
315*a58d3d2aSXin Li    model.rnn_units2 = rnn_units2
316*a58d3d2aSXin Li    model.nb_used_features = nb_used_features
317*a58d3d2aSXin Li    model.frame_size = frame_size
318*a58d3d2aSXin Li
319*a58d3d2aSXin Li    if not flag_e2e:
320*a58d3d2aSXin Li        encoder = Model([feat, pitch], cfeat)
321*a58d3d2aSXin Li        dec_rnn_in = Concatenate()([cpcm_decoder, dec_feat])
322*a58d3d2aSXin Li    else:
323*a58d3d2aSXin Li        encoder = Model([feat, pitch], [cfeat,lpcoeffs])
324*a58d3d2aSXin Li        dec_rnn_in = Concatenate()([cpcm_decoder, dec_feat])
325*a58d3d2aSXin Li    dec_gru_out1, state1 = rnn(dec_rnn_in, initial_state=dec_state1)
326*a58d3d2aSXin Li    dec_gru_out2, state2 = rnn2(Concatenate()([dec_gru_out1, dec_feat]), initial_state=dec_state2)
327*a58d3d2aSXin Li    dec_ulaw_prob = Lambda(tree_to_pdf_infer)(md(dec_gru_out2))
328*a58d3d2aSXin Li
329*a58d3d2aSXin Li    if flag_e2e:
330*a58d3d2aSXin Li        decoder = Model([dpcm, dec_feat, dec_state1, dec_state2], [dec_ulaw_prob, state1, state2])
331*a58d3d2aSXin Li    else:
332*a58d3d2aSXin Li        decoder = Model([dpcm, dec_feat, dec_state1, dec_state2], [dec_ulaw_prob, state1, state2])
333*a58d3d2aSXin Li
334*a58d3d2aSXin Li    # add parameters to model
335*a58d3d2aSXin Li    set_parameter(model, 'lpc_gamma', lpc_gamma, dtype='float64')
336*a58d3d2aSXin Li    set_parameter(model, 'flag_e2e', flag_e2e, dtype='bool')
337*a58d3d2aSXin Li    set_parameter(model, 'lookahead', lookahead, dtype='int32')
338*a58d3d2aSXin Li
339*a58d3d2aSXin Li    return model, encoder, decoder
340