xref: /aosp_15_r20/external/libopus/dnn/torch/osce/utils/moc.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1import numpy as np
2import scipy.signal
3
4def compute_vad_mask(x, fs, stop_db=-70):
5
6    frame_length = (fs + 49) // 50
7    x = x[: frame_length * (len(x) // frame_length)]
8
9    frames = x.reshape(-1, frame_length)
10    frame_energy = np.sum(frames ** 2, axis=1)
11    frame_energy_smooth = np.convolve(frame_energy, np.ones(5) / 5, mode='same')
12
13    max_threshold = frame_energy.max() * 10 ** (stop_db/20)
14    vactive = np.ones_like(frames)
15    vactive[frame_energy_smooth < max_threshold, :] = 0
16    vactive = vactive.reshape(-1)
17
18    filter = np.sin(np.arange(frame_length) * np.pi / (frame_length - 1))
19    filter = filter / filter.sum()
20
21    mask = np.convolve(vactive, filter, mode='same')
22
23    return x, mask
24
25def convert_mask(mask, num_frames, frame_size=160, hop_size=40):
26    num_samples = frame_size + (num_frames - 1) * hop_size
27    if len(mask) < num_samples:
28        mask = np.concatenate((mask, np.zeros(num_samples - len(mask))), dtype=mask.dtype)
29    else:
30        mask = mask[:num_samples]
31
32    new_mask = np.array([np.mean(mask[i*hop_size : i*hop_size + frame_size]) for i in range(num_frames)])
33
34    return new_mask
35
36def power_spectrum(x, window_size=160, hop_size=40, window='hamming'):
37    num_spectra = (len(x) - window_size - hop_size) // hop_size
38    window = scipy.signal.get_window(window, window_size)
39    N = window_size // 2
40
41    frames = np.concatenate([x[np.newaxis, i * hop_size : i * hop_size + window_size] for i in range(num_spectra)]) * window
42    psd = np.abs(np.fft.fft(frames, axis=1)[:, :N + 1]) ** 2
43
44    return psd
45
46
47def frequency_mask(num_bands, up_factor, down_factor):
48
49    up_mask = np.zeros((num_bands, num_bands))
50    down_mask = np.zeros((num_bands, num_bands))
51
52    for i in range(num_bands):
53        up_mask[i, : i + 1] = up_factor ** np.arange(i, -1, -1)
54        down_mask[i, i :] = down_factor ** np.arange(num_bands - i)
55
56    return down_mask @ up_mask
57
58
59def rect_fb(band_limits, num_bins=None):
60    num_bands = len(band_limits) - 1
61    if num_bins is None:
62        num_bins = band_limits[-1]
63
64    fb = np.zeros((num_bands, num_bins))
65    for i in range(num_bands):
66        fb[i, band_limits[i]:band_limits[i+1]] = 1
67
68    return fb
69
70
71def compare(x, y, apply_vad=False):
72    """ Modified version of opus_compare for 16 kHz mono signals
73
74    Args:
75        x (np.ndarray): reference input signal scaled to [-1, 1]
76        y (np.ndarray): test signal scaled to [-1, 1]
77
78    Returns:
79        float: perceptually weighted error
80    """
81    # filter bank: bark scale with minimum-2-bin bands and cutoff at 7.5 kHz
82    band_limits = [0, 2, 4, 6, 7, 9, 11, 13, 15, 18, 22, 26, 31, 36, 43, 51, 60, 75]
83    num_bands = len(band_limits) - 1
84    fb = rect_fb(band_limits, num_bins=81)
85
86    # trim samples to same size
87    num_samples = min(len(x), len(y))
88    x = x[:num_samples] * 2**15
89    y = y[:num_samples] * 2**15
90
91    psd_x = power_spectrum(x) + 100000
92    psd_y = power_spectrum(y) + 100000
93
94    num_frames = psd_x.shape[0]
95
96    # average band energies
97    be_x = (psd_x @ fb.T) / np.sum(fb, axis=1)
98
99    # frequecy masking
100    f_mask = frequency_mask(num_bands, 0.1, 0.03)
101    mask_x = be_x @ f_mask.T
102
103    # temporal masking
104    for i in range(1, num_frames):
105        mask_x[i, :] += 0.5 * mask_x[i-1, :]
106
107    # apply mask
108    masked_psd_x = psd_x + 0.1 * (mask_x @ fb)
109    masked_psd_y = psd_y + 0.1 * (mask_x @ fb)
110
111    # 2-frame average
112    masked_psd_x = masked_psd_x[1:] +  masked_psd_x[:-1]
113    masked_psd_y = masked_psd_y[1:] +  masked_psd_y[:-1]
114
115    # distortion metric
116    re = masked_psd_y / masked_psd_x
117    im = np.log(re) ** 2
118    Eb = ((im @ fb.T) / np.sum(fb, axis=1))
119    Ef = np.mean(Eb , axis=1)
120
121    if apply_vad:
122        _, mask = compute_vad_mask(x, 16000)
123        mask = convert_mask(mask, Ef.shape[0])
124    else:
125        mask = np.ones_like(Ef)
126
127    err = np.mean(np.abs(Ef[mask > 1e-6]) ** 3) ** (1/6)
128
129    return float(err)
130
131if __name__ == "__main__":
132    import argparse
133    from scipy.io import wavfile
134
135    parser = argparse.ArgumentParser()
136    parser.add_argument('ref', type=str, help='reference wav file')
137    parser.add_argument('deg', type=str, help='degraded wav file')
138    parser.add_argument('--apply-vad', action='store_true')
139    args = parser.parse_args()
140
141
142    fs1, x = wavfile.read(args.ref)
143    fs2, y = wavfile.read(args.deg)
144
145    if max(fs1, fs2) != 16000:
146        raise ValueError('error: encountered sampling frequency diffrent from 16kHz')
147
148    x = x.astype(np.float32) / 2**15
149    y = y.astype(np.float32) / 2**15
150
151    err = compare(x, y, apply_vad=args.apply_vad)
152
153    print(f"MOC: {err}")
154