1import os 2import argparse 3import random 4import numpy as np 5import sys 6import math as m 7 8import torch 9from torch import nn 10import torch.nn.functional as F 11import tqdm 12 13import fargan 14from dataset import FARGANDataset 15from stft_loss import * 16 17source_dir = os.path.split(os.path.abspath(__file__))[0] 18sys.path.append(os.path.join(source_dir, "../osce/")) 19 20import models as osce_models 21 22 23def fmap_loss(scores_real, scores_gen): 24 num_discs = len(scores_real) 25 loss_feat = 0 26 for k in range(num_discs): 27 num_layers = len(scores_gen[k]) - 1 28 f = 4 / num_discs / num_layers 29 for l in range(num_layers): 30 loss_feat += f * F.l1_loss(scores_gen[k][l], scores_real[k][l].detach()) 31 32 return loss_feat 33 34parser = argparse.ArgumentParser() 35 36parser.add_argument('features', type=str, help='path to feature file in .f32 format') 37parser.add_argument('signal', type=str, help='path to signal file in .s16 format') 38parser.add_argument('output', type=str, help='path to output folder') 39 40parser.add_argument('--suffix', type=str, help="model name suffix", default="") 41parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: CUDA_VISIBLE_DEVICES", default=None) 42 43 44model_group = parser.add_argument_group(title="model parameters") 45model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256) 46model_group.add_argument('--gamma', type=float, help="Use A(z/gamma), default: 0.9", default=0.9) 47 48training_group = parser.add_argument_group(title="training parameters") 49training_group.add_argument('--batch-size', type=int, help="batch size, default: 128", default=128) 50training_group.add_argument('--lr', type=float, help='learning rate, default: 5e-4', default=5e-4) 51training_group.add_argument('--epochs', type=int, help='number of training epochs, default: 50', default=50) 52training_group.add_argument('--sequence-length', type=int, help='sequence length, default: 60', default=60) 53training_group.add_argument('--lr-decay', type=float, help='learning rate decay factor, default: 0.0', default=0.0) 54training_group.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None) 55training_group.add_argument('--reg-weight', type=float, help='regression loss weight, default: 1.0', default=1.0) 56training_group.add_argument('--fmap-weight', type=float, help='feature matchin loss weight, default: 1.0', default=1.) 57 58args = parser.parse_args() 59 60if args.cuda_visible_devices != None: 61 os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices 62 63# checkpoints 64checkpoint_dir = os.path.join(args.output, 'checkpoints') 65checkpoint = dict() 66os.makedirs(checkpoint_dir, exist_ok=True) 67 68 69# training parameters 70batch_size = args.batch_size 71lr = args.lr 72epochs = args.epochs 73sequence_length = args.sequence_length 74lr_decay = args.lr_decay 75 76adam_betas = [0.8, 0.99] 77adam_eps = 1e-8 78features_file = args.features 79signal_file = args.signal 80 81# model parameters 82cond_size = args.cond_size 83 84 85checkpoint['batch_size'] = batch_size 86checkpoint['lr'] = lr 87checkpoint['lr_decay'] = lr_decay 88checkpoint['epochs'] = epochs 89checkpoint['sequence_length'] = sequence_length 90checkpoint['adam_betas'] = adam_betas 91 92 93device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 94 95checkpoint['model_args'] = () 96checkpoint['model_kwargs'] = {'cond_size': cond_size, 'gamma': args.gamma} 97print(checkpoint['model_kwargs']) 98model = fargan.FARGAN(*checkpoint['model_args'], **checkpoint['model_kwargs']) 99 100 101#discriminator 102disc_name = 'fdmresdisc' 103disc = osce_models.model_dict[disc_name]( 104 architecture='free', 105 design='f_down', 106 fft_sizes_16k=[2**n for n in range(6, 12)], 107 freq_roi=[0, 7400], 108 max_channels=256, 109 noise_gain=0.0 110) 111 112if type(args.initial_checkpoint) != type(None): 113 checkpoint = torch.load(args.initial_checkpoint, map_location='cpu') 114 model.load_state_dict(checkpoint['state_dict'], strict=False) 115 116checkpoint['state_dict'] = model.state_dict() 117 118 119dataset = FARGANDataset(features_file, signal_file, sequence_length=sequence_length) 120dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4) 121 122 123optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=adam_betas, eps=adam_eps) 124optimizer_disc = torch.optim.AdamW([p for p in disc.parameters() if p.requires_grad], lr=lr, betas=adam_betas, eps=adam_eps) 125 126 127# learning rate scheduler 128scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay * x)) 129scheduler_disc = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer_disc, lr_lambda=lambda x : 1 / (1 + lr_decay * x)) 130 131states = None 132 133spect_loss = MultiResolutionSTFTLoss(device).to(device) 134 135for param in model.parameters(): 136 param.requires_grad = False 137 138batch_count = 0 139if __name__ == '__main__': 140 model.to(device) 141 disc.to(device) 142 143 for epoch in range(1, epochs + 1): 144 145 m_r = 0 146 m_f = 0 147 s_r = 1 148 s_f = 1 149 150 running_cont_loss = 0 151 running_disc_loss = 0 152 running_gen_loss = 0 153 running_fmap_loss = 0 154 running_reg_loss = 0 155 running_wc = 0 156 157 print(f"training epoch {epoch}...") 158 with tqdm.tqdm(dataloader, unit='batch') as tepoch: 159 for i, (features, periods, target, lpc) in enumerate(tepoch): 160 if epoch == 1 and i == 400: 161 for param in model.parameters(): 162 param.requires_grad = True 163 for param in model.cond_net.parameters(): 164 param.requires_grad = False 165 for param in model.sig_net.cond_gain_dense.parameters(): 166 param.requires_grad = False 167 168 optimizer.zero_grad() 169 features = features.to(device) 170 #lpc = lpc.to(device) 171 #lpc = lpc*(args.gamma**torch.arange(1,17, device=device)) 172 #lpc = fargan.interp_lpc(lpc, 4) 173 periods = periods.to(device) 174 if True: 175 target = target[:, :sequence_length*160] 176 #lpc = lpc[:,:sequence_length*4,:] 177 features = features[:,:sequence_length+4,:] 178 periods = periods[:,:sequence_length+4] 179 else: 180 target=target[::2, :] 181 #lpc=lpc[::2,:] 182 features=features[::2,:] 183 periods=periods[::2,:] 184 target = target.to(device) 185 #target = fargan.analysis_filter(target, lpc[:,:,:], nb_subframes=1, gamma=args.gamma) 186 187 #nb_pre = random.randrange(1, 6) 188 nb_pre = 2 189 pre = target[:, :nb_pre*160] 190 output, _ = model(features, periods, target.size(1)//160 - nb_pre, pre=pre, states=None) 191 output = torch.cat([pre, output], -1) 192 193 194 # discriminator update 195 scores_gen = disc(output.detach().unsqueeze(1)) 196 scores_real = disc(target.unsqueeze(1)) 197 198 disc_loss = 0 199 for scale in scores_gen: 200 disc_loss += ((scale[-1]) ** 2).mean() 201 m_f = 0.9 * m_f + 0.1 * scale[-1].detach().mean().cpu().item() 202 s_f = 0.9 * s_f + 0.1 * scale[-1].detach().std().cpu().item() 203 204 for scale in scores_real: 205 disc_loss += ((1 - scale[-1]) ** 2).mean() 206 m_r = 0.9 * m_r + 0.1 * scale[-1].detach().mean().cpu().item() 207 s_r = 0.9 * s_r + 0.1 * scale[-1].detach().std().cpu().item() 208 209 disc_loss = 0.5 * disc_loss / len(scores_gen) 210 winning_chance = 0.5 * m.erfc( (m_r - m_f) / m.sqrt(2 * (s_f**2 + s_r**2)) ) 211 running_wc += winning_chance 212 213 disc.zero_grad() 214 disc_loss.backward() 215 optimizer_disc.step() 216 217 # model update 218 scores_gen = disc(output.unsqueeze(1)) 219 if False: # todo: check whether that makes a difference 220 with torch.no_grad(): 221 scores_real = disc(target.unsqueeze(1)) 222 223 cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+80], output[:, nb_pre*160:nb_pre*160+80]) 224 specc_loss = spect_loss(output, target.detach()) 225 reg_loss = (.00*cont_loss + specc_loss) 226 227 loss_gen = 0 228 for scale in scores_gen: 229 loss_gen += ((1 - scale[-1]) ** 2).mean() / len(scores_gen) 230 231 feat_loss = args.fmap_weight * fmap_loss(scores_real, scores_gen) 232 233 reg_weight = args.reg_weight# + 15./(1 + (batch_count/7600.)) 234 gen_loss = reg_weight * reg_loss + feat_loss + loss_gen 235 236 model.zero_grad() 237 238 239 gen_loss.backward() 240 optimizer.step() 241 242 #model.clip_weights() 243 244 scheduler.step() 245 scheduler_disc.step() 246 247 running_cont_loss += cont_loss.detach().cpu().item() 248 running_gen_loss += loss_gen.detach().cpu().item() 249 running_disc_loss += disc_loss.detach().cpu().item() 250 running_fmap_loss += feat_loss.detach().cpu().item() 251 running_reg_loss += reg_loss.detach().cpu().item() 252 253 254 255 tepoch.set_postfix(cont_loss=f"{running_cont_loss/(i+1):8.5f}", 256 reg_weight=f"{reg_weight:8.5f}", 257 gen_loss=f"{running_gen_loss/(i+1):8.5f}", 258 disc_loss=f"{running_disc_loss/(i+1):8.5f}", 259 fmap_loss=f"{running_fmap_loss/(i+1):8.5f}", 260 reg_loss=f"{running_reg_loss/(i+1):8.5f}", 261 wc = f"{running_wc/(i+1):8.5f}", 262 ) 263 batch_count = batch_count + 1 264 265 # save checkpoint 266 checkpoint_path = os.path.join(checkpoint_dir, f'fargan{args.suffix}_adv_{epoch}.pth') 267 checkpoint['state_dict'] = model.state_dict() 268 checkpoint['disc_sate_dict'] = disc.state_dict() 269 checkpoint['loss'] = { 270 'cont': running_cont_loss / len(dataloader), 271 'gen': running_gen_loss / len(dataloader), 272 'disc': running_disc_loss / len(dataloader), 273 'fmap': running_fmap_loss / len(dataloader), 274 'reg': running_reg_loss / len(dataloader) 275 } 276 checkpoint['epoch'] = epoch 277 torch.save(checkpoint, checkpoint_path) 278