xref: /aosp_15_r20/external/libopus/dnn/torch/lpcnet/models/lpcnet.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2/* Copyright (c) 2023 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30import torch
31from torch import nn
32import numpy as np
33
34from utils.ulaw import lin2ulawq, ulaw2lin
35from utils.sample import sample_excitation
36from utils.pcm import clip_to_int16
37from utils.sparsification import GRUSparsifier, calculate_gru_flops_per_step
38from utils.layers import DualFC
39from utils.misc import get_pdf_from_tree
40
41
42class LPCNet(nn.Module):
43    def __init__(self, config):
44        super(LPCNet, self).__init__()
45
46        #
47        self.input_layout       = config['input_layout']
48        self.feature_history    = config['feature_history']
49        self.feature_lookahead  = config['feature_lookahead']
50
51        # frame rate network parameters
52        self.feature_dimension          = config['feature_dimension']
53        self.period_embedding_dim       = config['period_embedding_dim']
54        self.period_levels              = config['period_levels']
55        self.feature_channels           = self.feature_dimension + self.period_embedding_dim
56        self.feature_conditioning_dim   = config['feature_conditioning_dim']
57        self.feature_conv_kernel_size   = config['feature_conv_kernel_size']
58
59
60        # frame rate network layers
61        self.period_embedding   = nn.Embedding(self.period_levels, self.period_embedding_dim)
62        self.feature_conv1      = nn.Conv1d(self.feature_channels, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid')
63        self.feature_conv2      = nn.Conv1d(self.feature_conditioning_dim, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid')
64        self.feature_dense1     = nn.Linear(self.feature_conditioning_dim, self.feature_conditioning_dim)
65        self.feature_dense2     = nn.Linear(*(2*[self.feature_conditioning_dim]))
66
67        # sample rate network parameters
68        self.frame_size             = config['frame_size']
69        self.signal_levels          = config['signal_levels']
70        self.signal_embedding_dim   = config['signal_embedding_dim']
71        self.gru_a_units            = config['gru_a_units']
72        self.gru_b_units            = config['gru_b_units']
73        self.output_levels          = config['output_levels']
74        self.hsampling              = config.get('hsampling', False)
75
76        self.gru_a_input_dim        = len(self.input_layout['signals']) * self.signal_embedding_dim + self.feature_conditioning_dim
77        self.gru_b_input_dim        = self.gru_a_units + self.feature_conditioning_dim
78
79        # sample rate network layers
80        self.signal_embedding   = nn.Embedding(self.signal_levels, self.signal_embedding_dim)
81        self.gru_a              = nn.GRU(self.gru_a_input_dim, self.gru_a_units, batch_first=True)
82        self.gru_b              = nn.GRU(self.gru_b_input_dim, self.gru_b_units, batch_first=True)
83        self.dual_fc            = DualFC(self.gru_b_units, self.output_levels)
84
85        # sparsification
86        self.sparsifier = []
87
88        # GRU A
89        if 'gru_a' in config['sparsification']:
90            gru_config  = config['sparsification']['gru_a']
91            task_list = [(self.gru_a, gru_config['params'])]
92            self.sparsifier.append(GRUSparsifier(task_list,
93                                                 gru_config['start'],
94                                                 gru_config['stop'],
95                                                 gru_config['interval'],
96                                                 gru_config['exponent'])
97            )
98            self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a,
99                                                                      gru_config['params'], drop_input=True)
100        else:
101            self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a, drop_input=True)
102
103        # GRU B
104        if 'gru_b' in config['sparsification']:
105            gru_config  = config['sparsification']['gru_b']
106            task_list = [(self.gru_b, gru_config['params'])]
107            self.sparsifier.append(GRUSparsifier(task_list,
108                                                 gru_config['start'],
109                                                 gru_config['stop'],
110                                                 gru_config['interval'],
111                                                 gru_config['exponent'])
112            )
113            self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b,
114                                                                      gru_config['params'])
115        else:
116            self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b)
117
118        # inference parameters
119        self.lpc_gamma = config.get('lpc_gamma', 1)
120
121    def sparsify(self):
122        for sparsifier in self.sparsifier:
123            sparsifier.step()
124
125    def get_gflops(self, fs, verbose=False):
126        gflops = 0
127
128        # frame rate network
129        conditioning_dim = self.feature_conditioning_dim
130        feature_channels = self.feature_channels
131        frame_rate = fs / self.frame_size
132        frame_rate_network_complexity = 1e-9 * 2 * (5 * conditioning_dim + 3 * feature_channels) * conditioning_dim * frame_rate
133        if verbose:
134            print(f"frame rate network: {frame_rate_network_complexity} GFLOPS")
135        gflops += frame_rate_network_complexity
136
137        # gru a
138        gru_a_rate = fs
139        gru_a_complexity = 1e-9 * gru_a_rate * self.gru_a_flops_per_step
140        if verbose:
141            print(f"gru A: {gru_a_complexity} GFLOPS")
142        gflops += gru_a_complexity
143
144        # gru b
145        gru_b_rate = fs
146        gru_b_complexity = 1e-9 * gru_b_rate * self.gru_b_flops_per_step
147        if verbose:
148            print(f"gru B: {gru_b_complexity} GFLOPS")
149        gflops += gru_b_complexity
150
151
152        # dual fcs
153        fc = self.dual_fc
154        rate = fs
155        input_size = fc.dense1.in_features
156        output_size = fc.dense1.out_features
157        dual_fc_complexity = 1e-9 *  (4 * input_size * output_size + 22 * output_size) * rate
158        if self.hsampling:
159            dual_fc_complexity /= 8
160        if verbose:
161            print(f"dual_fc: {dual_fc_complexity} GFLOPS")
162        gflops += dual_fc_complexity
163
164        if verbose:
165            print(f'total: {gflops} GFLOPS')
166
167        return gflops
168
169    def frame_rate_network(self, features, periods):
170
171        embedded_periods = torch.flatten(self.period_embedding(periods), 2, 3)
172        features = torch.concat((features, embedded_periods), dim=-1)
173
174        # convert to channels first and calculate conditioning vector
175        c = torch.permute(features, [0, 2, 1])
176
177        c = torch.tanh(self.feature_conv1(c))
178        c = torch.tanh(self.feature_conv2(c))
179        # back to channels last
180        c = torch.permute(c, [0, 2, 1])
181        c = torch.tanh(self.feature_dense1(c))
182        c = torch.tanh(self.feature_dense2(c))
183
184        return c
185
186    def sample_rate_network(self, signals, c, gru_states):
187        embedded_signals = torch.flatten(self.signal_embedding(signals), 2, 3)
188        c_upsampled      = torch.repeat_interleave(c, self.frame_size, dim=1)
189
190        y = torch.concat((embedded_signals, c_upsampled), dim=-1)
191        y, gru_a_state = self.gru_a(y, gru_states[0])
192        y = torch.concat((y, c_upsampled), dim=-1)
193        y, gru_b_state = self.gru_b(y, gru_states[1])
194
195        y = self.dual_fc(y)
196
197        if self.hsampling:
198            y = torch.sigmoid(y)
199            log_probs = torch.log(get_pdf_from_tree(y) + 1e-6)
200        else:
201            log_probs = torch.log_softmax(y, dim=-1)
202
203        return log_probs, (gru_a_state, gru_b_state)
204
205    def decoder(self, signals, c, gru_states):
206        embedded_signals = torch.flatten(self.signal_embedding(signals), 2, 3)
207
208        y = torch.concat((embedded_signals, c), dim=-1)
209        y, gru_a_state = self.gru_a(y, gru_states[0])
210        y = torch.concat((y, c), dim=-1)
211        y, gru_b_state = self.gru_b(y, gru_states[1])
212
213        y = self.dual_fc(y)
214
215        if self.hsampling:
216            y = torch.sigmoid(y)
217            probs = get_pdf_from_tree(y)
218        else:
219            probs = torch.softmax(y, dim=-1)
220
221        return probs, (gru_a_state, gru_b_state)
222
223    def forward(self, features, periods, signals, gru_states):
224
225        c            = self.frame_rate_network(features, periods)
226        log_probs, _ = self.sample_rate_network(signals, c, gru_states)
227
228        return log_probs
229
230    def generate(self, features, periods, lpcs):
231
232        with torch.no_grad():
233            device = self.parameters().__next__().device
234
235            num_frames          = features.shape[0] - self.feature_history - self.feature_lookahead
236            lpc_order           = lpcs.shape[-1]
237            num_input_signals   = len(self.input_layout['signals'])
238            pitch_corr_position = self.input_layout['features']['pitch_corr'][0]
239
240            # signal buffers
241            pcm    = torch.zeros((num_frames * self.frame_size + lpc_order))
242            output = torch.zeros((num_frames * self.frame_size), dtype=torch.int16)
243            mem = 0
244
245            # state buffers
246            gru_a_state = torch.zeros((1, 1, self.gru_a_units))
247            gru_b_state = torch.zeros((1, 1, self.gru_b_units))
248            gru_states = [gru_a_state, gru_b_state]
249
250            input_signals = torch.zeros((1, 1, num_input_signals), dtype=torch.long) + 128
251
252            # push data to device
253            features = features.to(device)
254            periods  = periods.to(device)
255            lpcs     = lpcs.to(device)
256
257            # lpc weighting
258            weights = torch.FloatTensor([self.lpc_gamma ** (i + 1) for i in range(lpc_order)]).to(device)
259            lpcs    = lpcs * weights
260
261            # run feature encoding
262            c = self.frame_rate_network(features.unsqueeze(0), periods.unsqueeze(0))
263
264            for frame_index in range(num_frames):
265                frame_start = frame_index * self.frame_size
266                pitch_corr  = features[frame_index + self.feature_history, pitch_corr_position]
267                a           = - torch.flip(lpcs[frame_index + self.feature_history], [0])
268                current_c   = c[:, frame_index : frame_index + 1, :]
269
270                for i in range(self.frame_size):
271                    pcm_position    = frame_start + i + lpc_order
272                    output_position = frame_start + i
273
274                    # prepare input
275                    pred = torch.sum(pcm[pcm_position - lpc_order : pcm_position] * a)
276                    if 'prediction' in self.input_layout['signals']:
277                        input_signals[0, 0, self.input_layout['signals']['prediction']] = lin2ulawq(pred)
278
279                    # run single step of sample rate network
280                    probs, gru_states = self.decoder(
281                                            input_signals,
282                                            current_c,
283                                            gru_states
284                                    )
285
286                    # sample from output
287                    exc_ulaw = sample_excitation(probs, pitch_corr)
288
289                    # signal generation
290                    exc = ulaw2lin(exc_ulaw)
291                    sig = exc + pred
292                    pcm[pcm_position] = sig
293                    mem = 0.85 * mem + float(sig)
294                    output[output_position] = clip_to_int16(round(mem))
295
296                    # buffer update
297                    if 'last_signal' in self.input_layout['signals']:
298                        input_signals[0, 0, self.input_layout['signals']['last_signal']] = lin2ulawq(sig)
299
300                    if 'last_error' in self.input_layout['signals']:
301                        input_signals[0, 0, self.input_layout['signals']['last_error']]  = lin2ulawq(exc)
302
303        return output
304