1*a58d3d2aSXin Liimport numpy as np 2*a58d3d2aSXin Liimport torch 3*a58d3d2aSXin Lifrom torch import nn 4*a58d3d2aSXin Liimport torch.nn.functional as F 5*a58d3d2aSXin Liimport filters 6*a58d3d2aSXin Lifrom torch.nn.utils import weight_norm 7*a58d3d2aSXin Li#from convert_lsp import lpc_to_lsp, lsp_to_lpc 8*a58d3d2aSXin Lifrom rc import lpc2rc, rc2lpc 9*a58d3d2aSXin Li 10*a58d3d2aSXin LiFs = 16000 11*a58d3d2aSXin Li 12*a58d3d2aSXin Lifid_dict = {} 13*a58d3d2aSXin Lidef dump_signal(x, filename): 14*a58d3d2aSXin Li return 15*a58d3d2aSXin Li if filename in fid_dict: 16*a58d3d2aSXin Li fid = fid_dict[filename] 17*a58d3d2aSXin Li else: 18*a58d3d2aSXin Li fid = open(filename, "w") 19*a58d3d2aSXin Li fid_dict[filename] = fid 20*a58d3d2aSXin Li x = x.detach().numpy().astype('float32') 21*a58d3d2aSXin Li x.tofile(fid) 22*a58d3d2aSXin Li 23*a58d3d2aSXin Li 24*a58d3d2aSXin Lidef sig_l1(y_true, y_pred): 25*a58d3d2aSXin Li return torch.mean(abs(y_true-y_pred))/torch.mean(abs(y_true)) 26*a58d3d2aSXin Li 27*a58d3d2aSXin Lidef sig_loss(y_true, y_pred): 28*a58d3d2aSXin Li t = y_true/(1e-15+torch.norm(y_true, dim=-1, p=2, keepdim=True)) 29*a58d3d2aSXin Li p = y_pred/(1e-15+torch.norm(y_pred, dim=-1, p=2, keepdim=True)) 30*a58d3d2aSXin Li return torch.mean(1.-torch.sum(p*t, dim=-1)) 31*a58d3d2aSXin Li 32*a58d3d2aSXin Lidef interp_lpc(lpc, factor): 33*a58d3d2aSXin Li #print(lpc.shape) 34*a58d3d2aSXin Li #f = (np.arange(factor)+.5*((factor+1)%2))/factor 35*a58d3d2aSXin Li lsp = torch.atanh(lpc2rc(lpc)) 36*a58d3d2aSXin Li #print("lsp0:") 37*a58d3d2aSXin Li #print(lsp) 38*a58d3d2aSXin Li shape = lsp.shape 39*a58d3d2aSXin Li #print("shape is", shape) 40*a58d3d2aSXin Li shape = (shape[0], shape[1]*factor, shape[2]) 41*a58d3d2aSXin Li interp_lsp = torch.zeros(shape, device=lpc.device) 42*a58d3d2aSXin Li for k in range(factor): 43*a58d3d2aSXin Li f = (k+.5*((factor+1)%2))/factor 44*a58d3d2aSXin Li interp = (1-f)*lsp[:,:-1,:] + f*lsp[:,1:,:] 45*a58d3d2aSXin Li interp_lsp[:,factor//2+k:-(factor//2):factor,:] = interp 46*a58d3d2aSXin Li for k in range(factor//2): 47*a58d3d2aSXin Li interp_lsp[:,k,:] = interp_lsp[:,factor//2,:] 48*a58d3d2aSXin Li for k in range((factor+1)//2): 49*a58d3d2aSXin Li interp_lsp[:,-k-1,:] = interp_lsp[:,-(factor+3)//2,:] 50*a58d3d2aSXin Li #print("lsp:") 51*a58d3d2aSXin Li #print(interp_lsp) 52*a58d3d2aSXin Li return rc2lpc(torch.tanh(interp_lsp)) 53*a58d3d2aSXin Li 54*a58d3d2aSXin Lidef analysis_filter(x, lpc, nb_subframes=4, subframe_size=40, gamma=.9): 55*a58d3d2aSXin Li device = x.device 56*a58d3d2aSXin Li batch_size = lpc.size(0) 57*a58d3d2aSXin Li 58*a58d3d2aSXin Li nb_frames = lpc.shape[1] 59*a58d3d2aSXin Li 60*a58d3d2aSXin Li 61*a58d3d2aSXin Li sig = torch.zeros(batch_size, subframe_size+16, device=device) 62*a58d3d2aSXin Li x = torch.reshape(x, (batch_size, nb_frames*nb_subframes, subframe_size)) 63*a58d3d2aSXin Li out = torch.zeros((batch_size, 0), device=device) 64*a58d3d2aSXin Li 65*a58d3d2aSXin Li #if gamma is not None: 66*a58d3d2aSXin Li # bw = gamma**(torch.arange(1, 17, device=device)) 67*a58d3d2aSXin Li # lpc = lpc*bw[None,None,:] 68*a58d3d2aSXin Li ones = torch.ones((*(lpc.shape[:-1]), 1), device=device) 69*a58d3d2aSXin Li zeros = torch.zeros((*(lpc.shape[:-1]), subframe_size-1), device=device) 70*a58d3d2aSXin Li a = torch.cat([ones, lpc], -1) 71*a58d3d2aSXin Li a_big = torch.cat([a, zeros], -1) 72*a58d3d2aSXin Li fir_mat_big = filters.toeplitz_from_filter(a_big) 73*a58d3d2aSXin Li 74*a58d3d2aSXin Li #print(a_big[:,0,:]) 75*a58d3d2aSXin Li for n in range(nb_frames): 76*a58d3d2aSXin Li for k in range(nb_subframes): 77*a58d3d2aSXin Li 78*a58d3d2aSXin Li sig = torch.cat([sig[:,subframe_size:], x[:,n*nb_subframes + k, :]], 1) 79*a58d3d2aSXin Li exc = torch.bmm(fir_mat_big[:,n,:,:], sig[:,:,None]) 80*a58d3d2aSXin Li out = torch.cat([out, exc[:,-subframe_size:,0]], 1) 81*a58d3d2aSXin Li 82*a58d3d2aSXin Li return out 83*a58d3d2aSXin Li 84*a58d3d2aSXin Li 85*a58d3d2aSXin Li# weight initialization and clipping 86*a58d3d2aSXin Lidef init_weights(module): 87*a58d3d2aSXin Li if isinstance(module, nn.GRU): 88*a58d3d2aSXin Li for p in module.named_parameters(): 89*a58d3d2aSXin Li if p[0].startswith('weight_hh_'): 90*a58d3d2aSXin Li nn.init.orthogonal_(p[1]) 91*a58d3d2aSXin Li 92*a58d3d2aSXin Lidef gen_phase_embedding(periods, frame_size): 93*a58d3d2aSXin Li device = periods.device 94*a58d3d2aSXin Li batch_size = periods.size(0) 95*a58d3d2aSXin Li nb_frames = periods.size(1) 96*a58d3d2aSXin Li w0 = 2*torch.pi/periods 97*a58d3d2aSXin Li w0_shift = torch.cat([2*torch.pi*torch.rand((batch_size, 1), device=device)/frame_size, w0[:,:-1]], 1) 98*a58d3d2aSXin Li cum_phase = frame_size*torch.cumsum(w0_shift, 1) 99*a58d3d2aSXin Li fine_phase = w0[:,:,None]*torch.broadcast_to(torch.arange(frame_size, device=device), (batch_size, nb_frames, frame_size)) 100*a58d3d2aSXin Li embed = torch.unsqueeze(cum_phase, 2) + fine_phase 101*a58d3d2aSXin Li embed = torch.reshape(embed, (batch_size, -1)) 102*a58d3d2aSXin Li return torch.cos(embed), torch.sin(embed) 103*a58d3d2aSXin Li 104*a58d3d2aSXin Liclass GLU(nn.Module): 105*a58d3d2aSXin Li def __init__(self, feat_size): 106*a58d3d2aSXin Li super(GLU, self).__init__() 107*a58d3d2aSXin Li 108*a58d3d2aSXin Li torch.manual_seed(5) 109*a58d3d2aSXin Li 110*a58d3d2aSXin Li self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False)) 111*a58d3d2aSXin Li 112*a58d3d2aSXin Li self.init_weights() 113*a58d3d2aSXin Li 114*a58d3d2aSXin Li def init_weights(self): 115*a58d3d2aSXin Li 116*a58d3d2aSXin Li for m in self.modules(): 117*a58d3d2aSXin Li if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\ 118*a58d3d2aSXin Li or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding): 119*a58d3d2aSXin Li nn.init.orthogonal_(m.weight.data) 120*a58d3d2aSXin Li 121*a58d3d2aSXin Li def forward(self, x): 122*a58d3d2aSXin Li 123*a58d3d2aSXin Li out = x * torch.sigmoid(self.gate(x)) 124*a58d3d2aSXin Li 125*a58d3d2aSXin Li return out 126*a58d3d2aSXin Li 127*a58d3d2aSXin Liclass FWConv(nn.Module): 128*a58d3d2aSXin Li def __init__(self, in_size, out_size, kernel_size=2): 129*a58d3d2aSXin Li super(FWConv, self).__init__() 130*a58d3d2aSXin Li 131*a58d3d2aSXin Li torch.manual_seed(5) 132*a58d3d2aSXin Li 133*a58d3d2aSXin Li self.in_size = in_size 134*a58d3d2aSXin Li self.kernel_size = kernel_size 135*a58d3d2aSXin Li self.conv = weight_norm(nn.Linear(in_size*self.kernel_size, out_size, bias=False)) 136*a58d3d2aSXin Li self.glu = GLU(out_size) 137*a58d3d2aSXin Li 138*a58d3d2aSXin Li self.init_weights() 139*a58d3d2aSXin Li 140*a58d3d2aSXin Li def init_weights(self): 141*a58d3d2aSXin Li 142*a58d3d2aSXin Li for m in self.modules(): 143*a58d3d2aSXin Li if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\ 144*a58d3d2aSXin Li or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding): 145*a58d3d2aSXin Li nn.init.orthogonal_(m.weight.data) 146*a58d3d2aSXin Li 147*a58d3d2aSXin Li def forward(self, x, state): 148*a58d3d2aSXin Li xcat = torch.cat((state, x), -1) 149*a58d3d2aSXin Li #print(x.shape, state.shape, xcat.shape, self.in_size, self.kernel_size) 150*a58d3d2aSXin Li out = self.glu(torch.tanh(self.conv(xcat))) 151*a58d3d2aSXin Li return out, xcat[:,self.in_size:] 152*a58d3d2aSXin Li 153*a58d3d2aSXin Lidef n(x): 154*a58d3d2aSXin Li return torch.clamp(x + (1./127.)*(torch.rand_like(x)-.5), min=-1., max=1.) 155*a58d3d2aSXin Li 156*a58d3d2aSXin Liclass FARGANCond(nn.Module): 157*a58d3d2aSXin Li def __init__(self, feature_dim=20, cond_size=256, pembed_dims=12): 158*a58d3d2aSXin Li super(FARGANCond, self).__init__() 159*a58d3d2aSXin Li 160*a58d3d2aSXin Li self.feature_dim = feature_dim 161*a58d3d2aSXin Li self.cond_size = cond_size 162*a58d3d2aSXin Li 163*a58d3d2aSXin Li self.pembed = nn.Embedding(224, pembed_dims) 164*a58d3d2aSXin Li self.fdense1 = nn.Linear(self.feature_dim + pembed_dims, 64, bias=False) 165*a58d3d2aSXin Li self.fconv1 = nn.Conv1d(64, 128, kernel_size=3, padding='valid', bias=False) 166*a58d3d2aSXin Li self.fdense2 = nn.Linear(128, 80*4, bias=False) 167*a58d3d2aSXin Li 168*a58d3d2aSXin Li self.apply(init_weights) 169*a58d3d2aSXin Li nb_params = sum(p.numel() for p in self.parameters()) 170*a58d3d2aSXin Li print(f"cond model: {nb_params} weights") 171*a58d3d2aSXin Li 172*a58d3d2aSXin Li def forward(self, features, period): 173*a58d3d2aSXin Li features = features[:,2:,:] 174*a58d3d2aSXin Li period = period[:,2:] 175*a58d3d2aSXin Li p = self.pembed(period-32) 176*a58d3d2aSXin Li features = torch.cat((features, p), -1) 177*a58d3d2aSXin Li tmp = torch.tanh(self.fdense1(features)) 178*a58d3d2aSXin Li tmp = tmp.permute(0, 2, 1) 179*a58d3d2aSXin Li tmp = torch.tanh(self.fconv1(tmp)) 180*a58d3d2aSXin Li tmp = tmp.permute(0, 2, 1) 181*a58d3d2aSXin Li tmp = torch.tanh(self.fdense2(tmp)) 182*a58d3d2aSXin Li #tmp = torch.tanh(self.fdense2(tmp)) 183*a58d3d2aSXin Li return tmp 184*a58d3d2aSXin Li 185*a58d3d2aSXin Liclass FARGANSub(nn.Module): 186*a58d3d2aSXin Li def __init__(self, subframe_size=40, nb_subframes=4, cond_size=256): 187*a58d3d2aSXin Li super(FARGANSub, self).__init__() 188*a58d3d2aSXin Li 189*a58d3d2aSXin Li self.subframe_size = subframe_size 190*a58d3d2aSXin Li self.nb_subframes = nb_subframes 191*a58d3d2aSXin Li self.cond_size = cond_size 192*a58d3d2aSXin Li self.cond_gain_dense = nn.Linear(80, 1) 193*a58d3d2aSXin Li 194*a58d3d2aSXin Li #self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False) 195*a58d3d2aSXin Li self.fwc0 = FWConv(2*self.subframe_size+80+4, 192) 196*a58d3d2aSXin Li self.gru1 = nn.GRUCell(192+2*self.subframe_size, 160, bias=False) 197*a58d3d2aSXin Li self.gru2 = nn.GRUCell(160+2*self.subframe_size, 128, bias=False) 198*a58d3d2aSXin Li self.gru3 = nn.GRUCell(128+2*self.subframe_size, 128, bias=False) 199*a58d3d2aSXin Li 200*a58d3d2aSXin Li self.gru1_glu = GLU(160) 201*a58d3d2aSXin Li self.gru2_glu = GLU(128) 202*a58d3d2aSXin Li self.gru3_glu = GLU(128) 203*a58d3d2aSXin Li self.skip_glu = GLU(128) 204*a58d3d2aSXin Li #self.ptaps_dense = nn.Linear(4*self.cond_size, 5) 205*a58d3d2aSXin Li 206*a58d3d2aSXin Li self.skip_dense = nn.Linear(192+160+2*128+2*self.subframe_size, 128, bias=False) 207*a58d3d2aSXin Li self.sig_dense_out = nn.Linear(128, self.subframe_size, bias=False) 208*a58d3d2aSXin Li self.gain_dense_out = nn.Linear(192, 4) 209*a58d3d2aSXin Li 210*a58d3d2aSXin Li 211*a58d3d2aSXin Li self.apply(init_weights) 212*a58d3d2aSXin Li nb_params = sum(p.numel() for p in self.parameters()) 213*a58d3d2aSXin Li print(f"subframe model: {nb_params} weights") 214*a58d3d2aSXin Li 215*a58d3d2aSXin Li def forward(self, cond, prev_pred, exc_mem, period, states, gain=None): 216*a58d3d2aSXin Li device = exc_mem.device 217*a58d3d2aSXin Li #print(cond.shape, prev.shape) 218*a58d3d2aSXin Li 219*a58d3d2aSXin Li cond = n(cond) 220*a58d3d2aSXin Li dump_signal(gain, 'gain0.f32') 221*a58d3d2aSXin Li gain = torch.exp(self.cond_gain_dense(cond)) 222*a58d3d2aSXin Li dump_signal(gain, 'gain1.f32') 223*a58d3d2aSXin Li idx = 256-period[:,None] 224*a58d3d2aSXin Li rng = torch.arange(self.subframe_size+4, device=device) 225*a58d3d2aSXin Li idx = idx + rng[None,:] - 2 226*a58d3d2aSXin Li mask = idx >= 256 227*a58d3d2aSXin Li idx = idx - mask*period[:,None] 228*a58d3d2aSXin Li pred = torch.gather(exc_mem, 1, idx) 229*a58d3d2aSXin Li pred = n(pred/(1e-5+gain)) 230*a58d3d2aSXin Li 231*a58d3d2aSXin Li prev = exc_mem[:,-self.subframe_size:] 232*a58d3d2aSXin Li dump_signal(prev, 'prev_in.f32') 233*a58d3d2aSXin Li prev = n(prev/(1e-5+gain)) 234*a58d3d2aSXin Li dump_signal(prev, 'pitch_exc.f32') 235*a58d3d2aSXin Li dump_signal(exc_mem, 'exc_mem.f32') 236*a58d3d2aSXin Li 237*a58d3d2aSXin Li tmp = torch.cat((cond, pred, prev), 1) 238*a58d3d2aSXin Li #fpitch = taps[:,0:1]*pred[:,:-4] + taps[:,1:2]*pred[:,1:-3] + taps[:,2:3]*pred[:,2:-2] + taps[:,3:4]*pred[:,3:-1] + taps[:,4:]*pred[:,4:] 239*a58d3d2aSXin Li fpitch = pred[:,2:-2] 240*a58d3d2aSXin Li 241*a58d3d2aSXin Li #tmp = self.dense1_glu(torch.tanh(self.sig_dense1(tmp))) 242*a58d3d2aSXin Li fwc0_out, fwc0_state = self.fwc0(tmp, states[3]) 243*a58d3d2aSXin Li fwc0_out = n(fwc0_out) 244*a58d3d2aSXin Li pitch_gain = torch.sigmoid(self.gain_dense_out(fwc0_out)) 245*a58d3d2aSXin Li 246*a58d3d2aSXin Li gru1_state = self.gru1(torch.cat([fwc0_out, pitch_gain[:,0:1]*fpitch, prev], 1), states[0]) 247*a58d3d2aSXin Li gru1_out = self.gru1_glu(n(gru1_state)) 248*a58d3d2aSXin Li gru1_out = n(gru1_out) 249*a58d3d2aSXin Li gru2_state = self.gru2(torch.cat([gru1_out, pitch_gain[:,1:2]*fpitch, prev], 1), states[1]) 250*a58d3d2aSXin Li gru2_out = self.gru2_glu(n(gru2_state)) 251*a58d3d2aSXin Li gru2_out = n(gru2_out) 252*a58d3d2aSXin Li gru3_state = self.gru3(torch.cat([gru2_out, pitch_gain[:,2:3]*fpitch, prev], 1), states[2]) 253*a58d3d2aSXin Li gru3_out = self.gru3_glu(n(gru3_state)) 254*a58d3d2aSXin Li gru3_out = n(gru3_out) 255*a58d3d2aSXin Li gru3_out = torch.cat([gru1_out, gru2_out, gru3_out, fwc0_out], 1) 256*a58d3d2aSXin Li skip_out = torch.tanh(self.skip_dense(torch.cat([gru3_out, pitch_gain[:,3:4]*fpitch, prev], 1))) 257*a58d3d2aSXin Li skip_out = self.skip_glu(n(skip_out)) 258*a58d3d2aSXin Li sig_out = torch.tanh(self.sig_dense_out(skip_out)) 259*a58d3d2aSXin Li dump_signal(sig_out, 'exc_out.f32') 260*a58d3d2aSXin Li #taps = self.ptaps_dense(gru3_out) 261*a58d3d2aSXin Li #taps = .2*taps + torch.exp(taps) 262*a58d3d2aSXin Li #taps = taps / (1e-2 + torch.sum(torch.abs(taps), dim=-1, keepdim=True)) 263*a58d3d2aSXin Li #dump_signal(taps, 'taps.f32') 264*a58d3d2aSXin Li 265*a58d3d2aSXin Li dump_signal(pitch_gain, 'pgain.f32') 266*a58d3d2aSXin Li #sig_out = (sig_out + pitch_gain*fpitch) * gain 267*a58d3d2aSXin Li sig_out = sig_out * gain 268*a58d3d2aSXin Li exc_mem = torch.cat([exc_mem[:,self.subframe_size:], sig_out], 1) 269*a58d3d2aSXin Li prev_pred = torch.cat([prev_pred[:,self.subframe_size:], fpitch], 1) 270*a58d3d2aSXin Li dump_signal(sig_out, 'sig_out.f32') 271*a58d3d2aSXin Li return sig_out, exc_mem, prev_pred, (gru1_state, gru2_state, gru3_state, fwc0_state) 272*a58d3d2aSXin Li 273*a58d3d2aSXin Liclass FARGAN(nn.Module): 274*a58d3d2aSXin Li def __init__(self, subframe_size=40, nb_subframes=4, feature_dim=20, cond_size=256, passthrough_size=0, has_gain=False, gamma=None): 275*a58d3d2aSXin Li super(FARGAN, self).__init__() 276*a58d3d2aSXin Li 277*a58d3d2aSXin Li self.subframe_size = subframe_size 278*a58d3d2aSXin Li self.nb_subframes = nb_subframes 279*a58d3d2aSXin Li self.frame_size = self.subframe_size*self.nb_subframes 280*a58d3d2aSXin Li self.feature_dim = feature_dim 281*a58d3d2aSXin Li self.cond_size = cond_size 282*a58d3d2aSXin Li 283*a58d3d2aSXin Li self.cond_net = FARGANCond(feature_dim=feature_dim, cond_size=cond_size) 284*a58d3d2aSXin Li self.sig_net = FARGANSub(subframe_size=subframe_size, nb_subframes=nb_subframes, cond_size=cond_size) 285*a58d3d2aSXin Li 286*a58d3d2aSXin Li def forward(self, features, period, nb_frames, pre=None, states=None): 287*a58d3d2aSXin Li device = features.device 288*a58d3d2aSXin Li batch_size = features.size(0) 289*a58d3d2aSXin Li 290*a58d3d2aSXin Li prev = torch.zeros(batch_size, 256, device=device) 291*a58d3d2aSXin Li exc_mem = torch.zeros(batch_size, 256, device=device) 292*a58d3d2aSXin Li nb_pre_frames = pre.size(1)//self.frame_size if pre is not None else 0 293*a58d3d2aSXin Li 294*a58d3d2aSXin Li states = ( 295*a58d3d2aSXin Li torch.zeros(batch_size, 160, device=device), 296*a58d3d2aSXin Li torch.zeros(batch_size, 128, device=device), 297*a58d3d2aSXin Li torch.zeros(batch_size, 128, device=device), 298*a58d3d2aSXin Li torch.zeros(batch_size, (2*self.subframe_size+80+4)*1, device=device) 299*a58d3d2aSXin Li ) 300*a58d3d2aSXin Li 301*a58d3d2aSXin Li sig = torch.zeros((batch_size, 0), device=device) 302*a58d3d2aSXin Li cond = self.cond_net(features, period) 303*a58d3d2aSXin Li if pre is not None: 304*a58d3d2aSXin Li exc_mem[:,-self.frame_size:] = pre[:, :self.frame_size] 305*a58d3d2aSXin Li start = 1 if nb_pre_frames>0 else 0 306*a58d3d2aSXin Li for n in range(start, nb_frames+nb_pre_frames): 307*a58d3d2aSXin Li for k in range(self.nb_subframes): 308*a58d3d2aSXin Li pos = n*self.frame_size + k*self.subframe_size 309*a58d3d2aSXin Li #print("now: ", preal.shape, prev.shape, sig_in.shape) 310*a58d3d2aSXin Li pitch = period[:, 3+n] 311*a58d3d2aSXin Li gain = .03*10**(0.5*features[:, 3+n, 0:1]/np.sqrt(18.0)) 312*a58d3d2aSXin Li #gain = gain[:,:,None] 313*a58d3d2aSXin Li out, exc_mem, prev, states = self.sig_net(cond[:, n, k*80:(k+1)*80], prev, exc_mem, pitch, states, gain=gain) 314*a58d3d2aSXin Li 315*a58d3d2aSXin Li if n < nb_pre_frames: 316*a58d3d2aSXin Li out = pre[:, pos:pos+self.subframe_size] 317*a58d3d2aSXin Li exc_mem[:,-self.subframe_size:] = out 318*a58d3d2aSXin Li else: 319*a58d3d2aSXin Li sig = torch.cat([sig, out], 1) 320*a58d3d2aSXin Li 321*a58d3d2aSXin Li states = [s.detach() for s in states] 322*a58d3d2aSXin Li return sig, states 323