1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #include "InferenceTest.hpp"
6*89c4ff92SAndroid Build Coastguard Worker
7*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Assert.hpp>
8*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Filesystem.hpp>
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Worker #include "../src/armnn/Profiling.hpp"
11*89c4ff92SAndroid Build Coastguard Worker #include <cxxopts/cxxopts.hpp>
12*89c4ff92SAndroid Build Coastguard Worker
13*89c4ff92SAndroid Build Coastguard Worker #include <fstream>
14*89c4ff92SAndroid Build Coastguard Worker #include <iostream>
15*89c4ff92SAndroid Build Coastguard Worker #include <iomanip>
16*89c4ff92SAndroid Build Coastguard Worker #include <array>
17*89c4ff92SAndroid Build Coastguard Worker
18*89c4ff92SAndroid Build Coastguard Worker using namespace std;
19*89c4ff92SAndroid Build Coastguard Worker using namespace std::chrono;
20*89c4ff92SAndroid Build Coastguard Worker using namespace armnn::test;
21*89c4ff92SAndroid Build Coastguard Worker
22*89c4ff92SAndroid Build Coastguard Worker namespace armnn
23*89c4ff92SAndroid Build Coastguard Worker {
24*89c4ff92SAndroid Build Coastguard Worker namespace test
25*89c4ff92SAndroid Build Coastguard Worker {
26*89c4ff92SAndroid Build Coastguard Worker /// Parse the command line of an ArmNN inference test program.
27*89c4ff92SAndroid Build Coastguard Worker /// \return false if any error occurred during options processing, otherwise true
ParseCommandLine(int argc,char ** argv,IInferenceTestCaseProvider & testCaseProvider,InferenceTestOptions & outParams)28*89c4ff92SAndroid Build Coastguard Worker bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
29*89c4ff92SAndroid Build Coastguard Worker InferenceTestOptions& outParams)
30*89c4ff92SAndroid Build Coastguard Worker {
31*89c4ff92SAndroid Build Coastguard Worker cxxopts::Options options("InferenceTest", "Inference iteration parameters");
32*89c4ff92SAndroid Build Coastguard Worker
33*89c4ff92SAndroid Build Coastguard Worker try
34*89c4ff92SAndroid Build Coastguard Worker {
35*89c4ff92SAndroid Build Coastguard Worker // Adds generic options needed for all inference tests.
36*89c4ff92SAndroid Build Coastguard Worker options
37*89c4ff92SAndroid Build Coastguard Worker .allow_unrecognised_options()
38*89c4ff92SAndroid Build Coastguard Worker .add_options()
39*89c4ff92SAndroid Build Coastguard Worker ("h,help", "Display help messages")
40*89c4ff92SAndroid Build Coastguard Worker ("i,iterations", "Sets the number of inferences to perform. If unset, will only be run once.",
41*89c4ff92SAndroid Build Coastguard Worker cxxopts::value<unsigned int>(outParams.m_IterationCount)->default_value("0"))
42*89c4ff92SAndroid Build Coastguard Worker ("inference-times-file",
43*89c4ff92SAndroid Build Coastguard Worker "If non-empty, each individual inference time will be recorded and output to this file",
44*89c4ff92SAndroid Build Coastguard Worker cxxopts::value<std::string>(outParams.m_InferenceTimesFile)->default_value(""))
45*89c4ff92SAndroid Build Coastguard Worker ("e,event-based-profiling", "Enables built in profiler. If unset, defaults to off.",
46*89c4ff92SAndroid Build Coastguard Worker cxxopts::value<bool>(outParams.m_EnableProfiling)->default_value("0"));
47*89c4ff92SAndroid Build Coastguard Worker
48*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string> required; //to be passed as reference to derived inference tests
49*89c4ff92SAndroid Build Coastguard Worker
50*89c4ff92SAndroid Build Coastguard Worker // Adds options specific to the ITestCaseProvider.
51*89c4ff92SAndroid Build Coastguard Worker testCaseProvider.AddCommandLineOptions(options, required);
52*89c4ff92SAndroid Build Coastguard Worker
53*89c4ff92SAndroid Build Coastguard Worker auto result = options.parse(argc, argv);
54*89c4ff92SAndroid Build Coastguard Worker
55*89c4ff92SAndroid Build Coastguard Worker if (result.count("help"))
56*89c4ff92SAndroid Build Coastguard Worker {
57*89c4ff92SAndroid Build Coastguard Worker std::cout << options.help() << std::endl;
58*89c4ff92SAndroid Build Coastguard Worker return false;
59*89c4ff92SAndroid Build Coastguard Worker }
60*89c4ff92SAndroid Build Coastguard Worker
61*89c4ff92SAndroid Build Coastguard Worker CheckRequiredOptions(result, required);
62*89c4ff92SAndroid Build Coastguard Worker
63*89c4ff92SAndroid Build Coastguard Worker }
64*89c4ff92SAndroid Build Coastguard Worker catch (const cxxopts::OptionException& e)
65*89c4ff92SAndroid Build Coastguard Worker {
66*89c4ff92SAndroid Build Coastguard Worker std::cerr << e.what() << std::endl << options.help() << std::endl;
67*89c4ff92SAndroid Build Coastguard Worker return false;
68*89c4ff92SAndroid Build Coastguard Worker }
69*89c4ff92SAndroid Build Coastguard Worker catch (const std::exception& e)
70*89c4ff92SAndroid Build Coastguard Worker {
71*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(false, "Caught unexpected exception");
72*89c4ff92SAndroid Build Coastguard Worker std::cerr << "Fatal internal error: " << e.what() << std::endl;
73*89c4ff92SAndroid Build Coastguard Worker return false;
74*89c4ff92SAndroid Build Coastguard Worker }
75*89c4ff92SAndroid Build Coastguard Worker
76*89c4ff92SAndroid Build Coastguard Worker if (!testCaseProvider.ProcessCommandLineOptions(outParams))
77*89c4ff92SAndroid Build Coastguard Worker {
78*89c4ff92SAndroid Build Coastguard Worker return false;
79*89c4ff92SAndroid Build Coastguard Worker }
80*89c4ff92SAndroid Build Coastguard Worker
81*89c4ff92SAndroid Build Coastguard Worker return true;
82*89c4ff92SAndroid Build Coastguard Worker }
83*89c4ff92SAndroid Build Coastguard Worker
ValidateDirectory(std::string & dir)84*89c4ff92SAndroid Build Coastguard Worker bool ValidateDirectory(std::string& dir)
85*89c4ff92SAndroid Build Coastguard Worker {
86*89c4ff92SAndroid Build Coastguard Worker if (dir.empty())
87*89c4ff92SAndroid Build Coastguard Worker {
88*89c4ff92SAndroid Build Coastguard Worker std::cerr << "No directory specified" << std::endl;
89*89c4ff92SAndroid Build Coastguard Worker return false;
90*89c4ff92SAndroid Build Coastguard Worker }
91*89c4ff92SAndroid Build Coastguard Worker
92*89c4ff92SAndroid Build Coastguard Worker if (dir[dir.length() - 1] != '/')
93*89c4ff92SAndroid Build Coastguard Worker {
94*89c4ff92SAndroid Build Coastguard Worker dir += "/";
95*89c4ff92SAndroid Build Coastguard Worker }
96*89c4ff92SAndroid Build Coastguard Worker
97*89c4ff92SAndroid Build Coastguard Worker if (!fs::exists(dir))
98*89c4ff92SAndroid Build Coastguard Worker {
99*89c4ff92SAndroid Build Coastguard Worker std::cerr << "Given directory " << dir << " does not exist" << std::endl;
100*89c4ff92SAndroid Build Coastguard Worker return false;
101*89c4ff92SAndroid Build Coastguard Worker }
102*89c4ff92SAndroid Build Coastguard Worker
103*89c4ff92SAndroid Build Coastguard Worker if (!fs::is_directory(dir))
104*89c4ff92SAndroid Build Coastguard Worker {
105*89c4ff92SAndroid Build Coastguard Worker std::cerr << "Given directory [" << dir << "] is not a directory" << std::endl;
106*89c4ff92SAndroid Build Coastguard Worker return false;
107*89c4ff92SAndroid Build Coastguard Worker }
108*89c4ff92SAndroid Build Coastguard Worker
109*89c4ff92SAndroid Build Coastguard Worker return true;
110*89c4ff92SAndroid Build Coastguard Worker }
111*89c4ff92SAndroid Build Coastguard Worker
InferenceTest(const InferenceTestOptions & params,const std::vector<unsigned int> & defaultTestCaseIds,IInferenceTestCaseProvider & testCaseProvider)112*89c4ff92SAndroid Build Coastguard Worker bool InferenceTest(const InferenceTestOptions& params,
113*89c4ff92SAndroid Build Coastguard Worker const std::vector<unsigned int>& defaultTestCaseIds,
114*89c4ff92SAndroid Build Coastguard Worker IInferenceTestCaseProvider& testCaseProvider)
115*89c4ff92SAndroid Build Coastguard Worker {
116*89c4ff92SAndroid Build Coastguard Worker #if !defined (NDEBUG)
117*89c4ff92SAndroid Build Coastguard Worker if (params.m_IterationCount > 0) // If just running a few select images then don't bother to warn.
118*89c4ff92SAndroid Build Coastguard Worker {
119*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(warning) << "Performance test running in DEBUG build - results may be inaccurate.";
120*89c4ff92SAndroid Build Coastguard Worker }
121*89c4ff92SAndroid Build Coastguard Worker #endif
122*89c4ff92SAndroid Build Coastguard Worker
123*89c4ff92SAndroid Build Coastguard Worker double totalTime = 0;
124*89c4ff92SAndroid Build Coastguard Worker unsigned int nbProcessed = 0;
125*89c4ff92SAndroid Build Coastguard Worker bool success = true;
126*89c4ff92SAndroid Build Coastguard Worker
127*89c4ff92SAndroid Build Coastguard Worker // Opens the file to write inference times too, if needed.
128*89c4ff92SAndroid Build Coastguard Worker ofstream inferenceTimesFile;
129*89c4ff92SAndroid Build Coastguard Worker const bool recordInferenceTimes = !params.m_InferenceTimesFile.empty();
130*89c4ff92SAndroid Build Coastguard Worker if (recordInferenceTimes)
131*89c4ff92SAndroid Build Coastguard Worker {
132*89c4ff92SAndroid Build Coastguard Worker inferenceTimesFile.open(params.m_InferenceTimesFile.c_str(), ios_base::trunc | ios_base::out);
133*89c4ff92SAndroid Build Coastguard Worker if (!inferenceTimesFile.good())
134*89c4ff92SAndroid Build Coastguard Worker {
135*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << "Failed to open inference times file for writing: "
136*89c4ff92SAndroid Build Coastguard Worker << params.m_InferenceTimesFile;
137*89c4ff92SAndroid Build Coastguard Worker return false;
138*89c4ff92SAndroid Build Coastguard Worker }
139*89c4ff92SAndroid Build Coastguard Worker }
140*89c4ff92SAndroid Build Coastguard Worker
141*89c4ff92SAndroid Build Coastguard Worker // Create a profiler and register it for the current thread.
142*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<IProfiler> profiler = std::make_unique<IProfiler>();
143*89c4ff92SAndroid Build Coastguard Worker ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
144*89c4ff92SAndroid Build Coastguard Worker
145*89c4ff92SAndroid Build Coastguard Worker // Enable profiling if requested.
146*89c4ff92SAndroid Build Coastguard Worker profiler->EnableProfiling(params.m_EnableProfiling);
147*89c4ff92SAndroid Build Coastguard Worker
148*89c4ff92SAndroid Build Coastguard Worker // Run a single test case to 'warm-up' the model. The first one can sometimes take up to 10x longer
149*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<IInferenceTestCase> warmupTestCase = testCaseProvider.GetTestCase(0);
150*89c4ff92SAndroid Build Coastguard Worker if (warmupTestCase == nullptr)
151*89c4ff92SAndroid Build Coastguard Worker {
152*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << "Failed to load test case";
153*89c4ff92SAndroid Build Coastguard Worker return false;
154*89c4ff92SAndroid Build Coastguard Worker }
155*89c4ff92SAndroid Build Coastguard Worker
156*89c4ff92SAndroid Build Coastguard Worker try
157*89c4ff92SAndroid Build Coastguard Worker {
158*89c4ff92SAndroid Build Coastguard Worker warmupTestCase->Run();
159*89c4ff92SAndroid Build Coastguard Worker }
160*89c4ff92SAndroid Build Coastguard Worker catch (const TestFrameworkException& testError)
161*89c4ff92SAndroid Build Coastguard Worker {
162*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << testError.what();
163*89c4ff92SAndroid Build Coastguard Worker return false;
164*89c4ff92SAndroid Build Coastguard Worker }
165*89c4ff92SAndroid Build Coastguard Worker
166*89c4ff92SAndroid Build Coastguard Worker const unsigned int nbTotalToProcess = params.m_IterationCount > 0 ? params.m_IterationCount
167*89c4ff92SAndroid Build Coastguard Worker : static_cast<unsigned int>(defaultTestCaseIds.size());
168*89c4ff92SAndroid Build Coastguard Worker
169*89c4ff92SAndroid Build Coastguard Worker for (; nbProcessed < nbTotalToProcess; nbProcessed++)
170*89c4ff92SAndroid Build Coastguard Worker {
171*89c4ff92SAndroid Build Coastguard Worker const unsigned int testCaseId = params.m_IterationCount > 0 ? nbProcessed : defaultTestCaseIds[nbProcessed];
172*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<IInferenceTestCase> testCase = testCaseProvider.GetTestCase(testCaseId);
173*89c4ff92SAndroid Build Coastguard Worker
174*89c4ff92SAndroid Build Coastguard Worker if (testCase == nullptr)
175*89c4ff92SAndroid Build Coastguard Worker {
176*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << "Failed to load test case";
177*89c4ff92SAndroid Build Coastguard Worker return false;
178*89c4ff92SAndroid Build Coastguard Worker }
179*89c4ff92SAndroid Build Coastguard Worker
180*89c4ff92SAndroid Build Coastguard Worker time_point<high_resolution_clock> predictStart;
181*89c4ff92SAndroid Build Coastguard Worker time_point<high_resolution_clock> predictEnd;
182*89c4ff92SAndroid Build Coastguard Worker
183*89c4ff92SAndroid Build Coastguard Worker TestCaseResult result = TestCaseResult::Ok;
184*89c4ff92SAndroid Build Coastguard Worker
185*89c4ff92SAndroid Build Coastguard Worker try
186*89c4ff92SAndroid Build Coastguard Worker {
187*89c4ff92SAndroid Build Coastguard Worker predictStart = high_resolution_clock::now();
188*89c4ff92SAndroid Build Coastguard Worker
189*89c4ff92SAndroid Build Coastguard Worker testCase->Run();
190*89c4ff92SAndroid Build Coastguard Worker
191*89c4ff92SAndroid Build Coastguard Worker predictEnd = high_resolution_clock::now();
192*89c4ff92SAndroid Build Coastguard Worker
193*89c4ff92SAndroid Build Coastguard Worker // duration<double> will convert the time difference into seconds as a double by default.
194*89c4ff92SAndroid Build Coastguard Worker double timeTakenS = duration<double>(predictEnd - predictStart).count();
195*89c4ff92SAndroid Build Coastguard Worker totalTime += timeTakenS;
196*89c4ff92SAndroid Build Coastguard Worker
197*89c4ff92SAndroid Build Coastguard Worker // Outputss inference times, if needed.
198*89c4ff92SAndroid Build Coastguard Worker if (recordInferenceTimes)
199*89c4ff92SAndroid Build Coastguard Worker {
200*89c4ff92SAndroid Build Coastguard Worker inferenceTimesFile << testCaseId << " " << (timeTakenS * 1000.0) << std::endl;
201*89c4ff92SAndroid Build Coastguard Worker }
202*89c4ff92SAndroid Build Coastguard Worker
203*89c4ff92SAndroid Build Coastguard Worker result = testCase->ProcessResult(params);
204*89c4ff92SAndroid Build Coastguard Worker
205*89c4ff92SAndroid Build Coastguard Worker }
206*89c4ff92SAndroid Build Coastguard Worker catch (const TestFrameworkException& testError)
207*89c4ff92SAndroid Build Coastguard Worker {
208*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << testError.what();
209*89c4ff92SAndroid Build Coastguard Worker result = TestCaseResult::Abort;
210*89c4ff92SAndroid Build Coastguard Worker }
211*89c4ff92SAndroid Build Coastguard Worker
212*89c4ff92SAndroid Build Coastguard Worker switch (result)
213*89c4ff92SAndroid Build Coastguard Worker {
214*89c4ff92SAndroid Build Coastguard Worker case TestCaseResult::Ok:
215*89c4ff92SAndroid Build Coastguard Worker break;
216*89c4ff92SAndroid Build Coastguard Worker case TestCaseResult::Abort:
217*89c4ff92SAndroid Build Coastguard Worker return false;
218*89c4ff92SAndroid Build Coastguard Worker case TestCaseResult::Failed:
219*89c4ff92SAndroid Build Coastguard Worker // This test failed so we will fail the entire program eventually, but keep going for now.
220*89c4ff92SAndroid Build Coastguard Worker success = false;
221*89c4ff92SAndroid Build Coastguard Worker break;
222*89c4ff92SAndroid Build Coastguard Worker default:
223*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(false, "Unexpected TestCaseResult");
224*89c4ff92SAndroid Build Coastguard Worker return false;
225*89c4ff92SAndroid Build Coastguard Worker }
226*89c4ff92SAndroid Build Coastguard Worker }
227*89c4ff92SAndroid Build Coastguard Worker
228*89c4ff92SAndroid Build Coastguard Worker const double averageTimePerTestCaseMs = totalTime / nbProcessed * 1000.0f;
229*89c4ff92SAndroid Build Coastguard Worker
230*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(info) << std::fixed << std::setprecision(3) <<
231*89c4ff92SAndroid Build Coastguard Worker "Total time for " << nbProcessed << " test cases: " << totalTime << " seconds";
232*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(info) << std::fixed << std::setprecision(3) <<
233*89c4ff92SAndroid Build Coastguard Worker "Average time per test case: " << averageTimePerTestCaseMs << " ms";
234*89c4ff92SAndroid Build Coastguard Worker
235*89c4ff92SAndroid Build Coastguard Worker // if profiling is enabled print out the results
236*89c4ff92SAndroid Build Coastguard Worker if (profiler && profiler->IsProfilingEnabled())
237*89c4ff92SAndroid Build Coastguard Worker {
238*89c4ff92SAndroid Build Coastguard Worker profiler->Print(std::cout);
239*89c4ff92SAndroid Build Coastguard Worker }
240*89c4ff92SAndroid Build Coastguard Worker
241*89c4ff92SAndroid Build Coastguard Worker if (!success)
242*89c4ff92SAndroid Build Coastguard Worker {
243*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << "One or more test cases failed";
244*89c4ff92SAndroid Build Coastguard Worker return false;
245*89c4ff92SAndroid Build Coastguard Worker }
246*89c4ff92SAndroid Build Coastguard Worker
247*89c4ff92SAndroid Build Coastguard Worker return testCaseProvider.OnInferenceTestFinished();
248*89c4ff92SAndroid Build Coastguard Worker }
249*89c4ff92SAndroid Build Coastguard Worker
250*89c4ff92SAndroid Build Coastguard Worker } // namespace test
251*89c4ff92SAndroid Build Coastguard Worker
252*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn
253