xref: /aosp_15_r20/external/armnn/include/armnnTfLiteParser/ITfLiteParser.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker #include "armnn/Types.hpp"
8*89c4ff92SAndroid Build Coastguard Worker #include "armnn/NetworkFwd.hpp"
9*89c4ff92SAndroid Build Coastguard Worker #include "armnn/Tensor.hpp"
10*89c4ff92SAndroid Build Coastguard Worker #include "armnn/INetwork.hpp"
11*89c4ff92SAndroid Build Coastguard Worker #include "armnn/Optional.hpp"
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker #include <memory>
14*89c4ff92SAndroid Build Coastguard Worker #include <map>
15*89c4ff92SAndroid Build Coastguard Worker #include <vector>
16*89c4ff92SAndroid Build Coastguard Worker 
17*89c4ff92SAndroid Build Coastguard Worker namespace armnnTfLiteParser
18*89c4ff92SAndroid Build Coastguard Worker {
19*89c4ff92SAndroid Build Coastguard Worker 
20*89c4ff92SAndroid Build Coastguard Worker using BindingPointInfo = armnn::BindingPointInfo;
21*89c4ff92SAndroid Build Coastguard Worker 
22*89c4ff92SAndroid Build Coastguard Worker class TfLiteParserImpl;
23*89c4ff92SAndroid Build Coastguard Worker class ITfLiteParser;
24*89c4ff92SAndroid Build Coastguard Worker using ITfLiteParserPtr = std::unique_ptr<ITfLiteParser, void(*)(ITfLiteParser* parser)>;
25*89c4ff92SAndroid Build Coastguard Worker 
26*89c4ff92SAndroid Build Coastguard Worker class ITfLiteParser
27*89c4ff92SAndroid Build Coastguard Worker {
28*89c4ff92SAndroid Build Coastguard Worker public:
29*89c4ff92SAndroid Build Coastguard Worker     struct TfLiteParserOptions
30*89c4ff92SAndroid Build Coastguard Worker     {
TfLiteParserOptionsarmnnTfLiteParser::ITfLiteParser::TfLiteParserOptions31*89c4ff92SAndroid Build Coastguard Worker         TfLiteParserOptions()
32*89c4ff92SAndroid Build Coastguard Worker             : m_AllowExpandedDims(false),
33*89c4ff92SAndroid Build Coastguard Worker               m_StandInLayerForUnsupported(false),
34*89c4ff92SAndroid Build Coastguard Worker               m_InferAndValidate(false) {}
35*89c4ff92SAndroid Build Coastguard Worker 
36*89c4ff92SAndroid Build Coastguard Worker         bool m_AllowExpandedDims;
37*89c4ff92SAndroid Build Coastguard Worker         bool m_StandInLayerForUnsupported;
38*89c4ff92SAndroid Build Coastguard Worker         bool m_InferAndValidate;
39*89c4ff92SAndroid Build Coastguard Worker     };
40*89c4ff92SAndroid Build Coastguard Worker 
41*89c4ff92SAndroid Build Coastguard Worker     static ITfLiteParser* CreateRaw(const armnn::Optional<TfLiteParserOptions>& options = armnn::EmptyOptional());
42*89c4ff92SAndroid Build Coastguard Worker     static ITfLiteParserPtr Create(const armnn::Optional<TfLiteParserOptions>& options = armnn::EmptyOptional());
43*89c4ff92SAndroid Build Coastguard Worker     static void Destroy(ITfLiteParser* parser);
44*89c4ff92SAndroid Build Coastguard Worker 
45*89c4ff92SAndroid Build Coastguard Worker     /// Create the network from a flatbuffers binary file on disk
46*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile);
47*89c4ff92SAndroid Build Coastguard Worker 
48*89c4ff92SAndroid Build Coastguard Worker     /// Create the network from a flatbuffers binary
49*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t> & binaryContent);
50*89c4ff92SAndroid Build Coastguard Worker 
51*89c4ff92SAndroid Build Coastguard Worker     /// Retrieve binding info (layer id and tensor info) for the network input identified by
52*89c4ff92SAndroid Build Coastguard Worker     /// the given layer name and subgraph id
53*89c4ff92SAndroid Build Coastguard Worker     BindingPointInfo GetNetworkInputBindingInfo(size_t subgraphId,
54*89c4ff92SAndroid Build Coastguard Worker                                                 const std::string& name) const;
55*89c4ff92SAndroid Build Coastguard Worker 
56*89c4ff92SAndroid Build Coastguard Worker     /// Retrieve binding info (layer id and tensor info) for the network output identified by
57*89c4ff92SAndroid Build Coastguard Worker     /// the given layer name and subgraph id
58*89c4ff92SAndroid Build Coastguard Worker     BindingPointInfo GetNetworkOutputBindingInfo(size_t subgraphId,
59*89c4ff92SAndroid Build Coastguard Worker                                                          const std::string& name) const;
60*89c4ff92SAndroid Build Coastguard Worker 
61*89c4ff92SAndroid Build Coastguard Worker     /// Return the number of subgraphs in the parsed model
62*89c4ff92SAndroid Build Coastguard Worker     size_t GetSubgraphCount() const;
63*89c4ff92SAndroid Build Coastguard Worker 
64*89c4ff92SAndroid Build Coastguard Worker     /// Return the input tensor names for a given subgraph
65*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::string> GetSubgraphInputTensorNames(size_t subgraphId) const;
66*89c4ff92SAndroid Build Coastguard Worker 
67*89c4ff92SAndroid Build Coastguard Worker     /// Return the output tensor names for a given subgraph
68*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::string> GetSubgraphOutputTensorNames(size_t subgraphId) const;
69*89c4ff92SAndroid Build Coastguard Worker 
70*89c4ff92SAndroid Build Coastguard Worker private:
71*89c4ff92SAndroid Build Coastguard Worker     ITfLiteParser(const armnn::Optional<TfLiteParserOptions>& options = armnn::EmptyOptional());
72*89c4ff92SAndroid Build Coastguard Worker     ~ITfLiteParser();
73*89c4ff92SAndroid Build Coastguard Worker 
74*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<TfLiteParserImpl> pTfLiteParserImpl;
75*89c4ff92SAndroid Build Coastguard Worker };
76*89c4ff92SAndroid Build Coastguard Worker 
77*89c4ff92SAndroid Build Coastguard Worker }
78