1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #include "TestBlocks.hpp"
6*89c4ff92SAndroid Build Coastguard Worker #include "TestStrategy.hpp"
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <IMemoryOptimizerStrategy.hpp>
9*89c4ff92SAndroid Build Coastguard Worker #include <MemoryOptimizerStrategyLibrary.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <strategies/StrategyValidator.hpp>
11*89c4ff92SAndroid Build Coastguard Worker
12*89c4ff92SAndroid Build Coastguard Worker
13*89c4ff92SAndroid Build Coastguard Worker #include <cxxopts.hpp>
14*89c4ff92SAndroid Build Coastguard Worker
15*89c4ff92SAndroid Build Coastguard Worker #include <iostream>
16*89c4ff92SAndroid Build Coastguard Worker #include <algorithm>
17*89c4ff92SAndroid Build Coastguard Worker #include <iomanip>
18*89c4ff92SAndroid Build Coastguard Worker
19*89c4ff92SAndroid Build Coastguard Worker std::vector<TestBlock> testBlocks
20*89c4ff92SAndroid Build Coastguard Worker {
21*89c4ff92SAndroid Build Coastguard Worker {"fsrcnn", fsrcnn},
22*89c4ff92SAndroid Build Coastguard Worker {"inceptionv4", inceptionv4},
23*89c4ff92SAndroid Build Coastguard Worker {"deeplabv3", deeplabv3},
24*89c4ff92SAndroid Build Coastguard Worker {"deepspeechv1", deepspeechv1},
25*89c4ff92SAndroid Build Coastguard Worker {"mobilebert", mobilebert},
26*89c4ff92SAndroid Build Coastguard Worker {"ssd_mobilenetv2", ssd_mobilenetv2},
27*89c4ff92SAndroid Build Coastguard Worker {"resnetv2", resnetv2},
28*89c4ff92SAndroid Build Coastguard Worker {"yolov3",yolov3}
29*89c4ff92SAndroid Build Coastguard Worker };
30*89c4ff92SAndroid Build Coastguard Worker
PrintModels()31*89c4ff92SAndroid Build Coastguard Worker void PrintModels()
32*89c4ff92SAndroid Build Coastguard Worker {
33*89c4ff92SAndroid Build Coastguard Worker std::cout << "Available models:\n";
34*89c4ff92SAndroid Build Coastguard Worker for (const auto& model : testBlocks)
35*89c4ff92SAndroid Build Coastguard Worker {
36*89c4ff92SAndroid Build Coastguard Worker std::cout << model.m_Name << "\n";
37*89c4ff92SAndroid Build Coastguard Worker }
38*89c4ff92SAndroid Build Coastguard Worker std::cout << "\n";
39*89c4ff92SAndroid Build Coastguard Worker }
40*89c4ff92SAndroid Build Coastguard Worker
GetMinPossibleMemorySize(const std::vector<armnn::MemBlock> & blocks)41*89c4ff92SAndroid Build Coastguard Worker size_t GetMinPossibleMemorySize(const std::vector<armnn::MemBlock>& blocks)
42*89c4ff92SAndroid Build Coastguard Worker {
43*89c4ff92SAndroid Build Coastguard Worker unsigned int maxLifetime = 0;
44*89c4ff92SAndroid Build Coastguard Worker for (auto& block: blocks)
45*89c4ff92SAndroid Build Coastguard Worker {
46*89c4ff92SAndroid Build Coastguard Worker maxLifetime = std::max(maxLifetime, block.m_EndOfLife);
47*89c4ff92SAndroid Build Coastguard Worker }
48*89c4ff92SAndroid Build Coastguard Worker maxLifetime++;
49*89c4ff92SAndroid Build Coastguard Worker
50*89c4ff92SAndroid Build Coastguard Worker std::vector<size_t> lifetimes(maxLifetime);
51*89c4ff92SAndroid Build Coastguard Worker for (const auto& block : blocks)
52*89c4ff92SAndroid Build Coastguard Worker {
53*89c4ff92SAndroid Build Coastguard Worker for (auto lifetime = block.m_StartOfLife; lifetime <= block.m_EndOfLife; ++lifetime)
54*89c4ff92SAndroid Build Coastguard Worker {
55*89c4ff92SAndroid Build Coastguard Worker lifetimes[lifetime] += block.m_MemSize;
56*89c4ff92SAndroid Build Coastguard Worker }
57*89c4ff92SAndroid Build Coastguard Worker }
58*89c4ff92SAndroid Build Coastguard Worker return *std::max_element(lifetimes.begin(), lifetimes.end());
59*89c4ff92SAndroid Build Coastguard Worker }
60*89c4ff92SAndroid Build Coastguard Worker
RunBenchmark(armnn::IMemoryOptimizerStrategy * strategy,std::vector<TestBlock> * models)61*89c4ff92SAndroid Build Coastguard Worker void RunBenchmark(armnn::IMemoryOptimizerStrategy* strategy, std::vector<TestBlock>* models)
62*89c4ff92SAndroid Build Coastguard Worker {
63*89c4ff92SAndroid Build Coastguard Worker using Clock = std::chrono::high_resolution_clock;
64*89c4ff92SAndroid Build Coastguard Worker float avgEfficiency = 0;
65*89c4ff92SAndroid Build Coastguard Worker std::chrono::duration<double, std::milli> avgDuration{};
66*89c4ff92SAndroid Build Coastguard Worker std::cout << "\nMemory Strategy: " << strategy->GetName()<< "\n";
67*89c4ff92SAndroid Build Coastguard Worker std::cout << "===============================================\n";
68*89c4ff92SAndroid Build Coastguard Worker for (auto& model : *models)
69*89c4ff92SAndroid Build Coastguard Worker {
70*89c4ff92SAndroid Build Coastguard Worker auto now = Clock::now();
71*89c4ff92SAndroid Build Coastguard Worker const std::vector<armnn::MemBin> result = strategy->Optimize(model.m_Blocks);
72*89c4ff92SAndroid Build Coastguard Worker auto duration = std::chrono::duration<double, std::milli>(Clock::now() - now);
73*89c4ff92SAndroid Build Coastguard Worker
74*89c4ff92SAndroid Build Coastguard Worker avgDuration += duration;
75*89c4ff92SAndroid Build Coastguard Worker size_t memoryUsage = 0;
76*89c4ff92SAndroid Build Coastguard Worker for (auto bin : result)
77*89c4ff92SAndroid Build Coastguard Worker {
78*89c4ff92SAndroid Build Coastguard Worker memoryUsage += bin.m_MemSize;
79*89c4ff92SAndroid Build Coastguard Worker }
80*89c4ff92SAndroid Build Coastguard Worker size_t minSize = GetMinPossibleMemorySize(model.m_Blocks);
81*89c4ff92SAndroid Build Coastguard Worker
82*89c4ff92SAndroid Build Coastguard Worker float efficiency = static_cast<float>(minSize) / static_cast<float>(memoryUsage);
83*89c4ff92SAndroid Build Coastguard Worker efficiency*=100;
84*89c4ff92SAndroid Build Coastguard Worker avgEfficiency += efficiency;
85*89c4ff92SAndroid Build Coastguard Worker std::cout << "\nModel: " << model.m_Name << "\n";
86*89c4ff92SAndroid Build Coastguard Worker
87*89c4ff92SAndroid Build Coastguard Worker std::cout << "Strategy execution time: " << std::setprecision(4) << duration.count() << " milliseconds\n";
88*89c4ff92SAndroid Build Coastguard Worker
89*89c4ff92SAndroid Build Coastguard Worker std::cout << "Memory usage: " << memoryUsage/1024 << " kb\n";
90*89c4ff92SAndroid Build Coastguard Worker
91*89c4ff92SAndroid Build Coastguard Worker std::cout << "Minimum possible usage: " << minSize/1024 << " kb\n";
92*89c4ff92SAndroid Build Coastguard Worker
93*89c4ff92SAndroid Build Coastguard Worker std::cout << "Memory efficiency: " << std::setprecision(3) << efficiency << "%\n";
94*89c4ff92SAndroid Build Coastguard Worker }
95*89c4ff92SAndroid Build Coastguard Worker
96*89c4ff92SAndroid Build Coastguard Worker avgDuration/= static_cast<double>(models->size());
97*89c4ff92SAndroid Build Coastguard Worker avgEfficiency/= static_cast<float>(models->size());
98*89c4ff92SAndroid Build Coastguard Worker
99*89c4ff92SAndroid Build Coastguard Worker std::cout << "\n===============================================\n";
100*89c4ff92SAndroid Build Coastguard Worker std::cout << "Average memory duration: " << std::setprecision(4) << avgDuration.count() << " milliseconds\n";
101*89c4ff92SAndroid Build Coastguard Worker std::cout << "Average memory efficiency: " << std::setprecision(3) << avgEfficiency << "%\n";
102*89c4ff92SAndroid Build Coastguard Worker }
103*89c4ff92SAndroid Build Coastguard Worker
104*89c4ff92SAndroid Build Coastguard Worker struct BenchmarkOptions
105*89c4ff92SAndroid Build Coastguard Worker {
106*89c4ff92SAndroid Build Coastguard Worker std::string m_StrategyName;
107*89c4ff92SAndroid Build Coastguard Worker std::string m_ModelName;
108*89c4ff92SAndroid Build Coastguard Worker bool m_UseDefaultStrategy = false;
109*89c4ff92SAndroid Build Coastguard Worker bool m_Validate = false;
110*89c4ff92SAndroid Build Coastguard Worker };
111*89c4ff92SAndroid Build Coastguard Worker
ParseOptions(int argc,char * argv[])112*89c4ff92SAndroid Build Coastguard Worker BenchmarkOptions ParseOptions(int argc, char* argv[])
113*89c4ff92SAndroid Build Coastguard Worker {
114*89c4ff92SAndroid Build Coastguard Worker cxxopts::Options options("Memory Benchmark", "Tests memory optimization strategies on different models");
115*89c4ff92SAndroid Build Coastguard Worker
116*89c4ff92SAndroid Build Coastguard Worker options.add_options()
117*89c4ff92SAndroid Build Coastguard Worker ("s, strategy", "Strategy name, do not specify to use default strategy", cxxopts::value<std::string>())
118*89c4ff92SAndroid Build Coastguard Worker ("m, model", "Model name", cxxopts::value<std::string>())
119*89c4ff92SAndroid Build Coastguard Worker ("v, validate", "Validate strategy", cxxopts::value<bool>()->default_value("false")->implicit_value("true"))
120*89c4ff92SAndroid Build Coastguard Worker ("h,help", "Display usage information");
121*89c4ff92SAndroid Build Coastguard Worker
122*89c4ff92SAndroid Build Coastguard Worker auto result = options.parse(argc, argv);
123*89c4ff92SAndroid Build Coastguard Worker if (result.count("help"))
124*89c4ff92SAndroid Build Coastguard Worker {
125*89c4ff92SAndroid Build Coastguard Worker std::cout << options.help() << std::endl;
126*89c4ff92SAndroid Build Coastguard Worker PrintModels();
127*89c4ff92SAndroid Build Coastguard Worker
128*89c4ff92SAndroid Build Coastguard Worker std::cout << "\nAvailable strategies:\n";
129*89c4ff92SAndroid Build Coastguard Worker
130*89c4ff92SAndroid Build Coastguard Worker for (const auto& s :armnn::GetMemoryOptimizerStrategyNames())
131*89c4ff92SAndroid Build Coastguard Worker {
132*89c4ff92SAndroid Build Coastguard Worker std::cout << s << "\n";
133*89c4ff92SAndroid Build Coastguard Worker }
134*89c4ff92SAndroid Build Coastguard Worker exit(EXIT_SUCCESS);
135*89c4ff92SAndroid Build Coastguard Worker }
136*89c4ff92SAndroid Build Coastguard Worker
137*89c4ff92SAndroid Build Coastguard Worker BenchmarkOptions benchmarkOptions;
138*89c4ff92SAndroid Build Coastguard Worker
139*89c4ff92SAndroid Build Coastguard Worker if(result.count("strategy"))
140*89c4ff92SAndroid Build Coastguard Worker {
141*89c4ff92SAndroid Build Coastguard Worker benchmarkOptions.m_StrategyName = result["strategy"].as<std::string>();
142*89c4ff92SAndroid Build Coastguard Worker }
143*89c4ff92SAndroid Build Coastguard Worker else
144*89c4ff92SAndroid Build Coastguard Worker {
145*89c4ff92SAndroid Build Coastguard Worker std::cout << "No Strategy given, using default strategy";
146*89c4ff92SAndroid Build Coastguard Worker
147*89c4ff92SAndroid Build Coastguard Worker benchmarkOptions.m_UseDefaultStrategy = true;
148*89c4ff92SAndroid Build Coastguard Worker }
149*89c4ff92SAndroid Build Coastguard Worker
150*89c4ff92SAndroid Build Coastguard Worker if(result.count("model"))
151*89c4ff92SAndroid Build Coastguard Worker {
152*89c4ff92SAndroid Build Coastguard Worker benchmarkOptions.m_ModelName = result["model"].as<std::string>();
153*89c4ff92SAndroid Build Coastguard Worker }
154*89c4ff92SAndroid Build Coastguard Worker
155*89c4ff92SAndroid Build Coastguard Worker benchmarkOptions.m_Validate = result["validate"].as<bool>();
156*89c4ff92SAndroid Build Coastguard Worker
157*89c4ff92SAndroid Build Coastguard Worker return benchmarkOptions;
158*89c4ff92SAndroid Build Coastguard Worker }
159*89c4ff92SAndroid Build Coastguard Worker
main(int argc,char * argv[])160*89c4ff92SAndroid Build Coastguard Worker int main(int argc, char* argv[])
161*89c4ff92SAndroid Build Coastguard Worker {
162*89c4ff92SAndroid Build Coastguard Worker BenchmarkOptions benchmarkOptions = ParseOptions(argc, argv);
163*89c4ff92SAndroid Build Coastguard Worker
164*89c4ff92SAndroid Build Coastguard Worker std::shared_ptr<armnn::IMemoryOptimizerStrategy> strategy;
165*89c4ff92SAndroid Build Coastguard Worker
166*89c4ff92SAndroid Build Coastguard Worker if (benchmarkOptions.m_UseDefaultStrategy)
167*89c4ff92SAndroid Build Coastguard Worker {
168*89c4ff92SAndroid Build Coastguard Worker strategy = std::make_shared<armnn::TestStrategy>();
169*89c4ff92SAndroid Build Coastguard Worker }
170*89c4ff92SAndroid Build Coastguard Worker else
171*89c4ff92SAndroid Build Coastguard Worker {
172*89c4ff92SAndroid Build Coastguard Worker strategy = armnn::GetMemoryOptimizerStrategy(benchmarkOptions.m_StrategyName);
173*89c4ff92SAndroid Build Coastguard Worker
174*89c4ff92SAndroid Build Coastguard Worker if (!strategy)
175*89c4ff92SAndroid Build Coastguard Worker {
176*89c4ff92SAndroid Build Coastguard Worker std::cout << "Strategy name not found\n";
177*89c4ff92SAndroid Build Coastguard Worker return 0;
178*89c4ff92SAndroid Build Coastguard Worker }
179*89c4ff92SAndroid Build Coastguard Worker }
180*89c4ff92SAndroid Build Coastguard Worker
181*89c4ff92SAndroid Build Coastguard Worker std::vector<TestBlock> model;
182*89c4ff92SAndroid Build Coastguard Worker std::vector<TestBlock>* modelsToTest = &testBlocks;
183*89c4ff92SAndroid Build Coastguard Worker if (benchmarkOptions.m_ModelName.size() != 0)
184*89c4ff92SAndroid Build Coastguard Worker {
185*89c4ff92SAndroid Build Coastguard Worker auto it = std::find_if(testBlocks.cbegin(), testBlocks.cend(), [&](const TestBlock testBlock)
186*89c4ff92SAndroid Build Coastguard Worker {
187*89c4ff92SAndroid Build Coastguard Worker return testBlock.m_Name == benchmarkOptions.m_ModelName;
188*89c4ff92SAndroid Build Coastguard Worker });
189*89c4ff92SAndroid Build Coastguard Worker
190*89c4ff92SAndroid Build Coastguard Worker if (it == testBlocks.end())
191*89c4ff92SAndroid Build Coastguard Worker {
192*89c4ff92SAndroid Build Coastguard Worker std::cout << "Model name not found\n";
193*89c4ff92SAndroid Build Coastguard Worker return 0;
194*89c4ff92SAndroid Build Coastguard Worker }
195*89c4ff92SAndroid Build Coastguard Worker else
196*89c4ff92SAndroid Build Coastguard Worker {
197*89c4ff92SAndroid Build Coastguard Worker model.push_back(*it);
198*89c4ff92SAndroid Build Coastguard Worker modelsToTest = &model;
199*89c4ff92SAndroid Build Coastguard Worker }
200*89c4ff92SAndroid Build Coastguard Worker }
201*89c4ff92SAndroid Build Coastguard Worker
202*89c4ff92SAndroid Build Coastguard Worker if (benchmarkOptions.m_Validate)
203*89c4ff92SAndroid Build Coastguard Worker {
204*89c4ff92SAndroid Build Coastguard Worker armnn::StrategyValidator strategyValidator;
205*89c4ff92SAndroid Build Coastguard Worker
206*89c4ff92SAndroid Build Coastguard Worker strategyValidator.SetStrategy(strategy);
207*89c4ff92SAndroid Build Coastguard Worker
208*89c4ff92SAndroid Build Coastguard Worker RunBenchmark(&strategyValidator, modelsToTest);
209*89c4ff92SAndroid Build Coastguard Worker }
210*89c4ff92SAndroid Build Coastguard Worker else
211*89c4ff92SAndroid Build Coastguard Worker {
212*89c4ff92SAndroid Build Coastguard Worker RunBenchmark(strategy.get(), modelsToTest);
213*89c4ff92SAndroid Build Coastguard Worker }
214*89c4ff92SAndroid Build Coastguard Worker
215*89c4ff92SAndroid Build Coastguard Worker }