xref: /aosp_15_r20/external/libopus/dnn/torch/plc/plc.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1*a58d3d2aSXin Liimport numpy as np
2*a58d3d2aSXin Liimport torch
3*a58d3d2aSXin Lifrom torch import nn
4*a58d3d2aSXin Liimport torch.nn.functional as F
5*a58d3d2aSXin Lifrom torch.nn.utils import weight_norm
6*a58d3d2aSXin Liimport math
7*a58d3d2aSXin Li
8*a58d3d2aSXin Lifid_dict = {}
9*a58d3d2aSXin Lidef dump_signal(x, filename):
10*a58d3d2aSXin Li    return
11*a58d3d2aSXin Li    if filename in fid_dict:
12*a58d3d2aSXin Li        fid = fid_dict[filename]
13*a58d3d2aSXin Li    else:
14*a58d3d2aSXin Li        fid = open(filename, "w")
15*a58d3d2aSXin Li        fid_dict[filename] = fid
16*a58d3d2aSXin Li    x = x.detach().numpy().astype('float32')
17*a58d3d2aSXin Li    x.tofile(fid)
18*a58d3d2aSXin Li
19*a58d3d2aSXin Li
20*a58d3d2aSXin Liclass IDCT(nn.Module):
21*a58d3d2aSXin Li    def __init__(self, N, device=None):
22*a58d3d2aSXin Li        super(IDCT, self).__init__()
23*a58d3d2aSXin Li
24*a58d3d2aSXin Li        self.N = N
25*a58d3d2aSXin Li        n = torch.arange(N, device=device)
26*a58d3d2aSXin Li        k = torch.arange(N, device=device)
27*a58d3d2aSXin Li        self.table = torch.cos(torch.pi/N * (n[:,None]+.5) * k[None,:])
28*a58d3d2aSXin Li        self.table[:,0] = self.table[:,0] * math.sqrt(.5)
29*a58d3d2aSXin Li        self.table = self.table / math.sqrt(N/2)
30*a58d3d2aSXin Li
31*a58d3d2aSXin Li    def forward(self, x):
32*a58d3d2aSXin Li        return F.linear(x, self.table, None)
33*a58d3d2aSXin Li
34*a58d3d2aSXin Lidef plc_loss(N, device=None, alpha=1.0, bias=1.):
35*a58d3d2aSXin Li    idct = IDCT(18, device=device)
36*a58d3d2aSXin Li    def loss(y_true,y_pred):
37*a58d3d2aSXin Li        mask = y_true[:,:,-1:]
38*a58d3d2aSXin Li        y_true = y_true[:,:,:-1]
39*a58d3d2aSXin Li        e = (y_pred - y_true)*mask
40*a58d3d2aSXin Li        e_bands = idct(e[:,:,:-2])
41*a58d3d2aSXin Li        bias_mask = torch.clamp(4*y_true[:,:,-1:], min=0., max=1.)
42*a58d3d2aSXin Li        l1_loss = torch.mean(torch.abs(e))
43*a58d3d2aSXin Li        ceps_loss = torch.mean(torch.abs(e[:,:,:-2]))
44*a58d3d2aSXin Li        band_loss = torch.mean(torch.abs(e_bands))
45*a58d3d2aSXin Li        biased_loss = torch.mean(bias_mask*torch.clamp(e_bands, min=0.))
46*a58d3d2aSXin Li        pitch_loss1 = torch.mean(torch.clamp(torch.abs(e[:,:,18:19]),max=1.))
47*a58d3d2aSXin Li        pitch_loss = torch.mean(torch.clamp(torch.abs(e[:,:,18:19]),max=.4))
48*a58d3d2aSXin Li        voice_bias = torch.mean(torch.clamp(-e[:,:,-1:], min=0.))
49*a58d3d2aSXin Li        tot = l1_loss + 0.1*voice_bias + alpha*(band_loss + bias*biased_loss) + pitch_loss1 + 8*pitch_loss
50*a58d3d2aSXin Li        return tot, l1_loss, ceps_loss, band_loss, pitch_loss
51*a58d3d2aSXin Li    return loss
52*a58d3d2aSXin Li
53*a58d3d2aSXin Li
54*a58d3d2aSXin Li# weight initialization and clipping
55*a58d3d2aSXin Lidef init_weights(module):
56*a58d3d2aSXin Li    if isinstance(module, nn.GRU):
57*a58d3d2aSXin Li        for p in module.named_parameters():
58*a58d3d2aSXin Li            if p[0].startswith('weight_hh_'):
59*a58d3d2aSXin Li                nn.init.orthogonal_(p[1])
60*a58d3d2aSXin Li
61*a58d3d2aSXin Li
62*a58d3d2aSXin Liclass GLU(nn.Module):
63*a58d3d2aSXin Li    def __init__(self, feat_size):
64*a58d3d2aSXin Li        super(GLU, self).__init__()
65*a58d3d2aSXin Li
66*a58d3d2aSXin Li        torch.manual_seed(5)
67*a58d3d2aSXin Li
68*a58d3d2aSXin Li        self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False))
69*a58d3d2aSXin Li
70*a58d3d2aSXin Li        self.init_weights()
71*a58d3d2aSXin Li
72*a58d3d2aSXin Li    def init_weights(self):
73*a58d3d2aSXin Li
74*a58d3d2aSXin Li        for m in self.modules():
75*a58d3d2aSXin Li            if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
76*a58d3d2aSXin Li            or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
77*a58d3d2aSXin Li                nn.init.orthogonal_(m.weight.data)
78*a58d3d2aSXin Li
79*a58d3d2aSXin Li    def forward(self, x):
80*a58d3d2aSXin Li
81*a58d3d2aSXin Li        out = x * torch.sigmoid(self.gate(x))
82*a58d3d2aSXin Li
83*a58d3d2aSXin Li        return out
84*a58d3d2aSXin Li
85*a58d3d2aSXin Liclass FWConv(nn.Module):
86*a58d3d2aSXin Li    def __init__(self, in_size, out_size, kernel_size=2):
87*a58d3d2aSXin Li        super(FWConv, self).__init__()
88*a58d3d2aSXin Li
89*a58d3d2aSXin Li        torch.manual_seed(5)
90*a58d3d2aSXin Li
91*a58d3d2aSXin Li        self.in_size = in_size
92*a58d3d2aSXin Li        self.kernel_size = kernel_size
93*a58d3d2aSXin Li        self.conv = weight_norm(nn.Linear(in_size*self.kernel_size, out_size, bias=False))
94*a58d3d2aSXin Li        self.glu = GLU(out_size)
95*a58d3d2aSXin Li
96*a58d3d2aSXin Li        self.init_weights()
97*a58d3d2aSXin Li
98*a58d3d2aSXin Li    def init_weights(self):
99*a58d3d2aSXin Li
100*a58d3d2aSXin Li        for m in self.modules():
101*a58d3d2aSXin Li            if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
102*a58d3d2aSXin Li            or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
103*a58d3d2aSXin Li                nn.init.orthogonal_(m.weight.data)
104*a58d3d2aSXin Li
105*a58d3d2aSXin Li    def forward(self, x, state):
106*a58d3d2aSXin Li        xcat = torch.cat((state, x), -1)
107*a58d3d2aSXin Li        out = self.glu(torch.tanh(self.conv(xcat)))
108*a58d3d2aSXin Li        return out, xcat[:,self.in_size:]
109*a58d3d2aSXin Li
110*a58d3d2aSXin Lidef n(x):
111*a58d3d2aSXin Li    return torch.clamp(x + (1./127.)*(torch.rand_like(x)-.5), min=-1., max=1.)
112*a58d3d2aSXin Li
113*a58d3d2aSXin Liclass PLC(nn.Module):
114*a58d3d2aSXin Li    def __init__(self, features_in=57, features_out=20, cond_size=128, gru_size=128):
115*a58d3d2aSXin Li        super(PLC, self).__init__()
116*a58d3d2aSXin Li
117*a58d3d2aSXin Li        self.features_in = features_in
118*a58d3d2aSXin Li        self.features_out = features_out
119*a58d3d2aSXin Li        self.cond_size = cond_size
120*a58d3d2aSXin Li        self.gru_size = gru_size
121*a58d3d2aSXin Li
122*a58d3d2aSXin Li        self.dense_in = nn.Linear(self.features_in, self.cond_size)
123*a58d3d2aSXin Li        self.gru1 = nn.GRU(self.cond_size, self.gru_size, batch_first=True)
124*a58d3d2aSXin Li        self.gru2 = nn.GRU(self.gru_size, self.gru_size, batch_first=True)
125*a58d3d2aSXin Li        self.dense_out = nn.Linear(self.gru_size, features_out)
126*a58d3d2aSXin Li
127*a58d3d2aSXin Li        self.apply(init_weights)
128*a58d3d2aSXin Li        nb_params = sum(p.numel() for p in self.parameters())
129*a58d3d2aSXin Li        print(f"plc model: {nb_params} weights")
130*a58d3d2aSXin Li
131*a58d3d2aSXin Li    def forward(self, features, lost, states=None):
132*a58d3d2aSXin Li        device = features.device
133*a58d3d2aSXin Li        batch_size = features.size(0)
134*a58d3d2aSXin Li        if states is None:
135*a58d3d2aSXin Li            gru1_state = torch.zeros((1, batch_size, self.gru_size), device=device)
136*a58d3d2aSXin Li            gru2_state = torch.zeros((1, batch_size, self.gru_size), device=device)
137*a58d3d2aSXin Li        else:
138*a58d3d2aSXin Li            gru1_state = states[0]
139*a58d3d2aSXin Li            gru2_state = states[1]
140*a58d3d2aSXin Li        x = torch.cat([features, lost], dim=-1)
141*a58d3d2aSXin Li        x = torch.tanh(self.dense_in(x))
142*a58d3d2aSXin Li        gru1_out, gru1_state = self.gru1(x, gru1_state)
143*a58d3d2aSXin Li        gru2_out, gru2_state = self.gru2(gru1_out, gru2_state)
144*a58d3d2aSXin Li        return self.dense_out(gru2_out), [gru1_state, gru2_state]
145