1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/Support/CommandLine.h"
17 #include "llvm/Support/FormatVariadic.h"
18 #include "llvm/Support/MemoryBuffer.h"
19 #include "llvm/Support/SourceMgr.h"
20 #include "llvm/Support/ToolOutputFile.h"
21 #include "llvm/Support/raw_ostream.h"
22 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
23 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
24 #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
25 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
26 #include "mlir/IR/Attributes.h" // from @llvm-project
27 #include "mlir/IR/Builders.h" // from @llvm-project
28 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
30 #include "mlir/IR/Location.h" // from @llvm-project
31 #include "mlir/IR/MLIRContext.h" // from @llvm-project
32 #include "mlir/IR/Operation.h" // from @llvm-project
33 #include "mlir/IR/Types.h" // from @llvm-project
34 #include "mlir/IR/Value.h" // from @llvm-project
35 #include "mlir/Support/LogicalResult.h" // from @llvm-project
36 #include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project
37 #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
38 #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h"
39 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
40 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
41 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
43 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
44
45 using llvm::cl::opt;
46
47 // Commandline flag to enable the control of flatbuffer import.
48 bool use_external_constant;
49
50 // Commandline flag to enable graph pruning.
51 bool experimental_prune_unreachable_nodes_unconditionally;
52
53 // NOLINTNEXTLINE
54 static opt<bool, true> use_external_constant_flag(
55 "use-external-constant",
56 llvm::cl::desc("Use external constant during flatbuffer import"),
57 llvm::cl::location(use_external_constant), llvm::cl::init(false));
58
59 // TODO(b/147111261): After the importer supports generic custom ops, we should
60 // change the flag to a more lightwise flag, e.g.
61 // "import_custom_ops_as_side_effect_free_ops", and let the MLIR DCE to prune
62 // the operations.
63 // NOLINTNEXTLINE
64 static opt<bool, true> experimental_prune_unreachable_nodes_unconditionally_flg(
65 "experimental-prune-unreachable-nodes-unconditionally",
66 llvm::cl::desc("Prune nodes that are not ancestors of the output nodes."),
67 llvm::cl::location(experimental_prune_unreachable_nodes_unconditionally),
68 llvm::cl::init(false));
69
70 // NOLINTNEXTLINE
71 static opt<std::string> input_arrays_flag(
72 "input-arrays",
73 llvm::cl::desc(
74 "List of input tensors, if different from the default inputs"),
75 llvm::cl::init(""));
76
77 // NOLINTNEXTLINE
78 static opt<std::string> output_arrays_flag(
79 "output-arrays",
80 llvm::cl::desc(
81 "List of output tensors, if different from the default outputs"),
82 llvm::cl::init(""));
83 using llvm::cl::opt;
84
85 // These command line flags enable control of the translation implementation.
86 bool emit_builtin_tflite_ops;
87 bool emit_custom_ops;
88 bool emit_select_tf_ops;
89 bool lower_tensor_list_ops;
90 bool strip_debug_info;
91
92 // NOLINTNEXTLINE
93 static opt<bool, true> emit_builtin_tflite_ops_flag(
94 "emit-builtin-tflite-ops",
95 llvm::cl::desc(
96 "Emit TFLite built in operations in the generated TFLite model"),
97 llvm::cl::location(emit_builtin_tflite_ops), llvm::cl::init(true));
98
99 // NOLINTNEXTLINE
100 static opt<bool, true> emit_select_tf_ops_flag(
101 "emit-select-tf-ops",
102 llvm::cl::desc(
103 "Emit Select TF operations (Flex ops) in the generated TFLite model"),
104 llvm::cl::location(emit_select_tf_ops), llvm::cl::init(false));
105
106 // NOLINTNEXTLINE
107 static opt<bool, true> emit_custom_ops_flag(
108 "emit-custom-ops",
109 llvm::cl::desc("Emit Custom operations in the generated TFLite model"),
110 llvm::cl::location(emit_custom_ops), llvm::cl::init(false));
111
112 // NOLINTNEXTLINE
113 static opt<bool, true> lower_tensor_list_ops_flag(
114 "lower-tensor-list-ops",
115 llvm::cl::desc("Lower the TensorList ops within the TFLite dialect"),
116 llvm::cl::location(lower_tensor_list_ops), llvm::cl::init(false));
117
118 // NOLINTNEXTLINE
119 static opt<bool, true> strip_debug_info_flag(
120 "strip-debug-info", llvm::cl::desc("Strip debug info during export"),
121 llvm::cl::location(strip_debug_info), llvm::cl::init(false));
122
123 namespace mlir {
124 namespace {
FlatBufferFileToMlirTrans(llvm::SourceMgr * source_mgr,MLIRContext * context,bool use_external_constant,bool experimental_prune_unreachable_nodes_unconditionally)125 static OwningOpRef<mlir::ModuleOp> FlatBufferFileToMlirTrans(
126 llvm::SourceMgr* source_mgr, MLIRContext* context,
127 bool use_external_constant,
128 bool experimental_prune_unreachable_nodes_unconditionally) {
129 const llvm::MemoryBuffer* input =
130 source_mgr->getMemoryBuffer(source_mgr->getMainFileID());
131 std::string error;
132 auto loc =
133 mlir::FileLineColLoc::get(context, input->getBufferIdentifier(), 0, 0);
134
135 // Parses input/output names from command line options.
136 std::vector<std::string> inputs;
137 std::vector<std::string> outputs;
138 // Use output parser since we only have tensor names.
139 if (!tensorflow::ParseOutputArrayInfo(input_arrays_flag, &inputs).ok()) {
140 return emitError(loc, "parsing input array info failed ")
141 << input_arrays_flag,
142 nullptr;
143 }
144 if (!tensorflow::ParseOutputArrayInfo(output_arrays_flag, &outputs).ok()) {
145 return emitError(loc, "parsing output array info failed ")
146 << output_arrays_flag,
147 nullptr;
148 }
149 return tflite::FlatBufferToMlir(
150 absl::string_view(input->getBufferStart(), input->getBufferSize()),
151 context, loc, use_external_constant, inputs, outputs,
152 experimental_prune_unreachable_nodes_unconditionally);
153 }
154
MlirToFlatBufferFileTranslateFunction(ModuleOp module,llvm::raw_ostream & output)155 static LogicalResult MlirToFlatBufferFileTranslateFunction(
156 ModuleOp module, llvm::raw_ostream& output) {
157 std::string serialized_flatbuffer;
158 std::unique_ptr<tensorflow::OpOrArgNameMapper> op_or_arg_name_mapper;
159 if (strip_debug_info) {
160 op_or_arg_name_mapper =
161 std::make_unique<tensorflow::OpOrArgStripNameMapper>();
162 } else {
163 op_or_arg_name_mapper =
164 std::make_unique<tensorflow::OpOrArgLocNameMapper>();
165 }
166 tflite::FlatbufferExportOptions options;
167 options.toco_flags.set_force_select_tf_ops(!emit_builtin_tflite_ops);
168 options.toco_flags.set_enable_select_tf_ops(emit_select_tf_ops);
169 options.toco_flags.set_allow_custom_ops(emit_custom_ops);
170 options.op_or_arg_name_mapper = op_or_arg_name_mapper.get();
171 if (!tflite::MlirToFlatBufferTranslateFunction(module, options,
172 &serialized_flatbuffer))
173 return mlir::failure();
174
175 output << serialized_flatbuffer;
176 return success();
177 }
178 } // namespace
179
180 static TranslateToMLIRRegistration FlatBufferFileToMlirTransReg(
181 "tflite-flatbuffer-to-mlir",
__anona6bed5d70202(llvm::SourceMgr& source_mgr, MLIRContext* context) 182 [](llvm::SourceMgr& source_mgr, MLIRContext* context) {
183 return FlatBufferFileToMlirTrans(
184 &source_mgr, context, use_external_constant,
185 experimental_prune_unreachable_nodes_unconditionally);
186 });
187
188 static TranslateFromMLIRRegistration MLIRToFlatBufferTranslate(
189 "mlir-to-tflite-flatbuffer", MlirToFlatBufferFileTranslateFunction,
__anona6bed5d70302(DialectRegistry& registry) 190 [](DialectRegistry& registry) {
191 registry.insert<quant::QuantizationDialect,
192 quantfork::QuantizationForkDialect>();
193 mlir::RegisterAllTensorFlowDialects(registry);
194 registry.insert<TFL::TensorFlowLiteDialect>();
195 registry.insert<arith::ArithmeticDialect>();
196 registry.insert<func::FuncDialect>();
197 });
198 } // namespace mlir
199