xref: /aosp_15_r20/external/armnn/tests/ExecuteNetwork/ArmNNExecutor.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include "IExecutor.hpp"
9*89c4ff92SAndroid Build Coastguard Worker #include "NetworkExecutionUtils/NetworkExecutionUtils.hpp"
10*89c4ff92SAndroid Build Coastguard Worker #include "ExecuteNetworkProgramOptions.hpp"
11*89c4ff92SAndroid Build Coastguard Worker #include "armnn/utility/NumericCast.hpp"
12*89c4ff92SAndroid Build Coastguard Worker #include "armnn/utility/Timer.hpp"
13*89c4ff92SAndroid Build Coastguard Worker 
14*89c4ff92SAndroid Build Coastguard Worker #include <armnn/ArmNN.hpp>
15*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Threadpool.hpp>
16*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Logging.hpp>
17*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Timer.hpp>
18*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendRegistry.hpp>
19*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Assert.hpp>
20*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/NumericCast.hpp>
21*89c4ff92SAndroid Build Coastguard Worker 
22*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Filesystem.hpp>
23*89c4ff92SAndroid Build Coastguard Worker #include <HeapProfiling.hpp>
24*89c4ff92SAndroid Build Coastguard Worker 
25*89c4ff92SAndroid Build Coastguard Worker #include <fmt/format.h>
26*89c4ff92SAndroid Build Coastguard Worker 
27*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_SERIALIZER)
28*89c4ff92SAndroid Build Coastguard Worker #include "armnnDeserializer/IDeserializer.hpp"
29*89c4ff92SAndroid Build Coastguard Worker #endif
30*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_TF_LITE_PARSER)
31*89c4ff92SAndroid Build Coastguard Worker #include <armnnTfLiteParser/ITfLiteParser.hpp>
32*89c4ff92SAndroid Build Coastguard Worker #endif
33*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_ONNX_PARSER)
34*89c4ff92SAndroid Build Coastguard Worker #include <armnnOnnxParser/IOnnxParser.hpp>
35*89c4ff92SAndroid Build Coastguard Worker #endif
36*89c4ff92SAndroid Build Coastguard Worker 
37*89c4ff92SAndroid Build Coastguard Worker class ArmNNExecutor : public IExecutor
38*89c4ff92SAndroid Build Coastguard Worker {
39*89c4ff92SAndroid Build Coastguard Worker public:
40*89c4ff92SAndroid Build Coastguard Worker     ArmNNExecutor(const ExecuteNetworkParams& params, armnn::IRuntime::CreationOptions runtimeOptions);
41*89c4ff92SAndroid Build Coastguard Worker 
42*89c4ff92SAndroid Build Coastguard Worker     std::vector<const void* > Execute() override;
43*89c4ff92SAndroid Build Coastguard Worker     void PrintNetworkInfo() override;
44*89c4ff92SAndroid Build Coastguard Worker     void CompareAndPrintResult(std::vector<const void*> otherOutput) override;
45*89c4ff92SAndroid Build Coastguard Worker 
46*89c4ff92SAndroid Build Coastguard Worker private:
47*89c4ff92SAndroid Build Coastguard Worker 
48*89c4ff92SAndroid Build Coastguard Worker     /**
49*89c4ff92SAndroid Build Coastguard Worker      * Returns a pointer to the armnn::IRuntime* this will be shared by all ArmNNExecutors.
50*89c4ff92SAndroid Build Coastguard Worker      */
GetRuntime(const armnn::IRuntime::CreationOptions & options)51*89c4ff92SAndroid Build Coastguard Worker     armnn::IRuntime* GetRuntime(const armnn::IRuntime::CreationOptions& options)
52*89c4ff92SAndroid Build Coastguard Worker     {
53*89c4ff92SAndroid Build Coastguard Worker         static armnn::IRuntimePtr instance = armnn::IRuntime::Create(options);
54*89c4ff92SAndroid Build Coastguard Worker         // Instantiated on first use.
55*89c4ff92SAndroid Build Coastguard Worker         return instance.get();
56*89c4ff92SAndroid Build Coastguard Worker     }
57*89c4ff92SAndroid Build Coastguard Worker 
58*89c4ff92SAndroid Build Coastguard Worker     struct IParser;
59*89c4ff92SAndroid Build Coastguard Worker     struct IOInfo;
60*89c4ff92SAndroid Build Coastguard Worker     struct IOStorage;
61*89c4ff92SAndroid Build Coastguard Worker 
62*89c4ff92SAndroid Build Coastguard Worker     using BindingPointInfo = armnn::BindingPointInfo;
63*89c4ff92SAndroid Build Coastguard Worker 
64*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<IParser> CreateParser();
65*89c4ff92SAndroid Build Coastguard Worker 
66*89c4ff92SAndroid Build Coastguard Worker     void ExecuteAsync();
67*89c4ff92SAndroid Build Coastguard Worker     void ExecuteSync();
68*89c4ff92SAndroid Build Coastguard Worker     void SetupInputsAndOutputs();
69*89c4ff92SAndroid Build Coastguard Worker 
70*89c4ff92SAndroid Build Coastguard Worker     IOInfo GetIOInfo(armnn::IOptimizedNetwork* optNet);
71*89c4ff92SAndroid Build Coastguard Worker 
72*89c4ff92SAndroid Build Coastguard Worker     void PrintOutputTensors(const armnn::OutputTensors* outputTensors, unsigned int iteration);
73*89c4ff92SAndroid Build Coastguard Worker 
74*89c4ff92SAndroid Build Coastguard Worker     armnn::IOptimizedNetworkPtr OptimizeNetwork(armnn::INetwork* network);
75*89c4ff92SAndroid Build Coastguard Worker 
76*89c4ff92SAndroid Build Coastguard Worker     struct IOStorage
77*89c4ff92SAndroid Build Coastguard Worker     {
IOStorageArmNNExecutor::IOStorage78*89c4ff92SAndroid Build Coastguard Worker         IOStorage(size_t size)
79*89c4ff92SAndroid Build Coastguard Worker         {
80*89c4ff92SAndroid Build Coastguard Worker             m_Mem = operator new(size);
81*89c4ff92SAndroid Build Coastguard Worker         }
~IOStorageArmNNExecutor::IOStorage82*89c4ff92SAndroid Build Coastguard Worker         ~IOStorage()
83*89c4ff92SAndroid Build Coastguard Worker         {
84*89c4ff92SAndroid Build Coastguard Worker             operator delete(m_Mem);
85*89c4ff92SAndroid Build Coastguard Worker         }
IOStorageArmNNExecutor::IOStorage86*89c4ff92SAndroid Build Coastguard Worker         IOStorage(IOStorage&& rhs)
87*89c4ff92SAndroid Build Coastguard Worker         {
88*89c4ff92SAndroid Build Coastguard Worker             this->m_Mem = rhs.m_Mem;
89*89c4ff92SAndroid Build Coastguard Worker             rhs.m_Mem = nullptr;
90*89c4ff92SAndroid Build Coastguard Worker         }
91*89c4ff92SAndroid Build Coastguard Worker 
92*89c4ff92SAndroid Build Coastguard Worker         IOStorage(const IOStorage& rhs) = delete;
93*89c4ff92SAndroid Build Coastguard Worker         IOStorage& operator=(IOStorage& rhs) = delete;
94*89c4ff92SAndroid Build Coastguard Worker         IOStorage& operator=(IOStorage&& rhs) = delete;
95*89c4ff92SAndroid Build Coastguard Worker 
96*89c4ff92SAndroid Build Coastguard Worker         void* m_Mem;
97*89c4ff92SAndroid Build Coastguard Worker     };
98*89c4ff92SAndroid Build Coastguard Worker 
99*89c4ff92SAndroid Build Coastguard Worker     struct IOInfo
100*89c4ff92SAndroid Build Coastguard Worker     {
101*89c4ff92SAndroid Build Coastguard Worker         std::vector<std::string> m_InputNames;
102*89c4ff92SAndroid Build Coastguard Worker         std::vector<std::string> m_OutputNames;
103*89c4ff92SAndroid Build Coastguard Worker         std::map<std::string, armnn::BindingPointInfo> m_InputInfoMap;
104*89c4ff92SAndroid Build Coastguard Worker         std::map<std::string, armnn::BindingPointInfo> m_OutputInfoMap;
105*89c4ff92SAndroid Build Coastguard Worker     };
106*89c4ff92SAndroid Build Coastguard Worker 
107*89c4ff92SAndroid Build Coastguard Worker     IOInfo m_IOInfo;
108*89c4ff92SAndroid Build Coastguard Worker     std::vector<IOStorage> m_InputStorage;
109*89c4ff92SAndroid Build Coastguard Worker     std::vector<IOStorage> m_OutputStorage;
110*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::InputTensors> m_InputTensorsVec;
111*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::OutputTensors> m_OutputTensorsVec;
112*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::vector<unsigned int>> m_ImportedInputIds;
113*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::vector<unsigned int>> m_ImportedOutputIds;
114*89c4ff92SAndroid Build Coastguard Worker     armnn::IRuntime* m_Runtime;
115*89c4ff92SAndroid Build Coastguard Worker     armnn::NetworkId m_NetworkId;
116*89c4ff92SAndroid Build Coastguard Worker     ExecuteNetworkParams m_Params;
117*89c4ff92SAndroid Build Coastguard Worker 
118*89c4ff92SAndroid Build Coastguard Worker     struct IParser
119*89c4ff92SAndroid Build Coastguard Worker     {
120*89c4ff92SAndroid Build Coastguard Worker         virtual armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) = 0;
121*89c4ff92SAndroid Build Coastguard Worker         virtual armnn::BindingPointInfo GetInputBindingPointInfo(size_t id, const std::string& inputName) = 0;
122*89c4ff92SAndroid Build Coastguard Worker         virtual armnn::BindingPointInfo GetOutputBindingPointInfo(size_t id, const std::string& outputName) = 0;
123*89c4ff92SAndroid Build Coastguard Worker 
~IParserArmNNExecutor::IParser124*89c4ff92SAndroid Build Coastguard Worker         virtual ~IParser(){};
125*89c4ff92SAndroid Build Coastguard Worker     };
126*89c4ff92SAndroid Build Coastguard Worker 
127*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_SERIALIZER)
128*89c4ff92SAndroid Build Coastguard Worker     class ArmNNDeserializer : public IParser
129*89c4ff92SAndroid Build Coastguard Worker     {
130*89c4ff92SAndroid Build Coastguard Worker         public:
131*89c4ff92SAndroid Build Coastguard Worker         ArmNNDeserializer();
132*89c4ff92SAndroid Build Coastguard Worker 
133*89c4ff92SAndroid Build Coastguard Worker         armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) override;
134*89c4ff92SAndroid Build Coastguard Worker         armnn::BindingPointInfo GetInputBindingPointInfo(size_t, const std::string& inputName) override;
135*89c4ff92SAndroid Build Coastguard Worker         armnn::BindingPointInfo GetOutputBindingPointInfo(size_t, const std::string& outputName) override;
136*89c4ff92SAndroid Build Coastguard Worker 
137*89c4ff92SAndroid Build Coastguard Worker         private:
138*89c4ff92SAndroid Build Coastguard Worker         armnnDeserializer::IDeserializerPtr m_Parser;
139*89c4ff92SAndroid Build Coastguard Worker     };
140*89c4ff92SAndroid Build Coastguard Worker #endif
141*89c4ff92SAndroid Build Coastguard Worker 
142*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_TF_LITE_PARSER)
143*89c4ff92SAndroid Build Coastguard Worker     class TfliteParser : public IParser
144*89c4ff92SAndroid Build Coastguard Worker     {
145*89c4ff92SAndroid Build Coastguard Worker     public:
146*89c4ff92SAndroid Build Coastguard Worker         TfliteParser(const ExecuteNetworkParams& params);
147*89c4ff92SAndroid Build Coastguard Worker 
148*89c4ff92SAndroid Build Coastguard Worker         armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) override;
149*89c4ff92SAndroid Build Coastguard Worker         armnn::BindingPointInfo GetInputBindingPointInfo(size_t subgraphId, const std::string& inputName) override;
150*89c4ff92SAndroid Build Coastguard Worker         armnn::BindingPointInfo GetOutputBindingPointInfo(size_t subgraphId, const std::string& outputName) override;
151*89c4ff92SAndroid Build Coastguard Worker 
152*89c4ff92SAndroid Build Coastguard Worker     private:
__anondf7b68c10102()153*89c4ff92SAndroid Build Coastguard Worker         armnnTfLiteParser::ITfLiteParserPtr m_Parser{nullptr, [](armnnTfLiteParser::ITfLiteParser*){}};
154*89c4ff92SAndroid Build Coastguard Worker     };
155*89c4ff92SAndroid Build Coastguard Worker #endif
156*89c4ff92SAndroid Build Coastguard Worker 
157*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_ONNX_PARSER)
158*89c4ff92SAndroid Build Coastguard Worker     class OnnxParser : public IParser
159*89c4ff92SAndroid Build Coastguard Worker     {
160*89c4ff92SAndroid Build Coastguard Worker         public:
161*89c4ff92SAndroid Build Coastguard Worker         OnnxParser();
162*89c4ff92SAndroid Build Coastguard Worker 
163*89c4ff92SAndroid Build Coastguard Worker         armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) override;
164*89c4ff92SAndroid Build Coastguard Worker         armnn::BindingPointInfo GetInputBindingPointInfo(size_t subgraphId, const std::string& inputName) override;
165*89c4ff92SAndroid Build Coastguard Worker         armnn::BindingPointInfo GetOutputBindingPointInfo(size_t subgraphId, const std::string& outputName) override;
166*89c4ff92SAndroid Build Coastguard Worker 
167*89c4ff92SAndroid Build Coastguard Worker         private:
168*89c4ff92SAndroid Build Coastguard Worker         armnnOnnxParser::IOnnxParserPtr m_Parser;
169*89c4ff92SAndroid Build Coastguard Worker     };
170*89c4ff92SAndroid Build Coastguard Worker #endif
171*89c4ff92SAndroid Build Coastguard Worker };