xref: /aosp_15_r20/external/armnn/samples/SpeechRecognition/include/Wav2LetterPreprocessor.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #ifndef SPEECH_RECOGNITION_EXAMPLE_WAV2LETTERPREPROCESSOR_HPP
6*89c4ff92SAndroid Build Coastguard Worker #define SPEECH_RECOGNITION_EXAMPLE_WAV2LETTERPREPROCESSOR_HPP
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <numeric>
9*89c4ff92SAndroid Build Coastguard Worker #include "DataStructures.hpp"
10*89c4ff92SAndroid Build Coastguard Worker #include "SlidingWindow.hpp"
11*89c4ff92SAndroid Build Coastguard Worker #include "MFCC.hpp"
12*89c4ff92SAndroid Build Coastguard Worker #include "Wav2LetterMFCC.hpp"
13*89c4ff92SAndroid Build Coastguard Worker // Class to facilitate pre-processing calculation for Wav2Letter model for ASR
14*89c4ff92SAndroid Build Coastguard Worker using AudioWindow = SlidingWindow<const float>;
15*89c4ff92SAndroid Build Coastguard Worker 
16*89c4ff92SAndroid Build Coastguard Worker class Wav2LetterPreprocessor
17*89c4ff92SAndroid Build Coastguard Worker {
18*89c4ff92SAndroid Build Coastguard Worker public:
19*89c4ff92SAndroid Build Coastguard Worker     Wav2LetterPreprocessor(uint32_t windowLen, uint32_t windowStride,
20*89c4ff92SAndroid Build Coastguard Worker                            std::unique_ptr<Wav2LetterMFCC> mfccInst);
21*89c4ff92SAndroid Build Coastguard Worker 
22*89c4ff92SAndroid Build Coastguard Worker     /**
23*89c4ff92SAndroid Build Coastguard Worker      * @brief       Calculates the features required from audio data. This
24*89c4ff92SAndroid Build Coastguard Worker      *              includes MFCC, first and second order deltas,
25*89c4ff92SAndroid Build Coastguard Worker      *              normalisation and finally, quantisation. The tensor is
26*89c4ff92SAndroid Build Coastguard Worker      *              populated with feature from a given window placed along
27*89c4ff92SAndroid Build Coastguard Worker      *              in a single row.
28*89c4ff92SAndroid Build Coastguard Worker      * @param[in]   audioData     pointer to the first element of audio data
29*89c4ff92SAndroid Build Coastguard Worker      * @param[in]   audioDataLen  number of elements in the audio data
30*89c4ff92SAndroid Build Coastguard Worker      * @param[in]   tensor        tensor to be populated
31*89c4ff92SAndroid Build Coastguard Worker      * @return      true if successful, false in case of error.
32*89c4ff92SAndroid Build Coastguard Worker      */
33*89c4ff92SAndroid Build Coastguard Worker     bool Invoke(const float* audioData, uint32_t audioDataLen, std::vector<int8_t>& output, int quantOffset,
34*89c4ff92SAndroid Build Coastguard Worker                 float quantScale);
35*89c4ff92SAndroid Build Coastguard Worker 
36*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<MFCC> m_mfcc;
37*89c4ff92SAndroid Build Coastguard Worker 
38*89c4ff92SAndroid Build Coastguard Worker     // Actual buffers to be populated
39*89c4ff92SAndroid Build Coastguard Worker     Array2d<float> m_mfccBuf;         // Contiguous buffer 1D: MFCC
40*89c4ff92SAndroid Build Coastguard Worker     Array2d<float> m_delta1Buf;       // Contiguous buffer 1D: Delta 1
41*89c4ff92SAndroid Build Coastguard Worker     Array2d<float> m_delta2Buf;       // Contiguous buffer 1D: Delta 2
42*89c4ff92SAndroid Build Coastguard Worker 
43*89c4ff92SAndroid Build Coastguard Worker     uint32_t m_windowLen;       // Window length for MFCC
44*89c4ff92SAndroid Build Coastguard Worker     uint32_t m_windowStride;    // Window stride len for MFCC
45*89c4ff92SAndroid Build Coastguard Worker     AudioWindow m_window;       // Sliding window
46*89c4ff92SAndroid Build Coastguard Worker 
47*89c4ff92SAndroid Build Coastguard Worker protected:
48*89c4ff92SAndroid Build Coastguard Worker     /**
49*89c4ff92SAndroid Build Coastguard Worker      * @brief Computes the first and second order deltas for the
50*89c4ff92SAndroid Build Coastguard Worker      *        MFCC buffers - they are assumed to be populated.
51*89c4ff92SAndroid Build Coastguard Worker      *
52*89c4ff92SAndroid Build Coastguard Worker      * @param[in]  mfcc   MFCC buffers
53*89c4ff92SAndroid Build Coastguard Worker      * @param[out] delta1 result of the first diff computation
54*89c4ff92SAndroid Build Coastguard Worker      * @param[out] delta2 result of the second diff computation
55*89c4ff92SAndroid Build Coastguard Worker      *
56*89c4ff92SAndroid Build Coastguard Worker      * @return true if successful, false otherwise
57*89c4ff92SAndroid Build Coastguard Worker      */
58*89c4ff92SAndroid Build Coastguard Worker     static bool ComputeDeltas(Array2d<float>& mfcc,
59*89c4ff92SAndroid Build Coastguard Worker                               Array2d<float>& delta1,
60*89c4ff92SAndroid Build Coastguard Worker                               Array2d<float>& delta2);
61*89c4ff92SAndroid Build Coastguard Worker 
62*89c4ff92SAndroid Build Coastguard Worker protected:
63*89c4ff92SAndroid Build Coastguard Worker 
64*89c4ff92SAndroid Build Coastguard Worker     /**
65*89c4ff92SAndroid Build Coastguard Worker      * @brief      Given a 2D vector of floats, computes the mean
66*89c4ff92SAndroid Build Coastguard Worker      * @param[in]   vec      vector of vector of floats
67*89c4ff92SAndroid Build Coastguard Worker      * @return      mean value
68*89c4ff92SAndroid Build Coastguard Worker      */
69*89c4ff92SAndroid Build Coastguard Worker     static float GetMean(Array2d<float>& vec);
70*89c4ff92SAndroid Build Coastguard Worker 
71*89c4ff92SAndroid Build Coastguard Worker     /**
72*89c4ff92SAndroid Build Coastguard Worker      * @brief       Given a 2D vector of floats, computes the stddev
73*89c4ff92SAndroid Build Coastguard Worker      * @param[in]   vec   vector of vector of floats
74*89c4ff92SAndroid Build Coastguard Worker      * @param[in]   mean     mean value of the vector passed in
75*89c4ff92SAndroid Build Coastguard Worker      * @return      stddev value
76*89c4ff92SAndroid Build Coastguard Worker      */
77*89c4ff92SAndroid Build Coastguard Worker     static float GetStdDev(Array2d<float>& vec, float mean);
78*89c4ff92SAndroid Build Coastguard Worker 
79*89c4ff92SAndroid Build Coastguard Worker     /**
80*89c4ff92SAndroid Build Coastguard Worker      * @brief           Given a 2D vector of floats, normalises it using
81*89c4ff92SAndroid Build Coastguard Worker      *                  the mean and the stddev
82*89c4ff92SAndroid Build Coastguard Worker      * @param[in/out]   vec      vector of vector of floats
83*89c4ff92SAndroid Build Coastguard Worker      * @return
84*89c4ff92SAndroid Build Coastguard Worker      */
85*89c4ff92SAndroid Build Coastguard Worker     static void NormaliseVec(Array2d<float>& vec);
86*89c4ff92SAndroid Build Coastguard Worker 
87*89c4ff92SAndroid Build Coastguard Worker     /**
88*89c4ff92SAndroid Build Coastguard Worker      * @brief       Normalises the MFCC and delta buffers
89*89c4ff92SAndroid Build Coastguard Worker      * @return
90*89c4ff92SAndroid Build Coastguard Worker      */
91*89c4ff92SAndroid Build Coastguard Worker     void Normalise();
92*89c4ff92SAndroid Build Coastguard Worker 
93*89c4ff92SAndroid Build Coastguard Worker     /**
94*89c4ff92SAndroid Build Coastguard Worker      * @brief       Given the quantisation and data type limits, computes
95*89c4ff92SAndroid Build Coastguard Worker      *              the quantised values of a floating point input data.
96*89c4ff92SAndroid Build Coastguard Worker      * @param[in]   elem            Element to be quantised
97*89c4ff92SAndroid Build Coastguard Worker      * @param[in]   quantScale      Scale
98*89c4ff92SAndroid Build Coastguard Worker      * @param[in]   quantOffset     Offset
99*89c4ff92SAndroid Build Coastguard Worker      * @param[in]   minVal          Numerical limit - minimum
100*89c4ff92SAndroid Build Coastguard Worker      * @param[in]   maxVal          Numerical limit - maximum
101*89c4ff92SAndroid Build Coastguard Worker      * @return      floating point quantised value
102*89c4ff92SAndroid Build Coastguard Worker      */
103*89c4ff92SAndroid Build Coastguard Worker     static float GetQuantElem(
104*89c4ff92SAndroid Build Coastguard Worker             float elem,
105*89c4ff92SAndroid Build Coastguard Worker             float quantScale,
106*89c4ff92SAndroid Build Coastguard Worker             int quantOffset,
107*89c4ff92SAndroid Build Coastguard Worker             float minVal,
108*89c4ff92SAndroid Build Coastguard Worker             float maxVal);
109*89c4ff92SAndroid Build Coastguard Worker 
110*89c4ff92SAndroid Build Coastguard Worker     /**
111*89c4ff92SAndroid Build Coastguard Worker      * @brief       Quantises the MFCC and delta buffers, and places them
112*89c4ff92SAndroid Build Coastguard Worker      *              in the output buffer. While doing so, it transposes
113*89c4ff92SAndroid Build Coastguard Worker      *              the data. Reason: Buffers in this class are arranged
114*89c4ff92SAndroid Build Coastguard Worker      *              for "time" axis to be row major. Primary reason for
115*89c4ff92SAndroid Build Coastguard Worker      *              this being the convolution speed up (as we can use
116*89c4ff92SAndroid Build Coastguard Worker      *              contiguous memory). The output, however, requires the
117*89c4ff92SAndroid Build Coastguard Worker      *              time axis to be in column major arrangement.
118*89c4ff92SAndroid Build Coastguard Worker      * @param[in]   outputBuf       pointer to the output buffer
119*89c4ff92SAndroid Build Coastguard Worker      * @param[in]   outputBufSz     output buffer's size
120*89c4ff92SAndroid Build Coastguard Worker      * @param[in]   quantScale      quantisation scale
121*89c4ff92SAndroid Build Coastguard Worker      * @param[in]   quantOffset     quantisation offset
122*89c4ff92SAndroid Build Coastguard Worker      */
123*89c4ff92SAndroid Build Coastguard Worker     template<typename T>
Quantise(T * outputBuf,int quantOffset,float quantScale)124*89c4ff92SAndroid Build Coastguard Worker     bool Quantise(T*outputBuf, int quantOffset, float quantScale)
125*89c4ff92SAndroid Build Coastguard Worker     {
126*89c4ff92SAndroid Build Coastguard Worker         // Populate
127*89c4ff92SAndroid Build Coastguard Worker         T* outputBufMfcc = outputBuf;
128*89c4ff92SAndroid Build Coastguard Worker         T* outputBufD1 = outputBuf + this->m_mfcc->m_params.m_numMfccFeatures;
129*89c4ff92SAndroid Build Coastguard Worker         T* outputBufD2 = outputBufD1 + this->m_mfcc->m_params.m_numMfccFeatures;
130*89c4ff92SAndroid Build Coastguard Worker         const uint32_t ptrIncr = this->m_mfcc->m_params.m_numMfccFeatures * 2; // (3 vectors - 1 vector)
131*89c4ff92SAndroid Build Coastguard Worker 
132*89c4ff92SAndroid Build Coastguard Worker         const float minVal = std::numeric_limits<T>::min();
133*89c4ff92SAndroid Build Coastguard Worker         const float maxVal = std::numeric_limits<T>::max();
134*89c4ff92SAndroid Build Coastguard Worker 
135*89c4ff92SAndroid Build Coastguard Worker         // We need to do a transpose while copying and concatenating the tensor
136*89c4ff92SAndroid Build Coastguard Worker         for (uint32_t j = 0; j < this->m_mfcc->m_params.m_numMfccVectors; ++j)
137*89c4ff92SAndroid Build Coastguard Worker         {
138*89c4ff92SAndroid Build Coastguard Worker             for (uint32_t i = 0; i < this->m_mfcc->m_params.m_numMfccFeatures; ++i)
139*89c4ff92SAndroid Build Coastguard Worker             {
140*89c4ff92SAndroid Build Coastguard Worker                 *outputBufMfcc++ = static_cast<T>(Wav2LetterPreprocessor::GetQuantElem(
141*89c4ff92SAndroid Build Coastguard Worker                         this->m_mfccBuf(i, j), quantScale,
142*89c4ff92SAndroid Build Coastguard Worker                         quantOffset, minVal, maxVal));
143*89c4ff92SAndroid Build Coastguard Worker                 *outputBufD1++ = static_cast<T>(Wav2LetterPreprocessor::GetQuantElem(
144*89c4ff92SAndroid Build Coastguard Worker                         this->m_delta1Buf(i, j), quantScale,
145*89c4ff92SAndroid Build Coastguard Worker                         quantOffset, minVal, maxVal));
146*89c4ff92SAndroid Build Coastguard Worker                 *outputBufD2++ = static_cast<T>(Wav2LetterPreprocessor::GetQuantElem(
147*89c4ff92SAndroid Build Coastguard Worker                         this->m_delta2Buf(i, j), quantScale,
148*89c4ff92SAndroid Build Coastguard Worker                         quantOffset, minVal, maxVal));
149*89c4ff92SAndroid Build Coastguard Worker             }
150*89c4ff92SAndroid Build Coastguard Worker             outputBufMfcc += ptrIncr;
151*89c4ff92SAndroid Build Coastguard Worker             outputBufD1 += ptrIncr;
152*89c4ff92SAndroid Build Coastguard Worker             outputBufD2 += ptrIncr;
153*89c4ff92SAndroid Build Coastguard Worker         }
154*89c4ff92SAndroid Build Coastguard Worker         return true;
155*89c4ff92SAndroid Build Coastguard Worker     }
156*89c4ff92SAndroid Build Coastguard Worker };
157*89c4ff92SAndroid Build Coastguard Worker 
158*89c4ff92SAndroid Build Coastguard Worker #endif //SPEECH_RECOGNITION_EXAMPLE_WAV2LETTERPREPROCESSOR_HPP
159