xref: /aosp_15_r20/external/libopus/dnn/torch/lpcnet/data/lpcnet_dataset.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2/* Copyright (c) 2023 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30""" Dataset for LPCNet training """
31import os
32
33import yaml
34import torch
35import numpy as np
36from torch.utils.data import Dataset
37
38
39scale = 255.0/32768.0
40scale_1 = 32768.0/255.0
41def ulaw2lin(u):
42    u = u - 128
43    s = np.sign(u)
44    u = np.abs(u)
45    return s*scale_1*(np.exp(u/128.*np.log(256))-1)
46
47
48def lin2ulaw(x):
49    s = np.sign(x)
50    x = np.abs(x)
51    u = (s*(128*np.log(1+scale*x)/np.log(256)))
52    u = np.clip(128 + np.round(u), 0, 255)
53    return u
54
55
56def run_lpc(signal, lpcs, frame_length=160):
57    num_frames, lpc_order = lpcs.shape
58
59    prediction = np.concatenate(
60        [- np.convolve(signal[i * frame_length : (i + 1) * frame_length + lpc_order - 1], lpcs[i], mode='valid') for i in range(num_frames)]
61    )
62    error = signal[lpc_order :] - prediction
63
64    return prediction, error
65
66class LPCNetDataset(Dataset):
67    def __init__(self,
68                 path_to_dataset,
69                 features=['cepstrum', 'periods', 'pitch_corr'],
70                 input_signals=['last_signal', 'prediction', 'last_error'],
71                 target='error',
72                 frames_per_sample=15,
73                 feature_history=2,
74                 feature_lookahead=2,
75                 lpc_gamma=1):
76
77        super(LPCNetDataset, self).__init__()
78
79        # load dataset info
80        self.path_to_dataset = path_to_dataset
81        with open(os.path.join(path_to_dataset, 'info.yml'), 'r') as f:
82            dataset = yaml.load(f, yaml.FullLoader)
83
84        # dataset version
85        self.version = dataset['version']
86        if self.version == 1:
87            self.getitem = self.getitem_v1
88        elif self.version == 2:
89            self.getitem = self.getitem_v2
90        else:
91            raise ValueError(f"dataset version {self.version} unknown")
92
93        # features
94        self.feature_history      = feature_history
95        self.feature_lookahead    = feature_lookahead
96        self.frame_offset         = 1 + self.feature_history
97        self.frames_per_sample    = frames_per_sample
98        self.input_features       = features
99        self.feature_frame_layout = dataset['feature_frame_layout']
100        self.lpc_gamma            = lpc_gamma
101
102        # load feature file
103        self.feature_file = os.path.join(path_to_dataset, dataset['feature_file'])
104        self.features = np.memmap(self.feature_file, dtype=dataset['feature_dtype'])
105        self.feature_frame_length = dataset['feature_frame_length']
106
107        assert len(self.features) % self.feature_frame_length == 0
108        self.features = self.features.reshape((-1, self.feature_frame_length))
109
110        # derive number of samples is dataset
111        self.dataset_length = (len(self.features) - self.frame_offset - self.feature_lookahead - 1) // self.frames_per_sample
112
113        # signals
114        self.frame_length               = dataset['frame_length']
115        self.signal_frame_layout        = dataset['signal_frame_layout']
116        self.input_signals              = input_signals
117        self.target                     = target
118
119        # load signals
120        self.signal_file  = os.path.join(path_to_dataset, dataset['signal_file'])
121        self.signals  = np.memmap(self.signal_file, dtype=dataset['signal_dtype'])
122        self.signal_frame_length  = dataset['signal_frame_length']
123        self.signals = self.signals.reshape((-1, self.signal_frame_length))
124        assert len(self.signals) == len(self.features) * self.frame_length
125
126    def __getitem__(self, index):
127        return self.getitem(index)
128
129    def getitem_v2(self, index):
130        sample = dict()
131
132        # extract features
133        frame_start = self.frame_offset + index       * self.frames_per_sample - self.feature_history
134        frame_stop  = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead
135
136        for feature in self.input_features:
137            feature_start, feature_stop = self.feature_frame_layout[feature]
138            sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]
139
140        # convert periods
141        if 'periods' in self.input_features:
142            sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16')
143
144        signal_start = (self.frame_offset + index       * self.frames_per_sample) * self.frame_length
145        signal_stop  = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length
146
147        # last_signal and signal are always expected to be there
148        sample['last_signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']]
149        sample['signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['signal']]
150
151        # calculate prediction and error if lpc coefficients present and prediction not given
152        if 'lpc' in self.feature_frame_layout and 'prediction' not in self.signal_frame_layout:
153            # lpc coefficients with one frame lookahead
154            # frame positions (start one frame early for past excitation)
155            frame_start = self.frame_offset + self.frames_per_sample * index - 1
156            frame_stop  = self.frame_offset + self.frames_per_sample * (index + 1)
157
158            # feature positions
159            lpc_start, lpc_stop = self.feature_frame_layout['lpc']
160            lpc_order = lpc_stop - lpc_start
161            lpcs = self.features[frame_start : frame_stop, lpc_start : lpc_stop]
162
163            # LPC weighting
164            lpc_order = lpc_stop - lpc_start
165            weights = np.array([self.lpc_gamma ** (i + 1) for i in range(lpc_order)])
166            lpcs = lpcs * weights
167
168            # signal position (lpc_order samples as history)
169            signal_start = frame_start * self.frame_length - lpc_order + 1
170            signal_stop  = frame_stop  * self.frame_length + 1
171            noisy_signal = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']]
172            clean_signal = self.signals[signal_start - 1 : signal_stop - 1, self.signal_frame_layout['signal']]
173
174            noisy_prediction, noisy_error = run_lpc(noisy_signal, lpcs, frame_length=self.frame_length)
175
176            # extract signals
177            offset = self.frame_length
178            sample['prediction'] = noisy_prediction[offset : offset + self.frame_length * self.frames_per_sample]
179            sample['last_error'] = noisy_error[offset - 1 : offset - 1 + self.frame_length * self.frames_per_sample]
180            # calculate error between real signal and noisy prediction
181
182
183            sample['error'] = sample['signal'] - sample['prediction']
184
185
186        # concatenate features
187        feature_keys = [key for key in self.input_features if not key.startswith("periods")]
188        features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
189        signals = torch.cat([torch.LongTensor(lin2ulaw(sample[key])).unsqueeze(-1) for key in self.input_signals], dim=-1)
190        target  = torch.LongTensor(lin2ulaw(sample[self.target]))
191        periods = torch.LongTensor(sample['periods'])
192
193        return {'features' : features, 'periods' : periods, 'signals' : signals, 'target' : target}
194
195    def getitem_v1(self, index):
196        sample = dict()
197
198        # extract features
199        frame_start = self.frame_offset + index       * self.frames_per_sample - self.feature_history
200        frame_stop  = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead
201
202        for feature in self.input_features:
203            feature_start, feature_stop = self.feature_frame_layout[feature]
204            sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]
205
206        # convert periods
207        if 'periods' in self.input_features:
208            sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16')
209
210        signal_start = (self.frame_offset + index       * self.frames_per_sample) * self.frame_length
211        signal_stop  = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length
212
213        # last_signal and signal are always expected to be there
214        for signal_name, index in self.signal_frame_layout.items():
215            sample[signal_name] = self.signals[signal_start : signal_stop, index]
216
217        # concatenate features
218        feature_keys = [key for key in self.input_features if not key.startswith("periods")]
219        features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
220        signals = torch.cat([torch.LongTensor(sample[key]).unsqueeze(-1) for key in self.input_signals], dim=-1)
221        target  = torch.LongTensor(sample[self.target])
222        periods = torch.LongTensor(sample['periods'])
223
224        return {'features' : features, 'periods' : periods, 'signals' : signals, 'target' : target}
225
226    def __len__(self):
227        return self.dataset_length
228