xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/jit/mark_for_compilation_pass.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
17 
18 #include <algorithm>
19 #include <atomic>
20 #include <deque>
21 #include <limits>
22 #include <unordered_map>
23 #include <unordered_set>
24 
25 #include "absl/base/call_once.h"
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/strings/str_join.h"
29 #include "tensorflow/compiler/jit/compilability_check_util.h"
30 #include "tensorflow/compiler/jit/deadness_analysis.h"
31 #include "tensorflow/compiler/jit/defs.h"
32 #include "tensorflow/compiler/jit/device_util.h"
33 #include "tensorflow/compiler/jit/flags.h"
34 #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
35 #include "tensorflow/compiler/jit/xla_cluster_util.h"
36 #include "tensorflow/compiler/tf2xla/const_analysis.h"
37 #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
38 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
39 #include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
40 #include "tensorflow/compiler/xla/statusor.h"
41 #include "tensorflow/compiler/xla/union_find.h"
42 #include "tensorflow/compiler/xla/util.h"
43 #include "tensorflow/core/common_runtime/function.h"
44 #include "tensorflow/core/common_runtime/graph_constructor.h"
45 #include "tensorflow/core/framework/bounds_check.h"
46 #include "tensorflow/core/framework/graph_def_util.h"
47 #include "tensorflow/core/framework/memory_types.h"
48 #include "tensorflow/core/framework/node_def.pb.h"
49 #include "tensorflow/core/framework/op_kernel.h"
50 #include "tensorflow/core/framework/tensor.pb.h"
51 #include "tensorflow/core/framework/types.h"
52 #include "tensorflow/core/graph/algorithm.h"
53 #include "tensorflow/core/graph/control_flow.h"
54 #include "tensorflow/core/lib/gtl/cleanup.h"
55 #include "tensorflow/core/lib/gtl/flatmap.h"
56 #include "tensorflow/core/lib/strings/stringprintf.h"
57 #include "tensorflow/core/platform/errors.h"
58 #include "tensorflow/core/platform/mutex.h"
59 #include "tensorflow/core/platform/statusor.h"
60 #include "tensorflow/core/platform/types.h"
61 #include "tensorflow/core/public/version.h"
62 #include "tensorflow/core/util/dump_graph.h"
63 
64 namespace tensorflow {
65 
66 namespace {
67 using DeadnessPredicate = DeadnessAnalysis::DeadnessPredicate;
68 using jit::DeviceId;
69 using jit::DeviceSet;
70 
71 // The clusters we create here are eventually lowered into an
72 // _XlaCompile/_XlaRun pair with a TF executor "fallback" that uses the
73 // PartitionedCall op to execute the cluster in the regular graph executor if
74 // need be.  PartitionedCall, however, reruns the entire TF graph optimization
75 // pipeline over the cluster which includes this mark for compilation pass.  To
76 // avoid endlessly recursing we tag nodes that we've already visited with this
77 // attribute so that we can bail out if we see them a second time.
78 //
79 // TODO(sanjoy): This method is not robust since it is possible that the
80 // optimizations run by PartitionedCall can mutate the cluster arbitrarily,
81 // dropping the kXlaAlreadyClustered attributes from all nodes in the process.
82 // The correct fix is to use the ConfigProto to pass in some sort of flag into
83 // the PartitionedCall kernel that tells it to not rerun auto-clustering on the
84 // cluster.
85 const char* kXlaAlreadyClustered = "_XlaAlreadyClustered";
86 
87 class MarkForCompilationPassImpl {
88  public:
89   struct DebugOptions {
90     // If true, do not respect the results of deadness analysis.
91     bool ignore_deadness_checks;
92 
93     // If true, do not do safety checks to preserve TensorFlow's resource
94     // variable concurrency semantics.
95     bool ignore_resource_variable_checks;
96 
97     // If true, do not respect the _XlaCompile=false attribute.
98     bool ignore_xla_compile_attr;
99 
100     // If true, compute the cluster name in a deterministic way so that its
101     // stable from run to rum.
102     bool deterministic_cluster_names;
103 
104     int max_cluster_size;
105     int min_cluster_size;
106 
107     // Compiler fuel for the auto-clustering algorithm.
108     //
109     // We decrement this value by one on every time we choose a compilation
110     // candidate and we stop clustering when it hits zero.  This means the
111     // initial value for this variable (via --tf_xla_clustering_fuel=N)
112     // effectively acts as a "cap" for how much we cluster and we can bisect
113     // over this initial value to discover clustering decisions that cause a
114     // miscompile or a performance regression.
115     std::atomic<int64_t>* fuel;
116 
117     bool dump_graphs;
118   };
119 
MarkForCompilationPassImpl(DebugOptions debug_options,Graph * graph,FunctionLibraryDefinition * flib_def,Env * env,OptimizerOptions::GlobalJitLevel global_jit_level,bool cpu_global_jit)120   MarkForCompilationPassImpl(DebugOptions debug_options, Graph* graph,
121                              FunctionLibraryDefinition* flib_def, Env* env,
122                              OptimizerOptions::GlobalJitLevel global_jit_level,
123                              bool cpu_global_jit)
124       : debug_options_(debug_options),
125         graph_(graph),
126         graph_fingerprint_(0),
127         flib_def_(flib_def),
128         env_(env),
129         global_jit_level_(global_jit_level),
130         cpu_global_jit_(cpu_global_jit) {}
131 
132   Status Run();
133 
134  private:
135   // Represents a "cluster" or a connected subgraph of a TensorFlow graph.
136   class Cluster {
137    public:
138     // Constructs a trivial cluster representing a single TF node.
Cluster(int tf_graph_node_id,int effective_cluster_size,bool has_functional_control_flow,DeviceSet devices,std::optional<DeviceId> resource_op_device,std::optional<int> resource_var_operation_node_id,std::optional<DeadnessPredicate> deadness_predicate,bool is_xla_compile_attr_true,std::optional<string> xla_scope)139     Cluster(int tf_graph_node_id, int effective_cluster_size,
140             bool has_functional_control_flow, DeviceSet devices,
141             std::optional<DeviceId> resource_op_device,
142             std::optional<int> resource_var_operation_node_id,
143             std::optional<DeadnessPredicate> deadness_predicate,
144             bool is_xla_compile_attr_true, std::optional<string> xla_scope)
145         : cycles_graph_node_id_(tf_graph_node_id),
146           effective_cluster_size_(effective_cluster_size),
147           has_functional_control_flow_(has_functional_control_flow),
148           devices_(std::move(devices)),
149           resource_op_device_(resource_op_device),
150           deadness_predicate_(deadness_predicate),
151           is_xla_compile_attr_true_(is_xla_compile_attr_true),
152           xla_scope_(std::move(xla_scope)) {
153       if (resource_var_operation_node_id.has_value()) {
154         resource_var_operation_node_ids_.push_back(
155             *resource_var_operation_node_id);
156       }
157     }
158 
159     // Merges `other` into this cluster, and clears `other`.  This method is
160     // closely tied with the implementation of `MarkForCompilationPassImpl`.
161     void Merge(Cluster* other);
162 
163     // If this is a trivial cluster containing only one node then return the ID
164     // of that node.  May not be called otherwise.
GetIdOfOnlyNode() const165     int GetIdOfOnlyNode() const {
166       DCHECK_EQ(cluster_size(), 1);
167       return cycles_graph_node_id();
168     }
169 
170     // The number of TF nodes in this cluster.
cluster_size() const171     int cluster_size() const { return cluster_size_; }
172 
173     // The ID of the cluster as represented in `cycles_graph_`.
cycles_graph_node_id() const174     int cycles_graph_node_id() const { return cycles_graph_node_id_; }
175 
176     // Sets the ID of the cluster as represented in `cycles_graph_`.
set_cycles_graph_node_id(int cycles_graph_node_id)177     void set_cycles_graph_node_id(int cycles_graph_node_id) {
178       cycles_graph_node_id_ = cycles_graph_node_id;
179     }
180 
181     // The size of the cluster excluding constant and identity nodes.
effective_cluster_size() const182     int effective_cluster_size() const { return effective_cluster_size_; }
183 
184     // True if the cluster has functional control flow like `If` and `While`.
has_functional_control_flow() const185     bool has_functional_control_flow() const {
186       return has_functional_control_flow_;
187     }
188 
189     // The set of devices nodes in the cluster are placed on.
devices() const190     const DeviceSet& devices() const { return devices_; }
191 
192     // If the cluster has a resource operation then the device the resource
193     // operation is placed on.  A cluster may have resource ops placed only on a
194     // single device.
resource_op_device() const195     const std::optional<DeviceId>& resource_op_device() const {
196       return resource_op_device_;
197     }
198 
199     // If not nullopt the a predicate that is true iff the cluster is alive.
200     // Otherwise the user has (unsafely) disabled deadness analysis.  If this is
201     // unset on a single Cluster instance then it is unset on all Cluster
202     // instances.
deadness_predicate() const203     const std::optional<DeadnessPredicate>& deadness_predicate() const {
204       return deadness_predicate_;
205     }
206 
207     // If true then the cluster has a XlaCompile=true attribute on one of its
208     // nodes.
is_xla_compile_attr_true() const209     bool is_xla_compile_attr_true() const { return is_xla_compile_attr_true_; }
210 
211     // If not nullopt then the all nodes in the cluster either do not have the
212     // XlaScope attribute set or have it set to the value returned.
xla_scope() const213     const std::optional<string>& xla_scope() const { return xla_scope_; }
214 
215     // Returns the TF graph node IDs for the resource variable operations in
216     // this cluster.
resource_var_operation_node_ids() const217     absl::Span<const int> resource_var_operation_node_ids() const {
218       return resource_var_operation_node_ids_;
219     }
220 
DebugString(const Graph & graph) const221     string DebugString(const Graph& graph) const {
222       Node* node = graph.FindNodeId(cycles_graph_node_id());
223       if (!node) {
224         // This should never happen but we try to be resilient because this is a
225         // debugging aid.
226         return absl::StrCat("NULL NODE IN #", cycles_graph_node_id());
227       }
228 
229       if (cluster_size() == 1) {
230         return absl::StrCat("<", node->name(), " #", cycles_graph_node_id(),
231                             ">");
232       }
233 
234       return absl::StrCat("<", node->name(), " + ", cluster_size() - 1,
235                           " others #", cycles_graph_node_id(), ">");
236     }
237 
238    private:
239     int cluster_size_ = 1;
240     int cycles_graph_node_id_;
241     int effective_cluster_size_;
242     bool has_functional_control_flow_;
243     DeviceSet devices_;
244     std::optional<DeviceId> resource_op_device_;
245     std::optional<DeadnessPredicate> deadness_predicate_;
246     bool is_xla_compile_attr_true_;
247     std::optional<string> xla_scope_;
248     std::vector<int> resource_var_operation_node_ids_;
249 
250     TF_DISALLOW_COPY_AND_ASSIGN(Cluster);
251   };
252 
253   // If `cluster` has only a single node then returns that, otherwise returns
254   // nullptr.
255   Node* GetOnlyNodeIn(const Cluster& cluster);
256 
257   // Returns true if `cluster` is a trivial cluster containing a "sink like"
258   // node -- a NoOp node that only the Sink node control depends on.
259   bool IsSinkLike(const Cluster& cluster);
260 
261   // Returns true if `cluster` looks like an "i++" operation on an integer
262   // scalar resource variable.
263   bool IsScalarIntegerResourceOperation(const Cluster& cluster);
264 
265   // ---------------------------------------------------------------------------
266   // The pass proceeds in five steps, out of which `RunEdgeContractionLoop` and
267   // `CreateClusters` do most of the heavy lifting.
268 
269   // Initializes some internal data structures.
270   //
271   // If this returns false then Initialize exited early (either because there is
272   // nothing to do or we saw a graph that we can't handle) and not all the
273   // fields in this MarkForCompilationPassImpl instance are set up.
274   StatusOr<bool> Initialize();
275 
276   // Runs through the entire cluster graph in post-order and calls `fn(from,
277   // to)` on each edge.  `fn(from, to)` is expected to return true if it was
278   // able to contract `from`->`to`.
279   //
280   // Returns true if `fn` returned true for any edge.
281   template <typename FnTy>
282   StatusOr<bool> ForEachEdgeInPostOrder(FnTy fn);
283 
284   // Contracts as many edges as possible to create XLA clusters.  After this
285   // finishes the clustering decisions made are implicitly stored in
286   // `clusters_`.
287   Status RunEdgeContractionLoop();
288 
289   // "Fixes up" clusters by removing some modes.
290   //
291   // Autoclustering can sometimes be overeager.  For example, clustering large
292   // constants (or large broadcasts of constants) can increase the live range
293   // of those constants, and increase overall memory usage.
294   //
295   // This function removes "obviously bad" cases like these.
296   Status DeclusterNodes();
297 
298   // Manifests the clustering decisions into the TF graph by tagging nodes with
299   // an `_XlaCluster` attribute.  Also some basic filter logic, like
300   // tf_xla_min_cluster_size, are applied here.
301   Status CreateClusters();
302 
303   Status DumpDebugInfo();
304 
IsCompilationCandidate(Node * n) const305   bool IsCompilationCandidate(Node* n) const {
306     return compilation_candidates_.find(n) != compilation_candidates_.end();
307   }
308 
309   // Tries to contract the edge from cluster `from` to cluster `to`.  Returns
310   // true if successful.
311   StatusOr<bool> TryToContractEdge(Cluster* from, Cluster* to);
312 
313   // Nodes that XLA can compile are put in `compilation_candidates_`.
314   Status FindCompilationCandidates();
315 
316   bool CompilationDisallowedByXlaCompileAttr(Node* node);
317 
318   // Populates `clusters_`.
319   Status BuildInitialClusterSet();
320 
321   StatusOr<bool> ShouldCompileClusterImpl(const Cluster& cluster);
322 
323   StatusOr<bool> ShouldCompileCluster(const Cluster& cluster);
324 
325   StatusOr<bool> ClusteringWillIntroduceInterDeviceDependency(
326       const Cluster& from, const Cluster& to);
327 
328   // Returns true if the devices in `cluster_a` and `cluster_b` are compatible
329   // and therefore not a hindrance for combining the two clusters into a larger
330   // cluster.
331   StatusOr<bool> AreDevicesCompatible(const Cluster& cluster_a,
332                                       const Cluster& cluster_b);
333 
334   void DumpPostClusteringGraphs();
335   void VLogClusteringSummary();
336 
MakeNewCluster(int cycles_graph_node_id,int effective_cluster_size,bool has_functional_control_flow,const DeviceSet & device_set,std::optional<DeviceId> resource_op_device,std::optional<int> resource_var_operation_node_id,std::optional<DeadnessPredicate> deadness_predicate,bool is_xla_compile_attr_true,std::optional<string> xla_scope)337   Cluster* MakeNewCluster(int cycles_graph_node_id, int effective_cluster_size,
338                           bool has_functional_control_flow,
339                           const DeviceSet& device_set,
340                           std::optional<DeviceId> resource_op_device,
341                           std::optional<int> resource_var_operation_node_id,
342                           std::optional<DeadnessPredicate> deadness_predicate,
343                           bool is_xla_compile_attr_true,
344                           std::optional<string> xla_scope) {
345     cluster_storage_.push_back(std::make_unique<Cluster>(
346         cycles_graph_node_id, effective_cluster_size,
347         has_functional_control_flow, device_set, resource_op_device,
348         resource_var_operation_node_id, deadness_predicate,
349         is_xla_compile_attr_true, xla_scope));
350     return cluster_storage_.back().get();
351   }
352 
353   std::optional<string> GetXlaScope(Node* n);
354 
355   // Returns the cluster for node `n`.  If two nodes, N1 and N2, are placed in
356   // the same cluster by the clustering algorithm then this function will return
357   // the same Cluster instance for N1 and N2.
358   //
359   // Returns nullptr if `n` is not a compilation candidate.
GetClusterForNode(Node * n)360   Cluster* GetClusterForNode(Node* n) {
361     return cluster_for_node_[n->id()].Get();
362   }
363 
364   // Returns the cluster for a node in `cycles_graph_`.  This uses the same
365   // underlying map because of how we set things up, but we can do an additional
366   // CHECK in this accessor.
367   //
368   // Returns nullptr if `node_id` is not a compilation candidate.
GetClusterForCyclesGraphNode(int node_id)369   Cluster* GetClusterForCyclesGraphNode(int node_id) {
370     // We have to check `graph_->FindNodeId(node) == nullptr` because we add all
371     // nodes in [0, graph_->num_node_ids()) to the cycle detection graph but the
372     // TF graph may be missing some node ids.
373     if (node_id >= graph_->num_node_ids() ||
374         graph_->FindNodeId(node_id) == nullptr) {
375       return nullptr;
376     }
377     Cluster* cluster = cluster_for_node_[node_id].Get();
378     if (cluster) {
379       DCHECK_EQ(cluster->cycles_graph_node_id(), node_id);
380     }
381     return cluster;
382   }
383 
384   bool LogNotContractableAndReturnFalse(Cluster* from, Cluster* to,
385                                         absl::string_view reason);
386 
387   // Finds a path in `cycles_graph_` from `from` to `to` that is not a direct
388   // edge from `from` to `to`.
389   //
390   // Tries to find a path that contains at least one unclusterable node.
391   std::vector<int> FindAlternatePathForDebugging(int from, int to);
392 
393   // Returns a string representing `cycles_graph_node_id`.  If the node is
394   // unclusterable (either it is a phatom "frame" node or is not a compilation
395   // candidate) then set `*found_unclustered` to true.
396   string DebugStringForCyclesGraphNode(int node_id, bool* found_unclustered);
397 
398   // We could not contract the edge from `from` to `to`.  Return a string
399   // describing an alternate path from `from` to `to` (besides the direct edge
400   // from `from` to `to`) which would have created a cycle had we contracted the
401   // edge.
402   //
403   // Tries (if possible) to find a path that contains at least one unclusterable
404   // node as it is surprising to the user if we print "A->B could not be
405   // contracted because of the path [P,Q,R]" where P, Q and R are all clusters
406   // since in that case a natural question is why we could not form a {A, P, Q,
407   // R, B} cluster.
408   string DescribePotentialCycle(int from, int to);
409 
410   // Merge the clusters `cluster_from` and `cluster_to`. After this step the
411   // larger combined cluster is represented by `cluster_from`, but can have
412   // `cycles_graph_`'s ID of either `cluster_from` or `cluster_to` depending on
413   // which way will require less operations.
MergeClusters(Cluster * cluster_from,Cluster * cluster_to)414   bool MergeClusters(Cluster* cluster_from, Cluster* cluster_to) {
415     int from = cluster_from->cycles_graph_node_id();
416     int to = cluster_to->cycles_graph_node_id();
417 
418     auto optional_merged_node = cycles_graph_.ContractEdge(from, to);
419     if (!optional_merged_node.has_value()) {
420       VLOG(3) << "Could not contract " << cluster_from->DebugString(*graph_)
421               << " -> " << cluster_to->DebugString(*graph_)
422               << " because contracting the edge would create a cycle via "
423               << DescribePotentialCycle(from, to) << ".";
424       return false;
425     }
426 
427     // Merge the clusters.
428     cluster_from->Merge(cluster_to);
429     // Update `cycle_graph_`'s ID.
430     cluster_from->set_cycles_graph_node_id(optional_merged_node.value());
431 
432     // Merge the UnionFind<Cluster*>.
433     cluster_for_node_[from].Merge(&cluster_for_node_[to]);
434 
435     return true;
436   }
437 
EdgeContractionFailureMsg(Cluster * from,Cluster * to,absl::string_view reason)438   string EdgeContractionFailureMsg(Cluster* from, Cluster* to,
439                                    absl::string_view reason) {
440     return absl::StrCat("Could not contract ", from->DebugString(*graph_),
441                         " -> ", to->DebugString(*graph_), " because ", reason,
442                         ".");
443   }
444 
445   DebugOptions debug_options_;
446   Graph* graph_;
447   uint64 graph_fingerprint_;
448   FunctionLibraryDefinition* flib_def_;
449   Env* env_;
450   OptimizerOptions::GlobalJitLevel global_jit_level_;
451   bool cpu_global_jit_;
452   absl::flat_hash_map<const Cluster*, bool> should_compile_cluster_cache_;
453   jit::DeviceInfoCache device_info_cache_;
454 
455   bool initialized_ = false;
456   bool edges_contracted_ = false;
457   bool clusters_created_ = false;
458 
459   std::vector<std::unique_ptr<Cluster>> cluster_storage_;
460   std::vector<UnionFind<Cluster*>> cluster_for_node_;
461   absl::flat_hash_set<const Node*> declustered_nodes_;
462   GraphCycles cycles_graph_;
463   OrderedNodeSet compilation_candidates_;
464   std::unique_ptr<DeadnessAnalysis> deadness_analysis_;
465   int64_t iteration_count_ = 0;
466   absl::flat_hash_set<std::pair<int, int>> unsafe_resource_deps_;
467 };
468 
FindAlternatePathForDebugging(int from,int to)469 std::vector<int> MarkForCompilationPassImpl::FindAlternatePathForDebugging(
470     int from, int to) {
471   std::vector<int> rpo = cycles_graph_.AllNodesInPostOrder();
472   absl::c_reverse(rpo);
473 
474   // best_pred_for_node[n] contains a predecessor of `n` that has an
475   // unclusterable node in some path from `from` to itself.
476   // best_pred_for_node[n] is unpopulated for nodes that are not reachable from
477   // `from`.  We build this table up inductively by traversing the cycles graph
478   // in RPO.
479   absl::flat_hash_map<int, int> best_pred_for_node;
480   best_pred_for_node[from] = -1;
481 
482   int rpo_index = 0, current_rpo_node;
483   do {
484     current_rpo_node = rpo[rpo_index++];
485     std::optional<int> some_pred, preferred_pred;
486     for (int pred : cycles_graph_.Predecessors(current_rpo_node)) {
487       if (!best_pred_for_node.contains(pred)) {
488         continue;
489       }
490 
491       // Ignore the from->to edge since we're trying to find an alternate path.
492       if (current_rpo_node == to && pred == from) {
493         continue;
494       }
495 
496       some_pred = pred;
497       if (GetClusterForCyclesGraphNode(pred) == nullptr) {
498         preferred_pred = pred;
499       }
500     }
501 
502     if (some_pred || preferred_pred) {
503       best_pred_for_node[current_rpo_node] =
504           preferred_pred.has_value() ? *preferred_pred : *some_pred;
505     }
506   } while (current_rpo_node != to);
507 
508   auto get_best_pred = [&](int n) {
509     auto it = best_pred_for_node.find(n);
510     CHECK(it != best_pred_for_node.end());
511     return it->second;
512   };
513 
514   std::vector<int> path;
515   int current_path_node = get_best_pred(to);
516   while (current_path_node != from) {
517     path.push_back(current_path_node);
518     current_path_node = get_best_pred(current_path_node);
519   }
520 
521   absl::c_reverse(path);
522   return path;
523 }
524 
DebugStringForCyclesGraphNode(int cycles_graph_node_id,bool * found_unclustered)525 string MarkForCompilationPassImpl::DebugStringForCyclesGraphNode(
526     int cycles_graph_node_id, bool* found_unclustered) {
527   Cluster* cluster = GetClusterForCyclesGraphNode(cycles_graph_node_id);
528   if (cluster) {
529     return cluster->DebugString(*graph_);
530   }
531 
532   *found_unclustered = true;
533   if (cycles_graph_node_id >= graph_->num_node_ids()) {
534     return absl::StrCat("<oob #", cycles_graph_node_id, ">");
535   }
536 
537   Node* node = graph_->FindNodeId(cycles_graph_node_id);
538   if (!node) {
539     return absl::StrCat("<bad #", cycles_graph_node_id, ">");
540   }
541 
542   return node->name();
543 }
544 
DescribePotentialCycle(int from,int to)545 string MarkForCompilationPassImpl::DescribePotentialCycle(int from, int to) {
546   std::vector<string> path_str;
547   bool found_unclustered = false;
548   absl::c_transform(FindAlternatePathForDebugging(from, to),
549                     std::back_inserter(path_str), [&](int node_id) {
550                       return DebugStringForCyclesGraphNode(node_id,
551                                                            &found_unclustered);
552                     });
553   return absl::StrCat(!found_unclustered ? "(all clusters) " : "", "[",
554                       absl::StrJoin(path_str, ","), "]");
555 }
556 
Merge(Cluster * other)557 void MarkForCompilationPassImpl::Cluster::Merge(Cluster* other) {
558   // We keep our own cycles_graph_node_id_ to mirror what GraphCycles does.
559 
560   // Clearing out data structures in `other` is just a memory saving
561   // optimization and not needed for correctness.
562 
563   cluster_size_ += other->cluster_size_;
564   effective_cluster_size_ += other->effective_cluster_size_;
565   has_functional_control_flow_ |= other->has_functional_control_flow_;
566 
567   devices_.UnionWith(other->devices_);
568 
569   DCHECK(!(resource_op_device_.has_value() &&
570            other->resource_op_device_.has_value()) ||
571          *resource_op_device_ == *other->resource_op_device_)
572       << "AreDevicesCompatible should have returned false otherwise!";
573 
574   if (!resource_op_device_.has_value()) {
575     resource_op_device_ = other->resource_op_device_;
576   }
577 
578   is_xla_compile_attr_true_ |= other->is_xla_compile_attr_true_;
579 
580   if (!xla_scope_.has_value()) {
581     xla_scope_ = std::move(other->xla_scope_);
582   }
583 
584   resource_var_operation_node_ids_.reserve(
585       resource_var_operation_node_ids_.size() +
586       other->resource_var_operation_node_ids_.size());
587   absl::c_copy(other->resource_var_operation_node_ids_,
588                std::back_inserter(resource_var_operation_node_ids_));
589   other->resource_var_operation_node_ids_.clear();
590 }
591 
IgnoreResourceOpForSafetyAnalysis(jit::DeviceInfoCache * device_info_cache,const Node & n,bool * ignore)592 Status IgnoreResourceOpForSafetyAnalysis(
593     jit::DeviceInfoCache* device_info_cache, const Node& n, bool* ignore) {
594   // If a resource operation is assigned to XLA_CPU or XLA_GPU explicitly then
595   // ignore it during resource operation safety analysis.  We need this hack
596   // because of two reasons:
597   //
598   //  1. Operations assigned to XLA_CPU and XLA_GPU have to always be compiled.
599   //  2. We don't support live-out values of type DT_RESOURCE and live-in values
600   //     of type DT_RESOURCE that are not resource variables.
601   //
602   // Together these imply we cannot let resource variable safety analysis
603   // constrain e.g. a TensorArrayV3->TensorArrayAssignV3 edge to be in different
604   // clusters: both of them will have to be clustered because of (1) and we
605   // won't be able to keep the edge between the two as neither the input to the
606   // second XLA cluster nor the output from the first XLA cluster are supported
607   // because of (2).
608   //
609   // TODO(b/113100872): This can be fixed if the TensorFlow representation for
610   // TensorArray and Stack on the XLA_{C|G}PU devices were the same in XLA; then
611   // (2) would no longer hold.
612 
613   if (n.assigned_device_name().empty()) {
614     *ignore = false;
615     return OkStatus();
616   }
617 
618   TF_ASSIGN_OR_RETURN(
619       const XlaOpRegistry::DeviceRegistration* registration,
620       device_info_cache->GetCompilationDevice(n.assigned_device_name()));
621 
622   if (!registration) {
623     *ignore = true;
624   } else {
625     *ignore = registration->cluster_resource_variable_ops_unsafely;
626   }
627   return OkStatus();
628 }
629 
Initialize()630 StatusOr<bool> MarkForCompilationPassImpl::Initialize() {
631   TF_RET_CHECK(!initialized_ && !edges_contracted_ && !clusters_created_);
632   initialized_ = true;
633 
634   TF_RETURN_IF_ERROR(FindCompilationCandidates());
635 
636   if (compilation_candidates_.empty()) {
637     VLOG(2) << "No compilable candidates";
638     return false;
639   }
640 
641   TF_ASSIGN_OR_RETURN(bool cycle_detection_graph_ok,
642                       CreateCycleDetectionGraph(graph_, &cycles_graph_));
643   if (!cycle_detection_graph_ok) {
644     // TODO(sanjoy): This should be logged via the XLA activity listener.
645     VLOG(2) << "Could not form cycle detection graph";
646     return false;
647   }
648 
649   if (!debug_options_.ignore_deadness_checks) {
650     XLA_SCOPED_LOGGING_TIMER_LEVEL("DeadnessAnalysis", 1);
651     TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(*graph_, &deadness_analysis_));
652   }
653 
654   // If the user is requesting deterministic cluster names compute a hash of the
655   // input graph to provide a stable but unique prefix for the name.
656   if (debug_options_.deterministic_cluster_names) {
657     TF_ASSIGN_OR_RETURN(graph_fingerprint_, FingerprintGraph(*graph_));
658   }
659 
660   // Each compilation candidate belongs to a cluster. The cluster's
661   // representative names the node in the 'cycles' graph that represents the
662   // cluster.
663   TF_RETURN_IF_ERROR(BuildInitialClusterSet());
664   return true;
665 }
666 
667 template <typename FnTy>
ForEachEdgeInPostOrder(FnTy fn)668 StatusOr<bool> MarkForCompilationPassImpl::ForEachEdgeInPostOrder(FnTy fn) {
669   bool changed = false;
670   for (int32_t node : cycles_graph_.AllNodesInPostOrder()) {
671     Cluster* cluster_from = GetClusterForCyclesGraphNode(node);
672     if (!cluster_from) {
673       continue;
674     }
675 
676     // Make a copy of the set of successors because we may modify the graph in
677     // TryToContractEdge.
678     std::vector<int32> successors_copy =
679         cycles_graph_.SuccessorsCopy(cluster_from->cycles_graph_node_id());
680 
681     for (int to : successors_copy) {
682       iteration_count_++;
683 
684       Cluster* cluster_to = GetClusterForCyclesGraphNode(to);
685       if (!cluster_to) {
686         continue;
687       }
688 
689       TF_ASSIGN_OR_RETURN(bool contracted_edge, fn(cluster_from, cluster_to));
690       changed |= contracted_edge;
691     }
692   }
693 
694   return changed;
695 }
696 
GetOnlyNodeIn(const Cluster & cluster)697 Node* MarkForCompilationPassImpl::GetOnlyNodeIn(const Cluster& cluster) {
698   return cluster.cluster_size() == 1
699              ? graph_->FindNodeId(cluster.GetIdOfOnlyNode())
700              : nullptr;
701 }
702 
IsSinkLike(const Cluster & cluster)703 bool MarkForCompilationPassImpl::IsSinkLike(const Cluster& cluster) {
704   if (Node* n = GetOnlyNodeIn(cluster)) {
705     return n->type_string() == "NoOp" && n->out_edges().size() == 1 &&
706            (*n->out_edges().begin())->dst()->IsSink();
707   }
708 
709   return false;
710 }
711 
IsScalarIntegerResourceOperation(const Cluster & cluster)712 bool MarkForCompilationPassImpl::IsScalarIntegerResourceOperation(
713     const Cluster& cluster) {
714   Node* n = GetOnlyNodeIn(cluster);
715   if (!n) {
716     return false;
717   }
718 
719   if (n->type_string() != "AssignAddVariableOp" &&
720       n->type_string() != "AssignSubVariableOp") {
721     return false;
722   }
723 
724   DataType dtype;
725   if (!TryGetNodeAttr(n->def(), "dtype", &dtype) || !DataTypeIsInteger(dtype)) {
726     return false;
727   }
728 
729   Node* const_input = nullptr;
730   for (const Edge* e : n->in_edges()) {
731     if (!e->IsControlEdge() && e->src()->IsConstant()) {
732       const_input = e->src();
733       break;
734     }
735   }
736 
737   if (!const_input) {
738     return false;
739   }
740 
741   const TensorProto* proto = nullptr;
742   if (!TryGetNodeAttr(const_input->def(), "value", &proto)) {
743     return false;
744   }
745 
746   return TensorShapeUtils::IsScalar(proto->tensor_shape());
747 }
748 
RunEdgeContractionLoop()749 Status MarkForCompilationPassImpl::RunEdgeContractionLoop() {
750   TF_RET_CHECK(initialized_ && !edges_contracted_ && !clusters_created_);
751   edges_contracted_ = true;
752 
753   // TODO(hpucha): Handle the case where kXlaClusterAttr is already set (for
754   // example, from the Grappler fusion pass).
755 
756   // In general there are multiple maximal clusterings, but they are not all
757   // equally performant.  Some clustering decision are likely to improve
758   // performance much more than others, and we cannot order contractions on this
759   // cost function, nor can we look at global information while deciding on
760   // individual edges to contract.  Instead, we will make decisions on these
761   // important edges then make decisions on all other edges, causing the highest
762   // chance of all most important edges to be contracted.
763   //
764   // An example of where this might occur is with a digraph:
765   // {A -> B, B -> C, A -> X, X -> C} where B is a Size operation and X is
766   // not-compilable. In this case, the valid clusterings are {A,B} or {B,C}. B
767   // should be clustered with A because it will prevent a potentially large
768   // tensor from A being computed and copied.
769   //
770   // To choose better maximal clusterings we make multiple iterations over the
771   // graph in post-order, where each such iteration is called a "phase".
772 
773   // Phase 0: contract metadata operations with their producer.
774 
775   VLOG(4) << "Running phase 0";
776   TF_RETURN_IF_ERROR(
777       ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) -> StatusOr<bool> {
778         // Shape consuming operations are desirable to cluster with their
779         // operands because they return a small set of scalar values after
780         // consuming a large amount of data.  For example, given a graph X -> Y
781         // -> Size -> Z, where the possible clustering is [{X, Y, Size}, {Z}] or
782         // [{X, Y}, {Size, Z}], the better clustering is Size with Y because the
783         // output of size will be a small tensor while Y is a potentially large
784         // tensor that must be computed and possible transposed/copied before
785         // the second cluster executes.
786         Node* n = GetOnlyNodeIn(*to);
787         bool is_shape_consumer_op = n && IsShapeConsumerOp(*n);
788         if (!is_shape_consumer_op) {
789           return false;
790         }
791 
792         return TryToContractEdge(from, to);
793       }).status());
794 
795   // Phase 1: apply a heuristic to ensure that we don't mess up clustering due
796   // to "group_deps".  After this phase most edges should have been contracted.
797 
798   VLOG(4) << "Running phase 1";
799   TF_RETURN_IF_ERROR(
800       ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) -> StatusOr<bool> {
801         // We split out this phase to get good clustering in the presence of a
802         // specific pattern seen in some graphs:
803         //
804         // digraph {
805         //   ApplyWeightUpdates_0 -> "iteration++"
806         //   ApplyWeightUpdates_1 -> "iteration++"
807         //   ApplyWeightUpdates_2 -> "iteration++"
808         //   ApplyWeightUpdates_0 -> Computation_A
809         //   ApplyWeightUpdates_1 -> Computation_B
810         //   ApplyWeightUpdates_2 -> Computation_C
811         //   Computation_A -> NoOp
812         //   Computation_B -> NoOp
813         //   Computation_C -> NoOp
814         //   "iteration++" -> NoOp
815         // }
816         //
817         // In the graph above we can't cluster iteration++ with any of the
818         // gradient update operations since that will break the TF resource
819         // variable memory model.  Given that constraint the ideal clustering
820         // would be to put all the gradient updates and all of the Computation_*
821         // nodes in one cluster, and leave iteration++ and NoOp unclustered.
822         //
823         // A naive post-order traversal would not create this good clustering,
824         // however.  Instead it will first create a cluster that puts
825         // Computation_* nodes, the NoOp and iteration++ node in a single
826         // cluster, after which it will fail to put any of the
827         // ApplyWeightUpdates_* nodes into this cluster. To avoid this fate we
828         // instead run a pass that avoids contracting edges _into_ NoOps like
829         // the above, and avoid clustering edges _from_ "iteration++" like the
830         // above.  Then we run a second pass that contracts the edges we could
831         // not contract the first time around.
832 
833         if (IsSinkLike(*to)) {
834           return false;
835         }
836 
837         if (IsScalarIntegerResourceOperation(*from)) {
838           return false;
839         }
840 
841         return TryToContractEdge(from, to);
842       }).status());
843 
844   // Phase 2: contract any remaining edges.  After this phase we should have a
845   // maximal clustering:
846   //
847   // A. We visit a cluster only after maximally clustering all its children.
848   // B. By the time we're done with a node all of its children that could have
849   //    been absorbed into the node have been absorbed.
850   // C. We have an invariant that making a cluster larger does not make edges
851   //    leaving it more contractable. That is, if we have
852   //    digraph { X->Y; Y->Z; } then collapsing X->Y does not make it possible
853   //    to contract Y->Z if Y->Z was not contractible originally.
854   VLOG(4) << "Running phase 2";
855   TF_RETURN_IF_ERROR(ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) {
856                        return TryToContractEdge(from, to);
857                      }).status());
858 
859   // Check that the conclusion made above (that iterating over the graph once in
860   // post order gives a maximal clustering) holds.  Once the linear time
861   // post-order scheme has been battle tested we can move this to happen only in
862   // debug builds.
863   VLOG(2) << "Checking idempotence";
864   TF_ASSIGN_OR_RETURN(bool changed,
865                       ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) {
866                         return TryToContractEdge(from, to);
867                       }));
868   TF_RET_CHECK(!changed);
869 
870   return OkStatus();
871 }
872 
DeclusterNodes()873 Status MarkForCompilationPassImpl::DeclusterNodes() {
874   for (Node* n : compilation_candidates_) {
875     Cluster* cluster = GetClusterForNode(n);
876     if (cluster == nullptr) {
877       continue;
878     }
879 
880     // De-cluster Fill ops that are
881     //  - used at least once outside the cluster, and
882     //  - not used inside the cluster.
883     //
884     // In this case, using XLA for the op can only make peak memory usage worse.
885     // If we don't cluster the Fill, it can be materialized right before it's
886     // used in the TF graph.  Whereas if we do cluster it, the Fill must be live
887     // starting at the end of the XLA cluster, potentially significantly
888     // increasing its live range.
889     //
890     // See b/221997940 for a real-world example of this.
891     if (n->op_def().name() == "Fill" &&
892         n->out_nodes().begin() != n->out_nodes().end() &&
893         absl::c_all_of(n->out_nodes(), [&](Node* user) {
894           return GetClusterForNode(user) != cluster;
895         })) {
896       declustered_nodes_.insert(n);
897     }
898   }
899 
900   return OkStatus();
901 }
902 
903 // Tracks monotonic sequence numbers for graphs.
904 class ClusterSequenceNumberGenerator {
905  public:
Reset()906   void Reset() {
907     mutex_lock lock(mu_);
908     sequence_numbers_.clear();
909   }
910 
GetNext(uint64 key)911   int64 GetNext(uint64 key) {
912     mutex_lock lock(mu_);
913     return sequence_numbers_[key]++;
914   }
915 
Global()916   static ClusterSequenceNumberGenerator& Global() {
917     static ClusterSequenceNumberGenerator* gen =
918         new ClusterSequenceNumberGenerator;
919     return *gen;
920   }
921 
922  private:
923   mutex mu_;
924   absl::flat_hash_map<uint64, int64> sequence_numbers_;
925 };
926 
927 // Get a monotonic sequence numbers for a graph identified by its `fingerprint`.
928 // The sequence number is necessary to disambiguate clusters extracted from the
929 // same graph and when duplicate graphs exist within the same process.
GetNextClusterSequenceNumber(uint64 fingerprint)930 int64_t GetNextClusterSequenceNumber(uint64 fingerprint) {
931   return ClusterSequenceNumberGenerator::Global().GetNext(fingerprint);
932 }
933 
CreateClusters()934 Status MarkForCompilationPassImpl::CreateClusters() {
935   TF_RET_CHECK(initialized_ && edges_contracted_ && !clusters_created_);
936   clusters_created_ = true;
937 
938   // Names for each cluster.
939   std::unordered_map<int, string> cluster_names;
940 
941   if (debug_options_.dump_graphs) {
942     DumpGraphToFile("before_mark_for_compilation", *graph_, flib_def_);
943   }
944 
945   // Mark clusters for compilation that:
946   // * are placed on a device that requires compilation (an XlaDevice),
947   // * are explicitly marked for compilation (_XlaCompile=true), or
948   // * have more than debug_options_.xla_min_cluster_size elements (applicable
949   //   only if compilation is enabled, otherwise there will be no such
950   //   candidates).
951   for (Node* n : compilation_candidates_) {
952     Cluster* cluster = GetClusterForNode(n);
953     TF_ASSIGN_OR_RETURN(bool should_compile_cluster,
954                         ShouldCompileCluster(*cluster));
955     if (!should_compile_cluster || declustered_nodes_.contains(n)) {
956       continue;
957     }
958 
959     // We assume that functional If and While nodes have at least
960     // min_cluster_size non-trivial nodes in them.  It would be more principled
961     // to (recursively) verify this fact, but that's probably not worth the
962     // trouble.
963 
964     if (cluster->effective_cluster_size() >= debug_options_.min_cluster_size ||
965         cluster->has_functional_control_flow() ||
966         cluster->is_xla_compile_attr_true()) {
967       string& name = cluster_names[cluster->cycles_graph_node_id()];
968 
969       if (name.empty()) {
970         if (debug_options_.deterministic_cluster_names) {
971           name = absl::StrCat("cluster_", graph_fingerprint_, "_",
972                               GetNextClusterSequenceNumber(graph_fingerprint_));
973         } else {
974           name = absl::StrCat("cluster_",
975                               GetNextClusterSequenceNumber(graph_fingerprint_));
976         }
977       }
978 
979       n->AddAttr(kXlaClusterAttr, name);
980       n->AddAttr(kXlaAlreadyClustered, true);
981       VLOG(3) << "Assigning node " << n->name() << " to cluster " << name;
982     }
983   }
984 
985   return OkStatus();
986 }
987 
DumpDebugInfo()988 Status MarkForCompilationPassImpl::DumpDebugInfo() {
989   TF_RET_CHECK(initialized_ && edges_contracted_ && clusters_created_);
990 
991   if (debug_options_.dump_graphs) {
992     DumpPostClusteringGraphs();
993   }
994 
995   VLogClusteringSummary();
996 
997   return OkStatus();
998 }
999 
1000 StatusOr<bool>
ClusteringWillIntroduceInterDeviceDependency(const Cluster & cluster_from,const Cluster & cluster_to)1001 MarkForCompilationPassImpl::ClusteringWillIntroduceInterDeviceDependency(
1002     const Cluster& cluster_from, const Cluster& cluster_to) {
1003   // If any of the consumer's producers are on a different device, do not
1004   // cluster these nodes. This prevents other work on this device from being
1005   // delayed by work on other devices. We consider predecessors of the entire
1006   // cluster rather than just the inputs to the node to prevent the cluster
1007   // still being combined in cases where the 'to' cluster has multiple
1008   // dependencies on the 'from' cluster and another dependency leads to a
1009   // merging of the clusters.
1010   //
1011   // TODO(b/117085735): We probably want to handle the reciprocal of this case
1012   // where a cluster is producing data for multiple devices.
1013   for (const auto& in_id :
1014        cycles_graph_.Predecessors(cluster_to.cycles_graph_node_id())) {
1015     const Cluster* cluster_in = GetClusterForCyclesGraphNode(in_id);
1016     if (cluster_in) {
1017       TF_ASSIGN_OR_RETURN(bool devices_compatible,
1018                           AreDevicesCompatible(cluster_to, *cluster_in));
1019       if (!devices_compatible) {
1020         return true;
1021       }
1022       TF_ASSIGN_OR_RETURN(devices_compatible,
1023                           AreDevicesCompatible(cluster_from, *cluster_in));
1024       if (!devices_compatible) {
1025         return true;
1026       }
1027     }
1028   }
1029 
1030   return false;
1031 }
1032 
GetXlaScope(Node * node)1033 std::optional<string> MarkForCompilationPassImpl::GetXlaScope(Node* node) {
1034   // Look for either _XlaScope or _XlaInternalScope on both nodes to guide
1035   // clustering.  If both nodes have a scope and the scopes do not match, do
1036   // not cluster along this edge.  If even one of the nodes lacks a scope
1037   // attribute, then it is treated as a "bridge" and a cluster may be created
1038   // along it.
1039   //
1040   // The difference between _XlaScope and _XlaInternalScope is that _XlaScope is
1041   // provided by users through jit_scope APIs, while _XlaInternalScope is
1042   // automatically generated by the ClusterScopingPass when auto_jit is on.  As
1043   // such, we respect _XlaScope only when auto_jit is off, while respecting
1044   // _XlaInternalScope only when auto_jit is on.
1045   //
1046   // We may want to restrict the _XlaScope behavior to require all nodes marked
1047   // with _XlaCompile=true to also have a _XlaScope property set (and raise an
1048   // error otherwise); but for now we don't do this.
1049 
1050   if (global_jit_level_ != OptimizerOptions::OFF) {
1051     // If global_jit_level_ is ON, respect only _XlaInternalScope.
1052     const string& scope =
1053         GetNodeAttrString(node->attrs(), kXlaInternalScopeAttr);
1054     if (!scope.empty()) {
1055       return scope;
1056     }
1057   } else {
1058     // If global_jit_level_ is OFF, respect only _XlaScope.
1059     const string& scope = GetNodeAttrString(node->attrs(), kXlaScopeAttr);
1060     if (!scope.empty()) {
1061       return scope;
1062     }
1063   }
1064 
1065   return std::nullopt;
1066 }
1067 
1068 // Returns true iff the attribute `attr_name` is attached to either the node or
1069 // to it's callee.
GetNodeOrFuncAttr(Node * node,FunctionLibraryDefinition * flib_def,const char * attr_name)1070 static bool GetNodeOrFuncAttr(Node* node, FunctionLibraryDefinition* flib_def,
1071                               const char* attr_name) {
1072   bool out = false;
1073   bool attr_value;
1074   if (TryGetNodeAttr(node->attrs(), attr_name, &attr_value)) {
1075     out |= attr_value;
1076   }
1077 
1078   if (flib_def->GetAttr(*node, attr_name, &attr_value).ok()) {
1079     out |= attr_value;
1080   }
1081   return out;
1082 }
1083 
BuildInitialClusterSet()1084 Status MarkForCompilationPassImpl::BuildInitialClusterSet() {
1085   auto ignore_resource_ops = [&](const Node& n, bool* ignore) {
1086     return IgnoreResourceOpForSafetyAnalysis(&device_info_cache_, n, ignore);
1087   };
1088 
1089   std::vector<std::pair<int, int>> unsafe_resource_deps_vect;
1090   TF_RETURN_IF_ERROR(ComputeIncompatibleResourceOperationPairs(
1091       *graph_, flib_def_, ignore_resource_ops, &unsafe_resource_deps_vect));
1092   absl::c_copy(
1093       unsafe_resource_deps_vect,
1094       std::inserter(unsafe_resource_deps_, unsafe_resource_deps_.begin()));
1095 
1096   cluster_for_node_.resize(graph_->num_node_ids());
1097   for (Node* node : graph_->nodes()) {
1098     if (!IsCompilationCandidate(node)) {
1099       cluster_for_node_[node->id()].Get() = nullptr;
1100       continue;
1101     }
1102 
1103     // We want clusters to be big enough that the benefit from XLA's
1104     // optimizations offsets XLA related overhead (for instance we add some
1105     // Switch/Merge nodes into the graph to implement lazy compilation).  To
1106     // this end, we don't count Identity and Constant nodes because they do not
1107     // enable interesting optimizations by themselves.
1108     int effective_cluster_size =
1109         (node->IsIdentity() || node->IsConstant()) ? 0 : 1;
1110 
1111     bool has_functional_control_flow = node->IsWhileNode() || node->IsIfNode();
1112 
1113     std::optional<DeadnessPredicate> deadness_predicate;
1114     if (deadness_analysis_) {
1115       TF_ASSIGN_OR_RETURN(
1116           deadness_predicate,
1117           deadness_analysis_->GetPredicateFor(node, Graph::kControlSlot));
1118     }
1119 
1120     const string& device_name_str = !node->assigned_device_name().empty()
1121                                         ? node->assigned_device_name()
1122                                         : node->requested_device();
1123     TF_ASSIGN_OR_RETURN(DeviceId device,
1124                         device_info_cache_.GetIdFor(device_name_str));
1125 
1126     bool is_resource_op = HasResourceInputOrOutput(*node);
1127     std::optional<DeviceId> resource_op_device;
1128     if (is_resource_op) {
1129       resource_op_device = device;
1130     }
1131 
1132     std::optional<int> resource_var_operation_node_id;
1133     if (is_resource_op || MayCallFunction(*node, flib_def_)) {
1134       resource_var_operation_node_id = node->id();
1135     }
1136 
1137     bool is_xla_compile_attr_true =
1138         GetNodeOrFuncAttr(node, flib_def_, kXlaCompileAttr) ||
1139         (global_jit_level_ != OptimizerOptions::OFF &&
1140          GetNodeOrFuncAttr(node, flib_def_, kXlaMustCompileAttr));
1141 
1142     DeviceSet devices;
1143     devices.Insert(device);
1144 
1145     Cluster* new_cluster = MakeNewCluster(
1146         /*cycles_graph_node_id=*/node->id(),
1147         /*effective_cluster_size=*/effective_cluster_size,
1148         /*has_functional_control_flow=*/has_functional_control_flow, devices,
1149         resource_op_device, resource_var_operation_node_id, deadness_predicate,
1150         /*is_xla_compile_attr_true=*/is_xla_compile_attr_true,
1151         GetXlaScope(node));
1152 
1153     cluster_for_node_[node->id()].Get() = new_cluster;
1154   }
1155 
1156   return OkStatus();
1157 }
1158 
IsIdentityDrivingConstsInLoop(Node * node)1159 StatusOr<bool> IsIdentityDrivingConstsInLoop(Node* node) {
1160   if (!node->IsIdentity()) {
1161     return false;
1162   }
1163 
1164   // Check if the Identity is driven by a Switch on its true path.
1165   auto it = absl::c_find_if(node->in_edges(), [](const Edge* e) {
1166     return e->src()->IsSwitch() && e->src_output() == 1;
1167   });
1168   if (it == node->in_edges().end()) {
1169     return false;
1170   }
1171   const Node* switch_node = (*it)->src();
1172 
1173   // Check if the Switch is driven by LoopCond.
1174   const Node* maybe_loop_cond;
1175   TF_RETURN_IF_ERROR(switch_node->input_node(1, &maybe_loop_cond));
1176   if (!maybe_loop_cond->IsLoopCond()) {
1177     return false;
1178   }
1179 
1180   // Check if the Identity is driving any const nodes through a control edge.
1181   bool driving_any_consts =
1182       absl::c_any_of(node->out_edges(), [](const Edge* e) {
1183         return e->dst()->IsConstant() && e->IsControlEdge();
1184       });
1185   if (!driving_any_consts) {
1186     return false;
1187   }
1188 
1189   return true;
1190 }
1191 
GetOrCreateClusterExcludeList()1192 absl::flat_hash_set<string> GetOrCreateClusterExcludeList() {
1193   MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
1194   absl::flat_hash_set<string> excludelist;
1195   for (auto s : absl::StrSplit(flags->tf_xla_cluster_exclude_ops, ',')) {
1196     if (!s.empty()) {
1197       excludelist.insert(string(s));
1198     }
1199   }
1200   if (VLOG_IS_ON(2) && !excludelist.empty()) {
1201     std::vector<string> vexcludelist(excludelist.begin(), excludelist.end());
1202     absl::c_sort(vexcludelist);
1203     VLOG(2) << "XLA clustering will exclude following TF operations from auto "
1204                "clustering: "
1205             << absl::StrJoin(vexcludelist, " ");
1206   }
1207   return excludelist;
1208 }
1209 
GetOrCreateAllowlist()1210 absl::flat_hash_set<string> GetOrCreateAllowlist() {
1211   absl::flat_hash_map<string, std::vector<string>>* allowlist_table =
1212       tensorflow::GetAllowlistTable();
1213   MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
1214   absl::flat_hash_set<string> allowlist;
1215 
1216   for (auto s : absl::StrSplit(flags->tf_xla_ops_to_cluster, ',')) {
1217     if (s == "FUSIBLE") {
1218       for (auto pair : *allowlist_table) {
1219         allowlist.insert(pair.second.begin(), pair.second.end());
1220       }
1221     } else if (allowlist_table->contains(s)) {
1222       auto v = allowlist_table->at(s);
1223       allowlist.insert(v.begin(), v.end());
1224     } else if (!s.empty()) {
1225       // Should be a user provided TF operation.
1226       allowlist.insert(string(s));
1227     }
1228   }
1229 
1230   if (VLOG_IS_ON(2) && !allowlist.empty()) {
1231     std::vector<string> vallowlist(allowlist.begin(), allowlist.end());
1232     absl::c_sort(vallowlist);
1233     VLOG(2) << "XLA clustering will only consider the following TF operations: "
1234             << absl::StrJoin(vallowlist, " ");
1235   }
1236   return allowlist;
1237 }
1238 
FindCompilationCandidates()1239 Status MarkForCompilationPassImpl::FindCompilationCandidates() {
1240   OptimizerOptions opts;
1241   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
1242       new ProcessFunctionLibraryRuntime(nullptr, env_, /*config=*/nullptr,
1243                                         TF_GRAPH_DEF_VERSION, flib_def_, opts));
1244   FunctionLibraryRuntime* lib_runtime =
1245       pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
1246   std::vector<bool> compile_time_const_nodes(graph_->num_node_ids(), false);
1247   TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
1248       *graph_, /*compile_time_const_arg_indices=*/nullptr,
1249       &compile_time_const_nodes, lib_runtime));
1250   // Iterate over nodes in sorted order so that compiler fuel is deterministic.
1251   // We can't simply pass op_nodes().begin() and op_nodes().end() to the
1252   // std::vector constructor because they're not proper iterators, with
1253   // iterator_traits defined and so on.
1254   std::vector<Node*> sorted_nodes;
1255   for (Node* node : graph_->op_nodes()) {
1256     sorted_nodes.push_back(node);
1257   }
1258   std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeComparatorID());
1259 
1260   if (*debug_options_.fuel >= std::numeric_limits<int64_t>::max() / 2) {
1261     // The assumption is that if fuel started out as INT64_MAX, it will forever
1262     // stay greater than INT64_MAX / 2.
1263     VLOG(2) << "Starting fuel: infinity";
1264   } else {
1265     VLOG(2) << "Starting fuel: " << *debug_options_.fuel;
1266   }
1267 
1268   VLOG(2) << "sorted_nodes.size() = " << sorted_nodes.size();
1269 
1270   auto allowlist = GetOrCreateAllowlist();
1271 
1272   std::vector<string> vall_ops = XlaOpRegistry::GetAllRegisteredOps();
1273   absl::flat_hash_set<string> all_ops(vall_ops.begin(), vall_ops.end());
1274   // Check that user's provided TF operation really exists.
1275   for (const auto& s : allowlist) {
1276     if (!all_ops.contains(s)) {
1277       return errors::InvalidArgument(
1278           "The operation '", s,
1279           "' passed to --tf_xla_ops_to_cluster is not supported by XLA.");
1280     }
1281   }
1282 
1283   for (Node* node : sorted_nodes) {
1284     if (*debug_options_.fuel <= 0) {
1285       VLOG(1)
1286           << "Hit fuel limit; not marking any remaining ops as clusterable.";
1287       break;
1288     }
1289 
1290     TF_ASSIGN_OR_RETURN(
1291         const DeviceType& device_type,
1292         device_info_cache_.GetDeviceTypeFor(node->assigned_device_name()));
1293     VLOG(4) << "Device type for " << node->name() << ": "
1294             << device_type.type_string();
1295 
1296     if (CompilationDisallowedByXlaCompileAttr(node)) {
1297       VLOG(2) << "Not clustering " << node->name()
1298               << ": disallowed by _XlaCompile attribute";
1299       continue;
1300     }
1301 
1302     const XlaOpRegistry::DeviceRegistration* registration;
1303     if (!XlaOpRegistry::GetCompilationDevice(device_type.type(),
1304                                              &registration)) {
1305       VLOG(2) << "Rejecting " << node->name()
1306               << ": could not find JIT device for " << device_type.type();
1307       continue;
1308     }
1309 
1310     auto cluster_exclude_op_list = GetOrCreateClusterExcludeList();
1311     RecursiveCompilabilityChecker::OperationFilter filter =
1312         CreateOperationFilter(*registration);
1313     filter.require_always_compilable = true;
1314     filter.allow_string_consts = false;
1315     filter.allow_collective_reduce_v2 = false;
1316     filter.allow_unique_op = false;
1317     filter.allow_where_op = true;
1318 
1319     for (const auto& s : cluster_exclude_op_list) {
1320       if (s == "Where") {
1321         filter.allow_where_op = false;
1322       } else {
1323         return errors::InvalidArgument(
1324             "The operation '", s,
1325             "' passed to --tf_xla_cluster_exclude_ops is not supported by "
1326             "XLA.");
1327       }
1328     }
1329 
1330     RecursiveCompilabilityChecker checker(
1331         filter, DeviceType{registration->compilation_device_name});
1332 
1333     if (!checker.IsCompilableNode(*node, lib_runtime)) {
1334       continue;
1335     }
1336 
1337     if (node->type_string() == "Const") {
1338       // Skip Const op with type DT_STRING, since XLA autoclustering doesn't
1339       // support it.
1340       const AttrValue* attr = node->attrs().Find("dtype");
1341       if (attr != nullptr && attr->type() == DT_STRING) {
1342         continue;
1343       }
1344     }
1345 
1346     if (!allowlist.empty() && !allowlist.contains(node->def().op())) {
1347       VLOG(1) << "Rejecting TF operation " << node->def().op()
1348               << " as it is not listed in --tf_xla_ops_to_cluster.";
1349       continue;
1350     }
1351 
1352     if (compile_time_const_nodes[node->id()]) {
1353       const OpDef* op_def;
1354       TF_RETURN_IF_ERROR(
1355           graph_->op_registry()->LookUpOpDef(node->type_string(), &op_def));
1356       if (op_def->is_stateful()) {
1357         // It is easiest to demonstrate the problem we're trying to solve with
1358         // an example.  Say we have this graph:
1359         //
1360         //   shape = RandomUniformInt();
1361         //   reshape = Reshape(input, shape)
1362         //
1363         // Both RandomUniformInt and Reshape are compilable by XLA so, absent
1364         // any other reason, we will try to put both shape and reshape in the
1365         // same cluster.  However, since XLA only supports statically shaped
1366         // values, it will expect to be able to constant fold `shape` to get a
1367         // static shape for `reshape`.  This is a problem because side-effecting
1368         // ops like RandomUniformInt() cannot be constant folded.  We fix this
1369         // by putting `shape` and `reshape` in different clusters, which results
1370         // in us recompiling `reshape`'s cluster for every new value of `shape`,
1371         // making `reshape` statically sized within each compilation.  We
1372         // simplify the solution even further by disallowing operations like
1373         // `shape` from being part of *any* non-trivial cluster.  They're either
1374         // not compiled by XLA altogether or, if assigned to an XLA_* device
1375         // with "must compile" semantics, compiled into a trivial single-op
1376         // cluster.  This approach leaves some room for improvement, and we can
1377         // consider implementing a more aggressive data-flow-analysis based
1378         // solution in the future if needed.
1379         //
1380         // One ugly problem we have to contend with: certain sets of ops *have*
1381         // to be in the same cluster because values flowing between them have
1382         // types that can't be live-in or live-out of a cluster.  These ops are:
1383         //
1384         //  - TensorArray ops operating on the same TensorArray instance.
1385         //  - Stack ops operating on the same Stack instance.
1386         //
1387         // To work around this we avoid isolating these specific ops.  Because
1388         // of this concession it is unsound to auto-cluster them because then
1389         // we'd create clusters we could not compile (because we can't constant
1390         // fold, say, a TensorArrayRead or a StackPopV2).  But we don't
1391         // auto-cluster these operations today so we're good for now.
1392         const XlaResourceOpInfo* op_info =
1393             GetResourceOpInfoForOp(node->type_string());
1394         bool is_tensor_array_or_stack_op =
1395             op_info && op_info->resource_kind() != XlaResourceKind::kVariable;
1396         if (!is_tensor_array_or_stack_op) {
1397           VLOG(2) << "Isolating " << node->name()
1398                   << ": must-be-constant stateful op";
1399           continue;
1400         }
1401       }
1402     }
1403 
1404     // This is a heuristic to avoid creating dependency between while loop
1405     // condition and body computations.  Dependency between them can be created
1406     // if a special Identity node in the following pattern is clustered in.
1407     // That is, an Identity node in the loop cond computation is used to drive
1408     // const nodes consumed by the loop body.  If this Identity node goes into
1409     // the same cluster with nodes from the loop body, extra dependency is
1410     // created between the loop cond and body computations and it hinders the
1411     // progression of the loop cond computation at runtime with significant
1412     // overhead.  Specifically, we look for the below pattern and do not cluster
1413     // in this Identity to avoid the described issue.  Since Identity has low
1414     // execution cost in native TF, the fact that this heuristic gives up these
1415     // special Identity nodes as candidates should not harm any performance.  If
1416     // other considerations emerge in the future, we can revisit the heuristic
1417     // and only disallow these Identities to go into the cluster with nodes from
1418     // the loop body but still consider them candidates.
1419     //
1420     // LoopCond ->
1421     // Merge    -> Switch -> Identity -> i++ -> ... -> NextIteration
1422     //                               ..> Const -> LoopBody
1423     //                            (control edge)
1424     TF_ASSIGN_OR_RETURN(bool is_identity_driving_consts_in_loop,
1425                         IsIdentityDrivingConstsInLoop(node));
1426     if (is_identity_driving_consts_in_loop) {
1427       VLOG(2) << "Rejecting " << node->name()
1428               << ": including it can create dependencies between while loop "
1429                  "condition and body computations with runtime overhead.";
1430       continue;
1431     }
1432 
1433     compilation_candidates_.insert(node);
1434     --(*debug_options_.fuel);
1435   }
1436 
1437   VLOG(2) << "compilation_candidates_.size() = "
1438           << compilation_candidates_.size();
1439   return OkStatus();
1440 }
1441 
CompilationDisallowedByXlaCompileAttr(Node * node)1442 bool MarkForCompilationPassImpl::CompilationDisallowedByXlaCompileAttr(
1443     Node* node) {
1444   if (debug_options_.ignore_xla_compile_attr) {
1445     return false;
1446   }
1447 
1448   // If there is a _XlaCompile annotation, use its value.
1449   bool compile = false;
1450   Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile);
1451   if (status.ok()) {
1452     if (!compile) {
1453       VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr("
1454               << kXlaCompileAttr << ") is false.";
1455     }
1456     return !compile;
1457   }
1458 
1459   status = flib_def_->GetAttr(*node, kXlaCompileAttr, &compile);
1460   if (status.ok()) {
1461     if (!compile) {
1462       VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr("
1463               << kXlaCompileAttr << ") on callee is false.";
1464     }
1465     return !compile;
1466   }
1467 
1468   return false;
1469 }
1470 
LogNotContractableAndReturnFalse(Cluster * from,Cluster * to,absl::string_view reason)1471 bool MarkForCompilationPassImpl::LogNotContractableAndReturnFalse(
1472     Cluster* from, Cluster* to, absl::string_view reason) {
1473   VLOG(3) << EdgeContractionFailureMsg(from, to, reason);
1474   return false;
1475 }
1476 
TryToContractEdge(Cluster * from,Cluster * to)1477 StatusOr<bool> MarkForCompilationPassImpl::TryToContractEdge(Cluster* from,
1478                                                              Cluster* to) {
1479   DCHECK(from->deadness_predicate().has_value() ==
1480          to->deadness_predicate().has_value());
1481   if (from->deadness_predicate() != to->deadness_predicate()) {
1482     VLOG(3) << EdgeContractionFailureMsg(
1483         from, to,
1484         absl::StrCat(
1485             "the two nodes have mismatching deadness: ",
1486             deadness_analysis_->DebugString(*from->deadness_predicate()),
1487             " and ",
1488             deadness_analysis_->DebugString(*to->deadness_predicate())));
1489     return false;
1490   }
1491 
1492   TF_ASSIGN_OR_RETURN(bool devices_compatible,
1493                       AreDevicesCompatible(*from, *to));
1494   if (!devices_compatible) {
1495     return LogNotContractableAndReturnFalse(
1496         from, to, "the two nodes have incompatible devices");
1497   }
1498 
1499   if (from->xla_scope().has_value() && to->xla_scope().has_value() &&
1500       *from->xla_scope() != *to->xla_scope()) {
1501     return LogNotContractableAndReturnFalse(
1502         from, to, "the two nodes have mismatching XLA scopes");
1503   }
1504 
1505   // Don't exceed the maximum cluster size.
1506   if (from->cluster_size() + to->cluster_size() >
1507       debug_options_.max_cluster_size) {
1508     return LogNotContractableAndReturnFalse(
1509         from, to, "the new cluster will be larger than the max cluster size");
1510   }
1511 
1512   TF_ASSIGN_OR_RETURN(bool will_introduce_cross_device_dependency,
1513                       ClusteringWillIntroduceInterDeviceDependency(*from, *to));
1514 
1515   if (will_introduce_cross_device_dependency) {
1516     return LogNotContractableAndReturnFalse(
1517         from, to, "the new cluster will introduce a cross device dependency");
1518   }
1519 
1520   // Check if contracting this edge will break the resource variable concurrency
1521   // semantics.  In theory this is quadratic in the number of nodes, but seems
1522   // to not be a problem in practice so far.
1523   if (!debug_options_.ignore_resource_variable_checks) {
1524     for (int resource_var_from : from->resource_var_operation_node_ids()) {
1525       for (int resource_var_to : to->resource_var_operation_node_ids()) {
1526         // If unsafe_resource_deps_ contains {A, B} then
1527         //
1528         //  a. A and B are resource operations.
1529         //  b. A and B cannot be placed in the same cluster.
1530         //  c. There is no path from B to A in the cycles graph (but there may
1531         //     be a path from A to B).
1532         //
1533         // So check the legality of the edge contraction by checking if any of
1534         // the n^2 pairs of resource variable operations are forbidden.
1535         if (unsafe_resource_deps_.contains(
1536                 {resource_var_from, resource_var_to})) {
1537           return LogNotContractableAndReturnFalse(
1538               from, to,
1539               "the new cluster would break resource variable semantics");
1540         }
1541       }
1542     }
1543   }
1544 
1545   return MergeClusters(from, to);
1546 }
1547 
Run()1548 Status MarkForCompilationPassImpl::Run() {
1549   // Make sure that kernels have been registered on the JIT device.
1550   XlaOpRegistry::RegisterCompilationKernels();
1551 
1552   // Start the timer after XlaOpRegistry::RegisterCompilationKernels which does
1553   // some one-time work.
1554   XLA_SCOPED_LOGGING_TIMER_LEVEL("MarkForCompilationPassImpl::Run", 1);
1555 
1556   TF_ASSIGN_OR_RETURN(bool initialized, Initialize());
1557   if (!initialized) {
1558     // Initialization exited early which means this instance of
1559     // MarkForCompilationPassImpl is not set up to run the subsequent phases.
1560     return OkStatus();
1561   }
1562 
1563   TF_RETURN_IF_ERROR(RunEdgeContractionLoop());
1564   TF_RETURN_IF_ERROR(DeclusterNodes());
1565   TF_RETURN_IF_ERROR(CreateClusters());
1566   TF_RETURN_IF_ERROR(DumpDebugInfo());
1567 
1568   return OkStatus();
1569 }
1570 
DumpPostClusteringGraphs()1571 void MarkForCompilationPassImpl::DumpPostClusteringGraphs() {
1572   DumpGraphToFile("mark_for_compilation", *graph_, flib_def_);
1573 
1574   // We also dump out an annotated version of the TF graph where the nodes
1575   // names are prefixed with the cluster names.  This can help visualizing the
1576   // clustering decisions on TensorBoard.
1577   Graph new_graph(graph_->op_registry());
1578   CopyGraph(*graph_, &new_graph);
1579 
1580   for (Node* n : new_graph.nodes()) {
1581     if (std::optional<absl::string_view> cluster_name =
1582             GetXlaClusterForNode(*n)) {
1583       n->set_name(absl::StrCat(*cluster_name, "/", n->name()));
1584     } else if (n->type_string() == "VarHandleOp") {
1585       n->set_name(absl::StrCat("varhandle/", n->name()));
1586     } else {
1587       // There is room for improvement here.  In particular, it may help to
1588       // split these unclustered nodes into classes where every node in a
1589       // specific class has edges to and from the same set of clusters.
1590       n->set_name(absl::StrCat("unclustered/", n->name()));
1591     }
1592   }
1593 
1594   DumpGraphToFile("mark_for_compilation_annotated", new_graph, flib_def_);
1595 }
1596 
RatioToString(int numerator,int denominator)1597 string RatioToString(int numerator, int denominator) {
1598   return absl::StrFormat("%d / %d (%.2f%%)", numerator, denominator,
1599                          (100.0 * numerator) / denominator);
1600 }
1601 
VLogClusteringSummary()1602 void MarkForCompilationPassImpl::VLogClusteringSummary() {
1603   if (!VLOG_IS_ON(2)) {
1604     return;
1605   }
1606 
1607   XlaAutoClusteringSummary auto_clustering_info =
1608       GetXlaAutoClusteringSummary(*graph_);
1609 
1610   VLOG(2) << "*** Clustering info for graph of size " << graph_->num_nodes();
1611   VLOG(2) << " Built " << auto_clustering_info.clusters_size()
1612           << " clusters, size "
1613           << RatioToString(auto_clustering_info.clustered_node_count(),
1614                            graph_->num_nodes());
1615 
1616   for (const XlaAutoClusteringSummary::Cluster& cluster :
1617        auto_clustering_info.clusters()) {
1618     absl::string_view cluster_name = cluster.name();
1619     int size = cluster.size();
1620     VLOG(2) << "  " << cluster_name << " "
1621             << RatioToString(size, graph_->num_nodes());
1622     for (const XlaAutoClusteringSummary::OpAndCount& op_count :
1623          cluster.op_histogram()) {
1624       VLOG(3) << "   " << op_count.op() << ": " << op_count.count()
1625               << " instances";
1626     }
1627   }
1628 
1629   if (!auto_clustering_info.unclustered_op_histogram().empty()) {
1630     VLOG(2) << " Unclustered nodes: "
1631             << RatioToString(auto_clustering_info.unclustered_node_count(),
1632                              graph_->num_nodes());
1633     for (const XlaAutoClusteringSummary::OpAndCount& op_count :
1634          auto_clustering_info.unclustered_op_histogram()) {
1635       VLOG(3) << "  " << op_count.op() << ": " << op_count.count()
1636               << " instances";
1637     }
1638   }
1639 
1640   struct EdgeInfo {
1641     absl::string_view node_name;
1642     std::optional<absl::string_view> cluster_name;
1643 
1644     absl::string_view GetClusterName() const {
1645       return cluster_name ? *cluster_name : "[none]";
1646     }
1647 
1648     std::pair<absl::string_view, std::optional<absl::string_view>> AsPair()
1649         const {
1650       return {node_name, cluster_name};
1651     }
1652 
1653     bool operator<(const EdgeInfo& other) const {
1654       return AsPair() < other.AsPair();
1655     }
1656   };
1657 
1658   using EdgeInfoMap = std::map<absl::string_view, std::map<EdgeInfo, int64_t>>;
1659 
1660   EdgeInfoMap incoming_edge_infos;
1661   EdgeInfoMap outgoing_edge_infos;
1662 
1663   std::set<absl::string_view> cluster_names_to_print;
1664 
1665   for (const Edge* e : graph_->edges()) {
1666     const Node* from = e->src();
1667     std::optional<absl::string_view> from_cluster_name =
1668         GetXlaClusterForNode(*from);
1669 
1670     const Node* to = e->dst();
1671     std::optional<absl::string_view> to_cluster_name =
1672         GetXlaClusterForNode(*to);
1673 
1674     if (to_cluster_name == from_cluster_name) {
1675       continue;
1676     }
1677 
1678     if (to_cluster_name) {
1679       incoming_edge_infos[*to_cluster_name]
1680                          [EdgeInfo{from->name(), from_cluster_name}]++;
1681       cluster_names_to_print.insert(*to_cluster_name);
1682     }
1683 
1684     if (from_cluster_name) {
1685       outgoing_edge_infos[*from_cluster_name][{to->name(), to_cluster_name}]++;
1686       cluster_names_to_print.insert(*from_cluster_name);
1687     }
1688   }
1689 
1690   VLOG(4) << "*** Inter-Cluster edges:";
1691   if (cluster_names_to_print.empty()) {
1692     VLOG(4) << "   [none]";
1693   }
1694 
1695   auto print_edge_info_set_for_cluster = [&](absl::string_view cluster_name,
1696                                              const EdgeInfoMap& edge_info_map,
1697                                              absl::string_view desc) {
1698     auto it = edge_info_map.find(cluster_name);
1699     if (it != edge_info_map.end()) {
1700       VLOG(4) << "  " << it->second.size() << " " << desc << " edges";
1701       for (const auto& edge_info_count_pair : it->second) {
1702         VLOG(4) << "   " << edge_info_count_pair.first.GetClusterName() << " "
1703                 << edge_info_count_pair.first.node_name << " # "
1704                 << edge_info_count_pair.second;
1705       }
1706     } else {
1707       VLOG(4) << "  No " << desc << " edges.";
1708     }
1709   };
1710 
1711   for (absl::string_view cluster_name : cluster_names_to_print) {
1712     VLOG(4) << " ** Cluster " << cluster_name;
1713     print_edge_info_set_for_cluster(cluster_name, incoming_edge_infos,
1714                                     "incoming");
1715     print_edge_info_set_for_cluster(cluster_name, outgoing_edge_infos,
1716                                     "outgoing");
1717   }
1718 }
1719 
AreDevicesCompatible(const Cluster & cluster_a,const Cluster & cluster_b)1720 StatusOr<bool> MarkForCompilationPassImpl::AreDevicesCompatible(
1721     const Cluster& cluster_a, const Cluster& cluster_b) {
1722   DeviceSet devices = cluster_a.devices();
1723   devices.UnionWith(cluster_b.devices());
1724 
1725   TF_ASSIGN_OR_RETURN(
1726       std::optional<jit::DeviceId> maybe_chosen_device,
1727       MaybePickDeviceForXla(device_info_cache_, devices,
1728                             /*allow_mixing_unknown_and_cpu=*/false));
1729   if (!maybe_chosen_device.has_value()) {
1730     return false;
1731   }
1732 
1733   jit::DeviceId chosen_device = *maybe_chosen_device;
1734 
1735   // If we are able to pick a device `chosen_device` for the larger cluster, the
1736   // resource operations in `cluster_a` and `cluster_b` must be placed on the
1737   // same device as `chosen_device`.  This is because the _XlaCompile and
1738   // _XlaRun kernels are going to run on and therefore try to access the
1739   // resource variables from `chosen_device`, which will be an error if the
1740   // resource variables are placed on some other device.
1741   auto resource_op_device_ok = [&](std::optional<DeviceId> resource_op_device) {
1742     return !resource_op_device.has_value() ||
1743            *resource_op_device == chosen_device;
1744   };
1745 
1746   return resource_op_device_ok(cluster_a.resource_op_device()) &&
1747          resource_op_device_ok(cluster_b.resource_op_device());
1748 }
1749 
1750 // Returns `true` iff we should compile `cluster`.
ShouldCompileClusterImpl(const Cluster & cluster)1751 StatusOr<bool> MarkForCompilationPassImpl::ShouldCompileClusterImpl(
1752     const Cluster& cluster) {
1753   TF_ASSIGN_OR_RETURN(DeviceId chosen_device,
1754                       PickDeviceForXla(device_info_cache_, cluster.devices(),
1755                                        /*allow_mixing_unknown_and_cpu=*/false));
1756 
1757   const DeviceType& device_type =
1758       device_info_cache_.GetDeviceTypeFor(chosen_device);
1759   const XlaOpRegistry::DeviceRegistration* registration =
1760       device_info_cache_.GetCompilationDevice(chosen_device);
1761   TF_RET_CHECK(registration)
1762       << "chosen device = " << device_info_cache_.GetNameFor(chosen_device)
1763       << "; device type = " << device_type.type() << "; devices ("
1764       << device_info_cache_.DebugString(cluster.devices());
1765 
1766   auto policy = registration->autoclustering_policy;
1767   bool should_compile =
1768       cluster.is_xla_compile_attr_true() ||
1769       policy == XlaOpRegistry::AutoclusteringPolicy::kAlways ||
1770       (policy == XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally &&
1771        global_jit_level_ != OptimizerOptions::OFF) ||
1772       (device_type.type_string() == DEVICE_CPU &&
1773        policy == XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested &&
1774        cpu_global_jit_);
1775 
1776   if (!should_compile && device_type.type_string() == DEVICE_CPU &&
1777       global_jit_level_ > OptimizerOptions::OFF) {
1778     static absl::once_flag once;
1779     absl::call_once(once, [] {
1780       LOG(WARNING) << R"((One-time warning): Not using XLA:CPU for cluster.
1781 
1782 If you want XLA:CPU, do one of the following:
1783 
1784  - set the TF_XLA_FLAGS to include "--tf_xla_cpu_global_jit", or
1785  - set cpu_global_jit to true on this session's OptimizerOptions, or
1786  - use experimental_jit_scope, or
1787  - use tf.function(jit_compile=True).
1788 
1789 To confirm that XLA is active, pass --vmodule=xla_compilation_cache=1 (as a
1790 proper command-line flag, not via TF_XLA_FLAGS).)";
1791 
1792       MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
1793       if (flags->tf_xla_cpu_global_jit) {
1794         LOG(WARNING)
1795             << "(Although the tf_xla_cpu_global_jit flag is currently enabled, "
1796                "perhaps it wasn't enabled at process startup?)";
1797       }
1798     });
1799   }
1800 
1801   VLOG(3) << (should_compile ? "Compiling" : "Not compiling")
1802           << " cluster with device "
1803           << device_info_cache_.GetNameFor(chosen_device);
1804 
1805   return should_compile;
1806 }
1807 
ShouldCompileCluster(const Cluster & cluster)1808 StatusOr<bool> MarkForCompilationPassImpl::ShouldCompileCluster(
1809     const Cluster& cluster) {
1810   auto it = should_compile_cluster_cache_.find(&cluster);
1811   if (it != should_compile_cluster_cache_.end()) {
1812     return it->second;
1813   }
1814 
1815   TF_ASSIGN_OR_RETURN(bool should_compile, ShouldCompileClusterImpl(cluster));
1816   should_compile_cluster_cache_.insert({&cluster, should_compile});
1817   return should_compile;
1818 }
1819 
MarkForCompilation(const GraphOptimizationPassOptions & options,const MarkForCompilationPassImpl::DebugOptions & debug_options)1820 Status MarkForCompilation(
1821     const GraphOptimizationPassOptions& options,
1822     const MarkForCompilationPassImpl::DebugOptions& debug_options) {
1823   Graph* graph = options.graph->get();
1824   FunctionLibraryDefinition* flib_def = options.flib_def;
1825 
1826   // Deadness analysis expects a graph with source and sink edges properly
1827   // connected but sometimes the incoming graph does not follow this invariant.
1828   // So fix up the source and sink edges before calling into deadness analysis.
1829   FixupSourceAndSinkEdges(graph);
1830 
1831   for (Node* n : graph->nodes()) {
1832     // See explanation on `kXlaAlreadyClustered`.
1833     if (n->attrs().Find(kXlaAlreadyClustered)) {
1834       return OkStatus();
1835     }
1836     // Skip the pass if we found TPUExecute or TPUExecuteAndUpdateVariables ops
1837     // in the graph, which indicates the graph is produced by TPU TF-XLA bridge
1838     // and doesn't require auto clustering.
1839     if (n->type_string() == "TPUExecute" ||
1840         n->type_string() == "TPUExecuteAndUpdateVariables") {
1841       return OkStatus();
1842     }
1843   }
1844 
1845   return MarkForCompilationPassImpl{
1846       debug_options,
1847       graph,
1848       flib_def,
1849       options.session_options != nullptr ? options.session_options->env
1850                                          : Env::Default(),
1851       GetGlobalJitLevelForGraph(options),
1852       options.session_options->config.graph_options()
1853           .optimizer_options()
1854           .cpu_global_jit()}
1855       .Run();
1856 }
1857 
GetPointerToFuel(int64_t initial_value)1858 std::atomic<int64_t>* GetPointerToFuel(int64_t initial_value) {
1859   static std::atomic<int64_t>* fuel = [&]() {
1860     std::atomic<int64_t>* fuel = new std::atomic<int64_t>;
1861     *fuel = initial_value;
1862     return fuel;
1863   }();
1864 
1865   return fuel;
1866 }
1867 }  // anonymous namespace
1868 
Run(const GraphOptimizationPassOptions & options)1869 Status MarkForCompilationPass::Run(
1870     const GraphOptimizationPassOptions& options) {
1871   MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
1872 
1873   MarkForCompilationPassImpl::DebugOptions debug_options;
1874   debug_options.ignore_deadness_checks =
1875       flags->tf_xla_disable_deadness_safety_checks_for_debugging;
1876   debug_options.ignore_resource_variable_checks =
1877       flags->tf_xla_disable_resource_variable_safety_checks_for_debugging;
1878   debug_options.ignore_xla_compile_attr = false;
1879   debug_options.deterministic_cluster_names =
1880       flags->tf_xla_deterministic_cluster_names;
1881   debug_options.max_cluster_size = flags->tf_xla_max_cluster_size;
1882   debug_options.min_cluster_size = flags->tf_xla_min_cluster_size;
1883   debug_options.fuel = GetPointerToFuel(flags->tf_xla_clustering_fuel);
1884   debug_options.dump_graphs = flags->tf_xla_clustering_debug;
1885 
1886   return MarkForCompilation(options, debug_options);
1887 }
1888 
RunForTest(const GraphOptimizationPassOptions & options,bool disable_deadness_analysis,bool deterministic_cluster_names)1889 Status MarkForCompilationPass::RunForTest(
1890     const GraphOptimizationPassOptions& options, bool disable_deadness_analysis,
1891     bool deterministic_cluster_names) {
1892   MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
1893 
1894   MarkForCompilationPassImpl::DebugOptions debug_options;
1895   debug_options.ignore_deadness_checks = disable_deadness_analysis;
1896   debug_options.ignore_resource_variable_checks =
1897       flags->tf_xla_disable_resource_variable_safety_checks_for_debugging;
1898   debug_options.ignore_xla_compile_attr = true;
1899   debug_options.deterministic_cluster_names = deterministic_cluster_names;
1900   debug_options.max_cluster_size = flags->tf_xla_max_cluster_size;
1901   debug_options.min_cluster_size = flags->tf_xla_min_cluster_size;
1902   debug_options.fuel = GetPointerToFuel(flags->tf_xla_clustering_fuel);
1903   debug_options.dump_graphs = flags->tf_xla_clustering_debug;
1904 
1905   return MarkForCompilation(options, debug_options);
1906 }
1907 
GetAllowlistTable()1908 absl::flat_hash_map<string, std::vector<string>>* GetAllowlistTable() {
1909   // Table format: category name: {list of TF operations in that category}
1910   static absl::flat_hash_map<string, std::vector<string>>* result =
1911       new absl::flat_hash_map<string, std::vector<string>>{
1912           // Unary
1913           {"PW",
1914            {"ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin",
1915             "Atan", "Atanh", "Ceil", "Cos", "Cosh", "Sin", "Exp", "Expm1",
1916             "Floor", "IsFinite", "IsInf", "IsNan", "Inv", "Reciprocal", "Log",
1917             "Log1p", "Invert", "LogicalNot", "Ndtri", "Neg", "Rint", "Round",
1918             "Rsqrt", "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt",
1919             "Square", "Tan", "Tanh", "Real", "Imag", "Erf", "Erfc", "Erfinv",
1920             "Lgamma", "Digamma",
1921             // Binary
1922             "Add", "AddV2", "Sub", "Mul", "Div", "Atan2", "Complex", "DivNoNan",
1923             "MulNoNan", "FloorDiv", "Xlogy", "Xlog1py", "Xdivy", "FloorMod",
1924             "BitwiseAnd", "BitwiseOr", "BitwiseXor", "LeftShift", "RightShift",
1925             "LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv",
1926             "ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "TruncateDiv",
1927             "TruncateMod", "Equal", "NotEqual", "Greater", "GreaterEqual",
1928             "Less", "LessEqual", "SigmoidGrad", "SoftplusGrad", "SoftsignGrad",
1929             "TanhGrad", "Pow", "SquaredDifference", "ApproximateEqual",
1930             // Others
1931             "AddN", "Bitcast", "Cast", "ClipByValue", "Const", "Empty",
1932             "Identity", "IdentityN", "Relu", "Relu6", "ReluGrad", "Relu6Grad",
1933             "LeakyReluGrad", "Elu", "EluGrad", "Selu", "SeluGrad", "Select",
1934             "SelectV2", "Transpose", "ConjugateTranspose",
1935             "_UnaryOpsComposition", "CollectiveReduceV2",
1936             "CollectiveAssignGroupV2",
1937             // The following 5 operations are converted to identity
1938             "PlaceholderWithDefault", "PreventGradient", "StopGradient",
1939             "Snapshot", "_EagerConst"}},
1940           // clang-format off
1941     {"RED",
1942      {"All", "Any", "Min", "Max", "Mean", "Prod", "Sum"}},
1943           // clang-format on
1944           {"PWRED",
1945            {"ArgMax", "ArgMin", "DiagPart", "Softmax",
1946             "SparseSoftmaxCrossEntropyWithLogits", "LogSoftmax"}},
1947           {"REDUCEWINDOW",
1948            {"ArgMax", "ArgMin", "DiagPart", "Softmax",
1949             "SparseSoftmaxCrossEntropyWithLogits", "LogSoftmax"}},
1950           {"REDUCEWINDOWPW", {"BiasAddGrad", "LRN", "LRNGrad"}},
1951           {"BN",
1952            {"FusedBatchNorm", "FusedBatchNormV2", "FusedBatchNormV3",
1953             "_FusedBatchNormEx", "FusedBatchNormGrad", "FusedBatchNormGradV2",
1954             "FusedBatchNormGradV3"}},
1955           {"SORT", {"TopKV2"}},  // XLA version much faster then TF version.
1956           {"MISC",
1957            // clang-format off
1958      {"ApproxTopK", "BroadcastTo", "ExpandDims", "Fill", "NoOp",
1959       "Range", "Rank", "Reshape", "Shape", "ShapeN", "Size", "Squeeze",
1960       "Transpose", "ZerosLike", "OnesLike", "BiasAdd" /*PW + Broadcast*/,
1961       "BroadcastArgs", "BroadcastGradientArgs", "OneHot", "Concat", "ConcatV2",
1962       "ConcatOffset", "Const", "MirrorPad", "MirrorPadGrad", "Pack", "Pad",
1963       "PadV2", "Reverse", "ReverseV2", "ReverseSequence", "Slice", "Split",
1964       "SplitV", "StridedSlice", "StridedSliceGrad",
1965       "ResourceStridedSliceAssign", "Tile", "Transpose", "InvertPermutation",
1966       "Unpack", "DeviceIndex", "TensorStridedSliceUpdate", "XlaConcatND",
1967       "XlaSplitND",
1968      }}};
1969   // clang-format on
1970   return result;
1971 }
1972 
1973 namespace testing {
ResetClusterSequenceNumber()1974 void ResetClusterSequenceNumber() {
1975   ClusterSequenceNumberGenerator::Global().Reset();
1976 }
1977 
GetKnownXLAAllowlistOp()1978 absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
1979   absl::flat_hash_set<string> result{
1980       "AdjustContrastv2",
1981       "AdjustHue",
1982       "AdjustSaturation",
1983       "Asinh",
1984       "Assert",
1985       "AssignAddVariableOp",
1986       "AssignSubVariableOp",
1987       "AssignVariableOp",
1988       "AssignVariableXlaConcatND",
1989       "AvgPool",
1990       "AvgPool3D",
1991       "AvgPool3DGrad",
1992       "AvgPoolGrad",
1993       "BatchMatMul",
1994       "BatchMatMulV2",
1995       "BatchMatMulV3",
1996       "BatchToSpace",
1997       "BatchToSpaceND",
1998       "BesselI0e",
1999       "BesselI1e",
2000       "Betainc",
2001       "BiasAddV1",
2002       "Bincount",
2003       "Bucketize",
2004       "Case",
2005       "CheckNumerics",
2006       "Cholesky",
2007       "ControlTrigger",
2008       "Conv2D",
2009       "Conv2DBackpropFilter",
2010       "Conv2DBackpropInput",
2011       "Conv3D",
2012       "Conv3DBackpropFilterV2",
2013       "Conv3DBackpropInputV2",
2014       "Cross",
2015       "Cumprod",
2016       "Cumsum",
2017       "DenseBincount",
2018       "DataFormatDimMap",
2019       "DataFormatVecPermute",
2020       "DepthToSpace",
2021       "DepthwiseConv2dNative",
2022       "DepthwiseConv2dNativeBackpropFilter",
2023       "DepthwiseConv2dNativeBackpropInput",
2024       "Dequantize",
2025       "Diag",
2026       "DynamicStitch",
2027       "DynamicPartition",
2028       "Einsum",
2029       "EmptyTensorList",
2030       "EnsureShape",
2031       "ExtractImagePatches",
2032       "Igamma",
2033       "IgammaGradA",
2034       "RandomGammaGrad",
2035       "Igammac",
2036       "FFT",
2037       "FFT2D",
2038       "FFT3D",
2039       "FakeParam",
2040       "FakeQuantWithMinMaxArgs",
2041       "FakeQuantWithMinMaxArgsGradient",
2042       "FakeQuantWithMinMaxVars",
2043       "FakeQuantWithMinMaxVarsGradient",
2044       "FakeQuantWithMinMaxVarsPerChannel",
2045       "FakeQuantWithMinMaxVarsPerChannelGradient",
2046       "Gather",
2047       "GatherNd",
2048       "GatherV2",
2049       "HSVToRGB",
2050       "IFFT",
2051       "IFFT2D",
2052       "IFFT3D",
2053       "IRFFT",
2054       "IRFFT2D",
2055       "IRFFT3D",
2056       "If",
2057       "InTopKV2",
2058       "L2Loss",
2059       "LeakyRelu",
2060       "LinSpace",
2061       "ListDiff",
2062       "LogMatrixDeterminant",
2063       "LowerBound",
2064       "MatMul",
2065       "MatrixBandPart",
2066       "MatrixDiag",
2067       "MatrixDiagPart",
2068       "MatrixDiagPartV2",
2069       "MatrixDiagPartV3",
2070       "MatrixDiagV2",
2071       "MatrixDiagV3",
2072       "MatrixInverse",
2073       "MatrixSetDiag",
2074       "MatrixSetDiagV2",
2075       "MatrixSetDiagV3",
2076       "MatrixSolve",
2077       "MatrixTriangularSolve",
2078       "MaxPool",
2079       "MaxPool3D",
2080       "MaxPool3DGrad",
2081       "MaxPool3DGradGrad",
2082       "MaxPoolGrad",
2083       "MaxPoolGradGrad",
2084       "MaxPoolGradGradV2",
2085       "MaxPoolGradV2",
2086       "MaxPoolV2",
2087       "Multinomial",
2088       "NextAfter",
2089       "NonMaxSuppressionV3",
2090       "NonMaxSuppressionV4",
2091       "ParallelDynamicStitch",
2092       "ParameterizedTruncatedNormal",
2093       "PartitionedCall",
2094       "Polygamma",
2095       "PopulationCount",
2096       "Qr",
2097       "QuantizeAndDequantizeV2",
2098       "QuantizeAndDequantizeV3",
2099       "QuantizeAndDequantizeV4",
2100       "RFFT",
2101       "RFFT2D",
2102       "RFFT3D",
2103       "RGBToHSV",
2104       "RandomShuffle",
2105       "RandomStandardNormal",
2106       "RandomUniform",
2107       "RandomUniformInt",
2108       "ReadVariableOp",
2109       "ReadVariableXlaSplitND",
2110       "ResizeBilinear",
2111       "ResizeBilinearGrad",
2112       "ResizeNearestNeighbor",
2113       "ResourceApplyAdaMax",
2114       "ResourceApplyAdadelta",
2115       "ResourceApplyAdagrad",
2116       "ResourceApplyAdagradDA",
2117       "ResourceApplyAdagradV2",
2118       "ResourceApplyAdam",
2119       "ResourceApplyAddSign",
2120       "ResourceApplyCenteredRMSProp",
2121       "ResourceApplyFtrl",
2122       "ResourceApplyFtrlV2",
2123       "ResourceApplyGradientDescent",
2124       "ResourceApplyKerasMomentum",
2125       "ResourceApplyMomentum",
2126       "ResourceApplyPowerSign",
2127       "ResourceApplyProximalAdagrad",
2128       "ResourceApplyProximalGradientDescent",
2129       "ResourceApplyRMSProp",
2130       "ResourceGather",
2131       "ResourceScatterAdd",
2132       "ResourceScatterDiv",
2133       "ResourceScatterMax",
2134       "ResourceScatterMin",
2135       "ResourceScatterMul",
2136       "ResourceScatterNdAdd",
2137       "ResourceScatterNdSub",
2138       "ResourceScatterNdUpdate",
2139       "ResourceScatterSub",
2140       "ResourceScatterUpdate",
2141       "RngReadAndSkip",
2142       "RngSkip",
2143       "Roll",
2144       "ScatterNd",
2145       "SelfAdjointEigV2",
2146       "SoftmaxCrossEntropyWithLogits",
2147       "SpaceToBatch",
2148       "SpaceToBatchND",
2149       "SpaceToDepth",
2150       "SparseMatMul",
2151       "SparseToDense",
2152       "StackCloseV2",
2153       "StackPopV2",
2154       "StackPushV2",
2155       "StackV2",
2156       "StatefulPartitionedCall",
2157       "StatefulStandardNormalV2",
2158       "StatefulTruncatedNormal",
2159       "StatefulUniform",
2160       "StatefulUniformFullInt",
2161       "StatefulUniformInt",
2162       "StatelessCase",
2163       "StatelessIf",
2164       "StatelessMultinomial",
2165       "StatelessParameterizedTruncatedNormal",
2166       "StatelessRandomGetAlg",
2167       "StatelessRandomGetKeyCounter",
2168       "StatelessRandomGetKeyCounterAlg",
2169       "StatelessRandomNormal",
2170       "StatelessRandomNormalV2",
2171       "StatelessRandomUniform",
2172       "StatelessRandomUniformV2",
2173       "StatelessRandomUniformInt",
2174       "StatelessRandomUniformIntV2",
2175       "StatelessRandomUniformFullInt",
2176       "StatelessRandomUniformFullIntV2",
2177       "StatelessTruncatedNormal",
2178       "StatelessTruncatedNormalV2",
2179       "StatelessWhile",
2180       "Svd",
2181       "SymbolicGradient",
2182       "TensorArrayCloseV3",
2183       "TensorArrayConcatV3",
2184       "TensorArrayGatherV3",
2185       "TensorArrayGradV3",
2186       "TensorArrayReadV3",
2187       "TensorArrayScatterV3",
2188       "TensorArraySizeV3",
2189       "TensorArraySplitV3",
2190       "TensorArrayV3",
2191       "TensorArrayWriteV3",
2192       "TensorListConcatV2",
2193       "TensorListElementShape",
2194       "TensorListFromTensor",
2195       "TensorListGather",
2196       "TensorListGetItem",
2197       "TensorListLength",
2198       "TensorListPopBack",
2199       "TensorListPushBack",
2200       "TensorListReserve",
2201       "TensorListSetItem",
2202       "TensorListSplit",
2203       "TensorListStack",
2204       "TensorScatterAdd",
2205       "TensorScatterMax",
2206       "TensorScatterMin",
2207       "TensorScatterSub",
2208       "TensorScatterUpdate",
2209       "ToBool",
2210       "TridiagonalSolve",
2211       "TridiagonalMatMul",
2212       "TruncatedNormal",
2213       "Unique",
2214       "UniqueV2",
2215       "UpperBound",
2216       "UnsortedSegmentMax",
2217       "UnsortedSegmentMin",
2218       "UnsortedSegmentProd",
2219       "UnsortedSegmentSum",
2220       "VarIsInitializedOp",
2221       "VariableShape",
2222       "Where",
2223       "While",
2224       "XlaBroadcastHelper",
2225       "XlaCallModule",
2226       "XlaConcatND",
2227       "XlaConv",
2228       "XlaConvV2",
2229       "XlaCustomCall",
2230       "XlaDequantize",
2231       "XlaDot",
2232       "XlaDotV2",
2233       "XlaDynamicSlice",
2234       "XlaDynamicUpdateSlice",
2235       "XlaEinsum",
2236       "XlaGather",
2237       "XlaIf",
2238       "XlaKeyValueSort",
2239       "XlaOptimizationBarrier",
2240       "XlaPad",
2241       "XlaRecv",
2242       "XlaReduce",
2243       "XlaReducePrecision",
2244       "XlaReduceWindow",
2245       "XlaRemoveDynamicDimensionSize",
2246       "XlaReplicaId",
2247       "XlaRngBitGenerator",
2248       "XlaScatter",
2249       "XlaSelectAndScatter",
2250       "XlaSelfAdjointEig",
2251       "XlaSend",
2252       "XlaSetBound",
2253       "XlaSetDynamicDimensionSize",
2254       "XlaSharding",
2255       "XlaSort",
2256       "XlaSplitND",
2257       "XlaSpmdFullToShardShape",
2258       "XlaSpmdShardToFullShape",
2259       "XlaSvd",
2260       "XlaVariadicReduce",
2261       "XlaVariadicReduceV2",
2262       "XlaVariadicSort",
2263       "XlaWhile",
2264       "Zeta",
2265       "_Arg",
2266       "_ArrayToList",
2267       "_ListToArray",
2268       "_Retval"};
2269   return result;
2270 }
2271 
2272 }  // namespace testing
2273 }  // namespace tensorflow
2274