1"""STFT-based Loss modules.""" 2 3import torch 4import torch.nn.functional as F 5import numpy as np 6import torchaudio 7 8 9def stft(x, fft_size, hop_size, win_length, window): 10 """Perform STFT and convert to magnitude spectrogram. 11 Args: 12 x (Tensor): Input signal tensor (B, T). 13 fft_size (int): FFT size. 14 hop_size (int): Hop size. 15 win_length (int): Window length. 16 window (str): Window function type. 17 Returns: 18 Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). 19 """ 20 21 #x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=False) 22 #real = x_stft[..., 0] 23 #imag = x_stft[..., 1] 24 25 # (kan-bayashi): clamp is needed to avoid nan or inf 26 #return torchaudio.functional.amplitude_to_DB(torch.abs(x_stft),db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80) 27 #return torch.clamp(torch.abs(x_stft), min=1e-7) 28 29 x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True) 30 return torch.clamp(torch.abs(x_stft), min=1e-7) 31 32class SpectralConvergenceLoss(torch.nn.Module): 33 """Spectral convergence loss module.""" 34 35 def __init__(self): 36 """Initilize spectral convergence loss module.""" 37 super(SpectralConvergenceLoss, self).__init__() 38 39 def forward(self, x_mag, y_mag): 40 """Calculate forward propagation. 41 Args: 42 x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). 43 y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). 44 Returns: 45 Tensor: Spectral convergence loss value. 46 """ 47 x_mag = torch.sqrt(x_mag) 48 y_mag = torch.sqrt(y_mag) 49 return torch.norm(y_mag - x_mag, p=1) / torch.norm(y_mag, p=1) 50 51class LogSTFTMagnitudeLoss(torch.nn.Module): 52 """Log STFT magnitude loss module.""" 53 54 def __init__(self): 55 """Initilize los STFT magnitude loss module.""" 56 super(LogSTFTMagnitudeLoss, self).__init__() 57 58 def forward(self, x, y): 59 """Calculate forward propagation. 60 Args: 61 x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). 62 y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). 63 Returns: 64 Tensor: Log STFT magnitude loss value. 65 """ 66 #F.l1_loss(torch.sqrt(y_mag), torch.sqrt(x_mag)) + 67 #F.l1_loss(torchaudio.functional.amplitude_to_DB(y_mag,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80),\ 68 #torchaudio.functional.amplitude_to_DB(x_mag,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80)) 69 70 #y_mag[:,:y_mag.size(1)//2,:] = y_mag[:,:y_mag.size(1)//2,:] *0.0 71 72 #return F.l1_loss(torch.log(y_mag) + torch.sqrt(y_mag), torch.log(x_mag) + torch.sqrt(x_mag)) 73 74 #return F.l1_loss(y_mag, x_mag) 75 76 error_loss = F.l1_loss(y, x) #+ F.l1_loss(torch.sqrt(y), torch.sqrt(x))#F.l1_loss(torch.log(y), torch.log(x))# 77 78 #x = torch.log(x) 79 #y = torch.log(y) 80 #x = x.permute(0,2,1).contiguous() 81 #y = y.permute(0,2,1).contiguous() 82 83 '''mean_x = torch.mean(x, dim=1, keepdim=True) 84 mean_y = torch.mean(y, dim=1, keepdim=True) 85 86 var_x = torch.var(x, dim=1, keepdim=True) 87 var_y = torch.var(y, dim=1, keepdim=True) 88 89 std_x = torch.std(x, dim=1, keepdim=True) 90 std_y = torch.std(y, dim=1, keepdim=True) 91 92 x_minus_mean = x - mean_x 93 y_minus_mean = y - mean_y 94 95 pearson_corr = torch.sum(x_minus_mean * y_minus_mean, dim=1, keepdim=True) / \ 96 (torch.sqrt(torch.sum(x_minus_mean ** 2, dim=1, keepdim=True) + 1e-7) * \ 97 torch.sqrt(torch.sum(y_minus_mean ** 2, dim=1, keepdim=True) + 1e-7)) 98 99 numerator = 2.0 * pearson_corr * std_x * std_y 100 denominator = var_x + var_y + (mean_y - mean_x)**2 101 102 ccc = numerator/denominator 103 104 ccc_loss = F.l1_loss(1.0 - ccc, torch.zeros_like(ccc))''' 105 106 return error_loss #+ ccc_loss#+ ccc_loss 107 108 109class STFTLoss(torch.nn.Module): 110 """STFT loss module.""" 111 112 def __init__(self, device, fft_size=1024, shift_size=120, win_length=600, window="hann_window"): 113 """Initialize STFT loss module.""" 114 super(STFTLoss, self).__init__() 115 self.fft_size = fft_size 116 self.shift_size = shift_size 117 self.win_length = win_length 118 self.window = getattr(torch, window)(win_length).to(device) 119 self.spectral_convergenge_loss = SpectralConvergenceLoss() 120 self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss() 121 122 def forward(self, x, y): 123 """Calculate forward propagation. 124 Args: 125 x (Tensor): Predicted signal (B, T). 126 y (Tensor): Groundtruth signal (B, T). 127 Returns: 128 Tensor: Spectral convergence loss value. 129 Tensor: Log STFT magnitude loss value. 130 """ 131 x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window) 132 y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window) 133 sc_loss = self.spectral_convergenge_loss(x_mag, y_mag) 134 mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag) 135 136 return sc_loss, mag_loss 137 138 139class MultiResolutionSTFTLoss(torch.nn.Module): 140 141 '''def __init__(self, 142 device, 143 fft_sizes=[2048, 1024, 512, 256, 128, 64], 144 hop_sizes=[512, 256, 128, 64, 32, 16], 145 win_lengths=[2048, 1024, 512, 256, 128, 64], 146 window="hann_window"):''' 147 148 '''def __init__(self, 149 device, 150 fft_sizes=[2048, 1024, 512, 256, 128, 64], 151 hop_sizes=[256, 128, 64, 32, 16, 8], 152 win_lengths=[1024, 512, 256, 128, 64, 32], 153 window="hann_window"):''' 154 155 def __init__(self, 156 device, 157 fft_sizes=[2560, 1280, 640, 320, 160, 80], 158 hop_sizes=[640, 320, 160, 80, 40, 20], 159 win_lengths=[2560, 1280, 640, 320, 160, 80], 160 window="hann_window"): 161 162 super(MultiResolutionSTFTLoss, self).__init__() 163 assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) 164 self.stft_losses = torch.nn.ModuleList() 165 for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths): 166 self.stft_losses += [STFTLoss(device, fs, ss, wl, window)] 167 168 def forward(self, x, y): 169 """Calculate forward propagation. 170 Args: 171 x (Tensor): Predicted signal (B, T). 172 y (Tensor): Groundtruth signal (B, T). 173 Returns: 174 Tensor: Multi resolution spectral convergence loss value. 175 Tensor: Multi resolution log STFT magnitude loss value. 176 """ 177 sc_loss = 0.0 178 mag_loss = 0.0 179 for f in self.stft_losses: 180 sc_l, mag_l = f(x, y) 181 sc_loss += sc_l 182 #mag_loss += mag_l 183 sc_loss /= len(self.stft_losses) 184 mag_loss /= len(self.stft_losses) 185 186 return sc_loss #mag_loss #+ 187