1*a58d3d2aSXin Li""" 2*a58d3d2aSXin LiCustom Loss functions and metrics for training/analysis 3*a58d3d2aSXin Li""" 4*a58d3d2aSXin Li 5*a58d3d2aSXin Lifrom tf_funcs import * 6*a58d3d2aSXin Liimport tensorflow as tf 7*a58d3d2aSXin Li 8*a58d3d2aSXin Li# The following loss functions all expect the lpcnet model to output the lpc prediction 9*a58d3d2aSXin Li 10*a58d3d2aSXin Li# Computing the excitation by subtracting the lpc prediction from the target, followed by minimizing the cross entropy 11*a58d3d2aSXin Lidef res_from_sigloss(): 12*a58d3d2aSXin Li def loss(y_true,y_pred): 13*a58d3d2aSXin Li p = y_pred[:,:,0:1] 14*a58d3d2aSXin Li model_out = y_pred[:,:,2:] 15*a58d3d2aSXin Li e_gt = tf_l2u(y_true - p) 16*a58d3d2aSXin Li e_gt = tf.round(e_gt) 17*a58d3d2aSXin Li e_gt = tf.cast(e_gt,'int32') 18*a58d3d2aSXin Li sparse_cel = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(e_gt,model_out) 19*a58d3d2aSXin Li return sparse_cel 20*a58d3d2aSXin Li return loss 21*a58d3d2aSXin Li 22*a58d3d2aSXin Li# Interpolated and Compensated Loss (In case of end to end lpcnet) 23*a58d3d2aSXin Li# Interpolates between adjacent embeddings based on the fractional value of the excitation computed (similar to the embedding interpolation) 24*a58d3d2aSXin Li# Also adds a probability compensation (to account for matching cross entropy in the linear domain), weighted by gamma 25*a58d3d2aSXin Lidef interp_mulaw(gamma = 1): 26*a58d3d2aSXin Li def loss(y_true,y_pred): 27*a58d3d2aSXin Li y_true = tf.cast(y_true, 'float32') 28*a58d3d2aSXin Li p = y_pred[:,:,0:1] 29*a58d3d2aSXin Li real_p = y_pred[:,:,1:2] 30*a58d3d2aSXin Li model_out = y_pred[:,:,2:] 31*a58d3d2aSXin Li e_gt = tf_l2u(y_true - p) 32*a58d3d2aSXin Li exc_gt = tf_l2u(y_true - real_p) 33*a58d3d2aSXin Li prob_compensation = tf.squeeze((K.abs(e_gt - 128)/128.0)*K.log(256.0)) 34*a58d3d2aSXin Li regularization = tf.squeeze((K.abs(exc_gt - 128)/128.0)*K.log(256.0)) 35*a58d3d2aSXin Li alpha = e_gt - tf.math.floor(e_gt) 36*a58d3d2aSXin Li alpha = tf.tile(alpha,[1,1,256]) 37*a58d3d2aSXin Li e_gt = tf.cast(e_gt,'int32') 38*a58d3d2aSXin Li e_gt = tf.clip_by_value(e_gt,0,254) 39*a58d3d2aSXin Li interp_probab = (1 - alpha)*model_out + alpha*tf.roll(model_out,shift = -1,axis = -1) 40*a58d3d2aSXin Li sparse_cel = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(e_gt,interp_probab) 41*a58d3d2aSXin Li loss_mod = sparse_cel + prob_compensation + gamma*regularization 42*a58d3d2aSXin Li return loss_mod 43*a58d3d2aSXin Li return loss 44*a58d3d2aSXin Li 45*a58d3d2aSXin Li# Same as above, except a metric 46*a58d3d2aSXin Lidef metric_oginterploss(y_true,y_pred): 47*a58d3d2aSXin Li p = y_pred[:,:,0:1] 48*a58d3d2aSXin Li model_out = y_pred[:,:,2:] 49*a58d3d2aSXin Li e_gt = tf_l2u(y_true - p) 50*a58d3d2aSXin Li prob_compensation = tf.squeeze((K.abs(e_gt - 128)/128.0)*K.log(256.0)) 51*a58d3d2aSXin Li alpha = e_gt - tf.math.floor(e_gt) 52*a58d3d2aSXin Li alpha = tf.tile(alpha,[1,1,256]) 53*a58d3d2aSXin Li e_gt = tf.cast(e_gt,'int32') 54*a58d3d2aSXin Li e_gt = tf.clip_by_value(e_gt,0,254) 55*a58d3d2aSXin Li interp_probab = (1 - alpha)*model_out + alpha*tf.roll(model_out,shift = -1,axis = -1) 56*a58d3d2aSXin Li sparse_cel = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(e_gt,interp_probab) 57*a58d3d2aSXin Li loss_mod = sparse_cel + prob_compensation 58*a58d3d2aSXin Li return loss_mod 59*a58d3d2aSXin Li 60*a58d3d2aSXin Li# Interpolated cross entropy loss metric 61*a58d3d2aSXin Lidef metric_icel(y_true, y_pred): 62*a58d3d2aSXin Li p = y_pred[:,:,0:1] 63*a58d3d2aSXin Li model_out = y_pred[:,:,2:] 64*a58d3d2aSXin Li e_gt = tf_l2u(y_true - p) 65*a58d3d2aSXin Li alpha = e_gt - tf.math.floor(e_gt) 66*a58d3d2aSXin Li alpha = tf.tile(alpha,[1,1,256]) 67*a58d3d2aSXin Li e_gt = tf.cast(e_gt,'int32') 68*a58d3d2aSXin Li e_gt = tf.clip_by_value(e_gt,0,254) #Check direction 69*a58d3d2aSXin Li interp_probab = (1 - alpha)*model_out + alpha*tf.roll(model_out,shift = -1,axis = -1) 70*a58d3d2aSXin Li sparse_cel = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(e_gt,interp_probab) 71*a58d3d2aSXin Li return sparse_cel 72*a58d3d2aSXin Li 73*a58d3d2aSXin Li# Non-interpolated (rounded) cross entropy loss metric 74*a58d3d2aSXin Lidef metric_cel(y_true, y_pred): 75*a58d3d2aSXin Li y_true = tf.cast(y_true, 'float32') 76*a58d3d2aSXin Li p = y_pred[:,:,0:1] 77*a58d3d2aSXin Li model_out = y_pred[:,:,2:] 78*a58d3d2aSXin Li e_gt = tf_l2u(y_true - p) 79*a58d3d2aSXin Li e_gt = tf.round(e_gt) 80*a58d3d2aSXin Li e_gt = tf.cast(e_gt,'int32') 81*a58d3d2aSXin Li e_gt = tf.clip_by_value(e_gt,0,255) 82*a58d3d2aSXin Li sparse_cel = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(e_gt,model_out) 83*a58d3d2aSXin Li return sparse_cel 84*a58d3d2aSXin Li 85*a58d3d2aSXin Li# Variance metric of the output excitation 86*a58d3d2aSXin Lidef metric_exc_sd(y_true,y_pred): 87*a58d3d2aSXin Li p = y_pred[:,:,0:1] 88*a58d3d2aSXin Li e_gt = tf_l2u(y_true - p) 89*a58d3d2aSXin Li sd_egt = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)(e_gt,128) 90*a58d3d2aSXin Li return sd_egt 91*a58d3d2aSXin Li 92*a58d3d2aSXin Lidef loss_matchlar(): 93*a58d3d2aSXin Li def loss(y_true,y_pred): 94*a58d3d2aSXin Li model_rc = y_pred[:,:,:16] 95*a58d3d2aSXin Li #y_true = lpc2rc(y_true) 96*a58d3d2aSXin Li loss_lar_diff = K.log((1.01 + model_rc)/(1.01 - model_rc)) - K.log((1.01 + y_true)/(1.01 - y_true)) 97*a58d3d2aSXin Li loss_lar_diff = tf.square(loss_lar_diff) 98*a58d3d2aSXin Li return tf.reduce_mean(loss_lar_diff, axis=-1) 99*a58d3d2aSXin Li return loss 100