1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/placer_inspection_required_ops_utils.h"
16
17 #include <unordered_map>
18 #include <unordered_set>
19
20 #include "absl/strings/str_cat.h"
21 #include "absl/types/optional.h"
22 #include "tensorflow/core/framework/function.h"
23 #include "tensorflow/core/framework/node_def_builder.h"
24 #include "tensorflow/core/graph/graph.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/core/status.h"
27
28 namespace tensorflow {
29 namespace {
30
IsFunctionCall(const Node & node)31 bool IsFunctionCall(const Node& node) {
32 // TODO(iga): Handle non-PCO functions when we add multi-device support
33 // to regular function calls. Also, the GetFunctionDefAndAttrs assumes that
34 // the function name is stored in the `f` attribute of the node. That code
35 // will need to change as well.
36 const string& op_type = node.op_def().name();
37 return op_type == "PartitionedCall" || op_type == "StatefulPartitionedCall";
38 }
39
40 // Utility to set node's value in `cache` and `is_deep` to `value`.
Set(const Node & node,bool value,bool * is_deep,std::vector<absl::optional<bool>> * cache)41 Status Set(const Node& node, bool value, bool* is_deep,
42 std::vector<absl::optional<bool>>* cache) {
43 *is_deep = value;
44 (*cache)[node.id()] = value;
45 return OkStatus();
46 }
47
48 } // namespace
49
PlacerInspectionRequiredOpChecker(const Graph * graph,const FunctionLibraryDefinition * flib_def)50 PlacerInspectionRequiredOpChecker::PlacerInspectionRequiredOpChecker(
51 const Graph* graph, const FunctionLibraryDefinition* flib_def)
52 : graph_(*graph), flib_def_(*flib_def) {
53 cache_.resize(graph_.num_node_ids());
54 }
55
IsPlacerInspectionRequired(const Node & node,bool * is_deep)56 Status PlacerInspectionRequiredOpChecker::IsPlacerInspectionRequired(
57 const Node& node, bool* is_deep) {
58 if (cache_[node.id()].has_value()) {
59 *is_deep = cache_[node.id()].value();
60 return OkStatus();
61 }
62
63 if (!IsFunctionCall(node)) {
64 return Set(node, false, is_deep, &cache_);
65 }
66 const FunctionDef* fdef;
67 NameAttrList func;
68 TF_RETURN_IF_ERROR(GetFunctionDefAndAttrs(flib_def_, node, &fdef, &func));
69 DataTypeVector types;
70 TF_RETURN_IF_ERROR(
71 OutputTypesForNode(AttrSlice(&func.attr()), fdef->signature(), &types));
72 for (DataType type : types) {
73 if (type == DT_RESOURCE) {
74 return Set(node, true, is_deep, &cache_);
75 }
76 }
77 return Set(node, false, is_deep, &cache_);
78 }
79
GetFunctionDefAndAttrs(const FunctionLibraryDefinition & flib_def,const Node & node,const FunctionDef ** fdef,NameAttrList * func)80 Status GetFunctionDefAndAttrs(const FunctionLibraryDefinition& flib_def,
81 const Node& node, const FunctionDef** fdef,
82 NameAttrList* func) {
83 TF_RETURN_IF_ERROR(GetNodeAttr(node.def(), "f", func));
84 const string& function_name = func->name();
85 *fdef = flib_def.Find(function_name);
86 if (*fdef == nullptr) {
87 return errors::InvalidArgument(
88 "Failed to find function \"", function_name,
89 "\" in function library: ", flib_def.ToProto().DebugString());
90 }
91 return OkStatus();
92 }
93
FunctionStack(const string & function_name)94 FunctionStack::FunctionStack(const string& function_name)
95 : current_function_name_(function_name) {}
96
Push(const Node * node_in_current_function,const string & new_current_function) const97 FunctionStack FunctionStack::Push(const Node* node_in_current_function,
98 const string& new_current_function) const {
99 FunctionStack new_stack(new_current_function);
100 new_stack.frames_ = frames_;
101 new_stack.frames_.emplace_back(current_function_name_,
102 node_in_current_function);
103 return new_stack;
104 }
105
HasFunction(const string & function_name) const106 bool FunctionStack::HasFunction(const string& function_name) const {
107 if (current_function_name_ == function_name) {
108 return true;
109 }
110 for (const Frame& frame : frames_) {
111 if (frame.function_name == function_name) {
112 return true;
113 }
114 }
115 return false;
116 }
117
FormatForError() const118 string FunctionStack::FormatForError() const {
119 std::vector<string> msgs;
120 for (int i = 0; i < frames_.size(); ++i) {
121 if (frames_[i].function_name.empty()) {
122 // Empty function body should only happen at the top level, i.e. i = 0.
123 // All internal frames should have valid function names.
124 msgs.push_back(absl::StrCat("Graph contains node ",
125 FormatNodeForError(*frames_[i].node)));
126
127 } else {
128 msgs.push_back(absl::StrCat(
129 "Function ", errors::FormatFunctionForError(frames_[i].function_name),
130 " contains node ", FormatNodeForError(*frames_[i].node)));
131 }
132 const string& fname = (i + 1 < frames_.size())
133 ? frames_[i + 1].function_name
134 : current_function_name_;
135 msgs.push_back(absl::StrCat("Node ", FormatNodeForError(*frames_[i].node),
136 " calls function ",
137 errors::FormatFunctionForError(fname)));
138 }
139 return absl::StrJoin(msgs, "\n ");
140 }
141
142 namespace {
143
144 using OutputEdgeMap = std::vector<std::vector<const Edge*>>;
145
146 constexpr char kIdentityOp[] = "Identity";
147
Uniquify(const string & candidate_name,std::unordered_set<string> * node_names)148 string Uniquify(const string& candidate_name,
149 std::unordered_set<string>* node_names) {
150 if (node_names->find(candidate_name) == node_names->end()) {
151 node_names->insert(candidate_name);
152 return candidate_name;
153 }
154
155 for (int counter = 0;; ++counter) {
156 string candidate = absl::StrCat(candidate_name, "_", counter);
157 if (node_names->find(candidate) == node_names->end()) {
158 node_names->insert(candidate);
159 return candidate;
160 }
161 }
162 }
163
AddInputIdentity(Node * node,int input_idx,Graph * graph,std::unordered_set<string> * node_names)164 Status AddInputIdentity(Node* node, int input_idx, Graph* graph,
165 std::unordered_set<string>* node_names) {
166 const Edge* edge;
167 TF_RETURN_IF_ERROR(node->input_edge(input_idx, &edge));
168
169 string identity_name = Uniquify(
170 absl::StrCat(edge->src()->name(), "_", node->name()), node_names);
171
172 NodeDefBuilder builder(identity_name, kIdentityOp);
173 builder.Attr("T", node->input_type(input_idx));
174 NodeDefBuilder::NodeOut input(edge->src()->name(), edge->src_output(),
175 node->input_type(input_idx));
176 builder.Input(input);
177 NodeDef identity_def;
178 TF_RETURN_IF_ERROR(builder.Finalize(&identity_def));
179 MergeDebugInfo(NodeDebugInfo(*node), &identity_def);
180
181 VLOG(6) << "Adding identity into " << edge->src()->name() << ":"
182 << edge->src_output() << " -> " << edge->dst()->name() << ":"
183 << input_idx << " \n"
184 << identity_def.DebugString();
185
186 TF_ASSIGN_OR_RETURN(Node * identity_node, graph->AddNode(identity_def));
187 graph->AddEdge(edge->src(), edge->src_output(), identity_node, 0);
188
189 // Replace node's `input_idx` input with the new identity's 0'th output
190 TF_RETURN_IF_ERROR(graph->UpdateEdge(identity_node, 0, node, input_idx));
191
192 VLOG(6) << "Successfully inserted identity. Modified node: \n"
193 << node->DebugString();
194 return OkStatus();
195 }
196
197 struct EdgePtrCompare {
operator ()tensorflow::__anond03c3ac50211::EdgePtrCompare198 bool operator()(const Edge* lhs, const Edge* rhs) const {
199 return lhs->id() < rhs->id();
200 }
201 };
202
AddOutputIdentities(Node * node,Graph * graph,std::unordered_set<string> * node_names)203 Status AddOutputIdentities(Node* node, Graph* graph,
204 std::unordered_set<string>* node_names) {
205 auto add_identity = [&](int src_output, const string& identity_name,
206 Node** identity_node) {
207 NodeDefBuilder builder(identity_name, kIdentityOp);
208 builder.Attr("T", node->output_type(src_output));
209 NodeDefBuilder::NodeOut input(node->name(), src_output,
210 node->output_type(src_output));
211 builder.Input(input);
212 NodeDef identity_def;
213 TF_RETURN_IF_ERROR(builder.Finalize(&identity_def));
214 MergeDebugInfo(NodeDebugInfo(*node), &identity_def);
215
216 TF_ASSIGN_OR_RETURN(*identity_node, graph->AddNode(identity_def));
217 graph->AddEdge(node, src_output, *identity_node, 0);
218 return OkStatus();
219 };
220
221 // output_used[i] == true iff `node`'s i'th output is used
222 // in this graph
223 std::vector<bool> output_used(node->num_outputs(), false);
224 // Copy the set of edges since EdgeSet does not allow modifications
225 // to graph edges during iteration.
226 const EdgeSet& out_edges = node->out_edges();
227 std::vector<const Edge*> edge_vector(out_edges.begin(), out_edges.end());
228 std::sort(edge_vector.begin(), edge_vector.end(), EdgePtrCompare());
229 for (const Edge* edge : edge_vector) {
230 if (edge->IsControlEdge()) {
231 continue;
232 }
233 output_used[edge->src_output()] = true;
234
235 Node* dst = edge->dst();
236 int dst_input = edge->dst_input();
237 int src_output = edge->src_output();
238 string identity_name =
239 Uniquify(absl::StrCat(node->name(), "_", dst->name()), node_names);
240 Node* identity_node;
241 TF_RETURN_IF_ERROR(add_identity(src_output, identity_name, &identity_node));
242 VLOG(6) << "Adding identity into " << node->name() << ":" << src_output
243 << " -> " << dst->name() << ":" << dst_input << " \n"
244 << identity_node->DebugString();
245
246 // Make original dst node consume the new identity's output instead of
247 // `node`'s output.
248 TF_RETURN_IF_ERROR(graph->UpdateEdge(identity_node, 0, dst, dst_input));
249 }
250
251 for (int output_idx = 0; output_idx < node->num_outputs(); ++output_idx) {
252 if (output_used[output_idx]) {
253 continue;
254 }
255 // The output is unused in the graph. Just add an identity
256 // consuming it.
257 string identity_name = Uniquify(node->name(), node_names);
258 Node* identity_node;
259 TF_RETURN_IF_ERROR(add_identity(output_idx, identity_name, &identity_node));
260 VLOG(6) << "Added identity into " << node->name() << ":" << output_idx
261 << " -> <no consumer>: \n"
262 << identity_node->DebugString();
263 }
264 return OkStatus();
265 }
266
IsolateNode(Node * node,Graph * graph)267 Status IsolateNode(Node* node, Graph* graph) {
268 // We use `node_names` to make sure we pick unique names.
269 // We don't use graph->NewName() because it produces verbose names and
270 // does not actually ensure that they are unique (it assumes all names
271 // are generated using it, which is not true today).
272 std::unordered_set<string> node_names(graph->num_nodes());
273 for (Node* n : graph->nodes()) {
274 node_names.insert(n->name());
275 }
276
277 for (int i = 0; i < node->num_inputs(); ++i) {
278 TF_RETURN_IF_ERROR(AddInputIdentity(node, i, graph, &node_names));
279 }
280 TF_RETURN_IF_ERROR(AddOutputIdentities(node, graph, &node_names));
281 return OkStatus();
282 }
283
284 } // namespace
285
IsolatePlacerInspectionRequiredOps(const FunctionLibraryDefinition & flib_def,Graph * graph)286 Status IsolatePlacerInspectionRequiredOps(
287 const FunctionLibraryDefinition& flib_def, Graph* graph) {
288 PlacerInspectionRequiredOpChecker checker(graph, &flib_def);
289 // It is OK to add nodes to the graph during iteration.
290 // New nodes will get ids above current ids. The loop
291 // will loop over current nodes only because the op_nodes()
292 // iterator uses node ids to iterate.
293 // Because the new nodes will be higher ids, the caching in
294 // the checker will also work fine as new nodes are added.
295 for (Node* node : graph->op_nodes()) {
296 bool should_be_isolated = false;
297 TF_RETURN_IF_ERROR(
298 checker.IsPlacerInspectionRequired(*node, &should_be_isolated));
299 if (!should_be_isolated) {
300 continue;
301 }
302 TF_RETURN_IF_ERROR(IsolateNode(node, graph));
303 }
304
305 return OkStatus();
306 }
307
308 } // namespace tensorflow
309