xref: /aosp_15_r20/external/armnn/samples/SpeechRecognition/include/Wav2LetterMFCC.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "MFCC.hpp"
8 
9 /* Class to provide Wav2Letter specific MFCC calculation requirements. */
10 class Wav2LetterMFCC : public MFCC
11 {
12 
13 public:
Wav2LetterMFCC(const MfccParams & params)14     explicit Wav2LetterMFCC(const MfccParams& params)
15         :  MFCC(params)
16     {}
17 
18     Wav2LetterMFCC()  = delete;
19     ~Wav2LetterMFCC() = default;
20 
21 protected:
22 
23     /**
24      * @brief       Overrides base class implementation of this function.
25      * @param[in]   fftVec                  Vector populated with FFT magnitudes
26      * @param[in]   melFilterBank           2D Vector with filter bank weights
27      * @param[in]   filterBankFilterFirst   Vector containing the first indices of filter bank
28      *                                      to be used for each bin.
29      * @param[in]   filterBankFilterLast    Vector containing the last indices of filter bank
30      *                                      to be used for each bin.
31      * @param[out]  melEnergies             Pre-allocated vector of MEL energies to be
32      *                                      populated.
33      * @return      true if successful, false otherwise
34      */
35     bool ApplyMelFilterBank(
36         std::vector<float>&                 fftVec,
37         std::vector<std::vector<float>>&    melFilterBank,
38         std::vector<uint32_t>&              filterBankFilterFirst,
39         std::vector<uint32_t>&              filterBankFilterLast,
40         std::vector<float>&                 melEnergies) override;
41 
42     /**
43      * @brief           Override for the base class implementation convert mel
44      *                  energies to logarithmic scale. The difference from
45      *                  default behaviour is that the power is converted to dB
46      *                  and subsequently clamped.
47      * @param[in,out]   melEnergies   1D vector of Mel energies
48      **/
49     void ConvertToLogarithmicScale(std::vector<float>& melEnergies) override;
50 
51     /**
52      * @brief       Create a matrix used to calculate Discrete Cosine
53      *              Transform. Override for the base class' default
54      *              implementation as the first and last elements
55      *              use a different normaliser.
56      * @param[in]   inputLength        input length of the buffer on which
57      *                                 DCT will be performed
58      * @param[in]   coefficientCount   Total coefficients per input length.
59      * @return      1D vector with inputLength x coefficientCount elements
60      *              populated with DCT coefficients.
61      */
62     std::vector<float> CreateDCTMatrix(int32_t inputLength,
63                                        int32_t coefficientCount) override;
64 
65     /**
66      * @brief       Given the low and high Mel values, get the normaliser
67      *              for weights to be applied when populating the filter
68      *              bank. Override for the base class implementation.
69      * @param[in]   leftMel        Low Mel frequency value.
70      * @param[in]   rightMel       High Mel frequency value.
71      * @param[in]   useHTKMethod   bool to signal if HTK method is to be
72      *                             used for calculation.
73      * @return      Value to use for normalising.
74      */
75     float GetMelFilterBankNormaliser(const float&   leftMel,
76                                      const float&   rightMel,
77                                      bool     useHTKMethod) override;
78 };