xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/utils/simple_delegate.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file has utilities that facilitates creating new delegates.
17 // - SimpleDelegateKernelInterface: Represents a Kernel which handles a subgraph
18 // to be delegated. It has Init/Prepare/Invoke which are going to be called
19 // during inference, similar to TFLite Kernels. Delegate owner should implement
20 // this interface to build/prepare/invoke the delegated subgraph.
21 // - SimpleDelegateInterface:
22 // This class wraps TFLiteDelegate and users need to implement the interface and
23 // then call TfLiteDelegateFactory::CreateSimpleDelegate(...) to get
24 // TfLiteDelegate* that can be passed to ModifyGraphWithDelegate and free it via
25 // TfLiteDelegateFactory::DeleteSimpleDelegate(...).
26 // or call TfLiteDelegateFactory::Create(...) to get a std::unique_ptr
27 // TfLiteDelegate that can also be passed to ModifyGraphWithDelegate, in which
28 // case TfLite interpereter takes the memory ownership of the delegate.
29 #ifndef TENSORFLOW_LITE_DELEGATES_UTILS_SIMPLE_DELEGATE_H_
30 #define TENSORFLOW_LITE_DELEGATES_UTILS_SIMPLE_DELEGATE_H_
31 
32 #include <memory>
33 
34 #include "tensorflow/lite/c/common.h"
35 
36 namespace tflite {
37 
38 using TfLiteDelegateUniquePtr =
39     std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>;
40 
41 // Users should inherit from this class and implement the interface below.
42 // Each instance represents a single part of the graph (subgraph).
43 class SimpleDelegateKernelInterface {
44  public:
~SimpleDelegateKernelInterface()45   virtual ~SimpleDelegateKernelInterface() {}
46 
47   // Initializes a delegated subgraph.
48   // The nodes in the subgraph are inside TfLiteDelegateParams->nodes_to_replace
49   virtual TfLiteStatus Init(TfLiteContext* context,
50                             const TfLiteDelegateParams* params) = 0;
51 
52   // Will be called by the framework. Should handle any needed preparation
53   // for the subgraph e.g. allocating buffers, compiling model.
54   // Returns status, and signalling any errors.
55   virtual TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) = 0;
56 
57   // Actual subgraph inference should happen on this call.
58   // Returns status, and signalling any errors.
59   // NOTE: Tensor data pointers (tensor->data) can change every inference, so
60   // the implementation of this method needs to take that into account.
61   virtual TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) = 0;
62 };
63 
64 // Pure Interface that clients should implement.
65 // The Interface represents a delegate's capabilities and provides a factory
66 // for SimpleDelegateKernelInterface.
67 //
68 // Clients should implement the following methods:
69 // - IsNodeSupportedByDelegate
70 // - Initialize
71 // - Name
72 // - CreateDelegateKernelInterface
73 // - DelegateOptions
74 class SimpleDelegateInterface {
75  public:
76   // Properties of a delegate.  These are used by TfLiteDelegateFactory to
77   // help determine how to partition the graph, i.e. which nodes each delegate
78   // will get applied to.
79   struct Options {
80     // Maximum number of delegated subgraph, values <=0 means unlimited.
81     int max_delegated_partitions = 0;
82 
83     // The minimum number of nodes allowed in a delegated graph, values <=0
84     // means unlimited.
85     int min_nodes_per_partition = 0;
86   };
87 
~SimpleDelegateInterface()88   virtual ~SimpleDelegateInterface() {}
89 
90   // Returns true if 'node' is supported by the delegate. False otherwise.
91   virtual bool IsNodeSupportedByDelegate(const TfLiteRegistration* registration,
92                                          const TfLiteNode* node,
93                                          TfLiteContext* context) const = 0;
94 
95   // Initialize the delegate before finding and replacing TfLite nodes with
96   // delegate kernels, for example, retrieving some TFLite settings from
97   // 'context'.
98   virtual TfLiteStatus Initialize(TfLiteContext* context) = 0;
99 
100   // Returns a name that identifies the delegate.
101   // This name is used for debugging/logging/profiling.
102   virtual const char* Name() const = 0;
103 
104   // Returns instance of an object that implements the interface
105   // SimpleDelegateKernelInterface.
106   // An instance of SimpleDelegateKernelInterface represents one subgraph to
107   // be delegated.
108   // Caller takes ownership of the returned object.
109   virtual std::unique_ptr<SimpleDelegateKernelInterface>
110   CreateDelegateKernelInterface() = 0;
111 
112   // Returns SimpleDelegateInterface::Options which has delegate properties
113   // relevant for graph partitioning.
114   virtual SimpleDelegateInterface::Options DelegateOptions() const = 0;
115 };
116 
117 // Factory class that provides static methods to deal with SimpleDelegate
118 // creation and deletion.
119 class TfLiteDelegateFactory {
120  public:
121   // Creates TfLiteDelegate from the provided SimpleDelegateInterface.
122   // The returned TfLiteDelegate should be deleted using DeleteSimpleDelegate.
123   // A simple usage of the flags bit mask:
124   // CreateSimpleDelegate(..., kTfLiteDelegateFlagsAllowDynamicTensors |
125   // kTfLiteDelegateFlagsRequirePropagatedShapes)
126   static TfLiteDelegate* CreateSimpleDelegate(
127       std::unique_ptr<SimpleDelegateInterface> simple_delegate,
128       int64_t flags = kTfLiteDelegateFlagsNone);
129 
130   // Deletes 'delegate' the passed pointer must be the one returned
131   // from CreateSimpleDelegate.
132   // This function will destruct the SimpleDelegate object too.
133   static void DeleteSimpleDelegate(TfLiteDelegate* delegate);
134 
135   // A convenient function wrapping the above two functions and returning a
136   // std::unique_ptr type for auto memory management.
Create(std::unique_ptr<SimpleDelegateInterface> simple_delegate)137   inline static TfLiteDelegateUniquePtr Create(
138       std::unique_ptr<SimpleDelegateInterface> simple_delegate) {
139     return TfLiteDelegateUniquePtr(
140         CreateSimpleDelegate(std::move(simple_delegate)), DeleteSimpleDelegate);
141   }
142 };
143 
144 }  // namespace tflite
145 
146 #endif  // TENSORFLOW_LITE_DELEGATES_UTILS_SIMPLE_DELEGATE_H_
147