1import torch 2import torch.nn as nn 3import torch.nn.functional as F 4from torch.nn.utils import weight_norm 5import numpy as np 6 7 8which_norm = weight_norm 9 10#################### Definition of basic model components #################### 11 12#Convolutional layer with 1 frame look-ahead (used for feature PreCondNet) 13class ConvLookahead(nn.Module): 14 def __init__(self, in_ch, out_ch, kernel_size, dilation=1, groups=1, bias= False): 15 super(ConvLookahead, self).__init__() 16 torch.manual_seed(5) 17 18 self.padding_left = (kernel_size - 2) * dilation 19 self.padding_right = 1 * dilation 20 21 self.conv = which_norm(nn.Conv1d(in_ch,out_ch,kernel_size,dilation=dilation, groups=groups, bias= bias)) 22 23 self.init_weights() 24 25 def init_weights(self): 26 27 for m in self.modules(): 28 if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding): 29 nn.init.orthogonal_(m.weight.data) 30 31 def forward(self, x): 32 33 x = F.pad(x,(self.padding_left, self.padding_right)) 34 conv_out = self.conv(x) 35 return conv_out 36 37#(modified) GLU Activation layer definition 38class GLU(nn.Module): 39 def __init__(self, feat_size): 40 super(GLU, self).__init__() 41 42 torch.manual_seed(5) 43 44 self.gate = which_norm(nn.Linear(feat_size, feat_size, bias=False)) 45 46 self.init_weights() 47 48 def init_weights(self): 49 50 for m in self.modules(): 51 if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\ 52 or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding): 53 nn.init.orthogonal_(m.weight.data) 54 55 def forward(self, x): 56 57 out = torch.tanh(x) * torch.sigmoid(self.gate(x)) 58 59 return out 60 61#GRU layer definition 62class ContForwardGRU(nn.Module): 63 def __init__(self, input_size, hidden_size, num_layers=1): 64 super(ContForwardGRU, self).__init__() 65 66 torch.manual_seed(5) 67 68 self.hidden_size = hidden_size 69 70 #This is to initialize the layer with history audio samples for continuation. 71 self.cont_fc = nn.Sequential(which_norm(nn.Linear(320, self.hidden_size, bias=False)), 72 nn.Tanh()) 73 74 self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,\ 75 bias=False) 76 77 self.nl = GLU(self.hidden_size) 78 79 self.init_weights() 80 81 def init_weights(self): 82 83 for m in self.modules(): 84 if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding): 85 nn.init.orthogonal_(m.weight.data) 86 87 def forward(self, x, x0): 88 89 self.gru.flatten_parameters() 90 91 h0 = self.cont_fc(x0).unsqueeze(0) 92 93 output, h0 = self.gru(x, h0) 94 95 return self.nl(output) 96 97# Framewise convolution layer definition 98class ContFramewiseConv(torch.nn.Module): 99 100 def __init__(self, frame_len, out_dim, frame_kernel_size=3, act='glu', causal=True): 101 102 super(ContFramewiseConv, self).__init__() 103 torch.manual_seed(5) 104 105 self.frame_kernel_size = frame_kernel_size 106 self.frame_len = frame_len 107 108 if (causal == True) or (self.frame_kernel_size == 2): 109 110 self.required_pad_left = (self.frame_kernel_size - 1) * self.frame_len 111 self.required_pad_right = 0 112 113 #This is to initialize the layer with history audio samples for continuation. 114 self.cont_fc = nn.Sequential(which_norm(nn.Linear(320, self.required_pad_left, bias=False)), 115 nn.Tanh() 116 ) 117 118 else: 119 #This means non-causal frame-wise convolution. We don't use it at the moment 120 self.required_pad_left = (self.frame_kernel_size - 1)//2 * self.frame_len 121 self.required_pad_right = (self.frame_kernel_size - 1)//2 * self.frame_len 122 123 self.fc_input_dim = self.frame_kernel_size * self.frame_len 124 self.fc_out_dim = out_dim 125 126 if act=='glu': 127 self.fc = nn.Sequential(which_norm(nn.Linear(self.fc_input_dim, self.fc_out_dim, bias=False)), 128 GLU(self.fc_out_dim) 129 ) 130 if act=='tanh': 131 self.fc = nn.Sequential(which_norm(nn.Linear(self.fc_input_dim, self.fc_out_dim, bias=False)), 132 nn.Tanh() 133 ) 134 135 self.init_weights() 136 137 138 def init_weights(self): 139 140 for m in self.modules(): 141 if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or\ 142 isinstance(m, nn.Embedding): 143 nn.init.orthogonal_(m.weight.data) 144 145 def forward(self, x, x0): 146 147 if self.frame_kernel_size == 1: 148 return self.fc(x) 149 150 x_flat = x.reshape(x.size(0),1,-1) 151 pad = self.cont_fc(x0).view(x0.size(0),1,-1) 152 x_flat_padded = torch.cat((pad, x_flat), dim=-1).unsqueeze(2) 153 154 x_flat_padded_unfolded = F.unfold(x_flat_padded,\ 155 kernel_size= (1,self.fc_input_dim), stride=self.frame_len).permute(0,2,1).contiguous() 156 157 out = self.fc(x_flat_padded_unfolded) 158 return out 159 160########################### The complete model definition ################################# 161 162class FWGAN500Cont(nn.Module): 163 def __init__(self): 164 super().__init__() 165 torch.manual_seed(5) 166 167 #PrecondNet: 168 self.bfcc_with_corr_upsampler = nn.Sequential(nn.ConvTranspose1d(19,64,kernel_size=5,stride=5,padding=0,\ 169 bias=False), 170 nn.Tanh()) 171 172 self.feat_in_conv = ConvLookahead(128,256,kernel_size=5) 173 self.feat_in_nl = GLU(256) 174 175 #GRU: 176 self.rnn = ContForwardGRU(256,256) 177 178 #Frame-wise convolution stack: 179 self.fwc1 = ContFramewiseConv(256, 256) 180 self.fwc2 = ContFramewiseConv(256, 128) 181 self.fwc3 = ContFramewiseConv(128, 128) 182 self.fwc4 = ContFramewiseConv(128, 64) 183 self.fwc5 = ContFramewiseConv(64, 64) 184 self.fwc6 = ContFramewiseConv(64, 32) 185 self.fwc7 = ContFramewiseConv(32, 32, act='tanh') 186 187 self.init_weights() 188 self.count_parameters() 189 190 def init_weights(self): 191 192 for m in self.modules(): 193 if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or\ 194 isinstance(m, nn.Embedding): 195 nn.init.orthogonal_(m.weight.data) 196 197 def count_parameters(self): 198 num_params = sum(p.numel() for p in self.parameters() if p.requires_grad) 199 print(f"Total number of {self.__class__.__name__} network parameters = {num_params}\n") 200 201 def create_phase_signals(self, periods): 202 203 batch_size = periods.size(0) 204 progression = torch.arange(1, 160 + 1, dtype=periods.dtype, device=periods.device).view((1, -1)) 205 progression = torch.repeat_interleave(progression, batch_size, 0) 206 207 phase0 = torch.zeros(batch_size, dtype=periods.dtype, device=periods.device).unsqueeze(-1) 208 chunks = [] 209 for sframe in range(periods.size(1)): 210 f = (2.0 * torch.pi / periods[:, sframe]).unsqueeze(-1) 211 212 chunk_sin = torch.sin(f * progression + phase0) 213 chunk_sin = chunk_sin.reshape(chunk_sin.size(0),-1,32) 214 215 chunk_cos = torch.cos(f * progression + phase0) 216 chunk_cos = chunk_cos.reshape(chunk_cos.size(0),-1,32) 217 218 chunk = torch.cat((chunk_sin, chunk_cos), dim = -1) 219 220 phase0 = phase0 + 160 * f 221 222 chunks.append(chunk) 223 224 phase_signals = torch.cat(chunks, dim=1) 225 226 return phase_signals 227 228 229 def gain_multiply(self, x, c0): 230 231 gain = 10**(0.5*c0/np.sqrt(18.0)) 232 gain = torch.repeat_interleave(gain, 160, dim=-1) 233 gain = gain.reshape(gain.size(0),1,-1).squeeze(1) 234 235 return x * gain 236 237 def forward(self, pitch_period, bfcc_with_corr, x0): 238 239 #This should create a latent representation of shape [Batch_dim, 500 frames, 256 elemets per frame] 240 p_embed = self.create_phase_signals(pitch_period).permute(0, 2, 1).contiguous() 241 envelope = self.bfcc_with_corr_upsampler(bfcc_with_corr.permute(0,2,1).contiguous()) 242 feat_in = torch.cat((p_embed , envelope), dim=1) 243 wav_latent = self.feat_in_nl(self.feat_in_conv(feat_in).permute(0,2,1).contiguous()) 244 245 #Generation with continuation using history samples x0 starts from here: 246 247 rnn_out = self.rnn(wav_latent, x0) 248 249 fwc1_out = self.fwc1(rnn_out, x0) 250 fwc2_out = self.fwc2(fwc1_out, x0) 251 fwc3_out = self.fwc3(fwc2_out, x0) 252 fwc4_out = self.fwc4(fwc3_out, x0) 253 fwc5_out = self.fwc5(fwc4_out, x0) 254 fwc6_out = self.fwc6(fwc5_out, x0) 255 fwc7_out = self.fwc7(fwc6_out, x0) 256 257 waveform_unscaled = fwc7_out.reshape(fwc7_out.size(0),1,-1).squeeze(1) 258 waveform = self.gain_multiply(waveform_unscaled,bfcc_with_corr[:,:,:1]) 259 260 return waveform 261