1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #pragma once 6 7 #include "armnn/Types.hpp" 8 #include "armnn/NetworkFwd.hpp" 9 #include "armnn/Tensor.hpp" 10 #include "armnn/INetwork.hpp" 11 #include "armnn/Optional.hpp" 12 13 #include <memory> 14 #include <map> 15 #include <vector> 16 17 namespace armnnTfLiteParser 18 { 19 20 using BindingPointInfo = armnn::BindingPointInfo; 21 22 class TfLiteParserImpl; 23 class ITfLiteParser; 24 using ITfLiteParserPtr = std::unique_ptr<ITfLiteParser, void(*)(ITfLiteParser* parser)>; 25 26 class ITfLiteParser 27 { 28 public: 29 struct TfLiteParserOptions 30 { TfLiteParserOptionsarmnnTfLiteParser::ITfLiteParser::TfLiteParserOptions31 TfLiteParserOptions() 32 : m_AllowExpandedDims(false), 33 m_StandInLayerForUnsupported(false), 34 m_InferAndValidate(false) {} 35 36 bool m_AllowExpandedDims; 37 bool m_StandInLayerForUnsupported; 38 bool m_InferAndValidate; 39 }; 40 41 static ITfLiteParser* CreateRaw(const armnn::Optional<TfLiteParserOptions>& options = armnn::EmptyOptional()); 42 static ITfLiteParserPtr Create(const armnn::Optional<TfLiteParserOptions>& options = armnn::EmptyOptional()); 43 static void Destroy(ITfLiteParser* parser); 44 45 /// Create the network from a flatbuffers binary file on disk 46 armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile); 47 48 /// Create the network from a flatbuffers binary 49 armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t> & binaryContent); 50 51 /// Retrieve binding info (layer id and tensor info) for the network input identified by 52 /// the given layer name and subgraph id 53 BindingPointInfo GetNetworkInputBindingInfo(size_t subgraphId, 54 const std::string& name) const; 55 56 /// Retrieve binding info (layer id and tensor info) for the network output identified by 57 /// the given layer name and subgraph id 58 BindingPointInfo GetNetworkOutputBindingInfo(size_t subgraphId, 59 const std::string& name) const; 60 61 /// Return the number of subgraphs in the parsed model 62 size_t GetSubgraphCount() const; 63 64 /// Return the input tensor names for a given subgraph 65 std::vector<std::string> GetSubgraphInputTensorNames(size_t subgraphId) const; 66 67 /// Return the output tensor names for a given subgraph 68 std::vector<std::string> GetSubgraphOutputTensorNames(size_t subgraphId) const; 69 70 private: 71 ITfLiteParser(const armnn::Optional<TfLiteParserOptions>& options = armnn::EmptyOptional()); 72 ~ITfLiteParser(); 73 74 std::unique_ptr<TfLiteParserImpl> pTfLiteParserImpl; 75 }; 76 77 } 78