xref: /aosp_15_r20/external/libopus/dnn/torch/osce/models/lavoce.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2/* Copyright (c) 2023 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30
31import torch
32from torch import nn
33import torch.nn.functional as F
34
35import numpy as np
36
37from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
38from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
39from utils.layers.td_shaper import TDShaper
40from utils.layers.noise_shaper import NoiseShaper
41from utils.complexity import _conv1d_flop_count
42from utils.endoscopy import write_data
43
44from models.nns_base import NNSBase
45from models.lpcnet_feature_net import LPCNetFeatureNet
46from .scale_embedding import ScaleEmbedding
47
48def print_channels(y, prefix="", name="", rate=16000):
49    num_channels = y.size(1)
50    for i in range(num_channels):
51        channel_name = f"{prefix}_c{i:02d}"
52        if len(name) > 0: channel_name += "_" + name
53        ch =  y[0,i,:].detach().cpu().numpy()
54        ch = ((2**14) * ch / np.max(ch)).astype(np.int16)
55        write_data(channel_name, ch, rate)
56
57
58
59class LaVoce(nn.Module):
60    """ Linear-Adaptive VOCodEr """
61    FEATURE_FRAME_SIZE=160
62    FRAME_SIZE=80
63
64    def __init__(self,
65                 num_features=20,
66                 pitch_embedding_dim=64,
67                 cond_dim=256,
68                 pitch_max=300,
69                 kernel_size=15,
70                 preemph=0.85,
71                 comb_gain_limit_db=-6,
72                 global_gain_limits_db=[-6, 6],
73                 conv_gain_limits_db=[-6, 6],
74                 norm_p=2,
75                 avg_pool_k=4,
76                 pulses=False,
77                 innovate1=True,
78                 innovate2=False,
79                 innovate3=False,
80                 ftrans_k=2):
81
82        super().__init__()
83
84
85        self.num_features           = num_features
86        self.cond_dim               = cond_dim
87        self.pitch_max              = pitch_max
88        self.pitch_embedding_dim    = pitch_embedding_dim
89        self.kernel_size            = kernel_size
90        self.preemph                = preemph
91        self.pulses                 = pulses
92        self.ftrans_k               = ftrans_k
93
94        assert self.FEATURE_FRAME_SIZE % self.FRAME_SIZE == 0
95        self.upsamp_factor =  self.FEATURE_FRAME_SIZE // self.FRAME_SIZE
96
97        # pitch embedding
98        self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)
99
100        # feature net
101        self.feature_net = LPCNetFeatureNet(num_features + pitch_embedding_dim, cond_dim, self.upsamp_factor)
102
103        # noise shaper
104        self.noise_shaper = NoiseShaper(cond_dim, self.FRAME_SIZE)
105
106        # comb filters
107        left_pad = self.kernel_size // 2
108        right_pad = self.kernel_size - 1 - left_pad
109        self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
110        self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
111
112
113        self.af_prescale = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
114        self.af_mix = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
115
116        # spectral shaping
117        self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
118
119        # non-linear transforms
120        self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, innovate=innovate1)
121        self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, innovate=innovate2)
122        self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, innovate=innovate3)
123
124        # combinators
125        self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
126        self.af3 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
127        self.af4 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
128
129        # feature transforms
130        self.post_cf1 = nn.Conv1d(cond_dim, cond_dim, ftrans_k)
131        self.post_cf2 = nn.Conv1d(cond_dim, cond_dim, ftrans_k)
132        self.post_af1 = nn.Conv1d(cond_dim, cond_dim, ftrans_k)
133        self.post_af2 = nn.Conv1d(cond_dim, cond_dim, ftrans_k)
134        self.post_af3 = nn.Conv1d(cond_dim, cond_dim, ftrans_k)
135
136
137    def create_phase_signals(self, periods):
138
139        batch_size = periods.size(0)
140        progression = torch.arange(1, self.FRAME_SIZE + 1, dtype=periods.dtype, device=periods.device).view((1, -1))
141        progression = torch.repeat_interleave(progression, batch_size, 0)
142
143        phase0 = torch.zeros(batch_size, dtype=periods.dtype, device=periods.device).unsqueeze(-1)
144        chunks = []
145        for sframe in range(periods.size(1)):
146            f = (2.0 * torch.pi / periods[:, sframe]).unsqueeze(-1)
147
148            if self.pulses:
149                alpha = torch.cos(f).view(batch_size, 1, 1)
150                chunk_sin = torch.sin(f  * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
151                pulse_a = torch.relu(chunk_sin - alpha) / (1 - alpha)
152                pulse_b = torch.relu(-chunk_sin - alpha) / (1 - alpha)
153
154                chunk = torch.cat((pulse_a, pulse_b), dim = 1)
155            else:
156                chunk_sin = torch.sin(f  * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
157                chunk_cos = torch.cos(f  * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
158
159                chunk = torch.cat((chunk_sin, chunk_cos), dim = 1)
160
161            phase0 = phase0 + self.FRAME_SIZE * f
162
163            chunks.append(chunk)
164
165        phase_signals = torch.cat(chunks, dim=-1)
166
167        return phase_signals
168
169    def flop_count(self, rate=16000, verbose=False):
170
171        frame_rate = rate / self.FRAME_SIZE
172
173        # feature net
174        feature_net_flops = self.feature_net.flop_count(frame_rate)
175        comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate)
176        af_flops = self.af1.flop_count(rate) + self.af2.flop_count(rate) + self.af3.flop_count(rate) + self.af4.flop_count(rate) + self.af_prescale.flop_count(rate) + self.af_mix.flop_count(rate)
177        feature_flops = (_conv1d_flop_count(self.post_cf1, frame_rate) + _conv1d_flop_count(self.post_cf2, frame_rate)
178                         + _conv1d_flop_count(self.post_af1, frame_rate) + _conv1d_flop_count(self.post_af2, frame_rate) + _conv1d_flop_count(self.post_af3, frame_rate))
179
180        if verbose:
181            print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
182            print(f"comb filters: {comb_flops / 1e6} MFLOPS")
183            print(f"adaptive conv: {af_flops / 1e6} MFLOPS")
184            print(f"feature transforms: {feature_flops / 1e6} MFLOPS")
185
186        return feature_net_flops + comb_flops + af_flops + feature_flops
187
188    def feature_transform(self, f, layer):
189        f = f.permute(0, 2, 1)
190        f = F.pad(f, [self.ftrans_k - 1, 0])
191        f = torch.tanh(layer(f))
192        return f.permute(0, 2, 1)
193
194    def forward(self, features, periods, debug=False):
195
196        periods         = periods.squeeze(-1)
197        pitch_embedding = self.pitch_embedding(periods)
198
199        full_features = torch.cat((features, pitch_embedding), dim=-1)
200        cf = self.feature_net(full_features)
201
202        # upsample periods
203        periods = torch.repeat_interleave(periods, self.upsamp_factor, 1)
204
205        # pre-net
206        ref_phase = torch.tanh(self.create_phase_signals(periods))
207        if debug: print_channels(ref_phase, prefix="lavoce_01", name="pulse")
208        x = self.af_prescale(ref_phase, cf)
209        noise = self.noise_shaper(cf)
210        if debug: print_channels(torch.cat((x, noise), dim=1), prefix="lavoce_02", name="inputs")
211        y = self.af_mix(torch.cat((x, noise), dim=1), cf)
212        if debug: print_channels(y, prefix="lavoce_03", name="postselect1")
213
214        # temporal shaping + innovating
215        y1 = y[:, 0:1, :]
216        y2 = self.tdshape1(y[:, 1:2, :], cf)
217        if debug: print_channels(y2, prefix="lavoce_04", name="postshape1")
218        y = torch.cat((y1, y2), dim=1)
219        y = self.af2(y, cf, debug=debug)
220        if debug: print_channels(y, prefix="lavoce_05", name="postselect2")
221        cf = self.feature_transform(cf, self.post_af2)
222
223        y1 = y[:, 0:1, :]
224        y2 = self.tdshape2(y[:, 1:2, :], cf)
225        if debug: print_channels(y2, prefix="lavoce_06", name="postshape2")
226        y = torch.cat((y1, y2), dim=1)
227        y = self.af3(y, cf, debug=debug)
228        if debug: print_channels(y, prefix="lavoce_07", name="postmix1")
229        cf = self.feature_transform(cf, self.post_af3)
230
231        # spectral shaping
232        y = self.cf1(y, cf, periods, debug=debug)
233        if debug: print_channels(y, prefix="lavoce_08", name="postcomb1")
234        cf = self.feature_transform(cf, self.post_cf1)
235
236        y = self.cf2(y, cf, periods, debug=debug)
237        if debug: print_channels(y, prefix="lavoce_09", name="postcomb2")
238        cf = self.feature_transform(cf, self.post_cf2)
239
240        y = self.af1(y, cf, debug=debug)
241        if debug: print_channels(y, prefix="lavoce_10", name="postselect3")
242        cf = self.feature_transform(cf, self.post_af1)
243
244        # final temporal env adjustment
245        y1 = y[:, 0:1, :]
246        y2 = self.tdshape3(y[:, 1:2, :], cf)
247        if debug: print_channels(y2, prefix="lavoce_11", name="postshape3")
248        y = torch.cat((y1, y2), dim=1)
249        y = self.af4(y, cf, debug=debug)
250        if debug: print_channels(y, prefix="lavoce_12", name="postmix2")
251
252        return y
253
254    def process(self, features, periods, debug=False):
255
256        self.eval()
257        device = next(iter(self.parameters())).device
258        with torch.no_grad():
259
260            # run model
261            f = features.unsqueeze(0).to(device)
262            p = periods.unsqueeze(0).to(device)
263
264            y = self.forward(f, p, debug=debug).squeeze()
265
266            # deemphasis
267            if self.preemph > 0:
268                for i in range(len(y) - 1):
269                    y[i + 1] += self.preemph * y[i]
270
271            # clip to valid range
272            out = torch.clip((2**15) * y, -2**15, 2**15 - 1).short()
273
274        return out