xref: /aosp_15_r20/external/armnn/tests/InferenceModel.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/ArmNN.hpp>
10*89c4ff92SAndroid Build Coastguard Worker 
11*89c4ff92SAndroid Build Coastguard Worker #if !defined(ARMNN_DISABLE_THREADS)
12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Threadpool.hpp>
13*89c4ff92SAndroid Build Coastguard Worker #include <common/include/IgnoreUnused.hpp>
14*89c4ff92SAndroid Build Coastguard Worker #endif
15*89c4ff92SAndroid Build Coastguard Worker 
16*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Logging.hpp>
17*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Timer.hpp>
18*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendRegistry.hpp>
19*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Assert.hpp>
20*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/NumericCast.hpp>
21*89c4ff92SAndroid Build Coastguard Worker 
22*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/TContainer.hpp>
23*89c4ff92SAndroid Build Coastguard Worker #include "NetworkExecutionUtils/NetworkExecutionUtils.hpp"
24*89c4ff92SAndroid Build Coastguard Worker 
25*89c4ff92SAndroid Build Coastguard Worker #include <common/include/ProfilingGuid.hpp>
26*89c4ff92SAndroid Build Coastguard Worker 
27*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_SERIALIZER)
28*89c4ff92SAndroid Build Coastguard Worker #include "armnnDeserializer/IDeserializer.hpp"
29*89c4ff92SAndroid Build Coastguard Worker #endif
30*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_TF_LITE_PARSER)
31*89c4ff92SAndroid Build Coastguard Worker #include <armnnTfLiteParser/ITfLiteParser.hpp>
32*89c4ff92SAndroid Build Coastguard Worker #endif
33*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_ONNX_PARSER)
34*89c4ff92SAndroid Build Coastguard Worker #include <armnnOnnxParser/IOnnxParser.hpp>
35*89c4ff92SAndroid Build Coastguard Worker #endif
36*89c4ff92SAndroid Build Coastguard Worker 
37*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Filesystem.hpp>
38*89c4ff92SAndroid Build Coastguard Worker #include <HeapProfiling.hpp>
39*89c4ff92SAndroid Build Coastguard Worker #include <TensorIOUtils.hpp>
40*89c4ff92SAndroid Build Coastguard Worker 
41*89c4ff92SAndroid Build Coastguard Worker #include "armnn/utility/StringUtils.hpp"
42*89c4ff92SAndroid Build Coastguard Worker #include <cxxopts/cxxopts.hpp>
43*89c4ff92SAndroid Build Coastguard Worker #include "CxxoptsUtils.hpp"
44*89c4ff92SAndroid Build Coastguard Worker #include <fmt/format.h>
45*89c4ff92SAndroid Build Coastguard Worker #include <mapbox/variant.hpp>
46*89c4ff92SAndroid Build Coastguard Worker 
47*89c4ff92SAndroid Build Coastguard Worker #include <algorithm>
48*89c4ff92SAndroid Build Coastguard Worker #include <iterator>
49*89c4ff92SAndroid Build Coastguard Worker #include <fstream>
50*89c4ff92SAndroid Build Coastguard Worker #include <map>
51*89c4ff92SAndroid Build Coastguard Worker #include <string>
52*89c4ff92SAndroid Build Coastguard Worker #include <vector>
53*89c4ff92SAndroid Build Coastguard Worker #include <type_traits>
54*89c4ff92SAndroid Build Coastguard Worker 
55*89c4ff92SAndroid Build Coastguard Worker namespace InferenceModelInternal
56*89c4ff92SAndroid Build Coastguard Worker {
57*89c4ff92SAndroid Build Coastguard Worker using BindingPointInfo = armnn::BindingPointInfo;
58*89c4ff92SAndroid Build Coastguard Worker 
59*89c4ff92SAndroid Build Coastguard Worker using QuantizationParams = std::pair<float,int32_t>;
60*89c4ff92SAndroid Build Coastguard Worker 
61*89c4ff92SAndroid Build Coastguard Worker struct Params
62*89c4ff92SAndroid Build Coastguard Worker {
63*89c4ff92SAndroid Build Coastguard Worker     std::string                     m_ModelPath;
64*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::string>        m_InputBindings;
65*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::TensorShape> m_InputShapes;
66*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::string>        m_OutputBindings;
67*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId>   m_ComputeDevices;
68*89c4ff92SAndroid Build Coastguard Worker     std::string                     m_DynamicBackendsPath;
69*89c4ff92SAndroid Build Coastguard Worker     size_t                          m_SubgraphId;
70*89c4ff92SAndroid Build Coastguard Worker     bool                            m_AllowExpandedDims;
71*89c4ff92SAndroid Build Coastguard Worker     bool                            m_IsModelBinary;
72*89c4ff92SAndroid Build Coastguard Worker     bool                            m_VisualizePostOptimizationModel;
73*89c4ff92SAndroid Build Coastguard Worker     bool                            m_EnableFp16TurboMode;
74*89c4ff92SAndroid Build Coastguard Worker     bool                            m_EnableBf16TurboMode;
75*89c4ff92SAndroid Build Coastguard Worker     bool                            m_PrintIntermediateLayers;
76*89c4ff92SAndroid Build Coastguard Worker     bool                            m_PrintIntermediateLayersToFile;
77*89c4ff92SAndroid Build Coastguard Worker     bool                            m_ParseUnsupported;
78*89c4ff92SAndroid Build Coastguard Worker     bool                            m_InferOutputShape;
79*89c4ff92SAndroid Build Coastguard Worker     bool                            m_EnableFastMath;
80*89c4ff92SAndroid Build Coastguard Worker     bool                            m_SaveCachedNetwork;
81*89c4ff92SAndroid Build Coastguard Worker     bool                            m_OutputDetailsToStdOut;
82*89c4ff92SAndroid Build Coastguard Worker     bool                            m_OutputDetailsOnlyToStdOut;
83*89c4ff92SAndroid Build Coastguard Worker     std::string                     m_CachedNetworkFilePath;
84*89c4ff92SAndroid Build Coastguard Worker     unsigned int                    m_NumberOfThreads;
85*89c4ff92SAndroid Build Coastguard Worker     std::string                     m_MLGOTuningFilePath;
86*89c4ff92SAndroid Build Coastguard Worker     bool                            m_AsyncEnabled;
87*89c4ff92SAndroid Build Coastguard Worker     size_t                          m_ThreadPoolSize;
88*89c4ff92SAndroid Build Coastguard Worker     bool                            m_ImportInputsIfAligned;
89*89c4ff92SAndroid Build Coastguard Worker 
90*89c4ff92SAndroid Build Coastguard Worker 
ParamsInferenceModelInternal::Params91*89c4ff92SAndroid Build Coastguard Worker     Params()
92*89c4ff92SAndroid Build Coastguard Worker         : m_ComputeDevices{}
93*89c4ff92SAndroid Build Coastguard Worker         , m_SubgraphId(0)
94*89c4ff92SAndroid Build Coastguard Worker         , m_AllowExpandedDims(false)
95*89c4ff92SAndroid Build Coastguard Worker         , m_IsModelBinary(true)
96*89c4ff92SAndroid Build Coastguard Worker         , m_VisualizePostOptimizationModel(false)
97*89c4ff92SAndroid Build Coastguard Worker         , m_EnableFp16TurboMode(false)
98*89c4ff92SAndroid Build Coastguard Worker         , m_EnableBf16TurboMode(false)
99*89c4ff92SAndroid Build Coastguard Worker         , m_PrintIntermediateLayers(false)
100*89c4ff92SAndroid Build Coastguard Worker         , m_PrintIntermediateLayersToFile(false)
101*89c4ff92SAndroid Build Coastguard Worker         , m_ParseUnsupported(false)
102*89c4ff92SAndroid Build Coastguard Worker         , m_InferOutputShape(false)
103*89c4ff92SAndroid Build Coastguard Worker         , m_EnableFastMath(false)
104*89c4ff92SAndroid Build Coastguard Worker         , m_SaveCachedNetwork(false)
105*89c4ff92SAndroid Build Coastguard Worker         , m_OutputDetailsToStdOut(false)
106*89c4ff92SAndroid Build Coastguard Worker         , m_OutputDetailsOnlyToStdOut(false)
107*89c4ff92SAndroid Build Coastguard Worker         , m_CachedNetworkFilePath("")
108*89c4ff92SAndroid Build Coastguard Worker         , m_NumberOfThreads(0)
109*89c4ff92SAndroid Build Coastguard Worker         , m_MLGOTuningFilePath("")
110*89c4ff92SAndroid Build Coastguard Worker         , m_AsyncEnabled(false)
111*89c4ff92SAndroid Build Coastguard Worker         , m_ThreadPoolSize(0)
112*89c4ff92SAndroid Build Coastguard Worker         , m_ImportInputsIfAligned(false)
113*89c4ff92SAndroid Build Coastguard Worker     {}
114*89c4ff92SAndroid Build Coastguard Worker };
115*89c4ff92SAndroid Build Coastguard Worker 
116*89c4ff92SAndroid Build Coastguard Worker } // namespace InferenceModelInternal
117*89c4ff92SAndroid Build Coastguard Worker 
118*89c4ff92SAndroid Build Coastguard Worker template <typename IParser>
119*89c4ff92SAndroid Build Coastguard Worker struct CreateNetworkImpl
120*89c4ff92SAndroid Build Coastguard Worker {
121*89c4ff92SAndroid Build Coastguard Worker public:
122*89c4ff92SAndroid Build Coastguard Worker     using Params = InferenceModelInternal::Params;
123*89c4ff92SAndroid Build Coastguard Worker 
CreateCreateNetworkImpl124*89c4ff92SAndroid Build Coastguard Worker     static armnn::INetworkPtr Create(const Params& params,
125*89c4ff92SAndroid Build Coastguard Worker                                      std::vector<armnn::BindingPointInfo>& inputBindings,
126*89c4ff92SAndroid Build Coastguard Worker                                      std::vector<armnn::BindingPointInfo>& outputBindings)
127*89c4ff92SAndroid Build Coastguard Worker     {
128*89c4ff92SAndroid Build Coastguard Worker         const std::string& modelPath = params.m_ModelPath;
129*89c4ff92SAndroid Build Coastguard Worker 
130*89c4ff92SAndroid Build Coastguard Worker         // Create a network from a file on disk
131*89c4ff92SAndroid Build Coastguard Worker         auto parser(IParser::Create());
132*89c4ff92SAndroid Build Coastguard Worker 
133*89c4ff92SAndroid Build Coastguard Worker         std::map<std::string, armnn::TensorShape> inputShapes;
134*89c4ff92SAndroid Build Coastguard Worker         if (!params.m_InputShapes.empty())
135*89c4ff92SAndroid Build Coastguard Worker         {
136*89c4ff92SAndroid Build Coastguard Worker             const size_t numInputShapes   = params.m_InputShapes.size();
137*89c4ff92SAndroid Build Coastguard Worker             const size_t numInputBindings = params.m_InputBindings.size();
138*89c4ff92SAndroid Build Coastguard Worker             if (numInputShapes < numInputBindings)
139*89c4ff92SAndroid Build Coastguard Worker             {
140*89c4ff92SAndroid Build Coastguard Worker                 throw armnn::Exception(fmt::format(
141*89c4ff92SAndroid Build Coastguard Worker                     "Not every input has its tensor shape specified: expected={0}, got={1}",
142*89c4ff92SAndroid Build Coastguard Worker                     numInputBindings, numInputShapes));
143*89c4ff92SAndroid Build Coastguard Worker             }
144*89c4ff92SAndroid Build Coastguard Worker 
145*89c4ff92SAndroid Build Coastguard Worker             for (size_t i = 0; i < numInputShapes; i++)
146*89c4ff92SAndroid Build Coastguard Worker             {
147*89c4ff92SAndroid Build Coastguard Worker                 inputShapes[params.m_InputBindings[i]] = params.m_InputShapes[i];
148*89c4ff92SAndroid Build Coastguard Worker             }
149*89c4ff92SAndroid Build Coastguard Worker         }
150*89c4ff92SAndroid Build Coastguard Worker 
151*89c4ff92SAndroid Build Coastguard Worker         std::vector<std::string> requestedOutputs = params.m_OutputBindings;
152*89c4ff92SAndroid Build Coastguard Worker         armnn::INetworkPtr network{nullptr, [](armnn::INetwork *){}};
153*89c4ff92SAndroid Build Coastguard Worker 
154*89c4ff92SAndroid Build Coastguard Worker         {
155*89c4ff92SAndroid Build Coastguard Worker             ARMNN_SCOPED_HEAP_PROFILING("Parsing");
156*89c4ff92SAndroid Build Coastguard Worker             // Handle text and binary input differently by calling the corresponding parser function
157*89c4ff92SAndroid Build Coastguard Worker             network = (params.m_IsModelBinary ?
158*89c4ff92SAndroid Build Coastguard Worker                 parser->CreateNetworkFromBinaryFile(modelPath.c_str(), inputShapes, requestedOutputs) :
159*89c4ff92SAndroid Build Coastguard Worker                 parser->CreateNetworkFromTextFile(modelPath.c_str(), inputShapes, requestedOutputs));
160*89c4ff92SAndroid Build Coastguard Worker         }
161*89c4ff92SAndroid Build Coastguard Worker 
162*89c4ff92SAndroid Build Coastguard Worker         for (const std::string& inputLayerName : params.m_InputBindings)
163*89c4ff92SAndroid Build Coastguard Worker         {
164*89c4ff92SAndroid Build Coastguard Worker             inputBindings.push_back(parser->GetNetworkInputBindingInfo(inputLayerName));
165*89c4ff92SAndroid Build Coastguard Worker         }
166*89c4ff92SAndroid Build Coastguard Worker 
167*89c4ff92SAndroid Build Coastguard Worker         for (const std::string& outputLayerName : params.m_OutputBindings)
168*89c4ff92SAndroid Build Coastguard Worker         {
169*89c4ff92SAndroid Build Coastguard Worker             outputBindings.push_back(parser->GetNetworkOutputBindingInfo(outputLayerName));
170*89c4ff92SAndroid Build Coastguard Worker         }
171*89c4ff92SAndroid Build Coastguard Worker 
172*89c4ff92SAndroid Build Coastguard Worker         return network;
173*89c4ff92SAndroid Build Coastguard Worker     }
174*89c4ff92SAndroid Build Coastguard Worker };
175*89c4ff92SAndroid Build Coastguard Worker 
176*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_SERIALIZER)
177*89c4ff92SAndroid Build Coastguard Worker template <>
178*89c4ff92SAndroid Build Coastguard Worker struct CreateNetworkImpl<armnnDeserializer::IDeserializer>
179*89c4ff92SAndroid Build Coastguard Worker {
180*89c4ff92SAndroid Build Coastguard Worker public:
181*89c4ff92SAndroid Build Coastguard Worker     using IParser          = armnnDeserializer::IDeserializer;
182*89c4ff92SAndroid Build Coastguard Worker     using Params           = InferenceModelInternal::Params;
183*89c4ff92SAndroid Build Coastguard Worker 
CreateCreateNetworkImpl184*89c4ff92SAndroid Build Coastguard Worker     static armnn::INetworkPtr Create(const Params& params,
185*89c4ff92SAndroid Build Coastguard Worker                                      std::vector<armnn::BindingPointInfo>& inputBindings,
186*89c4ff92SAndroid Build Coastguard Worker                                      std::vector<armnn::BindingPointInfo>& outputBindings)
187*89c4ff92SAndroid Build Coastguard Worker     {
188*89c4ff92SAndroid Build Coastguard Worker         auto parser(IParser::Create());
189*89c4ff92SAndroid Build Coastguard Worker         ARMNN_ASSERT(parser);
190*89c4ff92SAndroid Build Coastguard Worker 
191*89c4ff92SAndroid Build Coastguard Worker         armnn::INetworkPtr network{nullptr, [](armnn::INetwork *){}};
192*89c4ff92SAndroid Build Coastguard Worker 
193*89c4ff92SAndroid Build Coastguard Worker         {
194*89c4ff92SAndroid Build Coastguard Worker             ARMNN_SCOPED_HEAP_PROFILING("Parsing");
195*89c4ff92SAndroid Build Coastguard Worker 
196*89c4ff92SAndroid Build Coastguard Worker             std::error_code errorCode;
197*89c4ff92SAndroid Build Coastguard Worker             fs::path pathToFile(params.m_ModelPath);
198*89c4ff92SAndroid Build Coastguard Worker             if (!fs::exists(pathToFile, errorCode))
199*89c4ff92SAndroid Build Coastguard Worker             {
200*89c4ff92SAndroid Build Coastguard Worker                 throw armnn::FileNotFoundException(fmt::format("Cannot find the file ({0}) errorCode: {1} {2}",
201*89c4ff92SAndroid Build Coastguard Worker                                                    params.m_ModelPath,
202*89c4ff92SAndroid Build Coastguard Worker                                                    errorCode.message(),
203*89c4ff92SAndroid Build Coastguard Worker                                                    CHECK_LOCATION().AsString()));
204*89c4ff92SAndroid Build Coastguard Worker             }
205*89c4ff92SAndroid Build Coastguard Worker             std::ifstream file(params.m_ModelPath, std::ios::binary);
206*89c4ff92SAndroid Build Coastguard Worker 
207*89c4ff92SAndroid Build Coastguard Worker             network = parser->CreateNetworkFromBinary(file);
208*89c4ff92SAndroid Build Coastguard Worker         }
209*89c4ff92SAndroid Build Coastguard Worker 
210*89c4ff92SAndroid Build Coastguard Worker         unsigned int subgraphId = armnn::numeric_cast<unsigned int>(params.m_SubgraphId);
211*89c4ff92SAndroid Build Coastguard Worker 
212*89c4ff92SAndroid Build Coastguard Worker         for (const std::string& inputLayerName : params.m_InputBindings)
213*89c4ff92SAndroid Build Coastguard Worker         {
214*89c4ff92SAndroid Build Coastguard Worker             armnnDeserializer::BindingPointInfo inputBinding =
215*89c4ff92SAndroid Build Coastguard Worker                 parser->GetNetworkInputBindingInfo(subgraphId, inputLayerName);
216*89c4ff92SAndroid Build Coastguard Worker             inputBindings.push_back(std::make_pair(inputBinding.m_BindingId, inputBinding.m_TensorInfo));
217*89c4ff92SAndroid Build Coastguard Worker         }
218*89c4ff92SAndroid Build Coastguard Worker 
219*89c4ff92SAndroid Build Coastguard Worker         for (const std::string& outputLayerName : params.m_OutputBindings)
220*89c4ff92SAndroid Build Coastguard Worker         {
221*89c4ff92SAndroid Build Coastguard Worker             armnnDeserializer::BindingPointInfo outputBinding =
222*89c4ff92SAndroid Build Coastguard Worker                 parser->GetNetworkOutputBindingInfo(subgraphId, outputLayerName);
223*89c4ff92SAndroid Build Coastguard Worker             outputBindings.push_back(std::make_pair(outputBinding.m_BindingId, outputBinding.m_TensorInfo));
224*89c4ff92SAndroid Build Coastguard Worker         }
225*89c4ff92SAndroid Build Coastguard Worker 
226*89c4ff92SAndroid Build Coastguard Worker         return network;
227*89c4ff92SAndroid Build Coastguard Worker     }
228*89c4ff92SAndroid Build Coastguard Worker };
229*89c4ff92SAndroid Build Coastguard Worker #endif
230*89c4ff92SAndroid Build Coastguard Worker 
231*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_TF_LITE_PARSER)
232*89c4ff92SAndroid Build Coastguard Worker template <>
233*89c4ff92SAndroid Build Coastguard Worker struct CreateNetworkImpl<armnnTfLiteParser::ITfLiteParser>
234*89c4ff92SAndroid Build Coastguard Worker {
235*89c4ff92SAndroid Build Coastguard Worker public:
236*89c4ff92SAndroid Build Coastguard Worker     using IParser = armnnTfLiteParser::ITfLiteParser;
237*89c4ff92SAndroid Build Coastguard Worker     using Params = InferenceModelInternal::Params;
238*89c4ff92SAndroid Build Coastguard Worker 
CreateCreateNetworkImpl239*89c4ff92SAndroid Build Coastguard Worker     static armnn::INetworkPtr Create(const Params& params,
240*89c4ff92SAndroid Build Coastguard Worker                                      std::vector<armnn::BindingPointInfo>& inputBindings,
241*89c4ff92SAndroid Build Coastguard Worker                                      std::vector<armnn::BindingPointInfo>& outputBindings)
242*89c4ff92SAndroid Build Coastguard Worker     {
243*89c4ff92SAndroid Build Coastguard Worker         const std::string& modelPath = params.m_ModelPath;
244*89c4ff92SAndroid Build Coastguard Worker 
245*89c4ff92SAndroid Build Coastguard Worker         // Create a network from a file on disk
246*89c4ff92SAndroid Build Coastguard Worker         IParser::TfLiteParserOptions options;
247*89c4ff92SAndroid Build Coastguard Worker         options.m_AllowExpandedDims          = params.m_AllowExpandedDims;
248*89c4ff92SAndroid Build Coastguard Worker         options.m_StandInLayerForUnsupported = params.m_ParseUnsupported;
249*89c4ff92SAndroid Build Coastguard Worker         options.m_InferAndValidate           = params.m_InferOutputShape;
250*89c4ff92SAndroid Build Coastguard Worker         auto parser(IParser::Create(options));
251*89c4ff92SAndroid Build Coastguard Worker 
252*89c4ff92SAndroid Build Coastguard Worker         armnn::INetworkPtr network{nullptr, [](armnn::INetwork *){}};
253*89c4ff92SAndroid Build Coastguard Worker 
254*89c4ff92SAndroid Build Coastguard Worker         {
255*89c4ff92SAndroid Build Coastguard Worker             ARMNN_SCOPED_HEAP_PROFILING("Parsing");
256*89c4ff92SAndroid Build Coastguard Worker             network = parser->CreateNetworkFromBinaryFile(modelPath.c_str());
257*89c4ff92SAndroid Build Coastguard Worker         }
258*89c4ff92SAndroid Build Coastguard Worker 
259*89c4ff92SAndroid Build Coastguard Worker         for (const std::string& inputLayerName : params.m_InputBindings)
260*89c4ff92SAndroid Build Coastguard Worker         {
261*89c4ff92SAndroid Build Coastguard Worker             armnn::BindingPointInfo inputBinding =
262*89c4ff92SAndroid Build Coastguard Worker                 parser->GetNetworkInputBindingInfo(params.m_SubgraphId, inputLayerName);
263*89c4ff92SAndroid Build Coastguard Worker             inputBindings.push_back(inputBinding);
264*89c4ff92SAndroid Build Coastguard Worker         }
265*89c4ff92SAndroid Build Coastguard Worker 
266*89c4ff92SAndroid Build Coastguard Worker         for (const std::string& outputLayerName : params.m_OutputBindings)
267*89c4ff92SAndroid Build Coastguard Worker         {
268*89c4ff92SAndroid Build Coastguard Worker             armnn::BindingPointInfo outputBinding =
269*89c4ff92SAndroid Build Coastguard Worker                 parser->GetNetworkOutputBindingInfo(params.m_SubgraphId, outputLayerName);
270*89c4ff92SAndroid Build Coastguard Worker             outputBindings.push_back(outputBinding);
271*89c4ff92SAndroid Build Coastguard Worker         }
272*89c4ff92SAndroid Build Coastguard Worker 
273*89c4ff92SAndroid Build Coastguard Worker         return network;
274*89c4ff92SAndroid Build Coastguard Worker     }
275*89c4ff92SAndroid Build Coastguard Worker };
276*89c4ff92SAndroid Build Coastguard Worker #endif
277*89c4ff92SAndroid Build Coastguard Worker 
278*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_ONNX_PARSER)
279*89c4ff92SAndroid Build Coastguard Worker template <>
280*89c4ff92SAndroid Build Coastguard Worker struct CreateNetworkImpl<armnnOnnxParser::IOnnxParser>
281*89c4ff92SAndroid Build Coastguard Worker {
282*89c4ff92SAndroid Build Coastguard Worker public:
283*89c4ff92SAndroid Build Coastguard Worker     using IParser = armnnOnnxParser::IOnnxParser;
284*89c4ff92SAndroid Build Coastguard Worker     using Params = InferenceModelInternal::Params;
285*89c4ff92SAndroid Build Coastguard Worker     using BindingPointInfo = InferenceModelInternal::BindingPointInfo;
286*89c4ff92SAndroid Build Coastguard Worker 
CreateCreateNetworkImpl287*89c4ff92SAndroid Build Coastguard Worker     static armnn::INetworkPtr Create(const Params& params,
288*89c4ff92SAndroid Build Coastguard Worker                                      std::vector<BindingPointInfo>& inputBindings,
289*89c4ff92SAndroid Build Coastguard Worker                                      std::vector<BindingPointInfo>& outputBindings)
290*89c4ff92SAndroid Build Coastguard Worker     {
291*89c4ff92SAndroid Build Coastguard Worker         const std::string& modelPath = params.m_ModelPath;
292*89c4ff92SAndroid Build Coastguard Worker 
293*89c4ff92SAndroid Build Coastguard Worker         // Create a network from a file on disk
294*89c4ff92SAndroid Build Coastguard Worker         auto parser(IParser::Create());
295*89c4ff92SAndroid Build Coastguard Worker 
296*89c4ff92SAndroid Build Coastguard Worker         armnn::INetworkPtr network{nullptr, [](armnn::INetwork *){}};
297*89c4ff92SAndroid Build Coastguard Worker 
298*89c4ff92SAndroid Build Coastguard Worker         std::map<std::string, armnn::TensorShape> inputShapes;
299*89c4ff92SAndroid Build Coastguard Worker         if (!params.m_InputShapes.empty())
300*89c4ff92SAndroid Build Coastguard Worker         {
301*89c4ff92SAndroid Build Coastguard Worker             const size_t numInputShapes   = params.m_InputShapes.size();
302*89c4ff92SAndroid Build Coastguard Worker             const size_t numInputBindings = params.m_InputBindings.size();
303*89c4ff92SAndroid Build Coastguard Worker             if (numInputShapes < numInputBindings)
304*89c4ff92SAndroid Build Coastguard Worker             {
305*89c4ff92SAndroid Build Coastguard Worker                 throw armnn::Exception(fmt::format(
306*89c4ff92SAndroid Build Coastguard Worker                     "Not every input has its tensor shape specified: expected={0}, got={1}",
307*89c4ff92SAndroid Build Coastguard Worker                     numInputBindings, numInputShapes));
308*89c4ff92SAndroid Build Coastguard Worker             }
309*89c4ff92SAndroid Build Coastguard Worker 
310*89c4ff92SAndroid Build Coastguard Worker             for (size_t i = 0; i < numInputShapes; i++)
311*89c4ff92SAndroid Build Coastguard Worker             {
312*89c4ff92SAndroid Build Coastguard Worker                 inputShapes[params.m_InputBindings[i]] = params.m_InputShapes[i];
313*89c4ff92SAndroid Build Coastguard Worker             }
314*89c4ff92SAndroid Build Coastguard Worker 
315*89c4ff92SAndroid Build Coastguard Worker             {
316*89c4ff92SAndroid Build Coastguard Worker                 ARMNN_SCOPED_HEAP_PROFILING("Parsing");
317*89c4ff92SAndroid Build Coastguard Worker                 network = (params.m_IsModelBinary ?
318*89c4ff92SAndroid Build Coastguard Worker                     parser->CreateNetworkFromBinaryFile(modelPath.c_str(), inputShapes) :
319*89c4ff92SAndroid Build Coastguard Worker                     parser->CreateNetworkFromTextFile(modelPath.c_str(), inputShapes));
320*89c4ff92SAndroid Build Coastguard Worker             }
321*89c4ff92SAndroid Build Coastguard Worker         }
322*89c4ff92SAndroid Build Coastguard Worker 
323*89c4ff92SAndroid Build Coastguard Worker         else
324*89c4ff92SAndroid Build Coastguard Worker         {
325*89c4ff92SAndroid Build Coastguard Worker             ARMNN_SCOPED_HEAP_PROFILING("Parsing");
326*89c4ff92SAndroid Build Coastguard Worker             network = (params.m_IsModelBinary ?
327*89c4ff92SAndroid Build Coastguard Worker                 parser->CreateNetworkFromBinaryFile(modelPath.c_str()) :
328*89c4ff92SAndroid Build Coastguard Worker                 parser->CreateNetworkFromTextFile(modelPath.c_str()));
329*89c4ff92SAndroid Build Coastguard Worker         }
330*89c4ff92SAndroid Build Coastguard Worker 
331*89c4ff92SAndroid Build Coastguard Worker         for (const std::string& inputLayerName : params.m_InputBindings)
332*89c4ff92SAndroid Build Coastguard Worker         {
333*89c4ff92SAndroid Build Coastguard Worker             BindingPointInfo inputBinding = parser->GetNetworkInputBindingInfo(inputLayerName);
334*89c4ff92SAndroid Build Coastguard Worker             inputBindings.push_back(inputBinding);
335*89c4ff92SAndroid Build Coastguard Worker         }
336*89c4ff92SAndroid Build Coastguard Worker 
337*89c4ff92SAndroid Build Coastguard Worker         for (const std::string& outputLayerName : params.m_OutputBindings)
338*89c4ff92SAndroid Build Coastguard Worker         {
339*89c4ff92SAndroid Build Coastguard Worker             BindingPointInfo outputBinding = parser->GetNetworkOutputBindingInfo(outputLayerName);
340*89c4ff92SAndroid Build Coastguard Worker             outputBindings.push_back(outputBinding);
341*89c4ff92SAndroid Build Coastguard Worker         }
342*89c4ff92SAndroid Build Coastguard Worker 
343*89c4ff92SAndroid Build Coastguard Worker         return network;
344*89c4ff92SAndroid Build Coastguard Worker     }
345*89c4ff92SAndroid Build Coastguard Worker };
346*89c4ff92SAndroid Build Coastguard Worker #endif
347*89c4ff92SAndroid Build Coastguard Worker 
348*89c4ff92SAndroid Build Coastguard Worker 
349*89c4ff92SAndroid Build Coastguard Worker 
350*89c4ff92SAndroid Build Coastguard Worker template <typename IParser, typename TDataType>
351*89c4ff92SAndroid Build Coastguard Worker class InferenceModel
352*89c4ff92SAndroid Build Coastguard Worker {
353*89c4ff92SAndroid Build Coastguard Worker public:
354*89c4ff92SAndroid Build Coastguard Worker     using DataType           = TDataType;
355*89c4ff92SAndroid Build Coastguard Worker     using Params             = InferenceModelInternal::Params;
356*89c4ff92SAndroid Build Coastguard Worker     using QuantizationParams = InferenceModelInternal::QuantizationParams;
357*89c4ff92SAndroid Build Coastguard Worker 
358*89c4ff92SAndroid Build Coastguard Worker 
359*89c4ff92SAndroid Build Coastguard Worker     struct CommandLineOptions
360*89c4ff92SAndroid Build Coastguard Worker     {
361*89c4ff92SAndroid Build Coastguard Worker         std::string m_ModelDir;
362*89c4ff92SAndroid Build Coastguard Worker         std::vector<std::string> m_ComputeDevices;
363*89c4ff92SAndroid Build Coastguard Worker         std::string m_DynamicBackendsPath;
364*89c4ff92SAndroid Build Coastguard Worker         bool m_VisualizePostOptimizationModel;
365*89c4ff92SAndroid Build Coastguard Worker         bool m_EnableFp16TurboMode;
366*89c4ff92SAndroid Build Coastguard Worker         bool m_EnableBf16TurboMode;
367*89c4ff92SAndroid Build Coastguard Worker         std::string m_Labels;
368*89c4ff92SAndroid Build Coastguard Worker 
GetComputeDevicesAsBackendIdsInferenceModel::CommandLineOptions369*89c4ff92SAndroid Build Coastguard Worker         std::vector<armnn::BackendId> GetComputeDevicesAsBackendIds()
370*89c4ff92SAndroid Build Coastguard Worker         {
371*89c4ff92SAndroid Build Coastguard Worker             std::vector<armnn::BackendId> backendIds;
372*89c4ff92SAndroid Build Coastguard Worker             std::copy(m_ComputeDevices.begin(), m_ComputeDevices.end(), std::back_inserter(backendIds));
373*89c4ff92SAndroid Build Coastguard Worker             return backendIds;
374*89c4ff92SAndroid Build Coastguard Worker         }
375*89c4ff92SAndroid Build Coastguard Worker     };
376*89c4ff92SAndroid Build Coastguard Worker 
AddCommandLineOptions(cxxopts::Options & options,CommandLineOptions & cLineOptions,std::vector<std::string> & required)377*89c4ff92SAndroid Build Coastguard Worker     static void AddCommandLineOptions(cxxopts::Options& options,
378*89c4ff92SAndroid Build Coastguard Worker                                       CommandLineOptions& cLineOptions, std::vector<std::string>& required)
379*89c4ff92SAndroid Build Coastguard Worker     {
380*89c4ff92SAndroid Build Coastguard Worker         const std::vector<std::string> defaultComputes = { "CpuAcc", "CpuRef" };
381*89c4ff92SAndroid Build Coastguard Worker 
382*89c4ff92SAndroid Build Coastguard Worker         const std::string backendsMessage = "Which device to run layers on by default. Possible choices: "
383*89c4ff92SAndroid Build Coastguard Worker                                           + armnn::BackendRegistryInstance().GetBackendIdsAsString();
384*89c4ff92SAndroid Build Coastguard Worker 
385*89c4ff92SAndroid Build Coastguard Worker         options
386*89c4ff92SAndroid Build Coastguard Worker             .allow_unrecognised_options()
387*89c4ff92SAndroid Build Coastguard Worker             .add_options()
388*89c4ff92SAndroid Build Coastguard Worker                 ("m,model-dir", "Path to directory containing model files (.prototxt/.tflite)",
389*89c4ff92SAndroid Build Coastguard Worker                  cxxopts::value<std::string>(cLineOptions.m_ModelDir))
390*89c4ff92SAndroid Build Coastguard Worker                 ("c,compute", backendsMessage.c_str(),
391*89c4ff92SAndroid Build Coastguard Worker                  cxxopts::value<std::vector<std::string>>(cLineOptions.m_ComputeDevices)->default_value("CpuRef"))
392*89c4ff92SAndroid Build Coastguard Worker                 ("b,dynamic-backends-path",
393*89c4ff92SAndroid Build Coastguard Worker                  "Path where to load any available dynamic backend from. "
394*89c4ff92SAndroid Build Coastguard Worker                  "If left empty (the default), dynamic backends will not be used.",
395*89c4ff92SAndroid Build Coastguard Worker                  cxxopts::value(cLineOptions.m_DynamicBackendsPath))
396*89c4ff92SAndroid Build Coastguard Worker                 ("l,labels",
397*89c4ff92SAndroid Build Coastguard Worker                  "Text file containing one image filename - correct label pair per line, "
398*89c4ff92SAndroid Build Coastguard Worker                  "used to test the accuracy of the network.", cxxopts::value<std::string>(cLineOptions.m_Labels))
399*89c4ff92SAndroid Build Coastguard Worker                 ("v,visualize-optimized-model",
400*89c4ff92SAndroid Build Coastguard Worker                  "Produce a dot file useful for visualizing the graph post optimization."
401*89c4ff92SAndroid Build Coastguard Worker                  "The file will have the same name as the model with the .dot extention.",
402*89c4ff92SAndroid Build Coastguard Worker                  cxxopts::value<bool>(cLineOptions.m_VisualizePostOptimizationModel)->default_value("false"))
403*89c4ff92SAndroid Build Coastguard Worker                 ("fp16-turbo-mode",
404*89c4ff92SAndroid Build Coastguard Worker                  "If this option is enabled FP32 layers, weights and biases will be converted "
405*89c4ff92SAndroid Build Coastguard Worker                  "to FP16 where the backend supports it.",
406*89c4ff92SAndroid Build Coastguard Worker                  cxxopts::value<bool>(cLineOptions.m_EnableFp16TurboMode)->default_value("false"))
407*89c4ff92SAndroid Build Coastguard Worker                 ("bf16-turbo-mode",
408*89c4ff92SAndroid Build Coastguard Worker                  "If this option is enabled FP32 layers, weights and biases will be converted "
409*89c4ff92SAndroid Build Coastguard Worker                  "to BF16 where the backend supports it.",
410*89c4ff92SAndroid Build Coastguard Worker                  cxxopts::value<bool>(cLineOptions.m_EnableBf16TurboMode)->default_value("false"));
411*89c4ff92SAndroid Build Coastguard Worker 
412*89c4ff92SAndroid Build Coastguard Worker         required.emplace_back("model-dir");
413*89c4ff92SAndroid Build Coastguard Worker     }
414*89c4ff92SAndroid Build Coastguard Worker 
InferenceModel(const Params & params,bool enableProfiling,const std::string & dynamicBackendsPath,const std::shared_ptr<armnn::IRuntime> & runtime=nullptr)415*89c4ff92SAndroid Build Coastguard Worker     InferenceModel(const Params& params,
416*89c4ff92SAndroid Build Coastguard Worker                    bool enableProfiling,
417*89c4ff92SAndroid Build Coastguard Worker                    const std::string& dynamicBackendsPath,
418*89c4ff92SAndroid Build Coastguard Worker                    const std::shared_ptr<armnn::IRuntime>& runtime = nullptr)
419*89c4ff92SAndroid Build Coastguard Worker         : m_EnableProfiling(enableProfiling),
420*89c4ff92SAndroid Build Coastguard Worker           m_ProfilingDetailsMethod(armnn::ProfilingDetailsMethod::Undefined),
421*89c4ff92SAndroid Build Coastguard Worker           m_DynamicBackendsPath(dynamicBackendsPath),
422*89c4ff92SAndroid Build Coastguard Worker           m_ImportInputsIfAligned(params.m_ImportInputsIfAligned)
423*89c4ff92SAndroid Build Coastguard Worker     {
424*89c4ff92SAndroid Build Coastguard Worker         if (runtime)
425*89c4ff92SAndroid Build Coastguard Worker         {
426*89c4ff92SAndroid Build Coastguard Worker             m_Runtime = runtime;
427*89c4ff92SAndroid Build Coastguard Worker         }
428*89c4ff92SAndroid Build Coastguard Worker         else
429*89c4ff92SAndroid Build Coastguard Worker         {
430*89c4ff92SAndroid Build Coastguard Worker             armnn::IRuntime::CreationOptions options;
431*89c4ff92SAndroid Build Coastguard Worker             options.m_EnableGpuProfiling = m_EnableProfiling;
432*89c4ff92SAndroid Build Coastguard Worker             options.m_DynamicBackendsPath = m_DynamicBackendsPath;
433*89c4ff92SAndroid Build Coastguard Worker             m_Runtime = armnn::IRuntime::Create(options);
434*89c4ff92SAndroid Build Coastguard Worker         }
435*89c4ff92SAndroid Build Coastguard Worker 
436*89c4ff92SAndroid Build Coastguard Worker         // Configure the Profiler if the the profiling details are opted for
437*89c4ff92SAndroid Build Coastguard Worker         if (params.m_OutputDetailsOnlyToStdOut)
438*89c4ff92SAndroid Build Coastguard Worker             m_ProfilingDetailsMethod = armnn::ProfilingDetailsMethod::DetailsOnly;
439*89c4ff92SAndroid Build Coastguard Worker         else if (params.m_OutputDetailsToStdOut)
440*89c4ff92SAndroid Build Coastguard Worker             m_ProfilingDetailsMethod = armnn::ProfilingDetailsMethod::DetailsWithEvents;
441*89c4ff92SAndroid Build Coastguard Worker 
442*89c4ff92SAndroid Build Coastguard Worker         std::string invalidBackends;
443*89c4ff92SAndroid Build Coastguard Worker         if (!CheckRequestedBackendsAreValid(params.m_ComputeDevices, armnn::Optional<std::string&>(invalidBackends)))
444*89c4ff92SAndroid Build Coastguard Worker         {
445*89c4ff92SAndroid Build Coastguard Worker             throw armnn::Exception("Some backend IDs are invalid: " + invalidBackends);
446*89c4ff92SAndroid Build Coastguard Worker         }
447*89c4ff92SAndroid Build Coastguard Worker 
448*89c4ff92SAndroid Build Coastguard Worker         armnn::IOptimizedNetworkPtr optNet{nullptr, [](armnn::IOptimizedNetwork*){}};
449*89c4ff92SAndroid Build Coastguard Worker         {
450*89c4ff92SAndroid Build Coastguard Worker             const auto parsing_start_time = armnn::GetTimeNow();
451*89c4ff92SAndroid Build Coastguard Worker             armnn::INetworkPtr network = CreateNetworkImpl<IParser>::Create(params, m_InputBindings, m_OutputBindings);
452*89c4ff92SAndroid Build Coastguard Worker 
453*89c4ff92SAndroid Build Coastguard Worker             ARMNN_LOG(info) << "Network parsing time: " << std::setprecision(2)
454*89c4ff92SAndroid Build Coastguard Worker                             << std::fixed << armnn::GetTimeDuration(parsing_start_time).count() << " ms.";
455*89c4ff92SAndroid Build Coastguard Worker 
456*89c4ff92SAndroid Build Coastguard Worker             ARMNN_SCOPED_HEAP_PROFILING("Optimizing");
457*89c4ff92SAndroid Build Coastguard Worker 
458*89c4ff92SAndroid Build Coastguard Worker             armnn::OptimizerOptionsOpaque options;
459*89c4ff92SAndroid Build Coastguard Worker             options.SetReduceFp32ToFp16(params.m_EnableFp16TurboMode);
460*89c4ff92SAndroid Build Coastguard Worker             options.SetDebugEnabled(params.m_PrintIntermediateLayers);
461*89c4ff92SAndroid Build Coastguard Worker             options.SetDebugToFileEnabled(params.m_PrintIntermediateLayersToFile);
462*89c4ff92SAndroid Build Coastguard Worker             options.SetShapeInferenceMethod(params.m_InferOutputShape ?
463*89c4ff92SAndroid Build Coastguard Worker                     armnn::ShapeInferenceMethod::InferAndValidate : armnn::ShapeInferenceMethod::ValidateOnly);
464*89c4ff92SAndroid Build Coastguard Worker             options.SetProfilingEnabled(m_EnableProfiling);
465*89c4ff92SAndroid Build Coastguard Worker 
466*89c4ff92SAndroid Build Coastguard Worker             armnn::BackendOptions gpuAcc("GpuAcc",
467*89c4ff92SAndroid Build Coastguard Worker             {
468*89c4ff92SAndroid Build Coastguard Worker                 { "FastMathEnabled", params.m_EnableFastMath },
469*89c4ff92SAndroid Build Coastguard Worker                 { "SaveCachedNetwork", params.m_SaveCachedNetwork },
470*89c4ff92SAndroid Build Coastguard Worker                 { "CachedNetworkFilePath", params.m_CachedNetworkFilePath },
471*89c4ff92SAndroid Build Coastguard Worker                 { "MLGOTuningFilePath", params.m_MLGOTuningFilePath }
472*89c4ff92SAndroid Build Coastguard Worker             });
473*89c4ff92SAndroid Build Coastguard Worker 
474*89c4ff92SAndroid Build Coastguard Worker             armnn::BackendOptions cpuAcc("CpuAcc",
475*89c4ff92SAndroid Build Coastguard Worker             {
476*89c4ff92SAndroid Build Coastguard Worker                 { "FastMathEnabled", params.m_EnableFastMath },
477*89c4ff92SAndroid Build Coastguard Worker                 { "NumberOfThreads", params.m_NumberOfThreads }
478*89c4ff92SAndroid Build Coastguard Worker             });
479*89c4ff92SAndroid Build Coastguard Worker             options.AddModelOption(gpuAcc);
480*89c4ff92SAndroid Build Coastguard Worker             options.AddModelOption(cpuAcc);
481*89c4ff92SAndroid Build Coastguard Worker 
482*89c4ff92SAndroid Build Coastguard Worker             const auto optimization_start_time = armnn::GetTimeNow();
483*89c4ff92SAndroid Build Coastguard Worker             optNet = armnn::Optimize(*network, params.m_ComputeDevices, m_Runtime->GetDeviceSpec(), options);
484*89c4ff92SAndroid Build Coastguard Worker 
485*89c4ff92SAndroid Build Coastguard Worker             ARMNN_LOG(info) << "Optimization time: " << std::setprecision(2)
486*89c4ff92SAndroid Build Coastguard Worker                             << std::fixed << armnn::GetTimeDuration(optimization_start_time).count() << " ms.";
487*89c4ff92SAndroid Build Coastguard Worker 
488*89c4ff92SAndroid Build Coastguard Worker             if (!optNet)
489*89c4ff92SAndroid Build Coastguard Worker             {
490*89c4ff92SAndroid Build Coastguard Worker                 throw armnn::Exception("Optimize returned nullptr");
491*89c4ff92SAndroid Build Coastguard Worker             }
492*89c4ff92SAndroid Build Coastguard Worker 
493*89c4ff92SAndroid Build Coastguard Worker 
494*89c4ff92SAndroid Build Coastguard Worker         }
495*89c4ff92SAndroid Build Coastguard Worker 
496*89c4ff92SAndroid Build Coastguard Worker         if (params.m_VisualizePostOptimizationModel)
497*89c4ff92SAndroid Build Coastguard Worker         {
498*89c4ff92SAndroid Build Coastguard Worker             fs::path filename = params.m_ModelPath;
499*89c4ff92SAndroid Build Coastguard Worker             filename.replace_extension("dot");
500*89c4ff92SAndroid Build Coastguard Worker             std::fstream file(filename.c_str(), std::ios_base::out);
501*89c4ff92SAndroid Build Coastguard Worker             optNet->SerializeToDot(file);
502*89c4ff92SAndroid Build Coastguard Worker         }
503*89c4ff92SAndroid Build Coastguard Worker 
504*89c4ff92SAndroid Build Coastguard Worker         armnn::Status ret;
505*89c4ff92SAndroid Build Coastguard Worker         {
506*89c4ff92SAndroid Build Coastguard Worker             ARMNN_SCOPED_HEAP_PROFILING("LoadNetwork");
507*89c4ff92SAndroid Build Coastguard Worker 
508*89c4ff92SAndroid Build Coastguard Worker             const auto loading_start_time = armnn::GetTimeNow();
509*89c4ff92SAndroid Build Coastguard Worker             armnn::INetworkProperties networkProperties(params.m_AsyncEnabled,
510*89c4ff92SAndroid Build Coastguard Worker                                                         armnn::MemorySource::Undefined,
511*89c4ff92SAndroid Build Coastguard Worker                                                         armnn::MemorySource::Undefined,
512*89c4ff92SAndroid Build Coastguard Worker                                                         enableProfiling,
513*89c4ff92SAndroid Build Coastguard Worker                                                         m_ProfilingDetailsMethod);
514*89c4ff92SAndroid Build Coastguard Worker             std::string errorMessage;
515*89c4ff92SAndroid Build Coastguard Worker             ret = m_Runtime->LoadNetwork(m_NetworkIdentifier, std::move(optNet), errorMessage, networkProperties);
516*89c4ff92SAndroid Build Coastguard Worker 
517*89c4ff92SAndroid Build Coastguard Worker             ARMNN_LOG(info) << "Network loading time: " << std::setprecision(2)
518*89c4ff92SAndroid Build Coastguard Worker                             << std::fixed << armnn::GetTimeDuration(loading_start_time).count() << " ms.";
519*89c4ff92SAndroid Build Coastguard Worker #if !defined(ARMNN_DISABLE_THREADS)
520*89c4ff92SAndroid Build Coastguard Worker             if (params.m_AsyncEnabled && params.m_ThreadPoolSize > 0)
521*89c4ff92SAndroid Build Coastguard Worker             {
522*89c4ff92SAndroid Build Coastguard Worker                 std::vector<std::shared_ptr<armnn::IWorkingMemHandle>> memHandles;
523*89c4ff92SAndroid Build Coastguard Worker                 for (size_t i = 0; i < params.m_ThreadPoolSize; ++i)
524*89c4ff92SAndroid Build Coastguard Worker                 {
525*89c4ff92SAndroid Build Coastguard Worker                     memHandles.emplace_back(m_Runtime->CreateWorkingMemHandle(m_NetworkIdentifier));
526*89c4ff92SAndroid Build Coastguard Worker                 }
527*89c4ff92SAndroid Build Coastguard Worker 
528*89c4ff92SAndroid Build Coastguard Worker                 m_Threadpool = std::make_unique<armnn::Threadpool>(params.m_ThreadPoolSize,
529*89c4ff92SAndroid Build Coastguard Worker                                                                    m_Runtime.get(),
530*89c4ff92SAndroid Build Coastguard Worker                                                                    memHandles);
531*89c4ff92SAndroid Build Coastguard Worker             }
532*89c4ff92SAndroid Build Coastguard Worker #endif
533*89c4ff92SAndroid Build Coastguard Worker         }
534*89c4ff92SAndroid Build Coastguard Worker 
535*89c4ff92SAndroid Build Coastguard Worker         if (ret == armnn::Status::Failure)
536*89c4ff92SAndroid Build Coastguard Worker         {
537*89c4ff92SAndroid Build Coastguard Worker             throw armnn::Exception("IRuntime::LoadNetwork failed");
538*89c4ff92SAndroid Build Coastguard Worker         }
539*89c4ff92SAndroid Build Coastguard Worker     }
540*89c4ff92SAndroid Build Coastguard Worker 
CheckInputIndexIsValid(unsigned int inputIndex) const541*89c4ff92SAndroid Build Coastguard Worker     void CheckInputIndexIsValid(unsigned int inputIndex) const
542*89c4ff92SAndroid Build Coastguard Worker     {
543*89c4ff92SAndroid Build Coastguard Worker         if (m_InputBindings.size() < inputIndex + 1)
544*89c4ff92SAndroid Build Coastguard Worker         {
545*89c4ff92SAndroid Build Coastguard Worker             throw armnn::Exception(fmt::format("Input index out of range: {}", inputIndex));
546*89c4ff92SAndroid Build Coastguard Worker         }
547*89c4ff92SAndroid Build Coastguard Worker     }
548*89c4ff92SAndroid Build Coastguard Worker 
CheckOutputIndexIsValid(unsigned int outputIndex) const549*89c4ff92SAndroid Build Coastguard Worker     void CheckOutputIndexIsValid(unsigned int outputIndex) const
550*89c4ff92SAndroid Build Coastguard Worker     {
551*89c4ff92SAndroid Build Coastguard Worker         if (m_OutputBindings.size() < outputIndex + 1)
552*89c4ff92SAndroid Build Coastguard Worker         {
553*89c4ff92SAndroid Build Coastguard Worker             throw armnn::Exception(fmt::format("Output index out of range: {}", outputIndex));
554*89c4ff92SAndroid Build Coastguard Worker         }
555*89c4ff92SAndroid Build Coastguard Worker     }
556*89c4ff92SAndroid Build Coastguard Worker 
GetInputSize(unsigned int inputIndex=0u) const557*89c4ff92SAndroid Build Coastguard Worker     unsigned int GetInputSize(unsigned int inputIndex = 0u) const
558*89c4ff92SAndroid Build Coastguard Worker     {
559*89c4ff92SAndroid Build Coastguard Worker         CheckInputIndexIsValid(inputIndex);
560*89c4ff92SAndroid Build Coastguard Worker         return m_InputBindings[inputIndex].second.GetNumElements();
561*89c4ff92SAndroid Build Coastguard Worker     }
562*89c4ff92SAndroid Build Coastguard Worker 
GetOutputSize(unsigned int outputIndex=0u) const563*89c4ff92SAndroid Build Coastguard Worker     unsigned int GetOutputSize(unsigned int outputIndex = 0u) const
564*89c4ff92SAndroid Build Coastguard Worker     {
565*89c4ff92SAndroid Build Coastguard Worker         CheckOutputIndexIsValid(outputIndex);
566*89c4ff92SAndroid Build Coastguard Worker         return m_OutputBindings[outputIndex].second.GetNumElements();
567*89c4ff92SAndroid Build Coastguard Worker     }
568*89c4ff92SAndroid Build Coastguard Worker 
Run(const std::vector<armnnUtils::TContainer> & inputContainers,std::vector<armnnUtils::TContainer> & outputContainers)569*89c4ff92SAndroid Build Coastguard Worker     std::chrono::duration<double, std::milli> Run(
570*89c4ff92SAndroid Build Coastguard Worker             const std::vector<armnnUtils::TContainer>& inputContainers,
571*89c4ff92SAndroid Build Coastguard Worker             std::vector<armnnUtils::TContainer>& outputContainers)
572*89c4ff92SAndroid Build Coastguard Worker     {
573*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int i = 0; i < outputContainers.size(); ++i)
574*89c4ff92SAndroid Build Coastguard Worker         {
575*89c4ff92SAndroid Build Coastguard Worker             const unsigned int expectedOutputDataSize = GetOutputSize(i);
576*89c4ff92SAndroid Build Coastguard Worker 
577*89c4ff92SAndroid Build Coastguard Worker             mapbox::util::apply_visitor([expectedOutputDataSize, i](auto&& value)
578*89c4ff92SAndroid Build Coastguard Worker             {
579*89c4ff92SAndroid Build Coastguard Worker                 const unsigned int actualOutputDataSize   = armnn::numeric_cast<unsigned int>(value.size());
580*89c4ff92SAndroid Build Coastguard Worker                 if (actualOutputDataSize < expectedOutputDataSize)
581*89c4ff92SAndroid Build Coastguard Worker                 {
582*89c4ff92SAndroid Build Coastguard Worker                     unsigned int outputIndex = i;
583*89c4ff92SAndroid Build Coastguard Worker                     throw armnn::Exception(
584*89c4ff92SAndroid Build Coastguard Worker                             fmt::format("Not enough data for output #{0}: expected "
585*89c4ff92SAndroid Build Coastguard Worker                             "{1} elements, got {2}", outputIndex, expectedOutputDataSize, actualOutputDataSize));
586*89c4ff92SAndroid Build Coastguard Worker                 }
587*89c4ff92SAndroid Build Coastguard Worker             },
588*89c4ff92SAndroid Build Coastguard Worker             outputContainers[i]);
589*89c4ff92SAndroid Build Coastguard Worker         }
590*89c4ff92SAndroid Build Coastguard Worker 
591*89c4ff92SAndroid Build Coastguard Worker         std::shared_ptr<armnn::IProfiler> profiler = m_Runtime->GetProfiler(m_NetworkIdentifier);
592*89c4ff92SAndroid Build Coastguard Worker 
593*89c4ff92SAndroid Build Coastguard Worker         // Start timer to record inference time in EnqueueWorkload (in milliseconds)
594*89c4ff92SAndroid Build Coastguard Worker         const auto start_time = armnn::GetTimeNow();
595*89c4ff92SAndroid Build Coastguard Worker 
596*89c4ff92SAndroid Build Coastguard Worker         armnn::Status ret;
597*89c4ff92SAndroid Build Coastguard Worker         if (m_ImportInputsIfAligned)
598*89c4ff92SAndroid Build Coastguard Worker         {
599*89c4ff92SAndroid Build Coastguard Worker             std::vector<armnn::ImportedInputId> importedInputIds = m_Runtime->ImportInputs(
600*89c4ff92SAndroid Build Coastguard Worker                 m_NetworkIdentifier, MakeInputTensors(inputContainers), armnn::MemorySource::Malloc);
601*89c4ff92SAndroid Build Coastguard Worker 
602*89c4ff92SAndroid Build Coastguard Worker             std::vector<armnn::ImportedOutputId> importedOutputIds = m_Runtime->ImportOutputs(
603*89c4ff92SAndroid Build Coastguard Worker                 m_NetworkIdentifier, MakeOutputTensors(outputContainers), armnn::MemorySource::Malloc);
604*89c4ff92SAndroid Build Coastguard Worker 
605*89c4ff92SAndroid Build Coastguard Worker             ret = m_Runtime->EnqueueWorkload(m_NetworkIdentifier,
606*89c4ff92SAndroid Build Coastguard Worker                                              MakeInputTensors(inputContainers),
607*89c4ff92SAndroid Build Coastguard Worker                                              MakeOutputTensors(outputContainers),
608*89c4ff92SAndroid Build Coastguard Worker                                              importedInputIds,
609*89c4ff92SAndroid Build Coastguard Worker                                              importedOutputIds);
610*89c4ff92SAndroid Build Coastguard Worker         }
611*89c4ff92SAndroid Build Coastguard Worker         else
612*89c4ff92SAndroid Build Coastguard Worker         {
613*89c4ff92SAndroid Build Coastguard Worker             ret = m_Runtime->EnqueueWorkload(m_NetworkIdentifier,
614*89c4ff92SAndroid Build Coastguard Worker                                              MakeInputTensors(inputContainers),
615*89c4ff92SAndroid Build Coastguard Worker                                              MakeOutputTensors(outputContainers));
616*89c4ff92SAndroid Build Coastguard Worker         }
617*89c4ff92SAndroid Build Coastguard Worker         const auto duration = armnn::GetTimeDuration(start_time);
618*89c4ff92SAndroid Build Coastguard Worker 
619*89c4ff92SAndroid Build Coastguard Worker         // if profiling is enabled print out the results
620*89c4ff92SAndroid Build Coastguard Worker         if (profiler && profiler->IsProfilingEnabled())
621*89c4ff92SAndroid Build Coastguard Worker         {
622*89c4ff92SAndroid Build Coastguard Worker             profiler->Print(std::cout);
623*89c4ff92SAndroid Build Coastguard Worker         }
624*89c4ff92SAndroid Build Coastguard Worker 
625*89c4ff92SAndroid Build Coastguard Worker         if (ret == armnn::Status::Failure)
626*89c4ff92SAndroid Build Coastguard Worker         {
627*89c4ff92SAndroid Build Coastguard Worker             throw armnn::Exception("IRuntime::EnqueueWorkload failed");
628*89c4ff92SAndroid Build Coastguard Worker         }
629*89c4ff92SAndroid Build Coastguard Worker         else
630*89c4ff92SAndroid Build Coastguard Worker         {
631*89c4ff92SAndroid Build Coastguard Worker             return duration;
632*89c4ff92SAndroid Build Coastguard Worker         }
633*89c4ff92SAndroid Build Coastguard Worker     }
634*89c4ff92SAndroid Build Coastguard Worker 
RunAsync(armnn::experimental::IWorkingMemHandle & workingMemHandleRef,const std::vector<armnnUtils::TContainer> & inputContainers,std::vector<armnnUtils::TContainer> & outputContainers,unsigned int inferenceID)635*89c4ff92SAndroid Build Coastguard Worker     std::tuple<unsigned int, std::chrono::duration<double, std::milli>> RunAsync(
636*89c4ff92SAndroid Build Coastguard Worker         armnn::experimental::IWorkingMemHandle& workingMemHandleRef,
637*89c4ff92SAndroid Build Coastguard Worker         const std::vector<armnnUtils::TContainer>& inputContainers,
638*89c4ff92SAndroid Build Coastguard Worker         std::vector<armnnUtils::TContainer>& outputContainers,
639*89c4ff92SAndroid Build Coastguard Worker         unsigned int inferenceID)
640*89c4ff92SAndroid Build Coastguard Worker     {
641*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int i = 0; i < outputContainers.size(); ++i)
642*89c4ff92SAndroid Build Coastguard Worker         {
643*89c4ff92SAndroid Build Coastguard Worker             const unsigned int expectedOutputDataSize = GetOutputSize(i);
644*89c4ff92SAndroid Build Coastguard Worker 
645*89c4ff92SAndroid Build Coastguard Worker             mapbox::util::apply_visitor([expectedOutputDataSize, i](auto&& value)
646*89c4ff92SAndroid Build Coastguard Worker             {
647*89c4ff92SAndroid Build Coastguard Worker                 const unsigned int actualOutputDataSize   = armnn::numeric_cast<unsigned int>(value.size());
648*89c4ff92SAndroid Build Coastguard Worker                 if (actualOutputDataSize < expectedOutputDataSize)
649*89c4ff92SAndroid Build Coastguard Worker                 {
650*89c4ff92SAndroid Build Coastguard Worker                     unsigned int outputIndex = i;
651*89c4ff92SAndroid Build Coastguard Worker                     throw armnn::Exception(
652*89c4ff92SAndroid Build Coastguard Worker                             fmt::format("Not enough data for output #{0}: expected "
653*89c4ff92SAndroid Build Coastguard Worker                             "{1} elements, got {2}", outputIndex, expectedOutputDataSize, actualOutputDataSize));
654*89c4ff92SAndroid Build Coastguard Worker                 }
655*89c4ff92SAndroid Build Coastguard Worker             },
656*89c4ff92SAndroid Build Coastguard Worker             outputContainers[i]);
657*89c4ff92SAndroid Build Coastguard Worker         }
658*89c4ff92SAndroid Build Coastguard Worker 
659*89c4ff92SAndroid Build Coastguard Worker         std::shared_ptr<armnn::IProfiler> profiler = m_Runtime->GetProfiler(m_NetworkIdentifier);
660*89c4ff92SAndroid Build Coastguard Worker 
661*89c4ff92SAndroid Build Coastguard Worker         // Start timer to record inference time in EnqueueWorkload (in milliseconds)
662*89c4ff92SAndroid Build Coastguard Worker         const auto start_time = armnn::GetTimeNow();
663*89c4ff92SAndroid Build Coastguard Worker 
664*89c4ff92SAndroid Build Coastguard Worker         armnn::Status ret = m_Runtime->Execute(workingMemHandleRef,
665*89c4ff92SAndroid Build Coastguard Worker                                                MakeInputTensors(inputContainers),
666*89c4ff92SAndroid Build Coastguard Worker                                                MakeOutputTensors(outputContainers));
667*89c4ff92SAndroid Build Coastguard Worker 
668*89c4ff92SAndroid Build Coastguard Worker         const auto duration = armnn::GetTimeDuration(start_time);
669*89c4ff92SAndroid Build Coastguard Worker 
670*89c4ff92SAndroid Build Coastguard Worker         // if profiling is enabled print out the results
671*89c4ff92SAndroid Build Coastguard Worker         if (profiler && profiler->IsProfilingEnabled())
672*89c4ff92SAndroid Build Coastguard Worker         {
673*89c4ff92SAndroid Build Coastguard Worker             profiler->Print(std::cout);
674*89c4ff92SAndroid Build Coastguard Worker         }
675*89c4ff92SAndroid Build Coastguard Worker 
676*89c4ff92SAndroid Build Coastguard Worker         if (ret == armnn::Status::Failure)
677*89c4ff92SAndroid Build Coastguard Worker         {
678*89c4ff92SAndroid Build Coastguard Worker             throw armnn::Exception(
679*89c4ff92SAndroid Build Coastguard Worker                 fmt::format("IRuntime::Execute asynchronously failed for network #{0} on inference #{1}",
680*89c4ff92SAndroid Build Coastguard Worker                             m_NetworkIdentifier, inferenceID));
681*89c4ff92SAndroid Build Coastguard Worker         }
682*89c4ff92SAndroid Build Coastguard Worker         else
683*89c4ff92SAndroid Build Coastguard Worker         {
684*89c4ff92SAndroid Build Coastguard Worker             return std::make_tuple(inferenceID, duration);
685*89c4ff92SAndroid Build Coastguard Worker         }
686*89c4ff92SAndroid Build Coastguard Worker     }
687*89c4ff92SAndroid Build Coastguard Worker 
RunAsync(const std::vector<armnnUtils::TContainer> & inputContainers,std::vector<armnnUtils::TContainer> & outputContainers,std::shared_ptr<armnn::IAsyncExecutionCallback> cb)688*89c4ff92SAndroid Build Coastguard Worker     void RunAsync(const std::vector<armnnUtils::TContainer>& inputContainers,
689*89c4ff92SAndroid Build Coastguard Worker                   std::vector<armnnUtils::TContainer>& outputContainers,
690*89c4ff92SAndroid Build Coastguard Worker                   std::shared_ptr<armnn::IAsyncExecutionCallback> cb)
691*89c4ff92SAndroid Build Coastguard Worker     {
692*89c4ff92SAndroid Build Coastguard Worker #if !defined(ARMNN_DISABLE_THREADS)
693*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int i = 0; i < outputContainers.size(); ++i)
694*89c4ff92SAndroid Build Coastguard Worker         {
695*89c4ff92SAndroid Build Coastguard Worker             const unsigned int expectedOutputDataSize = GetOutputSize(i);
696*89c4ff92SAndroid Build Coastguard Worker 
697*89c4ff92SAndroid Build Coastguard Worker             mapbox::util::apply_visitor([expectedOutputDataSize, i](auto&& value)
698*89c4ff92SAndroid Build Coastguard Worker             {
699*89c4ff92SAndroid Build Coastguard Worker                 const unsigned int actualOutputDataSize   = armnn::numeric_cast<unsigned int>(value.size());
700*89c4ff92SAndroid Build Coastguard Worker                 if (actualOutputDataSize < expectedOutputDataSize)
701*89c4ff92SAndroid Build Coastguard Worker                 {
702*89c4ff92SAndroid Build Coastguard Worker                     unsigned int outputIndex = i;
703*89c4ff92SAndroid Build Coastguard Worker                     throw armnn::Exception(
704*89c4ff92SAndroid Build Coastguard Worker                             fmt::format("Not enough data for output #{0}: expected "
705*89c4ff92SAndroid Build Coastguard Worker                             "{1} elements, got {2}", outputIndex, expectedOutputDataSize, actualOutputDataSize));
706*89c4ff92SAndroid Build Coastguard Worker                 }
707*89c4ff92SAndroid Build Coastguard Worker             },
708*89c4ff92SAndroid Build Coastguard Worker             outputContainers[i]);
709*89c4ff92SAndroid Build Coastguard Worker         }
710*89c4ff92SAndroid Build Coastguard Worker 
711*89c4ff92SAndroid Build Coastguard Worker         std::shared_ptr<armnn::IProfiler> profiler = m_Runtime->GetProfiler(m_NetworkIdentifier);
712*89c4ff92SAndroid Build Coastguard Worker 
713*89c4ff92SAndroid Build Coastguard Worker         m_Threadpool->Schedule(m_NetworkIdentifier,
714*89c4ff92SAndroid Build Coastguard Worker                                MakeInputTensors(inputContainers),
715*89c4ff92SAndroid Build Coastguard Worker                                MakeOutputTensors(outputContainers),
716*89c4ff92SAndroid Build Coastguard Worker                                armnn::QosExecPriority::Medium,
717*89c4ff92SAndroid Build Coastguard Worker                                cb);
718*89c4ff92SAndroid Build Coastguard Worker 
719*89c4ff92SAndroid Build Coastguard Worker         // if profiling is enabled print out the results
720*89c4ff92SAndroid Build Coastguard Worker         if (profiler && profiler->IsProfilingEnabled())
721*89c4ff92SAndroid Build Coastguard Worker         {
722*89c4ff92SAndroid Build Coastguard Worker             profiler->Print(std::cout);
723*89c4ff92SAndroid Build Coastguard Worker         }
724*89c4ff92SAndroid Build Coastguard Worker #endif
725*89c4ff92SAndroid Build Coastguard Worker     }
726*89c4ff92SAndroid Build Coastguard Worker 
GetInputBindingInfo(unsigned int inputIndex=0u) const727*89c4ff92SAndroid Build Coastguard Worker     const armnn::BindingPointInfo& GetInputBindingInfo(unsigned int inputIndex = 0u) const
728*89c4ff92SAndroid Build Coastguard Worker     {
729*89c4ff92SAndroid Build Coastguard Worker         CheckInputIndexIsValid(inputIndex);
730*89c4ff92SAndroid Build Coastguard Worker         return m_InputBindings[inputIndex];
731*89c4ff92SAndroid Build Coastguard Worker     }
732*89c4ff92SAndroid Build Coastguard Worker 
GetInputBindingInfos() const733*89c4ff92SAndroid Build Coastguard Worker     const std::vector<armnn::BindingPointInfo>& GetInputBindingInfos() const
734*89c4ff92SAndroid Build Coastguard Worker     {
735*89c4ff92SAndroid Build Coastguard Worker         return m_InputBindings;
736*89c4ff92SAndroid Build Coastguard Worker     }
737*89c4ff92SAndroid Build Coastguard Worker 
GetOutputBindingInfo(unsigned int outputIndex=0u) const738*89c4ff92SAndroid Build Coastguard Worker     const armnn::BindingPointInfo& GetOutputBindingInfo(unsigned int outputIndex = 0u) const
739*89c4ff92SAndroid Build Coastguard Worker     {
740*89c4ff92SAndroid Build Coastguard Worker         CheckOutputIndexIsValid(outputIndex);
741*89c4ff92SAndroid Build Coastguard Worker         return m_OutputBindings[outputIndex];
742*89c4ff92SAndroid Build Coastguard Worker     }
743*89c4ff92SAndroid Build Coastguard Worker 
GetOutputBindingInfos() const744*89c4ff92SAndroid Build Coastguard Worker     const std::vector<armnn::BindingPointInfo>& GetOutputBindingInfos() const
745*89c4ff92SAndroid Build Coastguard Worker     {
746*89c4ff92SAndroid Build Coastguard Worker         return m_OutputBindings;
747*89c4ff92SAndroid Build Coastguard Worker     }
748*89c4ff92SAndroid Build Coastguard Worker 
GetQuantizationParams(unsigned int outputIndex=0u) const749*89c4ff92SAndroid Build Coastguard Worker     QuantizationParams GetQuantizationParams(unsigned int outputIndex = 0u) const
750*89c4ff92SAndroid Build Coastguard Worker     {
751*89c4ff92SAndroid Build Coastguard Worker         CheckOutputIndexIsValid(outputIndex);
752*89c4ff92SAndroid Build Coastguard Worker         return std::make_pair(m_OutputBindings[outputIndex].second.GetQuantizationScale(),
753*89c4ff92SAndroid Build Coastguard Worker                               m_OutputBindings[outputIndex].second.GetQuantizationOffset());
754*89c4ff92SAndroid Build Coastguard Worker     }
755*89c4ff92SAndroid Build Coastguard Worker 
GetInputQuantizationParams(unsigned int inputIndex=0u) const756*89c4ff92SAndroid Build Coastguard Worker     QuantizationParams GetInputQuantizationParams(unsigned int inputIndex = 0u) const
757*89c4ff92SAndroid Build Coastguard Worker     {
758*89c4ff92SAndroid Build Coastguard Worker         CheckInputIndexIsValid(inputIndex);
759*89c4ff92SAndroid Build Coastguard Worker         return std::make_pair(m_InputBindings[inputIndex].second.GetQuantizationScale(),
760*89c4ff92SAndroid Build Coastguard Worker                               m_InputBindings[inputIndex].second.GetQuantizationOffset());
761*89c4ff92SAndroid Build Coastguard Worker     }
762*89c4ff92SAndroid Build Coastguard Worker 
GetAllQuantizationParams() const763*89c4ff92SAndroid Build Coastguard Worker     std::vector<QuantizationParams> GetAllQuantizationParams() const
764*89c4ff92SAndroid Build Coastguard Worker     {
765*89c4ff92SAndroid Build Coastguard Worker         std::vector<QuantizationParams> quantizationParams;
766*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int i = 0u; i < m_OutputBindings.size(); i++)
767*89c4ff92SAndroid Build Coastguard Worker         {
768*89c4ff92SAndroid Build Coastguard Worker             quantizationParams.push_back(GetQuantizationParams(i));
769*89c4ff92SAndroid Build Coastguard Worker         }
770*89c4ff92SAndroid Build Coastguard Worker         return quantizationParams;
771*89c4ff92SAndroid Build Coastguard Worker     }
772*89c4ff92SAndroid Build Coastguard Worker 
CreateWorkingMemHandle()773*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<armnn::experimental::IWorkingMemHandle> CreateWorkingMemHandle()
774*89c4ff92SAndroid Build Coastguard Worker     {
775*89c4ff92SAndroid Build Coastguard Worker         return m_Runtime->CreateWorkingMemHandle(m_NetworkIdentifier);
776*89c4ff92SAndroid Build Coastguard Worker     }
777*89c4ff92SAndroid Build Coastguard Worker 
778*89c4ff92SAndroid Build Coastguard Worker private:
779*89c4ff92SAndroid Build Coastguard Worker     armnn::NetworkId m_NetworkIdentifier;
780*89c4ff92SAndroid Build Coastguard Worker     std::shared_ptr<armnn::IRuntime> m_Runtime;
781*89c4ff92SAndroid Build Coastguard Worker #if !defined(ARMNN_DISABLE_THREADS)
782*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<armnn::Threadpool> m_Threadpool;
783*89c4ff92SAndroid Build Coastguard Worker #endif
784*89c4ff92SAndroid Build Coastguard Worker 
785*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BindingPointInfo> m_InputBindings;
786*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BindingPointInfo> m_OutputBindings;
787*89c4ff92SAndroid Build Coastguard Worker     bool m_EnableProfiling;
788*89c4ff92SAndroid Build Coastguard Worker     armnn::ProfilingDetailsMethod m_ProfilingDetailsMethod;
789*89c4ff92SAndroid Build Coastguard Worker     std::string m_DynamicBackendsPath;
790*89c4ff92SAndroid Build Coastguard Worker     bool m_ImportInputsIfAligned;
791*89c4ff92SAndroid Build Coastguard Worker 
792*89c4ff92SAndroid Build Coastguard Worker     template<typename TContainer>
MakeInputTensors(const std::vector<TContainer> & inputDataContainers)793*89c4ff92SAndroid Build Coastguard Worker     armnn::InputTensors MakeInputTensors(const std::vector<TContainer>& inputDataContainers)
794*89c4ff92SAndroid Build Coastguard Worker     {
795*89c4ff92SAndroid Build Coastguard Worker         return armnnUtils::MakeInputTensors(m_InputBindings, inputDataContainers);
796*89c4ff92SAndroid Build Coastguard Worker     }
797*89c4ff92SAndroid Build Coastguard Worker 
798*89c4ff92SAndroid Build Coastguard Worker     template<typename TContainer>
MakeOutputTensors(std::vector<TContainer> & outputDataContainers)799*89c4ff92SAndroid Build Coastguard Worker     armnn::OutputTensors MakeOutputTensors(std::vector<TContainer>& outputDataContainers)
800*89c4ff92SAndroid Build Coastguard Worker     {
801*89c4ff92SAndroid Build Coastguard Worker         return armnnUtils::MakeOutputTensors(m_OutputBindings, outputDataContainers);
802*89c4ff92SAndroid Build Coastguard Worker     }
803*89c4ff92SAndroid Build Coastguard Worker };
804