1*89c4ff92SAndroid Build Coastguard Worker# Copyright © 2021 Arm Ltd and Contributors. All rights reserved. 2*89c4ff92SAndroid Build Coastguard Worker# SPDX-License-Identifier: MIT 3*89c4ff92SAndroid Build Coastguard Worker 4*89c4ff92SAndroid Build Coastguard Worker"""Class used to extract the Mel-frequency cepstral coefficients from a given audio frame.""" 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Workerimport numpy as np 7*89c4ff92SAndroid Build Coastguard Workerimport collections 8*89c4ff92SAndroid Build Coastguard Worker 9*89c4ff92SAndroid Build Coastguard WorkerMFCCParams = collections.namedtuple('MFCCParams', ['sampling_freq', 'num_fbank_bins', 'mel_lo_freq', 'mel_hi_freq', 10*89c4ff92SAndroid Build Coastguard Worker 'num_mfcc_feats', 'frame_len', 'use_htk_method', 'n_fft']) 11*89c4ff92SAndroid Build Coastguard Worker 12*89c4ff92SAndroid Build Coastguard Worker 13*89c4ff92SAndroid Build Coastguard Workerclass MFCC: 14*89c4ff92SAndroid Build Coastguard Worker 15*89c4ff92SAndroid Build Coastguard Worker def __init__(self, mfcc_params): 16*89c4ff92SAndroid Build Coastguard Worker self.mfcc_params = mfcc_params 17*89c4ff92SAndroid Build Coastguard Worker self.FREQ_STEP = 200.0 / 3 18*89c4ff92SAndroid Build Coastguard Worker self.MIN_LOG_HZ = 1000.0 19*89c4ff92SAndroid Build Coastguard Worker self.MIN_LOG_MEL = self.MIN_LOG_HZ / self.FREQ_STEP 20*89c4ff92SAndroid Build Coastguard Worker self.LOG_STEP = 1.8562979903656 / 27.0 21*89c4ff92SAndroid Build Coastguard Worker self._frame_len_padded = int(2 ** (np.ceil((np.log(self.mfcc_params.frame_len) / np.log(2.0))))) 22*89c4ff92SAndroid Build Coastguard Worker self._filter_bank_initialised = False 23*89c4ff92SAndroid Build Coastguard Worker self.__frame = np.zeros(self._frame_len_padded) 24*89c4ff92SAndroid Build Coastguard Worker self.__buffer = np.zeros(self._frame_len_padded) 25*89c4ff92SAndroid Build Coastguard Worker self._filter_bank_filter_first = np.zeros(self.mfcc_params.num_fbank_bins) 26*89c4ff92SAndroid Build Coastguard Worker self._filter_bank_filter_last = np.zeros(self.mfcc_params.num_fbank_bins) 27*89c4ff92SAndroid Build Coastguard Worker self.__mel_energies = np.zeros(self.mfcc_params.num_fbank_bins) 28*89c4ff92SAndroid Build Coastguard Worker self._dct_matrix = self.create_dct_matrix(self.mfcc_params.num_fbank_bins, self.mfcc_params.num_mfcc_feats) 29*89c4ff92SAndroid Build Coastguard Worker self.__mel_filter_bank = self.create_mel_filter_bank() 30*89c4ff92SAndroid Build Coastguard Worker self._np_mel_bank = np.zeros([self.mfcc_params.num_fbank_bins, int(self.mfcc_params.n_fft / 2) + 1]) 31*89c4ff92SAndroid Build Coastguard Worker 32*89c4ff92SAndroid Build Coastguard Worker for i in range(self.mfcc_params.num_fbank_bins): 33*89c4ff92SAndroid Build Coastguard Worker k = 0 34*89c4ff92SAndroid Build Coastguard Worker for j in range(int(self._filter_bank_filter_first[i]), int(self._filter_bank_filter_last[i]) + 1): 35*89c4ff92SAndroid Build Coastguard Worker self._np_mel_bank[i, j] = self.__mel_filter_bank[i][k] 36*89c4ff92SAndroid Build Coastguard Worker k += 1 37*89c4ff92SAndroid Build Coastguard Worker 38*89c4ff92SAndroid Build Coastguard Worker def mel_scale(self, freq, use_htk_method): 39*89c4ff92SAndroid Build Coastguard Worker """ 40*89c4ff92SAndroid Build Coastguard Worker Gets the mel scale for a particular sample frequency. 41*89c4ff92SAndroid Build Coastguard Worker 42*89c4ff92SAndroid Build Coastguard Worker Args: 43*89c4ff92SAndroid Build Coastguard Worker freq: The sampling frequency. 44*89c4ff92SAndroid Build Coastguard Worker use_htk_method: Boolean to set whether to use HTK method or not. 45*89c4ff92SAndroid Build Coastguard Worker 46*89c4ff92SAndroid Build Coastguard Worker Returns: 47*89c4ff92SAndroid Build Coastguard Worker the mel scale 48*89c4ff92SAndroid Build Coastguard Worker """ 49*89c4ff92SAndroid Build Coastguard Worker if use_htk_method: 50*89c4ff92SAndroid Build Coastguard Worker return 1127.0 * np.log(1.0 + freq / 700.0) 51*89c4ff92SAndroid Build Coastguard Worker else: 52*89c4ff92SAndroid Build Coastguard Worker mel = freq / self.FREQ_STEP 53*89c4ff92SAndroid Build Coastguard Worker 54*89c4ff92SAndroid Build Coastguard Worker if freq >= self.MIN_LOG_HZ: 55*89c4ff92SAndroid Build Coastguard Worker mel = self.MIN_LOG_MEL + np.log(freq / self.MIN_LOG_HZ) / self.LOG_STEP 56*89c4ff92SAndroid Build Coastguard Worker return mel 57*89c4ff92SAndroid Build Coastguard Worker 58*89c4ff92SAndroid Build Coastguard Worker def inv_mel_scale(self, mel_freq, use_htk_method): 59*89c4ff92SAndroid Build Coastguard Worker """ 60*89c4ff92SAndroid Build Coastguard Worker Gets the sample frequency for a particular mel. 61*89c4ff92SAndroid Build Coastguard Worker 62*89c4ff92SAndroid Build Coastguard Worker Args: 63*89c4ff92SAndroid Build Coastguard Worker mel_freq: The mel frequency. 64*89c4ff92SAndroid Build Coastguard Worker use_htk_method: Boolean to set whether to use HTK method or not. 65*89c4ff92SAndroid Build Coastguard Worker 66*89c4ff92SAndroid Build Coastguard Worker Returns: 67*89c4ff92SAndroid Build Coastguard Worker the sample frequency 68*89c4ff92SAndroid Build Coastguard Worker """ 69*89c4ff92SAndroid Build Coastguard Worker if use_htk_method: 70*89c4ff92SAndroid Build Coastguard Worker return 700.0 * (np.exp(mel_freq / 1127.0) - 1.0) 71*89c4ff92SAndroid Build Coastguard Worker else: 72*89c4ff92SAndroid Build Coastguard Worker freq = self.FREQ_STEP * mel_freq 73*89c4ff92SAndroid Build Coastguard Worker 74*89c4ff92SAndroid Build Coastguard Worker if mel_freq >= self.MIN_LOG_MEL: 75*89c4ff92SAndroid Build Coastguard Worker freq = self.MIN_LOG_HZ * np.exp(self.LOG_STEP * (mel_freq - self.MIN_LOG_MEL)) 76*89c4ff92SAndroid Build Coastguard Worker return freq 77*89c4ff92SAndroid Build Coastguard Worker 78*89c4ff92SAndroid Build Coastguard Worker def spectrum_calc(self, audio_data): 79*89c4ff92SAndroid Build Coastguard Worker return np.abs(np.fft.rfft(np.hanning(self.mfcc_params.frame_len + 1)[0:self.mfcc_params.frame_len] * audio_data, 80*89c4ff92SAndroid Build Coastguard Worker self.mfcc_params.n_fft)) 81*89c4ff92SAndroid Build Coastguard Worker 82*89c4ff92SAndroid Build Coastguard Worker def log_mel(self, mel_energy): 83*89c4ff92SAndroid Build Coastguard Worker mel_energy += 1e-10 # Avoid division by zero 84*89c4ff92SAndroid Build Coastguard Worker return np.log(mel_energy) 85*89c4ff92SAndroid Build Coastguard Worker 86*89c4ff92SAndroid Build Coastguard Worker def mfcc_compute(self, audio_data): 87*89c4ff92SAndroid Build Coastguard Worker """ 88*89c4ff92SAndroid Build Coastguard Worker Extracts the MFCC for a single frame. 89*89c4ff92SAndroid Build Coastguard Worker 90*89c4ff92SAndroid Build Coastguard Worker Args: 91*89c4ff92SAndroid Build Coastguard Worker audio_data: The audio data to process. 92*89c4ff92SAndroid Build Coastguard Worker 93*89c4ff92SAndroid Build Coastguard Worker Returns: 94*89c4ff92SAndroid Build Coastguard Worker the MFCC features 95*89c4ff92SAndroid Build Coastguard Worker """ 96*89c4ff92SAndroid Build Coastguard Worker if len(audio_data) != self.mfcc_params.frame_len: 97*89c4ff92SAndroid Build Coastguard Worker raise ValueError( 98*89c4ff92SAndroid Build Coastguard Worker f"audio_data buffer size {len(audio_data)} does not match frame length {self.mfcc_params.frame_len}") 99*89c4ff92SAndroid Build Coastguard Worker 100*89c4ff92SAndroid Build Coastguard Worker audio_data = np.array(audio_data) 101*89c4ff92SAndroid Build Coastguard Worker spec = self.spectrum_calc(audio_data) 102*89c4ff92SAndroid Build Coastguard Worker mel_energy = np.dot(self._np_mel_bank.astype(np.float32), 103*89c4ff92SAndroid Build Coastguard Worker np.transpose(spec).astype(np.float32)) 104*89c4ff92SAndroid Build Coastguard Worker log_mel_energy = self.log_mel(mel_energy) 105*89c4ff92SAndroid Build Coastguard Worker mfcc_feats = np.dot(self._dct_matrix, log_mel_energy) 106*89c4ff92SAndroid Build Coastguard Worker return mfcc_feats 107*89c4ff92SAndroid Build Coastguard Worker 108*89c4ff92SAndroid Build Coastguard Worker def create_dct_matrix(self, num_fbank_bins, num_mfcc_feats): 109*89c4ff92SAndroid Build Coastguard Worker """ 110*89c4ff92SAndroid Build Coastguard Worker Creates the Discrete Cosine Transform matrix to be used in the compute function. 111*89c4ff92SAndroid Build Coastguard Worker 112*89c4ff92SAndroid Build Coastguard Worker Args: 113*89c4ff92SAndroid Build Coastguard Worker num_fbank_bins: The number of filter bank bins 114*89c4ff92SAndroid Build Coastguard Worker num_mfcc_feats: the number of MFCC features 115*89c4ff92SAndroid Build Coastguard Worker 116*89c4ff92SAndroid Build Coastguard Worker Returns: 117*89c4ff92SAndroid Build Coastguard Worker the DCT matrix 118*89c4ff92SAndroid Build Coastguard Worker """ 119*89c4ff92SAndroid Build Coastguard Worker 120*89c4ff92SAndroid Build Coastguard Worker dct_m = np.zeros(num_fbank_bins * num_mfcc_feats) 121*89c4ff92SAndroid Build Coastguard Worker for k in range(num_mfcc_feats): 122*89c4ff92SAndroid Build Coastguard Worker for n in range(num_fbank_bins): 123*89c4ff92SAndroid Build Coastguard Worker dct_m[(k * num_fbank_bins) + n] = (np.sqrt(2 / num_fbank_bins)) * np.cos( 124*89c4ff92SAndroid Build Coastguard Worker (np.pi / num_fbank_bins) * (n + 0.5) * k) 125*89c4ff92SAndroid Build Coastguard Worker dct_m = np.reshape(dct_m, [self.mfcc_params.num_mfcc_feats, self.mfcc_params.num_fbank_bins]) 126*89c4ff92SAndroid Build Coastguard Worker return dct_m 127*89c4ff92SAndroid Build Coastguard Worker 128*89c4ff92SAndroid Build Coastguard Worker def mel_norm(self, weight, right_mel, left_mel): 129*89c4ff92SAndroid Build Coastguard Worker """ 130*89c4ff92SAndroid Build Coastguard Worker Placeholder function over-ridden in child class 131*89c4ff92SAndroid Build Coastguard Worker """ 132*89c4ff92SAndroid Build Coastguard Worker return weight 133*89c4ff92SAndroid Build Coastguard Worker 134*89c4ff92SAndroid Build Coastguard Worker def create_mel_filter_bank(self): 135*89c4ff92SAndroid Build Coastguard Worker """ 136*89c4ff92SAndroid Build Coastguard Worker Creates the Mel filter bank. 137*89c4ff92SAndroid Build Coastguard Worker 138*89c4ff92SAndroid Build Coastguard Worker Returns: 139*89c4ff92SAndroid Build Coastguard Worker the mel filter bank 140*89c4ff92SAndroid Build Coastguard Worker """ 141*89c4ff92SAndroid Build Coastguard Worker # FFT calculations are greatly accelerated for frame lengths which are powers of 2 142*89c4ff92SAndroid Build Coastguard Worker # Frames are padded and FFT bin width/length calculated accordingly 143*89c4ff92SAndroid Build Coastguard Worker num_fft_bins = int(self._frame_len_padded / 2) 144*89c4ff92SAndroid Build Coastguard Worker fft_bin_width = self.mfcc_params.sampling_freq / self._frame_len_padded 145*89c4ff92SAndroid Build Coastguard Worker 146*89c4ff92SAndroid Build Coastguard Worker mel_low_freq = self.mel_scale(self.mfcc_params.mel_lo_freq, self.mfcc_params.use_htk_method) 147*89c4ff92SAndroid Build Coastguard Worker mel_high_freq = self.mel_scale(self.mfcc_params.mel_hi_freq, self.mfcc_params.use_htk_method) 148*89c4ff92SAndroid Build Coastguard Worker mel_freq_delta = (mel_high_freq - mel_low_freq) / (self.mfcc_params.num_fbank_bins + 1) 149*89c4ff92SAndroid Build Coastguard Worker 150*89c4ff92SAndroid Build Coastguard Worker this_bin = np.zeros(num_fft_bins) 151*89c4ff92SAndroid Build Coastguard Worker mel_fbank = [0] * self.mfcc_params.num_fbank_bins 152*89c4ff92SAndroid Build Coastguard Worker for bin_num in range(self.mfcc_params.num_fbank_bins): 153*89c4ff92SAndroid Build Coastguard Worker left_mel = mel_low_freq + bin_num * mel_freq_delta 154*89c4ff92SAndroid Build Coastguard Worker center_mel = mel_low_freq + (bin_num + 1) * mel_freq_delta 155*89c4ff92SAndroid Build Coastguard Worker right_mel = mel_low_freq + (bin_num + 2) * mel_freq_delta 156*89c4ff92SAndroid Build Coastguard Worker first_index = last_index = -1 157*89c4ff92SAndroid Build Coastguard Worker 158*89c4ff92SAndroid Build Coastguard Worker for i in range(num_fft_bins): 159*89c4ff92SAndroid Build Coastguard Worker freq = (fft_bin_width * i) 160*89c4ff92SAndroid Build Coastguard Worker mel = self.mel_scale(freq, self.mfcc_params.use_htk_method) 161*89c4ff92SAndroid Build Coastguard Worker this_bin[i] = 0.0 162*89c4ff92SAndroid Build Coastguard Worker 163*89c4ff92SAndroid Build Coastguard Worker if (mel > left_mel) and (mel < right_mel): 164*89c4ff92SAndroid Build Coastguard Worker if mel <= center_mel: 165*89c4ff92SAndroid Build Coastguard Worker weight = (mel - left_mel) / (center_mel - left_mel) 166*89c4ff92SAndroid Build Coastguard Worker else: 167*89c4ff92SAndroid Build Coastguard Worker weight = (right_mel - mel) / (right_mel - center_mel) 168*89c4ff92SAndroid Build Coastguard Worker 169*89c4ff92SAndroid Build Coastguard Worker this_bin[i] = self.mel_norm(weight, right_mel, left_mel) 170*89c4ff92SAndroid Build Coastguard Worker 171*89c4ff92SAndroid Build Coastguard Worker if first_index == -1: 172*89c4ff92SAndroid Build Coastguard Worker first_index = i 173*89c4ff92SAndroid Build Coastguard Worker last_index = i 174*89c4ff92SAndroid Build Coastguard Worker 175*89c4ff92SAndroid Build Coastguard Worker self._filter_bank_filter_first[bin_num] = first_index 176*89c4ff92SAndroid Build Coastguard Worker self._filter_bank_filter_last[bin_num] = last_index 177*89c4ff92SAndroid Build Coastguard Worker mel_fbank[bin_num] = np.zeros(last_index - first_index + 1) 178*89c4ff92SAndroid Build Coastguard Worker j = 0 179*89c4ff92SAndroid Build Coastguard Worker 180*89c4ff92SAndroid Build Coastguard Worker for i in range(first_index, last_index + 1): 181*89c4ff92SAndroid Build Coastguard Worker mel_fbank[bin_num][j] = this_bin[i] 182*89c4ff92SAndroid Build Coastguard Worker j += 1 183*89c4ff92SAndroid Build Coastguard Worker 184*89c4ff92SAndroid Build Coastguard Worker return mel_fbank 185*89c4ff92SAndroid Build Coastguard Worker 186*89c4ff92SAndroid Build Coastguard Worker 187*89c4ff92SAndroid Build Coastguard Workerclass AudioPreprocessor: 188*89c4ff92SAndroid Build Coastguard Worker 189*89c4ff92SAndroid Build Coastguard Worker def __init__(self, mfcc, model_input_size, stride): 190*89c4ff92SAndroid Build Coastguard Worker self.model_input_size = model_input_size 191*89c4ff92SAndroid Build Coastguard Worker self.stride = stride 192*89c4ff92SAndroid Build Coastguard Worker self._mfcc_calc = mfcc 193*89c4ff92SAndroid Build Coastguard Worker 194*89c4ff92SAndroid Build Coastguard Worker def _normalize(self, values): 195*89c4ff92SAndroid Build Coastguard Worker """ 196*89c4ff92SAndroid Build Coastguard Worker Normalize values to mean 0 and std 1 197*89c4ff92SAndroid Build Coastguard Worker """ 198*89c4ff92SAndroid Build Coastguard Worker ret_val = (values - np.mean(values)) / np.std(values) 199*89c4ff92SAndroid Build Coastguard Worker return ret_val 200*89c4ff92SAndroid Build Coastguard Worker 201*89c4ff92SAndroid Build Coastguard Worker def _get_features(self, features, mfcc_instance, audio_data): 202*89c4ff92SAndroid Build Coastguard Worker idx = 0 203*89c4ff92SAndroid Build Coastguard Worker while len(features) < self.model_input_size * mfcc_instance.mfcc_params.num_mfcc_feats: 204*89c4ff92SAndroid Build Coastguard Worker current_frame_feats = mfcc_instance.mfcc_compute(audio_data[idx:idx + int(mfcc_instance.mfcc_params.frame_len)]) 205*89c4ff92SAndroid Build Coastguard Worker features.extend(current_frame_feats) 206*89c4ff92SAndroid Build Coastguard Worker idx += self.stride 207*89c4ff92SAndroid Build Coastguard Worker 208*89c4ff92SAndroid Build Coastguard Worker def mfcc_delta_calc(self, features): 209*89c4ff92SAndroid Build Coastguard Worker """ 210*89c4ff92SAndroid Build Coastguard Worker Placeholder function over-ridden in child class 211*89c4ff92SAndroid Build Coastguard Worker """ 212*89c4ff92SAndroid Build Coastguard Worker return features 213*89c4ff92SAndroid Build Coastguard Worker 214*89c4ff92SAndroid Build Coastguard Worker def extract_features(self, audio_data): 215*89c4ff92SAndroid Build Coastguard Worker """ 216*89c4ff92SAndroid Build Coastguard Worker Extracts the MFCC features. Also calculates each features first and second order derivatives 217*89c4ff92SAndroid Build Coastguard Worker if the mfcc_delta_calc() function has been implemented by a child class. 218*89c4ff92SAndroid Build Coastguard Worker The matrix returned should be sized appropriately for input to the model, based 219*89c4ff92SAndroid Build Coastguard Worker on the model info specified in the MFCC instance. 220*89c4ff92SAndroid Build Coastguard Worker 221*89c4ff92SAndroid Build Coastguard Worker Args: 222*89c4ff92SAndroid Build Coastguard Worker audio_data: the audio data to be used for this calculation 223*89c4ff92SAndroid Build Coastguard Worker Returns: 224*89c4ff92SAndroid Build Coastguard Worker the derived MFCC feature vector, sized appropriately for inference 225*89c4ff92SAndroid Build Coastguard Worker """ 226*89c4ff92SAndroid Build Coastguard Worker 227*89c4ff92SAndroid Build Coastguard Worker num_samples_per_inference = ((self.model_input_size - 1) 228*89c4ff92SAndroid Build Coastguard Worker * self.stride) + self._mfcc_calc.mfcc_params.frame_len 229*89c4ff92SAndroid Build Coastguard Worker 230*89c4ff92SAndroid Build Coastguard Worker if len(audio_data) < num_samples_per_inference: 231*89c4ff92SAndroid Build Coastguard Worker raise ValueError("audio_data size for feature extraction is smaller than " 232*89c4ff92SAndroid Build Coastguard Worker "the expected number of samples needed for inference") 233*89c4ff92SAndroid Build Coastguard Worker 234*89c4ff92SAndroid Build Coastguard Worker features = [] 235*89c4ff92SAndroid Build Coastguard Worker self._get_features(features, self._mfcc_calc, np.asarray(audio_data)) 236*89c4ff92SAndroid Build Coastguard Worker features = np.reshape(np.array(features), (self.model_input_size, self._mfcc_calc.mfcc_params.num_mfcc_feats)) 237*89c4ff92SAndroid Build Coastguard Worker features = self.mfcc_delta_calc(features) 238*89c4ff92SAndroid Build Coastguard Worker return np.float32(features) 239