xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/graph_constructor.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/graph_constructor.h"
17 
18 #include <algorithm>
19 #include <set>
20 #include <string>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <vector>
24 
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/string_view.h"
29 #include "tensorflow/core/common_runtime/shape_refiner.h"
30 #include "tensorflow/core/framework/function.h"
31 #include "tensorflow/core/framework/function.pb.h"
32 #include "tensorflow/core/framework/graph.pb.h"
33 #include "tensorflow/core/framework/node_def.pb.h"
34 #include "tensorflow/core/framework/node_def_util.h"
35 #include "tensorflow/core/framework/tensor_shape.pb.h"
36 #include "tensorflow/core/framework/types.h"
37 #include "tensorflow/core/framework/versions.h"
38 #include "tensorflow/core/framework/versions.pb.h"
39 #include "tensorflow/core/graph/algorithm.h"
40 #include "tensorflow/core/graph/graph.h"
41 #include "tensorflow/core/graph/tensor_id.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/lib/gtl/flatmap.h"
44 #include "tensorflow/core/lib/gtl/flatset.h"
45 #include "tensorflow/core/lib/gtl/inlined_vector.h"
46 #include "tensorflow/core/lib/strings/scanner.h"
47 #include "tensorflow/core/lib/strings/str_util.h"
48 #include "tensorflow/core/platform/errors.h"
49 #include "tensorflow/core/platform/logging.h"
50 #include "tensorflow/core/platform/macros.h"
51 #include "tensorflow/core/public/version.h"
52 
53 namespace tensorflow {
54 
55 namespace {
56 
57 // We remove duplicate control inputs before adding edges to the Graph, so we
58 // can skip expensive duplicates check in 'AddControlEdge'.
59 static constexpr const bool kDoNotCheckDuplicates = true;
60 
IsMerge(const NodeDef & node_def)61 inline bool IsMerge(const NodeDef& node_def) {
62   return node_def.op() == "Merge" || node_def.op() == "RefMerge" ||
63          node_def.op() == "_XlaMerge";
64 }
65 
IsNextIteration(const NodeDef & node_def)66 inline bool IsNextIteration(const NodeDef& node_def) {
67   return node_def.op() == "NextIteration" ||
68          node_def.op() == "RefNextIteration";
69 }
70 
IsValidNodeName(StringPiece s,bool allow_internal_ops)71 bool IsValidNodeName(StringPiece s, bool allow_internal_ops) {
72   using ::tensorflow::strings::Scanner;
73   Scanner scanner(s);
74   scanner
75       .One(allow_internal_ops ? Scanner::LETTER_DIGIT_DOT_UNDERSCORE
76                               : Scanner::LETTER_DIGIT_DOT)
77       .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
78 
79   while (true) {
80     if (!scanner.GetResult())  // Some error in previous iteration.
81       return false;
82     if (scanner.empty())  // No error, but nothing left, good.
83       return true;
84 
85     // Absorb another piece, starting with a '>'
86     scanner.One(Scanner::RANGLE)
87         .One(Scanner::LETTER_DIGIT_DOT)
88         .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
89   }
90 }
91 
92 class GraphConstructor {
93  public:
94   struct Options {
Optionstensorflow::__anon0bbc98090111::GraphConstructor::Options95     Options(const GraphConstructorOptions& in)  // NOLINT(runtime/explicit)
96         : allow_internal_ops(in.allow_internal_ops),
97           expect_device_spec(in.expect_device_spec),
98           importing(false),
99           validate_nodes(in.validate_nodes),
100           validate_colocation_constraints(false),
101           add_default_attributes(in.add_default_attributes) {}
Optionstensorflow::__anon0bbc98090111::GraphConstructor::Options102     Options(const ImportGraphDefOptions& in)  // NOLINT(runtime/explicit)
103         : allow_internal_ops(false),
104           expect_device_spec(false),
105           prefix(in.prefix.empty() || str_util::EndsWith(in.prefix, "/")
106                      ? in.prefix
107                      : in.prefix + "/"),
108           uniquify_names(in.uniquify_names),
109           uniquify_prefix(in.uniquify_prefix),
110           input_map(in.input_map.begin(), in.input_map.end()),
111           skip_mapped_nodes(in.skip_mapped_nodes),
112           control_dependencies(in.control_dependencies),
113           return_tensors(in.return_tensors.begin(), in.return_tensors.end()),
114           return_nodes(in.return_nodes),
115           importing(true),
116           validate_nodes(true),
117           validate_colocation_constraints(in.validate_colocation_constraints),
118           validate_shape(in.validate_shape),
119           default_device(in.default_device) {}
120 
121     bool allow_internal_ops;
122     bool expect_device_spec;
123 
124     string prefix;
125     bool uniquify_names;
126     bool uniquify_prefix;
127     std::map<TensorId, TensorId> input_map;
128     bool skip_mapped_nodes;
129     std::vector<string> control_dependencies;
130     std::vector<TensorId> return_tensors;
131     std::vector<string> return_nodes;
132 
133     // TODO(ashankar): This bool exists to separate out functionality required
134     // to make ImportGraphDef a close equivalent of Python's import_graph_def
135     // without affecting the behavior of ConvertGraphDefToGraph at the time
136     // ImportGraphDef was added.
137     //
138     // That said, the functionality here (shape and op validation) seems
139     // applicable to ConvertGraphDefToGraph as well, so make an attempt to
140     // remove this.
141     bool importing;
142     // If true, validates that nodes being converted have all expected attrs
143     // set and no unknown attrs set by calling ValidateNodeDef().
144     // `validate_nodes` is always true when `importing` is set.
145     bool validate_nodes;
146     bool validate_colocation_constraints;
147     bool validate_shape = true;
148 
149     // If true, GraphConstructor will add attributes with their default
150     // value to the Node when they are missing from the NodeDef.
151     bool add_default_attributes = true;
152 
153     string default_device;
154   };
155 
156   typedef gtl::ArraySlice<const NodeDef*> NodeDefSlice;
157 
158   // versions and library may be nullptr
159   static Status Construct(
160       const Options& opts, NodeDefSlice node_defs, const VersionDef* versions,
161       const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner,
162       std::vector<std::pair<Node*, int>>* return_tensors,
163       std::vector<Node*>* return_nodes,
164       std::vector<SafeTensorId>* missing_unused_input_map_keys);
165 
166   static Status Construct(
167       const Options& opts, GraphDef&& graph_def, Graph* g,
168       ShapeRefiner* refiner, std::vector<std::pair<Node*, int>>* return_tensors,
169       std::vector<Node*>* return_nodes,
170       std::vector<SafeTensorId>* missing_unused_input_map_keys);
171 
172  protected:
GraphConstructor(const Options & opts,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)173   GraphConstructor(const Options& opts, Graph* g, ShapeRefiner* refiner,
174                    std::vector<std::pair<Node*, int>>* return_tensors,
175                    std::vector<Node*>* return_nodes,
176                    std::vector<SafeTensorId>* missing_unused_input_map_keys)
177       : opts_(opts),
178         g_(g),
179         original_versions_(g->versions()),
180         prefix_(opts.prefix),
181         refiner_(refiner),
182         return_tensors_(return_tensors),
183         return_nodes_(return_nodes),
184         missing_unused_input_map_keys_(missing_unused_input_map_keys) {}
185 
~GraphConstructor()186   virtual ~GraphConstructor() {}
187 
TryImport()188   Status TryImport() {
189     TF_RETURN_IF_ERROR(EnsureNoNameCollisions());
190     TF_RETURN_IF_ERROR(ValidateInputMapAndControlDependencies());
191     TF_RETURN_IF_ERROR(BuildNodeIndex());
192     TF_RETURN_IF_ERROR(InitFromEdges());
193 
194     // NOTE: Convert() invokes `consume_node_def()` on each node in the input
195     // graph, so `get_node_def()` is no longer usable once it is called.
196     TF_RETURN_IF_ERROR(Convert());
197 
198     TF_RETURN_IF_ERROR(AddBackEdges());
199     TF_RETURN_IF_ERROR(UpdateVersionDef());
200     TF_RETURN_IF_ERROR(PopulateReturnTensors());
201     TF_RETURN_IF_ERROR(PopulateReturnNodes());
202     TF_RETURN_IF_ERROR(PopulateMissingUnusedInputMapKeys());
203     UpdateUniquifiedColocationNames();
204     FixupSourceAndSinkEdges(g_);
205     return OkStatus();
206   }
207 
208  private:
209   Status EnsureNoNameCollisions();
210   Status ValidateInputMapAndControlDependencies();
211   Status BuildNodeIndex();
212   Status InitFromEdges();
213   Status Convert();
214   Status AddBackEdges();
215   Status UpdateVersionDef();
216   Status PopulateReturnTensors();
217   Status PopulateReturnNodes();
218   Status PopulateMissingUnusedInputMapKeys();
219 
220   void Undo();
221 
222   // Prints cycles in the graph.
223   void PrintCycles();
224   // Performs DFS starting at `cur_node` and prints any cycles found.
225   void DFS(int cur_node, std::vector<int>* cur_branch,
226            std::vector<bool>* is_on_cur_branch,
227            absl::flat_hash_set<int>* unvisited,
228            const std::vector<absl::string_view>& node_names);
229   Status IsNodeFullyMapped(const NodeDef& node_def, bool* is_node_mapped);
230   Status ValidateColocationConstraints(const NodeDef& node_def);
231   Status MakeNode(NodeDef&& node_def, Node** node);
232   Status MakeEdge(Node* src, int output_index, Node* dst, int input_index);
233   Status ValidateShape(Node* node);
234   Status ModifyNodeDefForImport(NodeDef* node_def);
235   // Modifies node_def's inputs according to opts_.input_map.
236   // input_already_exists is a pre-initialized vector of length
237   // node_def->input_size(). This function will mark inputs that are remapped to
238   // true.
239   void RemapNodeDefInputs(NodeDef* node_def,
240                           std::vector<bool>* input_already_exists);
241   // input_already_exists is a pre-initialized vector of length
242   // node_def->input_size(). This function will add and mark control inputs as
243   // true.
244   void AddControlDependencies(NodeDef* node_def,
245                               std::vector<bool>* input_already_exists);
246   void AddPrefixToNodeDef(const std::vector<bool>& input_already_exists,
247                           NodeDef* node_def);
248 
249   // Modifies `node_def` if its name isn't unique, or if any of its inputs'
250   // names have been uniquified. This must be called in topological order on all
251   // nodes.
252   void UniquifyNames(const std::vector<bool>& input_already_exists,
253                      NodeDef* node_def);
254 
255   // Updates any constructed nodes' colocation group names if the name has been
256   // updated by UniquifyNames. This is called after all the nodes have been
257   // constructed so all the names have been uniquified if necessary.
258   void UpdateUniquifiedColocationNames();
259 
260   // Returns true if `name` already exists in `g_` (either as a node name or
261   // prefix).
262   bool NameExistsInGraph(StringPiece name);
263 
264   // Returns true if `name` already exists in the GraphDef being imported
265   // (either as a node name or prefix).
266   bool NameExistsInGraphDef(StringPiece name);
267 
268   // Returns a unique version of `original_name`, or `original_name` if it's
269   // already unique in the graph.
270   string FindUniqueName(StringPiece original_name);
271 
272   // Decrement pending count for users of `processed` and add the ones that now
273   // have all of their pending inputs satisfied to `ready_`.
274   void UpdatePendingCountAndReady(int processed, bool is_next_iteration);
275 
276   // Subclasses override the following virtual methods to provide efficient
277   // access to the original protocol buffer-based graph.
278 
279   // Returns the number of nodes in the graph.
280   virtual size_t node_def_count() const = 0;
281   // Returns the i^th node in the graph. Must not be called after
282   // consume_node_def(i).
283   virtual const NodeDef& get_node_def(int i) const = 0;
284   // Destructively reads the i^th node in the graph, avoiding a copy if
285   // possible. After calling this method, the result of get_node_def(i) is
286   // undefined.
287   virtual NodeDef consume_node_def(int i) = 0;
288   // Returns the version information for the graph, or nullptr if none is
289   // available.
290   virtual const VersionDef* versions() const = 0;
291   // Returns the function information for the graph, or nullptr if none is
292   // available.
293   virtual const FunctionDefLibrary* library() const = 0;
294 
295   // From constructor
296   const Options opts_;
297   Graph* g_;
298   const VersionDef original_versions_;
299 
300   // A copy of opts_.prefix, possibly uniquified.
301   string prefix_;
302 
303   ShapeRefiner* refiner_;
304 
305   // May be null. Not owned.
306   std::vector<std::pair<Node*, int>>* return_tensors_;
307 
308   // May be null. Not owned.
309   std::vector<Node*>* return_nodes_;
310 
311   // May be null. Not owned.
312   std::vector<SafeTensorId>* missing_unused_input_map_keys_;
313 
314   // Intermediate datastructure used to populate
315   // `missing_unused_input_map_keys_`.
316   std::set<TensorId> used_input_map_keys_;
317 
318   // Intermediate datastructure used to track the destinations of back edges.
319   absl::flat_hash_set<int> merge_node_indices_;
320 
321   // Mapping from node name to the index within node_defs_.
322   struct NodeInfo {
NodeInfotensorflow::__anon0bbc98090111::GraphConstructor::NodeInfo323     explicit NodeInfo(int i) : gdef_index(i), node(nullptr) {}
324     // Containers require that we have a default constructor.
NodeInfotensorflow::__anon0bbc98090111::GraphConstructor::NodeInfo325     NodeInfo() : NodeInfo(-1) {}
326     int gdef_index;
327     Node* node;  // nullptr until the NodeDef is converted to a Node.
328   };
329   absl::flat_hash_map<std::string, NodeInfo> gdef_nodes_;
330 
331   // Prefixes already used in the GraphDef being imported.
332   absl::flat_hash_set<StringPiece> gdef_prefixes_;
333 
334   // Mapping from node name to the existing node in g_.
335   absl::flat_hash_map<StringPiece, Node*> existing_nodes_;
336 
337   // Prefixes already used in the graph.
338   absl::flat_hash_set<StringPiece> existing_prefixes_;
339 
340   // Imported node names that have been uniquified. The key is the original
341   // name, the value is the new unique name.
342   gtl::FlatMap<string, string> uniquified_names_;
343 
344   // Index of NodeDefs in node_defs_ with all inputs already converted. We use a
345   // (sorted) set so nodes are created in the order defined in the GraphDef.
346   std::set<int> ready_;
347 
348   // Mapping between index within node_defs_ and the number of inputs that
349   // still need to be converted.
350   std::vector<int> pending_count_;
351 
352   // Mapping between index within node_defs_ and the index within node_defs_ of
353   // all nodes it outputs to.
354   std::vector<gtl::InlinedVector<int, 4>> outputs_;
355 
356   // Used in the conversion from node_defs_ to g_ to represent the ith input
357   // of a node.
358   struct InputInfo {
InputInfotensorflow::__anon0bbc98090111::GraphConstructor::InputInfo359     explicit InputInfo(const string& node_name, Node* n, int i)
360         : name(node_name), node(n), index(i) {}
361     // Use string instead of StringPiece so we don't have to manage lifetime
362     string name;
363     Node* node;
364     int index;
365 
IsControlInputtensorflow::__anon0bbc98090111::GraphConstructor::InputInfo366     static bool IsControlInput(const InputInfo& input) {
367       return input.index == Graph::kControlSlot;
368     }
CompareNametensorflow::__anon0bbc98090111::GraphConstructor::InputInfo369     static int CompareName(const InputInfo& lhs, const InputInfo& rhs) {
370       return lhs.name < rhs.name;
371     }
IsSameNametensorflow::__anon0bbc98090111::GraphConstructor::InputInfo372     static bool IsSameName(const InputInfo& lhs, const InputInfo& rhs) {
373       return lhs.name == rhs.name;
374     }
375   };
376 
377   // Used in the conversion from node_defs_ to g_ to represent an edge from
378   // the node named 'name' to node 'n'.
379   struct EdgeInfo {
EdgeInfotensorflow::__anon0bbc98090111::GraphConstructor::EdgeInfo380     explicit EdgeInfo(const string& name, int i1, Node* n, int i2)
381         : src_name(name), src_index(i1), dst_node(n), dst_index(i2) {}
382     // Use string instead of StringPiece so we don't have to manage lifetime
383     string src_name;
384     int src_index;
385     Node* dst_node;
386     int dst_index;
387   };
388   std::vector<EdgeInfo> back_edges_;
389 
390   TF_DISALLOW_COPY_AND_ASSIGN(GraphConstructor);
391 };
392 
393 // Implementation of GraphConstructor that does not take ownership of the
394 // input NodeDef messages and thus copies the nodes into the constructed Graph*.
395 //
396 // NOTE(mrry): Whenever possible, use NodeDefMovingGraphConstructor, which
397 // avoids copying each NodeDef into the constructed Graph*.
398 class NodeDefCopyingGraphConstructor : public GraphConstructor {
399  public:
NodeDefCopyingGraphConstructor(const Options & opts,NodeDefSlice node_defs,const VersionDef * versions,const FunctionDefLibrary * library,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)400   NodeDefCopyingGraphConstructor(
401       const Options& opts, NodeDefSlice node_defs, const VersionDef* versions,
402       const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner,
403       std::vector<std::pair<Node*, int>>* return_tensors,
404       std::vector<Node*>* return_nodes,
405       std::vector<SafeTensorId>* missing_unused_input_map_keys)
406       : GraphConstructor(opts, g, refiner, return_tensors, return_nodes,
407                          missing_unused_input_map_keys),
408         node_defs_(node_defs),
409         versions_(versions),
410         library_(library) {}
411 
412  private:
node_def_count() const413   size_t node_def_count() const override { return node_defs_.size(); }
get_node_def(int i) const414   const NodeDef& get_node_def(int i) const override { return *node_defs_[i]; }
consume_node_def(int i)415   NodeDef consume_node_def(int i) override { return *node_defs_[i]; }
versions() const416   const VersionDef* versions() const override { return versions_; }
library() const417   const FunctionDefLibrary* library() const override { return library_; }
418 
419   const NodeDefSlice node_defs_;
420   const VersionDef* const versions_;
421   const FunctionDefLibrary* const library_;
422 };
423 
424 // Implementation of GraphConstructor that takes ownership of the input
425 // GraphDef, and can perform destructive reads.
426 class NodeDefMovingGraphConstructor : public GraphConstructor {
427  public:
NodeDefMovingGraphConstructor(const Options & opts,GraphDef && graph_def,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)428   NodeDefMovingGraphConstructor(
429       const Options& opts, GraphDef&& graph_def, Graph* g,
430       ShapeRefiner* refiner, std::vector<std::pair<Node*, int>>* return_tensors,
431       std::vector<Node*>* return_nodes,
432       std::vector<SafeTensorId>* missing_unused_input_map_keys)
433       : GraphConstructor(opts, g, refiner, return_tensors, return_nodes,
434                          missing_unused_input_map_keys),
435         graph_def_(std::move(graph_def)),
436         is_consumed_(graph_def_.node_size(), false) {}
437 
438  private:
node_def_count() const439   size_t node_def_count() const override { return graph_def_.node().size(); }
get_node_def(int i) const440   const NodeDef& get_node_def(int i) const override {
441     CHECK(!is_consumed_[i])
442         << "NodeDef " << i << " accessed after it was consumed.";
443     return graph_def_.node(i);
444   }
consume_node_def(int i)445   NodeDef consume_node_def(int i) override {
446     CHECK(!is_consumed_[i]) << "NodeDef " << i << " consumed twice.";
447     is_consumed_[i] = true;
448     return std::move(*graph_def_.mutable_node(i));
449   }
versions() const450   const VersionDef* versions() const override { return &graph_def_.versions(); }
library() const451   const FunctionDefLibrary* library() const override {
452     return &graph_def_.library();
453   }
454 
455   GraphDef graph_def_;
456   std::vector<bool> is_consumed_;
457 };
458 
ForwardCompatibilityWindowPassed(const VersionDef & versions)459 bool ForwardCompatibilityWindowPassed(const VersionDef& versions) {
460   // TF_GRAPH_DEF_VERSION is incremented daily.
461   // TF has a 3 week forward compatibility guarantee.
462   return (versions.producer() - TF_GRAPH_DEF_VERSION) > 21;
463 }
464 
MaybeAppendVersionWarning(const VersionDef * versions,const Status & import_status)465 Status MaybeAppendVersionWarning(const VersionDef* versions,
466                                  const Status& import_status) {
467   if (versions && ForwardCompatibilityWindowPassed(*versions)) {
468     return Status(
469         import_status.code(),
470         absl::StrCat(
471             "Converting GraphDef to Graph has failed with an error: '",
472             import_status.error_message(),
473             "' The binary trying to import the GraphDef was built when "
474             "GraphDef version was ",
475             TF_GRAPH_DEF_VERSION,
476             ". The GraphDef was produced by a binary built when GraphDef "
477             "version was ",
478             versions->producer(),
479             ". The difference between these versions is larger than "
480             "TensorFlow's forward compatibility guarantee, and might be the "
481             "root cause for failing to import the GraphDef."));
482   }
483   return import_status;
484 }
485 
Construct(const Options & opts,NodeDefSlice node_defs,const VersionDef * versions,const FunctionDefLibrary * library,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)486 /* static */ Status GraphConstructor::Construct(
487     const Options& opts, NodeDefSlice node_defs, const VersionDef* versions,
488     const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner,
489     std::vector<std::pair<Node*, int>>* return_tensors,
490     std::vector<Node*>* return_nodes,
491     std::vector<SafeTensorId>* missing_unused_input_map_keys) {
492   if (versions) {
493     TF_RETURN_IF_ERROR(CheckVersions(*versions, TF_GRAPH_DEF_VERSION,
494                                      TF_GRAPH_DEF_VERSION_MIN_PRODUCER,
495                                      "GraphDef", "graph"));
496   }
497   NodeDefCopyingGraphConstructor c(opts, node_defs, versions, library, g,
498                                    refiner, return_tensors, return_nodes,
499                                    missing_unused_input_map_keys);
500   Status s = c.TryImport();
501   if (!s.ok()) {
502     c.Undo();
503     s = MaybeAppendVersionWarning(versions, s);
504   }
505   return s;
506 }
507 
Construct(const Options & opts,GraphDef && graph_def,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)508 /* static */ Status GraphConstructor::Construct(
509     const Options& opts, GraphDef&& graph_def, Graph* g, ShapeRefiner* refiner,
510     std::vector<std::pair<Node*, int>>* return_tensors,
511     std::vector<Node*>* return_nodes,
512     std::vector<SafeTensorId>* missing_unused_input_map_keys) {
513   TF_RETURN_IF_ERROR(CheckVersions(graph_def.versions(), TF_GRAPH_DEF_VERSION,
514                                    TF_GRAPH_DEF_VERSION_MIN_PRODUCER,
515                                    "GraphDef", "graph"));
516   VersionDef version_def = graph_def.versions();
517   NodeDefMovingGraphConstructor c(opts, std::move(graph_def), g, refiner,
518                                   return_tensors, return_nodes,
519                                   missing_unused_input_map_keys);
520   Status s = c.TryImport();
521   if (!s.ok()) {
522     c.Undo();
523     s = MaybeAppendVersionWarning(&version_def, s);
524   }
525   return s;
526 }
527 
UpdatePendingCountAndReady(int processed,bool is_next_iteration)528 void GraphConstructor::UpdatePendingCountAndReady(int processed,
529                                                   bool is_next_iteration) {
530   for (size_t i = 0; i < outputs_[processed].size(); ++i) {
531     const int output = outputs_[processed][i];
532     // We didn't consider NextIteration->Merge edges when computing
533     // pending_counts_ so we should not have to consider it here either.
534     bool is_next_iteration_to_merge_edge =
535         is_next_iteration && merge_node_indices_.count(output) == 1;
536     if (!is_next_iteration_to_merge_edge) {
537       int* current_pending_count = &pending_count_[output];
538       CHECK_GT(*current_pending_count, 0);
539       (*current_pending_count)--;
540       if (*current_pending_count == 0) {
541         ready_.insert(output);
542       }
543     }
544   }
545 }
546 
547 // This could be expensive but we don't expect to call it often, if at all (only
548 // if there are multiple nodes in g_ with the same name)
NodeNameInValues(const std::map<TensorId,TensorId> & input_map,const StringPiece & node_name)549 bool NodeNameInValues(const std::map<TensorId, TensorId>& input_map,
550                       const StringPiece& node_name) {
551   for (auto iter = input_map.begin(); iter != input_map.end(); ++iter) {
552     if (iter->second.first == node_name) return true;
553   }
554   return false;
555 }
556 
NodeNameInValues(const std::vector<string> & control_dependencies,const StringPiece & node_name)557 bool NodeNameInValues(const std::vector<string>& control_dependencies,
558                       const StringPiece& node_name) {
559   return std::find(control_dependencies.begin(), control_dependencies.end(),
560                    node_name) != control_dependencies.end();
561 }
562 
563 // Adds any prefixes of `node_name` (not including the full name itself) to
564 // `prefixes`.
AddPrefixes(StringPiece node_name,absl::flat_hash_set<StringPiece> * prefixes)565 void AddPrefixes(StringPiece node_name,
566                  absl::flat_hash_set<StringPiece>* prefixes) {
567   size_t idx = -1;
568   while ((idx = node_name.find('/', idx + 1)) != StringPiece::npos) {
569     prefixes->insert(node_name.substr(0, idx));
570   }
571 }
572 
EnsureNoNameCollisions()573 Status GraphConstructor::EnsureNoNameCollisions() {
574   existing_nodes_.reserve(g_->num_nodes());
575   // Populate existing_nodes_ and existing_prefixes_.
576   for (Node* n : g_->nodes()) {
577     bool already_exists = !existing_nodes_.insert({n->name(), n}).second;
578     if (already_exists) {
579       if (NodeNameInValues(opts_.input_map, n->name())) {
580         return errors::InvalidArgument(
581             "cannot resolve input_map because multiple nodes exist with name '",
582             n->name(), "'");
583       }
584       if (NodeNameInValues(opts_.control_dependencies, n->name())) {
585         return errors::InvalidArgument(
586             "cannot resolve control_dependencies because multiple nodes exist "
587             "with name '",
588             n->name(), "'");
589       }
590     }
591     AddPrefixes(n->name(), &existing_prefixes_);
592   }
593   if (prefix_.empty() && opts_.importing && !opts_.uniquify_names) {
594     for (size_t i = 0; i < node_def_count(); ++i) {
595       const string& name = get_node_def(i).name();
596       if (NameExistsInGraph(name)) {
597         return errors::InvalidArgument("Node name '", name,
598                                        "' already exists in the Graph");
599       }
600     }
601   } else if (!prefix_.empty()) {
602     StringPiece prefix_no_slash(prefix_);
603     prefix_no_slash.remove_suffix(1);
604     if (!IsValidNodeName(prefix_no_slash, false)) {
605       return errors::InvalidArgument("Imported node name prefix '", prefix_,
606                                      "' would lead to invalid node names");
607     }
608     if (NameExistsInGraph(prefix_no_slash) && opts_.uniquify_prefix) {
609       prefix_ = strings::StrCat(FindUniqueName(prefix_no_slash), "/");
610     }
611   }
612   return OkStatus();
613 }
614 
ValidateInputMapAndControlDependencies()615 Status GraphConstructor::ValidateInputMapAndControlDependencies() {
616   for (const auto& mapping : opts_.input_map) {
617     TensorId src = mapping.first;
618     TensorId dst = mapping.second;
619     if (existing_nodes_.count(dst.first) == 0) {
620       return errors::InvalidArgument(
621           "node '", dst.first, "' in input_map does not exist in graph ",
622           "(input_map entry: ", src.ToString(), "->", dst.ToString(), ")");
623     }
624     if ((src.second == Graph::kControlSlot) !=
625         (dst.second == Graph::kControlSlot)) {
626       return errors::InvalidArgument("input_map entry ", src.ToString(), "->",
627                                      dst.ToString(), " between ",
628                                      "control edge and non-control edge");
629     }
630   }
631   for (const string& node : opts_.control_dependencies) {
632     if (existing_nodes_.count(node) == 0) {
633       return errors::InvalidArgument(
634           "node '", node,
635           "' in control_dependencies does not exist in "
636           "graph");
637     }
638   }
639   return OkStatus();
640 }
641 
BuildNodeIndex()642 Status GraphConstructor::BuildNodeIndex() {
643   // Validate the node names and add them to gdef_nodes_ and gdef_prefixes_.
644   for (int n = 0; n < node_def_count(); ++n) {
645     const NodeDef& node_def = get_node_def(n);
646     if (!IsValidNodeName(node_def.name(), opts_.allow_internal_ops)) {
647       return errors::InvalidArgument(
648           "Node '", node_def.name(),
649           "': Node name contains invalid characters");
650     }
651     if (!gdef_nodes_.insert(std::make_pair(node_def.name(), NodeInfo(n)))
652              .second) {
653       return errors::InvalidArgument("Node '", node_def.name(),
654                                      "' is not unique");
655     }
656     // Validate the operation's type.
657     if (node_def.op().empty()) {
658       return errors::InvalidArgument("Node '", node_def.name(),
659                                      "' does not specify an operation");
660     }
661     if (opts_.expect_device_spec && node_def.device().empty()) {
662       return errors::InvalidArgument("Node '", node_def.name(),
663                                      "' is missing a device specification");
664     }
665     if (IsMerge(node_def)) {
666       merge_node_indices_.insert(n);
667     }
668     // Validate control edges at end
669     bool in_control_dependence = false;
670     for (int i = 0; i < node_def.input_size(); ++i) {
671       StringPiece input_name = node_def.input(i);
672       if (!input_name.empty() && absl::StartsWith(input_name, "^")) {
673         in_control_dependence = true;
674       } else if (in_control_dependence) {
675         return errors::InvalidArgument(
676             "Node '", node_def.name(),
677             "': Control dependencies must come after regular dependencies");
678       }
679     }
680     // Update gdef_prefixes_.
681     AddPrefixes(node_def.name(), &gdef_prefixes_);
682   }
683   return OkStatus();
684 }
685 
InitFromEdges()686 Status GraphConstructor::InitFromEdges() {
687   const int num_nodes = node_def_count();
688   pending_count_.reserve(num_nodes);
689   outputs_.resize(num_nodes);
690   gtl::FlatSet<string> next_iteration_nodes;
691   for (int n = 0; n < node_def_count(); ++n) {
692     const NodeDef& node_def = get_node_def(n);
693     if (IsNextIteration(node_def)) {
694       next_iteration_nodes.insert(node_def.name());
695     }
696   }
697 
698   // Parse the inputs for each node.
699   for (int n = 0; n < num_nodes; ++n) {
700     const NodeDef& node_def = get_node_def(n);
701     int pending_count = node_def.input_size();
702     if (IsMerge(node_def)) {
703       // Cycles in the graph are only allowed for while loops. A while loop is
704       // identified by an edge from a NextIteration node to a Merge node. For
705       // such Merge nodes, only wait for one non-control input before
706       // considering the node ready to process in Convert().
707       int32_t num_control_edges = 0;
708       bool has_loop_back_edge = false;
709       for (int i = 0; i < node_def.input_size(); ++i) {
710         StringPiece input_name(node_def.input(i));
711         if (absl::StartsWith(input_name, "^")) {
712           num_control_edges++;
713         } else {
714           TensorId id(ParseTensorName(input_name));
715           if (next_iteration_nodes.find(string(id.first)) !=
716               next_iteration_nodes.end()) {
717             has_loop_back_edge = true;
718           }
719         }
720       }
721       if (has_loop_back_edge) {
722         pending_count = num_control_edges + 1;
723       }
724     }
725     for (int i = 0; i < node_def.input_size(); ++i) {
726       StringPiece input_name = node_def.input(i);
727       TensorId id(ParseTensorName(input_name));
728       if (opts_.input_map.count(id) == 0) {
729         // If an input is not mapped, then the input should appear in the graph
730         // being imported.
731         auto iter = gdef_nodes_.find(id.first);
732         if (iter == gdef_nodes_.end()) {
733           return errors::InvalidArgument("Node '", node_def.name(),
734                                          "': Unknown input node '",
735                                          node_def.input(i), "'");
736         }
737         outputs_[iter->second.gdef_index].push_back(n);
738       } else {
739         // This input is mapped to an existing edge. Therefore this input is
740         // as good as being already processed.
741         --pending_count;
742         DCHECK_GE(pending_count, 0);
743       }
744     }
745     if (pending_count == 0) {
746       ready_.insert(n);
747     }
748     pending_count_.push_back(pending_count);
749   }
750   return OkStatus();
751 }
752 
ValidateColocationConstraints(const NodeDef & node_def)753 Status GraphConstructor::ValidateColocationConstraints(
754     const NodeDef& node_def) {
755   if (!opts_.validate_colocation_constraints || !opts_.importing)
756     return OkStatus();
757   const auto iter = node_def.attr().find(kColocationAttrName);
758   if (iter == node_def.attr().end()) return OkStatus();
759   for (const string& c : iter->second.list().s()) {
760     StringPiece s(c);
761     if (absl::ConsumePrefix(&s, kColocationGroupPrefix) &&
762         gdef_nodes_.find(s) == gdef_nodes_.end()) {
763       return errors::InvalidArgument(
764           "Node '", node_def.name(),
765           "' expects to be colocated with unknown node '", s, "'");
766     }
767   }
768   return OkStatus();
769 }
770 
MakeNode(NodeDef && node_def,Node ** node)771 Status GraphConstructor::MakeNode(NodeDef&& node_def, Node** node) {
772   // Add the node to the graph.
773   Status status;
774   *node = g_->AddNode(std::move(node_def), &status);
775   if (!status.ok()) return status;
776   if (opts_.expect_device_spec) {
777     (*node)->set_assigned_device_name((*node)->def().device());
778   }
779   return OkStatus();
780 }
781 
ValidateShape(Node * node)782 Status GraphConstructor::ValidateShape(Node* node) {
783   if (!opts_.importing || !opts_.validate_shape) return OkStatus();
784   TF_RETURN_IF_ERROR(refiner_->AddNode(node));
785   // For nodes with the _output_shapes attribute, override the shape.
786   std::vector<const TensorShapeProto*> shape_attrs;
787   const char* kAttrName = "_output_shapes";
788   if (!TryGetNodeAttr(node->attrs(), kAttrName, &shape_attrs)) {
789     // No _output_shapes attribute, the AddNode call above was sufficient.
790     return OkStatus();
791   }
792   auto* ic = refiner_->GetContext(node);
793   DCHECK(ic != nullptr)
794       << "ShapeRefiner::AddNode() should have created the InferenceContext";
795   if (shape_attrs.size() < node->num_outputs()) {
796     return errors::InvalidArgument(
797         "Node '", node->name(), "' has ", node->num_outputs(),
798         " outputs but the ", kAttrName, " attribute specifies shapes for ",
799         shape_attrs.size(), " outputs");
800   }
801   // NOTE(skyewm): we don't raise an error here because some users depend on
802   // this behavior, even though it's unsafe.
803   // TODO(b/74619486): raise an error.
804   if (shape_attrs.size() > node->num_outputs()) {
805     LOG(WARNING) << "Node '" << node->name() << "' has " << node->num_outputs()
806                  << " outputs but the " << kAttrName
807                  << " attribute specifies shapes for " << shape_attrs.size()
808                  << " outputs. Output shapes may be inaccurate.";
809   }
810   for (int i = 0; i < node->num_outputs(); ++i) {
811     const TensorShapeProto& p = *shape_attrs[i];
812     shape_inference::ShapeHandle h;
813     Status s = ic->MakeShapeFromShapeProto(p, &h);
814     if (!s.ok()) {
815       return errors::InvalidArgument("Node '", node->name(), " has an invalid ",
816                                      kAttrName, " attribute (shape #", i,
817                                      " error:'", s.error_message(), "'");
818     }
819     s = refiner_->SetShape(node, i, h);
820     if (!s.ok()) {
821       return errors::InvalidArgument(
822           "Node '", node->name(), "' has an ", kAttrName,
823           " attribute inconsistent with the GraphDef for output #", i, ": ",
824           s.error_message());
825     }
826   }
827   node->ClearAttr(kAttrName);
828   return OkStatus();
829 }
830 
ModifyNodeDefForImport(NodeDef * node_def)831 Status GraphConstructor::ModifyNodeDefForImport(NodeDef* node_def) {
832   const OpDef* op_def;
833   TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def));
834   AddDefaultsToNodeDef(*op_def, node_def);
835   TF_RETURN_IF_ERROR(ValidateNodeDef(*node_def, *op_def));
836   if (versions()) {
837     TF_RETURN_IF_ERROR(CheckOpDeprecation(*op_def, versions()->producer()));
838   }
839   return OkStatus();
840 }
841 
RemoveInputs(const std::vector<int> & inputs_to_remove,NodeDef * node_def,std::vector<bool> * input_already_exists)842 void RemoveInputs(const std::vector<int>& inputs_to_remove, NodeDef* node_def,
843                   std::vector<bool>* input_already_exists) {
844   // Remove 'inputs_to_remove' from 'node_def'
845   NodeDef copy;
846   copy.mutable_input()->Reserve(node_def->input_size() -
847                                 inputs_to_remove.size());
848   for (int i = 0, j = 0; i < node_def->input_size(); ++i) {
849     if (j < inputs_to_remove.size() && i == inputs_to_remove[j]) {
850       ++j;
851     } else {
852       copy.add_input()->swap(*node_def->mutable_input(i));
853     }
854   }
855   node_def->mutable_input()->Swap(copy.mutable_input());
856   // Remove 'inputs_to_remove' from 'input_already_exists'
857   for (int idx : inputs_to_remove) {
858     input_already_exists->erase(input_already_exists->begin() + idx);
859   }
860   DCHECK_EQ(input_already_exists->size(), node_def->input_size());
861 }
862 
RemapNodeDefInputs(NodeDef * node_def,std::vector<bool> * input_already_exists)863 void GraphConstructor::RemapNodeDefInputs(
864     NodeDef* node_def, std::vector<bool>* input_already_exists) {
865   DCHECK_EQ(input_already_exists->size(), node_def->input_size());
866   std::set<TensorId> control_inputs;
867   std::vector<int> inputs_to_remove;
868 
869   for (int i = 0; i < node_def->input_size(); ++i) {
870     auto iter = opts_.input_map.find(ParseTensorName(node_def->input(i)));
871     if (iter == opts_.input_map.end()) continue;
872     used_input_map_keys_.insert(iter->first);
873 
874     TensorId new_input = iter->second;
875     if (new_input.second == Graph::kControlSlot) {
876       // Check if we've already remapped a different input to new_input, and if
877       // so remove this input.
878       if (control_inputs.count(new_input) > 0) {
879         inputs_to_remove.push_back(i);
880         continue;
881       }
882       control_inputs.insert(new_input);
883     }
884     node_def->set_input(i, new_input.ToString());
885     (*input_already_exists)[i] = true;
886   }
887   if (!inputs_to_remove.empty()) {
888     RemoveInputs(inputs_to_remove, node_def, input_already_exists);
889   }
890 }
891 
AddControlDependencies(NodeDef * node_def,std::vector<bool> * input_already_exists)892 void GraphConstructor::AddControlDependencies(
893     NodeDef* node_def, std::vector<bool>* input_already_exists) {
894   // To avoid adding redundant control dependencies to every imported node, skip
895   // nodes that will inherit the dependencies from another imported node.
896   bool inherits_deps = false;
897   for (int i = 0; i < node_def->input_size(); ++i) {
898     // Assume we won't inherit dependencies from remapped inputs that already
899     // exist in the graph. Even if we're wrong, we'll only add redundant
900     // dependencies.
901     if ((*input_already_exists)[i]) continue;
902 
903     // If this input is a backedge, assume we won't inherit the dependencies.
904     // TODO(skyewm): we have many redundant ParseTensorName calls. It could be
905     // worth optimizing these.
906     TensorId id(ParseTensorName(node_def->input(i)));
907     auto iter = gdef_nodes_.find(id.first);
908     DCHECK(iter != gdef_nodes_.end()) << id.first;
909     if (iter->second.node == nullptr) {
910       // Input hasn't been created yet, indicating it's a backedge.
911       continue;
912     }
913     inherits_deps = true;
914   }
915   if (inherits_deps) return;
916 
917   // node_def either has no inputs or all remapped inputs, add the control
918   // dependencies
919   for (const string& control_dep : opts_.control_dependencies) {
920     string input = TensorId(control_dep, Graph::kControlSlot).ToString();
921     bool found = false;
922     for (int i = node_def->input_size() - 1; i >= 0; --i) {
923       const string& node_input = node_def->input(i);
924       if (node_input[0] != '^') {
925         // Control inputs are at the end. Break when we reach the non-control
926         // inputs.
927         break;
928       }
929       if (node_input == input) {
930         // Control dependency already exists
931         found = true;
932         break;
933       }
934     }
935     if (found) {
936       continue;
937     }
938     node_def->add_input(input);
939     input_already_exists->push_back(true);
940   }
941 }
942 
AddPrefixToNodeDef(const std::vector<bool> & input_already_exists,NodeDef * node_def)943 void GraphConstructor::AddPrefixToNodeDef(
944     const std::vector<bool>& input_already_exists, NodeDef* node_def) {
945   if (prefix_.empty()) return;
946   node_def->set_name(strings::StrCat(prefix_, node_def->name()));
947   // Update names of input nodes
948   for (int i = 0; i < node_def->input_size(); ++i) {
949     // Skip remapped inputs (which already exist in g_ and are not being
950     // imported).
951     if (input_already_exists[i]) continue;
952     StringPiece input(node_def->input(i));
953     if (absl::ConsumePrefix(&input, "^")) {
954       node_def->set_input(i, strings::StrCat("^", prefix_, input));
955     } else {
956       node_def->set_input(i, strings::StrCat(prefix_, input));
957     }
958   }
959   // Update names of colocation groups
960   if (node_def->attr().find(kColocationAttrName) != node_def->attr().end()) {
961     auto* list =
962         node_def->mutable_attr()->at(kColocationAttrName).mutable_list();
963     for (int i = 0; i < list->s_size(); ++i) {
964       StringPiece v(list->s(i));
965       if (absl::ConsumePrefix(&v, kColocationGroupPrefix)) {
966         list->set_s(i, strings::StrCat(kColocationGroupPrefix, prefix_, v));
967       }
968     }
969   }
970 }
971 
UniquifyNames(const std::vector<bool> & input_already_exists,NodeDef * node_def)972 void GraphConstructor::UniquifyNames(
973     const std::vector<bool>& input_already_exists, NodeDef* node_def) {
974   if (NameExistsInGraph(node_def->name())) {
975     string old_name = node_def->name();
976     node_def->set_name(FindUniqueName(node_def->name()));
977     uniquified_names_[old_name] = node_def->name();
978     // Note that we don't have to update gdef_nodes_ or gdef_prefixes_ with
979     // `name` because we guarantee the original NodeDef names are unique,
980     // meaning we won't generate this name again.
981   }
982   for (int i = 0; i < node_def->input_size(); ++i) {
983     // Skip remapped inputs (which already exist in g_ and are not being
984     // imported).
985     if (input_already_exists[i]) continue;
986     TensorId id = ParseTensorName(node_def->input(i));
987     // We require that UniquifyNames() is called on all NodeDefs in topological
988     // order. This guarantees that node_def's inputs will already be uniquified
989     // if necessary.
990     auto iter = uniquified_names_.find(string(id.first));
991     if (iter == uniquified_names_.end()) continue;
992     id.first = iter->second;
993     node_def->set_input(i, id.ToString());
994   }
995 }
996 
UpdateUniquifiedColocationNames()997 void GraphConstructor::UpdateUniquifiedColocationNames() {
998   for (const auto& pair : gdef_nodes_) {
999     Node* node = pair.second.node;
1000     if (node == nullptr) continue;
1001     std::vector<string> coloc_values;
1002     if (!TryGetNodeAttr(node->attrs(), kColocationAttrName, &coloc_values))
1003       continue;
1004     bool updated = false;
1005     for (size_t i = 0; i < coloc_values.size(); ++i) {
1006       StringPiece val(coloc_values[i]);
1007       if (absl::ConsumePrefix(&val, kColocationGroupPrefix)) {
1008         auto name_pair = uniquified_names_.find(string(val));
1009         if (name_pair == uniquified_names_.end()) continue;
1010         updated = true;
1011         coloc_values[i] =
1012             strings::StrCat(kColocationGroupPrefix, name_pair->second);
1013       }
1014     }
1015     if (updated) {
1016       node->AddAttr(kColocationAttrName, std::move(coloc_values));
1017     }
1018   }
1019 }
1020 
NameExistsInGraph(StringPiece name)1021 bool GraphConstructor::NameExistsInGraph(StringPiece name) {
1022   if (existing_nodes_.find(name) != existing_nodes_.end()) return true;
1023   if (existing_prefixes_.find(name) != existing_prefixes_.end()) return true;
1024   return false;
1025 }
1026 
NameExistsInGraphDef(StringPiece name)1027 bool GraphConstructor::NameExistsInGraphDef(StringPiece name) {
1028   if (gdef_nodes_.find(name) != gdef_nodes_.end()) return true;
1029   if (gdef_prefixes_.find(name) != gdef_prefixes_.end()) return true;
1030   return false;
1031 }
1032 
FindUniqueName(StringPiece original_name)1033 string GraphConstructor::FindUniqueName(StringPiece original_name) {
1034   string name(original_name);
1035   int count = 0;
1036   // Check that any generated names don't collide with imported NodeDefs (as
1037   // well as nodes in g_).
1038   while (NameExistsInGraph(name) || (count > 0 && NameExistsInGraphDef(name))) {
1039     name = strings::StrCat(original_name, "_", ++count);
1040   }
1041   return name;
1042 }
1043 
IsNodeFullyMapped(const NodeDef & node_def,bool * is_node_mapped)1044 Status GraphConstructor::IsNodeFullyMapped(const NodeDef& node_def,
1045                                            bool* is_node_mapped) {
1046   const OpDef* op_def;
1047   TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def.op(), &op_def));
1048   for (int i = 0; i < op_def->output_arg_size(); ++i) {
1049     if (opts_.input_map.find({node_def.name(), i}) == opts_.input_map.end()) {
1050       *is_node_mapped = false;
1051       return OkStatus();
1052     }
1053   }
1054   *is_node_mapped = true;
1055   return OkStatus();
1056 }
1057 
DFS(int cur_node,std::vector<int> * cur_branch,std::vector<bool> * is_on_cur_branch,absl::flat_hash_set<int> * unvisited,const std::vector<absl::string_view> & node_names)1058 void GraphConstructor::DFS(int cur_node, std::vector<int>* cur_branch,
1059                            std::vector<bool>* is_on_cur_branch,
1060                            absl::flat_hash_set<int>* unvisited,
1061                            const std::vector<absl::string_view>& node_names) {
1062   cur_branch->push_back(cur_node);
1063   is_on_cur_branch->at(cur_node) = true;
1064   for (auto next_node : outputs_[cur_node]) {
1065     if (unvisited->find(next_node) != unvisited->end()) {
1066       if (is_on_cur_branch->at(next_node)) {
1067         auto iter =
1068             std::find(cur_branch->begin(), cur_branch->end(), next_node);
1069         LOG(WARNING) << "Cycle detected:";
1070         while (iter != cur_branch->end()) {
1071           const absl::string_view name = node_names[*iter];
1072           DCHECK(!name.empty());
1073           LOG(WARNING) << "node id=" << *iter << ", name=" << name;
1074           ++iter;
1075         }
1076         LOG(WARNING) << "End of cycle";
1077       } else {
1078         DFS(next_node, cur_branch, is_on_cur_branch, unvisited, node_names);
1079       }
1080     }
1081   }
1082   cur_branch->pop_back();
1083   is_on_cur_branch->at(cur_node) = false;
1084   unvisited->erase(cur_node);
1085 }
1086 
PrintCycles()1087 void GraphConstructor::PrintCycles() {
1088   int num_nodes = outputs_.size();
1089 
1090   std::vector<absl::string_view> node_names;
1091   node_names.resize(num_nodes);
1092   for (const auto& named_node : gdef_nodes_) {
1093     DCHECK_GE(named_node.second.gdef_index, 0);
1094     DCHECK_LT(named_node.second.gdef_index, num_nodes);
1095     node_names[named_node.second.gdef_index] = named_node.first;
1096   }
1097 
1098   absl::flat_hash_set<int> unvisited;
1099   for (int i = 0; i < num_nodes; i++) {
1100     unvisited.insert(i);
1101   }
1102 
1103   while (!unvisited.empty()) {
1104     int cur_node = *unvisited.begin();
1105     // Nodes on the current branch of DFS in traversal order. This is used for
1106     // printing the nodes in the cycle.
1107     std::vector<int> cur_branch;
1108     // This is just to make lookups O(1).
1109     // is_on_cur_branch[i] ==
1110     //   (std::find(cur_branch.start(),
1111     //              cur_branch.end(), i) != cur_branch.end())
1112     std::vector<bool> is_on_cur_branch(num_nodes, false);
1113     DFS(cur_node, &cur_branch, &is_on_cur_branch, &unvisited, node_names);
1114   }
1115 }
1116 
Convert()1117 Status GraphConstructor::Convert() {
1118   // Import functions before adding nodes, since imported nodes may refer to
1119   // functions
1120   if (library()) {
1121     // TODO(b/135705010): Add rvalue overloads into the function library, to
1122     // avoid unnecessarily copying `*library()` here.
1123     TF_RETURN_IF_ERROR(g_->AddFunctionLibrary(*library()));
1124   }
1125 
1126   std::vector<InputInfo> inputs;
1127   int processed = 0;
1128 
1129   std::vector<bool> input_already_exists;
1130 
1131   // Process the NodeDefs in topological order.
1132   // (InitFromEdges() sets this up by filling in ready_ with nodes that have no
1133   // inputs, pending_counts_ with the number of inputs for each node and
1134   // outputs_ with the outputs of each node).
1135   while (!ready_.empty()) {
1136     int o = *ready_.begin();
1137     ready_.erase(ready_.begin());
1138     ++processed;
1139     inputs.clear();
1140     bool has_data_back_edge = false;
1141 
1142     NodeDef node_def = consume_node_def(o);
1143 
1144     // input_already_exists[i] is true iff the i-th input of the node we're
1145     // importing refers to a preexisting node in g_ (i.e. input[i] existed prior
1146     // to importing node_defs_).  Conversely, input_already_exists[i] is false
1147     // iff the input refers to a node in node_defs_.
1148     input_already_exists.clear();
1149     input_already_exists.resize(node_def.input_size(), false);
1150 
1151     std::string node_name = node_def.name();
1152 
1153     if (opts_.importing) {
1154       if (opts_.skip_mapped_nodes) {
1155         bool is_node_mapped = false;
1156         TF_RETURN_IF_ERROR(IsNodeFullyMapped(node_def, &is_node_mapped));
1157         if (is_node_mapped) {
1158           // Skip this node after updating pending_count_ for outputs
1159           UpdatePendingCountAndReady(o, IsNextIteration(node_def));
1160           continue;
1161         }
1162       }
1163 
1164       if (!opts_.input_map.empty()) {
1165         // Note that input_already_exists can shrink here
1166         RemapNodeDefInputs(&node_def, &input_already_exists);
1167       }
1168       if (!opts_.control_dependencies.empty()) {
1169         // Note that input_already_exists can grow here
1170         AddControlDependencies(&node_def, &input_already_exists);
1171       }
1172       if (!opts_.default_device.empty() && node_def.device().empty()) {
1173         node_def.set_device(opts_.default_device);
1174       }
1175     }
1176 
1177     DCHECK_EQ(node_def.input_size(), input_already_exists.size());
1178     TF_RETURN_IF_ERROR(ValidateColocationConstraints(node_def));
1179     for (int i = 0; i < node_def.input_size(); ++i) {
1180       TensorId tensor_id = ParseTensorName(node_def.input(i));
1181       Node* src_node;
1182       int src_index;
1183 
1184       if (!input_already_exists[i]) {
1185         // Locate input in newly-imported nodes
1186         auto iter = gdef_nodes_.find(tensor_id.node());
1187         DCHECK(iter != gdef_nodes_.end()) << tensor_id.node();
1188         src_node = iter->second.node;
1189         src_index = tensor_id.index();
1190         if (src_node == nullptr) has_data_back_edge = true;
1191       } else {
1192         // Input refers to preexistng node in graph
1193         auto iter = existing_nodes_.find(tensor_id.node());
1194         DCHECK(iter != existing_nodes_.end()) << tensor_id.node();
1195         src_node = iter->second;
1196         src_index = tensor_id.index();
1197       }
1198 
1199       if (src_node != nullptr && src_index >= src_node->num_outputs()) {
1200         std::ostringstream out;
1201         out << "Node '" << node_def.name() << "': Connecting to invalid output "
1202             << tensor_id.index() << " of source node " << tensor_id.node()
1203             << " which has " << src_node->num_outputs() << " outputs.";
1204 
1205         if (src_node->type_string() == "If" ||
1206             src_node->type_string() == "StatelessIf" ||
1207             src_node->type_string() == "While" ||
1208             src_node->type_string() == "StatelessWhile") {
1209           out << " Try using "
1210               << "tf.compat.v1.experimental.output_all_intermediates(True).";
1211         }
1212         return errors::InvalidArgument(out.str());
1213       }
1214 
1215       inputs.emplace_back(string(tensor_id.node()), src_node, src_index);
1216     }
1217 
1218     if (has_data_back_edge && !IsMerge(node_def)) {
1219       return errors::InvalidArgument(
1220           "Node '", node_def.name(),
1221           "' had a back edge, but only Merge nodes can have back edges.");
1222     }
1223 
1224     Node* node;
1225     if (opts_.importing) {
1226       if (!prefix_.empty()) {
1227         AddPrefixToNodeDef(input_already_exists, &node_def);
1228       }
1229       // Note: no need to uniquify names if the prefix already guarantees
1230       // uniqueness
1231       if (opts_.uniquify_names && (prefix_.empty() || !opts_.uniquify_prefix)) {
1232         UniquifyNames(input_already_exists, &node_def);
1233       }
1234     }
1235 
1236     if (opts_.importing) {
1237       TF_RETURN_IF_ERROR(ModifyNodeDefForImport(&node_def));
1238     } else {
1239       const OpDef* op_def;
1240       TF_RETURN_IF_ERROR(
1241           g_->op_registry()->LookUpOpDef(node_def.op(), &op_def));
1242       if (opts_.add_default_attributes) {
1243         AddDefaultsToNodeDef(*op_def, &node_def);
1244       }
1245       if (opts_.validate_nodes) {
1246         TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, *op_def));
1247       }
1248     }
1249 
1250     TF_RETURN_IF_ERROR(MakeNode(std::move(node_def), &node));
1251 
1252     gdef_nodes_[node_name].node = node;
1253 
1254     // Remove duplicate control inputs before adding edges to the graph. It
1255     // will allow us to skip expensive duplicates check in 'AddControlEdge'.
1256     auto first_control = absl::c_find_if(inputs, &InputInfo::IsControlInput);
1257     auto first_control_copy = first_control;
1258     std::sort(first_control, inputs.end(), &InputInfo::CompareName);
1259     inputs.erase(
1260         std::unique(first_control_copy, inputs.end(), &InputInfo::IsSameName),
1261         inputs.end());
1262 
1263     // Add edges from inputs to *node to the graph.
1264     for (size_t i = 0; i < inputs.size(); ++i) {
1265       if (inputs[i].node == nullptr) {
1266         // Record this back edge, which will be added after all nodes
1267         // are created.
1268         back_edges_.emplace_back(inputs[i].name, inputs[i].index, node, i);
1269       } else if (inputs[i].index == Graph::kControlSlot) {
1270         g_->AddControlEdge(inputs[i].node, node, kDoNotCheckDuplicates);
1271       } else {
1272         TF_RETURN_IF_ERROR(MakeEdge(inputs[i].node, inputs[i].index, node, i));
1273       }
1274     }
1275 
1276     TF_RETURN_IF_ERROR(ValidateShape(node));
1277 
1278     // Update pending_count_ for outputs.
1279     UpdatePendingCountAndReady(o, node->IsNextIteration());
1280   }
1281 
1282   if (processed < node_def_count()) {
1283     LOG(WARNING) << "IN " << __func__ << " " << (node_def_count() - processed)
1284                  << " NODES IN A CYCLE";
1285     for (int64_t i = 0; i < node_def_count(); i++) {
1286       if (pending_count_[i] != 0) {
1287         LOG(WARNING) << "PENDING: " << SummarizeNodeDef(get_node_def(i))
1288                      << " WITH PENDING COUNT = " << pending_count_[i];
1289       }
1290     }
1291     PrintCycles();
1292     return errors::InvalidArgument(node_def_count() - processed,
1293                                    " nodes in a cycle");
1294   }
1295 
1296   return OkStatus();
1297 }
1298 
AddBackEdges()1299 Status GraphConstructor::AddBackEdges() {
1300   // Add the back edges after all nodes are created.
1301   for (const auto& e : back_edges_) {
1302     Node* src_node = gdef_nodes_[e.src_name].node;
1303     if (e.src_index == Graph::kControlSlot) {
1304       g_->AddControlEdge(src_node, e.dst_node, kDoNotCheckDuplicates);
1305     } else {
1306       TF_RETURN_IF_ERROR(
1307           MakeEdge(src_node, e.src_index, e.dst_node, e.dst_index));
1308     }
1309 
1310     VLOG(2) << "Add back edge: " << src_node->name() << " -> "
1311             << e.dst_node->name();
1312   }
1313   return OkStatus();
1314 }
1315 
UpdateVersionDef()1316 Status GraphConstructor::UpdateVersionDef() {
1317   if (versions() == nullptr) return OkStatus();
1318 
1319   if (!opts_.importing) {
1320     g_->set_versions(*versions());
1321     return OkStatus();
1322   }
1323   VersionDef g_versions = g_->versions();
1324   g_versions.set_producer(
1325       std::min(g_versions.producer(), versions()->producer()));
1326   g_versions.set_min_consumer(
1327       std::max(g_versions.min_consumer(), versions()->min_consumer()));
1328   if (versions()->bad_consumers_size() > 0) {
1329     std::set<int> bad(g_versions.bad_consumers().begin(),
1330                       g_versions.bad_consumers().end());
1331     bad.insert(versions()->bad_consumers().begin(),
1332                versions()->bad_consumers().end());
1333     g_versions.clear_bad_consumers();
1334     for (int v : bad) {
1335       g_versions.add_bad_consumers(v);
1336     }
1337   }
1338   g_->set_versions(g_versions);
1339   return OkStatus();
1340 }
1341 
PopulateReturnTensors()1342 Status GraphConstructor::PopulateReturnTensors() {
1343   if (opts_.return_tensors.empty()) return OkStatus();
1344   for (const TensorId& id : opts_.return_tensors) {
1345     auto iter = opts_.input_map.find(id);
1346     if (iter == opts_.input_map.end()) {
1347       // Locate id in imported nodes
1348       auto iter = gdef_nodes_.find(id.first);
1349       if (iter == gdef_nodes_.end()) {
1350         return errors::InvalidArgument("Requested return tensor '",
1351                                        id.ToString(),
1352                                        "' not found in graph def");
1353       }
1354       int num_outputs = iter->second.node->num_outputs();
1355       if ((id.second < 0 || id.second >= num_outputs) &&
1356           id.second != Graph::kControlSlot) {
1357         return errors::InvalidArgument("Invalid return output ", id.second,
1358                                        " of node '", id.first, "', which has ",
1359                                        num_outputs, " output(s)");
1360       }
1361       return_tensors_->push_back({iter->second.node, id.second});
1362     } else {
1363       // id was remapped to existing node
1364       TensorId remapped_id = iter->second;
1365       DCHECK_GT(existing_nodes_.count(remapped_id.first), 0);
1366       Node* node = existing_nodes_[remapped_id.first];
1367       return_tensors_->push_back({node, remapped_id.second});
1368     }
1369   }
1370   return OkStatus();
1371 }
1372 
PopulateReturnNodes()1373 Status GraphConstructor::PopulateReturnNodes() {
1374   if (opts_.return_nodes.empty()) return OkStatus();
1375   for (StringPiece name : opts_.return_nodes) {
1376     auto iter = gdef_nodes_.find(name);
1377     if (iter == gdef_nodes_.end()) {
1378       return errors::InvalidArgument("Requested return node '", name,
1379                                      "' not found in graph def");
1380     }
1381     return_nodes_->push_back(iter->second.node);
1382   }
1383   return OkStatus();
1384 }
1385 
PopulateMissingUnusedInputMapKeys()1386 Status GraphConstructor::PopulateMissingUnusedInputMapKeys() {
1387   if (missing_unused_input_map_keys_ == nullptr) return OkStatus();
1388   for (const auto& input_map_pair : opts_.input_map) {
1389     TensorId key = input_map_pair.first;
1390     if (used_input_map_keys_.count(key) > 0) continue;
1391 
1392     auto pair = gdef_nodes_.find(key.first);
1393     if (pair == gdef_nodes_.end()) {
1394       // key's node doesn't exist in GraphDef
1395       missing_unused_input_map_keys_->push_back(key);
1396       continue;
1397     }
1398 
1399     // Check that key's index is in bounds. Get the number of outputs from the
1400     // NodeDef, rather than the imported Node, since the Node may not exist if
1401     // opts_.skip_mapped_nodes is true.
1402     const NodeDef& node_def = get_node_def(pair->second.gdef_index);
1403     const OpDef* op_def;
1404     TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def.op(), &op_def));
1405     int num_outputs;
1406     TF_RETURN_IF_ERROR(NumOutputsForNode(node_def, *op_def, &num_outputs));
1407     if (key.second >= num_outputs) {
1408       // key's index out of bounds
1409       missing_unused_input_map_keys_->push_back(key);
1410     }
1411   }
1412   return OkStatus();
1413 }
1414 
Undo()1415 void GraphConstructor::Undo() {
1416   for (const auto& iter : gdef_nodes_) {
1417     if (iter.second.node != nullptr) {
1418       g_->RemoveNode(iter.second.node);
1419     }
1420   }
1421   g_->set_versions(original_versions_);
1422 }
1423 
MakeEdge(Node * src,int output_index,Node * dst,int input_index)1424 Status GraphConstructor::MakeEdge(Node* src, int output_index, Node* dst,
1425                                   int input_index) {
1426   if (output_index >= src->num_outputs()) {
1427     return errors::InvalidArgument(
1428         "Output ", output_index, " of node ", src->name(),
1429         " does not exist. Node only has ", src->num_outputs(), " outputs.");
1430   }
1431   if (input_index >= dst->num_inputs()) {
1432     return errors::InvalidArgument(
1433         "Input ", input_index, " of node ", dst->name(),
1434         " does not exist. Node only has ", dst->num_inputs(), " inputs.");
1435   }
1436 
1437   DataType src_out = src->output_type(output_index);
1438   DataType dst_in = dst->input_type(input_index);
1439   if (!TypesCompatible(dst_in, src_out)) {
1440     return errors::InvalidArgument(
1441         "Input ", input_index, " of node ", dst->name(), " was passed ",
1442         DataTypeString(src_out), " from ", src->name(), ":", output_index,
1443         " incompatible with expected ", DataTypeString(dst_in), ".");
1444   }
1445   g_->AddEdge(src, output_index, dst, input_index);
1446   return OkStatus();
1447 }
1448 
1449 }  // namespace
1450 
ConvertGraphDefToGraph(const GraphConstructorOptions & opts,const GraphDef & gdef,Graph * g)1451 Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
1452                               const GraphDef& gdef, Graph* g) {
1453   ShapeRefiner refiner(gdef.versions().producer(), g->op_registry());
1454   return GraphConstructor::Construct(
1455       opts, gdef.node(), &gdef.versions(), &gdef.library(), g, &refiner,
1456       /*return_tensors=*/nullptr, /*return_nodes=*/nullptr,
1457       /*missing_unused_input_map_keys=*/nullptr);
1458 }
1459 
ConvertGraphDefToGraph(const GraphConstructorOptions & opts,GraphDef && gdef,Graph * g)1460 Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
1461                               GraphDef&& gdef, Graph* g) {
1462   ShapeRefiner refiner(gdef.versions().producer(), g->op_registry());
1463   return GraphConstructor::Construct(opts, std::move(gdef), g, &refiner,
1464                                      /*return_tensors=*/nullptr,
1465                                      /*return_nodes=*/nullptr,
1466                                      /*missing_unused_input_map_keys=*/nullptr);
1467 }
1468 
ConvertNodeDefsToGraph(const GraphConstructorOptions & opts,gtl::ArraySlice<NodeDef> nodes,Graph * g)1469 Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
1470                               gtl::ArraySlice<NodeDef> nodes, Graph* g) {
1471   ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, g->op_registry());
1472   // TODO(irving): Copy will go away once NodeInfo exists
1473   std::vector<const NodeDef*> node_defs;
1474   node_defs.reserve(nodes.size());
1475   for (const auto& n : nodes) {
1476     node_defs.push_back(&n);
1477   }
1478   return GraphConstructor::Construct(opts, node_defs, nullptr, nullptr, g,
1479                                      &refiner, /*return_tensors=*/nullptr,
1480                                      /*return_nodes=*/nullptr,
1481                                      /*missing_unused_input_map_keys=*/nullptr);
1482 }
1483 
ImportGraphDef(const ImportGraphDefOptions & opts,const GraphDef & gdef,Graph * g,ShapeRefiner * refiner,ImportGraphDefResults * results)1484 Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
1485                       Graph* g, ShapeRefiner* refiner,
1486                       ImportGraphDefResults* results) {
1487   if (!opts.return_tensors.empty()) {
1488     if (results == nullptr) {
1489       return errors::InvalidArgument(
1490           "results argument to ImportGraphDef() must be non-null if "
1491           "opts.return_tensors is non-empty");
1492     }
1493   }
1494 
1495   if (!opts.return_nodes.empty()) {
1496     if (opts.skip_mapped_nodes) {
1497       return errors::InvalidArgument(
1498           "Requesting return_nodes with skip_mapped_nodes set is not currently "
1499           "supported");
1500     }
1501     if (results == nullptr) {
1502       return errors::InvalidArgument(
1503           "results argument to ImportGraphDef() must be non-null if "
1504           "opts.return_nodes is non-empty");
1505     }
1506   }
1507 
1508   if (results != nullptr) {
1509     if (!results->return_tensors.empty() || !results->return_nodes.empty() ||
1510         !results->missing_unused_input_map_keys.empty()) {
1511       return errors::InvalidArgument(
1512           "All fields in results argument to ImportGraphDef() must be empty.");
1513     }
1514   }
1515 
1516   ShapeRefiner default_refiner(gdef.versions().producer(), g->op_registry());
1517   if (refiner == nullptr) {
1518     refiner = &default_refiner;
1519   } else {
1520     // Log a warning if we are importing a GraphDef at an older
1521     // producer version after already having added non-source/sink
1522     // nodes to the graph in the past.
1523     if (gdef.versions().producer() > 0 &&
1524         gdef.versions().producer() < refiner->graph_def_version() &&
1525         g->num_nodes() > 2) {
1526       LOG(WARNING) << "Importing a graph with a lower producer version "
1527                    << gdef.versions().producer()
1528                    << " into an existing graph with producer version "
1529                    << refiner->graph_def_version() << ". Shape inference will "
1530                    << "have run different parts of the graph with different "
1531                    << "producer versions.";
1532     }
1533   }
1534 
1535   // Set the graph def version of the refiner as the min of the
1536   // current value and the version from the graph we are about to
1537   // import.
1538   //
1539   // Note: to match Run() semantics, we should re-run shape inference
1540   // on the entire graph if the producer version has changed.  For now
1541   // we log the warning above.
1542   refiner->set_graph_def_version(
1543       std::min(refiner->graph_def_version(), gdef.versions().producer()));
1544 
1545   if (results == nullptr) {
1546     return GraphConstructor::Construct(opts, gdef.node(), &gdef.versions(),
1547                                        &gdef.library(), g, refiner, nullptr,
1548                                        nullptr, nullptr);
1549   } else {
1550     return GraphConstructor::Construct(
1551         opts, gdef.node(), &gdef.versions(), &gdef.library(), g, refiner,
1552         &results->return_tensors, &results->return_nodes,
1553         &results->missing_unused_input_map_keys);
1554   }
1555 }
1556 
CopyGraph(const Graph & src,Graph * dest)1557 void CopyGraph(const Graph& src, Graph* dest) { dest->Copy(src); }
1558 
1559 }  // namespace tensorflow
1560