1import argparse 2parser = argparse.ArgumentParser() 3 4parser.add_argument('features', type=str, help='Features generated from dump_data') 5parser.add_argument('data', type=str, help='Data generated from dump_data (offset by 5ms)') 6parser.add_argument('output', type=str, help='output .f32 feature file with replaced neural pitch') 7parser.add_argument('checkpoint', type=str, help='model checkpoint file') 8parser.add_argument('path_lpcnet_extractor', type=str, help='path to LPCNet extractor object file (generated on compilation)') 9parser.add_argument('--device', type=str, help='compute device',default = None,required = False) 10parser.add_argument('--replace_xcorr', type = bool, default = False, help='Replace LPCNet xcorr with updated one') 11 12args = parser.parse_args() 13 14import os 15 16from utils import stft, random_filter 17import subprocess 18import numpy as np 19import json 20import torch 21import tqdm 22 23from models import PitchDNNIF, PitchDNNXcorr, PitchDNN 24 25device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 26if device is not None: 27 device = torch.device(args.device) 28 29# Loading the appropriate model 30checkpoint = torch.load(args.checkpoint, map_location='cpu') 31dict_params = checkpoint['config'] 32 33if dict_params['data_format'] == 'if': 34 pitch_nn = PitchDNNIF(dict_params['freq_keep']*3, dict_params['gru_dim'], dict_params['output_dim']) 35elif dict_params['data_format'] == 'xcorr': 36 pitch_nn = PitchDNNXcorr(dict_params['xcorr_dim'], dict_params['gru_dim'], dict_params['output_dim']) 37else: 38 pitch_nn = PitchDNN(dict_params['freq_keep']*3, dict_params['xcorr_dim'], dict_params['gru_dim'], dict_params['output_dim']) 39 40pitch_nn.load_state_dict(checkpoint['state_dict']) 41pitch_nn = pitch_nn.to(device) 42 43N = dict_params['window_size'] 44H = dict_params['hop_factor'] 45freq_keep = dict_params['freq_keep'] 46 47os.environ["OMP_NUM_THREADS"] = "16" 48 49 50def run_lpc(signal, lpcs, frame_length=160): 51 num_frames, lpc_order = lpcs.shape 52 53 prediction = np.concatenate( 54 [- np.convolve(signal[i * frame_length : (i + 1) * frame_length + lpc_order - 1], lpcs[i], mode='valid') for i in range(num_frames)] 55 ) 56 error = signal[lpc_order :] - prediction 57 58 return prediction, error 59 60 61if __name__ == "__main__": 62 args = parser.parse_args() 63 64 features = np.memmap(args.features, dtype=np.float32,mode = 'r').reshape((-1, 36)) 65 data = np.memmap(args.data, dtype=np.int16,mode = 'r').reshape((-1, 2)) 66 67 num_frames = features.shape[0] 68 feature_dim = features.shape[1] 69 70 assert feature_dim == 36 71 72 output = np.memmap(args.output, dtype=np.float32, shape=(num_frames, feature_dim), mode='w+') 73 output[:, :36] = features 74 75 # lpc coefficients and signal 76 lpcs = features[:, 20:36] 77 sig = data[:, 1] 78 79 # parameters 80 81 # constants 82 pitch_min = 32 83 pitch_max = 256 84 lpc_order = 16 85 fs = 16000 86 frame_length = 160 87 overlap_frames = 100 88 chunk_size = 10000 89 history_length = frame_length * overlap_frames 90 history = np.zeros(history_length, dtype=np.int16) 91 pitch_position=18 92 xcorr_position=19 93 conf_position=36 94 95 num_frames = len(sig) // 160 - 1 96 97 frame_start = 0 98 frame_stop = min(frame_start + chunk_size, num_frames) 99 signal_start = 0 100 signal_stop = frame_stop * frame_length 101 102 niters = (num_frames - 1)//chunk_size 103 for i in tqdm.trange(niters): 104 if (frame_start > num_frames - 1): 105 break 106 chunk = np.concatenate((history, sig[signal_start:signal_stop])) 107 chunk_la = np.concatenate((history, sig[signal_start:signal_stop + 80])) 108 109 # Feature computation 110 spec = stft(x = np.concatenate([np.zeros(80),chunk_la/(2**15 - 1)]), w = 'boxcar', N = N, H = H).T 111 phase_diff = spec*np.conj(np.roll(spec,1,axis = -1)) 112 phase_diff = phase_diff/(np.abs(phase_diff) + 1.0e-8) 113 idx_save = np.concatenate([np.arange(freq_keep),(N//2 + 1) + np.arange(freq_keep),2*(N//2 + 1) + np.arange(freq_keep)]) 114 feature = np.concatenate([np.log(np.abs(spec) + 1.0e-8),np.real(phase_diff),np.imag(phase_diff)],axis = 0).T 115 feature_if = feature[:,idx_save] 116 117 data_temp = np.memmap('./temp_featcompute_' + dict_params['data_format'] + '_.raw', dtype=np.int16, shape=(chunk.shape[0]), mode='w+') 118 data_temp[:chunk.shape[0]] = chunk_la[80:].astype(np.int16) 119 120 subprocess.run([args.path_lpcnet_extractor, './temp_featcompute_' + dict_params['data_format'] + '_.raw', './temp_featcompute_xcorr_' + dict_params['data_format'] + '_.raw']) 121 feature_xcorr = np.flip(np.fromfile('./temp_featcompute_xcorr_' + dict_params['data_format'] + '_.raw', dtype='float32').reshape((-1,256),order = 'C'),axis = 1) 122 ones_zero_lag = np.expand_dims(np.ones(feature_xcorr.shape[0]),-1) 123 feature_xcorr = np.concatenate([ones_zero_lag,feature_xcorr],axis = -1) 124 125 os.remove('./temp_featcompute_' + dict_params['data_format'] + '_.raw') 126 os.remove('./temp_featcompute_xcorr_' + dict_params['data_format'] + '_.raw') 127 128 if dict_params['data_format'] == 'if': 129 feature = feature_if 130 elif dict_params['data_format'] == 'xcorr': 131 feature = feature_xcorr 132 else: 133 indmin = min(feature_if.shape[0],feature_xcorr.shape[0]) 134 feature = np.concatenate([feature_xcorr[:indmin,:],feature_if[:indmin,:]],-1) 135 136 # Compute pitch with my model 137 model_cents = pitch_nn(torch.from_numpy(np.copy(np.expand_dims(feature,0))).float().to(device)) 138 model_cents = 20*model_cents.argmax(dim=1).cpu().detach().squeeze().numpy() 139 frequency = 62.5*2**(model_cents/1200) 140 141 frequency = frequency[overlap_frames : overlap_frames + frame_stop - frame_start] 142 143 # convert frequencies to periods 144 periods = np.round(fs / frequency) 145 146 periods = np.clip(periods, pitch_min, pitch_max) 147 148 output[frame_start:frame_stop, pitch_position] = (periods - 100) / 50 149 150 frame_offset = (pitch_max + frame_length - 1) // frame_length 151 offset = frame_offset * frame_length 152 padding = lpc_order 153 154 155 if frame_start < frame_offset: 156 lpc_coeffs = np.concatenate((np.zeros((frame_offset - frame_start, lpc_order), dtype=np.float32), lpcs[:frame_stop])) 157 else: 158 lpc_coeffs = lpcs[frame_start - frame_offset : frame_stop] 159 160 pred, error = run_lpc(chunk[history_length - offset - padding :], lpc_coeffs, frame_length=frame_length) 161 162 xcorr = np.zeros(frame_stop - frame_start) 163 for i, p in enumerate(periods.astype(np.int16)): 164 if p > 0: 165 f1 = error[offset + i * frame_length : offset + (i + 1) * frame_length] 166 f2 = error[offset + i * frame_length - p : offset + (i + 1) * frame_length - p] 167 xcorr[i] = np.dot(f1, f2) / np.sqrt(np.dot(f1, f1) * np.dot(f2, f2) + 1e-6) 168 169 output[frame_start:frame_stop, xcorr_position] = xcorr - 0.5 170 171 # update buffers and indices 172 history = chunk[-history_length :] 173 174 frame_start += chunk_size 175 frame_stop += chunk_size 176 frame_stop = min(frame_stop, num_frames) 177 178 signal_start = frame_start * frame_length 179 signal_stop = frame_stop * frame_length 180