xref: /aosp_15_r20/external/libopus/dnn/torch/fargan/test_fargan.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1import os
2import argparse
3import numpy as np
4
5import torch
6from torch import nn
7import torch.nn.functional as F
8import tqdm
9
10import fargan
11from dataset import FARGANDataset
12
13nb_features = 36
14nb_used_features = 20
15
16parser = argparse.ArgumentParser()
17
18parser.add_argument('model', type=str, help='CELPNet model')
19parser.add_argument('features', type=str, help='path to feature file in .f32 format')
20parser.add_argument('output', type=str, help='path to output file (16-bit PCM)')
21
22parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: CUDA_VISIBLE_DEVICES", default=None)
23
24
25model_group = parser.add_argument_group(title="model parameters")
26model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
27
28args = parser.parse_args()
29
30if args.cuda_visible_devices != None:
31    os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices
32
33
34features_file = args.features
35signal_file = args.output
36
37
38
39device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
40
41checkpoint = torch.load(args.model, map_location='cpu')
42
43model = fargan.FARGAN(*checkpoint['model_args'], **checkpoint['model_kwargs'])
44
45
46model.load_state_dict(checkpoint['state_dict'], strict=False)
47
48features = np.reshape(np.memmap(features_file, dtype='float32', mode='r'), (1, -1, nb_features))
49lpc = features[:,4-1:-1,nb_used_features:]
50features = features[:, :, :nb_used_features]
51#periods = np.round(50*features[:,:,nb_used_features-2]+100).astype('int')
52periods = np.round(np.clip(256./2**(features[:,:,nb_used_features-2]+1.5), 32, 255)).astype('int')
53
54
55nb_frames = features.shape[1]
56#nb_frames = 1000
57gamma = checkpoint['model_kwargs']['gamma']
58
59def lpc_synthesis_one_frame(frame, filt, buffer, weighting_vector=np.ones(16)):
60
61    out = np.zeros_like(frame)
62    filt = np.flip(filt)
63
64    inp = frame[:]
65
66
67    for i in range(0, inp.shape[0]):
68
69        s = inp[i] - np.dot(buffer*weighting_vector, filt)
70
71        buffer[0] = s
72
73        buffer = np.roll(buffer, -1)
74
75        out[i] = s
76
77    return out
78
79def inverse_perceptual_weighting (pw_signal, filters, weighting_vector):
80
81    #inverse perceptual weighting= H_preemph / W(z/gamma)
82
83    signal = np.zeros_like(pw_signal)
84    buffer = np.zeros(16)
85    num_frames = pw_signal.shape[0] //160
86    assert num_frames == filters.shape[0]
87    for frame_idx in range(0, num_frames):
88
89        in_frame = pw_signal[frame_idx*160: (frame_idx+1)*160][:]
90        out_sig_frame = lpc_synthesis_one_frame(in_frame, filters[frame_idx, :], buffer, weighting_vector)
91        signal[frame_idx*160: (frame_idx+1)*160] = out_sig_frame[:]
92        buffer[:] = out_sig_frame[-16:]
93    return signal
94
95def inverse_perceptual_weighting40 (pw_signal, filters):
96
97    #inverse perceptual weighting= H_preemph / W(z/gamma)
98
99    signal = np.zeros_like(pw_signal)
100    buffer = np.zeros(16)
101    num_frames = pw_signal.shape[0] //40
102    assert num_frames == filters.shape[0]
103    for frame_idx in range(0, num_frames):
104        in_frame = pw_signal[frame_idx*40: (frame_idx+1)*40][:]
105        out_sig_frame = lpc_synthesis_one_frame(in_frame, filters[frame_idx, :], buffer)
106        signal[frame_idx*40: (frame_idx+1)*40] = out_sig_frame[:]
107        buffer[:] = out_sig_frame[-16:]
108    return signal
109
110from scipy.signal import lfilter
111
112if __name__ == '__main__':
113    model.to(device)
114    features = torch.tensor(features).to(device)
115    #lpc = torch.tensor(lpc).to(device)
116    periods = torch.tensor(periods).to(device)
117    weighting = gamma**np.arange(1, 17)
118    lpc = lpc*weighting
119    lpc = fargan.interp_lpc(torch.tensor(lpc), 4).numpy()
120
121    sig, _ = model(features, periods, nb_frames - 4)
122    #weighting_vector = np.array([gamma**i for i in range(16,0,-1)])
123    sig = sig.detach().numpy().flatten()
124    sig = lfilter(np.array([1.]), np.array([1., -.85]), sig)
125    #sig = inverse_perceptual_weighting40(sig, lpc[0,:,:])
126
127    pcm = np.round(32768*np.clip(sig, a_max=.99, a_min=-.99)).astype('int16')
128    pcm.tofile(signal_file)
129