xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/flex/kernel.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/flex/kernel.h"
16 
17 #include <algorithm>
18 #include <map>
19 #include <memory>
20 #include <set>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
26 #include "tensorflow/core/common_runtime/eager/context.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/node_def_util.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/platform/status.h"
31 #include "tensorflow/core/protobuf/error_codes.pb.h"
32 #include "tensorflow/lite/builtin_ops.h"
33 #include "tensorflow/lite/c/common.h"
34 #include "tensorflow/lite/context_util.h"
35 #include "tensorflow/lite/core/api/profiler.h"
36 #include "tensorflow/lite/delegates/flex/delegate.h"
37 #include "tensorflow/lite/delegates/flex/delegate_data.h"
38 #include "tensorflow/lite/delegates/flex/util.h"
39 #include "tensorflow/lite/kernels/kernel_util.h"
40 #include "tensorflow/lite/minimal_logging.h"
41 #include "tensorflow/lite/string_type.h"
42 
43 // Note: this is part of TF Lite's Flex delegation code which is to be
44 // completed soon.
45 
46 // This is the TF Lite op that is created by the flex delegate to handle
47 // execution of a supported subgraph. The usual flow is that the delegate
48 // informs the interpreter of supported nodes in a graph, and each supported
49 // subgraph is replaced with one instance of this kernel.
50 //
51 // The kernel is initialized with TfLiteDelegateParams from which we retrieve
52 // the global EagerContext and BufferMap, as well as a list of inputs and
53 // outputs to the subgraph. Those are used to build the OpData, with a list of
54 // TensorFlow Ops that should be executed in order (which we call an OpNode).
55 //
56 // For each node included in the subgraph, we query the interpreter and
57 // retrieve the associated NodeDef, which is then used to configure the
58 // corresponding TensorFlow OpKernel.
59 
60 using tensorflow::shape_inference::DimensionHandle;
61 using tensorflow::shape_inference::InferenceContext;
62 using tensorflow::shape_inference::ShapeAndType;
63 using tensorflow::shape_inference::ShapeHandle;
64 
65 namespace tflite {
66 namespace flex {
67 
68 constexpr char kReadVariableOp[] = "ReadVariableOp";
69 
70 struct OpNode;
71 
72 // Represents the origin of a given tensor as a reference to the output
73 // of an upstream node.
74 struct TensorSource {
75   OpNode* node;
76   int node_output_index;
77 };
78 
79 // A list of inputs of a given node of the TensorFlow graph.
80 class OpInputs {
81  public:
OpInputs(const TfLiteIntArray * indexes)82   explicit OpInputs(const TfLiteIntArray* indexes) {
83     for (int index : TfLiteIntArrayView(indexes)) {
84       inputs_.push_back(index);
85     }
86     forwardable_.resize(inputs_.size());
87   }
~OpInputs()88   ~OpInputs() {}
89 
Size() const90   int Size() const { return inputs_.size(); }
91 
TfLiteIndex(int i) const92   int TfLiteIndex(int i) const { return inputs_[i]; }
93 
94   // Given a map relating tensors to the node that originates them, populate a
95   // list of sources for the tensors in this class.
InitializeTensorSources(const std::map<int,TensorSource> & tflite_tensor_sources)96   void InitializeTensorSources(
97       const std::map<int, TensorSource>& tflite_tensor_sources) {
98     sources_.clear();
99     for (int i : inputs_) {
100       auto it = tflite_tensor_sources.find(i);
101       if (it == tflite_tensor_sources.end()) {
102         sources_.push_back({nullptr, 0});
103       } else {
104         sources_.push_back(it->second);
105       }
106     }
107   }
108 
SetForwardable(int i,bool v)109   void SetForwardable(int i, bool v) { forwardable_[i] = v; }
110 
IsForwardable(int i) const111   bool IsForwardable(int i) const { return forwardable_[i]; }
112 
GetTensorSource(int i) const113   TensorSource GetTensorSource(int i) const { return sources_[i]; }
114 
115  private:
116   std::vector<int> inputs_;
117   std::vector<TensorSource> sources_;
118 
119   // List of tensors that can be used by TF in its forwarding optimization.
120   // Doing so allows an input tensor to be modified and used as the output
121   // tensor. The delegate takes care of not holding any references to tensors
122   // in this list while the corresponding tensorflow::OpKernel is executed.
123   std::vector<int> forwardable_;
124 };
125 
126 // A list of outputs of a given node of the TensorFlow graph, along with
127 // the actual outputs of the tensorflow::OpKernel.
128 class OpOutputs {
129  public:
OpOutputs(const TfLiteIntArray * indexes)130   explicit OpOutputs(const TfLiteIntArray* indexes) {
131     for (int index : TfLiteIntArrayView(indexes)) {
132       outputs_.push_back(index);
133     }
134     vector_.resize(outputs_.size());
135   }
136   ~OpOutputs() = default;
137 
138   // Stores information about which of the tensors in this class are also
139   // outputs of the sugbraph.
InitializeGraphOutputs(const std::set<int> & subgraph_outputs)140   void InitializeGraphOutputs(const std::set<int>& subgraph_outputs) {
141     subgraph_outputs_.clear();
142     for (int i : outputs_) {
143       subgraph_outputs_.push_back(subgraph_outputs.count(i) > 0);
144     }
145   }
146 
147   // Returns true if the tensor given by index 'i' is an output of the entire
148   // subgraph.
IsSubgraphOutput(int i) const149   bool IsSubgraphOutput(int i) const { return subgraph_outputs_[i]; }
150 
GetTensor(int i) const151   const tensorflow::Tensor& GetTensor(int i) const { return vector_[i]; }
ReleaseTensor(int i)152   tensorflow::Tensor ReleaseTensor(int i) { return std::move(vector_[i]); }
153 
Size() const154   int Size() const { return outputs_.size(); }
155 
TfLiteIndex(int i) const156   int TfLiteIndex(int i) const { return outputs_[i]; }
157 
GetTensors()158   tensorflow::gtl::InlinedVector<tensorflow::Tensor, 2>* GetTensors() {
159     return &vector_;
160   }
161 
162  private:
163   std::vector<int> outputs_;
164   std::vector<bool> subgraph_outputs_;
165   tensorflow::gtl::InlinedVector<tensorflow::Tensor, 2> vector_;
166 };
167 
168 // This struct holds information such as tensor lifecycle and BufferMap which
169 // needs to be shared between `OpNode` and DelegateKernel.
170 struct OpDataInfo {
171   // Buffer map which stores the mapping between TfLiteTensor index to TF
172   // tensor.
173   BufferMap* buffer_map;
174   // Mapping information between TfLiteTensor index to last node which uses the
175   // tensor.
176   std::map<int, int>* tensor_release_map;
177   // For output tensors that don't need to be preserved in the BufferMap, we
178   // copy them to TF Lite tensors and keep the tensor indexes in this set.
179   std::set<int> already_transferred_outputs;
180 };
181 
182 // A single node within the larger 'op'. Note that this kernel executes many
183 // TensorFlow ops within a single TF Lite op.
184 class OpNode {
185  public:
OpNode(const TfLiteIntArray * inputs,const TfLiteIntArray * outputs)186   OpNode(const TfLiteIntArray* inputs, const TfLiteIntArray* outputs)
187       : inputs_(inputs), outputs_(outputs) {}
188   ~OpNode() = default;
189 
name() const190   const string& name() const { return name_; }
set_name(const string & name)191   void set_name(const string& name) { name_ = name; }
192 
index() const193   int index() const { return index_; }
set_index(int index)194   void set_index(int index) { index_ = index; }
195 
nodedef() const196   const tensorflow::NodeDef& nodedef() const { return nodedef_; }
op_reg_data() const197   const tensorflow::OpRegistrationData* op_reg_data() const {
198     return op_reg_data_;
199   }
200 
inputs() const201   const OpInputs& inputs() const { return inputs_; }
mutable_inputs()202   OpInputs* mutable_inputs() { return &inputs_; }
203 
outputs() const204   const OpOutputs& outputs() const { return outputs_; }
mutable_outputs()205   OpOutputs* mutable_outputs() { return &outputs_; }
206 
NumInputs() const207   int NumInputs() const { return inputs_.Size(); }
NumOutputs() const208   int NumOutputs() const { return outputs_.Size(); }
209 
op_kernel_runner() const210   const tensorflow::tfrt_stub::OpKernelRunner& op_kernel_runner() const {
211     return op_kernel_runner_;
212   }
213 
InitializeNodeDef(const void * custom_initial_data,int custom_initial_data_size)214   tensorflow::Status InitializeNodeDef(const void* custom_initial_data,
215                                        int custom_initial_data_size) {
216     if (!custom_initial_data) {
217       return tensorflow::errors::Internal(
218           "Cannot convert empty data into a valid NodeDef");
219     }
220     // The flexbuffer contains a vector where the first elements is the
221     // op name and the second is a serialized NodeDef.
222     const flexbuffers::Vector& v =
223         flexbuffers::GetRoot(
224             reinterpret_cast<const uint8_t*>(custom_initial_data),
225             custom_initial_data_size)
226             .AsVector();
227 
228     name_ = v[0].AsString().str();
229     if (!nodedef_.ParseFromString(v[1].AsString().str())) {
230       nodedef_.Clear();
231       return tensorflow::errors::Internal(
232           "Failed to parse data into a valid NodeDef");
233     }
234 
235     // Fill NodeDef with defaults if it's a valid op.
236     TF_RETURN_IF_ERROR(
237         tensorflow::OpRegistry::Global()->LookUp(nodedef_.op(), &op_reg_data_));
238     AddDefaultsToNodeDef(op_reg_data_->op_def, &nodedef_);
239 
240     return ::tensorflow::OkStatus();
241   }
242 
BuildOpKernelRunner(tensorflow::EagerContext * eager_context)243   tensorflow::Status BuildOpKernelRunner(
244       tensorflow::EagerContext* eager_context) {
245     // Create tensorflow::OpKernel on host CPU.
246     TF_ASSIGN_OR_RETURN(op_kernel_runner_,
247                         tensorflow::tfrt_stub::OpKernelRunner::Create(
248                             name_, inputs_.Size(), /*attr_builder=*/
249                             [this](tensorflow::AttrValueMap* attr_value_map) {
250                               *attr_value_map = nodedef_.attr();
251                               return ::tensorflow::OkStatus();
252                             },
253                             *eager_context->pflr(),
254                             eager_context->local_device_mgr()->HostCPU()));
255 
256     return ::tensorflow::OkStatus();
257   }
258 
BuildOpKernelInputs(const BufferMap * buffer_map,tensorflow::tfrt_stub::OpKernelRunState * run_state)259   tensorflow::Status BuildOpKernelInputs(
260       const BufferMap* buffer_map,
261       tensorflow::tfrt_stub::OpKernelRunState* run_state) {
262     run_state->input_tf_tensors.resize(inputs_.Size());
263     run_state->input_tf_tensor_values.resize(inputs_.Size());
264 
265     for (int i = 0; i < inputs_.Size(); ++i) {
266       int input_index = inputs_.TfLiteIndex(i);
267       TensorSource s = inputs_.GetTensorSource(i);
268       if (!s.node) {
269         // This input is not produced by this TF subgraph (it could be a TF
270         // Lite native buffer, or could be produced by a separater subgraph). We
271         // need to fetch it from the delegate's buffer_map.
272         if (!buffer_map->HasTensor(input_index)) {
273           return tensorflow::errors::Internal(
274               "Cannot read from invalid tensor index ", input_index);
275         }
276         run_state->input_tf_tensors[i] = buffer_map->GetTensor(input_index);
277       } else {
278         // If this is a forwardable tensor, we will remove it from the previous
279         // op's list, giving TF the opportunity to reuse its buffer.
280         if (inputs_.IsForwardable(i)) {
281           run_state->input_tf_tensors[i] =
282               s.node->outputs_.ReleaseTensor(s.node_output_index);
283         } else {
284           run_state->input_tf_tensors[i] =
285               s.node->outputs_.GetTensor(s.node_output_index);
286         }
287       }
288       run_state->input_tf_tensor_values[i].tensor =
289           &run_state->input_tf_tensors[i];
290     }
291     return ::tensorflow::OkStatus();
292   }
293 
294   // Returns whether an output tensor should be preserved in the buffer map by
295   // checking its lifetime information.
296   // The eager tensor doesn't need to be persisted in the buffer map if it has
297   // no future uses in the graph.
ShouldPersistTensorflowTensor(TfLiteContext * context,const OpDataInfo * shared_info,int tensor_index,int node_index)298   bool ShouldPersistTensorflowTensor(TfLiteContext* context,
299                                      const OpDataInfo* shared_info,
300                                      int tensor_index, int node_index) {
301     TfLiteTensor* tensor = &context->tensors[tensor_index];
302     // Always persist variant|resource|string tensors since they have special
303     // storage requirement.
304     if (IsResourceOrVariant(tensor) || tensor->type == kTfLiteString) {
305       return true;
306     }
307 
308     auto it = shared_info->tensor_release_map->find(tensor_index);
309     return it != shared_info->tensor_release_map->end() &&
310            it->second > node_index;
311   }
312 
313   // Copies the data of Tensorflow tensor into the corresponding TfLite tensor,
314   // after copy is done release the original tensor so that memory could be
315   // released by TF runtime.
CopyToTfLiteTensor(TfLiteContext * context,OpDataInfo * shared_info,TfLiteTensor * tensor,tensorflow::Tensor * tf_tensor,int tensor_index) const316   TfLiteStatus CopyToTfLiteTensor(TfLiteContext* context,
317                                   OpDataInfo* shared_info, TfLiteTensor* tensor,
318                                   tensorflow::Tensor* tf_tensor,
319                                   int tensor_index) const {
320     if (tensor->allocation_type == kTfLiteDynamic) {
321       // For dynamic tensors, update the TfLite tensor's shape information from
322       // the Tensorflow tensor.
323       CopyShapeAndType(context, *tf_tensor, tensor);
324     }
325     tensorflow::StringPiece t_data = tf_tensor->tensor_data();
326     if (tf_tensor->NumElements() != NumElements(tensor) ||
327         tf_tensor->TotalBytes() != tensor->bytes) {
328       TF_LITE_KERNEL_LOG(context,
329                          "FlexDelegate: Tensor %s(%d) buffer size mismatch "
330                          "%zu(%lld) != %ld(%ld)",
331                          tensor->name, tensor_index, tf_tensor->TotalBytes(),
332                          tf_tensor->NumElements(), tensor->bytes,
333                          NumElements(tensor));
334       return kTfLiteError;
335     }
336     // Copy TF tensor's data content into TfLiteTensor, and release the tensor.
337     memcpy(tensor->data.raw, t_data.data(), t_data.size());
338     *tf_tensor = {};
339     shared_info->already_transferred_outputs.insert(tensor_index);
340     return kTfLiteOk;
341   }
342 
343   // TODO(b/204479285): Release tensors from BufferMap if it has no future
344   // uses.
MaybePersistTensorflowOutputs(TfLiteContext * context,OpDataInfo * shared_info,int node_index)345   tensorflow::Status MaybePersistTensorflowOutputs(TfLiteContext* context,
346                                                    OpDataInfo* shared_info,
347                                                    int node_index) {
348     auto* tensors = outputs_.GetTensors();
349 
350     for (int i = 0; i < outputs_.Size(); ++i) {
351       if (outputs_.IsSubgraphOutput(i)) {
352         tensorflow::Tensor& tf_tensor = tensors->at(i);
353         const int tflite_index = outputs_.TfLiteIndex(i);
354         TfLiteTensor* tensor = &context->tensors[tflite_index];
355         if (!ShouldPersistTensorflowTensor(context, shared_info, tflite_index,
356                                            node_index)) {
357           if (CopyToTfLiteTensor(context, shared_info, tensor, &tf_tensor,
358                                  tflite_index) != kTfLiteOk) {
359             return tensorflow::Status(tensorflow::error::INTERNAL,
360                                       "failed to copy data from TF tensor");
361           }
362         } else {
363           shared_info->buffer_map->SetFromTensorFlow(outputs_.TfLiteIndex(i),
364                                                      tf_tensor);
365         }
366       }
367     }
368     return ::tensorflow::OkStatus();
369   }
370 
371  private:
372   OpNode(const OpNode&) = delete;
373   OpNode& operator=(const OpNode&) = delete;
374 
375   // The name of the TensorFlow op to execute.
376   string name_;
377   // Index of this node into TF Lite's operator list.
378   int index_;
379   // The corresponding NodeDef, containing the attributes for the op.
380   tensorflow::NodeDef nodedef_;
381   // The corresponding OpRegistrationData pointer.
382   const tensorflow::OpRegistrationData* op_reg_data_;
383   // List of inputs, as TF Lite tensor indices.
384   OpInputs inputs_;
385   // List of outputs, as TF Lite tensor indices.
386   OpOutputs outputs_;
387 
388   tensorflow::tfrt_stub::OpKernelRunner op_kernel_runner_;
389 };
390 
391 // The larger 'op', which contains all the nodes in a supported subgraph.
392 struct OpData {
393   tensorflow::EagerContext* eager_context;
394   tensorflow::CancellationManager* cancellation_manager;
395   std::vector<std::unique_ptr<OpNode>> nodes;
396   std::vector<int> subgraph_inputs;
397   std::vector<int> subgraph_outputs;
398   std::set<int>
399       disable_reusing_buffer_tensors;  // A list of input tensor indexes which
400                                        // input buffer should not be reused by
401                                        // tensorflow::Tensor.
402   OpDataInfo shared_info;
403 };
404 
ExecuteOpKernelRunner(tensorflow::tfrt_stub::OpKernelRunState * run_state,TfLiteContext * context,OpNode * node_data)405 tensorflow::Status DelegateKernel::ExecuteOpKernelRunner(
406     tensorflow::tfrt_stub::OpKernelRunState* run_state, TfLiteContext* context,
407     OpNode* node_data) {
408   const auto& op_kernel_runner = node_data->op_kernel_runner();
409 
410   if (op_kernel_runner.op_kernel()->num_outputs() != node_data->NumOutputs()) {
411     return tensorflow::errors::Internal(
412         "Unexpected number of outputs from tensorflow::OpKernel");
413   }
414 
415   TF_RETURN_IF_ERROR(node_data->BuildOpKernelInputs(
416       op_data_->shared_info.buffer_map, run_state));
417 
418   run_state->params.inputs = run_state->input_tf_tensor_values;
419   run_state->params.op_kernel = op_kernel_runner.op_kernel();
420   run_state->params.input_alloc_attrs = op_kernel_runner.input_alloc_attrs();
421   run_state->params.output_attr_array =
422       op_kernel_runner.output_alloc_attrs().data();
423   run_state->params.function_library =
424       op_kernel_runner.function_library_runtime();
425 
426   tensorflow::OpKernelContext tf_context(&run_state->params,
427                                          node_data->NumOutputs());
428   op_kernel_runner.Run(&tf_context);
429   TF_RETURN_IF_ERROR(tf_context.status());
430 
431   auto& outputs = *node_data->mutable_outputs()->GetTensors();
432   for (int i = 0; i < tf_context.num_outputs(); ++i) {
433     outputs[i] = std::move(*tf_context.mutable_output(i));
434   }
435 
436   return node_data->MaybePersistTensorflowOutputs(
437       context, &(op_data_->shared_info), node_data->index());
438 }
439 
DelegateKernel()440 DelegateKernel::DelegateKernel() : op_data_(new OpData) {}
~DelegateKernel()441 DelegateKernel::~DelegateKernel() {}
442 
Init(TfLiteContext * context,const TfLiteDelegateParams * params)443 TfLiteStatus DelegateKernel::Init(TfLiteContext* context,
444                                   const TfLiteDelegateParams* params) {
445   auto* flex_delegate_data =
446       reinterpret_cast<FlexDelegate*>(params->delegate->data_)->mutable_data();
447   op_data_->eager_context = flex_delegate_data->GetEagerContext();
448   op_data_->cancellation_manager = flex_delegate_data->GetCancellationManager();
449   op_data_->shared_info.buffer_map = flex_delegate_data->GetBufferMap(context);
450   op_data_->shared_info.tensor_release_map =
451       flex_delegate_data->GetTensorReleaseMap(context);
452 
453   CHECK(params->output_tensors);
454   std::set<int> output_set;
455   for (auto tensor_index : TfLiteIntArrayView(params->output_tensors)) {
456     op_data_->subgraph_outputs.push_back(tensor_index);
457     output_set.insert(tensor_index);
458   }
459 
460   CHECK(params->input_tensors);
461   for (auto tensor_index : TfLiteIntArrayView(params->input_tensors)) {
462     op_data_->subgraph_inputs.push_back(tensor_index);
463   }
464   std::set<int> subgraph_inputs(op_data_->subgraph_inputs.begin(),
465                                 op_data_->subgraph_inputs.end());
466 
467   op_data_->nodes.reserve(params->nodes_to_replace->size);
468 
469   CHECK(params->nodes_to_replace);
470   tensorflow::Status status;
471 
472   // Now we explicitly disable reusing TFLite tensor buffers for certain TF ops,
473   // since those ops might produce results which keep reference of the input
474   // tensors (buffer forwarding).
475   auto check_if_op_reuses_input = [](const string& op_name) {
476     return op_name == "TensorListPushBack" || op_name == "TensorListSetItem" ||
477            op_name == "SparseReshape";
478   };
479 
480   for (auto node_index : TfLiteIntArrayView(params->nodes_to_replace)) {
481     TfLiteNode* node;
482     TfLiteRegistration* reg;
483     context->GetNodeAndRegistration(context, node_index, &node, &reg);
484 
485     op_data_->nodes.emplace_back(new OpNode(node->inputs, node->outputs));
486     OpNode& node_data = *op_data_->nodes.back();
487 
488     node_data.set_index(node_index);
489     node_data.set_name("");
490 
491     status = node_data.InitializeNodeDef(node->custom_initial_data,
492                                          node->custom_initial_data_size);
493     if (!status.ok()) break;
494     status = node_data.BuildOpKernelRunner(op_data_->eager_context);
495     if (!status.ok()) break;
496 
497     // For each node handled by this delegate partition, record the mapping
498     // information between each input tensor and the node index. The node index
499     // is the index of the last node in execution order that uses this tensor.
500     // So the tensor is no longer needed after this last node is executed.
501     // Since we execute in order, then the maximum index is the index of the
502     // last node that needs this tensor.
503     for (auto tensor_index : TfLiteIntArrayView(node->inputs)) {
504       int node_id = node_index;
505       if (op_data_->shared_info.tensor_release_map->find(tensor_index) !=
506           op_data_->shared_info.tensor_release_map->end()) {
507         node_id =
508             std::max(op_data_->shared_info.tensor_release_map->at(tensor_index),
509                      node_index);
510       }
511       (*op_data_->shared_info.tensor_release_map)[tensor_index] = node_id;
512 
513       if (subgraph_inputs.count(tensor_index) &&
514           check_if_op_reuses_input(node_data.nodedef().op())) {
515         op_data_->disable_reusing_buffer_tensors.insert(tensor_index);
516       }
517     }
518   }
519 
520   TF_LITE_ENSURE_STATUS(ConvertStatus(context, status));
521 
522   // Given a TfLite tensor index, return the OpNode that produces it,
523   // along with it index into that OpNodes list of outputs.
524   std::map<int, TensorSource> tflite_tensor_sources;
525 
526   // Find out how each tensor is produced. This does not account for
527   // tensors that are not produced by tensorflow::Opkernels.
528   for (auto& node_data : op_data_->nodes) {
529     node_data->mutable_outputs()->InitializeGraphOutputs(output_set);
530     for (int i = 0; i < node_data->outputs().Size(); ++i) {
531       int output_index = node_data->outputs().TfLiteIndex(i);
532       tflite_tensor_sources[output_index] = TensorSource{node_data.get(), i};
533     }
534   }
535 
536   // For each node, resolve the inputs, so we can keep pointers to the nodes
537   // that produces them.
538   for (auto& node_data : op_data_->nodes) {
539     node_data->mutable_inputs()->InitializeTensorSources(tflite_tensor_sources);
540   }
541   return kTfLiteOk;
542 }
543 
Prepare(TfLiteContext * context,TfLiteNode * node)544 TfLiteStatus DelegateKernel::Prepare(TfLiteContext* context, TfLiteNode* node) {
545   TF_LITE_ENSURE_MSG(
546       context, op_data_->eager_context != nullptr,
547       "Failed to initialize eager context. This often happens when a CPU "
548       "device has not been registered, presumably because some symbols from "
549       "tensorflow/core:core_cpu_impl were not linked into the binary.");
550 
551   // We will keep track of the number of references to each tensor in the
552   // graph, so we can make them "forwardable" if there is only one reference.
553   std::map<int, int> tensor_ref_count;
554 
555   // Whenever we find a constant tensor, insert it in the buffer map.
556   BufferMap* buffer_map = op_data_->shared_info.buffer_map;
557   for (auto tensor_index : op_data_->subgraph_inputs) {
558     TfLiteTensor* tensor = &context->tensors[tensor_index];
559     if (IsConstantTensor(tensor)) {
560       if (!tensor->data_is_stale || !buffer_map->HasTensor(tensor_index)) {
561         buffer_map->SetFromTfLite(tensor_index, tensor);
562       }
563     }
564 
565     // Input tensors should never be forwarded so we increment their ref counts
566     // twice: once for this graph and another for the possibility of them being
567     // used by another subgraph, or being an output of the full graph.
568     tensor_ref_count[tensor_index] += 2;
569   }
570 
571   const bool shapes_are_valid =
572       (ValidateOutputTensorShapeConsistency(context) == kTfLiteOk);
573   if (shapes_are_valid) {
574     TFLITE_LOG(tflite::TFLITE_LOG_INFO,
575                "FlexDelegate: All tensor shapes are consistent.");
576   } else {
577     TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
578                "FlexDelegate: Some tensor shapes are inconsistent.");
579   }
580 
581   // All output tensors are allocated by TensorFlow, so we mark them as
582   // kTfLiteDynamic.
583   for (auto tensor_index : op_data_->subgraph_outputs) {
584     if (!shapes_are_valid) {
585       SetTensorToDynamic(&context->tensors[tensor_index]);
586     }
587     ++tensor_ref_count[tensor_index];
588   }
589 
590   for (const auto& node_data : op_data_->nodes) {
591     if (node_data->nodedef().op().empty()) {
592       TF_LITE_KERNEL_LOG(context, "Invalid NodeDef in Flex op '%s'",
593                          node_data->name().c_str());
594       return kTfLiteError;
595     }
596     TF_LITE_ENSURE(context, node_data->op_kernel_runner());
597 
598     for (int i = 0; i < node_data->inputs().Size(); ++i) {
599       ++tensor_ref_count[node_data->inputs().TfLiteIndex(i)];
600     }
601   }
602 
603   // All tensors that are referenced exactly once are marked as "forwardable",
604   // meaning that we will allow TensorFlow to reuse its buffer as the output of
605   // an op.
606   for (auto& node_data : op_data_->nodes) {
607     for (int i = 0; i < node_data->inputs().Size(); ++i) {
608       bool f = (tensor_ref_count[node_data->inputs().TfLiteIndex(i)] == 1);
609       node_data->mutable_inputs()->SetForwardable(i, f);
610     }
611   }
612 
613   return kTfLiteOk;
614 }
615 
ValidateOutputTensorShapeConsistency(TfLiteContext * context) const616 TfLiteStatus DelegateKernel::ValidateOutputTensorShapeConsistency(
617     TfLiteContext* context) const {
618   for (const auto& node_data : op_data_->nodes) {
619     auto op_name = node_data->name().c_str();
620     // Create an InferenceContext object.
621     auto num_inputs = node_data->inputs().Size();
622     std::vector<const tensorflow::Tensor*> input_tensors_vector(num_inputs,
623                                                                 nullptr);
624     InferenceContext c(
625         TF_GRAPH_DEF_VERSION, node_data->nodedef(),
626         node_data->op_reg_data()->op_def, std::vector<ShapeHandle>(num_inputs),
627         input_tensors_vector, {},
628         std::vector<std::unique_ptr<std::vector<ShapeAndType>>>());
629 
630     // Set input_shapes for ShapeInferenceFn.
631     for (int i = 0; i < num_inputs; ++i) {
632       const auto input_tensor_index = node_data->inputs().TfLiteIndex(i);
633       TfLiteTensor* tfl_tensor = &context->tensors[input_tensor_index];
634       // Provide constant input tensors since some op ("RFFT") needs it to
635       // calculate the output shape.
636       if (IsConstantTensor(tfl_tensor)) {
637         input_tensors_vector[i] =
638             op_data_->shared_info.buffer_map->GetTensorPtr(input_tensor_index);
639       }
640       const auto dims_array = tfl_tensor->dims;
641       std::vector<DimensionHandle> dims(dims_array->size);
642       for (int j = 0; j < dims_array->size; ++j) {
643         dims[j] = c.MakeDim(dims_array->data[j]);
644       }
645       c.SetInput(i, c.MakeShape(dims));
646     }
647     c.set_input_tensors(input_tensors_vector);
648 
649     tensorflow::Status status = c.construction_status();
650     if (!status.ok()) {
651       TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
652                  "Shape construction failed for op '%s'", op_name);
653       return kTfLiteError;
654     }
655 
656     // Run ShapeInferenceFn to calculate output shapes.
657     if (node_data->op_reg_data()->shape_inference_fn == nullptr) {
658       TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
659                  "No shape inference function exists for op '%s'", op_name);
660       return kTfLiteError;
661     }
662     status = c.Run(node_data->op_reg_data()->shape_inference_fn);
663 
664     // Compare calculated output shapes with node_data->outputs
665     auto num_outputs = node_data->outputs().Size();
666     if (num_outputs != c.num_outputs()) {
667       TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
668                  "Number of output tensors are mismatched for op '%s' %d != %d",
669                  op_name, num_outputs, c.num_outputs());
670       return kTfLiteError;
671     }
672     for (int i = 0; i < num_outputs; ++i) {
673       const auto output_tensor_index = node_data->outputs().TfLiteIndex(i);
674       TfLiteTensor* tfl_tensor = &context->tensors[output_tensor_index];
675       // tfl_tensor->dims only has valid information if the given model is
676       // converted by the MLIR converter. Also when ResizeInputTensor() is
677       // called the dims information becomes invalid.
678       const std::string tfl_shape_string =
679           GetShapeDebugString(tfl_tensor->dims);
680       const std::string calculated_shape_string = c.DebugString(c.output(i));
681       // Getting a shape string via c.DebugString() is the easiest way to get
682       // the shape information of the given ShapeHandle for now.
683       // TODO(b/169017408): Find a better approach without using debug string.
684       if (tfl_shape_string != calculated_shape_string) {
685         if ((strcmp(op_name, kReadVariableOp) == 0) &&
686             (tfl_tensor->dims->size > 0)) {
687           // If ReadVariableOp has an output with valid shape, use it since
688           // ShapeInferenceFn of ReadVariableOp doesn't work well without having
689           // a valid resource handle.
690           continue;
691         }
692 
693         TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
694                    "op '%s' output%d tensor#%d shape mismatch for  %s != %s",
695                    op_name, i, output_tensor_index, tfl_shape_string.c_str(),
696                    calculated_shape_string.c_str());
697         return kTfLiteError;
698       }
699     }
700   }
701   return kTfLiteOk;
702 }
703 
GetDefaultCancellationManager()704 static tensorflow::CancellationManager* GetDefaultCancellationManager() {
705   static auto* const cancellation_manager = new tensorflow::CancellationManager;
706   return cancellation_manager;
707 }
708 
Eval(TfLiteContext * context,TfLiteNode * node)709 TfLiteStatus DelegateKernel::Eval(TfLiteContext* context, TfLiteNode* node) {
710   BufferMap* buffer_map = op_data_->shared_info.buffer_map;
711 
712   // Insert a tensor in the buffer map for all inputs that are not constant.
713   // Constants were handled in Prepare() already.
714   for (auto tensor_index : op_data_->subgraph_inputs) {
715     TfLiteTensor* tensor = &context->tensors[tensor_index];
716     if (!IsConstantTensor(tensor)) {
717       // If this tensor is part of an earlier TF subgraph we should not add it
718       // to the BufferMap again, because TF already knows about it and its
719       // contents are kept automatically up-to-date.
720       if (!tensor->data_is_stale || !buffer_map->HasTensor(tensor_index)) {
721         buffer_map->SetFromTfLite(
722             tensor_index, tensor,
723             !op_data_->disable_reusing_buffer_tensors.count(tensor_index));
724       }
725     }
726   }
727 
728   auto& eager_context = *op_data_->eager_context;
729 
730   {
731     tensorflow::tfrt_stub::OpKernelRunState run_state;
732 
733     run_state.params.step_container = eager_context.StepContainer();
734     auto* device = eager_context.local_device_mgr()->HostCPU();
735     run_state.params.device = device;
736     run_state.params.resource_manager = device->resource_manager();
737     run_state.params.runner = eager_context.runner();
738     run_state.params.cancellation_manager =
739         op_data_->cancellation_manager ? op_data_->cancellation_manager
740                                        : GetDefaultCancellationManager();
741     // TODO(b/179048776): Set up remaining params such as collective and
742     // rendezvous.
743 
744     // Execute the TensorFlow Ops sequentially.
745     for (auto& node_data : op_data_->nodes) {
746       TFLITE_SCOPED_DELEGATE_OPERATOR_PROFILE(
747           reinterpret_cast<Profiler*>(context->profiler),
748           node_data->name().c_str(), node_data->index());
749 
750       if (op_data_->cancellation_manager != nullptr &&
751           op_data_->cancellation_manager->IsCancelled()) {
752         TF_LITE_KERNEL_LOG(
753             context, "Client requested cancel during DelegateKernel::Eval");
754         return kTfLiteError;
755       }
756 
757       auto status = ExecuteOpKernelRunner(&run_state, context, node_data.get());
758       TF_LITE_ENSURE_OK(context, ConvertStatus(context, status));
759     }
760   }
761 
762   for (auto tensor_index : op_data_->subgraph_outputs) {
763     if (op_data_->shared_info.already_transferred_outputs.count(tensor_index) !=
764         0) {
765       // Skip if a tensor output has already been copied to a TfLiteTensor.
766       continue;
767     }
768     if (!buffer_map->HasTensor(tensor_index)) {
769       TF_LITE_KERNEL_LOG(context, "Cannot write to invalid tensor index %d",
770                          tensor_index);
771       return kTfLiteError;
772     }
773 
774     // Copy TF tensor data to TFL allocated buffer for non dynamic tensors.
775     // For dynamic tensors, copy shape and put buffer_handle for the later
776     // CopyFromBufferHandle() call.
777     TfLiteTensor* tensor = &context->tensors[tensor_index];
778     const tensorflow::Tensor& tf_tensor = buffer_map->GetTensor(tensor_index);
779     if (tensor->allocation_type == kTfLiteDynamic) {
780       TF_LITE_ENSURE_OK(context, CopyShapeAndType(context, tf_tensor, tensor));
781       tensor->buffer_handle = tensor_index;
782       tensor->data_is_stale = true;
783       continue;
784     }
785     // If the tensor isn't dynamic, we can copy data directly to the buffer of
786     // the tensor. Before copying the data, check if the target buffer has
787     // expected size.
788     if (tf_tensor.NumElements() != NumElements(tensor) ||
789         tf_tensor.TotalBytes() != tensor->bytes) {
790       TF_LITE_KERNEL_LOG(context,
791                          "FlexDelegate: Tensor %s(%d) buffer size mismatch "
792                          "%zu(%lld) != %ld(%ld)",
793                          tensor->name, tensor_index, tf_tensor.TotalBytes(),
794                          tf_tensor.NumElements(), tensor->bytes,
795                          NumElements(tensor));
796       return kTfLiteError;
797     }
798     tensorflow::StringPiece t_data = tf_tensor.tensor_data();
799     memcpy(tensor->data.raw, t_data.data(), t_data.size());
800   }
801 
802   return kTfLiteOk;
803 }
804 
GetTensorReleaseMap() const805 const std::map<int, int>& DelegateKernel::GetTensorReleaseMap() const {
806   return *(op_data_->shared_info.tensor_release_map);
807 }
808 
809 }  // namespace flex
810 }  // namespace tflite
811