xref: /aosp_15_r20/external/libopus/dnn/torch/osce/stndrd/evaluation/moc.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2/* Copyright (c) 2023 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30import numpy as np
31import scipy.signal
32
33def compute_vad_mask(x, fs, stop_db=-70):
34
35    frame_length = (fs + 49) // 50
36    x = x[: frame_length * (len(x) // frame_length)]
37
38    frames = x.reshape(-1, frame_length)
39    frame_energy = np.sum(frames ** 2, axis=1)
40    frame_energy_smooth = np.convolve(frame_energy, np.ones(5) / 5, mode='same')
41
42    max_threshold = frame_energy.max() * 10 ** (stop_db/20)
43    vactive = np.ones_like(frames)
44    vactive[frame_energy_smooth < max_threshold, :] = 0
45    vactive = vactive.reshape(-1)
46
47    filter = np.sin(np.arange(frame_length) * np.pi / (frame_length - 1))
48    filter = filter / filter.sum()
49
50    mask = np.convolve(vactive, filter, mode='same')
51
52    return x, mask
53
54def convert_mask(mask, num_frames, frame_size=160, hop_size=40):
55    num_samples = frame_size + (num_frames - 1) * hop_size
56    if len(mask) < num_samples:
57        mask = np.concatenate((mask, np.zeros(num_samples - len(mask))), dtype=mask.dtype)
58    else:
59        mask = mask[:num_samples]
60
61    new_mask = np.array([np.mean(mask[i*hop_size : i*hop_size + frame_size]) for i in range(num_frames)])
62
63    return new_mask
64
65def power_spectrum(x, window_size=160, hop_size=40, window='hamming'):
66    num_spectra = (len(x) - window_size - hop_size) // hop_size
67    window = scipy.signal.get_window(window, window_size)
68    N = window_size // 2
69
70    frames = np.concatenate([x[np.newaxis, i * hop_size : i * hop_size + window_size] for i in range(num_spectra)]) * window
71    psd = np.abs(np.fft.fft(frames, axis=1)[:, :N + 1]) ** 2
72
73    return psd
74
75
76def frequency_mask(num_bands, up_factor, down_factor):
77
78    up_mask = np.zeros((num_bands, num_bands))
79    down_mask = np.zeros((num_bands, num_bands))
80
81    for i in range(num_bands):
82        up_mask[i, : i + 1] = up_factor ** np.arange(i, -1, -1)
83        down_mask[i, i :] = down_factor ** np.arange(num_bands - i)
84
85    return down_mask @ up_mask
86
87
88def rect_fb(band_limits, num_bins=None):
89    num_bands = len(band_limits) - 1
90    if num_bins is None:
91        num_bins = band_limits[-1]
92
93    fb = np.zeros((num_bands, num_bins))
94    for i in range(num_bands):
95        fb[i, band_limits[i]:band_limits[i+1]] = 1
96
97    return fb
98
99
100def compare(x, y, apply_vad=False):
101    """ Modified version of opus_compare for 16 kHz mono signals
102
103    Args:
104        x (np.ndarray): reference input signal scaled to [-1, 1]
105        y (np.ndarray): test signal scaled to [-1, 1]
106
107    Returns:
108        float: perceptually weighted error
109    """
110    # filter bank: bark scale with minimum-2-bin bands and cutoff at 7.5 kHz
111    band_limits = [0, 2, 4, 6, 7, 9, 11, 13, 15, 18, 22, 26, 31, 36, 43, 51, 60, 75]
112    num_bands = len(band_limits) - 1
113    fb = rect_fb(band_limits, num_bins=81)
114
115    # trim samples to same size
116    num_samples = min(len(x), len(y))
117    x = x[:num_samples] * 2**15
118    y = y[:num_samples] * 2**15
119
120    psd_x = power_spectrum(x) + 100000
121    psd_y = power_spectrum(y) + 100000
122
123    num_frames = psd_x.shape[0]
124
125    # average band energies
126    be_x = (psd_x @ fb.T) / np.sum(fb, axis=1)
127
128    # frequecy masking
129    f_mask = frequency_mask(num_bands, 0.1, 0.03)
130    mask_x = be_x @ f_mask.T
131
132    # temporal masking
133    for i in range(1, num_frames):
134        mask_x[i, :] += 0.5 * mask_x[i-1, :]
135
136    # apply mask
137    masked_psd_x = psd_x + 0.1 * (mask_x @ fb)
138    masked_psd_y = psd_y + 0.1 * (mask_x @ fb)
139
140    # 2-frame average
141    masked_psd_x = masked_psd_x[1:] +  masked_psd_x[:-1]
142    masked_psd_y = masked_psd_y[1:] +  masked_psd_y[:-1]
143
144    # distortion metric
145    re = masked_psd_y / masked_psd_x
146    im = np.log(re) ** 2
147    Eb = ((im @ fb.T) / np.sum(fb, axis=1))
148    Ef = np.mean(Eb , axis=1)
149
150    if apply_vad:
151        _, mask = compute_vad_mask(x, 16000)
152        mask = convert_mask(mask, Ef.shape[0])
153    else:
154        mask = np.ones_like(Ef)
155
156    err = np.mean(np.abs(Ef[mask > 1e-6]) ** 3) ** (1/6)
157
158    return float(err)
159
160if __name__ == "__main__":
161    import argparse
162    from scipy.io import wavfile
163
164    parser = argparse.ArgumentParser()
165    parser.add_argument('ref', type=str, help='reference wav file')
166    parser.add_argument('deg', type=str, help='degraded wav file')
167    parser.add_argument('--apply-vad', action='store_true')
168    args = parser.parse_args()
169
170
171    fs1, x = wavfile.read(args.ref)
172    fs2, y = wavfile.read(args.deg)
173
174    if max(fs1, fs2) != 16000:
175        raise ValueError('error: encountered sampling frequency diffrent from 16kHz')
176
177    x = x.astype(np.float32) / 2**15
178    y = y.astype(np.float32) / 2**15
179
180    err = compare(x, y, apply_vad=args.apply_vad)
181
182    print(f"MOC: {err}")
183