1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/partially_decluster_pass.h"
17
18 #include "absl/algorithm/container.h"
19 #include "absl/container/flat_hash_set.h"
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/jit/device_util.h"
22 #include "tensorflow/compiler/jit/xla_cluster_util.h"
23 #include "tensorflow/compiler/tf2xla/const_analysis.h"
24 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
25 #include "tensorflow/core/common_runtime/function.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/memory_types.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/graph/graph_node_util.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/public/version.h"
33
34 namespace tensorflow {
35 namespace {
36
NotBackedge(const Edge & edge)37 bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); }
38
39 namespace reduce_device_to_host_copies {
FindNodesToDecluster(const Graph & graph,absl::flat_hash_set<Node * > * result,absl::Span<Node * const> post_order)40 Status FindNodesToDecluster(const Graph& graph,
41 absl::flat_hash_set<Node*>* result,
42 absl::Span<Node* const> post_order) {
43 // Find nodes that have at least one user outside their cluster that expects
44 // hostmem output. These nodes should be cloned to outside the cluster to
45 // avoid the device-host copy we'd otherwise need.
46
47 MemoryTypeVector input_mtypes, output_mtypes;
48
49 for (Node* n : post_order) {
50 std::optional<absl::string_view> from_cluster = GetXlaClusterForNode(*n);
51 if (!from_cluster) {
52 continue;
53 }
54
55 // Assume the benefit of not outputting a larger tensor outweighs the
56 // benefit of this check.
57 // TODO(tpopp): Only apply this if the value being consumed is not output
58 // from the cluster to another consumer.
59 // TODO(tpopp): See if XlaRun can be modified to avoid this issue
60 // completely.
61 if (IsShapeConsumerOp(*n)) {
62 continue;
63 }
64 // We assume the only XLA-auto-clusterable operations with side effects are
65 // resource variable updates. We can't execute these twice.
66 if (HasResourceInputOrOutput(*n)) {
67 continue;
68 }
69
70 DeviceType device_type("");
71 TF_RETURN_IF_ERROR(
72 DeviceNameToDeviceType(n->assigned_device_name(), &device_type));
73 TF_RETURN_IF_ERROR(MemoryTypesForNode(graph.op_registry(), device_type,
74 n->def(), &input_mtypes,
75 &output_mtypes));
76 for (const Edge* e : n->out_edges()) {
77 Node* dst = e->dst();
78
79 if (e->IsControlEdge()) {
80 continue;
81 }
82
83 bool edge_incurs_extra_device_to_host_copy;
84 if (output_mtypes[e->src_output()] == DEVICE_MEMORY) {
85 // If the output of the *TensorFlow* operation is in DEVICE_MEMORY then
86 // keep the node clustered -- XLA will also produce the output in device
87 // memory and we will get some benefit from clustering.
88 edge_incurs_extra_device_to_host_copy = false;
89 } else {
90 MemoryTypeVector dst_input_mtypes, dst_output_mtypes;
91 DeviceType dst_device_type("");
92 TF_RETURN_IF_ERROR(DeviceNameToDeviceType(dst->assigned_device_name(),
93 &dst_device_type));
94 TF_RETURN_IF_ERROR(MemoryTypesForNode(graph.op_registry(), device_type,
95 dst->def(), &dst_input_mtypes,
96 &dst_output_mtypes));
97 edge_incurs_extra_device_to_host_copy =
98 dst_input_mtypes[e->dst_input()] == HOST_MEMORY;
99 }
100
101 if (!edge_incurs_extra_device_to_host_copy) {
102 continue;
103 }
104
105 // Check if `dst` is in a different cluster, unclustered, or about to be
106 // partially declustered (here we rely on the post-order traversal order).
107 // If yes, decluster `n` to avoid the device-to-host memcpy.
108 std::optional<absl::string_view> dst_cluster =
109 result->count(dst) ? std::nullopt : GetXlaClusterForNode(*dst);
110 if (from_cluster != dst_cluster) {
111 CHECK(result->insert(n).second);
112 break;
113 }
114 }
115 }
116 return OkStatus();
117 }
118
PartiallyDeclusterNode(Graph * graph,Node * n)119 Status PartiallyDeclusterNode(Graph* graph, Node* n) {
120 absl::string_view cluster_name = *GetXlaClusterForNode(*n);
121 absl::InlinedVector<const Edge*, 6> out_edges_to_clone;
122 for (const Edge* out_edge : n->out_edges()) {
123 if (out_edge->IsControlEdge()) {
124 continue;
125 }
126
127 Node* dst = out_edge->dst();
128 std::optional<absl::string_view> dst_cluster_name =
129 GetXlaClusterForNode(*dst);
130 if (dst_cluster_name != cluster_name) {
131 out_edges_to_clone.push_back(out_edge);
132 }
133 }
134
135 CHECK(!out_edges_to_clone.empty()) << n->DebugString();
136
137 NodeDef ndef = n->def();
138 ndef.set_name(absl::StrCat(n->name(), "/declustered"));
139 MergeDebugInfo(NodeDebugInfo(n->def()), &ndef);
140 RemoveFromXlaCluster(&ndef);
141 TF_ASSIGN_OR_RETURN(Node * cloned_node, graph->AddNode(ndef));
142 cloned_node->set_assigned_device_name(n->assigned_device_name());
143
144 for (const Edge* in_edge : n->in_edges()) {
145 graph->AddEdge(in_edge->src(), in_edge->src_output(), cloned_node,
146 in_edge->dst_input());
147 }
148
149 for (const Edge* out_edge_to_clone : out_edges_to_clone) {
150 graph->AddEdge(cloned_node, out_edge_to_clone->src_output(),
151 out_edge_to_clone->dst(), out_edge_to_clone->dst_input());
152 graph->RemoveEdge(out_edge_to_clone);
153 }
154
155 if (n->out_edges().empty()) {
156 graph->RemoveNode(n);
157 }
158
159 return OkStatus();
160 }
161
162 // Clones nodes to outside their cluster to avoid device-to-host copies. For
163 // instance, converts this:
164 //
165 // .....
166 // |
167 // v
168 // A_Clustered ====> C_Unclustered
169 // |
170 // v
171 // B_Clustered
172 //
173 // to:
174 //
175 // .....
176 // | |
177 // | +-------------+
178 // | |
179 // v v
180 // A_Clustered A_Unclustered ====> C_Unclustered
181 // |
182 // v
183 // B_Clustered
184 //
185 // where the ===> arrow has a hostmem source and destination and would entail a
186 // device to host copy if the source and destination were not in the same XLA
187 // cluster.
PartiallyDeclusterGraph(Graph * graph)188 Status PartiallyDeclusterGraph(Graph* graph) {
189 // When deciding whether to decluster a particular node, we base our decision
190 // on if we've decided that some of its consumers have to be declustered too.
191 // Iterating the graph in post-order guarantees that consumers have been
192 // visited before producers.
193 std::vector<Node*> post_order;
194 GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(),
195 /*edge_filter=*/NotBackedge);
196
197 absl::flat_hash_set<Node*> nodes_to_partially_decluster;
198 TF_RETURN_IF_ERROR(
199 FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order));
200
201 if (VLOG_IS_ON(3)) {
202 for (Node* n : post_order) {
203 if (nodes_to_partially_decluster.count(n)) {
204 VLOG(3) << n->DebugString();
205 }
206 }
207 }
208
209 for (Node* n : post_order) {
210 if (nodes_to_partially_decluster.count(n)) {
211 TF_RETURN_IF_ERROR(PartiallyDeclusterNode(graph, n));
212 }
213 }
214
215 // Recompute post order since PartiallyDeclusterNode may have deleted nodes.
216 post_order.clear();
217 GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(),
218 /*edge_filter=*/NotBackedge);
219 nodes_to_partially_decluster.clear();
220 TF_RETURN_IF_ERROR(
221 FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order));
222 CHECK(nodes_to_partially_decluster.empty());
223
224 return OkStatus();
225 }
226 } // namespace reduce_device_to_host_copies
227
228 namespace reduce_recompilation {
IsIntraClusterEdge(const Edge & edge)229 bool IsIntraClusterEdge(const Edge& edge) {
230 std::optional<absl::string_view> src_cluster_name =
231 GetXlaClusterForNode(*edge.src());
232 std::optional<absl::string_view> dst_cluster_name =
233 GetXlaClusterForNode(*edge.dst());
234 return src_cluster_name.has_value() && src_cluster_name == dst_cluster_name;
235 }
236
IsMustCompileDevice(const DeviceType & device_type)237 bool IsMustCompileDevice(const DeviceType& device_type) {
238 const XlaOpRegistry::DeviceRegistration* registration;
239 if (XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) {
240 return registration->autoclustering_policy ==
241 XlaOpRegistry::AutoclusteringPolicy::kAlways;
242 }
243
244 return false;
245 }
246
MustCompileNode(const Node * n,bool * must_compile)247 Status MustCompileNode(const Node* n, bool* must_compile) {
248 DeviceType device_type("");
249 TF_RETURN_IF_ERROR(
250 DeviceNameToDeviceType(n->assigned_device_name(), &device_type));
251
252 if (IsMustCompileDevice(device_type)) {
253 *must_compile = true;
254 return OkStatus();
255 }
256
257 // We must compile `n` if it does not have a TensorFlow kernel.
258 *must_compile = !FindKernelDef(device_type, n->def(), nullptr, nullptr).ok();
259 return OkStatus();
260 }
261
262 // Declusters nodes to reduce the number of times we think we need to recompile
263 // a TensorFlow graph.
264 //
265 // Abstractly, if we have a cluster of this form:
266 //
267 // x0 = arg0
268 // x1 = arg1
269 // ...
270 // shape = f(x0, x1, ...)
271 // result = Reshape(input=<something>, new_shape=shape)
272 //
273 // then pulling `f` out of the cluster may reduce the number of compilations and
274 // will never increase the number of compilations.
275 //
276 // We may reduce the number of compilations if f is many to one. For instance
277 // if f(x,y) = x-y then x=3,y=1 and x=4,y=2 will generate two different
278 // compilations if f is in the cluster but only one compilation if f is outside
279 // the cluster.
280 //
281 // Declustering f will increase the number of compilations only if f is a
282 // one-to-many "function" i.e. isn't a function at all. RNG is one possible
283 // example, depending on how we look at it. But we never create clusters where
284 // such f's would be marked as must-be-constant.
285 //
286 // We assume here that the extra repeated (repeated compared to a clustered f
287 // where it will always be constant folded) host-side computation of f does not
288 // regress performance in any significant manner. We will have to revisit this
289 // algorithm with a more complex cost model if this assumption turns out to be
290 // incorrect.
PartiallyDeclusterGraph(Graph * graph,const FunctionLibraryDefinition * flib_def,Env * env)291 Status PartiallyDeclusterGraph(Graph* graph,
292 const FunctionLibraryDefinition* flib_def,
293 Env* env) {
294 std::vector<bool> compile_time_const_nodes(graph->num_node_ids());
295 OptimizerOptions opts;
296 auto pflr = std::make_unique<ProcessFunctionLibraryRuntime>(
297 nullptr, env, /*config=*/nullptr, TF_GRAPH_DEF_VERSION, flib_def, opts);
298 FunctionLibraryRuntime* lib_runtime =
299 pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
300 TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*graph, nullptr,
301 &compile_time_const_nodes,
302 lib_runtime, IsIntraClusterEdge));
303
304 std::vector<Node*> rpo;
305 GetReversePostOrder(*graph, &rpo, /*stable_comparator=*/NodeComparatorName(),
306 /*edge_filter=*/NotBackedge);
307 for (Node* n : rpo) {
308 if (!compile_time_const_nodes[n->id()]) {
309 continue;
310 }
311
312 absl::string_view cluster_name = *GetXlaClusterForNode(*n);
313 bool node_on_cluster_edge =
314 absl::c_all_of(n->in_edges(), [&](const Edge* e) {
315 std::optional<absl::string_view> incoming_cluster =
316 GetXlaClusterForNode(*e->src());
317 return !incoming_cluster || *incoming_cluster != cluster_name;
318 });
319
320 // We don't want to decluster F in a graph like
321 //
322 // Input -> OP -> Shape -> F -> Reshape
323 //
324 // Doing so will break up the cluster. Even if we were okay with breaking
325 // up the cluster we will at least have to relabel the two clusters to have
326 // different cluster names.
327 //
328 // We may want to revisit this in the future: we may have cases where OP is
329 // a small computation that does not benefit from XLA while XLA can optimize
330 // everything that follows the Reshape. In these cases it may be wise to
331 // remove Input, OP, Shape and F from the cluster, if F is a many-to-one
332 // function.
333 //
334 // Note that we do the right thing for graphs like:
335 //
336 // Input -> F0 -> F1 -> Reshape
337 //
338 // Since we iterate in RPO, we'll first encounter F0, decluster it, then
339 // encounter F1, decluster it and so on.
340 if (node_on_cluster_edge) {
341 bool must_compile_node;
342 TF_RETURN_IF_ERROR(MustCompileNode(n, &must_compile_node));
343 if (!must_compile_node) {
344 if (n->IsConstant()) {
345 // We must decluster Const nodes that have an input control edge from
346 // a different device, because this node may be part of the
347 // co-ordination of while loops between devices.
348 for (auto it : n->in_edges()) {
349 if (!it->src()->assigned_device_name().empty() &&
350 it->src()->assigned_device_name() !=
351 n->assigned_device_name()) {
352 VLOG(3) << "Declustering Const with cross-device control input "
353 << n->name();
354 RemoveFromXlaCluster(n);
355 break;
356 }
357 }
358 } else {
359 VLOG(3) << "Declustering must-be-constant node " << n->name();
360 RemoveFromXlaCluster(n);
361 }
362 }
363 }
364 }
365
366 return OkStatus();
367 }
368 } // namespace reduce_recompilation
369
370 namespace decluster_root_shape_consumers {
371
PartiallyDeclusterGraph(Graph * graph)372 Status PartiallyDeclusterGraph(Graph* graph) {
373 std::vector<Node*> reverse_post_order;
374 GetReversePostOrder(*graph, &reverse_post_order,
375 /*stable_comparator=*/NodeComparatorName(),
376 /*edge_filter=*/NotBackedge);
377
378 for (Node* n : reverse_post_order) {
379 if (!IsShapeConsumerOp(*n)) {
380 continue;
381 }
382
383 std::optional<absl::string_view> cluster = GetXlaClusterForNode(*n);
384 if (!cluster.has_value()) {
385 continue;
386 }
387
388 auto input_belongs_to_same_cluster = [&](const Edge* e) {
389 return cluster == GetXlaClusterForNode(*e->src());
390 };
391
392 if (absl::c_any_of(n->in_edges(), input_belongs_to_same_cluster)) {
393 continue;
394 }
395
396 VLOG(2) << "Declustering " << n->name()
397 << " because it is a root shape consumer";
398 RemoveFromXlaCluster(n);
399 }
400 return OkStatus();
401 }
402 } // namespace decluster_root_shape_consumers
403 } // namespace
404
Run(const GraphOptimizationPassOptions & options)405 Status PartiallyDeclusterPass::Run(
406 const GraphOptimizationPassOptions& options) {
407 // NB! In this pass we assume the only XLA-auto-clusterable operations that
408 // may have side effects are resource variable operations so we don't cluster
409 // those. The pass will have to be updated if this assumption becomes
410 // invalid.
411
412 Graph* graph = options.graph->get();
413
414 TF_RETURN_IF_ERROR(
415 reduce_device_to_host_copies::PartiallyDeclusterGraph(graph));
416 if (options.flib_def == nullptr) {
417 return errors::InvalidArgument(
418 "GraphOptimizationPassOptions::flib_def must be set for "
419 "PartiallyDeclusterPass.");
420 }
421 if (options.session_options == nullptr ||
422 options.session_options->env == nullptr) {
423 return errors::InvalidArgument(
424 "GraphOptimizationPassOptions::session_options::env must be set for "
425 "PartiallyDeclusterPass.");
426 }
427 TF_RETURN_IF_ERROR(reduce_recompilation::PartiallyDeclusterGraph(
428 graph, options.flib_def, options.session_options->env));
429
430 TF_RETURN_IF_ERROR(
431 decluster_root_shape_consumers::PartiallyDeclusterGraph(graph));
432
433 return OkStatus();
434 }
435 } // namespace tensorflow
436