xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/translate/import_model.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_
17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_
18 
19 #include <string>
20 
21 #include "absl/strings/string_view.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
23 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
24 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
25 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
26 #include "mlir/Support/LLVM.h"  // from @llvm-project
27 #include "tensorflow/cc/saved_model/bundle_v2.h"
28 #include "tensorflow/cc/saved_model/loader.h"
29 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h"
30 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
31 #include "tensorflow/core/framework/function.h"
32 #include "tensorflow/core/framework/graph.pb.h"
33 #include "tensorflow/core/graph/graph.h"
34 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
35 #include "tensorflow/stream_executor/lib/statusor.h"
36 
37 namespace tensorflow {
38 
39 inline constexpr absl::string_view kImportModelDefaultGraphFuncName = "main";
40 
41 // Given a GraphDef, returns a MLIR module containing the graph, expressed with
42 // tf_executor dialect.
43 stream_executor::port::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>>
44 ConvertGraphdefToMlir(const GraphDef& graphdef,
45                       const GraphDebugInfo& debug_info,
46                       const GraphImportConfig& specs,
47                       mlir::MLIRContext* context,
48                       bool add_default_attributes = true);
49 
50 // Given a Graph, returns a MLIR module containing the graph, expressed with
51 // tf_executor dialect.
52 stream_executor::port::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>>
53 ConvertGraphToMlir(const Graph& graph, const GraphDebugInfo& debug_info,
54                    const FunctionLibraryDefinition& flib_def,
55                    const GraphImportConfig& specs, mlir::MLIRContext* context);
56 
57 // [Experimental]
58 // Given a Function, returns a MLIR module containing the graph, expressed with
59 // tf_executor dialect.
60 stream_executor::port::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>>
61 ConvertFunctionToMlir(const FunctionBody* fbody,
62                       const FunctionLibraryDefinition& flib_def,
63                       mlir::MLIRContext* context);
64 
65 // Given a SavedModel, returns a MLIR module containing the functions, expressed
66 // with tf_executor dialect.
67 stream_executor::port::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>>
68 ConvertSavedModelToMlir(SavedModelV2Bundle* saved_model,
69                         mlir::MLIRContext* context,
70                         absl::Span<std::string> exported_names,
71                         bool add_default_attributes = true,
72                         bool unconditionally_use_set_output_shapes = false);
73 
74 // Given a V1 SavedModel, returns a MLIR module containing the functions,
75 // expressed with tf_executor dialect.
76 stream_executor::port::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>>
77 ConvertSavedModelV1ToMlir(const SavedModelBundle& saved_model,
78                           absl::Span<std::string> exported_names,
79                           mlir::MLIRContext* context, MLIRImportOptions options,
80                           bool lift_variables = true);
81 
82 // Given a V1 SavedModel, returns a MLIR module containing the functions,
83 // expressed with tf_executor dialect. It does not require a session to be
84 // created and it does not perform any graph transformation. If `exported_names`
85 // is std::nullopt, all signatures will be imported. Otherwise, only names
86 // in `exported_names` are imported.
87 //
88 // Note that the word `Lite` means it is a lighter version compared to
89 // ConvertSavedModelV1ToMlir(), and is not related to TFLite.
90 //
91 // TODO(b/179683149): Rename this class to avoid confusion with TFLite.
92 stream_executor::port::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>>
93 ConvertSavedModelV1ToMlirLite(
94     const MetaGraphDef& meta_graph_def, const GraphDebugInfo& debug_info,
95     std::optional<absl::Span<const std::string>> exported_names,
96     mlir::MLIRContext* context, MLIRImportOptions options);
97 
98 // SavedModelMLIRImportInput is an adapter class for users to inject custom
99 // graph transformation logic on Tensorflow graphs before importing to MLIR. It
100 // serves as the source that provides the subgraphs requested by the savedmodel
101 // MLIR importer, and at the same time it allows the implementation of this
102 // class to transform the graph before feeding it to the importer.
103 class SavedModelMLIRImportInput {
104  public:
SavedModelMLIRImportInput(const MetaGraphDef * meta_graph_def,const GraphDebugInfo & debug_info)105   SavedModelMLIRImportInput(const MetaGraphDef* meta_graph_def,
106                             const GraphDebugInfo& debug_info)
107       : meta_graph_def_(meta_graph_def), debug_info_(debug_info) {
108     DCHECK(meta_graph_def);
109   }
110 
111   virtual ~SavedModelMLIRImportInput();
112 
113   // The original MetaGraphDef of the savedmodel.
meta_graph_def()114   const MetaGraphDef& meta_graph_def() const { return *meta_graph_def_; }
115 
debug_info()116   const GraphDebugInfo& debug_info() const { return debug_info_; }
117 
118   // GetSubGraph() is expected to return a tensorflow::Graph that contains the
119   // node set specified in `specs`. The implementation is free to transform the
120   // graph in the original savedmodel as needed, as long as it produces the same
121   // results and effects. If the transformation requires some configs in `spec`
122   // (e.g., control_outputs) to be changed, they should be updated accordingly
123   // and remain valid for the graph.
124   // `name` is a unique identifier for this subgraph, so the implementation can
125   // use it for eg. debugging or caching compilation results.
126   virtual stream_executor::port::StatusOr<const Graph*> GetSubGraph(
127       absl::string_view name, GraphImportConfig& specs) = 0;
128 
129  private:
130   const MetaGraphDef* meta_graph_def_ = nullptr;
131   GraphDebugInfo debug_info_;
132 };
133 
134 // Given the SavedModelMLIRImportInput for a saved model, returns a MLIR module
135 // containing the functions, expressed with tf_executor dialect. It does not
136 // require a session to be created. If `exported_names` is std::nullopt, all
137 // signatures will be imported. Otherwise, only names in `exported_names` are
138 // imported.
139 
140 //
141 // Note that the word `Lite` means it is a lighter version compared to
142 // ConvertSavedModelV1ToMlir(), and is not related to TFLite.
143 //
144 // TODO(b/179683149): Rename this class to avoid confusion with TFLite.
145 stream_executor::port::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>>
146 ConvertSavedModelV1ToMlirLite(
147     SavedModelMLIRImportInput& input,
148     std::optional<absl::Span<const std::string>> exported_names,
149     mlir::MLIRContext* context,
150     bool unconditionally_use_set_output_shapes = false);
151 
152 // Serialize a MLIR module to a string.
153 std::string MlirModuleToString(mlir::ModuleOp module,
154                                mlir::OpPrintingFlags flags);
155 std::string MlirModuleToString(mlir::ModuleOp m, bool show_debug_info = false);
156 
157 }  // namespace tensorflow
158 
159 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_
160