xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/jit/xla_cluster_util.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Contains utilities for clustering compilable graph nodes via XLA.
17 
18 #ifndef TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
19 #define TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
20 
21 #include <string>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/types/optional.h"
26 #include "tensorflow/compiler/jit/xla_activity.pb.h"
27 #include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 #include "tensorflow/core/common_runtime/optimization_registry.h"
30 #include "tensorflow/core/graph/algorithm.h"
31 #include "tensorflow/stream_executor/lib/statusor.h"
32 
33 namespace tensorflow {
34 
35 // The attribute that marks nodes to be grouped into functions by the
36 // encapsulate subgraphs pass.
37 extern const char* const kXlaClusterAttr;
38 
39 // The attribute that marks certain inputs to a Node as required to be a
40 // constant at compile time.  If this attribute is present then the
41 // CompileTimeConstantInput information in the corresponding XlaOpKernel is
42 // ignored.
43 //
44 // The value for this attribute, if present, has to be a list of strings naming
45 // the inputs to the node that must be constant.
46 extern const char* const kXlaCompileTimeConstantInputsAttr;
47 
48 using OrderedNodeSet = std::set<Node*, NodeComparatorID>;
49 
50 // Returns true if `node` has a ref tensor input that it forwards to its output.
51 bool HasForwardedRefInput(const Node& node);
52 
53 // Creates a graph representation to enable cycle detection when clustering.
54 // This representation handles loops in graph by disconnecting each loop from
55 // the enclosing graph.
56 //
57 // Returns true for success and false for valid graphs that we can't handle yet
58 // (b/127521408).
59 StatusOr<bool> CreateCycleDetectionGraph(const Graph* graph,
60                                          GraphCycles* cycles);
61 
62 // Returns the XLA cluster in which `node` is placed if it is in an XLA cluster,
63 // otherwise returns nullopt.
64 std::optional<absl::string_view> GetXlaClusterForNode(const Node& node);
65 
66 // Removes `node_def` its XLA cluster (by clearing its _XlaCluster attribute).
67 void RemoveFromXlaCluster(NodeDef* node_def);
68 
69 // Removes `node` its XLA cluster (by clearing its _XlaCluster attribute).
70 void RemoveFromXlaCluster(Node* node);
71 
72 // Returns true if `node` has a DT_RESOURCE typed input or output.
73 bool HasResourceInputOrOutput(const Node& node);
74 
75 // Determines the global jit level based on GraphOptimizationPassOptions,
76 // --tf_xla_auto_jit and whether the graph is a single GPU graph.
77 OptimizerOptions::GlobalJitLevel GetGlobalJitLevelForGraph(
78     const GraphOptimizationPassOptions& options);
79 
80 // Returns true if `g` is a single-GPU graph.  A single-GPU graph uses exactly
81 // one GPU (and any number of CPUs).
82 bool IsSingleGpuGraph(const Graph& g);
83 
84 // Returns true if it is possible (but not guaranteed) that `n` calls a
85 // function.
86 bool MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def);
87 
88 // Returns true if `node` an operator that consumes only the shape of its input,
89 // not the data itself.
90 bool IsShapeConsumerOp(const Node& node);
91 
92 // Computes a clustering summary for `graph`.  See documentation on
93 // `XlaAutoClusteringSummary` for details.
94 XlaAutoClusteringSummary GetXlaAutoClusteringSummary(const Graph& graph);
95 
96 // Returns the set of nodes that have a path to or from nodes that may have ref
97 // variables as input or output.
98 //
99 // We assume each node has a trivial path to itself so the returned set includes
100 // all of the nodes that have ref variables as input or output.
101 StatusOr<absl::flat_hash_set<Node*>> GetNodesRelatedToRefVariables(
102     const Graph& graph, FunctionLibraryRuntime* lib_runtime);
103 
104 // Deterministically serialized the graph to a byte string.
105 StatusOr<std::string> SerializeGraphDeterministic(const Graph& graph);
106 
107 // Computes a fingerprint of the given `graph`. The fingerprint can use used to
108 // check if two graphs are likely the same but should not be relied on
109 // determining if the graphs are identical.
110 StatusOr<uint64> FingerprintGraph(const Graph& graph);
111 
112 }  // namespace tensorflow
113 
114 #endif  // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
115