1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker
7*89c4ff92SAndroid Build Coastguard Worker #include "InferenceModel.hpp"
8*89c4ff92SAndroid Build Coastguard Worker
9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/ArmNN.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Logging.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/TypesUtils.hpp>
12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp>
13*89c4ff92SAndroid Build Coastguard Worker
14*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/TContainer.hpp>
15*89c4ff92SAndroid Build Coastguard Worker
16*89c4ff92SAndroid Build Coastguard Worker #include <cxxopts/cxxopts.hpp>
17*89c4ff92SAndroid Build Coastguard Worker #include <fmt/format.h>
18*89c4ff92SAndroid Build Coastguard Worker
19*89c4ff92SAndroid Build Coastguard Worker
20*89c4ff92SAndroid Build Coastguard Worker namespace armnn
21*89c4ff92SAndroid Build Coastguard Worker {
22*89c4ff92SAndroid Build Coastguard Worker
operator >>(std::istream & in,armnn::Compute & compute)23*89c4ff92SAndroid Build Coastguard Worker inline std::istream& operator>>(std::istream& in, armnn::Compute& compute)
24*89c4ff92SAndroid Build Coastguard Worker {
25*89c4ff92SAndroid Build Coastguard Worker std::string token;
26*89c4ff92SAndroid Build Coastguard Worker in >> token;
27*89c4ff92SAndroid Build Coastguard Worker compute = armnn::ParseComputeDevice(token.c_str());
28*89c4ff92SAndroid Build Coastguard Worker if (compute == armnn::Compute::Undefined)
29*89c4ff92SAndroid Build Coastguard Worker {
30*89c4ff92SAndroid Build Coastguard Worker in.setstate(std::ios_base::failbit);
31*89c4ff92SAndroid Build Coastguard Worker throw cxxopts::OptionException(fmt::format("Unrecognised compute device: {}", token));
32*89c4ff92SAndroid Build Coastguard Worker }
33*89c4ff92SAndroid Build Coastguard Worker return in;
34*89c4ff92SAndroid Build Coastguard Worker }
35*89c4ff92SAndroid Build Coastguard Worker
operator >>(std::istream & in,armnn::BackendId & backend)36*89c4ff92SAndroid Build Coastguard Worker inline std::istream& operator>>(std::istream& in, armnn::BackendId& backend)
37*89c4ff92SAndroid Build Coastguard Worker {
38*89c4ff92SAndroid Build Coastguard Worker std::string token;
39*89c4ff92SAndroid Build Coastguard Worker in >> token;
40*89c4ff92SAndroid Build Coastguard Worker armnn::Compute compute = armnn::ParseComputeDevice(token.c_str());
41*89c4ff92SAndroid Build Coastguard Worker if (compute == armnn::Compute::Undefined)
42*89c4ff92SAndroid Build Coastguard Worker {
43*89c4ff92SAndroid Build Coastguard Worker in.setstate(std::ios_base::failbit);
44*89c4ff92SAndroid Build Coastguard Worker throw cxxopts::OptionException(fmt::format("Unrecognised compute device: {}", token));
45*89c4ff92SAndroid Build Coastguard Worker }
46*89c4ff92SAndroid Build Coastguard Worker backend = compute;
47*89c4ff92SAndroid Build Coastguard Worker return in;
48*89c4ff92SAndroid Build Coastguard Worker }
49*89c4ff92SAndroid Build Coastguard Worker
50*89c4ff92SAndroid Build Coastguard Worker namespace test
51*89c4ff92SAndroid Build Coastguard Worker {
52*89c4ff92SAndroid Build Coastguard Worker
53*89c4ff92SAndroid Build Coastguard Worker class TestFrameworkException : public Exception
54*89c4ff92SAndroid Build Coastguard Worker {
55*89c4ff92SAndroid Build Coastguard Worker public:
56*89c4ff92SAndroid Build Coastguard Worker using Exception::Exception;
57*89c4ff92SAndroid Build Coastguard Worker };
58*89c4ff92SAndroid Build Coastguard Worker
59*89c4ff92SAndroid Build Coastguard Worker struct InferenceTestOptions
60*89c4ff92SAndroid Build Coastguard Worker {
61*89c4ff92SAndroid Build Coastguard Worker unsigned int m_IterationCount;
62*89c4ff92SAndroid Build Coastguard Worker std::string m_InferenceTimesFile;
63*89c4ff92SAndroid Build Coastguard Worker bool m_EnableProfiling;
64*89c4ff92SAndroid Build Coastguard Worker std::string m_DynamicBackendsPath;
65*89c4ff92SAndroid Build Coastguard Worker
InferenceTestOptionsarmnn::test::InferenceTestOptions66*89c4ff92SAndroid Build Coastguard Worker InferenceTestOptions()
67*89c4ff92SAndroid Build Coastguard Worker : m_IterationCount(0)
68*89c4ff92SAndroid Build Coastguard Worker , m_EnableProfiling(0)
69*89c4ff92SAndroid Build Coastguard Worker , m_DynamicBackendsPath()
70*89c4ff92SAndroid Build Coastguard Worker {}
71*89c4ff92SAndroid Build Coastguard Worker };
72*89c4ff92SAndroid Build Coastguard Worker
73*89c4ff92SAndroid Build Coastguard Worker enum class TestCaseResult
74*89c4ff92SAndroid Build Coastguard Worker {
75*89c4ff92SAndroid Build Coastguard Worker /// The test completed without any errors.
76*89c4ff92SAndroid Build Coastguard Worker Ok,
77*89c4ff92SAndroid Build Coastguard Worker /// The test failed (e.g. the prediction didn't match the validation file).
78*89c4ff92SAndroid Build Coastguard Worker /// This will eventually fail the whole program but the remaining test cases will still be run.
79*89c4ff92SAndroid Build Coastguard Worker Failed,
80*89c4ff92SAndroid Build Coastguard Worker /// The test failed with a fatal error. The remaining tests will not be run.
81*89c4ff92SAndroid Build Coastguard Worker Abort
82*89c4ff92SAndroid Build Coastguard Worker };
83*89c4ff92SAndroid Build Coastguard Worker
84*89c4ff92SAndroid Build Coastguard Worker class IInferenceTestCase
85*89c4ff92SAndroid Build Coastguard Worker {
86*89c4ff92SAndroid Build Coastguard Worker public:
~IInferenceTestCase()87*89c4ff92SAndroid Build Coastguard Worker virtual ~IInferenceTestCase() {}
88*89c4ff92SAndroid Build Coastguard Worker
89*89c4ff92SAndroid Build Coastguard Worker virtual void Run() = 0;
90*89c4ff92SAndroid Build Coastguard Worker virtual TestCaseResult ProcessResult(const InferenceTestOptions& options) = 0;
91*89c4ff92SAndroid Build Coastguard Worker };
92*89c4ff92SAndroid Build Coastguard Worker
93*89c4ff92SAndroid Build Coastguard Worker class IInferenceTestCaseProvider
94*89c4ff92SAndroid Build Coastguard Worker {
95*89c4ff92SAndroid Build Coastguard Worker public:
~IInferenceTestCaseProvider()96*89c4ff92SAndroid Build Coastguard Worker virtual ~IInferenceTestCaseProvider() {}
97*89c4ff92SAndroid Build Coastguard Worker
AddCommandLineOptions(cxxopts::Options & options,std::vector<std::string> & required)98*89c4ff92SAndroid Build Coastguard Worker virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required)
99*89c4ff92SAndroid Build Coastguard Worker {
100*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(options, required);
101*89c4ff92SAndroid Build Coastguard Worker };
ProcessCommandLineOptions(const InferenceTestOptions & commonOptions)102*89c4ff92SAndroid Build Coastguard Worker virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions)
103*89c4ff92SAndroid Build Coastguard Worker {
104*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(commonOptions);
105*89c4ff92SAndroid Build Coastguard Worker return true;
106*89c4ff92SAndroid Build Coastguard Worker };
107*89c4ff92SAndroid Build Coastguard Worker virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) = 0;
OnInferenceTestFinished()108*89c4ff92SAndroid Build Coastguard Worker virtual bool OnInferenceTestFinished() { return true; };
109*89c4ff92SAndroid Build Coastguard Worker };
110*89c4ff92SAndroid Build Coastguard Worker
111*89c4ff92SAndroid Build Coastguard Worker template <typename TModel>
112*89c4ff92SAndroid Build Coastguard Worker class InferenceModelTestCase : public IInferenceTestCase
113*89c4ff92SAndroid Build Coastguard Worker {
114*89c4ff92SAndroid Build Coastguard Worker public:
115*89c4ff92SAndroid Build Coastguard Worker
InferenceModelTestCase(TModel & model,unsigned int testCaseId,const std::vector<armnnUtils::TContainer> & inputs,const std::vector<unsigned int> & outputSizes)116*89c4ff92SAndroid Build Coastguard Worker InferenceModelTestCase(TModel& model,
117*89c4ff92SAndroid Build Coastguard Worker unsigned int testCaseId,
118*89c4ff92SAndroid Build Coastguard Worker const std::vector<armnnUtils::TContainer>& inputs,
119*89c4ff92SAndroid Build Coastguard Worker const std::vector<unsigned int>& outputSizes)
120*89c4ff92SAndroid Build Coastguard Worker : m_Model(model)
121*89c4ff92SAndroid Build Coastguard Worker , m_TestCaseId(testCaseId)
122*89c4ff92SAndroid Build Coastguard Worker , m_Inputs(std::move(inputs))
123*89c4ff92SAndroid Build Coastguard Worker {
124*89c4ff92SAndroid Build Coastguard Worker // Initialize output vector
125*89c4ff92SAndroid Build Coastguard Worker const size_t numOutputs = outputSizes.size();
126*89c4ff92SAndroid Build Coastguard Worker m_Outputs.reserve(numOutputs);
127*89c4ff92SAndroid Build Coastguard Worker
128*89c4ff92SAndroid Build Coastguard Worker for (size_t i = 0; i < numOutputs; i++)
129*89c4ff92SAndroid Build Coastguard Worker {
130*89c4ff92SAndroid Build Coastguard Worker m_Outputs.push_back(std::vector<typename TModel::DataType>(outputSizes[i]));
131*89c4ff92SAndroid Build Coastguard Worker }
132*89c4ff92SAndroid Build Coastguard Worker }
133*89c4ff92SAndroid Build Coastguard Worker
Run()134*89c4ff92SAndroid Build Coastguard Worker virtual void Run() override
135*89c4ff92SAndroid Build Coastguard Worker {
136*89c4ff92SAndroid Build Coastguard Worker m_Model.Run(m_Inputs, m_Outputs);
137*89c4ff92SAndroid Build Coastguard Worker }
138*89c4ff92SAndroid Build Coastguard Worker
139*89c4ff92SAndroid Build Coastguard Worker protected:
GetTestCaseId() const140*89c4ff92SAndroid Build Coastguard Worker unsigned int GetTestCaseId() const { return m_TestCaseId; }
GetOutputs() const141*89c4ff92SAndroid Build Coastguard Worker const std::vector<armnnUtils::TContainer>& GetOutputs() const { return m_Outputs; }
142*89c4ff92SAndroid Build Coastguard Worker
143*89c4ff92SAndroid Build Coastguard Worker private:
144*89c4ff92SAndroid Build Coastguard Worker TModel& m_Model;
145*89c4ff92SAndroid Build Coastguard Worker unsigned int m_TestCaseId;
146*89c4ff92SAndroid Build Coastguard Worker std::vector<armnnUtils::TContainer> m_Inputs;
147*89c4ff92SAndroid Build Coastguard Worker std::vector<armnnUtils::TContainer> m_Outputs;
148*89c4ff92SAndroid Build Coastguard Worker };
149*89c4ff92SAndroid Build Coastguard Worker
150*89c4ff92SAndroid Build Coastguard Worker template <typename TTestCaseDatabase, typename TModel>
151*89c4ff92SAndroid Build Coastguard Worker class ClassifierTestCase : public InferenceModelTestCase<TModel>
152*89c4ff92SAndroid Build Coastguard Worker {
153*89c4ff92SAndroid Build Coastguard Worker public:
154*89c4ff92SAndroid Build Coastguard Worker ClassifierTestCase(int& numInferencesRef,
155*89c4ff92SAndroid Build Coastguard Worker int& numCorrectInferencesRef,
156*89c4ff92SAndroid Build Coastguard Worker const std::vector<unsigned int>& validationPredictions,
157*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int>* validationPredictionsOut,
158*89c4ff92SAndroid Build Coastguard Worker TModel& model,
159*89c4ff92SAndroid Build Coastguard Worker unsigned int testCaseId,
160*89c4ff92SAndroid Build Coastguard Worker unsigned int label,
161*89c4ff92SAndroid Build Coastguard Worker std::vector<typename TModel::DataType> modelInput);
162*89c4ff92SAndroid Build Coastguard Worker
163*89c4ff92SAndroid Build Coastguard Worker virtual TestCaseResult ProcessResult(const InferenceTestOptions& params) override;
164*89c4ff92SAndroid Build Coastguard Worker
165*89c4ff92SAndroid Build Coastguard Worker private:
166*89c4ff92SAndroid Build Coastguard Worker unsigned int m_Label;
167*89c4ff92SAndroid Build Coastguard Worker InferenceModelInternal::QuantizationParams m_QuantizationParams;
168*89c4ff92SAndroid Build Coastguard Worker
169*89c4ff92SAndroid Build Coastguard Worker /// These fields reference the corresponding member in the ClassifierTestCaseProvider.
170*89c4ff92SAndroid Build Coastguard Worker /// @{
171*89c4ff92SAndroid Build Coastguard Worker int& m_NumInferencesRef;
172*89c4ff92SAndroid Build Coastguard Worker int& m_NumCorrectInferencesRef;
173*89c4ff92SAndroid Build Coastguard Worker const std::vector<unsigned int>& m_ValidationPredictions;
174*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int>* m_ValidationPredictionsOut;
175*89c4ff92SAndroid Build Coastguard Worker /// @}
176*89c4ff92SAndroid Build Coastguard Worker };
177*89c4ff92SAndroid Build Coastguard Worker
178*89c4ff92SAndroid Build Coastguard Worker template <typename TDatabase, typename InferenceModel>
179*89c4ff92SAndroid Build Coastguard Worker class ClassifierTestCaseProvider : public IInferenceTestCaseProvider
180*89c4ff92SAndroid Build Coastguard Worker {
181*89c4ff92SAndroid Build Coastguard Worker public:
182*89c4ff92SAndroid Build Coastguard Worker template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
183*89c4ff92SAndroid Build Coastguard Worker ClassifierTestCaseProvider(TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel);
184*89c4ff92SAndroid Build Coastguard Worker
185*89c4ff92SAndroid Build Coastguard Worker virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required) override;
186*89c4ff92SAndroid Build Coastguard Worker virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override;
187*89c4ff92SAndroid Build Coastguard Worker virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override;
188*89c4ff92SAndroid Build Coastguard Worker virtual bool OnInferenceTestFinished() override;
189*89c4ff92SAndroid Build Coastguard Worker
190*89c4ff92SAndroid Build Coastguard Worker private:
191*89c4ff92SAndroid Build Coastguard Worker void ReadPredictions();
192*89c4ff92SAndroid Build Coastguard Worker
193*89c4ff92SAndroid Build Coastguard Worker typename InferenceModel::CommandLineOptions m_ModelCommandLineOptions;
194*89c4ff92SAndroid Build Coastguard Worker std::function<std::unique_ptr<InferenceModel>(const InferenceTestOptions& commonOptions,
195*89c4ff92SAndroid Build Coastguard Worker typename InferenceModel::CommandLineOptions)> m_ConstructModel;
196*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<InferenceModel> m_Model;
197*89c4ff92SAndroid Build Coastguard Worker
198*89c4ff92SAndroid Build Coastguard Worker std::string m_DataDir;
199*89c4ff92SAndroid Build Coastguard Worker std::function<TDatabase(const char*, const InferenceModel&)> m_ConstructDatabase;
200*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<TDatabase> m_Database;
201*89c4ff92SAndroid Build Coastguard Worker
202*89c4ff92SAndroid Build Coastguard Worker int m_NumInferences; // Referenced by test cases.
203*89c4ff92SAndroid Build Coastguard Worker int m_NumCorrectInferences; // Referenced by test cases.
204*89c4ff92SAndroid Build Coastguard Worker
205*89c4ff92SAndroid Build Coastguard Worker std::string m_ValidationFileIn;
206*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> m_ValidationPredictions; // Referenced by test cases.
207*89c4ff92SAndroid Build Coastguard Worker
208*89c4ff92SAndroid Build Coastguard Worker std::string m_ValidationFileOut;
209*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> m_ValidationPredictionsOut; // Referenced by test cases.
210*89c4ff92SAndroid Build Coastguard Worker };
211*89c4ff92SAndroid Build Coastguard Worker
212*89c4ff92SAndroid Build Coastguard Worker bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
213*89c4ff92SAndroid Build Coastguard Worker InferenceTestOptions& outParams);
214*89c4ff92SAndroid Build Coastguard Worker
215*89c4ff92SAndroid Build Coastguard Worker bool ValidateDirectory(std::string& dir);
216*89c4ff92SAndroid Build Coastguard Worker
217*89c4ff92SAndroid Build Coastguard Worker bool InferenceTest(const InferenceTestOptions& params,
218*89c4ff92SAndroid Build Coastguard Worker const std::vector<unsigned int>& defaultTestCaseIds,
219*89c4ff92SAndroid Build Coastguard Worker IInferenceTestCaseProvider& testCaseProvider);
220*89c4ff92SAndroid Build Coastguard Worker
221*89c4ff92SAndroid Build Coastguard Worker template<typename TConstructTestCaseProvider>
222*89c4ff92SAndroid Build Coastguard Worker int InferenceTestMain(int argc,
223*89c4ff92SAndroid Build Coastguard Worker char* argv[],
224*89c4ff92SAndroid Build Coastguard Worker const std::vector<unsigned int>& defaultTestCaseIds,
225*89c4ff92SAndroid Build Coastguard Worker TConstructTestCaseProvider constructTestCaseProvider);
226*89c4ff92SAndroid Build Coastguard Worker
227*89c4ff92SAndroid Build Coastguard Worker template<typename TDatabase,
228*89c4ff92SAndroid Build Coastguard Worker typename TParser,
229*89c4ff92SAndroid Build Coastguard Worker typename TConstructDatabaseCallable>
230*89c4ff92SAndroid Build Coastguard Worker int ClassifierInferenceTestMain(int argc, char* argv[], const char* modelFilename, bool isModelBinary,
231*89c4ff92SAndroid Build Coastguard Worker const char* inputBindingName, const char* outputBindingName,
232*89c4ff92SAndroid Build Coastguard Worker const std::vector<unsigned int>& defaultTestCaseIds,
233*89c4ff92SAndroid Build Coastguard Worker TConstructDatabaseCallable constructDatabase,
234*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorShape* inputTensorShape = nullptr);
235*89c4ff92SAndroid Build Coastguard Worker
236*89c4ff92SAndroid Build Coastguard Worker } // namespace test
237*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn
238*89c4ff92SAndroid Build Coastguard Worker
239*89c4ff92SAndroid Build Coastguard Worker #include "InferenceTest.inl"
240