xref: /aosp_15_r20/external/libopus/dnn/torch/neural-pitch/evaluation.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1*a58d3d2aSXin Li"""
2*a58d3d2aSXin LiEvaluation script to compute the Raw Pitch Accuracy
3*a58d3d2aSXin LiProcedure:
4*a58d3d2aSXin Li    - Look at all voiced frames in file
5*a58d3d2aSXin Li    - Compute number of pitches in those frames that lie within a 50 cent threshold
6*a58d3d2aSXin Li    RPA = (Total number of pitches within threshold summed across all files)/(Total number of voiced frames summed accross all files)
7*a58d3d2aSXin Li"""
8*a58d3d2aSXin Li
9*a58d3d2aSXin Liimport os
10*a58d3d2aSXin Lios.environ["CUDA_VISIBLE_DEVICES"] = "0"
11*a58d3d2aSXin Li
12*a58d3d2aSXin Lifrom prettytable import PrettyTable
13*a58d3d2aSXin Liimport numpy as np
14*a58d3d2aSXin Liimport glob
15*a58d3d2aSXin Liimport random
16*a58d3d2aSXin Liimport tqdm
17*a58d3d2aSXin Liimport torch
18*a58d3d2aSXin Liimport librosa
19*a58d3d2aSXin Liimport json
20*a58d3d2aSXin Lifrom utils import stft, random_filter, feature_xform
21*a58d3d2aSXin Liimport subprocess
22*a58d3d2aSXin Liimport crepe
23*a58d3d2aSXin Li
24*a58d3d2aSXin Lifrom models import PitchDNN, PitchDNNIF, PitchDNNXcorr
25*a58d3d2aSXin Li
26*a58d3d2aSXin Lidevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27*a58d3d2aSXin Li
28*a58d3d2aSXin Lidef rca(reference,input,voicing,thresh = 25):
29*a58d3d2aSXin Li    idx_voiced = np.where(voicing != 0)[0]
30*a58d3d2aSXin Li    acc = np.where(np.abs(reference - input)[idx_voiced] < thresh)[0]
31*a58d3d2aSXin Li    return acc.shape[0]
32*a58d3d2aSXin Li
33*a58d3d2aSXin Lidef sweep_rca(reference,input,voicing,thresh = 25,ind_arr = np.arange(-10,10)):
34*a58d3d2aSXin Li    l = []
35*a58d3d2aSXin Li    for i in ind_arr:
36*a58d3d2aSXin Li        l.append(rca(reference,np.roll(input,i),voicing,thresh))
37*a58d3d2aSXin Li    l = np.array(l)
38*a58d3d2aSXin Li
39*a58d3d2aSXin Li    return np.max(l)
40*a58d3d2aSXin Li
41*a58d3d2aSXin Lidef rpa(model,device = 'cpu',data_format = 'if'):
42*a58d3d2aSXin Li    list_files = glob.glob('/home/ubuntu/Code/Datasets/SPEECH DATA/combined_mic_16k_raw/*.raw')
43*a58d3d2aSXin Li    dir_f0 = '/home/ubuntu/Code/Datasets/SPEECH DATA/combine_f0_ptdb/'
44*a58d3d2aSXin Li    # random_shuffle = list(np.random.permutation(len(list_files)))
45*a58d3d2aSXin Li    random.shuffle(list_files)
46*a58d3d2aSXin Li    list_files = list_files[:1000]
47*a58d3d2aSXin Li
48*a58d3d2aSXin Li    C_all = 0
49*a58d3d2aSXin Li    C_all_m = 0
50*a58d3d2aSXin Li    C_all_f = 0
51*a58d3d2aSXin Li    list_rca_model_all = []
52*a58d3d2aSXin Li    list_rca_male_all = []
53*a58d3d2aSXin Li    list_rca_female_all = []
54*a58d3d2aSXin Li
55*a58d3d2aSXin Li    thresh = 50
56*a58d3d2aSXin Li    N = 320
57*a58d3d2aSXin Li    H = 160
58*a58d3d2aSXin Li    freq_keep = 30
59*a58d3d2aSXin Li
60*a58d3d2aSXin Li    for idx in tqdm.trange(len(list_files)):
61*a58d3d2aSXin Li        audio_file = list_files[idx]
62*a58d3d2aSXin Li        file_name = os.path.basename(list_files[idx])[:-4]
63*a58d3d2aSXin Li
64*a58d3d2aSXin Li        audio = np.memmap(list_files[idx], dtype=np.int16)/(2**15 - 1)
65*a58d3d2aSXin Li        offset = 432
66*a58d3d2aSXin Li        audio = audio[offset:]
67*a58d3d2aSXin Li        rmse = np.squeeze(librosa.feature.rms(y = audio,frame_length = 320,hop_length = 160))
68*a58d3d2aSXin Li
69*a58d3d2aSXin Li        spec = stft(x = np.concatenate([np.zeros(160),audio]), w = 'boxcar', N = N, H = H).T
70*a58d3d2aSXin Li        phase_diff = spec*np.conj(np.roll(spec,1,axis = -1))
71*a58d3d2aSXin Li        phase_diff = phase_diff/(np.abs(phase_diff) + 1.0e-8)
72*a58d3d2aSXin Li        idx_save = np.concatenate([np.arange(freq_keep),(N//2 + 1) + np.arange(freq_keep),2*(N//2 + 1) + np.arange(freq_keep)])
73*a58d3d2aSXin Li        feature = np.concatenate([np.log(np.abs(spec) + 1.0e-8),np.real(phase_diff),np.imag(phase_diff)],axis = 0).T
74*a58d3d2aSXin Li        feature_if = feature[:,idx_save]
75*a58d3d2aSXin Li
76*a58d3d2aSXin Li        data_temp = np.memmap('./temp.raw', dtype=np.int16, shape=(audio.shape[0]), mode='w+')
77*a58d3d2aSXin Li        data_temp[:audio.shape[0]] = (audio/(np.max(np.abs(audio)))*(2**15 - 1)).astype(np.int16)
78*a58d3d2aSXin Li
79*a58d3d2aSXin Li        subprocess.run(["../../../lpcnet_xcorr_extractor", './temp.raw', './temp_xcorr.f32'])
80*a58d3d2aSXin Li        feature_xcorr = np.flip(np.fromfile('./temp_xcorr.f32', dtype='float32').reshape((-1,256),order = 'C'),axis = 1)
81*a58d3d2aSXin Li        ones_zero_lag = np.expand_dims(np.ones(feature_xcorr.shape[0]),-1)
82*a58d3d2aSXin Li        feature_xcorr = np.concatenate([ones_zero_lag,feature_xcorr],axis = -1)
83*a58d3d2aSXin Li        # feature_xcorr = feature_xform(feature_xcorr)
84*a58d3d2aSXin Li
85*a58d3d2aSXin Li        os.remove('./temp.raw')
86*a58d3d2aSXin Li        os.remove('./temp_xcorr.f32')
87*a58d3d2aSXin Li
88*a58d3d2aSXin Li        if data_format == 'if':
89*a58d3d2aSXin Li            feature = feature_if
90*a58d3d2aSXin Li        elif data_format == 'xcorr':
91*a58d3d2aSXin Li            feature = feature_xcorr
92*a58d3d2aSXin Li        else:
93*a58d3d2aSXin Li            indmin = min(feature_if.shape[0],feature_xcorr.shape[0])
94*a58d3d2aSXin Li            feature = np.concatenate([feature_xcorr[:indmin,:],feature_if[:indmin,:]],-1)
95*a58d3d2aSXin Li
96*a58d3d2aSXin Li
97*a58d3d2aSXin Li        pitch_file_name = dir_f0 + "ref" + os.path.basename(list_files[idx])[3:-4] + ".f0"
98*a58d3d2aSXin Li        pitch = np.loadtxt(pitch_file_name)[:,0]
99*a58d3d2aSXin Li        voicing = np.loadtxt(pitch_file_name)[:,1]
100*a58d3d2aSXin Li        indmin = min(voicing.shape[0],rmse.shape[0],pitch.shape[0])
101*a58d3d2aSXin Li        pitch = pitch[:indmin]
102*a58d3d2aSXin Li        voicing = voicing[:indmin]
103*a58d3d2aSXin Li        rmse = rmse[:indmin]
104*a58d3d2aSXin Li        voicing = voicing*(rmse > 0.05*np.max(rmse))
105*a58d3d2aSXin Li        if "mic_F" in audio_file:
106*a58d3d2aSXin Li            idx_correct = np.where(pitch < 125)
107*a58d3d2aSXin Li            voicing[idx_correct] = 0
108*a58d3d2aSXin Li
109*a58d3d2aSXin Li        cent = np.rint(1200*np.log2(np.divide(pitch, (16000/256), out=np.zeros_like(pitch), where=pitch!=0) + 1.0e-8)).astype('int')
110*a58d3d2aSXin Li
111*a58d3d2aSXin Li
112*a58d3d2aSXin Li        model_cents = model(torch.from_numpy(np.copy(np.expand_dims(feature,0))).float().to(device))
113*a58d3d2aSXin Li        model_cents = 20*model_cents.argmax(dim=1).cpu().detach().squeeze().numpy()
114*a58d3d2aSXin Li
115*a58d3d2aSXin Li        num_frames = min(cent.shape[0],model_cents.shape[0])
116*a58d3d2aSXin Li        pitch = pitch[:num_frames]
117*a58d3d2aSXin Li        cent = cent[:num_frames]
118*a58d3d2aSXin Li        voicing = voicing[:num_frames]
119*a58d3d2aSXin Li        model_cents = model_cents[:num_frames]
120*a58d3d2aSXin Li
121*a58d3d2aSXin Li        voicing_all = np.copy(voicing)
122*a58d3d2aSXin Li        # Forcefully make regions where pitch is <65 or greater than 500 unvoiced for relevant accurate pitch comparisons for our model
123*a58d3d2aSXin Li        force_out_of_pitch = np.where(np.logical_or(pitch < 65,pitch > 500)==True)
124*a58d3d2aSXin Li        voicing_all[force_out_of_pitch] = 0
125*a58d3d2aSXin Li        C_all = C_all + np.where(voicing_all != 0)[0].shape[0]
126*a58d3d2aSXin Li
127*a58d3d2aSXin Li        list_rca_model_all.append(rca(cent,model_cents,voicing_all,thresh))
128*a58d3d2aSXin Li
129*a58d3d2aSXin Li        if "mic_M" in audio_file:
130*a58d3d2aSXin Li            list_rca_male_all.append(rca(cent,model_cents,voicing_all,thresh))
131*a58d3d2aSXin Li            C_all_m = C_all_m + np.where(voicing_all != 0)[0].shape[0]
132*a58d3d2aSXin Li        else:
133*a58d3d2aSXin Li            list_rca_female_all.append(rca(cent,model_cents,voicing_all,thresh))
134*a58d3d2aSXin Li            C_all_f = C_all_f + np.where(voicing_all != 0)[0].shape[0]
135*a58d3d2aSXin Li
136*a58d3d2aSXin Li    list_rca_model_all = np.array(list_rca_model_all)
137*a58d3d2aSXin Li    list_rca_male_all = np.array(list_rca_male_all)
138*a58d3d2aSXin Li    list_rca_female_all = np.array(list_rca_female_all)
139*a58d3d2aSXin Li
140*a58d3d2aSXin Li
141*a58d3d2aSXin Li    x = PrettyTable()
142*a58d3d2aSXin Li
143*a58d3d2aSXin Li    x.field_names = ["Experiment", "Mean RPA"]
144*a58d3d2aSXin Li    x.add_row(["Both all pitches", np.sum(list_rca_model_all)/C_all])
145*a58d3d2aSXin Li
146*a58d3d2aSXin Li    x.add_row(["Male all pitches", np.sum(list_rca_male_all)/C_all_m])
147*a58d3d2aSXin Li
148*a58d3d2aSXin Li    x.add_row(["Female all pitches", np.sum(list_rca_female_all)/C_all_f])
149*a58d3d2aSXin Li
150*a58d3d2aSXin Li    print(x)
151*a58d3d2aSXin Li
152*a58d3d2aSXin Li    return None
153*a58d3d2aSXin Li
154*a58d3d2aSXin Lidef cycle_eval(checkpoint_list, noise_type = 'synthetic', noise_dataset = None, list_snr = [-20,-15,-10,-5,0,5,10,15,20], ptdb_dataset_path = None,fraction = 0.1,thresh = 50):
155*a58d3d2aSXin Li    """
156*a58d3d2aSXin Li    Cycle through SNR evaluation for list of checkpoints
157*a58d3d2aSXin Li    """
158*a58d3d2aSXin Li    list_files = glob.glob(ptdb_dataset_path + 'combined_mic_16k/*.raw')
159*a58d3d2aSXin Li    dir_f0 = ptdb_dataset_path + 'combined_reference_f0/'
160*a58d3d2aSXin Li    random.shuffle(list_files)
161*a58d3d2aSXin Li    list_files = list_files[:(int)(fraction*len(list_files))]
162*a58d3d2aSXin Li
163*a58d3d2aSXin Li    dict_models = {}
164*a58d3d2aSXin Li    list_snr.append(np.inf)
165*a58d3d2aSXin Li
166*a58d3d2aSXin Li    for f in checkpoint_list:
167*a58d3d2aSXin Li        if (f!='crepe') and (f!='lpcnet'):
168*a58d3d2aSXin Li
169*a58d3d2aSXin Li            checkpoint = torch.load(f, map_location='cpu')
170*a58d3d2aSXin Li            dict_params = checkpoint['config']
171*a58d3d2aSXin Li            if dict_params['data_format'] == 'if':
172*a58d3d2aSXin Li                from models import large_if_ccode as model
173*a58d3d2aSXin Li                pitch_nn = PitchDNNIF(dict_params['freq_keep']*3,dict_params['gru_dim'],dict_params['output_dim'])
174*a58d3d2aSXin Li            elif dict_params['data_format'] == 'xcorr':
175*a58d3d2aSXin Li                from models import large_xcorr as model
176*a58d3d2aSXin Li                pitch_nn = PitchDNNXcorr(dict_params['xcorr_dim'],dict_params['gru_dim'],dict_params['output_dim'])
177*a58d3d2aSXin Li            else:
178*a58d3d2aSXin Li                from models import large_joint as model
179*a58d3d2aSXin Li                pitch_nn = PitchDNN(dict_params['freq_keep']*3,dict_params['xcorr_dim'],dict_params['gru_dim'],dict_params['output_dim'])
180*a58d3d2aSXin Li
181*a58d3d2aSXin Li            pitch_nn.load_state_dict(checkpoint['state_dict'])
182*a58d3d2aSXin Li
183*a58d3d2aSXin Li            N = dict_params['window_size']
184*a58d3d2aSXin Li            H = dict_params['hop_factor']
185*a58d3d2aSXin Li            freq_keep = dict_params['freq_keep']
186*a58d3d2aSXin Li
187*a58d3d2aSXin Li            list_mean = []
188*a58d3d2aSXin Li            list_std = []
189*a58d3d2aSXin Li            for snr_dB in list_snr:
190*a58d3d2aSXin Li                C_all = 0
191*a58d3d2aSXin Li                C_correct = 0
192*a58d3d2aSXin Li                for idx in tqdm.trange(len(list_files)):
193*a58d3d2aSXin Li                    audio_file = list_files[idx]
194*a58d3d2aSXin Li                    file_name = os.path.basename(list_files[idx])[:-4]
195*a58d3d2aSXin Li
196*a58d3d2aSXin Li                    audio = np.memmap(list_files[idx], dtype=np.int16)/(2**15 - 1)
197*a58d3d2aSXin Li                    offset = 432
198*a58d3d2aSXin Li                    audio = audio[offset:]
199*a58d3d2aSXin Li                    rmse = np.squeeze(librosa.feature.rms(y = audio,frame_length = N,hop_length = H))
200*a58d3d2aSXin Li
201*a58d3d2aSXin Li                    if noise_type != 'synthetic':
202*a58d3d2aSXin Li                        list_noisefiles = noise_dataset + '*.wav'
203*a58d3d2aSXin Li                        noise_file = random.choice(glob.glob(list_noisefiles))
204*a58d3d2aSXin Li                        n = np.memmap(noise_file, dtype=np.int16,mode = 'r')/(2**15 - 1)
205*a58d3d2aSXin Li                        rand_range = np.random.randint(low = 0, high = (16000*60*5 - audio.shape[0])) # Last 1 minute of noise used for testing
206*a58d3d2aSXin Li                        n = n[rand_range:rand_range + audio.shape[0]]
207*a58d3d2aSXin Li                    else:
208*a58d3d2aSXin Li                        n = np.random.randn(audio.shape[0])
209*a58d3d2aSXin Li                        n = random_filter(n)
210*a58d3d2aSXin Li
211*a58d3d2aSXin Li                    snr_multiplier = np.sqrt((np.sum(np.abs(audio)**2)/np.sum(np.abs(n)**2))*10**(-snr_dB/10))
212*a58d3d2aSXin Li                    audio = audio + snr_multiplier*n
213*a58d3d2aSXin Li
214*a58d3d2aSXin Li                    spec = stft(x = np.concatenate([np.zeros(160),audio]), w = 'boxcar', N = N, H = H).T
215*a58d3d2aSXin Li                    phase_diff = spec*np.conj(np.roll(spec,1,axis = -1))
216*a58d3d2aSXin Li                    phase_diff = phase_diff/(np.abs(phase_diff) + 1.0e-8)
217*a58d3d2aSXin Li                    idx_save = np.concatenate([np.arange(freq_keep),(N//2 + 1) + np.arange(freq_keep),2*(N//2 + 1) + np.arange(freq_keep)])
218*a58d3d2aSXin Li                    feature = np.concatenate([np.log(np.abs(spec) + 1.0e-8),np.real(phase_diff),np.imag(phase_diff)],axis = 0).T
219*a58d3d2aSXin Li                    feature_if = feature[:,idx_save]
220*a58d3d2aSXin Li
221*a58d3d2aSXin Li                    data_temp = np.memmap('./temp.raw', dtype=np.int16, shape=(audio.shape[0]), mode='w+')
222*a58d3d2aSXin Li                    # data_temp[:audio.shape[0]] = (audio/(np.max(np.abs(audio)))*(2**15 - 1)).astype(np.int16)
223*a58d3d2aSXin Li                    data_temp[:audio.shape[0]] = ((audio)*(2**15 - 1)).astype(np.int16)
224*a58d3d2aSXin Li
225*a58d3d2aSXin Li                    subprocess.run(["../../../lpcnet_xcorr_extractor", './temp.raw', './temp_xcorr.f32'])
226*a58d3d2aSXin Li                    feature_xcorr = np.flip(np.fromfile('./temp_xcorr.f32', dtype='float32').reshape((-1,256),order = 'C'),axis = 1)
227*a58d3d2aSXin Li                    ones_zero_lag = np.expand_dims(np.ones(feature_xcorr.shape[0]),-1)
228*a58d3d2aSXin Li                    feature_xcorr = np.concatenate([ones_zero_lag,feature_xcorr],axis = -1)
229*a58d3d2aSXin Li
230*a58d3d2aSXin Li                    os.remove('./temp.raw')
231*a58d3d2aSXin Li                    os.remove('./temp_xcorr.f32')
232*a58d3d2aSXin Li
233*a58d3d2aSXin Li                    if dict_params['data_format'] == 'if':
234*a58d3d2aSXin Li                        feature = feature_if
235*a58d3d2aSXin Li                    elif dict_params['data_format'] == 'xcorr':
236*a58d3d2aSXin Li                        feature = feature_xcorr
237*a58d3d2aSXin Li                    else:
238*a58d3d2aSXin Li                        indmin = min(feature_if.shape[0],feature_xcorr.shape[0])
239*a58d3d2aSXin Li                        feature = np.concatenate([feature_xcorr[:indmin,:],feature_if[:indmin,:]],-1)
240*a58d3d2aSXin Li
241*a58d3d2aSXin Li                    pitch_file_name = dir_f0 + "ref" + os.path.basename(list_files[idx])[3:-4] + ".f0"
242*a58d3d2aSXin Li                    pitch = np.loadtxt(pitch_file_name)[:,0]
243*a58d3d2aSXin Li                    voicing = np.loadtxt(pitch_file_name)[:,1]
244*a58d3d2aSXin Li                    indmin = min(voicing.shape[0],rmse.shape[0],pitch.shape[0])
245*a58d3d2aSXin Li                    pitch = pitch[:indmin]
246*a58d3d2aSXin Li                    voicing = voicing[:indmin]
247*a58d3d2aSXin Li                    rmse = rmse[:indmin]
248*a58d3d2aSXin Li                    voicing = voicing*(rmse > 0.05*np.max(rmse))
249*a58d3d2aSXin Li                    if "mic_F" in audio_file:
250*a58d3d2aSXin Li                        idx_correct = np.where(pitch < 125)
251*a58d3d2aSXin Li                        voicing[idx_correct] = 0
252*a58d3d2aSXin Li
253*a58d3d2aSXin Li                    cent = np.rint(1200*np.log2(np.divide(pitch, (16000/256), out=np.zeros_like(pitch), where=pitch!=0) + 1.0e-8)).astype('int')
254*a58d3d2aSXin Li
255*a58d3d2aSXin Li                    model_cents = pitch_nn(torch.from_numpy(np.copy(np.expand_dims(feature,0))).float().to(device))
256*a58d3d2aSXin Li                    model_cents = 20*model_cents.argmax(dim=1).cpu().detach().squeeze().numpy()
257*a58d3d2aSXin Li
258*a58d3d2aSXin Li                    num_frames = min(cent.shape[0],model_cents.shape[0])
259*a58d3d2aSXin Li                    pitch = pitch[:num_frames]
260*a58d3d2aSXin Li                    cent = cent[:num_frames]
261*a58d3d2aSXin Li                    voicing = voicing[:num_frames]
262*a58d3d2aSXin Li                    model_cents = model_cents[:num_frames]
263*a58d3d2aSXin Li
264*a58d3d2aSXin Li                    voicing_all = np.copy(voicing)
265*a58d3d2aSXin Li                    # Forcefully make regions where pitch is <65 or greater than 500 unvoiced for relevant accurate pitch comparisons for our model
266*a58d3d2aSXin Li                    force_out_of_pitch = np.where(np.logical_or(pitch < 65,pitch > 500)==True)
267*a58d3d2aSXin Li                    voicing_all[force_out_of_pitch] = 0
268*a58d3d2aSXin Li                    C_all = C_all + np.where(voicing_all != 0)[0].shape[0]
269*a58d3d2aSXin Li
270*a58d3d2aSXin Li                    C_correct = C_correct + rca(cent,model_cents,voicing_all,thresh)
271*a58d3d2aSXin Li                list_mean.append(C_correct/C_all)
272*a58d3d2aSXin Li        else:
273*a58d3d2aSXin Li            fname = f
274*a58d3d2aSXin Li            list_mean = []
275*a58d3d2aSXin Li            list_std = []
276*a58d3d2aSXin Li            for snr_dB in list_snr:
277*a58d3d2aSXin Li                C_all = 0
278*a58d3d2aSXin Li                C_correct = 0
279*a58d3d2aSXin Li                for idx in tqdm.trange(len(list_files)):
280*a58d3d2aSXin Li                    audio_file = list_files[idx]
281*a58d3d2aSXin Li                    file_name = os.path.basename(list_files[idx])[:-4]
282*a58d3d2aSXin Li
283*a58d3d2aSXin Li                    audio = np.memmap(list_files[idx], dtype=np.int16)/(2**15 - 1)
284*a58d3d2aSXin Li                    offset = 432
285*a58d3d2aSXin Li                    audio = audio[offset:]
286*a58d3d2aSXin Li                    rmse = np.squeeze(librosa.feature.rms(y = audio,frame_length = 320,hop_length = 160))
287*a58d3d2aSXin Li
288*a58d3d2aSXin Li                    if noise_type != 'synthetic':
289*a58d3d2aSXin Li                        list_noisefiles = noise_dataset + '*.wav'
290*a58d3d2aSXin Li                        noise_file = random.choice(glob.glob(list_noisefiles))
291*a58d3d2aSXin Li                        n = np.memmap(noise_file, dtype=np.int16,mode = 'r')/(2**15 - 1)
292*a58d3d2aSXin Li                        rand_range = np.random.randint(low = 0, high = (16000*60*5 - audio.shape[0])) # Last 1 minute of noise used for testing
293*a58d3d2aSXin Li                        n = n[rand_range:rand_range + audio.shape[0]]
294*a58d3d2aSXin Li                    else:
295*a58d3d2aSXin Li                        n = np.random.randn(audio.shape[0])
296*a58d3d2aSXin Li                        n = random_filter(n)
297*a58d3d2aSXin Li
298*a58d3d2aSXin Li                    snr_multiplier = np.sqrt((np.sum(np.abs(audio)**2)/np.sum(np.abs(n)**2))*10**(-snr_dB/10))
299*a58d3d2aSXin Li                    audio = audio + snr_multiplier*n
300*a58d3d2aSXin Li
301*a58d3d2aSXin Li                    if (f == 'crepe'):
302*a58d3d2aSXin Li                        _, model_frequency, _, _ = crepe.predict(np.concatenate([np.zeros(80),audio]), 16000, viterbi=True,center=True,verbose=0)
303*a58d3d2aSXin Li                        model_cents = 1200*np.log2(model_frequency/(16000/256) + 1.0e-8)
304*a58d3d2aSXin Li                    else:
305*a58d3d2aSXin Li                        data_temp = np.memmap('./temp.raw', dtype=np.int16, shape=(audio.shape[0]), mode='w+')
306*a58d3d2aSXin Li                        # data_temp[:audio.shape[0]] = (audio/(np.max(np.abs(audio)))*(2**15 - 1)).astype(np.int16)
307*a58d3d2aSXin Li                        data_temp[:audio.shape[0]] = ((audio)*(2**15 - 1)).astype(np.int16)
308*a58d3d2aSXin Li
309*a58d3d2aSXin Li                        subprocess.run(["../../../lpcnet_xcorr_extractor", './temp.raw', './temp_xcorr.f32', './temp_period.f32'])
310*a58d3d2aSXin Li                        feature_xcorr = np.fromfile('./temp_period.f32', dtype='float32')
311*a58d3d2aSXin Li                        model_cents = 1200*np.log2((256/feature_xcorr +  1.0e-8) + 1.0e-8)
312*a58d3d2aSXin Li
313*a58d3d2aSXin Li                        os.remove('./temp.raw')
314*a58d3d2aSXin Li                        os.remove('./temp_xcorr.f32')
315*a58d3d2aSXin Li                        os.remove('./temp_period.f32')
316*a58d3d2aSXin Li
317*a58d3d2aSXin Li
318*a58d3d2aSXin Li                    pitch_file_name = dir_f0 + "ref" + os.path.basename(list_files[idx])[3:-4] + ".f0"
319*a58d3d2aSXin Li                    pitch = np.loadtxt(pitch_file_name)[:,0]
320*a58d3d2aSXin Li                    voicing = np.loadtxt(pitch_file_name)[:,1]
321*a58d3d2aSXin Li                    indmin = min(voicing.shape[0],rmse.shape[0],pitch.shape[0])
322*a58d3d2aSXin Li                    pitch = pitch[:indmin]
323*a58d3d2aSXin Li                    voicing = voicing[:indmin]
324*a58d3d2aSXin Li                    rmse = rmse[:indmin]
325*a58d3d2aSXin Li                    voicing = voicing*(rmse > 0.05*np.max(rmse))
326*a58d3d2aSXin Li                    if "mic_F" in audio_file:
327*a58d3d2aSXin Li                        idx_correct = np.where(pitch < 125)
328*a58d3d2aSXin Li                        voicing[idx_correct] = 0
329*a58d3d2aSXin Li
330*a58d3d2aSXin Li                    cent = np.rint(1200*np.log2(np.divide(pitch, (16000/256), out=np.zeros_like(pitch), where=pitch!=0) + 1.0e-8)).astype('int')
331*a58d3d2aSXin Li                    num_frames = min(cent.shape[0],model_cents.shape[0])
332*a58d3d2aSXin Li                    pitch = pitch[:num_frames]
333*a58d3d2aSXin Li                    cent = cent[:num_frames]
334*a58d3d2aSXin Li                    voicing = voicing[:num_frames]
335*a58d3d2aSXin Li                    model_cents = model_cents[:num_frames]
336*a58d3d2aSXin Li
337*a58d3d2aSXin Li                    voicing_all = np.copy(voicing)
338*a58d3d2aSXin Li                    # Forcefully make regions where pitch is <65 or greater than 500 unvoiced for relevant accurate pitch comparisons for our model
339*a58d3d2aSXin Li                    force_out_of_pitch = np.where(np.logical_or(pitch < 65,pitch > 500)==True)
340*a58d3d2aSXin Li                    voicing_all[force_out_of_pitch] = 0
341*a58d3d2aSXin Li                    C_all = C_all + np.where(voicing_all != 0)[0].shape[0]
342*a58d3d2aSXin Li
343*a58d3d2aSXin Li                    C_correct = C_correct + rca(cent,model_cents,voicing_all,thresh)
344*a58d3d2aSXin Li                list_mean.append(C_correct/C_all)
345*a58d3d2aSXin Li        dict_models[fname] = {}
346*a58d3d2aSXin Li        dict_models[fname]['list_SNR'] = list_mean[:-1]
347*a58d3d2aSXin Li        dict_models[fname]['inf'] = list_mean[-1]
348*a58d3d2aSXin Li
349*a58d3d2aSXin Li    return dict_models
350