xref: /aosp_15_r20/external/libopus/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2/* Copyright (c) 2023 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30import torch
31from torch import nn
32import torch.nn.functional as F
33
34from utils.endoscopy import write_data
35from utils.softquant import soft_quant
36
37class LimitedAdaptiveComb1d(nn.Module):
38    COUNTER = 1
39
40    def __init__(self,
41                 kernel_size,
42                 feature_dim,
43                 frame_size=160,
44                 overlap_size=40,
45                 padding=None,
46                 max_lag=256,
47                 name=None,
48                 gain_limit_db=10,
49                 global_gain_limits_db=[-6, 6],
50                 norm_p=2,
51                 softquant=False,
52                 apply_weight_norm=False,
53                 **kwargs):
54        """
55
56        Parameters:
57        -----------
58
59        feature_dim : int
60            dimension of features from which kernels, biases and gains are computed
61
62        frame_size : int, optional
63            frame size, defaults to 160
64
65        overlap_size : int, optional
66            overlap size for filter cross-fade. Cross-fade is done on the first overlap_size samples of every frame, defaults to 40
67
68        use_bias : bool, optional
69            if true, biases will be added to output channels. Defaults to True
70
71        padding : List[int, int], optional
72            left and right padding. Defaults to [(kernel_size - 1) // 2, kernel_size - 1 - (kernel_size - 1) // 2]
73
74        max_lag : int, optional
75            maximal pitch lag, defaults to 256
76
77        have_a0 : bool, optional
78            If true, the filter coefficient a0 will be learned as a positive gain (requires in_channels == out_channels). Otherwise, a0 is set to 0. Defaults to False
79
80        name: str or None, optional
81            specifies a name attribute for the module. If None the name is auto generated as comb_1d_COUNT, where COUNT is an instance counter for LimitedAdaptiveComb1d
82
83        """
84
85        super(LimitedAdaptiveComb1d, self).__init__()
86
87        self.in_channels   = 1
88        self.out_channels  = 1
89        self.feature_dim   = feature_dim
90        self.kernel_size   = kernel_size
91        self.frame_size    = frame_size
92        self.overlap_size  = overlap_size
93        self.max_lag       = max_lag
94        self.limit_db      = gain_limit_db
95        self.norm_p        = norm_p
96
97        if name is None:
98            self.name = "limited_adaptive_comb1d_" + str(LimitedAdaptiveComb1d.COUNTER)
99            LimitedAdaptiveComb1d.COUNTER += 1
100        else:
101            self.name = name
102
103        norm = torch.nn.utils.weight_norm if apply_weight_norm else lambda x, name=None: x
104
105        # network for generating convolution weights
106        self.conv_kernel = norm(nn.Linear(feature_dim, kernel_size))
107
108        if softquant:
109            self.conv_kernel = soft_quant(self.conv_kernel)
110
111
112        # comb filter gain
113        self.filter_gain = norm(nn.Linear(feature_dim, 1))
114        self.log_gain_limit = gain_limit_db * 0.11512925464970229
115        with torch.no_grad():
116            self.filter_gain.bias[:] = max(0.1, 4 + self.log_gain_limit)
117
118        self.global_filter_gain = norm(nn.Linear(feature_dim, 1))
119        log_min, log_max = global_gain_limits_db[0] * 0.11512925464970229, global_gain_limits_db[1] * 0.11512925464970229
120        self.filter_gain_a = (log_max - log_min) / 2
121        self.filter_gain_b = (log_max + log_min) / 2
122
123        if type(padding) == type(None):
124            self.padding = [kernel_size // 2, kernel_size - 1 - kernel_size // 2]
125        else:
126            self.padding = padding
127
128        self.overlap_win = nn.Parameter(.5 + .5 * torch.cos((torch.arange(self.overlap_size) + 0.5) * torch.pi / overlap_size), requires_grad=False)
129
130    def forward(self, x, features, lags, debug=False):
131        """ adaptive 1d convolution
132
133
134        Parameters:
135        -----------
136        x : torch.tensor
137            input signal of shape (batch_size, in_channels, num_samples)
138
139        feathres : torch.tensor
140            frame-wise features of shape (batch_size, num_frames, feature_dim)
141
142        lags: torch.LongTensor
143            frame-wise lags for comb-filtering
144
145        """
146
147        batch_size = x.size(0)
148        num_frames = features.size(1)
149        num_samples = x.size(2)
150        frame_size = self.frame_size
151        overlap_size = self.overlap_size
152        kernel_size = self.kernel_size
153        win1 = torch.flip(self.overlap_win, [0])
154        win2 = self.overlap_win
155
156        if num_samples // self.frame_size != num_frames:
157            raise ValueError('non matching sizes in AdaptiveConv1d.forward')
158
159        conv_kernels = self.conv_kernel(features).reshape((batch_size, num_frames, self.out_channels, self.in_channels, self.kernel_size))
160        conv_kernels = conv_kernels / (1e-6 + torch.norm(conv_kernels, p=self.norm_p, dim=-1, keepdim=True))
161
162        conv_gains   = torch.exp(- torch.relu(self.filter_gain(features).permute(0, 2, 1)) + self.log_gain_limit)
163        # calculate gains
164        global_conv_gains   = torch.exp(self.filter_gain_a * torch.tanh(self.global_filter_gain(features).permute(0, 2, 1)) + self.filter_gain_b)
165
166        if debug and batch_size == 1:
167            key = self.name + "_gains"
168            write_data(key, conv_gains.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
169            key = self.name + "_kernels"
170            write_data(key, conv_kernels.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
171            key = self.name + "_lags"
172            write_data(key, lags.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
173            key = self.name + "_global_conv_gains"
174            write_data(key, global_conv_gains.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
175
176
177        # frame-wise convolution with overlap-add
178        output_frames = []
179        overlap_mem = torch.zeros((batch_size, self.out_channels, self.overlap_size), device=x.device)
180        x = F.pad(x, self.padding)
181        x = F.pad(x, [self.max_lag, self.overlap_size])
182
183        idx = torch.arange(frame_size + kernel_size - 1 + overlap_size).to(x.device).view(1, 1, -1)
184        idx = torch.repeat_interleave(idx, batch_size, 0)
185        idx = torch.repeat_interleave(idx, self.in_channels, 1)
186
187
188        for i in range(num_frames):
189
190            cidx = idx + i * frame_size + self.max_lag - lags[..., i].view(batch_size, 1, 1)
191            xx = torch.gather(x, -1, cidx).reshape((1, batch_size * self.in_channels, -1))
192
193            new_chunk = torch.conv1d(xx, conv_kernels[:, i, ...].reshape((batch_size * self.out_channels, self.in_channels, self.kernel_size)), groups=batch_size).reshape(batch_size, self.out_channels, -1)
194
195            offset = self.max_lag + self.padding[0]
196            new_chunk = global_conv_gains[:, :, i : i + 1] * (new_chunk * conv_gains[:, :, i : i + 1] + x[..., offset + i * frame_size : offset + (i + 1) * frame_size + overlap_size])
197
198            # overlapping part
199            output_frames.append(new_chunk[:, :, : overlap_size] * win1 + overlap_mem * win2)
200
201            # non-overlapping part
202            output_frames.append(new_chunk[:, :, overlap_size : frame_size])
203
204            # mem for next frame
205            overlap_mem = new_chunk[:, :, frame_size :]
206
207        # concatenate chunks
208        output = torch.cat(output_frames, dim=-1)
209
210        return output
211
212    def flop_count(self, rate):
213        frame_rate = rate / self.frame_size
214        overlap = self.overlap_size
215        overhead = overlap / self.frame_size
216
217        count = 0
218
219        # kernel computation and filtering
220        count += 2 * (frame_rate * self.feature_dim * self.kernel_size)
221        count += 2 * (self.in_channels * self.out_channels * self.kernel_size * (1 + overhead) * rate)
222        count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels
223
224        # a0 computation
225        count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels
226
227        # windowing
228        count += overlap * frame_rate * 3 * self.out_channels
229
230        return count
231