1""" 2/* Copyright (c) 2023 Amazon 3 Written by Jan Buethe */ 4/* 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions 7 are met: 8 9 - Redistributions of source code must retain the above copyright 10 notice, this list of conditions and the following disclaimer. 11 12 - Redistributions in binary form must reproduce the above copyright 13 notice, this list of conditions and the following disclaimer in the 14 documentation and/or other materials provided with the distribution. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*/ 28""" 29 30import os 31import argparse 32import sys 33import math as m 34import random 35 36import yaml 37 38from tqdm import tqdm 39 40try: 41 import git 42 has_git = True 43except: 44 has_git = False 45 46import torch 47from torch.optim.lr_scheduler import LambdaLR 48import torch.nn.functional as F 49 50from scipy.io import wavfile 51import numpy as np 52import pesq 53 54from data import SilkEnhancementSet 55from models import model_dict 56 57 58from utils.silk_features import load_inference_data 59from utils.misc import count_parameters, retain_grads, get_grad_norm, create_weights 60 61from losses.stft_loss import MRSTFTLoss, MRLogMelLoss 62 63 64parser = argparse.ArgumentParser() 65 66parser.add_argument('setup', type=str, help='setup yaml file') 67parser.add_argument('output', type=str, help='output path') 68parser.add_argument('--device', type=str, help='compute device', default=None) 69parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None) 70parser.add_argument('--testdata', type=str, help='path to features and signal for testing', default=None) 71parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of stdout') 72 73args = parser.parse_args() 74 75 76torch.set_num_threads(4) 77 78with open(args.setup, 'r') as f: 79 setup = yaml.load(f.read(), yaml.FullLoader) 80 81checkpoint_prefix = 'checkpoint' 82output_prefix = 'output' 83setup_name = 'setup.yml' 84output_file='out.txt' 85 86 87# check model 88if not 'name' in setup['model']: 89 print(f'warning: did not find model entry in setup, using default PitchPostFilter') 90 model_name = 'pitchpostfilter' 91else: 92 model_name = setup['model']['name'] 93 94# prepare output folder 95if os.path.exists(args.output): 96 print("warning: output folder exists") 97 98 reply = input('continue? (y/n): ') 99 while reply not in {'y', 'n'}: 100 reply = input('continue? (y/n): ') 101 102 if reply == 'n': 103 os._exit() 104else: 105 os.makedirs(args.output, exist_ok=True) 106 107checkpoint_dir = os.path.join(args.output, 'checkpoints') 108os.makedirs(checkpoint_dir, exist_ok=True) 109 110# add repo info to setup 111if has_git: 112 working_dir = os.path.split(__file__)[0] 113 try: 114 repo = git.Repo(working_dir, search_parent_directories=True) 115 setup['repo'] = dict() 116 hash = repo.head.object.hexsha 117 urls = list(repo.remote().urls) 118 is_dirty = repo.is_dirty() 119 120 if is_dirty: 121 print("warning: repo is dirty") 122 123 setup['repo']['hash'] = hash 124 setup['repo']['urls'] = urls 125 setup['repo']['dirty'] = is_dirty 126 except: 127 has_git = False 128 129# dump setup 130with open(os.path.join(args.output, setup_name), 'w') as f: 131 yaml.dump(setup, f) 132 133 134ref = None 135if args.testdata is not None: 136 137 testsignal, features, periods, numbits = load_inference_data(args.testdata, **setup['data']) 138 139 inference_test = True 140 inference_folder = os.path.join(args.output, 'inference_test') 141 os.makedirs(os.path.join(args.output, 'inference_test'), exist_ok=True) 142 143 try: 144 ref = np.fromfile(os.path.join(args.testdata, 'clean.s16'), dtype=np.int16) 145 except: 146 pass 147else: 148 inference_test = False 149 150# training parameters 151batch_size = setup['training']['batch_size'] 152epochs = setup['training']['epochs'] 153lr = setup['training']['lr'] 154lr_decay_factor = setup['training']['lr_decay_factor'] 155lr_gen = lr * setup['training']['gen_lr_reduction'] 156lambda_feat = setup['training']['lambda_feat'] 157lambda_reg = setup['training']['lambda_reg'] 158adv_target = setup['training'].get('adv_target', 'target') 159 160# load training dataset 161data_config = setup['data'] 162data = SilkEnhancementSet(setup['dataset'], **data_config) 163 164# load validation dataset if given 165if 'validation_dataset' in setup: 166 validation_data = SilkEnhancementSet(setup['validation_dataset'], **data_config) 167 168 validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=4) 169 170 run_validation = True 171else: 172 run_validation = False 173 174# create model 175model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs']) 176 177# create discriminator 178disc_name = setup['discriminator']['name'] 179disc = model_dict[disc_name]( 180 *setup['discriminator']['args'], **setup['discriminator']['kwargs'] 181) 182 183# set compute device 184if type(args.device) == type(None): 185 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 186else: 187 device = torch.device(args.device) 188 189# dataloader 190dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=4) 191 192# optimizer is introduced to trainable parameters 193parameters = [p for p in model.parameters() if p.requires_grad] 194optimizer = torch.optim.Adam(parameters, lr=lr_gen) 195 196# disc optimizer 197parameters = [p for p in disc.parameters() if p.requires_grad] 198optimizer_disc = torch.optim.Adam(parameters, lr=lr, betas=[0.5, 0.9]) 199 200# learning rate scheduler 201scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x)) 202 203if args.initial_checkpoint is not None: 204 print(f"loading state dict from {args.initial_checkpoint}...") 205 chkpt = torch.load(args.initial_checkpoint, map_location=device) 206 model.load_state_dict(chkpt['state_dict']) 207 208 if 'disc_state_dict' in chkpt: 209 print(f"loading discriminator state dict from {args.initial_checkpoint}...") 210 disc.load_state_dict(chkpt['disc_state_dict']) 211 212 if 'optimizer_state_dict' in chkpt: 213 print(f"loading optimizer state dict from {args.initial_checkpoint}...") 214 optimizer.load_state_dict(chkpt['optimizer_state_dict']) 215 216 if 'disc_optimizer_state_dict' in chkpt: 217 print(f"loading discriminator optimizer state dict from {args.initial_checkpoint}...") 218 optimizer_disc.load_state_dict(chkpt['disc_optimizer_state_dict']) 219 220 if 'scheduler_state_disc' in chkpt: 221 print(f"loading scheduler state dict from {args.initial_checkpoint}...") 222 scheduler.load_state_dict(chkpt['scheduler_state_dict']) 223 224 # if 'torch_rng_state' in chkpt: 225 # print(f"setting torch RNG state from {args.initial_checkpoint}...") 226 # torch.set_rng_state(chkpt['torch_rng_state']) 227 228 if 'numpy_rng_state' in chkpt: 229 print(f"setting numpy RNG state from {args.initial_checkpoint}...") 230 np.random.set_state(chkpt['numpy_rng_state']) 231 232 if 'python_rng_state' in chkpt: 233 print(f"setting Python RNG state from {args.initial_checkpoint}...") 234 random.setstate(chkpt['python_rng_state']) 235 236# loss 237w_l1 = setup['training']['loss']['w_l1'] 238w_lm = setup['training']['loss']['w_lm'] 239w_slm = setup['training']['loss']['w_slm'] 240w_sc = setup['training']['loss']['w_sc'] 241w_logmel = setup['training']['loss']['w_logmel'] 242w_wsc = setup['training']['loss']['w_wsc'] 243w_xcorr = setup['training']['loss']['w_xcorr'] 244w_sxcorr = setup['training']['loss']['w_sxcorr'] 245w_l2 = setup['training']['loss']['w_l2'] 246 247w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2 248 249stftloss = MRSTFTLoss(sc_weight=w_sc, log_mag_weight=w_lm, wsc_weight=w_wsc, smooth_log_mag_weight=w_slm, sxcorr_weight=w_sxcorr).to(device) 250logmelloss = MRLogMelLoss().to(device) 251 252def xcorr_loss(y_true, y_pred): 253 dims = list(range(1, len(y_true.shape))) 254 255 loss = 1 - torch.sum(y_true * y_pred, dim=dims) / torch.sqrt(torch.sum(y_true ** 2, dim=dims) * torch.sum(y_pred ** 2, dim=dims) + 1e-9) 256 257 return torch.mean(loss) 258 259def td_l2_norm(y_true, y_pred): 260 dims = list(range(1, len(y_true.shape))) 261 262 loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6) 263 264 return loss.mean() 265 266def td_l1(y_true, y_pred, pow=0): 267 dims = list(range(1, len(y_true.shape))) 268 tmp = torch.mean(torch.abs(y_true - y_pred), dim=dims) / ((torch.mean(torch.abs(y_pred), dim=dims) + 1e-9) ** pow) 269 270 return torch.mean(tmp) 271 272def criterion(x, y): 273 274 return (w_l1 * td_l1(x, y, pow=1) + stftloss(x, y) + w_logmel * logmelloss(x, y) 275 + w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y)) / w_sum 276 277 278# model checkpoint 279checkpoint = { 280 'setup' : setup, 281 'state_dict' : model.state_dict(), 282 'loss' : -1 283} 284 285 286if not args.no_redirect: 287 print(f"re-directing output to {os.path.join(args.output, output_file)}") 288 sys.stdout = open(os.path.join(args.output, output_file), "w") 289 290 291print("summary:") 292 293print(f"generator: {count_parameters(model.cpu()) / 1e6:5.3f} M parameters") 294if hasattr(model, 'flop_count'): 295 print(f"generator: {model.flop_count(16000) / 1e6:5.3f} MFLOPS") 296print(f"discriminator: {count_parameters(disc.cpu()) / 1e6:5.3f} M parameters") 297 298if ref is not None: 299 noisy = np.fromfile(os.path.join(args.testdata, 'noisy.s16'), dtype=np.int16) 300 initial_mos = pesq.pesq(16000, ref, noisy, mode='wb') 301 print(f"initial MOS (PESQ): {initial_mos}") 302 303best_loss = 1e9 304log_interval = 10 305 306 307m_r = 0 308m_f = 0 309s_r = 1 310s_f = 1 311 312def optimizer_to(optim, device): 313 for param in optim.state.values(): 314 if isinstance(param, torch.Tensor): 315 param.data = param.data.to(device) 316 if param._grad is not None: 317 param._grad.data = param._grad.data.to(device) 318 elif isinstance(param, dict): 319 for subparam in param.values(): 320 if isinstance(subparam, torch.Tensor): 321 subparam.data = subparam.data.to(device) 322 if subparam._grad is not None: 323 subparam._grad.data = subparam._grad.data.to(device) 324 325optimizer_to(optimizer, device) 326optimizer_to(optimizer_disc, device) 327 328retain_grads(model) 329retain_grads(disc) 330 331for ep in range(1, epochs + 1): 332 print(f"training epoch {ep}...") 333 334 model.to(device) 335 disc.to(device) 336 model.train() 337 disc.train() 338 339 running_disc_loss = 0 340 running_adv_loss = 0 341 running_feature_loss = 0 342 running_reg_loss = 0 343 running_disc_grad_norm = 0 344 running_model_grad_norm = 0 345 346 with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch: 347 for i, batch in enumerate(tepoch): 348 349 # set gradients to zero 350 optimizer.zero_grad() 351 352 # push batch to device 353 for key in batch: 354 batch[key] = batch[key].to(device) 355 356 target = batch['target'].to(device) 357 disc_target = batch[adv_target].to(device) 358 359 # calculate model output 360 output = model(batch['signals'].permute(0, 2, 1), batch['features'], batch['periods'], batch['numbits']) 361 362 # discriminator update 363 scores_gen = disc(output.detach()) 364 scores_real = disc(disc_target.unsqueeze(1)) 365 366 disc_loss = 0 367 for score in scores_gen: 368 disc_loss += (((score[-1]) ** 2)).mean() 369 m_f = 0.9 * m_f + 0.1 * score[-1].detach().mean().cpu().item() 370 s_f = 0.9 * s_f + 0.1 * score[-1].detach().std().cpu().item() 371 372 for score in scores_real: 373 disc_loss += (((1 - score[-1]) ** 2)).mean() 374 m_r = 0.9 * m_r + 0.1 * score[-1].detach().mean().cpu().item() 375 s_r = 0.9 * s_r + 0.1 * score[-1].detach().std().cpu().item() 376 377 disc_loss = 0.5 * disc_loss / len(scores_gen) 378 winning_chance = 0.5 * m.erfc( (m_r - m_f) / m.sqrt(2 * (s_f**2 + s_r**2)) ) 379 380 disc.zero_grad() 381 disc_loss.backward() 382 383 running_disc_grad_norm += get_grad_norm(disc).detach().cpu().item() 384 385 optimizer_disc.step() 386 387 # generator update 388 scores_gen = disc(output) 389 390 # calculate loss 391 loss_reg = criterion(output.squeeze(1), target) 392 393 num_discs = len(scores_gen) 394 gen_loss = 0 395 for score in scores_gen: 396 gen_loss += (((1 - score[-1]) ** 2)).mean() / num_discs 397 398 loss_feat = 0 399 for k in range(num_discs): 400 num_layers = len(scores_gen[k]) - 1 401 f = 4 / num_discs / num_layers 402 for l in range(num_layers): 403 loss_feat += f * F.l1_loss(scores_gen[k][l], scores_real[k][l].detach()) 404 405 model.zero_grad() 406 407 (gen_loss + lambda_feat * loss_feat + lambda_reg * loss_reg).backward() 408 409 optimizer.step() 410 411 # sparsification 412 if hasattr(model, 'sparsifier'): 413 model.sparsifier() 414 415 running_model_grad_norm += get_grad_norm(model).detach().cpu().item() 416 running_adv_loss += gen_loss.detach().cpu().item() 417 running_disc_loss += disc_loss.detach().cpu().item() 418 running_feature_loss += lambda_feat * loss_feat.detach().cpu().item() 419 running_reg_loss += lambda_reg * loss_reg.detach().cpu().item() 420 421 # update status bar 422 if i % log_interval == 0: 423 tepoch.set_postfix(adv_loss=f"{running_adv_loss/(i + 1):8.7f}", 424 disc_loss=f"{running_disc_loss/(i + 1):8.7f}", 425 feat_loss=f"{running_feature_loss/(i + 1):8.7f}", 426 reg_loss=f"{running_reg_loss/(i + 1):8.7f}", 427 model_gradnorm=f"{running_model_grad_norm/(i+1):8.7f}", 428 disc_gradnorm=f"{running_disc_grad_norm/(i+1):8.7f}", 429 wc=f"{100*winning_chance:5.2f}%") 430 431 432 # save checkpoint 433 checkpoint['state_dict'] = model.state_dict() 434 checkpoint['disc_state_dict'] = disc.state_dict() 435 checkpoint['optimizer_state_dict'] = optimizer.state_dict() 436 checkpoint['disc_optimizer_state_dict'] = optimizer_disc.state_dict() 437 checkpoint['scheduler_state_dict'] = scheduler.state_dict() 438 checkpoint['torch_rng_state'] = torch.get_rng_state() 439 checkpoint['numpy_rng_state'] = np.random.get_state() 440 checkpoint['python_rng_state'] = random.getstate() 441 checkpoint['adv_loss'] = running_adv_loss/(i + 1) 442 checkpoint['disc_loss'] = running_disc_loss/(i + 1) 443 checkpoint['feature_loss'] = running_feature_loss/(i + 1) 444 checkpoint['reg_loss'] = running_reg_loss/(i + 1) 445 446 447 if inference_test: 448 print("running inference test...") 449 out = model.process(testsignal, features, periods, numbits).cpu().numpy() 450 wavfile.write(os.path.join(inference_folder, f'{model_name}_epoch_{ep}.wav'), 16000, out) 451 if ref is not None: 452 mos = pesq.pesq(16000, ref, out, mode='wb') 453 print(f"MOS (PESQ): {mos}") 454 455 456 torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth')) 457 torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth')) 458 459 460 print() 461 462print('Done') 463