1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/xla_cluster_util.h"
17
18 #include <string>
19 #include <unordered_map>
20
21 #include "absl/algorithm/container.h"
22 #include "absl/container/inlined_vector.h"
23 #include "absl/strings/match.h"
24 #include "absl/strings/numbers.h"
25 #include "absl/strings/str_cat.h"
26 #include "absl/strings/str_join.h"
27 #include "tensorflow/compiler/jit/flags.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/core/common_runtime/function.h"
30 #include "tensorflow/core/framework/bounds_check.h"
31 #include "tensorflow/core/framework/node_def.pb.h"
32 #include "tensorflow/core/graph/control_flow.h"
33 #include "tensorflow/core/lib/gtl/cleanup.h"
34 #include "tensorflow/core/lib/strings/proto_serialization.h"
35 #include "tensorflow/core/platform/errors.h"
36 #include "tensorflow/core/platform/fingerprint.h"
37 #include "tensorflow/core/public/session_options.h"
38 #include "tensorflow/core/util/device_name_utils.h"
39 #include "tensorflow/core/util/xla_config_registry.h"
40
41 namespace tensorflow {
42
43 const char* const kXlaClusterAttr = "_XlaCluster";
44 const char* const kXlaCompileTimeConstantInputsAttr =
45 "_XlaCompileTimeConstantInputs";
46
47 namespace {
48 // Returns a string describing how an edge from src to dst would
49 // create a cycle.
DescribeCycle(const GraphCycles * cycles,const Graph & graph,int src,int dst)50 string DescribeCycle(const GraphCycles* cycles, const Graph& graph, int src,
51 int dst) {
52 int32_t max_path_size = graph.num_node_ids() + 1;
53 std::vector<int32> path(max_path_size);
54 int32_t path_size = cycles->FindPath(dst, src, max_path_size, path.data());
55 if (path_size == 0) {
56 return "";
57 }
58
59 auto node_name = [&graph](int node_id) {
60 if (!FastBoundsCheck(node_id, graph.num_node_ids())) {
61 return string("(null)");
62 }
63 auto* node = graph.FindNodeId(node_id);
64 if (node == nullptr) {
65 return string("(null)");
66 }
67 return node->name();
68 };
69
70 string description;
71 absl::StrAppend(&description, "Edge from ", node_name(src), " to ",
72 node_name(dst), " would create a cycle.\n");
73 path.resize(path_size);
74 for (int32_t node_id : path) {
75 string ascii_art;
76 if (node_id == dst) {
77 ascii_art = "+-> ";
78 } else if (node_id != src) {
79 ascii_art = "| ";
80 } else {
81 ascii_art = "+-- ";
82 }
83 absl::StrAppend(&description, ascii_art, node_name(node_id), "\n");
84 }
85 return description;
86 }
87
AlwaysForwardsRefInput(const Node & node)88 bool AlwaysForwardsRefInput(const Node& node) { return node.IsIdentity(); }
89
90 } // namespace
91
HasForwardedRefInput(const Node & node)92 bool HasForwardedRefInput(const Node& node) {
93 if (AlwaysForwardsRefInput(node)) {
94 for (const Edge* incoming_edge : node.in_edges()) {
95 if (incoming_edge->IsControlEdge()) {
96 continue;
97 }
98
99 Node* incoming_node = incoming_edge->src();
100 if (IsRefType(incoming_node->output_type(incoming_edge->src_output()))) {
101 VLOG(2) << "Node " << node.def().ShortDebugString() << " has ref input "
102 << incoming_node->name() << " " << incoming_node->type_string();
103 return true;
104 }
105 }
106 }
107 return false;
108 }
109
CreateCycleDetectionGraph(const Graph * graph,GraphCycles * cycles)110 StatusOr<bool> CreateCycleDetectionGraph(const Graph* graph,
111 GraphCycles* cycles) {
112 for (int i = 0; i < graph->num_node_ids(); ++i) {
113 // We rely on the node IDs in the cycle detection graph being consecutive
114 // integers starting from 0.
115 CHECK_EQ(i, cycles->NewNode());
116 }
117
118 // Compute the loop structure of the graph.
119 std::vector<ControlFlowInfo> control_flow_info;
120 TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &control_flow_info));
121
122 // The clustering code must avoid adding cycles to the graph to prevent
123 // deadlock. However, the graph may contain loops, which would trigger the
124 // cycle detection code. To handle loops, we alter the structure of the cycle
125 // detection graph, disconnecting each loop from the enclosing graph.
126 // Specifically, we:
127 // * add a new "frame" node for each loop.
128 // * replace edges to "Enter" nodes, and edges from "Exit" nodes with edges
129 // to/from the corresponding frame node. In essence, we collapse the loop
130 // into a single node for the purpose of cycle detection in the enclosing
131 // graph.
132 // * the body of the loop should now be disconnected from the rest of the
133 // graph; we make it acyclic by breaking loop backedges (edges outgoing from
134 // "NextIteration" nodes.
135
136 // Map from frame name strings to node IDs in the cycle detection graph.
137 std::unordered_map<string, int> frame_nodes;
138
139 // Get the cycle graph node ID for frame 'frame_name', or add one if none
140 // exists.
141 auto GetOrAddFrameNodeId = [&frame_nodes, cycles](const string& frame_name) {
142 int& frame_id = frame_nodes.emplace(frame_name, -1).first->second;
143 if (frame_id < 0) {
144 // The emplace succeeded; we have not allocated a frame node yet.
145 frame_id = cycles->NewNode();
146 }
147 return frame_id;
148 };
149
150 for (Edge const* edge : graph->edges()) {
151 if (edge->dst()->IsEnter() || edge->src()->IsExit()) {
152 const char* src_type = "pre-enter";
153 const char* dst_type = "post-exit";
154 int src = edge->src()->id();
155 int dst = edge->dst()->id();
156
157 if (edge->dst()->IsEnter()) {
158 // Lift edges to an "Enter" node to the corresponding frame node.
159 const string& frame_name =
160 control_flow_info[edge->dst()->id()].frame_name;
161 dst = GetOrAddFrameNodeId(frame_name);
162 dst_type = "frame";
163 }
164
165 if (edge->src()->IsExit()) {
166 // Lift edges from an "Exit" node to the corresponding frame node.
167 const string& frame_name =
168 control_flow_info[edge->src()->id()].frame_name;
169 src = GetOrAddFrameNodeId(frame_name);
170 src_type = "frame";
171 }
172
173 if (!cycles->InsertEdge(src, dst)) {
174 // TODO(b/127521408): We can probably handle this situation with a more
175 // sophisticated SCC based algorithm, but for now we bail out.
176 VLOG(1) << "Cycle detected when adding " << src_type << "->" << dst_type
177 << " edge: " << DescribeCycle(cycles, *graph, src, dst);
178 return false;
179 }
180 // Drop the original edge.
181 continue;
182 }
183 if (edge->src()->IsNextIteration()) {
184 // Break loop back-edges.
185 continue;
186 }
187 if (!cycles->InsertEdge(edge->src()->id(), edge->dst()->id())) {
188 // This should never happen. All cycles in the graph should contain
189 // a control flow operator.
190 return errors::Internal(
191 "Found cycle in graph without control flow operator during XLA "
192 "compilation: ",
193 DescribeCycle(cycles, *graph, edge->src()->id(), edge->dst()->id()));
194 }
195 }
196
197 return true;
198 }
199
GetXlaClusterForNode(const Node & node)200 std::optional<absl::string_view> GetXlaClusterForNode(const Node& node) {
201 const AttrValue* attr_value = node.attrs().Find(kXlaClusterAttr);
202 if (attr_value == nullptr) {
203 return std::nullopt;
204 }
205 Status s = AttrValueHasType(*attr_value, "string");
206 if (!s.ok()) {
207 return std::nullopt;
208 }
209 return attr_value->s();
210 }
211
HasResourceInputOrOutput(const Node & node)212 bool HasResourceInputOrOutput(const Node& node) {
213 return std::find(node.input_types().begin(), node.input_types().end(),
214 DT_RESOURCE) != node.input_types().end() ||
215 std::find(node.output_types().begin(), node.output_types().end(),
216 DT_RESOURCE) != node.output_types().end();
217 }
218
RemoveFromXlaCluster(NodeDef * node_def)219 void RemoveFromXlaCluster(NodeDef* node_def) {
220 node_def->mutable_attr()->erase(kXlaClusterAttr);
221 }
222
RemoveFromXlaCluster(Node * node)223 void RemoveFromXlaCluster(Node* node) { node->ClearAttr(kXlaClusterAttr); }
224
225 namespace {
226 typedef xla_config_registry::XlaGlobalJitLevel XlaGlobalJitLevel;
227
GetXlaGlobalJitLevel(const OptimizerOptions::GlobalJitLevel & jit_level_in_session_opts)228 XlaGlobalJitLevel GetXlaGlobalJitLevel(
229 const OptimizerOptions::GlobalJitLevel& jit_level_in_session_opts) {
230 XlaGlobalJitLevel result;
231
232 if (jit_level_in_session_opts == OptimizerOptions::DEFAULT) {
233 // To set compilation to be on by default, change the following line.
234 result.single_gpu = result.general = OptimizerOptions::OFF;
235 } else {
236 result.single_gpu = result.general = jit_level_in_session_opts;
237 }
238
239 // If the flag tf_xla_auto_jit is a valid, non-DEFAULT setting, it overrides
240 // the setting in ConfigProto.
241 MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
242 if (flags->xla_auto_jit_flag.optimization_level_single_gpu !=
243 OptimizerOptions::DEFAULT) {
244 result.single_gpu = static_cast<OptimizerOptions::GlobalJitLevel>(
245 flags->xla_auto_jit_flag.optimization_level_single_gpu);
246 }
247 if (flags->xla_auto_jit_flag.optimization_level_general !=
248 OptimizerOptions::DEFAULT) {
249 result.general = static_cast<OptimizerOptions::GlobalJitLevel>(
250 flags->xla_auto_jit_flag.optimization_level_general);
251 }
252
253 return result;
254 }
255
GetGpuNumber(const string & device_name)256 int GetGpuNumber(const string& device_name) {
257 DeviceNameUtils::ParsedName parsed_name;
258 if (!DeviceNameUtils::ParseFullName(device_name, &parsed_name)) {
259 return -1;
260 }
261
262 return parsed_name.type == DEVICE_GPU ? parsed_name.id : -1;
263 }
264 } // namespace
265
IsSingleGpuGraph(const Graph & g)266 bool IsSingleGpuGraph(const Graph& g) {
267 int gpus_seen = 0;
268 absl::flat_hash_set<string> devices_seen;
269
270 for (Node* n : g.op_nodes()) {
271 if (devices_seen.contains(n->assigned_device_name())) {
272 continue;
273 }
274
275 int gpu_number = GetGpuNumber(n->assigned_device_name());
276 if (gpu_number != -1) {
277 if (++gpus_seen > 1) {
278 return false;
279 }
280 }
281
282 devices_seen.insert(n->assigned_device_name());
283 }
284
285 return gpus_seen == 1;
286 }
287
GetGlobalJitLevelForGraph(const GraphOptimizationPassOptions & options)288 OptimizerOptions::GlobalJitLevel GetGlobalJitLevelForGraph(
289 const GraphOptimizationPassOptions& options) {
290 OptimizerOptions::GlobalJitLevel jit_level_in_session_opts =
291 options.session_options->config.graph_options()
292 .optimizer_options()
293 .global_jit_level();
294 XlaGlobalJitLevel xla_global_jit_level =
295 GetXlaGlobalJitLevel(jit_level_in_session_opts);
296 if (xla_global_jit_level.single_gpu == xla_global_jit_level.general) {
297 VLOG(4) << "GetGlobalJitLevelForGraph returning "
298 << xla_global_jit_level.single_gpu;
299 return xla_global_jit_level.single_gpu;
300 }
301 OptimizerOptions::GlobalJitLevel result =
302 IsSingleGpuGraph(**options.graph) ? xla_global_jit_level.single_gpu
303 : xla_global_jit_level.general;
304 VLOG(4) << "GetGlobalJitLevelForGraph returning " << result;
305 return result;
306 }
307
MayCallFunction(const Node & n,const FunctionLibraryDefinition * flib_def)308 bool MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def) {
309 if (flib_def->Contains(n.type_string())) {
310 return true;
311 }
312
313 // This is a conservative check: there may be nodes with a `func`
314 // attribute that do not make function calls.
315 return absl::c_any_of(n.def().attr(),
316 [](const std::pair<string, AttrValue>& name_attr_pair) {
317 return name_attr_pair.second.has_func();
318 });
319 }
IsShapeConsumerOp(const Node & node)320 bool IsShapeConsumerOp(const Node& node) {
321 return node.type_string() == "Shape" || node.type_string() == "Rank" ||
322 node.type_string() == "Size";
323 }
324
325 namespace {
326 struct ClusterInfo {
327 int size;
328
329 // Maps op names to the number of times they appear in the cluster.
330 absl::flat_hash_map<absl::string_view, int> op_histogram;
331 };
332
HistogramMapToRepeatedOpAndCount(protobuf::RepeatedPtrField<XlaAutoClusteringSummary::OpAndCount> * result,const absl::flat_hash_map<absl::string_view,int> & histogram)333 void HistogramMapToRepeatedOpAndCount(
334 protobuf::RepeatedPtrField<XlaAutoClusteringSummary::OpAndCount>* result,
335 const absl::flat_hash_map<absl::string_view, int>& histogram) {
336 for (const auto& pair : histogram) {
337 XlaAutoClusteringSummary::OpAndCount* new_entry = result->Add();
338 new_entry->set_op(std::string(pair.first));
339 new_entry->set_count(pair.second);
340 }
341
342 absl::c_sort(*result, [](const XlaAutoClusteringSummary::OpAndCount& a,
343 const XlaAutoClusteringSummary::OpAndCount& b) {
344 return a.op() < b.op();
345 });
346 }
347
ClusterInfoToProtobuf(XlaAutoClusteringSummary::Cluster * result,absl::string_view name,const ClusterInfo & info)348 void ClusterInfoToProtobuf(XlaAutoClusteringSummary::Cluster* result,
349 absl::string_view name, const ClusterInfo& info) {
350 result->set_name(std::string(name));
351 result->set_size(info.size);
352 HistogramMapToRepeatedOpAndCount(result->mutable_op_histogram(),
353 info.op_histogram);
354 }
355 } // namespace
356
GetXlaAutoClusteringSummary(const Graph & graph)357 XlaAutoClusteringSummary GetXlaAutoClusteringSummary(const Graph& graph) {
358 absl::flat_hash_map<absl::string_view, ClusterInfo> cluster_name_to_info;
359 XlaAutoClusteringSummary result;
360
361 absl::flat_hash_map<absl::string_view, int> unclustered_op_histogram;
362
363 for (Node* n : graph.nodes()) {
364 std::optional<absl::string_view> cluster_name = GetXlaClusterForNode(*n);
365 if (cluster_name) {
366 result.set_clustered_node_count(result.clustered_node_count() + 1);
367 ClusterInfo* info = &cluster_name_to_info[*cluster_name];
368 info->size++;
369 info->op_histogram[n->type_string()]++;
370 } else {
371 result.set_unclustered_node_count(result.unclustered_node_count() + 1);
372 unclustered_op_histogram[n->type_string()]++;
373 }
374 }
375
376 for (const auto& pair : cluster_name_to_info) {
377 XlaAutoClusteringSummary::Cluster* new_cluster = result.add_clusters();
378 ClusterInfoToProtobuf(new_cluster, pair.first, pair.second);
379 }
380
381 absl::c_sort(*result.mutable_clusters(),
382 [&](const XlaAutoClusteringSummary::Cluster& a,
383 const XlaAutoClusteringSummary::Cluster& b) {
384 return a.name() < b.name();
385 });
386
387 HistogramMapToRepeatedOpAndCount(result.mutable_unclustered_op_histogram(),
388 unclustered_op_histogram);
389
390 return result;
391 }
392
393 namespace {
394 using CallTargetListTy = absl::InlinedVector<NameAttrList, 2>;
395
GetCallTargetListFromNode(const Node & n,FunctionLibraryRuntime * lib_runtime)396 CallTargetListTy GetCallTargetListFromNode(
397 const Node& n, FunctionLibraryRuntime* lib_runtime) {
398 const FunctionLibraryDefinition& flib_def =
399 *lib_runtime->GetFunctionLibraryDefinition();
400 if (flib_def.Find(n.type_string())) {
401 NameAttrList callee;
402 callee.set_name(n.type_string());
403 *callee.mutable_attr() = n.def().attr();
404 return {callee};
405 }
406
407 CallTargetListTy result;
408 for (const auto& name_attr_pair : n.attrs()) {
409 const AttrValue& attr_value = name_attr_pair.second;
410 if (attr_value.value_case() == AttrValue::kFunc) {
411 result.push_back(attr_value.func());
412 } else if (attr_value.value_case() == AttrValue::kList) {
413 result.insert(result.end(), attr_value.list().func().begin(),
414 attr_value.list().func().end());
415 }
416 }
417
418 return result;
419 }
420
421 enum class Direction { kForward, kBackward };
422
423 Status GetNodesRelatedToRefVariablesInDirection(
424 const Graph& graph, FunctionLibraryRuntime* lib_runtime,
425 Direction direction, int depth, absl::flat_hash_set<Node*>* result);
426
DoesAnyCalleeHaveRefNodes(const CallTargetListTy & call_target_list,FunctionLibraryRuntime * lib_runtime,Direction direction,int depth)427 StatusOr<bool> DoesAnyCalleeHaveRefNodes(
428 const CallTargetListTy& call_target_list,
429 FunctionLibraryRuntime* lib_runtime, Direction direction, int depth) {
430 const int kMaxDepth = 10;
431
432 if (depth == kMaxDepth && !call_target_list.empty()) {
433 // Conservative answer to avoid recursing too much.
434 return true;
435 }
436
437 absl::flat_hash_set<Node*> callee_ref_nodes;
438 for (const NameAttrList& call_target : call_target_list) {
439 const OpRegistrationData* op_reg;
440 if (OpRegistry::Global()->LookUp(call_target.name(), &op_reg).ok()) {
441 const OpDef& op = op_reg->op_def;
442 if (absl::c_any_of(op.output_arg(), [](const OpDef::ArgDef arg) {
443 return arg.is_ref();
444 })) {
445 return true;
446 }
447 continue;
448 }
449
450 callee_ref_nodes.clear();
451 FunctionLibraryRuntime::Handle handle;
452 if (!lib_runtime
453 ->Instantiate(call_target.name(), AttrSlice(&call_target.attr()),
454 &handle)
455 .ok()) {
456 VLOG(2) << "Could not find " << call_target.name()
457 << " in the function library.";
458 // Since we don't know the semantic of `n` we don't know if this is an
459 // error. We return true to signal a conservative answer.
460 return true;
461 }
462
463 auto release_handle_on_return = gtl::MakeCleanup(
464 [&] { TF_CHECK_OK(lib_runtime->ReleaseHandle(handle)); });
465
466 const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
467 TF_RETURN_IF_ERROR(GetNodesRelatedToRefVariablesInDirection(
468 *fbody->graph, lib_runtime, direction, depth + 1, &callee_ref_nodes));
469
470 // We could possibly use something cheaper than
471 // GetNodesRelatedToRefVariablesInDirection since we only care about the
472 // size of `callee_ref_nodes` but for now we don't ceare.
473 if (!callee_ref_nodes.empty()) {
474 return true;
475 }
476 }
477
478 return false;
479 }
480
481 // Helper for GetNodesRelatedToRefVariables that traverses the graph in one
482 // direction.
GetNodesRelatedToRefVariablesInDirection(const Graph & graph,FunctionLibraryRuntime * lib_runtime,Direction direction,int depth,absl::flat_hash_set<Node * > * result)483 Status GetNodesRelatedToRefVariablesInDirection(
484 const Graph& graph, FunctionLibraryRuntime* lib_runtime,
485 Direction direction, int depth, absl::flat_hash_set<Node*>* result) {
486 std::vector<Node*> nodes_in_order;
487 if (direction == Direction::kForward) {
488 GetReversePostOrder(graph, &nodes_in_order,
489 /*stable_comparator=*/NodeComparatorName());
490 } else {
491 GetPostOrder(graph, &nodes_in_order,
492 /*stable_comparator=*/NodeComparatorName());
493 }
494
495 size_t old_result_size;
496 int iterations = 0;
497
498 const int kMaxIterations = 10 * 1000;
499
500 std::vector<bool> callee_has_ref_nodes_cache;
501 callee_has_ref_nodes_cache.resize(graph.num_node_ids());
502
503 auto does_callee_have_ref_nodes = [&](Node* n) -> StatusOr<bool> {
504 if (iterations == 1) {
505 TF_ASSIGN_OR_RETURN(
506 bool callee_has_ref_nodes,
507 DoesAnyCalleeHaveRefNodes(GetCallTargetListFromNode(*n, lib_runtime),
508 lib_runtime, direction, depth));
509 callee_has_ref_nodes_cache[n->id()] = callee_has_ref_nodes;
510 return callee_has_ref_nodes;
511 } else {
512 return {callee_has_ref_nodes_cache[n->id()]};
513 }
514 };
515
516 do {
517 TF_RET_CHECK(iterations++ < kMaxIterations) << "infinite loop?";
518
519 old_result_size = result->size();
520 for (Node* n : nodes_in_order) {
521 if (n->IsSource() || n->IsSink()) {
522 continue;
523 }
524
525 bool inserted_n = false;
526 const EdgeSet& edges =
527 direction == Direction::kForward ? n->in_edges() : n->out_edges();
528 for (const Edge* e : edges) {
529 if (result->contains(direction == Direction::kForward ? e->src()
530 : e->dst())) {
531 result->insert(n);
532 inserted_n = true;
533 break;
534 }
535 }
536
537 if (inserted_n) {
538 continue;
539 }
540
541 if (direction == Direction::kForward &&
542 absl::c_any_of(n->output_types(), IsRefType)) {
543 result->insert(n);
544 continue;
545 }
546
547 TF_ASSIGN_OR_RETURN(bool callee_has_ref_nodes,
548 does_callee_have_ref_nodes(n));
549 if (callee_has_ref_nodes) {
550 result->insert(n);
551 continue;
552 }
553 }
554
555 // Loop until convergence.
556 } while (result->size() != old_result_size);
557
558 VLOG(2) << "# iterations = " << iterations;
559
560 return OkStatus();
561 }
562
563 // Sorts control inputs of a graphdef so that they are deterministically
564 // ordered.
SortControlInputs(GraphDef * gdef)565 void SortControlInputs(GraphDef* gdef) {
566 int64_t num_nodes = gdef->node_size();
567 for (int64_t i = 0; i < num_nodes; ++i) {
568 NodeDef* node = gdef->mutable_node(i);
569 // Stable sort control inputs and leave the order of data inputs unchanged.
570 std::stable_sort(node->mutable_input()->begin(),
571 node->mutable_input()->end(),
572 [](const string& a, const string& b) {
573 bool a_is_control = absl::StartsWith(a, "^");
574 bool b_is_control = absl::StartsWith(b, "^");
575 return (!a_is_control && b_is_control) ||
576 (a_is_control && b_is_control && a < b);
577 });
578 }
579 }
580 } // namespace
581
GetNodesRelatedToRefVariables(const Graph & graph,FunctionLibraryRuntime * lib_runtime)582 StatusOr<absl::flat_hash_set<Node*>> GetNodesRelatedToRefVariables(
583 const Graph& graph, FunctionLibraryRuntime* lib_runtime) {
584 absl::flat_hash_set<Node*> result;
585 TF_RETURN_IF_ERROR(GetNodesRelatedToRefVariablesInDirection(
586 graph, lib_runtime, Direction::kForward, 0, &result));
587 TF_RETURN_IF_ERROR(GetNodesRelatedToRefVariablesInDirection(
588 graph, lib_runtime, Direction::kBackward, 0, &result));
589
590 VLOG(1) << "GetNodesRelatedToRefVariables() found " << result.size()
591 << " nodes";
592 return result;
593 }
594
SerializeGraphDeterministic(const Graph & graph)595 StatusOr<std::string> SerializeGraphDeterministic(const Graph& graph) {
596 GraphDef def;
597 graph.ToGraphDef(&def);
598
599 // Before serialization, sort each node's control inputs to achieve
600 // determinism. Sorting control inputs could help (but not necessarily) create
601 // a deterministic serialization and fingerprint. Other sources of
602 // nondeterminism include unstable node ordering.
603 SortControlInputs(&def);
604
605 std::string s;
606 if (!SerializeToStringDeterministic(def, &s)) {
607 return errors::Internal("Failed to serialize graphdef.");
608 }
609 return s;
610 }
611
FingerprintGraph(const Graph & graph)612 StatusOr<uint64> FingerprintGraph(const Graph& graph) {
613 TF_ASSIGN_OR_RETURN(std::string serialized,
614 SerializeGraphDeterministic(graph));
615 return Hash64(serialized.data(), serialized.size());
616 }
617
618 // Register a callback for querying XlaGlobalJitLevel.
619 REGISTER_XLA_CONFIG_GETTER(GetXlaGlobalJitLevel);
620
621 } // namespace tensorflow
622