xref: /aosp_15_r20/external/libopus/dnn/torch/fargan/adv_train_fargan.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1import os
2import argparse
3import random
4import numpy as np
5import sys
6import math as m
7
8import torch
9from torch import nn
10import torch.nn.functional as F
11import tqdm
12
13import fargan
14from dataset import FARGANDataset
15from stft_loss import *
16
17source_dir = os.path.split(os.path.abspath(__file__))[0]
18sys.path.append(os.path.join(source_dir, "../osce/"))
19
20import models as osce_models
21
22
23def fmap_loss(scores_real, scores_gen):
24    num_discs = len(scores_real)
25    loss_feat = 0
26    for k in range(num_discs):
27        num_layers = len(scores_gen[k]) - 1
28        f = 4 / num_discs / num_layers
29        for l in range(num_layers):
30            loss_feat += f * F.l1_loss(scores_gen[k][l], scores_real[k][l].detach())
31
32    return loss_feat
33
34parser = argparse.ArgumentParser()
35
36parser.add_argument('features', type=str, help='path to feature file in .f32 format')
37parser.add_argument('signal', type=str, help='path to signal file in .s16 format')
38parser.add_argument('output', type=str, help='path to output folder')
39
40parser.add_argument('--suffix', type=str, help="model name suffix", default="")
41parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: CUDA_VISIBLE_DEVICES", default=None)
42
43
44model_group = parser.add_argument_group(title="model parameters")
45model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
46model_group.add_argument('--gamma', type=float, help="Use A(z/gamma), default: 0.9", default=0.9)
47
48training_group = parser.add_argument_group(title="training parameters")
49training_group.add_argument('--batch-size', type=int, help="batch size, default: 128", default=128)
50training_group.add_argument('--lr', type=float, help='learning rate, default: 5e-4', default=5e-4)
51training_group.add_argument('--epochs', type=int, help='number of training epochs, default: 50', default=50)
52training_group.add_argument('--sequence-length', type=int, help='sequence length, default: 60', default=60)
53training_group.add_argument('--lr-decay', type=float, help='learning rate decay factor, default: 0.0', default=0.0)
54training_group.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None)
55training_group.add_argument('--reg-weight', type=float, help='regression loss weight, default: 1.0', default=1.0)
56training_group.add_argument('--fmap-weight', type=float, help='feature matchin loss weight, default: 1.0', default=1.)
57
58args = parser.parse_args()
59
60if args.cuda_visible_devices != None:
61    os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices
62
63# checkpoints
64checkpoint_dir = os.path.join(args.output, 'checkpoints')
65checkpoint = dict()
66os.makedirs(checkpoint_dir, exist_ok=True)
67
68
69# training parameters
70batch_size = args.batch_size
71lr = args.lr
72epochs = args.epochs
73sequence_length = args.sequence_length
74lr_decay = args.lr_decay
75
76adam_betas = [0.8, 0.99]
77adam_eps = 1e-8
78features_file = args.features
79signal_file = args.signal
80
81# model parameters
82cond_size  = args.cond_size
83
84
85checkpoint['batch_size'] = batch_size
86checkpoint['lr'] = lr
87checkpoint['lr_decay'] = lr_decay
88checkpoint['epochs'] = epochs
89checkpoint['sequence_length'] = sequence_length
90checkpoint['adam_betas'] = adam_betas
91
92
93device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
94
95checkpoint['model_args']    = ()
96checkpoint['model_kwargs']  = {'cond_size': cond_size, 'gamma': args.gamma}
97print(checkpoint['model_kwargs'])
98model = fargan.FARGAN(*checkpoint['model_args'], **checkpoint['model_kwargs'])
99
100
101#discriminator
102disc_name = 'fdmresdisc'
103disc = osce_models.model_dict[disc_name](
104    architecture='free',
105    design='f_down',
106    fft_sizes_16k=[2**n for n in range(6, 12)],
107    freq_roi=[0, 7400],
108    max_channels=256,
109    noise_gain=0.0
110)
111
112if type(args.initial_checkpoint) != type(None):
113    checkpoint = torch.load(args.initial_checkpoint, map_location='cpu')
114    model.load_state_dict(checkpoint['state_dict'], strict=False)
115
116checkpoint['state_dict']    = model.state_dict()
117
118
119dataset = FARGANDataset(features_file, signal_file, sequence_length=sequence_length)
120dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
121
122
123optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=adam_betas, eps=adam_eps)
124optimizer_disc = torch.optim.AdamW([p for p in disc.parameters() if p.requires_grad], lr=lr, betas=adam_betas, eps=adam_eps)
125
126
127# learning rate scheduler
128scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay * x))
129scheduler_disc = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer_disc, lr_lambda=lambda x : 1 / (1 + lr_decay * x))
130
131states = None
132
133spect_loss =  MultiResolutionSTFTLoss(device).to(device)
134
135for param in model.parameters():
136    param.requires_grad = False
137
138batch_count = 0
139if __name__ == '__main__':
140    model.to(device)
141    disc.to(device)
142
143    for epoch in range(1, epochs + 1):
144
145        m_r = 0
146        m_f = 0
147        s_r = 1
148        s_f = 1
149
150        running_cont_loss = 0
151        running_disc_loss = 0
152        running_gen_loss = 0
153        running_fmap_loss = 0
154        running_reg_loss = 0
155        running_wc = 0
156
157        print(f"training epoch {epoch}...")
158        with tqdm.tqdm(dataloader, unit='batch') as tepoch:
159            for i, (features, periods, target, lpc) in enumerate(tepoch):
160                if epoch == 1 and i == 400:
161                    for param in model.parameters():
162                        param.requires_grad = True
163                    for param in model.cond_net.parameters():
164                        param.requires_grad = False
165                    for param in model.sig_net.cond_gain_dense.parameters():
166                        param.requires_grad = False
167
168                optimizer.zero_grad()
169                features = features.to(device)
170                #lpc = lpc.to(device)
171                #lpc = lpc*(args.gamma**torch.arange(1,17, device=device))
172                #lpc = fargan.interp_lpc(lpc, 4)
173                periods = periods.to(device)
174                if True:
175                    target = target[:, :sequence_length*160]
176                    #lpc = lpc[:,:sequence_length*4,:]
177                    features = features[:,:sequence_length+4,:]
178                    periods = periods[:,:sequence_length+4]
179                else:
180                    target=target[::2, :]
181                    #lpc=lpc[::2,:]
182                    features=features[::2,:]
183                    periods=periods[::2,:]
184                target = target.to(device)
185                #target = fargan.analysis_filter(target, lpc[:,:,:], nb_subframes=1, gamma=args.gamma)
186
187                #nb_pre = random.randrange(1, 6)
188                nb_pre = 2
189                pre = target[:, :nb_pre*160]
190                output, _ = model(features, periods, target.size(1)//160 - nb_pre, pre=pre, states=None)
191                output = torch.cat([pre, output], -1)
192
193
194                # discriminator update
195                scores_gen = disc(output.detach().unsqueeze(1))
196                scores_real = disc(target.unsqueeze(1))
197
198                disc_loss = 0
199                for scale in scores_gen:
200                    disc_loss += ((scale[-1]) ** 2).mean()
201                    m_f = 0.9 * m_f + 0.1 * scale[-1].detach().mean().cpu().item()
202                    s_f = 0.9 * s_f + 0.1 * scale[-1].detach().std().cpu().item()
203
204                for scale in scores_real:
205                    disc_loss += ((1 - scale[-1]) ** 2).mean()
206                    m_r = 0.9 * m_r + 0.1 * scale[-1].detach().mean().cpu().item()
207                    s_r = 0.9 * s_r + 0.1 * scale[-1].detach().std().cpu().item()
208
209                disc_loss = 0.5 * disc_loss / len(scores_gen)
210                winning_chance = 0.5 * m.erfc( (m_r - m_f) / m.sqrt(2 * (s_f**2 + s_r**2)) )
211                running_wc += winning_chance
212
213                disc.zero_grad()
214                disc_loss.backward()
215                optimizer_disc.step()
216
217                # model update
218                scores_gen = disc(output.unsqueeze(1))
219                if False: # todo: check whether that makes a difference
220                    with torch.no_grad():
221                        scores_real = disc(target.unsqueeze(1))
222
223                cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+80], output[:, nb_pre*160:nb_pre*160+80])
224                specc_loss = spect_loss(output, target.detach())
225                reg_loss = (.00*cont_loss + specc_loss)
226
227                loss_gen = 0
228                for scale in scores_gen:
229                    loss_gen += ((1 - scale[-1]) ** 2).mean() / len(scores_gen)
230
231                feat_loss = args.fmap_weight * fmap_loss(scores_real, scores_gen)
232
233                reg_weight = args.reg_weight# + 15./(1 + (batch_count/7600.))
234                gen_loss = reg_weight * reg_loss +  feat_loss + loss_gen
235
236                model.zero_grad()
237
238
239                gen_loss.backward()
240                optimizer.step()
241
242                #model.clip_weights()
243
244                scheduler.step()
245                scheduler_disc.step()
246
247                running_cont_loss += cont_loss.detach().cpu().item()
248                running_gen_loss += loss_gen.detach().cpu().item()
249                running_disc_loss += disc_loss.detach().cpu().item()
250                running_fmap_loss += feat_loss.detach().cpu().item()
251                running_reg_loss += reg_loss.detach().cpu().item()
252
253
254
255                tepoch.set_postfix(cont_loss=f"{running_cont_loss/(i+1):8.5f}",
256                                   reg_weight=f"{reg_weight:8.5f}",
257                                   gen_loss=f"{running_gen_loss/(i+1):8.5f}",
258                                   disc_loss=f"{running_disc_loss/(i+1):8.5f}",
259                                   fmap_loss=f"{running_fmap_loss/(i+1):8.5f}",
260                                   reg_loss=f"{running_reg_loss/(i+1):8.5f}",
261                                   wc = f"{running_wc/(i+1):8.5f}",
262                                   )
263                batch_count = batch_count + 1
264
265        # save checkpoint
266        checkpoint_path = os.path.join(checkpoint_dir, f'fargan{args.suffix}_adv_{epoch}.pth')
267        checkpoint['state_dict'] = model.state_dict()
268        checkpoint['disc_sate_dict'] = disc.state_dict()
269        checkpoint['loss'] = {
270            'cont': running_cont_loss / len(dataloader),
271            'gen': running_gen_loss / len(dataloader),
272            'disc': running_disc_loss / len(dataloader),
273            'fmap': running_fmap_loss / len(dataloader),
274            'reg': running_reg_loss / len(dataloader)
275        }
276        checkpoint['epoch'] = epoch
277        torch.save(checkpoint, checkpoint_path)
278