1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <functional>
17 #include <iostream>
18
19 #include "absl/strings/str_split.h"
20 #include "llvm/ADT/None.h"
21 #include "llvm/ADT/Optional.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "llvm/Support/InitLLVM.h"
29 #include "llvm/Support/SourceMgr.h"
30 #include "llvm/Support/ToolOutputFile.h"
31 #include "llvm/Support/raw_ostream.h"
32 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
33 #include "mlir/IR/AsmState.h" // from @llvm-project
34 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
35 #include "mlir/IR/Diagnostics.h" // from @llvm-project
36 #include "mlir/IR/MLIRContext.h" // from @llvm-project
37 #include "mlir/Parser/Parser.h" // from @llvm-project
38 #include "mlir/Pass/Pass.h" // from @llvm-project
39 #include "mlir/Pass/PassManager.h" // from @llvm-project
40 #include "mlir/Support/FileUtilities.h" // from @llvm-project
41 #include "mlir/Transforms/Passes.h" // from @llvm-project
42 #include "tensorflow/cc/saved_model/loader.h"
43 #include "tensorflow/compiler/mlir/init_mlir.h"
44 #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
45 #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
46 #include "tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h"
47 #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
48 #include "tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h"
49 #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
50 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
51 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
52 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
53 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h"
54 #include "tensorflow/compiler/mlir/xla/xla_mlir_translate.h"
55 #include "tensorflow/core/framework/types.pb.h"
56 #include "tensorflow/core/platform/errors.h"
57 #include "tensorflow/lite/model.h"
58 #include "tensorflow/lite/schema/schema_generated.h"
59 #include "tensorflow/stream_executor/lib/statusor.h"
60
61 using mlir::MLIRContext;
62 using mlir::ModuleOp;
63 using mlir::func::FuncOp;
64 using stream_executor::port::StatusOr;
65
66 // Debugging flag to print function mapping in the flatbuffer.
67 // NOLINTNEXTLINE
68 static llvm::cl::opt<bool> print_function_result_mapping(
69 "print-function-result-mapping",
70 llvm::cl::desc(
71 "Print the mapping of function result to flatbuffer output buffer"),
72 llvm::cl::init(false));
73
74 // NOLINTNEXTLINE
75 static llvm::cl::opt<std::string> weight_quantization(
76 "weight_quantization",
77 llvm::cl::desc("The type of the quantized weight buffer. Must be NONE, "
78 "INT8, FLOAT16."),
79 llvm::cl::init("NONE"));
80
81 enum TranslationStatus { kTrSuccess, kTrFailure };
82
PrintFunctionResultMapping(const std::string & result,ModuleOp module)83 static int PrintFunctionResultMapping(const std::string &result,
84 ModuleOp module) {
85 // Build model from the resultant string to extract the return values from
86 // their source of truth.
87 auto model =
88 tflite::FlatBufferModel::BuildFromBuffer(result.data(), result.size());
89 if (!model) return kTrFailure;
90
91 // Get an unknown location for where we don't have a terminator to get the
92 // location of the return value from.
93 auto unknown_loc = mlir::UnknownLoc::get(module.getContext());
94
95 auto print_buffer = [&](const tflite::SubGraph &subgraph, int id, int buffer,
96 std::function<mlir::Location(int)> loc) {
97 const auto &output_tensor = (*subgraph.tensors())[buffer];
98 std::cout << "\tname: '"
99 << (output_tensor->name() ? output_tensor->name()->str()
100 : "<<unnamed>>")
101 << "' buffer: " << buffer;
102 if (loc) std::cout << llvm::formatv(" {0}", loc(id)).str();
103 std::cout << '\n';
104 };
105
106 // For every subgraph print out the name (if available), each result's output
107 // buffer number and location of the return value (if available).
108 for (auto *subgraph : *(*model)->subgraphs()) {
109 std::string subgraph_name =
110 subgraph->name() ? subgraph->name()->str() : "<<unnamed subgraph>>";
111
112 std::cout << '\'' << subgraph_name << "' inputs:\n";
113 int i = 0;
114 for (auto input : *subgraph->inputs())
115 print_buffer(*subgraph, i++, input, nullptr);
116
117 std::cout << '\'' << subgraph_name << "' outputs:\n";
118 mlir::Operation *terminator = nullptr;
119 if (subgraph->name()) {
120 if (auto fn = module.lookupSymbol<FuncOp>(subgraph->name()->str()))
121 terminator = fn.back().getTerminator();
122 }
123 i = 0;
124 for (auto output : *subgraph->outputs()) {
125 print_buffer(*subgraph, i, output, [&](int i) {
126 return terminator ? terminator->getOperand(i).getLoc() : unknown_loc;
127 });
128 }
129 }
130 return kTrSuccess;
131 }
132
main(int argc,char ** argv)133 int main(int argc, char **argv) {
134 // TODO(jpienaar): Revise the command line option parsing here.
135 tensorflow::InitMlir y(&argc, &argv);
136
137 // TODO(antiagainst): We are pulling in multiple transformations as follows.
138 // Each transformation has its own set of command-line options; options of one
139 // transformation can essentially be aliases to another. For example, the
140 // -tfl-annotate-inputs has -tfl-input-arrays, -tfl-input-data-types, and
141 // -tfl-input-shapes, which are the same as -graphdef-to-mlir transformation's
142 // -tf_input_arrays, -tf_input_data_types, and -tf_input_shapes, respectively.
143 // We need to disable duplicated ones to provide a cleaner command-line option
144 // interface. That also means we need to relay the value set in one option to
145 // all its aliases.
146 mlir::registerAsmPrinterCLOptions();
147 mlir::registerMLIRContextCLOptions();
148 mlir::registerPassManagerCLOptions();
149 llvm::cl::ParseCommandLineOptions(
150 argc, argv, "TF GraphDef to TFLite FlatBuffer converter\n");
151
152 MLIRContext context;
153 llvm::SourceMgr source_mgr;
154 mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &context);
155
156 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> module;
157 std::unordered_set<std::string> tags;
158
159 tensorflow::GraphImportConfig specs;
160 specs.upgrade_legacy = upgrade_legacy;
161 specs.prune_unused_nodes = true;
162
163 if (!select_user_tf_ops.empty() && !emit_select_tf_ops) {
164 llvm::errs() << "You must specify `emit-select-tf-ops=true` when passing "
165 "`select-user-tf-ops` flag.";
166 return kTrFailure;
167 }
168
169 std::unique_ptr<tensorflow::SavedModelBundle> bundle;
170
171 // TODO(b/147435528): We need to test the e2e behavior once the graph freezing
172 // inside mlir is done.
173 if ((import_saved_model_object_graph || import_saved_model_signature_defs) &&
174 import_hlo) {
175 llvm::errs() << "Import saved model and import hlo cannot be both set.";
176 return kTrFailure;
177 }
178
179 if (import_saved_model_object_graph || import_saved_model_signature_defs) {
180 // Saved model import path.
181 int saved_model_version;
182 if (import_saved_model_object_graph) {
183 saved_model_version = 2;
184 } else {
185 saved_model_version = 1;
186 }
187 if (input_mlir)
188 module = tensorflow::errors::InvalidArgument(
189 "Importing saved model should not have input_mlir set");
190
191 tags = absl::StrSplit(saved_model_tags, ',');
192 std::vector<std::string> exported_names_vector =
193 absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
194 absl::Span<std::string> exported_names(exported_names_vector);
195
196 std::vector<std::string> extra_opdefs(custom_opdefs.begin(),
197 custom_opdefs.end());
198 module = tensorflow::ImportSavedModel(
199 input_file_name, saved_model_version, tags, extra_opdefs,
200 exported_names, specs, /*enable_variable_lifting=*/true, &context,
201 &bundle);
202 } else if (import_hlo) {
203 // HLO import path.
204 std::string error;
205 std::unique_ptr<llvm::MemoryBuffer> buffer =
206 mlir::openInputFile(input_file_name, &error);
207 if (buffer == nullptr) {
208 llvm::errs() << "Cannot open input file: " << input_file_name << " "
209 << error;
210 return kTrFailure;
211 }
212
213 auto content = buffer->getBuffer();
214 if (hlo_import_type == HloImportType::hlotxt) {
215 module = xla::HloTextToMlirHloTranslateFunction(content, &context, false);
216 } else if (hlo_import_type == HloImportType::proto) {
217 module = xla::HloToMlirHloTranslateFunction(content, &context, false);
218 } else {
219 module = mlir::OwningOpRef<mlir::ModuleOp>(
220 mlir::parseSourceString<mlir::ModuleOp>(content, &context));
221 }
222 } else {
223 // Graphdef import path.
224 module = tensorflow::LoadFromGraphdefOrMlirSource(
225 input_file_name, input_mlir, use_splatted_constant, custom_opdefs,
226 specs, debug_info_file, input_arrays, input_dtypes, input_shapes,
227 output_arrays, control_output_arrays, &source_mgr, &context);
228 }
229
230 // If errors occur, the library call in the above already logged the error
231 // message. So we can just return here.
232 if (!module.ok()) return kTrFailure;
233
234 // Set the quantization specifications from the command line flags.
235 mlir::quant::QuantizationSpecs quant_specs;
236 if (mlir::quant::ParseInputNodeQuantSpecs(
237 input_arrays, min_values, max_values, inference_type, &quant_specs)) {
238 llvm::errs() << "Failed to get input quant spec.";
239 return kTrFailure;
240 }
241 if (weight_quantization != "NONE") {
242 quant_specs.weight_quantization = true;
243 if (weight_quantization == "INT8") {
244 quant_specs.inference_type = tensorflow::DT_QINT8;
245 } else if (weight_quantization == "FLOAT16") {
246 quant_specs.inference_type = tensorflow::DT_HALF;
247 } else {
248 llvm::errs() << "Unknown weight quantization " << weight_quantization;
249 return kTrFailure;
250 }
251 }
252 if (!emit_quant_adaptor_ops) {
253 quant_specs.inference_input_type = quant_specs.inference_type;
254 }
255
256 if (!quant_stats_file_name.empty()) {
257 std::string error_message;
258 auto file = mlir::openInputFile(quant_stats_file_name, &error_message);
259 if (!file) {
260 llvm::errs() << "fail to open quant stats file: "
261 << quant_stats_file_name;
262 return kTrFailure;
263 }
264 quant_specs.serialized_quant_stats = file->getBuffer().str();
265 }
266
267 mlir::TFL::PassConfig pass_config(quant_specs);
268 pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
269 pass_config.lower_tensor_list_ops = lower_tensor_list_ops;
270 pass_config.unfold_batch_matmul = unfold_batchmatmul;
271 pass_config.unfold_large_splat_constant = unfold_large_splat_constant;
272 pass_config.guarantee_all_funcs_one_use = guarantee_all_funcs_one_use;
273 pass_config.enable_dynamic_update_slice = enable_dynamic_update_slice;
274 pass_config.runtime_verification = true;
275 pass_config.outline_tf_while = true;
276 pass_config.preserve_assert_op = preserve_assert_op;
277
278 if (enable_hlo_to_tf_conversion) {
279 pass_config.enable_hlo_to_tf_conversion = true;
280 }
281
282 toco::TocoFlags toco_flags;
283 toco_flags.set_force_select_tf_ops(!emit_builtin_tflite_ops);
284 toco_flags.set_enable_select_tf_ops(emit_select_tf_ops);
285 toco_flags.set_allow_custom_ops(emit_custom_ops);
286 toco_flags.set_allow_all_select_tf_ops(allow_all_select_tf_ops);
287 toco_flags.set_enable_dynamic_update_slice(enable_dynamic_update_slice);
288 // Read list of user select ops.
289 llvm::SmallVector<llvm::StringRef, 2> user_ops;
290 (llvm::StringRef(select_user_tf_ops))
291 .split(user_ops, ',', /*MaxSplit=*/-1,
292 /*KeepEmpty=*/false);
293 llvm::for_each(user_ops, [&toco_flags](llvm::StringRef op_name) {
294 *(toco_flags.add_select_user_tf_ops()) = op_name.str();
295 });
296
297 std::string result;
298 llvm::Optional<tensorflow::Session *> session = llvm::None;
299 if (bundle) session = bundle->GetSession();
300 auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer(
301 module.ValueOrDie().get(), output_mlir, toco_flags, pass_config, tags,
302 /*saved_model_dir=*/"", session, &result);
303 if (!status.ok()) return kTrFailure;
304
305 std::string error_msg;
306 auto output = mlir::openOutputFile(output_file_name, &error_msg);
307 if (output == nullptr) {
308 llvm::errs() << error_msg << '\n';
309 return kTrFailure;
310 }
311 output->os() << result;
312 output->keep();
313
314 // Print out debugging info related to function mapping.
315 if (print_function_result_mapping)
316 return PrintFunctionResultMapping(result, module.ValueOrDie().get());
317 return kTrSuccess;
318 }
319