1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #include "MathUtils.hpp"
6*89c4ff92SAndroid Build Coastguard Worker #include <cstring>
7*89c4ff92SAndroid Build Coastguard Worker #include <cmath>
8*89c4ff92SAndroid Build Coastguard Worker #include <numeric>
9*89c4ff92SAndroid Build Coastguard Worker #include <algorithm>
10*89c4ff92SAndroid Build Coastguard Worker #include <memory>
11*89c4ff92SAndroid Build Coastguard Worker #include "Wav2LetterPreprocessor.hpp"
12*89c4ff92SAndroid Build Coastguard Worker #include "Wav2LetterMFCC.hpp"
13*89c4ff92SAndroid Build Coastguard Worker
GetMean(Array2d<float> & vec)14*89c4ff92SAndroid Build Coastguard Worker float Wav2LetterPreprocessor::GetMean(Array2d<float>& vec)
15*89c4ff92SAndroid Build Coastguard Worker {
16*89c4ff92SAndroid Build Coastguard Worker return MathUtils::MeanF32(vec.begin(), vec.totalSize());
17*89c4ff92SAndroid Build Coastguard Worker }
18*89c4ff92SAndroid Build Coastguard Worker
GetStdDev(Array2d<float> & vec,const float mean)19*89c4ff92SAndroid Build Coastguard Worker float Wav2LetterPreprocessor::GetStdDev(Array2d<float>& vec, const float mean)
20*89c4ff92SAndroid Build Coastguard Worker {
21*89c4ff92SAndroid Build Coastguard Worker return MathUtils::StdDevF32(vec.begin(), vec.totalSize(), mean);
22*89c4ff92SAndroid Build Coastguard Worker }
23*89c4ff92SAndroid Build Coastguard Worker
NormaliseVec(Array2d<float> & vec)24*89c4ff92SAndroid Build Coastguard Worker void Wav2LetterPreprocessor::NormaliseVec(Array2d<float>& vec)
25*89c4ff92SAndroid Build Coastguard Worker {
26*89c4ff92SAndroid Build Coastguard Worker auto mean = Wav2LetterPreprocessor::GetMean(vec);
27*89c4ff92SAndroid Build Coastguard Worker auto stddev = Wav2LetterPreprocessor::GetStdDev(vec, mean);
28*89c4ff92SAndroid Build Coastguard Worker
29*89c4ff92SAndroid Build Coastguard Worker if (stddev == 0)
30*89c4ff92SAndroid Build Coastguard Worker {
31*89c4ff92SAndroid Build Coastguard Worker std::fill(vec.begin(), vec.end(), 0);
32*89c4ff92SAndroid Build Coastguard Worker }
33*89c4ff92SAndroid Build Coastguard Worker else
34*89c4ff92SAndroid Build Coastguard Worker {
35*89c4ff92SAndroid Build Coastguard Worker const float stddevInv = 1.f/stddev;
36*89c4ff92SAndroid Build Coastguard Worker const float normalisedMean = mean/stddev;
37*89c4ff92SAndroid Build Coastguard Worker
38*89c4ff92SAndroid Build Coastguard Worker auto NormalisingFunction = [=](float &value) {
39*89c4ff92SAndroid Build Coastguard Worker value = value * stddevInv - normalisedMean;
40*89c4ff92SAndroid Build Coastguard Worker };
41*89c4ff92SAndroid Build Coastguard Worker std::for_each(vec.begin(), vec.end(), NormalisingFunction);
42*89c4ff92SAndroid Build Coastguard Worker }
43*89c4ff92SAndroid Build Coastguard Worker }
44*89c4ff92SAndroid Build Coastguard Worker
Normalise()45*89c4ff92SAndroid Build Coastguard Worker void Wav2LetterPreprocessor::Normalise()
46*89c4ff92SAndroid Build Coastguard Worker {
47*89c4ff92SAndroid Build Coastguard Worker Wav2LetterPreprocessor::NormaliseVec(this->m_mfccBuf);
48*89c4ff92SAndroid Build Coastguard Worker Wav2LetterPreprocessor::NormaliseVec(this->m_delta1Buf);
49*89c4ff92SAndroid Build Coastguard Worker Wav2LetterPreprocessor::NormaliseVec(this->m_delta2Buf);
50*89c4ff92SAndroid Build Coastguard Worker }
51*89c4ff92SAndroid Build Coastguard Worker
GetQuantElem(const float elem,const float quantScale,const int quantOffset,const float minVal,const float maxVal)52*89c4ff92SAndroid Build Coastguard Worker float Wav2LetterPreprocessor::GetQuantElem(
53*89c4ff92SAndroid Build Coastguard Worker const float elem,
54*89c4ff92SAndroid Build Coastguard Worker const float quantScale,
55*89c4ff92SAndroid Build Coastguard Worker const int quantOffset,
56*89c4ff92SAndroid Build Coastguard Worker const float minVal,
57*89c4ff92SAndroid Build Coastguard Worker const float maxVal)
58*89c4ff92SAndroid Build Coastguard Worker {
59*89c4ff92SAndroid Build Coastguard Worker float val = std::round((elem/quantScale) + quantOffset);
60*89c4ff92SAndroid Build Coastguard Worker float returnVal = std::min<float>(std::max<float>(val, minVal), maxVal);
61*89c4ff92SAndroid Build Coastguard Worker return returnVal;
62*89c4ff92SAndroid Build Coastguard Worker }
63*89c4ff92SAndroid Build Coastguard Worker
Invoke(const float * audioData,const uint32_t audioDataLen,std::vector<int8_t> & output,int quantOffset,float quantScale)64*89c4ff92SAndroid Build Coastguard Worker bool Wav2LetterPreprocessor::Invoke(const float* audioData, const uint32_t audioDataLen, std::vector<int8_t>& output,
65*89c4ff92SAndroid Build Coastguard Worker int quantOffset, float quantScale)
66*89c4ff92SAndroid Build Coastguard Worker {
67*89c4ff92SAndroid Build Coastguard Worker this->m_window = SlidingWindow<const float>(
68*89c4ff92SAndroid Build Coastguard Worker audioData, audioDataLen,
69*89c4ff92SAndroid Build Coastguard Worker this->m_windowLen, this->m_windowStride);
70*89c4ff92SAndroid Build Coastguard Worker
71*89c4ff92SAndroid Build Coastguard Worker uint32_t mfccBufIdx = 0;
72*89c4ff92SAndroid Build Coastguard Worker
73*89c4ff92SAndroid Build Coastguard Worker // Init buffers with 0
74*89c4ff92SAndroid Build Coastguard Worker std::fill(m_mfccBuf.begin(), m_mfccBuf.end(), 0.f);
75*89c4ff92SAndroid Build Coastguard Worker std::fill(m_delta1Buf.begin(), m_delta1Buf.end(), 0.f);
76*89c4ff92SAndroid Build Coastguard Worker std::fill(m_delta2Buf.begin(), m_delta2Buf.end(), 0.f);
77*89c4ff92SAndroid Build Coastguard Worker
78*89c4ff92SAndroid Build Coastguard Worker // While we can slide over the window
79*89c4ff92SAndroid Build Coastguard Worker while (this->m_window.HasNext())
80*89c4ff92SAndroid Build Coastguard Worker {
81*89c4ff92SAndroid Build Coastguard Worker const float* mfccWindow = this->m_window.Next();
82*89c4ff92SAndroid Build Coastguard Worker auto mfccAudioData = std::vector<float>(
83*89c4ff92SAndroid Build Coastguard Worker mfccWindow,
84*89c4ff92SAndroid Build Coastguard Worker mfccWindow + this->m_windowLen);
85*89c4ff92SAndroid Build Coastguard Worker
86*89c4ff92SAndroid Build Coastguard Worker auto mfcc = this->m_mfcc->MfccCompute(mfccAudioData);
87*89c4ff92SAndroid Build Coastguard Worker for (size_t i = 0; i < this->m_mfccBuf.size(0); ++i)
88*89c4ff92SAndroid Build Coastguard Worker {
89*89c4ff92SAndroid Build Coastguard Worker this->m_mfccBuf(i, mfccBufIdx) = mfcc[i];
90*89c4ff92SAndroid Build Coastguard Worker }
91*89c4ff92SAndroid Build Coastguard Worker ++mfccBufIdx;
92*89c4ff92SAndroid Build Coastguard Worker }
93*89c4ff92SAndroid Build Coastguard Worker
94*89c4ff92SAndroid Build Coastguard Worker // Pad MFCC if needed by repeating last feature vector
95*89c4ff92SAndroid Build Coastguard Worker while (mfccBufIdx != this->m_mfcc->m_params.m_numMfccVectors)
96*89c4ff92SAndroid Build Coastguard Worker {
97*89c4ff92SAndroid Build Coastguard Worker memcpy(&this->m_mfccBuf(0, mfccBufIdx),
98*89c4ff92SAndroid Build Coastguard Worker &this->m_mfccBuf(0, mfccBufIdx - 1), sizeof(float) * this->m_mfcc->m_params.m_numMfccFeatures);
99*89c4ff92SAndroid Build Coastguard Worker ++mfccBufIdx;
100*89c4ff92SAndroid Build Coastguard Worker }
101*89c4ff92SAndroid Build Coastguard Worker
102*89c4ff92SAndroid Build Coastguard Worker // Compute first and second order deltas from MFCCs
103*89c4ff92SAndroid Build Coastguard Worker Wav2LetterPreprocessor::ComputeDeltas(this->m_mfccBuf,
104*89c4ff92SAndroid Build Coastguard Worker this->m_delta1Buf,
105*89c4ff92SAndroid Build Coastguard Worker this->m_delta2Buf);
106*89c4ff92SAndroid Build Coastguard Worker
107*89c4ff92SAndroid Build Coastguard Worker // Normalise
108*89c4ff92SAndroid Build Coastguard Worker this->Normalise();
109*89c4ff92SAndroid Build Coastguard Worker
110*89c4ff92SAndroid Build Coastguard Worker return this->Quantise<int8_t>(output.data(), quantOffset, quantScale);
111*89c4ff92SAndroid Build Coastguard Worker }
112*89c4ff92SAndroid Build Coastguard Worker
ComputeDeltas(Array2d<float> & mfcc,Array2d<float> & delta1,Array2d<float> & delta2)113*89c4ff92SAndroid Build Coastguard Worker bool Wav2LetterPreprocessor::ComputeDeltas(Array2d<float>& mfcc,
114*89c4ff92SAndroid Build Coastguard Worker Array2d<float>& delta1,
115*89c4ff92SAndroid Build Coastguard Worker Array2d<float>& delta2)
116*89c4ff92SAndroid Build Coastguard Worker {
117*89c4ff92SAndroid Build Coastguard Worker const std::vector <float> delta1Coeffs =
118*89c4ff92SAndroid Build Coastguard Worker {6.66666667e-02, 5.00000000e-02, 3.33333333e-02,
119*89c4ff92SAndroid Build Coastguard Worker 1.66666667e-02, -3.46944695e-18, -1.66666667e-02,
120*89c4ff92SAndroid Build Coastguard Worker -3.33333333e-02, -5.00000000e-02, -6.66666667e-02};
121*89c4ff92SAndroid Build Coastguard Worker
122*89c4ff92SAndroid Build Coastguard Worker const std::vector <float> delta2Coeffs =
123*89c4ff92SAndroid Build Coastguard Worker {0.06060606, 0.01515152, -0.01731602,
124*89c4ff92SAndroid Build Coastguard Worker -0.03679654, -0.04329004, -0.03679654,
125*89c4ff92SAndroid Build Coastguard Worker -0.01731602, 0.01515152, 0.06060606};
126*89c4ff92SAndroid Build Coastguard Worker
127*89c4ff92SAndroid Build Coastguard Worker if (delta1.size(0) == 0 || delta2.size(0) != delta1.size(0) ||
128*89c4ff92SAndroid Build Coastguard Worker mfcc.size(0) == 0 || mfcc.size(1) == 0)
129*89c4ff92SAndroid Build Coastguard Worker {
130*89c4ff92SAndroid Build Coastguard Worker return false;
131*89c4ff92SAndroid Build Coastguard Worker }
132*89c4ff92SAndroid Build Coastguard Worker
133*89c4ff92SAndroid Build Coastguard Worker // Get the middle index; coeff vec len should always be odd
134*89c4ff92SAndroid Build Coastguard Worker const size_t coeffLen = delta1Coeffs.size();
135*89c4ff92SAndroid Build Coastguard Worker const size_t fMidIdx = (coeffLen - 1)/2;
136*89c4ff92SAndroid Build Coastguard Worker const size_t numFeatures = mfcc.size(0);
137*89c4ff92SAndroid Build Coastguard Worker const size_t numFeatVectors = mfcc.size(1);
138*89c4ff92SAndroid Build Coastguard Worker
139*89c4ff92SAndroid Build Coastguard Worker // iterate through features in MFCC vector
140*89c4ff92SAndroid Build Coastguard Worker for (size_t i = 0; i < numFeatures; ++i)
141*89c4ff92SAndroid Build Coastguard Worker {
142*89c4ff92SAndroid Build Coastguard Worker /* for each feature, iterate through time (t) samples representing feature evolution and
143*89c4ff92SAndroid Build Coastguard Worker * calculate d/dt and d^2/dt^2, using 1d convolution with differential kernels.
144*89c4ff92SAndroid Build Coastguard Worker * Convolution padding = valid, result size is `time length - kernel length + 1`.
145*89c4ff92SAndroid Build Coastguard Worker * The result is padded with 0 from both sides to match the size of initial time samples data.
146*89c4ff92SAndroid Build Coastguard Worker *
147*89c4ff92SAndroid Build Coastguard Worker * For the small filter, conv1d implementation as a simple loop is efficient enough.
148*89c4ff92SAndroid Build Coastguard Worker * Filters of a greater size would need CMSIS-DSP functions to be used, like arm_fir_f32.
149*89c4ff92SAndroid Build Coastguard Worker */
150*89c4ff92SAndroid Build Coastguard Worker
151*89c4ff92SAndroid Build Coastguard Worker for (size_t j = fMidIdx; j < numFeatVectors - fMidIdx; ++j)
152*89c4ff92SAndroid Build Coastguard Worker {
153*89c4ff92SAndroid Build Coastguard Worker float d1 = 0;
154*89c4ff92SAndroid Build Coastguard Worker float d2 = 0;
155*89c4ff92SAndroid Build Coastguard Worker const size_t mfccStIdx = j - fMidIdx;
156*89c4ff92SAndroid Build Coastguard Worker
157*89c4ff92SAndroid Build Coastguard Worker for (size_t k = 0, m = coeffLen - 1; k < coeffLen; ++k, --m)
158*89c4ff92SAndroid Build Coastguard Worker {
159*89c4ff92SAndroid Build Coastguard Worker
160*89c4ff92SAndroid Build Coastguard Worker d1 += mfcc(i,mfccStIdx + k) * delta1Coeffs[m];
161*89c4ff92SAndroid Build Coastguard Worker d2 += mfcc(i,mfccStIdx + k) * delta2Coeffs[m];
162*89c4ff92SAndroid Build Coastguard Worker }
163*89c4ff92SAndroid Build Coastguard Worker
164*89c4ff92SAndroid Build Coastguard Worker delta1(i,j) = d1;
165*89c4ff92SAndroid Build Coastguard Worker delta2(i,j) = d2;
166*89c4ff92SAndroid Build Coastguard Worker }
167*89c4ff92SAndroid Build Coastguard Worker }
168*89c4ff92SAndroid Build Coastguard Worker
169*89c4ff92SAndroid Build Coastguard Worker return true;
170*89c4ff92SAndroid Build Coastguard Worker }
171*89c4ff92SAndroid Build Coastguard Worker
Wav2LetterPreprocessor(const uint32_t windowLen,const uint32_t windowStride,std::unique_ptr<Wav2LetterMFCC> mfccInst)172*89c4ff92SAndroid Build Coastguard Worker Wav2LetterPreprocessor::Wav2LetterPreprocessor(const uint32_t windowLen,
173*89c4ff92SAndroid Build Coastguard Worker const uint32_t windowStride,
174*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<Wav2LetterMFCC> mfccInst):
175*89c4ff92SAndroid Build Coastguard Worker m_mfcc(std::move(mfccInst)),
176*89c4ff92SAndroid Build Coastguard Worker m_mfccBuf(m_mfcc->m_params.m_numMfccFeatures, m_mfcc->m_params.m_numMfccVectors),
177*89c4ff92SAndroid Build Coastguard Worker m_delta1Buf(m_mfcc->m_params.m_numMfccFeatures, m_mfcc->m_params.m_numMfccVectors),
178*89c4ff92SAndroid Build Coastguard Worker m_delta2Buf(m_mfcc->m_params.m_numMfccFeatures, m_mfcc->m_params.m_numMfccVectors),
179*89c4ff92SAndroid Build Coastguard Worker m_windowLen(windowLen),
180*89c4ff92SAndroid Build Coastguard Worker m_windowStride(windowStride)
181*89c4ff92SAndroid Build Coastguard Worker {
182*89c4ff92SAndroid Build Coastguard Worker if (m_mfcc->m_params.m_numMfccFeatures > 0 && windowLen > 0)
183*89c4ff92SAndroid Build Coastguard Worker {
184*89c4ff92SAndroid Build Coastguard Worker this->m_mfcc->Init();
185*89c4ff92SAndroid Build Coastguard Worker }
186*89c4ff92SAndroid Build Coastguard Worker std::fill(m_mfccBuf.begin(), m_mfccBuf.end(), 0.f);
187*89c4ff92SAndroid Build Coastguard Worker }