xref: /aosp_15_r20/external/armnn/tests/InferenceTest.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "InferenceTest.hpp"
6 
7 #include <armnn/utility/Assert.hpp>
8 #include <armnnUtils/Filesystem.hpp>
9 
10 #include "../src/armnn/Profiling.hpp"
11 #include <cxxopts/cxxopts.hpp>
12 
13 #include <fstream>
14 #include <iostream>
15 #include <iomanip>
16 #include <array>
17 
18 using namespace std;
19 using namespace std::chrono;
20 using namespace armnn::test;
21 
22 namespace armnn
23 {
24 namespace test
25 {
26 /// Parse the command line of an ArmNN inference test program.
27 /// \return false if any error occurred during options processing, otherwise true
ParseCommandLine(int argc,char ** argv,IInferenceTestCaseProvider & testCaseProvider,InferenceTestOptions & outParams)28 bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
29     InferenceTestOptions& outParams)
30 {
31     cxxopts::Options options("InferenceTest", "Inference iteration parameters");
32 
33     try
34     {
35         // Adds generic options needed for all inference tests.
36         options
37             .allow_unrecognised_options()
38             .add_options()
39                 ("h,help", "Display help messages")
40                 ("i,iterations", "Sets the number of inferences to perform. If unset, will only be run once.",
41                  cxxopts::value<unsigned int>(outParams.m_IterationCount)->default_value("0"))
42                 ("inference-times-file",
43                  "If non-empty, each individual inference time will be recorded and output to this file",
44                  cxxopts::value<std::string>(outParams.m_InferenceTimesFile)->default_value(""))
45                 ("e,event-based-profiling", "Enables built in profiler. If unset, defaults to off.",
46                  cxxopts::value<bool>(outParams.m_EnableProfiling)->default_value("0"));
47 
48         std::vector<std::string> required; //to be passed as reference to derived inference tests
49 
50         // Adds options specific to the ITestCaseProvider.
51         testCaseProvider.AddCommandLineOptions(options, required);
52 
53         auto result = options.parse(argc, argv);
54 
55         if (result.count("help"))
56         {
57             std::cout << options.help() << std::endl;
58             return false;
59         }
60 
61         CheckRequiredOptions(result, required);
62 
63     }
64     catch (const cxxopts::OptionException& e)
65     {
66         std::cerr << e.what() << std::endl << options.help() << std::endl;
67         return false;
68     }
69     catch (const std::exception& e)
70     {
71         ARMNN_ASSERT_MSG(false, "Caught unexpected exception");
72         std::cerr << "Fatal internal error: " << e.what() << std::endl;
73         return false;
74     }
75 
76     if (!testCaseProvider.ProcessCommandLineOptions(outParams))
77     {
78         return false;
79     }
80 
81     return true;
82 }
83 
ValidateDirectory(std::string & dir)84 bool ValidateDirectory(std::string& dir)
85 {
86     if (dir.empty())
87     {
88         std::cerr << "No directory specified" << std::endl;
89         return false;
90     }
91 
92     if (dir[dir.length() - 1] != '/')
93     {
94         dir += "/";
95     }
96 
97     if (!fs::exists(dir))
98     {
99         std::cerr << "Given directory " << dir << " does not exist" << std::endl;
100         return false;
101     }
102 
103     if (!fs::is_directory(dir))
104     {
105         std::cerr << "Given directory [" << dir << "] is not a directory" << std::endl;
106         return false;
107     }
108 
109     return true;
110 }
111 
InferenceTest(const InferenceTestOptions & params,const std::vector<unsigned int> & defaultTestCaseIds,IInferenceTestCaseProvider & testCaseProvider)112 bool InferenceTest(const InferenceTestOptions& params,
113     const std::vector<unsigned int>& defaultTestCaseIds,
114     IInferenceTestCaseProvider& testCaseProvider)
115 {
116 #if !defined (NDEBUG)
117     if (params.m_IterationCount > 0) // If just running a few select images then don't bother to warn.
118     {
119         ARMNN_LOG(warning) << "Performance test running in DEBUG build - results may be inaccurate.";
120     }
121 #endif
122 
123     double totalTime = 0;
124     unsigned int nbProcessed = 0;
125     bool success = true;
126 
127     // Opens the file to write inference times too, if needed.
128     ofstream inferenceTimesFile;
129     const bool recordInferenceTimes = !params.m_InferenceTimesFile.empty();
130     if (recordInferenceTimes)
131     {
132         inferenceTimesFile.open(params.m_InferenceTimesFile.c_str(), ios_base::trunc | ios_base::out);
133         if (!inferenceTimesFile.good())
134         {
135             ARMNN_LOG(error) << "Failed to open inference times file for writing: "
136                 << params.m_InferenceTimesFile;
137             return false;
138         }
139     }
140 
141     // Create a profiler and register it for the current thread.
142     std::unique_ptr<IProfiler> profiler = std::make_unique<IProfiler>();
143     ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
144 
145     // Enable profiling if requested.
146     profiler->EnableProfiling(params.m_EnableProfiling);
147 
148     // Run a single test case to 'warm-up' the model. The first one can sometimes take up to 10x longer
149     std::unique_ptr<IInferenceTestCase> warmupTestCase = testCaseProvider.GetTestCase(0);
150     if (warmupTestCase == nullptr)
151     {
152         ARMNN_LOG(error) << "Failed to load test case";
153         return false;
154     }
155 
156     try
157     {
158         warmupTestCase->Run();
159     }
160     catch (const TestFrameworkException& testError)
161     {
162         ARMNN_LOG(error) << testError.what();
163         return false;
164     }
165 
166     const unsigned int nbTotalToProcess = params.m_IterationCount > 0 ? params.m_IterationCount
167         : static_cast<unsigned int>(defaultTestCaseIds.size());
168 
169     for (; nbProcessed < nbTotalToProcess; nbProcessed++)
170     {
171         const unsigned int testCaseId = params.m_IterationCount > 0 ? nbProcessed : defaultTestCaseIds[nbProcessed];
172         std::unique_ptr<IInferenceTestCase> testCase = testCaseProvider.GetTestCase(testCaseId);
173 
174         if (testCase == nullptr)
175         {
176             ARMNN_LOG(error) << "Failed to load test case";
177             return false;
178         }
179 
180         time_point<high_resolution_clock> predictStart;
181         time_point<high_resolution_clock> predictEnd;
182 
183         TestCaseResult result = TestCaseResult::Ok;
184 
185         try
186         {
187             predictStart = high_resolution_clock::now();
188 
189             testCase->Run();
190 
191             predictEnd = high_resolution_clock::now();
192 
193             // duration<double> will convert the time difference into seconds as a double by default.
194             double timeTakenS = duration<double>(predictEnd - predictStart).count();
195             totalTime += timeTakenS;
196 
197             // Outputss inference times, if needed.
198             if (recordInferenceTimes)
199             {
200                 inferenceTimesFile << testCaseId << " " << (timeTakenS * 1000.0) << std::endl;
201             }
202 
203             result = testCase->ProcessResult(params);
204 
205         }
206         catch (const TestFrameworkException& testError)
207         {
208             ARMNN_LOG(error) << testError.what();
209             result = TestCaseResult::Abort;
210         }
211 
212         switch (result)
213         {
214         case TestCaseResult::Ok:
215             break;
216         case TestCaseResult::Abort:
217             return false;
218         case TestCaseResult::Failed:
219             // This test failed so we will fail the entire program eventually, but keep going for now.
220             success = false;
221             break;
222         default:
223             ARMNN_ASSERT_MSG(false, "Unexpected TestCaseResult");
224             return false;
225         }
226     }
227 
228     const double averageTimePerTestCaseMs = totalTime / nbProcessed * 1000.0f;
229 
230     ARMNN_LOG(info) << std::fixed << std::setprecision(3) <<
231         "Total time for " << nbProcessed << " test cases: " << totalTime << " seconds";
232     ARMNN_LOG(info) << std::fixed << std::setprecision(3) <<
233         "Average time per test case: " << averageTimePerTestCaseMs << " ms";
234 
235     // if profiling is enabled print out the results
236     if (profiler && profiler->IsProfilingEnabled())
237     {
238         profiler->Print(std::cout);
239     }
240 
241     if (!success)
242     {
243         ARMNN_LOG(error) << "One or more test cases failed";
244         return false;
245     }
246 
247     return testCaseProvider.OnInferenceTestFinished();
248 }
249 
250 } // namespace test
251 
252 } // namespace armnn
253