xref: /aosp_15_r20/external/armnn/tests/InferenceTest.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #include "InferenceTest.hpp"
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Assert.hpp>
8*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Filesystem.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include "../src/armnn/Profiling.hpp"
11*89c4ff92SAndroid Build Coastguard Worker #include <cxxopts/cxxopts.hpp>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker #include <fstream>
14*89c4ff92SAndroid Build Coastguard Worker #include <iostream>
15*89c4ff92SAndroid Build Coastguard Worker #include <iomanip>
16*89c4ff92SAndroid Build Coastguard Worker #include <array>
17*89c4ff92SAndroid Build Coastguard Worker 
18*89c4ff92SAndroid Build Coastguard Worker using namespace std;
19*89c4ff92SAndroid Build Coastguard Worker using namespace std::chrono;
20*89c4ff92SAndroid Build Coastguard Worker using namespace armnn::test;
21*89c4ff92SAndroid Build Coastguard Worker 
22*89c4ff92SAndroid Build Coastguard Worker namespace armnn
23*89c4ff92SAndroid Build Coastguard Worker {
24*89c4ff92SAndroid Build Coastguard Worker namespace test
25*89c4ff92SAndroid Build Coastguard Worker {
26*89c4ff92SAndroid Build Coastguard Worker /// Parse the command line of an ArmNN inference test program.
27*89c4ff92SAndroid Build Coastguard Worker /// \return false if any error occurred during options processing, otherwise true
ParseCommandLine(int argc,char ** argv,IInferenceTestCaseProvider & testCaseProvider,InferenceTestOptions & outParams)28*89c4ff92SAndroid Build Coastguard Worker bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
29*89c4ff92SAndroid Build Coastguard Worker     InferenceTestOptions& outParams)
30*89c4ff92SAndroid Build Coastguard Worker {
31*89c4ff92SAndroid Build Coastguard Worker     cxxopts::Options options("InferenceTest", "Inference iteration parameters");
32*89c4ff92SAndroid Build Coastguard Worker 
33*89c4ff92SAndroid Build Coastguard Worker     try
34*89c4ff92SAndroid Build Coastguard Worker     {
35*89c4ff92SAndroid Build Coastguard Worker         // Adds generic options needed for all inference tests.
36*89c4ff92SAndroid Build Coastguard Worker         options
37*89c4ff92SAndroid Build Coastguard Worker             .allow_unrecognised_options()
38*89c4ff92SAndroid Build Coastguard Worker             .add_options()
39*89c4ff92SAndroid Build Coastguard Worker                 ("h,help", "Display help messages")
40*89c4ff92SAndroid Build Coastguard Worker                 ("i,iterations", "Sets the number of inferences to perform. If unset, will only be run once.",
41*89c4ff92SAndroid Build Coastguard Worker                  cxxopts::value<unsigned int>(outParams.m_IterationCount)->default_value("0"))
42*89c4ff92SAndroid Build Coastguard Worker                 ("inference-times-file",
43*89c4ff92SAndroid Build Coastguard Worker                  "If non-empty, each individual inference time will be recorded and output to this file",
44*89c4ff92SAndroid Build Coastguard Worker                  cxxopts::value<std::string>(outParams.m_InferenceTimesFile)->default_value(""))
45*89c4ff92SAndroid Build Coastguard Worker                 ("e,event-based-profiling", "Enables built in profiler. If unset, defaults to off.",
46*89c4ff92SAndroid Build Coastguard Worker                  cxxopts::value<bool>(outParams.m_EnableProfiling)->default_value("0"));
47*89c4ff92SAndroid Build Coastguard Worker 
48*89c4ff92SAndroid Build Coastguard Worker         std::vector<std::string> required; //to be passed as reference to derived inference tests
49*89c4ff92SAndroid Build Coastguard Worker 
50*89c4ff92SAndroid Build Coastguard Worker         // Adds options specific to the ITestCaseProvider.
51*89c4ff92SAndroid Build Coastguard Worker         testCaseProvider.AddCommandLineOptions(options, required);
52*89c4ff92SAndroid Build Coastguard Worker 
53*89c4ff92SAndroid Build Coastguard Worker         auto result = options.parse(argc, argv);
54*89c4ff92SAndroid Build Coastguard Worker 
55*89c4ff92SAndroid Build Coastguard Worker         if (result.count("help"))
56*89c4ff92SAndroid Build Coastguard Worker         {
57*89c4ff92SAndroid Build Coastguard Worker             std::cout << options.help() << std::endl;
58*89c4ff92SAndroid Build Coastguard Worker             return false;
59*89c4ff92SAndroid Build Coastguard Worker         }
60*89c4ff92SAndroid Build Coastguard Worker 
61*89c4ff92SAndroid Build Coastguard Worker         CheckRequiredOptions(result, required);
62*89c4ff92SAndroid Build Coastguard Worker 
63*89c4ff92SAndroid Build Coastguard Worker     }
64*89c4ff92SAndroid Build Coastguard Worker     catch (const cxxopts::OptionException& e)
65*89c4ff92SAndroid Build Coastguard Worker     {
66*89c4ff92SAndroid Build Coastguard Worker         std::cerr << e.what() << std::endl << options.help() << std::endl;
67*89c4ff92SAndroid Build Coastguard Worker         return false;
68*89c4ff92SAndroid Build Coastguard Worker     }
69*89c4ff92SAndroid Build Coastguard Worker     catch (const std::exception& e)
70*89c4ff92SAndroid Build Coastguard Worker     {
71*89c4ff92SAndroid Build Coastguard Worker         ARMNN_ASSERT_MSG(false, "Caught unexpected exception");
72*89c4ff92SAndroid Build Coastguard Worker         std::cerr << "Fatal internal error: " << e.what() << std::endl;
73*89c4ff92SAndroid Build Coastguard Worker         return false;
74*89c4ff92SAndroid Build Coastguard Worker     }
75*89c4ff92SAndroid Build Coastguard Worker 
76*89c4ff92SAndroid Build Coastguard Worker     if (!testCaseProvider.ProcessCommandLineOptions(outParams))
77*89c4ff92SAndroid Build Coastguard Worker     {
78*89c4ff92SAndroid Build Coastguard Worker         return false;
79*89c4ff92SAndroid Build Coastguard Worker     }
80*89c4ff92SAndroid Build Coastguard Worker 
81*89c4ff92SAndroid Build Coastguard Worker     return true;
82*89c4ff92SAndroid Build Coastguard Worker }
83*89c4ff92SAndroid Build Coastguard Worker 
ValidateDirectory(std::string & dir)84*89c4ff92SAndroid Build Coastguard Worker bool ValidateDirectory(std::string& dir)
85*89c4ff92SAndroid Build Coastguard Worker {
86*89c4ff92SAndroid Build Coastguard Worker     if (dir.empty())
87*89c4ff92SAndroid Build Coastguard Worker     {
88*89c4ff92SAndroid Build Coastguard Worker         std::cerr << "No directory specified" << std::endl;
89*89c4ff92SAndroid Build Coastguard Worker         return false;
90*89c4ff92SAndroid Build Coastguard Worker     }
91*89c4ff92SAndroid Build Coastguard Worker 
92*89c4ff92SAndroid Build Coastguard Worker     if (dir[dir.length() - 1] != '/')
93*89c4ff92SAndroid Build Coastguard Worker     {
94*89c4ff92SAndroid Build Coastguard Worker         dir += "/";
95*89c4ff92SAndroid Build Coastguard Worker     }
96*89c4ff92SAndroid Build Coastguard Worker 
97*89c4ff92SAndroid Build Coastguard Worker     if (!fs::exists(dir))
98*89c4ff92SAndroid Build Coastguard Worker     {
99*89c4ff92SAndroid Build Coastguard Worker         std::cerr << "Given directory " << dir << " does not exist" << std::endl;
100*89c4ff92SAndroid Build Coastguard Worker         return false;
101*89c4ff92SAndroid Build Coastguard Worker     }
102*89c4ff92SAndroid Build Coastguard Worker 
103*89c4ff92SAndroid Build Coastguard Worker     if (!fs::is_directory(dir))
104*89c4ff92SAndroid Build Coastguard Worker     {
105*89c4ff92SAndroid Build Coastguard Worker         std::cerr << "Given directory [" << dir << "] is not a directory" << std::endl;
106*89c4ff92SAndroid Build Coastguard Worker         return false;
107*89c4ff92SAndroid Build Coastguard Worker     }
108*89c4ff92SAndroid Build Coastguard Worker 
109*89c4ff92SAndroid Build Coastguard Worker     return true;
110*89c4ff92SAndroid Build Coastguard Worker }
111*89c4ff92SAndroid Build Coastguard Worker 
InferenceTest(const InferenceTestOptions & params,const std::vector<unsigned int> & defaultTestCaseIds,IInferenceTestCaseProvider & testCaseProvider)112*89c4ff92SAndroid Build Coastguard Worker bool InferenceTest(const InferenceTestOptions& params,
113*89c4ff92SAndroid Build Coastguard Worker     const std::vector<unsigned int>& defaultTestCaseIds,
114*89c4ff92SAndroid Build Coastguard Worker     IInferenceTestCaseProvider& testCaseProvider)
115*89c4ff92SAndroid Build Coastguard Worker {
116*89c4ff92SAndroid Build Coastguard Worker #if !defined (NDEBUG)
117*89c4ff92SAndroid Build Coastguard Worker     if (params.m_IterationCount > 0) // If just running a few select images then don't bother to warn.
118*89c4ff92SAndroid Build Coastguard Worker     {
119*89c4ff92SAndroid Build Coastguard Worker         ARMNN_LOG(warning) << "Performance test running in DEBUG build - results may be inaccurate.";
120*89c4ff92SAndroid Build Coastguard Worker     }
121*89c4ff92SAndroid Build Coastguard Worker #endif
122*89c4ff92SAndroid Build Coastguard Worker 
123*89c4ff92SAndroid Build Coastguard Worker     double totalTime = 0;
124*89c4ff92SAndroid Build Coastguard Worker     unsigned int nbProcessed = 0;
125*89c4ff92SAndroid Build Coastguard Worker     bool success = true;
126*89c4ff92SAndroid Build Coastguard Worker 
127*89c4ff92SAndroid Build Coastguard Worker     // Opens the file to write inference times too, if needed.
128*89c4ff92SAndroid Build Coastguard Worker     ofstream inferenceTimesFile;
129*89c4ff92SAndroid Build Coastguard Worker     const bool recordInferenceTimes = !params.m_InferenceTimesFile.empty();
130*89c4ff92SAndroid Build Coastguard Worker     if (recordInferenceTimes)
131*89c4ff92SAndroid Build Coastguard Worker     {
132*89c4ff92SAndroid Build Coastguard Worker         inferenceTimesFile.open(params.m_InferenceTimesFile.c_str(), ios_base::trunc | ios_base::out);
133*89c4ff92SAndroid Build Coastguard Worker         if (!inferenceTimesFile.good())
134*89c4ff92SAndroid Build Coastguard Worker         {
135*89c4ff92SAndroid Build Coastguard Worker             ARMNN_LOG(error) << "Failed to open inference times file for writing: "
136*89c4ff92SAndroid Build Coastguard Worker                 << params.m_InferenceTimesFile;
137*89c4ff92SAndroid Build Coastguard Worker             return false;
138*89c4ff92SAndroid Build Coastguard Worker         }
139*89c4ff92SAndroid Build Coastguard Worker     }
140*89c4ff92SAndroid Build Coastguard Worker 
141*89c4ff92SAndroid Build Coastguard Worker     // Create a profiler and register it for the current thread.
142*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<IProfiler> profiler = std::make_unique<IProfiler>();
143*89c4ff92SAndroid Build Coastguard Worker     ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
144*89c4ff92SAndroid Build Coastguard Worker 
145*89c4ff92SAndroid Build Coastguard Worker     // Enable profiling if requested.
146*89c4ff92SAndroid Build Coastguard Worker     profiler->EnableProfiling(params.m_EnableProfiling);
147*89c4ff92SAndroid Build Coastguard Worker 
148*89c4ff92SAndroid Build Coastguard Worker     // Run a single test case to 'warm-up' the model. The first one can sometimes take up to 10x longer
149*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<IInferenceTestCase> warmupTestCase = testCaseProvider.GetTestCase(0);
150*89c4ff92SAndroid Build Coastguard Worker     if (warmupTestCase == nullptr)
151*89c4ff92SAndroid Build Coastguard Worker     {
152*89c4ff92SAndroid Build Coastguard Worker         ARMNN_LOG(error) << "Failed to load test case";
153*89c4ff92SAndroid Build Coastguard Worker         return false;
154*89c4ff92SAndroid Build Coastguard Worker     }
155*89c4ff92SAndroid Build Coastguard Worker 
156*89c4ff92SAndroid Build Coastguard Worker     try
157*89c4ff92SAndroid Build Coastguard Worker     {
158*89c4ff92SAndroid Build Coastguard Worker         warmupTestCase->Run();
159*89c4ff92SAndroid Build Coastguard Worker     }
160*89c4ff92SAndroid Build Coastguard Worker     catch (const TestFrameworkException& testError)
161*89c4ff92SAndroid Build Coastguard Worker     {
162*89c4ff92SAndroid Build Coastguard Worker         ARMNN_LOG(error) << testError.what();
163*89c4ff92SAndroid Build Coastguard Worker         return false;
164*89c4ff92SAndroid Build Coastguard Worker     }
165*89c4ff92SAndroid Build Coastguard Worker 
166*89c4ff92SAndroid Build Coastguard Worker     const unsigned int nbTotalToProcess = params.m_IterationCount > 0 ? params.m_IterationCount
167*89c4ff92SAndroid Build Coastguard Worker         : static_cast<unsigned int>(defaultTestCaseIds.size());
168*89c4ff92SAndroid Build Coastguard Worker 
169*89c4ff92SAndroid Build Coastguard Worker     for (; nbProcessed < nbTotalToProcess; nbProcessed++)
170*89c4ff92SAndroid Build Coastguard Worker     {
171*89c4ff92SAndroid Build Coastguard Worker         const unsigned int testCaseId = params.m_IterationCount > 0 ? nbProcessed : defaultTestCaseIds[nbProcessed];
172*89c4ff92SAndroid Build Coastguard Worker         std::unique_ptr<IInferenceTestCase> testCase = testCaseProvider.GetTestCase(testCaseId);
173*89c4ff92SAndroid Build Coastguard Worker 
174*89c4ff92SAndroid Build Coastguard Worker         if (testCase == nullptr)
175*89c4ff92SAndroid Build Coastguard Worker         {
176*89c4ff92SAndroid Build Coastguard Worker             ARMNN_LOG(error) << "Failed to load test case";
177*89c4ff92SAndroid Build Coastguard Worker             return false;
178*89c4ff92SAndroid Build Coastguard Worker         }
179*89c4ff92SAndroid Build Coastguard Worker 
180*89c4ff92SAndroid Build Coastguard Worker         time_point<high_resolution_clock> predictStart;
181*89c4ff92SAndroid Build Coastguard Worker         time_point<high_resolution_clock> predictEnd;
182*89c4ff92SAndroid Build Coastguard Worker 
183*89c4ff92SAndroid Build Coastguard Worker         TestCaseResult result = TestCaseResult::Ok;
184*89c4ff92SAndroid Build Coastguard Worker 
185*89c4ff92SAndroid Build Coastguard Worker         try
186*89c4ff92SAndroid Build Coastguard Worker         {
187*89c4ff92SAndroid Build Coastguard Worker             predictStart = high_resolution_clock::now();
188*89c4ff92SAndroid Build Coastguard Worker 
189*89c4ff92SAndroid Build Coastguard Worker             testCase->Run();
190*89c4ff92SAndroid Build Coastguard Worker 
191*89c4ff92SAndroid Build Coastguard Worker             predictEnd = high_resolution_clock::now();
192*89c4ff92SAndroid Build Coastguard Worker 
193*89c4ff92SAndroid Build Coastguard Worker             // duration<double> will convert the time difference into seconds as a double by default.
194*89c4ff92SAndroid Build Coastguard Worker             double timeTakenS = duration<double>(predictEnd - predictStart).count();
195*89c4ff92SAndroid Build Coastguard Worker             totalTime += timeTakenS;
196*89c4ff92SAndroid Build Coastguard Worker 
197*89c4ff92SAndroid Build Coastguard Worker             // Outputss inference times, if needed.
198*89c4ff92SAndroid Build Coastguard Worker             if (recordInferenceTimes)
199*89c4ff92SAndroid Build Coastguard Worker             {
200*89c4ff92SAndroid Build Coastguard Worker                 inferenceTimesFile << testCaseId << " " << (timeTakenS * 1000.0) << std::endl;
201*89c4ff92SAndroid Build Coastguard Worker             }
202*89c4ff92SAndroid Build Coastguard Worker 
203*89c4ff92SAndroid Build Coastguard Worker             result = testCase->ProcessResult(params);
204*89c4ff92SAndroid Build Coastguard Worker 
205*89c4ff92SAndroid Build Coastguard Worker         }
206*89c4ff92SAndroid Build Coastguard Worker         catch (const TestFrameworkException& testError)
207*89c4ff92SAndroid Build Coastguard Worker         {
208*89c4ff92SAndroid Build Coastguard Worker             ARMNN_LOG(error) << testError.what();
209*89c4ff92SAndroid Build Coastguard Worker             result = TestCaseResult::Abort;
210*89c4ff92SAndroid Build Coastguard Worker         }
211*89c4ff92SAndroid Build Coastguard Worker 
212*89c4ff92SAndroid Build Coastguard Worker         switch (result)
213*89c4ff92SAndroid Build Coastguard Worker         {
214*89c4ff92SAndroid Build Coastguard Worker         case TestCaseResult::Ok:
215*89c4ff92SAndroid Build Coastguard Worker             break;
216*89c4ff92SAndroid Build Coastguard Worker         case TestCaseResult::Abort:
217*89c4ff92SAndroid Build Coastguard Worker             return false;
218*89c4ff92SAndroid Build Coastguard Worker         case TestCaseResult::Failed:
219*89c4ff92SAndroid Build Coastguard Worker             // This test failed so we will fail the entire program eventually, but keep going for now.
220*89c4ff92SAndroid Build Coastguard Worker             success = false;
221*89c4ff92SAndroid Build Coastguard Worker             break;
222*89c4ff92SAndroid Build Coastguard Worker         default:
223*89c4ff92SAndroid Build Coastguard Worker             ARMNN_ASSERT_MSG(false, "Unexpected TestCaseResult");
224*89c4ff92SAndroid Build Coastguard Worker             return false;
225*89c4ff92SAndroid Build Coastguard Worker         }
226*89c4ff92SAndroid Build Coastguard Worker     }
227*89c4ff92SAndroid Build Coastguard Worker 
228*89c4ff92SAndroid Build Coastguard Worker     const double averageTimePerTestCaseMs = totalTime / nbProcessed * 1000.0f;
229*89c4ff92SAndroid Build Coastguard Worker 
230*89c4ff92SAndroid Build Coastguard Worker     ARMNN_LOG(info) << std::fixed << std::setprecision(3) <<
231*89c4ff92SAndroid Build Coastguard Worker         "Total time for " << nbProcessed << " test cases: " << totalTime << " seconds";
232*89c4ff92SAndroid Build Coastguard Worker     ARMNN_LOG(info) << std::fixed << std::setprecision(3) <<
233*89c4ff92SAndroid Build Coastguard Worker         "Average time per test case: " << averageTimePerTestCaseMs << " ms";
234*89c4ff92SAndroid Build Coastguard Worker 
235*89c4ff92SAndroid Build Coastguard Worker     // if profiling is enabled print out the results
236*89c4ff92SAndroid Build Coastguard Worker     if (profiler && profiler->IsProfilingEnabled())
237*89c4ff92SAndroid Build Coastguard Worker     {
238*89c4ff92SAndroid Build Coastguard Worker         profiler->Print(std::cout);
239*89c4ff92SAndroid Build Coastguard Worker     }
240*89c4ff92SAndroid Build Coastguard Worker 
241*89c4ff92SAndroid Build Coastguard Worker     if (!success)
242*89c4ff92SAndroid Build Coastguard Worker     {
243*89c4ff92SAndroid Build Coastguard Worker         ARMNN_LOG(error) << "One or more test cases failed";
244*89c4ff92SAndroid Build Coastguard Worker         return false;
245*89c4ff92SAndroid Build Coastguard Worker     }
246*89c4ff92SAndroid Build Coastguard Worker 
247*89c4ff92SAndroid Build Coastguard Worker     return testCaseProvider.OnInferenceTestFinished();
248*89c4ff92SAndroid Build Coastguard Worker }
249*89c4ff92SAndroid Build Coastguard Worker 
250*89c4ff92SAndroid Build Coastguard Worker } // namespace test
251*89c4ff92SAndroid Build Coastguard Worker 
252*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn
253