xref: /aosp_15_r20/external/libopus/dnn/torch/osce/losses/stft_loss.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2/* Copyright (c) 2023 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30"""STFT-based Loss modules."""
31
32import torch
33import torch.nn.functional as F
34from torch import nn
35import numpy as np
36import torchaudio
37
38
39def get_window(win_name, win_length, *args, **kwargs):
40    window_dict = {
41        'bartlett_window'   : torch.bartlett_window,
42        'blackman_window'   : torch.blackman_window,
43        'hamming_window'    : torch.hamming_window,
44        'hann_window'       : torch.hann_window,
45        'kaiser_window'     : torch.kaiser_window
46    }
47
48    if not win_name in window_dict:
49        raise ValueError()
50
51    return window_dict[win_name](win_length, *args, **kwargs)
52
53
54def stft(x, fft_size, hop_size, win_length, window):
55    """Perform STFT and convert to magnitude spectrogram.
56    Args:
57        x (Tensor): Input signal tensor (B, T).
58        fft_size (int): FFT size.
59        hop_size (int): Hop size.
60        win_length (int): Window length.
61        window (str): Window function type.
62    Returns:
63        Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
64    """
65
66    win = get_window(window, win_length).to(x.device)
67    x_stft = torch.stft(x, fft_size, hop_size, win_length, win, return_complex=True)
68
69
70    return torch.clamp(torch.abs(x_stft), min=1e-7)
71
72def spectral_convergence_loss(Y_true, Y_pred):
73    dims=list(range(1, len(Y_pred.shape)))
74    return torch.mean(torch.norm(torch.abs(Y_true) - torch.abs(Y_pred), p="fro", dim=dims) / (torch.norm(Y_pred, p="fro", dim=dims) + 1e-6))
75
76
77def log_magnitude_loss(Y_true, Y_pred):
78    Y_true_log_abs = torch.log(torch.abs(Y_true) + 1e-15)
79    Y_pred_log_abs = torch.log(torch.abs(Y_pred) + 1e-15)
80
81    return torch.mean(torch.abs(Y_true_log_abs - Y_pred_log_abs))
82
83def spectral_xcorr_loss(Y_true, Y_pred):
84    Y_true = Y_true.abs()
85    Y_pred = Y_pred.abs()
86    dims=list(range(1, len(Y_pred.shape)))
87    xcorr = torch.sum(Y_true * Y_pred, dim=dims) / torch.sqrt(torch.sum(Y_true ** 2, dim=dims) * torch.sum(Y_pred ** 2, dim=dims) + 1e-9)
88
89    return 1 - xcorr.mean()
90
91
92
93class MRLogMelLoss(nn.Module):
94    def __init__(self,
95                 fft_sizes=[512, 256, 128, 64],
96                 overlap=0.5,
97                 fs=16000,
98                 n_mels=18
99                 ):
100
101        self.fft_sizes  = fft_sizes
102        self.overlap    = overlap
103        self.fs         = fs
104        self.n_mels     = n_mels
105
106        super().__init__()
107
108        self.mel_specs = []
109        for fft_size in fft_sizes:
110            hop_size = int(round(fft_size * (1 - self.overlap)))
111
112            n_mels = self.n_mels
113            if fft_size < 128:
114                n_mels //= 2
115
116            self.mel_specs.append(torchaudio.transforms.MelSpectrogram(fs, fft_size, hop_length=hop_size, n_mels=n_mels))
117
118        for i, mel_spec in enumerate(self.mel_specs):
119            self.add_module(f'mel_spec_{i+1}', mel_spec)
120
121    def forward(self, y_true, y_pred):
122
123        loss = torch.zeros(1, device=y_true.device)
124
125        for mel_spec in self.mel_specs:
126            Y_true = mel_spec(y_true)
127            Y_pred = mel_spec(y_pred)
128            loss = loss + log_magnitude_loss(Y_true, Y_pred)
129
130        loss = loss / len(self.mel_specs)
131
132        return loss
133
134def create_weight_matrix(num_bins, bins_per_band=10):
135    m = torch.zeros((num_bins, num_bins), dtype=torch.float32)
136
137    r0 = bins_per_band // 2
138    r1 = bins_per_band - r0
139
140    for i in range(num_bins):
141        i0 = max(i - r0, 0)
142        j0 = min(i + r1, num_bins)
143
144        m[i, i0: j0] += 1
145
146        if i < r0:
147            m[i, :r0 - i] += 1
148
149        if i > num_bins - r1:
150            m[i, num_bins - r1 - i:] += 1
151
152    return m / bins_per_band
153
154def weighted_spectral_convergence(Y_true, Y_pred, w):
155
156    # calculate sfm based weights
157    logY = torch.log(torch.abs(Y_true) + 1e-9)
158    Y = torch.abs(Y_true)
159
160    avg_logY = torch.matmul(logY.transpose(1, 2), w)
161    avg_Y = torch.matmul(Y.transpose(1, 2), w)
162
163    sfm = torch.exp(avg_logY) / (avg_Y + 1e-9)
164
165    weight = (torch.relu(1 - sfm) ** .5).transpose(1, 2)
166
167    loss = torch.mean(
168        torch.mean(weight * torch.abs(torch.abs(Y_true) - torch.abs(Y_pred)), dim=[1, 2])
169        / (torch.mean( weight * torch.abs(Y_true), dim=[1, 2]) + 1e-9)
170    )
171
172    return loss
173
174def gen_filterbank(N, Fs=16000):
175    in_freq = (np.arange(N+1, dtype='float32')/N*Fs/2)[None,:]
176    out_freq = (np.arange(N, dtype='float32')/N*Fs/2)[:,None]
177    #ERB from B.C.J Moore, An Introduction to the Psychology of Hearing, 5th Ed., page 73.
178    ERB_N = 24.7 + .108*in_freq
179    delta = np.abs(in_freq-out_freq)/ERB_N
180    center = (delta<.5).astype('float32')
181    R = -12*center*delta**2 + (1-center)*(3-12*delta)
182    RE = 10.**(R/10.)
183    norm = np.sum(RE, axis=1)
184    RE = RE/norm[:, np.newaxis]
185    return torch.from_numpy(RE)
186
187def smooth_log_mag(Y_true, Y_pred, filterbank):
188    Y_true_smooth = torch.matmul(filterbank, torch.abs(Y_true))
189    Y_pred_smooth = torch.matmul(filterbank, torch.abs(Y_pred))
190
191    loss = torch.abs(
192        torch.log(Y_true_smooth + 1e-9) - torch.log(Y_pred_smooth + 1e-9)
193    )
194
195    loss = loss.mean()
196
197    return loss
198
199class MRSTFTLoss(nn.Module):
200    def __init__(self,
201                 fft_sizes=[2048, 1024, 512, 256, 128, 64],
202                 overlap=0.5,
203                 window='hann_window',
204                 fs=16000,
205                 log_mag_weight=1,
206                 sc_weight=0,
207                 wsc_weight=0,
208                 smooth_log_mag_weight=0,
209                 sxcorr_weight=0):
210        super().__init__()
211
212        self.fft_sizes = fft_sizes
213        self.overlap = overlap
214        self.window = window
215        self.log_mag_weight = log_mag_weight
216        self.sc_weight = sc_weight
217        self.wsc_weight = wsc_weight
218        self.smooth_log_mag_weight = smooth_log_mag_weight
219        self.sxcorr_weight = sxcorr_weight
220        self.fs = fs
221
222        # weights for SFM weighted spectral convergence loss
223        self.wsc_weights = torch.nn.ParameterDict()
224        for fft_size in fft_sizes:
225            width = min(11, int(1000 * fft_size / self.fs + .5))
226            width += width % 2
227            self.wsc_weights[str(fft_size)] = torch.nn.Parameter(
228                create_weight_matrix(fft_size // 2 + 1, width),
229                requires_grad=False
230            )
231
232        # filterbanks for smooth log magnitude loss
233        self.filterbanks = torch.nn.ParameterDict()
234        for fft_size in fft_sizes:
235            self.filterbanks[str(fft_size)] = torch.nn.Parameter(
236                gen_filterbank(fft_size//2),
237                requires_grad=False
238            )
239
240
241    def __call__(self, y_true, y_pred):
242
243
244        lm_loss = torch.zeros(1, device=y_true.device)
245        sc_loss = torch.zeros(1, device=y_true.device)
246        wsc_loss = torch.zeros(1, device=y_true.device)
247        slm_loss = torch.zeros(1, device=y_true.device)
248        sxcorr_loss = torch.zeros(1, device=y_true.device)
249
250        for fft_size in self.fft_sizes:
251            hop_size = int(round(fft_size * (1 - self.overlap)))
252            win_size = fft_size
253
254            Y_true = stft(y_true, fft_size, hop_size, win_size, self.window)
255            Y_pred = stft(y_pred, fft_size, hop_size, win_size, self.window)
256
257            if self.log_mag_weight > 0:
258                lm_loss = lm_loss + log_magnitude_loss(Y_true, Y_pred)
259
260            if self.sc_weight > 0:
261                sc_loss = sc_loss + spectral_convergence_loss(Y_true, Y_pred)
262
263            if self.wsc_weight > 0:
264                wsc_loss = wsc_loss + weighted_spectral_convergence(Y_true, Y_pred, self.wsc_weights[str(fft_size)])
265
266            if self.smooth_log_mag_weight > 0:
267                slm_loss = slm_loss + smooth_log_mag(Y_true, Y_pred, self.filterbanks[str(fft_size)])
268
269            if self.sxcorr_weight > 0:
270                sxcorr_loss = sxcorr_loss + spectral_xcorr_loss(Y_true, Y_pred)
271
272
273        total_loss = (self.log_mag_weight * lm_loss + self.sc_weight * sc_loss
274                + self.wsc_weight * wsc_loss + self.smooth_log_mag_weight * slm_loss
275                + self.sxcorr_weight * sxcorr_loss) / len(self.fft_sizes)
276
277        return total_loss