1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/placer_inspection_required_ops_utils.h"
16 
17 #include <unordered_map>
18 #include <unordered_set>
19 
20 #include "absl/strings/str_cat.h"
21 #include "absl/types/optional.h"
22 #include "tensorflow/core/framework/function.h"
23 #include "tensorflow/core/framework/node_def_builder.h"
24 #include "tensorflow/core/graph/graph.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/core/status.h"
27 
28 namespace tensorflow {
29 namespace {
30 
IsFunctionCall(const Node & node)31 bool IsFunctionCall(const Node& node) {
32   // TODO(iga): Handle non-PCO functions when we add multi-device support
33   // to regular function calls. Also, the GetFunctionDefAndAttrs assumes that
34   // the function name is stored in the `f` attribute of the node. That code
35   // will need to change as well.
36   const string& op_type = node.op_def().name();
37   return op_type == "PartitionedCall" || op_type == "StatefulPartitionedCall";
38 }
39 
40 // Utility to set node's value in `cache` and `is_deep` to `value`.
Set(const Node & node,bool value,bool * is_deep,std::vector<absl::optional<bool>> * cache)41 Status Set(const Node& node, bool value, bool* is_deep,
42            std::vector<absl::optional<bool>>* cache) {
43   *is_deep = value;
44   (*cache)[node.id()] = value;
45   return OkStatus();
46 }
47 
48 }  // namespace
49 
PlacerInspectionRequiredOpChecker(const Graph * graph,const FunctionLibraryDefinition * flib_def)50 PlacerInspectionRequiredOpChecker::PlacerInspectionRequiredOpChecker(
51     const Graph* graph, const FunctionLibraryDefinition* flib_def)
52     : graph_(*graph), flib_def_(*flib_def) {
53   cache_.resize(graph_.num_node_ids());
54 }
55 
IsPlacerInspectionRequired(const Node & node,bool * is_deep)56 Status PlacerInspectionRequiredOpChecker::IsPlacerInspectionRequired(
57     const Node& node, bool* is_deep) {
58   if (cache_[node.id()].has_value()) {
59     *is_deep = cache_[node.id()].value();
60     return OkStatus();
61   }
62 
63   if (!IsFunctionCall(node)) {
64     return Set(node, false, is_deep, &cache_);
65   }
66   const FunctionDef* fdef;
67   NameAttrList func;
68   TF_RETURN_IF_ERROR(GetFunctionDefAndAttrs(flib_def_, node, &fdef, &func));
69   DataTypeVector types;
70   TF_RETURN_IF_ERROR(
71       OutputTypesForNode(AttrSlice(&func.attr()), fdef->signature(), &types));
72   for (DataType type : types) {
73     if (type == DT_RESOURCE) {
74       return Set(node, true, is_deep, &cache_);
75     }
76   }
77   return Set(node, false, is_deep, &cache_);
78 }
79 
GetFunctionDefAndAttrs(const FunctionLibraryDefinition & flib_def,const Node & node,const FunctionDef ** fdef,NameAttrList * func)80 Status GetFunctionDefAndAttrs(const FunctionLibraryDefinition& flib_def,
81                               const Node& node, const FunctionDef** fdef,
82                               NameAttrList* func) {
83   TF_RETURN_IF_ERROR(GetNodeAttr(node.def(), "f", func));
84   const string& function_name = func->name();
85   *fdef = flib_def.Find(function_name);
86   if (*fdef == nullptr) {
87     return errors::InvalidArgument(
88         "Failed to find function \"", function_name,
89         "\" in function library: ", flib_def.ToProto().DebugString());
90   }
91   return OkStatus();
92 }
93 
FunctionStack(const string & function_name)94 FunctionStack::FunctionStack(const string& function_name)
95     : current_function_name_(function_name) {}
96 
Push(const Node * node_in_current_function,const string & new_current_function) const97 FunctionStack FunctionStack::Push(const Node* node_in_current_function,
98                                   const string& new_current_function) const {
99   FunctionStack new_stack(new_current_function);
100   new_stack.frames_ = frames_;
101   new_stack.frames_.emplace_back(current_function_name_,
102                                  node_in_current_function);
103   return new_stack;
104 }
105 
HasFunction(const string & function_name) const106 bool FunctionStack::HasFunction(const string& function_name) const {
107   if (current_function_name_ == function_name) {
108     return true;
109   }
110   for (const Frame& frame : frames_) {
111     if (frame.function_name == function_name) {
112       return true;
113     }
114   }
115   return false;
116 }
117 
FormatForError() const118 string FunctionStack::FormatForError() const {
119   std::vector<string> msgs;
120   for (int i = 0; i < frames_.size(); ++i) {
121     if (frames_[i].function_name.empty()) {
122       // Empty function body should only happen at the top level, i.e. i = 0.
123       // All internal frames should have valid function names.
124       msgs.push_back(absl::StrCat("Graph contains node ",
125                                   FormatNodeForError(*frames_[i].node)));
126 
127     } else {
128       msgs.push_back(absl::StrCat(
129           "Function ", errors::FormatFunctionForError(frames_[i].function_name),
130           " contains node ", FormatNodeForError(*frames_[i].node)));
131     }
132     const string& fname = (i + 1 < frames_.size())
133                               ? frames_[i + 1].function_name
134                               : current_function_name_;
135     msgs.push_back(absl::StrCat("Node ", FormatNodeForError(*frames_[i].node),
136                                 " calls function ",
137                                 errors::FormatFunctionForError(fname)));
138   }
139   return absl::StrJoin(msgs, "\n  ");
140 }
141 
142 namespace {
143 
144 using OutputEdgeMap = std::vector<std::vector<const Edge*>>;
145 
146 constexpr char kIdentityOp[] = "Identity";
147 
Uniquify(const string & candidate_name,std::unordered_set<string> * node_names)148 string Uniquify(const string& candidate_name,
149                 std::unordered_set<string>* node_names) {
150   if (node_names->find(candidate_name) == node_names->end()) {
151     node_names->insert(candidate_name);
152     return candidate_name;
153   }
154 
155   for (int counter = 0;; ++counter) {
156     string candidate = absl::StrCat(candidate_name, "_", counter);
157     if (node_names->find(candidate) == node_names->end()) {
158       node_names->insert(candidate);
159       return candidate;
160     }
161   }
162 }
163 
AddInputIdentity(Node * node,int input_idx,Graph * graph,std::unordered_set<string> * node_names)164 Status AddInputIdentity(Node* node, int input_idx, Graph* graph,
165                         std::unordered_set<string>* node_names) {
166   const Edge* edge;
167   TF_RETURN_IF_ERROR(node->input_edge(input_idx, &edge));
168 
169   string identity_name = Uniquify(
170       absl::StrCat(edge->src()->name(), "_", node->name()), node_names);
171 
172   NodeDefBuilder builder(identity_name, kIdentityOp);
173   builder.Attr("T", node->input_type(input_idx));
174   NodeDefBuilder::NodeOut input(edge->src()->name(), edge->src_output(),
175                                 node->input_type(input_idx));
176   builder.Input(input);
177   NodeDef identity_def;
178   TF_RETURN_IF_ERROR(builder.Finalize(&identity_def));
179   MergeDebugInfo(NodeDebugInfo(*node), &identity_def);
180 
181   VLOG(6) << "Adding identity into " << edge->src()->name() << ":"
182           << edge->src_output() << " -> " << edge->dst()->name() << ":"
183           << input_idx << " \n"
184           << identity_def.DebugString();
185 
186   TF_ASSIGN_OR_RETURN(Node * identity_node, graph->AddNode(identity_def));
187   graph->AddEdge(edge->src(), edge->src_output(), identity_node, 0);
188 
189   // Replace node's `input_idx` input with the new identity's 0'th output
190   TF_RETURN_IF_ERROR(graph->UpdateEdge(identity_node, 0, node, input_idx));
191 
192   VLOG(6) << "Successfully inserted identity. Modified node: \n"
193           << node->DebugString();
194   return OkStatus();
195 }
196 
197 struct EdgePtrCompare {
operator ()tensorflow::__anond03c3ac50211::EdgePtrCompare198   bool operator()(const Edge* lhs, const Edge* rhs) const {
199     return lhs->id() < rhs->id();
200   }
201 };
202 
AddOutputIdentities(Node * node,Graph * graph,std::unordered_set<string> * node_names)203 Status AddOutputIdentities(Node* node, Graph* graph,
204                            std::unordered_set<string>* node_names) {
205   auto add_identity = [&](int src_output, const string& identity_name,
206                           Node** identity_node) {
207     NodeDefBuilder builder(identity_name, kIdentityOp);
208     builder.Attr("T", node->output_type(src_output));
209     NodeDefBuilder::NodeOut input(node->name(), src_output,
210                                   node->output_type(src_output));
211     builder.Input(input);
212     NodeDef identity_def;
213     TF_RETURN_IF_ERROR(builder.Finalize(&identity_def));
214     MergeDebugInfo(NodeDebugInfo(*node), &identity_def);
215 
216     TF_ASSIGN_OR_RETURN(*identity_node, graph->AddNode(identity_def));
217     graph->AddEdge(node, src_output, *identity_node, 0);
218     return OkStatus();
219   };
220 
221   // output_used[i] == true iff `node`'s i'th output is used
222   // in this graph
223   std::vector<bool> output_used(node->num_outputs(), false);
224   // Copy the set of edges since EdgeSet does not allow modifications
225   // to graph edges during iteration.
226   const EdgeSet& out_edges = node->out_edges();
227   std::vector<const Edge*> edge_vector(out_edges.begin(), out_edges.end());
228   std::sort(edge_vector.begin(), edge_vector.end(), EdgePtrCompare());
229   for (const Edge* edge : edge_vector) {
230     if (edge->IsControlEdge()) {
231       continue;
232     }
233     output_used[edge->src_output()] = true;
234 
235     Node* dst = edge->dst();
236     int dst_input = edge->dst_input();
237     int src_output = edge->src_output();
238     string identity_name =
239         Uniquify(absl::StrCat(node->name(), "_", dst->name()), node_names);
240     Node* identity_node;
241     TF_RETURN_IF_ERROR(add_identity(src_output, identity_name, &identity_node));
242     VLOG(6) << "Adding identity into " << node->name() << ":" << src_output
243             << " -> " << dst->name() << ":" << dst_input << " \n"
244             << identity_node->DebugString();
245 
246     // Make original dst node consume the new identity's output instead of
247     // `node`'s output.
248     TF_RETURN_IF_ERROR(graph->UpdateEdge(identity_node, 0, dst, dst_input));
249   }
250 
251   for (int output_idx = 0; output_idx < node->num_outputs(); ++output_idx) {
252     if (output_used[output_idx]) {
253       continue;
254     }
255     // The output is unused in the graph. Just add an identity
256     // consuming it.
257     string identity_name = Uniquify(node->name(), node_names);
258     Node* identity_node;
259     TF_RETURN_IF_ERROR(add_identity(output_idx, identity_name, &identity_node));
260     VLOG(6) << "Added identity into " << node->name() << ":" << output_idx
261             << " -> <no consumer>: \n"
262             << identity_node->DebugString();
263   }
264   return OkStatus();
265 }
266 
IsolateNode(Node * node,Graph * graph)267 Status IsolateNode(Node* node, Graph* graph) {
268   // We use `node_names` to make sure we pick unique names.
269   // We don't use graph->NewName() because it produces verbose names and
270   // does not actually ensure that they are unique (it assumes all names
271   // are generated using it, which is not true today).
272   std::unordered_set<string> node_names(graph->num_nodes());
273   for (Node* n : graph->nodes()) {
274     node_names.insert(n->name());
275   }
276 
277   for (int i = 0; i < node->num_inputs(); ++i) {
278     TF_RETURN_IF_ERROR(AddInputIdentity(node, i, graph, &node_names));
279   }
280   TF_RETURN_IF_ERROR(AddOutputIdentities(node, graph, &node_names));
281   return OkStatus();
282 }
283 
284 }  // namespace
285 
IsolatePlacerInspectionRequiredOps(const FunctionLibraryDefinition & flib_def,Graph * graph)286 Status IsolatePlacerInspectionRequiredOps(
287     const FunctionLibraryDefinition& flib_def, Graph* graph) {
288   PlacerInspectionRequiredOpChecker checker(graph, &flib_def);
289   // It is OK to add nodes to the graph during iteration.
290   // New nodes will get ids above current ids. The loop
291   // will loop over current nodes only because the op_nodes()
292   // iterator uses node ids to iterate.
293   // Because the new nodes will be higher ids, the caching in
294   // the checker will also work fine as new nodes are added.
295   for (Node* node : graph->op_nodes()) {
296     bool should_be_isolated = false;
297     TF_RETURN_IF_ERROR(
298         checker.IsPlacerInspectionRequired(*node, &should_be_isolated));
299     if (!should_be_isolated) {
300       continue;
301     }
302     TF_RETURN_IF_ERROR(IsolateNode(node, graph));
303   }
304 
305   return OkStatus();
306 }
307 
308 }  // namespace tensorflow
309