xref: /aosp_15_r20/external/libopus/dnn/torch/fargan/fargan.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1*a58d3d2aSXin Liimport numpy as np
2*a58d3d2aSXin Liimport torch
3*a58d3d2aSXin Lifrom torch import nn
4*a58d3d2aSXin Liimport torch.nn.functional as F
5*a58d3d2aSXin Liimport filters
6*a58d3d2aSXin Lifrom torch.nn.utils import weight_norm
7*a58d3d2aSXin Li#from convert_lsp import lpc_to_lsp, lsp_to_lpc
8*a58d3d2aSXin Lifrom rc import lpc2rc, rc2lpc
9*a58d3d2aSXin Li
10*a58d3d2aSXin LiFs = 16000
11*a58d3d2aSXin Li
12*a58d3d2aSXin Lifid_dict = {}
13*a58d3d2aSXin Lidef dump_signal(x, filename):
14*a58d3d2aSXin Li    return
15*a58d3d2aSXin Li    if filename in fid_dict:
16*a58d3d2aSXin Li        fid = fid_dict[filename]
17*a58d3d2aSXin Li    else:
18*a58d3d2aSXin Li        fid = open(filename, "w")
19*a58d3d2aSXin Li        fid_dict[filename] = fid
20*a58d3d2aSXin Li    x = x.detach().numpy().astype('float32')
21*a58d3d2aSXin Li    x.tofile(fid)
22*a58d3d2aSXin Li
23*a58d3d2aSXin Li
24*a58d3d2aSXin Lidef sig_l1(y_true, y_pred):
25*a58d3d2aSXin Li    return torch.mean(abs(y_true-y_pred))/torch.mean(abs(y_true))
26*a58d3d2aSXin Li
27*a58d3d2aSXin Lidef sig_loss(y_true, y_pred):
28*a58d3d2aSXin Li    t = y_true/(1e-15+torch.norm(y_true, dim=-1, p=2, keepdim=True))
29*a58d3d2aSXin Li    p = y_pred/(1e-15+torch.norm(y_pred, dim=-1, p=2, keepdim=True))
30*a58d3d2aSXin Li    return torch.mean(1.-torch.sum(p*t, dim=-1))
31*a58d3d2aSXin Li
32*a58d3d2aSXin Lidef interp_lpc(lpc, factor):
33*a58d3d2aSXin Li    #print(lpc.shape)
34*a58d3d2aSXin Li    #f = (np.arange(factor)+.5*((factor+1)%2))/factor
35*a58d3d2aSXin Li    lsp = torch.atanh(lpc2rc(lpc))
36*a58d3d2aSXin Li    #print("lsp0:")
37*a58d3d2aSXin Li    #print(lsp)
38*a58d3d2aSXin Li    shape = lsp.shape
39*a58d3d2aSXin Li    #print("shape is", shape)
40*a58d3d2aSXin Li    shape = (shape[0], shape[1]*factor, shape[2])
41*a58d3d2aSXin Li    interp_lsp = torch.zeros(shape, device=lpc.device)
42*a58d3d2aSXin Li    for k in range(factor):
43*a58d3d2aSXin Li        f = (k+.5*((factor+1)%2))/factor
44*a58d3d2aSXin Li        interp = (1-f)*lsp[:,:-1,:] + f*lsp[:,1:,:]
45*a58d3d2aSXin Li        interp_lsp[:,factor//2+k:-(factor//2):factor,:] = interp
46*a58d3d2aSXin Li    for k in range(factor//2):
47*a58d3d2aSXin Li        interp_lsp[:,k,:] = interp_lsp[:,factor//2,:]
48*a58d3d2aSXin Li    for k in range((factor+1)//2):
49*a58d3d2aSXin Li        interp_lsp[:,-k-1,:] = interp_lsp[:,-(factor+3)//2,:]
50*a58d3d2aSXin Li    #print("lsp:")
51*a58d3d2aSXin Li    #print(interp_lsp)
52*a58d3d2aSXin Li    return rc2lpc(torch.tanh(interp_lsp))
53*a58d3d2aSXin Li
54*a58d3d2aSXin Lidef analysis_filter(x, lpc, nb_subframes=4, subframe_size=40, gamma=.9):
55*a58d3d2aSXin Li    device = x.device
56*a58d3d2aSXin Li    batch_size = lpc.size(0)
57*a58d3d2aSXin Li
58*a58d3d2aSXin Li    nb_frames = lpc.shape[1]
59*a58d3d2aSXin Li
60*a58d3d2aSXin Li
61*a58d3d2aSXin Li    sig = torch.zeros(batch_size, subframe_size+16, device=device)
62*a58d3d2aSXin Li    x = torch.reshape(x, (batch_size, nb_frames*nb_subframes, subframe_size))
63*a58d3d2aSXin Li    out = torch.zeros((batch_size, 0), device=device)
64*a58d3d2aSXin Li
65*a58d3d2aSXin Li    #if gamma is not None:
66*a58d3d2aSXin Li    #    bw = gamma**(torch.arange(1, 17, device=device))
67*a58d3d2aSXin Li    #    lpc = lpc*bw[None,None,:]
68*a58d3d2aSXin Li    ones = torch.ones((*(lpc.shape[:-1]), 1), device=device)
69*a58d3d2aSXin Li    zeros = torch.zeros((*(lpc.shape[:-1]), subframe_size-1), device=device)
70*a58d3d2aSXin Li    a = torch.cat([ones, lpc], -1)
71*a58d3d2aSXin Li    a_big = torch.cat([a, zeros], -1)
72*a58d3d2aSXin Li    fir_mat_big = filters.toeplitz_from_filter(a_big)
73*a58d3d2aSXin Li
74*a58d3d2aSXin Li    #print(a_big[:,0,:])
75*a58d3d2aSXin Li    for n in range(nb_frames):
76*a58d3d2aSXin Li        for k in range(nb_subframes):
77*a58d3d2aSXin Li
78*a58d3d2aSXin Li            sig = torch.cat([sig[:,subframe_size:], x[:,n*nb_subframes + k, :]], 1)
79*a58d3d2aSXin Li            exc = torch.bmm(fir_mat_big[:,n,:,:], sig[:,:,None])
80*a58d3d2aSXin Li            out = torch.cat([out, exc[:,-subframe_size:,0]], 1)
81*a58d3d2aSXin Li
82*a58d3d2aSXin Li    return out
83*a58d3d2aSXin Li
84*a58d3d2aSXin Li
85*a58d3d2aSXin Li# weight initialization and clipping
86*a58d3d2aSXin Lidef init_weights(module):
87*a58d3d2aSXin Li    if isinstance(module, nn.GRU):
88*a58d3d2aSXin Li        for p in module.named_parameters():
89*a58d3d2aSXin Li            if p[0].startswith('weight_hh_'):
90*a58d3d2aSXin Li                nn.init.orthogonal_(p[1])
91*a58d3d2aSXin Li
92*a58d3d2aSXin Lidef gen_phase_embedding(periods, frame_size):
93*a58d3d2aSXin Li    device = periods.device
94*a58d3d2aSXin Li    batch_size = periods.size(0)
95*a58d3d2aSXin Li    nb_frames = periods.size(1)
96*a58d3d2aSXin Li    w0 = 2*torch.pi/periods
97*a58d3d2aSXin Li    w0_shift = torch.cat([2*torch.pi*torch.rand((batch_size, 1), device=device)/frame_size, w0[:,:-1]], 1)
98*a58d3d2aSXin Li    cum_phase = frame_size*torch.cumsum(w0_shift, 1)
99*a58d3d2aSXin Li    fine_phase = w0[:,:,None]*torch.broadcast_to(torch.arange(frame_size, device=device), (batch_size, nb_frames, frame_size))
100*a58d3d2aSXin Li    embed = torch.unsqueeze(cum_phase, 2) + fine_phase
101*a58d3d2aSXin Li    embed = torch.reshape(embed, (batch_size, -1))
102*a58d3d2aSXin Li    return torch.cos(embed), torch.sin(embed)
103*a58d3d2aSXin Li
104*a58d3d2aSXin Liclass GLU(nn.Module):
105*a58d3d2aSXin Li    def __init__(self, feat_size):
106*a58d3d2aSXin Li        super(GLU, self).__init__()
107*a58d3d2aSXin Li
108*a58d3d2aSXin Li        torch.manual_seed(5)
109*a58d3d2aSXin Li
110*a58d3d2aSXin Li        self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False))
111*a58d3d2aSXin Li
112*a58d3d2aSXin Li        self.init_weights()
113*a58d3d2aSXin Li
114*a58d3d2aSXin Li    def init_weights(self):
115*a58d3d2aSXin Li
116*a58d3d2aSXin Li        for m in self.modules():
117*a58d3d2aSXin Li            if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
118*a58d3d2aSXin Li            or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
119*a58d3d2aSXin Li                nn.init.orthogonal_(m.weight.data)
120*a58d3d2aSXin Li
121*a58d3d2aSXin Li    def forward(self, x):
122*a58d3d2aSXin Li
123*a58d3d2aSXin Li        out = x * torch.sigmoid(self.gate(x))
124*a58d3d2aSXin Li
125*a58d3d2aSXin Li        return out
126*a58d3d2aSXin Li
127*a58d3d2aSXin Liclass FWConv(nn.Module):
128*a58d3d2aSXin Li    def __init__(self, in_size, out_size, kernel_size=2):
129*a58d3d2aSXin Li        super(FWConv, self).__init__()
130*a58d3d2aSXin Li
131*a58d3d2aSXin Li        torch.manual_seed(5)
132*a58d3d2aSXin Li
133*a58d3d2aSXin Li        self.in_size = in_size
134*a58d3d2aSXin Li        self.kernel_size = kernel_size
135*a58d3d2aSXin Li        self.conv = weight_norm(nn.Linear(in_size*self.kernel_size, out_size, bias=False))
136*a58d3d2aSXin Li        self.glu = GLU(out_size)
137*a58d3d2aSXin Li
138*a58d3d2aSXin Li        self.init_weights()
139*a58d3d2aSXin Li
140*a58d3d2aSXin Li    def init_weights(self):
141*a58d3d2aSXin Li
142*a58d3d2aSXin Li        for m in self.modules():
143*a58d3d2aSXin Li            if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
144*a58d3d2aSXin Li            or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
145*a58d3d2aSXin Li                nn.init.orthogonal_(m.weight.data)
146*a58d3d2aSXin Li
147*a58d3d2aSXin Li    def forward(self, x, state):
148*a58d3d2aSXin Li        xcat = torch.cat((state, x), -1)
149*a58d3d2aSXin Li        #print(x.shape, state.shape, xcat.shape, self.in_size, self.kernel_size)
150*a58d3d2aSXin Li        out = self.glu(torch.tanh(self.conv(xcat)))
151*a58d3d2aSXin Li        return out, xcat[:,self.in_size:]
152*a58d3d2aSXin Li
153*a58d3d2aSXin Lidef n(x):
154*a58d3d2aSXin Li    return torch.clamp(x + (1./127.)*(torch.rand_like(x)-.5), min=-1., max=1.)
155*a58d3d2aSXin Li
156*a58d3d2aSXin Liclass FARGANCond(nn.Module):
157*a58d3d2aSXin Li    def __init__(self, feature_dim=20, cond_size=256, pembed_dims=12):
158*a58d3d2aSXin Li        super(FARGANCond, self).__init__()
159*a58d3d2aSXin Li
160*a58d3d2aSXin Li        self.feature_dim = feature_dim
161*a58d3d2aSXin Li        self.cond_size = cond_size
162*a58d3d2aSXin Li
163*a58d3d2aSXin Li        self.pembed = nn.Embedding(224, pembed_dims)
164*a58d3d2aSXin Li        self.fdense1 = nn.Linear(self.feature_dim + pembed_dims, 64, bias=False)
165*a58d3d2aSXin Li        self.fconv1 = nn.Conv1d(64, 128, kernel_size=3, padding='valid', bias=False)
166*a58d3d2aSXin Li        self.fdense2 = nn.Linear(128, 80*4, bias=False)
167*a58d3d2aSXin Li
168*a58d3d2aSXin Li        self.apply(init_weights)
169*a58d3d2aSXin Li        nb_params = sum(p.numel() for p in self.parameters())
170*a58d3d2aSXin Li        print(f"cond model: {nb_params} weights")
171*a58d3d2aSXin Li
172*a58d3d2aSXin Li    def forward(self, features, period):
173*a58d3d2aSXin Li        features = features[:,2:,:]
174*a58d3d2aSXin Li        period = period[:,2:]
175*a58d3d2aSXin Li        p = self.pembed(period-32)
176*a58d3d2aSXin Li        features = torch.cat((features, p), -1)
177*a58d3d2aSXin Li        tmp = torch.tanh(self.fdense1(features))
178*a58d3d2aSXin Li        tmp = tmp.permute(0, 2, 1)
179*a58d3d2aSXin Li        tmp = torch.tanh(self.fconv1(tmp))
180*a58d3d2aSXin Li        tmp = tmp.permute(0, 2, 1)
181*a58d3d2aSXin Li        tmp = torch.tanh(self.fdense2(tmp))
182*a58d3d2aSXin Li        #tmp = torch.tanh(self.fdense2(tmp))
183*a58d3d2aSXin Li        return tmp
184*a58d3d2aSXin Li
185*a58d3d2aSXin Liclass FARGANSub(nn.Module):
186*a58d3d2aSXin Li    def __init__(self, subframe_size=40, nb_subframes=4, cond_size=256):
187*a58d3d2aSXin Li        super(FARGANSub, self).__init__()
188*a58d3d2aSXin Li
189*a58d3d2aSXin Li        self.subframe_size = subframe_size
190*a58d3d2aSXin Li        self.nb_subframes = nb_subframes
191*a58d3d2aSXin Li        self.cond_size = cond_size
192*a58d3d2aSXin Li        self.cond_gain_dense = nn.Linear(80, 1)
193*a58d3d2aSXin Li
194*a58d3d2aSXin Li        #self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False)
195*a58d3d2aSXin Li        self.fwc0 = FWConv(2*self.subframe_size+80+4, 192)
196*a58d3d2aSXin Li        self.gru1 = nn.GRUCell(192+2*self.subframe_size, 160, bias=False)
197*a58d3d2aSXin Li        self.gru2 = nn.GRUCell(160+2*self.subframe_size, 128, bias=False)
198*a58d3d2aSXin Li        self.gru3 = nn.GRUCell(128+2*self.subframe_size, 128, bias=False)
199*a58d3d2aSXin Li
200*a58d3d2aSXin Li        self.gru1_glu = GLU(160)
201*a58d3d2aSXin Li        self.gru2_glu = GLU(128)
202*a58d3d2aSXin Li        self.gru3_glu = GLU(128)
203*a58d3d2aSXin Li        self.skip_glu = GLU(128)
204*a58d3d2aSXin Li        #self.ptaps_dense = nn.Linear(4*self.cond_size, 5)
205*a58d3d2aSXin Li
206*a58d3d2aSXin Li        self.skip_dense = nn.Linear(192+160+2*128+2*self.subframe_size, 128, bias=False)
207*a58d3d2aSXin Li        self.sig_dense_out = nn.Linear(128, self.subframe_size, bias=False)
208*a58d3d2aSXin Li        self.gain_dense_out = nn.Linear(192, 4)
209*a58d3d2aSXin Li
210*a58d3d2aSXin Li
211*a58d3d2aSXin Li        self.apply(init_weights)
212*a58d3d2aSXin Li        nb_params = sum(p.numel() for p in self.parameters())
213*a58d3d2aSXin Li        print(f"subframe model: {nb_params} weights")
214*a58d3d2aSXin Li
215*a58d3d2aSXin Li    def forward(self, cond, prev_pred, exc_mem, period, states, gain=None):
216*a58d3d2aSXin Li        device = exc_mem.device
217*a58d3d2aSXin Li        #print(cond.shape, prev.shape)
218*a58d3d2aSXin Li
219*a58d3d2aSXin Li        cond = n(cond)
220*a58d3d2aSXin Li        dump_signal(gain, 'gain0.f32')
221*a58d3d2aSXin Li        gain = torch.exp(self.cond_gain_dense(cond))
222*a58d3d2aSXin Li        dump_signal(gain, 'gain1.f32')
223*a58d3d2aSXin Li        idx = 256-period[:,None]
224*a58d3d2aSXin Li        rng = torch.arange(self.subframe_size+4, device=device)
225*a58d3d2aSXin Li        idx = idx + rng[None,:] - 2
226*a58d3d2aSXin Li        mask = idx >= 256
227*a58d3d2aSXin Li        idx = idx - mask*period[:,None]
228*a58d3d2aSXin Li        pred = torch.gather(exc_mem, 1, idx)
229*a58d3d2aSXin Li        pred = n(pred/(1e-5+gain))
230*a58d3d2aSXin Li
231*a58d3d2aSXin Li        prev = exc_mem[:,-self.subframe_size:]
232*a58d3d2aSXin Li        dump_signal(prev, 'prev_in.f32')
233*a58d3d2aSXin Li        prev = n(prev/(1e-5+gain))
234*a58d3d2aSXin Li        dump_signal(prev, 'pitch_exc.f32')
235*a58d3d2aSXin Li        dump_signal(exc_mem, 'exc_mem.f32')
236*a58d3d2aSXin Li
237*a58d3d2aSXin Li        tmp = torch.cat((cond, pred, prev), 1)
238*a58d3d2aSXin Li        #fpitch = taps[:,0:1]*pred[:,:-4] + taps[:,1:2]*pred[:,1:-3] + taps[:,2:3]*pred[:,2:-2] + taps[:,3:4]*pred[:,3:-1] + taps[:,4:]*pred[:,4:]
239*a58d3d2aSXin Li        fpitch = pred[:,2:-2]
240*a58d3d2aSXin Li
241*a58d3d2aSXin Li        #tmp = self.dense1_glu(torch.tanh(self.sig_dense1(tmp)))
242*a58d3d2aSXin Li        fwc0_out, fwc0_state = self.fwc0(tmp, states[3])
243*a58d3d2aSXin Li        fwc0_out = n(fwc0_out)
244*a58d3d2aSXin Li        pitch_gain = torch.sigmoid(self.gain_dense_out(fwc0_out))
245*a58d3d2aSXin Li
246*a58d3d2aSXin Li        gru1_state = self.gru1(torch.cat([fwc0_out, pitch_gain[:,0:1]*fpitch, prev], 1), states[0])
247*a58d3d2aSXin Li        gru1_out = self.gru1_glu(n(gru1_state))
248*a58d3d2aSXin Li        gru1_out = n(gru1_out)
249*a58d3d2aSXin Li        gru2_state = self.gru2(torch.cat([gru1_out, pitch_gain[:,1:2]*fpitch, prev], 1), states[1])
250*a58d3d2aSXin Li        gru2_out = self.gru2_glu(n(gru2_state))
251*a58d3d2aSXin Li        gru2_out = n(gru2_out)
252*a58d3d2aSXin Li        gru3_state = self.gru3(torch.cat([gru2_out, pitch_gain[:,2:3]*fpitch, prev], 1), states[2])
253*a58d3d2aSXin Li        gru3_out = self.gru3_glu(n(gru3_state))
254*a58d3d2aSXin Li        gru3_out = n(gru3_out)
255*a58d3d2aSXin Li        gru3_out = torch.cat([gru1_out, gru2_out, gru3_out, fwc0_out], 1)
256*a58d3d2aSXin Li        skip_out = torch.tanh(self.skip_dense(torch.cat([gru3_out, pitch_gain[:,3:4]*fpitch, prev], 1)))
257*a58d3d2aSXin Li        skip_out = self.skip_glu(n(skip_out))
258*a58d3d2aSXin Li        sig_out = torch.tanh(self.sig_dense_out(skip_out))
259*a58d3d2aSXin Li        dump_signal(sig_out, 'exc_out.f32')
260*a58d3d2aSXin Li        #taps = self.ptaps_dense(gru3_out)
261*a58d3d2aSXin Li        #taps = .2*taps + torch.exp(taps)
262*a58d3d2aSXin Li        #taps = taps / (1e-2 + torch.sum(torch.abs(taps), dim=-1, keepdim=True))
263*a58d3d2aSXin Li        #dump_signal(taps, 'taps.f32')
264*a58d3d2aSXin Li
265*a58d3d2aSXin Li        dump_signal(pitch_gain, 'pgain.f32')
266*a58d3d2aSXin Li        #sig_out = (sig_out + pitch_gain*fpitch) * gain
267*a58d3d2aSXin Li        sig_out = sig_out * gain
268*a58d3d2aSXin Li        exc_mem = torch.cat([exc_mem[:,self.subframe_size:], sig_out], 1)
269*a58d3d2aSXin Li        prev_pred = torch.cat([prev_pred[:,self.subframe_size:], fpitch], 1)
270*a58d3d2aSXin Li        dump_signal(sig_out, 'sig_out.f32')
271*a58d3d2aSXin Li        return sig_out, exc_mem, prev_pred, (gru1_state, gru2_state, gru3_state, fwc0_state)
272*a58d3d2aSXin Li
273*a58d3d2aSXin Liclass FARGAN(nn.Module):
274*a58d3d2aSXin Li    def __init__(self, subframe_size=40, nb_subframes=4, feature_dim=20, cond_size=256, passthrough_size=0, has_gain=False, gamma=None):
275*a58d3d2aSXin Li        super(FARGAN, self).__init__()
276*a58d3d2aSXin Li
277*a58d3d2aSXin Li        self.subframe_size = subframe_size
278*a58d3d2aSXin Li        self.nb_subframes = nb_subframes
279*a58d3d2aSXin Li        self.frame_size = self.subframe_size*self.nb_subframes
280*a58d3d2aSXin Li        self.feature_dim = feature_dim
281*a58d3d2aSXin Li        self.cond_size = cond_size
282*a58d3d2aSXin Li
283*a58d3d2aSXin Li        self.cond_net = FARGANCond(feature_dim=feature_dim, cond_size=cond_size)
284*a58d3d2aSXin Li        self.sig_net = FARGANSub(subframe_size=subframe_size, nb_subframes=nb_subframes, cond_size=cond_size)
285*a58d3d2aSXin Li
286*a58d3d2aSXin Li    def forward(self, features, period, nb_frames, pre=None, states=None):
287*a58d3d2aSXin Li        device = features.device
288*a58d3d2aSXin Li        batch_size = features.size(0)
289*a58d3d2aSXin Li
290*a58d3d2aSXin Li        prev = torch.zeros(batch_size, 256, device=device)
291*a58d3d2aSXin Li        exc_mem = torch.zeros(batch_size, 256, device=device)
292*a58d3d2aSXin Li        nb_pre_frames = pre.size(1)//self.frame_size if pre is not None else 0
293*a58d3d2aSXin Li
294*a58d3d2aSXin Li        states = (
295*a58d3d2aSXin Li            torch.zeros(batch_size, 160, device=device),
296*a58d3d2aSXin Li            torch.zeros(batch_size, 128, device=device),
297*a58d3d2aSXin Li            torch.zeros(batch_size, 128, device=device),
298*a58d3d2aSXin Li            torch.zeros(batch_size, (2*self.subframe_size+80+4)*1, device=device)
299*a58d3d2aSXin Li        )
300*a58d3d2aSXin Li
301*a58d3d2aSXin Li        sig = torch.zeros((batch_size, 0), device=device)
302*a58d3d2aSXin Li        cond = self.cond_net(features, period)
303*a58d3d2aSXin Li        if pre is not None:
304*a58d3d2aSXin Li            exc_mem[:,-self.frame_size:] = pre[:, :self.frame_size]
305*a58d3d2aSXin Li        start = 1 if nb_pre_frames>0 else 0
306*a58d3d2aSXin Li        for n in range(start, nb_frames+nb_pre_frames):
307*a58d3d2aSXin Li            for k in range(self.nb_subframes):
308*a58d3d2aSXin Li                pos = n*self.frame_size + k*self.subframe_size
309*a58d3d2aSXin Li                #print("now: ", preal.shape, prev.shape, sig_in.shape)
310*a58d3d2aSXin Li                pitch = period[:, 3+n]
311*a58d3d2aSXin Li                gain = .03*10**(0.5*features[:, 3+n, 0:1]/np.sqrt(18.0))
312*a58d3d2aSXin Li                #gain = gain[:,:,None]
313*a58d3d2aSXin Li                out, exc_mem, prev, states = self.sig_net(cond[:, n, k*80:(k+1)*80], prev, exc_mem, pitch, states, gain=gain)
314*a58d3d2aSXin Li
315*a58d3d2aSXin Li                if n < nb_pre_frames:
316*a58d3d2aSXin Li                    out = pre[:, pos:pos+self.subframe_size]
317*a58d3d2aSXin Li                    exc_mem[:,-self.subframe_size:] = out
318*a58d3d2aSXin Li                else:
319*a58d3d2aSXin Li                    sig = torch.cat([sig, out], 1)
320*a58d3d2aSXin Li
321*a58d3d2aSXin Li        states = [s.detach() for s in states]
322*a58d3d2aSXin Li        return sig, states
323