xref: /aosp_15_r20/external/armnn/samples/SpeechRecognition/src/Wav2LetterMFCC.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "Wav2LetterMFCC.hpp"
6 #include "MathUtils.hpp"
7 
8 #include <cfloat>
9 
ApplyMelFilterBank(std::vector<float> & fftVec,std::vector<std::vector<float>> & melFilterBank,std::vector<uint32_t> & filterBankFilterFirst,std::vector<uint32_t> & filterBankFilterLast,std::vector<float> & melEnergies)10 bool Wav2LetterMFCC::ApplyMelFilterBank(
11         std::vector<float>&                 fftVec,
12         std::vector<std::vector<float>>&    melFilterBank,
13         std::vector<uint32_t>&               filterBankFilterFirst,
14         std::vector<uint32_t>&               filterBankFilterLast,
15         std::vector<float>&                 melEnergies)
16 {
17     const size_t numBanks = melEnergies.size();
18 
19     if (numBanks != filterBankFilterFirst.size() ||
20             numBanks != filterBankFilterLast.size())
21     {
22         printf("Unexpected filter bank lengths\n");
23         return false;
24     }
25 
26     for (size_t bin = 0; bin < numBanks; ++bin)
27     {
28         auto filterBankIter = melFilterBank[bin].begin();
29         auto end = melFilterBank[bin].end();
30         // Avoid log of zero at later stages, same value used in librosa.
31         // The number was used during our default wav2letter model training.
32         float melEnergy = 1e-10;
33         const uint32_t firstIndex = filterBankFilterFirst[bin];
34         const uint32_t lastIndex = std::min<uint32_t>(filterBankFilterLast[bin], fftVec.size() - 1);
35 
36         for (uint32_t i = firstIndex; i <= lastIndex && filterBankIter != end; ++i)
37         {
38             melEnergy += (*filterBankIter++ * fftVec[i]);
39         }
40 
41         melEnergies[bin] = melEnergy;
42     }
43 
44     return true;
45 }
46 
ConvertToLogarithmicScale(std::vector<float> & melEnergies)47 void Wav2LetterMFCC::ConvertToLogarithmicScale(std::vector<float>& melEnergies)
48 {
49     float maxMelEnergy = -FLT_MAX;
50 
51     // Container for natural logarithms of mel energies.
52     std::vector <float> vecLogEnergies(melEnergies.size(), 0.f);
53 
54     // Because we are taking natural logs, we need to multiply by log10(e).
55     // Also, for wav2letter model, we scale our log10 values by 10.
56     constexpr float multiplier = 10.0 *  // Default scalar.
57                                   0.4342944819032518;  // log10f(std::exp(1.0))
58 
59     // Take log of the whole vector.
60     MathUtils::VecLogarithmF32(melEnergies, vecLogEnergies);
61 
62     // Scale the log values and get the max.
63     for (auto iterM = melEnergies.begin(), iterL = vecLogEnergies.begin();
64               iterM != melEnergies.end() && iterL != vecLogEnergies.end(); ++iterM, ++iterL)
65     {
66 
67         *iterM = *iterL * multiplier;
68 
69         // Save the max mel energy.
70         if (*iterM > maxMelEnergy)
71         {
72             maxMelEnergy = *iterM;
73         }
74     }
75 
76     // Clamp the mel energies.
77     constexpr float maxDb = 80.0;
78     const float clampLevelLowdB = maxMelEnergy - maxDb;
79     for (float& melEnergy : melEnergies)
80     {
81         melEnergy = std::max(melEnergy, clampLevelLowdB);
82     }
83 }
84 
CreateDCTMatrix(const int32_t inputLength,const int32_t coefficientCount)85 std::vector<float> Wav2LetterMFCC::CreateDCTMatrix(
86                                     const int32_t inputLength,
87                                     const int32_t coefficientCount)
88 {
89     std::vector<float> dctMatix(inputLength * coefficientCount);
90 
91     // Orthonormal normalization.
92     const float normalizerK0 = 2 * sqrtf(1.0f /
93                                     static_cast<float>(4 * inputLength));
94     const float normalizer = 2 * sqrtf(1.0f /
95                                     static_cast<float>(2 * inputLength));
96 
97     const float angleIncr = M_PI / inputLength;
98     float angle = angleIncr;  // We start using it at k = 1 loop.
99 
100     // First row of DCT will use normalizer K0.
101     for (int32_t n = 0; n < inputLength; ++n)
102     {
103         dctMatix[n] = normalizerK0;  // cos(0) = 1
104     }
105 
106     // Second row (index = 1) onwards, we use standard normalizer.
107     for (int32_t k = 1, m = inputLength; k < coefficientCount; ++k, m += inputLength)
108     {
109         for (int32_t n = 0; n < inputLength; ++n)
110         {
111             dctMatix[m+n] = normalizer * cosf((n + 0.5f) * angle);
112         }
113         angle += angleIncr;
114     }
115     return dctMatix;
116 }
117 
GetMelFilterBankNormaliser(const float & leftMel,const float & rightMel,const bool useHTKMethod)118 float Wav2LetterMFCC::GetMelFilterBankNormaliser(
119                                 const float&    leftMel,
120                                 const float&    rightMel,
121                                 const bool      useHTKMethod)
122 {
123     // Slaney normalization for mel weights.
124     return (2.0f / (MFCC::InverseMelScale(rightMel, useHTKMethod) -
125             MFCC::InverseMelScale(leftMel, useHTKMethod)));
126 }
127