xref: /aosp_15_r20/external/libopus/dnn/training_tf2/rdovae.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1#!/usr/bin/python3
2'''Copyright (c) 2022 Amazon
3
4   Redistribution and use in source and binary forms, with or without
5   modification, are permitted provided that the following conditions
6   are met:
7
8   - Redistributions of source code must retain the above copyright
9   notice, this list of conditions and the following disclaimer.
10
11   - Redistributions in binary form must reproduce the above copyright
12   notice, this list of conditions and the following disclaimer in the
13   documentation and/or other materials provided with the distribution.
14
15   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18   A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
19   CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26'''
27
28import math
29import tensorflow as tf
30from tensorflow.keras.models import Model
31from tensorflow.keras.layers import Input, GRU, Dense, Embedding, Reshape, Concatenate, Lambda, Conv1D, Multiply, Add, Bidirectional, MaxPooling1D, Activation, GaussianNoise, AveragePooling1D, RepeatVector
32from tensorflow.compat.v1.keras.layers import CuDNNGRU
33from tensorflow.keras import backend as K
34from tensorflow.keras.constraints import Constraint
35from tensorflow.keras.initializers import Initializer
36from tensorflow.keras.callbacks import Callback
37from tensorflow.keras.regularizers import l1
38import numpy as np
39import h5py
40from uniform_noise import UniformNoise
41
42class WeightClip(Constraint):
43    '''Clips the weights incident to each hidden unit to be inside a range
44    '''
45    def __init__(self, c=2):
46        self.c = c
47
48    def __call__(self, p):
49        # Ensure that abs of adjacent weights don't sum to more than 127. Otherwise there's a risk of
50        # saturation when implementing dot products with SSSE3 or AVX2.
51        return self.c*p/tf.maximum(self.c, tf.repeat(tf.abs(p[:, 1::2])+tf.abs(p[:, 0::2]), 2, axis=1))
52        #return K.clip(p, -self.c, self.c)
53
54    def get_config(self):
55        return {'name': self.__class__.__name__,
56            'c': self.c}
57
58constraint = WeightClip(0.496)
59
60def soft_quantize(x):
61    #x = 4*x
62    #x = x - (.25/np.math.pi)*tf.math.sin(2*np.math.pi*x)
63    #x = x - (.25/np.math.pi)*tf.math.sin(2*np.math.pi*x)
64    #x = x - (.25/np.math.pi)*tf.math.sin(2*np.math.pi*x)
65    return x
66
67def noise_quantize(x):
68    return soft_quantize(x + (K.random_uniform((128, 16, 80))-.5) )
69
70def hard_quantize(x):
71    x = soft_quantize(x)
72    quantized = tf.round(x)
73    return x + tf.stop_gradient(quantized - x)
74
75def apply_dead_zone(x):
76    d = x[1]*.05
77    x = x[0]
78    y = x - d*tf.math.tanh(x/(.1+d))
79    return y
80
81def rate_loss(y_true,y_pred):
82    log2_e = 1.4427
83    n = y_pred.shape[-1]
84    C = n - log2_e*np.math.log(np.math.gamma(n))
85    k = K.sum(K.abs(y_pred), axis=-1)
86    p = 1.5
87    #rate = C + (n-1)*log2_e*tf.math.log((k**p + (n/5)**p)**(1/p))
88    rate = C + (n-1)*log2_e*tf.math.log(k + .112*n**2/(n/1.8+k) )
89    return K.mean(rate)
90
91eps=1e-6
92def safelog2(x):
93    log2_e = 1.4427
94    return log2_e*tf.math.log(eps+x)
95
96def feat_dist_loss(y_true,y_pred):
97    lambda_1 = 1./K.sqrt(y_pred[:,:,:,-1])
98    y_pred = y_pred[:,:,:,:-1]
99    ceps = y_pred[:,:,:,:18] - y_true[:,:,:18]
100    pitch = 2*(y_pred[:,:,:,18:19] - y_true[:,:,18:19])/(y_true[:,:,18:19] + 2)
101    corr = y_pred[:,:,:,19:] - y_true[:,:,19:]
102    pitch_weight = K.square(K.maximum(0., y_true[:,:,19:]+.5))
103    return K.mean(lambda_1*K.mean(K.square(ceps) + 10*(1/18.)*K.abs(pitch)*pitch_weight + (1/18.)*K.square(corr), axis=-1))
104
105def sq1_rate_loss(y_true,y_pred):
106    lambda_val = K.sqrt(y_pred[:,:,-1])
107    y_pred = y_pred[:,:,:-1]
108    log2_e = 1.4427
109    n = y_pred.shape[-1]//3
110    r = (y_pred[:,:,2*n:])
111    p0 = (y_pred[:,:,n:2*n])
112    p0 = 1-r**(.5+.5*p0)
113    y_pred = y_pred[:,:,:n]
114    y_pred = soft_quantize(y_pred)
115
116    y0 = K.maximum(0., 1. - K.abs(y_pred))**2
117    rate = -y0*safelog2(p0*r**K.abs(y_pred)) - (1-y0)*safelog2(.5*(1-p0)*(1-r)*r**(K.abs(y_pred)-1))
118    rate = -safelog2(-.5*tf.math.log(r)*r**K.abs(y_pred))
119    rate = -safelog2((1-r)/(1+r)*r**K.abs(y_pred))
120    #rate = -safelog2(- tf.math.sinh(.5*tf.math.log(r))* r**K.abs(y_pred) - tf.math.cosh(K.maximum(0., .5 - K.abs(y_pred))*tf.math.log(r)) + 1)
121    rate = lambda_val*K.sum(rate, axis=-1)
122    return K.mean(rate)
123
124def sq2_rate_loss(y_true,y_pred):
125    lambda_val = K.sqrt(y_pred[:,:,-1])
126    y_pred = y_pred[:,:,:-1]
127    log2_e = 1.4427
128    n = y_pred.shape[-1]//3
129    r = y_pred[:,:,2*n:]
130    p0 = y_pred[:,:,n:2*n]
131    p0 = 1-r**(.5+.5*p0)
132    #theta = K.minimum(1., .5 + 0*p0 - 0.04*tf.math.log(r))
133    #p0 = 1-r**theta
134    y_pred = tf.round(y_pred[:,:,:n])
135    y0 = K.maximum(0., 1. - K.abs(y_pred))**2
136    rate = -y0*safelog2(p0*r**K.abs(y_pred)) - (1-y0)*safelog2(.5*(1-p0)*(1-r)*r**(K.abs(y_pred)-1))
137    rate = lambda_val*K.sum(rate, axis=-1)
138    return K.mean(rate)
139
140def sq_rate_metric(y_true,y_pred, reduce=True):
141    y_pred = y_pred[:,:,:-1]
142    log2_e = 1.4427
143    n = y_pred.shape[-1]//3
144    r = y_pred[:,:,2*n:]
145    p0 = y_pred[:,:,n:2*n]
146    p0 = 1-r**(.5+.5*p0)
147    #theta = K.minimum(1., .5 + 0*p0 - 0.04*tf.math.log(r))
148    #p0 = 1-r**theta
149    y_pred = tf.round(y_pred[:,:,:n])
150    y0 = K.maximum(0., 1. - K.abs(y_pred))**2
151    rate = -y0*safelog2(p0*r**K.abs(y_pred)) - (1-y0)*safelog2(.5*(1-p0)*(1-r)*r**(K.abs(y_pred)-1))
152    rate = K.sum(rate, axis=-1)
153    if reduce:
154        rate = K.mean(rate)
155    return rate
156
157def pvq_quant_search(x, k):
158    x = x/tf.reduce_sum(tf.abs(x), axis=-1, keepdims=True)
159    kx = k*x
160    y = tf.round(kx)
161    newk = k
162
163    for j in range(10):
164        #print("y = ", y)
165        #print("iteration ", j)
166        abs_y = tf.abs(y)
167        abs_kx = tf.abs(kx)
168        kk=tf.reduce_sum(abs_y, axis=-1)
169        #print("sums = ", kk)
170        plus = 1.000001*tf.reduce_min((abs_y+.5)/(abs_kx+1e-15), axis=-1)
171        minus = .999999*tf.reduce_max((abs_y-.5)/(abs_kx+1e-15), axis=-1)
172        #print("plus = ", plus)
173        #print("minus = ", minus)
174        factor = tf.where(kk>k, minus, plus)
175        factor = tf.where(kk==k, tf.ones_like(factor), factor)
176        #print("scale = ", factor)
177        factor = tf.expand_dims(factor, axis=-1)
178        #newk = newk * (k/kk)**.2
179        newk = newk*factor
180        kx = newk*x
181        #print("newk = ", newk)
182        #print("unquantized = ", newk*x)
183        y = tf.round(kx)
184
185    #print(y)
186    #print(K.mean(K.sum(K.abs(y), axis=-1)))
187    return y
188
189def pvq_quantize(x, k):
190    x = x/(1e-15+tf.norm(x, axis=-1,keepdims=True))
191    quantized = pvq_quant_search(x, k)
192    quantized = quantized/(1e-15+tf.norm(quantized, axis=-1,keepdims=True))
193    return x + tf.stop_gradient(quantized - x)
194
195
196def var_repeat(x):
197    return tf.repeat(tf.expand_dims(x[0], 1), K.shape(x[1])[1], axis=1)
198
199nb_state_dim = 24
200
201def new_rdovae_encoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256, training=False):
202    feat = Input(shape=(None, nb_used_features), batch_size=batch_size)
203
204    gru = CuDNNGRU if training else GRU
205    enc_dense1 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense1')
206    enc_dense2 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense2')
207    enc_dense3 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense3')
208    enc_dense4 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense4')
209    enc_dense5 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense5')
210    enc_dense6 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense6')
211    enc_dense7 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='enc_dense7')
212    enc_dense8 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='enc_dense8')
213
214    #bits_dense = Dense(nb_bits, activation='linear', name='bits_dense')
215    bits_dense = Conv1D(nb_bits, 4, padding='causal', activation='linear', name='bits_dense')
216
217    zero_out = Lambda(lambda x: 0*x)
218    inputs = Reshape((-1, 2*nb_used_features))(feat)
219    d1 = enc_dense1(inputs)
220    d2 = enc_dense2(d1)
221    d3 = enc_dense3(d2)
222    d4 = enc_dense4(d3)
223    d5 = enc_dense5(d4)
224    d6 = enc_dense6(d5)
225    d7 = enc_dense7(d6)
226    d8 = enc_dense8(d7)
227    pre_out = Concatenate()([d1, d2, d3, d4, d5, d6, d7, d8])
228    enc_out = bits_dense(pre_out)
229    global_dense1 = Dense(128, activation='tanh', name='gdense1')
230    global_dense2 = Dense(nb_state_dim, activation='tanh', name='gdense2')
231    global_bits = global_dense2(global_dense1(pre_out))
232
233    encoder = Model([feat], [enc_out, global_bits], name='encoder')
234    return encoder
235
236def new_rdovae_decoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256, training=False):
237    bits_input = Input(shape=(None, nb_bits), batch_size=batch_size, name="dec_bits")
238    gru_state_input = Input(shape=(nb_state_dim,), batch_size=batch_size, name="dec_state")
239
240
241    gru = CuDNNGRU if training else GRU
242    dec_dense1 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense1')
243    dec_dense2 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense2')
244    dec_dense3 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense3')
245    dec_dense4 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense4')
246    dec_dense5 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense5')
247    dec_dense6 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense6')
248    dec_dense7 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='dec_dense7')
249    dec_dense8 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='dec_dense8')
250
251    dec_final = Dense(bunch*nb_used_features, activation='linear', name='dec_final')
252
253    time_reverse = Lambda(lambda x: K.reverse(x, 1))
254    #time_reverse = Lambda(lambda x: x)
255    #gru_state_rep = RepeatVector(64//bunch)(gru_state_input)
256
257    #gru_state_rep = Lambda(var_repeat, output_shape=(None, nb_state_dim)) ([gru_state_input, bits_input])
258    gru_state1 = Dense(cond_size, name="state1", activation='tanh')(gru_state_input)
259    gru_state2 = Dense(cond_size, name="state2", activation='tanh')(gru_state_input)
260    gru_state3 = Dense(cond_size, name="state3", activation='tanh')(gru_state_input)
261
262    dec1 = dec_dense1(time_reverse(bits_input))
263    dec2 = dec_dense2(dec1, initial_state=gru_state1)
264    dec3 = dec_dense3(dec2)
265    dec4 = dec_dense4(dec3, initial_state=gru_state2)
266    dec5 = dec_dense5(dec4)
267    dec6 = dec_dense6(dec5, initial_state=gru_state3)
268    dec7 = dec_dense7(dec6)
269    dec8 = dec_dense8(dec7)
270    output = Reshape((-1, nb_used_features))(dec_final(Concatenate()([dec1, dec2, dec3, dec4, dec5, dec6, dec7, dec8])))
271    decoder = Model([bits_input, gru_state_input], time_reverse(output), name='decoder')
272    decoder.nb_bits = nb_bits
273    decoder.bunch = bunch
274    return decoder
275
276def new_split_decoder(decoder):
277    nb_bits = decoder.nb_bits
278    bunch = decoder.bunch
279    bits_input = Input(shape=(None, nb_bits), name="split_bits")
280    gru_state_input = Input(shape=(None,nb_state_dim), name="split_state")
281
282    range_select = Lambda(lambda x: x[0][:,x[1]:x[2],:])
283    elem_select = Lambda(lambda x: x[0][:,x[1],:])
284    points = [0, 100, 200, 300, 400]
285    outputs = []
286    for i in range(len(points)-1):
287        begin = points[i]//bunch
288        end = points[i+1]//bunch
289        state = elem_select([gru_state_input, end-1])
290        bits = range_select([bits_input, begin, end])
291        outputs.append(decoder([bits, state]))
292    output = Concatenate(axis=1)(outputs)
293    split = Model([bits_input, gru_state_input], output, name="split")
294    return split
295
296def tensor_concat(x):
297    #n = x[1]//2
298    #x = x[0]
299    n=2
300    y = []
301    for i in range(n-1):
302        offset = 2 * (n-1-i)
303        tmp = K.concatenate([x[i][:, offset:, :], x[-1][:, -offset:, :]], axis=-2)
304        y.append(tf.expand_dims(tmp, axis=0))
305    y.append(tf.expand_dims(x[-1], axis=0))
306    return Concatenate(axis=0)(y)
307
308
309def new_rdovae_model(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256, training=False):
310
311    feat = Input(shape=(None, nb_used_features), batch_size=batch_size)
312    quant_id = Input(shape=(None,), batch_size=batch_size)
313    lambda_val = Input(shape=(None, 1), batch_size=batch_size)
314    lambda_bunched = AveragePooling1D(pool_size=bunch//2, strides=bunch//2, padding="valid")(lambda_val)
315    lambda_up = Lambda(lambda x: K.repeat_elements(x, 2, axis=-2))(lambda_val)
316
317    qembedding = Embedding(nb_quant, 6*nb_bits, name='quant_embed', embeddings_initializer='zeros')
318    quant_embed_dec = qembedding(quant_id)
319    quant_scale = Activation('softplus')(Lambda(lambda x: x[:,:,:nb_bits], name='quant_scale_embed')(quant_embed_dec))
320
321    encoder = new_rdovae_encoder(nb_used_features, nb_bits, bunch, nb_quant, batch_size, cond_size, cond_size2, training=training)
322    ze, gru_state_dec = encoder([feat])
323    ze = Multiply()([ze, quant_scale])
324
325    decoder = new_rdovae_decoder(nb_used_features, nb_bits, bunch, nb_quant, batch_size, cond_size, cond_size2, training=training)
326    split_decoder = new_split_decoder(decoder)
327
328    dead_zone = Activation('softplus')(Lambda(lambda x: x[:,:,nb_bits:2*nb_bits], name='dead_zone_embed')(quant_embed_dec))
329    soft_distr_embed = Activation('sigmoid')(Lambda(lambda x: x[:,:,2*nb_bits:4*nb_bits], name='soft_distr_embed')(quant_embed_dec))
330    hard_distr_embed = Activation('sigmoid')(Lambda(lambda x: x[:,:,4*nb_bits:], name='hard_distr_embed')(quant_embed_dec))
331
332    noisequant = UniformNoise()
333    hardquant = Lambda(hard_quantize)
334    dzone = Lambda(apply_dead_zone)
335    dze = dzone([ze,dead_zone])
336    ndze = noisequant(dze)
337    dze_quant = hardquant(dze)
338
339    div = Lambda(lambda x: x[0]/x[1])
340    dze_quant = div([dze_quant,quant_scale])
341    ndze_unquant = div([ndze,quant_scale])
342
343    mod_select = Lambda(lambda x: x[0][:,x[1]::bunch//2,:])
344    gru_state_dec = Lambda(lambda x: pvq_quantize(x, 82))(gru_state_dec)
345    combined_output = []
346    unquantized_output = []
347    cat = Concatenate(name="out_cat")
348    for i in range(bunch//2):
349        dze_select = mod_select([dze_quant, i])
350        ndze_select = mod_select([ndze_unquant, i])
351        state_select = mod_select([gru_state_dec, i])
352
353        tmp = split_decoder([dze_select, state_select])
354        tmp = cat([tmp, lambda_up])
355        combined_output.append(tmp)
356
357        tmp = split_decoder([ndze_select, state_select])
358        tmp = cat([tmp, lambda_up])
359        unquantized_output.append(tmp)
360
361    concat = Lambda(tensor_concat, name="output")
362    combined_output = concat(combined_output)
363    unquantized_output = concat(unquantized_output)
364
365    e2 = Concatenate(name="hard_bits")([dze, hard_distr_embed, lambda_val])
366    e = Concatenate(name="soft_bits")([dze, soft_distr_embed, lambda_val])
367
368
369    model = Model([feat, quant_id, lambda_val], [combined_output, unquantized_output, e, e2], name="end2end")
370    model.nb_used_features = nb_used_features
371
372    return model, encoder, decoder, qembedding
373