1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include <string> 7*89c4ff92SAndroid Build Coastguard Worker #include <map> 8*89c4ff92SAndroid Build Coastguard Worker #include <vector> 9*89c4ff92SAndroid Build Coastguard Worker #include <algorithm> 10*89c4ff92SAndroid Build Coastguard Worker #include <cmath> 11*89c4ff92SAndroid Build Coastguard Worker 12*89c4ff92SAndroid Build Coastguard Worker # pragma once 13*89c4ff92SAndroid Build Coastguard Worker 14*89c4ff92SAndroid Build Coastguard Worker namespace asr 15*89c4ff92SAndroid Build Coastguard Worker { 16*89c4ff92SAndroid Build Coastguard Worker /** 17*89c4ff92SAndroid Build Coastguard Worker * @brief Class used to Decode the output of the ASR inference 18*89c4ff92SAndroid Build Coastguard Worker * 19*89c4ff92SAndroid Build Coastguard Worker */ 20*89c4ff92SAndroid Build Coastguard Worker class Decoder 21*89c4ff92SAndroid Build Coastguard Worker { 22*89c4ff92SAndroid Build Coastguard Worker public: 23*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::string> m_labels; 24*89c4ff92SAndroid Build Coastguard Worker /** 25*89c4ff92SAndroid Build Coastguard Worker * @brief Default constructor 26*89c4ff92SAndroid Build Coastguard Worker * @param[in] labels - map of labels to be used for decoding to text. 27*89c4ff92SAndroid Build Coastguard Worker */ 28*89c4ff92SAndroid Build Coastguard Worker Decoder(std::map<int, std::string>& labels); 29*89c4ff92SAndroid Build Coastguard Worker 30*89c4ff92SAndroid Build Coastguard Worker /** 31*89c4ff92SAndroid Build Coastguard Worker * @brief Function to decode the output into a text string 32*89c4ff92SAndroid Build Coastguard Worker * @param[in] output - the output vector to decode. 33*89c4ff92SAndroid Build Coastguard Worker */ 34*89c4ff92SAndroid Build Coastguard Worker template<typename T> DecodeOutput(std::vector<T> & contextToProcess)35*89c4ff92SAndroid Build Coastguard Worker std::string DecodeOutput(std::vector<T>& contextToProcess) 36*89c4ff92SAndroid Build Coastguard Worker { 37*89c4ff92SAndroid Build Coastguard Worker int rowLength = 29; 38*89c4ff92SAndroid Build Coastguard Worker 39*89c4ff92SAndroid Build Coastguard Worker std::vector<char> unfilteredText; 40*89c4ff92SAndroid Build Coastguard Worker 41*89c4ff92SAndroid Build Coastguard Worker for(int row = 0; row < contextToProcess.size()/rowLength; ++row) 42*89c4ff92SAndroid Build Coastguard Worker { 43*89c4ff92SAndroid Build Coastguard Worker std::vector<int16_t> rowVector; 44*89c4ff92SAndroid Build Coastguard Worker for(int j = 0; j < rowLength; ++j) 45*89c4ff92SAndroid Build Coastguard Worker { 46*89c4ff92SAndroid Build Coastguard Worker rowVector.emplace_back(static_cast<int16_t>(contextToProcess[row * rowLength + j])); 47*89c4ff92SAndroid Build Coastguard Worker } 48*89c4ff92SAndroid Build Coastguard Worker 49*89c4ff92SAndroid Build Coastguard Worker int maxIndex = std::distance(rowVector.begin(), std::max_element(rowVector.begin(), rowVector.end())); 50*89c4ff92SAndroid Build Coastguard Worker unfilteredText.emplace_back(this->m_labels.at(maxIndex)[0]); 51*89c4ff92SAndroid Build Coastguard Worker } 52*89c4ff92SAndroid Build Coastguard Worker 53*89c4ff92SAndroid Build Coastguard Worker std::string filteredText = FilterCharacters(unfilteredText); 54*89c4ff92SAndroid Build Coastguard Worker return filteredText; 55*89c4ff92SAndroid Build Coastguard Worker } 56*89c4ff92SAndroid Build Coastguard Worker 57*89c4ff92SAndroid Build Coastguard Worker /** 58*89c4ff92SAndroid Build Coastguard Worker * @brief Function to filter out unwanted characters 59*89c4ff92SAndroid Build Coastguard Worker * @param[in] unfiltered - the unfiltered output to be processed. 60*89c4ff92SAndroid Build Coastguard Worker */ 61*89c4ff92SAndroid Build Coastguard Worker std::string FilterCharacters(std::vector<char>& unfiltered); 62*89c4ff92SAndroid Build Coastguard Worker }; 63*89c4ff92SAndroid Build Coastguard Worker } // namespace asr 64