1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/flex/delegate_data.h"
16
17 #include <functional>
18 #include <memory>
19 #include <set>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/memory/memory.h"
25 #include "absl/strings/str_cat.h"
26 #include "flatbuffers/flexbuffers.h" // from @flatbuffers
27 #include "tensorflow/core/common_runtime/device_factory.h"
28 #include "tensorflow/core/common_runtime/eager/context.h"
29 #include "tensorflow/core/framework/function.h"
30 #include "tensorflow/core/framework/node_def.pb.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/resource_mgr.h"
33 #include "tensorflow/core/graph/graph.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/platform/errors.h"
36 #include "tensorflow/core/platform/status.h"
37 #include "tensorflow/core/platform/tstring.h"
38 #include "tensorflow/core/protobuf/error_codes.pb.h"
39 #include "tensorflow/lite/c/common.h"
40 #include "tensorflow/lite/core/subgraph.h"
41 #include "tensorflow/lite/delegates/flex/util.h"
42 #include "tensorflow/lite/schema/schema_generated.h"
43 #include "tensorflow/lite/util.h"
44
45 namespace tflite {
46 namespace flex {
47
48 namespace {
49
50 // Builds a `FunctionDef` proto that contains two nodes:
51 // The first node is a constant node which has the value of the resource key,
52 // the second node is a `TfLiteSubgraphExecute` node which will take the
53 // resource key, and the subgraph's inputs as arguments. The function's return
54 // value is the return value of `TfLiteSubgraphExecute`.
BuildFunctionDefProto(const std::string & function_name,const Subgraph & subgraph,tensorflow::FunctionDef & fdef)55 void BuildFunctionDefProto(const std::string& function_name,
56 const Subgraph& subgraph,
57 tensorflow::FunctionDef& fdef) {
58 // Map inputs/outputs to types.
59 std::vector<std::string> inputs, outputs;
60 inputs.reserve(subgraph.inputs().size());
61 outputs.reserve(subgraph.outputs().size());
62 for (int i = 0; i < subgraph.inputs().size(); ++i) {
63 inputs.push_back(absl::StrCat(
64 "args_", i, ": ",
65 TfLiteTypeToTfTypeName(subgraph.tensor(subgraph.inputs()[i])->type)));
66 }
67 for (int i = 0; i < subgraph.outputs().size(); ++i) {
68 outputs.push_back(absl::StrCat(
69 "res_", i, ": ",
70 TfLiteTypeToTfTypeName(subgraph.tensor(subgraph.outputs()[i])->type)));
71 }
72 std::vector<tensorflow::FunctionDefHelper::Node> nodes;
73 // The first node is a constant node containing the string value for the
74 // resource name.
75 nodes.push_back(tensorflow::FunctionDefHelper::Const<tensorflow::tstring>(
76 "SubgraphResourceKey", function_name));
77 // Builds the `TfLiteSubgraphExecute` node.
78 tensorflow::FunctionDefHelper::Node execute_node;
79 execute_node.ret.push_back("InvokeTfLite");
80 execute_node.op = "TfLiteSubgraphExecute";
81 execute_node.arg.push_back("SubgraphResourceKey:output:0");
82 for (int i = 0; i < subgraph.inputs().size(); ++i) {
83 execute_node.arg.push_back(absl::StrCat("args_", i));
84 }
85 nodes.push_back(execute_node);
86
87 std::vector<std::pair<std::string, std::string>> ret_def;
88 ret_def.reserve(subgraph.outputs().size());
89 for (int i = 0; i < subgraph.outputs().size(); ++i) {
90 ret_def.emplace_back(absl::StrCat("res_", i),
91 absl::StrCat("InvokeTfLite:output:", i));
92 }
93 fdef = tensorflow::FunctionDefHelper::Create(function_name, inputs, outputs,
94 /*attr_def=*/{}, nodes, ret_def);
95 // Insert input/output type attrs.
96 tensorflow::AttrValue tin_attrs, tout_attrs;
97 for (int i = 0; i < subgraph.inputs().size(); ++i) {
98 TF_DataType dtype = tflite::flex::GetTensorFlowDataType(
99 subgraph.tensor(subgraph.inputs()[i])->type);
100 tin_attrs.mutable_list()->add_type(tensorflow::DataType(dtype));
101 }
102 for (int i = 0; i < subgraph.outputs().size(); ++i) {
103 TF_DataType dtype = tflite::flex::GetTensorFlowDataType(
104 subgraph.tensor(subgraph.outputs()[i])->type);
105 tout_attrs.mutable_list()->add_type(tensorflow::DataType(dtype));
106 }
107 fdef.mutable_node_def(1)->mutable_attr()->insert({"Tin", tin_attrs});
108 fdef.mutable_node_def(1)->mutable_attr()->insert({"Tout", tout_attrs});
109 }
110
111 // Returns a list of subgraph names which have associated function attributes.
GetSubgraphNamesForFunctionExecution(const std::vector<std::unique_ptr<Subgraph>> & subgraphs,std::set<std::string> * result)112 tensorflow::Status GetSubgraphNamesForFunctionExecution(
113 const std::vector<std::unique_ptr<Subgraph>>& subgraphs,
114 std::set<std::string>* result) {
115 tensorflow::NodeDef node_def;
116 for (const auto& subgraph : subgraphs) {
117 for (const auto& node_and_reg : subgraph->nodes_and_registration()) {
118 if (node_and_reg.second.builtin_code != tflite::BuiltinOperator_CUSTOM) {
119 // If this isn't a custom op, skip.
120 continue;
121 }
122 const std::string custom_name = node_and_reg.second.custom_name;
123 if (custom_name.substr(0, strlen(tflite::kFlexCustomCodePrefix)) !=
124 tflite::kFlexCustomCodePrefix) {
125 // Skip if this is not a flex op.
126 continue;
127 }
128 // The flexbuffer contains a vector where the first elements is the
129 // op name and the second is a serialized NodeDef.
130 const flexbuffers::Vector& v =
131 flexbuffers::GetRoot(reinterpret_cast<const uint8_t*>(
132 node_and_reg.first.custom_initial_data),
133 node_and_reg.first.custom_initial_data_size)
134 .AsVector();
135 // TODO(b/181352924): Use proto arena if we see performance regression.
136 if (!node_def.ParseFromString(v[1].AsString().str())) {
137 return tensorflow::Status(tensorflow::error::INTERNAL,
138 "could not parse NodeDef");
139 }
140 // Loop through all the attributes in this node to check if it has
141 // function attribute.
142 for (const auto& attr : node_def.attr()) {
143 if (attr.second.has_func()) {
144 result->insert(attr.second.func().name());
145 }
146 }
147 }
148 }
149 return ::tensorflow::OkStatus();
150 }
151
152 } // namespace
153
RegisterFunctionDefForSubgraphs(Subgraph & main_subgraph,const std::function<tensorflow::Status (const std::vector<std::unique_ptr<Subgraph>> &,std::set<std::string> *)> & select_subgraphs_to_register,tensorflow::ResourceMgr * resource_mgr,tensorflow::EagerContext * eager_context,TfLiteDelegate * flex_delegate)154 tensorflow::Status RegisterFunctionDefForSubgraphs(
155 Subgraph& main_subgraph,
156 const std::function<tensorflow::Status(
157 const std::vector<std::unique_ptr<Subgraph>>&, std::set<std::string>*)>&
158 select_subgraphs_to_register,
159 tensorflow::ResourceMgr* resource_mgr,
160 tensorflow::EagerContext* eager_context, TfLiteDelegate* flex_delegate) {
161 std::vector<std::unique_ptr<Subgraph>>* subgraphs =
162 main_subgraph.GetSubgraphs();
163 if (!subgraphs) {
164 // If there are no subgraphs associated with the main subgraph, we will
165 // return ok status because no FunctionDef needs to be registered.
166 return ::tensorflow::OkStatus();
167 }
168 std::set<std::string> function_subgraphs;
169 TF_RETURN_IF_ERROR(
170 select_subgraphs_to_register(*subgraphs, &function_subgraphs));
171 for (int i = 0; i < subgraphs->size(); ++i) {
172 if (subgraphs->at(i)->GetName() == "main") {
173 continue;
174 }
175 const std::string subgraph_name = subgraphs->at(i)->GetName();
176 if (!function_subgraphs.count(subgraph_name)) {
177 continue;
178 }
179 // This is to ensure that we only register FunctionDefs for subgraphs that
180 // are used by TF ops to invoke functions.
181 auto* subgraph_resource =
182 new TFLiteSubgraphResource(*(subgraphs->at(i)), flex_delegate);
183 TF_RETURN_IF_ERROR(resource_mgr->Create<TFLiteSubgraphResource>(
184 "flex", subgraph_name, subgraph_resource));
185 tensorflow::FunctionDef fdef;
186 BuildFunctionDefProto(subgraph_name, *(subgraphs->at(i)), fdef);
187 TF_RETURN_IF_ERROR(eager_context->AddFunctionDef(fdef));
188 }
189 return ::tensorflow::OkStatus();
190 }
191
DelegateData()192 DelegateData::DelegateData() {}
193
~DelegateData()194 DelegateData::~DelegateData() {
195 if (eager_context_) {
196 // Notify the eager context to clean up the resource being held before
197 // destructing the `DelegateData`.
198 eager_context_->HostCPU()->ClearResourceMgr();
199 eager_context_->Unref();
200 }
201 }
202
Prepare(const tensorflow::SessionOptions & session_options,Subgraph * main_subgraph,TfLiteDelegate * flex_delegate)203 tensorflow::Status DelegateData::Prepare(
204 const tensorflow::SessionOptions& session_options, Subgraph* main_subgraph,
205 TfLiteDelegate* flex_delegate) {
206 if (eager_context_) {
207 return tensorflow::Status();
208 }
209 if (flex_delegate == nullptr && main_subgraph != nullptr) {
210 return tensorflow::Status(
211 tensorflow::error::FAILED_PRECONDITION,
212 "flex_delegate must be non-null when main_subgraph is provided.");
213 }
214
215 std::vector<std::unique_ptr<tensorflow::Device>> devices;
216
217 TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices(
218 session_options, "/job:localhost/replica:0/task:0", &devices));
219
220 auto device_mgr =
221 std::make_unique<tensorflow::StaticDeviceMgr>(std::move(devices));
222 // Note that Rendezvous is ref-counted so it will be automatically deleted.
223 tensorflow::Rendezvous* rendezvous =
224 new tensorflow::IntraProcessRendezvous(device_mgr.get());
225 eager_context_ = new tensorflow::EagerContext(
226 session_options,
227 tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
228 /*async=*/false, device_mgr.release(), /*device_mgr_owned*/ true,
229 rendezvous, nullptr);
230
231 if (main_subgraph) {
232 TF_RETURN_IF_ERROR(RegisterFunctionDefForSubgraphs(
233 *main_subgraph, GetSubgraphNamesForFunctionExecution,
234 eager_context_->HostCPU()->resource_manager(), eager_context_,
235 flex_delegate));
236 }
237 return tensorflow::Status();
238 }
239
240 } // namespace flex
241 } // namespace tflite
242