1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_ 17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_ 18 19 #include <string> 20 21 #include "absl/strings/string_view.h" 22 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project 23 #include "mlir/IR/BuiltinOps.h" // from @llvm-project 24 #include "mlir/IR/MLIRContext.h" // from @llvm-project 25 #include "mlir/IR/OperationSupport.h" // from @llvm-project 26 #include "mlir/Support/LLVM.h" // from @llvm-project 27 #include "tensorflow/cc/saved_model/bundle_v2.h" 28 #include "tensorflow/cc/saved_model/loader.h" 29 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h" 30 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" 31 #include "tensorflow/core/framework/function.h" 32 #include "tensorflow/core/framework/graph.pb.h" 33 #include "tensorflow/core/graph/graph.h" 34 #include "tensorflow/core/protobuf/graph_debug_info.pb.h" 35 #include "tensorflow/stream_executor/lib/statusor.h" 36 37 namespace tensorflow { 38 39 inline constexpr absl::string_view kImportModelDefaultGraphFuncName = "main"; 40 41 // Given a GraphDef, returns a MLIR module containing the graph, expressed with 42 // tf_executor dialect. 43 stream_executor::port::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> 44 ConvertGraphdefToMlir(const GraphDef& graphdef, 45 const GraphDebugInfo& debug_info, 46 const GraphImportConfig& specs, 47 mlir::MLIRContext* context, 48 bool add_default_attributes = true); 49 50 // Given a Graph, returns a MLIR module containing the graph, expressed with 51 // tf_executor dialect. 52 stream_executor::port::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> 53 ConvertGraphToMlir(const Graph& graph, const GraphDebugInfo& debug_info, 54 const FunctionLibraryDefinition& flib_def, 55 const GraphImportConfig& specs, mlir::MLIRContext* context); 56 57 // [Experimental] 58 // Given a Function, returns a MLIR module containing the graph, expressed with 59 // tf_executor dialect. 60 stream_executor::port::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> 61 ConvertFunctionToMlir(const FunctionBody* fbody, 62 const FunctionLibraryDefinition& flib_def, 63 mlir::MLIRContext* context); 64 65 // Given a SavedModel, returns a MLIR module containing the functions, expressed 66 // with tf_executor dialect. 67 stream_executor::port::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> 68 ConvertSavedModelToMlir(SavedModelV2Bundle* saved_model, 69 mlir::MLIRContext* context, 70 absl::Span<std::string> exported_names, 71 bool add_default_attributes = true, 72 bool unconditionally_use_set_output_shapes = false); 73 74 // Given a V1 SavedModel, returns a MLIR module containing the functions, 75 // expressed with tf_executor dialect. 76 stream_executor::port::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> 77 ConvertSavedModelV1ToMlir(const SavedModelBundle& saved_model, 78 absl::Span<std::string> exported_names, 79 mlir::MLIRContext* context, MLIRImportOptions options, 80 bool lift_variables = true); 81 82 // Given a V1 SavedModel, returns a MLIR module containing the functions, 83 // expressed with tf_executor dialect. It does not require a session to be 84 // created and it does not perform any graph transformation. If `exported_names` 85 // is std::nullopt, all signatures will be imported. Otherwise, only names 86 // in `exported_names` are imported. 87 // 88 // Note that the word `Lite` means it is a lighter version compared to 89 // ConvertSavedModelV1ToMlir(), and is not related to TFLite. 90 // 91 // TODO(b/179683149): Rename this class to avoid confusion with TFLite. 92 stream_executor::port::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> 93 ConvertSavedModelV1ToMlirLite( 94 const MetaGraphDef& meta_graph_def, const GraphDebugInfo& debug_info, 95 std::optional<absl::Span<const std::string>> exported_names, 96 mlir::MLIRContext* context, MLIRImportOptions options); 97 98 // SavedModelMLIRImportInput is an adapter class for users to inject custom 99 // graph transformation logic on Tensorflow graphs before importing to MLIR. It 100 // serves as the source that provides the subgraphs requested by the savedmodel 101 // MLIR importer, and at the same time it allows the implementation of this 102 // class to transform the graph before feeding it to the importer. 103 class SavedModelMLIRImportInput { 104 public: SavedModelMLIRImportInput(const MetaGraphDef * meta_graph_def,const GraphDebugInfo & debug_info)105 SavedModelMLIRImportInput(const MetaGraphDef* meta_graph_def, 106 const GraphDebugInfo& debug_info) 107 : meta_graph_def_(meta_graph_def), debug_info_(debug_info) { 108 DCHECK(meta_graph_def); 109 } 110 111 virtual ~SavedModelMLIRImportInput(); 112 113 // The original MetaGraphDef of the savedmodel. meta_graph_def()114 const MetaGraphDef& meta_graph_def() const { return *meta_graph_def_; } 115 debug_info()116 const GraphDebugInfo& debug_info() const { return debug_info_; } 117 118 // GetSubGraph() is expected to return a tensorflow::Graph that contains the 119 // node set specified in `specs`. The implementation is free to transform the 120 // graph in the original savedmodel as needed, as long as it produces the same 121 // results and effects. If the transformation requires some configs in `spec` 122 // (e.g., control_outputs) to be changed, they should be updated accordingly 123 // and remain valid for the graph. 124 // `name` is a unique identifier for this subgraph, so the implementation can 125 // use it for eg. debugging or caching compilation results. 126 virtual stream_executor::port::StatusOr<const Graph*> GetSubGraph( 127 absl::string_view name, GraphImportConfig& specs) = 0; 128 129 private: 130 const MetaGraphDef* meta_graph_def_ = nullptr; 131 GraphDebugInfo debug_info_; 132 }; 133 134 // Given the SavedModelMLIRImportInput for a saved model, returns a MLIR module 135 // containing the functions, expressed with tf_executor dialect. It does not 136 // require a session to be created. If `exported_names` is std::nullopt, all 137 // signatures will be imported. Otherwise, only names in `exported_names` are 138 // imported. 139 140 // 141 // Note that the word `Lite` means it is a lighter version compared to 142 // ConvertSavedModelV1ToMlir(), and is not related to TFLite. 143 // 144 // TODO(b/179683149): Rename this class to avoid confusion with TFLite. 145 stream_executor::port::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> 146 ConvertSavedModelV1ToMlirLite( 147 SavedModelMLIRImportInput& input, 148 std::optional<absl::Span<const std::string>> exported_names, 149 mlir::MLIRContext* context, 150 bool unconditionally_use_set_output_shapes = false); 151 152 // Serialize a MLIR module to a string. 153 std::string MlirModuleToString(mlir::ModuleOp module, 154 mlir::OpPrintingFlags flags); 155 std::string MlirModuleToString(mlir::ModuleOp m, bool show_debug_info = false); 156 157 } // namespace tensorflow 158 159 #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_ 160