xref: /aosp_15_r20/external/armnn/samples/SpeechRecognition/include/Decoder.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include <string>
7*89c4ff92SAndroid Build Coastguard Worker #include <map>
8*89c4ff92SAndroid Build Coastguard Worker #include <vector>
9*89c4ff92SAndroid Build Coastguard Worker #include <algorithm>
10*89c4ff92SAndroid Build Coastguard Worker #include <cmath>
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker # pragma once
13*89c4ff92SAndroid Build Coastguard Worker 
14*89c4ff92SAndroid Build Coastguard Worker namespace asr
15*89c4ff92SAndroid Build Coastguard Worker {
16*89c4ff92SAndroid Build Coastguard Worker /**
17*89c4ff92SAndroid Build Coastguard Worker * @brief Class used to Decode the output of the ASR inference
18*89c4ff92SAndroid Build Coastguard Worker *
19*89c4ff92SAndroid Build Coastguard Worker */
20*89c4ff92SAndroid Build Coastguard Worker     class Decoder
21*89c4ff92SAndroid Build Coastguard Worker     {
22*89c4ff92SAndroid Build Coastguard Worker     public:
23*89c4ff92SAndroid Build Coastguard Worker         std::map<int, std::string> m_labels;
24*89c4ff92SAndroid Build Coastguard Worker         /**
25*89c4ff92SAndroid Build Coastguard Worker         * @brief Default constructor
26*89c4ff92SAndroid Build Coastguard Worker         * @param[in] labels - map of labels to be used for decoding to text.
27*89c4ff92SAndroid Build Coastguard Worker         */
28*89c4ff92SAndroid Build Coastguard Worker         Decoder(std::map<int, std::string>& labels);
29*89c4ff92SAndroid Build Coastguard Worker 
30*89c4ff92SAndroid Build Coastguard Worker         /**
31*89c4ff92SAndroid Build Coastguard Worker         * @brief Function to decode the output into a text string
32*89c4ff92SAndroid Build Coastguard Worker         * @param[in] output - the output vector to decode.
33*89c4ff92SAndroid Build Coastguard Worker         */
34*89c4ff92SAndroid Build Coastguard Worker         template<typename T>
DecodeOutput(std::vector<T> & contextToProcess)35*89c4ff92SAndroid Build Coastguard Worker         std::string DecodeOutput(std::vector<T>& contextToProcess)
36*89c4ff92SAndroid Build Coastguard Worker         {
37*89c4ff92SAndroid Build Coastguard Worker             int rowLength = 29;
38*89c4ff92SAndroid Build Coastguard Worker 
39*89c4ff92SAndroid Build Coastguard Worker             std::vector<char> unfilteredText;
40*89c4ff92SAndroid Build Coastguard Worker 
41*89c4ff92SAndroid Build Coastguard Worker             for(int row = 0; row < contextToProcess.size()/rowLength; ++row)
42*89c4ff92SAndroid Build Coastguard Worker             {
43*89c4ff92SAndroid Build Coastguard Worker                 std::vector<int16_t> rowVector;
44*89c4ff92SAndroid Build Coastguard Worker                 for(int j = 0; j < rowLength; ++j)
45*89c4ff92SAndroid Build Coastguard Worker                 {
46*89c4ff92SAndroid Build Coastguard Worker                     rowVector.emplace_back(static_cast<int16_t>(contextToProcess[row * rowLength + j]));
47*89c4ff92SAndroid Build Coastguard Worker                 }
48*89c4ff92SAndroid Build Coastguard Worker 
49*89c4ff92SAndroid Build Coastguard Worker                 int maxIndex = std::distance(rowVector.begin(), std::max_element(rowVector.begin(), rowVector.end()));
50*89c4ff92SAndroid Build Coastguard Worker                 unfilteredText.emplace_back(this->m_labels.at(maxIndex)[0]);
51*89c4ff92SAndroid Build Coastguard Worker             }
52*89c4ff92SAndroid Build Coastguard Worker 
53*89c4ff92SAndroid Build Coastguard Worker             std::string filteredText = FilterCharacters(unfilteredText);
54*89c4ff92SAndroid Build Coastguard Worker             return filteredText;
55*89c4ff92SAndroid Build Coastguard Worker         }
56*89c4ff92SAndroid Build Coastguard Worker 
57*89c4ff92SAndroid Build Coastguard Worker         /**
58*89c4ff92SAndroid Build Coastguard Worker         * @brief Function to filter out unwanted characters
59*89c4ff92SAndroid Build Coastguard Worker         * @param[in] unfiltered - the unfiltered output to be processed.
60*89c4ff92SAndroid Build Coastguard Worker         */
61*89c4ff92SAndroid Build Coastguard Worker         std::string FilterCharacters(std::vector<char>& unfiltered);
62*89c4ff92SAndroid Build Coastguard Worker     };
63*89c4ff92SAndroid Build Coastguard Worker } // namespace asr
64