xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/flex/delegate_data.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/flex/delegate_data.h"
16 
17 #include <functional>
18 #include <memory>
19 #include <set>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/memory/memory.h"
25 #include "absl/strings/str_cat.h"
26 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
27 #include "tensorflow/core/common_runtime/device_factory.h"
28 #include "tensorflow/core/common_runtime/eager/context.h"
29 #include "tensorflow/core/framework/function.h"
30 #include "tensorflow/core/framework/node_def.pb.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/resource_mgr.h"
33 #include "tensorflow/core/graph/graph.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/platform/errors.h"
36 #include "tensorflow/core/platform/status.h"
37 #include "tensorflow/core/platform/tstring.h"
38 #include "tensorflow/core/protobuf/error_codes.pb.h"
39 #include "tensorflow/lite/c/common.h"
40 #include "tensorflow/lite/core/subgraph.h"
41 #include "tensorflow/lite/delegates/flex/util.h"
42 #include "tensorflow/lite/schema/schema_generated.h"
43 #include "tensorflow/lite/util.h"
44 
45 namespace tflite {
46 namespace flex {
47 
48 namespace {
49 
50 // Builds a `FunctionDef` proto that contains two nodes:
51 // The first node is a constant node which has the value of the resource key,
52 // the second node is a `TfLiteSubgraphExecute` node which will take the
53 // resource key, and the subgraph's inputs as arguments. The function's return
54 // value is the return value of `TfLiteSubgraphExecute`.
BuildFunctionDefProto(const std::string & function_name,const Subgraph & subgraph,tensorflow::FunctionDef & fdef)55 void BuildFunctionDefProto(const std::string& function_name,
56                            const Subgraph& subgraph,
57                            tensorflow::FunctionDef& fdef) {
58   // Map inputs/outputs to types.
59   std::vector<std::string> inputs, outputs;
60   inputs.reserve(subgraph.inputs().size());
61   outputs.reserve(subgraph.outputs().size());
62   for (int i = 0; i < subgraph.inputs().size(); ++i) {
63     inputs.push_back(absl::StrCat(
64         "args_", i, ": ",
65         TfLiteTypeToTfTypeName(subgraph.tensor(subgraph.inputs()[i])->type)));
66   }
67   for (int i = 0; i < subgraph.outputs().size(); ++i) {
68     outputs.push_back(absl::StrCat(
69         "res_", i, ": ",
70         TfLiteTypeToTfTypeName(subgraph.tensor(subgraph.outputs()[i])->type)));
71   }
72   std::vector<tensorflow::FunctionDefHelper::Node> nodes;
73   // The first node is a constant node containing the string value for the
74   // resource name.
75   nodes.push_back(tensorflow::FunctionDefHelper::Const<tensorflow::tstring>(
76       "SubgraphResourceKey", function_name));
77   // Builds the `TfLiteSubgraphExecute` node.
78   tensorflow::FunctionDefHelper::Node execute_node;
79   execute_node.ret.push_back("InvokeTfLite");
80   execute_node.op = "TfLiteSubgraphExecute";
81   execute_node.arg.push_back("SubgraphResourceKey:output:0");
82   for (int i = 0; i < subgraph.inputs().size(); ++i) {
83     execute_node.arg.push_back(absl::StrCat("args_", i));
84   }
85   nodes.push_back(execute_node);
86 
87   std::vector<std::pair<std::string, std::string>> ret_def;
88   ret_def.reserve(subgraph.outputs().size());
89   for (int i = 0; i < subgraph.outputs().size(); ++i) {
90     ret_def.emplace_back(absl::StrCat("res_", i),
91                          absl::StrCat("InvokeTfLite:output:", i));
92   }
93   fdef = tensorflow::FunctionDefHelper::Create(function_name, inputs, outputs,
94                                                /*attr_def=*/{}, nodes, ret_def);
95   // Insert input/output type attrs.
96   tensorflow::AttrValue tin_attrs, tout_attrs;
97   for (int i = 0; i < subgraph.inputs().size(); ++i) {
98     TF_DataType dtype = tflite::flex::GetTensorFlowDataType(
99         subgraph.tensor(subgraph.inputs()[i])->type);
100     tin_attrs.mutable_list()->add_type(tensorflow::DataType(dtype));
101   }
102   for (int i = 0; i < subgraph.outputs().size(); ++i) {
103     TF_DataType dtype = tflite::flex::GetTensorFlowDataType(
104         subgraph.tensor(subgraph.outputs()[i])->type);
105     tout_attrs.mutable_list()->add_type(tensorflow::DataType(dtype));
106   }
107   fdef.mutable_node_def(1)->mutable_attr()->insert({"Tin", tin_attrs});
108   fdef.mutable_node_def(1)->mutable_attr()->insert({"Tout", tout_attrs});
109 }
110 
111 // Returns a list of subgraph names which have associated function attributes.
GetSubgraphNamesForFunctionExecution(const std::vector<std::unique_ptr<Subgraph>> & subgraphs,std::set<std::string> * result)112 tensorflow::Status GetSubgraphNamesForFunctionExecution(
113     const std::vector<std::unique_ptr<Subgraph>>& subgraphs,
114     std::set<std::string>* result) {
115   tensorflow::NodeDef node_def;
116   for (const auto& subgraph : subgraphs) {
117     for (const auto& node_and_reg : subgraph->nodes_and_registration()) {
118       if (node_and_reg.second.builtin_code != tflite::BuiltinOperator_CUSTOM) {
119         // If this isn't a custom op, skip.
120         continue;
121       }
122       const std::string custom_name = node_and_reg.second.custom_name;
123       if (custom_name.substr(0, strlen(tflite::kFlexCustomCodePrefix)) !=
124           tflite::kFlexCustomCodePrefix) {
125         // Skip if this is not a flex op.
126         continue;
127       }
128       // The flexbuffer contains a vector where the first elements is the
129       // op name and the second is a serialized NodeDef.
130       const flexbuffers::Vector& v =
131           flexbuffers::GetRoot(reinterpret_cast<const uint8_t*>(
132                                    node_and_reg.first.custom_initial_data),
133                                node_and_reg.first.custom_initial_data_size)
134               .AsVector();
135       // TODO(b/181352924): Use proto arena if we see performance regression.
136       if (!node_def.ParseFromString(v[1].AsString().str())) {
137         return tensorflow::Status(tensorflow::error::INTERNAL,
138                                   "could not parse NodeDef");
139       }
140       // Loop through all the attributes in this node to check if it has
141       // function attribute.
142       for (const auto& attr : node_def.attr()) {
143         if (attr.second.has_func()) {
144           result->insert(attr.second.func().name());
145         }
146       }
147     }
148   }
149   return ::tensorflow::OkStatus();
150 }
151 
152 }  // namespace
153 
RegisterFunctionDefForSubgraphs(Subgraph & main_subgraph,const std::function<tensorflow::Status (const std::vector<std::unique_ptr<Subgraph>> &,std::set<std::string> *)> & select_subgraphs_to_register,tensorflow::ResourceMgr * resource_mgr,tensorflow::EagerContext * eager_context,TfLiteDelegate * flex_delegate)154 tensorflow::Status RegisterFunctionDefForSubgraphs(
155     Subgraph& main_subgraph,
156     const std::function<tensorflow::Status(
157         const std::vector<std::unique_ptr<Subgraph>>&, std::set<std::string>*)>&
158         select_subgraphs_to_register,
159     tensorflow::ResourceMgr* resource_mgr,
160     tensorflow::EagerContext* eager_context, TfLiteDelegate* flex_delegate) {
161   std::vector<std::unique_ptr<Subgraph>>* subgraphs =
162       main_subgraph.GetSubgraphs();
163   if (!subgraphs) {
164     // If there are no subgraphs associated with the main subgraph, we will
165     // return ok status because no FunctionDef needs to be registered.
166     return ::tensorflow::OkStatus();
167   }
168   std::set<std::string> function_subgraphs;
169   TF_RETURN_IF_ERROR(
170       select_subgraphs_to_register(*subgraphs, &function_subgraphs));
171   for (int i = 0; i < subgraphs->size(); ++i) {
172     if (subgraphs->at(i)->GetName() == "main") {
173       continue;
174     }
175     const std::string subgraph_name = subgraphs->at(i)->GetName();
176     if (!function_subgraphs.count(subgraph_name)) {
177       continue;
178     }
179     // This is to ensure that we only register FunctionDefs for subgraphs that
180     // are used by TF ops to invoke functions.
181     auto* subgraph_resource =
182         new TFLiteSubgraphResource(*(subgraphs->at(i)), flex_delegate);
183     TF_RETURN_IF_ERROR(resource_mgr->Create<TFLiteSubgraphResource>(
184         "flex", subgraph_name, subgraph_resource));
185     tensorflow::FunctionDef fdef;
186     BuildFunctionDefProto(subgraph_name, *(subgraphs->at(i)), fdef);
187     TF_RETURN_IF_ERROR(eager_context->AddFunctionDef(fdef));
188   }
189   return ::tensorflow::OkStatus();
190 }
191 
DelegateData()192 DelegateData::DelegateData() {}
193 
~DelegateData()194 DelegateData::~DelegateData() {
195   if (eager_context_) {
196     // Notify the eager context to clean up the resource being held before
197     // destructing the `DelegateData`.
198     eager_context_->HostCPU()->ClearResourceMgr();
199     eager_context_->Unref();
200   }
201 }
202 
Prepare(const tensorflow::SessionOptions & session_options,Subgraph * main_subgraph,TfLiteDelegate * flex_delegate)203 tensorflow::Status DelegateData::Prepare(
204     const tensorflow::SessionOptions& session_options, Subgraph* main_subgraph,
205     TfLiteDelegate* flex_delegate) {
206   if (eager_context_) {
207     return tensorflow::Status();
208   }
209   if (flex_delegate == nullptr && main_subgraph != nullptr) {
210     return tensorflow::Status(
211         tensorflow::error::FAILED_PRECONDITION,
212         "flex_delegate must be non-null when main_subgraph is provided.");
213   }
214 
215   std::vector<std::unique_ptr<tensorflow::Device>> devices;
216 
217   TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices(
218       session_options, "/job:localhost/replica:0/task:0", &devices));
219 
220   auto device_mgr =
221       std::make_unique<tensorflow::StaticDeviceMgr>(std::move(devices));
222   // Note that Rendezvous is ref-counted so it will be automatically deleted.
223   tensorflow::Rendezvous* rendezvous =
224       new tensorflow::IntraProcessRendezvous(device_mgr.get());
225   eager_context_ = new tensorflow::EagerContext(
226       session_options,
227       tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
228       /*async=*/false, device_mgr.release(), /*device_mgr_owned*/ true,
229       rendezvous, nullptr);
230 
231   if (main_subgraph) {
232     TF_RETURN_IF_ERROR(RegisterFunctionDefForSubgraphs(
233         *main_subgraph, GetSubgraphNamesForFunctionExecution,
234         eager_context_->HostCPU()->resource_manager(), eager_context_,
235         flex_delegate));
236   }
237   return tensorflow::Status();
238 }
239 
240 }  // namespace flex
241 }  // namespace tflite
242