1""" 2/* Copyright (c) 2023 Amazon 3 Written by Jan Buethe */ 4/* 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions 7 are met: 8 9 - Redistributions of source code must retain the above copyright 10 notice, this list of conditions and the following disclaimer. 11 12 - Redistributions in binary form must reproduce the above copyright 13 notice, this list of conditions and the following disclaimer in the 14 documentation and/or other materials provided with the distribution. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*/ 28""" 29 30import numbers 31 32import torch 33from torch import nn 34import torch.nn.functional as F 35from torch.nn.utils import weight_norm 36 37 38import numpy as np 39 40from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d 41from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d 42from utils.layers.td_shaper import TDShaper 43from utils.complexity import _conv1d_flop_count 44 45from models.nns_base import NNSBase 46from models.silk_feature_net_pl import SilkFeatureNetPL 47from models.silk_feature_net import SilkFeatureNet 48from .scale_embedding import ScaleEmbedding 49 50import sys 51sys.path.append('../dnntools') 52from dnntools.quantization import soft_quant 53from dnntools.sparsification import create_sparsifier, mark_for_sparsification 54 55class NoLACE(NNSBase): 56 """ Non-Linear Adaptive Coding Enhancer """ 57 FRAME_SIZE=80 58 59 def __init__(self, 60 num_features=47, 61 pitch_embedding_dim=64, 62 cond_dim=256, 63 pitch_max=257, 64 kernel_size=15, 65 preemph=0.85, 66 skip=91, 67 comb_gain_limit_db=-6, 68 global_gain_limits_db=[-6, 6], 69 conv_gain_limits_db=[-6, 6], 70 numbits_range=[50, 650], 71 numbits_embedding_dim=8, 72 hidden_feature_dim=64, 73 partial_lookahead=True, 74 norm_p=2, 75 avg_pool_k=4, 76 pool_after=False, 77 softquant=False, 78 sparsify=False, 79 sparsification_schedule=[100, 1000, 100], 80 sparsification_density=0.5, 81 apply_weight_norm=False): 82 83 super().__init__(skip=skip, preemph=preemph) 84 85 self.num_features = num_features 86 self.cond_dim = cond_dim 87 self.pitch_max = pitch_max 88 self.pitch_embedding_dim = pitch_embedding_dim 89 self.kernel_size = kernel_size 90 self.preemph = preemph 91 self.skip = skip 92 self.numbits_range = numbits_range 93 self.numbits_embedding_dim = numbits_embedding_dim 94 self.hidden_feature_dim = hidden_feature_dim 95 self.partial_lookahead = partial_lookahead 96 97 if isinstance(sparsification_density, numbers.Number): 98 sparsification_density = 10 * [sparsification_density] 99 100 norm = weight_norm if apply_weight_norm else lambda x, name=None: x 101 102 # pitch embedding 103 self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim) 104 105 # numbits embedding 106 self.numbits_embedding = ScaleEmbedding(numbits_embedding_dim, *numbits_range, logscale=True) 107 108 # feature net 109 if partial_lookahead: 110 self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant, sparsify=sparsify, sparsification_density=sparsification_density, apply_weight_norm=apply_weight_norm) 111 else: 112 self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim) 113 114 # comb filters 115 left_pad = self.kernel_size // 2 116 right_pad = self.kernel_size - 1 - left_pad 117 self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm) 118 self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm) 119 120 # spectral shaping 121 self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm) 122 123 # non-linear transforms 124 self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant, apply_weight_norm=apply_weight_norm) 125 self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant, apply_weight_norm=apply_weight_norm) 126 self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant, apply_weight_norm=apply_weight_norm) 127 128 # combinators 129 self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm) 130 self.af3 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm) 131 self.af4 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm) 132 133 # feature transforms 134 self.post_cf1 = norm(nn.Conv1d(cond_dim, cond_dim, 2)) 135 self.post_cf2 = norm(nn.Conv1d(cond_dim, cond_dim, 2)) 136 self.post_af1 = norm(nn.Conv1d(cond_dim, cond_dim, 2)) 137 self.post_af2 = norm(nn.Conv1d(cond_dim, cond_dim, 2)) 138 self.post_af3 = norm(nn.Conv1d(cond_dim, cond_dim, 2)) 139 140 if softquant: 141 self.post_cf1 = soft_quant(self.post_cf1) 142 self.post_cf2 = soft_quant(self.post_cf2) 143 self.post_af1 = soft_quant(self.post_af1) 144 self.post_af2 = soft_quant(self.post_af2) 145 self.post_af3 = soft_quant(self.post_af3) 146 147 148 if sparsify: 149 mark_for_sparsification(self.post_cf1, (sparsification_density[4], [8, 4])) 150 mark_for_sparsification(self.post_cf2, (sparsification_density[5], [8, 4])) 151 mark_for_sparsification(self.post_af1, (sparsification_density[6], [8, 4])) 152 mark_for_sparsification(self.post_af2, (sparsification_density[7], [8, 4])) 153 mark_for_sparsification(self.post_af3, (sparsification_density[8], [8, 4])) 154 155 self.sparsifier = create_sparsifier(self, *sparsification_schedule) 156 157 def flop_count(self, rate=16000, verbose=False): 158 159 frame_rate = rate / self.FRAME_SIZE 160 161 # feature net 162 feature_net_flops = self.feature_net.flop_count(frame_rate) 163 comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate) 164 af_flops = self.af1.flop_count(rate) + self.af2.flop_count(rate) + self.af3.flop_count(rate) + self.af4.flop_count(rate) 165 shape_flops = self.tdshape1.flop_count(rate) + self.tdshape2.flop_count(rate) + self.tdshape3.flop_count(rate) 166 feature_flops = (_conv1d_flop_count(self.post_cf1, frame_rate) + _conv1d_flop_count(self.post_cf2, frame_rate) 167 + _conv1d_flop_count(self.post_af1, frame_rate) + _conv1d_flop_count(self.post_af2, frame_rate) + _conv1d_flop_count(self.post_af3, frame_rate)) 168 169 if verbose: 170 print(f"feature net: {feature_net_flops / 1e6} MFLOPS") 171 print(f"comb filters: {comb_flops / 1e6} MFLOPS") 172 print(f"adaptive conv: {af_flops / 1e6} MFLOPS") 173 print(f"feature transforms: {feature_flops / 1e6} MFLOPS") 174 175 return feature_net_flops + comb_flops + af_flops + feature_flops + shape_flops 176 177 def feature_transform(self, f, layer): 178 f0 = f.permute(0, 2, 1) 179 f = F.pad(f0, [1, 0]) 180 f = torch.tanh(layer(f)) 181 return f.permute(0, 2, 1) 182 183 def forward(self, x, features, periods, numbits, debug=False): 184 185 periods = periods.squeeze(-1) 186 pitch_embedding = self.pitch_embedding(periods) 187 numbits_embedding = self.numbits_embedding(numbits).flatten(2) 188 189 full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1) 190 cf = self.feature_net(full_features) 191 192 y = self.cf1(x, cf, periods, debug=debug) 193 cf = self.feature_transform(cf, self.post_cf1) 194 195 y = self.cf2(y, cf, periods, debug=debug) 196 cf = self.feature_transform(cf, self.post_cf2) 197 198 y = self.af1(y, cf, debug=debug) 199 cf = self.feature_transform(cf, self.post_af1) 200 201 y1 = y[:, 0:1, :] 202 y2 = self.tdshape1(y[:, 1:2, :], cf) 203 y = torch.cat((y1, y2), dim=1) 204 y = self.af2(y, cf, debug=debug) 205 cf = self.feature_transform(cf, self.post_af2) 206 207 y1 = y[:, 0:1, :] 208 y2 = self.tdshape2(y[:, 1:2, :], cf) 209 y = torch.cat((y1, y2), dim=1) 210 y = self.af3(y, cf, debug=debug) 211 cf = self.feature_transform(cf, self.post_af3) 212 213 y1 = y[:, 0:1, :] 214 y2 = self.tdshape3(y[:, 1:2, :], cf) 215 y = torch.cat((y1, y2), dim=1) 216 y = self.af4(y, cf, debug=debug) 217 218 return y 219