1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/flex/kernel.h"
16
17 #include <algorithm>
18 #include <map>
19 #include <memory>
20 #include <set>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "flatbuffers/flexbuffers.h" // from @flatbuffers
26 #include "tensorflow/core/common_runtime/eager/context.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/node_def_util.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/platform/status.h"
31 #include "tensorflow/core/protobuf/error_codes.pb.h"
32 #include "tensorflow/lite/builtin_ops.h"
33 #include "tensorflow/lite/c/common.h"
34 #include "tensorflow/lite/context_util.h"
35 #include "tensorflow/lite/core/api/profiler.h"
36 #include "tensorflow/lite/delegates/flex/delegate.h"
37 #include "tensorflow/lite/delegates/flex/delegate_data.h"
38 #include "tensorflow/lite/delegates/flex/util.h"
39 #include "tensorflow/lite/kernels/kernel_util.h"
40 #include "tensorflow/lite/minimal_logging.h"
41 #include "tensorflow/lite/string_type.h"
42
43 // Note: this is part of TF Lite's Flex delegation code which is to be
44 // completed soon.
45
46 // This is the TF Lite op that is created by the flex delegate to handle
47 // execution of a supported subgraph. The usual flow is that the delegate
48 // informs the interpreter of supported nodes in a graph, and each supported
49 // subgraph is replaced with one instance of this kernel.
50 //
51 // The kernel is initialized with TfLiteDelegateParams from which we retrieve
52 // the global EagerContext and BufferMap, as well as a list of inputs and
53 // outputs to the subgraph. Those are used to build the OpData, with a list of
54 // TensorFlow Ops that should be executed in order (which we call an OpNode).
55 //
56 // For each node included in the subgraph, we query the interpreter and
57 // retrieve the associated NodeDef, which is then used to configure the
58 // corresponding TensorFlow OpKernel.
59
60 using tensorflow::shape_inference::DimensionHandle;
61 using tensorflow::shape_inference::InferenceContext;
62 using tensorflow::shape_inference::ShapeAndType;
63 using tensorflow::shape_inference::ShapeHandle;
64
65 namespace tflite {
66 namespace flex {
67
68 constexpr char kReadVariableOp[] = "ReadVariableOp";
69
70 struct OpNode;
71
72 // Represents the origin of a given tensor as a reference to the output
73 // of an upstream node.
74 struct TensorSource {
75 OpNode* node;
76 int node_output_index;
77 };
78
79 // A list of inputs of a given node of the TensorFlow graph.
80 class OpInputs {
81 public:
OpInputs(const TfLiteIntArray * indexes)82 explicit OpInputs(const TfLiteIntArray* indexes) {
83 for (int index : TfLiteIntArrayView(indexes)) {
84 inputs_.push_back(index);
85 }
86 forwardable_.resize(inputs_.size());
87 }
~OpInputs()88 ~OpInputs() {}
89
Size() const90 int Size() const { return inputs_.size(); }
91
TfLiteIndex(int i) const92 int TfLiteIndex(int i) const { return inputs_[i]; }
93
94 // Given a map relating tensors to the node that originates them, populate a
95 // list of sources for the tensors in this class.
InitializeTensorSources(const std::map<int,TensorSource> & tflite_tensor_sources)96 void InitializeTensorSources(
97 const std::map<int, TensorSource>& tflite_tensor_sources) {
98 sources_.clear();
99 for (int i : inputs_) {
100 auto it = tflite_tensor_sources.find(i);
101 if (it == tflite_tensor_sources.end()) {
102 sources_.push_back({nullptr, 0});
103 } else {
104 sources_.push_back(it->second);
105 }
106 }
107 }
108
SetForwardable(int i,bool v)109 void SetForwardable(int i, bool v) { forwardable_[i] = v; }
110
IsForwardable(int i) const111 bool IsForwardable(int i) const { return forwardable_[i]; }
112
GetTensorSource(int i) const113 TensorSource GetTensorSource(int i) const { return sources_[i]; }
114
115 private:
116 std::vector<int> inputs_;
117 std::vector<TensorSource> sources_;
118
119 // List of tensors that can be used by TF in its forwarding optimization.
120 // Doing so allows an input tensor to be modified and used as the output
121 // tensor. The delegate takes care of not holding any references to tensors
122 // in this list while the corresponding tensorflow::OpKernel is executed.
123 std::vector<int> forwardable_;
124 };
125
126 // A list of outputs of a given node of the TensorFlow graph, along with
127 // the actual outputs of the tensorflow::OpKernel.
128 class OpOutputs {
129 public:
OpOutputs(const TfLiteIntArray * indexes)130 explicit OpOutputs(const TfLiteIntArray* indexes) {
131 for (int index : TfLiteIntArrayView(indexes)) {
132 outputs_.push_back(index);
133 }
134 vector_.resize(outputs_.size());
135 }
136 ~OpOutputs() = default;
137
138 // Stores information about which of the tensors in this class are also
139 // outputs of the sugbraph.
InitializeGraphOutputs(const std::set<int> & subgraph_outputs)140 void InitializeGraphOutputs(const std::set<int>& subgraph_outputs) {
141 subgraph_outputs_.clear();
142 for (int i : outputs_) {
143 subgraph_outputs_.push_back(subgraph_outputs.count(i) > 0);
144 }
145 }
146
147 // Returns true if the tensor given by index 'i' is an output of the entire
148 // subgraph.
IsSubgraphOutput(int i) const149 bool IsSubgraphOutput(int i) const { return subgraph_outputs_[i]; }
150
GetTensor(int i) const151 const tensorflow::Tensor& GetTensor(int i) const { return vector_[i]; }
ReleaseTensor(int i)152 tensorflow::Tensor ReleaseTensor(int i) { return std::move(vector_[i]); }
153
Size() const154 int Size() const { return outputs_.size(); }
155
TfLiteIndex(int i) const156 int TfLiteIndex(int i) const { return outputs_[i]; }
157
GetTensors()158 tensorflow::gtl::InlinedVector<tensorflow::Tensor, 2>* GetTensors() {
159 return &vector_;
160 }
161
162 private:
163 std::vector<int> outputs_;
164 std::vector<bool> subgraph_outputs_;
165 tensorflow::gtl::InlinedVector<tensorflow::Tensor, 2> vector_;
166 };
167
168 // This struct holds information such as tensor lifecycle and BufferMap which
169 // needs to be shared between `OpNode` and DelegateKernel.
170 struct OpDataInfo {
171 // Buffer map which stores the mapping between TfLiteTensor index to TF
172 // tensor.
173 BufferMap* buffer_map;
174 // Mapping information between TfLiteTensor index to last node which uses the
175 // tensor.
176 std::map<int, int>* tensor_release_map;
177 // For output tensors that don't need to be preserved in the BufferMap, we
178 // copy them to TF Lite tensors and keep the tensor indexes in this set.
179 std::set<int> already_transferred_outputs;
180 };
181
182 // A single node within the larger 'op'. Note that this kernel executes many
183 // TensorFlow ops within a single TF Lite op.
184 class OpNode {
185 public:
OpNode(const TfLiteIntArray * inputs,const TfLiteIntArray * outputs)186 OpNode(const TfLiteIntArray* inputs, const TfLiteIntArray* outputs)
187 : inputs_(inputs), outputs_(outputs) {}
188 ~OpNode() = default;
189
name() const190 const string& name() const { return name_; }
set_name(const string & name)191 void set_name(const string& name) { name_ = name; }
192
index() const193 int index() const { return index_; }
set_index(int index)194 void set_index(int index) { index_ = index; }
195
nodedef() const196 const tensorflow::NodeDef& nodedef() const { return nodedef_; }
op_reg_data() const197 const tensorflow::OpRegistrationData* op_reg_data() const {
198 return op_reg_data_;
199 }
200
inputs() const201 const OpInputs& inputs() const { return inputs_; }
mutable_inputs()202 OpInputs* mutable_inputs() { return &inputs_; }
203
outputs() const204 const OpOutputs& outputs() const { return outputs_; }
mutable_outputs()205 OpOutputs* mutable_outputs() { return &outputs_; }
206
NumInputs() const207 int NumInputs() const { return inputs_.Size(); }
NumOutputs() const208 int NumOutputs() const { return outputs_.Size(); }
209
op_kernel_runner() const210 const tensorflow::tfrt_stub::OpKernelRunner& op_kernel_runner() const {
211 return op_kernel_runner_;
212 }
213
InitializeNodeDef(const void * custom_initial_data,int custom_initial_data_size)214 tensorflow::Status InitializeNodeDef(const void* custom_initial_data,
215 int custom_initial_data_size) {
216 if (!custom_initial_data) {
217 return tensorflow::errors::Internal(
218 "Cannot convert empty data into a valid NodeDef");
219 }
220 // The flexbuffer contains a vector where the first elements is the
221 // op name and the second is a serialized NodeDef.
222 const flexbuffers::Vector& v =
223 flexbuffers::GetRoot(
224 reinterpret_cast<const uint8_t*>(custom_initial_data),
225 custom_initial_data_size)
226 .AsVector();
227
228 name_ = v[0].AsString().str();
229 if (!nodedef_.ParseFromString(v[1].AsString().str())) {
230 nodedef_.Clear();
231 return tensorflow::errors::Internal(
232 "Failed to parse data into a valid NodeDef");
233 }
234
235 // Fill NodeDef with defaults if it's a valid op.
236 TF_RETURN_IF_ERROR(
237 tensorflow::OpRegistry::Global()->LookUp(nodedef_.op(), &op_reg_data_));
238 AddDefaultsToNodeDef(op_reg_data_->op_def, &nodedef_);
239
240 return ::tensorflow::OkStatus();
241 }
242
BuildOpKernelRunner(tensorflow::EagerContext * eager_context)243 tensorflow::Status BuildOpKernelRunner(
244 tensorflow::EagerContext* eager_context) {
245 // Create tensorflow::OpKernel on host CPU.
246 TF_ASSIGN_OR_RETURN(op_kernel_runner_,
247 tensorflow::tfrt_stub::OpKernelRunner::Create(
248 name_, inputs_.Size(), /*attr_builder=*/
249 [this](tensorflow::AttrValueMap* attr_value_map) {
250 *attr_value_map = nodedef_.attr();
251 return ::tensorflow::OkStatus();
252 },
253 *eager_context->pflr(),
254 eager_context->local_device_mgr()->HostCPU()));
255
256 return ::tensorflow::OkStatus();
257 }
258
BuildOpKernelInputs(const BufferMap * buffer_map,tensorflow::tfrt_stub::OpKernelRunState * run_state)259 tensorflow::Status BuildOpKernelInputs(
260 const BufferMap* buffer_map,
261 tensorflow::tfrt_stub::OpKernelRunState* run_state) {
262 run_state->input_tf_tensors.resize(inputs_.Size());
263 run_state->input_tf_tensor_values.resize(inputs_.Size());
264
265 for (int i = 0; i < inputs_.Size(); ++i) {
266 int input_index = inputs_.TfLiteIndex(i);
267 TensorSource s = inputs_.GetTensorSource(i);
268 if (!s.node) {
269 // This input is not produced by this TF subgraph (it could be a TF
270 // Lite native buffer, or could be produced by a separater subgraph). We
271 // need to fetch it from the delegate's buffer_map.
272 if (!buffer_map->HasTensor(input_index)) {
273 return tensorflow::errors::Internal(
274 "Cannot read from invalid tensor index ", input_index);
275 }
276 run_state->input_tf_tensors[i] = buffer_map->GetTensor(input_index);
277 } else {
278 // If this is a forwardable tensor, we will remove it from the previous
279 // op's list, giving TF the opportunity to reuse its buffer.
280 if (inputs_.IsForwardable(i)) {
281 run_state->input_tf_tensors[i] =
282 s.node->outputs_.ReleaseTensor(s.node_output_index);
283 } else {
284 run_state->input_tf_tensors[i] =
285 s.node->outputs_.GetTensor(s.node_output_index);
286 }
287 }
288 run_state->input_tf_tensor_values[i].tensor =
289 &run_state->input_tf_tensors[i];
290 }
291 return ::tensorflow::OkStatus();
292 }
293
294 // Returns whether an output tensor should be preserved in the buffer map by
295 // checking its lifetime information.
296 // The eager tensor doesn't need to be persisted in the buffer map if it has
297 // no future uses in the graph.
ShouldPersistTensorflowTensor(TfLiteContext * context,const OpDataInfo * shared_info,int tensor_index,int node_index)298 bool ShouldPersistTensorflowTensor(TfLiteContext* context,
299 const OpDataInfo* shared_info,
300 int tensor_index, int node_index) {
301 TfLiteTensor* tensor = &context->tensors[tensor_index];
302 // Always persist variant|resource|string tensors since they have special
303 // storage requirement.
304 if (IsResourceOrVariant(tensor) || tensor->type == kTfLiteString) {
305 return true;
306 }
307
308 auto it = shared_info->tensor_release_map->find(tensor_index);
309 return it != shared_info->tensor_release_map->end() &&
310 it->second > node_index;
311 }
312
313 // Copies the data of Tensorflow tensor into the corresponding TfLite tensor,
314 // after copy is done release the original tensor so that memory could be
315 // released by TF runtime.
CopyToTfLiteTensor(TfLiteContext * context,OpDataInfo * shared_info,TfLiteTensor * tensor,tensorflow::Tensor * tf_tensor,int tensor_index) const316 TfLiteStatus CopyToTfLiteTensor(TfLiteContext* context,
317 OpDataInfo* shared_info, TfLiteTensor* tensor,
318 tensorflow::Tensor* tf_tensor,
319 int tensor_index) const {
320 if (tensor->allocation_type == kTfLiteDynamic) {
321 // For dynamic tensors, update the TfLite tensor's shape information from
322 // the Tensorflow tensor.
323 CopyShapeAndType(context, *tf_tensor, tensor);
324 }
325 tensorflow::StringPiece t_data = tf_tensor->tensor_data();
326 if (tf_tensor->NumElements() != NumElements(tensor) ||
327 tf_tensor->TotalBytes() != tensor->bytes) {
328 TF_LITE_KERNEL_LOG(context,
329 "FlexDelegate: Tensor %s(%d) buffer size mismatch "
330 "%zu(%lld) != %ld(%ld)",
331 tensor->name, tensor_index, tf_tensor->TotalBytes(),
332 tf_tensor->NumElements(), tensor->bytes,
333 NumElements(tensor));
334 return kTfLiteError;
335 }
336 // Copy TF tensor's data content into TfLiteTensor, and release the tensor.
337 memcpy(tensor->data.raw, t_data.data(), t_data.size());
338 *tf_tensor = {};
339 shared_info->already_transferred_outputs.insert(tensor_index);
340 return kTfLiteOk;
341 }
342
343 // TODO(b/204479285): Release tensors from BufferMap if it has no future
344 // uses.
MaybePersistTensorflowOutputs(TfLiteContext * context,OpDataInfo * shared_info,int node_index)345 tensorflow::Status MaybePersistTensorflowOutputs(TfLiteContext* context,
346 OpDataInfo* shared_info,
347 int node_index) {
348 auto* tensors = outputs_.GetTensors();
349
350 for (int i = 0; i < outputs_.Size(); ++i) {
351 if (outputs_.IsSubgraphOutput(i)) {
352 tensorflow::Tensor& tf_tensor = tensors->at(i);
353 const int tflite_index = outputs_.TfLiteIndex(i);
354 TfLiteTensor* tensor = &context->tensors[tflite_index];
355 if (!ShouldPersistTensorflowTensor(context, shared_info, tflite_index,
356 node_index)) {
357 if (CopyToTfLiteTensor(context, shared_info, tensor, &tf_tensor,
358 tflite_index) != kTfLiteOk) {
359 return tensorflow::Status(tensorflow::error::INTERNAL,
360 "failed to copy data from TF tensor");
361 }
362 } else {
363 shared_info->buffer_map->SetFromTensorFlow(outputs_.TfLiteIndex(i),
364 tf_tensor);
365 }
366 }
367 }
368 return ::tensorflow::OkStatus();
369 }
370
371 private:
372 OpNode(const OpNode&) = delete;
373 OpNode& operator=(const OpNode&) = delete;
374
375 // The name of the TensorFlow op to execute.
376 string name_;
377 // Index of this node into TF Lite's operator list.
378 int index_;
379 // The corresponding NodeDef, containing the attributes for the op.
380 tensorflow::NodeDef nodedef_;
381 // The corresponding OpRegistrationData pointer.
382 const tensorflow::OpRegistrationData* op_reg_data_;
383 // List of inputs, as TF Lite tensor indices.
384 OpInputs inputs_;
385 // List of outputs, as TF Lite tensor indices.
386 OpOutputs outputs_;
387
388 tensorflow::tfrt_stub::OpKernelRunner op_kernel_runner_;
389 };
390
391 // The larger 'op', which contains all the nodes in a supported subgraph.
392 struct OpData {
393 tensorflow::EagerContext* eager_context;
394 tensorflow::CancellationManager* cancellation_manager;
395 std::vector<std::unique_ptr<OpNode>> nodes;
396 std::vector<int> subgraph_inputs;
397 std::vector<int> subgraph_outputs;
398 std::set<int>
399 disable_reusing_buffer_tensors; // A list of input tensor indexes which
400 // input buffer should not be reused by
401 // tensorflow::Tensor.
402 OpDataInfo shared_info;
403 };
404
ExecuteOpKernelRunner(tensorflow::tfrt_stub::OpKernelRunState * run_state,TfLiteContext * context,OpNode * node_data)405 tensorflow::Status DelegateKernel::ExecuteOpKernelRunner(
406 tensorflow::tfrt_stub::OpKernelRunState* run_state, TfLiteContext* context,
407 OpNode* node_data) {
408 const auto& op_kernel_runner = node_data->op_kernel_runner();
409
410 if (op_kernel_runner.op_kernel()->num_outputs() != node_data->NumOutputs()) {
411 return tensorflow::errors::Internal(
412 "Unexpected number of outputs from tensorflow::OpKernel");
413 }
414
415 TF_RETURN_IF_ERROR(node_data->BuildOpKernelInputs(
416 op_data_->shared_info.buffer_map, run_state));
417
418 run_state->params.inputs = run_state->input_tf_tensor_values;
419 run_state->params.op_kernel = op_kernel_runner.op_kernel();
420 run_state->params.input_alloc_attrs = op_kernel_runner.input_alloc_attrs();
421 run_state->params.output_attr_array =
422 op_kernel_runner.output_alloc_attrs().data();
423 run_state->params.function_library =
424 op_kernel_runner.function_library_runtime();
425
426 tensorflow::OpKernelContext tf_context(&run_state->params,
427 node_data->NumOutputs());
428 op_kernel_runner.Run(&tf_context);
429 TF_RETURN_IF_ERROR(tf_context.status());
430
431 auto& outputs = *node_data->mutable_outputs()->GetTensors();
432 for (int i = 0; i < tf_context.num_outputs(); ++i) {
433 outputs[i] = std::move(*tf_context.mutable_output(i));
434 }
435
436 return node_data->MaybePersistTensorflowOutputs(
437 context, &(op_data_->shared_info), node_data->index());
438 }
439
DelegateKernel()440 DelegateKernel::DelegateKernel() : op_data_(new OpData) {}
~DelegateKernel()441 DelegateKernel::~DelegateKernel() {}
442
Init(TfLiteContext * context,const TfLiteDelegateParams * params)443 TfLiteStatus DelegateKernel::Init(TfLiteContext* context,
444 const TfLiteDelegateParams* params) {
445 auto* flex_delegate_data =
446 reinterpret_cast<FlexDelegate*>(params->delegate->data_)->mutable_data();
447 op_data_->eager_context = flex_delegate_data->GetEagerContext();
448 op_data_->cancellation_manager = flex_delegate_data->GetCancellationManager();
449 op_data_->shared_info.buffer_map = flex_delegate_data->GetBufferMap(context);
450 op_data_->shared_info.tensor_release_map =
451 flex_delegate_data->GetTensorReleaseMap(context);
452
453 CHECK(params->output_tensors);
454 std::set<int> output_set;
455 for (auto tensor_index : TfLiteIntArrayView(params->output_tensors)) {
456 op_data_->subgraph_outputs.push_back(tensor_index);
457 output_set.insert(tensor_index);
458 }
459
460 CHECK(params->input_tensors);
461 for (auto tensor_index : TfLiteIntArrayView(params->input_tensors)) {
462 op_data_->subgraph_inputs.push_back(tensor_index);
463 }
464 std::set<int> subgraph_inputs(op_data_->subgraph_inputs.begin(),
465 op_data_->subgraph_inputs.end());
466
467 op_data_->nodes.reserve(params->nodes_to_replace->size);
468
469 CHECK(params->nodes_to_replace);
470 tensorflow::Status status;
471
472 // Now we explicitly disable reusing TFLite tensor buffers for certain TF ops,
473 // since those ops might produce results which keep reference of the input
474 // tensors (buffer forwarding).
475 auto check_if_op_reuses_input = [](const string& op_name) {
476 return op_name == "TensorListPushBack" || op_name == "TensorListSetItem" ||
477 op_name == "SparseReshape";
478 };
479
480 for (auto node_index : TfLiteIntArrayView(params->nodes_to_replace)) {
481 TfLiteNode* node;
482 TfLiteRegistration* reg;
483 context->GetNodeAndRegistration(context, node_index, &node, ®);
484
485 op_data_->nodes.emplace_back(new OpNode(node->inputs, node->outputs));
486 OpNode& node_data = *op_data_->nodes.back();
487
488 node_data.set_index(node_index);
489 node_data.set_name("");
490
491 status = node_data.InitializeNodeDef(node->custom_initial_data,
492 node->custom_initial_data_size);
493 if (!status.ok()) break;
494 status = node_data.BuildOpKernelRunner(op_data_->eager_context);
495 if (!status.ok()) break;
496
497 // For each node handled by this delegate partition, record the mapping
498 // information between each input tensor and the node index. The node index
499 // is the index of the last node in execution order that uses this tensor.
500 // So the tensor is no longer needed after this last node is executed.
501 // Since we execute in order, then the maximum index is the index of the
502 // last node that needs this tensor.
503 for (auto tensor_index : TfLiteIntArrayView(node->inputs)) {
504 int node_id = node_index;
505 if (op_data_->shared_info.tensor_release_map->find(tensor_index) !=
506 op_data_->shared_info.tensor_release_map->end()) {
507 node_id =
508 std::max(op_data_->shared_info.tensor_release_map->at(tensor_index),
509 node_index);
510 }
511 (*op_data_->shared_info.tensor_release_map)[tensor_index] = node_id;
512
513 if (subgraph_inputs.count(tensor_index) &&
514 check_if_op_reuses_input(node_data.nodedef().op())) {
515 op_data_->disable_reusing_buffer_tensors.insert(tensor_index);
516 }
517 }
518 }
519
520 TF_LITE_ENSURE_STATUS(ConvertStatus(context, status));
521
522 // Given a TfLite tensor index, return the OpNode that produces it,
523 // along with it index into that OpNodes list of outputs.
524 std::map<int, TensorSource> tflite_tensor_sources;
525
526 // Find out how each tensor is produced. This does not account for
527 // tensors that are not produced by tensorflow::Opkernels.
528 for (auto& node_data : op_data_->nodes) {
529 node_data->mutable_outputs()->InitializeGraphOutputs(output_set);
530 for (int i = 0; i < node_data->outputs().Size(); ++i) {
531 int output_index = node_data->outputs().TfLiteIndex(i);
532 tflite_tensor_sources[output_index] = TensorSource{node_data.get(), i};
533 }
534 }
535
536 // For each node, resolve the inputs, so we can keep pointers to the nodes
537 // that produces them.
538 for (auto& node_data : op_data_->nodes) {
539 node_data->mutable_inputs()->InitializeTensorSources(tflite_tensor_sources);
540 }
541 return kTfLiteOk;
542 }
543
Prepare(TfLiteContext * context,TfLiteNode * node)544 TfLiteStatus DelegateKernel::Prepare(TfLiteContext* context, TfLiteNode* node) {
545 TF_LITE_ENSURE_MSG(
546 context, op_data_->eager_context != nullptr,
547 "Failed to initialize eager context. This often happens when a CPU "
548 "device has not been registered, presumably because some symbols from "
549 "tensorflow/core:core_cpu_impl were not linked into the binary.");
550
551 // We will keep track of the number of references to each tensor in the
552 // graph, so we can make them "forwardable" if there is only one reference.
553 std::map<int, int> tensor_ref_count;
554
555 // Whenever we find a constant tensor, insert it in the buffer map.
556 BufferMap* buffer_map = op_data_->shared_info.buffer_map;
557 for (auto tensor_index : op_data_->subgraph_inputs) {
558 TfLiteTensor* tensor = &context->tensors[tensor_index];
559 if (IsConstantTensor(tensor)) {
560 if (!tensor->data_is_stale || !buffer_map->HasTensor(tensor_index)) {
561 buffer_map->SetFromTfLite(tensor_index, tensor);
562 }
563 }
564
565 // Input tensors should never be forwarded so we increment their ref counts
566 // twice: once for this graph and another for the possibility of them being
567 // used by another subgraph, or being an output of the full graph.
568 tensor_ref_count[tensor_index] += 2;
569 }
570
571 const bool shapes_are_valid =
572 (ValidateOutputTensorShapeConsistency(context) == kTfLiteOk);
573 if (shapes_are_valid) {
574 TFLITE_LOG(tflite::TFLITE_LOG_INFO,
575 "FlexDelegate: All tensor shapes are consistent.");
576 } else {
577 TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
578 "FlexDelegate: Some tensor shapes are inconsistent.");
579 }
580
581 // All output tensors are allocated by TensorFlow, so we mark them as
582 // kTfLiteDynamic.
583 for (auto tensor_index : op_data_->subgraph_outputs) {
584 if (!shapes_are_valid) {
585 SetTensorToDynamic(&context->tensors[tensor_index]);
586 }
587 ++tensor_ref_count[tensor_index];
588 }
589
590 for (const auto& node_data : op_data_->nodes) {
591 if (node_data->nodedef().op().empty()) {
592 TF_LITE_KERNEL_LOG(context, "Invalid NodeDef in Flex op '%s'",
593 node_data->name().c_str());
594 return kTfLiteError;
595 }
596 TF_LITE_ENSURE(context, node_data->op_kernel_runner());
597
598 for (int i = 0; i < node_data->inputs().Size(); ++i) {
599 ++tensor_ref_count[node_data->inputs().TfLiteIndex(i)];
600 }
601 }
602
603 // All tensors that are referenced exactly once are marked as "forwardable",
604 // meaning that we will allow TensorFlow to reuse its buffer as the output of
605 // an op.
606 for (auto& node_data : op_data_->nodes) {
607 for (int i = 0; i < node_data->inputs().Size(); ++i) {
608 bool f = (tensor_ref_count[node_data->inputs().TfLiteIndex(i)] == 1);
609 node_data->mutable_inputs()->SetForwardable(i, f);
610 }
611 }
612
613 return kTfLiteOk;
614 }
615
ValidateOutputTensorShapeConsistency(TfLiteContext * context) const616 TfLiteStatus DelegateKernel::ValidateOutputTensorShapeConsistency(
617 TfLiteContext* context) const {
618 for (const auto& node_data : op_data_->nodes) {
619 auto op_name = node_data->name().c_str();
620 // Create an InferenceContext object.
621 auto num_inputs = node_data->inputs().Size();
622 std::vector<const tensorflow::Tensor*> input_tensors_vector(num_inputs,
623 nullptr);
624 InferenceContext c(
625 TF_GRAPH_DEF_VERSION, node_data->nodedef(),
626 node_data->op_reg_data()->op_def, std::vector<ShapeHandle>(num_inputs),
627 input_tensors_vector, {},
628 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>());
629
630 // Set input_shapes for ShapeInferenceFn.
631 for (int i = 0; i < num_inputs; ++i) {
632 const auto input_tensor_index = node_data->inputs().TfLiteIndex(i);
633 TfLiteTensor* tfl_tensor = &context->tensors[input_tensor_index];
634 // Provide constant input tensors since some op ("RFFT") needs it to
635 // calculate the output shape.
636 if (IsConstantTensor(tfl_tensor)) {
637 input_tensors_vector[i] =
638 op_data_->shared_info.buffer_map->GetTensorPtr(input_tensor_index);
639 }
640 const auto dims_array = tfl_tensor->dims;
641 std::vector<DimensionHandle> dims(dims_array->size);
642 for (int j = 0; j < dims_array->size; ++j) {
643 dims[j] = c.MakeDim(dims_array->data[j]);
644 }
645 c.SetInput(i, c.MakeShape(dims));
646 }
647 c.set_input_tensors(input_tensors_vector);
648
649 tensorflow::Status status = c.construction_status();
650 if (!status.ok()) {
651 TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
652 "Shape construction failed for op '%s'", op_name);
653 return kTfLiteError;
654 }
655
656 // Run ShapeInferenceFn to calculate output shapes.
657 if (node_data->op_reg_data()->shape_inference_fn == nullptr) {
658 TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
659 "No shape inference function exists for op '%s'", op_name);
660 return kTfLiteError;
661 }
662 status = c.Run(node_data->op_reg_data()->shape_inference_fn);
663
664 // Compare calculated output shapes with node_data->outputs
665 auto num_outputs = node_data->outputs().Size();
666 if (num_outputs != c.num_outputs()) {
667 TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
668 "Number of output tensors are mismatched for op '%s' %d != %d",
669 op_name, num_outputs, c.num_outputs());
670 return kTfLiteError;
671 }
672 for (int i = 0; i < num_outputs; ++i) {
673 const auto output_tensor_index = node_data->outputs().TfLiteIndex(i);
674 TfLiteTensor* tfl_tensor = &context->tensors[output_tensor_index];
675 // tfl_tensor->dims only has valid information if the given model is
676 // converted by the MLIR converter. Also when ResizeInputTensor() is
677 // called the dims information becomes invalid.
678 const std::string tfl_shape_string =
679 GetShapeDebugString(tfl_tensor->dims);
680 const std::string calculated_shape_string = c.DebugString(c.output(i));
681 // Getting a shape string via c.DebugString() is the easiest way to get
682 // the shape information of the given ShapeHandle for now.
683 // TODO(b/169017408): Find a better approach without using debug string.
684 if (tfl_shape_string != calculated_shape_string) {
685 if ((strcmp(op_name, kReadVariableOp) == 0) &&
686 (tfl_tensor->dims->size > 0)) {
687 // If ReadVariableOp has an output with valid shape, use it since
688 // ShapeInferenceFn of ReadVariableOp doesn't work well without having
689 // a valid resource handle.
690 continue;
691 }
692
693 TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
694 "op '%s' output%d tensor#%d shape mismatch for %s != %s",
695 op_name, i, output_tensor_index, tfl_shape_string.c_str(),
696 calculated_shape_string.c_str());
697 return kTfLiteError;
698 }
699 }
700 }
701 return kTfLiteOk;
702 }
703
GetDefaultCancellationManager()704 static tensorflow::CancellationManager* GetDefaultCancellationManager() {
705 static auto* const cancellation_manager = new tensorflow::CancellationManager;
706 return cancellation_manager;
707 }
708
Eval(TfLiteContext * context,TfLiteNode * node)709 TfLiteStatus DelegateKernel::Eval(TfLiteContext* context, TfLiteNode* node) {
710 BufferMap* buffer_map = op_data_->shared_info.buffer_map;
711
712 // Insert a tensor in the buffer map for all inputs that are not constant.
713 // Constants were handled in Prepare() already.
714 for (auto tensor_index : op_data_->subgraph_inputs) {
715 TfLiteTensor* tensor = &context->tensors[tensor_index];
716 if (!IsConstantTensor(tensor)) {
717 // If this tensor is part of an earlier TF subgraph we should not add it
718 // to the BufferMap again, because TF already knows about it and its
719 // contents are kept automatically up-to-date.
720 if (!tensor->data_is_stale || !buffer_map->HasTensor(tensor_index)) {
721 buffer_map->SetFromTfLite(
722 tensor_index, tensor,
723 !op_data_->disable_reusing_buffer_tensors.count(tensor_index));
724 }
725 }
726 }
727
728 auto& eager_context = *op_data_->eager_context;
729
730 {
731 tensorflow::tfrt_stub::OpKernelRunState run_state;
732
733 run_state.params.step_container = eager_context.StepContainer();
734 auto* device = eager_context.local_device_mgr()->HostCPU();
735 run_state.params.device = device;
736 run_state.params.resource_manager = device->resource_manager();
737 run_state.params.runner = eager_context.runner();
738 run_state.params.cancellation_manager =
739 op_data_->cancellation_manager ? op_data_->cancellation_manager
740 : GetDefaultCancellationManager();
741 // TODO(b/179048776): Set up remaining params such as collective and
742 // rendezvous.
743
744 // Execute the TensorFlow Ops sequentially.
745 for (auto& node_data : op_data_->nodes) {
746 TFLITE_SCOPED_DELEGATE_OPERATOR_PROFILE(
747 reinterpret_cast<Profiler*>(context->profiler),
748 node_data->name().c_str(), node_data->index());
749
750 if (op_data_->cancellation_manager != nullptr &&
751 op_data_->cancellation_manager->IsCancelled()) {
752 TF_LITE_KERNEL_LOG(
753 context, "Client requested cancel during DelegateKernel::Eval");
754 return kTfLiteError;
755 }
756
757 auto status = ExecuteOpKernelRunner(&run_state, context, node_data.get());
758 TF_LITE_ENSURE_OK(context, ConvertStatus(context, status));
759 }
760 }
761
762 for (auto tensor_index : op_data_->subgraph_outputs) {
763 if (op_data_->shared_info.already_transferred_outputs.count(tensor_index) !=
764 0) {
765 // Skip if a tensor output has already been copied to a TfLiteTensor.
766 continue;
767 }
768 if (!buffer_map->HasTensor(tensor_index)) {
769 TF_LITE_KERNEL_LOG(context, "Cannot write to invalid tensor index %d",
770 tensor_index);
771 return kTfLiteError;
772 }
773
774 // Copy TF tensor data to TFL allocated buffer for non dynamic tensors.
775 // For dynamic tensors, copy shape and put buffer_handle for the later
776 // CopyFromBufferHandle() call.
777 TfLiteTensor* tensor = &context->tensors[tensor_index];
778 const tensorflow::Tensor& tf_tensor = buffer_map->GetTensor(tensor_index);
779 if (tensor->allocation_type == kTfLiteDynamic) {
780 TF_LITE_ENSURE_OK(context, CopyShapeAndType(context, tf_tensor, tensor));
781 tensor->buffer_handle = tensor_index;
782 tensor->data_is_stale = true;
783 continue;
784 }
785 // If the tensor isn't dynamic, we can copy data directly to the buffer of
786 // the tensor. Before copying the data, check if the target buffer has
787 // expected size.
788 if (tf_tensor.NumElements() != NumElements(tensor) ||
789 tf_tensor.TotalBytes() != tensor->bytes) {
790 TF_LITE_KERNEL_LOG(context,
791 "FlexDelegate: Tensor %s(%d) buffer size mismatch "
792 "%zu(%lld) != %ld(%ld)",
793 tensor->name, tensor_index, tf_tensor.TotalBytes(),
794 tf_tensor.NumElements(), tensor->bytes,
795 NumElements(tensor));
796 return kTfLiteError;
797 }
798 tensorflow::StringPiece t_data = tf_tensor.tensor_data();
799 memcpy(tensor->data.raw, t_data.data(), t_data.size());
800 }
801
802 return kTfLiteOk;
803 }
804
GetTensorReleaseMap() const805 const std::map<int, int>& DelegateKernel::GetTensorReleaseMap() const {
806 return *(op_data_->shared_info.tensor_release_map);
807 }
808
809 } // namespace flex
810 } // namespace tflite
811