xref: /aosp_15_r20/external/libopus/dnn/torch/fargan/dataset.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1import torch
2import numpy as np
3import fargan
4
5class FARGANDataset(torch.utils.data.Dataset):
6    def __init__(self,
7                feature_file,
8                signal_file,
9                frame_size=160,
10                sequence_length=15,
11                lookahead=1,
12                nb_used_features=20,
13                nb_features=36):
14
15        self.frame_size = frame_size
16        self.sequence_length = sequence_length
17        self.lookahead = lookahead
18        self.nb_features = nb_features
19        self.nb_used_features = nb_used_features
20        pcm_chunk_size = self.frame_size*self.sequence_length
21
22        self.data = np.memmap(signal_file, dtype='int16', mode='r')
23        #self.data = self.data[1::2]
24        self.nb_sequences = len(self.data)//(pcm_chunk_size)-4
25        self.data = self.data[(4-self.lookahead)*self.frame_size:]
26        self.data = self.data[:self.nb_sequences*pcm_chunk_size]
27
28
29        #self.data = np.reshape(self.data, (self.nb_sequences, pcm_chunk_size))
30        sizeof = self.data.strides[-1]
31        self.data = np.lib.stride_tricks.as_strided(self.data, shape=(self.nb_sequences, pcm_chunk_size*2),
32                                           strides=(pcm_chunk_size*sizeof, sizeof))
33
34        self.features = np.reshape(np.memmap(feature_file, dtype='float32', mode='r'), (-1, nb_features))
35        sizeof = self.features.strides[-1]
36        self.features = np.lib.stride_tricks.as_strided(self.features, shape=(self.nb_sequences, self.sequence_length*2+4, nb_features),
37                                           strides=(self.sequence_length*self.nb_features*sizeof, self.nb_features*sizeof, sizeof))
38        #self.periods = np.round(50*self.features[:,:,self.nb_used_features-2]+100).astype('int')
39        self.periods = np.round(np.clip(256./2**(self.features[:,:,self.nb_used_features-2]+1.5), 32, 255)).astype('int')
40
41        self.lpc = self.features[:, :, self.nb_used_features:]
42        self.features = self.features[:, :, :self.nb_used_features]
43        print("lpc_size:", self.lpc.shape)
44
45    def __len__(self):
46        return self.nb_sequences
47
48    def __getitem__(self, index):
49        features = self.features[index, :, :].copy()
50        if self.lookahead != 0:
51            lpc = self.lpc[index, 4-self.lookahead:-self.lookahead, :].copy()
52        else:
53            lpc = self.lpc[index, 4:, :].copy()
54        data = self.data[index, :].copy().astype(np.float32) / 2**15
55        periods = self.periods[index, :].copy()
56        #lpc = lpc*(self.gamma**np.arange(1,17))
57        #lpc=lpc[None,:,:]
58        #lpc = fargan.interp_lpc(lpc, 4)
59        #lpc=lpc[0,:,:]
60
61        return features, periods, data, lpc
62