1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #pragma once 6 7 #include "InferenceTest.hpp" 8 #include "DeepSpeechV1Database.hpp" 9 10 #include <armnn/utility/Assert.hpp> 11 #include <armnn/utility/IgnoreUnused.hpp> 12 #include <armnnUtils/FloatingPointComparison.hpp> 13 14 #include <vector> 15 16 namespace 17 { 18 19 template<typename Model> 20 class DeepSpeechV1TestCase : public InferenceModelTestCase<Model> 21 { 22 public: DeepSpeechV1TestCase(Model & model,unsigned int testCaseId,const DeepSpeechV1TestCaseData & testCaseData)23 DeepSpeechV1TestCase(Model& model, 24 unsigned int testCaseId, 25 const DeepSpeechV1TestCaseData& testCaseData) 26 : InferenceModelTestCase<Model>(model, 27 testCaseId, 28 { testCaseData.m_InputData.m_InputSeq, 29 testCaseData.m_InputData.m_StateH, 30 testCaseData.m_InputData.m_StateC}, 31 { k_OutputSize1, k_OutputSize2, k_OutputSize3 }) 32 , m_ExpectedOutputs({testCaseData.m_ExpectedOutputData.m_InputSeq, testCaseData.m_ExpectedOutputData.m_StateH, 33 testCaseData.m_ExpectedOutputData.m_StateC}) 34 {} 35 ProcessResult(const InferenceTestOptions & options)36 TestCaseResult ProcessResult(const InferenceTestOptions& options) override 37 { 38 armnn::IgnoreUnused(options); 39 const std::vector<float>& output1 = mapbox::util::get<std::vector<float>>(this->GetOutputs()[0]); // logits 40 ARMNN_ASSERT(output1.size() == k_OutputSize1); 41 42 const std::vector<float>& output2 = mapbox::util::get<std::vector<float>>(this->GetOutputs()[1]); // new_state_c 43 ARMNN_ASSERT(output2.size() == k_OutputSize2); 44 45 const std::vector<float>& output3 = mapbox::util::get<std::vector<float>>(this->GetOutputs()[2]); // new_state_h 46 ARMNN_ASSERT(output3.size() == k_OutputSize3); 47 48 // Check each output to see whether it is the expected value 49 for (unsigned int j = 0u; j < output1.size(); j++) 50 { 51 if(!armnnUtils::within_percentage_tolerance(output1[j], m_ExpectedOutputs.m_InputSeq[j])) 52 { 53 ARMNN_LOG(error) << "InputSeq for Lstm " << this->GetTestCaseId() << 54 " is incorrect at" << j; 55 return TestCaseResult::Failed; 56 } 57 } 58 59 for (unsigned int j = 0u; j < output2.size(); j++) 60 { 61 if(!armnnUtils::within_percentage_tolerance(output2[j], m_ExpectedOutputs.m_StateH[j])) 62 { 63 ARMNN_LOG(error) << "StateH for Lstm " << this->GetTestCaseId() << 64 " is incorrect"; 65 return TestCaseResult::Failed; 66 } 67 } 68 69 for (unsigned int j = 0u; j < output3.size(); j++) 70 { 71 if(!armnnUtils::within_percentage_tolerance(output3[j], m_ExpectedOutputs.m_StateC[j])) 72 { 73 ARMNN_LOG(error) << "StateC for Lstm " << this->GetTestCaseId() << 74 " is incorrect"; 75 return TestCaseResult::Failed; 76 } 77 } 78 return TestCaseResult::Ok; 79 } 80 81 private: 82 83 static constexpr unsigned int k_OutputSize1 = 464u; 84 static constexpr unsigned int k_OutputSize2 = 2048u; 85 static constexpr unsigned int k_OutputSize3 = 2048u; 86 87 LstmInput m_ExpectedOutputs; 88 }; 89 90 template <typename Model> 91 class DeepSpeechV1TestCaseProvider : public IInferenceTestCaseProvider 92 { 93 public: 94 template <typename TConstructModelCallable> DeepSpeechV1TestCaseProvider(TConstructModelCallable constructModel)95 explicit DeepSpeechV1TestCaseProvider(TConstructModelCallable constructModel) 96 : m_ConstructModel(constructModel) 97 {} 98 AddCommandLineOptions(cxxopts::Options & options,std::vector<std::string> & required)99 virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required) override 100 { 101 options 102 .allow_unrecognised_options() 103 .add_options() 104 ("s,input-seq-dir", "Path to directory containing test data for m_InputSeq", 105 cxxopts::value<std::string>(m_InputSeqDir)) 106 ("h,prev-state-h-dir", "Path to directory containing test data for m_PrevStateH", 107 cxxopts::value<std::string>(m_PrevStateHDir)) 108 ("c,prev-state-c-dir", "Path to directory containing test data for m_PrevStateC", 109 cxxopts::value<std::string>(m_PrevStateCDir)) 110 ("l,logits-dir", "Path to directory containing test data for m_Logits", 111 cxxopts::value<std::string>(m_LogitsDir)) 112 ("H,new-state-h-dir", "Path to directory containing test data for m_NewStateH", 113 cxxopts::value<std::string>(m_NewStateHDir)) 114 ("C,new-state-c-dir", "Path to directory containing test data for m_NewStateC", 115 cxxopts::value<std::string>(m_NewStateCDir)); 116 117 required.insert(required.end(), {"input-seq-dir", "prev-state-h-dir", "prev-state-c-dir", "logits-dir", 118 "new-state-h-dir", "new-state-c-dir"}); 119 120 Model::AddCommandLineOptions(options, m_ModelCommandLineOptions, required); 121 } 122 ProcessCommandLineOptions(const InferenceTestOptions & commonOptions)123 virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override 124 { 125 if (!ValidateDirectory(m_InputSeqDir)) 126 { 127 return false; 128 } 129 130 if (!ValidateDirectory(m_PrevStateCDir)) 131 { 132 return false; 133 } 134 135 if (!ValidateDirectory(m_PrevStateHDir)) 136 { 137 return false; 138 } 139 140 if (!ValidateDirectory(m_LogitsDir)) 141 { 142 return false; 143 } 144 145 if (!ValidateDirectory(m_NewStateCDir)) 146 { 147 return false; 148 } 149 150 if (!ValidateDirectory(m_NewStateHDir)) 151 { 152 return false; 153 } 154 155 m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions); 156 if (!m_Model) 157 { 158 return false; 159 } 160 m_Database = std::make_unique<DeepSpeechV1Database>(m_InputSeqDir.c_str(), m_PrevStateHDir.c_str(), 161 m_PrevStateCDir.c_str(), m_LogitsDir.c_str(), 162 m_NewStateHDir.c_str(), m_NewStateCDir.c_str()); 163 if (!m_Database) 164 { 165 return false; 166 } 167 168 return true; 169 } 170 GetTestCase(unsigned int testCaseId)171 std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override 172 { 173 std::unique_ptr<DeepSpeechV1TestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId); 174 if (!testCaseData) 175 { 176 return nullptr; 177 } 178 179 return std::make_unique<DeepSpeechV1TestCase<Model>>(*m_Model, testCaseId, *testCaseData); 180 } 181 182 private: 183 typename Model::CommandLineOptions m_ModelCommandLineOptions; 184 std::function<std::unique_ptr<Model>(const InferenceTestOptions&, 185 typename Model::CommandLineOptions)> m_ConstructModel; 186 std::unique_ptr<Model> m_Model; 187 188 std::string m_InputSeqDir; 189 std::string m_PrevStateCDir; 190 std::string m_PrevStateHDir; 191 std::string m_LogitsDir; 192 std::string m_NewStateCDir; 193 std::string m_NewStateHDir; 194 195 std::unique_ptr<DeepSpeechV1Database> m_Database; 196 }; 197 198 } // anonymous namespace 199