xref: /aosp_15_r20/external/armnn/delegate/classic/include/armnn_delegate.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2020-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <DelegateOptions.hpp>
9 
10 #include <tensorflow/lite/builtin_ops.h>
11 #include <tensorflow/lite/c/builtin_op_data.h>
12 #include <tensorflow/lite/c/common.h>
13 #include <tensorflow/lite/minimal_logging.h>
14 #include <tensorflow/lite/version.h>
15 
16 #if TF_MAJOR_VERSION > 2 || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 3)
17 #define ARMNN_POST_TFLITE_2_3
18 #endif
19 
20 #if TF_MAJOR_VERSION > 2 || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 4)
21 #define ARMNN_POST_TFLITE_2_4
22 #endif
23 
24 #if TF_MAJOR_VERSION > 2 || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 5)
25 #define ARMNN_POST_TFLITE_2_5
26 #endif
27 
28 namespace armnnDelegate
29 {
30 
31 struct DelegateData
32 {
DelegateDataarmnnDelegate::DelegateData33     DelegateData(const std::vector<armnn::BackendId>& backends)
34             : m_Backends(backends)
35             , m_Network(nullptr, nullptr)
36     {}
37 
38     const std::vector<armnn::BackendId>       m_Backends;
39     armnn::INetworkPtr                        m_Network;
40     std::vector<armnn::IOutputSlot*>          m_OutputSlotForNode;
41 };
42 
43 // Forward decleration for functions initializing the ArmNN Delegate
44 DelegateOptions TfLiteArmnnDelegateOptionsDefault();
45 
46 TfLiteDelegate* TfLiteArmnnDelegateCreate(armnnDelegate::DelegateOptions options);
47 
48 void TfLiteArmnnDelegateDelete(TfLiteDelegate* tfLiteDelegate);
49 
50 TfLiteStatus DoPrepare(TfLiteContext* context, TfLiteDelegate* delegate);
51 
52 /// ArmNN Delegate
53 class Delegate
54 {
55     friend class ArmnnSubgraph;
56 public:
57     explicit Delegate(armnnDelegate::DelegateOptions options);
58 
59     TfLiteIntArray* IdentifyOperatorsToDelegate(TfLiteContext* context);
60 
61     TfLiteDelegate* GetDelegate();
62 
63     /// Retrieve version in X.Y.Z form
64     static const std::string GetVersion();
65 
66 private:
67     /**
68      * Returns a pointer to the armnn::IRuntime* this will be shared by all armnn_delegates.
69      */
GetRuntime(const armnn::IRuntime::CreationOptions & options)70     armnn::IRuntime* GetRuntime(const armnn::IRuntime::CreationOptions& options)
71     {
72         static armnn::IRuntimePtr instance = armnn::IRuntime::Create(options);
73         // Instantiated on first use.
74         return instance.get();
75     }
76 
77     TfLiteDelegate m_Delegate = {
78             reinterpret_cast<void*>(this),  // .data_
79             DoPrepare,                      // .Prepare
80             nullptr,                        // .CopyFromBufferHandle
81             nullptr,                        // .CopyToBufferHandle
82             nullptr,                        // .FreeBufferHandle
83             kTfLiteDelegateFlagsNone,       // .flags
84             nullptr,                        // .opaque_delegate_builder
85     };
86 
87     /// ArmNN Runtime pointer
88     armnn::IRuntime* m_Runtime;
89     /// ArmNN Delegate Options
90     armnnDelegate::DelegateOptions m_Options;
91 };
92 
93 /// ArmnnSubgraph class where parsing the nodes to ArmNN format and creating the ArmNN Graph
94 class ArmnnSubgraph
95 {
96 public:
97     static ArmnnSubgraph* Create(TfLiteContext* tfLiteContext,
98                                  const TfLiteDelegateParams* parameters,
99                                  const Delegate* delegate);
100 
101     TfLiteStatus Prepare(TfLiteContext* tfLiteContext);
102 
103     TfLiteStatus Invoke(TfLiteContext* tfLiteContext, TfLiteNode* tfLiteNode);
104 
105     static TfLiteStatus VisitNode(DelegateData& delegateData,
106                                   TfLiteContext* tfLiteContext,
107                                   TfLiteRegistration* tfLiteRegistration,
108                                   TfLiteNode* tfLiteNode,
109                                   int nodeIndex);
110 
111 private:
ArmnnSubgraph(armnn::NetworkId networkId,armnn::IRuntime * runtime,std::vector<armnn::BindingPointInfo> & inputBindings,std::vector<armnn::BindingPointInfo> & outputBindings)112     ArmnnSubgraph(armnn::NetworkId networkId,
113                   armnn::IRuntime* runtime,
114                   std::vector<armnn::BindingPointInfo>& inputBindings,
115                   std::vector<armnn::BindingPointInfo>& outputBindings)
116         : m_NetworkId(networkId), m_Runtime(runtime), m_InputBindings(inputBindings), m_OutputBindings(outputBindings)
117     {}
118 
119     static TfLiteStatus AddInputLayer(DelegateData& delegateData,
120                                       TfLiteContext* tfLiteContext,
121                                       const TfLiteIntArray* inputs,
122                                       std::vector<armnn::BindingPointInfo>& inputBindings);
123 
124     static TfLiteStatus AddOutputLayer(DelegateData& delegateData,
125                                        TfLiteContext* tfLiteContext,
126                                        const TfLiteIntArray* outputs,
127                                        std::vector<armnn::BindingPointInfo>& outputBindings);
128 
129 
130     /// The Network Id
131     armnn::NetworkId m_NetworkId;
132     /// ArmNN Runtime
133     armnn::IRuntime* m_Runtime;
134 
135     // Binding information for inputs and outputs
136     std::vector<armnn::BindingPointInfo> m_InputBindings;
137     std::vector<armnn::BindingPointInfo> m_OutputBindings;
138 
139 };
140 
141 } // armnnDelegate namespace