1""" 2/* Copyright (c) 2023 Amazon 3 Written by Jan Buethe */ 4/* 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions 7 are met: 8 9 - Redistributions of source code must retain the above copyright 10 notice, this list of conditions and the following disclaimer. 11 12 - Redistributions in binary form must reproduce the above copyright 13 notice, this list of conditions and the following disclaimer in the 14 documentation and/or other materials provided with the distribution. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*/ 28""" 29 30import torch 31from torch import nn 32import torch.nn.functional as F 33 34from utils.endoscopy import write_data 35from utils.softquant import soft_quant 36 37class LimitedAdaptiveComb1d(nn.Module): 38 COUNTER = 1 39 40 def __init__(self, 41 kernel_size, 42 feature_dim, 43 frame_size=160, 44 overlap_size=40, 45 padding=None, 46 max_lag=256, 47 name=None, 48 gain_limit_db=10, 49 global_gain_limits_db=[-6, 6], 50 norm_p=2, 51 softquant=False, 52 apply_weight_norm=False, 53 **kwargs): 54 """ 55 56 Parameters: 57 ----------- 58 59 feature_dim : int 60 dimension of features from which kernels, biases and gains are computed 61 62 frame_size : int, optional 63 frame size, defaults to 160 64 65 overlap_size : int, optional 66 overlap size for filter cross-fade. Cross-fade is done on the first overlap_size samples of every frame, defaults to 40 67 68 use_bias : bool, optional 69 if true, biases will be added to output channels. Defaults to True 70 71 padding : List[int, int], optional 72 left and right padding. Defaults to [(kernel_size - 1) // 2, kernel_size - 1 - (kernel_size - 1) // 2] 73 74 max_lag : int, optional 75 maximal pitch lag, defaults to 256 76 77 have_a0 : bool, optional 78 If true, the filter coefficient a0 will be learned as a positive gain (requires in_channels == out_channels). Otherwise, a0 is set to 0. Defaults to False 79 80 name: str or None, optional 81 specifies a name attribute for the module. If None the name is auto generated as comb_1d_COUNT, where COUNT is an instance counter for LimitedAdaptiveComb1d 82 83 """ 84 85 super(LimitedAdaptiveComb1d, self).__init__() 86 87 self.in_channels = 1 88 self.out_channels = 1 89 self.feature_dim = feature_dim 90 self.kernel_size = kernel_size 91 self.frame_size = frame_size 92 self.overlap_size = overlap_size 93 self.max_lag = max_lag 94 self.limit_db = gain_limit_db 95 self.norm_p = norm_p 96 97 if name is None: 98 self.name = "limited_adaptive_comb1d_" + str(LimitedAdaptiveComb1d.COUNTER) 99 LimitedAdaptiveComb1d.COUNTER += 1 100 else: 101 self.name = name 102 103 norm = torch.nn.utils.weight_norm if apply_weight_norm else lambda x, name=None: x 104 105 # network for generating convolution weights 106 self.conv_kernel = norm(nn.Linear(feature_dim, kernel_size)) 107 108 if softquant: 109 self.conv_kernel = soft_quant(self.conv_kernel) 110 111 112 # comb filter gain 113 self.filter_gain = norm(nn.Linear(feature_dim, 1)) 114 self.log_gain_limit = gain_limit_db * 0.11512925464970229 115 with torch.no_grad(): 116 self.filter_gain.bias[:] = max(0.1, 4 + self.log_gain_limit) 117 118 self.global_filter_gain = norm(nn.Linear(feature_dim, 1)) 119 log_min, log_max = global_gain_limits_db[0] * 0.11512925464970229, global_gain_limits_db[1] * 0.11512925464970229 120 self.filter_gain_a = (log_max - log_min) / 2 121 self.filter_gain_b = (log_max + log_min) / 2 122 123 if type(padding) == type(None): 124 self.padding = [kernel_size // 2, kernel_size - 1 - kernel_size // 2] 125 else: 126 self.padding = padding 127 128 self.overlap_win = nn.Parameter(.5 + .5 * torch.cos((torch.arange(self.overlap_size) + 0.5) * torch.pi / overlap_size), requires_grad=False) 129 130 def forward(self, x, features, lags, debug=False): 131 """ adaptive 1d convolution 132 133 134 Parameters: 135 ----------- 136 x : torch.tensor 137 input signal of shape (batch_size, in_channels, num_samples) 138 139 feathres : torch.tensor 140 frame-wise features of shape (batch_size, num_frames, feature_dim) 141 142 lags: torch.LongTensor 143 frame-wise lags for comb-filtering 144 145 """ 146 147 batch_size = x.size(0) 148 num_frames = features.size(1) 149 num_samples = x.size(2) 150 frame_size = self.frame_size 151 overlap_size = self.overlap_size 152 kernel_size = self.kernel_size 153 win1 = torch.flip(self.overlap_win, [0]) 154 win2 = self.overlap_win 155 156 if num_samples // self.frame_size != num_frames: 157 raise ValueError('non matching sizes in AdaptiveConv1d.forward') 158 159 conv_kernels = self.conv_kernel(features).reshape((batch_size, num_frames, self.out_channels, self.in_channels, self.kernel_size)) 160 conv_kernels = conv_kernels / (1e-6 + torch.norm(conv_kernels, p=self.norm_p, dim=-1, keepdim=True)) 161 162 conv_gains = torch.exp(- torch.relu(self.filter_gain(features).permute(0, 2, 1)) + self.log_gain_limit) 163 # calculate gains 164 global_conv_gains = torch.exp(self.filter_gain_a * torch.tanh(self.global_filter_gain(features).permute(0, 2, 1)) + self.filter_gain_b) 165 166 if debug and batch_size == 1: 167 key = self.name + "_gains" 168 write_data(key, conv_gains.detach().squeeze().cpu().numpy(), 16000 // self.frame_size) 169 key = self.name + "_kernels" 170 write_data(key, conv_kernels.detach().squeeze().cpu().numpy(), 16000 // self.frame_size) 171 key = self.name + "_lags" 172 write_data(key, lags.detach().squeeze().cpu().numpy(), 16000 // self.frame_size) 173 key = self.name + "_global_conv_gains" 174 write_data(key, global_conv_gains.detach().squeeze().cpu().numpy(), 16000 // self.frame_size) 175 176 177 # frame-wise convolution with overlap-add 178 output_frames = [] 179 overlap_mem = torch.zeros((batch_size, self.out_channels, self.overlap_size), device=x.device) 180 x = F.pad(x, self.padding) 181 x = F.pad(x, [self.max_lag, self.overlap_size]) 182 183 idx = torch.arange(frame_size + kernel_size - 1 + overlap_size).to(x.device).view(1, 1, -1) 184 idx = torch.repeat_interleave(idx, batch_size, 0) 185 idx = torch.repeat_interleave(idx, self.in_channels, 1) 186 187 188 for i in range(num_frames): 189 190 cidx = idx + i * frame_size + self.max_lag - lags[..., i].view(batch_size, 1, 1) 191 xx = torch.gather(x, -1, cidx).reshape((1, batch_size * self.in_channels, -1)) 192 193 new_chunk = torch.conv1d(xx, conv_kernels[:, i, ...].reshape((batch_size * self.out_channels, self.in_channels, self.kernel_size)), groups=batch_size).reshape(batch_size, self.out_channels, -1) 194 195 offset = self.max_lag + self.padding[0] 196 new_chunk = global_conv_gains[:, :, i : i + 1] * (new_chunk * conv_gains[:, :, i : i + 1] + x[..., offset + i * frame_size : offset + (i + 1) * frame_size + overlap_size]) 197 198 # overlapping part 199 output_frames.append(new_chunk[:, :, : overlap_size] * win1 + overlap_mem * win2) 200 201 # non-overlapping part 202 output_frames.append(new_chunk[:, :, overlap_size : frame_size]) 203 204 # mem for next frame 205 overlap_mem = new_chunk[:, :, frame_size :] 206 207 # concatenate chunks 208 output = torch.cat(output_frames, dim=-1) 209 210 return output 211 212 def flop_count(self, rate): 213 frame_rate = rate / self.frame_size 214 overlap = self.overlap_size 215 overhead = overlap / self.frame_size 216 217 count = 0 218 219 # kernel computation and filtering 220 count += 2 * (frame_rate * self.feature_dim * self.kernel_size) 221 count += 2 * (self.in_channels * self.out_channels * self.kernel_size * (1 + overhead) * rate) 222 count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels 223 224 # a0 computation 225 count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels 226 227 # windowing 228 count += overlap * frame_rate * 3 * self.out_channels 229 230 return count 231