xref: /aosp_15_r20/external/libopus/dnn/training_tf2/tf_funcs.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1*a58d3d2aSXin Li"""
2*a58d3d2aSXin LiTensorflow/Keras helper functions to do the following:
3*a58d3d2aSXin Li    1. \mu law <-> Linear domain conversion
4*a58d3d2aSXin Li    2. Differentiable prediction from the input signal and LP coefficients
5*a58d3d2aSXin Li    3. Differentiable transformations Reflection Coefficients (RCs) <-> LP Coefficients
6*a58d3d2aSXin Li"""
7*a58d3d2aSXin Lifrom tensorflow.keras.layers import Lambda, Multiply, Layer, Concatenate
8*a58d3d2aSXin Lifrom tensorflow.keras import backend as K
9*a58d3d2aSXin Liimport tensorflow as tf
10*a58d3d2aSXin Li
11*a58d3d2aSXin Li# \mu law <-> Linear conversion functions
12*a58d3d2aSXin Liscale = 255.0/32768.0
13*a58d3d2aSXin Liscale_1 = 32768.0/255.0
14*a58d3d2aSXin Lidef tf_l2u(x):
15*a58d3d2aSXin Li    s = K.sign(x)
16*a58d3d2aSXin Li    x = K.abs(x)
17*a58d3d2aSXin Li    u = (s*(128*K.log(1+scale*x)/K.log(256.0)))
18*a58d3d2aSXin Li    u = K.clip(128 + u, 0, 255)
19*a58d3d2aSXin Li    return u
20*a58d3d2aSXin Li
21*a58d3d2aSXin Lidef tf_u2l(u):
22*a58d3d2aSXin Li    u = tf.cast(u,"float32")
23*a58d3d2aSXin Li    u = u - 128.0
24*a58d3d2aSXin Li    s = K.sign(u)
25*a58d3d2aSXin Li    u = K.abs(u)
26*a58d3d2aSXin Li    return s*scale_1*(K.exp(u/128.*K.log(256.0))-1)
27*a58d3d2aSXin Li
28*a58d3d2aSXin Li# Differentiable Prediction Layer
29*a58d3d2aSXin Li# Computes the LP prediction from the input lag signal and the LP coefficients
30*a58d3d2aSXin Li# The inputs xt and lpc conform with the shapes in lpcnet.py (the '2400' is coded keeping this in mind)
31*a58d3d2aSXin Liclass diff_pred(Layer):
32*a58d3d2aSXin Li    def call(self, inputs, lpcoeffs_N = 16, frame_size = 160):
33*a58d3d2aSXin Li        xt = inputs[0]
34*a58d3d2aSXin Li        lpc = inputs[1]
35*a58d3d2aSXin Li
36*a58d3d2aSXin Li        rept = Lambda(lambda x: K.repeat_elements(x , frame_size, 1))
37*a58d3d2aSXin Li        zpX = Lambda(lambda x: K.concatenate([0*x[:,0:lpcoeffs_N,:], x],axis = 1))
38*a58d3d2aSXin Li        cX = Lambda(lambda x: K.concatenate([x[:,(lpcoeffs_N - i):(lpcoeffs_N - i + 2400),:] for i in range(lpcoeffs_N)],axis = 2))
39*a58d3d2aSXin Li
40*a58d3d2aSXin Li        pred = -Multiply()([rept(lpc),cX(zpX(xt))])
41*a58d3d2aSXin Li
42*a58d3d2aSXin Li        return K.sum(pred,axis = 2,keepdims = True)
43*a58d3d2aSXin Li
44*a58d3d2aSXin Li# Differentiable Transformations (RC <-> LPC) computed using the Levinson Durbin Recursion
45*a58d3d2aSXin Liclass diff_rc2lpc(Layer):
46*a58d3d2aSXin Li    def call(self, inputs, lpcoeffs_N = 16):
47*a58d3d2aSXin Li        def pred_lpc_recursive(input):
48*a58d3d2aSXin Li            temp = (input[0] + K.repeat_elements(input[1],input[0].shape[2],2)*K.reverse(input[0],axes = 2))
49*a58d3d2aSXin Li            temp = Concatenate(axis = 2)([temp,input[1]])
50*a58d3d2aSXin Li            return temp
51*a58d3d2aSXin Li        Llpc = Lambda(pred_lpc_recursive)
52*a58d3d2aSXin Li        inputs = inputs[:,:,:lpcoeffs_N]
53*a58d3d2aSXin Li        lpc_init = inputs
54*a58d3d2aSXin Li        for i in range(1,lpcoeffs_N):
55*a58d3d2aSXin Li            lpc_init = Llpc([lpc_init[:,:,:i],K.expand_dims(inputs[:,:,i],axis = -1)])
56*a58d3d2aSXin Li        return lpc_init
57*a58d3d2aSXin Li
58*a58d3d2aSXin Liclass diff_lpc2rc(Layer):
59*a58d3d2aSXin Li    def call(self, inputs, lpcoeffs_N = 16):
60*a58d3d2aSXin Li        def pred_rc_recursive(input):
61*a58d3d2aSXin Li            ki = K.repeat_elements(K.expand_dims(input[1][:,:,0],axis = -1),input[0].shape[2],2)
62*a58d3d2aSXin Li            temp = (input[0] - ki*K.reverse(input[0],axes = 2))/(1 - ki*ki)
63*a58d3d2aSXin Li            temp = Concatenate(axis = 2)([temp,input[1]])
64*a58d3d2aSXin Li            return temp
65*a58d3d2aSXin Li        Lrc = Lambda(pred_rc_recursive)
66*a58d3d2aSXin Li        rc_init = inputs
67*a58d3d2aSXin Li        for i in range(1,lpcoeffs_N):
68*a58d3d2aSXin Li            j = (lpcoeffs_N - i + 1)
69*a58d3d2aSXin Li            rc_init = Lrc([rc_init[:,:,:(j - 1)],rc_init[:,:,(j - 1):]])
70*a58d3d2aSXin Li        return rc_init
71