xref: /aosp_15_r20/external/armnn/tests/ModelAccuracyTool-Armnn/ModelAccuracyTool-Armnn.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "../ImageTensorGenerator/ImageTensorGenerator.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include "../InferenceTest.hpp"
8*89c4ff92SAndroid Build Coastguard Worker #include "ModelAccuracyChecker.hpp"
9*89c4ff92SAndroid Build Coastguard Worker #include "armnnDeserializer/IDeserializer.hpp"
10*89c4ff92SAndroid Build Coastguard Worker 
11*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Filesystem.hpp>
12*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/TContainer.hpp>
13*89c4ff92SAndroid Build Coastguard Worker 
14*89c4ff92SAndroid Build Coastguard Worker #include <cxxopts/cxxopts.hpp>
15*89c4ff92SAndroid Build Coastguard Worker #include <map>
16*89c4ff92SAndroid Build Coastguard Worker 
17*89c4ff92SAndroid Build Coastguard Worker using namespace armnn::test;
18*89c4ff92SAndroid Build Coastguard Worker 
19*89c4ff92SAndroid Build Coastguard Worker /** Load image names and ground-truth labels from the image directory and the ground truth label file
20*89c4ff92SAndroid Build Coastguard Worker  *
21*89c4ff92SAndroid Build Coastguard Worker  * @pre \p validationLabelPath exists and is valid regular file
22*89c4ff92SAndroid Build Coastguard Worker  * @pre \p imageDirectoryPath exists and is valid directory
23*89c4ff92SAndroid Build Coastguard Worker  * @pre labels in validation file correspond to images which are in lexicographical order with the image name
24*89c4ff92SAndroid Build Coastguard Worker  * @pre image index starts at 1
25*89c4ff92SAndroid Build Coastguard Worker  * @pre \p begIndex and \p endIndex are end-inclusive
26*89c4ff92SAndroid Build Coastguard Worker  *
27*89c4ff92SAndroid Build Coastguard Worker  * @param[in] validationLabelPath Path to validation label file
28*89c4ff92SAndroid Build Coastguard Worker  * @param[in] imageDirectoryPath  Path to directory containing validation images
29*89c4ff92SAndroid Build Coastguard Worker  * @param[in] begIndex            Begin index of images to be loaded. Inclusive
30*89c4ff92SAndroid Build Coastguard Worker  * @param[in] endIndex            End index of images to be loaded. Inclusive
31*89c4ff92SAndroid Build Coastguard Worker  * @param[in] excludelistPath     Path to excludelist file
32*89c4ff92SAndroid Build Coastguard Worker  * @return A map mapping image file names to their corresponding ground-truth labels
33*89c4ff92SAndroid Build Coastguard Worker  */
34*89c4ff92SAndroid Build Coastguard Worker map<std::string, std::string> LoadValidationImageFilenamesAndLabels(const string& validationLabelPath,
35*89c4ff92SAndroid Build Coastguard Worker                                                                     const string& imageDirectoryPath,
36*89c4ff92SAndroid Build Coastguard Worker                                                                     size_t begIndex             = 0,
37*89c4ff92SAndroid Build Coastguard Worker                                                                     size_t endIndex             = 0,
38*89c4ff92SAndroid Build Coastguard Worker                                                                     const string& excludelistPath = "");
39*89c4ff92SAndroid Build Coastguard Worker 
40*89c4ff92SAndroid Build Coastguard Worker /** Load model output labels from file
41*89c4ff92SAndroid Build Coastguard Worker  *
42*89c4ff92SAndroid Build Coastguard Worker  * @pre \p modelOutputLabelsPath exists and is a regular file
43*89c4ff92SAndroid Build Coastguard Worker  *
44*89c4ff92SAndroid Build Coastguard Worker  * @param[in] modelOutputLabelsPath path to model output labels file
45*89c4ff92SAndroid Build Coastguard Worker  * @return A vector of labels, which in turn is described by a list of category names
46*89c4ff92SAndroid Build Coastguard Worker  */
47*89c4ff92SAndroid Build Coastguard Worker std::vector<armnnUtils::LabelCategoryNames> LoadModelOutputLabels(const std::string& modelOutputLabelsPath);
48*89c4ff92SAndroid Build Coastguard Worker 
main(int argc,char * argv[])49*89c4ff92SAndroid Build Coastguard Worker int main(int argc, char* argv[])
50*89c4ff92SAndroid Build Coastguard Worker {
51*89c4ff92SAndroid Build Coastguard Worker     try
52*89c4ff92SAndroid Build Coastguard Worker     {
53*89c4ff92SAndroid Build Coastguard Worker         armnn::LogSeverity level = armnn::LogSeverity::Debug;
54*89c4ff92SAndroid Build Coastguard Worker         armnn::ConfigureLogging(true, true, level);
55*89c4ff92SAndroid Build Coastguard Worker 
56*89c4ff92SAndroid Build Coastguard Worker         std::string modelPath;
57*89c4ff92SAndroid Build Coastguard Worker         std::string modelFormat;
58*89c4ff92SAndroid Build Coastguard Worker         std::vector<std::string> inputNames;
59*89c4ff92SAndroid Build Coastguard Worker         std::vector<std::string> outputNames;
60*89c4ff92SAndroid Build Coastguard Worker         std::string dataDir;
61*89c4ff92SAndroid Build Coastguard Worker         std::string modelOutputLabelsPath;
62*89c4ff92SAndroid Build Coastguard Worker         std::string validationLabelPath;
63*89c4ff92SAndroid Build Coastguard Worker         std::string inputLayout;
64*89c4ff92SAndroid Build Coastguard Worker         std::vector<armnn::BackendId> computeDevice;
65*89c4ff92SAndroid Build Coastguard Worker         std::string validationRange;
66*89c4ff92SAndroid Build Coastguard Worker         std::string excludelistPath;
67*89c4ff92SAndroid Build Coastguard Worker 
68*89c4ff92SAndroid Build Coastguard Worker         const std::string backendsMessage = "Which device to run layers on by default. Possible choices: "
69*89c4ff92SAndroid Build Coastguard Worker                                             + armnn::BackendRegistryInstance().GetBackendIdsAsString();
70*89c4ff92SAndroid Build Coastguard Worker 
71*89c4ff92SAndroid Build Coastguard Worker         try
72*89c4ff92SAndroid Build Coastguard Worker         {
73*89c4ff92SAndroid Build Coastguard Worker             cxxopts::Options options("ModeAccuracyTool-Armnn","Options");
74*89c4ff92SAndroid Build Coastguard Worker 
75*89c4ff92SAndroid Build Coastguard Worker             options.add_options()
76*89c4ff92SAndroid Build Coastguard Worker                 ("h,help", "Display help messages")
77*89c4ff92SAndroid Build Coastguard Worker                 ("m,model-path",
78*89c4ff92SAndroid Build Coastguard Worker                     "Path to armnn format model file",
79*89c4ff92SAndroid Build Coastguard Worker                     cxxopts::value<std::string>(modelPath))
80*89c4ff92SAndroid Build Coastguard Worker                 ("f,model-format",
81*89c4ff92SAndroid Build Coastguard Worker                     "The model format. Supported values: tflite",
82*89c4ff92SAndroid Build Coastguard Worker                     cxxopts::value<std::string>(modelFormat))
83*89c4ff92SAndroid Build Coastguard Worker                 ("i,input-name",
84*89c4ff92SAndroid Build Coastguard Worker                     "Identifier of the input tensors in the network separated by comma with no space.",
85*89c4ff92SAndroid Build Coastguard Worker                     cxxopts::value<std::vector<std::string>>(inputNames))
86*89c4ff92SAndroid Build Coastguard Worker                 ("o,output-name",
87*89c4ff92SAndroid Build Coastguard Worker                     "Identifier of the output tensors in the network separated by comma with no space.",
88*89c4ff92SAndroid Build Coastguard Worker                     cxxopts::value<std::vector<std::string>>(outputNames))
89*89c4ff92SAndroid Build Coastguard Worker                 ("d,data-dir",
90*89c4ff92SAndroid Build Coastguard Worker                     "Path to directory containing the ImageNet test data",
91*89c4ff92SAndroid Build Coastguard Worker                     cxxopts::value<std::string>(dataDir))
92*89c4ff92SAndroid Build Coastguard Worker                 ("p,model-output-labels",
93*89c4ff92SAndroid Build Coastguard Worker                     "Path to model output labels file.",
94*89c4ff92SAndroid Build Coastguard Worker                     cxxopts::value<std::string>(modelOutputLabelsPath))
95*89c4ff92SAndroid Build Coastguard Worker                 ("v,validation-labels-path",
96*89c4ff92SAndroid Build Coastguard Worker                     "Path to ImageNet Validation Label file",
97*89c4ff92SAndroid Build Coastguard Worker                     cxxopts::value<std::string>(validationLabelPath))
98*89c4ff92SAndroid Build Coastguard Worker                 ("l,data-layout",
99*89c4ff92SAndroid Build Coastguard Worker                     "Data layout. Supported value: NHWC, NCHW. Default: NHWC",
100*89c4ff92SAndroid Build Coastguard Worker                     cxxopts::value<std::string>(inputLayout)->default_value("NHWC"))
101*89c4ff92SAndroid Build Coastguard Worker                 ("c,compute",
102*89c4ff92SAndroid Build Coastguard Worker                     backendsMessage.c_str(),
103*89c4ff92SAndroid Build Coastguard Worker                     cxxopts::value<std::vector<armnn::BackendId>>(computeDevice)->default_value("CpuAcc,CpuRef"))
104*89c4ff92SAndroid Build Coastguard Worker                 ("r,validation-range",
105*89c4ff92SAndroid Build Coastguard Worker                     "The range of the images to be evaluated. Specified in the form <begin index>:<end index>."
106*89c4ff92SAndroid Build Coastguard Worker                     "The index starts at 1 and the range is inclusive."
107*89c4ff92SAndroid Build Coastguard Worker                     "By default the evaluation will be performed on all images.",
108*89c4ff92SAndroid Build Coastguard Worker                     cxxopts::value<std::string>(validationRange)->default_value("1:0"))
109*89c4ff92SAndroid Build Coastguard Worker                 ("e,excludelist-path",
110*89c4ff92SAndroid Build Coastguard Worker                     "Path to a excludelist file where each line denotes the index of an image to be "
111*89c4ff92SAndroid Build Coastguard Worker                     "excluded from evaluation.",
112*89c4ff92SAndroid Build Coastguard Worker                     cxxopts::value<std::string>(excludelistPath)->default_value(""));
113*89c4ff92SAndroid Build Coastguard Worker 
114*89c4ff92SAndroid Build Coastguard Worker             auto result = options.parse(argc, argv);
115*89c4ff92SAndroid Build Coastguard Worker 
116*89c4ff92SAndroid Build Coastguard Worker             if (result.count("help") > 0)
117*89c4ff92SAndroid Build Coastguard Worker             {
118*89c4ff92SAndroid Build Coastguard Worker                 std::cout << options.help() << std::endl;
119*89c4ff92SAndroid Build Coastguard Worker                 return EXIT_FAILURE;
120*89c4ff92SAndroid Build Coastguard Worker             }
121*89c4ff92SAndroid Build Coastguard Worker 
122*89c4ff92SAndroid Build Coastguard Worker             // Check for mandatory single options.
123*89c4ff92SAndroid Build Coastguard Worker             std::string mandatorySingleParameters[] = { "model-path", "model-format", "input-name", "output-name",
124*89c4ff92SAndroid Build Coastguard Worker                                                         "data-dir", "model-output-labels", "validation-labels-path" };
125*89c4ff92SAndroid Build Coastguard Worker             for (auto param : mandatorySingleParameters)
126*89c4ff92SAndroid Build Coastguard Worker             {
127*89c4ff92SAndroid Build Coastguard Worker                 if (result.count(param) != 1)
128*89c4ff92SAndroid Build Coastguard Worker                 {
129*89c4ff92SAndroid Build Coastguard Worker                     std::cerr << "Parameter \'--" << param << "\' is required but missing." << std::endl;
130*89c4ff92SAndroid Build Coastguard Worker                     return EXIT_FAILURE;
131*89c4ff92SAndroid Build Coastguard Worker                 }
132*89c4ff92SAndroid Build Coastguard Worker             }
133*89c4ff92SAndroid Build Coastguard Worker         }
134*89c4ff92SAndroid Build Coastguard Worker         catch (const cxxopts::OptionException& e)
135*89c4ff92SAndroid Build Coastguard Worker         {
136*89c4ff92SAndroid Build Coastguard Worker             std::cerr << e.what() << std::endl << std::endl;
137*89c4ff92SAndroid Build Coastguard Worker             return EXIT_FAILURE;
138*89c4ff92SAndroid Build Coastguard Worker         }
139*89c4ff92SAndroid Build Coastguard Worker         catch (const std::exception& e)
140*89c4ff92SAndroid Build Coastguard Worker         {
141*89c4ff92SAndroid Build Coastguard Worker             ARMNN_ASSERT_MSG(false, "Caught unexpected exception");
142*89c4ff92SAndroid Build Coastguard Worker             std::cerr << "Fatal internal error: " << e.what() << std::endl;
143*89c4ff92SAndroid Build Coastguard Worker             return EXIT_FAILURE;
144*89c4ff92SAndroid Build Coastguard Worker         }
145*89c4ff92SAndroid Build Coastguard Worker 
146*89c4ff92SAndroid Build Coastguard Worker         // Check if the requested backend are all valid
147*89c4ff92SAndroid Build Coastguard Worker         std::string invalidBackends;
148*89c4ff92SAndroid Build Coastguard Worker         if (!CheckRequestedBackendsAreValid(computeDevice, armnn::Optional<std::string&>(invalidBackends)))
149*89c4ff92SAndroid Build Coastguard Worker         {
150*89c4ff92SAndroid Build Coastguard Worker             ARMNN_LOG(fatal) << "The list of preferred devices contains invalid backend IDs: "
151*89c4ff92SAndroid Build Coastguard Worker                              << invalidBackends;
152*89c4ff92SAndroid Build Coastguard Worker             return EXIT_FAILURE;
153*89c4ff92SAndroid Build Coastguard Worker         }
154*89c4ff92SAndroid Build Coastguard Worker         armnn::Status status;
155*89c4ff92SAndroid Build Coastguard Worker 
156*89c4ff92SAndroid Build Coastguard Worker         // Create runtime
157*89c4ff92SAndroid Build Coastguard Worker         armnn::IRuntime::CreationOptions options;
158*89c4ff92SAndroid Build Coastguard Worker         armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options));
159*89c4ff92SAndroid Build Coastguard Worker         std::ifstream file(modelPath);
160*89c4ff92SAndroid Build Coastguard Worker 
161*89c4ff92SAndroid Build Coastguard Worker         // Create Parser
162*89c4ff92SAndroid Build Coastguard Worker         using IParser = armnnDeserializer::IDeserializer;
163*89c4ff92SAndroid Build Coastguard Worker         auto armnnparser(IParser::Create());
164*89c4ff92SAndroid Build Coastguard Worker 
165*89c4ff92SAndroid Build Coastguard Worker         // Create a network
166*89c4ff92SAndroid Build Coastguard Worker         armnn::INetworkPtr network = armnnparser->CreateNetworkFromBinary(file);
167*89c4ff92SAndroid Build Coastguard Worker 
168*89c4ff92SAndroid Build Coastguard Worker         // Optimizes the network.
169*89c4ff92SAndroid Build Coastguard Worker         armnn::IOptimizedNetworkPtr optimizedNet(nullptr, nullptr);
170*89c4ff92SAndroid Build Coastguard Worker         try
171*89c4ff92SAndroid Build Coastguard Worker         {
172*89c4ff92SAndroid Build Coastguard Worker             optimizedNet = armnn::Optimize(*network, computeDevice, runtime->GetDeviceSpec());
173*89c4ff92SAndroid Build Coastguard Worker         }
174*89c4ff92SAndroid Build Coastguard Worker         catch (const armnn::Exception& e)
175*89c4ff92SAndroid Build Coastguard Worker         {
176*89c4ff92SAndroid Build Coastguard Worker             std::stringstream message;
177*89c4ff92SAndroid Build Coastguard Worker             message << "armnn::Exception (" << e.what() << ") caught from optimize.";
178*89c4ff92SAndroid Build Coastguard Worker             ARMNN_LOG(fatal) << message.str();
179*89c4ff92SAndroid Build Coastguard Worker             return EXIT_FAILURE;
180*89c4ff92SAndroid Build Coastguard Worker         }
181*89c4ff92SAndroid Build Coastguard Worker 
182*89c4ff92SAndroid Build Coastguard Worker         // Loads the network into the runtime.
183*89c4ff92SAndroid Build Coastguard Worker         armnn::NetworkId networkId;
184*89c4ff92SAndroid Build Coastguard Worker         status = runtime->LoadNetwork(networkId, std::move(optimizedNet));
185*89c4ff92SAndroid Build Coastguard Worker         if (status == armnn::Status::Failure)
186*89c4ff92SAndroid Build Coastguard Worker         {
187*89c4ff92SAndroid Build Coastguard Worker             ARMNN_LOG(fatal) << "armnn::IRuntime: Failed to load network";
188*89c4ff92SAndroid Build Coastguard Worker             return EXIT_FAILURE;
189*89c4ff92SAndroid Build Coastguard Worker         }
190*89c4ff92SAndroid Build Coastguard Worker 
191*89c4ff92SAndroid Build Coastguard Worker         // Set up Network
192*89c4ff92SAndroid Build Coastguard Worker         using BindingPointInfo = InferenceModelInternal::BindingPointInfo;
193*89c4ff92SAndroid Build Coastguard Worker 
194*89c4ff92SAndroid Build Coastguard Worker         // Handle inputNames and outputNames, there can be multiple.
195*89c4ff92SAndroid Build Coastguard Worker         std::vector<BindingPointInfo> inputBindings;
196*89c4ff92SAndroid Build Coastguard Worker         for(auto& input: inputNames)
197*89c4ff92SAndroid Build Coastguard Worker         {
198*89c4ff92SAndroid Build Coastguard Worker             const armnnDeserializer::BindingPointInfo&
199*89c4ff92SAndroid Build Coastguard Worker                     inputBindingInfo = armnnparser->GetNetworkInputBindingInfo(0, input);
200*89c4ff92SAndroid Build Coastguard Worker 
201*89c4ff92SAndroid Build Coastguard Worker             std::pair<armnn::LayerBindingId, armnn::TensorInfo>
202*89c4ff92SAndroid Build Coastguard Worker                     m_InputBindingInfo(inputBindingInfo.m_BindingId, inputBindingInfo.m_TensorInfo);
203*89c4ff92SAndroid Build Coastguard Worker             inputBindings.push_back(m_InputBindingInfo);
204*89c4ff92SAndroid Build Coastguard Worker         }
205*89c4ff92SAndroid Build Coastguard Worker 
206*89c4ff92SAndroid Build Coastguard Worker         std::vector<BindingPointInfo> outputBindings;
207*89c4ff92SAndroid Build Coastguard Worker         for(auto& output: outputNames)
208*89c4ff92SAndroid Build Coastguard Worker         {
209*89c4ff92SAndroid Build Coastguard Worker             const armnnDeserializer::BindingPointInfo&
210*89c4ff92SAndroid Build Coastguard Worker                     outputBindingInfo = armnnparser->GetNetworkOutputBindingInfo(0, output);
211*89c4ff92SAndroid Build Coastguard Worker 
212*89c4ff92SAndroid Build Coastguard Worker             std::pair<armnn::LayerBindingId, armnn::TensorInfo>
213*89c4ff92SAndroid Build Coastguard Worker                     m_OutputBindingInfo(outputBindingInfo.m_BindingId, outputBindingInfo.m_TensorInfo);
214*89c4ff92SAndroid Build Coastguard Worker             outputBindings.push_back(m_OutputBindingInfo);
215*89c4ff92SAndroid Build Coastguard Worker         }
216*89c4ff92SAndroid Build Coastguard Worker 
217*89c4ff92SAndroid Build Coastguard Worker         // Load model output labels
218*89c4ff92SAndroid Build Coastguard Worker         if (modelOutputLabelsPath.empty() || !fs::exists(modelOutputLabelsPath) ||
219*89c4ff92SAndroid Build Coastguard Worker             !fs::is_regular_file(modelOutputLabelsPath))
220*89c4ff92SAndroid Build Coastguard Worker         {
221*89c4ff92SAndroid Build Coastguard Worker             ARMNN_LOG(fatal) << "Invalid model output labels path at " << modelOutputLabelsPath;
222*89c4ff92SAndroid Build Coastguard Worker         }
223*89c4ff92SAndroid Build Coastguard Worker         const std::vector<armnnUtils::LabelCategoryNames> modelOutputLabels =
224*89c4ff92SAndroid Build Coastguard Worker             LoadModelOutputLabels(modelOutputLabelsPath);
225*89c4ff92SAndroid Build Coastguard Worker 
226*89c4ff92SAndroid Build Coastguard Worker         // Parse begin and end image indices
227*89c4ff92SAndroid Build Coastguard Worker         std::vector<std::string> imageIndexStrs = armnnUtils::SplitBy(validationRange, ":");
228*89c4ff92SAndroid Build Coastguard Worker         size_t imageBegIndex;
229*89c4ff92SAndroid Build Coastguard Worker         size_t imageEndIndex;
230*89c4ff92SAndroid Build Coastguard Worker         if (imageIndexStrs.size() != 2)
231*89c4ff92SAndroid Build Coastguard Worker         {
232*89c4ff92SAndroid Build Coastguard Worker             ARMNN_LOG(fatal) << "Invalid validation range specification: Invalid format " << validationRange;
233*89c4ff92SAndroid Build Coastguard Worker             return EXIT_FAILURE;
234*89c4ff92SAndroid Build Coastguard Worker         }
235*89c4ff92SAndroid Build Coastguard Worker         try
236*89c4ff92SAndroid Build Coastguard Worker         {
237*89c4ff92SAndroid Build Coastguard Worker             imageBegIndex = std::stoul(imageIndexStrs[0]);
238*89c4ff92SAndroid Build Coastguard Worker             imageEndIndex = std::stoul(imageIndexStrs[1]);
239*89c4ff92SAndroid Build Coastguard Worker         }
240*89c4ff92SAndroid Build Coastguard Worker         catch (const std::exception& e)
241*89c4ff92SAndroid Build Coastguard Worker         {
242*89c4ff92SAndroid Build Coastguard Worker             ARMNN_LOG(fatal) << "Invalid validation range specification: " << validationRange;
243*89c4ff92SAndroid Build Coastguard Worker             return EXIT_FAILURE;
244*89c4ff92SAndroid Build Coastguard Worker         }
245*89c4ff92SAndroid Build Coastguard Worker 
246*89c4ff92SAndroid Build Coastguard Worker         // Validate  excludelist file if it's specified
247*89c4ff92SAndroid Build Coastguard Worker         if (!excludelistPath.empty() &&
248*89c4ff92SAndroid Build Coastguard Worker             !(fs::exists(excludelistPath) && fs::is_regular_file(excludelistPath)))
249*89c4ff92SAndroid Build Coastguard Worker         {
250*89c4ff92SAndroid Build Coastguard Worker             ARMNN_LOG(fatal) << "Invalid path to excludelist file at " << excludelistPath;
251*89c4ff92SAndroid Build Coastguard Worker             return EXIT_FAILURE;
252*89c4ff92SAndroid Build Coastguard Worker         }
253*89c4ff92SAndroid Build Coastguard Worker 
254*89c4ff92SAndroid Build Coastguard Worker         fs::path pathToDataDir(dataDir);
255*89c4ff92SAndroid Build Coastguard Worker         const map<std::string, std::string> imageNameToLabel = LoadValidationImageFilenamesAndLabels(
256*89c4ff92SAndroid Build Coastguard Worker             validationLabelPath, pathToDataDir.string(), imageBegIndex, imageEndIndex, excludelistPath);
257*89c4ff92SAndroid Build Coastguard Worker         armnnUtils::ModelAccuracyChecker checker(imageNameToLabel, modelOutputLabels);
258*89c4ff92SAndroid Build Coastguard Worker 
259*89c4ff92SAndroid Build Coastguard Worker         if (ValidateDirectory(dataDir))
260*89c4ff92SAndroid Build Coastguard Worker         {
261*89c4ff92SAndroid Build Coastguard Worker             InferenceModel<armnnDeserializer::IDeserializer, float>::Params params;
262*89c4ff92SAndroid Build Coastguard Worker 
263*89c4ff92SAndroid Build Coastguard Worker             params.m_ModelPath      = modelPath;
264*89c4ff92SAndroid Build Coastguard Worker             params.m_IsModelBinary  = true;
265*89c4ff92SAndroid Build Coastguard Worker             params.m_ComputeDevices = computeDevice;
266*89c4ff92SAndroid Build Coastguard Worker             // Insert inputNames and outputNames into params vector
267*89c4ff92SAndroid Build Coastguard Worker             params.m_InputBindings.insert(std::end(params.m_InputBindings),
268*89c4ff92SAndroid Build Coastguard Worker                                           std::begin(inputNames),
269*89c4ff92SAndroid Build Coastguard Worker                                           std::end(inputNames));
270*89c4ff92SAndroid Build Coastguard Worker             params.m_OutputBindings.insert(std::end(params.m_OutputBindings),
271*89c4ff92SAndroid Build Coastguard Worker                                            std::begin(outputNames),
272*89c4ff92SAndroid Build Coastguard Worker                                            std::end(outputNames));
273*89c4ff92SAndroid Build Coastguard Worker 
274*89c4ff92SAndroid Build Coastguard Worker             using TParser = armnnDeserializer::IDeserializer;
275*89c4ff92SAndroid Build Coastguard Worker             // If dynamicBackends is empty it will be disabled by default.
276*89c4ff92SAndroid Build Coastguard Worker             InferenceModel<TParser, float> model(params, false, "");
277*89c4ff92SAndroid Build Coastguard Worker 
278*89c4ff92SAndroid Build Coastguard Worker             // Get input tensor information
279*89c4ff92SAndroid Build Coastguard Worker             const armnn::TensorInfo& inputTensorInfo   = model.GetInputBindingInfo().second;
280*89c4ff92SAndroid Build Coastguard Worker             const armnn::TensorShape& inputTensorShape = inputTensorInfo.GetShape();
281*89c4ff92SAndroid Build Coastguard Worker             const armnn::DataType& inputTensorDataType = inputTensorInfo.GetDataType();
282*89c4ff92SAndroid Build Coastguard Worker             armnn::DataLayout inputTensorDataLayout;
283*89c4ff92SAndroid Build Coastguard Worker             if (inputLayout == "NCHW")
284*89c4ff92SAndroid Build Coastguard Worker             {
285*89c4ff92SAndroid Build Coastguard Worker                 inputTensorDataLayout = armnn::DataLayout::NCHW;
286*89c4ff92SAndroid Build Coastguard Worker             }
287*89c4ff92SAndroid Build Coastguard Worker             else if (inputLayout == "NHWC")
288*89c4ff92SAndroid Build Coastguard Worker             {
289*89c4ff92SAndroid Build Coastguard Worker                 inputTensorDataLayout = armnn::DataLayout::NHWC;
290*89c4ff92SAndroid Build Coastguard Worker             }
291*89c4ff92SAndroid Build Coastguard Worker             else
292*89c4ff92SAndroid Build Coastguard Worker             {
293*89c4ff92SAndroid Build Coastguard Worker                 ARMNN_LOG(fatal) << "Invalid Data layout: " << inputLayout;
294*89c4ff92SAndroid Build Coastguard Worker                 return EXIT_FAILURE;
295*89c4ff92SAndroid Build Coastguard Worker             }
296*89c4ff92SAndroid Build Coastguard Worker             const unsigned int inputTensorWidth =
297*89c4ff92SAndroid Build Coastguard Worker                 inputTensorDataLayout == armnn::DataLayout::NCHW ? inputTensorShape[3] : inputTensorShape[2];
298*89c4ff92SAndroid Build Coastguard Worker             const unsigned int inputTensorHeight =
299*89c4ff92SAndroid Build Coastguard Worker                 inputTensorDataLayout == armnn::DataLayout::NCHW ? inputTensorShape[2] : inputTensorShape[1];
300*89c4ff92SAndroid Build Coastguard Worker             // Get output tensor info
301*89c4ff92SAndroid Build Coastguard Worker             const unsigned int outputNumElements = model.GetOutputSize();
302*89c4ff92SAndroid Build Coastguard Worker             // Check output tensor shape is valid
303*89c4ff92SAndroid Build Coastguard Worker             if (modelOutputLabels.size() != outputNumElements)
304*89c4ff92SAndroid Build Coastguard Worker             {
305*89c4ff92SAndroid Build Coastguard Worker                 ARMNN_LOG(fatal) << "Number of output elements: " << outputNumElements
306*89c4ff92SAndroid Build Coastguard Worker                                          << " , mismatches the number of output labels: " << modelOutputLabels.size();
307*89c4ff92SAndroid Build Coastguard Worker                 return EXIT_FAILURE;
308*89c4ff92SAndroid Build Coastguard Worker             }
309*89c4ff92SAndroid Build Coastguard Worker 
310*89c4ff92SAndroid Build Coastguard Worker             const unsigned int batchSize = 1;
311*89c4ff92SAndroid Build Coastguard Worker             // Get normalisation parameters
312*89c4ff92SAndroid Build Coastguard Worker             SupportedFrontend modelFrontend;
313*89c4ff92SAndroid Build Coastguard Worker             if (modelFormat == "tflite")
314*89c4ff92SAndroid Build Coastguard Worker             {
315*89c4ff92SAndroid Build Coastguard Worker                 modelFrontend = SupportedFrontend::TFLite;
316*89c4ff92SAndroid Build Coastguard Worker             }
317*89c4ff92SAndroid Build Coastguard Worker             else
318*89c4ff92SAndroid Build Coastguard Worker             {
319*89c4ff92SAndroid Build Coastguard Worker                 ARMNN_LOG(fatal) << "Unsupported frontend: " << modelFormat;
320*89c4ff92SAndroid Build Coastguard Worker                 return EXIT_FAILURE;
321*89c4ff92SAndroid Build Coastguard Worker             }
322*89c4ff92SAndroid Build Coastguard Worker             const NormalizationParameters& normParams = GetNormalizationParameters(modelFrontend, inputTensorDataType);
323*89c4ff92SAndroid Build Coastguard Worker             for (const auto& imageEntry : imageNameToLabel)
324*89c4ff92SAndroid Build Coastguard Worker             {
325*89c4ff92SAndroid Build Coastguard Worker                 const std::string imageName = imageEntry.first;
326*89c4ff92SAndroid Build Coastguard Worker                 std::cout << "Processing image: " << imageName << "\n";
327*89c4ff92SAndroid Build Coastguard Worker 
328*89c4ff92SAndroid Build Coastguard Worker                 vector<armnnUtils::TContainer> inputDataContainers;
329*89c4ff92SAndroid Build Coastguard Worker                 vector<armnnUtils::TContainer> outputDataContainers;
330*89c4ff92SAndroid Build Coastguard Worker 
331*89c4ff92SAndroid Build Coastguard Worker                 auto imagePath = pathToDataDir / fs::path(imageName);
332*89c4ff92SAndroid Build Coastguard Worker                 switch (inputTensorDataType)
333*89c4ff92SAndroid Build Coastguard Worker                 {
334*89c4ff92SAndroid Build Coastguard Worker                     case armnn::DataType::Signed32:
335*89c4ff92SAndroid Build Coastguard Worker                         inputDataContainers.push_back(
336*89c4ff92SAndroid Build Coastguard Worker                             PrepareImageTensor<int>(imagePath.string(),
337*89c4ff92SAndroid Build Coastguard Worker                             inputTensorWidth, inputTensorHeight,
338*89c4ff92SAndroid Build Coastguard Worker                             normParams,
339*89c4ff92SAndroid Build Coastguard Worker                             batchSize,
340*89c4ff92SAndroid Build Coastguard Worker                             inputTensorDataLayout));
341*89c4ff92SAndroid Build Coastguard Worker                         outputDataContainers = { vector<int>(outputNumElements) };
342*89c4ff92SAndroid Build Coastguard Worker                         break;
343*89c4ff92SAndroid Build Coastguard Worker                     case armnn::DataType::QAsymmU8:
344*89c4ff92SAndroid Build Coastguard Worker                         inputDataContainers.push_back(
345*89c4ff92SAndroid Build Coastguard Worker                             PrepareImageTensor<uint8_t>(imagePath.string(),
346*89c4ff92SAndroid Build Coastguard Worker                             inputTensorWidth, inputTensorHeight,
347*89c4ff92SAndroid Build Coastguard Worker                             normParams,
348*89c4ff92SAndroid Build Coastguard Worker                             batchSize,
349*89c4ff92SAndroid Build Coastguard Worker                             inputTensorDataLayout));
350*89c4ff92SAndroid Build Coastguard Worker                         outputDataContainers = { vector<uint8_t>(outputNumElements) };
351*89c4ff92SAndroid Build Coastguard Worker                         break;
352*89c4ff92SAndroid Build Coastguard Worker                     case armnn::DataType::Float32:
353*89c4ff92SAndroid Build Coastguard Worker                     default:
354*89c4ff92SAndroid Build Coastguard Worker                         inputDataContainers.push_back(
355*89c4ff92SAndroid Build Coastguard Worker                             PrepareImageTensor<float>(imagePath.string(),
356*89c4ff92SAndroid Build Coastguard Worker                             inputTensorWidth, inputTensorHeight,
357*89c4ff92SAndroid Build Coastguard Worker                             normParams,
358*89c4ff92SAndroid Build Coastguard Worker                             batchSize,
359*89c4ff92SAndroid Build Coastguard Worker                             inputTensorDataLayout));
360*89c4ff92SAndroid Build Coastguard Worker                         outputDataContainers = { vector<float>(outputNumElements) };
361*89c4ff92SAndroid Build Coastguard Worker                         break;
362*89c4ff92SAndroid Build Coastguard Worker                 }
363*89c4ff92SAndroid Build Coastguard Worker 
364*89c4ff92SAndroid Build Coastguard Worker                 status = runtime->EnqueueWorkload(networkId,
365*89c4ff92SAndroid Build Coastguard Worker                                                   armnnUtils::MakeInputTensors(inputBindings, inputDataContainers),
366*89c4ff92SAndroid Build Coastguard Worker                                                   armnnUtils::MakeOutputTensors(outputBindings, outputDataContainers));
367*89c4ff92SAndroid Build Coastguard Worker 
368*89c4ff92SAndroid Build Coastguard Worker                 if (status == armnn::Status::Failure)
369*89c4ff92SAndroid Build Coastguard Worker                 {
370*89c4ff92SAndroid Build Coastguard Worker                     ARMNN_LOG(fatal) << "armnn::IRuntime: Failed to enqueue workload for image: " << imageName;
371*89c4ff92SAndroid Build Coastguard Worker                 }
372*89c4ff92SAndroid Build Coastguard Worker 
373*89c4ff92SAndroid Build Coastguard Worker                 checker.AddImageResult<armnnUtils::TContainer>(imageName, outputDataContainers);
374*89c4ff92SAndroid Build Coastguard Worker             }
375*89c4ff92SAndroid Build Coastguard Worker         }
376*89c4ff92SAndroid Build Coastguard Worker         else
377*89c4ff92SAndroid Build Coastguard Worker         {
378*89c4ff92SAndroid Build Coastguard Worker             return EXIT_SUCCESS;
379*89c4ff92SAndroid Build Coastguard Worker         }
380*89c4ff92SAndroid Build Coastguard Worker 
381*89c4ff92SAndroid Build Coastguard Worker         for(unsigned int i = 1; i <= 5; ++i)
382*89c4ff92SAndroid Build Coastguard Worker         {
383*89c4ff92SAndroid Build Coastguard Worker             std::cout << "Top " << i <<  " Accuracy: " << checker.GetAccuracy(i) << "%" << "\n";
384*89c4ff92SAndroid Build Coastguard Worker         }
385*89c4ff92SAndroid Build Coastguard Worker 
386*89c4ff92SAndroid Build Coastguard Worker         ARMNN_LOG(info) << "Accuracy Tool ran successfully!";
387*89c4ff92SAndroid Build Coastguard Worker         return EXIT_SUCCESS;
388*89c4ff92SAndroid Build Coastguard Worker     }
389*89c4ff92SAndroid Build Coastguard Worker     catch (const armnn::Exception& e)
390*89c4ff92SAndroid Build Coastguard Worker     {
391*89c4ff92SAndroid Build Coastguard Worker         // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
392*89c4ff92SAndroid Build Coastguard Worker         // exception of type std::length_error.
393*89c4ff92SAndroid Build Coastguard Worker         // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
394*89c4ff92SAndroid Build Coastguard Worker         std::cerr << "Armnn Error: " << e.what() << std::endl;
395*89c4ff92SAndroid Build Coastguard Worker         return EXIT_FAILURE;
396*89c4ff92SAndroid Build Coastguard Worker     }
397*89c4ff92SAndroid Build Coastguard Worker     catch (const std::exception& e)
398*89c4ff92SAndroid Build Coastguard Worker     {
399*89c4ff92SAndroid Build Coastguard Worker         // Coverity fix: various boost exceptions can be thrown by methods called by this test.
400*89c4ff92SAndroid Build Coastguard Worker         std::cerr << "WARNING: ModelAccuracyTool-Armnn: An error has occurred when running the "
401*89c4ff92SAndroid Build Coastguard Worker                      "Accuracy Tool: " << e.what() << std::endl;
402*89c4ff92SAndroid Build Coastguard Worker         return EXIT_FAILURE;
403*89c4ff92SAndroid Build Coastguard Worker     }
404*89c4ff92SAndroid Build Coastguard Worker }
405*89c4ff92SAndroid Build Coastguard Worker 
LoadValidationImageFilenamesAndLabels(const string & validationLabelPath,const string & imageDirectoryPath,size_t begIndex,size_t endIndex,const string & excludelistPath)406*89c4ff92SAndroid Build Coastguard Worker map<std::string, std::string> LoadValidationImageFilenamesAndLabels(const string& validationLabelPath,
407*89c4ff92SAndroid Build Coastguard Worker                                                                     const string& imageDirectoryPath,
408*89c4ff92SAndroid Build Coastguard Worker                                                                     size_t begIndex,
409*89c4ff92SAndroid Build Coastguard Worker                                                                     size_t endIndex,
410*89c4ff92SAndroid Build Coastguard Worker                                                                     const string& excludelistPath)
411*89c4ff92SAndroid Build Coastguard Worker {
412*89c4ff92SAndroid Build Coastguard Worker     // Populate imageFilenames with names of all .JPEG, .PNG images
413*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::string> imageFilenames;
414*89c4ff92SAndroid Build Coastguard Worker     for (const auto& imageEntry : fs::directory_iterator(fs::path(imageDirectoryPath)))
415*89c4ff92SAndroid Build Coastguard Worker     {
416*89c4ff92SAndroid Build Coastguard Worker         fs::path imagePath = imageEntry.path();
417*89c4ff92SAndroid Build Coastguard Worker 
418*89c4ff92SAndroid Build Coastguard Worker         // Get extension and convert to uppercase
419*89c4ff92SAndroid Build Coastguard Worker         std::string imageExtension = imagePath.extension().string();
420*89c4ff92SAndroid Build Coastguard Worker         std::transform(imageExtension.begin(), imageExtension.end(), imageExtension.begin(), ::toupper);
421*89c4ff92SAndroid Build Coastguard Worker 
422*89c4ff92SAndroid Build Coastguard Worker         if (fs::is_regular_file(imagePath) && (imageExtension == ".JPEG" || imageExtension == ".PNG"))
423*89c4ff92SAndroid Build Coastguard Worker         {
424*89c4ff92SAndroid Build Coastguard Worker             imageFilenames.push_back(imagePath.filename().string());
425*89c4ff92SAndroid Build Coastguard Worker         }
426*89c4ff92SAndroid Build Coastguard Worker     }
427*89c4ff92SAndroid Build Coastguard Worker     if (imageFilenames.empty())
428*89c4ff92SAndroid Build Coastguard Worker     {
429*89c4ff92SAndroid Build Coastguard Worker         throw armnn::Exception("No image file (JPEG, PNG) found at " + imageDirectoryPath);
430*89c4ff92SAndroid Build Coastguard Worker     }
431*89c4ff92SAndroid Build Coastguard Worker 
432*89c4ff92SAndroid Build Coastguard Worker     // Sort the image filenames lexicographically
433*89c4ff92SAndroid Build Coastguard Worker     std::sort(imageFilenames.begin(), imageFilenames.end());
434*89c4ff92SAndroid Build Coastguard Worker 
435*89c4ff92SAndroid Build Coastguard Worker     std::cout << imageFilenames.size() << " images found at " << imageDirectoryPath << std::endl;
436*89c4ff92SAndroid Build Coastguard Worker 
437*89c4ff92SAndroid Build Coastguard Worker     // Get default end index
438*89c4ff92SAndroid Build Coastguard Worker     if (begIndex < 1 || endIndex > imageFilenames.size())
439*89c4ff92SAndroid Build Coastguard Worker     {
440*89c4ff92SAndroid Build Coastguard Worker         throw armnn::Exception("Invalid image index range");
441*89c4ff92SAndroid Build Coastguard Worker     }
442*89c4ff92SAndroid Build Coastguard Worker     endIndex = endIndex == 0 ? imageFilenames.size() : endIndex;
443*89c4ff92SAndroid Build Coastguard Worker     if (begIndex > endIndex)
444*89c4ff92SAndroid Build Coastguard Worker     {
445*89c4ff92SAndroid Build Coastguard Worker         throw armnn::Exception("Invalid image index range");
446*89c4ff92SAndroid Build Coastguard Worker     }
447*89c4ff92SAndroid Build Coastguard Worker 
448*89c4ff92SAndroid Build Coastguard Worker     // Load excludelist if there is one
449*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> excludelist;
450*89c4ff92SAndroid Build Coastguard Worker     if (!excludelistPath.empty())
451*89c4ff92SAndroid Build Coastguard Worker     {
452*89c4ff92SAndroid Build Coastguard Worker         std::ifstream excludelistFile(excludelistPath);
453*89c4ff92SAndroid Build Coastguard Worker         unsigned int index;
454*89c4ff92SAndroid Build Coastguard Worker         while (excludelistFile >> index)
455*89c4ff92SAndroid Build Coastguard Worker         {
456*89c4ff92SAndroid Build Coastguard Worker             excludelist.push_back(index);
457*89c4ff92SAndroid Build Coastguard Worker         }
458*89c4ff92SAndroid Build Coastguard Worker     }
459*89c4ff92SAndroid Build Coastguard Worker 
460*89c4ff92SAndroid Build Coastguard Worker     // Load ground truth labels and pair them with corresponding image names
461*89c4ff92SAndroid Build Coastguard Worker     std::string classification;
462*89c4ff92SAndroid Build Coastguard Worker     map<std::string, std::string> imageNameToLabel;
463*89c4ff92SAndroid Build Coastguard Worker     ifstream infile(validationLabelPath);
464*89c4ff92SAndroid Build Coastguard Worker     size_t imageIndex          = begIndex;
465*89c4ff92SAndroid Build Coastguard Worker     size_t excludelistIndexCount = 0;
466*89c4ff92SAndroid Build Coastguard Worker     while (std::getline(infile, classification))
467*89c4ff92SAndroid Build Coastguard Worker     {
468*89c4ff92SAndroid Build Coastguard Worker         if (imageIndex > endIndex)
469*89c4ff92SAndroid Build Coastguard Worker         {
470*89c4ff92SAndroid Build Coastguard Worker             break;
471*89c4ff92SAndroid Build Coastguard Worker         }
472*89c4ff92SAndroid Build Coastguard Worker         // If current imageIndex is included in excludelist, skip the current image
473*89c4ff92SAndroid Build Coastguard Worker         if (excludelistIndexCount < excludelist.size() && imageIndex == excludelist[excludelistIndexCount])
474*89c4ff92SAndroid Build Coastguard Worker         {
475*89c4ff92SAndroid Build Coastguard Worker             ++imageIndex;
476*89c4ff92SAndroid Build Coastguard Worker             ++excludelistIndexCount;
477*89c4ff92SAndroid Build Coastguard Worker             continue;
478*89c4ff92SAndroid Build Coastguard Worker         }
479*89c4ff92SAndroid Build Coastguard Worker         imageNameToLabel.insert(std::pair<std::string, std::string>(imageFilenames[imageIndex - 1], classification));
480*89c4ff92SAndroid Build Coastguard Worker         ++imageIndex;
481*89c4ff92SAndroid Build Coastguard Worker     }
482*89c4ff92SAndroid Build Coastguard Worker     std::cout << excludelistIndexCount << " images in excludelist" << std::endl;
483*89c4ff92SAndroid Build Coastguard Worker     std::cout << imageIndex - begIndex - excludelistIndexCount << " images to be loaded" << std::endl;
484*89c4ff92SAndroid Build Coastguard Worker     return imageNameToLabel;
485*89c4ff92SAndroid Build Coastguard Worker }
486*89c4ff92SAndroid Build Coastguard Worker 
LoadModelOutputLabels(const std::string & modelOutputLabelsPath)487*89c4ff92SAndroid Build Coastguard Worker std::vector<armnnUtils::LabelCategoryNames> LoadModelOutputLabels(const std::string& modelOutputLabelsPath)
488*89c4ff92SAndroid Build Coastguard Worker {
489*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnnUtils::LabelCategoryNames> modelOutputLabels;
490*89c4ff92SAndroid Build Coastguard Worker     ifstream modelOutputLablesFile(modelOutputLabelsPath);
491*89c4ff92SAndroid Build Coastguard Worker     std::string line;
492*89c4ff92SAndroid Build Coastguard Worker     while (std::getline(modelOutputLablesFile, line))
493*89c4ff92SAndroid Build Coastguard Worker     {
494*89c4ff92SAndroid Build Coastguard Worker         armnnUtils::LabelCategoryNames tokens                  = armnnUtils::SplitBy(line, ":");
495*89c4ff92SAndroid Build Coastguard Worker         armnnUtils::LabelCategoryNames predictionCategoryNames = armnnUtils::SplitBy(tokens.back(), ",");
496*89c4ff92SAndroid Build Coastguard Worker         std::transform(predictionCategoryNames.begin(), predictionCategoryNames.end(), predictionCategoryNames.begin(),
497*89c4ff92SAndroid Build Coastguard Worker                        [](const std::string& category) { return armnnUtils::Strip(category); });
498*89c4ff92SAndroid Build Coastguard Worker         modelOutputLabels.push_back(predictionCategoryNames);
499*89c4ff92SAndroid Build Coastguard Worker     }
500*89c4ff92SAndroid Build Coastguard Worker     return modelOutputLabels;
501*89c4ff92SAndroid Build Coastguard Worker }