1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017, 2023 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include "JsonPrinterTestImpl.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include "armnn/utility/StringUtils.hpp"
8*89c4ff92SAndroid Build Coastguard Worker
9*89c4ff92SAndroid Build Coastguard Worker #include <Profiling.hpp>
10*89c4ff92SAndroid Build Coastguard Worker
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Descriptors.hpp>
12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/IRuntime.hpp>
13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/INetwork.hpp>
14*89c4ff92SAndroid Build Coastguard Worker
15*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
16*89c4ff92SAndroid Build Coastguard Worker
17*89c4ff92SAndroid Build Coastguard Worker #include <sstream>
18*89c4ff92SAndroid Build Coastguard Worker #include <stack>
19*89c4ff92SAndroid Build Coastguard Worker #include <string>
20*89c4ff92SAndroid Build Coastguard Worker #include <algorithm>
21*89c4ff92SAndroid Build Coastguard Worker
AreMatchingPair(const char opening,const char closing)22*89c4ff92SAndroid Build Coastguard Worker inline bool AreMatchingPair(const char opening, const char closing)
23*89c4ff92SAndroid Build Coastguard Worker {
24*89c4ff92SAndroid Build Coastguard Worker return (opening == '{' && closing == '}') || (opening == '[' && closing == ']');
25*89c4ff92SAndroid Build Coastguard Worker }
26*89c4ff92SAndroid Build Coastguard Worker
AreParenthesesMatching(const std::string & exp)27*89c4ff92SAndroid Build Coastguard Worker bool AreParenthesesMatching(const std::string& exp)
28*89c4ff92SAndroid Build Coastguard Worker {
29*89c4ff92SAndroid Build Coastguard Worker std::stack<char> expStack;
30*89c4ff92SAndroid Build Coastguard Worker for (size_t i = 0; i < exp.length(); ++i)
31*89c4ff92SAndroid Build Coastguard Worker {
32*89c4ff92SAndroid Build Coastguard Worker if (exp[i] == '{' || exp[i] == '[')
33*89c4ff92SAndroid Build Coastguard Worker {
34*89c4ff92SAndroid Build Coastguard Worker expStack.push(exp[i]);
35*89c4ff92SAndroid Build Coastguard Worker }
36*89c4ff92SAndroid Build Coastguard Worker else if (exp[i] == '}' || exp[i] == ']')
37*89c4ff92SAndroid Build Coastguard Worker {
38*89c4ff92SAndroid Build Coastguard Worker if (expStack.empty() || !AreMatchingPair(expStack.top(), exp[i]))
39*89c4ff92SAndroid Build Coastguard Worker {
40*89c4ff92SAndroid Build Coastguard Worker return false;
41*89c4ff92SAndroid Build Coastguard Worker }
42*89c4ff92SAndroid Build Coastguard Worker else
43*89c4ff92SAndroid Build Coastguard Worker {
44*89c4ff92SAndroid Build Coastguard Worker expStack.pop();
45*89c4ff92SAndroid Build Coastguard Worker }
46*89c4ff92SAndroid Build Coastguard Worker }
47*89c4ff92SAndroid Build Coastguard Worker }
48*89c4ff92SAndroid Build Coastguard Worker return expStack.empty();
49*89c4ff92SAndroid Build Coastguard Worker }
50*89c4ff92SAndroid Build Coastguard Worker
ExtractMeasurements(const std::string & exp)51*89c4ff92SAndroid Build Coastguard Worker std::vector<double> ExtractMeasurements(const std::string& exp)
52*89c4ff92SAndroid Build Coastguard Worker {
53*89c4ff92SAndroid Build Coastguard Worker std::vector<double> numbers;
54*89c4ff92SAndroid Build Coastguard Worker bool inArray = false;
55*89c4ff92SAndroid Build Coastguard Worker std::string numberString;
56*89c4ff92SAndroid Build Coastguard Worker for (size_t i = 0; i < exp.size(); ++i)
57*89c4ff92SAndroid Build Coastguard Worker {
58*89c4ff92SAndroid Build Coastguard Worker if (exp[i] == '[')
59*89c4ff92SAndroid Build Coastguard Worker {
60*89c4ff92SAndroid Build Coastguard Worker inArray = true;
61*89c4ff92SAndroid Build Coastguard Worker }
62*89c4ff92SAndroid Build Coastguard Worker else if (exp[i] == ']' && inArray)
63*89c4ff92SAndroid Build Coastguard Worker {
64*89c4ff92SAndroid Build Coastguard Worker try
65*89c4ff92SAndroid Build Coastguard Worker {
66*89c4ff92SAndroid Build Coastguard Worker armnn::stringUtils::StringTrim(numberString, "\t,\n");
67*89c4ff92SAndroid Build Coastguard Worker numbers.push_back(std::stod(numberString));
68*89c4ff92SAndroid Build Coastguard Worker }
69*89c4ff92SAndroid Build Coastguard Worker catch (std::invalid_argument const&)
70*89c4ff92SAndroid Build Coastguard Worker {
71*89c4ff92SAndroid Build Coastguard Worker FAIL(("Could not convert measurements to double: " + numberString));
72*89c4ff92SAndroid Build Coastguard Worker }
73*89c4ff92SAndroid Build Coastguard Worker
74*89c4ff92SAndroid Build Coastguard Worker numberString.clear();
75*89c4ff92SAndroid Build Coastguard Worker inArray = false;
76*89c4ff92SAndroid Build Coastguard Worker }
77*89c4ff92SAndroid Build Coastguard Worker else if (exp[i] == ',' && inArray)
78*89c4ff92SAndroid Build Coastguard Worker {
79*89c4ff92SAndroid Build Coastguard Worker try
80*89c4ff92SAndroid Build Coastguard Worker {
81*89c4ff92SAndroid Build Coastguard Worker armnn::stringUtils::StringTrim(numberString, "\t,\n");
82*89c4ff92SAndroid Build Coastguard Worker numbers.push_back(std::stod(numberString));
83*89c4ff92SAndroid Build Coastguard Worker }
84*89c4ff92SAndroid Build Coastguard Worker catch (std::invalid_argument const&)
85*89c4ff92SAndroid Build Coastguard Worker {
86*89c4ff92SAndroid Build Coastguard Worker FAIL(("Could not convert measurements to double: " + numberString));
87*89c4ff92SAndroid Build Coastguard Worker }
88*89c4ff92SAndroid Build Coastguard Worker numberString.clear();
89*89c4ff92SAndroid Build Coastguard Worker }
90*89c4ff92SAndroid Build Coastguard Worker else if (exp[i] != '[' && inArray && exp[i] != ',' && exp[i] != ' ')
91*89c4ff92SAndroid Build Coastguard Worker {
92*89c4ff92SAndroid Build Coastguard Worker numberString += exp[i];
93*89c4ff92SAndroid Build Coastguard Worker }
94*89c4ff92SAndroid Build Coastguard Worker }
95*89c4ff92SAndroid Build Coastguard Worker return numbers;
96*89c4ff92SAndroid Build Coastguard Worker }
97*89c4ff92SAndroid Build Coastguard Worker
ExtractSections(const std::string & exp)98*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string> ExtractSections(const std::string& exp)
99*89c4ff92SAndroid Build Coastguard Worker {
100*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string> sections;
101*89c4ff92SAndroid Build Coastguard Worker
102*89c4ff92SAndroid Build Coastguard Worker std::stack<size_t> s;
103*89c4ff92SAndroid Build Coastguard Worker for (size_t i = 0; i < exp.size(); i++)
104*89c4ff92SAndroid Build Coastguard Worker {
105*89c4ff92SAndroid Build Coastguard Worker if (exp.at(i) == '{')
106*89c4ff92SAndroid Build Coastguard Worker {
107*89c4ff92SAndroid Build Coastguard Worker s.push(i);
108*89c4ff92SAndroid Build Coastguard Worker }
109*89c4ff92SAndroid Build Coastguard Worker else if (exp.at(i) == '}')
110*89c4ff92SAndroid Build Coastguard Worker {
111*89c4ff92SAndroid Build Coastguard Worker size_t from = s.top();
112*89c4ff92SAndroid Build Coastguard Worker s.pop();
113*89c4ff92SAndroid Build Coastguard Worker sections.push_back(exp.substr(from, i - from + 1));
114*89c4ff92SAndroid Build Coastguard Worker }
115*89c4ff92SAndroid Build Coastguard Worker }
116*89c4ff92SAndroid Build Coastguard Worker
117*89c4ff92SAndroid Build Coastguard Worker return sections;
118*89c4ff92SAndroid Build Coastguard Worker }
119*89c4ff92SAndroid Build Coastguard Worker
GetSoftmaxProfilerJson(const std::vector<armnn::BackendId> & backends)120*89c4ff92SAndroid Build Coastguard Worker std::string GetSoftmaxProfilerJson(const std::vector<armnn::BackendId>& backends)
121*89c4ff92SAndroid Build Coastguard Worker {
122*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
123*89c4ff92SAndroid Build Coastguard Worker
124*89c4ff92SAndroid Build Coastguard Worker CHECK(!backends.empty());
125*89c4ff92SAndroid Build Coastguard Worker
126*89c4ff92SAndroid Build Coastguard Worker ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance();
127*89c4ff92SAndroid Build Coastguard Worker
128*89c4ff92SAndroid Build Coastguard Worker // Create runtime in which test will run
129*89c4ff92SAndroid Build Coastguard Worker IRuntime::CreationOptions options;
130*89c4ff92SAndroid Build Coastguard Worker options.m_EnableGpuProfiling = backends.front() == armnn::Compute::GpuAcc;
131*89c4ff92SAndroid Build Coastguard Worker IRuntimePtr runtime(IRuntime::Create(options));
132*89c4ff92SAndroid Build Coastguard Worker
133*89c4ff92SAndroid Build Coastguard Worker // build up the structure of the network
134*89c4ff92SAndroid Build Coastguard Worker INetworkPtr net(INetwork::Create());
135*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input = net->AddInputLayer(0, "input");
136*89c4ff92SAndroid Build Coastguard Worker SoftmaxDescriptor softmaxDescriptor;
137*89c4ff92SAndroid Build Coastguard Worker // Set Axis to -1 if CL or Neon until further Axes are supported.
138*89c4ff92SAndroid Build Coastguard Worker if ( backends.front() == armnn::Compute::CpuAcc || backends.front() == armnn::Compute::GpuAcc)
139*89c4ff92SAndroid Build Coastguard Worker {
140*89c4ff92SAndroid Build Coastguard Worker softmaxDescriptor.m_Axis = -1;
141*89c4ff92SAndroid Build Coastguard Worker }
142*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* softmax = net->AddSoftmaxLayer(softmaxDescriptor, "softmax");
143*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* output = net->AddOutputLayer(0, "output");
144*89c4ff92SAndroid Build Coastguard Worker
145*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot(0).Connect(softmax->GetInputSlot(0));
146*89c4ff92SAndroid Build Coastguard Worker softmax->GetOutputSlot(0).Connect(output->GetInputSlot(0));
147*89c4ff92SAndroid Build Coastguard Worker
148*89c4ff92SAndroid Build Coastguard Worker // set the tensors in the network
149*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputTensorInfo(TensorShape({1, 5}), DataType::QAsymmU8);
150*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo.SetQuantizationOffset(100);
151*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo.SetQuantizationScale(10000.0f);
152*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
153*89c4ff92SAndroid Build Coastguard Worker
154*89c4ff92SAndroid Build Coastguard Worker TensorInfo outputTensorInfo(TensorShape({1, 5}), DataType::QAsymmU8);
155*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo.SetQuantizationOffset(0);
156*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo.SetQuantizationScale(1.0f / 256.0f);
157*89c4ff92SAndroid Build Coastguard Worker softmax->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
158*89c4ff92SAndroid Build Coastguard Worker
159*89c4ff92SAndroid Build Coastguard Worker // optimize the network
160*89c4ff92SAndroid Build Coastguard Worker armnn::OptimizerOptionsOpaque optOptions;
161*89c4ff92SAndroid Build Coastguard Worker optOptions.SetProfilingEnabled(true);
162*89c4ff92SAndroid Build Coastguard Worker IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec(), optOptions);
163*89c4ff92SAndroid Build Coastguard Worker if(!optNet)
164*89c4ff92SAndroid Build Coastguard Worker {
165*89c4ff92SAndroid Build Coastguard Worker FAIL("Error occurred during Optimization, Optimize() returned nullptr.");
166*89c4ff92SAndroid Build Coastguard Worker }
167*89c4ff92SAndroid Build Coastguard Worker // load it into the runtime
168*89c4ff92SAndroid Build Coastguard Worker NetworkId netId;
169*89c4ff92SAndroid Build Coastguard Worker auto error = runtime->LoadNetwork(netId, std::move(optNet));
170*89c4ff92SAndroid Build Coastguard Worker CHECK(error == Status::Success);
171*89c4ff92SAndroid Build Coastguard Worker
172*89c4ff92SAndroid Build Coastguard Worker // create structures for input & output
173*89c4ff92SAndroid Build Coastguard Worker std::vector<uint8_t> inputData
174*89c4ff92SAndroid Build Coastguard Worker {
175*89c4ff92SAndroid Build Coastguard Worker 1, 10, 3, 200, 5
176*89c4ff92SAndroid Build Coastguard Worker // one of inputs is sufficiently larger than the others to saturate softmax
177*89c4ff92SAndroid Build Coastguard Worker };
178*89c4ff92SAndroid Build Coastguard Worker std::vector<uint8_t> outputData(5);
179*89c4ff92SAndroid Build Coastguard Worker
180*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputTensorInfo2 = runtime->GetInputTensorInfo(netId, 0);
181*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo2.SetConstant(true);
182*89c4ff92SAndroid Build Coastguard Worker armnn::InputTensors inputTensors
183*89c4ff92SAndroid Build Coastguard Worker {
184*89c4ff92SAndroid Build Coastguard Worker {0, armnn::ConstTensor(inputTensorInfo2, inputData.data())}
185*89c4ff92SAndroid Build Coastguard Worker };
186*89c4ff92SAndroid Build Coastguard Worker armnn::OutputTensors outputTensors
187*89c4ff92SAndroid Build Coastguard Worker {
188*89c4ff92SAndroid Build Coastguard Worker {0, armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data())}
189*89c4ff92SAndroid Build Coastguard Worker };
190*89c4ff92SAndroid Build Coastguard Worker
191*89c4ff92SAndroid Build Coastguard Worker runtime->GetProfiler(netId)->EnableProfiling(true);
192*89c4ff92SAndroid Build Coastguard Worker
193*89c4ff92SAndroid Build Coastguard Worker // do the inferences
194*89c4ff92SAndroid Build Coastguard Worker runtime->EnqueueWorkload(netId, inputTensors, outputTensors);
195*89c4ff92SAndroid Build Coastguard Worker runtime->EnqueueWorkload(netId, inputTensors, outputTensors);
196*89c4ff92SAndroid Build Coastguard Worker runtime->EnqueueWorkload(netId, inputTensors, outputTensors);
197*89c4ff92SAndroid Build Coastguard Worker
198*89c4ff92SAndroid Build Coastguard Worker // retrieve the Profiler.Print() output
199*89c4ff92SAndroid Build Coastguard Worker std::stringstream ss;
200*89c4ff92SAndroid Build Coastguard Worker profilerManager.GetProfiler()->Print(ss);
201*89c4ff92SAndroid Build Coastguard Worker
202*89c4ff92SAndroid Build Coastguard Worker return ss.str();
203*89c4ff92SAndroid Build Coastguard Worker }
204*89c4ff92SAndroid Build Coastguard Worker
ValidateProfilerJson(std::string & result)205*89c4ff92SAndroid Build Coastguard Worker inline void ValidateProfilerJson(std::string& result)
206*89c4ff92SAndroid Build Coastguard Worker {
207*89c4ff92SAndroid Build Coastguard Worker // ensure all measurements are greater than zero
208*89c4ff92SAndroid Build Coastguard Worker std::vector<double> measurementsVector = ExtractMeasurements(result);
209*89c4ff92SAndroid Build Coastguard Worker CHECK(!measurementsVector.empty());
210*89c4ff92SAndroid Build Coastguard Worker
211*89c4ff92SAndroid Build Coastguard Worker // check sections contain raw and unit tags
212*89c4ff92SAndroid Build Coastguard Worker // first ensure Parenthesis are balanced
213*89c4ff92SAndroid Build Coastguard Worker if (AreParenthesesMatching(result))
214*89c4ff92SAndroid Build Coastguard Worker {
215*89c4ff92SAndroid Build Coastguard Worker // remove parent sections that will not have raw or unit tag
216*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string> sectionVector = ExtractSections(result);
217*89c4ff92SAndroid Build Coastguard Worker for (size_t i = 0; i < sectionVector.size(); ++i)
218*89c4ff92SAndroid Build Coastguard Worker {
219*89c4ff92SAndroid Build Coastguard Worker
220*89c4ff92SAndroid Build Coastguard Worker if (sectionVector[i].find("\"ArmNN\":") != std::string::npos
221*89c4ff92SAndroid Build Coastguard Worker || sectionVector[i].find("\"optimize_measurements\":") != std::string::npos
222*89c4ff92SAndroid Build Coastguard Worker || sectionVector[i].find("\"loaded_network_measurements\":") != std::string::npos
223*89c4ff92SAndroid Build Coastguard Worker || sectionVector[i].find("\"inference_measurements\":") != std::string::npos)
224*89c4ff92SAndroid Build Coastguard Worker {
225*89c4ff92SAndroid Build Coastguard Worker sectionVector.erase(sectionVector.begin() + static_cast<int>(i));
226*89c4ff92SAndroid Build Coastguard Worker }
227*89c4ff92SAndroid Build Coastguard Worker }
228*89c4ff92SAndroid Build Coastguard Worker CHECK(!sectionVector.empty());
229*89c4ff92SAndroid Build Coastguard Worker
230*89c4ff92SAndroid Build Coastguard Worker CHECK(std::all_of(sectionVector.begin(), sectionVector.end(),
231*89c4ff92SAndroid Build Coastguard Worker [](std::string i) { return (i.find("\"raw\":") != std::string::npos); }));
232*89c4ff92SAndroid Build Coastguard Worker
233*89c4ff92SAndroid Build Coastguard Worker CHECK(std::all_of(sectionVector.begin(), sectionVector.end(),
234*89c4ff92SAndroid Build Coastguard Worker [](std::string i) { return (i.find("\"unit\":") != std::string::npos); }));
235*89c4ff92SAndroid Build Coastguard Worker }
236*89c4ff92SAndroid Build Coastguard Worker
237*89c4ff92SAndroid Build Coastguard Worker // remove the time measurements as they vary from test to test
238*89c4ff92SAndroid Build Coastguard Worker result.erase(std::remove_if (result.begin(),result.end(),
239*89c4ff92SAndroid Build Coastguard Worker [](char c) { return c == '.'; }), result.end());
240*89c4ff92SAndroid Build Coastguard Worker result.erase(std::remove_if (result.begin(), result.end(), &isdigit), result.end());
241*89c4ff92SAndroid Build Coastguard Worker result.erase(std::remove_if (result.begin(),result.end(),
242*89c4ff92SAndroid Build Coastguard Worker [](char c) { return c == '\t'; }), result.end());
243*89c4ff92SAndroid Build Coastguard Worker
244*89c4ff92SAndroid Build Coastguard Worker CHECK(result.find("ArmNN") != std::string::npos);
245*89c4ff92SAndroid Build Coastguard Worker CHECK(result.find("inference_measurements") != std::string::npos);
246*89c4ff92SAndroid Build Coastguard Worker
247*89c4ff92SAndroid Build Coastguard Worker // ensure no spare parenthesis present in print output
248*89c4ff92SAndroid Build Coastguard Worker CHECK(AreParenthesesMatching(result));
249*89c4ff92SAndroid Build Coastguard Worker }
250*89c4ff92SAndroid Build Coastguard Worker
RunSoftmaxProfilerJsonPrinterTest(const std::vector<armnn::BackendId> & backends)251*89c4ff92SAndroid Build Coastguard Worker void RunSoftmaxProfilerJsonPrinterTest(const std::vector<armnn::BackendId>& backends)
252*89c4ff92SAndroid Build Coastguard Worker {
253*89c4ff92SAndroid Build Coastguard Worker // setup the test fixture and obtain JSON Printer result
254*89c4ff92SAndroid Build Coastguard Worker std::string result = GetSoftmaxProfilerJson(backends);
255*89c4ff92SAndroid Build Coastguard Worker
256*89c4ff92SAndroid Build Coastguard Worker // validate the JSON Printer result
257*89c4ff92SAndroid Build Coastguard Worker ValidateProfilerJson(result);
258*89c4ff92SAndroid Build Coastguard Worker
259*89c4ff92SAndroid Build Coastguard Worker const armnn::BackendId& firstBackend = backends.at(0);
260*89c4ff92SAndroid Build Coastguard Worker if (firstBackend == armnn::Compute::GpuAcc)
261*89c4ff92SAndroid Build Coastguard Worker {
262*89c4ff92SAndroid Build Coastguard Worker CHECK(result.find("OpenClKernelTimer/: softmax_layer_max_shift_exp_sum_quantized_serial GWS[,,]")
263*89c4ff92SAndroid Build Coastguard Worker != std::string::npos);
264*89c4ff92SAndroid Build Coastguard Worker }
265*89c4ff92SAndroid Build Coastguard Worker else if (firstBackend == armnn::Compute::CpuAcc)
266*89c4ff92SAndroid Build Coastguard Worker {
267*89c4ff92SAndroid Build Coastguard Worker CHECK(result.find("NeonKernelTimer") != std::string::npos); // Validate backend
268*89c4ff92SAndroid Build Coastguard Worker
269*89c4ff92SAndroid Build Coastguard Worker bool softmaxCheck = ((result.find("softmax") != std::string::npos) || // Validate softmax
270*89c4ff92SAndroid Build Coastguard Worker (result.find("Softmax") != std::string::npos) ||
271*89c4ff92SAndroid Build Coastguard Worker (result.find("SoftMax") != std::string::npos));
272*89c4ff92SAndroid Build Coastguard Worker CHECK(softmaxCheck);
273*89c4ff92SAndroid Build Coastguard Worker
274*89c4ff92SAndroid Build Coastguard Worker }
275*89c4ff92SAndroid Build Coastguard Worker }
276