xref: /aosp_15_r20/external/armnn/samples/common/src/Audio/MFCC.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "MFCC.hpp"
6 #include "MathUtils.hpp"
7 
8 #include <cfloat>
9 #include <cinttypes>
10 #include <cstring>
11 
MfccParams(const float samplingFreq,const int numFbankBins,const float melLoFreq,const float melHiFreq,const int numMfccFeats,const int frameLen,const bool useHtkMethod,const int numMfccVectors)12 MfccParams::MfccParams(
13         const float samplingFreq,
14         const int numFbankBins,
15         const float melLoFreq,
16         const float melHiFreq,
17         const int numMfccFeats,
18         const int frameLen,
19         const bool useHtkMethod,
20         const int numMfccVectors):
21         m_samplingFreq(samplingFreq),
22         m_numFbankBins(numFbankBins),
23         m_melLoFreq(melLoFreq),
24         m_melHiFreq(melHiFreq),
25         m_numMfccFeatures(numMfccFeats),
26         m_frameLen(frameLen),
27         m_numMfccVectors(numMfccVectors),
28         /* Smallest power of 2 >= frame length. */
29         m_frameLenPadded(pow(2, ceil((log(frameLen)/log(2))))),
30         m_useHtkMethod(useHtkMethod)
31 {}
32 
Str()33 std::string MfccParams::Str()
34 {
35     char strC[1024];
36     snprintf(strC, sizeof(strC) - 1, "\n   \
37             \n\t Sampling frequency:         %f\
38             \n\t Number of filter banks:     %u\
39             \n\t Mel frequency limit (low):  %f\
40             \n\t Mel frequency limit (high): %f\
41             \n\t Number of MFCC features:    %u\
42             \n\t Frame length:               %u\
43             \n\t Padded frame length:        %u\
44             \n\t Using HTK for Mel scale:    %s\n",
45              this->m_samplingFreq, this->m_numFbankBins, this->m_melLoFreq,
46              this->m_melHiFreq, this->m_numMfccFeatures, this->m_frameLen,
47              this->m_frameLenPadded, this->m_useHtkMethod ? "yes" : "no");
48     return std::string{strC};
49 }
50 
MFCC(const MfccParams & params)51 MFCC::MFCC(const MfccParams& params):
52     m_params(params),
53     m_filterBankInitialised(false)
54 {
55     this->m_buffer = std::vector<float>(
56             this->m_params.m_frameLenPadded, 0.0);
57     this->m_frame = std::vector<float>(
58             this->m_params.m_frameLenPadded, 0.0);
59     this->m_melEnergies = std::vector<float>(
60             this->m_params.m_numFbankBins, 0.0);
61 
62     this->m_windowFunc = std::vector<float>(this->m_params.m_frameLen);
63     const auto multiplier = static_cast<float>(2 * M_PI / this->m_params.m_frameLen);
64 
65     /* Create window function. */
66     for (size_t i = 0; i < this->m_params.m_frameLen; i++)
67     {
68         this->m_windowFunc[i] = (0.5 - (0.5 * cosf(static_cast<float>(i) * multiplier)));
69     }
70 
71 }
72 
Init()73 void MFCC::Init()
74 {
75     this->InitMelFilterBank();
76 }
77 
MelScale(const float freq,const bool useHTKMethod)78 float MFCC::MelScale(const float freq, const bool useHTKMethod)
79 {
80     if (useHTKMethod)
81     {
82         return 1127.0f * logf (1.0f + freq / 700.0f);
83     }
84     else
85     {
86         /* Slaney formula for mel scale. */
87         float mel = freq / ms_freqStep;
88 
89         if (freq >= ms_minLogHz)
90         {
91             mel = ms_minLogMel + logf(freq / ms_minLogHz) / ms_logStep;
92         }
93         return mel;
94     }
95 }
96 
InverseMelScale(const float melFreq,const bool useHTKMethod)97 float MFCC::InverseMelScale(const float melFreq, const bool useHTKMethod)
98 {
99     if (useHTKMethod) {
100         return 700.0f * (expf (melFreq / 1127.0f) - 1.0f);
101     }
102     else
103     {
104         /* Slaney formula for mel scale. */
105         float freq = ms_freqStep * melFreq;
106 
107         if (melFreq >= ms_minLogMel)
108         {
109             freq = ms_minLogHz * expf(ms_logStep * (melFreq - ms_minLogMel));
110         }
111         return freq;
112     }
113 }
114 
115 
ApplyMelFilterBank(std::vector<float> & fftVec,std::vector<std::vector<float>> & melFilterBank,std::vector<uint32_t> & filterBankFilterFirst,std::vector<uint32_t> & filterBankFilterLast,std::vector<float> & melEnergies)116 bool MFCC::ApplyMelFilterBank(
117         std::vector<float>&                 fftVec,
118         std::vector<std::vector<float>>&    melFilterBank,
119         std::vector<uint32_t>&              filterBankFilterFirst,
120         std::vector<uint32_t>&              filterBankFilterLast,
121         std::vector<float>&                 melEnergies)
122 {
123     const size_t numBanks = melEnergies.size();
124 
125     if (numBanks != filterBankFilterFirst.size() ||
126         numBanks != filterBankFilterLast.size())
127     {
128         printf("unexpected filter bank lengths\n");
129         return false;
130     }
131 
132     for (size_t bin = 0; bin < numBanks; ++bin)
133     {
134         auto filterBankIter = melFilterBank[bin].begin();
135         auto end = melFilterBank[bin].end();
136         float melEnergy = FLT_MIN;  /* Avoid log of zero at later stages */
137         const uint32_t firstIndex = filterBankFilterFirst[bin];
138         const uint32_t lastIndex = std::min<uint32_t>(filterBankFilterLast[bin], fftVec.size() - 1);
139 
140         for (uint32_t i = firstIndex; i <= lastIndex && filterBankIter != end; i++)
141         {
142             float energyRep = sqrt(fftVec[i]);
143             melEnergy += (*filterBankIter++ * energyRep);
144         }
145 
146         melEnergies[bin] = melEnergy;
147     }
148 
149     return true;
150 }
151 
ConvertToLogarithmicScale(std::vector<float> & melEnergies)152 void MFCC::ConvertToLogarithmicScale(std::vector<float>& melEnergies)
153 {
154     for (float& melEnergy : melEnergies)
155     {
156         melEnergy = logf(melEnergy);
157     }
158 }
159 
ConvertToPowerSpectrum()160 void MFCC::ConvertToPowerSpectrum()
161 {
162     const uint32_t halfDim = this->m_buffer.size() / 2;
163 
164     /* Handle this special case. */
165     float firstEnergy = this->m_buffer[0] * this->m_buffer[0];
166     float lastEnergy = this->m_buffer[1] * this->m_buffer[1];
167 
168     MathUtils::ComplexMagnitudeSquaredF32(
169             this->m_buffer.data(),
170             this->m_buffer.size(),
171             this->m_buffer.data(),
172             this->m_buffer.size()/2);
173 
174     this->m_buffer[0] = firstEnergy;
175     this->m_buffer[halfDim] = lastEnergy;
176 }
177 
CreateDCTMatrix(const int32_t inputLength,const int32_t coefficientCount)178 std::vector<float> MFCC::CreateDCTMatrix(
179                             const int32_t inputLength,
180                             const int32_t coefficientCount)
181 {
182     std::vector<float> dctMatrix(inputLength * coefficientCount);
183 
184     const float normalizer = sqrtf(2.0f/inputLength);
185     const float angleIncr = M_PI/inputLength;
186     float angle = 0;
187 
188     for (int32_t k = 0, m = 0; k < coefficientCount; k++, m += inputLength)
189     {
190         for (int32_t n = 0; n < inputLength; n++)
191         {
192             dctMatrix[m + n] = normalizer * cosf((n + 0.5f) * angle);
193         }
194         angle += angleIncr;
195     }
196 
197     return dctMatrix;
198 }
199 
GetMelFilterBankNormaliser(const float & leftMel,const float & rightMel,const bool useHTKMethod)200 float MFCC::GetMelFilterBankNormaliser(
201                 const float&    leftMel,
202                 const float&    rightMel,
203                 const bool      useHTKMethod)
204 {
205     /* By default, no normalisation => return 1 */
206     return 1.f;
207 }
208 
InitMelFilterBank()209 void MFCC::InitMelFilterBank()
210 {
211     if (!this->IsMelFilterBankInited())
212     {
213         this->m_melFilterBank = this->CreateMelFilterBank();
214         this->m_dctMatrix = this->CreateDCTMatrix(
215                                 this->m_params.m_numFbankBins,
216                                 this->m_params.m_numMfccFeatures);
217         this->m_filterBankInitialised = true;
218     }
219 }
220 
IsMelFilterBankInited() const221 bool MFCC::IsMelFilterBankInited() const
222 {
223     return this->m_filterBankInitialised;
224 }
225 
MfccComputePreFeature(const std::vector<float> & audioData)226 void MFCC::MfccComputePreFeature(const std::vector<float>& audioData)
227 {
228     this->InitMelFilterBank();
229 
230     auto size = std::min(std::min(this->m_frame.size(), audioData.size()),
231                          static_cast<size_t>(this->m_params.m_frameLen)) * sizeof(float);
232     std::memcpy(this->m_frame.data(), audioData.data(), size);
233 
234     /* Apply window function to input frame. */
235     for(size_t i = 0; i < this->m_params.m_frameLen; i++)
236     {
237         this->m_frame[i] *= this->m_windowFunc[i];
238     }
239 
240     /* Set remaining frame values to 0. */
241     std::fill(this->m_frame.begin() + this->m_params.m_frameLen,this->m_frame.end(), 0);
242 
243     /* Compute FFT. */
244     MathUtils::FftF32(this->m_frame, this->m_buffer);
245 
246     /* Convert to power spectrum. */
247     this->ConvertToPowerSpectrum();
248 
249     /* Apply mel filterbanks. */
250     if (!this->ApplyMelFilterBank(this->m_buffer,
251                                   this->m_melFilterBank,
252                                   this->m_filterBankFilterFirst,
253                                   this->m_filterBankFilterLast,
254                                   this->m_melEnergies))
255     {
256         printf("Failed to apply MEL filter banks\n");
257     }
258 
259     /* Convert to logarithmic scale. */
260     this->ConvertToLogarithmicScale(this->m_melEnergies);
261 }
262 
MfccCompute(const std::vector<float> & audioData)263 std::vector<float> MFCC::MfccCompute(const std::vector<float>& audioData)
264 {
265     this->MfccComputePreFeature(audioData);
266 
267     std::vector<float> mfccOut(this->m_params.m_numMfccFeatures);
268 
269     float * ptrMel = this->m_melEnergies.data();
270     float * ptrDct = this->m_dctMatrix.data();
271     float * ptrMfcc = mfccOut.data();
272 
273     /* Take DCT. Uses matrix mul. */
274     for (size_t i = 0, j = 0; i < mfccOut.size();
275                 ++i, j += this->m_params.m_numFbankBins)
276     {
277         *ptrMfcc++ = MathUtils::DotProductF32(
278                 ptrDct + j,
279                 ptrMel,
280                 this->m_params.m_numFbankBins);
281     }
282     return mfccOut;
283 }
284 
CreateMelFilterBank()285 std::vector<std::vector<float>> MFCC::CreateMelFilterBank()
286 {
287     size_t numFftBins = this->m_params.m_frameLenPadded / 2;
288     float fftBinWidth = static_cast<float>(this->m_params.m_samplingFreq) / this->m_params.m_frameLenPadded;
289 
290     float melLowFreq = MFCC::MelScale(this->m_params.m_melLoFreq,
291                                       this->m_params.m_useHtkMethod);
292     float melHighFreq = MFCC::MelScale(this->m_params.m_melHiFreq,
293                                        this->m_params.m_useHtkMethod);
294     float melFreqDelta = (melHighFreq - melLowFreq) / (this->m_params.m_numFbankBins + 1);
295 
296     std::vector<float> thisBin = std::vector<float>(numFftBins);
297     std::vector<std::vector<float>> melFilterBank(
298                                         this->m_params.m_numFbankBins);
299     this->m_filterBankFilterFirst =
300                     std::vector<uint32_t>(this->m_params.m_numFbankBins);
301     this->m_filterBankFilterLast =
302                     std::vector<uint32_t>(this->m_params.m_numFbankBins);
303 
304     for (size_t bin = 0; bin < this->m_params.m_numFbankBins; bin++)
305     {
306         float leftMel = melLowFreq + bin * melFreqDelta;
307         float centerMel = melLowFreq + (bin + 1) * melFreqDelta;
308         float rightMel = melLowFreq + (bin + 2) * melFreqDelta;
309 
310         uint32_t firstIndex = 0;
311         uint32_t lastIndex = 0;
312         bool firstIndexFound = false;
313         const float normaliser = this->GetMelFilterBankNormaliser(leftMel, rightMel, this->m_params.m_useHtkMethod);
314 
315         for (size_t i = 0; i < numFftBins; i++)
316         {
317             float freq = (fftBinWidth * i);  /* Center freq of this fft bin. */
318             float mel = MFCC::MelScale(freq, this->m_params.m_useHtkMethod);
319             thisBin[i] = 0.0;
320 
321             if (mel > leftMel && mel < rightMel)
322             {
323                 float weight;
324                 if (mel <= centerMel)
325                 {
326                     weight = (mel - leftMel) / (centerMel - leftMel);
327                 }
328                 else
329                 {
330                     weight = (rightMel - mel) / (rightMel - centerMel);
331                 }
332 
333                 thisBin[i] = weight * normaliser;
334                 if (!firstIndexFound)
335                 {
336                     firstIndex = i;
337                     firstIndexFound = true;
338                 }
339                 lastIndex = i;
340             }
341         }
342 
343         this->m_filterBankFilterFirst[bin] = firstIndex;
344         this->m_filterBankFilterLast[bin] = lastIndex;
345 
346         /* Copy the part we care about. */
347         for (uint32_t i = firstIndex; i <= lastIndex; i++)
348         {
349             melFilterBank[bin].push_back(thisBin[i]);
350         }
351     }
352 
353     return melFilterBank;
354 }
355