1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
17
18 #include <string>
19 #include <unordered_set>
20 #include <utility>
21
22 #include "absl/types/span.h"
23 #include "llvm/Support/raw_ostream.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
25 #include "mlir/IR/Attributes.h" // from @llvm-project
26 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
27 #include "mlir/IR/Visitors.h" // from @llvm-project
28 #include "mlir/Parser/Parser.h" // from @llvm-project
29 #include "mlir/Pass/Pass.h" // from @llvm-project
30 #include "mlir/Pass/PassManager.h" // from @llvm-project
31 #include "mlir/Support/FileUtilities.h" // from @llvm-project
32 #include "mlir/Support/LogicalResult.h" // from @llvm-project
33 #include "mlir/Transforms/Passes.h" // from @llvm-project
34 #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
35 #include "tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h"
36 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
37 #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
38 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
40 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
41 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.h"
42 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
43 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
44 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
45 #include "tensorflow/core/framework/op.h"
46 #include "tensorflow/core/framework/op_def.pb.h"
47 #include "tensorflow/core/framework/types.pb.h"
48 #include "tensorflow/lite/tools/optimize/quantize_weights.h"
49 #include "tensorflow/lite/tools/optimize/reduced_precision_support.h"
50 #include "tensorflow/stream_executor/lib/statusor.h"
51
52 namespace tensorflow {
53 namespace {
54 using mlir::MLIRContext;
55 using mlir::ModuleOp;
56 using mlir::Operation;
57 using mlir::OwningOpRef;
58 using stream_executor::port::StatusOr;
59
IsControlFlowV1Op(Operation * op)60 bool IsControlFlowV1Op(Operation* op) {
61 return mlir::isa<mlir::tf_executor::SwitchOp, mlir::tf_executor::MergeOp,
62 mlir::tf_executor::EnterOp, mlir::tf_executor::ExitOp,
63 mlir::tf_executor::NextIterationSinkOp,
64 mlir::tf_executor::NextIterationSourceOp>(op);
65 }
66
IsValidGraph(mlir::ModuleOp module)67 mlir::LogicalResult IsValidGraph(mlir::ModuleOp module) {
68 auto result = module.walk([&](Operation* op) {
69 return IsControlFlowV1Op(op) ? mlir::WalkResult::interrupt()
70 : mlir::WalkResult::advance();
71 });
72 if (result.wasInterrupted()) {
73 mlir::TFL::AttachErrorCode(
74 module.emitError(
75 "The graph has Control Flow V1 ops. TFLite converter doesn't "
76 "support Control Flow V1 ops. Consider using Control Flow V2 ops "
77 "instead. See https://www.tensorflow.org/api_docs/python/tf/compat/"
78 "v1/enable_control_flow_v2."),
79 tflite::metrics::ConverterErrorData::ERROR_UNSUPPORTED_CONTROL_FLOW_V1);
80 return mlir::failure();
81 }
82 return mlir::success();
83 }
84
85 // Util that registers 'extra_tf_opdefs' to the TF global registry.
86 // Return OK on success, failure if registering failed.
RegisterExtraTfOpDefs(absl::Span<const std::string> extra_tf_opdefs)87 Status RegisterExtraTfOpDefs(absl::Span<const std::string> extra_tf_opdefs) {
88 for (const auto& tf_opdefs_string : extra_tf_opdefs) {
89 tensorflow::OpDef opdef;
90 if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string,
91 &opdef)) {
92 LOG(ERROR) << "OpDef parsing failed for: " << tf_opdefs_string;
93 return errors::InvalidArgument("fail to parse extra OpDef");
94 }
95 // Register extra opdefs.
96 // TODO(b/133770952): Support shape functions.
97 tensorflow::OpRegistry::Global()->Register(
98 [opdef](tensorflow::OpRegistrationData* op_reg_data) -> Status {
99 *op_reg_data = tensorflow::OpRegistrationData(opdef);
100 return OkStatus();
101 });
102 }
103 return OkStatus();
104 }
105 } // namespace
106
LoadFromGraphdefOrMlirSource(const std::string & input_filename,bool input_mlir,bool use_splatted_constant,const std::vector<std::string> & extra_tf_opdefs,const GraphImportConfig & specs,absl::string_view debug_info_file,absl::string_view input_arrays,absl::string_view input_dtypes,absl::string_view input_shapes,absl::string_view output_arrays,absl::string_view control_output_arrays,llvm::SourceMgr * source_mgr,MLIRContext * context)107 StatusOr<OwningOpRef<ModuleOp>> LoadFromGraphdefOrMlirSource(
108 const std::string& input_filename, bool input_mlir,
109 bool use_splatted_constant, const std::vector<std::string>& extra_tf_opdefs,
110 const GraphImportConfig& specs, absl::string_view debug_info_file,
111 absl::string_view input_arrays, absl::string_view input_dtypes,
112 absl::string_view input_shapes, absl::string_view output_arrays,
113 absl::string_view control_output_arrays, llvm::SourceMgr* source_mgr,
114 MLIRContext* context) {
115 // Set up the input file.
116 std::string error_message;
117 auto file = mlir::openInputFile(input_filename, &error_message);
118 if (!file) {
119 llvm::errs() << error_message << "\n";
120 return errors::InvalidArgument("fail to open input file");
121 }
122
123 if (input_mlir) {
124 source_mgr->AddNewSourceBuffer(std::move(file), llvm::SMLoc());
125 return OwningOpRef<ModuleOp>(
126 mlir::parseSourceFile<mlir::ModuleOp>(*source_mgr, context));
127 }
128
129 // Register extra TF ops passed as OpDef.
130 auto extra_opdefs_status = RegisterExtraTfOpDefs(extra_tf_opdefs);
131 if (!extra_opdefs_status.ok()) return extra_opdefs_status;
132
133 if (use_splatted_constant) {
134 return tensorflow::GraphdefToSplattedMlirTranslateFunction(
135 file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
136 input_shapes, output_arrays, control_output_arrays,
137 specs.prune_unused_nodes, /*convert_legacy_fed_inputs=*/true,
138 /*graph_as_function=*/false, specs.upgrade_legacy,
139 /*enable_shape_inference=*/false,
140 /*unconditionally_use_set_output_shapes=*/true, context);
141 }
142 return tensorflow::GraphdefToMlirTranslateFunction(
143 file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
144 input_shapes, output_arrays, control_output_arrays,
145 specs.prune_unused_nodes, /*convert_legacy_fed_inputs=*/true,
146 /*graph_as_function=*/false, specs.upgrade_legacy,
147 /*enable_shape_inference=*/false,
148 /*unconditionally_use_set_output_shapes=*/true, context);
149 }
150
151 // Applying post-training dynamic range quantization from the old TOCO quantizer
152 // on the translated_result using quant_specs and saving the final output in
153 // result.
ApplyDynamicRangeQuantizationFromOldQuantizer(const mlir::quant::QuantizationSpecs & quant_specs,std::string translated_result,std::string * result)154 Status ApplyDynamicRangeQuantizationFromOldQuantizer(
155 const mlir::quant::QuantizationSpecs& quant_specs,
156 std::string translated_result, std::string* result) {
157 flatbuffers::FlatBufferBuilder q_builder(/*initial_size=*/10240);
158 const uint8_t* buffer =
159 reinterpret_cast<const uint8_t*>(translated_result.c_str());
160 const ::tflite::Model* input_model = ::tflite::GetModel(buffer);
161
162 ::tflite::optimize::BufferType quantized_type;
163 switch (quant_specs.inference_type) {
164 case tensorflow::DT_QINT8:
165 quantized_type = ::tflite::optimize::BufferType::QUANTIZED_INT8;
166 break;
167 case tensorflow::DT_HALF:
168 quantized_type = ::tflite::optimize::BufferType::QUANTIZED_FLOAT16;
169 break;
170 default:
171 return errors::InvalidArgument("Quantized type not supported");
172 break;
173 }
174
175 bool use_updated_hybrid_scheme = !quant_specs.disable_per_channel;
176 if (::tflite::optimize::QuantizeWeights(
177 &q_builder, input_model, quantized_type, use_updated_hybrid_scheme,
178 ::tflite::optimize::QuantizerType::OLD_QUANTIZER) != kTfLiteOk) {
179 return errors::InvalidArgument("Quantize weights transformation failed.");
180 }
181 const uint8_t* q_buffer = q_builder.GetBufferPointer();
182 *result =
183 string(reinterpret_cast<const char*>(q_buffer), q_builder.GetSize());
184
185 return OkStatus();
186 }
187
ConvertTFExecutorToTFLOrFlatbuffer(mlir::ModuleOp module,bool export_to_mlir,const toco::TocoFlags & toco_flags,const mlir::TFL::PassConfig & pass_config,const std::unordered_set<std::string> & saved_model_tags,llvm::StringRef saved_model_dir,llvm::Optional<tensorflow::Session * > session,std::string * result)188 Status ConvertTFExecutorToTFLOrFlatbuffer(
189 mlir::ModuleOp module, bool export_to_mlir,
190 const toco::TocoFlags& toco_flags, const mlir::TFL::PassConfig& pass_config,
191 const std::unordered_set<std::string>& saved_model_tags,
192 llvm::StringRef saved_model_dir,
193 llvm::Optional<tensorflow::Session*> session, std::string* result) {
194 // Explicitly disable dumping Op details on failures.
195 module.getContext()->printOpOnDiagnostic(false);
196
197 // Register a warning handler only log to std out.
198 mlir::ScopedDiagnosticHandler s(
199 module.getContext(), [](mlir::Diagnostic& diag) {
200 if (diag.getSeverity() == mlir::DiagnosticSeverity::Warning) {
201 for (auto& note : diag.getNotes()) {
202 std::cout << note.str() << "\n";
203 LOG(WARNING) << note.str() << "\n";
204 }
205 }
206 return mlir::failure();
207 });
208
209 mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext(),
210 /*propagate=*/true);
211
212 if (failed(IsValidGraph(module))) {
213 return statusHandler.ConsumeStatus();
214 }
215
216 mlir::PassManager pass_manager(module.getContext());
217 mlir::registerPassManagerCLOptions();
218 mlir::applyPassManagerCLOptions(pass_manager);
219 pass_manager.addInstrumentation(
220 std::make_unique<mlir::TFL::ErrorCollectorInstrumentation>(
221 pass_manager.getContext()));
222
223 tensorflow::AddPreVariableFreezingTFToTFLConversionPasses(pass_config,
224 &pass_manager);
225 if (failed(pass_manager.run(module))) {
226 return statusHandler.ConsumeStatus();
227 }
228
229 // Freeze variables if a session is provided.
230 if (session.has_value()) {
231 mlir::TFL::ErrorCollectorInstrumentation collector(module.getContext());
232 if (failed(mlir::tf_saved_model::FreezeVariables(module,
233 session.getValue()))) {
234 auto status = statusHandler.ConsumeStatus();
235 mlir::TFL::ErrorCollector* collector =
236 mlir::TFL::ErrorCollector::GetErrorCollector();
237 if (!collector->CollectedErrors().empty()) {
238 // LINT.IfChange
239 return errors::InvalidArgument(
240 "Variable constant folding is failed. Please consider using "
241 "enabling `experimental_enable_resource_variables` flag in the "
242 "TFLite converter object. For example, "
243 "converter.experimental_enable_resource_variables = True");
244 // LINT.ThenChange(//tensorflow/lite/python/lite_v2_test.py)
245 }
246 return status;
247 }
248 }
249 pass_manager.clear();
250 tensorflow::AddPostVariableFreezingTFToTFLConversionPasses(
251 saved_model_dir, toco_flags, pass_config, &pass_manager);
252 if (failed(pass_manager.run(module))) {
253 auto status = statusHandler.ConsumeStatus();
254 mlir::TFL::ErrorCollector* collector =
255 mlir::TFL::ErrorCollector::GetErrorCollector();
256 for (const auto& error_data : collector->CollectedErrors()) {
257 if (error_data.subcomponent() == "FreezeGlobalTensorsPass") {
258 // LINT.IfChange
259 return errors::InvalidArgument(
260 "Variable constant folding is failed. Please consider using "
261 "enabling `experimental_enable_resource_variables` flag in the "
262 "TFLite converter object. For example, "
263 "converter.experimental_enable_resource_variables = True");
264 // LINT.ThenChange(//tensorflow/lite/python/lite_v2_test.py)
265 }
266 }
267 return status;
268 }
269
270 if (export_to_mlir) {
271 llvm::raw_string_ostream os(*result);
272 module.print(os);
273 return statusHandler.ConsumeStatus();
274 }
275
276 // Write MLIR TFLite dialect into FlatBuffer
277 const mlir::quant::QuantizationSpecs& quant_specs = pass_config.quant_specs;
278 OpOrArgLocNameMapper op_or_arg_name_mapper;
279 tflite::FlatbufferExportOptions options;
280 std::string translated_result;
281 options.toco_flags = toco_flags;
282 options.saved_model_tags = saved_model_tags;
283 options.op_or_arg_name_mapper = &op_or_arg_name_mapper;
284 if (quant_specs.support_mask !=
285 tflite::optimize::ReducedPrecisionSupport::None) {
286 options.metadata.insert(
287 MetadataForReducedPrecisionSupport(quant_specs.support_mask));
288 }
289 if (!tflite::MlirToFlatBufferTranslateFunction(module, options,
290 &translated_result)) {
291 return statusHandler.ConsumeStatus();
292 }
293
294 // TODO(b/176267167): Quantize flex fallback in the MLIR pipeline
295 if (quant_specs.weight_quantization &&
296 (!quant_specs.RunAndRewriteDynamicRangeQuantizationPasses() ||
297 !pass_config.emit_builtin_tflite_ops)) {
298 // Apply post-training dynamic range quantization from the old TOCO
299 // quantizer.Once MLIR has support for this, we can remove this if
300 // statement.
301 auto status = ApplyDynamicRangeQuantizationFromOldQuantizer(
302 quant_specs, translated_result, result);
303 if (!status.ok()) return status;
304 } else {
305 *result = translated_result;
306 }
307
308 if (mlir::failed(module.verifyInvariants())) {
309 return tensorflow::errors::Unknown("Final module is invalid");
310 }
311 return OkStatus();
312 }
313
ImportSavedModel(const std::string & input_filename,const int saved_model_version,const std::unordered_set<std::string> & tags,absl::Span<const std::string> extra_tf_opdefs,absl::Span<std::string> exported_names,const GraphImportConfig & specs,bool enable_variable_lifting,mlir::MLIRContext * context,std::unique_ptr<tensorflow::SavedModelBundle> * saved_model_bundle)314 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ImportSavedModel(
315 const std::string& input_filename, const int saved_model_version,
316 const std::unordered_set<std::string>& tags,
317 absl::Span<const std::string> extra_tf_opdefs,
318 absl::Span<std::string> exported_names, const GraphImportConfig& specs,
319 bool enable_variable_lifting, mlir::MLIRContext* context,
320 std::unique_ptr<tensorflow::SavedModelBundle>* saved_model_bundle) {
321 // Register extra TF ops passed as OpDef.
322 auto extra_opdefs_status = RegisterExtraTfOpDefs(extra_tf_opdefs);
323 if (!extra_opdefs_status.ok()) return extra_opdefs_status;
324
325 if (saved_model_version == 2) {
326 auto module_or = tensorflow::SavedModelObjectGraphToMlirImport(
327 input_filename, tags, exported_names, context,
328 /*unconditionally_use_set_output_shapes=*/true);
329 if (!module_or.status().ok()) return module_or.status();
330 return std::move(module_or).value();
331 } else if (saved_model_version == 1) {
332 MLIRImportOptions options;
333 options.upgrade_legacy = specs.upgrade_legacy;
334 options.unconditionally_use_set_output_shapes = true;
335 auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport(
336 input_filename, tags, exported_names, context, options,
337 enable_variable_lifting, saved_model_bundle);
338
339 if (!module_or.status().ok()) return module_or.status();
340 return std::move(module_or).value();
341 } else {
342 return tensorflow::errors::InvalidArgument(
343 "Should be either saved model v1 or v2");
344 }
345 }
346
347 } // namespace tensorflow
348