1import numpy as np 2import torch 3from torch import nn 4import torch.nn.functional as F 5import filters 6from torch.nn.utils import weight_norm 7#from convert_lsp import lpc_to_lsp, lsp_to_lpc 8from rc import lpc2rc, rc2lpc 9 10Fs = 16000 11 12fid_dict = {} 13def dump_signal(x, filename): 14 return 15 if filename in fid_dict: 16 fid = fid_dict[filename] 17 else: 18 fid = open(filename, "w") 19 fid_dict[filename] = fid 20 x = x.detach().numpy().astype('float32') 21 x.tofile(fid) 22 23 24def sig_l1(y_true, y_pred): 25 return torch.mean(abs(y_true-y_pred))/torch.mean(abs(y_true)) 26 27def sig_loss(y_true, y_pred): 28 t = y_true/(1e-15+torch.norm(y_true, dim=-1, p=2, keepdim=True)) 29 p = y_pred/(1e-15+torch.norm(y_pred, dim=-1, p=2, keepdim=True)) 30 return torch.mean(1.-torch.sum(p*t, dim=-1)) 31 32def interp_lpc(lpc, factor): 33 #print(lpc.shape) 34 #f = (np.arange(factor)+.5*((factor+1)%2))/factor 35 lsp = torch.atanh(lpc2rc(lpc)) 36 #print("lsp0:") 37 #print(lsp) 38 shape = lsp.shape 39 #print("shape is", shape) 40 shape = (shape[0], shape[1]*factor, shape[2]) 41 interp_lsp = torch.zeros(shape, device=lpc.device) 42 for k in range(factor): 43 f = (k+.5*((factor+1)%2))/factor 44 interp = (1-f)*lsp[:,:-1,:] + f*lsp[:,1:,:] 45 interp_lsp[:,factor//2+k:-(factor//2):factor,:] = interp 46 for k in range(factor//2): 47 interp_lsp[:,k,:] = interp_lsp[:,factor//2,:] 48 for k in range((factor+1)//2): 49 interp_lsp[:,-k-1,:] = interp_lsp[:,-(factor+3)//2,:] 50 #print("lsp:") 51 #print(interp_lsp) 52 return rc2lpc(torch.tanh(interp_lsp)) 53 54def analysis_filter(x, lpc, nb_subframes=4, subframe_size=40, gamma=.9): 55 device = x.device 56 batch_size = lpc.size(0) 57 58 nb_frames = lpc.shape[1] 59 60 61 sig = torch.zeros(batch_size, subframe_size+16, device=device) 62 x = torch.reshape(x, (batch_size, nb_frames*nb_subframes, subframe_size)) 63 out = torch.zeros((batch_size, 0), device=device) 64 65 #if gamma is not None: 66 # bw = gamma**(torch.arange(1, 17, device=device)) 67 # lpc = lpc*bw[None,None,:] 68 ones = torch.ones((*(lpc.shape[:-1]), 1), device=device) 69 zeros = torch.zeros((*(lpc.shape[:-1]), subframe_size-1), device=device) 70 a = torch.cat([ones, lpc], -1) 71 a_big = torch.cat([a, zeros], -1) 72 fir_mat_big = filters.toeplitz_from_filter(a_big) 73 74 #print(a_big[:,0,:]) 75 for n in range(nb_frames): 76 for k in range(nb_subframes): 77 78 sig = torch.cat([sig[:,subframe_size:], x[:,n*nb_subframes + k, :]], 1) 79 exc = torch.bmm(fir_mat_big[:,n,:,:], sig[:,:,None]) 80 out = torch.cat([out, exc[:,-subframe_size:,0]], 1) 81 82 return out 83 84 85# weight initialization and clipping 86def init_weights(module): 87 if isinstance(module, nn.GRU): 88 for p in module.named_parameters(): 89 if p[0].startswith('weight_hh_'): 90 nn.init.orthogonal_(p[1]) 91 92def gen_phase_embedding(periods, frame_size): 93 device = periods.device 94 batch_size = periods.size(0) 95 nb_frames = periods.size(1) 96 w0 = 2*torch.pi/periods 97 w0_shift = torch.cat([2*torch.pi*torch.rand((batch_size, 1), device=device)/frame_size, w0[:,:-1]], 1) 98 cum_phase = frame_size*torch.cumsum(w0_shift, 1) 99 fine_phase = w0[:,:,None]*torch.broadcast_to(torch.arange(frame_size, device=device), (batch_size, nb_frames, frame_size)) 100 embed = torch.unsqueeze(cum_phase, 2) + fine_phase 101 embed = torch.reshape(embed, (batch_size, -1)) 102 return torch.cos(embed), torch.sin(embed) 103 104class GLU(nn.Module): 105 def __init__(self, feat_size): 106 super(GLU, self).__init__() 107 108 torch.manual_seed(5) 109 110 self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False)) 111 112 self.init_weights() 113 114 def init_weights(self): 115 116 for m in self.modules(): 117 if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\ 118 or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding): 119 nn.init.orthogonal_(m.weight.data) 120 121 def forward(self, x): 122 123 out = x * torch.sigmoid(self.gate(x)) 124 125 return out 126 127class FWConv(nn.Module): 128 def __init__(self, in_size, out_size, kernel_size=2): 129 super(FWConv, self).__init__() 130 131 torch.manual_seed(5) 132 133 self.in_size = in_size 134 self.kernel_size = kernel_size 135 self.conv = weight_norm(nn.Linear(in_size*self.kernel_size, out_size, bias=False)) 136 self.glu = GLU(out_size) 137 138 self.init_weights() 139 140 def init_weights(self): 141 142 for m in self.modules(): 143 if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\ 144 or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding): 145 nn.init.orthogonal_(m.weight.data) 146 147 def forward(self, x, state): 148 xcat = torch.cat((state, x), -1) 149 #print(x.shape, state.shape, xcat.shape, self.in_size, self.kernel_size) 150 out = self.glu(torch.tanh(self.conv(xcat))) 151 return out, xcat[:,self.in_size:] 152 153def n(x): 154 return torch.clamp(x + (1./127.)*(torch.rand_like(x)-.5), min=-1., max=1.) 155 156class FARGANCond(nn.Module): 157 def __init__(self, feature_dim=20, cond_size=256, pembed_dims=12): 158 super(FARGANCond, self).__init__() 159 160 self.feature_dim = feature_dim 161 self.cond_size = cond_size 162 163 self.pembed = nn.Embedding(224, pembed_dims) 164 self.fdense1 = nn.Linear(self.feature_dim + pembed_dims, 64, bias=False) 165 self.fconv1 = nn.Conv1d(64, 128, kernel_size=3, padding='valid', bias=False) 166 self.fdense2 = nn.Linear(128, 80*4, bias=False) 167 168 self.apply(init_weights) 169 nb_params = sum(p.numel() for p in self.parameters()) 170 print(f"cond model: {nb_params} weights") 171 172 def forward(self, features, period): 173 features = features[:,2:,:] 174 period = period[:,2:] 175 p = self.pembed(period-32) 176 features = torch.cat((features, p), -1) 177 tmp = torch.tanh(self.fdense1(features)) 178 tmp = tmp.permute(0, 2, 1) 179 tmp = torch.tanh(self.fconv1(tmp)) 180 tmp = tmp.permute(0, 2, 1) 181 tmp = torch.tanh(self.fdense2(tmp)) 182 #tmp = torch.tanh(self.fdense2(tmp)) 183 return tmp 184 185class FARGANSub(nn.Module): 186 def __init__(self, subframe_size=40, nb_subframes=4, cond_size=256): 187 super(FARGANSub, self).__init__() 188 189 self.subframe_size = subframe_size 190 self.nb_subframes = nb_subframes 191 self.cond_size = cond_size 192 self.cond_gain_dense = nn.Linear(80, 1) 193 194 #self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False) 195 self.fwc0 = FWConv(2*self.subframe_size+80+4, 192) 196 self.gru1 = nn.GRUCell(192+2*self.subframe_size, 160, bias=False) 197 self.gru2 = nn.GRUCell(160+2*self.subframe_size, 128, bias=False) 198 self.gru3 = nn.GRUCell(128+2*self.subframe_size, 128, bias=False) 199 200 self.gru1_glu = GLU(160) 201 self.gru2_glu = GLU(128) 202 self.gru3_glu = GLU(128) 203 self.skip_glu = GLU(128) 204 #self.ptaps_dense = nn.Linear(4*self.cond_size, 5) 205 206 self.skip_dense = nn.Linear(192+160+2*128+2*self.subframe_size, 128, bias=False) 207 self.sig_dense_out = nn.Linear(128, self.subframe_size, bias=False) 208 self.gain_dense_out = nn.Linear(192, 4) 209 210 211 self.apply(init_weights) 212 nb_params = sum(p.numel() for p in self.parameters()) 213 print(f"subframe model: {nb_params} weights") 214 215 def forward(self, cond, prev_pred, exc_mem, period, states, gain=None): 216 device = exc_mem.device 217 #print(cond.shape, prev.shape) 218 219 cond = n(cond) 220 dump_signal(gain, 'gain0.f32') 221 gain = torch.exp(self.cond_gain_dense(cond)) 222 dump_signal(gain, 'gain1.f32') 223 idx = 256-period[:,None] 224 rng = torch.arange(self.subframe_size+4, device=device) 225 idx = idx + rng[None,:] - 2 226 mask = idx >= 256 227 idx = idx - mask*period[:,None] 228 pred = torch.gather(exc_mem, 1, idx) 229 pred = n(pred/(1e-5+gain)) 230 231 prev = exc_mem[:,-self.subframe_size:] 232 dump_signal(prev, 'prev_in.f32') 233 prev = n(prev/(1e-5+gain)) 234 dump_signal(prev, 'pitch_exc.f32') 235 dump_signal(exc_mem, 'exc_mem.f32') 236 237 tmp = torch.cat((cond, pred, prev), 1) 238 #fpitch = taps[:,0:1]*pred[:,:-4] + taps[:,1:2]*pred[:,1:-3] + taps[:,2:3]*pred[:,2:-2] + taps[:,3:4]*pred[:,3:-1] + taps[:,4:]*pred[:,4:] 239 fpitch = pred[:,2:-2] 240 241 #tmp = self.dense1_glu(torch.tanh(self.sig_dense1(tmp))) 242 fwc0_out, fwc0_state = self.fwc0(tmp, states[3]) 243 fwc0_out = n(fwc0_out) 244 pitch_gain = torch.sigmoid(self.gain_dense_out(fwc0_out)) 245 246 gru1_state = self.gru1(torch.cat([fwc0_out, pitch_gain[:,0:1]*fpitch, prev], 1), states[0]) 247 gru1_out = self.gru1_glu(n(gru1_state)) 248 gru1_out = n(gru1_out) 249 gru2_state = self.gru2(torch.cat([gru1_out, pitch_gain[:,1:2]*fpitch, prev], 1), states[1]) 250 gru2_out = self.gru2_glu(n(gru2_state)) 251 gru2_out = n(gru2_out) 252 gru3_state = self.gru3(torch.cat([gru2_out, pitch_gain[:,2:3]*fpitch, prev], 1), states[2]) 253 gru3_out = self.gru3_glu(n(gru3_state)) 254 gru3_out = n(gru3_out) 255 gru3_out = torch.cat([gru1_out, gru2_out, gru3_out, fwc0_out], 1) 256 skip_out = torch.tanh(self.skip_dense(torch.cat([gru3_out, pitch_gain[:,3:4]*fpitch, prev], 1))) 257 skip_out = self.skip_glu(n(skip_out)) 258 sig_out = torch.tanh(self.sig_dense_out(skip_out)) 259 dump_signal(sig_out, 'exc_out.f32') 260 #taps = self.ptaps_dense(gru3_out) 261 #taps = .2*taps + torch.exp(taps) 262 #taps = taps / (1e-2 + torch.sum(torch.abs(taps), dim=-1, keepdim=True)) 263 #dump_signal(taps, 'taps.f32') 264 265 dump_signal(pitch_gain, 'pgain.f32') 266 #sig_out = (sig_out + pitch_gain*fpitch) * gain 267 sig_out = sig_out * gain 268 exc_mem = torch.cat([exc_mem[:,self.subframe_size:], sig_out], 1) 269 prev_pred = torch.cat([prev_pred[:,self.subframe_size:], fpitch], 1) 270 dump_signal(sig_out, 'sig_out.f32') 271 return sig_out, exc_mem, prev_pred, (gru1_state, gru2_state, gru3_state, fwc0_state) 272 273class FARGAN(nn.Module): 274 def __init__(self, subframe_size=40, nb_subframes=4, feature_dim=20, cond_size=256, passthrough_size=0, has_gain=False, gamma=None): 275 super(FARGAN, self).__init__() 276 277 self.subframe_size = subframe_size 278 self.nb_subframes = nb_subframes 279 self.frame_size = self.subframe_size*self.nb_subframes 280 self.feature_dim = feature_dim 281 self.cond_size = cond_size 282 283 self.cond_net = FARGANCond(feature_dim=feature_dim, cond_size=cond_size) 284 self.sig_net = FARGANSub(subframe_size=subframe_size, nb_subframes=nb_subframes, cond_size=cond_size) 285 286 def forward(self, features, period, nb_frames, pre=None, states=None): 287 device = features.device 288 batch_size = features.size(0) 289 290 prev = torch.zeros(batch_size, 256, device=device) 291 exc_mem = torch.zeros(batch_size, 256, device=device) 292 nb_pre_frames = pre.size(1)//self.frame_size if pre is not None else 0 293 294 states = ( 295 torch.zeros(batch_size, 160, device=device), 296 torch.zeros(batch_size, 128, device=device), 297 torch.zeros(batch_size, 128, device=device), 298 torch.zeros(batch_size, (2*self.subframe_size+80+4)*1, device=device) 299 ) 300 301 sig = torch.zeros((batch_size, 0), device=device) 302 cond = self.cond_net(features, period) 303 if pre is not None: 304 exc_mem[:,-self.frame_size:] = pre[:, :self.frame_size] 305 start = 1 if nb_pre_frames>0 else 0 306 for n in range(start, nb_frames+nb_pre_frames): 307 for k in range(self.nb_subframes): 308 pos = n*self.frame_size + k*self.subframe_size 309 #print("now: ", preal.shape, prev.shape, sig_in.shape) 310 pitch = period[:, 3+n] 311 gain = .03*10**(0.5*features[:, 3+n, 0:1]/np.sqrt(18.0)) 312 #gain = gain[:,:,None] 313 out, exc_mem, prev, states = self.sig_net(cond[:, n, k*80:(k+1)*80], prev, exc_mem, pitch, states, gain=gain) 314 315 if n < nb_pre_frames: 316 out = pre[:, pos:pos+self.subframe_size] 317 exc_mem[:,-self.subframe_size:] = out 318 else: 319 sig = torch.cat([sig, out], 1) 320 321 states = [s.detach() for s in states] 322 return sig, states 323