xref: /aosp_15_r20/external/libopus/dnn/torch/osce/models/no_lace.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2/* Copyright (c) 2023 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30import numbers
31
32import torch
33from torch import nn
34import torch.nn.functional as F
35from torch.nn.utils import weight_norm
36
37
38import numpy as np
39
40from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
41from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
42from utils.layers.td_shaper import TDShaper
43from utils.complexity import _conv1d_flop_count
44
45from models.nns_base import NNSBase
46from models.silk_feature_net_pl import SilkFeatureNetPL
47from models.silk_feature_net import SilkFeatureNet
48from .scale_embedding import ScaleEmbedding
49
50import sys
51sys.path.append('../dnntools')
52from dnntools.quantization import soft_quant
53from dnntools.sparsification import create_sparsifier, mark_for_sparsification
54
55class NoLACE(NNSBase):
56    """ Non-Linear Adaptive Coding Enhancer """
57    FRAME_SIZE=80
58
59    def __init__(self,
60                 num_features=47,
61                 pitch_embedding_dim=64,
62                 cond_dim=256,
63                 pitch_max=257,
64                 kernel_size=15,
65                 preemph=0.85,
66                 skip=91,
67                 comb_gain_limit_db=-6,
68                 global_gain_limits_db=[-6, 6],
69                 conv_gain_limits_db=[-6, 6],
70                 numbits_range=[50, 650],
71                 numbits_embedding_dim=8,
72                 hidden_feature_dim=64,
73                 partial_lookahead=True,
74                 norm_p=2,
75                 avg_pool_k=4,
76                 pool_after=False,
77                 softquant=False,
78                 sparsify=False,
79                 sparsification_schedule=[100, 1000, 100],
80                 sparsification_density=0.5,
81                 apply_weight_norm=False):
82
83        super().__init__(skip=skip, preemph=preemph)
84
85        self.num_features           = num_features
86        self.cond_dim               = cond_dim
87        self.pitch_max              = pitch_max
88        self.pitch_embedding_dim    = pitch_embedding_dim
89        self.kernel_size            = kernel_size
90        self.preemph                = preemph
91        self.skip                   = skip
92        self.numbits_range          = numbits_range
93        self.numbits_embedding_dim  = numbits_embedding_dim
94        self.hidden_feature_dim     = hidden_feature_dim
95        self.partial_lookahead      = partial_lookahead
96
97        if isinstance(sparsification_density, numbers.Number):
98            sparsification_density = 10 * [sparsification_density]
99
100        norm = weight_norm if apply_weight_norm else lambda x, name=None: x
101
102        # pitch embedding
103        self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)
104
105        # numbits embedding
106        self.numbits_embedding = ScaleEmbedding(numbits_embedding_dim, *numbits_range, logscale=True)
107
108        # feature net
109        if partial_lookahead:
110            self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant, sparsify=sparsify, sparsification_density=sparsification_density, apply_weight_norm=apply_weight_norm)
111        else:
112            self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim)
113
114        # comb filters
115        left_pad = self.kernel_size // 2
116        right_pad = self.kernel_size - 1 - left_pad
117        self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
118        self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
119
120        # spectral shaping
121        self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
122
123        # non-linear transforms
124        self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant, apply_weight_norm=apply_weight_norm)
125        self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant, apply_weight_norm=apply_weight_norm)
126        self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant, apply_weight_norm=apply_weight_norm)
127
128        # combinators
129        self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
130        self.af3 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
131        self.af4 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
132
133        # feature transforms
134        self.post_cf1 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
135        self.post_cf2 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
136        self.post_af1 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
137        self.post_af2 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
138        self.post_af3 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
139
140        if softquant:
141            self.post_cf1 = soft_quant(self.post_cf1)
142            self.post_cf2 = soft_quant(self.post_cf2)
143            self.post_af1 = soft_quant(self.post_af1)
144            self.post_af2 = soft_quant(self.post_af2)
145            self.post_af3 = soft_quant(self.post_af3)
146
147
148        if sparsify:
149            mark_for_sparsification(self.post_cf1, (sparsification_density[4], [8, 4]))
150            mark_for_sparsification(self.post_cf2, (sparsification_density[5], [8, 4]))
151            mark_for_sparsification(self.post_af1, (sparsification_density[6], [8, 4]))
152            mark_for_sparsification(self.post_af2, (sparsification_density[7], [8, 4]))
153            mark_for_sparsification(self.post_af3, (sparsification_density[8], [8, 4]))
154
155            self.sparsifier = create_sparsifier(self, *sparsification_schedule)
156
157    def flop_count(self, rate=16000, verbose=False):
158
159        frame_rate = rate / self.FRAME_SIZE
160
161        # feature net
162        feature_net_flops = self.feature_net.flop_count(frame_rate)
163        comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate)
164        af_flops = self.af1.flop_count(rate) + self.af2.flop_count(rate) + self.af3.flop_count(rate) + self.af4.flop_count(rate)
165        shape_flops = self.tdshape1.flop_count(rate) + self.tdshape2.flop_count(rate) + self.tdshape3.flop_count(rate)
166        feature_flops = (_conv1d_flop_count(self.post_cf1, frame_rate) + _conv1d_flop_count(self.post_cf2, frame_rate)
167                         + _conv1d_flop_count(self.post_af1, frame_rate) + _conv1d_flop_count(self.post_af2, frame_rate) + _conv1d_flop_count(self.post_af3, frame_rate))
168
169        if verbose:
170            print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
171            print(f"comb filters: {comb_flops / 1e6} MFLOPS")
172            print(f"adaptive conv: {af_flops / 1e6} MFLOPS")
173            print(f"feature transforms: {feature_flops / 1e6} MFLOPS")
174
175        return feature_net_flops + comb_flops + af_flops + feature_flops + shape_flops
176
177    def feature_transform(self, f, layer):
178        f0 = f.permute(0, 2, 1)
179        f = F.pad(f0, [1, 0])
180        f = torch.tanh(layer(f))
181        return f.permute(0, 2, 1)
182
183    def forward(self, x, features, periods, numbits, debug=False):
184
185        periods         = periods.squeeze(-1)
186        pitch_embedding = self.pitch_embedding(periods)
187        numbits_embedding = self.numbits_embedding(numbits).flatten(2)
188
189        full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1)
190        cf = self.feature_net(full_features)
191
192        y = self.cf1(x, cf, periods, debug=debug)
193        cf = self.feature_transform(cf, self.post_cf1)
194
195        y = self.cf2(y, cf, periods, debug=debug)
196        cf = self.feature_transform(cf, self.post_cf2)
197
198        y = self.af1(y, cf, debug=debug)
199        cf = self.feature_transform(cf, self.post_af1)
200
201        y1 = y[:, 0:1, :]
202        y2 = self.tdshape1(y[:, 1:2, :], cf)
203        y = torch.cat((y1, y2), dim=1)
204        y = self.af2(y, cf, debug=debug)
205        cf = self.feature_transform(cf, self.post_af2)
206
207        y1 = y[:, 0:1, :]
208        y2 = self.tdshape2(y[:, 1:2, :], cf)
209        y = torch.cat((y1, y2), dim=1)
210        y = self.af3(y, cf, debug=debug)
211        cf = self.feature_transform(cf, self.post_af3)
212
213        y1 = y[:, 0:1, :]
214        y2 = self.tdshape3(y[:, 1:2, :], cf)
215        y = torch.cat((y1, y2), dim=1)
216        y = self.af4(y, cf, debug=debug)
217
218        return y
219