xref: /aosp_15_r20/external/armnn/tests/DeepSpeechV1InferenceTest.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker #include "InferenceTest.hpp"
8*89c4ff92SAndroid Build Coastguard Worker #include "DeepSpeechV1Database.hpp"
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Assert.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp>
12*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/FloatingPointComparison.hpp>
13*89c4ff92SAndroid Build Coastguard Worker 
14*89c4ff92SAndroid Build Coastguard Worker #include <vector>
15*89c4ff92SAndroid Build Coastguard Worker 
16*89c4ff92SAndroid Build Coastguard Worker namespace
17*89c4ff92SAndroid Build Coastguard Worker {
18*89c4ff92SAndroid Build Coastguard Worker 
19*89c4ff92SAndroid Build Coastguard Worker template<typename Model>
20*89c4ff92SAndroid Build Coastguard Worker class DeepSpeechV1TestCase : public InferenceModelTestCase<Model>
21*89c4ff92SAndroid Build Coastguard Worker {
22*89c4ff92SAndroid Build Coastguard Worker public:
DeepSpeechV1TestCase(Model & model,unsigned int testCaseId,const DeepSpeechV1TestCaseData & testCaseData)23*89c4ff92SAndroid Build Coastguard Worker     DeepSpeechV1TestCase(Model& model,
24*89c4ff92SAndroid Build Coastguard Worker                          unsigned int testCaseId,
25*89c4ff92SAndroid Build Coastguard Worker                          const DeepSpeechV1TestCaseData& testCaseData)
26*89c4ff92SAndroid Build Coastguard Worker         : InferenceModelTestCase<Model>(model,
27*89c4ff92SAndroid Build Coastguard Worker                                         testCaseId,
28*89c4ff92SAndroid Build Coastguard Worker                                         { testCaseData.m_InputData.m_InputSeq,
29*89c4ff92SAndroid Build Coastguard Worker                                           testCaseData.m_InputData.m_StateH,
30*89c4ff92SAndroid Build Coastguard Worker                                           testCaseData.m_InputData.m_StateC},
31*89c4ff92SAndroid Build Coastguard Worker                                         { k_OutputSize1, k_OutputSize2, k_OutputSize3 })
32*89c4ff92SAndroid Build Coastguard Worker         , m_ExpectedOutputs({testCaseData.m_ExpectedOutputData.m_InputSeq, testCaseData.m_ExpectedOutputData.m_StateH,
33*89c4ff92SAndroid Build Coastguard Worker                              testCaseData.m_ExpectedOutputData.m_StateC})
34*89c4ff92SAndroid Build Coastguard Worker     {}
35*89c4ff92SAndroid Build Coastguard Worker 
ProcessResult(const InferenceTestOptions & options)36*89c4ff92SAndroid Build Coastguard Worker     TestCaseResult ProcessResult(const InferenceTestOptions& options) override
37*89c4ff92SAndroid Build Coastguard Worker     {
38*89c4ff92SAndroid Build Coastguard Worker         armnn::IgnoreUnused(options);
39*89c4ff92SAndroid Build Coastguard Worker         const std::vector<float>& output1 = mapbox::util::get<std::vector<float>>(this->GetOutputs()[0]); // logits
40*89c4ff92SAndroid Build Coastguard Worker         ARMNN_ASSERT(output1.size() == k_OutputSize1);
41*89c4ff92SAndroid Build Coastguard Worker 
42*89c4ff92SAndroid Build Coastguard Worker         const std::vector<float>& output2 = mapbox::util::get<std::vector<float>>(this->GetOutputs()[1]); // new_state_c
43*89c4ff92SAndroid Build Coastguard Worker         ARMNN_ASSERT(output2.size() == k_OutputSize2);
44*89c4ff92SAndroid Build Coastguard Worker 
45*89c4ff92SAndroid Build Coastguard Worker         const std::vector<float>& output3 = mapbox::util::get<std::vector<float>>(this->GetOutputs()[2]); // new_state_h
46*89c4ff92SAndroid Build Coastguard Worker         ARMNN_ASSERT(output3.size() == k_OutputSize3);
47*89c4ff92SAndroid Build Coastguard Worker 
48*89c4ff92SAndroid Build Coastguard Worker         // Check each output to see whether it is the expected value
49*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int j = 0u; j < output1.size(); j++)
50*89c4ff92SAndroid Build Coastguard Worker         {
51*89c4ff92SAndroid Build Coastguard Worker             if(!armnnUtils::within_percentage_tolerance(output1[j], m_ExpectedOutputs.m_InputSeq[j]))
52*89c4ff92SAndroid Build Coastguard Worker             {
53*89c4ff92SAndroid Build Coastguard Worker                 ARMNN_LOG(error) << "InputSeq for Lstm " << this->GetTestCaseId() <<
54*89c4ff92SAndroid Build Coastguard Worker                                          " is incorrect at" << j;
55*89c4ff92SAndroid Build Coastguard Worker                 return TestCaseResult::Failed;
56*89c4ff92SAndroid Build Coastguard Worker             }
57*89c4ff92SAndroid Build Coastguard Worker         }
58*89c4ff92SAndroid Build Coastguard Worker 
59*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int j = 0u; j < output2.size(); j++)
60*89c4ff92SAndroid Build Coastguard Worker         {
61*89c4ff92SAndroid Build Coastguard Worker             if(!armnnUtils::within_percentage_tolerance(output2[j], m_ExpectedOutputs.m_StateH[j]))
62*89c4ff92SAndroid Build Coastguard Worker             {
63*89c4ff92SAndroid Build Coastguard Worker                 ARMNN_LOG(error) << "StateH for Lstm " << this->GetTestCaseId() <<
64*89c4ff92SAndroid Build Coastguard Worker                                          " is incorrect";
65*89c4ff92SAndroid Build Coastguard Worker                 return TestCaseResult::Failed;
66*89c4ff92SAndroid Build Coastguard Worker             }
67*89c4ff92SAndroid Build Coastguard Worker         }
68*89c4ff92SAndroid Build Coastguard Worker 
69*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int j = 0u; j < output3.size(); j++)
70*89c4ff92SAndroid Build Coastguard Worker         {
71*89c4ff92SAndroid Build Coastguard Worker             if(!armnnUtils::within_percentage_tolerance(output3[j], m_ExpectedOutputs.m_StateC[j]))
72*89c4ff92SAndroid Build Coastguard Worker             {
73*89c4ff92SAndroid Build Coastguard Worker                 ARMNN_LOG(error) << "StateC for Lstm " << this->GetTestCaseId() <<
74*89c4ff92SAndroid Build Coastguard Worker                                          " is incorrect";
75*89c4ff92SAndroid Build Coastguard Worker                 return TestCaseResult::Failed;
76*89c4ff92SAndroid Build Coastguard Worker             }
77*89c4ff92SAndroid Build Coastguard Worker         }
78*89c4ff92SAndroid Build Coastguard Worker         return TestCaseResult::Ok;
79*89c4ff92SAndroid Build Coastguard Worker     }
80*89c4ff92SAndroid Build Coastguard Worker 
81*89c4ff92SAndroid Build Coastguard Worker private:
82*89c4ff92SAndroid Build Coastguard Worker 
83*89c4ff92SAndroid Build Coastguard Worker     static constexpr unsigned int k_OutputSize1 = 464u;
84*89c4ff92SAndroid Build Coastguard Worker     static constexpr unsigned int k_OutputSize2 = 2048u;
85*89c4ff92SAndroid Build Coastguard Worker     static constexpr unsigned int k_OutputSize3 = 2048u;
86*89c4ff92SAndroid Build Coastguard Worker 
87*89c4ff92SAndroid Build Coastguard Worker     LstmInput m_ExpectedOutputs;
88*89c4ff92SAndroid Build Coastguard Worker };
89*89c4ff92SAndroid Build Coastguard Worker 
90*89c4ff92SAndroid Build Coastguard Worker template <typename Model>
91*89c4ff92SAndroid Build Coastguard Worker class DeepSpeechV1TestCaseProvider : public IInferenceTestCaseProvider
92*89c4ff92SAndroid Build Coastguard Worker {
93*89c4ff92SAndroid Build Coastguard Worker public:
94*89c4ff92SAndroid Build Coastguard Worker     template <typename TConstructModelCallable>
DeepSpeechV1TestCaseProvider(TConstructModelCallable constructModel)95*89c4ff92SAndroid Build Coastguard Worker     explicit DeepSpeechV1TestCaseProvider(TConstructModelCallable constructModel)
96*89c4ff92SAndroid Build Coastguard Worker         : m_ConstructModel(constructModel)
97*89c4ff92SAndroid Build Coastguard Worker     {}
98*89c4ff92SAndroid Build Coastguard Worker 
AddCommandLineOptions(cxxopts::Options & options,std::vector<std::string> & required)99*89c4ff92SAndroid Build Coastguard Worker     virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required) override
100*89c4ff92SAndroid Build Coastguard Worker     {
101*89c4ff92SAndroid Build Coastguard Worker         options
102*89c4ff92SAndroid Build Coastguard Worker             .allow_unrecognised_options()
103*89c4ff92SAndroid Build Coastguard Worker             .add_options()
104*89c4ff92SAndroid Build Coastguard Worker                 ("s,input-seq-dir", "Path to directory containing test data for m_InputSeq",
105*89c4ff92SAndroid Build Coastguard Worker                  cxxopts::value<std::string>(m_InputSeqDir))
106*89c4ff92SAndroid Build Coastguard Worker                 ("h,prev-state-h-dir", "Path to directory containing test data for m_PrevStateH",
107*89c4ff92SAndroid Build Coastguard Worker                  cxxopts::value<std::string>(m_PrevStateHDir))
108*89c4ff92SAndroid Build Coastguard Worker                 ("c,prev-state-c-dir", "Path to directory containing test data for m_PrevStateC",
109*89c4ff92SAndroid Build Coastguard Worker                  cxxopts::value<std::string>(m_PrevStateCDir))
110*89c4ff92SAndroid Build Coastguard Worker                 ("l,logits-dir", "Path to directory containing test data for m_Logits",
111*89c4ff92SAndroid Build Coastguard Worker                  cxxopts::value<std::string>(m_LogitsDir))
112*89c4ff92SAndroid Build Coastguard Worker                 ("H,new-state-h-dir", "Path to directory containing test data for m_NewStateH",
113*89c4ff92SAndroid Build Coastguard Worker                  cxxopts::value<std::string>(m_NewStateHDir))
114*89c4ff92SAndroid Build Coastguard Worker                 ("C,new-state-c-dir", "Path to directory containing test data for m_NewStateC",
115*89c4ff92SAndroid Build Coastguard Worker                  cxxopts::value<std::string>(m_NewStateCDir));
116*89c4ff92SAndroid Build Coastguard Worker 
117*89c4ff92SAndroid Build Coastguard Worker         required.insert(required.end(), {"input-seq-dir", "prev-state-h-dir", "prev-state-c-dir", "logits-dir",
118*89c4ff92SAndroid Build Coastguard Worker                                          "new-state-h-dir", "new-state-c-dir"});
119*89c4ff92SAndroid Build Coastguard Worker 
120*89c4ff92SAndroid Build Coastguard Worker         Model::AddCommandLineOptions(options, m_ModelCommandLineOptions, required);
121*89c4ff92SAndroid Build Coastguard Worker     }
122*89c4ff92SAndroid Build Coastguard Worker 
ProcessCommandLineOptions(const InferenceTestOptions & commonOptions)123*89c4ff92SAndroid Build Coastguard Worker     virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override
124*89c4ff92SAndroid Build Coastguard Worker     {
125*89c4ff92SAndroid Build Coastguard Worker         if (!ValidateDirectory(m_InputSeqDir))
126*89c4ff92SAndroid Build Coastguard Worker         {
127*89c4ff92SAndroid Build Coastguard Worker             return false;
128*89c4ff92SAndroid Build Coastguard Worker         }
129*89c4ff92SAndroid Build Coastguard Worker 
130*89c4ff92SAndroid Build Coastguard Worker         if (!ValidateDirectory(m_PrevStateCDir))
131*89c4ff92SAndroid Build Coastguard Worker         {
132*89c4ff92SAndroid Build Coastguard Worker             return false;
133*89c4ff92SAndroid Build Coastguard Worker         }
134*89c4ff92SAndroid Build Coastguard Worker 
135*89c4ff92SAndroid Build Coastguard Worker         if (!ValidateDirectory(m_PrevStateHDir))
136*89c4ff92SAndroid Build Coastguard Worker         {
137*89c4ff92SAndroid Build Coastguard Worker             return false;
138*89c4ff92SAndroid Build Coastguard Worker         }
139*89c4ff92SAndroid Build Coastguard Worker 
140*89c4ff92SAndroid Build Coastguard Worker         if (!ValidateDirectory(m_LogitsDir))
141*89c4ff92SAndroid Build Coastguard Worker         {
142*89c4ff92SAndroid Build Coastguard Worker             return false;
143*89c4ff92SAndroid Build Coastguard Worker         }
144*89c4ff92SAndroid Build Coastguard Worker 
145*89c4ff92SAndroid Build Coastguard Worker         if (!ValidateDirectory(m_NewStateCDir))
146*89c4ff92SAndroid Build Coastguard Worker         {
147*89c4ff92SAndroid Build Coastguard Worker             return false;
148*89c4ff92SAndroid Build Coastguard Worker         }
149*89c4ff92SAndroid Build Coastguard Worker 
150*89c4ff92SAndroid Build Coastguard Worker         if (!ValidateDirectory(m_NewStateHDir))
151*89c4ff92SAndroid Build Coastguard Worker         {
152*89c4ff92SAndroid Build Coastguard Worker             return false;
153*89c4ff92SAndroid Build Coastguard Worker         }
154*89c4ff92SAndroid Build Coastguard Worker 
155*89c4ff92SAndroid Build Coastguard Worker         m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions);
156*89c4ff92SAndroid Build Coastguard Worker         if (!m_Model)
157*89c4ff92SAndroid Build Coastguard Worker         {
158*89c4ff92SAndroid Build Coastguard Worker             return false;
159*89c4ff92SAndroid Build Coastguard Worker         }
160*89c4ff92SAndroid Build Coastguard Worker         m_Database = std::make_unique<DeepSpeechV1Database>(m_InputSeqDir.c_str(), m_PrevStateHDir.c_str(),
161*89c4ff92SAndroid Build Coastguard Worker                                                             m_PrevStateCDir.c_str(), m_LogitsDir.c_str(),
162*89c4ff92SAndroid Build Coastguard Worker                                                             m_NewStateHDir.c_str(), m_NewStateCDir.c_str());
163*89c4ff92SAndroid Build Coastguard Worker         if (!m_Database)
164*89c4ff92SAndroid Build Coastguard Worker         {
165*89c4ff92SAndroid Build Coastguard Worker             return false;
166*89c4ff92SAndroid Build Coastguard Worker         }
167*89c4ff92SAndroid Build Coastguard Worker 
168*89c4ff92SAndroid Build Coastguard Worker         return true;
169*89c4ff92SAndroid Build Coastguard Worker     }
170*89c4ff92SAndroid Build Coastguard Worker 
GetTestCase(unsigned int testCaseId)171*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override
172*89c4ff92SAndroid Build Coastguard Worker     {
173*89c4ff92SAndroid Build Coastguard Worker         std::unique_ptr<DeepSpeechV1TestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
174*89c4ff92SAndroid Build Coastguard Worker         if (!testCaseData)
175*89c4ff92SAndroid Build Coastguard Worker         {
176*89c4ff92SAndroid Build Coastguard Worker             return nullptr;
177*89c4ff92SAndroid Build Coastguard Worker         }
178*89c4ff92SAndroid Build Coastguard Worker 
179*89c4ff92SAndroid Build Coastguard Worker         return std::make_unique<DeepSpeechV1TestCase<Model>>(*m_Model, testCaseId, *testCaseData);
180*89c4ff92SAndroid Build Coastguard Worker     }
181*89c4ff92SAndroid Build Coastguard Worker 
182*89c4ff92SAndroid Build Coastguard Worker private:
183*89c4ff92SAndroid Build Coastguard Worker     typename Model::CommandLineOptions m_ModelCommandLineOptions;
184*89c4ff92SAndroid Build Coastguard Worker     std::function<std::unique_ptr<Model>(const InferenceTestOptions&,
185*89c4ff92SAndroid Build Coastguard Worker                                          typename Model::CommandLineOptions)> m_ConstructModel;
186*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<Model> m_Model;
187*89c4ff92SAndroid Build Coastguard Worker 
188*89c4ff92SAndroid Build Coastguard Worker     std::string m_InputSeqDir;
189*89c4ff92SAndroid Build Coastguard Worker     std::string m_PrevStateCDir;
190*89c4ff92SAndroid Build Coastguard Worker     std::string m_PrevStateHDir;
191*89c4ff92SAndroid Build Coastguard Worker     std::string m_LogitsDir;
192*89c4ff92SAndroid Build Coastguard Worker     std::string m_NewStateCDir;
193*89c4ff92SAndroid Build Coastguard Worker     std::string m_NewStateHDir;
194*89c4ff92SAndroid Build Coastguard Worker 
195*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<DeepSpeechV1Database> m_Database;
196*89c4ff92SAndroid Build Coastguard Worker };
197*89c4ff92SAndroid Build Coastguard Worker 
198*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
199