xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/hash_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/data/hash_utils.h"
16 
17 #include <queue>
18 
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/str_join.h"
23 #include "tensorflow/core/common_runtime/function.h"
24 #include "tensorflow/core/data/dataset_utils.h"
25 #include "tensorflow/core/framework/attr_value.pb.h"
26 #include "tensorflow/core/framework/dataset.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/framework/node_def_util.h"
29 #include "tensorflow/core/framework/op.h"
30 #include "tensorflow/core/framework/op_def.pb.h"
31 #include "tensorflow/core/framework/op_def_builder.h"
32 #include "tensorflow/core/framework/op_def_util.h"
33 #include "tensorflow/core/framework/op_kernel.h"
34 #include "tensorflow/core/framework/tensor.pb.h"
35 #include "tensorflow/core/framework/types.h"
36 #include "tensorflow/core/graph/graph_def_builder.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/core/lib/hash/hash.h"
39 #include "tensorflow/core/lib/strings/proto_serialization.h"
40 #include "tensorflow/core/platform/errors.h"
41 #include "tensorflow/core/platform/regexp.h"
42 #include "tensorflow/core/platform/status.h"
43 #include "tensorflow/core/util/work_sharder.h"
44 
45 namespace tensorflow {
46 namespace data {
47 namespace {
48 
49 // clang-format off
50 constexpr std::array<const char*, 3> kOpsWithSeed = {
51     "AnonymousRandomSeedGenerator",
52     "ShuffleDataset",
53     "ShuffleAndRepeatDataset"
54 };
55 // clang-format on
56 constexpr char kSeedInputName[] = "seed";
57 constexpr char kSeed2InputName[] = "seed2";
58 constexpr char kSeedGeneratorInputName[] = "seed_generator";
59 
60 template <std::size_t SIZE>
IsNodeOfType(const NodeDef & node,const std::array<const char *,SIZE> & op_types)61 bool IsNodeOfType(const NodeDef& node,
62                   const std::array<const char*, SIZE>& op_types) {
63   for (const auto& type : op_types) {
64     if (MatchesAnyVersion(type, node.op())) {
65       return true;
66     }
67   }
68   return false;
69 }
70 
GetSink(const GraphDef & graph_def,const NodeDef ** sink)71 Status GetSink(const GraphDef& graph_def, const NodeDef** sink) {
72   for (auto& node : graph_def.node()) {
73     if (node.op() == "_Retval") {
74       *sink = &node;
75       break;
76     }
77   }
78 
79   if (sink == nullptr) {
80     return errors::Internal("Cannot find sink node for dataset graph.");
81   }
82   return OkStatus();
83 }
84 
ShouldIgnoreInput(const NodeDef & node,int i,bool * result)85 Status ShouldIgnoreInput(const NodeDef& node, int i, bool* result) {
86   *result = false;
87   if (IsNodeOfType(node, kOpsWithSeed)) {
88     const OpRegistrationData* reg;
89     auto status = OpRegistry::Global()->LookUp(node.op(), &reg);
90 
91     if (status.ok()) {
92       if (reg->op_def.input_arg_size() > i) {
93         const std::string input_arg_name = reg->op_def.input_arg(i).name();
94         if (input_arg_name == kSeedInputName ||
95             input_arg_name == kSeed2InputName ||
96             input_arg_name == kSeedGeneratorInputName) {
97           VLOG(2) << "Ignoring arg: " << input_arg_name
98                   << " from node: " << node.name();
99           *result = true;
100           return OkStatus();
101         }
102       }
103     } else if (errors::IsNotFound(status)) {
104       LOG(WARNING) << "Cannot find " << node.op()
105                    << " in global op registry, so cannot determine which "
106                       "inputs are seeds.";
107     } else {
108       return status;
109     }
110   }
111   return OkStatus();
112 }
113 
ParseInputNodeName(absl::string_view input_name,absl::string_view * node_name,absl::string_view * suffix,bool * is_control_input)114 Status ParseInputNodeName(absl::string_view input_name,
115                           absl::string_view* node_name,
116                           absl::string_view* suffix, bool* is_control_input) {
117   if (input_name[0] == '^') {
118     *node_name = input_name.substr(1);
119     *is_control_input = true;
120     return OkStatus();
121   }
122   std::pair<absl::string_view, absl::string_view> node_spec =
123       absl::StrSplit(input_name, absl::MaxSplits(':', 1));
124   *node_name = node_spec.first;
125   *suffix = node_spec.second;
126   *is_control_input = false;
127   return OkStatus();
128 }
129 
130 // Given a graph_def and a root_node, this class computes a fingerprint that
131 // tries to capture the structure of the graph rooted at the provided node.
132 // It does not at any point rely on the names of the nodes in the graph and
133 // just relies on the connections between different nodes. In the presence of
134 // multiple cycles in the graph, there is a non-zero possibility that two
135 // graphs with different structure might end up with the same fingerprint
136 // as in order to break cycles we prune away some edges (in a deterministic
137 // fashion though). Idea for this algorithm was borrowed from:
138 // https://stackoverflow.com/questions/11338746/directed-graphs-with-a-given-root-node-match-another-directed-graph-for-equali
139 class GraphHasher {
140   using NodeCache = absl::flat_hash_map<const NodeDef*, uint64>;
141   using FunctionCache = absl::flat_hash_map<const FunctionDef*, uint64>;
142   using AttrCache =
143       absl::flat_hash_map<std::pair<const NodeDef*, bool>, uint64>;
144 
145  public:
146   // `GraphHasher` does not take ownership of `graph_def`, `root_node`, or
147   // `flib_def`.
GraphHasher(const GraphDef * graph,const NodeDef * root,const FunctionLibraryDefinition * flib)148   explicit GraphHasher(const GraphDef* graph, const NodeDef* root,
149                        const FunctionLibraryDefinition* flib)
150       : graph_(graph), root_(root), flib_(flib) {
151     node_cache_ = std::make_shared<NodeCache>();
152     function_cache_ = std::make_shared<FunctionCache>();
153     attr_cache_ = std::make_shared<AttrCache>();
154   }
GraphHasher(const GraphDef * graph,const NodeDef * root,const FunctionLibraryDefinition * flib,std::shared_ptr<NodeCache> node_cache,std::shared_ptr<FunctionCache> function_cache,std::shared_ptr<AttrCache> attr_cache)155   explicit GraphHasher(const GraphDef* graph, const NodeDef* root,
156                        const FunctionLibraryDefinition* flib,
157                        std::shared_ptr<NodeCache> node_cache,
158                        std::shared_ptr<FunctionCache> function_cache,
159                        std::shared_ptr<AttrCache> attr_cache)
160       : graph_(graph),
161         root_(root),
162         flib_(flib),
163         node_cache_(node_cache),
164         function_cache_(function_cache),
165         attr_cache_(attr_cache) {}
166 
Init()167   Status Init() {
168     // Construct a map of name -> NodeDef to avoid repeated linear searches.
169     absl::flat_hash_map<absl::string_view, const NodeDef*> node_def_by_name;
170     node_def_by_name.reserve(graph_->node_size());
171     for (const auto& node : graph_->node()) {
172       auto result = node_def_by_name.emplace(node.name(), &node);
173       if (TF_PREDICT_FALSE(!result.second)) {
174         auto node_name_formatter =
175             [](std::string* out,
176                const decltype(node_def_by_name)::value_type& item) {
177               absl::StrAppend(out, "'", item.first, "'");
178             };
179         return errors::Internal(
180             "Encountered graph with duplicate node name '", node.name(),
181             "' in [", absl::StrJoin(node_def_by_name, ",", node_name_formatter),
182             "]");
183       }
184     }
185     // Pre-process the graph to do a BFS and prune away cycles that might cause
186     // problems.
187     absl::flat_hash_set<absl::string_view> visited;
188     std::queue<const NodeDef*> bfs_queue;
189     bfs_queue.push(root_);
190     while (!bfs_queue.empty()) {
191       const NodeDef* node = bfs_queue.front();
192       bfs_queue.pop();
193       if (visited.contains(node->name())) {
194         continue;
195       }
196       visited.insert(node->name());
197       NodeRep node_rep;
198       for (int i = 0; i < node->input_size(); ++i) {
199         DCHECK_GT(node->input(i).length(), 0);
200 
201         // We skip trying to take the hash of the seeds of any ops, as they
202         // are irrelevant to the hash of the graph and may vary from run to run.
203         bool should_ignore_input = false;
204         TF_RETURN_IF_ERROR(ShouldIgnoreInput(*node, i, &should_ignore_input));
205         if (should_ignore_input) continue;
206 
207         absl::string_view node_name, suffix;
208         bool is_control_input;
209         TF_RETURN_IF_ERROR(ParseInputNodeName(node->input(i), &node_name,
210                                               &suffix, &is_control_input));
211 
212         auto* input_node = gtl::FindPtrOrNull(node_def_by_name, node_name);
213         if (input_node == nullptr) {
214           return errors::Internal("Graph node [", node->name(), "] has input [",
215                                   node_name, "] that doesn't exist in graph");
216         }
217 
218         // If we've already seen this node before, skip it and don't add it to
219         // the queue.
220         if (visited.contains(node_name)) {
221           EdgeRep cycle_edge(node, input_node);
222           cycle_forming_edges_.insert(cycle_edge.GetHash());
223           continue;
224         }
225         if (is_control_input) {
226           node_rep.node_control_inputs.push_back(input_node);
227         } else {
228           node_rep.node_inputs.push_back(std::make_pair(input_node, suffix));
229           bfs_queue.push(input_node);
230         }
231       }
232       nodes_[node] = node_rep;
233     }
234     return OkStatus();
235   }
236 
HashRoot(uint64 * hash)237   Status HashRoot(uint64* hash) { return HashNode(root_, hash); }
238 
CheckEqual(GraphHasher * that)239   Status CheckEqual(GraphHasher* that) {
240     return CheckNodesEqual(root_, that, that->root_);
241   }
242 
243  private:
HashNode(const NodeDef * node,uint64 * hash)244   Status HashNode(const NodeDef* node, uint64* hash) {
245     auto it = node_cache_->find(node);
246     if (it != node_cache_->end()) {
247       *hash = it->second;
248       return OkStatus();
249     }
250 
251     NodeRep* node_rep = gtl::FindOrNull(nodes_, node);
252     if (node_rep == nullptr) {
253       return errors::InvalidArgument("Could not find node: ", node->name());
254     }
255 
256     uint64 non_input_hash;
257     TF_RETURN_IF_ERROR(
258         HashNodeNonInput(node, /*hash_functions=*/true, &non_input_hash));
259 
260     uint64 control_inputs_hash;
261     TF_RETURN_IF_ERROR(
262         HashControlInputs(node_rep->node_control_inputs, &control_inputs_hash));
263 
264     // Hash regular inputs. We combine them in an ordered fashion.
265     uint64 inputs_hash = 0;
266     for (const auto& input : node_rep->node_inputs) {
267       uint64 node_hash = 0;
268       EdgeRep edge(node, input.first);
269       // If the edge was pruned we get the non input node hash to avoid cycles.
270       if (cycle_forming_edges_.contains(edge.GetHash())) {
271         TF_RETURN_IF_ERROR(
272             HashNodeNonInput(input.first, /*hash_functions=*/true, &node_hash));
273       } else {
274         TF_RETURN_IF_ERROR(HashNode(input.first, &node_hash));
275       }
276       inputs_hash = Hash64Combine(
277           inputs_hash, Hash64Combine(node_hash, Hash64(input.second.data(),
278                                                        input.second.size())));
279     }
280 
281     *hash = Hash64Combine(non_input_hash,
282                           Hash64Combine(control_inputs_hash, inputs_hash));
283     auto result = node_cache_->emplace(node, *hash);
284     if (!result.second) {
285       return errors::Internal(absl::StrCat("Computed the hash for node ",
286                                            node->DebugString(), " twice!"));
287     }
288     return OkStatus();
289   }
290 
CheckNodesEqual(const NodeDef * this_node,GraphHasher * that,const NodeDef * that_node)291   Status CheckNodesEqual(const NodeDef* this_node, GraphHasher* that,
292                          const NodeDef* that_node) {
293     Status s = CheckNodesEqualHelper(this_node, that, that_node);
294     if (!s.ok()) {
295       return errors::FailedPrecondition("Nodes ", this_node->name(), " and ",
296                                         that_node->name(),
297                                         " are not the same:\n", s);
298     }
299     return s;
300   }
301 
CheckNodesEqualHelper(const NodeDef * this_node,GraphHasher * that,const NodeDef * that_node)302   Status CheckNodesEqualHelper(const NodeDef* this_node, GraphHasher* that,
303                                const NodeDef* that_node) {
304     TF_RETURN_IF_ERROR(CheckNodesEqualNonInput(this_node, that, that_node,
305                                                /*compare_functions=*/true));
306 
307     TF_RETURN_IF_ERROR(
308         CheckControlInputsEqual(nodes_[this_node].node_control_inputs, that,
309                                 that->nodes_[that_node].node_control_inputs));
310 
311     auto& this_node_inputs = nodes_[this_node].node_inputs;
312     auto& that_node_inputs = that->nodes_[that_node].node_inputs;
313     if (this_node_inputs.size() != that_node_inputs.size()) {
314       return errors::FailedPrecondition(
315           "Nodes have different numbers of node inputs: ",
316           this_node_inputs.size(), " vs ", that_node_inputs.size());
317     }
318     for (int i = 0; i < this_node_inputs.size(); ++i) {
319       const NodeDef* this_input = this_node_inputs[i].first;
320       const NodeDef* that_input = that_node_inputs[i].first;
321       if (is_cycle_forming_edge(this_node, this_input)) {
322         TF_RETURN_IF_ERROR(CheckNodesEqualNonInput(this_input, that, that_input,
323                                                    /*compare_functions=*/true));
324       } else {
325         TF_RETURN_IF_ERROR(CheckNodesEqual(this_input, that, that_input));
326       }
327       absl::string_view this_input_suffix = this_node_inputs[i].second;
328       absl::string_view that_input_suffix = that_node_inputs[i].second;
329       if (this_input_suffix != that_input_suffix) {
330         return errors::FailedPrecondition(
331             "Node inputs ", this_input->name(), " and ", that_input->name(),
332             " have different suffixes: ", this_input_suffix, " vs ",
333             that_input_suffix);
334       }
335     }
336     return OkStatus();
337   }
338 
HashNodeNonInput(const NodeDef * node,bool hash_functions,uint64 * hash)339   Status HashNodeNonInput(const NodeDef* node, bool hash_functions,
340                           uint64* hash) {
341     auto iter = attr_cache_->find(std::make_pair(node, hash_functions));
342     if (iter != attr_cache_->end()) {
343       *hash = iter->second;
344       return OkStatus();
345     }
346     // Hash Attrs. We get the list of attrs from the op registry and then look
347     // up their values in the NodeDef attr map. This avoids looping over
348     // a map which is non-deterministic.
349     uint64 attrs_hash = 0;
350     const OpRegistrationData* reg;
351     TF_RETURN_IF_ERROR(flib_->LookUp(node->op(), &reg));
352     uint64 op_hash = 0;
353     if (reg->is_function_op) {
354       if (hash_functions) {
355         TF_RETURN_IF_ERROR(HashFunction(node->op(), node->attr(), &op_hash));
356       }
357     } else {
358       op_hash = Hash64(node->op());
359     }
360 
361     for (const auto& attr : reg->op_def.attr()) {
362       const auto& attr_key = attr.name();
363       // Ignore "metadata" attribute of tf.data operations.
364       if (DatasetOpKernel::IsDatasetOp(reg->op_def) && attr_key == "metadata")
365         continue;
366       auto node_attr_iter = node->attr().find(attr_key);
367       if (node_attr_iter == node->attr().end()) {
368         continue;
369       }
370       const auto& attr_value = node_attr_iter->second;
371       if (attr_key == kColocationAttrName ||
372           attr_key == kColocationGroupPrefix) {
373         continue;
374       }
375       uint64 attr_hash = 0;
376       TF_RETURN_IF_ERROR(
377           HashAttr(attr_key, attr_value, hash_functions, &attr_hash));
378       attrs_hash = Hash64Combine(attrs_hash, attr_hash);
379     }
380 
381     // Hash Device.
382     uint64 device_hash = Hash64(node->device());
383 
384     *hash = Hash64Combine(op_hash, Hash64Combine(attrs_hash, device_hash));
385 
386     auto result =
387         attr_cache_->emplace(std::make_pair(node, hash_functions), *hash);
388     if (!result.second) {
389       return errors::Internal(absl::StrCat(
390           "Computed the hash for non-input node: ", node->DebugString(),
391           " and hash function bool: ", hash_functions, "twice!"));
392     }
393     return OkStatus();
394   }
395 
CheckNodesEqualNonInput(const NodeDef * this_node,GraphHasher * that,const NodeDef * that_node,bool compare_functions)396   Status CheckNodesEqualNonInput(const NodeDef* this_node, GraphHasher* that,
397                                  const NodeDef* that_node,
398                                  bool compare_functions) {
399     // We get the list of attrs from the op registry and then look
400     // up their values in the NodeDef attr map. This avoids looping over
401     // a map which is non-deterministic.
402     const OpRegistrationData* reg;
403     TF_RETURN_IF_ERROR(flib_->LookUp(this_node->op(), &reg));
404     if (reg->is_function_op) {
405       if (compare_functions) {
406         TF_RETURN_IF_ERROR(
407             CheckFunctionsEqual(this_node->op(), this_node->attr(), that,
408                                 that_node->op(), that_node->attr()));
409       }
410     } else {
411       if (this_node->op() != that_node->op()) {
412         return errors::FailedPrecondition(
413             "ops for nodes ", this_node->name(), " and ", that_node->name(),
414             " are different: ", this_node->op(), " != ", that_node->op());
415       }
416     }
417 
418     for (const auto& attr : reg->op_def.attr()) {
419       const auto& attr_key = attr.name();
420       const bool this_has_attr = this_node->attr().contains(attr_key);
421       const bool that_has_attr = that_node->attr().contains(attr_key);
422       if (this_has_attr != that_has_attr) {
423         return errors::FailedPrecondition(
424             "attr with key ", attr_key, " is different for nodes ",
425             this_node->name(), " and ", that_node->name(),
426             ". Present in former: ", this_has_attr,
427             ". Present in latter: ", that_has_attr);
428       }
429       if (!this_has_attr) {
430         continue;
431       }
432       if (attr_key == kColocationAttrName ||
433           attr_key == kColocationGroupPrefix) {
434         continue;
435       }
436       const auto& this_attr = this_node->attr().at(attr_key);
437       const auto& that_attr = that_node->attr().at(attr_key);
438       TF_RETURN_IF_ERROR(CheckAttrsEqual(attr_key, this_attr, that, that_attr,
439                                          compare_functions));
440     }
441 
442     if (this_node->device() != that_node->device()) {
443       return errors::FailedPrecondition(
444           "Devices are different for nodes ", this_node->name(), " and ",
445           that_node->name(), ": ", this_node->device(), " vs ",
446           that_node->device());
447     }
448     return OkStatus();
449   }
450 
HashAttr(const std::string & attr_name,const AttrValue & attr_value,bool hash_functions,uint64 * hash)451   Status HashAttr(const std::string& attr_name, const AttrValue& attr_value,
452                   bool hash_functions, uint64* hash) {
453     uint64 value_hash = 0;
454     if (attr_value.has_func()) {
455       if (hash_functions) {
456         TF_RETURN_IF_ERROR(HashFunction(attr_value.func(), &value_hash));
457       }
458     } else if (attr_value.has_list() && attr_value.list().func_size() > 0) {
459       if (hash_functions) {
460         for (auto& func : attr_value.list().func()) {
461           uint64 func_hash;
462           TF_RETURN_IF_ERROR(HashFunction(func, &func_hash));
463           value_hash = Hash64Combine(value_hash, func_hash);
464         }
465       }
466     } else {
467       value_hash = DeterministicProtoHash64(attr_value);
468     }
469     *hash = Hash64Combine(Hash64(attr_name), value_hash);
470     return OkStatus();
471   }
472 
CheckAttrsEqual(const std::string & attr_name,const AttrValue & this_attr,GraphHasher * that,const AttrValue & that_attr,bool compare_functions)473   Status CheckAttrsEqual(const std::string& attr_name,
474                          const AttrValue& this_attr, GraphHasher* that,
475                          const AttrValue& that_attr, bool compare_functions) {
476     if (this_attr.has_func() != that_attr.has_func()) {
477       return errors::FailedPrecondition(
478           "AttrValues are of different types: ", this_attr.DebugString(),
479           " vs ", that_attr.DebugString());
480     }
481     if (this_attr.has_func()) {
482       if (compare_functions) {
483         TF_RETURN_IF_ERROR(
484             CheckFunctionsEqual(this_attr.func(), that, that_attr.func()));
485       }
486       return OkStatus();
487     }
488     if (this_attr.has_list() != that_attr.has_list()) {
489       return errors::FailedPrecondition(
490           "AttrValues are of different types: ", this_attr.DebugString(),
491           " vs ", that_attr.DebugString());
492     }
493     if (this_attr.has_list()) {
494       if (this_attr.list().func_size() != that_attr.list().func_size()) {
495         return errors::FailedPrecondition(
496             "AttrValues have func lists of different sizes: ",
497             this_attr.DebugString(), " vs ", that_attr.DebugString());
498       }
499       if (compare_functions) {
500         for (int i = 0; i < this_attr.list().func_size(); ++i) {
501           TF_RETURN_IF_ERROR(CheckFunctionsEqual(this_attr.list().func(i), that,
502                                                  that_attr.list().func(i)));
503         }
504       }
505       return OkStatus();
506     }
507     uint64 this_hash, that_hash;
508     TF_RETURN_IF_ERROR(
509         HashAttr(attr_name, this_attr, /*hash_functions=*/true, &this_hash));
510     TF_RETURN_IF_ERROR(that->HashAttr(attr_name, that_attr,
511                                       /*hash_functions=*/true, &that_hash));
512     if (this_hash != that_hash) {
513       return errors::FailedPrecondition(
514           "AttrValues are different: ", this_attr.DebugString(), " vs ",
515           that_attr.DebugString());
516     }
517     return OkStatus();
518   }
519 
HashFunction(const NameAttrList & func,uint64 * hash)520   Status HashFunction(const NameAttrList& func, uint64* hash) {
521     return HashFunction(func.name(), func.attr(), hash);
522   }
523 
HashFunction(const std::string & name,const AttrValueMap & attrs,uint64 * hash)524   Status HashFunction(const std::string& name, const AttrValueMap& attrs,
525                       uint64* hash) {
526     const FunctionDef* fdef = flib_->Find(name);
527     auto it = function_cache_->find(fdef);
528     if (it != function_cache_->end()) {
529       *hash = it->second;
530       return OkStatus();
531     }
532 
533     // Convert to a GraphDef.
534     std::unique_ptr<FunctionBody> fbody;
535     TF_RETURN_IF_ERROR(
536         FunctionDefToBodyHelper(*fdef, AttrSlice(&attrs), flib_, &fbody));
537     GraphDef graph_def = fbody->graph->ToGraphDefDebug();
538 
539     // For each return node, we create a new GraphHasher to compute a hash.
540     // We then combine these hashes to produce the hash ordered.
541     uint64 ret_nodes_hash = 0;
542     for (const auto& ret_node : fbody->ret_nodes) {
543       uint64 ret_node_hash = 0;
544       GraphHasher hasher(&graph_def, &ret_node->def(), flib_, node_cache_,
545                          function_cache_, attr_cache_);
546       TF_RETURN_IF_ERROR(hasher.Init());
547       TF_RETURN_IF_ERROR(hasher.HashRoot(&ret_node_hash));
548       ret_nodes_hash = Hash64Combine(ret_nodes_hash, ret_node_hash);
549     }
550 
551     std::vector<const NodeDef*> control_rets;
552     control_rets.reserve(fbody->control_ret_nodes.size());
553     for (const auto& control_ret_node : fbody->control_ret_nodes) {
554       control_rets.push_back(&control_ret_node->def());
555     }
556     uint64 control_ret_nodes_hash = 0;
557     TF_RETURN_IF_ERROR(
558         HashControlInputs(control_rets, &control_ret_nodes_hash));
559 
560     *hash = Hash64Combine(ret_nodes_hash, control_ret_nodes_hash);
561     auto result = function_cache_->emplace(fdef, *hash);
562     if (!result.second) {
563       return errors::Internal(
564           absl::StrCat("Computed the hash for function ", name, " twice!"));
565     }
566     return OkStatus();
567   }
568 
CheckFunctionsEqual(const NameAttrList & this_func,GraphHasher * that,const NameAttrList & that_func)569   Status CheckFunctionsEqual(const NameAttrList& this_func, GraphHasher* that,
570                              const NameAttrList& that_func) {
571     return CheckFunctionsEqual(this_func.name(), this_func.attr(), that,
572                                that_func.name(), that_func.attr());
573   }
CheckFunctionsEqual(const std::string & this_name,const AttrValueMap & this_attrs,GraphHasher * that,const std::string & that_name,const AttrValueMap & that_attrs)574   Status CheckFunctionsEqual(const std::string& this_name,
575                              const AttrValueMap& this_attrs, GraphHasher* that,
576                              const std::string& that_name,
577                              const AttrValueMap& that_attrs) {
578     Status s = CheckFunctionsEqualHelper(this_name, this_attrs, that, that_name,
579                                          that_attrs);
580     if (!s.ok()) {
581       return errors::FailedPrecondition("Functions ", this_name, " and ",
582                                         that_name, " are not the same:\n", s);
583     }
584     return s;
585   }
586 
CheckFunctionsEqualHelper(const std::string & this_name,const AttrValueMap & this_attrs,GraphHasher * that,const std::string & that_name,const AttrValueMap & that_attrs)587   Status CheckFunctionsEqualHelper(const std::string& this_name,
588                                    const AttrValueMap& this_attrs,
589                                    GraphHasher* that,
590                                    const std::string& that_name,
591                                    const AttrValueMap& that_attrs) {
592     const FunctionDef* this_fdef = flib_->Find(this_name);
593     const FunctionDef* that_fdef = that->flib_->Find(that_name);
594 
595     // Convert to GraphDefs.
596     std::unique_ptr<FunctionBody> this_fbody;
597     TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
598         *this_fdef, AttrSlice(&this_attrs), flib_, &this_fbody));
599     GraphDef this_graph_def = this_fbody->graph->ToGraphDefDebug();
600     std::unique_ptr<FunctionBody> that_fbody;
601     TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
602         *that_fdef, AttrSlice(&that_attrs), that->flib_, &that_fbody));
603     GraphDef that_graph_def = that_fbody->graph->ToGraphDefDebug();
604 
605     if (this_fbody->ret_nodes.size() != that_fbody->ret_nodes.size()) {
606       return errors::FailedPrecondition(
607           "Different numbers of ret nodes for functions ", this_name, " and ",
608           that_name, ": ", this_fbody->ret_nodes.size(), " vs ",
609           that_fbody->ret_nodes.size());
610     }
611     for (int i = 0; i < this_fbody->ret_nodes.size(); ++i) {
612       const NodeDef* this_root = &this_fbody->ret_nodes[i]->def();
613       const NodeDef* that_root = &that_fbody->ret_nodes[i]->def();
614       GraphHasher this_hasher(&this_graph_def, this_root, flib_, node_cache_,
615                               function_cache_, attr_cache_);
616       TF_RETURN_IF_ERROR(this_hasher.Init());
617       GraphHasher that_hasher(&that_graph_def, that_root, that->flib_,
618                               node_cache_, function_cache_, attr_cache_);
619       TF_RETURN_IF_ERROR(that_hasher.Init());
620       TF_RETURN_IF_ERROR(this_hasher.CheckEqual(&that_hasher));
621     }
622 
623     std::vector<const NodeDef*> this_control_rets;
624     this_control_rets.reserve(this_fbody->control_ret_nodes.size());
625     for (const auto& control_ret_node : this_fbody->control_ret_nodes) {
626       this_control_rets.push_back(&control_ret_node->def());
627     }
628     std::vector<const NodeDef*> that_control_rets;
629     that_control_rets.reserve(that_fbody->control_ret_nodes.size());
630     for (const auto& control_ret_node : that_fbody->control_ret_nodes) {
631       that_control_rets.push_back(&control_ret_node->def());
632     }
633     TF_RETURN_IF_ERROR(
634         CheckControlInputsEqual(this_control_rets, that, that_control_rets));
635     return OkStatus();
636   }
637 
HashControlInputs(const std::vector<const NodeDef * > & inputs,uint64 * hash)638   Status HashControlInputs(const std::vector<const NodeDef*>& inputs,
639                            uint64* hash) {
640     *hash = 0;
641     for (const NodeDef* input : inputs) {
642       uint64 node_hash = 0;
643       TF_RETURN_IF_ERROR(
644           HashNodeNonInput(input, /*hash_functions=*/false, &node_hash));
645       *hash = Hash64CombineUnordered(*hash, node_hash);
646     }
647     return OkStatus();
648   }
649 
CheckControlInputsEqual(const std::vector<const NodeDef * > & this_inputs,GraphHasher * that,const std::vector<const NodeDef * > & that_inputs)650   Status CheckControlInputsEqual(
651       const std::vector<const NodeDef*>& this_inputs, GraphHasher* that,
652       const std::vector<const NodeDef*>& that_inputs) {
653     absl::flat_hash_map<uint64, const NodeDef*> this_hashes;
654     for (const NodeDef* input : this_inputs) {
655       uint64 node_hash = 0;
656       TF_RETURN_IF_ERROR(
657           HashNodeNonInput(input, /*hash_functions=*/false, &node_hash));
658       this_hashes[node_hash] = input;
659     }
660     absl::flat_hash_map<uint64, const NodeDef*> that_hashes;
661     for (const NodeDef* input : that_inputs) {
662       uint64 node_hash = 0;
663       TF_RETURN_IF_ERROR(
664           HashNodeNonInput(input, /*hash_functions=*/false, &node_hash));
665       auto this_iter = this_hashes.find(node_hash);
666       if (this_iter != this_hashes.end()) {
667         this_hashes.erase(this_iter);
668       } else {
669         that_hashes[node_hash] = input;
670       }
671     }
672     if (!this_hashes.empty()) {
673       auto formatter = [](string* out,
674                           const decltype(this_hashes)::value_type& item) {
675         out->append(item.second->name());
676       };
677       return errors::FailedPrecondition(
678           "Control dependencies are different. One node has dependencies [",
679           absl::StrJoin(this_hashes, ", ", formatter),
680           "], which don't match any of the other node's dependencies [",
681           absl::StrJoin(that_hashes, ", ", formatter), "]");
682     }
683     return OkStatus();
684   }
685 
686  private:
is_cycle_forming_edge(const NodeDef * start,const NodeDef * end)687   bool is_cycle_forming_edge(const NodeDef* start, const NodeDef* end) {
688     EdgeRep edge(start, end);
689     return cycle_forming_edges_.contains(edge.GetHash());
690   }
691 
692   struct NodeRep {
693     std::vector<const NodeDef*> node_control_inputs;
694     std::vector<std::pair<const NodeDef*, absl::string_view>> node_inputs;
695   };
696 
697   struct EdgeRep {
698     const NodeDef* start_node;
699     const NodeDef* end_node;
700 
EdgeReptensorflow::data::__anon4acced740111::GraphHasher::EdgeRep701     EdgeRep(const NodeDef* start, const NodeDef* end)
702         : start_node(start), end_node(end) {}
703 
GetHashtensorflow::data::__anon4acced740111::GraphHasher::EdgeRep704     uint64 GetHash() {
705       return Hash64Combine(absl::Hash<const NodeDef*>()(start_node),
706                            absl::Hash<const NodeDef*>()(end_node));
707     }
708   };
709   const GraphDef* const graph_;                  // Not owned.
710   const NodeDef* const root_;                    // Not owned.
711   const FunctionLibraryDefinition* const flib_;  // Not owned.
712   // Edges that need to be pruned as their presence will cause cycles.
713   absl::flat_hash_set<uint64> cycle_forming_edges_;
714   absl::flat_hash_map<const NodeDef*, NodeRep> nodes_;
715   std::shared_ptr<NodeCache> node_cache_;
716   std::shared_ptr<FunctionCache> function_cache_;
717   std::shared_ptr<AttrCache> attr_cache_;
718 };
719 
720 }  // anonymous namespace
721 
HashTensor(const Tensor & tensor,uint64 * hash)722 Status HashTensor(const Tensor& tensor, uint64* hash) {
723   const tstring* s = nullptr;
724   // Hash tensor type.
725   *hash = Hash64Combine(0, tensor.dtype());
726   // Hash tensor shape.
727   for (int i = 0; i < tensor.shape().dims(); ++i) {
728     *hash = Hash64Combine(*hash, tensor.shape().dim_size(i));
729   }
730   // Hash tensor data.
731   switch (tensor.dtype()) {
732     case DT_RESOURCE:
733     case DT_VARIANT:
734       return errors::Unimplemented("Hashing ", DataTypeString(tensor.dtype()),
735                                    " is not supported.");
736     case DT_STRING:
737       s = tensor.flat<tstring>().data();
738       for (int i = 0; i < tensor.NumElements(); ++i, ++s) {
739         *hash = Hash64Combine(*hash, Hash64(s->data(), s->size()));
740       }
741       break;
742     default:
743       *hash = Hash64(tensor.tensor_data().data(), tensor.tensor_data().size());
744   }
745   return OkStatus();
746 }
747 
HashNode(const GraphDef & graph,const NodeDef & node,uint64 * hash)748 Status HashNode(const GraphDef& graph, const NodeDef& node, uint64* hash) {
749   const FunctionLibraryDefinition flib_def(OpRegistry::Global(),
750                                            graph.library());
751   return HashNode(graph, node, flib_def, hash);
752 }
753 
HashNode(const GraphDef & graph,const NodeDef & node,const FunctionLibraryDefinition & flib_def,uint64 * hash)754 Status HashNode(const GraphDef& graph, const NodeDef& node,
755                 const FunctionLibraryDefinition& flib_def, uint64* hash) {
756   GraphHasher hasher(&graph, &node, &flib_def);
757   TF_RETURN_IF_ERROR(hasher.Init());
758   return hasher.HashRoot(hash);
759 }
760 
HashGraph(const GraphDef & graph_def,uint64 * hash)761 Status HashGraph(const GraphDef& graph_def, uint64* hash) {
762   const NodeDef* sink = nullptr;
763   TF_RETURN_IF_ERROR(GetSink(graph_def, &sink));
764   return HashNode(graph_def, *sink, hash);
765 }
766 
CheckGraphsEqual(const GraphDef & a,const GraphDef & b)767 Status CheckGraphsEqual(const GraphDef& a, const GraphDef& b) {
768   const NodeDef* sink_a;
769   TF_RETURN_IF_ERROR(GetSink(a, &sink_a));
770   const NodeDef* sink_b;
771   TF_RETURN_IF_ERROR(GetSink(b, &sink_b));
772   return CheckSubgraphsEqual(a, sink_a, b, sink_b);
773 }
774 
CheckSubgraphsEqual(const GraphDef & a,const NodeDef * node_a,const GraphDef & b,const NodeDef * node_b)775 Status CheckSubgraphsEqual(const GraphDef& a, const NodeDef* node_a,
776                            const GraphDef& b, const NodeDef* node_b) {
777   const FunctionLibraryDefinition flib_def_a(OpRegistry::Global(), a.library());
778   GraphHasher hasher_a(&a, node_a, &flib_def_a);
779   TF_RETURN_IF_ERROR(hasher_a.Init());
780 
781   const FunctionLibraryDefinition flib_def_b(OpRegistry::Global(), b.library());
782   GraphHasher hasher_b(&b, node_b, &flib_def_b);
783   TF_RETURN_IF_ERROR(hasher_b.Init());
784 
785   return hasher_a.CheckEqual(&hasher_b);
786 }
787 
788 }  // namespace data
789 }  // namespace tensorflow
790