1import os 2import argparse 3import numpy as np 4 5import torch 6from torch import nn 7import torch.nn.functional as F 8import tqdm 9 10import fargan 11from dataset import FARGANDataset 12 13nb_features = 36 14nb_used_features = 20 15 16parser = argparse.ArgumentParser() 17 18parser.add_argument('model', type=str, help='CELPNet model') 19parser.add_argument('features', type=str, help='path to feature file in .f32 format') 20parser.add_argument('output', type=str, help='path to output file (16-bit PCM)') 21 22parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: CUDA_VISIBLE_DEVICES", default=None) 23 24 25model_group = parser.add_argument_group(title="model parameters") 26model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256) 27 28args = parser.parse_args() 29 30if args.cuda_visible_devices != None: 31 os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices 32 33 34features_file = args.features 35signal_file = args.output 36 37 38 39device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 40 41checkpoint = torch.load(args.model, map_location='cpu') 42 43model = fargan.FARGAN(*checkpoint['model_args'], **checkpoint['model_kwargs']) 44 45 46model.load_state_dict(checkpoint['state_dict'], strict=False) 47 48features = np.reshape(np.memmap(features_file, dtype='float32', mode='r'), (1, -1, nb_features)) 49lpc = features[:,4-1:-1,nb_used_features:] 50features = features[:, :, :nb_used_features] 51#periods = np.round(50*features[:,:,nb_used_features-2]+100).astype('int') 52periods = np.round(np.clip(256./2**(features[:,:,nb_used_features-2]+1.5), 32, 255)).astype('int') 53 54 55nb_frames = features.shape[1] 56#nb_frames = 1000 57gamma = checkpoint['model_kwargs']['gamma'] 58 59def lpc_synthesis_one_frame(frame, filt, buffer, weighting_vector=np.ones(16)): 60 61 out = np.zeros_like(frame) 62 filt = np.flip(filt) 63 64 inp = frame[:] 65 66 67 for i in range(0, inp.shape[0]): 68 69 s = inp[i] - np.dot(buffer*weighting_vector, filt) 70 71 buffer[0] = s 72 73 buffer = np.roll(buffer, -1) 74 75 out[i] = s 76 77 return out 78 79def inverse_perceptual_weighting (pw_signal, filters, weighting_vector): 80 81 #inverse perceptual weighting= H_preemph / W(z/gamma) 82 83 signal = np.zeros_like(pw_signal) 84 buffer = np.zeros(16) 85 num_frames = pw_signal.shape[0] //160 86 assert num_frames == filters.shape[0] 87 for frame_idx in range(0, num_frames): 88 89 in_frame = pw_signal[frame_idx*160: (frame_idx+1)*160][:] 90 out_sig_frame = lpc_synthesis_one_frame(in_frame, filters[frame_idx, :], buffer, weighting_vector) 91 signal[frame_idx*160: (frame_idx+1)*160] = out_sig_frame[:] 92 buffer[:] = out_sig_frame[-16:] 93 return signal 94 95def inverse_perceptual_weighting40 (pw_signal, filters): 96 97 #inverse perceptual weighting= H_preemph / W(z/gamma) 98 99 signal = np.zeros_like(pw_signal) 100 buffer = np.zeros(16) 101 num_frames = pw_signal.shape[0] //40 102 assert num_frames == filters.shape[0] 103 for frame_idx in range(0, num_frames): 104 in_frame = pw_signal[frame_idx*40: (frame_idx+1)*40][:] 105 out_sig_frame = lpc_synthesis_one_frame(in_frame, filters[frame_idx, :], buffer) 106 signal[frame_idx*40: (frame_idx+1)*40] = out_sig_frame[:] 107 buffer[:] = out_sig_frame[-16:] 108 return signal 109 110from scipy.signal import lfilter 111 112if __name__ == '__main__': 113 model.to(device) 114 features = torch.tensor(features).to(device) 115 #lpc = torch.tensor(lpc).to(device) 116 periods = torch.tensor(periods).to(device) 117 weighting = gamma**np.arange(1, 17) 118 lpc = lpc*weighting 119 lpc = fargan.interp_lpc(torch.tensor(lpc), 4).numpy() 120 121 sig, _ = model(features, periods, nb_frames - 4) 122 #weighting_vector = np.array([gamma**i for i in range(16,0,-1)]) 123 sig = sig.detach().numpy().flatten() 124 sig = lfilter(np.array([1.]), np.array([1., -.85]), sig) 125 #sig = inverse_perceptual_weighting40(sig, lpc[0,:,:]) 126 127 pcm = np.round(32768*np.clip(sig, a_max=.99, a_min=-.99)).astype('int16') 128 pcm.tofile(signal_file) 129