xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/utils/simple_delegate.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/utils/simple_delegate.h"
16 
17 #include <limits>
18 #include <memory>
19 #include <string>
20 #include <vector>
21 
22 #include "tensorflow/lite/builtin_ops.h"
23 #include "tensorflow/lite/c/common.h"
24 #include "tensorflow/lite/context_util.h"
25 #include "tensorflow/lite/delegates/utils.h"
26 #include "tensorflow/lite/kernels/internal/compatibility.h"
27 #include "tensorflow/lite/minimal_logging.h"
28 
29 namespace tflite {
30 namespace {
GetDelegateKernelRegistration(SimpleDelegateInterface * delegate)31 TfLiteRegistration GetDelegateKernelRegistration(
32     SimpleDelegateInterface* delegate) {
33   TfLiteRegistration kernel_registration{};
34   kernel_registration.profiling_string = nullptr;
35   kernel_registration.builtin_code = kTfLiteBuiltinDelegate;
36   kernel_registration.custom_name = delegate->Name();
37   kernel_registration.version = 1;
38   kernel_registration.free = [](TfLiteContext* context, void* buffer) -> void {
39     delete reinterpret_cast<SimpleDelegateKernelInterface*>(buffer);
40   };
41   kernel_registration.init = [](TfLiteContext* context, const char* buffer,
42                                 size_t length) -> void* {
43     const TfLiteDelegateParams* params =
44         reinterpret_cast<const TfLiteDelegateParams*>(buffer);
45     if (params == nullptr) {
46       TF_LITE_KERNEL_LOG(context, "NULL TfLiteDelegateParams passed.");
47       return nullptr;
48     }
49     auto* delegate =
50         reinterpret_cast<SimpleDelegateInterface*>(params->delegate->data_);
51     std::unique_ptr<SimpleDelegateKernelInterface> delegate_kernel(
52         delegate->CreateDelegateKernelInterface());
53     if (delegate_kernel->Init(context, params) != kTfLiteOk) {
54       return nullptr;
55     }
56     return delegate_kernel.release();
57   };
58   kernel_registration.prepare = [](TfLiteContext* context,
59                                    TfLiteNode* node) -> TfLiteStatus {
60     if (node->user_data == nullptr) {
61       TF_LITE_KERNEL_LOG(context, "Delegate kernel was not initialized");
62       return kTfLiteError;
63     }
64     SimpleDelegateKernelInterface* delegate_kernel =
65         reinterpret_cast<SimpleDelegateKernelInterface*>(node->user_data);
66     return delegate_kernel->Prepare(context, node);
67   };
68   kernel_registration.invoke = [](TfLiteContext* context,
69                                   TfLiteNode* node) -> TfLiteStatus {
70     SimpleDelegateKernelInterface* delegate_kernel =
71         reinterpret_cast<SimpleDelegateKernelInterface*>(node->user_data);
72     TFLITE_DCHECK(delegate_kernel != nullptr);
73     return delegate_kernel->Eval(context, node);
74   };
75 
76   return kernel_registration;
77 }
78 
DelegatePrepare(TfLiteContext * context,TfLiteDelegate * base_delegate)79 TfLiteStatus DelegatePrepare(TfLiteContext* context,
80                              TfLiteDelegate* base_delegate) {
81   auto* delegate =
82       reinterpret_cast<SimpleDelegateInterface*>(base_delegate->data_);
83   auto delegate_options = delegate->DelegateOptions();
84   if (delegate_options.max_delegated_partitions <= 0)
85     delegate_options.max_delegated_partitions = std::numeric_limits<int>::max();
86 
87   TF_LITE_ENSURE_STATUS(delegate->Initialize(context));
88   delegates::IsNodeSupportedFn node_supported_fn =
89       [=](TfLiteContext* context, TfLiteNode* node,
90           TfLiteRegistration* registration,
91           std::string* unsupported_details) -> bool {
92     return delegate->IsNodeSupportedByDelegate(registration, node, context);
93   };
94   // TODO(b/149484598): Update to have method that gets all supported nodes.
95   delegates::GraphPartitionHelper helper(context, node_supported_fn);
96   TF_LITE_ENSURE_STATUS(helper.Partition(nullptr));
97 
98   std::vector<int> supported_nodes = helper.GetNodesOfFirstNLargestPartitions(
99       delegate_options.max_delegated_partitions,
100       delegate_options.min_nodes_per_partition);
101 
102   TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO,
103                        "%s delegate: %d nodes delegated out of %d nodes with "
104                        "%d partitions.\n",
105                        delegate->Name(), supported_nodes.size(),
106                        helper.num_total_nodes(), helper.num_partitions());
107   TfLiteRegistration delegate_kernel_registration =
108       GetDelegateKernelRegistration(delegate);
109 
110   return context->ReplaceNodeSubsetsWithDelegateKernels(
111       context, delegate_kernel_registration,
112       BuildTfLiteIntArray(supported_nodes).get(), base_delegate);
113 }
114 }  // namespace
115 
CreateSimpleDelegate(std::unique_ptr<SimpleDelegateInterface> simple_delegate,int64_t flag)116 TfLiteDelegate* TfLiteDelegateFactory::CreateSimpleDelegate(
117     std::unique_ptr<SimpleDelegateInterface> simple_delegate, int64_t flag) {
118   if (simple_delegate == nullptr) {
119     return nullptr;
120   }
121   auto delegate = new TfLiteDelegate();
122   delegate->Prepare = &DelegatePrepare;
123   delegate->flags = flag;
124   delegate->CopyFromBufferHandle = nullptr;
125   delegate->CopyToBufferHandle = nullptr;
126   delegate->FreeBufferHandle = nullptr;
127   delegate->data_ = simple_delegate.release();
128   return delegate;
129 }
130 
DeleteSimpleDelegate(TfLiteDelegate * delegate)131 void TfLiteDelegateFactory::DeleteSimpleDelegate(TfLiteDelegate* delegate) {
132   if (!delegate) return;
133   SimpleDelegateInterface* simple_delegate =
134       reinterpret_cast<SimpleDelegateInterface*>(delegate->data_);
135   delete simple_delegate;
136   delete delegate;
137 }
138 }  // namespace tflite
139