1import numpy as np 2import torch 3from torch import nn 4import torch.nn.functional as F 5from torch.nn.utils import weight_norm 6import math 7 8fid_dict = {} 9def dump_signal(x, filename): 10 return 11 if filename in fid_dict: 12 fid = fid_dict[filename] 13 else: 14 fid = open(filename, "w") 15 fid_dict[filename] = fid 16 x = x.detach().numpy().astype('float32') 17 x.tofile(fid) 18 19 20class IDCT(nn.Module): 21 def __init__(self, N, device=None): 22 super(IDCT, self).__init__() 23 24 self.N = N 25 n = torch.arange(N, device=device) 26 k = torch.arange(N, device=device) 27 self.table = torch.cos(torch.pi/N * (n[:,None]+.5) * k[None,:]) 28 self.table[:,0] = self.table[:,0] * math.sqrt(.5) 29 self.table = self.table / math.sqrt(N/2) 30 31 def forward(self, x): 32 return F.linear(x, self.table, None) 33 34def plc_loss(N, device=None, alpha=1.0, bias=1.): 35 idct = IDCT(18, device=device) 36 def loss(y_true,y_pred): 37 mask = y_true[:,:,-1:] 38 y_true = y_true[:,:,:-1] 39 e = (y_pred - y_true)*mask 40 e_bands = idct(e[:,:,:-2]) 41 bias_mask = torch.clamp(4*y_true[:,:,-1:], min=0., max=1.) 42 l1_loss = torch.mean(torch.abs(e)) 43 ceps_loss = torch.mean(torch.abs(e[:,:,:-2])) 44 band_loss = torch.mean(torch.abs(e_bands)) 45 biased_loss = torch.mean(bias_mask*torch.clamp(e_bands, min=0.)) 46 pitch_loss1 = torch.mean(torch.clamp(torch.abs(e[:,:,18:19]),max=1.)) 47 pitch_loss = torch.mean(torch.clamp(torch.abs(e[:,:,18:19]),max=.4)) 48 voice_bias = torch.mean(torch.clamp(-e[:,:,-1:], min=0.)) 49 tot = l1_loss + 0.1*voice_bias + alpha*(band_loss + bias*biased_loss) + pitch_loss1 + 8*pitch_loss 50 return tot, l1_loss, ceps_loss, band_loss, pitch_loss 51 return loss 52 53 54# weight initialization and clipping 55def init_weights(module): 56 if isinstance(module, nn.GRU): 57 for p in module.named_parameters(): 58 if p[0].startswith('weight_hh_'): 59 nn.init.orthogonal_(p[1]) 60 61 62class GLU(nn.Module): 63 def __init__(self, feat_size): 64 super(GLU, self).__init__() 65 66 torch.manual_seed(5) 67 68 self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False)) 69 70 self.init_weights() 71 72 def init_weights(self): 73 74 for m in self.modules(): 75 if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\ 76 or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding): 77 nn.init.orthogonal_(m.weight.data) 78 79 def forward(self, x): 80 81 out = x * torch.sigmoid(self.gate(x)) 82 83 return out 84 85class FWConv(nn.Module): 86 def __init__(self, in_size, out_size, kernel_size=2): 87 super(FWConv, self).__init__() 88 89 torch.manual_seed(5) 90 91 self.in_size = in_size 92 self.kernel_size = kernel_size 93 self.conv = weight_norm(nn.Linear(in_size*self.kernel_size, out_size, bias=False)) 94 self.glu = GLU(out_size) 95 96 self.init_weights() 97 98 def init_weights(self): 99 100 for m in self.modules(): 101 if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\ 102 or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding): 103 nn.init.orthogonal_(m.weight.data) 104 105 def forward(self, x, state): 106 xcat = torch.cat((state, x), -1) 107 out = self.glu(torch.tanh(self.conv(xcat))) 108 return out, xcat[:,self.in_size:] 109 110def n(x): 111 return torch.clamp(x + (1./127.)*(torch.rand_like(x)-.5), min=-1., max=1.) 112 113class PLC(nn.Module): 114 def __init__(self, features_in=57, features_out=20, cond_size=128, gru_size=128): 115 super(PLC, self).__init__() 116 117 self.features_in = features_in 118 self.features_out = features_out 119 self.cond_size = cond_size 120 self.gru_size = gru_size 121 122 self.dense_in = nn.Linear(self.features_in, self.cond_size) 123 self.gru1 = nn.GRU(self.cond_size, self.gru_size, batch_first=True) 124 self.gru2 = nn.GRU(self.gru_size, self.gru_size, batch_first=True) 125 self.dense_out = nn.Linear(self.gru_size, features_out) 126 127 self.apply(init_weights) 128 nb_params = sum(p.numel() for p in self.parameters()) 129 print(f"plc model: {nb_params} weights") 130 131 def forward(self, features, lost, states=None): 132 device = features.device 133 batch_size = features.size(0) 134 if states is None: 135 gru1_state = torch.zeros((1, batch_size, self.gru_size), device=device) 136 gru2_state = torch.zeros((1, batch_size, self.gru_size), device=device) 137 else: 138 gru1_state = states[0] 139 gru2_state = states[1] 140 x = torch.cat([features, lost], dim=-1) 141 x = torch.tanh(self.dense_in(x)) 142 gru1_out, gru1_state = self.gru1(x, gru1_state) 143 gru2_out, gru2_state = self.gru2(gru1_out, gru2_state) 144 return self.dense_out(gru2_out), [gru1_state, gru2_state] 145