1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #pragma once 7*89c4ff92SAndroid Build Coastguard Worker 8*89c4ff92SAndroid Build Coastguard Worker 9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/ArmNN.hpp> 10*89c4ff92SAndroid Build Coastguard Worker 11*89c4ff92SAndroid Build Coastguard Worker #if !defined(ARMNN_DISABLE_THREADS) 12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Threadpool.hpp> 13*89c4ff92SAndroid Build Coastguard Worker #include <common/include/IgnoreUnused.hpp> 14*89c4ff92SAndroid Build Coastguard Worker #endif 15*89c4ff92SAndroid Build Coastguard Worker 16*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Logging.hpp> 17*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Timer.hpp> 18*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendRegistry.hpp> 19*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Assert.hpp> 20*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/NumericCast.hpp> 21*89c4ff92SAndroid Build Coastguard Worker 22*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/TContainer.hpp> 23*89c4ff92SAndroid Build Coastguard Worker #include "NetworkExecutionUtils/NetworkExecutionUtils.hpp" 24*89c4ff92SAndroid Build Coastguard Worker 25*89c4ff92SAndroid Build Coastguard Worker #include <common/include/ProfilingGuid.hpp> 26*89c4ff92SAndroid Build Coastguard Worker 27*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_SERIALIZER) 28*89c4ff92SAndroid Build Coastguard Worker #include "armnnDeserializer/IDeserializer.hpp" 29*89c4ff92SAndroid Build Coastguard Worker #endif 30*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_TF_LITE_PARSER) 31*89c4ff92SAndroid Build Coastguard Worker #include <armnnTfLiteParser/ITfLiteParser.hpp> 32*89c4ff92SAndroid Build Coastguard Worker #endif 33*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_ONNX_PARSER) 34*89c4ff92SAndroid Build Coastguard Worker #include <armnnOnnxParser/IOnnxParser.hpp> 35*89c4ff92SAndroid Build Coastguard Worker #endif 36*89c4ff92SAndroid Build Coastguard Worker 37*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Filesystem.hpp> 38*89c4ff92SAndroid Build Coastguard Worker #include <HeapProfiling.hpp> 39*89c4ff92SAndroid Build Coastguard Worker #include <TensorIOUtils.hpp> 40*89c4ff92SAndroid Build Coastguard Worker 41*89c4ff92SAndroid Build Coastguard Worker #include "armnn/utility/StringUtils.hpp" 42*89c4ff92SAndroid Build Coastguard Worker #include <cxxopts/cxxopts.hpp> 43*89c4ff92SAndroid Build Coastguard Worker #include "CxxoptsUtils.hpp" 44*89c4ff92SAndroid Build Coastguard Worker #include <fmt/format.h> 45*89c4ff92SAndroid Build Coastguard Worker #include <mapbox/variant.hpp> 46*89c4ff92SAndroid Build Coastguard Worker 47*89c4ff92SAndroid Build Coastguard Worker #include <algorithm> 48*89c4ff92SAndroid Build Coastguard Worker #include <iterator> 49*89c4ff92SAndroid Build Coastguard Worker #include <fstream> 50*89c4ff92SAndroid Build Coastguard Worker #include <map> 51*89c4ff92SAndroid Build Coastguard Worker #include <string> 52*89c4ff92SAndroid Build Coastguard Worker #include <vector> 53*89c4ff92SAndroid Build Coastguard Worker #include <type_traits> 54*89c4ff92SAndroid Build Coastguard Worker 55*89c4ff92SAndroid Build Coastguard Worker namespace InferenceModelInternal 56*89c4ff92SAndroid Build Coastguard Worker { 57*89c4ff92SAndroid Build Coastguard Worker using BindingPointInfo = armnn::BindingPointInfo; 58*89c4ff92SAndroid Build Coastguard Worker 59*89c4ff92SAndroid Build Coastguard Worker using QuantizationParams = std::pair<float,int32_t>; 60*89c4ff92SAndroid Build Coastguard Worker 61*89c4ff92SAndroid Build Coastguard Worker struct Params 62*89c4ff92SAndroid Build Coastguard Worker { 63*89c4ff92SAndroid Build Coastguard Worker std::string m_ModelPath; 64*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string> m_InputBindings; 65*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::TensorShape> m_InputShapes; 66*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string> m_OutputBindings; 67*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> m_ComputeDevices; 68*89c4ff92SAndroid Build Coastguard Worker std::string m_DynamicBackendsPath; 69*89c4ff92SAndroid Build Coastguard Worker size_t m_SubgraphId; 70*89c4ff92SAndroid Build Coastguard Worker bool m_AllowExpandedDims; 71*89c4ff92SAndroid Build Coastguard Worker bool m_IsModelBinary; 72*89c4ff92SAndroid Build Coastguard Worker bool m_VisualizePostOptimizationModel; 73*89c4ff92SAndroid Build Coastguard Worker bool m_EnableFp16TurboMode; 74*89c4ff92SAndroid Build Coastguard Worker bool m_EnableBf16TurboMode; 75*89c4ff92SAndroid Build Coastguard Worker bool m_PrintIntermediateLayers; 76*89c4ff92SAndroid Build Coastguard Worker bool m_PrintIntermediateLayersToFile; 77*89c4ff92SAndroid Build Coastguard Worker bool m_ParseUnsupported; 78*89c4ff92SAndroid Build Coastguard Worker bool m_InferOutputShape; 79*89c4ff92SAndroid Build Coastguard Worker bool m_EnableFastMath; 80*89c4ff92SAndroid Build Coastguard Worker bool m_SaveCachedNetwork; 81*89c4ff92SAndroid Build Coastguard Worker bool m_OutputDetailsToStdOut; 82*89c4ff92SAndroid Build Coastguard Worker bool m_OutputDetailsOnlyToStdOut; 83*89c4ff92SAndroid Build Coastguard Worker std::string m_CachedNetworkFilePath; 84*89c4ff92SAndroid Build Coastguard Worker unsigned int m_NumberOfThreads; 85*89c4ff92SAndroid Build Coastguard Worker std::string m_MLGOTuningFilePath; 86*89c4ff92SAndroid Build Coastguard Worker bool m_AsyncEnabled; 87*89c4ff92SAndroid Build Coastguard Worker size_t m_ThreadPoolSize; 88*89c4ff92SAndroid Build Coastguard Worker bool m_ImportInputsIfAligned; 89*89c4ff92SAndroid Build Coastguard Worker 90*89c4ff92SAndroid Build Coastguard Worker ParamsInferenceModelInternal::Params91*89c4ff92SAndroid Build Coastguard Worker Params() 92*89c4ff92SAndroid Build Coastguard Worker : m_ComputeDevices{} 93*89c4ff92SAndroid Build Coastguard Worker , m_SubgraphId(0) 94*89c4ff92SAndroid Build Coastguard Worker , m_AllowExpandedDims(false) 95*89c4ff92SAndroid Build Coastguard Worker , m_IsModelBinary(true) 96*89c4ff92SAndroid Build Coastguard Worker , m_VisualizePostOptimizationModel(false) 97*89c4ff92SAndroid Build Coastguard Worker , m_EnableFp16TurboMode(false) 98*89c4ff92SAndroid Build Coastguard Worker , m_EnableBf16TurboMode(false) 99*89c4ff92SAndroid Build Coastguard Worker , m_PrintIntermediateLayers(false) 100*89c4ff92SAndroid Build Coastguard Worker , m_PrintIntermediateLayersToFile(false) 101*89c4ff92SAndroid Build Coastguard Worker , m_ParseUnsupported(false) 102*89c4ff92SAndroid Build Coastguard Worker , m_InferOutputShape(false) 103*89c4ff92SAndroid Build Coastguard Worker , m_EnableFastMath(false) 104*89c4ff92SAndroid Build Coastguard Worker , m_SaveCachedNetwork(false) 105*89c4ff92SAndroid Build Coastguard Worker , m_OutputDetailsToStdOut(false) 106*89c4ff92SAndroid Build Coastguard Worker , m_OutputDetailsOnlyToStdOut(false) 107*89c4ff92SAndroid Build Coastguard Worker , m_CachedNetworkFilePath("") 108*89c4ff92SAndroid Build Coastguard Worker , m_NumberOfThreads(0) 109*89c4ff92SAndroid Build Coastguard Worker , m_MLGOTuningFilePath("") 110*89c4ff92SAndroid Build Coastguard Worker , m_AsyncEnabled(false) 111*89c4ff92SAndroid Build Coastguard Worker , m_ThreadPoolSize(0) 112*89c4ff92SAndroid Build Coastguard Worker , m_ImportInputsIfAligned(false) 113*89c4ff92SAndroid Build Coastguard Worker {} 114*89c4ff92SAndroid Build Coastguard Worker }; 115*89c4ff92SAndroid Build Coastguard Worker 116*89c4ff92SAndroid Build Coastguard Worker } // namespace InferenceModelInternal 117*89c4ff92SAndroid Build Coastguard Worker 118*89c4ff92SAndroid Build Coastguard Worker template <typename IParser> 119*89c4ff92SAndroid Build Coastguard Worker struct CreateNetworkImpl 120*89c4ff92SAndroid Build Coastguard Worker { 121*89c4ff92SAndroid Build Coastguard Worker public: 122*89c4ff92SAndroid Build Coastguard Worker using Params = InferenceModelInternal::Params; 123*89c4ff92SAndroid Build Coastguard Worker CreateCreateNetworkImpl124*89c4ff92SAndroid Build Coastguard Worker static armnn::INetworkPtr Create(const Params& params, 125*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BindingPointInfo>& inputBindings, 126*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BindingPointInfo>& outputBindings) 127*89c4ff92SAndroid Build Coastguard Worker { 128*89c4ff92SAndroid Build Coastguard Worker const std::string& modelPath = params.m_ModelPath; 129*89c4ff92SAndroid Build Coastguard Worker 130*89c4ff92SAndroid Build Coastguard Worker // Create a network from a file on disk 131*89c4ff92SAndroid Build Coastguard Worker auto parser(IParser::Create()); 132*89c4ff92SAndroid Build Coastguard Worker 133*89c4ff92SAndroid Build Coastguard Worker std::map<std::string, armnn::TensorShape> inputShapes; 134*89c4ff92SAndroid Build Coastguard Worker if (!params.m_InputShapes.empty()) 135*89c4ff92SAndroid Build Coastguard Worker { 136*89c4ff92SAndroid Build Coastguard Worker const size_t numInputShapes = params.m_InputShapes.size(); 137*89c4ff92SAndroid Build Coastguard Worker const size_t numInputBindings = params.m_InputBindings.size(); 138*89c4ff92SAndroid Build Coastguard Worker if (numInputShapes < numInputBindings) 139*89c4ff92SAndroid Build Coastguard Worker { 140*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception(fmt::format( 141*89c4ff92SAndroid Build Coastguard Worker "Not every input has its tensor shape specified: expected={0}, got={1}", 142*89c4ff92SAndroid Build Coastguard Worker numInputBindings, numInputShapes)); 143*89c4ff92SAndroid Build Coastguard Worker } 144*89c4ff92SAndroid Build Coastguard Worker 145*89c4ff92SAndroid Build Coastguard Worker for (size_t i = 0; i < numInputShapes; i++) 146*89c4ff92SAndroid Build Coastguard Worker { 147*89c4ff92SAndroid Build Coastguard Worker inputShapes[params.m_InputBindings[i]] = params.m_InputShapes[i]; 148*89c4ff92SAndroid Build Coastguard Worker } 149*89c4ff92SAndroid Build Coastguard Worker } 150*89c4ff92SAndroid Build Coastguard Worker 151*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string> requestedOutputs = params.m_OutputBindings; 152*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr network{nullptr, [](armnn::INetwork *){}}; 153*89c4ff92SAndroid Build Coastguard Worker 154*89c4ff92SAndroid Build Coastguard Worker { 155*89c4ff92SAndroid Build Coastguard Worker ARMNN_SCOPED_HEAP_PROFILING("Parsing"); 156*89c4ff92SAndroid Build Coastguard Worker // Handle text and binary input differently by calling the corresponding parser function 157*89c4ff92SAndroid Build Coastguard Worker network = (params.m_IsModelBinary ? 158*89c4ff92SAndroid Build Coastguard Worker parser->CreateNetworkFromBinaryFile(modelPath.c_str(), inputShapes, requestedOutputs) : 159*89c4ff92SAndroid Build Coastguard Worker parser->CreateNetworkFromTextFile(modelPath.c_str(), inputShapes, requestedOutputs)); 160*89c4ff92SAndroid Build Coastguard Worker } 161*89c4ff92SAndroid Build Coastguard Worker 162*89c4ff92SAndroid Build Coastguard Worker for (const std::string& inputLayerName : params.m_InputBindings) 163*89c4ff92SAndroid Build Coastguard Worker { 164*89c4ff92SAndroid Build Coastguard Worker inputBindings.push_back(parser->GetNetworkInputBindingInfo(inputLayerName)); 165*89c4ff92SAndroid Build Coastguard Worker } 166*89c4ff92SAndroid Build Coastguard Worker 167*89c4ff92SAndroid Build Coastguard Worker for (const std::string& outputLayerName : params.m_OutputBindings) 168*89c4ff92SAndroid Build Coastguard Worker { 169*89c4ff92SAndroid Build Coastguard Worker outputBindings.push_back(parser->GetNetworkOutputBindingInfo(outputLayerName)); 170*89c4ff92SAndroid Build Coastguard Worker } 171*89c4ff92SAndroid Build Coastguard Worker 172*89c4ff92SAndroid Build Coastguard Worker return network; 173*89c4ff92SAndroid Build Coastguard Worker } 174*89c4ff92SAndroid Build Coastguard Worker }; 175*89c4ff92SAndroid Build Coastguard Worker 176*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_SERIALIZER) 177*89c4ff92SAndroid Build Coastguard Worker template <> 178*89c4ff92SAndroid Build Coastguard Worker struct CreateNetworkImpl<armnnDeserializer::IDeserializer> 179*89c4ff92SAndroid Build Coastguard Worker { 180*89c4ff92SAndroid Build Coastguard Worker public: 181*89c4ff92SAndroid Build Coastguard Worker using IParser = armnnDeserializer::IDeserializer; 182*89c4ff92SAndroid Build Coastguard Worker using Params = InferenceModelInternal::Params; 183*89c4ff92SAndroid Build Coastguard Worker CreateCreateNetworkImpl184*89c4ff92SAndroid Build Coastguard Worker static armnn::INetworkPtr Create(const Params& params, 185*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BindingPointInfo>& inputBindings, 186*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BindingPointInfo>& outputBindings) 187*89c4ff92SAndroid Build Coastguard Worker { 188*89c4ff92SAndroid Build Coastguard Worker auto parser(IParser::Create()); 189*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(parser); 190*89c4ff92SAndroid Build Coastguard Worker 191*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr network{nullptr, [](armnn::INetwork *){}}; 192*89c4ff92SAndroid Build Coastguard Worker 193*89c4ff92SAndroid Build Coastguard Worker { 194*89c4ff92SAndroid Build Coastguard Worker ARMNN_SCOPED_HEAP_PROFILING("Parsing"); 195*89c4ff92SAndroid Build Coastguard Worker 196*89c4ff92SAndroid Build Coastguard Worker std::error_code errorCode; 197*89c4ff92SAndroid Build Coastguard Worker fs::path pathToFile(params.m_ModelPath); 198*89c4ff92SAndroid Build Coastguard Worker if (!fs::exists(pathToFile, errorCode)) 199*89c4ff92SAndroid Build Coastguard Worker { 200*89c4ff92SAndroid Build Coastguard Worker throw armnn::FileNotFoundException(fmt::format("Cannot find the file ({0}) errorCode: {1} {2}", 201*89c4ff92SAndroid Build Coastguard Worker params.m_ModelPath, 202*89c4ff92SAndroid Build Coastguard Worker errorCode.message(), 203*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION().AsString())); 204*89c4ff92SAndroid Build Coastguard Worker } 205*89c4ff92SAndroid Build Coastguard Worker std::ifstream file(params.m_ModelPath, std::ios::binary); 206*89c4ff92SAndroid Build Coastguard Worker 207*89c4ff92SAndroid Build Coastguard Worker network = parser->CreateNetworkFromBinary(file); 208*89c4ff92SAndroid Build Coastguard Worker } 209*89c4ff92SAndroid Build Coastguard Worker 210*89c4ff92SAndroid Build Coastguard Worker unsigned int subgraphId = armnn::numeric_cast<unsigned int>(params.m_SubgraphId); 211*89c4ff92SAndroid Build Coastguard Worker 212*89c4ff92SAndroid Build Coastguard Worker for (const std::string& inputLayerName : params.m_InputBindings) 213*89c4ff92SAndroid Build Coastguard Worker { 214*89c4ff92SAndroid Build Coastguard Worker armnnDeserializer::BindingPointInfo inputBinding = 215*89c4ff92SAndroid Build Coastguard Worker parser->GetNetworkInputBindingInfo(subgraphId, inputLayerName); 216*89c4ff92SAndroid Build Coastguard Worker inputBindings.push_back(std::make_pair(inputBinding.m_BindingId, inputBinding.m_TensorInfo)); 217*89c4ff92SAndroid Build Coastguard Worker } 218*89c4ff92SAndroid Build Coastguard Worker 219*89c4ff92SAndroid Build Coastguard Worker for (const std::string& outputLayerName : params.m_OutputBindings) 220*89c4ff92SAndroid Build Coastguard Worker { 221*89c4ff92SAndroid Build Coastguard Worker armnnDeserializer::BindingPointInfo outputBinding = 222*89c4ff92SAndroid Build Coastguard Worker parser->GetNetworkOutputBindingInfo(subgraphId, outputLayerName); 223*89c4ff92SAndroid Build Coastguard Worker outputBindings.push_back(std::make_pair(outputBinding.m_BindingId, outputBinding.m_TensorInfo)); 224*89c4ff92SAndroid Build Coastguard Worker } 225*89c4ff92SAndroid Build Coastguard Worker 226*89c4ff92SAndroid Build Coastguard Worker return network; 227*89c4ff92SAndroid Build Coastguard Worker } 228*89c4ff92SAndroid Build Coastguard Worker }; 229*89c4ff92SAndroid Build Coastguard Worker #endif 230*89c4ff92SAndroid Build Coastguard Worker 231*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_TF_LITE_PARSER) 232*89c4ff92SAndroid Build Coastguard Worker template <> 233*89c4ff92SAndroid Build Coastguard Worker struct CreateNetworkImpl<armnnTfLiteParser::ITfLiteParser> 234*89c4ff92SAndroid Build Coastguard Worker { 235*89c4ff92SAndroid Build Coastguard Worker public: 236*89c4ff92SAndroid Build Coastguard Worker using IParser = armnnTfLiteParser::ITfLiteParser; 237*89c4ff92SAndroid Build Coastguard Worker using Params = InferenceModelInternal::Params; 238*89c4ff92SAndroid Build Coastguard Worker CreateCreateNetworkImpl239*89c4ff92SAndroid Build Coastguard Worker static armnn::INetworkPtr Create(const Params& params, 240*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BindingPointInfo>& inputBindings, 241*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BindingPointInfo>& outputBindings) 242*89c4ff92SAndroid Build Coastguard Worker { 243*89c4ff92SAndroid Build Coastguard Worker const std::string& modelPath = params.m_ModelPath; 244*89c4ff92SAndroid Build Coastguard Worker 245*89c4ff92SAndroid Build Coastguard Worker // Create a network from a file on disk 246*89c4ff92SAndroid Build Coastguard Worker IParser::TfLiteParserOptions options; 247*89c4ff92SAndroid Build Coastguard Worker options.m_AllowExpandedDims = params.m_AllowExpandedDims; 248*89c4ff92SAndroid Build Coastguard Worker options.m_StandInLayerForUnsupported = params.m_ParseUnsupported; 249*89c4ff92SAndroid Build Coastguard Worker options.m_InferAndValidate = params.m_InferOutputShape; 250*89c4ff92SAndroid Build Coastguard Worker auto parser(IParser::Create(options)); 251*89c4ff92SAndroid Build Coastguard Worker 252*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr network{nullptr, [](armnn::INetwork *){}}; 253*89c4ff92SAndroid Build Coastguard Worker 254*89c4ff92SAndroid Build Coastguard Worker { 255*89c4ff92SAndroid Build Coastguard Worker ARMNN_SCOPED_HEAP_PROFILING("Parsing"); 256*89c4ff92SAndroid Build Coastguard Worker network = parser->CreateNetworkFromBinaryFile(modelPath.c_str()); 257*89c4ff92SAndroid Build Coastguard Worker } 258*89c4ff92SAndroid Build Coastguard Worker 259*89c4ff92SAndroid Build Coastguard Worker for (const std::string& inputLayerName : params.m_InputBindings) 260*89c4ff92SAndroid Build Coastguard Worker { 261*89c4ff92SAndroid Build Coastguard Worker armnn::BindingPointInfo inputBinding = 262*89c4ff92SAndroid Build Coastguard Worker parser->GetNetworkInputBindingInfo(params.m_SubgraphId, inputLayerName); 263*89c4ff92SAndroid Build Coastguard Worker inputBindings.push_back(inputBinding); 264*89c4ff92SAndroid Build Coastguard Worker } 265*89c4ff92SAndroid Build Coastguard Worker 266*89c4ff92SAndroid Build Coastguard Worker for (const std::string& outputLayerName : params.m_OutputBindings) 267*89c4ff92SAndroid Build Coastguard Worker { 268*89c4ff92SAndroid Build Coastguard Worker armnn::BindingPointInfo outputBinding = 269*89c4ff92SAndroid Build Coastguard Worker parser->GetNetworkOutputBindingInfo(params.m_SubgraphId, outputLayerName); 270*89c4ff92SAndroid Build Coastguard Worker outputBindings.push_back(outputBinding); 271*89c4ff92SAndroid Build Coastguard Worker } 272*89c4ff92SAndroid Build Coastguard Worker 273*89c4ff92SAndroid Build Coastguard Worker return network; 274*89c4ff92SAndroid Build Coastguard Worker } 275*89c4ff92SAndroid Build Coastguard Worker }; 276*89c4ff92SAndroid Build Coastguard Worker #endif 277*89c4ff92SAndroid Build Coastguard Worker 278*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_ONNX_PARSER) 279*89c4ff92SAndroid Build Coastguard Worker template <> 280*89c4ff92SAndroid Build Coastguard Worker struct CreateNetworkImpl<armnnOnnxParser::IOnnxParser> 281*89c4ff92SAndroid Build Coastguard Worker { 282*89c4ff92SAndroid Build Coastguard Worker public: 283*89c4ff92SAndroid Build Coastguard Worker using IParser = armnnOnnxParser::IOnnxParser; 284*89c4ff92SAndroid Build Coastguard Worker using Params = InferenceModelInternal::Params; 285*89c4ff92SAndroid Build Coastguard Worker using BindingPointInfo = InferenceModelInternal::BindingPointInfo; 286*89c4ff92SAndroid Build Coastguard Worker CreateCreateNetworkImpl287*89c4ff92SAndroid Build Coastguard Worker static armnn::INetworkPtr Create(const Params& params, 288*89c4ff92SAndroid Build Coastguard Worker std::vector<BindingPointInfo>& inputBindings, 289*89c4ff92SAndroid Build Coastguard Worker std::vector<BindingPointInfo>& outputBindings) 290*89c4ff92SAndroid Build Coastguard Worker { 291*89c4ff92SAndroid Build Coastguard Worker const std::string& modelPath = params.m_ModelPath; 292*89c4ff92SAndroid Build Coastguard Worker 293*89c4ff92SAndroid Build Coastguard Worker // Create a network from a file on disk 294*89c4ff92SAndroid Build Coastguard Worker auto parser(IParser::Create()); 295*89c4ff92SAndroid Build Coastguard Worker 296*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr network{nullptr, [](armnn::INetwork *){}}; 297*89c4ff92SAndroid Build Coastguard Worker 298*89c4ff92SAndroid Build Coastguard Worker std::map<std::string, armnn::TensorShape> inputShapes; 299*89c4ff92SAndroid Build Coastguard Worker if (!params.m_InputShapes.empty()) 300*89c4ff92SAndroid Build Coastguard Worker { 301*89c4ff92SAndroid Build Coastguard Worker const size_t numInputShapes = params.m_InputShapes.size(); 302*89c4ff92SAndroid Build Coastguard Worker const size_t numInputBindings = params.m_InputBindings.size(); 303*89c4ff92SAndroid Build Coastguard Worker if (numInputShapes < numInputBindings) 304*89c4ff92SAndroid Build Coastguard Worker { 305*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception(fmt::format( 306*89c4ff92SAndroid Build Coastguard Worker "Not every input has its tensor shape specified: expected={0}, got={1}", 307*89c4ff92SAndroid Build Coastguard Worker numInputBindings, numInputShapes)); 308*89c4ff92SAndroid Build Coastguard Worker } 309*89c4ff92SAndroid Build Coastguard Worker 310*89c4ff92SAndroid Build Coastguard Worker for (size_t i = 0; i < numInputShapes; i++) 311*89c4ff92SAndroid Build Coastguard Worker { 312*89c4ff92SAndroid Build Coastguard Worker inputShapes[params.m_InputBindings[i]] = params.m_InputShapes[i]; 313*89c4ff92SAndroid Build Coastguard Worker } 314*89c4ff92SAndroid Build Coastguard Worker 315*89c4ff92SAndroid Build Coastguard Worker { 316*89c4ff92SAndroid Build Coastguard Worker ARMNN_SCOPED_HEAP_PROFILING("Parsing"); 317*89c4ff92SAndroid Build Coastguard Worker network = (params.m_IsModelBinary ? 318*89c4ff92SAndroid Build Coastguard Worker parser->CreateNetworkFromBinaryFile(modelPath.c_str(), inputShapes) : 319*89c4ff92SAndroid Build Coastguard Worker parser->CreateNetworkFromTextFile(modelPath.c_str(), inputShapes)); 320*89c4ff92SAndroid Build Coastguard Worker } 321*89c4ff92SAndroid Build Coastguard Worker } 322*89c4ff92SAndroid Build Coastguard Worker 323*89c4ff92SAndroid Build Coastguard Worker else 324*89c4ff92SAndroid Build Coastguard Worker { 325*89c4ff92SAndroid Build Coastguard Worker ARMNN_SCOPED_HEAP_PROFILING("Parsing"); 326*89c4ff92SAndroid Build Coastguard Worker network = (params.m_IsModelBinary ? 327*89c4ff92SAndroid Build Coastguard Worker parser->CreateNetworkFromBinaryFile(modelPath.c_str()) : 328*89c4ff92SAndroid Build Coastguard Worker parser->CreateNetworkFromTextFile(modelPath.c_str())); 329*89c4ff92SAndroid Build Coastguard Worker } 330*89c4ff92SAndroid Build Coastguard Worker 331*89c4ff92SAndroid Build Coastguard Worker for (const std::string& inputLayerName : params.m_InputBindings) 332*89c4ff92SAndroid Build Coastguard Worker { 333*89c4ff92SAndroid Build Coastguard Worker BindingPointInfo inputBinding = parser->GetNetworkInputBindingInfo(inputLayerName); 334*89c4ff92SAndroid Build Coastguard Worker inputBindings.push_back(inputBinding); 335*89c4ff92SAndroid Build Coastguard Worker } 336*89c4ff92SAndroid Build Coastguard Worker 337*89c4ff92SAndroid Build Coastguard Worker for (const std::string& outputLayerName : params.m_OutputBindings) 338*89c4ff92SAndroid Build Coastguard Worker { 339*89c4ff92SAndroid Build Coastguard Worker BindingPointInfo outputBinding = parser->GetNetworkOutputBindingInfo(outputLayerName); 340*89c4ff92SAndroid Build Coastguard Worker outputBindings.push_back(outputBinding); 341*89c4ff92SAndroid Build Coastguard Worker } 342*89c4ff92SAndroid Build Coastguard Worker 343*89c4ff92SAndroid Build Coastguard Worker return network; 344*89c4ff92SAndroid Build Coastguard Worker } 345*89c4ff92SAndroid Build Coastguard Worker }; 346*89c4ff92SAndroid Build Coastguard Worker #endif 347*89c4ff92SAndroid Build Coastguard Worker 348*89c4ff92SAndroid Build Coastguard Worker 349*89c4ff92SAndroid Build Coastguard Worker 350*89c4ff92SAndroid Build Coastguard Worker template <typename IParser, typename TDataType> 351*89c4ff92SAndroid Build Coastguard Worker class InferenceModel 352*89c4ff92SAndroid Build Coastguard Worker { 353*89c4ff92SAndroid Build Coastguard Worker public: 354*89c4ff92SAndroid Build Coastguard Worker using DataType = TDataType; 355*89c4ff92SAndroid Build Coastguard Worker using Params = InferenceModelInternal::Params; 356*89c4ff92SAndroid Build Coastguard Worker using QuantizationParams = InferenceModelInternal::QuantizationParams; 357*89c4ff92SAndroid Build Coastguard Worker 358*89c4ff92SAndroid Build Coastguard Worker 359*89c4ff92SAndroid Build Coastguard Worker struct CommandLineOptions 360*89c4ff92SAndroid Build Coastguard Worker { 361*89c4ff92SAndroid Build Coastguard Worker std::string m_ModelDir; 362*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string> m_ComputeDevices; 363*89c4ff92SAndroid Build Coastguard Worker std::string m_DynamicBackendsPath; 364*89c4ff92SAndroid Build Coastguard Worker bool m_VisualizePostOptimizationModel; 365*89c4ff92SAndroid Build Coastguard Worker bool m_EnableFp16TurboMode; 366*89c4ff92SAndroid Build Coastguard Worker bool m_EnableBf16TurboMode; 367*89c4ff92SAndroid Build Coastguard Worker std::string m_Labels; 368*89c4ff92SAndroid Build Coastguard Worker GetComputeDevicesAsBackendIdsInferenceModel::CommandLineOptions369*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> GetComputeDevicesAsBackendIds() 370*89c4ff92SAndroid Build Coastguard Worker { 371*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backendIds; 372*89c4ff92SAndroid Build Coastguard Worker std::copy(m_ComputeDevices.begin(), m_ComputeDevices.end(), std::back_inserter(backendIds)); 373*89c4ff92SAndroid Build Coastguard Worker return backendIds; 374*89c4ff92SAndroid Build Coastguard Worker } 375*89c4ff92SAndroid Build Coastguard Worker }; 376*89c4ff92SAndroid Build Coastguard Worker AddCommandLineOptions(cxxopts::Options & options,CommandLineOptions & cLineOptions,std::vector<std::string> & required)377*89c4ff92SAndroid Build Coastguard Worker static void AddCommandLineOptions(cxxopts::Options& options, 378*89c4ff92SAndroid Build Coastguard Worker CommandLineOptions& cLineOptions, std::vector<std::string>& required) 379*89c4ff92SAndroid Build Coastguard Worker { 380*89c4ff92SAndroid Build Coastguard Worker const std::vector<std::string> defaultComputes = { "CpuAcc", "CpuRef" }; 381*89c4ff92SAndroid Build Coastguard Worker 382*89c4ff92SAndroid Build Coastguard Worker const std::string backendsMessage = "Which device to run layers on by default. Possible choices: " 383*89c4ff92SAndroid Build Coastguard Worker + armnn::BackendRegistryInstance().GetBackendIdsAsString(); 384*89c4ff92SAndroid Build Coastguard Worker 385*89c4ff92SAndroid Build Coastguard Worker options 386*89c4ff92SAndroid Build Coastguard Worker .allow_unrecognised_options() 387*89c4ff92SAndroid Build Coastguard Worker .add_options() 388*89c4ff92SAndroid Build Coastguard Worker ("m,model-dir", "Path to directory containing model files (.prototxt/.tflite)", 389*89c4ff92SAndroid Build Coastguard Worker cxxopts::value<std::string>(cLineOptions.m_ModelDir)) 390*89c4ff92SAndroid Build Coastguard Worker ("c,compute", backendsMessage.c_str(), 391*89c4ff92SAndroid Build Coastguard Worker cxxopts::value<std::vector<std::string>>(cLineOptions.m_ComputeDevices)->default_value("CpuRef")) 392*89c4ff92SAndroid Build Coastguard Worker ("b,dynamic-backends-path", 393*89c4ff92SAndroid Build Coastguard Worker "Path where to load any available dynamic backend from. " 394*89c4ff92SAndroid Build Coastguard Worker "If left empty (the default), dynamic backends will not be used.", 395*89c4ff92SAndroid Build Coastguard Worker cxxopts::value(cLineOptions.m_DynamicBackendsPath)) 396*89c4ff92SAndroid Build Coastguard Worker ("l,labels", 397*89c4ff92SAndroid Build Coastguard Worker "Text file containing one image filename - correct label pair per line, " 398*89c4ff92SAndroid Build Coastguard Worker "used to test the accuracy of the network.", cxxopts::value<std::string>(cLineOptions.m_Labels)) 399*89c4ff92SAndroid Build Coastguard Worker ("v,visualize-optimized-model", 400*89c4ff92SAndroid Build Coastguard Worker "Produce a dot file useful for visualizing the graph post optimization." 401*89c4ff92SAndroid Build Coastguard Worker "The file will have the same name as the model with the .dot extention.", 402*89c4ff92SAndroid Build Coastguard Worker cxxopts::value<bool>(cLineOptions.m_VisualizePostOptimizationModel)->default_value("false")) 403*89c4ff92SAndroid Build Coastguard Worker ("fp16-turbo-mode", 404*89c4ff92SAndroid Build Coastguard Worker "If this option is enabled FP32 layers, weights and biases will be converted " 405*89c4ff92SAndroid Build Coastguard Worker "to FP16 where the backend supports it.", 406*89c4ff92SAndroid Build Coastguard Worker cxxopts::value<bool>(cLineOptions.m_EnableFp16TurboMode)->default_value("false")) 407*89c4ff92SAndroid Build Coastguard Worker ("bf16-turbo-mode", 408*89c4ff92SAndroid Build Coastguard Worker "If this option is enabled FP32 layers, weights and biases will be converted " 409*89c4ff92SAndroid Build Coastguard Worker "to BF16 where the backend supports it.", 410*89c4ff92SAndroid Build Coastguard Worker cxxopts::value<bool>(cLineOptions.m_EnableBf16TurboMode)->default_value("false")); 411*89c4ff92SAndroid Build Coastguard Worker 412*89c4ff92SAndroid Build Coastguard Worker required.emplace_back("model-dir"); 413*89c4ff92SAndroid Build Coastguard Worker } 414*89c4ff92SAndroid Build Coastguard Worker InferenceModel(const Params & params,bool enableProfiling,const std::string & dynamicBackendsPath,const std::shared_ptr<armnn::IRuntime> & runtime=nullptr)415*89c4ff92SAndroid Build Coastguard Worker InferenceModel(const Params& params, 416*89c4ff92SAndroid Build Coastguard Worker bool enableProfiling, 417*89c4ff92SAndroid Build Coastguard Worker const std::string& dynamicBackendsPath, 418*89c4ff92SAndroid Build Coastguard Worker const std::shared_ptr<armnn::IRuntime>& runtime = nullptr) 419*89c4ff92SAndroid Build Coastguard Worker : m_EnableProfiling(enableProfiling), 420*89c4ff92SAndroid Build Coastguard Worker m_ProfilingDetailsMethod(armnn::ProfilingDetailsMethod::Undefined), 421*89c4ff92SAndroid Build Coastguard Worker m_DynamicBackendsPath(dynamicBackendsPath), 422*89c4ff92SAndroid Build Coastguard Worker m_ImportInputsIfAligned(params.m_ImportInputsIfAligned) 423*89c4ff92SAndroid Build Coastguard Worker { 424*89c4ff92SAndroid Build Coastguard Worker if (runtime) 425*89c4ff92SAndroid Build Coastguard Worker { 426*89c4ff92SAndroid Build Coastguard Worker m_Runtime = runtime; 427*89c4ff92SAndroid Build Coastguard Worker } 428*89c4ff92SAndroid Build Coastguard Worker else 429*89c4ff92SAndroid Build Coastguard Worker { 430*89c4ff92SAndroid Build Coastguard Worker armnn::IRuntime::CreationOptions options; 431*89c4ff92SAndroid Build Coastguard Worker options.m_EnableGpuProfiling = m_EnableProfiling; 432*89c4ff92SAndroid Build Coastguard Worker options.m_DynamicBackendsPath = m_DynamicBackendsPath; 433*89c4ff92SAndroid Build Coastguard Worker m_Runtime = armnn::IRuntime::Create(options); 434*89c4ff92SAndroid Build Coastguard Worker } 435*89c4ff92SAndroid Build Coastguard Worker 436*89c4ff92SAndroid Build Coastguard Worker // Configure the Profiler if the the profiling details are opted for 437*89c4ff92SAndroid Build Coastguard Worker if (params.m_OutputDetailsOnlyToStdOut) 438*89c4ff92SAndroid Build Coastguard Worker m_ProfilingDetailsMethod = armnn::ProfilingDetailsMethod::DetailsOnly; 439*89c4ff92SAndroid Build Coastguard Worker else if (params.m_OutputDetailsToStdOut) 440*89c4ff92SAndroid Build Coastguard Worker m_ProfilingDetailsMethod = armnn::ProfilingDetailsMethod::DetailsWithEvents; 441*89c4ff92SAndroid Build Coastguard Worker 442*89c4ff92SAndroid Build Coastguard Worker std::string invalidBackends; 443*89c4ff92SAndroid Build Coastguard Worker if (!CheckRequestedBackendsAreValid(params.m_ComputeDevices, armnn::Optional<std::string&>(invalidBackends))) 444*89c4ff92SAndroid Build Coastguard Worker { 445*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception("Some backend IDs are invalid: " + invalidBackends); 446*89c4ff92SAndroid Build Coastguard Worker } 447*89c4ff92SAndroid Build Coastguard Worker 448*89c4ff92SAndroid Build Coastguard Worker armnn::IOptimizedNetworkPtr optNet{nullptr, [](armnn::IOptimizedNetwork*){}}; 449*89c4ff92SAndroid Build Coastguard Worker { 450*89c4ff92SAndroid Build Coastguard Worker const auto parsing_start_time = armnn::GetTimeNow(); 451*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr network = CreateNetworkImpl<IParser>::Create(params, m_InputBindings, m_OutputBindings); 452*89c4ff92SAndroid Build Coastguard Worker 453*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(info) << "Network parsing time: " << std::setprecision(2) 454*89c4ff92SAndroid Build Coastguard Worker << std::fixed << armnn::GetTimeDuration(parsing_start_time).count() << " ms."; 455*89c4ff92SAndroid Build Coastguard Worker 456*89c4ff92SAndroid Build Coastguard Worker ARMNN_SCOPED_HEAP_PROFILING("Optimizing"); 457*89c4ff92SAndroid Build Coastguard Worker 458*89c4ff92SAndroid Build Coastguard Worker armnn::OptimizerOptionsOpaque options; 459*89c4ff92SAndroid Build Coastguard Worker options.SetReduceFp32ToFp16(params.m_EnableFp16TurboMode); 460*89c4ff92SAndroid Build Coastguard Worker options.SetDebugEnabled(params.m_PrintIntermediateLayers); 461*89c4ff92SAndroid Build Coastguard Worker options.SetDebugToFileEnabled(params.m_PrintIntermediateLayersToFile); 462*89c4ff92SAndroid Build Coastguard Worker options.SetShapeInferenceMethod(params.m_InferOutputShape ? 463*89c4ff92SAndroid Build Coastguard Worker armnn::ShapeInferenceMethod::InferAndValidate : armnn::ShapeInferenceMethod::ValidateOnly); 464*89c4ff92SAndroid Build Coastguard Worker options.SetProfilingEnabled(m_EnableProfiling); 465*89c4ff92SAndroid Build Coastguard Worker 466*89c4ff92SAndroid Build Coastguard Worker armnn::BackendOptions gpuAcc("GpuAcc", 467*89c4ff92SAndroid Build Coastguard Worker { 468*89c4ff92SAndroid Build Coastguard Worker { "FastMathEnabled", params.m_EnableFastMath }, 469*89c4ff92SAndroid Build Coastguard Worker { "SaveCachedNetwork", params.m_SaveCachedNetwork }, 470*89c4ff92SAndroid Build Coastguard Worker { "CachedNetworkFilePath", params.m_CachedNetworkFilePath }, 471*89c4ff92SAndroid Build Coastguard Worker { "MLGOTuningFilePath", params.m_MLGOTuningFilePath } 472*89c4ff92SAndroid Build Coastguard Worker }); 473*89c4ff92SAndroid Build Coastguard Worker 474*89c4ff92SAndroid Build Coastguard Worker armnn::BackendOptions cpuAcc("CpuAcc", 475*89c4ff92SAndroid Build Coastguard Worker { 476*89c4ff92SAndroid Build Coastguard Worker { "FastMathEnabled", params.m_EnableFastMath }, 477*89c4ff92SAndroid Build Coastguard Worker { "NumberOfThreads", params.m_NumberOfThreads } 478*89c4ff92SAndroid Build Coastguard Worker }); 479*89c4ff92SAndroid Build Coastguard Worker options.AddModelOption(gpuAcc); 480*89c4ff92SAndroid Build Coastguard Worker options.AddModelOption(cpuAcc); 481*89c4ff92SAndroid Build Coastguard Worker 482*89c4ff92SAndroid Build Coastguard Worker const auto optimization_start_time = armnn::GetTimeNow(); 483*89c4ff92SAndroid Build Coastguard Worker optNet = armnn::Optimize(*network, params.m_ComputeDevices, m_Runtime->GetDeviceSpec(), options); 484*89c4ff92SAndroid Build Coastguard Worker 485*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(info) << "Optimization time: " << std::setprecision(2) 486*89c4ff92SAndroid Build Coastguard Worker << std::fixed << armnn::GetTimeDuration(optimization_start_time).count() << " ms."; 487*89c4ff92SAndroid Build Coastguard Worker 488*89c4ff92SAndroid Build Coastguard Worker if (!optNet) 489*89c4ff92SAndroid Build Coastguard Worker { 490*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception("Optimize returned nullptr"); 491*89c4ff92SAndroid Build Coastguard Worker } 492*89c4ff92SAndroid Build Coastguard Worker 493*89c4ff92SAndroid Build Coastguard Worker 494*89c4ff92SAndroid Build Coastguard Worker } 495*89c4ff92SAndroid Build Coastguard Worker 496*89c4ff92SAndroid Build Coastguard Worker if (params.m_VisualizePostOptimizationModel) 497*89c4ff92SAndroid Build Coastguard Worker { 498*89c4ff92SAndroid Build Coastguard Worker fs::path filename = params.m_ModelPath; 499*89c4ff92SAndroid Build Coastguard Worker filename.replace_extension("dot"); 500*89c4ff92SAndroid Build Coastguard Worker std::fstream file(filename.c_str(), std::ios_base::out); 501*89c4ff92SAndroid Build Coastguard Worker optNet->SerializeToDot(file); 502*89c4ff92SAndroid Build Coastguard Worker } 503*89c4ff92SAndroid Build Coastguard Worker 504*89c4ff92SAndroid Build Coastguard Worker armnn::Status ret; 505*89c4ff92SAndroid Build Coastguard Worker { 506*89c4ff92SAndroid Build Coastguard Worker ARMNN_SCOPED_HEAP_PROFILING("LoadNetwork"); 507*89c4ff92SAndroid Build Coastguard Worker 508*89c4ff92SAndroid Build Coastguard Worker const auto loading_start_time = armnn::GetTimeNow(); 509*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkProperties networkProperties(params.m_AsyncEnabled, 510*89c4ff92SAndroid Build Coastguard Worker armnn::MemorySource::Undefined, 511*89c4ff92SAndroid Build Coastguard Worker armnn::MemorySource::Undefined, 512*89c4ff92SAndroid Build Coastguard Worker enableProfiling, 513*89c4ff92SAndroid Build Coastguard Worker m_ProfilingDetailsMethod); 514*89c4ff92SAndroid Build Coastguard Worker std::string errorMessage; 515*89c4ff92SAndroid Build Coastguard Worker ret = m_Runtime->LoadNetwork(m_NetworkIdentifier, std::move(optNet), errorMessage, networkProperties); 516*89c4ff92SAndroid Build Coastguard Worker 517*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(info) << "Network loading time: " << std::setprecision(2) 518*89c4ff92SAndroid Build Coastguard Worker << std::fixed << armnn::GetTimeDuration(loading_start_time).count() << " ms."; 519*89c4ff92SAndroid Build Coastguard Worker #if !defined(ARMNN_DISABLE_THREADS) 520*89c4ff92SAndroid Build Coastguard Worker if (params.m_AsyncEnabled && params.m_ThreadPoolSize > 0) 521*89c4ff92SAndroid Build Coastguard Worker { 522*89c4ff92SAndroid Build Coastguard Worker std::vector<std::shared_ptr<armnn::IWorkingMemHandle>> memHandles; 523*89c4ff92SAndroid Build Coastguard Worker for (size_t i = 0; i < params.m_ThreadPoolSize; ++i) 524*89c4ff92SAndroid Build Coastguard Worker { 525*89c4ff92SAndroid Build Coastguard Worker memHandles.emplace_back(m_Runtime->CreateWorkingMemHandle(m_NetworkIdentifier)); 526*89c4ff92SAndroid Build Coastguard Worker } 527*89c4ff92SAndroid Build Coastguard Worker 528*89c4ff92SAndroid Build Coastguard Worker m_Threadpool = std::make_unique<armnn::Threadpool>(params.m_ThreadPoolSize, 529*89c4ff92SAndroid Build Coastguard Worker m_Runtime.get(), 530*89c4ff92SAndroid Build Coastguard Worker memHandles); 531*89c4ff92SAndroid Build Coastguard Worker } 532*89c4ff92SAndroid Build Coastguard Worker #endif 533*89c4ff92SAndroid Build Coastguard Worker } 534*89c4ff92SAndroid Build Coastguard Worker 535*89c4ff92SAndroid Build Coastguard Worker if (ret == armnn::Status::Failure) 536*89c4ff92SAndroid Build Coastguard Worker { 537*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception("IRuntime::LoadNetwork failed"); 538*89c4ff92SAndroid Build Coastguard Worker } 539*89c4ff92SAndroid Build Coastguard Worker } 540*89c4ff92SAndroid Build Coastguard Worker CheckInputIndexIsValid(unsigned int inputIndex) const541*89c4ff92SAndroid Build Coastguard Worker void CheckInputIndexIsValid(unsigned int inputIndex) const 542*89c4ff92SAndroid Build Coastguard Worker { 543*89c4ff92SAndroid Build Coastguard Worker if (m_InputBindings.size() < inputIndex + 1) 544*89c4ff92SAndroid Build Coastguard Worker { 545*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception(fmt::format("Input index out of range: {}", inputIndex)); 546*89c4ff92SAndroid Build Coastguard Worker } 547*89c4ff92SAndroid Build Coastguard Worker } 548*89c4ff92SAndroid Build Coastguard Worker CheckOutputIndexIsValid(unsigned int outputIndex) const549*89c4ff92SAndroid Build Coastguard Worker void CheckOutputIndexIsValid(unsigned int outputIndex) const 550*89c4ff92SAndroid Build Coastguard Worker { 551*89c4ff92SAndroid Build Coastguard Worker if (m_OutputBindings.size() < outputIndex + 1) 552*89c4ff92SAndroid Build Coastguard Worker { 553*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception(fmt::format("Output index out of range: {}", outputIndex)); 554*89c4ff92SAndroid Build Coastguard Worker } 555*89c4ff92SAndroid Build Coastguard Worker } 556*89c4ff92SAndroid Build Coastguard Worker GetInputSize(unsigned int inputIndex=0u) const557*89c4ff92SAndroid Build Coastguard Worker unsigned int GetInputSize(unsigned int inputIndex = 0u) const 558*89c4ff92SAndroid Build Coastguard Worker { 559*89c4ff92SAndroid Build Coastguard Worker CheckInputIndexIsValid(inputIndex); 560*89c4ff92SAndroid Build Coastguard Worker return m_InputBindings[inputIndex].second.GetNumElements(); 561*89c4ff92SAndroid Build Coastguard Worker } 562*89c4ff92SAndroid Build Coastguard Worker GetOutputSize(unsigned int outputIndex=0u) const563*89c4ff92SAndroid Build Coastguard Worker unsigned int GetOutputSize(unsigned int outputIndex = 0u) const 564*89c4ff92SAndroid Build Coastguard Worker { 565*89c4ff92SAndroid Build Coastguard Worker CheckOutputIndexIsValid(outputIndex); 566*89c4ff92SAndroid Build Coastguard Worker return m_OutputBindings[outputIndex].second.GetNumElements(); 567*89c4ff92SAndroid Build Coastguard Worker } 568*89c4ff92SAndroid Build Coastguard Worker Run(const std::vector<armnnUtils::TContainer> & inputContainers,std::vector<armnnUtils::TContainer> & outputContainers)569*89c4ff92SAndroid Build Coastguard Worker std::chrono::duration<double, std::milli> Run( 570*89c4ff92SAndroid Build Coastguard Worker const std::vector<armnnUtils::TContainer>& inputContainers, 571*89c4ff92SAndroid Build Coastguard Worker std::vector<armnnUtils::TContainer>& outputContainers) 572*89c4ff92SAndroid Build Coastguard Worker { 573*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < outputContainers.size(); ++i) 574*89c4ff92SAndroid Build Coastguard Worker { 575*89c4ff92SAndroid Build Coastguard Worker const unsigned int expectedOutputDataSize = GetOutputSize(i); 576*89c4ff92SAndroid Build Coastguard Worker 577*89c4ff92SAndroid Build Coastguard Worker mapbox::util::apply_visitor([expectedOutputDataSize, i](auto&& value) 578*89c4ff92SAndroid Build Coastguard Worker { 579*89c4ff92SAndroid Build Coastguard Worker const unsigned int actualOutputDataSize = armnn::numeric_cast<unsigned int>(value.size()); 580*89c4ff92SAndroid Build Coastguard Worker if (actualOutputDataSize < expectedOutputDataSize) 581*89c4ff92SAndroid Build Coastguard Worker { 582*89c4ff92SAndroid Build Coastguard Worker unsigned int outputIndex = i; 583*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception( 584*89c4ff92SAndroid Build Coastguard Worker fmt::format("Not enough data for output #{0}: expected " 585*89c4ff92SAndroid Build Coastguard Worker "{1} elements, got {2}", outputIndex, expectedOutputDataSize, actualOutputDataSize)); 586*89c4ff92SAndroid Build Coastguard Worker } 587*89c4ff92SAndroid Build Coastguard Worker }, 588*89c4ff92SAndroid Build Coastguard Worker outputContainers[i]); 589*89c4ff92SAndroid Build Coastguard Worker } 590*89c4ff92SAndroid Build Coastguard Worker 591*89c4ff92SAndroid Build Coastguard Worker std::shared_ptr<armnn::IProfiler> profiler = m_Runtime->GetProfiler(m_NetworkIdentifier); 592*89c4ff92SAndroid Build Coastguard Worker 593*89c4ff92SAndroid Build Coastguard Worker // Start timer to record inference time in EnqueueWorkload (in milliseconds) 594*89c4ff92SAndroid Build Coastguard Worker const auto start_time = armnn::GetTimeNow(); 595*89c4ff92SAndroid Build Coastguard Worker 596*89c4ff92SAndroid Build Coastguard Worker armnn::Status ret; 597*89c4ff92SAndroid Build Coastguard Worker if (m_ImportInputsIfAligned) 598*89c4ff92SAndroid Build Coastguard Worker { 599*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::ImportedInputId> importedInputIds = m_Runtime->ImportInputs( 600*89c4ff92SAndroid Build Coastguard Worker m_NetworkIdentifier, MakeInputTensors(inputContainers), armnn::MemorySource::Malloc); 601*89c4ff92SAndroid Build Coastguard Worker 602*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::ImportedOutputId> importedOutputIds = m_Runtime->ImportOutputs( 603*89c4ff92SAndroid Build Coastguard Worker m_NetworkIdentifier, MakeOutputTensors(outputContainers), armnn::MemorySource::Malloc); 604*89c4ff92SAndroid Build Coastguard Worker 605*89c4ff92SAndroid Build Coastguard Worker ret = m_Runtime->EnqueueWorkload(m_NetworkIdentifier, 606*89c4ff92SAndroid Build Coastguard Worker MakeInputTensors(inputContainers), 607*89c4ff92SAndroid Build Coastguard Worker MakeOutputTensors(outputContainers), 608*89c4ff92SAndroid Build Coastguard Worker importedInputIds, 609*89c4ff92SAndroid Build Coastguard Worker importedOutputIds); 610*89c4ff92SAndroid Build Coastguard Worker } 611*89c4ff92SAndroid Build Coastguard Worker else 612*89c4ff92SAndroid Build Coastguard Worker { 613*89c4ff92SAndroid Build Coastguard Worker ret = m_Runtime->EnqueueWorkload(m_NetworkIdentifier, 614*89c4ff92SAndroid Build Coastguard Worker MakeInputTensors(inputContainers), 615*89c4ff92SAndroid Build Coastguard Worker MakeOutputTensors(outputContainers)); 616*89c4ff92SAndroid Build Coastguard Worker } 617*89c4ff92SAndroid Build Coastguard Worker const auto duration = armnn::GetTimeDuration(start_time); 618*89c4ff92SAndroid Build Coastguard Worker 619*89c4ff92SAndroid Build Coastguard Worker // if profiling is enabled print out the results 620*89c4ff92SAndroid Build Coastguard Worker if (profiler && profiler->IsProfilingEnabled()) 621*89c4ff92SAndroid Build Coastguard Worker { 622*89c4ff92SAndroid Build Coastguard Worker profiler->Print(std::cout); 623*89c4ff92SAndroid Build Coastguard Worker } 624*89c4ff92SAndroid Build Coastguard Worker 625*89c4ff92SAndroid Build Coastguard Worker if (ret == armnn::Status::Failure) 626*89c4ff92SAndroid Build Coastguard Worker { 627*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception("IRuntime::EnqueueWorkload failed"); 628*89c4ff92SAndroid Build Coastguard Worker } 629*89c4ff92SAndroid Build Coastguard Worker else 630*89c4ff92SAndroid Build Coastguard Worker { 631*89c4ff92SAndroid Build Coastguard Worker return duration; 632*89c4ff92SAndroid Build Coastguard Worker } 633*89c4ff92SAndroid Build Coastguard Worker } 634*89c4ff92SAndroid Build Coastguard Worker RunAsync(armnn::experimental::IWorkingMemHandle & workingMemHandleRef,const std::vector<armnnUtils::TContainer> & inputContainers,std::vector<armnnUtils::TContainer> & outputContainers,unsigned int inferenceID)635*89c4ff92SAndroid Build Coastguard Worker std::tuple<unsigned int, std::chrono::duration<double, std::milli>> RunAsync( 636*89c4ff92SAndroid Build Coastguard Worker armnn::experimental::IWorkingMemHandle& workingMemHandleRef, 637*89c4ff92SAndroid Build Coastguard Worker const std::vector<armnnUtils::TContainer>& inputContainers, 638*89c4ff92SAndroid Build Coastguard Worker std::vector<armnnUtils::TContainer>& outputContainers, 639*89c4ff92SAndroid Build Coastguard Worker unsigned int inferenceID) 640*89c4ff92SAndroid Build Coastguard Worker { 641*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < outputContainers.size(); ++i) 642*89c4ff92SAndroid Build Coastguard Worker { 643*89c4ff92SAndroid Build Coastguard Worker const unsigned int expectedOutputDataSize = GetOutputSize(i); 644*89c4ff92SAndroid Build Coastguard Worker 645*89c4ff92SAndroid Build Coastguard Worker mapbox::util::apply_visitor([expectedOutputDataSize, i](auto&& value) 646*89c4ff92SAndroid Build Coastguard Worker { 647*89c4ff92SAndroid Build Coastguard Worker const unsigned int actualOutputDataSize = armnn::numeric_cast<unsigned int>(value.size()); 648*89c4ff92SAndroid Build Coastguard Worker if (actualOutputDataSize < expectedOutputDataSize) 649*89c4ff92SAndroid Build Coastguard Worker { 650*89c4ff92SAndroid Build Coastguard Worker unsigned int outputIndex = i; 651*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception( 652*89c4ff92SAndroid Build Coastguard Worker fmt::format("Not enough data for output #{0}: expected " 653*89c4ff92SAndroid Build Coastguard Worker "{1} elements, got {2}", outputIndex, expectedOutputDataSize, actualOutputDataSize)); 654*89c4ff92SAndroid Build Coastguard Worker } 655*89c4ff92SAndroid Build Coastguard Worker }, 656*89c4ff92SAndroid Build Coastguard Worker outputContainers[i]); 657*89c4ff92SAndroid Build Coastguard Worker } 658*89c4ff92SAndroid Build Coastguard Worker 659*89c4ff92SAndroid Build Coastguard Worker std::shared_ptr<armnn::IProfiler> profiler = m_Runtime->GetProfiler(m_NetworkIdentifier); 660*89c4ff92SAndroid Build Coastguard Worker 661*89c4ff92SAndroid Build Coastguard Worker // Start timer to record inference time in EnqueueWorkload (in milliseconds) 662*89c4ff92SAndroid Build Coastguard Worker const auto start_time = armnn::GetTimeNow(); 663*89c4ff92SAndroid Build Coastguard Worker 664*89c4ff92SAndroid Build Coastguard Worker armnn::Status ret = m_Runtime->Execute(workingMemHandleRef, 665*89c4ff92SAndroid Build Coastguard Worker MakeInputTensors(inputContainers), 666*89c4ff92SAndroid Build Coastguard Worker MakeOutputTensors(outputContainers)); 667*89c4ff92SAndroid Build Coastguard Worker 668*89c4ff92SAndroid Build Coastguard Worker const auto duration = armnn::GetTimeDuration(start_time); 669*89c4ff92SAndroid Build Coastguard Worker 670*89c4ff92SAndroid Build Coastguard Worker // if profiling is enabled print out the results 671*89c4ff92SAndroid Build Coastguard Worker if (profiler && profiler->IsProfilingEnabled()) 672*89c4ff92SAndroid Build Coastguard Worker { 673*89c4ff92SAndroid Build Coastguard Worker profiler->Print(std::cout); 674*89c4ff92SAndroid Build Coastguard Worker } 675*89c4ff92SAndroid Build Coastguard Worker 676*89c4ff92SAndroid Build Coastguard Worker if (ret == armnn::Status::Failure) 677*89c4ff92SAndroid Build Coastguard Worker { 678*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception( 679*89c4ff92SAndroid Build Coastguard Worker fmt::format("IRuntime::Execute asynchronously failed for network #{0} on inference #{1}", 680*89c4ff92SAndroid Build Coastguard Worker m_NetworkIdentifier, inferenceID)); 681*89c4ff92SAndroid Build Coastguard Worker } 682*89c4ff92SAndroid Build Coastguard Worker else 683*89c4ff92SAndroid Build Coastguard Worker { 684*89c4ff92SAndroid Build Coastguard Worker return std::make_tuple(inferenceID, duration); 685*89c4ff92SAndroid Build Coastguard Worker } 686*89c4ff92SAndroid Build Coastguard Worker } 687*89c4ff92SAndroid Build Coastguard Worker RunAsync(const std::vector<armnnUtils::TContainer> & inputContainers,std::vector<armnnUtils::TContainer> & outputContainers,std::shared_ptr<armnn::IAsyncExecutionCallback> cb)688*89c4ff92SAndroid Build Coastguard Worker void RunAsync(const std::vector<armnnUtils::TContainer>& inputContainers, 689*89c4ff92SAndroid Build Coastguard Worker std::vector<armnnUtils::TContainer>& outputContainers, 690*89c4ff92SAndroid Build Coastguard Worker std::shared_ptr<armnn::IAsyncExecutionCallback> cb) 691*89c4ff92SAndroid Build Coastguard Worker { 692*89c4ff92SAndroid Build Coastguard Worker #if !defined(ARMNN_DISABLE_THREADS) 693*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < outputContainers.size(); ++i) 694*89c4ff92SAndroid Build Coastguard Worker { 695*89c4ff92SAndroid Build Coastguard Worker const unsigned int expectedOutputDataSize = GetOutputSize(i); 696*89c4ff92SAndroid Build Coastguard Worker 697*89c4ff92SAndroid Build Coastguard Worker mapbox::util::apply_visitor([expectedOutputDataSize, i](auto&& value) 698*89c4ff92SAndroid Build Coastguard Worker { 699*89c4ff92SAndroid Build Coastguard Worker const unsigned int actualOutputDataSize = armnn::numeric_cast<unsigned int>(value.size()); 700*89c4ff92SAndroid Build Coastguard Worker if (actualOutputDataSize < expectedOutputDataSize) 701*89c4ff92SAndroid Build Coastguard Worker { 702*89c4ff92SAndroid Build Coastguard Worker unsigned int outputIndex = i; 703*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception( 704*89c4ff92SAndroid Build Coastguard Worker fmt::format("Not enough data for output #{0}: expected " 705*89c4ff92SAndroid Build Coastguard Worker "{1} elements, got {2}", outputIndex, expectedOutputDataSize, actualOutputDataSize)); 706*89c4ff92SAndroid Build Coastguard Worker } 707*89c4ff92SAndroid Build Coastguard Worker }, 708*89c4ff92SAndroid Build Coastguard Worker outputContainers[i]); 709*89c4ff92SAndroid Build Coastguard Worker } 710*89c4ff92SAndroid Build Coastguard Worker 711*89c4ff92SAndroid Build Coastguard Worker std::shared_ptr<armnn::IProfiler> profiler = m_Runtime->GetProfiler(m_NetworkIdentifier); 712*89c4ff92SAndroid Build Coastguard Worker 713*89c4ff92SAndroid Build Coastguard Worker m_Threadpool->Schedule(m_NetworkIdentifier, 714*89c4ff92SAndroid Build Coastguard Worker MakeInputTensors(inputContainers), 715*89c4ff92SAndroid Build Coastguard Worker MakeOutputTensors(outputContainers), 716*89c4ff92SAndroid Build Coastguard Worker armnn::QosExecPriority::Medium, 717*89c4ff92SAndroid Build Coastguard Worker cb); 718*89c4ff92SAndroid Build Coastguard Worker 719*89c4ff92SAndroid Build Coastguard Worker // if profiling is enabled print out the results 720*89c4ff92SAndroid Build Coastguard Worker if (profiler && profiler->IsProfilingEnabled()) 721*89c4ff92SAndroid Build Coastguard Worker { 722*89c4ff92SAndroid Build Coastguard Worker profiler->Print(std::cout); 723*89c4ff92SAndroid Build Coastguard Worker } 724*89c4ff92SAndroid Build Coastguard Worker #endif 725*89c4ff92SAndroid Build Coastguard Worker } 726*89c4ff92SAndroid Build Coastguard Worker GetInputBindingInfo(unsigned int inputIndex=0u) const727*89c4ff92SAndroid Build Coastguard Worker const armnn::BindingPointInfo& GetInputBindingInfo(unsigned int inputIndex = 0u) const 728*89c4ff92SAndroid Build Coastguard Worker { 729*89c4ff92SAndroid Build Coastguard Worker CheckInputIndexIsValid(inputIndex); 730*89c4ff92SAndroid Build Coastguard Worker return m_InputBindings[inputIndex]; 731*89c4ff92SAndroid Build Coastguard Worker } 732*89c4ff92SAndroid Build Coastguard Worker GetInputBindingInfos() const733*89c4ff92SAndroid Build Coastguard Worker const std::vector<armnn::BindingPointInfo>& GetInputBindingInfos() const 734*89c4ff92SAndroid Build Coastguard Worker { 735*89c4ff92SAndroid Build Coastguard Worker return m_InputBindings; 736*89c4ff92SAndroid Build Coastguard Worker } 737*89c4ff92SAndroid Build Coastguard Worker GetOutputBindingInfo(unsigned int outputIndex=0u) const738*89c4ff92SAndroid Build Coastguard Worker const armnn::BindingPointInfo& GetOutputBindingInfo(unsigned int outputIndex = 0u) const 739*89c4ff92SAndroid Build Coastguard Worker { 740*89c4ff92SAndroid Build Coastguard Worker CheckOutputIndexIsValid(outputIndex); 741*89c4ff92SAndroid Build Coastguard Worker return m_OutputBindings[outputIndex]; 742*89c4ff92SAndroid Build Coastguard Worker } 743*89c4ff92SAndroid Build Coastguard Worker GetOutputBindingInfos() const744*89c4ff92SAndroid Build Coastguard Worker const std::vector<armnn::BindingPointInfo>& GetOutputBindingInfos() const 745*89c4ff92SAndroid Build Coastguard Worker { 746*89c4ff92SAndroid Build Coastguard Worker return m_OutputBindings; 747*89c4ff92SAndroid Build Coastguard Worker } 748*89c4ff92SAndroid Build Coastguard Worker GetQuantizationParams(unsigned int outputIndex=0u) const749*89c4ff92SAndroid Build Coastguard Worker QuantizationParams GetQuantizationParams(unsigned int outputIndex = 0u) const 750*89c4ff92SAndroid Build Coastguard Worker { 751*89c4ff92SAndroid Build Coastguard Worker CheckOutputIndexIsValid(outputIndex); 752*89c4ff92SAndroid Build Coastguard Worker return std::make_pair(m_OutputBindings[outputIndex].second.GetQuantizationScale(), 753*89c4ff92SAndroid Build Coastguard Worker m_OutputBindings[outputIndex].second.GetQuantizationOffset()); 754*89c4ff92SAndroid Build Coastguard Worker } 755*89c4ff92SAndroid Build Coastguard Worker GetInputQuantizationParams(unsigned int inputIndex=0u) const756*89c4ff92SAndroid Build Coastguard Worker QuantizationParams GetInputQuantizationParams(unsigned int inputIndex = 0u) const 757*89c4ff92SAndroid Build Coastguard Worker { 758*89c4ff92SAndroid Build Coastguard Worker CheckInputIndexIsValid(inputIndex); 759*89c4ff92SAndroid Build Coastguard Worker return std::make_pair(m_InputBindings[inputIndex].second.GetQuantizationScale(), 760*89c4ff92SAndroid Build Coastguard Worker m_InputBindings[inputIndex].second.GetQuantizationOffset()); 761*89c4ff92SAndroid Build Coastguard Worker } 762*89c4ff92SAndroid Build Coastguard Worker GetAllQuantizationParams() const763*89c4ff92SAndroid Build Coastguard Worker std::vector<QuantizationParams> GetAllQuantizationParams() const 764*89c4ff92SAndroid Build Coastguard Worker { 765*89c4ff92SAndroid Build Coastguard Worker std::vector<QuantizationParams> quantizationParams; 766*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0u; i < m_OutputBindings.size(); i++) 767*89c4ff92SAndroid Build Coastguard Worker { 768*89c4ff92SAndroid Build Coastguard Worker quantizationParams.push_back(GetQuantizationParams(i)); 769*89c4ff92SAndroid Build Coastguard Worker } 770*89c4ff92SAndroid Build Coastguard Worker return quantizationParams; 771*89c4ff92SAndroid Build Coastguard Worker } 772*89c4ff92SAndroid Build Coastguard Worker CreateWorkingMemHandle()773*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<armnn::experimental::IWorkingMemHandle> CreateWorkingMemHandle() 774*89c4ff92SAndroid Build Coastguard Worker { 775*89c4ff92SAndroid Build Coastguard Worker return m_Runtime->CreateWorkingMemHandle(m_NetworkIdentifier); 776*89c4ff92SAndroid Build Coastguard Worker } 777*89c4ff92SAndroid Build Coastguard Worker 778*89c4ff92SAndroid Build Coastguard Worker private: 779*89c4ff92SAndroid Build Coastguard Worker armnn::NetworkId m_NetworkIdentifier; 780*89c4ff92SAndroid Build Coastguard Worker std::shared_ptr<armnn::IRuntime> m_Runtime; 781*89c4ff92SAndroid Build Coastguard Worker #if !defined(ARMNN_DISABLE_THREADS) 782*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<armnn::Threadpool> m_Threadpool; 783*89c4ff92SAndroid Build Coastguard Worker #endif 784*89c4ff92SAndroid Build Coastguard Worker 785*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BindingPointInfo> m_InputBindings; 786*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BindingPointInfo> m_OutputBindings; 787*89c4ff92SAndroid Build Coastguard Worker bool m_EnableProfiling; 788*89c4ff92SAndroid Build Coastguard Worker armnn::ProfilingDetailsMethod m_ProfilingDetailsMethod; 789*89c4ff92SAndroid Build Coastguard Worker std::string m_DynamicBackendsPath; 790*89c4ff92SAndroid Build Coastguard Worker bool m_ImportInputsIfAligned; 791*89c4ff92SAndroid Build Coastguard Worker 792*89c4ff92SAndroid Build Coastguard Worker template<typename TContainer> MakeInputTensors(const std::vector<TContainer> & inputDataContainers)793*89c4ff92SAndroid Build Coastguard Worker armnn::InputTensors MakeInputTensors(const std::vector<TContainer>& inputDataContainers) 794*89c4ff92SAndroid Build Coastguard Worker { 795*89c4ff92SAndroid Build Coastguard Worker return armnnUtils::MakeInputTensors(m_InputBindings, inputDataContainers); 796*89c4ff92SAndroid Build Coastguard Worker } 797*89c4ff92SAndroid Build Coastguard Worker 798*89c4ff92SAndroid Build Coastguard Worker template<typename TContainer> MakeOutputTensors(std::vector<TContainer> & outputDataContainers)799*89c4ff92SAndroid Build Coastguard Worker armnn::OutputTensors MakeOutputTensors(std::vector<TContainer>& outputDataContainers) 800*89c4ff92SAndroid Build Coastguard Worker { 801*89c4ff92SAndroid Build Coastguard Worker return armnnUtils::MakeOutputTensors(m_OutputBindings, outputDataContainers); 802*89c4ff92SAndroid Build Coastguard Worker } 803*89c4ff92SAndroid Build Coastguard Worker }; 804