1import numpy as np 2import scipy.signal 3 4def compute_vad_mask(x, fs, stop_db=-70): 5 6 frame_length = (fs + 49) // 50 7 x = x[: frame_length * (len(x) // frame_length)] 8 9 frames = x.reshape(-1, frame_length) 10 frame_energy = np.sum(frames ** 2, axis=1) 11 frame_energy_smooth = np.convolve(frame_energy, np.ones(5) / 5, mode='same') 12 13 max_threshold = frame_energy.max() * 10 ** (stop_db/20) 14 vactive = np.ones_like(frames) 15 vactive[frame_energy_smooth < max_threshold, :] = 0 16 vactive = vactive.reshape(-1) 17 18 filter = np.sin(np.arange(frame_length) * np.pi / (frame_length - 1)) 19 filter = filter / filter.sum() 20 21 mask = np.convolve(vactive, filter, mode='same') 22 23 return x, mask 24 25def convert_mask(mask, num_frames, frame_size=160, hop_size=40): 26 num_samples = frame_size + (num_frames - 1) * hop_size 27 if len(mask) < num_samples: 28 mask = np.concatenate((mask, np.zeros(num_samples - len(mask))), dtype=mask.dtype) 29 else: 30 mask = mask[:num_samples] 31 32 new_mask = np.array([np.mean(mask[i*hop_size : i*hop_size + frame_size]) for i in range(num_frames)]) 33 34 return new_mask 35 36def power_spectrum(x, window_size=160, hop_size=40, window='hamming'): 37 num_spectra = (len(x) - window_size - hop_size) // hop_size 38 window = scipy.signal.get_window(window, window_size) 39 N = window_size // 2 40 41 frames = np.concatenate([x[np.newaxis, i * hop_size : i * hop_size + window_size] for i in range(num_spectra)]) * window 42 psd = np.abs(np.fft.fft(frames, axis=1)[:, :N + 1]) ** 2 43 44 return psd 45 46 47def frequency_mask(num_bands, up_factor, down_factor): 48 49 up_mask = np.zeros((num_bands, num_bands)) 50 down_mask = np.zeros((num_bands, num_bands)) 51 52 for i in range(num_bands): 53 up_mask[i, : i + 1] = up_factor ** np.arange(i, -1, -1) 54 down_mask[i, i :] = down_factor ** np.arange(num_bands - i) 55 56 return down_mask @ up_mask 57 58 59def rect_fb(band_limits, num_bins=None): 60 num_bands = len(band_limits) - 1 61 if num_bins is None: 62 num_bins = band_limits[-1] 63 64 fb = np.zeros((num_bands, num_bins)) 65 for i in range(num_bands): 66 fb[i, band_limits[i]:band_limits[i+1]] = 1 67 68 return fb 69 70 71def compare(x, y, apply_vad=False): 72 """ Modified version of opus_compare for 16 kHz mono signals 73 74 Args: 75 x (np.ndarray): reference input signal scaled to [-1, 1] 76 y (np.ndarray): test signal scaled to [-1, 1] 77 78 Returns: 79 float: perceptually weighted error 80 """ 81 # filter bank: bark scale with minimum-2-bin bands and cutoff at 7.5 kHz 82 band_limits = [0, 2, 4, 6, 7, 9, 11, 13, 15, 18, 22, 26, 31, 36, 43, 51, 60, 75] 83 num_bands = len(band_limits) - 1 84 fb = rect_fb(band_limits, num_bins=81) 85 86 # trim samples to same size 87 num_samples = min(len(x), len(y)) 88 x = x[:num_samples] * 2**15 89 y = y[:num_samples] * 2**15 90 91 psd_x = power_spectrum(x) + 100000 92 psd_y = power_spectrum(y) + 100000 93 94 num_frames = psd_x.shape[0] 95 96 # average band energies 97 be_x = (psd_x @ fb.T) / np.sum(fb, axis=1) 98 99 # frequecy masking 100 f_mask = frequency_mask(num_bands, 0.1, 0.03) 101 mask_x = be_x @ f_mask.T 102 103 # temporal masking 104 for i in range(1, num_frames): 105 mask_x[i, :] += 0.5 * mask_x[i-1, :] 106 107 # apply mask 108 masked_psd_x = psd_x + 0.1 * (mask_x @ fb) 109 masked_psd_y = psd_y + 0.1 * (mask_x @ fb) 110 111 # 2-frame average 112 masked_psd_x = masked_psd_x[1:] + masked_psd_x[:-1] 113 masked_psd_y = masked_psd_y[1:] + masked_psd_y[:-1] 114 115 # distortion metric 116 re = masked_psd_y / masked_psd_x 117 im = np.log(re) ** 2 118 Eb = ((im @ fb.T) / np.sum(fb, axis=1)) 119 Ef = np.mean(Eb , axis=1) 120 121 if apply_vad: 122 _, mask = compute_vad_mask(x, 16000) 123 mask = convert_mask(mask, Ef.shape[0]) 124 else: 125 mask = np.ones_like(Ef) 126 127 err = np.mean(np.abs(Ef[mask > 1e-6]) ** 3) ** (1/6) 128 129 return float(err) 130 131if __name__ == "__main__": 132 import argparse 133 from scipy.io import wavfile 134 135 parser = argparse.ArgumentParser() 136 parser.add_argument('ref', type=str, help='reference wav file') 137 parser.add_argument('deg', type=str, help='degraded wav file') 138 parser.add_argument('--apply-vad', action='store_true') 139 args = parser.parse_args() 140 141 142 fs1, x = wavfile.read(args.ref) 143 fs2, y = wavfile.read(args.deg) 144 145 if max(fs1, fs2) != 16000: 146 raise ValueError('error: encountered sampling frequency diffrent from 16kHz') 147 148 x = x.astype(np.float32) / 2**15 149 y = y.astype(np.float32) / 2**15 150 151 err = compare(x, y, apply_vad=args.apply_vad) 152 153 print(f"MOC: {err}") 154