xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
17 
18 #include <string>
19 #include <unordered_set>
20 #include <utility>
21 
22 #include "absl/types/span.h"
23 #include "llvm/Support/raw_ostream.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
25 #include "mlir/IR/Attributes.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
27 #include "mlir/IR/Visitors.h"  // from @llvm-project
28 #include "mlir/Parser/Parser.h"  // from @llvm-project
29 #include "mlir/Pass/Pass.h"  // from @llvm-project
30 #include "mlir/Pass/PassManager.h"  // from @llvm-project
31 #include "mlir/Support/FileUtilities.h"  // from @llvm-project
32 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
33 #include "mlir/Transforms/Passes.h"  // from @llvm-project
34 #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
35 #include "tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h"
36 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
37 #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
38 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
40 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
41 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.h"
42 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
43 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
44 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
45 #include "tensorflow/core/framework/op.h"
46 #include "tensorflow/core/framework/op_def.pb.h"
47 #include "tensorflow/core/framework/types.pb.h"
48 #include "tensorflow/lite/tools/optimize/quantize_weights.h"
49 #include "tensorflow/lite/tools/optimize/reduced_precision_support.h"
50 #include "tensorflow/stream_executor/lib/statusor.h"
51 
52 namespace tensorflow {
53 namespace {
54 using mlir::MLIRContext;
55 using mlir::ModuleOp;
56 using mlir::Operation;
57 using mlir::OwningOpRef;
58 using stream_executor::port::StatusOr;
59 
IsControlFlowV1Op(Operation * op)60 bool IsControlFlowV1Op(Operation* op) {
61   return mlir::isa<mlir::tf_executor::SwitchOp, mlir::tf_executor::MergeOp,
62                    mlir::tf_executor::EnterOp, mlir::tf_executor::ExitOp,
63                    mlir::tf_executor::NextIterationSinkOp,
64                    mlir::tf_executor::NextIterationSourceOp>(op);
65 }
66 
IsValidGraph(mlir::ModuleOp module)67 mlir::LogicalResult IsValidGraph(mlir::ModuleOp module) {
68   auto result = module.walk([&](Operation* op) {
69     return IsControlFlowV1Op(op) ? mlir::WalkResult::interrupt()
70                                  : mlir::WalkResult::advance();
71   });
72   if (result.wasInterrupted()) {
73     mlir::TFL::AttachErrorCode(
74         module.emitError(
75             "The graph has Control Flow V1 ops. TFLite converter doesn't "
76             "support Control Flow V1 ops. Consider using Control Flow V2 ops "
77             "instead. See https://www.tensorflow.org/api_docs/python/tf/compat/"
78             "v1/enable_control_flow_v2."),
79         tflite::metrics::ConverterErrorData::ERROR_UNSUPPORTED_CONTROL_FLOW_V1);
80     return mlir::failure();
81   }
82   return mlir::success();
83 }
84 
85 // Util that registers 'extra_tf_opdefs' to the TF global registry.
86 // Return OK on success, failure if registering failed.
RegisterExtraTfOpDefs(absl::Span<const std::string> extra_tf_opdefs)87 Status RegisterExtraTfOpDefs(absl::Span<const std::string> extra_tf_opdefs) {
88   for (const auto& tf_opdefs_string : extra_tf_opdefs) {
89     tensorflow::OpDef opdef;
90     if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string,
91                                                            &opdef)) {
92       LOG(ERROR) << "OpDef parsing failed for: " << tf_opdefs_string;
93       return errors::InvalidArgument("fail to parse extra OpDef");
94     }
95     // Register extra opdefs.
96     // TODO(b/133770952): Support shape functions.
97     tensorflow::OpRegistry::Global()->Register(
98         [opdef](tensorflow::OpRegistrationData* op_reg_data) -> Status {
99           *op_reg_data = tensorflow::OpRegistrationData(opdef);
100           return OkStatus();
101         });
102   }
103   return OkStatus();
104 }
105 }  // namespace
106 
LoadFromGraphdefOrMlirSource(const std::string & input_filename,bool input_mlir,bool use_splatted_constant,const std::vector<std::string> & extra_tf_opdefs,const GraphImportConfig & specs,absl::string_view debug_info_file,absl::string_view input_arrays,absl::string_view input_dtypes,absl::string_view input_shapes,absl::string_view output_arrays,absl::string_view control_output_arrays,llvm::SourceMgr * source_mgr,MLIRContext * context)107 StatusOr<OwningOpRef<ModuleOp>> LoadFromGraphdefOrMlirSource(
108     const std::string& input_filename, bool input_mlir,
109     bool use_splatted_constant, const std::vector<std::string>& extra_tf_opdefs,
110     const GraphImportConfig& specs, absl::string_view debug_info_file,
111     absl::string_view input_arrays, absl::string_view input_dtypes,
112     absl::string_view input_shapes, absl::string_view output_arrays,
113     absl::string_view control_output_arrays, llvm::SourceMgr* source_mgr,
114     MLIRContext* context) {
115   // Set up the input file.
116   std::string error_message;
117   auto file = mlir::openInputFile(input_filename, &error_message);
118   if (!file) {
119     llvm::errs() << error_message << "\n";
120     return errors::InvalidArgument("fail to open input file");
121   }
122 
123   if (input_mlir) {
124     source_mgr->AddNewSourceBuffer(std::move(file), llvm::SMLoc());
125     return OwningOpRef<ModuleOp>(
126         mlir::parseSourceFile<mlir::ModuleOp>(*source_mgr, context));
127   }
128 
129   // Register extra TF ops passed as OpDef.
130   auto extra_opdefs_status = RegisterExtraTfOpDefs(extra_tf_opdefs);
131   if (!extra_opdefs_status.ok()) return extra_opdefs_status;
132 
133   if (use_splatted_constant) {
134     return tensorflow::GraphdefToSplattedMlirTranslateFunction(
135         file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
136         input_shapes, output_arrays, control_output_arrays,
137         specs.prune_unused_nodes, /*convert_legacy_fed_inputs=*/true,
138         /*graph_as_function=*/false, specs.upgrade_legacy,
139         /*enable_shape_inference=*/false,
140         /*unconditionally_use_set_output_shapes=*/true, context);
141   }
142   return tensorflow::GraphdefToMlirTranslateFunction(
143       file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
144       input_shapes, output_arrays, control_output_arrays,
145       specs.prune_unused_nodes, /*convert_legacy_fed_inputs=*/true,
146       /*graph_as_function=*/false, specs.upgrade_legacy,
147       /*enable_shape_inference=*/false,
148       /*unconditionally_use_set_output_shapes=*/true, context);
149 }
150 
151 // Applying post-training dynamic range quantization from the old TOCO quantizer
152 // on the translated_result using quant_specs and saving the final output in
153 // result.
ApplyDynamicRangeQuantizationFromOldQuantizer(const mlir::quant::QuantizationSpecs & quant_specs,std::string translated_result,std::string * result)154 Status ApplyDynamicRangeQuantizationFromOldQuantizer(
155     const mlir::quant::QuantizationSpecs& quant_specs,
156     std::string translated_result, std::string* result) {
157   flatbuffers::FlatBufferBuilder q_builder(/*initial_size=*/10240);
158   const uint8_t* buffer =
159       reinterpret_cast<const uint8_t*>(translated_result.c_str());
160   const ::tflite::Model* input_model = ::tflite::GetModel(buffer);
161 
162   ::tflite::optimize::BufferType quantized_type;
163   switch (quant_specs.inference_type) {
164     case tensorflow::DT_QINT8:
165       quantized_type = ::tflite::optimize::BufferType::QUANTIZED_INT8;
166       break;
167     case tensorflow::DT_HALF:
168       quantized_type = ::tflite::optimize::BufferType::QUANTIZED_FLOAT16;
169       break;
170     default:
171       return errors::InvalidArgument("Quantized type not supported");
172       break;
173   }
174 
175   bool use_updated_hybrid_scheme = !quant_specs.disable_per_channel;
176   if (::tflite::optimize::QuantizeWeights(
177           &q_builder, input_model, quantized_type, use_updated_hybrid_scheme,
178           ::tflite::optimize::QuantizerType::OLD_QUANTIZER) != kTfLiteOk) {
179     return errors::InvalidArgument("Quantize weights transformation failed.");
180   }
181   const uint8_t* q_buffer = q_builder.GetBufferPointer();
182   *result =
183       string(reinterpret_cast<const char*>(q_buffer), q_builder.GetSize());
184 
185   return OkStatus();
186 }
187 
ConvertTFExecutorToTFLOrFlatbuffer(mlir::ModuleOp module,bool export_to_mlir,const toco::TocoFlags & toco_flags,const mlir::TFL::PassConfig & pass_config,const std::unordered_set<std::string> & saved_model_tags,llvm::StringRef saved_model_dir,llvm::Optional<tensorflow::Session * > session,std::string * result)188 Status ConvertTFExecutorToTFLOrFlatbuffer(
189     mlir::ModuleOp module, bool export_to_mlir,
190     const toco::TocoFlags& toco_flags, const mlir::TFL::PassConfig& pass_config,
191     const std::unordered_set<std::string>& saved_model_tags,
192     llvm::StringRef saved_model_dir,
193     llvm::Optional<tensorflow::Session*> session, std::string* result) {
194   // Explicitly disable dumping Op details on failures.
195   module.getContext()->printOpOnDiagnostic(false);
196 
197   // Register a warning handler only log to std out.
198   mlir::ScopedDiagnosticHandler s(
199       module.getContext(), [](mlir::Diagnostic& diag) {
200         if (diag.getSeverity() == mlir::DiagnosticSeverity::Warning) {
201           for (auto& note : diag.getNotes()) {
202             std::cout << note.str() << "\n";
203             LOG(WARNING) << note.str() << "\n";
204           }
205         }
206         return mlir::failure();
207       });
208 
209   mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext(),
210                                                     /*propagate=*/true);
211 
212   if (failed(IsValidGraph(module))) {
213     return statusHandler.ConsumeStatus();
214   }
215 
216   mlir::PassManager pass_manager(module.getContext());
217   mlir::registerPassManagerCLOptions();
218   mlir::applyPassManagerCLOptions(pass_manager);
219   pass_manager.addInstrumentation(
220       std::make_unique<mlir::TFL::ErrorCollectorInstrumentation>(
221           pass_manager.getContext()));
222 
223   tensorflow::AddPreVariableFreezingTFToTFLConversionPasses(pass_config,
224                                                             &pass_manager);
225   if (failed(pass_manager.run(module))) {
226     return statusHandler.ConsumeStatus();
227   }
228 
229   // Freeze variables if a session is provided.
230   if (session.has_value()) {
231     mlir::TFL::ErrorCollectorInstrumentation collector(module.getContext());
232     if (failed(mlir::tf_saved_model::FreezeVariables(module,
233                                                      session.getValue()))) {
234       auto status = statusHandler.ConsumeStatus();
235       mlir::TFL::ErrorCollector* collector =
236           mlir::TFL::ErrorCollector::GetErrorCollector();
237       if (!collector->CollectedErrors().empty()) {
238         // LINT.IfChange
239         return errors::InvalidArgument(
240             "Variable constant folding is failed. Please consider using "
241             "enabling `experimental_enable_resource_variables` flag in the "
242             "TFLite converter object. For example, "
243             "converter.experimental_enable_resource_variables = True");
244         // LINT.ThenChange(//tensorflow/lite/python/lite_v2_test.py)
245       }
246       return status;
247     }
248   }
249   pass_manager.clear();
250   tensorflow::AddPostVariableFreezingTFToTFLConversionPasses(
251       saved_model_dir, toco_flags, pass_config, &pass_manager);
252   if (failed(pass_manager.run(module))) {
253     auto status = statusHandler.ConsumeStatus();
254     mlir::TFL::ErrorCollector* collector =
255         mlir::TFL::ErrorCollector::GetErrorCollector();
256     for (const auto& error_data : collector->CollectedErrors()) {
257       if (error_data.subcomponent() == "FreezeGlobalTensorsPass") {
258         // LINT.IfChange
259         return errors::InvalidArgument(
260             "Variable constant folding is failed. Please consider using "
261             "enabling `experimental_enable_resource_variables` flag in the "
262             "TFLite converter object. For example, "
263             "converter.experimental_enable_resource_variables = True");
264         // LINT.ThenChange(//tensorflow/lite/python/lite_v2_test.py)
265       }
266     }
267     return status;
268   }
269 
270   if (export_to_mlir) {
271     llvm::raw_string_ostream os(*result);
272     module.print(os);
273     return statusHandler.ConsumeStatus();
274   }
275 
276   // Write MLIR TFLite dialect into FlatBuffer
277   const mlir::quant::QuantizationSpecs& quant_specs = pass_config.quant_specs;
278   OpOrArgLocNameMapper op_or_arg_name_mapper;
279   tflite::FlatbufferExportOptions options;
280   std::string translated_result;
281   options.toco_flags = toco_flags;
282   options.saved_model_tags = saved_model_tags;
283   options.op_or_arg_name_mapper = &op_or_arg_name_mapper;
284   if (quant_specs.support_mask !=
285       tflite::optimize::ReducedPrecisionSupport::None) {
286     options.metadata.insert(
287         MetadataForReducedPrecisionSupport(quant_specs.support_mask));
288   }
289   if (!tflite::MlirToFlatBufferTranslateFunction(module, options,
290                                                  &translated_result)) {
291     return statusHandler.ConsumeStatus();
292   }
293 
294   // TODO(b/176267167): Quantize flex fallback in the MLIR pipeline
295   if (quant_specs.weight_quantization &&
296       (!quant_specs.RunAndRewriteDynamicRangeQuantizationPasses() ||
297        !pass_config.emit_builtin_tflite_ops)) {
298     // Apply post-training dynamic range quantization from the old TOCO
299     // quantizer.Once MLIR has support for this, we can remove this if
300     // statement.
301     auto status = ApplyDynamicRangeQuantizationFromOldQuantizer(
302         quant_specs, translated_result, result);
303     if (!status.ok()) return status;
304   } else {
305     *result = translated_result;
306   }
307 
308   if (mlir::failed(module.verifyInvariants())) {
309     return tensorflow::errors::Unknown("Final module is invalid");
310   }
311   return OkStatus();
312 }
313 
ImportSavedModel(const std::string & input_filename,const int saved_model_version,const std::unordered_set<std::string> & tags,absl::Span<const std::string> extra_tf_opdefs,absl::Span<std::string> exported_names,const GraphImportConfig & specs,bool enable_variable_lifting,mlir::MLIRContext * context,std::unique_ptr<tensorflow::SavedModelBundle> * saved_model_bundle)314 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ImportSavedModel(
315     const std::string& input_filename, const int saved_model_version,
316     const std::unordered_set<std::string>& tags,
317     absl::Span<const std::string> extra_tf_opdefs,
318     absl::Span<std::string> exported_names, const GraphImportConfig& specs,
319     bool enable_variable_lifting, mlir::MLIRContext* context,
320     std::unique_ptr<tensorflow::SavedModelBundle>* saved_model_bundle) {
321   // Register extra TF ops passed as OpDef.
322   auto extra_opdefs_status = RegisterExtraTfOpDefs(extra_tf_opdefs);
323   if (!extra_opdefs_status.ok()) return extra_opdefs_status;
324 
325   if (saved_model_version == 2) {
326     auto module_or = tensorflow::SavedModelObjectGraphToMlirImport(
327         input_filename, tags, exported_names, context,
328         /*unconditionally_use_set_output_shapes=*/true);
329     if (!module_or.status().ok()) return module_or.status();
330     return std::move(module_or).value();
331   } else if (saved_model_version == 1) {
332     MLIRImportOptions options;
333     options.upgrade_legacy = specs.upgrade_legacy;
334     options.unconditionally_use_set_output_shapes = true;
335     auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport(
336         input_filename, tags, exported_names, context, options,
337         enable_variable_lifting, saved_model_bundle);
338 
339     if (!module_or.status().ok()) return module_or.status();
340     return std::move(module_or).value();
341   } else {
342     return tensorflow::errors::InvalidArgument(
343         "Should be either saved model v1 or v2");
344   }
345 }
346 
347 }  // namespace tensorflow
348