xref: /aosp_15_r20/external/armnn/tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "armnn/ArmNN.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include "armnn/Utils.hpp"
8*89c4ff92SAndroid Build Coastguard Worker #include "armnn/INetwork.hpp"
9*89c4ff92SAndroid Build Coastguard Worker #include "armnnTfLiteParser/TfLiteParser.hpp"
10*89c4ff92SAndroid Build Coastguard Worker #include "../Cifar10Database.hpp"
11*89c4ff92SAndroid Build Coastguard Worker #include "../InferenceTest.hpp"
12*89c4ff92SAndroid Build Coastguard Worker #include "../InferenceModel.hpp"
13*89c4ff92SAndroid Build Coastguard Worker 
14*89c4ff92SAndroid Build Coastguard Worker #include <cxxopts/cxxopts.hpp>
15*89c4ff92SAndroid Build Coastguard Worker 
16*89c4ff92SAndroid Build Coastguard Worker #include <iostream>
17*89c4ff92SAndroid Build Coastguard Worker #include <chrono>
18*89c4ff92SAndroid Build Coastguard Worker #include <vector>
19*89c4ff92SAndroid Build Coastguard Worker #include <array>
20*89c4ff92SAndroid Build Coastguard Worker 
21*89c4ff92SAndroid Build Coastguard Worker 
22*89c4ff92SAndroid Build Coastguard Worker using namespace std;
23*89c4ff92SAndroid Build Coastguard Worker using namespace std::chrono;
24*89c4ff92SAndroid Build Coastguard Worker using namespace armnn::test;
25*89c4ff92SAndroid Build Coastguard Worker 
main(int argc,char * argv[])26*89c4ff92SAndroid Build Coastguard Worker int main(int argc, char* argv[])
27*89c4ff92SAndroid Build Coastguard Worker {
28*89c4ff92SAndroid Build Coastguard Worker #ifdef NDEBUG
29*89c4ff92SAndroid Build Coastguard Worker     armnn::LogSeverity level = armnn::LogSeverity::Info;
30*89c4ff92SAndroid Build Coastguard Worker #else
31*89c4ff92SAndroid Build Coastguard Worker     armnn::LogSeverity level = armnn::LogSeverity::Debug;
32*89c4ff92SAndroid Build Coastguard Worker #endif
33*89c4ff92SAndroid Build Coastguard Worker 
34*89c4ff92SAndroid Build Coastguard Worker     try
35*89c4ff92SAndroid Build Coastguard Worker     {
36*89c4ff92SAndroid Build Coastguard Worker         // Configures logging for both the ARMNN library and this test program.
37*89c4ff92SAndroid Build Coastguard Worker         armnn::ConfigureLogging(true, true, level);
38*89c4ff92SAndroid Build Coastguard Worker 
39*89c4ff92SAndroid Build Coastguard Worker         std::vector<armnn::BackendId> computeDevice;
40*89c4ff92SAndroid Build Coastguard Worker         std::string modelDir;
41*89c4ff92SAndroid Build Coastguard Worker         std::string dataDir;
42*89c4ff92SAndroid Build Coastguard Worker 
43*89c4ff92SAndroid Build Coastguard Worker         const std::string backendsMessage = "Which device to run layers on by default. Possible choices: "
44*89c4ff92SAndroid Build Coastguard Worker                                           + armnn::BackendRegistryInstance().GetBackendIdsAsString();
45*89c4ff92SAndroid Build Coastguard Worker 
46*89c4ff92SAndroid Build Coastguard Worker         cxxopts::Options in_options("MultipleNetworksCifar10",
47*89c4ff92SAndroid Build Coastguard Worker                                     "Run multiple networks inference tests using Cifar-10 data.");
48*89c4ff92SAndroid Build Coastguard Worker 
49*89c4ff92SAndroid Build Coastguard Worker         try
50*89c4ff92SAndroid Build Coastguard Worker         {
51*89c4ff92SAndroid Build Coastguard Worker             // Adds generic options needed for all inference tests.
52*89c4ff92SAndroid Build Coastguard Worker             in_options.add_options()
53*89c4ff92SAndroid Build Coastguard Worker                 ("h,help", "Display help messages")
54*89c4ff92SAndroid Build Coastguard Worker                 ("m,model-dir", "Path to directory containing the Cifar10 model file",
55*89c4ff92SAndroid Build Coastguard Worker                  cxxopts::value<std::string>(modelDir))
56*89c4ff92SAndroid Build Coastguard Worker                 ("c,compute", backendsMessage.c_str(),
57*89c4ff92SAndroid Build Coastguard Worker                  cxxopts::value<std::vector<armnn::BackendId>>(computeDevice)->default_value("CpuAcc,CpuRef"))
58*89c4ff92SAndroid Build Coastguard Worker                 ("d,data-dir", "Path to directory containing the Cifar10 test data",
59*89c4ff92SAndroid Build Coastguard Worker                  cxxopts::value<std::string>(dataDir));
60*89c4ff92SAndroid Build Coastguard Worker 
61*89c4ff92SAndroid Build Coastguard Worker             auto result = in_options.parse(argc, argv);
62*89c4ff92SAndroid Build Coastguard Worker 
63*89c4ff92SAndroid Build Coastguard Worker             if(result.count("help") > 0)
64*89c4ff92SAndroid Build Coastguard Worker             {
65*89c4ff92SAndroid Build Coastguard Worker                 std::cout << in_options.help() << std::endl;
66*89c4ff92SAndroid Build Coastguard Worker                 return EXIT_FAILURE;
67*89c4ff92SAndroid Build Coastguard Worker             }
68*89c4ff92SAndroid Build Coastguard Worker 
69*89c4ff92SAndroid Build Coastguard Worker             //ensure mandatory parameters given
70*89c4ff92SAndroid Build Coastguard Worker             std::string mandatorySingleParameters[] = {"model-dir", "data-dir"};
71*89c4ff92SAndroid Build Coastguard Worker             for (auto param : mandatorySingleParameters)
72*89c4ff92SAndroid Build Coastguard Worker             {
73*89c4ff92SAndroid Build Coastguard Worker                 if(result.count(param) > 0)
74*89c4ff92SAndroid Build Coastguard Worker                 {
75*89c4ff92SAndroid Build Coastguard Worker                     std::string dir = result[param].as<std::string>();
76*89c4ff92SAndroid Build Coastguard Worker 
77*89c4ff92SAndroid Build Coastguard Worker                     if(!ValidateDirectory(dir)) {
78*89c4ff92SAndroid Build Coastguard Worker                         return EXIT_FAILURE;
79*89c4ff92SAndroid Build Coastguard Worker                     }
80*89c4ff92SAndroid Build Coastguard Worker                 } else {
81*89c4ff92SAndroid Build Coastguard Worker                     std::cerr << "Parameter \'--" << param << "\' is required but missing." << std::endl;
82*89c4ff92SAndroid Build Coastguard Worker                     return EXIT_FAILURE;
83*89c4ff92SAndroid Build Coastguard Worker                 }
84*89c4ff92SAndroid Build Coastguard Worker             }
85*89c4ff92SAndroid Build Coastguard Worker         }
86*89c4ff92SAndroid Build Coastguard Worker         catch (const cxxopts::OptionException& e)
87*89c4ff92SAndroid Build Coastguard Worker         {
88*89c4ff92SAndroid Build Coastguard Worker             std::cerr << e.what() << std::endl << in_options.help() << std::endl;
89*89c4ff92SAndroid Build Coastguard Worker             return EXIT_FAILURE;
90*89c4ff92SAndroid Build Coastguard Worker         }
91*89c4ff92SAndroid Build Coastguard Worker 
92*89c4ff92SAndroid Build Coastguard Worker         fs::path modelPath = fs::path(modelDir + "/cifar10_tf.prototxt");
93*89c4ff92SAndroid Build Coastguard Worker 
94*89c4ff92SAndroid Build Coastguard Worker         // Create runtime
95*89c4ff92SAndroid Build Coastguard Worker         // This will also load dynamic backend in case that the dynamic backend path is specified
96*89c4ff92SAndroid Build Coastguard Worker         armnn::IRuntime::CreationOptions options;
97*89c4ff92SAndroid Build Coastguard Worker         armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options));
98*89c4ff92SAndroid Build Coastguard Worker 
99*89c4ff92SAndroid Build Coastguard Worker         // Check if the requested backend are all valid
100*89c4ff92SAndroid Build Coastguard Worker         std::string invalidBackends;
101*89c4ff92SAndroid Build Coastguard Worker         if (!CheckRequestedBackendsAreValid(computeDevice, armnn::Optional<std::string&>(invalidBackends)))
102*89c4ff92SAndroid Build Coastguard Worker         {
103*89c4ff92SAndroid Build Coastguard Worker             ARMNN_LOG(fatal) << "The list of preferred devices contains invalid backend IDs: "
104*89c4ff92SAndroid Build Coastguard Worker                              << invalidBackends;
105*89c4ff92SAndroid Build Coastguard Worker             return EXIT_FAILURE;
106*89c4ff92SAndroid Build Coastguard Worker         }
107*89c4ff92SAndroid Build Coastguard Worker 
108*89c4ff92SAndroid Build Coastguard Worker         // Loads networks.
109*89c4ff92SAndroid Build Coastguard Worker         armnn::Status status;
110*89c4ff92SAndroid Build Coastguard Worker         struct Net
111*89c4ff92SAndroid Build Coastguard Worker         {
112*89c4ff92SAndroid Build Coastguard Worker             Net(armnn::NetworkId netId,
113*89c4ff92SAndroid Build Coastguard Worker                 const std::pair<armnn::LayerBindingId, armnn::TensorInfo>& in,
114*89c4ff92SAndroid Build Coastguard Worker                 const std::pair<armnn::LayerBindingId, armnn::TensorInfo>& out)
115*89c4ff92SAndroid Build Coastguard Worker             : m_Network(netId)
116*89c4ff92SAndroid Build Coastguard Worker             , m_InputBindingInfo(in)
117*89c4ff92SAndroid Build Coastguard Worker             , m_OutputBindingInfo(out)
118*89c4ff92SAndroid Build Coastguard Worker             {}
119*89c4ff92SAndroid Build Coastguard Worker 
120*89c4ff92SAndroid Build Coastguard Worker             armnn::NetworkId m_Network;
121*89c4ff92SAndroid Build Coastguard Worker             std::pair<armnn::LayerBindingId, armnn::TensorInfo> m_InputBindingInfo;
122*89c4ff92SAndroid Build Coastguard Worker             std::pair<armnn::LayerBindingId, armnn::TensorInfo> m_OutputBindingInfo;
123*89c4ff92SAndroid Build Coastguard Worker         };
124*89c4ff92SAndroid Build Coastguard Worker         std::vector<Net> networks;
125*89c4ff92SAndroid Build Coastguard Worker 
126*89c4ff92SAndroid Build Coastguard Worker         armnnTfLiteParser::ITfLiteParserPtr parser(armnnTfLiteParser::ITfLiteParserPtr::Create());
127*89c4ff92SAndroid Build Coastguard Worker 
128*89c4ff92SAndroid Build Coastguard Worker         const int networksCount = 4;
129*89c4ff92SAndroid Build Coastguard Worker         for (int i = 0; i < networksCount; ++i)
130*89c4ff92SAndroid Build Coastguard Worker         {
131*89c4ff92SAndroid Build Coastguard Worker             // Creates a network from a file on the disk.
132*89c4ff92SAndroid Build Coastguard Worker             armnn::INetworkPtr network = parser->CreateNetworkFromBinaryFile(modelPath.c_str(), {}, { "prob" });
133*89c4ff92SAndroid Build Coastguard Worker 
134*89c4ff92SAndroid Build Coastguard Worker             // Optimizes the network.
135*89c4ff92SAndroid Build Coastguard Worker             armnn::IOptimizedNetworkPtr optimizedNet(nullptr, nullptr);
136*89c4ff92SAndroid Build Coastguard Worker             try
137*89c4ff92SAndroid Build Coastguard Worker             {
138*89c4ff92SAndroid Build Coastguard Worker                 optimizedNet = armnn::Optimize(*network, computeDevice, runtime->GetDeviceSpec());
139*89c4ff92SAndroid Build Coastguard Worker             }
140*89c4ff92SAndroid Build Coastguard Worker             catch (const armnn::Exception& e)
141*89c4ff92SAndroid Build Coastguard Worker             {
142*89c4ff92SAndroid Build Coastguard Worker                 std::stringstream message;
143*89c4ff92SAndroid Build Coastguard Worker                 message << "armnn::Exception ("<<e.what()<<") caught from optimize.";
144*89c4ff92SAndroid Build Coastguard Worker                 ARMNN_LOG(fatal) << message.str();
145*89c4ff92SAndroid Build Coastguard Worker                 return EXIT_FAILURE;
146*89c4ff92SAndroid Build Coastguard Worker             }
147*89c4ff92SAndroid Build Coastguard Worker 
148*89c4ff92SAndroid Build Coastguard Worker             // Loads the network into the runtime.
149*89c4ff92SAndroid Build Coastguard Worker             armnn::NetworkId networkId;
150*89c4ff92SAndroid Build Coastguard Worker             status = runtime->LoadNetwork(networkId, std::move(optimizedNet));
151*89c4ff92SAndroid Build Coastguard Worker             if (status == armnn::Status::Failure)
152*89c4ff92SAndroid Build Coastguard Worker             {
153*89c4ff92SAndroid Build Coastguard Worker                 ARMNN_LOG(fatal) << "armnn::IRuntime: Failed to load network";
154*89c4ff92SAndroid Build Coastguard Worker                 return EXIT_FAILURE;
155*89c4ff92SAndroid Build Coastguard Worker             }
156*89c4ff92SAndroid Build Coastguard Worker 
157*89c4ff92SAndroid Build Coastguard Worker             networks.emplace_back(networkId,
158*89c4ff92SAndroid Build Coastguard Worker                 parser->GetNetworkInputBindingInfo("data"),
159*89c4ff92SAndroid Build Coastguard Worker                 parser->GetNetworkOutputBindingInfo("prob"));
160*89c4ff92SAndroid Build Coastguard Worker         }
161*89c4ff92SAndroid Build Coastguard Worker 
162*89c4ff92SAndroid Build Coastguard Worker         // Loads a test case and tests inference.
163*89c4ff92SAndroid Build Coastguard Worker         if (!ValidateDirectory(dataDir))
164*89c4ff92SAndroid Build Coastguard Worker         {
165*89c4ff92SAndroid Build Coastguard Worker             return EXIT_FAILURE;
166*89c4ff92SAndroid Build Coastguard Worker         }
167*89c4ff92SAndroid Build Coastguard Worker         Cifar10Database cifar10(dataDir);
168*89c4ff92SAndroid Build Coastguard Worker 
169*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int i = 0; i < 3; ++i)
170*89c4ff92SAndroid Build Coastguard Worker         {
171*89c4ff92SAndroid Build Coastguard Worker             // Loads test case data (including image data).
172*89c4ff92SAndroid Build Coastguard Worker             std::unique_ptr<Cifar10Database::TTestCaseData> testCaseData = cifar10.GetTestCaseData(i);
173*89c4ff92SAndroid Build Coastguard Worker 
174*89c4ff92SAndroid Build Coastguard Worker             // Tests inference.
175*89c4ff92SAndroid Build Coastguard Worker             std::vector<TContainer> outputs;
176*89c4ff92SAndroid Build Coastguard Worker             outputs.reserve(networksCount);
177*89c4ff92SAndroid Build Coastguard Worker 
178*89c4ff92SAndroid Build Coastguard Worker             for (unsigned int j = 0; j < networksCount; ++j)
179*89c4ff92SAndroid Build Coastguard Worker             {
180*89c4ff92SAndroid Build Coastguard Worker                 outputs.push_back(std::vector<float>(10));
181*89c4ff92SAndroid Build Coastguard Worker             }
182*89c4ff92SAndroid Build Coastguard Worker 
183*89c4ff92SAndroid Build Coastguard Worker             for (unsigned int k = 0; k < networksCount; ++k)
184*89c4ff92SAndroid Build Coastguard Worker             {
185*89c4ff92SAndroid Build Coastguard Worker                 std::vector<armnn::BindingPointInfo> inputBindings  = { networks[k].m_InputBindingInfo  };
186*89c4ff92SAndroid Build Coastguard Worker                 std::vector<armnn::BindingPointInfo> outputBindings = { networks[k].m_OutputBindingInfo };
187*89c4ff92SAndroid Build Coastguard Worker 
188*89c4ff92SAndroid Build Coastguard Worker                 std::vector<TContainer> inputDataContainers = { testCaseData->m_InputImage };
189*89c4ff92SAndroid Build Coastguard Worker                 std::vector<TContainer> outputDataContainers = { outputs[k] };
190*89c4ff92SAndroid Build Coastguard Worker 
191*89c4ff92SAndroid Build Coastguard Worker                 status = runtime->EnqueueWorkload(networks[k].m_Network,
192*89c4ff92SAndroid Build Coastguard Worker                     armnnUtils::MakeInputTensors(inputBindings, inputDataContainers),
193*89c4ff92SAndroid Build Coastguard Worker                     armnnUtils::MakeOutputTensors(outputBindings, outputDataContainers));
194*89c4ff92SAndroid Build Coastguard Worker                 if (status == armnn::Status::Failure)
195*89c4ff92SAndroid Build Coastguard Worker                 {
196*89c4ff92SAndroid Build Coastguard Worker                     ARMNN_LOG(fatal) << "armnn::IRuntime: Failed to enqueue workload";
197*89c4ff92SAndroid Build Coastguard Worker                     return EXIT_FAILURE;
198*89c4ff92SAndroid Build Coastguard Worker                 }
199*89c4ff92SAndroid Build Coastguard Worker             }
200*89c4ff92SAndroid Build Coastguard Worker 
201*89c4ff92SAndroid Build Coastguard Worker             // Compares outputs.
202*89c4ff92SAndroid Build Coastguard Worker             std::vector<float> output0 = mapbox::util::get<std::vector<float>>(outputs[0]);
203*89c4ff92SAndroid Build Coastguard Worker 
204*89c4ff92SAndroid Build Coastguard Worker             for (unsigned int k = 1; k < networksCount; ++k)
205*89c4ff92SAndroid Build Coastguard Worker             {
206*89c4ff92SAndroid Build Coastguard Worker                 std::vector<float> outputK = mapbox::util::get<std::vector<float>>(outputs[k]);
207*89c4ff92SAndroid Build Coastguard Worker 
208*89c4ff92SAndroid Build Coastguard Worker                 if (!std::equal(output0.begin(), output0.end(), outputK.begin(), outputK.end()))
209*89c4ff92SAndroid Build Coastguard Worker                 {
210*89c4ff92SAndroid Build Coastguard Worker                     ARMNN_LOG(error) << "Multiple networks inference failed!";
211*89c4ff92SAndroid Build Coastguard Worker                     return EXIT_FAILURE;
212*89c4ff92SAndroid Build Coastguard Worker                 }
213*89c4ff92SAndroid Build Coastguard Worker             }
214*89c4ff92SAndroid Build Coastguard Worker         }
215*89c4ff92SAndroid Build Coastguard Worker 
216*89c4ff92SAndroid Build Coastguard Worker         ARMNN_LOG(info) << "Multiple networks inference ran successfully!";
217*89c4ff92SAndroid Build Coastguard Worker         return EXIT_SUCCESS;
218*89c4ff92SAndroid Build Coastguard Worker     }
219*89c4ff92SAndroid Build Coastguard Worker     catch (const armnn::Exception& e)
220*89c4ff92SAndroid Build Coastguard Worker     {
221*89c4ff92SAndroid Build Coastguard Worker         // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
222*89c4ff92SAndroid Build Coastguard Worker         // exception of type std::length_error.
223*89c4ff92SAndroid Build Coastguard Worker         // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
224*89c4ff92SAndroid Build Coastguard Worker         std::cerr << "Armnn Error: " << e.what() << std::endl;
225*89c4ff92SAndroid Build Coastguard Worker         return EXIT_FAILURE;
226*89c4ff92SAndroid Build Coastguard Worker     }
227*89c4ff92SAndroid Build Coastguard Worker     catch (const std::exception& e)
228*89c4ff92SAndroid Build Coastguard Worker     {
229*89c4ff92SAndroid Build Coastguard Worker         // Coverity fix: various boost exceptions can be thrown by methods called by this test.
230*89c4ff92SAndroid Build Coastguard Worker         std::cerr << "WARNING: MultipleNetworksCifar10: An error has occurred when running the "
231*89c4ff92SAndroid Build Coastguard Worker                      "multiple networks inference tests: " << e.what() << std::endl;
232*89c4ff92SAndroid Build Coastguard Worker         return EXIT_FAILURE;
233*89c4ff92SAndroid Build Coastguard Worker     }
234*89c4ff92SAndroid Build Coastguard Worker }
235