xref: /aosp_15_r20/external/libopus/dnn/torch/osce/models/lavoce_400.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2/* Copyright (c) 2023 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30
31import torch
32from torch import nn
33import torch.nn.functional as F
34
35import numpy as np
36
37from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
38from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
39from utils.layers.td_shaper import TDShaper
40from utils.layers.noise_shaper import NoiseShaper
41from utils.complexity import _conv1d_flop_count
42from utils.endoscopy import write_data
43
44from models.nns_base import NNSBase
45from models.lpcnet_feature_net import LPCNetFeatureNet
46from .scale_embedding import ScaleEmbedding
47
48class LaVoce400(nn.Module):
49    """ Linear-Adaptive VOCodEr """
50    FEATURE_FRAME_SIZE=160
51    FRAME_SIZE=40
52
53    def __init__(self,
54                 num_features=20,
55                 pitch_embedding_dim=64,
56                 cond_dim=256,
57                 pitch_max=300,
58                 kernel_size=15,
59                 preemph=0.85,
60                 comb_gain_limit_db=-6,
61                 global_gain_limits_db=[-6, 6],
62                 conv_gain_limits_db=[-6, 6],
63                 norm_p=2,
64                 avg_pool_k=4,
65                 pulses=False):
66
67        super().__init__()
68
69
70        self.num_features           = num_features
71        self.cond_dim               = cond_dim
72        self.pitch_max              = pitch_max
73        self.pitch_embedding_dim    = pitch_embedding_dim
74        self.kernel_size            = kernel_size
75        self.preemph                = preemph
76        self.pulses                 = pulses
77
78        assert self.FEATURE_FRAME_SIZE % self.FRAME_SIZE == 0
79        self.upsamp_factor =  self.FEATURE_FRAME_SIZE // self.FRAME_SIZE
80
81        # pitch embedding
82        self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)
83
84        # feature net
85        self.feature_net = LPCNetFeatureNet(num_features + pitch_embedding_dim, cond_dim, self.upsamp_factor)
86
87        # noise shaper
88        self.noise_shaper = NoiseShaper(cond_dim, self.FRAME_SIZE)
89
90        # comb filters
91        left_pad = self.kernel_size // 2
92        right_pad = self.kernel_size - 1 - left_pad
93        self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=20, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
94        self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=20, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
95
96
97        self.af_prescale = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
98        self.af_mix = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
99
100        # spectral shaping
101        self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
102
103        # non-linear transforms
104        self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, innovate=True)
105        self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k)
106        self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k)
107
108        # combinators
109        self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
110        self.af3 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
111        self.af4 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
112
113        # feature transforms
114        self.post_cf1 = nn.Conv1d(cond_dim, cond_dim, 2)
115        self.post_cf2 = nn.Conv1d(cond_dim, cond_dim, 2)
116        self.post_af1 = nn.Conv1d(cond_dim, cond_dim, 2)
117        self.post_af2 = nn.Conv1d(cond_dim, cond_dim, 2)
118        self.post_af3 = nn.Conv1d(cond_dim, cond_dim, 2)
119
120
121    def create_phase_signals(self, periods):
122
123        batch_size = periods.size(0)
124        progression = torch.arange(1, self.FRAME_SIZE + 1, dtype=periods.dtype, device=periods.device).view((1, -1))
125        progression = torch.repeat_interleave(progression, batch_size, 0)
126
127        phase0 = torch.zeros(batch_size, dtype=periods.dtype, device=periods.device).unsqueeze(-1)
128        chunks = []
129        for sframe in range(periods.size(1)):
130            f = (2.0 * torch.pi / periods[:, sframe]).unsqueeze(-1)
131
132            if self.pulses:
133                alpha = torch.cos(f).view(batch_size, 1, 1)
134                chunk_sin = torch.sin(f  * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
135                pulse_a = torch.relu(chunk_sin - alpha) / (1 - alpha)
136                pulse_b = torch.relu(-chunk_sin - alpha) / (1 - alpha)
137
138                chunk = torch.cat((pulse_a, pulse_b), dim = 1)
139            else:
140                chunk_sin = torch.sin(f  * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
141                chunk_cos = torch.cos(f  * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
142
143                chunk = torch.cat((chunk_sin, chunk_cos), dim = 1)
144
145            phase0 = phase0 + self.FRAME_SIZE * f
146
147            chunks.append(chunk)
148
149        phase_signals = torch.cat(chunks, dim=-1)
150
151        return phase_signals
152
153    def flop_count(self, rate=16000, verbose=False):
154
155        frame_rate = rate / self.FRAME_SIZE
156
157        # feature net
158        feature_net_flops = self.feature_net.flop_count(frame_rate)
159        comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate)
160        af_flops = self.af1.flop_count(rate) + self.af2.flop_count(rate) + self.af3.flop_count(rate) + self.af4.flop_count(rate) + self.af_prescale.flop_count(rate) + self.af_mix.flop_count(rate)
161        feature_flops = (_conv1d_flop_count(self.post_cf1, frame_rate) + _conv1d_flop_count(self.post_cf2, frame_rate)
162                         + _conv1d_flop_count(self.post_af1, frame_rate) + _conv1d_flop_count(self.post_af2, frame_rate) + _conv1d_flop_count(self.post_af3, frame_rate))
163
164        if verbose:
165            print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
166            print(f"comb filters: {comb_flops / 1e6} MFLOPS")
167            print(f"adaptive conv: {af_flops / 1e6} MFLOPS")
168            print(f"feature transforms: {feature_flops / 1e6} MFLOPS")
169
170        return feature_net_flops + comb_flops + af_flops + feature_flops
171
172    def feature_transform(self, f, layer):
173        f = f.permute(0, 2, 1)
174        f = F.pad(f, [1, 0])
175        f = torch.tanh(layer(f))
176        return f.permute(0, 2, 1)
177
178    def forward(self, features, periods, debug=False):
179
180        periods         = periods.squeeze(-1)
181        pitch_embedding = self.pitch_embedding(periods)
182
183        full_features = torch.cat((features, pitch_embedding), dim=-1)
184        cf = self.feature_net(full_features)
185
186        # upsample periods
187        periods = torch.repeat_interleave(periods, self.upsamp_factor, 1)
188
189        # pre-net
190        ref_phase = torch.tanh(self.create_phase_signals(periods))
191        x = self.af_prescale(ref_phase, cf)
192        noise = self.noise_shaper(cf)
193        y = self.af_mix(torch.cat((x, noise), dim=1), cf)
194
195        if debug:
196            ch0 = y[0,0,:].detach().cpu().numpy()
197            ch1 = y[0,1,:].detach().cpu().numpy()
198            ch0 = (2**15 * ch0 / np.max(ch0)).astype(np.int16)
199            ch1 = (2**15 * ch1 / np.max(ch1)).astype(np.int16)
200            write_data('prior_channel0', ch0, 16000)
201            write_data('prior_channel1', ch1, 16000)
202
203        # temporal shaping + innovating
204        y1 = y[:, 0:1, :]
205        y2 = self.tdshape1(y[:, 1:2, :], cf)
206        y = torch.cat((y1, y2), dim=1)
207        y = self.af2(y, cf, debug=debug)
208        cf = self.feature_transform(cf, self.post_af2)
209
210        y1 = y[:, 0:1, :]
211        y2 = self.tdshape2(y[:, 1:2, :], cf)
212        y = torch.cat((y1, y2), dim=1)
213        y = self.af3(y, cf, debug=debug)
214        cf = self.feature_transform(cf, self.post_af3)
215
216        # spectral shaping
217        y = self.cf1(y, cf, periods, debug=debug)
218        cf = self.feature_transform(cf, self.post_cf1)
219
220        y = self.cf2(y, cf, periods, debug=debug)
221        cf = self.feature_transform(cf, self.post_cf2)
222
223        y = self.af1(y, cf, debug=debug)
224        cf = self.feature_transform(cf, self.post_af1)
225
226        # final temporal env adjustment
227        y1 = y[:, 0:1, :]
228        y2 = self.tdshape3(y[:, 1:2, :], cf)
229        y = torch.cat((y1, y2), dim=1)
230        y = self.af4(y, cf, debug=debug)
231
232        return y
233
234    def process(self, features, periods, debug=False):
235
236        self.eval()
237        device = next(iter(self.parameters())).device
238        with torch.no_grad():
239
240            # run model
241            f = features.unsqueeze(0).to(device)
242            p = periods.unsqueeze(0).to(device)
243
244            y = self.forward(f, p, debug=debug).squeeze()
245
246            # deemphasis
247            if self.preemph > 0:
248                for i in range(len(y) - 1):
249                    y[i + 1] += self.preemph * y[i]
250
251            # clip to valid range
252            out = torch.clip((2**15) * y, -2**15, 2**15 - 1).short()
253
254        return out