xref: /aosp_15_r20/external/libopus/dnn/torch/neural-pitch/neural_pitch_update.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1import argparse
2parser = argparse.ArgumentParser()
3
4parser.add_argument('features', type=str, help='Features generated from dump_data')
5parser.add_argument('data', type=str, help='Data generated from dump_data (offset by 5ms)')
6parser.add_argument('output', type=str, help='output .f32 feature file with replaced neural pitch')
7parser.add_argument('checkpoint', type=str, help='model checkpoint file')
8parser.add_argument('path_lpcnet_extractor', type=str, help='path to LPCNet extractor object file (generated on compilation)')
9parser.add_argument('--device', type=str, help='compute device',default = None,required = False)
10parser.add_argument('--replace_xcorr', type = bool, default = False, help='Replace LPCNet xcorr with updated one')
11
12args = parser.parse_args()
13
14import os
15
16from utils import stft, random_filter
17import subprocess
18import numpy as np
19import json
20import torch
21import tqdm
22
23from models import PitchDNNIF, PitchDNNXcorr, PitchDNN
24
25device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26if device is not None:
27    device = torch.device(args.device)
28
29# Loading the appropriate model
30checkpoint = torch.load(args.checkpoint, map_location='cpu')
31dict_params = checkpoint['config']
32
33if dict_params['data_format'] == 'if':
34    pitch_nn = PitchDNNIF(dict_params['freq_keep']*3, dict_params['gru_dim'], dict_params['output_dim'])
35elif dict_params['data_format'] == 'xcorr':
36    pitch_nn = PitchDNNXcorr(dict_params['xcorr_dim'], dict_params['gru_dim'], dict_params['output_dim'])
37else:
38    pitch_nn = PitchDNN(dict_params['freq_keep']*3, dict_params['xcorr_dim'], dict_params['gru_dim'], dict_params['output_dim'])
39
40pitch_nn.load_state_dict(checkpoint['state_dict'])
41pitch_nn = pitch_nn.to(device)
42
43N = dict_params['window_size']
44H = dict_params['hop_factor']
45freq_keep = dict_params['freq_keep']
46
47os.environ["OMP_NUM_THREADS"] = "16"
48
49
50def run_lpc(signal, lpcs, frame_length=160):
51    num_frames, lpc_order = lpcs.shape
52
53    prediction = np.concatenate(
54        [- np.convolve(signal[i * frame_length : (i + 1) * frame_length + lpc_order - 1], lpcs[i], mode='valid') for i in range(num_frames)]
55    )
56    error = signal[lpc_order :] - prediction
57
58    return prediction, error
59
60
61if __name__ == "__main__":
62    args = parser.parse_args()
63
64    features = np.memmap(args.features, dtype=np.float32,mode = 'r').reshape((-1, 36))
65    data     = np.memmap(args.data, dtype=np.int16,mode = 'r').reshape((-1, 2))
66
67    num_frames = features.shape[0]
68    feature_dim = features.shape[1]
69
70    assert feature_dim == 36
71
72    output  = np.memmap(args.output, dtype=np.float32, shape=(num_frames, feature_dim), mode='w+')
73    output[:, :36] = features
74
75    # lpc coefficients and signal
76    lpcs = features[:, 20:36]
77    sig = data[:, 1]
78
79    # parameters
80
81    # constants
82    pitch_min = 32
83    pitch_max = 256
84    lpc_order = 16
85    fs = 16000
86    frame_length = 160
87    overlap_frames = 100
88    chunk_size = 10000
89    history_length = frame_length * overlap_frames
90    history = np.zeros(history_length, dtype=np.int16)
91    pitch_position=18
92    xcorr_position=19
93    conf_position=36
94
95    num_frames = len(sig) // 160 - 1
96
97    frame_start = 0
98    frame_stop = min(frame_start + chunk_size, num_frames)
99    signal_start = 0
100    signal_stop = frame_stop * frame_length
101
102    niters = (num_frames - 1)//chunk_size
103    for i in tqdm.trange(niters):
104        if (frame_start > num_frames - 1):
105            break
106        chunk = np.concatenate((history, sig[signal_start:signal_stop]))
107        chunk_la = np.concatenate((history, sig[signal_start:signal_stop + 80]))
108
109        # Feature computation
110        spec = stft(x = np.concatenate([np.zeros(80),chunk_la/(2**15 - 1)]), w = 'boxcar', N = N, H = H).T
111        phase_diff = spec*np.conj(np.roll(spec,1,axis = -1))
112        phase_diff = phase_diff/(np.abs(phase_diff) + 1.0e-8)
113        idx_save = np.concatenate([np.arange(freq_keep),(N//2 + 1) + np.arange(freq_keep),2*(N//2 + 1) + np.arange(freq_keep)])
114        feature = np.concatenate([np.log(np.abs(spec) + 1.0e-8),np.real(phase_diff),np.imag(phase_diff)],axis = 0).T
115        feature_if = feature[:,idx_save]
116
117        data_temp = np.memmap('./temp_featcompute_' + dict_params['data_format'] + '_.raw', dtype=np.int16, shape=(chunk.shape[0]), mode='w+')
118        data_temp[:chunk.shape[0]] = chunk_la[80:].astype(np.int16)
119
120        subprocess.run([args.path_lpcnet_extractor, './temp_featcompute_' + dict_params['data_format'] + '_.raw', './temp_featcompute_xcorr_' + dict_params['data_format'] + '_.raw'])
121        feature_xcorr = np.flip(np.fromfile('./temp_featcompute_xcorr_' + dict_params['data_format'] + '_.raw', dtype='float32').reshape((-1,256),order = 'C'),axis = 1)
122        ones_zero_lag = np.expand_dims(np.ones(feature_xcorr.shape[0]),-1)
123        feature_xcorr = np.concatenate([ones_zero_lag,feature_xcorr],axis = -1)
124
125        os.remove('./temp_featcompute_' + dict_params['data_format'] + '_.raw')
126        os.remove('./temp_featcompute_xcorr_' + dict_params['data_format'] + '_.raw')
127
128        if dict_params['data_format'] == 'if':
129            feature = feature_if
130        elif dict_params['data_format'] == 'xcorr':
131            feature = feature_xcorr
132        else:
133            indmin = min(feature_if.shape[0],feature_xcorr.shape[0])
134            feature = np.concatenate([feature_xcorr[:indmin,:],feature_if[:indmin,:]],-1)
135
136        # Compute pitch with my model
137        model_cents = pitch_nn(torch.from_numpy(np.copy(np.expand_dims(feature,0))).float().to(device))
138        model_cents = 20*model_cents.argmax(dim=1).cpu().detach().squeeze().numpy()
139        frequency = 62.5*2**(model_cents/1200)
140
141        frequency  = frequency[overlap_frames : overlap_frames + frame_stop - frame_start]
142
143        # convert frequencies to periods
144        periods    = np.round(fs / frequency)
145
146        periods = np.clip(periods, pitch_min, pitch_max)
147
148        output[frame_start:frame_stop, pitch_position] = (periods - 100) / 50
149
150        frame_offset = (pitch_max + frame_length - 1) // frame_length
151        offset = frame_offset * frame_length
152        padding = lpc_order
153
154
155        if frame_start < frame_offset:
156            lpc_coeffs = np.concatenate((np.zeros((frame_offset - frame_start, lpc_order), dtype=np.float32), lpcs[:frame_stop]))
157        else:
158            lpc_coeffs = lpcs[frame_start - frame_offset : frame_stop]
159
160        pred, error = run_lpc(chunk[history_length - offset - padding :], lpc_coeffs, frame_length=frame_length)
161
162        xcorr = np.zeros(frame_stop - frame_start)
163        for i, p in enumerate(periods.astype(np.int16)):
164            if p > 0:
165                f1 = error[offset + i * frame_length : offset + (i + 1) * frame_length]
166                f2 = error[offset + i * frame_length - p : offset + (i + 1) * frame_length - p]
167                xcorr[i] = np.dot(f1, f2) / np.sqrt(np.dot(f1, f1) * np.dot(f2, f2) + 1e-6)
168
169        output[frame_start:frame_stop, xcorr_position] = xcorr - 0.5
170
171        # update buffers and indices
172        history = chunk[-history_length :]
173
174        frame_start += chunk_size
175        frame_stop += chunk_size
176        frame_stop = min(frame_stop, num_frames)
177
178        signal_start = frame_start * frame_length
179        signal_stop  = frame_stop  * frame_length
180