xref: /aosp_15_r20/external/libopus/dnn/torch/neural-pitch/training.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1*a58d3d2aSXin Li"""
2*a58d3d2aSXin LiTraining the neural pitch estimator
3*a58d3d2aSXin Li
4*a58d3d2aSXin Li"""
5*a58d3d2aSXin Li
6*a58d3d2aSXin Liimport os
7*a58d3d2aSXin Liimport argparse
8*a58d3d2aSXin Liparser = argparse.ArgumentParser()
9*a58d3d2aSXin Li
10*a58d3d2aSXin Liparser.add_argument('features', type=str, help='.f32 IF Features for training (generated by augmentation script)')
11*a58d3d2aSXin Liparser.add_argument('features_pitch', type=str, help='.npy Pitch file for training (generated by augmentation script)')
12*a58d3d2aSXin Liparser.add_argument('output_folder', type=str, help='Output directory to store the model weights and config')
13*a58d3d2aSXin Liparser.add_argument('data_format', type=str, help='Choice of Input Data',choices=['if','xcorr','both'])
14*a58d3d2aSXin Liparser.add_argument('--gpu_index', type=int, help='GPU index to use if multiple GPUs',default = 0,required = False)
15*a58d3d2aSXin Liparser.add_argument('--confidence_threshold', type=float, help='Confidence value below which pitch will be neglected during training',default = 0.4,required = False)
16*a58d3d2aSXin Liparser.add_argument('--context', type=int, help='Sequence length during training',default = 100,required = False)
17*a58d3d2aSXin Liparser.add_argument('--N', type=int, help='STFT window size',default = 320,required = False)
18*a58d3d2aSXin Liparser.add_argument('--H', type=int, help='STFT Hop size',default = 160,required = False)
19*a58d3d2aSXin Liparser.add_argument('--xcorr_dimension', type=int, help='Dimension of Input cross-correlation',default = 257,required = False)
20*a58d3d2aSXin Liparser.add_argument('--freq_keep', type=int, help='Number of Frequencies to keep',default = 30,required = False)
21*a58d3d2aSXin Liparser.add_argument('--gru_dim', type=int, help='GRU Dimension',default = 64,required = False)
22*a58d3d2aSXin Liparser.add_argument('--output_dim', type=int, help='Output dimension',default = 192,required = False)
23*a58d3d2aSXin Liparser.add_argument('--learning_rate', type=float, help='Learning Rate',default = 1.0e-3,required = False)
24*a58d3d2aSXin Liparser.add_argument('--epochs', type=int, help='Number of training epochs',default = 50,required = False)
25*a58d3d2aSXin Liparser.add_argument('--choice_cel', type=str, help='Choice of Cross Entropy Loss (default or robust)',choices=['default','robust'],default = 'default',required = False)
26*a58d3d2aSXin Liparser.add_argument('--prefix', type=str, help="prefix for model export, default: model", default='model')
27*a58d3d2aSXin Liparser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None)
28*a58d3d2aSXin Li
29*a58d3d2aSXin Li
30*a58d3d2aSXin Liargs = parser.parse_args()
31*a58d3d2aSXin Li
32*a58d3d2aSXin Li# import os
33*a58d3d2aSXin Li# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
34*a58d3d2aSXin Li# os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_index)
35*a58d3d2aSXin Li
36*a58d3d2aSXin Li# Fixing the seeds for reproducability
37*a58d3d2aSXin Liimport time
38*a58d3d2aSXin Linp_seed = int(time.time())
39*a58d3d2aSXin Litorch_seed = int(time.time())
40*a58d3d2aSXin Li
41*a58d3d2aSXin Liimport torch
42*a58d3d2aSXin Litorch.manual_seed(torch_seed)
43*a58d3d2aSXin Liimport numpy as np
44*a58d3d2aSXin Linp.random.seed(np_seed)
45*a58d3d2aSXin Lifrom utils import count_parameters
46*a58d3d2aSXin Liimport tqdm
47*a58d3d2aSXin Lifrom models import PitchDNN, PitchDNNIF, PitchDNNXcorr, PitchDNNDataloader
48*a58d3d2aSXin Li
49*a58d3d2aSXin Lidevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50*a58d3d2aSXin Li
51*a58d3d2aSXin Li
52*a58d3d2aSXin Liif args.data_format == 'if':
53*a58d3d2aSXin Li    pitch_nn = PitchDNNIF(3 * args.freq_keep - 2, args.gru_dim, args.output_dim)
54*a58d3d2aSXin Lielif args.data_format == 'xcorr':
55*a58d3d2aSXin Li    pitch_nn = PitchDNNXcorr(args.xcorr_dimension, args.gru_dim, args.output_dim)
56*a58d3d2aSXin Lielse:
57*a58d3d2aSXin Li    pitch_nn = PitchDNN(3 * args.freq_keep - 2, 224, args.gru_dim, args.output_dim)
58*a58d3d2aSXin Li
59*a58d3d2aSXin Liif type(args.initial_checkpoint) != type(None):
60*a58d3d2aSXin Li    checkpoint = torch.load(args.initial_checkpoint, map_location='cpu')
61*a58d3d2aSXin Li    pitch_nn.load_state_dict(checkpoint['state_dict'], strict=False)
62*a58d3d2aSXin Li
63*a58d3d2aSXin Li
64*a58d3d2aSXin Lidataset_training = PitchDNNDataloader(args.features,args.features_pitch,args.confidence_threshold,args.context,args.data_format)
65*a58d3d2aSXin Li
66*a58d3d2aSXin Lidef loss_custom(logits,labels,confidence,choice = 'default',nmax = 192,q = 0.7):
67*a58d3d2aSXin Li    logits_softmax = torch.nn.Softmax(dim = 1)(logits).permute(0,2,1)
68*a58d3d2aSXin Li    labels_one_hot = torch.nn.functional.one_hot(labels.long(),nmax)
69*a58d3d2aSXin Li
70*a58d3d2aSXin Li    if choice == 'default':
71*a58d3d2aSXin Li        # Categorical Cross Entropy
72*a58d3d2aSXin Li        CE = -torch.sum(torch.log(logits_softmax*labels_one_hot + 1.0e-6)*labels_one_hot,dim=-1)
73*a58d3d2aSXin Li        CE = torch.mean(confidence*CE)
74*a58d3d2aSXin Li
75*a58d3d2aSXin Li    else:
76*a58d3d2aSXin Li        # Robust Cross Entropy
77*a58d3d2aSXin Li        CE = (1.0/q)*(1 - torch.sum(torch.pow(logits_softmax*labels_one_hot + 1.0e-7,q),dim=-1) )
78*a58d3d2aSXin Li        CE = torch.sum(confidence*CE)
79*a58d3d2aSXin Li
80*a58d3d2aSXin Li    return CE
81*a58d3d2aSXin Li
82*a58d3d2aSXin Lidef accuracy(logits,labels,confidence,choice = 'default',nmax = 192,q = 0.7):
83*a58d3d2aSXin Li    logits_softmax = torch.nn.Softmax(dim = 1)(logits).permute(0,2,1)
84*a58d3d2aSXin Li    pred_pitch = torch.argmax(logits_softmax, 2)
85*a58d3d2aSXin Li    accuracy = (pred_pitch != labels.long())*1.
86*a58d3d2aSXin Li    return 1.-torch.mean(confidence*accuracy)
87*a58d3d2aSXin Li
88*a58d3d2aSXin Litrain_dataset, test_dataset = torch.utils.data.random_split(dataset_training, [0.95,0.05], generator=torch.Generator().manual_seed(torch_seed))
89*a58d3d2aSXin Li
90*a58d3d2aSXin Libatch_size = 256
91*a58d3d2aSXin Litrain_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=False)
92*a58d3d2aSXin Litest_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=False)
93*a58d3d2aSXin Li
94*a58d3d2aSXin Lipitch_nn = pitch_nn.to(device)
95*a58d3d2aSXin Linum_params = count_parameters(pitch_nn)
96*a58d3d2aSXin Lilearning_rate = args.learning_rate
97*a58d3d2aSXin Limodel_opt = torch.optim.Adam(pitch_nn.parameters(), lr = learning_rate)
98*a58d3d2aSXin Li
99*a58d3d2aSXin Linum_epochs = args.epochs
100*a58d3d2aSXin Li
101*a58d3d2aSXin Lifor epoch in range(num_epochs):
102*a58d3d2aSXin Li    losses = []
103*a58d3d2aSXin Li    accs = []
104*a58d3d2aSXin Li    pitch_nn.train()
105*a58d3d2aSXin Li    with tqdm.tqdm(train_dataloader) as train_epoch:
106*a58d3d2aSXin Li        for i, (xi, yi, ci) in enumerate(train_epoch):
107*a58d3d2aSXin Li            yi, xi, ci = yi.to(device, non_blocking=True), xi.to(device, non_blocking=True), ci.to(device, non_blocking=True)
108*a58d3d2aSXin Li            pi = pitch_nn(xi.float())
109*a58d3d2aSXin Li            loss = loss_custom(logits = pi,labels = yi,confidence = ci,choice = args.choice_cel,nmax = args.output_dim)
110*a58d3d2aSXin Li            acc = accuracy(logits = pi,labels = yi,confidence = ci,choice = args.choice_cel,nmax = args.output_dim)
111*a58d3d2aSXin Li            acc = acc.detach()
112*a58d3d2aSXin Li
113*a58d3d2aSXin Li            model_opt.zero_grad()
114*a58d3d2aSXin Li            loss.backward()
115*a58d3d2aSXin Li            model_opt.step()
116*a58d3d2aSXin Li
117*a58d3d2aSXin Li            losses.append(loss.item())
118*a58d3d2aSXin Li            accs.append(acc.item())
119*a58d3d2aSXin Li            avg_loss = np.mean(losses)
120*a58d3d2aSXin Li            avg_acc = np.mean(accs)
121*a58d3d2aSXin Li            train_epoch.set_postfix({"Train Epoch" : epoch, "Train Loss":avg_loss, "acc" : avg_acc.item()})
122*a58d3d2aSXin Li
123*a58d3d2aSXin Li    if epoch % 5 == 0:
124*a58d3d2aSXin Li        pitch_nn.eval()
125*a58d3d2aSXin Li        losses = []
126*a58d3d2aSXin Li        with tqdm.tqdm(test_dataloader) as test_epoch:
127*a58d3d2aSXin Li            for i, (xi, yi, ci) in enumerate(test_epoch):
128*a58d3d2aSXin Li                yi, xi, ci = yi.to(device, non_blocking=True), xi.to(device, non_blocking=True), ci.to(device, non_blocking=True)
129*a58d3d2aSXin Li                pi = pitch_nn(xi.float())
130*a58d3d2aSXin Li                loss = loss_custom(logits = pi,labels = yi,confidence = ci,choice = args.choice_cel,nmax = args.output_dim)
131*a58d3d2aSXin Li                losses.append(loss.item())
132*a58d3d2aSXin Li                avg_loss = np.mean(losses)
133*a58d3d2aSXin Li                test_epoch.set_postfix({"Epoch" : epoch, "Test Loss":avg_loss})
134*a58d3d2aSXin Li
135*a58d3d2aSXin Lipitch_nn.eval()
136*a58d3d2aSXin Li
137*a58d3d2aSXin Liconfig = dict(
138*a58d3d2aSXin Li    data_format=args.data_format,
139*a58d3d2aSXin Li    epochs=num_epochs,
140*a58d3d2aSXin Li    window_size= args.N,
141*a58d3d2aSXin Li    hop_factor= args.H,
142*a58d3d2aSXin Li    freq_keep=args.freq_keep,
143*a58d3d2aSXin Li    batch_size=batch_size,
144*a58d3d2aSXin Li    learning_rate=learning_rate,
145*a58d3d2aSXin Li    confidence_threshold=args.confidence_threshold,
146*a58d3d2aSXin Li    model_parameters=num_params,
147*a58d3d2aSXin Li    np_seed=np_seed,
148*a58d3d2aSXin Li    torch_seed=torch_seed,
149*a58d3d2aSXin Li    xcorr_dim=args.xcorr_dimension,
150*a58d3d2aSXin Li    dim_input=3*args.freq_keep - 2,
151*a58d3d2aSXin Li    gru_dim=args.gru_dim,
152*a58d3d2aSXin Li    output_dim=args.output_dim,
153*a58d3d2aSXin Li    choice_cel=args.choice_cel,
154*a58d3d2aSXin Li    context=args.context,
155*a58d3d2aSXin Li)
156*a58d3d2aSXin Li
157*a58d3d2aSXin Limodel_save_path = os.path.join(args.output_folder, f"{args.prefix}_{args.data_format}.pth")
158*a58d3d2aSXin Licheckpoint = {
159*a58d3d2aSXin Li    'state_dict': pitch_nn.state_dict(),
160*a58d3d2aSXin Li    'config': config
161*a58d3d2aSXin Li}
162*a58d3d2aSXin Litorch.save(checkpoint, model_save_path)
163