1 //
2 // Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #pragma once
7
8 #include <DelegateOptions.hpp>
9
10 #include <tensorflow/lite/c/c_api_opaque.h>
11 #include <tensorflow/lite/core/experimental/acceleration/configuration/c/stable_delegate.h>
12
13 namespace armnnOpaqueDelegate
14 {
15
16 struct DelegateData
17 {
DelegateDataarmnnOpaqueDelegate::DelegateData18 DelegateData(const std::vector<armnn::BackendId>& backends)
19 : m_Backends(backends)
20 , m_Network(nullptr, nullptr)
21 {}
22
23 const std::vector<armnn::BackendId> m_Backends;
24 armnn::INetworkPtr m_Network;
25 std::vector<armnn::IOutputSlot*> m_OutputSlotForNode;
26 };
27
28 /// Forward declaration for functions initializing the ArmNN Delegate
29 ::armnnDelegate::DelegateOptions TfLiteArmnnDelegateOptionsDefault();
30
31 TfLiteOpaqueDelegate* TfLiteArmnnOpaqueDelegateCreate(const void* settings);
32
33 void TfLiteArmnnOpaqueDelegateDelete(TfLiteOpaqueDelegate* tfLiteDelegate);
34
35 TfLiteStatus DoPrepare(TfLiteOpaqueContext* context, TfLiteOpaqueDelegate* delegate, void* data);
36
37 /// ArmNN Opaque Delegate
38 class ArmnnOpaqueDelegate
39 {
40 friend class ArmnnSubgraph;
41 public:
42 explicit ArmnnOpaqueDelegate(armnnDelegate::DelegateOptions options);
43
44 TfLiteIntArray* IdentifyOperatorsToDelegate(TfLiteOpaqueContext* context);
45
GetDelegateBuilder()46 TfLiteOpaqueDelegateBuilder* GetDelegateBuilder() { return &m_Builder; }
47
48 /// Retrieve version in X.Y.Z form
49 static const std::string GetVersion();
50
51 private:
52 /**
53 * Returns a pointer to the armnn::IRuntime* this will be shared by all armnn_delegates.
54 */
GetRuntime(const armnn::IRuntime::CreationOptions & options)55 armnn::IRuntime* GetRuntime(const armnn::IRuntime::CreationOptions& options)
56 {
57 static armnn::IRuntimePtr instance = armnn::IRuntime::Create(options);
58 /// Instantiated on first use.
59 return instance.get();
60 }
61
62 TfLiteOpaqueDelegateBuilder m_Builder =
63 {
64 reinterpret_cast<void*>(this), // .data_
65 DoPrepare, // .Prepare
66 nullptr, // .CopyFromBufferHandle
67 nullptr, // .CopyToBufferHandle
68 nullptr, // .FreeBufferHandle
69 kTfLiteDelegateFlagsNone, // .flags
70 };
71
72 /// ArmNN Runtime pointer
73 armnn::IRuntime* m_Runtime;
74 /// ArmNN Delegate Options
75 armnnDelegate::DelegateOptions m_Options;
76 };
77
TfLiteArmnnOpaqueDelegateErrno(TfLiteOpaqueDelegate * delegate)78 static int TfLiteArmnnOpaqueDelegateErrno(TfLiteOpaqueDelegate* delegate) { return 0; }
79
80 /// In order for the delegate to be loaded by TfLite
81 const TfLiteOpaqueDelegatePlugin* GetArmnnDelegatePluginApi();
82
83 extern const TfLiteStableDelegate TFL_TheStableDelegate;
84
85 /// ArmnnSubgraph class where parsing the nodes to ArmNN format and creating the ArmNN Graph
86 class ArmnnSubgraph
87 {
88 public:
89 static ArmnnSubgraph* Create(TfLiteOpaqueContext* tfLiteContext,
90 const TfLiteOpaqueDelegateParams* parameters,
91 const ArmnnOpaqueDelegate* delegate);
92
93 TfLiteStatus Prepare(TfLiteOpaqueContext* tfLiteContext);
94
95 TfLiteStatus Invoke(TfLiteOpaqueContext* tfLiteContext, TfLiteOpaqueNode* tfLiteNode);
96
97 static TfLiteStatus VisitNode(DelegateData& delegateData,
98 TfLiteOpaqueContext* tfLiteContext,
99 TfLiteRegistrationExternal* tfLiteRegistration,
100 TfLiteOpaqueNode* tfLiteNode,
101 int nodeIndex);
102 private:
ArmnnSubgraph(armnn::NetworkId networkId,armnn::IRuntime * runtime,std::vector<armnn::BindingPointInfo> & inputBindings,std::vector<armnn::BindingPointInfo> & outputBindings)103 ArmnnSubgraph(armnn::NetworkId networkId,
104 armnn::IRuntime* runtime,
105 std::vector<armnn::BindingPointInfo>& inputBindings,
106 std::vector<armnn::BindingPointInfo>& outputBindings)
107 : m_NetworkId(networkId)
108 , m_Runtime(runtime)
109 , m_InputBindings(inputBindings)
110 , m_OutputBindings(outputBindings)
111 {}
112 static TfLiteStatus AddInputLayer(DelegateData& delegateData,
113 TfLiteOpaqueContext* tfLiteContext,
114 const TfLiteIntArray* inputs,
115 std::vector<armnn::BindingPointInfo>& inputBindings);
116 static TfLiteStatus AddOutputLayer(DelegateData& delegateData,
117 TfLiteOpaqueContext* tfLiteContext,
118 const TfLiteIntArray* outputs,
119 std::vector<armnn::BindingPointInfo>& outputBindings);
120 /// The Network Id
121 armnn::NetworkId m_NetworkId;
122 /// ArmNN Runtime
123 armnn::IRuntime* m_Runtime;
124 /// Binding information for inputs and outputs
125 std::vector<armnn::BindingPointInfo> m_InputBindings;
126 std::vector<armnn::BindingPointInfo> m_OutputBindings;
127 };
128
129 } // armnnOpaqueDelegate namespace