xref: /aosp_15_r20/external/libopus/dnn/torch/fargan/dataset.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1*a58d3d2aSXin Liimport torch
2*a58d3d2aSXin Liimport numpy as np
3*a58d3d2aSXin Liimport fargan
4*a58d3d2aSXin Li
5*a58d3d2aSXin Liclass FARGANDataset(torch.utils.data.Dataset):
6*a58d3d2aSXin Li    def __init__(self,
7*a58d3d2aSXin Li                feature_file,
8*a58d3d2aSXin Li                signal_file,
9*a58d3d2aSXin Li                frame_size=160,
10*a58d3d2aSXin Li                sequence_length=15,
11*a58d3d2aSXin Li                lookahead=1,
12*a58d3d2aSXin Li                nb_used_features=20,
13*a58d3d2aSXin Li                nb_features=36):
14*a58d3d2aSXin Li
15*a58d3d2aSXin Li        self.frame_size = frame_size
16*a58d3d2aSXin Li        self.sequence_length = sequence_length
17*a58d3d2aSXin Li        self.lookahead = lookahead
18*a58d3d2aSXin Li        self.nb_features = nb_features
19*a58d3d2aSXin Li        self.nb_used_features = nb_used_features
20*a58d3d2aSXin Li        pcm_chunk_size = self.frame_size*self.sequence_length
21*a58d3d2aSXin Li
22*a58d3d2aSXin Li        self.data = np.memmap(signal_file, dtype='int16', mode='r')
23*a58d3d2aSXin Li        #self.data = self.data[1::2]
24*a58d3d2aSXin Li        self.nb_sequences = len(self.data)//(pcm_chunk_size)-4
25*a58d3d2aSXin Li        self.data = self.data[(4-self.lookahead)*self.frame_size:]
26*a58d3d2aSXin Li        self.data = self.data[:self.nb_sequences*pcm_chunk_size]
27*a58d3d2aSXin Li
28*a58d3d2aSXin Li
29*a58d3d2aSXin Li        #self.data = np.reshape(self.data, (self.nb_sequences, pcm_chunk_size))
30*a58d3d2aSXin Li        sizeof = self.data.strides[-1]
31*a58d3d2aSXin Li        self.data = np.lib.stride_tricks.as_strided(self.data, shape=(self.nb_sequences, pcm_chunk_size*2),
32*a58d3d2aSXin Li                                           strides=(pcm_chunk_size*sizeof, sizeof))
33*a58d3d2aSXin Li
34*a58d3d2aSXin Li        self.features = np.reshape(np.memmap(feature_file, dtype='float32', mode='r'), (-1, nb_features))
35*a58d3d2aSXin Li        sizeof = self.features.strides[-1]
36*a58d3d2aSXin Li        self.features = np.lib.stride_tricks.as_strided(self.features, shape=(self.nb_sequences, self.sequence_length*2+4, nb_features),
37*a58d3d2aSXin Li                                           strides=(self.sequence_length*self.nb_features*sizeof, self.nb_features*sizeof, sizeof))
38*a58d3d2aSXin Li        #self.periods = np.round(50*self.features[:,:,self.nb_used_features-2]+100).astype('int')
39*a58d3d2aSXin Li        self.periods = np.round(np.clip(256./2**(self.features[:,:,self.nb_used_features-2]+1.5), 32, 255)).astype('int')
40*a58d3d2aSXin Li
41*a58d3d2aSXin Li        self.lpc = self.features[:, :, self.nb_used_features:]
42*a58d3d2aSXin Li        self.features = self.features[:, :, :self.nb_used_features]
43*a58d3d2aSXin Li        print("lpc_size:", self.lpc.shape)
44*a58d3d2aSXin Li
45*a58d3d2aSXin Li    def __len__(self):
46*a58d3d2aSXin Li        return self.nb_sequences
47*a58d3d2aSXin Li
48*a58d3d2aSXin Li    def __getitem__(self, index):
49*a58d3d2aSXin Li        features = self.features[index, :, :].copy()
50*a58d3d2aSXin Li        if self.lookahead != 0:
51*a58d3d2aSXin Li            lpc = self.lpc[index, 4-self.lookahead:-self.lookahead, :].copy()
52*a58d3d2aSXin Li        else:
53*a58d3d2aSXin Li            lpc = self.lpc[index, 4:, :].copy()
54*a58d3d2aSXin Li        data = self.data[index, :].copy().astype(np.float32) / 2**15
55*a58d3d2aSXin Li        periods = self.periods[index, :].copy()
56*a58d3d2aSXin Li        #lpc = lpc*(self.gamma**np.arange(1,17))
57*a58d3d2aSXin Li        #lpc=lpc[None,:,:]
58*a58d3d2aSXin Li        #lpc = fargan.interp_lpc(lpc, 4)
59*a58d3d2aSXin Li        #lpc=lpc[0,:,:]
60*a58d3d2aSXin Li
61*a58d3d2aSXin Li        return features, periods, data, lpc
62