xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/utils.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_DELEGATES_UTILS_H_
17 #define TENSORFLOW_LITE_DELEGATES_UTILS_H_
18 
19 // Utility functions and classes for implementing delegates.
20 
21 #include <functional>
22 #include <limits>
23 #include <set>
24 #include <string>
25 #include <unordered_map>
26 #include <utility>
27 #include <vector>
28 
29 #include "tensorflow/lite/c/common.h"
30 #include "tensorflow/lite/util.h"
31 
32 namespace tflite {
33 namespace delegates {
34 
35 // Creates a new Read/Write tensor having the same shape as the original, but
36 // with a different type. Note that this might void existing references to
37 // tensors.
38 TfLiteStatus CreateNewTensorWithDifferentType(TfLiteContext* context,
39                                               const int original_tensor_index,
40                                               TfLiteType new_type,
41                                               TfLiteTensor** new_tensor,
42                                               int* new_tensor_index);
43 
44 using IsNodeSupportedFn =
45     std::function<bool(TfLiteContext*, TfLiteNode*, TfLiteRegistration*,
46                        std::string* unsupported_details)>;
47 
48 // A utility class to help model graph parition.
49 // Note the class *needs* to be used in TfLiteDelegate::Prepare.
50 class GraphPartitionHelper {
51  public:
GraphPartitionHelper(TfLiteContext * context,IsNodeSupportedFn is_node_supported_fn)52   GraphPartitionHelper(TfLiteContext* context,
53                        IsNodeSupportedFn is_node_supported_fn)
54       : context_(context), is_node_supported_fn_(is_node_supported_fn) {}
55 
GraphPartitionHelper(TfLiteContext * context,const std::vector<int> & supported_node_indices)56   GraphPartitionHelper(TfLiteContext* context,
57                        const std::vector<int>& supported_node_indices)
58       : context_(context),
59         num_total_nodes_(supported_node_indices.size()),
60         supported_nodes_(
61             ConvertVectorToTfLiteIntArray(supported_node_indices)) {}
62 
~GraphPartitionHelper()63   virtual ~GraphPartitionHelper() {
64     TfLiteIntArrayFree(supported_nodes_);
65     TfLiteIntArrayFree(original_execution_plan_);
66   }
67 
68   // Partition the graph into node subsets such that each subset could be
69   // replaced with one delegate kernel (i.e. a kTfLiteBuiltinDelegate op).
70   // If 'unsupported_nodes_info' is provided, it will be populated with
71   // information about all different unsupported nodes.
72   virtual TfLiteStatus Partition(std::set<std::string>* unsupported_nodes_info);
73 
74   // Returns the first n largest partitions or all if #partitions is less than
75   // 'n' and each parition has at least (>=) 'min_nodes_per_partition' nodes.
76   // Note that partitions are ranked according to the number of nodes that
77   // a partition has, and the returned TfLiteDelegateParams objects are *owned*
78   // by the TfLite runtime.
79   // TODO(b/156707497): remove this and use GetNodesOfFirstNLargestPartitions
80   std::vector<TfLiteDelegateParams*> GetFirstNLargestPartitions(
81       int n = std::numeric_limits<int>::max(),
82       int min_nodes_per_partition = 0) const;
83 
84   // Returns a list of node indices of all nodes from the first n largest
85   // partitions. If there are fewer paritions than n, all nodes will be
86   // returned. The partition is ranked according to the number of nodes.
87   std::vector<int> GetNodesOfFirstNLargestPartitions(
88       int n = std::numeric_limits<int>::max(),
89       int min_nodes_per_partition = 0) {
90     // Separated implementation that can be overrided, to preserve default value
91     return GetNodesOfFirstNLargestPartitionsImpl(n, min_nodes_per_partition);
92   }
93 
num_total_nodes()94   int num_total_nodes() const { return num_total_nodes_; }
num_supported_nodes()95   int num_supported_nodes() const { return num_supported_nodes_; }
num_partitions()96   int num_partitions() const { return partitions_.size(); }
97 
98  protected:
IsNodeSupported(TfLiteContext * context,TfLiteNode * node,TfLiteRegistration * registration,int node_id,std::string * unsupported_details)99   virtual bool IsNodeSupported(TfLiteContext* context, TfLiteNode* node,
100                                TfLiteRegistration* registration, int node_id,
101                                std::string* unsupported_details) {
102     return is_node_supported_fn_(context, node, registration,
103                                  unsupported_details);
104   }
105   virtual std::vector<int> GetNodesOfFirstNLargestPartitionsImpl(
106       int n, int min_nodes_per_partition);
107 
108   TfLiteContext* const context_ = nullptr;
109 
110   // Doesn't own the memory of each TfLiteDelegateParams object as it's
111   // managed by the TfLite runtime itself. See
112   // TfLiteContext::PreviewDelegatePartitioning for details.
113   std::vector<TfLiteDelegateParams*> partitions_;
114 
115   // Copy of (pre-delegation) execution plan obtained from TfLiteContext in
116   // PrepareSupportedNodes
117   TfLiteIntArray* original_execution_plan_ = nullptr;
118 
119  private:
120   // Generate a list of supported nodes (i.e. populating 'supported_nodes_') by
121   // iterating over all nodes (i,e. those listed in the execution_plan
122   // associated w/ 'context_').
123   // If 'unsupported_nodes_info' is provided, it will be populated with
124   // information about all different unsupported nodes.
125   TfLiteStatus PrepareSupportedNodes(
126       std::set<std::string>* unsupported_nodes_info = nullptr);
127 
128   // The number of total nodes passed in for partitioning (i.e. the
129   // execution_plan size associated w/ 'context_')
130   int num_total_nodes_ = 0;
131 
132   int num_supported_nodes_ = 0;
133 
134   // Tells if a node is supported as it could be delegated.
135   const IsNodeSupportedFn is_node_supported_fn_ = nullptr;
136 
137   // Contains an array of supported node indices.
138   TfLiteIntArray* supported_nodes_ = nullptr;  // owns the memory
139 };
140 
141 // Specialized partitioner for graphs that possibly contain fp16 tensors.
142 //
143 // From nodes that accept fp16 inputs, this delegates the following:
144 // 1. All nodes (except DEQUANTIZE) that are supported with constant fp16 inputs
145 // by the delegate (in the TFLite graph, these nodes take in dequantized FP32
146 // outputs).
147 // 2. All fp16 DEQUANTIZE nodes that have *all* their consumers in the *first*
148 // delegated partition. This is because TFLite's partitioning algorithm
149 // greedily puts all such nodes in the first partition.
150 class FP16GraphPartitionHelper : public GraphPartitionHelper {
151  public:
FP16GraphPartitionHelper(TfLiteContext * context,IsNodeSupportedFn is_node_supported_fn)152   FP16GraphPartitionHelper(TfLiteContext* context,
153                            IsNodeSupportedFn is_node_supported_fn)
154       : GraphPartitionHelper(context, std::move(is_node_supported_fn)) {}
155 
156  protected:
157   // Specialized function to handle fp16 nodes.
158   bool IsNodeSupported(TfLiteContext* context, TfLiteNode* node,
159                        TfLiteRegistration* registration, int node_id,
160                        std::string* unsupported_details) override;
161 
162   // This will remap input tensors by removing FP16 to FP32 dequantized tensors.
163   std::vector<int> GetNodesOfFirstNLargestPartitionsImpl(
164       int n, int min_nodes_per_partition) override;
165 
166  private:
167   // This remaps fp32 inputs of the given node to their corresponding fp16
168   // version, if applicable. Can be summarized as:
169   // fp16 -> DEQUANTIZE -> fp32 -> OP -> output
170   // becomes
171   // fp16 -> OP -> output
172   void RemapFp16InputTensors(TfLiteNode* node,
173                              std::vector<int>* orig_inputs) const;
174 
175   // Performs the above remapping for all nodes in the given list, without
176   // tracking the original inputs.
177   void RemapFp16InputTensors(const std::vector<int>& nodes) const;
178 
179   // ('dequantize' here refers to fp16 DEQUANTIZE)
180   // Mapping of dequantize nodes' output tensor-id to its node id.
181   // TODO(b/156707497): Use absl hash_maps here.
182   std::unordered_map<int, int> constant_dequant_nodes_;
183   // Mapping of DEQUANTIZE node's output (fp32) to its input (fp16).
184   std::unordered_map<int, int> constant_dequant_map_;
185 };
186 
187 }  // namespace delegates
188 }  // namespace tflite
189 
190 #endif  // TENSORFLOW_LITE_DELEGATES_UTILS_H_
191