xref: /aosp_15_r20/external/libopus/dnn/torch/fwgan/models/fwgan500.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4from torch.nn.utils import weight_norm
5import numpy as np
6
7
8which_norm = weight_norm
9
10#################### Definition of basic model components ####################
11
12#Convolutional layer with 1 frame look-ahead (used for feature PreCondNet)
13class ConvLookahead(nn.Module):
14    def __init__(self, in_ch, out_ch, kernel_size, dilation=1, groups=1, bias= False):
15        super(ConvLookahead, self).__init__()
16        torch.manual_seed(5)
17
18        self.padding_left = (kernel_size - 2) * dilation
19        self.padding_right = 1 * dilation
20
21        self.conv = which_norm(nn.Conv1d(in_ch,out_ch,kernel_size,dilation=dilation, groups=groups, bias= bias))
22
23        self.init_weights()
24
25    def init_weights(self):
26
27        for m in self.modules():
28            if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
29                nn.init.orthogonal_(m.weight.data)
30
31    def forward(self, x):
32
33        x = F.pad(x,(self.padding_left, self.padding_right))
34        conv_out = self.conv(x)
35        return conv_out
36
37#(modified) GLU Activation layer definition
38class GLU(nn.Module):
39    def __init__(self, feat_size):
40        super(GLU, self).__init__()
41
42        torch.manual_seed(5)
43
44        self.gate = which_norm(nn.Linear(feat_size, feat_size, bias=False))
45
46        self.init_weights()
47
48    def init_weights(self):
49
50        for m in self.modules():
51            if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
52            or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
53                nn.init.orthogonal_(m.weight.data)
54
55    def forward(self, x):
56
57        out = torch.tanh(x) * torch.sigmoid(self.gate(x))
58
59        return out
60
61#GRU layer definition
62class ContForwardGRU(nn.Module):
63    def __init__(self, input_size, hidden_size, num_layers=1):
64        super(ContForwardGRU, self).__init__()
65
66        torch.manual_seed(5)
67
68        self.hidden_size = hidden_size
69
70        #This is to initialize the layer with history audio samples for continuation.
71        self.cont_fc = nn.Sequential(which_norm(nn.Linear(320, self.hidden_size, bias=False)),
72                        nn.Tanh())
73
74        self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,\
75                          bias=False)
76
77        self.nl = GLU(self.hidden_size)
78
79        self.init_weights()
80
81    def init_weights(self):
82
83        for m in self.modules():
84            if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
85                nn.init.orthogonal_(m.weight.data)
86
87    def forward(self, x, x0):
88
89        self.gru.flatten_parameters()
90
91        h0 = self.cont_fc(x0).unsqueeze(0)
92
93        output, h0 = self.gru(x, h0)
94
95        return self.nl(output)
96
97# Framewise convolution layer definition
98class ContFramewiseConv(torch.nn.Module):
99
100    def __init__(self, frame_len, out_dim, frame_kernel_size=3, act='glu', causal=True):
101
102        super(ContFramewiseConv, self).__init__()
103        torch.manual_seed(5)
104
105        self.frame_kernel_size = frame_kernel_size
106        self.frame_len = frame_len
107
108        if (causal == True) or (self.frame_kernel_size == 2):
109
110            self.required_pad_left = (self.frame_kernel_size - 1) * self.frame_len
111            self.required_pad_right = 0
112
113            #This is to initialize the layer with history audio samples for continuation.
114            self.cont_fc = nn.Sequential(which_norm(nn.Linear(320, self.required_pad_left, bias=False)),
115                                    nn.Tanh()
116                                   )
117
118        else:
119            #This means non-causal frame-wise convolution. We don't use it at the moment
120            self.required_pad_left = (self.frame_kernel_size - 1)//2 * self.frame_len
121            self.required_pad_right = (self.frame_kernel_size - 1)//2 * self.frame_len
122
123        self.fc_input_dim = self.frame_kernel_size * self.frame_len
124        self.fc_out_dim = out_dim
125
126        if act=='glu':
127            self.fc = nn.Sequential(which_norm(nn.Linear(self.fc_input_dim, self.fc_out_dim, bias=False)),
128                                    GLU(self.fc_out_dim)
129                                   )
130        if act=='tanh':
131            self.fc = nn.Sequential(which_norm(nn.Linear(self.fc_input_dim, self.fc_out_dim, bias=False)),
132                                    nn.Tanh()
133                                   )
134
135        self.init_weights()
136
137
138    def init_weights(self):
139
140        for m in self.modules():
141            if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or\
142            isinstance(m, nn.Embedding):
143                nn.init.orthogonal_(m.weight.data)
144
145    def forward(self, x, x0):
146
147        if self.frame_kernel_size == 1:
148            return self.fc(x)
149
150        x_flat = x.reshape(x.size(0),1,-1)
151        pad = self.cont_fc(x0).view(x0.size(0),1,-1)
152        x_flat_padded = torch.cat((pad, x_flat), dim=-1).unsqueeze(2)
153
154        x_flat_padded_unfolded = F.unfold(x_flat_padded,\
155                    kernel_size= (1,self.fc_input_dim), stride=self.frame_len).permute(0,2,1).contiguous()
156
157        out = self.fc(x_flat_padded_unfolded)
158        return out
159
160########################### The complete model definition #################################
161
162class FWGAN500Cont(nn.Module):
163    def __init__(self):
164        super().__init__()
165        torch.manual_seed(5)
166
167        #PrecondNet:
168        self.bfcc_with_corr_upsampler = nn.Sequential(nn.ConvTranspose1d(19,64,kernel_size=5,stride=5,padding=0,\
169                                                      bias=False),
170                                                      nn.Tanh())
171
172        self.feat_in_conv = ConvLookahead(128,256,kernel_size=5)
173        self.feat_in_nl = GLU(256)
174
175        #GRU:
176        self.rnn = ContForwardGRU(256,256)
177
178        #Frame-wise convolution stack:
179        self.fwc1 = ContFramewiseConv(256, 256)
180        self.fwc2 = ContFramewiseConv(256, 128)
181        self.fwc3 = ContFramewiseConv(128, 128)
182        self.fwc4 = ContFramewiseConv(128, 64)
183        self.fwc5 = ContFramewiseConv(64, 64)
184        self.fwc6 = ContFramewiseConv(64, 32)
185        self.fwc7 = ContFramewiseConv(32, 32, act='tanh')
186
187        self.init_weights()
188        self.count_parameters()
189
190    def init_weights(self):
191
192        for m in self.modules():
193            if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or\
194            isinstance(m, nn.Embedding):
195                nn.init.orthogonal_(m.weight.data)
196
197    def count_parameters(self):
198        num_params =  sum(p.numel() for p in self.parameters() if p.requires_grad)
199        print(f"Total number of {self.__class__.__name__} network parameters = {num_params}\n")
200
201    def create_phase_signals(self, periods):
202
203        batch_size = periods.size(0)
204        progression = torch.arange(1, 160 + 1, dtype=periods.dtype, device=periods.device).view((1, -1))
205        progression = torch.repeat_interleave(progression, batch_size, 0)
206
207        phase0 = torch.zeros(batch_size, dtype=periods.dtype, device=periods.device).unsqueeze(-1)
208        chunks = []
209        for sframe in range(periods.size(1)):
210            f = (2.0 * torch.pi / periods[:, sframe]).unsqueeze(-1)
211
212            chunk_sin = torch.sin(f  * progression + phase0)
213            chunk_sin = chunk_sin.reshape(chunk_sin.size(0),-1,32)
214
215            chunk_cos = torch.cos(f  * progression + phase0)
216            chunk_cos = chunk_cos.reshape(chunk_cos.size(0),-1,32)
217
218            chunk = torch.cat((chunk_sin, chunk_cos), dim = -1)
219
220            phase0 = phase0 + 160 * f
221
222            chunks.append(chunk)
223
224        phase_signals = torch.cat(chunks, dim=1)
225
226        return phase_signals
227
228
229    def gain_multiply(self, x, c0):
230
231        gain = 10**(0.5*c0/np.sqrt(18.0))
232        gain = torch.repeat_interleave(gain, 160, dim=-1)
233        gain  = gain.reshape(gain.size(0),1,-1).squeeze(1)
234
235        return x * gain
236
237    def forward(self, pitch_period, bfcc_with_corr, x0):
238
239        #This should create a latent representation of shape [Batch_dim, 500 frames, 256 elemets per frame]
240        p_embed = self.create_phase_signals(pitch_period).permute(0, 2, 1).contiguous()
241        envelope = self.bfcc_with_corr_upsampler(bfcc_with_corr.permute(0,2,1).contiguous())
242        feat_in = torch.cat((p_embed , envelope), dim=1)
243        wav_latent = self.feat_in_nl(self.feat_in_conv(feat_in).permute(0,2,1).contiguous())
244
245        #Generation with continuation using history samples x0 starts from here:
246
247        rnn_out = self.rnn(wav_latent, x0)
248
249        fwc1_out = self.fwc1(rnn_out, x0)
250        fwc2_out = self.fwc2(fwc1_out, x0)
251        fwc3_out = self.fwc3(fwc2_out, x0)
252        fwc4_out = self.fwc4(fwc3_out, x0)
253        fwc5_out = self.fwc5(fwc4_out, x0)
254        fwc6_out = self.fwc6(fwc5_out, x0)
255        fwc7_out = self.fwc7(fwc6_out, x0)
256
257        waveform_unscaled = fwc7_out.reshape(fwc7_out.size(0),1,-1).squeeze(1)
258        waveform = self.gain_multiply(waveform_unscaled,bfcc_with_corr[:,:,:1])
259
260        return  waveform
261