1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "InferenceTest.hpp"
6
7 #include <armnn/utility/Assert.hpp>
8 #include <armnnUtils/Filesystem.hpp>
9
10 #include "../src/armnn/Profiling.hpp"
11 #include <cxxopts/cxxopts.hpp>
12
13 #include <fstream>
14 #include <iostream>
15 #include <iomanip>
16 #include <array>
17
18 using namespace std;
19 using namespace std::chrono;
20 using namespace armnn::test;
21
22 namespace armnn
23 {
24 namespace test
25 {
26 /// Parse the command line of an ArmNN inference test program.
27 /// \return false if any error occurred during options processing, otherwise true
ParseCommandLine(int argc,char ** argv,IInferenceTestCaseProvider & testCaseProvider,InferenceTestOptions & outParams)28 bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
29 InferenceTestOptions& outParams)
30 {
31 cxxopts::Options options("InferenceTest", "Inference iteration parameters");
32
33 try
34 {
35 // Adds generic options needed for all inference tests.
36 options
37 .allow_unrecognised_options()
38 .add_options()
39 ("h,help", "Display help messages")
40 ("i,iterations", "Sets the number of inferences to perform. If unset, will only be run once.",
41 cxxopts::value<unsigned int>(outParams.m_IterationCount)->default_value("0"))
42 ("inference-times-file",
43 "If non-empty, each individual inference time will be recorded and output to this file",
44 cxxopts::value<std::string>(outParams.m_InferenceTimesFile)->default_value(""))
45 ("e,event-based-profiling", "Enables built in profiler. If unset, defaults to off.",
46 cxxopts::value<bool>(outParams.m_EnableProfiling)->default_value("0"));
47
48 std::vector<std::string> required; //to be passed as reference to derived inference tests
49
50 // Adds options specific to the ITestCaseProvider.
51 testCaseProvider.AddCommandLineOptions(options, required);
52
53 auto result = options.parse(argc, argv);
54
55 if (result.count("help"))
56 {
57 std::cout << options.help() << std::endl;
58 return false;
59 }
60
61 CheckRequiredOptions(result, required);
62
63 }
64 catch (const cxxopts::OptionException& e)
65 {
66 std::cerr << e.what() << std::endl << options.help() << std::endl;
67 return false;
68 }
69 catch (const std::exception& e)
70 {
71 ARMNN_ASSERT_MSG(false, "Caught unexpected exception");
72 std::cerr << "Fatal internal error: " << e.what() << std::endl;
73 return false;
74 }
75
76 if (!testCaseProvider.ProcessCommandLineOptions(outParams))
77 {
78 return false;
79 }
80
81 return true;
82 }
83
ValidateDirectory(std::string & dir)84 bool ValidateDirectory(std::string& dir)
85 {
86 if (dir.empty())
87 {
88 std::cerr << "No directory specified" << std::endl;
89 return false;
90 }
91
92 if (dir[dir.length() - 1] != '/')
93 {
94 dir += "/";
95 }
96
97 if (!fs::exists(dir))
98 {
99 std::cerr << "Given directory " << dir << " does not exist" << std::endl;
100 return false;
101 }
102
103 if (!fs::is_directory(dir))
104 {
105 std::cerr << "Given directory [" << dir << "] is not a directory" << std::endl;
106 return false;
107 }
108
109 return true;
110 }
111
InferenceTest(const InferenceTestOptions & params,const std::vector<unsigned int> & defaultTestCaseIds,IInferenceTestCaseProvider & testCaseProvider)112 bool InferenceTest(const InferenceTestOptions& params,
113 const std::vector<unsigned int>& defaultTestCaseIds,
114 IInferenceTestCaseProvider& testCaseProvider)
115 {
116 #if !defined (NDEBUG)
117 if (params.m_IterationCount > 0) // If just running a few select images then don't bother to warn.
118 {
119 ARMNN_LOG(warning) << "Performance test running in DEBUG build - results may be inaccurate.";
120 }
121 #endif
122
123 double totalTime = 0;
124 unsigned int nbProcessed = 0;
125 bool success = true;
126
127 // Opens the file to write inference times too, if needed.
128 ofstream inferenceTimesFile;
129 const bool recordInferenceTimes = !params.m_InferenceTimesFile.empty();
130 if (recordInferenceTimes)
131 {
132 inferenceTimesFile.open(params.m_InferenceTimesFile.c_str(), ios_base::trunc | ios_base::out);
133 if (!inferenceTimesFile.good())
134 {
135 ARMNN_LOG(error) << "Failed to open inference times file for writing: "
136 << params.m_InferenceTimesFile;
137 return false;
138 }
139 }
140
141 // Create a profiler and register it for the current thread.
142 std::unique_ptr<IProfiler> profiler = std::make_unique<IProfiler>();
143 ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
144
145 // Enable profiling if requested.
146 profiler->EnableProfiling(params.m_EnableProfiling);
147
148 // Run a single test case to 'warm-up' the model. The first one can sometimes take up to 10x longer
149 std::unique_ptr<IInferenceTestCase> warmupTestCase = testCaseProvider.GetTestCase(0);
150 if (warmupTestCase == nullptr)
151 {
152 ARMNN_LOG(error) << "Failed to load test case";
153 return false;
154 }
155
156 try
157 {
158 warmupTestCase->Run();
159 }
160 catch (const TestFrameworkException& testError)
161 {
162 ARMNN_LOG(error) << testError.what();
163 return false;
164 }
165
166 const unsigned int nbTotalToProcess = params.m_IterationCount > 0 ? params.m_IterationCount
167 : static_cast<unsigned int>(defaultTestCaseIds.size());
168
169 for (; nbProcessed < nbTotalToProcess; nbProcessed++)
170 {
171 const unsigned int testCaseId = params.m_IterationCount > 0 ? nbProcessed : defaultTestCaseIds[nbProcessed];
172 std::unique_ptr<IInferenceTestCase> testCase = testCaseProvider.GetTestCase(testCaseId);
173
174 if (testCase == nullptr)
175 {
176 ARMNN_LOG(error) << "Failed to load test case";
177 return false;
178 }
179
180 time_point<high_resolution_clock> predictStart;
181 time_point<high_resolution_clock> predictEnd;
182
183 TestCaseResult result = TestCaseResult::Ok;
184
185 try
186 {
187 predictStart = high_resolution_clock::now();
188
189 testCase->Run();
190
191 predictEnd = high_resolution_clock::now();
192
193 // duration<double> will convert the time difference into seconds as a double by default.
194 double timeTakenS = duration<double>(predictEnd - predictStart).count();
195 totalTime += timeTakenS;
196
197 // Outputss inference times, if needed.
198 if (recordInferenceTimes)
199 {
200 inferenceTimesFile << testCaseId << " " << (timeTakenS * 1000.0) << std::endl;
201 }
202
203 result = testCase->ProcessResult(params);
204
205 }
206 catch (const TestFrameworkException& testError)
207 {
208 ARMNN_LOG(error) << testError.what();
209 result = TestCaseResult::Abort;
210 }
211
212 switch (result)
213 {
214 case TestCaseResult::Ok:
215 break;
216 case TestCaseResult::Abort:
217 return false;
218 case TestCaseResult::Failed:
219 // This test failed so we will fail the entire program eventually, but keep going for now.
220 success = false;
221 break;
222 default:
223 ARMNN_ASSERT_MSG(false, "Unexpected TestCaseResult");
224 return false;
225 }
226 }
227
228 const double averageTimePerTestCaseMs = totalTime / nbProcessed * 1000.0f;
229
230 ARMNN_LOG(info) << std::fixed << std::setprecision(3) <<
231 "Total time for " << nbProcessed << " test cases: " << totalTime << " seconds";
232 ARMNN_LOG(info) << std::fixed << std::setprecision(3) <<
233 "Average time per test case: " << averageTimePerTestCaseMs << " ms";
234
235 // if profiling is enabled print out the results
236 if (profiler && profiler->IsProfilingEnabled())
237 {
238 profiler->Print(std::cout);
239 }
240
241 if (!success)
242 {
243 ARMNN_LOG(error) << "One or more test cases failed";
244 return false;
245 }
246
247 return testCaseProvider.OnInferenceTestFinished();
248 }
249
250 } // namespace test
251
252 } // namespace armnn
253