xref: /aosp_15_r20/external/armnn/tests/InferenceTest.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "InferenceModel.hpp"
8 
9 #include <armnn/ArmNN.hpp>
10 #include <armnn/Logging.hpp>
11 #include <armnn/TypesUtils.hpp>
12 #include <armnn/utility/IgnoreUnused.hpp>
13 
14 #include <armnnUtils/TContainer.hpp>
15 
16 #include <cxxopts/cxxopts.hpp>
17 #include <fmt/format.h>
18 
19 
20 namespace armnn
21 {
22 
operator >>(std::istream & in,armnn::Compute & compute)23 inline std::istream& operator>>(std::istream& in, armnn::Compute& compute)
24 {
25     std::string token;
26     in >> token;
27     compute = armnn::ParseComputeDevice(token.c_str());
28     if (compute == armnn::Compute::Undefined)
29     {
30         in.setstate(std::ios_base::failbit);
31         throw cxxopts::OptionException(fmt::format("Unrecognised compute device: {}", token));
32     }
33     return in;
34 }
35 
operator >>(std::istream & in,armnn::BackendId & backend)36 inline std::istream& operator>>(std::istream& in, armnn::BackendId& backend)
37 {
38     std::string token;
39     in >> token;
40     armnn::Compute compute = armnn::ParseComputeDevice(token.c_str());
41     if (compute == armnn::Compute::Undefined)
42     {
43         in.setstate(std::ios_base::failbit);
44         throw cxxopts::OptionException(fmt::format("Unrecognised compute device: {}", token));
45     }
46     backend = compute;
47     return in;
48 }
49 
50 namespace test
51 {
52 
53 class TestFrameworkException : public Exception
54 {
55 public:
56     using Exception::Exception;
57 };
58 
59 struct InferenceTestOptions
60 {
61     unsigned int m_IterationCount;
62     std::string m_InferenceTimesFile;
63     bool m_EnableProfiling;
64     std::string m_DynamicBackendsPath;
65 
InferenceTestOptionsarmnn::test::InferenceTestOptions66     InferenceTestOptions()
67         : m_IterationCount(0)
68         , m_EnableProfiling(0)
69         , m_DynamicBackendsPath()
70     {}
71 };
72 
73 enum class TestCaseResult
74 {
75     /// The test completed without any errors.
76     Ok,
77     /// The test failed (e.g. the prediction didn't match the validation file).
78     /// This will eventually fail the whole program but the remaining test cases will still be run.
79     Failed,
80     /// The test failed with a fatal error. The remaining tests will not be run.
81     Abort
82 };
83 
84 class IInferenceTestCase
85 {
86 public:
~IInferenceTestCase()87     virtual ~IInferenceTestCase() {}
88 
89     virtual void Run() = 0;
90     virtual TestCaseResult ProcessResult(const InferenceTestOptions& options) = 0;
91 };
92 
93 class IInferenceTestCaseProvider
94 {
95 public:
~IInferenceTestCaseProvider()96     virtual ~IInferenceTestCaseProvider() {}
97 
AddCommandLineOptions(cxxopts::Options & options,std::vector<std::string> & required)98     virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required)
99     {
100         IgnoreUnused(options, required);
101     };
ProcessCommandLineOptions(const InferenceTestOptions & commonOptions)102     virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions)
103     {
104         IgnoreUnused(commonOptions);
105         return true;
106     };
107     virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) = 0;
OnInferenceTestFinished()108     virtual bool OnInferenceTestFinished() { return true; };
109 };
110 
111 template <typename TModel>
112 class InferenceModelTestCase : public IInferenceTestCase
113 {
114 public:
115 
InferenceModelTestCase(TModel & model,unsigned int testCaseId,const std::vector<armnnUtils::TContainer> & inputs,const std::vector<unsigned int> & outputSizes)116     InferenceModelTestCase(TModel& model,
117                            unsigned int testCaseId,
118                            const std::vector<armnnUtils::TContainer>& inputs,
119                            const std::vector<unsigned int>& outputSizes)
120         : m_Model(model)
121         , m_TestCaseId(testCaseId)
122         , m_Inputs(std::move(inputs))
123     {
124         // Initialize output vector
125         const size_t numOutputs = outputSizes.size();
126         m_Outputs.reserve(numOutputs);
127 
128         for (size_t i = 0; i < numOutputs; i++)
129         {
130             m_Outputs.push_back(std::vector<typename TModel::DataType>(outputSizes[i]));
131         }
132     }
133 
Run()134     virtual void Run() override
135     {
136         m_Model.Run(m_Inputs, m_Outputs);
137     }
138 
139 protected:
GetTestCaseId() const140     unsigned int GetTestCaseId() const { return m_TestCaseId; }
GetOutputs() const141     const std::vector<armnnUtils::TContainer>& GetOutputs() const { return m_Outputs; }
142 
143 private:
144     TModel&                         m_Model;
145     unsigned int                    m_TestCaseId;
146     std::vector<armnnUtils::TContainer>  m_Inputs;
147     std::vector<armnnUtils::TContainer>  m_Outputs;
148 };
149 
150 template <typename TTestCaseDatabase, typename TModel>
151 class ClassifierTestCase : public InferenceModelTestCase<TModel>
152 {
153 public:
154     ClassifierTestCase(int& numInferencesRef,
155         int& numCorrectInferencesRef,
156         const std::vector<unsigned int>& validationPredictions,
157         std::vector<unsigned int>* validationPredictionsOut,
158         TModel& model,
159         unsigned int testCaseId,
160         unsigned int label,
161         std::vector<typename TModel::DataType> modelInput);
162 
163     virtual TestCaseResult ProcessResult(const InferenceTestOptions& params) override;
164 
165 private:
166     unsigned int m_Label;
167     InferenceModelInternal::QuantizationParams m_QuantizationParams;
168 
169     /// These fields reference the corresponding member in the ClassifierTestCaseProvider.
170     /// @{
171     int& m_NumInferencesRef;
172     int& m_NumCorrectInferencesRef;
173     const std::vector<unsigned int>& m_ValidationPredictions;
174     std::vector<unsigned int>* m_ValidationPredictionsOut;
175     /// @}
176 };
177 
178 template <typename TDatabase, typename InferenceModel>
179 class ClassifierTestCaseProvider : public IInferenceTestCaseProvider
180 {
181 public:
182     template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
183     ClassifierTestCaseProvider(TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel);
184 
185     virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required) override;
186     virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override;
187     virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override;
188     virtual bool OnInferenceTestFinished() override;
189 
190 private:
191     void ReadPredictions();
192 
193     typename InferenceModel::CommandLineOptions m_ModelCommandLineOptions;
194     std::function<std::unique_ptr<InferenceModel>(const InferenceTestOptions& commonOptions,
195                                                   typename InferenceModel::CommandLineOptions)> m_ConstructModel;
196     std::unique_ptr<InferenceModel> m_Model;
197 
198     std::string m_DataDir;
199     std::function<TDatabase(const char*, const InferenceModel&)> m_ConstructDatabase;
200     std::unique_ptr<TDatabase> m_Database;
201 
202     int m_NumInferences; // Referenced by test cases.
203     int m_NumCorrectInferences; // Referenced by test cases.
204 
205     std::string m_ValidationFileIn;
206     std::vector<unsigned int> m_ValidationPredictions; // Referenced by test cases.
207 
208     std::string m_ValidationFileOut;
209     std::vector<unsigned int> m_ValidationPredictionsOut; // Referenced by test cases.
210 };
211 
212 bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
213     InferenceTestOptions& outParams);
214 
215 bool ValidateDirectory(std::string& dir);
216 
217 bool InferenceTest(const InferenceTestOptions& params,
218     const std::vector<unsigned int>& defaultTestCaseIds,
219     IInferenceTestCaseProvider& testCaseProvider);
220 
221 template<typename TConstructTestCaseProvider>
222 int InferenceTestMain(int argc,
223     char* argv[],
224     const std::vector<unsigned int>& defaultTestCaseIds,
225     TConstructTestCaseProvider constructTestCaseProvider);
226 
227 template<typename TDatabase,
228     typename TParser,
229     typename TConstructDatabaseCallable>
230 int ClassifierInferenceTestMain(int argc, char* argv[], const char* modelFilename, bool isModelBinary,
231     const char* inputBindingName, const char* outputBindingName,
232     const std::vector<unsigned int>& defaultTestCaseIds,
233     TConstructDatabaseCallable constructDatabase,
234     const armnn::TensorShape* inputTensorShape = nullptr);
235 
236 } // namespace test
237 } // namespace armnn
238 
239 #include "InferenceTest.inl"
240