xref: /aosp_15_r20/external/armnn/samples/common/include/Audio/MFCC.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <vector>
9*89c4ff92SAndroid Build Coastguard Worker #include <cstdint>
10*89c4ff92SAndroid Build Coastguard Worker #include <cmath>
11*89c4ff92SAndroid Build Coastguard Worker #include <limits>
12*89c4ff92SAndroid Build Coastguard Worker #include <string>
13*89c4ff92SAndroid Build Coastguard Worker 
14*89c4ff92SAndroid Build Coastguard Worker /* MFCC's consolidated parameters */
15*89c4ff92SAndroid Build Coastguard Worker class MfccParams
16*89c4ff92SAndroid Build Coastguard Worker {
17*89c4ff92SAndroid Build Coastguard Worker public:
18*89c4ff92SAndroid Build Coastguard Worker     float       m_samplingFreq;
19*89c4ff92SAndroid Build Coastguard Worker     int         m_numFbankBins;
20*89c4ff92SAndroid Build Coastguard Worker     float       m_melLoFreq;
21*89c4ff92SAndroid Build Coastguard Worker     float       m_melHiFreq;
22*89c4ff92SAndroid Build Coastguard Worker     int         m_numMfccFeatures;
23*89c4ff92SAndroid Build Coastguard Worker     int         m_frameLen;
24*89c4ff92SAndroid Build Coastguard Worker     int         m_frameLenPadded;
25*89c4ff92SAndroid Build Coastguard Worker     bool        m_useHtkMethod;
26*89c4ff92SAndroid Build Coastguard Worker     int         m_numMfccVectors;
27*89c4ff92SAndroid Build Coastguard Worker     /** @brief  Constructor */
28*89c4ff92SAndroid Build Coastguard Worker     MfccParams(const float samplingFreq, const int numFbankBins,
29*89c4ff92SAndroid Build Coastguard Worker                const float melLoFreq, const float melHiFreq,
30*89c4ff92SAndroid Build Coastguard Worker                const int numMfccFeats, const int frameLen,
31*89c4ff92SAndroid Build Coastguard Worker                const bool useHtkMethod, const int numMfccVectors);
32*89c4ff92SAndroid Build Coastguard Worker     /* Delete the default constructor */
33*89c4ff92SAndroid Build Coastguard Worker     MfccParams()  = delete;
34*89c4ff92SAndroid Build Coastguard Worker     /* Default destructor */
35*89c4ff92SAndroid Build Coastguard Worker     ~MfccParams() = default;
36*89c4ff92SAndroid Build Coastguard Worker     /** @brief  String representation of parameters */
37*89c4ff92SAndroid Build Coastguard Worker     std::string Str();
38*89c4ff92SAndroid Build Coastguard Worker };
39*89c4ff92SAndroid Build Coastguard Worker 
40*89c4ff92SAndroid Build Coastguard Worker /**
41*89c4ff92SAndroid Build Coastguard Worker  * @brief   Class for MFCC feature extraction.
42*89c4ff92SAndroid Build Coastguard Worker  *          Based on https://github.com/ARM-software/ML-KWS-for-MCU/blob/master/Deployment/Source/MFCC/mfcc.cpp
43*89c4ff92SAndroid Build Coastguard Worker  *          This class is designed to be generic and self-sufficient but
44*89c4ff92SAndroid Build Coastguard Worker  *          certain calculation routines can be overridden to accommodate
45*89c4ff92SAndroid Build Coastguard Worker  *          use-case specific requirements.
46*89c4ff92SAndroid Build Coastguard Worker  */
47*89c4ff92SAndroid Build Coastguard Worker class MFCC {
48*89c4ff92SAndroid Build Coastguard Worker public:
49*89c4ff92SAndroid Build Coastguard Worker     /**
50*89c4ff92SAndroid Build Coastguard Worker      * @brief       Constructor
51*89c4ff92SAndroid Build Coastguard Worker      * @param[in]   params   MFCC parameters
52*89c4ff92SAndroid Build Coastguard Worker     */
53*89c4ff92SAndroid Build Coastguard Worker     explicit MFCC(const MfccParams& params);
54*89c4ff92SAndroid Build Coastguard Worker 
55*89c4ff92SAndroid Build Coastguard Worker     MFCC() = delete;
56*89c4ff92SAndroid Build Coastguard Worker 
57*89c4ff92SAndroid Build Coastguard Worker     ~MFCC() = default;
58*89c4ff92SAndroid Build Coastguard Worker 
59*89c4ff92SAndroid Build Coastguard Worker     /**
60*89c4ff92SAndroid Build Coastguard Worker     * @brief        Extract MFCC  features for one single small frame of
61*89c4ff92SAndroid Build Coastguard Worker     *               audio data e.g. 640 samples.
62*89c4ff92SAndroid Build Coastguard Worker     * @param[in]    audioData   Vector of audio samples to calculate
63*89c4ff92SAndroid Build Coastguard Worker     *                           features for.
64*89c4ff92SAndroid Build Coastguard Worker     * @return       Vector of extracted MFCC features.
65*89c4ff92SAndroid Build Coastguard Worker     **/
66*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> MfccCompute(const std::vector<float>& audioData);
67*89c4ff92SAndroid Build Coastguard Worker 
68*89c4ff92SAndroid Build Coastguard Worker     /** @brief  Initialise. */
69*89c4ff92SAndroid Build Coastguard Worker     void Init();
70*89c4ff92SAndroid Build Coastguard Worker 
71*89c4ff92SAndroid Build Coastguard Worker    /**
72*89c4ff92SAndroid Build Coastguard Worker     * @brief        Extract MFCC features and quantise for one single small
73*89c4ff92SAndroid Build Coastguard Worker     *               frame of audio data e.g. 640 samples.
74*89c4ff92SAndroid Build Coastguard Worker     * @param[in]    audioData     Vector of audio samples to calculate
75*89c4ff92SAndroid Build Coastguard Worker     *                             features for.
76*89c4ff92SAndroid Build Coastguard Worker     * @param[in]    quantScale    Quantisation scale.
77*89c4ff92SAndroid Build Coastguard Worker     * @param[in]    quantOffset   Quantisation offset.
78*89c4ff92SAndroid Build Coastguard Worker     * @return       Vector of extracted quantised MFCC features.
79*89c4ff92SAndroid Build Coastguard Worker     **/
80*89c4ff92SAndroid Build Coastguard Worker     template<typename T>
MfccComputeQuant(const std::vector<float> & audioData,const float quantScale,const int quantOffset)81*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> MfccComputeQuant(const std::vector<float>& audioData,
82*89c4ff92SAndroid Build Coastguard Worker                                     const float quantScale,
83*89c4ff92SAndroid Build Coastguard Worker                                     const int quantOffset)
84*89c4ff92SAndroid Build Coastguard Worker     {
85*89c4ff92SAndroid Build Coastguard Worker         this->MfccComputePreFeature(audioData);
86*89c4ff92SAndroid Build Coastguard Worker         float minVal = std::numeric_limits<T>::min();
87*89c4ff92SAndroid Build Coastguard Worker         float maxVal = std::numeric_limits<T>::max();
88*89c4ff92SAndroid Build Coastguard Worker 
89*89c4ff92SAndroid Build Coastguard Worker         std::vector<T> mfccOut(this->m_params.m_numMfccFeatures);
90*89c4ff92SAndroid Build Coastguard Worker         const size_t numFbankBins = this->m_params.m_numFbankBins;
91*89c4ff92SAndroid Build Coastguard Worker 
92*89c4ff92SAndroid Build Coastguard Worker         /* Take DCT. Uses matrix mul. */
93*89c4ff92SAndroid Build Coastguard Worker         for (size_t i = 0, j = 0; i < mfccOut.size(); ++i, j += numFbankBins)
94*89c4ff92SAndroid Build Coastguard Worker         {
95*89c4ff92SAndroid Build Coastguard Worker             float sum = 0;
96*89c4ff92SAndroid Build Coastguard Worker             for (size_t k = 0; k < numFbankBins; ++k)
97*89c4ff92SAndroid Build Coastguard Worker             {
98*89c4ff92SAndroid Build Coastguard Worker                 sum += this->m_dctMatrix[j + k] * this->m_melEnergies[k];
99*89c4ff92SAndroid Build Coastguard Worker             }
100*89c4ff92SAndroid Build Coastguard Worker             /* Quantize to T. */
101*89c4ff92SAndroid Build Coastguard Worker             sum = std::round((sum / quantScale) + quantOffset);
102*89c4ff92SAndroid Build Coastguard Worker             mfccOut[i] = static_cast<T>(std::min<float>(std::max<float>(sum, minVal), maxVal));
103*89c4ff92SAndroid Build Coastguard Worker         }
104*89c4ff92SAndroid Build Coastguard Worker 
105*89c4ff92SAndroid Build Coastguard Worker         return mfccOut;
106*89c4ff92SAndroid Build Coastguard Worker     }
107*89c4ff92SAndroid Build Coastguard Worker 
108*89c4ff92SAndroid Build Coastguard Worker     MfccParams m_params;
109*89c4ff92SAndroid Build Coastguard Worker 
110*89c4ff92SAndroid Build Coastguard Worker     /* Constants */
111*89c4ff92SAndroid Build Coastguard Worker     static constexpr float ms_logStep = /*logf(6.4)*/ 1.8562979903656 / 27.0;
112*89c4ff92SAndroid Build Coastguard Worker     static constexpr float ms_freqStep = 200.0 / 3;
113*89c4ff92SAndroid Build Coastguard Worker     static constexpr float ms_minLogHz = 1000.0;
114*89c4ff92SAndroid Build Coastguard Worker     static constexpr float ms_minLogMel = ms_minLogHz / ms_freqStep;
115*89c4ff92SAndroid Build Coastguard Worker 
116*89c4ff92SAndroid Build Coastguard Worker protected:
117*89c4ff92SAndroid Build Coastguard Worker     /**
118*89c4ff92SAndroid Build Coastguard Worker      * @brief       Project input frequency to Mel Scale.
119*89c4ff92SAndroid Build Coastguard Worker      * @param[in]   freq           Input frequency in floating point.
120*89c4ff92SAndroid Build Coastguard Worker      * @param[in]   useHTKMethod   bool to signal if HTK method is to be
121*89c4ff92SAndroid Build Coastguard Worker      *                             used for calculation.
122*89c4ff92SAndroid Build Coastguard Worker      * @return      Mel transformed frequency in floating point.
123*89c4ff92SAndroid Build Coastguard Worker      **/
124*89c4ff92SAndroid Build Coastguard Worker     static float MelScale(float freq,
125*89c4ff92SAndroid Build Coastguard Worker                           bool  useHTKMethod = true);
126*89c4ff92SAndroid Build Coastguard Worker 
127*89c4ff92SAndroid Build Coastguard Worker     /**
128*89c4ff92SAndroid Build Coastguard Worker      * @brief       Inverse Mel transform - convert MEL warped frequency
129*89c4ff92SAndroid Build Coastguard Worker      *              back to normal frequency.
130*89c4ff92SAndroid Build Coastguard Worker      * @param[in]   melFreq        Mel frequency in floating point.
131*89c4ff92SAndroid Build Coastguard Worker      * @param[in]   useHTKMethod   bool to signal if HTK method is to be
132*89c4ff92SAndroid Build Coastguard Worker      *                             used for calculation.
133*89c4ff92SAndroid Build Coastguard Worker      * @return      Real world frequency in floating point.
134*89c4ff92SAndroid Build Coastguard Worker      **/
135*89c4ff92SAndroid Build Coastguard Worker     static float InverseMelScale(float melFreq,
136*89c4ff92SAndroid Build Coastguard Worker                                  bool  useHTKMethod = true);
137*89c4ff92SAndroid Build Coastguard Worker 
138*89c4ff92SAndroid Build Coastguard Worker     /**
139*89c4ff92SAndroid Build Coastguard Worker      * @brief       Populates MEL energies after applying the MEL filter
140*89c4ff92SAndroid Build Coastguard Worker      *              bank weights and adding them up to be placed into
141*89c4ff92SAndroid Build Coastguard Worker      *              bins, according to the filter bank's first and last
142*89c4ff92SAndroid Build Coastguard Worker      *              indices (pre-computed for each filter bank element
143*89c4ff92SAndroid Build Coastguard Worker      *              by CreateMelFilterBank function).
144*89c4ff92SAndroid Build Coastguard Worker      * @param[in]   fftVec                  Vector populated with FFT magnitudes.
145*89c4ff92SAndroid Build Coastguard Worker      * @param[in]   melFilterBank           2D Vector with filter bank weights.
146*89c4ff92SAndroid Build Coastguard Worker      * @param[in]   filterBankFilterFirst   Vector containing the first indices of filter bank
147*89c4ff92SAndroid Build Coastguard Worker      *                                      to be used for each bin.
148*89c4ff92SAndroid Build Coastguard Worker      * @param[in]   filterBankFilterLast    Vector containing the last indices of filter bank
149*89c4ff92SAndroid Build Coastguard Worker      *                                      to be used for each bin.
150*89c4ff92SAndroid Build Coastguard Worker      * @param[out]  melEnergies             Pre-allocated vector of MEL energies to be
151*89c4ff92SAndroid Build Coastguard Worker      *                                      populated.
152*89c4ff92SAndroid Build Coastguard Worker      * @return      true if successful, false otherwise.
153*89c4ff92SAndroid Build Coastguard Worker      */
154*89c4ff92SAndroid Build Coastguard Worker     virtual bool ApplyMelFilterBank(
155*89c4ff92SAndroid Build Coastguard Worker         std::vector<float>&                 fftVec,
156*89c4ff92SAndroid Build Coastguard Worker         std::vector<std::vector<float>>&    melFilterBank,
157*89c4ff92SAndroid Build Coastguard Worker         std::vector<uint32_t>&              filterBankFilterFirst,
158*89c4ff92SAndroid Build Coastguard Worker         std::vector<uint32_t>&              filterBankFilterLast,
159*89c4ff92SAndroid Build Coastguard Worker         std::vector<float>&                 melEnergies);
160*89c4ff92SAndroid Build Coastguard Worker 
161*89c4ff92SAndroid Build Coastguard Worker     /**
162*89c4ff92SAndroid Build Coastguard Worker      * @brief           Converts the Mel energies for logarithmic scale.
163*89c4ff92SAndroid Build Coastguard Worker      * @param[in,out]   melEnergies   1D vector of Mel energies.
164*89c4ff92SAndroid Build Coastguard Worker      **/
165*89c4ff92SAndroid Build Coastguard Worker     virtual void ConvertToLogarithmicScale(std::vector<float>& melEnergies);
166*89c4ff92SAndroid Build Coastguard Worker 
167*89c4ff92SAndroid Build Coastguard Worker     /**
168*89c4ff92SAndroid Build Coastguard Worker      * @brief       Create a matrix used to calculate Discrete Cosine
169*89c4ff92SAndroid Build Coastguard Worker      *              Transform.
170*89c4ff92SAndroid Build Coastguard Worker      * @param[in]   inputLength        Input length of the buffer on which
171*89c4ff92SAndroid Build Coastguard Worker      *                                 DCT will be performed.
172*89c4ff92SAndroid Build Coastguard Worker      * @param[in]   coefficientCount   Total coefficients per input length.
173*89c4ff92SAndroid Build Coastguard Worker      * @return      1D vector with inputLength x coefficientCount elements
174*89c4ff92SAndroid Build Coastguard Worker      *              populated with DCT coefficients.
175*89c4ff92SAndroid Build Coastguard Worker      */
176*89c4ff92SAndroid Build Coastguard Worker     virtual std::vector<float> CreateDCTMatrix(
177*89c4ff92SAndroid Build Coastguard Worker                                 int32_t inputLength,
178*89c4ff92SAndroid Build Coastguard Worker                                 int32_t coefficientCount);
179*89c4ff92SAndroid Build Coastguard Worker 
180*89c4ff92SAndroid Build Coastguard Worker     /**
181*89c4ff92SAndroid Build Coastguard Worker      * @brief       Given the low and high Mel values, get the normaliser
182*89c4ff92SAndroid Build Coastguard Worker      *              for weights to be applied when populating the filter
183*89c4ff92SAndroid Build Coastguard Worker      *              bank.
184*89c4ff92SAndroid Build Coastguard Worker      * @param[in]   leftMel        Low Mel frequency value.
185*89c4ff92SAndroid Build Coastguard Worker      * @param[in]   rightMel       High Mel frequency value.
186*89c4ff92SAndroid Build Coastguard Worker      * @param[in]   useHTKMethod   bool to signal if HTK method is to be
187*89c4ff92SAndroid Build Coastguard Worker      *                             used for calculation.
188*89c4ff92SAndroid Build Coastguard Worker      * @return      Value to use for normalizing.
189*89c4ff92SAndroid Build Coastguard Worker      */
190*89c4ff92SAndroid Build Coastguard Worker     virtual float GetMelFilterBankNormaliser(
191*89c4ff92SAndroid Build Coastguard Worker                     const float&   leftMel,
192*89c4ff92SAndroid Build Coastguard Worker                     const float&   rightMel,
193*89c4ff92SAndroid Build Coastguard Worker                     bool     useHTKMethod);
194*89c4ff92SAndroid Build Coastguard Worker 
195*89c4ff92SAndroid Build Coastguard Worker private:
196*89c4ff92SAndroid Build Coastguard Worker 
197*89c4ff92SAndroid Build Coastguard Worker     std::vector<float>              m_frame;
198*89c4ff92SAndroid Build Coastguard Worker     std::vector<float>              m_buffer;
199*89c4ff92SAndroid Build Coastguard Worker     std::vector<float>              m_melEnergies;
200*89c4ff92SAndroid Build Coastguard Worker     std::vector<float>              m_windowFunc;
201*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::vector<float>> m_melFilterBank;
202*89c4ff92SAndroid Build Coastguard Worker     std::vector<float>              m_dctMatrix;
203*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint32_t>           m_filterBankFilterFirst;
204*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint32_t>           m_filterBankFilterLast;
205*89c4ff92SAndroid Build Coastguard Worker     bool                            m_filterBankInitialised;
206*89c4ff92SAndroid Build Coastguard Worker 
207*89c4ff92SAndroid Build Coastguard Worker     /**
208*89c4ff92SAndroid Build Coastguard Worker      * @brief       Initialises the filter banks and the DCT matrix. **/
209*89c4ff92SAndroid Build Coastguard Worker     void InitMelFilterBank();
210*89c4ff92SAndroid Build Coastguard Worker 
211*89c4ff92SAndroid Build Coastguard Worker     /**
212*89c4ff92SAndroid Build Coastguard Worker      * @brief       Signals whether the instance of MFCC has had its
213*89c4ff92SAndroid Build Coastguard Worker      *              required buffers initialised.
214*89c4ff92SAndroid Build Coastguard Worker      * @return      true if initialised, false otherwise.
215*89c4ff92SAndroid Build Coastguard Worker      **/
216*89c4ff92SAndroid Build Coastguard Worker     bool IsMelFilterBankInited() const;
217*89c4ff92SAndroid Build Coastguard Worker 
218*89c4ff92SAndroid Build Coastguard Worker     /**
219*89c4ff92SAndroid Build Coastguard Worker      * @brief       Create mel filter banks for MFCC calculation.
220*89c4ff92SAndroid Build Coastguard Worker      * @return      2D vector of floats.
221*89c4ff92SAndroid Build Coastguard Worker      **/
222*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::vector<float>> CreateMelFilterBank();
223*89c4ff92SAndroid Build Coastguard Worker 
224*89c4ff92SAndroid Build Coastguard Worker     /**
225*89c4ff92SAndroid Build Coastguard Worker      * @brief       Computes and populates internal memeber buffers used
226*89c4ff92SAndroid Build Coastguard Worker      *              in MFCC feature calculation
227*89c4ff92SAndroid Build Coastguard Worker      * @param[in]   audioData   1D vector of 16-bit audio data.
228*89c4ff92SAndroid Build Coastguard Worker      */
229*89c4ff92SAndroid Build Coastguard Worker     void MfccComputePreFeature(const std::vector<float>& audioData);
230*89c4ff92SAndroid Build Coastguard Worker 
231*89c4ff92SAndroid Build Coastguard Worker     /** @brief       Computes the magnitude from an interleaved complex array. */
232*89c4ff92SAndroid Build Coastguard Worker     void ConvertToPowerSpectrum();
233*89c4ff92SAndroid Build Coastguard Worker 
234*89c4ff92SAndroid Build Coastguard Worker };
235