1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #pragma once 7*89c4ff92SAndroid Build Coastguard Worker 8*89c4ff92SAndroid Build Coastguard Worker #include "IExecutor.hpp" 9*89c4ff92SAndroid Build Coastguard Worker #include "NetworkExecutionUtils/NetworkExecutionUtils.hpp" 10*89c4ff92SAndroid Build Coastguard Worker #include "ExecuteNetworkProgramOptions.hpp" 11*89c4ff92SAndroid Build Coastguard Worker #include "armnn/utility/NumericCast.hpp" 12*89c4ff92SAndroid Build Coastguard Worker #include "armnn/utility/Timer.hpp" 13*89c4ff92SAndroid Build Coastguard Worker 14*89c4ff92SAndroid Build Coastguard Worker #include <armnn/ArmNN.hpp> 15*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Threadpool.hpp> 16*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Logging.hpp> 17*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Timer.hpp> 18*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendRegistry.hpp> 19*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Assert.hpp> 20*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/NumericCast.hpp> 21*89c4ff92SAndroid Build Coastguard Worker 22*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Filesystem.hpp> 23*89c4ff92SAndroid Build Coastguard Worker #include <HeapProfiling.hpp> 24*89c4ff92SAndroid Build Coastguard Worker 25*89c4ff92SAndroid Build Coastguard Worker #include <fmt/format.h> 26*89c4ff92SAndroid Build Coastguard Worker 27*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_SERIALIZER) 28*89c4ff92SAndroid Build Coastguard Worker #include "armnnDeserializer/IDeserializer.hpp" 29*89c4ff92SAndroid Build Coastguard Worker #endif 30*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_TF_LITE_PARSER) 31*89c4ff92SAndroid Build Coastguard Worker #include <armnnTfLiteParser/ITfLiteParser.hpp> 32*89c4ff92SAndroid Build Coastguard Worker #endif 33*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_ONNX_PARSER) 34*89c4ff92SAndroid Build Coastguard Worker #include <armnnOnnxParser/IOnnxParser.hpp> 35*89c4ff92SAndroid Build Coastguard Worker #endif 36*89c4ff92SAndroid Build Coastguard Worker 37*89c4ff92SAndroid Build Coastguard Worker class ArmNNExecutor : public IExecutor 38*89c4ff92SAndroid Build Coastguard Worker { 39*89c4ff92SAndroid Build Coastguard Worker public: 40*89c4ff92SAndroid Build Coastguard Worker ArmNNExecutor(const ExecuteNetworkParams& params, armnn::IRuntime::CreationOptions runtimeOptions); 41*89c4ff92SAndroid Build Coastguard Worker 42*89c4ff92SAndroid Build Coastguard Worker std::vector<const void* > Execute() override; 43*89c4ff92SAndroid Build Coastguard Worker void PrintNetworkInfo() override; 44*89c4ff92SAndroid Build Coastguard Worker void CompareAndPrintResult(std::vector<const void*> otherOutput) override; 45*89c4ff92SAndroid Build Coastguard Worker 46*89c4ff92SAndroid Build Coastguard Worker private: 47*89c4ff92SAndroid Build Coastguard Worker 48*89c4ff92SAndroid Build Coastguard Worker /** 49*89c4ff92SAndroid Build Coastguard Worker * Returns a pointer to the armnn::IRuntime* this will be shared by all ArmNNExecutors. 50*89c4ff92SAndroid Build Coastguard Worker */ GetRuntime(const armnn::IRuntime::CreationOptions & options)51*89c4ff92SAndroid Build Coastguard Worker armnn::IRuntime* GetRuntime(const armnn::IRuntime::CreationOptions& options) 52*89c4ff92SAndroid Build Coastguard Worker { 53*89c4ff92SAndroid Build Coastguard Worker static armnn::IRuntimePtr instance = armnn::IRuntime::Create(options); 54*89c4ff92SAndroid Build Coastguard Worker // Instantiated on first use. 55*89c4ff92SAndroid Build Coastguard Worker return instance.get(); 56*89c4ff92SAndroid Build Coastguard Worker } 57*89c4ff92SAndroid Build Coastguard Worker 58*89c4ff92SAndroid Build Coastguard Worker struct IParser; 59*89c4ff92SAndroid Build Coastguard Worker struct IOInfo; 60*89c4ff92SAndroid Build Coastguard Worker struct IOStorage; 61*89c4ff92SAndroid Build Coastguard Worker 62*89c4ff92SAndroid Build Coastguard Worker using BindingPointInfo = armnn::BindingPointInfo; 63*89c4ff92SAndroid Build Coastguard Worker 64*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<IParser> CreateParser(); 65*89c4ff92SAndroid Build Coastguard Worker 66*89c4ff92SAndroid Build Coastguard Worker void ExecuteAsync(); 67*89c4ff92SAndroid Build Coastguard Worker void ExecuteSync(); 68*89c4ff92SAndroid Build Coastguard Worker void SetupInputsAndOutputs(); 69*89c4ff92SAndroid Build Coastguard Worker 70*89c4ff92SAndroid Build Coastguard Worker IOInfo GetIOInfo(armnn::IOptimizedNetwork* optNet); 71*89c4ff92SAndroid Build Coastguard Worker 72*89c4ff92SAndroid Build Coastguard Worker void PrintOutputTensors(const armnn::OutputTensors* outputTensors, unsigned int iteration); 73*89c4ff92SAndroid Build Coastguard Worker 74*89c4ff92SAndroid Build Coastguard Worker armnn::IOptimizedNetworkPtr OptimizeNetwork(armnn::INetwork* network); 75*89c4ff92SAndroid Build Coastguard Worker 76*89c4ff92SAndroid Build Coastguard Worker struct IOStorage 77*89c4ff92SAndroid Build Coastguard Worker { IOStorageArmNNExecutor::IOStorage78*89c4ff92SAndroid Build Coastguard Worker IOStorage(size_t size) 79*89c4ff92SAndroid Build Coastguard Worker { 80*89c4ff92SAndroid Build Coastguard Worker m_Mem = operator new(size); 81*89c4ff92SAndroid Build Coastguard Worker } ~IOStorageArmNNExecutor::IOStorage82*89c4ff92SAndroid Build Coastguard Worker ~IOStorage() 83*89c4ff92SAndroid Build Coastguard Worker { 84*89c4ff92SAndroid Build Coastguard Worker operator delete(m_Mem); 85*89c4ff92SAndroid Build Coastguard Worker } IOStorageArmNNExecutor::IOStorage86*89c4ff92SAndroid Build Coastguard Worker IOStorage(IOStorage&& rhs) 87*89c4ff92SAndroid Build Coastguard Worker { 88*89c4ff92SAndroid Build Coastguard Worker this->m_Mem = rhs.m_Mem; 89*89c4ff92SAndroid Build Coastguard Worker rhs.m_Mem = nullptr; 90*89c4ff92SAndroid Build Coastguard Worker } 91*89c4ff92SAndroid Build Coastguard Worker 92*89c4ff92SAndroid Build Coastguard Worker IOStorage(const IOStorage& rhs) = delete; 93*89c4ff92SAndroid Build Coastguard Worker IOStorage& operator=(IOStorage& rhs) = delete; 94*89c4ff92SAndroid Build Coastguard Worker IOStorage& operator=(IOStorage&& rhs) = delete; 95*89c4ff92SAndroid Build Coastguard Worker 96*89c4ff92SAndroid Build Coastguard Worker void* m_Mem; 97*89c4ff92SAndroid Build Coastguard Worker }; 98*89c4ff92SAndroid Build Coastguard Worker 99*89c4ff92SAndroid Build Coastguard Worker struct IOInfo 100*89c4ff92SAndroid Build Coastguard Worker { 101*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string> m_InputNames; 102*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string> m_OutputNames; 103*89c4ff92SAndroid Build Coastguard Worker std::map<std::string, armnn::BindingPointInfo> m_InputInfoMap; 104*89c4ff92SAndroid Build Coastguard Worker std::map<std::string, armnn::BindingPointInfo> m_OutputInfoMap; 105*89c4ff92SAndroid Build Coastguard Worker }; 106*89c4ff92SAndroid Build Coastguard Worker 107*89c4ff92SAndroid Build Coastguard Worker IOInfo m_IOInfo; 108*89c4ff92SAndroid Build Coastguard Worker std::vector<IOStorage> m_InputStorage; 109*89c4ff92SAndroid Build Coastguard Worker std::vector<IOStorage> m_OutputStorage; 110*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::InputTensors> m_InputTensorsVec; 111*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::OutputTensors> m_OutputTensorsVec; 112*89c4ff92SAndroid Build Coastguard Worker std::vector<std::vector<unsigned int>> m_ImportedInputIds; 113*89c4ff92SAndroid Build Coastguard Worker std::vector<std::vector<unsigned int>> m_ImportedOutputIds; 114*89c4ff92SAndroid Build Coastguard Worker armnn::IRuntime* m_Runtime; 115*89c4ff92SAndroid Build Coastguard Worker armnn::NetworkId m_NetworkId; 116*89c4ff92SAndroid Build Coastguard Worker ExecuteNetworkParams m_Params; 117*89c4ff92SAndroid Build Coastguard Worker 118*89c4ff92SAndroid Build Coastguard Worker struct IParser 119*89c4ff92SAndroid Build Coastguard Worker { 120*89c4ff92SAndroid Build Coastguard Worker virtual armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) = 0; 121*89c4ff92SAndroid Build Coastguard Worker virtual armnn::BindingPointInfo GetInputBindingPointInfo(size_t id, const std::string& inputName) = 0; 122*89c4ff92SAndroid Build Coastguard Worker virtual armnn::BindingPointInfo GetOutputBindingPointInfo(size_t id, const std::string& outputName) = 0; 123*89c4ff92SAndroid Build Coastguard Worker ~IParserArmNNExecutor::IParser124*89c4ff92SAndroid Build Coastguard Worker virtual ~IParser(){}; 125*89c4ff92SAndroid Build Coastguard Worker }; 126*89c4ff92SAndroid Build Coastguard Worker 127*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_SERIALIZER) 128*89c4ff92SAndroid Build Coastguard Worker class ArmNNDeserializer : public IParser 129*89c4ff92SAndroid Build Coastguard Worker { 130*89c4ff92SAndroid Build Coastguard Worker public: 131*89c4ff92SAndroid Build Coastguard Worker ArmNNDeserializer(); 132*89c4ff92SAndroid Build Coastguard Worker 133*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) override; 134*89c4ff92SAndroid Build Coastguard Worker armnn::BindingPointInfo GetInputBindingPointInfo(size_t, const std::string& inputName) override; 135*89c4ff92SAndroid Build Coastguard Worker armnn::BindingPointInfo GetOutputBindingPointInfo(size_t, const std::string& outputName) override; 136*89c4ff92SAndroid Build Coastguard Worker 137*89c4ff92SAndroid Build Coastguard Worker private: 138*89c4ff92SAndroid Build Coastguard Worker armnnDeserializer::IDeserializerPtr m_Parser; 139*89c4ff92SAndroid Build Coastguard Worker }; 140*89c4ff92SAndroid Build Coastguard Worker #endif 141*89c4ff92SAndroid Build Coastguard Worker 142*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_TF_LITE_PARSER) 143*89c4ff92SAndroid Build Coastguard Worker class TfliteParser : public IParser 144*89c4ff92SAndroid Build Coastguard Worker { 145*89c4ff92SAndroid Build Coastguard Worker public: 146*89c4ff92SAndroid Build Coastguard Worker TfliteParser(const ExecuteNetworkParams& params); 147*89c4ff92SAndroid Build Coastguard Worker 148*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) override; 149*89c4ff92SAndroid Build Coastguard Worker armnn::BindingPointInfo GetInputBindingPointInfo(size_t subgraphId, const std::string& inputName) override; 150*89c4ff92SAndroid Build Coastguard Worker armnn::BindingPointInfo GetOutputBindingPointInfo(size_t subgraphId, const std::string& outputName) override; 151*89c4ff92SAndroid Build Coastguard Worker 152*89c4ff92SAndroid Build Coastguard Worker private: __anondf7b68c10102()153*89c4ff92SAndroid Build Coastguard Worker armnnTfLiteParser::ITfLiteParserPtr m_Parser{nullptr, [](armnnTfLiteParser::ITfLiteParser*){}}; 154*89c4ff92SAndroid Build Coastguard Worker }; 155*89c4ff92SAndroid Build Coastguard Worker #endif 156*89c4ff92SAndroid Build Coastguard Worker 157*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_ONNX_PARSER) 158*89c4ff92SAndroid Build Coastguard Worker class OnnxParser : public IParser 159*89c4ff92SAndroid Build Coastguard Worker { 160*89c4ff92SAndroid Build Coastguard Worker public: 161*89c4ff92SAndroid Build Coastguard Worker OnnxParser(); 162*89c4ff92SAndroid Build Coastguard Worker 163*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) override; 164*89c4ff92SAndroid Build Coastguard Worker armnn::BindingPointInfo GetInputBindingPointInfo(size_t subgraphId, const std::string& inputName) override; 165*89c4ff92SAndroid Build Coastguard Worker armnn::BindingPointInfo GetOutputBindingPointInfo(size_t subgraphId, const std::string& outputName) override; 166*89c4ff92SAndroid Build Coastguard Worker 167*89c4ff92SAndroid Build Coastguard Worker private: 168*89c4ff92SAndroid Build Coastguard Worker armnnOnnxParser::IOnnxParserPtr m_Parser; 169*89c4ff92SAndroid Build Coastguard Worker }; 170*89c4ff92SAndroid Build Coastguard Worker #endif 171*89c4ff92SAndroid Build Coastguard Worker };