1""" 2/* Copyright (c) 2023 Amazon 3 Written by Jan Buethe */ 4/* 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions 7 are met: 8 9 - Redistributions of source code must retain the above copyright 10 notice, this list of conditions and the following disclaimer. 11 12 - Redistributions in binary form must reproduce the above copyright 13 notice, this list of conditions and the following disclaimer in the 14 documentation and/or other materials provided with the distribution. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*/ 28""" 29 30"""STFT-based Loss modules.""" 31 32import torch 33import torch.nn.functional as F 34from torch import nn 35import numpy as np 36import torchaudio 37 38 39def get_window(win_name, win_length, *args, **kwargs): 40 window_dict = { 41 'bartlett_window' : torch.bartlett_window, 42 'blackman_window' : torch.blackman_window, 43 'hamming_window' : torch.hamming_window, 44 'hann_window' : torch.hann_window, 45 'kaiser_window' : torch.kaiser_window 46 } 47 48 if not win_name in window_dict: 49 raise ValueError() 50 51 return window_dict[win_name](win_length, *args, **kwargs) 52 53 54def stft(x, fft_size, hop_size, win_length, window): 55 """Perform STFT and convert to magnitude spectrogram. 56 Args: 57 x (Tensor): Input signal tensor (B, T). 58 fft_size (int): FFT size. 59 hop_size (int): Hop size. 60 win_length (int): Window length. 61 window (str): Window function type. 62 Returns: 63 Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). 64 """ 65 66 win = get_window(window, win_length).to(x.device) 67 x_stft = torch.stft(x, fft_size, hop_size, win_length, win, return_complex=True) 68 69 70 return torch.clamp(torch.abs(x_stft), min=1e-7) 71 72def spectral_convergence_loss(Y_true, Y_pred): 73 dims=list(range(1, len(Y_pred.shape))) 74 return torch.mean(torch.norm(torch.abs(Y_true) - torch.abs(Y_pred), p="fro", dim=dims) / (torch.norm(Y_pred, p="fro", dim=dims) + 1e-6)) 75 76 77def log_magnitude_loss(Y_true, Y_pred): 78 Y_true_log_abs = torch.log(torch.abs(Y_true) + 1e-15) 79 Y_pred_log_abs = torch.log(torch.abs(Y_pred) + 1e-15) 80 81 return torch.mean(torch.abs(Y_true_log_abs - Y_pred_log_abs)) 82 83def spectral_xcorr_loss(Y_true, Y_pred): 84 Y_true = Y_true.abs() 85 Y_pred = Y_pred.abs() 86 dims=list(range(1, len(Y_pred.shape))) 87 xcorr = torch.sum(Y_true * Y_pred, dim=dims) / torch.sqrt(torch.sum(Y_true ** 2, dim=dims) * torch.sum(Y_pred ** 2, dim=dims) + 1e-9) 88 89 return 1 - xcorr.mean() 90 91 92 93class MRLogMelLoss(nn.Module): 94 def __init__(self, 95 fft_sizes=[512, 256, 128, 64], 96 overlap=0.5, 97 fs=16000, 98 n_mels=18 99 ): 100 101 self.fft_sizes = fft_sizes 102 self.overlap = overlap 103 self.fs = fs 104 self.n_mels = n_mels 105 106 super().__init__() 107 108 self.mel_specs = [] 109 for fft_size in fft_sizes: 110 hop_size = int(round(fft_size * (1 - self.overlap))) 111 112 n_mels = self.n_mels 113 if fft_size < 128: 114 n_mels //= 2 115 116 self.mel_specs.append(torchaudio.transforms.MelSpectrogram(fs, fft_size, hop_length=hop_size, n_mels=n_mels)) 117 118 for i, mel_spec in enumerate(self.mel_specs): 119 self.add_module(f'mel_spec_{i+1}', mel_spec) 120 121 def forward(self, y_true, y_pred): 122 123 loss = torch.zeros(1, device=y_true.device) 124 125 for mel_spec in self.mel_specs: 126 Y_true = mel_spec(y_true) 127 Y_pred = mel_spec(y_pred) 128 loss = loss + log_magnitude_loss(Y_true, Y_pred) 129 130 loss = loss / len(self.mel_specs) 131 132 return loss 133 134def create_weight_matrix(num_bins, bins_per_band=10): 135 m = torch.zeros((num_bins, num_bins), dtype=torch.float32) 136 137 r0 = bins_per_band // 2 138 r1 = bins_per_band - r0 139 140 for i in range(num_bins): 141 i0 = max(i - r0, 0) 142 j0 = min(i + r1, num_bins) 143 144 m[i, i0: j0] += 1 145 146 if i < r0: 147 m[i, :r0 - i] += 1 148 149 if i > num_bins - r1: 150 m[i, num_bins - r1 - i:] += 1 151 152 return m / bins_per_band 153 154def weighted_spectral_convergence(Y_true, Y_pred, w): 155 156 # calculate sfm based weights 157 logY = torch.log(torch.abs(Y_true) + 1e-9) 158 Y = torch.abs(Y_true) 159 160 avg_logY = torch.matmul(logY.transpose(1, 2), w) 161 avg_Y = torch.matmul(Y.transpose(1, 2), w) 162 163 sfm = torch.exp(avg_logY) / (avg_Y + 1e-9) 164 165 weight = (torch.relu(1 - sfm) ** .5).transpose(1, 2) 166 167 loss = torch.mean( 168 torch.mean(weight * torch.abs(torch.abs(Y_true) - torch.abs(Y_pred)), dim=[1, 2]) 169 / (torch.mean( weight * torch.abs(Y_true), dim=[1, 2]) + 1e-9) 170 ) 171 172 return loss 173 174def gen_filterbank(N, Fs=16000): 175 in_freq = (np.arange(N+1, dtype='float32')/N*Fs/2)[None,:] 176 out_freq = (np.arange(N, dtype='float32')/N*Fs/2)[:,None] 177 #ERB from B.C.J Moore, An Introduction to the Psychology of Hearing, 5th Ed., page 73. 178 ERB_N = 24.7 + .108*in_freq 179 delta = np.abs(in_freq-out_freq)/ERB_N 180 center = (delta<.5).astype('float32') 181 R = -12*center*delta**2 + (1-center)*(3-12*delta) 182 RE = 10.**(R/10.) 183 norm = np.sum(RE, axis=1) 184 RE = RE/norm[:, np.newaxis] 185 return torch.from_numpy(RE) 186 187def smooth_log_mag(Y_true, Y_pred, filterbank): 188 Y_true_smooth = torch.matmul(filterbank, torch.abs(Y_true)) 189 Y_pred_smooth = torch.matmul(filterbank, torch.abs(Y_pred)) 190 191 loss = torch.abs( 192 torch.log(Y_true_smooth + 1e-9) - torch.log(Y_pred_smooth + 1e-9) 193 ) 194 195 loss = loss.mean() 196 197 return loss 198 199class MRSTFTLoss(nn.Module): 200 def __init__(self, 201 fft_sizes=[2048, 1024, 512, 256, 128, 64], 202 overlap=0.5, 203 window='hann_window', 204 fs=16000, 205 log_mag_weight=1, 206 sc_weight=0, 207 wsc_weight=0, 208 smooth_log_mag_weight=0, 209 sxcorr_weight=0): 210 super().__init__() 211 212 self.fft_sizes = fft_sizes 213 self.overlap = overlap 214 self.window = window 215 self.log_mag_weight = log_mag_weight 216 self.sc_weight = sc_weight 217 self.wsc_weight = wsc_weight 218 self.smooth_log_mag_weight = smooth_log_mag_weight 219 self.sxcorr_weight = sxcorr_weight 220 self.fs = fs 221 222 # weights for SFM weighted spectral convergence loss 223 self.wsc_weights = torch.nn.ParameterDict() 224 for fft_size in fft_sizes: 225 width = min(11, int(1000 * fft_size / self.fs + .5)) 226 width += width % 2 227 self.wsc_weights[str(fft_size)] = torch.nn.Parameter( 228 create_weight_matrix(fft_size // 2 + 1, width), 229 requires_grad=False 230 ) 231 232 # filterbanks for smooth log magnitude loss 233 self.filterbanks = torch.nn.ParameterDict() 234 for fft_size in fft_sizes: 235 self.filterbanks[str(fft_size)] = torch.nn.Parameter( 236 gen_filterbank(fft_size//2), 237 requires_grad=False 238 ) 239 240 241 def __call__(self, y_true, y_pred): 242 243 244 lm_loss = torch.zeros(1, device=y_true.device) 245 sc_loss = torch.zeros(1, device=y_true.device) 246 wsc_loss = torch.zeros(1, device=y_true.device) 247 slm_loss = torch.zeros(1, device=y_true.device) 248 sxcorr_loss = torch.zeros(1, device=y_true.device) 249 250 for fft_size in self.fft_sizes: 251 hop_size = int(round(fft_size * (1 - self.overlap))) 252 win_size = fft_size 253 254 Y_true = stft(y_true, fft_size, hop_size, win_size, self.window) 255 Y_pred = stft(y_pred, fft_size, hop_size, win_size, self.window) 256 257 if self.log_mag_weight > 0: 258 lm_loss = lm_loss + log_magnitude_loss(Y_true, Y_pred) 259 260 if self.sc_weight > 0: 261 sc_loss = sc_loss + spectral_convergence_loss(Y_true, Y_pred) 262 263 if self.wsc_weight > 0: 264 wsc_loss = wsc_loss + weighted_spectral_convergence(Y_true, Y_pred, self.wsc_weights[str(fft_size)]) 265 266 if self.smooth_log_mag_weight > 0: 267 slm_loss = slm_loss + smooth_log_mag(Y_true, Y_pred, self.filterbanks[str(fft_size)]) 268 269 if self.sxcorr_weight > 0: 270 sxcorr_loss = sxcorr_loss + spectral_xcorr_loss(Y_true, Y_pred) 271 272 273 total_loss = (self.log_mag_weight * lm_loss + self.sc_weight * sc_loss 274 + self.wsc_weight * wsc_loss + self.smooth_log_mag_weight * slm_loss 275 + self.sxcorr_weight * sxcorr_loss) / len(self.fft_sizes) 276 277 return total_loss