1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h"
16
17 #include <memory>
18 #include <string>
19 #include <utility>
20
21 #include "absl/types/span.h"
22 #include "llvm/ADT/None.h"
23 #include "llvm/ADT/StringSet.h"
24 #include "llvm/Support/ToolOutputFile.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
26 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
27 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
28 #include "mlir/IR/MLIRContext.h" // from @llvm-project
29 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
30 #include "mlir/Pass/Pass.h" // from @llvm-project
31 #include "mlir/Support/FileUtilities.h" // from @llvm-project
32 #include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project
33 #include "tensorflow/cc/saved_model/loader.h"
34 #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
35 #include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h"
36 #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
37 #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
38 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
39 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
40 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
41 #include "tensorflow/core/framework/graph.pb.h"
42 #include "tensorflow/core/framework/types.pb.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/platform/status.h"
45 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
46 #include "tensorflow/lite/toco/model_flags.pb.h"
47 #include "tensorflow/lite/toco/toco_flags.pb.h"
48 #include "tensorflow/lite/toco/types.pb.h"
49 #include "tensorflow/stream_executor/lib/statusor.h"
50
51 namespace tensorflow {
52
HandleInputOutputArraysWithModule(const toco::ModelFlags & model_flags,mlir::OwningOpRef<mlir::ModuleOp> * module)53 Status HandleInputOutputArraysWithModule(
54 const toco::ModelFlags& model_flags,
55 mlir::OwningOpRef<mlir::ModuleOp>* module) {
56 mlir::func::FuncOp entry_function = nullptr;
57 for (auto func : module->get().getOps<mlir::func::FuncOp>()) {
58 if (auto tf_attrs =
59 func->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function")) {
60 // TODO(b/184697652): There could be multiple entry functions. Let's
61 // handle such cases if there are any needs for that.
62 if (entry_function != nullptr) {
63 return errors::InvalidArgument(
64 "There should be only one tf.entry_function");
65 }
66 entry_function = func;
67 }
68 }
69 if (entry_function == nullptr) {
70 return errors::InvalidArgument("no tf.entry_function found");
71 }
72
73 // Get the list of input Op names from the function attribute.
74 mlir::DictionaryAttr tf_attrs =
75 entry_function->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
76 llvm::SmallVector<llvm::StringRef, 4> function_input_names;
77 function_input_names.reserve(model_flags.input_arrays().size());
78 auto input_attr = tf_attrs.get("inputs");
79 if (!input_attr) {
80 return errors::InvalidArgument("no inputs attribute found");
81 }
82 auto input_names = input_attr.cast<mlir::StringAttr>().getValue();
83 input_names.split(function_input_names, ",", /*MaxSplit=*/-1,
84 /*KeepEmpty=*/false);
85 const int function_input_names_size = function_input_names.size();
86 if (function_input_names_size != model_flags.input_arrays().size()) {
87 return errors::InvalidArgument(
88 "input array size mismatch: got ", function_input_names.size(),
89 ", expected: ", model_flags.input_arrays().size());
90 }
91 llvm::StringSet<> function_input_names_set;
92 function_input_names_set.insert(function_input_names.begin(),
93 function_input_names.end());
94 for (const auto& input_array : model_flags.input_arrays()) {
95 if (function_input_names_set.count(input_array.name()) == 0) {
96 return errors::InvalidArgument("input array name (", input_array.name(),
97 ") does not exist in the given graph");
98 }
99 }
100
101 // Get the list of output Op names from the function attribute.
102 llvm::SmallVector<llvm::StringRef, 4> function_output_names;
103 function_output_names.reserve(model_flags.output_arrays().size());
104 auto output_attr = tf_attrs.get("outputs");
105 if (!output_attr) {
106 return errors::InvalidArgument("no outputs attribute found");
107 }
108 auto output_names = output_attr.cast<mlir::StringAttr>().getValue();
109 output_names.split(function_output_names, ",", /*MaxSplit=*/-1,
110 /*KeepEmpty=*/false);
111 const int function_output_names_size = function_output_names.size();
112 if (function_output_names_size != model_flags.output_arrays().size()) {
113 return errors::InvalidArgument(
114 "output array size mismatch: got ", function_output_names.size(),
115 ", expected: ", model_flags.output_arrays().size());
116 }
117 llvm::StringSet<> function_output_names_set;
118 function_output_names_set.insert(function_output_names.begin(),
119 function_output_names.end());
120 for (const auto& output_array : model_flags.output_arrays()) {
121 if (function_output_names_set.count(output_array) == 0) {
122 return errors::InvalidArgument("output array name (", output_array,
123 ") does not exist in the given graph");
124 }
125 }
126 return OkStatus();
127 }
128
ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags & model_flags,const toco::TocoFlags & toco_flags,string * result)129 Status ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
130 const toco::TocoFlags& toco_flags,
131 string* result) {
132 mlir::MLIRContext context;
133 mlir::quant::QuantizationSpecs quant_specs;
134
135 // Parse input arrays.
136 std::vector<string> node_names;
137 std::vector<string> node_dtypes;
138 std::vector<llvm::Optional<std::vector<int>>> node_shapes;
139 std::vector<llvm::Optional<double>> node_mins;
140 std::vector<llvm::Optional<double>> node_maxs;
141
142 // Populate quantization specs.
143 TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs(
144 model_flags, toco_flags, &quant_specs, &node_names, &node_dtypes,
145 &node_shapes, &node_mins, &node_maxs));
146
147 internal::WarningUnusedFlags(model_flags, toco_flags);
148
149 // Register all custom ops, including user-specified custom ops.
150 TF_RETURN_IF_ERROR(internal::RegisterAllCustomOps(toco_flags));
151
152 auto& saved_model_tags = model_flags.saved_model_tags();
153 auto& saved_model_exported_names = model_flags.saved_model_exported_names();
154 std::unordered_set<std::string> tags(saved_model_tags.begin(),
155 saved_model_tags.end());
156 auto exported_names_in_vector = std::vector<std::string>(
157 saved_model_exported_names.begin(), saved_model_exported_names.end());
158 absl::Span<std::string> exported_names(exported_names_in_vector);
159
160 if (exported_names.empty()) {
161 return errors::Unimplemented("Need at least one exported name.");
162 }
163
164 tensorflow::GraphImportConfig specs;
165 specs.upgrade_legacy = true;
166
167 std::vector<std::string> custom_opdefs(toco_flags.custom_opdefs().begin(),
168 toco_flags.custom_opdefs().end());
169 auto bundle = std::make_unique<tensorflow::SavedModelBundle>();
170 TF_ASSIGN_OR_RETURN(
171 auto module,
172 ImportSavedModel(
173 model_flags.saved_model_dir(), model_flags.saved_model_version(),
174 tags, absl::MakeSpan(custom_opdefs), exported_names, specs,
175 !toco_flags.enable_tflite_resource_variables(), &context, &bundle));
176
177 if (!model_flags.input_arrays().empty() ||
178 !model_flags.output_arrays().empty()) {
179 TF_RETURN_IF_ERROR(HandleInputOutputArraysWithModule(model_flags, &module));
180 }
181
182 mlir::TFL::PassConfig pass_config(quant_specs);
183 bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
184 pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
185 pass_config.enable_tflite_variables =
186 toco_flags.enable_tflite_resource_variables();
187 pass_config.unfold_batch_matmul = toco_flags.unfold_batchmatmul();
188 pass_config.lower_tensor_list_ops = toco_flags.lower_tensor_list_ops();
189 // Disable the unfolding of the 16x16 TF::BatchMatMulOp to avoid the
190 // conversion to an unsupported 16x16 TFL::FullyConnectedOp.
191 if (toco_flags.inference_type() == toco::IODataType::QUANTIZED_INT16) {
192 pass_config.unfold_batch_matmul = false;
193 }
194 pass_config.unfold_large_splat_constant =
195 toco_flags.unfold_large_splat_constant();
196 pass_config.enable_dynamic_update_slice =
197 toco_flags.enable_dynamic_update_slice();
198 pass_config.preserve_assert_op = toco_flags.preserve_assert_op();
199 pass_config.guarantee_all_funcs_one_use =
200 toco_flags.guarantee_all_funcs_one_use();
201
202 // TODO(b/153507667): Pass the session object when importing logic is removed.
203 auto status = internal::ConvertMLIRToTFLiteFlatBuffer(
204 model_flags, toco_flags, std::move(module), pass_config, tags, result,
205 bundle ? bundle->GetSession() : nullptr);
206 return status;
207 }
208
209 } // namespace tensorflow
210