1*a58d3d2aSXin Liimport numpy as np 2*a58d3d2aSXin Liimport torch 3*a58d3d2aSXin Lifrom torch import nn 4*a58d3d2aSXin Liimport torch.nn.functional as F 5*a58d3d2aSXin Lifrom torch.nn.utils import weight_norm 6*a58d3d2aSXin Liimport math 7*a58d3d2aSXin Li 8*a58d3d2aSXin Lifid_dict = {} 9*a58d3d2aSXin Lidef dump_signal(x, filename): 10*a58d3d2aSXin Li return 11*a58d3d2aSXin Li if filename in fid_dict: 12*a58d3d2aSXin Li fid = fid_dict[filename] 13*a58d3d2aSXin Li else: 14*a58d3d2aSXin Li fid = open(filename, "w") 15*a58d3d2aSXin Li fid_dict[filename] = fid 16*a58d3d2aSXin Li x = x.detach().numpy().astype('float32') 17*a58d3d2aSXin Li x.tofile(fid) 18*a58d3d2aSXin Li 19*a58d3d2aSXin Li 20*a58d3d2aSXin Liclass IDCT(nn.Module): 21*a58d3d2aSXin Li def __init__(self, N, device=None): 22*a58d3d2aSXin Li super(IDCT, self).__init__() 23*a58d3d2aSXin Li 24*a58d3d2aSXin Li self.N = N 25*a58d3d2aSXin Li n = torch.arange(N, device=device) 26*a58d3d2aSXin Li k = torch.arange(N, device=device) 27*a58d3d2aSXin Li self.table = torch.cos(torch.pi/N * (n[:,None]+.5) * k[None,:]) 28*a58d3d2aSXin Li self.table[:,0] = self.table[:,0] * math.sqrt(.5) 29*a58d3d2aSXin Li self.table = self.table / math.sqrt(N/2) 30*a58d3d2aSXin Li 31*a58d3d2aSXin Li def forward(self, x): 32*a58d3d2aSXin Li return F.linear(x, self.table, None) 33*a58d3d2aSXin Li 34*a58d3d2aSXin Lidef plc_loss(N, device=None, alpha=1.0, bias=1.): 35*a58d3d2aSXin Li idct = IDCT(18, device=device) 36*a58d3d2aSXin Li def loss(y_true,y_pred): 37*a58d3d2aSXin Li mask = y_true[:,:,-1:] 38*a58d3d2aSXin Li y_true = y_true[:,:,:-1] 39*a58d3d2aSXin Li e = (y_pred - y_true)*mask 40*a58d3d2aSXin Li e_bands = idct(e[:,:,:-2]) 41*a58d3d2aSXin Li bias_mask = torch.clamp(4*y_true[:,:,-1:], min=0., max=1.) 42*a58d3d2aSXin Li l1_loss = torch.mean(torch.abs(e)) 43*a58d3d2aSXin Li ceps_loss = torch.mean(torch.abs(e[:,:,:-2])) 44*a58d3d2aSXin Li band_loss = torch.mean(torch.abs(e_bands)) 45*a58d3d2aSXin Li biased_loss = torch.mean(bias_mask*torch.clamp(e_bands, min=0.)) 46*a58d3d2aSXin Li pitch_loss1 = torch.mean(torch.clamp(torch.abs(e[:,:,18:19]),max=1.)) 47*a58d3d2aSXin Li pitch_loss = torch.mean(torch.clamp(torch.abs(e[:,:,18:19]),max=.4)) 48*a58d3d2aSXin Li voice_bias = torch.mean(torch.clamp(-e[:,:,-1:], min=0.)) 49*a58d3d2aSXin Li tot = l1_loss + 0.1*voice_bias + alpha*(band_loss + bias*biased_loss) + pitch_loss1 + 8*pitch_loss 50*a58d3d2aSXin Li return tot, l1_loss, ceps_loss, band_loss, pitch_loss 51*a58d3d2aSXin Li return loss 52*a58d3d2aSXin Li 53*a58d3d2aSXin Li 54*a58d3d2aSXin Li# weight initialization and clipping 55*a58d3d2aSXin Lidef init_weights(module): 56*a58d3d2aSXin Li if isinstance(module, nn.GRU): 57*a58d3d2aSXin Li for p in module.named_parameters(): 58*a58d3d2aSXin Li if p[0].startswith('weight_hh_'): 59*a58d3d2aSXin Li nn.init.orthogonal_(p[1]) 60*a58d3d2aSXin Li 61*a58d3d2aSXin Li 62*a58d3d2aSXin Liclass GLU(nn.Module): 63*a58d3d2aSXin Li def __init__(self, feat_size): 64*a58d3d2aSXin Li super(GLU, self).__init__() 65*a58d3d2aSXin Li 66*a58d3d2aSXin Li torch.manual_seed(5) 67*a58d3d2aSXin Li 68*a58d3d2aSXin Li self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False)) 69*a58d3d2aSXin Li 70*a58d3d2aSXin Li self.init_weights() 71*a58d3d2aSXin Li 72*a58d3d2aSXin Li def init_weights(self): 73*a58d3d2aSXin Li 74*a58d3d2aSXin Li for m in self.modules(): 75*a58d3d2aSXin Li if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\ 76*a58d3d2aSXin Li or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding): 77*a58d3d2aSXin Li nn.init.orthogonal_(m.weight.data) 78*a58d3d2aSXin Li 79*a58d3d2aSXin Li def forward(self, x): 80*a58d3d2aSXin Li 81*a58d3d2aSXin Li out = x * torch.sigmoid(self.gate(x)) 82*a58d3d2aSXin Li 83*a58d3d2aSXin Li return out 84*a58d3d2aSXin Li 85*a58d3d2aSXin Liclass FWConv(nn.Module): 86*a58d3d2aSXin Li def __init__(self, in_size, out_size, kernel_size=2): 87*a58d3d2aSXin Li super(FWConv, self).__init__() 88*a58d3d2aSXin Li 89*a58d3d2aSXin Li torch.manual_seed(5) 90*a58d3d2aSXin Li 91*a58d3d2aSXin Li self.in_size = in_size 92*a58d3d2aSXin Li self.kernel_size = kernel_size 93*a58d3d2aSXin Li self.conv = weight_norm(nn.Linear(in_size*self.kernel_size, out_size, bias=False)) 94*a58d3d2aSXin Li self.glu = GLU(out_size) 95*a58d3d2aSXin Li 96*a58d3d2aSXin Li self.init_weights() 97*a58d3d2aSXin Li 98*a58d3d2aSXin Li def init_weights(self): 99*a58d3d2aSXin Li 100*a58d3d2aSXin Li for m in self.modules(): 101*a58d3d2aSXin Li if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\ 102*a58d3d2aSXin Li or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding): 103*a58d3d2aSXin Li nn.init.orthogonal_(m.weight.data) 104*a58d3d2aSXin Li 105*a58d3d2aSXin Li def forward(self, x, state): 106*a58d3d2aSXin Li xcat = torch.cat((state, x), -1) 107*a58d3d2aSXin Li out = self.glu(torch.tanh(self.conv(xcat))) 108*a58d3d2aSXin Li return out, xcat[:,self.in_size:] 109*a58d3d2aSXin Li 110*a58d3d2aSXin Lidef n(x): 111*a58d3d2aSXin Li return torch.clamp(x + (1./127.)*(torch.rand_like(x)-.5), min=-1., max=1.) 112*a58d3d2aSXin Li 113*a58d3d2aSXin Liclass PLC(nn.Module): 114*a58d3d2aSXin Li def __init__(self, features_in=57, features_out=20, cond_size=128, gru_size=128): 115*a58d3d2aSXin Li super(PLC, self).__init__() 116*a58d3d2aSXin Li 117*a58d3d2aSXin Li self.features_in = features_in 118*a58d3d2aSXin Li self.features_out = features_out 119*a58d3d2aSXin Li self.cond_size = cond_size 120*a58d3d2aSXin Li self.gru_size = gru_size 121*a58d3d2aSXin Li 122*a58d3d2aSXin Li self.dense_in = nn.Linear(self.features_in, self.cond_size) 123*a58d3d2aSXin Li self.gru1 = nn.GRU(self.cond_size, self.gru_size, batch_first=True) 124*a58d3d2aSXin Li self.gru2 = nn.GRU(self.gru_size, self.gru_size, batch_first=True) 125*a58d3d2aSXin Li self.dense_out = nn.Linear(self.gru_size, features_out) 126*a58d3d2aSXin Li 127*a58d3d2aSXin Li self.apply(init_weights) 128*a58d3d2aSXin Li nb_params = sum(p.numel() for p in self.parameters()) 129*a58d3d2aSXin Li print(f"plc model: {nb_params} weights") 130*a58d3d2aSXin Li 131*a58d3d2aSXin Li def forward(self, features, lost, states=None): 132*a58d3d2aSXin Li device = features.device 133*a58d3d2aSXin Li batch_size = features.size(0) 134*a58d3d2aSXin Li if states is None: 135*a58d3d2aSXin Li gru1_state = torch.zeros((1, batch_size, self.gru_size), device=device) 136*a58d3d2aSXin Li gru2_state = torch.zeros((1, batch_size, self.gru_size), device=device) 137*a58d3d2aSXin Li else: 138*a58d3d2aSXin Li gru1_state = states[0] 139*a58d3d2aSXin Li gru2_state = states[1] 140*a58d3d2aSXin Li x = torch.cat([features, lost], dim=-1) 141*a58d3d2aSXin Li x = torch.tanh(self.dense_in(x)) 142*a58d3d2aSXin Li gru1_out, gru1_state = self.gru1(x, gru1_state) 143*a58d3d2aSXin Li gru2_out, gru2_state = self.gru2(gru1_out, gru2_state) 144*a58d3d2aSXin Li return self.dense_out(gru2_out), [gru1_state, gru2_state] 145