1""" 2/* Copyright (c) 2023 Amazon 3 Written by Jan Buethe */ 4/* 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions 7 are met: 8 9 - Redistributions of source code must retain the above copyright 10 notice, this list of conditions and the following disclaimer. 11 12 - Redistributions in binary form must reproduce the above copyright 13 notice, this list of conditions and the following disclaimer in the 14 documentation and/or other materials provided with the distribution. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*/ 28""" 29 30""" Dataset for LPCNet training """ 31import os 32 33import yaml 34import torch 35import numpy as np 36from torch.utils.data import Dataset 37 38 39scale = 255.0/32768.0 40scale_1 = 32768.0/255.0 41def ulaw2lin(u): 42 u = u - 128 43 s = np.sign(u) 44 u = np.abs(u) 45 return s*scale_1*(np.exp(u/128.*np.log(256))-1) 46 47 48def lin2ulaw(x): 49 s = np.sign(x) 50 x = np.abs(x) 51 u = (s*(128*np.log(1+scale*x)/np.log(256))) 52 u = np.clip(128 + np.round(u), 0, 255) 53 return u 54 55 56def run_lpc(signal, lpcs, frame_length=160): 57 num_frames, lpc_order = lpcs.shape 58 59 prediction = np.concatenate( 60 [- np.convolve(signal[i * frame_length : (i + 1) * frame_length + lpc_order - 1], lpcs[i], mode='valid') for i in range(num_frames)] 61 ) 62 error = signal[lpc_order :] - prediction 63 64 return prediction, error 65 66class LPCNetDataset(Dataset): 67 def __init__(self, 68 path_to_dataset, 69 features=['cepstrum', 'periods', 'pitch_corr'], 70 input_signals=['last_signal', 'prediction', 'last_error'], 71 target='error', 72 frames_per_sample=15, 73 feature_history=2, 74 feature_lookahead=2, 75 lpc_gamma=1): 76 77 super(LPCNetDataset, self).__init__() 78 79 # load dataset info 80 self.path_to_dataset = path_to_dataset 81 with open(os.path.join(path_to_dataset, 'info.yml'), 'r') as f: 82 dataset = yaml.load(f, yaml.FullLoader) 83 84 # dataset version 85 self.version = dataset['version'] 86 if self.version == 1: 87 self.getitem = self.getitem_v1 88 elif self.version == 2: 89 self.getitem = self.getitem_v2 90 else: 91 raise ValueError(f"dataset version {self.version} unknown") 92 93 # features 94 self.feature_history = feature_history 95 self.feature_lookahead = feature_lookahead 96 self.frame_offset = 1 + self.feature_history 97 self.frames_per_sample = frames_per_sample 98 self.input_features = features 99 self.feature_frame_layout = dataset['feature_frame_layout'] 100 self.lpc_gamma = lpc_gamma 101 102 # load feature file 103 self.feature_file = os.path.join(path_to_dataset, dataset['feature_file']) 104 self.features = np.memmap(self.feature_file, dtype=dataset['feature_dtype']) 105 self.feature_frame_length = dataset['feature_frame_length'] 106 107 assert len(self.features) % self.feature_frame_length == 0 108 self.features = self.features.reshape((-1, self.feature_frame_length)) 109 110 # derive number of samples is dataset 111 self.dataset_length = (len(self.features) - self.frame_offset - self.feature_lookahead - 1) // self.frames_per_sample 112 113 # signals 114 self.frame_length = dataset['frame_length'] 115 self.signal_frame_layout = dataset['signal_frame_layout'] 116 self.input_signals = input_signals 117 self.target = target 118 119 # load signals 120 self.signal_file = os.path.join(path_to_dataset, dataset['signal_file']) 121 self.signals = np.memmap(self.signal_file, dtype=dataset['signal_dtype']) 122 self.signal_frame_length = dataset['signal_frame_length'] 123 self.signals = self.signals.reshape((-1, self.signal_frame_length)) 124 assert len(self.signals) == len(self.features) * self.frame_length 125 126 def __getitem__(self, index): 127 return self.getitem(index) 128 129 def getitem_v2(self, index): 130 sample = dict() 131 132 # extract features 133 frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history 134 frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead 135 136 for feature in self.input_features: 137 feature_start, feature_stop = self.feature_frame_layout[feature] 138 sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop] 139 140 # convert periods 141 if 'periods' in self.input_features: 142 sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16') 143 144 signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length 145 signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length 146 147 # last_signal and signal are always expected to be there 148 sample['last_signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']] 149 sample['signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['signal']] 150 151 # calculate prediction and error if lpc coefficients present and prediction not given 152 if 'lpc' in self.feature_frame_layout and 'prediction' not in self.signal_frame_layout: 153 # lpc coefficients with one frame lookahead 154 # frame positions (start one frame early for past excitation) 155 frame_start = self.frame_offset + self.frames_per_sample * index - 1 156 frame_stop = self.frame_offset + self.frames_per_sample * (index + 1) 157 158 # feature positions 159 lpc_start, lpc_stop = self.feature_frame_layout['lpc'] 160 lpc_order = lpc_stop - lpc_start 161 lpcs = self.features[frame_start : frame_stop, lpc_start : lpc_stop] 162 163 # LPC weighting 164 lpc_order = lpc_stop - lpc_start 165 weights = np.array([self.lpc_gamma ** (i + 1) for i in range(lpc_order)]) 166 lpcs = lpcs * weights 167 168 # signal position (lpc_order samples as history) 169 signal_start = frame_start * self.frame_length - lpc_order + 1 170 signal_stop = frame_stop * self.frame_length + 1 171 noisy_signal = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']] 172 clean_signal = self.signals[signal_start - 1 : signal_stop - 1, self.signal_frame_layout['signal']] 173 174 noisy_prediction, noisy_error = run_lpc(noisy_signal, lpcs, frame_length=self.frame_length) 175 176 # extract signals 177 offset = self.frame_length 178 sample['prediction'] = noisy_prediction[offset : offset + self.frame_length * self.frames_per_sample] 179 sample['last_error'] = noisy_error[offset - 1 : offset - 1 + self.frame_length * self.frames_per_sample] 180 # calculate error between real signal and noisy prediction 181 182 183 sample['error'] = sample['signal'] - sample['prediction'] 184 185 186 # concatenate features 187 feature_keys = [key for key in self.input_features if not key.startswith("periods")] 188 features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1) 189 signals = torch.cat([torch.LongTensor(lin2ulaw(sample[key])).unsqueeze(-1) for key in self.input_signals], dim=-1) 190 target = torch.LongTensor(lin2ulaw(sample[self.target])) 191 periods = torch.LongTensor(sample['periods']) 192 193 return {'features' : features, 'periods' : periods, 'signals' : signals, 'target' : target} 194 195 def getitem_v1(self, index): 196 sample = dict() 197 198 # extract features 199 frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history 200 frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead 201 202 for feature in self.input_features: 203 feature_start, feature_stop = self.feature_frame_layout[feature] 204 sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop] 205 206 # convert periods 207 if 'periods' in self.input_features: 208 sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16') 209 210 signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length 211 signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length 212 213 # last_signal and signal are always expected to be there 214 for signal_name, index in self.signal_frame_layout.items(): 215 sample[signal_name] = self.signals[signal_start : signal_stop, index] 216 217 # concatenate features 218 feature_keys = [key for key in self.input_features if not key.startswith("periods")] 219 features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1) 220 signals = torch.cat([torch.LongTensor(sample[key]).unsqueeze(-1) for key in self.input_signals], dim=-1) 221 target = torch.LongTensor(sample[self.target]) 222 periods = torch.LongTensor(sample['periods']) 223 224 return {'features' : features, 'periods' : periods, 'signals' : signals, 'target' : target} 225 226 def __len__(self): 227 return self.dataset_length 228