1""" 2/* Copyright (c) 2023 Amazon 3 Written by Jan Buethe */ 4/* 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions 7 are met: 8 9 - Redistributions of source code must retain the above copyright 10 notice, this list of conditions and the following disclaimer. 11 12 - Redistributions in binary form must reproduce the above copyright 13 notice, this list of conditions and the following disclaimer in the 14 documentation and/or other materials provided with the distribution. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*/ 28""" 29 30 31import torch 32from torch import nn 33import torch.nn.functional as F 34 35import numpy as np 36 37from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d 38from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d 39from utils.layers.td_shaper import TDShaper 40from utils.layers.noise_shaper import NoiseShaper 41from utils.complexity import _conv1d_flop_count 42from utils.endoscopy import write_data 43 44from models.nns_base import NNSBase 45from models.lpcnet_feature_net import LPCNetFeatureNet 46from .scale_embedding import ScaleEmbedding 47 48class LaVoce400(nn.Module): 49 """ Linear-Adaptive VOCodEr """ 50 FEATURE_FRAME_SIZE=160 51 FRAME_SIZE=40 52 53 def __init__(self, 54 num_features=20, 55 pitch_embedding_dim=64, 56 cond_dim=256, 57 pitch_max=300, 58 kernel_size=15, 59 preemph=0.85, 60 comb_gain_limit_db=-6, 61 global_gain_limits_db=[-6, 6], 62 conv_gain_limits_db=[-6, 6], 63 norm_p=2, 64 avg_pool_k=4, 65 pulses=False): 66 67 super().__init__() 68 69 70 self.num_features = num_features 71 self.cond_dim = cond_dim 72 self.pitch_max = pitch_max 73 self.pitch_embedding_dim = pitch_embedding_dim 74 self.kernel_size = kernel_size 75 self.preemph = preemph 76 self.pulses = pulses 77 78 assert self.FEATURE_FRAME_SIZE % self.FRAME_SIZE == 0 79 self.upsamp_factor = self.FEATURE_FRAME_SIZE // self.FRAME_SIZE 80 81 # pitch embedding 82 self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim) 83 84 # feature net 85 self.feature_net = LPCNetFeatureNet(num_features + pitch_embedding_dim, cond_dim, self.upsamp_factor) 86 87 # noise shaper 88 self.noise_shaper = NoiseShaper(cond_dim, self.FRAME_SIZE) 89 90 # comb filters 91 left_pad = self.kernel_size // 2 92 right_pad = self.kernel_size - 1 - left_pad 93 self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=20, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p) 94 self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=20, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p) 95 96 97 self.af_prescale = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) 98 self.af_mix = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) 99 100 # spectral shaping 101 self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) 102 103 # non-linear transforms 104 self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, innovate=True) 105 self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k) 106 self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k) 107 108 # combinators 109 self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) 110 self.af3 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) 111 self.af4 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) 112 113 # feature transforms 114 self.post_cf1 = nn.Conv1d(cond_dim, cond_dim, 2) 115 self.post_cf2 = nn.Conv1d(cond_dim, cond_dim, 2) 116 self.post_af1 = nn.Conv1d(cond_dim, cond_dim, 2) 117 self.post_af2 = nn.Conv1d(cond_dim, cond_dim, 2) 118 self.post_af3 = nn.Conv1d(cond_dim, cond_dim, 2) 119 120 121 def create_phase_signals(self, periods): 122 123 batch_size = periods.size(0) 124 progression = torch.arange(1, self.FRAME_SIZE + 1, dtype=periods.dtype, device=periods.device).view((1, -1)) 125 progression = torch.repeat_interleave(progression, batch_size, 0) 126 127 phase0 = torch.zeros(batch_size, dtype=periods.dtype, device=periods.device).unsqueeze(-1) 128 chunks = [] 129 for sframe in range(periods.size(1)): 130 f = (2.0 * torch.pi / periods[:, sframe]).unsqueeze(-1) 131 132 if self.pulses: 133 alpha = torch.cos(f).view(batch_size, 1, 1) 134 chunk_sin = torch.sin(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE) 135 pulse_a = torch.relu(chunk_sin - alpha) / (1 - alpha) 136 pulse_b = torch.relu(-chunk_sin - alpha) / (1 - alpha) 137 138 chunk = torch.cat((pulse_a, pulse_b), dim = 1) 139 else: 140 chunk_sin = torch.sin(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE) 141 chunk_cos = torch.cos(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE) 142 143 chunk = torch.cat((chunk_sin, chunk_cos), dim = 1) 144 145 phase0 = phase0 + self.FRAME_SIZE * f 146 147 chunks.append(chunk) 148 149 phase_signals = torch.cat(chunks, dim=-1) 150 151 return phase_signals 152 153 def flop_count(self, rate=16000, verbose=False): 154 155 frame_rate = rate / self.FRAME_SIZE 156 157 # feature net 158 feature_net_flops = self.feature_net.flop_count(frame_rate) 159 comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate) 160 af_flops = self.af1.flop_count(rate) + self.af2.flop_count(rate) + self.af3.flop_count(rate) + self.af4.flop_count(rate) + self.af_prescale.flop_count(rate) + self.af_mix.flop_count(rate) 161 feature_flops = (_conv1d_flop_count(self.post_cf1, frame_rate) + _conv1d_flop_count(self.post_cf2, frame_rate) 162 + _conv1d_flop_count(self.post_af1, frame_rate) + _conv1d_flop_count(self.post_af2, frame_rate) + _conv1d_flop_count(self.post_af3, frame_rate)) 163 164 if verbose: 165 print(f"feature net: {feature_net_flops / 1e6} MFLOPS") 166 print(f"comb filters: {comb_flops / 1e6} MFLOPS") 167 print(f"adaptive conv: {af_flops / 1e6} MFLOPS") 168 print(f"feature transforms: {feature_flops / 1e6} MFLOPS") 169 170 return feature_net_flops + comb_flops + af_flops + feature_flops 171 172 def feature_transform(self, f, layer): 173 f = f.permute(0, 2, 1) 174 f = F.pad(f, [1, 0]) 175 f = torch.tanh(layer(f)) 176 return f.permute(0, 2, 1) 177 178 def forward(self, features, periods, debug=False): 179 180 periods = periods.squeeze(-1) 181 pitch_embedding = self.pitch_embedding(periods) 182 183 full_features = torch.cat((features, pitch_embedding), dim=-1) 184 cf = self.feature_net(full_features) 185 186 # upsample periods 187 periods = torch.repeat_interleave(periods, self.upsamp_factor, 1) 188 189 # pre-net 190 ref_phase = torch.tanh(self.create_phase_signals(periods)) 191 x = self.af_prescale(ref_phase, cf) 192 noise = self.noise_shaper(cf) 193 y = self.af_mix(torch.cat((x, noise), dim=1), cf) 194 195 if debug: 196 ch0 = y[0,0,:].detach().cpu().numpy() 197 ch1 = y[0,1,:].detach().cpu().numpy() 198 ch0 = (2**15 * ch0 / np.max(ch0)).astype(np.int16) 199 ch1 = (2**15 * ch1 / np.max(ch1)).astype(np.int16) 200 write_data('prior_channel0', ch0, 16000) 201 write_data('prior_channel1', ch1, 16000) 202 203 # temporal shaping + innovating 204 y1 = y[:, 0:1, :] 205 y2 = self.tdshape1(y[:, 1:2, :], cf) 206 y = torch.cat((y1, y2), dim=1) 207 y = self.af2(y, cf, debug=debug) 208 cf = self.feature_transform(cf, self.post_af2) 209 210 y1 = y[:, 0:1, :] 211 y2 = self.tdshape2(y[:, 1:2, :], cf) 212 y = torch.cat((y1, y2), dim=1) 213 y = self.af3(y, cf, debug=debug) 214 cf = self.feature_transform(cf, self.post_af3) 215 216 # spectral shaping 217 y = self.cf1(y, cf, periods, debug=debug) 218 cf = self.feature_transform(cf, self.post_cf1) 219 220 y = self.cf2(y, cf, periods, debug=debug) 221 cf = self.feature_transform(cf, self.post_cf2) 222 223 y = self.af1(y, cf, debug=debug) 224 cf = self.feature_transform(cf, self.post_af1) 225 226 # final temporal env adjustment 227 y1 = y[:, 0:1, :] 228 y2 = self.tdshape3(y[:, 1:2, :], cf) 229 y = torch.cat((y1, y2), dim=1) 230 y = self.af4(y, cf, debug=debug) 231 232 return y 233 234 def process(self, features, periods, debug=False): 235 236 self.eval() 237 device = next(iter(self.parameters())).device 238 with torch.no_grad(): 239 240 # run model 241 f = features.unsqueeze(0).to(device) 242 p = periods.unsqueeze(0).to(device) 243 244 y = self.forward(f, p, debug=debug).squeeze() 245 246 # deemphasis 247 if self.preemph > 0: 248 for i in range(len(y) - 1): 249 y[i + 1] += self.preemph * y[i] 250 251 # clip to valid range 252 out = torch.clip((2**15) * y, -2**15, 2**15 - 1).short() 253 254 return out