1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include "armnn/ArmNN.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include "armnn/Utils.hpp"
8*89c4ff92SAndroid Build Coastguard Worker #include "armnn/INetwork.hpp"
9*89c4ff92SAndroid Build Coastguard Worker #include "armnnTfLiteParser/TfLiteParser.hpp"
10*89c4ff92SAndroid Build Coastguard Worker #include "../Cifar10Database.hpp"
11*89c4ff92SAndroid Build Coastguard Worker #include "../InferenceTest.hpp"
12*89c4ff92SAndroid Build Coastguard Worker #include "../InferenceModel.hpp"
13*89c4ff92SAndroid Build Coastguard Worker
14*89c4ff92SAndroid Build Coastguard Worker #include <cxxopts/cxxopts.hpp>
15*89c4ff92SAndroid Build Coastguard Worker
16*89c4ff92SAndroid Build Coastguard Worker #include <iostream>
17*89c4ff92SAndroid Build Coastguard Worker #include <chrono>
18*89c4ff92SAndroid Build Coastguard Worker #include <vector>
19*89c4ff92SAndroid Build Coastguard Worker #include <array>
20*89c4ff92SAndroid Build Coastguard Worker
21*89c4ff92SAndroid Build Coastguard Worker
22*89c4ff92SAndroid Build Coastguard Worker using namespace std;
23*89c4ff92SAndroid Build Coastguard Worker using namespace std::chrono;
24*89c4ff92SAndroid Build Coastguard Worker using namespace armnn::test;
25*89c4ff92SAndroid Build Coastguard Worker
main(int argc,char * argv[])26*89c4ff92SAndroid Build Coastguard Worker int main(int argc, char* argv[])
27*89c4ff92SAndroid Build Coastguard Worker {
28*89c4ff92SAndroid Build Coastguard Worker #ifdef NDEBUG
29*89c4ff92SAndroid Build Coastguard Worker armnn::LogSeverity level = armnn::LogSeverity::Info;
30*89c4ff92SAndroid Build Coastguard Worker #else
31*89c4ff92SAndroid Build Coastguard Worker armnn::LogSeverity level = armnn::LogSeverity::Debug;
32*89c4ff92SAndroid Build Coastguard Worker #endif
33*89c4ff92SAndroid Build Coastguard Worker
34*89c4ff92SAndroid Build Coastguard Worker try
35*89c4ff92SAndroid Build Coastguard Worker {
36*89c4ff92SAndroid Build Coastguard Worker // Configures logging for both the ARMNN library and this test program.
37*89c4ff92SAndroid Build Coastguard Worker armnn::ConfigureLogging(true, true, level);
38*89c4ff92SAndroid Build Coastguard Worker
39*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> computeDevice;
40*89c4ff92SAndroid Build Coastguard Worker std::string modelDir;
41*89c4ff92SAndroid Build Coastguard Worker std::string dataDir;
42*89c4ff92SAndroid Build Coastguard Worker
43*89c4ff92SAndroid Build Coastguard Worker const std::string backendsMessage = "Which device to run layers on by default. Possible choices: "
44*89c4ff92SAndroid Build Coastguard Worker + armnn::BackendRegistryInstance().GetBackendIdsAsString();
45*89c4ff92SAndroid Build Coastguard Worker
46*89c4ff92SAndroid Build Coastguard Worker cxxopts::Options in_options("MultipleNetworksCifar10",
47*89c4ff92SAndroid Build Coastguard Worker "Run multiple networks inference tests using Cifar-10 data.");
48*89c4ff92SAndroid Build Coastguard Worker
49*89c4ff92SAndroid Build Coastguard Worker try
50*89c4ff92SAndroid Build Coastguard Worker {
51*89c4ff92SAndroid Build Coastguard Worker // Adds generic options needed for all inference tests.
52*89c4ff92SAndroid Build Coastguard Worker in_options.add_options()
53*89c4ff92SAndroid Build Coastguard Worker ("h,help", "Display help messages")
54*89c4ff92SAndroid Build Coastguard Worker ("m,model-dir", "Path to directory containing the Cifar10 model file",
55*89c4ff92SAndroid Build Coastguard Worker cxxopts::value<std::string>(modelDir))
56*89c4ff92SAndroid Build Coastguard Worker ("c,compute", backendsMessage.c_str(),
57*89c4ff92SAndroid Build Coastguard Worker cxxopts::value<std::vector<armnn::BackendId>>(computeDevice)->default_value("CpuAcc,CpuRef"))
58*89c4ff92SAndroid Build Coastguard Worker ("d,data-dir", "Path to directory containing the Cifar10 test data",
59*89c4ff92SAndroid Build Coastguard Worker cxxopts::value<std::string>(dataDir));
60*89c4ff92SAndroid Build Coastguard Worker
61*89c4ff92SAndroid Build Coastguard Worker auto result = in_options.parse(argc, argv);
62*89c4ff92SAndroid Build Coastguard Worker
63*89c4ff92SAndroid Build Coastguard Worker if(result.count("help") > 0)
64*89c4ff92SAndroid Build Coastguard Worker {
65*89c4ff92SAndroid Build Coastguard Worker std::cout << in_options.help() << std::endl;
66*89c4ff92SAndroid Build Coastguard Worker return EXIT_FAILURE;
67*89c4ff92SAndroid Build Coastguard Worker }
68*89c4ff92SAndroid Build Coastguard Worker
69*89c4ff92SAndroid Build Coastguard Worker //ensure mandatory parameters given
70*89c4ff92SAndroid Build Coastguard Worker std::string mandatorySingleParameters[] = {"model-dir", "data-dir"};
71*89c4ff92SAndroid Build Coastguard Worker for (auto param : mandatorySingleParameters)
72*89c4ff92SAndroid Build Coastguard Worker {
73*89c4ff92SAndroid Build Coastguard Worker if(result.count(param) > 0)
74*89c4ff92SAndroid Build Coastguard Worker {
75*89c4ff92SAndroid Build Coastguard Worker std::string dir = result[param].as<std::string>();
76*89c4ff92SAndroid Build Coastguard Worker
77*89c4ff92SAndroid Build Coastguard Worker if(!ValidateDirectory(dir)) {
78*89c4ff92SAndroid Build Coastguard Worker return EXIT_FAILURE;
79*89c4ff92SAndroid Build Coastguard Worker }
80*89c4ff92SAndroid Build Coastguard Worker } else {
81*89c4ff92SAndroid Build Coastguard Worker std::cerr << "Parameter \'--" << param << "\' is required but missing." << std::endl;
82*89c4ff92SAndroid Build Coastguard Worker return EXIT_FAILURE;
83*89c4ff92SAndroid Build Coastguard Worker }
84*89c4ff92SAndroid Build Coastguard Worker }
85*89c4ff92SAndroid Build Coastguard Worker }
86*89c4ff92SAndroid Build Coastguard Worker catch (const cxxopts::OptionException& e)
87*89c4ff92SAndroid Build Coastguard Worker {
88*89c4ff92SAndroid Build Coastguard Worker std::cerr << e.what() << std::endl << in_options.help() << std::endl;
89*89c4ff92SAndroid Build Coastguard Worker return EXIT_FAILURE;
90*89c4ff92SAndroid Build Coastguard Worker }
91*89c4ff92SAndroid Build Coastguard Worker
92*89c4ff92SAndroid Build Coastguard Worker fs::path modelPath = fs::path(modelDir + "/cifar10_tf.prototxt");
93*89c4ff92SAndroid Build Coastguard Worker
94*89c4ff92SAndroid Build Coastguard Worker // Create runtime
95*89c4ff92SAndroid Build Coastguard Worker // This will also load dynamic backend in case that the dynamic backend path is specified
96*89c4ff92SAndroid Build Coastguard Worker armnn::IRuntime::CreationOptions options;
97*89c4ff92SAndroid Build Coastguard Worker armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options));
98*89c4ff92SAndroid Build Coastguard Worker
99*89c4ff92SAndroid Build Coastguard Worker // Check if the requested backend are all valid
100*89c4ff92SAndroid Build Coastguard Worker std::string invalidBackends;
101*89c4ff92SAndroid Build Coastguard Worker if (!CheckRequestedBackendsAreValid(computeDevice, armnn::Optional<std::string&>(invalidBackends)))
102*89c4ff92SAndroid Build Coastguard Worker {
103*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(fatal) << "The list of preferred devices contains invalid backend IDs: "
104*89c4ff92SAndroid Build Coastguard Worker << invalidBackends;
105*89c4ff92SAndroid Build Coastguard Worker return EXIT_FAILURE;
106*89c4ff92SAndroid Build Coastguard Worker }
107*89c4ff92SAndroid Build Coastguard Worker
108*89c4ff92SAndroid Build Coastguard Worker // Loads networks.
109*89c4ff92SAndroid Build Coastguard Worker armnn::Status status;
110*89c4ff92SAndroid Build Coastguard Worker struct Net
111*89c4ff92SAndroid Build Coastguard Worker {
112*89c4ff92SAndroid Build Coastguard Worker Net(armnn::NetworkId netId,
113*89c4ff92SAndroid Build Coastguard Worker const std::pair<armnn::LayerBindingId, armnn::TensorInfo>& in,
114*89c4ff92SAndroid Build Coastguard Worker const std::pair<armnn::LayerBindingId, armnn::TensorInfo>& out)
115*89c4ff92SAndroid Build Coastguard Worker : m_Network(netId)
116*89c4ff92SAndroid Build Coastguard Worker , m_InputBindingInfo(in)
117*89c4ff92SAndroid Build Coastguard Worker , m_OutputBindingInfo(out)
118*89c4ff92SAndroid Build Coastguard Worker {}
119*89c4ff92SAndroid Build Coastguard Worker
120*89c4ff92SAndroid Build Coastguard Worker armnn::NetworkId m_Network;
121*89c4ff92SAndroid Build Coastguard Worker std::pair<armnn::LayerBindingId, armnn::TensorInfo> m_InputBindingInfo;
122*89c4ff92SAndroid Build Coastguard Worker std::pair<armnn::LayerBindingId, armnn::TensorInfo> m_OutputBindingInfo;
123*89c4ff92SAndroid Build Coastguard Worker };
124*89c4ff92SAndroid Build Coastguard Worker std::vector<Net> networks;
125*89c4ff92SAndroid Build Coastguard Worker
126*89c4ff92SAndroid Build Coastguard Worker armnnTfLiteParser::ITfLiteParserPtr parser(armnnTfLiteParser::ITfLiteParserPtr::Create());
127*89c4ff92SAndroid Build Coastguard Worker
128*89c4ff92SAndroid Build Coastguard Worker const int networksCount = 4;
129*89c4ff92SAndroid Build Coastguard Worker for (int i = 0; i < networksCount; ++i)
130*89c4ff92SAndroid Build Coastguard Worker {
131*89c4ff92SAndroid Build Coastguard Worker // Creates a network from a file on the disk.
132*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr network = parser->CreateNetworkFromBinaryFile(modelPath.c_str(), {}, { "prob" });
133*89c4ff92SAndroid Build Coastguard Worker
134*89c4ff92SAndroid Build Coastguard Worker // Optimizes the network.
135*89c4ff92SAndroid Build Coastguard Worker armnn::IOptimizedNetworkPtr optimizedNet(nullptr, nullptr);
136*89c4ff92SAndroid Build Coastguard Worker try
137*89c4ff92SAndroid Build Coastguard Worker {
138*89c4ff92SAndroid Build Coastguard Worker optimizedNet = armnn::Optimize(*network, computeDevice, runtime->GetDeviceSpec());
139*89c4ff92SAndroid Build Coastguard Worker }
140*89c4ff92SAndroid Build Coastguard Worker catch (const armnn::Exception& e)
141*89c4ff92SAndroid Build Coastguard Worker {
142*89c4ff92SAndroid Build Coastguard Worker std::stringstream message;
143*89c4ff92SAndroid Build Coastguard Worker message << "armnn::Exception ("<<e.what()<<") caught from optimize.";
144*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(fatal) << message.str();
145*89c4ff92SAndroid Build Coastguard Worker return EXIT_FAILURE;
146*89c4ff92SAndroid Build Coastguard Worker }
147*89c4ff92SAndroid Build Coastguard Worker
148*89c4ff92SAndroid Build Coastguard Worker // Loads the network into the runtime.
149*89c4ff92SAndroid Build Coastguard Worker armnn::NetworkId networkId;
150*89c4ff92SAndroid Build Coastguard Worker status = runtime->LoadNetwork(networkId, std::move(optimizedNet));
151*89c4ff92SAndroid Build Coastguard Worker if (status == armnn::Status::Failure)
152*89c4ff92SAndroid Build Coastguard Worker {
153*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(fatal) << "armnn::IRuntime: Failed to load network";
154*89c4ff92SAndroid Build Coastguard Worker return EXIT_FAILURE;
155*89c4ff92SAndroid Build Coastguard Worker }
156*89c4ff92SAndroid Build Coastguard Worker
157*89c4ff92SAndroid Build Coastguard Worker networks.emplace_back(networkId,
158*89c4ff92SAndroid Build Coastguard Worker parser->GetNetworkInputBindingInfo("data"),
159*89c4ff92SAndroid Build Coastguard Worker parser->GetNetworkOutputBindingInfo("prob"));
160*89c4ff92SAndroid Build Coastguard Worker }
161*89c4ff92SAndroid Build Coastguard Worker
162*89c4ff92SAndroid Build Coastguard Worker // Loads a test case and tests inference.
163*89c4ff92SAndroid Build Coastguard Worker if (!ValidateDirectory(dataDir))
164*89c4ff92SAndroid Build Coastguard Worker {
165*89c4ff92SAndroid Build Coastguard Worker return EXIT_FAILURE;
166*89c4ff92SAndroid Build Coastguard Worker }
167*89c4ff92SAndroid Build Coastguard Worker Cifar10Database cifar10(dataDir);
168*89c4ff92SAndroid Build Coastguard Worker
169*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < 3; ++i)
170*89c4ff92SAndroid Build Coastguard Worker {
171*89c4ff92SAndroid Build Coastguard Worker // Loads test case data (including image data).
172*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<Cifar10Database::TTestCaseData> testCaseData = cifar10.GetTestCaseData(i);
173*89c4ff92SAndroid Build Coastguard Worker
174*89c4ff92SAndroid Build Coastguard Worker // Tests inference.
175*89c4ff92SAndroid Build Coastguard Worker std::vector<TContainer> outputs;
176*89c4ff92SAndroid Build Coastguard Worker outputs.reserve(networksCount);
177*89c4ff92SAndroid Build Coastguard Worker
178*89c4ff92SAndroid Build Coastguard Worker for (unsigned int j = 0; j < networksCount; ++j)
179*89c4ff92SAndroid Build Coastguard Worker {
180*89c4ff92SAndroid Build Coastguard Worker outputs.push_back(std::vector<float>(10));
181*89c4ff92SAndroid Build Coastguard Worker }
182*89c4ff92SAndroid Build Coastguard Worker
183*89c4ff92SAndroid Build Coastguard Worker for (unsigned int k = 0; k < networksCount; ++k)
184*89c4ff92SAndroid Build Coastguard Worker {
185*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BindingPointInfo> inputBindings = { networks[k].m_InputBindingInfo };
186*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BindingPointInfo> outputBindings = { networks[k].m_OutputBindingInfo };
187*89c4ff92SAndroid Build Coastguard Worker
188*89c4ff92SAndroid Build Coastguard Worker std::vector<TContainer> inputDataContainers = { testCaseData->m_InputImage };
189*89c4ff92SAndroid Build Coastguard Worker std::vector<TContainer> outputDataContainers = { outputs[k] };
190*89c4ff92SAndroid Build Coastguard Worker
191*89c4ff92SAndroid Build Coastguard Worker status = runtime->EnqueueWorkload(networks[k].m_Network,
192*89c4ff92SAndroid Build Coastguard Worker armnnUtils::MakeInputTensors(inputBindings, inputDataContainers),
193*89c4ff92SAndroid Build Coastguard Worker armnnUtils::MakeOutputTensors(outputBindings, outputDataContainers));
194*89c4ff92SAndroid Build Coastguard Worker if (status == armnn::Status::Failure)
195*89c4ff92SAndroid Build Coastguard Worker {
196*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(fatal) << "armnn::IRuntime: Failed to enqueue workload";
197*89c4ff92SAndroid Build Coastguard Worker return EXIT_FAILURE;
198*89c4ff92SAndroid Build Coastguard Worker }
199*89c4ff92SAndroid Build Coastguard Worker }
200*89c4ff92SAndroid Build Coastguard Worker
201*89c4ff92SAndroid Build Coastguard Worker // Compares outputs.
202*89c4ff92SAndroid Build Coastguard Worker std::vector<float> output0 = mapbox::util::get<std::vector<float>>(outputs[0]);
203*89c4ff92SAndroid Build Coastguard Worker
204*89c4ff92SAndroid Build Coastguard Worker for (unsigned int k = 1; k < networksCount; ++k)
205*89c4ff92SAndroid Build Coastguard Worker {
206*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputK = mapbox::util::get<std::vector<float>>(outputs[k]);
207*89c4ff92SAndroid Build Coastguard Worker
208*89c4ff92SAndroid Build Coastguard Worker if (!std::equal(output0.begin(), output0.end(), outputK.begin(), outputK.end()))
209*89c4ff92SAndroid Build Coastguard Worker {
210*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << "Multiple networks inference failed!";
211*89c4ff92SAndroid Build Coastguard Worker return EXIT_FAILURE;
212*89c4ff92SAndroid Build Coastguard Worker }
213*89c4ff92SAndroid Build Coastguard Worker }
214*89c4ff92SAndroid Build Coastguard Worker }
215*89c4ff92SAndroid Build Coastguard Worker
216*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(info) << "Multiple networks inference ran successfully!";
217*89c4ff92SAndroid Build Coastguard Worker return EXIT_SUCCESS;
218*89c4ff92SAndroid Build Coastguard Worker }
219*89c4ff92SAndroid Build Coastguard Worker catch (const armnn::Exception& e)
220*89c4ff92SAndroid Build Coastguard Worker {
221*89c4ff92SAndroid Build Coastguard Worker // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
222*89c4ff92SAndroid Build Coastguard Worker // exception of type std::length_error.
223*89c4ff92SAndroid Build Coastguard Worker // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
224*89c4ff92SAndroid Build Coastguard Worker std::cerr << "Armnn Error: " << e.what() << std::endl;
225*89c4ff92SAndroid Build Coastguard Worker return EXIT_FAILURE;
226*89c4ff92SAndroid Build Coastguard Worker }
227*89c4ff92SAndroid Build Coastguard Worker catch (const std::exception& e)
228*89c4ff92SAndroid Build Coastguard Worker {
229*89c4ff92SAndroid Build Coastguard Worker // Coverity fix: various boost exceptions can be thrown by methods called by this test.
230*89c4ff92SAndroid Build Coastguard Worker std::cerr << "WARNING: MultipleNetworksCifar10: An error has occurred when running the "
231*89c4ff92SAndroid Build Coastguard Worker "multiple networks inference tests: " << e.what() << std::endl;
232*89c4ff92SAndroid Build Coastguard Worker return EXIT_FAILURE;
233*89c4ff92SAndroid Build Coastguard Worker }
234*89c4ff92SAndroid Build Coastguard Worker }
235