xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/jit/xla_cluster_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/xla_cluster_util.h"
17 
18 #include <string>
19 #include <unordered_map>
20 
21 #include "absl/algorithm/container.h"
22 #include "absl/container/inlined_vector.h"
23 #include "absl/strings/match.h"
24 #include "absl/strings/numbers.h"
25 #include "absl/strings/str_cat.h"
26 #include "absl/strings/str_join.h"
27 #include "tensorflow/compiler/jit/flags.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/core/common_runtime/function.h"
30 #include "tensorflow/core/framework/bounds_check.h"
31 #include "tensorflow/core/framework/node_def.pb.h"
32 #include "tensorflow/core/graph/control_flow.h"
33 #include "tensorflow/core/lib/gtl/cleanup.h"
34 #include "tensorflow/core/lib/strings/proto_serialization.h"
35 #include "tensorflow/core/platform/errors.h"
36 #include "tensorflow/core/platform/fingerprint.h"
37 #include "tensorflow/core/public/session_options.h"
38 #include "tensorflow/core/util/device_name_utils.h"
39 #include "tensorflow/core/util/xla_config_registry.h"
40 
41 namespace tensorflow {
42 
43 const char* const kXlaClusterAttr = "_XlaCluster";
44 const char* const kXlaCompileTimeConstantInputsAttr =
45     "_XlaCompileTimeConstantInputs";
46 
47 namespace {
48 // Returns a string describing how an edge from src to dst would
49 // create a cycle.
DescribeCycle(const GraphCycles * cycles,const Graph & graph,int src,int dst)50 string DescribeCycle(const GraphCycles* cycles, const Graph& graph, int src,
51                      int dst) {
52   int32_t max_path_size = graph.num_node_ids() + 1;
53   std::vector<int32> path(max_path_size);
54   int32_t path_size = cycles->FindPath(dst, src, max_path_size, path.data());
55   if (path_size == 0) {
56     return "";
57   }
58 
59   auto node_name = [&graph](int node_id) {
60     if (!FastBoundsCheck(node_id, graph.num_node_ids())) {
61       return string("(null)");
62     }
63     auto* node = graph.FindNodeId(node_id);
64     if (node == nullptr) {
65       return string("(null)");
66     }
67     return node->name();
68   };
69 
70   string description;
71   absl::StrAppend(&description, "Edge from ", node_name(src), " to ",
72                   node_name(dst), " would create a cycle.\n");
73   path.resize(path_size);
74   for (int32_t node_id : path) {
75     string ascii_art;
76     if (node_id == dst) {
77       ascii_art = "+-> ";
78     } else if (node_id != src) {
79       ascii_art = "|   ";
80     } else {
81       ascii_art = "+-- ";
82     }
83     absl::StrAppend(&description, ascii_art, node_name(node_id), "\n");
84   }
85   return description;
86 }
87 
AlwaysForwardsRefInput(const Node & node)88 bool AlwaysForwardsRefInput(const Node& node) { return node.IsIdentity(); }
89 
90 }  // namespace
91 
HasForwardedRefInput(const Node & node)92 bool HasForwardedRefInput(const Node& node) {
93   if (AlwaysForwardsRefInput(node)) {
94     for (const Edge* incoming_edge : node.in_edges()) {
95       if (incoming_edge->IsControlEdge()) {
96         continue;
97       }
98 
99       Node* incoming_node = incoming_edge->src();
100       if (IsRefType(incoming_node->output_type(incoming_edge->src_output()))) {
101         VLOG(2) << "Node " << node.def().ShortDebugString() << " has ref input "
102                 << incoming_node->name() << " " << incoming_node->type_string();
103         return true;
104       }
105     }
106   }
107   return false;
108 }
109 
CreateCycleDetectionGraph(const Graph * graph,GraphCycles * cycles)110 StatusOr<bool> CreateCycleDetectionGraph(const Graph* graph,
111                                          GraphCycles* cycles) {
112   for (int i = 0; i < graph->num_node_ids(); ++i) {
113     // We rely on the node IDs in the cycle detection graph being consecutive
114     // integers starting from 0.
115     CHECK_EQ(i, cycles->NewNode());
116   }
117 
118   // Compute the loop structure of the graph.
119   std::vector<ControlFlowInfo> control_flow_info;
120   TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &control_flow_info));
121 
122   // The clustering code must avoid adding cycles to the graph to prevent
123   // deadlock. However, the graph may contain loops, which would trigger the
124   // cycle detection code. To handle loops, we alter the structure of the cycle
125   // detection graph, disconnecting each loop from the enclosing graph.
126   // Specifically, we:
127   // * add a new "frame" node for each loop.
128   // * replace edges to "Enter" nodes, and edges from "Exit" nodes with edges
129   //   to/from the corresponding frame node. In essence, we collapse the loop
130   //   into a single node for the purpose of cycle detection in the enclosing
131   //   graph.
132   // * the body of the loop should now be disconnected from the rest of the
133   //   graph; we make it acyclic by breaking loop backedges (edges outgoing from
134   //   "NextIteration" nodes.
135 
136   // Map from frame name strings to node IDs in the cycle detection graph.
137   std::unordered_map<string, int> frame_nodes;
138 
139   // Get the cycle graph node ID for frame 'frame_name', or add one if none
140   // exists.
141   auto GetOrAddFrameNodeId = [&frame_nodes, cycles](const string& frame_name) {
142     int& frame_id = frame_nodes.emplace(frame_name, -1).first->second;
143     if (frame_id < 0) {
144       // The emplace succeeded; we have not allocated a frame node yet.
145       frame_id = cycles->NewNode();
146     }
147     return frame_id;
148   };
149 
150   for (Edge const* edge : graph->edges()) {
151     if (edge->dst()->IsEnter() || edge->src()->IsExit()) {
152       const char* src_type = "pre-enter";
153       const char* dst_type = "post-exit";
154       int src = edge->src()->id();
155       int dst = edge->dst()->id();
156 
157       if (edge->dst()->IsEnter()) {
158         // Lift edges to an "Enter" node to the corresponding frame node.
159         const string& frame_name =
160             control_flow_info[edge->dst()->id()].frame_name;
161         dst = GetOrAddFrameNodeId(frame_name);
162         dst_type = "frame";
163       }
164 
165       if (edge->src()->IsExit()) {
166         // Lift edges from an "Exit" node to the corresponding frame node.
167         const string& frame_name =
168             control_flow_info[edge->src()->id()].frame_name;
169         src = GetOrAddFrameNodeId(frame_name);
170         src_type = "frame";
171       }
172 
173       if (!cycles->InsertEdge(src, dst)) {
174         // TODO(b/127521408): We can probably handle this situation with a more
175         // sophisticated SCC based algorithm, but for now we bail out.
176         VLOG(1) << "Cycle detected when adding " << src_type << "->" << dst_type
177                 << " edge: " << DescribeCycle(cycles, *graph, src, dst);
178         return false;
179       }
180       // Drop the original edge.
181       continue;
182     }
183     if (edge->src()->IsNextIteration()) {
184       // Break loop back-edges.
185       continue;
186     }
187     if (!cycles->InsertEdge(edge->src()->id(), edge->dst()->id())) {
188       // This should never happen. All cycles in the graph should contain
189       // a control flow operator.
190       return errors::Internal(
191           "Found cycle in graph without control flow operator during XLA "
192           "compilation: ",
193           DescribeCycle(cycles, *graph, edge->src()->id(), edge->dst()->id()));
194     }
195   }
196 
197   return true;
198 }
199 
GetXlaClusterForNode(const Node & node)200 std::optional<absl::string_view> GetXlaClusterForNode(const Node& node) {
201   const AttrValue* attr_value = node.attrs().Find(kXlaClusterAttr);
202   if (attr_value == nullptr) {
203     return std::nullopt;
204   }
205   Status s = AttrValueHasType(*attr_value, "string");
206   if (!s.ok()) {
207     return std::nullopt;
208   }
209   return attr_value->s();
210 }
211 
HasResourceInputOrOutput(const Node & node)212 bool HasResourceInputOrOutput(const Node& node) {
213   return std::find(node.input_types().begin(), node.input_types().end(),
214                    DT_RESOURCE) != node.input_types().end() ||
215          std::find(node.output_types().begin(), node.output_types().end(),
216                    DT_RESOURCE) != node.output_types().end();
217 }
218 
RemoveFromXlaCluster(NodeDef * node_def)219 void RemoveFromXlaCluster(NodeDef* node_def) {
220   node_def->mutable_attr()->erase(kXlaClusterAttr);
221 }
222 
RemoveFromXlaCluster(Node * node)223 void RemoveFromXlaCluster(Node* node) { node->ClearAttr(kXlaClusterAttr); }
224 
225 namespace {
226 typedef xla_config_registry::XlaGlobalJitLevel XlaGlobalJitLevel;
227 
GetXlaGlobalJitLevel(const OptimizerOptions::GlobalJitLevel & jit_level_in_session_opts)228 XlaGlobalJitLevel GetXlaGlobalJitLevel(
229     const OptimizerOptions::GlobalJitLevel& jit_level_in_session_opts) {
230   XlaGlobalJitLevel result;
231 
232   if (jit_level_in_session_opts == OptimizerOptions::DEFAULT) {
233     // To set compilation to be on by default, change the following line.
234     result.single_gpu = result.general = OptimizerOptions::OFF;
235   } else {
236     result.single_gpu = result.general = jit_level_in_session_opts;
237   }
238 
239   // If the flag tf_xla_auto_jit is a valid, non-DEFAULT setting, it overrides
240   // the setting in ConfigProto.
241   MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
242   if (flags->xla_auto_jit_flag.optimization_level_single_gpu !=
243       OptimizerOptions::DEFAULT) {
244     result.single_gpu = static_cast<OptimizerOptions::GlobalJitLevel>(
245         flags->xla_auto_jit_flag.optimization_level_single_gpu);
246   }
247   if (flags->xla_auto_jit_flag.optimization_level_general !=
248       OptimizerOptions::DEFAULT) {
249     result.general = static_cast<OptimizerOptions::GlobalJitLevel>(
250         flags->xla_auto_jit_flag.optimization_level_general);
251   }
252 
253   return result;
254 }
255 
GetGpuNumber(const string & device_name)256 int GetGpuNumber(const string& device_name) {
257   DeviceNameUtils::ParsedName parsed_name;
258   if (!DeviceNameUtils::ParseFullName(device_name, &parsed_name)) {
259     return -1;
260   }
261 
262   return parsed_name.type == DEVICE_GPU ? parsed_name.id : -1;
263 }
264 }  // namespace
265 
IsSingleGpuGraph(const Graph & g)266 bool IsSingleGpuGraph(const Graph& g) {
267   int gpus_seen = 0;
268   absl::flat_hash_set<string> devices_seen;
269 
270   for (Node* n : g.op_nodes()) {
271     if (devices_seen.contains(n->assigned_device_name())) {
272       continue;
273     }
274 
275     int gpu_number = GetGpuNumber(n->assigned_device_name());
276     if (gpu_number != -1) {
277       if (++gpus_seen > 1) {
278         return false;
279       }
280     }
281 
282     devices_seen.insert(n->assigned_device_name());
283   }
284 
285   return gpus_seen == 1;
286 }
287 
GetGlobalJitLevelForGraph(const GraphOptimizationPassOptions & options)288 OptimizerOptions::GlobalJitLevel GetGlobalJitLevelForGraph(
289     const GraphOptimizationPassOptions& options) {
290   OptimizerOptions::GlobalJitLevel jit_level_in_session_opts =
291       options.session_options->config.graph_options()
292           .optimizer_options()
293           .global_jit_level();
294   XlaGlobalJitLevel xla_global_jit_level =
295       GetXlaGlobalJitLevel(jit_level_in_session_opts);
296   if (xla_global_jit_level.single_gpu == xla_global_jit_level.general) {
297     VLOG(4) << "GetGlobalJitLevelForGraph returning "
298             << xla_global_jit_level.single_gpu;
299     return xla_global_jit_level.single_gpu;
300   }
301   OptimizerOptions::GlobalJitLevel result =
302       IsSingleGpuGraph(**options.graph) ? xla_global_jit_level.single_gpu
303                                         : xla_global_jit_level.general;
304   VLOG(4) << "GetGlobalJitLevelForGraph returning " << result;
305   return result;
306 }
307 
MayCallFunction(const Node & n,const FunctionLibraryDefinition * flib_def)308 bool MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def) {
309   if (flib_def->Contains(n.type_string())) {
310     return true;
311   }
312 
313   // This is a conservative check: there may be nodes with a `func`
314   // attribute that do not make function calls.
315   return absl::c_any_of(n.def().attr(),
316                         [](const std::pair<string, AttrValue>& name_attr_pair) {
317                           return name_attr_pair.second.has_func();
318                         });
319 }
IsShapeConsumerOp(const Node & node)320 bool IsShapeConsumerOp(const Node& node) {
321   return node.type_string() == "Shape" || node.type_string() == "Rank" ||
322          node.type_string() == "Size";
323 }
324 
325 namespace {
326 struct ClusterInfo {
327   int size;
328 
329   // Maps op names to the number of times they appear in the cluster.
330   absl::flat_hash_map<absl::string_view, int> op_histogram;
331 };
332 
HistogramMapToRepeatedOpAndCount(protobuf::RepeatedPtrField<XlaAutoClusteringSummary::OpAndCount> * result,const absl::flat_hash_map<absl::string_view,int> & histogram)333 void HistogramMapToRepeatedOpAndCount(
334     protobuf::RepeatedPtrField<XlaAutoClusteringSummary::OpAndCount>* result,
335     const absl::flat_hash_map<absl::string_view, int>& histogram) {
336   for (const auto& pair : histogram) {
337     XlaAutoClusteringSummary::OpAndCount* new_entry = result->Add();
338     new_entry->set_op(std::string(pair.first));
339     new_entry->set_count(pair.second);
340   }
341 
342   absl::c_sort(*result, [](const XlaAutoClusteringSummary::OpAndCount& a,
343                            const XlaAutoClusteringSummary::OpAndCount& b) {
344     return a.op() < b.op();
345   });
346 }
347 
ClusterInfoToProtobuf(XlaAutoClusteringSummary::Cluster * result,absl::string_view name,const ClusterInfo & info)348 void ClusterInfoToProtobuf(XlaAutoClusteringSummary::Cluster* result,
349                            absl::string_view name, const ClusterInfo& info) {
350   result->set_name(std::string(name));
351   result->set_size(info.size);
352   HistogramMapToRepeatedOpAndCount(result->mutable_op_histogram(),
353                                    info.op_histogram);
354 }
355 }  // namespace
356 
GetXlaAutoClusteringSummary(const Graph & graph)357 XlaAutoClusteringSummary GetXlaAutoClusteringSummary(const Graph& graph) {
358   absl::flat_hash_map<absl::string_view, ClusterInfo> cluster_name_to_info;
359   XlaAutoClusteringSummary result;
360 
361   absl::flat_hash_map<absl::string_view, int> unclustered_op_histogram;
362 
363   for (Node* n : graph.nodes()) {
364     std::optional<absl::string_view> cluster_name = GetXlaClusterForNode(*n);
365     if (cluster_name) {
366       result.set_clustered_node_count(result.clustered_node_count() + 1);
367       ClusterInfo* info = &cluster_name_to_info[*cluster_name];
368       info->size++;
369       info->op_histogram[n->type_string()]++;
370     } else {
371       result.set_unclustered_node_count(result.unclustered_node_count() + 1);
372       unclustered_op_histogram[n->type_string()]++;
373     }
374   }
375 
376   for (const auto& pair : cluster_name_to_info) {
377     XlaAutoClusteringSummary::Cluster* new_cluster = result.add_clusters();
378     ClusterInfoToProtobuf(new_cluster, pair.first, pair.second);
379   }
380 
381   absl::c_sort(*result.mutable_clusters(),
382                [&](const XlaAutoClusteringSummary::Cluster& a,
383                    const XlaAutoClusteringSummary::Cluster& b) {
384                  return a.name() < b.name();
385                });
386 
387   HistogramMapToRepeatedOpAndCount(result.mutable_unclustered_op_histogram(),
388                                    unclustered_op_histogram);
389 
390   return result;
391 }
392 
393 namespace {
394 using CallTargetListTy = absl::InlinedVector<NameAttrList, 2>;
395 
GetCallTargetListFromNode(const Node & n,FunctionLibraryRuntime * lib_runtime)396 CallTargetListTy GetCallTargetListFromNode(
397     const Node& n, FunctionLibraryRuntime* lib_runtime) {
398   const FunctionLibraryDefinition& flib_def =
399       *lib_runtime->GetFunctionLibraryDefinition();
400   if (flib_def.Find(n.type_string())) {
401     NameAttrList callee;
402     callee.set_name(n.type_string());
403     *callee.mutable_attr() = n.def().attr();
404     return {callee};
405   }
406 
407   CallTargetListTy result;
408   for (const auto& name_attr_pair : n.attrs()) {
409     const AttrValue& attr_value = name_attr_pair.second;
410     if (attr_value.value_case() == AttrValue::kFunc) {
411       result.push_back(attr_value.func());
412     } else if (attr_value.value_case() == AttrValue::kList) {
413       result.insert(result.end(), attr_value.list().func().begin(),
414                     attr_value.list().func().end());
415     }
416   }
417 
418   return result;
419 }
420 
421 enum class Direction { kForward, kBackward };
422 
423 Status GetNodesRelatedToRefVariablesInDirection(
424     const Graph& graph, FunctionLibraryRuntime* lib_runtime,
425     Direction direction, int depth, absl::flat_hash_set<Node*>* result);
426 
DoesAnyCalleeHaveRefNodes(const CallTargetListTy & call_target_list,FunctionLibraryRuntime * lib_runtime,Direction direction,int depth)427 StatusOr<bool> DoesAnyCalleeHaveRefNodes(
428     const CallTargetListTy& call_target_list,
429     FunctionLibraryRuntime* lib_runtime, Direction direction, int depth) {
430   const int kMaxDepth = 10;
431 
432   if (depth == kMaxDepth && !call_target_list.empty()) {
433     // Conservative answer to avoid recursing too much.
434     return true;
435   }
436 
437   absl::flat_hash_set<Node*> callee_ref_nodes;
438   for (const NameAttrList& call_target : call_target_list) {
439     const OpRegistrationData* op_reg;
440     if (OpRegistry::Global()->LookUp(call_target.name(), &op_reg).ok()) {
441       const OpDef& op = op_reg->op_def;
442       if (absl::c_any_of(op.output_arg(), [](const OpDef::ArgDef arg) {
443             return arg.is_ref();
444           })) {
445         return true;
446       }
447       continue;
448     }
449 
450     callee_ref_nodes.clear();
451     FunctionLibraryRuntime::Handle handle;
452     if (!lib_runtime
453              ->Instantiate(call_target.name(), AttrSlice(&call_target.attr()),
454                            &handle)
455              .ok()) {
456       VLOG(2) << "Could not find " << call_target.name()
457               << " in the function library.";
458       // Since we don't know the semantic of `n` we don't know if this is an
459       // error.  We return true to signal a conservative answer.
460       return true;
461     }
462 
463     auto release_handle_on_return = gtl::MakeCleanup(
464         [&] { TF_CHECK_OK(lib_runtime->ReleaseHandle(handle)); });
465 
466     const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
467     TF_RETURN_IF_ERROR(GetNodesRelatedToRefVariablesInDirection(
468         *fbody->graph, lib_runtime, direction, depth + 1, &callee_ref_nodes));
469 
470     // We could possibly use something cheaper than
471     // GetNodesRelatedToRefVariablesInDirection since we only care about the
472     // size of `callee_ref_nodes` but for now we don't ceare.
473     if (!callee_ref_nodes.empty()) {
474       return true;
475     }
476   }
477 
478   return false;
479 }
480 
481 // Helper for GetNodesRelatedToRefVariables that traverses the graph in one
482 // direction.
GetNodesRelatedToRefVariablesInDirection(const Graph & graph,FunctionLibraryRuntime * lib_runtime,Direction direction,int depth,absl::flat_hash_set<Node * > * result)483 Status GetNodesRelatedToRefVariablesInDirection(
484     const Graph& graph, FunctionLibraryRuntime* lib_runtime,
485     Direction direction, int depth, absl::flat_hash_set<Node*>* result) {
486   std::vector<Node*> nodes_in_order;
487   if (direction == Direction::kForward) {
488     GetReversePostOrder(graph, &nodes_in_order,
489                         /*stable_comparator=*/NodeComparatorName());
490   } else {
491     GetPostOrder(graph, &nodes_in_order,
492                  /*stable_comparator=*/NodeComparatorName());
493   }
494 
495   size_t old_result_size;
496   int iterations = 0;
497 
498   const int kMaxIterations = 10 * 1000;
499 
500   std::vector<bool> callee_has_ref_nodes_cache;
501   callee_has_ref_nodes_cache.resize(graph.num_node_ids());
502 
503   auto does_callee_have_ref_nodes = [&](Node* n) -> StatusOr<bool> {
504     if (iterations == 1) {
505       TF_ASSIGN_OR_RETURN(
506           bool callee_has_ref_nodes,
507           DoesAnyCalleeHaveRefNodes(GetCallTargetListFromNode(*n, lib_runtime),
508                                     lib_runtime, direction, depth));
509       callee_has_ref_nodes_cache[n->id()] = callee_has_ref_nodes;
510       return callee_has_ref_nodes;
511     } else {
512       return {callee_has_ref_nodes_cache[n->id()]};
513     }
514   };
515 
516   do {
517     TF_RET_CHECK(iterations++ < kMaxIterations) << "infinite loop?";
518 
519     old_result_size = result->size();
520     for (Node* n : nodes_in_order) {
521       if (n->IsSource() || n->IsSink()) {
522         continue;
523       }
524 
525       bool inserted_n = false;
526       const EdgeSet& edges =
527           direction == Direction::kForward ? n->in_edges() : n->out_edges();
528       for (const Edge* e : edges) {
529         if (result->contains(direction == Direction::kForward ? e->src()
530                                                               : e->dst())) {
531           result->insert(n);
532           inserted_n = true;
533           break;
534         }
535       }
536 
537       if (inserted_n) {
538         continue;
539       }
540 
541       if (direction == Direction::kForward &&
542           absl::c_any_of(n->output_types(), IsRefType)) {
543         result->insert(n);
544         continue;
545       }
546 
547       TF_ASSIGN_OR_RETURN(bool callee_has_ref_nodes,
548                           does_callee_have_ref_nodes(n));
549       if (callee_has_ref_nodes) {
550         result->insert(n);
551         continue;
552       }
553     }
554 
555     // Loop until convergence.
556   } while (result->size() != old_result_size);
557 
558   VLOG(2) << "# iterations = " << iterations;
559 
560   return OkStatus();
561 }
562 
563 // Sorts control inputs of a graphdef so that they are deterministically
564 // ordered.
SortControlInputs(GraphDef * gdef)565 void SortControlInputs(GraphDef* gdef) {
566   int64_t num_nodes = gdef->node_size();
567   for (int64_t i = 0; i < num_nodes; ++i) {
568     NodeDef* node = gdef->mutable_node(i);
569     // Stable sort control inputs and leave the order of data inputs unchanged.
570     std::stable_sort(node->mutable_input()->begin(),
571                      node->mutable_input()->end(),
572                      [](const string& a, const string& b) {
573                        bool a_is_control = absl::StartsWith(a, "^");
574                        bool b_is_control = absl::StartsWith(b, "^");
575                        return (!a_is_control && b_is_control) ||
576                               (a_is_control && b_is_control && a < b);
577                      });
578   }
579 }
580 }  // namespace
581 
GetNodesRelatedToRefVariables(const Graph & graph,FunctionLibraryRuntime * lib_runtime)582 StatusOr<absl::flat_hash_set<Node*>> GetNodesRelatedToRefVariables(
583     const Graph& graph, FunctionLibraryRuntime* lib_runtime) {
584   absl::flat_hash_set<Node*> result;
585   TF_RETURN_IF_ERROR(GetNodesRelatedToRefVariablesInDirection(
586       graph, lib_runtime, Direction::kForward, 0, &result));
587   TF_RETURN_IF_ERROR(GetNodesRelatedToRefVariablesInDirection(
588       graph, lib_runtime, Direction::kBackward, 0, &result));
589 
590   VLOG(1) << "GetNodesRelatedToRefVariables() found " << result.size()
591           << " nodes";
592   return result;
593 }
594 
SerializeGraphDeterministic(const Graph & graph)595 StatusOr<std::string> SerializeGraphDeterministic(const Graph& graph) {
596   GraphDef def;
597   graph.ToGraphDef(&def);
598 
599   // Before serialization, sort each node's control inputs to achieve
600   // determinism. Sorting control inputs could help (but not necessarily) create
601   // a deterministic serialization and fingerprint. Other sources of
602   // nondeterminism include unstable node ordering.
603   SortControlInputs(&def);
604 
605   std::string s;
606   if (!SerializeToStringDeterministic(def, &s)) {
607     return errors::Internal("Failed to serialize graphdef.");
608   }
609   return s;
610 }
611 
FingerprintGraph(const Graph & graph)612 StatusOr<uint64> FingerprintGraph(const Graph& graph) {
613   TF_ASSIGN_OR_RETURN(std::string serialized,
614                       SerializeGraphDeterministic(graph));
615   return Hash64(serialized.data(), serialized.size());
616 }
617 
618 // Register a callback for querying XlaGlobalJitLevel.
619 REGISTER_XLA_CONFIG_GETTER(GetXlaGlobalJitLevel);
620 
621 }  // namespace tensorflow
622