1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker #pragma once 6*89c4ff92SAndroid Build Coastguard Worker 7*89c4ff92SAndroid Build Coastguard Worker #include "InferenceTest.hpp" 8*89c4ff92SAndroid Build Coastguard Worker #include "DeepSpeechV1Database.hpp" 9*89c4ff92SAndroid Build Coastguard Worker 10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Assert.hpp> 11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp> 12*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/FloatingPointComparison.hpp> 13*89c4ff92SAndroid Build Coastguard Worker 14*89c4ff92SAndroid Build Coastguard Worker #include <vector> 15*89c4ff92SAndroid Build Coastguard Worker 16*89c4ff92SAndroid Build Coastguard Worker namespace 17*89c4ff92SAndroid Build Coastguard Worker { 18*89c4ff92SAndroid Build Coastguard Worker 19*89c4ff92SAndroid Build Coastguard Worker template<typename Model> 20*89c4ff92SAndroid Build Coastguard Worker class DeepSpeechV1TestCase : public InferenceModelTestCase<Model> 21*89c4ff92SAndroid Build Coastguard Worker { 22*89c4ff92SAndroid Build Coastguard Worker public: DeepSpeechV1TestCase(Model & model,unsigned int testCaseId,const DeepSpeechV1TestCaseData & testCaseData)23*89c4ff92SAndroid Build Coastguard Worker DeepSpeechV1TestCase(Model& model, 24*89c4ff92SAndroid Build Coastguard Worker unsigned int testCaseId, 25*89c4ff92SAndroid Build Coastguard Worker const DeepSpeechV1TestCaseData& testCaseData) 26*89c4ff92SAndroid Build Coastguard Worker : InferenceModelTestCase<Model>(model, 27*89c4ff92SAndroid Build Coastguard Worker testCaseId, 28*89c4ff92SAndroid Build Coastguard Worker { testCaseData.m_InputData.m_InputSeq, 29*89c4ff92SAndroid Build Coastguard Worker testCaseData.m_InputData.m_StateH, 30*89c4ff92SAndroid Build Coastguard Worker testCaseData.m_InputData.m_StateC}, 31*89c4ff92SAndroid Build Coastguard Worker { k_OutputSize1, k_OutputSize2, k_OutputSize3 }) 32*89c4ff92SAndroid Build Coastguard Worker , m_ExpectedOutputs({testCaseData.m_ExpectedOutputData.m_InputSeq, testCaseData.m_ExpectedOutputData.m_StateH, 33*89c4ff92SAndroid Build Coastguard Worker testCaseData.m_ExpectedOutputData.m_StateC}) 34*89c4ff92SAndroid Build Coastguard Worker {} 35*89c4ff92SAndroid Build Coastguard Worker ProcessResult(const InferenceTestOptions & options)36*89c4ff92SAndroid Build Coastguard Worker TestCaseResult ProcessResult(const InferenceTestOptions& options) override 37*89c4ff92SAndroid Build Coastguard Worker { 38*89c4ff92SAndroid Build Coastguard Worker armnn::IgnoreUnused(options); 39*89c4ff92SAndroid Build Coastguard Worker const std::vector<float>& output1 = mapbox::util::get<std::vector<float>>(this->GetOutputs()[0]); // logits 40*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(output1.size() == k_OutputSize1); 41*89c4ff92SAndroid Build Coastguard Worker 42*89c4ff92SAndroid Build Coastguard Worker const std::vector<float>& output2 = mapbox::util::get<std::vector<float>>(this->GetOutputs()[1]); // new_state_c 43*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(output2.size() == k_OutputSize2); 44*89c4ff92SAndroid Build Coastguard Worker 45*89c4ff92SAndroid Build Coastguard Worker const std::vector<float>& output3 = mapbox::util::get<std::vector<float>>(this->GetOutputs()[2]); // new_state_h 46*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(output3.size() == k_OutputSize3); 47*89c4ff92SAndroid Build Coastguard Worker 48*89c4ff92SAndroid Build Coastguard Worker // Check each output to see whether it is the expected value 49*89c4ff92SAndroid Build Coastguard Worker for (unsigned int j = 0u; j < output1.size(); j++) 50*89c4ff92SAndroid Build Coastguard Worker { 51*89c4ff92SAndroid Build Coastguard Worker if(!armnnUtils::within_percentage_tolerance(output1[j], m_ExpectedOutputs.m_InputSeq[j])) 52*89c4ff92SAndroid Build Coastguard Worker { 53*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << "InputSeq for Lstm " << this->GetTestCaseId() << 54*89c4ff92SAndroid Build Coastguard Worker " is incorrect at" << j; 55*89c4ff92SAndroid Build Coastguard Worker return TestCaseResult::Failed; 56*89c4ff92SAndroid Build Coastguard Worker } 57*89c4ff92SAndroid Build Coastguard Worker } 58*89c4ff92SAndroid Build Coastguard Worker 59*89c4ff92SAndroid Build Coastguard Worker for (unsigned int j = 0u; j < output2.size(); j++) 60*89c4ff92SAndroid Build Coastguard Worker { 61*89c4ff92SAndroid Build Coastguard Worker if(!armnnUtils::within_percentage_tolerance(output2[j], m_ExpectedOutputs.m_StateH[j])) 62*89c4ff92SAndroid Build Coastguard Worker { 63*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << "StateH for Lstm " << this->GetTestCaseId() << 64*89c4ff92SAndroid Build Coastguard Worker " is incorrect"; 65*89c4ff92SAndroid Build Coastguard Worker return TestCaseResult::Failed; 66*89c4ff92SAndroid Build Coastguard Worker } 67*89c4ff92SAndroid Build Coastguard Worker } 68*89c4ff92SAndroid Build Coastguard Worker 69*89c4ff92SAndroid Build Coastguard Worker for (unsigned int j = 0u; j < output3.size(); j++) 70*89c4ff92SAndroid Build Coastguard Worker { 71*89c4ff92SAndroid Build Coastguard Worker if(!armnnUtils::within_percentage_tolerance(output3[j], m_ExpectedOutputs.m_StateC[j])) 72*89c4ff92SAndroid Build Coastguard Worker { 73*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << "StateC for Lstm " << this->GetTestCaseId() << 74*89c4ff92SAndroid Build Coastguard Worker " is incorrect"; 75*89c4ff92SAndroid Build Coastguard Worker return TestCaseResult::Failed; 76*89c4ff92SAndroid Build Coastguard Worker } 77*89c4ff92SAndroid Build Coastguard Worker } 78*89c4ff92SAndroid Build Coastguard Worker return TestCaseResult::Ok; 79*89c4ff92SAndroid Build Coastguard Worker } 80*89c4ff92SAndroid Build Coastguard Worker 81*89c4ff92SAndroid Build Coastguard Worker private: 82*89c4ff92SAndroid Build Coastguard Worker 83*89c4ff92SAndroid Build Coastguard Worker static constexpr unsigned int k_OutputSize1 = 464u; 84*89c4ff92SAndroid Build Coastguard Worker static constexpr unsigned int k_OutputSize2 = 2048u; 85*89c4ff92SAndroid Build Coastguard Worker static constexpr unsigned int k_OutputSize3 = 2048u; 86*89c4ff92SAndroid Build Coastguard Worker 87*89c4ff92SAndroid Build Coastguard Worker LstmInput m_ExpectedOutputs; 88*89c4ff92SAndroid Build Coastguard Worker }; 89*89c4ff92SAndroid Build Coastguard Worker 90*89c4ff92SAndroid Build Coastguard Worker template <typename Model> 91*89c4ff92SAndroid Build Coastguard Worker class DeepSpeechV1TestCaseProvider : public IInferenceTestCaseProvider 92*89c4ff92SAndroid Build Coastguard Worker { 93*89c4ff92SAndroid Build Coastguard Worker public: 94*89c4ff92SAndroid Build Coastguard Worker template <typename TConstructModelCallable> DeepSpeechV1TestCaseProvider(TConstructModelCallable constructModel)95*89c4ff92SAndroid Build Coastguard Worker explicit DeepSpeechV1TestCaseProvider(TConstructModelCallable constructModel) 96*89c4ff92SAndroid Build Coastguard Worker : m_ConstructModel(constructModel) 97*89c4ff92SAndroid Build Coastguard Worker {} 98*89c4ff92SAndroid Build Coastguard Worker AddCommandLineOptions(cxxopts::Options & options,std::vector<std::string> & required)99*89c4ff92SAndroid Build Coastguard Worker virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required) override 100*89c4ff92SAndroid Build Coastguard Worker { 101*89c4ff92SAndroid Build Coastguard Worker options 102*89c4ff92SAndroid Build Coastguard Worker .allow_unrecognised_options() 103*89c4ff92SAndroid Build Coastguard Worker .add_options() 104*89c4ff92SAndroid Build Coastguard Worker ("s,input-seq-dir", "Path to directory containing test data for m_InputSeq", 105*89c4ff92SAndroid Build Coastguard Worker cxxopts::value<std::string>(m_InputSeqDir)) 106*89c4ff92SAndroid Build Coastguard Worker ("h,prev-state-h-dir", "Path to directory containing test data for m_PrevStateH", 107*89c4ff92SAndroid Build Coastguard Worker cxxopts::value<std::string>(m_PrevStateHDir)) 108*89c4ff92SAndroid Build Coastguard Worker ("c,prev-state-c-dir", "Path to directory containing test data for m_PrevStateC", 109*89c4ff92SAndroid Build Coastguard Worker cxxopts::value<std::string>(m_PrevStateCDir)) 110*89c4ff92SAndroid Build Coastguard Worker ("l,logits-dir", "Path to directory containing test data for m_Logits", 111*89c4ff92SAndroid Build Coastguard Worker cxxopts::value<std::string>(m_LogitsDir)) 112*89c4ff92SAndroid Build Coastguard Worker ("H,new-state-h-dir", "Path to directory containing test data for m_NewStateH", 113*89c4ff92SAndroid Build Coastguard Worker cxxopts::value<std::string>(m_NewStateHDir)) 114*89c4ff92SAndroid Build Coastguard Worker ("C,new-state-c-dir", "Path to directory containing test data for m_NewStateC", 115*89c4ff92SAndroid Build Coastguard Worker cxxopts::value<std::string>(m_NewStateCDir)); 116*89c4ff92SAndroid Build Coastguard Worker 117*89c4ff92SAndroid Build Coastguard Worker required.insert(required.end(), {"input-seq-dir", "prev-state-h-dir", "prev-state-c-dir", "logits-dir", 118*89c4ff92SAndroid Build Coastguard Worker "new-state-h-dir", "new-state-c-dir"}); 119*89c4ff92SAndroid Build Coastguard Worker 120*89c4ff92SAndroid Build Coastguard Worker Model::AddCommandLineOptions(options, m_ModelCommandLineOptions, required); 121*89c4ff92SAndroid Build Coastguard Worker } 122*89c4ff92SAndroid Build Coastguard Worker ProcessCommandLineOptions(const InferenceTestOptions & commonOptions)123*89c4ff92SAndroid Build Coastguard Worker virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override 124*89c4ff92SAndroid Build Coastguard Worker { 125*89c4ff92SAndroid Build Coastguard Worker if (!ValidateDirectory(m_InputSeqDir)) 126*89c4ff92SAndroid Build Coastguard Worker { 127*89c4ff92SAndroid Build Coastguard Worker return false; 128*89c4ff92SAndroid Build Coastguard Worker } 129*89c4ff92SAndroid Build Coastguard Worker 130*89c4ff92SAndroid Build Coastguard Worker if (!ValidateDirectory(m_PrevStateCDir)) 131*89c4ff92SAndroid Build Coastguard Worker { 132*89c4ff92SAndroid Build Coastguard Worker return false; 133*89c4ff92SAndroid Build Coastguard Worker } 134*89c4ff92SAndroid Build Coastguard Worker 135*89c4ff92SAndroid Build Coastguard Worker if (!ValidateDirectory(m_PrevStateHDir)) 136*89c4ff92SAndroid Build Coastguard Worker { 137*89c4ff92SAndroid Build Coastguard Worker return false; 138*89c4ff92SAndroid Build Coastguard Worker } 139*89c4ff92SAndroid Build Coastguard Worker 140*89c4ff92SAndroid Build Coastguard Worker if (!ValidateDirectory(m_LogitsDir)) 141*89c4ff92SAndroid Build Coastguard Worker { 142*89c4ff92SAndroid Build Coastguard Worker return false; 143*89c4ff92SAndroid Build Coastguard Worker } 144*89c4ff92SAndroid Build Coastguard Worker 145*89c4ff92SAndroid Build Coastguard Worker if (!ValidateDirectory(m_NewStateCDir)) 146*89c4ff92SAndroid Build Coastguard Worker { 147*89c4ff92SAndroid Build Coastguard Worker return false; 148*89c4ff92SAndroid Build Coastguard Worker } 149*89c4ff92SAndroid Build Coastguard Worker 150*89c4ff92SAndroid Build Coastguard Worker if (!ValidateDirectory(m_NewStateHDir)) 151*89c4ff92SAndroid Build Coastguard Worker { 152*89c4ff92SAndroid Build Coastguard Worker return false; 153*89c4ff92SAndroid Build Coastguard Worker } 154*89c4ff92SAndroid Build Coastguard Worker 155*89c4ff92SAndroid Build Coastguard Worker m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions); 156*89c4ff92SAndroid Build Coastguard Worker if (!m_Model) 157*89c4ff92SAndroid Build Coastguard Worker { 158*89c4ff92SAndroid Build Coastguard Worker return false; 159*89c4ff92SAndroid Build Coastguard Worker } 160*89c4ff92SAndroid Build Coastguard Worker m_Database = std::make_unique<DeepSpeechV1Database>(m_InputSeqDir.c_str(), m_PrevStateHDir.c_str(), 161*89c4ff92SAndroid Build Coastguard Worker m_PrevStateCDir.c_str(), m_LogitsDir.c_str(), 162*89c4ff92SAndroid Build Coastguard Worker m_NewStateHDir.c_str(), m_NewStateCDir.c_str()); 163*89c4ff92SAndroid Build Coastguard Worker if (!m_Database) 164*89c4ff92SAndroid Build Coastguard Worker { 165*89c4ff92SAndroid Build Coastguard Worker return false; 166*89c4ff92SAndroid Build Coastguard Worker } 167*89c4ff92SAndroid Build Coastguard Worker 168*89c4ff92SAndroid Build Coastguard Worker return true; 169*89c4ff92SAndroid Build Coastguard Worker } 170*89c4ff92SAndroid Build Coastguard Worker GetTestCase(unsigned int testCaseId)171*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override 172*89c4ff92SAndroid Build Coastguard Worker { 173*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<DeepSpeechV1TestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId); 174*89c4ff92SAndroid Build Coastguard Worker if (!testCaseData) 175*89c4ff92SAndroid Build Coastguard Worker { 176*89c4ff92SAndroid Build Coastguard Worker return nullptr; 177*89c4ff92SAndroid Build Coastguard Worker } 178*89c4ff92SAndroid Build Coastguard Worker 179*89c4ff92SAndroid Build Coastguard Worker return std::make_unique<DeepSpeechV1TestCase<Model>>(*m_Model, testCaseId, *testCaseData); 180*89c4ff92SAndroid Build Coastguard Worker } 181*89c4ff92SAndroid Build Coastguard Worker 182*89c4ff92SAndroid Build Coastguard Worker private: 183*89c4ff92SAndroid Build Coastguard Worker typename Model::CommandLineOptions m_ModelCommandLineOptions; 184*89c4ff92SAndroid Build Coastguard Worker std::function<std::unique_ptr<Model>(const InferenceTestOptions&, 185*89c4ff92SAndroid Build Coastguard Worker typename Model::CommandLineOptions)> m_ConstructModel; 186*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<Model> m_Model; 187*89c4ff92SAndroid Build Coastguard Worker 188*89c4ff92SAndroid Build Coastguard Worker std::string m_InputSeqDir; 189*89c4ff92SAndroid Build Coastguard Worker std::string m_PrevStateCDir; 190*89c4ff92SAndroid Build Coastguard Worker std::string m_PrevStateHDir; 191*89c4ff92SAndroid Build Coastguard Worker std::string m_LogitsDir; 192*89c4ff92SAndroid Build Coastguard Worker std::string m_NewStateCDir; 193*89c4ff92SAndroid Build Coastguard Worker std::string m_NewStateHDir; 194*89c4ff92SAndroid Build Coastguard Worker 195*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<DeepSpeechV1Database> m_Database; 196*89c4ff92SAndroid Build Coastguard Worker }; 197*89c4ff92SAndroid Build Coastguard Worker 198*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace 199