1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // This file has utilities that facilitates creating new delegates. 17 // - SimpleDelegateKernelInterface: Represents a Kernel which handles a subgraph 18 // to be delegated. It has Init/Prepare/Invoke which are going to be called 19 // during inference, similar to TFLite Kernels. Delegate owner should implement 20 // this interface to build/prepare/invoke the delegated subgraph. 21 // - SimpleDelegateInterface: 22 // This class wraps TFLiteDelegate and users need to implement the interface and 23 // then call TfLiteDelegateFactory::CreateSimpleDelegate(...) to get 24 // TfLiteDelegate* that can be passed to ModifyGraphWithDelegate and free it via 25 // TfLiteDelegateFactory::DeleteSimpleDelegate(...). 26 // or call TfLiteDelegateFactory::Create(...) to get a std::unique_ptr 27 // TfLiteDelegate that can also be passed to ModifyGraphWithDelegate, in which 28 // case TfLite interpereter takes the memory ownership of the delegate. 29 #ifndef TENSORFLOW_LITE_DELEGATES_UTILS_SIMPLE_DELEGATE_H_ 30 #define TENSORFLOW_LITE_DELEGATES_UTILS_SIMPLE_DELEGATE_H_ 31 32 #include <memory> 33 34 #include "tensorflow/lite/c/common.h" 35 36 namespace tflite { 37 38 using TfLiteDelegateUniquePtr = 39 std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>; 40 41 // Users should inherit from this class and implement the interface below. 42 // Each instance represents a single part of the graph (subgraph). 43 class SimpleDelegateKernelInterface { 44 public: ~SimpleDelegateKernelInterface()45 virtual ~SimpleDelegateKernelInterface() {} 46 47 // Initializes a delegated subgraph. 48 // The nodes in the subgraph are inside TfLiteDelegateParams->nodes_to_replace 49 virtual TfLiteStatus Init(TfLiteContext* context, 50 const TfLiteDelegateParams* params) = 0; 51 52 // Will be called by the framework. Should handle any needed preparation 53 // for the subgraph e.g. allocating buffers, compiling model. 54 // Returns status, and signalling any errors. 55 virtual TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) = 0; 56 57 // Actual subgraph inference should happen on this call. 58 // Returns status, and signalling any errors. 59 // NOTE: Tensor data pointers (tensor->data) can change every inference, so 60 // the implementation of this method needs to take that into account. 61 virtual TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) = 0; 62 }; 63 64 // Pure Interface that clients should implement. 65 // The Interface represents a delegate's capabilities and provides a factory 66 // for SimpleDelegateKernelInterface. 67 // 68 // Clients should implement the following methods: 69 // - IsNodeSupportedByDelegate 70 // - Initialize 71 // - Name 72 // - CreateDelegateKernelInterface 73 // - DelegateOptions 74 class SimpleDelegateInterface { 75 public: 76 // Properties of a delegate. These are used by TfLiteDelegateFactory to 77 // help determine how to partition the graph, i.e. which nodes each delegate 78 // will get applied to. 79 struct Options { 80 // Maximum number of delegated subgraph, values <=0 means unlimited. 81 int max_delegated_partitions = 0; 82 83 // The minimum number of nodes allowed in a delegated graph, values <=0 84 // means unlimited. 85 int min_nodes_per_partition = 0; 86 }; 87 ~SimpleDelegateInterface()88 virtual ~SimpleDelegateInterface() {} 89 90 // Returns true if 'node' is supported by the delegate. False otherwise. 91 virtual bool IsNodeSupportedByDelegate(const TfLiteRegistration* registration, 92 const TfLiteNode* node, 93 TfLiteContext* context) const = 0; 94 95 // Initialize the delegate before finding and replacing TfLite nodes with 96 // delegate kernels, for example, retrieving some TFLite settings from 97 // 'context'. 98 virtual TfLiteStatus Initialize(TfLiteContext* context) = 0; 99 100 // Returns a name that identifies the delegate. 101 // This name is used for debugging/logging/profiling. 102 virtual const char* Name() const = 0; 103 104 // Returns instance of an object that implements the interface 105 // SimpleDelegateKernelInterface. 106 // An instance of SimpleDelegateKernelInterface represents one subgraph to 107 // be delegated. 108 // Caller takes ownership of the returned object. 109 virtual std::unique_ptr<SimpleDelegateKernelInterface> 110 CreateDelegateKernelInterface() = 0; 111 112 // Returns SimpleDelegateInterface::Options which has delegate properties 113 // relevant for graph partitioning. 114 virtual SimpleDelegateInterface::Options DelegateOptions() const = 0; 115 }; 116 117 // Factory class that provides static methods to deal with SimpleDelegate 118 // creation and deletion. 119 class TfLiteDelegateFactory { 120 public: 121 // Creates TfLiteDelegate from the provided SimpleDelegateInterface. 122 // The returned TfLiteDelegate should be deleted using DeleteSimpleDelegate. 123 // A simple usage of the flags bit mask: 124 // CreateSimpleDelegate(..., kTfLiteDelegateFlagsAllowDynamicTensors | 125 // kTfLiteDelegateFlagsRequirePropagatedShapes) 126 static TfLiteDelegate* CreateSimpleDelegate( 127 std::unique_ptr<SimpleDelegateInterface> simple_delegate, 128 int64_t flags = kTfLiteDelegateFlagsNone); 129 130 // Deletes 'delegate' the passed pointer must be the one returned 131 // from CreateSimpleDelegate. 132 // This function will destruct the SimpleDelegate object too. 133 static void DeleteSimpleDelegate(TfLiteDelegate* delegate); 134 135 // A convenient function wrapping the above two functions and returning a 136 // std::unique_ptr type for auto memory management. Create(std::unique_ptr<SimpleDelegateInterface> simple_delegate)137 inline static TfLiteDelegateUniquePtr Create( 138 std::unique_ptr<SimpleDelegateInterface> simple_delegate) { 139 return TfLiteDelegateUniquePtr( 140 CreateSimpleDelegate(std::move(simple_delegate)), DeleteSimpleDelegate); 141 } 142 }; 143 144 } // namespace tflite 145 146 #endif // TENSORFLOW_LITE_DELEGATES_UTILS_SIMPLE_DELEGATE_H_ 147