xref: /aosp_15_r20/external/libopus/dnn/torch/osce/adv_train_vocoder.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2/* Copyright (c) 2023 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30import os
31import argparse
32import sys
33import math as m
34import random
35
36import yaml
37
38from tqdm import tqdm
39
40try:
41    import git
42    has_git = True
43except:
44    has_git = False
45
46import torch
47from torch.optim.lr_scheduler import LambdaLR
48import torch.nn.functional as F
49
50from scipy.io import wavfile
51import numpy as np
52import pesq
53
54from data import LPCNetVocodingDataset
55from models import model_dict
56
57
58from utils.lpcnet_features import load_lpcnet_features
59from utils.misc import count_parameters
60
61from losses.stft_loss import MRSTFTLoss, MRLogMelLoss
62
63
64parser = argparse.ArgumentParser()
65
66parser.add_argument('setup', type=str, help='setup yaml file')
67parser.add_argument('output', type=str, help='output path')
68parser.add_argument('--device', type=str, help='compute device', default=None)
69parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None)
70parser.add_argument('--test-features', type=str, help='path to features for testing', default=None)
71parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of stdout')
72
73args = parser.parse_args()
74
75
76torch.set_num_threads(4)
77
78with open(args.setup, 'r') as f:
79    setup = yaml.load(f.read(), yaml.FullLoader)
80
81checkpoint_prefix = 'checkpoint'
82output_prefix = 'output'
83setup_name = 'setup.yml'
84output_file='out.txt'
85
86
87# check model
88if not 'name' in setup['model']:
89    print(f'warning: did not find model entry in setup, using default PitchPostFilter')
90    model_name = 'pitchpostfilter'
91else:
92    model_name = setup['model']['name']
93
94# prepare output folder
95if os.path.exists(args.output):
96    print("warning: output folder exists")
97
98    reply = input('continue? (y/n): ')
99    while reply not in {'y', 'n'}:
100        reply = input('continue? (y/n): ')
101
102    if reply == 'n':
103        os._exit()
104else:
105    os.makedirs(args.output, exist_ok=True)
106
107checkpoint_dir = os.path.join(args.output, 'checkpoints')
108os.makedirs(checkpoint_dir, exist_ok=True)
109
110# add repo info to setup
111if has_git:
112    working_dir = os.path.split(__file__)[0]
113    try:
114        repo = git.Repo(working_dir, search_parent_directories=True)
115        setup['repo'] = dict()
116        hash = repo.head.object.hexsha
117        urls = list(repo.remote().urls)
118        is_dirty = repo.is_dirty()
119
120        if is_dirty:
121            print("warning: repo is dirty")
122
123        setup['repo']['hash'] = hash
124        setup['repo']['urls'] = urls
125        setup['repo']['dirty'] = is_dirty
126    except:
127        has_git = False
128
129# dump setup
130with open(os.path.join(args.output, setup_name), 'w') as f:
131    yaml.dump(setup, f)
132
133
134ref = None
135# prepare inference test if wanted
136inference_test = False
137if type(args.test_features) != type(None):
138    test_features = load_lpcnet_features(args.test_features)
139    features = test_features['features']
140    periods = test_features['periods']
141    inference_folder = os.path.join(args.output, 'inference_test')
142    os.makedirs(inference_folder, exist_ok=True)
143    inference_test = True
144
145
146# training parameters
147batch_size      = setup['training']['batch_size']
148epochs          = setup['training']['epochs']
149lr              = setup['training']['lr']
150lr_decay_factor = setup['training']['lr_decay_factor']
151lr_gen          = lr * setup['training']['gen_lr_reduction']
152lambda_feat     =  setup['training']['lambda_feat']
153lambda_reg      = setup['training']['lambda_reg']
154adv_target      = setup['training'].get('adv_target', 'target')
155
156
157# load training dataset
158data_config = setup['data']
159data = LPCNetVocodingDataset(setup['dataset'], **data_config)
160
161# load validation dataset if given
162if 'validation_dataset' in setup:
163    validation_data = LPCNetVocodingDataset(setup['validation_dataset'], **data_config)
164
165    validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=4)
166
167    run_validation = True
168else:
169    run_validation = False
170
171# create model
172model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs'])
173
174
175# create discriminator
176disc_name = setup['discriminator']['name']
177disc = model_dict[disc_name](
178    *setup['discriminator']['args'], **setup['discriminator']['kwargs']
179)
180
181
182
183# set compute device
184if type(args.device) == type(None):
185    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
186else:
187    device = torch.device(args.device)
188
189
190
191# dataloader
192dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=4)
193
194# optimizer is introduced to trainable parameters
195parameters = [p for p in model.parameters() if p.requires_grad]
196optimizer = torch.optim.Adam(parameters, lr=lr_gen)
197
198# disc optimizer
199parameters = [p for p in disc.parameters() if p.requires_grad]
200optimizer_disc = torch.optim.Adam(parameters, lr=lr, betas=[0.5, 0.9])
201
202# learning rate scheduler
203scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
204
205if args.initial_checkpoint is not None:
206    print(f"loading state dict from {args.initial_checkpoint}...")
207    chkpt = torch.load(args.initial_checkpoint, map_location=device)
208    model.load_state_dict(chkpt['state_dict'])
209
210    if 'disc_state_dict' in chkpt:
211        print(f"loading discriminator state dict from {args.initial_checkpoint}...")
212        disc.load_state_dict(chkpt['disc_state_dict'])
213
214    if 'optimizer_state_dict' in chkpt:
215        print(f"loading optimizer state dict from {args.initial_checkpoint}...")
216        optimizer.load_state_dict(chkpt['optimizer_state_dict'])
217
218    if 'disc_optimizer_state_dict' in chkpt:
219        print(f"loading discriminator optimizer state dict from {args.initial_checkpoint}...")
220        optimizer_disc.load_state_dict(chkpt['disc_optimizer_state_dict'])
221
222    if 'scheduler_state_disc' in chkpt:
223        print(f"loading scheduler state dict from {args.initial_checkpoint}...")
224        scheduler.load_state_dict(chkpt['scheduler_state_dict'])
225
226    # if 'torch_rng_state' in chkpt:
227    #     print(f"setting torch RNG state from {args.initial_checkpoint}...")
228    #     torch.set_rng_state(chkpt['torch_rng_state'])
229
230    if 'numpy_rng_state' in chkpt:
231        print(f"setting numpy RNG state from {args.initial_checkpoint}...")
232        np.random.set_state(chkpt['numpy_rng_state'])
233
234    if 'python_rng_state' in chkpt:
235        print(f"setting Python RNG state from {args.initial_checkpoint}...")
236        random.setstate(chkpt['python_rng_state'])
237
238# loss
239w_l1 = setup['training']['loss']['w_l1']
240w_lm = setup['training']['loss']['w_lm']
241w_slm = setup['training']['loss']['w_slm']
242w_sc = setup['training']['loss']['w_sc']
243w_logmel = setup['training']['loss']['w_logmel']
244w_wsc = setup['training']['loss']['w_wsc']
245w_xcorr = setup['training']['loss']['w_xcorr']
246w_sxcorr = setup['training']['loss']['w_sxcorr']
247w_l2 = setup['training']['loss']['w_l2']
248
249w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2
250
251stftloss = MRSTFTLoss(sc_weight=w_sc, log_mag_weight=w_lm, wsc_weight=w_wsc, smooth_log_mag_weight=w_slm, sxcorr_weight=w_sxcorr).to(device)
252logmelloss = MRLogMelLoss().to(device)
253
254def xcorr_loss(y_true, y_pred):
255    dims = list(range(1, len(y_true.shape)))
256
257    loss = 1 - torch.sum(y_true * y_pred, dim=dims) / torch.sqrt(torch.sum(y_true ** 2, dim=dims) * torch.sum(y_pred ** 2, dim=dims) + 1e-9)
258
259    return torch.mean(loss)
260
261def td_l2_norm(y_true, y_pred):
262    dims = list(range(1, len(y_true.shape)))
263
264    loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6)
265
266    return loss.mean()
267
268def td_l1(y_true, y_pred, pow=0):
269    dims = list(range(1, len(y_true.shape)))
270    tmp = torch.mean(torch.abs(y_true - y_pred), dim=dims) / ((torch.mean(torch.abs(y_pred), dim=dims) + 1e-9) ** pow)
271
272    return torch.mean(tmp)
273
274def criterion(x, y):
275
276    return (w_l1 * td_l1(x, y, pow=1) +  stftloss(x, y) + w_logmel * logmelloss(x, y)
277            + w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y)) / w_sum
278
279
280# model checkpoint
281checkpoint = {
282    'setup'         : setup,
283    'state_dict'    : model.state_dict(),
284    'loss'          : -1
285}
286
287
288if not args.no_redirect:
289    print(f"re-directing output to {os.path.join(args.output, output_file)}")
290    sys.stdout = open(os.path.join(args.output, output_file), "w")
291
292
293print("summary:")
294
295print(f"generator: {count_parameters(model.cpu()) / 1e6:5.3f} M parameters")
296if hasattr(model, 'flop_count'):
297    print(f"generator: {model.flop_count(16000) / 1e6:5.3f} MFLOPS")
298print(f"discriminator: {count_parameters(disc.cpu()) / 1e6:5.3f} M parameters")
299
300if ref is not None:
301    noisy = np.fromfile(os.path.join(args.testdata, 'noisy.s16'), dtype=np.int16)
302    initial_mos = pesq.pesq(16000, ref, noisy, mode='wb')
303    print(f"initial MOS (PESQ): {initial_mos}")
304
305best_loss = 1e9
306log_interval = 10
307
308
309m_r = 0
310m_f = 0
311s_r = 1
312s_f = 1
313
314def optimizer_to(optim, device):
315    for param in optim.state.values():
316        if isinstance(param, torch.Tensor):
317            param.data = param.data.to(device)
318            if param._grad is not None:
319                param._grad.data = param._grad.data.to(device)
320        elif isinstance(param, dict):
321            for subparam in param.values():
322                if isinstance(subparam, torch.Tensor):
323                    subparam.data = subparam.data.to(device)
324                    if subparam._grad is not None:
325                        subparam._grad.data = subparam._grad.data.to(device)
326
327optimizer_to(optimizer, device)
328optimizer_to(optimizer_disc, device)
329
330
331for ep in range(1, epochs + 1):
332    print(f"training epoch {ep}...")
333
334    model.to(device)
335    disc.to(device)
336    model.train()
337    disc.train()
338
339    running_disc_loss = 0
340    running_adv_loss = 0
341    running_feature_loss = 0
342    running_reg_loss = 0
343
344    with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
345        for i, batch in enumerate(tepoch):
346
347            # set gradients to zero
348            optimizer.zero_grad()
349
350            # push batch to device
351            for key in batch:
352                batch[key] = batch[key].to(device)
353
354            target = batch['target'].to(device)
355            disc_target = batch[adv_target].to(device)
356
357            # calculate model output
358            output = model(batch['features'], batch['periods'])
359
360            # discriminator update
361            scores_gen = disc(output.detach())
362            scores_real = disc(disc_target.unsqueeze(1))
363
364            disc_loss = 0
365            for scale in scores_gen:
366                disc_loss += ((scale[-1]) ** 2).mean()
367                m_f = 0.9 * m_f + 0.1 * scale[-1].detach().mean().cpu().item()
368                s_f = 0.9 * s_f + 0.1 * scale[-1].detach().std().cpu().item()
369
370            for scale in scores_real:
371                disc_loss += ((1 - scale[-1]) ** 2).mean()
372                m_r = 0.9 * m_r + 0.1 * scale[-1].detach().mean().cpu().item()
373                s_r = 0.9 * s_r + 0.1 * scale[-1].detach().std().cpu().item()
374
375            disc_loss = 0.5 * disc_loss / len(scores_gen)
376            winning_chance = 0.5 * m.erfc( (m_r - m_f) / m.sqrt(2 * (s_f**2 + s_r**2)) )
377
378            disc.zero_grad()
379            disc_loss.backward()
380            optimizer_disc.step()
381
382            # generator update
383            scores_gen = disc(output)
384
385
386            # calculate loss
387            loss_reg = criterion(output.squeeze(1), target)
388
389            num_discs = len(scores_gen)
390            loss_gen = 0
391            for scale in scores_gen:
392                loss_gen += ((1 - scale[-1]) ** 2).mean() / num_discs
393
394            loss_feat = 0
395            for k in range(num_discs):
396                num_layers = len(scores_gen[k]) - 1
397                f = 4 / num_discs / num_layers
398                for l in range(num_layers):
399                    loss_feat += f * F.l1_loss(scores_gen[k][l], scores_real[k][l].detach())
400
401            model.zero_grad()
402
403            (loss_gen + lambda_feat * loss_feat + lambda_reg * loss_reg).backward()
404
405            optimizer.step()
406
407            running_adv_loss += loss_gen.detach().cpu().item()
408            running_disc_loss += disc_loss.detach().cpu().item()
409            running_feature_loss += lambda_feat * loss_feat.detach().cpu().item()
410            running_reg_loss += lambda_reg * loss_reg.detach().cpu().item()
411
412            # update status bar
413            if i % log_interval == 0:
414                tepoch.set_postfix(adv_loss=f"{running_adv_loss/(i + 1):8.7f}",
415                                   disc_loss=f"{running_disc_loss/(i + 1):8.7f}",
416                                   feat_loss=f"{running_feature_loss/(i + 1):8.7f}",
417                                   reg_loss=f"{running_reg_loss/(i + 1):8.7f}",
418                                   wc=f"{100*winning_chance:5.2f}%")
419
420
421    # save checkpoint
422    checkpoint['state_dict'] = model.state_dict()
423    checkpoint['disc_state_dict'] = disc.state_dict()
424    checkpoint['optimizer_state_dict'] = optimizer.state_dict()
425    checkpoint['disc_optimizer_state_dict'] = optimizer_disc.state_dict()
426    checkpoint['scheduler_state_dict'] = scheduler.state_dict()
427    checkpoint['torch_rng_state'] = torch.get_rng_state()
428    checkpoint['numpy_rng_state'] = np.random.get_state()
429    checkpoint['python_rng_state'] = random.getstate()
430    checkpoint['adv_loss']   = running_adv_loss/(i + 1)
431    checkpoint['disc_loss']  = running_disc_loss/(i + 1)
432    checkpoint['feature_loss'] = running_feature_loss/(i + 1)
433    checkpoint['reg_loss'] = running_reg_loss/(i + 1)
434
435
436    if inference_test:
437        print("running inference test...")
438        out = model.process(features, periods).cpu().numpy()
439        wavfile.write(os.path.join(inference_folder, f'{model_name}_epoch_{ep}.wav'), 16000, out)
440        if ref is not None:
441            mos = pesq.pesq(16000, ref, out, mode='wb')
442            print(f"MOS (PESQ): {mos}")
443
444
445    torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth'))
446    torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
447
448
449    print()
450
451print('Done')
452