xref: /aosp_15_r20/external/libopus/dnn/torch/fwgan/inference.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1*a58d3d2aSXin Liimport os
2*a58d3d2aSXin Liimport time
3*a58d3d2aSXin Liimport torch
4*a58d3d2aSXin Liimport numpy as np
5*a58d3d2aSXin Lifrom scipy import signal as si
6*a58d3d2aSXin Lifrom scipy.io import wavfile
7*a58d3d2aSXin Liimport argparse
8*a58d3d2aSXin Li
9*a58d3d2aSXin Lifrom models import model_dict
10*a58d3d2aSXin Li
11*a58d3d2aSXin Liparser = argparse.ArgumentParser()
12*a58d3d2aSXin Liparser.add_argument('model', choices=['fwgan400', 'fwgan500'], help='model name')
13*a58d3d2aSXin Liparser.add_argument('weightfile', type=str, help='weight file')
14*a58d3d2aSXin Liparser.add_argument('input', type=str, help='input: feature file or folder with feature files')
15*a58d3d2aSXin Liparser.add_argument('output', type=str, help='output: wav file name or folder name, depending on input')
16*a58d3d2aSXin Li
17*a58d3d2aSXin Li
18*a58d3d2aSXin Li########################### Signal Processing Layers ###########################
19*a58d3d2aSXin Li
20*a58d3d2aSXin Lidef preemphasis(x, coef= -0.85):
21*a58d3d2aSXin Li
22*a58d3d2aSXin Li    return si.lfilter(np.array([1.0, coef]), np.array([1.0]), x).astype('float32')
23*a58d3d2aSXin Li
24*a58d3d2aSXin Lidef deemphasis(x, coef= -0.85):
25*a58d3d2aSXin Li
26*a58d3d2aSXin Li    return si.lfilter(np.array([1.0]), np.array([1.0, coef]), x).astype('float32')
27*a58d3d2aSXin Li
28*a58d3d2aSXin Ligamma = 0.92
29*a58d3d2aSXin Liweighting_vector = np.array([gamma**i for i in range(16,0,-1)])
30*a58d3d2aSXin Li
31*a58d3d2aSXin Li
32*a58d3d2aSXin Lidef lpc_synthesis_one_frame(frame, filt, buffer, weighting_vector=np.ones(16)):
33*a58d3d2aSXin Li
34*a58d3d2aSXin Li    out = np.zeros_like(frame)
35*a58d3d2aSXin Li
36*a58d3d2aSXin Li    filt = np.flip(filt)
37*a58d3d2aSXin Li
38*a58d3d2aSXin Li    inp = frame[:]
39*a58d3d2aSXin Li
40*a58d3d2aSXin Li
41*a58d3d2aSXin Li    for i in range(0, inp.shape[0]):
42*a58d3d2aSXin Li
43*a58d3d2aSXin Li        s = inp[i] - np.dot(buffer*weighting_vector, filt)
44*a58d3d2aSXin Li
45*a58d3d2aSXin Li        buffer[0] = s
46*a58d3d2aSXin Li
47*a58d3d2aSXin Li        buffer = np.roll(buffer, -1)
48*a58d3d2aSXin Li
49*a58d3d2aSXin Li        out[i] = s
50*a58d3d2aSXin Li
51*a58d3d2aSXin Li    return out
52*a58d3d2aSXin Li
53*a58d3d2aSXin Lidef inverse_perceptual_weighting (pw_signal, filters, weighting_vector):
54*a58d3d2aSXin Li
55*a58d3d2aSXin Li    #inverse perceptual weighting= H_preemph / W(z/gamma)
56*a58d3d2aSXin Li
57*a58d3d2aSXin Li    pw_signal = preemphasis(pw_signal)
58*a58d3d2aSXin Li
59*a58d3d2aSXin Li    signal = np.zeros_like(pw_signal)
60*a58d3d2aSXin Li    buffer = np.zeros(16)
61*a58d3d2aSXin Li    num_frames = pw_signal.shape[0] //160
62*a58d3d2aSXin Li    assert num_frames == filters.shape[0]
63*a58d3d2aSXin Li
64*a58d3d2aSXin Li    for frame_idx in range(0, num_frames):
65*a58d3d2aSXin Li
66*a58d3d2aSXin Li        in_frame = pw_signal[frame_idx*160: (frame_idx+1)*160][:]
67*a58d3d2aSXin Li        out_sig_frame = lpc_synthesis_one_frame(in_frame, filters[frame_idx, :], buffer, weighting_vector)
68*a58d3d2aSXin Li        signal[frame_idx*160: (frame_idx+1)*160] = out_sig_frame[:]
69*a58d3d2aSXin Li        buffer[:] = out_sig_frame[-16:]
70*a58d3d2aSXin Li
71*a58d3d2aSXin Li    return signal
72*a58d3d2aSXin Li
73*a58d3d2aSXin Li
74*a58d3d2aSXin Lidef process_item(generator, feature_filename, output_filename, verbose=False):
75*a58d3d2aSXin Li
76*a58d3d2aSXin Li    feat = np.memmap(feature_filename, dtype='float32', mode='r')
77*a58d3d2aSXin Li
78*a58d3d2aSXin Li    num_feat_frames = len(feat) // 36
79*a58d3d2aSXin Li    feat = np.reshape(feat, (num_feat_frames, 36))
80*a58d3d2aSXin Li
81*a58d3d2aSXin Li    bfcc = np.copy(feat[:, :18])
82*a58d3d2aSXin Li    corr = np.copy(feat[:, 19:20]) + 0.5
83*a58d3d2aSXin Li    bfcc_with_corr =  torch.from_numpy(np.hstack((bfcc, corr))).type(torch.FloatTensor).unsqueeze(0)#.to(device)
84*a58d3d2aSXin Li
85*a58d3d2aSXin Li    period = torch.from_numpy((0.1 + 50 * np.copy(feat[:, 18:19]) + 100)\
86*a58d3d2aSXin Li                            .astype('int32')).type(torch.long).view(1,-1)#.to(device)
87*a58d3d2aSXin Li
88*a58d3d2aSXin Li    lpc_filters = np.copy(feat[:, -16:])
89*a58d3d2aSXin Li
90*a58d3d2aSXin Li    start_time = time.time()
91*a58d3d2aSXin Li    x1 = generator(period, bfcc_with_corr, torch.zeros(1,320)) #this means the vocoder runs in complete synthesis mode with zero history audio frames
92*a58d3d2aSXin Li    end_time = time.time()
93*a58d3d2aSXin Li    total_time = end_time - start_time
94*a58d3d2aSXin Li    x1 = x1.squeeze(1).squeeze(0).detach().cpu().numpy()
95*a58d3d2aSXin Li    gen_seconds = len(x1)/16000
96*a58d3d2aSXin Li    out = deemphasis(inverse_perceptual_weighting(x1, lpc_filters, weighting_vector))
97*a58d3d2aSXin Li    if verbose:
98*a58d3d2aSXin Li        print(f"Took {total_time:.3f}s to generate {len(x1)}  samples ({gen_seconds}s) -> {gen_seconds/total_time:.2f}x real time")
99*a58d3d2aSXin Li
100*a58d3d2aSXin Li    out = np.clip(np.round(2**15 * out), -2**15, 2**15 -1).astype(np.int16)
101*a58d3d2aSXin Li    wavfile.write(output_filename, 16000, out)
102*a58d3d2aSXin Li
103*a58d3d2aSXin Li
104*a58d3d2aSXin Li########################### The inference loop over folder containing lpcnet feature files #################################
105*a58d3d2aSXin Liif __name__ == "__main__":
106*a58d3d2aSXin Li
107*a58d3d2aSXin Li    args = parser.parse_args()
108*a58d3d2aSXin Li
109*a58d3d2aSXin Li    generator = model_dict[args.model]()
110*a58d3d2aSXin Li
111*a58d3d2aSXin Li
112*a58d3d2aSXin Li    #Load the FWGAN500Hz Checkpoint
113*a58d3d2aSXin Li    saved_gen= torch.load(args.weightfile, map_location='cpu')
114*a58d3d2aSXin Li    generator.load_state_dict(saved_gen)
115*a58d3d2aSXin Li
116*a58d3d2aSXin Li    #this is just to remove the weight_norm from the model layers as it's no longer needed
117*a58d3d2aSXin Li    def _remove_weight_norm(m):
118*a58d3d2aSXin Li        try:
119*a58d3d2aSXin Li            torch.nn.utils.remove_weight_norm(m)
120*a58d3d2aSXin Li        except ValueError:  # this module didn't have weight norm
121*a58d3d2aSXin Li            return
122*a58d3d2aSXin Li    generator.apply(_remove_weight_norm)
123*a58d3d2aSXin Li
124*a58d3d2aSXin Li    #enable inference mode
125*a58d3d2aSXin Li    generator = generator.eval()
126*a58d3d2aSXin Li
127*a58d3d2aSXin Li    print('Successfully loaded the generator model ... start generation:')
128*a58d3d2aSXin Li
129*a58d3d2aSXin Li    if os.path.isdir(args.input):
130*a58d3d2aSXin Li
131*a58d3d2aSXin Li        os.makedirs(args.output, exist_ok=True)
132*a58d3d2aSXin Li
133*a58d3d2aSXin Li        for fn in os.listdir(args.input):
134*a58d3d2aSXin Li            print(f"processing input {fn}...")
135*a58d3d2aSXin Li            feature_filename = os.path.join(args.input, fn)
136*a58d3d2aSXin Li            output_filename = os.path.join(args.output, os.path.splitext(fn)[0] + f"_{args.model}.wav")
137*a58d3d2aSXin Li            process_item(generator, feature_filename, output_filename)
138*a58d3d2aSXin Li    else:
139*a58d3d2aSXin Li        process_item(generator, args.input, args.output)
140*a58d3d2aSXin Li
141*a58d3d2aSXin Li    print("Finished!")