1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
17
18 #include <atomic>
19 #include <functional>
20 #include <iterator>
21 #include <memory>
22 #include <queue>
23 #include <string>
24 #include <tuple>
25 #include <type_traits>
26 #include <unordered_set>
27 #include <utility>
28 #include <vector>
29
30 #include "absl/algorithm/container.h"
31 #include "absl/base/thread_annotations.h"
32 #include "absl/container/flat_hash_map.h"
33 #include "absl/container/flat_hash_set.h"
34 #include "absl/container/inlined_vector.h"
35 #include "absl/strings/escaping.h"
36 #include "absl/strings/numbers.h"
37 #include "absl/strings/str_cat.h"
38 #include "absl/strings/str_join.h"
39 #include "absl/strings/string_view.h"
40 #include "absl/strings/strip.h"
41 #include "absl/synchronization/mutex.h"
42 #include "llvm/ADT/ArrayRef.h"
43 #include "llvm/ADT/DenseMap.h"
44 #include "llvm/ADT/DenseSet.h"
45 #include "llvm/ADT/STLExtras.h"
46 #include "llvm/ADT/ScopeExit.h"
47 #include "llvm/ADT/SetVector.h"
48 #include "llvm/ADT/SmallVector.h"
49 #include "llvm/ADT/StringRef.h"
50 #include "llvm/ADT/StringSet.h"
51 #include "llvm/ADT/Twine.h"
52 #include "llvm/Support/FormatVariadic.h"
53 #include "llvm/Support/SourceMgr.h"
54 #include "llvm/Support/raw_ostream.h"
55 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
56 #include "mlir/IR/Attributes.h" // from @llvm-project
57 #include "mlir/IR/Builders.h" // from @llvm-project
58 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
59 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
60 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
61 #include "mlir/IR/Diagnostics.h" // from @llvm-project
62 #include "mlir/IR/Location.h" // from @llvm-project
63 #include "mlir/IR/MLIRContext.h" // from @llvm-project
64 #include "mlir/IR/OpDefinition.h" // from @llvm-project
65 #include "mlir/IR/Types.h" // from @llvm-project
66 #include "mlir/IR/Verifier.h" // from @llvm-project
67 #include "mlir/Pass/PassManager.h" // from @llvm-project
68 #include "tensorflow/cc/saved_model/constants.h"
69 #include "tensorflow/cc/saved_model/loader_util.h"
70 #include "tensorflow/compiler/jit/shape_inference_helpers.h"
71 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
72 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
73 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
74 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
75 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
76 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
77 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
78 #include "tensorflow/compiler/mlir/tensorflow/transforms/initialize_variables_in_session_init.h"
79 #include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h"
80 #include "tensorflow/compiler/mlir/tensorflow/transforms/mark_initialized_variables.h"
81 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
82 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
83 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
84 #include "tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h"
85 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h"
86 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
87 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
88 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
89 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
90 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
91 #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
92 #include "tensorflow/compiler/xla/status_macros.h"
93 #include "tensorflow/core/common_runtime/function.h"
94 #include "tensorflow/core/common_runtime/graph_constructor.h"
95 #include "tensorflow/core/common_runtime/shape_refiner.h"
96 #include "tensorflow/core/framework/attr_value.pb.h"
97 #include "tensorflow/core/framework/function.pb.h"
98 #include "tensorflow/core/framework/graph.pb.h"
99 #include "tensorflow/core/framework/node_def.pb.h"
100 #include "tensorflow/core/framework/node_def_util.h"
101 #include "tensorflow/core/framework/op.h"
102 #include "tensorflow/core/framework/resource_var.h"
103 #include "tensorflow/core/framework/shape_inference.h"
104 #include "tensorflow/core/framework/tensor.pb.h"
105 #include "tensorflow/core/framework/types.h"
106 #include "tensorflow/core/framework/types.pb.h"
107 #include "tensorflow/core/framework/versions.pb.h"
108 #include "tensorflow/core/graph/algorithm.h"
109 #include "tensorflow/core/graph/graph.h"
110 #include "tensorflow/core/graph/node_builder.h"
111 #include "tensorflow/core/graph/tensor_id.h"
112 #include "tensorflow/core/grappler/utils/transitive_fanin.h"
113 #include "tensorflow/core/lib/core/errors.h"
114 #include "tensorflow/core/lib/strings/str_util.h"
115 #include "tensorflow/core/platform/crash_analysis.h"
116 #include "tensorflow/core/platform/env.h"
117 #include "tensorflow/core/platform/errors.h"
118 #include "tensorflow/core/platform/fingerprint.h"
119 #include "tensorflow/core/platform/logging.h"
120 #include "tensorflow/core/platform/path.h"
121 #include "tensorflow/core/platform/protobuf.h"
122 #include "tensorflow/core/platform/threadpool.h"
123 #include "tensorflow/core/platform/types.h"
124 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
125 #include "tensorflow/core/protobuf/meta_graph.pb.h"
126 #include "tensorflow/core/protobuf/saved_object_graph.pb.h"
127 #include "tensorflow/core/protobuf/saver.pb.h"
128 #include "tensorflow/core/protobuf/struct.pb.h"
129 #include "tensorflow/core/protobuf/trackable_object_graph.pb.h"
130 #include "tensorflow/core/util/device_name_utils.h"
131 #include "tensorflow/core/util/dump_graph.h"
132 #include "tensorflow/stream_executor/lib/statusor.h"
133
StringRefToView(llvm::StringRef ref)134 static inline absl::string_view StringRefToView(llvm::StringRef ref) {
135 return {ref.data(), ref.size()};
136 }
137
138 namespace tensorflow {
139
140 constexpr size_t kNumThreadToConvertSignatures = 10;
141 constexpr absl::string_view kOutputShapesAttrName = "_output_shapes";
142
143 using mlir::NamedAttrList;
144 using mlir::TensorType;
145 using mlir::tf_saved_model::AssetOp;
146 using mlir::tf_saved_model::GlobalTensorOp;
147 using mlir::tf_saved_model::SessionInitializerOp;
148 using stream_executor::port::StatusOr;
149
150 namespace {
151
IsOutputShapesAttribute(const AttrValue & attr_value,llvm::StringRef attr_name)152 bool IsOutputShapesAttribute(const AttrValue& attr_value,
153 llvm::StringRef attr_name) {
154 return attr_name.compare(kOutputShapesAttrName) == 0 &&
155 attr_value.value_case() == AttrValue::kList;
156 }
157
IsResourceOutputShapesAttribute(const AttrValue & attr_value,llvm::StringRef attr_name)158 bool IsResourceOutputShapesAttribute(const AttrValue& attr_value,
159 llvm::StringRef attr_name) {
160 if (attr_name == "_handle_dtypes" || attr_name == "_handle_shapes")
161 return attr_value.value_case() == AttrValue::kList;
162 return false;
163 }
164
LoadImporterDialects(mlir::MLIRContext & context)165 void LoadImporterDialects(mlir::MLIRContext& context) {
166 // Load dialects involved in the conversion
167 mlir::DialectRegistry registry;
168 mlir::RegisterAllTensorFlowDialects(registry);
169 context.appendDialectRegistry(registry);
170 for (llvm::StringRef name : registry.getDialectNames())
171 context.getOrLoadDialect(name);
172 }
173
174 // This class is used to generate new MLIR function name strings that are both
175 // unique in the TF function library `flib_` and unique among the name strings
176 // generated by the class object during its lifetime.
177 //
178 // In theory, this class is not necessary because we should simply take
179 // the TF function name and use it as MLIR function name. However, for some
180 // unknown reasons (callout for investigation in b/142268695), keeping the
181 // function names unchanged in an MLIR roundtrip causes test failures.
182 // TODO(b/142268695) Re-evaluate whether we need this class v.s. directly using
183 // and TF function name as MLIR function name after b/142268695 is root caused.
184 class NameUniquifier : public OpOrArgNameMapper {
185 public:
NameUniquifier(const FunctionLibraryDefinition & flib)186 explicit NameUniquifier(const FunctionLibraryDefinition& flib)
187 : flib_(flib) {}
188
189 private:
IsUnique(llvm::StringRef name)190 bool IsUnique(llvm::StringRef name) override {
191 return !flib_.Contains(std::string(name));
192 }
193
GetName(OpOrVal op_or_val)194 std::string GetName(OpOrVal op_or_val) override {
195 DCHECK(false) << "Unimplemented";
196 return "";
197 }
198
199 const FunctionLibraryDefinition& flib_;
200 };
201
202 // Stateful helper class to import a TensorFlow model into an MLIR Module.
203 //
204 // This is the base class that contains common utilities shared between the
205 // GraphDef importer and SavedModel importer.
206 //
207 // A subclass is expected to call `PrepareConvert` first to perform necessary
208 // preparation over the graph and also certain internal bookkeeping data.
209 // Afterwards the other protected methods can be called.
210 class ImporterBase {
211 protected:
ImporterBase(const FunctionLibraryDefinition & flib,const GraphDebugInfo & debug_info,const GraphImportConfig & specs,mlir::ModuleOp module,std::unordered_map<std::string,std::string> * tf_name_to_mlir_name,NameUniquifier * function_name_uniquifier,llvm::StringRef function_name_for_debug_info="")212 explicit ImporterBase(
213 const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
214 const GraphImportConfig& specs, mlir::ModuleOp module,
215 std::unordered_map<std::string, std::string>* tf_name_to_mlir_name,
216 NameUniquifier* function_name_uniquifier,
217 llvm::StringRef function_name_for_debug_info = "")
218 : builder_(module.getContext()),
219 module_(module),
220 context_(module.getContext()),
221 tf_name_to_mlir_name_(tf_name_to_mlir_name),
222 graph_flib_(flib),
223 specs_(specs),
224 debug_info_(debug_info),
225 function_name_for_debug_info_(function_name_for_debug_info),
226 function_name_uniquifier_(function_name_uniquifier),
227 error_handler_(module.getContext()) {
228 // Log import config.
229 if (VLOG_IS_ON(1)) {
230 LOG(INFO) << "Importing with: " << specs.str();
231 for (auto& it : *tf_name_to_mlir_name) {
232 LOG(INFO) << "\t" << it.first << " -> " << it.second;
233 }
234 }
235 }
236
237 // Returns the inferred function signature of the given function body. Input
238 // types are unranked tensor of the respective datatype in the function and
239 // result types are inferred by the shape_refiner_. Result types need not be
240 // unranked tensors and could be ranked tensors in cases where result type
241 // depends on an op with static output shape like tf.Const.
242 StatusOr<mlir::FunctionType> InferLibFunctionType(const FunctionBody& fbody);
243
244 // Extracts arg and ret nodes from FunctionBody.
245 void GetArgsAndRetsFromFunctionBody(
246 const FunctionBody& fbody,
247 absl::InlinedVector<OutputTensor, 4>* arg_nodes,
248 absl::InlinedVector<OutputTensor, 4>* ret_nodes,
249 absl::InlinedVector<Node*, 4>* control_ret_nodes);
250
251 // Prepares converting the graph to an MLIR module. This step removes the
252 // backedges of the graph, orders the nodes and infers the shapes.
253 // PrepareConvert needs to ensure that the original `graph` is cloned prior
254 // execution. The cloning procedure relies on the roundtrip through the
255 // GraphDef. Graph to GraphDef def conversion is heavy, in case, `graph_def`
256 // was obtained previously provide it to the PrepareConvert to reuse.
257 Status PrepareConvert(const Graph& graph,
258 std::unique_ptr<GraphDef> graph_def = nullptr);
259
260 // Converts the prepared graph to a Function and adds it to the module. A set
261 // of nodes from the graph are given to converted to the arguments and returns
262 // of the function.
263 Status Convert(llvm::StringRef func_name, mlir::FunctionType func_type,
264 const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
265 const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
266 const absl::InlinedVector<Node*, 4>& control_ret_nodes,
267 llvm::ArrayRef<mlir::NamedAttribute> attrs);
268
269 // Finds out the function definition for the given function name from the
270 // graph and converts it to a function of the module. This method is called
271 // on demand because the graph flib_def does not provide an iterator
272 // interface.
273 Status ConvertLibFunction(llvm::StringRef func_name);
274
275 // Returns the list of nodes in the graph. Nodes are presented in the reverse
276 // order of a post-order depth-first visit starting from the graph's source
277 // nodes.
GetOrderedNodes() const278 llvm::ArrayRef<Node*> GetOrderedNodes() const { return ordered_nodes_; }
279
280 // Returns the inferred input type at index `idx` of the `node` in the
281 // context.
282 StatusOr<mlir::Type> InferInputType(const Node& node, int idx,
283 mlir::Builder builder);
284
285 // Returns the inferred output type at index `idx` of the `node` in the
286 // context.
287 StatusOr<mlir::Type> InferOutputType(const Node& node, int idx,
288 mlir::Builder builder);
289
290 // Convert deferred TF functions to the MLIR representation.
291 // Conversion is deferred for efficiency reasons, e.g., to limit depth
292 // of recursion and reduce stack size pressure.
293 Status ConvertDeferredFunctions();
294
295 private:
296 // Most types with subtypes have only one subtype.
297 using ElementSubtypes = llvm::SmallVector<TensorType, 1>;
298
299 // Metadata used for deferred function conversion.
300 struct DeferredConversionMetaData {
DeferredConversionMetaDatatensorflow::__anon636eac960111::ImporterBase::DeferredConversionMetaData301 DeferredConversionMetaData(
302 const std::string& function_name,
303 const std::vector<mlir::NamedAttribute>& attributes)
304 : function_name(function_name), attributes(attributes) {}
305
306 std::string function_name;
307 std::vector<mlir::NamedAttribute> attributes;
308 };
309
310 // Adds all the ordered_nodes to the shape refiner shape_refiner_. Then all
311 // data type and shape information is maintained by the shape_refiner_.
312 // TODO(jpienaar): Remove once shape inference on import is removed.
313 Status AddNodesToShapeRefiner(
314 std::unordered_map<string, Node*>* node_name_map);
315
316 // Prune nodes that do not feed into fetch nodes.
317 Status PruneUnreachableNodes(
318 std::unordered_map<string, Node*>* node_name_map);
319
320 // Converts feeds to Placeholder nodes.
321 Status ConvertFeedsToPlaceholders(
322 std::unordered_map<string, Node*>* node_name_map);
323
324 // Converts the inferred shape referred to by 'handle' in 'context', with
325 // given element type, and returns an MLIR tensor type.
326 StatusOr<TensorType> ConvertDataTypeAndShape(
327 DataType dtype, const shape_inference::ShapeHandle& handle,
328 const std::vector<shape_inference::ShapeAndType>* handle_subtypes,
329 shape_inference::InferenceContext* context, mlir::Builder builder);
330
331 // Converts the inferred shape referred to by 'handle' in 'context', with
332 // given element type, and returns an MLIR tensor type.
333 StatusOr<TensorType> ConvertElementTypeAndShape(
334 mlir::Type element_type, const shape_inference::ShapeHandle& handle,
335 shape_inference::InferenceContext* context, mlir::Builder builder);
336
337 // Converts the inferred subtypes for an element type to corresponding MLIR
338 // types in 'context'.
339 StatusOr<ElementSubtypes> ConvertSubtypes(
340 const std::vector<shape_inference::ShapeAndType>* handle_subtypes,
341 shape_inference::InferenceContext* context, mlir::Builder builder);
342
343 // Converts the tensor proto into an MLIR elements attribute.
ConvertTensorProto(const TensorProto & value)344 StatusOr<mlir::ElementsAttr> ConvertTensorProto(const TensorProto& value) {
345 return ::tensorflow::ConvertTensorProto(value, &builder_);
346 }
347
348 // Converts func name in graphdef to mlir::SymbolRefAttribute.
349 StatusOr<mlir::FlatSymbolRefAttr> ConvertFunctionCallName(
350 const std::string& func_name);
351
352 // Converts the given non-function-call AttrValue to an MLIR Attribute.
353 StatusOr<mlir::Attribute> ConvertAttributeValue(const AttrValue& value);
354
355 // Converts the given function-call AttrValue to MLIR Attributes and pushes
356 // them to the given attributes list. For example, if there is a kFunc
357 // AttrValue {name : foo, attrs : {k1 : bar, k2 : rfc}}, it will convert it to
358 // a list of MLIR Attributes: [{base_name : foo}, {base_name.k1 : bar},
359 // {base_name.k2 : rfc}}.
360 Status ConvertFunctionCallAttribute(const std::string& base_name,
361 const AttrValue& value,
362 NamedAttrList* attributes);
363
364 // Helper to create either a tf_executor operation or a TF operation wrapped
365 // in an island.
366 mlir::Operation* CreateOperation(
367 const Node& node, llvm::StringRef node_type_name,
368 const mlir::OperationState& result,
369 const llvm::SmallVectorImpl<mlir::Value>& control_operands);
370
371 // Converts one NodeDef from the input GraphDef into an Operation and
372 // inserts it into the MLIR module using builder_.
373 Status ConvertNode(const Node& node);
374
375 // If the input graph represents a while-loop, the edges pointing from a
376 // "NextIteration" node to a "Merge" node add cyclic dependencies and make the
377 // topological sorting impossible. We need to remove these edges from the
378 // input graph to infer shapes and construct a Function. For each
379 // "NextIteration" node, there are two operations, "NextIteration.source"
380 // and "NextIteration.sink" are added to the MLIR module.
381 using BackEdge = BackEdgeHelper::BackEdge;
382
383 // Removes backedges from the input graph. The removed edges are added back to
384 // to OpBuilder after the remaining graph is converted to the Function.
385 Status RemoveBackedges();
386
387 // Restores backedges removed during shape inference to the final Function.
388 Status AddBackedges();
389
390 // Restores a single backedge in the Function by adding a replicated
391 // operation before the dst operation.
392 Status AddBackedge(mlir::Operation* sink, mlir::Operation* dst,
393 int dst_input);
394
395 // Adds the input arguments and return operation to the function. The
396 // arguments are added as basic block argument. Also the argument types and
397 // the id of the nodes from the input graph needs to be specified.
398 Status ConvertFunctionArgAndRets(
399 mlir::func::FuncOp func, mlir::tf_executor::GraphOp graph_op,
400 llvm::ArrayRef<mlir::Type> arg_types,
401 const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
402 const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
403 const absl::InlinedVector<Node*, 4>& control_ret_nodes);
404
405 // Gets the location information of the given node. It uses the
406 // "original_node_name" in the NodeDef to get the corresponding file location
407 // (FileLineColLoc) from the input DebugInfo and returns an CallSiteLoc. If
408 // there are multiple "original_node_names", a FusedLoc is returned. If the
409 // node name couldn't be found in the input DebugInfo, a NameLoc is used as
410 // the location.
411 mlir::Location GetLocation(const Node& node);
412
413 // Appends the location string for the node to the error message and returns
414 // the combined error status.
415 Status EmitErrorWithLocationStr(const Node& node, const Status& error_status);
416
417 // Inserts a placeholder node in the graph to replace a feed output tensor,
418 // and returns the new placeholder node and a boolean indicating if the
419 // original input node was removed from the graph. Uses of the feed output
420 // tensor are replaced with this placeholder node. If the feed output tensor
421 // is of a single output node, the control dependencies are forwarded to the
422 // the placeholder node, and the original node will be removed.
423 // Note: This modifies the graph, and so any list of ordered nodes needs to be
424 // reconstructed.
425 StatusOr<std::pair<Node*, bool>> CreatePlaceholderNodeForFeed(
426 const TensorShapeProto& shape, DataType dtype, Node* node, int index,
427 const std::unordered_map<string, Node*>& node_name_map);
428
429 // Gets the input and output nodes corresponding to the specified input and
430 // output nodes in specs_. If there are no input or output nodes specified,
431 // nodes will be empty.
432 Status GetInputOutputNodes(
433 const std::unordered_map<string, Node*>& node_name_map,
434 std::unordered_set<const Node*>* nodes);
435
436 // The input graph with backedges removed. The removed backedges are stored
437 // in the back_edge_helper.
438 BackEdgeHelper back_edge_helper_;
439 // A map between node and output index, for each backedge.
440 absl::flat_hash_map<const Node*, int> back_edge_node_output_;
441 absl::flat_hash_map<const Node*, BackEdge> back_edge_dst_inputs_;
442 // A map between sink and source operation of NextIteration
443 absl::flat_hash_map<mlir::Operation*, mlir::Operation*>
444 next_iteration_sink_source_;
445
446 // All nodes and version information about the (copied) imported graph.
447 std::unique_ptr<Graph> graph_;
448 std::vector<Node*> ordered_nodes_;
449
450 // Maps from a Node ID to a MLIR value.
451 using NodeValueMap = absl::flat_hash_map<int, mlir::Operation*>;
452
453 mlir::OpBuilder builder_;
454 mlir::ModuleOp module_;
455 mlir::MLIRContext* context_;
456 std::unordered_map<std::string, std::string>* tf_name_to_mlir_name_;
457 const FunctionLibraryDefinition& graph_flib_;
458 const GraphImportConfig& specs_;
459 const GraphDebugInfo& debug_info_;
460 llvm::StringRef function_name_for_debug_info_;
461 NodeValueMap node_values_;
462 // TODO(jpienaar): Remove once shape inference on import is removed.
463 // The shape_refinner_ will be nullptr if shape inference on import is
464 // not enabled.
465 std::unique_ptr<ShapeRefiner> shape_refiner_ = nullptr;
466 NameUniquifier* function_name_uniquifier_;
467 mlir::StatusScopedDiagnosticHandler error_handler_;
468 // All the TF ops encountered that aren't modelled in dialect.
469 llvm::DenseSet<mlir::StringAttr> unmodelled_op_names_;
470
471 protected:
472 // Maps feed as TensorId to new Placeholder node name.
473 absl::flat_hash_map<TensorId, absl::string_view> remapped_feeds_;
474 // Keep track of functions required deferred conversion.
475 std::queue<DeferredConversionMetaData> deferred_functions_;
476 };
477
478 // Returns true if the node with given name has a non primary output that is
479 // used by some other node as an input. Returns false if no outputs are in use
480 // or only the first output is in use.
HasNonPrimaryOutputInUse(const GraphDef & graph_def,const std::string & node)481 bool HasNonPrimaryOutputInUse(const GraphDef& graph_def,
482 const std::string& node) {
483 for (const auto& node_def : graph_def.node()) {
484 for (const auto& input : node_def.input()) {
485 if (absl::StartsWith(input, node + ":") && input != node + ":0") {
486 return true;
487 }
488 }
489 }
490 return false;
491 }
492
493 // Updates the given LegacyFedInput node with Placeholder node if it is one of
494 // the inputs. Returns an error if non primary output of the LegacyFedInput node
495 // is in use and therefore can not be replaced by the Placeholder node that only
496 // has a single output.
UpdateLegacyFedInputNode(const GraphDef & graph_def,const GraphImportConfig::InputArrays & inputs,NodeDef * node)497 Status UpdateLegacyFedInputNode(const GraphDef& graph_def,
498 const GraphImportConfig::InputArrays& inputs,
499 NodeDef* node) {
500 const std::string& node_name = node->name();
501 auto it = inputs.find(node_name);
502
503 // Node is not an input.
504 if (it == inputs.end()) return OkStatus();
505
506 if (HasNonPrimaryOutputInUse(graph_def, node_name)) {
507 return errors::InvalidArgument(
508 "LegacyFedInput node ", node->name(),
509 " has non primary output in use and can not be replaced with "
510 "Placeholder node");
511 }
512
513 DataType dtype = it->second.imported_dtype;
514 // Uses the existing output type if it isn't specified by the user.
515 if (dtype == DT_INVALID) {
516 dtype = node->attr().at("output_types").list().type(0);
517 }
518 // Update op name, drop inputs and set attributes required by the Placeholder
519 // op.
520 *node->mutable_op() = "Placeholder";
521 node->clear_attr();
522 node->clear_input();
523 AddNodeAttr("dtype", dtype, node);
524 AddNodeAttr("shape", it->second.shape, node);
525 return OkStatus();
526 }
527
528 // Preprocesses GraphDef before it can be converted to Graph by,
529 // - Adding the default attributes to each node def if they are missing from
530 // the GraphDef.
531 // - Replacing LegacyFedInput nodes with Placeholder nodes if
532 // convert_legacy_fed_inputs option is enabled.
PreprocessGraphDef(const GraphImportConfig * specs,GraphDef * graph_def)533 Status PreprocessGraphDef(const GraphImportConfig* specs, GraphDef* graph_def) {
534 for (auto& node_def : *graph_def->mutable_node()) {
535 // TODO(hinsu): Completely deprecate support for LegacyFedInput ops. One
536 // solution could be have a tool to let users upgrade old serialized graphs.
537 if (specs && specs->convert_legacy_fed_inputs &&
538 node_def.op() == "LegacyFedInput") {
539 TF_RETURN_IF_ERROR(
540 UpdateLegacyFedInputNode(*graph_def, specs->inputs, &node_def));
541 }
542
543 const tensorflow::OpRegistrationData* op_reg_data =
544 tensorflow::OpRegistry::Global()->LookUp(node_def.op());
545 if (!op_reg_data) {
546 // This is likely a function call node, so we should continue.
547 continue;
548 }
549 ::tensorflow::AddDefaultsToNodeDef(op_reg_data->op_def, &node_def);
550 }
551 return OkStatus();
552 }
553
554 // Mapping from node name to feed (index and ArrayInfo). Node name must outlive
555 // this map.
556 using FeedsByNode = absl::flat_hash_map<
557 absl::string_view,
558 absl::flat_hash_map<int, const std::pair<std::string, ArrayInfo>*>>;
559
560 // Creates from a `GraphImportConfig::InputArrays` a mapping from a feeds output
561 // tensor name to index and ArrayInfo. Keys and values are backed by
562 // `GraphImportConfig::InputArrays`.
GetFeedsByNode(const GraphImportConfig::InputArrays & inputs)563 StatusOr<FeedsByNode> GetFeedsByNode(
564 const GraphImportConfig::InputArrays& inputs) {
565 FeedsByNode feeds_by_node;
566 feeds_by_node.reserve(inputs.size());
567
568 for (const auto& input : inputs) {
569 TensorId tensor = ParseTensorName(input.first);
570 if (tensor.index() < 0)
571 return errors::FailedPrecondition(
572 "Feed output tensor must be a data output '", tensor.ToString(), "'");
573
574 auto& node = feeds_by_node[tensor.node()];
575 if (!node.insert({tensor.index(), &input}).second)
576 return errors::FailedPrecondition(
577 "Multiple feeds for the same output tensor '", tensor.ToString(),
578 "'");
579 }
580
581 return feeds_by_node;
582 }
583
584 // Creates a unique name for a node that will be replacing a feed output tensor.
GetUniqueNodeName(absl::string_view node_name,int index,const std::unordered_map<string,Node * > & node_name_map)585 std::string GetUniqueNodeName(
586 absl::string_view node_name, int index,
587 const std::unordered_map<string, Node*>& node_name_map) {
588 std::string new_node_name_base = absl::StrCat(node_name, "_", index);
589 int count = 0;
590 std::string new_node_name = new_node_name_base;
591 while (node_name_map.find(new_node_name) != node_name_map.end()) {
592 new_node_name = absl::StrCat(new_node_name_base, "_", count++);
593 }
594 return new_node_name;
595 }
596
ConvertDeferredFunctions()597 Status ImporterBase::ConvertDeferredFunctions() {
598 while (!deferred_functions_.empty()) {
599 auto conversion_metadata = deferred_functions_.front();
600 deferred_functions_.pop();
601
602 const FunctionDef* func_def =
603 graph_flib_.Find(conversion_metadata.function_name);
604 // Converts the graph to an MLIR function and adds it to the module.
605 // We populate the NodeSpec so that all the _Arg ops get their shape
606 // added correctly.
607 GraphImportConfig specs;
608 specs.enable_shape_inference = specs_.enable_shape_inference;
609 specs.unconditionally_use_set_output_shapes =
610 specs_.unconditionally_use_set_output_shapes;
611 for (const auto& name_and_value : func_def->attr()) {
612 if (name_and_value.first == "_input_shapes") {
613 auto& list = name_and_value.second.list();
614 auto& signature = func_def->signature();
615 // Some models have "_input_shapes" attribute, but with its value empty
616 if (list.shape_size() > 0 &&
617 list.shape_size() != signature.input_arg_size()) {
618 return errors::FailedPrecondition(
619 "Number of input arguments must be equal to the length of "
620 "_input_shapes attribute in function '",
621 StringRefToView(conversion_metadata.function_name), "'.");
622 }
623 for (int i = 0, e = signature.input_arg_size(); i < e; i++) {
624 auto& input_arg = signature.input_arg(i);
625 auto& array_info = specs.inputs[input_arg.name()];
626 array_info.imported_dtype = input_arg.type();
627 // set to unranked for empty "_input_shapes" attribute
628 if (list.shape_size() > 0)
629 array_info.shape = list.shape(i);
630 else
631 array_info.shape.set_unknown_rank(true);
632 }
633 }
634 }
635
636 ImporterBase importer(graph_flib_, debug_info_, specs, module_,
637 tf_name_to_mlir_name_, function_name_uniquifier_,
638 conversion_metadata.function_name);
639
640 std::unique_ptr<FunctionBody> fbody;
641 TF_RETURN_IF_ERROR(
642 FunctionDefToBodyHelper(*func_def, AttrSlice(), &graph_flib_, &fbody));
643 TF_RETURN_IF_ERROR(importer.PrepareConvert(*fbody->graph));
644
645 TF_ASSIGN_OR_RETURN(auto func_type, importer.InferLibFunctionType(*fbody));
646
647 absl::InlinedVector<OutputTensor, 4> arg_nodes;
648 absl::InlinedVector<OutputTensor, 4> ret_nodes;
649 absl::InlinedVector<Node*, 4> control_ret_nodes;
650 importer.GetArgsAndRetsFromFunctionBody(*fbody, &arg_nodes, &ret_nodes,
651 &control_ret_nodes);
652 const std::string& mlir_func_name =
653 (*tf_name_to_mlir_name_)[conversion_metadata.function_name];
654
655 TF_RETURN_IF_ERROR(importer.Convert(mlir_func_name, func_type, arg_nodes,
656 ret_nodes, control_ret_nodes,
657 conversion_metadata.attributes));
658
659 // Additional function bodies could be discovered during the deferred
660 // loading of the current function. Add them to the working queue.
661 while (!importer.deferred_functions_.empty()) {
662 deferred_functions_.push(importer.deferred_functions_.front());
663 importer.deferred_functions_.pop();
664 }
665 }
666
667 return OkStatus();
668 }
669
RemoveBackedges()670 Status ImporterBase::RemoveBackedges() {
671 // Remove all the backedges. So the nodes can be added to the shape refiner.
672 TF_RETURN_IF_ERROR(back_edge_helper_.Remove(graph_.get()));
673 VLOG(1) << "Found " << (back_edge_helper_.RemovedEdges().size())
674 << " backedges.";
675
676 // Creates a map for quickly identifying whether a node output is a backedge.
677 for (const auto& edge : back_edge_helper_.RemovedEdges()) {
678 if (back_edge_node_output_.find(edge.src) != back_edge_node_output_.end() &&
679 back_edge_node_output_[edge.src] != edge.src_output) {
680 return errors::FailedPrecondition(
681 "More than one of the src node outputs are backedges!");
682 }
683 back_edge_node_output_[edge.src] = edge.src_output;
684 // We expect a merge to receive a single backedge (multiple NextIteration
685 // nodes feeding into the same merge is unexpected here).
686 DCHECK(!back_edge_dst_inputs_.contains(edge.dst));
687 back_edge_dst_inputs_[edge.dst] = edge;
688 }
689
690 // Obtains a RPO ordering, using node names as a tiebreak for stable sorting.
691 GetReversePostOrder(
692 *graph_, &ordered_nodes_,
693 [](const Node* n1, const Node* n2) { return n1->name() < n2->name(); });
694 return OkStatus();
695 }
696
CopyStackTraces(const Graph & from,Graph * to)697 Status CopyStackTraces(const Graph& from, Graph* to) {
698 // Copy over the stack traces.
699 // TODO(jpienaar): This really shouldn't be needed, copying the Graph above
700 // and then needing these traversals is unfortunate.
701 std::unordered_map<string, Node*> node_map = from.BuildNodeNameIndex();
702 for (Node* node : to->nodes()) {
703 if (const Node* old_node = node_map[node->name()]) {
704 if (const std::shared_ptr<AbstractStackTrace>& stack =
705 old_node->GetStackTrace()) {
706 DVLOG(2) << "Stack for " << node->name() << " "
707 << old_node->GetStackTrace()->ToString(
708 AbstractStackTrace::TracePrintingOptions());
709 node->SetStackTrace(stack);
710 } else {
711 DVLOG(1) << "No stack for " << node->name() << " (" << node
712 << ") in Graph " << &from;
713 }
714 } else {
715 DVLOG(1) << "No stack for " << node->name() << " (" << node
716 << ") in Graph " << &from;
717 }
718 }
719
720 return OkStatus();
721 }
722
CreatePlaceholderNodeForFeed(const TensorShapeProto & shape,DataType dtype,Node * node,int index,const std::unordered_map<string,Node * > & node_name_map)723 StatusOr<std::pair<Node*, bool>> ImporterBase::CreatePlaceholderNodeForFeed(
724 const TensorShapeProto& shape, DataType dtype, Node* node, int index,
725 const std::unordered_map<string, Node*>& node_name_map) {
726 DCHECK_LT(index, node->num_outputs());
727 const bool update_inplace = node->num_outputs() == 1 && index == 0;
728 std::string new_node_name =
729 update_inplace ? node->name()
730 : GetUniqueNodeName(node->name(), index, node_name_map);
731
732 Node* placeholder_node;
733 NodeBuilder builder(new_node_name, "Placeholder");
734 builder.Attr("shape", shape);
735 builder.Attr("dtype", dtype);
736 TF_RETURN_IF_ERROR(builder.Finalize(graph_.get(), &placeholder_node));
737
738 // Update edges from original feed with Placeholder node.
739 std::vector<const Edge*> data_edges;
740 std::vector<const Edge*> control_edges;
741 for (const tensorflow::Edge* edge : node->out_edges()) {
742 if (edge->src_output() == index) {
743 data_edges.push_back(edge);
744 } else if (update_inplace && edge->IsControlEdge()) {
745 control_edges.push_back(edge);
746 }
747 }
748
749 for (const auto* edge : data_edges) {
750 TF_RETURN_IF_ERROR(graph_->UpdateEdge(placeholder_node, 0, edge->dst(),
751 edge->dst_input()));
752 }
753
754 // TODO(lyandy): Preserve control dependencies properly by not forwarding
755 // control dependencies to data outputs and not removing single output nodes.
756 // When a data output is replaced as a feed, unless there is another non feed
757 // data output or an explicit control output used by the same node, transitive
758 // control dependencies are not to be executed. For single output nodes,
759 // Placeholders can be converted to a NoOp if there are no uses, and
760 // PlaceholderWithDefault can be converted to an Identity.
761 for (const auto* edge : control_edges) {
762 graph_->AddControlEdge(placeholder_node, edge->dst());
763 graph_->RemoveControlEdge(edge);
764 }
765
766 if (update_inplace) {
767 graph_->RemoveNode(node);
768 }
769
770 return std::pair<Node*, bool>(placeholder_node, update_inplace);
771 }
772
GetInputOutputNodes(const std::unordered_map<string,Node * > & node_name_map,std::unordered_set<const Node * > * nodes)773 Status ImporterBase::GetInputOutputNodes(
774 const std::unordered_map<string, Node*>& node_name_map,
775 std::unordered_set<const Node*>* nodes) {
776 auto add_node = [&](absl::string_view name) {
777 auto it = node_name_map.find(std::string(name));
778 if (it == node_name_map.end()) {
779 return errors::FailedPrecondition(
780 absl::StrCat("Graph does not contain node: ", name));
781 }
782 nodes->insert(it->second);
783 return OkStatus();
784 };
785
786 // Remap feeds and fetches to newly created Placeholder nodes.
787 for (const auto& input : specs_.inputs) {
788 TensorId tensor = ParseTensorName(input.first);
789 auto remapped_it = remapped_feeds_.find(tensor);
790 if (remapped_it != remapped_feeds_.end()) {
791 TF_RETURN_IF_ERROR(add_node(remapped_it->second));
792 } else {
793 TF_RETURN_IF_ERROR(add_node(tensor.node()));
794 }
795 }
796
797 for (const auto& output : specs_.outputs) {
798 TensorId tensor = ParseTensorName(output);
799 auto remapped_it = remapped_feeds_.find(tensor);
800 if (remapped_it != remapped_feeds_.end()) {
801 TF_RETURN_IF_ERROR(add_node(remapped_it->second));
802 } else {
803 TF_RETURN_IF_ERROR(add_node(tensor.node()));
804 }
805 }
806
807 for (const auto& control_output : specs_.control_outputs)
808 TF_RETURN_IF_ERROR(add_node(control_output));
809
810 return OkStatus();
811 }
812
813 // TODO(jpienaar): Remove this post shape inference on import flag is removed.
AddNodesToShapeRefiner(std::unordered_map<string,Node * > * node_name_map)814 Status ImporterBase::AddNodesToShapeRefiner(
815 std::unordered_map<string, Node*>* node_name_map) {
816 shape_refiner_ = std::make_unique<ShapeRefiner>(graph_->versions(),
817 graph_->op_registry());
818 // Some operations (for example "TPUExecute") don't have shape inference
819 // function defined, so we should set this to false for adding nodes with
820 // these types of operations.
821 shape_refiner_->set_require_shape_inference_fns(false);
822 shape_refiner_->set_function_library_for_shape_inference(&graph_flib_);
823
824 TF_ASSIGN_OR_RETURN(auto feeds_by_node, GetFeedsByNode(specs_.inputs));
825
826 // First add all nodes to the refiner.
827 for (Node* node : ordered_nodes_) {
828 // We need to use a TensorFlow node to teach the shape refiner that user
829 // specifies certain data type and shape for the inputs in the `specs_`.
830 // This node shouldn't have any inputs, only have one output and its
831 // output type/shape is only determined by its "named" attributes. (The
832 // attributes should have fixed names so we can use the info from `specs_`
833 // to set the value of them.) `Placeholder` satisfies these constraints.
834 //
835 // Therefore, if the input node isn't a `Placeholder`, we create one and use
836 // it to replace the original input node, so the shape refiner can
837 // successfully propagate the user's input type and shape to the rest of the
838 // graph.
839 bool node_added_to_shape_refiner = false;
840 auto it = feeds_by_node.find(node->name());
841 if (it != feeds_by_node.end()) {
842 auto op_name = node->op_def().name();
843 if (op_name != "Placeholder" && op_name != "LegacyFedInput" &&
844 op_name != FunctionLibraryDefinition::kArgOp) {
845 for (const auto& output_tensor : it->second) {
846 const int index = output_tensor.first;
847 const ArrayInfo& array_info = output_tensor.second->second;
848
849 DataType dtype = array_info.imported_dtype;
850 // Uses the existing output type if it isn't specified by the user.
851 if (dtype == DT_INVALID) {
852 dtype = node->output_type(index);
853 }
854
855 TF_ASSIGN_OR_RETURN(
856 auto placeholder_node_and_removed,
857 CreatePlaceholderNodeForFeed(array_info.shape, dtype, node, index,
858 *node_name_map));
859
860 Node* placeholder_node = placeholder_node_and_removed.first;
861 if (placeholder_node_and_removed.second) {
862 // Original node has been removed from the graph.
863 node = placeholder_node;
864 node_added_to_shape_refiner = true;
865 }
866 remapped_feeds_[{it->first, index}] = placeholder_node->name();
867 (*node_name_map)[placeholder_node->name()] = placeholder_node;
868 // Add the new placeholder node to the shape refiner.
869 Status status = shape_refiner_->AddNode(placeholder_node);
870 if (!status.ok()) {
871 return EmitErrorWithLocationStr(*placeholder_node, status);
872 }
873 }
874 } else {
875 auto index_it = it->second.find(0);
876 if (index_it == it->second.end()) {
877 return errors::FailedPrecondition(
878 "Missing feed output tensor at index 0 for node '", node->name(),
879 "'");
880 }
881 node->AddAttr("shape", index_it->second->second.shape);
882 DataType dtype = index_it->second->second.imported_dtype;
883 // Uses the existing output type if it isn't specified by the user.
884 if (dtype == DT_INVALID) {
885 dtype = node->output_type(0);
886 }
887 node->AddAttr("dtype", dtype);
888 }
889 }
890 if (!node_added_to_shape_refiner) {
891 // Add the node to the shape refiner if the node hasn't been removed.
892 Status status = shape_refiner_->AddNode(node);
893 if (!status.ok()) {
894 return EmitErrorWithLocationStr(*node, status);
895 }
896 }
897
898 auto set_shape_from_list_attr = [&](const AttrValue* attr) {
899 auto& list = attr->list();
900 // This follows the same approach as in ValidateShape, but only flags
901 // warning in case where there are mismatch in number of shapes and
902 // outputs and in which case it just returns without attempting to refine.
903 if (list.shape_size() != node->num_outputs()) {
904 LOG(WARNING) << "Node '" << node->name() << "' has "
905 << node->num_outputs() << " outputs but the "
906 << kOutputShapesAttrName
907 << " attribute specifies shapes for " << list.shape_size()
908 << " outputs";
909 return OkStatus();
910 }
911
912 for (const auto& shape : llvm::enumerate(list.shape())) {
913 auto* node_context = shape_refiner_->GetContext(node);
914 shape_inference::ShapeHandle handle;
915 Status status =
916 node_context->MakeShapeFromShapeProto(shape.value(), &handle);
917 if (!status.ok()) {
918 return EmitErrorWithLocationStr(*node, status);
919 }
920 node_context->set_output(shape.index(), handle);
921 }
922 return OkStatus();
923 };
924
925 // If it is the argument node, the shape handle is set explicitly, so it
926 // can be propagated to the body nodes of the function.
927 if (StringPiece(node->type_string()) == FunctionLibraryDefinition::kArgOp) {
928 auto* node_context = shape_refiner_->GetContext(node);
929 DCHECK(node_context != nullptr);
930 if (const AttrValue* attr = node->attrs().Find("shape")) {
931 shape_inference::ShapeHandle handle;
932 Status status =
933 node_context->MakeShapeFromShapeProto(attr->shape(), &handle);
934 if (!status.ok()) {
935 return EmitErrorWithLocationStr(*node, status);
936 }
937 node_context->set_output(0, handle);
938 } else if (const AttrValue* attr =
939 node->attrs().Find(kOutputShapesAttrName)) {
940 TF_RETURN_IF_ERROR(set_shape_from_list_attr(attr));
941 } else {
942 node_context->set_output(0, node_context->UnknownShape());
943 }
944 }
945
946 // Following GraphConstructor::ValidateShape called from
947 // GraphConstructor::Convert, override the shape if _output_shapes is set.
948 if (specs_.unconditionally_use_set_output_shapes ||
949 node->op_def().name() == "ReadVariableOp") {
950 if (const AttrValue* attr = node->attrs().Find(kOutputShapesAttrName))
951 TF_RETURN_IF_ERROR(set_shape_from_list_attr(attr));
952 }
953 }
954
955 // Since we might have inserted and removed nodes from the graph, fix
956 // source/sink edges and reconstruct the RPO ordering of nodes
957 FixupSourceAndSinkEdges(graph_.get());
958
959 // Prune nodes in the graph that are not reachable from the output.
960 if (specs_.prune_unused_nodes) {
961 std::unordered_set<const Node*> prune_start;
962 TF_RETURN_IF_ERROR(GetInputOutputNodes(*node_name_map, &prune_start));
963 if (!prune_start.empty()) {
964 if (PruneForReverseReachability(graph_.get(), prune_start)) {
965 VLOG(1) << "Pruned unused nodes in graphdef";
966 } else {
967 VLOG(1) << "No unused nodes in graphdef to prune";
968 }
969 } else {
970 VLOG(1) << "No output nodes specified, skipping pruning";
971 }
972 } else {
973 VLOG(1) << "Pruning unused nodes in graphdef is disabled";
974 }
975
976 // Re-initialize ordered_nodes_ since we might have modified the graph.
977 GetReversePostOrder(
978 *graph_, &ordered_nodes_,
979 [](const Node* n1, const Node* n2) { return n1->name() < n2->name(); });
980
981 VLOG(1) << "Inferring graph shapes to fixpoint";
982
983 // The "changed" information from UpdateNode can give false positives, so we
984 // create a dedicated method to verify the shapes are not changed before and
985 // after the shape refine.
986 auto same_inferred_shape = [](shape_inference::InferenceContext* c,
987 shape_inference::ShapeHandle s0,
988 shape_inference::ShapeHandle s1) -> bool {
989 if (s0.SameHandle(s1) || (!c->RankKnown(s0) && !c->RankKnown(s1))) {
990 return true;
991 }
992 if (c->Rank(s0) != c->Rank(s1)) {
993 return false;
994 }
995 for (int i = 0; i < c->Rank(s0); ++i) {
996 if (!c->Dim(s0, i).SameHandle(c->Dim(s1, i))) {
997 int64_t val0 = c->Value(c->Dim(s0, i));
998 int64_t val1 = c->Value(c->Dim(s1, i));
999 // Negative value is treated as unknown so all negative values indicate
1000 // the same dimension.
1001 if (val0 >= 0 && val1 >= 0 && val0 != val1) return false;
1002 }
1003 }
1004 return true;
1005 };
1006
1007 bool changed = true;
1008 int i = 0;
1009 const int kMaxIterationCount = 2;
1010 while (changed && i != kMaxIterationCount) {
1011 changed = false;
1012 for (const Node* node : ordered_nodes_) {
1013 auto* shape_context = shape_refiner_->GetContext(node);
1014 DCHECK(shape_context != nullptr);
1015 absl::InlinedVector<shape_inference::ShapeHandle, 4> existing;
1016 existing.reserve(shape_context->num_outputs());
1017 for (int o = 0; o < shape_context->num_outputs(); ++o) {
1018 existing.push_back(shape_context->output(o));
1019 }
1020 bool inferred = false;
1021 shape_inference::ShapeHandle handle;
1022 Status status =
1023 shape_refiner_->UpdateNode(node, /*relax=*/false, &inferred);
1024 if (!status.ok()) {
1025 return EmitErrorWithLocationStr(*node, status);
1026 }
1027 for (int o = 0; o < shape_context->num_outputs(); ++o) {
1028 if (!same_inferred_shape(shape_context, shape_context->output(o),
1029 existing[o])) {
1030 changed = true;
1031 break;
1032 }
1033 }
1034 }
1035 ++i;
1036 }
1037 if (i >= kMaxIterationCount) {
1038 LOG(WARNING) << "Graph shapes did not converge to a fixpoint within "
1039 << kMaxIterationCount
1040 << " iterations. Graph shapes may be conservative.";
1041 }
1042 VLOG(1) << "Graph shapes were inferred with " << (i - 1)
1043 << " extra rounds of analysis to reach a fixpoint.";
1044 return OkStatus();
1045 }
1046
InferInputType(const Node & node,int idx,mlir::Builder builder)1047 StatusOr<mlir::Type> ImporterBase::InferInputType(const Node& node, int idx,
1048 mlir::Builder builder) {
1049 if (specs_.enable_shape_inference) {
1050 // TODO(jpienaar): Remove this if shape inference on import flag is removed.
1051 ExtendedInferenceContext* shape_context =
1052 shape_refiner_->GetExtendedContext(&node);
1053 DataType dtype = shape_context->input_type(idx);
1054 auto* context = shape_context->get_context();
1055 return ConvertDataTypeAndShape(dtype, context->input(idx),
1056 context->input_handle_shapes_and_types(idx),
1057 context, builder);
1058 }
1059 DataType dtype = node.properties()->input_types[idx];
1060 mlir::Type element_type;
1061 TF_RETURN_IF_ERROR(ConvertDataType(dtype, builder, &element_type));
1062 return mlir::UnrankedTensorType::get(element_type);
1063 }
1064
InferOutputType(const Node & node,int idx,mlir::Builder builder)1065 StatusOr<mlir::Type> ImporterBase::InferOutputType(const Node& node, int idx,
1066 mlir::Builder builder) {
1067 DataType dtype = node.properties()->output_types[idx];
1068
1069 // Returns output type given inference context.
1070 auto shape_ic =
1071 [&](shape_inference::InferenceContext* c) -> StatusOr<mlir::Type> {
1072 // TODO(b/200093974): Post triage, consider following
1073 // GraphConstructor::ValidateShape in checking _output_shapes always.
1074 if (specs_.unconditionally_use_set_output_shapes) {
1075 if (const AttrValue* attr = node.attrs().Find(kOutputShapesAttrName)) {
1076 auto& list = attr->list();
1077 if (list.shape_size() > idx) {
1078 const TensorShapeProto& p = list.shape()[idx];
1079 shape_inference::ShapeHandle h;
1080 Status s = c->MakeShapeFromShapeProto(p, &h);
1081 if (!s.ok())
1082 return errors::InvalidArgument(
1083 "Node '", node.name(), " has an invalid ",
1084 kOutputShapesAttrName, " attribute (shape #", idx, " error:'",
1085 s.error_message(), "')");
1086 c->set_output(idx, h);
1087 }
1088 }
1089 }
1090
1091 return ConvertDataTypeAndShape(dtype, c->output(idx),
1092 c->output_handle_shapes_and_types(idx), c,
1093 builder);
1094 };
1095
1096 if (specs_.enable_shape_inference) {
1097 // TODO(jpienaar): Remove this if shape inference on import flag is removed.
1098 ExtendedInferenceContext* shape_context =
1099 shape_refiner_->GetExtendedContext(&node);
1100 return shape_ic(shape_context->get_context());
1101 }
1102
1103 // Treat TensorList init ops specially here as the op requires knowing its
1104 // element dtype.
1105 // TODO(jpienaar): Reconsider post refactoring shape functions.
1106 if (node.type_string() == "TensorListReserve" ||
1107 node.type_string() == "EmptyTensorList") {
1108 mlir::Type etype;
1109 if (auto element_dtype = node.attrs().Find("element_dtype")) {
1110 TF_RETURN_IF_ERROR(
1111 ConvertDataType(element_dtype->type(), builder, &etype));
1112 }
1113 return mlir::RankedTensorType::get(
1114 {}, mlir::TF::VariantType::get({mlir::UnrankedTensorType::get(etype)},
1115 etype.getContext()));
1116 }
1117
1118 if (node.IsWhileNode()) {
1119 auto* output_shapes = node.attrs().Find("output_shapes");
1120 auto* element_types = node.attrs().Find("T");
1121 if (output_shapes && !output_shapes->list().shape().empty()) {
1122 const auto& output_shape = output_shapes->list().shape(idx);
1123 const auto& element_type = element_types->list().type(idx);
1124 return ConvertToMlirTensorType(output_shape, element_type, &builder);
1125 }
1126 }
1127
1128 auto type_from_array_attr = [&node, &idx, &builder](
1129 absl::string_view output_shape_attr,
1130 absl::string_view element_type_attr) {
1131 auto* output_shapes = node.attrs().Find(output_shape_attr);
1132 auto* element_types = node.attrs().Find(element_type_attr);
1133 const auto& output_shape = output_shapes->list().shape(idx);
1134 const auto& element_type = element_types->list().type(idx);
1135 return ConvertToMlirTensorType(output_shape, element_type, &builder);
1136 };
1137
1138 if (node.type_string() == "IteratorGetNext" ||
1139 node.type_string() == "IteratorGetNextSync" ||
1140 node.type_string() == "MultiDeviceIteratorGetNextFromShard")
1141 return type_from_array_attr("output_shapes", "output_types");
1142
1143 if (node.type_string() == "InfeedDequeueTuple")
1144 return type_from_array_attr("shapes", "dtypes");
1145
1146 if (node.type_string() == "InfeedDequeue") {
1147 assert(idx == 0);
1148 const auto& output_shape = node.attrs().Find("shape")->shape();
1149 const auto& element_type = node.attrs().Find("dtype")->type();
1150 return ConvertToMlirTensorType(output_shape, element_type, &builder);
1151 }
1152
1153 // Returns a simple, more conservative unranked tensor type.
1154 auto default_type = [&]() -> StatusOr<mlir::Type> {
1155 mlir::Type element_type;
1156 TF_RETURN_IF_ERROR(ConvertDataType(dtype, builder, &element_type));
1157
1158 // TODO(b/200093974): Post triage, consider following
1159 // GraphConstructor::ValidateShape in checking _output_shapes.
1160 if (specs_.unconditionally_use_set_output_shapes) {
1161 if (const AttrValue* attr = node.attrs().Find(kOutputShapesAttrName)) {
1162 auto& list = attr->list();
1163 if (list.shape_size() > idx) {
1164 llvm::SmallVector<int64_t, 4> shape;
1165 const TensorShapeProto& shape_proto = list.shape()[idx];
1166 if (shape_proto.unknown_rank())
1167 return mlir::UnrankedTensorType::get(element_type);
1168 TF_RETURN_IF_ERROR(ConvertToMlirShape(shape_proto, &shape));
1169 return mlir::RankedTensorType::get(shape, element_type);
1170 }
1171 }
1172 }
1173
1174 return mlir::UnrankedTensorType::get(element_type);
1175 };
1176
1177 // Below we only try and do some shape inference for "source" ops which have
1178 // no inputs.
1179 if (node.num_inputs() > 0) return default_type();
1180
1181 // Do some simply inference here to get the function arguments correct for
1182 // this common case.
1183 // TODO(jpienaar): Reconsider post refactoring shape functions.
1184 if (node.IsArg()) {
1185 if (dtype == DT_RESOURCE) {
1186 const AttrValue* dtype_attr = node.attrs().Find("_handle_dtypes");
1187 const AttrValue* shape_attr = node.attrs().Find("_handle_shapes");
1188 if (dtype_attr && shape_attr) {
1189 if (dtype_attr->list().type().empty()) {
1190 return errors::InvalidArgument(
1191 "Invalid \"_handle_dtypes\" attribute value for _Arg node: ",
1192 shape_attr->DebugString());
1193 }
1194 if (shape_attr->list().shape().empty()) {
1195 return errors::InvalidArgument(
1196 "Invalid \"_handle_shapes\" attribute value for _Arg node: ",
1197 shape_attr->DebugString());
1198 }
1199 DataType dtype = dtype_attr->list().type(0);
1200 const TensorShapeProto& shape_proto = shape_attr->list().shape(0);
1201 TF_ASSIGN_OR_RETURN(
1202 auto etype, ConvertToMlirTensorType(shape_proto, dtype, &builder));
1203 return mlir::UnrankedTensorType::get(mlir::TF::ResourceType::get(
1204 {etype.cast<TensorType>()}, builder.getContext()));
1205 } else {
1206 return mlir::UnrankedTensorType::get(
1207 mlir::TF::ResourceType::get(builder.getContext()));
1208 }
1209 } else if (auto shape = node.attrs().Find("_output_shapes")) {
1210 if (shape->has_list() && shape->list().shape_size() == 1) {
1211 return ConvertToMlirTensorType(shape->list().shape().at(0), dtype,
1212 &builder);
1213 }
1214 }
1215 }
1216
1217 const tensorflow::OpRegistrationData* op_reg_data;
1218 TF_RETURN_IF_ERROR(
1219 graph_->op_registry()->LookUp(node.type_string(), &op_reg_data));
1220 if (!op_reg_data) {
1221 DVLOG(1) << "Skipping inference for unregistered op " << node.type_string();
1222 return default_type();
1223 }
1224 if (op_reg_data->shape_inference_fn == nullptr) {
1225 DVLOG(1) << "Skipping inference for op without shape function "
1226 << node.type_string();
1227 return default_type();
1228 }
1229 shape_inference::InferenceContext c(graph_->versions().producer(),
1230 node.attrs(), op_reg_data->op_def,
1231 std::vector<PartialTensorShape>{}, {},
1232 /*input_tensors_as_shapes=*/{}, {});
1233 TF_RETURN_IF_ERROR(c.Run(op_reg_data->shape_inference_fn));
1234 return shape_ic(&c);
1235 }
1236
ConvertDataTypeAndShape(DataType dtype,const shape_inference::ShapeHandle & handle,const std::vector<shape_inference::ShapeAndType> * handle_subtypes,shape_inference::InferenceContext * context,mlir::Builder builder)1237 StatusOr<TensorType> ImporterBase::ConvertDataTypeAndShape(
1238 DataType dtype, const shape_inference::ShapeHandle& handle,
1239 const std::vector<shape_inference::ShapeAndType>* handle_subtypes,
1240 shape_inference::InferenceContext* context, mlir::Builder builder) {
1241 TF_ASSIGN_OR_RETURN(auto subtypes,
1242 ConvertSubtypes(handle_subtypes, context, builder));
1243
1244 mlir::Type element_type;
1245 if (dtype == DT_VARIANT)
1246 element_type = mlir::TF::VariantType::get(subtypes, context_);
1247 else if (dtype == DT_RESOURCE)
1248 element_type = mlir::TF::ResourceType::get(subtypes, context_);
1249 else
1250 TF_RETURN_IF_ERROR(
1251 ::tensorflow::ConvertDataType(dtype, builder, &element_type));
1252
1253 return ConvertElementTypeAndShape(element_type, handle, context, builder);
1254 }
1255
ConvertElementTypeAndShape(mlir::Type element_type,const shape_inference::ShapeHandle & handle,shape_inference::InferenceContext * context,mlir::Builder builder)1256 StatusOr<TensorType> ImporterBase::ConvertElementTypeAndShape(
1257 mlir::Type element_type, const shape_inference::ShapeHandle& handle,
1258 shape_inference::InferenceContext* context, mlir::Builder builder) {
1259 if (!context->RankKnown(handle)) {
1260 return mlir::UnrankedTensorType::get(element_type);
1261 }
1262
1263 // Sentinel for an unknown dimension size. getTensorType interprets any
1264 // negative value as an unknown dimension.
1265 // TODO(jmolloy): Ideally this shouldn't be a local sentinel.
1266 const int64_t kUnknownDim = -1;
1267
1268 absl::InlinedVector<int64_t, 4> dimensions;
1269 int32_t rank = context->Rank(handle);
1270 dimensions.reserve(rank);
1271 for (int i = 0; i < rank; ++i) {
1272 auto dim_handle = context->Dim(handle, i);
1273 if (!context->ValueKnown(dim_handle))
1274 dimensions.push_back(kUnknownDim);
1275 else
1276 dimensions.push_back(context->Value(dim_handle));
1277 }
1278
1279 return mlir::RankedTensorType::get(
1280 llvm::makeArrayRef(dimensions.begin(), dimensions.end()), element_type);
1281 }
1282
ConvertSubtypes(const std::vector<shape_inference::ShapeAndType> * handle_subtypes,shape_inference::InferenceContext * context,mlir::Builder builder)1283 StatusOr<ImporterBase::ElementSubtypes> ImporterBase::ConvertSubtypes(
1284 const std::vector<shape_inference::ShapeAndType>* handle_subtypes,
1285 shape_inference::InferenceContext* context, mlir::Builder builder) {
1286 ElementSubtypes subtypes;
1287 if (!handle_subtypes) return subtypes;
1288
1289 subtypes.reserve(handle_subtypes->size());
1290 for (const auto& subtype : *handle_subtypes) {
1291 mlir::Type element_type;
1292 TF_RETURN_IF_ERROR(
1293 ::tensorflow::ConvertDataType(subtype.dtype, builder, &element_type));
1294 TF_ASSIGN_OR_RETURN(TensorType type,
1295 ConvertElementTypeAndShape(element_type, subtype.shape,
1296 context, builder));
1297 subtypes.push_back(type);
1298 }
1299 return subtypes;
1300 }
1301
ConvertFunctionCallAttribute(const std::string & base_name,const AttrValue & value,NamedAttrList * attributes)1302 Status ImporterBase::ConvertFunctionCallAttribute(const std::string& base_name,
1303 const AttrValue& value,
1304 NamedAttrList* attributes) {
1305 TF_ASSIGN_OR_RETURN(auto func_attr,
1306 ConvertFunctionCallName(value.func().name()));
1307 if (!func_attr) return OkStatus();
1308 attributes->push_back(builder_.getNamedAttr(base_name, func_attr));
1309
1310 for (const auto& it : value.func().attr()) {
1311 auto name = absl::StrCat(base_name, ".", it.first);
1312 TF_ASSIGN_OR_RETURN(auto value, ConvertAttributeValue(it.second));
1313 attributes->push_back(builder_.getNamedAttr(name, value));
1314 }
1315 return OkStatus();
1316 }
1317
ConvertFunctionCallName(const std::string & func_name)1318 StatusOr<mlir::FlatSymbolRefAttr> ImporterBase::ConvertFunctionCallName(
1319 const std::string& func_name) {
1320 // Some ops like XlaHostCompute op uses empty value to represent missing
1321 // functions. Such attribute values should be defined optional in MLIR
1322 // definition.
1323 if (func_name.empty()) return mlir::FlatSymbolRefAttr();
1324
1325 TF_RETURN_IF_ERROR(ConvertLibFunction(func_name));
1326 auto mlir_func_name = (*tf_name_to_mlir_name_)[func_name];
1327 return mlir::SymbolRefAttr::get(builder_.getContext(), mlir_func_name);
1328 }
1329
ConvertAttributeValue(const AttrValue & value)1330 StatusOr<mlir::Attribute> ImporterBase::ConvertAttributeValue(
1331 const AttrValue& value) {
1332 switch (value.value_case()) {
1333 case AttrValue::kFunc: {
1334 // TODO(b/156546237): Unify kFunc/NameAttrList attribute representation.
1335 // Currently kFunc/NameAttrList attributes in a kList/repeated AttrValue
1336 // will not use this representation. This also doesn't handle empty
1337 // function values like ConvertFunctionCallName method.
1338 NamedAttrList attrs;
1339 for (const auto& func_attr : value.func().attr()) {
1340 TF_ASSIGN_OR_RETURN(
1341 auto attr, ImporterBase::ConvertAttributeValue(func_attr.second));
1342 attrs.push_back(builder_.getNamedAttr(func_attr.first, attr));
1343 }
1344 auto func_attrs = builder_.getDictionaryAttr(attrs);
1345 return mlir::TF::FuncAttr::get(context_, value.func().name(), func_attrs);
1346 }
1347 case AttrValue::kList: {
1348 if (!value.list().func().empty()) {
1349 absl::InlinedVector<mlir::Attribute, 8> attrs;
1350 for (const auto& item : value.list().func()) {
1351 TF_ASSIGN_OR_RETURN(auto attr, ConvertFunctionCallName(item.name()));
1352 if (item.attr_size() != 0)
1353 return errors::Unimplemented(
1354 "func attributes with non-zero attr.size()");
1355 if (attr) attrs.push_back(attr);
1356 }
1357 return builder_.getArrayAttr(
1358 llvm::makeArrayRef(attrs.begin(), attrs.end()));
1359 }
1360 return ConvertNonFuncAttributeValue(value, &builder_);
1361 }
1362 default:
1363 return ConvertNonFuncAttributeValue(value, &builder_);
1364 }
1365 }
1366
GetArgsAndRetsFromFunctionBody(const FunctionBody & fbody,absl::InlinedVector<OutputTensor,4> * arg_nodes,absl::InlinedVector<OutputTensor,4> * ret_nodes,absl::InlinedVector<Node *,4> * control_ret_nodes)1367 void ImporterBase::GetArgsAndRetsFromFunctionBody(
1368 const FunctionBody& fbody, absl::InlinedVector<OutputTensor, 4>* arg_nodes,
1369 absl::InlinedVector<OutputTensor, 4>* ret_nodes,
1370 absl::InlinedVector<Node*, 4>* control_ret_nodes) {
1371 arg_nodes->reserve(fbody.arg_nodes.size());
1372 ret_nodes->reserve(fbody.ret_nodes.size());
1373 for (auto arg : fbody.arg_nodes) {
1374 arg_nodes->emplace_back(arg, 0);
1375 }
1376 for (auto ret : fbody.ret_nodes) {
1377 ret_nodes->emplace_back(ret, 0);
1378 }
1379 *control_ret_nodes = fbody.control_ret_nodes;
1380 }
1381
ConvertLibFunction(llvm::StringRef func_name)1382 Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) {
1383 // If the library function has been converted already, nothing needs to be
1384 // done.
1385 if (tf_name_to_mlir_name_->find(std::string(func_name)) !=
1386 tf_name_to_mlir_name_->end())
1387 return OkStatus();
1388
1389 std::string mlir_func_name(
1390 function_name_uniquifier_->GetUniqueName(func_name));
1391 (*tf_name_to_mlir_name_)[std::string(func_name)] = mlir_func_name;
1392
1393 const auto& func_lib = graph_flib_;
1394 const auto* func_def = func_lib.Find(std::string(func_name));
1395 if (func_def == nullptr) {
1396 return errors::FailedPrecondition(
1397 absl::StrCat("Failed to find function '", StringRefToView(func_name),
1398 "'. The imported TensorFlow GraphDef is ill-formed."));
1399 }
1400
1401 // Converts the argument and return types to MLIR types.
1402 std::vector<mlir::NamedAttribute> attributes;
1403 attributes.reserve(func_def->attr_size());
1404 for (const auto& name_and_value : func_def->attr()) {
1405 // This is a function definition attribute, so it shouldn't contain
1406 // kFunc attribute and it is treated as normal one.
1407 TF_ASSIGN_OR_RETURN(auto attr,
1408 ConvertAttributeValue(name_and_value.second));
1409 std::string attr_name =
1410 mangling_util::MangleAttributeName(name_and_value.first);
1411 attributes.push_back(builder_.getNamedAttr(attr_name, attr));
1412 }
1413
1414 // Checks opdef stateful attribute and import that as Function Attribute
1415 if (func_def->signature().is_stateful()) {
1416 auto stateful_str = mlir::TF::TensorFlowDialect::GetStatefulAttrName();
1417 attributes.push_back(
1418 builder_.getNamedAttr(stateful_str, builder_.getUnitAttr()));
1419 }
1420
1421 // Checks for an associated custom gradient function. Adds it to the attribute
1422 // list of this function.
1423 auto grad_func_name = func_lib.FindGradient(std::string(func_name));
1424 if (!grad_func_name.empty()) {
1425 TF_RETURN_IF_ERROR(ConvertLibFunction(grad_func_name));
1426 auto mlir_grad_func_name = (*tf_name_to_mlir_name_)[grad_func_name];
1427 auto gradient_attr =
1428 mlir::SymbolRefAttr::get(builder_.getContext(), mlir_grad_func_name);
1429 auto grad_string = mlir::TF::TensorFlowDialect::GetGradientAttrName();
1430 attributes.push_back(builder_.getNamedAttr(grad_string, gradient_attr));
1431 }
1432
1433 deferred_functions_.emplace(func_name.str(), attributes);
1434 return OkStatus();
1435 }
1436
PruneUnreachableNodes(std::unordered_map<string,Node * > * node_name_map)1437 Status ImporterBase::PruneUnreachableNodes(
1438 std::unordered_map<string, Node*>* node_name_map) {
1439 std::unordered_set<const Node*> prune_start;
1440 TF_RETURN_IF_ERROR(GetInputOutputNodes(*node_name_map, &prune_start));
1441
1442 if (!prune_start.empty()) {
1443 if (PruneForReverseReachability(graph_.get(), prune_start)) {
1444 VLOG(1) << "Pruned unused nodes in graphdef";
1445 } else {
1446 VLOG(1) << "No unused nodes in graphdef to prune";
1447 }
1448 } else {
1449 VLOG(1) << "No output nodes specified, skipping pruning";
1450 }
1451 return OkStatus();
1452 }
1453
ConvertFeedsToPlaceholders(std::unordered_map<string,Node * > * node_name_map)1454 Status ImporterBase::ConvertFeedsToPlaceholders(
1455 std::unordered_map<string, Node*>* node_name_map) {
1456 // Feeds (edges) are converted into single-output placeholder nodes to
1457 // simplify the conversion process.
1458 TF_ASSIGN_OR_RETURN(auto feeds_by_node, GetFeedsByNode(specs_.inputs));
1459 for (const auto& it : feeds_by_node) {
1460 TensorId tensor = ParseTensorName(it.first);
1461 auto jt = node_name_map->find(std::string(tensor.node()));
1462 if (jt == node_name_map->end()) {
1463 return errors::FailedPrecondition(
1464 absl::StrCat("Graph does not contain node: ", tensor.node()));
1465 }
1466
1467 Node* node = jt->second;
1468 auto op_name = node->op_def().name();
1469 if (op_name != "Placeholder" && op_name != "LegacyFedInput" &&
1470 op_name != FunctionLibraryDefinition::kArgOp) {
1471 for (const auto& output_tensor : it.second) {
1472 const int index = output_tensor.first;
1473 const ArrayInfo& array_info = output_tensor.second->second;
1474
1475 DataType dtype = array_info.imported_dtype;
1476 // Uses the existing output type if it isn't specified by the user.
1477 if (dtype == DT_INVALID) {
1478 dtype = node->output_type(index);
1479 }
1480
1481 TF_ASSIGN_OR_RETURN(
1482 auto placeholder_node_and_removed,
1483 CreatePlaceholderNodeForFeed(array_info.shape, dtype, node, index,
1484 *node_name_map));
1485
1486 Node* placeholder_node = placeholder_node_and_removed.first;
1487 if (placeholder_node->in_edges().empty()) {
1488 graph_->AddControlEdge(graph_->source_node(), placeholder_node,
1489 true /* skip test for duplicates */);
1490 }
1491 if (placeholder_node->out_edges().empty()) {
1492 graph_->AddControlEdge(placeholder_node, graph_->sink_node(),
1493 true /* skip test for duplicates */);
1494 }
1495 remapped_feeds_[{it.first, index}] = placeholder_node->name();
1496 (*node_name_map)[placeholder_node->name()] = placeholder_node;
1497 }
1498 }
1499 }
1500 return OkStatus();
1501 }
1502
PrepareConvert(const Graph & graph,std::unique_ptr<GraphDef> graph_def)1503 Status ImporterBase::PrepareConvert(const Graph& graph,
1504 std::unique_ptr<GraphDef> graph_def) {
1505 // TODO(fengliuai): Converting to GraphDef and back is the easiest way to
1506 // clone a graph.
1507 // TODO(fengliuai): clone the graph without going to graph_def first.
1508 if (graph_def == nullptr) {
1509 graph_def = std::make_unique<GraphDef>();
1510 graph.ToGraphDef(graph_def.get());
1511 }
1512 graph_ = std::make_unique<Graph>(graph.flib_def());
1513 GraphConstructorOptions opts;
1514 opts.allow_internal_ops = true;
1515 opts.add_default_attributes = true;
1516 TF_RETURN_IF_ERROR(::tensorflow::ConvertGraphDefToGraph(
1517 opts, std::move(*graph_def), graph_.get()));
1518
1519 TF_RETURN_IF_ERROR(RemoveBackedges());
1520
1521 TF_RETURN_IF_ERROR(CopyStackTraces(graph, graph_.get()));
1522
1523 auto node_name_map = graph_->BuildNodeNameIndex();
1524
1525 if (specs_.enable_shape_inference) {
1526 // TODO(jpienaar): Remove once infer shapes on import flag is removed.
1527 TF_RETURN_IF_ERROR(AddNodesToShapeRefiner(&node_name_map));
1528 } else {
1529 TF_RETURN_IF_ERROR(ConvertFeedsToPlaceholders(&node_name_map));
1530 }
1531
1532 // Prune nodes in the graph that are not reachable from the output.
1533 if (specs_.prune_unused_nodes) {
1534 TF_RETURN_IF_ERROR(PruneUnreachableNodes(&node_name_map));
1535 }
1536
1537 if (!specs_.enable_shape_inference) {
1538 // Re-initialize ordered_nodes_ since we might have modified the graph.
1539 GetReversePostOrder(
1540 *graph_, &ordered_nodes_,
1541 [](const Node* n1, const Node* n2) { return n1->name() < n2->name(); });
1542 }
1543
1544 return OkStatus();
1545 }
1546
Convert(llvm::StringRef func_name,mlir::FunctionType func_type,const absl::InlinedVector<OutputTensor,4> & arg_nodes,const absl::InlinedVector<OutputTensor,4> & ret_nodes,const absl::InlinedVector<Node *,4> & control_ret_nodes,llvm::ArrayRef<mlir::NamedAttribute> attrs)1547 Status ImporterBase::Convert(
1548 llvm::StringRef func_name, mlir::FunctionType func_type,
1549 const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
1550 const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
1551 const absl::InlinedVector<Node*, 4>& control_ret_nodes,
1552 llvm::ArrayRef<mlir::NamedAttribute> attrs) {
1553 // TODO(b/122040776): Uses debug info for FunctionDef.
1554 auto function = mlir::func::FuncOp::create(mlir::UnknownLoc::get(context_),
1555 func_name, func_type, attrs);
1556
1557 module_.push_back(function);
1558 // Seeds the builder with an initial block.
1559 function.addEntryBlock();
1560 builder_ = mlir::OpBuilder(function.getBody());
1561
1562 // Create the graph operation in which we will convert the individual nodes.
1563 auto graph = builder_.create<mlir::tf_executor::GraphOp>(
1564 function.getLoc(), func_type.getResults());
1565 builder_.createBlock(&graph.body());
1566
1567 for (const Node* node : ordered_nodes_) {
1568 TF_RETURN_IF_ERROR(ConvertNode(*node));
1569 }
1570
1571 // Adds the backedges back to the function by creating the source and sink
1572 // pairs.
1573 TF_RETURN_IF_ERROR(AddBackedges());
1574
1575 TF_RETURN_IF_ERROR(ConvertFunctionArgAndRets(function, graph,
1576 func_type.getInputs(), arg_nodes,
1577 ret_nodes, control_ret_nodes));
1578
1579 // TODO(jpienaar): Update post removing shape_refinier_.
1580 if (!specs_.enable_shape_inference) {
1581 // Refine graph's type given more precise fetch.
1582 auto fetch = graph.GetFetch();
1583 bool all_equal = true;
1584 for (auto it :
1585 llvm::zip_first(graph.getResults(), fetch.getOperandTypes())) {
1586 auto rt = std::get<1>(it);
1587 if (rt == std::get<0>(it).getType()) continue;
1588 std::get<0>(it).setType(rt);
1589 all_equal = false;
1590 }
1591 if (!all_equal) {
1592 function.setType(mlir::FunctionType::get(function.getContext(),
1593 func_type.getInputs(),
1594 graph.getResultTypes()));
1595 }
1596 }
1597
1598 return OkStatus();
1599 }
1600
ConvertFunctionArgAndRets(mlir::func::FuncOp func,mlir::tf_executor::GraphOp graph_op,llvm::ArrayRef<mlir::Type> arg_types,const absl::InlinedVector<OutputTensor,4> & arg_nodes,const absl::InlinedVector<OutputTensor,4> & ret_nodes,const absl::InlinedVector<Node *,4> & control_ret_nodes)1601 Status ImporterBase::ConvertFunctionArgAndRets(
1602 mlir::func::FuncOp func, mlir::tf_executor::GraphOp graph_op,
1603 llvm::ArrayRef<mlir::Type> arg_types,
1604 const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
1605 const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
1606 const absl::InlinedVector<Node*, 4>& control_ret_nodes) {
1607 // Store the arg/return attributes as a list rather than uniqueuing during
1608 // construction.
1609 llvm::SmallVector<mlir::NamedAttrList, 4> arg_attrs;
1610 arg_attrs.resize(func.getNumArguments());
1611 llvm::SmallVector<mlir::NamedAttrList, 4> ret_attrs;
1612 ret_attrs.resize(func.getNumResults());
1613
1614 auto set_attributes_on_func = [&](Node* node, int64_t index, bool is_arg) {
1615 for (const auto& node_attr : node->attrs()) {
1616 const auto& key = node_attr.first;
1617 // Only import optional attributes (e.g., those starting with an
1618 // underscore).
1619 if (key.empty() || key[0] != '_') continue;
1620 // Ignore shape inference attributes as shape information is already
1621 // populated in the result type.
1622 if (IsOutputShapesAttribute(node_attr.second, key) ||
1623 IsResourceOutputShapesAttribute(node_attr.second, key))
1624 continue;
1625 TF_ASSIGN_OR_RETURN(auto converted_attr,
1626 ConvertAttributeValue(node_attr.second));
1627 std::string dialect_attribute = "tf." + key;
1628 if (is_arg) {
1629 arg_attrs[index].set(dialect_attribute, converted_attr);
1630 } else {
1631 func.setResultAttr(index, dialect_attribute, converted_attr);
1632 ret_attrs[index].set(dialect_attribute, converted_attr);
1633 }
1634 }
1635 return OkStatus();
1636 };
1637
1638 auto* bb = &func.front();
1639 llvm::SmallDenseMap<std::pair<Node*, int>, mlir::Value, 4>
1640 arg_nodes_to_values;
1641 for (int i = 0, e = arg_types.size(); i < e; ++i) {
1642 auto& arg_node = arg_nodes[i];
1643 // The lookup can't fail here: otherwise some nodes in the function haven't
1644 // be converted to mlir operations and don't have a mapping.
1645 mlir::Operation* island = node_values_.find(arg_node.node->id())->second;
1646
1647 auto bb_arg = bb->getArgument(i);
1648 mlir::Value arg_def = bb_arg;
1649
1650 if (island->getNumResults() != 2)
1651 return errors::InvalidArgument(
1652 "Only feed output tensors of single output nodes are supported");
1653
1654 // Collect mapping of OutputTensor to associated block arg.
1655 arg_nodes_to_values.try_emplace({arg_node.node, arg_node.index}, arg_def);
1656 island->getResult(0).replaceAllUsesWith(arg_def);
1657 // Erase control outputs from feed.
1658 auto control_uses = island->getResult(1).getUses();
1659 for (auto& control_use : llvm::make_early_inc_range(control_uses))
1660 control_use.getOwner()->eraseOperand(control_use.getOperandNumber());
1661
1662 if (!arg_node.node->requested_device().empty())
1663 arg_attrs[i].set("tf.device", builder_.getStringAttr(
1664 arg_node.node->requested_device()));
1665
1666 if (arg_node.node->IsArg()) {
1667 TF_RETURN_IF_ERROR(
1668 set_attributes_on_func(arg_node.node, i, /*is_arg=*/true));
1669 }
1670
1671 island->dropAllReferences();
1672 island->erase();
1673 }
1674
1675 llvm::SmallVector<mlir::Value, 8> inst_to_return;
1676 for (auto ret_and_idx : llvm::enumerate(ret_nodes)) {
1677 const auto& ret = ret_and_idx.value();
1678 auto* inst = node_values_[ret.node->id()];
1679 if (ret.node->IsRetval()) {
1680 if (!ret.node->requested_device().empty())
1681 ret_attrs[ret_and_idx.index()].set(
1682 "tf.device", builder_.getStringAttr(ret.node->requested_device()));
1683 TF_RETURN_IF_ERROR(set_attributes_on_func(ret.node, ret_and_idx.index(),
1684 /*is_arg=*/false));
1685 // Lookup the instruction inside the island
1686 auto island_op = llvm::cast<mlir::tf_executor::IslandOp>(inst);
1687 mlir::Operation* inner_op = &island_op.GetBody().front();
1688 // Remove kRetOp or kDeviceRetOp operation and return its operand.
1689 // kRetOp and kDeviceRetOp should have just one operand unless they have
1690 // control dependencies.
1691 if (inner_op->getNumOperands() != 1)
1692 return errors::Unimplemented("Return node with multiple inputs.");
1693 inst_to_return.push_back(inner_op->getOperand(0));
1694 inst->dropAllReferences();
1695 inst->erase();
1696 } else {
1697 // Lookup and use block arg if fetch is a feed.
1698 auto it = arg_nodes_to_values.find({ret.node, ret.index});
1699 if (it != arg_nodes_to_values.end())
1700 inst_to_return.push_back(it->second);
1701 else
1702 inst_to_return.push_back(inst->getResult(ret.index));
1703 }
1704 }
1705
1706 for (Node* control_ret : control_ret_nodes) {
1707 auto* inst = node_values_[control_ret->id()];
1708 inst_to_return.push_back(*std::prev(inst->result_end()));
1709 }
1710
1711 // Terminate the function by adding a Fetch operation to terminate the graph
1712 // and a return operation to return the Graph results.
1713 builder_.setInsertionPointToEnd(&graph_op.body().front());
1714 builder_.create<mlir::tf_executor::FetchOp>(graph_op.getLoc(),
1715 inst_to_return);
1716 builder_.setInsertionPointToEnd(bb);
1717 builder_.create<mlir::func::ReturnOp>(mlir::UnknownLoc::get(context_),
1718 graph_op.getResults());
1719
1720 func.setAllArgAttrs(
1721 llvm::to_vector<4>(llvm::map_range(arg_attrs, [&](NamedAttrList& list) {
1722 return list.getDictionary(context_);
1723 })));
1724 func.setAllResultAttrs(
1725 llvm::to_vector<4>(llvm::map_range(ret_attrs, [&](NamedAttrList& list) {
1726 return list.getDictionary(context_);
1727 })));
1728
1729 return OkStatus();
1730 }
1731
GetLocation(const Node & node)1732 mlir::Location ImporterBase::GetLocation(const Node& node) {
1733 DVLOG(1) << "Getting location for " << node.name() << " " << &node;
1734 // TODO(b/142400497): What is the semantic contract for locations?
1735 const auto& debug_info = debug_info_.traces();
1736
1737 // Create a location for node `name` in function `function_name`.
1738 auto create_location = [&](llvm::StringRef name,
1739 llvm::StringRef function_name) -> mlir::Location {
1740 // Use the catenation of function and node names as the lookup key into the
1741 // debug info. This matches the way that the key is formed on the python
1742 // side.
1743 //
1744 // We also use this as the name for the NameLoc for ops in function, since
1745 // otherwise our names could collide across functions.
1746 // For ops in the main graph, we omit the "@function_name" (which, would be
1747 // just "@" since function_name would be empty) because some code seems to
1748 // depend on the name being this way for correctness.
1749 std::string debug_info_key = (name + "@" + function_name).str();
1750 std::string name_for_name_loc =
1751 function_name.empty() ? name.str() : debug_info_key;
1752 auto name_loc_id = mlir::StringAttr::get(context_, name_for_name_loc);
1753
1754 llvm::SmallVector<mlir::Location, 4> locations;
1755 // Prefer stack traces if available, fallback to debug info if not, and then
1756 // finally to just name.
1757 if (auto stack_trace = node.GetStackTrace()) {
1758 DVLOG(1) << "Stack available for " << node.name();
1759 absl::Span<const StackFrame> frames = stack_trace->ToFrames();
1760 locations.reserve(frames.size());
1761 for (const StackFrame& frame : llvm::reverse(frames)) {
1762 auto file_name = mlir::StringAttr::get(context_, frame.file_name);
1763 // Use col 1 as there is no column info in StackTrace.
1764 auto file_line_loc =
1765 mlir::FileLineColLoc::get(file_name, frame.line_number, 1);
1766 locations.push_back(file_line_loc);
1767 }
1768 } else {
1769 DVLOG(1) << "No stack trace for " << node.name();
1770 const auto location_it = debug_info.find(debug_info_key);
1771 if (location_it != debug_info.end()) {
1772 DVLOG(1) << "Available serialized debug info for " << node.name();
1773 // Convert the stack trace to a chain of mlir::CallSiteLocs.
1774 const auto& trace = location_it->second;
1775 locations.reserve(trace.file_line_cols_size());
1776 for (const auto& location : trace.file_line_cols()) {
1777 const auto& file = debug_info_.files(location.file_index());
1778 auto file_name = mlir::StringAttr::get(context_, file);
1779 auto file_line_loc = mlir::FileLineColLoc::get(
1780 file_name, location.line(), location.col());
1781 locations.push_back(file_line_loc);
1782 }
1783 }
1784 }
1785
1786 // If there are no locations in the stack trace, fall back to just a
1787 // NameLoc with no child.
1788 if (locations.empty()) return mlir::NameLoc::get(name_loc_id);
1789
1790 // Use the front FileLineColLoc to generate a NameLoc.
1791 mlir::Location node_name_loc =
1792 mlir::NameLoc::get(name_loc_id, locations.front());
1793
1794 // If there are more locations then generate a stack trace, otherwise just
1795 // return the name loc.
1796 auto callsite_locs = llvm::makeArrayRef(locations).drop_front();
1797 return callsite_locs.empty()
1798 ? node_name_loc
1799 : mlir::CallSiteLoc::get(node_name_loc, callsite_locs);
1800 };
1801
1802 // Create a location for node `name` in function `function_name`.
1803 auto create_op_type_and_name_locations = [&]() {
1804 return mlir::FusedLoc::get(
1805 context_,
1806 // Add the type operation for the propagation of op_type metadata.
1807 {mlir::NameLoc::get(
1808 mlir::StringAttr::get(context_, node.type_string() + ":")),
1809 create_location(node.name(), function_name_for_debug_info_)});
1810 };
1811
1812 // For NextIteration nodes, location is used to pair source and sink nodes.
1813 // Hence, we use node name as location to keep it unique.
1814 // TODO(prakalps): In future the plan is to use tokens to pair source/sink
1815 // nodes. Then NextIteration nodes would not need to be handled separately.
1816 if (node.type_string() == "NextIteration") {
1817 return create_op_type_and_name_locations();
1818 }
1819
1820 const auto& node_def = node.def();
1821 auto original_nodes =
1822 node_def.experimental_debug_info().original_node_names();
1823 auto original_funcs =
1824 node_def.experimental_debug_info().original_func_names();
1825
1826 if (original_nodes.empty()) {
1827 return create_op_type_and_name_locations();
1828 } else {
1829 // If the original nodes are defined, then we use them to get a list of
1830 // call sites, and then fuse them to a single fused location, with the name
1831 // of the node_def.
1832 llvm::SmallVector<mlir::Location, 4> node_locations;
1833 node_locations.reserve(original_nodes.size() + 2);
1834 // Add the type operation for the propagation of op_type metadata.
1835 node_locations.push_back(mlir::NameLoc::get(
1836 mlir::StringAttr::get(context_, node.type_string() + ":")));
1837 // Retrieve the names from the experimental_debug_info.
1838 for (int i = 0, e = original_nodes.size(); i != e; ++i) {
1839 auto node_name = original_nodes[i];
1840 auto func_name = (i < original_funcs.size()) ? original_funcs[i] : "";
1841 node_locations.push_back(create_location(node_name, func_name));
1842 }
1843 // Retrieve the name of the node_def.
1844 node_locations.push_back(
1845 create_location(node.name(), function_name_for_debug_info_));
1846 return mlir::FusedLoc::get(context_, node_locations);
1847 }
1848 }
1849
EmitErrorWithLocationStr(const Node & node,const Status & error_status)1850 Status ImporterBase::EmitErrorWithLocationStr(const Node& node,
1851 const Status& error_status) {
1852 const mlir::Location location = GetLocation(node);
1853 mlir::emitError(location);
1854 return error_handler_.Combine(error_status);
1855 }
1856
CreateOperation(const Node & node,llvm::StringRef node_type_name,const mlir::OperationState & result,const llvm::SmallVectorImpl<mlir::Value> & control_operands)1857 mlir::Operation* ImporterBase::CreateOperation(
1858 const Node& node, llvm::StringRef node_type_name,
1859 const mlir::OperationState& result,
1860 const llvm::SmallVectorImpl<mlir::Value>& control_operands) {
1861 // For the tf.executor specific operations (not wrapped in an island), we
1862 // have an extra returned value for the control result, and we concatenate
1863 // control and non-control operands.
1864 mlir::SmallVector<mlir::Type, 4> types(result.types);
1865 types.push_back(mlir::tf_executor::ControlType::get(builder_.getContext()));
1866 mlir::SmallVector<mlir::Value, 4> operands(result.operands);
1867 operands.append(control_operands.begin(), control_operands.end());
1868
1869 auto loc = result.location;
1870 // Dispatch based on the name and create the appropriate operation.
1871 if (node.IsSwitch()) {
1872 // Switch and _SwitchN both are in switch class, differentiate based on
1873 // op name.
1874 if (node.op_def().name() == "_SwitchN") {
1875 return builder_.create<mlir::tf_executor::SwitchNOp>(loc, types, operands,
1876 result.attributes);
1877 }
1878 return builder_.create<mlir::tf_executor::SwitchOp>(loc, types, operands,
1879 result.attributes);
1880 }
1881 if (node.IsMerge()) {
1882 return builder_.create<mlir::tf_executor::MergeOp>(loc, types, operands,
1883 result.attributes);
1884 }
1885 if (node.IsNextIteration()) {
1886 // NextIteration is a bit special, we create a pair of operations that are
1887 // linked together through a token returned by the source.
1888 // We make use of a separate builder to insert the source at the top of
1889 // the block.
1890 mlir::OpBuilder builder_at_begin(builder_.getBlock(),
1891 builder_.getBlock()->begin());
1892 auto source_op =
1893 builder_at_begin.create<mlir::tf_executor::NextIterationSourceOp>(
1894 loc, operands[0].getType(), result.attributes);
1895 return builder_.create<mlir::tf_executor::NextIterationSinkOp>(
1896 loc, source_op.token(), operands, result.attributes);
1897 }
1898 if (node.IsLoopCond()) {
1899 return builder_.create<mlir::tf_executor::LoopCondOp>(loc, types, operands,
1900 result.attributes);
1901 }
1902 if (node.IsEnter()) {
1903 return builder_.create<mlir::tf_executor::EnterOp>(loc, types, operands,
1904 result.attributes);
1905 }
1906 if (node.IsExit()) {
1907 return builder_.create<mlir::tf_executor::ExitOp>(loc, types, operands,
1908 result.attributes);
1909 }
1910 if (node.IsControlTrigger()) {
1911 return builder_.create<mlir::tf_executor::ControlTriggerOp>(
1912 loc, operands, result.attributes);
1913 }
1914 // Regular TensorFlow operation are wrapped in a tf_executor.island.
1915 auto island = builder_.create<mlir::tf_executor::IslandOp>(
1916 result.location, types, control_operands,
1917 mlir::ArrayRef<mlir::NamedAttribute>{});
1918 island.body().push_back(new mlir::Block);
1919 mlir::OpBuilder island_builder =
1920 mlir::OpBuilder::atBlockEnd(&island.GetBody());
1921
1922 // Create the operation inside the island now.
1923 mlir::Operation* inner_op = island_builder.create(result);
1924
1925 // Sets operand_segment_sizes or result_segment_sizes attribute to the op.
1926 const auto set_segment_sizes_attr =
1927 [&](const NameRangeMap& arg_ranges,
1928 const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
1929 llvm::StringRef attr_name) {
1930 std::vector<int32_t> values;
1931 values.reserve(args.size());
1932 for (const auto& arg : args) {
1933 auto range = arg_ranges.at(arg.name());
1934 values.push_back(
1935 range.second - range.first);
1936 }
1937 auto attr_value = mlir::DenseI32ArrayAttr::get(inner_op->getContext(), values);
1938 inner_op->setAttr(attr_name, attr_value);
1939 };
1940
1941 if (inner_op->hasTrait<mlir::OpTrait::AttrSizedOperandSegments>() ||
1942 inner_op->hasTrait<mlir::OpTrait::AttrSizedResultSegments>()) {
1943 // The op has multiple variadic operands or results.
1944 // Calculate operand and result segment sizes using the OpDef.
1945 NameRangeMap input_ranges, output_ranges;
1946 // This will fail only if the OpDef is syntactically invalid.
1947 // TODO(jpienaar): Convert this CHECK into a properly propagated error.
1948 TF_CHECK_OK(
1949 NameRangesForNode(node, node.op_def(), &input_ranges, &output_ranges));
1950 if (inner_op->hasTrait<mlir::OpTrait::AttrSizedOperandSegments>()) {
1951 // Add derived "operand_segment_sizes" attr to the created operation.
1952 // TODO(b/146937733): Don't use <void> here.
1953 set_segment_sizes_attr(input_ranges, node.op_def().input_arg(),
1954 mlir::OpTrait::AttrSizedOperandSegments<
1955 void>::getOperandSegmentSizeAttr());
1956 }
1957
1958 if (inner_op->hasTrait<mlir::OpTrait::AttrSizedResultSegments>()) {
1959 // Add derived "result_segment_sizes" attr to the created operation.
1960 // TODO(b/146937733): Don't use <void> here.
1961 set_segment_sizes_attr(output_ranges, node.op_def().output_arg(),
1962 mlir::OpTrait::AttrSizedResultSegments<
1963 void>::getResultSegmentSizeAttr());
1964 }
1965 }
1966
1967 if (VLOG_IS_ON(1)) {
1968 mlir::OperationName name = inner_op->getName();
1969 if (!name.isRegistered() &&
1970 // Skip unmodelled ops that are handled differently.
1971 (node_type_name != "_Arg" && node_type_name != "_Retval") &&
1972 !unmodelled_op_names_.count(name.getIdentifier())) {
1973 if (node.op_def().is_stateful()) {
1974 VLOG(1) << "[potentially conservative] Op type `" << node.type_string()
1975 << "` is stateful but effects not modelled";
1976 } else {
1977 // See if any resource type is used.
1978 bool resource = false;
1979 std::function<bool(mlir::Type)> record_resource;
1980 record_resource = [&](mlir::Type type) {
1981 if (resource) return true;
1982 if (type.isa<mlir::TF::ResourceType>()) {
1983 resource = true;
1984 return true;
1985 }
1986 if (auto with_subtype =
1987 type.dyn_cast<mlir::SubElementTypeInterface>()) {
1988 with_subtype.walkSubTypes(
1989 [&](mlir::Type t) { record_resource(t); });
1990 }
1991 return resource;
1992 };
1993
1994 for (mlir::Type t : inner_op->getResultTypes())
1995 if (record_resource(t)) break;
1996 for (mlir::Type t : inner_op->getOperandTypes())
1997 if (record_resource(t)) break;
1998 if (resource) {
1999 unmodelled_op_names_.insert(name.getIdentifier());
2000 VLOG(1) << "[potentially conservative] Op type `"
2001 << node.type_string()
2002 << "` has resource operands/results but effects not modelled";
2003 }
2004 }
2005 }
2006 }
2007
2008 // Add the terminator for the island
2009 island_builder.create<mlir::tf_executor::YieldOp>(result.location,
2010 inner_op->getResults());
2011 return island.getOperation();
2012 }
2013
ConvertNode(const Node & node)2014 Status ImporterBase::ConvertNode(const Node& node) {
2015 if (!node.IsOp()) {
2016 // Don't import the pseudo-nodes _SOURCE or _SINK. These are added by
2017 // Graph and don't exist in GraphDef.
2018 return OkStatus();
2019 }
2020
2021 // If it is a custom OP, its definition should be found in the library. We
2022 // create the MLIR function and insert it to the module if it doesn't exist.
2023 std::string node_type_name = node.type_string();
2024 const auto* func_def = graph_flib_.Find(node_type_name);
2025 bool convert_to_legacy_call = false;
2026 if (func_def) {
2027 TF_RETURN_IF_ERROR(ConvertLibFunction(node_type_name));
2028 node_type_name = (*tf_name_to_mlir_name_)[node_type_name];
2029 convert_to_legacy_call = true;
2030 }
2031
2032 auto get_full_op_name = [&](const std::string& op_name) {
2033 const char* kTfPrefix = "tf.";
2034 return kTfPrefix + op_name;
2035 };
2036
2037 std::string op_name = get_full_op_name(node_type_name);
2038 if (back_edge_node_output_.contains(&node)) {
2039 op_name = op_name + ".sink";
2040 }
2041
2042 mlir::OperationState result(GetLocation(node), op_name);
2043 for (int i = 0; i < node.num_outputs(); ++i) {
2044 // The backedge has been removed, so we shouldn't count the corresponding
2045 // output from the src node when converting to an operation.
2046 if (back_edge_node_output_.contains(&node) &&
2047 back_edge_node_output_[&node] == i) {
2048 continue;
2049 }
2050 TF_ASSIGN_OR_RETURN(auto type, InferOutputType(node, i, builder_));
2051 result.types.push_back(type);
2052 }
2053
2054 // Surprisingly input edges can be nondeterministically ordered. This
2055 // particularly seems to be the case for the control edges between _SOURCE
2056 // and _SINK that the Graph constructor inserts. Copy the input edges and
2057 // sort the edges, but only the control edges, not data edges!
2058 // TODO(jmolloy): We should probably just ignore _SOURCE and _SINK nodes.
2059 // They'll break roundtripping anyway unless we strip them when converting
2060 // back to graphdef.
2061 absl::InlinedVector<const Edge*, 8> in_edges(node.in_edges().size());
2062 absl::c_copy(node.in_edges(), in_edges.begin());
2063 absl::c_stable_sort(in_edges, [](const Edge* e1, const Edge* e2) {
2064 if (e1->IsControlEdge() && !e2->IsControlEdge()) return false;
2065 if (!e1->IsControlEdge() && e2->IsControlEdge()) return true;
2066 if (e1->IsControlEdge() && e2->IsControlEdge())
2067 return e1->src()->id() < e2->src()->id();
2068 return e1->dst_input() < e2->dst_input();
2069 });
2070
2071 result.operands.reserve(in_edges.size());
2072
2073 // Collect the control operands separately, they will be held by the island.
2074 mlir::SmallVector<mlir::Value, 8> control_operands;
2075
2076 for (const auto* input_edge : in_edges) {
2077 const Node& input_node = *input_edge->src();
2078 if (input_node.IsSource()) {
2079 if (in_edges.size() != 1) {
2080 return errors::FailedPrecondition(
2081 "The node has other inputs besides the _Source node");
2082 }
2083 // We don't import the _SOURCE node.
2084 continue;
2085 }
2086 if (input_node.IsArg() && input_edge->IsControlEdge()) {
2087 // Currently we have not reached consensus as to what TF function
2088 // semantics are (b/133509504). Here we assume that all arguments to a
2089 // function should be available before we start execution of any internal
2090 // node. This makes the control dependencies between function arguments
2091 // and internal nodes redundant, and so we do not import them. The TF
2092 // inliner however assumes no such dependency between function args and
2093 // internal nodes exists, unless explicitly stated. Since we drop control
2094 // dependencies here, it leads to loss of information. If the function is
2095 // inlined later, the inliner would not know of these explicit control
2096 // dependencies present in the original graph.
2097 continue;
2098 }
2099 if (node_values_.find(input_node.id()) == node_values_.end())
2100 return errors::FailedPrecondition(
2101 "Graph not traversed in reverse post order; use seen before def!");
2102 mlir::Operation* inst = node_values_[input_node.id()];
2103 if (input_edge->IsControlEdge())
2104 control_operands.push_back(inst->getResult(inst->getNumResults() - 1));
2105 else
2106 result.operands.push_back(inst->getResult(input_edge->src_output()));
2107 }
2108
2109 using FuncPairType = std::pair<const std::string*, const AttrValue*>;
2110 std::vector<FuncPairType> funcs;
2111 result.attributes.reserve(node.attrs().size() + 2);
2112 auto abstract_op = result.name.getRegisteredInfo();
2113 auto derived_op =
2114 abstract_op
2115 ? abstract_op->getInterface<mlir::DerivedAttributeOpInterface>()
2116 : nullptr;
2117 for (const auto& name_and_value : node.attrs()) {
2118 const auto& attr_name = name_and_value.first;
2119 // Skip adding derived attributes to the generated op.
2120 if (derived_op && derived_op->isDerivedAttribute(attr_name)) continue;
2121 const AttrValue& attr_value = name_and_value.second;
2122
2123 // Remove _output_shapes attribute that will be added by the exporter.
2124 if (IsOutputShapesAttribute(attr_value, attr_name)) continue;
2125
2126 if (attr_value.value_case() == AttrValue::kFunc) {
2127 // Attribute iteration order is not defined for protocol buffer Map.
2128 // Process function attributes separately in the lexicographical order to
2129 // have deterministic order of functions in the constructed IR.
2130 funcs.emplace_back(&attr_name, &attr_value);
2131 } else {
2132 TF_ASSIGN_OR_RETURN(auto attr, ConvertAttributeValue(attr_value));
2133 result.attributes.push_back(builder_.getNamedAttr(attr_name, attr));
2134 }
2135 }
2136
2137 auto comparator = [](const FuncPairType& a, const FuncPairType& b) {
2138 return *a.first < *b.first;
2139 };
2140 std::sort(funcs.begin(), funcs.end(), comparator);
2141 for (const auto& func : funcs) {
2142 TF_RETURN_IF_ERROR(ConvertFunctionCallAttribute(*func.first, *func.second,
2143 &result.attributes));
2144 }
2145
2146 const auto& node_def = node.def();
2147 // NodeDef can contain partial TF device names. In such cases, canonicalize
2148 // it. Note that in current TF, placer will place full device name to each
2149 // node.
2150 DeviceNameUtils::ParsedName parsed_name;
2151 if (!DeviceNameUtils::ParseFullName(node_def.device(), &parsed_name)) {
2152 return errors::InvalidArgument(
2153 "Op ", op_name, " has invalid device name: ", node_def.device());
2154 }
2155 // Keep the parsed name untouched if the device name is empty.
2156 if (!node_def.device().empty()) {
2157 if (!parsed_name.has_type) {
2158 parsed_name.type = "CPU";
2159 parsed_name.has_type = true;
2160 }
2161 if (!parsed_name.has_id) {
2162 parsed_name.id = 0;
2163 parsed_name.has_id = true;
2164 }
2165 }
2166 result.attributes.push_back(builder_.getNamedAttr(
2167 "device", builder_.getStringAttr(
2168 DeviceNameUtils::ParsedNameToString(parsed_name))));
2169
2170 // Map user function calls to LegacyCall ops and add the user function name
2171 // as an attribute.
2172 if (convert_to_legacy_call) {
2173 result.name = mlir::OperationName(get_full_op_name("LegacyCall"), context_);
2174 mlir::SymbolRefAttr val =
2175 mlir::SymbolRefAttr::get(builder_.getContext(), node_type_name);
2176 result.addAttribute("f", val);
2177
2178 if (!result.attributes.get("_disable_call_shape_inference")) {
2179 result.addAttribute("_disable_call_shape_inference",
2180 builder_.getBoolAttr(false));
2181 }
2182 }
2183
2184 auto composite_control_flow_op = [&](const std::string& name) {
2185 result.name = mlir::OperationName(get_full_op_name(name), context_);
2186 bool stateless = absl::StartsWith(node_type_name, "Stateless");
2187 mlir::BoolAttr val = builder_.getBoolAttr(stateless);
2188 result.attributes.push_back(builder_.getNamedAttr("is_stateless", val));
2189 };
2190
2191 // Map Case/If/While and StatelessCase/If/While op in TensorFlow to the common
2192 // Case/If/While op in MLIR and add the differentiating attribute.
2193 if (node.IsCaseNode()) composite_control_flow_op("Case");
2194 if (node.IsIfNode()) composite_control_flow_op("If");
2195 if (node.IsWhileNode()) {
2196 composite_control_flow_op("While");
2197 auto* output_shapes = node.attrs().Find("output_shapes");
2198 if (output_shapes && !output_shapes->list().shape().empty())
2199 result.attributes.push_back(
2200 builder_.getNamedAttr("shape_invariant", builder_.getUnitAttr()));
2201 }
2202
2203 // Register the mapping between the TF node and the newly created operation.
2204 node_values_[node.id()] =
2205 CreateOperation(node, node_type_name, result, control_operands);
2206 return OkStatus();
2207 }
2208
2209 // Add the backedges to the CFG. Given a backedge, we replace the original
2210 // source and destination operations by two new operations. Most of the
2211 // fields of the replacements are copied from the original operations.
2212 // However,
2213 // - for the src operation, one output is inserted to the front of the output
2214 // list. The type of the output is set to the type of the non-control result
2215 // of the dst operation, and
2216 // - for the dst operation, one operand is inserted to the front of the
2217 // operand list. This operand is using the first result of the src
2218 // operation.
2219 // TODO(fengliuai): Preserve the order of the results and operands if
2220 // necessary.
AddBackedges()2221 Status ImporterBase::AddBackedges() {
2222 for (auto it : back_edge_dst_inputs_) {
2223 BackEdge& edge = it.second;
2224 if (!edge.src->IsNextIteration() || !edge.dst->IsMerge()) {
2225 return errors::FailedPrecondition(
2226 "Invalid backedge; should be from NextIteration to Merge!");
2227 }
2228 auto* sink = node_values_[edge.src->id()];
2229 auto* dst = node_values_[edge.dst->id()];
2230 TF_RETURN_IF_ERROR(AddBackedge(sink, dst, edge.dst_input));
2231 }
2232 return OkStatus();
2233 }
2234
AddBackedge(mlir::Operation * sink,mlir::Operation * dst,int dst_input)2235 Status ImporterBase::AddBackedge(mlir::Operation* sink, mlir::Operation* dst,
2236 int dst_input) {
2237 // Get the NextIteration.Source operation from the token operand of the sink.
2238 mlir::Operation* source = sink->getOperand(0).getDefiningOp();
2239
2240 // Adds the "source" to the operands of the dst by creating a new dst
2241 // operation.
2242 mlir::OperationState state(dst->getLoc(), dst->getName());
2243 auto num_operands = dst->getNumOperands();
2244 state.operands.reserve(num_operands + 1);
2245 for (int input = 0, e = num_operands + 1; input != e; ++input) {
2246 if (input < dst_input) {
2247 state.operands.push_back(dst->getOperand(input));
2248 } else if (input == dst_input) {
2249 state.operands.push_back(source->getResult(0));
2250 } else {
2251 state.operands.push_back(dst->getOperand(input - 1));
2252 }
2253 }
2254 state.attributes.assign(dst->getAttrs().begin(), dst->getAttrs().end());
2255 state.types.assign(dst->getResultTypes().begin(),
2256 dst->getResultTypes().end());
2257 builder_.setInsertionPoint(dst);
2258 auto* new_dst = builder_.create(state);
2259
2260 // Replaces the output uses of the old operation by the corresponding
2261 // result of the new operation, and deletes the old operation.
2262 for (unsigned i = 0, e = dst->getNumResults(); i != e; ++i) {
2263 auto new_output = new_dst->getResult(i);
2264 dst->getResult(i).replaceAllUsesWith(new_output);
2265 }
2266 dst->dropAllReferences();
2267 dst->erase();
2268 return OkStatus();
2269 }
2270
InferLibFunctionType(const FunctionBody & fbody)2271 StatusOr<mlir::FunctionType> ImporterBase::InferLibFunctionType(
2272 const FunctionBody& fbody) {
2273 mlir::Builder builder(context_);
2274
2275 // The FunctionBody contains a graph with a single-output _Arg node for each
2276 // function argument and a single-input _Retval node for each function return
2277 // value.
2278 //
2279 // We already populated the ShapeRefiner with all the information about the
2280 // shapes of these graph edges, so we just query it to build the corresponding
2281 // MLIR function type signature.
2282
2283 llvm::SmallVector<mlir::Type, 4> arg_types;
2284 if (specs_.inputs.empty()) {
2285 arg_types.reserve(fbody.arg_types.size());
2286 for (auto arg : fbody.arg_nodes) {
2287 // Find node in the graph using the node id instead of using `arg`
2288 // directly because the graph has been cloned.
2289 auto* node = graph_->FindNodeId(arg->id());
2290 TF_ASSIGN_OR_RETURN(auto type,
2291 InferOutputType(*node, /*idx=*/0, builder));
2292 arg_types.push_back(type);
2293 }
2294 } else {
2295 arg_types.reserve(fbody.arg_types.size());
2296 for (const auto& it : llvm::enumerate(specs_.inputs)) {
2297 mlir::Type element_type;
2298 const auto& node_info = it.value().second;
2299 DataType dtype = node_info.imported_dtype;
2300 // Uses the existing output type of the arg node if the data type of the
2301 // the node isn't specified through the import configuration.
2302 if (dtype == DT_INVALID) {
2303 auto arg = fbody.arg_nodes[it.index()];
2304 auto* node = graph_->FindNodeId(arg->id());
2305 dtype = node->output_type(0);
2306 if (dtype == DT_INVALID) {
2307 return errors::InvalidArgument("Input ", it.index(),
2308 "has invalid data type");
2309 }
2310 }
2311 TF_RETURN_IF_ERROR(
2312 ::tensorflow::ConvertDataType(dtype, builder, &element_type));
2313 if (node_info.shape.unknown_rank()) {
2314 arg_types.push_back(mlir::UnrankedTensorType::get(element_type));
2315 } else {
2316 llvm::SmallVector<int64_t, 4> shape;
2317 TF_RETURN_IF_ERROR(ConvertToMlirShape(node_info.shape, &shape));
2318 arg_types.push_back(mlir::RankedTensorType::get(shape, element_type));
2319 }
2320 }
2321 }
2322
2323 llvm::SmallVector<mlir::Type, 4> ret_types;
2324 ret_types.reserve(fbody.ret_types.size());
2325 for (auto ret : fbody.ret_nodes) {
2326 // Find node in the graph using the node id instead of using `ret` directly
2327 // because the graph has been cloned.
2328 auto* node = graph_->FindNodeId(ret->id());
2329 TF_ASSIGN_OR_RETURN(auto type, InferInputType(*node, /*idx=*/0, builder));
2330 ret_types.push_back(type);
2331 }
2332
2333 return builder.getFunctionType(arg_types, ret_types);
2334 }
2335
2336 // Stateful helper class to import a TensorFlow model expressed in GraphDef into
2337 // an MLIR Module.
2338 //
2339 // The nodes defined in the graph are converted to a function called
2340 // 'func_name'. All library function definitions are converted to MLIR functions
2341 // in the module.
2342 class GraphDefImporter : public ImporterBase {
2343 public:
2344 // Main entry point: converts the given graph to an MLIR Module.
2345 static StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> Convert(
2346 mlir::MLIRContext* context, const Graph& graph,
2347 const GraphDebugInfo& debug_info,
2348 const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs,
2349 std::unordered_map<std::string, std::string>& tf_name_to_mlir_name);
2350
2351 private:
GraphDefImporter(const FunctionLibraryDefinition & flib,const GraphDebugInfo & debug_info,const GraphImportConfig & specs,mlir::ModuleOp module,std::unordered_map<std::string,std::string> * tf_name_to_mlir_name,NameUniquifier * function_name_uniquifier)2352 explicit GraphDefImporter(
2353 const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
2354 const GraphImportConfig& specs, mlir::ModuleOp module,
2355 std::unordered_map<std::string, std::string>* tf_name_to_mlir_name,
2356 NameUniquifier* function_name_uniquifier)
2357 : ImporterBase(flib, debug_info, specs, module, tf_name_to_mlir_name,
2358 function_name_uniquifier) {}
2359
2360 // Returns the function signature of the main function of converted MLIR
2361 // module, the input nodes and output nodes. The type and shape information
2362 // for the function arguments are read from `specs`, but the type and shape
2363 // information for the function returns are inferred by the shape refiner in
2364 // ImporterBase.
2365 StatusOr<mlir::FunctionType> InferMainFunctionType(
2366 const GraphImportConfig& specs, mlir::MLIRContext* context,
2367 absl::InlinedVector<OutputTensor, 4>* arg_nodes,
2368 absl::InlinedVector<OutputTensor, 4>* ret_nodes);
2369
2370 // Returns the function signature of the main function, alongside input and
2371 // output nodes, for function graphs. Arguments and return values are
2372 // determined by node op type. Type and shape information of the function are
2373 // inferred by the shape refiner in ImporterBase.
2374 StatusOr<mlir::FunctionType> GetArgsRetsAndTypesFromFunctionGraph(
2375 mlir::MLIRContext* context,
2376 absl::InlinedVector<OutputTensor, 4>* arg_nodes,
2377 absl::InlinedVector<OutputTensor, 4>* ret_nodes);
2378
2379 // Finds the graph's target nodes/function's control ret nodes based on
2380 // supplied node names in `control_outputs`. If `control_outputs` are not
2381 // unique or a control ret node is missing, an error will be returned.
2382 Status GetControlRetsFromGraph(
2383 llvm::ArrayRef<std::string> control_outputs,
2384 absl::InlinedVector<Node*, 4>* control_ret_nodes);
2385 };
2386
Convert(mlir::MLIRContext * context,const Graph & graph,const GraphDebugInfo & debug_info,const FunctionLibraryDefinition & flib_def,const GraphImportConfig & specs,std::unordered_map<std::string,std::string> & tf_name_to_mlir_name)2387 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> GraphDefImporter::Convert(
2388 mlir::MLIRContext* context, const Graph& graph,
2389 const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def,
2390 const GraphImportConfig& specs,
2391 std::unordered_map<std::string, std::string>& tf_name_to_mlir_name) {
2392 LoadImporterDialects(*context);
2393 mlir::OwningOpRef<mlir::ModuleOp> module =
2394 mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
2395 NameUniquifier function_name_uniquifier(flib_def);
2396
2397 // importer.PrepareConvert below will attemp to clone the original `graph`
2398 // via conversion to the graph def first. Convert graph to graph_def here
2399 // first and avoid extra copies later.
2400 auto graph_def = std::make_unique<GraphDef>();
2401 graph.ToGraphDef(graph_def.get());
2402
2403 static std::atomic<uint32> counter(0);
2404 uint32 current_file_prefix = counter++;
2405 const auto* graph_crash_handle = crash_analysis::ReportProtoDataOnCrash(
2406 absl::StrCat(current_file_prefix, "_mlir_import_graph.pbtxt"),
2407 *graph_def);
2408 auto reachable_flib = flib_def.ReachableDefinitions(*graph_def);
2409 const auto* flib_crash_handle = crash_analysis::ReportProtoDataOnCrash(
2410 absl::StrCat(current_file_prefix, "_mlir_import_flib.pbtxt"),
2411 reachable_flib.ToProto());
2412
2413 auto scope_exit = llvm::make_scope_exit([&]() {
2414 crash_analysis::RemoveReportData(graph_crash_handle);
2415 crash_analysis::RemoveReportData(flib_crash_handle);
2416 });
2417
2418 VLOG(1) << "Importing: "
2419 << ::tensorflow::DumpGraphToFile("tf_mlir_importer_base", graph,
2420 &flib_def);
2421
2422 GraphDefImporter importer(flib_def, debug_info, specs, module.get(),
2423 &tf_name_to_mlir_name, &function_name_uniquifier);
2424
2425 TF_RETURN_IF_ERROR(importer.PrepareConvert(graph, std::move(graph_def)));
2426
2427 mlir::FunctionType func_type;
2428 absl::InlinedVector<OutputTensor, 4> arg_nodes;
2429 absl::InlinedVector<OutputTensor, 4> ret_nodes;
2430 absl::InlinedVector<Node*, 4> control_ret_nodes;
2431 llvm::SmallVector<mlir::NamedAttribute, 1> attrs;
2432 if (specs.graph_as_function) {
2433 if (specs.prune_unused_nodes || !specs.inputs.empty() ||
2434 !specs.outputs.empty())
2435 return errors::InvalidArgument(
2436 "Pruning of graph is currently unsupported when the main graph is "
2437 "converted to a function.");
2438
2439 TF_ASSIGN_OR_RETURN(func_type,
2440 importer.GetArgsRetsAndTypesFromFunctionGraph(
2441 context, &arg_nodes, &ret_nodes));
2442
2443 TF_RETURN_IF_ERROR(importer.GetControlRetsFromGraph(specs.control_outputs,
2444 &control_ret_nodes));
2445
2446 mlir::Builder b(context);
2447 std::string s;
2448 llvm::raw_string_ostream ss(s);
2449 auto node_name = [&](const OutputTensor& tensor) {
2450 ss << tensor.node->name();
2451 };
2452 llvm::interleave(arg_nodes, ss, node_name, ",");
2453 auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str()));
2454 s.clear();
2455 llvm::interleave(ret_nodes, ss, node_name, ",");
2456 auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str()));
2457 s.clear();
2458 llvm::interleave(specs.control_outputs, ss, ",");
2459 auto control_outputs =
2460 b.getNamedAttr("control_outputs", b.getStringAttr(ss.str()));
2461
2462 // Under `graph_as_function` mode, `tf.entry_function` is always set as it
2463 // is assumed feed, fetch, and target nodes are set correctly.
2464 attrs.push_back(b.getNamedAttr(
2465 "tf.entry_function",
2466 b.getDictionaryAttr({inputs, outputs, control_outputs})));
2467 } else {
2468 // Collects the argument and return nodes by looking up the node names
2469 // specified by the user.
2470 TF_ASSIGN_OR_RETURN(func_type, importer.InferMainFunctionType(
2471 specs, context, &arg_nodes, &ret_nodes));
2472
2473 TF_RETURN_IF_ERROR(importer.GetControlRetsFromGraph(specs.control_outputs,
2474 &control_ret_nodes));
2475
2476 // TODO(prakalps): Refactor to keep tf.entry_function attribute encoding and
2477 // decoding in a centralized place.
2478 // Record the input and output mapping.
2479 if (!specs.inputs.empty() || !specs.outputs.empty() ||
2480 !specs.control_outputs.empty()) {
2481 mlir::Builder b(context);
2482 std::string s;
2483 llvm::raw_string_ostream ss(s);
2484 llvm::interleave(
2485 specs.inputs, ss,
2486 [&](const std::pair<std::string, ArrayInfo>& v) { ss << v.first; },
2487 ",");
2488 auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str()));
2489 s.clear();
2490 llvm::interleave(specs.outputs, ss, ",");
2491 auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str()));
2492 s.clear();
2493 llvm::interleave(specs.control_outputs, ss, ",");
2494 auto control_outputs =
2495 b.getNamedAttr("control_outputs", b.getStringAttr(ss.str()));
2496
2497 attrs.push_back(b.getNamedAttr(
2498 "tf.entry_function",
2499 b.getDictionaryAttr({inputs, outputs, control_outputs})));
2500 }
2501 }
2502
2503 // Record version info.
2504 PopulateTfVersions(module.get(), graph.versions());
2505
2506 const llvm::StringRef& graph_func_name =
2507 specs.graph_func_name.empty() ? kImportModelDefaultGraphFuncName
2508 : specs.graph_func_name;
2509 TF_RETURN_IF_ERROR(importer.ImporterBase::Convert(graph_func_name, func_type,
2510 arg_nodes, ret_nodes,
2511 control_ret_nodes, attrs));
2512 TF_RETURN_IF_ERROR(importer.ImporterBase::ConvertDeferredFunctions());
2513
2514 // Mark main function public, others private.
2515 for (auto function : module.get().getOps<mlir::func::FuncOp>()) {
2516 auto visibility = function.getName() == graph_func_name
2517 ? mlir::func::FuncOp::Visibility::Public
2518 : mlir::func::FuncOp::Visibility::Private;
2519 function.setVisibility(visibility);
2520 }
2521 VLOG(1) << "Imported: "
2522 << tensorflow::DumpMlirOpToFile("tf_mlir_imported_base",
2523 module.get());
2524 return module;
2525 }
2526
InferMainFunctionType(const GraphImportConfig & specs,mlir::MLIRContext * context,absl::InlinedVector<OutputTensor,4> * arg_nodes,absl::InlinedVector<OutputTensor,4> * ret_nodes)2527 StatusOr<mlir::FunctionType> GraphDefImporter::InferMainFunctionType(
2528 const GraphImportConfig& specs, mlir::MLIRContext* context,
2529 absl::InlinedVector<OutputTensor, 4>* arg_nodes,
2530 absl::InlinedVector<OutputTensor, 4>* ret_nodes) {
2531 // Find all the input nodes and output nodes.
2532 // Feeds have been remapped to single output nodes (Placeholder), so an exact
2533 // name match is sufficient.
2534 absl::flat_hash_map<absl::string_view, int> inputs;
2535 for (auto input_and_idx : llvm::enumerate(specs.inputs)) {
2536 TensorId tensor = ParseTensorName(input_and_idx.value().first);
2537 auto remapped_it = remapped_feeds_.find(tensor);
2538 if (remapped_it != remapped_feeds_.end()) {
2539 inputs.insert({remapped_it->second, input_and_idx.index()});
2540 } else {
2541 inputs.insert({tensor.node(), input_and_idx.index()});
2542 }
2543 }
2544
2545 absl::flat_hash_set<absl::string_view> output_node_names;
2546 std::vector<TensorId> outputs;
2547 output_node_names.reserve(specs.outputs.size());
2548 for (const auto& output : specs.outputs) {
2549 TensorId tensor = ParseTensorName(output);
2550 auto remapped_it = remapped_feeds_.find(tensor);
2551 if (remapped_it != remapped_feeds_.end()) {
2552 output_node_names.insert(remapped_it->second);
2553 outputs.push_back({remapped_it->second, 0});
2554 } else {
2555 output_node_names.insert(tensor.node());
2556 outputs.push_back(tensor);
2557 }
2558 }
2559
2560 if (!inputs.empty() || !outputs.empty()) {
2561 arg_nodes->resize(inputs.size());
2562 ret_nodes->resize(outputs.size());
2563
2564 for (Node* n : GetOrderedNodes()) {
2565 // Handle inputs/arguments.
2566 auto input_it = inputs.find(n->name());
2567 if (input_it != inputs.end()) {
2568 (*arg_nodes)[input_it->second] = {n, 0};
2569 }
2570
2571 // Handle outputs/returns.
2572 if (output_node_names.contains(n->name())) {
2573 for (int i = 0, e = outputs.size(); i != e; ++i) {
2574 TensorId tensor = outputs[i];
2575 if (n->name() != tensor.node()) continue;
2576 (*ret_nodes)[i] = {n, tensor.index()};
2577 }
2578 }
2579 }
2580 }
2581
2582 // Starts to construct the function type.
2583 mlir::Builder builder(context);
2584 llvm::SmallVector<mlir::Type, 4> arg_types;
2585 arg_types.reserve(specs.inputs.size());
2586 int i = 0;
2587 for (const auto& it : specs.inputs) {
2588 Node* arg_node = arg_nodes->at(i).node;
2589 if (arg_node == nullptr) {
2590 return errors::InvalidArgument("Input ", it.first,
2591 " was not found in graph");
2592 }
2593 mlir::Type element_type;
2594 const auto& node_info = it.second;
2595 DataType imported_dtype = node_info.imported_dtype;
2596 // Uses the existing output type of the arg node if the data type of the
2597 // the node isn't specified through the import configuration.
2598 if (imported_dtype == DT_INVALID) {
2599 imported_dtype = arg_node->output_type(0);
2600 if (imported_dtype == DT_INVALID) {
2601 return errors::InvalidArgument("Input ", i, "has invalid data type");
2602 }
2603 }
2604 // Check if we have subtypes first
2605 if (!node_info.subtypes.empty()) {
2606 std::vector<mlir::TensorType> subtypes;
2607 for (const auto& st : node_info.subtypes) {
2608 mlir::Type st_data_type;
2609 llvm::SmallVector<int64_t> shape;
2610 TF_RETURN_IF_ERROR(ConvertToMlirShape(st.shape, &shape));
2611 TF_RETURN_IF_ERROR(
2612 ConvertDataType(st.imported_dtype, builder, &st_data_type));
2613 subtypes.push_back(mlir::RankedTensorType::get(shape, st_data_type));
2614 }
2615 if (imported_dtype == DT_RESOURCE) {
2616 element_type =
2617 mlir::TF::ResourceType::get(subtypes, builder.getContext());
2618 } else if (imported_dtype == DT_VARIANT) {
2619 element_type =
2620 mlir::TF::VariantType::get(subtypes, builder.getContext());
2621 } else {
2622 return errors::InvalidArgument(DataType_Name(imported_dtype),
2623 " takes no subtypes.");
2624 }
2625 } else {
2626 TF_RETURN_IF_ERROR(
2627 ConvertDataType(imported_dtype, builder, &element_type));
2628 }
2629 if (node_info.shape.unknown_rank()) {
2630 arg_types.push_back(mlir::UnrankedTensorType::get(element_type));
2631 } else {
2632 llvm::SmallVector<int64_t, 4> shape;
2633 TF_RETURN_IF_ERROR(ConvertToMlirShape(node_info.shape, &shape));
2634 arg_types.push_back(mlir::RankedTensorType::get(shape, element_type));
2635 }
2636 i++;
2637 }
2638
2639 llvm::SmallVector<mlir::Type, 4> ret_types;
2640 ret_types.reserve(specs.outputs.size());
2641 for (int i = 0, e = specs.outputs.size(); i != e; ++i) {
2642 if (ret_nodes->at(i).node == nullptr) {
2643 return errors::InvalidArgument("Output ", specs.outputs[i],
2644 " was not found in graph");
2645 }
2646 }
2647 for (const auto& ret : *ret_nodes) {
2648 if (ret.node->num_outputs() <= ret.index) {
2649 return errors::InvalidArgument("Invalid output index ", ret.index,
2650 " specified for node: ", ret.node->name());
2651 }
2652 TF_ASSIGN_OR_RETURN(auto type,
2653 InferOutputType(*ret.node, ret.index, builder));
2654 ret_types.push_back(type);
2655 }
2656
2657 return builder.getFunctionType(arg_types, ret_types);
2658 }
2659
2660 StatusOr<mlir::FunctionType>
GetArgsRetsAndTypesFromFunctionGraph(mlir::MLIRContext * context,absl::InlinedVector<OutputTensor,4> * arg_nodes,absl::InlinedVector<OutputTensor,4> * ret_nodes)2661 GraphDefImporter::GetArgsRetsAndTypesFromFunctionGraph(
2662 mlir::MLIRContext* context, absl::InlinedVector<OutputTensor, 4>* arg_nodes,
2663 absl::InlinedVector<OutputTensor, 4>* ret_nodes) {
2664 auto add_node = [](Node* node, absl::InlinedVector<OutputTensor, 4>* nodes) {
2665 auto* attr = node->attrs().Find("index");
2666 if (!attr)
2667 return errors::InvalidArgument(node->type_string(), " node '",
2668 node->name(),
2669 "' is missing attribute 'index'");
2670
2671 auto index = attr->i();
2672 const int num_nodes = nodes->size();
2673 if (num_nodes < index + 1) nodes->resize(index + 1);
2674
2675 if ((*nodes)[index].node != nullptr)
2676 return errors::InvalidArgument(node->type_string(), " node '",
2677 node->name(), "' has attribute 'index' ",
2678 index, " that conflicts with node '",
2679 (*nodes)[index].node->name(), "'");
2680 (*nodes)[index] = {node, 0};
2681
2682 return OkStatus();
2683 };
2684
2685 // Collect arg and ret nodes from graph.
2686 for (auto* node : GetOrderedNodes())
2687 if (node->IsArg())
2688 TF_RETURN_IF_ERROR(add_node(node, arg_nodes));
2689 else if (node->IsRetval())
2690 TF_RETURN_IF_ERROR(add_node(node, ret_nodes));
2691
2692 // Collect arg and ret types and create function type.
2693 mlir::Builder builder(context);
2694 llvm::SmallVector<mlir::Type, 4> arg_types;
2695 arg_types.reserve(arg_nodes->size());
2696 for (auto arg_node_and_idx : llvm::enumerate(*arg_nodes)) {
2697 auto& arg_node = arg_node_and_idx.value();
2698 if (arg_node.node == nullptr)
2699 return errors::InvalidArgument("Graph missing _Arg at index ",
2700 arg_node_and_idx.index());
2701
2702 TF_ASSIGN_OR_RETURN(auto type,
2703 InferOutputType(*arg_node.node, /*idx=*/0, builder));
2704 arg_types.push_back(type);
2705 }
2706
2707 llvm::SmallVector<mlir::Type, 4> ret_types;
2708 ret_types.reserve(ret_nodes->size());
2709 for (auto ret_node_and_idx : llvm::enumerate(*ret_nodes)) {
2710 auto& ret_node = ret_node_and_idx.value();
2711 if (ret_node.node == nullptr)
2712 return errors::InvalidArgument("Graph missing _Retval at index ",
2713 ret_node_and_idx.index());
2714
2715 TF_ASSIGN_OR_RETURN(auto type,
2716 InferInputType(*ret_node.node, /*idx=*/0, builder));
2717 ret_types.push_back(type);
2718 }
2719
2720 return builder.getFunctionType(arg_types, ret_types);
2721 }
2722
GetControlRetsFromGraph(llvm::ArrayRef<std::string> control_outputs,absl::InlinedVector<Node *,4> * control_ret_nodes)2723 Status GraphDefImporter::GetControlRetsFromGraph(
2724 llvm::ArrayRef<std::string> control_outputs,
2725 absl::InlinedVector<Node*, 4>* control_ret_nodes) {
2726 if (control_outputs.empty()) return OkStatus();
2727
2728 llvm::SmallDenseMap<llvm::StringRef, int32_t> controls_to_idx;
2729 for (auto control_and_idx : llvm::enumerate(control_outputs))
2730 controls_to_idx.insert({control_and_idx.value(), control_and_idx.index()});
2731
2732 if (controls_to_idx.size() != control_outputs.size())
2733 return errors::InvalidArgument("Control outputs must be unique");
2734
2735 control_ret_nodes->resize(controls_to_idx.size());
2736
2737 for (auto* node : GetOrderedNodes()) {
2738 auto it = controls_to_idx.find(node->name());
2739 if (it != controls_to_idx.end()) (*control_ret_nodes)[it->second] = node;
2740 }
2741
2742 for (auto node_and_name : llvm::zip(*control_ret_nodes, control_outputs))
2743 if (std::get<0>(node_and_name) == nullptr)
2744 return errors::InvalidArgument(
2745 "Control output '", std::get<1>(node_and_name), "' is missing");
2746
2747 return OkStatus();
2748 }
2749
2750 // Stateful helper class to import a TensorFlow model expressed in SavedModel
2751 // into an MLIR Module.
2752 class SavedModelObjectGraphImporter : public ImporterBase {
2753 public:
2754 // Main entry point: converts all functions in the given meta graph to an MLIR
2755 // Module.
2756 static StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> Convert(
2757 SavedModelV2Bundle* saved_model, absl::Span<std::string> exported_names,
2758 mlir::MLIRContext* context, bool add_default_attributes,
2759 bool unconditionally_use_set_output_shapes);
2760
2761 private:
SavedModelObjectGraphImporter(const FunctionLibraryDefinition & flib,const GraphDebugInfo & debug_info,const GraphImportConfig & specs,mlir::ModuleOp module,std::unordered_map<std::string,std::string> * tf_name_to_mlir_name,NameUniquifier * function_name_uniquifier)2762 explicit SavedModelObjectGraphImporter(
2763 const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
2764 const GraphImportConfig& specs, mlir::ModuleOp module,
2765 std::unordered_map<std::string, std::string>* tf_name_to_mlir_name,
2766 NameUniquifier* function_name_uniquifier)
2767 : ImporterBase(flib, debug_info, specs, module, tf_name_to_mlir_name,
2768 function_name_uniquifier) {}
2769 };
2770
2771 // Determines the names used to reference objects in the SavedObjectGraph.
2772 class ObjectNames {
2773 public:
2774 explicit ObjectNames(const SavedObjectGraph& object_graph,
2775 absl::Span<std::string> exported_names);
2776
2777 // Gets the names that external users of the SavedModel can use to refer to
2778 // this node.
2779 llvm::ArrayRef<llvm::StringRef> GetExportedNames(int node_id) const;
2780
2781 // Gets the name in the module symbol table for this node.
2782 // This name is only used for internal IR references.
2783 llvm::StringRef GetSymbolTableName(int node_id) const;
2784
2785 private:
2786 // In the absence of any other information, use this name as the symbol table
2787 // name for this node.
2788 std::string GetDefaultSymbolTableName(int node_id) const;
2789 // Determines if a name is exported.
2790 bool IsExported(const std::string& name);
2791 // Main object graph traversal function.
2792 void RecursivelyVisitObjectGraph(int node_id);
2793 // Gets a stable StringRef from a std::string.
2794 llvm::StringRef SaveString(const std::string& s) const;
2795
2796 // The object graph we are traversing.
2797 const SavedObjectGraph& object_graph_;
2798 // The set of names to export. Empty means "export all".
2799 std::unordered_set<std::string> names_to_export_;
2800
2801 // When we recursively follow the object graph tree structure from the root,
2802 // we track its path in the object graph by pushing and popping from here
2803 // during traversal.
2804 llvm::SmallVector<std::string, 8> path_segments_;
2805 // The set of node IDs that are on the current DFS stack.
2806 // For cyclic object graphs, this prevents infinite recursion.
2807 std::unordered_set<int> on_stack_nodes_;
2808
2809 // Key: node_id.
2810 // Value: all object names that node_id appears as.
2811 // Each object name corresponds to a unique path from the root of the object
2812 // graph.
2813 // The common intuitive case is when there is only one name for a given
2814 // object, which corresponds to the object graph being a tree.
2815 //
2816 // But, there cases where the object graph is a general graph. For
2817 // example, this happens commonly in Keras models, where `foo.bar` is
2818 // also reachable via the name `keras_api.foo.bar`.
2819 // Cycles are possible too.
2820 absl::flat_hash_map<int, std::vector<std::string>> object_names_;
2821
2822 // Key: node_id
2823 // Value: all names that this object is exported as
2824 absl::flat_hash_map<int, llvm::SmallVector<llvm::StringRef, 1>>
2825 exported_names_;
2826 // Key: node_id
2827 // Value: pretty symbol table name to use for internal references to this
2828 // object.
2829 absl::flat_hash_map<int, llvm::StringRef> pretty_symbol_table_name_;
2830
2831 // Stable strings we can take StringRef's into. Used only by the SaveString
2832 // method.
2833 mutable std::unordered_set<std::string> saved_strings_;
2834 };
2835
ObjectNames(const SavedObjectGraph & object_graph,absl::Span<std::string> exported_names)2836 ObjectNames::ObjectNames(const SavedObjectGraph& object_graph,
2837 absl::Span<std::string> exported_names)
2838 : object_graph_(object_graph),
2839 names_to_export_(exported_names.begin(), exported_names.end()) {
2840 // Visit all reachable nodes from the root of the object graph.
2841 // This builds up object_names_ to contain all names like `foo.bar` that a
2842 // particular node in the graph can be reached from.
2843 RecursivelyVisitObjectGraph(/*node_id=*/0);
2844
2845 // Populate the exported_names_ map.
2846 // TODO(silvasean): Diagnose typos in exported names?
2847 for (auto& kv : object_names_) {
2848 // Make object names map independent of our particular choice of object
2849 // graph traversal.
2850 std::sort(kv.second.begin(), kv.second.end(),
2851 [](absl::string_view a, absl::string_view b) {
2852 // The sort order here influences the "pretty name" we assign
2853 // below. We want the most debuggable name to be first.
2854 //
2855 // Debuggability heuristics:
2856 // 1. Names that end in digits are likely to be internal aliases
2857 // to the "real" names.
2858 // 2. Longer names are more likely to be internal aliases.
2859 //
2860 // Example set of object names created by Keras for the weight
2861 // matrix of a fully connected layer on a trivial FC mnist
2862 // model:
2863 // - `model.layer-1.kernel` (this is the "best" name)
2864 // - `model.keras_api.layers.1.kernel`
2865 // - `model.variables.0`
2866 // - `model.keras_api.layers.1.keras_api.trainable_variables.0`
2867 // - ... 10 more long aliases ending in digits ...
2868 return std::make_tuple(isdigit(a.back()), a.size(), a) <
2869 std::make_tuple(isdigit(b.back()), b.size(), b);
2870 });
2871 for (const std::string& name : kv.second) {
2872 if (IsExported(name)) {
2873 exported_names_[kv.first].push_back(SaveString(name));
2874 }
2875 }
2876 }
2877 // Create "pretty" symbol table names for nodes where that is applicable.
2878 // We could make all symbol table names use the default, which is basically
2879 // just the node id. But for debugging purposes, it's nicer if we can mix in
2880 // a recognizable object name if we have the information to do so.
2881 for (auto& kv : object_names_) {
2882 int node_id = kv.first;
2883 std::string internal_name =
2884 absl::StrCat(GetDefaultSymbolTableName(node_id), "__");
2885 // If the object has an exported name, we prefer that since it is probably
2886 // the most recognizable. Otherwise, we grab some non-exported name of the
2887 // object.
2888 if (exported_names_.find(node_id) != exported_names_.end()) {
2889 internal_name += exported_names_[node_id][0].str();
2890 } else {
2891 internal_name += object_names_[node_id][0];
2892 }
2893 pretty_symbol_table_name_[node_id] = SaveString(internal_name);
2894 }
2895 }
2896
GetExportedNames(int node_id) const2897 llvm::ArrayRef<llvm::StringRef> ObjectNames::GetExportedNames(
2898 int node_id) const {
2899 auto it = exported_names_.find(node_id);
2900 if (it != exported_names_.end()) {
2901 return it->second;
2902 }
2903 return {};
2904 }
2905
GetSymbolTableName(int node_id) const2906 llvm::StringRef ObjectNames::GetSymbolTableName(int node_id) const {
2907 auto it = pretty_symbol_table_name_.find(node_id);
2908 if (it != pretty_symbol_table_name_.end()) {
2909 return it->second;
2910 }
2911 return SaveString(GetDefaultSymbolTableName(node_id));
2912 }
2913
GetDefaultSymbolTableName(int node_id) const2914 std::string ObjectNames::GetDefaultSymbolTableName(int node_id) const {
2915 return absl::StrCat("__sm_node", node_id);
2916 }
2917
IsExported(const std::string & name)2918 bool ObjectNames::IsExported(const std::string& name) {
2919 if (names_to_export_.empty()) {
2920 return true;
2921 }
2922 return names_to_export_.find(name) != names_to_export_.end();
2923 }
2924
RecursivelyVisitObjectGraph(int node_id)2925 void ObjectNames::RecursivelyVisitObjectGraph(int node_id) {
2926 const SavedObject& object = object_graph_.nodes(node_id);
2927
2928 switch (object.kind_case()) {
2929 case SavedObject::kConstant:
2930 case SavedObject::kFunction:
2931 case SavedObject::kVariable: {
2932 object_names_[node_id].push_back(absl::StrJoin(path_segments_, "."));
2933 break;
2934 }
2935 default:
2936 break;
2937 }
2938
2939 for (const auto& child_ref : object.children()) {
2940 bool on_stack = !on_stack_nodes_.insert(child_ref.node_id()).second;
2941 if (on_stack) {
2942 // This is a backedge. Don't traverse it.
2943 continue;
2944 }
2945
2946 path_segments_.push_back(child_ref.local_name());
2947 RecursivelyVisitObjectGraph(child_ref.node_id());
2948 path_segments_.pop_back();
2949
2950 on_stack_nodes_.erase(child_ref.node_id());
2951 }
2952 }
2953
SaveString(const std::string & s) const2954 llvm::StringRef ObjectNames::SaveString(const std::string& s) const {
2955 return llvm::StringRef(*saved_strings_.insert(s).first);
2956 }
2957
2958 // Extracts a TensorProto for a Const op from a GraphDef, given an op_name.
2959 // Returns nullptr on not found or other mismatch.
2960 // This returns a pointer to the actual node within the graph_def so as to
2961 // avoid expensive copies.
ExtractConstTensorFromGraph(const GraphDef & graph_def,const std::string & op_name)2962 const TensorProto* ExtractConstTensorFromGraph(const GraphDef& graph_def,
2963 const std::string& op_name) {
2964 const NodeDef* match_node = nullptr;
2965 for (const auto& node : graph_def.node()) {
2966 if (node.name() == op_name) {
2967 match_node = &node;
2968 }
2969 }
2970
2971 if (!match_node) {
2972 return nullptr;
2973 }
2974
2975 auto value_it = match_node->attr().find("value");
2976 if (value_it == match_node->attr().end()) {
2977 return nullptr;
2978 }
2979
2980 if (!value_it->second.has_tensor()) {
2981 return nullptr;
2982 }
2983
2984 return &value_it->second.tensor();
2985 }
2986
2987 const TrackableObjectGraph::TrackableObject::SerializedTensor*
FindSerializedTensorInTrackable(const TrackableObjectGraph::TrackableObject & trackable_object,StringPiece name)2988 FindSerializedTensorInTrackable(
2989 const TrackableObjectGraph::TrackableObject& trackable_object,
2990 StringPiece name) {
2991 for (const auto& maybe_serialized_tensor : trackable_object.attributes()) {
2992 if (maybe_serialized_tensor.name() == name) {
2993 return &maybe_serialized_tensor;
2994 }
2995 }
2996 return nullptr;
2997 }
2998
DiagnoseMultipleConcreteFunctions(const SavedObjectGraph & object_graph,const ObjectNames & object_names)2999 Status DiagnoseMultipleConcreteFunctions(const SavedObjectGraph& object_graph,
3000 const ObjectNames& object_names) {
3001 for (int node_id = 0; node_id < object_graph.nodes_size(); node_id++) {
3002 const SavedObject& object = object_graph.nodes(node_id);
3003 if (object_names.GetExportedNames(node_id).empty()) {
3004 continue;
3005 }
3006 if (object.kind_case() == SavedObject::kFunction) {
3007 // We only allow a single input signature to each SavedFunction.
3008 // This assumption means we have a 1:1 correspondence between
3009 // tf.function <=> SavedFunction <=> SavedConcreteFunction <=> FunctionDef
3010 // This makes defining the ABI easier (or even well-defined at all).
3011 // TODO(silvasean): How to detect a function that doesn't have an
3012 // explicitly user-provided input signature, but happens to have been
3013 // traced exactly once?
3014 if (object.function().concrete_functions_size() != 1) {
3015 llvm::SmallVector<std::string, 4> names;
3016 for (llvm::StringRef s : object_names.GetExportedNames(node_id)) {
3017 names.push_back("'" + s.str() + "'");
3018 }
3019 return errors::InvalidArgument(
3020 "Exported function with exported name(s) ",
3021 absl::StrJoin(names, ", "),
3022 " with multiple concrete functions. Add "
3023 "@tf.function(input_signature=[...]) on this function, or use a "
3024 "narrower list of exported names that excludes this function.");
3025 }
3026 }
3027 }
3028 return OkStatus();
3029 }
3030
3031 // Recursively traverses a StructuredValue, linearizing all the leaves.
3032 //
3033 // This currently only handles the subset of StructuredValue that is needed for
3034 // signatures.
3035 //
3036 // Given a StructuredValue with structure [{"x": leaf0}], the "index path"
3037 // needed to reach leaf0 is `[0, "x"]`, as it would be if you were operating on
3038 // a Python object (`obj[0]["x"] is leaf0`). Each leaf corresponds to a
3039 // linearized function argument or return on a FunctionDef, and hence to an
3040 // mlir::func::FuncOp argument / return.
3041 //
3042 // This must match the linearization that happens in `tf.nest.flatten`.
3043 // In particular, dict values should be linearized in sorted key order.
3044 //
3045 // The linearized index paths can be returned back to a structured
3046 // representation (e.g. to emit C structs matching a signature) with a simple
3047 // algorithm that recurses on each run of index paths with identical first
3048 // elements.
3049 class StructuredValueLinearizer {
3050 public:
3051 StructuredValueLinearizer(const StructuredValue& value,
3052 mlir::MLIRContext* context);
3053
3054 // Returns the list of index paths to each leaf of the StructuredValue,
3055 // in a linearized order matching `tf.nest.flatten`.
3056 //
3057 // If an error occurred during the linearization process, an error message
3058 // with `error_context` prepended will be included in the returned status.
3059 StatusOr<llvm::ArrayRef<mlir::ArrayAttr>> GetLeafIndexPaths(
3060 llvm::StringRef error_context) const;
3061
3062 private:
3063 // Main function that recursively traverses the StructuredValue.
3064 void RecursivelyFindLeaves(const StructuredValue& value);
3065
3066 mlir::Builder builder_;
3067 // The current index path. We push/pop this during recursive traversal of the
3068 // StructuredValue.
3069 llvm::SmallVector<mlir::Attribute, 4> current_index_path_;
3070 // The list of leaf index paths we have discovered so far.
3071 llvm::SmallVector<mlir::ArrayAttr, 4> leaf_index_paths_;
3072 // If non-empty, an error message to report.
3073 std::string error_message_;
3074 };
3075
StructuredValueLinearizer(const StructuredValue & value,mlir::MLIRContext * context)3076 StructuredValueLinearizer::StructuredValueLinearizer(
3077 const StructuredValue& value, mlir::MLIRContext* context)
3078 : builder_(context) {
3079 RecursivelyFindLeaves(value);
3080 }
3081
3082 StatusOr<llvm::ArrayRef<mlir::ArrayAttr>>
GetLeafIndexPaths(llvm::StringRef error_context) const3083 StructuredValueLinearizer::GetLeafIndexPaths(
3084 llvm::StringRef error_context) const {
3085 if (error_message_.empty()) {
3086 return llvm::makeArrayRef(leaf_index_paths_);
3087 }
3088 return errors::InvalidArgument(
3089 error_context.str(), error_message_,
3090 "This likely means that you have @tf.function "
3091 "on an exported function instead of "
3092 "@tf.function(input_signature=[...]). Consider annotating an "
3093 "input_signature or narrowing your set of "
3094 "exported names to not include this function.");
3095 }
3096
RecursivelyFindLeaves(const StructuredValue & value)3097 void StructuredValueLinearizer::RecursivelyFindLeaves(
3098 const StructuredValue& value) {
3099 switch (value.kind_case()) {
3100 case StructuredValue::kDictValue: {
3101 // Dict values must be linearized in sorted order of keys.
3102 const DictValue& dict = value.dict_value();
3103 using FieldTy = protobuf::MapPair<std::string, StructuredValue>;
3104 llvm::SmallVector<const FieldTy*, 4> fields;
3105 for (auto& field : dict.fields()) {
3106 fields.push_back(&field);
3107 }
3108 llvm::sort(fields, [](const FieldTy* a, const FieldTy* b) {
3109 return a->first < b->first;
3110 });
3111 for (auto& field : fields) {
3112 current_index_path_.push_back(builder_.getStringAttr(field->first));
3113 RecursivelyFindLeaves(field->second);
3114 current_index_path_.pop_back();
3115 }
3116 return;
3117 }
3118 case StructuredValue::kTupleValue: {
3119 const TupleValue& tuple = value.tuple_value();
3120 for (int i = 0, e = tuple.values_size(); i < e; i++) {
3121 current_index_path_.push_back(builder_.getI64IntegerAttr(i));
3122 RecursivelyFindLeaves(tuple.values(i));
3123 current_index_path_.pop_back();
3124 }
3125 return;
3126 }
3127 // We don't differentiate between tuples and lists.
3128 case StructuredValue::kListValue: {
3129 const ListValue& list = value.list_value();
3130 for (int i = 0, e = list.values_size(); i < e; i++) {
3131 current_index_path_.push_back(builder_.getI64IntegerAttr(i));
3132 RecursivelyFindLeaves(list.values(i));
3133 current_index_path_.pop_back();
3134 }
3135 return;
3136 }
3137 case StructuredValue::kTensorSpecValue: {
3138 // Base case: record the current path stack as the index path needed to
3139 // get to this leaf.
3140 leaf_index_paths_.push_back(builder_.getArrayAttr(current_index_path_));
3141 return;
3142 }
3143 case StructuredValue::kNoneValue: {
3144 // Base case: do nothing.
3145 // This arises, for example, as the top-level object of an output
3146 // signature when there are no return values.
3147 return;
3148 }
3149 default: {
3150 llvm::raw_string_ostream os(error_message_);
3151 // TODO(silvasean): Use an enumerant name string instead of a number.
3152 os << "Unhandled structured value kind " << value.kind_case()
3153 << " at index path: <value>";
3154 for (auto path_element : current_index_path_) {
3155 os << ".";
3156 if (auto integer = path_element.dyn_cast<mlir::IntegerAttr>()) {
3157 os << integer.getValue();
3158 } else {
3159 auto str = path_element.cast<mlir::StringAttr>();
3160 os << str.getValue();
3161 }
3162 }
3163 os << "\n";
3164 }
3165 }
3166 }
3167
3168 // For exported functions with bound inputs, rewrite the function
3169 // signature to match the requirements of tf_saved_model bound input args.
3170 //
3171 // The raw imported functions have `tensor<*x!tf_type.resource>` as the type for
3172 // mutable bound inputs and `tensor<...>` as the type for immutable
3173 // bound inputs. Here we canonicalize both of them into
3174 // `tensor<!tf_type.resource<tensor<...>>>`.
AdjustBoundInputArgTypes(mlir::ModuleOp module)3175 void AdjustBoundInputArgTypes(mlir::ModuleOp module) {
3176 mlir::SymbolTable symbol_table(module);
3177 for (auto func : module.getOps<mlir::func::FuncOp>()) {
3178 if (!mlir::tf_saved_model::IsExported(func)) continue;
3179 mlir::OpBuilder builder(func.getBody());
3180 llvm::SmallVector<mlir::Type, 4> new_input_types;
3181 for (int i = 0, e = func.getNumArguments(); i < e; i++) {
3182 auto arg = func.getArgument(i);
3183 auto global_tensor = mlir::tf_saved_model::LookupBoundInputOfType<
3184 mlir::tf_saved_model::GlobalTensorOp>(func, i, symbol_table);
3185 if (global_tensor) {
3186 auto old_type = arg.getType();
3187 auto new_type =
3188 mlir::tf_saved_model::GetBoundInputArgTypeFor(global_tensor);
3189 arg.setType(new_type);
3190 if (global_tensor.is_mutable()) {
3191 auto arg_with_original_type = builder.create<mlir::TF::CastOp>(
3192 global_tensor.getLoc(), old_type, arg,
3193 /*Truncate=*/builder.getBoolAttr(false));
3194 arg.replaceAllUsesWith(arg_with_original_type);
3195 // The RAUW replaces the arg with itself, so we need to set it back.
3196 arg_with_original_type.setOperand(arg);
3197 } else {
3198 auto arg_with_original_type =
3199 builder.create<mlir::TF::ReadVariableOp>(global_tensor.getLoc(),
3200 old_type, arg);
3201 arg.replaceAllUsesWith(arg_with_original_type);
3202 // The RAUW replaces the arg with itself, so we need to set it back.
3203 arg_with_original_type.setOperand(arg);
3204 }
3205 }
3206 new_input_types.push_back(arg.getType());
3207 }
3208 func.setType(mlir::FunctionType::get(module.getContext(), new_input_types,
3209 func.getFunctionType().getResults()));
3210 }
3211 }
3212
3213 // Marks the visibility of functions in the saved model module.
MarkSavedModelFunctionVisibility(mlir::ModuleOp module)3214 void MarkSavedModelFunctionVisibility(mlir::ModuleOp module) {
3215 for (auto func : module.getOps<mlir::func::FuncOp>()) {
3216 auto visibility = mlir::tf_saved_model::IsExported(func)
3217 ? mlir::func::FuncOp::Visibility::Public
3218 : mlir::func::FuncOp::Visibility::Private;
3219 func.setVisibility(visibility);
3220 }
3221 }
3222
3223 // Reorder the ops in the module to make testing easier and less dependent
3224 // on implementation details such as the order of functions in the
3225 // FunctionDefLibrary.
3226 //
3227 // The order this ensures is:
3228 // 1. GlobalTensorOp's
3229 // 2. FuncOps's.
3230 //
3231 // Within each of 1. and 2., ops are sorted by exported name (if
3232 // available, and only the first exported name is considered), followed by
3233 // non-exported ops.
SortSavedModelModule(mlir::ModuleOp module)3234 void SortSavedModelModule(mlir::ModuleOp module) {
3235 struct NamedGlobalTensor {
3236 llvm::StringRef name;
3237 GlobalTensorOp global_tensor;
3238 };
3239 llvm::SmallVector<NamedGlobalTensor, 8> named_global_tensors;
3240 for (auto global_tensor : module.getOps<GlobalTensorOp>()) {
3241 auto exported_names = mlir::tf_saved_model::GetExportedNames(global_tensor);
3242 // We use stable_sort, so duplicate empty names are fine here.
3243 named_global_tensors.push_back(
3244 {exported_names.empty() ? "" : exported_names.front(), global_tensor});
3245 }
3246 llvm::stable_sort(named_global_tensors,
3247 [](const NamedGlobalTensor& a, const NamedGlobalTensor& b) {
3248 return std::make_tuple(a.name.empty(), a.name) <
3249 std::make_tuple(b.name.empty(), b.name);
3250 });
3251
3252 struct NamedFunc {
3253 llvm::StringRef name;
3254 mlir::func::FuncOp func;
3255 };
3256 llvm::SmallVector<NamedFunc, 8> named_funcs;
3257 llvm::SmallVector<mlir::func::FuncOp, 8> private_funcs;
3258 for (auto func : module.getOps<mlir::func::FuncOp>()) {
3259 auto exported_names = mlir::tf_saved_model::GetExportedNames(func);
3260 if (!exported_names.empty())
3261 named_funcs.push_back({exported_names.front(), func});
3262 else
3263 private_funcs.push_back(func);
3264 }
3265 llvm::stable_sort(named_funcs, [](const NamedFunc& a, const NamedFunc& b) {
3266 return a.name < b.name;
3267 });
3268 llvm::stable_sort(private_funcs,
3269 [](mlir::func::FuncOp a, mlir::func::FuncOp b) {
3270 return a.getName() < b.getName();
3271 });
3272
3273 struct NamedAsset {
3274 llvm::StringRef name;
3275 AssetOp asset;
3276 };
3277 llvm::SmallVector<NamedAsset, 4> assets;
3278 for (auto asset : module.getOps<AssetOp>()) {
3279 assets.push_back({asset.getName(), asset});
3280 }
3281 llvm::stable_sort(assets, [](const NamedAsset& a, const NamedAsset& b) {
3282 return a.name < b.name;
3283 });
3284
3285 // Move onto the front of the module in reverse of the final desired order.
3286 for (auto func : llvm::reverse(private_funcs)) {
3287 func.getOperation()->moveBefore(&module.getBody()->front());
3288 }
3289 for (auto named_func : llvm::reverse(named_funcs)) {
3290 named_func.func.getOperation()->moveBefore(&module.getBody()->front());
3291 }
3292 for (auto named_global_tensor : llvm::reverse(named_global_tensors)) {
3293 named_global_tensor.global_tensor.getOperation()->moveBefore(
3294 &module.getBody()->front());
3295 }
3296
3297 for (auto asset : assets) {
3298 asset.asset.getOperation()->moveBefore(&module.getBody()->front());
3299 }
3300
3301 auto initializers = module.getOps<SessionInitializerOp>();
3302 if (!initializers.empty()) {
3303 (*initializers.begin())
3304 .getOperation()
3305 ->moveBefore(&module.getBody()->front());
3306 }
3307 }
3308
CreateSavedModelIR(const ObjectNames & object_names,mlir::ModuleOp module,const SavedObjectGraph & object_graph,const std::unordered_map<std::string,std::string> & tf_name_to_mlir_name,SavedModelV2Bundle * saved_model)3309 Status CreateSavedModelIR(
3310 const ObjectNames& object_names, mlir::ModuleOp module,
3311 const SavedObjectGraph& object_graph,
3312 const std::unordered_map<std::string, std::string>& tf_name_to_mlir_name,
3313 SavedModelV2Bundle* saved_model) {
3314 mlir::OpBuilder builder(module.getBodyRegion());
3315 mlir::SymbolTable symbol_table(module);
3316
3317 // Create a side data-structure, indexed by the object_graph node_id to
3318 // a TrackableObject that is restorable.
3319 absl::flat_hash_map<int, const TrackableObjectGraph::TrackableObject*>
3320 restored_objects;
3321 TF_RETURN_IF_ERROR(saved_model->VisitObjectsToRestore(
3322 [&](int saved_node_id,
3323 const TrackableObjectGraph::TrackableObject& trackable_object) {
3324 restored_objects.insert(
3325 std::make_pair(saved_node_id, &trackable_object));
3326 return OkStatus();
3327 }));
3328
3329 for (int node_id = 0; node_id < object_graph.nodes_size(); node_id++) {
3330 const SavedObject& object = object_graph.nodes(node_id);
3331 // For correctness, we cannot import functions that don't have exported
3332 // names, since they don't necessarily have a well-defined ABI (diagnosed
3333 // earlier).
3334 //
3335 // For variables/constants, pruning them is purely an optimization,
3336 // and more complicated since it requires use-def analysis of which
3337 // functions use which variables/constants, so we don't do anything
3338 // special for them here as part of our initial IR construction.
3339 if (object.kind_case() == SavedObject::kFunction) {
3340 if (object_names.GetExportedNames(node_id).empty()) {
3341 continue;
3342 }
3343 std::string error_context =
3344 "While importing SavedModel function '" +
3345 object_names.GetExportedNames(node_id)[0].str() + "': ";
3346 const SavedFunction& function = object.function();
3347 auto orig_func = symbol_table.lookup<mlir::func::FuncOp>(
3348 tf_name_to_mlir_name.find(function.concrete_functions(0))->second);
3349 mlir::func::FuncOp func = orig_func;
3350 // If there are potentially references to this func from within the
3351 // module, create a wrapper around it and decorate the wrapper with the
3352 // tf_saved_model attributes instead.
3353 if (!mlir::SymbolTable::symbolKnownUseEmpty(orig_func.getSymNameAttr(),
3354 &module.getBodyRegion())) {
3355 func = orig_func.cloneWithoutRegions();
3356 module.insert(module.getBody()->begin(), func);
3357 func.addEntryBlock();
3358 func.setName(builder.getStringAttr("__sm_exported_" +
3359 orig_func.getName().str()));
3360 llvm::SmallVector<mlir::Value, 4> args_as_values;
3361 for (auto block_argument : func.getArguments()) {
3362 args_as_values.push_back(block_argument);
3363 }
3364 mlir::OpBuilder body_builder(&func.getBody());
3365 auto call = body_builder.create<mlir::TF::StatefulPartitionedCallOp>(
3366 func.getLoc(), orig_func.getFunctionType().getResults(),
3367 args_as_values,
3368 mlir::SymbolRefAttr::get(builder.getContext(), orig_func.getName()),
3369 /*config=*/builder.getStringAttr(""),
3370 /*config_proto=*/builder.getStringAttr(""),
3371 /*executor_type=*/builder.getStringAttr(""));
3372 body_builder.create<mlir::func::ReturnOp>(func.getLoc(),
3373 call.getResults());
3374 }
3375 func->setAttr(
3376 "tf_saved_model.exported_names",
3377 builder.getStrArrayAttr(object_names.GetExportedNames(node_id)));
3378 const SavedConcreteFunction& concrete_function =
3379 object_graph.concrete_functions().at(function.concrete_functions(0));
3380
3381 // We do not handle the other element of this tuple, which corresponds to
3382 // Python kwonlyargs, since currently TensorFlow prohibits this in
3383 // combination with input_signature:
3384 // https://github.com/tensorflow/tensorflow/blob/8cb8627abb5ef83a6fba34f8fd0e4ee430562eb1/tensorflow/python/eager/function.py#L2027-L2030
3385 // Our SavedModel import requires input_signature on the tf.function, so
3386 // we never need to handle the kwonlyargs.
3387 auto positional_arg_structure =
3388 concrete_function.canonicalized_input_signature()
3389 .tuple_value()
3390 .values(0);
3391 StructuredValueLinearizer input_linearizer(positional_arg_structure,
3392 builder.getContext());
3393
3394 int bound_input_base =
3395 func.getNumArguments() - concrete_function.bound_inputs_size();
3396 TF_ASSIGN_OR_RETURN(auto input_index_paths,
3397 input_linearizer.GetLeafIndexPaths(
3398 error_context + "in input signature: "));
3399 const int input_index_paths_size = input_index_paths.size();
3400 if (bound_input_base != input_index_paths_size) {
3401 return errors::InvalidArgument(
3402 error_context,
3403 "Argument mismatch between concrete function input signature "
3404 "vs underlying FunctionDef for concrete function '",
3405 function.concrete_functions(0), "' (", input_index_paths.size(),
3406 " vs ", bound_input_base, ")");
3407 }
3408 for (auto index_path : llvm::enumerate(input_index_paths)) {
3409 func.setArgAttr(index_path.index(), "tf_saved_model.index_path",
3410 index_path.value());
3411 }
3412
3413 for (auto& bound_input :
3414 llvm::enumerate(concrete_function.bound_inputs())) {
3415 int arg_index = bound_input_base + bound_input.index();
3416 auto symbol_ref = mlir::SymbolRefAttr::get(
3417 builder.getContext(),
3418 object_names.GetSymbolTableName(bound_input.value()));
3419 func.setArgAttr(arg_index, "tf_saved_model.bound_input", symbol_ref);
3420 }
3421
3422 StructuredValueLinearizer output_linearizer(
3423 concrete_function.output_signature(), builder.getContext());
3424 TF_ASSIGN_OR_RETURN(auto output_index_paths,
3425 output_linearizer.GetLeafIndexPaths(
3426 error_context + "in output signature: "));
3427 if (func.getNumResults() != output_index_paths.size()) {
3428 return errors::InvalidArgument(
3429 error_context,
3430 "Result mismatch between concrete function output signature "
3431 "vs underlying FunctionDef for concrete function '",
3432 function.concrete_functions(0), "' (", output_index_paths.size(),
3433 " vs ", func.getNumResults(), ")");
3434 }
3435 for (auto index_path : llvm::enumerate(output_index_paths)) {
3436 func.setResultAttr(index_path.index(), "tf_saved_model.index_path",
3437 index_path.value());
3438 }
3439 } else if (object.kind_case() == SavedObject::kVariable) {
3440 const SavedVariable& variable = object.variable();
3441 // Find the trackable in the side data structure.
3442 auto variable_trackable_it = restored_objects.find(node_id);
3443 if (variable_trackable_it == restored_objects.end()) {
3444 return errors::FailedPrecondition("Could not restore saved variable: ",
3445 variable.name());
3446 }
3447 const auto* serialized_tensor_attr = FindSerializedTensorInTrackable(
3448 *variable_trackable_it->second, "VARIABLE_VALUE");
3449 if (!serialized_tensor_attr) {
3450 return errors::FailedPrecondition(
3451 "Could not find serialized tensor for saved variable: ",
3452 variable.name());
3453 }
3454 const auto& checkpoint_key = serialized_tensor_attr->checkpoint_key();
3455
3456 // Load it from the reader.
3457 Tensor value;
3458 TF_RETURN_WITH_CONTEXT_IF_ERROR(
3459 saved_model->variable_reader()->Lookup(checkpoint_key, &value),
3460 "Could not read checkpoint key from variables bundle: ",
3461 checkpoint_key);
3462 TF_ASSIGN_OR_RETURN(auto value_attr, ConvertTensor(value, &builder));
3463 // A variable can have a partially known type, such as tensor<?x27x?xf32>,
3464 // even if the initializer is a specific static shape.
3465 TF_ASSIGN_OR_RETURN(
3466 auto type, ConvertToMlirTensorType(variable.shape(), variable.dtype(),
3467 &builder));
3468 auto op = builder.create<GlobalTensorOp>(
3469 builder.getUnknownLoc(),
3470 builder.getStringAttr(object_names.GetSymbolTableName(node_id)),
3471 value_attr,
3472 /*type=*/mlir::TypeAttr::get(type),
3473 /*is_mutable=*/builder.getUnitAttr());
3474 op->setAttr(
3475 "tf_saved_model.exported_names",
3476 builder.getStrArrayAttr(object_names.GetExportedNames(node_id)));
3477 } else if (object.kind_case() == SavedObject::kConstant) {
3478 const SavedConstant& constant = object.constant();
3479 const TensorProto* value = ExtractConstTensorFromGraph(
3480 saved_model->meta_graph_def().graph_def(), constant.operation());
3481 if (!value) {
3482 return errors::FailedPrecondition(
3483 "Unable to find const node referenced in object graph: ",
3484 constant.operation());
3485 }
3486 TF_ASSIGN_OR_RETURN(auto value_attr,
3487 ConvertTensorProto(*value, &builder));
3488 auto op = builder.create<GlobalTensorOp>(
3489 builder.getUnknownLoc(),
3490 builder.getStringAttr(object_names.GetSymbolTableName(node_id)),
3491 value_attr,
3492 /*type=*/mlir::TypeAttr::get(value_attr.getType()),
3493 /*is_mutable=*/nullptr);
3494 op->setAttr(
3495 "tf_saved_model.exported_names",
3496 builder.getStrArrayAttr(object_names.GetExportedNames(node_id)));
3497 }
3498 }
3499 AdjustBoundInputArgTypes(module);
3500 module->setAttr("tf_saved_model.semantics", builder.getUnitAttr());
3501 SortSavedModelModule(module);
3502 MarkSavedModelFunctionVisibility(module);
3503 return OkStatus();
3504 }
3505
3506 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>>
Convert(SavedModelV2Bundle * saved_model,absl::Span<std::string> exported_names,mlir::MLIRContext * context,bool add_default_attributes,bool unconditionally_use_set_output_shapes)3507 SavedModelObjectGraphImporter::Convert(
3508 SavedModelV2Bundle* saved_model, absl::Span<std::string> exported_names,
3509 mlir::MLIRContext* context, bool add_default_attributes,
3510 // TODO(b/200093974): Remove post triage.
3511 bool unconditionally_use_set_output_shapes) {
3512 LoadImporterDialects(*context);
3513 GraphDebugInfo dummy_debug_info;
3514 const GraphDebugInfo& debug_info =
3515 saved_model->debug_info() ? *saved_model->debug_info() : dummy_debug_info;
3516
3517 GraphImportConfig specs;
3518 specs.prune_unused_nodes = true;
3519 specs.unconditionally_use_set_output_shapes =
3520 unconditionally_use_set_output_shapes;
3521 mlir::OwningOpRef<mlir::ModuleOp> module =
3522 mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
3523 std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
3524
3525 const auto& graphdef = saved_model->meta_graph_def().graph_def();
3526 PopulateTfVersions(module.get(), graphdef.versions());
3527
3528 GraphConstructorOptions options;
3529 options.allow_internal_ops = true;
3530 options.add_default_attributes = add_default_attributes;
3531 Graph graph(OpRegistry::Global());
3532
3533 GraphDef preprocessed_graphdef(graphdef);
3534 if (add_default_attributes) {
3535 TF_RETURN_IF_ERROR(PreprocessGraphDef(nullptr, &preprocessed_graphdef));
3536 }
3537
3538 TF_RETURN_IF_ERROR(
3539 ConvertGraphDefToGraph(options, preprocessed_graphdef, &graph));
3540
3541 NameUniquifier function_name_uniquifier(graph.flib_def());
3542 SavedModelObjectGraphImporter importer(graph.flib_def(), debug_info, specs,
3543 module.get(), &tf_name_to_mlir_name,
3544 &function_name_uniquifier);
3545
3546 TF_RETURN_IF_ERROR(importer.PrepareConvert(graph));
3547
3548 auto fn_names = graph.flib_def().ListFunctionNames();
3549 for (const auto& fn_name : fn_names) {
3550 TF_RETURN_IF_ERROR(importer.ConvertLibFunction(fn_name));
3551 }
3552 TF_RETURN_IF_ERROR(importer.ConvertDeferredFunctions());
3553
3554 if (!saved_model->meta_graph_def().has_object_graph_def()) {
3555 return errors::InvalidArgument(
3556 "SavedModel does not have an object graph. Please use TF2.");
3557 }
3558 auto& object_graph = saved_model->meta_graph_def().object_graph_def();
3559 ObjectNames object_names(object_graph, exported_names);
3560
3561 // Clean up a couple func's that always seem to be present when importing a
3562 // SavedModel. This is not strictly needed, as there is a separate pass that
3563 // will clean them up, but this makes staring at the raw IR of minimal
3564 // examples quite a bit nicer.
3565 for (auto func :
3566 llvm::make_early_inc_range(module->getOps<mlir::func::FuncOp>())) {
3567 if (func.getName().startswith("__inference__traced_save_") ||
3568 func.getName().startswith("__inference__traced_restore_") ||
3569 func.getName().startswith("__inference_signature_wrapper_")) {
3570 func.erase();
3571 }
3572 }
3573
3574 // Diagnose SavedFunction's with multiple input signatures.
3575 TF_RETURN_IF_ERROR(
3576 DiagnoseMultipleConcreteFunctions(object_graph, object_names));
3577
3578 // Construct the SavedModel IR.
3579 TF_RETURN_IF_ERROR(CreateSavedModelIR(object_names, module.get(),
3580 object_graph, tf_name_to_mlir_name,
3581 saved_model));
3582 assert(mlir::succeeded(mlir::verify(module.get())));
3583
3584 return module;
3585 }
3586
3587 class SimpleSavedModelMLIRImportInput : public SavedModelMLIRImportInput {
3588 public:
Create(const MLIRImportOptions & import_options,const MetaGraphDef * meta_graph_def,const GraphDebugInfo & debug_info)3589 static StatusOr<SimpleSavedModelMLIRImportInput> Create(
3590 const MLIRImportOptions& import_options,
3591 const MetaGraphDef* meta_graph_def, const GraphDebugInfo& debug_info) {
3592 DCHECK(meta_graph_def);
3593 GraphDef graph_def = meta_graph_def->graph_def();
3594 auto graph = std::make_unique<Graph>(OpRegistry::Global());
3595
3596 if (import_options.upgrade_legacy) {
3597 TF_RETURN_IF_ERROR(GenerateResourceSharedNameIfEmpty(
3598 graph_def, graph->flib_def().default_registry()));
3599 }
3600
3601 GraphConstructorOptions graph_ctor_options;
3602 graph_ctor_options.allow_internal_ops = true;
3603 graph_ctor_options.add_default_attributes = true;
3604 TF_RETURN_IF_ERROR(
3605 ConvertGraphDefToGraph(graph_ctor_options, graph_def, graph.get()));
3606
3607 if (import_options.upgrade_legacy) {
3608 // TODO(jpienaar): Remove need to const_cast.
3609 TF_RETURN_IF_ERROR(UpgradeLegacyGraph(
3610 graph.get(),
3611 const_cast<FunctionLibraryDefinition*>(&graph->flib_def()),
3612 /*restrict_functionalization_to_compiled_nodes=*/false));
3613 }
3614
3615 return SimpleSavedModelMLIRImportInput(meta_graph_def, debug_info,
3616 std::move(graph));
3617 }
3618
SimpleSavedModelMLIRImportInput(const MetaGraphDef * meta_graph_def,const GraphDebugInfo & debug_info,std::unique_ptr<Graph> graph)3619 SimpleSavedModelMLIRImportInput(const MetaGraphDef* meta_graph_def,
3620 const GraphDebugInfo& debug_info,
3621 std::unique_ptr<Graph> graph)
3622 : SavedModelMLIRImportInput(meta_graph_def, debug_info),
3623 graph_(std::move(graph)) {}
3624
GetSubGraph(absl::string_view name,GraphImportConfig & specs)3625 StatusOr<const Graph*> GetSubGraph(absl::string_view name,
3626 GraphImportConfig& specs) override {
3627 DCHECK(CheckGraphNameValidity(name));
3628 DCHECK(CheckGraphContainsFeedsAndFetches(specs));
3629 return graph_.get();
3630 }
3631
3632 private:
CheckGraphContainsFeedsAndFetches(const GraphImportConfig & specs) const3633 bool CheckGraphContainsFeedsAndFetches(const GraphImportConfig& specs) const {
3634 absl::flat_hash_set<std::string> feed_fetch_nodes;
3635 for (const auto& iter : specs.inputs) {
3636 TensorId tensor_id = ParseTensorName(iter.first);
3637 feed_fetch_nodes.insert(std::string(tensor_id.node()));
3638 }
3639 for (const auto& output : llvm::concat<const std::string>(
3640 specs.outputs, specs.control_outputs)) {
3641 TensorId tensor_id = ParseTensorName(output);
3642 feed_fetch_nodes.insert(std::string(tensor_id.node()));
3643 }
3644
3645 for (Node* node : graph_->op_nodes()) {
3646 feed_fetch_nodes.erase(node->name());
3647 }
3648
3649 return feed_fetch_nodes.empty();
3650 }
3651
CheckGraphNameValidity(absl::string_view name) const3652 bool CheckGraphNameValidity(absl::string_view name) const {
3653 // If it is one of the signature name, it is valid.
3654 const auto& signature_defs = meta_graph_def().signature_def();
3655 if (signature_defs.contains(std::string(name))) return true;
3656
3657 // If it is the restore graph name, it is valid.
3658 if (meta_graph_def().has_saver_def() &&
3659 meta_graph_def().saver_def().restore_op_name() == name)
3660 return true;
3661
3662 // If it is the init graph name, it is valid.
3663 std::string init_op_name;
3664 if (internal::GetInitOp("", meta_graph_def(), &init_op_name).ok()) {
3665 if (init_op_name == name) return true;
3666 }
3667
3668 return false;
3669 }
3670
3671 // `graph_` contains the entire graph in the original MetaGraphDef.
3672 std::unique_ptr<Graph> graph_;
3673 };
3674
GetOriginalTfFuncNamesFromGraphDef(const GraphDef & graph_def)3675 static absl::flat_hash_set<std::string> GetOriginalTfFuncNamesFromGraphDef(
3676 const GraphDef& graph_def) {
3677 absl::flat_hash_set<std::string> original_func_tf_names;
3678 for (const auto& function : graph_def.library().function()) {
3679 original_func_tf_names.insert(function.signature().name());
3680 }
3681 return original_func_tf_names;
3682 }
3683
3684 // A helper class to import a TensorFlow model expressed in SavedModel V1 into
3685 // an MLIR Module in SavedModel dialect.
3686 //
3687 // TODO(b/179683149): Rename this class to avoid confusion with TFLite.
3688 class SavedModelSignatureDefImporterLite {
3689 public:
3690 // Main entry point: converts all functions (specified by SignatureDefs) in
3691 // the given meta graph to an MLIR Module.
3692 //
3693 // `import_restore` is introduced to control whether restore graph
3694 // is imported in eg. SavedModelSignatureDefImporter. Ideally, we don't need
3695 // this option to control this as restore graph should be always imported.
3696 // However, right now, SavedModelSignatureDefImporter cannot handle restore
3697 // graph correctly.
3698 //
3699 // TODO(chky): Remove import_restore once the restore graph is correctly
3700 // handled in SavedModelSignatureDefImporter.
Convert(SavedModelMLIRImportInput & input,std::optional<absl::Span<const std::string>> exported_names,mlir::MLIRContext * context,bool import_restore=true,bool unconditionally_use_set_output_shapes=false)3701 static StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> Convert(
3702 SavedModelMLIRImportInput& input,
3703 std::optional<absl::Span<const std::string>> exported_names,
3704 mlir::MLIRContext* context, bool import_restore = true,
3705 bool unconditionally_use_set_output_shapes = false) {
3706 SavedModelSignatureDefImporterLite importer(
3707 input, exported_names, context, import_restore,
3708 unconditionally_use_set_output_shapes);
3709 return importer.ConvertSignatures();
3710 }
3711
3712 private:
SavedModelSignatureDefImporterLite(SavedModelMLIRImportInput & input,std::optional<absl::Span<const std::string>> exported_names,mlir::MLIRContext * context,bool import_restore,bool unconditionally_use_set_output_shapes)3713 SavedModelSignatureDefImporterLite(
3714 SavedModelMLIRImportInput& input,
3715 std::optional<absl::Span<const std::string>> exported_names,
3716 mlir::MLIRContext* context, bool import_restore,
3717 bool unconditionally_use_set_output_shapes)
3718 : input_(input),
3719 original_func_tf_names_(GetOriginalTfFuncNamesFromGraphDef(
3720 input.meta_graph_def().graph_def())),
3721 exported_names_(exported_names),
3722 module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))),
3723 symbol_table_(module_.get()),
3724 import_restore_(import_restore),
3725 unconditionally_use_set_output_shapes_(
3726 unconditionally_use_set_output_shapes) {}
3727
3728 // Converts the SavedModel to the SavedModel dialect. Creates an MLIR function
3729 // for each signature.
3730 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertSignatures();
3731 Status ConvertSignature(const std::string& sig_def_key,
3732 const SignatureDef& signature_def);
3733
3734 struct AssetInfo {
3735 std::string tensor_name;
3736 mlir::tf_saved_model::AssetOp op;
3737 };
3738 StatusOr<std::vector<AssetInfo>> ConvertAssets();
3739 // Converts the initialization graph in the SavedModel to an MLIR function.
3740 Status ConvertInitializer(const std::string& target_node_name,
3741 const std::vector<AssetInfo>& assets);
3742
3743 // Converts a graph with feeds and fetches to an MLIR function.
3744 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertGraph(
3745 const std::string& name,
3746 const std::vector<std::pair<std::string, TensorInfo>>& inputs,
3747 const std::vector<std::pair<std::string, TensorInfo>>& outputs,
3748 const std::vector<std::string> control_outputs,
3749 std::unordered_map<std::string, std::string>& tf_name_to_mlir_name);
3750
3751 // Moves the functions in `sub_module` to `module_` and skips the duplicate
3752 // functions.
3753 Status MoveConvertedFunctionsToModule(
3754 absl::string_view name, mlir::ModuleOp sub_module,
3755 const std::unordered_map<std::string, std::string>& tf_name_to_mlir_name);
3756
3757 StatusOr<GraphImportConfig::InputArrays> ParseInputArrays(
3758 llvm::ArrayRef<std::pair<std::string, TensorInfo>> inputs);
3759
3760 private:
3761 SavedModelMLIRImportInput& input_;
3762 absl::flat_hash_set<std::string> original_func_tf_names_;
3763 std::optional<absl::Span<const std::string>> exported_names_;
3764 mlir::OwningOpRef<mlir::ModuleOp> module_;
3765 absl::Mutex symbol_table_mu_;
3766 mlir::SymbolTable symbol_table_ ABSL_GUARDED_BY(symbol_table_mu_);
3767 bool import_restore_ = true;
3768 bool unconditionally_use_set_output_shapes_ = false;
3769 };
3770
3771 StatusOr<std::vector<SavedModelSignatureDefImporterLite::AssetInfo>>
ConvertAssets()3772 SavedModelSignatureDefImporterLite::ConvertAssets() {
3773 std::vector<AssetFileDef> asset_file_defs;
3774 TF_RETURN_IF_ERROR(
3775 internal::GetAssetFileDefs(input_.meta_graph_def(), &asset_file_defs));
3776
3777 std::vector<AssetInfo> results;
3778 results.reserve(asset_file_defs.size());
3779
3780 mlir::OpBuilder builder(module_->getBodyRegion());
3781 unsigned i = 0; // Use to generate unique sym_name(s) for duplicate assets.
3782 for (const auto& asset : asset_file_defs) {
3783 auto asset_op = builder.create<mlir::tf_saved_model::AssetOp>(
3784 module_->getLoc(),
3785 /*sym_name=*/
3786 builder.getStringAttr(
3787 absl::StrCat("__tf_saved_model_asset", i++, "_", asset.filename())),
3788 /*filename=*/
3789 builder.getStringAttr(
3790 io::JoinPath(kSavedModelAssetsDirectory, asset.filename())));
3791
3792 results.push_back({asset.tensor_info().name(), asset_op});
3793 }
3794
3795 return results;
3796 }
3797
MoveConvertedFunctionsToModule(absl::string_view name,mlir::ModuleOp sub_module,const std::unordered_map<std::string,std::string> & tf_name_to_mlir_name)3798 Status SavedModelSignatureDefImporterLite::MoveConvertedFunctionsToModule(
3799 absl::string_view name, mlir::ModuleOp sub_module,
3800 const std::unordered_map<std::string, std::string>& tf_name_to_mlir_name) {
3801 mlir::Builder builder(sub_module.getContext());
3802 mlir::SymbolTable sub_module_symbol_table(sub_module);
3803
3804 // Functions originally from graphdef library might have a different name
3805 // after conversion, we build the set of the converted names
3806 absl::flat_hash_set<std::string> original_func_mlir_names;
3807 for (const auto& kv : tf_name_to_mlir_name) {
3808 if (original_func_tf_names_.contains(kv.first))
3809 original_func_mlir_names.insert(kv.second);
3810 }
3811
3812 // Prefix private functions with the unique signature name, so that it cannot
3813 // collide with private functions used in the other signatures.
3814 for (auto func : sub_module.getOps<mlir::func::FuncOp>()) {
3815 if (mlir::tf_saved_model::IsExported(func)) continue;
3816
3817 // Skip the original functions from graphdef library
3818 if (original_func_mlir_names.count(func.getSymName().str())) continue;
3819
3820 std::string new_sym_name = absl::StrCat(name, "/", func.getSymName().str());
3821 mlir::StringAttr new_sym_name_attr = builder.getStringAttr(new_sym_name);
3822 if (mlir::failed(sub_module_symbol_table.replaceAllSymbolUses(
3823 func, new_sym_name_attr, sub_module)))
3824 return tensorflow::errors::InvalidArgument(absl::StrCat(
3825 "SavedModelSignatureDefImporterLite: failed to assign a unique "
3826 "name to the private function used in a signature: ",
3827 func.getSymName().str()));
3828
3829 mlir::SymbolTable::setSymbolName(func, new_sym_name);
3830 }
3831
3832 // Copy all functions used by this signature to the final MLIR module.
3833 for (auto func : sub_module.getOps<mlir::func::FuncOp>()) {
3834 absl::MutexLock l(&symbol_table_mu_);
3835 // The insert here is a NO-OP if the function already exists.
3836 symbol_table_.insert(func.clone());
3837 }
3838
3839 return OkStatus();
3840 }
3841
ConvertInitializer(const std::string & target_node_name,const std::vector<AssetInfo> & assets)3842 Status SavedModelSignatureDefImporterLite::ConvertInitializer(
3843 const std::string& target_node_name, const std::vector<AssetInfo>& assets) {
3844 std::vector<std::pair<std::string, TensorInfo>> inputs;
3845 inputs.reserve(assets.size());
3846 for (const auto& asset : assets) {
3847 TensorInfo tensor_info;
3848 tensor_info.set_name(asset.tensor_name);
3849 tensor_info.set_dtype(DT_STRING);
3850 tensor_info.mutable_tensor_shape();
3851 inputs.push_back({asset.tensor_name, tensor_info});
3852 }
3853
3854 std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
3855 TF_ASSIGN_OR_RETURN(auto sub_module,
3856 ConvertGraph(target_node_name, inputs, {},
3857 {target_node_name}, tf_name_to_mlir_name));
3858
3859 mlir::SymbolTable sub_symbol_table(*sub_module);
3860
3861 auto init_func_op =
3862 sub_symbol_table.lookup<mlir::func::FuncOp>(target_node_name);
3863 init_func_op->removeAttr("tf.entry_function");
3864
3865 mlir::OpBuilder builder(module_->getBodyRegion());
3866
3867 // Bind asset inputs to asset ops.
3868 DCHECK_EQ(init_func_op.getNumArguments(), assets.size());
3869 for (const auto& iter : llvm::enumerate(assets)) {
3870 auto asset_op = iter.value().op;
3871 init_func_op.setArgAttr(
3872 iter.index(), "tf_saved_model.bound_input",
3873 mlir::SymbolRefAttr::get(builder.getContext(), asset_op.getName()));
3874 }
3875
3876 // Set the exported name of init function to an reserved name for
3877 // tf_saved_model.
3878 init_func_op->setAttr(
3879 "tf_saved_model.exported_names",
3880 builder.getStrArrayAttr({absl::StrCat(
3881 "__tf_saved_model_session_initializer_", target_node_name)}));
3882
3883 // Move the converted functions to top level MLIR module.
3884 return MoveConvertedFunctionsToModule(target_node_name, *sub_module,
3885 tf_name_to_mlir_name);
3886 }
3887
3888 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>>
ConvertGraph(const std::string & name,const std::vector<std::pair<std::string,TensorInfo>> & inputs,const std::vector<std::pair<std::string,TensorInfo>> & outputs,const std::vector<std::string> control_outputs,std::unordered_map<std::string,std::string> & tf_name_to_mlir_name)3889 SavedModelSignatureDefImporterLite::ConvertGraph(
3890 const std::string& name,
3891 const std::vector<std::pair<std::string, TensorInfo>>& inputs,
3892 const std::vector<std::pair<std::string, TensorInfo>>& outputs,
3893 const std::vector<std::string> control_outputs,
3894 std::unordered_map<std::string, std::string>& tf_name_to_mlir_name) {
3895 VLOG(1) << "Importing Signature: " << name;
3896
3897 GraphImportConfig specs;
3898 specs.graph_func_name = name;
3899 specs.prune_unused_nodes = true;
3900 TF_ASSIGN_OR_RETURN(specs.inputs, ParseInputArrays(inputs));
3901 for (auto& output : outputs) specs.outputs.push_back(output.second.name());
3902 specs.control_outputs = control_outputs;
3903 specs.enable_shape_inference = false;
3904 specs.unconditionally_use_set_output_shapes =
3905 unconditionally_use_set_output_shapes_;
3906
3907 TF_ASSIGN_OR_RETURN(const auto* subgraph, input_.GetSubGraph(name, specs));
3908
3909 // Convert sub-graph to MLIR module.
3910 return GraphDefImporter::Convert(module_->getContext(), *subgraph,
3911 input_.debug_info(), subgraph->flib_def(),
3912 specs, tf_name_to_mlir_name);
3913 }
3914
ConvertSignature(const std::string & sig_def_key,const SignatureDef & signature_def)3915 Status SavedModelSignatureDefImporterLite::ConvertSignature(
3916 const std::string& sig_def_key, const SignatureDef& signature_def) {
3917 // Create local vectors for the input and output and sort them to be
3918 // deterministic. We don't want anyone to really depend on the order, client
3919 // should lookup argument/result mapping by attribute name.
3920 // To avoid accidentally depending on the order we use an unintuitive sorting.
3921 std::vector<std::pair<std::string, TensorInfo>> inputs(
3922 signature_def.inputs().begin(), signature_def.inputs().end());
3923 llvm::sort(inputs, [](const auto& lhs, const auto& rhs) {
3924 return tensorflow::Fingerprint64(lhs.first) <
3925 tensorflow::Fingerprint64(rhs.first);
3926 });
3927 std::vector<std::pair<std::string, TensorInfo>> outputs(
3928 signature_def.outputs().begin(), signature_def.outputs().end());
3929 llvm::sort(outputs, [](const auto& lhs, const auto& rhs) {
3930 return tensorflow::Fingerprint64(lhs.first) <
3931 tensorflow::Fingerprint64(rhs.first);
3932 });
3933
3934 std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
3935
3936 // Convert sub-graph to MLIR module.
3937 TF_ASSIGN_OR_RETURN(
3938 auto sub_module,
3939 ConvertGraph(sig_def_key, inputs, outputs, {}, tf_name_to_mlir_name));
3940 mlir::OpBuilder builder(sub_module->getBodyRegion());
3941
3942 // Find the FuncOp which corresponds to current SignatureDef.
3943 mlir::SymbolTable sub_symbol_table(*sub_module);
3944 auto func_op = sub_symbol_table.lookup<mlir::func::FuncOp>(sig_def_key);
3945 TF_RET_CHECK(func_op)
3946 << "Graphdef importer should have created a function named "
3947 << sig_def_key << ".";
3948
3949 // Use unique SignatureDef key as exported name.
3950 func_op->setAttr("tf_saved_model.exported_names",
3951 builder.getStrArrayAttr({sig_def_key}));
3952
3953 // Transfer input and output parameter names to index_path attributes.
3954 for (auto input_and_idx : llvm::enumerate(inputs)) {
3955 func_op.setArgAttr(input_and_idx.index(), "tf_saved_model.index_path",
3956 builder.getStrArrayAttr({input_and_idx.value().first}));
3957 }
3958 for (auto output_and_idx : llvm::enumerate(outputs)) {
3959 func_op.setResultAttr(
3960 output_and_idx.index(), "tf_saved_model.index_path",
3961 builder.getStrArrayAttr({output_and_idx.value().first}));
3962 }
3963
3964 // Move the converted functions to top level MLIR module.
3965 return MoveConvertedFunctionsToModule(sig_def_key, *sub_module,
3966 tf_name_to_mlir_name);
3967 }
3968
3969 StatusOr<GraphImportConfig::InputArrays>
ParseInputArrays(llvm::ArrayRef<std::pair<std::string,TensorInfo>> inputs)3970 SavedModelSignatureDefImporterLite::ParseInputArrays(
3971 llvm::ArrayRef<std::pair<std::string, TensorInfo>> inputs) {
3972 GraphImportConfig::InputArrays results;
3973 for (const auto& iter : inputs) {
3974 const auto& tensor_info = iter.second;
3975
3976 // TODO(b/184675681): Support other encoding cases.
3977 //
3978 // TODO(b/184679394): Add unit test for this check.
3979 TF_RET_CHECK(tensor_info.encoding_case() == tensorflow::TensorInfo::kName)
3980 << "Only dense tensor is supported, but got encoding case "
3981 << tensor_info.encoding_case();
3982
3983 VLOG(1) << "Importing Signature Input: input_name = " << iter.first
3984 << ", tensor_info = " << tensor_info.DebugString();
3985
3986 ArrayInfo array_info;
3987 array_info.imported_dtype = tensor_info.dtype();
3988
3989 if (tensor_info.has_tensor_shape()) {
3990 array_info.shape = tensor_info.tensor_shape();
3991 } else {
3992 // If there is no tensor shape in the tensor info, conservatively set
3993 // unknown_rank to true.
3994 array_info.shape.set_unknown_rank(true);
3995 }
3996
3997 results.insert(std::pair<std::string, ArrayInfo>(tensor_info.name(),
3998 std::move(array_info)));
3999 }
4000 return results;
4001 }
4002
4003 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>>
ConvertSignatures()4004 SavedModelSignatureDefImporterLite::ConvertSignatures() {
4005 LoadImporterDialects(*module_->getContext());
4006
4007 const auto& signatures = input_.meta_graph_def().signature_def();
4008 PopulateTfVersions(module_.get(),
4009 input_.meta_graph_def().graph_def().versions());
4010
4011 llvm::DenseSet<llvm::StringRef> exported_name_set;
4012 bool import_all_signatures = !exported_names_.has_value();
4013 if (exported_names_.has_value()) {
4014 exported_name_set.insert(exported_names_->begin(), exported_names_->end());
4015 }
4016
4017 absl::Mutex error_status_mu; // Needed since `error_status` is non-atomic.
4018 tensorflow::Status error_status;
4019 {
4020 // Start a threadpool to convert signatures, since signature conversion can
4021 // be time consuming especially for large models. Threadpool destructor
4022 // blocks until all work is done.
4023 thread::ThreadPool thread_pool(Env::Default(), "ConvertSignatures",
4024 kNumThreadToConvertSignatures);
4025 for (const auto& key_and_signature_def : signatures) {
4026 const std::string& sig_def_key = key_and_signature_def.first;
4027 const SignatureDef& signature_def = key_and_signature_def.second;
4028
4029 // It is safe to skip "__saved_model_init_op" since it is an internal
4030 // signature that is not user-accessible. This signature will be handled
4031 // in ConvertInitializer().
4032 if (sig_def_key == "__saved_model_init_op") {
4033 continue;
4034 }
4035 if (!import_all_signatures && exported_name_set.count(sig_def_key) == 0) {
4036 continue;
4037 }
4038
4039 thread_pool.Schedule([&]() {
4040 auto status = ConvertSignature(sig_def_key, signature_def);
4041 if (!status.ok()) {
4042 absl::MutexLock l(&error_status_mu);
4043 error_status = std::move(status);
4044 }
4045 });
4046 }
4047 }
4048 TF_RETURN_IF_ERROR(error_status);
4049
4050 TF_ASSIGN_OR_RETURN(auto assets, ConvertAssets());
4051
4052 mlir::OpBuilder builder(module_->getBodyRegion());
4053 llvm::SmallVector<mlir::Attribute, 2> init_sym_refs;
4054
4055 if (import_restore_ && input_.meta_graph_def().has_saver_def()) {
4056 std::vector<AssetInfo> variable_and_assets;
4057
4058 // Create an AssetOp for the variable checkpoint files. The relative
4059 // filename is used here.
4060 auto variable_filename_op = builder.create<mlir::tf_saved_model::AssetOp>(
4061 module_->getLoc(),
4062 /*sym_name=*/
4063 builder.getStringAttr("__tf_saved_model_variables"),
4064 /*filename=*/
4065 builder.getStringAttr(io::JoinPath(kSavedModelVariablesDirectory,
4066 kSavedModelVariablesFilename)));
4067 variable_and_assets.push_back(
4068 {input_.meta_graph_def().saver_def().filename_tensor_name(),
4069 variable_filename_op});
4070 variable_and_assets.insert(variable_and_assets.end(), assets.begin(),
4071 assets.end());
4072
4073 const auto& restore_op_name =
4074 input_.meta_graph_def().saver_def().restore_op_name();
4075 TF_RETURN_IF_ERROR(
4076 ConvertInitializer(restore_op_name, variable_and_assets));
4077 init_sym_refs.push_back(
4078 mlir::SymbolRefAttr::get(builder.getContext(), restore_op_name));
4079 }
4080
4081 std::string init_op_name;
4082 TF_RETURN_IF_ERROR(
4083 internal::GetInitOp("", input_.meta_graph_def(), &init_op_name));
4084 if (!init_op_name.empty()) {
4085 TF_RETURN_IF_ERROR(ConvertInitializer(init_op_name, assets));
4086 init_sym_refs.push_back(
4087 mlir::SymbolRefAttr::get(builder.getContext(), init_op_name));
4088 }
4089
4090 builder.create<mlir::tf_saved_model::SessionInitializerOp>(
4091 module_->getLoc(), builder.getArrayAttr(init_sym_refs));
4092
4093 (*module_)->setAttr("tf_saved_model.semantics", builder.getUnitAttr());
4094
4095 SortSavedModelModule(*module_);
4096 MarkSavedModelFunctionVisibility(*module_);
4097
4098 return std::move(module_);
4099 }
4100
4101 // A helper class to import a TensorFlow model expressed in SavedModel V1 into
4102 // an MLIR Module in SavedModel dialect. In addition to importing the model, it
4103 // performs a few graph transformations, including:
4104 // 1) Convert read-only ref variables to resource variables
4105 // 2) Lift resource variables to global_tensors by using a TF session.
4106 class SavedModelSignatureDefImporter {
4107 public:
4108 // Main entry point: converts all functions (specified by SignatureDefs) in
4109 // the given meta graph to an MLIR Module.
Convert(const SavedModelBundle & bundle,std::optional<absl::Span<const std::string>> exported_names,mlir::MLIRContext * context,tensorflow::MLIRImportOptions options,bool lift_varhandle_ops_to_args=true)4110 static StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> Convert(
4111 const SavedModelBundle& bundle,
4112 std::optional<absl::Span<const std::string>> exported_names,
4113 mlir::MLIRContext* context, tensorflow::MLIRImportOptions options,
4114 bool lift_varhandle_ops_to_args = true) {
4115 // debug_info might not be loaded with loader_lite.
4116 GraphDebugInfo debug_info;
4117 if (bundle.debug_info != nullptr) debug_info = *bundle.debug_info;
4118
4119 TF_ASSIGN_OR_RETURN(auto input,
4120 SimpleSavedModelMLIRImportInput::Create(
4121 options, &bundle.meta_graph_def, debug_info));
4122
4123 TF_ASSIGN_OR_RETURN(auto module,
4124 SavedModelSignatureDefImporterLite::Convert(
4125 input, exported_names, context,
4126 /*import_restore=*/false));
4127
4128 mlir::OpBuilder builder(module->getContext());
4129 (*module)->setAttr("tf_saved_model.under_construction",
4130 builder.getUnitAttr());
4131 TF_RETURN_IF_ERROR(
4132 LiftVariables(bundle, *module, lift_varhandle_ops_to_args));
4133 (*module)->removeAttr("tf_saved_model.under_construction");
4134
4135 return module;
4136 }
4137
4138 private:
4139 // Lifts the variables in `module`.
4140 static Status LiftVariables(const SavedModelBundle& bundle,
4141 mlir::ModuleOp module,
4142 bool lift_varhandle_ops_to_args);
4143 };
4144
LiftVariables(const SavedModelBundle & bundle,mlir::ModuleOp module,bool lift_varhandle_ops_to_args)4145 Status SavedModelSignatureDefImporter::LiftVariables(
4146 const SavedModelBundle& bundle, mlir::ModuleOp module,
4147 bool lift_varhandle_ops_to_args) {
4148 mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
4149
4150 mlir::PassManager pm(module.getContext());
4151 SetCrashReproducer(pm);
4152 pm.addNestedPass<mlir::func::FuncOp>(
4153 mlir::tf_executor::CreateTFExecutorGraphPruningPass());
4154 pm.addNestedPass<mlir::func::FuncOp>(
4155 mlir::CreateExecutorDialectToFunctionalConversionPass());
4156 pm.addPass(
4157 mlir::tf_saved_model::CreateRemoveVariablesInSessionInitializerPass());
4158 pm.addNestedPass<mlir::func::FuncOp>(
4159 mlir::TF::
4160 CreateConvertReadonlyReferenceVariablesToResourceVariablesPass());
4161 if (mlir::failed(pm.run(module)))
4162 return diag_handler.Combine(
4163 errors::Internal("Failed to prepare to lift variables."));
4164
4165 if (lift_varhandle_ops_to_args) {
4166 if (failed(mlir::tf_saved_model::MarkInitializedVariablesInFunction(
4167 module, bundle.GetSession())))
4168 return diag_handler.Combine(
4169 errors::Internal("Failed to prepare to mark initialized variables."));
4170 pm.clear();
4171 pm.addPass(mlir::TF::CreatePromoteVarHandlesToArgsPass());
4172 if (mlir::failed(pm.run(module)))
4173 return diag_handler.Combine(
4174 errors::Internal("Failed to promote var handles to args."));
4175 if (failed(
4176 mlir::tf_saved_model::LiftVariables(module, bundle.GetSession())))
4177 return diag_handler.Combine(
4178 errors::Internal("Failed to lift variables."));
4179 } else {
4180 if (failed(mlir::tf_saved_model::InitializeVariablesInSessionInitializer(
4181 module, bundle.GetSession())))
4182 return diag_handler.Combine(
4183 errors::Internal("Failed to initialize variables in session init."));
4184 }
4185
4186 pm.clear();
4187 pm.addNestedPass<mlir::func::FuncOp>(
4188 mlir::tf_saved_model::CreateDedupBoundInputBindingPass());
4189 if (mlir::failed(pm.run(module)))
4190 return diag_handler.Combine(
4191 errors::Internal("Failed to dedup bound inputs."));
4192
4193 return OkStatus();
4194 }
4195
4196 } // namespace
4197
~SavedModelMLIRImportInput()4198 SavedModelMLIRImportInput::~SavedModelMLIRImportInput() {}
4199
ConvertGraphdefToMlir(const GraphDef & graphdef,const GraphDebugInfo & debug_info,const GraphImportConfig & specs,mlir::MLIRContext * context,bool add_default_attributes)4200 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertGraphdefToMlir(
4201 const GraphDef& graphdef, const GraphDebugInfo& debug_info,
4202 const GraphImportConfig& specs, mlir::MLIRContext* context,
4203 bool add_default_attributes) {
4204 GraphConstructorOptions options;
4205 options.allow_internal_ops = true;
4206 options.add_default_attributes = add_default_attributes;
4207 Graph graph(OpRegistry::Global());
4208
4209 GraphDef preprocessed_graphdef(graphdef);
4210 if (add_default_attributes) {
4211 TF_RETURN_IF_ERROR(PreprocessGraphDef(&specs, &preprocessed_graphdef));
4212 }
4213 if (specs.upgrade_legacy) {
4214 TF_RETURN_IF_ERROR(GenerateResourceSharedNameIfEmpty(
4215 preprocessed_graphdef, graph.flib_def().default_registry()));
4216 }
4217 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
4218 options, std::move(preprocessed_graphdef), &graph));
4219 return ConvertGraphToMlir(graph, debug_info, graph.flib_def(), specs,
4220 context);
4221 }
4222
ConvertGraphToMlir(const Graph & graph,const GraphDebugInfo & debug_info,const FunctionLibraryDefinition & flib_def,const GraphImportConfig & specs,mlir::MLIRContext * context)4223 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertGraphToMlir(
4224 const Graph& graph, const GraphDebugInfo& debug_info,
4225 const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs,
4226 mlir::MLIRContext* context) {
4227 // TODO(jpienaar): Remove need to const_cast.
4228 if (specs.upgrade_legacy) {
4229 TF_RETURN_IF_ERROR(
4230 UpgradeLegacyGraph(const_cast<Graph*>(&graph),
4231 const_cast<FunctionLibraryDefinition*>(&flib_def),
4232 specs.restrict_functionalization_to_compiled_nodes));
4233 }
4234 std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
4235 return GraphDefImporter::Convert(context, graph, debug_info, flib_def, specs,
4236 tf_name_to_mlir_name);
4237 }
4238
4239 stream_executor::port::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>>
ConvertFunctionToMlir(const FunctionBody * fbody,const FunctionLibraryDefinition & flib_def,mlir::MLIRContext * context)4240 ConvertFunctionToMlir(const FunctionBody* fbody,
4241 const FunctionLibraryDefinition& flib_def,
4242 mlir::MLIRContext* context) {
4243 tensorflow::GraphDebugInfo dummy_debug_info;
4244 tensorflow::GraphImportConfig specs;
4245 specs.graph_func_name = fbody->fdef.signature().name();
4246 specs.enable_shape_inference = false;
4247 specs.graph_as_function = true;
4248 for (const auto* control_ret_node : fbody->control_ret_nodes)
4249 specs.control_outputs.push_back(control_ret_node->name());
4250 std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
4251 return GraphDefImporter::Convert(context, *fbody->graph, dummy_debug_info,
4252 flib_def, specs, tf_name_to_mlir_name);
4253 }
4254
ConvertSavedModelToMlir(SavedModelV2Bundle * saved_model,mlir::MLIRContext * context,absl::Span<std::string> exported_names,bool add_default_attributes,bool unconditionally_use_set_output_shapes)4255 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertSavedModelToMlir(
4256 SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
4257 absl::Span<std::string> exported_names, bool add_default_attributes,
4258 bool unconditionally_use_set_output_shapes) {
4259 return SavedModelObjectGraphImporter::Convert(
4260 saved_model, exported_names, context, add_default_attributes,
4261 unconditionally_use_set_output_shapes);
4262 }
4263
ConvertSavedModelV1ToMlir(const SavedModelBundle & saved_model,absl::Span<std::string> exported_names,mlir::MLIRContext * context,MLIRImportOptions options,bool lift_variables)4264 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertSavedModelV1ToMlir(
4265 const SavedModelBundle& saved_model, absl::Span<std::string> exported_names,
4266 mlir::MLIRContext* context, MLIRImportOptions options,
4267 bool lift_variables) {
4268 std::optional<absl::Span<const std::string>> optional_exported_names;
4269 // TODO(b/187062560): Change ConvertSavedModelV1ToMlir() to take an optional
4270 // `exported_names` so that it can be configured to import only restore/init
4271 // graphs.
4272 if (!exported_names.empty()) optional_exported_names = exported_names;
4273 return SavedModelSignatureDefImporter::Convert(
4274 saved_model, optional_exported_names, context, options, lift_variables);
4275 }
4276
ConvertSavedModelV1ToMlirLite(const MetaGraphDef & meta_graph_def,const GraphDebugInfo & debug_info,std::optional<absl::Span<const std::string>> exported_names,mlir::MLIRContext * context,MLIRImportOptions options)4277 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertSavedModelV1ToMlirLite(
4278 const MetaGraphDef& meta_graph_def, const GraphDebugInfo& debug_info,
4279 std::optional<absl::Span<const std::string>> exported_names,
4280 mlir::MLIRContext* context, MLIRImportOptions options) {
4281 TF_ASSIGN_OR_RETURN(auto input, SimpleSavedModelMLIRImportInput::Create(
4282 options, &meta_graph_def, debug_info));
4283 return ConvertSavedModelV1ToMlirLite(
4284 input, exported_names, context,
4285 options.unconditionally_use_set_output_shapes);
4286 }
4287
ConvertSavedModelV1ToMlirLite(SavedModelMLIRImportInput & input,std::optional<absl::Span<const std::string>> exported_names,mlir::MLIRContext * context,bool unconditionally_use_set_output_shapes)4288 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertSavedModelV1ToMlirLite(
4289 SavedModelMLIRImportInput& input,
4290 std::optional<absl::Span<const std::string>> exported_names,
4291 mlir::MLIRContext* context, bool unconditionally_use_set_output_shapes) {
4292 return SavedModelSignatureDefImporterLite::Convert(
4293 input, exported_names, context,
4294 /*import_restore=*/true, unconditionally_use_set_output_shapes);
4295 }
4296
MlirModuleToString(mlir::ModuleOp module,mlir::OpPrintingFlags flags)4297 std::string MlirModuleToString(mlir::ModuleOp module,
4298 mlir::OpPrintingFlags flags) {
4299 std::string txt_module;
4300 {
4301 llvm::raw_string_ostream os{txt_module};
4302 module.print(os, flags);
4303 }
4304 return txt_module;
4305 }
4306
MlirModuleToString(mlir::ModuleOp module,bool show_debug_info)4307 std::string MlirModuleToString(mlir::ModuleOp module, bool show_debug_info) {
4308 mlir::OpPrintingFlags flags;
4309 if (show_debug_info) flags.enableDebugInfo();
4310 return MlirModuleToString(module, flags);
4311 }
4312
4313 } // namespace tensorflow
4314