1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
17
18 #include <algorithm>
19 #include <atomic>
20 #include <deque>
21 #include <limits>
22 #include <unordered_map>
23 #include <unordered_set>
24
25 #include "absl/base/call_once.h"
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/strings/str_join.h"
29 #include "tensorflow/compiler/jit/compilability_check_util.h"
30 #include "tensorflow/compiler/jit/deadness_analysis.h"
31 #include "tensorflow/compiler/jit/defs.h"
32 #include "tensorflow/compiler/jit/device_util.h"
33 #include "tensorflow/compiler/jit/flags.h"
34 #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
35 #include "tensorflow/compiler/jit/xla_cluster_util.h"
36 #include "tensorflow/compiler/tf2xla/const_analysis.h"
37 #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
38 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
39 #include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
40 #include "tensorflow/compiler/xla/statusor.h"
41 #include "tensorflow/compiler/xla/union_find.h"
42 #include "tensorflow/compiler/xla/util.h"
43 #include "tensorflow/core/common_runtime/function.h"
44 #include "tensorflow/core/common_runtime/graph_constructor.h"
45 #include "tensorflow/core/framework/bounds_check.h"
46 #include "tensorflow/core/framework/graph_def_util.h"
47 #include "tensorflow/core/framework/memory_types.h"
48 #include "tensorflow/core/framework/node_def.pb.h"
49 #include "tensorflow/core/framework/op_kernel.h"
50 #include "tensorflow/core/framework/tensor.pb.h"
51 #include "tensorflow/core/framework/types.h"
52 #include "tensorflow/core/graph/algorithm.h"
53 #include "tensorflow/core/graph/control_flow.h"
54 #include "tensorflow/core/lib/gtl/cleanup.h"
55 #include "tensorflow/core/lib/gtl/flatmap.h"
56 #include "tensorflow/core/lib/strings/stringprintf.h"
57 #include "tensorflow/core/platform/errors.h"
58 #include "tensorflow/core/platform/mutex.h"
59 #include "tensorflow/core/platform/statusor.h"
60 #include "tensorflow/core/platform/types.h"
61 #include "tensorflow/core/public/version.h"
62 #include "tensorflow/core/util/dump_graph.h"
63
64 namespace tensorflow {
65
66 namespace {
67 using DeadnessPredicate = DeadnessAnalysis::DeadnessPredicate;
68 using jit::DeviceId;
69 using jit::DeviceSet;
70
71 // The clusters we create here are eventually lowered into an
72 // _XlaCompile/_XlaRun pair with a TF executor "fallback" that uses the
73 // PartitionedCall op to execute the cluster in the regular graph executor if
74 // need be. PartitionedCall, however, reruns the entire TF graph optimization
75 // pipeline over the cluster which includes this mark for compilation pass. To
76 // avoid endlessly recursing we tag nodes that we've already visited with this
77 // attribute so that we can bail out if we see them a second time.
78 //
79 // TODO(sanjoy): This method is not robust since it is possible that the
80 // optimizations run by PartitionedCall can mutate the cluster arbitrarily,
81 // dropping the kXlaAlreadyClustered attributes from all nodes in the process.
82 // The correct fix is to use the ConfigProto to pass in some sort of flag into
83 // the PartitionedCall kernel that tells it to not rerun auto-clustering on the
84 // cluster.
85 const char* kXlaAlreadyClustered = "_XlaAlreadyClustered";
86
87 class MarkForCompilationPassImpl {
88 public:
89 struct DebugOptions {
90 // If true, do not respect the results of deadness analysis.
91 bool ignore_deadness_checks;
92
93 // If true, do not do safety checks to preserve TensorFlow's resource
94 // variable concurrency semantics.
95 bool ignore_resource_variable_checks;
96
97 // If true, do not respect the _XlaCompile=false attribute.
98 bool ignore_xla_compile_attr;
99
100 // If true, compute the cluster name in a deterministic way so that its
101 // stable from run to rum.
102 bool deterministic_cluster_names;
103
104 int max_cluster_size;
105 int min_cluster_size;
106
107 // Compiler fuel for the auto-clustering algorithm.
108 //
109 // We decrement this value by one on every time we choose a compilation
110 // candidate and we stop clustering when it hits zero. This means the
111 // initial value for this variable (via --tf_xla_clustering_fuel=N)
112 // effectively acts as a "cap" for how much we cluster and we can bisect
113 // over this initial value to discover clustering decisions that cause a
114 // miscompile or a performance regression.
115 std::atomic<int64_t>* fuel;
116
117 bool dump_graphs;
118 };
119
MarkForCompilationPassImpl(DebugOptions debug_options,Graph * graph,FunctionLibraryDefinition * flib_def,Env * env,OptimizerOptions::GlobalJitLevel global_jit_level,bool cpu_global_jit)120 MarkForCompilationPassImpl(DebugOptions debug_options, Graph* graph,
121 FunctionLibraryDefinition* flib_def, Env* env,
122 OptimizerOptions::GlobalJitLevel global_jit_level,
123 bool cpu_global_jit)
124 : debug_options_(debug_options),
125 graph_(graph),
126 graph_fingerprint_(0),
127 flib_def_(flib_def),
128 env_(env),
129 global_jit_level_(global_jit_level),
130 cpu_global_jit_(cpu_global_jit) {}
131
132 Status Run();
133
134 private:
135 // Represents a "cluster" or a connected subgraph of a TensorFlow graph.
136 class Cluster {
137 public:
138 // Constructs a trivial cluster representing a single TF node.
Cluster(int tf_graph_node_id,int effective_cluster_size,bool has_functional_control_flow,DeviceSet devices,std::optional<DeviceId> resource_op_device,std::optional<int> resource_var_operation_node_id,std::optional<DeadnessPredicate> deadness_predicate,bool is_xla_compile_attr_true,std::optional<string> xla_scope)139 Cluster(int tf_graph_node_id, int effective_cluster_size,
140 bool has_functional_control_flow, DeviceSet devices,
141 std::optional<DeviceId> resource_op_device,
142 std::optional<int> resource_var_operation_node_id,
143 std::optional<DeadnessPredicate> deadness_predicate,
144 bool is_xla_compile_attr_true, std::optional<string> xla_scope)
145 : cycles_graph_node_id_(tf_graph_node_id),
146 effective_cluster_size_(effective_cluster_size),
147 has_functional_control_flow_(has_functional_control_flow),
148 devices_(std::move(devices)),
149 resource_op_device_(resource_op_device),
150 deadness_predicate_(deadness_predicate),
151 is_xla_compile_attr_true_(is_xla_compile_attr_true),
152 xla_scope_(std::move(xla_scope)) {
153 if (resource_var_operation_node_id.has_value()) {
154 resource_var_operation_node_ids_.push_back(
155 *resource_var_operation_node_id);
156 }
157 }
158
159 // Merges `other` into this cluster, and clears `other`. This method is
160 // closely tied with the implementation of `MarkForCompilationPassImpl`.
161 void Merge(Cluster* other);
162
163 // If this is a trivial cluster containing only one node then return the ID
164 // of that node. May not be called otherwise.
GetIdOfOnlyNode() const165 int GetIdOfOnlyNode() const {
166 DCHECK_EQ(cluster_size(), 1);
167 return cycles_graph_node_id();
168 }
169
170 // The number of TF nodes in this cluster.
cluster_size() const171 int cluster_size() const { return cluster_size_; }
172
173 // The ID of the cluster as represented in `cycles_graph_`.
cycles_graph_node_id() const174 int cycles_graph_node_id() const { return cycles_graph_node_id_; }
175
176 // Sets the ID of the cluster as represented in `cycles_graph_`.
set_cycles_graph_node_id(int cycles_graph_node_id)177 void set_cycles_graph_node_id(int cycles_graph_node_id) {
178 cycles_graph_node_id_ = cycles_graph_node_id;
179 }
180
181 // The size of the cluster excluding constant and identity nodes.
effective_cluster_size() const182 int effective_cluster_size() const { return effective_cluster_size_; }
183
184 // True if the cluster has functional control flow like `If` and `While`.
has_functional_control_flow() const185 bool has_functional_control_flow() const {
186 return has_functional_control_flow_;
187 }
188
189 // The set of devices nodes in the cluster are placed on.
devices() const190 const DeviceSet& devices() const { return devices_; }
191
192 // If the cluster has a resource operation then the device the resource
193 // operation is placed on. A cluster may have resource ops placed only on a
194 // single device.
resource_op_device() const195 const std::optional<DeviceId>& resource_op_device() const {
196 return resource_op_device_;
197 }
198
199 // If not nullopt the a predicate that is true iff the cluster is alive.
200 // Otherwise the user has (unsafely) disabled deadness analysis. If this is
201 // unset on a single Cluster instance then it is unset on all Cluster
202 // instances.
deadness_predicate() const203 const std::optional<DeadnessPredicate>& deadness_predicate() const {
204 return deadness_predicate_;
205 }
206
207 // If true then the cluster has a XlaCompile=true attribute on one of its
208 // nodes.
is_xla_compile_attr_true() const209 bool is_xla_compile_attr_true() const { return is_xla_compile_attr_true_; }
210
211 // If not nullopt then the all nodes in the cluster either do not have the
212 // XlaScope attribute set or have it set to the value returned.
xla_scope() const213 const std::optional<string>& xla_scope() const { return xla_scope_; }
214
215 // Returns the TF graph node IDs for the resource variable operations in
216 // this cluster.
resource_var_operation_node_ids() const217 absl::Span<const int> resource_var_operation_node_ids() const {
218 return resource_var_operation_node_ids_;
219 }
220
DebugString(const Graph & graph) const221 string DebugString(const Graph& graph) const {
222 Node* node = graph.FindNodeId(cycles_graph_node_id());
223 if (!node) {
224 // This should never happen but we try to be resilient because this is a
225 // debugging aid.
226 return absl::StrCat("NULL NODE IN #", cycles_graph_node_id());
227 }
228
229 if (cluster_size() == 1) {
230 return absl::StrCat("<", node->name(), " #", cycles_graph_node_id(),
231 ">");
232 }
233
234 return absl::StrCat("<", node->name(), " + ", cluster_size() - 1,
235 " others #", cycles_graph_node_id(), ">");
236 }
237
238 private:
239 int cluster_size_ = 1;
240 int cycles_graph_node_id_;
241 int effective_cluster_size_;
242 bool has_functional_control_flow_;
243 DeviceSet devices_;
244 std::optional<DeviceId> resource_op_device_;
245 std::optional<DeadnessPredicate> deadness_predicate_;
246 bool is_xla_compile_attr_true_;
247 std::optional<string> xla_scope_;
248 std::vector<int> resource_var_operation_node_ids_;
249
250 TF_DISALLOW_COPY_AND_ASSIGN(Cluster);
251 };
252
253 // If `cluster` has only a single node then returns that, otherwise returns
254 // nullptr.
255 Node* GetOnlyNodeIn(const Cluster& cluster);
256
257 // Returns true if `cluster` is a trivial cluster containing a "sink like"
258 // node -- a NoOp node that only the Sink node control depends on.
259 bool IsSinkLike(const Cluster& cluster);
260
261 // Returns true if `cluster` looks like an "i++" operation on an integer
262 // scalar resource variable.
263 bool IsScalarIntegerResourceOperation(const Cluster& cluster);
264
265 // ---------------------------------------------------------------------------
266 // The pass proceeds in five steps, out of which `RunEdgeContractionLoop` and
267 // `CreateClusters` do most of the heavy lifting.
268
269 // Initializes some internal data structures.
270 //
271 // If this returns false then Initialize exited early (either because there is
272 // nothing to do or we saw a graph that we can't handle) and not all the
273 // fields in this MarkForCompilationPassImpl instance are set up.
274 StatusOr<bool> Initialize();
275
276 // Runs through the entire cluster graph in post-order and calls `fn(from,
277 // to)` on each edge. `fn(from, to)` is expected to return true if it was
278 // able to contract `from`->`to`.
279 //
280 // Returns true if `fn` returned true for any edge.
281 template <typename FnTy>
282 StatusOr<bool> ForEachEdgeInPostOrder(FnTy fn);
283
284 // Contracts as many edges as possible to create XLA clusters. After this
285 // finishes the clustering decisions made are implicitly stored in
286 // `clusters_`.
287 Status RunEdgeContractionLoop();
288
289 // "Fixes up" clusters by removing some modes.
290 //
291 // Autoclustering can sometimes be overeager. For example, clustering large
292 // constants (or large broadcasts of constants) can increase the live range
293 // of those constants, and increase overall memory usage.
294 //
295 // This function removes "obviously bad" cases like these.
296 Status DeclusterNodes();
297
298 // Manifests the clustering decisions into the TF graph by tagging nodes with
299 // an `_XlaCluster` attribute. Also some basic filter logic, like
300 // tf_xla_min_cluster_size, are applied here.
301 Status CreateClusters();
302
303 Status DumpDebugInfo();
304
IsCompilationCandidate(Node * n) const305 bool IsCompilationCandidate(Node* n) const {
306 return compilation_candidates_.find(n) != compilation_candidates_.end();
307 }
308
309 // Tries to contract the edge from cluster `from` to cluster `to`. Returns
310 // true if successful.
311 StatusOr<bool> TryToContractEdge(Cluster* from, Cluster* to);
312
313 // Nodes that XLA can compile are put in `compilation_candidates_`.
314 Status FindCompilationCandidates();
315
316 bool CompilationDisallowedByXlaCompileAttr(Node* node);
317
318 // Populates `clusters_`.
319 Status BuildInitialClusterSet();
320
321 StatusOr<bool> ShouldCompileClusterImpl(const Cluster& cluster);
322
323 StatusOr<bool> ShouldCompileCluster(const Cluster& cluster);
324
325 StatusOr<bool> ClusteringWillIntroduceInterDeviceDependency(
326 const Cluster& from, const Cluster& to);
327
328 // Returns true if the devices in `cluster_a` and `cluster_b` are compatible
329 // and therefore not a hindrance for combining the two clusters into a larger
330 // cluster.
331 StatusOr<bool> AreDevicesCompatible(const Cluster& cluster_a,
332 const Cluster& cluster_b);
333
334 void DumpPostClusteringGraphs();
335 void VLogClusteringSummary();
336
MakeNewCluster(int cycles_graph_node_id,int effective_cluster_size,bool has_functional_control_flow,const DeviceSet & device_set,std::optional<DeviceId> resource_op_device,std::optional<int> resource_var_operation_node_id,std::optional<DeadnessPredicate> deadness_predicate,bool is_xla_compile_attr_true,std::optional<string> xla_scope)337 Cluster* MakeNewCluster(int cycles_graph_node_id, int effective_cluster_size,
338 bool has_functional_control_flow,
339 const DeviceSet& device_set,
340 std::optional<DeviceId> resource_op_device,
341 std::optional<int> resource_var_operation_node_id,
342 std::optional<DeadnessPredicate> deadness_predicate,
343 bool is_xla_compile_attr_true,
344 std::optional<string> xla_scope) {
345 cluster_storage_.push_back(std::make_unique<Cluster>(
346 cycles_graph_node_id, effective_cluster_size,
347 has_functional_control_flow, device_set, resource_op_device,
348 resource_var_operation_node_id, deadness_predicate,
349 is_xla_compile_attr_true, xla_scope));
350 return cluster_storage_.back().get();
351 }
352
353 std::optional<string> GetXlaScope(Node* n);
354
355 // Returns the cluster for node `n`. If two nodes, N1 and N2, are placed in
356 // the same cluster by the clustering algorithm then this function will return
357 // the same Cluster instance for N1 and N2.
358 //
359 // Returns nullptr if `n` is not a compilation candidate.
GetClusterForNode(Node * n)360 Cluster* GetClusterForNode(Node* n) {
361 return cluster_for_node_[n->id()].Get();
362 }
363
364 // Returns the cluster for a node in `cycles_graph_`. This uses the same
365 // underlying map because of how we set things up, but we can do an additional
366 // CHECK in this accessor.
367 //
368 // Returns nullptr if `node_id` is not a compilation candidate.
GetClusterForCyclesGraphNode(int node_id)369 Cluster* GetClusterForCyclesGraphNode(int node_id) {
370 // We have to check `graph_->FindNodeId(node) == nullptr` because we add all
371 // nodes in [0, graph_->num_node_ids()) to the cycle detection graph but the
372 // TF graph may be missing some node ids.
373 if (node_id >= graph_->num_node_ids() ||
374 graph_->FindNodeId(node_id) == nullptr) {
375 return nullptr;
376 }
377 Cluster* cluster = cluster_for_node_[node_id].Get();
378 if (cluster) {
379 DCHECK_EQ(cluster->cycles_graph_node_id(), node_id);
380 }
381 return cluster;
382 }
383
384 bool LogNotContractableAndReturnFalse(Cluster* from, Cluster* to,
385 absl::string_view reason);
386
387 // Finds a path in `cycles_graph_` from `from` to `to` that is not a direct
388 // edge from `from` to `to`.
389 //
390 // Tries to find a path that contains at least one unclusterable node.
391 std::vector<int> FindAlternatePathForDebugging(int from, int to);
392
393 // Returns a string representing `cycles_graph_node_id`. If the node is
394 // unclusterable (either it is a phatom "frame" node or is not a compilation
395 // candidate) then set `*found_unclustered` to true.
396 string DebugStringForCyclesGraphNode(int node_id, bool* found_unclustered);
397
398 // We could not contract the edge from `from` to `to`. Return a string
399 // describing an alternate path from `from` to `to` (besides the direct edge
400 // from `from` to `to`) which would have created a cycle had we contracted the
401 // edge.
402 //
403 // Tries (if possible) to find a path that contains at least one unclusterable
404 // node as it is surprising to the user if we print "A->B could not be
405 // contracted because of the path [P,Q,R]" where P, Q and R are all clusters
406 // since in that case a natural question is why we could not form a {A, P, Q,
407 // R, B} cluster.
408 string DescribePotentialCycle(int from, int to);
409
410 // Merge the clusters `cluster_from` and `cluster_to`. After this step the
411 // larger combined cluster is represented by `cluster_from`, but can have
412 // `cycles_graph_`'s ID of either `cluster_from` or `cluster_to` depending on
413 // which way will require less operations.
MergeClusters(Cluster * cluster_from,Cluster * cluster_to)414 bool MergeClusters(Cluster* cluster_from, Cluster* cluster_to) {
415 int from = cluster_from->cycles_graph_node_id();
416 int to = cluster_to->cycles_graph_node_id();
417
418 auto optional_merged_node = cycles_graph_.ContractEdge(from, to);
419 if (!optional_merged_node.has_value()) {
420 VLOG(3) << "Could not contract " << cluster_from->DebugString(*graph_)
421 << " -> " << cluster_to->DebugString(*graph_)
422 << " because contracting the edge would create a cycle via "
423 << DescribePotentialCycle(from, to) << ".";
424 return false;
425 }
426
427 // Merge the clusters.
428 cluster_from->Merge(cluster_to);
429 // Update `cycle_graph_`'s ID.
430 cluster_from->set_cycles_graph_node_id(optional_merged_node.value());
431
432 // Merge the UnionFind<Cluster*>.
433 cluster_for_node_[from].Merge(&cluster_for_node_[to]);
434
435 return true;
436 }
437
EdgeContractionFailureMsg(Cluster * from,Cluster * to,absl::string_view reason)438 string EdgeContractionFailureMsg(Cluster* from, Cluster* to,
439 absl::string_view reason) {
440 return absl::StrCat("Could not contract ", from->DebugString(*graph_),
441 " -> ", to->DebugString(*graph_), " because ", reason,
442 ".");
443 }
444
445 DebugOptions debug_options_;
446 Graph* graph_;
447 uint64 graph_fingerprint_;
448 FunctionLibraryDefinition* flib_def_;
449 Env* env_;
450 OptimizerOptions::GlobalJitLevel global_jit_level_;
451 bool cpu_global_jit_;
452 absl::flat_hash_map<const Cluster*, bool> should_compile_cluster_cache_;
453 jit::DeviceInfoCache device_info_cache_;
454
455 bool initialized_ = false;
456 bool edges_contracted_ = false;
457 bool clusters_created_ = false;
458
459 std::vector<std::unique_ptr<Cluster>> cluster_storage_;
460 std::vector<UnionFind<Cluster*>> cluster_for_node_;
461 absl::flat_hash_set<const Node*> declustered_nodes_;
462 GraphCycles cycles_graph_;
463 OrderedNodeSet compilation_candidates_;
464 std::unique_ptr<DeadnessAnalysis> deadness_analysis_;
465 int64_t iteration_count_ = 0;
466 absl::flat_hash_set<std::pair<int, int>> unsafe_resource_deps_;
467 };
468
FindAlternatePathForDebugging(int from,int to)469 std::vector<int> MarkForCompilationPassImpl::FindAlternatePathForDebugging(
470 int from, int to) {
471 std::vector<int> rpo = cycles_graph_.AllNodesInPostOrder();
472 absl::c_reverse(rpo);
473
474 // best_pred_for_node[n] contains a predecessor of `n` that has an
475 // unclusterable node in some path from `from` to itself.
476 // best_pred_for_node[n] is unpopulated for nodes that are not reachable from
477 // `from`. We build this table up inductively by traversing the cycles graph
478 // in RPO.
479 absl::flat_hash_map<int, int> best_pred_for_node;
480 best_pred_for_node[from] = -1;
481
482 int rpo_index = 0, current_rpo_node;
483 do {
484 current_rpo_node = rpo[rpo_index++];
485 std::optional<int> some_pred, preferred_pred;
486 for (int pred : cycles_graph_.Predecessors(current_rpo_node)) {
487 if (!best_pred_for_node.contains(pred)) {
488 continue;
489 }
490
491 // Ignore the from->to edge since we're trying to find an alternate path.
492 if (current_rpo_node == to && pred == from) {
493 continue;
494 }
495
496 some_pred = pred;
497 if (GetClusterForCyclesGraphNode(pred) == nullptr) {
498 preferred_pred = pred;
499 }
500 }
501
502 if (some_pred || preferred_pred) {
503 best_pred_for_node[current_rpo_node] =
504 preferred_pred.has_value() ? *preferred_pred : *some_pred;
505 }
506 } while (current_rpo_node != to);
507
508 auto get_best_pred = [&](int n) {
509 auto it = best_pred_for_node.find(n);
510 CHECK(it != best_pred_for_node.end());
511 return it->second;
512 };
513
514 std::vector<int> path;
515 int current_path_node = get_best_pred(to);
516 while (current_path_node != from) {
517 path.push_back(current_path_node);
518 current_path_node = get_best_pred(current_path_node);
519 }
520
521 absl::c_reverse(path);
522 return path;
523 }
524
DebugStringForCyclesGraphNode(int cycles_graph_node_id,bool * found_unclustered)525 string MarkForCompilationPassImpl::DebugStringForCyclesGraphNode(
526 int cycles_graph_node_id, bool* found_unclustered) {
527 Cluster* cluster = GetClusterForCyclesGraphNode(cycles_graph_node_id);
528 if (cluster) {
529 return cluster->DebugString(*graph_);
530 }
531
532 *found_unclustered = true;
533 if (cycles_graph_node_id >= graph_->num_node_ids()) {
534 return absl::StrCat("<oob #", cycles_graph_node_id, ">");
535 }
536
537 Node* node = graph_->FindNodeId(cycles_graph_node_id);
538 if (!node) {
539 return absl::StrCat("<bad #", cycles_graph_node_id, ">");
540 }
541
542 return node->name();
543 }
544
DescribePotentialCycle(int from,int to)545 string MarkForCompilationPassImpl::DescribePotentialCycle(int from, int to) {
546 std::vector<string> path_str;
547 bool found_unclustered = false;
548 absl::c_transform(FindAlternatePathForDebugging(from, to),
549 std::back_inserter(path_str), [&](int node_id) {
550 return DebugStringForCyclesGraphNode(node_id,
551 &found_unclustered);
552 });
553 return absl::StrCat(!found_unclustered ? "(all clusters) " : "", "[",
554 absl::StrJoin(path_str, ","), "]");
555 }
556
Merge(Cluster * other)557 void MarkForCompilationPassImpl::Cluster::Merge(Cluster* other) {
558 // We keep our own cycles_graph_node_id_ to mirror what GraphCycles does.
559
560 // Clearing out data structures in `other` is just a memory saving
561 // optimization and not needed for correctness.
562
563 cluster_size_ += other->cluster_size_;
564 effective_cluster_size_ += other->effective_cluster_size_;
565 has_functional_control_flow_ |= other->has_functional_control_flow_;
566
567 devices_.UnionWith(other->devices_);
568
569 DCHECK(!(resource_op_device_.has_value() &&
570 other->resource_op_device_.has_value()) ||
571 *resource_op_device_ == *other->resource_op_device_)
572 << "AreDevicesCompatible should have returned false otherwise!";
573
574 if (!resource_op_device_.has_value()) {
575 resource_op_device_ = other->resource_op_device_;
576 }
577
578 is_xla_compile_attr_true_ |= other->is_xla_compile_attr_true_;
579
580 if (!xla_scope_.has_value()) {
581 xla_scope_ = std::move(other->xla_scope_);
582 }
583
584 resource_var_operation_node_ids_.reserve(
585 resource_var_operation_node_ids_.size() +
586 other->resource_var_operation_node_ids_.size());
587 absl::c_copy(other->resource_var_operation_node_ids_,
588 std::back_inserter(resource_var_operation_node_ids_));
589 other->resource_var_operation_node_ids_.clear();
590 }
591
IgnoreResourceOpForSafetyAnalysis(jit::DeviceInfoCache * device_info_cache,const Node & n,bool * ignore)592 Status IgnoreResourceOpForSafetyAnalysis(
593 jit::DeviceInfoCache* device_info_cache, const Node& n, bool* ignore) {
594 // If a resource operation is assigned to XLA_CPU or XLA_GPU explicitly then
595 // ignore it during resource operation safety analysis. We need this hack
596 // because of two reasons:
597 //
598 // 1. Operations assigned to XLA_CPU and XLA_GPU have to always be compiled.
599 // 2. We don't support live-out values of type DT_RESOURCE and live-in values
600 // of type DT_RESOURCE that are not resource variables.
601 //
602 // Together these imply we cannot let resource variable safety analysis
603 // constrain e.g. a TensorArrayV3->TensorArrayAssignV3 edge to be in different
604 // clusters: both of them will have to be clustered because of (1) and we
605 // won't be able to keep the edge between the two as neither the input to the
606 // second XLA cluster nor the output from the first XLA cluster are supported
607 // because of (2).
608 //
609 // TODO(b/113100872): This can be fixed if the TensorFlow representation for
610 // TensorArray and Stack on the XLA_{C|G}PU devices were the same in XLA; then
611 // (2) would no longer hold.
612
613 if (n.assigned_device_name().empty()) {
614 *ignore = false;
615 return OkStatus();
616 }
617
618 TF_ASSIGN_OR_RETURN(
619 const XlaOpRegistry::DeviceRegistration* registration,
620 device_info_cache->GetCompilationDevice(n.assigned_device_name()));
621
622 if (!registration) {
623 *ignore = true;
624 } else {
625 *ignore = registration->cluster_resource_variable_ops_unsafely;
626 }
627 return OkStatus();
628 }
629
Initialize()630 StatusOr<bool> MarkForCompilationPassImpl::Initialize() {
631 TF_RET_CHECK(!initialized_ && !edges_contracted_ && !clusters_created_);
632 initialized_ = true;
633
634 TF_RETURN_IF_ERROR(FindCompilationCandidates());
635
636 if (compilation_candidates_.empty()) {
637 VLOG(2) << "No compilable candidates";
638 return false;
639 }
640
641 TF_ASSIGN_OR_RETURN(bool cycle_detection_graph_ok,
642 CreateCycleDetectionGraph(graph_, &cycles_graph_));
643 if (!cycle_detection_graph_ok) {
644 // TODO(sanjoy): This should be logged via the XLA activity listener.
645 VLOG(2) << "Could not form cycle detection graph";
646 return false;
647 }
648
649 if (!debug_options_.ignore_deadness_checks) {
650 XLA_SCOPED_LOGGING_TIMER_LEVEL("DeadnessAnalysis", 1);
651 TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(*graph_, &deadness_analysis_));
652 }
653
654 // If the user is requesting deterministic cluster names compute a hash of the
655 // input graph to provide a stable but unique prefix for the name.
656 if (debug_options_.deterministic_cluster_names) {
657 TF_ASSIGN_OR_RETURN(graph_fingerprint_, FingerprintGraph(*graph_));
658 }
659
660 // Each compilation candidate belongs to a cluster. The cluster's
661 // representative names the node in the 'cycles' graph that represents the
662 // cluster.
663 TF_RETURN_IF_ERROR(BuildInitialClusterSet());
664 return true;
665 }
666
667 template <typename FnTy>
ForEachEdgeInPostOrder(FnTy fn)668 StatusOr<bool> MarkForCompilationPassImpl::ForEachEdgeInPostOrder(FnTy fn) {
669 bool changed = false;
670 for (int32_t node : cycles_graph_.AllNodesInPostOrder()) {
671 Cluster* cluster_from = GetClusterForCyclesGraphNode(node);
672 if (!cluster_from) {
673 continue;
674 }
675
676 // Make a copy of the set of successors because we may modify the graph in
677 // TryToContractEdge.
678 std::vector<int32> successors_copy =
679 cycles_graph_.SuccessorsCopy(cluster_from->cycles_graph_node_id());
680
681 for (int to : successors_copy) {
682 iteration_count_++;
683
684 Cluster* cluster_to = GetClusterForCyclesGraphNode(to);
685 if (!cluster_to) {
686 continue;
687 }
688
689 TF_ASSIGN_OR_RETURN(bool contracted_edge, fn(cluster_from, cluster_to));
690 changed |= contracted_edge;
691 }
692 }
693
694 return changed;
695 }
696
GetOnlyNodeIn(const Cluster & cluster)697 Node* MarkForCompilationPassImpl::GetOnlyNodeIn(const Cluster& cluster) {
698 return cluster.cluster_size() == 1
699 ? graph_->FindNodeId(cluster.GetIdOfOnlyNode())
700 : nullptr;
701 }
702
IsSinkLike(const Cluster & cluster)703 bool MarkForCompilationPassImpl::IsSinkLike(const Cluster& cluster) {
704 if (Node* n = GetOnlyNodeIn(cluster)) {
705 return n->type_string() == "NoOp" && n->out_edges().size() == 1 &&
706 (*n->out_edges().begin())->dst()->IsSink();
707 }
708
709 return false;
710 }
711
IsScalarIntegerResourceOperation(const Cluster & cluster)712 bool MarkForCompilationPassImpl::IsScalarIntegerResourceOperation(
713 const Cluster& cluster) {
714 Node* n = GetOnlyNodeIn(cluster);
715 if (!n) {
716 return false;
717 }
718
719 if (n->type_string() != "AssignAddVariableOp" &&
720 n->type_string() != "AssignSubVariableOp") {
721 return false;
722 }
723
724 DataType dtype;
725 if (!TryGetNodeAttr(n->def(), "dtype", &dtype) || !DataTypeIsInteger(dtype)) {
726 return false;
727 }
728
729 Node* const_input = nullptr;
730 for (const Edge* e : n->in_edges()) {
731 if (!e->IsControlEdge() && e->src()->IsConstant()) {
732 const_input = e->src();
733 break;
734 }
735 }
736
737 if (!const_input) {
738 return false;
739 }
740
741 const TensorProto* proto = nullptr;
742 if (!TryGetNodeAttr(const_input->def(), "value", &proto)) {
743 return false;
744 }
745
746 return TensorShapeUtils::IsScalar(proto->tensor_shape());
747 }
748
RunEdgeContractionLoop()749 Status MarkForCompilationPassImpl::RunEdgeContractionLoop() {
750 TF_RET_CHECK(initialized_ && !edges_contracted_ && !clusters_created_);
751 edges_contracted_ = true;
752
753 // TODO(hpucha): Handle the case where kXlaClusterAttr is already set (for
754 // example, from the Grappler fusion pass).
755
756 // In general there are multiple maximal clusterings, but they are not all
757 // equally performant. Some clustering decision are likely to improve
758 // performance much more than others, and we cannot order contractions on this
759 // cost function, nor can we look at global information while deciding on
760 // individual edges to contract. Instead, we will make decisions on these
761 // important edges then make decisions on all other edges, causing the highest
762 // chance of all most important edges to be contracted.
763 //
764 // An example of where this might occur is with a digraph:
765 // {A -> B, B -> C, A -> X, X -> C} where B is a Size operation and X is
766 // not-compilable. In this case, the valid clusterings are {A,B} or {B,C}. B
767 // should be clustered with A because it will prevent a potentially large
768 // tensor from A being computed and copied.
769 //
770 // To choose better maximal clusterings we make multiple iterations over the
771 // graph in post-order, where each such iteration is called a "phase".
772
773 // Phase 0: contract metadata operations with their producer.
774
775 VLOG(4) << "Running phase 0";
776 TF_RETURN_IF_ERROR(
777 ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) -> StatusOr<bool> {
778 // Shape consuming operations are desirable to cluster with their
779 // operands because they return a small set of scalar values after
780 // consuming a large amount of data. For example, given a graph X -> Y
781 // -> Size -> Z, where the possible clustering is [{X, Y, Size}, {Z}] or
782 // [{X, Y}, {Size, Z}], the better clustering is Size with Y because the
783 // output of size will be a small tensor while Y is a potentially large
784 // tensor that must be computed and possible transposed/copied before
785 // the second cluster executes.
786 Node* n = GetOnlyNodeIn(*to);
787 bool is_shape_consumer_op = n && IsShapeConsumerOp(*n);
788 if (!is_shape_consumer_op) {
789 return false;
790 }
791
792 return TryToContractEdge(from, to);
793 }).status());
794
795 // Phase 1: apply a heuristic to ensure that we don't mess up clustering due
796 // to "group_deps". After this phase most edges should have been contracted.
797
798 VLOG(4) << "Running phase 1";
799 TF_RETURN_IF_ERROR(
800 ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) -> StatusOr<bool> {
801 // We split out this phase to get good clustering in the presence of a
802 // specific pattern seen in some graphs:
803 //
804 // digraph {
805 // ApplyWeightUpdates_0 -> "iteration++"
806 // ApplyWeightUpdates_1 -> "iteration++"
807 // ApplyWeightUpdates_2 -> "iteration++"
808 // ApplyWeightUpdates_0 -> Computation_A
809 // ApplyWeightUpdates_1 -> Computation_B
810 // ApplyWeightUpdates_2 -> Computation_C
811 // Computation_A -> NoOp
812 // Computation_B -> NoOp
813 // Computation_C -> NoOp
814 // "iteration++" -> NoOp
815 // }
816 //
817 // In the graph above we can't cluster iteration++ with any of the
818 // gradient update operations since that will break the TF resource
819 // variable memory model. Given that constraint the ideal clustering
820 // would be to put all the gradient updates and all of the Computation_*
821 // nodes in one cluster, and leave iteration++ and NoOp unclustered.
822 //
823 // A naive post-order traversal would not create this good clustering,
824 // however. Instead it will first create a cluster that puts
825 // Computation_* nodes, the NoOp and iteration++ node in a single
826 // cluster, after which it will fail to put any of the
827 // ApplyWeightUpdates_* nodes into this cluster. To avoid this fate we
828 // instead run a pass that avoids contracting edges _into_ NoOps like
829 // the above, and avoid clustering edges _from_ "iteration++" like the
830 // above. Then we run a second pass that contracts the edges we could
831 // not contract the first time around.
832
833 if (IsSinkLike(*to)) {
834 return false;
835 }
836
837 if (IsScalarIntegerResourceOperation(*from)) {
838 return false;
839 }
840
841 return TryToContractEdge(from, to);
842 }).status());
843
844 // Phase 2: contract any remaining edges. After this phase we should have a
845 // maximal clustering:
846 //
847 // A. We visit a cluster only after maximally clustering all its children.
848 // B. By the time we're done with a node all of its children that could have
849 // been absorbed into the node have been absorbed.
850 // C. We have an invariant that making a cluster larger does not make edges
851 // leaving it more contractable. That is, if we have
852 // digraph { X->Y; Y->Z; } then collapsing X->Y does not make it possible
853 // to contract Y->Z if Y->Z was not contractible originally.
854 VLOG(4) << "Running phase 2";
855 TF_RETURN_IF_ERROR(ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) {
856 return TryToContractEdge(from, to);
857 }).status());
858
859 // Check that the conclusion made above (that iterating over the graph once in
860 // post order gives a maximal clustering) holds. Once the linear time
861 // post-order scheme has been battle tested we can move this to happen only in
862 // debug builds.
863 VLOG(2) << "Checking idempotence";
864 TF_ASSIGN_OR_RETURN(bool changed,
865 ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) {
866 return TryToContractEdge(from, to);
867 }));
868 TF_RET_CHECK(!changed);
869
870 return OkStatus();
871 }
872
DeclusterNodes()873 Status MarkForCompilationPassImpl::DeclusterNodes() {
874 for (Node* n : compilation_candidates_) {
875 Cluster* cluster = GetClusterForNode(n);
876 if (cluster == nullptr) {
877 continue;
878 }
879
880 // De-cluster Fill ops that are
881 // - used at least once outside the cluster, and
882 // - not used inside the cluster.
883 //
884 // In this case, using XLA for the op can only make peak memory usage worse.
885 // If we don't cluster the Fill, it can be materialized right before it's
886 // used in the TF graph. Whereas if we do cluster it, the Fill must be live
887 // starting at the end of the XLA cluster, potentially significantly
888 // increasing its live range.
889 //
890 // See b/221997940 for a real-world example of this.
891 if (n->op_def().name() == "Fill" &&
892 n->out_nodes().begin() != n->out_nodes().end() &&
893 absl::c_all_of(n->out_nodes(), [&](Node* user) {
894 return GetClusterForNode(user) != cluster;
895 })) {
896 declustered_nodes_.insert(n);
897 }
898 }
899
900 return OkStatus();
901 }
902
903 // Tracks monotonic sequence numbers for graphs.
904 class ClusterSequenceNumberGenerator {
905 public:
Reset()906 void Reset() {
907 mutex_lock lock(mu_);
908 sequence_numbers_.clear();
909 }
910
GetNext(uint64 key)911 int64 GetNext(uint64 key) {
912 mutex_lock lock(mu_);
913 return sequence_numbers_[key]++;
914 }
915
Global()916 static ClusterSequenceNumberGenerator& Global() {
917 static ClusterSequenceNumberGenerator* gen =
918 new ClusterSequenceNumberGenerator;
919 return *gen;
920 }
921
922 private:
923 mutex mu_;
924 absl::flat_hash_map<uint64, int64> sequence_numbers_;
925 };
926
927 // Get a monotonic sequence numbers for a graph identified by its `fingerprint`.
928 // The sequence number is necessary to disambiguate clusters extracted from the
929 // same graph and when duplicate graphs exist within the same process.
GetNextClusterSequenceNumber(uint64 fingerprint)930 int64_t GetNextClusterSequenceNumber(uint64 fingerprint) {
931 return ClusterSequenceNumberGenerator::Global().GetNext(fingerprint);
932 }
933
CreateClusters()934 Status MarkForCompilationPassImpl::CreateClusters() {
935 TF_RET_CHECK(initialized_ && edges_contracted_ && !clusters_created_);
936 clusters_created_ = true;
937
938 // Names for each cluster.
939 std::unordered_map<int, string> cluster_names;
940
941 if (debug_options_.dump_graphs) {
942 DumpGraphToFile("before_mark_for_compilation", *graph_, flib_def_);
943 }
944
945 // Mark clusters for compilation that:
946 // * are placed on a device that requires compilation (an XlaDevice),
947 // * are explicitly marked for compilation (_XlaCompile=true), or
948 // * have more than debug_options_.xla_min_cluster_size elements (applicable
949 // only if compilation is enabled, otherwise there will be no such
950 // candidates).
951 for (Node* n : compilation_candidates_) {
952 Cluster* cluster = GetClusterForNode(n);
953 TF_ASSIGN_OR_RETURN(bool should_compile_cluster,
954 ShouldCompileCluster(*cluster));
955 if (!should_compile_cluster || declustered_nodes_.contains(n)) {
956 continue;
957 }
958
959 // We assume that functional If and While nodes have at least
960 // min_cluster_size non-trivial nodes in them. It would be more principled
961 // to (recursively) verify this fact, but that's probably not worth the
962 // trouble.
963
964 if (cluster->effective_cluster_size() >= debug_options_.min_cluster_size ||
965 cluster->has_functional_control_flow() ||
966 cluster->is_xla_compile_attr_true()) {
967 string& name = cluster_names[cluster->cycles_graph_node_id()];
968
969 if (name.empty()) {
970 if (debug_options_.deterministic_cluster_names) {
971 name = absl::StrCat("cluster_", graph_fingerprint_, "_",
972 GetNextClusterSequenceNumber(graph_fingerprint_));
973 } else {
974 name = absl::StrCat("cluster_",
975 GetNextClusterSequenceNumber(graph_fingerprint_));
976 }
977 }
978
979 n->AddAttr(kXlaClusterAttr, name);
980 n->AddAttr(kXlaAlreadyClustered, true);
981 VLOG(3) << "Assigning node " << n->name() << " to cluster " << name;
982 }
983 }
984
985 return OkStatus();
986 }
987
DumpDebugInfo()988 Status MarkForCompilationPassImpl::DumpDebugInfo() {
989 TF_RET_CHECK(initialized_ && edges_contracted_ && clusters_created_);
990
991 if (debug_options_.dump_graphs) {
992 DumpPostClusteringGraphs();
993 }
994
995 VLogClusteringSummary();
996
997 return OkStatus();
998 }
999
1000 StatusOr<bool>
ClusteringWillIntroduceInterDeviceDependency(const Cluster & cluster_from,const Cluster & cluster_to)1001 MarkForCompilationPassImpl::ClusteringWillIntroduceInterDeviceDependency(
1002 const Cluster& cluster_from, const Cluster& cluster_to) {
1003 // If any of the consumer's producers are on a different device, do not
1004 // cluster these nodes. This prevents other work on this device from being
1005 // delayed by work on other devices. We consider predecessors of the entire
1006 // cluster rather than just the inputs to the node to prevent the cluster
1007 // still being combined in cases where the 'to' cluster has multiple
1008 // dependencies on the 'from' cluster and another dependency leads to a
1009 // merging of the clusters.
1010 //
1011 // TODO(b/117085735): We probably want to handle the reciprocal of this case
1012 // where a cluster is producing data for multiple devices.
1013 for (const auto& in_id :
1014 cycles_graph_.Predecessors(cluster_to.cycles_graph_node_id())) {
1015 const Cluster* cluster_in = GetClusterForCyclesGraphNode(in_id);
1016 if (cluster_in) {
1017 TF_ASSIGN_OR_RETURN(bool devices_compatible,
1018 AreDevicesCompatible(cluster_to, *cluster_in));
1019 if (!devices_compatible) {
1020 return true;
1021 }
1022 TF_ASSIGN_OR_RETURN(devices_compatible,
1023 AreDevicesCompatible(cluster_from, *cluster_in));
1024 if (!devices_compatible) {
1025 return true;
1026 }
1027 }
1028 }
1029
1030 return false;
1031 }
1032
GetXlaScope(Node * node)1033 std::optional<string> MarkForCompilationPassImpl::GetXlaScope(Node* node) {
1034 // Look for either _XlaScope or _XlaInternalScope on both nodes to guide
1035 // clustering. If both nodes have a scope and the scopes do not match, do
1036 // not cluster along this edge. If even one of the nodes lacks a scope
1037 // attribute, then it is treated as a "bridge" and a cluster may be created
1038 // along it.
1039 //
1040 // The difference between _XlaScope and _XlaInternalScope is that _XlaScope is
1041 // provided by users through jit_scope APIs, while _XlaInternalScope is
1042 // automatically generated by the ClusterScopingPass when auto_jit is on. As
1043 // such, we respect _XlaScope only when auto_jit is off, while respecting
1044 // _XlaInternalScope only when auto_jit is on.
1045 //
1046 // We may want to restrict the _XlaScope behavior to require all nodes marked
1047 // with _XlaCompile=true to also have a _XlaScope property set (and raise an
1048 // error otherwise); but for now we don't do this.
1049
1050 if (global_jit_level_ != OptimizerOptions::OFF) {
1051 // If global_jit_level_ is ON, respect only _XlaInternalScope.
1052 const string& scope =
1053 GetNodeAttrString(node->attrs(), kXlaInternalScopeAttr);
1054 if (!scope.empty()) {
1055 return scope;
1056 }
1057 } else {
1058 // If global_jit_level_ is OFF, respect only _XlaScope.
1059 const string& scope = GetNodeAttrString(node->attrs(), kXlaScopeAttr);
1060 if (!scope.empty()) {
1061 return scope;
1062 }
1063 }
1064
1065 return std::nullopt;
1066 }
1067
1068 // Returns true iff the attribute `attr_name` is attached to either the node or
1069 // to it's callee.
GetNodeOrFuncAttr(Node * node,FunctionLibraryDefinition * flib_def,const char * attr_name)1070 static bool GetNodeOrFuncAttr(Node* node, FunctionLibraryDefinition* flib_def,
1071 const char* attr_name) {
1072 bool out = false;
1073 bool attr_value;
1074 if (TryGetNodeAttr(node->attrs(), attr_name, &attr_value)) {
1075 out |= attr_value;
1076 }
1077
1078 if (flib_def->GetAttr(*node, attr_name, &attr_value).ok()) {
1079 out |= attr_value;
1080 }
1081 return out;
1082 }
1083
BuildInitialClusterSet()1084 Status MarkForCompilationPassImpl::BuildInitialClusterSet() {
1085 auto ignore_resource_ops = [&](const Node& n, bool* ignore) {
1086 return IgnoreResourceOpForSafetyAnalysis(&device_info_cache_, n, ignore);
1087 };
1088
1089 std::vector<std::pair<int, int>> unsafe_resource_deps_vect;
1090 TF_RETURN_IF_ERROR(ComputeIncompatibleResourceOperationPairs(
1091 *graph_, flib_def_, ignore_resource_ops, &unsafe_resource_deps_vect));
1092 absl::c_copy(
1093 unsafe_resource_deps_vect,
1094 std::inserter(unsafe_resource_deps_, unsafe_resource_deps_.begin()));
1095
1096 cluster_for_node_.resize(graph_->num_node_ids());
1097 for (Node* node : graph_->nodes()) {
1098 if (!IsCompilationCandidate(node)) {
1099 cluster_for_node_[node->id()].Get() = nullptr;
1100 continue;
1101 }
1102
1103 // We want clusters to be big enough that the benefit from XLA's
1104 // optimizations offsets XLA related overhead (for instance we add some
1105 // Switch/Merge nodes into the graph to implement lazy compilation). To
1106 // this end, we don't count Identity and Constant nodes because they do not
1107 // enable interesting optimizations by themselves.
1108 int effective_cluster_size =
1109 (node->IsIdentity() || node->IsConstant()) ? 0 : 1;
1110
1111 bool has_functional_control_flow = node->IsWhileNode() || node->IsIfNode();
1112
1113 std::optional<DeadnessPredicate> deadness_predicate;
1114 if (deadness_analysis_) {
1115 TF_ASSIGN_OR_RETURN(
1116 deadness_predicate,
1117 deadness_analysis_->GetPredicateFor(node, Graph::kControlSlot));
1118 }
1119
1120 const string& device_name_str = !node->assigned_device_name().empty()
1121 ? node->assigned_device_name()
1122 : node->requested_device();
1123 TF_ASSIGN_OR_RETURN(DeviceId device,
1124 device_info_cache_.GetIdFor(device_name_str));
1125
1126 bool is_resource_op = HasResourceInputOrOutput(*node);
1127 std::optional<DeviceId> resource_op_device;
1128 if (is_resource_op) {
1129 resource_op_device = device;
1130 }
1131
1132 std::optional<int> resource_var_operation_node_id;
1133 if (is_resource_op || MayCallFunction(*node, flib_def_)) {
1134 resource_var_operation_node_id = node->id();
1135 }
1136
1137 bool is_xla_compile_attr_true =
1138 GetNodeOrFuncAttr(node, flib_def_, kXlaCompileAttr) ||
1139 (global_jit_level_ != OptimizerOptions::OFF &&
1140 GetNodeOrFuncAttr(node, flib_def_, kXlaMustCompileAttr));
1141
1142 DeviceSet devices;
1143 devices.Insert(device);
1144
1145 Cluster* new_cluster = MakeNewCluster(
1146 /*cycles_graph_node_id=*/node->id(),
1147 /*effective_cluster_size=*/effective_cluster_size,
1148 /*has_functional_control_flow=*/has_functional_control_flow, devices,
1149 resource_op_device, resource_var_operation_node_id, deadness_predicate,
1150 /*is_xla_compile_attr_true=*/is_xla_compile_attr_true,
1151 GetXlaScope(node));
1152
1153 cluster_for_node_[node->id()].Get() = new_cluster;
1154 }
1155
1156 return OkStatus();
1157 }
1158
IsIdentityDrivingConstsInLoop(Node * node)1159 StatusOr<bool> IsIdentityDrivingConstsInLoop(Node* node) {
1160 if (!node->IsIdentity()) {
1161 return false;
1162 }
1163
1164 // Check if the Identity is driven by a Switch on its true path.
1165 auto it = absl::c_find_if(node->in_edges(), [](const Edge* e) {
1166 return e->src()->IsSwitch() && e->src_output() == 1;
1167 });
1168 if (it == node->in_edges().end()) {
1169 return false;
1170 }
1171 const Node* switch_node = (*it)->src();
1172
1173 // Check if the Switch is driven by LoopCond.
1174 const Node* maybe_loop_cond;
1175 TF_RETURN_IF_ERROR(switch_node->input_node(1, &maybe_loop_cond));
1176 if (!maybe_loop_cond->IsLoopCond()) {
1177 return false;
1178 }
1179
1180 // Check if the Identity is driving any const nodes through a control edge.
1181 bool driving_any_consts =
1182 absl::c_any_of(node->out_edges(), [](const Edge* e) {
1183 return e->dst()->IsConstant() && e->IsControlEdge();
1184 });
1185 if (!driving_any_consts) {
1186 return false;
1187 }
1188
1189 return true;
1190 }
1191
GetOrCreateClusterExcludeList()1192 absl::flat_hash_set<string> GetOrCreateClusterExcludeList() {
1193 MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
1194 absl::flat_hash_set<string> excludelist;
1195 for (auto s : absl::StrSplit(flags->tf_xla_cluster_exclude_ops, ',')) {
1196 if (!s.empty()) {
1197 excludelist.insert(string(s));
1198 }
1199 }
1200 if (VLOG_IS_ON(2) && !excludelist.empty()) {
1201 std::vector<string> vexcludelist(excludelist.begin(), excludelist.end());
1202 absl::c_sort(vexcludelist);
1203 VLOG(2) << "XLA clustering will exclude following TF operations from auto "
1204 "clustering: "
1205 << absl::StrJoin(vexcludelist, " ");
1206 }
1207 return excludelist;
1208 }
1209
GetOrCreateAllowlist()1210 absl::flat_hash_set<string> GetOrCreateAllowlist() {
1211 absl::flat_hash_map<string, std::vector<string>>* allowlist_table =
1212 tensorflow::GetAllowlistTable();
1213 MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
1214 absl::flat_hash_set<string> allowlist;
1215
1216 for (auto s : absl::StrSplit(flags->tf_xla_ops_to_cluster, ',')) {
1217 if (s == "FUSIBLE") {
1218 for (auto pair : *allowlist_table) {
1219 allowlist.insert(pair.second.begin(), pair.second.end());
1220 }
1221 } else if (allowlist_table->contains(s)) {
1222 auto v = allowlist_table->at(s);
1223 allowlist.insert(v.begin(), v.end());
1224 } else if (!s.empty()) {
1225 // Should be a user provided TF operation.
1226 allowlist.insert(string(s));
1227 }
1228 }
1229
1230 if (VLOG_IS_ON(2) && !allowlist.empty()) {
1231 std::vector<string> vallowlist(allowlist.begin(), allowlist.end());
1232 absl::c_sort(vallowlist);
1233 VLOG(2) << "XLA clustering will only consider the following TF operations: "
1234 << absl::StrJoin(vallowlist, " ");
1235 }
1236 return allowlist;
1237 }
1238
FindCompilationCandidates()1239 Status MarkForCompilationPassImpl::FindCompilationCandidates() {
1240 OptimizerOptions opts;
1241 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
1242 new ProcessFunctionLibraryRuntime(nullptr, env_, /*config=*/nullptr,
1243 TF_GRAPH_DEF_VERSION, flib_def_, opts));
1244 FunctionLibraryRuntime* lib_runtime =
1245 pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
1246 std::vector<bool> compile_time_const_nodes(graph_->num_node_ids(), false);
1247 TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
1248 *graph_, /*compile_time_const_arg_indices=*/nullptr,
1249 &compile_time_const_nodes, lib_runtime));
1250 // Iterate over nodes in sorted order so that compiler fuel is deterministic.
1251 // We can't simply pass op_nodes().begin() and op_nodes().end() to the
1252 // std::vector constructor because they're not proper iterators, with
1253 // iterator_traits defined and so on.
1254 std::vector<Node*> sorted_nodes;
1255 for (Node* node : graph_->op_nodes()) {
1256 sorted_nodes.push_back(node);
1257 }
1258 std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeComparatorID());
1259
1260 if (*debug_options_.fuel >= std::numeric_limits<int64_t>::max() / 2) {
1261 // The assumption is that if fuel started out as INT64_MAX, it will forever
1262 // stay greater than INT64_MAX / 2.
1263 VLOG(2) << "Starting fuel: infinity";
1264 } else {
1265 VLOG(2) << "Starting fuel: " << *debug_options_.fuel;
1266 }
1267
1268 VLOG(2) << "sorted_nodes.size() = " << sorted_nodes.size();
1269
1270 auto allowlist = GetOrCreateAllowlist();
1271
1272 std::vector<string> vall_ops = XlaOpRegistry::GetAllRegisteredOps();
1273 absl::flat_hash_set<string> all_ops(vall_ops.begin(), vall_ops.end());
1274 // Check that user's provided TF operation really exists.
1275 for (const auto& s : allowlist) {
1276 if (!all_ops.contains(s)) {
1277 return errors::InvalidArgument(
1278 "The operation '", s,
1279 "' passed to --tf_xla_ops_to_cluster is not supported by XLA.");
1280 }
1281 }
1282
1283 for (Node* node : sorted_nodes) {
1284 if (*debug_options_.fuel <= 0) {
1285 VLOG(1)
1286 << "Hit fuel limit; not marking any remaining ops as clusterable.";
1287 break;
1288 }
1289
1290 TF_ASSIGN_OR_RETURN(
1291 const DeviceType& device_type,
1292 device_info_cache_.GetDeviceTypeFor(node->assigned_device_name()));
1293 VLOG(4) << "Device type for " << node->name() << ": "
1294 << device_type.type_string();
1295
1296 if (CompilationDisallowedByXlaCompileAttr(node)) {
1297 VLOG(2) << "Not clustering " << node->name()
1298 << ": disallowed by _XlaCompile attribute";
1299 continue;
1300 }
1301
1302 const XlaOpRegistry::DeviceRegistration* registration;
1303 if (!XlaOpRegistry::GetCompilationDevice(device_type.type(),
1304 ®istration)) {
1305 VLOG(2) << "Rejecting " << node->name()
1306 << ": could not find JIT device for " << device_type.type();
1307 continue;
1308 }
1309
1310 auto cluster_exclude_op_list = GetOrCreateClusterExcludeList();
1311 RecursiveCompilabilityChecker::OperationFilter filter =
1312 CreateOperationFilter(*registration);
1313 filter.require_always_compilable = true;
1314 filter.allow_string_consts = false;
1315 filter.allow_collective_reduce_v2 = false;
1316 filter.allow_unique_op = false;
1317 filter.allow_where_op = true;
1318
1319 for (const auto& s : cluster_exclude_op_list) {
1320 if (s == "Where") {
1321 filter.allow_where_op = false;
1322 } else {
1323 return errors::InvalidArgument(
1324 "The operation '", s,
1325 "' passed to --tf_xla_cluster_exclude_ops is not supported by "
1326 "XLA.");
1327 }
1328 }
1329
1330 RecursiveCompilabilityChecker checker(
1331 filter, DeviceType{registration->compilation_device_name});
1332
1333 if (!checker.IsCompilableNode(*node, lib_runtime)) {
1334 continue;
1335 }
1336
1337 if (node->type_string() == "Const") {
1338 // Skip Const op with type DT_STRING, since XLA autoclustering doesn't
1339 // support it.
1340 const AttrValue* attr = node->attrs().Find("dtype");
1341 if (attr != nullptr && attr->type() == DT_STRING) {
1342 continue;
1343 }
1344 }
1345
1346 if (!allowlist.empty() && !allowlist.contains(node->def().op())) {
1347 VLOG(1) << "Rejecting TF operation " << node->def().op()
1348 << " as it is not listed in --tf_xla_ops_to_cluster.";
1349 continue;
1350 }
1351
1352 if (compile_time_const_nodes[node->id()]) {
1353 const OpDef* op_def;
1354 TF_RETURN_IF_ERROR(
1355 graph_->op_registry()->LookUpOpDef(node->type_string(), &op_def));
1356 if (op_def->is_stateful()) {
1357 // It is easiest to demonstrate the problem we're trying to solve with
1358 // an example. Say we have this graph:
1359 //
1360 // shape = RandomUniformInt();
1361 // reshape = Reshape(input, shape)
1362 //
1363 // Both RandomUniformInt and Reshape are compilable by XLA so, absent
1364 // any other reason, we will try to put both shape and reshape in the
1365 // same cluster. However, since XLA only supports statically shaped
1366 // values, it will expect to be able to constant fold `shape` to get a
1367 // static shape for `reshape`. This is a problem because side-effecting
1368 // ops like RandomUniformInt() cannot be constant folded. We fix this
1369 // by putting `shape` and `reshape` in different clusters, which results
1370 // in us recompiling `reshape`'s cluster for every new value of `shape`,
1371 // making `reshape` statically sized within each compilation. We
1372 // simplify the solution even further by disallowing operations like
1373 // `shape` from being part of *any* non-trivial cluster. They're either
1374 // not compiled by XLA altogether or, if assigned to an XLA_* device
1375 // with "must compile" semantics, compiled into a trivial single-op
1376 // cluster. This approach leaves some room for improvement, and we can
1377 // consider implementing a more aggressive data-flow-analysis based
1378 // solution in the future if needed.
1379 //
1380 // One ugly problem we have to contend with: certain sets of ops *have*
1381 // to be in the same cluster because values flowing between them have
1382 // types that can't be live-in or live-out of a cluster. These ops are:
1383 //
1384 // - TensorArray ops operating on the same TensorArray instance.
1385 // - Stack ops operating on the same Stack instance.
1386 //
1387 // To work around this we avoid isolating these specific ops. Because
1388 // of this concession it is unsound to auto-cluster them because then
1389 // we'd create clusters we could not compile (because we can't constant
1390 // fold, say, a TensorArrayRead or a StackPopV2). But we don't
1391 // auto-cluster these operations today so we're good for now.
1392 const XlaResourceOpInfo* op_info =
1393 GetResourceOpInfoForOp(node->type_string());
1394 bool is_tensor_array_or_stack_op =
1395 op_info && op_info->resource_kind() != XlaResourceKind::kVariable;
1396 if (!is_tensor_array_or_stack_op) {
1397 VLOG(2) << "Isolating " << node->name()
1398 << ": must-be-constant stateful op";
1399 continue;
1400 }
1401 }
1402 }
1403
1404 // This is a heuristic to avoid creating dependency between while loop
1405 // condition and body computations. Dependency between them can be created
1406 // if a special Identity node in the following pattern is clustered in.
1407 // That is, an Identity node in the loop cond computation is used to drive
1408 // const nodes consumed by the loop body. If this Identity node goes into
1409 // the same cluster with nodes from the loop body, extra dependency is
1410 // created between the loop cond and body computations and it hinders the
1411 // progression of the loop cond computation at runtime with significant
1412 // overhead. Specifically, we look for the below pattern and do not cluster
1413 // in this Identity to avoid the described issue. Since Identity has low
1414 // execution cost in native TF, the fact that this heuristic gives up these
1415 // special Identity nodes as candidates should not harm any performance. If
1416 // other considerations emerge in the future, we can revisit the heuristic
1417 // and only disallow these Identities to go into the cluster with nodes from
1418 // the loop body but still consider them candidates.
1419 //
1420 // LoopCond ->
1421 // Merge -> Switch -> Identity -> i++ -> ... -> NextIteration
1422 // ..> Const -> LoopBody
1423 // (control edge)
1424 TF_ASSIGN_OR_RETURN(bool is_identity_driving_consts_in_loop,
1425 IsIdentityDrivingConstsInLoop(node));
1426 if (is_identity_driving_consts_in_loop) {
1427 VLOG(2) << "Rejecting " << node->name()
1428 << ": including it can create dependencies between while loop "
1429 "condition and body computations with runtime overhead.";
1430 continue;
1431 }
1432
1433 compilation_candidates_.insert(node);
1434 --(*debug_options_.fuel);
1435 }
1436
1437 VLOG(2) << "compilation_candidates_.size() = "
1438 << compilation_candidates_.size();
1439 return OkStatus();
1440 }
1441
CompilationDisallowedByXlaCompileAttr(Node * node)1442 bool MarkForCompilationPassImpl::CompilationDisallowedByXlaCompileAttr(
1443 Node* node) {
1444 if (debug_options_.ignore_xla_compile_attr) {
1445 return false;
1446 }
1447
1448 // If there is a _XlaCompile annotation, use its value.
1449 bool compile = false;
1450 Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile);
1451 if (status.ok()) {
1452 if (!compile) {
1453 VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr("
1454 << kXlaCompileAttr << ") is false.";
1455 }
1456 return !compile;
1457 }
1458
1459 status = flib_def_->GetAttr(*node, kXlaCompileAttr, &compile);
1460 if (status.ok()) {
1461 if (!compile) {
1462 VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr("
1463 << kXlaCompileAttr << ") on callee is false.";
1464 }
1465 return !compile;
1466 }
1467
1468 return false;
1469 }
1470
LogNotContractableAndReturnFalse(Cluster * from,Cluster * to,absl::string_view reason)1471 bool MarkForCompilationPassImpl::LogNotContractableAndReturnFalse(
1472 Cluster* from, Cluster* to, absl::string_view reason) {
1473 VLOG(3) << EdgeContractionFailureMsg(from, to, reason);
1474 return false;
1475 }
1476
TryToContractEdge(Cluster * from,Cluster * to)1477 StatusOr<bool> MarkForCompilationPassImpl::TryToContractEdge(Cluster* from,
1478 Cluster* to) {
1479 DCHECK(from->deadness_predicate().has_value() ==
1480 to->deadness_predicate().has_value());
1481 if (from->deadness_predicate() != to->deadness_predicate()) {
1482 VLOG(3) << EdgeContractionFailureMsg(
1483 from, to,
1484 absl::StrCat(
1485 "the two nodes have mismatching deadness: ",
1486 deadness_analysis_->DebugString(*from->deadness_predicate()),
1487 " and ",
1488 deadness_analysis_->DebugString(*to->deadness_predicate())));
1489 return false;
1490 }
1491
1492 TF_ASSIGN_OR_RETURN(bool devices_compatible,
1493 AreDevicesCompatible(*from, *to));
1494 if (!devices_compatible) {
1495 return LogNotContractableAndReturnFalse(
1496 from, to, "the two nodes have incompatible devices");
1497 }
1498
1499 if (from->xla_scope().has_value() && to->xla_scope().has_value() &&
1500 *from->xla_scope() != *to->xla_scope()) {
1501 return LogNotContractableAndReturnFalse(
1502 from, to, "the two nodes have mismatching XLA scopes");
1503 }
1504
1505 // Don't exceed the maximum cluster size.
1506 if (from->cluster_size() + to->cluster_size() >
1507 debug_options_.max_cluster_size) {
1508 return LogNotContractableAndReturnFalse(
1509 from, to, "the new cluster will be larger than the max cluster size");
1510 }
1511
1512 TF_ASSIGN_OR_RETURN(bool will_introduce_cross_device_dependency,
1513 ClusteringWillIntroduceInterDeviceDependency(*from, *to));
1514
1515 if (will_introduce_cross_device_dependency) {
1516 return LogNotContractableAndReturnFalse(
1517 from, to, "the new cluster will introduce a cross device dependency");
1518 }
1519
1520 // Check if contracting this edge will break the resource variable concurrency
1521 // semantics. In theory this is quadratic in the number of nodes, but seems
1522 // to not be a problem in practice so far.
1523 if (!debug_options_.ignore_resource_variable_checks) {
1524 for (int resource_var_from : from->resource_var_operation_node_ids()) {
1525 for (int resource_var_to : to->resource_var_operation_node_ids()) {
1526 // If unsafe_resource_deps_ contains {A, B} then
1527 //
1528 // a. A and B are resource operations.
1529 // b. A and B cannot be placed in the same cluster.
1530 // c. There is no path from B to A in the cycles graph (but there may
1531 // be a path from A to B).
1532 //
1533 // So check the legality of the edge contraction by checking if any of
1534 // the n^2 pairs of resource variable operations are forbidden.
1535 if (unsafe_resource_deps_.contains(
1536 {resource_var_from, resource_var_to})) {
1537 return LogNotContractableAndReturnFalse(
1538 from, to,
1539 "the new cluster would break resource variable semantics");
1540 }
1541 }
1542 }
1543 }
1544
1545 return MergeClusters(from, to);
1546 }
1547
Run()1548 Status MarkForCompilationPassImpl::Run() {
1549 // Make sure that kernels have been registered on the JIT device.
1550 XlaOpRegistry::RegisterCompilationKernels();
1551
1552 // Start the timer after XlaOpRegistry::RegisterCompilationKernels which does
1553 // some one-time work.
1554 XLA_SCOPED_LOGGING_TIMER_LEVEL("MarkForCompilationPassImpl::Run", 1);
1555
1556 TF_ASSIGN_OR_RETURN(bool initialized, Initialize());
1557 if (!initialized) {
1558 // Initialization exited early which means this instance of
1559 // MarkForCompilationPassImpl is not set up to run the subsequent phases.
1560 return OkStatus();
1561 }
1562
1563 TF_RETURN_IF_ERROR(RunEdgeContractionLoop());
1564 TF_RETURN_IF_ERROR(DeclusterNodes());
1565 TF_RETURN_IF_ERROR(CreateClusters());
1566 TF_RETURN_IF_ERROR(DumpDebugInfo());
1567
1568 return OkStatus();
1569 }
1570
DumpPostClusteringGraphs()1571 void MarkForCompilationPassImpl::DumpPostClusteringGraphs() {
1572 DumpGraphToFile("mark_for_compilation", *graph_, flib_def_);
1573
1574 // We also dump out an annotated version of the TF graph where the nodes
1575 // names are prefixed with the cluster names. This can help visualizing the
1576 // clustering decisions on TensorBoard.
1577 Graph new_graph(graph_->op_registry());
1578 CopyGraph(*graph_, &new_graph);
1579
1580 for (Node* n : new_graph.nodes()) {
1581 if (std::optional<absl::string_view> cluster_name =
1582 GetXlaClusterForNode(*n)) {
1583 n->set_name(absl::StrCat(*cluster_name, "/", n->name()));
1584 } else if (n->type_string() == "VarHandleOp") {
1585 n->set_name(absl::StrCat("varhandle/", n->name()));
1586 } else {
1587 // There is room for improvement here. In particular, it may help to
1588 // split these unclustered nodes into classes where every node in a
1589 // specific class has edges to and from the same set of clusters.
1590 n->set_name(absl::StrCat("unclustered/", n->name()));
1591 }
1592 }
1593
1594 DumpGraphToFile("mark_for_compilation_annotated", new_graph, flib_def_);
1595 }
1596
RatioToString(int numerator,int denominator)1597 string RatioToString(int numerator, int denominator) {
1598 return absl::StrFormat("%d / %d (%.2f%%)", numerator, denominator,
1599 (100.0 * numerator) / denominator);
1600 }
1601
VLogClusteringSummary()1602 void MarkForCompilationPassImpl::VLogClusteringSummary() {
1603 if (!VLOG_IS_ON(2)) {
1604 return;
1605 }
1606
1607 XlaAutoClusteringSummary auto_clustering_info =
1608 GetXlaAutoClusteringSummary(*graph_);
1609
1610 VLOG(2) << "*** Clustering info for graph of size " << graph_->num_nodes();
1611 VLOG(2) << " Built " << auto_clustering_info.clusters_size()
1612 << " clusters, size "
1613 << RatioToString(auto_clustering_info.clustered_node_count(),
1614 graph_->num_nodes());
1615
1616 for (const XlaAutoClusteringSummary::Cluster& cluster :
1617 auto_clustering_info.clusters()) {
1618 absl::string_view cluster_name = cluster.name();
1619 int size = cluster.size();
1620 VLOG(2) << " " << cluster_name << " "
1621 << RatioToString(size, graph_->num_nodes());
1622 for (const XlaAutoClusteringSummary::OpAndCount& op_count :
1623 cluster.op_histogram()) {
1624 VLOG(3) << " " << op_count.op() << ": " << op_count.count()
1625 << " instances";
1626 }
1627 }
1628
1629 if (!auto_clustering_info.unclustered_op_histogram().empty()) {
1630 VLOG(2) << " Unclustered nodes: "
1631 << RatioToString(auto_clustering_info.unclustered_node_count(),
1632 graph_->num_nodes());
1633 for (const XlaAutoClusteringSummary::OpAndCount& op_count :
1634 auto_clustering_info.unclustered_op_histogram()) {
1635 VLOG(3) << " " << op_count.op() << ": " << op_count.count()
1636 << " instances";
1637 }
1638 }
1639
1640 struct EdgeInfo {
1641 absl::string_view node_name;
1642 std::optional<absl::string_view> cluster_name;
1643
1644 absl::string_view GetClusterName() const {
1645 return cluster_name ? *cluster_name : "[none]";
1646 }
1647
1648 std::pair<absl::string_view, std::optional<absl::string_view>> AsPair()
1649 const {
1650 return {node_name, cluster_name};
1651 }
1652
1653 bool operator<(const EdgeInfo& other) const {
1654 return AsPair() < other.AsPair();
1655 }
1656 };
1657
1658 using EdgeInfoMap = std::map<absl::string_view, std::map<EdgeInfo, int64_t>>;
1659
1660 EdgeInfoMap incoming_edge_infos;
1661 EdgeInfoMap outgoing_edge_infos;
1662
1663 std::set<absl::string_view> cluster_names_to_print;
1664
1665 for (const Edge* e : graph_->edges()) {
1666 const Node* from = e->src();
1667 std::optional<absl::string_view> from_cluster_name =
1668 GetXlaClusterForNode(*from);
1669
1670 const Node* to = e->dst();
1671 std::optional<absl::string_view> to_cluster_name =
1672 GetXlaClusterForNode(*to);
1673
1674 if (to_cluster_name == from_cluster_name) {
1675 continue;
1676 }
1677
1678 if (to_cluster_name) {
1679 incoming_edge_infos[*to_cluster_name]
1680 [EdgeInfo{from->name(), from_cluster_name}]++;
1681 cluster_names_to_print.insert(*to_cluster_name);
1682 }
1683
1684 if (from_cluster_name) {
1685 outgoing_edge_infos[*from_cluster_name][{to->name(), to_cluster_name}]++;
1686 cluster_names_to_print.insert(*from_cluster_name);
1687 }
1688 }
1689
1690 VLOG(4) << "*** Inter-Cluster edges:";
1691 if (cluster_names_to_print.empty()) {
1692 VLOG(4) << " [none]";
1693 }
1694
1695 auto print_edge_info_set_for_cluster = [&](absl::string_view cluster_name,
1696 const EdgeInfoMap& edge_info_map,
1697 absl::string_view desc) {
1698 auto it = edge_info_map.find(cluster_name);
1699 if (it != edge_info_map.end()) {
1700 VLOG(4) << " " << it->second.size() << " " << desc << " edges";
1701 for (const auto& edge_info_count_pair : it->second) {
1702 VLOG(4) << " " << edge_info_count_pair.first.GetClusterName() << " "
1703 << edge_info_count_pair.first.node_name << " # "
1704 << edge_info_count_pair.second;
1705 }
1706 } else {
1707 VLOG(4) << " No " << desc << " edges.";
1708 }
1709 };
1710
1711 for (absl::string_view cluster_name : cluster_names_to_print) {
1712 VLOG(4) << " ** Cluster " << cluster_name;
1713 print_edge_info_set_for_cluster(cluster_name, incoming_edge_infos,
1714 "incoming");
1715 print_edge_info_set_for_cluster(cluster_name, outgoing_edge_infos,
1716 "outgoing");
1717 }
1718 }
1719
AreDevicesCompatible(const Cluster & cluster_a,const Cluster & cluster_b)1720 StatusOr<bool> MarkForCompilationPassImpl::AreDevicesCompatible(
1721 const Cluster& cluster_a, const Cluster& cluster_b) {
1722 DeviceSet devices = cluster_a.devices();
1723 devices.UnionWith(cluster_b.devices());
1724
1725 TF_ASSIGN_OR_RETURN(
1726 std::optional<jit::DeviceId> maybe_chosen_device,
1727 MaybePickDeviceForXla(device_info_cache_, devices,
1728 /*allow_mixing_unknown_and_cpu=*/false));
1729 if (!maybe_chosen_device.has_value()) {
1730 return false;
1731 }
1732
1733 jit::DeviceId chosen_device = *maybe_chosen_device;
1734
1735 // If we are able to pick a device `chosen_device` for the larger cluster, the
1736 // resource operations in `cluster_a` and `cluster_b` must be placed on the
1737 // same device as `chosen_device`. This is because the _XlaCompile and
1738 // _XlaRun kernels are going to run on and therefore try to access the
1739 // resource variables from `chosen_device`, which will be an error if the
1740 // resource variables are placed on some other device.
1741 auto resource_op_device_ok = [&](std::optional<DeviceId> resource_op_device) {
1742 return !resource_op_device.has_value() ||
1743 *resource_op_device == chosen_device;
1744 };
1745
1746 return resource_op_device_ok(cluster_a.resource_op_device()) &&
1747 resource_op_device_ok(cluster_b.resource_op_device());
1748 }
1749
1750 // Returns `true` iff we should compile `cluster`.
ShouldCompileClusterImpl(const Cluster & cluster)1751 StatusOr<bool> MarkForCompilationPassImpl::ShouldCompileClusterImpl(
1752 const Cluster& cluster) {
1753 TF_ASSIGN_OR_RETURN(DeviceId chosen_device,
1754 PickDeviceForXla(device_info_cache_, cluster.devices(),
1755 /*allow_mixing_unknown_and_cpu=*/false));
1756
1757 const DeviceType& device_type =
1758 device_info_cache_.GetDeviceTypeFor(chosen_device);
1759 const XlaOpRegistry::DeviceRegistration* registration =
1760 device_info_cache_.GetCompilationDevice(chosen_device);
1761 TF_RET_CHECK(registration)
1762 << "chosen device = " << device_info_cache_.GetNameFor(chosen_device)
1763 << "; device type = " << device_type.type() << "; devices ("
1764 << device_info_cache_.DebugString(cluster.devices());
1765
1766 auto policy = registration->autoclustering_policy;
1767 bool should_compile =
1768 cluster.is_xla_compile_attr_true() ||
1769 policy == XlaOpRegistry::AutoclusteringPolicy::kAlways ||
1770 (policy == XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally &&
1771 global_jit_level_ != OptimizerOptions::OFF) ||
1772 (device_type.type_string() == DEVICE_CPU &&
1773 policy == XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested &&
1774 cpu_global_jit_);
1775
1776 if (!should_compile && device_type.type_string() == DEVICE_CPU &&
1777 global_jit_level_ > OptimizerOptions::OFF) {
1778 static absl::once_flag once;
1779 absl::call_once(once, [] {
1780 LOG(WARNING) << R"((One-time warning): Not using XLA:CPU for cluster.
1781
1782 If you want XLA:CPU, do one of the following:
1783
1784 - set the TF_XLA_FLAGS to include "--tf_xla_cpu_global_jit", or
1785 - set cpu_global_jit to true on this session's OptimizerOptions, or
1786 - use experimental_jit_scope, or
1787 - use tf.function(jit_compile=True).
1788
1789 To confirm that XLA is active, pass --vmodule=xla_compilation_cache=1 (as a
1790 proper command-line flag, not via TF_XLA_FLAGS).)";
1791
1792 MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
1793 if (flags->tf_xla_cpu_global_jit) {
1794 LOG(WARNING)
1795 << "(Although the tf_xla_cpu_global_jit flag is currently enabled, "
1796 "perhaps it wasn't enabled at process startup?)";
1797 }
1798 });
1799 }
1800
1801 VLOG(3) << (should_compile ? "Compiling" : "Not compiling")
1802 << " cluster with device "
1803 << device_info_cache_.GetNameFor(chosen_device);
1804
1805 return should_compile;
1806 }
1807
ShouldCompileCluster(const Cluster & cluster)1808 StatusOr<bool> MarkForCompilationPassImpl::ShouldCompileCluster(
1809 const Cluster& cluster) {
1810 auto it = should_compile_cluster_cache_.find(&cluster);
1811 if (it != should_compile_cluster_cache_.end()) {
1812 return it->second;
1813 }
1814
1815 TF_ASSIGN_OR_RETURN(bool should_compile, ShouldCompileClusterImpl(cluster));
1816 should_compile_cluster_cache_.insert({&cluster, should_compile});
1817 return should_compile;
1818 }
1819
MarkForCompilation(const GraphOptimizationPassOptions & options,const MarkForCompilationPassImpl::DebugOptions & debug_options)1820 Status MarkForCompilation(
1821 const GraphOptimizationPassOptions& options,
1822 const MarkForCompilationPassImpl::DebugOptions& debug_options) {
1823 Graph* graph = options.graph->get();
1824 FunctionLibraryDefinition* flib_def = options.flib_def;
1825
1826 // Deadness analysis expects a graph with source and sink edges properly
1827 // connected but sometimes the incoming graph does not follow this invariant.
1828 // So fix up the source and sink edges before calling into deadness analysis.
1829 FixupSourceAndSinkEdges(graph);
1830
1831 for (Node* n : graph->nodes()) {
1832 // See explanation on `kXlaAlreadyClustered`.
1833 if (n->attrs().Find(kXlaAlreadyClustered)) {
1834 return OkStatus();
1835 }
1836 // Skip the pass if we found TPUExecute or TPUExecuteAndUpdateVariables ops
1837 // in the graph, which indicates the graph is produced by TPU TF-XLA bridge
1838 // and doesn't require auto clustering.
1839 if (n->type_string() == "TPUExecute" ||
1840 n->type_string() == "TPUExecuteAndUpdateVariables") {
1841 return OkStatus();
1842 }
1843 }
1844
1845 return MarkForCompilationPassImpl{
1846 debug_options,
1847 graph,
1848 flib_def,
1849 options.session_options != nullptr ? options.session_options->env
1850 : Env::Default(),
1851 GetGlobalJitLevelForGraph(options),
1852 options.session_options->config.graph_options()
1853 .optimizer_options()
1854 .cpu_global_jit()}
1855 .Run();
1856 }
1857
GetPointerToFuel(int64_t initial_value)1858 std::atomic<int64_t>* GetPointerToFuel(int64_t initial_value) {
1859 static std::atomic<int64_t>* fuel = [&]() {
1860 std::atomic<int64_t>* fuel = new std::atomic<int64_t>;
1861 *fuel = initial_value;
1862 return fuel;
1863 }();
1864
1865 return fuel;
1866 }
1867 } // anonymous namespace
1868
Run(const GraphOptimizationPassOptions & options)1869 Status MarkForCompilationPass::Run(
1870 const GraphOptimizationPassOptions& options) {
1871 MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
1872
1873 MarkForCompilationPassImpl::DebugOptions debug_options;
1874 debug_options.ignore_deadness_checks =
1875 flags->tf_xla_disable_deadness_safety_checks_for_debugging;
1876 debug_options.ignore_resource_variable_checks =
1877 flags->tf_xla_disable_resource_variable_safety_checks_for_debugging;
1878 debug_options.ignore_xla_compile_attr = false;
1879 debug_options.deterministic_cluster_names =
1880 flags->tf_xla_deterministic_cluster_names;
1881 debug_options.max_cluster_size = flags->tf_xla_max_cluster_size;
1882 debug_options.min_cluster_size = flags->tf_xla_min_cluster_size;
1883 debug_options.fuel = GetPointerToFuel(flags->tf_xla_clustering_fuel);
1884 debug_options.dump_graphs = flags->tf_xla_clustering_debug;
1885
1886 return MarkForCompilation(options, debug_options);
1887 }
1888
RunForTest(const GraphOptimizationPassOptions & options,bool disable_deadness_analysis,bool deterministic_cluster_names)1889 Status MarkForCompilationPass::RunForTest(
1890 const GraphOptimizationPassOptions& options, bool disable_deadness_analysis,
1891 bool deterministic_cluster_names) {
1892 MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
1893
1894 MarkForCompilationPassImpl::DebugOptions debug_options;
1895 debug_options.ignore_deadness_checks = disable_deadness_analysis;
1896 debug_options.ignore_resource_variable_checks =
1897 flags->tf_xla_disable_resource_variable_safety_checks_for_debugging;
1898 debug_options.ignore_xla_compile_attr = true;
1899 debug_options.deterministic_cluster_names = deterministic_cluster_names;
1900 debug_options.max_cluster_size = flags->tf_xla_max_cluster_size;
1901 debug_options.min_cluster_size = flags->tf_xla_min_cluster_size;
1902 debug_options.fuel = GetPointerToFuel(flags->tf_xla_clustering_fuel);
1903 debug_options.dump_graphs = flags->tf_xla_clustering_debug;
1904
1905 return MarkForCompilation(options, debug_options);
1906 }
1907
GetAllowlistTable()1908 absl::flat_hash_map<string, std::vector<string>>* GetAllowlistTable() {
1909 // Table format: category name: {list of TF operations in that category}
1910 static absl::flat_hash_map<string, std::vector<string>>* result =
1911 new absl::flat_hash_map<string, std::vector<string>>{
1912 // Unary
1913 {"PW",
1914 {"ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin",
1915 "Atan", "Atanh", "Ceil", "Cos", "Cosh", "Sin", "Exp", "Expm1",
1916 "Floor", "IsFinite", "IsInf", "IsNan", "Inv", "Reciprocal", "Log",
1917 "Log1p", "Invert", "LogicalNot", "Ndtri", "Neg", "Rint", "Round",
1918 "Rsqrt", "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt",
1919 "Square", "Tan", "Tanh", "Real", "Imag", "Erf", "Erfc", "Erfinv",
1920 "Lgamma", "Digamma",
1921 // Binary
1922 "Add", "AddV2", "Sub", "Mul", "Div", "Atan2", "Complex", "DivNoNan",
1923 "MulNoNan", "FloorDiv", "Xlogy", "Xlog1py", "Xdivy", "FloorMod",
1924 "BitwiseAnd", "BitwiseOr", "BitwiseXor", "LeftShift", "RightShift",
1925 "LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv",
1926 "ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "TruncateDiv",
1927 "TruncateMod", "Equal", "NotEqual", "Greater", "GreaterEqual",
1928 "Less", "LessEqual", "SigmoidGrad", "SoftplusGrad", "SoftsignGrad",
1929 "TanhGrad", "Pow", "SquaredDifference", "ApproximateEqual",
1930 // Others
1931 "AddN", "Bitcast", "Cast", "ClipByValue", "Const", "Empty",
1932 "Identity", "IdentityN", "Relu", "Relu6", "ReluGrad", "Relu6Grad",
1933 "LeakyReluGrad", "Elu", "EluGrad", "Selu", "SeluGrad", "Select",
1934 "SelectV2", "Transpose", "ConjugateTranspose",
1935 "_UnaryOpsComposition", "CollectiveReduceV2",
1936 "CollectiveAssignGroupV2",
1937 // The following 5 operations are converted to identity
1938 "PlaceholderWithDefault", "PreventGradient", "StopGradient",
1939 "Snapshot", "_EagerConst"}},
1940 // clang-format off
1941 {"RED",
1942 {"All", "Any", "Min", "Max", "Mean", "Prod", "Sum"}},
1943 // clang-format on
1944 {"PWRED",
1945 {"ArgMax", "ArgMin", "DiagPart", "Softmax",
1946 "SparseSoftmaxCrossEntropyWithLogits", "LogSoftmax"}},
1947 {"REDUCEWINDOW",
1948 {"ArgMax", "ArgMin", "DiagPart", "Softmax",
1949 "SparseSoftmaxCrossEntropyWithLogits", "LogSoftmax"}},
1950 {"REDUCEWINDOWPW", {"BiasAddGrad", "LRN", "LRNGrad"}},
1951 {"BN",
1952 {"FusedBatchNorm", "FusedBatchNormV2", "FusedBatchNormV3",
1953 "_FusedBatchNormEx", "FusedBatchNormGrad", "FusedBatchNormGradV2",
1954 "FusedBatchNormGradV3"}},
1955 {"SORT", {"TopKV2"}}, // XLA version much faster then TF version.
1956 {"MISC",
1957 // clang-format off
1958 {"ApproxTopK", "BroadcastTo", "ExpandDims", "Fill", "NoOp",
1959 "Range", "Rank", "Reshape", "Shape", "ShapeN", "Size", "Squeeze",
1960 "Transpose", "ZerosLike", "OnesLike", "BiasAdd" /*PW + Broadcast*/,
1961 "BroadcastArgs", "BroadcastGradientArgs", "OneHot", "Concat", "ConcatV2",
1962 "ConcatOffset", "Const", "MirrorPad", "MirrorPadGrad", "Pack", "Pad",
1963 "PadV2", "Reverse", "ReverseV2", "ReverseSequence", "Slice", "Split",
1964 "SplitV", "StridedSlice", "StridedSliceGrad",
1965 "ResourceStridedSliceAssign", "Tile", "Transpose", "InvertPermutation",
1966 "Unpack", "DeviceIndex", "TensorStridedSliceUpdate", "XlaConcatND",
1967 "XlaSplitND",
1968 }}};
1969 // clang-format on
1970 return result;
1971 }
1972
1973 namespace testing {
ResetClusterSequenceNumber()1974 void ResetClusterSequenceNumber() {
1975 ClusterSequenceNumberGenerator::Global().Reset();
1976 }
1977
GetKnownXLAAllowlistOp()1978 absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
1979 absl::flat_hash_set<string> result{
1980 "AdjustContrastv2",
1981 "AdjustHue",
1982 "AdjustSaturation",
1983 "Asinh",
1984 "Assert",
1985 "AssignAddVariableOp",
1986 "AssignSubVariableOp",
1987 "AssignVariableOp",
1988 "AssignVariableXlaConcatND",
1989 "AvgPool",
1990 "AvgPool3D",
1991 "AvgPool3DGrad",
1992 "AvgPoolGrad",
1993 "BatchMatMul",
1994 "BatchMatMulV2",
1995 "BatchMatMulV3",
1996 "BatchToSpace",
1997 "BatchToSpaceND",
1998 "BesselI0e",
1999 "BesselI1e",
2000 "Betainc",
2001 "BiasAddV1",
2002 "Bincount",
2003 "Bucketize",
2004 "Case",
2005 "CheckNumerics",
2006 "Cholesky",
2007 "ControlTrigger",
2008 "Conv2D",
2009 "Conv2DBackpropFilter",
2010 "Conv2DBackpropInput",
2011 "Conv3D",
2012 "Conv3DBackpropFilterV2",
2013 "Conv3DBackpropInputV2",
2014 "Cross",
2015 "Cumprod",
2016 "Cumsum",
2017 "DenseBincount",
2018 "DataFormatDimMap",
2019 "DataFormatVecPermute",
2020 "DepthToSpace",
2021 "DepthwiseConv2dNative",
2022 "DepthwiseConv2dNativeBackpropFilter",
2023 "DepthwiseConv2dNativeBackpropInput",
2024 "Dequantize",
2025 "Diag",
2026 "DynamicStitch",
2027 "DynamicPartition",
2028 "Einsum",
2029 "EmptyTensorList",
2030 "EnsureShape",
2031 "ExtractImagePatches",
2032 "Igamma",
2033 "IgammaGradA",
2034 "RandomGammaGrad",
2035 "Igammac",
2036 "FFT",
2037 "FFT2D",
2038 "FFT3D",
2039 "FakeParam",
2040 "FakeQuantWithMinMaxArgs",
2041 "FakeQuantWithMinMaxArgsGradient",
2042 "FakeQuantWithMinMaxVars",
2043 "FakeQuantWithMinMaxVarsGradient",
2044 "FakeQuantWithMinMaxVarsPerChannel",
2045 "FakeQuantWithMinMaxVarsPerChannelGradient",
2046 "Gather",
2047 "GatherNd",
2048 "GatherV2",
2049 "HSVToRGB",
2050 "IFFT",
2051 "IFFT2D",
2052 "IFFT3D",
2053 "IRFFT",
2054 "IRFFT2D",
2055 "IRFFT3D",
2056 "If",
2057 "InTopKV2",
2058 "L2Loss",
2059 "LeakyRelu",
2060 "LinSpace",
2061 "ListDiff",
2062 "LogMatrixDeterminant",
2063 "LowerBound",
2064 "MatMul",
2065 "MatrixBandPart",
2066 "MatrixDiag",
2067 "MatrixDiagPart",
2068 "MatrixDiagPartV2",
2069 "MatrixDiagPartV3",
2070 "MatrixDiagV2",
2071 "MatrixDiagV3",
2072 "MatrixInverse",
2073 "MatrixSetDiag",
2074 "MatrixSetDiagV2",
2075 "MatrixSetDiagV3",
2076 "MatrixSolve",
2077 "MatrixTriangularSolve",
2078 "MaxPool",
2079 "MaxPool3D",
2080 "MaxPool3DGrad",
2081 "MaxPool3DGradGrad",
2082 "MaxPoolGrad",
2083 "MaxPoolGradGrad",
2084 "MaxPoolGradGradV2",
2085 "MaxPoolGradV2",
2086 "MaxPoolV2",
2087 "Multinomial",
2088 "NextAfter",
2089 "NonMaxSuppressionV3",
2090 "NonMaxSuppressionV4",
2091 "ParallelDynamicStitch",
2092 "ParameterizedTruncatedNormal",
2093 "PartitionedCall",
2094 "Polygamma",
2095 "PopulationCount",
2096 "Qr",
2097 "QuantizeAndDequantizeV2",
2098 "QuantizeAndDequantizeV3",
2099 "QuantizeAndDequantizeV4",
2100 "RFFT",
2101 "RFFT2D",
2102 "RFFT3D",
2103 "RGBToHSV",
2104 "RandomShuffle",
2105 "RandomStandardNormal",
2106 "RandomUniform",
2107 "RandomUniformInt",
2108 "ReadVariableOp",
2109 "ReadVariableXlaSplitND",
2110 "ResizeBilinear",
2111 "ResizeBilinearGrad",
2112 "ResizeNearestNeighbor",
2113 "ResourceApplyAdaMax",
2114 "ResourceApplyAdadelta",
2115 "ResourceApplyAdagrad",
2116 "ResourceApplyAdagradDA",
2117 "ResourceApplyAdagradV2",
2118 "ResourceApplyAdam",
2119 "ResourceApplyAddSign",
2120 "ResourceApplyCenteredRMSProp",
2121 "ResourceApplyFtrl",
2122 "ResourceApplyFtrlV2",
2123 "ResourceApplyGradientDescent",
2124 "ResourceApplyKerasMomentum",
2125 "ResourceApplyMomentum",
2126 "ResourceApplyPowerSign",
2127 "ResourceApplyProximalAdagrad",
2128 "ResourceApplyProximalGradientDescent",
2129 "ResourceApplyRMSProp",
2130 "ResourceGather",
2131 "ResourceScatterAdd",
2132 "ResourceScatterDiv",
2133 "ResourceScatterMax",
2134 "ResourceScatterMin",
2135 "ResourceScatterMul",
2136 "ResourceScatterNdAdd",
2137 "ResourceScatterNdSub",
2138 "ResourceScatterNdUpdate",
2139 "ResourceScatterSub",
2140 "ResourceScatterUpdate",
2141 "RngReadAndSkip",
2142 "RngSkip",
2143 "Roll",
2144 "ScatterNd",
2145 "SelfAdjointEigV2",
2146 "SoftmaxCrossEntropyWithLogits",
2147 "SpaceToBatch",
2148 "SpaceToBatchND",
2149 "SpaceToDepth",
2150 "SparseMatMul",
2151 "SparseToDense",
2152 "StackCloseV2",
2153 "StackPopV2",
2154 "StackPushV2",
2155 "StackV2",
2156 "StatefulPartitionedCall",
2157 "StatefulStandardNormalV2",
2158 "StatefulTruncatedNormal",
2159 "StatefulUniform",
2160 "StatefulUniformFullInt",
2161 "StatefulUniformInt",
2162 "StatelessCase",
2163 "StatelessIf",
2164 "StatelessMultinomial",
2165 "StatelessParameterizedTruncatedNormal",
2166 "StatelessRandomGetAlg",
2167 "StatelessRandomGetKeyCounter",
2168 "StatelessRandomGetKeyCounterAlg",
2169 "StatelessRandomNormal",
2170 "StatelessRandomNormalV2",
2171 "StatelessRandomUniform",
2172 "StatelessRandomUniformV2",
2173 "StatelessRandomUniformInt",
2174 "StatelessRandomUniformIntV2",
2175 "StatelessRandomUniformFullInt",
2176 "StatelessRandomUniformFullIntV2",
2177 "StatelessTruncatedNormal",
2178 "StatelessTruncatedNormalV2",
2179 "StatelessWhile",
2180 "Svd",
2181 "SymbolicGradient",
2182 "TensorArrayCloseV3",
2183 "TensorArrayConcatV3",
2184 "TensorArrayGatherV3",
2185 "TensorArrayGradV3",
2186 "TensorArrayReadV3",
2187 "TensorArrayScatterV3",
2188 "TensorArraySizeV3",
2189 "TensorArraySplitV3",
2190 "TensorArrayV3",
2191 "TensorArrayWriteV3",
2192 "TensorListConcatV2",
2193 "TensorListElementShape",
2194 "TensorListFromTensor",
2195 "TensorListGather",
2196 "TensorListGetItem",
2197 "TensorListLength",
2198 "TensorListPopBack",
2199 "TensorListPushBack",
2200 "TensorListReserve",
2201 "TensorListSetItem",
2202 "TensorListSplit",
2203 "TensorListStack",
2204 "TensorScatterAdd",
2205 "TensorScatterMax",
2206 "TensorScatterMin",
2207 "TensorScatterSub",
2208 "TensorScatterUpdate",
2209 "ToBool",
2210 "TridiagonalSolve",
2211 "TridiagonalMatMul",
2212 "TruncatedNormal",
2213 "Unique",
2214 "UniqueV2",
2215 "UpperBound",
2216 "UnsortedSegmentMax",
2217 "UnsortedSegmentMin",
2218 "UnsortedSegmentProd",
2219 "UnsortedSegmentSum",
2220 "VarIsInitializedOp",
2221 "VariableShape",
2222 "Where",
2223 "While",
2224 "XlaBroadcastHelper",
2225 "XlaCallModule",
2226 "XlaConcatND",
2227 "XlaConv",
2228 "XlaConvV2",
2229 "XlaCustomCall",
2230 "XlaDequantize",
2231 "XlaDot",
2232 "XlaDotV2",
2233 "XlaDynamicSlice",
2234 "XlaDynamicUpdateSlice",
2235 "XlaEinsum",
2236 "XlaGather",
2237 "XlaIf",
2238 "XlaKeyValueSort",
2239 "XlaOptimizationBarrier",
2240 "XlaPad",
2241 "XlaRecv",
2242 "XlaReduce",
2243 "XlaReducePrecision",
2244 "XlaReduceWindow",
2245 "XlaRemoveDynamicDimensionSize",
2246 "XlaReplicaId",
2247 "XlaRngBitGenerator",
2248 "XlaScatter",
2249 "XlaSelectAndScatter",
2250 "XlaSelfAdjointEig",
2251 "XlaSend",
2252 "XlaSetBound",
2253 "XlaSetDynamicDimensionSize",
2254 "XlaSharding",
2255 "XlaSort",
2256 "XlaSplitND",
2257 "XlaSpmdFullToShardShape",
2258 "XlaSpmdShardToFullShape",
2259 "XlaSvd",
2260 "XlaVariadicReduce",
2261 "XlaVariadicReduceV2",
2262 "XlaVariadicSort",
2263 "XlaWhile",
2264 "Zeta",
2265 "_Arg",
2266 "_ArrayToList",
2267 "_ListToArray",
2268 "_Retval"};
2269 return result;
2270 }
2271
2272 } // namespace testing
2273 } // namespace tensorflow
2274