1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker #ifndef SPEECH_RECOGNITION_EXAMPLE_WAV2LETTERPREPROCESSOR_HPP 6*89c4ff92SAndroid Build Coastguard Worker #define SPEECH_RECOGNITION_EXAMPLE_WAV2LETTERPREPROCESSOR_HPP 7*89c4ff92SAndroid Build Coastguard Worker 8*89c4ff92SAndroid Build Coastguard Worker #include <numeric> 9*89c4ff92SAndroid Build Coastguard Worker #include "DataStructures.hpp" 10*89c4ff92SAndroid Build Coastguard Worker #include "SlidingWindow.hpp" 11*89c4ff92SAndroid Build Coastguard Worker #include "MFCC.hpp" 12*89c4ff92SAndroid Build Coastguard Worker #include "Wav2LetterMFCC.hpp" 13*89c4ff92SAndroid Build Coastguard Worker // Class to facilitate pre-processing calculation for Wav2Letter model for ASR 14*89c4ff92SAndroid Build Coastguard Worker using AudioWindow = SlidingWindow<const float>; 15*89c4ff92SAndroid Build Coastguard Worker 16*89c4ff92SAndroid Build Coastguard Worker class Wav2LetterPreprocessor 17*89c4ff92SAndroid Build Coastguard Worker { 18*89c4ff92SAndroid Build Coastguard Worker public: 19*89c4ff92SAndroid Build Coastguard Worker Wav2LetterPreprocessor(uint32_t windowLen, uint32_t windowStride, 20*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<Wav2LetterMFCC> mfccInst); 21*89c4ff92SAndroid Build Coastguard Worker 22*89c4ff92SAndroid Build Coastguard Worker /** 23*89c4ff92SAndroid Build Coastguard Worker * @brief Calculates the features required from audio data. This 24*89c4ff92SAndroid Build Coastguard Worker * includes MFCC, first and second order deltas, 25*89c4ff92SAndroid Build Coastguard Worker * normalisation and finally, quantisation. The tensor is 26*89c4ff92SAndroid Build Coastguard Worker * populated with feature from a given window placed along 27*89c4ff92SAndroid Build Coastguard Worker * in a single row. 28*89c4ff92SAndroid Build Coastguard Worker * @param[in] audioData pointer to the first element of audio data 29*89c4ff92SAndroid Build Coastguard Worker * @param[in] audioDataLen number of elements in the audio data 30*89c4ff92SAndroid Build Coastguard Worker * @param[in] tensor tensor to be populated 31*89c4ff92SAndroid Build Coastguard Worker * @return true if successful, false in case of error. 32*89c4ff92SAndroid Build Coastguard Worker */ 33*89c4ff92SAndroid Build Coastguard Worker bool Invoke(const float* audioData, uint32_t audioDataLen, std::vector<int8_t>& output, int quantOffset, 34*89c4ff92SAndroid Build Coastguard Worker float quantScale); 35*89c4ff92SAndroid Build Coastguard Worker 36*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<MFCC> m_mfcc; 37*89c4ff92SAndroid Build Coastguard Worker 38*89c4ff92SAndroid Build Coastguard Worker // Actual buffers to be populated 39*89c4ff92SAndroid Build Coastguard Worker Array2d<float> m_mfccBuf; // Contiguous buffer 1D: MFCC 40*89c4ff92SAndroid Build Coastguard Worker Array2d<float> m_delta1Buf; // Contiguous buffer 1D: Delta 1 41*89c4ff92SAndroid Build Coastguard Worker Array2d<float> m_delta2Buf; // Contiguous buffer 1D: Delta 2 42*89c4ff92SAndroid Build Coastguard Worker 43*89c4ff92SAndroid Build Coastguard Worker uint32_t m_windowLen; // Window length for MFCC 44*89c4ff92SAndroid Build Coastguard Worker uint32_t m_windowStride; // Window stride len for MFCC 45*89c4ff92SAndroid Build Coastguard Worker AudioWindow m_window; // Sliding window 46*89c4ff92SAndroid Build Coastguard Worker 47*89c4ff92SAndroid Build Coastguard Worker protected: 48*89c4ff92SAndroid Build Coastguard Worker /** 49*89c4ff92SAndroid Build Coastguard Worker * @brief Computes the first and second order deltas for the 50*89c4ff92SAndroid Build Coastguard Worker * MFCC buffers - they are assumed to be populated. 51*89c4ff92SAndroid Build Coastguard Worker * 52*89c4ff92SAndroid Build Coastguard Worker * @param[in] mfcc MFCC buffers 53*89c4ff92SAndroid Build Coastguard Worker * @param[out] delta1 result of the first diff computation 54*89c4ff92SAndroid Build Coastguard Worker * @param[out] delta2 result of the second diff computation 55*89c4ff92SAndroid Build Coastguard Worker * 56*89c4ff92SAndroid Build Coastguard Worker * @return true if successful, false otherwise 57*89c4ff92SAndroid Build Coastguard Worker */ 58*89c4ff92SAndroid Build Coastguard Worker static bool ComputeDeltas(Array2d<float>& mfcc, 59*89c4ff92SAndroid Build Coastguard Worker Array2d<float>& delta1, 60*89c4ff92SAndroid Build Coastguard Worker Array2d<float>& delta2); 61*89c4ff92SAndroid Build Coastguard Worker 62*89c4ff92SAndroid Build Coastguard Worker protected: 63*89c4ff92SAndroid Build Coastguard Worker 64*89c4ff92SAndroid Build Coastguard Worker /** 65*89c4ff92SAndroid Build Coastguard Worker * @brief Given a 2D vector of floats, computes the mean 66*89c4ff92SAndroid Build Coastguard Worker * @param[in] vec vector of vector of floats 67*89c4ff92SAndroid Build Coastguard Worker * @return mean value 68*89c4ff92SAndroid Build Coastguard Worker */ 69*89c4ff92SAndroid Build Coastguard Worker static float GetMean(Array2d<float>& vec); 70*89c4ff92SAndroid Build Coastguard Worker 71*89c4ff92SAndroid Build Coastguard Worker /** 72*89c4ff92SAndroid Build Coastguard Worker * @brief Given a 2D vector of floats, computes the stddev 73*89c4ff92SAndroid Build Coastguard Worker * @param[in] vec vector of vector of floats 74*89c4ff92SAndroid Build Coastguard Worker * @param[in] mean mean value of the vector passed in 75*89c4ff92SAndroid Build Coastguard Worker * @return stddev value 76*89c4ff92SAndroid Build Coastguard Worker */ 77*89c4ff92SAndroid Build Coastguard Worker static float GetStdDev(Array2d<float>& vec, float mean); 78*89c4ff92SAndroid Build Coastguard Worker 79*89c4ff92SAndroid Build Coastguard Worker /** 80*89c4ff92SAndroid Build Coastguard Worker * @brief Given a 2D vector of floats, normalises it using 81*89c4ff92SAndroid Build Coastguard Worker * the mean and the stddev 82*89c4ff92SAndroid Build Coastguard Worker * @param[in/out] vec vector of vector of floats 83*89c4ff92SAndroid Build Coastguard Worker * @return 84*89c4ff92SAndroid Build Coastguard Worker */ 85*89c4ff92SAndroid Build Coastguard Worker static void NormaliseVec(Array2d<float>& vec); 86*89c4ff92SAndroid Build Coastguard Worker 87*89c4ff92SAndroid Build Coastguard Worker /** 88*89c4ff92SAndroid Build Coastguard Worker * @brief Normalises the MFCC and delta buffers 89*89c4ff92SAndroid Build Coastguard Worker * @return 90*89c4ff92SAndroid Build Coastguard Worker */ 91*89c4ff92SAndroid Build Coastguard Worker void Normalise(); 92*89c4ff92SAndroid Build Coastguard Worker 93*89c4ff92SAndroid Build Coastguard Worker /** 94*89c4ff92SAndroid Build Coastguard Worker * @brief Given the quantisation and data type limits, computes 95*89c4ff92SAndroid Build Coastguard Worker * the quantised values of a floating point input data. 96*89c4ff92SAndroid Build Coastguard Worker * @param[in] elem Element to be quantised 97*89c4ff92SAndroid Build Coastguard Worker * @param[in] quantScale Scale 98*89c4ff92SAndroid Build Coastguard Worker * @param[in] quantOffset Offset 99*89c4ff92SAndroid Build Coastguard Worker * @param[in] minVal Numerical limit - minimum 100*89c4ff92SAndroid Build Coastguard Worker * @param[in] maxVal Numerical limit - maximum 101*89c4ff92SAndroid Build Coastguard Worker * @return floating point quantised value 102*89c4ff92SAndroid Build Coastguard Worker */ 103*89c4ff92SAndroid Build Coastguard Worker static float GetQuantElem( 104*89c4ff92SAndroid Build Coastguard Worker float elem, 105*89c4ff92SAndroid Build Coastguard Worker float quantScale, 106*89c4ff92SAndroid Build Coastguard Worker int quantOffset, 107*89c4ff92SAndroid Build Coastguard Worker float minVal, 108*89c4ff92SAndroid Build Coastguard Worker float maxVal); 109*89c4ff92SAndroid Build Coastguard Worker 110*89c4ff92SAndroid Build Coastguard Worker /** 111*89c4ff92SAndroid Build Coastguard Worker * @brief Quantises the MFCC and delta buffers, and places them 112*89c4ff92SAndroid Build Coastguard Worker * in the output buffer. While doing so, it transposes 113*89c4ff92SAndroid Build Coastguard Worker * the data. Reason: Buffers in this class are arranged 114*89c4ff92SAndroid Build Coastguard Worker * for "time" axis to be row major. Primary reason for 115*89c4ff92SAndroid Build Coastguard Worker * this being the convolution speed up (as we can use 116*89c4ff92SAndroid Build Coastguard Worker * contiguous memory). The output, however, requires the 117*89c4ff92SAndroid Build Coastguard Worker * time axis to be in column major arrangement. 118*89c4ff92SAndroid Build Coastguard Worker * @param[in] outputBuf pointer to the output buffer 119*89c4ff92SAndroid Build Coastguard Worker * @param[in] outputBufSz output buffer's size 120*89c4ff92SAndroid Build Coastguard Worker * @param[in] quantScale quantisation scale 121*89c4ff92SAndroid Build Coastguard Worker * @param[in] quantOffset quantisation offset 122*89c4ff92SAndroid Build Coastguard Worker */ 123*89c4ff92SAndroid Build Coastguard Worker template<typename T> Quantise(T * outputBuf,int quantOffset,float quantScale)124*89c4ff92SAndroid Build Coastguard Worker bool Quantise(T*outputBuf, int quantOffset, float quantScale) 125*89c4ff92SAndroid Build Coastguard Worker { 126*89c4ff92SAndroid Build Coastguard Worker // Populate 127*89c4ff92SAndroid Build Coastguard Worker T* outputBufMfcc = outputBuf; 128*89c4ff92SAndroid Build Coastguard Worker T* outputBufD1 = outputBuf + this->m_mfcc->m_params.m_numMfccFeatures; 129*89c4ff92SAndroid Build Coastguard Worker T* outputBufD2 = outputBufD1 + this->m_mfcc->m_params.m_numMfccFeatures; 130*89c4ff92SAndroid Build Coastguard Worker const uint32_t ptrIncr = this->m_mfcc->m_params.m_numMfccFeatures * 2; // (3 vectors - 1 vector) 131*89c4ff92SAndroid Build Coastguard Worker 132*89c4ff92SAndroid Build Coastguard Worker const float minVal = std::numeric_limits<T>::min(); 133*89c4ff92SAndroid Build Coastguard Worker const float maxVal = std::numeric_limits<T>::max(); 134*89c4ff92SAndroid Build Coastguard Worker 135*89c4ff92SAndroid Build Coastguard Worker // We need to do a transpose while copying and concatenating the tensor 136*89c4ff92SAndroid Build Coastguard Worker for (uint32_t j = 0; j < this->m_mfcc->m_params.m_numMfccVectors; ++j) 137*89c4ff92SAndroid Build Coastguard Worker { 138*89c4ff92SAndroid Build Coastguard Worker for (uint32_t i = 0; i < this->m_mfcc->m_params.m_numMfccFeatures; ++i) 139*89c4ff92SAndroid Build Coastguard Worker { 140*89c4ff92SAndroid Build Coastguard Worker *outputBufMfcc++ = static_cast<T>(Wav2LetterPreprocessor::GetQuantElem( 141*89c4ff92SAndroid Build Coastguard Worker this->m_mfccBuf(i, j), quantScale, 142*89c4ff92SAndroid Build Coastguard Worker quantOffset, minVal, maxVal)); 143*89c4ff92SAndroid Build Coastguard Worker *outputBufD1++ = static_cast<T>(Wav2LetterPreprocessor::GetQuantElem( 144*89c4ff92SAndroid Build Coastguard Worker this->m_delta1Buf(i, j), quantScale, 145*89c4ff92SAndroid Build Coastguard Worker quantOffset, minVal, maxVal)); 146*89c4ff92SAndroid Build Coastguard Worker *outputBufD2++ = static_cast<T>(Wav2LetterPreprocessor::GetQuantElem( 147*89c4ff92SAndroid Build Coastguard Worker this->m_delta2Buf(i, j), quantScale, 148*89c4ff92SAndroid Build Coastguard Worker quantOffset, minVal, maxVal)); 149*89c4ff92SAndroid Build Coastguard Worker } 150*89c4ff92SAndroid Build Coastguard Worker outputBufMfcc += ptrIncr; 151*89c4ff92SAndroid Build Coastguard Worker outputBufD1 += ptrIncr; 152*89c4ff92SAndroid Build Coastguard Worker outputBufD2 += ptrIncr; 153*89c4ff92SAndroid Build Coastguard Worker } 154*89c4ff92SAndroid Build Coastguard Worker return true; 155*89c4ff92SAndroid Build Coastguard Worker } 156*89c4ff92SAndroid Build Coastguard Worker }; 157*89c4ff92SAndroid Build Coastguard Worker 158*89c4ff92SAndroid Build Coastguard Worker #endif //SPEECH_RECOGNITION_EXAMPLE_WAV2LETTERPREPROCESSOR_HPP 159