1 // 2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include "IExecutor.hpp" 9 #include "NetworkExecutionUtils/NetworkExecutionUtils.hpp" 10 #include "ExecuteNetworkProgramOptions.hpp" 11 #include "armnn/utility/NumericCast.hpp" 12 #include "armnn/utility/Timer.hpp" 13 14 #include <armnn/ArmNN.hpp> 15 #include <armnn/Threadpool.hpp> 16 #include <armnn/Logging.hpp> 17 #include <armnn/utility/Timer.hpp> 18 #include <armnn/BackendRegistry.hpp> 19 #include <armnn/utility/Assert.hpp> 20 #include <armnn/utility/NumericCast.hpp> 21 22 #include <armnnUtils/Filesystem.hpp> 23 #include <HeapProfiling.hpp> 24 25 #include <fmt/format.h> 26 27 #if defined(ARMNN_SERIALIZER) 28 #include "armnnDeserializer/IDeserializer.hpp" 29 #endif 30 #if defined(ARMNN_TF_LITE_PARSER) 31 #include <armnnTfLiteParser/ITfLiteParser.hpp> 32 #endif 33 #if defined(ARMNN_ONNX_PARSER) 34 #include <armnnOnnxParser/IOnnxParser.hpp> 35 #endif 36 37 class ArmNNExecutor : public IExecutor 38 { 39 public: 40 ArmNNExecutor(const ExecuteNetworkParams& params, armnn::IRuntime::CreationOptions runtimeOptions); 41 42 std::vector<const void* > Execute() override; 43 void PrintNetworkInfo() override; 44 void CompareAndPrintResult(std::vector<const void*> otherOutput) override; 45 46 private: 47 48 /** 49 * Returns a pointer to the armnn::IRuntime* this will be shared by all ArmNNExecutors. 50 */ GetRuntime(const armnn::IRuntime::CreationOptions & options)51 armnn::IRuntime* GetRuntime(const armnn::IRuntime::CreationOptions& options) 52 { 53 static armnn::IRuntimePtr instance = armnn::IRuntime::Create(options); 54 // Instantiated on first use. 55 return instance.get(); 56 } 57 58 struct IParser; 59 struct IOInfo; 60 struct IOStorage; 61 62 using BindingPointInfo = armnn::BindingPointInfo; 63 64 std::unique_ptr<IParser> CreateParser(); 65 66 void ExecuteAsync(); 67 void ExecuteSync(); 68 void SetupInputsAndOutputs(); 69 70 IOInfo GetIOInfo(armnn::IOptimizedNetwork* optNet); 71 72 void PrintOutputTensors(const armnn::OutputTensors* outputTensors, unsigned int iteration); 73 74 armnn::IOptimizedNetworkPtr OptimizeNetwork(armnn::INetwork* network); 75 76 struct IOStorage 77 { IOStorageArmNNExecutor::IOStorage78 IOStorage(size_t size) 79 { 80 m_Mem = operator new(size); 81 } ~IOStorageArmNNExecutor::IOStorage82 ~IOStorage() 83 { 84 operator delete(m_Mem); 85 } IOStorageArmNNExecutor::IOStorage86 IOStorage(IOStorage&& rhs) 87 { 88 this->m_Mem = rhs.m_Mem; 89 rhs.m_Mem = nullptr; 90 } 91 92 IOStorage(const IOStorage& rhs) = delete; 93 IOStorage& operator=(IOStorage& rhs) = delete; 94 IOStorage& operator=(IOStorage&& rhs) = delete; 95 96 void* m_Mem; 97 }; 98 99 struct IOInfo 100 { 101 std::vector<std::string> m_InputNames; 102 std::vector<std::string> m_OutputNames; 103 std::map<std::string, armnn::BindingPointInfo> m_InputInfoMap; 104 std::map<std::string, armnn::BindingPointInfo> m_OutputInfoMap; 105 }; 106 107 IOInfo m_IOInfo; 108 std::vector<IOStorage> m_InputStorage; 109 std::vector<IOStorage> m_OutputStorage; 110 std::vector<armnn::InputTensors> m_InputTensorsVec; 111 std::vector<armnn::OutputTensors> m_OutputTensorsVec; 112 std::vector<std::vector<unsigned int>> m_ImportedInputIds; 113 std::vector<std::vector<unsigned int>> m_ImportedOutputIds; 114 armnn::IRuntime* m_Runtime; 115 armnn::NetworkId m_NetworkId; 116 ExecuteNetworkParams m_Params; 117 118 struct IParser 119 { 120 virtual armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) = 0; 121 virtual armnn::BindingPointInfo GetInputBindingPointInfo(size_t id, const std::string& inputName) = 0; 122 virtual armnn::BindingPointInfo GetOutputBindingPointInfo(size_t id, const std::string& outputName) = 0; 123 ~IParserArmNNExecutor::IParser124 virtual ~IParser(){}; 125 }; 126 127 #if defined(ARMNN_SERIALIZER) 128 class ArmNNDeserializer : public IParser 129 { 130 public: 131 ArmNNDeserializer(); 132 133 armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) override; 134 armnn::BindingPointInfo GetInputBindingPointInfo(size_t, const std::string& inputName) override; 135 armnn::BindingPointInfo GetOutputBindingPointInfo(size_t, const std::string& outputName) override; 136 137 private: 138 armnnDeserializer::IDeserializerPtr m_Parser; 139 }; 140 #endif 141 142 #if defined(ARMNN_TF_LITE_PARSER) 143 class TfliteParser : public IParser 144 { 145 public: 146 TfliteParser(const ExecuteNetworkParams& params); 147 148 armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) override; 149 armnn::BindingPointInfo GetInputBindingPointInfo(size_t subgraphId, const std::string& inputName) override; 150 armnn::BindingPointInfo GetOutputBindingPointInfo(size_t subgraphId, const std::string& outputName) override; 151 152 private: __anondf7b68c10102()153 armnnTfLiteParser::ITfLiteParserPtr m_Parser{nullptr, [](armnnTfLiteParser::ITfLiteParser*){}}; 154 }; 155 #endif 156 157 #if defined(ARMNN_ONNX_PARSER) 158 class OnnxParser : public IParser 159 { 160 public: 161 OnnxParser(); 162 163 armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) override; 164 armnn::BindingPointInfo GetInputBindingPointInfo(size_t subgraphId, const std::string& inputName) override; 165 armnn::BindingPointInfo GetOutputBindingPointInfo(size_t subgraphId, const std::string& outputName) override; 166 167 private: 168 armnnOnnxParser::IOnnxParserPtr m_Parser; 169 }; 170 #endif 171 };