xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/jit/partially_decluster_pass.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/partially_decluster_pass.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "absl/container/flat_hash_set.h"
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/jit/device_util.h"
22 #include "tensorflow/compiler/jit/xla_cluster_util.h"
23 #include "tensorflow/compiler/tf2xla/const_analysis.h"
24 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
25 #include "tensorflow/core/common_runtime/function.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/memory_types.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/graph/graph_node_util.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/public/version.h"
33 
34 namespace tensorflow {
35 namespace {
36 
NotBackedge(const Edge & edge)37 bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); }
38 
39 namespace reduce_device_to_host_copies {
FindNodesToDecluster(const Graph & graph,absl::flat_hash_set<Node * > * result,absl::Span<Node * const> post_order)40 Status FindNodesToDecluster(const Graph& graph,
41                             absl::flat_hash_set<Node*>* result,
42                             absl::Span<Node* const> post_order) {
43   // Find nodes that have at least one user outside their cluster that expects
44   // hostmem output.  These nodes should be cloned to outside the cluster to
45   // avoid the device-host copy we'd otherwise need.
46 
47   MemoryTypeVector input_mtypes, output_mtypes;
48 
49   for (Node* n : post_order) {
50     std::optional<absl::string_view> from_cluster = GetXlaClusterForNode(*n);
51     if (!from_cluster) {
52       continue;
53     }
54 
55     // Assume the benefit of not outputting a larger tensor outweighs the
56     // benefit of this check.
57     // TODO(tpopp): Only apply this if the value being consumed is not output
58     // from the cluster to another consumer.
59     // TODO(tpopp): See if XlaRun can be modified to avoid this issue
60     // completely.
61     if (IsShapeConsumerOp(*n)) {
62       continue;
63     }
64     // We assume the only XLA-auto-clusterable operations with side effects are
65     // resource variable updates.  We can't execute these twice.
66     if (HasResourceInputOrOutput(*n)) {
67       continue;
68     }
69 
70     DeviceType device_type("");
71     TF_RETURN_IF_ERROR(
72         DeviceNameToDeviceType(n->assigned_device_name(), &device_type));
73     TF_RETURN_IF_ERROR(MemoryTypesForNode(graph.op_registry(), device_type,
74                                           n->def(), &input_mtypes,
75                                           &output_mtypes));
76     for (const Edge* e : n->out_edges()) {
77       Node* dst = e->dst();
78 
79       if (e->IsControlEdge()) {
80         continue;
81       }
82 
83       bool edge_incurs_extra_device_to_host_copy;
84       if (output_mtypes[e->src_output()] == DEVICE_MEMORY) {
85         // If the output of the *TensorFlow* operation is in DEVICE_MEMORY then
86         // keep the node clustered -- XLA will also produce the output in device
87         // memory and we will get some benefit from clustering.
88         edge_incurs_extra_device_to_host_copy = false;
89       } else {
90         MemoryTypeVector dst_input_mtypes, dst_output_mtypes;
91         DeviceType dst_device_type("");
92         TF_RETURN_IF_ERROR(DeviceNameToDeviceType(dst->assigned_device_name(),
93                                                   &dst_device_type));
94         TF_RETURN_IF_ERROR(MemoryTypesForNode(graph.op_registry(), device_type,
95                                               dst->def(), &dst_input_mtypes,
96                                               &dst_output_mtypes));
97         edge_incurs_extra_device_to_host_copy =
98             dst_input_mtypes[e->dst_input()] == HOST_MEMORY;
99       }
100 
101       if (!edge_incurs_extra_device_to_host_copy) {
102         continue;
103       }
104 
105       // Check if `dst` is in a different cluster, unclustered, or about to be
106       // partially declustered (here we rely on the post-order traversal order).
107       // If yes, decluster `n` to avoid the device-to-host memcpy.
108       std::optional<absl::string_view> dst_cluster =
109           result->count(dst) ? std::nullopt : GetXlaClusterForNode(*dst);
110       if (from_cluster != dst_cluster) {
111         CHECK(result->insert(n).second);
112         break;
113       }
114     }
115   }
116   return OkStatus();
117 }
118 
PartiallyDeclusterNode(Graph * graph,Node * n)119 Status PartiallyDeclusterNode(Graph* graph, Node* n) {
120   absl::string_view cluster_name = *GetXlaClusterForNode(*n);
121   absl::InlinedVector<const Edge*, 6> out_edges_to_clone;
122   for (const Edge* out_edge : n->out_edges()) {
123     if (out_edge->IsControlEdge()) {
124       continue;
125     }
126 
127     Node* dst = out_edge->dst();
128     std::optional<absl::string_view> dst_cluster_name =
129         GetXlaClusterForNode(*dst);
130     if (dst_cluster_name != cluster_name) {
131       out_edges_to_clone.push_back(out_edge);
132     }
133   }
134 
135   CHECK(!out_edges_to_clone.empty()) << n->DebugString();
136 
137   NodeDef ndef = n->def();
138   ndef.set_name(absl::StrCat(n->name(), "/declustered"));
139   MergeDebugInfo(NodeDebugInfo(n->def()), &ndef);
140   RemoveFromXlaCluster(&ndef);
141   TF_ASSIGN_OR_RETURN(Node * cloned_node, graph->AddNode(ndef));
142   cloned_node->set_assigned_device_name(n->assigned_device_name());
143 
144   for (const Edge* in_edge : n->in_edges()) {
145     graph->AddEdge(in_edge->src(), in_edge->src_output(), cloned_node,
146                    in_edge->dst_input());
147   }
148 
149   for (const Edge* out_edge_to_clone : out_edges_to_clone) {
150     graph->AddEdge(cloned_node, out_edge_to_clone->src_output(),
151                    out_edge_to_clone->dst(), out_edge_to_clone->dst_input());
152     graph->RemoveEdge(out_edge_to_clone);
153   }
154 
155   if (n->out_edges().empty()) {
156     graph->RemoveNode(n);
157   }
158 
159   return OkStatus();
160 }
161 
162 // Clones nodes to outside their cluster to avoid device-to-host copies.  For
163 // instance, converts this:
164 //
165 //         .....
166 //           |
167 //           v
168 //      A_Clustered ====> C_Unclustered
169 //           |
170 //           v
171 //      B_Clustered
172 //
173 // to:
174 //
175 //         .....
176 //          | |
177 //          | +-------------+
178 //          |               |
179 //          v               v
180 //      A_Clustered   A_Unclustered ====> C_Unclustered
181 //           |
182 //           v
183 //      B_Clustered
184 //
185 // where the ===> arrow has a hostmem source and destination and would entail a
186 // device to host copy if the source and destination were not in the same XLA
187 // cluster.
PartiallyDeclusterGraph(Graph * graph)188 Status PartiallyDeclusterGraph(Graph* graph) {
189   // When deciding whether to decluster a particular node, we base our decision
190   // on if we've decided that some of its consumers have to be declustered too.
191   // Iterating the graph in post-order guarantees that consumers have been
192   // visited before producers.
193   std::vector<Node*> post_order;
194   GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(),
195                /*edge_filter=*/NotBackedge);
196 
197   absl::flat_hash_set<Node*> nodes_to_partially_decluster;
198   TF_RETURN_IF_ERROR(
199       FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order));
200 
201   if (VLOG_IS_ON(3)) {
202     for (Node* n : post_order) {
203       if (nodes_to_partially_decluster.count(n)) {
204         VLOG(3) << n->DebugString();
205       }
206     }
207   }
208 
209   for (Node* n : post_order) {
210     if (nodes_to_partially_decluster.count(n)) {
211       TF_RETURN_IF_ERROR(PartiallyDeclusterNode(graph, n));
212     }
213   }
214 
215   // Recompute post order since PartiallyDeclusterNode may have deleted nodes.
216   post_order.clear();
217   GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(),
218                /*edge_filter=*/NotBackedge);
219   nodes_to_partially_decluster.clear();
220   TF_RETURN_IF_ERROR(
221       FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order));
222   CHECK(nodes_to_partially_decluster.empty());
223 
224   return OkStatus();
225 }
226 }  // namespace reduce_device_to_host_copies
227 
228 namespace reduce_recompilation {
IsIntraClusterEdge(const Edge & edge)229 bool IsIntraClusterEdge(const Edge& edge) {
230   std::optional<absl::string_view> src_cluster_name =
231       GetXlaClusterForNode(*edge.src());
232   std::optional<absl::string_view> dst_cluster_name =
233       GetXlaClusterForNode(*edge.dst());
234   return src_cluster_name.has_value() && src_cluster_name == dst_cluster_name;
235 }
236 
IsMustCompileDevice(const DeviceType & device_type)237 bool IsMustCompileDevice(const DeviceType& device_type) {
238   const XlaOpRegistry::DeviceRegistration* registration;
239   if (XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration)) {
240     return registration->autoclustering_policy ==
241            XlaOpRegistry::AutoclusteringPolicy::kAlways;
242   }
243 
244   return false;
245 }
246 
MustCompileNode(const Node * n,bool * must_compile)247 Status MustCompileNode(const Node* n, bool* must_compile) {
248   DeviceType device_type("");
249   TF_RETURN_IF_ERROR(
250       DeviceNameToDeviceType(n->assigned_device_name(), &device_type));
251 
252   if (IsMustCompileDevice(device_type)) {
253     *must_compile = true;
254     return OkStatus();
255   }
256 
257   // We must compile `n` if it does not have a TensorFlow kernel.
258   *must_compile = !FindKernelDef(device_type, n->def(), nullptr, nullptr).ok();
259   return OkStatus();
260 }
261 
262 // Declusters nodes to reduce the number of times we think we need to recompile
263 // a TensorFlow graph.
264 //
265 // Abstractly, if we have a cluster of this form:
266 //
267 //   x0 = arg0
268 //   x1 = arg1
269 //     ...
270 //   shape = f(x0, x1, ...)
271 //   result = Reshape(input=<something>, new_shape=shape)
272 //
273 // then pulling `f` out of the cluster may reduce the number of compilations and
274 // will never increase the number of compilations.
275 //
276 // We may reduce the number of compilations if f is many to one.  For instance
277 // if f(x,y) = x-y then x=3,y=1 and x=4,y=2 will generate two different
278 // compilations if f is in the cluster but only one compilation if f is outside
279 // the cluster.
280 //
281 // Declustering f will increase the number of compilations only if f is a
282 // one-to-many "function" i.e. isn't a function at all.  RNG is one possible
283 // example, depending on how we look at it.  But we never create clusters where
284 // such f's would be marked as must-be-constant.
285 //
286 // We assume here that the extra repeated (repeated compared to a clustered f
287 // where it will always be constant folded) host-side computation of f does not
288 // regress performance in any significant manner.  We will have to revisit this
289 // algorithm with a more complex cost model if this assumption turns out to be
290 // incorrect.
PartiallyDeclusterGraph(Graph * graph,const FunctionLibraryDefinition * flib_def,Env * env)291 Status PartiallyDeclusterGraph(Graph* graph,
292                                const FunctionLibraryDefinition* flib_def,
293                                Env* env) {
294   std::vector<bool> compile_time_const_nodes(graph->num_node_ids());
295   OptimizerOptions opts;
296   auto pflr = std::make_unique<ProcessFunctionLibraryRuntime>(
297       nullptr, env, /*config=*/nullptr, TF_GRAPH_DEF_VERSION, flib_def, opts);
298   FunctionLibraryRuntime* lib_runtime =
299       pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
300   TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*graph, nullptr,
301                                             &compile_time_const_nodes,
302                                             lib_runtime, IsIntraClusterEdge));
303 
304   std::vector<Node*> rpo;
305   GetReversePostOrder(*graph, &rpo, /*stable_comparator=*/NodeComparatorName(),
306                       /*edge_filter=*/NotBackedge);
307   for (Node* n : rpo) {
308     if (!compile_time_const_nodes[n->id()]) {
309       continue;
310     }
311 
312     absl::string_view cluster_name = *GetXlaClusterForNode(*n);
313     bool node_on_cluster_edge =
314         absl::c_all_of(n->in_edges(), [&](const Edge* e) {
315           std::optional<absl::string_view> incoming_cluster =
316               GetXlaClusterForNode(*e->src());
317           return !incoming_cluster || *incoming_cluster != cluster_name;
318         });
319 
320     // We don't want to decluster F in a graph like
321     //
322     //   Input -> OP -> Shape -> F -> Reshape
323     //
324     // Doing so will break up the cluster.  Even if we were okay with breaking
325     // up the cluster we will at least have to relabel the two clusters to have
326     // different cluster names.
327     //
328     // We may want to revisit this in the future: we may have cases where OP is
329     // a small computation that does not benefit from XLA while XLA can optimize
330     // everything that follows the Reshape.  In these cases it may be wise to
331     // remove Input, OP, Shape and F from the cluster, if F is a many-to-one
332     // function.
333     //
334     // Note that we do the right thing for graphs like:
335     //
336     //   Input -> F0 -> F1 -> Reshape
337     //
338     // Since we iterate in RPO, we'll first encounter F0, decluster it, then
339     // encounter F1, decluster it and so on.
340     if (node_on_cluster_edge) {
341       bool must_compile_node;
342       TF_RETURN_IF_ERROR(MustCompileNode(n, &must_compile_node));
343       if (!must_compile_node) {
344         if (n->IsConstant()) {
345           // We must decluster Const nodes that have an input control edge from
346           // a different device, because this node may be part of the
347           // co-ordination of while loops between devices.
348           for (auto it : n->in_edges()) {
349             if (!it->src()->assigned_device_name().empty() &&
350                 it->src()->assigned_device_name() !=
351                     n->assigned_device_name()) {
352               VLOG(3) << "Declustering Const with cross-device control input "
353                       << n->name();
354               RemoveFromXlaCluster(n);
355               break;
356             }
357           }
358         } else {
359           VLOG(3) << "Declustering must-be-constant node " << n->name();
360           RemoveFromXlaCluster(n);
361         }
362       }
363     }
364   }
365 
366   return OkStatus();
367 }
368 }  // namespace reduce_recompilation
369 
370 namespace decluster_root_shape_consumers {
371 
PartiallyDeclusterGraph(Graph * graph)372 Status PartiallyDeclusterGraph(Graph* graph) {
373   std::vector<Node*> reverse_post_order;
374   GetReversePostOrder(*graph, &reverse_post_order,
375                       /*stable_comparator=*/NodeComparatorName(),
376                       /*edge_filter=*/NotBackedge);
377 
378   for (Node* n : reverse_post_order) {
379     if (!IsShapeConsumerOp(*n)) {
380       continue;
381     }
382 
383     std::optional<absl::string_view> cluster = GetXlaClusterForNode(*n);
384     if (!cluster.has_value()) {
385       continue;
386     }
387 
388     auto input_belongs_to_same_cluster = [&](const Edge* e) {
389       return cluster == GetXlaClusterForNode(*e->src());
390     };
391 
392     if (absl::c_any_of(n->in_edges(), input_belongs_to_same_cluster)) {
393       continue;
394     }
395 
396     VLOG(2) << "Declustering " << n->name()
397             << " because it is a root shape consumer";
398     RemoveFromXlaCluster(n);
399   }
400   return OkStatus();
401 }
402 }  // namespace decluster_root_shape_consumers
403 }  // namespace
404 
Run(const GraphOptimizationPassOptions & options)405 Status PartiallyDeclusterPass::Run(
406     const GraphOptimizationPassOptions& options) {
407   // NB!  In this pass we assume the only XLA-auto-clusterable operations that
408   // may have side effects are resource variable operations so we don't cluster
409   // those.  The pass will have to be updated if this assumption becomes
410   // invalid.
411 
412   Graph* graph = options.graph->get();
413 
414   TF_RETURN_IF_ERROR(
415       reduce_device_to_host_copies::PartiallyDeclusterGraph(graph));
416   if (options.flib_def == nullptr) {
417     return errors::InvalidArgument(
418         "GraphOptimizationPassOptions::flib_def must be set for "
419         "PartiallyDeclusterPass.");
420   }
421   if (options.session_options == nullptr ||
422       options.session_options->env == nullptr) {
423     return errors::InvalidArgument(
424         "GraphOptimizationPassOptions::session_options::env must be set for "
425         "PartiallyDeclusterPass.");
426   }
427   TF_RETURN_IF_ERROR(reduce_recompilation::PartiallyDeclusterGraph(
428       graph, options.flib_def, options.session_options->env));
429 
430   TF_RETURN_IF_ERROR(
431       decluster_root_shape_consumers::PartiallyDeclusterGraph(graph));
432 
433   return OkStatus();
434 }
435 }  // namespace tensorflow
436