1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include "InferenceModel.hpp"
8
9 #include <armnn/ArmNN.hpp>
10 #include <armnn/Logging.hpp>
11 #include <armnn/TypesUtils.hpp>
12 #include <armnn/utility/IgnoreUnused.hpp>
13
14 #include <armnnUtils/TContainer.hpp>
15
16 #include <cxxopts/cxxopts.hpp>
17 #include <fmt/format.h>
18
19
20 namespace armnn
21 {
22
operator >>(std::istream & in,armnn::Compute & compute)23 inline std::istream& operator>>(std::istream& in, armnn::Compute& compute)
24 {
25 std::string token;
26 in >> token;
27 compute = armnn::ParseComputeDevice(token.c_str());
28 if (compute == armnn::Compute::Undefined)
29 {
30 in.setstate(std::ios_base::failbit);
31 throw cxxopts::OptionException(fmt::format("Unrecognised compute device: {}", token));
32 }
33 return in;
34 }
35
operator >>(std::istream & in,armnn::BackendId & backend)36 inline std::istream& operator>>(std::istream& in, armnn::BackendId& backend)
37 {
38 std::string token;
39 in >> token;
40 armnn::Compute compute = armnn::ParseComputeDevice(token.c_str());
41 if (compute == armnn::Compute::Undefined)
42 {
43 in.setstate(std::ios_base::failbit);
44 throw cxxopts::OptionException(fmt::format("Unrecognised compute device: {}", token));
45 }
46 backend = compute;
47 return in;
48 }
49
50 namespace test
51 {
52
53 class TestFrameworkException : public Exception
54 {
55 public:
56 using Exception::Exception;
57 };
58
59 struct InferenceTestOptions
60 {
61 unsigned int m_IterationCount;
62 std::string m_InferenceTimesFile;
63 bool m_EnableProfiling;
64 std::string m_DynamicBackendsPath;
65
InferenceTestOptionsarmnn::test::InferenceTestOptions66 InferenceTestOptions()
67 : m_IterationCount(0)
68 , m_EnableProfiling(0)
69 , m_DynamicBackendsPath()
70 {}
71 };
72
73 enum class TestCaseResult
74 {
75 /// The test completed without any errors.
76 Ok,
77 /// The test failed (e.g. the prediction didn't match the validation file).
78 /// This will eventually fail the whole program but the remaining test cases will still be run.
79 Failed,
80 /// The test failed with a fatal error. The remaining tests will not be run.
81 Abort
82 };
83
84 class IInferenceTestCase
85 {
86 public:
~IInferenceTestCase()87 virtual ~IInferenceTestCase() {}
88
89 virtual void Run() = 0;
90 virtual TestCaseResult ProcessResult(const InferenceTestOptions& options) = 0;
91 };
92
93 class IInferenceTestCaseProvider
94 {
95 public:
~IInferenceTestCaseProvider()96 virtual ~IInferenceTestCaseProvider() {}
97
AddCommandLineOptions(cxxopts::Options & options,std::vector<std::string> & required)98 virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required)
99 {
100 IgnoreUnused(options, required);
101 };
ProcessCommandLineOptions(const InferenceTestOptions & commonOptions)102 virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions)
103 {
104 IgnoreUnused(commonOptions);
105 return true;
106 };
107 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) = 0;
OnInferenceTestFinished()108 virtual bool OnInferenceTestFinished() { return true; };
109 };
110
111 template <typename TModel>
112 class InferenceModelTestCase : public IInferenceTestCase
113 {
114 public:
115
InferenceModelTestCase(TModel & model,unsigned int testCaseId,const std::vector<armnnUtils::TContainer> & inputs,const std::vector<unsigned int> & outputSizes)116 InferenceModelTestCase(TModel& model,
117 unsigned int testCaseId,
118 const std::vector<armnnUtils::TContainer>& inputs,
119 const std::vector<unsigned int>& outputSizes)
120 : m_Model(model)
121 , m_TestCaseId(testCaseId)
122 , m_Inputs(std::move(inputs))
123 {
124 // Initialize output vector
125 const size_t numOutputs = outputSizes.size();
126 m_Outputs.reserve(numOutputs);
127
128 for (size_t i = 0; i < numOutputs; i++)
129 {
130 m_Outputs.push_back(std::vector<typename TModel::DataType>(outputSizes[i]));
131 }
132 }
133
Run()134 virtual void Run() override
135 {
136 m_Model.Run(m_Inputs, m_Outputs);
137 }
138
139 protected:
GetTestCaseId() const140 unsigned int GetTestCaseId() const { return m_TestCaseId; }
GetOutputs() const141 const std::vector<armnnUtils::TContainer>& GetOutputs() const { return m_Outputs; }
142
143 private:
144 TModel& m_Model;
145 unsigned int m_TestCaseId;
146 std::vector<armnnUtils::TContainer> m_Inputs;
147 std::vector<armnnUtils::TContainer> m_Outputs;
148 };
149
150 template <typename TTestCaseDatabase, typename TModel>
151 class ClassifierTestCase : public InferenceModelTestCase<TModel>
152 {
153 public:
154 ClassifierTestCase(int& numInferencesRef,
155 int& numCorrectInferencesRef,
156 const std::vector<unsigned int>& validationPredictions,
157 std::vector<unsigned int>* validationPredictionsOut,
158 TModel& model,
159 unsigned int testCaseId,
160 unsigned int label,
161 std::vector<typename TModel::DataType> modelInput);
162
163 virtual TestCaseResult ProcessResult(const InferenceTestOptions& params) override;
164
165 private:
166 unsigned int m_Label;
167 InferenceModelInternal::QuantizationParams m_QuantizationParams;
168
169 /// These fields reference the corresponding member in the ClassifierTestCaseProvider.
170 /// @{
171 int& m_NumInferencesRef;
172 int& m_NumCorrectInferencesRef;
173 const std::vector<unsigned int>& m_ValidationPredictions;
174 std::vector<unsigned int>* m_ValidationPredictionsOut;
175 /// @}
176 };
177
178 template <typename TDatabase, typename InferenceModel>
179 class ClassifierTestCaseProvider : public IInferenceTestCaseProvider
180 {
181 public:
182 template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
183 ClassifierTestCaseProvider(TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel);
184
185 virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required) override;
186 virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override;
187 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override;
188 virtual bool OnInferenceTestFinished() override;
189
190 private:
191 void ReadPredictions();
192
193 typename InferenceModel::CommandLineOptions m_ModelCommandLineOptions;
194 std::function<std::unique_ptr<InferenceModel>(const InferenceTestOptions& commonOptions,
195 typename InferenceModel::CommandLineOptions)> m_ConstructModel;
196 std::unique_ptr<InferenceModel> m_Model;
197
198 std::string m_DataDir;
199 std::function<TDatabase(const char*, const InferenceModel&)> m_ConstructDatabase;
200 std::unique_ptr<TDatabase> m_Database;
201
202 int m_NumInferences; // Referenced by test cases.
203 int m_NumCorrectInferences; // Referenced by test cases.
204
205 std::string m_ValidationFileIn;
206 std::vector<unsigned int> m_ValidationPredictions; // Referenced by test cases.
207
208 std::string m_ValidationFileOut;
209 std::vector<unsigned int> m_ValidationPredictionsOut; // Referenced by test cases.
210 };
211
212 bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
213 InferenceTestOptions& outParams);
214
215 bool ValidateDirectory(std::string& dir);
216
217 bool InferenceTest(const InferenceTestOptions& params,
218 const std::vector<unsigned int>& defaultTestCaseIds,
219 IInferenceTestCaseProvider& testCaseProvider);
220
221 template<typename TConstructTestCaseProvider>
222 int InferenceTestMain(int argc,
223 char* argv[],
224 const std::vector<unsigned int>& defaultTestCaseIds,
225 TConstructTestCaseProvider constructTestCaseProvider);
226
227 template<typename TDatabase,
228 typename TParser,
229 typename TConstructDatabaseCallable>
230 int ClassifierInferenceTestMain(int argc, char* argv[], const char* modelFilename, bool isModelBinary,
231 const char* inputBindingName, const char* outputBindingName,
232 const std::vector<unsigned int>& defaultTestCaseIds,
233 TConstructDatabaseCallable constructDatabase,
234 const armnn::TensorShape* inputTensorShape = nullptr);
235
236 } // namespace test
237 } // namespace armnn
238
239 #include "InferenceTest.inl"
240