1*a58d3d2aSXin Li""" 2*a58d3d2aSXin LiTraining the neural pitch estimator 3*a58d3d2aSXin Li 4*a58d3d2aSXin Li""" 5*a58d3d2aSXin Li 6*a58d3d2aSXin Liimport os 7*a58d3d2aSXin Liimport argparse 8*a58d3d2aSXin Liparser = argparse.ArgumentParser() 9*a58d3d2aSXin Li 10*a58d3d2aSXin Liparser.add_argument('features', type=str, help='.f32 IF Features for training (generated by augmentation script)') 11*a58d3d2aSXin Liparser.add_argument('features_pitch', type=str, help='.npy Pitch file for training (generated by augmentation script)') 12*a58d3d2aSXin Liparser.add_argument('output_folder', type=str, help='Output directory to store the model weights and config') 13*a58d3d2aSXin Liparser.add_argument('data_format', type=str, help='Choice of Input Data',choices=['if','xcorr','both']) 14*a58d3d2aSXin Liparser.add_argument('--gpu_index', type=int, help='GPU index to use if multiple GPUs',default = 0,required = False) 15*a58d3d2aSXin Liparser.add_argument('--confidence_threshold', type=float, help='Confidence value below which pitch will be neglected during training',default = 0.4,required = False) 16*a58d3d2aSXin Liparser.add_argument('--context', type=int, help='Sequence length during training',default = 100,required = False) 17*a58d3d2aSXin Liparser.add_argument('--N', type=int, help='STFT window size',default = 320,required = False) 18*a58d3d2aSXin Liparser.add_argument('--H', type=int, help='STFT Hop size',default = 160,required = False) 19*a58d3d2aSXin Liparser.add_argument('--xcorr_dimension', type=int, help='Dimension of Input cross-correlation',default = 257,required = False) 20*a58d3d2aSXin Liparser.add_argument('--freq_keep', type=int, help='Number of Frequencies to keep',default = 30,required = False) 21*a58d3d2aSXin Liparser.add_argument('--gru_dim', type=int, help='GRU Dimension',default = 64,required = False) 22*a58d3d2aSXin Liparser.add_argument('--output_dim', type=int, help='Output dimension',default = 192,required = False) 23*a58d3d2aSXin Liparser.add_argument('--learning_rate', type=float, help='Learning Rate',default = 1.0e-3,required = False) 24*a58d3d2aSXin Liparser.add_argument('--epochs', type=int, help='Number of training epochs',default = 50,required = False) 25*a58d3d2aSXin Liparser.add_argument('--choice_cel', type=str, help='Choice of Cross Entropy Loss (default or robust)',choices=['default','robust'],default = 'default',required = False) 26*a58d3d2aSXin Liparser.add_argument('--prefix', type=str, help="prefix for model export, default: model", default='model') 27*a58d3d2aSXin Liparser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None) 28*a58d3d2aSXin Li 29*a58d3d2aSXin Li 30*a58d3d2aSXin Liargs = parser.parse_args() 31*a58d3d2aSXin Li 32*a58d3d2aSXin Li# import os 33*a58d3d2aSXin Li# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 34*a58d3d2aSXin Li# os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_index) 35*a58d3d2aSXin Li 36*a58d3d2aSXin Li# Fixing the seeds for reproducability 37*a58d3d2aSXin Liimport time 38*a58d3d2aSXin Linp_seed = int(time.time()) 39*a58d3d2aSXin Litorch_seed = int(time.time()) 40*a58d3d2aSXin Li 41*a58d3d2aSXin Liimport torch 42*a58d3d2aSXin Litorch.manual_seed(torch_seed) 43*a58d3d2aSXin Liimport numpy as np 44*a58d3d2aSXin Linp.random.seed(np_seed) 45*a58d3d2aSXin Lifrom utils import count_parameters 46*a58d3d2aSXin Liimport tqdm 47*a58d3d2aSXin Lifrom models import PitchDNN, PitchDNNIF, PitchDNNXcorr, PitchDNNDataloader 48*a58d3d2aSXin Li 49*a58d3d2aSXin Lidevice = torch.device("cuda" if torch.cuda.is_available() else "cpu") 50*a58d3d2aSXin Li 51*a58d3d2aSXin Li 52*a58d3d2aSXin Liif args.data_format == 'if': 53*a58d3d2aSXin Li pitch_nn = PitchDNNIF(3 * args.freq_keep - 2, args.gru_dim, args.output_dim) 54*a58d3d2aSXin Lielif args.data_format == 'xcorr': 55*a58d3d2aSXin Li pitch_nn = PitchDNNXcorr(args.xcorr_dimension, args.gru_dim, args.output_dim) 56*a58d3d2aSXin Lielse: 57*a58d3d2aSXin Li pitch_nn = PitchDNN(3 * args.freq_keep - 2, 224, args.gru_dim, args.output_dim) 58*a58d3d2aSXin Li 59*a58d3d2aSXin Liif type(args.initial_checkpoint) != type(None): 60*a58d3d2aSXin Li checkpoint = torch.load(args.initial_checkpoint, map_location='cpu') 61*a58d3d2aSXin Li pitch_nn.load_state_dict(checkpoint['state_dict'], strict=False) 62*a58d3d2aSXin Li 63*a58d3d2aSXin Li 64*a58d3d2aSXin Lidataset_training = PitchDNNDataloader(args.features,args.features_pitch,args.confidence_threshold,args.context,args.data_format) 65*a58d3d2aSXin Li 66*a58d3d2aSXin Lidef loss_custom(logits,labels,confidence,choice = 'default',nmax = 192,q = 0.7): 67*a58d3d2aSXin Li logits_softmax = torch.nn.Softmax(dim = 1)(logits).permute(0,2,1) 68*a58d3d2aSXin Li labels_one_hot = torch.nn.functional.one_hot(labels.long(),nmax) 69*a58d3d2aSXin Li 70*a58d3d2aSXin Li if choice == 'default': 71*a58d3d2aSXin Li # Categorical Cross Entropy 72*a58d3d2aSXin Li CE = -torch.sum(torch.log(logits_softmax*labels_one_hot + 1.0e-6)*labels_one_hot,dim=-1) 73*a58d3d2aSXin Li CE = torch.mean(confidence*CE) 74*a58d3d2aSXin Li 75*a58d3d2aSXin Li else: 76*a58d3d2aSXin Li # Robust Cross Entropy 77*a58d3d2aSXin Li CE = (1.0/q)*(1 - torch.sum(torch.pow(logits_softmax*labels_one_hot + 1.0e-7,q),dim=-1) ) 78*a58d3d2aSXin Li CE = torch.sum(confidence*CE) 79*a58d3d2aSXin Li 80*a58d3d2aSXin Li return CE 81*a58d3d2aSXin Li 82*a58d3d2aSXin Lidef accuracy(logits,labels,confidence,choice = 'default',nmax = 192,q = 0.7): 83*a58d3d2aSXin Li logits_softmax = torch.nn.Softmax(dim = 1)(logits).permute(0,2,1) 84*a58d3d2aSXin Li pred_pitch = torch.argmax(logits_softmax, 2) 85*a58d3d2aSXin Li accuracy = (pred_pitch != labels.long())*1. 86*a58d3d2aSXin Li return 1.-torch.mean(confidence*accuracy) 87*a58d3d2aSXin Li 88*a58d3d2aSXin Litrain_dataset, test_dataset = torch.utils.data.random_split(dataset_training, [0.95,0.05], generator=torch.Generator().manual_seed(torch_seed)) 89*a58d3d2aSXin Li 90*a58d3d2aSXin Libatch_size = 256 91*a58d3d2aSXin Litrain_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=False) 92*a58d3d2aSXin Litest_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=False) 93*a58d3d2aSXin Li 94*a58d3d2aSXin Lipitch_nn = pitch_nn.to(device) 95*a58d3d2aSXin Linum_params = count_parameters(pitch_nn) 96*a58d3d2aSXin Lilearning_rate = args.learning_rate 97*a58d3d2aSXin Limodel_opt = torch.optim.Adam(pitch_nn.parameters(), lr = learning_rate) 98*a58d3d2aSXin Li 99*a58d3d2aSXin Linum_epochs = args.epochs 100*a58d3d2aSXin Li 101*a58d3d2aSXin Lifor epoch in range(num_epochs): 102*a58d3d2aSXin Li losses = [] 103*a58d3d2aSXin Li accs = [] 104*a58d3d2aSXin Li pitch_nn.train() 105*a58d3d2aSXin Li with tqdm.tqdm(train_dataloader) as train_epoch: 106*a58d3d2aSXin Li for i, (xi, yi, ci) in enumerate(train_epoch): 107*a58d3d2aSXin Li yi, xi, ci = yi.to(device, non_blocking=True), xi.to(device, non_blocking=True), ci.to(device, non_blocking=True) 108*a58d3d2aSXin Li pi = pitch_nn(xi.float()) 109*a58d3d2aSXin Li loss = loss_custom(logits = pi,labels = yi,confidence = ci,choice = args.choice_cel,nmax = args.output_dim) 110*a58d3d2aSXin Li acc = accuracy(logits = pi,labels = yi,confidence = ci,choice = args.choice_cel,nmax = args.output_dim) 111*a58d3d2aSXin Li acc = acc.detach() 112*a58d3d2aSXin Li 113*a58d3d2aSXin Li model_opt.zero_grad() 114*a58d3d2aSXin Li loss.backward() 115*a58d3d2aSXin Li model_opt.step() 116*a58d3d2aSXin Li 117*a58d3d2aSXin Li losses.append(loss.item()) 118*a58d3d2aSXin Li accs.append(acc.item()) 119*a58d3d2aSXin Li avg_loss = np.mean(losses) 120*a58d3d2aSXin Li avg_acc = np.mean(accs) 121*a58d3d2aSXin Li train_epoch.set_postfix({"Train Epoch" : epoch, "Train Loss":avg_loss, "acc" : avg_acc.item()}) 122*a58d3d2aSXin Li 123*a58d3d2aSXin Li if epoch % 5 == 0: 124*a58d3d2aSXin Li pitch_nn.eval() 125*a58d3d2aSXin Li losses = [] 126*a58d3d2aSXin Li with tqdm.tqdm(test_dataloader) as test_epoch: 127*a58d3d2aSXin Li for i, (xi, yi, ci) in enumerate(test_epoch): 128*a58d3d2aSXin Li yi, xi, ci = yi.to(device, non_blocking=True), xi.to(device, non_blocking=True), ci.to(device, non_blocking=True) 129*a58d3d2aSXin Li pi = pitch_nn(xi.float()) 130*a58d3d2aSXin Li loss = loss_custom(logits = pi,labels = yi,confidence = ci,choice = args.choice_cel,nmax = args.output_dim) 131*a58d3d2aSXin Li losses.append(loss.item()) 132*a58d3d2aSXin Li avg_loss = np.mean(losses) 133*a58d3d2aSXin Li test_epoch.set_postfix({"Epoch" : epoch, "Test Loss":avg_loss}) 134*a58d3d2aSXin Li 135*a58d3d2aSXin Lipitch_nn.eval() 136*a58d3d2aSXin Li 137*a58d3d2aSXin Liconfig = dict( 138*a58d3d2aSXin Li data_format=args.data_format, 139*a58d3d2aSXin Li epochs=num_epochs, 140*a58d3d2aSXin Li window_size= args.N, 141*a58d3d2aSXin Li hop_factor= args.H, 142*a58d3d2aSXin Li freq_keep=args.freq_keep, 143*a58d3d2aSXin Li batch_size=batch_size, 144*a58d3d2aSXin Li learning_rate=learning_rate, 145*a58d3d2aSXin Li confidence_threshold=args.confidence_threshold, 146*a58d3d2aSXin Li model_parameters=num_params, 147*a58d3d2aSXin Li np_seed=np_seed, 148*a58d3d2aSXin Li torch_seed=torch_seed, 149*a58d3d2aSXin Li xcorr_dim=args.xcorr_dimension, 150*a58d3d2aSXin Li dim_input=3*args.freq_keep - 2, 151*a58d3d2aSXin Li gru_dim=args.gru_dim, 152*a58d3d2aSXin Li output_dim=args.output_dim, 153*a58d3d2aSXin Li choice_cel=args.choice_cel, 154*a58d3d2aSXin Li context=args.context, 155*a58d3d2aSXin Li) 156*a58d3d2aSXin Li 157*a58d3d2aSXin Limodel_save_path = os.path.join(args.output_folder, f"{args.prefix}_{args.data_format}.pth") 158*a58d3d2aSXin Licheckpoint = { 159*a58d3d2aSXin Li 'state_dict': pitch_nn.state_dict(), 160*a58d3d2aSXin Li 'config': config 161*a58d3d2aSXin Li} 162*a58d3d2aSXin Litorch.save(checkpoint, model_save_path) 163