1""" 2/* Copyright (c) 2023 Amazon 3 Written by Jan Buethe */ 4/* 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions 7 are met: 8 9 - Redistributions of source code must retain the above copyright 10 notice, this list of conditions and the following disclaimer. 11 12 - Redistributions in binary form must reproduce the above copyright 13 notice, this list of conditions and the following disclaimer in the 14 documentation and/or other materials provided with the distribution. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*/ 28""" 29 30import torch 31from torch import nn 32import numpy as np 33 34from utils.ulaw import lin2ulawq, ulaw2lin 35from utils.sample import sample_excitation 36from utils.pcm import clip_to_int16 37from utils.sparsification import GRUSparsifier, calculate_gru_flops_per_step 38from utils.layers import DualFC 39from utils.misc import get_pdf_from_tree 40 41 42class LPCNet(nn.Module): 43 def __init__(self, config): 44 super(LPCNet, self).__init__() 45 46 # 47 self.input_layout = config['input_layout'] 48 self.feature_history = config['feature_history'] 49 self.feature_lookahead = config['feature_lookahead'] 50 51 # frame rate network parameters 52 self.feature_dimension = config['feature_dimension'] 53 self.period_embedding_dim = config['period_embedding_dim'] 54 self.period_levels = config['period_levels'] 55 self.feature_channels = self.feature_dimension + self.period_embedding_dim 56 self.feature_conditioning_dim = config['feature_conditioning_dim'] 57 self.feature_conv_kernel_size = config['feature_conv_kernel_size'] 58 59 60 # frame rate network layers 61 self.period_embedding = nn.Embedding(self.period_levels, self.period_embedding_dim) 62 self.feature_conv1 = nn.Conv1d(self.feature_channels, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid') 63 self.feature_conv2 = nn.Conv1d(self.feature_conditioning_dim, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid') 64 self.feature_dense1 = nn.Linear(self.feature_conditioning_dim, self.feature_conditioning_dim) 65 self.feature_dense2 = nn.Linear(*(2*[self.feature_conditioning_dim])) 66 67 # sample rate network parameters 68 self.frame_size = config['frame_size'] 69 self.signal_levels = config['signal_levels'] 70 self.signal_embedding_dim = config['signal_embedding_dim'] 71 self.gru_a_units = config['gru_a_units'] 72 self.gru_b_units = config['gru_b_units'] 73 self.output_levels = config['output_levels'] 74 self.hsampling = config.get('hsampling', False) 75 76 self.gru_a_input_dim = len(self.input_layout['signals']) * self.signal_embedding_dim + self.feature_conditioning_dim 77 self.gru_b_input_dim = self.gru_a_units + self.feature_conditioning_dim 78 79 # sample rate network layers 80 self.signal_embedding = nn.Embedding(self.signal_levels, self.signal_embedding_dim) 81 self.gru_a = nn.GRU(self.gru_a_input_dim, self.gru_a_units, batch_first=True) 82 self.gru_b = nn.GRU(self.gru_b_input_dim, self.gru_b_units, batch_first=True) 83 self.dual_fc = DualFC(self.gru_b_units, self.output_levels) 84 85 # sparsification 86 self.sparsifier = [] 87 88 # GRU A 89 if 'gru_a' in config['sparsification']: 90 gru_config = config['sparsification']['gru_a'] 91 task_list = [(self.gru_a, gru_config['params'])] 92 self.sparsifier.append(GRUSparsifier(task_list, 93 gru_config['start'], 94 gru_config['stop'], 95 gru_config['interval'], 96 gru_config['exponent']) 97 ) 98 self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a, 99 gru_config['params'], drop_input=True) 100 else: 101 self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a, drop_input=True) 102 103 # GRU B 104 if 'gru_b' in config['sparsification']: 105 gru_config = config['sparsification']['gru_b'] 106 task_list = [(self.gru_b, gru_config['params'])] 107 self.sparsifier.append(GRUSparsifier(task_list, 108 gru_config['start'], 109 gru_config['stop'], 110 gru_config['interval'], 111 gru_config['exponent']) 112 ) 113 self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b, 114 gru_config['params']) 115 else: 116 self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b) 117 118 # inference parameters 119 self.lpc_gamma = config.get('lpc_gamma', 1) 120 121 def sparsify(self): 122 for sparsifier in self.sparsifier: 123 sparsifier.step() 124 125 def get_gflops(self, fs, verbose=False): 126 gflops = 0 127 128 # frame rate network 129 conditioning_dim = self.feature_conditioning_dim 130 feature_channels = self.feature_channels 131 frame_rate = fs / self.frame_size 132 frame_rate_network_complexity = 1e-9 * 2 * (5 * conditioning_dim + 3 * feature_channels) * conditioning_dim * frame_rate 133 if verbose: 134 print(f"frame rate network: {frame_rate_network_complexity} GFLOPS") 135 gflops += frame_rate_network_complexity 136 137 # gru a 138 gru_a_rate = fs 139 gru_a_complexity = 1e-9 * gru_a_rate * self.gru_a_flops_per_step 140 if verbose: 141 print(f"gru A: {gru_a_complexity} GFLOPS") 142 gflops += gru_a_complexity 143 144 # gru b 145 gru_b_rate = fs 146 gru_b_complexity = 1e-9 * gru_b_rate * self.gru_b_flops_per_step 147 if verbose: 148 print(f"gru B: {gru_b_complexity} GFLOPS") 149 gflops += gru_b_complexity 150 151 152 # dual fcs 153 fc = self.dual_fc 154 rate = fs 155 input_size = fc.dense1.in_features 156 output_size = fc.dense1.out_features 157 dual_fc_complexity = 1e-9 * (4 * input_size * output_size + 22 * output_size) * rate 158 if self.hsampling: 159 dual_fc_complexity /= 8 160 if verbose: 161 print(f"dual_fc: {dual_fc_complexity} GFLOPS") 162 gflops += dual_fc_complexity 163 164 if verbose: 165 print(f'total: {gflops} GFLOPS') 166 167 return gflops 168 169 def frame_rate_network(self, features, periods): 170 171 embedded_periods = torch.flatten(self.period_embedding(periods), 2, 3) 172 features = torch.concat((features, embedded_periods), dim=-1) 173 174 # convert to channels first and calculate conditioning vector 175 c = torch.permute(features, [0, 2, 1]) 176 177 c = torch.tanh(self.feature_conv1(c)) 178 c = torch.tanh(self.feature_conv2(c)) 179 # back to channels last 180 c = torch.permute(c, [0, 2, 1]) 181 c = torch.tanh(self.feature_dense1(c)) 182 c = torch.tanh(self.feature_dense2(c)) 183 184 return c 185 186 def sample_rate_network(self, signals, c, gru_states): 187 embedded_signals = torch.flatten(self.signal_embedding(signals), 2, 3) 188 c_upsampled = torch.repeat_interleave(c, self.frame_size, dim=1) 189 190 y = torch.concat((embedded_signals, c_upsampled), dim=-1) 191 y, gru_a_state = self.gru_a(y, gru_states[0]) 192 y = torch.concat((y, c_upsampled), dim=-1) 193 y, gru_b_state = self.gru_b(y, gru_states[1]) 194 195 y = self.dual_fc(y) 196 197 if self.hsampling: 198 y = torch.sigmoid(y) 199 log_probs = torch.log(get_pdf_from_tree(y) + 1e-6) 200 else: 201 log_probs = torch.log_softmax(y, dim=-1) 202 203 return log_probs, (gru_a_state, gru_b_state) 204 205 def decoder(self, signals, c, gru_states): 206 embedded_signals = torch.flatten(self.signal_embedding(signals), 2, 3) 207 208 y = torch.concat((embedded_signals, c), dim=-1) 209 y, gru_a_state = self.gru_a(y, gru_states[0]) 210 y = torch.concat((y, c), dim=-1) 211 y, gru_b_state = self.gru_b(y, gru_states[1]) 212 213 y = self.dual_fc(y) 214 215 if self.hsampling: 216 y = torch.sigmoid(y) 217 probs = get_pdf_from_tree(y) 218 else: 219 probs = torch.softmax(y, dim=-1) 220 221 return probs, (gru_a_state, gru_b_state) 222 223 def forward(self, features, periods, signals, gru_states): 224 225 c = self.frame_rate_network(features, periods) 226 log_probs, _ = self.sample_rate_network(signals, c, gru_states) 227 228 return log_probs 229 230 def generate(self, features, periods, lpcs): 231 232 with torch.no_grad(): 233 device = self.parameters().__next__().device 234 235 num_frames = features.shape[0] - self.feature_history - self.feature_lookahead 236 lpc_order = lpcs.shape[-1] 237 num_input_signals = len(self.input_layout['signals']) 238 pitch_corr_position = self.input_layout['features']['pitch_corr'][0] 239 240 # signal buffers 241 pcm = torch.zeros((num_frames * self.frame_size + lpc_order)) 242 output = torch.zeros((num_frames * self.frame_size), dtype=torch.int16) 243 mem = 0 244 245 # state buffers 246 gru_a_state = torch.zeros((1, 1, self.gru_a_units)) 247 gru_b_state = torch.zeros((1, 1, self.gru_b_units)) 248 gru_states = [gru_a_state, gru_b_state] 249 250 input_signals = torch.zeros((1, 1, num_input_signals), dtype=torch.long) + 128 251 252 # push data to device 253 features = features.to(device) 254 periods = periods.to(device) 255 lpcs = lpcs.to(device) 256 257 # lpc weighting 258 weights = torch.FloatTensor([self.lpc_gamma ** (i + 1) for i in range(lpc_order)]).to(device) 259 lpcs = lpcs * weights 260 261 # run feature encoding 262 c = self.frame_rate_network(features.unsqueeze(0), periods.unsqueeze(0)) 263 264 for frame_index in range(num_frames): 265 frame_start = frame_index * self.frame_size 266 pitch_corr = features[frame_index + self.feature_history, pitch_corr_position] 267 a = - torch.flip(lpcs[frame_index + self.feature_history], [0]) 268 current_c = c[:, frame_index : frame_index + 1, :] 269 270 for i in range(self.frame_size): 271 pcm_position = frame_start + i + lpc_order 272 output_position = frame_start + i 273 274 # prepare input 275 pred = torch.sum(pcm[pcm_position - lpc_order : pcm_position] * a) 276 if 'prediction' in self.input_layout['signals']: 277 input_signals[0, 0, self.input_layout['signals']['prediction']] = lin2ulawq(pred) 278 279 # run single step of sample rate network 280 probs, gru_states = self.decoder( 281 input_signals, 282 current_c, 283 gru_states 284 ) 285 286 # sample from output 287 exc_ulaw = sample_excitation(probs, pitch_corr) 288 289 # signal generation 290 exc = ulaw2lin(exc_ulaw) 291 sig = exc + pred 292 pcm[pcm_position] = sig 293 mem = 0.85 * mem + float(sig) 294 output[output_position] = clip_to_int16(round(mem)) 295 296 # buffer update 297 if 'last_signal' in self.input_layout['signals']: 298 input_signals[0, 0, self.input_layout['signals']['last_signal']] = lin2ulawq(sig) 299 300 if 'last_error' in self.input_layout['signals']: 301 input_signals[0, 0, self.input_layout['signals']['last_error']] = lin2ulawq(exc) 302 303 return output 304