1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker #pragma once 6*89c4ff92SAndroid Build Coastguard Worker 7*89c4ff92SAndroid Build Coastguard Worker #include "armnn/Types.hpp" 8*89c4ff92SAndroid Build Coastguard Worker #include "armnn/NetworkFwd.hpp" 9*89c4ff92SAndroid Build Coastguard Worker #include "armnn/Tensor.hpp" 10*89c4ff92SAndroid Build Coastguard Worker #include "armnn/INetwork.hpp" 11*89c4ff92SAndroid Build Coastguard Worker #include "armnn/Optional.hpp" 12*89c4ff92SAndroid Build Coastguard Worker 13*89c4ff92SAndroid Build Coastguard Worker #include <memory> 14*89c4ff92SAndroid Build Coastguard Worker #include <map> 15*89c4ff92SAndroid Build Coastguard Worker #include <vector> 16*89c4ff92SAndroid Build Coastguard Worker 17*89c4ff92SAndroid Build Coastguard Worker namespace armnnTfLiteParser 18*89c4ff92SAndroid Build Coastguard Worker { 19*89c4ff92SAndroid Build Coastguard Worker 20*89c4ff92SAndroid Build Coastguard Worker using BindingPointInfo = armnn::BindingPointInfo; 21*89c4ff92SAndroid Build Coastguard Worker 22*89c4ff92SAndroid Build Coastguard Worker class TfLiteParserImpl; 23*89c4ff92SAndroid Build Coastguard Worker class ITfLiteParser; 24*89c4ff92SAndroid Build Coastguard Worker using ITfLiteParserPtr = std::unique_ptr<ITfLiteParser, void(*)(ITfLiteParser* parser)>; 25*89c4ff92SAndroid Build Coastguard Worker 26*89c4ff92SAndroid Build Coastguard Worker class ITfLiteParser 27*89c4ff92SAndroid Build Coastguard Worker { 28*89c4ff92SAndroid Build Coastguard Worker public: 29*89c4ff92SAndroid Build Coastguard Worker struct TfLiteParserOptions 30*89c4ff92SAndroid Build Coastguard Worker { TfLiteParserOptionsarmnnTfLiteParser::ITfLiteParser::TfLiteParserOptions31*89c4ff92SAndroid Build Coastguard Worker TfLiteParserOptions() 32*89c4ff92SAndroid Build Coastguard Worker : m_AllowExpandedDims(false), 33*89c4ff92SAndroid Build Coastguard Worker m_StandInLayerForUnsupported(false), 34*89c4ff92SAndroid Build Coastguard Worker m_InferAndValidate(false) {} 35*89c4ff92SAndroid Build Coastguard Worker 36*89c4ff92SAndroid Build Coastguard Worker bool m_AllowExpandedDims; 37*89c4ff92SAndroid Build Coastguard Worker bool m_StandInLayerForUnsupported; 38*89c4ff92SAndroid Build Coastguard Worker bool m_InferAndValidate; 39*89c4ff92SAndroid Build Coastguard Worker }; 40*89c4ff92SAndroid Build Coastguard Worker 41*89c4ff92SAndroid Build Coastguard Worker static ITfLiteParser* CreateRaw(const armnn::Optional<TfLiteParserOptions>& options = armnn::EmptyOptional()); 42*89c4ff92SAndroid Build Coastguard Worker static ITfLiteParserPtr Create(const armnn::Optional<TfLiteParserOptions>& options = armnn::EmptyOptional()); 43*89c4ff92SAndroid Build Coastguard Worker static void Destroy(ITfLiteParser* parser); 44*89c4ff92SAndroid Build Coastguard Worker 45*89c4ff92SAndroid Build Coastguard Worker /// Create the network from a flatbuffers binary file on disk 46*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile); 47*89c4ff92SAndroid Build Coastguard Worker 48*89c4ff92SAndroid Build Coastguard Worker /// Create the network from a flatbuffers binary 49*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t> & binaryContent); 50*89c4ff92SAndroid Build Coastguard Worker 51*89c4ff92SAndroid Build Coastguard Worker /// Retrieve binding info (layer id and tensor info) for the network input identified by 52*89c4ff92SAndroid Build Coastguard Worker /// the given layer name and subgraph id 53*89c4ff92SAndroid Build Coastguard Worker BindingPointInfo GetNetworkInputBindingInfo(size_t subgraphId, 54*89c4ff92SAndroid Build Coastguard Worker const std::string& name) const; 55*89c4ff92SAndroid Build Coastguard Worker 56*89c4ff92SAndroid Build Coastguard Worker /// Retrieve binding info (layer id and tensor info) for the network output identified by 57*89c4ff92SAndroid Build Coastguard Worker /// the given layer name and subgraph id 58*89c4ff92SAndroid Build Coastguard Worker BindingPointInfo GetNetworkOutputBindingInfo(size_t subgraphId, 59*89c4ff92SAndroid Build Coastguard Worker const std::string& name) const; 60*89c4ff92SAndroid Build Coastguard Worker 61*89c4ff92SAndroid Build Coastguard Worker /// Return the number of subgraphs in the parsed model 62*89c4ff92SAndroid Build Coastguard Worker size_t GetSubgraphCount() const; 63*89c4ff92SAndroid Build Coastguard Worker 64*89c4ff92SAndroid Build Coastguard Worker /// Return the input tensor names for a given subgraph 65*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string> GetSubgraphInputTensorNames(size_t subgraphId) const; 66*89c4ff92SAndroid Build Coastguard Worker 67*89c4ff92SAndroid Build Coastguard Worker /// Return the output tensor names for a given subgraph 68*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string> GetSubgraphOutputTensorNames(size_t subgraphId) const; 69*89c4ff92SAndroid Build Coastguard Worker 70*89c4ff92SAndroid Build Coastguard Worker private: 71*89c4ff92SAndroid Build Coastguard Worker ITfLiteParser(const armnn::Optional<TfLiteParserOptions>& options = armnn::EmptyOptional()); 72*89c4ff92SAndroid Build Coastguard Worker ~ITfLiteParser(); 73*89c4ff92SAndroid Build Coastguard Worker 74*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<TfLiteParserImpl> pTfLiteParserImpl; 75*89c4ff92SAndroid Build Coastguard Worker }; 76*89c4ff92SAndroid Build Coastguard Worker 77*89c4ff92SAndroid Build Coastguard Worker } 78