xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
17 
18 #include <atomic>
19 #include <functional>
20 #include <iterator>
21 #include <memory>
22 #include <queue>
23 #include <string>
24 #include <tuple>
25 #include <type_traits>
26 #include <unordered_set>
27 #include <utility>
28 #include <vector>
29 
30 #include "absl/algorithm/container.h"
31 #include "absl/base/thread_annotations.h"
32 #include "absl/container/flat_hash_map.h"
33 #include "absl/container/flat_hash_set.h"
34 #include "absl/container/inlined_vector.h"
35 #include "absl/strings/escaping.h"
36 #include "absl/strings/numbers.h"
37 #include "absl/strings/str_cat.h"
38 #include "absl/strings/str_join.h"
39 #include "absl/strings/string_view.h"
40 #include "absl/strings/strip.h"
41 #include "absl/synchronization/mutex.h"
42 #include "llvm/ADT/ArrayRef.h"
43 #include "llvm/ADT/DenseMap.h"
44 #include "llvm/ADT/DenseSet.h"
45 #include "llvm/ADT/STLExtras.h"
46 #include "llvm/ADT/ScopeExit.h"
47 #include "llvm/ADT/SetVector.h"
48 #include "llvm/ADT/SmallVector.h"
49 #include "llvm/ADT/StringRef.h"
50 #include "llvm/ADT/StringSet.h"
51 #include "llvm/ADT/Twine.h"
52 #include "llvm/Support/FormatVariadic.h"
53 #include "llvm/Support/SourceMgr.h"
54 #include "llvm/Support/raw_ostream.h"
55 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
56 #include "mlir/IR/Attributes.h"  // from @llvm-project
57 #include "mlir/IR/Builders.h"  // from @llvm-project
58 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
59 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
60 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
61 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
62 #include "mlir/IR/Location.h"  // from @llvm-project
63 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
64 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
65 #include "mlir/IR/Types.h"  // from @llvm-project
66 #include "mlir/IR/Verifier.h"  // from @llvm-project
67 #include "mlir/Pass/PassManager.h"  // from @llvm-project
68 #include "tensorflow/cc/saved_model/constants.h"
69 #include "tensorflow/cc/saved_model/loader_util.h"
70 #include "tensorflow/compiler/jit/shape_inference_helpers.h"
71 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
72 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
73 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
74 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
75 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
76 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
77 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
78 #include "tensorflow/compiler/mlir/tensorflow/transforms/initialize_variables_in_session_init.h"
79 #include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h"
80 #include "tensorflow/compiler/mlir/tensorflow/transforms/mark_initialized_variables.h"
81 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
82 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
83 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
84 #include "tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h"
85 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h"
86 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
87 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
88 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
89 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
90 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
91 #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
92 #include "tensorflow/compiler/xla/status_macros.h"
93 #include "tensorflow/core/common_runtime/function.h"
94 #include "tensorflow/core/common_runtime/graph_constructor.h"
95 #include "tensorflow/core/common_runtime/shape_refiner.h"
96 #include "tensorflow/core/framework/attr_value.pb.h"
97 #include "tensorflow/core/framework/function.pb.h"
98 #include "tensorflow/core/framework/graph.pb.h"
99 #include "tensorflow/core/framework/node_def.pb.h"
100 #include "tensorflow/core/framework/node_def_util.h"
101 #include "tensorflow/core/framework/op.h"
102 #include "tensorflow/core/framework/resource_var.h"
103 #include "tensorflow/core/framework/shape_inference.h"
104 #include "tensorflow/core/framework/tensor.pb.h"
105 #include "tensorflow/core/framework/types.h"
106 #include "tensorflow/core/framework/types.pb.h"
107 #include "tensorflow/core/framework/versions.pb.h"
108 #include "tensorflow/core/graph/algorithm.h"
109 #include "tensorflow/core/graph/graph.h"
110 #include "tensorflow/core/graph/node_builder.h"
111 #include "tensorflow/core/graph/tensor_id.h"
112 #include "tensorflow/core/grappler/utils/transitive_fanin.h"
113 #include "tensorflow/core/lib/core/errors.h"
114 #include "tensorflow/core/lib/strings/str_util.h"
115 #include "tensorflow/core/platform/crash_analysis.h"
116 #include "tensorflow/core/platform/env.h"
117 #include "tensorflow/core/platform/errors.h"
118 #include "tensorflow/core/platform/fingerprint.h"
119 #include "tensorflow/core/platform/logging.h"
120 #include "tensorflow/core/platform/path.h"
121 #include "tensorflow/core/platform/protobuf.h"
122 #include "tensorflow/core/platform/threadpool.h"
123 #include "tensorflow/core/platform/types.h"
124 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
125 #include "tensorflow/core/protobuf/meta_graph.pb.h"
126 #include "tensorflow/core/protobuf/saved_object_graph.pb.h"
127 #include "tensorflow/core/protobuf/saver.pb.h"
128 #include "tensorflow/core/protobuf/struct.pb.h"
129 #include "tensorflow/core/protobuf/trackable_object_graph.pb.h"
130 #include "tensorflow/core/util/device_name_utils.h"
131 #include "tensorflow/core/util/dump_graph.h"
132 #include "tensorflow/stream_executor/lib/statusor.h"
133 
StringRefToView(llvm::StringRef ref)134 static inline absl::string_view StringRefToView(llvm::StringRef ref) {
135   return {ref.data(), ref.size()};
136 }
137 
138 namespace tensorflow {
139 
140 constexpr size_t kNumThreadToConvertSignatures = 10;
141 constexpr absl::string_view kOutputShapesAttrName = "_output_shapes";
142 
143 using mlir::NamedAttrList;
144 using mlir::TensorType;
145 using mlir::tf_saved_model::AssetOp;
146 using mlir::tf_saved_model::GlobalTensorOp;
147 using mlir::tf_saved_model::SessionInitializerOp;
148 using stream_executor::port::StatusOr;
149 
150 namespace {
151 
IsOutputShapesAttribute(const AttrValue & attr_value,llvm::StringRef attr_name)152 bool IsOutputShapesAttribute(const AttrValue& attr_value,
153                              llvm::StringRef attr_name) {
154   return attr_name.compare(kOutputShapesAttrName) == 0 &&
155          attr_value.value_case() == AttrValue::kList;
156 }
157 
IsResourceOutputShapesAttribute(const AttrValue & attr_value,llvm::StringRef attr_name)158 bool IsResourceOutputShapesAttribute(const AttrValue& attr_value,
159                                      llvm::StringRef attr_name) {
160   if (attr_name == "_handle_dtypes" || attr_name == "_handle_shapes")
161     return attr_value.value_case() == AttrValue::kList;
162   return false;
163 }
164 
LoadImporterDialects(mlir::MLIRContext & context)165 void LoadImporterDialects(mlir::MLIRContext& context) {
166   // Load dialects involved in the conversion
167   mlir::DialectRegistry registry;
168   mlir::RegisterAllTensorFlowDialects(registry);
169   context.appendDialectRegistry(registry);
170   for (llvm::StringRef name : registry.getDialectNames())
171     context.getOrLoadDialect(name);
172 }
173 
174 // This class is used to generate new MLIR function name strings that are both
175 // unique in the TF function library `flib_` and unique among the name strings
176 // generated by the class object during its lifetime.
177 //
178 // In theory, this class is not necessary because we should simply take
179 // the TF function name and use it as MLIR function name. However, for some
180 // unknown reasons (callout for investigation in b/142268695), keeping the
181 // function names unchanged in an MLIR roundtrip causes test failures.
182 // TODO(b/142268695) Re-evaluate whether we need this class v.s. directly using
183 // and TF function name as MLIR function name after b/142268695 is root caused.
184 class NameUniquifier : public OpOrArgNameMapper {
185  public:
NameUniquifier(const FunctionLibraryDefinition & flib)186   explicit NameUniquifier(const FunctionLibraryDefinition& flib)
187       : flib_(flib) {}
188 
189  private:
IsUnique(llvm::StringRef name)190   bool IsUnique(llvm::StringRef name) override {
191     return !flib_.Contains(std::string(name));
192   }
193 
GetName(OpOrVal op_or_val)194   std::string GetName(OpOrVal op_or_val) override {
195     DCHECK(false) << "Unimplemented";
196     return "";
197   }
198 
199   const FunctionLibraryDefinition& flib_;
200 };
201 
202 // Stateful helper class to import a TensorFlow model into an MLIR Module.
203 //
204 // This is the base class that contains common utilities shared between the
205 // GraphDef importer and SavedModel importer.
206 //
207 // A subclass is expected to call `PrepareConvert` first to perform necessary
208 // preparation over the graph and also certain internal bookkeeping data.
209 // Afterwards the other protected methods can be called.
210 class ImporterBase {
211  protected:
ImporterBase(const FunctionLibraryDefinition & flib,const GraphDebugInfo & debug_info,const GraphImportConfig & specs,mlir::ModuleOp module,std::unordered_map<std::string,std::string> * tf_name_to_mlir_name,NameUniquifier * function_name_uniquifier,llvm::StringRef function_name_for_debug_info="")212   explicit ImporterBase(
213       const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
214       const GraphImportConfig& specs, mlir::ModuleOp module,
215       std::unordered_map<std::string, std::string>* tf_name_to_mlir_name,
216       NameUniquifier* function_name_uniquifier,
217       llvm::StringRef function_name_for_debug_info = "")
218       : builder_(module.getContext()),
219         module_(module),
220         context_(module.getContext()),
221         tf_name_to_mlir_name_(tf_name_to_mlir_name),
222         graph_flib_(flib),
223         specs_(specs),
224         debug_info_(debug_info),
225         function_name_for_debug_info_(function_name_for_debug_info),
226         function_name_uniquifier_(function_name_uniquifier),
227         error_handler_(module.getContext()) {
228     // Log import config.
229     if (VLOG_IS_ON(1)) {
230       LOG(INFO) << "Importing with: " << specs.str();
231       for (auto& it : *tf_name_to_mlir_name) {
232         LOG(INFO) << "\t" << it.first << " -> " << it.second;
233       }
234     }
235   }
236 
237   // Returns the inferred function signature of the given function body. Input
238   // types are unranked tensor of the respective datatype in the function and
239   // result types are inferred by the shape_refiner_. Result types need not be
240   // unranked tensors and could be ranked tensors in cases where result type
241   // depends on an op with static output shape like tf.Const.
242   StatusOr<mlir::FunctionType> InferLibFunctionType(const FunctionBody& fbody);
243 
244   // Extracts arg and ret nodes from FunctionBody.
245   void GetArgsAndRetsFromFunctionBody(
246       const FunctionBody& fbody,
247       absl::InlinedVector<OutputTensor, 4>* arg_nodes,
248       absl::InlinedVector<OutputTensor, 4>* ret_nodes,
249       absl::InlinedVector<Node*, 4>* control_ret_nodes);
250 
251   // Prepares converting the graph to an MLIR module. This step removes the
252   // backedges of the graph, orders the nodes and infers the shapes.
253   // PrepareConvert needs to ensure that the original `graph` is cloned prior
254   // execution. The cloning procedure relies on the roundtrip through the
255   // GraphDef. Graph to GraphDef def conversion is heavy, in case, `graph_def`
256   // was obtained previously provide it to the PrepareConvert to reuse.
257   Status PrepareConvert(const Graph& graph,
258                         std::unique_ptr<GraphDef> graph_def = nullptr);
259 
260   // Converts the prepared graph to a Function and adds it to the module. A set
261   // of nodes from the graph are given to converted to the arguments and returns
262   // of the function.
263   Status Convert(llvm::StringRef func_name, mlir::FunctionType func_type,
264                  const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
265                  const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
266                  const absl::InlinedVector<Node*, 4>& control_ret_nodes,
267                  llvm::ArrayRef<mlir::NamedAttribute> attrs);
268 
269   // Finds out the function definition for the given function name from the
270   // graph and converts it to a function of the module. This method is called
271   // on demand because the graph flib_def does not provide an iterator
272   // interface.
273   Status ConvertLibFunction(llvm::StringRef func_name);
274 
275   // Returns the list of nodes in the graph. Nodes are presented in the reverse
276   // order of a post-order depth-first visit starting from the graph's source
277   // nodes.
GetOrderedNodes() const278   llvm::ArrayRef<Node*> GetOrderedNodes() const { return ordered_nodes_; }
279 
280   // Returns the inferred input type at index `idx` of the `node` in the
281   // context.
282   StatusOr<mlir::Type> InferInputType(const Node& node, int idx,
283                                       mlir::Builder builder);
284 
285   // Returns the inferred output type at index `idx` of the `node` in the
286   // context.
287   StatusOr<mlir::Type> InferOutputType(const Node& node, int idx,
288                                        mlir::Builder builder);
289 
290   // Convert deferred TF functions to the MLIR representation.
291   // Conversion is deferred for efficiency reasons, e.g., to limit depth
292   // of recursion and reduce stack size pressure.
293   Status ConvertDeferredFunctions();
294 
295  private:
296   // Most types with subtypes have only one subtype.
297   using ElementSubtypes = llvm::SmallVector<TensorType, 1>;
298 
299   // Metadata used for deferred function conversion.
300   struct DeferredConversionMetaData {
DeferredConversionMetaDatatensorflow::__anon636eac960111::ImporterBase::DeferredConversionMetaData301     DeferredConversionMetaData(
302         const std::string& function_name,
303         const std::vector<mlir::NamedAttribute>& attributes)
304         : function_name(function_name), attributes(attributes) {}
305 
306     std::string function_name;
307     std::vector<mlir::NamedAttribute> attributes;
308   };
309 
310   // Adds all the ordered_nodes to the shape refiner shape_refiner_. Then all
311   // data type and shape information is maintained by the shape_refiner_.
312   // TODO(jpienaar): Remove once shape inference on import is removed.
313   Status AddNodesToShapeRefiner(
314       std::unordered_map<string, Node*>* node_name_map);
315 
316   // Prune nodes that do not feed into fetch nodes.
317   Status PruneUnreachableNodes(
318       std::unordered_map<string, Node*>* node_name_map);
319 
320   // Converts feeds to Placeholder nodes.
321   Status ConvertFeedsToPlaceholders(
322       std::unordered_map<string, Node*>* node_name_map);
323 
324   // Converts the inferred shape referred to by 'handle' in 'context', with
325   // given element type, and returns an MLIR tensor type.
326   StatusOr<TensorType> ConvertDataTypeAndShape(
327       DataType dtype, const shape_inference::ShapeHandle& handle,
328       const std::vector<shape_inference::ShapeAndType>* handle_subtypes,
329       shape_inference::InferenceContext* context, mlir::Builder builder);
330 
331   // Converts the inferred shape referred to by 'handle' in 'context', with
332   // given element type, and returns an MLIR tensor type.
333   StatusOr<TensorType> ConvertElementTypeAndShape(
334       mlir::Type element_type, const shape_inference::ShapeHandle& handle,
335       shape_inference::InferenceContext* context, mlir::Builder builder);
336 
337   // Converts the inferred subtypes for an element type to corresponding MLIR
338   // types in 'context'.
339   StatusOr<ElementSubtypes> ConvertSubtypes(
340       const std::vector<shape_inference::ShapeAndType>* handle_subtypes,
341       shape_inference::InferenceContext* context, mlir::Builder builder);
342 
343   // Converts the tensor proto into an MLIR elements attribute.
ConvertTensorProto(const TensorProto & value)344   StatusOr<mlir::ElementsAttr> ConvertTensorProto(const TensorProto& value) {
345     return ::tensorflow::ConvertTensorProto(value, &builder_);
346   }
347 
348   // Converts func name in graphdef to mlir::SymbolRefAttribute.
349   StatusOr<mlir::FlatSymbolRefAttr> ConvertFunctionCallName(
350       const std::string& func_name);
351 
352   // Converts the given non-function-call AttrValue to an MLIR Attribute.
353   StatusOr<mlir::Attribute> ConvertAttributeValue(const AttrValue& value);
354 
355   // Converts the given function-call AttrValue to MLIR Attributes and pushes
356   // them to the given attributes list. For example, if there is a kFunc
357   // AttrValue {name : foo, attrs : {k1 : bar, k2 : rfc}}, it will convert it to
358   // a list of MLIR Attributes: [{base_name : foo}, {base_name.k1 : bar},
359   // {base_name.k2 : rfc}}.
360   Status ConvertFunctionCallAttribute(const std::string& base_name,
361                                       const AttrValue& value,
362                                       NamedAttrList* attributes);
363 
364   // Helper to create either a tf_executor operation or a TF operation wrapped
365   // in an island.
366   mlir::Operation* CreateOperation(
367       const Node& node, llvm::StringRef node_type_name,
368       const mlir::OperationState& result,
369       const llvm::SmallVectorImpl<mlir::Value>& control_operands);
370 
371   // Converts one NodeDef from the input GraphDef into an Operation and
372   // inserts it into the MLIR module using builder_.
373   Status ConvertNode(const Node& node);
374 
375   // If the input graph represents a while-loop, the edges pointing from a
376   // "NextIteration" node to a "Merge" node add cyclic dependencies and make the
377   // topological sorting impossible. We need to remove these edges from the
378   // input graph to infer shapes and construct a Function. For each
379   // "NextIteration" node, there are two operations, "NextIteration.source"
380   // and "NextIteration.sink" are added to the MLIR module.
381   using BackEdge = BackEdgeHelper::BackEdge;
382 
383   // Removes backedges from the input graph. The removed edges are added back to
384   // to OpBuilder after the remaining graph is converted to the Function.
385   Status RemoveBackedges();
386 
387   // Restores backedges removed during shape inference to the final Function.
388   Status AddBackedges();
389 
390   // Restores a single backedge in the Function by adding a replicated
391   // operation before the dst operation.
392   Status AddBackedge(mlir::Operation* sink, mlir::Operation* dst,
393                      int dst_input);
394 
395   // Adds the input arguments and return operation to the function. The
396   // arguments are added as basic block argument. Also the argument types and
397   // the id of the nodes from the input graph needs to be specified.
398   Status ConvertFunctionArgAndRets(
399       mlir::func::FuncOp func, mlir::tf_executor::GraphOp graph_op,
400       llvm::ArrayRef<mlir::Type> arg_types,
401       const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
402       const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
403       const absl::InlinedVector<Node*, 4>& control_ret_nodes);
404 
405   // Gets the location information of the given node. It uses the
406   // "original_node_name" in the NodeDef to get the corresponding file location
407   // (FileLineColLoc) from the input DebugInfo and returns an CallSiteLoc. If
408   // there are multiple "original_node_names", a FusedLoc is returned. If the
409   // node name couldn't be found in the input DebugInfo, a NameLoc is used as
410   // the location.
411   mlir::Location GetLocation(const Node& node);
412 
413   // Appends the location string for the node to the error message and returns
414   // the combined error status.
415   Status EmitErrorWithLocationStr(const Node& node, const Status& error_status);
416 
417   // Inserts a placeholder node in the graph to replace a feed output tensor,
418   // and returns the new placeholder node and a boolean indicating if the
419   // original input node was removed from the graph. Uses of the feed output
420   // tensor are replaced with this placeholder node. If the feed output tensor
421   // is of a single output node, the control dependencies are forwarded to the
422   // the placeholder node, and the original node will be removed.
423   // Note: This modifies the graph, and so any list of ordered nodes needs to be
424   // reconstructed.
425   StatusOr<std::pair<Node*, bool>> CreatePlaceholderNodeForFeed(
426       const TensorShapeProto& shape, DataType dtype, Node* node, int index,
427       const std::unordered_map<string, Node*>& node_name_map);
428 
429   // Gets the input and output nodes corresponding to the specified input and
430   // output nodes in specs_. If there are no input or output nodes specified,
431   // nodes will be empty.
432   Status GetInputOutputNodes(
433       const std::unordered_map<string, Node*>& node_name_map,
434       std::unordered_set<const Node*>* nodes);
435 
436   // The input graph with backedges removed. The removed backedges are stored
437   // in the back_edge_helper.
438   BackEdgeHelper back_edge_helper_;
439   // A map between node and output index, for each backedge.
440   absl::flat_hash_map<const Node*, int> back_edge_node_output_;
441   absl::flat_hash_map<const Node*, BackEdge> back_edge_dst_inputs_;
442   // A map between sink and source operation of NextIteration
443   absl::flat_hash_map<mlir::Operation*, mlir::Operation*>
444       next_iteration_sink_source_;
445 
446   // All nodes and version information about the (copied) imported graph.
447   std::unique_ptr<Graph> graph_;
448   std::vector<Node*> ordered_nodes_;
449 
450   // Maps from a Node ID to a MLIR value.
451   using NodeValueMap = absl::flat_hash_map<int, mlir::Operation*>;
452 
453   mlir::OpBuilder builder_;
454   mlir::ModuleOp module_;
455   mlir::MLIRContext* context_;
456   std::unordered_map<std::string, std::string>* tf_name_to_mlir_name_;
457   const FunctionLibraryDefinition& graph_flib_;
458   const GraphImportConfig& specs_;
459   const GraphDebugInfo& debug_info_;
460   llvm::StringRef function_name_for_debug_info_;
461   NodeValueMap node_values_;
462   // TODO(jpienaar): Remove once shape inference on import is removed.
463   // The shape_refinner_ will be nullptr if shape inference on import is
464   // not enabled.
465   std::unique_ptr<ShapeRefiner> shape_refiner_ = nullptr;
466   NameUniquifier* function_name_uniquifier_;
467   mlir::StatusScopedDiagnosticHandler error_handler_;
468   // All the TF ops encountered that aren't modelled in dialect.
469   llvm::DenseSet<mlir::StringAttr> unmodelled_op_names_;
470 
471  protected:
472   // Maps feed as TensorId to new Placeholder node name.
473   absl::flat_hash_map<TensorId, absl::string_view> remapped_feeds_;
474   // Keep track of functions required deferred conversion.
475   std::queue<DeferredConversionMetaData> deferred_functions_;
476 };
477 
478 // Returns true if the node with given name has a non primary output that is
479 // used by some other node as an input. Returns false if no outputs are in use
480 // or only the first output is in use.
HasNonPrimaryOutputInUse(const GraphDef & graph_def,const std::string & node)481 bool HasNonPrimaryOutputInUse(const GraphDef& graph_def,
482                               const std::string& node) {
483   for (const auto& node_def : graph_def.node()) {
484     for (const auto& input : node_def.input()) {
485       if (absl::StartsWith(input, node + ":") && input != node + ":0") {
486         return true;
487       }
488     }
489   }
490   return false;
491 }
492 
493 // Updates the given LegacyFedInput node with Placeholder node if it is one of
494 // the inputs. Returns an error if non primary output of the LegacyFedInput node
495 // is in use and therefore can not be replaced by the Placeholder node that only
496 // has a single output.
UpdateLegacyFedInputNode(const GraphDef & graph_def,const GraphImportConfig::InputArrays & inputs,NodeDef * node)497 Status UpdateLegacyFedInputNode(const GraphDef& graph_def,
498                                 const GraphImportConfig::InputArrays& inputs,
499                                 NodeDef* node) {
500   const std::string& node_name = node->name();
501   auto it = inputs.find(node_name);
502 
503   // Node is not an input.
504   if (it == inputs.end()) return OkStatus();
505 
506   if (HasNonPrimaryOutputInUse(graph_def, node_name)) {
507     return errors::InvalidArgument(
508         "LegacyFedInput node ", node->name(),
509         " has non primary output in use and can not be replaced with "
510         "Placeholder node");
511   }
512 
513   DataType dtype = it->second.imported_dtype;
514   // Uses the existing output type if it isn't specified by the user.
515   if (dtype == DT_INVALID) {
516     dtype = node->attr().at("output_types").list().type(0);
517   }
518   // Update op name, drop inputs and set attributes required by the Placeholder
519   // op.
520   *node->mutable_op() = "Placeholder";
521   node->clear_attr();
522   node->clear_input();
523   AddNodeAttr("dtype", dtype, node);
524   AddNodeAttr("shape", it->second.shape, node);
525   return OkStatus();
526 }
527 
528 // Preprocesses GraphDef before it can be converted to Graph by,
529 // - Adding the default attributes to each node def if they are missing from
530 //   the GraphDef.
531 // - Replacing LegacyFedInput nodes with Placeholder nodes if
532 //   convert_legacy_fed_inputs option is enabled.
PreprocessGraphDef(const GraphImportConfig * specs,GraphDef * graph_def)533 Status PreprocessGraphDef(const GraphImportConfig* specs, GraphDef* graph_def) {
534   for (auto& node_def : *graph_def->mutable_node()) {
535     // TODO(hinsu): Completely deprecate support for LegacyFedInput ops. One
536     // solution could be have a tool to let users upgrade old serialized graphs.
537     if (specs && specs->convert_legacy_fed_inputs &&
538         node_def.op() == "LegacyFedInput") {
539       TF_RETURN_IF_ERROR(
540           UpdateLegacyFedInputNode(*graph_def, specs->inputs, &node_def));
541     }
542 
543     const tensorflow::OpRegistrationData* op_reg_data =
544         tensorflow::OpRegistry::Global()->LookUp(node_def.op());
545     if (!op_reg_data) {
546       // This is likely a function call node, so we should continue.
547       continue;
548     }
549     ::tensorflow::AddDefaultsToNodeDef(op_reg_data->op_def, &node_def);
550   }
551   return OkStatus();
552 }
553 
554 // Mapping from node name to feed (index and ArrayInfo). Node name must outlive
555 // this map.
556 using FeedsByNode = absl::flat_hash_map<
557     absl::string_view,
558     absl::flat_hash_map<int, const std::pair<std::string, ArrayInfo>*>>;
559 
560 // Creates from a `GraphImportConfig::InputArrays` a mapping from a feeds output
561 // tensor name to index and ArrayInfo. Keys and values are backed by
562 // `GraphImportConfig::InputArrays`.
GetFeedsByNode(const GraphImportConfig::InputArrays & inputs)563 StatusOr<FeedsByNode> GetFeedsByNode(
564     const GraphImportConfig::InputArrays& inputs) {
565   FeedsByNode feeds_by_node;
566   feeds_by_node.reserve(inputs.size());
567 
568   for (const auto& input : inputs) {
569     TensorId tensor = ParseTensorName(input.first);
570     if (tensor.index() < 0)
571       return errors::FailedPrecondition(
572           "Feed output tensor must be a data output '", tensor.ToString(), "'");
573 
574     auto& node = feeds_by_node[tensor.node()];
575     if (!node.insert({tensor.index(), &input}).second)
576       return errors::FailedPrecondition(
577           "Multiple feeds for the same output tensor '", tensor.ToString(),
578           "'");
579   }
580 
581   return feeds_by_node;
582 }
583 
584 // Creates a unique name for a node that will be replacing a feed output tensor.
GetUniqueNodeName(absl::string_view node_name,int index,const std::unordered_map<string,Node * > & node_name_map)585 std::string GetUniqueNodeName(
586     absl::string_view node_name, int index,
587     const std::unordered_map<string, Node*>& node_name_map) {
588   std::string new_node_name_base = absl::StrCat(node_name, "_", index);
589   int count = 0;
590   std::string new_node_name = new_node_name_base;
591   while (node_name_map.find(new_node_name) != node_name_map.end()) {
592     new_node_name = absl::StrCat(new_node_name_base, "_", count++);
593   }
594   return new_node_name;
595 }
596 
ConvertDeferredFunctions()597 Status ImporterBase::ConvertDeferredFunctions() {
598   while (!deferred_functions_.empty()) {
599     auto conversion_metadata = deferred_functions_.front();
600     deferred_functions_.pop();
601 
602     const FunctionDef* func_def =
603         graph_flib_.Find(conversion_metadata.function_name);
604     // Converts the graph to an MLIR function and adds it to the module.
605     // We populate the NodeSpec so that all the _Arg ops get their shape
606     // added correctly.
607     GraphImportConfig specs;
608     specs.enable_shape_inference = specs_.enable_shape_inference;
609     specs.unconditionally_use_set_output_shapes =
610         specs_.unconditionally_use_set_output_shapes;
611     for (const auto& name_and_value : func_def->attr()) {
612       if (name_and_value.first == "_input_shapes") {
613         auto& list = name_and_value.second.list();
614         auto& signature = func_def->signature();
615         // Some models have "_input_shapes" attribute, but with its value empty
616         if (list.shape_size() > 0 &&
617             list.shape_size() != signature.input_arg_size()) {
618           return errors::FailedPrecondition(
619               "Number of input arguments must be equal to the length of "
620               "_input_shapes attribute in function '",
621               StringRefToView(conversion_metadata.function_name), "'.");
622         }
623         for (int i = 0, e = signature.input_arg_size(); i < e; i++) {
624           auto& input_arg = signature.input_arg(i);
625           auto& array_info = specs.inputs[input_arg.name()];
626           array_info.imported_dtype = input_arg.type();
627           // set to unranked for empty "_input_shapes" attribute
628           if (list.shape_size() > 0)
629             array_info.shape = list.shape(i);
630           else
631             array_info.shape.set_unknown_rank(true);
632         }
633       }
634     }
635 
636     ImporterBase importer(graph_flib_, debug_info_, specs, module_,
637                           tf_name_to_mlir_name_, function_name_uniquifier_,
638                           conversion_metadata.function_name);
639 
640     std::unique_ptr<FunctionBody> fbody;
641     TF_RETURN_IF_ERROR(
642         FunctionDefToBodyHelper(*func_def, AttrSlice(), &graph_flib_, &fbody));
643     TF_RETURN_IF_ERROR(importer.PrepareConvert(*fbody->graph));
644 
645     TF_ASSIGN_OR_RETURN(auto func_type, importer.InferLibFunctionType(*fbody));
646 
647     absl::InlinedVector<OutputTensor, 4> arg_nodes;
648     absl::InlinedVector<OutputTensor, 4> ret_nodes;
649     absl::InlinedVector<Node*, 4> control_ret_nodes;
650     importer.GetArgsAndRetsFromFunctionBody(*fbody, &arg_nodes, &ret_nodes,
651                                             &control_ret_nodes);
652     const std::string& mlir_func_name =
653         (*tf_name_to_mlir_name_)[conversion_metadata.function_name];
654 
655     TF_RETURN_IF_ERROR(importer.Convert(mlir_func_name, func_type, arg_nodes,
656                                         ret_nodes, control_ret_nodes,
657                                         conversion_metadata.attributes));
658 
659     // Additional function bodies could be discovered during the deferred
660     // loading of the current function. Add them to the working queue.
661     while (!importer.deferred_functions_.empty()) {
662       deferred_functions_.push(importer.deferred_functions_.front());
663       importer.deferred_functions_.pop();
664     }
665   }
666 
667   return OkStatus();
668 }
669 
RemoveBackedges()670 Status ImporterBase::RemoveBackedges() {
671   // Remove all the backedges. So the nodes can be added to the shape refiner.
672   TF_RETURN_IF_ERROR(back_edge_helper_.Remove(graph_.get()));
673   VLOG(1) << "Found " << (back_edge_helper_.RemovedEdges().size())
674           << " backedges.";
675 
676   // Creates a map for quickly identifying whether a node output is a backedge.
677   for (const auto& edge : back_edge_helper_.RemovedEdges()) {
678     if (back_edge_node_output_.find(edge.src) != back_edge_node_output_.end() &&
679         back_edge_node_output_[edge.src] != edge.src_output) {
680       return errors::FailedPrecondition(
681           "More than one of the src node outputs are backedges!");
682     }
683     back_edge_node_output_[edge.src] = edge.src_output;
684     // We expect a merge to receive a single backedge (multiple NextIteration
685     // nodes feeding into the same merge is unexpected here).
686     DCHECK(!back_edge_dst_inputs_.contains(edge.dst));
687     back_edge_dst_inputs_[edge.dst] = edge;
688   }
689 
690   // Obtains a RPO ordering, using node names as a tiebreak for stable sorting.
691   GetReversePostOrder(
692       *graph_, &ordered_nodes_,
693       [](const Node* n1, const Node* n2) { return n1->name() < n2->name(); });
694   return OkStatus();
695 }
696 
CopyStackTraces(const Graph & from,Graph * to)697 Status CopyStackTraces(const Graph& from, Graph* to) {
698   // Copy over the stack traces.
699   // TODO(jpienaar): This really shouldn't be needed, copying the Graph above
700   // and then needing these traversals is unfortunate.
701   std::unordered_map<string, Node*> node_map = from.BuildNodeNameIndex();
702   for (Node* node : to->nodes()) {
703     if (const Node* old_node = node_map[node->name()]) {
704       if (const std::shared_ptr<AbstractStackTrace>& stack =
705               old_node->GetStackTrace()) {
706         DVLOG(2) << "Stack for " << node->name() << " "
707                  << old_node->GetStackTrace()->ToString(
708                         AbstractStackTrace::TracePrintingOptions());
709         node->SetStackTrace(stack);
710       } else {
711         DVLOG(1) << "No stack for " << node->name() << " (" << node
712                  << ") in Graph " << &from;
713       }
714     } else {
715       DVLOG(1) << "No stack for " << node->name() << " (" << node
716                << ") in Graph " << &from;
717     }
718   }
719 
720   return OkStatus();
721 }
722 
CreatePlaceholderNodeForFeed(const TensorShapeProto & shape,DataType dtype,Node * node,int index,const std::unordered_map<string,Node * > & node_name_map)723 StatusOr<std::pair<Node*, bool>> ImporterBase::CreatePlaceholderNodeForFeed(
724     const TensorShapeProto& shape, DataType dtype, Node* node, int index,
725     const std::unordered_map<string, Node*>& node_name_map) {
726   DCHECK_LT(index, node->num_outputs());
727   const bool update_inplace = node->num_outputs() == 1 && index == 0;
728   std::string new_node_name =
729       update_inplace ? node->name()
730                      : GetUniqueNodeName(node->name(), index, node_name_map);
731 
732   Node* placeholder_node;
733   NodeBuilder builder(new_node_name, "Placeholder");
734   builder.Attr("shape", shape);
735   builder.Attr("dtype", dtype);
736   TF_RETURN_IF_ERROR(builder.Finalize(graph_.get(), &placeholder_node));
737 
738   // Update edges from original feed with Placeholder node.
739   std::vector<const Edge*> data_edges;
740   std::vector<const Edge*> control_edges;
741   for (const tensorflow::Edge* edge : node->out_edges()) {
742     if (edge->src_output() == index) {
743       data_edges.push_back(edge);
744     } else if (update_inplace && edge->IsControlEdge()) {
745       control_edges.push_back(edge);
746     }
747   }
748 
749   for (const auto* edge : data_edges) {
750     TF_RETURN_IF_ERROR(graph_->UpdateEdge(placeholder_node, 0, edge->dst(),
751                                           edge->dst_input()));
752   }
753 
754   // TODO(lyandy): Preserve control dependencies properly by not forwarding
755   // control dependencies to data outputs and not removing single output nodes.
756   // When a data output is replaced as a feed, unless there is another non feed
757   // data output or an explicit control output used by the same node, transitive
758   // control dependencies are not to be executed. For single output nodes,
759   // Placeholders can be converted to a NoOp if there are no uses, and
760   // PlaceholderWithDefault can be converted to an Identity.
761   for (const auto* edge : control_edges) {
762     graph_->AddControlEdge(placeholder_node, edge->dst());
763     graph_->RemoveControlEdge(edge);
764   }
765 
766   if (update_inplace) {
767     graph_->RemoveNode(node);
768   }
769 
770   return std::pair<Node*, bool>(placeholder_node, update_inplace);
771 }
772 
GetInputOutputNodes(const std::unordered_map<string,Node * > & node_name_map,std::unordered_set<const Node * > * nodes)773 Status ImporterBase::GetInputOutputNodes(
774     const std::unordered_map<string, Node*>& node_name_map,
775     std::unordered_set<const Node*>* nodes) {
776   auto add_node = [&](absl::string_view name) {
777     auto it = node_name_map.find(std::string(name));
778     if (it == node_name_map.end()) {
779       return errors::FailedPrecondition(
780           absl::StrCat("Graph does not contain node: ", name));
781     }
782     nodes->insert(it->second);
783     return OkStatus();
784   };
785 
786   // Remap feeds and fetches to newly created Placeholder nodes.
787   for (const auto& input : specs_.inputs) {
788     TensorId tensor = ParseTensorName(input.first);
789     auto remapped_it = remapped_feeds_.find(tensor);
790     if (remapped_it != remapped_feeds_.end()) {
791       TF_RETURN_IF_ERROR(add_node(remapped_it->second));
792     } else {
793       TF_RETURN_IF_ERROR(add_node(tensor.node()));
794     }
795   }
796 
797   for (const auto& output : specs_.outputs) {
798     TensorId tensor = ParseTensorName(output);
799     auto remapped_it = remapped_feeds_.find(tensor);
800     if (remapped_it != remapped_feeds_.end()) {
801       TF_RETURN_IF_ERROR(add_node(remapped_it->second));
802     } else {
803       TF_RETURN_IF_ERROR(add_node(tensor.node()));
804     }
805   }
806 
807   for (const auto& control_output : specs_.control_outputs)
808     TF_RETURN_IF_ERROR(add_node(control_output));
809 
810   return OkStatus();
811 }
812 
813 // TODO(jpienaar): Remove this post shape inference on import flag is removed.
AddNodesToShapeRefiner(std::unordered_map<string,Node * > * node_name_map)814 Status ImporterBase::AddNodesToShapeRefiner(
815     std::unordered_map<string, Node*>* node_name_map) {
816   shape_refiner_ = std::make_unique<ShapeRefiner>(graph_->versions(),
817                                                    graph_->op_registry());
818   // Some operations (for example "TPUExecute") don't have shape inference
819   // function defined, so we should set this to false for adding nodes with
820   // these types of operations.
821   shape_refiner_->set_require_shape_inference_fns(false);
822   shape_refiner_->set_function_library_for_shape_inference(&graph_flib_);
823 
824   TF_ASSIGN_OR_RETURN(auto feeds_by_node, GetFeedsByNode(specs_.inputs));
825 
826   // First add all nodes to the refiner.
827   for (Node* node : ordered_nodes_) {
828     // We need to use a TensorFlow node to teach the shape refiner that user
829     // specifies certain data type and shape for the inputs in the `specs_`.
830     // This node shouldn't have any inputs, only have one output and its
831     // output type/shape is only determined by its "named" attributes. (The
832     // attributes should have fixed names so we can use the info from `specs_`
833     // to set the value of them.) `Placeholder` satisfies these constraints.
834     //
835     // Therefore, if the input node isn't a `Placeholder`, we create one and use
836     // it to replace the original input node, so the shape refiner can
837     // successfully propagate the user's input type and shape to the rest of the
838     // graph.
839     bool node_added_to_shape_refiner = false;
840     auto it = feeds_by_node.find(node->name());
841     if (it != feeds_by_node.end()) {
842       auto op_name = node->op_def().name();
843       if (op_name != "Placeholder" && op_name != "LegacyFedInput" &&
844           op_name != FunctionLibraryDefinition::kArgOp) {
845         for (const auto& output_tensor : it->second) {
846           const int index = output_tensor.first;
847           const ArrayInfo& array_info = output_tensor.second->second;
848 
849           DataType dtype = array_info.imported_dtype;
850           // Uses the existing output type if it isn't specified by the user.
851           if (dtype == DT_INVALID) {
852             dtype = node->output_type(index);
853           }
854 
855           TF_ASSIGN_OR_RETURN(
856               auto placeholder_node_and_removed,
857               CreatePlaceholderNodeForFeed(array_info.shape, dtype, node, index,
858                                            *node_name_map));
859 
860           Node* placeholder_node = placeholder_node_and_removed.first;
861           if (placeholder_node_and_removed.second) {
862             // Original node has been removed from the graph.
863             node = placeholder_node;
864             node_added_to_shape_refiner = true;
865           }
866           remapped_feeds_[{it->first, index}] = placeholder_node->name();
867           (*node_name_map)[placeholder_node->name()] = placeholder_node;
868           // Add the new placeholder node to the shape refiner.
869           Status status = shape_refiner_->AddNode(placeholder_node);
870           if (!status.ok()) {
871             return EmitErrorWithLocationStr(*placeholder_node, status);
872           }
873         }
874       } else {
875         auto index_it = it->second.find(0);
876         if (index_it == it->second.end()) {
877           return errors::FailedPrecondition(
878               "Missing feed output tensor at index 0 for node '", node->name(),
879               "'");
880         }
881         node->AddAttr("shape", index_it->second->second.shape);
882         DataType dtype = index_it->second->second.imported_dtype;
883         // Uses the existing output type if it isn't specified by the user.
884         if (dtype == DT_INVALID) {
885           dtype = node->output_type(0);
886         }
887         node->AddAttr("dtype", dtype);
888       }
889     }
890     if (!node_added_to_shape_refiner) {
891       // Add the node to the shape refiner if the node hasn't been removed.
892       Status status = shape_refiner_->AddNode(node);
893       if (!status.ok()) {
894         return EmitErrorWithLocationStr(*node, status);
895       }
896     }
897 
898     auto set_shape_from_list_attr = [&](const AttrValue* attr) {
899       auto& list = attr->list();
900       // This follows the same approach as in ValidateShape, but only flags
901       // warning in case where there are mismatch in number of shapes and
902       // outputs and in which case it just returns without attempting to refine.
903       if (list.shape_size() != node->num_outputs()) {
904         LOG(WARNING) << "Node '" << node->name() << "' has "
905                      << node->num_outputs() << " outputs but the "
906                      << kOutputShapesAttrName
907                      << " attribute specifies shapes for " << list.shape_size()
908                      << " outputs";
909         return OkStatus();
910       }
911 
912       for (const auto& shape : llvm::enumerate(list.shape())) {
913         auto* node_context = shape_refiner_->GetContext(node);
914         shape_inference::ShapeHandle handle;
915         Status status =
916             node_context->MakeShapeFromShapeProto(shape.value(), &handle);
917         if (!status.ok()) {
918           return EmitErrorWithLocationStr(*node, status);
919         }
920         node_context->set_output(shape.index(), handle);
921       }
922       return OkStatus();
923     };
924 
925     // If it is the argument node, the shape handle is set explicitly, so it
926     // can be propagated to the body nodes of the function.
927     if (StringPiece(node->type_string()) == FunctionLibraryDefinition::kArgOp) {
928       auto* node_context = shape_refiner_->GetContext(node);
929       DCHECK(node_context != nullptr);
930       if (const AttrValue* attr = node->attrs().Find("shape")) {
931         shape_inference::ShapeHandle handle;
932         Status status =
933             node_context->MakeShapeFromShapeProto(attr->shape(), &handle);
934         if (!status.ok()) {
935           return EmitErrorWithLocationStr(*node, status);
936         }
937         node_context->set_output(0, handle);
938       } else if (const AttrValue* attr =
939                      node->attrs().Find(kOutputShapesAttrName)) {
940         TF_RETURN_IF_ERROR(set_shape_from_list_attr(attr));
941       } else {
942         node_context->set_output(0, node_context->UnknownShape());
943       }
944     }
945 
946     // Following GraphConstructor::ValidateShape called from
947     // GraphConstructor::Convert, override the shape if _output_shapes is set.
948     if (specs_.unconditionally_use_set_output_shapes ||
949         node->op_def().name() == "ReadVariableOp") {
950       if (const AttrValue* attr = node->attrs().Find(kOutputShapesAttrName))
951         TF_RETURN_IF_ERROR(set_shape_from_list_attr(attr));
952     }
953   }
954 
955   // Since we might have inserted and removed nodes from the graph, fix
956   // source/sink edges and reconstruct the RPO ordering of nodes
957   FixupSourceAndSinkEdges(graph_.get());
958 
959   // Prune nodes in the graph that are not reachable from the output.
960   if (specs_.prune_unused_nodes) {
961     std::unordered_set<const Node*> prune_start;
962     TF_RETURN_IF_ERROR(GetInputOutputNodes(*node_name_map, &prune_start));
963     if (!prune_start.empty()) {
964       if (PruneForReverseReachability(graph_.get(), prune_start)) {
965         VLOG(1) << "Pruned unused nodes in graphdef";
966       } else {
967         VLOG(1) << "No unused nodes in graphdef to prune";
968       }
969     } else {
970       VLOG(1) << "No output nodes specified, skipping pruning";
971     }
972   } else {
973     VLOG(1) << "Pruning unused nodes in graphdef is disabled";
974   }
975 
976   // Re-initialize ordered_nodes_ since we might have modified the graph.
977   GetReversePostOrder(
978       *graph_, &ordered_nodes_,
979       [](const Node* n1, const Node* n2) { return n1->name() < n2->name(); });
980 
981   VLOG(1) << "Inferring graph shapes to fixpoint";
982 
983   // The "changed" information from UpdateNode can give false positives, so we
984   // create a dedicated method to verify the shapes are not changed before and
985   // after the shape refine.
986   auto same_inferred_shape = [](shape_inference::InferenceContext* c,
987                                 shape_inference::ShapeHandle s0,
988                                 shape_inference::ShapeHandle s1) -> bool {
989     if (s0.SameHandle(s1) || (!c->RankKnown(s0) && !c->RankKnown(s1))) {
990       return true;
991     }
992     if (c->Rank(s0) != c->Rank(s1)) {
993       return false;
994     }
995     for (int i = 0; i < c->Rank(s0); ++i) {
996       if (!c->Dim(s0, i).SameHandle(c->Dim(s1, i))) {
997         int64_t val0 = c->Value(c->Dim(s0, i));
998         int64_t val1 = c->Value(c->Dim(s1, i));
999         // Negative value is treated as unknown so all negative values indicate
1000         // the same dimension.
1001         if (val0 >= 0 && val1 >= 0 && val0 != val1) return false;
1002       }
1003     }
1004     return true;
1005   };
1006 
1007   bool changed = true;
1008   int i = 0;
1009   const int kMaxIterationCount = 2;
1010   while (changed && i != kMaxIterationCount) {
1011     changed = false;
1012     for (const Node* node : ordered_nodes_) {
1013       auto* shape_context = shape_refiner_->GetContext(node);
1014       DCHECK(shape_context != nullptr);
1015       absl::InlinedVector<shape_inference::ShapeHandle, 4> existing;
1016       existing.reserve(shape_context->num_outputs());
1017       for (int o = 0; o < shape_context->num_outputs(); ++o) {
1018         existing.push_back(shape_context->output(o));
1019       }
1020       bool inferred = false;
1021       shape_inference::ShapeHandle handle;
1022       Status status =
1023           shape_refiner_->UpdateNode(node, /*relax=*/false, &inferred);
1024       if (!status.ok()) {
1025         return EmitErrorWithLocationStr(*node, status);
1026       }
1027       for (int o = 0; o < shape_context->num_outputs(); ++o) {
1028         if (!same_inferred_shape(shape_context, shape_context->output(o),
1029                                  existing[o])) {
1030           changed = true;
1031           break;
1032         }
1033       }
1034     }
1035     ++i;
1036   }
1037   if (i >= kMaxIterationCount) {
1038     LOG(WARNING) << "Graph shapes did not converge to a fixpoint within "
1039                  << kMaxIterationCount
1040                  << " iterations. Graph shapes may be conservative.";
1041   }
1042   VLOG(1) << "Graph shapes were inferred with " << (i - 1)
1043           << " extra rounds of analysis to reach a fixpoint.";
1044   return OkStatus();
1045 }
1046 
InferInputType(const Node & node,int idx,mlir::Builder builder)1047 StatusOr<mlir::Type> ImporterBase::InferInputType(const Node& node, int idx,
1048                                                   mlir::Builder builder) {
1049   if (specs_.enable_shape_inference) {
1050     // TODO(jpienaar): Remove this if shape inference on import flag is removed.
1051     ExtendedInferenceContext* shape_context =
1052         shape_refiner_->GetExtendedContext(&node);
1053     DataType dtype = shape_context->input_type(idx);
1054     auto* context = shape_context->get_context();
1055     return ConvertDataTypeAndShape(dtype, context->input(idx),
1056                                    context->input_handle_shapes_and_types(idx),
1057                                    context, builder);
1058   }
1059   DataType dtype = node.properties()->input_types[idx];
1060   mlir::Type element_type;
1061   TF_RETURN_IF_ERROR(ConvertDataType(dtype, builder, &element_type));
1062   return mlir::UnrankedTensorType::get(element_type);
1063 }
1064 
InferOutputType(const Node & node,int idx,mlir::Builder builder)1065 StatusOr<mlir::Type> ImporterBase::InferOutputType(const Node& node, int idx,
1066                                                    mlir::Builder builder) {
1067   DataType dtype = node.properties()->output_types[idx];
1068 
1069   // Returns output type given inference context.
1070   auto shape_ic =
1071       [&](shape_inference::InferenceContext* c) -> StatusOr<mlir::Type> {
1072     // TODO(b/200093974): Post triage, consider following
1073     // GraphConstructor::ValidateShape in checking _output_shapes always.
1074     if (specs_.unconditionally_use_set_output_shapes) {
1075       if (const AttrValue* attr = node.attrs().Find(kOutputShapesAttrName)) {
1076         auto& list = attr->list();
1077         if (list.shape_size() > idx) {
1078           const TensorShapeProto& p = list.shape()[idx];
1079           shape_inference::ShapeHandle h;
1080           Status s = c->MakeShapeFromShapeProto(p, &h);
1081           if (!s.ok())
1082             return errors::InvalidArgument(
1083                 "Node '", node.name(), " has an invalid ",
1084                 kOutputShapesAttrName, " attribute (shape #", idx, " error:'",
1085                 s.error_message(), "')");
1086           c->set_output(idx, h);
1087         }
1088       }
1089     }
1090 
1091     return ConvertDataTypeAndShape(dtype, c->output(idx),
1092                                    c->output_handle_shapes_and_types(idx), c,
1093                                    builder);
1094   };
1095 
1096   if (specs_.enable_shape_inference) {
1097     // TODO(jpienaar): Remove this if shape inference on import flag is removed.
1098     ExtendedInferenceContext* shape_context =
1099         shape_refiner_->GetExtendedContext(&node);
1100     return shape_ic(shape_context->get_context());
1101   }
1102 
1103   // Treat TensorList init ops specially here as the op requires knowing its
1104   // element dtype.
1105   // TODO(jpienaar): Reconsider post refactoring shape functions.
1106   if (node.type_string() == "TensorListReserve" ||
1107       node.type_string() == "EmptyTensorList") {
1108     mlir::Type etype;
1109     if (auto element_dtype = node.attrs().Find("element_dtype")) {
1110       TF_RETURN_IF_ERROR(
1111           ConvertDataType(element_dtype->type(), builder, &etype));
1112     }
1113     return mlir::RankedTensorType::get(
1114         {}, mlir::TF::VariantType::get({mlir::UnrankedTensorType::get(etype)},
1115                                        etype.getContext()));
1116   }
1117 
1118   if (node.IsWhileNode()) {
1119     auto* output_shapes = node.attrs().Find("output_shapes");
1120     auto* element_types = node.attrs().Find("T");
1121     if (output_shapes && !output_shapes->list().shape().empty()) {
1122       const auto& output_shape = output_shapes->list().shape(idx);
1123       const auto& element_type = element_types->list().type(idx);
1124       return ConvertToMlirTensorType(output_shape, element_type, &builder);
1125     }
1126   }
1127 
1128   auto type_from_array_attr = [&node, &idx, &builder](
1129                                   absl::string_view output_shape_attr,
1130                                   absl::string_view element_type_attr) {
1131     auto* output_shapes = node.attrs().Find(output_shape_attr);
1132     auto* element_types = node.attrs().Find(element_type_attr);
1133     const auto& output_shape = output_shapes->list().shape(idx);
1134     const auto& element_type = element_types->list().type(idx);
1135     return ConvertToMlirTensorType(output_shape, element_type, &builder);
1136   };
1137 
1138   if (node.type_string() == "IteratorGetNext" ||
1139       node.type_string() == "IteratorGetNextSync" ||
1140       node.type_string() == "MultiDeviceIteratorGetNextFromShard")
1141     return type_from_array_attr("output_shapes", "output_types");
1142 
1143   if (node.type_string() == "InfeedDequeueTuple")
1144     return type_from_array_attr("shapes", "dtypes");
1145 
1146   if (node.type_string() == "InfeedDequeue") {
1147     assert(idx == 0);
1148     const auto& output_shape = node.attrs().Find("shape")->shape();
1149     const auto& element_type = node.attrs().Find("dtype")->type();
1150     return ConvertToMlirTensorType(output_shape, element_type, &builder);
1151   }
1152 
1153   // Returns a simple, more conservative unranked tensor type.
1154   auto default_type = [&]() -> StatusOr<mlir::Type> {
1155     mlir::Type element_type;
1156     TF_RETURN_IF_ERROR(ConvertDataType(dtype, builder, &element_type));
1157 
1158     // TODO(b/200093974): Post triage, consider following
1159     // GraphConstructor::ValidateShape in checking _output_shapes.
1160     if (specs_.unconditionally_use_set_output_shapes) {
1161       if (const AttrValue* attr = node.attrs().Find(kOutputShapesAttrName)) {
1162         auto& list = attr->list();
1163         if (list.shape_size() > idx) {
1164           llvm::SmallVector<int64_t, 4> shape;
1165           const TensorShapeProto& shape_proto = list.shape()[idx];
1166           if (shape_proto.unknown_rank())
1167             return mlir::UnrankedTensorType::get(element_type);
1168           TF_RETURN_IF_ERROR(ConvertToMlirShape(shape_proto, &shape));
1169           return mlir::RankedTensorType::get(shape, element_type);
1170         }
1171       }
1172     }
1173 
1174     return mlir::UnrankedTensorType::get(element_type);
1175   };
1176 
1177   // Below we only try and do some shape inference for "source" ops which have
1178   // no inputs.
1179   if (node.num_inputs() > 0) return default_type();
1180 
1181   // Do some simply inference here to get the function arguments correct for
1182   // this common case.
1183   // TODO(jpienaar): Reconsider post refactoring shape functions.
1184   if (node.IsArg()) {
1185     if (dtype == DT_RESOURCE) {
1186       const AttrValue* dtype_attr = node.attrs().Find("_handle_dtypes");
1187       const AttrValue* shape_attr = node.attrs().Find("_handle_shapes");
1188       if (dtype_attr && shape_attr) {
1189         if (dtype_attr->list().type().empty()) {
1190           return errors::InvalidArgument(
1191               "Invalid \"_handle_dtypes\" attribute value for _Arg node: ",
1192               shape_attr->DebugString());
1193         }
1194         if (shape_attr->list().shape().empty()) {
1195           return errors::InvalidArgument(
1196               "Invalid \"_handle_shapes\" attribute value for _Arg node: ",
1197               shape_attr->DebugString());
1198         }
1199         DataType dtype = dtype_attr->list().type(0);
1200         const TensorShapeProto& shape_proto = shape_attr->list().shape(0);
1201         TF_ASSIGN_OR_RETURN(
1202             auto etype, ConvertToMlirTensorType(shape_proto, dtype, &builder));
1203         return mlir::UnrankedTensorType::get(mlir::TF::ResourceType::get(
1204             {etype.cast<TensorType>()}, builder.getContext()));
1205       } else {
1206         return mlir::UnrankedTensorType::get(
1207             mlir::TF::ResourceType::get(builder.getContext()));
1208       }
1209     } else if (auto shape = node.attrs().Find("_output_shapes")) {
1210       if (shape->has_list() && shape->list().shape_size() == 1) {
1211         return ConvertToMlirTensorType(shape->list().shape().at(0), dtype,
1212                                        &builder);
1213       }
1214     }
1215   }
1216 
1217   const tensorflow::OpRegistrationData* op_reg_data;
1218   TF_RETURN_IF_ERROR(
1219       graph_->op_registry()->LookUp(node.type_string(), &op_reg_data));
1220   if (!op_reg_data) {
1221     DVLOG(1) << "Skipping inference for unregistered op " << node.type_string();
1222     return default_type();
1223   }
1224   if (op_reg_data->shape_inference_fn == nullptr) {
1225     DVLOG(1) << "Skipping inference for op without shape function "
1226              << node.type_string();
1227     return default_type();
1228   }
1229   shape_inference::InferenceContext c(graph_->versions().producer(),
1230                                       node.attrs(), op_reg_data->op_def,
1231                                       std::vector<PartialTensorShape>{}, {},
1232                                       /*input_tensors_as_shapes=*/{}, {});
1233   TF_RETURN_IF_ERROR(c.Run(op_reg_data->shape_inference_fn));
1234   return shape_ic(&c);
1235 }
1236 
ConvertDataTypeAndShape(DataType dtype,const shape_inference::ShapeHandle & handle,const std::vector<shape_inference::ShapeAndType> * handle_subtypes,shape_inference::InferenceContext * context,mlir::Builder builder)1237 StatusOr<TensorType> ImporterBase::ConvertDataTypeAndShape(
1238     DataType dtype, const shape_inference::ShapeHandle& handle,
1239     const std::vector<shape_inference::ShapeAndType>* handle_subtypes,
1240     shape_inference::InferenceContext* context, mlir::Builder builder) {
1241   TF_ASSIGN_OR_RETURN(auto subtypes,
1242                       ConvertSubtypes(handle_subtypes, context, builder));
1243 
1244   mlir::Type element_type;
1245   if (dtype == DT_VARIANT)
1246     element_type = mlir::TF::VariantType::get(subtypes, context_);
1247   else if (dtype == DT_RESOURCE)
1248     element_type = mlir::TF::ResourceType::get(subtypes, context_);
1249   else
1250     TF_RETURN_IF_ERROR(
1251         ::tensorflow::ConvertDataType(dtype, builder, &element_type));
1252 
1253   return ConvertElementTypeAndShape(element_type, handle, context, builder);
1254 }
1255 
ConvertElementTypeAndShape(mlir::Type element_type,const shape_inference::ShapeHandle & handle,shape_inference::InferenceContext * context,mlir::Builder builder)1256 StatusOr<TensorType> ImporterBase::ConvertElementTypeAndShape(
1257     mlir::Type element_type, const shape_inference::ShapeHandle& handle,
1258     shape_inference::InferenceContext* context, mlir::Builder builder) {
1259   if (!context->RankKnown(handle)) {
1260     return mlir::UnrankedTensorType::get(element_type);
1261   }
1262 
1263   // Sentinel for an unknown dimension size. getTensorType interprets any
1264   // negative value as an unknown dimension.
1265   // TODO(jmolloy): Ideally this shouldn't be a local sentinel.
1266   const int64_t kUnknownDim = -1;
1267 
1268   absl::InlinedVector<int64_t, 4> dimensions;
1269   int32_t rank = context->Rank(handle);
1270   dimensions.reserve(rank);
1271   for (int i = 0; i < rank; ++i) {
1272     auto dim_handle = context->Dim(handle, i);
1273     if (!context->ValueKnown(dim_handle))
1274       dimensions.push_back(kUnknownDim);
1275     else
1276       dimensions.push_back(context->Value(dim_handle));
1277   }
1278 
1279   return mlir::RankedTensorType::get(
1280       llvm::makeArrayRef(dimensions.begin(), dimensions.end()), element_type);
1281 }
1282 
ConvertSubtypes(const std::vector<shape_inference::ShapeAndType> * handle_subtypes,shape_inference::InferenceContext * context,mlir::Builder builder)1283 StatusOr<ImporterBase::ElementSubtypes> ImporterBase::ConvertSubtypes(
1284     const std::vector<shape_inference::ShapeAndType>* handle_subtypes,
1285     shape_inference::InferenceContext* context, mlir::Builder builder) {
1286   ElementSubtypes subtypes;
1287   if (!handle_subtypes) return subtypes;
1288 
1289   subtypes.reserve(handle_subtypes->size());
1290   for (const auto& subtype : *handle_subtypes) {
1291     mlir::Type element_type;
1292     TF_RETURN_IF_ERROR(
1293         ::tensorflow::ConvertDataType(subtype.dtype, builder, &element_type));
1294     TF_ASSIGN_OR_RETURN(TensorType type,
1295                         ConvertElementTypeAndShape(element_type, subtype.shape,
1296                                                    context, builder));
1297     subtypes.push_back(type);
1298   }
1299   return subtypes;
1300 }
1301 
ConvertFunctionCallAttribute(const std::string & base_name,const AttrValue & value,NamedAttrList * attributes)1302 Status ImporterBase::ConvertFunctionCallAttribute(const std::string& base_name,
1303                                                   const AttrValue& value,
1304                                                   NamedAttrList* attributes) {
1305   TF_ASSIGN_OR_RETURN(auto func_attr,
1306                       ConvertFunctionCallName(value.func().name()));
1307   if (!func_attr) return OkStatus();
1308   attributes->push_back(builder_.getNamedAttr(base_name, func_attr));
1309 
1310   for (const auto& it : value.func().attr()) {
1311     auto name = absl::StrCat(base_name, ".", it.first);
1312     TF_ASSIGN_OR_RETURN(auto value, ConvertAttributeValue(it.second));
1313     attributes->push_back(builder_.getNamedAttr(name, value));
1314   }
1315   return OkStatus();
1316 }
1317 
ConvertFunctionCallName(const std::string & func_name)1318 StatusOr<mlir::FlatSymbolRefAttr> ImporterBase::ConvertFunctionCallName(
1319     const std::string& func_name) {
1320   // Some ops like XlaHostCompute op uses empty value to represent missing
1321   // functions. Such attribute values should be defined optional in MLIR
1322   // definition.
1323   if (func_name.empty()) return mlir::FlatSymbolRefAttr();
1324 
1325   TF_RETURN_IF_ERROR(ConvertLibFunction(func_name));
1326   auto mlir_func_name = (*tf_name_to_mlir_name_)[func_name];
1327   return mlir::SymbolRefAttr::get(builder_.getContext(), mlir_func_name);
1328 }
1329 
ConvertAttributeValue(const AttrValue & value)1330 StatusOr<mlir::Attribute> ImporterBase::ConvertAttributeValue(
1331     const AttrValue& value) {
1332   switch (value.value_case()) {
1333     case AttrValue::kFunc: {
1334       // TODO(b/156546237): Unify kFunc/NameAttrList attribute representation.
1335       // Currently kFunc/NameAttrList attributes in a kList/repeated AttrValue
1336       // will not use this representation. This also doesn't handle empty
1337       // function values like ConvertFunctionCallName method.
1338       NamedAttrList attrs;
1339       for (const auto& func_attr : value.func().attr()) {
1340         TF_ASSIGN_OR_RETURN(
1341             auto attr, ImporterBase::ConvertAttributeValue(func_attr.second));
1342         attrs.push_back(builder_.getNamedAttr(func_attr.first, attr));
1343       }
1344       auto func_attrs = builder_.getDictionaryAttr(attrs);
1345       return mlir::TF::FuncAttr::get(context_, value.func().name(), func_attrs);
1346     }
1347     case AttrValue::kList: {
1348       if (!value.list().func().empty()) {
1349         absl::InlinedVector<mlir::Attribute, 8> attrs;
1350         for (const auto& item : value.list().func()) {
1351           TF_ASSIGN_OR_RETURN(auto attr, ConvertFunctionCallName(item.name()));
1352           if (item.attr_size() != 0)
1353             return errors::Unimplemented(
1354                 "func attributes with non-zero attr.size()");
1355           if (attr) attrs.push_back(attr);
1356         }
1357         return builder_.getArrayAttr(
1358             llvm::makeArrayRef(attrs.begin(), attrs.end()));
1359       }
1360       return ConvertNonFuncAttributeValue(value, &builder_);
1361     }
1362     default:
1363       return ConvertNonFuncAttributeValue(value, &builder_);
1364   }
1365 }
1366 
GetArgsAndRetsFromFunctionBody(const FunctionBody & fbody,absl::InlinedVector<OutputTensor,4> * arg_nodes,absl::InlinedVector<OutputTensor,4> * ret_nodes,absl::InlinedVector<Node *,4> * control_ret_nodes)1367 void ImporterBase::GetArgsAndRetsFromFunctionBody(
1368     const FunctionBody& fbody, absl::InlinedVector<OutputTensor, 4>* arg_nodes,
1369     absl::InlinedVector<OutputTensor, 4>* ret_nodes,
1370     absl::InlinedVector<Node*, 4>* control_ret_nodes) {
1371   arg_nodes->reserve(fbody.arg_nodes.size());
1372   ret_nodes->reserve(fbody.ret_nodes.size());
1373   for (auto arg : fbody.arg_nodes) {
1374     arg_nodes->emplace_back(arg, 0);
1375   }
1376   for (auto ret : fbody.ret_nodes) {
1377     ret_nodes->emplace_back(ret, 0);
1378   }
1379   *control_ret_nodes = fbody.control_ret_nodes;
1380 }
1381 
ConvertLibFunction(llvm::StringRef func_name)1382 Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) {
1383   // If the library function has been converted already, nothing needs to be
1384   // done.
1385   if (tf_name_to_mlir_name_->find(std::string(func_name)) !=
1386       tf_name_to_mlir_name_->end())
1387     return OkStatus();
1388 
1389   std::string mlir_func_name(
1390       function_name_uniquifier_->GetUniqueName(func_name));
1391   (*tf_name_to_mlir_name_)[std::string(func_name)] = mlir_func_name;
1392 
1393   const auto& func_lib = graph_flib_;
1394   const auto* func_def = func_lib.Find(std::string(func_name));
1395   if (func_def == nullptr) {
1396     return errors::FailedPrecondition(
1397         absl::StrCat("Failed to find function '", StringRefToView(func_name),
1398                      "'. The imported TensorFlow GraphDef is ill-formed."));
1399   }
1400 
1401   // Converts the argument and return types to MLIR types.
1402   std::vector<mlir::NamedAttribute> attributes;
1403   attributes.reserve(func_def->attr_size());
1404   for (const auto& name_and_value : func_def->attr()) {
1405     // This is a function definition attribute, so it shouldn't contain
1406     // kFunc attribute and it is treated as normal one.
1407     TF_ASSIGN_OR_RETURN(auto attr,
1408                         ConvertAttributeValue(name_and_value.second));
1409     std::string attr_name =
1410         mangling_util::MangleAttributeName(name_and_value.first);
1411     attributes.push_back(builder_.getNamedAttr(attr_name, attr));
1412   }
1413 
1414   // Checks opdef stateful attribute and import that as Function Attribute
1415   if (func_def->signature().is_stateful()) {
1416     auto stateful_str = mlir::TF::TensorFlowDialect::GetStatefulAttrName();
1417     attributes.push_back(
1418         builder_.getNamedAttr(stateful_str, builder_.getUnitAttr()));
1419   }
1420 
1421   // Checks for an associated custom gradient function. Adds it to the attribute
1422   // list of this function.
1423   auto grad_func_name = func_lib.FindGradient(std::string(func_name));
1424   if (!grad_func_name.empty()) {
1425     TF_RETURN_IF_ERROR(ConvertLibFunction(grad_func_name));
1426     auto mlir_grad_func_name = (*tf_name_to_mlir_name_)[grad_func_name];
1427     auto gradient_attr =
1428         mlir::SymbolRefAttr::get(builder_.getContext(), mlir_grad_func_name);
1429     auto grad_string = mlir::TF::TensorFlowDialect::GetGradientAttrName();
1430     attributes.push_back(builder_.getNamedAttr(grad_string, gradient_attr));
1431   }
1432 
1433   deferred_functions_.emplace(func_name.str(), attributes);
1434   return OkStatus();
1435 }
1436 
PruneUnreachableNodes(std::unordered_map<string,Node * > * node_name_map)1437 Status ImporterBase::PruneUnreachableNodes(
1438     std::unordered_map<string, Node*>* node_name_map) {
1439   std::unordered_set<const Node*> prune_start;
1440   TF_RETURN_IF_ERROR(GetInputOutputNodes(*node_name_map, &prune_start));
1441 
1442   if (!prune_start.empty()) {
1443     if (PruneForReverseReachability(graph_.get(), prune_start)) {
1444       VLOG(1) << "Pruned unused nodes in graphdef";
1445     } else {
1446       VLOG(1) << "No unused nodes in graphdef to prune";
1447     }
1448   } else {
1449     VLOG(1) << "No output nodes specified, skipping pruning";
1450   }
1451   return OkStatus();
1452 }
1453 
ConvertFeedsToPlaceholders(std::unordered_map<string,Node * > * node_name_map)1454 Status ImporterBase::ConvertFeedsToPlaceholders(
1455     std::unordered_map<string, Node*>* node_name_map) {
1456   // Feeds (edges) are converted into single-output placeholder nodes to
1457   // simplify the conversion process.
1458   TF_ASSIGN_OR_RETURN(auto feeds_by_node, GetFeedsByNode(specs_.inputs));
1459   for (const auto& it : feeds_by_node) {
1460     TensorId tensor = ParseTensorName(it.first);
1461     auto jt = node_name_map->find(std::string(tensor.node()));
1462     if (jt == node_name_map->end()) {
1463       return errors::FailedPrecondition(
1464           absl::StrCat("Graph does not contain node: ", tensor.node()));
1465     }
1466 
1467     Node* node = jt->second;
1468     auto op_name = node->op_def().name();
1469     if (op_name != "Placeholder" && op_name != "LegacyFedInput" &&
1470         op_name != FunctionLibraryDefinition::kArgOp) {
1471       for (const auto& output_tensor : it.second) {
1472         const int index = output_tensor.first;
1473         const ArrayInfo& array_info = output_tensor.second->second;
1474 
1475         DataType dtype = array_info.imported_dtype;
1476         // Uses the existing output type if it isn't specified by the user.
1477         if (dtype == DT_INVALID) {
1478           dtype = node->output_type(index);
1479         }
1480 
1481         TF_ASSIGN_OR_RETURN(
1482             auto placeholder_node_and_removed,
1483             CreatePlaceholderNodeForFeed(array_info.shape, dtype, node, index,
1484                                          *node_name_map));
1485 
1486         Node* placeholder_node = placeholder_node_and_removed.first;
1487         if (placeholder_node->in_edges().empty()) {
1488           graph_->AddControlEdge(graph_->source_node(), placeholder_node,
1489                                  true /* skip test for duplicates */);
1490         }
1491         if (placeholder_node->out_edges().empty()) {
1492           graph_->AddControlEdge(placeholder_node, graph_->sink_node(),
1493                                  true /* skip test for duplicates */);
1494         }
1495         remapped_feeds_[{it.first, index}] = placeholder_node->name();
1496         (*node_name_map)[placeholder_node->name()] = placeholder_node;
1497       }
1498     }
1499   }
1500   return OkStatus();
1501 }
1502 
PrepareConvert(const Graph & graph,std::unique_ptr<GraphDef> graph_def)1503 Status ImporterBase::PrepareConvert(const Graph& graph,
1504                                     std::unique_ptr<GraphDef> graph_def) {
1505   // TODO(fengliuai): Converting to GraphDef and back is the easiest way to
1506   // clone a graph.
1507   // TODO(fengliuai): clone the graph without going to graph_def first.
1508   if (graph_def == nullptr) {
1509     graph_def = std::make_unique<GraphDef>();
1510     graph.ToGraphDef(graph_def.get());
1511   }
1512   graph_ = std::make_unique<Graph>(graph.flib_def());
1513   GraphConstructorOptions opts;
1514   opts.allow_internal_ops = true;
1515   opts.add_default_attributes = true;
1516   TF_RETURN_IF_ERROR(::tensorflow::ConvertGraphDefToGraph(
1517       opts, std::move(*graph_def), graph_.get()));
1518 
1519   TF_RETURN_IF_ERROR(RemoveBackedges());
1520 
1521   TF_RETURN_IF_ERROR(CopyStackTraces(graph, graph_.get()));
1522 
1523   auto node_name_map = graph_->BuildNodeNameIndex();
1524 
1525   if (specs_.enable_shape_inference) {
1526     // TODO(jpienaar): Remove once infer shapes on import flag is removed.
1527     TF_RETURN_IF_ERROR(AddNodesToShapeRefiner(&node_name_map));
1528   } else {
1529     TF_RETURN_IF_ERROR(ConvertFeedsToPlaceholders(&node_name_map));
1530   }
1531 
1532   // Prune nodes in the graph that are not reachable from the output.
1533   if (specs_.prune_unused_nodes) {
1534     TF_RETURN_IF_ERROR(PruneUnreachableNodes(&node_name_map));
1535   }
1536 
1537   if (!specs_.enable_shape_inference) {
1538     // Re-initialize ordered_nodes_ since we might have modified the graph.
1539     GetReversePostOrder(
1540         *graph_, &ordered_nodes_,
1541         [](const Node* n1, const Node* n2) { return n1->name() < n2->name(); });
1542   }
1543 
1544   return OkStatus();
1545 }
1546 
Convert(llvm::StringRef func_name,mlir::FunctionType func_type,const absl::InlinedVector<OutputTensor,4> & arg_nodes,const absl::InlinedVector<OutputTensor,4> & ret_nodes,const absl::InlinedVector<Node *,4> & control_ret_nodes,llvm::ArrayRef<mlir::NamedAttribute> attrs)1547 Status ImporterBase::Convert(
1548     llvm::StringRef func_name, mlir::FunctionType func_type,
1549     const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
1550     const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
1551     const absl::InlinedVector<Node*, 4>& control_ret_nodes,
1552     llvm::ArrayRef<mlir::NamedAttribute> attrs) {
1553   // TODO(b/122040776): Uses debug info for FunctionDef.
1554   auto function = mlir::func::FuncOp::create(mlir::UnknownLoc::get(context_),
1555                                              func_name, func_type, attrs);
1556 
1557   module_.push_back(function);
1558   // Seeds the builder with an initial block.
1559   function.addEntryBlock();
1560   builder_ = mlir::OpBuilder(function.getBody());
1561 
1562   // Create the graph operation in which we will convert the individual nodes.
1563   auto graph = builder_.create<mlir::tf_executor::GraphOp>(
1564       function.getLoc(), func_type.getResults());
1565   builder_.createBlock(&graph.body());
1566 
1567   for (const Node* node : ordered_nodes_) {
1568     TF_RETURN_IF_ERROR(ConvertNode(*node));
1569   }
1570 
1571   // Adds the backedges back to the function by creating the source and sink
1572   // pairs.
1573   TF_RETURN_IF_ERROR(AddBackedges());
1574 
1575   TF_RETURN_IF_ERROR(ConvertFunctionArgAndRets(function, graph,
1576                                                func_type.getInputs(), arg_nodes,
1577                                                ret_nodes, control_ret_nodes));
1578 
1579   // TODO(jpienaar): Update post removing shape_refinier_.
1580   if (!specs_.enable_shape_inference) {
1581     // Refine graph's type given more precise fetch.
1582     auto fetch = graph.GetFetch();
1583     bool all_equal = true;
1584     for (auto it :
1585          llvm::zip_first(graph.getResults(), fetch.getOperandTypes())) {
1586       auto rt = std::get<1>(it);
1587       if (rt == std::get<0>(it).getType()) continue;
1588       std::get<0>(it).setType(rt);
1589       all_equal = false;
1590     }
1591     if (!all_equal) {
1592       function.setType(mlir::FunctionType::get(function.getContext(),
1593                                                func_type.getInputs(),
1594                                                graph.getResultTypes()));
1595     }
1596   }
1597 
1598   return OkStatus();
1599 }
1600 
ConvertFunctionArgAndRets(mlir::func::FuncOp func,mlir::tf_executor::GraphOp graph_op,llvm::ArrayRef<mlir::Type> arg_types,const absl::InlinedVector<OutputTensor,4> & arg_nodes,const absl::InlinedVector<OutputTensor,4> & ret_nodes,const absl::InlinedVector<Node *,4> & control_ret_nodes)1601 Status ImporterBase::ConvertFunctionArgAndRets(
1602     mlir::func::FuncOp func, mlir::tf_executor::GraphOp graph_op,
1603     llvm::ArrayRef<mlir::Type> arg_types,
1604     const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
1605     const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
1606     const absl::InlinedVector<Node*, 4>& control_ret_nodes) {
1607   // Store the arg/return attributes as a list rather than uniqueuing during
1608   // construction.
1609   llvm::SmallVector<mlir::NamedAttrList, 4> arg_attrs;
1610   arg_attrs.resize(func.getNumArguments());
1611   llvm::SmallVector<mlir::NamedAttrList, 4> ret_attrs;
1612   ret_attrs.resize(func.getNumResults());
1613 
1614   auto set_attributes_on_func = [&](Node* node, int64_t index, bool is_arg) {
1615     for (const auto& node_attr : node->attrs()) {
1616       const auto& key = node_attr.first;
1617       // Only import optional attributes (e.g., those starting with an
1618       // underscore).
1619       if (key.empty() || key[0] != '_') continue;
1620       // Ignore shape inference attributes as shape information is already
1621       // populated in the result type.
1622       if (IsOutputShapesAttribute(node_attr.second, key) ||
1623           IsResourceOutputShapesAttribute(node_attr.second, key))
1624         continue;
1625       TF_ASSIGN_OR_RETURN(auto converted_attr,
1626                           ConvertAttributeValue(node_attr.second));
1627       std::string dialect_attribute = "tf." + key;
1628       if (is_arg) {
1629         arg_attrs[index].set(dialect_attribute, converted_attr);
1630       } else {
1631         func.setResultAttr(index, dialect_attribute, converted_attr);
1632         ret_attrs[index].set(dialect_attribute, converted_attr);
1633       }
1634     }
1635     return OkStatus();
1636   };
1637 
1638   auto* bb = &func.front();
1639   llvm::SmallDenseMap<std::pair<Node*, int>, mlir::Value, 4>
1640       arg_nodes_to_values;
1641   for (int i = 0, e = arg_types.size(); i < e; ++i) {
1642     auto& arg_node = arg_nodes[i];
1643     // The lookup can't fail here: otherwise some nodes in the function haven't
1644     // be converted to mlir operations and don't have a mapping.
1645     mlir::Operation* island = node_values_.find(arg_node.node->id())->second;
1646 
1647     auto bb_arg = bb->getArgument(i);
1648     mlir::Value arg_def = bb_arg;
1649 
1650     if (island->getNumResults() != 2)
1651       return errors::InvalidArgument(
1652           "Only feed output tensors of single output nodes are supported");
1653 
1654     // Collect mapping of OutputTensor to associated block arg.
1655     arg_nodes_to_values.try_emplace({arg_node.node, arg_node.index}, arg_def);
1656     island->getResult(0).replaceAllUsesWith(arg_def);
1657     // Erase control outputs from feed.
1658     auto control_uses = island->getResult(1).getUses();
1659     for (auto& control_use : llvm::make_early_inc_range(control_uses))
1660       control_use.getOwner()->eraseOperand(control_use.getOperandNumber());
1661 
1662     if (!arg_node.node->requested_device().empty())
1663       arg_attrs[i].set("tf.device", builder_.getStringAttr(
1664                                         arg_node.node->requested_device()));
1665 
1666     if (arg_node.node->IsArg()) {
1667       TF_RETURN_IF_ERROR(
1668           set_attributes_on_func(arg_node.node, i, /*is_arg=*/true));
1669     }
1670 
1671     island->dropAllReferences();
1672     island->erase();
1673   }
1674 
1675   llvm::SmallVector<mlir::Value, 8> inst_to_return;
1676   for (auto ret_and_idx : llvm::enumerate(ret_nodes)) {
1677     const auto& ret = ret_and_idx.value();
1678     auto* inst = node_values_[ret.node->id()];
1679     if (ret.node->IsRetval()) {
1680       if (!ret.node->requested_device().empty())
1681         ret_attrs[ret_and_idx.index()].set(
1682             "tf.device", builder_.getStringAttr(ret.node->requested_device()));
1683       TF_RETURN_IF_ERROR(set_attributes_on_func(ret.node, ret_and_idx.index(),
1684                                                 /*is_arg=*/false));
1685       // Lookup the instruction inside the island
1686       auto island_op = llvm::cast<mlir::tf_executor::IslandOp>(inst);
1687       mlir::Operation* inner_op = &island_op.GetBody().front();
1688       // Remove kRetOp or kDeviceRetOp operation and return its operand.
1689       // kRetOp and kDeviceRetOp should have just one operand unless they have
1690       // control dependencies.
1691       if (inner_op->getNumOperands() != 1)
1692         return errors::Unimplemented("Return node with multiple inputs.");
1693       inst_to_return.push_back(inner_op->getOperand(0));
1694       inst->dropAllReferences();
1695       inst->erase();
1696     } else {
1697       // Lookup and use block arg if fetch is a feed.
1698       auto it = arg_nodes_to_values.find({ret.node, ret.index});
1699       if (it != arg_nodes_to_values.end())
1700         inst_to_return.push_back(it->second);
1701       else
1702         inst_to_return.push_back(inst->getResult(ret.index));
1703     }
1704   }
1705 
1706   for (Node* control_ret : control_ret_nodes) {
1707     auto* inst = node_values_[control_ret->id()];
1708     inst_to_return.push_back(*std::prev(inst->result_end()));
1709   }
1710 
1711   // Terminate the function by adding a Fetch operation to terminate the graph
1712   // and a return operation to return the Graph results.
1713   builder_.setInsertionPointToEnd(&graph_op.body().front());
1714   builder_.create<mlir::tf_executor::FetchOp>(graph_op.getLoc(),
1715                                               inst_to_return);
1716   builder_.setInsertionPointToEnd(bb);
1717   builder_.create<mlir::func::ReturnOp>(mlir::UnknownLoc::get(context_),
1718                                         graph_op.getResults());
1719 
1720   func.setAllArgAttrs(
1721       llvm::to_vector<4>(llvm::map_range(arg_attrs, [&](NamedAttrList& list) {
1722         return list.getDictionary(context_);
1723       })));
1724   func.setAllResultAttrs(
1725       llvm::to_vector<4>(llvm::map_range(ret_attrs, [&](NamedAttrList& list) {
1726         return list.getDictionary(context_);
1727       })));
1728 
1729   return OkStatus();
1730 }
1731 
GetLocation(const Node & node)1732 mlir::Location ImporterBase::GetLocation(const Node& node) {
1733   DVLOG(1) << "Getting location for " << node.name() << " " << &node;
1734   // TODO(b/142400497): What is the semantic contract for locations?
1735   const auto& debug_info = debug_info_.traces();
1736 
1737   // Create a location for node `name` in function `function_name`.
1738   auto create_location = [&](llvm::StringRef name,
1739                              llvm::StringRef function_name) -> mlir::Location {
1740     // Use the catenation of function and node names as the lookup key into the
1741     // debug info. This matches the way that the key is formed on the python
1742     // side.
1743     //
1744     // We also use this as the name for the NameLoc for ops in function, since
1745     // otherwise our names could collide across functions.
1746     // For ops in the main graph, we omit the "@function_name" (which, would be
1747     // just "@" since function_name would be empty) because some code seems to
1748     // depend on the name being this way for correctness.
1749     std::string debug_info_key = (name + "@" + function_name).str();
1750     std::string name_for_name_loc =
1751         function_name.empty() ? name.str() : debug_info_key;
1752     auto name_loc_id = mlir::StringAttr::get(context_, name_for_name_loc);
1753 
1754     llvm::SmallVector<mlir::Location, 4> locations;
1755     // Prefer stack traces if available, fallback to debug info if not, and then
1756     // finally to just name.
1757     if (auto stack_trace = node.GetStackTrace()) {
1758       DVLOG(1) << "Stack available for " << node.name();
1759       absl::Span<const StackFrame> frames = stack_trace->ToFrames();
1760       locations.reserve(frames.size());
1761       for (const StackFrame& frame : llvm::reverse(frames)) {
1762         auto file_name = mlir::StringAttr::get(context_, frame.file_name);
1763         // Use col 1 as there is no column info in StackTrace.
1764         auto file_line_loc =
1765             mlir::FileLineColLoc::get(file_name, frame.line_number, 1);
1766         locations.push_back(file_line_loc);
1767       }
1768     } else {
1769       DVLOG(1) << "No stack trace for " << node.name();
1770       const auto location_it = debug_info.find(debug_info_key);
1771       if (location_it != debug_info.end()) {
1772         DVLOG(1) << "Available serialized debug info for " << node.name();
1773         // Convert the stack trace to a chain of mlir::CallSiteLocs.
1774         const auto& trace = location_it->second;
1775         locations.reserve(trace.file_line_cols_size());
1776         for (const auto& location : trace.file_line_cols()) {
1777           const auto& file = debug_info_.files(location.file_index());
1778           auto file_name = mlir::StringAttr::get(context_, file);
1779           auto file_line_loc = mlir::FileLineColLoc::get(
1780               file_name, location.line(), location.col());
1781           locations.push_back(file_line_loc);
1782         }
1783       }
1784     }
1785 
1786     // If there are no locations in the stack trace, fall back to just a
1787     // NameLoc with no child.
1788     if (locations.empty()) return mlir::NameLoc::get(name_loc_id);
1789 
1790     // Use the front FileLineColLoc to generate a NameLoc.
1791     mlir::Location node_name_loc =
1792         mlir::NameLoc::get(name_loc_id, locations.front());
1793 
1794     // If there are more locations then generate a stack trace, otherwise just
1795     // return the name loc.
1796     auto callsite_locs = llvm::makeArrayRef(locations).drop_front();
1797     return callsite_locs.empty()
1798                ? node_name_loc
1799                : mlir::CallSiteLoc::get(node_name_loc, callsite_locs);
1800   };
1801 
1802   // Create a location for node `name` in function `function_name`.
1803   auto create_op_type_and_name_locations = [&]() {
1804     return mlir::FusedLoc::get(
1805         context_,
1806         // Add the type operation for the propagation of op_type metadata.
1807         {mlir::NameLoc::get(
1808              mlir::StringAttr::get(context_, node.type_string() + ":")),
1809          create_location(node.name(), function_name_for_debug_info_)});
1810   };
1811 
1812   // For NextIteration nodes, location is used to pair source and sink nodes.
1813   // Hence, we use node name as location to keep it unique.
1814   // TODO(prakalps): In future the plan is to use tokens to pair source/sink
1815   // nodes. Then NextIteration nodes would not need to be handled separately.
1816   if (node.type_string() == "NextIteration") {
1817     return create_op_type_and_name_locations();
1818   }
1819 
1820   const auto& node_def = node.def();
1821   auto original_nodes =
1822       node_def.experimental_debug_info().original_node_names();
1823   auto original_funcs =
1824       node_def.experimental_debug_info().original_func_names();
1825 
1826   if (original_nodes.empty()) {
1827     return create_op_type_and_name_locations();
1828   } else {
1829     // If the original nodes are defined, then we use them to get a list of
1830     // call sites, and then fuse them to a single fused location, with the name
1831     // of the node_def.
1832     llvm::SmallVector<mlir::Location, 4> node_locations;
1833     node_locations.reserve(original_nodes.size() + 2);
1834     // Add the type operation for the propagation of op_type metadata.
1835     node_locations.push_back(mlir::NameLoc::get(
1836         mlir::StringAttr::get(context_, node.type_string() + ":")));
1837     // Retrieve the names from the experimental_debug_info.
1838     for (int i = 0, e = original_nodes.size(); i != e; ++i) {
1839       auto node_name = original_nodes[i];
1840       auto func_name = (i < original_funcs.size()) ? original_funcs[i] : "";
1841       node_locations.push_back(create_location(node_name, func_name));
1842     }
1843     // Retrieve the name of the node_def.
1844     node_locations.push_back(
1845         create_location(node.name(), function_name_for_debug_info_));
1846     return mlir::FusedLoc::get(context_, node_locations);
1847   }
1848 }
1849 
EmitErrorWithLocationStr(const Node & node,const Status & error_status)1850 Status ImporterBase::EmitErrorWithLocationStr(const Node& node,
1851                                               const Status& error_status) {
1852   const mlir::Location location = GetLocation(node);
1853   mlir::emitError(location);
1854   return error_handler_.Combine(error_status);
1855 }
1856 
CreateOperation(const Node & node,llvm::StringRef node_type_name,const mlir::OperationState & result,const llvm::SmallVectorImpl<mlir::Value> & control_operands)1857 mlir::Operation* ImporterBase::CreateOperation(
1858     const Node& node, llvm::StringRef node_type_name,
1859     const mlir::OperationState& result,
1860     const llvm::SmallVectorImpl<mlir::Value>& control_operands) {
1861   // For the tf.executor specific operations (not wrapped in an island), we
1862   // have an extra returned value for the control result, and we concatenate
1863   // control and non-control operands.
1864   mlir::SmallVector<mlir::Type, 4> types(result.types);
1865   types.push_back(mlir::tf_executor::ControlType::get(builder_.getContext()));
1866   mlir::SmallVector<mlir::Value, 4> operands(result.operands);
1867   operands.append(control_operands.begin(), control_operands.end());
1868 
1869   auto loc = result.location;
1870   // Dispatch based on the name and create the appropriate operation.
1871   if (node.IsSwitch()) {
1872     // Switch and _SwitchN both are in switch class, differentiate based on
1873     // op name.
1874     if (node.op_def().name() == "_SwitchN") {
1875       return builder_.create<mlir::tf_executor::SwitchNOp>(loc, types, operands,
1876                                                            result.attributes);
1877     }
1878     return builder_.create<mlir::tf_executor::SwitchOp>(loc, types, operands,
1879                                                         result.attributes);
1880   }
1881   if (node.IsMerge()) {
1882     return builder_.create<mlir::tf_executor::MergeOp>(loc, types, operands,
1883                                                        result.attributes);
1884   }
1885   if (node.IsNextIteration()) {
1886     // NextIteration is a bit special, we create a pair of operations that are
1887     // linked together through a token returned by the source.
1888     // We make use of a separate builder to insert the source at the top of
1889     // the block.
1890     mlir::OpBuilder builder_at_begin(builder_.getBlock(),
1891                                      builder_.getBlock()->begin());
1892     auto source_op =
1893         builder_at_begin.create<mlir::tf_executor::NextIterationSourceOp>(
1894             loc, operands[0].getType(), result.attributes);
1895     return builder_.create<mlir::tf_executor::NextIterationSinkOp>(
1896         loc, source_op.token(), operands, result.attributes);
1897   }
1898   if (node.IsLoopCond()) {
1899     return builder_.create<mlir::tf_executor::LoopCondOp>(loc, types, operands,
1900                                                           result.attributes);
1901   }
1902   if (node.IsEnter()) {
1903     return builder_.create<mlir::tf_executor::EnterOp>(loc, types, operands,
1904                                                        result.attributes);
1905   }
1906   if (node.IsExit()) {
1907     return builder_.create<mlir::tf_executor::ExitOp>(loc, types, operands,
1908                                                       result.attributes);
1909   }
1910   if (node.IsControlTrigger()) {
1911     return builder_.create<mlir::tf_executor::ControlTriggerOp>(
1912         loc, operands, result.attributes);
1913   }
1914   // Regular TensorFlow operation are wrapped in a tf_executor.island.
1915   auto island = builder_.create<mlir::tf_executor::IslandOp>(
1916       result.location, types, control_operands,
1917       mlir::ArrayRef<mlir::NamedAttribute>{});
1918   island.body().push_back(new mlir::Block);
1919   mlir::OpBuilder island_builder =
1920       mlir::OpBuilder::atBlockEnd(&island.GetBody());
1921 
1922   // Create the operation inside the island now.
1923   mlir::Operation* inner_op = island_builder.create(result);
1924 
1925   // Sets operand_segment_sizes or result_segment_sizes attribute to the op.
1926   const auto set_segment_sizes_attr =
1927       [&](const NameRangeMap& arg_ranges,
1928           const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
1929           llvm::StringRef attr_name) {
1930         std::vector<int32_t> values;
1931         values.reserve(args.size());
1932         for (const auto& arg : args) {
1933           auto range = arg_ranges.at(arg.name());
1934           values.push_back(
1935               range.second - range.first);
1936         }
1937         auto attr_value = mlir::DenseI32ArrayAttr::get(inner_op->getContext(), values);
1938         inner_op->setAttr(attr_name, attr_value);
1939       };
1940 
1941   if (inner_op->hasTrait<mlir::OpTrait::AttrSizedOperandSegments>() ||
1942       inner_op->hasTrait<mlir::OpTrait::AttrSizedResultSegments>()) {
1943     // The op has multiple variadic operands or results.
1944     // Calculate operand and result segment sizes using the OpDef.
1945     NameRangeMap input_ranges, output_ranges;
1946     // This will fail only if the OpDef is syntactically invalid.
1947     // TODO(jpienaar): Convert this CHECK into a properly propagated error.
1948     TF_CHECK_OK(
1949         NameRangesForNode(node, node.op_def(), &input_ranges, &output_ranges));
1950     if (inner_op->hasTrait<mlir::OpTrait::AttrSizedOperandSegments>()) {
1951       // Add derived "operand_segment_sizes" attr to the created operation.
1952       // TODO(b/146937733): Don't use <void> here.
1953       set_segment_sizes_attr(input_ranges, node.op_def().input_arg(),
1954                              mlir::OpTrait::AttrSizedOperandSegments<
1955                                  void>::getOperandSegmentSizeAttr());
1956     }
1957 
1958     if (inner_op->hasTrait<mlir::OpTrait::AttrSizedResultSegments>()) {
1959       // Add derived "result_segment_sizes" attr to the created operation.
1960       // TODO(b/146937733): Don't use <void> here.
1961       set_segment_sizes_attr(output_ranges, node.op_def().output_arg(),
1962                              mlir::OpTrait::AttrSizedResultSegments<
1963                                  void>::getResultSegmentSizeAttr());
1964     }
1965   }
1966 
1967   if (VLOG_IS_ON(1)) {
1968     mlir::OperationName name = inner_op->getName();
1969     if (!name.isRegistered() &&
1970         // Skip unmodelled ops that are handled differently.
1971         (node_type_name != "_Arg" && node_type_name != "_Retval") &&
1972         !unmodelled_op_names_.count(name.getIdentifier())) {
1973       if (node.op_def().is_stateful()) {
1974         VLOG(1) << "[potentially conservative] Op type `" << node.type_string()
1975                 << "` is stateful but effects not modelled";
1976       } else {
1977         // See if any resource type is used.
1978         bool resource = false;
1979         std::function<bool(mlir::Type)> record_resource;
1980         record_resource = [&](mlir::Type type) {
1981           if (resource) return true;
1982           if (type.isa<mlir::TF::ResourceType>()) {
1983             resource = true;
1984             return true;
1985           }
1986           if (auto with_subtype =
1987                   type.dyn_cast<mlir::SubElementTypeInterface>()) {
1988             with_subtype.walkSubTypes(
1989                 [&](mlir::Type t) { record_resource(t); });
1990           }
1991           return resource;
1992         };
1993 
1994         for (mlir::Type t : inner_op->getResultTypes())
1995           if (record_resource(t)) break;
1996         for (mlir::Type t : inner_op->getOperandTypes())
1997           if (record_resource(t)) break;
1998         if (resource) {
1999           unmodelled_op_names_.insert(name.getIdentifier());
2000           VLOG(1) << "[potentially conservative] Op type `"
2001                   << node.type_string()
2002                   << "` has resource operands/results but effects not modelled";
2003         }
2004       }
2005     }
2006   }
2007 
2008   // Add the terminator for the island
2009   island_builder.create<mlir::tf_executor::YieldOp>(result.location,
2010                                                     inner_op->getResults());
2011   return island.getOperation();
2012 }
2013 
ConvertNode(const Node & node)2014 Status ImporterBase::ConvertNode(const Node& node) {
2015   if (!node.IsOp()) {
2016     // Don't import the pseudo-nodes _SOURCE or _SINK. These are added by
2017     // Graph and don't exist in GraphDef.
2018     return OkStatus();
2019   }
2020 
2021   // If it is a custom OP, its definition should be found in the library. We
2022   // create the MLIR function and insert it to the module if it doesn't exist.
2023   std::string node_type_name = node.type_string();
2024   const auto* func_def = graph_flib_.Find(node_type_name);
2025   bool convert_to_legacy_call = false;
2026   if (func_def) {
2027     TF_RETURN_IF_ERROR(ConvertLibFunction(node_type_name));
2028     node_type_name = (*tf_name_to_mlir_name_)[node_type_name];
2029     convert_to_legacy_call = true;
2030   }
2031 
2032   auto get_full_op_name = [&](const std::string& op_name) {
2033     const char* kTfPrefix = "tf.";
2034     return kTfPrefix + op_name;
2035   };
2036 
2037   std::string op_name = get_full_op_name(node_type_name);
2038   if (back_edge_node_output_.contains(&node)) {
2039     op_name = op_name + ".sink";
2040   }
2041 
2042   mlir::OperationState result(GetLocation(node), op_name);
2043   for (int i = 0; i < node.num_outputs(); ++i) {
2044     // The backedge has been removed, so we shouldn't count the corresponding
2045     // output from the src node when converting to an operation.
2046     if (back_edge_node_output_.contains(&node) &&
2047         back_edge_node_output_[&node] == i) {
2048       continue;
2049     }
2050     TF_ASSIGN_OR_RETURN(auto type, InferOutputType(node, i, builder_));
2051     result.types.push_back(type);
2052   }
2053 
2054   // Surprisingly input edges can be nondeterministically ordered. This
2055   // particularly seems to be the case for the control edges between _SOURCE
2056   // and _SINK that the Graph constructor inserts. Copy the input edges and
2057   // sort the edges, but only the control edges, not data edges!
2058   // TODO(jmolloy): We should probably just ignore _SOURCE and _SINK nodes.
2059   // They'll break roundtripping anyway unless we strip them when converting
2060   // back to graphdef.
2061   absl::InlinedVector<const Edge*, 8> in_edges(node.in_edges().size());
2062   absl::c_copy(node.in_edges(), in_edges.begin());
2063   absl::c_stable_sort(in_edges, [](const Edge* e1, const Edge* e2) {
2064     if (e1->IsControlEdge() && !e2->IsControlEdge()) return false;
2065     if (!e1->IsControlEdge() && e2->IsControlEdge()) return true;
2066     if (e1->IsControlEdge() && e2->IsControlEdge())
2067       return e1->src()->id() < e2->src()->id();
2068     return e1->dst_input() < e2->dst_input();
2069   });
2070 
2071   result.operands.reserve(in_edges.size());
2072 
2073   // Collect the control operands separately, they will be held by the island.
2074   mlir::SmallVector<mlir::Value, 8> control_operands;
2075 
2076   for (const auto* input_edge : in_edges) {
2077     const Node& input_node = *input_edge->src();
2078     if (input_node.IsSource()) {
2079       if (in_edges.size() != 1) {
2080         return errors::FailedPrecondition(
2081             "The node has other inputs besides the _Source node");
2082       }
2083       // We don't import the _SOURCE node.
2084       continue;
2085     }
2086     if (input_node.IsArg() && input_edge->IsControlEdge()) {
2087       // Currently we have not reached consensus as to what TF function
2088       // semantics are (b/133509504). Here we assume that all arguments to a
2089       // function should be available before we start execution of any internal
2090       // node. This makes the control dependencies between function arguments
2091       // and internal nodes redundant, and so we do not import them. The TF
2092       // inliner however assumes no such dependency between function args and
2093       // internal nodes exists, unless explicitly stated. Since we drop control
2094       // dependencies here, it leads to loss of information. If the function is
2095       // inlined later, the inliner would not know of these explicit control
2096       // dependencies present in the original graph.
2097       continue;
2098     }
2099     if (node_values_.find(input_node.id()) == node_values_.end())
2100       return errors::FailedPrecondition(
2101           "Graph not traversed in reverse post order; use seen before def!");
2102     mlir::Operation* inst = node_values_[input_node.id()];
2103     if (input_edge->IsControlEdge())
2104       control_operands.push_back(inst->getResult(inst->getNumResults() - 1));
2105     else
2106       result.operands.push_back(inst->getResult(input_edge->src_output()));
2107   }
2108 
2109   using FuncPairType = std::pair<const std::string*, const AttrValue*>;
2110   std::vector<FuncPairType> funcs;
2111   result.attributes.reserve(node.attrs().size() + 2);
2112   auto abstract_op = result.name.getRegisteredInfo();
2113   auto derived_op =
2114       abstract_op
2115           ? abstract_op->getInterface<mlir::DerivedAttributeOpInterface>()
2116           : nullptr;
2117   for (const auto& name_and_value : node.attrs()) {
2118     const auto& attr_name = name_and_value.first;
2119     // Skip adding derived attributes to the generated op.
2120     if (derived_op && derived_op->isDerivedAttribute(attr_name)) continue;
2121     const AttrValue& attr_value = name_and_value.second;
2122 
2123     // Remove _output_shapes attribute that will be added by the exporter.
2124     if (IsOutputShapesAttribute(attr_value, attr_name)) continue;
2125 
2126     if (attr_value.value_case() == AttrValue::kFunc) {
2127       // Attribute iteration order is not defined for protocol buffer Map.
2128       // Process function attributes separately in the lexicographical order to
2129       // have deterministic order of functions in the constructed IR.
2130       funcs.emplace_back(&attr_name, &attr_value);
2131     } else {
2132       TF_ASSIGN_OR_RETURN(auto attr, ConvertAttributeValue(attr_value));
2133       result.attributes.push_back(builder_.getNamedAttr(attr_name, attr));
2134     }
2135   }
2136 
2137   auto comparator = [](const FuncPairType& a, const FuncPairType& b) {
2138     return *a.first < *b.first;
2139   };
2140   std::sort(funcs.begin(), funcs.end(), comparator);
2141   for (const auto& func : funcs) {
2142     TF_RETURN_IF_ERROR(ConvertFunctionCallAttribute(*func.first, *func.second,
2143                                                     &result.attributes));
2144   }
2145 
2146   const auto& node_def = node.def();
2147   // NodeDef can contain partial TF device names. In such cases, canonicalize
2148   // it. Note that in current TF, placer will place full device name to each
2149   // node.
2150   DeviceNameUtils::ParsedName parsed_name;
2151   if (!DeviceNameUtils::ParseFullName(node_def.device(), &parsed_name)) {
2152     return errors::InvalidArgument(
2153         "Op ", op_name, " has invalid device name: ", node_def.device());
2154   }
2155   // Keep the parsed name untouched if the device name is empty.
2156   if (!node_def.device().empty()) {
2157     if (!parsed_name.has_type) {
2158       parsed_name.type = "CPU";
2159       parsed_name.has_type = true;
2160     }
2161     if (!parsed_name.has_id) {
2162       parsed_name.id = 0;
2163       parsed_name.has_id = true;
2164     }
2165   }
2166   result.attributes.push_back(builder_.getNamedAttr(
2167       "device", builder_.getStringAttr(
2168                     DeviceNameUtils::ParsedNameToString(parsed_name))));
2169 
2170   // Map user function calls to LegacyCall ops and add the user function name
2171   // as an attribute.
2172   if (convert_to_legacy_call) {
2173     result.name = mlir::OperationName(get_full_op_name("LegacyCall"), context_);
2174     mlir::SymbolRefAttr val =
2175         mlir::SymbolRefAttr::get(builder_.getContext(), node_type_name);
2176     result.addAttribute("f", val);
2177 
2178     if (!result.attributes.get("_disable_call_shape_inference")) {
2179       result.addAttribute("_disable_call_shape_inference",
2180                           builder_.getBoolAttr(false));
2181     }
2182   }
2183 
2184   auto composite_control_flow_op = [&](const std::string& name) {
2185     result.name = mlir::OperationName(get_full_op_name(name), context_);
2186     bool stateless = absl::StartsWith(node_type_name, "Stateless");
2187     mlir::BoolAttr val = builder_.getBoolAttr(stateless);
2188     result.attributes.push_back(builder_.getNamedAttr("is_stateless", val));
2189   };
2190 
2191   // Map Case/If/While and StatelessCase/If/While op in TensorFlow to the common
2192   // Case/If/While op in MLIR and add the differentiating attribute.
2193   if (node.IsCaseNode()) composite_control_flow_op("Case");
2194   if (node.IsIfNode()) composite_control_flow_op("If");
2195   if (node.IsWhileNode()) {
2196     composite_control_flow_op("While");
2197     auto* output_shapes = node.attrs().Find("output_shapes");
2198     if (output_shapes && !output_shapes->list().shape().empty())
2199       result.attributes.push_back(
2200           builder_.getNamedAttr("shape_invariant", builder_.getUnitAttr()));
2201   }
2202 
2203   // Register the mapping between the TF node and the newly created operation.
2204   node_values_[node.id()] =
2205       CreateOperation(node, node_type_name, result, control_operands);
2206   return OkStatus();
2207 }
2208 
2209 // Add the backedges to the CFG. Given a backedge, we replace the original
2210 // source and destination operations by two new operations. Most of the
2211 // fields of the replacements are copied from the original operations.
2212 // However,
2213 // - for the src operation, one output is inserted to the front of the output
2214 //   list. The type of the output is set to the type of the non-control result
2215 //   of the dst operation, and
2216 // - for the dst operation, one operand is inserted to the front of the
2217 //   operand list. This operand is using the first result of the src
2218 //   operation.
2219 // TODO(fengliuai): Preserve the order of the results and operands if
2220 // necessary.
AddBackedges()2221 Status ImporterBase::AddBackedges() {
2222   for (auto it : back_edge_dst_inputs_) {
2223     BackEdge& edge = it.second;
2224     if (!edge.src->IsNextIteration() || !edge.dst->IsMerge()) {
2225       return errors::FailedPrecondition(
2226           "Invalid backedge; should be from NextIteration to Merge!");
2227     }
2228     auto* sink = node_values_[edge.src->id()];
2229     auto* dst = node_values_[edge.dst->id()];
2230     TF_RETURN_IF_ERROR(AddBackedge(sink, dst, edge.dst_input));
2231   }
2232   return OkStatus();
2233 }
2234 
AddBackedge(mlir::Operation * sink,mlir::Operation * dst,int dst_input)2235 Status ImporterBase::AddBackedge(mlir::Operation* sink, mlir::Operation* dst,
2236                                  int dst_input) {
2237   // Get the NextIteration.Source operation from the token operand of the sink.
2238   mlir::Operation* source = sink->getOperand(0).getDefiningOp();
2239 
2240   // Adds the "source" to the operands of the dst by creating a new dst
2241   // operation.
2242   mlir::OperationState state(dst->getLoc(), dst->getName());
2243   auto num_operands = dst->getNumOperands();
2244   state.operands.reserve(num_operands + 1);
2245   for (int input = 0, e = num_operands + 1; input != e; ++input) {
2246     if (input < dst_input) {
2247       state.operands.push_back(dst->getOperand(input));
2248     } else if (input == dst_input) {
2249       state.operands.push_back(source->getResult(0));
2250     } else {
2251       state.operands.push_back(dst->getOperand(input - 1));
2252     }
2253   }
2254   state.attributes.assign(dst->getAttrs().begin(), dst->getAttrs().end());
2255   state.types.assign(dst->getResultTypes().begin(),
2256                      dst->getResultTypes().end());
2257   builder_.setInsertionPoint(dst);
2258   auto* new_dst = builder_.create(state);
2259 
2260   // Replaces the output uses of the old operation by the corresponding
2261   // result of the new operation, and deletes the old operation.
2262   for (unsigned i = 0, e = dst->getNumResults(); i != e; ++i) {
2263     auto new_output = new_dst->getResult(i);
2264     dst->getResult(i).replaceAllUsesWith(new_output);
2265   }
2266   dst->dropAllReferences();
2267   dst->erase();
2268   return OkStatus();
2269 }
2270 
InferLibFunctionType(const FunctionBody & fbody)2271 StatusOr<mlir::FunctionType> ImporterBase::InferLibFunctionType(
2272     const FunctionBody& fbody) {
2273   mlir::Builder builder(context_);
2274 
2275   // The FunctionBody contains a graph with a single-output _Arg node for each
2276   // function argument and a single-input _Retval node for each function return
2277   // value.
2278   //
2279   // We already populated the ShapeRefiner with all the information about the
2280   // shapes of these graph edges, so we just query it to build the corresponding
2281   // MLIR function type signature.
2282 
2283   llvm::SmallVector<mlir::Type, 4> arg_types;
2284   if (specs_.inputs.empty()) {
2285     arg_types.reserve(fbody.arg_types.size());
2286     for (auto arg : fbody.arg_nodes) {
2287       // Find node in the graph using the node id instead of using `arg`
2288       // directly because the graph has been cloned.
2289       auto* node = graph_->FindNodeId(arg->id());
2290       TF_ASSIGN_OR_RETURN(auto type,
2291                           InferOutputType(*node, /*idx=*/0, builder));
2292       arg_types.push_back(type);
2293     }
2294   } else {
2295     arg_types.reserve(fbody.arg_types.size());
2296     for (const auto& it : llvm::enumerate(specs_.inputs)) {
2297       mlir::Type element_type;
2298       const auto& node_info = it.value().second;
2299       DataType dtype = node_info.imported_dtype;
2300       // Uses the existing output type of the arg node if the data type of the
2301       // the node isn't specified through the import configuration.
2302       if (dtype == DT_INVALID) {
2303         auto arg = fbody.arg_nodes[it.index()];
2304         auto* node = graph_->FindNodeId(arg->id());
2305         dtype = node->output_type(0);
2306         if (dtype == DT_INVALID) {
2307           return errors::InvalidArgument("Input ", it.index(),
2308                                          "has invalid data type");
2309         }
2310       }
2311       TF_RETURN_IF_ERROR(
2312           ::tensorflow::ConvertDataType(dtype, builder, &element_type));
2313       if (node_info.shape.unknown_rank()) {
2314         arg_types.push_back(mlir::UnrankedTensorType::get(element_type));
2315       } else {
2316         llvm::SmallVector<int64_t, 4> shape;
2317         TF_RETURN_IF_ERROR(ConvertToMlirShape(node_info.shape, &shape));
2318         arg_types.push_back(mlir::RankedTensorType::get(shape, element_type));
2319       }
2320     }
2321   }
2322 
2323   llvm::SmallVector<mlir::Type, 4> ret_types;
2324   ret_types.reserve(fbody.ret_types.size());
2325   for (auto ret : fbody.ret_nodes) {
2326     // Find node in the graph using the node id instead of using `ret` directly
2327     // because the graph has been cloned.
2328     auto* node = graph_->FindNodeId(ret->id());
2329     TF_ASSIGN_OR_RETURN(auto type, InferInputType(*node, /*idx=*/0, builder));
2330     ret_types.push_back(type);
2331   }
2332 
2333   return builder.getFunctionType(arg_types, ret_types);
2334 }
2335 
2336 // Stateful helper class to import a TensorFlow model expressed in GraphDef into
2337 // an MLIR Module.
2338 //
2339 // The nodes defined in the graph are converted to a function called
2340 // 'func_name'. All library function definitions are converted to MLIR functions
2341 // in the module.
2342 class GraphDefImporter : public ImporterBase {
2343  public:
2344   // Main entry point: converts the given graph to an MLIR Module.
2345   static StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> Convert(
2346       mlir::MLIRContext* context, const Graph& graph,
2347       const GraphDebugInfo& debug_info,
2348       const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs,
2349       std::unordered_map<std::string, std::string>& tf_name_to_mlir_name);
2350 
2351  private:
GraphDefImporter(const FunctionLibraryDefinition & flib,const GraphDebugInfo & debug_info,const GraphImportConfig & specs,mlir::ModuleOp module,std::unordered_map<std::string,std::string> * tf_name_to_mlir_name,NameUniquifier * function_name_uniquifier)2352   explicit GraphDefImporter(
2353       const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
2354       const GraphImportConfig& specs, mlir::ModuleOp module,
2355       std::unordered_map<std::string, std::string>* tf_name_to_mlir_name,
2356       NameUniquifier* function_name_uniquifier)
2357       : ImporterBase(flib, debug_info, specs, module, tf_name_to_mlir_name,
2358                      function_name_uniquifier) {}
2359 
2360   // Returns the function signature of the main function of converted MLIR
2361   // module, the input nodes and output nodes. The type and shape information
2362   // for the function arguments are read from `specs`, but the type and shape
2363   // information for the function returns are inferred by the shape refiner in
2364   // ImporterBase.
2365   StatusOr<mlir::FunctionType> InferMainFunctionType(
2366       const GraphImportConfig& specs, mlir::MLIRContext* context,
2367       absl::InlinedVector<OutputTensor, 4>* arg_nodes,
2368       absl::InlinedVector<OutputTensor, 4>* ret_nodes);
2369 
2370   // Returns the function signature of the main function, alongside input and
2371   // output nodes, for function graphs. Arguments and return values are
2372   // determined by node op type. Type and shape information of the function are
2373   // inferred by the shape refiner in ImporterBase.
2374   StatusOr<mlir::FunctionType> GetArgsRetsAndTypesFromFunctionGraph(
2375       mlir::MLIRContext* context,
2376       absl::InlinedVector<OutputTensor, 4>* arg_nodes,
2377       absl::InlinedVector<OutputTensor, 4>* ret_nodes);
2378 
2379   // Finds the graph's target nodes/function's control ret nodes based on
2380   // supplied node names in `control_outputs`. If `control_outputs` are not
2381   // unique or a control ret node is missing, an error will be returned.
2382   Status GetControlRetsFromGraph(
2383       llvm::ArrayRef<std::string> control_outputs,
2384       absl::InlinedVector<Node*, 4>* control_ret_nodes);
2385 };
2386 
Convert(mlir::MLIRContext * context,const Graph & graph,const GraphDebugInfo & debug_info,const FunctionLibraryDefinition & flib_def,const GraphImportConfig & specs,std::unordered_map<std::string,std::string> & tf_name_to_mlir_name)2387 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> GraphDefImporter::Convert(
2388     mlir::MLIRContext* context, const Graph& graph,
2389     const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def,
2390     const GraphImportConfig& specs,
2391     std::unordered_map<std::string, std::string>& tf_name_to_mlir_name) {
2392   LoadImporterDialects(*context);
2393   mlir::OwningOpRef<mlir::ModuleOp> module =
2394       mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
2395   NameUniquifier function_name_uniquifier(flib_def);
2396 
2397   // importer.PrepareConvert below will attemp to clone the original `graph`
2398   // via conversion to the graph def first. Convert graph to graph_def here
2399   // first and avoid extra copies later.
2400   auto graph_def = std::make_unique<GraphDef>();
2401   graph.ToGraphDef(graph_def.get());
2402 
2403   static std::atomic<uint32> counter(0);
2404   uint32 current_file_prefix = counter++;
2405   const auto* graph_crash_handle = crash_analysis::ReportProtoDataOnCrash(
2406       absl::StrCat(current_file_prefix, "_mlir_import_graph.pbtxt"),
2407       *graph_def);
2408   auto reachable_flib = flib_def.ReachableDefinitions(*graph_def);
2409   const auto* flib_crash_handle = crash_analysis::ReportProtoDataOnCrash(
2410       absl::StrCat(current_file_prefix, "_mlir_import_flib.pbtxt"),
2411       reachable_flib.ToProto());
2412 
2413   auto scope_exit = llvm::make_scope_exit([&]() {
2414     crash_analysis::RemoveReportData(graph_crash_handle);
2415     crash_analysis::RemoveReportData(flib_crash_handle);
2416   });
2417 
2418   VLOG(1) << "Importing: "
2419           << ::tensorflow::DumpGraphToFile("tf_mlir_importer_base", graph,
2420                                            &flib_def);
2421 
2422   GraphDefImporter importer(flib_def, debug_info, specs, module.get(),
2423                             &tf_name_to_mlir_name, &function_name_uniquifier);
2424 
2425   TF_RETURN_IF_ERROR(importer.PrepareConvert(graph, std::move(graph_def)));
2426 
2427   mlir::FunctionType func_type;
2428   absl::InlinedVector<OutputTensor, 4> arg_nodes;
2429   absl::InlinedVector<OutputTensor, 4> ret_nodes;
2430   absl::InlinedVector<Node*, 4> control_ret_nodes;
2431   llvm::SmallVector<mlir::NamedAttribute, 1> attrs;
2432   if (specs.graph_as_function) {
2433     if (specs.prune_unused_nodes || !specs.inputs.empty() ||
2434         !specs.outputs.empty())
2435       return errors::InvalidArgument(
2436           "Pruning of graph is currently unsupported when the main graph is "
2437           "converted to a function.");
2438 
2439     TF_ASSIGN_OR_RETURN(func_type,
2440                         importer.GetArgsRetsAndTypesFromFunctionGraph(
2441                             context, &arg_nodes, &ret_nodes));
2442 
2443     TF_RETURN_IF_ERROR(importer.GetControlRetsFromGraph(specs.control_outputs,
2444                                                         &control_ret_nodes));
2445 
2446     mlir::Builder b(context);
2447     std::string s;
2448     llvm::raw_string_ostream ss(s);
2449     auto node_name = [&](const OutputTensor& tensor) {
2450       ss << tensor.node->name();
2451     };
2452     llvm::interleave(arg_nodes, ss, node_name, ",");
2453     auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str()));
2454     s.clear();
2455     llvm::interleave(ret_nodes, ss, node_name, ",");
2456     auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str()));
2457     s.clear();
2458     llvm::interleave(specs.control_outputs, ss, ",");
2459     auto control_outputs =
2460         b.getNamedAttr("control_outputs", b.getStringAttr(ss.str()));
2461 
2462     // Under `graph_as_function` mode, `tf.entry_function` is always set as it
2463     // is assumed feed, fetch, and target nodes are set correctly.
2464     attrs.push_back(b.getNamedAttr(
2465         "tf.entry_function",
2466         b.getDictionaryAttr({inputs, outputs, control_outputs})));
2467   } else {
2468     // Collects the argument and return nodes by looking up the node names
2469     // specified by the user.
2470     TF_ASSIGN_OR_RETURN(func_type, importer.InferMainFunctionType(
2471                                        specs, context, &arg_nodes, &ret_nodes));
2472 
2473     TF_RETURN_IF_ERROR(importer.GetControlRetsFromGraph(specs.control_outputs,
2474                                                         &control_ret_nodes));
2475 
2476     // TODO(prakalps): Refactor to keep tf.entry_function attribute encoding and
2477     // decoding in a centralized place.
2478     // Record the input and output mapping.
2479     if (!specs.inputs.empty() || !specs.outputs.empty() ||
2480         !specs.control_outputs.empty()) {
2481       mlir::Builder b(context);
2482       std::string s;
2483       llvm::raw_string_ostream ss(s);
2484       llvm::interleave(
2485           specs.inputs, ss,
2486           [&](const std::pair<std::string, ArrayInfo>& v) { ss << v.first; },
2487           ",");
2488       auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str()));
2489       s.clear();
2490       llvm::interleave(specs.outputs, ss, ",");
2491       auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str()));
2492       s.clear();
2493       llvm::interleave(specs.control_outputs, ss, ",");
2494       auto control_outputs =
2495           b.getNamedAttr("control_outputs", b.getStringAttr(ss.str()));
2496 
2497       attrs.push_back(b.getNamedAttr(
2498           "tf.entry_function",
2499           b.getDictionaryAttr({inputs, outputs, control_outputs})));
2500     }
2501   }
2502 
2503   // Record version info.
2504   PopulateTfVersions(module.get(), graph.versions());
2505 
2506   const llvm::StringRef& graph_func_name =
2507       specs.graph_func_name.empty() ? kImportModelDefaultGraphFuncName
2508                                     : specs.graph_func_name;
2509   TF_RETURN_IF_ERROR(importer.ImporterBase::Convert(graph_func_name, func_type,
2510                                                     arg_nodes, ret_nodes,
2511                                                     control_ret_nodes, attrs));
2512   TF_RETURN_IF_ERROR(importer.ImporterBase::ConvertDeferredFunctions());
2513 
2514   // Mark main function public, others private.
2515   for (auto function : module.get().getOps<mlir::func::FuncOp>()) {
2516     auto visibility = function.getName() == graph_func_name
2517                           ? mlir::func::FuncOp::Visibility::Public
2518                           : mlir::func::FuncOp::Visibility::Private;
2519     function.setVisibility(visibility);
2520   }
2521   VLOG(1) << "Imported: "
2522           << tensorflow::DumpMlirOpToFile("tf_mlir_imported_base",
2523                                           module.get());
2524   return module;
2525 }
2526 
InferMainFunctionType(const GraphImportConfig & specs,mlir::MLIRContext * context,absl::InlinedVector<OutputTensor,4> * arg_nodes,absl::InlinedVector<OutputTensor,4> * ret_nodes)2527 StatusOr<mlir::FunctionType> GraphDefImporter::InferMainFunctionType(
2528     const GraphImportConfig& specs, mlir::MLIRContext* context,
2529     absl::InlinedVector<OutputTensor, 4>* arg_nodes,
2530     absl::InlinedVector<OutputTensor, 4>* ret_nodes) {
2531   // Find all the input nodes and output nodes.
2532   // Feeds have been remapped to single output nodes (Placeholder), so an exact
2533   // name match is sufficient.
2534   absl::flat_hash_map<absl::string_view, int> inputs;
2535   for (auto input_and_idx : llvm::enumerate(specs.inputs)) {
2536     TensorId tensor = ParseTensorName(input_and_idx.value().first);
2537     auto remapped_it = remapped_feeds_.find(tensor);
2538     if (remapped_it != remapped_feeds_.end()) {
2539       inputs.insert({remapped_it->second, input_and_idx.index()});
2540     } else {
2541       inputs.insert({tensor.node(), input_and_idx.index()});
2542     }
2543   }
2544 
2545   absl::flat_hash_set<absl::string_view> output_node_names;
2546   std::vector<TensorId> outputs;
2547   output_node_names.reserve(specs.outputs.size());
2548   for (const auto& output : specs.outputs) {
2549     TensorId tensor = ParseTensorName(output);
2550     auto remapped_it = remapped_feeds_.find(tensor);
2551     if (remapped_it != remapped_feeds_.end()) {
2552       output_node_names.insert(remapped_it->second);
2553       outputs.push_back({remapped_it->second, 0});
2554     } else {
2555       output_node_names.insert(tensor.node());
2556       outputs.push_back(tensor);
2557     }
2558   }
2559 
2560   if (!inputs.empty() || !outputs.empty()) {
2561     arg_nodes->resize(inputs.size());
2562     ret_nodes->resize(outputs.size());
2563 
2564     for (Node* n : GetOrderedNodes()) {
2565       // Handle inputs/arguments.
2566       auto input_it = inputs.find(n->name());
2567       if (input_it != inputs.end()) {
2568         (*arg_nodes)[input_it->second] = {n, 0};
2569       }
2570 
2571       // Handle outputs/returns.
2572       if (output_node_names.contains(n->name())) {
2573         for (int i = 0, e = outputs.size(); i != e; ++i) {
2574           TensorId tensor = outputs[i];
2575           if (n->name() != tensor.node()) continue;
2576           (*ret_nodes)[i] = {n, tensor.index()};
2577         }
2578       }
2579     }
2580   }
2581 
2582   // Starts to construct the function type.
2583   mlir::Builder builder(context);
2584   llvm::SmallVector<mlir::Type, 4> arg_types;
2585   arg_types.reserve(specs.inputs.size());
2586   int i = 0;
2587   for (const auto& it : specs.inputs) {
2588     Node* arg_node = arg_nodes->at(i).node;
2589     if (arg_node == nullptr) {
2590       return errors::InvalidArgument("Input ", it.first,
2591                                      " was not found in graph");
2592     }
2593     mlir::Type element_type;
2594     const auto& node_info = it.second;
2595     DataType imported_dtype = node_info.imported_dtype;
2596     // Uses the existing output type of the arg node if the data type of the
2597     // the node isn't specified through the import configuration.
2598     if (imported_dtype == DT_INVALID) {
2599       imported_dtype = arg_node->output_type(0);
2600       if (imported_dtype == DT_INVALID) {
2601         return errors::InvalidArgument("Input ", i, "has invalid data type");
2602       }
2603     }
2604     // Check if we have subtypes first
2605     if (!node_info.subtypes.empty()) {
2606       std::vector<mlir::TensorType> subtypes;
2607       for (const auto& st : node_info.subtypes) {
2608         mlir::Type st_data_type;
2609         llvm::SmallVector<int64_t> shape;
2610         TF_RETURN_IF_ERROR(ConvertToMlirShape(st.shape, &shape));
2611         TF_RETURN_IF_ERROR(
2612             ConvertDataType(st.imported_dtype, builder, &st_data_type));
2613         subtypes.push_back(mlir::RankedTensorType::get(shape, st_data_type));
2614       }
2615       if (imported_dtype == DT_RESOURCE) {
2616         element_type =
2617             mlir::TF::ResourceType::get(subtypes, builder.getContext());
2618       } else if (imported_dtype == DT_VARIANT) {
2619         element_type =
2620             mlir::TF::VariantType::get(subtypes, builder.getContext());
2621       } else {
2622         return errors::InvalidArgument(DataType_Name(imported_dtype),
2623                                        " takes no subtypes.");
2624       }
2625     } else {
2626       TF_RETURN_IF_ERROR(
2627           ConvertDataType(imported_dtype, builder, &element_type));
2628     }
2629     if (node_info.shape.unknown_rank()) {
2630       arg_types.push_back(mlir::UnrankedTensorType::get(element_type));
2631     } else {
2632       llvm::SmallVector<int64_t, 4> shape;
2633       TF_RETURN_IF_ERROR(ConvertToMlirShape(node_info.shape, &shape));
2634       arg_types.push_back(mlir::RankedTensorType::get(shape, element_type));
2635     }
2636     i++;
2637   }
2638 
2639   llvm::SmallVector<mlir::Type, 4> ret_types;
2640   ret_types.reserve(specs.outputs.size());
2641   for (int i = 0, e = specs.outputs.size(); i != e; ++i) {
2642     if (ret_nodes->at(i).node == nullptr) {
2643       return errors::InvalidArgument("Output ", specs.outputs[i],
2644                                      " was not found in graph");
2645     }
2646   }
2647   for (const auto& ret : *ret_nodes) {
2648     if (ret.node->num_outputs() <= ret.index) {
2649       return errors::InvalidArgument("Invalid output index ", ret.index,
2650                                      " specified for node: ", ret.node->name());
2651     }
2652     TF_ASSIGN_OR_RETURN(auto type,
2653                         InferOutputType(*ret.node, ret.index, builder));
2654     ret_types.push_back(type);
2655   }
2656 
2657   return builder.getFunctionType(arg_types, ret_types);
2658 }
2659 
2660 StatusOr<mlir::FunctionType>
GetArgsRetsAndTypesFromFunctionGraph(mlir::MLIRContext * context,absl::InlinedVector<OutputTensor,4> * arg_nodes,absl::InlinedVector<OutputTensor,4> * ret_nodes)2661 GraphDefImporter::GetArgsRetsAndTypesFromFunctionGraph(
2662     mlir::MLIRContext* context, absl::InlinedVector<OutputTensor, 4>* arg_nodes,
2663     absl::InlinedVector<OutputTensor, 4>* ret_nodes) {
2664   auto add_node = [](Node* node, absl::InlinedVector<OutputTensor, 4>* nodes) {
2665     auto* attr = node->attrs().Find("index");
2666     if (!attr)
2667       return errors::InvalidArgument(node->type_string(), " node '",
2668                                      node->name(),
2669                                      "' is missing attribute 'index'");
2670 
2671     auto index = attr->i();
2672     const int num_nodes = nodes->size();
2673     if (num_nodes < index + 1) nodes->resize(index + 1);
2674 
2675     if ((*nodes)[index].node != nullptr)
2676       return errors::InvalidArgument(node->type_string(), " node '",
2677                                      node->name(), "' has attribute 'index' ",
2678                                      index, " that conflicts with node '",
2679                                      (*nodes)[index].node->name(), "'");
2680     (*nodes)[index] = {node, 0};
2681 
2682     return OkStatus();
2683   };
2684 
2685   // Collect arg and ret nodes from graph.
2686   for (auto* node : GetOrderedNodes())
2687     if (node->IsArg())
2688       TF_RETURN_IF_ERROR(add_node(node, arg_nodes));
2689     else if (node->IsRetval())
2690       TF_RETURN_IF_ERROR(add_node(node, ret_nodes));
2691 
2692   // Collect arg and ret types and create function type.
2693   mlir::Builder builder(context);
2694   llvm::SmallVector<mlir::Type, 4> arg_types;
2695   arg_types.reserve(arg_nodes->size());
2696   for (auto arg_node_and_idx : llvm::enumerate(*arg_nodes)) {
2697     auto& arg_node = arg_node_and_idx.value();
2698     if (arg_node.node == nullptr)
2699       return errors::InvalidArgument("Graph missing _Arg at index ",
2700                                      arg_node_and_idx.index());
2701 
2702     TF_ASSIGN_OR_RETURN(auto type,
2703                         InferOutputType(*arg_node.node, /*idx=*/0, builder));
2704     arg_types.push_back(type);
2705   }
2706 
2707   llvm::SmallVector<mlir::Type, 4> ret_types;
2708   ret_types.reserve(ret_nodes->size());
2709   for (auto ret_node_and_idx : llvm::enumerate(*ret_nodes)) {
2710     auto& ret_node = ret_node_and_idx.value();
2711     if (ret_node.node == nullptr)
2712       return errors::InvalidArgument("Graph missing _Retval at index ",
2713                                      ret_node_and_idx.index());
2714 
2715     TF_ASSIGN_OR_RETURN(auto type,
2716                         InferInputType(*ret_node.node, /*idx=*/0, builder));
2717     ret_types.push_back(type);
2718   }
2719 
2720   return builder.getFunctionType(arg_types, ret_types);
2721 }
2722 
GetControlRetsFromGraph(llvm::ArrayRef<std::string> control_outputs,absl::InlinedVector<Node *,4> * control_ret_nodes)2723 Status GraphDefImporter::GetControlRetsFromGraph(
2724     llvm::ArrayRef<std::string> control_outputs,
2725     absl::InlinedVector<Node*, 4>* control_ret_nodes) {
2726   if (control_outputs.empty()) return OkStatus();
2727 
2728   llvm::SmallDenseMap<llvm::StringRef, int32_t> controls_to_idx;
2729   for (auto control_and_idx : llvm::enumerate(control_outputs))
2730     controls_to_idx.insert({control_and_idx.value(), control_and_idx.index()});
2731 
2732   if (controls_to_idx.size() != control_outputs.size())
2733     return errors::InvalidArgument("Control outputs must be unique");
2734 
2735   control_ret_nodes->resize(controls_to_idx.size());
2736 
2737   for (auto* node : GetOrderedNodes()) {
2738     auto it = controls_to_idx.find(node->name());
2739     if (it != controls_to_idx.end()) (*control_ret_nodes)[it->second] = node;
2740   }
2741 
2742   for (auto node_and_name : llvm::zip(*control_ret_nodes, control_outputs))
2743     if (std::get<0>(node_and_name) == nullptr)
2744       return errors::InvalidArgument(
2745           "Control output '", std::get<1>(node_and_name), "' is missing");
2746 
2747   return OkStatus();
2748 }
2749 
2750 // Stateful helper class to import a TensorFlow model expressed in SavedModel
2751 // into an MLIR Module.
2752 class SavedModelObjectGraphImporter : public ImporterBase {
2753  public:
2754   // Main entry point: converts all functions in the given meta graph to an MLIR
2755   // Module.
2756   static StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> Convert(
2757       SavedModelV2Bundle* saved_model, absl::Span<std::string> exported_names,
2758       mlir::MLIRContext* context, bool add_default_attributes,
2759       bool unconditionally_use_set_output_shapes);
2760 
2761  private:
SavedModelObjectGraphImporter(const FunctionLibraryDefinition & flib,const GraphDebugInfo & debug_info,const GraphImportConfig & specs,mlir::ModuleOp module,std::unordered_map<std::string,std::string> * tf_name_to_mlir_name,NameUniquifier * function_name_uniquifier)2762   explicit SavedModelObjectGraphImporter(
2763       const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
2764       const GraphImportConfig& specs, mlir::ModuleOp module,
2765       std::unordered_map<std::string, std::string>* tf_name_to_mlir_name,
2766       NameUniquifier* function_name_uniquifier)
2767       : ImporterBase(flib, debug_info, specs, module, tf_name_to_mlir_name,
2768                      function_name_uniquifier) {}
2769 };
2770 
2771 // Determines the names used to reference objects in the SavedObjectGraph.
2772 class ObjectNames {
2773  public:
2774   explicit ObjectNames(const SavedObjectGraph& object_graph,
2775                        absl::Span<std::string> exported_names);
2776 
2777   // Gets the names that external users of the SavedModel can use to refer to
2778   // this node.
2779   llvm::ArrayRef<llvm::StringRef> GetExportedNames(int node_id) const;
2780 
2781   // Gets the name in the module symbol table for this node.
2782   // This name is only used for internal IR references.
2783   llvm::StringRef GetSymbolTableName(int node_id) const;
2784 
2785  private:
2786   // In the absence of any other information, use this name as the symbol table
2787   // name for this node.
2788   std::string GetDefaultSymbolTableName(int node_id) const;
2789   // Determines if a name is exported.
2790   bool IsExported(const std::string& name);
2791   // Main object graph traversal function.
2792   void RecursivelyVisitObjectGraph(int node_id);
2793   // Gets a stable StringRef from a std::string.
2794   llvm::StringRef SaveString(const std::string& s) const;
2795 
2796   // The object graph we are traversing.
2797   const SavedObjectGraph& object_graph_;
2798   // The set of names to export. Empty means "export all".
2799   std::unordered_set<std::string> names_to_export_;
2800 
2801   // When we recursively follow the object graph tree structure from the root,
2802   // we track its path in the object graph by pushing and popping from here
2803   // during traversal.
2804   llvm::SmallVector<std::string, 8> path_segments_;
2805   // The set of node IDs that are on the current DFS stack.
2806   // For cyclic object graphs, this prevents infinite recursion.
2807   std::unordered_set<int> on_stack_nodes_;
2808 
2809   // Key: node_id.
2810   // Value: all object names that node_id appears as.
2811   // Each object name corresponds to a unique path from the root of the object
2812   // graph.
2813   // The common intuitive case is when there is only one name for a given
2814   // object, which corresponds to the object graph being a tree.
2815   //
2816   // But, there cases where the object graph is a general graph. For
2817   // example, this happens commonly in Keras models, where `foo.bar` is
2818   // also reachable via the name `keras_api.foo.bar`.
2819   // Cycles are possible too.
2820   absl::flat_hash_map<int, std::vector<std::string>> object_names_;
2821 
2822   // Key: node_id
2823   // Value: all names that this object is exported as
2824   absl::flat_hash_map<int, llvm::SmallVector<llvm::StringRef, 1>>
2825       exported_names_;
2826   // Key: node_id
2827   // Value: pretty symbol table name to use for internal references to this
2828   // object.
2829   absl::flat_hash_map<int, llvm::StringRef> pretty_symbol_table_name_;
2830 
2831   // Stable strings we can take StringRef's into. Used only by the SaveString
2832   // method.
2833   mutable std::unordered_set<std::string> saved_strings_;
2834 };
2835 
ObjectNames(const SavedObjectGraph & object_graph,absl::Span<std::string> exported_names)2836 ObjectNames::ObjectNames(const SavedObjectGraph& object_graph,
2837                          absl::Span<std::string> exported_names)
2838     : object_graph_(object_graph),
2839       names_to_export_(exported_names.begin(), exported_names.end()) {
2840   // Visit all reachable nodes from the root of the object graph.
2841   // This builds up object_names_ to contain all names like `foo.bar` that a
2842   // particular node in the graph can be reached from.
2843   RecursivelyVisitObjectGraph(/*node_id=*/0);
2844 
2845   // Populate the exported_names_ map.
2846   // TODO(silvasean): Diagnose typos in exported names?
2847   for (auto& kv : object_names_) {
2848     // Make object names map independent of our particular choice of object
2849     // graph traversal.
2850     std::sort(kv.second.begin(), kv.second.end(),
2851               [](absl::string_view a, absl::string_view b) {
2852                 // The sort order here influences the "pretty name" we assign
2853                 // below. We want the most debuggable name to be first.
2854                 //
2855                 // Debuggability heuristics:
2856                 // 1. Names that end in digits are likely to be internal aliases
2857                 // to the "real" names.
2858                 // 2. Longer names are more likely to be internal aliases.
2859                 //
2860                 // Example set of object names created by Keras for the weight
2861                 // matrix of a fully connected layer on a trivial FC mnist
2862                 // model:
2863                 // - `model.layer-1.kernel` (this is the "best" name)
2864                 // - `model.keras_api.layers.1.kernel`
2865                 // - `model.variables.0`
2866                 // - `model.keras_api.layers.1.keras_api.trainable_variables.0`
2867                 // - ... 10 more long aliases ending in digits ...
2868                 return std::make_tuple(isdigit(a.back()), a.size(), a) <
2869                        std::make_tuple(isdigit(b.back()), b.size(), b);
2870               });
2871     for (const std::string& name : kv.second) {
2872       if (IsExported(name)) {
2873         exported_names_[kv.first].push_back(SaveString(name));
2874       }
2875     }
2876   }
2877   // Create "pretty" symbol table names for nodes where that is applicable.
2878   // We could make all symbol table names use the default, which is basically
2879   // just the node id. But for debugging purposes, it's nicer if we can mix in
2880   // a recognizable object name if we have the information to do so.
2881   for (auto& kv : object_names_) {
2882     int node_id = kv.first;
2883     std::string internal_name =
2884         absl::StrCat(GetDefaultSymbolTableName(node_id), "__");
2885     // If the object has an exported name, we prefer that since it is probably
2886     // the most recognizable. Otherwise, we grab some non-exported name of the
2887     // object.
2888     if (exported_names_.find(node_id) != exported_names_.end()) {
2889       internal_name += exported_names_[node_id][0].str();
2890     } else {
2891       internal_name += object_names_[node_id][0];
2892     }
2893     pretty_symbol_table_name_[node_id] = SaveString(internal_name);
2894   }
2895 }
2896 
GetExportedNames(int node_id) const2897 llvm::ArrayRef<llvm::StringRef> ObjectNames::GetExportedNames(
2898     int node_id) const {
2899   auto it = exported_names_.find(node_id);
2900   if (it != exported_names_.end()) {
2901     return it->second;
2902   }
2903   return {};
2904 }
2905 
GetSymbolTableName(int node_id) const2906 llvm::StringRef ObjectNames::GetSymbolTableName(int node_id) const {
2907   auto it = pretty_symbol_table_name_.find(node_id);
2908   if (it != pretty_symbol_table_name_.end()) {
2909     return it->second;
2910   }
2911   return SaveString(GetDefaultSymbolTableName(node_id));
2912 }
2913 
GetDefaultSymbolTableName(int node_id) const2914 std::string ObjectNames::GetDefaultSymbolTableName(int node_id) const {
2915   return absl::StrCat("__sm_node", node_id);
2916 }
2917 
IsExported(const std::string & name)2918 bool ObjectNames::IsExported(const std::string& name) {
2919   if (names_to_export_.empty()) {
2920     return true;
2921   }
2922   return names_to_export_.find(name) != names_to_export_.end();
2923 }
2924 
RecursivelyVisitObjectGraph(int node_id)2925 void ObjectNames::RecursivelyVisitObjectGraph(int node_id) {
2926   const SavedObject& object = object_graph_.nodes(node_id);
2927 
2928   switch (object.kind_case()) {
2929     case SavedObject::kConstant:
2930     case SavedObject::kFunction:
2931     case SavedObject::kVariable: {
2932       object_names_[node_id].push_back(absl::StrJoin(path_segments_, "."));
2933       break;
2934     }
2935     default:
2936       break;
2937   }
2938 
2939   for (const auto& child_ref : object.children()) {
2940     bool on_stack = !on_stack_nodes_.insert(child_ref.node_id()).second;
2941     if (on_stack) {
2942       // This is a backedge. Don't traverse it.
2943       continue;
2944     }
2945 
2946     path_segments_.push_back(child_ref.local_name());
2947     RecursivelyVisitObjectGraph(child_ref.node_id());
2948     path_segments_.pop_back();
2949 
2950     on_stack_nodes_.erase(child_ref.node_id());
2951   }
2952 }
2953 
SaveString(const std::string & s) const2954 llvm::StringRef ObjectNames::SaveString(const std::string& s) const {
2955   return llvm::StringRef(*saved_strings_.insert(s).first);
2956 }
2957 
2958 // Extracts a TensorProto for a Const op from a GraphDef, given an op_name.
2959 // Returns nullptr on not found or other mismatch.
2960 // This returns a pointer to the actual node within the graph_def so as to
2961 // avoid expensive copies.
ExtractConstTensorFromGraph(const GraphDef & graph_def,const std::string & op_name)2962 const TensorProto* ExtractConstTensorFromGraph(const GraphDef& graph_def,
2963                                                const std::string& op_name) {
2964   const NodeDef* match_node = nullptr;
2965   for (const auto& node : graph_def.node()) {
2966     if (node.name() == op_name) {
2967       match_node = &node;
2968     }
2969   }
2970 
2971   if (!match_node) {
2972     return nullptr;
2973   }
2974 
2975   auto value_it = match_node->attr().find("value");
2976   if (value_it == match_node->attr().end()) {
2977     return nullptr;
2978   }
2979 
2980   if (!value_it->second.has_tensor()) {
2981     return nullptr;
2982   }
2983 
2984   return &value_it->second.tensor();
2985 }
2986 
2987 const TrackableObjectGraph::TrackableObject::SerializedTensor*
FindSerializedTensorInTrackable(const TrackableObjectGraph::TrackableObject & trackable_object,StringPiece name)2988 FindSerializedTensorInTrackable(
2989     const TrackableObjectGraph::TrackableObject& trackable_object,
2990     StringPiece name) {
2991   for (const auto& maybe_serialized_tensor : trackable_object.attributes()) {
2992     if (maybe_serialized_tensor.name() == name) {
2993       return &maybe_serialized_tensor;
2994     }
2995   }
2996   return nullptr;
2997 }
2998 
DiagnoseMultipleConcreteFunctions(const SavedObjectGraph & object_graph,const ObjectNames & object_names)2999 Status DiagnoseMultipleConcreteFunctions(const SavedObjectGraph& object_graph,
3000                                          const ObjectNames& object_names) {
3001   for (int node_id = 0; node_id < object_graph.nodes_size(); node_id++) {
3002     const SavedObject& object = object_graph.nodes(node_id);
3003     if (object_names.GetExportedNames(node_id).empty()) {
3004       continue;
3005     }
3006     if (object.kind_case() == SavedObject::kFunction) {
3007       // We only allow a single input signature to each SavedFunction.
3008       // This assumption means we have a 1:1 correspondence between
3009       // tf.function <=> SavedFunction <=> SavedConcreteFunction <=> FunctionDef
3010       // This makes defining the ABI easier (or even well-defined at all).
3011       // TODO(silvasean): How to detect a function that doesn't have an
3012       // explicitly user-provided input signature, but happens to have been
3013       // traced exactly once?
3014       if (object.function().concrete_functions_size() != 1) {
3015         llvm::SmallVector<std::string, 4> names;
3016         for (llvm::StringRef s : object_names.GetExportedNames(node_id)) {
3017           names.push_back("'" + s.str() + "'");
3018         }
3019         return errors::InvalidArgument(
3020             "Exported function with exported name(s) ",
3021             absl::StrJoin(names, ", "),
3022             " with multiple concrete functions. Add "
3023             "@tf.function(input_signature=[...]) on this function, or use a "
3024             "narrower list of exported names that excludes this function.");
3025       }
3026     }
3027   }
3028   return OkStatus();
3029 }
3030 
3031 // Recursively traverses a StructuredValue, linearizing all the leaves.
3032 //
3033 // This currently only handles the subset of StructuredValue that is needed for
3034 // signatures.
3035 //
3036 // Given a StructuredValue with structure [{"x": leaf0}], the "index path"
3037 // needed to reach leaf0 is `[0, "x"]`, as it would be if you were operating on
3038 // a Python object (`obj[0]["x"] is leaf0`). Each leaf corresponds to a
3039 // linearized function argument or return on a FunctionDef, and hence to an
3040 // mlir::func::FuncOp argument / return.
3041 //
3042 // This must match the linearization that happens in `tf.nest.flatten`.
3043 // In particular, dict values should be linearized in sorted key order.
3044 //
3045 // The linearized index paths can be returned back to a structured
3046 // representation (e.g. to emit C structs matching a signature) with a simple
3047 // algorithm that recurses on each run of index paths with identical first
3048 // elements.
3049 class StructuredValueLinearizer {
3050  public:
3051   StructuredValueLinearizer(const StructuredValue& value,
3052                             mlir::MLIRContext* context);
3053 
3054   // Returns the list of index paths to each leaf of the StructuredValue,
3055   // in a linearized order matching `tf.nest.flatten`.
3056   //
3057   // If an error occurred during the linearization process, an error message
3058   // with `error_context` prepended will be included in the returned status.
3059   StatusOr<llvm::ArrayRef<mlir::ArrayAttr>> GetLeafIndexPaths(
3060       llvm::StringRef error_context) const;
3061 
3062  private:
3063   // Main function that recursively traverses the StructuredValue.
3064   void RecursivelyFindLeaves(const StructuredValue& value);
3065 
3066   mlir::Builder builder_;
3067   // The current index path. We push/pop this during recursive traversal of the
3068   // StructuredValue.
3069   llvm::SmallVector<mlir::Attribute, 4> current_index_path_;
3070   // The list of leaf index paths we have discovered so far.
3071   llvm::SmallVector<mlir::ArrayAttr, 4> leaf_index_paths_;
3072   // If non-empty, an error message to report.
3073   std::string error_message_;
3074 };
3075 
StructuredValueLinearizer(const StructuredValue & value,mlir::MLIRContext * context)3076 StructuredValueLinearizer::StructuredValueLinearizer(
3077     const StructuredValue& value, mlir::MLIRContext* context)
3078     : builder_(context) {
3079   RecursivelyFindLeaves(value);
3080 }
3081 
3082 StatusOr<llvm::ArrayRef<mlir::ArrayAttr>>
GetLeafIndexPaths(llvm::StringRef error_context) const3083 StructuredValueLinearizer::GetLeafIndexPaths(
3084     llvm::StringRef error_context) const {
3085   if (error_message_.empty()) {
3086     return llvm::makeArrayRef(leaf_index_paths_);
3087   }
3088   return errors::InvalidArgument(
3089       error_context.str(), error_message_,
3090       "This likely means that you have @tf.function "
3091       "on an exported function instead of "
3092       "@tf.function(input_signature=[...]). Consider annotating an "
3093       "input_signature or narrowing your set of "
3094       "exported names to not include this function.");
3095 }
3096 
RecursivelyFindLeaves(const StructuredValue & value)3097 void StructuredValueLinearizer::RecursivelyFindLeaves(
3098     const StructuredValue& value) {
3099   switch (value.kind_case()) {
3100     case StructuredValue::kDictValue: {
3101       // Dict values must be linearized in sorted order of keys.
3102       const DictValue& dict = value.dict_value();
3103       using FieldTy = protobuf::MapPair<std::string, StructuredValue>;
3104       llvm::SmallVector<const FieldTy*, 4> fields;
3105       for (auto& field : dict.fields()) {
3106         fields.push_back(&field);
3107       }
3108       llvm::sort(fields, [](const FieldTy* a, const FieldTy* b) {
3109         return a->first < b->first;
3110       });
3111       for (auto& field : fields) {
3112         current_index_path_.push_back(builder_.getStringAttr(field->first));
3113         RecursivelyFindLeaves(field->second);
3114         current_index_path_.pop_back();
3115       }
3116       return;
3117     }
3118     case StructuredValue::kTupleValue: {
3119       const TupleValue& tuple = value.tuple_value();
3120       for (int i = 0, e = tuple.values_size(); i < e; i++) {
3121         current_index_path_.push_back(builder_.getI64IntegerAttr(i));
3122         RecursivelyFindLeaves(tuple.values(i));
3123         current_index_path_.pop_back();
3124       }
3125       return;
3126     }
3127     // We don't differentiate between tuples and lists.
3128     case StructuredValue::kListValue: {
3129       const ListValue& list = value.list_value();
3130       for (int i = 0, e = list.values_size(); i < e; i++) {
3131         current_index_path_.push_back(builder_.getI64IntegerAttr(i));
3132         RecursivelyFindLeaves(list.values(i));
3133         current_index_path_.pop_back();
3134       }
3135       return;
3136     }
3137     case StructuredValue::kTensorSpecValue: {
3138       // Base case: record the current path stack as the index path needed to
3139       // get to this leaf.
3140       leaf_index_paths_.push_back(builder_.getArrayAttr(current_index_path_));
3141       return;
3142     }
3143     case StructuredValue::kNoneValue: {
3144       // Base case: do nothing.
3145       // This arises, for example, as the top-level object of an output
3146       // signature when there are no return values.
3147       return;
3148     }
3149     default: {
3150       llvm::raw_string_ostream os(error_message_);
3151       // TODO(silvasean): Use an enumerant name string instead of a number.
3152       os << "Unhandled structured value kind " << value.kind_case()
3153          << " at index path: <value>";
3154       for (auto path_element : current_index_path_) {
3155         os << ".";
3156         if (auto integer = path_element.dyn_cast<mlir::IntegerAttr>()) {
3157           os << integer.getValue();
3158         } else {
3159           auto str = path_element.cast<mlir::StringAttr>();
3160           os << str.getValue();
3161         }
3162       }
3163       os << "\n";
3164     }
3165   }
3166 }
3167 
3168 // For exported functions with bound inputs, rewrite the function
3169 // signature to match the requirements of tf_saved_model bound input args.
3170 //
3171 // The raw imported functions have `tensor<*x!tf_type.resource>` as the type for
3172 // mutable bound inputs and `tensor<...>` as the type for immutable
3173 // bound inputs. Here we canonicalize both of them into
3174 // `tensor<!tf_type.resource<tensor<...>>>`.
AdjustBoundInputArgTypes(mlir::ModuleOp module)3175 void AdjustBoundInputArgTypes(mlir::ModuleOp module) {
3176   mlir::SymbolTable symbol_table(module);
3177   for (auto func : module.getOps<mlir::func::FuncOp>()) {
3178     if (!mlir::tf_saved_model::IsExported(func)) continue;
3179     mlir::OpBuilder builder(func.getBody());
3180     llvm::SmallVector<mlir::Type, 4> new_input_types;
3181     for (int i = 0, e = func.getNumArguments(); i < e; i++) {
3182       auto arg = func.getArgument(i);
3183       auto global_tensor = mlir::tf_saved_model::LookupBoundInputOfType<
3184           mlir::tf_saved_model::GlobalTensorOp>(func, i, symbol_table);
3185       if (global_tensor) {
3186         auto old_type = arg.getType();
3187         auto new_type =
3188             mlir::tf_saved_model::GetBoundInputArgTypeFor(global_tensor);
3189         arg.setType(new_type);
3190         if (global_tensor.is_mutable()) {
3191           auto arg_with_original_type = builder.create<mlir::TF::CastOp>(
3192               global_tensor.getLoc(), old_type, arg,
3193               /*Truncate=*/builder.getBoolAttr(false));
3194           arg.replaceAllUsesWith(arg_with_original_type);
3195           // The RAUW replaces the arg with itself, so we need to set it back.
3196           arg_with_original_type.setOperand(arg);
3197         } else {
3198           auto arg_with_original_type =
3199               builder.create<mlir::TF::ReadVariableOp>(global_tensor.getLoc(),
3200                                                        old_type, arg);
3201           arg.replaceAllUsesWith(arg_with_original_type);
3202           // The RAUW replaces the arg with itself, so we need to set it back.
3203           arg_with_original_type.setOperand(arg);
3204         }
3205       }
3206       new_input_types.push_back(arg.getType());
3207     }
3208     func.setType(mlir::FunctionType::get(module.getContext(), new_input_types,
3209                                          func.getFunctionType().getResults()));
3210   }
3211 }
3212 
3213 // Marks the visibility of functions in the saved model module.
MarkSavedModelFunctionVisibility(mlir::ModuleOp module)3214 void MarkSavedModelFunctionVisibility(mlir::ModuleOp module) {
3215   for (auto func : module.getOps<mlir::func::FuncOp>()) {
3216     auto visibility = mlir::tf_saved_model::IsExported(func)
3217                           ? mlir::func::FuncOp::Visibility::Public
3218                           : mlir::func::FuncOp::Visibility::Private;
3219     func.setVisibility(visibility);
3220   }
3221 }
3222 
3223 // Reorder the ops in the module to make testing easier and less dependent
3224 // on implementation details such as the order of functions in the
3225 // FunctionDefLibrary.
3226 //
3227 // The order this ensures is:
3228 // 1. GlobalTensorOp's
3229 // 2. FuncOps's.
3230 //
3231 // Within each of 1. and 2., ops are sorted by exported name (if
3232 // available, and only the first exported name is considered), followed by
3233 // non-exported ops.
SortSavedModelModule(mlir::ModuleOp module)3234 void SortSavedModelModule(mlir::ModuleOp module) {
3235   struct NamedGlobalTensor {
3236     llvm::StringRef name;
3237     GlobalTensorOp global_tensor;
3238   };
3239   llvm::SmallVector<NamedGlobalTensor, 8> named_global_tensors;
3240   for (auto global_tensor : module.getOps<GlobalTensorOp>()) {
3241     auto exported_names = mlir::tf_saved_model::GetExportedNames(global_tensor);
3242     // We use stable_sort, so duplicate empty names are fine here.
3243     named_global_tensors.push_back(
3244         {exported_names.empty() ? "" : exported_names.front(), global_tensor});
3245   }
3246   llvm::stable_sort(named_global_tensors,
3247                     [](const NamedGlobalTensor& a, const NamedGlobalTensor& b) {
3248                       return std::make_tuple(a.name.empty(), a.name) <
3249                              std::make_tuple(b.name.empty(), b.name);
3250                     });
3251 
3252   struct NamedFunc {
3253     llvm::StringRef name;
3254     mlir::func::FuncOp func;
3255   };
3256   llvm::SmallVector<NamedFunc, 8> named_funcs;
3257   llvm::SmallVector<mlir::func::FuncOp, 8> private_funcs;
3258   for (auto func : module.getOps<mlir::func::FuncOp>()) {
3259     auto exported_names = mlir::tf_saved_model::GetExportedNames(func);
3260     if (!exported_names.empty())
3261       named_funcs.push_back({exported_names.front(), func});
3262     else
3263       private_funcs.push_back(func);
3264   }
3265   llvm::stable_sort(named_funcs, [](const NamedFunc& a, const NamedFunc& b) {
3266     return a.name < b.name;
3267   });
3268   llvm::stable_sort(private_funcs,
3269                     [](mlir::func::FuncOp a, mlir::func::FuncOp b) {
3270                       return a.getName() < b.getName();
3271                     });
3272 
3273   struct NamedAsset {
3274     llvm::StringRef name;
3275     AssetOp asset;
3276   };
3277   llvm::SmallVector<NamedAsset, 4> assets;
3278   for (auto asset : module.getOps<AssetOp>()) {
3279     assets.push_back({asset.getName(), asset});
3280   }
3281   llvm::stable_sort(assets, [](const NamedAsset& a, const NamedAsset& b) {
3282     return a.name < b.name;
3283   });
3284 
3285   // Move onto the front of the module in reverse of the final desired order.
3286   for (auto func : llvm::reverse(private_funcs)) {
3287     func.getOperation()->moveBefore(&module.getBody()->front());
3288   }
3289   for (auto named_func : llvm::reverse(named_funcs)) {
3290     named_func.func.getOperation()->moveBefore(&module.getBody()->front());
3291   }
3292   for (auto named_global_tensor : llvm::reverse(named_global_tensors)) {
3293     named_global_tensor.global_tensor.getOperation()->moveBefore(
3294         &module.getBody()->front());
3295   }
3296 
3297   for (auto asset : assets) {
3298     asset.asset.getOperation()->moveBefore(&module.getBody()->front());
3299   }
3300 
3301   auto initializers = module.getOps<SessionInitializerOp>();
3302   if (!initializers.empty()) {
3303     (*initializers.begin())
3304         .getOperation()
3305         ->moveBefore(&module.getBody()->front());
3306   }
3307 }
3308 
CreateSavedModelIR(const ObjectNames & object_names,mlir::ModuleOp module,const SavedObjectGraph & object_graph,const std::unordered_map<std::string,std::string> & tf_name_to_mlir_name,SavedModelV2Bundle * saved_model)3309 Status CreateSavedModelIR(
3310     const ObjectNames& object_names, mlir::ModuleOp module,
3311     const SavedObjectGraph& object_graph,
3312     const std::unordered_map<std::string, std::string>& tf_name_to_mlir_name,
3313     SavedModelV2Bundle* saved_model) {
3314   mlir::OpBuilder builder(module.getBodyRegion());
3315   mlir::SymbolTable symbol_table(module);
3316 
3317   // Create a side data-structure, indexed by the object_graph node_id to
3318   // a TrackableObject that is restorable.
3319   absl::flat_hash_map<int, const TrackableObjectGraph::TrackableObject*>
3320       restored_objects;
3321   TF_RETURN_IF_ERROR(saved_model->VisitObjectsToRestore(
3322       [&](int saved_node_id,
3323           const TrackableObjectGraph::TrackableObject& trackable_object) {
3324         restored_objects.insert(
3325             std::make_pair(saved_node_id, &trackable_object));
3326         return OkStatus();
3327       }));
3328 
3329   for (int node_id = 0; node_id < object_graph.nodes_size(); node_id++) {
3330     const SavedObject& object = object_graph.nodes(node_id);
3331     // For correctness, we cannot import functions that don't have exported
3332     // names, since they don't necessarily have a well-defined ABI (diagnosed
3333     // earlier).
3334     //
3335     // For variables/constants, pruning them is purely an optimization,
3336     // and more complicated since it requires use-def analysis of which
3337     // functions use which variables/constants, so we don't do anything
3338     // special for them here as part of our initial IR construction.
3339     if (object.kind_case() == SavedObject::kFunction) {
3340       if (object_names.GetExportedNames(node_id).empty()) {
3341         continue;
3342       }
3343       std::string error_context =
3344           "While importing SavedModel function '" +
3345           object_names.GetExportedNames(node_id)[0].str() + "': ";
3346       const SavedFunction& function = object.function();
3347       auto orig_func = symbol_table.lookup<mlir::func::FuncOp>(
3348           tf_name_to_mlir_name.find(function.concrete_functions(0))->second);
3349       mlir::func::FuncOp func = orig_func;
3350       // If there are potentially references to this func from within the
3351       // module, create a wrapper around it and decorate the wrapper with the
3352       // tf_saved_model attributes instead.
3353       if (!mlir::SymbolTable::symbolKnownUseEmpty(orig_func.getSymNameAttr(),
3354                                                   &module.getBodyRegion())) {
3355         func = orig_func.cloneWithoutRegions();
3356         module.insert(module.getBody()->begin(), func);
3357         func.addEntryBlock();
3358         func.setName(builder.getStringAttr("__sm_exported_" +
3359                                            orig_func.getName().str()));
3360         llvm::SmallVector<mlir::Value, 4> args_as_values;
3361         for (auto block_argument : func.getArguments()) {
3362           args_as_values.push_back(block_argument);
3363         }
3364         mlir::OpBuilder body_builder(&func.getBody());
3365         auto call = body_builder.create<mlir::TF::StatefulPartitionedCallOp>(
3366             func.getLoc(), orig_func.getFunctionType().getResults(),
3367             args_as_values,
3368             mlir::SymbolRefAttr::get(builder.getContext(), orig_func.getName()),
3369             /*config=*/builder.getStringAttr(""),
3370             /*config_proto=*/builder.getStringAttr(""),
3371             /*executor_type=*/builder.getStringAttr(""));
3372         body_builder.create<mlir::func::ReturnOp>(func.getLoc(),
3373                                                   call.getResults());
3374       }
3375       func->setAttr(
3376           "tf_saved_model.exported_names",
3377           builder.getStrArrayAttr(object_names.GetExportedNames(node_id)));
3378       const SavedConcreteFunction& concrete_function =
3379           object_graph.concrete_functions().at(function.concrete_functions(0));
3380 
3381       // We do not handle the other element of this tuple, which corresponds to
3382       // Python kwonlyargs, since currently TensorFlow prohibits this in
3383       // combination with input_signature:
3384       // https://github.com/tensorflow/tensorflow/blob/8cb8627abb5ef83a6fba34f8fd0e4ee430562eb1/tensorflow/python/eager/function.py#L2027-L2030
3385       // Our SavedModel import requires input_signature on the tf.function, so
3386       // we never need to handle the kwonlyargs.
3387       auto positional_arg_structure =
3388           concrete_function.canonicalized_input_signature()
3389               .tuple_value()
3390               .values(0);
3391       StructuredValueLinearizer input_linearizer(positional_arg_structure,
3392                                                  builder.getContext());
3393 
3394       int bound_input_base =
3395           func.getNumArguments() - concrete_function.bound_inputs_size();
3396       TF_ASSIGN_OR_RETURN(auto input_index_paths,
3397                           input_linearizer.GetLeafIndexPaths(
3398                               error_context + "in input signature: "));
3399       const int input_index_paths_size = input_index_paths.size();
3400       if (bound_input_base != input_index_paths_size) {
3401         return errors::InvalidArgument(
3402             error_context,
3403             "Argument mismatch between concrete function input signature "
3404             "vs underlying FunctionDef for concrete function '",
3405             function.concrete_functions(0), "' (", input_index_paths.size(),
3406             " vs ", bound_input_base, ")");
3407       }
3408       for (auto index_path : llvm::enumerate(input_index_paths)) {
3409         func.setArgAttr(index_path.index(), "tf_saved_model.index_path",
3410                         index_path.value());
3411       }
3412 
3413       for (auto& bound_input :
3414            llvm::enumerate(concrete_function.bound_inputs())) {
3415         int arg_index = bound_input_base + bound_input.index();
3416         auto symbol_ref = mlir::SymbolRefAttr::get(
3417             builder.getContext(),
3418             object_names.GetSymbolTableName(bound_input.value()));
3419         func.setArgAttr(arg_index, "tf_saved_model.bound_input", symbol_ref);
3420       }
3421 
3422       StructuredValueLinearizer output_linearizer(
3423           concrete_function.output_signature(), builder.getContext());
3424       TF_ASSIGN_OR_RETURN(auto output_index_paths,
3425                           output_linearizer.GetLeafIndexPaths(
3426                               error_context + "in output signature: "));
3427       if (func.getNumResults() != output_index_paths.size()) {
3428         return errors::InvalidArgument(
3429             error_context,
3430             "Result mismatch between concrete function output signature "
3431             "vs underlying FunctionDef for concrete function '",
3432             function.concrete_functions(0), "' (", output_index_paths.size(),
3433             " vs ", func.getNumResults(), ")");
3434       }
3435       for (auto index_path : llvm::enumerate(output_index_paths)) {
3436         func.setResultAttr(index_path.index(), "tf_saved_model.index_path",
3437                            index_path.value());
3438       }
3439     } else if (object.kind_case() == SavedObject::kVariable) {
3440       const SavedVariable& variable = object.variable();
3441       // Find the trackable in the side data structure.
3442       auto variable_trackable_it = restored_objects.find(node_id);
3443       if (variable_trackable_it == restored_objects.end()) {
3444         return errors::FailedPrecondition("Could not restore saved variable: ",
3445                                           variable.name());
3446       }
3447       const auto* serialized_tensor_attr = FindSerializedTensorInTrackable(
3448           *variable_trackable_it->second, "VARIABLE_VALUE");
3449       if (!serialized_tensor_attr) {
3450         return errors::FailedPrecondition(
3451             "Could not find serialized tensor for saved variable: ",
3452             variable.name());
3453       }
3454       const auto& checkpoint_key = serialized_tensor_attr->checkpoint_key();
3455 
3456       // Load it from the reader.
3457       Tensor value;
3458       TF_RETURN_WITH_CONTEXT_IF_ERROR(
3459           saved_model->variable_reader()->Lookup(checkpoint_key, &value),
3460           "Could not read checkpoint key from variables bundle: ",
3461           checkpoint_key);
3462       TF_ASSIGN_OR_RETURN(auto value_attr, ConvertTensor(value, &builder));
3463       // A variable can have a partially known type, such as tensor<?x27x?xf32>,
3464       // even if the initializer is a specific static shape.
3465       TF_ASSIGN_OR_RETURN(
3466           auto type, ConvertToMlirTensorType(variable.shape(), variable.dtype(),
3467                                              &builder));
3468       auto op = builder.create<GlobalTensorOp>(
3469           builder.getUnknownLoc(),
3470           builder.getStringAttr(object_names.GetSymbolTableName(node_id)),
3471           value_attr,
3472           /*type=*/mlir::TypeAttr::get(type),
3473           /*is_mutable=*/builder.getUnitAttr());
3474       op->setAttr(
3475           "tf_saved_model.exported_names",
3476           builder.getStrArrayAttr(object_names.GetExportedNames(node_id)));
3477     } else if (object.kind_case() == SavedObject::kConstant) {
3478       const SavedConstant& constant = object.constant();
3479       const TensorProto* value = ExtractConstTensorFromGraph(
3480           saved_model->meta_graph_def().graph_def(), constant.operation());
3481       if (!value) {
3482         return errors::FailedPrecondition(
3483             "Unable to find const node referenced in object graph: ",
3484             constant.operation());
3485       }
3486       TF_ASSIGN_OR_RETURN(auto value_attr,
3487                           ConvertTensorProto(*value, &builder));
3488       auto op = builder.create<GlobalTensorOp>(
3489           builder.getUnknownLoc(),
3490           builder.getStringAttr(object_names.GetSymbolTableName(node_id)),
3491           value_attr,
3492           /*type=*/mlir::TypeAttr::get(value_attr.getType()),
3493           /*is_mutable=*/nullptr);
3494       op->setAttr(
3495           "tf_saved_model.exported_names",
3496           builder.getStrArrayAttr(object_names.GetExportedNames(node_id)));
3497     }
3498   }
3499   AdjustBoundInputArgTypes(module);
3500   module->setAttr("tf_saved_model.semantics", builder.getUnitAttr());
3501   SortSavedModelModule(module);
3502   MarkSavedModelFunctionVisibility(module);
3503   return OkStatus();
3504 }
3505 
3506 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>>
Convert(SavedModelV2Bundle * saved_model,absl::Span<std::string> exported_names,mlir::MLIRContext * context,bool add_default_attributes,bool unconditionally_use_set_output_shapes)3507 SavedModelObjectGraphImporter::Convert(
3508     SavedModelV2Bundle* saved_model, absl::Span<std::string> exported_names,
3509     mlir::MLIRContext* context, bool add_default_attributes,
3510     // TODO(b/200093974): Remove post triage.
3511     bool unconditionally_use_set_output_shapes) {
3512   LoadImporterDialects(*context);
3513   GraphDebugInfo dummy_debug_info;
3514   const GraphDebugInfo& debug_info =
3515       saved_model->debug_info() ? *saved_model->debug_info() : dummy_debug_info;
3516 
3517   GraphImportConfig specs;
3518   specs.prune_unused_nodes = true;
3519   specs.unconditionally_use_set_output_shapes =
3520       unconditionally_use_set_output_shapes;
3521   mlir::OwningOpRef<mlir::ModuleOp> module =
3522       mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
3523   std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
3524 
3525   const auto& graphdef = saved_model->meta_graph_def().graph_def();
3526   PopulateTfVersions(module.get(), graphdef.versions());
3527 
3528   GraphConstructorOptions options;
3529   options.allow_internal_ops = true;
3530   options.add_default_attributes = add_default_attributes;
3531   Graph graph(OpRegistry::Global());
3532 
3533   GraphDef preprocessed_graphdef(graphdef);
3534   if (add_default_attributes) {
3535     TF_RETURN_IF_ERROR(PreprocessGraphDef(nullptr, &preprocessed_graphdef));
3536   }
3537 
3538   TF_RETURN_IF_ERROR(
3539       ConvertGraphDefToGraph(options, preprocessed_graphdef, &graph));
3540 
3541   NameUniquifier function_name_uniquifier(graph.flib_def());
3542   SavedModelObjectGraphImporter importer(graph.flib_def(), debug_info, specs,
3543                                          module.get(), &tf_name_to_mlir_name,
3544                                          &function_name_uniquifier);
3545 
3546   TF_RETURN_IF_ERROR(importer.PrepareConvert(graph));
3547 
3548   auto fn_names = graph.flib_def().ListFunctionNames();
3549   for (const auto& fn_name : fn_names) {
3550     TF_RETURN_IF_ERROR(importer.ConvertLibFunction(fn_name));
3551   }
3552   TF_RETURN_IF_ERROR(importer.ConvertDeferredFunctions());
3553 
3554   if (!saved_model->meta_graph_def().has_object_graph_def()) {
3555     return errors::InvalidArgument(
3556         "SavedModel does not have an object graph. Please use TF2.");
3557   }
3558   auto& object_graph = saved_model->meta_graph_def().object_graph_def();
3559   ObjectNames object_names(object_graph, exported_names);
3560 
3561   // Clean up a couple func's that always seem to be present when importing a
3562   // SavedModel. This is not strictly needed, as there is a separate pass that
3563   // will clean them up, but this makes staring at the raw IR of minimal
3564   // examples quite a bit nicer.
3565   for (auto func :
3566        llvm::make_early_inc_range(module->getOps<mlir::func::FuncOp>())) {
3567     if (func.getName().startswith("__inference__traced_save_") ||
3568         func.getName().startswith("__inference__traced_restore_") ||
3569         func.getName().startswith("__inference_signature_wrapper_")) {
3570       func.erase();
3571     }
3572   }
3573 
3574   // Diagnose SavedFunction's with multiple input signatures.
3575   TF_RETURN_IF_ERROR(
3576       DiagnoseMultipleConcreteFunctions(object_graph, object_names));
3577 
3578   // Construct the SavedModel IR.
3579   TF_RETURN_IF_ERROR(CreateSavedModelIR(object_names, module.get(),
3580                                         object_graph, tf_name_to_mlir_name,
3581                                         saved_model));
3582   assert(mlir::succeeded(mlir::verify(module.get())));
3583 
3584   return module;
3585 }
3586 
3587 class SimpleSavedModelMLIRImportInput : public SavedModelMLIRImportInput {
3588  public:
Create(const MLIRImportOptions & import_options,const MetaGraphDef * meta_graph_def,const GraphDebugInfo & debug_info)3589   static StatusOr<SimpleSavedModelMLIRImportInput> Create(
3590       const MLIRImportOptions& import_options,
3591       const MetaGraphDef* meta_graph_def, const GraphDebugInfo& debug_info) {
3592     DCHECK(meta_graph_def);
3593     GraphDef graph_def = meta_graph_def->graph_def();
3594     auto graph = std::make_unique<Graph>(OpRegistry::Global());
3595 
3596     if (import_options.upgrade_legacy) {
3597       TF_RETURN_IF_ERROR(GenerateResourceSharedNameIfEmpty(
3598           graph_def, graph->flib_def().default_registry()));
3599     }
3600 
3601     GraphConstructorOptions graph_ctor_options;
3602     graph_ctor_options.allow_internal_ops = true;
3603     graph_ctor_options.add_default_attributes = true;
3604     TF_RETURN_IF_ERROR(
3605         ConvertGraphDefToGraph(graph_ctor_options, graph_def, graph.get()));
3606 
3607     if (import_options.upgrade_legacy) {
3608       // TODO(jpienaar): Remove need to const_cast.
3609       TF_RETURN_IF_ERROR(UpgradeLegacyGraph(
3610           graph.get(),
3611           const_cast<FunctionLibraryDefinition*>(&graph->flib_def()),
3612           /*restrict_functionalization_to_compiled_nodes=*/false));
3613     }
3614 
3615     return SimpleSavedModelMLIRImportInput(meta_graph_def, debug_info,
3616                                            std::move(graph));
3617   }
3618 
SimpleSavedModelMLIRImportInput(const MetaGraphDef * meta_graph_def,const GraphDebugInfo & debug_info,std::unique_ptr<Graph> graph)3619   SimpleSavedModelMLIRImportInput(const MetaGraphDef* meta_graph_def,
3620                                   const GraphDebugInfo& debug_info,
3621                                   std::unique_ptr<Graph> graph)
3622       : SavedModelMLIRImportInput(meta_graph_def, debug_info),
3623         graph_(std::move(graph)) {}
3624 
GetSubGraph(absl::string_view name,GraphImportConfig & specs)3625   StatusOr<const Graph*> GetSubGraph(absl::string_view name,
3626                                      GraphImportConfig& specs) override {
3627     DCHECK(CheckGraphNameValidity(name));
3628     DCHECK(CheckGraphContainsFeedsAndFetches(specs));
3629     return graph_.get();
3630   }
3631 
3632  private:
CheckGraphContainsFeedsAndFetches(const GraphImportConfig & specs) const3633   bool CheckGraphContainsFeedsAndFetches(const GraphImportConfig& specs) const {
3634     absl::flat_hash_set<std::string> feed_fetch_nodes;
3635     for (const auto& iter : specs.inputs) {
3636       TensorId tensor_id = ParseTensorName(iter.first);
3637       feed_fetch_nodes.insert(std::string(tensor_id.node()));
3638     }
3639     for (const auto& output : llvm::concat<const std::string>(
3640              specs.outputs, specs.control_outputs)) {
3641       TensorId tensor_id = ParseTensorName(output);
3642       feed_fetch_nodes.insert(std::string(tensor_id.node()));
3643     }
3644 
3645     for (Node* node : graph_->op_nodes()) {
3646       feed_fetch_nodes.erase(node->name());
3647     }
3648 
3649     return feed_fetch_nodes.empty();
3650   }
3651 
CheckGraphNameValidity(absl::string_view name) const3652   bool CheckGraphNameValidity(absl::string_view name) const {
3653     // If it is one of the signature name, it is valid.
3654     const auto& signature_defs = meta_graph_def().signature_def();
3655     if (signature_defs.contains(std::string(name))) return true;
3656 
3657     // If it is the restore graph name, it is valid.
3658     if (meta_graph_def().has_saver_def() &&
3659         meta_graph_def().saver_def().restore_op_name() == name)
3660       return true;
3661 
3662     // If it is the init graph name, it is valid.
3663     std::string init_op_name;
3664     if (internal::GetInitOp("", meta_graph_def(), &init_op_name).ok()) {
3665       if (init_op_name == name) return true;
3666     }
3667 
3668     return false;
3669   }
3670 
3671   // `graph_` contains the entire graph in the original MetaGraphDef.
3672   std::unique_ptr<Graph> graph_;
3673 };
3674 
GetOriginalTfFuncNamesFromGraphDef(const GraphDef & graph_def)3675 static absl::flat_hash_set<std::string> GetOriginalTfFuncNamesFromGraphDef(
3676     const GraphDef& graph_def) {
3677   absl::flat_hash_set<std::string> original_func_tf_names;
3678   for (const auto& function : graph_def.library().function()) {
3679     original_func_tf_names.insert(function.signature().name());
3680   }
3681   return original_func_tf_names;
3682 }
3683 
3684 // A helper class to import a TensorFlow model expressed in SavedModel V1 into
3685 // an MLIR Module in SavedModel dialect.
3686 //
3687 // TODO(b/179683149): Rename this class to avoid confusion with TFLite.
3688 class SavedModelSignatureDefImporterLite {
3689  public:
3690   // Main entry point: converts all functions (specified by SignatureDefs) in
3691   // the given meta graph to an MLIR Module.
3692   //
3693   // `import_restore` is introduced to control whether restore graph
3694   // is imported in eg. SavedModelSignatureDefImporter. Ideally, we don't need
3695   // this option to control this as restore graph should be always imported.
3696   // However, right now, SavedModelSignatureDefImporter cannot handle restore
3697   // graph correctly.
3698   //
3699   // TODO(chky): Remove import_restore once the restore graph is correctly
3700   // handled in SavedModelSignatureDefImporter.
Convert(SavedModelMLIRImportInput & input,std::optional<absl::Span<const std::string>> exported_names,mlir::MLIRContext * context,bool import_restore=true,bool unconditionally_use_set_output_shapes=false)3701   static StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> Convert(
3702       SavedModelMLIRImportInput& input,
3703       std::optional<absl::Span<const std::string>> exported_names,
3704       mlir::MLIRContext* context, bool import_restore = true,
3705       bool unconditionally_use_set_output_shapes = false) {
3706     SavedModelSignatureDefImporterLite importer(
3707         input, exported_names, context, import_restore,
3708         unconditionally_use_set_output_shapes);
3709     return importer.ConvertSignatures();
3710   }
3711 
3712  private:
SavedModelSignatureDefImporterLite(SavedModelMLIRImportInput & input,std::optional<absl::Span<const std::string>> exported_names,mlir::MLIRContext * context,bool import_restore,bool unconditionally_use_set_output_shapes)3713   SavedModelSignatureDefImporterLite(
3714       SavedModelMLIRImportInput& input,
3715       std::optional<absl::Span<const std::string>> exported_names,
3716       mlir::MLIRContext* context, bool import_restore,
3717       bool unconditionally_use_set_output_shapes)
3718       : input_(input),
3719         original_func_tf_names_(GetOriginalTfFuncNamesFromGraphDef(
3720             input.meta_graph_def().graph_def())),
3721         exported_names_(exported_names),
3722         module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))),
3723         symbol_table_(module_.get()),
3724         import_restore_(import_restore),
3725         unconditionally_use_set_output_shapes_(
3726             unconditionally_use_set_output_shapes) {}
3727 
3728   // Converts the SavedModel to the SavedModel dialect. Creates an MLIR function
3729   // for each signature.
3730   StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertSignatures();
3731   Status ConvertSignature(const std::string& sig_def_key,
3732                           const SignatureDef& signature_def);
3733 
3734   struct AssetInfo {
3735     std::string tensor_name;
3736     mlir::tf_saved_model::AssetOp op;
3737   };
3738   StatusOr<std::vector<AssetInfo>> ConvertAssets();
3739   // Converts the initialization graph in the SavedModel to an MLIR function.
3740   Status ConvertInitializer(const std::string& target_node_name,
3741                             const std::vector<AssetInfo>& assets);
3742 
3743   // Converts a graph with feeds and fetches to an MLIR function.
3744   StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertGraph(
3745       const std::string& name,
3746       const std::vector<std::pair<std::string, TensorInfo>>& inputs,
3747       const std::vector<std::pair<std::string, TensorInfo>>& outputs,
3748       const std::vector<std::string> control_outputs,
3749       std::unordered_map<std::string, std::string>& tf_name_to_mlir_name);
3750 
3751   // Moves the functions in `sub_module` to `module_` and skips the duplicate
3752   // functions.
3753   Status MoveConvertedFunctionsToModule(
3754       absl::string_view name, mlir::ModuleOp sub_module,
3755       const std::unordered_map<std::string, std::string>& tf_name_to_mlir_name);
3756 
3757   StatusOr<GraphImportConfig::InputArrays> ParseInputArrays(
3758       llvm::ArrayRef<std::pair<std::string, TensorInfo>> inputs);
3759 
3760  private:
3761   SavedModelMLIRImportInput& input_;
3762   absl::flat_hash_set<std::string> original_func_tf_names_;
3763   std::optional<absl::Span<const std::string>> exported_names_;
3764   mlir::OwningOpRef<mlir::ModuleOp> module_;
3765   absl::Mutex symbol_table_mu_;
3766   mlir::SymbolTable symbol_table_ ABSL_GUARDED_BY(symbol_table_mu_);
3767   bool import_restore_ = true;
3768   bool unconditionally_use_set_output_shapes_ = false;
3769 };
3770 
3771 StatusOr<std::vector<SavedModelSignatureDefImporterLite::AssetInfo>>
ConvertAssets()3772 SavedModelSignatureDefImporterLite::ConvertAssets() {
3773   std::vector<AssetFileDef> asset_file_defs;
3774   TF_RETURN_IF_ERROR(
3775       internal::GetAssetFileDefs(input_.meta_graph_def(), &asset_file_defs));
3776 
3777   std::vector<AssetInfo> results;
3778   results.reserve(asset_file_defs.size());
3779 
3780   mlir::OpBuilder builder(module_->getBodyRegion());
3781   unsigned i = 0;  // Use to generate unique sym_name(s) for duplicate assets.
3782   for (const auto& asset : asset_file_defs) {
3783     auto asset_op = builder.create<mlir::tf_saved_model::AssetOp>(
3784         module_->getLoc(),
3785         /*sym_name=*/
3786         builder.getStringAttr(
3787             absl::StrCat("__tf_saved_model_asset", i++, "_", asset.filename())),
3788         /*filename=*/
3789         builder.getStringAttr(
3790             io::JoinPath(kSavedModelAssetsDirectory, asset.filename())));
3791 
3792     results.push_back({asset.tensor_info().name(), asset_op});
3793   }
3794 
3795   return results;
3796 }
3797 
MoveConvertedFunctionsToModule(absl::string_view name,mlir::ModuleOp sub_module,const std::unordered_map<std::string,std::string> & tf_name_to_mlir_name)3798 Status SavedModelSignatureDefImporterLite::MoveConvertedFunctionsToModule(
3799     absl::string_view name, mlir::ModuleOp sub_module,
3800     const std::unordered_map<std::string, std::string>& tf_name_to_mlir_name) {
3801   mlir::Builder builder(sub_module.getContext());
3802   mlir::SymbolTable sub_module_symbol_table(sub_module);
3803 
3804   // Functions originally from graphdef library might have a different name
3805   // after conversion, we build the set of the converted names
3806   absl::flat_hash_set<std::string> original_func_mlir_names;
3807   for (const auto& kv : tf_name_to_mlir_name) {
3808     if (original_func_tf_names_.contains(kv.first))
3809       original_func_mlir_names.insert(kv.second);
3810   }
3811 
3812   // Prefix private functions with the unique signature name, so that it cannot
3813   // collide with private functions used in the other signatures.
3814   for (auto func : sub_module.getOps<mlir::func::FuncOp>()) {
3815     if (mlir::tf_saved_model::IsExported(func)) continue;
3816 
3817     // Skip the original functions from graphdef library
3818     if (original_func_mlir_names.count(func.getSymName().str())) continue;
3819 
3820     std::string new_sym_name = absl::StrCat(name, "/", func.getSymName().str());
3821     mlir::StringAttr new_sym_name_attr = builder.getStringAttr(new_sym_name);
3822     if (mlir::failed(sub_module_symbol_table.replaceAllSymbolUses(
3823             func, new_sym_name_attr, sub_module)))
3824       return tensorflow::errors::InvalidArgument(absl::StrCat(
3825           "SavedModelSignatureDefImporterLite: failed to assign a unique "
3826           "name to the private function used in a signature: ",
3827           func.getSymName().str()));
3828 
3829     mlir::SymbolTable::setSymbolName(func, new_sym_name);
3830   }
3831 
3832   // Copy all functions used by this signature to the final MLIR module.
3833   for (auto func : sub_module.getOps<mlir::func::FuncOp>()) {
3834     absl::MutexLock l(&symbol_table_mu_);
3835     // The insert here is a NO-OP if the function already exists.
3836     symbol_table_.insert(func.clone());
3837   }
3838 
3839   return OkStatus();
3840 }
3841 
ConvertInitializer(const std::string & target_node_name,const std::vector<AssetInfo> & assets)3842 Status SavedModelSignatureDefImporterLite::ConvertInitializer(
3843     const std::string& target_node_name, const std::vector<AssetInfo>& assets) {
3844   std::vector<std::pair<std::string, TensorInfo>> inputs;
3845   inputs.reserve(assets.size());
3846   for (const auto& asset : assets) {
3847     TensorInfo tensor_info;
3848     tensor_info.set_name(asset.tensor_name);
3849     tensor_info.set_dtype(DT_STRING);
3850     tensor_info.mutable_tensor_shape();
3851     inputs.push_back({asset.tensor_name, tensor_info});
3852   }
3853 
3854   std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
3855   TF_ASSIGN_OR_RETURN(auto sub_module,
3856                       ConvertGraph(target_node_name, inputs, {},
3857                                    {target_node_name}, tf_name_to_mlir_name));
3858 
3859   mlir::SymbolTable sub_symbol_table(*sub_module);
3860 
3861   auto init_func_op =
3862       sub_symbol_table.lookup<mlir::func::FuncOp>(target_node_name);
3863   init_func_op->removeAttr("tf.entry_function");
3864 
3865   mlir::OpBuilder builder(module_->getBodyRegion());
3866 
3867   // Bind asset inputs to asset ops.
3868   DCHECK_EQ(init_func_op.getNumArguments(), assets.size());
3869   for (const auto& iter : llvm::enumerate(assets)) {
3870     auto asset_op = iter.value().op;
3871     init_func_op.setArgAttr(
3872         iter.index(), "tf_saved_model.bound_input",
3873         mlir::SymbolRefAttr::get(builder.getContext(), asset_op.getName()));
3874   }
3875 
3876   // Set the exported name of init function to an reserved name for
3877   // tf_saved_model.
3878   init_func_op->setAttr(
3879       "tf_saved_model.exported_names",
3880       builder.getStrArrayAttr({absl::StrCat(
3881           "__tf_saved_model_session_initializer_", target_node_name)}));
3882 
3883   // Move the converted functions to top level MLIR module.
3884   return MoveConvertedFunctionsToModule(target_node_name, *sub_module,
3885                                         tf_name_to_mlir_name);
3886 }
3887 
3888 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>>
ConvertGraph(const std::string & name,const std::vector<std::pair<std::string,TensorInfo>> & inputs,const std::vector<std::pair<std::string,TensorInfo>> & outputs,const std::vector<std::string> control_outputs,std::unordered_map<std::string,std::string> & tf_name_to_mlir_name)3889 SavedModelSignatureDefImporterLite::ConvertGraph(
3890     const std::string& name,
3891     const std::vector<std::pair<std::string, TensorInfo>>& inputs,
3892     const std::vector<std::pair<std::string, TensorInfo>>& outputs,
3893     const std::vector<std::string> control_outputs,
3894     std::unordered_map<std::string, std::string>& tf_name_to_mlir_name) {
3895   VLOG(1) << "Importing Signature: " << name;
3896 
3897   GraphImportConfig specs;
3898   specs.graph_func_name = name;
3899   specs.prune_unused_nodes = true;
3900   TF_ASSIGN_OR_RETURN(specs.inputs, ParseInputArrays(inputs));
3901   for (auto& output : outputs) specs.outputs.push_back(output.second.name());
3902   specs.control_outputs = control_outputs;
3903   specs.enable_shape_inference = false;
3904   specs.unconditionally_use_set_output_shapes =
3905       unconditionally_use_set_output_shapes_;
3906 
3907   TF_ASSIGN_OR_RETURN(const auto* subgraph, input_.GetSubGraph(name, specs));
3908 
3909   // Convert sub-graph to MLIR module.
3910   return GraphDefImporter::Convert(module_->getContext(), *subgraph,
3911                                    input_.debug_info(), subgraph->flib_def(),
3912                                    specs, tf_name_to_mlir_name);
3913 }
3914 
ConvertSignature(const std::string & sig_def_key,const SignatureDef & signature_def)3915 Status SavedModelSignatureDefImporterLite::ConvertSignature(
3916     const std::string& sig_def_key, const SignatureDef& signature_def) {
3917   // Create local vectors for the input and output and sort them to be
3918   // deterministic. We don't want anyone to really depend on the order, client
3919   // should lookup argument/result mapping by attribute name.
3920   // To avoid accidentally depending on the order we use an unintuitive sorting.
3921   std::vector<std::pair<std::string, TensorInfo>> inputs(
3922       signature_def.inputs().begin(), signature_def.inputs().end());
3923   llvm::sort(inputs, [](const auto& lhs, const auto& rhs) {
3924     return tensorflow::Fingerprint64(lhs.first) <
3925            tensorflow::Fingerprint64(rhs.first);
3926   });
3927   std::vector<std::pair<std::string, TensorInfo>> outputs(
3928       signature_def.outputs().begin(), signature_def.outputs().end());
3929   llvm::sort(outputs, [](const auto& lhs, const auto& rhs) {
3930     return tensorflow::Fingerprint64(lhs.first) <
3931            tensorflow::Fingerprint64(rhs.first);
3932   });
3933 
3934   std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
3935 
3936   // Convert sub-graph to MLIR module.
3937   TF_ASSIGN_OR_RETURN(
3938       auto sub_module,
3939       ConvertGraph(sig_def_key, inputs, outputs, {}, tf_name_to_mlir_name));
3940   mlir::OpBuilder builder(sub_module->getBodyRegion());
3941 
3942   // Find the FuncOp which corresponds to current SignatureDef.
3943   mlir::SymbolTable sub_symbol_table(*sub_module);
3944   auto func_op = sub_symbol_table.lookup<mlir::func::FuncOp>(sig_def_key);
3945   TF_RET_CHECK(func_op)
3946       << "Graphdef importer should have created a function named "
3947       << sig_def_key << ".";
3948 
3949   // Use unique SignatureDef key as exported name.
3950   func_op->setAttr("tf_saved_model.exported_names",
3951                    builder.getStrArrayAttr({sig_def_key}));
3952 
3953   // Transfer input and output parameter names to index_path attributes.
3954   for (auto input_and_idx : llvm::enumerate(inputs)) {
3955     func_op.setArgAttr(input_and_idx.index(), "tf_saved_model.index_path",
3956                        builder.getStrArrayAttr({input_and_idx.value().first}));
3957   }
3958   for (auto output_and_idx : llvm::enumerate(outputs)) {
3959     func_op.setResultAttr(
3960         output_and_idx.index(), "tf_saved_model.index_path",
3961         builder.getStrArrayAttr({output_and_idx.value().first}));
3962   }
3963 
3964   // Move the converted functions to top level MLIR module.
3965   return MoveConvertedFunctionsToModule(sig_def_key, *sub_module,
3966                                         tf_name_to_mlir_name);
3967 }
3968 
3969 StatusOr<GraphImportConfig::InputArrays>
ParseInputArrays(llvm::ArrayRef<std::pair<std::string,TensorInfo>> inputs)3970 SavedModelSignatureDefImporterLite::ParseInputArrays(
3971     llvm::ArrayRef<std::pair<std::string, TensorInfo>> inputs) {
3972   GraphImportConfig::InputArrays results;
3973   for (const auto& iter : inputs) {
3974     const auto& tensor_info = iter.second;
3975 
3976     // TODO(b/184675681): Support other encoding cases.
3977     //
3978     // TODO(b/184679394): Add unit test for this check.
3979     TF_RET_CHECK(tensor_info.encoding_case() == tensorflow::TensorInfo::kName)
3980         << "Only dense tensor is supported, but got encoding case "
3981         << tensor_info.encoding_case();
3982 
3983     VLOG(1) << "Importing Signature Input: input_name = " << iter.first
3984             << ", tensor_info = " << tensor_info.DebugString();
3985 
3986     ArrayInfo array_info;
3987     array_info.imported_dtype = tensor_info.dtype();
3988 
3989     if (tensor_info.has_tensor_shape()) {
3990       array_info.shape = tensor_info.tensor_shape();
3991     } else {
3992       // If there is no tensor shape in the tensor info, conservatively set
3993       // unknown_rank to true.
3994       array_info.shape.set_unknown_rank(true);
3995     }
3996 
3997     results.insert(std::pair<std::string, ArrayInfo>(tensor_info.name(),
3998                                                      std::move(array_info)));
3999   }
4000   return results;
4001 }
4002 
4003 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>>
ConvertSignatures()4004 SavedModelSignatureDefImporterLite::ConvertSignatures() {
4005   LoadImporterDialects(*module_->getContext());
4006 
4007   const auto& signatures = input_.meta_graph_def().signature_def();
4008   PopulateTfVersions(module_.get(),
4009                      input_.meta_graph_def().graph_def().versions());
4010 
4011   llvm::DenseSet<llvm::StringRef> exported_name_set;
4012   bool import_all_signatures = !exported_names_.has_value();
4013   if (exported_names_.has_value()) {
4014     exported_name_set.insert(exported_names_->begin(), exported_names_->end());
4015   }
4016 
4017   absl::Mutex error_status_mu;  // Needed since `error_status` is non-atomic.
4018   tensorflow::Status error_status;
4019   {
4020     // Start a threadpool to convert signatures, since signature conversion can
4021     // be time consuming especially for large models. Threadpool destructor
4022     // blocks until all work is done.
4023     thread::ThreadPool thread_pool(Env::Default(), "ConvertSignatures",
4024                                    kNumThreadToConvertSignatures);
4025     for (const auto& key_and_signature_def : signatures) {
4026       const std::string& sig_def_key = key_and_signature_def.first;
4027       const SignatureDef& signature_def = key_and_signature_def.second;
4028 
4029       // It is safe to skip "__saved_model_init_op" since it is an internal
4030       // signature that is not user-accessible. This signature will be handled
4031       // in ConvertInitializer().
4032       if (sig_def_key == "__saved_model_init_op") {
4033         continue;
4034       }
4035       if (!import_all_signatures && exported_name_set.count(sig_def_key) == 0) {
4036         continue;
4037       }
4038 
4039       thread_pool.Schedule([&]() {
4040         auto status = ConvertSignature(sig_def_key, signature_def);
4041         if (!status.ok()) {
4042           absl::MutexLock l(&error_status_mu);
4043           error_status = std::move(status);
4044         }
4045       });
4046     }
4047   }
4048   TF_RETURN_IF_ERROR(error_status);
4049 
4050   TF_ASSIGN_OR_RETURN(auto assets, ConvertAssets());
4051 
4052   mlir::OpBuilder builder(module_->getBodyRegion());
4053   llvm::SmallVector<mlir::Attribute, 2> init_sym_refs;
4054 
4055   if (import_restore_ && input_.meta_graph_def().has_saver_def()) {
4056     std::vector<AssetInfo> variable_and_assets;
4057 
4058     // Create an AssetOp for the variable checkpoint files. The relative
4059     // filename is used here.
4060     auto variable_filename_op = builder.create<mlir::tf_saved_model::AssetOp>(
4061         module_->getLoc(),
4062         /*sym_name=*/
4063         builder.getStringAttr("__tf_saved_model_variables"),
4064         /*filename=*/
4065         builder.getStringAttr(io::JoinPath(kSavedModelVariablesDirectory,
4066                                            kSavedModelVariablesFilename)));
4067     variable_and_assets.push_back(
4068         {input_.meta_graph_def().saver_def().filename_tensor_name(),
4069          variable_filename_op});
4070     variable_and_assets.insert(variable_and_assets.end(), assets.begin(),
4071                                assets.end());
4072 
4073     const auto& restore_op_name =
4074         input_.meta_graph_def().saver_def().restore_op_name();
4075     TF_RETURN_IF_ERROR(
4076         ConvertInitializer(restore_op_name, variable_and_assets));
4077     init_sym_refs.push_back(
4078         mlir::SymbolRefAttr::get(builder.getContext(), restore_op_name));
4079   }
4080 
4081   std::string init_op_name;
4082   TF_RETURN_IF_ERROR(
4083       internal::GetInitOp("", input_.meta_graph_def(), &init_op_name));
4084   if (!init_op_name.empty()) {
4085     TF_RETURN_IF_ERROR(ConvertInitializer(init_op_name, assets));
4086     init_sym_refs.push_back(
4087         mlir::SymbolRefAttr::get(builder.getContext(), init_op_name));
4088   }
4089 
4090   builder.create<mlir::tf_saved_model::SessionInitializerOp>(
4091       module_->getLoc(), builder.getArrayAttr(init_sym_refs));
4092 
4093   (*module_)->setAttr("tf_saved_model.semantics", builder.getUnitAttr());
4094 
4095   SortSavedModelModule(*module_);
4096   MarkSavedModelFunctionVisibility(*module_);
4097 
4098   return std::move(module_);
4099 }
4100 
4101 // A helper class to import a TensorFlow model expressed in SavedModel V1 into
4102 // an MLIR Module in SavedModel dialect. In addition to importing the model, it
4103 // performs a few graph transformations, including:
4104 //  1) Convert read-only ref variables to resource variables
4105 //  2) Lift resource variables to global_tensors by using a TF session.
4106 class SavedModelSignatureDefImporter {
4107  public:
4108   // Main entry point: converts all functions (specified by SignatureDefs) in
4109   // the given meta graph to an MLIR Module.
Convert(const SavedModelBundle & bundle,std::optional<absl::Span<const std::string>> exported_names,mlir::MLIRContext * context,tensorflow::MLIRImportOptions options,bool lift_varhandle_ops_to_args=true)4110   static StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> Convert(
4111       const SavedModelBundle& bundle,
4112       std::optional<absl::Span<const std::string>> exported_names,
4113       mlir::MLIRContext* context, tensorflow::MLIRImportOptions options,
4114       bool lift_varhandle_ops_to_args = true) {
4115     // debug_info might not be loaded with loader_lite.
4116     GraphDebugInfo debug_info;
4117     if (bundle.debug_info != nullptr) debug_info = *bundle.debug_info;
4118 
4119     TF_ASSIGN_OR_RETURN(auto input,
4120                         SimpleSavedModelMLIRImportInput::Create(
4121                             options, &bundle.meta_graph_def, debug_info));
4122 
4123     TF_ASSIGN_OR_RETURN(auto module,
4124                         SavedModelSignatureDefImporterLite::Convert(
4125                             input, exported_names, context,
4126                             /*import_restore=*/false));
4127 
4128     mlir::OpBuilder builder(module->getContext());
4129     (*module)->setAttr("tf_saved_model.under_construction",
4130                        builder.getUnitAttr());
4131     TF_RETURN_IF_ERROR(
4132         LiftVariables(bundle, *module, lift_varhandle_ops_to_args));
4133     (*module)->removeAttr("tf_saved_model.under_construction");
4134 
4135     return module;
4136   }
4137 
4138  private:
4139   // Lifts the variables in `module`.
4140   static Status LiftVariables(const SavedModelBundle& bundle,
4141                               mlir::ModuleOp module,
4142                               bool lift_varhandle_ops_to_args);
4143 };
4144 
LiftVariables(const SavedModelBundle & bundle,mlir::ModuleOp module,bool lift_varhandle_ops_to_args)4145 Status SavedModelSignatureDefImporter::LiftVariables(
4146     const SavedModelBundle& bundle, mlir::ModuleOp module,
4147     bool lift_varhandle_ops_to_args) {
4148   mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
4149 
4150   mlir::PassManager pm(module.getContext());
4151   SetCrashReproducer(pm);
4152   pm.addNestedPass<mlir::func::FuncOp>(
4153       mlir::tf_executor::CreateTFExecutorGraphPruningPass());
4154   pm.addNestedPass<mlir::func::FuncOp>(
4155       mlir::CreateExecutorDialectToFunctionalConversionPass());
4156   pm.addPass(
4157       mlir::tf_saved_model::CreateRemoveVariablesInSessionInitializerPass());
4158   pm.addNestedPass<mlir::func::FuncOp>(
4159       mlir::TF::
4160           CreateConvertReadonlyReferenceVariablesToResourceVariablesPass());
4161   if (mlir::failed(pm.run(module)))
4162     return diag_handler.Combine(
4163         errors::Internal("Failed to prepare to lift variables."));
4164 
4165   if (lift_varhandle_ops_to_args) {
4166     if (failed(mlir::tf_saved_model::MarkInitializedVariablesInFunction(
4167             module, bundle.GetSession())))
4168       return diag_handler.Combine(
4169           errors::Internal("Failed to prepare to mark initialized variables."));
4170     pm.clear();
4171     pm.addPass(mlir::TF::CreatePromoteVarHandlesToArgsPass());
4172     if (mlir::failed(pm.run(module)))
4173       return diag_handler.Combine(
4174           errors::Internal("Failed to promote var handles to args."));
4175     if (failed(
4176             mlir::tf_saved_model::LiftVariables(module, bundle.GetSession())))
4177       return diag_handler.Combine(
4178           errors::Internal("Failed to lift variables."));
4179   } else {
4180     if (failed(mlir::tf_saved_model::InitializeVariablesInSessionInitializer(
4181             module, bundle.GetSession())))
4182       return diag_handler.Combine(
4183           errors::Internal("Failed to initialize variables in session init."));
4184   }
4185 
4186   pm.clear();
4187   pm.addNestedPass<mlir::func::FuncOp>(
4188       mlir::tf_saved_model::CreateDedupBoundInputBindingPass());
4189   if (mlir::failed(pm.run(module)))
4190     return diag_handler.Combine(
4191         errors::Internal("Failed to dedup bound inputs."));
4192 
4193   return OkStatus();
4194 }
4195 
4196 }  // namespace
4197 
~SavedModelMLIRImportInput()4198 SavedModelMLIRImportInput::~SavedModelMLIRImportInput() {}
4199 
ConvertGraphdefToMlir(const GraphDef & graphdef,const GraphDebugInfo & debug_info,const GraphImportConfig & specs,mlir::MLIRContext * context,bool add_default_attributes)4200 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertGraphdefToMlir(
4201     const GraphDef& graphdef, const GraphDebugInfo& debug_info,
4202     const GraphImportConfig& specs, mlir::MLIRContext* context,
4203     bool add_default_attributes) {
4204   GraphConstructorOptions options;
4205   options.allow_internal_ops = true;
4206   options.add_default_attributes = add_default_attributes;
4207   Graph graph(OpRegistry::Global());
4208 
4209   GraphDef preprocessed_graphdef(graphdef);
4210   if (add_default_attributes) {
4211     TF_RETURN_IF_ERROR(PreprocessGraphDef(&specs, &preprocessed_graphdef));
4212   }
4213   if (specs.upgrade_legacy) {
4214     TF_RETURN_IF_ERROR(GenerateResourceSharedNameIfEmpty(
4215         preprocessed_graphdef, graph.flib_def().default_registry()));
4216   }
4217   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
4218       options, std::move(preprocessed_graphdef), &graph));
4219   return ConvertGraphToMlir(graph, debug_info, graph.flib_def(), specs,
4220                             context);
4221 }
4222 
ConvertGraphToMlir(const Graph & graph,const GraphDebugInfo & debug_info,const FunctionLibraryDefinition & flib_def,const GraphImportConfig & specs,mlir::MLIRContext * context)4223 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertGraphToMlir(
4224     const Graph& graph, const GraphDebugInfo& debug_info,
4225     const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs,
4226     mlir::MLIRContext* context) {
4227   // TODO(jpienaar): Remove need to const_cast.
4228   if (specs.upgrade_legacy) {
4229     TF_RETURN_IF_ERROR(
4230         UpgradeLegacyGraph(const_cast<Graph*>(&graph),
4231                            const_cast<FunctionLibraryDefinition*>(&flib_def),
4232                            specs.restrict_functionalization_to_compiled_nodes));
4233   }
4234   std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
4235   return GraphDefImporter::Convert(context, graph, debug_info, flib_def, specs,
4236                                    tf_name_to_mlir_name);
4237 }
4238 
4239 stream_executor::port::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>>
ConvertFunctionToMlir(const FunctionBody * fbody,const FunctionLibraryDefinition & flib_def,mlir::MLIRContext * context)4240 ConvertFunctionToMlir(const FunctionBody* fbody,
4241                       const FunctionLibraryDefinition& flib_def,
4242                       mlir::MLIRContext* context) {
4243   tensorflow::GraphDebugInfo dummy_debug_info;
4244   tensorflow::GraphImportConfig specs;
4245   specs.graph_func_name = fbody->fdef.signature().name();
4246   specs.enable_shape_inference = false;
4247   specs.graph_as_function = true;
4248   for (const auto* control_ret_node : fbody->control_ret_nodes)
4249     specs.control_outputs.push_back(control_ret_node->name());
4250   std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
4251   return GraphDefImporter::Convert(context, *fbody->graph, dummy_debug_info,
4252                                    flib_def, specs, tf_name_to_mlir_name);
4253 }
4254 
ConvertSavedModelToMlir(SavedModelV2Bundle * saved_model,mlir::MLIRContext * context,absl::Span<std::string> exported_names,bool add_default_attributes,bool unconditionally_use_set_output_shapes)4255 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertSavedModelToMlir(
4256     SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
4257     absl::Span<std::string> exported_names, bool add_default_attributes,
4258     bool unconditionally_use_set_output_shapes) {
4259   return SavedModelObjectGraphImporter::Convert(
4260       saved_model, exported_names, context, add_default_attributes,
4261       unconditionally_use_set_output_shapes);
4262 }
4263 
ConvertSavedModelV1ToMlir(const SavedModelBundle & saved_model,absl::Span<std::string> exported_names,mlir::MLIRContext * context,MLIRImportOptions options,bool lift_variables)4264 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertSavedModelV1ToMlir(
4265     const SavedModelBundle& saved_model, absl::Span<std::string> exported_names,
4266     mlir::MLIRContext* context, MLIRImportOptions options,
4267     bool lift_variables) {
4268   std::optional<absl::Span<const std::string>> optional_exported_names;
4269   // TODO(b/187062560): Change ConvertSavedModelV1ToMlir() to take an optional
4270   // `exported_names` so that it can be configured to import only restore/init
4271   // graphs.
4272   if (!exported_names.empty()) optional_exported_names = exported_names;
4273   return SavedModelSignatureDefImporter::Convert(
4274       saved_model, optional_exported_names, context, options, lift_variables);
4275 }
4276 
ConvertSavedModelV1ToMlirLite(const MetaGraphDef & meta_graph_def,const GraphDebugInfo & debug_info,std::optional<absl::Span<const std::string>> exported_names,mlir::MLIRContext * context,MLIRImportOptions options)4277 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertSavedModelV1ToMlirLite(
4278     const MetaGraphDef& meta_graph_def, const GraphDebugInfo& debug_info,
4279     std::optional<absl::Span<const std::string>> exported_names,
4280     mlir::MLIRContext* context, MLIRImportOptions options) {
4281   TF_ASSIGN_OR_RETURN(auto input, SimpleSavedModelMLIRImportInput::Create(
4282                                       options, &meta_graph_def, debug_info));
4283   return ConvertSavedModelV1ToMlirLite(
4284       input, exported_names, context,
4285       options.unconditionally_use_set_output_shapes);
4286 }
4287 
ConvertSavedModelV1ToMlirLite(SavedModelMLIRImportInput & input,std::optional<absl::Span<const std::string>> exported_names,mlir::MLIRContext * context,bool unconditionally_use_set_output_shapes)4288 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertSavedModelV1ToMlirLite(
4289     SavedModelMLIRImportInput& input,
4290     std::optional<absl::Span<const std::string>> exported_names,
4291     mlir::MLIRContext* context, bool unconditionally_use_set_output_shapes) {
4292   return SavedModelSignatureDefImporterLite::Convert(
4293       input, exported_names, context,
4294       /*import_restore=*/true, unconditionally_use_set_output_shapes);
4295 }
4296 
MlirModuleToString(mlir::ModuleOp module,mlir::OpPrintingFlags flags)4297 std::string MlirModuleToString(mlir::ModuleOp module,
4298                                mlir::OpPrintingFlags flags) {
4299   std::string txt_module;
4300   {
4301     llvm::raw_string_ostream os{txt_module};
4302     module.print(os, flags);
4303   }
4304   return txt_module;
4305 }
4306 
MlirModuleToString(mlir::ModuleOp module,bool show_debug_info)4307 std::string MlirModuleToString(mlir::ModuleOp module, bool show_debug_info) {
4308   mlir::OpPrintingFlags flags;
4309   if (show_debug_info) flags.enableDebugInfo();
4310   return MlirModuleToString(module, flags);
4311 }
4312 
4313 }  // namespace tensorflow
4314