xref: /aosp_15_r20/external/libopus/dnn/torch/osce/models/lace.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2/* Copyright (c) 2023 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30import torch
31from torch import nn
32import torch.nn.functional as F
33
34import numpy as np
35
36from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
37from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
38
39from models.nns_base import NNSBase
40from models.silk_feature_net_pl import SilkFeatureNetPL
41from models.silk_feature_net import SilkFeatureNet
42from .scale_embedding import ScaleEmbedding
43
44import sys
45sys.path.append('../dnntools')
46
47from dnntools.sparsification import create_sparsifier
48
49
50class LACE(NNSBase):
51    """ Linear-Adaptive Coding Enhancer """
52    FRAME_SIZE=80
53
54    def __init__(self,
55                 num_features=47,
56                 pitch_embedding_dim=64,
57                 cond_dim=256,
58                 pitch_max=257,
59                 kernel_size=15,
60                 preemph=0.85,
61                 skip=91,
62                 comb_gain_limit_db=-6,
63                 global_gain_limits_db=[-6, 6],
64                 conv_gain_limits_db=[-6, 6],
65                 numbits_range=[50, 650],
66                 numbits_embedding_dim=8,
67                 hidden_feature_dim=64,
68                 partial_lookahead=True,
69                 norm_p=2,
70                 softquant=False,
71                 sparsify=False,
72                 sparsification_schedule=[10000, 30000, 100],
73                 sparsification_density=0.5,
74                 apply_weight_norm=False):
75
76        super().__init__(skip=skip, preemph=preemph)
77
78
79        self.num_features           = num_features
80        self.cond_dim               = cond_dim
81        self.pitch_max              = pitch_max
82        self.pitch_embedding_dim    = pitch_embedding_dim
83        self.kernel_size            = kernel_size
84        self.preemph                = preemph
85        self.skip                   = skip
86        self.numbits_range          = numbits_range
87        self.numbits_embedding_dim  = numbits_embedding_dim
88        self.hidden_feature_dim     = hidden_feature_dim
89        self.partial_lookahead      = partial_lookahead
90
91        # pitch embedding
92        self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)
93
94        # numbits embedding
95        self.numbits_embedding = ScaleEmbedding(numbits_embedding_dim, *numbits_range, logscale=True)
96
97        # feature net
98        if partial_lookahead:
99            self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant, sparsify=sparsify, sparsification_density=sparsification_density, apply_weight_norm=apply_weight_norm)
100        else:
101            self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim)
102
103        # comb filters
104        left_pad = self.kernel_size // 2
105        right_pad = self.kernel_size - 1 - left_pad
106        self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
107        self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
108
109        # spectral shaping
110        self.af1 = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
111
112        if sparsify:
113            self.sparsifier = create_sparsifier(self, *sparsification_schedule)
114
115    def flop_count(self, rate=16000, verbose=False):
116
117        frame_rate = rate / self.FRAME_SIZE
118
119        # feature net
120        feature_net_flops = self.feature_net.flop_count(frame_rate)
121        comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate)
122        af_flops = self.af1.flop_count(rate)
123
124        if verbose:
125            print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
126            print(f"comb filters: {comb_flops / 1e6} MFLOPS")
127            print(f"adaptive conv: {af_flops / 1e6} MFLOPS")
128
129        return feature_net_flops + comb_flops + af_flops
130
131    def forward(self, x, features, periods, numbits, debug=False):
132
133        periods         = periods.squeeze(-1)
134        pitch_embedding = self.pitch_embedding(periods)
135        numbits_embedding = self.numbits_embedding(numbits).flatten(2)
136
137        full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1)
138        cf = self.feature_net(full_features)
139
140        y = self.cf1(x, cf, periods, debug=debug)
141
142        y = self.cf2(y, cf, periods, debug=debug)
143
144        y = self.af1(y, cf, debug=debug)
145
146        return y
147
148    def get_impulse_responses(self, features, periods, numbits):
149        """ generates impoulse responses on frame centers (input without batch dimension) """
150
151        num_frames = features.size(0)
152        batch_size = 32
153        max_len = 2 * (self.pitch_max + self.kernel_size) + 10
154
155        # spread out some pulses
156        x = np.zeros((batch_size, 1, num_frames * self.FRAME_SIZE))
157        for b in range(batch_size):
158            x[b, :, self.FRAME_SIZE // 2 + b * self.FRAME_SIZE :: batch_size * self.FRAME_SIZE] = 1
159
160        # prepare input
161        x = torch.from_numpy(x).float().to(features.device)
162        features = torch.repeat_interleave(features.unsqueeze(0), batch_size, 0)
163        periods = torch.repeat_interleave(periods.unsqueeze(0), batch_size, 0)
164        numbits = torch.repeat_interleave(numbits.unsqueeze(0), batch_size, 0)
165
166        # run network
167        with torch.no_grad():
168            periods         = periods.squeeze(-1)
169            pitch_embedding = self.pitch_embedding(periods)
170            numbits_embedding = self.numbits_embedding(numbits).flatten(2)
171            full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1)
172            cf = self.feature_net(full_features)
173            y = self.cf1(x, cf, periods, debug=False)
174            y = self.cf2(y, cf, periods, debug=False)
175            y = self.af1(y, cf, debug=False)
176
177        # collect responses
178        y = y.detach().squeeze().cpu().numpy()
179        cut_frames = (max_len + self.FRAME_SIZE - 1) // self.FRAME_SIZE
180        num_responses = num_frames - cut_frames
181        responses = np.zeros((num_responses, max_len))
182
183        for i in range(num_responses):
184            b = i % batch_size
185            start = self.FRAME_SIZE // 2 + i * self.FRAME_SIZE
186            stop = start + max_len
187
188            responses[i, :] = y[b, start:stop]
189
190        return responses
191