1*a58d3d2aSXin Liimport os 2*a58d3d2aSXin Liimport time 3*a58d3d2aSXin Liimport torch 4*a58d3d2aSXin Liimport numpy as np 5*a58d3d2aSXin Lifrom scipy import signal as si 6*a58d3d2aSXin Lifrom scipy.io import wavfile 7*a58d3d2aSXin Liimport argparse 8*a58d3d2aSXin Li 9*a58d3d2aSXin Lifrom models import model_dict 10*a58d3d2aSXin Li 11*a58d3d2aSXin Liparser = argparse.ArgumentParser() 12*a58d3d2aSXin Liparser.add_argument('model', choices=['fwgan400', 'fwgan500'], help='model name') 13*a58d3d2aSXin Liparser.add_argument('weightfile', type=str, help='weight file') 14*a58d3d2aSXin Liparser.add_argument('input', type=str, help='input: feature file or folder with feature files') 15*a58d3d2aSXin Liparser.add_argument('output', type=str, help='output: wav file name or folder name, depending on input') 16*a58d3d2aSXin Li 17*a58d3d2aSXin Li 18*a58d3d2aSXin Li########################### Signal Processing Layers ########################### 19*a58d3d2aSXin Li 20*a58d3d2aSXin Lidef preemphasis(x, coef= -0.85): 21*a58d3d2aSXin Li 22*a58d3d2aSXin Li return si.lfilter(np.array([1.0, coef]), np.array([1.0]), x).astype('float32') 23*a58d3d2aSXin Li 24*a58d3d2aSXin Lidef deemphasis(x, coef= -0.85): 25*a58d3d2aSXin Li 26*a58d3d2aSXin Li return si.lfilter(np.array([1.0]), np.array([1.0, coef]), x).astype('float32') 27*a58d3d2aSXin Li 28*a58d3d2aSXin Ligamma = 0.92 29*a58d3d2aSXin Liweighting_vector = np.array([gamma**i for i in range(16,0,-1)]) 30*a58d3d2aSXin Li 31*a58d3d2aSXin Li 32*a58d3d2aSXin Lidef lpc_synthesis_one_frame(frame, filt, buffer, weighting_vector=np.ones(16)): 33*a58d3d2aSXin Li 34*a58d3d2aSXin Li out = np.zeros_like(frame) 35*a58d3d2aSXin Li 36*a58d3d2aSXin Li filt = np.flip(filt) 37*a58d3d2aSXin Li 38*a58d3d2aSXin Li inp = frame[:] 39*a58d3d2aSXin Li 40*a58d3d2aSXin Li 41*a58d3d2aSXin Li for i in range(0, inp.shape[0]): 42*a58d3d2aSXin Li 43*a58d3d2aSXin Li s = inp[i] - np.dot(buffer*weighting_vector, filt) 44*a58d3d2aSXin Li 45*a58d3d2aSXin Li buffer[0] = s 46*a58d3d2aSXin Li 47*a58d3d2aSXin Li buffer = np.roll(buffer, -1) 48*a58d3d2aSXin Li 49*a58d3d2aSXin Li out[i] = s 50*a58d3d2aSXin Li 51*a58d3d2aSXin Li return out 52*a58d3d2aSXin Li 53*a58d3d2aSXin Lidef inverse_perceptual_weighting (pw_signal, filters, weighting_vector): 54*a58d3d2aSXin Li 55*a58d3d2aSXin Li #inverse perceptual weighting= H_preemph / W(z/gamma) 56*a58d3d2aSXin Li 57*a58d3d2aSXin Li pw_signal = preemphasis(pw_signal) 58*a58d3d2aSXin Li 59*a58d3d2aSXin Li signal = np.zeros_like(pw_signal) 60*a58d3d2aSXin Li buffer = np.zeros(16) 61*a58d3d2aSXin Li num_frames = pw_signal.shape[0] //160 62*a58d3d2aSXin Li assert num_frames == filters.shape[0] 63*a58d3d2aSXin Li 64*a58d3d2aSXin Li for frame_idx in range(0, num_frames): 65*a58d3d2aSXin Li 66*a58d3d2aSXin Li in_frame = pw_signal[frame_idx*160: (frame_idx+1)*160][:] 67*a58d3d2aSXin Li out_sig_frame = lpc_synthesis_one_frame(in_frame, filters[frame_idx, :], buffer, weighting_vector) 68*a58d3d2aSXin Li signal[frame_idx*160: (frame_idx+1)*160] = out_sig_frame[:] 69*a58d3d2aSXin Li buffer[:] = out_sig_frame[-16:] 70*a58d3d2aSXin Li 71*a58d3d2aSXin Li return signal 72*a58d3d2aSXin Li 73*a58d3d2aSXin Li 74*a58d3d2aSXin Lidef process_item(generator, feature_filename, output_filename, verbose=False): 75*a58d3d2aSXin Li 76*a58d3d2aSXin Li feat = np.memmap(feature_filename, dtype='float32', mode='r') 77*a58d3d2aSXin Li 78*a58d3d2aSXin Li num_feat_frames = len(feat) // 36 79*a58d3d2aSXin Li feat = np.reshape(feat, (num_feat_frames, 36)) 80*a58d3d2aSXin Li 81*a58d3d2aSXin Li bfcc = np.copy(feat[:, :18]) 82*a58d3d2aSXin Li corr = np.copy(feat[:, 19:20]) + 0.5 83*a58d3d2aSXin Li bfcc_with_corr = torch.from_numpy(np.hstack((bfcc, corr))).type(torch.FloatTensor).unsqueeze(0)#.to(device) 84*a58d3d2aSXin Li 85*a58d3d2aSXin Li period = torch.from_numpy((0.1 + 50 * np.copy(feat[:, 18:19]) + 100)\ 86*a58d3d2aSXin Li .astype('int32')).type(torch.long).view(1,-1)#.to(device) 87*a58d3d2aSXin Li 88*a58d3d2aSXin Li lpc_filters = np.copy(feat[:, -16:]) 89*a58d3d2aSXin Li 90*a58d3d2aSXin Li start_time = time.time() 91*a58d3d2aSXin Li x1 = generator(period, bfcc_with_corr, torch.zeros(1,320)) #this means the vocoder runs in complete synthesis mode with zero history audio frames 92*a58d3d2aSXin Li end_time = time.time() 93*a58d3d2aSXin Li total_time = end_time - start_time 94*a58d3d2aSXin Li x1 = x1.squeeze(1).squeeze(0).detach().cpu().numpy() 95*a58d3d2aSXin Li gen_seconds = len(x1)/16000 96*a58d3d2aSXin Li out = deemphasis(inverse_perceptual_weighting(x1, lpc_filters, weighting_vector)) 97*a58d3d2aSXin Li if verbose: 98*a58d3d2aSXin Li print(f"Took {total_time:.3f}s to generate {len(x1)} samples ({gen_seconds}s) -> {gen_seconds/total_time:.2f}x real time") 99*a58d3d2aSXin Li 100*a58d3d2aSXin Li out = np.clip(np.round(2**15 * out), -2**15, 2**15 -1).astype(np.int16) 101*a58d3d2aSXin Li wavfile.write(output_filename, 16000, out) 102*a58d3d2aSXin Li 103*a58d3d2aSXin Li 104*a58d3d2aSXin Li########################### The inference loop over folder containing lpcnet feature files ################################# 105*a58d3d2aSXin Liif __name__ == "__main__": 106*a58d3d2aSXin Li 107*a58d3d2aSXin Li args = parser.parse_args() 108*a58d3d2aSXin Li 109*a58d3d2aSXin Li generator = model_dict[args.model]() 110*a58d3d2aSXin Li 111*a58d3d2aSXin Li 112*a58d3d2aSXin Li #Load the FWGAN500Hz Checkpoint 113*a58d3d2aSXin Li saved_gen= torch.load(args.weightfile, map_location='cpu') 114*a58d3d2aSXin Li generator.load_state_dict(saved_gen) 115*a58d3d2aSXin Li 116*a58d3d2aSXin Li #this is just to remove the weight_norm from the model layers as it's no longer needed 117*a58d3d2aSXin Li def _remove_weight_norm(m): 118*a58d3d2aSXin Li try: 119*a58d3d2aSXin Li torch.nn.utils.remove_weight_norm(m) 120*a58d3d2aSXin Li except ValueError: # this module didn't have weight norm 121*a58d3d2aSXin Li return 122*a58d3d2aSXin Li generator.apply(_remove_weight_norm) 123*a58d3d2aSXin Li 124*a58d3d2aSXin Li #enable inference mode 125*a58d3d2aSXin Li generator = generator.eval() 126*a58d3d2aSXin Li 127*a58d3d2aSXin Li print('Successfully loaded the generator model ... start generation:') 128*a58d3d2aSXin Li 129*a58d3d2aSXin Li if os.path.isdir(args.input): 130*a58d3d2aSXin Li 131*a58d3d2aSXin Li os.makedirs(args.output, exist_ok=True) 132*a58d3d2aSXin Li 133*a58d3d2aSXin Li for fn in os.listdir(args.input): 134*a58d3d2aSXin Li print(f"processing input {fn}...") 135*a58d3d2aSXin Li feature_filename = os.path.join(args.input, fn) 136*a58d3d2aSXin Li output_filename = os.path.join(args.output, os.path.splitext(fn)[0] + f"_{args.model}.wav") 137*a58d3d2aSXin Li process_item(generator, feature_filename, output_filename) 138*a58d3d2aSXin Li else: 139*a58d3d2aSXin Li process_item(generator, args.input, args.output) 140*a58d3d2aSXin Li 141*a58d3d2aSXin Li print("Finished!")