xref: /aosp_15_r20/external/armnn/tests/InferenceTest.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker #include "InferenceModel.hpp"
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/ArmNN.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Logging.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/TypesUtils.hpp>
12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp>
13*89c4ff92SAndroid Build Coastguard Worker 
14*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/TContainer.hpp>
15*89c4ff92SAndroid Build Coastguard Worker 
16*89c4ff92SAndroid Build Coastguard Worker #include <cxxopts/cxxopts.hpp>
17*89c4ff92SAndroid Build Coastguard Worker #include <fmt/format.h>
18*89c4ff92SAndroid Build Coastguard Worker 
19*89c4ff92SAndroid Build Coastguard Worker 
20*89c4ff92SAndroid Build Coastguard Worker namespace armnn
21*89c4ff92SAndroid Build Coastguard Worker {
22*89c4ff92SAndroid Build Coastguard Worker 
operator >>(std::istream & in,armnn::Compute & compute)23*89c4ff92SAndroid Build Coastguard Worker inline std::istream& operator>>(std::istream& in, armnn::Compute& compute)
24*89c4ff92SAndroid Build Coastguard Worker {
25*89c4ff92SAndroid Build Coastguard Worker     std::string token;
26*89c4ff92SAndroid Build Coastguard Worker     in >> token;
27*89c4ff92SAndroid Build Coastguard Worker     compute = armnn::ParseComputeDevice(token.c_str());
28*89c4ff92SAndroid Build Coastguard Worker     if (compute == armnn::Compute::Undefined)
29*89c4ff92SAndroid Build Coastguard Worker     {
30*89c4ff92SAndroid Build Coastguard Worker         in.setstate(std::ios_base::failbit);
31*89c4ff92SAndroid Build Coastguard Worker         throw cxxopts::OptionException(fmt::format("Unrecognised compute device: {}", token));
32*89c4ff92SAndroid Build Coastguard Worker     }
33*89c4ff92SAndroid Build Coastguard Worker     return in;
34*89c4ff92SAndroid Build Coastguard Worker }
35*89c4ff92SAndroid Build Coastguard Worker 
operator >>(std::istream & in,armnn::BackendId & backend)36*89c4ff92SAndroid Build Coastguard Worker inline std::istream& operator>>(std::istream& in, armnn::BackendId& backend)
37*89c4ff92SAndroid Build Coastguard Worker {
38*89c4ff92SAndroid Build Coastguard Worker     std::string token;
39*89c4ff92SAndroid Build Coastguard Worker     in >> token;
40*89c4ff92SAndroid Build Coastguard Worker     armnn::Compute compute = armnn::ParseComputeDevice(token.c_str());
41*89c4ff92SAndroid Build Coastguard Worker     if (compute == armnn::Compute::Undefined)
42*89c4ff92SAndroid Build Coastguard Worker     {
43*89c4ff92SAndroid Build Coastguard Worker         in.setstate(std::ios_base::failbit);
44*89c4ff92SAndroid Build Coastguard Worker         throw cxxopts::OptionException(fmt::format("Unrecognised compute device: {}", token));
45*89c4ff92SAndroid Build Coastguard Worker     }
46*89c4ff92SAndroid Build Coastguard Worker     backend = compute;
47*89c4ff92SAndroid Build Coastguard Worker     return in;
48*89c4ff92SAndroid Build Coastguard Worker }
49*89c4ff92SAndroid Build Coastguard Worker 
50*89c4ff92SAndroid Build Coastguard Worker namespace test
51*89c4ff92SAndroid Build Coastguard Worker {
52*89c4ff92SAndroid Build Coastguard Worker 
53*89c4ff92SAndroid Build Coastguard Worker class TestFrameworkException : public Exception
54*89c4ff92SAndroid Build Coastguard Worker {
55*89c4ff92SAndroid Build Coastguard Worker public:
56*89c4ff92SAndroid Build Coastguard Worker     using Exception::Exception;
57*89c4ff92SAndroid Build Coastguard Worker };
58*89c4ff92SAndroid Build Coastguard Worker 
59*89c4ff92SAndroid Build Coastguard Worker struct InferenceTestOptions
60*89c4ff92SAndroid Build Coastguard Worker {
61*89c4ff92SAndroid Build Coastguard Worker     unsigned int m_IterationCount;
62*89c4ff92SAndroid Build Coastguard Worker     std::string m_InferenceTimesFile;
63*89c4ff92SAndroid Build Coastguard Worker     bool m_EnableProfiling;
64*89c4ff92SAndroid Build Coastguard Worker     std::string m_DynamicBackendsPath;
65*89c4ff92SAndroid Build Coastguard Worker 
InferenceTestOptionsarmnn::test::InferenceTestOptions66*89c4ff92SAndroid Build Coastguard Worker     InferenceTestOptions()
67*89c4ff92SAndroid Build Coastguard Worker         : m_IterationCount(0)
68*89c4ff92SAndroid Build Coastguard Worker         , m_EnableProfiling(0)
69*89c4ff92SAndroid Build Coastguard Worker         , m_DynamicBackendsPath()
70*89c4ff92SAndroid Build Coastguard Worker     {}
71*89c4ff92SAndroid Build Coastguard Worker };
72*89c4ff92SAndroid Build Coastguard Worker 
73*89c4ff92SAndroid Build Coastguard Worker enum class TestCaseResult
74*89c4ff92SAndroid Build Coastguard Worker {
75*89c4ff92SAndroid Build Coastguard Worker     /// The test completed without any errors.
76*89c4ff92SAndroid Build Coastguard Worker     Ok,
77*89c4ff92SAndroid Build Coastguard Worker     /// The test failed (e.g. the prediction didn't match the validation file).
78*89c4ff92SAndroid Build Coastguard Worker     /// This will eventually fail the whole program but the remaining test cases will still be run.
79*89c4ff92SAndroid Build Coastguard Worker     Failed,
80*89c4ff92SAndroid Build Coastguard Worker     /// The test failed with a fatal error. The remaining tests will not be run.
81*89c4ff92SAndroid Build Coastguard Worker     Abort
82*89c4ff92SAndroid Build Coastguard Worker };
83*89c4ff92SAndroid Build Coastguard Worker 
84*89c4ff92SAndroid Build Coastguard Worker class IInferenceTestCase
85*89c4ff92SAndroid Build Coastguard Worker {
86*89c4ff92SAndroid Build Coastguard Worker public:
~IInferenceTestCase()87*89c4ff92SAndroid Build Coastguard Worker     virtual ~IInferenceTestCase() {}
88*89c4ff92SAndroid Build Coastguard Worker 
89*89c4ff92SAndroid Build Coastguard Worker     virtual void Run() = 0;
90*89c4ff92SAndroid Build Coastguard Worker     virtual TestCaseResult ProcessResult(const InferenceTestOptions& options) = 0;
91*89c4ff92SAndroid Build Coastguard Worker };
92*89c4ff92SAndroid Build Coastguard Worker 
93*89c4ff92SAndroid Build Coastguard Worker class IInferenceTestCaseProvider
94*89c4ff92SAndroid Build Coastguard Worker {
95*89c4ff92SAndroid Build Coastguard Worker public:
~IInferenceTestCaseProvider()96*89c4ff92SAndroid Build Coastguard Worker     virtual ~IInferenceTestCaseProvider() {}
97*89c4ff92SAndroid Build Coastguard Worker 
AddCommandLineOptions(cxxopts::Options & options,std::vector<std::string> & required)98*89c4ff92SAndroid Build Coastguard Worker     virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required)
99*89c4ff92SAndroid Build Coastguard Worker     {
100*89c4ff92SAndroid Build Coastguard Worker         IgnoreUnused(options, required);
101*89c4ff92SAndroid Build Coastguard Worker     };
ProcessCommandLineOptions(const InferenceTestOptions & commonOptions)102*89c4ff92SAndroid Build Coastguard Worker     virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions)
103*89c4ff92SAndroid Build Coastguard Worker     {
104*89c4ff92SAndroid Build Coastguard Worker         IgnoreUnused(commonOptions);
105*89c4ff92SAndroid Build Coastguard Worker         return true;
106*89c4ff92SAndroid Build Coastguard Worker     };
107*89c4ff92SAndroid Build Coastguard Worker     virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) = 0;
OnInferenceTestFinished()108*89c4ff92SAndroid Build Coastguard Worker     virtual bool OnInferenceTestFinished() { return true; };
109*89c4ff92SAndroid Build Coastguard Worker };
110*89c4ff92SAndroid Build Coastguard Worker 
111*89c4ff92SAndroid Build Coastguard Worker template <typename TModel>
112*89c4ff92SAndroid Build Coastguard Worker class InferenceModelTestCase : public IInferenceTestCase
113*89c4ff92SAndroid Build Coastguard Worker {
114*89c4ff92SAndroid Build Coastguard Worker public:
115*89c4ff92SAndroid Build Coastguard Worker 
InferenceModelTestCase(TModel & model,unsigned int testCaseId,const std::vector<armnnUtils::TContainer> & inputs,const std::vector<unsigned int> & outputSizes)116*89c4ff92SAndroid Build Coastguard Worker     InferenceModelTestCase(TModel& model,
117*89c4ff92SAndroid Build Coastguard Worker                            unsigned int testCaseId,
118*89c4ff92SAndroid Build Coastguard Worker                            const std::vector<armnnUtils::TContainer>& inputs,
119*89c4ff92SAndroid Build Coastguard Worker                            const std::vector<unsigned int>& outputSizes)
120*89c4ff92SAndroid Build Coastguard Worker         : m_Model(model)
121*89c4ff92SAndroid Build Coastguard Worker         , m_TestCaseId(testCaseId)
122*89c4ff92SAndroid Build Coastguard Worker         , m_Inputs(std::move(inputs))
123*89c4ff92SAndroid Build Coastguard Worker     {
124*89c4ff92SAndroid Build Coastguard Worker         // Initialize output vector
125*89c4ff92SAndroid Build Coastguard Worker         const size_t numOutputs = outputSizes.size();
126*89c4ff92SAndroid Build Coastguard Worker         m_Outputs.reserve(numOutputs);
127*89c4ff92SAndroid Build Coastguard Worker 
128*89c4ff92SAndroid Build Coastguard Worker         for (size_t i = 0; i < numOutputs; i++)
129*89c4ff92SAndroid Build Coastguard Worker         {
130*89c4ff92SAndroid Build Coastguard Worker             m_Outputs.push_back(std::vector<typename TModel::DataType>(outputSizes[i]));
131*89c4ff92SAndroid Build Coastguard Worker         }
132*89c4ff92SAndroid Build Coastguard Worker     }
133*89c4ff92SAndroid Build Coastguard Worker 
Run()134*89c4ff92SAndroid Build Coastguard Worker     virtual void Run() override
135*89c4ff92SAndroid Build Coastguard Worker     {
136*89c4ff92SAndroid Build Coastguard Worker         m_Model.Run(m_Inputs, m_Outputs);
137*89c4ff92SAndroid Build Coastguard Worker     }
138*89c4ff92SAndroid Build Coastguard Worker 
139*89c4ff92SAndroid Build Coastguard Worker protected:
GetTestCaseId() const140*89c4ff92SAndroid Build Coastguard Worker     unsigned int GetTestCaseId() const { return m_TestCaseId; }
GetOutputs() const141*89c4ff92SAndroid Build Coastguard Worker     const std::vector<armnnUtils::TContainer>& GetOutputs() const { return m_Outputs; }
142*89c4ff92SAndroid Build Coastguard Worker 
143*89c4ff92SAndroid Build Coastguard Worker private:
144*89c4ff92SAndroid Build Coastguard Worker     TModel&                         m_Model;
145*89c4ff92SAndroid Build Coastguard Worker     unsigned int                    m_TestCaseId;
146*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnnUtils::TContainer>  m_Inputs;
147*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnnUtils::TContainer>  m_Outputs;
148*89c4ff92SAndroid Build Coastguard Worker };
149*89c4ff92SAndroid Build Coastguard Worker 
150*89c4ff92SAndroid Build Coastguard Worker template <typename TTestCaseDatabase, typename TModel>
151*89c4ff92SAndroid Build Coastguard Worker class ClassifierTestCase : public InferenceModelTestCase<TModel>
152*89c4ff92SAndroid Build Coastguard Worker {
153*89c4ff92SAndroid Build Coastguard Worker public:
154*89c4ff92SAndroid Build Coastguard Worker     ClassifierTestCase(int& numInferencesRef,
155*89c4ff92SAndroid Build Coastguard Worker         int& numCorrectInferencesRef,
156*89c4ff92SAndroid Build Coastguard Worker         const std::vector<unsigned int>& validationPredictions,
157*89c4ff92SAndroid Build Coastguard Worker         std::vector<unsigned int>* validationPredictionsOut,
158*89c4ff92SAndroid Build Coastguard Worker         TModel& model,
159*89c4ff92SAndroid Build Coastguard Worker         unsigned int testCaseId,
160*89c4ff92SAndroid Build Coastguard Worker         unsigned int label,
161*89c4ff92SAndroid Build Coastguard Worker         std::vector<typename TModel::DataType> modelInput);
162*89c4ff92SAndroid Build Coastguard Worker 
163*89c4ff92SAndroid Build Coastguard Worker     virtual TestCaseResult ProcessResult(const InferenceTestOptions& params) override;
164*89c4ff92SAndroid Build Coastguard Worker 
165*89c4ff92SAndroid Build Coastguard Worker private:
166*89c4ff92SAndroid Build Coastguard Worker     unsigned int m_Label;
167*89c4ff92SAndroid Build Coastguard Worker     InferenceModelInternal::QuantizationParams m_QuantizationParams;
168*89c4ff92SAndroid Build Coastguard Worker 
169*89c4ff92SAndroid Build Coastguard Worker     /// These fields reference the corresponding member in the ClassifierTestCaseProvider.
170*89c4ff92SAndroid Build Coastguard Worker     /// @{
171*89c4ff92SAndroid Build Coastguard Worker     int& m_NumInferencesRef;
172*89c4ff92SAndroid Build Coastguard Worker     int& m_NumCorrectInferencesRef;
173*89c4ff92SAndroid Build Coastguard Worker     const std::vector<unsigned int>& m_ValidationPredictions;
174*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int>* m_ValidationPredictionsOut;
175*89c4ff92SAndroid Build Coastguard Worker     /// @}
176*89c4ff92SAndroid Build Coastguard Worker };
177*89c4ff92SAndroid Build Coastguard Worker 
178*89c4ff92SAndroid Build Coastguard Worker template <typename TDatabase, typename InferenceModel>
179*89c4ff92SAndroid Build Coastguard Worker class ClassifierTestCaseProvider : public IInferenceTestCaseProvider
180*89c4ff92SAndroid Build Coastguard Worker {
181*89c4ff92SAndroid Build Coastguard Worker public:
182*89c4ff92SAndroid Build Coastguard Worker     template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
183*89c4ff92SAndroid Build Coastguard Worker     ClassifierTestCaseProvider(TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel);
184*89c4ff92SAndroid Build Coastguard Worker 
185*89c4ff92SAndroid Build Coastguard Worker     virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required) override;
186*89c4ff92SAndroid Build Coastguard Worker     virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override;
187*89c4ff92SAndroid Build Coastguard Worker     virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override;
188*89c4ff92SAndroid Build Coastguard Worker     virtual bool OnInferenceTestFinished() override;
189*89c4ff92SAndroid Build Coastguard Worker 
190*89c4ff92SAndroid Build Coastguard Worker private:
191*89c4ff92SAndroid Build Coastguard Worker     void ReadPredictions();
192*89c4ff92SAndroid Build Coastguard Worker 
193*89c4ff92SAndroid Build Coastguard Worker     typename InferenceModel::CommandLineOptions m_ModelCommandLineOptions;
194*89c4ff92SAndroid Build Coastguard Worker     std::function<std::unique_ptr<InferenceModel>(const InferenceTestOptions& commonOptions,
195*89c4ff92SAndroid Build Coastguard Worker                                                   typename InferenceModel::CommandLineOptions)> m_ConstructModel;
196*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<InferenceModel> m_Model;
197*89c4ff92SAndroid Build Coastguard Worker 
198*89c4ff92SAndroid Build Coastguard Worker     std::string m_DataDir;
199*89c4ff92SAndroid Build Coastguard Worker     std::function<TDatabase(const char*, const InferenceModel&)> m_ConstructDatabase;
200*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<TDatabase> m_Database;
201*89c4ff92SAndroid Build Coastguard Worker 
202*89c4ff92SAndroid Build Coastguard Worker     int m_NumInferences; // Referenced by test cases.
203*89c4ff92SAndroid Build Coastguard Worker     int m_NumCorrectInferences; // Referenced by test cases.
204*89c4ff92SAndroid Build Coastguard Worker 
205*89c4ff92SAndroid Build Coastguard Worker     std::string m_ValidationFileIn;
206*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> m_ValidationPredictions; // Referenced by test cases.
207*89c4ff92SAndroid Build Coastguard Worker 
208*89c4ff92SAndroid Build Coastguard Worker     std::string m_ValidationFileOut;
209*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> m_ValidationPredictionsOut; // Referenced by test cases.
210*89c4ff92SAndroid Build Coastguard Worker };
211*89c4ff92SAndroid Build Coastguard Worker 
212*89c4ff92SAndroid Build Coastguard Worker bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
213*89c4ff92SAndroid Build Coastguard Worker     InferenceTestOptions& outParams);
214*89c4ff92SAndroid Build Coastguard Worker 
215*89c4ff92SAndroid Build Coastguard Worker bool ValidateDirectory(std::string& dir);
216*89c4ff92SAndroid Build Coastguard Worker 
217*89c4ff92SAndroid Build Coastguard Worker bool InferenceTest(const InferenceTestOptions& params,
218*89c4ff92SAndroid Build Coastguard Worker     const std::vector<unsigned int>& defaultTestCaseIds,
219*89c4ff92SAndroid Build Coastguard Worker     IInferenceTestCaseProvider& testCaseProvider);
220*89c4ff92SAndroid Build Coastguard Worker 
221*89c4ff92SAndroid Build Coastguard Worker template<typename TConstructTestCaseProvider>
222*89c4ff92SAndroid Build Coastguard Worker int InferenceTestMain(int argc,
223*89c4ff92SAndroid Build Coastguard Worker     char* argv[],
224*89c4ff92SAndroid Build Coastguard Worker     const std::vector<unsigned int>& defaultTestCaseIds,
225*89c4ff92SAndroid Build Coastguard Worker     TConstructTestCaseProvider constructTestCaseProvider);
226*89c4ff92SAndroid Build Coastguard Worker 
227*89c4ff92SAndroid Build Coastguard Worker template<typename TDatabase,
228*89c4ff92SAndroid Build Coastguard Worker     typename TParser,
229*89c4ff92SAndroid Build Coastguard Worker     typename TConstructDatabaseCallable>
230*89c4ff92SAndroid Build Coastguard Worker int ClassifierInferenceTestMain(int argc, char* argv[], const char* modelFilename, bool isModelBinary,
231*89c4ff92SAndroid Build Coastguard Worker     const char* inputBindingName, const char* outputBindingName,
232*89c4ff92SAndroid Build Coastguard Worker     const std::vector<unsigned int>& defaultTestCaseIds,
233*89c4ff92SAndroid Build Coastguard Worker     TConstructDatabaseCallable constructDatabase,
234*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorShape* inputTensorShape = nullptr);
235*89c4ff92SAndroid Build Coastguard Worker 
236*89c4ff92SAndroid Build Coastguard Worker } // namespace test
237*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn
238*89c4ff92SAndroid Build Coastguard Worker 
239*89c4ff92SAndroid Build Coastguard Worker #include "InferenceTest.inl"
240