1""" 2/* Copyright (c) 2023 Amazon 3 Written by Jan Buethe */ 4/* 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions 7 are met: 8 9 - Redistributions of source code must retain the above copyright 10 notice, this list of conditions and the following disclaimer. 11 12 - Redistributions in binary form must reproduce the above copyright 13 notice, this list of conditions and the following disclaimer in the 14 documentation and/or other materials provided with the distribution. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*/ 28""" 29 30import numpy as np 31import scipy.signal 32 33def compute_vad_mask(x, fs, stop_db=-70): 34 35 frame_length = (fs + 49) // 50 36 x = x[: frame_length * (len(x) // frame_length)] 37 38 frames = x.reshape(-1, frame_length) 39 frame_energy = np.sum(frames ** 2, axis=1) 40 frame_energy_smooth = np.convolve(frame_energy, np.ones(5) / 5, mode='same') 41 42 max_threshold = frame_energy.max() * 10 ** (stop_db/20) 43 vactive = np.ones_like(frames) 44 vactive[frame_energy_smooth < max_threshold, :] = 0 45 vactive = vactive.reshape(-1) 46 47 filter = np.sin(np.arange(frame_length) * np.pi / (frame_length - 1)) 48 filter = filter / filter.sum() 49 50 mask = np.convolve(vactive, filter, mode='same') 51 52 return x, mask 53 54def convert_mask(mask, num_frames, frame_size=160, hop_size=40): 55 num_samples = frame_size + (num_frames - 1) * hop_size 56 if len(mask) < num_samples: 57 mask = np.concatenate((mask, np.zeros(num_samples - len(mask))), dtype=mask.dtype) 58 else: 59 mask = mask[:num_samples] 60 61 new_mask = np.array([np.mean(mask[i*hop_size : i*hop_size + frame_size]) for i in range(num_frames)]) 62 63 return new_mask 64 65def power_spectrum(x, window_size=160, hop_size=40, window='hamming'): 66 num_spectra = (len(x) - window_size - hop_size) // hop_size 67 window = scipy.signal.get_window(window, window_size) 68 N = window_size // 2 69 70 frames = np.concatenate([x[np.newaxis, i * hop_size : i * hop_size + window_size] for i in range(num_spectra)]) * window 71 psd = np.abs(np.fft.fft(frames, axis=1)[:, :N + 1]) ** 2 72 73 return psd 74 75 76def frequency_mask(num_bands, up_factor, down_factor): 77 78 up_mask = np.zeros((num_bands, num_bands)) 79 down_mask = np.zeros((num_bands, num_bands)) 80 81 for i in range(num_bands): 82 up_mask[i, : i + 1] = up_factor ** np.arange(i, -1, -1) 83 down_mask[i, i :] = down_factor ** np.arange(num_bands - i) 84 85 return down_mask @ up_mask 86 87 88def rect_fb(band_limits, num_bins=None): 89 num_bands = len(band_limits) - 1 90 if num_bins is None: 91 num_bins = band_limits[-1] 92 93 fb = np.zeros((num_bands, num_bins)) 94 for i in range(num_bands): 95 fb[i, band_limits[i]:band_limits[i+1]] = 1 96 97 return fb 98 99 100def compare(x, y, apply_vad=False): 101 """ Modified version of opus_compare for 16 kHz mono signals 102 103 Args: 104 x (np.ndarray): reference input signal scaled to [-1, 1] 105 y (np.ndarray): test signal scaled to [-1, 1] 106 107 Returns: 108 float: perceptually weighted error 109 """ 110 # filter bank: bark scale with minimum-2-bin bands and cutoff at 7.5 kHz 111 band_limits = [0, 2, 4, 6, 7, 9, 11, 13, 15, 18, 22, 26, 31, 36, 43, 51, 60, 75] 112 num_bands = len(band_limits) - 1 113 fb = rect_fb(band_limits, num_bins=81) 114 115 # trim samples to same size 116 num_samples = min(len(x), len(y)) 117 x = x[:num_samples] * 2**15 118 y = y[:num_samples] * 2**15 119 120 psd_x = power_spectrum(x) + 100000 121 psd_y = power_spectrum(y) + 100000 122 123 num_frames = psd_x.shape[0] 124 125 # average band energies 126 be_x = (psd_x @ fb.T) / np.sum(fb, axis=1) 127 128 # frequecy masking 129 f_mask = frequency_mask(num_bands, 0.1, 0.03) 130 mask_x = be_x @ f_mask.T 131 132 # temporal masking 133 for i in range(1, num_frames): 134 mask_x[i, :] += 0.5 * mask_x[i-1, :] 135 136 # apply mask 137 masked_psd_x = psd_x + 0.1 * (mask_x @ fb) 138 masked_psd_y = psd_y + 0.1 * (mask_x @ fb) 139 140 # 2-frame average 141 masked_psd_x = masked_psd_x[1:] + masked_psd_x[:-1] 142 masked_psd_y = masked_psd_y[1:] + masked_psd_y[:-1] 143 144 # distortion metric 145 re = masked_psd_y / masked_psd_x 146 im = np.log(re) ** 2 147 Eb = ((im @ fb.T) / np.sum(fb, axis=1)) 148 Ef = np.mean(Eb , axis=1) 149 150 if apply_vad: 151 _, mask = compute_vad_mask(x, 16000) 152 mask = convert_mask(mask, Ef.shape[0]) 153 else: 154 mask = np.ones_like(Ef) 155 156 err = np.mean(np.abs(Ef[mask > 1e-6]) ** 3) ** (1/6) 157 158 return float(err) 159 160if __name__ == "__main__": 161 import argparse 162 from scipy.io import wavfile 163 164 parser = argparse.ArgumentParser() 165 parser.add_argument('ref', type=str, help='reference wav file') 166 parser.add_argument('deg', type=str, help='degraded wav file') 167 parser.add_argument('--apply-vad', action='store_true') 168 args = parser.parse_args() 169 170 171 fs1, x = wavfile.read(args.ref) 172 fs2, y = wavfile.read(args.deg) 173 174 if max(fs1, fs2) != 16000: 175 raise ValueError('error: encountered sampling frequency diffrent from 16kHz') 176 177 x = x.astype(np.float32) / 2**15 178 y = y.astype(np.float32) / 2**15 179 180 err = compare(x, y, apply_vad=args.apply_vad) 181 182 print(f"MOC: {err}") 183