xref: /aosp_15_r20/external/libopus/dnn/training_tf2/lossfuncs.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1*a58d3d2aSXin Li"""
2*a58d3d2aSXin LiCustom Loss functions and metrics for training/analysis
3*a58d3d2aSXin Li"""
4*a58d3d2aSXin Li
5*a58d3d2aSXin Lifrom tf_funcs import *
6*a58d3d2aSXin Liimport tensorflow as tf
7*a58d3d2aSXin Li
8*a58d3d2aSXin Li# The following loss functions all expect the lpcnet model to output the lpc prediction
9*a58d3d2aSXin Li
10*a58d3d2aSXin Li# Computing the excitation by subtracting the lpc prediction from the target, followed by minimizing the cross entropy
11*a58d3d2aSXin Lidef res_from_sigloss():
12*a58d3d2aSXin Li    def loss(y_true,y_pred):
13*a58d3d2aSXin Li        p = y_pred[:,:,0:1]
14*a58d3d2aSXin Li        model_out = y_pred[:,:,2:]
15*a58d3d2aSXin Li        e_gt = tf_l2u(y_true - p)
16*a58d3d2aSXin Li        e_gt = tf.round(e_gt)
17*a58d3d2aSXin Li        e_gt = tf.cast(e_gt,'int32')
18*a58d3d2aSXin Li        sparse_cel = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(e_gt,model_out)
19*a58d3d2aSXin Li        return sparse_cel
20*a58d3d2aSXin Li    return loss
21*a58d3d2aSXin Li
22*a58d3d2aSXin Li# Interpolated and Compensated Loss (In case of end to end lpcnet)
23*a58d3d2aSXin Li# Interpolates between adjacent embeddings based on the fractional value of the excitation computed (similar to the embedding interpolation)
24*a58d3d2aSXin Li# Also adds a probability compensation (to account for matching cross entropy in the linear domain), weighted by gamma
25*a58d3d2aSXin Lidef interp_mulaw(gamma = 1):
26*a58d3d2aSXin Li    def loss(y_true,y_pred):
27*a58d3d2aSXin Li        y_true = tf.cast(y_true, 'float32')
28*a58d3d2aSXin Li        p = y_pred[:,:,0:1]
29*a58d3d2aSXin Li        real_p = y_pred[:,:,1:2]
30*a58d3d2aSXin Li        model_out = y_pred[:,:,2:]
31*a58d3d2aSXin Li        e_gt = tf_l2u(y_true - p)
32*a58d3d2aSXin Li        exc_gt = tf_l2u(y_true - real_p)
33*a58d3d2aSXin Li        prob_compensation = tf.squeeze((K.abs(e_gt - 128)/128.0)*K.log(256.0))
34*a58d3d2aSXin Li        regularization = tf.squeeze((K.abs(exc_gt - 128)/128.0)*K.log(256.0))
35*a58d3d2aSXin Li        alpha = e_gt - tf.math.floor(e_gt)
36*a58d3d2aSXin Li        alpha = tf.tile(alpha,[1,1,256])
37*a58d3d2aSXin Li        e_gt = tf.cast(e_gt,'int32')
38*a58d3d2aSXin Li        e_gt = tf.clip_by_value(e_gt,0,254)
39*a58d3d2aSXin Li        interp_probab = (1 - alpha)*model_out + alpha*tf.roll(model_out,shift = -1,axis = -1)
40*a58d3d2aSXin Li        sparse_cel = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(e_gt,interp_probab)
41*a58d3d2aSXin Li        loss_mod = sparse_cel + prob_compensation + gamma*regularization
42*a58d3d2aSXin Li        return loss_mod
43*a58d3d2aSXin Li    return loss
44*a58d3d2aSXin Li
45*a58d3d2aSXin Li# Same as above, except a metric
46*a58d3d2aSXin Lidef metric_oginterploss(y_true,y_pred):
47*a58d3d2aSXin Li    p = y_pred[:,:,0:1]
48*a58d3d2aSXin Li    model_out = y_pred[:,:,2:]
49*a58d3d2aSXin Li    e_gt = tf_l2u(y_true - p)
50*a58d3d2aSXin Li    prob_compensation = tf.squeeze((K.abs(e_gt - 128)/128.0)*K.log(256.0))
51*a58d3d2aSXin Li    alpha = e_gt - tf.math.floor(e_gt)
52*a58d3d2aSXin Li    alpha = tf.tile(alpha,[1,1,256])
53*a58d3d2aSXin Li    e_gt = tf.cast(e_gt,'int32')
54*a58d3d2aSXin Li    e_gt = tf.clip_by_value(e_gt,0,254)
55*a58d3d2aSXin Li    interp_probab = (1 - alpha)*model_out + alpha*tf.roll(model_out,shift = -1,axis = -1)
56*a58d3d2aSXin Li    sparse_cel = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(e_gt,interp_probab)
57*a58d3d2aSXin Li    loss_mod = sparse_cel + prob_compensation
58*a58d3d2aSXin Li    return loss_mod
59*a58d3d2aSXin Li
60*a58d3d2aSXin Li# Interpolated cross entropy loss metric
61*a58d3d2aSXin Lidef metric_icel(y_true, y_pred):
62*a58d3d2aSXin Li    p = y_pred[:,:,0:1]
63*a58d3d2aSXin Li    model_out = y_pred[:,:,2:]
64*a58d3d2aSXin Li    e_gt = tf_l2u(y_true - p)
65*a58d3d2aSXin Li    alpha = e_gt - tf.math.floor(e_gt)
66*a58d3d2aSXin Li    alpha = tf.tile(alpha,[1,1,256])
67*a58d3d2aSXin Li    e_gt = tf.cast(e_gt,'int32')
68*a58d3d2aSXin Li    e_gt = tf.clip_by_value(e_gt,0,254) #Check direction
69*a58d3d2aSXin Li    interp_probab = (1 - alpha)*model_out + alpha*tf.roll(model_out,shift = -1,axis = -1)
70*a58d3d2aSXin Li    sparse_cel = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(e_gt,interp_probab)
71*a58d3d2aSXin Li    return sparse_cel
72*a58d3d2aSXin Li
73*a58d3d2aSXin Li# Non-interpolated (rounded) cross entropy loss metric
74*a58d3d2aSXin Lidef metric_cel(y_true, y_pred):
75*a58d3d2aSXin Li    y_true = tf.cast(y_true, 'float32')
76*a58d3d2aSXin Li    p = y_pred[:,:,0:1]
77*a58d3d2aSXin Li    model_out = y_pred[:,:,2:]
78*a58d3d2aSXin Li    e_gt = tf_l2u(y_true - p)
79*a58d3d2aSXin Li    e_gt = tf.round(e_gt)
80*a58d3d2aSXin Li    e_gt = tf.cast(e_gt,'int32')
81*a58d3d2aSXin Li    e_gt = tf.clip_by_value(e_gt,0,255)
82*a58d3d2aSXin Li    sparse_cel = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(e_gt,model_out)
83*a58d3d2aSXin Li    return sparse_cel
84*a58d3d2aSXin Li
85*a58d3d2aSXin Li# Variance metric of the output excitation
86*a58d3d2aSXin Lidef metric_exc_sd(y_true,y_pred):
87*a58d3d2aSXin Li    p = y_pred[:,:,0:1]
88*a58d3d2aSXin Li    e_gt = tf_l2u(y_true - p)
89*a58d3d2aSXin Li    sd_egt = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)(e_gt,128)
90*a58d3d2aSXin Li    return sd_egt
91*a58d3d2aSXin Li
92*a58d3d2aSXin Lidef loss_matchlar():
93*a58d3d2aSXin Li    def loss(y_true,y_pred):
94*a58d3d2aSXin Li        model_rc = y_pred[:,:,:16]
95*a58d3d2aSXin Li        #y_true = lpc2rc(y_true)
96*a58d3d2aSXin Li        loss_lar_diff = K.log((1.01 + model_rc)/(1.01 - model_rc)) - K.log((1.01 + y_true)/(1.01 - y_true))
97*a58d3d2aSXin Li        loss_lar_diff = tf.square(loss_lar_diff)
98*a58d3d2aSXin Li        return tf.reduce_mean(loss_lar_diff, axis=-1)
99*a58d3d2aSXin Li    return loss
100