1""" 2/* Copyright (c) 2023 Amazon 3 Written by Jan Buethe */ 4/* 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions 7 are met: 8 9 - Redistributions of source code must retain the above copyright 10 notice, this list of conditions and the following disclaimer. 11 12 - Redistributions in binary form must reproduce the above copyright 13 notice, this list of conditions and the following disclaimer in the 14 documentation and/or other materials provided with the distribution. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*/ 28""" 29 30 31import torch 32from torch import nn 33import torch.nn.functional as F 34 35import numpy as np 36 37from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d 38from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d 39from utils.layers.td_shaper import TDShaper 40from utils.layers.noise_shaper import NoiseShaper 41from utils.complexity import _conv1d_flop_count 42from utils.endoscopy import write_data 43 44from models.nns_base import NNSBase 45from models.lpcnet_feature_net import LPCNetFeatureNet 46from .scale_embedding import ScaleEmbedding 47 48def print_channels(y, prefix="", name="", rate=16000): 49 num_channels = y.size(1) 50 for i in range(num_channels): 51 channel_name = f"{prefix}_c{i:02d}" 52 if len(name) > 0: channel_name += "_" + name 53 ch = y[0,i,:].detach().cpu().numpy() 54 ch = ((2**14) * ch / np.max(ch)).astype(np.int16) 55 write_data(channel_name, ch, rate) 56 57 58 59class LaVoce(nn.Module): 60 """ Linear-Adaptive VOCodEr """ 61 FEATURE_FRAME_SIZE=160 62 FRAME_SIZE=80 63 64 def __init__(self, 65 num_features=20, 66 pitch_embedding_dim=64, 67 cond_dim=256, 68 pitch_max=300, 69 kernel_size=15, 70 preemph=0.85, 71 comb_gain_limit_db=-6, 72 global_gain_limits_db=[-6, 6], 73 conv_gain_limits_db=[-6, 6], 74 norm_p=2, 75 avg_pool_k=4, 76 pulses=False, 77 innovate1=True, 78 innovate2=False, 79 innovate3=False, 80 ftrans_k=2): 81 82 super().__init__() 83 84 85 self.num_features = num_features 86 self.cond_dim = cond_dim 87 self.pitch_max = pitch_max 88 self.pitch_embedding_dim = pitch_embedding_dim 89 self.kernel_size = kernel_size 90 self.preemph = preemph 91 self.pulses = pulses 92 self.ftrans_k = ftrans_k 93 94 assert self.FEATURE_FRAME_SIZE % self.FRAME_SIZE == 0 95 self.upsamp_factor = self.FEATURE_FRAME_SIZE // self.FRAME_SIZE 96 97 # pitch embedding 98 self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim) 99 100 # feature net 101 self.feature_net = LPCNetFeatureNet(num_features + pitch_embedding_dim, cond_dim, self.upsamp_factor) 102 103 # noise shaper 104 self.noise_shaper = NoiseShaper(cond_dim, self.FRAME_SIZE) 105 106 # comb filters 107 left_pad = self.kernel_size // 2 108 right_pad = self.kernel_size - 1 - left_pad 109 self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p) 110 self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p) 111 112 113 self.af_prescale = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) 114 self.af_mix = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) 115 116 # spectral shaping 117 self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) 118 119 # non-linear transforms 120 self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, innovate=innovate1) 121 self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, innovate=innovate2) 122 self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, innovate=innovate3) 123 124 # combinators 125 self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) 126 self.af3 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) 127 self.af4 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) 128 129 # feature transforms 130 self.post_cf1 = nn.Conv1d(cond_dim, cond_dim, ftrans_k) 131 self.post_cf2 = nn.Conv1d(cond_dim, cond_dim, ftrans_k) 132 self.post_af1 = nn.Conv1d(cond_dim, cond_dim, ftrans_k) 133 self.post_af2 = nn.Conv1d(cond_dim, cond_dim, ftrans_k) 134 self.post_af3 = nn.Conv1d(cond_dim, cond_dim, ftrans_k) 135 136 137 def create_phase_signals(self, periods): 138 139 batch_size = periods.size(0) 140 progression = torch.arange(1, self.FRAME_SIZE + 1, dtype=periods.dtype, device=periods.device).view((1, -1)) 141 progression = torch.repeat_interleave(progression, batch_size, 0) 142 143 phase0 = torch.zeros(batch_size, dtype=periods.dtype, device=periods.device).unsqueeze(-1) 144 chunks = [] 145 for sframe in range(periods.size(1)): 146 f = (2.0 * torch.pi / periods[:, sframe]).unsqueeze(-1) 147 148 if self.pulses: 149 alpha = torch.cos(f).view(batch_size, 1, 1) 150 chunk_sin = torch.sin(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE) 151 pulse_a = torch.relu(chunk_sin - alpha) / (1 - alpha) 152 pulse_b = torch.relu(-chunk_sin - alpha) / (1 - alpha) 153 154 chunk = torch.cat((pulse_a, pulse_b), dim = 1) 155 else: 156 chunk_sin = torch.sin(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE) 157 chunk_cos = torch.cos(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE) 158 159 chunk = torch.cat((chunk_sin, chunk_cos), dim = 1) 160 161 phase0 = phase0 + self.FRAME_SIZE * f 162 163 chunks.append(chunk) 164 165 phase_signals = torch.cat(chunks, dim=-1) 166 167 return phase_signals 168 169 def flop_count(self, rate=16000, verbose=False): 170 171 frame_rate = rate / self.FRAME_SIZE 172 173 # feature net 174 feature_net_flops = self.feature_net.flop_count(frame_rate) 175 comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate) 176 af_flops = self.af1.flop_count(rate) + self.af2.flop_count(rate) + self.af3.flop_count(rate) + self.af4.flop_count(rate) + self.af_prescale.flop_count(rate) + self.af_mix.flop_count(rate) 177 feature_flops = (_conv1d_flop_count(self.post_cf1, frame_rate) + _conv1d_flop_count(self.post_cf2, frame_rate) 178 + _conv1d_flop_count(self.post_af1, frame_rate) + _conv1d_flop_count(self.post_af2, frame_rate) + _conv1d_flop_count(self.post_af3, frame_rate)) 179 180 if verbose: 181 print(f"feature net: {feature_net_flops / 1e6} MFLOPS") 182 print(f"comb filters: {comb_flops / 1e6} MFLOPS") 183 print(f"adaptive conv: {af_flops / 1e6} MFLOPS") 184 print(f"feature transforms: {feature_flops / 1e6} MFLOPS") 185 186 return feature_net_flops + comb_flops + af_flops + feature_flops 187 188 def feature_transform(self, f, layer): 189 f = f.permute(0, 2, 1) 190 f = F.pad(f, [self.ftrans_k - 1, 0]) 191 f = torch.tanh(layer(f)) 192 return f.permute(0, 2, 1) 193 194 def forward(self, features, periods, debug=False): 195 196 periods = periods.squeeze(-1) 197 pitch_embedding = self.pitch_embedding(periods) 198 199 full_features = torch.cat((features, pitch_embedding), dim=-1) 200 cf = self.feature_net(full_features) 201 202 # upsample periods 203 periods = torch.repeat_interleave(periods, self.upsamp_factor, 1) 204 205 # pre-net 206 ref_phase = torch.tanh(self.create_phase_signals(periods)) 207 if debug: print_channels(ref_phase, prefix="lavoce_01", name="pulse") 208 x = self.af_prescale(ref_phase, cf) 209 noise = self.noise_shaper(cf) 210 if debug: print_channels(torch.cat((x, noise), dim=1), prefix="lavoce_02", name="inputs") 211 y = self.af_mix(torch.cat((x, noise), dim=1), cf) 212 if debug: print_channels(y, prefix="lavoce_03", name="postselect1") 213 214 # temporal shaping + innovating 215 y1 = y[:, 0:1, :] 216 y2 = self.tdshape1(y[:, 1:2, :], cf) 217 if debug: print_channels(y2, prefix="lavoce_04", name="postshape1") 218 y = torch.cat((y1, y2), dim=1) 219 y = self.af2(y, cf, debug=debug) 220 if debug: print_channels(y, prefix="lavoce_05", name="postselect2") 221 cf = self.feature_transform(cf, self.post_af2) 222 223 y1 = y[:, 0:1, :] 224 y2 = self.tdshape2(y[:, 1:2, :], cf) 225 if debug: print_channels(y2, prefix="lavoce_06", name="postshape2") 226 y = torch.cat((y1, y2), dim=1) 227 y = self.af3(y, cf, debug=debug) 228 if debug: print_channels(y, prefix="lavoce_07", name="postmix1") 229 cf = self.feature_transform(cf, self.post_af3) 230 231 # spectral shaping 232 y = self.cf1(y, cf, periods, debug=debug) 233 if debug: print_channels(y, prefix="lavoce_08", name="postcomb1") 234 cf = self.feature_transform(cf, self.post_cf1) 235 236 y = self.cf2(y, cf, periods, debug=debug) 237 if debug: print_channels(y, prefix="lavoce_09", name="postcomb2") 238 cf = self.feature_transform(cf, self.post_cf2) 239 240 y = self.af1(y, cf, debug=debug) 241 if debug: print_channels(y, prefix="lavoce_10", name="postselect3") 242 cf = self.feature_transform(cf, self.post_af1) 243 244 # final temporal env adjustment 245 y1 = y[:, 0:1, :] 246 y2 = self.tdshape3(y[:, 1:2, :], cf) 247 if debug: print_channels(y2, prefix="lavoce_11", name="postshape3") 248 y = torch.cat((y1, y2), dim=1) 249 y = self.af4(y, cf, debug=debug) 250 if debug: print_channels(y, prefix="lavoce_12", name="postmix2") 251 252 return y 253 254 def process(self, features, periods, debug=False): 255 256 self.eval() 257 device = next(iter(self.parameters())).device 258 with torch.no_grad(): 259 260 # run model 261 f = features.unsqueeze(0).to(device) 262 p = periods.unsqueeze(0).to(device) 263 264 y = self.forward(f, p, debug=debug).squeeze() 265 266 # deemphasis 267 if self.preemph > 0: 268 for i in range(len(y) - 1): 269 y[i + 1] += self.preemph * y[i] 270 271 # clip to valid range 272 out = torch.clip((2**15) * y, -2**15, 2**15 - 1).short() 273 274 return out