xref: /aosp_15_r20/external/libopus/dnn/torch/osce/adv_train_model.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2/* Copyright (c) 2023 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30import os
31import argparse
32import sys
33import math as m
34import random
35
36import yaml
37
38from tqdm import tqdm
39
40try:
41    import git
42    has_git = True
43except:
44    has_git = False
45
46import torch
47from torch.optim.lr_scheduler import LambdaLR
48import torch.nn.functional as F
49
50from scipy.io import wavfile
51import numpy as np
52import pesq
53
54from data import SilkEnhancementSet
55from models import model_dict
56
57
58from utils.silk_features import load_inference_data
59from utils.misc import count_parameters, retain_grads, get_grad_norm, create_weights
60
61from losses.stft_loss import MRSTFTLoss, MRLogMelLoss
62
63
64parser = argparse.ArgumentParser()
65
66parser.add_argument('setup', type=str, help='setup yaml file')
67parser.add_argument('output', type=str, help='output path')
68parser.add_argument('--device', type=str, help='compute device', default=None)
69parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None)
70parser.add_argument('--testdata', type=str, help='path to features and signal for testing', default=None)
71parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of stdout')
72
73args = parser.parse_args()
74
75
76torch.set_num_threads(4)
77
78with open(args.setup, 'r') as f:
79    setup = yaml.load(f.read(), yaml.FullLoader)
80
81checkpoint_prefix = 'checkpoint'
82output_prefix = 'output'
83setup_name = 'setup.yml'
84output_file='out.txt'
85
86
87# check model
88if not 'name' in setup['model']:
89    print(f'warning: did not find model entry in setup, using default PitchPostFilter')
90    model_name = 'pitchpostfilter'
91else:
92    model_name = setup['model']['name']
93
94# prepare output folder
95if os.path.exists(args.output):
96    print("warning: output folder exists")
97
98    reply = input('continue? (y/n): ')
99    while reply not in {'y', 'n'}:
100        reply = input('continue? (y/n): ')
101
102    if reply == 'n':
103        os._exit()
104else:
105    os.makedirs(args.output, exist_ok=True)
106
107checkpoint_dir = os.path.join(args.output, 'checkpoints')
108os.makedirs(checkpoint_dir, exist_ok=True)
109
110# add repo info to setup
111if has_git:
112    working_dir = os.path.split(__file__)[0]
113    try:
114        repo = git.Repo(working_dir, search_parent_directories=True)
115        setup['repo'] = dict()
116        hash = repo.head.object.hexsha
117        urls = list(repo.remote().urls)
118        is_dirty = repo.is_dirty()
119
120        if is_dirty:
121            print("warning: repo is dirty")
122
123        setup['repo']['hash'] = hash
124        setup['repo']['urls'] = urls
125        setup['repo']['dirty'] = is_dirty
126    except:
127        has_git = False
128
129# dump setup
130with open(os.path.join(args.output, setup_name), 'w') as f:
131    yaml.dump(setup, f)
132
133
134ref = None
135if args.testdata is not None:
136
137    testsignal, features, periods, numbits = load_inference_data(args.testdata, **setup['data'])
138
139    inference_test = True
140    inference_folder = os.path.join(args.output, 'inference_test')
141    os.makedirs(os.path.join(args.output, 'inference_test'), exist_ok=True)
142
143    try:
144        ref = np.fromfile(os.path.join(args.testdata, 'clean.s16'), dtype=np.int16)
145    except:
146        pass
147else:
148    inference_test = False
149
150# training parameters
151batch_size      = setup['training']['batch_size']
152epochs          = setup['training']['epochs']
153lr              = setup['training']['lr']
154lr_decay_factor = setup['training']['lr_decay_factor']
155lr_gen          = lr * setup['training']['gen_lr_reduction']
156lambda_feat     =  setup['training']['lambda_feat']
157lambda_reg      = setup['training']['lambda_reg']
158adv_target      = setup['training'].get('adv_target', 'target')
159
160# load training dataset
161data_config = setup['data']
162data = SilkEnhancementSet(setup['dataset'], **data_config)
163
164# load validation dataset if given
165if 'validation_dataset' in setup:
166    validation_data = SilkEnhancementSet(setup['validation_dataset'], **data_config)
167
168    validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=4)
169
170    run_validation = True
171else:
172    run_validation = False
173
174# create model
175model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs'])
176
177# create discriminator
178disc_name = setup['discriminator']['name']
179disc = model_dict[disc_name](
180    *setup['discriminator']['args'], **setup['discriminator']['kwargs']
181)
182
183# set compute device
184if type(args.device) == type(None):
185    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
186else:
187    device = torch.device(args.device)
188
189# dataloader
190dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=4)
191
192# optimizer is introduced to trainable parameters
193parameters = [p for p in model.parameters() if p.requires_grad]
194optimizer = torch.optim.Adam(parameters, lr=lr_gen)
195
196# disc optimizer
197parameters = [p for p in disc.parameters() if p.requires_grad]
198optimizer_disc = torch.optim.Adam(parameters, lr=lr, betas=[0.5, 0.9])
199
200# learning rate scheduler
201scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
202
203if args.initial_checkpoint is not None:
204    print(f"loading state dict from {args.initial_checkpoint}...")
205    chkpt = torch.load(args.initial_checkpoint, map_location=device)
206    model.load_state_dict(chkpt['state_dict'])
207
208    if 'disc_state_dict' in chkpt:
209        print(f"loading discriminator state dict from {args.initial_checkpoint}...")
210        disc.load_state_dict(chkpt['disc_state_dict'])
211
212    if 'optimizer_state_dict' in chkpt:
213        print(f"loading optimizer state dict from {args.initial_checkpoint}...")
214        optimizer.load_state_dict(chkpt['optimizer_state_dict'])
215
216    if 'disc_optimizer_state_dict' in chkpt:
217        print(f"loading discriminator optimizer state dict from {args.initial_checkpoint}...")
218        optimizer_disc.load_state_dict(chkpt['disc_optimizer_state_dict'])
219
220    if 'scheduler_state_disc' in chkpt:
221        print(f"loading scheduler state dict from {args.initial_checkpoint}...")
222        scheduler.load_state_dict(chkpt['scheduler_state_dict'])
223
224    # if 'torch_rng_state' in chkpt:
225    #     print(f"setting torch RNG state from {args.initial_checkpoint}...")
226    #     torch.set_rng_state(chkpt['torch_rng_state'])
227
228    if 'numpy_rng_state' in chkpt:
229        print(f"setting numpy RNG state from {args.initial_checkpoint}...")
230        np.random.set_state(chkpt['numpy_rng_state'])
231
232    if 'python_rng_state' in chkpt:
233        print(f"setting Python RNG state from {args.initial_checkpoint}...")
234        random.setstate(chkpt['python_rng_state'])
235
236# loss
237w_l1 = setup['training']['loss']['w_l1']
238w_lm = setup['training']['loss']['w_lm']
239w_slm = setup['training']['loss']['w_slm']
240w_sc = setup['training']['loss']['w_sc']
241w_logmel = setup['training']['loss']['w_logmel']
242w_wsc = setup['training']['loss']['w_wsc']
243w_xcorr = setup['training']['loss']['w_xcorr']
244w_sxcorr = setup['training']['loss']['w_sxcorr']
245w_l2 = setup['training']['loss']['w_l2']
246
247w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2
248
249stftloss = MRSTFTLoss(sc_weight=w_sc, log_mag_weight=w_lm, wsc_weight=w_wsc, smooth_log_mag_weight=w_slm, sxcorr_weight=w_sxcorr).to(device)
250logmelloss = MRLogMelLoss().to(device)
251
252def xcorr_loss(y_true, y_pred):
253    dims = list(range(1, len(y_true.shape)))
254
255    loss = 1 - torch.sum(y_true * y_pred, dim=dims) / torch.sqrt(torch.sum(y_true ** 2, dim=dims) * torch.sum(y_pred ** 2, dim=dims) + 1e-9)
256
257    return torch.mean(loss)
258
259def td_l2_norm(y_true, y_pred):
260    dims = list(range(1, len(y_true.shape)))
261
262    loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6)
263
264    return loss.mean()
265
266def td_l1(y_true, y_pred, pow=0):
267    dims = list(range(1, len(y_true.shape)))
268    tmp = torch.mean(torch.abs(y_true - y_pred), dim=dims) / ((torch.mean(torch.abs(y_pred), dim=dims) + 1e-9) ** pow)
269
270    return torch.mean(tmp)
271
272def criterion(x, y):
273
274    return (w_l1 * td_l1(x, y, pow=1) +  stftloss(x, y) + w_logmel * logmelloss(x, y)
275            + w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y)) / w_sum
276
277
278# model checkpoint
279checkpoint = {
280    'setup'         : setup,
281    'state_dict'    : model.state_dict(),
282    'loss'          : -1
283}
284
285
286if not args.no_redirect:
287    print(f"re-directing output to {os.path.join(args.output, output_file)}")
288    sys.stdout = open(os.path.join(args.output, output_file), "w")
289
290
291print("summary:")
292
293print(f"generator: {count_parameters(model.cpu()) / 1e6:5.3f} M parameters")
294if hasattr(model, 'flop_count'):
295    print(f"generator: {model.flop_count(16000) / 1e6:5.3f} MFLOPS")
296print(f"discriminator: {count_parameters(disc.cpu()) / 1e6:5.3f} M parameters")
297
298if ref is not None:
299    noisy = np.fromfile(os.path.join(args.testdata, 'noisy.s16'), dtype=np.int16)
300    initial_mos = pesq.pesq(16000, ref, noisy, mode='wb')
301    print(f"initial MOS (PESQ): {initial_mos}")
302
303best_loss = 1e9
304log_interval = 10
305
306
307m_r = 0
308m_f = 0
309s_r = 1
310s_f = 1
311
312def optimizer_to(optim, device):
313    for param in optim.state.values():
314        if isinstance(param, torch.Tensor):
315            param.data = param.data.to(device)
316            if param._grad is not None:
317                param._grad.data = param._grad.data.to(device)
318        elif isinstance(param, dict):
319            for subparam in param.values():
320                if isinstance(subparam, torch.Tensor):
321                    subparam.data = subparam.data.to(device)
322                    if subparam._grad is not None:
323                        subparam._grad.data = subparam._grad.data.to(device)
324
325optimizer_to(optimizer, device)
326optimizer_to(optimizer_disc, device)
327
328retain_grads(model)
329retain_grads(disc)
330
331for ep in range(1, epochs + 1):
332    print(f"training epoch {ep}...")
333
334    model.to(device)
335    disc.to(device)
336    model.train()
337    disc.train()
338
339    running_disc_loss = 0
340    running_adv_loss = 0
341    running_feature_loss = 0
342    running_reg_loss = 0
343    running_disc_grad_norm = 0
344    running_model_grad_norm = 0
345
346    with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
347        for i, batch in enumerate(tepoch):
348
349            # set gradients to zero
350            optimizer.zero_grad()
351
352            # push batch to device
353            for key in batch:
354                batch[key] = batch[key].to(device)
355
356            target = batch['target'].to(device)
357            disc_target = batch[adv_target].to(device)
358
359            # calculate model output
360            output = model(batch['signals'].permute(0, 2, 1), batch['features'], batch['periods'], batch['numbits'])
361
362            # discriminator update
363            scores_gen = disc(output.detach())
364            scores_real = disc(disc_target.unsqueeze(1))
365
366            disc_loss = 0
367            for score in scores_gen:
368                disc_loss += (((score[-1]) ** 2)).mean()
369                m_f = 0.9 * m_f + 0.1 * score[-1].detach().mean().cpu().item()
370                s_f = 0.9 * s_f + 0.1 * score[-1].detach().std().cpu().item()
371
372            for score in scores_real:
373                disc_loss += (((1 - score[-1]) ** 2)).mean()
374                m_r = 0.9 * m_r + 0.1 * score[-1].detach().mean().cpu().item()
375                s_r = 0.9 * s_r + 0.1 * score[-1].detach().std().cpu().item()
376
377            disc_loss = 0.5 * disc_loss / len(scores_gen)
378            winning_chance = 0.5 * m.erfc( (m_r - m_f) / m.sqrt(2 * (s_f**2 + s_r**2)) )
379
380            disc.zero_grad()
381            disc_loss.backward()
382
383            running_disc_grad_norm += get_grad_norm(disc).detach().cpu().item()
384
385            optimizer_disc.step()
386
387            # generator update
388            scores_gen = disc(output)
389
390            # calculate loss
391            loss_reg = criterion(output.squeeze(1), target)
392
393            num_discs = len(scores_gen)
394            gen_loss = 0
395            for score in  scores_gen:
396                gen_loss += (((1 - score[-1]) ** 2)).mean() / num_discs
397
398            loss_feat = 0
399            for k in range(num_discs):
400                num_layers = len(scores_gen[k]) - 1
401                f = 4 / num_discs / num_layers
402                for l in range(num_layers):
403                    loss_feat += f * F.l1_loss(scores_gen[k][l], scores_real[k][l].detach())
404
405            model.zero_grad()
406
407            (gen_loss + lambda_feat * loss_feat + lambda_reg * loss_reg).backward()
408
409            optimizer.step()
410
411            # sparsification
412            if hasattr(model, 'sparsifier'):
413                model.sparsifier()
414
415            running_model_grad_norm += get_grad_norm(model).detach().cpu().item()
416            running_adv_loss += gen_loss.detach().cpu().item()
417            running_disc_loss += disc_loss.detach().cpu().item()
418            running_feature_loss += lambda_feat * loss_feat.detach().cpu().item()
419            running_reg_loss += lambda_reg * loss_reg.detach().cpu().item()
420
421            # update status bar
422            if i % log_interval == 0:
423                tepoch.set_postfix(adv_loss=f"{running_adv_loss/(i + 1):8.7f}",
424                                   disc_loss=f"{running_disc_loss/(i + 1):8.7f}",
425                                   feat_loss=f"{running_feature_loss/(i + 1):8.7f}",
426                                   reg_loss=f"{running_reg_loss/(i + 1):8.7f}",
427                                   model_gradnorm=f"{running_model_grad_norm/(i+1):8.7f}",
428                                   disc_gradnorm=f"{running_disc_grad_norm/(i+1):8.7f}",
429                                   wc=f"{100*winning_chance:5.2f}%")
430
431
432    # save checkpoint
433    checkpoint['state_dict'] = model.state_dict()
434    checkpoint['disc_state_dict'] = disc.state_dict()
435    checkpoint['optimizer_state_dict'] = optimizer.state_dict()
436    checkpoint['disc_optimizer_state_dict'] = optimizer_disc.state_dict()
437    checkpoint['scheduler_state_dict'] = scheduler.state_dict()
438    checkpoint['torch_rng_state'] = torch.get_rng_state()
439    checkpoint['numpy_rng_state'] = np.random.get_state()
440    checkpoint['python_rng_state'] = random.getstate()
441    checkpoint['adv_loss']   = running_adv_loss/(i + 1)
442    checkpoint['disc_loss']  = running_disc_loss/(i + 1)
443    checkpoint['feature_loss'] = running_feature_loss/(i + 1)
444    checkpoint['reg_loss'] = running_reg_loss/(i + 1)
445
446
447    if inference_test:
448        print("running inference test...")
449        out = model.process(testsignal, features, periods, numbits).cpu().numpy()
450        wavfile.write(os.path.join(inference_folder, f'{model_name}_epoch_{ep}.wav'), 16000, out)
451        if ref is not None:
452            mos = pesq.pesq(16000, ref, out, mode='wb')
453            print(f"MOS (PESQ): {mos}")
454
455
456    torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth'))
457    torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
458
459
460    print()
461
462print('Done')
463