xref: /aosp_15_r20/external/armnn/tests/MemoryStrategyBenchmark/MemoryStrategyBenchmark.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #include "TestBlocks.hpp"
6*89c4ff92SAndroid Build Coastguard Worker #include "TestStrategy.hpp"
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <IMemoryOptimizerStrategy.hpp>
9*89c4ff92SAndroid Build Coastguard Worker #include <MemoryOptimizerStrategyLibrary.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <strategies/StrategyValidator.hpp>
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker #include <cxxopts.hpp>
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker #include <iostream>
16*89c4ff92SAndroid Build Coastguard Worker #include <algorithm>
17*89c4ff92SAndroid Build Coastguard Worker #include <iomanip>
18*89c4ff92SAndroid Build Coastguard Worker 
19*89c4ff92SAndroid Build Coastguard Worker std::vector<TestBlock> testBlocks
20*89c4ff92SAndroid Build Coastguard Worker {
21*89c4ff92SAndroid Build Coastguard Worker     {"fsrcnn", fsrcnn},
22*89c4ff92SAndroid Build Coastguard Worker     {"inceptionv4", inceptionv4},
23*89c4ff92SAndroid Build Coastguard Worker     {"deeplabv3", deeplabv3},
24*89c4ff92SAndroid Build Coastguard Worker     {"deepspeechv1", deepspeechv1},
25*89c4ff92SAndroid Build Coastguard Worker     {"mobilebert", mobilebert},
26*89c4ff92SAndroid Build Coastguard Worker     {"ssd_mobilenetv2", ssd_mobilenetv2},
27*89c4ff92SAndroid Build Coastguard Worker     {"resnetv2", resnetv2},
28*89c4ff92SAndroid Build Coastguard Worker     {"yolov3",yolov3}
29*89c4ff92SAndroid Build Coastguard Worker };
30*89c4ff92SAndroid Build Coastguard Worker 
PrintModels()31*89c4ff92SAndroid Build Coastguard Worker void PrintModels()
32*89c4ff92SAndroid Build Coastguard Worker {
33*89c4ff92SAndroid Build Coastguard Worker     std::cout << "Available models:\n";
34*89c4ff92SAndroid Build Coastguard Worker     for (const auto& model : testBlocks)
35*89c4ff92SAndroid Build Coastguard Worker     {
36*89c4ff92SAndroid Build Coastguard Worker         std::cout << model.m_Name << "\n";
37*89c4ff92SAndroid Build Coastguard Worker     }
38*89c4ff92SAndroid Build Coastguard Worker     std::cout << "\n";
39*89c4ff92SAndroid Build Coastguard Worker }
40*89c4ff92SAndroid Build Coastguard Worker 
GetMinPossibleMemorySize(const std::vector<armnn::MemBlock> & blocks)41*89c4ff92SAndroid Build Coastguard Worker size_t GetMinPossibleMemorySize(const std::vector<armnn::MemBlock>& blocks)
42*89c4ff92SAndroid Build Coastguard Worker {
43*89c4ff92SAndroid Build Coastguard Worker     unsigned int maxLifetime = 0;
44*89c4ff92SAndroid Build Coastguard Worker     for (auto& block: blocks)
45*89c4ff92SAndroid Build Coastguard Worker     {
46*89c4ff92SAndroid Build Coastguard Worker         maxLifetime = std::max(maxLifetime, block.m_EndOfLife);
47*89c4ff92SAndroid Build Coastguard Worker     }
48*89c4ff92SAndroid Build Coastguard Worker     maxLifetime++;
49*89c4ff92SAndroid Build Coastguard Worker 
50*89c4ff92SAndroid Build Coastguard Worker    std::vector<size_t> lifetimes(maxLifetime);
51*89c4ff92SAndroid Build Coastguard Worker    for (const auto& block : blocks)
52*89c4ff92SAndroid Build Coastguard Worker    {
53*89c4ff92SAndroid Build Coastguard Worker        for (auto lifetime = block.m_StartOfLife; lifetime <= block.m_EndOfLife; ++lifetime)
54*89c4ff92SAndroid Build Coastguard Worker        {
55*89c4ff92SAndroid Build Coastguard Worker            lifetimes[lifetime] += block.m_MemSize;
56*89c4ff92SAndroid Build Coastguard Worker        }
57*89c4ff92SAndroid Build Coastguard Worker    }
58*89c4ff92SAndroid Build Coastguard Worker    return *std::max_element(lifetimes.begin(), lifetimes.end());
59*89c4ff92SAndroid Build Coastguard Worker }
60*89c4ff92SAndroid Build Coastguard Worker 
RunBenchmark(armnn::IMemoryOptimizerStrategy * strategy,std::vector<TestBlock> * models)61*89c4ff92SAndroid Build Coastguard Worker void RunBenchmark(armnn::IMemoryOptimizerStrategy* strategy, std::vector<TestBlock>* models)
62*89c4ff92SAndroid Build Coastguard Worker {
63*89c4ff92SAndroid Build Coastguard Worker     using Clock = std::chrono::high_resolution_clock;
64*89c4ff92SAndroid Build Coastguard Worker     float avgEfficiency = 0;
65*89c4ff92SAndroid Build Coastguard Worker     std::chrono::duration<double, std::milli> avgDuration{};
66*89c4ff92SAndroid Build Coastguard Worker     std::cout << "\nMemory Strategy: "  << strategy->GetName()<< "\n";
67*89c4ff92SAndroid Build Coastguard Worker     std::cout << "===============================================\n";
68*89c4ff92SAndroid Build Coastguard Worker     for (auto& model : *models)
69*89c4ff92SAndroid Build Coastguard Worker     {
70*89c4ff92SAndroid Build Coastguard Worker         auto now = Clock::now();
71*89c4ff92SAndroid Build Coastguard Worker         const std::vector<armnn::MemBin> result = strategy->Optimize(model.m_Blocks);
72*89c4ff92SAndroid Build Coastguard Worker         auto duration = std::chrono::duration<double, std::milli>(Clock::now() - now);
73*89c4ff92SAndroid Build Coastguard Worker 
74*89c4ff92SAndroid Build Coastguard Worker         avgDuration += duration;
75*89c4ff92SAndroid Build Coastguard Worker         size_t memoryUsage = 0;
76*89c4ff92SAndroid Build Coastguard Worker         for (auto bin : result)
77*89c4ff92SAndroid Build Coastguard Worker         {
78*89c4ff92SAndroid Build Coastguard Worker             memoryUsage += bin.m_MemSize;
79*89c4ff92SAndroid Build Coastguard Worker         }
80*89c4ff92SAndroid Build Coastguard Worker         size_t minSize = GetMinPossibleMemorySize(model.m_Blocks);
81*89c4ff92SAndroid Build Coastguard Worker 
82*89c4ff92SAndroid Build Coastguard Worker         float efficiency = static_cast<float>(minSize) / static_cast<float>(memoryUsage);
83*89c4ff92SAndroid Build Coastguard Worker         efficiency*=100;
84*89c4ff92SAndroid Build Coastguard Worker         avgEfficiency += efficiency;
85*89c4ff92SAndroid Build Coastguard Worker         std::cout << "\nModel: " << model.m_Name << "\n";
86*89c4ff92SAndroid Build Coastguard Worker 
87*89c4ff92SAndroid Build Coastguard Worker         std::cout << "Strategy execution time: " << std::setprecision(4) << duration.count() << " milliseconds\n";
88*89c4ff92SAndroid Build Coastguard Worker 
89*89c4ff92SAndroid Build Coastguard Worker         std::cout << "Memory usage: " << memoryUsage/1024 << " kb\n";
90*89c4ff92SAndroid Build Coastguard Worker 
91*89c4ff92SAndroid Build Coastguard Worker         std::cout << "Minimum possible usage: " << minSize/1024 << " kb\n";
92*89c4ff92SAndroid Build Coastguard Worker 
93*89c4ff92SAndroid Build Coastguard Worker         std::cout << "Memory efficiency: " << std::setprecision(3) << efficiency << "%\n";
94*89c4ff92SAndroid Build Coastguard Worker     }
95*89c4ff92SAndroid Build Coastguard Worker 
96*89c4ff92SAndroid Build Coastguard Worker     avgDuration/= static_cast<double>(models->size());
97*89c4ff92SAndroid Build Coastguard Worker     avgEfficiency/= static_cast<float>(models->size());
98*89c4ff92SAndroid Build Coastguard Worker 
99*89c4ff92SAndroid Build Coastguard Worker     std::cout << "\n===============================================\n";
100*89c4ff92SAndroid Build Coastguard Worker     std::cout << "Average memory duration: " << std::setprecision(4) << avgDuration.count() << " milliseconds\n";
101*89c4ff92SAndroid Build Coastguard Worker     std::cout << "Average memory efficiency: " << std::setprecision(3) << avgEfficiency << "%\n";
102*89c4ff92SAndroid Build Coastguard Worker }
103*89c4ff92SAndroid Build Coastguard Worker 
104*89c4ff92SAndroid Build Coastguard Worker struct BenchmarkOptions
105*89c4ff92SAndroid Build Coastguard Worker {
106*89c4ff92SAndroid Build Coastguard Worker     std::string m_StrategyName;
107*89c4ff92SAndroid Build Coastguard Worker     std::string m_ModelName;
108*89c4ff92SAndroid Build Coastguard Worker     bool m_UseDefaultStrategy = false;
109*89c4ff92SAndroid Build Coastguard Worker     bool m_Validate = false;
110*89c4ff92SAndroid Build Coastguard Worker };
111*89c4ff92SAndroid Build Coastguard Worker 
ParseOptions(int argc,char * argv[])112*89c4ff92SAndroid Build Coastguard Worker BenchmarkOptions ParseOptions(int argc, char* argv[])
113*89c4ff92SAndroid Build Coastguard Worker {
114*89c4ff92SAndroid Build Coastguard Worker     cxxopts::Options options("Memory Benchmark", "Tests memory optimization strategies on different models");
115*89c4ff92SAndroid Build Coastguard Worker 
116*89c4ff92SAndroid Build Coastguard Worker     options.add_options()
117*89c4ff92SAndroid Build Coastguard Worker         ("s, strategy", "Strategy name, do not specify to use default strategy", cxxopts::value<std::string>())
118*89c4ff92SAndroid Build Coastguard Worker         ("m, model", "Model name", cxxopts::value<std::string>())
119*89c4ff92SAndroid Build Coastguard Worker         ("v, validate", "Validate strategy", cxxopts::value<bool>()->default_value("false")->implicit_value("true"))
120*89c4ff92SAndroid Build Coastguard Worker         ("h,help", "Display usage information");
121*89c4ff92SAndroid Build Coastguard Worker 
122*89c4ff92SAndroid Build Coastguard Worker     auto result = options.parse(argc, argv);
123*89c4ff92SAndroid Build Coastguard Worker     if (result.count("help"))
124*89c4ff92SAndroid Build Coastguard Worker     {
125*89c4ff92SAndroid Build Coastguard Worker         std::cout << options.help() << std::endl;
126*89c4ff92SAndroid Build Coastguard Worker         PrintModels();
127*89c4ff92SAndroid Build Coastguard Worker 
128*89c4ff92SAndroid Build Coastguard Worker         std::cout << "\nAvailable strategies:\n";
129*89c4ff92SAndroid Build Coastguard Worker 
130*89c4ff92SAndroid Build Coastguard Worker         for (const auto& s :armnn::GetMemoryOptimizerStrategyNames())
131*89c4ff92SAndroid Build Coastguard Worker         {
132*89c4ff92SAndroid Build Coastguard Worker             std::cout << s << "\n";
133*89c4ff92SAndroid Build Coastguard Worker         }
134*89c4ff92SAndroid Build Coastguard Worker         exit(EXIT_SUCCESS);
135*89c4ff92SAndroid Build Coastguard Worker     }
136*89c4ff92SAndroid Build Coastguard Worker 
137*89c4ff92SAndroid Build Coastguard Worker     BenchmarkOptions benchmarkOptions;
138*89c4ff92SAndroid Build Coastguard Worker 
139*89c4ff92SAndroid Build Coastguard Worker     if(result.count("strategy"))
140*89c4ff92SAndroid Build Coastguard Worker     {
141*89c4ff92SAndroid Build Coastguard Worker         benchmarkOptions.m_StrategyName = result["strategy"].as<std::string>();
142*89c4ff92SAndroid Build Coastguard Worker     }
143*89c4ff92SAndroid Build Coastguard Worker     else
144*89c4ff92SAndroid Build Coastguard Worker     {
145*89c4ff92SAndroid Build Coastguard Worker         std::cout << "No Strategy given, using default strategy";
146*89c4ff92SAndroid Build Coastguard Worker 
147*89c4ff92SAndroid Build Coastguard Worker         benchmarkOptions.m_UseDefaultStrategy = true;
148*89c4ff92SAndroid Build Coastguard Worker     }
149*89c4ff92SAndroid Build Coastguard Worker 
150*89c4ff92SAndroid Build Coastguard Worker     if(result.count("model"))
151*89c4ff92SAndroid Build Coastguard Worker     {
152*89c4ff92SAndroid Build Coastguard Worker         benchmarkOptions.m_ModelName = result["model"].as<std::string>();
153*89c4ff92SAndroid Build Coastguard Worker     }
154*89c4ff92SAndroid Build Coastguard Worker 
155*89c4ff92SAndroid Build Coastguard Worker     benchmarkOptions.m_Validate = result["validate"].as<bool>();
156*89c4ff92SAndroid Build Coastguard Worker 
157*89c4ff92SAndroid Build Coastguard Worker     return benchmarkOptions;
158*89c4ff92SAndroid Build Coastguard Worker }
159*89c4ff92SAndroid Build Coastguard Worker 
main(int argc,char * argv[])160*89c4ff92SAndroid Build Coastguard Worker int main(int argc, char* argv[])
161*89c4ff92SAndroid Build Coastguard Worker {
162*89c4ff92SAndroid Build Coastguard Worker     BenchmarkOptions benchmarkOptions = ParseOptions(argc, argv);
163*89c4ff92SAndroid Build Coastguard Worker 
164*89c4ff92SAndroid Build Coastguard Worker     std::shared_ptr<armnn::IMemoryOptimizerStrategy> strategy;
165*89c4ff92SAndroid Build Coastguard Worker 
166*89c4ff92SAndroid Build Coastguard Worker     if (benchmarkOptions.m_UseDefaultStrategy)
167*89c4ff92SAndroid Build Coastguard Worker     {
168*89c4ff92SAndroid Build Coastguard Worker         strategy = std::make_shared<armnn::TestStrategy>();
169*89c4ff92SAndroid Build Coastguard Worker     }
170*89c4ff92SAndroid Build Coastguard Worker     else
171*89c4ff92SAndroid Build Coastguard Worker     {
172*89c4ff92SAndroid Build Coastguard Worker         strategy = armnn::GetMemoryOptimizerStrategy(benchmarkOptions.m_StrategyName);
173*89c4ff92SAndroid Build Coastguard Worker 
174*89c4ff92SAndroid Build Coastguard Worker         if (!strategy)
175*89c4ff92SAndroid Build Coastguard Worker         {
176*89c4ff92SAndroid Build Coastguard Worker             std::cout << "Strategy name not found\n";
177*89c4ff92SAndroid Build Coastguard Worker             return 0;
178*89c4ff92SAndroid Build Coastguard Worker         }
179*89c4ff92SAndroid Build Coastguard Worker     }
180*89c4ff92SAndroid Build Coastguard Worker 
181*89c4ff92SAndroid Build Coastguard Worker     std::vector<TestBlock> model;
182*89c4ff92SAndroid Build Coastguard Worker     std::vector<TestBlock>* modelsToTest = &testBlocks;
183*89c4ff92SAndroid Build Coastguard Worker     if (benchmarkOptions.m_ModelName.size() != 0)
184*89c4ff92SAndroid Build Coastguard Worker     {
185*89c4ff92SAndroid Build Coastguard Worker         auto it = std::find_if(testBlocks.cbegin(), testBlocks.cend(), [&](const TestBlock testBlock)
186*89c4ff92SAndroid Build Coastguard Worker         {
187*89c4ff92SAndroid Build Coastguard Worker             return testBlock.m_Name == benchmarkOptions.m_ModelName;
188*89c4ff92SAndroid Build Coastguard Worker         });
189*89c4ff92SAndroid Build Coastguard Worker 
190*89c4ff92SAndroid Build Coastguard Worker         if (it == testBlocks.end())
191*89c4ff92SAndroid Build Coastguard Worker         {
192*89c4ff92SAndroid Build Coastguard Worker             std::cout << "Model name not found\n";
193*89c4ff92SAndroid Build Coastguard Worker             return 0;
194*89c4ff92SAndroid Build Coastguard Worker         }
195*89c4ff92SAndroid Build Coastguard Worker         else
196*89c4ff92SAndroid Build Coastguard Worker         {
197*89c4ff92SAndroid Build Coastguard Worker             model.push_back(*it);
198*89c4ff92SAndroid Build Coastguard Worker             modelsToTest = &model;
199*89c4ff92SAndroid Build Coastguard Worker         }
200*89c4ff92SAndroid Build Coastguard Worker     }
201*89c4ff92SAndroid Build Coastguard Worker 
202*89c4ff92SAndroid Build Coastguard Worker     if (benchmarkOptions.m_Validate)
203*89c4ff92SAndroid Build Coastguard Worker     {
204*89c4ff92SAndroid Build Coastguard Worker         armnn::StrategyValidator strategyValidator;
205*89c4ff92SAndroid Build Coastguard Worker 
206*89c4ff92SAndroid Build Coastguard Worker         strategyValidator.SetStrategy(strategy);
207*89c4ff92SAndroid Build Coastguard Worker 
208*89c4ff92SAndroid Build Coastguard Worker         RunBenchmark(&strategyValidator, modelsToTest);
209*89c4ff92SAndroid Build Coastguard Worker     }
210*89c4ff92SAndroid Build Coastguard Worker     else
211*89c4ff92SAndroid Build Coastguard Worker     {
212*89c4ff92SAndroid Build Coastguard Worker         RunBenchmark(strategy.get(), modelsToTest);
213*89c4ff92SAndroid Build Coastguard Worker     }
214*89c4ff92SAndroid Build Coastguard Worker 
215*89c4ff92SAndroid Build Coastguard Worker }