xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/utils.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <cstring>
21 #include <string>
22 #include <vector>
23 
24 #include "tensorflow/lite/builtin_ops.h"
25 #include "tensorflow/lite/context_util.h"
26 #include "tensorflow/lite/kernels/kernel_util.h"
27 
28 namespace tflite {
29 namespace delegates {
30 
CreateNewTensorWithDifferentType(TfLiteContext * context,const int original_tensor_index,TfLiteType new_type,TfLiteTensor ** new_tensor,int * new_tensor_index)31 TfLiteStatus CreateNewTensorWithDifferentType(TfLiteContext* context,
32                                               const int original_tensor_index,
33                                               TfLiteType new_type,
34                                               TfLiteTensor** new_tensor,
35                                               int* new_tensor_index) {
36   TF_LITE_ENSURE_STATUS(context->AddTensors(context, 1, new_tensor_index));
37   const TfLiteTensor& original_tensor = context->tensors[original_tensor_index];
38   *new_tensor = &context->tensors[*new_tensor_index];
39   (*new_tensor)->type = new_type;
40   (*new_tensor)->allocation_type = kTfLiteArenaRw;
41   const auto* original_dims = original_tensor.dims;
42   TfLiteIntArray* dims = TfLiteIntArrayCreate(original_dims->size);
43   for (int i = 0; i < original_dims->size; ++i) {
44     dims->data[i] = original_dims->data[i];
45   }
46   if (context->ResizeTensor(context, *new_tensor, dims) != kTfLiteOk) {
47     TF_LITE_KERNEL_LOG(context, "Could not resize new delegate tensor");
48     return kTfLiteError;
49   }
50   return kTfLiteOk;
51 }
52 
Partition(std::set<std::string> * unsupported_nodes_info)53 TfLiteStatus GraphPartitionHelper::Partition(
54     std::set<std::string>* unsupported_nodes_info) {
55   const auto prepare_status = PrepareSupportedNodes(unsupported_nodes_info);
56   if (prepare_status != kTfLiteOk) return prepare_status;
57 
58   TfLiteDelegateParams* partition_params_array_ = nullptr;
59   int num_partitions_ = 0;
60   if (context_->PreviewDelegatePartitioning(context_, supported_nodes_,
61                                             &partition_params_array_,
62                                             &num_partitions_) != kTfLiteOk) {
63     TF_LITE_KERNEL_LOG(context_, "Unable to preview delegate partition.\n");
64     return kTfLiteError;
65   }
66 
67   for (int i = 0; i < num_partitions_; ++i) {
68     partitions_.push_back(partition_params_array_ + i);
69   }
70 
71   return kTfLiteOk;
72 }
73 
74 std::vector<TfLiteDelegateParams*>
GetFirstNLargestPartitions(int n,int min_nodes_per_partition) const75 GraphPartitionHelper::GetFirstNLargestPartitions(
76     int n, int min_nodes_per_partition) const {
77   // In general, the number of partitions in a delegate is never likely to be
78   // high enough to cause latency issues. Also considering this is generally a
79   // one-time work, we simply unconditionally sort partitions here according to
80   // the size.
81   std::vector<TfLiteDelegateParams*> sorted_partitions(partitions_);
82   std::sort(sorted_partitions.begin(), sorted_partitions.end(),
83             [](TfLiteDelegateParams* left, TfLiteDelegateParams* right) {
84               // Reverse sort
85               return left->nodes_to_replace->size >
86                      right->nodes_to_replace->size;
87             });
88 
89   std::vector<TfLiteDelegateParams*> results;
90   auto p_it = sorted_partitions.begin();
91   const int total = sorted_partitions.size();
92   for (int i = 0; i < std::min(total, n); ++i, ++p_it) {
93     auto* p = (*p_it);
94     if (p->nodes_to_replace->size < min_nodes_per_partition) {
95       break;
96     }
97     results.push_back(p);
98   }
99   return results;
100 }
101 
GetNodesOfFirstNLargestPartitionsImpl(int n,int min_nodes_per_partition)102 std::vector<int> GraphPartitionHelper::GetNodesOfFirstNLargestPartitionsImpl(
103     int n, int min_nodes_per_partition) {
104   auto first_n_partitions =
105       GetFirstNLargestPartitions(n, min_nodes_per_partition);
106   std::vector<int> ops_to_replace;
107   for (const auto p : first_n_partitions) {
108     auto nodes = p->nodes_to_replace;
109     ops_to_replace.insert(ops_to_replace.end(), nodes->data,
110                           nodes->data + nodes->size);
111   }
112   return ops_to_replace;
113 }
114 
PrepareSupportedNodes(std::set<std::string> * unsupported_nodes_info)115 TfLiteStatus GraphPartitionHelper::PrepareSupportedNodes(
116     std::set<std::string>* unsupported_nodes_info) {
117   if (!is_node_supported_fn_) return kTfLiteOk;
118 
119   TfLiteIntArray* execution_plan = nullptr;
120   auto status = context_->GetExecutionPlan(context_, &execution_plan);
121   if (status != kTfLiteOk) {
122     TF_LITE_KERNEL_LOG(context_, "Unable to get graph execution plan.\n");
123     return status;
124   }
125   // context->GetExecutionPlan invalidates memory obtained from previous calls,
126   // which is dangerous if a delegate's IsNodeSupportedFn uses it anywhere.
127   // So we store a copy to ensure validity.
128   num_total_nodes_ = execution_plan->size;
129   original_execution_plan_ = TfLiteIntArrayCreate(execution_plan->size);
130   std::memcpy(original_execution_plan_->data, execution_plan->data,
131               num_total_nodes_ * sizeof(int32_t));
132 
133   supported_nodes_ = TfLiteIntArrayCreate(num_total_nodes_);
134   supported_nodes_->size = 0;
135   for (int node_id : TfLiteIntArrayView(original_execution_plan_)) {
136     TfLiteNode* node;
137     TfLiteRegistration* registration;
138 
139     status = context_->GetNodeAndRegistration(context_, node_id, &node,
140                                               &registration);
141     if (status != kTfLiteOk) {
142       TF_LITE_KERNEL_LOG(context_,
143                          "Couldn't get node and registration info for op: %d\n",
144                          node_id);
145       supported_nodes_->size = 0;
146       return status;
147     }
148 
149     std::string unsupported_details;
150     if (IsNodeSupported(context_, node, registration, node_id,
151                         &unsupported_details)) {
152       supported_nodes_->data[supported_nodes_->size++] = node_id;
153     } else if (unsupported_nodes_info) {
154       std::string node_info = GetOpNameByRegistration(*registration);
155       node_info.append(": ");
156       node_info.append(unsupported_details);
157       unsupported_nodes_info->insert(node_info);
158     }
159   }
160 
161   num_supported_nodes_ = supported_nodes_->size;
162   return kTfLiteOk;
163 }
164 
165 std::vector<int>
GetNodesOfFirstNLargestPartitionsImpl(int n,int min_nodes_per_partition)166 FP16GraphPartitionHelper::GetNodesOfFirstNLargestPartitionsImpl(
167     int n, int min_nodes_per_partition) {
168   std::vector<int> ops_to_replace;
169 
170   if (num_supported_nodes() + constant_dequant_nodes_.size() ==
171       num_total_nodes()) {
172     // Scenario 1: Full Delegation.
173     // We delegate all nodes in this case to avoid unnecessary partitions due to
174     // FP16 DEQUANT nodes. This is safe to do since no non-delegated op needs
175     // the output of such a DEQUANT.
176     for (int node_id : TfLiteIntArrayView(original_execution_plan_)) {
177       ops_to_replace.push_back(node_id);
178     }
179   } else {
180     // Scenario 2: Partial Delegation.
181     // In this case, we just select the top 'n' applicable node subsets to
182     // delegate, devoid of any FP16 DEQUANT ops. Handling the latter is tricky
183     // in partial delegation cases & causes edge cases if non-delegated nodes
184     // consume their output. So we keep all of them on CPU.
185     auto first_n_partitions =
186         GetFirstNLargestPartitions(n, min_nodes_per_partition);
187     if (first_n_partitions.empty()) return ops_to_replace;
188     for (int i = 0; i < first_n_partitions.size(); ++i) {
189       auto nodes = first_n_partitions[i]->nodes_to_replace;
190       ops_to_replace.insert(ops_to_replace.end(), nodes->data,
191                             nodes->data + nodes->size);
192     }
193   }
194 
195   // Modify the inputs of relevant ops that support fp16 constants.
196   RemapFp16InputTensors(ops_to_replace);
197   return ops_to_replace;
198 }
199 
IsNodeSupported(TfLiteContext * context,TfLiteNode * node,TfLiteRegistration * registration,int node_id,std::string * unsupported_details)200 bool FP16GraphPartitionHelper::IsNodeSupported(
201     TfLiteContext* context, TfLiteNode* node, TfLiteRegistration* registration,
202     int node_id, std::string* unsupported_details) {
203   if (registration->builtin_code == kTfLiteBuiltinDequantize) {
204     auto& dequantize_input = context_->tensors[node->inputs->data[0]];
205     if (dequantize_input.type == kTfLiteFloat16 &&
206         IsConstantTensor(&dequantize_input)) {
207       // Update mappings if this node is a fp16 DEQUANTIZE node that
208       // works on a **constant** input tensor.
209       // If the input is not a constant, the remapping that we do here will
210       // cause bugs due to preceding ops such as DENSIFY.
211       constant_dequant_map_[node->outputs->data[0]] = node->inputs->data[0];
212       constant_dequant_nodes_[node->outputs->data[0]] = node_id;
213       // We do not accept these ops right now.
214       // This is done to support use-cases where a DEQUANTIZE output might be
215       // consumed by a CPU op.
216       return false;
217     }
218   }
219 
220   // To check if a (possibly) FP16 node is supported, we temporarily point the
221   // node's inputs to the original fp16 tensors. This 'mutated' node is then
222   // passed to the base IsNodeSupported function for checking. After the check,
223   // we remap the original node inputs, so that the TFLite graph remains the
224   // same.
225   std::vector<int> orig_inputs;
226   if (!constant_dequant_nodes_.empty()) {
227     RemapFp16InputTensors(node, &orig_inputs);
228   }
229 
230   const auto is_supported = GraphPartitionHelper::IsNodeSupported(
231       context, node, registration, node_id, unsupported_details);
232 
233   if (!orig_inputs.empty() && node->inputs->size == orig_inputs.size()) {
234     // Remapping happened. Restore original inputs.
235     for (int j = 0; j < node->inputs->size; ++j) {
236       node->inputs->data[j] = orig_inputs[j];
237     }
238   }
239   return is_supported;
240 }
241 
RemapFp16InputTensors(const std::vector<int> & nodes) const242 void FP16GraphPartitionHelper::RemapFp16InputTensors(
243     const std::vector<int>& nodes) const {
244   for (int node_id : nodes) {
245     TfLiteNode* node;
246     TfLiteRegistration* registration;
247     TfLiteStatus status = context_->GetNodeAndRegistration(
248         context_, node_id, &node, &registration);
249     if (status != kTfLiteOk) {
250       TF_LITE_KERNEL_LOG(context_,
251                          "Couldn't get node and registration info for op: %d\n",
252                          node_id);
253     }
254     RemapFp16InputTensors(node, nullptr /* orig_inputs*/);
255   }
256 }
257 
RemapFp16InputTensors(TfLiteNode * node,std::vector<int> * orig_inputs) const258 void FP16GraphPartitionHelper::RemapFp16InputTensors(
259     TfLiteNode* node, std::vector<int>* orig_inputs) const {
260   TfLiteIntArray* inputs = node->inputs;
261   auto inputs_view = TfLiteIntArrayView(inputs);
262   // Prepopulate 'orig_inputs' first and clear it if there's no input from a
263   // dequant op.
264   if (orig_inputs) {
265     orig_inputs->clear();
266     orig_inputs->reserve(inputs->size);
267     for (auto tid : inputs_view) {
268       orig_inputs->push_back(tid);
269     }
270   }
271   // Fix this node's inputs (i.e. prune out the preceding dequantize node) in
272   // order to test if it is supported.
273   bool is_remapped = false;
274   for (int j = 0; j < inputs->size; ++j) {
275     const int input_tid = inputs->data[j];
276     const auto it = constant_dequant_map_.find(input_tid);
277     if (it != constant_dequant_map_.end()) {
278       inputs->data[j] = it->second;
279       is_remapped = true;
280     }
281   }
282   if (!is_remapped && orig_inputs) orig_inputs->clear();
283 }
284 
285 }  // namespace delegates
286 }  // namespace tflite
287