xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <functional>
17 #include <iostream>
18 
19 #include "absl/strings/str_split.h"
20 #include "llvm/ADT/None.h"
21 #include "llvm/ADT/Optional.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "llvm/Support/InitLLVM.h"
29 #include "llvm/Support/SourceMgr.h"
30 #include "llvm/Support/ToolOutputFile.h"
31 #include "llvm/Support/raw_ostream.h"
32 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
33 #include "mlir/IR/AsmState.h"  // from @llvm-project
34 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
35 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
36 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
37 #include "mlir/Parser/Parser.h"  // from @llvm-project
38 #include "mlir/Pass/Pass.h"  // from @llvm-project
39 #include "mlir/Pass/PassManager.h"  // from @llvm-project
40 #include "mlir/Support/FileUtilities.h"  // from @llvm-project
41 #include "mlir/Transforms/Passes.h"  // from @llvm-project
42 #include "tensorflow/cc/saved_model/loader.h"
43 #include "tensorflow/compiler/mlir/init_mlir.h"
44 #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
45 #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
46 #include "tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h"
47 #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
48 #include "tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h"
49 #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
50 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
51 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
52 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
53 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h"
54 #include "tensorflow/compiler/mlir/xla/xla_mlir_translate.h"
55 #include "tensorflow/core/framework/types.pb.h"
56 #include "tensorflow/core/platform/errors.h"
57 #include "tensorflow/lite/model.h"
58 #include "tensorflow/lite/schema/schema_generated.h"
59 #include "tensorflow/stream_executor/lib/statusor.h"
60 
61 using mlir::MLIRContext;
62 using mlir::ModuleOp;
63 using mlir::func::FuncOp;
64 using stream_executor::port::StatusOr;
65 
66 // Debugging flag to print function mapping in the flatbuffer.
67 // NOLINTNEXTLINE
68 static llvm::cl::opt<bool> print_function_result_mapping(
69     "print-function-result-mapping",
70     llvm::cl::desc(
71         "Print the mapping of function result to flatbuffer output buffer"),
72     llvm::cl::init(false));
73 
74 // NOLINTNEXTLINE
75 static llvm::cl::opt<std::string> weight_quantization(
76     "weight_quantization",
77     llvm::cl::desc("The type of the quantized weight buffer. Must be NONE, "
78                    "INT8, FLOAT16."),
79     llvm::cl::init("NONE"));
80 
81 enum TranslationStatus { kTrSuccess, kTrFailure };
82 
PrintFunctionResultMapping(const std::string & result,ModuleOp module)83 static int PrintFunctionResultMapping(const std::string &result,
84                                       ModuleOp module) {
85   // Build model from the resultant string to extract the return values from
86   // their source of truth.
87   auto model =
88       tflite::FlatBufferModel::BuildFromBuffer(result.data(), result.size());
89   if (!model) return kTrFailure;
90 
91   // Get an unknown location for where we don't have a terminator to get the
92   // location of the return value from.
93   auto unknown_loc = mlir::UnknownLoc::get(module.getContext());
94 
95   auto print_buffer = [&](const tflite::SubGraph &subgraph, int id, int buffer,
96                           std::function<mlir::Location(int)> loc) {
97     const auto &output_tensor = (*subgraph.tensors())[buffer];
98     std::cout << "\tname: '"
99               << (output_tensor->name() ? output_tensor->name()->str()
100                                         : "<<unnamed>>")
101               << "' buffer: " << buffer;
102     if (loc) std::cout << llvm::formatv(" {0}", loc(id)).str();
103     std::cout << '\n';
104   };
105 
106   // For every subgraph print out the name (if available), each result's output
107   // buffer number and location of the return value (if available).
108   for (auto *subgraph : *(*model)->subgraphs()) {
109     std::string subgraph_name =
110         subgraph->name() ? subgraph->name()->str() : "<<unnamed subgraph>>";
111 
112     std::cout << '\'' << subgraph_name << "' inputs:\n";
113     int i = 0;
114     for (auto input : *subgraph->inputs())
115       print_buffer(*subgraph, i++, input, nullptr);
116 
117     std::cout << '\'' << subgraph_name << "' outputs:\n";
118     mlir::Operation *terminator = nullptr;
119     if (subgraph->name()) {
120       if (auto fn = module.lookupSymbol<FuncOp>(subgraph->name()->str()))
121         terminator = fn.back().getTerminator();
122     }
123     i = 0;
124     for (auto output : *subgraph->outputs()) {
125       print_buffer(*subgraph, i, output, [&](int i) {
126         return terminator ? terminator->getOperand(i).getLoc() : unknown_loc;
127       });
128     }
129   }
130   return kTrSuccess;
131 }
132 
main(int argc,char ** argv)133 int main(int argc, char **argv) {
134   // TODO(jpienaar): Revise the command line option parsing here.
135   tensorflow::InitMlir y(&argc, &argv);
136 
137   // TODO(antiagainst): We are pulling in multiple transformations as follows.
138   // Each transformation has its own set of command-line options; options of one
139   // transformation can essentially be aliases to another. For example, the
140   // -tfl-annotate-inputs has -tfl-input-arrays, -tfl-input-data-types, and
141   // -tfl-input-shapes, which are the same as -graphdef-to-mlir transformation's
142   // -tf_input_arrays, -tf_input_data_types, and -tf_input_shapes, respectively.
143   // We need to disable duplicated ones to provide a cleaner command-line option
144   // interface. That also means we need to relay the value set in one option to
145   // all its aliases.
146   mlir::registerAsmPrinterCLOptions();
147   mlir::registerMLIRContextCLOptions();
148   mlir::registerPassManagerCLOptions();
149   llvm::cl::ParseCommandLineOptions(
150       argc, argv, "TF GraphDef to TFLite FlatBuffer converter\n");
151 
152   MLIRContext context;
153   llvm::SourceMgr source_mgr;
154   mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &context);
155 
156   StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> module;
157   std::unordered_set<std::string> tags;
158 
159   tensorflow::GraphImportConfig specs;
160   specs.upgrade_legacy = upgrade_legacy;
161   specs.prune_unused_nodes = true;
162 
163   if (!select_user_tf_ops.empty() && !emit_select_tf_ops) {
164     llvm::errs() << "You must specify `emit-select-tf-ops=true` when passing "
165                     "`select-user-tf-ops` flag.";
166     return kTrFailure;
167   }
168 
169   std::unique_ptr<tensorflow::SavedModelBundle> bundle;
170 
171   // TODO(b/147435528): We need to test the e2e behavior once the graph freezing
172   // inside mlir is done.
173   if ((import_saved_model_object_graph || import_saved_model_signature_defs) &&
174       import_hlo) {
175     llvm::errs() << "Import saved model and import hlo cannot be both set.";
176     return kTrFailure;
177   }
178 
179   if (import_saved_model_object_graph || import_saved_model_signature_defs) {
180     // Saved model import path.
181     int saved_model_version;
182     if (import_saved_model_object_graph) {
183       saved_model_version = 2;
184     } else {
185       saved_model_version = 1;
186     }
187     if (input_mlir)
188       module = tensorflow::errors::InvalidArgument(
189           "Importing saved model should not have input_mlir set");
190 
191     tags = absl::StrSplit(saved_model_tags, ',');
192     std::vector<std::string> exported_names_vector =
193         absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
194     absl::Span<std::string> exported_names(exported_names_vector);
195 
196     std::vector<std::string> extra_opdefs(custom_opdefs.begin(),
197                                           custom_opdefs.end());
198     module = tensorflow::ImportSavedModel(
199         input_file_name, saved_model_version, tags, extra_opdefs,
200         exported_names, specs, /*enable_variable_lifting=*/true, &context,
201         &bundle);
202   } else if (import_hlo) {
203     // HLO import path.
204     std::string error;
205     std::unique_ptr<llvm::MemoryBuffer> buffer =
206         mlir::openInputFile(input_file_name, &error);
207     if (buffer == nullptr) {
208       llvm::errs() << "Cannot open input file: " << input_file_name << " "
209                    << error;
210       return kTrFailure;
211     }
212 
213     auto content = buffer->getBuffer();
214     if (hlo_import_type == HloImportType::hlotxt) {
215       module = xla::HloTextToMlirHloTranslateFunction(content, &context, false);
216     } else if (hlo_import_type == HloImportType::proto) {
217       module = xla::HloToMlirHloTranslateFunction(content, &context, false);
218     } else {
219       module = mlir::OwningOpRef<mlir::ModuleOp>(
220           mlir::parseSourceString<mlir::ModuleOp>(content, &context));
221     }
222   } else {
223     // Graphdef import path.
224     module = tensorflow::LoadFromGraphdefOrMlirSource(
225         input_file_name, input_mlir, use_splatted_constant, custom_opdefs,
226         specs, debug_info_file, input_arrays, input_dtypes, input_shapes,
227         output_arrays, control_output_arrays, &source_mgr, &context);
228   }
229 
230   // If errors occur, the library call in the above already logged the error
231   // message. So we can just return here.
232   if (!module.ok()) return kTrFailure;
233 
234   // Set the quantization specifications from the command line flags.
235   mlir::quant::QuantizationSpecs quant_specs;
236   if (mlir::quant::ParseInputNodeQuantSpecs(
237           input_arrays, min_values, max_values, inference_type, &quant_specs)) {
238     llvm::errs() << "Failed to get input quant spec.";
239     return kTrFailure;
240   }
241   if (weight_quantization != "NONE") {
242     quant_specs.weight_quantization = true;
243     if (weight_quantization == "INT8") {
244       quant_specs.inference_type = tensorflow::DT_QINT8;
245     } else if (weight_quantization == "FLOAT16") {
246       quant_specs.inference_type = tensorflow::DT_HALF;
247     } else {
248       llvm::errs() << "Unknown weight quantization " << weight_quantization;
249       return kTrFailure;
250     }
251   }
252   if (!emit_quant_adaptor_ops) {
253     quant_specs.inference_input_type = quant_specs.inference_type;
254   }
255 
256   if (!quant_stats_file_name.empty()) {
257     std::string error_message;
258     auto file = mlir::openInputFile(quant_stats_file_name, &error_message);
259     if (!file) {
260       llvm::errs() << "fail to open quant stats file: "
261                    << quant_stats_file_name;
262       return kTrFailure;
263     }
264     quant_specs.serialized_quant_stats = file->getBuffer().str();
265   }
266 
267   mlir::TFL::PassConfig pass_config(quant_specs);
268   pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
269   pass_config.lower_tensor_list_ops = lower_tensor_list_ops;
270   pass_config.unfold_batch_matmul = unfold_batchmatmul;
271   pass_config.unfold_large_splat_constant = unfold_large_splat_constant;
272   pass_config.guarantee_all_funcs_one_use = guarantee_all_funcs_one_use;
273   pass_config.enable_dynamic_update_slice = enable_dynamic_update_slice;
274   pass_config.runtime_verification = true;
275   pass_config.outline_tf_while = true;
276   pass_config.preserve_assert_op = preserve_assert_op;
277 
278   if (enable_hlo_to_tf_conversion) {
279     pass_config.enable_hlo_to_tf_conversion = true;
280   }
281 
282   toco::TocoFlags toco_flags;
283   toco_flags.set_force_select_tf_ops(!emit_builtin_tflite_ops);
284   toco_flags.set_enable_select_tf_ops(emit_select_tf_ops);
285   toco_flags.set_allow_custom_ops(emit_custom_ops);
286   toco_flags.set_allow_all_select_tf_ops(allow_all_select_tf_ops);
287   toco_flags.set_enable_dynamic_update_slice(enable_dynamic_update_slice);
288   // Read list of user select ops.
289   llvm::SmallVector<llvm::StringRef, 2> user_ops;
290   (llvm::StringRef(select_user_tf_ops))
291       .split(user_ops, ',', /*MaxSplit=*/-1,
292              /*KeepEmpty=*/false);
293   llvm::for_each(user_ops, [&toco_flags](llvm::StringRef op_name) {
294     *(toco_flags.add_select_user_tf_ops()) = op_name.str();
295   });
296 
297   std::string result;
298   llvm::Optional<tensorflow::Session *> session = llvm::None;
299   if (bundle) session = bundle->GetSession();
300   auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer(
301       module.ValueOrDie().get(), output_mlir, toco_flags, pass_config, tags,
302       /*saved_model_dir=*/"", session, &result);
303   if (!status.ok()) return kTrFailure;
304 
305   std::string error_msg;
306   auto output = mlir::openOutputFile(output_file_name, &error_msg);
307   if (output == nullptr) {
308     llvm::errs() << error_msg << '\n';
309     return kTrFailure;
310   }
311   output->os() << result;
312   output->keep();
313 
314   // Print out debugging info related to function mapping.
315   if (print_function_result_mapping)
316     return PrintFunctionResultMapping(result, module.ValueOrDie().get());
317   return kTrSuccess;
318 }
319