xref: /aosp_15_r20/external/armnn/tests/DeepSpeechV1Database.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "LstmCommon.hpp"
8 
9 #include <memory>
10 #include <string>
11 #include <vector>
12 
13 #include <armnn/TypesUtils.hpp>
14 #include <armnn/utility/NumericCast.hpp>
15 
16 #include <array>
17 #include <string>
18 
19 #include "InferenceTestImage.hpp"
20 
21 namespace
22 {
23 
24 template<typename T, typename TParseElementFunc>
ParseArrayImpl(std::istream & stream,TParseElementFunc parseElementFunc,const char * chars="\\t ,:")25 std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc, const char * chars = "\t ,:")
26 {
27     std::vector<T> result;
28     // Processes line-by-line.
29     std::string line;
30     while (std::getline(stream, line))
31     {
32         std::vector<std::string> tokens = armnn::stringUtils::StringTokenizer(line, chars);
33         for (const std::string& token : tokens)
34         {
35             if (!token.empty()) // See https://stackoverflow.com/questions/10437406/
36             {
37                 try
38                 {
39                     result.push_back(parseElementFunc(token));
40                 }
41                 catch (const std::exception&)
42                 {
43                     ARMNN_LOG(error) << "'" << token << "' is not a valid number. It has been ignored.";
44                 }
45             }
46         }
47     }
48 
49     return result;
50 }
51 
52 template<armnn::DataType NonQuantizedType>
53 auto ParseDataArray(std::istream & stream);
54 
55 template<armnn::DataType QuantizedType>
56 auto ParseDataArray(std::istream& stream,
57                     const float& quantizationScale,
58                     const int32_t& quantizationOffset);
59 
60 // NOTE: declaring the template specialisations inline to prevent them
61 //       being flagged as unused functions when -Werror=unused-function is in effect
62 template<>
ParseDataArray(std::istream & stream)63 inline auto ParseDataArray<armnn::DataType::Float32>(std::istream & stream)
64 {
65     return ParseArrayImpl<float>(stream, [](const std::string& s) { return std::stof(s); });
66 }
67 
68 template<>
ParseDataArray(std::istream & stream)69 inline auto ParseDataArray<armnn::DataType::Signed32>(std::istream & stream)
70 {
71     return ParseArrayImpl<int>(stream, [](const std::string & s) { return std::stoi(s); });
72 }
73 
74 template<>
ParseDataArray(std::istream & stream,const float & quantizationScale,const int32_t & quantizationOffset)75 inline auto ParseDataArray<armnn::DataType::QAsymmU8>(std::istream& stream,
76                                                       const float& quantizationScale,
77                                                       const int32_t& quantizationOffset)
78 {
79     return ParseArrayImpl<uint8_t>(stream,
80                                    [&quantizationScale, &quantizationOffset](const std::string & s)
81                                    {
82                                        return armnn::numeric_cast<uint8_t>(
83                                                armnn::Quantize<uint8_t>(std::stof(s),
84                                                                          quantizationScale,
85                                                                          quantizationOffset));
86                                    });
87 }
88 
89 struct DeepSpeechV1TestCaseData
90 {
DeepSpeechV1TestCaseData__anond00d63bc0111::DeepSpeechV1TestCaseData91     DeepSpeechV1TestCaseData(
92         const LstmInput& inputData,
93         const LstmInput& expectedOutputData)
94         : m_InputData(inputData)
95         , m_ExpectedOutputData(expectedOutputData)
96     {}
97 
98     LstmInput m_InputData;
99     LstmInput m_ExpectedOutputData;
100 };
101 
102 class DeepSpeechV1Database
103 {
104 public:
105     explicit DeepSpeechV1Database(const std::string& inputSeqDir, const std::string& prevStateHDir,
106                                   const std::string& prevStateCDir, const std::string& logitsDir,
107                                   const std::string& newStateHDir, const std::string& newStateCDir);
108 
109     std::unique_ptr<DeepSpeechV1TestCaseData> GetTestCaseData(unsigned int testCaseId);
110 
111 private:
112     std::string m_InputSeqDir;
113     std::string m_PrevStateHDir;
114     std::string m_PrevStateCDir;
115     std::string m_LogitsDir;
116     std::string m_NewStateHDir;
117     std::string m_NewStateCDir;
118 };
119 
DeepSpeechV1Database(const std::string & inputSeqDir,const std::string & prevStateHDir,const std::string & prevStateCDir,const std::string & logitsDir,const std::string & newStateHDir,const std::string & newStateCDir)120 DeepSpeechV1Database::DeepSpeechV1Database(const std::string& inputSeqDir, const std::string& prevStateHDir,
121                                            const std::string& prevStateCDir, const std::string& logitsDir,
122                                            const std::string& newStateHDir, const std::string& newStateCDir)
123     : m_InputSeqDir(inputSeqDir)
124     , m_PrevStateHDir(prevStateHDir)
125     , m_PrevStateCDir(prevStateCDir)
126     , m_LogitsDir(logitsDir)
127     , m_NewStateHDir(newStateHDir)
128     , m_NewStateCDir(newStateCDir)
129 {}
130 
GetTestCaseData(unsigned int testCaseId)131 std::unique_ptr<DeepSpeechV1TestCaseData> DeepSpeechV1Database::GetTestCaseData(unsigned int testCaseId)
132 {
133     // Load test case input
134     const std::string inputSeqPath   = m_InputSeqDir + "input_node_0_flat.txt";
135     const std::string prevStateCPath = m_PrevStateCDir + "previous_state_c_0.txt";
136     const std::string prevStateHPath = m_PrevStateHDir + "previous_state_h_0.txt";
137 
138     std::vector<float> inputSeqData;
139     std::vector<float> prevStateCData;
140     std::vector<float> prevStateHData;
141 
142     std::ifstream inputSeqFile(inputSeqPath);
143     std::ifstream prevStateCTensorFile(prevStateCPath);
144     std::ifstream prevStateHTensorFile(prevStateHPath);
145 
146     try
147     {
148         inputSeqData   = ParseDataArray<armnn::DataType::Float32>(inputSeqFile);
149         prevStateCData = ParseDataArray<armnn::DataType::Float32>(prevStateCTensorFile);
150         prevStateHData = ParseDataArray<armnn::DataType::Float32>(prevStateHTensorFile);
151     }
152     catch (const InferenceTestImageException& e)
153     {
154         ARMNN_LOG(fatal) << "Failed to load image for test case " << testCaseId << ". Error: " << e.what();
155         return nullptr;
156     }
157 
158     // Prepare test case expected output
159     const std::string logitsPath   = m_LogitsDir + "logits.txt";
160     const std::string newStateCPath = m_NewStateCDir + "new_state_c.txt";
161     const std::string newStateHPath = m_NewStateHDir + "new_state_h.txt";
162 
163     std::vector<float> logitsData;
164     std::vector<float> expectedNewStateCData;
165     std::vector<float> expectedNewStateHData;
166 
167     std::ifstream logitsTensorFile(logitsPath);
168     std::ifstream newStateCTensorFile(newStateCPath);
169     std::ifstream newStateHTensorFile(newStateHPath);
170 
171     try
172     {
173         logitsData     = ParseDataArray<armnn::DataType::Float32>(logitsTensorFile);
174         expectedNewStateCData = ParseDataArray<armnn::DataType::Float32>(newStateCTensorFile);
175         expectedNewStateHData = ParseDataArray<armnn::DataType::Float32>(newStateHTensorFile);
176     }
177     catch (const InferenceTestImageException& e)
178     {
179         ARMNN_LOG(fatal) << "Failed to load image for test case " << testCaseId << ". Error: " << e.what();
180         return nullptr;
181     }
182 
183     // use the struct for representing input and output data
184     LstmInput inputDataSingleTest(inputSeqData, prevStateHData, prevStateCData);
185 
186     LstmInput expectedOutputsSingleTest(logitsData, expectedNewStateHData, expectedNewStateCData);
187 
188     return std::make_unique<DeepSpeechV1TestCaseData>(inputDataSingleTest, expectedOutputsSingleTest);
189 }
190 
191 } // anonymous namespace
192