1 // 2 // Copyright © 2020-2023 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include <DelegateOptions.hpp> 9 10 #include <tensorflow/lite/builtin_ops.h> 11 #include <tensorflow/lite/c/builtin_op_data.h> 12 #include <tensorflow/lite/c/common.h> 13 #include <tensorflow/lite/minimal_logging.h> 14 #include <tensorflow/lite/version.h> 15 16 #if TF_MAJOR_VERSION > 2 || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 3) 17 #define ARMNN_POST_TFLITE_2_3 18 #endif 19 20 #if TF_MAJOR_VERSION > 2 || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 4) 21 #define ARMNN_POST_TFLITE_2_4 22 #endif 23 24 #if TF_MAJOR_VERSION > 2 || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 5) 25 #define ARMNN_POST_TFLITE_2_5 26 #endif 27 28 namespace armnnDelegate 29 { 30 31 struct DelegateData 32 { DelegateDataarmnnDelegate::DelegateData33 DelegateData(const std::vector<armnn::BackendId>& backends) 34 : m_Backends(backends) 35 , m_Network(nullptr, nullptr) 36 {} 37 38 const std::vector<armnn::BackendId> m_Backends; 39 armnn::INetworkPtr m_Network; 40 std::vector<armnn::IOutputSlot*> m_OutputSlotForNode; 41 }; 42 43 // Forward decleration for functions initializing the ArmNN Delegate 44 DelegateOptions TfLiteArmnnDelegateOptionsDefault(); 45 46 TfLiteDelegate* TfLiteArmnnDelegateCreate(armnnDelegate::DelegateOptions options); 47 48 void TfLiteArmnnDelegateDelete(TfLiteDelegate* tfLiteDelegate); 49 50 TfLiteStatus DoPrepare(TfLiteContext* context, TfLiteDelegate* delegate); 51 52 /// ArmNN Delegate 53 class Delegate 54 { 55 friend class ArmnnSubgraph; 56 public: 57 explicit Delegate(armnnDelegate::DelegateOptions options); 58 59 TfLiteIntArray* IdentifyOperatorsToDelegate(TfLiteContext* context); 60 61 TfLiteDelegate* GetDelegate(); 62 63 /// Retrieve version in X.Y.Z form 64 static const std::string GetVersion(); 65 66 private: 67 /** 68 * Returns a pointer to the armnn::IRuntime* this will be shared by all armnn_delegates. 69 */ GetRuntime(const armnn::IRuntime::CreationOptions & options)70 armnn::IRuntime* GetRuntime(const armnn::IRuntime::CreationOptions& options) 71 { 72 static armnn::IRuntimePtr instance = armnn::IRuntime::Create(options); 73 // Instantiated on first use. 74 return instance.get(); 75 } 76 77 TfLiteDelegate m_Delegate = { 78 reinterpret_cast<void*>(this), // .data_ 79 DoPrepare, // .Prepare 80 nullptr, // .CopyFromBufferHandle 81 nullptr, // .CopyToBufferHandle 82 nullptr, // .FreeBufferHandle 83 kTfLiteDelegateFlagsNone, // .flags 84 nullptr, // .opaque_delegate_builder 85 }; 86 87 /// ArmNN Runtime pointer 88 armnn::IRuntime* m_Runtime; 89 /// ArmNN Delegate Options 90 armnnDelegate::DelegateOptions m_Options; 91 }; 92 93 /// ArmnnSubgraph class where parsing the nodes to ArmNN format and creating the ArmNN Graph 94 class ArmnnSubgraph 95 { 96 public: 97 static ArmnnSubgraph* Create(TfLiteContext* tfLiteContext, 98 const TfLiteDelegateParams* parameters, 99 const Delegate* delegate); 100 101 TfLiteStatus Prepare(TfLiteContext* tfLiteContext); 102 103 TfLiteStatus Invoke(TfLiteContext* tfLiteContext, TfLiteNode* tfLiteNode); 104 105 static TfLiteStatus VisitNode(DelegateData& delegateData, 106 TfLiteContext* tfLiteContext, 107 TfLiteRegistration* tfLiteRegistration, 108 TfLiteNode* tfLiteNode, 109 int nodeIndex); 110 111 private: ArmnnSubgraph(armnn::NetworkId networkId,armnn::IRuntime * runtime,std::vector<armnn::BindingPointInfo> & inputBindings,std::vector<armnn::BindingPointInfo> & outputBindings)112 ArmnnSubgraph(armnn::NetworkId networkId, 113 armnn::IRuntime* runtime, 114 std::vector<armnn::BindingPointInfo>& inputBindings, 115 std::vector<armnn::BindingPointInfo>& outputBindings) 116 : m_NetworkId(networkId), m_Runtime(runtime), m_InputBindings(inputBindings), m_OutputBindings(outputBindings) 117 {} 118 119 static TfLiteStatus AddInputLayer(DelegateData& delegateData, 120 TfLiteContext* tfLiteContext, 121 const TfLiteIntArray* inputs, 122 std::vector<armnn::BindingPointInfo>& inputBindings); 123 124 static TfLiteStatus AddOutputLayer(DelegateData& delegateData, 125 TfLiteContext* tfLiteContext, 126 const TfLiteIntArray* outputs, 127 std::vector<armnn::BindingPointInfo>& outputBindings); 128 129 130 /// The Network Id 131 armnn::NetworkId m_NetworkId; 132 /// ArmNN Runtime 133 armnn::IRuntime* m_Runtime; 134 135 // Binding information for inputs and outputs 136 std::vector<armnn::BindingPointInfo> m_InputBindings; 137 std::vector<armnn::BindingPointInfo> m_OutputBindings; 138 139 }; 140 141 } // armnnDelegate namespace