xref: /aosp_15_r20/external/libopus/dnn/training_tf2/lpcnet_plc.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1*a58d3d2aSXin Li#!/usr/bin/python3
2*a58d3d2aSXin Li'''Copyright (c) 2021-2022 Amazon
3*a58d3d2aSXin Li   Copyright (c) 2018-2019 Mozilla
4*a58d3d2aSXin Li
5*a58d3d2aSXin Li   Redistribution and use in source and binary forms, with or without
6*a58d3d2aSXin Li   modification, are permitted provided that the following conditions
7*a58d3d2aSXin Li   are met:
8*a58d3d2aSXin Li
9*a58d3d2aSXin Li   - Redistributions of source code must retain the above copyright
10*a58d3d2aSXin Li   notice, this list of conditions and the following disclaimer.
11*a58d3d2aSXin Li
12*a58d3d2aSXin Li   - Redistributions in binary form must reproduce the above copyright
13*a58d3d2aSXin Li   notice, this list of conditions and the following disclaimer in the
14*a58d3d2aSXin Li   documentation and/or other materials provided with the distribution.
15*a58d3d2aSXin Li
16*a58d3d2aSXin Li   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17*a58d3d2aSXin Li   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18*a58d3d2aSXin Li   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19*a58d3d2aSXin Li   A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
20*a58d3d2aSXin Li   CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21*a58d3d2aSXin Li   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22*a58d3d2aSXin Li   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23*a58d3d2aSXin Li   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24*a58d3d2aSXin Li   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25*a58d3d2aSXin Li   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26*a58d3d2aSXin Li   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*a58d3d2aSXin Li'''
28*a58d3d2aSXin Li
29*a58d3d2aSXin Liimport math
30*a58d3d2aSXin Liimport tensorflow as tf
31*a58d3d2aSXin Lifrom tensorflow.keras.models import Model
32*a58d3d2aSXin Lifrom tensorflow.keras.layers import Input, GRU, Dense, Embedding, Reshape, Concatenate, Lambda, Conv1D, Multiply, Add, Bidirectional, MaxPooling1D, Activation, GaussianNoise
33*a58d3d2aSXin Lifrom tensorflow.compat.v1.keras.layers import CuDNNGRU
34*a58d3d2aSXin Lifrom tensorflow.keras import backend as K
35*a58d3d2aSXin Lifrom tensorflow.keras.constraints import Constraint
36*a58d3d2aSXin Lifrom tensorflow.keras.initializers import Initializer
37*a58d3d2aSXin Lifrom tensorflow.keras.callbacks import Callback
38*a58d3d2aSXin Liimport numpy as np
39*a58d3d2aSXin Li
40*a58d3d2aSXin Lidef quant_regularizer(x):
41*a58d3d2aSXin Li    Q = 128
42*a58d3d2aSXin Li    Q_1 = 1./Q
43*a58d3d2aSXin Li    #return .01 * tf.reduce_mean(1 - tf.math.cos(2*3.1415926535897931*(Q*x-tf.round(Q*x))))
44*a58d3d2aSXin Li    return .01 * tf.reduce_mean(K.sqrt(K.sqrt(1.0001 - tf.math.cos(2*3.1415926535897931*(Q*x-tf.round(Q*x))))))
45*a58d3d2aSXin Li
46*a58d3d2aSXin Li
47*a58d3d2aSXin Liclass WeightClip(Constraint):
48*a58d3d2aSXin Li    '''Clips the weights incident to each hidden unit to be inside a range
49*a58d3d2aSXin Li    '''
50*a58d3d2aSXin Li    def __init__(self, c=2):
51*a58d3d2aSXin Li        self.c = c
52*a58d3d2aSXin Li
53*a58d3d2aSXin Li    def __call__(self, p):
54*a58d3d2aSXin Li        # Ensure that abs of adjacent weights don't sum to more than 127. Otherwise there's a risk of
55*a58d3d2aSXin Li        # saturation when implementing dot products with SSSE3 or AVX2.
56*a58d3d2aSXin Li        return self.c*p/tf.maximum(self.c, tf.repeat(tf.abs(p[:, 1::2])+tf.abs(p[:, 0::2]), 2, axis=1))
57*a58d3d2aSXin Li        #return K.clip(p, -self.c, self.c)
58*a58d3d2aSXin Li
59*a58d3d2aSXin Li    def get_config(self):
60*a58d3d2aSXin Li        return {'name': self.__class__.__name__,
61*a58d3d2aSXin Li            'c': self.c}
62*a58d3d2aSXin Li
63*a58d3d2aSXin Liconstraint = WeightClip(0.992)
64*a58d3d2aSXin Li
65*a58d3d2aSXin Lidef new_lpcnet_plc_model(rnn_units=256, nb_used_features=20, nb_burg_features=36, batch_size=128, training=False, adaptation=False, quantize=False, cond_size=128):
66*a58d3d2aSXin Li    feat = Input(shape=(None, nb_used_features+nb_burg_features), batch_size=batch_size)
67*a58d3d2aSXin Li    lost = Input(shape=(None, 1), batch_size=batch_size)
68*a58d3d2aSXin Li
69*a58d3d2aSXin Li    fdense1 = Dense(cond_size, activation='tanh', name='plc_dense1')
70*a58d3d2aSXin Li
71*a58d3d2aSXin Li    cfeat = Concatenate()([feat, lost])
72*a58d3d2aSXin Li    cfeat = fdense1(cfeat)
73*a58d3d2aSXin Li    #cfeat = Conv1D(cond_size, 3, padding='causal', activation='tanh', name='plc_conv1')(cfeat)
74*a58d3d2aSXin Li
75*a58d3d2aSXin Li    quant = quant_regularizer if quantize else None
76*a58d3d2aSXin Li
77*a58d3d2aSXin Li    if training:
78*a58d3d2aSXin Li        rnn = CuDNNGRU(rnn_units, return_sequences=True, return_state=True, name='plc_gru1', stateful=True,
79*a58d3d2aSXin Li              kernel_constraint=constraint, recurrent_constraint = constraint, kernel_regularizer=quant, recurrent_regularizer=quant)
80*a58d3d2aSXin Li        rnn2 = CuDNNGRU(rnn_units, return_sequences=True, return_state=True, name='plc_gru2', stateful=True,
81*a58d3d2aSXin Li              kernel_constraint=constraint, recurrent_constraint = constraint, kernel_regularizer=quant, recurrent_regularizer=quant)
82*a58d3d2aSXin Li    else:
83*a58d3d2aSXin Li        rnn = GRU(rnn_units, return_sequences=True, return_state=True, recurrent_activation="sigmoid", reset_after='true', name='plc_gru1', stateful=True,
84*a58d3d2aSXin Li              kernel_constraint=constraint, recurrent_constraint = constraint, kernel_regularizer=quant, recurrent_regularizer=quant)
85*a58d3d2aSXin Li        rnn2 = GRU(rnn_units, return_sequences=True, return_state=True, recurrent_activation="sigmoid", reset_after='true', name='plc_gru2', stateful=True,
86*a58d3d2aSXin Li              kernel_constraint=constraint, recurrent_constraint = constraint, kernel_regularizer=quant, recurrent_regularizer=quant)
87*a58d3d2aSXin Li
88*a58d3d2aSXin Li    gru_out1, _ = rnn(cfeat)
89*a58d3d2aSXin Li    gru_out1 = GaussianNoise(.005)(gru_out1)
90*a58d3d2aSXin Li    gru_out2, _ = rnn2(gru_out1)
91*a58d3d2aSXin Li
92*a58d3d2aSXin Li    out_dense = Dense(nb_used_features, activation='linear', name='plc_out')
93*a58d3d2aSXin Li    plc_out = out_dense(gru_out2)
94*a58d3d2aSXin Li
95*a58d3d2aSXin Li    model = Model([feat, lost], plc_out)
96*a58d3d2aSXin Li    model.rnn_units = rnn_units
97*a58d3d2aSXin Li    model.cond_size = cond_size
98*a58d3d2aSXin Li    model.nb_used_features = nb_used_features
99*a58d3d2aSXin Li    model.nb_burg_features = nb_burg_features
100*a58d3d2aSXin Li
101*a58d3d2aSXin Li    return model
102