xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_
17 #define TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_
18 
19 #include <string>
20 #include <unordered_set>
21 
22 #include "absl/types/span.h"
23 #include "llvm/Support/SourceMgr.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
26 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
27 #include "mlir/Pass/PassManager.h"  // from @llvm-project
28 #include "tensorflow/cc/saved_model/loader.h"
29 #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
30 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
31 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
32 #include "tensorflow/lite/toco/toco_flags.pb.h"
33 #include "tensorflow/stream_executor/lib/statusor.h"
34 
35 namespace tensorflow {
36 
37 // Load a TF model from a GraphDef definition or a TF control flow dialect MLIR
38 // source into a MLIR module. If `input_mlir` is true, load from a MLIR source
39 // file; otherwise, load from a GraphDef.
40 // Setting prune_unused_nodes to true, would prune unreachable nodes if
41 // output_arrays is specified.
42 stream_executor::port::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>>
43 LoadFromGraphdefOrMlirSource(
44     const std::string& input_filename, bool input_mlir,
45     bool use_splatted_constant, const std::vector<std::string>& extra_tf_opdefs,
46     const GraphImportConfig& specs, absl::string_view debug_info_file,
47     absl::string_view input_arrays, absl::string_view input_dtypes,
48     absl::string_view input_shapes, absl::string_view output_arrays,
49     absl::string_view control_output_arrays, llvm::SourceMgr* source_mgr,
50     mlir::MLIRContext* context);
51 
52 // Load Saved model (either v1 or v2) into MLIR.
53 // 'saved_model_bundle' will be initialized if V1 model was loaded.
54 stream_executor::port::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>>
55 ImportSavedModel(
56     const std::string& input_filename, const int saved_model_version,
57     const std::unordered_set<std::string>& tags,
58     absl::Span<const std::string> extra_tf_opdefs,
59     absl::Span<std::string> exported_names, const GraphImportConfig& specs,
60     bool enable_variable_lifting, mlir::MLIRContext* context,
61     std::unique_ptr<tensorflow::SavedModelBundle>* saved_model_bundle);
62 
63 // Taking a MLIR module in TF executor dialect and a set of parameters,
64 // applies a set of passes (configured accordingly to the provided
65 // `pass_config`) to convert the module to TF Lite dialect and serializes the
66 // result to a string. Depending on an attribute in the module main function,
67 // full integer quantization is applied.
68 // * `quantizated_buffer_type` can be set to INT8 or FLOAT16 to trigger the
69 // corresponding weight quantization.
70 // * `export_to_mlir` enables exporting to MLIR text format, otherwise exported
71 // in flat buffer. If the
72 // * `session` pointer may provided, it will be used to freeze resource
73 // variables. If the `saved_model_dir` directory path is provided, then the
74 // `tf_saved_model.asset` ops will be freezed.
75 Status ConvertTFExecutorToTFLOrFlatbuffer(
76     mlir::ModuleOp module, bool export_to_mlir,
77     const toco::TocoFlags& toco_flags, const mlir::TFL::PassConfig& pass_config,
78     const std::unordered_set<std::string>& saved_model_tags,
79     llvm::StringRef saved_model_dir,
80     llvm::Optional<tensorflow::Session*> session, std::string* result);
81 }  // namespace tensorflow
82 
83 #endif  // TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_
84