xref: /aosp_15_r20/external/armnn/include/armnnTfLiteParser/ITfLiteParser.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "armnn/Types.hpp"
8 #include "armnn/NetworkFwd.hpp"
9 #include "armnn/Tensor.hpp"
10 #include "armnn/INetwork.hpp"
11 #include "armnn/Optional.hpp"
12 
13 #include <memory>
14 #include <map>
15 #include <vector>
16 
17 namespace armnnTfLiteParser
18 {
19 
20 using BindingPointInfo = armnn::BindingPointInfo;
21 
22 class TfLiteParserImpl;
23 class ITfLiteParser;
24 using ITfLiteParserPtr = std::unique_ptr<ITfLiteParser, void(*)(ITfLiteParser* parser)>;
25 
26 class ITfLiteParser
27 {
28 public:
29     struct TfLiteParserOptions
30     {
TfLiteParserOptionsarmnnTfLiteParser::ITfLiteParser::TfLiteParserOptions31         TfLiteParserOptions()
32             : m_AllowExpandedDims(false),
33               m_StandInLayerForUnsupported(false),
34               m_InferAndValidate(false) {}
35 
36         bool m_AllowExpandedDims;
37         bool m_StandInLayerForUnsupported;
38         bool m_InferAndValidate;
39     };
40 
41     static ITfLiteParser* CreateRaw(const armnn::Optional<TfLiteParserOptions>& options = armnn::EmptyOptional());
42     static ITfLiteParserPtr Create(const armnn::Optional<TfLiteParserOptions>& options = armnn::EmptyOptional());
43     static void Destroy(ITfLiteParser* parser);
44 
45     /// Create the network from a flatbuffers binary file on disk
46     armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile);
47 
48     /// Create the network from a flatbuffers binary
49     armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t> & binaryContent);
50 
51     /// Retrieve binding info (layer id and tensor info) for the network input identified by
52     /// the given layer name and subgraph id
53     BindingPointInfo GetNetworkInputBindingInfo(size_t subgraphId,
54                                                 const std::string& name) const;
55 
56     /// Retrieve binding info (layer id and tensor info) for the network output identified by
57     /// the given layer name and subgraph id
58     BindingPointInfo GetNetworkOutputBindingInfo(size_t subgraphId,
59                                                          const std::string& name) const;
60 
61     /// Return the number of subgraphs in the parsed model
62     size_t GetSubgraphCount() const;
63 
64     /// Return the input tensor names for a given subgraph
65     std::vector<std::string> GetSubgraphInputTensorNames(size_t subgraphId) const;
66 
67     /// Return the output tensor names for a given subgraph
68     std::vector<std::string> GetSubgraphOutputTensorNames(size_t subgraphId) const;
69 
70 private:
71     ITfLiteParser(const armnn::Optional<TfLiteParserOptions>& options = armnn::EmptyOptional());
72     ~ITfLiteParser();
73 
74     std::unique_ptr<TfLiteParserImpl> pTfLiteParserImpl;
75 };
76 
77 }
78