xref: /aosp_15_r20/external/armnn/python/pyarmnn/examples/common/mfcc.py (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker# Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
2*89c4ff92SAndroid Build Coastguard Worker# SPDX-License-Identifier: MIT
3*89c4ff92SAndroid Build Coastguard Worker
4*89c4ff92SAndroid Build Coastguard Worker"""Class used to extract the Mel-frequency cepstral coefficients from a given audio frame."""
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Workerimport numpy as np
7*89c4ff92SAndroid Build Coastguard Workerimport collections
8*89c4ff92SAndroid Build Coastguard Worker
9*89c4ff92SAndroid Build Coastguard WorkerMFCCParams = collections.namedtuple('MFCCParams', ['sampling_freq', 'num_fbank_bins', 'mel_lo_freq', 'mel_hi_freq',
10*89c4ff92SAndroid Build Coastguard Worker                                                   'num_mfcc_feats', 'frame_len', 'use_htk_method', 'n_fft'])
11*89c4ff92SAndroid Build Coastguard Worker
12*89c4ff92SAndroid Build Coastguard Worker
13*89c4ff92SAndroid Build Coastguard Workerclass MFCC:
14*89c4ff92SAndroid Build Coastguard Worker
15*89c4ff92SAndroid Build Coastguard Worker    def __init__(self, mfcc_params):
16*89c4ff92SAndroid Build Coastguard Worker        self.mfcc_params = mfcc_params
17*89c4ff92SAndroid Build Coastguard Worker        self.FREQ_STEP = 200.0 / 3
18*89c4ff92SAndroid Build Coastguard Worker        self.MIN_LOG_HZ = 1000.0
19*89c4ff92SAndroid Build Coastguard Worker        self.MIN_LOG_MEL = self.MIN_LOG_HZ / self.FREQ_STEP
20*89c4ff92SAndroid Build Coastguard Worker        self.LOG_STEP = 1.8562979903656 / 27.0
21*89c4ff92SAndroid Build Coastguard Worker        self._frame_len_padded = int(2 ** (np.ceil((np.log(self.mfcc_params.frame_len) / np.log(2.0)))))
22*89c4ff92SAndroid Build Coastguard Worker        self._filter_bank_initialised = False
23*89c4ff92SAndroid Build Coastguard Worker        self.__frame = np.zeros(self._frame_len_padded)
24*89c4ff92SAndroid Build Coastguard Worker        self.__buffer = np.zeros(self._frame_len_padded)
25*89c4ff92SAndroid Build Coastguard Worker        self._filter_bank_filter_first = np.zeros(self.mfcc_params.num_fbank_bins)
26*89c4ff92SAndroid Build Coastguard Worker        self._filter_bank_filter_last = np.zeros(self.mfcc_params.num_fbank_bins)
27*89c4ff92SAndroid Build Coastguard Worker        self.__mel_energies = np.zeros(self.mfcc_params.num_fbank_bins)
28*89c4ff92SAndroid Build Coastguard Worker        self._dct_matrix = self.create_dct_matrix(self.mfcc_params.num_fbank_bins, self.mfcc_params.num_mfcc_feats)
29*89c4ff92SAndroid Build Coastguard Worker        self.__mel_filter_bank = self.create_mel_filter_bank()
30*89c4ff92SAndroid Build Coastguard Worker        self._np_mel_bank = np.zeros([self.mfcc_params.num_fbank_bins, int(self.mfcc_params.n_fft / 2) + 1])
31*89c4ff92SAndroid Build Coastguard Worker
32*89c4ff92SAndroid Build Coastguard Worker        for i in range(self.mfcc_params.num_fbank_bins):
33*89c4ff92SAndroid Build Coastguard Worker            k = 0
34*89c4ff92SAndroid Build Coastguard Worker            for j in range(int(self._filter_bank_filter_first[i]), int(self._filter_bank_filter_last[i]) + 1):
35*89c4ff92SAndroid Build Coastguard Worker                self._np_mel_bank[i, j] = self.__mel_filter_bank[i][k]
36*89c4ff92SAndroid Build Coastguard Worker                k += 1
37*89c4ff92SAndroid Build Coastguard Worker
38*89c4ff92SAndroid Build Coastguard Worker    def mel_scale(self, freq, use_htk_method):
39*89c4ff92SAndroid Build Coastguard Worker        """
40*89c4ff92SAndroid Build Coastguard Worker        Gets the mel scale for a particular sample frequency.
41*89c4ff92SAndroid Build Coastguard Worker
42*89c4ff92SAndroid Build Coastguard Worker        Args:
43*89c4ff92SAndroid Build Coastguard Worker            freq: The sampling frequency.
44*89c4ff92SAndroid Build Coastguard Worker            use_htk_method: Boolean to set whether to use HTK method or not.
45*89c4ff92SAndroid Build Coastguard Worker
46*89c4ff92SAndroid Build Coastguard Worker        Returns:
47*89c4ff92SAndroid Build Coastguard Worker            the mel scale
48*89c4ff92SAndroid Build Coastguard Worker        """
49*89c4ff92SAndroid Build Coastguard Worker        if use_htk_method:
50*89c4ff92SAndroid Build Coastguard Worker            return 1127.0 * np.log(1.0 + freq / 700.0)
51*89c4ff92SAndroid Build Coastguard Worker        else:
52*89c4ff92SAndroid Build Coastguard Worker            mel = freq / self.FREQ_STEP
53*89c4ff92SAndroid Build Coastguard Worker
54*89c4ff92SAndroid Build Coastguard Worker        if freq >= self.MIN_LOG_HZ:
55*89c4ff92SAndroid Build Coastguard Worker            mel = self.MIN_LOG_MEL + np.log(freq / self.MIN_LOG_HZ) / self.LOG_STEP
56*89c4ff92SAndroid Build Coastguard Worker        return mel
57*89c4ff92SAndroid Build Coastguard Worker
58*89c4ff92SAndroid Build Coastguard Worker    def inv_mel_scale(self, mel_freq, use_htk_method):
59*89c4ff92SAndroid Build Coastguard Worker        """
60*89c4ff92SAndroid Build Coastguard Worker        Gets the sample frequency for a particular mel.
61*89c4ff92SAndroid Build Coastguard Worker
62*89c4ff92SAndroid Build Coastguard Worker        Args:
63*89c4ff92SAndroid Build Coastguard Worker            mel_freq: The mel frequency.
64*89c4ff92SAndroid Build Coastguard Worker            use_htk_method: Boolean to set whether to use HTK method or not.
65*89c4ff92SAndroid Build Coastguard Worker
66*89c4ff92SAndroid Build Coastguard Worker        Returns:
67*89c4ff92SAndroid Build Coastguard Worker            the sample frequency
68*89c4ff92SAndroid Build Coastguard Worker        """
69*89c4ff92SAndroid Build Coastguard Worker        if use_htk_method:
70*89c4ff92SAndroid Build Coastguard Worker            return 700.0 * (np.exp(mel_freq / 1127.0) - 1.0)
71*89c4ff92SAndroid Build Coastguard Worker        else:
72*89c4ff92SAndroid Build Coastguard Worker            freq = self.FREQ_STEP * mel_freq
73*89c4ff92SAndroid Build Coastguard Worker
74*89c4ff92SAndroid Build Coastguard Worker            if mel_freq >= self.MIN_LOG_MEL:
75*89c4ff92SAndroid Build Coastguard Worker                freq = self.MIN_LOG_HZ * np.exp(self.LOG_STEP * (mel_freq - self.MIN_LOG_MEL))
76*89c4ff92SAndroid Build Coastguard Worker            return freq
77*89c4ff92SAndroid Build Coastguard Worker
78*89c4ff92SAndroid Build Coastguard Worker    def spectrum_calc(self, audio_data):
79*89c4ff92SAndroid Build Coastguard Worker        return np.abs(np.fft.rfft(np.hanning(self.mfcc_params.frame_len + 1)[0:self.mfcc_params.frame_len] * audio_data,
80*89c4ff92SAndroid Build Coastguard Worker                                  self.mfcc_params.n_fft))
81*89c4ff92SAndroid Build Coastguard Worker
82*89c4ff92SAndroid Build Coastguard Worker    def log_mel(self, mel_energy):
83*89c4ff92SAndroid Build Coastguard Worker        mel_energy += 1e-10  # Avoid division by zero
84*89c4ff92SAndroid Build Coastguard Worker        return np.log(mel_energy)
85*89c4ff92SAndroid Build Coastguard Worker
86*89c4ff92SAndroid Build Coastguard Worker    def mfcc_compute(self, audio_data):
87*89c4ff92SAndroid Build Coastguard Worker        """
88*89c4ff92SAndroid Build Coastguard Worker        Extracts the MFCC for a single frame.
89*89c4ff92SAndroid Build Coastguard Worker
90*89c4ff92SAndroid Build Coastguard Worker        Args:
91*89c4ff92SAndroid Build Coastguard Worker            audio_data: The audio data to process.
92*89c4ff92SAndroid Build Coastguard Worker
93*89c4ff92SAndroid Build Coastguard Worker        Returns:
94*89c4ff92SAndroid Build Coastguard Worker            the MFCC features
95*89c4ff92SAndroid Build Coastguard Worker        """
96*89c4ff92SAndroid Build Coastguard Worker        if len(audio_data) != self.mfcc_params.frame_len:
97*89c4ff92SAndroid Build Coastguard Worker            raise ValueError(
98*89c4ff92SAndroid Build Coastguard Worker                f"audio_data buffer size {len(audio_data)} does not match frame length {self.mfcc_params.frame_len}")
99*89c4ff92SAndroid Build Coastguard Worker
100*89c4ff92SAndroid Build Coastguard Worker        audio_data = np.array(audio_data)
101*89c4ff92SAndroid Build Coastguard Worker        spec = self.spectrum_calc(audio_data)
102*89c4ff92SAndroid Build Coastguard Worker        mel_energy = np.dot(self._np_mel_bank.astype(np.float32),
103*89c4ff92SAndroid Build Coastguard Worker                            np.transpose(spec).astype(np.float32))
104*89c4ff92SAndroid Build Coastguard Worker        log_mel_energy = self.log_mel(mel_energy)
105*89c4ff92SAndroid Build Coastguard Worker        mfcc_feats = np.dot(self._dct_matrix, log_mel_energy)
106*89c4ff92SAndroid Build Coastguard Worker        return mfcc_feats
107*89c4ff92SAndroid Build Coastguard Worker
108*89c4ff92SAndroid Build Coastguard Worker    def create_dct_matrix(self, num_fbank_bins, num_mfcc_feats):
109*89c4ff92SAndroid Build Coastguard Worker        """
110*89c4ff92SAndroid Build Coastguard Worker        Creates the Discrete Cosine Transform matrix to be used in the compute function.
111*89c4ff92SAndroid Build Coastguard Worker
112*89c4ff92SAndroid Build Coastguard Worker        Args:
113*89c4ff92SAndroid Build Coastguard Worker            num_fbank_bins: The number of filter bank bins
114*89c4ff92SAndroid Build Coastguard Worker            num_mfcc_feats: the number of MFCC features
115*89c4ff92SAndroid Build Coastguard Worker
116*89c4ff92SAndroid Build Coastguard Worker        Returns:
117*89c4ff92SAndroid Build Coastguard Worker            the DCT matrix
118*89c4ff92SAndroid Build Coastguard Worker        """
119*89c4ff92SAndroid Build Coastguard Worker
120*89c4ff92SAndroid Build Coastguard Worker        dct_m = np.zeros(num_fbank_bins * num_mfcc_feats)
121*89c4ff92SAndroid Build Coastguard Worker        for k in range(num_mfcc_feats):
122*89c4ff92SAndroid Build Coastguard Worker            for n in range(num_fbank_bins):
123*89c4ff92SAndroid Build Coastguard Worker                dct_m[(k * num_fbank_bins) + n] = (np.sqrt(2 / num_fbank_bins)) * np.cos(
124*89c4ff92SAndroid Build Coastguard Worker                    (np.pi / num_fbank_bins) * (n + 0.5) * k)
125*89c4ff92SAndroid Build Coastguard Worker        dct_m = np.reshape(dct_m, [self.mfcc_params.num_mfcc_feats, self.mfcc_params.num_fbank_bins])
126*89c4ff92SAndroid Build Coastguard Worker        return dct_m
127*89c4ff92SAndroid Build Coastguard Worker
128*89c4ff92SAndroid Build Coastguard Worker    def mel_norm(self, weight, right_mel, left_mel):
129*89c4ff92SAndroid Build Coastguard Worker        """
130*89c4ff92SAndroid Build Coastguard Worker        Placeholder function over-ridden in child class
131*89c4ff92SAndroid Build Coastguard Worker        """
132*89c4ff92SAndroid Build Coastguard Worker        return weight
133*89c4ff92SAndroid Build Coastguard Worker
134*89c4ff92SAndroid Build Coastguard Worker    def create_mel_filter_bank(self):
135*89c4ff92SAndroid Build Coastguard Worker        """
136*89c4ff92SAndroid Build Coastguard Worker        Creates the Mel filter bank.
137*89c4ff92SAndroid Build Coastguard Worker
138*89c4ff92SAndroid Build Coastguard Worker        Returns:
139*89c4ff92SAndroid Build Coastguard Worker            the mel filter bank
140*89c4ff92SAndroid Build Coastguard Worker        """
141*89c4ff92SAndroid Build Coastguard Worker        # FFT calculations are greatly accelerated for frame lengths which are powers of 2
142*89c4ff92SAndroid Build Coastguard Worker        # Frames are padded and FFT bin width/length calculated accordingly
143*89c4ff92SAndroid Build Coastguard Worker        num_fft_bins = int(self._frame_len_padded / 2)
144*89c4ff92SAndroid Build Coastguard Worker        fft_bin_width = self.mfcc_params.sampling_freq / self._frame_len_padded
145*89c4ff92SAndroid Build Coastguard Worker
146*89c4ff92SAndroid Build Coastguard Worker        mel_low_freq = self.mel_scale(self.mfcc_params.mel_lo_freq, self.mfcc_params.use_htk_method)
147*89c4ff92SAndroid Build Coastguard Worker        mel_high_freq = self.mel_scale(self.mfcc_params.mel_hi_freq, self.mfcc_params.use_htk_method)
148*89c4ff92SAndroid Build Coastguard Worker        mel_freq_delta = (mel_high_freq - mel_low_freq) / (self.mfcc_params.num_fbank_bins + 1)
149*89c4ff92SAndroid Build Coastguard Worker
150*89c4ff92SAndroid Build Coastguard Worker        this_bin = np.zeros(num_fft_bins)
151*89c4ff92SAndroid Build Coastguard Worker        mel_fbank = [0] * self.mfcc_params.num_fbank_bins
152*89c4ff92SAndroid Build Coastguard Worker        for bin_num in range(self.mfcc_params.num_fbank_bins):
153*89c4ff92SAndroid Build Coastguard Worker            left_mel = mel_low_freq + bin_num * mel_freq_delta
154*89c4ff92SAndroid Build Coastguard Worker            center_mel = mel_low_freq + (bin_num + 1) * mel_freq_delta
155*89c4ff92SAndroid Build Coastguard Worker            right_mel = mel_low_freq + (bin_num + 2) * mel_freq_delta
156*89c4ff92SAndroid Build Coastguard Worker            first_index = last_index = -1
157*89c4ff92SAndroid Build Coastguard Worker
158*89c4ff92SAndroid Build Coastguard Worker            for i in range(num_fft_bins):
159*89c4ff92SAndroid Build Coastguard Worker                freq = (fft_bin_width * i)
160*89c4ff92SAndroid Build Coastguard Worker                mel = self.mel_scale(freq, self.mfcc_params.use_htk_method)
161*89c4ff92SAndroid Build Coastguard Worker                this_bin[i] = 0.0
162*89c4ff92SAndroid Build Coastguard Worker
163*89c4ff92SAndroid Build Coastguard Worker                if (mel > left_mel) and (mel < right_mel):
164*89c4ff92SAndroid Build Coastguard Worker                    if mel <= center_mel:
165*89c4ff92SAndroid Build Coastguard Worker                        weight = (mel - left_mel) / (center_mel - left_mel)
166*89c4ff92SAndroid Build Coastguard Worker                    else:
167*89c4ff92SAndroid Build Coastguard Worker                        weight = (right_mel - mel) / (right_mel - center_mel)
168*89c4ff92SAndroid Build Coastguard Worker
169*89c4ff92SAndroid Build Coastguard Worker                    this_bin[i] = self.mel_norm(weight, right_mel, left_mel)
170*89c4ff92SAndroid Build Coastguard Worker
171*89c4ff92SAndroid Build Coastguard Worker                    if first_index == -1:
172*89c4ff92SAndroid Build Coastguard Worker                        first_index = i
173*89c4ff92SAndroid Build Coastguard Worker                    last_index = i
174*89c4ff92SAndroid Build Coastguard Worker
175*89c4ff92SAndroid Build Coastguard Worker            self._filter_bank_filter_first[bin_num] = first_index
176*89c4ff92SAndroid Build Coastguard Worker            self._filter_bank_filter_last[bin_num] = last_index
177*89c4ff92SAndroid Build Coastguard Worker            mel_fbank[bin_num] = np.zeros(last_index - first_index + 1)
178*89c4ff92SAndroid Build Coastguard Worker            j = 0
179*89c4ff92SAndroid Build Coastguard Worker
180*89c4ff92SAndroid Build Coastguard Worker            for i in range(first_index, last_index + 1):
181*89c4ff92SAndroid Build Coastguard Worker                mel_fbank[bin_num][j] = this_bin[i]
182*89c4ff92SAndroid Build Coastguard Worker                j += 1
183*89c4ff92SAndroid Build Coastguard Worker
184*89c4ff92SAndroid Build Coastguard Worker        return mel_fbank
185*89c4ff92SAndroid Build Coastguard Worker
186*89c4ff92SAndroid Build Coastguard Worker
187*89c4ff92SAndroid Build Coastguard Workerclass AudioPreprocessor:
188*89c4ff92SAndroid Build Coastguard Worker
189*89c4ff92SAndroid Build Coastguard Worker    def __init__(self, mfcc, model_input_size, stride):
190*89c4ff92SAndroid Build Coastguard Worker        self.model_input_size = model_input_size
191*89c4ff92SAndroid Build Coastguard Worker        self.stride = stride
192*89c4ff92SAndroid Build Coastguard Worker        self._mfcc_calc = mfcc
193*89c4ff92SAndroid Build Coastguard Worker
194*89c4ff92SAndroid Build Coastguard Worker    def _normalize(self, values):
195*89c4ff92SAndroid Build Coastguard Worker        """
196*89c4ff92SAndroid Build Coastguard Worker        Normalize values to mean 0 and std 1
197*89c4ff92SAndroid Build Coastguard Worker        """
198*89c4ff92SAndroid Build Coastguard Worker        ret_val = (values - np.mean(values)) / np.std(values)
199*89c4ff92SAndroid Build Coastguard Worker        return ret_val
200*89c4ff92SAndroid Build Coastguard Worker
201*89c4ff92SAndroid Build Coastguard Worker    def _get_features(self, features, mfcc_instance, audio_data):
202*89c4ff92SAndroid Build Coastguard Worker        idx = 0
203*89c4ff92SAndroid Build Coastguard Worker        while len(features) < self.model_input_size * mfcc_instance.mfcc_params.num_mfcc_feats:
204*89c4ff92SAndroid Build Coastguard Worker            current_frame_feats = mfcc_instance.mfcc_compute(audio_data[idx:idx + int(mfcc_instance.mfcc_params.frame_len)])
205*89c4ff92SAndroid Build Coastguard Worker            features.extend(current_frame_feats)
206*89c4ff92SAndroid Build Coastguard Worker            idx += self.stride
207*89c4ff92SAndroid Build Coastguard Worker
208*89c4ff92SAndroid Build Coastguard Worker    def mfcc_delta_calc(self, features):
209*89c4ff92SAndroid Build Coastguard Worker        """
210*89c4ff92SAndroid Build Coastguard Worker        Placeholder function over-ridden in child class
211*89c4ff92SAndroid Build Coastguard Worker        """
212*89c4ff92SAndroid Build Coastguard Worker        return features
213*89c4ff92SAndroid Build Coastguard Worker
214*89c4ff92SAndroid Build Coastguard Worker    def extract_features(self, audio_data):
215*89c4ff92SAndroid Build Coastguard Worker        """
216*89c4ff92SAndroid Build Coastguard Worker        Extracts the MFCC features. Also calculates each features first and second order derivatives
217*89c4ff92SAndroid Build Coastguard Worker        if the mfcc_delta_calc() function has been implemented by a child class.
218*89c4ff92SAndroid Build Coastguard Worker        The matrix returned should be sized appropriately for input to the model, based
219*89c4ff92SAndroid Build Coastguard Worker        on the model info specified in the MFCC instance.
220*89c4ff92SAndroid Build Coastguard Worker
221*89c4ff92SAndroid Build Coastguard Worker        Args:
222*89c4ff92SAndroid Build Coastguard Worker            audio_data: the audio data to be used for this calculation
223*89c4ff92SAndroid Build Coastguard Worker        Returns:
224*89c4ff92SAndroid Build Coastguard Worker            the derived MFCC feature vector, sized appropriately for inference
225*89c4ff92SAndroid Build Coastguard Worker        """
226*89c4ff92SAndroid Build Coastguard Worker
227*89c4ff92SAndroid Build Coastguard Worker        num_samples_per_inference = ((self.model_input_size - 1)
228*89c4ff92SAndroid Build Coastguard Worker                                     * self.stride) + self._mfcc_calc.mfcc_params.frame_len
229*89c4ff92SAndroid Build Coastguard Worker
230*89c4ff92SAndroid Build Coastguard Worker        if len(audio_data) < num_samples_per_inference:
231*89c4ff92SAndroid Build Coastguard Worker            raise ValueError("audio_data size for feature extraction is smaller than "
232*89c4ff92SAndroid Build Coastguard Worker                             "the expected number of samples needed for inference")
233*89c4ff92SAndroid Build Coastguard Worker
234*89c4ff92SAndroid Build Coastguard Worker        features = []
235*89c4ff92SAndroid Build Coastguard Worker        self._get_features(features, self._mfcc_calc, np.asarray(audio_data))
236*89c4ff92SAndroid Build Coastguard Worker        features = np.reshape(np.array(features), (self.model_input_size, self._mfcc_calc.mfcc_params.num_mfcc_feats))
237*89c4ff92SAndroid Build Coastguard Worker        features = self.mfcc_delta_calc(features)
238*89c4ff92SAndroid Build Coastguard Worker        return np.float32(features)
239