xref: /aosp_15_r20/external/armnn/tests/ExecuteNetwork/ArmNNExecutor.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "IExecutor.hpp"
9 #include "NetworkExecutionUtils/NetworkExecutionUtils.hpp"
10 #include "ExecuteNetworkProgramOptions.hpp"
11 #include "armnn/utility/NumericCast.hpp"
12 #include "armnn/utility/Timer.hpp"
13 
14 #include <armnn/ArmNN.hpp>
15 #include <armnn/Threadpool.hpp>
16 #include <armnn/Logging.hpp>
17 #include <armnn/utility/Timer.hpp>
18 #include <armnn/BackendRegistry.hpp>
19 #include <armnn/utility/Assert.hpp>
20 #include <armnn/utility/NumericCast.hpp>
21 
22 #include <armnnUtils/Filesystem.hpp>
23 #include <HeapProfiling.hpp>
24 
25 #include <fmt/format.h>
26 
27 #if defined(ARMNN_SERIALIZER)
28 #include "armnnDeserializer/IDeserializer.hpp"
29 #endif
30 #if defined(ARMNN_TF_LITE_PARSER)
31 #include <armnnTfLiteParser/ITfLiteParser.hpp>
32 #endif
33 #if defined(ARMNN_ONNX_PARSER)
34 #include <armnnOnnxParser/IOnnxParser.hpp>
35 #endif
36 
37 class ArmNNExecutor : public IExecutor
38 {
39 public:
40     ArmNNExecutor(const ExecuteNetworkParams& params, armnn::IRuntime::CreationOptions runtimeOptions);
41 
42     std::vector<const void* > Execute() override;
43     void PrintNetworkInfo() override;
44     void CompareAndPrintResult(std::vector<const void*> otherOutput) override;
45 
46 private:
47 
48     /**
49      * Returns a pointer to the armnn::IRuntime* this will be shared by all ArmNNExecutors.
50      */
GetRuntime(const armnn::IRuntime::CreationOptions & options)51     armnn::IRuntime* GetRuntime(const armnn::IRuntime::CreationOptions& options)
52     {
53         static armnn::IRuntimePtr instance = armnn::IRuntime::Create(options);
54         // Instantiated on first use.
55         return instance.get();
56     }
57 
58     struct IParser;
59     struct IOInfo;
60     struct IOStorage;
61 
62     using BindingPointInfo = armnn::BindingPointInfo;
63 
64     std::unique_ptr<IParser> CreateParser();
65 
66     void ExecuteAsync();
67     void ExecuteSync();
68     void SetupInputsAndOutputs();
69 
70     IOInfo GetIOInfo(armnn::IOptimizedNetwork* optNet);
71 
72     void PrintOutputTensors(const armnn::OutputTensors* outputTensors, unsigned int iteration);
73 
74     armnn::IOptimizedNetworkPtr OptimizeNetwork(armnn::INetwork* network);
75 
76     struct IOStorage
77     {
IOStorageArmNNExecutor::IOStorage78         IOStorage(size_t size)
79         {
80             m_Mem = operator new(size);
81         }
~IOStorageArmNNExecutor::IOStorage82         ~IOStorage()
83         {
84             operator delete(m_Mem);
85         }
IOStorageArmNNExecutor::IOStorage86         IOStorage(IOStorage&& rhs)
87         {
88             this->m_Mem = rhs.m_Mem;
89             rhs.m_Mem = nullptr;
90         }
91 
92         IOStorage(const IOStorage& rhs) = delete;
93         IOStorage& operator=(IOStorage& rhs) = delete;
94         IOStorage& operator=(IOStorage&& rhs) = delete;
95 
96         void* m_Mem;
97     };
98 
99     struct IOInfo
100     {
101         std::vector<std::string> m_InputNames;
102         std::vector<std::string> m_OutputNames;
103         std::map<std::string, armnn::BindingPointInfo> m_InputInfoMap;
104         std::map<std::string, armnn::BindingPointInfo> m_OutputInfoMap;
105     };
106 
107     IOInfo m_IOInfo;
108     std::vector<IOStorage> m_InputStorage;
109     std::vector<IOStorage> m_OutputStorage;
110     std::vector<armnn::InputTensors> m_InputTensorsVec;
111     std::vector<armnn::OutputTensors> m_OutputTensorsVec;
112     std::vector<std::vector<unsigned int>> m_ImportedInputIds;
113     std::vector<std::vector<unsigned int>> m_ImportedOutputIds;
114     armnn::IRuntime* m_Runtime;
115     armnn::NetworkId m_NetworkId;
116     ExecuteNetworkParams m_Params;
117 
118     struct IParser
119     {
120         virtual armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) = 0;
121         virtual armnn::BindingPointInfo GetInputBindingPointInfo(size_t id, const std::string& inputName) = 0;
122         virtual armnn::BindingPointInfo GetOutputBindingPointInfo(size_t id, const std::string& outputName) = 0;
123 
~IParserArmNNExecutor::IParser124         virtual ~IParser(){};
125     };
126 
127 #if defined(ARMNN_SERIALIZER)
128     class ArmNNDeserializer : public IParser
129     {
130         public:
131         ArmNNDeserializer();
132 
133         armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) override;
134         armnn::BindingPointInfo GetInputBindingPointInfo(size_t, const std::string& inputName) override;
135         armnn::BindingPointInfo GetOutputBindingPointInfo(size_t, const std::string& outputName) override;
136 
137         private:
138         armnnDeserializer::IDeserializerPtr m_Parser;
139     };
140 #endif
141 
142 #if defined(ARMNN_TF_LITE_PARSER)
143     class TfliteParser : public IParser
144     {
145     public:
146         TfliteParser(const ExecuteNetworkParams& params);
147 
148         armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) override;
149         armnn::BindingPointInfo GetInputBindingPointInfo(size_t subgraphId, const std::string& inputName) override;
150         armnn::BindingPointInfo GetOutputBindingPointInfo(size_t subgraphId, const std::string& outputName) override;
151 
152     private:
__anondf7b68c10102()153         armnnTfLiteParser::ITfLiteParserPtr m_Parser{nullptr, [](armnnTfLiteParser::ITfLiteParser*){}};
154     };
155 #endif
156 
157 #if defined(ARMNN_ONNX_PARSER)
158     class OnnxParser : public IParser
159     {
160         public:
161         OnnxParser();
162 
163         armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) override;
164         armnn::BindingPointInfo GetInputBindingPointInfo(size_t subgraphId, const std::string& inputName) override;
165         armnn::BindingPointInfo GetOutputBindingPointInfo(size_t subgraphId, const std::string& outputName) override;
166 
167         private:
168         armnnOnnxParser::IOnnxParserPtr m_Parser;
169     };
170 #endif
171 };