xref: /aosp_15_r20/external/libopus/dnn/training_tf2/lpcnet.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1#!/usr/bin/python3
2'''Copyright (c) 2018 Mozilla
3
4   Redistribution and use in source and binary forms, with or without
5   modification, are permitted provided that the following conditions
6   are met:
7
8   - Redistributions of source code must retain the above copyright
9   notice, this list of conditions and the following disclaimer.
10
11   - Redistributions in binary form must reproduce the above copyright
12   notice, this list of conditions and the following disclaimer in the
13   documentation and/or other materials provided with the distribution.
14
15   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18   A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
19   CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26'''
27
28import math
29import tensorflow as tf
30from tensorflow.keras.models import Model
31from tensorflow.keras.layers import Input, GRU, Dense, Embedding, Reshape, Concatenate, Lambda, Conv1D, Multiply, Add, Bidirectional, MaxPooling1D, Activation, GaussianNoise
32from tensorflow.compat.v1.keras.layers import CuDNNGRU
33from tensorflow.keras import backend as K
34from tensorflow.keras.constraints import Constraint
35from tensorflow.keras.initializers import Initializer
36from tensorflow.keras.callbacks import Callback
37from mdense import MDense
38import numpy as np
39import h5py
40import sys
41from tf_funcs import *
42from diffembed import diff_Embed
43from parameters import set_parameter
44
45frame_size = 160
46pcm_bits = 8
47embed_size = 128
48pcm_levels = 2**pcm_bits
49
50def interleave(p, samples):
51    p2=tf.expand_dims(p, 3)
52    nb_repeats = pcm_levels//(2*p.shape[2])
53    p3 = tf.reshape(tf.repeat(tf.concat([1-p2, p2], 3), nb_repeats), (-1, samples, pcm_levels))
54    return p3
55
56def tree_to_pdf(p, samples):
57    return interleave(p[:,:,1:2], samples) * interleave(p[:,:,2:4], samples) * interleave(p[:,:,4:8], samples) * interleave(p[:,:,8:16], samples) \
58         * interleave(p[:,:,16:32], samples) * interleave(p[:,:,32:64], samples) * interleave(p[:,:,64:128], samples) * interleave(p[:,:,128:256], samples)
59
60def tree_to_pdf_train(p):
61    #FIXME: try not to hardcode the 2400 samples (15 frames * 160 samples/frame)
62    return tree_to_pdf(p, 2400)
63
64def tree_to_pdf_infer(p):
65    return tree_to_pdf(p, 1)
66
67def quant_regularizer(x):
68    Q = 128
69    Q_1 = 1./Q
70    #return .01 * tf.reduce_mean(1 - tf.math.cos(2*3.1415926535897931*(Q*x-tf.round(Q*x))))
71    return .01 * tf.reduce_mean(K.sqrt(K.sqrt(1.0001 - tf.math.cos(2*3.1415926535897931*(Q*x-tf.round(Q*x))))))
72
73class Sparsify(Callback):
74    def __init__(self, t_start, t_end, interval, density, quantize=False):
75        super(Sparsify, self).__init__()
76        self.batch = 0
77        self.t_start = t_start
78        self.t_end = t_end
79        self.interval = interval
80        self.final_density = density
81        self.quantize = quantize
82
83    def on_batch_end(self, batch, logs=None):
84        #print("batch number", self.batch)
85        self.batch += 1
86        if self.quantize or (self.batch > self.t_start and (self.batch-self.t_start) % self.interval == 0) or self.batch >= self.t_end:
87            #print("constrain");
88            layer = self.model.get_layer('gru_a')
89            w = layer.get_weights()
90            p = w[1]
91            nb = p.shape[1]//p.shape[0]
92            N = p.shape[0]
93            #print("nb = ", nb, ", N = ", N);
94            #print(p.shape)
95            #print ("density = ", density)
96            for k in range(nb):
97                density = self.final_density[k]
98                if self.batch < self.t_end and not self.quantize:
99                    r = 1 - (self.batch-self.t_start)/(self.t_end - self.t_start)
100                    density = 1 - (1-self.final_density[k])*(1 - r*r*r)
101                A = p[:, k*N:(k+1)*N]
102                A = A - np.diag(np.diag(A))
103                #This is needed because of the CuDNNGRU strange weight ordering
104                A = np.transpose(A, (1, 0))
105                L=np.reshape(A, (N//4, 4, N//8, 8))
106                S=np.sum(L*L, axis=-1)
107                S=np.sum(S, axis=1)
108                SS=np.sort(np.reshape(S, (-1,)))
109                thresh = SS[round(N*N//32*(1-density))]
110                mask = (S>=thresh).astype('float32')
111                mask = np.repeat(mask, 4, axis=0)
112                mask = np.repeat(mask, 8, axis=1)
113                mask = np.minimum(1, mask + np.diag(np.ones((N,))))
114                #This is needed because of the CuDNNGRU strange weight ordering
115                mask = np.transpose(mask, (1, 0))
116                p[:, k*N:(k+1)*N] = p[:, k*N:(k+1)*N]*mask
117                #print(thresh, np.mean(mask))
118            if self.quantize and ((self.batch > self.t_start and (self.batch-self.t_start) % self.interval == 0) or self.batch >= self.t_end):
119                if self.batch < self.t_end:
120                    threshold = .5*(self.batch - self.t_start)/(self.t_end - self.t_start)
121                else:
122                    threshold = .5
123                quant = np.round(p*128.)
124                res = p*128.-quant
125                mask = (np.abs(res) <= threshold).astype('float32')
126                p = mask/128.*quant + (1-mask)*p
127
128            w[1] = p
129            layer.set_weights(w)
130
131class SparsifyGRUB(Callback):
132    def __init__(self, t_start, t_end, interval, grua_units, density, quantize=False):
133        super(SparsifyGRUB, self).__init__()
134        self.batch = 0
135        self.t_start = t_start
136        self.t_end = t_end
137        self.interval = interval
138        self.final_density = density
139        self.grua_units = grua_units
140        self.quantize = quantize
141
142    def on_batch_end(self, batch, logs=None):
143        #print("batch number", self.batch)
144        self.batch += 1
145        if self.quantize or (self.batch > self.t_start and (self.batch-self.t_start) % self.interval == 0) or self.batch >= self.t_end:
146            #print("constrain");
147            layer = self.model.get_layer('gru_b')
148            w = layer.get_weights()
149            p = w[0]
150            N = p.shape[0]
151            M = p.shape[1]//3
152            for k in range(3):
153                density = self.final_density[k]
154                if self.batch < self.t_end and not self.quantize:
155                    r = 1 - (self.batch-self.t_start)/(self.t_end - self.t_start)
156                    density = 1 - (1-self.final_density[k])*(1 - r*r*r)
157                A = p[:, k*M:(k+1)*M]
158                #This is needed because of the CuDNNGRU strange weight ordering
159                A = np.reshape(A, (M, N))
160                A = np.transpose(A, (1, 0))
161                N2 = self.grua_units
162                A2 = A[:N2, :]
163                L=np.reshape(A2, (N2//4, 4, M//8, 8))
164                S=np.sum(L*L, axis=-1)
165                S=np.sum(S, axis=1)
166                SS=np.sort(np.reshape(S, (-1,)))
167                thresh = SS[round(M*N2//32*(1-density))]
168                mask = (S>=thresh).astype('float32')
169                mask = np.repeat(mask, 4, axis=0)
170                mask = np.repeat(mask, 8, axis=1)
171                A = np.concatenate([A2*mask, A[N2:,:]], axis=0)
172                #This is needed because of the CuDNNGRU strange weight ordering
173                A = np.transpose(A, (1, 0))
174                A = np.reshape(A, (N, M))
175                p[:, k*M:(k+1)*M] = A
176                #print(thresh, np.mean(mask))
177            if self.quantize and ((self.batch > self.t_start and (self.batch-self.t_start) % self.interval == 0) or self.batch >= self.t_end):
178                if self.batch < self.t_end:
179                    threshold = .5*(self.batch - self.t_start)/(self.t_end - self.t_start)
180                else:
181                    threshold = .5
182                quant = np.round(p*128.)
183                res = p*128.-quant
184                mask = (np.abs(res) <= threshold).astype('float32')
185                p = mask/128.*quant + (1-mask)*p
186
187            w[0] = p
188            layer.set_weights(w)
189
190
191class PCMInit(Initializer):
192    def __init__(self, gain=.1, seed=None):
193        self.gain = gain
194        self.seed = seed
195
196    def __call__(self, shape, dtype=None):
197        num_rows = 1
198        for dim in shape[:-1]:
199            num_rows *= dim
200        num_cols = shape[-1]
201        flat_shape = (num_rows, num_cols)
202        if self.seed is not None:
203            np.random.seed(self.seed)
204        a = np.random.uniform(-1.7321, 1.7321, flat_shape)
205        #a[:,0] = math.sqrt(12)*np.arange(-.5*num_rows+.5,.5*num_rows-.4)/num_rows
206        #a[:,1] = .5*a[:,0]*a[:,0]*a[:,0]
207        a = a + np.reshape(math.sqrt(12)*np.arange(-.5*num_rows+.5,.5*num_rows-.4)/num_rows, (num_rows, 1))
208        return self.gain * a.astype("float32")
209
210    def get_config(self):
211        return {
212            'gain': self.gain,
213            'seed': self.seed
214        }
215
216class WeightClip(Constraint):
217    '''Clips the weights incident to each hidden unit to be inside a range
218    '''
219    def __init__(self, c=2):
220        self.c = c
221
222    def __call__(self, p):
223        # Ensure that abs of adjacent weights don't sum to more than 127. Otherwise there's a risk of
224        # saturation when implementing dot products with SSSE3 or AVX2.
225        return self.c*p/tf.maximum(self.c, tf.repeat(tf.abs(p[:, 1::2])+tf.abs(p[:, 0::2]), 2, axis=1))
226        #return K.clip(p, -self.c, self.c)
227
228    def get_config(self):
229        return {'name': self.__class__.__name__,
230            'c': self.c}
231
232constraint = WeightClip(0.992)
233
234def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features=20, batch_size=128, training=False, adaptation=False, quantize=False, flag_e2e = False, cond_size=128, lpc_order=16, lpc_gamma=1., lookahead=2):
235    pcm = Input(shape=(None, 1), batch_size=batch_size)
236    dpcm = Input(shape=(None, 3), batch_size=batch_size)
237    feat = Input(shape=(None, nb_used_features), batch_size=batch_size)
238    pitch = Input(shape=(None, 1), batch_size=batch_size)
239    dec_feat = Input(shape=(None, cond_size))
240    dec_state1 = Input(shape=(rnn_units1,))
241    dec_state2 = Input(shape=(rnn_units2,))
242
243    padding = 'valid' if training else 'same'
244    fconv1 = Conv1D(cond_size, 3, padding=padding, activation='tanh', name='feature_conv1')
245    fconv2 = Conv1D(cond_size, 3, padding=padding, activation='tanh', name='feature_conv2')
246    pembed = Embedding(256, 64, name='embed_pitch')
247    cat_feat = Concatenate()([feat, Reshape((-1, 64))(pembed(pitch))])
248
249    cfeat = fconv2(fconv1(cat_feat))
250
251    fdense1 = Dense(cond_size, activation='tanh', name='feature_dense1')
252    fdense2 = Dense(cond_size, activation='tanh', name='feature_dense2')
253
254    if flag_e2e and quantize:
255        fconv1.trainable = False
256        fconv2.trainable = False
257        fdense1.trainable = False
258        fdense2.trainable = False
259
260    cfeat = fdense2(fdense1(cfeat))
261
262    error_calc = Lambda(lambda x: tf_l2u(x[0] - tf.roll(x[1],1,axis = 1)))
263    if flag_e2e:
264        lpcoeffs = diff_rc2lpc(name = "rc2lpc")(cfeat)
265    else:
266        lpcoeffs = Input(shape=(None, lpc_order), batch_size=batch_size)
267
268    real_preds = diff_pred(name = "real_lpc2preds")([pcm,lpcoeffs])
269    weighting = lpc_gamma ** np.arange(1, 17).astype('float32')
270    weighted_lpcoeffs = Lambda(lambda x: x[0]*x[1])([lpcoeffs, weighting])
271    tensor_preds = diff_pred(name = "lpc2preds")([pcm,weighted_lpcoeffs])
272    past_errors = error_calc([pcm,tensor_preds])
273
274    embed = diff_Embed(name='embed_sig',initializer = PCMInit())
275    cpcm = Concatenate()([tf_l2u(pcm),tf_l2u(tensor_preds),past_errors])
276    cpcm = GaussianNoise(.3)(cpcm)
277    cpcm = Reshape((-1, embed_size*3))(embed(cpcm))
278    cpcm_decoder = Reshape((-1, embed_size*3))(embed(dpcm))
279
280
281    rep = Lambda(lambda x: K.repeat_elements(x, frame_size, 1))
282
283    quant = quant_regularizer if quantize else None
284
285    if training:
286        rnn = CuDNNGRU(rnn_units1, return_sequences=True, return_state=True, name='gru_a', stateful=True,
287              recurrent_constraint = constraint, recurrent_regularizer=quant)
288        rnn2 = CuDNNGRU(rnn_units2, return_sequences=True, return_state=True, name='gru_b', stateful=True,
289               kernel_constraint=constraint, recurrent_constraint = constraint, kernel_regularizer=quant, recurrent_regularizer=quant)
290    else:
291        rnn = GRU(rnn_units1, return_sequences=True, return_state=True, recurrent_activation="sigmoid", reset_after='true', name='gru_a', stateful=True,
292              recurrent_constraint = constraint, recurrent_regularizer=quant)
293        rnn2 = GRU(rnn_units2, return_sequences=True, return_state=True, recurrent_activation="sigmoid", reset_after='true', name='gru_b', stateful=True,
294               kernel_constraint=constraint, recurrent_constraint = constraint, kernel_regularizer=quant, recurrent_regularizer=quant)
295
296    rnn_in = Concatenate()([cpcm, rep(cfeat)])
297    md = MDense(pcm_levels, activation='sigmoid', name='dual_fc')
298    gru_out1, _ = rnn(rnn_in)
299    gru_out1 = GaussianNoise(.005)(gru_out1)
300    gru_out2, _ = rnn2(Concatenate()([gru_out1, rep(cfeat)]))
301    ulaw_prob = Lambda(tree_to_pdf_train)(md(gru_out2))
302
303    if adaptation:
304        rnn.trainable=False
305        rnn2.trainable=False
306        md.trainable=False
307        embed.Trainable=False
308
309    m_out = Concatenate(name='pdf')([tensor_preds,real_preds,ulaw_prob])
310    if not flag_e2e:
311        model = Model([pcm, feat, pitch, lpcoeffs], m_out)
312    else:
313        model = Model([pcm, feat, pitch], [m_out, cfeat])
314    model.rnn_units1 = rnn_units1
315    model.rnn_units2 = rnn_units2
316    model.nb_used_features = nb_used_features
317    model.frame_size = frame_size
318
319    if not flag_e2e:
320        encoder = Model([feat, pitch], cfeat)
321        dec_rnn_in = Concatenate()([cpcm_decoder, dec_feat])
322    else:
323        encoder = Model([feat, pitch], [cfeat,lpcoeffs])
324        dec_rnn_in = Concatenate()([cpcm_decoder, dec_feat])
325    dec_gru_out1, state1 = rnn(dec_rnn_in, initial_state=dec_state1)
326    dec_gru_out2, state2 = rnn2(Concatenate()([dec_gru_out1, dec_feat]), initial_state=dec_state2)
327    dec_ulaw_prob = Lambda(tree_to_pdf_infer)(md(dec_gru_out2))
328
329    if flag_e2e:
330        decoder = Model([dpcm, dec_feat, dec_state1, dec_state2], [dec_ulaw_prob, state1, state2])
331    else:
332        decoder = Model([dpcm, dec_feat, dec_state1, dec_state2], [dec_ulaw_prob, state1, state2])
333
334    # add parameters to model
335    set_parameter(model, 'lpc_gamma', lpc_gamma, dtype='float64')
336    set_parameter(model, 'flag_e2e', flag_e2e, dtype='bool')
337    set_parameter(model, 'lookahead', lookahead, dtype='int32')
338
339    return model, encoder, decoder
340