1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/encapsulate_util.h"
17
18 #include <algorithm>
19 #include <iterator>
20
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/types/optional.h"
25 #include "tensorflow/compiler/jit/shape_inference.h"
26 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
27 #include "tensorflow/core/framework/node_def_util.h"
28 #include "tensorflow/core/graph/node_builder.h"
29 #include "tensorflow/core/protobuf/error_codes.pb.h"
30 #include "tensorflow/stream_executor/lib/statusor.h"
31
32 using stream_executor::port::StatusOr;
33
34 namespace tensorflow {
35
36 namespace {
37
38 // Returns string attribute value for the node if the attribute is present,
39 // otherwise returns empty optional value.
GetStringAttr(const Node & n,const string & attr_name)40 std::optional<string> GetStringAttr(const Node& n, const string& attr_name) {
41 auto attr = n.attrs().Find(attr_name);
42 if (!attr) {
43 return std::nullopt;
44 } else {
45 return attr->s();
46 }
47 }
48
49 // Adds a value to the node's list attribute.
50 template <typename T>
AppendToListAttr(Node * n,const string & attr_name,const string & value)51 Status AppendToListAttr(Node* n, const string& attr_name, const string& value) {
52 std::vector<T> attr_value;
53 Status s = GetNodeAttr(n->attrs(), attr_name, &attr_value);
54 if (!s.ok() && s.code() != error::NOT_FOUND) {
55 return s;
56 }
57
58 n->ClearAttr(attr_name);
59 attr_value.push_back(value);
60 n->AddAttr(attr_name, attr_value);
61 return OkStatus();
62 }
63
64 // Replaces attribute value.
65 template <typename T>
ReplaceAttr(Node * n,const string & attr_name,const T & value)66 void ReplaceAttr(Node* n, const string& attr_name, const T& value) {
67 n->ClearAttr(attr_name);
68 n->AddAttr(attr_name, value);
69 }
70
71 // Step 1 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of
72 // `PreprocessEdgesBetweenOutsideCompilations` for details.
PreprocessControlEdgesBetweenOutsideCompilations(Graph * g,const string & outside_compilation_attr_name)73 Status PreprocessControlEdgesBetweenOutsideCompilations(
74 Graph* g, const string& outside_compilation_attr_name) {
75 // Gather edges to remove. We should not remove the edge while iterating.
76 std::vector<const Edge*> edges_to_remove;
77 for (const Edge* e : g->edges()) {
78 if (!e->IsControlEdge()) {
79 continue;
80 }
81
82 auto src_outside_compilation =
83 GetStringAttr(*e->src(), outside_compilation_attr_name);
84 auto dst_outside_compilation =
85 GetStringAttr(*e->dst(), outside_compilation_attr_name);
86
87 if (src_outside_compilation && dst_outside_compilation) {
88 if (*src_outside_compilation != *dst_outside_compilation) {
89 // Case 1a: outside compilation to outside compilation control edge.
90 edges_to_remove.push_back(e);
91
92 TF_RETURN_IF_ERROR(AppendToListAttr<string>(
93 e->dst(), kXlaControlDependenciesWithinXlaClusterAttrName,
94 e->src()->name()));
95 }
96 } else if (src_outside_compilation && !dst_outside_compilation) {
97 // Case 1b: outside compilation to its XLA computation control edge.
98 ReplaceAttr(e->src(), kXlaConnectedToXlaComputationAttrName, true);
99 } else if (!src_outside_compilation && dst_outside_compilation) {
100 // Case 1b: XLA computation to outside compilation in it control edge.
101 ReplaceAttr(e->dst(), kXlaConnectedFromXlaComputationAttrName, true);
102 }
103 }
104
105 for (auto e : edges_to_remove) {
106 g->RemoveEdge(e);
107 }
108 return OkStatus();
109 }
110
111 // Step 2 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of
112 // `PreprocessEdgesBetweenOutsideCompilations` for details.
PreprocessDataEdgesBetweenOutsideCompilations(Graph * g,const string & outside_compilation_attr_name)113 Status PreprocessDataEdgesBetweenOutsideCompilations(
114 Graph* g, const string& outside_compilation_attr_name) {
115 // Gather edges between outside compilation and host computation. Notice that
116 // we do not store `Edge*` directly because we remove some nodes while adding
117 // Identity nodes, and those Edge pointers might be invalidated.
118 struct EdgeInfo {
119 int dst_input, dst_node_id;
120 };
121 std::vector<EdgeInfo> edges;
122 for (const Edge* e : g->edges()) {
123 if (e->IsControlEdge()) {
124 continue;
125 }
126
127 auto src_outside_compilation =
128 GetStringAttr(*e->src(), outside_compilation_attr_name);
129 auto dst_outside_compilation =
130 GetStringAttr(*e->dst(), outside_compilation_attr_name);
131
132 if (src_outside_compilation && dst_outside_compilation &&
133 *src_outside_compilation != *dst_outside_compilation) {
134 edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()});
135 VLOG(4) << "Oc -> oc edge: " << e->DebugString();
136 }
137 }
138
139 // Remove the edge from host to outside compilation. Add a placeholder as
140 // outside compilation node input.
141 std::map<std::pair<string, int>, Node*> placeholders;
142 for (int i = 0, end = edges.size(); i < end; i++) {
143 Node* dst = g->FindNodeId(edges[i].dst_node_id);
144 const Edge* e;
145 TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e));
146 Node* src = e->src();
147 int src_output = e->src_output(), dst_input = e->dst_input();
148 g->RemoveEdge(e);
149
150 // Find or create placeholder node.
151 string new_name =
152 absl::StrCat(src->name(), "_oc_to_oc_placeholder_", src_output);
153 auto placeholder_index = std::make_pair(src->name(), src_output);
154 auto iter = placeholders.find(placeholder_index);
155 Node* placeholder_node;
156 if (iter == placeholders.end()) {
157 NodeDefBuilder placeholder_builder(new_name, "Placeholder");
158 placeholder_builder.Attr("dtype", src->output_type(src_output));
159 string outside_compilation_attr;
160 TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(),
161 outside_compilation_attr_name,
162 &outside_compilation_attr));
163 placeholder_builder.Attr(outside_compilation_attr_name,
164 outside_compilation_attr);
165 placeholder_builder.Attr(kOutsideCompilationOriginalNodeAttrName,
166 src->name());
167 placeholder_builder.Attr(kOutsideCompilationSrcOutputAttrName,
168 src_output);
169 NodeDef placeholder_def;
170 TF_RETURN_IF_ERROR(placeholder_builder.Finalize(&placeholder_def));
171 TF_ASSIGN_OR_RETURN(placeholder_node, g->AddNode(placeholder_def));
172 placeholders[placeholder_index] = placeholder_node;
173 } else {
174 placeholder_node = iter->second;
175 }
176 g->AddEdge(placeholder_node, 0, dst, dst_input);
177
178 // Replace `e->dst()` because its input node changed.
179 NodeDef new_def = dst->def();
180 *new_def.mutable_input(dst_input) = placeholder_node->name();
181 TF_ASSIGN_OR_RETURN(Node * dst_replace_node, ReplaceNode(g, dst, new_def));
182
183 // Other edge in `edges` might have `e->dst()` as src or dst
184 // node. Before removing `e->dst()`, replace those edges with
185 // corresponding edges for `dst_replace_node`.
186 for (int j = i + 1, end = edges.size(); j < end; j++) {
187 if (edges[j].dst_node_id == edges[i].dst_node_id) {
188 edges[j].dst_node_id = dst_replace_node->id();
189 }
190 }
191 }
192 return OkStatus();
193 }
194
195 // Step 1 for `PostprocessEdgesBetweenOutsideCompilations`. See comments of
196 // `PostprocessEdgesBetweenOutsideCompilations` for details.
PostprocessDataEdgesBetweenOutsideCompilations(Graph * g,const string & outside_compilation_attr_name)197 Status PostprocessDataEdgesBetweenOutsideCompilations(
198 Graph* g, const string& outside_compilation_attr_name) {
199 // Gather all outside compilation to outside compilation nodes.
200 std::vector<Node*> placeholder_nodes;
201 for (Node* n : g->nodes()) {
202 if (n->type_string() == "Placeholder" &&
203 HasNodeAttr(n->def(), kOutsideCompilationOriginalNodeAttrName)) {
204 placeholder_nodes.push_back(n);
205 }
206 }
207
208 // Remove the placeholder nodes, and reconnect original edge.
209 auto node_name_index = g->BuildNodeNameIndex();
210 for (auto n : placeholder_nodes) {
211 string node_name;
212 int node_src_output;
213 TF_RETURN_IF_ERROR(GetNodeAttr(
214 n->attrs(), kOutsideCompilationOriginalNodeAttrName, &node_name));
215 TF_RETURN_IF_ERROR(GetNodeAttr(
216 n->attrs(), kOutsideCompilationSrcOutputAttrName, &node_src_output));
217 auto iter = node_name_index.find(node_name);
218 if (iter == node_name_index.end()) {
219 return errors::Internal(
220 "Cannot find original node for oc -> host placeholder node ",
221 node_name);
222 }
223
224 // Change all usage node to use the original node instead.
225 Node* original_node = iter->second;
226 std::vector<const Edge*> control_edges;
227 std::vector<OutEdgeInfo> data_edges;
228 for (auto e : n->out_edges()) {
229 if (e->IsControlEdge()) {
230 control_edges.push_back(e);
231 } else {
232 data_edges.push_back({e->dst(), e->src_output(), e->dst_input()});
233 }
234 }
235 for (const Edge* e : control_edges) {
236 g->AddControlEdge(original_node, e->dst());
237 g->RemoveEdge(e);
238 }
239 for (int i = 0, end = data_edges.size(); i < end; i++) {
240 Node* dst = data_edges[i].dst;
241 NodeDef new_def = dst->def();
242 int dst_input = data_edges[i].dst_input;
243 *new_def.mutable_input(dst_input) =
244 absl::StrCat(original_node->name(), ":", node_src_output);
245 TF_ASSIGN_OR_RETURN(Node * replace_node, ReplaceNode(g, dst, new_def));
246
247 const Edge* edge_to_replace = nullptr;
248 TF_RETURN_IF_ERROR(replace_node->input_edge(dst_input, &edge_to_replace));
249 g->RemoveEdge(edge_to_replace);
250 g->AddEdge(original_node, node_src_output, replace_node, dst_input);
251
252 // Other edges might have `dst` as dst node. Update those edges with
253 // `replace_node`.
254 for (int j = i + 1, end = data_edges.size(); j < end; j++) {
255 if (data_edges[j].dst == dst) {
256 data_edges[j].dst = replace_node;
257 }
258 }
259
260 // Other placeholder node might have `dst` as original node. Update
261 // `node_name_index` with `replace_node`.
262 node_name_index[replace_node->name()] = replace_node;
263 }
264
265 // Remove placeholder node.
266 g->RemoveNode(n);
267 }
268 return OkStatus();
269 }
270
271 // Step 2 for `PostprocessEdgesBetweenOutsideCompilations`. See comments of
272 // `PostprocessEdgesBetweenOutsideCompilations` for details.
PostprocessControlEdgesBetweenOutsideCompilations(Graph * g,const string & outside_compilation_attr_name)273 Status PostprocessControlEdgesBetweenOutsideCompilations(
274 Graph* g, const string& outside_compilation_attr_name) {
275 auto node_name_index = g->BuildNodeNameIndex();
276
277 // Reconnect outside compilation to outside compilation control edge.
278 for (Node* n : g->nodes()) {
279 std::vector<string> control_deps;
280 Status s =
281 GetNodeAttr(n->attrs(), kXlaControlDependenciesWithinXlaClusterAttrName,
282 &control_deps);
283 if (!s.ok()) {
284 if (s.code() != error::NOT_FOUND) {
285 return s;
286 } else {
287 continue;
288 }
289 } else {
290 n->ClearAttr(kXlaControlDependenciesWithinXlaClusterAttrName);
291 for (const string& control_input : control_deps) {
292 auto iter = node_name_index.find(control_input);
293 if (iter == node_name_index.end()) {
294 return errors::Internal("Cannot find original node for ",
295 control_input);
296 }
297 g->AddControlEdge(iter->second, n);
298 }
299 }
300 }
301 return OkStatus();
302 }
303 } // namespace
304
305 const char kXlaInferredShapesAttrName[] = "_xla_inferred_shapes";
306
307 const char kXlaConnectedToXlaComputationAttrName[] =
308 "_xla_connected_to_xla_computation";
309 const char kXlaConnectedFromXlaComputationAttrName[] =
310 "_xla_connected_from_xla_computation";
311 const char kOutsideCompilationOriginalNodeAttrName[] =
312 "_xla_oc_to_oc_node_name";
313 const char kOutsideCompilationSrcOutputAttrName[] = "_xla_oc_to_oc_src_output";
314 const char kXlaControlDependenciesWithinXlaClusterAttrName[] =
315 "_xla_control_dependencies_within_xla_cluster";
316 const char kXlaIsLiftedArgAttrName[] = "_xla_is_lifted_arg";
317 const char kXlaLiftedArgOutsideCompilationAttrName[] = "_xla_lifted_arg_oc";
318 const char kXlaOutsideCompilationInputsAttrName[] = "_xla_oc_inputs";
319 const char kXlaIsPlaceholderForArg[] = "_xla_is_placeholder_for_arg";
320
PerformStaticShapeInferenceBeforeEncapsulation(Graph * g)321 Status PerformStaticShapeInferenceBeforeEncapsulation(Graph* g) {
322 // Perform shape inference.
323 std::map<int, InferredShape> arg_shapes;
324 GraphShapeInfo shape_info;
325 TF_RETURN_IF_ERROR(
326 InferShapes(g, arg_shapes, /*fnlib_def=*/nullptr, &shape_info));
327
328 // Add attribute for output shapes.
329 auto node_name_index = g->BuildNodeNameIndex();
330 for (auto iter : shape_info) {
331 std::vector<PartialTensorShape> output_shapes;
332 std::transform(iter.second.begin(), iter.second.end(),
333 std::back_inserter(output_shapes),
334 [](const InferredShape& inferred_shape) {
335 return inferred_shape.shape;
336 });
337 Node* n = node_name_index[iter.first];
338 n->AddAttr(kXlaInferredShapesAttrName, output_shapes);
339 }
340
341 return OkStatus();
342 }
343
344 StatusOr<std::unique_ptr<absl::flat_hash_map<string, std::vector<string>>>>
OutsideCompilationClusterDependencies(const Graph * g,const string & outside_compilation_attr_name)345 OutsideCompilationClusterDependencies(
346 const Graph* g, const string& outside_compilation_attr_name) {
347 auto cluster_deps = std::make_unique<
348 absl::flat_hash_map<string, absl::flat_hash_set<string>>>();
349
350 for (const Edge* e : g->edges()) {
351 auto src_outside_compilation =
352 GetStringAttr(*e->src(), outside_compilation_attr_name);
353 auto dst_outside_compilation =
354 GetStringAttr(*e->dst(), outside_compilation_attr_name);
355
356 if (src_outside_compilation && dst_outside_compilation &&
357 *src_outside_compilation != *dst_outside_compilation) {
358 auto dst_deps_it = cluster_deps->find(*dst_outside_compilation);
359 if (dst_deps_it == cluster_deps->end()) {
360 cluster_deps->insert(std::make_pair(
361 *dst_outside_compilation,
362 absl::flat_hash_set<string>({*src_outside_compilation})));
363 } else {
364 dst_deps_it->second.insert(*src_outside_compilation);
365 }
366 }
367 }
368
369 auto cluster_deps_ordered =
370 std::make_unique<absl::flat_hash_map<string, std::vector<string>>>();
371
372 for (auto it = cluster_deps->begin(); it != cluster_deps->end(); it++) {
373 std::vector<string> ordered_deps(it->second.begin(), it->second.end());
374 std::sort(ordered_deps.begin(), ordered_deps.end());
375 cluster_deps_ordered->insert(std::make_pair(it->first, ordered_deps));
376 }
377
378 return std::move(cluster_deps_ordered);
379 }
380
PreprocessEdgesBetweenOutsideCompilations(Graph * g,const string & outside_compilation_attr_name)381 Status PreprocessEdgesBetweenOutsideCompilations(
382 Graph* g, const string& outside_compilation_attr_name) {
383 // Remove edges from source node to outside compilation nodes, and edges
384 // from outside compilation nodes to sink node.
385 std::vector<const Edge*> edges_to_remove;
386 for (const Edge* e : g->source_node()->out_edges()) {
387 if (HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) {
388 edges_to_remove.push_back(e);
389 }
390 }
391 for (const Edge* e : g->sink_node()->in_edges()) {
392 if (HasNodeAttr(e->src()->def(), outside_compilation_attr_name)) {
393 edges_to_remove.push_back(e);
394 }
395 }
396 for (auto e : edges_to_remove) {
397 g->RemoveEdge(e);
398 }
399
400 TF_RETURN_IF_ERROR(PreprocessControlEdgesBetweenOutsideCompilations(
401 g, outside_compilation_attr_name));
402 TF_RETURN_IF_ERROR(PreprocessDataEdgesBetweenOutsideCompilations(
403 g, outside_compilation_attr_name));
404 return OkStatus();
405 }
406
PostprocessEdgesBetweenOutsideCompilations(Graph * g,const string & outside_compilation_attr_name)407 Status PostprocessEdgesBetweenOutsideCompilations(
408 Graph* g, const string& outside_compilation_attr_name) {
409 TF_RETURN_IF_ERROR(PostprocessDataEdgesBetweenOutsideCompilations(
410 g, outside_compilation_attr_name));
411 TF_RETURN_IF_ERROR(PostprocessControlEdgesBetweenOutsideCompilations(
412 g, outside_compilation_attr_name));
413 return OkStatus();
414 }
415
416 } // namespace tensorflow
417